Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
dcuai
dlexamples
Commits
0fd8347d
Commit
0fd8347d
authored
Jan 08, 2023
by
unknown
Browse files
添加mmclassification-0.24.1代码,删除mmclassification-speed-benchmark
parent
cc567e9e
Changes
838
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1397 additions
and
212 deletions
+1397
-212
openmmlab_test/mmclassification-0.24.1/mmcls/datasets/imagenet.py
...b_test/mmclassification-0.24.1/mmcls/datasets/imagenet.py
+45
-91
openmmlab_test/mmclassification-0.24.1/mmcls/datasets/imagenet21k.py
...est/mmclassification-0.24.1/mmcls/datasets/imagenet21k.py
+174
-0
openmmlab_test/mmclassification-0.24.1/mmcls/datasets/mnist.py
...mlab_test/mmclassification-0.24.1/mmcls/datasets/mnist.py
+3
-2
openmmlab_test/mmclassification-0.24.1/mmcls/datasets/multi_label.py
...est/mmclassification-0.24.1/mmcls/datasets/multi_label.py
+10
-13
openmmlab_test/mmclassification-0.24.1/mmcls/datasets/pipelines/__init__.py
...lassification-0.24.1/mmcls/datasets/pipelines/__init__.py
+22
-0
openmmlab_test/mmclassification-0.24.1/mmcls/datasets/pipelines/auto_augment.py
...ification-0.24.1/mmcls/datasets/pipelines/auto_augment.py
+91
-37
openmmlab_test/mmclassification-0.24.1/mmcls/datasets/pipelines/compose.py
...classification-0.24.1/mmcls/datasets/pipelines/compose.py
+1
-0
openmmlab_test/mmclassification-0.24.1/mmcls/datasets/pipelines/formatting.py
...ssification-0.24.1/mmcls/datasets/pipelines/formatting.py
+195
-0
openmmlab_test/mmclassification-0.24.1/mmcls/datasets/pipelines/loading.py
...classification-0.24.1/mmcls/datasets/pipelines/loading.py
+1
-0
openmmlab_test/mmclassification-0.24.1/mmcls/datasets/pipelines/transforms.py
...ssification-0.24.1/mmcls/datasets/pipelines/transforms.py
+158
-68
openmmlab_test/mmclassification-0.24.1/mmcls/datasets/samplers/__init__.py
...classification-0.24.1/mmcls/datasets/samplers/__init__.py
+5
-0
openmmlab_test/mmclassification-0.24.1/mmcls/datasets/samplers/distributed_sampler.py
...ion-0.24.1/mmcls/datasets/samplers/distributed_sampler.py
+61
-0
openmmlab_test/mmclassification-0.24.1/mmcls/datasets/samplers/repeat_aug.py
...assification-0.24.1/mmcls/datasets/samplers/repeat_aug.py
+106
-0
openmmlab_test/mmclassification-0.24.1/mmcls/datasets/stanford_cars.py
...t/mmclassification-0.24.1/mmcls/datasets/stanford_cars.py
+210
-0
openmmlab_test/mmclassification-0.24.1/mmcls/datasets/utils.py
...mlab_test/mmclassification-0.24.1/mmcls/datasets/utils.py
+153
-0
openmmlab_test/mmclassification-0.24.1/mmcls/datasets/voc.py
openmmlab_test/mmclassification-0.24.1/mmcls/datasets/voc.py
+94
-0
openmmlab_test/mmclassification-0.24.1/mmcls/models/__init__.py
...lab_test/mmclassification-0.24.1/mmcls/models/__init__.py
+14
-0
openmmlab_test/mmclassification-0.24.1/mmcls/models/backbones/__init__.py
...mclassification-0.24.1/mmcls/models/backbones/__init__.py
+51
-0
openmmlab_test/mmclassification-0.24.1/mmcls/models/backbones/alexnet.py
...mmclassification-0.24.1/mmcls/models/backbones/alexnet.py
+2
-1
openmmlab_test/mmclassification-0.24.1/mmcls/models/backbones/base_backbone.py
...sification-0.24.1/mmcls/models/backbones/base_backbone.py
+1
-0
No files found.
Too many changes to show.
To preserve performance only
838 of 838+
files are displayed.
Plain diff
Email patch
openmmlab_test/mmclassification-
speed-benchmark
/mmcls/datasets/imagenet.py
→
openmmlab_test/mmclassification-
0.24.1
/mmcls/datasets/imagenet.py
View file @
0fd8347d
import
os
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
Optional
,
Sequence
,
Union
import
numpy
as
np
from
.base_dataset
import
BaseDataset
from
.builder
import
DATASETS
from
.builder
import
DATASETS
from
.custom
import
CustomDataset
def
has_file_allowed_extension
(
filename
,
extensions
):
@
DATASETS
.
register_module
()
"""Checks if a file is an allowed extension.
class
ImageNet
(
CustomDataset
):
"""`ImageNet <http://www.image-net.org>`_ Dataset.
Args:
filename (string): path to a file
Returns:
bool: True if the filename ends with a known image extension
"""
filename_lower
=
filename
.
lower
()
return
any
(
filename_lower
.
endswith
(
ext
)
for
ext
in
extensions
)
def
find_folders
(
root
):
"""Find classes by folders under a root.
Args:
root (string): root directory of folders
Returns:
folder_to_idx (dict): the map from folder name to class idx
"""
folders
=
[
d
for
d
in
os
.
listdir
(
root
)
if
os
.
path
.
isdir
(
os
.
path
.
join
(
root
,
d
))
]
folders
.
sort
()
folder_to_idx
=
{
folders
[
i
]:
i
for
i
in
range
(
len
(
folders
))}
return
folder_to_idx
def
get_samples
(
root
,
folder_to_idx
,
extensions
):
The dataset supports two kinds of annotation format. More details can be
"""Make dataset by walking all images under a root
.
found in :class:`CustomDataset`
.
Args:
Args:
root (string): root directory of folders
data_prefix (str): The path of data directory.
folder_to_idx (dict): the map from class name to class idx
pipeline (Sequence[dict]): A list of dict, where each element
extensions (tuple): allowed extensions
represents a operation defined in :mod:`mmcls.datasets.pipelines`.
Defaults to an empty tuple.
Returns:
classes (str | Sequence[str], optional): Specify names of classes.
samples (list): a list of tuple where each element is (image, label)
"""
samples
=
[]
root
=
os
.
path
.
expanduser
(
root
)
for
folder_name
in
sorted
(
os
.
listdir
(
root
)):
_dir
=
os
.
path
.
join
(
root
,
folder_name
)
if
not
os
.
path
.
isdir
(
_dir
):
continue
for
_
,
_
,
fns
in
sorted
(
os
.
walk
(
_dir
)):
for
fn
in
sorted
(
fns
):
if
has_file_allowed_extension
(
fn
,
extensions
):
path
=
os
.
path
.
join
(
folder_name
,
fn
)
item
=
(
path
,
folder_to_idx
[
folder_name
])
samples
.
append
(
item
)
return
samples
- If is string, it should be a file path, and the every line of
the file is a name of a class.
- If is a sequence of string, every item is a name of class.
- If is None, use the default ImageNet-1k classes names.
@
DATASETS
.
register_module
()
Defaults to None.
class
ImageNet
(
BaseDataset
):
ann_file (str, optional): The annotation file. If is string, read
"""`ImageNet <http://www.image-net.org>`_ Dataset.
samples paths from the ann_file. If is None, find samples in
``data_prefix``. Defaults to None.
This implementation is modified from
extensions (Sequence[str]): A sequence of allowed extensions. Defaults
https://github.com/pytorch/vision/blob/master/torchvision/datasets/imagenet.py # noqa: E501
to ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif').
"""
test_mode (bool): In train mode or test mode. It's only a mark and
won't be used in this class. Defaults to False.
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details.
If None, automatically inference from the specified path.
Defaults to None.
"""
# noqa: E501
IMG_EXTENSIONS
=
(
'.jpg'
,
'.jpeg'
,
'.png'
,
'.ppm'
,
'.bmp'
,
'.pgm'
,
'.tif'
)
IMG_EXTENSIONS
=
(
'.jpg'
,
'.jpeg'
,
'.png'
,
'.ppm'
,
'.bmp'
,
'.pgm'
,
'.tif'
)
CLASSES
=
[
CLASSES
=
[
...
@@ -1075,31 +1042,18 @@ class ImageNet(BaseDataset):
...
@@ -1075,31 +1042,18 @@ class ImageNet(BaseDataset):
'toilet tissue, toilet paper, bathroom tissue'
'toilet tissue, toilet paper, bathroom tissue'
]
]
def
load_annotations
(
self
):
def
__init__
(
self
,
if
self
.
ann_file
is
None
:
data_prefix
:
str
,
folder_to_idx
=
find_folders
(
self
.
data_prefix
)
pipeline
:
Sequence
=
(),
samples
=
get_samples
(
classes
:
Union
[
str
,
Sequence
[
str
],
None
]
=
None
,
self
.
data_prefix
,
ann_file
:
Optional
[
str
]
=
None
,
folder_to_idx
,
test_mode
:
bool
=
False
,
extensions
=
self
.
IMG_EXTENSIONS
)
file_client_args
:
Optional
[
dict
]
=
None
):
if
len
(
samples
)
==
0
:
super
().
__init__
(
raise
(
RuntimeError
(
'Found 0 files in subfolders of: '
data_prefix
=
data_prefix
,
f
'
{
self
.
data_prefix
}
. '
pipeline
=
pipeline
,
'Supported extensions are: '
classes
=
classes
,
f
'
{
","
.
join
(
self
.
IMG_EXTENSIONS
)
}
'
))
ann_file
=
ann_file
,
extensions
=
self
.
IMG_EXTENSIONS
,
self
.
folder_to_idx
=
folder_to_idx
test_mode
=
test_mode
,
elif
isinstance
(
self
.
ann_file
,
str
):
file_client_args
=
file_client_args
)
with
open
(
self
.
ann_file
)
as
f
:
samples
=
[
x
.
strip
().
split
(
' '
)
for
x
in
f
.
readlines
()]
else
:
raise
TypeError
(
'ann_file must be a str or None'
)
self
.
samples
=
samples
data_infos
=
[]
for
filename
,
gt_label
in
self
.
samples
:
info
=
{
'img_prefix'
:
self
.
data_prefix
}
info
[
'img_info'
]
=
{
'filename'
:
filename
}
info
[
'gt_label'
]
=
np
.
array
(
gt_label
,
dtype
=
np
.
int64
)
data_infos
.
append
(
info
)
return
data_infos
openmmlab_test/mmclassification-0.24.1/mmcls/datasets/imagenet21k.py
0 → 100644
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
import
gc
import
pickle
import
warnings
from
typing
import
List
,
Optional
,
Sequence
,
Tuple
,
Union
import
numpy
as
np
from
.builder
import
DATASETS
from
.custom
import
CustomDataset
@
DATASETS
.
register_module
()
class
ImageNet21k
(
CustomDataset
):
"""ImageNet21k Dataset.
Since the dataset ImageNet21k is extremely big, cantains 21k+ classes
and 1.4B files. This class has improved the following points on the
basis of the class ``ImageNet``, in order to save memory, we enable the
``serialize_data`` optional by default. With this option, the annotation
won't be stored in the list ``data_infos``, but be serialized as an
array.
Args:
data_prefix (str): The path of data directory.
pipeline (Sequence[dict]): A list of dict, where each element
represents a operation defined in :mod:`mmcls.datasets.pipelines`.
Defaults to an empty tuple.
classes (str | Sequence[str], optional): Specify names of classes.
- If is string, it should be a file path, and the every line of
the file is a name of a class.
- If is a sequence of string, every item is a name of class.
- If is None, the object won't have category information.
(Not recommended)
Defaults to None.
ann_file (str, optional): The annotation file. If is string, read
samples paths from the ann_file. If is None, find samples in
``data_prefix``. Defaults to None.
serialize_data (bool): Whether to hold memory using serialized objects,
when enabled, data loader workers can use shared RAM from master
process instead of making a copy. Defaults to True.
multi_label (bool): Not implement by now. Use multi label or not.
Defaults to False.
recursion_subdir(bool): Deprecated, and the dataset will recursively
get all images now.
test_mode (bool): In train mode or test mode. It's only a mark and
won't be used in this class. Defaults to False.
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details.
If None, automatically inference from the specified path.
Defaults to None.
"""
IMG_EXTENSIONS
=
(
'.jpg'
,
'.jpeg'
,
'.png'
,
'.ppm'
,
'.bmp'
,
'.pgm'
,
'.tif'
)
CLASSES
=
None
def
__init__
(
self
,
data_prefix
:
str
,
pipeline
:
Sequence
=
(),
classes
:
Union
[
str
,
Sequence
[
str
],
None
]
=
None
,
ann_file
:
Optional
[
str
]
=
None
,
serialize_data
:
bool
=
True
,
multi_label
:
bool
=
False
,
recursion_subdir
:
bool
=
True
,
test_mode
=
False
,
file_client_args
:
Optional
[
dict
]
=
None
):
assert
recursion_subdir
,
'The `recursion_subdir` option is '
\
'deprecated. Now the dataset will recursively get all images.'
if
multi_label
:
raise
NotImplementedError
(
'The `multi_label` option is not supported by now.'
)
self
.
multi_label
=
multi_label
self
.
serialize_data
=
serialize_data
if
ann_file
is
None
:
warnings
.
warn
(
'The ImageNet21k dataset is large, and scanning directory may '
'consume long time. Considering to specify the `ann_file` to '
'accelerate the initialization.'
,
UserWarning
)
if
classes
is
None
:
warnings
.
warn
(
'The CLASSES is not stored in the `ImageNet21k` class. '
'Considering to specify the `classes` argument if you need '
'do inference on the ImageNet-21k dataset'
,
UserWarning
)
super
().
__init__
(
data_prefix
=
data_prefix
,
pipeline
=
pipeline
,
classes
=
classes
,
ann_file
=
ann_file
,
extensions
=
self
.
IMG_EXTENSIONS
,
test_mode
=
test_mode
,
file_client_args
=
file_client_args
)
if
self
.
serialize_data
:
self
.
data_infos_bytes
,
self
.
data_address
=
self
.
_serialize_data
()
# Empty cache for preventing making multiple copies of
# `self.data_infos` when loading data multi-processes.
self
.
data_infos
.
clear
()
gc
.
collect
()
def
get_cat_ids
(
self
,
idx
:
int
)
->
List
[
int
]:
"""Get category id by index.
Args:
idx (int): Index of data.
Returns:
cat_ids (List[int]): Image category of specified index.
"""
return
[
int
(
self
.
get_data_info
(
idx
)[
'gt_label'
])]
def
get_data_info
(
self
,
idx
:
int
)
->
dict
:
"""Get annotation by index.
Args:
idx (int): The index of data.
Returns:
dict: The idx-th annotation of the dataset.
"""
if
self
.
serialize_data
:
start_addr
=
0
if
idx
==
0
else
self
.
data_address
[
idx
-
1
].
item
()
end_addr
=
self
.
data_address
[
idx
].
item
()
bytes
=
memoryview
(
self
.
data_infos_bytes
[
start_addr
:
end_addr
])
data_info
=
pickle
.
loads
(
bytes
)
else
:
data_info
=
self
.
data_infos
[
idx
]
return
data_info
def
prepare_data
(
self
,
idx
):
data_info
=
self
.
get_data_info
(
idx
)
return
self
.
pipeline
(
data_info
)
def
_serialize_data
(
self
)
->
Tuple
[
np
.
ndarray
,
np
.
ndarray
]:
"""Serialize ``self.data_infos`` to save memory when launching multiple
workers in data loading. This function will be called in ``full_init``.
Hold memory using serialized objects, and data loader workers can use
shared RAM from master process instead of making a copy.
Returns:
Tuple[np.ndarray, np.ndarray]: serialize result and corresponding
address.
"""
def
_serialize
(
data
):
buffer
=
pickle
.
dumps
(
data
,
protocol
=
4
)
return
np
.
frombuffer
(
buffer
,
dtype
=
np
.
uint8
)
serialized_data_infos_list
=
[
_serialize
(
x
)
for
x
in
self
.
data_infos
]
address_list
=
np
.
asarray
([
len
(
x
)
for
x
in
serialized_data_infos_list
],
dtype
=
np
.
int64
)
data_address
:
np
.
ndarray
=
np
.
cumsum
(
address_list
)
serialized_data_infos
=
np
.
concatenate
(
serialized_data_infos_list
)
return
serialized_data_infos
,
data_address
def
__len__
(
self
)
->
int
:
"""Get the length of filtered dataset and automatically call
``full_init`` if the dataset has not been fully init.
Returns:
int: The length of filtered dataset.
"""
if
self
.
serialize_data
:
return
len
(
self
.
data_address
)
else
:
return
len
(
self
.
data_infos
)
openmmlab_test/mmclassification-
speed-benchmark
/mmcls/datasets/mnist.py
→
openmmlab_test/mmclassification-
0.24.1
/mmcls/datasets/mnist.py
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
import
codecs
import
codecs
import
os
import
os
import
os.path
as
osp
import
os.path
as
osp
...
@@ -17,8 +18,8 @@ class MNIST(BaseDataset):
...
@@ -17,8 +18,8 @@ class MNIST(BaseDataset):
"""`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.
"""`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.
This implementation is modified from
This implementation is modified from
https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py
# noqa: E501
https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py
"""
"""
# noqa: E501
resource_prefix
=
'http://yann.lecun.com/exdb/mnist/'
resource_prefix
=
'http://yann.lecun.com/exdb/mnist/'
resources
=
{
resources
=
{
...
...
openmmlab_test/mmclassification-
speed-benchmark
/mmcls/datasets/multi_label.py
→
openmmlab_test/mmclassification-
0.24.1
/mmcls/datasets/multi_label.py
View file @
0fd8347d
import
warnings
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
List
import
numpy
as
np
import
numpy
as
np
...
@@ -9,25 +10,25 @@ from .base_dataset import BaseDataset
...
@@ -9,25 +10,25 @@ from .base_dataset import BaseDataset
class
MultiLabelDataset
(
BaseDataset
):
class
MultiLabelDataset
(
BaseDataset
):
"""Multi-label Dataset."""
"""Multi-label Dataset."""
def
get_cat_ids
(
self
,
idx
)
:
def
get_cat_ids
(
self
,
idx
:
int
)
->
List
[
int
]
:
"""Get category ids by index.
"""Get category ids by index.
Args:
Args:
idx (int): Index of data.
idx (int): Index of data.
Returns:
Returns:
np.ndarray
: Image categories of specified index.
cat_ids (List[int])
: Image categories of specified index.
"""
"""
gt_labels
=
self
.
data_infos
[
idx
][
'gt_label'
]
gt_labels
=
self
.
data_infos
[
idx
][
'gt_label'
]
cat_ids
=
np
.
where
(
gt_labels
==
1
)[
0
]
cat_ids
=
np
.
where
(
gt_labels
==
1
)[
0
]
.
tolist
()
return
cat_ids
return
cat_ids
def
evaluate
(
self
,
def
evaluate
(
self
,
results
,
results
,
metric
=
'mAP'
,
metric
=
'mAP'
,
metric_options
=
None
,
metric_options
=
None
,
logger
=
None
,
indices
=
None
,
**
deprecated_kwargs
):
logger
=
None
):
"""Evaluate the dataset.
"""Evaluate the dataset.
Args:
Args:
...
@@ -39,19 +40,13 @@ class MultiLabelDataset(BaseDataset):
...
@@ -39,19 +40,13 @@ class MultiLabelDataset(BaseDataset):
Allowed keys are 'k' and 'thr'. Defaults to None
Allowed keys are 'k' and 'thr'. Defaults to None
logger (logging.Logger | str, optional): Logger used for printing
logger (logging.Logger | str, optional): Logger used for printing
related information during evaluation. Defaults to None.
related information during evaluation. Defaults to None.
deprecated_kwargs (dict): Used for containing deprecated arguments.
Returns:
Returns:
dict: evaluation results
dict: evaluation results
"""
"""
if
metric_options
is
None
:
if
metric_options
is
None
or
metric_options
==
{}
:
metric_options
=
{
'thr'
:
0.5
}
metric_options
=
{
'thr'
:
0.5
}
if
deprecated_kwargs
!=
{}:
warnings
.
warn
(
'Option arguments for metrics has been changed to '
'`metric_options`.'
)
metric_options
=
{
**
deprecated_kwargs
}
if
isinstance
(
metric
,
str
):
if
isinstance
(
metric
,
str
):
metrics
=
[
metric
]
metrics
=
[
metric
]
else
:
else
:
...
@@ -60,6 +55,8 @@ class MultiLabelDataset(BaseDataset):
...
@@ -60,6 +55,8 @@ class MultiLabelDataset(BaseDataset):
eval_results
=
{}
eval_results
=
{}
results
=
np
.
vstack
(
results
)
results
=
np
.
vstack
(
results
)
gt_labels
=
self
.
get_gt_labels
()
gt_labels
=
self
.
get_gt_labels
()
if
indices
is
not
None
:
gt_labels
=
gt_labels
[
indices
]
num_imgs
=
len
(
results
)
num_imgs
=
len
(
results
)
assert
len
(
gt_labels
)
==
num_imgs
,
'dataset testing results should '
\
assert
len
(
gt_labels
)
==
num_imgs
,
'dataset testing results should '
\
'be of the same length as gt_labels.'
'be of the same length as gt_labels.'
...
...
openmmlab_test/mmclassification-0.24.1/mmcls/datasets/pipelines/__init__.py
0 → 100644
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
from
.auto_augment
import
(
AutoAugment
,
AutoContrast
,
Brightness
,
ColorTransform
,
Contrast
,
Cutout
,
Equalize
,
Invert
,
Posterize
,
RandAugment
,
Rotate
,
Sharpness
,
Shear
,
Solarize
,
SolarizeAdd
,
Translate
)
from
.compose
import
Compose
from
.formatting
import
(
Collect
,
ImageToTensor
,
ToNumpy
,
ToPIL
,
ToTensor
,
Transpose
,
to_tensor
)
from
.loading
import
LoadImageFromFile
from
.transforms
import
(
CenterCrop
,
ColorJitter
,
Lighting
,
Normalize
,
Pad
,
RandomCrop
,
RandomErasing
,
RandomFlip
,
RandomGrayscale
,
RandomResizedCrop
,
Resize
)
__all__
=
[
'Compose'
,
'to_tensor'
,
'ToTensor'
,
'ImageToTensor'
,
'ToPIL'
,
'ToNumpy'
,
'Transpose'
,
'Collect'
,
'LoadImageFromFile'
,
'Resize'
,
'CenterCrop'
,
'RandomFlip'
,
'Normalize'
,
'RandomCrop'
,
'RandomResizedCrop'
,
'RandomGrayscale'
,
'Shear'
,
'Translate'
,
'Rotate'
,
'Invert'
,
'ColorTransform'
,
'Solarize'
,
'Posterize'
,
'AutoContrast'
,
'Equalize'
,
'Contrast'
,
'Brightness'
,
'Sharpness'
,
'AutoAugment'
,
'SolarizeAdd'
,
'Cutout'
,
'RandAugment'
,
'Lighting'
,
'ColorJitter'
,
'RandomErasing'
,
'Pad'
]
openmmlab_test/mmclassification-
speed-benchmark
/mmcls/datasets/pipelines/auto_augment.py
→
openmmlab_test/mmclassification-
0.24.1
/mmcls/datasets/pipelines/auto_augment.py
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
import
copy
import
copy
import
inspect
import
random
import
random
from
math
import
ceil
from
numbers
import
Number
from
numbers
import
Number
from
typing
import
Sequence
from
typing
import
Sequence
...
@@ -9,18 +12,43 @@ import numpy as np
...
@@ -9,18 +12,43 @@ import numpy as np
from
..builder
import
PIPELINES
from
..builder
import
PIPELINES
from
.compose
import
Compose
from
.compose
import
Compose
# Default hyperparameters for all Ops
_HPARAMS_DEFAULT
=
dict
(
pad_val
=
128
)
def
random_negative
(
value
,
random_negative_prob
):
def
random_negative
(
value
,
random_negative_prob
):
"""Randomly negate value based on random_negative_prob."""
"""Randomly negate value based on random_negative_prob."""
return
-
value
if
np
.
random
.
rand
()
<
random_negative_prob
else
value
return
-
value
if
np
.
random
.
rand
()
<
random_negative_prob
else
value
def
merge_hparams
(
policy
:
dict
,
hparams
:
dict
):
"""Merge hyperparameters into policy config.
Only merge partial hyperparameters required of the policy.
Args:
policy (dict): Original policy config dict.
hparams (dict): Hyperparameters need to be merged.
Returns:
dict: Policy config dict after adding ``hparams``.
"""
op
=
PIPELINES
.
get
(
policy
[
'type'
])
assert
op
is
not
None
,
f
'Invalid policy type "
{
policy
[
"type"
]
}
".'
for
key
,
value
in
hparams
.
items
():
if
policy
.
get
(
key
,
None
)
is
not
None
:
continue
if
key
in
inspect
.
getfullargspec
(
op
.
__init__
).
args
:
policy
[
key
]
=
value
return
policy
@
PIPELINES
.
register_module
()
@
PIPELINES
.
register_module
()
class
AutoAugment
(
object
):
class
AutoAugment
(
object
):
"""Auto augmentation. This data augmentation is proposed in `AutoAugment:
"""Auto augmentation.
Learning Augmentation Policies from Data.
<https://arxiv.org/abs/1805.09501>`_.
This data augmentation is proposed in `AutoAugment: Learning Augmentation
Policies from Data <https://arxiv.org/abs/1805.09501>`_.
Args:
Args:
policies (list[list[dict]]): The policies of auto augmentation. Each
policies (list[list[dict]]): The policies of auto augmentation. Each
...
@@ -28,9 +56,12 @@ class AutoAugment(object):
...
@@ -28,9 +56,12 @@ class AutoAugment(object):
composed by several augmentations (dict). When AutoAugment is
composed by several augmentations (dict). When AutoAugment is
called, a random policy in ``policies`` will be selected to
called, a random policy in ``policies`` will be selected to
augment images.
augment images.
hparams (dict): Configs of hyperparameters. Hyperparameters will be
used in policies that require these arguments if these arguments
are not set in policy dicts. Defaults to use _HPARAMS_DEFAULT.
"""
"""
def
__init__
(
self
,
policies
):
def
__init__
(
self
,
policies
,
hparams
=
_HPARAMS_DEFAULT
):
assert
isinstance
(
policies
,
list
)
and
len
(
policies
)
>
0
,
\
assert
isinstance
(
policies
,
list
)
and
len
(
policies
)
>
0
,
\
'Policies must be a non-empty list.'
'Policies must be a non-empty list.'
for
policy
in
policies
:
for
policy
in
policies
:
...
@@ -41,7 +72,13 @@ class AutoAugment(object):
...
@@ -41,7 +72,13 @@ class AutoAugment(object):
'Each specific augmentation must be a dict with key'
\
'Each specific augmentation must be a dict with key'
\
' "type".'
' "type".'
self
.
policies
=
copy
.
deepcopy
(
policies
)
self
.
hparams
=
hparams
policies
=
copy
.
deepcopy
(
policies
)
self
.
policies
=
[]
for
sub
in
policies
:
merged_sub
=
[
merge_hparams
(
policy
,
hparams
)
for
policy
in
sub
]
self
.
policies
.
append
(
merged_sub
)
self
.
sub_policy
=
[
Compose
(
policy
)
for
policy
in
self
.
policies
]
self
.
sub_policy
=
[
Compose
(
policy
)
for
policy
in
self
.
policies
]
def
__call__
(
self
,
results
):
def
__call__
(
self
,
results
):
...
@@ -56,9 +93,10 @@ class AutoAugment(object):
...
@@ -56,9 +93,10 @@ class AutoAugment(object):
@
PIPELINES
.
register_module
()
@
PIPELINES
.
register_module
()
class
RandAugment
(
object
):
class
RandAugment
(
object
):
"""Random augmentation. This data augmentation is proposed in `RandAugment:
r
"""Random augmentation.
Practical automated data augmentation with a reduced search space.
This data augmentation is proposed in `RandAugment: Practical automated
data augmentation with a reduced search space
<https://arxiv.org/abs/1909.13719>`_.
<https://arxiv.org/abs/1909.13719>`_.
Args:
Args:
...
@@ -78,19 +116,26 @@ class RandAugment(object):
...
@@ -78,19 +116,26 @@ class RandAugment(object):
total_level (int | float): Total level for the magnitude. Defaults to
total_level (int | float): Total level for the magnitude. Defaults to
30.
30.
magnitude_std (Number | str): Deviation of magnitude noise applied.
magnitude_std (Number | str): Deviation of magnitude noise applied.
If positive number, magnitude is sampled from normal distribution
(mean=magnitude, std=magnitude_std).
- If positive number, magnitude is sampled from normal distribution
If 0 or negative number, magnitude remains unchanged.
(mean=magnitude, std=magnitude_std).
If str "inf", magnitude is sampled from uniform distribution
- If 0 or negative number, magnitude remains unchanged.
(range=[min, magnitude]).
- If str "inf", magnitude is sampled from uniform distribution
(range=[min, magnitude]).
hparams (dict): Configs of hyperparameters. Hyperparameters will be
used in policies that require these arguments if these arguments
are not set in policy dicts. Defaults to use _HPARAMS_DEFAULT.
Note:
Note:
`magnitude_std` will introduce some randomness to policy, modified by
`magnitude_std` will introduce some randomness to policy, modified by
https://github.com/rwightman/pytorch-image-models
https://github.com/rwightman/pytorch-image-models.
When magnitude_std=0, we calculate the magnitude as follows:
When magnitude_std=0, we calculate the magnitude as follows:
.. math::
.. math::
magnitude = magnitude_level / total_level * (val2 - val1) + val1
\text{magnitude} = \frac{\text{magnitude_level}}
{\text{totallevel}} \times (\text{val2} - \text{val1})
+ \text{val1}
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -98,7 +143,8 @@ class RandAugment(object):
...
@@ -98,7 +143,8 @@ class RandAugment(object):
num_policies
,
num_policies
,
magnitude_level
,
magnitude_level
,
magnitude_std
=
0.
,
magnitude_std
=
0.
,
total_level
=
30
):
total_level
=
30
,
hparams
=
_HPARAMS_DEFAULT
):
assert
isinstance
(
num_policies
,
int
),
'Number of policies must be '
\
assert
isinstance
(
num_policies
,
int
),
'Number of policies must be '
\
f
'of int type, got
{
type
(
num_policies
)
}
instead.'
f
'of int type, got
{
type
(
num_policies
)
}
instead.'
assert
isinstance
(
magnitude_level
,
(
int
,
float
)),
\
assert
isinstance
(
magnitude_level
,
(
int
,
float
)),
\
...
@@ -125,8 +171,10 @@ class RandAugment(object):
...
@@ -125,8 +171,10 @@ class RandAugment(object):
self
.
magnitude_level
=
magnitude_level
self
.
magnitude_level
=
magnitude_level
self
.
magnitude_std
=
magnitude_std
self
.
magnitude_std
=
magnitude_std
self
.
total_level
=
total_level
self
.
total_level
=
total_level
self
.
policies
=
policies
self
.
hparams
=
hparams
self
.
_check_policies
(
self
.
policies
)
policies
=
copy
.
deepcopy
(
policies
)
self
.
_check_policies
(
policies
)
self
.
policies
=
[
merge_hparams
(
policy
,
hparams
)
for
policy
in
policies
]
def
_check_policies
(
self
,
policies
):
def
_check_policies
(
self
,
policies
):
for
policy
in
policies
:
for
policy
in
policies
:
...
@@ -190,8 +238,8 @@ class Shear(object):
...
@@ -190,8 +238,8 @@ class Shear(object):
Args:
Args:
magnitude (int | float): The magnitude used for shear.
magnitude (int | float): The magnitude used for shear.
pad_val (int,
tupl
e[int]): Pixel pad_val value for constant fill.
If a
pad_val (int,
Sequenc
e[int]): Pixel pad_val value for constant fill.
tupl
e of length 3, it is used to pad_val R, G, B channels
If a sequenc
e of length 3, it is used to pad_val R, G, B channels
respectively. Defaults to 128.
respectively. Defaults to 128.
prob (float): The probability for performing Shear therefore should be
prob (float): The probability for performing Shear therefore should be
in range [0, 1]. Defaults to 0.5.
in range [0, 1]. Defaults to 0.5.
...
@@ -214,7 +262,7 @@ class Shear(object):
...
@@ -214,7 +262,7 @@ class Shear(object):
f
'be int or float, but got
{
type
(
magnitude
)
}
instead.'
f
'be int or float, but got
{
type
(
magnitude
)
}
instead.'
if
isinstance
(
pad_val
,
int
):
if
isinstance
(
pad_val
,
int
):
pad_val
=
tuple
([
pad_val
]
*
3
)
pad_val
=
tuple
([
pad_val
]
*
3
)
elif
isinstance
(
pad_val
,
tupl
e
):
elif
isinstance
(
pad_val
,
Sequenc
e
):
assert
len
(
pad_val
)
==
3
,
'pad_val as a tuple must have 3 '
\
assert
len
(
pad_val
)
==
3
,
'pad_val as a tuple must have 3 '
\
f
'elements, got
{
len
(
pad_val
)
}
instead.'
f
'elements, got
{
len
(
pad_val
)
}
instead.'
assert
all
(
isinstance
(
i
,
int
)
for
i
in
pad_val
),
'pad_val as a '
\
assert
all
(
isinstance
(
i
,
int
)
for
i
in
pad_val
),
'pad_val as a '
\
...
@@ -229,7 +277,7 @@ class Shear(object):
...
@@ -229,7 +277,7 @@ class Shear(object):
f
'should be in range [0,1], got
{
random_negative_prob
}
instead.'
f
'should be in range [0,1], got
{
random_negative_prob
}
instead.'
self
.
magnitude
=
magnitude
self
.
magnitude
=
magnitude
self
.
pad_val
=
pad_val
self
.
pad_val
=
tuple
(
pad_val
)
self
.
prob
=
prob
self
.
prob
=
prob
self
.
direction
=
direction
self
.
direction
=
direction
self
.
random_negative_prob
=
random_negative_prob
self
.
random_negative_prob
=
random_negative_prob
...
@@ -269,9 +317,9 @@ class Translate(object):
...
@@ -269,9 +317,9 @@ class Translate(object):
magnitude (int | float): The magnitude used for translate. Note that
magnitude (int | float): The magnitude used for translate. Note that
the offset is calculated by magnitude * size in the corresponding
the offset is calculated by magnitude * size in the corresponding
direction. With a magnitude of 1, the whole image will be moved out
direction. With a magnitude of 1, the whole image will be moved out
of the range.
of the range.
pad_val (int,
tupl
e[int]): Pixel pad_val value for constant fill.
If a
pad_val (int,
Sequenc
e[int]): Pixel pad_val value for constant fill.
tupl
e of length 3, it is used to pad_val R, G, B channels
If a sequenc
e of length 3, it is used to pad_val R, G, B channels
respectively. Defaults to 128.
respectively. Defaults to 128.
prob (float): The probability for performing translate therefore should
prob (float): The probability for performing translate therefore should
be in range [0, 1]. Defaults to 0.5.
be in range [0, 1]. Defaults to 0.5.
...
@@ -294,7 +342,7 @@ class Translate(object):
...
@@ -294,7 +342,7 @@ class Translate(object):
f
'be int or float, but got
{
type
(
magnitude
)
}
instead.'
f
'be int or float, but got
{
type
(
magnitude
)
}
instead.'
if
isinstance
(
pad_val
,
int
):
if
isinstance
(
pad_val
,
int
):
pad_val
=
tuple
([
pad_val
]
*
3
)
pad_val
=
tuple
([
pad_val
]
*
3
)
elif
isinstance
(
pad_val
,
tupl
e
):
elif
isinstance
(
pad_val
,
Sequenc
e
):
assert
len
(
pad_val
)
==
3
,
'pad_val as a tuple must have 3 '
\
assert
len
(
pad_val
)
==
3
,
'pad_val as a tuple must have 3 '
\
f
'elements, got
{
len
(
pad_val
)
}
instead.'
f
'elements, got
{
len
(
pad_val
)
}
instead.'
assert
all
(
isinstance
(
i
,
int
)
for
i
in
pad_val
),
'pad_val as a '
\
assert
all
(
isinstance
(
i
,
int
)
for
i
in
pad_val
),
'pad_val as a '
\
...
@@ -309,7 +357,7 @@ class Translate(object):
...
@@ -309,7 +357,7 @@ class Translate(object):
f
'should be in range [0,1], got
{
random_negative_prob
}
instead.'
f
'should be in range [0,1], got
{
random_negative_prob
}
instead.'
self
.
magnitude
=
magnitude
self
.
magnitude
=
magnitude
self
.
pad_val
=
pad_val
self
.
pad_val
=
tuple
(
pad_val
)
self
.
prob
=
prob
self
.
prob
=
prob
self
.
direction
=
direction
self
.
direction
=
direction
self
.
random_negative_prob
=
random_negative_prob
self
.
random_negative_prob
=
random_negative_prob
...
@@ -354,11 +402,11 @@ class Rotate(object):
...
@@ -354,11 +402,11 @@ class Rotate(object):
angle (float): The angle used for rotate. Positive values stand for
angle (float): The angle used for rotate. Positive values stand for
clockwise rotation.
clockwise rotation.
center (tuple[float], optional): Center point (w, h) of the rotation in
center (tuple[float], optional): Center point (w, h) of the rotation in
the source image. If None, the center of the image will be used.
the source image. If None, the center of the image will be used.
d
efaults to None.
D
efaults to None.
scale (float): Isotropic scale factor. Defaults to 1.0.
scale (float): Isotropic scale factor. Defaults to 1.0.
pad_val (int,
tupl
e[int]): Pixel pad_val value for constant fill.
If a
pad_val (int,
Sequenc
e[int]): Pixel pad_val value for constant fill.
tupl
e of length 3, it is used to pad_val R, G, B channels
If a sequenc
e of length 3, it is used to pad_val R, G, B channels
respectively. Defaults to 128.
respectively. Defaults to 128.
prob (float): The probability for performing Rotate therefore should be
prob (float): The probability for performing Rotate therefore should be
in range [0, 1]. Defaults to 0.5.
in range [0, 1]. Defaults to 0.5.
...
@@ -388,7 +436,7 @@ class Rotate(object):
...
@@ -388,7 +436,7 @@ class Rotate(object):
f
'got
{
type
(
scale
)
}
instead.'
f
'got
{
type
(
scale
)
}
instead.'
if
isinstance
(
pad_val
,
int
):
if
isinstance
(
pad_val
,
int
):
pad_val
=
tuple
([
pad_val
]
*
3
)
pad_val
=
tuple
([
pad_val
]
*
3
)
elif
isinstance
(
pad_val
,
tupl
e
):
elif
isinstance
(
pad_val
,
Sequenc
e
):
assert
len
(
pad_val
)
==
3
,
'pad_val as a tuple must have 3 '
\
assert
len
(
pad_val
)
==
3
,
'pad_val as a tuple must have 3 '
\
f
'elements, got
{
len
(
pad_val
)
}
instead.'
f
'elements, got
{
len
(
pad_val
)
}
instead.'
assert
all
(
isinstance
(
i
,
int
)
for
i
in
pad_val
),
'pad_val as a '
\
assert
all
(
isinstance
(
i
,
int
)
for
i
in
pad_val
),
'pad_val as a '
\
...
@@ -403,7 +451,7 @@ class Rotate(object):
...
@@ -403,7 +451,7 @@ class Rotate(object):
self
.
angle
=
angle
self
.
angle
=
angle
self
.
center
=
center
self
.
center
=
center
self
.
scale
=
scale
self
.
scale
=
scale
self
.
pad_val
=
pad_val
self
.
pad_val
=
tuple
(
pad_val
)
self
.
prob
=
prob
self
.
prob
=
prob
self
.
random_negative_prob
=
random_negative_prob
self
.
random_negative_prob
=
random_negative_prob
self
.
interpolation
=
interpolation
self
.
interpolation
=
interpolation
...
@@ -621,7 +669,8 @@ class Posterize(object):
...
@@ -621,7 +669,8 @@ class Posterize(object):
assert
0
<=
prob
<=
1.0
,
'The prob should be in range [0,1], '
\
assert
0
<=
prob
<=
1.0
,
'The prob should be in range [0,1], '
\
f
'got
{
prob
}
instead.'
f
'got
{
prob
}
instead.'
self
.
bits
=
int
(
bits
)
# To align timm version, we need to round up to integer here.
self
.
bits
=
ceil
(
bits
)
self
.
prob
=
prob
self
.
prob
=
prob
def
__call__
(
self
,
results
):
def
__call__
(
self
,
results
):
...
@@ -692,7 +741,7 @@ class ColorTransform(object):
...
@@ -692,7 +741,7 @@ class ColorTransform(object):
Args:
Args:
magnitude (int | float): The magnitude used for color transform. A
magnitude (int | float): The magnitude used for color transform. A
positive magnitude would enhance the color and a negative magnitude
positive magnitude would enhance the color and a negative magnitude
would make the image grayer. A magnitude=0 gives the origin img.
would make the image grayer. A magnitude=0 gives the origin img.
prob (float): The probability for performing ColorTransform therefore
prob (float): The probability for performing ColorTransform therefore
should be in range [0, 1]. Defaults to 0.5.
should be in range [0, 1]. Defaults to 0.5.
random_negative_prob (float): The probability that turns the magnitude
random_negative_prob (float): The probability that turns the magnitude
...
@@ -827,8 +876,8 @@ class Cutout(object):
...
@@ -827,8 +876,8 @@ class Cutout(object):
shape (int | float | tuple(int | float)): Expected cutout shape (h, w).
shape (int | float | tuple(int | float)): Expected cutout shape (h, w).
If given as a single value, the value will be used for
If given as a single value, the value will be used for
both h and w.
both h and w.
pad_val (int,
tupl
e[int]): Pixel pad_val value for constant fill.
If
pad_val (int,
Sequenc
e[int]): Pixel pad_val value for constant fill.
it is a
tupl
e, it must have the same length with the image
If
it is a
sequenc
e, it must have the same length with the image
channels. Defaults to 128.
channels. Defaults to 128.
prob (float): The probability for performing cutout therefore should
prob (float): The probability for performing cutout therefore should
be in range [0, 1]. Defaults to 0.5.
be in range [0, 1]. Defaults to 0.5.
...
@@ -843,11 +892,16 @@ class Cutout(object):
...
@@ -843,11 +892,16 @@ class Cutout(object):
raise
TypeError
(
raise
TypeError
(
'shape must be of '
'shape must be of '
f
'type int, float or tuple, got
{
type
(
shape
)
}
instead'
)
f
'type int, float or tuple, got
{
type
(
shape
)
}
instead'
)
if
isinstance
(
pad_val
,
int
):
pad_val
=
tuple
([
pad_val
]
*
3
)
elif
isinstance
(
pad_val
,
Sequence
):
assert
len
(
pad_val
)
==
3
,
'pad_val as a tuple must have 3 '
\
f
'elements, got
{
len
(
pad_val
)
}
instead.'
assert
0
<=
prob
<=
1.0
,
'The prob should be in range [0,1], '
\
assert
0
<=
prob
<=
1.0
,
'The prob should be in range [0,1], '
\
f
'got
{
prob
}
instead.'
f
'got
{
prob
}
instead.'
self
.
shape
=
shape
self
.
shape
=
shape
self
.
pad_val
=
pad_val
self
.
pad_val
=
tuple
(
pad_val
)
self
.
prob
=
prob
self
.
prob
=
prob
def
__call__
(
self
,
results
):
def
__call__
(
self
,
results
):
...
...
openmmlab_test/mmclassification-
speed-benchmark
/mmcls/datasets/pipelines/compose.py
→
openmmlab_test/mmclassification-
0.24.1
/mmcls/datasets/pipelines/compose.py
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
from
collections.abc
import
Sequence
from
collections.abc
import
Sequence
from
mmcv.utils
import
build_from_cfg
from
mmcv.utils
import
build_from_cfg
...
...
openmmlab_test/mmclassification-0.24.1/mmcls/datasets/pipelines/formatting.py
0 → 100644
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
from
collections.abc
import
Sequence
import
mmcv
import
numpy
as
np
import
torch
from
mmcv.parallel
import
DataContainer
as
DC
from
PIL
import
Image
from
..builder
import
PIPELINES
def
to_tensor
(
data
):
"""Convert objects of various python types to :obj:`torch.Tensor`.
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
:class:`Sequence`, :class:`int` and :class:`float`.
"""
if
isinstance
(
data
,
torch
.
Tensor
):
return
data
elif
isinstance
(
data
,
np
.
ndarray
):
return
torch
.
from_numpy
(
data
)
elif
isinstance
(
data
,
Sequence
)
and
not
mmcv
.
is_str
(
data
):
return
torch
.
tensor
(
data
)
elif
isinstance
(
data
,
int
):
return
torch
.
LongTensor
([
data
])
elif
isinstance
(
data
,
float
):
return
torch
.
FloatTensor
([
data
])
else
:
raise
TypeError
(
f
'Type
{
type
(
data
)
}
cannot be converted to tensor.'
'Supported types are: `numpy.ndarray`, `torch.Tensor`, '
'`Sequence`, `int` and `float`'
)
@
PIPELINES
.
register_module
()
class
ToTensor
(
object
):
def
__init__
(
self
,
keys
):
self
.
keys
=
keys
def
__call__
(
self
,
results
):
for
key
in
self
.
keys
:
results
[
key
]
=
to_tensor
(
results
[
key
])
return
results
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
f
'(keys=
{
self
.
keys
}
)'
@
PIPELINES
.
register_module
()
class
ImageToTensor
(
object
):
def
__init__
(
self
,
keys
):
self
.
keys
=
keys
def
__call__
(
self
,
results
):
for
key
in
self
.
keys
:
img
=
results
[
key
]
if
len
(
img
.
shape
)
<
3
:
img
=
np
.
expand_dims
(
img
,
-
1
)
results
[
key
]
=
to_tensor
(
img
.
transpose
(
2
,
0
,
1
))
return
results
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
f
'(keys=
{
self
.
keys
}
)'
@
PIPELINES
.
register_module
()
class
Transpose
(
object
):
def
__init__
(
self
,
keys
,
order
):
self
.
keys
=
keys
self
.
order
=
order
def
__call__
(
self
,
results
):
for
key
in
self
.
keys
:
results
[
key
]
=
results
[
key
].
transpose
(
self
.
order
)
return
results
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
\
f
'(keys=
{
self
.
keys
}
, order=
{
self
.
order
}
)'
@
PIPELINES
.
register_module
()
class
ToPIL
(
object
):
def
__init__
(
self
):
pass
def
__call__
(
self
,
results
):
results
[
'img'
]
=
Image
.
fromarray
(
results
[
'img'
])
return
results
@
PIPELINES
.
register_module
()
class
ToNumpy
(
object
):
def
__init__
(
self
):
pass
def
__call__
(
self
,
results
):
results
[
'img'
]
=
np
.
array
(
results
[
'img'
],
dtype
=
np
.
float32
)
return
results
@
PIPELINES
.
register_module
()
class
Collect
(
object
):
"""Collect data from the loader relevant to the specific task.
This is usually the last stage of the data loader pipeline. Typically keys
is set to some subset of "img" and "gt_label".
Args:
keys (Sequence[str]): Keys of results to be collected in ``data``.
meta_keys (Sequence[str], optional): Meta keys to be converted to
``mmcv.DataContainer`` and collected in ``data[img_metas]``.
Default: ('filename', 'ori_shape', 'img_shape', 'flip',
'flip_direction', 'img_norm_cfg')
Returns:
dict: The result dict contains the following keys
- keys in ``self.keys``
- ``img_metas`` if available
"""
def
__init__
(
self
,
keys
,
meta_keys
=
(
'filename'
,
'ori_filename'
,
'ori_shape'
,
'img_shape'
,
'flip'
,
'flip_direction'
,
'img_norm_cfg'
)):
self
.
keys
=
keys
self
.
meta_keys
=
meta_keys
def
__call__
(
self
,
results
):
data
=
{}
img_meta
=
{}
for
key
in
self
.
meta_keys
:
if
key
in
results
:
img_meta
[
key
]
=
results
[
key
]
data
[
'img_metas'
]
=
DC
(
img_meta
,
cpu_only
=
True
)
for
key
in
self
.
keys
:
data
[
key
]
=
results
[
key
]
return
data
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
\
f
'(keys=
{
self
.
keys
}
, meta_keys=
{
self
.
meta_keys
}
)'
@
PIPELINES
.
register_module
()
class
WrapFieldsToLists
(
object
):
"""Wrap fields of the data dictionary into lists for evaluation.
This class can be used as a last step of a test or validation
pipeline for single image evaluation or inference.
Example:
>>> test_pipeline = [
>>> dict(type='LoadImageFromFile'),
>>> dict(type='Normalize',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True),
>>> dict(type='ImageToTensor', keys=['img']),
>>> dict(type='Collect', keys=['img']),
>>> dict(type='WrapIntoLists')
>>> ]
"""
def
__call__
(
self
,
results
):
# Wrap dict fields into lists
for
key
,
val
in
results
.
items
():
results
[
key
]
=
[
val
]
return
results
def
__repr__
(
self
):
return
f
'
{
self
.
__class__
.
__name__
}
()'
@
PIPELINES
.
register_module
()
class
ToHalf
(
object
):
def
__init__
(
self
,
keys
):
self
.
keys
=
keys
def
__call__
(
self
,
results
):
for
k
in
self
.
keys
:
if
isinstance
(
results
[
k
],
torch
.
Tensor
):
results
[
k
]
=
results
[
k
].
to
(
torch
.
half
)
else
:
results
[
k
]
=
results
[
k
].
astype
(
np
.
float16
)
return
results
openmmlab_test/mmclassification-
speed-benchmark
/mmcls/datasets/pipelines/loading.py
→
openmmlab_test/mmclassification-
0.24.1
/mmcls/datasets/pipelines/loading.py
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
import
os.path
as
osp
import
os.path
as
osp
import
mmcv
import
mmcv
...
...
openmmlab_test/mmclassification-
speed-benchmark
/mmcls/datasets/pipelines/transforms.py
→
openmmlab_test/mmclassification-
0.24.1
/mmcls/datasets/pipelines/transforms.py
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
import
copy
import
inspect
import
inspect
import
math
import
math
import
random
import
random
...
@@ -36,18 +38,19 @@ class RandomCrop(object):
...
@@ -36,18 +38,19 @@ class RandomCrop(object):
pad_val (Number | Sequence[Number]): Pixel pad_val value for constant
pad_val (Number | Sequence[Number]): Pixel pad_val value for constant
fill. If a tuple of length 3, it is used to pad_val R, G, B
fill. If a tuple of length 3, it is used to pad_val R, G, B
channels respectively. Default: 0.
channels respectively. Default: 0.
padding_mode (str): Type of padding. Should be: constant, edge,
padding_mode (str): Type of padding. Defaults to "constant". Should
reflect or symmetric. Default: constant.
be one of the following:
-constant: Pads with a constant value, this value is specified
- constant: Pads with a constant value, this value is specified
\
with pad_val.
with pad_val.
-edge: pads with the last value at the edge of the image.
-
edge: pads with the last value at the edge of the image.
-reflect: Pads with reflection of image without repeating the
-
reflect: Pads with reflection of image without repeating the
\
last value on the edge. For example, padding [1, 2, 3, 4]
last value on the edge. For example, padding [1, 2, 3, 4]
\
with 2 elements on both sides in reflect mode will result
with 2 elements on both sides in reflect mode will result
\
in [3, 2, 1, 2, 3, 4, 3, 2].
in [3, 2, 1, 2, 3, 4, 3, 2].
-symmetric: Pads with reflection of image repeating the last
-
symmetric: Pads with reflection of image repeating the last
\
value on the edge. For example, padding [1, 2, 3, 4] with
value on the edge. For example, padding [1, 2, 3, 4] with
\
2 elements on both sides in symmetric mode will result in
2 elements on both sides in symmetric mode will result in
\
[2, 1, 1, 2, 3, 4, 4, 3].
[2, 1, 1, 2, 3, 4, 4, 3].
"""
"""
...
@@ -151,7 +154,7 @@ class RandomResizedCrop(object):
...
@@ -151,7 +154,7 @@ class RandomResizedCrop(object):
to the original image. Defaults to (0.08, 1.0).
to the original image. Defaults to (0.08, 1.0).
ratio (tuple): Range of the random aspect ratio of the cropped image
ratio (tuple): Range of the random aspect ratio of the cropped image
compared to the original image. Defaults to (3. / 4., 4. / 3.).
compared to the original image. Defaults to (3. / 4., 4. / 3.).
max_attempts (int): Maxi
n
um number of attempts before falling back to
max_attempts (int): Maxi
m
um number of attempts before falling back to
Central Crop. Defaults to 10.
Central Crop. Defaults to 10.
efficientnet_style (bool): Whether to use efficientnet style Random
efficientnet_style (bool): Whether to use efficientnet style Random
ResizedCrop. Defaults to False.
ResizedCrop. Defaults to False.
...
@@ -163,7 +166,7 @@ class RandomResizedCrop(object):
...
@@ -163,7 +166,7 @@ class RandomResizedCrop(object):
interpolation (str): Interpolation method, accepted values are
interpolation (str): Interpolation method, accepted values are
'nearest', 'bilinear', 'bicubic', 'area', 'lanczos'. Defaults to
'nearest', 'bilinear', 'bicubic', 'area', 'lanczos'. Defaults to
'bilinear'.
'bilinear'.
backend (str): The image resize backend type, acc
p
eted values are
backend (str): The image resize backend type, acce
p
ted values are
`cv2` and `pillow`. Defaults to `cv2`.
`cv2` and `pillow`. Defaults to `cv2`.
"""
"""
...
@@ -191,7 +194,7 @@ class RandomResizedCrop(object):
...
@@ -191,7 +194,7 @@ class RandomResizedCrop(object):
f
'But received scale
{
scale
}
and rato
{
ratio
}
.'
)
f
'But received scale
{
scale
}
and rato
{
ratio
}
.'
)
assert
min_covered
>=
0
,
'min_covered should be no less than 0.'
assert
min_covered
>=
0
,
'min_covered should be no less than 0.'
assert
isinstance
(
max_attempts
,
int
)
and
max_attempts
>=
0
,
\
assert
isinstance
(
max_attempts
,
int
)
and
max_attempts
>=
0
,
\
'max_attempts mush be
of typle
int and no less than 0.'
'max_attempts mush be int and no less than 0.'
assert
interpolation
in
(
'nearest'
,
'bilinear'
,
'bicubic'
,
'area'
,
assert
interpolation
in
(
'nearest'
,
'bilinear'
,
'bicubic'
,
'area'
,
'lanczos'
)
'lanczos'
)
if
backend
not
in
[
'cv2'
,
'pillow'
]:
if
backend
not
in
[
'cv2'
,
'pillow'
]:
...
@@ -217,7 +220,7 @@ class RandomResizedCrop(object):
...
@@ -217,7 +220,7 @@ class RandomResizedCrop(object):
compared to the original image size.
compared to the original image size.
ratio (tuple): Range of the random aspect ratio of the cropped
ratio (tuple): Range of the random aspect ratio of the cropped
image compared to the original image area.
image compared to the original image area.
max_attempts (int): Maxi
n
um number of attempts before falling back
max_attempts (int): Maxi
m
um number of attempts before falling back
to central crop. Defaults to 10.
to central crop. Defaults to 10.
Returns:
Returns:
...
@@ -279,7 +282,7 @@ class RandomResizedCrop(object):
...
@@ -279,7 +282,7 @@ class RandomResizedCrop(object):
compared to the original image size.
compared to the original image size.
ratio (tuple): Range of the random aspect ratio of the cropped
ratio (tuple): Range of the random aspect ratio of the cropped
image compared to the original image area.
image compared to the original image area.
max_attempts (int): Maxi
n
um number of attempts before falling back
max_attempts (int): Maxi
m
um number of attempts before falling back
to central crop. Defaults to 10.
to central crop. Defaults to 10.
min_covered (Number): Minimum ratio of the cropped area to the
min_covered (Number): Minimum ratio of the cropped area to the
original area. Only valid if efficientnet_style is true.
original area. Only valid if efficientnet_style is true.
...
@@ -311,7 +314,7 @@ class RandomResizedCrop(object):
...
@@ -311,7 +314,7 @@ class RandomResizedCrop(object):
max_target_height
=
min
(
max_target_height
,
height
)
max_target_height
=
min
(
max_target_height
,
height
)
min_target_height
=
min
(
max_target_height
,
min_target_height
)
min_target_height
=
min
(
max_target_height
,
min_target_height
)
# slightly differs from tf i
n
plementation
# slightly differs from tf i
m
plementation
target_height
=
int
(
target_height
=
int
(
round
(
random
.
uniform
(
min_target_height
,
max_target_height
)))
round
(
random
.
uniform
(
min_target_height
,
max_target_height
)))
target_width
=
int
(
round
(
target_height
*
aspect_ratio
))
target_width
=
int
(
round
(
target_height
*
aspect_ratio
))
...
@@ -393,11 +396,12 @@ class RandomGrayscale(object):
...
@@ -393,11 +396,12 @@ class RandomGrayscale(object):
grayscale. Default: 0.1.
grayscale. Default: 0.1.
Returns:
Returns:
ndarray: Grayscale version of the input image with probability
ndarray: Image after randomly grayscale transform.
gray_prob and unchanged with probability (1-gray_prob).
- If input image is 1 channel: grayscale version is 1 channel.
Notes:
- If input image is 3 channel: grayscale version is 3 channel
- If input image is 1 channel: grayscale version is 1 channel.
with r == g == b.
- If input image is 3 channel: grayscale version is 3 channel
with r == g == b.
"""
"""
def
__init__
(
self
,
gray_prob
=
0.1
):
def
__init__
(
self
,
gray_prob
=
0.1
):
...
@@ -484,20 +488,24 @@ class RandomErasing(object):
...
@@ -484,20 +488,24 @@ class RandomErasing(object):
if float, it will be converted to (aspect_ratio, 1/aspect_ratio)
if float, it will be converted to (aspect_ratio, 1/aspect_ratio)
Default: (3/10, 10/3)
Default: (3/10, 10/3)
mode (str): Fill method in erased area, can be:
mode (str): Fill method in erased area, can be:
- 'const' (default): All pixels are assign with the same value.
- 'rand': each pixel is assigned with a random value in [0, 255]
- const (default): All pixels are assign with the same value.
- rand: each pixel is assigned with a random value in [0, 255]
fill_color (sequence | Number): Base color filled in erased area.
fill_color (sequence | Number): Base color filled in erased area.
Default
:
(128, 128, 128)
Default
s to
(128, 128, 128)
.
fill_std (sequence | Number, optional): If set and mode
=
'rand',
fill
fill_std (sequence | Number, optional): If set and
``
mode
`` is
'rand',
erased area with random color from normal distribution
fill
erased area with random color from normal distribution
(mean=fill_color, std=fill_std); If not set, fill erased area with
(mean=fill_color, std=fill_std); If not set, fill erased area with
random color from uniform distribution (0~255)
random color from uniform distribution (0~255). Defaults to None.
Default: None
Note:
Note:
See https://arxiv.org/pdf/1708.04896.pdf
See `Random Erasing Data Augmentation
<https://arxiv.org/pdf/1708.04896.pdf>`_
This paper provided 4 modes: RE-R, RE-M, RE-0, RE-255, and use RE-M as
This paper provided 4 modes: RE-R, RE-M, RE-0, RE-255, and use RE-M as
default.
default. The config of these 4 modes are:
- RE-R: RandomErasing(mode='rand')
- RE-R: RandomErasing(mode='rand')
- RE-M: RandomErasing(mode='const', fill_color=(123.67, 116.3, 103.5))
- RE-M: RandomErasing(mode='const', fill_color=(123.67, 116.3, 103.5))
- RE-0: RandomErasing(mode='const', fill_color=0)
- RE-0: RandomErasing(mode='const', fill_color=0)
...
@@ -605,6 +613,58 @@ class RandomErasing(object):
...
@@ -605,6 +613,58 @@ class RandomErasing(object):
return
repr_str
return
repr_str
@
PIPELINES
.
register_module
()
class
Pad
(
object
):
"""Pad images.
Args:
size (tuple[int] | None): Expected padding size (h, w). Conflicts with
pad_to_square. Defaults to None.
pad_to_square (bool): Pad any image to square shape. Defaults to False.
pad_val (Number | Sequence[Number]): Values to be filled in padding
areas when padding_mode is 'constant'. Default to 0.
padding_mode (str): Type of padding. Should be: constant, edge,
reflect or symmetric. Default to "constant".
"""
def
__init__
(
self
,
size
=
None
,
pad_to_square
=
False
,
pad_val
=
0
,
padding_mode
=
'constant'
):
assert
(
size
is
None
)
^
(
pad_to_square
is
False
),
\
'Only one of [size, pad_to_square] should be given, '
\
f
'but get
{
(
size
is
not
None
)
+
(
pad_to_square
is
not
False
)
}
'
self
.
size
=
size
self
.
pad_to_square
=
pad_to_square
self
.
pad_val
=
pad_val
self
.
padding_mode
=
padding_mode
def
__call__
(
self
,
results
):
for
key
in
results
.
get
(
'img_fields'
,
[
'img'
]):
img
=
results
[
key
]
if
self
.
pad_to_square
:
target_size
=
tuple
(
max
(
img
.
shape
[
0
],
img
.
shape
[
1
])
for
_
in
range
(
2
))
else
:
target_size
=
self
.
size
img
=
mmcv
.
impad
(
img
,
shape
=
target_size
,
pad_val
=
self
.
pad_val
,
padding_mode
=
self
.
padding_mode
)
results
[
key
]
=
img
results
[
'img_shape'
]
=
img
.
shape
return
results
def
__repr__
(
self
):
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
f
'(size=
{
self
.
size
}
, '
repr_str
+=
f
'(pad_val=
{
self
.
pad_val
}
, '
repr_str
+=
f
'padding_mode=
{
self
.
padding_mode
}
)'
return
repr_str
@
PIPELINES
.
register_module
()
@
PIPELINES
.
register_module
()
class
Resize
(
object
):
class
Resize
(
object
):
"""Resize images.
"""Resize images.
...
@@ -613,35 +673,49 @@ class Resize(object):
...
@@ -613,35 +673,49 @@ class Resize(object):
size (int | tuple): Images scales for resizing (h, w).
size (int | tuple): Images scales for resizing (h, w).
When size is int, the default behavior is to resize an image
When size is int, the default behavior is to resize an image
to (size, size). When size is tuple and the second value is -1,
to (size, size). When size is tuple and the second value is -1,
the short edge of an image is resized to its first value.
the image will be resized according to adaptive_side. For example,
For example, when size is 224, the image is resized to 224x224.
when size is 224, the image is resized to 224x224. When size is
When size is (224, -1), the short side is resized to 224 and the
(224, -1) and adaptive_size is "short", the short side is resized
other side is computed based on the short side, maintaining the
to 224 and the other side is computed based on the short side,
aspect ratio.
maintaining the aspect ratio.
interpolation (str): Interpolation method, accepted values are
interpolation (str): Interpolation method. For "cv2" backend, accepted
"nearest", "bilinear", "bicubic", "area", "lanczos".
values are "nearest", "bilinear", "bicubic", "area", "lanczos". For
"pillow" backend, accepted values are "nearest", "bilinear",
"bicubic", "box", "lanczos", "hamming".
More details can be found in `mmcv.image.geometric`.
More details can be found in `mmcv.image.geometric`.
backend (str): The image resize backend type, accpeted values are
adaptive_side(str): Adaptive resize policy, accepted values are
"short", "long", "height", "width". Default to "short".
backend (str): The image resize backend type, accepted values are
`cv2` and `pillow`. Default: `cv2`.
`cv2` and `pillow`. Default: `cv2`.
"""
"""
def
__init__
(
self
,
size
,
interpolation
=
'bilinear'
,
backend
=
'cv2'
):
def
__init__
(
self
,
size
,
interpolation
=
'bilinear'
,
adaptive_side
=
'short'
,
backend
=
'cv2'
):
assert
isinstance
(
size
,
int
)
or
(
isinstance
(
size
,
tuple
)
assert
isinstance
(
size
,
int
)
or
(
isinstance
(
size
,
tuple
)
and
len
(
size
)
==
2
)
and
len
(
size
)
==
2
)
self
.
resize_w_short_side
=
False
assert
adaptive_side
in
{
'short'
,
'long'
,
'height'
,
'width'
}
self
.
adaptive_side
=
adaptive_side
self
.
adaptive_resize
=
False
if
isinstance
(
size
,
int
):
if
isinstance
(
size
,
int
):
assert
size
>
0
assert
size
>
0
size
=
(
size
,
size
)
size
=
(
size
,
size
)
else
:
else
:
assert
size
[
0
]
>
0
and
(
size
[
1
]
>
0
or
size
[
1
]
==
-
1
)
assert
size
[
0
]
>
0
and
(
size
[
1
]
>
0
or
size
[
1
]
==
-
1
)
if
size
[
1
]
==
-
1
:
if
size
[
1
]
==
-
1
:
self
.
resize_w_short_side
=
True
self
.
adaptive_resize
=
True
assert
interpolation
in
(
'nearest'
,
'bilinear'
,
'bicubic'
,
'area'
,
'lanczos'
)
if
backend
not
in
[
'cv2'
,
'pillow'
]:
if
backend
not
in
[
'cv2'
,
'pillow'
]:
raise
ValueError
(
f
'backend:
{
backend
}
is not supported for resize.'
raise
ValueError
(
f
'backend:
{
backend
}
is not supported for resize.'
'Supported backends are "cv2", "pillow"'
)
'Supported backends are "cv2", "pillow"'
)
if
backend
==
'cv2'
:
assert
interpolation
in
(
'nearest'
,
'bilinear'
,
'bicubic'
,
'area'
,
'lanczos'
)
else
:
assert
interpolation
in
(
'nearest'
,
'bilinear'
,
'bicubic'
,
'box'
,
'lanczos'
,
'hamming'
)
self
.
size
=
size
self
.
size
=
size
self
.
interpolation
=
interpolation
self
.
interpolation
=
interpolation
self
.
backend
=
backend
self
.
backend
=
backend
...
@@ -650,19 +724,29 @@ class Resize(object):
...
@@ -650,19 +724,29 @@ class Resize(object):
for
key
in
results
.
get
(
'img_fields'
,
[
'img'
]):
for
key
in
results
.
get
(
'img_fields'
,
[
'img'
]):
img
=
results
[
key
]
img
=
results
[
key
]
ignore_resize
=
False
ignore_resize
=
False
if
self
.
resize_w_short_
si
d
e
:
if
self
.
adaptive_re
si
z
e
:
h
,
w
=
img
.
shape
[:
2
]
h
,
w
=
img
.
shape
[:
2
]
short_side
=
self
.
size
[
0
]
target_size
=
self
.
size
[
0
]
if
(
w
<=
h
and
w
==
short_side
)
or
(
h
<=
w
and
h
==
short_side
):
condition_ignore_resize
=
{
'short'
:
min
(
h
,
w
)
==
target_size
,
'long'
:
max
(
h
,
w
)
==
target_size
,
'height'
:
h
==
target_size
,
'width'
:
w
==
target_size
}
if
condition_ignore_resize
[
self
.
adaptive_side
]:
ignore_resize
=
True
ignore_resize
=
True
elif
any
([
self
.
adaptive_side
==
'short'
and
w
<
h
,
self
.
adaptive_side
==
'long'
and
w
>
h
,
self
.
adaptive_side
==
'width'
,
]):
width
=
target_size
height
=
int
(
target_size
*
h
/
w
)
else
:
else
:
if
w
<
h
:
height
=
target_size
width
=
short_side
width
=
int
(
target_size
*
w
/
h
)
height
=
int
(
short_side
*
h
/
w
)
else
:
height
=
short_side
width
=
int
(
short_side
*
w
/
h
)
else
:
else
:
height
,
width
=
self
.
size
height
,
width
=
self
.
size
if
not
ignore_resize
:
if
not
ignore_resize
:
...
@@ -700,21 +784,23 @@ class CenterCrop(object):
...
@@ -700,21 +784,23 @@ class CenterCrop(object):
32.
32.
interpolation (str): Interpolation method, accepted values are
interpolation (str): Interpolation method, accepted values are
'nearest', 'bilinear', 'bicubic', 'area', 'lanczos'. Only valid if
'nearest', 'bilinear', 'bicubic', 'area', 'lanczos'. Only valid if
efficientnet
style is True. Defaults to 'bilinear'.
``
efficientnet
_
style
``
is True. Defaults to 'bilinear'.
backend (str): The image resize backend type, acc
p
eted values are
backend (str): The image resize backend type, acce
p
ted values are
`cv2` and `pillow`. Only valid if efficientnet style is True.
`cv2` and `pillow`. Only valid if efficientnet style is True.
Defaults to `cv2`.
Defaults to `cv2`.
Notes:
Notes:
If the image is smaller than the crop size, return the original image.
- If the image is smaller than the crop size, return the original
If efficientnet_style is set to False, the pipeline would be a simple
image.
center crop using the crop_size.
- If efficientnet_style is set to False, the pipeline would be a simple
If efficientnet_style is set to True, the pipeline will be to first to
center crop using the crop_size.
perform the center crop with the crop_size_ as:
- If efficientnet_style is set to True, the pipeline will be to first
to perform the center crop with the ``crop_size_`` as:
.. math::
.. math::
crop\_size\_ = crop\_size / (crop\_size + crop\_padding) * short\_edge
\text{crop_size_} = \frac{\text{crop_size}}{\text{crop_size} +
\text{crop_padding}} \times \text{short_edge}
And then the pipeline resizes the img to the input crop size.
And then the pipeline resizes the img to the input crop size.
"""
"""
...
@@ -886,7 +972,7 @@ class Lighting(object):
...
@@ -886,7 +972,7 @@ class Lighting(object):
eigvec (list[list]): the eigenvector of the convariance matrix of pixel
eigvec (list[list]): the eigenvector of the convariance matrix of pixel
values, respectively.
values, respectively.
alphastd (float): The standard deviation for distribution of alpha.
alphastd (float): The standard deviation for distribution of alpha.
D
a
faults to 0.1
D
e
faults to 0.1
to_rgb (bool): Whether to convert img to rgb.
to_rgb (bool): Whether to convert img to rgb.
"""
"""
...
@@ -1032,19 +1118,23 @@ class Albu(object):
...
@@ -1032,19 +1118,23 @@ class Albu(object):
return
updated_dict
return
updated_dict
def
__call__
(
self
,
results
):
def
__call__
(
self
,
results
):
# backup gt_label in case Albu modify it.
_gt_label
=
copy
.
deepcopy
(
results
.
get
(
'gt_label'
,
None
))
# dict to albumentations format
# dict to albumentations format
results
=
self
.
mapper
(
results
,
self
.
keymap_to_albu
)
results
=
self
.
mapper
(
results
,
self
.
keymap_to_albu
)
# process aug
results
=
self
.
aug
(
**
results
)
results
=
self
.
aug
(
**
results
)
if
'gt_labels'
in
results
:
if
isinstance
(
results
[
'gt_labels'
],
list
):
results
[
'gt_labels'
]
=
np
.
array
(
results
[
'gt_labels'
])
results
[
'gt_labels'
]
=
results
[
'gt_labels'
].
astype
(
np
.
int64
)
# back to the original format
# back to the original format
results
=
self
.
mapper
(
results
,
self
.
keymap_back
)
results
=
self
.
mapper
(
results
,
self
.
keymap_back
)
if
_gt_label
is
not
None
:
# recover backup gt_label
results
.
update
({
'gt_label'
:
_gt_label
})
# update final shape
# update final shape
if
self
.
update_pad_shape
:
if
self
.
update_pad_shape
:
results
[
'pad_shape'
]
=
results
[
'img'
].
shape
results
[
'pad_shape'
]
=
results
[
'img'
].
shape
...
...
openmmlab_test/mmclassification-0.24.1/mmcls/datasets/samplers/__init__.py
0 → 100644
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
from
.distributed_sampler
import
DistributedSampler
from
.repeat_aug
import
RepeatAugSampler
__all__
=
(
'DistributedSampler'
,
'RepeatAugSampler'
)
openmmlab_test/mmclassification-0.24.1/mmcls/datasets/samplers/distributed_sampler.py
0 → 100644
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
from
torch.utils.data
import
DistributedSampler
as
_DistributedSampler
from
mmcls.core.utils
import
sync_random_seed
from
mmcls.datasets
import
SAMPLERS
from
mmcls.utils
import
auto_select_device
@
SAMPLERS
.
register_module
()
class
DistributedSampler
(
_DistributedSampler
):
def
__init__
(
self
,
dataset
,
num_replicas
=
None
,
rank
=
None
,
shuffle
=
True
,
round_up
=
True
,
seed
=
0
):
super
().
__init__
(
dataset
,
num_replicas
=
num_replicas
,
rank
=
rank
)
self
.
shuffle
=
shuffle
self
.
round_up
=
round_up
if
self
.
round_up
:
self
.
total_size
=
self
.
num_samples
*
self
.
num_replicas
else
:
self
.
total_size
=
len
(
self
.
dataset
)
# In distributed sampling, different ranks should sample
# non-overlapped data in the dataset. Therefore, this function
# is used to make sure that each rank shuffles the data indices
# in the same order based on the same seed. Then different ranks
# could use different indices to select non-overlapped data from the
# same data list.
self
.
seed
=
sync_random_seed
(
seed
,
device
=
auto_select_device
())
def
__iter__
(
self
):
# deterministically shuffle based on epoch
if
self
.
shuffle
:
g
=
torch
.
Generator
()
# When :attr:`shuffle=True`, this ensures all replicas
# use a different random ordering for each epoch.
# Otherwise, the next iteration of this sampler will
# yield the same ordering.
g
.
manual_seed
(
self
.
epoch
+
self
.
seed
)
indices
=
torch
.
randperm
(
len
(
self
.
dataset
),
generator
=
g
).
tolist
()
else
:
indices
=
torch
.
arange
(
len
(
self
.
dataset
)).
tolist
()
# add extra samples to make it evenly divisible
if
self
.
round_up
:
indices
=
(
indices
*
int
(
self
.
total_size
/
len
(
indices
)
+
1
))[:
self
.
total_size
]
assert
len
(
indices
)
==
self
.
total_size
# subsample
indices
=
indices
[
self
.
rank
:
self
.
total_size
:
self
.
num_replicas
]
if
self
.
round_up
:
assert
len
(
indices
)
==
self
.
num_samples
return
iter
(
indices
)
openmmlab_test/mmclassification-0.24.1/mmcls/datasets/samplers/repeat_aug.py
0 → 100644
View file @
0fd8347d
import
math
import
torch
from
mmcv.runner
import
get_dist_info
from
torch.utils.data
import
Sampler
from
mmcls.core.utils
import
sync_random_seed
from
mmcls.datasets
import
SAMPLERS
@
SAMPLERS
.
register_module
()
class
RepeatAugSampler
(
Sampler
):
"""Sampler that restricts data loading to a subset of the dataset for
distributed, with repeated augmentation. It ensures that different each
augmented version of a sample will be visible to a different process (GPU).
Heavily based on torch.utils.data.DistributedSampler.
This sampler was taken from
https://github.com/facebookresearch/deit/blob/0c4b8f60/samplers.py
Used in
Copyright (c) 2015-present, Facebook, Inc.
"""
def
__init__
(
self
,
dataset
,
num_replicas
=
None
,
rank
=
None
,
shuffle
=
True
,
num_repeats
=
3
,
selected_round
=
256
,
selected_ratio
=
0
,
seed
=
0
):
default_rank
,
default_world_size
=
get_dist_info
()
rank
=
default_rank
if
rank
is
None
else
rank
num_replicas
=
(
default_world_size
if
num_replicas
is
None
else
num_replicas
)
self
.
dataset
=
dataset
self
.
num_replicas
=
num_replicas
self
.
rank
=
rank
self
.
shuffle
=
shuffle
self
.
num_repeats
=
num_repeats
self
.
epoch
=
0
self
.
num_samples
=
int
(
math
.
ceil
(
len
(
self
.
dataset
)
*
num_repeats
/
self
.
num_replicas
))
self
.
total_size
=
self
.
num_samples
*
self
.
num_replicas
# Determine the number of samples to select per epoch for each rank.
# num_selected logic defaults to be the same as original RASampler
# impl, but this one can be tweaked
# via selected_ratio and selected_round args.
selected_ratio
=
selected_ratio
or
num_replicas
# ratio to reduce
# selected samples by, num_replicas if 0
if
selected_round
:
self
.
num_selected_samples
=
int
(
math
.
floor
(
len
(
self
.
dataset
)
//
selected_round
*
selected_round
/
selected_ratio
))
else
:
self
.
num_selected_samples
=
int
(
math
.
ceil
(
len
(
self
.
dataset
)
/
selected_ratio
))
# In distributed sampling, different ranks should sample
# non-overlapped data in the dataset. Therefore, this function
# is used to make sure that each rank shuffles the data indices
# in the same order based on the same seed. Then different ranks
# could use different indices to select non-overlapped data from the
# same data list.
self
.
seed
=
sync_random_seed
(
seed
)
def
__iter__
(
self
):
# deterministically shuffle based on epoch
if
self
.
shuffle
:
if
self
.
num_replicas
>
1
:
# In distributed environment
# deterministically shuffle based on epoch
g
=
torch
.
Generator
()
# When :attr:`shuffle=True`, this ensures all replicas
# use a different random ordering for each epoch.
# Otherwise, the next iteration of this sampler will
# yield the same ordering.
g
.
manual_seed
(
self
.
epoch
+
self
.
seed
)
indices
=
torch
.
randperm
(
len
(
self
.
dataset
),
generator
=
g
).
tolist
()
else
:
indices
=
torch
.
randperm
(
len
(
self
.
dataset
)).
tolist
()
else
:
indices
=
list
(
range
(
len
(
self
.
dataset
)))
# produce repeats e.g. [0, 0, 0, 1, 1, 1, 2, 2, 2....]
indices
=
[
x
for
x
in
indices
for
_
in
range
(
self
.
num_repeats
)]
# add extra samples to make it evenly divisible
padding_size
=
self
.
total_size
-
len
(
indices
)
indices
+=
indices
[:
padding_size
]
assert
len
(
indices
)
==
self
.
total_size
# subsample per rank
indices
=
indices
[
self
.
rank
:
self
.
total_size
:
self
.
num_replicas
]
assert
len
(
indices
)
==
self
.
num_samples
# return up to num selected samples
return
iter
(
indices
[:
self
.
num_selected_samples
])
def
__len__
(
self
):
return
self
.
num_selected_samples
def
set_epoch
(
self
,
epoch
):
self
.
epoch
=
epoch
openmmlab_test/mmclassification-0.24.1/mmcls/datasets/stanford_cars.py
0 → 100644
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
import
os.path
as
osp
from
typing
import
Optional
import
numpy
as
np
from
.base_dataset
import
BaseDataset
from
.builder
import
DATASETS
@
DATASETS
.
register_module
()
class
StanfordCars
(
BaseDataset
):
"""`Stanford Cars`_ Dataset.
After downloading and decompression, the dataset
directory structure is as follows.
Stanford Cars dataset directory::
Stanford Cars
├── cars_train
│ ├── 00001.jpg
│ ├── 00002.jpg
│ └── ...
├── cars_test
│ ├── 00001.jpg
│ ├── 00002.jpg
│ └── ...
└── devkit
├── cars_meta.mat
├── cars_train_annos.mat
├── cars_test_annos.mat
├── cars_test_annoswithlabels.mat
├── eval_train.m
└── train_perfect_preds.txt
.. _Stanford Cars: https://ai.stanford.edu/~jkrause/cars/car_dataset.html
Args:
data_prefix (str): the prefix of data path
test_mode (bool): ``test_mode=True`` means in test phase. It determines
to use the training set or test set.
ann_file (str, optional): The annotation file. If is string, read
samples paths from the ann_file. If is None, read samples path
from cars_{train|test}_annos.mat file. Defaults to None.
"""
# noqa: E501
CLASSES
=
[
'AM General Hummer SUV 2000'
,
'Acura RL Sedan 2012'
,
'Acura TL Sedan 2012'
,
'Acura TL Type-S 2008'
,
'Acura TSX Sedan 2012'
,
'Acura Integra Type R 2001'
,
'Acura ZDX Hatchback 2012'
,
'Aston Martin V8 Vantage Convertible 2012'
,
'Aston Martin V8 Vantage Coupe 2012'
,
'Aston Martin Virage Convertible 2012'
,
'Aston Martin Virage Coupe 2012'
,
'Audi RS 4 Convertible 2008'
,
'Audi A5 Coupe 2012'
,
'Audi TTS Coupe 2012'
,
'Audi R8 Coupe 2012'
,
'Audi V8 Sedan 1994'
,
'Audi 100 Sedan 1994'
,
'Audi 100 Wagon 1994'
,
'Audi TT Hatchback 2011'
,
'Audi S6 Sedan 2011'
,
'Audi S5 Convertible 2012'
,
'Audi S5 Coupe 2012'
,
'Audi S4 Sedan 2012'
,
'Audi S4 Sedan 2007'
,
'Audi TT RS Coupe 2012'
,
'BMW ActiveHybrid 5 Sedan 2012'
,
'BMW 1 Series Convertible 2012'
,
'BMW 1 Series Coupe 2012'
,
'BMW 3 Series Sedan 2012'
,
'BMW 3 Series Wagon 2012'
,
'BMW 6 Series Convertible 2007'
,
'BMW X5 SUV 2007'
,
'BMW X6 SUV 2012'
,
'BMW M3 Coupe 2012'
,
'BMW M5 Sedan 2010'
,
'BMW M6 Convertible 2010'
,
'BMW X3 SUV 2012'
,
'BMW Z4 Convertible 2012'
,
'Bentley Continental Supersports Conv. Convertible 2012'
,
'Bentley Arnage Sedan 2009'
,
'Bentley Mulsanne Sedan 2011'
,
'Bentley Continental GT Coupe 2012'
,
'Bentley Continental GT Coupe 2007'
,
'Bentley Continental Flying Spur Sedan 2007'
,
'Bugatti Veyron 16.4 Convertible 2009'
,
'Bugatti Veyron 16.4 Coupe 2009'
,
'Buick Regal GS 2012'
,
'Buick Rainier SUV 2007'
,
'Buick Verano Sedan 2012'
,
'Buick Enclave SUV 2012'
,
'Cadillac CTS-V Sedan 2012'
,
'Cadillac SRX SUV 2012'
,
'Cadillac Escalade EXT Crew Cab 2007'
,
'Chevrolet Silverado 1500 Hybrid Crew Cab 2012'
,
'Chevrolet Corvette Convertible 2012'
,
'Chevrolet Corvette ZR1 2012'
,
'Chevrolet Corvette Ron Fellows Edition Z06 2007'
,
'Chevrolet Traverse SUV 2012'
,
'Chevrolet Camaro Convertible 2012'
,
'Chevrolet HHR SS 2010'
,
'Chevrolet Impala Sedan 2007'
,
'Chevrolet Tahoe Hybrid SUV 2012'
,
'Chevrolet Sonic Sedan 2012'
,
'Chevrolet Express Cargo Van 2007'
,
'Chevrolet Avalanche Crew Cab 2012'
,
'Chevrolet Cobalt SS 2010'
,
'Chevrolet Malibu Hybrid Sedan 2010'
,
'Chevrolet TrailBlazer SS 2009'
,
'Chevrolet Silverado 2500HD Regular Cab 2012'
,
'Chevrolet Silverado 1500 Classic Extended Cab 2007'
,
'Chevrolet Express Van 2007'
,
'Chevrolet Monte Carlo Coupe 2007'
,
'Chevrolet Malibu Sedan 2007'
,
'Chevrolet Silverado 1500 Extended Cab 2012'
,
'Chevrolet Silverado 1500 Regular Cab 2012'
,
'Chrysler Aspen SUV 2009'
,
'Chrysler Sebring Convertible 2010'
,
'Chrysler Town and Country Minivan 2012'
,
'Chrysler 300 SRT-8 2010'
,
'Chrysler Crossfire Convertible 2008'
,
'Chrysler PT Cruiser Convertible 2008'
,
'Daewoo Nubira Wagon 2002'
,
'Dodge Caliber Wagon 2012'
,
'Dodge Caliber Wagon 2007'
,
'Dodge Caravan Minivan 1997'
,
'Dodge Ram Pickup 3500 Crew Cab 2010'
,
'Dodge Ram Pickup 3500 Quad Cab 2009'
,
'Dodge Sprinter Cargo Van 2009'
,
'Dodge Journey SUV 2012'
,
'Dodge Dakota Crew Cab 2010'
,
'Dodge Dakota Club Cab 2007'
,
'Dodge Magnum Wagon 2008'
,
'Dodge Challenger SRT8 2011'
,
'Dodge Durango SUV 2012'
,
'Dodge Durango SUV 2007'
,
'Dodge Charger Sedan 2012'
,
'Dodge Charger SRT-8 2009'
,
'Eagle Talon Hatchback 1998'
,
'FIAT 500 Abarth 2012'
,
'FIAT 500 Convertible 2012'
,
'Ferrari FF Coupe 2012'
,
'Ferrari California Convertible 2012'
,
'Ferrari 458 Italia Convertible 2012'
,
'Ferrari 458 Italia Coupe 2012'
,
'Fisker Karma Sedan 2012'
,
'Ford F-450 Super Duty Crew Cab 2012'
,
'Ford Mustang Convertible 2007'
,
'Ford Freestar Minivan 2007'
,
'Ford Expedition EL SUV 2009'
,
'Ford Edge SUV 2012'
,
'Ford Ranger SuperCab 2011'
,
'Ford GT Coupe 2006'
,
'Ford F-150 Regular Cab 2012'
,
'Ford F-150 Regular Cab 2007'
,
'Ford Focus Sedan 2007'
,
'Ford E-Series Wagon Van 2012'
,
'Ford Fiesta Sedan 2012'
,
'GMC Terrain SUV 2012'
,
'GMC Savana Van 2012'
,
'GMC Yukon Hybrid SUV 2012'
,
'GMC Acadia SUV 2012'
,
'GMC Canyon Extended Cab 2012'
,
'Geo Metro Convertible 1993'
,
'HUMMER H3T Crew Cab 2010'
,
'HUMMER H2 SUT Crew Cab 2009'
,
'Honda Odyssey Minivan 2012'
,
'Honda Odyssey Minivan 2007'
,
'Honda Accord Coupe 2012'
,
'Honda Accord Sedan 2012'
,
'Hyundai Veloster Hatchback 2012'
,
'Hyundai Santa Fe SUV 2012'
,
'Hyundai Tucson SUV 2012'
,
'Hyundai Veracruz SUV 2012'
,
'Hyundai Sonata Hybrid Sedan 2012'
,
'Hyundai Elantra Sedan 2007'
,
'Hyundai Accent Sedan 2012'
,
'Hyundai Genesis Sedan 2012'
,
'Hyundai Sonata Sedan 2012'
,
'Hyundai Elantra Touring Hatchback 2012'
,
'Hyundai Azera Sedan 2012'
,
'Infiniti G Coupe IPL 2012'
,
'Infiniti QX56 SUV 2011'
,
'Isuzu Ascender SUV 2008'
,
'Jaguar XK XKR 2012'
,
'Jeep Patriot SUV 2012'
,
'Jeep Wrangler SUV 2012'
,
'Jeep Liberty SUV 2012'
,
'Jeep Grand Cherokee SUV 2012'
,
'Jeep Compass SUV 2012'
,
'Lamborghini Reventon Coupe 2008'
,
'Lamborghini Aventador Coupe 2012'
,
'Lamborghini Gallardo LP 570-4 Superleggera 2012'
,
'Lamborghini Diablo Coupe 2001'
,
'Land Rover Range Rover SUV 2012'
,
'Land Rover LR2 SUV 2012'
,
'Lincoln Town Car Sedan 2011'
,
'MINI Cooper Roadster Convertible 2012'
,
'Maybach Landaulet Convertible 2012'
,
'Mazda Tribute SUV 2011'
,
'McLaren MP4-12C Coupe 2012'
,
'Mercedes-Benz 300-Class Convertible 1993'
,
'Mercedes-Benz C-Class Sedan 2012'
,
'Mercedes-Benz SL-Class Coupe 2009'
,
'Mercedes-Benz E-Class Sedan 2012'
,
'Mercedes-Benz S-Class Sedan 2012'
,
'Mercedes-Benz Sprinter Van 2012'
,
'Mitsubishi Lancer Sedan 2012'
,
'Nissan Leaf Hatchback 2012'
,
'Nissan NV Passenger Van 2012'
,
'Nissan Juke Hatchback 2012'
,
'Nissan 240SX Coupe 1998'
,
'Plymouth Neon Coupe 1999'
,
'Porsche Panamera Sedan 2012'
,
'Ram C/V Cargo Van Minivan 2012'
,
'Rolls-Royce Phantom Drophead Coupe Convertible 2012'
,
'Rolls-Royce Ghost Sedan 2012'
,
'Rolls-Royce Phantom Sedan 2012'
,
'Scion xD Hatchback 2012'
,
'Spyker C8 Convertible 2009'
,
'Spyker C8 Coupe 2009'
,
'Suzuki Aerio Sedan 2007'
,
'Suzuki Kizashi Sedan 2012'
,
'Suzuki SX4 Hatchback 2012'
,
'Suzuki SX4 Sedan 2012'
,
'Tesla Model S Sedan 2012'
,
'Toyota Sequoia SUV 2012'
,
'Toyota Camry Sedan 2012'
,
'Toyota Corolla Sedan 2012'
,
'Toyota 4Runner SUV 2012'
,
'Volkswagen Golf Hatchback 2012'
,
'Volkswagen Golf Hatchback 1991'
,
'Volkswagen Beetle Hatchback 2012'
,
'Volvo C30 Hatchback 2012'
,
'Volvo 240 Sedan 1993'
,
'Volvo XC90 SUV 2007'
,
'smart fortwo Convertible 2012'
]
def
__init__
(
self
,
data_prefix
:
str
,
test_mode
:
bool
,
ann_file
:
Optional
[
str
]
=
None
,
**
kwargs
):
if
test_mode
:
if
ann_file
is
not
None
:
self
.
test_ann_file
=
ann_file
else
:
self
.
test_ann_file
=
osp
.
join
(
data_prefix
,
'devkit/cars_test_annos_withlabels.mat'
)
data_prefix
=
osp
.
join
(
data_prefix
,
'cars_test'
)
else
:
if
ann_file
is
not
None
:
self
.
train_ann_file
=
ann_file
else
:
self
.
train_ann_file
=
osp
.
join
(
data_prefix
,
'devkit/cars_train_annos.mat'
)
data_prefix
=
osp
.
join
(
data_prefix
,
'cars_train'
)
super
(
StanfordCars
,
self
).
__init__
(
ann_file
=
ann_file
,
data_prefix
=
data_prefix
,
test_mode
=
test_mode
,
**
kwargs
)
def
load_annotations
(
self
):
try
:
import
scipy.io
as
sio
except
ImportError
:
raise
ImportError
(
'please run `pip install scipy` to install package `scipy`.'
)
data_infos
=
[]
if
self
.
test_mode
:
data
=
sio
.
loadmat
(
self
.
test_ann_file
)
else
:
data
=
sio
.
loadmat
(
self
.
train_ann_file
)
for
img
in
data
[
'annotations'
][
0
]:
info
=
{
'img_prefix'
:
self
.
data_prefix
}
# The organization of each record is as follows,
# 0: bbox_x1 of each image
# 1: bbox_y1 of each image
# 2: bbox_x2 of each image
# 3: bbox_y2 of each image
# 4: class_id, start from 0, so
# here we need to '- 1' to let them start from 0
# 5: file name of each image
info
[
'img_info'
]
=
{
'filename'
:
img
[
5
][
0
]}
info
[
'gt_label'
]
=
np
.
array
(
img
[
4
][
0
][
0
]
-
1
,
dtype
=
np
.
int64
)
data_infos
.
append
(
info
)
return
data_infos
openmmlab_test/mmclassification-0.24.1/mmcls/datasets/utils.py
0 → 100644
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
import
gzip
import
hashlib
import
os
import
os.path
import
shutil
import
tarfile
import
urllib.error
import
urllib.request
import
zipfile
__all__
=
[
'rm_suffix'
,
'check_integrity'
,
'download_and_extract_archive'
]
def
rm_suffix
(
s
,
suffix
=
None
):
if
suffix
is
None
:
return
s
[:
s
.
rfind
(
'.'
)]
else
:
return
s
[:
s
.
rfind
(
suffix
)]
def
calculate_md5
(
fpath
,
chunk_size
=
1024
*
1024
):
md5
=
hashlib
.
md5
()
with
open
(
fpath
,
'rb'
)
as
f
:
for
chunk
in
iter
(
lambda
:
f
.
read
(
chunk_size
),
b
''
):
md5
.
update
(
chunk
)
return
md5
.
hexdigest
()
def
check_md5
(
fpath
,
md5
,
**
kwargs
):
return
md5
==
calculate_md5
(
fpath
,
**
kwargs
)
def
check_integrity
(
fpath
,
md5
=
None
):
if
not
os
.
path
.
isfile
(
fpath
):
return
False
if
md5
is
None
:
return
True
return
check_md5
(
fpath
,
md5
)
def
download_url_to_file
(
url
,
fpath
):
with
urllib
.
request
.
urlopen
(
url
)
as
resp
,
open
(
fpath
,
'wb'
)
as
of
:
shutil
.
copyfileobj
(
resp
,
of
)
def
download_url
(
url
,
root
,
filename
=
None
,
md5
=
None
):
"""Download a file from a url and place it in root.
Args:
url (str): URL to download file from.
root (str): Directory to place downloaded file in.
filename (str | None): Name to save the file under.
If filename is None, use the basename of the URL.
md5 (str | None): MD5 checksum of the download.
If md5 is None, download without md5 check.
"""
root
=
os
.
path
.
expanduser
(
root
)
if
not
filename
:
filename
=
os
.
path
.
basename
(
url
)
fpath
=
os
.
path
.
join
(
root
,
filename
)
os
.
makedirs
(
root
,
exist_ok
=
True
)
if
check_integrity
(
fpath
,
md5
):
print
(
f
'Using downloaded and verified file:
{
fpath
}
'
)
else
:
try
:
print
(
f
'Downloading
{
url
}
to
{
fpath
}
'
)
download_url_to_file
(
url
,
fpath
)
except
(
urllib
.
error
.
URLError
,
IOError
)
as
e
:
if
url
[:
5
]
==
'https'
:
url
=
url
.
replace
(
'https:'
,
'http:'
)
print
(
'Failed download. Trying https -> http instead.'
f
' Downloading
{
url
}
to
{
fpath
}
'
)
download_url_to_file
(
url
,
fpath
)
else
:
raise
e
# check integrity of downloaded file
if
not
check_integrity
(
fpath
,
md5
):
raise
RuntimeError
(
'File not found or corrupted.'
)
def
_is_tarxz
(
filename
):
return
filename
.
endswith
(
'.tar.xz'
)
def
_is_tar
(
filename
):
return
filename
.
endswith
(
'.tar'
)
def
_is_targz
(
filename
):
return
filename
.
endswith
(
'.tar.gz'
)
def
_is_tgz
(
filename
):
return
filename
.
endswith
(
'.tgz'
)
def
_is_gzip
(
filename
):
return
filename
.
endswith
(
'.gz'
)
and
not
filename
.
endswith
(
'.tar.gz'
)
def
_is_zip
(
filename
):
return
filename
.
endswith
(
'.zip'
)
def
extract_archive
(
from_path
,
to_path
=
None
,
remove_finished
=
False
):
if
to_path
is
None
:
to_path
=
os
.
path
.
dirname
(
from_path
)
if
_is_tar
(
from_path
):
with
tarfile
.
open
(
from_path
,
'r'
)
as
tar
:
tar
.
extractall
(
path
=
to_path
)
elif
_is_targz
(
from_path
)
or
_is_tgz
(
from_path
):
with
tarfile
.
open
(
from_path
,
'r:gz'
)
as
tar
:
tar
.
extractall
(
path
=
to_path
)
elif
_is_tarxz
(
from_path
):
with
tarfile
.
open
(
from_path
,
'r:xz'
)
as
tar
:
tar
.
extractall
(
path
=
to_path
)
elif
_is_gzip
(
from_path
):
to_path
=
os
.
path
.
join
(
to_path
,
os
.
path
.
splitext
(
os
.
path
.
basename
(
from_path
))[
0
])
with
open
(
to_path
,
'wb'
)
as
out_f
,
gzip
.
GzipFile
(
from_path
)
as
zip_f
:
out_f
.
write
(
zip_f
.
read
())
elif
_is_zip
(
from_path
):
with
zipfile
.
ZipFile
(
from_path
,
'r'
)
as
z
:
z
.
extractall
(
to_path
)
else
:
raise
ValueError
(
f
'Extraction of
{
from_path
}
not supported'
)
if
remove_finished
:
os
.
remove
(
from_path
)
def
download_and_extract_archive
(
url
,
download_root
,
extract_root
=
None
,
filename
=
None
,
md5
=
None
,
remove_finished
=
False
):
download_root
=
os
.
path
.
expanduser
(
download_root
)
if
extract_root
is
None
:
extract_root
=
download_root
if
not
filename
:
filename
=
os
.
path
.
basename
(
url
)
download_url
(
url
,
download_root
,
filename
,
md5
)
archive
=
os
.
path
.
join
(
download_root
,
filename
)
print
(
f
'Extracting
{
archive
}
to
{
extract_root
}
'
)
extract_archive
(
archive
,
extract_root
,
remove_finished
)
openmmlab_test/mmclassification-0.24.1/mmcls/datasets/voc.py
0 → 100644
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
import
os.path
as
osp
import
xml.etree.ElementTree
as
ET
import
mmcv
import
numpy
as
np
from
.builder
import
DATASETS
from
.multi_label
import
MultiLabelDataset
@
DATASETS
.
register_module
()
class
VOC
(
MultiLabelDataset
):
"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Dataset.
Args:
data_prefix (str): the prefix of data path
pipeline (list): a list of dict, where each element represents
a operation defined in `mmcls.datasets.pipelines`
ann_file (str | None): the annotation file. When ann_file is str,
the subclass is expected to read from the ann_file. When ann_file
is None, the subclass is expected to read according to data_prefix
difficult_as_postive (Optional[bool]): Whether to map the difficult
labels as positive. If it set to True, map difficult examples to
positive ones(1), If it set to False, map difficult examples to
negative ones(0). Defaults to None, the difficult labels will be
set to '-1'.
"""
CLASSES
=
(
'aeroplane'
,
'bicycle'
,
'bird'
,
'boat'
,
'bottle'
,
'bus'
,
'car'
,
'cat'
,
'chair'
,
'cow'
,
'diningtable'
,
'dog'
,
'horse'
,
'motorbike'
,
'person'
,
'pottedplant'
,
'sheep'
,
'sofa'
,
'train'
,
'tvmonitor'
)
def
__init__
(
self
,
difficult_as_postive
=
None
,
**
kwargs
):
self
.
difficult_as_postive
=
difficult_as_postive
super
(
VOC
,
self
).
__init__
(
**
kwargs
)
if
'VOC2007'
in
self
.
data_prefix
:
self
.
year
=
2007
else
:
raise
ValueError
(
'Cannot infer dataset year from img_prefix.'
)
def
load_annotations
(
self
):
"""Load annotations.
Returns:
list[dict]: Annotation info from XML file.
"""
data_infos
=
[]
img_ids
=
mmcv
.
list_from_file
(
self
.
ann_file
)
for
img_id
in
img_ids
:
filename
=
f
'JPEGImages/
{
img_id
}
.jpg'
xml_path
=
osp
.
join
(
self
.
data_prefix
,
'Annotations'
,
f
'
{
img_id
}
.xml'
)
tree
=
ET
.
parse
(
xml_path
)
root
=
tree
.
getroot
()
labels
=
[]
labels_difficult
=
[]
for
obj
in
root
.
findall
(
'object'
):
label_name
=
obj
.
find
(
'name'
).
text
# in case customized dataset has wrong labels
# or CLASSES has been override.
if
label_name
not
in
self
.
CLASSES
:
continue
label
=
self
.
class_to_idx
[
label_name
]
difficult
=
int
(
obj
.
find
(
'difficult'
).
text
)
if
difficult
:
labels_difficult
.
append
(
label
)
else
:
labels
.
append
(
label
)
gt_label
=
np
.
zeros
(
len
(
self
.
CLASSES
))
# set difficult example first, then set postivate examples.
# The order cannot be swapped for the case where multiple objects
# of the same kind exist and some are difficult.
if
self
.
difficult_as_postive
is
None
:
# map difficult examples to -1,
# it may be used in evaluation to ignore difficult targets.
gt_label
[
labels_difficult
]
=
-
1
elif
self
.
difficult_as_postive
:
# map difficult examples to positive ones(1).
gt_label
[
labels_difficult
]
=
1
else
:
# map difficult examples to negative ones(0).
gt_label
[
labels_difficult
]
=
0
gt_label
[
labels
]
=
1
info
=
dict
(
img_prefix
=
self
.
data_prefix
,
img_info
=
dict
(
filename
=
filename
),
gt_label
=
gt_label
.
astype
(
np
.
int8
))
data_infos
.
append
(
info
)
return
data_infos
openmmlab_test/mmclassification-0.24.1/mmcls/models/__init__.py
0 → 100644
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
from
.backbones
import
*
# noqa: F401,F403
from
.builder
import
(
BACKBONES
,
CLASSIFIERS
,
HEADS
,
LOSSES
,
NECKS
,
build_backbone
,
build_classifier
,
build_head
,
build_loss
,
build_neck
)
from
.classifiers
import
*
# noqa: F401,F403
from
.heads
import
*
# noqa: F401,F403
from
.losses
import
*
# noqa: F401,F403
from
.necks
import
*
# noqa: F401,F403
__all__
=
[
'BACKBONES'
,
'HEADS'
,
'NECKS'
,
'LOSSES'
,
'CLASSIFIERS'
,
'build_backbone'
,
'build_head'
,
'build_neck'
,
'build_loss'
,
'build_classifier'
]
openmmlab_test/mmclassification-0.24.1/mmcls/models/backbones/__init__.py
0 → 100644
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
from
.alexnet
import
AlexNet
from
.conformer
import
Conformer
from
.convmixer
import
ConvMixer
from
.convnext
import
ConvNeXt
from
.cspnet
import
CSPDarkNet
,
CSPNet
,
CSPResNet
,
CSPResNeXt
from
.deit
import
DistilledVisionTransformer
from
.densenet
import
DenseNet
from
.efficientformer
import
EfficientFormer
from
.efficientnet
import
EfficientNet
from
.hornet
import
HorNet
from
.hrnet
import
HRNet
from
.lenet
import
LeNet5
from
.mlp_mixer
import
MlpMixer
from
.mobilenet_v2
import
MobileNetV2
from
.mobilenet_v3
import
MobileNetV3
from
.mvit
import
MViT
from
.poolformer
import
PoolFormer
from
.regnet
import
RegNet
from
.repmlp
import
RepMLPNet
from
.repvgg
import
RepVGG
from
.res2net
import
Res2Net
from
.resnest
import
ResNeSt
from
.resnet
import
ResNet
,
ResNetV1c
,
ResNetV1d
from
.resnet_cifar
import
ResNet_CIFAR
from
.resnext
import
ResNeXt
from
.seresnet
import
SEResNet
from
.seresnext
import
SEResNeXt
from
.shufflenet_v1
import
ShuffleNetV1
from
.shufflenet_v2
import
ShuffleNetV2
from
.swin_transformer
import
SwinTransformer
from
.swin_transformer_v2
import
SwinTransformerV2
from
.t2t_vit
import
T2T_ViT
from
.timm_backbone
import
TIMMBackbone
from
.tnt
import
TNT
from
.twins
import
PCPVT
,
SVT
from
.van
import
VAN
from
.vgg
import
VGG
from
.vision_transformer
import
VisionTransformer
__all__
=
[
'LeNet5'
,
'AlexNet'
,
'VGG'
,
'RegNet'
,
'ResNet'
,
'ResNeXt'
,
'ResNetV1d'
,
'ResNeSt'
,
'ResNet_CIFAR'
,
'SEResNet'
,
'SEResNeXt'
,
'ShuffleNetV1'
,
'ShuffleNetV2'
,
'MobileNetV2'
,
'MobileNetV3'
,
'VisionTransformer'
,
'SwinTransformer'
,
'SwinTransformerV2'
,
'TNT'
,
'TIMMBackbone'
,
'T2T_ViT'
,
'Res2Net'
,
'RepVGG'
,
'Conformer'
,
'MlpMixer'
,
'DistilledVisionTransformer'
,
'PCPVT'
,
'SVT'
,
'EfficientNet'
,
'ConvNeXt'
,
'HRNet'
,
'ResNetV1c'
,
'ConvMixer'
,
'CSPDarkNet'
,
'CSPResNet'
,
'CSPResNeXt'
,
'CSPNet'
,
'RepMLPNet'
,
'PoolFormer'
,
'DenseNet'
,
'VAN'
,
'MViT'
,
'EfficientFormer'
,
'HorNet'
]
openmmlab_test/mmclassification-
speed-benchmark
/mmcls/models/backbones/alexnet.py
→
openmmlab_test/mmclassification-
0.24.1
/mmcls/models/backbones/alexnet.py
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
import
torch.nn
as
nn
import
torch.nn
as
nn
from
..builder
import
BACKBONES
from
..builder
import
BACKBONES
...
@@ -52,4 +53,4 @@ class AlexNet(BaseBackbone):
...
@@ -52,4 +53,4 @@ class AlexNet(BaseBackbone):
x
=
x
.
view
(
x
.
size
(
0
),
256
*
6
*
6
)
x
=
x
.
view
(
x
.
size
(
0
),
256
*
6
*
6
)
x
=
self
.
classifier
(
x
)
x
=
self
.
classifier
(
x
)
return
x
return
(
x
,
)
openmmlab_test/mmclassification-
speed-benchmark
/mmcls/models/backbones/base_backbone.py
→
openmmlab_test/mmclassification-
0.24.1
/mmcls/models/backbones/base_backbone.py
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
from
abc
import
ABCMeta
,
abstractmethod
from
abc
import
ABCMeta
,
abstractmethod
from
mmcv.runner
import
BaseModule
from
mmcv.runner
import
BaseModule
...
...
Prev
1
…
31
32
33
34
35
36
37
38
39
…
42
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment