Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
c5958862
Unverified
Commit
c5958862
authored
Nov 14, 2022
by
Vasilis Vryniotis
Committed by
GitHub
Nov 14, 2022
Browse files
Fix bug on prototype `pad` (#6949)
parent
deba0562
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
61 additions
and
34 deletions
+61
-34
torchvision/prototype/transforms/functional/_geometry.py
torchvision/prototype/transforms/functional/_geometry.py
+61
-34
No files found.
torchvision/prototype/transforms/functional/_geometry.py
View file @
c5958862
...
@@ -727,39 +727,38 @@ def _pad_with_scalar_fill(
...
@@ -727,39 +727,38 @@ def _pad_with_scalar_fill(
shape
=
image
.
shape
shape
=
image
.
shape
num_channels
,
height
,
width
=
shape
[
-
3
:]
num_channels
,
height
,
width
=
shape
[
-
3
:]
if
image
.
numel
()
>
0
:
batch_size
=
1
image
=
image
.
reshape
(
-
1
,
num_channels
,
height
,
width
)
for
s
in
shape
[:
-
3
]:
batch_size
*=
s
if
padding_mode
==
"edge"
:
# Similar to the padding order, `torch_pad`'s PIL's padding modes don't have the same names. Thus, we map
image
=
image
.
reshape
(
batch_size
,
num_channels
,
height
,
width
)
# the PIL name for the padding mode, which we are also using for our API, to the corresponding `torch_pad`
# name.
if
padding_mode
==
"edge"
:
padding_mode
=
"replicate"
# Similar to the padding order, `torch_pad`'s PIL's padding modes don't have the same names. Thus, we map
# the PIL name for the padding mode, which we are also using for our API, to the corresponding `torch_pad`
if
padding_mode
==
"constant"
:
# name.
image
=
torch_pad
(
image
,
torch_padding
,
mode
=
padding_mode
,
value
=
float
(
fill
))
padding_mode
=
"replicate"
elif
padding_mode
in
(
"reflect"
,
"replicate"
):
# `torch_pad` only supports `"reflect"` or `"replicate"` padding for floating point inputs.
if
padding_mode
==
"constant"
:
# TODO: See https://github.com/pytorch/pytorch/issues/40763
image
=
torch_pad
(
image
,
torch_padding
,
mode
=
padding_mode
,
value
=
float
(
fill
))
dtype
=
image
.
dtype
elif
padding_mode
in
(
"reflect"
,
"replicate"
):
if
not
image
.
is_floating_point
():
# `torch_pad` only supports `"reflect"` or `"replicate"` padding for floating point inputs.
needs_cast
=
True
# TODO: See https://github.com/pytorch/pytorch/issues/40763
image
=
image
.
to
(
torch
.
float32
)
dtype
=
image
.
dtype
else
:
if
not
image
.
is_floating_point
():
needs_cast
=
False
needs_cast
=
True
image
=
image
.
to
(
torch
.
float32
)
image
=
torch_pad
(
image
,
torch_padding
,
mode
=
padding_mode
)
else
:
needs_cast
=
False
if
needs_cast
:
image
=
image
.
to
(
dtype
)
else
:
# padding_mode == "symmetric"
image
=
_FT
.
_pad_symmetric
(
image
,
torch_padding
)
new_height
,
new_width
=
image
.
shape
[
-
2
:]
image
=
torch_pad
(
image
,
torch_padding
,
mode
=
padding_mode
)
else
:
left
,
right
,
top
,
bottom
=
torch_padding
if
needs_cast
:
new_height
=
height
+
top
+
bottom
image
=
image
.
to
(
dtype
)
new_width
=
width
+
left
+
right
else
:
# padding_mode == "symmetric"
image
=
_FT
.
_pad_symmetric
(
image
,
torch_padding
)
new_height
,
new_width
=
image
.
shape
[
-
2
:]
return
image
.
reshape
(
shape
[:
-
3
]
+
(
num_channels
,
new_height
,
new_width
))
return
image
.
reshape
(
shape
[:
-
3
]
+
(
num_channels
,
new_height
,
new_width
))
...
@@ -868,7 +867,24 @@ def pad(
...
@@ -868,7 +867,24 @@ def pad(
return
pad_image_pil
(
inpt
,
padding
,
fill
=
fill
,
padding_mode
=
padding_mode
)
return
pad_image_pil
(
inpt
,
padding
,
fill
=
fill
,
padding_mode
=
padding_mode
)
crop_image_tensor
=
_FT
.
crop
def
crop_image_tensor
(
image
:
torch
.
Tensor
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
torch
.
Tensor
:
h
,
w
=
image
.
shape
[
-
2
:]
right
=
left
+
width
bottom
=
top
+
height
if
left
<
0
or
top
<
0
or
right
>
w
or
bottom
>
h
:
image
=
image
[...,
max
(
top
,
0
)
:
bottom
,
max
(
left
,
0
)
:
right
]
torch_padding
=
[
max
(
min
(
right
,
0
)
-
left
,
0
),
max
(
right
-
max
(
w
,
left
),
0
),
max
(
min
(
bottom
,
0
)
-
top
,
0
),
max
(
bottom
-
max
(
h
,
top
),
0
),
]
return
_pad_with_scalar_fill
(
image
,
torch_padding
,
fill
=
0
,
padding_mode
=
"constant"
)
return
image
[...,
top
:
bottom
,
left
:
right
]
crop_image_pil
=
_FP
.
crop
crop_image_pil
=
_FP
.
crop
...
@@ -893,7 +909,18 @@ def crop_bounding_box(
...
@@ -893,7 +909,18 @@ def crop_bounding_box(
def
crop_mask
(
mask
:
torch
.
Tensor
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
torch
.
Tensor
:
def
crop_mask
(
mask
:
torch
.
Tensor
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
torch
.
Tensor
:
return
crop_image_tensor
(
mask
,
top
,
left
,
height
,
width
)
if
mask
.
ndim
<
3
:
mask
=
mask
.
unsqueeze
(
0
)
needs_squeeze
=
True
else
:
needs_squeeze
=
False
output
=
crop_image_tensor
(
mask
,
top
,
left
,
height
,
width
)
if
needs_squeeze
:
output
=
output
.
squeeze
(
0
)
return
output
def
crop_video
(
video
:
torch
.
Tensor
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
torch
.
Tensor
:
def
crop_video
(
video
:
torch
.
Tensor
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
torch
.
Tensor
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment