Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
9c4f7389
Unverified
Commit
9c4f7389
authored
Aug 22, 2023
by
vfdev
Committed by
GitHub
Aug 22, 2023
Browse files
Fixed issue with jitted AA transforms in v2 and added tests (#7839)
parent
37081ee6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
108 additions
and
5 deletions
+108
-5
test/test_transforms_v2_consistency.py
test/test_transforms_v2_consistency.py
+94
-0
torchvision/transforms/v2/_auto_augment.py
torchvision/transforms/v2/_auto_augment.py
+14
-5
No files found.
test/test_transforms_v2_consistency.py
View file @
9c4f7389
...
@@ -927,6 +927,29 @@ class TestAATransforms:
...
@@ -927,6 +927,29 @@ class TestAATransforms:
assert_close
(
expected_output
,
output
,
atol
=
1
,
rtol
=
0.1
)
assert_close
(
expected_output
,
output
,
atol
=
1
,
rtol
=
0.1
)
@
pytest
.
mark
.
parametrize
(
"interpolation"
,
[
v2_transforms
.
InterpolationMode
.
NEAREST
,
v2_transforms
.
InterpolationMode
.
BILINEAR
,
],
)
def
test_randaug_jit
(
self
,
interpolation
):
inpt
=
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
)
t_ref
=
legacy_transforms
.
RandAugment
(
interpolation
=
interpolation
,
num_ops
=
1
)
t
=
v2_transforms
.
RandAugment
(
interpolation
=
interpolation
,
num_ops
=
1
)
tt_ref
=
torch
.
jit
.
script
(
t_ref
)
tt
=
torch
.
jit
.
script
(
t
)
torch
.
manual_seed
(
12
)
expected_output
=
tt_ref
(
inpt
)
torch
.
manual_seed
(
12
)
scripted_output
=
tt
(
inpt
)
assert_equal
(
scripted_output
,
expected_output
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"inpt"
,
"inpt"
,
[
[
...
@@ -979,6 +1002,29 @@ class TestAATransforms:
...
@@ -979,6 +1002,29 @@ class TestAATransforms:
assert_close
(
expected_output
,
output
,
atol
=
1
,
rtol
=
0.1
)
assert_close
(
expected_output
,
output
,
atol
=
1
,
rtol
=
0.1
)
@
pytest
.
mark
.
parametrize
(
"interpolation"
,
[
v2_transforms
.
InterpolationMode
.
NEAREST
,
v2_transforms
.
InterpolationMode
.
BILINEAR
,
],
)
def
test_trivial_aug_jit
(
self
,
interpolation
):
inpt
=
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
)
t_ref
=
legacy_transforms
.
TrivialAugmentWide
(
interpolation
=
interpolation
)
t
=
v2_transforms
.
TrivialAugmentWide
(
interpolation
=
interpolation
)
tt_ref
=
torch
.
jit
.
script
(
t_ref
)
tt
=
torch
.
jit
.
script
(
t
)
torch
.
manual_seed
(
12
)
expected_output
=
tt_ref
(
inpt
)
torch
.
manual_seed
(
12
)
scripted_output
=
tt
(
inpt
)
assert_equal
(
scripted_output
,
expected_output
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"inpt"
,
"inpt"
,
[
[
...
@@ -1032,6 +1078,30 @@ class TestAATransforms:
...
@@ -1032,6 +1078,30 @@ class TestAATransforms:
assert_equal
(
expected_output
,
output
)
assert_equal
(
expected_output
,
output
)
@
pytest
.
mark
.
parametrize
(
"interpolation"
,
[
v2_transforms
.
InterpolationMode
.
NEAREST
,
v2_transforms
.
InterpolationMode
.
BILINEAR
,
],
)
def
test_augmix_jit
(
self
,
interpolation
):
inpt
=
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
)
t_ref
=
legacy_transforms
.
AugMix
(
interpolation
=
interpolation
,
mixture_width
=
1
,
chain_depth
=
1
)
t
=
v2_transforms
.
AugMix
(
interpolation
=
interpolation
,
mixture_width
=
1
,
chain_depth
=
1
)
tt_ref
=
torch
.
jit
.
script
(
t_ref
)
tt
=
torch
.
jit
.
script
(
t
)
torch
.
manual_seed
(
12
)
expected_output
=
tt_ref
(
inpt
)
torch
.
manual_seed
(
12
)
scripted_output
=
tt
(
inpt
)
assert_equal
(
scripted_output
,
expected_output
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"inpt"
,
"inpt"
,
[
[
...
@@ -1061,6 +1131,30 @@ class TestAATransforms:
...
@@ -1061,6 +1131,30 @@ class TestAATransforms:
assert_equal
(
expected_output
,
output
)
assert_equal
(
expected_output
,
output
)
@
pytest
.
mark
.
parametrize
(
"interpolation"
,
[
v2_transforms
.
InterpolationMode
.
NEAREST
,
v2_transforms
.
InterpolationMode
.
BILINEAR
,
],
)
def
test_aa_jit
(
self
,
interpolation
):
inpt
=
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
)
aa_policy
=
legacy_transforms
.
AutoAugmentPolicy
(
"imagenet"
)
t_ref
=
legacy_transforms
.
AutoAugment
(
aa_policy
,
interpolation
=
interpolation
)
t
=
v2_transforms
.
AutoAugment
(
aa_policy
,
interpolation
=
interpolation
)
tt_ref
=
torch
.
jit
.
script
(
t_ref
)
tt
=
torch
.
jit
.
script
(
t
)
torch
.
manual_seed
(
12
)
expected_output
=
tt_ref
(
inpt
)
torch
.
manual_seed
(
12
)
scripted_output
=
tt
(
inpt
)
assert_equal
(
scripted_output
,
expected_output
)
def
import_transforms_from_references
(
reference
):
def
import_transforms_from_references
(
reference
):
HERE
=
Path
(
__file__
).
parent
HERE
=
Path
(
__file__
).
parent
...
...
torchvision/transforms/v2/_auto_augment.py
View file @
9c4f7389
...
@@ -28,7 +28,16 @@ class _AutoAugmentBase(Transform):
...
@@ -28,7 +28,16 @@ class _AutoAugmentBase(Transform):
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
interpolation
=
_check_interpolation
(
interpolation
)
self
.
interpolation
=
_check_interpolation
(
interpolation
)
self
.
fill
=
_setup_fill_arg
(
fill
)
self
.
fill
=
fill
self
.
_fill
=
_setup_fill_arg
(
fill
)
def
_extract_params_for_v1_transform
(
self
)
->
Dict
[
str
,
Any
]:
params
=
super
().
_extract_params_for_v1_transform
()
if
not
(
params
[
"fill"
]
is
None
or
isinstance
(
params
[
"fill"
],
(
int
,
float
))):
raise
ValueError
(
f
"
{
type
(
self
).
__name__
}
() can only be scripted for a scalar `fill`, but got
{
self
.
fill
}
."
)
return
params
def
_get_random_item
(
self
,
dct
:
Dict
[
str
,
Tuple
[
Callable
,
bool
]])
->
Tuple
[
str
,
Tuple
[
Callable
,
bool
]]:
def
_get_random_item
(
self
,
dct
:
Dict
[
str
,
Tuple
[
Callable
,
bool
]])
->
Tuple
[
str
,
Tuple
[
Callable
,
bool
]]:
keys
=
tuple
(
dct
.
keys
())
keys
=
tuple
(
dct
.
keys
())
...
@@ -335,7 +344,7 @@ class AutoAugment(_AutoAugmentBase):
...
@@ -335,7 +344,7 @@ class AutoAugment(_AutoAugmentBase):
magnitude
=
0.0
magnitude
=
0.0
image_or_video
=
self
.
_apply_image_or_video_transform
(
image_or_video
=
self
.
_apply_image_or_video_transform
(
image_or_video
,
transform_id
,
magnitude
,
interpolation
=
self
.
interpolation
,
fill
=
self
.
fill
image_or_video
,
transform_id
,
magnitude
,
interpolation
=
self
.
interpolation
,
fill
=
self
.
_
fill
)
)
return
self
.
_unflatten_and_insert_image_or_video
(
flat_inputs_with_spec
,
image_or_video
)
return
self
.
_unflatten_and_insert_image_or_video
(
flat_inputs_with_spec
,
image_or_video
)
...
@@ -419,7 +428,7 @@ class RandAugment(_AutoAugmentBase):
...
@@ -419,7 +428,7 @@ class RandAugment(_AutoAugmentBase):
else
:
else
:
magnitude
=
0.0
magnitude
=
0.0
image_or_video
=
self
.
_apply_image_or_video_transform
(
image_or_video
=
self
.
_apply_image_or_video_transform
(
image_or_video
,
transform_id
,
magnitude
,
interpolation
=
self
.
interpolation
,
fill
=
self
.
fill
image_or_video
,
transform_id
,
magnitude
,
interpolation
=
self
.
interpolation
,
fill
=
self
.
_
fill
)
)
return
self
.
_unflatten_and_insert_image_or_video
(
flat_inputs_with_spec
,
image_or_video
)
return
self
.
_unflatten_and_insert_image_or_video
(
flat_inputs_with_spec
,
image_or_video
)
...
@@ -491,7 +500,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
...
@@ -491,7 +500,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
magnitude
=
0.0
magnitude
=
0.0
image_or_video
=
self
.
_apply_image_or_video_transform
(
image_or_video
=
self
.
_apply_image_or_video_transform
(
image_or_video
,
transform_id
,
magnitude
,
interpolation
=
self
.
interpolation
,
fill
=
self
.
fill
image_or_video
,
transform_id
,
magnitude
,
interpolation
=
self
.
interpolation
,
fill
=
self
.
_
fill
)
)
return
self
.
_unflatten_and_insert_image_or_video
(
flat_inputs_with_spec
,
image_or_video
)
return
self
.
_unflatten_and_insert_image_or_video
(
flat_inputs_with_spec
,
image_or_video
)
...
@@ -614,7 +623,7 @@ class AugMix(_AutoAugmentBase):
...
@@ -614,7 +623,7 @@ class AugMix(_AutoAugmentBase):
magnitude
=
0.0
magnitude
=
0.0
aug
=
self
.
_apply_image_or_video_transform
(
aug
=
self
.
_apply_image_or_video_transform
(
aug
,
transform_id
,
magnitude
,
interpolation
=
self
.
interpolation
,
fill
=
self
.
fill
aug
,
transform_id
,
magnitude
,
interpolation
=
self
.
interpolation
,
fill
=
self
.
_
fill
)
)
mix
.
add_
(
combined_weights
[:,
i
].
reshape
(
batch_dims
)
*
aug
)
mix
.
add_
(
combined_weights
[:,
i
].
reshape
(
batch_dims
)
*
aug
)
mix
=
mix
.
reshape
(
orig_dims
).
to
(
dtype
=
image_or_video
.
dtype
)
mix
=
mix
.
reshape
(
orig_dims
).
to
(
dtype
=
image_or_video
.
dtype
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment