Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
d367a01a
Unverified
Commit
d367a01a
authored
Oct 28, 2021
by
Jirka Borovec
Committed by
GitHub
Oct 28, 2021
Browse files
Use f-strings almost everywhere, and other cleanups by applying pyupgrade (#4585)
Co-authored-by:
Nicolas Hug
<
nicolashug@fb.com
>
parent
50dfe207
Changes
136
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
174 additions
and
188 deletions
+174
-188
torchvision/ops/ps_roi_align.py
torchvision/ops/ps_roi_align.py
+1
-1
torchvision/ops/ps_roi_pool.py
torchvision/ops/ps_roi_pool.py
+1
-1
torchvision/ops/roi_align.py
torchvision/ops/roi_align.py
+1
-1
torchvision/ops/roi_pool.py
torchvision/ops/roi_pool.py
+1
-1
torchvision/ops/stochastic_depth.py
torchvision/ops/stochastic_depth.py
+2
-2
torchvision/prototype/datasets/utils/_dataset.py
torchvision/prototype/datasets/utils/_dataset.py
+1
-1
torchvision/prototype/datasets/utils/_internal.py
torchvision/prototype/datasets/utils/_internal.py
+1
-1
torchvision/prototype/models/_api.py
torchvision/prototype/models/_api.py
+2
-2
torchvision/transforms/_functional_video.py
torchvision/transforms/_functional_video.py
+1
-1
torchvision/transforms/_transforms_video.py
torchvision/transforms/_transforms_video.py
+11
-10
torchvision/transforms/autoaugment.py
torchvision/transforms/autoaugment.py
+3
-3
torchvision/transforms/functional.py
torchvision/transforms/functional.py
+31
-33
torchvision/transforms/functional_pil.py
torchvision/transforms/functional_pil.py
+26
-28
torchvision/transforms/functional_tensor.py
torchvision/transforms/functional_tensor.py
+24
-25
torchvision/transforms/transforms.py
torchvision/transforms/transforms.py
+66
-76
torchvision/utils.py
torchvision/utils.py
+2
-2
No files found.
torchvision/ops/ps_roi_align.py
View file @
d367a01a
...
@@ -65,7 +65,7 @@ class PSRoIAlign(nn.Module):
...
@@ -65,7 +65,7 @@ class PSRoIAlign(nn.Module):
spatial_scale
:
float
,
spatial_scale
:
float
,
sampling_ratio
:
int
,
sampling_ratio
:
int
,
):
):
super
(
PSRoIAlign
,
self
).
__init__
()
super
().
__init__
()
self
.
output_size
=
output_size
self
.
output_size
=
output_size
self
.
spatial_scale
=
spatial_scale
self
.
spatial_scale
=
spatial_scale
self
.
sampling_ratio
=
sampling_ratio
self
.
sampling_ratio
=
sampling_ratio
...
...
torchvision/ops/ps_roi_pool.py
View file @
d367a01a
...
@@ -52,7 +52,7 @@ class PSRoIPool(nn.Module):
...
@@ -52,7 +52,7 @@ class PSRoIPool(nn.Module):
"""
"""
def
__init__
(
self
,
output_size
:
int
,
spatial_scale
:
float
):
def
__init__
(
self
,
output_size
:
int
,
spatial_scale
:
float
):
super
(
PSRoIPool
,
self
).
__init__
()
super
().
__init__
()
self
.
output_size
=
output_size
self
.
output_size
=
output_size
self
.
spatial_scale
=
spatial_scale
self
.
spatial_scale
=
spatial_scale
...
...
torchvision/ops/roi_align.py
View file @
d367a01a
...
@@ -72,7 +72,7 @@ class RoIAlign(nn.Module):
...
@@ -72,7 +72,7 @@ class RoIAlign(nn.Module):
sampling_ratio
:
int
,
sampling_ratio
:
int
,
aligned
:
bool
=
False
,
aligned
:
bool
=
False
,
):
):
super
(
RoIAlign
,
self
).
__init__
()
super
().
__init__
()
self
.
output_size
=
output_size
self
.
output_size
=
output_size
self
.
spatial_scale
=
spatial_scale
self
.
spatial_scale
=
spatial_scale
self
.
sampling_ratio
=
sampling_ratio
self
.
sampling_ratio
=
sampling_ratio
...
...
torchvision/ops/roi_pool.py
View file @
d367a01a
...
@@ -54,7 +54,7 @@ class RoIPool(nn.Module):
...
@@ -54,7 +54,7 @@ class RoIPool(nn.Module):
"""
"""
def
__init__
(
self
,
output_size
:
BroadcastingList2
[
int
],
spatial_scale
:
float
):
def
__init__
(
self
,
output_size
:
BroadcastingList2
[
int
],
spatial_scale
:
float
):
super
(
RoIPool
,
self
).
__init__
()
super
().
__init__
()
self
.
output_size
=
output_size
self
.
output_size
=
output_size
self
.
spatial_scale
=
spatial_scale
self
.
spatial_scale
=
spatial_scale
...
...
torchvision/ops/stochastic_depth.py
View file @
d367a01a
...
@@ -22,9 +22,9 @@ def stochastic_depth(input: Tensor, p: float, mode: str, training: bool = True)
...
@@ -22,9 +22,9 @@ def stochastic_depth(input: Tensor, p: float, mode: str, training: bool = True)
Tensor[N, ...]: The randomly zeroed tensor.
Tensor[N, ...]: The randomly zeroed tensor.
"""
"""
if
p
<
0.0
or
p
>
1.0
:
if
p
<
0.0
or
p
>
1.0
:
raise
ValueError
(
"drop probability has to be between 0 and 1, but got {}"
.
format
(
p
)
)
raise
ValueError
(
f
"drop probability has to be between 0 and 1, but got
{
p
}
"
)
if
mode
not
in
[
"batch"
,
"row"
]:
if
mode
not
in
[
"batch"
,
"row"
]:
raise
ValueError
(
"mode has to be either 'batch' or 'row', but got {
}"
.
format
(
mode
)
)
raise
ValueError
(
f
"mode has to be either 'batch' or 'row', but got
{
mode
}
"
)
if
not
training
or
p
==
0.0
:
if
not
training
or
p
==
0.0
:
return
input
return
input
...
...
torchvision/prototype/datasets/utils/_dataset.py
View file @
d367a01a
...
@@ -71,7 +71,7 @@ class DatasetInfo:
...
@@ -71,7 +71,7 @@ class DatasetInfo:
@
staticmethod
@
staticmethod
def
read_categories_file
(
path
:
pathlib
.
Path
)
->
List
[
List
[
str
]]:
def
read_categories_file
(
path
:
pathlib
.
Path
)
->
List
[
List
[
str
]]:
with
open
(
path
,
"r"
,
newline
=
""
)
as
file
:
with
open
(
path
,
newline
=
""
)
as
file
:
return
[
row
for
row
in
csv
.
reader
(
file
)]
return
[
row
for
row
in
csv
.
reader
(
file
)]
@
property
@
property
...
...
torchvision/prototype/datasets/utils/_internal.py
View file @
d367a01a
...
@@ -63,7 +63,7 @@ def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
...
@@ -63,7 +63,7 @@ def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
if
len
(
seq
)
==
1
:
if
len
(
seq
)
==
1
:
return
f
"'
{
seq
[
0
]
}
'"
return
f
"'
{
seq
[
0
]
}
'"
return
f
"""'
{
"', '"
.
join
([
str
(
item
)
for
item
in
seq
[:
-
1
]])
}
',
"""
f
"""
{
separate_last
}
'
{
seq
[
-
1
]
}
'."""
return
f
"""'
{
"', '"
.
join
([
str
(
item
)
for
item
in
seq
[:
-
1
]])
}
',
{
separate_last
}
'
{
seq
[
-
1
]
}
'."""
def
add_suggestion
(
def
add_suggestion
(
...
...
torchvision/prototype/models/_api.py
View file @
d367a01a
...
@@ -52,7 +52,7 @@ class Weights(Enum):
...
@@ -52,7 +52,7 @@ class Weights(Enum):
obj
=
cls
.
from_str
(
obj
)
obj
=
cls
.
from_str
(
obj
)
elif
not
isinstance
(
obj
,
cls
)
and
not
isinstance
(
obj
,
WeightEntry
):
elif
not
isinstance
(
obj
,
cls
)
and
not
isinstance
(
obj
,
WeightEntry
):
raise
TypeError
(
raise
TypeError
(
f
"Invalid Weight class provided; expected
{
cls
.
__name__
}
"
f
"
but received
{
obj
.
__class__
.
__name__
}
."
f
"Invalid Weight class provided; expected
{
cls
.
__name__
}
but received
{
obj
.
__class__
.
__name__
}
."
)
)
return
obj
return
obj
...
@@ -106,7 +106,7 @@ def get_weight(fn: Callable, weight_name: str) -> Weights:
...
@@ -106,7 +106,7 @@ def get_weight(fn: Callable, weight_name: str) -> Weights:
if
weights_class
is
None
:
if
weights_class
is
None
:
raise
ValueError
(
raise
ValueError
(
"The weight class for the specific method couldn't be retrieved. Make sure the typing info is
"
"
correct."
"The weight class for the specific method couldn't be retrieved. Make sure the typing info is correct."
)
)
return
weights_class
.
from_str
(
weight_name
)
return
weights_class
.
from_str
(
weight_name
)
torchvision/transforms/_functional_video.py
View file @
d367a01a
...
@@ -101,4 +101,4 @@ def hflip(clip):
...
@@ -101,4 +101,4 @@ def hflip(clip):
flipped clip (torch.tensor): Size is (C, T, H, W)
flipped clip (torch.tensor): Size is (C, T, H, W)
"""
"""
assert
_is_tensor_video_clip
(
clip
),
"clip should be a 4D torch.tensor"
assert
_is_tensor_video_clip
(
clip
),
"clip should be a 4D torch.tensor"
return
clip
.
flip
(
(
-
1
)
)
return
clip
.
flip
(
-
1
)
torchvision/transforms/_transforms_video.py
View file @
d367a01a
...
@@ -44,7 +44,7 @@ class RandomCropVideo(RandomCrop):
...
@@ -44,7 +44,7 @@ class RandomCropVideo(RandomCrop):
return
F
.
crop
(
clip
,
i
,
j
,
h
,
w
)
return
F
.
crop
(
clip
,
i
,
j
,
h
,
w
)
def
__repr__
(
self
):
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"(size={
0})"
.
format
(
self
.
size
)
return
self
.
__class__
.
__name__
+
f
"(size=
{
self
.
size
}
)"
class
RandomResizedCropVideo
(
RandomResizedCrop
):
class
RandomResizedCropVideo
(
RandomResizedCrop
):
...
@@ -77,12 +77,13 @@ class RandomResizedCropVideo(RandomResizedCrop):
...
@@ -77,12 +77,13 @@ class RandomResizedCropVideo(RandomResizedCrop):
return
F
.
resized_crop
(
clip
,
i
,
j
,
h
,
w
,
self
.
size
,
self
.
interpolation_mode
)
return
F
.
resized_crop
(
clip
,
i
,
j
,
h
,
w
,
self
.
size
,
self
.
interpolation_mode
)
def
__repr__
(
self
):
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"(size={0}, interpolation_mode={1}, scale={2}, ratio={3})"
.
format
(
return
(
self
.
size
,
self
.
interpolation_mode
,
self
.
scale
,
self
.
ratio
self
.
__class__
.
__name__
+
f
"(size=
{
self
.
size
}
, interpolation_mode=
{
self
.
interpolation_mode
}
, scale=
{
self
.
scale
}
, ratio=
{
self
.
ratio
}
)"
)
)
class
CenterCropVideo
(
object
)
:
class
CenterCropVideo
:
def
__init__
(
self
,
crop_size
):
def
__init__
(
self
,
crop_size
):
if
isinstance
(
crop_size
,
numbers
.
Number
):
if
isinstance
(
crop_size
,
numbers
.
Number
):
self
.
crop_size
=
(
int
(
crop_size
),
int
(
crop_size
))
self
.
crop_size
=
(
int
(
crop_size
),
int
(
crop_size
))
...
@@ -100,10 +101,10 @@ class CenterCropVideo(object):
...
@@ -100,10 +101,10 @@ class CenterCropVideo(object):
return
F
.
center_crop
(
clip
,
self
.
crop_size
)
return
F
.
center_crop
(
clip
,
self
.
crop_size
)
def
__repr__
(
self
):
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"(crop_size={
0})"
.
format
(
self
.
crop_size
)
return
self
.
__class__
.
__name__
+
f
"(crop_size=
{
self
.
crop_size
}
)"
class
NormalizeVideo
(
object
)
:
class
NormalizeVideo
:
"""
"""
Normalize the video clip by mean subtraction and division by standard deviation
Normalize the video clip by mean subtraction and division by standard deviation
Args:
Args:
...
@@ -125,10 +126,10 @@ class NormalizeVideo(object):
...
@@ -125,10 +126,10 @@ class NormalizeVideo(object):
return
F
.
normalize
(
clip
,
self
.
mean
,
self
.
std
,
self
.
inplace
)
return
F
.
normalize
(
clip
,
self
.
mean
,
self
.
std
,
self
.
inplace
)
def
__repr__
(
self
):
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"(mean={
0}, std={1}, inplace={2})"
.
format
(
self
.
mean
,
self
.
std
,
self
.
inplace
)
return
self
.
__class__
.
__name__
+
f
"(mean=
{
self
.
mean
}
,
std=
{
self
.
std
}
,
inplace=
{
self
.
inplace
}
)"
class
ToTensorVideo
(
object
)
:
class
ToTensorVideo
:
"""
"""
Convert tensor data type from uint8 to float, divide value by 255.0 and
Convert tensor data type from uint8 to float, divide value by 255.0 and
permute the dimensions of clip tensor
permute the dimensions of clip tensor
...
@@ -150,7 +151,7 @@ class ToTensorVideo(object):
...
@@ -150,7 +151,7 @@ class ToTensorVideo(object):
return
self
.
__class__
.
__name__
return
self
.
__class__
.
__name__
class
RandomHorizontalFlipVideo
(
object
)
:
class
RandomHorizontalFlipVideo
:
"""
"""
Flip the video clip along the horizonal direction with a given probability
Flip the video clip along the horizonal direction with a given probability
Args:
Args:
...
@@ -172,4 +173,4 @@ class RandomHorizontalFlipVideo(object):
...
@@ -172,4 +173,4 @@ class RandomHorizontalFlipVideo(object):
return
clip
return
clip
def
__repr__
(
self
):
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"(p={
0})"
.
format
(
self
.
p
)
return
self
.
__class__
.
__name__
+
f
"(p=
{
self
.
p
}
)"
torchvision/transforms/autoaugment.py
View file @
d367a01a
...
@@ -76,7 +76,7 @@ def _apply_op(
...
@@ -76,7 +76,7 @@ def _apply_op(
elif
op_name
==
"Identity"
:
elif
op_name
==
"Identity"
:
pass
pass
else
:
else
:
raise
ValueError
(
"The provided operator {} is not recognized."
.
format
(
op_name
)
)
raise
ValueError
(
f
"The provided operator
{
op_name
}
is not recognized."
)
return
img
return
img
...
@@ -208,7 +208,7 @@ class AutoAugment(torch.nn.Module):
...
@@ -208,7 +208,7 @@ class AutoAugment(torch.nn.Module):
((
"ShearX"
,
0.7
,
2
),
(
"Invert"
,
0.1
,
None
)),
((
"ShearX"
,
0.7
,
2
),
(
"Invert"
,
0.1
,
None
)),
]
]
else
:
else
:
raise
ValueError
(
"The provided policy {} is not recognized."
.
format
(
policy
)
)
raise
ValueError
(
f
"The provided policy
{
policy
}
is not recognized."
)
def
_augmentation_space
(
self
,
num_bins
:
int
,
image_size
:
List
[
int
])
->
Dict
[
str
,
Tuple
[
Tensor
,
bool
]]:
def
_augmentation_space
(
self
,
num_bins
:
int
,
image_size
:
List
[
int
])
->
Dict
[
str
,
Tuple
[
Tensor
,
bool
]]:
return
{
return
{
...
@@ -270,7 +270,7 @@ class AutoAugment(torch.nn.Module):
...
@@ -270,7 +270,7 @@ class AutoAugment(torch.nn.Module):
return
img
return
img
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
self
.
__class__
.
__name__
+
"(policy={
}, fill={})"
.
format
(
self
.
policy
,
self
.
fill
)
return
self
.
__class__
.
__name__
+
f
"(policy=
{
self
.
policy
}
,
fill=
{
self
.
fill
}
)"
class
RandAugment
(
torch
.
nn
.
Module
):
class
RandAugment
(
torch
.
nn
.
Module
):
...
...
torchvision/transforms/functional.py
View file @
d367a01a
...
@@ -111,10 +111,10 @@ def to_tensor(pic):
...
@@ -111,10 +111,10 @@ def to_tensor(pic):
Tensor: Converted image.
Tensor: Converted image.
"""
"""
if
not
(
F_pil
.
_is_pil_image
(
pic
)
or
_is_numpy
(
pic
)):
if
not
(
F_pil
.
_is_pil_image
(
pic
)
or
_is_numpy
(
pic
)):
raise
TypeError
(
"pic should be PIL Image or ndarray. Got {
}"
.
format
(
type
(
pic
)
)
)
raise
TypeError
(
f
"pic should be PIL Image or ndarray. Got
{
type
(
pic
)
}
"
)
if
_is_numpy
(
pic
)
and
not
_is_numpy_image
(
pic
):
if
_is_numpy
(
pic
)
and
not
_is_numpy_image
(
pic
):
raise
ValueError
(
"pic should be 2/3 dimensional. Got {} dimensions."
.
format
(
pic
.
ndim
)
)
raise
ValueError
(
f
"pic should be 2/3 dimensional. Got
{
pic
.
ndim
}
dimensions."
)
default_float_dtype
=
torch
.
get_default_dtype
()
default_float_dtype
=
torch
.
get_default_dtype
()
...
@@ -167,7 +167,7 @@ def pil_to_tensor(pic):
...
@@ -167,7 +167,7 @@ def pil_to_tensor(pic):
Tensor: Converted image.
Tensor: Converted image.
"""
"""
if
not
F_pil
.
_is_pil_image
(
pic
):
if
not
F_pil
.
_is_pil_image
(
pic
):
raise
TypeError
(
"pic should be PIL Image. Got {
}"
.
format
(
type
(
pic
)
)
)
raise
TypeError
(
f
"pic should be PIL Image. Got
{
type
(
pic
)
}
"
)
if
accimage
is
not
None
and
isinstance
(
pic
,
accimage
.
Image
):
if
accimage
is
not
None
and
isinstance
(
pic
,
accimage
.
Image
):
# accimage format is always uint8 internally, so always return uint8 here
# accimage format is always uint8 internally, so always return uint8 here
...
@@ -226,11 +226,11 @@ def to_pil_image(pic, mode=None):
...
@@ -226,11 +226,11 @@ def to_pil_image(pic, mode=None):
PIL Image: Image converted to PIL Image.
PIL Image: Image converted to PIL Image.
"""
"""
if
not
(
isinstance
(
pic
,
torch
.
Tensor
)
or
isinstance
(
pic
,
np
.
ndarray
)):
if
not
(
isinstance
(
pic
,
torch
.
Tensor
)
or
isinstance
(
pic
,
np
.
ndarray
)):
raise
TypeError
(
"pic should be Tensor or ndarray. Got {
}."
.
format
(
type
(
pic
)
)
)
raise
TypeError
(
f
"pic should be Tensor or ndarray. Got
{
type
(
pic
)
}
."
)
elif
isinstance
(
pic
,
torch
.
Tensor
):
elif
isinstance
(
pic
,
torch
.
Tensor
):
if
pic
.
ndimension
()
not
in
{
2
,
3
}:
if
pic
.
ndimension
()
not
in
{
2
,
3
}:
raise
ValueError
(
"pic should be 2/3 dimensional. Got {
}
dimension
s."
.
format
(
pic
.
n
dimension
())
)
raise
ValueError
(
f
"pic should be 2/3 dimensional. Got
{
pic
.
n
dimension
()
}
dimension
s."
)
elif
pic
.
ndimension
()
==
2
:
elif
pic
.
ndimension
()
==
2
:
# if 2D image, add channel dimension (CHW)
# if 2D image, add channel dimension (CHW)
...
@@ -238,11 +238,11 @@ def to_pil_image(pic, mode=None):
...
@@ -238,11 +238,11 @@ def to_pil_image(pic, mode=None):
# check number of channels
# check number of channels
if
pic
.
shape
[
-
3
]
>
4
:
if
pic
.
shape
[
-
3
]
>
4
:
raise
ValueError
(
"pic should not have > 4 channels. Got {
} channels."
.
format
(
pic
.
shape
[
-
3
]
)
)
raise
ValueError
(
f
"pic should not have > 4 channels. Got
{
pic
.
shape
[
-
3
]
}
channels."
)
elif
isinstance
(
pic
,
np
.
ndarray
):
elif
isinstance
(
pic
,
np
.
ndarray
):
if
pic
.
ndim
not
in
{
2
,
3
}:
if
pic
.
ndim
not
in
{
2
,
3
}:
raise
ValueError
(
"pic should be 2/3 dimensional. Got {} dimensions."
.
format
(
pic
.
ndim
)
)
raise
ValueError
(
f
"pic should be 2/3 dimensional. Got
{
pic
.
ndim
}
dimensions."
)
elif
pic
.
ndim
==
2
:
elif
pic
.
ndim
==
2
:
# if 2D image, add channel dimension (HWC)
# if 2D image, add channel dimension (HWC)
...
@@ -250,7 +250,7 @@ def to_pil_image(pic, mode=None):
...
@@ -250,7 +250,7 @@ def to_pil_image(pic, mode=None):
# check number of channels
# check number of channels
if
pic
.
shape
[
-
1
]
>
4
:
if
pic
.
shape
[
-
1
]
>
4
:
raise
ValueError
(
"pic should not have > 4 channels. Got {
} channels."
.
format
(
pic
.
shape
[
-
1
]
)
)
raise
ValueError
(
f
"pic should not have > 4 channels. Got
{
pic
.
shape
[
-
1
]
}
channels."
)
npimg
=
pic
npimg
=
pic
if
isinstance
(
pic
,
torch
.
Tensor
):
if
isinstance
(
pic
,
torch
.
Tensor
):
...
@@ -259,7 +259,7 @@ def to_pil_image(pic, mode=None):
...
@@ -259,7 +259,7 @@ def to_pil_image(pic, mode=None):
npimg
=
np
.
transpose
(
pic
.
cpu
().
numpy
(),
(
1
,
2
,
0
))
npimg
=
np
.
transpose
(
pic
.
cpu
().
numpy
(),
(
1
,
2
,
0
))
if
not
isinstance
(
npimg
,
np
.
ndarray
):
if
not
isinstance
(
npimg
,
np
.
ndarray
):
raise
TypeError
(
"Input pic must be a torch.Tensor or NumPy ndarray,
"
+
"not {}"
.
format
(
type
(
npimg
)
)
)
raise
TypeError
(
"Input pic must be a torch.Tensor or NumPy ndarray,
not {
type(npimg)
}"
)
if
npimg
.
shape
[
2
]
==
1
:
if
npimg
.
shape
[
2
]
==
1
:
expected_mode
=
None
expected_mode
=
None
...
@@ -273,15 +273,13 @@ def to_pil_image(pic, mode=None):
...
@@ -273,15 +273,13 @@ def to_pil_image(pic, mode=None):
elif
npimg
.
dtype
==
np
.
float32
:
elif
npimg
.
dtype
==
np
.
float32
:
expected_mode
=
"F"
expected_mode
=
"F"
if
mode
is
not
None
and
mode
!=
expected_mode
:
if
mode
is
not
None
and
mode
!=
expected_mode
:
raise
ValueError
(
raise
ValueError
(
f
"Incorrect mode (
{
mode
}
) supplied for input type
{
np
.
dtype
}
. Should be
{
expected_mode
}
"
)
"Incorrect mode ({}) supplied for input type {}. Should be {}"
.
format
(
mode
,
np
.
dtype
,
expected_mode
)
)
mode
=
expected_mode
mode
=
expected_mode
elif
npimg
.
shape
[
2
]
==
2
:
elif
npimg
.
shape
[
2
]
==
2
:
permitted_2_channel_modes
=
[
"LA"
]
permitted_2_channel_modes
=
[
"LA"
]
if
mode
is
not
None
and
mode
not
in
permitted_2_channel_modes
:
if
mode
is
not
None
and
mode
not
in
permitted_2_channel_modes
:
raise
ValueError
(
"Only modes {} are supported for 2D inputs"
.
format
(
permitted_2_channel_modes
)
)
raise
ValueError
(
f
"Only modes
{
permitted_2_channel_modes
}
are supported for 2D inputs"
)
if
mode
is
None
and
npimg
.
dtype
==
np
.
uint8
:
if
mode
is
None
and
npimg
.
dtype
==
np
.
uint8
:
mode
=
"LA"
mode
=
"LA"
...
@@ -289,19 +287,19 @@ def to_pil_image(pic, mode=None):
...
@@ -289,19 +287,19 @@ def to_pil_image(pic, mode=None):
elif
npimg
.
shape
[
2
]
==
4
:
elif
npimg
.
shape
[
2
]
==
4
:
permitted_4_channel_modes
=
[
"RGBA"
,
"CMYK"
,
"RGBX"
]
permitted_4_channel_modes
=
[
"RGBA"
,
"CMYK"
,
"RGBX"
]
if
mode
is
not
None
and
mode
not
in
permitted_4_channel_modes
:
if
mode
is
not
None
and
mode
not
in
permitted_4_channel_modes
:
raise
ValueError
(
"Only modes {} are supported for 4D inputs"
.
format
(
permitted_4_channel_modes
)
)
raise
ValueError
(
f
"Only modes
{
permitted_4_channel_modes
}
are supported for 4D inputs"
)
if
mode
is
None
and
npimg
.
dtype
==
np
.
uint8
:
if
mode
is
None
and
npimg
.
dtype
==
np
.
uint8
:
mode
=
"RGBA"
mode
=
"RGBA"
else
:
else
:
permitted_3_channel_modes
=
[
"RGB"
,
"YCbCr"
,
"HSV"
]
permitted_3_channel_modes
=
[
"RGB"
,
"YCbCr"
,
"HSV"
]
if
mode
is
not
None
and
mode
not
in
permitted_3_channel_modes
:
if
mode
is
not
None
and
mode
not
in
permitted_3_channel_modes
:
raise
ValueError
(
"Only modes {} are supported for 3D inputs"
.
format
(
permitted_3_channel_modes
)
)
raise
ValueError
(
f
"Only modes
{
permitted_3_channel_modes
}
are supported for 3D inputs"
)
if
mode
is
None
and
npimg
.
dtype
==
np
.
uint8
:
if
mode
is
None
and
npimg
.
dtype
==
np
.
uint8
:
mode
=
"RGB"
mode
=
"RGB"
if
mode
is
None
:
if
mode
is
None
:
raise
TypeError
(
"Input type {} is not supported"
.
format
(
npimg
.
dtype
)
)
raise
TypeError
(
f
"Input type
{
npimg
.
dtype
}
is not supported"
)
return
Image
.
fromarray
(
npimg
,
mode
=
mode
)
return
Image
.
fromarray
(
npimg
,
mode
=
mode
)
...
@@ -325,15 +323,14 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool
...
@@ -325,15 +323,14 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool
Tensor: Normalized Tensor image.
Tensor: Normalized Tensor image.
"""
"""
if
not
isinstance
(
tensor
,
torch
.
Tensor
):
if
not
isinstance
(
tensor
,
torch
.
Tensor
):
raise
TypeError
(
"Input tensor should be a torch tensor. Got {
}."
.
format
(
type
(
tensor
)
)
)
raise
TypeError
(
f
"Input tensor should be a torch tensor. Got
{
type
(
tensor
)
}
."
)
if
not
tensor
.
is_floating_point
():
if
not
tensor
.
is_floating_point
():
raise
TypeError
(
"Input tensor should be a float tensor. Got {
}."
.
format
(
tensor
.
dtype
)
)
raise
TypeError
(
f
"Input tensor should be a float tensor. Got
{
tensor
.
dtype
}
."
)
if
tensor
.
ndim
<
3
:
if
tensor
.
ndim
<
3
:
raise
ValueError
(
raise
ValueError
(
"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = "
f
"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() =
{
tensor
.
size
()
}
"
"{}."
.
format
(
tensor
.
size
())
)
)
if
not
inplace
:
if
not
inplace
:
...
@@ -343,7 +340,7 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool
...
@@ -343,7 +340,7 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool
mean
=
torch
.
as_tensor
(
mean
,
dtype
=
dtype
,
device
=
tensor
.
device
)
mean
=
torch
.
as_tensor
(
mean
,
dtype
=
dtype
,
device
=
tensor
.
device
)
std
=
torch
.
as_tensor
(
std
,
dtype
=
dtype
,
device
=
tensor
.
device
)
std
=
torch
.
as_tensor
(
std
,
dtype
=
dtype
,
device
=
tensor
.
device
)
if
(
std
==
0
).
any
():
if
(
std
==
0
).
any
():
raise
ValueError
(
"std evaluated to zero after conversion to {}, leading to division by zero."
.
format
(
dtype
)
)
raise
ValueError
(
f
"std evaluated to zero after conversion to
{
dtype
}
, leading to division by zero."
)
if
mean
.
ndim
==
1
:
if
mean
.
ndim
==
1
:
mean
=
mean
.
view
(
-
1
,
1
,
1
)
mean
=
mean
.
view
(
-
1
,
1
,
1
)
if
std
.
ndim
==
1
:
if
std
.
ndim
==
1
:
...
@@ -425,7 +422,7 @@ def resize(
...
@@ -425,7 +422,7 @@ def resize(
def
scale
(
*
args
,
**
kwargs
):
def
scale
(
*
args
,
**
kwargs
):
warnings
.
warn
(
"The use of the transforms.Scale transform is deprecated,
"
+
"
please use transforms.Resize instead."
)
warnings
.
warn
(
"The use of the transforms.Scale transform is deprecated, please use transforms.Resize instead."
)
return
resize
(
*
args
,
**
kwargs
)
return
resize
(
*
args
,
**
kwargs
)
...
@@ -923,7 +920,8 @@ def _get_inverse_affine_matrix(
...
@@ -923,7 +920,8 @@ def _get_inverse_affine_matrix(
# Thus, the inverse is M^-1 = C * RSS^-1 * C^-1 * T^-1
# Thus, the inverse is M^-1 = C * RSS^-1 * C^-1 * T^-1
rot
=
math
.
radians
(
angle
)
rot
=
math
.
radians
(
angle
)
sx
,
sy
=
[
math
.
radians
(
s
)
for
s
in
shear
]
sx
=
math
.
radians
(
shear
[
0
])
sy
=
math
.
radians
(
shear
[
1
])
cx
,
cy
=
center
cx
,
cy
=
center
tx
,
ty
=
translate
tx
,
ty
=
translate
...
@@ -1121,7 +1119,7 @@ def affine(
...
@@ -1121,7 +1119,7 @@ def affine(
shear
=
[
shear
[
0
],
shear
[
0
]]
shear
=
[
shear
[
0
],
shear
[
0
]]
if
len
(
shear
)
!=
2
:
if
len
(
shear
)
!=
2
:
raise
ValueError
(
"Shear should be a sequence containing two values. Got {
}"
.
format
(
shear
)
)
raise
ValueError
(
f
"Shear should be a sequence containing two values. Got
{
shear
}
"
)
img_size
=
get_image_size
(
img
)
img_size
=
get_image_size
(
img
)
if
not
isinstance
(
img
,
torch
.
Tensor
):
if
not
isinstance
(
img
,
torch
.
Tensor
):
...
@@ -1201,7 +1199,7 @@ def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool
...
@@ -1201,7 +1199,7 @@ def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool
Tensor Image: Erased image.
Tensor Image: Erased image.
"""
"""
if
not
isinstance
(
img
,
torch
.
Tensor
):
if
not
isinstance
(
img
,
torch
.
Tensor
):
raise
TypeError
(
"img should be Tensor Image. Got {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"img should be Tensor Image. Got
{
type
(
img
)
}
"
)
if
not
inplace
:
if
not
inplace
:
img
=
img
.
clone
()
img
=
img
.
clone
()
...
@@ -1237,34 +1235,34 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[floa
...
@@ -1237,34 +1235,34 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[floa
PIL Image or Tensor: Gaussian Blurred version of the image.
PIL Image or Tensor: Gaussian Blurred version of the image.
"""
"""
if
not
isinstance
(
kernel_size
,
(
int
,
list
,
tuple
)):
if
not
isinstance
(
kernel_size
,
(
int
,
list
,
tuple
)):
raise
TypeError
(
"kernel_size should be int or a sequence of integers. Got {
}"
.
format
(
type
(
kernel_size
)
)
)
raise
TypeError
(
f
"kernel_size should be int or a sequence of integers. Got
{
type
(
kernel_size
)
}
"
)
if
isinstance
(
kernel_size
,
int
):
if
isinstance
(
kernel_size
,
int
):
kernel_size
=
[
kernel_size
,
kernel_size
]
kernel_size
=
[
kernel_size
,
kernel_size
]
if
len
(
kernel_size
)
!=
2
:
if
len
(
kernel_size
)
!=
2
:
raise
ValueError
(
"If kernel_size is a sequence its length should be 2. Got {
}"
.
format
(
len
(
kernel_size
)
)
)
raise
ValueError
(
f
"If kernel_size is a sequence its length should be 2. Got
{
len
(
kernel_size
)
}
"
)
for
ksize
in
kernel_size
:
for
ksize
in
kernel_size
:
if
ksize
%
2
==
0
or
ksize
<
0
:
if
ksize
%
2
==
0
or
ksize
<
0
:
raise
ValueError
(
"kernel_size should have odd and positive integers. Got {
}"
.
format
(
kernel_size
)
)
raise
ValueError
(
f
"kernel_size should have odd and positive integers. Got
{
kernel_size
}
"
)
if
sigma
is
None
:
if
sigma
is
None
:
sigma
=
[
ksize
*
0.15
+
0.35
for
ksize
in
kernel_size
]
sigma
=
[
ksize
*
0.15
+
0.35
for
ksize
in
kernel_size
]
if
sigma
is
not
None
and
not
isinstance
(
sigma
,
(
int
,
float
,
list
,
tuple
)):
if
sigma
is
not
None
and
not
isinstance
(
sigma
,
(
int
,
float
,
list
,
tuple
)):
raise
TypeError
(
"sigma should be either float or sequence of floats. Got {
}"
.
format
(
type
(
sigma
)
)
)
raise
TypeError
(
f
"sigma should be either float or sequence of floats. Got
{
type
(
sigma
)
}
"
)
if
isinstance
(
sigma
,
(
int
,
float
)):
if
isinstance
(
sigma
,
(
int
,
float
)):
sigma
=
[
float
(
sigma
),
float
(
sigma
)]
sigma
=
[
float
(
sigma
),
float
(
sigma
)]
if
isinstance
(
sigma
,
(
list
,
tuple
))
and
len
(
sigma
)
==
1
:
if
isinstance
(
sigma
,
(
list
,
tuple
))
and
len
(
sigma
)
==
1
:
sigma
=
[
sigma
[
0
],
sigma
[
0
]]
sigma
=
[
sigma
[
0
],
sigma
[
0
]]
if
len
(
sigma
)
!=
2
:
if
len
(
sigma
)
!=
2
:
raise
ValueError
(
"If sigma is a sequence, its length should be 2. Got {
}"
.
format
(
len
(
sigma
)
)
)
raise
ValueError
(
f
"If sigma is a sequence, its length should be 2. Got
{
len
(
sigma
)
}
"
)
for
s
in
sigma
:
for
s
in
sigma
:
if
s
<=
0.0
:
if
s
<=
0.0
:
raise
ValueError
(
"sigma should have positive values. Got {
}"
.
format
(
sigma
)
)
raise
ValueError
(
f
"sigma should have positive values. Got
{
sigma
}
"
)
t_img
=
img
t_img
=
img
if
not
isinstance
(
img
,
torch
.
Tensor
):
if
not
isinstance
(
img
,
torch
.
Tensor
):
if
not
F_pil
.
_is_pil_image
(
img
):
if
not
F_pil
.
_is_pil_image
(
img
):
raise
TypeError
(
"img should be PIL Image or Tensor. Got {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"img should be PIL Image or Tensor. Got
{
type
(
img
)
}
"
)
t_img
=
to_tensor
(
img
)
t_img
=
to_tensor
(
img
)
...
@@ -1307,7 +1305,7 @@ def posterize(img: Tensor, bits: int) -> Tensor:
...
@@ -1307,7 +1305,7 @@ def posterize(img: Tensor, bits: int) -> Tensor:
PIL Image or Tensor: Posterized image.
PIL Image or Tensor: Posterized image.
"""
"""
if
not
(
0
<=
bits
<=
8
):
if
not
(
0
<=
bits
<=
8
):
raise
ValueError
(
"The number if bits should be between 0 and 8. Got {
}"
.
format
(
bits
)
)
raise
ValueError
(
f
"The number if bits should be between 0 and 8. Got
{
bits
}
"
)
if
not
isinstance
(
img
,
torch
.
Tensor
):
if
not
isinstance
(
img
,
torch
.
Tensor
):
return
F_pil
.
posterize
(
img
,
bits
)
return
F_pil
.
posterize
(
img
,
bits
)
...
...
torchvision/transforms/functional_pil.py
View file @
d367a01a
...
@@ -23,20 +23,20 @@ def _is_pil_image(img: Any) -> bool:
...
@@ -23,20 +23,20 @@ def _is_pil_image(img: Any) -> bool:
def
get_image_size
(
img
:
Any
)
->
List
[
int
]:
def
get_image_size
(
img
:
Any
)
->
List
[
int
]:
if
_is_pil_image
(
img
):
if
_is_pil_image
(
img
):
return
list
(
img
.
size
)
return
list
(
img
.
size
)
raise
TypeError
(
"Unexpected type {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"Unexpected type
{
type
(
img
)
}
"
)
@
torch
.
jit
.
unused
@
torch
.
jit
.
unused
def
get_image_num_channels
(
img
:
Any
)
->
int
:
def
get_image_num_channels
(
img
:
Any
)
->
int
:
if
_is_pil_image
(
img
):
if
_is_pil_image
(
img
):
return
1
if
img
.
mode
==
"L"
else
3
return
1
if
img
.
mode
==
"L"
else
3
raise
TypeError
(
"Unexpected type {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"Unexpected type
{
type
(
img
)
}
"
)
@
torch
.
jit
.
unused
@
torch
.
jit
.
unused
def
hflip
(
img
:
Image
.
Image
)
->
Image
.
Image
:
def
hflip
(
img
:
Image
.
Image
)
->
Image
.
Image
:
if
not
_is_pil_image
(
img
):
if
not
_is_pil_image
(
img
):
raise
TypeError
(
"img should be PIL Image. Got {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"img should be PIL Image. Got
{
type
(
img
)
}
"
)
return
img
.
transpose
(
Image
.
FLIP_LEFT_RIGHT
)
return
img
.
transpose
(
Image
.
FLIP_LEFT_RIGHT
)
...
@@ -44,7 +44,7 @@ def hflip(img: Image.Image) -> Image.Image:
...
@@ -44,7 +44,7 @@ def hflip(img: Image.Image) -> Image.Image:
@
torch
.
jit
.
unused
@
torch
.
jit
.
unused
def
vflip
(
img
:
Image
.
Image
)
->
Image
.
Image
:
def
vflip
(
img
:
Image
.
Image
)
->
Image
.
Image
:
if
not
_is_pil_image
(
img
):
if
not
_is_pil_image
(
img
):
raise
TypeError
(
"img should be PIL Image. Got {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"img should be PIL Image. Got
{
type
(
img
)
}
"
)
return
img
.
transpose
(
Image
.
FLIP_TOP_BOTTOM
)
return
img
.
transpose
(
Image
.
FLIP_TOP_BOTTOM
)
...
@@ -52,7 +52,7 @@ def vflip(img: Image.Image) -> Image.Image:
...
@@ -52,7 +52,7 @@ def vflip(img: Image.Image) -> Image.Image:
@
torch
.
jit
.
unused
@
torch
.
jit
.
unused
def
adjust_brightness
(
img
:
Image
.
Image
,
brightness_factor
:
float
)
->
Image
.
Image
:
def
adjust_brightness
(
img
:
Image
.
Image
,
brightness_factor
:
float
)
->
Image
.
Image
:
if
not
_is_pil_image
(
img
):
if
not
_is_pil_image
(
img
):
raise
TypeError
(
"img should be PIL Image. Got {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"img should be PIL Image. Got
{
type
(
img
)
}
"
)
enhancer
=
ImageEnhance
.
Brightness
(
img
)
enhancer
=
ImageEnhance
.
Brightness
(
img
)
img
=
enhancer
.
enhance
(
brightness_factor
)
img
=
enhancer
.
enhance
(
brightness_factor
)
...
@@ -62,7 +62,7 @@ def adjust_brightness(img: Image.Image, brightness_factor: float) -> Image.Image
...
@@ -62,7 +62,7 @@ def adjust_brightness(img: Image.Image, brightness_factor: float) -> Image.Image
@
torch
.
jit
.
unused
@
torch
.
jit
.
unused
def
adjust_contrast
(
img
:
Image
.
Image
,
contrast_factor
:
float
)
->
Image
.
Image
:
def
adjust_contrast
(
img
:
Image
.
Image
,
contrast_factor
:
float
)
->
Image
.
Image
:
if
not
_is_pil_image
(
img
):
if
not
_is_pil_image
(
img
):
raise
TypeError
(
"img should be PIL Image. Got {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"img should be PIL Image. Got
{
type
(
img
)
}
"
)
enhancer
=
ImageEnhance
.
Contrast
(
img
)
enhancer
=
ImageEnhance
.
Contrast
(
img
)
img
=
enhancer
.
enhance
(
contrast_factor
)
img
=
enhancer
.
enhance
(
contrast_factor
)
...
@@ -72,7 +72,7 @@ def adjust_contrast(img: Image.Image, contrast_factor: float) -> Image.Image:
...
@@ -72,7 +72,7 @@ def adjust_contrast(img: Image.Image, contrast_factor: float) -> Image.Image:
@
torch
.
jit
.
unused
@
torch
.
jit
.
unused
def
adjust_saturation
(
img
:
Image
.
Image
,
saturation_factor
:
float
)
->
Image
.
Image
:
def
adjust_saturation
(
img
:
Image
.
Image
,
saturation_factor
:
float
)
->
Image
.
Image
:
if
not
_is_pil_image
(
img
):
if
not
_is_pil_image
(
img
):
raise
TypeError
(
"img should be PIL Image. Got {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"img should be PIL Image. Got
{
type
(
img
)
}
"
)
enhancer
=
ImageEnhance
.
Color
(
img
)
enhancer
=
ImageEnhance
.
Color
(
img
)
img
=
enhancer
.
enhance
(
saturation_factor
)
img
=
enhancer
.
enhance
(
saturation_factor
)
...
@@ -82,10 +82,10 @@ def adjust_saturation(img: Image.Image, saturation_factor: float) -> Image.Image
...
@@ -82,10 +82,10 @@ def adjust_saturation(img: Image.Image, saturation_factor: float) -> Image.Image
@
torch
.
jit
.
unused
@
torch
.
jit
.
unused
def
adjust_hue
(
img
:
Image
.
Image
,
hue_factor
:
float
)
->
Image
.
Image
:
def
adjust_hue
(
img
:
Image
.
Image
,
hue_factor
:
float
)
->
Image
.
Image
:
if
not
(
-
0.5
<=
hue_factor
<=
0.5
):
if
not
(
-
0.5
<=
hue_factor
<=
0.5
):
raise
ValueError
(
"hue_factor ({}) is not in [-0.5, 0.5]."
.
format
(
hue_factor
)
)
raise
ValueError
(
f
"hue_factor (
{
hue_factor
}
) is not in [-0.5, 0.5]."
)
if
not
_is_pil_image
(
img
):
if
not
_is_pil_image
(
img
):
raise
TypeError
(
"img should be PIL Image. Got {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"img should be PIL Image. Got
{
type
(
img
)
}
"
)
input_mode
=
img
.
mode
input_mode
=
img
.
mode
if
input_mode
in
{
"L"
,
"1"
,
"I"
,
"F"
}:
if
input_mode
in
{
"L"
,
"1"
,
"I"
,
"F"
}:
...
@@ -111,7 +111,7 @@ def adjust_gamma(
...
@@ -111,7 +111,7 @@ def adjust_gamma(
)
->
Image
.
Image
:
)
->
Image
.
Image
:
if
not
_is_pil_image
(
img
):
if
not
_is_pil_image
(
img
):
raise
TypeError
(
"img should be PIL Image. Got {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"img should be PIL Image. Got
{
type
(
img
)
}
"
)
if
gamma
<
0
:
if
gamma
<
0
:
raise
ValueError
(
"Gamma should be a non-negative real number"
)
raise
ValueError
(
"Gamma should be a non-negative real number"
)
...
@@ -134,7 +134,7 @@ def pad(
...
@@ -134,7 +134,7 @@ def pad(
)
->
Image
.
Image
:
)
->
Image
.
Image
:
if
not
_is_pil_image
(
img
):
if
not
_is_pil_image
(
img
):
raise
TypeError
(
"img should be PIL Image. Got {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"img should be PIL Image. Got
{
type
(
img
)
}
"
)
if
not
isinstance
(
padding
,
(
numbers
.
Number
,
tuple
,
list
)):
if
not
isinstance
(
padding
,
(
numbers
.
Number
,
tuple
,
list
)):
raise
TypeError
(
"Got inappropriate padding arg"
)
raise
TypeError
(
"Got inappropriate padding arg"
)
...
@@ -147,9 +147,7 @@ def pad(
...
@@ -147,9 +147,7 @@ def pad(
padding
=
tuple
(
padding
)
padding
=
tuple
(
padding
)
if
isinstance
(
padding
,
tuple
)
and
len
(
padding
)
not
in
[
1
,
2
,
4
]:
if
isinstance
(
padding
,
tuple
)
and
len
(
padding
)
not
in
[
1
,
2
,
4
]:
raise
ValueError
(
raise
ValueError
(
f
"Padding must be an int or a 1, 2, or 4 element tuple, not a
{
len
(
padding
)
}
element tuple"
)
"Padding must be an int or a 1, 2, or 4 element tuple, not a "
+
"{} element tuple"
.
format
(
len
(
padding
))
)
if
isinstance
(
padding
,
tuple
)
and
len
(
padding
)
==
1
:
if
isinstance
(
padding
,
tuple
)
and
len
(
padding
)
==
1
:
# Compatibility with `functional_tensor.pad`
# Compatibility with `functional_tensor.pad`
...
@@ -217,7 +215,7 @@ def crop(
...
@@ -217,7 +215,7 @@ def crop(
)
->
Image
.
Image
:
)
->
Image
.
Image
:
if
not
_is_pil_image
(
img
):
if
not
_is_pil_image
(
img
):
raise
TypeError
(
"img should be PIL Image. Got {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"img should be PIL Image. Got
{
type
(
img
)
}
"
)
return
img
.
crop
((
left
,
top
,
left
+
width
,
top
+
height
))
return
img
.
crop
((
left
,
top
,
left
+
width
,
top
+
height
))
...
@@ -231,9 +229,9 @@ def resize(
...
@@ -231,9 +229,9 @@ def resize(
)
->
Image
.
Image
:
)
->
Image
.
Image
:
if
not
_is_pil_image
(
img
):
if
not
_is_pil_image
(
img
):
raise
TypeError
(
"img should be PIL Image. Got {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"img should be PIL Image. Got
{
type
(
img
)
}
"
)
if
not
(
isinstance
(
size
,
int
)
or
(
isinstance
(
size
,
Sequence
)
and
len
(
size
)
in
(
1
,
2
))):
if
not
(
isinstance
(
size
,
int
)
or
(
isinstance
(
size
,
Sequence
)
and
len
(
size
)
in
(
1
,
2
))):
raise
TypeError
(
"Got inappropriate size arg: {
}"
.
format
(
size
)
)
raise
TypeError
(
f
"Got inappropriate size arg:
{
size
}
"
)
if
isinstance
(
size
,
Sequence
)
and
len
(
size
)
==
1
:
if
isinstance
(
size
,
Sequence
)
and
len
(
size
)
==
1
:
size
=
size
[
0
]
size
=
size
[
0
]
...
@@ -281,7 +279,7 @@ def _parse_fill(
...
@@ -281,7 +279,7 @@ def _parse_fill(
fill
=
tuple
([
fill
]
*
num_bands
)
fill
=
tuple
([
fill
]
*
num_bands
)
if
isinstance
(
fill
,
(
list
,
tuple
)):
if
isinstance
(
fill
,
(
list
,
tuple
)):
if
len
(
fill
)
!=
num_bands
:
if
len
(
fill
)
!=
num_bands
:
msg
=
"The number of elements in 'fill' does not match the number of
"
"
bands of the image ({} != {})"
msg
=
"The number of elements in 'fill' does not match the number of bands of the image ({} != {})"
raise
ValueError
(
msg
.
format
(
len
(
fill
),
num_bands
))
raise
ValueError
(
msg
.
format
(
len
(
fill
),
num_bands
))
fill
=
tuple
(
fill
)
fill
=
tuple
(
fill
)
...
@@ -298,7 +296,7 @@ def affine(
...
@@ -298,7 +296,7 @@ def affine(
)
->
Image
.
Image
:
)
->
Image
.
Image
:
if
not
_is_pil_image
(
img
):
if
not
_is_pil_image
(
img
):
raise
TypeError
(
"img should be PIL Image. Got {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"img should be PIL Image. Got
{
type
(
img
)
}
"
)
output_size
=
img
.
size
output_size
=
img
.
size
opts
=
_parse_fill
(
fill
,
img
)
opts
=
_parse_fill
(
fill
,
img
)
...
@@ -316,7 +314,7 @@ def rotate(
...
@@ -316,7 +314,7 @@ def rotate(
)
->
Image
.
Image
:
)
->
Image
.
Image
:
if
not
_is_pil_image
(
img
):
if
not
_is_pil_image
(
img
):
raise
TypeError
(
"img should be PIL Image. Got {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"img should be PIL Image. Got
{
type
(
img
)
}
"
)
opts
=
_parse_fill
(
fill
,
img
)
opts
=
_parse_fill
(
fill
,
img
)
return
img
.
rotate
(
angle
,
interpolation
,
expand
,
center
,
**
opts
)
return
img
.
rotate
(
angle
,
interpolation
,
expand
,
center
,
**
opts
)
...
@@ -331,7 +329,7 @@ def perspective(
...
@@ -331,7 +329,7 @@ def perspective(
)
->
Image
.
Image
:
)
->
Image
.
Image
:
if
not
_is_pil_image
(
img
):
if
not
_is_pil_image
(
img
):
raise
TypeError
(
"img should be PIL Image. Got {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"img should be PIL Image. Got
{
type
(
img
)
}
"
)
opts
=
_parse_fill
(
fill
,
img
)
opts
=
_parse_fill
(
fill
,
img
)
...
@@ -341,7 +339,7 @@ def perspective(
...
@@ -341,7 +339,7 @@ def perspective(
@
torch
.
jit
.
unused
@
torch
.
jit
.
unused
def
to_grayscale
(
img
:
Image
.
Image
,
num_output_channels
:
int
)
->
Image
.
Image
:
def
to_grayscale
(
img
:
Image
.
Image
,
num_output_channels
:
int
)
->
Image
.
Image
:
if
not
_is_pil_image
(
img
):
if
not
_is_pil_image
(
img
):
raise
TypeError
(
"img should be PIL Image. Got {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"img should be PIL Image. Got
{
type
(
img
)
}
"
)
if
num_output_channels
==
1
:
if
num_output_channels
==
1
:
img
=
img
.
convert
(
"L"
)
img
=
img
.
convert
(
"L"
)
...
@@ -359,28 +357,28 @@ def to_grayscale(img: Image.Image, num_output_channels: int) -> Image.Image:
...
@@ -359,28 +357,28 @@ def to_grayscale(img: Image.Image, num_output_channels: int) -> Image.Image:
@
torch
.
jit
.
unused
@
torch
.
jit
.
unused
def
invert
(
img
:
Image
.
Image
)
->
Image
.
Image
:
def
invert
(
img
:
Image
.
Image
)
->
Image
.
Image
:
if
not
_is_pil_image
(
img
):
if
not
_is_pil_image
(
img
):
raise
TypeError
(
"img should be PIL Image. Got {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"img should be PIL Image. Got
{
type
(
img
)
}
"
)
return
ImageOps
.
invert
(
img
)
return
ImageOps
.
invert
(
img
)
@
torch
.
jit
.
unused
@
torch
.
jit
.
unused
def
posterize
(
img
:
Image
.
Image
,
bits
:
int
)
->
Image
.
Image
:
def
posterize
(
img
:
Image
.
Image
,
bits
:
int
)
->
Image
.
Image
:
if
not
_is_pil_image
(
img
):
if
not
_is_pil_image
(
img
):
raise
TypeError
(
"img should be PIL Image. Got {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"img should be PIL Image. Got
{
type
(
img
)
}
"
)
return
ImageOps
.
posterize
(
img
,
bits
)
return
ImageOps
.
posterize
(
img
,
bits
)
@
torch
.
jit
.
unused
@
torch
.
jit
.
unused
def
solarize
(
img
:
Image
.
Image
,
threshold
:
int
)
->
Image
.
Image
:
def
solarize
(
img
:
Image
.
Image
,
threshold
:
int
)
->
Image
.
Image
:
if
not
_is_pil_image
(
img
):
if
not
_is_pil_image
(
img
):
raise
TypeError
(
"img should be PIL Image. Got {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"img should be PIL Image. Got
{
type
(
img
)
}
"
)
return
ImageOps
.
solarize
(
img
,
threshold
)
return
ImageOps
.
solarize
(
img
,
threshold
)
@
torch
.
jit
.
unused
@
torch
.
jit
.
unused
def
adjust_sharpness
(
img
:
Image
.
Image
,
sharpness_factor
:
float
)
->
Image
.
Image
:
def
adjust_sharpness
(
img
:
Image
.
Image
,
sharpness_factor
:
float
)
->
Image
.
Image
:
if
not
_is_pil_image
(
img
):
if
not
_is_pil_image
(
img
):
raise
TypeError
(
"img should be PIL Image. Got {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"img should be PIL Image. Got
{
type
(
img
)
}
"
)
enhancer
=
ImageEnhance
.
Sharpness
(
img
)
enhancer
=
ImageEnhance
.
Sharpness
(
img
)
img
=
enhancer
.
enhance
(
sharpness_factor
)
img
=
enhancer
.
enhance
(
sharpness_factor
)
...
@@ -390,12 +388,12 @@ def adjust_sharpness(img: Image.Image, sharpness_factor: float) -> Image.Image:
...
@@ -390,12 +388,12 @@ def adjust_sharpness(img: Image.Image, sharpness_factor: float) -> Image.Image:
@
torch
.
jit
.
unused
@
torch
.
jit
.
unused
def
autocontrast
(
img
:
Image
.
Image
)
->
Image
.
Image
:
def
autocontrast
(
img
:
Image
.
Image
)
->
Image
.
Image
:
if
not
_is_pil_image
(
img
):
if
not
_is_pil_image
(
img
):
raise
TypeError
(
"img should be PIL Image. Got {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"img should be PIL Image. Got
{
type
(
img
)
}
"
)
return
ImageOps
.
autocontrast
(
img
)
return
ImageOps
.
autocontrast
(
img
)
@
torch
.
jit
.
unused
@
torch
.
jit
.
unused
def
equalize
(
img
:
Image
.
Image
)
->
Image
.
Image
:
def
equalize
(
img
:
Image
.
Image
)
->
Image
.
Image
:
if
not
_is_pil_image
(
img
):
if
not
_is_pil_image
(
img
):
raise
TypeError
(
"img should be PIL Image. Got {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"img should be PIL Image. Got
{
type
(
img
)
}
"
)
return
ImageOps
.
equalize
(
img
)
return
ImageOps
.
equalize
(
img
)
torchvision/transforms/functional_tensor.py
View file @
d367a01a
...
@@ -28,7 +28,7 @@ def get_image_num_channels(img: Tensor) -> int:
...
@@ -28,7 +28,7 @@ def get_image_num_channels(img: Tensor) -> int:
elif
img
.
ndim
>
2
:
elif
img
.
ndim
>
2
:
return
img
.
shape
[
-
3
]
return
img
.
shape
[
-
3
]
raise
TypeError
(
"Input ndim should be 2 or more. Got {
}"
.
format
(
img
.
ndim
)
)
raise
TypeError
(
f
"Input ndim should be 2 or more. Got
{
img
.
ndim
}
"
)
def
_max_value
(
dtype
:
torch
.
dtype
)
->
float
:
def
_max_value
(
dtype
:
torch
.
dtype
)
->
float
:
...
@@ -52,7 +52,7 @@ def _max_value(dtype: torch.dtype) -> float:
...
@@ -52,7 +52,7 @@ def _max_value(dtype: torch.dtype) -> float:
def
_assert_channels
(
img
:
Tensor
,
permitted
:
List
[
int
])
->
None
:
def
_assert_channels
(
img
:
Tensor
,
permitted
:
List
[
int
])
->
None
:
c
=
get_image_num_channels
(
img
)
c
=
get_image_num_channels
(
img
)
if
c
not
in
permitted
:
if
c
not
in
permitted
:
raise
TypeError
(
"Input image tensor permitted channel values are {}, but found {}"
.
format
(
permitted
,
c
)
)
raise
TypeError
(
f
"Input image tensor permitted channel values are
{
permitted
}
, but found
{
c
}
"
)
def
convert_image_dtype
(
image
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
=
torch
.
float
)
->
torch
.
Tensor
:
def
convert_image_dtype
(
image
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
=
torch
.
float
)
->
torch
.
Tensor
:
...
@@ -134,7 +134,7 @@ def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
...
@@ -134,7 +134,7 @@ def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
def
rgb_to_grayscale
(
img
:
Tensor
,
num_output_channels
:
int
=
1
)
->
Tensor
:
def
rgb_to_grayscale
(
img
:
Tensor
,
num_output_channels
:
int
=
1
)
->
Tensor
:
if
img
.
ndim
<
3
:
if
img
.
ndim
<
3
:
raise
TypeError
(
"Input image tensor should have at least 3 dimensions, but found {
}"
.
format
(
img
.
ndim
)
)
raise
TypeError
(
f
"Input image tensor should have at least 3 dimensions, but found
{
img
.
ndim
}
"
)
_assert_channels
(
img
,
[
3
])
_assert_channels
(
img
,
[
3
])
if
num_output_channels
not
in
(
1
,
3
):
if
num_output_channels
not
in
(
1
,
3
):
...
@@ -154,7 +154,7 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
...
@@ -154,7 +154,7 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
def
adjust_brightness
(
img
:
Tensor
,
brightness_factor
:
float
)
->
Tensor
:
def
adjust_brightness
(
img
:
Tensor
,
brightness_factor
:
float
)
->
Tensor
:
if
brightness_factor
<
0
:
if
brightness_factor
<
0
:
raise
ValueError
(
"brightness_factor ({}) is not non-negative."
.
format
(
brightness_factor
)
)
raise
ValueError
(
f
"brightness_factor (
{
brightness_factor
}
) is not non-negative."
)
_assert_image_tensor
(
img
)
_assert_image_tensor
(
img
)
...
@@ -165,7 +165,7 @@ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
...
@@ -165,7 +165,7 @@ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
def
adjust_contrast
(
img
:
Tensor
,
contrast_factor
:
float
)
->
Tensor
:
def
adjust_contrast
(
img
:
Tensor
,
contrast_factor
:
float
)
->
Tensor
:
if
contrast_factor
<
0
:
if
contrast_factor
<
0
:
raise
ValueError
(
"contrast_factor ({}) is not non-negative."
.
format
(
contrast_factor
)
)
raise
ValueError
(
f
"contrast_factor (
{
contrast_factor
}
) is not non-negative."
)
_assert_image_tensor
(
img
)
_assert_image_tensor
(
img
)
...
@@ -182,7 +182,7 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
...
@@ -182,7 +182,7 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
def
adjust_hue
(
img
:
Tensor
,
hue_factor
:
float
)
->
Tensor
:
def
adjust_hue
(
img
:
Tensor
,
hue_factor
:
float
)
->
Tensor
:
if
not
(
-
0.5
<=
hue_factor
<=
0.5
):
if
not
(
-
0.5
<=
hue_factor
<=
0.5
):
raise
ValueError
(
"hue_factor ({}) is not in [-0.5, 0.5]."
.
format
(
hue_factor
)
)
raise
ValueError
(
f
"hue_factor (
{
hue_factor
}
) is not in [-0.5, 0.5]."
)
if
not
(
isinstance
(
img
,
torch
.
Tensor
)):
if
not
(
isinstance
(
img
,
torch
.
Tensor
)):
raise
TypeError
(
"Input img should be Tensor image"
)
raise
TypeError
(
"Input img should be Tensor image"
)
...
@@ -211,7 +211,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
...
@@ -211,7 +211,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
def
adjust_saturation
(
img
:
Tensor
,
saturation_factor
:
float
)
->
Tensor
:
def
adjust_saturation
(
img
:
Tensor
,
saturation_factor
:
float
)
->
Tensor
:
if
saturation_factor
<
0
:
if
saturation_factor
<
0
:
raise
ValueError
(
"saturation_factor ({}) is not non-negative."
.
format
(
saturation_factor
)
)
raise
ValueError
(
f
"saturation_factor (
{
saturation_factor
}
) is not non-negative."
)
_assert_image_tensor
(
img
)
_assert_image_tensor
(
img
)
...
@@ -246,7 +246,7 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
...
@@ -246,7 +246,7 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
def
center_crop
(
img
:
Tensor
,
output_size
:
BroadcastingList2
[
int
])
->
Tensor
:
def
center_crop
(
img
:
Tensor
,
output_size
:
BroadcastingList2
[
int
])
->
Tensor
:
"""DEPRECATED"""
"""DEPRECATED"""
warnings
.
warn
(
warnings
.
warn
(
"This method is deprecated and will be removed in future releases.
"
"
Please, use ``F.center_crop`` instead."
"This method is deprecated and will be removed in future releases. Please, use ``F.center_crop`` instead."
)
)
_assert_image_tensor
(
img
)
_assert_image_tensor
(
img
)
...
@@ -268,7 +268,7 @@ def center_crop(img: Tensor, output_size: BroadcastingList2[int]) -> Tensor:
...
@@ -268,7 +268,7 @@ def center_crop(img: Tensor, output_size: BroadcastingList2[int]) -> Tensor:
def
five_crop
(
img
:
Tensor
,
size
:
BroadcastingList2
[
int
])
->
List
[
Tensor
]:
def
five_crop
(
img
:
Tensor
,
size
:
BroadcastingList2
[
int
])
->
List
[
Tensor
]:
"""DEPRECATED"""
"""DEPRECATED"""
warnings
.
warn
(
warnings
.
warn
(
"This method is deprecated and will be removed in future releases.
"
"
Please, use ``F.five_crop`` instead."
"This method is deprecated and will be removed in future releases. Please, use ``F.five_crop`` instead."
)
)
_assert_image_tensor
(
img
)
_assert_image_tensor
(
img
)
...
@@ -293,7 +293,7 @@ def five_crop(img: Tensor, size: BroadcastingList2[int]) -> List[Tensor]:
...
@@ -293,7 +293,7 @@ def five_crop(img: Tensor, size: BroadcastingList2[int]) -> List[Tensor]:
def
ten_crop
(
img
:
Tensor
,
size
:
BroadcastingList2
[
int
],
vertical_flip
:
bool
=
False
)
->
List
[
Tensor
]:
def
ten_crop
(
img
:
Tensor
,
size
:
BroadcastingList2
[
int
],
vertical_flip
:
bool
=
False
)
->
List
[
Tensor
]:
"""DEPRECATED"""
"""DEPRECATED"""
warnings
.
warn
(
warnings
.
warn
(
"This method is deprecated and will be removed in future releases.
"
"
Please, use ``F.ten_crop`` instead."
"This method is deprecated and will be removed in future releases. Please, use ``F.ten_crop`` instead."
)
)
_assert_image_tensor
(
img
)
_assert_image_tensor
(
img
)
...
@@ -382,7 +382,8 @@ def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor:
...
@@ -382,7 +382,8 @@ def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor:
# crop if needed
# crop if needed
if
padding
[
0
]
<
0
or
padding
[
1
]
<
0
or
padding
[
2
]
<
0
or
padding
[
3
]
<
0
:
if
padding
[
0
]
<
0
or
padding
[
1
]
<
0
or
padding
[
2
]
<
0
or
padding
[
3
]
<
0
:
crop_left
,
crop_right
,
crop_top
,
crop_bottom
=
[
-
min
(
x
,
0
)
for
x
in
padding
]
neg_min_padding
=
[
-
min
(
x
,
0
)
for
x
in
padding
]
crop_left
,
crop_right
,
crop_top
,
crop_bottom
=
neg_min_padding
img
=
img
[...,
crop_top
:
img
.
shape
[
-
2
]
-
crop_bottom
,
crop_left
:
img
.
shape
[
-
1
]
-
crop_right
]
img
=
img
[...,
crop_top
:
img
.
shape
[
-
2
]
-
crop_bottom
,
crop_left
:
img
.
shape
[
-
1
]
-
crop_right
]
padding
=
[
max
(
x
,
0
)
for
x
in
padding
]
padding
=
[
max
(
x
,
0
)
for
x
in
padding
]
...
@@ -421,9 +422,7 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
...
@@ -421,9 +422,7 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
padding
=
list
(
padding
)
padding
=
list
(
padding
)
if
isinstance
(
padding
,
list
)
and
len
(
padding
)
not
in
[
1
,
2
,
4
]:
if
isinstance
(
padding
,
list
)
and
len
(
padding
)
not
in
[
1
,
2
,
4
]:
raise
ValueError
(
raise
ValueError
(
f
"Padding must be an int or a 1, 2, or 4 element tuple, not a
{
len
(
padding
)
}
element tuple"
)
"Padding must be an int or a 1, 2, or 4 element tuple, not a "
+
"{} element tuple"
.
format
(
len
(
padding
))
)
if
padding_mode
not
in
[
"constant"
,
"edge"
,
"reflect"
,
"symmetric"
]:
if
padding_mode
not
in
[
"constant"
,
"edge"
,
"reflect"
,
"symmetric"
]:
raise
ValueError
(
"Padding mode should be either constant, edge, reflect or symmetric"
)
raise
ValueError
(
"Padding mode should be either constant, edge, reflect or symmetric"
)
...
@@ -501,7 +500,7 @@ def resize(
...
@@ -501,7 +500,7 @@ def resize(
if
isinstance
(
size
,
list
):
if
isinstance
(
size
,
list
):
if
len
(
size
)
not
in
[
1
,
2
]:
if
len
(
size
)
not
in
[
1
,
2
]:
raise
ValueError
(
raise
ValueError
(
"Size must be an int or a 1 or 2 element tuple/list, not a
"
"{
} element tuple/list"
.
format
(
len
(
size
))
f
"Size must be an int or a 1 or 2 element tuple/list, not a
{
len
(
size
)
}
element tuple/list"
)
)
if
max_size
is
not
None
and
len
(
size
)
!=
1
:
if
max_size
is
not
None
and
len
(
size
)
!=
1
:
raise
ValueError
(
raise
ValueError
(
...
@@ -597,7 +596,7 @@ def _assert_grid_transform_inputs(
...
@@ -597,7 +596,7 @@ def _assert_grid_transform_inputs(
raise
ValueError
(
msg
.
format
(
len
(
fill
),
num_channels
))
raise
ValueError
(
msg
.
format
(
len
(
fill
),
num_channels
))
if
interpolation
not
in
supported_interpolation_modes
:
if
interpolation
not
in
supported_interpolation_modes
:
raise
ValueError
(
"Interpolation mode '{}' is unsupported with Tensor input"
.
format
(
interpolation
)
)
raise
ValueError
(
f
"Interpolation mode '
{
interpolation
}
' is unsupported with Tensor input"
)
def
_cast_squeeze_in
(
img
:
Tensor
,
req_dtypes
:
List
[
torch
.
dtype
])
->
Tuple
[
Tensor
,
bool
,
bool
,
torch
.
dtype
]:
def
_cast_squeeze_in
(
img
:
Tensor
,
req_dtypes
:
List
[
torch
.
dtype
])
->
Tuple
[
Tensor
,
bool
,
bool
,
torch
.
dtype
]:
...
@@ -823,7 +822,7 @@ def _get_gaussian_kernel2d(
...
@@ -823,7 +822,7 @@ def _get_gaussian_kernel2d(
def
gaussian_blur
(
img
:
Tensor
,
kernel_size
:
List
[
int
],
sigma
:
List
[
float
])
->
Tensor
:
def
gaussian_blur
(
img
:
Tensor
,
kernel_size
:
List
[
int
],
sigma
:
List
[
float
])
->
Tensor
:
if
not
(
isinstance
(
img
,
torch
.
Tensor
)):
if
not
(
isinstance
(
img
,
torch
.
Tensor
)):
raise
TypeError
(
"img should be Tensor. Got {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"img should be Tensor. Got
{
type
(
img
)
}
"
)
_assert_image_tensor
(
img
)
_assert_image_tensor
(
img
)
...
@@ -852,7 +851,7 @@ def invert(img: Tensor) -> Tensor:
...
@@ -852,7 +851,7 @@ def invert(img: Tensor) -> Tensor:
_assert_image_tensor
(
img
)
_assert_image_tensor
(
img
)
if
img
.
ndim
<
3
:
if
img
.
ndim
<
3
:
raise
TypeError
(
"Input image tensor should have at least 3 dimensions, but found {
}"
.
format
(
img
.
ndim
)
)
raise
TypeError
(
f
"Input image tensor should have at least 3 dimensions, but found
{
img
.
ndim
}
"
)
_assert_channels
(
img
,
[
1
,
3
])
_assert_channels
(
img
,
[
1
,
3
])
...
@@ -865,9 +864,9 @@ def posterize(img: Tensor, bits: int) -> Tensor:
...
@@ -865,9 +864,9 @@ def posterize(img: Tensor, bits: int) -> Tensor:
_assert_image_tensor
(
img
)
_assert_image_tensor
(
img
)
if
img
.
ndim
<
3
:
if
img
.
ndim
<
3
:
raise
TypeError
(
"Input image tensor should have at least 3 dimensions, but found {
}"
.
format
(
img
.
ndim
)
)
raise
TypeError
(
f
"Input image tensor should have at least 3 dimensions, but found
{
img
.
ndim
}
"
)
if
img
.
dtype
!=
torch
.
uint8
:
if
img
.
dtype
!=
torch
.
uint8
:
raise
TypeError
(
"Only torch.uint8 image tensors are supported, but found {
}"
.
format
(
img
.
dtype
)
)
raise
TypeError
(
f
"Only torch.uint8 image tensors are supported, but found
{
img
.
dtype
}
"
)
_assert_channels
(
img
,
[
1
,
3
])
_assert_channels
(
img
,
[
1
,
3
])
mask
=
-
int
(
2
**
(
8
-
bits
))
# JIT-friendly for: ~(2 ** (8 - bits) - 1)
mask
=
-
int
(
2
**
(
8
-
bits
))
# JIT-friendly for: ~(2 ** (8 - bits) - 1)
...
@@ -879,7 +878,7 @@ def solarize(img: Tensor, threshold: float) -> Tensor:
...
@@ -879,7 +878,7 @@ def solarize(img: Tensor, threshold: float) -> Tensor:
_assert_image_tensor
(
img
)
_assert_image_tensor
(
img
)
if
img
.
ndim
<
3
:
if
img
.
ndim
<
3
:
raise
TypeError
(
"Input image tensor should have at least 3 dimensions, but found {
}"
.
format
(
img
.
ndim
)
)
raise
TypeError
(
f
"Input image tensor should have at least 3 dimensions, but found
{
img
.
ndim
}
"
)
_assert_channels
(
img
,
[
1
,
3
])
_assert_channels
(
img
,
[
1
,
3
])
...
@@ -912,7 +911,7 @@ def _blurred_degenerate_image(img: Tensor) -> Tensor:
...
@@ -912,7 +911,7 @@ def _blurred_degenerate_image(img: Tensor) -> Tensor:
def
adjust_sharpness
(
img
:
Tensor
,
sharpness_factor
:
float
)
->
Tensor
:
def
adjust_sharpness
(
img
:
Tensor
,
sharpness_factor
:
float
)
->
Tensor
:
if
sharpness_factor
<
0
:
if
sharpness_factor
<
0
:
raise
ValueError
(
"sharpness_factor ({}) is not non-negative."
.
format
(
sharpness_factor
)
)
raise
ValueError
(
f
"sharpness_factor (
{
sharpness_factor
}
) is not non-negative."
)
_assert_image_tensor
(
img
)
_assert_image_tensor
(
img
)
...
@@ -929,7 +928,7 @@ def autocontrast(img: Tensor) -> Tensor:
...
@@ -929,7 +928,7 @@ def autocontrast(img: Tensor) -> Tensor:
_assert_image_tensor
(
img
)
_assert_image_tensor
(
img
)
if
img
.
ndim
<
3
:
if
img
.
ndim
<
3
:
raise
TypeError
(
"Input image tensor should have at least 3 dimensions, but found {
}"
.
format
(
img
.
ndim
)
)
raise
TypeError
(
f
"Input image tensor should have at least 3 dimensions, but found
{
img
.
ndim
}
"
)
_assert_channels
(
img
,
[
1
,
3
])
_assert_channels
(
img
,
[
1
,
3
])
...
@@ -976,9 +975,9 @@ def equalize(img: Tensor) -> Tensor:
...
@@ -976,9 +975,9 @@ def equalize(img: Tensor) -> Tensor:
_assert_image_tensor
(
img
)
_assert_image_tensor
(
img
)
if
not
(
3
<=
img
.
ndim
<=
4
):
if
not
(
3
<=
img
.
ndim
<=
4
):
raise
TypeError
(
"Input image tensor should have 3 or 4 dimensions, but found {
}"
.
format
(
img
.
ndim
)
)
raise
TypeError
(
f
"Input image tensor should have 3 or 4 dimensions, but found
{
img
.
ndim
}
"
)
if
img
.
dtype
!=
torch
.
uint8
:
if
img
.
dtype
!=
torch
.
uint8
:
raise
TypeError
(
"Only torch.uint8 image tensors are supported, but found {
}"
.
format
(
img
.
dtype
)
)
raise
TypeError
(
f
"Only torch.uint8 image tensors are supported, but found
{
img
.
dtype
}
"
)
_assert_channels
(
img
,
[
1
,
3
])
_assert_channels
(
img
,
[
1
,
3
])
...
...
torchvision/transforms/transforms.py
View file @
d367a01a
...
@@ -98,7 +98,7 @@ class Compose:
...
@@ -98,7 +98,7 @@ class Compose:
format_string
=
self
.
__class__
.
__name__
+
"("
format_string
=
self
.
__class__
.
__name__
+
"("
for
t
in
self
.
transforms
:
for
t
in
self
.
transforms
:
format_string
+=
"
\n
"
format_string
+=
"
\n
"
format_string
+=
" {
0
}"
.
format
(
t
)
format_string
+=
f
"
{
t
}
"
format_string
+=
"
\n
)"
format_string
+=
"
\n
)"
return
format_string
return
format_string
...
@@ -220,7 +220,7 @@ class ToPILImage:
...
@@ -220,7 +220,7 @@ class ToPILImage:
def
__repr__
(
self
):
def
__repr__
(
self
):
format_string
=
self
.
__class__
.
__name__
+
"("
format_string
=
self
.
__class__
.
__name__
+
"("
if
self
.
mode
is
not
None
:
if
self
.
mode
is
not
None
:
format_string
+=
"mode={
0}"
.
format
(
self
.
mode
)
format_string
+=
f
"mode=
{
self
.
mode
}
"
format_string
+=
")"
format_string
+=
")"
return
format_string
return
format_string
...
@@ -260,7 +260,7 @@ class Normalize(torch.nn.Module):
...
@@ -260,7 +260,7 @@ class Normalize(torch.nn.Module):
return
F
.
normalize
(
tensor
,
self
.
mean
,
self
.
std
,
self
.
inplace
)
return
F
.
normalize
(
tensor
,
self
.
mean
,
self
.
std
,
self
.
inplace
)
def
__repr__
(
self
):
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"(mean={
0}, std={1})"
.
format
(
self
.
mean
,
self
.
std
)
return
self
.
__class__
.
__name__
+
f
"(mean=
{
self
.
mean
}
,
std=
{
self
.
std
}
)"
class
Resize
(
torch
.
nn
.
Module
):
class
Resize
(
torch
.
nn
.
Module
):
...
@@ -310,7 +310,7 @@ class Resize(torch.nn.Module):
...
@@ -310,7 +310,7 @@ class Resize(torch.nn.Module):
def
__init__
(
self
,
size
,
interpolation
=
InterpolationMode
.
BILINEAR
,
max_size
=
None
,
antialias
=
None
):
def
__init__
(
self
,
size
,
interpolation
=
InterpolationMode
.
BILINEAR
,
max_size
=
None
,
antialias
=
None
):
super
().
__init__
()
super
().
__init__
()
if
not
isinstance
(
size
,
(
int
,
Sequence
)):
if
not
isinstance
(
size
,
(
int
,
Sequence
)):
raise
TypeError
(
"Size should be int or sequence. Got {
}"
.
format
(
type
(
size
)
)
)
raise
TypeError
(
f
"Size should be int or sequence. Got
{
type
(
size
)
}
"
)
if
isinstance
(
size
,
Sequence
)
and
len
(
size
)
not
in
(
1
,
2
):
if
isinstance
(
size
,
Sequence
)
and
len
(
size
)
not
in
(
1
,
2
):
raise
ValueError
(
"If size is a sequence, it should have 1 or 2 values"
)
raise
ValueError
(
"If size is a sequence, it should have 1 or 2 values"
)
self
.
size
=
size
self
.
size
=
size
...
@@ -338,10 +338,8 @@ class Resize(torch.nn.Module):
...
@@ -338,10 +338,8 @@ class Resize(torch.nn.Module):
return
F
.
resize
(
img
,
self
.
size
,
self
.
interpolation
,
self
.
max_size
,
self
.
antialias
)
return
F
.
resize
(
img
,
self
.
size
,
self
.
interpolation
,
self
.
max_size
,
self
.
antialias
)
def
__repr__
(
self
):
def
__repr__
(
self
):
interpolate_str
=
self
.
interpolation
.
value
detail
=
f
"(size=
{
self
.
size
}
, interpolation=
{
self
.
interpolation
.
value
}
, max_size=
{
self
.
max_size
}
, antialias=
{
self
.
antialias
}
)"
return
self
.
__class__
.
__name__
+
"(size={0}, interpolation={1}, max_size={2}, antialias={3})"
.
format
(
return
self
.
__class__
.
__name__
+
detail
self
.
size
,
interpolate_str
,
self
.
max_size
,
self
.
antialias
)
class
Scale
(
Resize
):
class
Scale
(
Resize
):
...
@@ -350,10 +348,8 @@ class Scale(Resize):
...
@@ -350,10 +348,8 @@ class Scale(Resize):
"""
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
warnings
.
warn
(
warnings
.
warn
(
"The use of the transforms.Scale transform is deprecated, please use transforms.Resize instead."
)
"The use of the transforms.Scale transform is deprecated, "
+
"please use transforms.Resize instead."
super
().
__init__
(
*
args
,
**
kwargs
)
)
super
(
Scale
,
self
).
__init__
(
*
args
,
**
kwargs
)
class
CenterCrop
(
torch
.
nn
.
Module
):
class
CenterCrop
(
torch
.
nn
.
Module
):
...
@@ -383,7 +379,7 @@ class CenterCrop(torch.nn.Module):
...
@@ -383,7 +379,7 @@ class CenterCrop(torch.nn.Module):
return
F
.
center_crop
(
img
,
self
.
size
)
return
F
.
center_crop
(
img
,
self
.
size
)
def
__repr__
(
self
):
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"(size={
0})"
.
format
(
self
.
size
)
return
self
.
__class__
.
__name__
+
f
"(size=
{
self
.
size
}
)"
class
Pad
(
torch
.
nn
.
Module
):
class
Pad
(
torch
.
nn
.
Module
):
...
@@ -437,7 +433,7 @@ class Pad(torch.nn.Module):
...
@@ -437,7 +433,7 @@ class Pad(torch.nn.Module):
if
isinstance
(
padding
,
Sequence
)
and
len
(
padding
)
not
in
[
1
,
2
,
4
]:
if
isinstance
(
padding
,
Sequence
)
and
len
(
padding
)
not
in
[
1
,
2
,
4
]:
raise
ValueError
(
raise
ValueError
(
"Padding must be an int or a 1, 2, or 4 element tuple, not a
"
+
"{
} element tuple"
.
format
(
len
(
padding
))
f
"Padding must be an int or a 1, 2, or 4 element tuple, not a
{
len
(
padding
)
}
element tuple"
)
)
self
.
padding
=
padding
self
.
padding
=
padding
...
@@ -455,9 +451,7 @@ class Pad(torch.nn.Module):
...
@@ -455,9 +451,7 @@ class Pad(torch.nn.Module):
return
F
.
pad
(
img
,
self
.
padding
,
self
.
fill
,
self
.
padding_mode
)
return
F
.
pad
(
img
,
self
.
padding
,
self
.
fill
,
self
.
padding_mode
)
def
__repr__
(
self
):
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"(padding={0}, fill={1}, padding_mode={2})"
.
format
(
return
self
.
__class__
.
__name__
+
f
"(padding=
{
self
.
padding
}
, fill=
{
self
.
fill
}
, padding_mode=
{
self
.
padding_mode
}
)"
self
.
padding
,
self
.
fill
,
self
.
padding_mode
)
class
Lambda
:
class
Lambda
:
...
@@ -469,7 +463,7 @@ class Lambda:
...
@@ -469,7 +463,7 @@ class Lambda:
def
__init__
(
self
,
lambd
):
def
__init__
(
self
,
lambd
):
if
not
callable
(
lambd
):
if
not
callable
(
lambd
):
raise
TypeError
(
"Argument lambd should be callable, got {
}"
.
format
(
repr
(
type
(
lambd
).
__name__
)
)
)
raise
TypeError
(
f
"Argument lambd should be callable, got
{
repr
(
type
(
lambd
).
__name__
)
}
"
)
self
.
lambd
=
lambd
self
.
lambd
=
lambd
def
__call__
(
self
,
img
):
def
__call__
(
self
,
img
):
...
@@ -498,7 +492,7 @@ class RandomTransforms:
...
@@ -498,7 +492,7 @@ class RandomTransforms:
format_string
=
self
.
__class__
.
__name__
+
"("
format_string
=
self
.
__class__
.
__name__
+
"("
for
t
in
self
.
transforms
:
for
t
in
self
.
transforms
:
format_string
+=
"
\n
"
format_string
+=
"
\n
"
format_string
+=
" {
0
}"
.
format
(
t
)
format_string
+=
f
"
{
t
}
"
format_string
+=
"
\n
)"
format_string
+=
"
\n
)"
return
format_string
return
format_string
...
@@ -537,10 +531,10 @@ class RandomApply(torch.nn.Module):
...
@@ -537,10 +531,10 @@ class RandomApply(torch.nn.Module):
def
__repr__
(
self
):
def
__repr__
(
self
):
format_string
=
self
.
__class__
.
__name__
+
"("
format_string
=
self
.
__class__
.
__name__
+
"("
format_string
+=
"
\n
p={
}"
.
format
(
self
.
p
)
format_string
+=
f
"
\n
p=
{
self
.
p
}
"
for
t
in
self
.
transforms
:
for
t
in
self
.
transforms
:
format_string
+=
"
\n
"
format_string
+=
"
\n
"
format_string
+=
" {
0
}"
.
format
(
t
)
format_string
+=
f
"
{
t
}
"
format_string
+=
"
\n
)"
format_string
+=
"
\n
)"
return
format_string
return
format_string
...
@@ -571,7 +565,7 @@ class RandomChoice(RandomTransforms):
...
@@ -571,7 +565,7 @@ class RandomChoice(RandomTransforms):
def
__repr__
(
self
):
def
__repr__
(
self
):
format_string
=
super
().
__repr__
()
format_string
=
super
().
__repr__
()
format_string
+=
"(p={
0})"
.
format
(
self
.
p
)
format_string
+=
f
"(p=
{
self
.
p
}
)"
return
format_string
return
format_string
...
@@ -634,7 +628,7 @@ class RandomCrop(torch.nn.Module):
...
@@ -634,7 +628,7 @@ class RandomCrop(torch.nn.Module):
th
,
tw
=
output_size
th
,
tw
=
output_size
if
h
+
1
<
th
or
w
+
1
<
tw
:
if
h
+
1
<
th
or
w
+
1
<
tw
:
raise
ValueError
(
"Required crop size {} is larger then input image size {
}"
.
format
((
t
h
,
t
w
)
,
(
h
,
w
))
)
raise
ValueError
(
f
"Required crop size
{
(
th
,
tw
)
}
is larger then input image size
{
(
h
,
w
)
}
"
)
if
w
==
tw
and
h
==
th
:
if
w
==
tw
and
h
==
th
:
return
0
,
0
,
h
,
w
return
0
,
0
,
h
,
w
...
@@ -679,7 +673,7 @@ class RandomCrop(torch.nn.Module):
...
@@ -679,7 +673,7 @@ class RandomCrop(torch.nn.Module):
return
F
.
crop
(
img
,
i
,
j
,
h
,
w
)
return
F
.
crop
(
img
,
i
,
j
,
h
,
w
)
def
__repr__
(
self
):
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"(size={
0
}, padding={
1})"
.
format
(
self
.
size
,
self
.
padding
)
return
self
.
__class__
.
__name__
+
f
"(size=
{
self
.
size
}
, padding=
{
self
.
padding
}
)"
class
RandomHorizontalFlip
(
torch
.
nn
.
Module
):
class
RandomHorizontalFlip
(
torch
.
nn
.
Module
):
...
@@ -709,7 +703,7 @@ class RandomHorizontalFlip(torch.nn.Module):
...
@@ -709,7 +703,7 @@ class RandomHorizontalFlip(torch.nn.Module):
return
img
return
img
def
__repr__
(
self
):
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"(p={
})"
.
format
(
self
.
p
)
return
self
.
__class__
.
__name__
+
f
"(p=
{
self
.
p
}
)"
class
RandomVerticalFlip
(
torch
.
nn
.
Module
):
class
RandomVerticalFlip
(
torch
.
nn
.
Module
):
...
@@ -739,7 +733,7 @@ class RandomVerticalFlip(torch.nn.Module):
...
@@ -739,7 +733,7 @@ class RandomVerticalFlip(torch.nn.Module):
return
img
return
img
def
__repr__
(
self
):
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"(p={
})"
.
format
(
self
.
p
)
return
self
.
__class__
.
__name__
+
f
"(p=
{
self
.
p
}
)"
class
RandomPerspective
(
torch
.
nn
.
Module
):
class
RandomPerspective
(
torch
.
nn
.
Module
):
...
@@ -839,7 +833,7 @@ class RandomPerspective(torch.nn.Module):
...
@@ -839,7 +833,7 @@ class RandomPerspective(torch.nn.Module):
return
startpoints
,
endpoints
return
startpoints
,
endpoints
def
__repr__
(
self
):
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"(p={
})"
.
format
(
self
.
p
)
return
self
.
__class__
.
__name__
+
f
"(p=
{
self
.
p
}
)"
class
RandomResizedCrop
(
torch
.
nn
.
Module
):
class
RandomResizedCrop
(
torch
.
nn
.
Module
):
...
@@ -951,10 +945,10 @@ class RandomResizedCrop(torch.nn.Module):
...
@@ -951,10 +945,10 @@ class RandomResizedCrop(torch.nn.Module):
def
__repr__
(
self
):
def
__repr__
(
self
):
interpolate_str
=
self
.
interpolation
.
value
interpolate_str
=
self
.
interpolation
.
value
format_string
=
self
.
__class__
.
__name__
+
"(size={
0}"
.
format
(
self
.
size
)
format_string
=
self
.
__class__
.
__name__
+
f
"(size=
{
self
.
size
}
"
format_string
+=
", scale={
0}"
.
format
(
tuple
(
round
(
s
,
4
)
for
s
in
self
.
scale
)
)
format_string
+=
f
", scale=
{
tuple
(
round
(
s
,
4
)
for
s
in
self
.
scale
)
}
"
format_string
+=
", ratio={
0}"
.
format
(
tuple
(
round
(
r
,
4
)
for
r
in
self
.
ratio
)
)
format_string
+=
f
", ratio=
{
tuple
(
round
(
r
,
4
)
for
r
in
self
.
ratio
)
}
"
format_string
+=
", interpolation={
0})"
.
format
(
interpolate_str
)
format_string
+=
f
", interpolation=
{
interpolate_str
}
)"
return
format_string
return
format_string
...
@@ -968,7 +962,7 @@ class RandomSizedCrop(RandomResizedCrop):
...
@@ -968,7 +962,7 @@ class RandomSizedCrop(RandomResizedCrop):
"The use of the transforms.RandomSizedCrop transform is deprecated, "
"The use of the transforms.RandomSizedCrop transform is deprecated, "
+
"please use transforms.RandomResizedCrop instead."
+
"please use transforms.RandomResizedCrop instead."
)
)
super
(
RandomSizedCrop
,
self
).
__init__
(
*
args
,
**
kwargs
)
super
().
__init__
(
*
args
,
**
kwargs
)
class
FiveCrop
(
torch
.
nn
.
Module
):
class
FiveCrop
(
torch
.
nn
.
Module
):
...
@@ -1014,7 +1008,7 @@ class FiveCrop(torch.nn.Module):
...
@@ -1014,7 +1008,7 @@ class FiveCrop(torch.nn.Module):
return
F
.
five_crop
(
img
,
self
.
size
)
return
F
.
five_crop
(
img
,
self
.
size
)
def
__repr__
(
self
):
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"(size={
0})"
.
format
(
self
.
size
)
return
self
.
__class__
.
__name__
+
f
"(size=
{
self
.
size
}
)"
class
TenCrop
(
torch
.
nn
.
Module
):
class
TenCrop
(
torch
.
nn
.
Module
):
...
@@ -1063,7 +1057,7 @@ class TenCrop(torch.nn.Module):
...
@@ -1063,7 +1057,7 @@ class TenCrop(torch.nn.Module):
return
F
.
ten_crop
(
img
,
self
.
size
,
self
.
vertical_flip
)
return
F
.
ten_crop
(
img
,
self
.
size
,
self
.
vertical_flip
)
def
__repr__
(
self
):
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"(size={
0
}, vertical_flip={
1})"
.
format
(
self
.
size
,
self
.
vertical_flip
)
return
self
.
__class__
.
__name__
+
f
"(size=
{
self
.
size
}
, vertical_flip=
{
self
.
vertical_flip
}
)"
class
LinearTransformation
(
torch
.
nn
.
Module
):
class
LinearTransformation
(
torch
.
nn
.
Module
):
...
@@ -1090,22 +1084,18 @@ class LinearTransformation(torch.nn.Module):
...
@@ -1090,22 +1084,18 @@ class LinearTransformation(torch.nn.Module):
if
transformation_matrix
.
size
(
0
)
!=
transformation_matrix
.
size
(
1
):
if
transformation_matrix
.
size
(
0
)
!=
transformation_matrix
.
size
(
1
):
raise
ValueError
(
raise
ValueError
(
"transformation_matrix should be square. Got "
"transformation_matrix should be square. Got "
+
"[{} x {}] rectangular matrix."
.
format
(
*
transformation_matrix
.
size
())
f
"
{
tuple
(
transformation_matrix
.
size
())
}
rectangular matrix."
)
)
if
mean_vector
.
size
(
0
)
!=
transformation_matrix
.
size
(
0
):
if
mean_vector
.
size
(
0
)
!=
transformation_matrix
.
size
(
0
):
raise
ValueError
(
raise
ValueError
(
"mean_vector should have the same length {}"
.
format
(
mean_vector
.
size
(
0
))
f
"mean_vector should have the same length
{
mean_vector
.
size
(
0
)
}
"
+
" as any one of the dimensions of the transformation_matrix [{}]"
.
format
(
f
" as any one of the dimensions of the transformation_matrix [
{
tuple
(
transformation_matrix
.
size
())
}
]"
tuple
(
transformation_matrix
.
size
())
)
)
)
if
transformation_matrix
.
device
!=
mean_vector
.
device
:
if
transformation_matrix
.
device
!=
mean_vector
.
device
:
raise
ValueError
(
raise
ValueError
(
"Input tensors should be on the same device. Got {} and {}"
.
format
(
f
"Input tensors should be on the same device. Got
{
transformation_matrix
.
device
}
and
{
mean_vector
.
device
}
"
transformation_matrix
.
device
,
mean_vector
.
device
)
)
)
self
.
transformation_matrix
=
transformation_matrix
self
.
transformation_matrix
=
transformation_matrix
...
@@ -1124,14 +1114,14 @@ class LinearTransformation(torch.nn.Module):
...
@@ -1124,14 +1114,14 @@ class LinearTransformation(torch.nn.Module):
if
n
!=
self
.
transformation_matrix
.
shape
[
0
]:
if
n
!=
self
.
transformation_matrix
.
shape
[
0
]:
raise
ValueError
(
raise
ValueError
(
"Input tensor and transformation matrix have incompatible shape."
"Input tensor and transformation matrix have incompatible shape."
+
"[{
} x {} x {}] != "
.
format
(
shape
[
-
3
]
,
shape
[
-
2
]
,
shape
[
-
1
]
)
+
f
"[
{
shape
[
-
3
]
}
x
{
shape
[
-
2
]
}
x
{
shape
[
-
1
]
}
] != "
+
"{
}"
.
format
(
self
.
transformation_matrix
.
shape
[
0
]
)
+
f
"
{
self
.
transformation_matrix
.
shape
[
0
]
}
"
)
)
if
tensor
.
device
.
type
!=
self
.
mean_vector
.
device
.
type
:
if
tensor
.
device
.
type
!=
self
.
mean_vector
.
device
.
type
:
raise
ValueError
(
raise
ValueError
(
"Input tensor should be on the same device as transformation matrix and mean vector. "
"Input tensor should be on the same device as transformation matrix and mean vector. "
"Got {
} vs {}"
.
format
(
tensor
.
device
,
self
.
mean_vector
.
device
)
f
"Got
{
tensor
.
device
}
vs
{
self
.
mean_vector
.
device
}
"
)
)
flat_tensor
=
tensor
.
view
(
-
1
,
n
)
-
self
.
mean_vector
flat_tensor
=
tensor
.
view
(
-
1
,
n
)
-
self
.
mean_vector
...
@@ -1178,15 +1168,15 @@ class ColorJitter(torch.nn.Module):
...
@@ -1178,15 +1168,15 @@ class ColorJitter(torch.nn.Module):
def
_check_input
(
self
,
value
,
name
,
center
=
1
,
bound
=
(
0
,
float
(
"inf"
)),
clip_first_on_zero
=
True
):
def
_check_input
(
self
,
value
,
name
,
center
=
1
,
bound
=
(
0
,
float
(
"inf"
)),
clip_first_on_zero
=
True
):
if
isinstance
(
value
,
numbers
.
Number
):
if
isinstance
(
value
,
numbers
.
Number
):
if
value
<
0
:
if
value
<
0
:
raise
ValueError
(
"If {} is a single number, it must be non negative."
.
format
(
name
)
)
raise
ValueError
(
f
"If
{
name
}
is a single number, it must be non negative."
)
value
=
[
center
-
float
(
value
),
center
+
float
(
value
)]
value
=
[
center
-
float
(
value
),
center
+
float
(
value
)]
if
clip_first_on_zero
:
if
clip_first_on_zero
:
value
[
0
]
=
max
(
value
[
0
],
0.0
)
value
[
0
]
=
max
(
value
[
0
],
0.0
)
elif
isinstance
(
value
,
(
tuple
,
list
))
and
len
(
value
)
==
2
:
elif
isinstance
(
value
,
(
tuple
,
list
))
and
len
(
value
)
==
2
:
if
not
bound
[
0
]
<=
value
[
0
]
<=
value
[
1
]
<=
bound
[
1
]:
if
not
bound
[
0
]
<=
value
[
0
]
<=
value
[
1
]
<=
bound
[
1
]:
raise
ValueError
(
"{} values should be between {
}"
.
format
(
name
,
bound
)
)
raise
ValueError
(
f
"
{
name
}
values should be between
{
bound
}
"
)
else
:
else
:
raise
TypeError
(
"{} should be a single number or a list/tuple with length 2."
.
format
(
name
)
)
raise
TypeError
(
f
"
{
name
}
should be a single number or a list/tuple with length 2."
)
# if value is 0 or (1., 1.) for brightness/contrast/saturation
# if value is 0 or (1., 1.) for brightness/contrast/saturation
# or (0., 0.) for hue, do nothing
# or (0., 0.) for hue, do nothing
...
@@ -1252,10 +1242,10 @@ class ColorJitter(torch.nn.Module):
...
@@ -1252,10 +1242,10 @@ class ColorJitter(torch.nn.Module):
def
__repr__
(
self
):
def
__repr__
(
self
):
format_string
=
self
.
__class__
.
__name__
+
"("
format_string
=
self
.
__class__
.
__name__
+
"("
format_string
+=
"brightness={
0}"
.
format
(
self
.
brightness
)
format_string
+=
f
"brightness=
{
self
.
brightness
}
"
format_string
+=
", contrast={
0}"
.
format
(
self
.
contrast
)
format_string
+=
f
", contrast=
{
self
.
contrast
}
"
format_string
+=
", saturation={
0}"
.
format
(
self
.
saturation
)
format_string
+=
f
", saturation=
{
self
.
saturation
}
"
format_string
+=
", hue={
0})"
.
format
(
self
.
hue
)
format_string
+=
f
", hue=
{
self
.
hue
}
)"
return
format_string
return
format_string
...
@@ -1352,13 +1342,13 @@ class RandomRotation(torch.nn.Module):
...
@@ -1352,13 +1342,13 @@ class RandomRotation(torch.nn.Module):
def
__repr__
(
self
):
def
__repr__
(
self
):
interpolate_str
=
self
.
interpolation
.
value
interpolate_str
=
self
.
interpolation
.
value
format_string
=
self
.
__class__
.
__name__
+
"(degrees={
0}"
.
format
(
self
.
degrees
)
format_string
=
self
.
__class__
.
__name__
+
f
"(degrees=
{
self
.
degrees
}
"
format_string
+=
", interpolation={
0}"
.
format
(
interpolate_str
)
format_string
+=
f
", interpolation=
{
interpolate_str
}
"
format_string
+=
", expand={
0}"
.
format
(
self
.
expand
)
format_string
+=
f
", expand=
{
self
.
expand
}
"
if
self
.
center
is
not
None
:
if
self
.
center
is
not
None
:
format_string
+=
", center={
0}"
.
format
(
self
.
center
)
format_string
+=
f
", center=
{
self
.
center
}
"
if
self
.
fill
is
not
None
:
if
self
.
fill
is
not
None
:
format_string
+=
", fill={
0}"
.
format
(
self
.
fill
)
format_string
+=
f
", fill=
{
self
.
fill
}
"
format_string
+=
")"
format_string
+=
")"
return
format_string
return
format_string
...
@@ -1568,7 +1558,7 @@ class Grayscale(torch.nn.Module):
...
@@ -1568,7 +1558,7 @@ class Grayscale(torch.nn.Module):
return
F
.
rgb_to_grayscale
(
img
,
num_output_channels
=
self
.
num_output_channels
)
return
F
.
rgb_to_grayscale
(
img
,
num_output_channels
=
self
.
num_output_channels
)
def
__repr__
(
self
):
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"(num_output_channels={
0})"
.
format
(
self
.
num_output_channels
)
return
self
.
__class__
.
__name__
+
f
"(num_output_channels=
{
self
.
num_output_channels
}
)"
class
RandomGrayscale
(
torch
.
nn
.
Module
):
class
RandomGrayscale
(
torch
.
nn
.
Module
):
...
@@ -1605,7 +1595,7 @@ class RandomGrayscale(torch.nn.Module):
...
@@ -1605,7 +1595,7 @@ class RandomGrayscale(torch.nn.Module):
return
img
return
img
def
__repr__
(
self
):
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"(p={
0})"
.
format
(
self
.
p
)
return
self
.
__class__
.
__name__
+
f
"(p=
{
self
.
p
}
)"
class
RandomErasing
(
torch
.
nn
.
Module
):
class
RandomErasing
(
torch
.
nn
.
Module
):
...
@@ -1726,7 +1716,7 @@ class RandomErasing(torch.nn.Module):
...
@@ -1726,7 +1716,7 @@ class RandomErasing(torch.nn.Module):
if
value
is
not
None
and
not
(
len
(
value
)
in
(
1
,
img
.
shape
[
-
3
])):
if
value
is
not
None
and
not
(
len
(
value
)
in
(
1
,
img
.
shape
[
-
3
])):
raise
ValueError
(
raise
ValueError
(
"If value is a sequence, it should have either a single value or "
"If value is a sequence, it should have either a single value or "
"{} (number of input channels)"
.
format
(
img
.
shape
[
-
3
])
f
"
{
img
.
shape
[
-
3
]
}
(number of input channels)"
)
)
x
,
y
,
h
,
w
,
v
=
self
.
get_params
(
img
,
scale
=
self
.
scale
,
ratio
=
self
.
ratio
,
value
=
value
)
x
,
y
,
h
,
w
,
v
=
self
.
get_params
(
img
,
scale
=
self
.
scale
,
ratio
=
self
.
ratio
,
value
=
value
)
...
@@ -1734,11 +1724,11 @@ class RandomErasing(torch.nn.Module):
...
@@ -1734,11 +1724,11 @@ class RandomErasing(torch.nn.Module):
return
img
return
img
def
__repr__
(
self
):
def
__repr__
(
self
):
s
=
"(p={
}, "
.
format
(
self
.
p
)
s
=
f
"(p=
{
self
.
p
}
, "
s
+=
"scale={
}, "
.
format
(
self
.
scale
)
s
+=
f
"scale=
{
self
.
scale
}
, "
s
+=
"ratio={
}, "
.
format
(
self
.
ratio
)
s
+=
f
"ratio=
{
self
.
ratio
}
, "
s
+=
"value={
}, "
.
format
(
self
.
value
)
s
+=
f
"value=
{
self
.
value
}
, "
s
+=
"inplace={
})"
.
format
(
self
.
inplace
)
s
+=
f
"inplace=
{
self
.
inplace
}
)"
return
self
.
__class__
.
__name__
+
s
return
self
.
__class__
.
__name__
+
s
...
@@ -1803,8 +1793,8 @@ class GaussianBlur(torch.nn.Module):
...
@@ -1803,8 +1793,8 @@ class GaussianBlur(torch.nn.Module):
return
F
.
gaussian_blur
(
img
,
self
.
kernel_size
,
[
sigma
,
sigma
])
return
F
.
gaussian_blur
(
img
,
self
.
kernel_size
,
[
sigma
,
sigma
])
def
__repr__
(
self
):
def
__repr__
(
self
):
s
=
"(kernel_size={
}, "
.
format
(
self
.
kernel_size
)
s
=
f
"(kernel_size=
{
self
.
kernel_size
}
, "
s
+=
"sigma={
})"
.
format
(
self
.
sigma
)
s
+=
f
"sigma=
{
self
.
sigma
}
)"
return
self
.
__class__
.
__name__
+
s
return
self
.
__class__
.
__name__
+
s
...
@@ -1824,15 +1814,15 @@ def _setup_size(size, error_msg):
...
@@ -1824,15 +1814,15 @@ def _setup_size(size, error_msg):
def
_check_sequence_input
(
x
,
name
,
req_sizes
):
def
_check_sequence_input
(
x
,
name
,
req_sizes
):
msg
=
req_sizes
[
0
]
if
len
(
req_sizes
)
<
2
else
" or "
.
join
([
str
(
s
)
for
s
in
req_sizes
])
msg
=
req_sizes
[
0
]
if
len
(
req_sizes
)
<
2
else
" or "
.
join
([
str
(
s
)
for
s
in
req_sizes
])
if
not
isinstance
(
x
,
Sequence
):
if
not
isinstance
(
x
,
Sequence
):
raise
TypeError
(
"{} should be a sequence of length {}."
.
format
(
name
,
msg
)
)
raise
TypeError
(
f
"
{
name
}
should be a sequence of length
{
msg
}
."
)
if
len
(
x
)
not
in
req_sizes
:
if
len
(
x
)
not
in
req_sizes
:
raise
ValueError
(
"{} should be sequence of length {}."
.
format
(
name
,
msg
)
)
raise
ValueError
(
f
"
{
name
}
should be sequence of length
{
msg
}
."
)
def
_setup_angle
(
x
,
name
,
req_sizes
=
(
2
,)):
def
_setup_angle
(
x
,
name
,
req_sizes
=
(
2
,)):
if
isinstance
(
x
,
numbers
.
Number
):
if
isinstance
(
x
,
numbers
.
Number
):
if
x
<
0
:
if
x
<
0
:
raise
ValueError
(
"If {} is a single number, it must be positive."
.
format
(
name
)
)
raise
ValueError
(
f
"If
{
name
}
is a single number, it must be positive."
)
x
=
[
-
x
,
x
]
x
=
[
-
x
,
x
]
else
:
else
:
_check_sequence_input
(
x
,
name
,
req_sizes
)
_check_sequence_input
(
x
,
name
,
req_sizes
)
...
@@ -1867,7 +1857,7 @@ class RandomInvert(torch.nn.Module):
...
@@ -1867,7 +1857,7 @@ class RandomInvert(torch.nn.Module):
return
img
return
img
def
__repr__
(
self
):
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"(p={
})"
.
format
(
self
.
p
)
return
self
.
__class__
.
__name__
+
f
"(p=
{
self
.
p
}
)"
class
RandomPosterize
(
torch
.
nn
.
Module
):
class
RandomPosterize
(
torch
.
nn
.
Module
):
...
@@ -1899,7 +1889,7 @@ class RandomPosterize(torch.nn.Module):
...
@@ -1899,7 +1889,7 @@ class RandomPosterize(torch.nn.Module):
return
img
return
img
def
__repr__
(
self
):
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"(bits={
},p={})"
.
format
(
self
.
bits
,
self
.
p
)
return
self
.
__class__
.
__name__
+
f
"(bits=
{
self
.
bits
}
,p=
{
self
.
p
}
)"
class
RandomSolarize
(
torch
.
nn
.
Module
):
class
RandomSolarize
(
torch
.
nn
.
Module
):
...
@@ -1931,7 +1921,7 @@ class RandomSolarize(torch.nn.Module):
...
@@ -1931,7 +1921,7 @@ class RandomSolarize(torch.nn.Module):
return
img
return
img
def
__repr__
(
self
):
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"(threshold={
},p={})"
.
format
(
self
.
threshold
,
self
.
p
)
return
self
.
__class__
.
__name__
+
f
"(threshold=
{
self
.
threshold
}
,p=
{
self
.
p
}
)"
class
RandomAdjustSharpness
(
torch
.
nn
.
Module
):
class
RandomAdjustSharpness
(
torch
.
nn
.
Module
):
...
@@ -1963,7 +1953,7 @@ class RandomAdjustSharpness(torch.nn.Module):
...
@@ -1963,7 +1953,7 @@ class RandomAdjustSharpness(torch.nn.Module):
return
img
return
img
def
__repr__
(
self
):
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"(sharpness_factor={
},p={})"
.
format
(
self
.
sharpness_factor
,
self
.
p
)
return
self
.
__class__
.
__name__
+
f
"(sharpness_factor=
{
self
.
sharpness_factor
}
,p=
{
self
.
p
}
)"
class
RandomAutocontrast
(
torch
.
nn
.
Module
):
class
RandomAutocontrast
(
torch
.
nn
.
Module
):
...
@@ -1993,7 +1983,7 @@ class RandomAutocontrast(torch.nn.Module):
...
@@ -1993,7 +1983,7 @@ class RandomAutocontrast(torch.nn.Module):
return
img
return
img
def
__repr__
(
self
):
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"(p={
})"
.
format
(
self
.
p
)
return
self
.
__class__
.
__name__
+
f
"(p=
{
self
.
p
}
)"
class
RandomEqualize
(
torch
.
nn
.
Module
):
class
RandomEqualize
(
torch
.
nn
.
Module
):
...
@@ -2023,4 +2013,4 @@ class RandomEqualize(torch.nn.Module):
...
@@ -2023,4 +2013,4 @@ class RandomEqualize(torch.nn.Module):
return
img
return
img
def
__repr__
(
self
):
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"(p={
})"
.
format
(
self
.
p
)
return
self
.
__class__
.
__name__
+
f
"(p=
{
self
.
p
}
)"
torchvision/utils.py
View file @
d367a01a
import
math
import
math
import
pathlib
import
pathlib
import
warnings
import
warnings
from
typing
import
Union
,
Optional
,
List
,
Tuple
,
Text
,
BinaryIO
from
typing
import
Union
,
Optional
,
List
,
Tuple
,
BinaryIO
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -114,7 +114,7 @@ def make_grid(
...
@@ -114,7 +114,7 @@ def make_grid(
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
save_image
(
def
save_image
(
tensor
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]],
tensor
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]],
fp
:
Union
[
Text
,
pathlib
.
Path
,
BinaryIO
],
fp
:
Union
[
str
,
pathlib
.
Path
,
BinaryIO
],
format
:
Optional
[
str
]
=
None
,
format
:
Optional
[
str
]
=
None
,
**
kwargs
,
**
kwargs
,
)
->
None
:
)
->
None
:
...
...
Prev
1
…
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment