Unverified Commit 7fb8d068 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Compose transform keeps BC (#6391)

* [proto] Compose keeps BC

* Compose -> Compose(Transform)
parent ae831144
...@@ -1083,3 +1083,22 @@ class TestToTensor: ...@@ -1083,3 +1083,22 @@ class TestToTensor:
fn.call_count == 0 fn.call_count == 0
else: else:
fn.assert_called_once_with(inpt) fn.assert_called_once_with(inpt)
class TestCompose:
def test_assertions(self):
with pytest.raises(TypeError, match="Argument transforms should be a sequence of callables"):
transforms.Compose(123)
@pytest.mark.parametrize(
"trfms",
[
[transforms.Pad(2), transforms.RandomCrop(28)],
[lambda x: 2.0 * x],
],
)
def test_ctor(self, trfms):
c = transforms.Compose(trfms)
inpt = torch.rand(1, 3, 32, 32)
output = c(inpt)
assert isinstance(output, torch.Tensor)
from typing import Any, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional, Sequence
import torch import torch
from torchvision.prototype.transforms import Transform from torchvision.prototype.transforms import Transform
...@@ -7,11 +7,11 @@ from ._transform import _RandomApplyTransform ...@@ -7,11 +7,11 @@ from ._transform import _RandomApplyTransform
class Compose(Transform): class Compose(Transform):
def __init__(self, *transforms: Transform) -> None: def __init__(self, transforms: Sequence[Callable]) -> None:
super().__init__() super().__init__()
if not isinstance(transforms, Sequence):
raise TypeError("Argument transforms should be a sequence of callables")
self.transforms = transforms self.transforms = transforms
for idx, transform in enumerate(transforms):
self.add_module(str(idx), transform)
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0] sample = inputs if len(inputs) > 1 else inputs[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment