Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
7fb8d068
Unverified
Commit
7fb8d068
authored
Aug 11, 2022
by
vfdev
Committed by
GitHub
Aug 11, 2022
Browse files
[proto] Compose transform keeps BC (#6391)
* [proto] Compose keeps BC * Compose -> Compose(Transform)
parent
ae831144
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
4 deletions
+23
-4
test/test_prototype_transforms.py
test/test_prototype_transforms.py
+19
-0
torchvision/prototype/transforms/_container.py
torchvision/prototype/transforms/_container.py
+4
-4
No files found.
test/test_prototype_transforms.py
View file @
7fb8d068
...
@@ -1083,3 +1083,22 @@ class TestToTensor:
...
@@ -1083,3 +1083,22 @@ class TestToTensor:
fn
.
call_count
==
0
fn
.
call_count
==
0
else
:
else
:
fn
.
assert_called_once_with
(
inpt
)
fn
.
assert_called_once_with
(
inpt
)
class
TestCompose
:
def
test_assertions
(
self
):
with
pytest
.
raises
(
TypeError
,
match
=
"Argument transforms should be a sequence of callables"
):
transforms
.
Compose
(
123
)
@
pytest
.
mark
.
parametrize
(
"trfms"
,
[
[
transforms
.
Pad
(
2
),
transforms
.
RandomCrop
(
28
)],
[
lambda
x
:
2.0
*
x
],
],
)
def
test_ctor
(
self
,
trfms
):
c
=
transforms
.
Compose
(
trfms
)
inpt
=
torch
.
rand
(
1
,
3
,
32
,
32
)
output
=
c
(
inpt
)
assert
isinstance
(
output
,
torch
.
Tensor
)
torchvision/prototype/transforms/_container.py
View file @
7fb8d068
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Sequence
import
torch
import
torch
from
torchvision.prototype.transforms
import
Transform
from
torchvision.prototype.transforms
import
Transform
...
@@ -7,11 +7,11 @@ from ._transform import _RandomApplyTransform
...
@@ -7,11 +7,11 @@ from ._transform import _RandomApplyTransform
class
Compose
(
Transform
):
class
Compose
(
Transform
):
def
__init__
(
self
,
*
transforms
:
Transform
)
->
None
:
def
__init__
(
self
,
transforms
:
Sequence
[
Callable
]
)
->
None
:
super
().
__init__
()
super
().
__init__
()
if
not
isinstance
(
transforms
,
Sequence
):
raise
TypeError
(
"Argument transforms should be a sequence of callables"
)
self
.
transforms
=
transforms
self
.
transforms
=
transforms
for
idx
,
transform
in
enumerate
(
transforms
):
self
.
add_module
(
str
(
idx
),
transform
)
def
forward
(
self
,
*
inputs
:
Any
)
->
Any
:
def
forward
(
self
,
*
inputs
:
Any
)
->
Any
:
sample
=
inputs
if
len
(
inputs
)
>
1
else
inputs
[
0
]
sample
=
inputs
if
len
(
inputs
)
>
1
else
inputs
[
0
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment