"tests/vscode:/vscode.git/clone" did not exist on "13c754c15d5952f9e160b952d4177f1b7b329a67"
Unverified Commit f9966d22 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Restored BC for RandomChoice and RandomOrder (#6488)

parent 020eafe1
......@@ -1092,20 +1092,18 @@ class TestToTensor:
fn.assert_called_once_with(inpt)
class TestCompose:
def test_assertions(self):
class TestContainers:
@pytest.mark.parametrize("transform_cls", [transforms.Compose, transforms.RandomChoice, transforms.RandomOrder])
def test_assertions(self, transform_cls):
with pytest.raises(TypeError, match="Argument transforms should be a sequence of callables"):
transforms.Compose(123)
transform_cls(transforms.RandomCrop(28))
@pytest.mark.parametrize("transform_cls", [transforms.Compose, transforms.RandomChoice, transforms.RandomOrder])
@pytest.mark.parametrize(
"trfms",
[
[transforms.Pad(2), transforms.RandomCrop(28)],
[lambda x: 2.0 * x],
],
"trfms", [[transforms.Pad(2), transforms.RandomCrop(28)], [lambda x: 2.0 * x, transforms.RandomCrop(28)]]
)
def test_ctor(self, trfms):
c = transforms.Compose(trfms)
def test_ctor(self, transform_cls, trfms):
c = transform_cls(trfms)
inpt = torch.rand(1, 3, 32, 32)
output = c(inpt)
assert isinstance(output, torch.Tensor)
......
......@@ -33,7 +33,9 @@ class RandomApply(_RandomApplyTransform):
class RandomChoice(Transform):
def __init__(self, *transforms: Transform, probabilities: Optional[List[float]] = None) -> None:
def __init__(self, transforms: Sequence[Callable], probabilities: Optional[List[float]] = None) -> None:
if not isinstance(transforms, Sequence):
raise TypeError("Argument transforms should be a sequence of callables")
if probabilities is None:
probabilities = [1] * len(transforms)
elif len(probabilities) != len(transforms):
......@@ -45,9 +47,6 @@ class RandomChoice(Transform):
super().__init__()
self.transforms = transforms
for idx, transform in enumerate(transforms):
self.add_module(str(idx), transform)
total = sum(probabilities)
self.probabilities = [p / total for p in probabilities]
......@@ -58,11 +57,11 @@ class RandomChoice(Transform):
class RandomOrder(Transform):
def __init__(self, *transforms: Transform) -> None:
def __init__(self, transforms: Sequence[Callable]) -> None:
if not isinstance(transforms, Sequence):
raise TypeError("Argument transforms should be a sequence of callables")
super().__init__()
self.transforms = transforms
for idx, transform in enumerate(transforms):
self.add_module(str(idx), transform)
def forward(self, *inputs: Any) -> Any:
for idx in torch.randperm(len(self.transforms)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment