"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "f088027e937b2ee1acef1f6b2776b7b2fee7ffd6"
Unverified Commit f9966d22 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Restored BC for RandomChoice and RandomOrder (#6488)

parent 020eafe1
...@@ -1092,20 +1092,18 @@ class TestToTensor: ...@@ -1092,20 +1092,18 @@ class TestToTensor:
fn.assert_called_once_with(inpt) fn.assert_called_once_with(inpt)
class TestCompose: class TestContainers:
def test_assertions(self): @pytest.mark.parametrize("transform_cls", [transforms.Compose, transforms.RandomChoice, transforms.RandomOrder])
def test_assertions(self, transform_cls):
with pytest.raises(TypeError, match="Argument transforms should be a sequence of callables"): with pytest.raises(TypeError, match="Argument transforms should be a sequence of callables"):
transforms.Compose(123) transform_cls(transforms.RandomCrop(28))
@pytest.mark.parametrize("transform_cls", [transforms.Compose, transforms.RandomChoice, transforms.RandomOrder])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"trfms", "trfms", [[transforms.Pad(2), transforms.RandomCrop(28)], [lambda x: 2.0 * x, transforms.RandomCrop(28)]]
[
[transforms.Pad(2), transforms.RandomCrop(28)],
[lambda x: 2.0 * x],
],
) )
def test_ctor(self, trfms): def test_ctor(self, transform_cls, trfms):
c = transforms.Compose(trfms) c = transform_cls(trfms)
inpt = torch.rand(1, 3, 32, 32) inpt = torch.rand(1, 3, 32, 32)
output = c(inpt) output = c(inpt)
assert isinstance(output, torch.Tensor) assert isinstance(output, torch.Tensor)
......
...@@ -33,7 +33,9 @@ class RandomApply(_RandomApplyTransform): ...@@ -33,7 +33,9 @@ class RandomApply(_RandomApplyTransform):
class RandomChoice(Transform): class RandomChoice(Transform):
def __init__(self, *transforms: Transform, probabilities: Optional[List[float]] = None) -> None: def __init__(self, transforms: Sequence[Callable], probabilities: Optional[List[float]] = None) -> None:
if not isinstance(transforms, Sequence):
raise TypeError("Argument transforms should be a sequence of callables")
if probabilities is None: if probabilities is None:
probabilities = [1] * len(transforms) probabilities = [1] * len(transforms)
elif len(probabilities) != len(transforms): elif len(probabilities) != len(transforms):
...@@ -45,9 +47,6 @@ class RandomChoice(Transform): ...@@ -45,9 +47,6 @@ class RandomChoice(Transform):
super().__init__() super().__init__()
self.transforms = transforms self.transforms = transforms
for idx, transform in enumerate(transforms):
self.add_module(str(idx), transform)
total = sum(probabilities) total = sum(probabilities)
self.probabilities = [p / total for p in probabilities] self.probabilities = [p / total for p in probabilities]
...@@ -58,11 +57,11 @@ class RandomChoice(Transform): ...@@ -58,11 +57,11 @@ class RandomChoice(Transform):
class RandomOrder(Transform): class RandomOrder(Transform):
def __init__(self, *transforms: Transform) -> None: def __init__(self, transforms: Sequence[Callable]) -> None:
if not isinstance(transforms, Sequence):
raise TypeError("Argument transforms should be a sequence of callables")
super().__init__() super().__init__()
self.transforms = transforms self.transforms = transforms
for idx, transform in enumerate(transforms):
self.add_module(str(idx), transform)
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
for idx in torch.randperm(len(self.transforms)): for idx in torch.randperm(len(self.transforms)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment