Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
d60d5e71
Unverified
Commit
d60d5e71
authored
Aug 31, 2023
by
Philip Meier
Committed by
GitHub
Aug 31, 2023
Browse files
[CHERRYPICK] allow sequence fill for v2 AA scripted (#7920)
parent
f588fd1a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
11 deletions
+14
-11
test/test_transforms_v2_consistency.py
test/test_transforms_v2_consistency.py
+12
-9
torchvision/transforms/v2/_auto_augment.py
torchvision/transforms/v2/_auto_augment.py
+2
-2
No files found.
test/test_transforms_v2_consistency.py
View file @
d60d5e71
...
@@ -755,10 +755,11 @@ class TestAATransforms:
...
@@ -755,10 +755,11 @@ class TestAATransforms:
v2_transforms
.
InterpolationMode
.
BILINEAR
,
v2_transforms
.
InterpolationMode
.
BILINEAR
,
],
],
)
)
def
test_randaug_jit
(
self
,
interpolation
):
@
pytest
.
mark
.
parametrize
(
"fill"
,
[
None
,
85
,
(
10
,
-
10
,
10
),
0.7
,
[
0.0
,
0.0
,
0.0
],
[
1
],
1
])
def
test_randaug_jit
(
self
,
interpolation
,
fill
):
inpt
=
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
)
inpt
=
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
)
t_ref
=
legacy_transforms
.
RandAugment
(
interpolation
=
interpolation
,
num_ops
=
1
)
t_ref
=
legacy_transforms
.
RandAugment
(
interpolation
=
interpolation
,
num_ops
=
1
,
fill
=
fill
)
t
=
v2_transforms
.
RandAugment
(
interpolation
=
interpolation
,
num_ops
=
1
)
t
=
v2_transforms
.
RandAugment
(
interpolation
=
interpolation
,
num_ops
=
1
,
fill
=
fill
)
tt_ref
=
torch
.
jit
.
script
(
t_ref
)
tt_ref
=
torch
.
jit
.
script
(
t_ref
)
tt
=
torch
.
jit
.
script
(
t
)
tt
=
torch
.
jit
.
script
(
t
)
...
@@ -830,10 +831,11 @@ class TestAATransforms:
...
@@ -830,10 +831,11 @@ class TestAATransforms:
v2_transforms
.
InterpolationMode
.
BILINEAR
,
v2_transforms
.
InterpolationMode
.
BILINEAR
,
],
],
)
)
def
test_trivial_aug_jit
(
self
,
interpolation
):
@
pytest
.
mark
.
parametrize
(
"fill"
,
[
None
,
85
,
(
10
,
-
10
,
10
),
0.7
,
[
0.0
,
0.0
,
0.0
],
[
1
],
1
])
def
test_trivial_aug_jit
(
self
,
interpolation
,
fill
):
inpt
=
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
)
inpt
=
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
)
t_ref
=
legacy_transforms
.
TrivialAugmentWide
(
interpolation
=
interpolation
)
t_ref
=
legacy_transforms
.
TrivialAugmentWide
(
interpolation
=
interpolation
,
fill
=
fill
)
t
=
v2_transforms
.
TrivialAugmentWide
(
interpolation
=
interpolation
)
t
=
v2_transforms
.
TrivialAugmentWide
(
interpolation
=
interpolation
,
fill
=
fill
)
tt_ref
=
torch
.
jit
.
script
(
t_ref
)
tt_ref
=
torch
.
jit
.
script
(
t_ref
)
tt
=
torch
.
jit
.
script
(
t
)
tt
=
torch
.
jit
.
script
(
t
)
...
@@ -906,11 +908,12 @@ class TestAATransforms:
...
@@ -906,11 +908,12 @@ class TestAATransforms:
v2_transforms
.
InterpolationMode
.
BILINEAR
,
v2_transforms
.
InterpolationMode
.
BILINEAR
,
],
],
)
)
def
test_augmix_jit
(
self
,
interpolation
):
@
pytest
.
mark
.
parametrize
(
"fill"
,
[
None
,
85
,
(
10
,
-
10
,
10
),
0.7
,
[
0.0
,
0.0
,
0.0
],
[
1
],
1
])
def
test_augmix_jit
(
self
,
interpolation
,
fill
):
inpt
=
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
)
inpt
=
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
)
t_ref
=
legacy_transforms
.
AugMix
(
interpolation
=
interpolation
,
mixture_width
=
1
,
chain_depth
=
1
)
t_ref
=
legacy_transforms
.
AugMix
(
interpolation
=
interpolation
,
mixture_width
=
1
,
chain_depth
=
1
,
fill
=
fill
)
t
=
v2_transforms
.
AugMix
(
interpolation
=
interpolation
,
mixture_width
=
1
,
chain_depth
=
1
)
t
=
v2_transforms
.
AugMix
(
interpolation
=
interpolation
,
mixture_width
=
1
,
chain_depth
=
1
,
fill
=
fill
)
tt_ref
=
torch
.
jit
.
script
(
t_ref
)
tt_ref
=
torch
.
jit
.
script
(
t_ref
)
tt
=
torch
.
jit
.
script
(
t
)
tt
=
torch
.
jit
.
script
(
t
)
...
...
torchvision/transforms/v2/_auto_augment.py
View file @
d60d5e71
...
@@ -33,8 +33,8 @@ class _AutoAugmentBase(Transform):
...
@@ -33,8 +33,8 @@ class _AutoAugmentBase(Transform):
def
_extract_params_for_v1_transform
(
self
)
->
Dict
[
str
,
Any
]:
def
_extract_params_for_v1_transform
(
self
)
->
Dict
[
str
,
Any
]:
params
=
super
().
_extract_params_for_v1_transform
()
params
=
super
().
_extract_params_for_v1_transform
()
if
not
(
params
[
"fill"
]
is
None
or
isinstance
(
params
[
"fill"
],
(
int
,
float
))
):
if
isinstance
(
params
[
"fill"
],
dict
):
raise
ValueError
(
f
"
{
type
(
self
).
__name__
}
() can
only
be scripted for
a scalar `fill`, but got
{
self
.
fill
}
."
)
raise
ValueError
(
f
"
{
type
(
self
).
__name__
}
() can
not
be scripted for
when `fill` is a dictionary
."
)
return
params
return
params
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment