Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
bb3aae7b
Unverified
Commit
bb3aae7b
authored
Jul 13, 2023
by
Nicolas Hug
Committed by
GitHub
Jul 13, 2023
Browse files
Add --backend and --use-v2 support to detection refs (#7732)
Co-authored-by:
Philip Meier
<
github.pmeier@posteo.de
>
parent
08c9938f
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
166 additions
and
106 deletions
+166
-106
references/classification/presets.py
references/classification/presets.py
+25
-22
references/detection/coco_utils.py
references/detection/coco_utils.py
+20
-15
references/detection/engine.py
references/detection/engine.py
+2
-2
references/detection/group_by_aspect_ratio.py
references/detection/group_by_aspect_ratio.py
+3
-1
references/detection/presets.py
references/detection/presets.py
+89
-53
references/detection/train.py
references/detection/train.py
+19
-10
references/detection/transforms.py
references/detection/transforms.py
+7
-2
test/test_transforms_v2_consistency.py
test/test_transforms_v2_consistency.py
+1
-1
No files found.
references/classification/presets.py
View file @
bb3aae7b
...
@@ -15,6 +15,9 @@ def get_module(use_v2):
...
@@ -15,6 +15,9 @@ def get_module(use_v2):
class
ClassificationPresetTrain
:
class
ClassificationPresetTrain
:
# Note: this transform assumes that the input to forward() are always PIL
# images, regardless of the backend parameter. We may change that in the
# future though, if we change the output type from the dataset.
def
__init__
(
def
__init__
(
self
,
self
,
*
,
*
,
...
@@ -30,42 +33,42 @@ class ClassificationPresetTrain:
...
@@ -30,42 +33,42 @@ class ClassificationPresetTrain:
backend
=
"pil"
,
backend
=
"pil"
,
use_v2
=
False
,
use_v2
=
False
,
):
):
module
=
get_module
(
use_v2
)
T
=
get_module
(
use_v2
)
transforms
=
[]
transforms
=
[]
backend
=
backend
.
lower
()
backend
=
backend
.
lower
()
if
backend
==
"tensor"
:
if
backend
==
"tensor"
:
transforms
.
append
(
module
.
PILToTensor
())
transforms
.
append
(
T
.
PILToTensor
())
elif
backend
!=
"pil"
:
elif
backend
!=
"pil"
:
raise
ValueError
(
f
"backend can be 'tensor' or 'pil', but got
{
backend
}
"
)
raise
ValueError
(
f
"backend can be 'tensor' or 'pil', but got
{
backend
}
"
)
transforms
.
append
(
module
.
RandomResizedCrop
(
crop_size
,
interpolation
=
interpolation
,
antialias
=
True
))
transforms
.
append
(
T
.
RandomResizedCrop
(
crop_size
,
interpolation
=
interpolation
,
antialias
=
True
))
if
hflip_prob
>
0
:
if
hflip_prob
>
0
:
transforms
.
append
(
module
.
RandomHorizontalFlip
(
hflip_prob
))
transforms
.
append
(
T
.
RandomHorizontalFlip
(
hflip_prob
))
if
auto_augment_policy
is
not
None
:
if
auto_augment_policy
is
not
None
:
if
auto_augment_policy
==
"ra"
:
if
auto_augment_policy
==
"ra"
:
transforms
.
append
(
module
.
RandAugment
(
interpolation
=
interpolation
,
magnitude
=
ra_magnitude
))
transforms
.
append
(
T
.
RandAugment
(
interpolation
=
interpolation
,
magnitude
=
ra_magnitude
))
elif
auto_augment_policy
==
"ta_wide"
:
elif
auto_augment_policy
==
"ta_wide"
:
transforms
.
append
(
module
.
TrivialAugmentWide
(
interpolation
=
interpolation
))
transforms
.
append
(
T
.
TrivialAugmentWide
(
interpolation
=
interpolation
))
elif
auto_augment_policy
==
"augmix"
:
elif
auto_augment_policy
==
"augmix"
:
transforms
.
append
(
module
.
AugMix
(
interpolation
=
interpolation
,
severity
=
augmix_severity
))
transforms
.
append
(
T
.
AugMix
(
interpolation
=
interpolation
,
severity
=
augmix_severity
))
else
:
else
:
aa_policy
=
module
.
AutoAugmentPolicy
(
auto_augment_policy
)
aa_policy
=
T
.
AutoAugmentPolicy
(
auto_augment_policy
)
transforms
.
append
(
module
.
AutoAugment
(
policy
=
aa_policy
,
interpolation
=
interpolation
))
transforms
.
append
(
T
.
AutoAugment
(
policy
=
aa_policy
,
interpolation
=
interpolation
))
if
backend
==
"pil"
:
if
backend
==
"pil"
:
transforms
.
append
(
module
.
PILToTensor
())
transforms
.
append
(
T
.
PILToTensor
())
transforms
.
extend
(
transforms
.
extend
(
[
[
module
.
ConvertImageDtype
(
torch
.
float
),
T
.
ConvertImageDtype
(
torch
.
float
),
module
.
Normalize
(
mean
=
mean
,
std
=
std
),
T
.
Normalize
(
mean
=
mean
,
std
=
std
),
]
]
)
)
if
random_erase_prob
>
0
:
if
random_erase_prob
>
0
:
transforms
.
append
(
module
.
RandomErasing
(
p
=
random_erase_prob
))
transforms
.
append
(
T
.
RandomErasing
(
p
=
random_erase_prob
))
self
.
transforms
=
module
.
Compose
(
transforms
)
self
.
transforms
=
T
.
Compose
(
transforms
)
def
__call__
(
self
,
img
):
def
__call__
(
self
,
img
):
return
self
.
transforms
(
img
)
return
self
.
transforms
(
img
)
...
@@ -83,28 +86,28 @@ class ClassificationPresetEval:
...
@@ -83,28 +86,28 @@ class ClassificationPresetEval:
backend
=
"pil"
,
backend
=
"pil"
,
use_v2
=
False
,
use_v2
=
False
,
):
):
module
=
get_module
(
use_v2
)
T
=
get_module
(
use_v2
)
transforms
=
[]
transforms
=
[]
backend
=
backend
.
lower
()
backend
=
backend
.
lower
()
if
backend
==
"tensor"
:
if
backend
==
"tensor"
:
transforms
.
append
(
module
.
PILToTensor
())
transforms
.
append
(
T
.
PILToTensor
())
elif
backend
!=
"pil"
:
elif
backend
!=
"pil"
:
raise
ValueError
(
f
"backend can be 'tensor' or 'pil', but got
{
backend
}
"
)
raise
ValueError
(
f
"backend can be 'tensor' or 'pil', but got
{
backend
}
"
)
transforms
+=
[
transforms
+=
[
module
.
Resize
(
resize_size
,
interpolation
=
interpolation
,
antialias
=
True
),
T
.
Resize
(
resize_size
,
interpolation
=
interpolation
,
antialias
=
True
),
module
.
CenterCrop
(
crop_size
),
T
.
CenterCrop
(
crop_size
),
]
]
if
backend
==
"pil"
:
if
backend
==
"pil"
:
transforms
.
append
(
module
.
PILToTensor
())
transforms
.
append
(
T
.
PILToTensor
())
transforms
+=
[
transforms
+=
[
module
.
ConvertImageDtype
(
torch
.
float
),
T
.
ConvertImageDtype
(
torch
.
float
),
module
.
Normalize
(
mean
=
mean
,
std
=
std
),
T
.
Normalize
(
mean
=
mean
,
std
=
std
),
]
]
self
.
transforms
=
module
.
Compose
(
transforms
)
self
.
transforms
=
T
.
Compose
(
transforms
)
def
__call__
(
self
,
img
):
def
__call__
(
self
,
img
):
return
self
.
transforms
(
img
)
return
self
.
transforms
(
img
)
references/detection/coco_utils.py
View file @
bb3aae7b
...
@@ -7,6 +7,7 @@ import torchvision
...
@@ -7,6 +7,7 @@ import torchvision
import
transforms
as
T
import
transforms
as
T
from
pycocotools
import
mask
as
coco_mask
from
pycocotools
import
mask
as
coco_mask
from
pycocotools.coco
import
COCO
from
pycocotools.coco
import
COCO
from
torchvision.datasets
import
wrap_dataset_for_transforms_v2
class
FilterAndRemapCocoCategories
:
class
FilterAndRemapCocoCategories
:
...
@@ -49,7 +50,6 @@ class ConvertCocoPolysToMask:
...
@@ -49,7 +50,6 @@ class ConvertCocoPolysToMask:
w
,
h
=
image
.
size
w
,
h
=
image
.
size
image_id
=
target
[
"image_id"
]
image_id
=
target
[
"image_id"
]
image_id
=
torch
.
tensor
([
image_id
])
anno
=
target
[
"annotations"
]
anno
=
target
[
"annotations"
]
...
@@ -126,10 +126,6 @@ def _coco_remove_images_without_annotations(dataset, cat_list=None):
...
@@ -126,10 +126,6 @@ def _coco_remove_images_without_annotations(dataset, cat_list=None):
return
True
return
True
return
False
return
False
if
not
isinstance
(
dataset
,
torchvision
.
datasets
.
CocoDetection
):
raise
TypeError
(
f
"This function expects dataset of type torchvision.datasets.CocoDetection, instead got
{
type
(
dataset
)
}
"
)
ids
=
[]
ids
=
[]
for
ds_idx
,
img_id
in
enumerate
(
dataset
.
ids
):
for
ds_idx
,
img_id
in
enumerate
(
dataset
.
ids
):
ann_ids
=
dataset
.
coco
.
getAnnIds
(
imgIds
=
img_id
,
iscrowd
=
None
)
ann_ids
=
dataset
.
coco
.
getAnnIds
(
imgIds
=
img_id
,
iscrowd
=
None
)
...
@@ -196,12 +192,15 @@ def convert_to_coco_api(ds):
...
@@ -196,12 +192,15 @@ def convert_to_coco_api(ds):
def
get_coco_api_from_dataset
(
dataset
):
def
get_coco_api_from_dataset
(
dataset
):
# FIXME: This is... awful?
for
_
in
range
(
10
):
for
_
in
range
(
10
):
if
isinstance
(
dataset
,
torchvision
.
datasets
.
CocoDetection
):
if
isinstance
(
dataset
,
torchvision
.
datasets
.
CocoDetection
):
break
break
if
isinstance
(
dataset
,
torch
.
utils
.
data
.
Subset
):
if
isinstance
(
dataset
,
torch
.
utils
.
data
.
Subset
):
dataset
=
dataset
.
dataset
dataset
=
dataset
.
dataset
if
isinstance
(
dataset
,
torchvision
.
datasets
.
CocoDetection
):
if
isinstance
(
dataset
,
torchvision
.
datasets
.
CocoDetection
)
or
isinstance
(
getattr
(
dataset
,
"_dataset"
,
None
),
torchvision
.
datasets
.
CocoDetection
):
return
dataset
.
coco
return
dataset
.
coco
return
convert_to_coco_api
(
dataset
)
return
convert_to_coco_api
(
dataset
)
...
@@ -220,7 +219,7 @@ class CocoDetection(torchvision.datasets.CocoDetection):
...
@@ -220,7 +219,7 @@ class CocoDetection(torchvision.datasets.CocoDetection):
return
img
,
target
return
img
,
target
def
get_coco
(
root
,
image_set
,
transforms
,
mode
=
"instances"
):
def
get_coco
(
root
,
image_set
,
transforms
,
mode
=
"instances"
,
use_v2
=
False
):
anno_file_template
=
"{}_{}2017.json"
anno_file_template
=
"{}_{}2017.json"
PATHS
=
{
PATHS
=
{
"train"
:
(
"train2017"
,
os
.
path
.
join
(
"annotations"
,
anno_file_template
.
format
(
mode
,
"train"
))),
"train"
:
(
"train2017"
,
os
.
path
.
join
(
"annotations"
,
anno_file_template
.
format
(
mode
,
"train"
))),
...
@@ -228,17 +227,21 @@ def get_coco(root, image_set, transforms, mode="instances"):
...
@@ -228,17 +227,21 @@ def get_coco(root, image_set, transforms, mode="instances"):
# "train": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val")))
# "train": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val")))
}
}
t
=
[
ConvertCocoPolysToMask
()]
if
transforms
is
not
None
:
t
.
append
(
transforms
)
transforms
=
T
.
Compose
(
t
)
img_folder
,
ann_file
=
PATHS
[
image_set
]
img_folder
,
ann_file
=
PATHS
[
image_set
]
img_folder
=
os
.
path
.
join
(
root
,
img_folder
)
img_folder
=
os
.
path
.
join
(
root
,
img_folder
)
ann_file
=
os
.
path
.
join
(
root
,
ann_file
)
ann_file
=
os
.
path
.
join
(
root
,
ann_file
)
dataset
=
CocoDetection
(
img_folder
,
ann_file
,
transforms
=
transforms
)
if
use_v2
:
dataset
=
torchvision
.
datasets
.
CocoDetection
(
img_folder
,
ann_file
,
transforms
=
transforms
)
# TODO: need to update target_keys to handle masks for segmentation!
dataset
=
wrap_dataset_for_transforms_v2
(
dataset
,
target_keys
=
{
"boxes"
,
"labels"
,
"image_id"
})
else
:
t
=
[
ConvertCocoPolysToMask
()]
if
transforms
is
not
None
:
t
.
append
(
transforms
)
transforms
=
T
.
Compose
(
t
)
dataset
=
CocoDetection
(
img_folder
,
ann_file
,
transforms
=
transforms
)
if
image_set
==
"train"
:
if
image_set
==
"train"
:
dataset
=
_coco_remove_images_without_annotations
(
dataset
)
dataset
=
_coco_remove_images_without_annotations
(
dataset
)
...
@@ -248,5 +251,7 @@ def get_coco(root, image_set, transforms, mode="instances"):
...
@@ -248,5 +251,7 @@ def get_coco(root, image_set, transforms, mode="instances"):
return
dataset
return
dataset
def
get_coco_kp
(
root
,
image_set
,
transforms
):
def
get_coco_kp
(
root
,
image_set
,
transforms
,
use_v2
=
False
):
if
use_v2
:
raise
ValueError
(
"KeyPoints aren't supported by transforms V2 yet."
)
return
get_coco
(
root
,
image_set
,
transforms
,
mode
=
"person_keypoints"
)
return
get_coco
(
root
,
image_set
,
transforms
,
mode
=
"person_keypoints"
)
references/detection/engine.py
View file @
bb3aae7b
...
@@ -26,7 +26,7 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, sc
...
@@ -26,7 +26,7 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, sc
for
images
,
targets
in
metric_logger
.
log_every
(
data_loader
,
print_freq
,
header
):
for
images
,
targets
in
metric_logger
.
log_every
(
data_loader
,
print_freq
,
header
):
images
=
list
(
image
.
to
(
device
)
for
image
in
images
)
images
=
list
(
image
.
to
(
device
)
for
image
in
images
)
targets
=
[{
k
:
v
.
to
(
device
)
for
k
,
v
in
t
.
items
()}
for
t
in
targets
]
targets
=
[{
k
:
v
.
to
(
device
)
if
isinstance
(
v
,
torch
.
Tensor
)
else
v
for
k
,
v
in
t
.
items
()}
for
t
in
targets
]
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
scaler
is
not
None
):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
scaler
is
not
None
):
loss_dict
=
model
(
images
,
targets
)
loss_dict
=
model
(
images
,
targets
)
losses
=
sum
(
loss
for
loss
in
loss_dict
.
values
())
losses
=
sum
(
loss
for
loss
in
loss_dict
.
values
())
...
@@ -97,7 +97,7 @@ def evaluate(model, data_loader, device):
...
@@ -97,7 +97,7 @@ def evaluate(model, data_loader, device):
outputs
=
[{
k
:
v
.
to
(
cpu_device
)
for
k
,
v
in
t
.
items
()}
for
t
in
outputs
]
outputs
=
[{
k
:
v
.
to
(
cpu_device
)
for
k
,
v
in
t
.
items
()}
for
t
in
outputs
]
model_time
=
time
.
time
()
-
model_time
model_time
=
time
.
time
()
-
model_time
res
=
{
target
[
"image_id"
]
.
item
()
:
output
for
target
,
output
in
zip
(
targets
,
outputs
)}
res
=
{
target
[
"image_id"
]:
output
for
target
,
output
in
zip
(
targets
,
outputs
)}
evaluator_time
=
time
.
time
()
evaluator_time
=
time
.
time
()
coco_evaluator
.
update
(
res
)
coco_evaluator
.
update
(
res
)
evaluator_time
=
time
.
time
()
-
evaluator_time
evaluator_time
=
time
.
time
()
-
evaluator_time
...
...
references/detection/group_by_aspect_ratio.py
View file @
bb3aae7b
...
@@ -164,7 +164,9 @@ def compute_aspect_ratios(dataset, indices=None):
...
@@ -164,7 +164,9 @@ def compute_aspect_ratios(dataset, indices=None):
if
hasattr
(
dataset
,
"get_height_and_width"
):
if
hasattr
(
dataset
,
"get_height_and_width"
):
return
_compute_aspect_ratios_custom_dataset
(
dataset
,
indices
)
return
_compute_aspect_ratios_custom_dataset
(
dataset
,
indices
)
if
isinstance
(
dataset
,
torchvision
.
datasets
.
CocoDetection
):
if
isinstance
(
dataset
,
torchvision
.
datasets
.
CocoDetection
)
or
isinstance
(
getattr
(
dataset
,
"_dataset"
,
None
),
torchvision
.
datasets
.
CocoDetection
):
return
_compute_aspect_ratios_coco_dataset
(
dataset
,
indices
)
return
_compute_aspect_ratios_coco_dataset
(
dataset
,
indices
)
if
isinstance
(
dataset
,
torchvision
.
datasets
.
VOCDetection
):
if
isinstance
(
dataset
,
torchvision
.
datasets
.
VOCDetection
):
...
...
references/detection/presets.py
View file @
bb3aae7b
from
collections
import
defaultdict
import
torch
import
torch
import
transforms
as
T
import
transforms
as
reference_transforms
def
get_modules
(
use_v2
):
# We need a protected import to avoid the V2 warning in case just V1 is used
if
use_v2
:
import
torchvision.datapoints
import
torchvision.transforms.v2
return
torchvision
.
transforms
.
v2
,
torchvision
.
datapoints
else
:
return
reference_transforms
,
None
class
DetectionPresetTrain
:
class
DetectionPresetTrain
:
def
__init__
(
self
,
*
,
data_augmentation
,
hflip_prob
=
0.5
,
mean
=
(
123.0
,
117.0
,
104.0
)):
# Note: this transform assumes that the input to forward() are always PIL
# images, regardless of the backend parameter.
def
__init__
(
self
,
*
,
data_augmentation
,
hflip_prob
=
0.5
,
mean
=
(
123.0
,
117.0
,
104.0
),
backend
=
"pil"
,
use_v2
=
False
,
):
T
,
datapoints
=
get_modules
(
use_v2
)
transforms
=
[]
backend
=
backend
.
lower
()
if
backend
==
"datapoint"
:
transforms
.
append
(
T
.
ToImageTensor
())
elif
backend
==
"tensor"
:
transforms
.
append
(
T
.
PILToTensor
())
elif
backend
!=
"pil"
:
raise
ValueError
(
f
"backend can be 'datapoint', 'tensor' or 'pil', but got
{
backend
}
"
)
if
data_augmentation
==
"hflip"
:
if
data_augmentation
==
"hflip"
:
self
.
transforms
=
T
.
Compose
(
transforms
+=
[
T
.
RandomHorizontalFlip
(
p
=
hflip_prob
)]
[
T
.
RandomHorizontalFlip
(
p
=
hflip_prob
),
T
.
PILToTensor
(),
T
.
ConvertImageDtype
(
torch
.
float
),
]
)
elif
data_augmentation
==
"lsj"
:
elif
data_augmentation
==
"lsj"
:
self
.
transforms
=
T
.
Compose
(
transforms
+=
[
[
T
.
ScaleJitter
(
target_size
=
(
1024
,
1024
),
antialias
=
True
),
T
.
ScaleJitter
(
target_size
=
(
1024
,
1024
)),
# TODO: FixedSizeCrop below doesn't work on tensors!
T
.
FixedSizeCrop
(
size
=
(
1024
,
1024
),
fill
=
mean
),
reference_transforms
.
FixedSizeCrop
(
size
=
(
1024
,
1024
),
fill
=
mean
),
T
.
RandomHorizontalFlip
(
p
=
hflip_prob
),
T
.
RandomHorizontalFlip
(
p
=
hflip_prob
),
T
.
PILToTensor
(),
]
T
.
ConvertImageDtype
(
torch
.
float
),
]
)
elif
data_augmentation
==
"multiscale"
:
elif
data_augmentation
==
"multiscale"
:
self
.
transforms
=
T
.
Compose
(
transforms
+=
[
[
T
.
RandomShortestSize
(
min_size
=
(
480
,
512
,
544
,
576
,
608
,
640
,
672
,
704
,
736
,
768
,
800
),
max_size
=
1333
),
T
.
RandomShortestSize
(
T
.
RandomHorizontalFlip
(
p
=
hflip_prob
),
min_size
=
(
480
,
512
,
544
,
576
,
608
,
640
,
672
,
704
,
736
,
768
,
800
),
max_size
=
1333
]
),
T
.
RandomHorizontalFlip
(
p
=
hflip_prob
),
T
.
PILToTensor
(),
T
.
ConvertImageDtype
(
torch
.
float
),
]
)
elif
data_augmentation
==
"ssd"
:
elif
data_augmentation
==
"ssd"
:
self
.
transforms
=
T
.
Compose
(
fill
=
defaultdict
(
lambda
:
mean
,
{
datapoints
.
Mask
:
0
})
if
use_v2
else
list
(
mean
)
[
transforms
+=
[
T
.
RandomPhotometricDistort
(),
T
.
RandomPhotometricDistort
(),
T
.
RandomZoomOut
(
fill
=
list
(
mean
)),
T
.
RandomZoomOut
(
fill
=
fill
),
T
.
RandomIoUCrop
(),
T
.
RandomIoUCrop
(),
T
.
RandomHorizontalFlip
(
p
=
hflip_prob
),
T
.
RandomHorizontalFlip
(
p
=
hflip_prob
),
T
.
PILToTensor
(),
]
T
.
ConvertImageDtype
(
torch
.
float
),
]
)
elif
data_augmentation
==
"ssdlite"
:
elif
data_augmentation
==
"ssdlite"
:
self
.
transforms
=
T
.
Compose
(
transforms
+=
[
[
T
.
RandomIoUCrop
(),
T
.
RandomIoUCrop
(),
T
.
RandomHorizontalFlip
(
p
=
hflip_prob
),
T
.
RandomHorizontalFlip
(
p
=
hflip_prob
),
]
T
.
PILToTensor
(),
T
.
ConvertImageDtype
(
torch
.
float
),
]
)
else
:
else
:
raise
ValueError
(
f
'Unknown data augmentation policy "
{
data_augmentation
}
"'
)
raise
ValueError
(
f
'Unknown data augmentation policy "
{
data_augmentation
}
"'
)
if
backend
==
"pil"
:
# Note: we could just convert to pure tensors even in v2.
transforms
+=
[
T
.
ToImageTensor
()
if
use_v2
else
T
.
PILToTensor
()]
transforms
+=
[
T
.
ConvertImageDtype
(
torch
.
float
)]
if
use_v2
:
transforms
+=
[
T
.
ConvertBoundingBoxFormat
(
datapoints
.
BoundingBoxFormat
.
XYXY
),
T
.
SanitizeBoundingBox
(),
]
self
.
transforms
=
T
.
Compose
(
transforms
)
def
__call__
(
self
,
img
,
target
):
def
__call__
(
self
,
img
,
target
):
return
self
.
transforms
(
img
,
target
)
return
self
.
transforms
(
img
,
target
)
class
DetectionPresetEval
:
class
DetectionPresetEval
:
def
__init__
(
self
):
def
__init__
(
self
,
backend
=
"pil"
,
use_v2
=
False
):
self
.
transforms
=
T
.
Compose
(
T
,
_
=
get_modules
(
use_v2
)
[
transforms
=
[]
T
.
PILToTensor
(),
backend
=
backend
.
lower
()
T
.
ConvertImageDtype
(
torch
.
float
),
if
backend
==
"pil"
:
]
# Note: we could just convert to pure tensors even in v2?
)
transforms
+=
[
T
.
ToImageTensor
()
if
use_v2
else
T
.
PILToTensor
()]
elif
backend
==
"tensor"
:
transforms
+=
[
T
.
PILToTensor
()]
elif
backend
==
"datapoint"
:
transforms
+=
[
T
.
ToImageTensor
()]
else
:
raise
ValueError
(
f
"backend can be 'datapoint', 'tensor' or 'pil', but got
{
backend
}
"
)
transforms
+=
[
T
.
ConvertImageDtype
(
torch
.
float
)]
self
.
transforms
=
T
.
Compose
(
transforms
)
def
__call__
(
self
,
img
,
target
):
def
__call__
(
self
,
img
,
target
):
return
self
.
transforms
(
img
,
target
)
return
self
.
transforms
(
img
,
target
)
references/detection/train.py
View file @
bb3aae7b
...
@@ -40,23 +40,26 @@ def copypaste_collate_fn(batch):
...
@@ -40,23 +40,26 @@ def copypaste_collate_fn(batch):
return
copypaste
(
*
utils
.
collate_fn
(
batch
))
return
copypaste
(
*
utils
.
collate_fn
(
batch
))
def
get_dataset
(
name
,
image_set
,
transform
,
data_path
):
def
get_dataset
(
is_train
,
args
):
paths
=
{
"coco"
:
(
data_path
,
get_coco
,
91
),
"coco_kp"
:
(
data_path
,
get_coco_kp
,
2
)}
image_set
=
"train"
if
is_train
else
"val"
p
,
ds_fn
,
num_classes
=
paths
[
name
]
paths
=
{
"coco"
:
(
args
.
data_path
,
get_coco
,
91
),
"coco_kp"
:
(
args
.
data_path
,
get_coco_kp
,
2
)}
p
,
ds_fn
,
num_classes
=
paths
[
args
.
dataset
]
ds
=
ds_fn
(
p
,
image_set
=
image_set
,
transforms
=
transform
)
ds
=
ds_fn
(
p
,
image_set
=
image_set
,
transforms
=
get_
transform
(
is_train
,
args
),
use_v2
=
args
.
use_v2
)
return
ds
,
num_classes
return
ds
,
num_classes
def
get_transform
(
train
,
args
):
def
get_transform
(
is_train
,
args
):
if
train
:
if
is_train
:
return
presets
.
DetectionPresetTrain
(
data_augmentation
=
args
.
data_augmentation
)
return
presets
.
DetectionPresetTrain
(
data_augmentation
=
args
.
data_augmentation
,
backend
=
args
.
backend
,
use_v2
=
args
.
use_v2
)
elif
args
.
weights
and
args
.
test_only
:
elif
args
.
weights
and
args
.
test_only
:
weights
=
torchvision
.
models
.
get_weight
(
args
.
weights
)
weights
=
torchvision
.
models
.
get_weight
(
args
.
weights
)
trans
=
weights
.
transforms
()
trans
=
weights
.
transforms
()
return
lambda
img
,
target
:
(
trans
(
img
),
target
)
return
lambda
img
,
target
:
(
trans
(
img
),
target
)
else
:
else
:
return
presets
.
DetectionPresetEval
()
return
presets
.
DetectionPresetEval
(
backend
=
args
.
backend
,
use_v2
=
args
.
use_v2
)
def
get_args_parser
(
add_help
=
True
):
def
get_args_parser
(
add_help
=
True
):
...
@@ -159,10 +162,16 @@ def get_args_parser(add_help=True):
...
@@ -159,10 +162,16 @@ def get_args_parser(add_help=True):
help
=
"Use CopyPaste data augmentation. Works only with data-augmentation='lsj'."
,
help
=
"Use CopyPaste data augmentation. Works only with data-augmentation='lsj'."
,
)
)
parser
.
add_argument
(
"--backend"
,
default
=
"PIL"
,
type
=
str
.
lower
,
help
=
"PIL or tensor - case insensitive"
)
parser
.
add_argument
(
"--use-v2"
,
action
=
"store_true"
,
help
=
"Use V2 transforms"
)
return
parser
return
parser
def
main
(
args
):
def
main
(
args
):
if
args
.
backend
.
lower
()
==
"datapoint"
and
not
args
.
use_v2
:
raise
ValueError
(
"Use --use-v2 if you want to use the datapoint backend."
)
if
args
.
output_dir
:
if
args
.
output_dir
:
utils
.
mkdir
(
args
.
output_dir
)
utils
.
mkdir
(
args
.
output_dir
)
...
@@ -177,8 +186,8 @@ def main(args):
...
@@ -177,8 +186,8 @@ def main(args):
# Data loading code
# Data loading code
print
(
"Loading data"
)
print
(
"Loading data"
)
dataset
,
num_classes
=
get_dataset
(
args
.
dataset
,
"train"
,
get_transform
(
True
,
args
),
args
.
data_path
)
dataset
,
num_classes
=
get_dataset
(
is_train
=
True
,
args
=
args
)
dataset_test
,
_
=
get_dataset
(
args
.
dataset
,
"val"
,
get_transform
(
False
,
args
),
args
.
data_path
)
dataset_test
,
_
=
get_dataset
(
is_train
=
False
,
args
=
args
)
print
(
"Creating data loaders"
)
print
(
"Creating data loaders"
)
if
args
.
distributed
:
if
args
.
distributed
:
...
...
references/detection/transforms.py
View file @
bb3aae7b
...
@@ -293,11 +293,13 @@ class ScaleJitter(nn.Module):
...
@@ -293,11 +293,13 @@ class ScaleJitter(nn.Module):
target_size
:
Tuple
[
int
,
int
],
target_size
:
Tuple
[
int
,
int
],
scale_range
:
Tuple
[
float
,
float
]
=
(
0.1
,
2.0
),
scale_range
:
Tuple
[
float
,
float
]
=
(
0.1
,
2.0
),
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
antialias
=
True
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
target_size
=
target_size
self
.
target_size
=
target_size
self
.
scale_range
=
scale_range
self
.
scale_range
=
scale_range
self
.
interpolation
=
interpolation
self
.
interpolation
=
interpolation
self
.
antialias
=
antialias
def
forward
(
def
forward
(
self
,
image
:
Tensor
,
target
:
Optional
[
Dict
[
str
,
Tensor
]]
=
None
self
,
image
:
Tensor
,
target
:
Optional
[
Dict
[
str
,
Tensor
]]
=
None
...
@@ -315,14 +317,17 @@ class ScaleJitter(nn.Module):
...
@@ -315,14 +317,17 @@ class ScaleJitter(nn.Module):
new_width
=
int
(
orig_width
*
r
)
new_width
=
int
(
orig_width
*
r
)
new_height
=
int
(
orig_height
*
r
)
new_height
=
int
(
orig_height
*
r
)
image
=
F
.
resize
(
image
,
[
new_height
,
new_width
],
interpolation
=
self
.
interpolation
)
image
=
F
.
resize
(
image
,
[
new_height
,
new_width
],
interpolation
=
self
.
interpolation
,
antialias
=
self
.
antialias
)
if
target
is
not
None
:
if
target
is
not
None
:
target
[
"boxes"
][:,
0
::
2
]
*=
new_width
/
orig_width
target
[
"boxes"
][:,
0
::
2
]
*=
new_width
/
orig_width
target
[
"boxes"
][:,
1
::
2
]
*=
new_height
/
orig_height
target
[
"boxes"
][:,
1
::
2
]
*=
new_height
/
orig_height
if
"masks"
in
target
:
if
"masks"
in
target
:
target
[
"masks"
]
=
F
.
resize
(
target
[
"masks"
]
=
F
.
resize
(
target
[
"masks"
],
[
new_height
,
new_width
],
interpolation
=
InterpolationMode
.
NEAREST
target
[
"masks"
],
[
new_height
,
new_width
],
interpolation
=
InterpolationMode
.
NEAREST
,
antialias
=
self
.
antialias
,
)
)
return
image
,
target
return
image
,
target
...
...
test/test_transforms_v2_consistency.py
View file @
bb3aae7b
...
@@ -1133,7 +1133,7 @@ class TestRefDetTransforms:
...
@@ -1133,7 +1133,7 @@ class TestRefDetTransforms:
{
"with_mask"
:
False
},
{
"with_mask"
:
False
},
),
),
(
det_transforms
.
RandomZoomOut
(),
v2_transforms
.
RandomZoomOut
(),
{
"with_mask"
:
False
}),
(
det_transforms
.
RandomZoomOut
(),
v2_transforms
.
RandomZoomOut
(),
{
"with_mask"
:
False
}),
(
det_transforms
.
ScaleJitter
((
1024
,
1024
)),
v2_transforms
.
ScaleJitter
((
1024
,
1024
)),
{}),
(
det_transforms
.
ScaleJitter
((
1024
,
1024
)),
v2_transforms
.
ScaleJitter
((
1024
,
1024
)
,
antialias
=
True
),
{}),
(
(
det_transforms
.
RandomShortestSize
(
det_transforms
.
RandomShortestSize
(
min_size
=
(
480
,
512
,
544
,
576
,
608
,
640
,
672
,
704
,
736
,
768
,
800
),
max_size
=
1333
min_size
=
(
480
,
512
,
544
,
576
,
608
,
640
,
672
,
704
,
736
,
768
,
800
),
max_size
=
1333
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment