Unverified Commit 3080082d authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Make RandomApply torchscriptable in V2 (#7256)

parent 316cc25c
......@@ -806,6 +806,11 @@ class TestContainerTransforms:
check_call_consistency(prototype_transform, legacy_transform)
if sequence_type is nn.ModuleList:
# quick and dirty test that it is jit-scriptable
scripted = torch.jit.script(prototype_transform)
scripted(torch.rand(1, 3, 300, 300))
# We can't test other values for `p` since the random parameter generation is different
@pytest.mark.parametrize("probabilities", [(0, 1), (1, 0)])
def test_random_choice(self, probabilities):
......
import warnings
from typing import Any, Callable, List, Optional, Sequence, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
import torch
from torch import nn
from torchvision import transforms as _transforms
from torchvision.prototype.transforms import Transform
......@@ -28,6 +29,8 @@ class Compose(Transform):
class RandomApply(Transform):
_v1_transform_cls = _transforms.RandomApply
def __init__(self, transforms: Union[Sequence[Callable], nn.ModuleList], p: float = 0.5) -> None:
super().__init__()
......@@ -39,6 +42,9 @@ class RandomApply(Transform):
raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].")
self.p = p
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
return {"transforms": self.transforms, "p": self.p}
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
......
......@@ -141,8 +141,9 @@ class Transform(nn.Module):
if self._v1_transform_cls is None:
raise RuntimeError(
f"Transform {type(self).__name__} cannot be JIT scripted. "
f"This is only support for backward compatibility with transforms which already in v1."
f"For torchscript support (on tensors only), you can use the functional API instead."
"torchscript is only supported for backward compatibility with transforms "
"which are already in torchvision.transforms. "
"For torchscript support (on tensors only), you can use the functional API instead."
)
return self._v1_transform_cls(**self._extract_params_for_v1_transform())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment