Unverified Commit d75a5241 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

allow nn.ModuleList in RandomApply (#7197)

parent 539c6e29
......@@ -23,6 +23,7 @@ from prototype_common_utils import (
make_label,
make_segmentation_mask,
)
from torch import nn
from torchvision import transforms as legacy_transforms
from torchvision._utils import sequence_to_str
from torchvision.prototype import datapoints, transforms as prototype_transforms
......@@ -761,19 +762,24 @@ class TestContainerTransforms:
check_call_consistency(prototype_transform, legacy_transform)
@pytest.mark.parametrize("p", [0, 0.1, 0.5, 0.9, 1])
def test_random_apply(self, p):
@pytest.mark.parametrize("sequence_type", [list, nn.ModuleList])
def test_random_apply(self, p, sequence_type):
prototype_transform = prototype_transforms.RandomApply(
sequence_type(
[
prototype_transforms.Resize(256),
prototype_transforms.CenterCrop(224),
],
]
),
p=p,
)
legacy_transform = legacy_transforms.RandomApply(
sequence_type(
[
legacy_transforms.Resize(256),
legacy_transforms.CenterCrop(224),
],
]
),
p=p,
)
......
import warnings
from typing import Any, Callable, List, Optional, Sequence
from typing import Any, Callable, List, Optional, Sequence, Union
import torch
from torch import nn
from torchvision.prototype.transforms import Transform
......@@ -25,9 +27,13 @@ class Compose(Transform):
return "\n".join(format_string)
class RandomApply(Compose):
def __init__(self, transforms: Sequence[Callable], p: float = 0.5) -> None:
super().__init__(transforms)
class RandomApply(Transform):
def __init__(self, transforms: Union[Sequence[Callable], nn.ModuleList], p: float = 0.5) -> None:
super().__init__()
if not isinstance(transforms, (Sequence, nn.ModuleList)):
raise TypeError("Argument transforms should be a sequence of callables or a `nn.ModuleList`")
self.transforms = transforms
if not (0.0 <= p <= 1.0):
raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].")
......@@ -39,7 +45,15 @@ class RandomApply(Compose):
if torch.rand(1) >= self.p:
return sample
return super().forward(sample)
for transform in self.transforms:
sample = transform(sample)
return sample
def extra_repr(self) -> str:
format_string = []
for t in self.transforms:
format_string.append(f" {t}")
return "\n".join(format_string)
class RandomChoice(Transform):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment