"src/partition/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "f2c80b440e80226441dc6c11a95ade10defaaf11"
Unverified Commit 3080082d authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Make RandomApply torchscriptable in V2 (#7256)

parent 316cc25c
...@@ -806,6 +806,11 @@ class TestContainerTransforms: ...@@ -806,6 +806,11 @@ class TestContainerTransforms:
check_call_consistency(prototype_transform, legacy_transform) check_call_consistency(prototype_transform, legacy_transform)
if sequence_type is nn.ModuleList:
# quick and dirty test that it is jit-scriptable
scripted = torch.jit.script(prototype_transform)
scripted(torch.rand(1, 3, 300, 300))
# We can't test other values for `p` since the random parameter generation is different # We can't test other values for `p` since the random parameter generation is different
@pytest.mark.parametrize("probabilities", [(0, 1), (1, 0)]) @pytest.mark.parametrize("probabilities", [(0, 1), (1, 0)])
def test_random_choice(self, probabilities): def test_random_choice(self, probabilities):
......
import warnings import warnings
from typing import Any, Callable, List, Optional, Sequence, Union from typing import Any, Callable, Dict, List, Optional, Sequence, Union
import torch import torch
from torch import nn from torch import nn
from torchvision import transforms as _transforms
from torchvision.prototype.transforms import Transform from torchvision.prototype.transforms import Transform
...@@ -28,6 +29,8 @@ class Compose(Transform): ...@@ -28,6 +29,8 @@ class Compose(Transform):
class RandomApply(Transform): class RandomApply(Transform):
_v1_transform_cls = _transforms.RandomApply
def __init__(self, transforms: Union[Sequence[Callable], nn.ModuleList], p: float = 0.5) -> None: def __init__(self, transforms: Union[Sequence[Callable], nn.ModuleList], p: float = 0.5) -> None:
super().__init__() super().__init__()
...@@ -39,6 +42,9 @@ class RandomApply(Transform): ...@@ -39,6 +42,9 @@ class RandomApply(Transform):
raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].") raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].")
self.p = p self.p = p
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
return {"transforms": self.transforms, "p": self.p}
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0] sample = inputs if len(inputs) > 1 else inputs[0]
......
...@@ -141,8 +141,9 @@ class Transform(nn.Module): ...@@ -141,8 +141,9 @@ class Transform(nn.Module):
if self._v1_transform_cls is None: if self._v1_transform_cls is None:
raise RuntimeError( raise RuntimeError(
f"Transform {type(self).__name__} cannot be JIT scripted. " f"Transform {type(self).__name__} cannot be JIT scripted. "
f"This is only support for backward compatibility with transforms which already in v1." "torchscript is only supported for backward compatibility with transforms "
f"For torchscript support (on tensors only), you can use the functional API instead." "which are already in torchvision.transforms. "
"For torchscript support (on tensors only), you can use the functional API instead."
) )
return self._v1_transform_cls(**self._extract_params_for_v1_transform()) return self._v1_transform_cls(**self._extract_params_for_v1_transform())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment