Unverified Commit 528651a0 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Fix bug with Compose and PR 6504 (#6510)

* [proto] Fix bug with Compose and PR 6504

* Added tests and fixed other bugs
parent 7245dc9e
...@@ -1108,13 +1108,18 @@ class TestContainers: ...@@ -1108,13 +1108,18 @@ class TestContainers:
@pytest.mark.parametrize("transform_cls", [transforms.Compose, transforms.RandomChoice, transforms.RandomOrder]) @pytest.mark.parametrize("transform_cls", [transforms.Compose, transforms.RandomChoice, transforms.RandomOrder])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"trfms", [[transforms.Pad(2), transforms.RandomCrop(28)], [lambda x: 2.0 * x, transforms.RandomCrop(28)]] "trfms",
[
[transforms.Pad(2), transforms.RandomCrop(28)],
[lambda x: 2.0 * x, transforms.Pad(2), transforms.RandomCrop(28)],
],
) )
def test_ctor(self, transform_cls, trfms): def test_ctor(self, transform_cls, trfms):
c = transform_cls(trfms) c = transform_cls(trfms)
inpt = torch.rand(1, 3, 32, 32) inpt = torch.rand(1, 3, 32, 32)
output = c(inpt) output = c(inpt)
assert isinstance(output, torch.Tensor) assert isinstance(output, torch.Tensor)
assert output.ndim == 4
class TestRandomChoice: class TestRandomChoice:
......
...@@ -15,9 +15,10 @@ class Compose(Transform): ...@@ -15,9 +15,10 @@ class Compose(Transform):
self.transforms = transforms self.transforms = transforms
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
for transform in self.transforms: for transform in self.transforms:
inputs = transform(*inputs) sample = transform(sample)
return inputs return sample
class RandomApply(_RandomApplyTransform): class RandomApply(_RandomApplyTransform):
...@@ -76,7 +77,8 @@ class RandomOrder(Transform): ...@@ -76,7 +77,8 @@ class RandomOrder(Transform):
self.transforms = transforms self.transforms = transforms
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
for idx in torch.randperm(len(self.transforms)): for idx in torch.randperm(len(self.transforms)):
transform = self.transforms[idx] transform = self.transforms[idx]
inputs = transform(*inputs) sample = transform(sample)
return inputs return sample
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment