Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
3080082d
"src/partition/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "f2c80b440e80226441dc6c11a95ade10defaaf11"
Unverified
Commit
3080082d
authored
Feb 15, 2023
by
Nicolas Hug
Committed by
GitHub
Feb 15, 2023
Browse files
Make RandomApply torchscriptable in V2 (#7256)
parent
316cc25c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
15 additions
and
3 deletions
+15
-3
test/test_prototype_transforms_consistency.py
test/test_prototype_transforms_consistency.py
+5
-0
torchvision/prototype/transforms/_container.py
torchvision/prototype/transforms/_container.py
+7
-1
torchvision/prototype/transforms/_transform.py
torchvision/prototype/transforms/_transform.py
+3
-2
No files found.
test/test_prototype_transforms_consistency.py
View file @
3080082d
...
@@ -806,6 +806,11 @@ class TestContainerTransforms:
...
@@ -806,6 +806,11 @@ class TestContainerTransforms:
check_call_consistency
(
prototype_transform
,
legacy_transform
)
check_call_consistency
(
prototype_transform
,
legacy_transform
)
if
sequence_type
is
nn
.
ModuleList
:
# quick and dirty test that it is jit-scriptable
scripted
=
torch
.
jit
.
script
(
prototype_transform
)
scripted
(
torch
.
rand
(
1
,
3
,
300
,
300
))
# We can't test other values for `p` since the random parameter generation is different
# We can't test other values for `p` since the random parameter generation is different
@
pytest
.
mark
.
parametrize
(
"probabilities"
,
[(
0
,
1
),
(
1
,
0
)])
@
pytest
.
mark
.
parametrize
(
"probabilities"
,
[(
0
,
1
),
(
1
,
0
)])
def
test_random_choice
(
self
,
probabilities
):
def
test_random_choice
(
self
,
probabilities
):
...
...
torchvision/prototype/transforms/_container.py
View file @
3080082d
import
warnings
import
warnings
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Sequence
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Sequence
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
torchvision
import
transforms
as
_transforms
from
torchvision.prototype.transforms
import
Transform
from
torchvision.prototype.transforms
import
Transform
...
@@ -28,6 +29,8 @@ class Compose(Transform):
...
@@ -28,6 +29,8 @@ class Compose(Transform):
class
RandomApply
(
Transform
):
class
RandomApply
(
Transform
):
_v1_transform_cls
=
_transforms
.
RandomApply
def
__init__
(
self
,
transforms
:
Union
[
Sequence
[
Callable
],
nn
.
ModuleList
],
p
:
float
=
0.5
)
->
None
:
def
__init__
(
self
,
transforms
:
Union
[
Sequence
[
Callable
],
nn
.
ModuleList
],
p
:
float
=
0.5
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -39,6 +42,9 @@ class RandomApply(Transform):
...
@@ -39,6 +42,9 @@ class RandomApply(Transform):
raise
ValueError
(
"`p` should be a floating point value in the interval [0.0, 1.0]."
)
raise
ValueError
(
"`p` should be a floating point value in the interval [0.0, 1.0]."
)
self
.
p
=
p
self
.
p
=
p
def
_extract_params_for_v1_transform
(
self
)
->
Dict
[
str
,
Any
]:
return
{
"transforms"
:
self
.
transforms
,
"p"
:
self
.
p
}
def
forward
(
self
,
*
inputs
:
Any
)
->
Any
:
def
forward
(
self
,
*
inputs
:
Any
)
->
Any
:
sample
=
inputs
if
len
(
inputs
)
>
1
else
inputs
[
0
]
sample
=
inputs
if
len
(
inputs
)
>
1
else
inputs
[
0
]
...
...
torchvision/prototype/transforms/_transform.py
View file @
3080082d
...
@@ -141,8 +141,9 @@ class Transform(nn.Module):
...
@@ -141,8 +141,9 @@ class Transform(nn.Module):
if
self
.
_v1_transform_cls
is
None
:
if
self
.
_v1_transform_cls
is
None
:
raise
RuntimeError
(
raise
RuntimeError
(
f
"Transform
{
type
(
self
).
__name__
}
cannot be JIT scripted. "
f
"Transform
{
type
(
self
).
__name__
}
cannot be JIT scripted. "
f
"This is only support for backward compatibility with transforms which already in v1."
"torchscript is only supported for backward compatibility with transforms "
f
"For torchscript support (on tensors only), you can use the functional API instead."
"which are already in torchvision.transforms. "
"For torchscript support (on tensors only), you can use the functional API instead."
)
)
return
self
.
_v1_transform_cls
(
**
self
.
_extract_params_for_v1_transform
())
return
self
.
_v1_transform_cls
(
**
self
.
_extract_params_for_v1_transform
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment