Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
96f6e0a1
Unverified
Commit
96f6e0a1
authored
Aug 26, 2021
by
Vasilis Vryniotis
Committed by
GitHub
Aug 26, 2021
Browse files
Make get_image_size and get_image_num_channels public. (#4321)
parent
37a9ee5b
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
70 additions
and
40 deletions
+70
-40
references/detection/transforms.py
references/detection/transforms.py
+4
-4
test/test_functional_tensor.py
test/test_functional_tensor.py
+19
-1
torchvision/transforms/autoaugment.py
torchvision/transforms/autoaugment.py
+3
-3
torchvision/transforms/functional.py
torchvision/transforms/functional.py
+25
-13
torchvision/transforms/functional_pil.py
torchvision/transforms/functional_pil.py
+3
-3
torchvision/transforms/functional_tensor.py
torchvision/transforms/functional_tensor.py
+7
-7
torchvision/transforms/transforms.py
torchvision/transforms/transforms.py
+9
-9
No files found.
references/detection/transforms.py
View file @
96f6e0a1
...
@@ -33,7 +33,7 @@ class RandomHorizontalFlip(T.RandomHorizontalFlip):
...
@@ -33,7 +33,7 @@ class RandomHorizontalFlip(T.RandomHorizontalFlip):
if
torch
.
rand
(
1
)
<
self
.
p
:
if
torch
.
rand
(
1
)
<
self
.
p
:
image
=
F
.
hflip
(
image
)
image
=
F
.
hflip
(
image
)
if
target
is
not
None
:
if
target
is
not
None
:
width
,
_
=
F
.
_
get_image_size
(
image
)
width
,
_
=
F
.
get_image_size
(
image
)
target
[
"boxes"
][:,
[
0
,
2
]]
=
width
-
target
[
"boxes"
][:,
[
2
,
0
]]
target
[
"boxes"
][:,
[
0
,
2
]]
=
width
-
target
[
"boxes"
][:,
[
2
,
0
]]
if
"masks"
in
target
:
if
"masks"
in
target
:
target
[
"masks"
]
=
target
[
"masks"
].
flip
(
-
1
)
target
[
"masks"
]
=
target
[
"masks"
].
flip
(
-
1
)
...
@@ -76,7 +76,7 @@ class RandomIoUCrop(nn.Module):
...
@@ -76,7 +76,7 @@ class RandomIoUCrop(nn.Module):
elif
image
.
ndimension
()
==
2
:
elif
image
.
ndimension
()
==
2
:
image
=
image
.
unsqueeze
(
0
)
image
=
image
.
unsqueeze
(
0
)
orig_w
,
orig_h
=
F
.
_
get_image_size
(
image
)
orig_w
,
orig_h
=
F
.
get_image_size
(
image
)
while
True
:
while
True
:
# sample an option
# sample an option
...
@@ -157,7 +157,7 @@ class RandomZoomOut(nn.Module):
...
@@ -157,7 +157,7 @@ class RandomZoomOut(nn.Module):
if
torch
.
rand
(
1
)
<
self
.
p
:
if
torch
.
rand
(
1
)
<
self
.
p
:
return
image
,
target
return
image
,
target
orig_w
,
orig_h
=
F
.
_
get_image_size
(
image
)
orig_w
,
orig_h
=
F
.
get_image_size
(
image
)
r
=
self
.
side_range
[
0
]
+
torch
.
rand
(
1
)
*
(
self
.
side_range
[
1
]
-
self
.
side_range
[
0
])
r
=
self
.
side_range
[
0
]
+
torch
.
rand
(
1
)
*
(
self
.
side_range
[
1
]
-
self
.
side_range
[
0
])
canvas_width
=
int
(
orig_w
*
r
)
canvas_width
=
int
(
orig_w
*
r
)
...
@@ -226,7 +226,7 @@ class RandomPhotometricDistort(nn.Module):
...
@@ -226,7 +226,7 @@ class RandomPhotometricDistort(nn.Module):
image
=
self
.
_contrast
(
image
)
image
=
self
.
_contrast
(
image
)
if
r
[
6
]
<
self
.
p
:
if
r
[
6
]
<
self
.
p
:
channels
=
F
.
_
get_image_num_channels
(
image
)
channels
=
F
.
get_image_num_channels
(
image
)
permutation
=
torch
.
randperm
(
channels
)
permutation
=
torch
.
randperm
(
channels
)
is_pil
=
F
.
_is_pil_image
(
image
)
is_pil
=
F
.
_is_pil_image
(
image
)
...
...
test/test_functional_tensor.py
View file @
96f6e0a1
...
@@ -31,6 +31,24 @@ from typing import Dict, List, Sequence, Tuple
...
@@ -31,6 +31,24 @@ from typing import Dict, List, Sequence, Tuple
NEAREST
,
BILINEAR
,
BICUBIC
=
InterpolationMode
.
NEAREST
,
InterpolationMode
.
BILINEAR
,
InterpolationMode
.
BICUBIC
NEAREST
,
BILINEAR
,
BICUBIC
=
InterpolationMode
.
NEAREST
,
InterpolationMode
.
BILINEAR
,
InterpolationMode
.
BICUBIC
@
pytest
.
mark
.
parametrize
(
'device'
,
cpu_and_gpu
())
@
pytest
.
mark
.
parametrize
(
'fn'
,
[
F
.
get_image_size
,
F
.
get_image_num_channels
])
def
test_image_sizes
(
device
,
fn
):
script_F
=
torch
.
jit
.
script
(
fn
)
img_tensor
,
pil_img
=
_create_data
(
16
,
18
,
3
,
device
=
device
)
value_img
=
fn
(
img_tensor
)
value_pil_img
=
fn
(
pil_img
)
assert
value_img
==
value_pil_img
value_img_script
=
script_F
(
img_tensor
)
assert
value_img
==
value_img_script
batch_tensors
=
_create_data_batch
(
16
,
18
,
3
,
num_samples
=
4
,
device
=
device
)
value_img_batch
=
fn
(
batch_tensors
)
assert
value_img
==
value_img_batch
@
needs_cuda
@
needs_cuda
def
test_scale_channel
():
def
test_scale_channel
():
"""Make sure that _scale_channel gives the same results on CPU and GPU as
"""Make sure that _scale_channel gives the same results on CPU and GPU as
...
@@ -908,7 +926,7 @@ def test_resized_crop(device, mode):
...
@@ -908,7 +926,7 @@ def test_resized_crop(device, mode):
@
pytest
.
mark
.
parametrize
(
'device'
,
cpu_and_gpu
())
@
pytest
.
mark
.
parametrize
(
'device'
,
cpu_and_gpu
())
@
pytest
.
mark
.
parametrize
(
'func, args'
,
[
@
pytest
.
mark
.
parametrize
(
'func, args'
,
[
(
F_t
.
_
get_image_size
,
()),
(
F_t
.
vflip
,
()),
(
F_t
.
get_image_size
,
()),
(
F_t
.
vflip
,
()),
(
F_t
.
hflip
,
()),
(
F_t
.
crop
,
(
1
,
2
,
4
,
5
)),
(
F_t
.
hflip
,
()),
(
F_t
.
crop
,
(
1
,
2
,
4
,
5
)),
(
F_t
.
adjust_brightness
,
(
0.
,
)),
(
F_t
.
adjust_contrast
,
(
1.
,
)),
(
F_t
.
adjust_brightness
,
(
0.
,
)),
(
F_t
.
adjust_contrast
,
(
1.
,
)),
(
F_t
.
adjust_hue
,
(
-
0.5
,
)),
(
F_t
.
adjust_saturation
,
(
2.
,
)),
(
F_t
.
adjust_hue
,
(
-
0.5
,
)),
(
F_t
.
adjust_saturation
,
(
2.
,
)),
...
...
torchvision/transforms/autoaugment.py
View file @
96f6e0a1
...
@@ -188,7 +188,7 @@ class AutoAugment(torch.nn.Module):
...
@@ -188,7 +188,7 @@ class AutoAugment(torch.nn.Module):
fill
=
self
.
fill
fill
=
self
.
fill
if
isinstance
(
img
,
Tensor
):
if
isinstance
(
img
,
Tensor
):
if
isinstance
(
fill
,
(
int
,
float
)):
if
isinstance
(
fill
,
(
int
,
float
)):
fill
=
[
float
(
fill
)]
*
F
.
_
get_image_num_channels
(
img
)
fill
=
[
float
(
fill
)]
*
F
.
get_image_num_channels
(
img
)
elif
fill
is
not
None
:
elif
fill
is
not
None
:
fill
=
[
float
(
f
)
for
f
in
fill
]
fill
=
[
float
(
f
)
for
f
in
fill
]
...
@@ -209,10 +209,10 @@ class AutoAugment(torch.nn.Module):
...
@@ -209,10 +209,10 @@ class AutoAugment(torch.nn.Module):
img
=
F
.
affine
(
img
,
angle
=
0.0
,
translate
=
[
0
,
0
],
scale
=
1.0
,
shear
=
[
0.0
,
math
.
degrees
(
magnitude
)],
img
=
F
.
affine
(
img
,
angle
=
0.0
,
translate
=
[
0
,
0
],
scale
=
1.0
,
shear
=
[
0.0
,
math
.
degrees
(
magnitude
)],
interpolation
=
self
.
interpolation
,
fill
=
fill
)
interpolation
=
self
.
interpolation
,
fill
=
fill
)
elif
op_name
==
"TranslateX"
:
elif
op_name
==
"TranslateX"
:
img
=
F
.
affine
(
img
,
angle
=
0.0
,
translate
=
[
int
(
F
.
_
get_image_size
(
img
)[
0
]
*
magnitude
),
0
],
scale
=
1.0
,
img
=
F
.
affine
(
img
,
angle
=
0.0
,
translate
=
[
int
(
F
.
get_image_size
(
img
)[
0
]
*
magnitude
),
0
],
scale
=
1.0
,
interpolation
=
self
.
interpolation
,
shear
=
[
0.0
,
0.0
],
fill
=
fill
)
interpolation
=
self
.
interpolation
,
shear
=
[
0.0
,
0.0
],
fill
=
fill
)
elif
op_name
==
"TranslateY"
:
elif
op_name
==
"TranslateY"
:
img
=
F
.
affine
(
img
,
angle
=
0.0
,
translate
=
[
0
,
int
(
F
.
_
get_image_size
(
img
)[
1
]
*
magnitude
)],
scale
=
1.0
,
img
=
F
.
affine
(
img
,
angle
=
0.0
,
translate
=
[
0
,
int
(
F
.
get_image_size
(
img
)[
1
]
*
magnitude
)],
scale
=
1.0
,
interpolation
=
self
.
interpolation
,
shear
=
[
0.0
,
0.0
],
fill
=
fill
)
interpolation
=
self
.
interpolation
,
shear
=
[
0.0
,
0.0
],
fill
=
fill
)
elif
op_name
==
"Rotate"
:
elif
op_name
==
"Rotate"
:
img
=
F
.
rotate
(
img
,
magnitude
,
interpolation
=
self
.
interpolation
,
fill
=
fill
)
img
=
F
.
rotate
(
img
,
magnitude
,
interpolation
=
self
.
interpolation
,
fill
=
fill
)
...
...
torchvision/transforms/functional.py
View file @
96f6e0a1
...
@@ -58,22 +58,34 @@ pil_modes_mapping = {
...
@@ -58,22 +58,34 @@ pil_modes_mapping = {
_is_pil_image
=
F_pil
.
_is_pil_image
_is_pil_image
=
F_pil
.
_is_pil_image
def
_get_image_size
(
img
:
Tensor
)
->
List
[
int
]:
def
get_image_size
(
img
:
Tensor
)
->
List
[
int
]:
"""Returns image size as [w, h]
"""Returns the size of an image as [width, height].
Args:
img (PIL Image or Tensor): The image to be checked.
Returns:
List[int]: The image size.
"""
"""
if
isinstance
(
img
,
torch
.
Tensor
):
if
isinstance
(
img
,
torch
.
Tensor
):
return
F_t
.
_
get_image_size
(
img
)
return
F_t
.
get_image_size
(
img
)
return
F_pil
.
_
get_image_size
(
img
)
return
F_pil
.
get_image_size
(
img
)
def
_get_image_num_channels
(
img
:
Tensor
)
->
int
:
def
get_image_num_channels
(
img
:
Tensor
)
->
int
:
"""Returns number of image channels
"""Returns the number of channels of an image.
Args:
img (PIL Image or Tensor): The image to be checked.
Returns:
int: The number of channels.
"""
"""
if
isinstance
(
img
,
torch
.
Tensor
):
if
isinstance
(
img
,
torch
.
Tensor
):
return
F_t
.
_
get_image_num_channels
(
img
)
return
F_t
.
get_image_num_channels
(
img
)
return
F_pil
.
_
get_image_num_channels
(
img
)
return
F_pil
.
get_image_num_channels
(
img
)
@
torch
.
jit
.
unused
@
torch
.
jit
.
unused
...
@@ -500,7 +512,7 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor:
...
@@ -500,7 +512,7 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor:
elif
isinstance
(
output_size
,
(
tuple
,
list
))
and
len
(
output_size
)
==
1
:
elif
isinstance
(
output_size
,
(
tuple
,
list
))
and
len
(
output_size
)
==
1
:
output_size
=
(
output_size
[
0
],
output_size
[
0
])
output_size
=
(
output_size
[
0
],
output_size
[
0
])
image_width
,
image_height
=
_
get_image_size
(
img
)
image_width
,
image_height
=
get_image_size
(
img
)
crop_height
,
crop_width
=
output_size
crop_height
,
crop_width
=
output_size
if
crop_width
>
image_width
or
crop_height
>
image_height
:
if
crop_width
>
image_width
or
crop_height
>
image_height
:
...
@@ -511,7 +523,7 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor:
...
@@ -511,7 +523,7 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor:
(
crop_height
-
image_height
+
1
)
//
2
if
crop_height
>
image_height
else
0
,
(
crop_height
-
image_height
+
1
)
//
2
if
crop_height
>
image_height
else
0
,
]
]
img
=
pad
(
img
,
padding_ltrb
,
fill
=
0
)
# PIL uses fill value 0
img
=
pad
(
img
,
padding_ltrb
,
fill
=
0
)
# PIL uses fill value 0
image_width
,
image_height
=
_
get_image_size
(
img
)
image_width
,
image_height
=
get_image_size
(
img
)
if
crop_width
==
image_width
and
crop_height
==
image_height
:
if
crop_width
==
image_width
and
crop_height
==
image_height
:
return
img
return
img
...
@@ -696,7 +708,7 @@ def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Ten
...
@@ -696,7 +708,7 @@ def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Ten
if
len
(
size
)
!=
2
:
if
len
(
size
)
!=
2
:
raise
ValueError
(
"Please provide only two dimensions (h, w) for size."
)
raise
ValueError
(
"Please provide only two dimensions (h, w) for size."
)
image_width
,
image_height
=
_
get_image_size
(
img
)
image_width
,
image_height
=
get_image_size
(
img
)
crop_height
,
crop_width
=
size
crop_height
,
crop_width
=
size
if
crop_width
>
image_width
or
crop_height
>
image_height
:
if
crop_width
>
image_width
or
crop_height
>
image_height
:
msg
=
"Requested crop size {} is bigger than input size {}"
msg
=
"Requested crop size {} is bigger than input size {}"
...
@@ -993,7 +1005,7 @@ def rotate(
...
@@ -993,7 +1005,7 @@ def rotate(
center_f
=
[
0.0
,
0.0
]
center_f
=
[
0.0
,
0.0
]
if
center
is
not
None
:
if
center
is
not
None
:
img_size
=
_
get_image_size
(
img
)
img_size
=
get_image_size
(
img
)
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
center_f
=
[
1.0
*
(
c
-
s
*
0.5
)
for
c
,
s
in
zip
(
center
,
img_size
)]
center_f
=
[
1.0
*
(
c
-
s
*
0.5
)
for
c
,
s
in
zip
(
center
,
img_size
)]
...
@@ -1094,7 +1106,7 @@ def affine(
...
@@ -1094,7 +1106,7 @@ def affine(
if
len
(
shear
)
!=
2
:
if
len
(
shear
)
!=
2
:
raise
ValueError
(
"Shear should be a sequence containing two values. Got {}"
.
format
(
shear
))
raise
ValueError
(
"Shear should be a sequence containing two values. Got {}"
.
format
(
shear
))
img_size
=
_
get_image_size
(
img
)
img_size
=
get_image_size
(
img
)
if
not
isinstance
(
img
,
torch
.
Tensor
):
if
not
isinstance
(
img
,
torch
.
Tensor
):
# center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5)
# center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5)
# it is visually better to estimate the center without 0.5 offset
# it is visually better to estimate the center without 0.5 offset
...
...
torchvision/transforms/functional_pil.py
View file @
96f6e0a1
...
@@ -20,14 +20,14 @@ def _is_pil_image(img: Any) -> bool:
...
@@ -20,14 +20,14 @@ def _is_pil_image(img: Any) -> bool:
@
torch
.
jit
.
unused
@
torch
.
jit
.
unused
def
_
get_image_size
(
img
:
Any
)
->
List
[
int
]:
def
get_image_size
(
img
:
Any
)
->
List
[
int
]:
if
_is_pil_image
(
img
):
if
_is_pil_image
(
img
):
return
img
.
size
return
list
(
img
.
size
)
raise
TypeError
(
"Unexpected type {}"
.
format
(
type
(
img
)))
raise
TypeError
(
"Unexpected type {}"
.
format
(
type
(
img
)))
@
torch
.
jit
.
unused
@
torch
.
jit
.
unused
def
_
get_image_num_channels
(
img
:
Any
)
->
int
:
def
get_image_num_channels
(
img
:
Any
)
->
int
:
if
_is_pil_image
(
img
):
if
_is_pil_image
(
img
):
return
1
if
img
.
mode
==
'L'
else
3
return
1
if
img
.
mode
==
'L'
else
3
raise
TypeError
(
"Unexpected type {}"
.
format
(
type
(
img
)))
raise
TypeError
(
"Unexpected type {}"
.
format
(
type
(
img
)))
...
...
torchvision/transforms/functional_tensor.py
View file @
96f6e0a1
...
@@ -16,13 +16,13 @@ def _assert_image_tensor(img: Tensor) -> None:
...
@@ -16,13 +16,13 @@ def _assert_image_tensor(img: Tensor) -> None:
raise
TypeError
(
"Tensor is not a torch image."
)
raise
TypeError
(
"Tensor is not a torch image."
)
def
_
get_image_size
(
img
:
Tensor
)
->
List
[
int
]:
def
get_image_size
(
img
:
Tensor
)
->
List
[
int
]:
# Returns (w, h) of tensor image
# Returns (w, h) of tensor image
_assert_image_tensor
(
img
)
_assert_image_tensor
(
img
)
return
[
img
.
shape
[
-
1
],
img
.
shape
[
-
2
]]
return
[
img
.
shape
[
-
1
],
img
.
shape
[
-
2
]]
def
_
get_image_num_channels
(
img
:
Tensor
)
->
int
:
def
get_image_num_channels
(
img
:
Tensor
)
->
int
:
if
img
.
ndim
==
2
:
if
img
.
ndim
==
2
:
return
1
return
1
elif
img
.
ndim
>
2
:
elif
img
.
ndim
>
2
:
...
@@ -50,7 +50,7 @@ def _max_value(dtype: torch.dtype) -> float:
...
@@ -50,7 +50,7 @@ def _max_value(dtype: torch.dtype) -> float:
def
_assert_channels
(
img
:
Tensor
,
permitted
:
List
[
int
])
->
None
:
def
_assert_channels
(
img
:
Tensor
,
permitted
:
List
[
int
])
->
None
:
c
=
_
get_image_num_channels
(
img
)
c
=
get_image_num_channels
(
img
)
if
c
not
in
permitted
:
if
c
not
in
permitted
:
raise
TypeError
(
"Input image tensor permitted channel values are {}, but found {}"
.
format
(
permitted
,
c
))
raise
TypeError
(
"Input image tensor permitted channel values are {}, but found {}"
.
format
(
permitted
,
c
))
...
@@ -122,7 +122,7 @@ def hflip(img: Tensor) -> Tensor:
...
@@ -122,7 +122,7 @@ def hflip(img: Tensor) -> Tensor:
def
crop
(
img
:
Tensor
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
Tensor
:
def
crop
(
img
:
Tensor
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
Tensor
:
_assert_image_tensor
(
img
)
_assert_image_tensor
(
img
)
w
,
h
=
_
get_image_size
(
img
)
w
,
h
=
get_image_size
(
img
)
right
=
left
+
width
right
=
left
+
width
bottom
=
top
+
height
bottom
=
top
+
height
...
@@ -187,7 +187,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
...
@@ -187,7 +187,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
_assert_image_tensor
(
img
)
_assert_image_tensor
(
img
)
_assert_channels
(
img
,
[
1
,
3
])
_assert_channels
(
img
,
[
1
,
3
])
if
_
get_image_num_channels
(
img
)
==
1
:
# Match PIL behaviour
if
get_image_num_channels
(
img
)
==
1
:
# Match PIL behaviour
return
img
return
img
orig_dtype
=
img
.
dtype
orig_dtype
=
img
.
dtype
...
@@ -513,7 +513,7 @@ def resize(
...
@@ -513,7 +513,7 @@ def resize(
if
antialias
and
interpolation
not
in
[
"bilinear"
,
"bicubic"
]:
if
antialias
and
interpolation
not
in
[
"bilinear"
,
"bicubic"
]:
raise
ValueError
(
"Antialias option is supported for bilinear and bicubic interpolation modes only"
)
raise
ValueError
(
"Antialias option is supported for bilinear and bicubic interpolation modes only"
)
w
,
h
=
_
get_image_size
(
img
)
w
,
h
=
get_image_size
(
img
)
if
isinstance
(
size
,
int
)
or
len
(
size
)
==
1
:
# specified size only for the smallest edge
if
isinstance
(
size
,
int
)
or
len
(
size
)
==
1
:
# specified size only for the smallest edge
short
,
long
=
(
w
,
h
)
if
w
<=
h
else
(
h
,
w
)
short
,
long
=
(
w
,
h
)
if
w
<=
h
else
(
h
,
w
)
...
@@ -586,7 +586,7 @@ def _assert_grid_transform_inputs(
...
@@ -586,7 +586,7 @@ def _assert_grid_transform_inputs(
warnings
.
warn
(
"Argument fill should be either int, float, tuple or list"
)
warnings
.
warn
(
"Argument fill should be either int, float, tuple or list"
)
# Check fill
# Check fill
num_channels
=
_
get_image_num_channels
(
img
)
num_channels
=
get_image_num_channels
(
img
)
if
isinstance
(
fill
,
(
tuple
,
list
))
and
(
len
(
fill
)
>
1
and
len
(
fill
)
!=
num_channels
):
if
isinstance
(
fill
,
(
tuple
,
list
))
and
(
len
(
fill
)
>
1
and
len
(
fill
)
!=
num_channels
):
msg
=
(
"The number of elements in 'fill' cannot broadcast to match the number of "
msg
=
(
"The number of elements in 'fill' cannot broadcast to match the number of "
"channels of the image ({} != {})"
)
"channels of the image ({} != {})"
)
...
...
torchvision/transforms/transforms.py
View file @
96f6e0a1
...
@@ -575,7 +575,7 @@ class RandomCrop(torch.nn.Module):
...
@@ -575,7 +575,7 @@ class RandomCrop(torch.nn.Module):
Returns:
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
"""
"""
w
,
h
=
F
.
_
get_image_size
(
img
)
w
,
h
=
F
.
get_image_size
(
img
)
th
,
tw
=
output_size
th
,
tw
=
output_size
if
h
+
1
<
th
or
w
+
1
<
tw
:
if
h
+
1
<
th
or
w
+
1
<
tw
:
...
@@ -613,7 +613,7 @@ class RandomCrop(torch.nn.Module):
...
@@ -613,7 +613,7 @@ class RandomCrop(torch.nn.Module):
if
self
.
padding
is
not
None
:
if
self
.
padding
is
not
None
:
img
=
F
.
pad
(
img
,
self
.
padding
,
self
.
fill
,
self
.
padding_mode
)
img
=
F
.
pad
(
img
,
self
.
padding
,
self
.
fill
,
self
.
padding_mode
)
width
,
height
=
F
.
_
get_image_size
(
img
)
width
,
height
=
F
.
get_image_size
(
img
)
# pad the width if needed
# pad the width if needed
if
self
.
pad_if_needed
and
width
<
self
.
size
[
1
]:
if
self
.
pad_if_needed
and
width
<
self
.
size
[
1
]:
padding
=
[
self
.
size
[
1
]
-
width
,
0
]
padding
=
[
self
.
size
[
1
]
-
width
,
0
]
...
@@ -742,12 +742,12 @@ class RandomPerspective(torch.nn.Module):
...
@@ -742,12 +742,12 @@ class RandomPerspective(torch.nn.Module):
fill
=
self
.
fill
fill
=
self
.
fill
if
isinstance
(
img
,
Tensor
):
if
isinstance
(
img
,
Tensor
):
if
isinstance
(
fill
,
(
int
,
float
)):
if
isinstance
(
fill
,
(
int
,
float
)):
fill
=
[
float
(
fill
)]
*
F
.
_
get_image_num_channels
(
img
)
fill
=
[
float
(
fill
)]
*
F
.
get_image_num_channels
(
img
)
else
:
else
:
fill
=
[
float
(
f
)
for
f
in
fill
]
fill
=
[
float
(
f
)
for
f
in
fill
]
if
torch
.
rand
(
1
)
<
self
.
p
:
if
torch
.
rand
(
1
)
<
self
.
p
:
width
,
height
=
F
.
_
get_image_size
(
img
)
width
,
height
=
F
.
get_image_size
(
img
)
startpoints
,
endpoints
=
self
.
get_params
(
width
,
height
,
self
.
distortion_scale
)
startpoints
,
endpoints
=
self
.
get_params
(
width
,
height
,
self
.
distortion_scale
)
return
F
.
perspective
(
img
,
startpoints
,
endpoints
,
self
.
interpolation
,
fill
)
return
F
.
perspective
(
img
,
startpoints
,
endpoints
,
self
.
interpolation
,
fill
)
return
img
return
img
...
@@ -858,7 +858,7 @@ class RandomResizedCrop(torch.nn.Module):
...
@@ -858,7 +858,7 @@ class RandomResizedCrop(torch.nn.Module):
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
sized crop.
sized crop.
"""
"""
width
,
height
=
F
.
_
get_image_size
(
img
)
width
,
height
=
F
.
get_image_size
(
img
)
area
=
height
*
width
area
=
height
*
width
log_ratio
=
torch
.
log
(
torch
.
tensor
(
ratio
))
log_ratio
=
torch
.
log
(
torch
.
tensor
(
ratio
))
...
@@ -1280,7 +1280,7 @@ class RandomRotation(torch.nn.Module):
...
@@ -1280,7 +1280,7 @@ class RandomRotation(torch.nn.Module):
fill
=
self
.
fill
fill
=
self
.
fill
if
isinstance
(
img
,
Tensor
):
if
isinstance
(
img
,
Tensor
):
if
isinstance
(
fill
,
(
int
,
float
)):
if
isinstance
(
fill
,
(
int
,
float
)):
fill
=
[
float
(
fill
)]
*
F
.
_
get_image_num_channels
(
img
)
fill
=
[
float
(
fill
)]
*
F
.
get_image_num_channels
(
img
)
else
:
else
:
fill
=
[
float
(
f
)
for
f
in
fill
]
fill
=
[
float
(
f
)
for
f
in
fill
]
angle
=
self
.
get_params
(
self
.
degrees
)
angle
=
self
.
get_params
(
self
.
degrees
)
...
@@ -1439,11 +1439,11 @@ class RandomAffine(torch.nn.Module):
...
@@ -1439,11 +1439,11 @@ class RandomAffine(torch.nn.Module):
fill
=
self
.
fill
fill
=
self
.
fill
if
isinstance
(
img
,
Tensor
):
if
isinstance
(
img
,
Tensor
):
if
isinstance
(
fill
,
(
int
,
float
)):
if
isinstance
(
fill
,
(
int
,
float
)):
fill
=
[
float
(
fill
)]
*
F
.
_
get_image_num_channels
(
img
)
fill
=
[
float
(
fill
)]
*
F
.
get_image_num_channels
(
img
)
else
:
else
:
fill
=
[
float
(
f
)
for
f
in
fill
]
fill
=
[
float
(
f
)
for
f
in
fill
]
img_size
=
F
.
_
get_image_size
(
img
)
img_size
=
F
.
get_image_size
(
img
)
ret
=
self
.
get_params
(
self
.
degrees
,
self
.
translate
,
self
.
scale
,
self
.
shear
,
img_size
)
ret
=
self
.
get_params
(
self
.
degrees
,
self
.
translate
,
self
.
scale
,
self
.
shear
,
img_size
)
...
@@ -1529,7 +1529,7 @@ class RandomGrayscale(torch.nn.Module):
...
@@ -1529,7 +1529,7 @@ class RandomGrayscale(torch.nn.Module):
Returns:
Returns:
PIL Image or Tensor: Randomly grayscaled image.
PIL Image or Tensor: Randomly grayscaled image.
"""
"""
num_output_channels
=
F
.
_
get_image_num_channels
(
img
)
num_output_channels
=
F
.
get_image_num_channels
(
img
)
if
torch
.
rand
(
1
)
<
self
.
p
:
if
torch
.
rand
(
1
)
<
self
.
p
:
return
F
.
rgb_to_grayscale
(
img
,
num_output_channels
=
num_output_channels
)
return
F
.
rgb_to_grayscale
(
img
,
num_output_channels
=
num_output_channels
)
return
img
return
img
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment