Unverified Commit 528651a0 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Fix bug with Compose and PR 6504 (#6510)

* [proto] Fix bug with Compose and PR 6504

* Added tests and fixed other bugs
parent 7245dc9e
......@@ -1108,13 +1108,18 @@ class TestContainers:
@pytest.mark.parametrize("transform_cls", [transforms.Compose, transforms.RandomChoice, transforms.RandomOrder])
@pytest.mark.parametrize(
"trfms", [[transforms.Pad(2), transforms.RandomCrop(28)], [lambda x: 2.0 * x, transforms.RandomCrop(28)]]
"trfms",
[
[transforms.Pad(2), transforms.RandomCrop(28)],
[lambda x: 2.0 * x, transforms.Pad(2), transforms.RandomCrop(28)],
],
)
def test_ctor(self, transform_cls, trfms):
c = transform_cls(trfms)
inpt = torch.rand(1, 3, 32, 32)
output = c(inpt)
assert isinstance(output, torch.Tensor)
assert output.ndim == 4
class TestRandomChoice:
......
......@@ -15,9 +15,10 @@ class Compose(Transform):
self.transforms = transforms
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
for transform in self.transforms:
inputs = transform(*inputs)
return inputs
sample = transform(sample)
return sample
class RandomApply(_RandomApplyTransform):
......@@ -76,7 +77,8 @@ class RandomOrder(Transform):
self.transforms = transforms
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
for idx in torch.randperm(len(self.transforms)):
transform = self.transforms[idx]
inputs = transform(*inputs)
return inputs
sample = transform(sample)
return sample
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment