Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
c5958862
Unverified
Commit
c5958862
authored
Nov 14, 2022
by
Vasilis Vryniotis
Committed by
GitHub
Nov 14, 2022
Browse files
Fix bug on prototype `pad` (#6949)
parent
deba0562
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
61 additions
and
34 deletions
+61
-34
torchvision/prototype/transforms/functional/_geometry.py
torchvision/prototype/transforms/functional/_geometry.py
+61
-34
No files found.
torchvision/prototype/transforms/functional/_geometry.py
View file @
c5958862
...
...
@@ -727,39 +727,38 @@ def _pad_with_scalar_fill(
shape
=
image
.
shape
num_channels
,
height
,
width
=
shape
[
-
3
:]
if
image
.
numel
()
>
0
:
image
=
image
.
reshape
(
-
1
,
num_channels
,
height
,
width
)
if
padding_mode
==
"edge"
:
# Similar to the padding order, `torch_pad`'s PIL's padding modes don't have the same names. Thus, we map
# the PIL name for the padding mode, which we are also using for our API, to the corresponding `torch_pad`
# name.
padding_mode
=
"replicate"
if
padding_mode
==
"constant"
:
image
=
torch_pad
(
image
,
torch_padding
,
mode
=
padding_mode
,
value
=
float
(
fill
))
elif
padding_mode
in
(
"reflect"
,
"replicate"
):
# `torch_pad` only supports `"reflect"` or `"replicate"` padding for floating point inputs.
# TODO: See https://github.com/pytorch/pytorch/issues/40763
dtype
=
image
.
dtype
if
not
image
.
is_floating_point
():
needs_cast
=
True
image
=
image
.
to
(
torch
.
float32
)
else
:
needs_cast
=
False
image
=
torch_pad
(
image
,
torch_padding
,
mode
=
padding_mode
)
if
needs_cast
:
image
=
image
.
to
(
dtype
)
else
:
# padding_mode == "symmetric"
image
=
_FT
.
_pad_symmetric
(
image
,
torch_padding
)
batch_size
=
1
for
s
in
shape
[:
-
3
]:
batch_size
*=
s
image
=
image
.
reshape
(
batch_size
,
num_channels
,
height
,
width
)
if
padding_mode
==
"edge"
:
# Similar to the padding order, `torch_pad`'s PIL's padding modes don't have the same names. Thus, we map
# the PIL name for the padding mode, which we are also using for our API, to the corresponding `torch_pad`
# name.
padding_mode
=
"replicate"
if
padding_mode
==
"constant"
:
image
=
torch_pad
(
image
,
torch_padding
,
mode
=
padding_mode
,
value
=
float
(
fill
))
elif
padding_mode
in
(
"reflect"
,
"replicate"
):
# `torch_pad` only supports `"reflect"` or `"replicate"` padding for floating point inputs.
# TODO: See https://github.com/pytorch/pytorch/issues/40763
dtype
=
image
.
dtype
if
not
image
.
is_floating_point
():
needs_cast
=
True
image
=
image
.
to
(
torch
.
float32
)
else
:
needs_cast
=
False
new_height
,
new_width
=
image
.
shape
[
-
2
:]
else
:
left
,
right
,
top
,
bottom
=
torch_padding
new_height
=
height
+
top
+
bottom
new_width
=
width
+
left
+
right
image
=
torch_pad
(
image
,
torch_padding
,
mode
=
padding_mode
)
if
needs_cast
:
image
=
image
.
to
(
dtype
)
else
:
# padding_mode == "symmetric"
image
=
_FT
.
_pad_symmetric
(
image
,
torch_padding
)
new_height
,
new_width
=
image
.
shape
[
-
2
:]
return
image
.
reshape
(
shape
[:
-
3
]
+
(
num_channels
,
new_height
,
new_width
))
...
...
@@ -868,7 +867,24 @@ def pad(
return
pad_image_pil
(
inpt
,
padding
,
fill
=
fill
,
padding_mode
=
padding_mode
)
crop_image_tensor
=
_FT
.
crop
def
crop_image_tensor
(
image
:
torch
.
Tensor
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
torch
.
Tensor
:
h
,
w
=
image
.
shape
[
-
2
:]
right
=
left
+
width
bottom
=
top
+
height
if
left
<
0
or
top
<
0
or
right
>
w
or
bottom
>
h
:
image
=
image
[...,
max
(
top
,
0
)
:
bottom
,
max
(
left
,
0
)
:
right
]
torch_padding
=
[
max
(
min
(
right
,
0
)
-
left
,
0
),
max
(
right
-
max
(
w
,
left
),
0
),
max
(
min
(
bottom
,
0
)
-
top
,
0
),
max
(
bottom
-
max
(
h
,
top
),
0
),
]
return
_pad_with_scalar_fill
(
image
,
torch_padding
,
fill
=
0
,
padding_mode
=
"constant"
)
return
image
[...,
top
:
bottom
,
left
:
right
]
crop_image_pil
=
_FP
.
crop
...
...
@@ -893,7 +909,18 @@ def crop_bounding_box(
def
crop_mask
(
mask
:
torch
.
Tensor
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
torch
.
Tensor
:
return
crop_image_tensor
(
mask
,
top
,
left
,
height
,
width
)
if
mask
.
ndim
<
3
:
mask
=
mask
.
unsqueeze
(
0
)
needs_squeeze
=
True
else
:
needs_squeeze
=
False
output
=
crop_image_tensor
(
mask
,
top
,
left
,
height
,
width
)
if
needs_squeeze
:
output
=
output
.
squeeze
(
0
)
return
output
def
crop_video
(
video
:
torch
.
Tensor
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
torch
.
Tensor
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment