Unverified Commit d75a5241 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

allow nn.ModuleList in RandomApply (#7197)

parent 539c6e29
...@@ -23,6 +23,7 @@ from prototype_common_utils import ( ...@@ -23,6 +23,7 @@ from prototype_common_utils import (
make_label, make_label,
make_segmentation_mask, make_segmentation_mask,
) )
from torch import nn
from torchvision import transforms as legacy_transforms from torchvision import transforms as legacy_transforms
from torchvision._utils import sequence_to_str from torchvision._utils import sequence_to_str
from torchvision.prototype import datapoints, transforms as prototype_transforms from torchvision.prototype import datapoints, transforms as prototype_transforms
...@@ -761,19 +762,24 @@ class TestContainerTransforms: ...@@ -761,19 +762,24 @@ class TestContainerTransforms:
check_call_consistency(prototype_transform, legacy_transform) check_call_consistency(prototype_transform, legacy_transform)
@pytest.mark.parametrize("p", [0, 0.1, 0.5, 0.9, 1]) @pytest.mark.parametrize("p", [0, 0.1, 0.5, 0.9, 1])
def test_random_apply(self, p): @pytest.mark.parametrize("sequence_type", [list, nn.ModuleList])
def test_random_apply(self, p, sequence_type):
prototype_transform = prototype_transforms.RandomApply( prototype_transform = prototype_transforms.RandomApply(
[ sequence_type(
prototype_transforms.Resize(256), [
prototype_transforms.CenterCrop(224), prototype_transforms.Resize(256),
], prototype_transforms.CenterCrop(224),
]
),
p=p, p=p,
) )
legacy_transform = legacy_transforms.RandomApply( legacy_transform = legacy_transforms.RandomApply(
[ sequence_type(
legacy_transforms.Resize(256), [
legacy_transforms.CenterCrop(224), legacy_transforms.Resize(256),
], legacy_transforms.CenterCrop(224),
]
),
p=p, p=p,
) )
......
import warnings import warnings
from typing import Any, Callable, List, Optional, Sequence from typing import Any, Callable, List, Optional, Sequence, Union
import torch import torch
from torch import nn
from torchvision.prototype.transforms import Transform from torchvision.prototype.transforms import Transform
...@@ -25,9 +27,13 @@ class Compose(Transform): ...@@ -25,9 +27,13 @@ class Compose(Transform):
return "\n".join(format_string) return "\n".join(format_string)
class RandomApply(Compose): class RandomApply(Transform):
def __init__(self, transforms: Sequence[Callable], p: float = 0.5) -> None: def __init__(self, transforms: Union[Sequence[Callable], nn.ModuleList], p: float = 0.5) -> None:
super().__init__(transforms) super().__init__()
if not isinstance(transforms, (Sequence, nn.ModuleList)):
raise TypeError("Argument transforms should be a sequence of callables or a `nn.ModuleList`")
self.transforms = transforms
if not (0.0 <= p <= 1.0): if not (0.0 <= p <= 1.0):
raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].") raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].")
...@@ -39,7 +45,15 @@ class RandomApply(Compose): ...@@ -39,7 +45,15 @@ class RandomApply(Compose):
if torch.rand(1) >= self.p: if torch.rand(1) >= self.p:
return sample return sample
return super().forward(sample) for transform in self.transforms:
sample = transform(sample)
return sample
def extra_repr(self) -> str:
format_string = []
for t in self.transforms:
format_string.append(f" {t}")
return "\n".join(format_string)
class RandomChoice(Transform): class RandomChoice(Transform):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment