Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
7cf0f4cc
Unverified
Commit
7cf0f4cc
authored
Jan 31, 2023
by
Philip Meier
Committed by
GitHub
Jan 31, 2023
Browse files
make transforms v2 JIT scriptable (#7135)
parent
170160a5
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
206 additions
and
17 deletions
+206
-17
test/test_prototype_transforms_consistency.py
test/test_prototype_transforms_consistency.py
+58
-13
torchvision/prototype/transforms/_augment.py
torchvision/prototype/transforms/_augment.py
+9
-1
torchvision/prototype/transforms/_auto_augment.py
torchvision/prototype/transforms/_auto_augment.py
+7
-1
torchvision/prototype/transforms/_color.py
torchvision/prototype/transforms/_color.py
+22
-1
torchvision/prototype/transforms/_geometry.py
torchvision/prototype/transforms/_geometry.py
+53
-0
torchvision/prototype/transforms/_meta.py
torchvision/prototype/transforms/_meta.py
+3
-0
torchvision/prototype/transforms/_misc.py
torchvision/prototype/transforms/_misc.py
+6
-0
torchvision/prototype/transforms/_transform.py
torchvision/prototype/transforms/_transform.py
+48
-1
No files found.
test/test_prototype_transforms_consistency.py
View file @
7cf0f4cc
...
...
@@ -34,6 +34,15 @@ from torchvision.transforms import functional as legacy_F
DEFAULT_MAKE_IMAGES_KWARGS
=
dict
(
color_spaces
=
[
"RGB"
],
extra_dims
=
[(
4
,)])
class
NotScriptableArgsKwargs
(
ArgsKwargs
):
"""
This class is used to mark parameters that render the transform non-scriptable. They still work in eager mode and
thus will be tested there, but will be skipped by the JIT tests.
"""
pass
class
ConsistencyConfig
:
def
__init__
(
self
,
...
...
@@ -73,7 +82,7 @@ CONSISTENCY_CONFIGS = [
prototype_transforms
.
Resize
,
legacy_transforms
.
Resize
,
[
ArgsKwargs
(
32
),
NotScriptable
ArgsKwargs
(
32
),
ArgsKwargs
([
32
]),
ArgsKwargs
((
32
,
29
)),
ArgsKwargs
((
31
,
28
),
interpolation
=
prototype_transforms
.
InterpolationMode
.
NEAREST
),
...
...
@@ -84,8 +93,10 @@ CONSISTENCY_CONFIGS = [
# ArgsKwargs((30, 27), interpolation=0),
# ArgsKwargs((35, 29), interpolation=2),
# ArgsKwargs((34, 25), interpolation=3),
ArgsKwargs
(
31
,
max_size
=
32
),
ArgsKwargs
(
30
,
max_size
=
100
),
NotScriptableArgsKwargs
(
31
,
max_size
=
32
),
ArgsKwargs
([
31
],
max_size
=
32
),
NotScriptableArgsKwargs
(
30
,
max_size
=
100
),
ArgsKwargs
([
31
],
max_size
=
32
),
ArgsKwargs
((
29
,
32
),
antialias
=
False
),
ArgsKwargs
((
28
,
31
),
antialias
=
True
),
],
...
...
@@ -121,14 +132,15 @@ CONSISTENCY_CONFIGS = [
prototype_transforms
.
Pad
,
legacy_transforms
.
Pad
,
[
ArgsKwargs
(
3
),
NotScriptable
ArgsKwargs
(
3
),
ArgsKwargs
([
3
]),
ArgsKwargs
([
2
,
3
]),
ArgsKwargs
([
3
,
2
,
1
,
4
]),
ArgsKwargs
(
5
,
fill
=
1
,
padding_mode
=
"constant"
),
ArgsKwargs
(
5
,
padding_mode
=
"edge"
),
ArgsKwargs
(
5
,
padding_mode
=
"reflect"
),
ArgsKwargs
(
5
,
padding_mode
=
"symmetric"
),
NotScriptableArgsKwargs
(
5
,
fill
=
1
,
padding_mode
=
"constant"
),
ArgsKwargs
([
5
],
fill
=
1
,
padding_mode
=
"constant"
),
NotScriptableArgsKwargs
(
5
,
padding_mode
=
"edge"
),
NotScriptableArgsKwargs
(
5
,
padding_mode
=
"reflect"
),
NotScriptableArgsKwargs
(
5
,
padding_mode
=
"symmetric"
),
],
),
ConsistencyConfig
(
...
...
@@ -170,7 +182,7 @@ CONSISTENCY_CONFIGS = [
ConsistencyConfig
(
prototype_transforms
.
ToPILImage
,
legacy_transforms
.
ToPILImage
,
[
ArgsKwargs
()],
[
NotScriptable
ArgsKwargs
()],
make_images_kwargs
=
dict
(
color_spaces
=
[
"GRAY"
,
...
...
@@ -186,7 +198,7 @@ CONSISTENCY_CONFIGS = [
prototype_transforms
.
Lambda
,
legacy_transforms
.
Lambda
,
[
ArgsKwargs
(
lambda
image
:
image
/
2
),
NotScriptable
ArgsKwargs
(
lambda
image
:
image
/
2
),
],
# Technically, this also supports PIL, but it is overkill to write a function here that supports tensor and PIL
# images given that the transform does nothing but call it anyway.
...
...
@@ -380,14 +392,15 @@ CONSISTENCY_CONFIGS = [
[
ArgsKwargs
(
12
),
ArgsKwargs
((
15
,
17
)),
ArgsKwargs
(
11
,
padding
=
1
),
NotScriptableArgsKwargs
(
11
,
padding
=
1
),
ArgsKwargs
(
11
,
padding
=
[
1
]),
ArgsKwargs
((
8
,
13
),
padding
=
(
2
,
3
)),
ArgsKwargs
((
14
,
9
),
padding
=
(
0
,
2
,
1
,
0
)),
ArgsKwargs
(
36
,
pad_if_needed
=
True
),
ArgsKwargs
((
7
,
8
),
fill
=
1
),
ArgsKwargs
(
5
,
fill
=
(
1
,
2
,
3
)),
NotScriptable
ArgsKwargs
(
5
,
fill
=
(
1
,
2
,
3
)),
ArgsKwargs
(
12
),
ArgsKwargs
(
15
,
padding
=
2
,
padding_mode
=
"edge"
),
NotScriptable
ArgsKwargs
(
15
,
padding
=
2
,
padding_mode
=
"edge"
),
ArgsKwargs
(
17
,
padding
=
(
1
,
0
),
padding_mode
=
"reflect"
),
ArgsKwargs
(
8
,
padding
=
(
3
,
0
,
0
,
1
),
padding_mode
=
"symmetric"
),
],
...
...
@@ -642,6 +655,38 @@ def test_call_consistency(config, args_kwargs):
)
@
pytest
.
mark
.
parametrize
(
(
"config"
,
"args_kwargs"
),
[
pytest
.
param
(
config
,
args_kwargs
,
id
=
f
"
{
config
.
legacy_cls
.
__name__
}
-
{
idx
:
0
{
len
(
str
(
len
(
config
.
args_kwargs
)))
}
d
}
"
)
for
config
in
CONSISTENCY_CONFIGS
for
idx
,
args_kwargs
in
enumerate
(
config
.
args_kwargs
)
if
not
isinstance
(
args_kwargs
,
NotScriptableArgsKwargs
)
],
)
def
test_jit_consistency
(
config
,
args_kwargs
):
args
,
kwargs
=
args_kwargs
prototype_transform_eager
=
config
.
prototype_cls
(
*
args
,
**
kwargs
)
legacy_transform_eager
=
config
.
legacy_cls
(
*
args
,
**
kwargs
)
legacy_transform_scripted
=
torch
.
jit
.
script
(
legacy_transform_eager
)
prototype_transform_scripted
=
torch
.
jit
.
script
(
prototype_transform_eager
)
for
image
in
make_images
(
**
config
.
make_images_kwargs
):
image
=
image
.
as_subclass
(
torch
.
Tensor
)
torch
.
manual_seed
(
0
)
output_legacy_scripted
=
legacy_transform_scripted
(
image
)
torch
.
manual_seed
(
0
)
output_prototype_scripted
=
prototype_transform_scripted
(
image
)
assert_close
(
output_prototype_scripted
,
output_legacy_scripted
,
**
config
.
closeness_kwargs
)
class
TestContainerTransforms
:
"""
Since we are testing containers here, we also need some transforms to wrap. Thus, testing a container transform for
...
...
torchvision/prototype/transforms/_augment.py
View file @
7cf0f4cc
...
...
@@ -6,7 +6,7 @@ from typing import Any, cast, Dict, List, Optional, Tuple, Union
import
PIL.Image
import
torch
from
torch.utils._pytree
import
tree_flatten
,
tree_unflatten
from
torchvision
import
transforms
as
_transforms
from
torchvision.ops
import
masks_to_boxes
from
torchvision.prototype
import
datapoints
from
torchvision.prototype.transforms
import
functional
as
F
,
InterpolationMode
,
Transform
...
...
@@ -16,6 +16,14 @@ from .utils import has_any, is_simple_tensor, query_chw, query_spatial_size
class
RandomErasing
(
_RandomApplyTransform
):
_v1_transform_cls
=
_transforms
.
RandomErasing
def
_extract_params_for_v1_transform
(
self
)
->
Dict
[
str
,
Any
]:
return
dict
(
super
().
_extract_params_for_v1_transform
(),
value
=
"random"
if
self
.
value
is
None
else
self
.
value
,
)
_transformed_types
=
(
is_simple_tensor
,
datapoints
.
Image
,
PIL
.
Image
.
Image
,
datapoints
.
Video
)
def
__init__
(
...
...
torchvision/prototype/transforms/_auto_augment.py
View file @
7cf0f4cc
...
...
@@ -5,7 +5,7 @@ import PIL.Image
import
torch
from
torch.utils._pytree
import
tree_flatten
,
tree_unflatten
,
TreeSpec
from
torchvision
import
transforms
as
_transforms
from
torchvision.prototype
import
datapoints
from
torchvision.prototype.transforms
import
AutoAugmentPolicy
,
functional
as
F
,
InterpolationMode
,
Transform
from
torchvision.prototype.transforms.functional._meta
import
get_spatial_size
...
...
@@ -161,6 +161,8 @@ class _AutoAugmentBase(Transform):
class
AutoAugment
(
_AutoAugmentBase
):
_v1_transform_cls
=
_transforms
.
AutoAugment
_AUGMENTATION_SPACE
=
{
"ShearX"
:
(
lambda
num_bins
,
height
,
width
:
torch
.
linspace
(
0.0
,
0.3
,
num_bins
),
True
),
"ShearY"
:
(
lambda
num_bins
,
height
,
width
:
torch
.
linspace
(
0.0
,
0.3
,
num_bins
),
True
),
...
...
@@ -315,6 +317,7 @@ class AutoAugment(_AutoAugmentBase):
class
RandAugment
(
_AutoAugmentBase
):
_v1_transform_cls
=
_transforms
.
RandAugment
_AUGMENTATION_SPACE
=
{
"Identity"
:
(
lambda
num_bins
,
height
,
width
:
None
,
False
),
"ShearX"
:
(
lambda
num_bins
,
height
,
width
:
torch
.
linspace
(
0.0
,
0.3
,
num_bins
),
True
),
...
...
@@ -375,6 +378,7 @@ class RandAugment(_AutoAugmentBase):
class
TrivialAugmentWide
(
_AutoAugmentBase
):
_v1_transform_cls
=
_transforms
.
TrivialAugmentWide
_AUGMENTATION_SPACE
=
{
"Identity"
:
(
lambda
num_bins
,
height
,
width
:
None
,
False
),
"ShearX"
:
(
lambda
num_bins
,
height
,
width
:
torch
.
linspace
(
0.0
,
0.99
,
num_bins
),
True
),
...
...
@@ -425,6 +429,8 @@ class TrivialAugmentWide(_AutoAugmentBase):
class
AugMix
(
_AutoAugmentBase
):
_v1_transform_cls
=
_transforms
.
AugMix
_PARTIAL_AUGMENTATION_SPACE
=
{
"ShearX"
:
(
lambda
num_bins
,
height
,
width
:
torch
.
linspace
(
0.0
,
0.3
,
num_bins
),
True
),
"ShearY"
:
(
lambda
num_bins
,
height
,
width
:
torch
.
linspace
(
0.0
,
0.3
,
num_bins
),
True
),
...
...
torchvision/prototype/transforms/_color.py
View file @
7cf0f4cc
...
...
@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import
PIL.Image
import
torch
from
torchvision
import
transforms
as
_transforms
from
torchvision.prototype
import
datapoints
from
torchvision.prototype.transforms
import
functional
as
F
,
Transform
...
...
@@ -12,6 +12,8 @@ from .utils import is_simple_tensor, query_chw
class
Grayscale
(
Transform
):
_v1_transform_cls
=
_transforms
.
Grayscale
_transformed_types
=
(
datapoints
.
Image
,
PIL
.
Image
.
Image
,
...
...
@@ -28,6 +30,8 @@ class Grayscale(Transform):
class
RandomGrayscale
(
_RandomApplyTransform
):
_v1_transform_cls
=
_transforms
.
RandomGrayscale
_transformed_types
=
(
datapoints
.
Image
,
PIL
.
Image
.
Image
,
...
...
@@ -47,6 +51,11 @@ class RandomGrayscale(_RandomApplyTransform):
class
ColorJitter
(
Transform
):
_v1_transform_cls
=
_transforms
.
ColorJitter
def
_extract_params_for_v1_transform
(
self
)
->
Dict
[
str
,
Any
]:
return
{
attr
:
value
or
0
for
attr
,
value
in
super
().
_extract_params_for_v1_transform
().
items
()}
def
__init__
(
self
,
brightness
:
Optional
[
Union
[
float
,
Sequence
[
float
]]]
=
None
,
...
...
@@ -194,16 +203,22 @@ class RandomPhotometricDistort(Transform):
class
RandomEqualize
(
_RandomApplyTransform
):
_v1_transform_cls
=
_transforms
.
RandomEqualize
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
equalize
(
inpt
)
class
RandomInvert
(
_RandomApplyTransform
):
_v1_transform_cls
=
_transforms
.
RandomInvert
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
invert
(
inpt
)
class
RandomPosterize
(
_RandomApplyTransform
):
_v1_transform_cls
=
_transforms
.
RandomPosterize
def
__init__
(
self
,
bits
:
int
,
p
:
float
=
0.5
)
->
None
:
super
().
__init__
(
p
=
p
)
self
.
bits
=
bits
...
...
@@ -213,6 +228,8 @@ class RandomPosterize(_RandomApplyTransform):
class
RandomSolarize
(
_RandomApplyTransform
):
_v1_transform_cls
=
_transforms
.
RandomSolarize
def
__init__
(
self
,
threshold
:
float
,
p
:
float
=
0.5
)
->
None
:
super
().
__init__
(
p
=
p
)
self
.
threshold
=
threshold
...
...
@@ -222,11 +239,15 @@ class RandomSolarize(_RandomApplyTransform):
class
RandomAutocontrast
(
_RandomApplyTransform
):
_v1_transform_cls
=
_transforms
.
RandomAutocontrast
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
autocontrast
(
inpt
)
class
RandomAdjustSharpness
(
_RandomApplyTransform
):
_v1_transform_cls
=
_transforms
.
RandomAdjustSharpness
def
__init__
(
self
,
sharpness_factor
:
float
,
p
:
float
=
0.5
)
->
None
:
super
().
__init__
(
p
=
p
)
self
.
sharpness_factor
=
sharpness_factor
...
...
torchvision/prototype/transforms/_geometry.py
View file @
7cf0f4cc
...
...
@@ -6,6 +6,7 @@ from typing import Any, cast, Dict, List, Literal, Optional, Sequence, Tuple, Ty
import
PIL.Image
import
torch
from
torchvision
import
transforms
as
_transforms
from
torchvision.ops.boxes
import
box_iou
from
torchvision.prototype
import
datapoints
from
torchvision.prototype.transforms
import
functional
as
F
,
InterpolationMode
,
Transform
...
...
@@ -25,16 +26,22 @@ from .utils import has_all, has_any, is_simple_tensor, query_bounding_box, query
class
RandomHorizontalFlip
(
_RandomApplyTransform
):
_v1_transform_cls
=
_transforms
.
RandomHorizontalFlip
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
horizontal_flip
(
inpt
)
class
RandomVerticalFlip
(
_RandomApplyTransform
):
_v1_transform_cls
=
_transforms
.
RandomVerticalFlip
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
vertical_flip
(
inpt
)
class
Resize
(
Transform
):
_v1_transform_cls
=
_transforms
.
Resize
def
__init__
(
self
,
size
:
Union
[
int
,
Sequence
[
int
]],
...
...
@@ -69,6 +76,8 @@ class Resize(Transform):
class
CenterCrop
(
Transform
):
_v1_transform_cls
=
_transforms
.
CenterCrop
def
__init__
(
self
,
size
:
Union
[
int
,
Sequence
[
int
]]):
super
().
__init__
()
self
.
size
=
_setup_size
(
size
,
error_msg
=
"Please provide only two dimensions (h, w) for size."
)
...
...
@@ -78,6 +87,8 @@ class CenterCrop(Transform):
class
RandomResizedCrop
(
Transform
):
_v1_transform_cls
=
_transforms
.
RandomResizedCrop
def
__init__
(
self
,
size
:
Union
[
int
,
Sequence
[
int
]],
...
...
@@ -174,6 +185,8 @@ class FiveCrop(Transform):
torch.Size([5])
"""
_v1_transform_cls
=
_transforms
.
FiveCrop
_transformed_types
=
(
datapoints
.
Image
,
PIL
.
Image
.
Image
,
...
...
@@ -200,6 +213,8 @@ class TenCrop(Transform):
See :class:`~torchvision.prototype.transforms.FiveCrop` for an example.
"""
_v1_transform_cls
=
_transforms
.
TenCrop
_transformed_types
=
(
datapoints
.
Image
,
PIL
.
Image
.
Image
,
...
...
@@ -223,6 +238,18 @@ class TenCrop(Transform):
class
Pad
(
Transform
):
_v1_transform_cls
=
_transforms
.
Pad
def
_extract_params_for_v1_transform
(
self
)
->
Dict
[
str
,
Any
]:
params
=
super
().
_extract_params_for_v1_transform
()
if
not
(
params
[
"fill"
]
is
None
or
isinstance
(
params
[
"fill"
],
(
int
,
float
))):
raise
ValueError
(
f
"
{
type
(
self
.
__name__
)
}
() can only be scripted for a scalar `fill`, but got
{
self
.
fill
}
for images."
)
return
params
def
__init__
(
self
,
padding
:
Union
[
int
,
Sequence
[
int
]],
...
...
@@ -285,6 +312,8 @@ class RandomZoomOut(_RandomApplyTransform):
class
RandomRotation
(
Transform
):
_v1_transform_cls
=
_transforms
.
RandomRotation
def
__init__
(
self
,
degrees
:
Union
[
numbers
.
Number
,
Sequence
],
...
...
@@ -322,6 +351,8 @@ class RandomRotation(Transform):
class
RandomAffine
(
Transform
):
_v1_transform_cls
=
_transforms
.
RandomAffine
def
__init__
(
self
,
degrees
:
Union
[
numbers
.
Number
,
Sequence
],
...
...
@@ -399,6 +430,24 @@ class RandomAffine(Transform):
class
RandomCrop
(
Transform
):
_v1_transform_cls
=
_transforms
.
RandomCrop
def
_extract_params_for_v1_transform
(
self
)
->
Dict
[
str
,
Any
]:
params
=
super
().
_extract_params_for_v1_transform
()
if
not
(
params
[
"fill"
]
is
None
or
isinstance
(
params
[
"fill"
],
(
int
,
float
))):
raise
ValueError
(
f
"
{
type
(
self
.
__name__
)
}
() can only be scripted for a scalar `fill`, but got
{
self
.
fill
}
for images."
)
padding
=
self
.
padding
if
padding
is
not
None
:
pad_left
,
pad_right
,
pad_top
,
pad_bottom
=
padding
padding
=
[
pad_left
,
pad_top
,
pad_right
,
pad_bottom
]
params
[
"padding"
]
=
padding
return
params
def
__init__
(
self
,
size
:
Union
[
int
,
Sequence
[
int
]],
...
...
@@ -491,6 +540,8 @@ class RandomCrop(Transform):
class
RandomPerspective
(
_RandomApplyTransform
):
_v1_transform_cls
=
_transforms
.
RandomPerspective
def
__init__
(
self
,
distortion_scale
:
float
=
0.5
,
...
...
@@ -550,6 +601,8 @@ class RandomPerspective(_RandomApplyTransform):
class
ElasticTransform
(
Transform
):
_v1_transform_cls
=
_transforms
.
ElasticTransform
def
__init__
(
self
,
alpha
:
Union
[
float
,
Sequence
[
float
]]
=
50.0
,
...
...
torchvision/prototype/transforms/_meta.py
View file @
7cf0f4cc
...
...
@@ -2,6 +2,7 @@ from typing import Any, Dict, Union
import
torch
from
torchvision
import
transforms
as
_transforms
from
torchvision.prototype
import
datapoints
from
torchvision.prototype.transforms
import
functional
as
F
,
Transform
...
...
@@ -27,6 +28,8 @@ class ConvertBoundingBoxFormat(Transform):
class
ConvertDtype
(
Transform
):
_v1_transform_cls
=
_transforms
.
ConvertImageDtype
_transformed_types
=
(
is_simple_tensor
,
datapoints
.
Image
,
datapoints
.
Video
)
def
__init__
(
self
,
dtype
:
torch
.
dtype
=
torch
.
float32
)
->
None
:
...
...
torchvision/prototype/transforms/_misc.py
View file @
7cf0f4cc
...
...
@@ -4,6 +4,7 @@ import PIL.Image
import
torch
from
torchvision
import
transforms
as
_transforms
from
torchvision.ops
import
remove_small_boxes
from
torchvision.prototype
import
datapoints
from
torchvision.prototype.transforms
import
functional
as
F
,
Transform
...
...
@@ -39,6 +40,8 @@ class Lambda(Transform):
class
LinearTransformation
(
Transform
):
_v1_transform_cls
=
_transforms
.
LinearTransformation
_transformed_types
=
(
is_simple_tensor
,
datapoints
.
Image
,
datapoints
.
Video
)
def
__init__
(
self
,
transformation_matrix
:
torch
.
Tensor
,
mean_vector
:
torch
.
Tensor
):
...
...
@@ -94,6 +97,7 @@ class LinearTransformation(Transform):
class
Normalize
(
Transform
):
_v1_transform_cls
=
_transforms
.
Normalize
_transformed_types
=
(
datapoints
.
Image
,
is_simple_tensor
,
datapoints
.
Video
)
def
__init__
(
self
,
mean
:
Sequence
[
float
],
std
:
Sequence
[
float
],
inplace
:
bool
=
False
):
...
...
@@ -113,6 +117,8 @@ class Normalize(Transform):
class
GaussianBlur
(
Transform
):
_v1_transform_cls
=
_transforms
.
GaussianBlur
def
__init__
(
self
,
kernel_size
:
Union
[
int
,
Sequence
[
int
]],
sigma
:
Union
[
int
,
float
,
Sequence
[
float
]]
=
(
0.1
,
2.0
)
)
->
None
:
...
...
torchvision/prototype/transforms/_transform.py
View file @
7cf0f4cc
from
__future__
import
annotations
import
enum
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Tuple
,
Type
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Type
,
Union
import
PIL.Image
import
torch
...
...
@@ -54,6 +56,51 @@ class Transform(nn.Module):
return
", "
.
join
(
extra
)
# This attribute should be set on all transforms that have a v1 equivalent. Doing so enables the v2 transformation
# to be scriptable. See `_extract_params_for_v1_transform()` and `__prepare_scriptable__` for details.
_v1_transform_cls
:
Optional
[
Type
[
nn
.
Module
]]
=
None
def
_extract_params_for_v1_transform
(
self
)
->
Dict
[
str
,
Any
]:
# This method is called by `__prepare_scriptable__` to instantiate the equivalent v1 transform from the current
# v2 transform instance. It does two things:
# 1. Extract all available public attributes that are specific to that transform and not `nn.Module` in general
# 2. If available handle the `fill` attribute for v1 compatibility (see below for details)
# Overwrite this method on the v2 transform class if the above is not sufficient. For example, this might happen
# if the v2 transform introduced new parameters that are not support by the v1 transform.
common_attrs
=
nn
.
Module
().
__dict__
.
keys
()
params
=
{
attr
:
value
for
attr
,
value
in
self
.
__dict__
.
items
()
if
not
attr
.
startswith
(
"_"
)
and
attr
not
in
common_attrs
}
# transforms v2 has a more complex handling for the `fill` parameter than v1. By default, the input is parsed
# with `prototype.transforms._utils._setup_fill_arg()`, which returns a defaultdict that holds the fill value
# for the different datapoint types. Below we extract the value for tensors and return that together with the
# other params.
# This is needed for `Pad`, `ElasticTransform`, `RandomAffine`, `RandomCrop`, `RandomPerspective` and
# `RandomRotation`
if
"fill"
in
params
:
fill_type_defaultdict
=
params
.
pop
(
"fill"
)
params
[
"fill"
]
=
fill_type_defaultdict
[
torch
.
Tensor
]
return
params
def
__prepare_scriptable__
(
self
)
->
nn
.
Module
:
# This method is called early on when `torch.jit.script`'ing an `nn.Module` instance. If it succeeds, the return
# value is used for scripting over the original object that should have been scripted. Since the v1 transforms
# are JIT scriptable, and we made sure that for single image inputs v1 and v2 are equivalent, we just return the
# equivalent v1 transform here. This of course only makes transforms v2 JIT scriptable as long as transforms v1
# is around.
if
self
.
_v1_transform_cls
is
None
:
raise
RuntimeError
(
f
"Transform
{
type
(
self
.
__name__
)
}
cannot be JIT scripted. "
f
"This is only support for backward compatibility with transforms which already in v1."
f
"For torchscript support (on tensors only), you can use the functional API instead."
)
return
self
.
_v1_transform_cls
(
**
self
.
_extract_params_for_v1_transform
())
class
_RandomApplyTransform
(
Transform
):
def
__init__
(
self
,
p
:
float
=
0.5
)
->
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment