Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
dcuai
dlexamples
Commits
85529f35
Commit
85529f35
authored
Jul 30, 2022
by
unknown
Browse files
添加openmmlab测试用例
parent
b21b0c01
Changes
977
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2811 additions
and
0 deletions
+2811
-0
openmmlab_test/mmclassification-speed-benchmark/mmcls/core/evaluation/eval_metrics.py
...ion-speed-benchmark/mmcls/core/evaluation/eval_metrics.py
+235
-0
openmmlab_test/mmclassification-speed-benchmark/mmcls/core/evaluation/mean_ap.py
...fication-speed-benchmark/mmcls/core/evaluation/mean_ap.py
+73
-0
openmmlab_test/mmclassification-speed-benchmark/mmcls/core/evaluation/multilabel_eval_metrics.py
...enchmark/mmcls/core/evaluation/multilabel_eval_metrics.py
+71
-0
openmmlab_test/mmclassification-speed-benchmark/mmcls/core/export/__init__.py
...ssification-speed-benchmark/mmcls/core/export/__init__.py
+3
-0
openmmlab_test/mmclassification-speed-benchmark/mmcls/core/export/test.py
...mclassification-speed-benchmark/mmcls/core/export/test.py
+95
-0
openmmlab_test/mmclassification-speed-benchmark/mmcls/core/fp16/__init__.py
...lassification-speed-benchmark/mmcls/core/fp16/__init__.py
+4
-0
openmmlab_test/mmclassification-speed-benchmark/mmcls/core/fp16/decorators.py
...ssification-speed-benchmark/mmcls/core/fp16/decorators.py
+160
-0
openmmlab_test/mmclassification-speed-benchmark/mmcls/core/fp16/hooks.py
...mmclassification-speed-benchmark/mmcls/core/fp16/hooks.py
+128
-0
openmmlab_test/mmclassification-speed-benchmark/mmcls/core/fp16/utils.py
...mmclassification-speed-benchmark/mmcls/core/fp16/utils.py
+23
-0
openmmlab_test/mmclassification-speed-benchmark/mmcls/core/utils/__init__.py
...assification-speed-benchmark/mmcls/core/utils/__init__.py
+4
-0
openmmlab_test/mmclassification-speed-benchmark/mmcls/core/utils/dist_utils.py
...sification-speed-benchmark/mmcls/core/utils/dist_utils.py
+56
-0
openmmlab_test/mmclassification-speed-benchmark/mmcls/core/utils/misc.py
...mmclassification-speed-benchmark/mmcls/core/utils/misc.py
+7
-0
openmmlab_test/mmclassification-speed-benchmark/mmcls/datasets/__init__.py
...classification-speed-benchmark/mmcls/datasets/__init__.py
+18
-0
openmmlab_test/mmclassification-speed-benchmark/mmcls/datasets/base_dataset.py
...sification-speed-benchmark/mmcls/datasets/base_dataset.py
+198
-0
openmmlab_test/mmclassification-speed-benchmark/mmcls/datasets/builder.py
...mclassification-speed-benchmark/mmcls/datasets/builder.py
+108
-0
openmmlab_test/mmclassification-speed-benchmark/mmcls/datasets/cifar.py
.../mmclassification-speed-benchmark/mmcls/datasets/cifar.py
+132
-0
openmmlab_test/mmclassification-speed-benchmark/mmcls/datasets/dataset_wrappers.py
...cation-speed-benchmark/mmcls/datasets/dataset_wrappers.py
+162
-0
openmmlab_test/mmclassification-speed-benchmark/mmcls/datasets/dummy.py
.../mmclassification-speed-benchmark/mmcls/datasets/dummy.py
+45
-0
openmmlab_test/mmclassification-speed-benchmark/mmcls/datasets/imagenet.py
...classification-speed-benchmark/mmcls/datasets/imagenet.py
+1105
-0
openmmlab_test/mmclassification-speed-benchmark/mmcls/datasets/mnist.py
.../mmclassification-speed-benchmark/mmcls/datasets/mnist.py
+184
-0
No files found.
Too many changes to show.
To preserve performance only
977 of 977+
files are displayed.
Plain diff
Email patch
openmmlab_test/mmclassification-speed-benchmark/mmcls/core/evaluation/eval_metrics.py
0 → 100644
View file @
85529f35
import
numpy
as
np
import
torch
def
calculate_confusion_matrix
(
pred
,
target
):
"""Calculate confusion matrix according to the prediction and target.
Args:
pred (torch.Tensor | np.array): The model prediction with shape (N, C).
target (torch.Tensor | np.array): The target of each prediction with
shape (N, 1) or (N,).
Returns:
torch.Tensor: Confusion matrix with shape (C, C), where C is the number
of classes.
"""
if
isinstance
(
pred
,
np
.
ndarray
):
pred
=
torch
.
from_numpy
(
pred
)
if
isinstance
(
target
,
np
.
ndarray
):
target
=
torch
.
from_numpy
(
target
)
assert
(
isinstance
(
pred
,
torch
.
Tensor
)
and
isinstance
(
target
,
torch
.
Tensor
)),
\
(
f
'pred and target should be torch.Tensor or np.ndarray, '
f
'but got
{
type
(
pred
)
}
and
{
type
(
target
)
}
.'
)
num_classes
=
pred
.
size
(
1
)
_
,
pred_label
=
pred
.
topk
(
1
,
dim
=
1
)
pred_label
=
pred_label
.
view
(
-
1
)
target_label
=
target
.
view
(
-
1
)
assert
len
(
pred_label
)
==
len
(
target_label
)
confusion_matrix
=
torch
.
zeros
(
num_classes
,
num_classes
)
with
torch
.
no_grad
():
for
t
,
p
in
zip
(
target_label
,
pred_label
):
confusion_matrix
[
t
.
long
(),
p
.
long
()]
+=
1
return
confusion_matrix
def
precision_recall_f1
(
pred
,
target
,
average_mode
=
'macro'
,
thrs
=
None
):
"""Calculate precision, recall and f1 score according to the prediction and
target.
Args:
pred (torch.Tensor | np.array): The model prediction with shape (N, C).
target (torch.Tensor | np.array): The target of each prediction with
shape (N, 1) or (N,).
average_mode (str): The type of averaging performed on the result.
Options are 'macro' and 'none'. If 'none', the scores for each
class are returned. If 'macro', calculate metrics for each class,
and find their unweighted mean.
Defaults to 'macro'.
thrs (float | tuple[float], optional): Predictions with scores under
the thresholds are considered negative. Default to None.
Returns:
float | np.array | list[float | np.array]: Precision, recall, f1 score.
If the ``average_mode`` is set to macro, np.array is used in favor
of float to give class-wise results. If the ``average_mode`` is set
to none, float is used to return a single value.
If ``thrs`` is a single float or None, the function will return
float or np.array. If ``thrs`` is a tuple, the function will return
a list containing metrics for each ``thrs`` condition.
"""
allowed_average_mode
=
[
'macro'
,
'none'
]
if
average_mode
not
in
allowed_average_mode
:
raise
ValueError
(
f
'Unsupport type of averaging
{
average_mode
}
.'
)
if
isinstance
(
pred
,
torch
.
Tensor
):
pred
=
pred
.
numpy
()
if
isinstance
(
target
,
torch
.
Tensor
):
target
=
target
.
numpy
()
assert
(
isinstance
(
pred
,
np
.
ndarray
)
and
isinstance
(
target
,
np
.
ndarray
)),
\
(
f
'pred and target should be torch.Tensor or np.ndarray, '
f
'but got
{
type
(
pred
)
}
and
{
type
(
target
)
}
.'
)
if
thrs
is
None
:
thrs
=
0.0
if
isinstance
(
thrs
,
float
):
thrs
=
(
thrs
,
)
return_single
=
True
elif
isinstance
(
thrs
,
tuple
):
return_single
=
False
else
:
raise
TypeError
(
f
'thrs should be float or tuple, but got
{
type
(
thrs
)
}
.'
)
label
=
np
.
indices
(
pred
.
shape
)[
1
]
pred_label
=
np
.
argsort
(
pred
,
axis
=
1
)[:,
-
1
]
pred_score
=
np
.
sort
(
pred
,
axis
=
1
)[:,
-
1
]
precisions
=
[]
recalls
=
[]
f1_scores
=
[]
for
thr
in
thrs
:
# Only prediction values larger than thr are counted as positive
_pred_label
=
pred_label
.
copy
()
if
thr
is
not
None
:
_pred_label
[
pred_score
<=
thr
]
=
-
1
pred_positive
=
label
==
_pred_label
.
reshape
(
-
1
,
1
)
gt_positive
=
label
==
target
.
reshape
(
-
1
,
1
)
precision
=
(
pred_positive
&
gt_positive
).
sum
(
0
)
/
np
.
maximum
(
pred_positive
.
sum
(
0
),
1
)
*
100
recall
=
(
pred_positive
&
gt_positive
).
sum
(
0
)
/
np
.
maximum
(
gt_positive
.
sum
(
0
),
1
)
*
100
f1_score
=
2
*
precision
*
recall
/
np
.
maximum
(
precision
+
recall
,
1e-20
)
if
average_mode
==
'macro'
:
precision
=
float
(
precision
.
mean
())
recall
=
float
(
recall
.
mean
())
f1_score
=
float
(
f1_score
.
mean
())
precisions
.
append
(
precision
)
recalls
.
append
(
recall
)
f1_scores
.
append
(
f1_score
)
if
return_single
:
return
precisions
[
0
],
recalls
[
0
],
f1_scores
[
0
]
else
:
return
precisions
,
recalls
,
f1_scores
def
precision
(
pred
,
target
,
average_mode
=
'macro'
,
thrs
=
None
):
"""Calculate precision according to the prediction and target.
Args:
pred (torch.Tensor | np.array): The model prediction with shape (N, C).
target (torch.Tensor | np.array): The target of each prediction with
shape (N, 1) or (N,).
average_mode (str): The type of averaging performed on the result.
Options are 'macro' and 'none'. If 'none', the scores for each
class are returned. If 'macro', calculate metrics for each class,
and find their unweighted mean.
Defaults to 'macro'.
thrs (float | tuple[float], optional): Predictions with scores under
the thresholds are considered negative. Default to None.
Returns:
float | np.array | list[float | np.array]: Precision.
If the ``average_mode`` is set to macro, np.array is used in favor
of float to give class-wise results. If the ``average_mode`` is set
to none, float is used to return a single value.
If ``thrs`` is a single float or None, the function will return
float or np.array. If ``thrs`` is a tuple, the function will return
a list containing metrics for each ``thrs`` condition.
"""
precisions
,
_
,
_
=
precision_recall_f1
(
pred
,
target
,
average_mode
,
thrs
)
return
precisions
def
recall
(
pred
,
target
,
average_mode
=
'macro'
,
thrs
=
None
):
"""Calculate recall according to the prediction and target.
Args:
pred (torch.Tensor | np.array): The model prediction with shape (N, C).
target (torch.Tensor | np.array): The target of each prediction with
shape (N, 1) or (N,).
average_mode (str): The type of averaging performed on the result.
Options are 'macro' and 'none'. If 'none', the scores for each
class are returned. If 'macro', calculate metrics for each class,
and find their unweighted mean.
Defaults to 'macro'.
thrs (float | tuple[float], optional): Predictions with scores under
the thresholds are considered negative. Default to None.
Returns:
float | np.array | list[float | np.array]: Recall.
If the ``average_mode`` is set to macro, np.array is used in favor
of float to give class-wise results. If the ``average_mode`` is set
to none, float is used to return a single value.
If ``thrs`` is a single float or None, the function will return
float or np.array. If ``thrs`` is a tuple, the function will return
a list containing metrics for each ``thrs`` condition.
"""
_
,
recalls
,
_
=
precision_recall_f1
(
pred
,
target
,
average_mode
,
thrs
)
return
recalls
def
f1_score
(
pred
,
target
,
average_mode
=
'macro'
,
thrs
=
None
):
"""Calculate F1 score according to the prediction and target.
Args:
pred (torch.Tensor | np.array): The model prediction with shape (N, C).
target (torch.Tensor | np.array): The target of each prediction with
shape (N, 1) or (N,).
average_mode (str): The type of averaging performed on the result.
Options are 'macro' and 'none'. If 'none', the scores for each
class are returned. If 'macro', calculate metrics for each class,
and find their unweighted mean.
Defaults to 'macro'.
thrs (float | tuple[float], optional): Predictions with scores under
the thresholds are considered negative. Default to None.
Returns:
float | np.array | list[float | np.array]: F1 score.
If the ``average_mode`` is set to macro, np.array is used in favor
of float to give class-wise results. If the ``average_mode`` is set
to none, float is used to return a single value.
If ``thrs`` is a single float or None, the function will return
float or np.array. If ``thrs`` is a tuple, the function will return
a list containing metrics for each ``thrs`` condition.
"""
_
,
_
,
f1_scores
=
precision_recall_f1
(
pred
,
target
,
average_mode
,
thrs
)
return
f1_scores
def
support
(
pred
,
target
,
average_mode
=
'macro'
):
"""Calculate the total number of occurrences of each label according to the
prediction and target.
Args:
pred (torch.Tensor | np.array): The model prediction with shape (N, C).
target (torch.Tensor | np.array): The target of each prediction with
shape (N, 1) or (N,).
average_mode (str): The type of averaging performed on the result.
Options are 'macro' and 'none'. If 'none', the scores for each
class are returned. If 'macro', calculate metrics for each class,
and find their unweighted sum.
Defaults to 'macro'.
Returns:
float | np.array: Precision, recall, f1 score.
The function returns a single float if the average_mode is set to
macro, or a np.array with shape C if the average_mode is set to
none.
"""
confusion_matrix
=
calculate_confusion_matrix
(
pred
,
target
)
with
torch
.
no_grad
():
res
=
confusion_matrix
.
sum
(
1
)
if
average_mode
==
'macro'
:
res
=
float
(
res
.
sum
().
numpy
())
elif
average_mode
==
'none'
:
res
=
res
.
numpy
()
else
:
raise
ValueError
(
f
'Unsupport type of averaging
{
average_mode
}
.'
)
return
res
openmmlab_test/mmclassification-speed-benchmark/mmcls/core/evaluation/mean_ap.py
0 → 100644
View file @
85529f35
import
numpy
as
np
import
torch
def
average_precision
(
pred
,
target
):
"""Calculate the average precision for a single class.
AP summarizes a precision-recall curve as the weighted mean of maximum
precisions obtained for any r'>r, where r is the recall:
..math::
\\
text{AP} =
\\
sum_n (R_n - R_{n-1}) P_n
Note that no approximation is involved since the curve is piecewise
constant.
Args:
pred (np.ndarray): The model prediction with shape (N, ).
target (np.ndarray): The target of each prediction with shape (N, ).
Returns:
float: a single float as average precision value.
"""
eps
=
np
.
finfo
(
np
.
float32
).
eps
# sort examples
sort_inds
=
np
.
argsort
(
-
pred
)
sort_target
=
target
[
sort_inds
]
# count true positive examples
pos_inds
=
sort_target
==
1
tp
=
np
.
cumsum
(
pos_inds
)
total_pos
=
tp
[
-
1
]
# count not difficult examples
pn_inds
=
sort_target
!=
-
1
pn
=
np
.
cumsum
(
pn_inds
)
tp
[
np
.
logical_not
(
pos_inds
)]
=
0
precision
=
tp
/
np
.
maximum
(
pn
,
eps
)
ap
=
np
.
sum
(
precision
)
/
np
.
maximum
(
total_pos
,
eps
)
return
ap
def
mAP
(
pred
,
target
):
"""Calculate the mean average precision with respect of classes.
Args:
pred (torch.Tensor | np.ndarray): The model prediction with shape
(N, C), where C is the number of classes.
target (torch.Tensor | np.ndarray): The target of each prediction with
shape (N, C), where C is the number of classes. 1 stands for
positive examples, 0 stands for negative examples and -1 stands for
difficult examples.
Returns:
float: A single float as mAP value.
"""
if
isinstance
(
pred
,
torch
.
Tensor
)
and
isinstance
(
target
,
torch
.
Tensor
):
pred
=
pred
.
detach
().
cpu
().
numpy
()
target
=
target
.
detach
().
cpu
().
numpy
()
elif
not
(
isinstance
(
pred
,
np
.
ndarray
)
and
isinstance
(
target
,
np
.
ndarray
)):
raise
TypeError
(
'pred and target should both be torch.Tensor or'
'np.ndarray'
)
assert
pred
.
shape
==
\
target
.
shape
,
'pred and target should be in the same shape.'
num_classes
=
pred
.
shape
[
1
]
ap
=
np
.
zeros
(
num_classes
)
for
k
in
range
(
num_classes
):
ap
[
k
]
=
average_precision
(
pred
[:,
k
],
target
[:,
k
])
mean_ap
=
ap
.
mean
()
*
100.0
return
mean_ap
openmmlab_test/mmclassification-speed-benchmark/mmcls/core/evaluation/multilabel_eval_metrics.py
0 → 100644
View file @
85529f35
import
warnings
import
numpy
as
np
import
torch
def
average_performance
(
pred
,
target
,
thr
=
None
,
k
=
None
):
"""Calculate CP, CR, CF1, OP, OR, OF1, where C stands for per-class
average, O stands for overall average, P stands for precision, R stands for
recall and F1 stands for F1-score.
Args:
pred (torch.Tensor | np.ndarray): The model prediction with shape
(N, C), where C is the number of classes.
target (torch.Tensor | np.ndarray): The target of each prediction with
shape (N, C), where C is the number of classes. 1 stands for
positive examples, 0 stands for negative examples and -1 stands for
difficult examples.
thr (float): The confidence threshold. Defaults to None.
k (int): Top-k performance. Note that if thr and k are both given, k
will be ignored. Defaults to None.
Returns:
tuple: (CP, CR, CF1, OP, OR, OF1)
"""
if
isinstance
(
pred
,
torch
.
Tensor
)
and
isinstance
(
target
,
torch
.
Tensor
):
pred
=
pred
.
detach
().
cpu
().
numpy
()
target
=
target
.
detach
().
cpu
().
numpy
()
elif
not
(
isinstance
(
pred
,
np
.
ndarray
)
and
isinstance
(
target
,
np
.
ndarray
)):
raise
TypeError
(
'pred and target should both be torch.Tensor or'
'np.ndarray'
)
if
thr
is
None
and
k
is
None
:
thr
=
0.5
warnings
.
warn
(
'Neither thr nor k is given, set thr as 0.5 by '
'default.'
)
elif
thr
is
not
None
and
k
is
not
None
:
warnings
.
warn
(
'Both thr and k are given, use threshold in favor of '
'top-k.'
)
assert
pred
.
shape
==
\
target
.
shape
,
'pred and target should be in the same shape.'
eps
=
np
.
finfo
(
np
.
float32
).
eps
target
[
target
==
-
1
]
=
0
if
thr
is
not
None
:
# a label is predicted positive if the confidence is no lower than thr
pos_inds
=
pred
>=
thr
else
:
# top-k labels will be predicted positive for any example
sort_inds
=
np
.
argsort
(
-
pred
,
axis
=
1
)
sort_inds_
=
sort_inds
[:,
:
k
]
inds
=
np
.
indices
(
sort_inds_
.
shape
)
pos_inds
=
np
.
zeros_like
(
pred
)
pos_inds
[
inds
[
0
],
sort_inds_
]
=
1
tp
=
(
pos_inds
*
target
)
==
1
fp
=
(
pos_inds
*
(
1
-
target
))
==
1
fn
=
((
1
-
pos_inds
)
*
target
)
==
1
precision_class
=
tp
.
sum
(
axis
=
0
)
/
np
.
maximum
(
tp
.
sum
(
axis
=
0
)
+
fp
.
sum
(
axis
=
0
),
eps
)
recall_class
=
tp
.
sum
(
axis
=
0
)
/
np
.
maximum
(
tp
.
sum
(
axis
=
0
)
+
fn
.
sum
(
axis
=
0
),
eps
)
CP
=
precision_class
.
mean
()
*
100.0
CR
=
recall_class
.
mean
()
*
100.0
CF1
=
2
*
CP
*
CR
/
np
.
maximum
(
CP
+
CR
,
eps
)
OP
=
tp
.
sum
()
/
np
.
maximum
(
tp
.
sum
()
+
fp
.
sum
(),
eps
)
*
100.0
OR
=
tp
.
sum
()
/
np
.
maximum
(
tp
.
sum
()
+
fn
.
sum
(),
eps
)
*
100.0
OF1
=
2
*
OP
*
OR
/
np
.
maximum
(
OP
+
OR
,
eps
)
return
CP
,
CR
,
CF1
,
OP
,
OR
,
OF1
openmmlab_test/mmclassification-speed-benchmark/mmcls/core/export/__init__.py
0 → 100644
View file @
85529f35
from
.test
import
ONNXRuntimeClassifier
,
TensorRTClassifier
__all__
=
[
'ONNXRuntimeClassifier'
,
'TensorRTClassifier'
]
openmmlab_test/mmclassification-speed-benchmark/mmcls/core/export/test.py
0 → 100644
View file @
85529f35
import
warnings
import
numpy
as
np
import
onnxruntime
as
ort
import
torch
from
mmcls.models.classifiers
import
BaseClassifier
class
ONNXRuntimeClassifier
(
BaseClassifier
):
"""Wrapper for classifier's inference with ONNXRuntime."""
def
__init__
(
self
,
onnx_file
,
class_names
,
device_id
):
super
(
ONNXRuntimeClassifier
,
self
).
__init__
()
sess
=
ort
.
InferenceSession
(
onnx_file
)
providers
=
[
'CPUExecutionProvider'
]
options
=
[{}]
is_cuda_available
=
ort
.
get_device
()
==
'GPU'
if
is_cuda_available
:
providers
.
insert
(
0
,
'CUDAExecutionProvider'
)
options
.
insert
(
0
,
{
'device_id'
:
device_id
})
sess
.
set_providers
(
providers
,
options
)
self
.
sess
=
sess
self
.
CLASSES
=
class_names
self
.
device_id
=
device_id
self
.
io_binding
=
sess
.
io_binding
()
self
.
output_names
=
[
_
.
name
for
_
in
sess
.
get_outputs
()]
self
.
is_cuda_available
=
is_cuda_available
def
simple_test
(
self
,
img
,
img_metas
,
**
kwargs
):
raise
NotImplementedError
(
'This method is not implemented.'
)
def
extract_feat
(
self
,
imgs
):
raise
NotImplementedError
(
'This method is not implemented.'
)
def
forward_train
(
self
,
imgs
,
**
kwargs
):
raise
NotImplementedError
(
'This method is not implemented.'
)
def
forward_test
(
self
,
imgs
,
img_metas
,
**
kwargs
):
input_data
=
imgs
# set io binding for inputs/outputs
device_type
=
'cuda'
if
self
.
is_cuda_available
else
'cpu'
if
not
self
.
is_cuda_available
:
input_data
=
input_data
.
cpu
()
self
.
io_binding
.
bind_input
(
name
=
'input'
,
device_type
=
device_type
,
device_id
=
self
.
device_id
,
element_type
=
np
.
float32
,
shape
=
input_data
.
shape
,
buffer_ptr
=
input_data
.
data_ptr
())
for
name
in
self
.
output_names
:
self
.
io_binding
.
bind_output
(
name
)
# run session to get outputs
self
.
sess
.
run_with_iobinding
(
self
.
io_binding
)
results
=
self
.
io_binding
.
copy_outputs_to_cpu
()[
0
]
return
list
(
results
)
class
TensorRTClassifier
(
BaseClassifier
):
def
__init__
(
self
,
trt_file
,
class_names
,
device_id
):
super
(
TensorRTClassifier
,
self
).
__init__
()
from
mmcv.tensorrt
import
TRTWraper
,
load_tensorrt_plugin
try
:
load_tensorrt_plugin
()
except
(
ImportError
,
ModuleNotFoundError
):
warnings
.
warn
(
'If input model has custom op from mmcv,
\
you may have to build mmcv with TensorRT from source.'
)
model
=
TRTWraper
(
trt_file
,
input_names
=
[
'input'
],
output_names
=
[
'probs'
])
self
.
model
=
model
self
.
device_id
=
device_id
self
.
CLASSES
=
class_names
def
simple_test
(
self
,
img
,
img_metas
,
**
kwargs
):
raise
NotImplementedError
(
'This method is not implemented.'
)
def
extract_feat
(
self
,
imgs
):
raise
NotImplementedError
(
'This method is not implemented.'
)
def
forward_train
(
self
,
imgs
,
**
kwargs
):
raise
NotImplementedError
(
'This method is not implemented.'
)
def
forward_test
(
self
,
imgs
,
img_metas
,
**
kwargs
):
input_data
=
imgs
with
torch
.
cuda
.
device
(
self
.
device_id
),
torch
.
no_grad
():
results
=
self
.
model
({
'input'
:
input_data
})[
'probs'
]
results
=
results
.
detach
().
cpu
().
numpy
()
return
list
(
results
)
openmmlab_test/mmclassification-speed-benchmark/mmcls/core/fp16/__init__.py
0 → 100644
View file @
85529f35
from
.decorators
import
auto_fp16
,
force_fp32
from
.hooks
import
Fp16OptimizerHook
,
wrap_fp16_model
__all__
=
[
'auto_fp16'
,
'force_fp32'
,
'Fp16OptimizerHook'
,
'wrap_fp16_model'
]
openmmlab_test/mmclassification-speed-benchmark/mmcls/core/fp16/decorators.py
0 → 100644
View file @
85529f35
import
functools
from
inspect
import
getfullargspec
import
torch
from
.utils
import
cast_tensor_type
def
auto_fp16
(
apply_to
=
None
,
out_fp32
=
False
):
"""Decorator to enable fp16 training automatically.
This decorator is useful when you write custom modules and want to support
mixed precision training. If inputs arguments are fp32 tensors, they will
be converted to fp16 automatically. Arguments other than fp32 tensors are
ignored.
Args:
apply_to (Iterable, optional): The argument names to be converted.
`None` indicates all arguments.
out_fp32 (bool): Whether to convert the output back to fp32.
:Example:
class MyModule1(nn.Module)
# Convert x and y to fp16
@auto_fp16()
def forward(self, x, y):
pass
class MyModule2(nn.Module):
# convert pred to fp16
@auto_fp16(apply_to=('pred', ))
def do_something(self, pred, others):
pass
"""
def
auto_fp16_wrapper
(
old_func
):
@
functools
.
wraps
(
old_func
)
def
new_func
(
*
args
,
**
kwargs
):
# check if the module has set the attribute `fp16_enabled`, if not,
# just fallback to the original method.
if
not
isinstance
(
args
[
0
],
torch
.
nn
.
Module
):
raise
TypeError
(
'@auto_fp16 can only be used to decorate the '
'method of nn.Module'
)
if
not
(
hasattr
(
args
[
0
],
'fp16_enabled'
)
and
args
[
0
].
fp16_enabled
):
return
old_func
(
*
args
,
**
kwargs
)
# get the arg spec of the decorated method
args_info
=
getfullargspec
(
old_func
)
# get the argument names to be casted
args_to_cast
=
args_info
.
args
if
apply_to
is
None
else
apply_to
# convert the args that need to be processed
new_args
=
[]
# NOTE: default args are not taken into consideration
if
args
:
arg_names
=
args_info
.
args
[:
len
(
args
)]
for
i
,
arg_name
in
enumerate
(
arg_names
):
if
arg_name
in
args_to_cast
:
new_args
.
append
(
cast_tensor_type
(
args
[
i
],
torch
.
float
,
torch
.
half
))
else
:
new_args
.
append
(
args
[
i
])
# convert the kwargs that need to be processed
new_kwargs
=
{}
if
kwargs
:
for
arg_name
,
arg_value
in
kwargs
.
items
():
if
arg_name
in
args_to_cast
:
new_kwargs
[
arg_name
]
=
cast_tensor_type
(
arg_value
,
torch
.
float
,
torch
.
half
)
else
:
new_kwargs
[
arg_name
]
=
arg_value
# apply converted arguments to the decorated method
output
=
old_func
(
*
new_args
,
**
new_kwargs
)
# cast the results back to fp32 if necessary
if
out_fp32
:
output
=
cast_tensor_type
(
output
,
torch
.
half
,
torch
.
float
)
return
output
return
new_func
return
auto_fp16_wrapper
def
force_fp32
(
apply_to
=
None
,
out_fp16
=
False
):
"""Decorator to convert input arguments to fp32 in force.
This decorator is useful when you write custom modules and want to support
mixed precision training. If there are some inputs that must be processed
in fp32 mode, then this decorator can handle it. If inputs arguments are
fp16 tensors, they will be converted to fp32 automatically. Arguments other
than fp16 tensors are ignored.
Args:
apply_to (Iterable, optional): The argument names to be converted.
`None` indicates all arguments.
out_fp16 (bool): Whether to convert the output back to fp16.
:Example:
class MyModule1(nn.Module)
# Convert x and y to fp32
@force_fp32()
def loss(self, x, y):
pass
class MyModule2(nn.Module):
# convert pred to fp32
@force_fp32(apply_to=('pred', ))
def post_process(self, pred, others):
pass
"""
def
force_fp32_wrapper
(
old_func
):
@
functools
.
wraps
(
old_func
)
def
new_func
(
*
args
,
**
kwargs
):
# check if the module has set the attribute `fp16_enabled`, if not,
# just fallback to the original method.
if
not
isinstance
(
args
[
0
],
torch
.
nn
.
Module
):
raise
TypeError
(
'@force_fp32 can only be used to decorate the '
'method of nn.Module'
)
if
not
(
hasattr
(
args
[
0
],
'fp16_enabled'
)
and
args
[
0
].
fp16_enabled
):
return
old_func
(
*
args
,
**
kwargs
)
# get the arg spec of the decorated method
args_info
=
getfullargspec
(
old_func
)
# get the argument names to be casted
args_to_cast
=
args_info
.
args
if
apply_to
is
None
else
apply_to
# convert the args that need to be processed
new_args
=
[]
if
args
:
arg_names
=
args_info
.
args
[:
len
(
args
)]
for
i
,
arg_name
in
enumerate
(
arg_names
):
if
arg_name
in
args_to_cast
:
new_args
.
append
(
cast_tensor_type
(
args
[
i
],
torch
.
half
,
torch
.
float
))
else
:
new_args
.
append
(
args
[
i
])
# convert the kwargs that need to be processed
new_kwargs
=
dict
()
if
kwargs
:
for
arg_name
,
arg_value
in
kwargs
.
items
():
if
arg_name
in
args_to_cast
:
new_kwargs
[
arg_name
]
=
cast_tensor_type
(
arg_value
,
torch
.
half
,
torch
.
float
)
else
:
new_kwargs
[
arg_name
]
=
arg_value
# apply converted arguments to the decorated method
output
=
old_func
(
*
new_args
,
**
new_kwargs
)
# cast the results back to fp32 if necessary
if
out_fp16
:
output
=
cast_tensor_type
(
output
,
torch
.
float
,
torch
.
half
)
return
output
return
new_func
return
force_fp32_wrapper
openmmlab_test/mmclassification-speed-benchmark/mmcls/core/fp16/hooks.py
0 → 100644
View file @
85529f35
import
copy
import
torch
import
torch.nn
as
nn
from
mmcv.runner
import
OptimizerHook
from
mmcv.utils.parrots_wrapper
import
_BatchNorm
from
..utils
import
allreduce_grads
from
.utils
import
cast_tensor_type
class
Fp16OptimizerHook
(
OptimizerHook
):
"""FP16 optimizer hook.
The steps of fp16 optimizer is as follows.
1. Scale the loss value.
2. BP in the fp16 model.
2. Copy gradients from fp16 model to fp32 weights.
3. Update fp32 weights.
4. Copy updated parameters from fp32 weights to fp16 model.
Refer to https://arxiv.org/abs/1710.03740 for more details.
Args:
loss_scale (float): Scale factor multiplied with loss.
"""
def
__init__
(
self
,
grad_clip
=
None
,
coalesce
=
True
,
bucket_size_mb
=-
1
,
loss_scale
=
512.
,
distributed
=
True
):
self
.
grad_clip
=
grad_clip
self
.
coalesce
=
coalesce
self
.
bucket_size_mb
=
bucket_size_mb
self
.
loss_scale
=
loss_scale
self
.
distributed
=
distributed
def
before_run
(
self
,
runner
):
# keep a copy of fp32 weights
runner
.
optimizer
.
param_groups
=
copy
.
deepcopy
(
runner
.
optimizer
.
param_groups
)
# convert model to fp16
wrap_fp16_model
(
runner
.
model
)
def
copy_grads_to_fp32
(
self
,
fp16_net
,
fp32_weights
):
"""Copy gradients from fp16 model to fp32 weight copy."""
for
fp32_param
,
fp16_param
in
zip
(
fp32_weights
,
fp16_net
.
parameters
()):
if
fp16_param
.
grad
is
not
None
:
if
fp32_param
.
grad
is
None
:
fp32_param
.
grad
=
fp32_param
.
data
.
new
(
fp32_param
.
size
())
fp32_param
.
grad
.
copy_
(
fp16_param
.
grad
)
def
copy_params_to_fp16
(
self
,
fp16_net
,
fp32_weights
):
"""Copy updated params from fp32 weight copy to fp16 model."""
for
fp16_param
,
fp32_param
in
zip
(
fp16_net
.
parameters
(),
fp32_weights
):
fp16_param
.
data
.
copy_
(
fp32_param
.
data
)
def
after_train_iter
(
self
,
runner
):
# clear grads of last iteration
runner
.
model
.
zero_grad
()
runner
.
optimizer
.
zero_grad
()
# scale the loss value
scaled_loss
=
runner
.
outputs
[
'loss'
]
*
self
.
loss_scale
scaled_loss
.
backward
()
# copy fp16 grads in the model to fp32 params in the optimizer
fp32_weights
=
[]
for
param_group
in
runner
.
optimizer
.
param_groups
:
fp32_weights
+=
param_group
[
'params'
]
self
.
copy_grads_to_fp32
(
runner
.
model
,
fp32_weights
)
# allreduce grads
if
self
.
distributed
:
allreduce_grads
(
fp32_weights
,
self
.
coalesce
,
self
.
bucket_size_mb
)
# scale the gradients back
for
param
in
fp32_weights
:
if
param
.
grad
is
not
None
:
param
.
grad
.
div_
(
self
.
loss_scale
)
if
self
.
grad_clip
is
not
None
:
self
.
clip_grads
(
fp32_weights
)
# update fp32 params
runner
.
optimizer
.
step
()
# copy fp32 params to the fp16 model
self
.
copy_params_to_fp16
(
runner
.
model
,
fp32_weights
)
def
wrap_fp16_model
(
model
):
# convert model to fp16
model
.
half
()
# patch the normalization layers to make it work in fp32 mode
patch_norm_fp32
(
model
)
# set `fp16_enabled` flag
for
m
in
model
.
modules
():
if
hasattr
(
m
,
'fp16_enabled'
):
m
.
fp16_enabled
=
True
def
patch_norm_fp32
(
module
):
if
isinstance
(
module
,
(
_BatchNorm
,
nn
.
GroupNorm
)):
module
.
float
()
module
.
forward
=
patch_forward_method
(
module
.
forward
,
torch
.
half
,
torch
.
float
)
for
child
in
module
.
children
():
patch_norm_fp32
(
child
)
return
module
def
patch_forward_method
(
func
,
src_type
,
dst_type
,
convert_output
=
True
):
"""Patch the forward method of a module.
Args:
func (callable): The original forward method.
src_type (torch.dtype): Type of input arguments to be converted from.
dst_type (torch.dtype): Type of input arguments to be converted to.
convert_output (bool): Whether to convert the output back to src_type.
Returns:
callable: The patched forward method.
"""
def
new_forward
(
*
args
,
**
kwargs
):
output
=
func
(
*
cast_tensor_type
(
args
,
src_type
,
dst_type
),
**
cast_tensor_type
(
kwargs
,
src_type
,
dst_type
))
if
convert_output
:
output
=
cast_tensor_type
(
output
,
dst_type
,
src_type
)
return
output
return
new_forward
openmmlab_test/mmclassification-speed-benchmark/mmcls/core/fp16/utils.py
0 → 100644
View file @
85529f35
from
collections
import
abc
import
numpy
as
np
import
torch
def
cast_tensor_type
(
inputs
,
src_type
,
dst_type
):
if
isinstance
(
inputs
,
torch
.
Tensor
):
return
inputs
.
to
(
dst_type
)
elif
isinstance
(
inputs
,
str
):
return
inputs
elif
isinstance
(
inputs
,
np
.
ndarray
):
return
inputs
elif
isinstance
(
inputs
,
abc
.
Mapping
):
return
type
(
inputs
)({
k
:
cast_tensor_type
(
v
,
src_type
,
dst_type
)
for
k
,
v
in
inputs
.
items
()
})
elif
isinstance
(
inputs
,
abc
.
Iterable
):
return
type
(
inputs
)(
cast_tensor_type
(
item
,
src_type
,
dst_type
)
for
item
in
inputs
)
else
:
return
inputs
openmmlab_test/mmclassification-speed-benchmark/mmcls/core/utils/__init__.py
0 → 100644
View file @
85529f35
from
.dist_utils
import
DistOptimizerHook
,
allreduce_grads
from
.misc
import
multi_apply
__all__
=
[
'allreduce_grads'
,
'DistOptimizerHook'
,
'multi_apply'
]
openmmlab_test/mmclassification-speed-benchmark/mmcls/core/utils/dist_utils.py
0 → 100644
View file @
85529f35
from
collections
import
OrderedDict
import
torch.distributed
as
dist
from
mmcv.runner
import
OptimizerHook
from
torch._utils
import
(
_flatten_dense_tensors
,
_take_tensors
,
_unflatten_dense_tensors
)
def
_allreduce_coalesced
(
tensors
,
world_size
,
bucket_size_mb
=-
1
):
if
bucket_size_mb
>
0
:
bucket_size_bytes
=
bucket_size_mb
*
1024
*
1024
buckets
=
_take_tensors
(
tensors
,
bucket_size_bytes
)
else
:
buckets
=
OrderedDict
()
for
tensor
in
tensors
:
tp
=
tensor
.
type
()
if
tp
not
in
buckets
:
buckets
[
tp
]
=
[]
buckets
[
tp
].
append
(
tensor
)
buckets
=
buckets
.
values
()
for
bucket
in
buckets
:
flat_tensors
=
_flatten_dense_tensors
(
bucket
)
dist
.
all_reduce
(
flat_tensors
)
flat_tensors
.
div_
(
world_size
)
for
tensor
,
synced
in
zip
(
bucket
,
_unflatten_dense_tensors
(
flat_tensors
,
bucket
)):
tensor
.
copy_
(
synced
)
def
allreduce_grads
(
params
,
coalesce
=
True
,
bucket_size_mb
=-
1
):
grads
=
[
param
.
grad
.
data
for
param
in
params
if
param
.
requires_grad
and
param
.
grad
is
not
None
]
world_size
=
dist
.
get_world_size
()
if
coalesce
:
_allreduce_coalesced
(
grads
,
world_size
,
bucket_size_mb
)
else
:
for
tensor
in
grads
:
dist
.
all_reduce
(
tensor
.
div_
(
world_size
))
class
DistOptimizerHook
(
OptimizerHook
):
def
__init__
(
self
,
grad_clip
=
None
,
coalesce
=
True
,
bucket_size_mb
=-
1
):
self
.
grad_clip
=
grad_clip
self
.
coalesce
=
coalesce
self
.
bucket_size_mb
=
bucket_size_mb
def
after_train_iter
(
self
,
runner
):
runner
.
optimizer
.
zero_grad
()
runner
.
outputs
[
'loss'
].
backward
()
if
self
.
grad_clip
is
not
None
:
self
.
clip_grads
(
runner
.
model
.
parameters
())
runner
.
optimizer
.
step
()
openmmlab_test/mmclassification-speed-benchmark/mmcls/core/utils/misc.py
0 → 100644
View file @
85529f35
from
functools
import
partial
def
multi_apply
(
func
,
*
args
,
**
kwargs
):
pfunc
=
partial
(
func
,
**
kwargs
)
if
kwargs
else
func
map_results
=
map
(
pfunc
,
*
args
)
return
tuple
(
map
(
list
,
zip
(
*
map_results
)))
openmmlab_test/mmclassification-speed-benchmark/mmcls/datasets/__init__.py
0 → 100644
View file @
85529f35
from
.base_dataset
import
BaseDataset
from
.builder
import
DATASETS
,
PIPELINES
,
build_dataloader
,
build_dataset
from
.cifar
import
CIFAR10
,
CIFAR100
from
.dataset_wrappers
import
(
ClassBalancedDataset
,
ConcatDataset
,
RepeatDataset
)
from
.dummy
import
DummyImageNet
from
.imagenet
import
ImageNet
from
.mnist
import
MNIST
,
FashionMNIST
from
.multi_label
import
MultiLabelDataset
from
.samplers
import
DistributedSampler
from
.voc
import
VOC
__all__
=
[
'BaseDataset'
,
'ImageNet'
,
'CIFAR10'
,
'CIFAR100'
,
'MNIST'
,
'FashionMNIST'
,
'VOC'
,
'MultiLabelDataset'
,
'build_dataloader'
,
'build_dataset'
,
'Compose'
,
'DistributedSampler'
,
'ConcatDataset'
,
'RepeatDataset'
,
'DummyImageNet'
,
'ClassBalancedDataset'
,
'DATASETS'
,
'PIPELINES'
]
openmmlab_test/mmclassification-speed-benchmark/mmcls/datasets/base_dataset.py
0 → 100644
View file @
85529f35
import
copy
from
abc
import
ABCMeta
,
abstractmethod
import
mmcv
import
numpy
as
np
from
torch.utils.data
import
Dataset
from
mmcls.core.evaluation
import
precision_recall_f1
,
support
from
mmcls.models.losses
import
accuracy
from
.pipelines
import
Compose
class
BaseDataset
(
Dataset
,
metaclass
=
ABCMeta
):
"""Base dataset.
Args:
data_prefix (str): the prefix of data path
pipeline (list): a list of dict, where each element represents
a operation defined in `mmcls.datasets.pipelines`
ann_file (str | None): the annotation file. When ann_file is str,
the subclass is expected to read from the ann_file. When ann_file
is None, the subclass is expected to read according to data_prefix
test_mode (bool): in train mode or test mode
"""
CLASSES
=
None
def
__init__
(
self
,
data_prefix
,
pipeline
,
classes
=
None
,
ann_file
=
None
,
test_mode
=
False
):
super
(
BaseDataset
,
self
).
__init__
()
self
.
ann_file
=
ann_file
self
.
data_prefix
=
data_prefix
self
.
test_mode
=
test_mode
self
.
pipeline
=
Compose
(
pipeline
)
self
.
CLASSES
=
self
.
get_classes
(
classes
)
self
.
data_infos
=
self
.
load_annotations
()
@
abstractmethod
def
load_annotations
(
self
):
pass
@
property
def
class_to_idx
(
self
):
"""Map mapping class name to class index.
Returns:
dict: mapping from class name to class index.
"""
return
{
_class
:
i
for
i
,
_class
in
enumerate
(
self
.
CLASSES
)}
def
get_gt_labels
(
self
):
"""Get all ground-truth labels (categories).
Returns:
list[int]: categories for all images.
"""
gt_labels
=
np
.
array
([
data
[
'gt_label'
]
for
data
in
self
.
data_infos
])
return
gt_labels
def
get_cat_ids
(
self
,
idx
):
"""Get category id by index.
Args:
idx (int): Index of data.
Returns:
int: Image category of specified index.
"""
return
self
.
data_infos
[
idx
][
'gt_label'
].
astype
(
np
.
int
)
def
prepare_data
(
self
,
idx
):
results
=
copy
.
deepcopy
(
self
.
data_infos
[
idx
])
return
self
.
pipeline
(
results
)
def
__len__
(
self
):
return
len
(
self
.
data_infos
)
def
__getitem__
(
self
,
idx
):
return
self
.
prepare_data
(
idx
)
@
classmethod
def
get_classes
(
cls
,
classes
=
None
):
"""Get class names of current dataset.
Args:
classes (Sequence[str] | str | None): If classes is None, use
default CLASSES defined by builtin dataset. If classes is a
string, take it as a file name. The file contains the name of
classes where each line contains one class name. If classes is
a tuple or list, override the CLASSES defined by the dataset.
Returns:
tuple[str] or list[str]: Names of categories of the dataset.
"""
if
classes
is
None
:
return
cls
.
CLASSES
if
isinstance
(
classes
,
str
):
# take it as a file path
class_names
=
mmcv
.
list_from_file
(
classes
)
elif
isinstance
(
classes
,
(
tuple
,
list
)):
class_names
=
classes
else
:
raise
ValueError
(
f
'Unsupported type
{
type
(
classes
)
}
of classes.'
)
return
class_names
def
evaluate
(
self
,
results
,
metric
=
'accuracy'
,
metric_options
=
None
,
logger
=
None
):
"""Evaluate the dataset.
Args:
results (list): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated.
Default value is `accuracy`.
metric_options (dict, optional): Options for calculating metrics.
Allowed keys are 'topk', 'thrs' and 'average_mode'.
Defaults to None.
logger (logging.Logger | str, optional): Logger used for printing
related information during evaluation. Defaults to None.
Returns:
dict: evaluation results
"""
if
metric_options
is
None
:
metric_options
=
{
'topk'
:
(
1
,
5
)}
if
isinstance
(
metric
,
str
):
metrics
=
[
metric
]
else
:
metrics
=
metric
allowed_metrics
=
[
'accuracy'
,
'precision'
,
'recall'
,
'f1_score'
,
'support'
]
eval_results
=
{}
results
=
np
.
vstack
(
results
)
gt_labels
=
self
.
get_gt_labels
()
num_imgs
=
len
(
results
)
assert
len
(
gt_labels
)
==
num_imgs
,
'dataset testing results should '
\
'be of the same length as gt_labels.'
invalid_metrics
=
set
(
metrics
)
-
set
(
allowed_metrics
)
if
len
(
invalid_metrics
)
!=
0
:
raise
ValueError
(
f
'metirc
{
invalid_metrics
}
is not supported.'
)
topk
=
metric_options
.
get
(
'topk'
,
(
1
,
5
))
thrs
=
metric_options
.
get
(
'thrs'
)
average_mode
=
metric_options
.
get
(
'average_mode'
,
'macro'
)
if
'accuracy'
in
metrics
:
acc
=
accuracy
(
results
,
gt_labels
,
topk
=
topk
,
thrs
=
thrs
)
if
isinstance
(
topk
,
tuple
):
eval_results_
=
{
f
'accuracy_top-
{
k
}
'
:
a
for
k
,
a
in
zip
(
topk
,
acc
)
}
else
:
eval_results_
=
{
'accuracy'
:
acc
}
if
isinstance
(
thrs
,
tuple
):
for
key
,
values
in
eval_results_
.
items
():
eval_results
.
update
({
f
'
{
key
}
_thr_
{
thr
:.
2
f
}
'
:
value
.
item
()
for
thr
,
value
in
zip
(
thrs
,
values
)
})
else
:
eval_results
.
update
(
{
k
:
v
.
item
()
for
k
,
v
in
eval_results_
.
items
()})
if
'support'
in
metrics
:
support_value
=
support
(
results
,
gt_labels
,
average_mode
=
average_mode
)
eval_results
[
'support'
]
=
support_value
precision_recall_f1_keys
=
[
'precision'
,
'recall'
,
'f1_score'
]
if
len
(
set
(
metrics
)
&
set
(
precision_recall_f1_keys
))
!=
0
:
precision_recall_f1_values
=
precision_recall_f1
(
results
,
gt_labels
,
average_mode
=
average_mode
,
thrs
=
thrs
)
for
key
,
values
in
zip
(
precision_recall_f1_keys
,
precision_recall_f1_values
):
if
key
in
metrics
:
if
isinstance
(
thrs
,
tuple
):
eval_results
.
update
({
f
'
{
key
}
_thr_
{
thr
:.
2
f
}
'
:
value
for
thr
,
value
in
zip
(
thrs
,
values
)
})
else
:
eval_results
[
key
]
=
values
return
eval_results
openmmlab_test/mmclassification-speed-benchmark/mmcls/datasets/builder.py
0 → 100644
View file @
85529f35
import
platform
import
random
from
functools
import
partial
import
numpy
as
np
from
mmcv.parallel
import
collate
from
mmcv.runner
import
get_dist_info
from
mmcv.utils
import
Registry
,
build_from_cfg
from
torch.utils.data
import
DataLoader
from
.samplers
import
DistributedSampler
if
platform
.
system
()
!=
'Windows'
:
# https://github.com/pytorch/pytorch/issues/973
import
resource
rlimit
=
resource
.
getrlimit
(
resource
.
RLIMIT_NOFILE
)
hard_limit
=
rlimit
[
1
]
soft_limit
=
min
(
4096
,
hard_limit
)
resource
.
setrlimit
(
resource
.
RLIMIT_NOFILE
,
(
soft_limit
,
hard_limit
))
DATASETS
=
Registry
(
'dataset'
)
PIPELINES
=
Registry
(
'pipeline'
)
def
build_dataset
(
cfg
,
default_args
=
None
):
from
.dataset_wrappers
import
(
ConcatDataset
,
RepeatDataset
,
ClassBalancedDataset
)
if
isinstance
(
cfg
,
(
list
,
tuple
)):
dataset
=
ConcatDataset
([
build_dataset
(
c
,
default_args
)
for
c
in
cfg
])
elif
cfg
[
'type'
]
==
'RepeatDataset'
:
dataset
=
RepeatDataset
(
build_dataset
(
cfg
[
'dataset'
],
default_args
),
cfg
[
'times'
])
elif
cfg
[
'type'
]
==
'ClassBalancedDataset'
:
dataset
=
ClassBalancedDataset
(
build_dataset
(
cfg
[
'dataset'
],
default_args
),
cfg
[
'oversample_thr'
])
else
:
dataset
=
build_from_cfg
(
cfg
,
DATASETS
,
default_args
)
return
dataset
def
build_dataloader
(
dataset
,
samples_per_gpu
,
workers_per_gpu
,
num_gpus
=
1
,
dist
=
True
,
shuffle
=
True
,
round_up
=
True
,
seed
=
None
,
**
kwargs
):
"""Build PyTorch DataLoader.
In distributed training, each GPU/process has a dataloader.
In non-distributed training, there is only one dataloader for all GPUs.
Args:
dataset (Dataset): A PyTorch dataset.
samples_per_gpu (int): Number of training samples on each GPU, i.e.,
batch size of each GPU.
workers_per_gpu (int): How many subprocesses to use for data loading
for each GPU.
num_gpus (int): Number of GPUs. Only used in non-distributed training.
dist (bool): Distributed training/test or not. Default: True.
shuffle (bool): Whether to shuffle the data at every epoch.
Default: True.
round_up (bool): Whether to round up the length of dataset by adding
extra samples to make it evenly divisible. Default: True.
kwargs: any keyword argument to be used to initialize DataLoader
Returns:
DataLoader: A PyTorch dataloader.
"""
rank
,
world_size
=
get_dist_info
()
if
dist
:
sampler
=
DistributedSampler
(
dataset
,
world_size
,
rank
,
shuffle
=
shuffle
,
round_up
=
round_up
)
shuffle
=
False
batch_size
=
samples_per_gpu
num_workers
=
workers_per_gpu
else
:
sampler
=
None
batch_size
=
num_gpus
*
samples_per_gpu
num_workers
=
num_gpus
*
workers_per_gpu
init_fn
=
partial
(
worker_init_fn
,
num_workers
=
num_workers
,
rank
=
rank
,
seed
=
seed
)
if
seed
is
not
None
else
None
data_loader
=
DataLoader
(
dataset
,
batch_size
=
batch_size
,
sampler
=
sampler
,
num_workers
=
num_workers
,
collate_fn
=
partial
(
collate
,
samples_per_gpu
=
samples_per_gpu
),
pin_memory
=
False
,
shuffle
=
shuffle
,
worker_init_fn
=
init_fn
,
**
kwargs
)
return
data_loader
def
worker_init_fn
(
worker_id
,
num_workers
,
rank
,
seed
):
# The seed of each worker equals to
# num_worker * rank + worker_id + user_seed
worker_seed
=
num_workers
*
rank
+
worker_id
+
seed
np
.
random
.
seed
(
worker_seed
)
random
.
seed
(
worker_seed
)
openmmlab_test/mmclassification-speed-benchmark/mmcls/datasets/cifar.py
0 → 100644
View file @
85529f35
import
os
import
os.path
import
pickle
import
numpy
as
np
import
torch.distributed
as
dist
from
mmcv.runner
import
get_dist_info
from
.base_dataset
import
BaseDataset
from
.builder
import
DATASETS
from
.utils
import
check_integrity
,
download_and_extract_archive
@
DATASETS
.
register_module
()
class
CIFAR10
(
BaseDataset
):
"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
This implementation is modified from
https://github.com/pytorch/vision/blob/master/torchvision/datasets/cifar.py # noqa: E501
"""
base_folder
=
'cifar-10-batches-py'
url
=
'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
filename
=
'cifar-10-python.tar.gz'
tgz_md5
=
'c58f30108f718f92721af3b95e74349a'
train_list
=
[
[
'data_batch_1'
,
'c99cafc152244af753f735de768cd75f'
],
[
'data_batch_2'
,
'd4bba439e000b95fd0a9bffe97cbabec'
],
[
'data_batch_3'
,
'54ebc095f3ab1f0389bbae665268c751'
],
[
'data_batch_4'
,
'634d18415352ddfa80567beed471001a'
],
[
'data_batch_5'
,
'482c414d41f54cd18b22e5b47cb7c3cb'
],
]
test_list
=
[
[
'test_batch'
,
'40351d587109b95175f43aff81a1287e'
],
]
meta
=
{
'filename'
:
'batches.meta'
,
'key'
:
'label_names'
,
'md5'
:
'5ff9c542aee3614f3951f8cda6e48888'
,
}
def
load_annotations
(
self
):
rank
,
world_size
=
get_dist_info
()
if
rank
==
0
and
not
self
.
_check_integrity
():
download_and_extract_archive
(
self
.
url
,
self
.
data_prefix
,
filename
=
self
.
filename
,
md5
=
self
.
tgz_md5
)
if
world_size
>
1
:
dist
.
barrier
()
assert
self
.
_check_integrity
(),
\
'Shared storage seems unavailable. '
\
f
'Please download the dataset manually through
{
self
.
url
}
.'
if
not
self
.
test_mode
:
downloaded_list
=
self
.
train_list
else
:
downloaded_list
=
self
.
test_list
self
.
imgs
=
[]
self
.
gt_labels
=
[]
# load the picked numpy arrays
for
file_name
,
checksum
in
downloaded_list
:
file_path
=
os
.
path
.
join
(
self
.
data_prefix
,
self
.
base_folder
,
file_name
)
with
open
(
file_path
,
'rb'
)
as
f
:
entry
=
pickle
.
load
(
f
,
encoding
=
'latin1'
)
self
.
imgs
.
append
(
entry
[
'data'
])
if
'labels'
in
entry
:
self
.
gt_labels
.
extend
(
entry
[
'labels'
])
else
:
self
.
gt_labels
.
extend
(
entry
[
'fine_labels'
])
self
.
imgs
=
np
.
vstack
(
self
.
imgs
).
reshape
(
-
1
,
3
,
32
,
32
)
self
.
imgs
=
self
.
imgs
.
transpose
((
0
,
2
,
3
,
1
))
# convert to HWC
self
.
_load_meta
()
data_infos
=
[]
for
img
,
gt_label
in
zip
(
self
.
imgs
,
self
.
gt_labels
):
gt_label
=
np
.
array
(
gt_label
,
dtype
=
np
.
int64
)
info
=
{
'img'
:
img
,
'gt_label'
:
gt_label
}
data_infos
.
append
(
info
)
return
data_infos
def
_load_meta
(
self
):
path
=
os
.
path
.
join
(
self
.
data_prefix
,
self
.
base_folder
,
self
.
meta
[
'filename'
])
if
not
check_integrity
(
path
,
self
.
meta
[
'md5'
]):
raise
RuntimeError
(
'Dataset metadata file not found or corrupted.'
+
' You can use download=True to download it'
)
with
open
(
path
,
'rb'
)
as
infile
:
data
=
pickle
.
load
(
infile
,
encoding
=
'latin1'
)
self
.
CLASSES
=
data
[
self
.
meta
[
'key'
]]
def
_check_integrity
(
self
):
root
=
self
.
data_prefix
for
fentry
in
(
self
.
train_list
+
self
.
test_list
):
filename
,
md5
=
fentry
[
0
],
fentry
[
1
]
fpath
=
os
.
path
.
join
(
root
,
self
.
base_folder
,
filename
)
if
not
check_integrity
(
fpath
,
md5
):
return
False
return
True
@
DATASETS
.
register_module
()
class
CIFAR100
(
CIFAR10
):
"""`CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset."""
base_folder
=
'cifar-100-python'
url
=
'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
filename
=
'cifar-100-python.tar.gz'
tgz_md5
=
'eb9058c3a382ffc7106e4002c42a8d85'
train_list
=
[
[
'train'
,
'16019d7e3df5f24257cddd939b257f8d'
],
]
test_list
=
[
[
'test'
,
'f0ef6b0ae62326f3e7ffdfab6717acfc'
],
]
meta
=
{
'filename'
:
'meta'
,
'key'
:
'fine_label_names'
,
'md5'
:
'7973b15100ade9c7d40fb424638fde48'
,
}
openmmlab_test/mmclassification-speed-benchmark/mmcls/datasets/dataset_wrappers.py
0 → 100644
View file @
85529f35
import
bisect
import
math
from
collections
import
defaultdict
import
numpy
as
np
from
torch.utils.data.dataset
import
ConcatDataset
as
_ConcatDataset
from
.builder
import
DATASETS
@
DATASETS
.
register_module
()
class
ConcatDataset
(
_ConcatDataset
):
"""A wrapper of concatenated dataset.
Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
add `get_cat_ids` function.
Args:
datasets (list[:obj:`Dataset`]): A list of datasets.
"""
def
__init__
(
self
,
datasets
):
super
(
ConcatDataset
,
self
).
__init__
(
datasets
)
self
.
CLASSES
=
datasets
[
0
].
CLASSES
def
get_cat_ids
(
self
,
idx
):
if
idx
<
0
:
if
-
idx
>
len
(
self
):
raise
ValueError
(
'absolute value of index should not exceed dataset length'
)
idx
=
len
(
self
)
+
idx
dataset_idx
=
bisect
.
bisect_right
(
self
.
cumulative_sizes
,
idx
)
if
dataset_idx
==
0
:
sample_idx
=
idx
else
:
sample_idx
=
idx
-
self
.
cumulative_sizes
[
dataset_idx
-
1
]
return
self
.
datasets
[
dataset_idx
].
get_cat_ids
(
sample_idx
)
@
DATASETS
.
register_module
()
class
RepeatDataset
(
object
):
"""A wrapper of repeated dataset.
The length of repeated dataset will be `times` larger than the original
dataset. This is useful when the data loading time is long but the dataset
is small. Using RepeatDataset can reduce the data loading time between
epochs.
Args:
dataset (:obj:`Dataset`): The dataset to be repeated.
times (int): Repeat times.
"""
def
__init__
(
self
,
dataset
,
times
):
self
.
dataset
=
dataset
self
.
times
=
times
self
.
CLASSES
=
dataset
.
CLASSES
self
.
_ori_len
=
len
(
self
.
dataset
)
def
__getitem__
(
self
,
idx
):
return
self
.
dataset
[
idx
%
self
.
_ori_len
]
def
get_cat_ids
(
self
,
idx
):
return
self
.
dataset
.
get_cat_ids
(
idx
%
self
.
_ori_len
)
def
__len__
(
self
):
return
self
.
times
*
self
.
_ori_len
# Modified from https://github.com/facebookresearch/detectron2/blob/41d475b75a230221e21d9cac5d69655e3415e3a4/detectron2/data/samplers/distributed_sampler.py#L57 # noqa
@
DATASETS
.
register_module
()
class
ClassBalancedDataset
(
object
):
"""A wrapper of repeated dataset with repeat factor.
Suitable for training on class imbalanced datasets like LVIS. Following
the sampling strategy in [1], in each epoch, an image may appear multiple
times based on its "repeat factor".
The repeat factor for an image is a function of the frequency the rarest
category labeled in that image. The "frequency of category c" in [0, 1]
is defined by the fraction of images in the training set (without repeats)
in which category c appears.
The dataset needs to instantiate :func:`self.get_cat_ids(idx)` to support
ClassBalancedDataset.
The repeat factor is computed as followed.
1. For each category c, compute the fraction # of images
that contain it: f(c)
2. For each category c, compute the category-level repeat factor:
r(c) = max(1, sqrt(t/f(c)))
3. For each image I and its labels L(I), compute the image-level repeat
factor:
r(I) = max_{c in L(I)} r(c)
References:
.. [1] https://arxiv.org/pdf/1908.03195.pdf
Args:
dataset (:obj:`CustomDataset`): The dataset to be repeated.
oversample_thr (float): frequency threshold below which data is
repeated. For categories with `f_c` >= `oversample_thr`, there is
no oversampling. For categories with `f_c` < `oversample_thr`, the
degree of oversampling following the square-root inverse frequency
heuristic above.
"""
def
__init__
(
self
,
dataset
,
oversample_thr
):
self
.
dataset
=
dataset
self
.
oversample_thr
=
oversample_thr
self
.
CLASSES
=
dataset
.
CLASSES
repeat_factors
=
self
.
_get_repeat_factors
(
dataset
,
oversample_thr
)
repeat_indices
=
[]
for
dataset_index
,
repeat_factor
in
enumerate
(
repeat_factors
):
repeat_indices
.
extend
([
dataset_index
]
*
math
.
ceil
(
repeat_factor
))
self
.
repeat_indices
=
repeat_indices
flags
=
[]
if
hasattr
(
self
.
dataset
,
'flag'
):
for
flag
,
repeat_factor
in
zip
(
self
.
dataset
.
flag
,
repeat_factors
):
flags
.
extend
([
flag
]
*
int
(
math
.
ceil
(
repeat_factor
)))
assert
len
(
flags
)
==
len
(
repeat_indices
)
self
.
flag
=
np
.
asarray
(
flags
,
dtype
=
np
.
uint8
)
def
_get_repeat_factors
(
self
,
dataset
,
repeat_thr
):
# 1. For each category c, compute the fraction # of images
# that contain it: f(c)
category_freq
=
defaultdict
(
int
)
num_images
=
len
(
dataset
)
for
idx
in
range
(
num_images
):
cat_ids
=
set
(
self
.
dataset
.
get_cat_ids
(
idx
))
for
cat_id
in
cat_ids
:
category_freq
[
cat_id
]
+=
1
for
k
,
v
in
category_freq
.
items
():
assert
v
>
0
,
f
'caterogy
{
k
}
does not contain any images'
category_freq
[
k
]
=
v
/
num_images
# 2. For each category c, compute the category-level repeat factor:
# r(c) = max(1, sqrt(t/f(c)))
category_repeat
=
{
cat_id
:
max
(
1.0
,
math
.
sqrt
(
repeat_thr
/
cat_freq
))
for
cat_id
,
cat_freq
in
category_freq
.
items
()
}
# 3. For each image I and its labels L(I), compute the image-level
# repeat factor:
# r(I) = max_{c in L(I)} r(c)
repeat_factors
=
[]
for
idx
in
range
(
num_images
):
cat_ids
=
set
(
self
.
dataset
.
get_cat_ids
(
idx
))
repeat_factor
=
max
(
{
category_repeat
[
cat_id
]
for
cat_id
in
cat_ids
})
repeat_factors
.
append
(
repeat_factor
)
return
repeat_factors
def
__getitem__
(
self
,
idx
):
ori_index
=
self
.
repeat_indices
[
idx
]
return
self
.
dataset
[
ori_index
]
def
__len__
(
self
):
return
len
(
self
.
repeat_indices
)
openmmlab_test/mmclassification-speed-benchmark/mmcls/datasets/dummy.py
0 → 100644
View file @
85529f35
import
numpy
as
np
from
.base_dataset
import
BaseDataset
from
.builder
import
DATASETS
@
DATASETS
.
register_module
()
class
DummyImageNet
(
BaseDataset
):
"""`Dummy ImageNet <http://www.image-net.org>`_ Dataset.
This implementation is modified from
https://github.com/pytorch/vision/blob/master/torchvision/datasets/imagenet.py # noqa: E501
"""
dummy_images
=
{
i
:
np
.
random
.
randint
(
0
,
256
,
size
=
(
224
,
224
,
3
),
dtype
=
np
.
uint8
)
for
i
in
range
(
1000
)
}
def
__init__
(
self
,
data_prefix
,
pipeline
,
classes
=
None
,
ann_file
=
None
,
test_mode
=
False
):
if
test_mode
:
self
.
size
=
50000
else
:
self
.
size
=
1281167
super
().
__init__
(
data_prefix
,
pipeline
,
classes
=
classes
,
ann_file
=
ann_file
,
test_mode
=
test_mode
)
def
load_annotations
(
self
):
data_infos
=
[]
for
i
in
range
(
self
.
size
):
gt_label
=
i
%
1000
info
=
{
'img_prefix'
:
self
.
data_prefix
}
info
[
'img'
]
=
self
.
dummy_images
[
gt_label
]
info
[
'gt_label'
]
=
np
.
array
(
gt_label
,
dtype
=
np
.
int64
)
data_infos
.
append
(
info
)
return
data_infos
openmmlab_test/mmclassification-speed-benchmark/mmcls/datasets/imagenet.py
0 → 100644
View file @
85529f35
import
os
import
numpy
as
np
from
.base_dataset
import
BaseDataset
from
.builder
import
DATASETS
def
has_file_allowed_extension
(
filename
,
extensions
):
"""Checks if a file is an allowed extension.
Args:
filename (string): path to a file
Returns:
bool: True if the filename ends with a known image extension
"""
filename_lower
=
filename
.
lower
()
return
any
(
filename_lower
.
endswith
(
ext
)
for
ext
in
extensions
)
def
find_folders
(
root
):
"""Find classes by folders under a root.
Args:
root (string): root directory of folders
Returns:
folder_to_idx (dict): the map from folder name to class idx
"""
folders
=
[
d
for
d
in
os
.
listdir
(
root
)
if
os
.
path
.
isdir
(
os
.
path
.
join
(
root
,
d
))
]
folders
.
sort
()
folder_to_idx
=
{
folders
[
i
]:
i
for
i
in
range
(
len
(
folders
))}
return
folder_to_idx
def
get_samples
(
root
,
folder_to_idx
,
extensions
):
"""Make dataset by walking all images under a root.
Args:
root (string): root directory of folders
folder_to_idx (dict): the map from class name to class idx
extensions (tuple): allowed extensions
Returns:
samples (list): a list of tuple where each element is (image, label)
"""
samples
=
[]
root
=
os
.
path
.
expanduser
(
root
)
for
folder_name
in
sorted
(
os
.
listdir
(
root
)):
_dir
=
os
.
path
.
join
(
root
,
folder_name
)
if
not
os
.
path
.
isdir
(
_dir
):
continue
for
_
,
_
,
fns
in
sorted
(
os
.
walk
(
_dir
)):
for
fn
in
sorted
(
fns
):
if
has_file_allowed_extension
(
fn
,
extensions
):
path
=
os
.
path
.
join
(
folder_name
,
fn
)
item
=
(
path
,
folder_to_idx
[
folder_name
])
samples
.
append
(
item
)
return
samples
@
DATASETS
.
register_module
()
class
ImageNet
(
BaseDataset
):
"""`ImageNet <http://www.image-net.org>`_ Dataset.
This implementation is modified from
https://github.com/pytorch/vision/blob/master/torchvision/datasets/imagenet.py # noqa: E501
"""
IMG_EXTENSIONS
=
(
'.jpg'
,
'.jpeg'
,
'.png'
,
'.ppm'
,
'.bmp'
,
'.pgm'
,
'.tif'
)
CLASSES
=
[
'tench, Tinca tinca'
,
'goldfish, Carassius auratus'
,
'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias'
,
# noqa: E501
'tiger shark, Galeocerdo cuvieri'
,
'hammerhead, hammerhead shark'
,
'electric ray, crampfish, numbfish, torpedo'
,
'stingray'
,
'cock'
,
'hen'
,
'ostrich, Struthio camelus'
,
'brambling, Fringilla montifringilla'
,
'goldfinch, Carduelis carduelis'
,
'house finch, linnet, Carpodacus mexicanus'
,
'junco, snowbird'
,
'indigo bunting, indigo finch, indigo bird, Passerina cyanea'
,
'robin, American robin, Turdus migratorius'
,
'bulbul'
,
'jay'
,
'magpie'
,
'chickadee'
,
'water ouzel, dipper'
,
'kite'
,
'bald eagle, American eagle, Haliaeetus leucocephalus'
,
'vulture'
,
'great grey owl, great gray owl, Strix nebulosa'
,
'European fire salamander, Salamandra salamandra'
,
'common newt, Triturus vulgaris'
,
'eft'
,
'spotted salamander, Ambystoma maculatum'
,
'axolotl, mud puppy, Ambystoma mexicanum'
,
'bullfrog, Rana catesbeiana'
,
'tree frog, tree-frog'
,
'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui'
,
'loggerhead, loggerhead turtle, Caretta caretta'
,
'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea'
,
# noqa: E501
'mud turtle'
,
'terrapin'
,
'box turtle, box tortoise'
,
'banded gecko'
,
'common iguana, iguana, Iguana iguana'
,
'American chameleon, anole, Anolis carolinensis'
,
'whiptail, whiptail lizard'
,
'agama'
,
'frilled lizard, Chlamydosaurus kingi'
,
'alligator lizard'
,
'Gila monster, Heloderma suspectum'
,
'green lizard, Lacerta viridis'
,
'African chameleon, Chamaeleo chamaeleon'
,
'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis'
,
# noqa: E501
'African crocodile, Nile crocodile, Crocodylus niloticus'
,
'American alligator, Alligator mississipiensis'
,
'triceratops'
,
'thunder snake, worm snake, Carphophis amoenus'
,
'ringneck snake, ring-necked snake, ring snake'
,
'hognose snake, puff adder, sand viper'
,
'green snake, grass snake'
,
'king snake, kingsnake'
,
'garter snake, grass snake'
,
'water snake'
,
'vine snake'
,
'night snake, Hypsiglena torquata'
,
'boa constrictor, Constrictor constrictor'
,
'rock python, rock snake, Python sebae'
,
'Indian cobra, Naja naja'
,
'green mamba'
,
'sea snake'
,
'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus'
,
'diamondback, diamondback rattlesnake, Crotalus adamanteus'
,
'sidewinder, horned rattlesnake, Crotalus cerastes'
,
'trilobite'
,
'harvestman, daddy longlegs, Phalangium opilio'
,
'scorpion'
,
'black and gold garden spider, Argiope aurantia'
,
'barn spider, Araneus cavaticus'
,
'garden spider, Aranea diademata'
,
'black widow, Latrodectus mactans'
,
'tarantula'
,
'wolf spider, hunting spider'
,
'tick'
,
'centipede'
,
'black grouse'
,
'ptarmigan'
,
'ruffed grouse, partridge, Bonasa umbellus'
,
'prairie chicken, prairie grouse, prairie fowl'
,
'peacock'
,
'quail'
,
'partridge'
,
'African grey, African gray, Psittacus erithacus'
,
'macaw'
,
'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita'
,
'lorikeet'
,
'coucal'
,
'bee eater'
,
'hornbill'
,
'hummingbird'
,
'jacamar'
,
'toucan'
,
'drake'
,
'red-breasted merganser, Mergus serrator'
,
'goose'
,
'black swan, Cygnus atratus'
,
'tusker'
,
'echidna, spiny anteater, anteater'
,
'platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus'
,
# noqa: E501
'wallaby, brush kangaroo'
,
'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus'
,
# noqa: E501
'wombat'
,
'jellyfish'
,
'sea anemone, anemone'
,
'brain coral'
,
'flatworm, platyhelminth'
,
'nematode, nematode worm, roundworm'
,
'conch'
,
'snail'
,
'slug'
,
'sea slug, nudibranch'
,
'chiton, coat-of-mail shell, sea cradle, polyplacophore'
,
'chambered nautilus, pearly nautilus, nautilus'
,
'Dungeness crab, Cancer magister'
,
'rock crab, Cancer irroratus'
,
'fiddler crab'
,
'king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica'
,
# noqa: E501
'American lobster, Northern lobster, Maine lobster, Homarus americanus'
,
# noqa: E501
'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish'
,
# noqa: E501
'crayfish, crawfish, crawdad, crawdaddy'
,
'hermit crab'
,
'isopod'
,
'white stork, Ciconia ciconia'
,
'black stork, Ciconia nigra'
,
'spoonbill'
,
'flamingo'
,
'little blue heron, Egretta caerulea'
,
'American egret, great white heron, Egretta albus'
,
'bittern'
,
'crane'
,
'limpkin, Aramus pictus'
,
'European gallinule, Porphyrio porphyrio'
,
'American coot, marsh hen, mud hen, water hen, Fulica americana'
,
'bustard'
,
'ruddy turnstone, Arenaria interpres'
,
'red-backed sandpiper, dunlin, Erolia alpina'
,
'redshank, Tringa totanus'
,
'dowitcher'
,
'oystercatcher, oyster catcher'
,
'pelican'
,
'king penguin, Aptenodytes patagonica'
,
'albatross, mollymawk'
,
'grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus'
,
# noqa: E501
'killer whale, killer, orca, grampus, sea wolf, Orcinus orca'
,
'dugong, Dugong dugon'
,
'sea lion'
,
'Chihuahua'
,
'Japanese spaniel'
,
'Maltese dog, Maltese terrier, Maltese'
,
'Pekinese, Pekingese, Peke'
,
'Shih-Tzu'
,
'Blenheim spaniel'
,
'papillon'
,
'toy terrier'
,
'Rhodesian ridgeback'
,
'Afghan hound, Afghan'
,
'basset, basset hound'
,
'beagle'
,
'bloodhound, sleuthhound'
,
'bluetick'
,
'black-and-tan coonhound'
,
'Walker hound, Walker foxhound'
,
'English foxhound'
,
'redbone'
,
'borzoi, Russian wolfhound'
,
'Irish wolfhound'
,
'Italian greyhound'
,
'whippet'
,
'Ibizan hound, Ibizan Podenco'
,
'Norwegian elkhound, elkhound'
,
'otterhound, otter hound'
,
'Saluki, gazelle hound'
,
'Scottish deerhound, deerhound'
,
'Weimaraner'
,
'Staffordshire bullterrier, Staffordshire bull terrier'
,
'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier'
,
# noqa: E501
'Bedlington terrier'
,
'Border terrier'
,
'Kerry blue terrier'
,
'Irish terrier'
,
'Norfolk terrier'
,
'Norwich terrier'
,
'Yorkshire terrier'
,
'wire-haired fox terrier'
,
'Lakeland terrier'
,
'Sealyham terrier, Sealyham'
,
'Airedale, Airedale terrier'
,
'cairn, cairn terrier'
,
'Australian terrier'
,
'Dandie Dinmont, Dandie Dinmont terrier'
,
'Boston bull, Boston terrier'
,
'miniature schnauzer'
,
'giant schnauzer'
,
'standard schnauzer'
,
'Scotch terrier, Scottish terrier, Scottie'
,
'Tibetan terrier, chrysanthemum dog'
,
'silky terrier, Sydney silky'
,
'soft-coated wheaten terrier'
,
'West Highland white terrier'
,
'Lhasa, Lhasa apso'
,
'flat-coated retriever'
,
'curly-coated retriever'
,
'golden retriever'
,
'Labrador retriever'
,
'Chesapeake Bay retriever'
,
'German short-haired pointer'
,
'vizsla, Hungarian pointer'
,
'English setter'
,
'Irish setter, red setter'
,
'Gordon setter'
,
'Brittany spaniel'
,
'clumber, clumber spaniel'
,
'English springer, English springer spaniel'
,
'Welsh springer spaniel'
,
'cocker spaniel, English cocker spaniel, cocker'
,
'Sussex spaniel'
,
'Irish water spaniel'
,
'kuvasz'
,
'schipperke'
,
'groenendael'
,
'malinois'
,
'briard'
,
'kelpie'
,
'komondor'
,
'Old English sheepdog, bobtail'
,
'Shetland sheepdog, Shetland sheep dog, Shetland'
,
'collie'
,
'Border collie'
,
'Bouvier des Flandres, Bouviers des Flandres'
,
'Rottweiler'
,
'German shepherd, German shepherd dog, German police dog, alsatian'
,
'Doberman, Doberman pinscher'
,
'miniature pinscher'
,
'Greater Swiss Mountain dog'
,
'Bernese mountain dog'
,
'Appenzeller'
,
'EntleBucher'
,
'boxer'
,
'bull mastiff'
,
'Tibetan mastiff'
,
'French bulldog'
,
'Great Dane'
,
'Saint Bernard, St Bernard'
,
'Eskimo dog, husky'
,
'malamute, malemute, Alaskan malamute'
,
'Siberian husky'
,
'dalmatian, coach dog, carriage dog'
,
'affenpinscher, monkey pinscher, monkey dog'
,
'basenji'
,
'pug, pug-dog'
,
'Leonberg'
,
'Newfoundland, Newfoundland dog'
,
'Great Pyrenees'
,
'Samoyed, Samoyede'
,
'Pomeranian'
,
'chow, chow chow'
,
'keeshond'
,
'Brabancon griffon'
,
'Pembroke, Pembroke Welsh corgi'
,
'Cardigan, Cardigan Welsh corgi'
,
'toy poodle'
,
'miniature poodle'
,
'standard poodle'
,
'Mexican hairless'
,
'timber wolf, grey wolf, gray wolf, Canis lupus'
,
'white wolf, Arctic wolf, Canis lupus tundrarum'
,
'red wolf, maned wolf, Canis rufus, Canis niger'
,
'coyote, prairie wolf, brush wolf, Canis latrans'
,
'dingo, warrigal, warragal, Canis dingo'
,
'dhole, Cuon alpinus'
,
'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus'
,
'hyena, hyaena'
,
'red fox, Vulpes vulpes'
,
'kit fox, Vulpes macrotis'
,
'Arctic fox, white fox, Alopex lagopus'
,
'grey fox, gray fox, Urocyon cinereoargenteus'
,
'tabby, tabby cat'
,
'tiger cat'
,
'Persian cat'
,
'Siamese cat, Siamese'
,
'Egyptian cat'
,
'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor'
,
# noqa: E501
'lynx, catamount'
,
'leopard, Panthera pardus'
,
'snow leopard, ounce, Panthera uncia'
,
'jaguar, panther, Panthera onca, Felis onca'
,
'lion, king of beasts, Panthera leo'
,
'tiger, Panthera tigris'
,
'cheetah, chetah, Acinonyx jubatus'
,
'brown bear, bruin, Ursus arctos'
,
'American black bear, black bear, Ursus americanus, Euarctos americanus'
,
# noqa: E501
'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus'
,
'sloth bear, Melursus ursinus, Ursus ursinus'
,
'mongoose'
,
'meerkat, mierkat'
,
'tiger beetle'
,
'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle'
,
'ground beetle, carabid beetle'
,
'long-horned beetle, longicorn, longicorn beetle'
,
'leaf beetle, chrysomelid'
,
'dung beetle'
,
'rhinoceros beetle'
,
'weevil'
,
'fly'
,
'bee'
,
'ant, emmet, pismire'
,
'grasshopper, hopper'
,
'cricket'
,
'walking stick, walkingstick, stick insect'
,
'cockroach, roach'
,
'mantis, mantid'
,
'cicada, cicala'
,
'leafhopper'
,
'lacewing, lacewing fly'
,
"dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk"
,
# noqa: E501
'damselfly'
,
'admiral'
,
'ringlet, ringlet butterfly'
,
'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus'
,
'cabbage butterfly'
,
'sulphur butterfly, sulfur butterfly'
,
'lycaenid, lycaenid butterfly'
,
'starfish, sea star'
,
'sea urchin'
,
'sea cucumber, holothurian'
,
'wood rabbit, cottontail, cottontail rabbit'
,
'hare'
,
'Angora, Angora rabbit'
,
'hamster'
,
'porcupine, hedgehog'
,
'fox squirrel, eastern fox squirrel, Sciurus niger'
,
'marmot'
,
'beaver'
,
'guinea pig, Cavia cobaya'
,
'sorrel'
,
'zebra'
,
'hog, pig, grunter, squealer, Sus scrofa'
,
'wild boar, boar, Sus scrofa'
,
'warthog'
,
'hippopotamus, hippo, river horse, Hippopotamus amphibius'
,
'ox'
,
'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis'
,
'bison'
,
'ram, tup'
,
'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis'
,
# noqa: E501
'ibex, Capra ibex'
,
'hartebeest'
,
'impala, Aepyceros melampus'
,
'gazelle'
,
'Arabian camel, dromedary, Camelus dromedarius'
,
'llama'
,
'weasel'
,
'mink'
,
'polecat, fitch, foulmart, foumart, Mustela putorius'
,
'black-footed ferret, ferret, Mustela nigripes'
,
'otter'
,
'skunk, polecat, wood pussy'
,
'badger'
,
'armadillo'
,
'three-toed sloth, ai, Bradypus tridactylus'
,
'orangutan, orang, orangutang, Pongo pygmaeus'
,
'gorilla, Gorilla gorilla'
,
'chimpanzee, chimp, Pan troglodytes'
,
'gibbon, Hylobates lar'
,
'siamang, Hylobates syndactylus, Symphalangus syndactylus'
,
'guenon, guenon monkey'
,
'patas, hussar monkey, Erythrocebus patas'
,
'baboon'
,
'macaque'
,
'langur'
,
'colobus, colobus monkey'
,
'proboscis monkey, Nasalis larvatus'
,
'marmoset'
,
'capuchin, ringtail, Cebus capucinus'
,
'howler monkey, howler'
,
'titi, titi monkey'
,
'spider monkey, Ateles geoffroyi'
,
'squirrel monkey, Saimiri sciureus'
,
'Madagascar cat, ring-tailed lemur, Lemur catta'
,
'indri, indris, Indri indri, Indri brevicaudatus'
,
'Indian elephant, Elephas maximus'
,
'African elephant, Loxodonta africana'
,
'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens'
,
'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca'
,
'barracouta, snoek'
,
'eel'
,
'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch'
,
# noqa: E501
'rock beauty, Holocanthus tricolor'
,
'anemone fish'
,
'sturgeon'
,
'gar, garfish, garpike, billfish, Lepisosteus osseus'
,
'lionfish'
,
'puffer, pufferfish, blowfish, globefish'
,
'abacus'
,
'abaya'
,
"academic gown, academic robe, judge's robe"
,
'accordion, piano accordion, squeeze box'
,
'acoustic guitar'
,
'aircraft carrier, carrier, flattop, attack aircraft carrier'
,
'airliner'
,
'airship, dirigible'
,
'altar'
,
'ambulance'
,
'amphibian, amphibious vehicle'
,
'analog clock'
,
'apiary, bee house'
,
'apron'
,
'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin'
,
# noqa: E501
'assault rifle, assault gun'
,
'backpack, back pack, knapsack, packsack, rucksack, haversack'
,
'bakery, bakeshop, bakehouse'
,
'balance beam, beam'
,
'balloon'
,
'ballpoint, ballpoint pen, ballpen, Biro'
,
'Band Aid'
,
'banjo'
,
'bannister, banister, balustrade, balusters, handrail'
,
'barbell'
,
'barber chair'
,
'barbershop'
,
'barn'
,
'barometer'
,
'barrel, cask'
,
'barrow, garden cart, lawn cart, wheelbarrow'
,
'baseball'
,
'basketball'
,
'bassinet'
,
'bassoon'
,
'bathing cap, swimming cap'
,
'bath towel'
,
'bathtub, bathing tub, bath, tub'
,
'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon'
,
# noqa: E501
'beacon, lighthouse, beacon light, pharos'
,
'beaker'
,
'bearskin, busby, shako'
,
'beer bottle'
,
'beer glass'
,
'bell cote, bell cot'
,
'bib'
,
'bicycle-built-for-two, tandem bicycle, tandem'
,
'bikini, two-piece'
,
'binder, ring-binder'
,
'binoculars, field glasses, opera glasses'
,
'birdhouse'
,
'boathouse'
,
'bobsled, bobsleigh, bob'
,
'bolo tie, bolo, bola tie, bola'
,
'bonnet, poke bonnet'
,
'bookcase'
,
'bookshop, bookstore, bookstall'
,
'bottlecap'
,
'bow'
,
'bow tie, bow-tie, bowtie'
,
'brass, memorial tablet, plaque'
,
'brassiere, bra, bandeau'
,
'breakwater, groin, groyne, mole, bulwark, seawall, jetty'
,
'breastplate, aegis, egis'
,
'broom'
,
'bucket, pail'
,
'buckle'
,
'bulletproof vest'
,
'bullet train, bullet'
,
'butcher shop, meat market'
,
'cab, hack, taxi, taxicab'
,
'caldron, cauldron'
,
'candle, taper, wax light'
,
'cannon'
,
'canoe'
,
'can opener, tin opener'
,
'cardigan'
,
'car mirror'
,
'carousel, carrousel, merry-go-round, roundabout, whirligig'
,
"carpenter's kit, tool kit"
,
'carton'
,
'car wheel'
,
'cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM'
,
# noqa: E501
'cassette'
,
'cassette player'
,
'castle'
,
'catamaran'
,
'CD player'
,
'cello, violoncello'
,
'cellular telephone, cellular phone, cellphone, cell, mobile phone'
,
'chain'
,
'chainlink fence'
,
'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour'
,
# noqa: E501
'chain saw, chainsaw'
,
'chest'
,
'chiffonier, commode'
,
'chime, bell, gong'
,
'china cabinet, china closet'
,
'Christmas stocking'
,
'church, church building'
,
'cinema, movie theater, movie theatre, movie house, picture palace'
,
'cleaver, meat cleaver, chopper'
,
'cliff dwelling'
,
'cloak'
,
'clog, geta, patten, sabot'
,
'cocktail shaker'
,
'coffee mug'
,
'coffeepot'
,
'coil, spiral, volute, whorl, helix'
,
'combination lock'
,
'computer keyboard, keypad'
,
'confectionery, confectionary, candy store'
,
'container ship, containership, container vessel'
,
'convertible'
,
'corkscrew, bottle screw'
,
'cornet, horn, trumpet, trump'
,
'cowboy boot'
,
'cowboy hat, ten-gallon hat'
,
'cradle'
,
'crane'
,
'crash helmet'
,
'crate'
,
'crib, cot'
,
'Crock Pot'
,
'croquet ball'
,
'crutch'
,
'cuirass'
,
'dam, dike, dyke'
,
'desk'
,
'desktop computer'
,
'dial telephone, dial phone'
,
'diaper, nappy, napkin'
,
'digital clock'
,
'digital watch'
,
'dining table, board'
,
'dishrag, dishcloth'
,
'dishwasher, dish washer, dishwashing machine'
,
'disk brake, disc brake'
,
'dock, dockage, docking facility'
,
'dogsled, dog sled, dog sleigh'
,
'dome'
,
'doormat, welcome mat'
,
'drilling platform, offshore rig'
,
'drum, membranophone, tympan'
,
'drumstick'
,
'dumbbell'
,
'Dutch oven'
,
'electric fan, blower'
,
'electric guitar'
,
'electric locomotive'
,
'entertainment center'
,
'envelope'
,
'espresso maker'
,
'face powder'
,
'feather boa, boa'
,
'file, file cabinet, filing cabinet'
,
'fireboat'
,
'fire engine, fire truck'
,
'fire screen, fireguard'
,
'flagpole, flagstaff'
,
'flute, transverse flute'
,
'folding chair'
,
'football helmet'
,
'forklift'
,
'fountain'
,
'fountain pen'
,
'four-poster'
,
'freight car'
,
'French horn, horn'
,
'frying pan, frypan, skillet'
,
'fur coat'
,
'garbage truck, dustcart'
,
'gasmask, respirator, gas helmet'
,
'gas pump, gasoline pump, petrol pump, island dispenser'
,
'goblet'
,
'go-kart'
,
'golf ball'
,
'golfcart, golf cart'
,
'gondola'
,
'gong, tam-tam'
,
'gown'
,
'grand piano, grand'
,
'greenhouse, nursery, glasshouse'
,
'grille, radiator grille'
,
'grocery store, grocery, food market, market'
,
'guillotine'
,
'hair slide'
,
'hair spray'
,
'half track'
,
'hammer'
,
'hamper'
,
'hand blower, blow dryer, blow drier, hair dryer, hair drier'
,
'hand-held computer, hand-held microcomputer'
,
'handkerchief, hankie, hanky, hankey'
,
'hard disc, hard disk, fixed disk'
,
'harmonica, mouth organ, harp, mouth harp'
,
'harp'
,
'harvester, reaper'
,
'hatchet'
,
'holster'
,
'home theater, home theatre'
,
'honeycomb'
,
'hook, claw'
,
'hoopskirt, crinoline'
,
'horizontal bar, high bar'
,
'horse cart, horse-cart'
,
'hourglass'
,
'iPod'
,
'iron, smoothing iron'
,
"jack-o'-lantern"
,
'jean, blue jean, denim'
,
'jeep, landrover'
,
'jersey, T-shirt, tee shirt'
,
'jigsaw puzzle'
,
'jinrikisha, ricksha, rickshaw'
,
'joystick'
,
'kimono'
,
'knee pad'
,
'knot'
,
'lab coat, laboratory coat'
,
'ladle'
,
'lampshade, lamp shade'
,
'laptop, laptop computer'
,
'lawn mower, mower'
,
'lens cap, lens cover'
,
'letter opener, paper knife, paperknife'
,
'library'
,
'lifeboat'
,
'lighter, light, igniter, ignitor'
,
'limousine, limo'
,
'liner, ocean liner'
,
'lipstick, lip rouge'
,
'Loafer'
,
'lotion'
,
'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system'
,
# noqa: E501
"loupe, jeweler's loupe"
,
'lumbermill, sawmill'
,
'magnetic compass'
,
'mailbag, postbag'
,
'mailbox, letter box'
,
'maillot'
,
'maillot, tank suit'
,
'manhole cover'
,
'maraca'
,
'marimba, xylophone'
,
'mask'
,
'matchstick'
,
'maypole'
,
'maze, labyrinth'
,
'measuring cup'
,
'medicine chest, medicine cabinet'
,
'megalith, megalithic structure'
,
'microphone, mike'
,
'microwave, microwave oven'
,
'military uniform'
,
'milk can'
,
'minibus'
,
'miniskirt, mini'
,
'minivan'
,
'missile'
,
'mitten'
,
'mixing bowl'
,
'mobile home, manufactured home'
,
'Model T'
,
'modem'
,
'monastery'
,
'monitor'
,
'moped'
,
'mortar'
,
'mortarboard'
,
'mosque'
,
'mosquito net'
,
'motor scooter, scooter'
,
'mountain bike, all-terrain bike, off-roader'
,
'mountain tent'
,
'mouse, computer mouse'
,
'mousetrap'
,
'moving van'
,
'muzzle'
,
'nail'
,
'neck brace'
,
'necklace'
,
'nipple'
,
'notebook, notebook computer'
,
'obelisk'
,
'oboe, hautboy, hautbois'
,
'ocarina, sweet potato'
,
'odometer, hodometer, mileometer, milometer'
,
'oil filter'
,
'organ, pipe organ'
,
'oscilloscope, scope, cathode-ray oscilloscope, CRO'
,
'overskirt'
,
'oxcart'
,
'oxygen mask'
,
'packet'
,
'paddle, boat paddle'
,
'paddlewheel, paddle wheel'
,
'padlock'
,
'paintbrush'
,
"pajama, pyjama, pj's, jammies"
,
'palace'
,
'panpipe, pandean pipe, syrinx'
,
'paper towel'
,
'parachute, chute'
,
'parallel bars, bars'
,
'park bench'
,
'parking meter'
,
'passenger car, coach, carriage'
,
'patio, terrace'
,
'pay-phone, pay-station'
,
'pedestal, plinth, footstall'
,
'pencil box, pencil case'
,
'pencil sharpener'
,
'perfume, essence'
,
'Petri dish'
,
'photocopier'
,
'pick, plectrum, plectron'
,
'pickelhaube'
,
'picket fence, paling'
,
'pickup, pickup truck'
,
'pier'
,
'piggy bank, penny bank'
,
'pill bottle'
,
'pillow'
,
'ping-pong ball'
,
'pinwheel'
,
'pirate, pirate ship'
,
'pitcher, ewer'
,
"plane, carpenter's plane, woodworking plane"
,
'planetarium'
,
'plastic bag'
,
'plate rack'
,
'plow, plough'
,
"plunger, plumber's helper"
,
'Polaroid camera, Polaroid Land camera'
,
'pole'
,
'police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria'
,
# noqa: E501
'poncho'
,
'pool table, billiard table, snooker table'
,
'pop bottle, soda bottle'
,
'pot, flowerpot'
,
"potter's wheel"
,
'power drill'
,
'prayer rug, prayer mat'
,
'printer'
,
'prison, prison house'
,
'projectile, missile'
,
'projector'
,
'puck, hockey puck'
,
'punching bag, punch bag, punching ball, punchball'
,
'purse'
,
'quill, quill pen'
,
'quilt, comforter, comfort, puff'
,
'racer, race car, racing car'
,
'racket, racquet'
,
'radiator'
,
'radio, wireless'
,
'radio telescope, radio reflector'
,
'rain barrel'
,
'recreational vehicle, RV, R.V.'
,
'reel'
,
'reflex camera'
,
'refrigerator, icebox'
,
'remote control, remote'
,
'restaurant, eating house, eating place, eatery'
,
'revolver, six-gun, six-shooter'
,
'rifle'
,
'rocking chair, rocker'
,
'rotisserie'
,
'rubber eraser, rubber, pencil eraser'
,
'rugby ball'
,
'rule, ruler'
,
'running shoe'
,
'safe'
,
'safety pin'
,
'saltshaker, salt shaker'
,
'sandal'
,
'sarong'
,
'sax, saxophone'
,
'scabbard'
,
'scale, weighing machine'
,
'school bus'
,
'schooner'
,
'scoreboard'
,
'screen, CRT screen'
,
'screw'
,
'screwdriver'
,
'seat belt, seatbelt'
,
'sewing machine'
,
'shield, buckler'
,
'shoe shop, shoe-shop, shoe store'
,
'shoji'
,
'shopping basket'
,
'shopping cart'
,
'shovel'
,
'shower cap'
,
'shower curtain'
,
'ski'
,
'ski mask'
,
'sleeping bag'
,
'slide rule, slipstick'
,
'sliding door'
,
'slot, one-armed bandit'
,
'snorkel'
,
'snowmobile'
,
'snowplow, snowplough'
,
'soap dispenser'
,
'soccer ball'
,
'sock'
,
'solar dish, solar collector, solar furnace'
,
'sombrero'
,
'soup bowl'
,
'space bar'
,
'space heater'
,
'space shuttle'
,
'spatula'
,
'speedboat'
,
"spider web, spider's web"
,
'spindle'
,
'sports car, sport car'
,
'spotlight, spot'
,
'stage'
,
'steam locomotive'
,
'steel arch bridge'
,
'steel drum'
,
'stethoscope'
,
'stole'
,
'stone wall'
,
'stopwatch, stop watch'
,
'stove'
,
'strainer'
,
'streetcar, tram, tramcar, trolley, trolley car'
,
'stretcher'
,
'studio couch, day bed'
,
'stupa, tope'
,
'submarine, pigboat, sub, U-boat'
,
'suit, suit of clothes'
,
'sundial'
,
'sunglass'
,
'sunglasses, dark glasses, shades'
,
'sunscreen, sunblock, sun blocker'
,
'suspension bridge'
,
'swab, swob, mop'
,
'sweatshirt'
,
'swimming trunks, bathing trunks'
,
'swing'
,
'switch, electric switch, electrical switch'
,
'syringe'
,
'table lamp'
,
'tank, army tank, armored combat vehicle, armoured combat vehicle'
,
'tape player'
,
'teapot'
,
'teddy, teddy bear'
,
'television, television system'
,
'tennis ball'
,
'thatch, thatched roof'
,
'theater curtain, theatre curtain'
,
'thimble'
,
'thresher, thrasher, threshing machine'
,
'throne'
,
'tile roof'
,
'toaster'
,
'tobacco shop, tobacconist shop, tobacconist'
,
'toilet seat'
,
'torch'
,
'totem pole'
,
'tow truck, tow car, wrecker'
,
'toyshop'
,
'tractor'
,
'trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi'
,
# noqa: E501
'tray'
,
'trench coat'
,
'tricycle, trike, velocipede'
,
'trimaran'
,
'tripod'
,
'triumphal arch'
,
'trolleybus, trolley coach, trackless trolley'
,
'trombone'
,
'tub, vat'
,
'turnstile'
,
'typewriter keyboard'
,
'umbrella'
,
'unicycle, monocycle'
,
'upright, upright piano'
,
'vacuum, vacuum cleaner'
,
'vase'
,
'vault'
,
'velvet'
,
'vending machine'
,
'vestment'
,
'viaduct'
,
'violin, fiddle'
,
'volleyball'
,
'waffle iron'
,
'wall clock'
,
'wallet, billfold, notecase, pocketbook'
,
'wardrobe, closet, press'
,
'warplane, military plane'
,
'washbasin, handbasin, washbowl, lavabo, wash-hand basin'
,
'washer, automatic washer, washing machine'
,
'water bottle'
,
'water jug'
,
'water tower'
,
'whiskey jug'
,
'whistle'
,
'wig'
,
'window screen'
,
'window shade'
,
'Windsor tie'
,
'wine bottle'
,
'wing'
,
'wok'
,
'wooden spoon'
,
'wool, woolen, woollen'
,
'worm fence, snake fence, snake-rail fence, Virginia fence'
,
'wreck'
,
'yawl'
,
'yurt'
,
'web site, website, internet site, site'
,
'comic book'
,
'crossword puzzle, crossword'
,
'street sign'
,
'traffic light, traffic signal, stoplight'
,
'book jacket, dust cover, dust jacket, dust wrapper'
,
'menu'
,
'plate'
,
'guacamole'
,
'consomme'
,
'hot pot, hotpot'
,
'trifle'
,
'ice cream, icecream'
,
'ice lolly, lolly, lollipop, popsicle'
,
'French loaf'
,
'bagel, beigel'
,
'pretzel'
,
'cheeseburger'
,
'hotdog, hot dog, red hot'
,
'mashed potato'
,
'head cabbage'
,
'broccoli'
,
'cauliflower'
,
'zucchini, courgette'
,
'spaghetti squash'
,
'acorn squash'
,
'butternut squash'
,
'cucumber, cuke'
,
'artichoke, globe artichoke'
,
'bell pepper'
,
'cardoon'
,
'mushroom'
,
'Granny Smith'
,
'strawberry'
,
'orange'
,
'lemon'
,
'fig'
,
'pineapple, ananas'
,
'banana'
,
'jackfruit, jak, jack'
,
'custard apple'
,
'pomegranate'
,
'hay'
,
'carbonara'
,
'chocolate sauce, chocolate syrup'
,
'dough'
,
'meat loaf, meatloaf'
,
'pizza, pizza pie'
,
'potpie'
,
'burrito'
,
'red wine'
,
'espresso'
,
'cup'
,
'eggnog'
,
'alp'
,
'bubble'
,
'cliff, drop, drop-off'
,
'coral reef'
,
'geyser'
,
'lakeside, lakeshore'
,
'promontory, headland, head, foreland'
,
'sandbar, sand bar'
,
'seashore, coast, seacoast, sea-coast'
,
'valley, vale'
,
'volcano'
,
'ballplayer, baseball player'
,
'groom, bridegroom'
,
'scuba diver'
,
'rapeseed'
,
'daisy'
,
"yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum"
,
# noqa: E501
'corn'
,
'acorn'
,
'hip, rose hip, rosehip'
,
'buckeye, horse chestnut, conker'
,
'coral fungus'
,
'agaric'
,
'gyromitra'
,
'stinkhorn, carrion fungus'
,
'earthstar'
,
'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa'
,
# noqa: E501
'bolete'
,
'ear, spike, capitulum'
,
'toilet tissue, toilet paper, bathroom tissue'
]
def
load_annotations
(
self
):
if
self
.
ann_file
is
None
:
folder_to_idx
=
find_folders
(
self
.
data_prefix
)
samples
=
get_samples
(
self
.
data_prefix
,
folder_to_idx
,
extensions
=
self
.
IMG_EXTENSIONS
)
if
len
(
samples
)
==
0
:
raise
(
RuntimeError
(
'Found 0 files in subfolders of: '
f
'
{
self
.
data_prefix
}
. '
'Supported extensions are: '
f
'
{
","
.
join
(
self
.
IMG_EXTENSIONS
)
}
'
))
self
.
folder_to_idx
=
folder_to_idx
elif
isinstance
(
self
.
ann_file
,
str
):
with
open
(
self
.
ann_file
)
as
f
:
samples
=
[
x
.
strip
().
split
(
' '
)
for
x
in
f
.
readlines
()]
else
:
raise
TypeError
(
'ann_file must be a str or None'
)
self
.
samples
=
samples
data_infos
=
[]
for
filename
,
gt_label
in
self
.
samples
:
info
=
{
'img_prefix'
:
self
.
data_prefix
}
info
[
'img_info'
]
=
{
'filename'
:
filename
}
info
[
'gt_label'
]
=
np
.
array
(
gt_label
,
dtype
=
np
.
int64
)
data_infos
.
append
(
info
)
return
data_infos
openmmlab_test/mmclassification-speed-benchmark/mmcls/datasets/mnist.py
0 → 100644
View file @
85529f35
import
codecs
import
os
import
os.path
as
osp
import
numpy
as
np
import
torch
import
torch.distributed
as
dist
from
mmcv.runner
import
get_dist_info
,
master_only
from
.base_dataset
import
BaseDataset
from
.builder
import
DATASETS
from
.utils
import
download_and_extract_archive
,
rm_suffix
@
DATASETS
.
register_module
()
class
MNIST
(
BaseDataset
):
"""`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.
This implementation is modified from
https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py # noqa: E501
"""
resource_prefix
=
'http://yann.lecun.com/exdb/mnist/'
resources
=
{
'train_image_file'
:
(
'train-images-idx3-ubyte.gz'
,
'f68b3c2dcbeaaa9fbdd348bbdeb94873'
),
'train_label_file'
:
(
'train-labels-idx1-ubyte.gz'
,
'd53e105ee54ea40749a09fcbcd1e9432'
),
'test_image_file'
:
(
't10k-images-idx3-ubyte.gz'
,
'9fb629c4189551a2d022fa330f9573f3'
),
'test_label_file'
:
(
't10k-labels-idx1-ubyte.gz'
,
'ec29112dd5afa0611ce80d1b7f02629c'
)
}
CLASSES
=
[
'0 - zero'
,
'1 - one'
,
'2 - two'
,
'3 - three'
,
'4 - four'
,
'5 - five'
,
'6 - six'
,
'7 - seven'
,
'8 - eight'
,
'9 - nine'
]
def
load_annotations
(
self
):
train_image_file
=
osp
.
join
(
self
.
data_prefix
,
rm_suffix
(
self
.
resources
[
'train_image_file'
][
0
]))
train_label_file
=
osp
.
join
(
self
.
data_prefix
,
rm_suffix
(
self
.
resources
[
'train_label_file'
][
0
]))
test_image_file
=
osp
.
join
(
self
.
data_prefix
,
rm_suffix
(
self
.
resources
[
'test_image_file'
][
0
]))
test_label_file
=
osp
.
join
(
self
.
data_prefix
,
rm_suffix
(
self
.
resources
[
'test_label_file'
][
0
]))
if
not
osp
.
exists
(
train_image_file
)
or
not
osp
.
exists
(
train_label_file
)
or
not
osp
.
exists
(
test_image_file
)
or
not
osp
.
exists
(
test_label_file
):
self
.
download
()
_
,
world_size
=
get_dist_info
()
if
world_size
>
1
:
dist
.
barrier
()
assert
osp
.
exists
(
train_image_file
)
and
osp
.
exists
(
train_label_file
)
and
osp
.
exists
(
test_image_file
)
and
osp
.
exists
(
test_label_file
),
\
'Shared storage seems unavailable. Please download dataset '
\
f
'manually through
{
self
.
resource_prefix
}
.'
train_set
=
(
read_image_file
(
train_image_file
),
read_label_file
(
train_label_file
))
test_set
=
(
read_image_file
(
test_image_file
),
read_label_file
(
test_label_file
))
if
not
self
.
test_mode
:
imgs
,
gt_labels
=
train_set
else
:
imgs
,
gt_labels
=
test_set
data_infos
=
[]
for
img
,
gt_label
in
zip
(
imgs
,
gt_labels
):
gt_label
=
np
.
array
(
gt_label
,
dtype
=
np
.
int64
)
info
=
{
'img'
:
img
.
numpy
(),
'gt_label'
:
gt_label
}
data_infos
.
append
(
info
)
return
data_infos
@
master_only
def
download
(
self
):
os
.
makedirs
(
self
.
data_prefix
,
exist_ok
=
True
)
# download files
for
url
,
md5
in
self
.
resources
.
values
():
url
=
osp
.
join
(
self
.
resource_prefix
,
url
)
filename
=
url
.
rpartition
(
'/'
)[
2
]
download_and_extract_archive
(
url
,
download_root
=
self
.
data_prefix
,
filename
=
filename
,
md5
=
md5
)
@
DATASETS
.
register_module
()
class
FashionMNIST
(
MNIST
):
"""`Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_
Dataset."""
resource_prefix
=
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/'
# noqa: E501
resources
=
{
'train_image_file'
:
(
'train-images-idx3-ubyte.gz'
,
'8d4fb7e6c68d591d4c3dfef9ec88bf0d'
),
'train_label_file'
:
(
'train-labels-idx1-ubyte.gz'
,
'25c81989df183df01b3e8a0aad5dffbe'
),
'test_image_file'
:
(
't10k-images-idx3-ubyte.gz'
,
'bef4ecab320f06d8554ea6380940ec79'
),
'test_label_file'
:
(
't10k-labels-idx1-ubyte.gz'
,
'bb300cfdad3c16e7a12a480ee83cd310'
)
}
CLASSES
=
[
'T-shirt/top'
,
'Trouser'
,
'Pullover'
,
'Dress'
,
'Coat'
,
'Sandal'
,
'Shirt'
,
'Sneaker'
,
'Bag'
,
'Ankle boot'
]
def
get_int
(
b
):
return
int
(
codecs
.
encode
(
b
,
'hex'
),
16
)
def
open_maybe_compressed_file
(
path
):
"""Return a file object that possibly decompresses 'path' on the fly.
Decompression occurs when argument `path` is a string and ends with '.gz'
or '.xz'.
"""
if
not
isinstance
(
path
,
str
):
return
path
if
path
.
endswith
(
'.gz'
):
import
gzip
return
gzip
.
open
(
path
,
'rb'
)
if
path
.
endswith
(
'.xz'
):
import
lzma
return
lzma
.
open
(
path
,
'rb'
)
return
open
(
path
,
'rb'
)
def
read_sn3_pascalvincent_tensor
(
path
,
strict
=
True
):
"""Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-
io.lsh').
Argument may be a filename, compressed filename, or file object.
"""
# typemap
if
not
hasattr
(
read_sn3_pascalvincent_tensor
,
'typemap'
):
read_sn3_pascalvincent_tensor
.
typemap
=
{
8
:
(
torch
.
uint8
,
np
.
uint8
,
np
.
uint8
),
9
:
(
torch
.
int8
,
np
.
int8
,
np
.
int8
),
11
:
(
torch
.
int16
,
np
.
dtype
(
'>i2'
),
'i2'
),
12
:
(
torch
.
int32
,
np
.
dtype
(
'>i4'
),
'i4'
),
13
:
(
torch
.
float32
,
np
.
dtype
(
'>f4'
),
'f4'
),
14
:
(
torch
.
float64
,
np
.
dtype
(
'>f8'
),
'f8'
)
}
# read
with
open_maybe_compressed_file
(
path
)
as
f
:
data
=
f
.
read
()
# parse
magic
=
get_int
(
data
[
0
:
4
])
nd
=
magic
%
256
ty
=
magic
//
256
assert
nd
>=
1
and
nd
<=
3
assert
ty
>=
8
and
ty
<=
14
m
=
read_sn3_pascalvincent_tensor
.
typemap
[
ty
]
s
=
[
get_int
(
data
[
4
*
(
i
+
1
):
4
*
(
i
+
2
)])
for
i
in
range
(
nd
)]
parsed
=
np
.
frombuffer
(
data
,
dtype
=
m
[
1
],
offset
=
(
4
*
(
nd
+
1
)))
assert
parsed
.
shape
[
0
]
==
np
.
prod
(
s
)
or
not
strict
return
torch
.
from_numpy
(
parsed
.
astype
(
m
[
2
],
copy
=
False
)).
view
(
*
s
)
def
read_label_file
(
path
):
with
open
(
path
,
'rb'
)
as
f
:
x
=
read_sn3_pascalvincent_tensor
(
f
,
strict
=
False
)
assert
(
x
.
dtype
==
torch
.
uint8
)
assert
(
x
.
ndimension
()
==
1
)
return
x
.
long
()
def
read_image_file
(
path
):
with
open
(
path
,
'rb'
)
as
f
:
x
=
read_sn3_pascalvincent_tensor
(
f
,
strict
=
False
)
assert
(
x
.
dtype
==
torch
.
uint8
)
assert
(
x
.
ndimension
()
==
3
)
return
x
Prev
1
…
10
11
12
13
14
15
16
17
18
…
49
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment