"vscode:/vscode.git/clone" did not exist on "cac0c44a7546fac99d84ec0c59bb613db8c7f9a1"
Unverified Commit f1f6a972 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Added RandomApply as scriptable transform (#2769)

parent 6eff0a43
......@@ -179,7 +179,7 @@
"Next, we show how to combine input transformations and model's forward pass and use `torch.jit.script` to obtain a single scripted module.\n",
"\n",
"**Note:** we have to use only scriptable transformations that should be derived from `torch.nn.Module`. \n",
"Since v0.8.0, all transformations are scriptable except `Compose`, `RandomApply`, `RandomChoice`, `RandomOrder`, `Lambda` and those applied on PIL images. \n",
"Since v0.8.0, all transformations are scriptable except `Compose`, `RandomChoice`, `RandomOrder`, `Lambda` and those applied on PIL images.\n",
"The transformations like `Compose` are kept for backward compatibility and can be easily replaced by existing torch modules, like `nn.Sequential`.\n",
"\n",
"Let's define a module `Predictor` that transforms input tensor and applies ImageNet pretrained resnet18 model on it."
......
......@@ -466,6 +466,35 @@ class Tester(TransformsTester):
with self.assertRaisesRegex(RuntimeError, r"Could not get name of python class object"):
torch.jit.script(t)
def test_random_apply(self):
tensor, _ = self._create_data(26, 34, device=self.device)
tensor = tensor.to(dtype=torch.float32) / 255.0
transforms = T.RandomApply([
T.RandomHorizontalFlip(),
T.ColorJitter(),
], p=0.4)
s_transforms = T.RandomApply(torch.nn.ModuleList([
T.RandomHorizontalFlip(),
T.ColorJitter(),
]), p=0.4)
scripted_fn = torch.jit.script(s_transforms)
torch.manual_seed(12)
transformed_tensor = transforms(tensor)
torch.manual_seed(12)
transformed_tensor_script = scripted_fn(tensor)
self.assertTrue(transformed_tensor.equal(transformed_tensor_script), msg="{}".format(transforms))
if torch.device(self.device).type == "cpu":
# Can't check this twice, otherwise
# "Can't redefine method: forward on class: __torch__.torchvision.transforms.transforms.RandomApply"
transforms = T.RandomApply([
T.ColorJitter(),
], p=0.3)
with self.assertRaisesRegex(RuntimeError, r"Module 'RandomApply' has no attribute 'transforms'"):
torch.jit.script(transforms)
def test_gaussian_blur(self):
tol = 1.0 + 1e-10
self._test_class_op(
......
......@@ -400,7 +400,8 @@ class RandomTransforms:
"""
def __init__(self, transforms):
assert isinstance(transforms, (list, tuple))
if not isinstance(transforms, Sequence):
raise TypeError("Argument transforms should be a sequence")
self.transforms = transforms
def __call__(self, *args, **kwargs):
......@@ -415,21 +416,33 @@ class RandomTransforms:
return format_string
class RandomApply(RandomTransforms):
class RandomApply(torch.nn.Module):
"""Apply randomly a list of transformations with a given probability.
This transform does not support torchscript.
.. note::
In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of
transforms as shown below:
>>> transforms = transforms.RandomApply(torch.nn.ModuleList([
>>> transforms.ColorJitter(),
>>> ]), p=0.3)
>>> scripted_transforms = torch.jit.script(transforms)
Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
`lambda` functions or ``PIL.Image``.
Args:
transforms (list or tuple): list of transformations
transforms (list or tuple or torch.nn.Module): list of transformations
p (float): probability
"""
def __init__(self, transforms, p=0.5):
super().__init__(transforms)
super().__init__()
self.transforms = transforms
self.p = p
def __call__(self, img):
if self.p < random.random():
def forward(self, img):
if self.p < torch.rand(1):
return img
for t in self.transforms:
img = t(img)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment