"docs/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "3250d3df168c956389bd16956aa458ce111570d0"
Unverified Commit cc0f9d02 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

improve UX for v2 Compose (#7758)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent a6dea861
......@@ -26,6 +26,8 @@ from common_utils import (
make_video,
set_rng_seed,
)
from torch import nn
from torch.testing import assert_close
from torchvision import datapoints
......@@ -1634,3 +1636,64 @@ class TestRotate:
def test_transform_unknown_fill_error(self):
with pytest.raises(TypeError, match="Got inappropriate fill arg"):
transforms.RandomAffine(degrees=0, fill="fill")
class TestCompose:
class BuiltinTransform(transforms.Transform):
def _transform(self, inpt, params):
return inpt
class PackedInputTransform(nn.Module):
def forward(self, sample):
assert len(sample) == 2
return sample
class UnpackedInputTransform(nn.Module):
def forward(self, image, label):
return image, label
@pytest.mark.parametrize(
"transform_clss",
[
[BuiltinTransform],
[PackedInputTransform],
[UnpackedInputTransform],
[BuiltinTransform, BuiltinTransform],
[PackedInputTransform, PackedInputTransform],
[UnpackedInputTransform, UnpackedInputTransform],
[BuiltinTransform, PackedInputTransform, BuiltinTransform],
[BuiltinTransform, UnpackedInputTransform, BuiltinTransform],
[PackedInputTransform, BuiltinTransform, PackedInputTransform],
[UnpackedInputTransform, BuiltinTransform, UnpackedInputTransform],
],
)
@pytest.mark.parametrize("unpack", [True, False])
def test_packed_unpacked(self, transform_clss, unpack):
needs_packed_inputs = any(issubclass(cls, self.PackedInputTransform) for cls in transform_clss)
needs_unpacked_inputs = any(issubclass(cls, self.UnpackedInputTransform) for cls in transform_clss)
assert not (needs_packed_inputs and needs_unpacked_inputs)
transform = transforms.Compose([cls() for cls in transform_clss])
image = make_image()
label = 3
packed_input = (image, label)
def call_transform():
if unpack:
return transform(*packed_input)
else:
return transform(packed_input)
if needs_unpacked_inputs and not unpack:
with pytest.raises(TypeError, match="missing 1 required positional argument"):
call_transform()
elif needs_packed_inputs and unpack:
with pytest.raises(TypeError, match="takes 2 positional arguments but 3 were given"):
call_transform()
else:
output = call_transform()
assert isinstance(output, tuple) and len(output) == 2
assert output[0] is image
assert output[1] is label
......@@ -43,13 +43,16 @@ class Compose(Transform):
super().__init__()
if not isinstance(transforms, Sequence):
raise TypeError("Argument transforms should be a sequence of callables")
elif not transforms:
raise ValueError("Pass at least one transform")
self.transforms = transforms
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
needs_unpacking = len(inputs) > 1
for transform in self.transforms:
sample = transform(sample)
return sample
outputs = transform(*inputs)
inputs = outputs if needs_unpacking else (outputs,)
return outputs
def extra_repr(self) -> str:
format_string = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment