Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
e7c0abac
Unverified
Commit
e7c0abac
authored
Sep 29, 2020
by
vfdev
Committed by
GitHub
Sep 29, 2020
Browse files
Removed type from exception error (#2729)
Otherwise, torch jit scripted function raises exception on save
parent
53ccd538
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
39 additions
and
8 deletions
+39
-8
test/test_transforms_tensor.py
test/test_transforms_tensor.py
+32
-1
torchvision/transforms/functional_tensor.py
torchvision/transforms/functional_tensor.py
+7
-7
No files found.
test/test_transforms_tensor.py
View file @
e7c0abac
import
os
import
torch
from
torchvision
import
transforms
as
T
from
torchvision.transforms
import
functional
as
F
...
...
@@ -8,7 +9,7 @@ import numpy as np
import
unittest
from
common_utils
import
TransformsTester
from
common_utils
import
TransformsTester
,
get_tmp_dir
class
Tester
(
TransformsTester
):
...
...
@@ -73,6 +74,9 @@ class Tester(TransformsTester):
batch_tensors
=
self
.
_create_data_batch
(
height
=
23
,
width
=
34
,
channels
=
3
,
num_samples
=
4
,
device
=
self
.
device
)
self
.
_test_transform_vs_scripted_on_batch
(
f
,
scripted_fn
,
batch_tensors
)
with
get_tmp_dir
()
as
tmp_dir
:
scripted_fn
.
save
(
os
.
path
.
join
(
tmp_dir
,
"t_{}.pt"
.
format
(
method
)))
def
_test_op
(
self
,
func
,
method
,
fn_kwargs
=
None
,
meth_kwargs
=
None
):
self
.
_test_functional_op
(
func
,
fn_kwargs
)
self
.
_test_class_op
(
method
,
meth_kwargs
)
...
...
@@ -188,6 +192,9 @@ class Tester(TransformsTester):
scripted_fn
=
torch
.
jit
.
script
(
f
)
scripted_fn
(
tensor
)
with
get_tmp_dir
()
as
tmp_dir
:
scripted_fn
.
save
(
os
.
path
.
join
(
tmp_dir
,
"t_center_crop.pt"
))
def
_test_op_list_output
(
self
,
func
,
method
,
out_length
,
fn_kwargs
=
None
,
meth_kwargs
=
None
):
if
fn_kwargs
is
None
:
fn_kwargs
=
{}
...
...
@@ -231,6 +238,9 @@ class Tester(TransformsTester):
self
.
assertTrue
(
transformed_img
.
equal
(
transformed_batch
[
i
,
...]),
msg
=
"{} vs {}"
.
format
(
transformed_img
,
transformed_batch
[
i
,
...]))
with
get_tmp_dir
()
as
tmp_dir
:
scripted_fn
.
save
(
os
.
path
.
join
(
tmp_dir
,
"t_op_list_{}.pt"
.
format
(
method
)))
def
test_five_crop
(
self
):
fn_kwargs
=
meth_kwargs
=
{
"size"
:
(
5
,)}
self
.
_test_op_list_output
(
...
...
@@ -294,6 +304,9 @@ class Tester(TransformsTester):
self
.
_test_transform_vs_scripted
(
transform
,
s_transform
,
tensor
)
self
.
_test_transform_vs_scripted_on_batch
(
transform
,
s_transform
,
batch_tensors
)
with
get_tmp_dir
()
as
tmp_dir
:
script_fn
.
save
(
os
.
path
.
join
(
tmp_dir
,
"t_resize.pt"
))
def
test_resized_crop
(
self
):
tensor
=
torch
.
randint
(
0
,
255
,
size
=
(
3
,
44
,
56
),
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
batch_tensors
=
torch
.
randint
(
0
,
255
,
size
=
(
4
,
3
,
44
,
56
),
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
...
...
@@ -309,6 +322,9 @@ class Tester(TransformsTester):
self
.
_test_transform_vs_scripted
(
transform
,
s_transform
,
tensor
)
self
.
_test_transform_vs_scripted_on_batch
(
transform
,
s_transform
,
batch_tensors
)
with
get_tmp_dir
()
as
tmp_dir
:
s_transform
.
save
(
os
.
path
.
join
(
tmp_dir
,
"t_resized_crop.pt"
))
def
test_random_affine
(
self
):
tensor
=
torch
.
randint
(
0
,
255
,
size
=
(
3
,
44
,
56
),
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
batch_tensors
=
torch
.
randint
(
0
,
255
,
size
=
(
4
,
3
,
44
,
56
),
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
...
...
@@ -327,6 +343,9 @@ class Tester(TransformsTester):
self
.
_test_transform_vs_scripted
(
transform
,
s_transform
,
tensor
)
self
.
_test_transform_vs_scripted_on_batch
(
transform
,
s_transform
,
batch_tensors
)
with
get_tmp_dir
()
as
tmp_dir
:
s_transform
.
save
(
os
.
path
.
join
(
tmp_dir
,
"t_random_affine.pt"
))
def
test_random_rotate
(
self
):
tensor
=
torch
.
randint
(
0
,
255
,
size
=
(
3
,
44
,
56
),
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
batch_tensors
=
torch
.
randint
(
0
,
255
,
size
=
(
4
,
3
,
44
,
56
),
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
...
...
@@ -343,6 +362,9 @@ class Tester(TransformsTester):
self
.
_test_transform_vs_scripted
(
transform
,
s_transform
,
tensor
)
self
.
_test_transform_vs_scripted_on_batch
(
transform
,
s_transform
,
batch_tensors
)
with
get_tmp_dir
()
as
tmp_dir
:
s_transform
.
save
(
os
.
path
.
join
(
tmp_dir
,
"t_random_rotate.pt"
))
def
test_random_perspective
(
self
):
tensor
=
torch
.
randint
(
0
,
255
,
size
=
(
3
,
44
,
56
),
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
batch_tensors
=
torch
.
randint
(
0
,
255
,
size
=
(
4
,
3
,
44
,
56
),
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
...
...
@@ -358,6 +380,9 @@ class Tester(TransformsTester):
self
.
_test_transform_vs_scripted
(
transform
,
s_transform
,
tensor
)
self
.
_test_transform_vs_scripted_on_batch
(
transform
,
s_transform
,
batch_tensors
)
with
get_tmp_dir
()
as
tmp_dir
:
s_transform
.
save
(
os
.
path
.
join
(
tmp_dir
,
"t_perspective.pt"
))
def
test_to_grayscale
(
self
):
meth_kwargs
=
{
"num_output_channels"
:
1
}
...
...
@@ -388,6 +413,9 @@ class Tester(TransformsTester):
self
.
_test_transform_vs_scripted
(
fn
,
scripted_fn
,
tensor
)
self
.
_test_transform_vs_scripted_on_batch
(
fn
,
scripted_fn
,
batch_tensors
)
with
get_tmp_dir
()
as
tmp_dir
:
scripted_fn
.
save
(
os
.
path
.
join
(
tmp_dir
,
"t_norm.pt"
))
def
test_linear_transformation
(
self
):
c
,
h
,
w
=
3
,
24
,
32
...
...
@@ -410,6 +438,9 @@ class Tester(TransformsTester):
s_transformed_batch
=
scripted_fn
(
batch_tensors
)
self
.
assertTrue
(
transformed_batch
.
equal
(
s_transformed_batch
))
with
get_tmp_dir
()
as
tmp_dir
:
scripted_fn
.
save
(
os
.
path
.
join
(
tmp_dir
,
"t_norm.pt"
))
def
test_compose
(
self
):
tensor
,
_
=
self
.
_create_data
(
26
,
34
,
device
=
self
.
device
)
tensor
=
tensor
.
to
(
dtype
=
torch
.
float32
)
/
255.0
...
...
torchvision/transforms/functional_tensor.py
View file @
e7c0abac
...
...
@@ -15,7 +15,7 @@ def _get_image_size(img: Tensor) -> List[int]:
"""Returns (w, h) of tensor image"""
if
_is_tensor_a_torch_image
(
img
):
return
[
img
.
shape
[
-
1
],
img
.
shape
[
-
2
]]
raise
TypeError
(
"Unexpected
type {}"
.
format
(
type
(
img
))
)
raise
TypeError
(
"Unexpected
input type"
)
def
_get_image_num_channels
(
img
:
Tensor
)
->
int
:
...
...
@@ -24,7 +24,7 @@ def _get_image_num_channels(img: Tensor) -> int:
elif
img
.
ndim
>
2
:
return
img
.
shape
[
-
3
]
raise
TypeError
(
"
Unexpected type
{}"
.
format
(
type
(
img
)
))
raise
TypeError
(
"
Input ndim should be 2 or more. Got
{}"
.
format
(
img
.
ndim
))
def
vflip
(
img
:
Tensor
)
->
Tensor
:
...
...
@@ -223,7 +223,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
raise
ValueError
(
'hue_factor ({}) is not in [-0.5, 0.5].'
.
format
(
hue_factor
))
if
not
(
isinstance
(
img
,
torch
.
Tensor
)
and
_is_tensor_a_torch_image
(
img
)):
raise
TypeError
(
'img should be Tensor image
. Got {}'
.
format
(
type
(
img
))
)
raise
TypeError
(
'
Input
img should be Tensor image
'
)
orig_dtype
=
img
.
dtype
if
img
.
dtype
==
torch
.
uint8
:
...
...
@@ -294,7 +294,7 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
"""
if
not
isinstance
(
img
,
torch
.
Tensor
):
raise
TypeError
(
'img should be a Tensor.
Got {}'
.
format
(
type
(
img
))
)
raise
TypeError
(
'
Input
img should be a Tensor.
'
)
if
gamma
<
0
:
raise
ValueError
(
'Gamma should be a non-negative real number'
)
...
...
@@ -763,10 +763,10 @@ def _assert_grid_transform_inputs(
coeffs
:
Optional
[
List
[
float
]]
=
None
,
):
if
not
(
isinstance
(
img
,
torch
.
Tensor
)
and
_is_tensor_a_torch_image
(
img
)):
raise
TypeError
(
"img should be Tensor Image
. Got {}"
.
format
(
type
(
img
))
)
raise
TypeError
(
"
Input
img should be Tensor Image
"
)
if
matrix
is
not
None
and
not
isinstance
(
matrix
,
list
):
raise
TypeError
(
"Argument matrix should be a list
. Got {}"
.
format
(
type
(
matrix
))
)
raise
TypeError
(
"Argument matrix should be a list
"
)
if
matrix
is
not
None
and
len
(
matrix
)
!=
6
:
raise
ValueError
(
"Argument matrix should have 6 float values"
)
...
...
@@ -989,7 +989,7 @@ def perspective(
Tensor: transformed image.
"""
if
not
(
isinstance
(
img
,
torch
.
Tensor
)
and
_is_tensor_a_torch_image
(
img
)):
raise
TypeError
(
'img should be Tensor Image
. Got {}'
.
format
(
type
(
img
))
)
raise
TypeError
(
'
Input
img should be Tensor Image
'
)
_interpolation_modes
=
{
0
:
"nearest"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment