Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
7cf0f4cc
"model/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "0fbfcf3c9c7bfdbf4616238595eafd7eca2a916c"
Unverified
Commit
7cf0f4cc
authored
Jan 31, 2023
by
Philip Meier
Committed by
GitHub
Jan 31, 2023
Browse files
make transforms v2 JIT scriptable (#7135)
parent
170160a5
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
206 additions
and
17 deletions
+206
-17
test/test_prototype_transforms_consistency.py
test/test_prototype_transforms_consistency.py
+58
-13
torchvision/prototype/transforms/_augment.py
torchvision/prototype/transforms/_augment.py
+9
-1
torchvision/prototype/transforms/_auto_augment.py
torchvision/prototype/transforms/_auto_augment.py
+7
-1
torchvision/prototype/transforms/_color.py
torchvision/prototype/transforms/_color.py
+22
-1
torchvision/prototype/transforms/_geometry.py
torchvision/prototype/transforms/_geometry.py
+53
-0
torchvision/prototype/transforms/_meta.py
torchvision/prototype/transforms/_meta.py
+3
-0
torchvision/prototype/transforms/_misc.py
torchvision/prototype/transforms/_misc.py
+6
-0
torchvision/prototype/transforms/_transform.py
torchvision/prototype/transforms/_transform.py
+48
-1
No files found.
test/test_prototype_transforms_consistency.py
View file @
7cf0f4cc
...
...
@@ -34,6 +34,15 @@ from torchvision.transforms import functional as legacy_F
DEFAULT_MAKE_IMAGES_KWARGS
=
dict
(
color_spaces
=
[
"RGB"
],
extra_dims
=
[(
4
,)])
class
NotScriptableArgsKwargs
(
ArgsKwargs
):
"""
This class is used to mark parameters that render the transform non-scriptable. They still work in eager mode and
thus will be tested there, but will be skipped by the JIT tests.
"""
pass
class
ConsistencyConfig
:
def
__init__
(
self
,
...
...
@@ -73,7 +82,7 @@ CONSISTENCY_CONFIGS = [
prototype_transforms
.
Resize
,
legacy_transforms
.
Resize
,
[
ArgsKwargs
(
32
),
NotScriptable
ArgsKwargs
(
32
),
ArgsKwargs
([
32
]),
ArgsKwargs
((
32
,
29
)),
ArgsKwargs
((
31
,
28
),
interpolation
=
prototype_transforms
.
InterpolationMode
.
NEAREST
),
...
...
@@ -84,8 +93,10 @@ CONSISTENCY_CONFIGS = [
# ArgsKwargs((30, 27), interpolation=0),
# ArgsKwargs((35, 29), interpolation=2),
# ArgsKwargs((34, 25), interpolation=3),
ArgsKwargs
(
31
,
max_size
=
32
),
ArgsKwargs
(
30
,
max_size
=
100
),
NotScriptableArgsKwargs
(
31
,
max_size
=
32
),
ArgsKwargs
([
31
],
max_size
=
32
),
NotScriptableArgsKwargs
(
30
,
max_size
=
100
),
ArgsKwargs
([
31
],
max_size
=
32
),
ArgsKwargs
((
29
,
32
),
antialias
=
False
),
ArgsKwargs
((
28
,
31
),
antialias
=
True
),
],
...
...
@@ -121,14 +132,15 @@ CONSISTENCY_CONFIGS = [
prototype_transforms
.
Pad
,
legacy_transforms
.
Pad
,
[
ArgsKwargs
(
3
),
NotScriptable
ArgsKwargs
(
3
),
ArgsKwargs
([
3
]),
ArgsKwargs
([
2
,
3
]),
ArgsKwargs
([
3
,
2
,
1
,
4
]),
ArgsKwargs
(
5
,
fill
=
1
,
padding_mode
=
"constant"
),
ArgsKwargs
(
5
,
padding_mode
=
"edge"
),
ArgsKwargs
(
5
,
padding_mode
=
"reflect"
),
ArgsKwargs
(
5
,
padding_mode
=
"symmetric"
),
NotScriptableArgsKwargs
(
5
,
fill
=
1
,
padding_mode
=
"constant"
),
ArgsKwargs
([
5
],
fill
=
1
,
padding_mode
=
"constant"
),
NotScriptableArgsKwargs
(
5
,
padding_mode
=
"edge"
),
NotScriptableArgsKwargs
(
5
,
padding_mode
=
"reflect"
),
NotScriptableArgsKwargs
(
5
,
padding_mode
=
"symmetric"
),
],
),
ConsistencyConfig
(
...
...
@@ -170,7 +182,7 @@ CONSISTENCY_CONFIGS = [
ConsistencyConfig
(
prototype_transforms
.
ToPILImage
,
legacy_transforms
.
ToPILImage
,
[
ArgsKwargs
()],
[
NotScriptable
ArgsKwargs
()],
make_images_kwargs
=
dict
(
color_spaces
=
[
"GRAY"
,
...
...
@@ -186,7 +198,7 @@ CONSISTENCY_CONFIGS = [
prototype_transforms
.
Lambda
,
legacy_transforms
.
Lambda
,
[
ArgsKwargs
(
lambda
image
:
image
/
2
),
NotScriptable
ArgsKwargs
(
lambda
image
:
image
/
2
),
],
# Technically, this also supports PIL, but it is overkill to write a function here that supports tensor and PIL
# images given that the transform does nothing but call it anyway.
...
...
@@ -380,14 +392,15 @@ CONSISTENCY_CONFIGS = [
[
ArgsKwargs
(
12
),
ArgsKwargs
((
15
,
17
)),
ArgsKwargs
(
11
,
padding
=
1
),
NotScriptableArgsKwargs
(
11
,
padding
=
1
),
ArgsKwargs
(
11
,
padding
=
[
1
]),
ArgsKwargs
((
8
,
13
),
padding
=
(
2
,
3
)),
ArgsKwargs
((
14
,
9
),
padding
=
(
0
,
2
,
1
,
0
)),
ArgsKwargs
(
36
,
pad_if_needed
=
True
),
ArgsKwargs
((
7
,
8
),
fill
=
1
),
ArgsKwargs
(
5
,
fill
=
(
1
,
2
,
3
)),
NotScriptable
ArgsKwargs
(
5
,
fill
=
(
1
,
2
,
3
)),
ArgsKwargs
(
12
),
ArgsKwargs
(
15
,
padding
=
2
,
padding_mode
=
"edge"
),
NotScriptable
ArgsKwargs
(
15
,
padding
=
2
,
padding_mode
=
"edge"
),
ArgsKwargs
(
17
,
padding
=
(
1
,
0
),
padding_mode
=
"reflect"
),
ArgsKwargs
(
8
,
padding
=
(
3
,
0
,
0
,
1
),
padding_mode
=
"symmetric"
),
],
...
...
@@ -642,6 +655,38 @@ def test_call_consistency(config, args_kwargs):
)
@
pytest
.
mark
.
parametrize
(
(
"config"
,
"args_kwargs"
),
[
pytest
.
param
(
config
,
args_kwargs
,
id
=
f
"
{
config
.
legacy_cls
.
__name__
}
-
{
idx
:
0
{
len
(
str
(
len
(
config
.
args_kwargs
)))
}
d
}
"
)
for
config
in
CONSISTENCY_CONFIGS
for
idx
,
args_kwargs
in
enumerate
(
config
.
args_kwargs
)
if
not
isinstance
(
args_kwargs
,
NotScriptableArgsKwargs
)
],
)
def
test_jit_consistency
(
config
,
args_kwargs
):
args
,
kwargs
=
args_kwargs
prototype_transform_eager
=
config
.
prototype_cls
(
*
args
,
**
kwargs
)
legacy_transform_eager
=
config
.
legacy_cls
(
*
args
,
**
kwargs
)
legacy_transform_scripted
=
torch
.
jit
.
script
(
legacy_transform_eager
)
prototype_transform_scripted
=
torch
.
jit
.
script
(
prototype_transform_eager
)
for
image
in
make_images
(
**
config
.
make_images_kwargs
):
image
=
image
.
as_subclass
(
torch
.
Tensor
)
torch
.
manual_seed
(
0
)
output_legacy_scripted
=
legacy_transform_scripted
(
image
)
torch
.
manual_seed
(
0
)
output_prototype_scripted
=
prototype_transform_scripted
(
image
)
assert_close
(
output_prototype_scripted
,
output_legacy_scripted
,
**
config
.
closeness_kwargs
)
class
TestContainerTransforms
:
"""
Since we are testing containers here, we also need some transforms to wrap. Thus, testing a container transform for
...
...
torchvision/prototype/transforms/_augment.py
View file @
7cf0f4cc
...
...
@@ -6,7 +6,7 @@ from typing import Any, cast, Dict, List, Optional, Tuple, Union
import
PIL.Image
import
torch
from
torch.utils._pytree
import
tree_flatten
,
tree_unflatten
from
torchvision
import
transforms
as
_transforms
from
torchvision.ops
import
masks_to_boxes
from
torchvision.prototype
import
datapoints
from
torchvision.prototype.transforms
import
functional
as
F
,
InterpolationMode
,
Transform
...
...
@@ -16,6 +16,14 @@ from .utils import has_any, is_simple_tensor, query_chw, query_spatial_size
class
RandomErasing
(
_RandomApplyTransform
):
_v1_transform_cls
=
_transforms
.
RandomErasing
def
_extract_params_for_v1_transform
(
self
)
->
Dict
[
str
,
Any
]:
return
dict
(
super
().
_extract_params_for_v1_transform
(),
value
=
"random"
if
self
.
value
is
None
else
self
.
value
,
)
_transformed_types
=
(
is_simple_tensor
,
datapoints
.
Image
,
PIL
.
Image
.
Image
,
datapoints
.
Video
)
def
__init__
(
...
...
torchvision/prototype/transforms/_auto_augment.py
View file @
7cf0f4cc
...
...
@@ -5,7 +5,7 @@ import PIL.Image
import
torch
from
torch.utils._pytree
import
tree_flatten
,
tree_unflatten
,
TreeSpec
from
torchvision
import
transforms
as
_transforms
from
torchvision.prototype
import
datapoints
from
torchvision.prototype.transforms
import
AutoAugmentPolicy
,
functional
as
F
,
InterpolationMode
,
Transform
from
torchvision.prototype.transforms.functional._meta
import
get_spatial_size
...
...
@@ -161,6 +161,8 @@ class _AutoAugmentBase(Transform):
class
AutoAugment
(
_AutoAugmentBase
):
_v1_transform_cls
=
_transforms
.
AutoAugment
_AUGMENTATION_SPACE
=
{
"ShearX"
:
(
lambda
num_bins
,
height
,
width
:
torch
.
linspace
(
0.0
,
0.3
,
num_bins
),
True
),
"ShearY"
:
(
lambda
num_bins
,
height
,
width
:
torch
.
linspace
(
0.0
,
0.3
,
num_bins
),
True
),
...
...
@@ -315,6 +317,7 @@ class AutoAugment(_AutoAugmentBase):
class
RandAugment
(
_AutoAugmentBase
):
_v1_transform_cls
=
_transforms
.
RandAugment
_AUGMENTATION_SPACE
=
{
"Identity"
:
(
lambda
num_bins
,
height
,
width
:
None
,
False
),
"ShearX"
:
(
lambda
num_bins
,
height
,
width
:
torch
.
linspace
(
0.0
,
0.3
,
num_bins
),
True
),
...
...
@@ -375,6 +378,7 @@ class RandAugment(_AutoAugmentBase):
class
TrivialAugmentWide
(
_AutoAugmentBase
):
_v1_transform_cls
=
_transforms
.
TrivialAugmentWide
_AUGMENTATION_SPACE
=
{
"Identity"
:
(
lambda
num_bins
,
height
,
width
:
None
,
False
),
"ShearX"
:
(
lambda
num_bins
,
height
,
width
:
torch
.
linspace
(
0.0
,
0.99
,
num_bins
),
True
),
...
...
@@ -425,6 +429,8 @@ class TrivialAugmentWide(_AutoAugmentBase):
class
AugMix
(
_AutoAugmentBase
):
_v1_transform_cls
=
_transforms
.
AugMix
_PARTIAL_AUGMENTATION_SPACE
=
{
"ShearX"
:
(
lambda
num_bins
,
height
,
width
:
torch
.
linspace
(
0.0
,
0.3
,
num_bins
),
True
),
"ShearY"
:
(
lambda
num_bins
,
height
,
width
:
torch
.
linspace
(
0.0
,
0.3
,
num_bins
),
True
),
...
...
torchvision/prototype/transforms/_color.py
View file @
7cf0f4cc
...
...
@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import
PIL.Image
import
torch
from
torchvision
import
transforms
as
_transforms
from
torchvision.prototype
import
datapoints
from
torchvision.prototype.transforms
import
functional
as
F
,
Transform
...
...
@@ -12,6 +12,8 @@ from .utils import is_simple_tensor, query_chw
class
Grayscale
(
Transform
):
_v1_transform_cls
=
_transforms
.
Grayscale
_transformed_types
=
(
datapoints
.
Image
,
PIL
.
Image
.
Image
,
...
...
@@ -28,6 +30,8 @@ class Grayscale(Transform):
class
RandomGrayscale
(
_RandomApplyTransform
):
_v1_transform_cls
=
_transforms
.
RandomGrayscale
_transformed_types
=
(
datapoints
.
Image
,
PIL
.
Image
.
Image
,
...
...
@@ -47,6 +51,11 @@ class RandomGrayscale(_RandomApplyTransform):
class
ColorJitter
(
Transform
):
_v1_transform_cls
=
_transforms
.
ColorJitter
def
_extract_params_for_v1_transform
(
self
)
->
Dict
[
str
,
Any
]:
return
{
attr
:
value
or
0
for
attr
,
value
in
super
().
_extract_params_for_v1_transform
().
items
()}
def
__init__
(
self
,
brightness
:
Optional
[
Union
[
float
,
Sequence
[
float
]]]
=
None
,
...
...
@@ -194,16 +203,22 @@ class RandomPhotometricDistort(Transform):
class
RandomEqualize
(
_RandomApplyTransform
):
_v1_transform_cls
=
_transforms
.
RandomEqualize
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
equalize
(
inpt
)
class
RandomInvert
(
_RandomApplyTransform
):
_v1_transform_cls
=
_transforms
.
RandomInvert
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
invert
(
inpt
)
class
RandomPosterize
(
_RandomApplyTransform
):
_v1_transform_cls
=
_transforms
.
RandomPosterize
def
__init__
(
self
,
bits
:
int
,
p
:
float
=
0.5
)
->
None
:
super
().
__init__
(
p
=
p
)
self
.
bits
=
bits
...
...
@@ -213,6 +228,8 @@ class RandomPosterize(_RandomApplyTransform):
class
RandomSolarize
(
_RandomApplyTransform
):
_v1_transform_cls
=
_transforms
.
RandomSolarize
def
__init__
(
self
,
threshold
:
float
,
p
:
float
=
0.5
)
->
None
:
super
().
__init__
(
p
=
p
)
self
.
threshold
=
threshold
...
...
@@ -222,11 +239,15 @@ class RandomSolarize(_RandomApplyTransform):
class
RandomAutocontrast
(
_RandomApplyTransform
):
_v1_transform_cls
=
_transforms
.
RandomAutocontrast
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
autocontrast
(
inpt
)
class
RandomAdjustSharpness
(
_RandomApplyTransform
):
_v1_transform_cls
=
_transforms
.
RandomAdjustSharpness
def
__init__
(
self
,
sharpness_factor
:
float
,
p
:
float
=
0.5
)
->
None
:
super
().
__init__
(
p
=
p
)
self
.
sharpness_factor
=
sharpness_factor
...
...
torchvision/prototype/transforms/_geometry.py
View file @
7cf0f4cc
...
...
@@ -6,6 +6,7 @@ from typing import Any, cast, Dict, List, Literal, Optional, Sequence, Tuple, Ty
import
PIL.Image
import
torch
from
torchvision
import
transforms
as
_transforms
from
torchvision.ops.boxes
import
box_iou
from
torchvision.prototype
import
datapoints
from
torchvision.prototype.transforms
import
functional
as
F
,
InterpolationMode
,
Transform
...
...
@@ -25,16 +26,22 @@ from .utils import has_all, has_any, is_simple_tensor, query_bounding_box, query
class
RandomHorizontalFlip
(
_RandomApplyTransform
):
_v1_transform_cls
=
_transforms
.
RandomHorizontalFlip
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
horizontal_flip
(
inpt
)
class
RandomVerticalFlip
(
_RandomApplyTransform
):
_v1_transform_cls
=
_transforms
.
RandomVerticalFlip
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
vertical_flip
(
inpt
)
class
Resize
(
Transform
):
_v1_transform_cls
=
_transforms
.
Resize
def
__init__
(
self
,
size
:
Union
[
int
,
Sequence
[
int
]],
...
...
@@ -69,6 +76,8 @@ class Resize(Transform):
class
CenterCrop
(
Transform
):
_v1_transform_cls
=
_transforms
.
CenterCrop
def
__init__
(
self
,
size
:
Union
[
int
,
Sequence
[
int
]]):
super
().
__init__
()
self
.
size
=
_setup_size
(
size
,
error_msg
=
"Please provide only two dimensions (h, w) for size."
)
...
...
@@ -78,6 +87,8 @@ class CenterCrop(Transform):
class
RandomResizedCrop
(
Transform
):
_v1_transform_cls
=
_transforms
.
RandomResizedCrop
def
__init__
(
self
,
size
:
Union
[
int
,
Sequence
[
int
]],
...
...
@@ -174,6 +185,8 @@ class FiveCrop(Transform):
torch.Size([5])
"""
_v1_transform_cls
=
_transforms
.
FiveCrop
_transformed_types
=
(
datapoints
.
Image
,
PIL
.
Image
.
Image
,
...
...
@@ -200,6 +213,8 @@ class TenCrop(Transform):
See :class:`~torchvision.prototype.transforms.FiveCrop` for an example.
"""
_v1_transform_cls
=
_transforms
.
TenCrop
_transformed_types
=
(
datapoints
.
Image
,
PIL
.
Image
.
Image
,
...
...
@@ -223,6 +238,18 @@ class TenCrop(Transform):
class
Pad
(
Transform
):
_v1_transform_cls
=
_transforms
.
Pad
def
_extract_params_for_v1_transform
(
self
)
->
Dict
[
str
,
Any
]:
params
=
super
().
_extract_params_for_v1_transform
()
if
not
(
params
[
"fill"
]
is
None
or
isinstance
(
params
[
"fill"
],
(
int
,
float
))):
raise
ValueError
(
f
"
{
type
(
self
.
__name__
)
}
() can only be scripted for a scalar `fill`, but got
{
self
.
fill
}
for images."
)
return
params
def
__init__
(
self
,
padding
:
Union
[
int
,
Sequence
[
int
]],
...
...
@@ -285,6 +312,8 @@ class RandomZoomOut(_RandomApplyTransform):
class
RandomRotation
(
Transform
):
_v1_transform_cls
=
_transforms
.
RandomRotation
def
__init__
(
self
,
degrees
:
Union
[
numbers
.
Number
,
Sequence
],
...
...
@@ -322,6 +351,8 @@ class RandomRotation(Transform):
class
RandomAffine
(
Transform
):
_v1_transform_cls
=
_transforms
.
RandomAffine
def
__init__
(
self
,
degrees
:
Union
[
numbers
.
Number
,
Sequence
],
...
...
@@ -399,6 +430,24 @@ class RandomAffine(Transform):
class
RandomCrop
(
Transform
):
_v1_transform_cls
=
_transforms
.
RandomCrop
def
_extract_params_for_v1_transform
(
self
)
->
Dict
[
str
,
Any
]:
params
=
super
().
_extract_params_for_v1_transform
()
if
not
(
params
[
"fill"
]
is
None
or
isinstance
(
params
[
"fill"
],
(
int
,
float
))):
raise
ValueError
(
f
"
{
type
(
self
.
__name__
)
}
() can only be scripted for a scalar `fill`, but got
{
self
.
fill
}
for images."
)
padding
=
self
.
padding
if
padding
is
not
None
:
pad_left
,
pad_right
,
pad_top
,
pad_bottom
=
padding
padding
=
[
pad_left
,
pad_top
,
pad_right
,
pad_bottom
]
params
[
"padding"
]
=
padding
return
params
def
__init__
(
self
,
size
:
Union
[
int
,
Sequence
[
int
]],
...
...
@@ -491,6 +540,8 @@ class RandomCrop(Transform):
class
RandomPerspective
(
_RandomApplyTransform
):
_v1_transform_cls
=
_transforms
.
RandomPerspective
def
__init__
(
self
,
distortion_scale
:
float
=
0.5
,
...
...
@@ -550,6 +601,8 @@ class RandomPerspective(_RandomApplyTransform):
class
ElasticTransform
(
Transform
):
_v1_transform_cls
=
_transforms
.
ElasticTransform
def
__init__
(
self
,
alpha
:
Union
[
float
,
Sequence
[
float
]]
=
50.0
,
...
...
torchvision/prototype/transforms/_meta.py
View file @
7cf0f4cc
...
...
@@ -2,6 +2,7 @@ from typing import Any, Dict, Union
import
torch
from
torchvision
import
transforms
as
_transforms
from
torchvision.prototype
import
datapoints
from
torchvision.prototype.transforms
import
functional
as
F
,
Transform
...
...
@@ -27,6 +28,8 @@ class ConvertBoundingBoxFormat(Transform):
class
ConvertDtype
(
Transform
):
_v1_transform_cls
=
_transforms
.
ConvertImageDtype
_transformed_types
=
(
is_simple_tensor
,
datapoints
.
Image
,
datapoints
.
Video
)
def
__init__
(
self
,
dtype
:
torch
.
dtype
=
torch
.
float32
)
->
None
:
...
...
torchvision/prototype/transforms/_misc.py
View file @
7cf0f4cc
...
...
@@ -4,6 +4,7 @@ import PIL.Image
import
torch
from
torchvision
import
transforms
as
_transforms
from
torchvision.ops
import
remove_small_boxes
from
torchvision.prototype
import
datapoints
from
torchvision.prototype.transforms
import
functional
as
F
,
Transform
...
...
@@ -39,6 +40,8 @@ class Lambda(Transform):
class
LinearTransformation
(
Transform
):
_v1_transform_cls
=
_transforms
.
LinearTransformation
_transformed_types
=
(
is_simple_tensor
,
datapoints
.
Image
,
datapoints
.
Video
)
def
__init__
(
self
,
transformation_matrix
:
torch
.
Tensor
,
mean_vector
:
torch
.
Tensor
):
...
...
@@ -94,6 +97,7 @@ class LinearTransformation(Transform):
class
Normalize
(
Transform
):
_v1_transform_cls
=
_transforms
.
Normalize
_transformed_types
=
(
datapoints
.
Image
,
is_simple_tensor
,
datapoints
.
Video
)
def
__init__
(
self
,
mean
:
Sequence
[
float
],
std
:
Sequence
[
float
],
inplace
:
bool
=
False
):
...
...
@@ -113,6 +117,8 @@ class Normalize(Transform):
class
GaussianBlur
(
Transform
):
_v1_transform_cls
=
_transforms
.
GaussianBlur
def
__init__
(
self
,
kernel_size
:
Union
[
int
,
Sequence
[
int
]],
sigma
:
Union
[
int
,
float
,
Sequence
[
float
]]
=
(
0.1
,
2.0
)
)
->
None
:
...
...
torchvision/prototype/transforms/_transform.py
View file @
7cf0f4cc
from
__future__
import
annotations
import
enum
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Tuple
,
Type
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Type
,
Union
import
PIL.Image
import
torch
...
...
@@ -54,6 +56,51 @@ class Transform(nn.Module):
return
", "
.
join
(
extra
)
# This attribute should be set on all transforms that have a v1 equivalent. Doing so enables the v2 transformation
# to be scriptable. See `_extract_params_for_v1_transform()` and `__prepare_scriptable__` for details.
_v1_transform_cls
:
Optional
[
Type
[
nn
.
Module
]]
=
None
def
_extract_params_for_v1_transform
(
self
)
->
Dict
[
str
,
Any
]:
# This method is called by `__prepare_scriptable__` to instantiate the equivalent v1 transform from the current
# v2 transform instance. It does two things:
# 1. Extract all available public attributes that are specific to that transform and not `nn.Module` in general
# 2. If available handle the `fill` attribute for v1 compatibility (see below for details)
# Overwrite this method on the v2 transform class if the above is not sufficient. For example, this might happen
# if the v2 transform introduced new parameters that are not support by the v1 transform.
common_attrs
=
nn
.
Module
().
__dict__
.
keys
()
params
=
{
attr
:
value
for
attr
,
value
in
self
.
__dict__
.
items
()
if
not
attr
.
startswith
(
"_"
)
and
attr
not
in
common_attrs
}
# transforms v2 has a more complex handling for the `fill` parameter than v1. By default, the input is parsed
# with `prototype.transforms._utils._setup_fill_arg()`, which returns a defaultdict that holds the fill value
# for the different datapoint types. Below we extract the value for tensors and return that together with the
# other params.
# This is needed for `Pad`, `ElasticTransform`, `RandomAffine`, `RandomCrop`, `RandomPerspective` and
# `RandomRotation`
if
"fill"
in
params
:
fill_type_defaultdict
=
params
.
pop
(
"fill"
)
params
[
"fill"
]
=
fill_type_defaultdict
[
torch
.
Tensor
]
return
params
def
__prepare_scriptable__
(
self
)
->
nn
.
Module
:
# This method is called early on when `torch.jit.script`'ing an `nn.Module` instance. If it succeeds, the return
# value is used for scripting over the original object that should have been scripted. Since the v1 transforms
# are JIT scriptable, and we made sure that for single image inputs v1 and v2 are equivalent, we just return the
# equivalent v1 transform here. This of course only makes transforms v2 JIT scriptable as long as transforms v1
# is around.
if
self
.
_v1_transform_cls
is
None
:
raise
RuntimeError
(
f
"Transform
{
type
(
self
.
__name__
)
}
cannot be JIT scripted. "
f
"This is only support for backward compatibility with transforms which already in v1."
f
"For torchscript support (on tensors only), you can use the functional API instead."
)
return
self
.
_v1_transform_cls
(
**
self
.
_extract_params_for_v1_transform
())
class
_RandomApplyTransform
(
Transform
):
def
__init__
(
self
,
p
:
float
=
0.5
)
->
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment