Unverified Commit 7fb8d068 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Compose transform keeps BC (#6391)

* [proto] Compose keeps BC

* Compose -> Compose(Transform)
parent ae831144
......@@ -1083,3 +1083,22 @@ class TestToTensor:
fn.call_count == 0
else:
fn.assert_called_once_with(inpt)
class TestCompose:
def test_assertions(self):
with pytest.raises(TypeError, match="Argument transforms should be a sequence of callables"):
transforms.Compose(123)
@pytest.mark.parametrize(
"trfms",
[
[transforms.Pad(2), transforms.RandomCrop(28)],
[lambda x: 2.0 * x],
],
)
def test_ctor(self, trfms):
c = transforms.Compose(trfms)
inpt = torch.rand(1, 3, 32, 32)
output = c(inpt)
assert isinstance(output, torch.Tensor)
from typing import Any, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Sequence
import torch
from torchvision.prototype.transforms import Transform
......@@ -7,11 +7,11 @@ from ._transform import _RandomApplyTransform
class Compose(Transform):
def __init__(self, *transforms: Transform) -> None:
def __init__(self, transforms: Sequence[Callable]) -> None:
super().__init__()
if not isinstance(transforms, Sequence):
raise TypeError("Argument transforms should be a sequence of callables")
self.transforms = transforms
for idx, transform in enumerate(transforms):
self.add_module(str(idx), transform)
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment