Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
e7c0abac
Unverified
Commit
e7c0abac
authored
Sep 29, 2020
by
vfdev
Committed by
GitHub
Sep 29, 2020
Browse files
Removed type from exception error (#2729)
Otherwise, torch jit scripted function raises exception on save
parent
53ccd538
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
39 additions
and
8 deletions
+39
-8
test/test_transforms_tensor.py
test/test_transforms_tensor.py
+32
-1
torchvision/transforms/functional_tensor.py
torchvision/transforms/functional_tensor.py
+7
-7
No files found.
test/test_transforms_tensor.py
View file @
e7c0abac
import
os
import
torch
import
torch
from
torchvision
import
transforms
as
T
from
torchvision
import
transforms
as
T
from
torchvision.transforms
import
functional
as
F
from
torchvision.transforms
import
functional
as
F
...
@@ -8,7 +9,7 @@ import numpy as np
...
@@ -8,7 +9,7 @@ import numpy as np
import
unittest
import
unittest
from
common_utils
import
TransformsTester
from
common_utils
import
TransformsTester
,
get_tmp_dir
class
Tester
(
TransformsTester
):
class
Tester
(
TransformsTester
):
...
@@ -73,6 +74,9 @@ class Tester(TransformsTester):
...
@@ -73,6 +74,9 @@ class Tester(TransformsTester):
batch_tensors
=
self
.
_create_data_batch
(
height
=
23
,
width
=
34
,
channels
=
3
,
num_samples
=
4
,
device
=
self
.
device
)
batch_tensors
=
self
.
_create_data_batch
(
height
=
23
,
width
=
34
,
channels
=
3
,
num_samples
=
4
,
device
=
self
.
device
)
self
.
_test_transform_vs_scripted_on_batch
(
f
,
scripted_fn
,
batch_tensors
)
self
.
_test_transform_vs_scripted_on_batch
(
f
,
scripted_fn
,
batch_tensors
)
with
get_tmp_dir
()
as
tmp_dir
:
scripted_fn
.
save
(
os
.
path
.
join
(
tmp_dir
,
"t_{}.pt"
.
format
(
method
)))
def
_test_op
(
self
,
func
,
method
,
fn_kwargs
=
None
,
meth_kwargs
=
None
):
def
_test_op
(
self
,
func
,
method
,
fn_kwargs
=
None
,
meth_kwargs
=
None
):
self
.
_test_functional_op
(
func
,
fn_kwargs
)
self
.
_test_functional_op
(
func
,
fn_kwargs
)
self
.
_test_class_op
(
method
,
meth_kwargs
)
self
.
_test_class_op
(
method
,
meth_kwargs
)
...
@@ -188,6 +192,9 @@ class Tester(TransformsTester):
...
@@ -188,6 +192,9 @@ class Tester(TransformsTester):
scripted_fn
=
torch
.
jit
.
script
(
f
)
scripted_fn
=
torch
.
jit
.
script
(
f
)
scripted_fn
(
tensor
)
scripted_fn
(
tensor
)
with
get_tmp_dir
()
as
tmp_dir
:
scripted_fn
.
save
(
os
.
path
.
join
(
tmp_dir
,
"t_center_crop.pt"
))
def
_test_op_list_output
(
self
,
func
,
method
,
out_length
,
fn_kwargs
=
None
,
meth_kwargs
=
None
):
def
_test_op_list_output
(
self
,
func
,
method
,
out_length
,
fn_kwargs
=
None
,
meth_kwargs
=
None
):
if
fn_kwargs
is
None
:
if
fn_kwargs
is
None
:
fn_kwargs
=
{}
fn_kwargs
=
{}
...
@@ -231,6 +238,9 @@ class Tester(TransformsTester):
...
@@ -231,6 +238,9 @@ class Tester(TransformsTester):
self
.
assertTrue
(
transformed_img
.
equal
(
transformed_batch
[
i
,
...]),
self
.
assertTrue
(
transformed_img
.
equal
(
transformed_batch
[
i
,
...]),
msg
=
"{} vs {}"
.
format
(
transformed_img
,
transformed_batch
[
i
,
...]))
msg
=
"{} vs {}"
.
format
(
transformed_img
,
transformed_batch
[
i
,
...]))
with
get_tmp_dir
()
as
tmp_dir
:
scripted_fn
.
save
(
os
.
path
.
join
(
tmp_dir
,
"t_op_list_{}.pt"
.
format
(
method
)))
def
test_five_crop
(
self
):
def
test_five_crop
(
self
):
fn_kwargs
=
meth_kwargs
=
{
"size"
:
(
5
,)}
fn_kwargs
=
meth_kwargs
=
{
"size"
:
(
5
,)}
self
.
_test_op_list_output
(
self
.
_test_op_list_output
(
...
@@ -294,6 +304,9 @@ class Tester(TransformsTester):
...
@@ -294,6 +304,9 @@ class Tester(TransformsTester):
self
.
_test_transform_vs_scripted
(
transform
,
s_transform
,
tensor
)
self
.
_test_transform_vs_scripted
(
transform
,
s_transform
,
tensor
)
self
.
_test_transform_vs_scripted_on_batch
(
transform
,
s_transform
,
batch_tensors
)
self
.
_test_transform_vs_scripted_on_batch
(
transform
,
s_transform
,
batch_tensors
)
with
get_tmp_dir
()
as
tmp_dir
:
script_fn
.
save
(
os
.
path
.
join
(
tmp_dir
,
"t_resize.pt"
))
def
test_resized_crop
(
self
):
def
test_resized_crop
(
self
):
tensor
=
torch
.
randint
(
0
,
255
,
size
=
(
3
,
44
,
56
),
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
tensor
=
torch
.
randint
(
0
,
255
,
size
=
(
3
,
44
,
56
),
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
batch_tensors
=
torch
.
randint
(
0
,
255
,
size
=
(
4
,
3
,
44
,
56
),
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
batch_tensors
=
torch
.
randint
(
0
,
255
,
size
=
(
4
,
3
,
44
,
56
),
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
...
@@ -309,6 +322,9 @@ class Tester(TransformsTester):
...
@@ -309,6 +322,9 @@ class Tester(TransformsTester):
self
.
_test_transform_vs_scripted
(
transform
,
s_transform
,
tensor
)
self
.
_test_transform_vs_scripted
(
transform
,
s_transform
,
tensor
)
self
.
_test_transform_vs_scripted_on_batch
(
transform
,
s_transform
,
batch_tensors
)
self
.
_test_transform_vs_scripted_on_batch
(
transform
,
s_transform
,
batch_tensors
)
with
get_tmp_dir
()
as
tmp_dir
:
s_transform
.
save
(
os
.
path
.
join
(
tmp_dir
,
"t_resized_crop.pt"
))
def
test_random_affine
(
self
):
def
test_random_affine
(
self
):
tensor
=
torch
.
randint
(
0
,
255
,
size
=
(
3
,
44
,
56
),
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
tensor
=
torch
.
randint
(
0
,
255
,
size
=
(
3
,
44
,
56
),
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
batch_tensors
=
torch
.
randint
(
0
,
255
,
size
=
(
4
,
3
,
44
,
56
),
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
batch_tensors
=
torch
.
randint
(
0
,
255
,
size
=
(
4
,
3
,
44
,
56
),
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
...
@@ -327,6 +343,9 @@ class Tester(TransformsTester):
...
@@ -327,6 +343,9 @@ class Tester(TransformsTester):
self
.
_test_transform_vs_scripted
(
transform
,
s_transform
,
tensor
)
self
.
_test_transform_vs_scripted
(
transform
,
s_transform
,
tensor
)
self
.
_test_transform_vs_scripted_on_batch
(
transform
,
s_transform
,
batch_tensors
)
self
.
_test_transform_vs_scripted_on_batch
(
transform
,
s_transform
,
batch_tensors
)
with
get_tmp_dir
()
as
tmp_dir
:
s_transform
.
save
(
os
.
path
.
join
(
tmp_dir
,
"t_random_affine.pt"
))
def
test_random_rotate
(
self
):
def
test_random_rotate
(
self
):
tensor
=
torch
.
randint
(
0
,
255
,
size
=
(
3
,
44
,
56
),
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
tensor
=
torch
.
randint
(
0
,
255
,
size
=
(
3
,
44
,
56
),
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
batch_tensors
=
torch
.
randint
(
0
,
255
,
size
=
(
4
,
3
,
44
,
56
),
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
batch_tensors
=
torch
.
randint
(
0
,
255
,
size
=
(
4
,
3
,
44
,
56
),
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
...
@@ -343,6 +362,9 @@ class Tester(TransformsTester):
...
@@ -343,6 +362,9 @@ class Tester(TransformsTester):
self
.
_test_transform_vs_scripted
(
transform
,
s_transform
,
tensor
)
self
.
_test_transform_vs_scripted
(
transform
,
s_transform
,
tensor
)
self
.
_test_transform_vs_scripted_on_batch
(
transform
,
s_transform
,
batch_tensors
)
self
.
_test_transform_vs_scripted_on_batch
(
transform
,
s_transform
,
batch_tensors
)
with
get_tmp_dir
()
as
tmp_dir
:
s_transform
.
save
(
os
.
path
.
join
(
tmp_dir
,
"t_random_rotate.pt"
))
def
test_random_perspective
(
self
):
def
test_random_perspective
(
self
):
tensor
=
torch
.
randint
(
0
,
255
,
size
=
(
3
,
44
,
56
),
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
tensor
=
torch
.
randint
(
0
,
255
,
size
=
(
3
,
44
,
56
),
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
batch_tensors
=
torch
.
randint
(
0
,
255
,
size
=
(
4
,
3
,
44
,
56
),
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
batch_tensors
=
torch
.
randint
(
0
,
255
,
size
=
(
4
,
3
,
44
,
56
),
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
...
@@ -358,6 +380,9 @@ class Tester(TransformsTester):
...
@@ -358,6 +380,9 @@ class Tester(TransformsTester):
self
.
_test_transform_vs_scripted
(
transform
,
s_transform
,
tensor
)
self
.
_test_transform_vs_scripted
(
transform
,
s_transform
,
tensor
)
self
.
_test_transform_vs_scripted_on_batch
(
transform
,
s_transform
,
batch_tensors
)
self
.
_test_transform_vs_scripted_on_batch
(
transform
,
s_transform
,
batch_tensors
)
with
get_tmp_dir
()
as
tmp_dir
:
s_transform
.
save
(
os
.
path
.
join
(
tmp_dir
,
"t_perspective.pt"
))
def
test_to_grayscale
(
self
):
def
test_to_grayscale
(
self
):
meth_kwargs
=
{
"num_output_channels"
:
1
}
meth_kwargs
=
{
"num_output_channels"
:
1
}
...
@@ -388,6 +413,9 @@ class Tester(TransformsTester):
...
@@ -388,6 +413,9 @@ class Tester(TransformsTester):
self
.
_test_transform_vs_scripted
(
fn
,
scripted_fn
,
tensor
)
self
.
_test_transform_vs_scripted
(
fn
,
scripted_fn
,
tensor
)
self
.
_test_transform_vs_scripted_on_batch
(
fn
,
scripted_fn
,
batch_tensors
)
self
.
_test_transform_vs_scripted_on_batch
(
fn
,
scripted_fn
,
batch_tensors
)
with
get_tmp_dir
()
as
tmp_dir
:
scripted_fn
.
save
(
os
.
path
.
join
(
tmp_dir
,
"t_norm.pt"
))
def
test_linear_transformation
(
self
):
def
test_linear_transformation
(
self
):
c
,
h
,
w
=
3
,
24
,
32
c
,
h
,
w
=
3
,
24
,
32
...
@@ -410,6 +438,9 @@ class Tester(TransformsTester):
...
@@ -410,6 +438,9 @@ class Tester(TransformsTester):
s_transformed_batch
=
scripted_fn
(
batch_tensors
)
s_transformed_batch
=
scripted_fn
(
batch_tensors
)
self
.
assertTrue
(
transformed_batch
.
equal
(
s_transformed_batch
))
self
.
assertTrue
(
transformed_batch
.
equal
(
s_transformed_batch
))
with
get_tmp_dir
()
as
tmp_dir
:
scripted_fn
.
save
(
os
.
path
.
join
(
tmp_dir
,
"t_norm.pt"
))
def
test_compose
(
self
):
def
test_compose
(
self
):
tensor
,
_
=
self
.
_create_data
(
26
,
34
,
device
=
self
.
device
)
tensor
,
_
=
self
.
_create_data
(
26
,
34
,
device
=
self
.
device
)
tensor
=
tensor
.
to
(
dtype
=
torch
.
float32
)
/
255.0
tensor
=
tensor
.
to
(
dtype
=
torch
.
float32
)
/
255.0
...
...
torchvision/transforms/functional_tensor.py
View file @
e7c0abac
...
@@ -15,7 +15,7 @@ def _get_image_size(img: Tensor) -> List[int]:
...
@@ -15,7 +15,7 @@ def _get_image_size(img: Tensor) -> List[int]:
"""Returns (w, h) of tensor image"""
"""Returns (w, h) of tensor image"""
if
_is_tensor_a_torch_image
(
img
):
if
_is_tensor_a_torch_image
(
img
):
return
[
img
.
shape
[
-
1
],
img
.
shape
[
-
2
]]
return
[
img
.
shape
[
-
1
],
img
.
shape
[
-
2
]]
raise
TypeError
(
"Unexpected
type {}"
.
format
(
type
(
img
))
)
raise
TypeError
(
"Unexpected
input type"
)
def
_get_image_num_channels
(
img
:
Tensor
)
->
int
:
def
_get_image_num_channels
(
img
:
Tensor
)
->
int
:
...
@@ -24,7 +24,7 @@ def _get_image_num_channels(img: Tensor) -> int:
...
@@ -24,7 +24,7 @@ def _get_image_num_channels(img: Tensor) -> int:
elif
img
.
ndim
>
2
:
elif
img
.
ndim
>
2
:
return
img
.
shape
[
-
3
]
return
img
.
shape
[
-
3
]
raise
TypeError
(
"
Unexpected type
{}"
.
format
(
type
(
img
)
))
raise
TypeError
(
"
Input ndim should be 2 or more. Got
{}"
.
format
(
img
.
ndim
))
def
vflip
(
img
:
Tensor
)
->
Tensor
:
def
vflip
(
img
:
Tensor
)
->
Tensor
:
...
@@ -223,7 +223,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
...
@@ -223,7 +223,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
raise
ValueError
(
'hue_factor ({}) is not in [-0.5, 0.5].'
.
format
(
hue_factor
))
raise
ValueError
(
'hue_factor ({}) is not in [-0.5, 0.5].'
.
format
(
hue_factor
))
if
not
(
isinstance
(
img
,
torch
.
Tensor
)
and
_is_tensor_a_torch_image
(
img
)):
if
not
(
isinstance
(
img
,
torch
.
Tensor
)
and
_is_tensor_a_torch_image
(
img
)):
raise
TypeError
(
'img should be Tensor image
. Got {}'
.
format
(
type
(
img
))
)
raise
TypeError
(
'
Input
img should be Tensor image
'
)
orig_dtype
=
img
.
dtype
orig_dtype
=
img
.
dtype
if
img
.
dtype
==
torch
.
uint8
:
if
img
.
dtype
==
torch
.
uint8
:
...
@@ -294,7 +294,7 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
...
@@ -294,7 +294,7 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
"""
"""
if
not
isinstance
(
img
,
torch
.
Tensor
):
if
not
isinstance
(
img
,
torch
.
Tensor
):
raise
TypeError
(
'img should be a Tensor.
Got {}'
.
format
(
type
(
img
))
)
raise
TypeError
(
'
Input
img should be a Tensor.
'
)
if
gamma
<
0
:
if
gamma
<
0
:
raise
ValueError
(
'Gamma should be a non-negative real number'
)
raise
ValueError
(
'Gamma should be a non-negative real number'
)
...
@@ -763,10 +763,10 @@ def _assert_grid_transform_inputs(
...
@@ -763,10 +763,10 @@ def _assert_grid_transform_inputs(
coeffs
:
Optional
[
List
[
float
]]
=
None
,
coeffs
:
Optional
[
List
[
float
]]
=
None
,
):
):
if
not
(
isinstance
(
img
,
torch
.
Tensor
)
and
_is_tensor_a_torch_image
(
img
)):
if
not
(
isinstance
(
img
,
torch
.
Tensor
)
and
_is_tensor_a_torch_image
(
img
)):
raise
TypeError
(
"img should be Tensor Image
. Got {}"
.
format
(
type
(
img
))
)
raise
TypeError
(
"
Input
img should be Tensor Image
"
)
if
matrix
is
not
None
and
not
isinstance
(
matrix
,
list
):
if
matrix
is
not
None
and
not
isinstance
(
matrix
,
list
):
raise
TypeError
(
"Argument matrix should be a list
. Got {}"
.
format
(
type
(
matrix
))
)
raise
TypeError
(
"Argument matrix should be a list
"
)
if
matrix
is
not
None
and
len
(
matrix
)
!=
6
:
if
matrix
is
not
None
and
len
(
matrix
)
!=
6
:
raise
ValueError
(
"Argument matrix should have 6 float values"
)
raise
ValueError
(
"Argument matrix should have 6 float values"
)
...
@@ -989,7 +989,7 @@ def perspective(
...
@@ -989,7 +989,7 @@ def perspective(
Tensor: transformed image.
Tensor: transformed image.
"""
"""
if
not
(
isinstance
(
img
,
torch
.
Tensor
)
and
_is_tensor_a_torch_image
(
img
)):
if
not
(
isinstance
(
img
,
torch
.
Tensor
)
and
_is_tensor_a_torch_image
(
img
)):
raise
TypeError
(
'img should be Tensor Image
. Got {}'
.
format
(
type
(
img
))
)
raise
TypeError
(
'
Input
img should be Tensor Image
'
)
_interpolation_modes
=
{
_interpolation_modes
=
{
0
:
"nearest"
,
0
:
"nearest"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment