Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
d367a01a
Unverified
Commit
d367a01a
authored
Oct 28, 2021
by
Jirka Borovec
Committed by
GitHub
Oct 28, 2021
Browse files
Use f-strings almost everywhere, and other cleanups by applying pyupgrade (#4585)
Co-authored-by:
Nicolas Hug
<
nicolashug@fb.com
>
parent
50dfe207
Changes
136
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
174 additions
and
188 deletions
+174
-188
torchvision/ops/ps_roi_align.py
torchvision/ops/ps_roi_align.py
+1
-1
torchvision/ops/ps_roi_pool.py
torchvision/ops/ps_roi_pool.py
+1
-1
torchvision/ops/roi_align.py
torchvision/ops/roi_align.py
+1
-1
torchvision/ops/roi_pool.py
torchvision/ops/roi_pool.py
+1
-1
torchvision/ops/stochastic_depth.py
torchvision/ops/stochastic_depth.py
+2
-2
torchvision/prototype/datasets/utils/_dataset.py
torchvision/prototype/datasets/utils/_dataset.py
+1
-1
torchvision/prototype/datasets/utils/_internal.py
torchvision/prototype/datasets/utils/_internal.py
+1
-1
torchvision/prototype/models/_api.py
torchvision/prototype/models/_api.py
+2
-2
torchvision/transforms/_functional_video.py
torchvision/transforms/_functional_video.py
+1
-1
torchvision/transforms/_transforms_video.py
torchvision/transforms/_transforms_video.py
+11
-10
torchvision/transforms/autoaugment.py
torchvision/transforms/autoaugment.py
+3
-3
torchvision/transforms/functional.py
torchvision/transforms/functional.py
+31
-33
torchvision/transforms/functional_pil.py
torchvision/transforms/functional_pil.py
+26
-28
torchvision/transforms/functional_tensor.py
torchvision/transforms/functional_tensor.py
+24
-25
torchvision/transforms/transforms.py
torchvision/transforms/transforms.py
+66
-76
torchvision/utils.py
torchvision/utils.py
+2
-2
No files found.
torchvision/ops/ps_roi_align.py
View file @
d367a01a
...
...
@@ -65,7 +65,7 @@ class PSRoIAlign(nn.Module):
spatial_scale
:
float
,
sampling_ratio
:
int
,
):
super
(
PSRoIAlign
,
self
).
__init__
()
super
().
__init__
()
self
.
output_size
=
output_size
self
.
spatial_scale
=
spatial_scale
self
.
sampling_ratio
=
sampling_ratio
...
...
torchvision/ops/ps_roi_pool.py
View file @
d367a01a
...
...
@@ -52,7 +52,7 @@ class PSRoIPool(nn.Module):
"""
def
__init__
(
self
,
output_size
:
int
,
spatial_scale
:
float
):
super
(
PSRoIPool
,
self
).
__init__
()
super
().
__init__
()
self
.
output_size
=
output_size
self
.
spatial_scale
=
spatial_scale
...
...
torchvision/ops/roi_align.py
View file @
d367a01a
...
...
@@ -72,7 +72,7 @@ class RoIAlign(nn.Module):
sampling_ratio
:
int
,
aligned
:
bool
=
False
,
):
super
(
RoIAlign
,
self
).
__init__
()
super
().
__init__
()
self
.
output_size
=
output_size
self
.
spatial_scale
=
spatial_scale
self
.
sampling_ratio
=
sampling_ratio
...
...
torchvision/ops/roi_pool.py
View file @
d367a01a
...
...
@@ -54,7 +54,7 @@ class RoIPool(nn.Module):
"""
def
__init__
(
self
,
output_size
:
BroadcastingList2
[
int
],
spatial_scale
:
float
):
super
(
RoIPool
,
self
).
__init__
()
super
().
__init__
()
self
.
output_size
=
output_size
self
.
spatial_scale
=
spatial_scale
...
...
torchvision/ops/stochastic_depth.py
View file @
d367a01a
...
...
@@ -22,9 +22,9 @@ def stochastic_depth(input: Tensor, p: float, mode: str, training: bool = True)
Tensor[N, ...]: The randomly zeroed tensor.
"""
if
p
<
0.0
or
p
>
1.0
:
raise
ValueError
(
"drop probability has to be between 0 and 1, but got {}"
.
format
(
p
)
)
raise
ValueError
(
f
"drop probability has to be between 0 and 1, but got
{
p
}
"
)
if
mode
not
in
[
"batch"
,
"row"
]:
raise
ValueError
(
"mode has to be either 'batch' or 'row', but got {
}"
.
format
(
mode
)
)
raise
ValueError
(
f
"mode has to be either 'batch' or 'row', but got
{
mode
}
"
)
if
not
training
or
p
==
0.0
:
return
input
...
...
torchvision/prototype/datasets/utils/_dataset.py
View file @
d367a01a
...
...
@@ -71,7 +71,7 @@ class DatasetInfo:
@
staticmethod
def
read_categories_file
(
path
:
pathlib
.
Path
)
->
List
[
List
[
str
]]:
with
open
(
path
,
"r"
,
newline
=
""
)
as
file
:
with
open
(
path
,
newline
=
""
)
as
file
:
return
[
row
for
row
in
csv
.
reader
(
file
)]
@
property
...
...
torchvision/prototype/datasets/utils/_internal.py
View file @
d367a01a
...
...
@@ -63,7 +63,7 @@ def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
if
len
(
seq
)
==
1
:
return
f
"'
{
seq
[
0
]
}
'"
return
f
"""'
{
"', '"
.
join
([
str
(
item
)
for
item
in
seq
[:
-
1
]])
}
',
"""
f
"""
{
separate_last
}
'
{
seq
[
-
1
]
}
'."""
return
f
"""'
{
"', '"
.
join
([
str
(
item
)
for
item
in
seq
[:
-
1
]])
}
',
{
separate_last
}
'
{
seq
[
-
1
]
}
'."""
def
add_suggestion
(
...
...
torchvision/prototype/models/_api.py
View file @
d367a01a
...
...
@@ -52,7 +52,7 @@ class Weights(Enum):
obj
=
cls
.
from_str
(
obj
)
elif
not
isinstance
(
obj
,
cls
)
and
not
isinstance
(
obj
,
WeightEntry
):
raise
TypeError
(
f
"Invalid Weight class provided; expected
{
cls
.
__name__
}
"
f
"
but received
{
obj
.
__class__
.
__name__
}
."
f
"Invalid Weight class provided; expected
{
cls
.
__name__
}
but received
{
obj
.
__class__
.
__name__
}
."
)
return
obj
...
...
@@ -106,7 +106,7 @@ def get_weight(fn: Callable, weight_name: str) -> Weights:
if
weights_class
is
None
:
raise
ValueError
(
"The weight class for the specific method couldn't be retrieved. Make sure the typing info is
"
"
correct."
"The weight class for the specific method couldn't be retrieved. Make sure the typing info is correct."
)
return
weights_class
.
from_str
(
weight_name
)
torchvision/transforms/_functional_video.py
View file @
d367a01a
...
...
@@ -101,4 +101,4 @@ def hflip(clip):
flipped clip (torch.tensor): Size is (C, T, H, W)
"""
assert
_is_tensor_video_clip
(
clip
),
"clip should be a 4D torch.tensor"
return
clip
.
flip
(
(
-
1
)
)
return
clip
.
flip
(
-
1
)
torchvision/transforms/_transforms_video.py
View file @
d367a01a
...
...
@@ -44,7 +44,7 @@ class RandomCropVideo(RandomCrop):
return
F
.
crop
(
clip
,
i
,
j
,
h
,
w
)
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"(size={
0})"
.
format
(
self
.
size
)
return
self
.
__class__
.
__name__
+
f
"(size=
{
self
.
size
}
)"
class
RandomResizedCropVideo
(
RandomResizedCrop
):
...
...
@@ -77,12 +77,13 @@ class RandomResizedCropVideo(RandomResizedCrop):
return
F
.
resized_crop
(
clip
,
i
,
j
,
h
,
w
,
self
.
size
,
self
.
interpolation_mode
)
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"(size={0}, interpolation_mode={1}, scale={2}, ratio={3})"
.
format
(
self
.
size
,
self
.
interpolation_mode
,
self
.
scale
,
self
.
ratio
return
(
self
.
__class__
.
__name__
+
f
"(size=
{
self
.
size
}
, interpolation_mode=
{
self
.
interpolation_mode
}
, scale=
{
self
.
scale
}
, ratio=
{
self
.
ratio
}
)"
)
class
CenterCropVideo
(
object
)
:
class
CenterCropVideo
:
def
__init__
(
self
,
crop_size
):
if
isinstance
(
crop_size
,
numbers
.
Number
):
self
.
crop_size
=
(
int
(
crop_size
),
int
(
crop_size
))
...
...
@@ -100,10 +101,10 @@ class CenterCropVideo(object):
return
F
.
center_crop
(
clip
,
self
.
crop_size
)
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"(crop_size={
0})"
.
format
(
self
.
crop_size
)
return
self
.
__class__
.
__name__
+
f
"(crop_size=
{
self
.
crop_size
}
)"
class
NormalizeVideo
(
object
)
:
class
NormalizeVideo
:
"""
Normalize the video clip by mean subtraction and division by standard deviation
Args:
...
...
@@ -125,10 +126,10 @@ class NormalizeVideo(object):
return
F
.
normalize
(
clip
,
self
.
mean
,
self
.
std
,
self
.
inplace
)
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"(mean={
0}, std={1}, inplace={2})"
.
format
(
self
.
mean
,
self
.
std
,
self
.
inplace
)
return
self
.
__class__
.
__name__
+
f
"(mean=
{
self
.
mean
}
,
std=
{
self
.
std
}
,
inplace=
{
self
.
inplace
}
)"
class
ToTensorVideo
(
object
)
:
class
ToTensorVideo
:
"""
Convert tensor data type from uint8 to float, divide value by 255.0 and
permute the dimensions of clip tensor
...
...
@@ -150,7 +151,7 @@ class ToTensorVideo(object):
return
self
.
__class__
.
__name__
class
RandomHorizontalFlipVideo
(
object
)
:
class
RandomHorizontalFlipVideo
:
"""
Flip the video clip along the horizonal direction with a given probability
Args:
...
...
@@ -172,4 +173,4 @@ class RandomHorizontalFlipVideo(object):
return
clip
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"(p={
0})"
.
format
(
self
.
p
)
return
self
.
__class__
.
__name__
+
f
"(p=
{
self
.
p
}
)"
torchvision/transforms/autoaugment.py
View file @
d367a01a
...
...
@@ -76,7 +76,7 @@ def _apply_op(
elif
op_name
==
"Identity"
:
pass
else
:
raise
ValueError
(
"The provided operator {} is not recognized."
.
format
(
op_name
)
)
raise
ValueError
(
f
"The provided operator
{
op_name
}
is not recognized."
)
return
img
...
...
@@ -208,7 +208,7 @@ class AutoAugment(torch.nn.Module):
((
"ShearX"
,
0.7
,
2
),
(
"Invert"
,
0.1
,
None
)),
]
else
:
raise
ValueError
(
"The provided policy {} is not recognized."
.
format
(
policy
)
)
raise
ValueError
(
f
"The provided policy
{
policy
}
is not recognized."
)
def
_augmentation_space
(
self
,
num_bins
:
int
,
image_size
:
List
[
int
])
->
Dict
[
str
,
Tuple
[
Tensor
,
bool
]]:
return
{
...
...
@@ -270,7 +270,7 @@ class AutoAugment(torch.nn.Module):
return
img
def
__repr__
(
self
)
->
str
:
return
self
.
__class__
.
__name__
+
"(policy={
}, fill={})"
.
format
(
self
.
policy
,
self
.
fill
)
return
self
.
__class__
.
__name__
+
f
"(policy=
{
self
.
policy
}
,
fill=
{
self
.
fill
}
)"
class
RandAugment
(
torch
.
nn
.
Module
):
...
...
torchvision/transforms/functional.py
View file @
d367a01a
...
...
@@ -111,10 +111,10 @@ def to_tensor(pic):
Tensor: Converted image.
"""
if
not
(
F_pil
.
_is_pil_image
(
pic
)
or
_is_numpy
(
pic
)):
raise
TypeError
(
"pic should be PIL Image or ndarray. Got {
}"
.
format
(
type
(
pic
)
)
)
raise
TypeError
(
f
"pic should be PIL Image or ndarray. Got
{
type
(
pic
)
}
"
)
if
_is_numpy
(
pic
)
and
not
_is_numpy_image
(
pic
):
raise
ValueError
(
"pic should be 2/3 dimensional. Got {} dimensions."
.
format
(
pic
.
ndim
)
)
raise
ValueError
(
f
"pic should be 2/3 dimensional. Got
{
pic
.
ndim
}
dimensions."
)
default_float_dtype
=
torch
.
get_default_dtype
()
...
...
@@ -167,7 +167,7 @@ def pil_to_tensor(pic):
Tensor: Converted image.
"""
if
not
F_pil
.
_is_pil_image
(
pic
):
raise
TypeError
(
"pic should be PIL Image. Got {
}"
.
format
(
type
(
pic
)
)
)
raise
TypeError
(
f
"pic should be PIL Image. Got
{
type
(
pic
)
}
"
)
if
accimage
is
not
None
and
isinstance
(
pic
,
accimage
.
Image
):
# accimage format is always uint8 internally, so always return uint8 here
...
...
@@ -226,11 +226,11 @@ def to_pil_image(pic, mode=None):
PIL Image: Image converted to PIL Image.
"""
if
not
(
isinstance
(
pic
,
torch
.
Tensor
)
or
isinstance
(
pic
,
np
.
ndarray
)):
raise
TypeError
(
"pic should be Tensor or ndarray. Got {
}."
.
format
(
type
(
pic
)
)
)
raise
TypeError
(
f
"pic should be Tensor or ndarray. Got
{
type
(
pic
)
}
."
)
elif
isinstance
(
pic
,
torch
.
Tensor
):
if
pic
.
ndimension
()
not
in
{
2
,
3
}:
raise
ValueError
(
"pic should be 2/3 dimensional. Got {
}
dimension
s."
.
format
(
pic
.
n
dimension
())
)
raise
ValueError
(
f
"pic should be 2/3 dimensional. Got
{
pic
.
n
dimension
()
}
dimension
s."
)
elif
pic
.
ndimension
()
==
2
:
# if 2D image, add channel dimension (CHW)
...
...
@@ -238,11 +238,11 @@ def to_pil_image(pic, mode=None):
# check number of channels
if
pic
.
shape
[
-
3
]
>
4
:
raise
ValueError
(
"pic should not have > 4 channels. Got {
} channels."
.
format
(
pic
.
shape
[
-
3
]
)
)
raise
ValueError
(
f
"pic should not have > 4 channels. Got
{
pic
.
shape
[
-
3
]
}
channels."
)
elif
isinstance
(
pic
,
np
.
ndarray
):
if
pic
.
ndim
not
in
{
2
,
3
}:
raise
ValueError
(
"pic should be 2/3 dimensional. Got {} dimensions."
.
format
(
pic
.
ndim
)
)
raise
ValueError
(
f
"pic should be 2/3 dimensional. Got
{
pic
.
ndim
}
dimensions."
)
elif
pic
.
ndim
==
2
:
# if 2D image, add channel dimension (HWC)
...
...
@@ -250,7 +250,7 @@ def to_pil_image(pic, mode=None):
# check number of channels
if
pic
.
shape
[
-
1
]
>
4
:
raise
ValueError
(
"pic should not have > 4 channels. Got {
} channels."
.
format
(
pic
.
shape
[
-
1
]
)
)
raise
ValueError
(
f
"pic should not have > 4 channels. Got
{
pic
.
shape
[
-
1
]
}
channels."
)
npimg
=
pic
if
isinstance
(
pic
,
torch
.
Tensor
):
...
...
@@ -259,7 +259,7 @@ def to_pil_image(pic, mode=None):
npimg
=
np
.
transpose
(
pic
.
cpu
().
numpy
(),
(
1
,
2
,
0
))
if
not
isinstance
(
npimg
,
np
.
ndarray
):
raise
TypeError
(
"Input pic must be a torch.Tensor or NumPy ndarray,
"
+
"not {}"
.
format
(
type
(
npimg
)
)
)
raise
TypeError
(
"Input pic must be a torch.Tensor or NumPy ndarray,
not {
type(npimg)
}"
)
if
npimg
.
shape
[
2
]
==
1
:
expected_mode
=
None
...
...
@@ -273,15 +273,13 @@ def to_pil_image(pic, mode=None):
elif
npimg
.
dtype
==
np
.
float32
:
expected_mode
=
"F"
if
mode
is
not
None
and
mode
!=
expected_mode
:
raise
ValueError
(
"Incorrect mode ({}) supplied for input type {}. Should be {}"
.
format
(
mode
,
np
.
dtype
,
expected_mode
)
)
raise
ValueError
(
f
"Incorrect mode (
{
mode
}
) supplied for input type
{
np
.
dtype
}
. Should be
{
expected_mode
}
"
)
mode
=
expected_mode
elif
npimg
.
shape
[
2
]
==
2
:
permitted_2_channel_modes
=
[
"LA"
]
if
mode
is
not
None
and
mode
not
in
permitted_2_channel_modes
:
raise
ValueError
(
"Only modes {} are supported for 2D inputs"
.
format
(
permitted_2_channel_modes
)
)
raise
ValueError
(
f
"Only modes
{
permitted_2_channel_modes
}
are supported for 2D inputs"
)
if
mode
is
None
and
npimg
.
dtype
==
np
.
uint8
:
mode
=
"LA"
...
...
@@ -289,19 +287,19 @@ def to_pil_image(pic, mode=None):
elif
npimg
.
shape
[
2
]
==
4
:
permitted_4_channel_modes
=
[
"RGBA"
,
"CMYK"
,
"RGBX"
]
if
mode
is
not
None
and
mode
not
in
permitted_4_channel_modes
:
raise
ValueError
(
"Only modes {} are supported for 4D inputs"
.
format
(
permitted_4_channel_modes
)
)
raise
ValueError
(
f
"Only modes
{
permitted_4_channel_modes
}
are supported for 4D inputs"
)
if
mode
is
None
and
npimg
.
dtype
==
np
.
uint8
:
mode
=
"RGBA"
else
:
permitted_3_channel_modes
=
[
"RGB"
,
"YCbCr"
,
"HSV"
]
if
mode
is
not
None
and
mode
not
in
permitted_3_channel_modes
:
raise
ValueError
(
"Only modes {} are supported for 3D inputs"
.
format
(
permitted_3_channel_modes
)
)
raise
ValueError
(
f
"Only modes
{
permitted_3_channel_modes
}
are supported for 3D inputs"
)
if
mode
is
None
and
npimg
.
dtype
==
np
.
uint8
:
mode
=
"RGB"
if
mode
is
None
:
raise
TypeError
(
"Input type {} is not supported"
.
format
(
npimg
.
dtype
)
)
raise
TypeError
(
f
"Input type
{
npimg
.
dtype
}
is not supported"
)
return
Image
.
fromarray
(
npimg
,
mode
=
mode
)
...
...
@@ -325,15 +323,14 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool
Tensor: Normalized Tensor image.
"""
if
not
isinstance
(
tensor
,
torch
.
Tensor
):
raise
TypeError
(
"Input tensor should be a torch tensor. Got {
}."
.
format
(
type
(
tensor
)
)
)
raise
TypeError
(
f
"Input tensor should be a torch tensor. Got
{
type
(
tensor
)
}
."
)
if
not
tensor
.
is_floating_point
():
raise
TypeError
(
"Input tensor should be a float tensor. Got {
}."
.
format
(
tensor
.
dtype
)
)
raise
TypeError
(
f
"Input tensor should be a float tensor. Got
{
tensor
.
dtype
}
."
)
if
tensor
.
ndim
<
3
:
raise
ValueError
(
"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = "
"{}."
.
format
(
tensor
.
size
())
f
"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() =
{
tensor
.
size
()
}
"
)
if
not
inplace
:
...
...
@@ -343,7 +340,7 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool
mean
=
torch
.
as_tensor
(
mean
,
dtype
=
dtype
,
device
=
tensor
.
device
)
std
=
torch
.
as_tensor
(
std
,
dtype
=
dtype
,
device
=
tensor
.
device
)
if
(
std
==
0
).
any
():
raise
ValueError
(
"std evaluated to zero after conversion to {}, leading to division by zero."
.
format
(
dtype
)
)
raise
ValueError
(
f
"std evaluated to zero after conversion to
{
dtype
}
, leading to division by zero."
)
if
mean
.
ndim
==
1
:
mean
=
mean
.
view
(
-
1
,
1
,
1
)
if
std
.
ndim
==
1
:
...
...
@@ -425,7 +422,7 @@ def resize(
def
scale
(
*
args
,
**
kwargs
):
warnings
.
warn
(
"The use of the transforms.Scale transform is deprecated,
"
+
"
please use transforms.Resize instead."
)
warnings
.
warn
(
"The use of the transforms.Scale transform is deprecated, please use transforms.Resize instead."
)
return
resize
(
*
args
,
**
kwargs
)
...
...
@@ -923,7 +920,8 @@ def _get_inverse_affine_matrix(
# Thus, the inverse is M^-1 = C * RSS^-1 * C^-1 * T^-1
rot
=
math
.
radians
(
angle
)
sx
,
sy
=
[
math
.
radians
(
s
)
for
s
in
shear
]
sx
=
math
.
radians
(
shear
[
0
])
sy
=
math
.
radians
(
shear
[
1
])
cx
,
cy
=
center
tx
,
ty
=
translate
...
...
@@ -1121,7 +1119,7 @@ def affine(
shear
=
[
shear
[
0
],
shear
[
0
]]
if
len
(
shear
)
!=
2
:
raise
ValueError
(
"Shear should be a sequence containing two values. Got {
}"
.
format
(
shear
)
)
raise
ValueError
(
f
"Shear should be a sequence containing two values. Got
{
shear
}
"
)
img_size
=
get_image_size
(
img
)
if
not
isinstance
(
img
,
torch
.
Tensor
):
...
...
@@ -1201,7 +1199,7 @@ def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool
Tensor Image: Erased image.
"""
if
not
isinstance
(
img
,
torch
.
Tensor
):
raise
TypeError
(
"img should be Tensor Image. Got {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"img should be Tensor Image. Got
{
type
(
img
)
}
"
)
if
not
inplace
:
img
=
img
.
clone
()
...
...
@@ -1237,34 +1235,34 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[floa
PIL Image or Tensor: Gaussian Blurred version of the image.
"""
if
not
isinstance
(
kernel_size
,
(
int
,
list
,
tuple
)):
raise
TypeError
(
"kernel_size should be int or a sequence of integers. Got {
}"
.
format
(
type
(
kernel_size
)
)
)
raise
TypeError
(
f
"kernel_size should be int or a sequence of integers. Got
{
type
(
kernel_size
)
}
"
)
if
isinstance
(
kernel_size
,
int
):
kernel_size
=
[
kernel_size
,
kernel_size
]
if
len
(
kernel_size
)
!=
2
:
raise
ValueError
(
"If kernel_size is a sequence its length should be 2. Got {
}"
.
format
(
len
(
kernel_size
)
)
)
raise
ValueError
(
f
"If kernel_size is a sequence its length should be 2. Got
{
len
(
kernel_size
)
}
"
)
for
ksize
in
kernel_size
:
if
ksize
%
2
==
0
or
ksize
<
0
:
raise
ValueError
(
"kernel_size should have odd and positive integers. Got {
}"
.
format
(
kernel_size
)
)
raise
ValueError
(
f
"kernel_size should have odd and positive integers. Got
{
kernel_size
}
"
)
if
sigma
is
None
:
sigma
=
[
ksize
*
0.15
+
0.35
for
ksize
in
kernel_size
]
if
sigma
is
not
None
and
not
isinstance
(
sigma
,
(
int
,
float
,
list
,
tuple
)):
raise
TypeError
(
"sigma should be either float or sequence of floats. Got {
}"
.
format
(
type
(
sigma
)
)
)
raise
TypeError
(
f
"sigma should be either float or sequence of floats. Got
{
type
(
sigma
)
}
"
)
if
isinstance
(
sigma
,
(
int
,
float
)):
sigma
=
[
float
(
sigma
),
float
(
sigma
)]
if
isinstance
(
sigma
,
(
list
,
tuple
))
and
len
(
sigma
)
==
1
:
sigma
=
[
sigma
[
0
],
sigma
[
0
]]
if
len
(
sigma
)
!=
2
:
raise
ValueError
(
"If sigma is a sequence, its length should be 2. Got {
}"
.
format
(
len
(
sigma
)
)
)
raise
ValueError
(
f
"If sigma is a sequence, its length should be 2. Got
{
len
(
sigma
)
}
"
)
for
s
in
sigma
:
if
s
<=
0.0
:
raise
ValueError
(
"sigma should have positive values. Got {
}"
.
format
(
sigma
)
)
raise
ValueError
(
f
"sigma should have positive values. Got
{
sigma
}
"
)
t_img
=
img
if
not
isinstance
(
img
,
torch
.
Tensor
):
if
not
F_pil
.
_is_pil_image
(
img
):
raise
TypeError
(
"img should be PIL Image or Tensor. Got {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"img should be PIL Image or Tensor. Got
{
type
(
img
)
}
"
)
t_img
=
to_tensor
(
img
)
...
...
@@ -1307,7 +1305,7 @@ def posterize(img: Tensor, bits: int) -> Tensor:
PIL Image or Tensor: Posterized image.
"""
if
not
(
0
<=
bits
<=
8
):
raise
ValueError
(
"The number if bits should be between 0 and 8. Got {
}"
.
format
(
bits
)
)
raise
ValueError
(
f
"The number if bits should be between 0 and 8. Got
{
bits
}
"
)
if
not
isinstance
(
img
,
torch
.
Tensor
):
return
F_pil
.
posterize
(
img
,
bits
)
...
...
torchvision/transforms/functional_pil.py
View file @
d367a01a
...
...
@@ -23,20 +23,20 @@ def _is_pil_image(img: Any) -> bool:
def
get_image_size
(
img
:
Any
)
->
List
[
int
]:
if
_is_pil_image
(
img
):
return
list
(
img
.
size
)
raise
TypeError
(
"Unexpected type {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"Unexpected type
{
type
(
img
)
}
"
)
@
torch
.
jit
.
unused
def
get_image_num_channels
(
img
:
Any
)
->
int
:
if
_is_pil_image
(
img
):
return
1
if
img
.
mode
==
"L"
else
3
raise
TypeError
(
"Unexpected type {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"Unexpected type
{
type
(
img
)
}
"
)
@
torch
.
jit
.
unused
def
hflip
(
img
:
Image
.
Image
)
->
Image
.
Image
:
if
not
_is_pil_image
(
img
):
raise
TypeError
(
"img should be PIL Image. Got {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"img should be PIL Image. Got
{
type
(
img
)
}
"
)
return
img
.
transpose
(
Image
.
FLIP_LEFT_RIGHT
)
...
...
@@ -44,7 +44,7 @@ def hflip(img: Image.Image) -> Image.Image:
@
torch
.
jit
.
unused
def
vflip
(
img
:
Image
.
Image
)
->
Image
.
Image
:
if
not
_is_pil_image
(
img
):
raise
TypeError
(
"img should be PIL Image. Got {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"img should be PIL Image. Got
{
type
(
img
)
}
"
)
return
img
.
transpose
(
Image
.
FLIP_TOP_BOTTOM
)
...
...
@@ -52,7 +52,7 @@ def vflip(img: Image.Image) -> Image.Image:
@
torch
.
jit
.
unused
def
adjust_brightness
(
img
:
Image
.
Image
,
brightness_factor
:
float
)
->
Image
.
Image
:
if
not
_is_pil_image
(
img
):
raise
TypeError
(
"img should be PIL Image. Got {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"img should be PIL Image. Got
{
type
(
img
)
}
"
)
enhancer
=
ImageEnhance
.
Brightness
(
img
)
img
=
enhancer
.
enhance
(
brightness_factor
)
...
...
@@ -62,7 +62,7 @@ def adjust_brightness(img: Image.Image, brightness_factor: float) -> Image.Image
@
torch
.
jit
.
unused
def
adjust_contrast
(
img
:
Image
.
Image
,
contrast_factor
:
float
)
->
Image
.
Image
:
if
not
_is_pil_image
(
img
):
raise
TypeError
(
"img should be PIL Image. Got {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"img should be PIL Image. Got
{
type
(
img
)
}
"
)
enhancer
=
ImageEnhance
.
Contrast
(
img
)
img
=
enhancer
.
enhance
(
contrast_factor
)
...
...
@@ -72,7 +72,7 @@ def adjust_contrast(img: Image.Image, contrast_factor: float) -> Image.Image:
@
torch
.
jit
.
unused
def
adjust_saturation
(
img
:
Image
.
Image
,
saturation_factor
:
float
)
->
Image
.
Image
:
if
not
_is_pil_image
(
img
):
raise
TypeError
(
"img should be PIL Image. Got {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"img should be PIL Image. Got
{
type
(
img
)
}
"
)
enhancer
=
ImageEnhance
.
Color
(
img
)
img
=
enhancer
.
enhance
(
saturation_factor
)
...
...
@@ -82,10 +82,10 @@ def adjust_saturation(img: Image.Image, saturation_factor: float) -> Image.Image
@
torch
.
jit
.
unused
def
adjust_hue
(
img
:
Image
.
Image
,
hue_factor
:
float
)
->
Image
.
Image
:
if
not
(
-
0.5
<=
hue_factor
<=
0.5
):
raise
ValueError
(
"hue_factor ({}) is not in [-0.5, 0.5]."
.
format
(
hue_factor
)
)
raise
ValueError
(
f
"hue_factor (
{
hue_factor
}
) is not in [-0.5, 0.5]."
)
if
not
_is_pil_image
(
img
):
raise
TypeError
(
"img should be PIL Image. Got {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"img should be PIL Image. Got
{
type
(
img
)
}
"
)
input_mode
=
img
.
mode
if
input_mode
in
{
"L"
,
"1"
,
"I"
,
"F"
}:
...
...
@@ -111,7 +111,7 @@ def adjust_gamma(
)
->
Image
.
Image
:
if
not
_is_pil_image
(
img
):
raise
TypeError
(
"img should be PIL Image. Got {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"img should be PIL Image. Got
{
type
(
img
)
}
"
)
if
gamma
<
0
:
raise
ValueError
(
"Gamma should be a non-negative real number"
)
...
...
@@ -134,7 +134,7 @@ def pad(
)
->
Image
.
Image
:
if
not
_is_pil_image
(
img
):
raise
TypeError
(
"img should be PIL Image. Got {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"img should be PIL Image. Got
{
type
(
img
)
}
"
)
if
not
isinstance
(
padding
,
(
numbers
.
Number
,
tuple
,
list
)):
raise
TypeError
(
"Got inappropriate padding arg"
)
...
...
@@ -147,9 +147,7 @@ def pad(
padding
=
tuple
(
padding
)
if
isinstance
(
padding
,
tuple
)
and
len
(
padding
)
not
in
[
1
,
2
,
4
]:
raise
ValueError
(
"Padding must be an int or a 1, 2, or 4 element tuple, not a "
+
"{} element tuple"
.
format
(
len
(
padding
))
)
raise
ValueError
(
f
"Padding must be an int or a 1, 2, or 4 element tuple, not a
{
len
(
padding
)
}
element tuple"
)
if
isinstance
(
padding
,
tuple
)
and
len
(
padding
)
==
1
:
# Compatibility with `functional_tensor.pad`
...
...
@@ -217,7 +215,7 @@ def crop(
)
->
Image
.
Image
:
if
not
_is_pil_image
(
img
):
raise
TypeError
(
"img should be PIL Image. Got {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"img should be PIL Image. Got
{
type
(
img
)
}
"
)
return
img
.
crop
((
left
,
top
,
left
+
width
,
top
+
height
))
...
...
@@ -231,9 +229,9 @@ def resize(
)
->
Image
.
Image
:
if
not
_is_pil_image
(
img
):
raise
TypeError
(
"img should be PIL Image. Got {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"img should be PIL Image. Got
{
type
(
img
)
}
"
)
if
not
(
isinstance
(
size
,
int
)
or
(
isinstance
(
size
,
Sequence
)
and
len
(
size
)
in
(
1
,
2
))):
raise
TypeError
(
"Got inappropriate size arg: {
}"
.
format
(
size
)
)
raise
TypeError
(
f
"Got inappropriate size arg:
{
size
}
"
)
if
isinstance
(
size
,
Sequence
)
and
len
(
size
)
==
1
:
size
=
size
[
0
]
...
...
@@ -281,7 +279,7 @@ def _parse_fill(
fill
=
tuple
([
fill
]
*
num_bands
)
if
isinstance
(
fill
,
(
list
,
tuple
)):
if
len
(
fill
)
!=
num_bands
:
msg
=
"The number of elements in 'fill' does not match the number of
"
"
bands of the image ({} != {})"
msg
=
"The number of elements in 'fill' does not match the number of bands of the image ({} != {})"
raise
ValueError
(
msg
.
format
(
len
(
fill
),
num_bands
))
fill
=
tuple
(
fill
)
...
...
@@ -298,7 +296,7 @@ def affine(
)
->
Image
.
Image
:
if
not
_is_pil_image
(
img
):
raise
TypeError
(
"img should be PIL Image. Got {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"img should be PIL Image. Got
{
type
(
img
)
}
"
)
output_size
=
img
.
size
opts
=
_parse_fill
(
fill
,
img
)
...
...
@@ -316,7 +314,7 @@ def rotate(
)
->
Image
.
Image
:
if
not
_is_pil_image
(
img
):
raise
TypeError
(
"img should be PIL Image. Got {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"img should be PIL Image. Got
{
type
(
img
)
}
"
)
opts
=
_parse_fill
(
fill
,
img
)
return
img
.
rotate
(
angle
,
interpolation
,
expand
,
center
,
**
opts
)
...
...
@@ -331,7 +329,7 @@ def perspective(
)
->
Image
.
Image
:
if
not
_is_pil_image
(
img
):
raise
TypeError
(
"img should be PIL Image. Got {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"img should be PIL Image. Got
{
type
(
img
)
}
"
)
opts
=
_parse_fill
(
fill
,
img
)
...
...
@@ -341,7 +339,7 @@ def perspective(
@
torch
.
jit
.
unused
def
to_grayscale
(
img
:
Image
.
Image
,
num_output_channels
:
int
)
->
Image
.
Image
:
if
not
_is_pil_image
(
img
):
raise
TypeError
(
"img should be PIL Image. Got {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"img should be PIL Image. Got
{
type
(
img
)
}
"
)
if
num_output_channels
==
1
:
img
=
img
.
convert
(
"L"
)
...
...
@@ -359,28 +357,28 @@ def to_grayscale(img: Image.Image, num_output_channels: int) -> Image.Image:
@
torch
.
jit
.
unused
def
invert
(
img
:
Image
.
Image
)
->
Image
.
Image
:
if
not
_is_pil_image
(
img
):
raise
TypeError
(
"img should be PIL Image. Got {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"img should be PIL Image. Got
{
type
(
img
)
}
"
)
return
ImageOps
.
invert
(
img
)
@
torch
.
jit
.
unused
def
posterize
(
img
:
Image
.
Image
,
bits
:
int
)
->
Image
.
Image
:
if
not
_is_pil_image
(
img
):
raise
TypeError
(
"img should be PIL Image. Got {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"img should be PIL Image. Got
{
type
(
img
)
}
"
)
return
ImageOps
.
posterize
(
img
,
bits
)
@
torch
.
jit
.
unused
def
solarize
(
img
:
Image
.
Image
,
threshold
:
int
)
->
Image
.
Image
:
if
not
_is_pil_image
(
img
):
raise
TypeError
(
"img should be PIL Image. Got {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"img should be PIL Image. Got
{
type
(
img
)
}
"
)
return
ImageOps
.
solarize
(
img
,
threshold
)
@
torch
.
jit
.
unused
def
adjust_sharpness
(
img
:
Image
.
Image
,
sharpness_factor
:
float
)
->
Image
.
Image
:
if
not
_is_pil_image
(
img
):
raise
TypeError
(
"img should be PIL Image. Got {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"img should be PIL Image. Got
{
type
(
img
)
}
"
)
enhancer
=
ImageEnhance
.
Sharpness
(
img
)
img
=
enhancer
.
enhance
(
sharpness_factor
)
...
...
@@ -390,12 +388,12 @@ def adjust_sharpness(img: Image.Image, sharpness_factor: float) -> Image.Image:
@
torch
.
jit
.
unused
def
autocontrast
(
img
:
Image
.
Image
)
->
Image
.
Image
:
if
not
_is_pil_image
(
img
):
raise
TypeError
(
"img should be PIL Image. Got {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"img should be PIL Image. Got
{
type
(
img
)
}
"
)
return
ImageOps
.
autocontrast
(
img
)
@
torch
.
jit
.
unused
def
equalize
(
img
:
Image
.
Image
)
->
Image
.
Image
:
if
not
_is_pil_image
(
img
):
raise
TypeError
(
"img should be PIL Image. Got {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"img should be PIL Image. Got
{
type
(
img
)
}
"
)
return
ImageOps
.
equalize
(
img
)
torchvision/transforms/functional_tensor.py
View file @
d367a01a
...
...
@@ -28,7 +28,7 @@ def get_image_num_channels(img: Tensor) -> int:
elif
img
.
ndim
>
2
:
return
img
.
shape
[
-
3
]
raise
TypeError
(
"Input ndim should be 2 or more. Got {
}"
.
format
(
img
.
ndim
)
)
raise
TypeError
(
f
"Input ndim should be 2 or more. Got
{
img
.
ndim
}
"
)
def
_max_value
(
dtype
:
torch
.
dtype
)
->
float
:
...
...
@@ -52,7 +52,7 @@ def _max_value(dtype: torch.dtype) -> float:
def
_assert_channels
(
img
:
Tensor
,
permitted
:
List
[
int
])
->
None
:
c
=
get_image_num_channels
(
img
)
if
c
not
in
permitted
:
raise
TypeError
(
"Input image tensor permitted channel values are {}, but found {}"
.
format
(
permitted
,
c
)
)
raise
TypeError
(
f
"Input image tensor permitted channel values are
{
permitted
}
, but found
{
c
}
"
)
def
convert_image_dtype
(
image
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
=
torch
.
float
)
->
torch
.
Tensor
:
...
...
@@ -134,7 +134,7 @@ def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
def
rgb_to_grayscale
(
img
:
Tensor
,
num_output_channels
:
int
=
1
)
->
Tensor
:
if
img
.
ndim
<
3
:
raise
TypeError
(
"Input image tensor should have at least 3 dimensions, but found {
}"
.
format
(
img
.
ndim
)
)
raise
TypeError
(
f
"Input image tensor should have at least 3 dimensions, but found
{
img
.
ndim
}
"
)
_assert_channels
(
img
,
[
3
])
if
num_output_channels
not
in
(
1
,
3
):
...
...
@@ -154,7 +154,7 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
def
adjust_brightness
(
img
:
Tensor
,
brightness_factor
:
float
)
->
Tensor
:
if
brightness_factor
<
0
:
raise
ValueError
(
"brightness_factor ({}) is not non-negative."
.
format
(
brightness_factor
)
)
raise
ValueError
(
f
"brightness_factor (
{
brightness_factor
}
) is not non-negative."
)
_assert_image_tensor
(
img
)
...
...
@@ -165,7 +165,7 @@ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
def
adjust_contrast
(
img
:
Tensor
,
contrast_factor
:
float
)
->
Tensor
:
if
contrast_factor
<
0
:
raise
ValueError
(
"contrast_factor ({}) is not non-negative."
.
format
(
contrast_factor
)
)
raise
ValueError
(
f
"contrast_factor (
{
contrast_factor
}
) is not non-negative."
)
_assert_image_tensor
(
img
)
...
...
@@ -182,7 +182,7 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
def
adjust_hue
(
img
:
Tensor
,
hue_factor
:
float
)
->
Tensor
:
if
not
(
-
0.5
<=
hue_factor
<=
0.5
):
raise
ValueError
(
"hue_factor ({}) is not in [-0.5, 0.5]."
.
format
(
hue_factor
)
)
raise
ValueError
(
f
"hue_factor (
{
hue_factor
}
) is not in [-0.5, 0.5]."
)
if
not
(
isinstance
(
img
,
torch
.
Tensor
)):
raise
TypeError
(
"Input img should be Tensor image"
)
...
...
@@ -211,7 +211,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
def
adjust_saturation
(
img
:
Tensor
,
saturation_factor
:
float
)
->
Tensor
:
if
saturation_factor
<
0
:
raise
ValueError
(
"saturation_factor ({}) is not non-negative."
.
format
(
saturation_factor
)
)
raise
ValueError
(
f
"saturation_factor (
{
saturation_factor
}
) is not non-negative."
)
_assert_image_tensor
(
img
)
...
...
@@ -246,7 +246,7 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
def
center_crop
(
img
:
Tensor
,
output_size
:
BroadcastingList2
[
int
])
->
Tensor
:
"""DEPRECATED"""
warnings
.
warn
(
"This method is deprecated and will be removed in future releases.
"
"
Please, use ``F.center_crop`` instead."
"This method is deprecated and will be removed in future releases. Please, use ``F.center_crop`` instead."
)
_assert_image_tensor
(
img
)
...
...
@@ -268,7 +268,7 @@ def center_crop(img: Tensor, output_size: BroadcastingList2[int]) -> Tensor:
def
five_crop
(
img
:
Tensor
,
size
:
BroadcastingList2
[
int
])
->
List
[
Tensor
]:
"""DEPRECATED"""
warnings
.
warn
(
"This method is deprecated and will be removed in future releases.
"
"
Please, use ``F.five_crop`` instead."
"This method is deprecated and will be removed in future releases. Please, use ``F.five_crop`` instead."
)
_assert_image_tensor
(
img
)
...
...
@@ -293,7 +293,7 @@ def five_crop(img: Tensor, size: BroadcastingList2[int]) -> List[Tensor]:
def
ten_crop
(
img
:
Tensor
,
size
:
BroadcastingList2
[
int
],
vertical_flip
:
bool
=
False
)
->
List
[
Tensor
]:
"""DEPRECATED"""
warnings
.
warn
(
"This method is deprecated and will be removed in future releases.
"
"
Please, use ``F.ten_crop`` instead."
"This method is deprecated and will be removed in future releases. Please, use ``F.ten_crop`` instead."
)
_assert_image_tensor
(
img
)
...
...
@@ -382,7 +382,8 @@ def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor:
# crop if needed
if
padding
[
0
]
<
0
or
padding
[
1
]
<
0
or
padding
[
2
]
<
0
or
padding
[
3
]
<
0
:
crop_left
,
crop_right
,
crop_top
,
crop_bottom
=
[
-
min
(
x
,
0
)
for
x
in
padding
]
neg_min_padding
=
[
-
min
(
x
,
0
)
for
x
in
padding
]
crop_left
,
crop_right
,
crop_top
,
crop_bottom
=
neg_min_padding
img
=
img
[...,
crop_top
:
img
.
shape
[
-
2
]
-
crop_bottom
,
crop_left
:
img
.
shape
[
-
1
]
-
crop_right
]
padding
=
[
max
(
x
,
0
)
for
x
in
padding
]
...
...
@@ -421,9 +422,7 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
padding
=
list
(
padding
)
if
isinstance
(
padding
,
list
)
and
len
(
padding
)
not
in
[
1
,
2
,
4
]:
raise
ValueError
(
"Padding must be an int or a 1, 2, or 4 element tuple, not a "
+
"{} element tuple"
.
format
(
len
(
padding
))
)
raise
ValueError
(
f
"Padding must be an int or a 1, 2, or 4 element tuple, not a
{
len
(
padding
)
}
element tuple"
)
if
padding_mode
not
in
[
"constant"
,
"edge"
,
"reflect"
,
"symmetric"
]:
raise
ValueError
(
"Padding mode should be either constant, edge, reflect or symmetric"
)
...
...
@@ -501,7 +500,7 @@ def resize(
if
isinstance
(
size
,
list
):
if
len
(
size
)
not
in
[
1
,
2
]:
raise
ValueError
(
"Size must be an int or a 1 or 2 element tuple/list, not a
"
"{
} element tuple/list"
.
format
(
len
(
size
))
f
"Size must be an int or a 1 or 2 element tuple/list, not a
{
len
(
size
)
}
element tuple/list"
)
if
max_size
is
not
None
and
len
(
size
)
!=
1
:
raise
ValueError
(
...
...
@@ -597,7 +596,7 @@ def _assert_grid_transform_inputs(
raise
ValueError
(
msg
.
format
(
len
(
fill
),
num_channels
))
if
interpolation
not
in
supported_interpolation_modes
:
raise
ValueError
(
"Interpolation mode '{}' is unsupported with Tensor input"
.
format
(
interpolation
)
)
raise
ValueError
(
f
"Interpolation mode '
{
interpolation
}
' is unsupported with Tensor input"
)
def
_cast_squeeze_in
(
img
:
Tensor
,
req_dtypes
:
List
[
torch
.
dtype
])
->
Tuple
[
Tensor
,
bool
,
bool
,
torch
.
dtype
]:
...
...
@@ -823,7 +822,7 @@ def _get_gaussian_kernel2d(
def
gaussian_blur
(
img
:
Tensor
,
kernel_size
:
List
[
int
],
sigma
:
List
[
float
])
->
Tensor
:
if
not
(
isinstance
(
img
,
torch
.
Tensor
)):
raise
TypeError
(
"img should be Tensor. Got {
}"
.
format
(
type
(
img
)
)
)
raise
TypeError
(
f
"img should be Tensor. Got
{
type
(
img
)
}
"
)
_assert_image_tensor
(
img
)
...
...
@@ -852,7 +851,7 @@ def invert(img: Tensor) -> Tensor:
_assert_image_tensor
(
img
)
if
img
.
ndim
<
3
:
raise
TypeError
(
"Input image tensor should have at least 3 dimensions, but found {
}"
.
format
(
img
.
ndim
)
)
raise
TypeError
(
f
"Input image tensor should have at least 3 dimensions, but found
{
img
.
ndim
}
"
)
_assert_channels
(
img
,
[
1
,
3
])
...
...
@@ -865,9 +864,9 @@ def posterize(img: Tensor, bits: int) -> Tensor:
_assert_image_tensor
(
img
)
if
img
.
ndim
<
3
:
raise
TypeError
(
"Input image tensor should have at least 3 dimensions, but found {
}"
.
format
(
img
.
ndim
)
)
raise
TypeError
(
f
"Input image tensor should have at least 3 dimensions, but found
{
img
.
ndim
}
"
)
if
img
.
dtype
!=
torch
.
uint8
:
raise
TypeError
(
"Only torch.uint8 image tensors are supported, but found {
}"
.
format
(
img
.
dtype
)
)
raise
TypeError
(
f
"Only torch.uint8 image tensors are supported, but found
{
img
.
dtype
}
"
)
_assert_channels
(
img
,
[
1
,
3
])
mask
=
-
int
(
2
**
(
8
-
bits
))
# JIT-friendly for: ~(2 ** (8 - bits) - 1)
...
...
@@ -879,7 +878,7 @@ def solarize(img: Tensor, threshold: float) -> Tensor:
_assert_image_tensor
(
img
)
if
img
.
ndim
<
3
:
raise
TypeError
(
"Input image tensor should have at least 3 dimensions, but found {
}"
.
format
(
img
.
ndim
)
)
raise
TypeError
(
f
"Input image tensor should have at least 3 dimensions, but found
{
img
.
ndim
}
"
)
_assert_channels
(
img
,
[
1
,
3
])
...
...
@@ -912,7 +911,7 @@ def _blurred_degenerate_image(img: Tensor) -> Tensor:
def
adjust_sharpness
(
img
:
Tensor
,
sharpness_factor
:
float
)
->
Tensor
:
if
sharpness_factor
<
0
:
raise
ValueError
(
"sharpness_factor ({}) is not non-negative."
.
format
(
sharpness_factor
)
)
raise
ValueError
(
f
"sharpness_factor (
{
sharpness_factor
}
) is not non-negative."
)
_assert_image_tensor
(
img
)
...
...
@@ -929,7 +928,7 @@ def autocontrast(img: Tensor) -> Tensor:
_assert_image_tensor
(
img
)
if
img
.
ndim
<
3
:
raise
TypeError
(
"Input image tensor should have at least 3 dimensions, but found {
}"
.
format
(
img
.
ndim
)
)
raise
TypeError
(
f
"Input image tensor should have at least 3 dimensions, but found
{
img
.
ndim
}
"
)
_assert_channels
(
img
,
[
1
,
3
])
...
...
@@ -976,9 +975,9 @@ def equalize(img: Tensor) -> Tensor:
_assert_image_tensor
(
img
)
if
not
(
3
<=
img
.
ndim
<=
4
):
raise
TypeError
(
"Input image tensor should have 3 or 4 dimensions, but found {
}"
.
format
(
img
.
ndim
)
)
raise
TypeError
(
f
"Input image tensor should have 3 or 4 dimensions, but found
{
img
.
ndim
}
"
)
if
img
.
dtype
!=
torch
.
uint8
:
raise
TypeError
(
"Only torch.uint8 image tensors are supported, but found {
}"
.
format
(
img
.
dtype
)
)
raise
TypeError
(
f
"Only torch.uint8 image tensors are supported, but found
{
img
.
dtype
}
"
)
_assert_channels
(
img
,
[
1
,
3
])
...
...
torchvision/transforms/transforms.py
View file @
d367a01a
...
...
@@ -98,7 +98,7 @@ class Compose:
format_string
=
self
.
__class__
.
__name__
+
"("
for
t
in
self
.
transforms
:
format_string
+=
"
\n
"
format_string
+=
" {
0
}"
.
format
(
t
)
format_string
+=
f
"
{
t
}
"
format_string
+=
"
\n
)"
return
format_string
...
...
@@ -220,7 +220,7 @@ class ToPILImage:
def
__repr__
(
self
):
format_string
=
self
.
__class__
.
__name__
+
"("
if
self
.
mode
is
not
None
:
format_string
+=
"mode={
0}"
.
format
(
self
.
mode
)
format_string
+=
f
"mode=
{
self
.
mode
}
"
format_string
+=
")"
return
format_string
...
...
@@ -260,7 +260,7 @@ class Normalize(torch.nn.Module):
return
F
.
normalize
(
tensor
,
self
.
mean
,
self
.
std
,
self
.
inplace
)
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"(mean={
0}, std={1})"
.
format
(
self
.
mean
,
self
.
std
)
return
self
.
__class__
.
__name__
+
f
"(mean=
{
self
.
mean
}
,
std=
{
self
.
std
}
)"
class
Resize
(
torch
.
nn
.
Module
):
...
...
@@ -310,7 +310,7 @@ class Resize(torch.nn.Module):
def
__init__
(
self
,
size
,
interpolation
=
InterpolationMode
.
BILINEAR
,
max_size
=
None
,
antialias
=
None
):
super
().
__init__
()
if
not
isinstance
(
size
,
(
int
,
Sequence
)):
raise
TypeError
(
"Size should be int or sequence. Got {
}"
.
format
(
type
(
size
)
)
)
raise
TypeError
(
f
"Size should be int or sequence. Got
{
type
(
size
)
}
"
)
if
isinstance
(
size
,
Sequence
)
and
len
(
size
)
not
in
(
1
,
2
):
raise
ValueError
(
"If size is a sequence, it should have 1 or 2 values"
)
self
.
size
=
size
...
...
@@ -338,10 +338,8 @@ class Resize(torch.nn.Module):
return
F
.
resize
(
img
,
self
.
size
,
self
.
interpolation
,
self
.
max_size
,
self
.
antialias
)
def
__repr__
(
self
):
interpolate_str
=
self
.
interpolation
.
value
return
self
.
__class__
.
__name__
+
"(size={0}, interpolation={1}, max_size={2}, antialias={3})"
.
format
(
self
.
size
,
interpolate_str
,
self
.
max_size
,
self
.
antialias
)
detail
=
f
"(size=
{
self
.
size
}
, interpolation=
{
self
.
interpolation
.
value
}
, max_size=
{
self
.
max_size
}
, antialias=
{
self
.
antialias
}
)"
return
self
.
__class__
.
__name__
+
detail
class
Scale
(
Resize
):
...
...
@@ -350,10 +348,8 @@ class Scale(Resize):
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
warnings
.
warn
(
"The use of the transforms.Scale transform is deprecated, "
+
"please use transforms.Resize instead."
)
super
(
Scale
,
self
).
__init__
(
*
args
,
**
kwargs
)
warnings
.
warn
(
"The use of the transforms.Scale transform is deprecated, please use transforms.Resize instead."
)
super
().
__init__
(
*
args
,
**
kwargs
)
class
CenterCrop
(
torch
.
nn
.
Module
):
...
...
@@ -383,7 +379,7 @@ class CenterCrop(torch.nn.Module):
return
F
.
center_crop
(
img
,
self
.
size
)
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"(size={
0})"
.
format
(
self
.
size
)
return
self
.
__class__
.
__name__
+
f
"(size=
{
self
.
size
}
)"
class
Pad
(
torch
.
nn
.
Module
):
...
...
@@ -437,7 +433,7 @@ class Pad(torch.nn.Module):
if
isinstance
(
padding
,
Sequence
)
and
len
(
padding
)
not
in
[
1
,
2
,
4
]:
raise
ValueError
(
"Padding must be an int or a 1, 2, or 4 element tuple, not a
"
+
"{
} element tuple"
.
format
(
len
(
padding
))
f
"Padding must be an int or a 1, 2, or 4 element tuple, not a
{
len
(
padding
)
}
element tuple"
)
self
.
padding
=
padding
...
...
@@ -455,9 +451,7 @@ class Pad(torch.nn.Module):
return
F
.
pad
(
img
,
self
.
padding
,
self
.
fill
,
self
.
padding_mode
)
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"(padding={0}, fill={1}, padding_mode={2})"
.
format
(
self
.
padding
,
self
.
fill
,
self
.
padding_mode
)
return
self
.
__class__
.
__name__
+
f
"(padding=
{
self
.
padding
}
, fill=
{
self
.
fill
}
, padding_mode=
{
self
.
padding_mode
}
)"
class
Lambda
:
...
...
@@ -469,7 +463,7 @@ class Lambda:
def
__init__
(
self
,
lambd
):
if
not
callable
(
lambd
):
raise
TypeError
(
"Argument lambd should be callable, got {
}"
.
format
(
repr
(
type
(
lambd
).
__name__
)
)
)
raise
TypeError
(
f
"Argument lambd should be callable, got
{
repr
(
type
(
lambd
).
__name__
)
}
"
)
self
.
lambd
=
lambd
def
__call__
(
self
,
img
):
...
...
@@ -498,7 +492,7 @@ class RandomTransforms:
format_string
=
self
.
__class__
.
__name__
+
"("
for
t
in
self
.
transforms
:
format_string
+=
"
\n
"
format_string
+=
" {
0
}"
.
format
(
t
)
format_string
+=
f
"
{
t
}
"
format_string
+=
"
\n
)"
return
format_string
...
...
@@ -537,10 +531,10 @@ class RandomApply(torch.nn.Module):
def
__repr__
(
self
):
format_string
=
self
.
__class__
.
__name__
+
"("
format_string
+=
"
\n
p={
}"
.
format
(
self
.
p
)
format_string
+=
f
"
\n
p=
{
self
.
p
}
"
for
t
in
self
.
transforms
:
format_string
+=
"
\n
"
format_string
+=
" {
0
}"
.
format
(
t
)
format_string
+=
f
"
{
t
}
"
format_string
+=
"
\n
)"
return
format_string
...
...
@@ -571,7 +565,7 @@ class RandomChoice(RandomTransforms):
def
__repr__
(
self
):
format_string
=
super
().
__repr__
()
format_string
+=
"(p={
0})"
.
format
(
self
.
p
)
format_string
+=
f
"(p=
{
self
.
p
}
)"
return
format_string
...
...
@@ -634,7 +628,7 @@ class RandomCrop(torch.nn.Module):
th
,
tw
=
output_size
if
h
+
1
<
th
or
w
+
1
<
tw
:
raise
ValueError
(
"Required crop size {} is larger then input image size {
}"
.
format
((
t
h
,
t
w
)
,
(
h
,
w
))
)
raise
ValueError
(
f
"Required crop size
{
(
th
,
tw
)
}
is larger then input image size
{
(
h
,
w
)
}
"
)
if
w
==
tw
and
h
==
th
:
return
0
,
0
,
h
,
w
...
...
@@ -679,7 +673,7 @@ class RandomCrop(torch.nn.Module):
return
F
.
crop
(
img
,
i
,
j
,
h
,
w
)
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"(size={
0
}, padding={
1})"
.
format
(
self
.
size
,
self
.
padding
)
return
self
.
__class__
.
__name__
+
f
"(size=
{
self
.
size
}
, padding=
{
self
.
padding
}
)"
class
RandomHorizontalFlip
(
torch
.
nn
.
Module
):
...
...
@@ -709,7 +703,7 @@ class RandomHorizontalFlip(torch.nn.Module):
return
img
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"(p={
})"
.
format
(
self
.
p
)
return
self
.
__class__
.
__name__
+
f
"(p=
{
self
.
p
}
)"
class
RandomVerticalFlip
(
torch
.
nn
.
Module
):
...
...
@@ -739,7 +733,7 @@ class RandomVerticalFlip(torch.nn.Module):
return
img
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"(p={
})"
.
format
(
self
.
p
)
return
self
.
__class__
.
__name__
+
f
"(p=
{
self
.
p
}
)"
class
RandomPerspective
(
torch
.
nn
.
Module
):
...
...
@@ -839,7 +833,7 @@ class RandomPerspective(torch.nn.Module):
return
startpoints
,
endpoints
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"(p={
})"
.
format
(
self
.
p
)
return
self
.
__class__
.
__name__
+
f
"(p=
{
self
.
p
}
)"
class
RandomResizedCrop
(
torch
.
nn
.
Module
):
...
...
@@ -951,10 +945,10 @@ class RandomResizedCrop(torch.nn.Module):
def
__repr__
(
self
):
interpolate_str
=
self
.
interpolation
.
value
format_string
=
self
.
__class__
.
__name__
+
"(size={
0}"
.
format
(
self
.
size
)
format_string
+=
", scale={
0}"
.
format
(
tuple
(
round
(
s
,
4
)
for
s
in
self
.
scale
)
)
format_string
+=
", ratio={
0}"
.
format
(
tuple
(
round
(
r
,
4
)
for
r
in
self
.
ratio
)
)
format_string
+=
", interpolation={
0})"
.
format
(
interpolate_str
)
format_string
=
self
.
__class__
.
__name__
+
f
"(size=
{
self
.
size
}
"
format_string
+=
f
", scale=
{
tuple
(
round
(
s
,
4
)
for
s
in
self
.
scale
)
}
"
format_string
+=
f
", ratio=
{
tuple
(
round
(
r
,
4
)
for
r
in
self
.
ratio
)
}
"
format_string
+=
f
", interpolation=
{
interpolate_str
}
)"
return
format_string
...
...
@@ -968,7 +962,7 @@ class RandomSizedCrop(RandomResizedCrop):
"The use of the transforms.RandomSizedCrop transform is deprecated, "
+
"please use transforms.RandomResizedCrop instead."
)
super
(
RandomSizedCrop
,
self
).
__init__
(
*
args
,
**
kwargs
)
super
().
__init__
(
*
args
,
**
kwargs
)
class
FiveCrop
(
torch
.
nn
.
Module
):
...
...
@@ -1014,7 +1008,7 @@ class FiveCrop(torch.nn.Module):
return
F
.
five_crop
(
img
,
self
.
size
)
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"(size={
0})"
.
format
(
self
.
size
)
return
self
.
__class__
.
__name__
+
f
"(size=
{
self
.
size
}
)"
class
TenCrop
(
torch
.
nn
.
Module
):
...
...
@@ -1063,7 +1057,7 @@ class TenCrop(torch.nn.Module):
return
F
.
ten_crop
(
img
,
self
.
size
,
self
.
vertical_flip
)
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"(size={
0
}, vertical_flip={
1})"
.
format
(
self
.
size
,
self
.
vertical_flip
)
return
self
.
__class__
.
__name__
+
f
"(size=
{
self
.
size
}
, vertical_flip=
{
self
.
vertical_flip
}
)"
class
LinearTransformation
(
torch
.
nn
.
Module
):
...
...
@@ -1090,22 +1084,18 @@ class LinearTransformation(torch.nn.Module):
if
transformation_matrix
.
size
(
0
)
!=
transformation_matrix
.
size
(
1
):
raise
ValueError
(
"transformation_matrix should be square. Got "
+
"[{} x {}] rectangular matrix."
.
format
(
*
transformation_matrix
.
size
())
f
"
{
tuple
(
transformation_matrix
.
size
())
}
rectangular matrix."
)
if
mean_vector
.
size
(
0
)
!=
transformation_matrix
.
size
(
0
):
raise
ValueError
(
"mean_vector should have the same length {}"
.
format
(
mean_vector
.
size
(
0
))
+
" as any one of the dimensions of the transformation_matrix [{}]"
.
format
(
tuple
(
transformation_matrix
.
size
())
)
f
"mean_vector should have the same length
{
mean_vector
.
size
(
0
)
}
"
f
" as any one of the dimensions of the transformation_matrix [
{
tuple
(
transformation_matrix
.
size
())
}
]"
)
if
transformation_matrix
.
device
!=
mean_vector
.
device
:
raise
ValueError
(
"Input tensors should be on the same device. Got {} and {}"
.
format
(
transformation_matrix
.
device
,
mean_vector
.
device
)
f
"Input tensors should be on the same device. Got
{
transformation_matrix
.
device
}
and
{
mean_vector
.
device
}
"
)
self
.
transformation_matrix
=
transformation_matrix
...
...
@@ -1124,14 +1114,14 @@ class LinearTransformation(torch.nn.Module):
if
n
!=
self
.
transformation_matrix
.
shape
[
0
]:
raise
ValueError
(
"Input tensor and transformation matrix have incompatible shape."
+
"[{
} x {} x {}] != "
.
format
(
shape
[
-
3
]
,
shape
[
-
2
]
,
shape
[
-
1
]
)
+
"{
}"
.
format
(
self
.
transformation_matrix
.
shape
[
0
]
)
+
f
"[
{
shape
[
-
3
]
}
x
{
shape
[
-
2
]
}
x
{
shape
[
-
1
]
}
] != "
+
f
"
{
self
.
transformation_matrix
.
shape
[
0
]
}
"
)
if
tensor
.
device
.
type
!=
self
.
mean_vector
.
device
.
type
:
raise
ValueError
(
"Input tensor should be on the same device as transformation matrix and mean vector. "
"Got {
} vs {}"
.
format
(
tensor
.
device
,
self
.
mean_vector
.
device
)
f
"Got
{
tensor
.
device
}
vs
{
self
.
mean_vector
.
device
}
"
)
flat_tensor
=
tensor
.
view
(
-
1
,
n
)
-
self
.
mean_vector
...
...
@@ -1178,15 +1168,15 @@ class ColorJitter(torch.nn.Module):
def
_check_input
(
self
,
value
,
name
,
center
=
1
,
bound
=
(
0
,
float
(
"inf"
)),
clip_first_on_zero
=
True
):
if
isinstance
(
value
,
numbers
.
Number
):
if
value
<
0
:
raise
ValueError
(
"If {} is a single number, it must be non negative."
.
format
(
name
)
)
raise
ValueError
(
f
"If
{
name
}
is a single number, it must be non negative."
)
value
=
[
center
-
float
(
value
),
center
+
float
(
value
)]
if
clip_first_on_zero
:
value
[
0
]
=
max
(
value
[
0
],
0.0
)
elif
isinstance
(
value
,
(
tuple
,
list
))
and
len
(
value
)
==
2
:
if
not
bound
[
0
]
<=
value
[
0
]
<=
value
[
1
]
<=
bound
[
1
]:
raise
ValueError
(
"{} values should be between {
}"
.
format
(
name
,
bound
)
)
raise
ValueError
(
f
"
{
name
}
values should be between
{
bound
}
"
)
else
:
raise
TypeError
(
"{} should be a single number or a list/tuple with length 2."
.
format
(
name
)
)
raise
TypeError
(
f
"
{
name
}
should be a single number or a list/tuple with length 2."
)
# if value is 0 or (1., 1.) for brightness/contrast/saturation
# or (0., 0.) for hue, do nothing
...
...
@@ -1252,10 +1242,10 @@ class ColorJitter(torch.nn.Module):
def
__repr__
(
self
):
format_string
=
self
.
__class__
.
__name__
+
"("
format_string
+=
"brightness={
0}"
.
format
(
self
.
brightness
)
format_string
+=
", contrast={
0}"
.
format
(
self
.
contrast
)
format_string
+=
", saturation={
0}"
.
format
(
self
.
saturation
)
format_string
+=
", hue={
0})"
.
format
(
self
.
hue
)
format_string
+=
f
"brightness=
{
self
.
brightness
}
"
format_string
+=
f
", contrast=
{
self
.
contrast
}
"
format_string
+=
f
", saturation=
{
self
.
saturation
}
"
format_string
+=
f
", hue=
{
self
.
hue
}
)"
return
format_string
...
...
@@ -1352,13 +1342,13 @@ class RandomRotation(torch.nn.Module):
def
__repr__
(
self
):
interpolate_str
=
self
.
interpolation
.
value
format_string
=
self
.
__class__
.
__name__
+
"(degrees={
0}"
.
format
(
self
.
degrees
)
format_string
+=
", interpolation={
0}"
.
format
(
interpolate_str
)
format_string
+=
", expand={
0}"
.
format
(
self
.
expand
)
format_string
=
self
.
__class__
.
__name__
+
f
"(degrees=
{
self
.
degrees
}
"
format_string
+=
f
", interpolation=
{
interpolate_str
}
"
format_string
+=
f
", expand=
{
self
.
expand
}
"
if
self
.
center
is
not
None
:
format_string
+=
", center={
0}"
.
format
(
self
.
center
)
format_string
+=
f
", center=
{
self
.
center
}
"
if
self
.
fill
is
not
None
:
format_string
+=
", fill={
0}"
.
format
(
self
.
fill
)
format_string
+=
f
", fill=
{
self
.
fill
}
"
format_string
+=
")"
return
format_string
...
...
@@ -1568,7 +1558,7 @@ class Grayscale(torch.nn.Module):
return
F
.
rgb_to_grayscale
(
img
,
num_output_channels
=
self
.
num_output_channels
)
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"(num_output_channels={
0})"
.
format
(
self
.
num_output_channels
)
return
self
.
__class__
.
__name__
+
f
"(num_output_channels=
{
self
.
num_output_channels
}
)"
class
RandomGrayscale
(
torch
.
nn
.
Module
):
...
...
@@ -1605,7 +1595,7 @@ class RandomGrayscale(torch.nn.Module):
return
img
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"(p={
0})"
.
format
(
self
.
p
)
return
self
.
__class__
.
__name__
+
f
"(p=
{
self
.
p
}
)"
class
RandomErasing
(
torch
.
nn
.
Module
):
...
...
@@ -1726,7 +1716,7 @@ class RandomErasing(torch.nn.Module):
if
value
is
not
None
and
not
(
len
(
value
)
in
(
1
,
img
.
shape
[
-
3
])):
raise
ValueError
(
"If value is a sequence, it should have either a single value or "
"{} (number of input channels)"
.
format
(
img
.
shape
[
-
3
])
f
"
{
img
.
shape
[
-
3
]
}
(number of input channels)"
)
x
,
y
,
h
,
w
,
v
=
self
.
get_params
(
img
,
scale
=
self
.
scale
,
ratio
=
self
.
ratio
,
value
=
value
)
...
...
@@ -1734,11 +1724,11 @@ class RandomErasing(torch.nn.Module):
return
img
def
__repr__
(
self
):
s
=
"(p={
}, "
.
format
(
self
.
p
)
s
+=
"scale={
}, "
.
format
(
self
.
scale
)
s
+=
"ratio={
}, "
.
format
(
self
.
ratio
)
s
+=
"value={
}, "
.
format
(
self
.
value
)
s
+=
"inplace={
})"
.
format
(
self
.
inplace
)
s
=
f
"(p=
{
self
.
p
}
, "
s
+=
f
"scale=
{
self
.
scale
}
, "
s
+=
f
"ratio=
{
self
.
ratio
}
, "
s
+=
f
"value=
{
self
.
value
}
, "
s
+=
f
"inplace=
{
self
.
inplace
}
)"
return
self
.
__class__
.
__name__
+
s
...
...
@@ -1803,8 +1793,8 @@ class GaussianBlur(torch.nn.Module):
return
F
.
gaussian_blur
(
img
,
self
.
kernel_size
,
[
sigma
,
sigma
])
def
__repr__
(
self
):
s
=
"(kernel_size={
}, "
.
format
(
self
.
kernel_size
)
s
+=
"sigma={
})"
.
format
(
self
.
sigma
)
s
=
f
"(kernel_size=
{
self
.
kernel_size
}
, "
s
+=
f
"sigma=
{
self
.
sigma
}
)"
return
self
.
__class__
.
__name__
+
s
...
...
@@ -1824,15 +1814,15 @@ def _setup_size(size, error_msg):
def
_check_sequence_input
(
x
,
name
,
req_sizes
):
msg
=
req_sizes
[
0
]
if
len
(
req_sizes
)
<
2
else
" or "
.
join
([
str
(
s
)
for
s
in
req_sizes
])
if
not
isinstance
(
x
,
Sequence
):
raise
TypeError
(
"{} should be a sequence of length {}."
.
format
(
name
,
msg
)
)
raise
TypeError
(
f
"
{
name
}
should be a sequence of length
{
msg
}
."
)
if
len
(
x
)
not
in
req_sizes
:
raise
ValueError
(
"{} should be sequence of length {}."
.
format
(
name
,
msg
)
)
raise
ValueError
(
f
"
{
name
}
should be sequence of length
{
msg
}
."
)
def
_setup_angle
(
x
,
name
,
req_sizes
=
(
2
,)):
if
isinstance
(
x
,
numbers
.
Number
):
if
x
<
0
:
raise
ValueError
(
"If {} is a single number, it must be positive."
.
format
(
name
)
)
raise
ValueError
(
f
"If
{
name
}
is a single number, it must be positive."
)
x
=
[
-
x
,
x
]
else
:
_check_sequence_input
(
x
,
name
,
req_sizes
)
...
...
@@ -1867,7 +1857,7 @@ class RandomInvert(torch.nn.Module):
return
img
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"(p={
})"
.
format
(
self
.
p
)
return
self
.
__class__
.
__name__
+
f
"(p=
{
self
.
p
}
)"
class
RandomPosterize
(
torch
.
nn
.
Module
):
...
...
@@ -1899,7 +1889,7 @@ class RandomPosterize(torch.nn.Module):
return
img
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"(bits={
},p={})"
.
format
(
self
.
bits
,
self
.
p
)
return
self
.
__class__
.
__name__
+
f
"(bits=
{
self
.
bits
}
,p=
{
self
.
p
}
)"
class
RandomSolarize
(
torch
.
nn
.
Module
):
...
...
@@ -1931,7 +1921,7 @@ class RandomSolarize(torch.nn.Module):
return
img
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"(threshold={
},p={})"
.
format
(
self
.
threshold
,
self
.
p
)
return
self
.
__class__
.
__name__
+
f
"(threshold=
{
self
.
threshold
}
,p=
{
self
.
p
}
)"
class
RandomAdjustSharpness
(
torch
.
nn
.
Module
):
...
...
@@ -1963,7 +1953,7 @@ class RandomAdjustSharpness(torch.nn.Module):
return
img
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"(sharpness_factor={
},p={})"
.
format
(
self
.
sharpness_factor
,
self
.
p
)
return
self
.
__class__
.
__name__
+
f
"(sharpness_factor=
{
self
.
sharpness_factor
}
,p=
{
self
.
p
}
)"
class
RandomAutocontrast
(
torch
.
nn
.
Module
):
...
...
@@ -1993,7 +1983,7 @@ class RandomAutocontrast(torch.nn.Module):
return
img
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"(p={
})"
.
format
(
self
.
p
)
return
self
.
__class__
.
__name__
+
f
"(p=
{
self
.
p
}
)"
class
RandomEqualize
(
torch
.
nn
.
Module
):
...
...
@@ -2023,4 +2013,4 @@ class RandomEqualize(torch.nn.Module):
return
img
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"(p={
})"
.
format
(
self
.
p
)
return
self
.
__class__
.
__name__
+
f
"(p=
{
self
.
p
}
)"
torchvision/utils.py
View file @
d367a01a
import
math
import
pathlib
import
warnings
from
typing
import
Union
,
Optional
,
List
,
Tuple
,
Text
,
BinaryIO
from
typing
import
Union
,
Optional
,
List
,
Tuple
,
BinaryIO
import
numpy
as
np
import
torch
...
...
@@ -114,7 +114,7 @@ def make_grid(
@
torch
.
no_grad
()
def
save_image
(
tensor
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]],
fp
:
Union
[
Text
,
pathlib
.
Path
,
BinaryIO
],
fp
:
Union
[
str
,
pathlib
.
Path
,
BinaryIO
],
format
:
Optional
[
str
]
=
None
,
**
kwargs
,
)
->
None
:
...
...
Prev
1
…
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment