Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
c0911e31
"vscode:/vscode.git/clone" did not exist on "8520f0bea4a09e60a217fe3a8cf24b8f733ec16c"
Unverified
Commit
c0911e31
authored
Sep 16, 2022
by
vfdev
Committed by
GitHub
Sep 16, 2022
Browse files
Update typehint for fill arg in rotate (#6594)
parent
753bf186
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
14 additions
and
10 deletions
+14
-10
test/test_prototype_transforms_functional.py
test/test_prototype_transforms_functional.py
+5
-3
torchvision/prototype/transforms/functional/_geometry.py
torchvision/prototype/transforms/functional/_geometry.py
+1
-1
torchvision/transforms/functional_tensor.py
torchvision/transforms/functional_tensor.py
+8
-6
No files found.
test/test_prototype_transforms_functional.py
View file @
c0911e31
...
...
@@ -102,18 +102,20 @@ def affine_mask():
@
register_kernel_info_from_sample_inputs_fn
def
rotate_image_tensor
():
for
image
,
angle
,
expand
,
center
,
fill
in
itertools
.
product
(
for
image
,
angle
,
expand
,
center
in
itertools
.
product
(
make_images
(),
[
-
87
,
15
,
90
],
# angle
[
True
,
False
],
# expand
[
None
,
[
12
,
23
]],
# center
[
None
,
[
128
],
[
12.0
]],
# fill
):
if
center
is
not
None
and
expand
:
# Skip warning: The provided center argument is ignored if expand is True
continue
yield
ArgsKwargs
(
image
,
angle
=
angle
,
expand
=
expand
,
center
=
center
,
fill
=
fill
)
yield
ArgsKwargs
(
image
,
angle
=
angle
,
expand
=
expand
,
center
=
center
,
fill
=
None
)
for
fill
in
[
None
,
128.0
,
128
,
[
12.0
],
[
1.0
,
2.0
,
3.0
]]:
yield
ArgsKwargs
(
image
,
angle
=
23
,
expand
=
False
,
center
=
None
,
fill
=
fill
)
@
register_kernel_info_from_sample_inputs_fn
...
...
torchvision/prototype/transforms/functional/_geometry.py
View file @
c0911e31
...
...
@@ -467,7 +467,7 @@ def rotate_image_tensor(
angle
:
float
,
interpolation
:
InterpolationMode
=
InterpolationMode
.
NEAREST
,
expand
:
bool
=
False
,
fill
:
Optional
[
List
[
float
]]
=
None
,
fill
:
Optional
[
Union
[
int
,
float
,
List
[
float
]]
]
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
torch
.
Tensor
:
num_channels
,
height
,
width
=
img
.
shape
[
-
3
:]
...
...
torchvision/transforms/functional_tensor.py
View file @
c0911e31
...
...
@@ -475,7 +475,7 @@ def _assert_grid_transform_inputs(
img
:
Tensor
,
matrix
:
Optional
[
List
[
float
]],
interpolation
:
str
,
fill
:
Optional
[
List
[
float
]],
fill
:
Optional
[
Union
[
int
,
float
,
List
[
float
]]
]
,
supported_interpolation_modes
:
List
[
str
],
coeffs
:
Optional
[
List
[
float
]]
=
None
,
)
->
None
:
...
...
@@ -499,7 +499,7 @@ def _assert_grid_transform_inputs(
# Check fill
num_channels
=
get_dimensions
(
img
)[
0
]
if
isinstance
(
fill
,
(
tuple
,
list
))
and
(
len
(
fill
)
>
1
and
len
(
fill
)
!=
num_channels
):
if
fill
is
not
None
and
isinstance
(
fill
,
(
tuple
,
list
))
and
(
len
(
fill
)
>
1
and
len
(
fill
)
!=
num_channels
):
msg
=
(
"The number of elements in 'fill' cannot broadcast to match the number of "
"channels of the image ({} != {})"
...
...
@@ -539,7 +539,9 @@ def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtyp
return
img
def
_apply_grid_transform
(
img
:
Tensor
,
grid
:
Tensor
,
mode
:
str
,
fill
:
Optional
[
List
[
float
]])
->
Tensor
:
def
_apply_grid_transform
(
img
:
Tensor
,
grid
:
Tensor
,
mode
:
str
,
fill
:
Optional
[
Union
[
int
,
float
,
List
[
float
]]]
)
->
Tensor
:
img
,
need_cast
,
need_squeeze
,
out_dtype
=
_cast_squeeze_in
(
img
,
[
grid
.
dtype
])
...
...
@@ -559,8 +561,8 @@ def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str, fill: Optional[L
mask
=
img
[:,
-
1
:,
:,
:]
# N * 1 * H * W
img
=
img
[:,
:
-
1
,
:,
:]
# N * C * H * W
mask
=
mask
.
expand_as
(
img
)
len_fill
=
len
(
fill
)
if
isinstance
(
fill
,
(
tuple
,
list
))
else
1
fill_img
=
torch
.
tensor
(
fill
,
dtype
=
img
.
dtype
,
device
=
img
.
device
).
view
(
1
,
len_fill
,
1
,
1
).
expand_as
(
img
)
fill_list
,
len_fill
=
(
fill
,
len
(
fill
)
)
if
isinstance
(
fill
,
(
tuple
,
list
))
else
([
float
(
fill
)],
1
)
fill_img
=
torch
.
tensor
(
fill
_list
,
dtype
=
img
.
dtype
,
device
=
img
.
device
).
view
(
1
,
len_fill
,
1
,
1
).
expand_as
(
img
)
if
mode
==
"nearest"
:
mask
=
mask
<
0.5
img
[
mask
]
=
fill_img
[
mask
]
...
...
@@ -648,7 +650,7 @@ def rotate(
matrix
:
List
[
float
],
interpolation
:
str
=
"nearest"
,
expand
:
bool
=
False
,
fill
:
Optional
[
List
[
float
]]
=
None
,
fill
:
Optional
[
Union
[
int
,
float
,
List
[
float
]]
]
=
None
,
)
->
Tensor
:
_assert_grid_transform_inputs
(
img
,
matrix
,
interpolation
,
fill
,
[
"nearest"
,
"bilinear"
])
w
,
h
=
img
.
shape
[
-
1
],
img
.
shape
[
-
2
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment