Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
6a9b5492
Unverified
Commit
6a9b5492
authored
Jan 18, 2024
by
vfdev
Committed by
GitHub
Jan 18, 2024
Browse files
Enabled torch compile on _compute_affine_output_size (#8218)
parent
1de7a74a
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
30 additions
and
0 deletions
+30
-0
torchvision/transforms/v2/functional/_geometry.py
torchvision/transforms/v2/functional/_geometry.py
+30
-0
No files found.
torchvision/transforms/v2/functional/_geometry.py
View file @
6a9b5492
...
@@ -525,6 +525,13 @@ def _get_inverse_affine_matrix(
...
@@ -525,6 +525,13 @@ def _get_inverse_affine_matrix(
def
_compute_affine_output_size
(
matrix
:
List
[
float
],
w
:
int
,
h
:
int
)
->
Tuple
[
int
,
int
]:
def
_compute_affine_output_size
(
matrix
:
List
[
float
],
w
:
int
,
h
:
int
)
->
Tuple
[
int
,
int
]:
if
torch
.
_dynamo
.
is_compiling
()
and
not
torch
.
jit
.
is_scripting
():
return
_compute_affine_output_size_python
(
matrix
,
w
,
h
)
else
:
return
_compute_affine_output_size_tensor
(
matrix
,
w
,
h
)
def
_compute_affine_output_size_tensor
(
matrix
:
List
[
float
],
w
:
int
,
h
:
int
)
->
Tuple
[
int
,
int
]:
# Inspired of PIL implementation:
# Inspired of PIL implementation:
# https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054
# https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054
...
@@ -559,6 +566,29 @@ def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[in
...
@@ -559,6 +566,29 @@ def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[in
return
int
(
size
[
0
]),
int
(
size
[
1
])
# w, h
return
int
(
size
[
0
]),
int
(
size
[
1
])
# w, h
def
_compute_affine_output_size_python
(
matrix
:
List
[
float
],
w
:
int
,
h
:
int
)
->
Tuple
[
int
,
int
]:
# Mostly copied from PIL implementation:
# The only difference is with transformed points as input matrix has zero translation part here and
# PIL has a centered translation part.
# https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054
a
,
b
,
c
,
d
,
e
,
f
=
matrix
xx
=
[]
yy
=
[]
half_w
=
0.5
*
w
half_h
=
0.5
*
h
for
x
,
y
in
((
-
half_w
,
-
half_h
),
(
half_w
,
-
half_h
),
(
half_w
,
half_h
),
(
-
half_w
,
half_h
)):
nx
=
a
*
x
+
b
*
y
+
c
ny
=
d
*
x
+
e
*
y
+
f
xx
.
append
(
nx
+
half_w
)
yy
.
append
(
ny
+
half_h
)
nw
=
math
.
ceil
(
max
(
xx
))
-
math
.
floor
(
min
(
xx
))
nh
=
math
.
ceil
(
max
(
yy
))
-
math
.
floor
(
min
(
yy
))
return
int
(
nw
),
int
(
nh
)
# w, h
def
_apply_grid_transform
(
img
:
torch
.
Tensor
,
grid
:
torch
.
Tensor
,
mode
:
str
,
fill
:
_FillTypeJIT
)
->
torch
.
Tensor
:
def
_apply_grid_transform
(
img
:
torch
.
Tensor
,
grid
:
torch
.
Tensor
,
mode
:
str
,
fill
:
_FillTypeJIT
)
->
torch
.
Tensor
:
input_shape
=
img
.
shape
input_shape
=
img
.
shape
output_height
,
output_width
=
grid
.
shape
[
1
],
grid
.
shape
[
2
]
output_height
,
output_width
=
grid
.
shape
[
1
],
grid
.
shape
[
2
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment