Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
f3c89cc6
Unverified
Commit
f3c89cc6
authored
Aug 03, 2023
by
Nicolas Hug
Committed by
GitHub
Aug 03, 2023
Browse files
Remove cutmix and mixup from prototype (#7787)
Co-authored-by:
Philip Meier
<
github.pmeier@posteo.de
>
parent
cab9fba8
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
3 additions
and
142 deletions
+3
-142
test/test_prototype_transforms.py
test/test_prototype_transforms.py
+1
-47
torchvision/prototype/transforms/__init__.py
torchvision/prototype/transforms/__init__.py
+1
-1
torchvision/prototype/transforms/_augment.py
torchvision/prototype/transforms/_augment.py
+1
-94
No files found.
test/test_prototype_transforms.py
View file @
f3c89cc6
...
...
@@ -12,13 +12,10 @@ from common_utils import (
make_bounding_box
,
make_detection_mask
,
make_image
,
make_images
,
make_segmentation_mask
,
make_video
,
make_videos
,
)
from
prototype_common_utils
import
make_label
,
make_one_hot_labels
from
prototype_common_utils
import
make_label
from
torchvision.datapoints
import
BoundingBoxes
,
BoundingBoxFormat
,
Image
,
Mask
,
Video
from
torchvision.prototype
import
datapoints
,
transforms
...
...
@@ -44,49 +41,6 @@ def parametrize(transforms_with_inputs):
)
@
parametrize
(
[
(
transform
,
[
dict
(
inpt
=
inpt
,
one_hot_label
=
one_hot_label
)
for
inpt
,
one_hot_label
in
itertools
.
product
(
itertools
.
chain
(
make_images
(
extra_dims
=
BATCH_EXTRA_DIMS
,
dtypes
=
[
torch
.
float
]),
make_videos
(
extra_dims
=
BATCH_EXTRA_DIMS
,
dtypes
=
[
torch
.
float
]),
),
make_one_hot_labels
(
extra_dims
=
BATCH_EXTRA_DIMS
,
dtypes
=
[
torch
.
float
]),
)
],
)
for
transform
in
[
transforms
.
RandomMixUp
(
alpha
=
1.0
),
transforms
.
RandomCutMix
(
alpha
=
1.0
),
]
]
)
def
test_mixup_cutmix
(
transform
,
input
):
transform
(
input
)
input_copy
=
dict
(
input
)
input_copy
[
"path"
]
=
"/path/to/somewhere"
input_copy
[
"num"
]
=
1234
transform
(
input_copy
)
# Check if we raise an error if sample contains bbox or mask or label
err_msg
=
"does not support PIL images, bounding boxes, masks and plain labels"
input_copy
=
dict
(
input
)
for
unsup_data
in
[
make_label
(),
make_bounding_box
(
format
=
"XYXY"
),
make_detection_mask
(),
make_segmentation_mask
(),
]:
input_copy
[
"unsupported"
]
=
unsup_data
with
pytest
.
raises
(
TypeError
,
match
=
err_msg
):
transform
(
input_copy
)
class
TestSimpleCopyPaste
:
def
create_fake_image
(
self
,
mocker
,
image_type
):
if
image_type
==
PIL
.
Image
.
Image
:
...
...
torchvision/prototype/transforms/__init__.py
View file @
f3c89cc6
from
._presets
import
StereoMatching
# usort: skip
from
._augment
import
RandomCutMix
,
RandomMixUp
,
SimpleCopyPaste
from
._augment
import
SimpleCopyPaste
from
._geometry
import
FixedSizeCrop
from
._misc
import
PermuteDimensions
,
TransposeDimensions
from
._type_conversion
import
LabelToOneHot
torchvision/prototype/transforms/_augment.py
View file @
f3c89cc6
import
math
from
typing
import
Any
,
cast
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
PIL.Image
...
...
@@ -9,100 +8,8 @@ from torchvision.ops import masks_to_boxes
from
torchvision.prototype
import
datapoints
as
proto_datapoints
from
torchvision.transforms.v2
import
functional
as
F
,
InterpolationMode
,
Transform
from
torchvision.transforms.v2._transform
import
_RandomApplyTransform
from
torchvision.transforms.v2.functional._geometry
import
_check_interpolation
from
torchvision.transforms.v2.utils
import
has_any
,
is_simple_tensor
,
query_size
class
_BaseMixUpCutMix
(
_RandomApplyTransform
):
def
__init__
(
self
,
alpha
:
float
,
p
:
float
=
0.5
)
->
None
:
super
().
__init__
(
p
=
p
)
self
.
alpha
=
alpha
self
.
_dist
=
torch
.
distributions
.
Beta
(
torch
.
tensor
([
alpha
]),
torch
.
tensor
([
alpha
]))
def
_check_inputs
(
self
,
flat_inputs
:
List
[
Any
])
->
None
:
if
not
(
has_any
(
flat_inputs
,
datapoints
.
Image
,
datapoints
.
Video
,
is_simple_tensor
)
and
has_any
(
flat_inputs
,
proto_datapoints
.
OneHotLabel
)
):
raise
TypeError
(
f
"
{
type
(
self
).
__name__
}
() is only defined for tensor images/videos and one-hot labels."
)
if
has_any
(
flat_inputs
,
PIL
.
Image
.
Image
,
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
,
proto_datapoints
.
Label
):
raise
TypeError
(
f
"
{
type
(
self
).
__name__
}
() does not support PIL images, bounding boxes, masks and plain labels."
)
def
_mixup_onehotlabel
(
self
,
inpt
:
proto_datapoints
.
OneHotLabel
,
lam
:
float
)
->
proto_datapoints
.
OneHotLabel
:
if
inpt
.
ndim
<
2
:
raise
ValueError
(
"Need a batch of one hot labels"
)
output
=
inpt
.
roll
(
1
,
0
).
mul_
(
1.0
-
lam
).
add_
(
inpt
.
mul
(
lam
))
return
proto_datapoints
.
OneHotLabel
.
wrap_like
(
inpt
,
output
)
class
RandomMixUp
(
_BaseMixUpCutMix
):
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
return
dict
(
lam
=
float
(
self
.
_dist
.
sample
(())))
# type: ignore[arg-type]
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
lam
=
params
[
"lam"
]
if
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
or
is_simple_tensor
(
inpt
):
expected_ndim
=
5
if
isinstance
(
inpt
,
datapoints
.
Video
)
else
4
if
inpt
.
ndim
<
expected_ndim
:
raise
ValueError
(
"The transform expects a batched input"
)
output
=
inpt
.
roll
(
1
,
0
).
mul_
(
1.0
-
lam
).
add_
(
inpt
.
mul
(
lam
))
if
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
output
=
type
(
inpt
).
wrap_like
(
inpt
,
output
)
# type: ignore[arg-type]
return
output
elif
isinstance
(
inpt
,
proto_datapoints
.
OneHotLabel
):
return
self
.
_mixup_onehotlabel
(
inpt
,
lam
)
else
:
return
inpt
class
RandomCutMix
(
_BaseMixUpCutMix
):
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
lam
=
float
(
self
.
_dist
.
sample
(()))
# type: ignore[arg-type]
H
,
W
=
query_size
(
flat_inputs
)
r_x
=
torch
.
randint
(
W
,
())
r_y
=
torch
.
randint
(
H
,
())
r
=
0.5
*
math
.
sqrt
(
1.0
-
lam
)
r_w_half
=
int
(
r
*
W
)
r_h_half
=
int
(
r
*
H
)
x1
=
int
(
torch
.
clamp
(
r_x
-
r_w_half
,
min
=
0
))
y1
=
int
(
torch
.
clamp
(
r_y
-
r_h_half
,
min
=
0
))
x2
=
int
(
torch
.
clamp
(
r_x
+
r_w_half
,
max
=
W
))
y2
=
int
(
torch
.
clamp
(
r_y
+
r_h_half
,
max
=
H
))
box
=
(
x1
,
y1
,
x2
,
y2
)
lam_adjusted
=
float
(
1.0
-
(
x2
-
x1
)
*
(
y2
-
y1
)
/
(
W
*
H
))
return
dict
(
box
=
box
,
lam_adjusted
=
lam_adjusted
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
if
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
or
is_simple_tensor
(
inpt
):
box
=
params
[
"box"
]
expected_ndim
=
5
if
isinstance
(
inpt
,
datapoints
.
Video
)
else
4
if
inpt
.
ndim
<
expected_ndim
:
raise
ValueError
(
"The transform expects a batched input"
)
x1
,
y1
,
x2
,
y2
=
box
rolled
=
inpt
.
roll
(
1
,
0
)
output
=
inpt
.
clone
()
output
[...,
y1
:
y2
,
x1
:
x2
]
=
rolled
[...,
y1
:
y2
,
x1
:
x2
]
if
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
output
=
inpt
.
wrap_like
(
inpt
,
output
)
# type: ignore[arg-type]
return
output
elif
isinstance
(
inpt
,
proto_datapoints
.
OneHotLabel
):
lam_adjusted
=
params
[
"lam_adjusted"
]
return
self
.
_mixup_onehotlabel
(
inpt
,
lam_adjusted
)
else
:
return
inpt
from
torchvision.transforms.v2.utils
import
is_simple_tensor
class
SimpleCopyPaste
(
Transform
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment