Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
dcuai
dlexamples
Commits
0fd8347d
Commit
0fd8347d
authored
Jan 08, 2023
by
unknown
Browse files
添加mmclassification-0.24.1代码,删除mmclassification-speed-benchmark
parent
cc567e9e
Changes
838
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2428 additions
and
15 deletions
+2428
-15
openmmlab_test/mmclassification-0.24.1/mmcls/core/export/test.py
...ab_test/mmclassification-0.24.1/mmcls/core/export/test.py
+96
-0
openmmlab_test/mmclassification-0.24.1/mmcls/core/hook/__init__.py
..._test/mmclassification-0.24.1/mmcls/core/hook/__init__.py
+10
-0
openmmlab_test/mmclassification-0.24.1/mmcls/core/hook/class_num_check_hook.py
...sification-0.24.1/mmcls/core/hook/class_num_check_hook.py
+73
-0
openmmlab_test/mmclassification-0.24.1/mmcls/core/hook/lr_updater.py
...est/mmclassification-0.24.1/mmcls/core/hook/lr_updater.py
+83
-0
openmmlab_test/mmclassification-0.24.1/mmcls/core/hook/precise_bn_hook.py
...mclassification-0.24.1/mmcls/core/hook/precise_bn_hook.py
+180
-0
openmmlab_test/mmclassification-0.24.1/mmcls/core/hook/wandblogger_hook.py
...classification-0.24.1/mmcls/core/hook/wandblogger_hook.py
+340
-0
openmmlab_test/mmclassification-0.24.1/mmcls/core/optimizers/__init__.py
...mmclassification-0.24.1/mmcls/core/optimizers/__init__.py
+6
-0
openmmlab_test/mmclassification-0.24.1/mmcls/core/optimizers/lamb.py
...est/mmclassification-0.24.1/mmcls/core/optimizers/lamb.py
+227
-0
openmmlab_test/mmclassification-0.24.1/mmcls/core/utils/__init__.py
...test/mmclassification-0.24.1/mmcls/core/utils/__init__.py
+7
-0
openmmlab_test/mmclassification-0.24.1/mmcls/core/utils/dist_utils.py
...st/mmclassification-0.24.1/mmcls/core/utils/dist_utils.py
+98
-0
openmmlab_test/mmclassification-0.24.1/mmcls/core/utils/misc.py
...lab_test/mmclassification-0.24.1/mmcls/core/utils/misc.py
+1
-0
openmmlab_test/mmclassification-0.24.1/mmcls/core/visualization/__init__.py
...lassification-0.24.1/mmcls/core/visualization/__init__.py
+8
-0
openmmlab_test/mmclassification-0.24.1/mmcls/core/visualization/image.py
...mmclassification-0.24.1/mmcls/core/visualization/image.py
+343
-0
openmmlab_test/mmclassification-0.24.1/mmcls/datasets/__init__.py
...b_test/mmclassification-0.24.1/mmcls/datasets/__init__.py
+25
-0
openmmlab_test/mmclassification-0.24.1/mmcls/datasets/base_dataset.py
...st/mmclassification-0.24.1/mmcls/datasets/base_dataset.py
+36
-13
openmmlab_test/mmclassification-0.24.1/mmcls/datasets/builder.py
...ab_test/mmclassification-0.24.1/mmcls/datasets/builder.py
+183
-0
openmmlab_test/mmclassification-0.24.1/mmcls/datasets/cifar.py
...mlab_test/mmclassification-0.24.1/mmcls/datasets/cifar.py
+25
-2
openmmlab_test/mmclassification-0.24.1/mmcls/datasets/cub.py
openmmlab_test/mmclassification-0.24.1/mmcls/datasets/cub.py
+129
-0
openmmlab_test/mmclassification-0.24.1/mmcls/datasets/custom.py
...lab_test/mmclassification-0.24.1/mmcls/datasets/custom.py
+229
-0
openmmlab_test/mmclassification-0.24.1/mmcls/datasets/dataset_wrappers.py
...mclassification-0.24.1/mmcls/datasets/dataset_wrappers.py
+329
-0
No files found.
Too many changes to show.
To preserve performance only
838 of 838+
files are displayed.
Plain diff
Email patch
openmmlab_test/mmclassification-0.24.1/mmcls/core/export/test.py
0 → 100644
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
import
warnings
import
numpy
as
np
import
onnxruntime
as
ort
import
torch
from
mmcls.models.classifiers
import
BaseClassifier
class
ONNXRuntimeClassifier
(
BaseClassifier
):
"""Wrapper for classifier's inference with ONNXRuntime."""
def
__init__
(
self
,
onnx_file
,
class_names
,
device_id
):
super
(
ONNXRuntimeClassifier
,
self
).
__init__
()
sess
=
ort
.
InferenceSession
(
onnx_file
)
providers
=
[
'CPUExecutionProvider'
]
options
=
[{}]
is_cuda_available
=
ort
.
get_device
()
==
'GPU'
if
is_cuda_available
:
providers
.
insert
(
0
,
'CUDAExecutionProvider'
)
options
.
insert
(
0
,
{
'device_id'
:
device_id
})
sess
.
set_providers
(
providers
,
options
)
self
.
sess
=
sess
self
.
CLASSES
=
class_names
self
.
device_id
=
device_id
self
.
io_binding
=
sess
.
io_binding
()
self
.
output_names
=
[
_
.
name
for
_
in
sess
.
get_outputs
()]
self
.
is_cuda_available
=
is_cuda_available
def
simple_test
(
self
,
img
,
img_metas
,
**
kwargs
):
raise
NotImplementedError
(
'This method is not implemented.'
)
def
extract_feat
(
self
,
imgs
):
raise
NotImplementedError
(
'This method is not implemented.'
)
def
forward_train
(
self
,
imgs
,
**
kwargs
):
raise
NotImplementedError
(
'This method is not implemented.'
)
def
forward_test
(
self
,
imgs
,
img_metas
,
**
kwargs
):
input_data
=
imgs
# set io binding for inputs/outputs
device_type
=
'cuda'
if
self
.
is_cuda_available
else
'cpu'
if
not
self
.
is_cuda_available
:
input_data
=
input_data
.
cpu
()
self
.
io_binding
.
bind_input
(
name
=
'input'
,
device_type
=
device_type
,
device_id
=
self
.
device_id
,
element_type
=
np
.
float32
,
shape
=
input_data
.
shape
,
buffer_ptr
=
input_data
.
data_ptr
())
for
name
in
self
.
output_names
:
self
.
io_binding
.
bind_output
(
name
)
# run session to get outputs
self
.
sess
.
run_with_iobinding
(
self
.
io_binding
)
results
=
self
.
io_binding
.
copy_outputs_to_cpu
()[
0
]
return
list
(
results
)
class
TensorRTClassifier
(
BaseClassifier
):
def
__init__
(
self
,
trt_file
,
class_names
,
device_id
):
super
(
TensorRTClassifier
,
self
).
__init__
()
from
mmcv.tensorrt
import
TRTWraper
,
load_tensorrt_plugin
try
:
load_tensorrt_plugin
()
except
(
ImportError
,
ModuleNotFoundError
):
warnings
.
warn
(
'If input model has custom op from mmcv,
\
you may have to build mmcv with TensorRT from source.'
)
model
=
TRTWraper
(
trt_file
,
input_names
=
[
'input'
],
output_names
=
[
'probs'
])
self
.
model
=
model
self
.
device_id
=
device_id
self
.
CLASSES
=
class_names
def
simple_test
(
self
,
img
,
img_metas
,
**
kwargs
):
raise
NotImplementedError
(
'This method is not implemented.'
)
def
extract_feat
(
self
,
imgs
):
raise
NotImplementedError
(
'This method is not implemented.'
)
def
forward_train
(
self
,
imgs
,
**
kwargs
):
raise
NotImplementedError
(
'This method is not implemented.'
)
def
forward_test
(
self
,
imgs
,
img_metas
,
**
kwargs
):
input_data
=
imgs
with
torch
.
cuda
.
device
(
self
.
device_id
),
torch
.
no_grad
():
results
=
self
.
model
({
'input'
:
input_data
})[
'probs'
]
results
=
results
.
detach
().
cpu
().
numpy
()
return
list
(
results
)
openmmlab_test/mmclassification-0.24.1/mmcls/core/hook/__init__.py
0 → 100644
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
from
.class_num_check_hook
import
ClassNumCheckHook
from
.lr_updater
import
CosineAnnealingCooldownLrUpdaterHook
from
.precise_bn_hook
import
PreciseBNHook
from
.wandblogger_hook
import
MMClsWandbHook
__all__
=
[
'ClassNumCheckHook'
,
'PreciseBNHook'
,
'CosineAnnealingCooldownLrUpdaterHook'
,
'MMClsWandbHook'
]
openmmlab_test/mmclassification-0.24.1/mmcls/core/hook/class_num_check_hook.py
0 → 100644
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved
from
mmcv.runner
import
IterBasedRunner
from
mmcv.runner.hooks
import
HOOKS
,
Hook
from
mmcv.utils
import
is_seq_of
@
HOOKS
.
register_module
()
class
ClassNumCheckHook
(
Hook
):
def
_check_head
(
self
,
runner
,
dataset
):
"""Check whether the `num_classes` in head matches the length of
`CLASSES` in `dataset`.
Args:
runner (obj:`EpochBasedRunner`, `IterBasedRunner`): runner object.
dataset (obj: `BaseDataset`): the dataset to check.
"""
model
=
runner
.
model
if
dataset
.
CLASSES
is
None
:
runner
.
logger
.
warning
(
f
'Please set `CLASSES` '
f
'in the
{
dataset
.
__class__
.
__name__
}
and'
f
'check if it is consistent with the `num_classes` '
f
'of head'
)
else
:
assert
is_seq_of
(
dataset
.
CLASSES
,
str
),
\
(
f
'`CLASSES` in
{
dataset
.
__class__
.
__name__
}
'
f
'should be a tuple of str.'
)
for
name
,
module
in
model
.
named_modules
():
if
hasattr
(
module
,
'num_classes'
):
assert
module
.
num_classes
==
len
(
dataset
.
CLASSES
),
\
(
f
'The `num_classes` (
{
module
.
num_classes
}
) in '
f
'
{
module
.
__class__
.
__name__
}
of '
f
'
{
model
.
__class__
.
__name__
}
does not matches '
f
'the length of `CLASSES` '
f
'
{
len
(
dataset
.
CLASSES
)
}
) in '
f
'
{
dataset
.
__class__
.
__name__
}
'
)
def
before_train_iter
(
self
,
runner
):
"""Check whether the training dataset is compatible with head.
Args:
runner (obj: `IterBasedRunner`): Iter based Runner.
"""
if
not
isinstance
(
runner
,
IterBasedRunner
):
return
self
.
_check_head
(
runner
,
runner
.
data_loader
.
_dataloader
.
dataset
)
def
before_val_iter
(
self
,
runner
):
"""Check whether the eval dataset is compatible with head.
Args:
runner (obj:`IterBasedRunner`): Iter based Runner.
"""
if
not
isinstance
(
runner
,
IterBasedRunner
):
return
self
.
_check_head
(
runner
,
runner
.
data_loader
.
_dataloader
.
dataset
)
def
before_train_epoch
(
self
,
runner
):
"""Check whether the training dataset is compatible with head.
Args:
runner (obj:`EpochBasedRunner`): Epoch based Runner.
"""
self
.
_check_head
(
runner
,
runner
.
data_loader
.
dataset
)
def
before_val_epoch
(
self
,
runner
):
"""Check whether the eval dataset is compatible with head.
Args:
runner (obj:`EpochBasedRunner`): Epoch based Runner.
"""
self
.
_check_head
(
runner
,
runner
.
data_loader
.
dataset
)
openmmlab_test/mmclassification-0.24.1/mmcls/core/hook/lr_updater.py
0 → 100644
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
from
math
import
cos
,
pi
from
mmcv.runner.hooks
import
HOOKS
,
LrUpdaterHook
@
HOOKS
.
register_module
()
class
CosineAnnealingCooldownLrUpdaterHook
(
LrUpdaterHook
):
"""Cosine annealing learning rate scheduler with cooldown.
Args:
min_lr (float, optional): The minimum learning rate after annealing.
Defaults to None.
min_lr_ratio (float, optional): The minimum learning ratio after
nnealing. Defaults to None.
cool_down_ratio (float): The cooldown ratio. Defaults to 0.1.
cool_down_time (int): The cooldown time. Defaults to 10.
by_epoch (bool): If True, the learning rate changes epoch by epoch. If
False, the learning rate changes iter by iter. Defaults to True.
warmup (string, optional): Type of warmup used. It can be None (use no
warmup), 'constant', 'linear' or 'exp'. Defaults to None.
warmup_iters (int): The number of iterations or epochs that warmup
lasts. Defaults to 0.
warmup_ratio (float): LR used at the beginning of warmup equals to
``warmup_ratio * initial_lr``. Defaults to 0.1.
warmup_by_epoch (bool): If True, the ``warmup_iters``
means the number of epochs that warmup lasts, otherwise means the
number of iteration that warmup lasts. Defaults to False.
Note:
You need to set one and only one of ``min_lr`` and ``min_lr_ratio``.
"""
def
__init__
(
self
,
min_lr
=
None
,
min_lr_ratio
=
None
,
cool_down_ratio
=
0.1
,
cool_down_time
=
10
,
**
kwargs
):
assert
(
min_lr
is
None
)
^
(
min_lr_ratio
is
None
)
self
.
min_lr
=
min_lr
self
.
min_lr_ratio
=
min_lr_ratio
self
.
cool_down_time
=
cool_down_time
self
.
cool_down_ratio
=
cool_down_ratio
super
(
CosineAnnealingCooldownLrUpdaterHook
,
self
).
__init__
(
**
kwargs
)
def
get_lr
(
self
,
runner
,
base_lr
):
if
self
.
by_epoch
:
progress
=
runner
.
epoch
max_progress
=
runner
.
max_epochs
else
:
progress
=
runner
.
iter
max_progress
=
runner
.
max_iters
if
self
.
min_lr_ratio
is
not
None
:
target_lr
=
base_lr
*
self
.
min_lr_ratio
else
:
target_lr
=
self
.
min_lr
if
progress
>
max_progress
-
self
.
cool_down_time
:
return
target_lr
*
self
.
cool_down_ratio
else
:
max_progress
=
max_progress
-
self
.
cool_down_time
return
annealing_cos
(
base_lr
,
target_lr
,
progress
/
max_progress
)
def
annealing_cos
(
start
,
end
,
factor
,
weight
=
1
):
"""Calculate annealing cos learning rate.
Cosine anneal from `weight * start + (1 - weight) * end` to `end` as
percentage goes from 0.0 to 1.0.
Args:
start (float): The starting learning rate of the cosine annealing.
end (float): The ending learing rate of the cosine annealing.
factor (float): The coefficient of `pi` when calculating the current
percentage. Range from 0.0 to 1.0.
weight (float, optional): The combination factor of `start` and `end`
when calculating the actual starting learning rate. Default to 1.
"""
cos_out
=
cos
(
pi
*
factor
)
+
1
return
end
+
0.5
*
weight
*
(
start
-
end
)
*
cos_out
openmmlab_test/mmclassification-0.24.1/mmcls/core/hook/precise_bn_hook.py
0 → 100644
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
# Adapted from https://github.com/facebookresearch/pycls/blob/f8cd962737e33ce9e19b3083a33551da95c2d9c0/pycls/core/net.py # noqa: E501
# Original licence: Copyright (c) 2019 Facebook, Inc under the Apache License 2.0 # noqa: E501
import
itertools
import
logging
from
typing
import
List
,
Optional
import
mmcv
import
torch
import
torch.nn
as
nn
from
mmcv.runner
import
EpochBasedRunner
,
get_dist_info
from
mmcv.runner.hooks
import
HOOKS
,
Hook
from
mmcv.utils
import
print_log
from
torch.functional
import
Tensor
from
torch.nn
import
GroupNorm
from
torch.nn.modules.batchnorm
import
_BatchNorm
from
torch.nn.modules.instancenorm
import
_InstanceNorm
from
torch.utils.data
import
DataLoader
def
scaled_all_reduce
(
tensors
:
List
[
Tensor
],
num_gpus
:
int
)
->
List
[
Tensor
]:
"""Performs the scaled all_reduce operation on the provided tensors.
The input tensors are modified in-place. Currently supports only the sum
reduction operator. The reduced values are scaled by the inverse size of
the process group.
Args:
tensors (List[torch.Tensor]): The tensors to process.
num_gpus (int): The number of gpus to use
Returns:
List[torch.Tensor]: The processed tensors.
"""
# There is no need for reduction in the single-proc case
if
num_gpus
==
1
:
return
tensors
# Queue the reductions
reductions
=
[]
for
tensor
in
tensors
:
reduction
=
torch
.
distributed
.
all_reduce
(
tensor
,
async_op
=
True
)
reductions
.
append
(
reduction
)
# Wait for reductions to finish
for
reduction
in
reductions
:
reduction
.
wait
()
# Scale the results
for
tensor
in
tensors
:
tensor
.
mul_
(
1.0
/
num_gpus
)
return
tensors
@
torch
.
no_grad
()
def
update_bn_stats
(
model
:
nn
.
Module
,
loader
:
DataLoader
,
num_samples
:
int
=
8192
,
logger
:
Optional
[
logging
.
Logger
]
=
None
)
->
None
:
"""Computes precise BN stats on training data.
Args:
model (nn.module): The model whose bn stats will be recomputed.
loader (DataLoader): PyTorch dataloader._dataloader
num_samples (int): The number of samples to update the bn stats.
Defaults to 8192.
logger (:obj:`logging.Logger` | None): Logger for logging.
Default: None.
"""
# get dist info
rank
,
world_size
=
get_dist_info
()
# Compute the number of mini-batches to use, if the size of dataloader is
# less than num_iters, use all the samples in dataloader.
num_iter
=
num_samples
//
(
loader
.
batch_size
*
world_size
)
num_iter
=
min
(
num_iter
,
len
(
loader
))
# Retrieve the BN layers
bn_layers
=
[
m
for
m
in
model
.
modules
()
if
m
.
training
and
isinstance
(
m
,
(
_BatchNorm
))
]
if
len
(
bn_layers
)
==
0
:
print_log
(
'No BN found in model'
,
logger
=
logger
,
level
=
logging
.
WARNING
)
return
print_log
(
f
'
{
len
(
bn_layers
)
}
BN found, run
{
num_iter
}
iters...'
,
logger
=
logger
)
# Finds all the other norm layers with training=True.
other_norm_layers
=
[
m
for
m
in
model
.
modules
()
if
m
.
training
and
isinstance
(
m
,
(
_InstanceNorm
,
GroupNorm
))
]
if
len
(
other_norm_layers
)
>
0
:
print_log
(
'IN/GN stats will not be updated in PreciseHook.'
,
logger
=
logger
,
level
=
logging
.
INFO
)
# Initialize BN stats storage for computing
# mean(mean(batch)) and mean(var(batch))
running_means
=
[
torch
.
zeros_like
(
bn
.
running_mean
)
for
bn
in
bn_layers
]
running_vars
=
[
torch
.
zeros_like
(
bn
.
running_var
)
for
bn
in
bn_layers
]
# Remember momentum values
momentums
=
[
bn
.
momentum
for
bn
in
bn_layers
]
# Set momentum to 1.0 to compute BN stats that reflect the current batch
for
bn
in
bn_layers
:
bn
.
momentum
=
1.0
# Average the BN stats for each BN layer over the batches
if
rank
==
0
:
prog_bar
=
mmcv
.
ProgressBar
(
num_iter
)
for
data
in
itertools
.
islice
(
loader
,
num_iter
):
model
.
train_step
(
data
)
for
i
,
bn
in
enumerate
(
bn_layers
):
running_means
[
i
]
+=
bn
.
running_mean
/
num_iter
running_vars
[
i
]
+=
bn
.
running_var
/
num_iter
if
rank
==
0
:
prog_bar
.
update
()
# Sync BN stats across GPUs (no reduction if 1 GPU used)
running_means
=
scaled_all_reduce
(
running_means
,
world_size
)
running_vars
=
scaled_all_reduce
(
running_vars
,
world_size
)
# Set BN stats and restore original momentum values
for
i
,
bn
in
enumerate
(
bn_layers
):
bn
.
running_mean
=
running_means
[
i
]
bn
.
running_var
=
running_vars
[
i
]
bn
.
momentum
=
momentums
[
i
]
@
HOOKS
.
register_module
()
class
PreciseBNHook
(
Hook
):
"""Precise BN hook.
Recompute and update the batch norm stats to make them more precise. During
training both BN stats and the weight are changing after every iteration,
so the running average can not precisely reflect the actual stats of the
current model.
With this hook, the BN stats are recomputed with fixed weights, to make the
running average more precise. Specifically, it computes the true average of
per-batch mean/variance instead of the running average. See Sec. 3 of the
paper `Rethinking Batch in BatchNorm <https://arxiv.org/abs/2105.07576>`
for details.
This hook will update BN stats, so it should be executed before
``CheckpointHook`` and ``EMAHook``, generally set its priority to
"ABOVE_NORMAL".
Args:
num_samples (int): The number of samples to update the bn stats.
Defaults to 8192.
interval (int): Perform precise bn interval. Defaults to 1.
"""
def
__init__
(
self
,
num_samples
:
int
=
8192
,
interval
:
int
=
1
)
->
None
:
assert
interval
>
0
and
num_samples
>
0
self
.
interval
=
interval
self
.
num_samples
=
num_samples
def
_perform_precise_bn
(
self
,
runner
:
EpochBasedRunner
)
->
None
:
print_log
(
f
'Running Precise BN for
{
self
.
num_samples
}
items...'
,
logger
=
runner
.
logger
)
update_bn_stats
(
runner
.
model
,
runner
.
data_loader
,
self
.
num_samples
,
logger
=
runner
.
logger
)
print_log
(
'Finish Precise BN, BN stats updated.'
,
logger
=
runner
.
logger
)
def
after_train_epoch
(
self
,
runner
:
EpochBasedRunner
)
->
None
:
"""Calculate prcise BN and broadcast BN stats across GPUs.
Args:
runner (obj:`EpochBasedRunner`): runner object.
"""
assert
isinstance
(
runner
,
EpochBasedRunner
),
\
'PreciseBN only supports `EpochBasedRunner` by now'
# if by epoch, do perform precise every `self.interval` epochs;
if
self
.
every_n_epochs
(
runner
,
self
.
interval
):
self
.
_perform_precise_bn
(
runner
)
openmmlab_test/mmclassification-0.24.1/mmcls/core/hook/wandblogger_hook.py
0 → 100644
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
import
os.path
as
osp
import
numpy
as
np
from
mmcv.runner
import
HOOKS
,
BaseRunner
from
mmcv.runner.dist_utils
import
master_only
from
mmcv.runner.hooks.checkpoint
import
CheckpointHook
from
mmcv.runner.hooks.evaluation
import
DistEvalHook
,
EvalHook
from
mmcv.runner.hooks.logger.wandb
import
WandbLoggerHook
@
HOOKS
.
register_module
()
class
MMClsWandbHook
(
WandbLoggerHook
):
"""Enhanced Wandb logger hook for classification.
Comparing with the :cls:`mmcv.runner.WandbLoggerHook`, this hook can not
only automatically log all information in ``log_buffer`` but also log
the following extra information.
- **Checkpoints**: If ``log_checkpoint`` is True, the checkpoint saved at
every checkpoint interval will be saved as W&B Artifacts. This depends on
the : class:`mmcv.runner.CheckpointHook` whose priority is higher than
this hook. Please refer to
https://docs.wandb.ai/guides/artifacts/model-versioning to learn more
about model versioning with W&B Artifacts.
- **Checkpoint Metadata**: If ``log_checkpoint_metadata`` is True, every
checkpoint artifact will have a metadata associated with it. The metadata
contains the evaluation metrics computed on validation data with that
checkpoint along with the current epoch/iter. It depends on
:class:`EvalHook` whose priority is higher than this hook.
- **Evaluation**: At every interval, this hook logs the model prediction as
interactive W&B Tables. The number of samples logged is given by
``num_eval_images``. Currently, this hook logs the predicted labels along
with the ground truth at every evaluation interval. This depends on the
:class:`EvalHook` whose priority is higher than this hook. Also note that
the data is just logged once and subsequent evaluation tables uses
reference to the logged data to save memory usage. Please refer to
https://docs.wandb.ai/guides/data-vis to learn more about W&B Tables.
Here is a config example:
.. code:: python
checkpoint_config = dict(interval=10)
# To log checkpoint metadata, the interval of checkpoint saving should
# be divisible by the interval of evaluation.
evaluation = dict(interval=5)
log_config = dict(
...
hooks=[
...
dict(type='MMClsWandbHook',
init_kwargs={
'entity': "YOUR_ENTITY",
'project': "YOUR_PROJECT_NAME"
},
log_checkpoint=True,
log_checkpoint_metadata=True,
num_eval_images=100)
])
Args:
init_kwargs (dict): A dict passed to wandb.init to initialize
a W&B run. Please refer to https://docs.wandb.ai/ref/python/init
for possible key-value pairs.
interval (int): Logging interval (every k iterations). Defaults to 10.
log_checkpoint (bool): Save the checkpoint at every checkpoint interval
as W&B Artifacts. Use this for model versioning where each version
is a checkpoint. Defaults to False.
log_checkpoint_metadata (bool): Log the evaluation metrics computed
on the validation data with the checkpoint, along with current
epoch as a metadata to that checkpoint.
Defaults to True.
num_eval_images (int): The number of validation images to be logged.
If zero, the evaluation won't be logged. Defaults to 100.
"""
def
__init__
(
self
,
init_kwargs
=
None
,
interval
=
10
,
log_checkpoint
=
False
,
log_checkpoint_metadata
=
False
,
num_eval_images
=
100
,
**
kwargs
):
super
(
MMClsWandbHook
,
self
).
__init__
(
init_kwargs
,
interval
,
**
kwargs
)
self
.
log_checkpoint
=
log_checkpoint
self
.
log_checkpoint_metadata
=
(
log_checkpoint
and
log_checkpoint_metadata
)
self
.
num_eval_images
=
num_eval_images
self
.
log_evaluation
=
(
num_eval_images
>
0
)
self
.
ckpt_hook
:
CheckpointHook
=
None
self
.
eval_hook
:
EvalHook
=
None
@
master_only
def
before_run
(
self
,
runner
:
BaseRunner
):
super
(
MMClsWandbHook
,
self
).
before_run
(
runner
)
# Inspect CheckpointHook and EvalHook
for
hook
in
runner
.
hooks
:
if
isinstance
(
hook
,
CheckpointHook
):
self
.
ckpt_hook
=
hook
if
isinstance
(
hook
,
(
EvalHook
,
DistEvalHook
)):
self
.
eval_hook
=
hook
# Check conditions to log checkpoint
if
self
.
log_checkpoint
:
if
self
.
ckpt_hook
is
None
:
self
.
log_checkpoint
=
False
self
.
log_checkpoint_metadata
=
False
runner
.
logger
.
warning
(
'To log checkpoint in MMClsWandbHook, `CheckpointHook` is'
'required, please check hooks in the runner.'
)
else
:
self
.
ckpt_interval
=
self
.
ckpt_hook
.
interval
# Check conditions to log evaluation
if
self
.
log_evaluation
or
self
.
log_checkpoint_metadata
:
if
self
.
eval_hook
is
None
:
self
.
log_evaluation
=
False
self
.
log_checkpoint_metadata
=
False
runner
.
logger
.
warning
(
'To log evaluation or checkpoint metadata in '
'MMClsWandbHook, `EvalHook` or `DistEvalHook` in mmcls '
'is required, please check whether the validation '
'is enabled.'
)
else
:
self
.
eval_interval
=
self
.
eval_hook
.
interval
self
.
val_dataset
=
self
.
eval_hook
.
dataloader
.
dataset
if
(
self
.
log_evaluation
and
self
.
num_eval_images
>
len
(
self
.
val_dataset
)):
self
.
num_eval_images
=
len
(
self
.
val_dataset
)
runner
.
logger
.
warning
(
f
'The num_eval_images (
{
self
.
num_eval_images
}
) is '
'greater than the total number of validation samples '
f
'(
{
len
(
self
.
val_dataset
)
}
). The complete validation '
'dataset will be logged.'
)
# Check conditions to log checkpoint metadata
if
self
.
log_checkpoint_metadata
:
assert
self
.
ckpt_interval
%
self
.
eval_interval
==
0
,
\
'To log checkpoint metadata in MMClsWandbHook, the interval '
\
f
'of checkpoint saving (
{
self
.
ckpt_interval
}
) should be '
\
'divisible by the interval of evaluation '
\
f
'(
{
self
.
eval_interval
}
).'
# Initialize evaluation table
if
self
.
log_evaluation
:
# Initialize data table
self
.
_init_data_table
()
# Add ground truth to the data table
self
.
_add_ground_truth
()
# Log ground truth data
self
.
_log_data_table
()
@
master_only
def
after_train_epoch
(
self
,
runner
):
super
(
MMClsWandbHook
,
self
).
after_train_epoch
(
runner
)
if
not
self
.
by_epoch
:
return
# Save checkpoint and metadata
if
(
self
.
log_checkpoint
and
self
.
every_n_epochs
(
runner
,
self
.
ckpt_interval
)
or
(
self
.
ckpt_hook
.
save_last
and
self
.
is_last_epoch
(
runner
))):
if
self
.
log_checkpoint_metadata
and
self
.
eval_hook
:
metadata
=
{
'epoch'
:
runner
.
epoch
+
1
,
**
self
.
_get_eval_results
()
}
else
:
metadata
=
None
aliases
=
[
f
'epoch_
{
runner
.
epoch
+
1
}
'
,
'latest'
]
model_path
=
osp
.
join
(
self
.
ckpt_hook
.
out_dir
,
f
'epoch_
{
runner
.
epoch
+
1
}
.pth'
)
self
.
_log_ckpt_as_artifact
(
model_path
,
aliases
,
metadata
)
# Save prediction table
if
self
.
log_evaluation
and
self
.
eval_hook
.
_should_evaluate
(
runner
):
results
=
self
.
eval_hook
.
latest_results
# Initialize evaluation table
self
.
_init_pred_table
()
# Add predictions to evaluation table
self
.
_add_predictions
(
results
,
runner
.
epoch
+
1
)
# Log the evaluation table
self
.
_log_eval_table
(
runner
.
epoch
+
1
)
@
master_only
def
after_train_iter
(
self
,
runner
):
if
self
.
get_mode
(
runner
)
==
'train'
:
# An ugly patch. The iter-based eval hook will call the
# `after_train_iter` method of all logger hooks before evaluation.
# Use this trick to skip that call.
# Don't call super method at first, it will clear the log_buffer
return
super
(
MMClsWandbHook
,
self
).
after_train_iter
(
runner
)
else
:
super
(
MMClsWandbHook
,
self
).
after_train_iter
(
runner
)
if
self
.
by_epoch
:
return
# Save checkpoint and metadata
if
(
self
.
log_checkpoint
and
self
.
every_n_iters
(
runner
,
self
.
ckpt_interval
)
or
(
self
.
ckpt_hook
.
save_last
and
self
.
is_last_iter
(
runner
))):
if
self
.
log_checkpoint_metadata
and
self
.
eval_hook
:
metadata
=
{
'iter'
:
runner
.
iter
+
1
,
**
self
.
_get_eval_results
()
}
else
:
metadata
=
None
aliases
=
[
f
'iter_
{
runner
.
iter
+
1
}
'
,
'latest'
]
model_path
=
osp
.
join
(
self
.
ckpt_hook
.
out_dir
,
f
'iter_
{
runner
.
iter
+
1
}
.pth'
)
self
.
_log_ckpt_as_artifact
(
model_path
,
aliases
,
metadata
)
# Save prediction table
if
self
.
log_evaluation
and
self
.
eval_hook
.
_should_evaluate
(
runner
):
results
=
self
.
eval_hook
.
latest_results
# Initialize evaluation table
self
.
_init_pred_table
()
# Log predictions
self
.
_add_predictions
(
results
,
runner
.
iter
+
1
)
# Log the table
self
.
_log_eval_table
(
runner
.
iter
+
1
)
@
master_only
def
after_run
(
self
,
runner
):
self
.
wandb
.
finish
()
def
_log_ckpt_as_artifact
(
self
,
model_path
,
aliases
,
metadata
=
None
):
"""Log model checkpoint as W&B Artifact.
Args:
model_path (str): Path of the checkpoint to log.
aliases (list): List of the aliases associated with this artifact.
metadata (dict, optional): Metadata associated with this artifact.
"""
model_artifact
=
self
.
wandb
.
Artifact
(
f
'run_
{
self
.
wandb
.
run
.
id
}
_model'
,
type
=
'model'
,
metadata
=
metadata
)
model_artifact
.
add_file
(
model_path
)
self
.
wandb
.
log_artifact
(
model_artifact
,
aliases
=
aliases
)
def
_get_eval_results
(
self
):
"""Get model evaluation results."""
results
=
self
.
eval_hook
.
latest_results
eval_results
=
self
.
val_dataset
.
evaluate
(
results
,
logger
=
'silent'
,
**
self
.
eval_hook
.
eval_kwargs
)
return
eval_results
def
_init_data_table
(
self
):
"""Initialize the W&B Tables for validation data."""
columns
=
[
'image_name'
,
'image'
,
'ground_truth'
]
self
.
data_table
=
self
.
wandb
.
Table
(
columns
=
columns
)
def
_init_pred_table
(
self
):
"""Initialize the W&B Tables for model evaluation."""
columns
=
[
'epoch'
]
if
self
.
by_epoch
else
[
'iter'
]
columns
+=
[
'image_name'
,
'image'
,
'ground_truth'
,
'prediction'
]
+
list
(
self
.
val_dataset
.
CLASSES
)
self
.
eval_table
=
self
.
wandb
.
Table
(
columns
=
columns
)
def
_add_ground_truth
(
self
):
# Get image loading pipeline
from
mmcls.datasets.pipelines
import
LoadImageFromFile
img_loader
=
None
for
t
in
self
.
val_dataset
.
pipeline
.
transforms
:
if
isinstance
(
t
,
LoadImageFromFile
):
img_loader
=
t
CLASSES
=
self
.
val_dataset
.
CLASSES
self
.
eval_image_indexs
=
np
.
arange
(
len
(
self
.
val_dataset
))
# Set seed so that same validation set is logged each time.
np
.
random
.
seed
(
42
)
np
.
random
.
shuffle
(
self
.
eval_image_indexs
)
self
.
eval_image_indexs
=
self
.
eval_image_indexs
[:
self
.
num_eval_images
]
for
idx
in
self
.
eval_image_indexs
:
img_info
=
self
.
val_dataset
.
data_infos
[
idx
]
if
img_loader
is
not
None
:
img_info
=
img_loader
(
img_info
)
# Get image and convert from BGR to RGB
image
=
img_info
[
'img'
][...,
::
-
1
]
else
:
# For CIFAR dataset.
image
=
img_info
[
'img'
]
image_name
=
img_info
.
get
(
'filename'
,
f
'img_
{
idx
}
'
)
gt_label
=
img_info
.
get
(
'gt_label'
).
item
()
self
.
data_table
.
add_data
(
image_name
,
self
.
wandb
.
Image
(
image
),
CLASSES
[
gt_label
])
def
_add_predictions
(
self
,
results
,
idx
):
table_idxs
=
self
.
data_table_ref
.
get_index
()
assert
len
(
table_idxs
)
==
len
(
self
.
eval_image_indexs
)
for
ndx
,
eval_image_index
in
enumerate
(
self
.
eval_image_indexs
):
result
=
results
[
eval_image_index
]
self
.
eval_table
.
add_data
(
idx
,
self
.
data_table_ref
.
data
[
ndx
][
0
],
self
.
data_table_ref
.
data
[
ndx
][
1
],
self
.
data_table_ref
.
data
[
ndx
][
2
],
self
.
val_dataset
.
CLASSES
[
np
.
argmax
(
result
)],
*
tuple
(
result
))
def
_log_data_table
(
self
):
"""Log the W&B Tables for validation data as artifact and calls
`use_artifact` on it so that the evaluation table can use the reference
of already uploaded images.
This allows the data to be uploaded just once.
"""
data_artifact
=
self
.
wandb
.
Artifact
(
'val'
,
type
=
'dataset'
)
data_artifact
.
add
(
self
.
data_table
,
'val_data'
)
self
.
wandb
.
run
.
use_artifact
(
data_artifact
)
data_artifact
.
wait
()
self
.
data_table_ref
=
data_artifact
.
get
(
'val_data'
)
def
_log_eval_table
(
self
,
idx
):
"""Log the W&B Tables for model evaluation.
The table will be logged multiple times creating new version. Use this
to compare models at different intervals interactively.
"""
pred_artifact
=
self
.
wandb
.
Artifact
(
f
'run_
{
self
.
wandb
.
run
.
id
}
_pred'
,
type
=
'evaluation'
)
pred_artifact
.
add
(
self
.
eval_table
,
'eval_data'
)
if
self
.
by_epoch
:
aliases
=
[
'latest'
,
f
'epoch_
{
idx
}
'
]
else
:
aliases
=
[
'latest'
,
f
'iter_
{
idx
}
'
]
self
.
wandb
.
run
.
log_artifact
(
pred_artifact
,
aliases
=
aliases
)
openmmlab_test/mmclassification-0.24.1/mmcls/core/optimizers/__init__.py
0 → 100644
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
from
.lamb
import
Lamb
__all__
=
[
'Lamb'
,
]
openmmlab_test/mmclassification-0.24.1/mmcls/core/optimizers/lamb.py
0 → 100644
View file @
0fd8347d
"""PyTorch Lamb optimizer w/ behaviour similar to NVIDIA FusedLamb.
This optimizer code was adapted from the following (starting with latest)
* https://github.com/HabanaAI/Model-References/blob/
2b435114fe8e31f159b1d3063b8280ae37af7423/PyTorch/nlp/bert/pretraining/lamb.py
* https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/
LanguageModeling/Transformer-XL/pytorch/lamb.py
* https://github.com/cybertronai/pytorch-lamb
Use FusedLamb if you can (GPU). The reason for including this variant of Lamb
is to have a version that is
similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or
cannot install/use APEX.
In addition to some cleanup, this Lamb impl has been modified to support
PyTorch XLA and has been tested on TPU.
Original copyrights for above sources are below.
Modifications Copyright 2021 Ross Wightman
"""
# Copyright (c) 2021, Habana Labs Ltd. All rights reserved.
# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# MIT License
#
# Copyright (c) 2019 cybertronai
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import
math
import
torch
from
mmcv.runner
import
OPTIMIZERS
from
torch.optim
import
Optimizer
@
OPTIMIZERS
.
register_module
()
class
Lamb
(
Optimizer
):
"""A pure pytorch variant of FuseLAMB (NvLamb variant) optimizer.
This class is copied from `timm`_. The LAMB was proposed in `Large Batch
Optimization for Deep Learning - Training BERT in 76 minutes`_.
.. _timm:
https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/lamb.py
.. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
https://arxiv.org/abs/1904.00962
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its norm. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
grad_averaging (bool, optional): whether apply (1-beta2) to grad when
calculating running averages of gradient. (default: True)
max_grad_norm (float, optional): value used to clip global grad norm
(default: 1.0)
trust_clip (bool): enable LAMBC trust ratio clipping (default: False)
always_adapt (boolean, optional): Apply adaptive learning rate to 0.0
weight decay parameter (default: False)
"""
# noqa: E501
def
__init__
(
self
,
params
,
lr
=
1e-3
,
bias_correction
=
True
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-6
,
weight_decay
=
0.01
,
grad_averaging
=
True
,
max_grad_norm
=
1.0
,
trust_clip
=
False
,
always_adapt
=
False
):
defaults
=
dict
(
lr
=
lr
,
bias_correction
=
bias_correction
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
,
grad_averaging
=
grad_averaging
,
max_grad_norm
=
max_grad_norm
,
trust_clip
=
trust_clip
,
always_adapt
=
always_adapt
)
super
().
__init__
(
params
,
defaults
)
@
torch
.
no_grad
()
def
step
(
self
,
closure
=
None
):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss
=
None
if
closure
is
not
None
:
with
torch
.
enable_grad
():
loss
=
closure
()
device
=
self
.
param_groups
[
0
][
'params'
][
0
].
device
one_tensor
=
torch
.
tensor
(
1.0
,
device
=
device
)
# because torch.where doesn't handle scalars correctly
global_grad_norm
=
torch
.
zeros
(
1
,
device
=
device
)
for
group
in
self
.
param_groups
:
for
p
in
group
[
'params'
]:
if
p
.
grad
is
None
:
continue
grad
=
p
.
grad
if
grad
.
is_sparse
:
raise
RuntimeError
(
'Lamb does not support sparse gradients, consider '
'SparseAdam instead.'
)
global_grad_norm
.
add_
(
grad
.
pow
(
2
).
sum
())
global_grad_norm
=
torch
.
sqrt
(
global_grad_norm
)
# FIXME it'd be nice to remove explicit tensor conversion of scalars
# when torch.where promotes
# scalar types properly https://github.com/pytorch/pytorch/issues/9190
max_grad_norm
=
torch
.
tensor
(
self
.
defaults
[
'max_grad_norm'
],
device
=
device
)
clip_global_grad_norm
=
torch
.
where
(
global_grad_norm
>
max_grad_norm
,
global_grad_norm
/
max_grad_norm
,
one_tensor
)
for
group
in
self
.
param_groups
:
bias_correction
=
1
if
group
[
'bias_correction'
]
else
0
beta1
,
beta2
=
group
[
'betas'
]
grad_averaging
=
1
if
group
[
'grad_averaging'
]
else
0
beta3
=
1
-
beta1
if
grad_averaging
else
1.0
# assume same step across group now to simplify things
# per parameter step can be easily support by making it tensor, or
# pass list into kernel
if
'step'
in
group
:
group
[
'step'
]
+=
1
else
:
group
[
'step'
]
=
1
if
bias_correction
:
bias_correction1
=
1
-
beta1
**
group
[
'step'
]
bias_correction2
=
1
-
beta2
**
group
[
'step'
]
else
:
bias_correction1
,
bias_correction2
=
1.0
,
1.0
for
p
in
group
[
'params'
]:
if
p
.
grad
is
None
:
continue
grad
=
p
.
grad
.
div_
(
clip_global_grad_norm
)
state
=
self
.
state
[
p
]
# State initialization
if
len
(
state
)
==
0
:
# Exponential moving average of gradient valuesa
state
[
'exp_avg'
]
=
torch
.
zeros_like
(
p
)
# Exponential moving average of squared gradient values
state
[
'exp_avg_sq'
]
=
torch
.
zeros_like
(
p
)
exp_avg
,
exp_avg_sq
=
state
[
'exp_avg'
],
state
[
'exp_avg_sq'
]
# Decay the first and second moment running average coefficient
exp_avg
.
mul_
(
beta1
).
add_
(
grad
,
alpha
=
beta3
)
# m_t
exp_avg_sq
.
mul_
(
beta2
).
addcmul_
(
grad
,
grad
,
value
=
1
-
beta2
)
# v_t
denom
=
(
exp_avg_sq
.
sqrt
()
/
math
.
sqrt
(
bias_correction2
)).
add_
(
group
[
'eps'
])
update
=
(
exp_avg
/
bias_correction1
).
div_
(
denom
)
weight_decay
=
group
[
'weight_decay'
]
if
weight_decay
!=
0
:
update
.
add_
(
p
,
alpha
=
weight_decay
)
if
weight_decay
!=
0
or
group
[
'always_adapt'
]:
# Layer-wise LR adaptation. By default, skip adaptation on
# parameters that are
# excluded from weight decay, unless always_adapt == True,
# then always enabled.
w_norm
=
p
.
norm
(
2.0
)
g_norm
=
update
.
norm
(
2.0
)
# FIXME nested where required since logical and/or not
# working in PT XLA
trust_ratio
=
torch
.
where
(
w_norm
>
0
,
torch
.
where
(
g_norm
>
0
,
w_norm
/
g_norm
,
one_tensor
),
one_tensor
,
)
if
group
[
'trust_clip'
]:
# LAMBC trust clipping, upper bound fixed at one
trust_ratio
=
torch
.
minimum
(
trust_ratio
,
one_tensor
)
update
.
mul_
(
trust_ratio
)
p
.
add_
(
update
,
alpha
=-
group
[
'lr'
])
return
loss
openmmlab_test/mmclassification-0.24.1/mmcls/core/utils/__init__.py
0 → 100644
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
from
.dist_utils
import
DistOptimizerHook
,
allreduce_grads
,
sync_random_seed
from
.misc
import
multi_apply
__all__
=
[
'allreduce_grads'
,
'DistOptimizerHook'
,
'multi_apply'
,
'sync_random_seed'
]
openmmlab_test/mmclassification-0.24.1/mmcls/core/utils/dist_utils.py
0 → 100644
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
from
collections
import
OrderedDict
import
numpy
as
np
import
torch
import
torch.distributed
as
dist
from
mmcv.runner
import
OptimizerHook
,
get_dist_info
from
torch._utils
import
(
_flatten_dense_tensors
,
_take_tensors
,
_unflatten_dense_tensors
)
def
_allreduce_coalesced
(
tensors
,
world_size
,
bucket_size_mb
=-
1
):
if
bucket_size_mb
>
0
:
bucket_size_bytes
=
bucket_size_mb
*
1024
*
1024
buckets
=
_take_tensors
(
tensors
,
bucket_size_bytes
)
else
:
buckets
=
OrderedDict
()
for
tensor
in
tensors
:
tp
=
tensor
.
type
()
if
tp
not
in
buckets
:
buckets
[
tp
]
=
[]
buckets
[
tp
].
append
(
tensor
)
buckets
=
buckets
.
values
()
for
bucket
in
buckets
:
flat_tensors
=
_flatten_dense_tensors
(
bucket
)
dist
.
all_reduce
(
flat_tensors
)
flat_tensors
.
div_
(
world_size
)
for
tensor
,
synced
in
zip
(
bucket
,
_unflatten_dense_tensors
(
flat_tensors
,
bucket
)):
tensor
.
copy_
(
synced
)
def
allreduce_grads
(
params
,
coalesce
=
True
,
bucket_size_mb
=-
1
):
grads
=
[
param
.
grad
.
data
for
param
in
params
if
param
.
requires_grad
and
param
.
grad
is
not
None
]
world_size
=
dist
.
get_world_size
()
if
coalesce
:
_allreduce_coalesced
(
grads
,
world_size
,
bucket_size_mb
)
else
:
for
tensor
in
grads
:
dist
.
all_reduce
(
tensor
.
div_
(
world_size
))
class
DistOptimizerHook
(
OptimizerHook
):
def
__init__
(
self
,
grad_clip
=
None
,
coalesce
=
True
,
bucket_size_mb
=-
1
):
self
.
grad_clip
=
grad_clip
self
.
coalesce
=
coalesce
self
.
bucket_size_mb
=
bucket_size_mb
def
after_train_iter
(
self
,
runner
):
runner
.
optimizer
.
zero_grad
()
runner
.
outputs
[
'loss'
].
backward
()
if
self
.
grad_clip
is
not
None
:
self
.
clip_grads
(
runner
.
model
.
parameters
())
runner
.
optimizer
.
step
()
def
sync_random_seed
(
seed
=
None
,
device
=
'cuda'
):
"""Make sure different ranks share the same seed.
All workers must call this function, otherwise it will deadlock.
This method is generally used in `DistributedSampler`,
because the seed should be identical across all processes
in the distributed group.
In distributed sampling, different ranks should sample non-overlapped
data in the dataset. Therefore, this function is used to make sure that
each rank shuffles the data indices in the same order based
on the same seed. Then different ranks could use different indices
to select non-overlapped data from the same data list.
Args:
seed (int, Optional): The seed. Default to None.
device (str): The device where the seed will be put on.
Default to 'cuda'.
Returns:
int: Seed to be used.
"""
if
seed
is
None
:
seed
=
np
.
random
.
randint
(
2
**
31
)
assert
isinstance
(
seed
,
int
)
rank
,
world_size
=
get_dist_info
()
if
world_size
==
1
:
return
seed
if
rank
==
0
:
random_num
=
torch
.
tensor
(
seed
,
dtype
=
torch
.
int32
,
device
=
device
)
else
:
random_num
=
torch
.
tensor
(
0
,
dtype
=
torch
.
int32
,
device
=
device
)
dist
.
broadcast
(
random_num
,
src
=
0
)
return
random_num
.
item
()
openmmlab_test/mmclassification-
speed-benchmark
/mmcls/core/utils/misc.py
→
openmmlab_test/mmclassification-
0.24.1
/mmcls/core/utils/misc.py
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
from
functools
import
partial
...
...
openmmlab_test/mmclassification-0.24.1/mmcls/core/visualization/__init__.py
0 → 100644
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
from
.image
import
(
BaseFigureContextManager
,
ImshowInfosContextManager
,
color_val_matplotlib
,
imshow_infos
)
__all__
=
[
'BaseFigureContextManager'
,
'ImshowInfosContextManager'
,
'imshow_infos'
,
'color_val_matplotlib'
]
openmmlab_test/mmclassification-0.24.1/mmcls/core/visualization/image.py
0 → 100644
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
import
matplotlib.pyplot
as
plt
import
mmcv
import
numpy
as
np
from
matplotlib.backend_bases
import
CloseEvent
# A small value
EPS
=
1e-2
def
color_val_matplotlib
(
color
):
"""Convert various input in BGR order to normalized RGB matplotlib color
tuples,
Args:
color (:obj:`mmcv.Color`/str/tuple/int/ndarray): Color inputs
Returns:
tuple[float]: A tuple of 3 normalized floats indicating RGB channels.
"""
color
=
mmcv
.
color_val
(
color
)
color
=
[
color
/
255
for
color
in
color
[::
-
1
]]
return
tuple
(
color
)
class
BaseFigureContextManager
:
"""Context Manager to reuse matplotlib figure.
It provides a figure for saving and a figure for showing to support
different settings.
Args:
axis (bool): Whether to show the axis lines.
fig_save_cfg (dict): Keyword parameters of figure for saving.
Defaults to empty dict.
fig_show_cfg (dict): Keyword parameters of figure for showing.
Defaults to empty dict.
"""
def
__init__
(
self
,
axis
=
False
,
fig_save_cfg
=
{},
fig_show_cfg
=
{})
->
None
:
self
.
is_inline
=
'inline'
in
plt
.
get_backend
()
# Because save and show need different figure size
# We set two figure and axes to handle save and show
self
.
fig_save
:
plt
.
Figure
=
None
self
.
fig_save_cfg
=
fig_save_cfg
self
.
ax_save
:
plt
.
Axes
=
None
self
.
fig_show
:
plt
.
Figure
=
None
self
.
fig_show_cfg
=
fig_show_cfg
self
.
ax_show
:
plt
.
Axes
=
None
self
.
axis
=
axis
def
__enter__
(
self
):
if
not
self
.
is_inline
:
# If use inline backend, we cannot control which figure to show,
# so disable the interactive fig_show, and put the initialization
# of fig_save to `prepare` function.
self
.
_initialize_fig_save
()
self
.
_initialize_fig_show
()
return
self
def
_initialize_fig_save
(
self
):
fig
=
plt
.
figure
(
**
self
.
fig_save_cfg
)
ax
=
fig
.
add_subplot
()
# remove white edges by set subplot margin
fig
.
subplots_adjust
(
left
=
0
,
right
=
1
,
bottom
=
0
,
top
=
1
)
self
.
fig_save
,
self
.
ax_save
=
fig
,
ax
def
_initialize_fig_show
(
self
):
# fig_save will be resized to image size, only fig_show needs fig_size.
fig
=
plt
.
figure
(
**
self
.
fig_show_cfg
)
ax
=
fig
.
add_subplot
()
# remove white edges by set subplot margin
fig
.
subplots_adjust
(
left
=
0
,
right
=
1
,
bottom
=
0
,
top
=
1
)
self
.
fig_show
,
self
.
ax_show
=
fig
,
ax
def
__exit__
(
self
,
exc_type
,
exc_value
,
traceback
):
if
self
.
is_inline
:
# If use inline backend, whether to close figure depends on if
# users want to show the image.
return
plt
.
close
(
self
.
fig_save
)
plt
.
close
(
self
.
fig_show
)
def
prepare
(
self
):
if
self
.
is_inline
:
# if use inline backend, just rebuild the fig_save.
self
.
_initialize_fig_save
()
self
.
ax_save
.
cla
()
self
.
ax_save
.
axis
(
self
.
axis
)
return
# If users force to destroy the window, rebuild fig_show.
if
not
plt
.
fignum_exists
(
self
.
fig_show
.
number
):
self
.
_initialize_fig_show
()
# Clear all axes
self
.
ax_save
.
cla
()
self
.
ax_save
.
axis
(
self
.
axis
)
self
.
ax_show
.
cla
()
self
.
ax_show
.
axis
(
self
.
axis
)
def
wait_continue
(
self
,
timeout
=
0
,
continue_key
=
' '
)
->
int
:
"""Show the image and wait for the user's input.
This implementation refers to
https://github.com/matplotlib/matplotlib/blob/v3.5.x/lib/matplotlib/_blocking_input.py
Args:
timeout (int): If positive, continue after ``timeout`` seconds.
Defaults to 0.
continue_key (str): The key for users to continue. Defaults to
the space key.
Returns:
int: If zero, means time out or the user pressed ``continue_key``,
and if one, means the user closed the show figure.
"""
# noqa: E501
if
self
.
is_inline
:
# If use inline backend, interactive input and timeout is no use.
return
if
self
.
fig_show
.
canvas
.
manager
:
# Ensure that the figure is shown
self
.
fig_show
.
show
()
while
True
:
# Connect the events to the handler function call.
event
=
None
def
handler
(
ev
):
# Set external event variable
nonlocal
event
# Qt backend may fire two events at the same time,
# use a condition to avoid missing close event.
event
=
ev
if
not
isinstance
(
event
,
CloseEvent
)
else
event
self
.
fig_show
.
canvas
.
stop_event_loop
()
cids
=
[
self
.
fig_show
.
canvas
.
mpl_connect
(
name
,
handler
)
for
name
in
(
'key_press_event'
,
'close_event'
)
]
try
:
self
.
fig_show
.
canvas
.
start_event_loop
(
timeout
)
finally
:
# Run even on exception like ctrl-c.
# Disconnect the callbacks.
for
cid
in
cids
:
self
.
fig_show
.
canvas
.
mpl_disconnect
(
cid
)
if
isinstance
(
event
,
CloseEvent
):
return
1
# Quit for close.
elif
event
is
None
or
event
.
key
==
continue_key
:
return
0
# Quit for continue.
class
ImshowInfosContextManager
(
BaseFigureContextManager
):
"""Context Manager to reuse matplotlib figure and put infos on images.
Args:
fig_size (tuple[int]): Size of the figure to show image.
Examples:
>>> import mmcv
>>> from mmcls.core import visualization as vis
>>> img1 = mmcv.imread("./1.png")
>>> info1 = {'class': 'cat', 'label': 0}
>>> img2 = mmcv.imread("./2.png")
>>> info2 = {'class': 'dog', 'label': 1}
>>> with vis.ImshowInfosContextManager() as manager:
... # Show img1
... manager.put_img_infos(img1, info1)
... # Show img2 on the same figure and save output image.
... manager.put_img_infos(
... img2, info2, out_file='./2_out.png')
"""
def
__init__
(
self
,
fig_size
=
(
15
,
10
)):
super
().
__init__
(
axis
=
False
,
# A proper dpi for image save with default font size.
fig_save_cfg
=
dict
(
frameon
=
False
,
dpi
=
36
),
fig_show_cfg
=
dict
(
frameon
=
False
,
figsize
=
fig_size
))
def
_put_text
(
self
,
ax
,
text
,
x
,
y
,
text_color
,
font_size
):
ax
.
text
(
x
,
y
,
f
'
{
text
}
'
,
bbox
=
{
'facecolor'
:
'black'
,
'alpha'
:
0.7
,
'pad'
:
0.2
,
'edgecolor'
:
'none'
,
'boxstyle'
:
'round'
},
color
=
text_color
,
fontsize
=
font_size
,
family
=
'monospace'
,
verticalalignment
=
'top'
,
horizontalalignment
=
'left'
)
def
put_img_infos
(
self
,
img
,
infos
,
text_color
=
'white'
,
font_size
=
26
,
row_width
=
20
,
win_name
=
''
,
show
=
True
,
wait_time
=
0
,
out_file
=
None
):
"""Show image with extra information.
Args:
img (str | ndarray): The image to be displayed.
infos (dict): Extra infos to display in the image.
text_color (:obj:`mmcv.Color`/str/tuple/int/ndarray): Extra infos
display color. Defaults to 'white'.
font_size (int): Extra infos display font size. Defaults to 26.
row_width (int): width between each row of results on the image.
win_name (str): The image title. Defaults to ''
show (bool): Whether to show the image. Defaults to True.
wait_time (int): How many seconds to display the image.
Defaults to 0.
out_file (Optional[str]): The filename to write the image.
Defaults to None.
Returns:
np.ndarray: The image with extra infomations.
"""
self
.
prepare
()
text_color
=
color_val_matplotlib
(
text_color
)
img
=
mmcv
.
imread
(
img
).
astype
(
np
.
uint8
)
x
,
y
=
3
,
row_width
//
2
img
=
mmcv
.
bgr2rgb
(
img
)
width
,
height
=
img
.
shape
[
1
],
img
.
shape
[
0
]
img
=
np
.
ascontiguousarray
(
img
)
# add a small EPS to avoid precision lost due to matplotlib's
# truncation (https://github.com/matplotlib/matplotlib/issues/15363)
dpi
=
self
.
fig_save
.
get_dpi
()
self
.
fig_save
.
set_size_inches
((
width
+
EPS
)
/
dpi
,
(
height
+
EPS
)
/
dpi
)
for
k
,
v
in
infos
.
items
():
if
isinstance
(
v
,
float
):
v
=
f
'
{
v
:.
2
f
}
'
label_text
=
f
'
{
k
}
:
{
v
}
'
self
.
_put_text
(
self
.
ax_save
,
label_text
,
x
,
y
,
text_color
,
font_size
)
if
show
and
not
self
.
is_inline
:
self
.
_put_text
(
self
.
ax_show
,
label_text
,
x
,
y
,
text_color
,
font_size
)
y
+=
row_width
self
.
ax_save
.
imshow
(
img
)
stream
,
_
=
self
.
fig_save
.
canvas
.
print_to_buffer
()
buffer
=
np
.
frombuffer
(
stream
,
dtype
=
'uint8'
)
img_rgba
=
buffer
.
reshape
(
height
,
width
,
4
)
rgb
,
_
=
np
.
split
(
img_rgba
,
[
3
],
axis
=
2
)
img_save
=
rgb
.
astype
(
'uint8'
)
img_save
=
mmcv
.
rgb2bgr
(
img_save
)
if
out_file
is
not
None
:
mmcv
.
imwrite
(
img_save
,
out_file
)
ret
=
0
if
show
and
not
self
.
is_inline
:
# Reserve some space for the tip.
self
.
ax_show
.
set_title
(
win_name
)
self
.
ax_show
.
set_ylim
(
height
+
20
)
self
.
ax_show
.
text
(
width
//
2
,
height
+
18
,
'Press SPACE to continue.'
,
ha
=
'center'
,
fontsize
=
font_size
)
self
.
ax_show
.
imshow
(
img
)
# Refresh canvas, necessary for Qt5 backend.
self
.
fig_show
.
canvas
.
draw
()
ret
=
self
.
wait_continue
(
timeout
=
wait_time
)
elif
(
not
show
)
and
self
.
is_inline
:
# If use inline backend, we use fig_save to show the image
# So we need to close it if users don't want to show.
plt
.
close
(
self
.
fig_save
)
return
ret
,
img_save
def
imshow_infos
(
img
,
infos
,
text_color
=
'white'
,
font_size
=
26
,
row_width
=
20
,
win_name
=
''
,
show
=
True
,
fig_size
=
(
15
,
10
),
wait_time
=
0
,
out_file
=
None
):
"""Show image with extra information.
Args:
img (str | ndarray): The image to be displayed.
infos (dict): Extra infos to display in the image.
text_color (:obj:`mmcv.Color`/str/tuple/int/ndarray): Extra infos
display color. Defaults to 'white'.
font_size (int): Extra infos display font size. Defaults to 26.
row_width (int): width between each row of results on the image.
win_name (str): The image title. Defaults to ''
show (bool): Whether to show the image. Defaults to True.
fig_size (tuple): Image show figure size. Defaults to (15, 10).
wait_time (int): How many seconds to display the image. Defaults to 0.
out_file (Optional[str]): The filename to write the image.
Defaults to None.
Returns:
np.ndarray: The image with extra infomations.
"""
with
ImshowInfosContextManager
(
fig_size
=
fig_size
)
as
manager
:
_
,
img
=
manager
.
put_img_infos
(
img
,
infos
,
text_color
=
text_color
,
font_size
=
font_size
,
row_width
=
row_width
,
win_name
=
win_name
,
show
=
show
,
wait_time
=
wait_time
,
out_file
=
out_file
)
return
img
openmmlab_test/mmclassification-0.24.1/mmcls/datasets/__init__.py
0 → 100644
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
from
.base_dataset
import
BaseDataset
from
.builder
import
(
DATASETS
,
PIPELINES
,
SAMPLERS
,
build_dataloader
,
build_dataset
,
build_sampler
)
from
.cifar
import
CIFAR10
,
CIFAR100
from
.cub
import
CUB
from
.custom
import
CustomDataset
from
.dataset_wrappers
import
(
ClassBalancedDataset
,
ConcatDataset
,
KFoldDataset
,
RepeatDataset
)
from
.imagenet
import
ImageNet
from
.imagenet21k
import
ImageNet21k
from
.mnist
import
MNIST
,
FashionMNIST
from
.multi_label
import
MultiLabelDataset
from
.samplers
import
DistributedSampler
,
RepeatAugSampler
from
.stanford_cars
import
StanfordCars
from
.voc
import
VOC
__all__
=
[
'BaseDataset'
,
'ImageNet'
,
'CIFAR10'
,
'CIFAR100'
,
'MNIST'
,
'FashionMNIST'
,
'VOC'
,
'MultiLabelDataset'
,
'build_dataloader'
,
'build_dataset'
,
'DistributedSampler'
,
'ConcatDataset'
,
'RepeatDataset'
,
'ClassBalancedDataset'
,
'DATASETS'
,
'PIPELINES'
,
'ImageNet21k'
,
'SAMPLERS'
,
'build_sampler'
,
'RepeatAugSampler'
,
'KFoldDataset'
,
'CUB'
,
'CustomDataset'
,
'StanfordCars'
]
openmmlab_test/mmclassification-
speed-benchmark
/mmcls/datasets/base_dataset.py
→
openmmlab_test/mmclassification-
0.24.1
/mmcls/datasets/base_dataset.py
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
import
copy
import
os.path
as
osp
from
abc
import
ABCMeta
,
abstractmethod
from
os
import
PathLike
from
typing
import
List
import
mmcv
import
numpy
as
np
...
...
@@ -10,6 +14,13 @@ from mmcls.models.losses import accuracy
from
.pipelines
import
Compose
def
expanduser
(
path
):
if
isinstance
(
path
,
(
str
,
PathLike
)):
return
osp
.
expanduser
(
path
)
else
:
return
path
class
BaseDataset
(
Dataset
,
metaclass
=
ABCMeta
):
"""Base dataset.
...
...
@@ -32,12 +43,11 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
ann_file
=
None
,
test_mode
=
False
):
super
(
BaseDataset
,
self
).
__init__
()
self
.
ann_file
=
ann_file
self
.
data_prefix
=
data_prefix
self
.
test_mode
=
test_mode
self
.
data_prefix
=
expanduser
(
data_prefix
)
self
.
pipeline
=
Compose
(
pipeline
)
self
.
CLASSES
=
self
.
get_classes
(
classes
)
self
.
ann_file
=
expanduser
(
ann_file
)
self
.
test_mode
=
test_mode
self
.
data_infos
=
self
.
load_annotations
()
@
abstractmethod
...
...
@@ -58,23 +68,23 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
"""Get all ground-truth labels (categories).
Returns:
list[int]
: categories for all images.
np.ndarray
: categories for all images.
"""
gt_labels
=
np
.
array
([
data
[
'gt_label'
]
for
data
in
self
.
data_infos
])
return
gt_labels
def
get_cat_ids
(
self
,
idx
)
:
def
get_cat_ids
(
self
,
idx
:
int
)
->
List
[
int
]
:
"""Get category id by index.
Args:
idx (int): Index of data.
Returns:
int: Image category of specified index.
cat_ids (List[
int
])
: Image category of specified index.
"""
return
self
.
data_infos
[
idx
][
'gt_label'
]
.
astype
(
np
.
int
)
return
[
int
(
self
.
data_infos
[
idx
][
'gt_label'
])
]
def
prepare_data
(
self
,
idx
):
results
=
copy
.
deepcopy
(
self
.
data_infos
[
idx
])
...
...
@@ -89,6 +99,7 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
@
classmethod
def
get_classes
(
cls
,
classes
=
None
):
"""Get class names of current dataset.
Args:
classes (Sequence[str] | str | None): If classes is None, use
default CLASSES defined by builtin dataset. If classes is a
...
...
@@ -104,7 +115,7 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
if
isinstance
(
classes
,
str
):
# take it as a file path
class_names
=
mmcv
.
list_from_file
(
classes
)
class_names
=
mmcv
.
list_from_file
(
expanduser
(
classes
)
)
elif
isinstance
(
classes
,
(
tuple
,
list
)):
class_names
=
classes
else
:
...
...
@@ -116,6 +127,7 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
results
,
metric
=
'accuracy'
,
metric_options
=
None
,
indices
=
None
,
logger
=
None
):
"""Evaluate the dataset.
...
...
@@ -126,6 +138,8 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
metric_options (dict, optional): Options for calculating metrics.
Allowed keys are 'topk', 'thrs' and 'average_mode'.
Defaults to None.
indices (list, optional): The indices of samples corresponding to
the results. Defaults to None.
logger (logging.Logger | str, optional): Logger used for printing
related information during evaluation. Defaults to None.
Returns:
...
...
@@ -143,20 +157,25 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
eval_results
=
{}
results
=
np
.
vstack
(
results
)
gt_labels
=
self
.
get_gt_labels
()
if
indices
is
not
None
:
gt_labels
=
gt_labels
[
indices
]
num_imgs
=
len
(
results
)
assert
len
(
gt_labels
)
==
num_imgs
,
'dataset testing results should '
\
'be of the same length as gt_labels.'
invalid_metrics
=
set
(
metrics
)
-
set
(
allowed_metrics
)
if
len
(
invalid_metrics
)
!=
0
:
raise
ValueError
(
f
'met
i
rc
{
invalid_metrics
}
is not supported.'
)
raise
ValueError
(
f
'metr
i
c
{
invalid_metrics
}
is not supported.'
)
topk
=
metric_options
.
get
(
'topk'
,
(
1
,
5
))
thrs
=
metric_options
.
get
(
'thrs'
)
average_mode
=
metric_options
.
get
(
'average_mode'
,
'macro'
)
if
'accuracy'
in
metrics
:
acc
=
accuracy
(
results
,
gt_labels
,
topk
=
topk
,
thrs
=
thrs
)
if
thrs
is
not
None
:
acc
=
accuracy
(
results
,
gt_labels
,
topk
=
topk
,
thrs
=
thrs
)
else
:
acc
=
accuracy
(
results
,
gt_labels
,
topk
=
topk
)
if
isinstance
(
topk
,
tuple
):
eval_results_
=
{
f
'accuracy_top-
{
k
}
'
:
a
...
...
@@ -182,8 +201,12 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
precision_recall_f1_keys
=
[
'precision'
,
'recall'
,
'f1_score'
]
if
len
(
set
(
metrics
)
&
set
(
precision_recall_f1_keys
))
!=
0
:
precision_recall_f1_values
=
precision_recall_f1
(
results
,
gt_labels
,
average_mode
=
average_mode
,
thrs
=
thrs
)
if
thrs
is
not
None
:
precision_recall_f1_values
=
precision_recall_f1
(
results
,
gt_labels
,
average_mode
=
average_mode
,
thrs
=
thrs
)
else
:
precision_recall_f1_values
=
precision_recall_f1
(
results
,
gt_labels
,
average_mode
=
average_mode
)
for
key
,
values
in
zip
(
precision_recall_f1_keys
,
precision_recall_f1_values
):
if
key
in
metrics
:
...
...
openmmlab_test/mmclassification-0.24.1/mmcls/datasets/builder.py
0 → 100644
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
import
copy
import
platform
import
random
from
functools
import
partial
import
numpy
as
np
import
torch
from
mmcv.parallel
import
collate
from
mmcv.runner
import
get_dist_info
from
mmcv.utils
import
Registry
,
build_from_cfg
,
digit_version
from
torch.utils.data
import
DataLoader
try
:
from
mmcv.utils
import
IS_IPU_AVAILABLE
except
ImportError
:
IS_IPU_AVAILABLE
=
False
if
platform
.
system
()
!=
'Windows'
:
# https://github.com/pytorch/pytorch/issues/973
import
resource
rlimit
=
resource
.
getrlimit
(
resource
.
RLIMIT_NOFILE
)
hard_limit
=
rlimit
[
1
]
soft_limit
=
min
(
4096
,
hard_limit
)
resource
.
setrlimit
(
resource
.
RLIMIT_NOFILE
,
(
soft_limit
,
hard_limit
))
DATASETS
=
Registry
(
'dataset'
)
PIPELINES
=
Registry
(
'pipeline'
)
SAMPLERS
=
Registry
(
'sampler'
)
def
build_dataset
(
cfg
,
default_args
=
None
):
from
.dataset_wrappers
import
(
ClassBalancedDataset
,
ConcatDataset
,
KFoldDataset
,
RepeatDataset
)
if
isinstance
(
cfg
,
(
list
,
tuple
)):
dataset
=
ConcatDataset
([
build_dataset
(
c
,
default_args
)
for
c
in
cfg
])
elif
cfg
[
'type'
]
==
'ConcatDataset'
:
dataset
=
ConcatDataset
(
[
build_dataset
(
c
,
default_args
)
for
c
in
cfg
[
'datasets'
]],
separate_eval
=
cfg
.
get
(
'separate_eval'
,
True
))
elif
cfg
[
'type'
]
==
'RepeatDataset'
:
dataset
=
RepeatDataset
(
build_dataset
(
cfg
[
'dataset'
],
default_args
),
cfg
[
'times'
])
elif
cfg
[
'type'
]
==
'ClassBalancedDataset'
:
dataset
=
ClassBalancedDataset
(
build_dataset
(
cfg
[
'dataset'
],
default_args
),
cfg
[
'oversample_thr'
])
elif
cfg
[
'type'
]
==
'KFoldDataset'
:
cp_cfg
=
copy
.
deepcopy
(
cfg
)
if
cp_cfg
.
get
(
'test_mode'
,
None
)
is
None
:
cp_cfg
[
'test_mode'
]
=
(
default_args
or
{}).
pop
(
'test_mode'
,
False
)
cp_cfg
[
'dataset'
]
=
build_dataset
(
cp_cfg
[
'dataset'
],
default_args
)
cp_cfg
.
pop
(
'type'
)
dataset
=
KFoldDataset
(
**
cp_cfg
)
else
:
dataset
=
build_from_cfg
(
cfg
,
DATASETS
,
default_args
)
return
dataset
def
build_dataloader
(
dataset
,
samples_per_gpu
,
workers_per_gpu
,
num_gpus
=
1
,
dist
=
True
,
shuffle
=
True
,
round_up
=
True
,
seed
=
None
,
pin_memory
=
True
,
persistent_workers
=
True
,
sampler_cfg
=
None
,
**
kwargs
):
"""Build PyTorch DataLoader.
In distributed training, each GPU/process has a dataloader.
In non-distributed training, there is only one dataloader for all GPUs.
Args:
dataset (Dataset): A PyTorch dataset.
samples_per_gpu (int): Number of training samples on each GPU, i.e.,
batch size of each GPU.
workers_per_gpu (int): How many subprocesses to use for data loading
for each GPU.
num_gpus (int): Number of GPUs. Only used in non-distributed training.
dist (bool): Distributed training/test or not. Default: True.
shuffle (bool): Whether to shuffle the data at every epoch.
Default: True.
round_up (bool): Whether to round up the length of dataset by adding
extra samples to make it evenly divisible. Default: True.
pin_memory (bool): Whether to use pin_memory in DataLoader.
Default: True
persistent_workers (bool): If True, the data loader will not shutdown
the worker processes after a dataset has been consumed once.
This allows to maintain the workers Dataset instances alive.
The argument also has effect in PyTorch>=1.7.0.
Default: True
sampler_cfg (dict): sampler configuration to override the default
sampler
kwargs: any keyword argument to be used to initialize DataLoader
Returns:
DataLoader: A PyTorch dataloader.
"""
rank
,
world_size
=
get_dist_info
()
# Custom sampler logic
if
sampler_cfg
:
# shuffle=False when val and test
sampler_cfg
.
update
(
shuffle
=
shuffle
)
sampler
=
build_sampler
(
sampler_cfg
,
default_args
=
dict
(
dataset
=
dataset
,
num_replicas
=
world_size
,
rank
=
rank
,
seed
=
seed
))
# Default sampler logic
elif
dist
:
sampler
=
build_sampler
(
dict
(
type
=
'DistributedSampler'
,
dataset
=
dataset
,
num_replicas
=
world_size
,
rank
=
rank
,
shuffle
=
shuffle
,
round_up
=
round_up
,
seed
=
seed
))
else
:
sampler
=
None
# If sampler exists, turn off dataloader shuffle
if
sampler
is
not
None
:
shuffle
=
False
if
dist
:
batch_size
=
samples_per_gpu
num_workers
=
workers_per_gpu
else
:
batch_size
=
num_gpus
*
samples_per_gpu
num_workers
=
num_gpus
*
workers_per_gpu
init_fn
=
partial
(
worker_init_fn
,
num_workers
=
num_workers
,
rank
=
rank
,
seed
=
seed
)
if
seed
is
not
None
else
None
if
digit_version
(
torch
.
__version__
)
>=
digit_version
(
'1.8.0'
):
kwargs
[
'persistent_workers'
]
=
persistent_workers
if
IS_IPU_AVAILABLE
:
from
mmcv.device.ipu
import
IPUDataLoader
data_loader
=
IPUDataLoader
(
dataset
,
None
,
batch_size
=
samples_per_gpu
,
num_workers
=
num_workers
,
shuffle
=
shuffle
,
worker_init_fn
=
init_fn
,
**
kwargs
)
else
:
data_loader
=
DataLoader
(
dataset
,
batch_size
=
batch_size
,
sampler
=
sampler
,
num_workers
=
num_workers
,
collate_fn
=
partial
(
collate
,
samples_per_gpu
=
samples_per_gpu
),
pin_memory
=
pin_memory
,
shuffle
=
shuffle
,
worker_init_fn
=
init_fn
,
**
kwargs
)
return
data_loader
def
worker_init_fn
(
worker_id
,
num_workers
,
rank
,
seed
):
# The seed of each worker equals to
# num_worker * rank + worker_id + user_seed
worker_seed
=
num_workers
*
rank
+
worker_id
+
seed
np
.
random
.
seed
(
worker_seed
)
random
.
seed
(
worker_seed
)
torch
.
manual_seed
(
worker_seed
)
def
build_sampler
(
cfg
,
default_args
=
None
):
if
cfg
is
None
:
return
None
else
:
return
build_from_cfg
(
cfg
,
SAMPLERS
,
default_args
=
default_args
)
openmmlab_test/mmclassification-
speed-benchmark
/mmcls/datasets/cifar.py
→
openmmlab_test/mmclassification-
0.24.1
/mmcls/datasets/cifar.py
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
import
os
import
os.path
import
pickle
...
...
@@ -16,8 +17,8 @@ class CIFAR10(BaseDataset):
"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
This implementation is modified from
https://github.com/pytorch/vision/blob/master/torchvision/datasets/cifar.py
# noqa: E501
"""
https://github.com/pytorch/vision/blob/master/torchvision/datasets/cifar.py
"""
# noqa: E501
base_folder
=
'cifar-10-batches-py'
url
=
'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
...
...
@@ -39,6 +40,10 @@ class CIFAR10(BaseDataset):
'key'
:
'label_names'
,
'md5'
:
'5ff9c542aee3614f3951f8cda6e48888'
,
}
CLASSES
=
[
'airplane'
,
'automobile'
,
'bird'
,
'cat'
,
'deer'
,
'dog'
,
'frog'
,
'horse'
,
'ship'
,
'truck'
]
def
load_annotations
(
self
):
...
...
@@ -130,3 +135,21 @@ class CIFAR100(CIFAR10):
'key'
:
'fine_label_names'
,
'md5'
:
'7973b15100ade9c7d40fb424638fde48'
,
}
CLASSES
=
[
'apple'
,
'aquarium_fish'
,
'baby'
,
'bear'
,
'beaver'
,
'bed'
,
'bee'
,
'beetle'
,
'bicycle'
,
'bottle'
,
'bowl'
,
'boy'
,
'bridge'
,
'bus'
,
'butterfly'
,
'camel'
,
'can'
,
'castle'
,
'caterpillar'
,
'cattle'
,
'chair'
,
'chimpanzee'
,
'clock'
,
'cloud'
,
'cockroach'
,
'couch'
,
'crab'
,
'crocodile'
,
'cup'
,
'dinosaur'
,
'dolphin'
,
'elephant'
,
'flatfish'
,
'forest'
,
'fox'
,
'girl'
,
'hamster'
,
'house'
,
'kangaroo'
,
'keyboard'
,
'lamp'
,
'lawn_mower'
,
'leopard'
,
'lion'
,
'lizard'
,
'lobster'
,
'man'
,
'maple_tree'
,
'motorcycle'
,
'mountain'
,
'mouse'
,
'mushroom'
,
'oak_tree'
,
'orange'
,
'orchid'
,
'otter'
,
'palm_tree'
,
'pear'
,
'pickup_truck'
,
'pine_tree'
,
'plain'
,
'plate'
,
'poppy'
,
'porcupine'
,
'possum'
,
'rabbit'
,
'raccoon'
,
'ray'
,
'road'
,
'rocket'
,
'rose'
,
'sea'
,
'seal'
,
'shark'
,
'shrew'
,
'skunk'
,
'skyscraper'
,
'snail'
,
'snake'
,
'spider'
,
'squirrel'
,
'streetcar'
,
'sunflower'
,
'sweet_pepper'
,
'table'
,
'tank'
,
'telephone'
,
'television'
,
'tiger'
,
'tractor'
,
'train'
,
'trout'
,
'tulip'
,
'turtle'
,
'wardrobe'
,
'whale'
,
'willow_tree'
,
'wolf'
,
'woman'
,
'worm'
]
openmmlab_test/mmclassification-0.24.1/mmcls/datasets/cub.py
0 → 100644
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
import
numpy
as
np
from
.base_dataset
import
BaseDataset
from
.builder
import
DATASETS
@
DATASETS
.
register_module
()
class
CUB
(
BaseDataset
):
"""The CUB-200-2011 Dataset.
Support the `CUB-200-2011 <http://www.vision.caltech.edu/visipedia/CUB-200-2011.html>`_ Dataset.
Comparing with the `CUB-200 <http://www.vision.caltech.edu/visipedia/CUB-200.html>`_ Dataset,
there are much more pictures in `CUB-200-2011`.
Args:
ann_file (str): the annotation file.
images.txt in CUB.
image_class_labels_file (str): the label file.
image_class_labels.txt in CUB.
train_test_split_file (str): the split file.
train_test_split_file.txt in CUB.
"""
# noqa: E501
CLASSES
=
[
'Black_footed_Albatross'
,
'Laysan_Albatross'
,
'Sooty_Albatross'
,
'Groove_billed_Ani'
,
'Crested_Auklet'
,
'Least_Auklet'
,
'Parakeet_Auklet'
,
'Rhinoceros_Auklet'
,
'Brewer_Blackbird'
,
'Red_winged_Blackbird'
,
'Rusty_Blackbird'
,
'Yellow_headed_Blackbird'
,
'Bobolink'
,
'Indigo_Bunting'
,
'Lazuli_Bunting'
,
'Painted_Bunting'
,
'Cardinal'
,
'Spotted_Catbird'
,
'Gray_Catbird'
,
'Yellow_breasted_Chat'
,
'Eastern_Towhee'
,
'Chuck_will_Widow'
,
'Brandt_Cormorant'
,
'Red_faced_Cormorant'
,
'Pelagic_Cormorant'
,
'Bronzed_Cowbird'
,
'Shiny_Cowbird'
,
'Brown_Creeper'
,
'American_Crow'
,
'Fish_Crow'
,
'Black_billed_Cuckoo'
,
'Mangrove_Cuckoo'
,
'Yellow_billed_Cuckoo'
,
'Gray_crowned_Rosy_Finch'
,
'Purple_Finch'
,
'Northern_Flicker'
,
'Acadian_Flycatcher'
,
'Great_Crested_Flycatcher'
,
'Least_Flycatcher'
,
'Olive_sided_Flycatcher'
,
'Scissor_tailed_Flycatcher'
,
'Vermilion_Flycatcher'
,
'Yellow_bellied_Flycatcher'
,
'Frigatebird'
,
'Northern_Fulmar'
,
'Gadwall'
,
'American_Goldfinch'
,
'European_Goldfinch'
,
'Boat_tailed_Grackle'
,
'Eared_Grebe'
,
'Horned_Grebe'
,
'Pied_billed_Grebe'
,
'Western_Grebe'
,
'Blue_Grosbeak'
,
'Evening_Grosbeak'
,
'Pine_Grosbeak'
,
'Rose_breasted_Grosbeak'
,
'Pigeon_Guillemot'
,
'California_Gull'
,
'Glaucous_winged_Gull'
,
'Heermann_Gull'
,
'Herring_Gull'
,
'Ivory_Gull'
,
'Ring_billed_Gull'
,
'Slaty_backed_Gull'
,
'Western_Gull'
,
'Anna_Hummingbird'
,
'Ruby_throated_Hummingbird'
,
'Rufous_Hummingbird'
,
'Green_Violetear'
,
'Long_tailed_Jaeger'
,
'Pomarine_Jaeger'
,
'Blue_Jay'
,
'Florida_Jay'
,
'Green_Jay'
,
'Dark_eyed_Junco'
,
'Tropical_Kingbird'
,
'Gray_Kingbird'
,
'Belted_Kingfisher'
,
'Green_Kingfisher'
,
'Pied_Kingfisher'
,
'Ringed_Kingfisher'
,
'White_breasted_Kingfisher'
,
'Red_legged_Kittiwake'
,
'Horned_Lark'
,
'Pacific_Loon'
,
'Mallard'
,
'Western_Meadowlark'
,
'Hooded_Merganser'
,
'Red_breasted_Merganser'
,
'Mockingbird'
,
'Nighthawk'
,
'Clark_Nutcracker'
,
'White_breasted_Nuthatch'
,
'Baltimore_Oriole'
,
'Hooded_Oriole'
,
'Orchard_Oriole'
,
'Scott_Oriole'
,
'Ovenbird'
,
'Brown_Pelican'
,
'White_Pelican'
,
'Western_Wood_Pewee'
,
'Sayornis'
,
'American_Pipit'
,
'Whip_poor_Will'
,
'Horned_Puffin'
,
'Common_Raven'
,
'White_necked_Raven'
,
'American_Redstart'
,
'Geococcyx'
,
'Loggerhead_Shrike'
,
'Great_Grey_Shrike'
,
'Baird_Sparrow'
,
'Black_throated_Sparrow'
,
'Brewer_Sparrow'
,
'Chipping_Sparrow'
,
'Clay_colored_Sparrow'
,
'House_Sparrow'
,
'Field_Sparrow'
,
'Fox_Sparrow'
,
'Grasshopper_Sparrow'
,
'Harris_Sparrow'
,
'Henslow_Sparrow'
,
'Le_Conte_Sparrow'
,
'Lincoln_Sparrow'
,
'Nelson_Sharp_tailed_Sparrow'
,
'Savannah_Sparrow'
,
'Seaside_Sparrow'
,
'Song_Sparrow'
,
'Tree_Sparrow'
,
'Vesper_Sparrow'
,
'White_crowned_Sparrow'
,
'White_throated_Sparrow'
,
'Cape_Glossy_Starling'
,
'Bank_Swallow'
,
'Barn_Swallow'
,
'Cliff_Swallow'
,
'Tree_Swallow'
,
'Scarlet_Tanager'
,
'Summer_Tanager'
,
'Artic_Tern'
,
'Black_Tern'
,
'Caspian_Tern'
,
'Common_Tern'
,
'Elegant_Tern'
,
'Forsters_Tern'
,
'Least_Tern'
,
'Green_tailed_Towhee'
,
'Brown_Thrasher'
,
'Sage_Thrasher'
,
'Black_capped_Vireo'
,
'Blue_headed_Vireo'
,
'Philadelphia_Vireo'
,
'Red_eyed_Vireo'
,
'Warbling_Vireo'
,
'White_eyed_Vireo'
,
'Yellow_throated_Vireo'
,
'Bay_breasted_Warbler'
,
'Black_and_white_Warbler'
,
'Black_throated_Blue_Warbler'
,
'Blue_winged_Warbler'
,
'Canada_Warbler'
,
'Cape_May_Warbler'
,
'Cerulean_Warbler'
,
'Chestnut_sided_Warbler'
,
'Golden_winged_Warbler'
,
'Hooded_Warbler'
,
'Kentucky_Warbler'
,
'Magnolia_Warbler'
,
'Mourning_Warbler'
,
'Myrtle_Warbler'
,
'Nashville_Warbler'
,
'Orange_crowned_Warbler'
,
'Palm_Warbler'
,
'Pine_Warbler'
,
'Prairie_Warbler'
,
'Prothonotary_Warbler'
,
'Swainson_Warbler'
,
'Tennessee_Warbler'
,
'Wilson_Warbler'
,
'Worm_eating_Warbler'
,
'Yellow_Warbler'
,
'Northern_Waterthrush'
,
'Louisiana_Waterthrush'
,
'Bohemian_Waxwing'
,
'Cedar_Waxwing'
,
'American_Three_toed_Woodpecker'
,
'Pileated_Woodpecker'
,
'Red_bellied_Woodpecker'
,
'Red_cockaded_Woodpecker'
,
'Red_headed_Woodpecker'
,
'Downy_Woodpecker'
,
'Bewick_Wren'
,
'Cactus_Wren'
,
'Carolina_Wren'
,
'House_Wren'
,
'Marsh_Wren'
,
'Rock_Wren'
,
'Winter_Wren'
,
'Common_Yellowthroat'
]
def
__init__
(
self
,
*
args
,
ann_file
,
image_class_labels_file
,
train_test_split_file
,
**
kwargs
):
self
.
image_class_labels_file
=
image_class_labels_file
self
.
train_test_split_file
=
train_test_split_file
super
(
CUB
,
self
).
__init__
(
*
args
,
ann_file
=
ann_file
,
**
kwargs
)
def
load_annotations
(
self
):
with
open
(
self
.
ann_file
)
as
f
:
samples
=
[
x
.
strip
().
split
(
' '
)[
1
]
for
x
in
f
.
readlines
()]
with
open
(
self
.
image_class_labels_file
)
as
f
:
gt_labels
=
[
# in the official CUB-200-2011 dataset, labels in
# image_class_labels_file are started from 1, so
# here we need to '- 1' to let them start from 0.
int
(
x
.
strip
().
split
(
' '
)[
1
])
-
1
for
x
in
f
.
readlines
()
]
with
open
(
self
.
train_test_split_file
)
as
f
:
splits
=
[
int
(
x
.
strip
().
split
(
' '
)[
1
])
for
x
in
f
.
readlines
()]
assert
len
(
samples
)
==
len
(
gt_labels
)
==
len
(
splits
),
\
f
'samples(
{
len
(
samples
)
}
), gt_labels(
{
len
(
gt_labels
)
}
) and '
\
f
'splits(
{
len
(
splits
)
}
) should have same length.'
data_infos
=
[]
for
filename
,
gt_label
,
split
in
zip
(
samples
,
gt_labels
,
splits
):
if
split
and
self
.
test_mode
:
# skip train samples when test_mode=True
continue
elif
not
split
and
not
self
.
test_mode
:
# skip test samples when test_mode=False
continue
info
=
{
'img_prefix'
:
self
.
data_prefix
}
info
[
'img_info'
]
=
{
'filename'
:
filename
}
info
[
'gt_label'
]
=
np
.
array
(
gt_label
,
dtype
=
np
.
int64
)
data_infos
.
append
(
info
)
return
data_infos
openmmlab_test/mmclassification-0.24.1/mmcls/datasets/custom.py
0 → 100644
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
import
warnings
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
import
mmcv
import
numpy
as
np
from
mmcv
import
FileClient
from
.base_dataset
import
BaseDataset
from
.builder
import
DATASETS
def
find_folders
(
root
:
str
,
file_client
:
FileClient
)
->
Tuple
[
List
[
str
],
Dict
[
str
,
int
]]:
"""Find classes by folders under a root.
Args:
root (string): root directory of folders
Returns:
Tuple[List[str], Dict[str, int]]:
- folders: The name of sub folders under the root.
- folder_to_idx: The map from folder name to class idx.
"""
folders
=
list
(
file_client
.
list_dir_or_file
(
root
,
list_dir
=
True
,
list_file
=
False
,
recursive
=
False
,
))
folders
.
sort
()
folder_to_idx
=
{
folders
[
i
]:
i
for
i
in
range
(
len
(
folders
))}
return
folders
,
folder_to_idx
def
get_samples
(
root
:
str
,
folder_to_idx
:
Dict
[
str
,
int
],
is_valid_file
:
Callable
,
file_client
:
FileClient
):
"""Make dataset by walking all images under a root.
Args:
root (string): root directory of folders
folder_to_idx (dict): the map from class name to class idx
is_valid_file (Callable): A function that takes path of a file
and check if the file is a valid sample file.
Returns:
Tuple[list, set]:
- samples: a list of tuple where each element is (image, class_idx)
- empty_folders: The folders don't have any valid files.
"""
samples
=
[]
available_classes
=
set
()
for
folder_name
in
sorted
(
list
(
folder_to_idx
.
keys
())):
_dir
=
file_client
.
join_path
(
root
,
folder_name
)
files
=
list
(
file_client
.
list_dir_or_file
(
_dir
,
list_dir
=
False
,
list_file
=
True
,
recursive
=
True
,
))
for
file
in
sorted
(
list
(
files
)):
if
is_valid_file
(
file
):
path
=
file_client
.
join_path
(
folder_name
,
file
)
item
=
(
path
,
folder_to_idx
[
folder_name
])
samples
.
append
(
item
)
available_classes
.
add
(
folder_name
)
empty_folders
=
set
(
folder_to_idx
.
keys
())
-
available_classes
return
samples
,
empty_folders
@
DATASETS
.
register_module
()
class
CustomDataset
(
BaseDataset
):
"""Custom dataset for classification.
The dataset supports two kinds of annotation format.
1. An annotation file is provided, and each line indicates a sample:
The sample files: ::
data_prefix/
├── folder_1
│ ├── xxx.png
│ ├── xxy.png
│ └── ...
└── folder_2
├── 123.png
├── nsdf3.png
└── ...
The annotation file (the first column is the image path and the second
column is the index of category): ::
folder_1/xxx.png 0
folder_1/xxy.png 1
folder_2/123.png 5
folder_2/nsdf3.png 3
...
Please specify the name of categories by the argument ``classes``.
2. The samples are arranged in the specific way: ::
data_prefix/
├── class_x
│ ├── xxx.png
│ ├── xxy.png
│ └── ...
│ └── xxz.png
└── class_y
├── 123.png
├── nsdf3.png
├── ...
└── asd932_.png
If the ``ann_file`` is specified, the dataset will be generated by the
first way, otherwise, try the second way.
Args:
data_prefix (str): The path of data directory.
pipeline (Sequence[dict]): A list of dict, where each element
represents a operation defined in :mod:`mmcls.datasets.pipelines`.
Defaults to an empty tuple.
classes (str | Sequence[str], optional): Specify names of classes.
- If is string, it should be a file path, and the every line of
the file is a name of a class.
- If is a sequence of string, every item is a name of class.
- If is None, use ``cls.CLASSES`` or the names of sub folders
(If use the second way to arrange samples).
Defaults to None.
ann_file (str, optional): The annotation file. If is string, read
samples paths from the ann_file. If is None, find samples in
``data_prefix``. Defaults to None.
extensions (Sequence[str]): A sequence of allowed extensions. Defaults
to ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif').
test_mode (bool): In train mode or test mode. It's only a mark and
won't be used in this class. Defaults to False.
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details.
If None, automatically inference from the specified path.
Defaults to None.
"""
def
__init__
(
self
,
data_prefix
:
str
,
pipeline
:
Sequence
=
(),
classes
:
Union
[
str
,
Sequence
[
str
],
None
]
=
None
,
ann_file
:
Optional
[
str
]
=
None
,
extensions
:
Sequence
[
str
]
=
(
'.jpg'
,
'.jpeg'
,
'.png'
,
'.ppm'
,
'.bmp'
,
'.pgm'
,
'.tif'
),
test_mode
:
bool
=
False
,
file_client_args
:
Optional
[
dict
]
=
None
):
self
.
extensions
=
tuple
(
set
([
i
.
lower
()
for
i
in
extensions
]))
self
.
file_client_args
=
file_client_args
super
().
__init__
(
data_prefix
=
data_prefix
,
pipeline
=
pipeline
,
classes
=
classes
,
ann_file
=
ann_file
,
test_mode
=
test_mode
)
def
_find_samples
(
self
):
"""find samples from ``data_prefix``."""
file_client
=
FileClient
.
infer_client
(
self
.
file_client_args
,
self
.
data_prefix
)
classes
,
folder_to_idx
=
find_folders
(
self
.
data_prefix
,
file_client
)
samples
,
empty_classes
=
get_samples
(
self
.
data_prefix
,
folder_to_idx
,
is_valid_file
=
self
.
is_valid_file
,
file_client
=
file_client
,
)
if
len
(
samples
)
==
0
:
raise
RuntimeError
(
f
'Found 0 files in subfolders of:
{
self
.
data_prefix
}
. '
f
'Supported extensions are:
{
","
.
join
(
self
.
extensions
)
}
'
)
if
self
.
CLASSES
is
not
None
:
assert
len
(
self
.
CLASSES
)
==
len
(
classes
),
\
f
"The number of subfolders (
{
len
(
classes
)
}
) doesn't match "
\
f
'the number of specified classes (
{
len
(
self
.
CLASSES
)
}
). '
\
'Please check the data folder.'
else
:
self
.
CLASSES
=
classes
if
empty_classes
:
warnings
.
warn
(
'Found no valid file in the folder '
f
'
{
", "
.
join
(
empty_classes
)
}
. '
f
"Supported extensions are:
{
', '
.
join
(
self
.
extensions
)
}
"
,
UserWarning
)
self
.
folder_to_idx
=
folder_to_idx
return
samples
def
load_annotations
(
self
):
"""Load image paths and gt_labels."""
if
self
.
ann_file
is
None
:
samples
=
self
.
_find_samples
()
elif
isinstance
(
self
.
ann_file
,
str
):
lines
=
mmcv
.
list_from_file
(
self
.
ann_file
,
file_client_args
=
self
.
file_client_args
)
samples
=
[
x
.
strip
().
rsplit
(
' '
,
1
)
for
x
in
lines
]
else
:
raise
TypeError
(
'ann_file must be a str or None'
)
data_infos
=
[]
for
filename
,
gt_label
in
samples
:
info
=
{
'img_prefix'
:
self
.
data_prefix
}
info
[
'img_info'
]
=
{
'filename'
:
filename
}
info
[
'gt_label'
]
=
np
.
array
(
gt_label
,
dtype
=
np
.
int64
)
data_infos
.
append
(
info
)
return
data_infos
def
is_valid_file
(
self
,
filename
:
str
)
->
bool
:
"""Check if a file is a valid sample."""
return
filename
.
lower
().
endswith
(
self
.
extensions
)
openmmlab_test/mmclassification-0.24.1/mmcls/datasets/dataset_wrappers.py
0 → 100644
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
import
bisect
import
math
from
collections
import
defaultdict
import
numpy
as
np
from
mmcv.utils
import
print_log
from
torch.utils.data.dataset
import
ConcatDataset
as
_ConcatDataset
from
.builder
import
DATASETS
@
DATASETS
.
register_module
()
class
ConcatDataset
(
_ConcatDataset
):
"""A wrapper of concatenated dataset.
Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
add `get_cat_ids` function.
Args:
datasets (list[:obj:`BaseDataset`]): A list of datasets.
separate_eval (bool): Whether to evaluate the results
separately if it is used as validation dataset.
Defaults to True.
"""
def
__init__
(
self
,
datasets
,
separate_eval
=
True
):
super
(
ConcatDataset
,
self
).
__init__
(
datasets
)
self
.
separate_eval
=
separate_eval
self
.
CLASSES
=
datasets
[
0
].
CLASSES
if
not
separate_eval
:
if
len
(
set
([
type
(
ds
)
for
ds
in
datasets
]))
!=
1
:
raise
NotImplementedError
(
'To evaluate a concat dataset non-separately, '
'all the datasets should have same types'
)
def
get_cat_ids
(
self
,
idx
):
if
idx
<
0
:
if
-
idx
>
len
(
self
):
raise
ValueError
(
'absolute value of index should not exceed dataset length'
)
idx
=
len
(
self
)
+
idx
dataset_idx
=
bisect
.
bisect_right
(
self
.
cumulative_sizes
,
idx
)
if
dataset_idx
==
0
:
sample_idx
=
idx
else
:
sample_idx
=
idx
-
self
.
cumulative_sizes
[
dataset_idx
-
1
]
return
self
.
datasets
[
dataset_idx
].
get_cat_ids
(
sample_idx
)
def
evaluate
(
self
,
results
,
*
args
,
indices
=
None
,
logger
=
None
,
**
kwargs
):
"""Evaluate the results.
Args:
results (list[list | tuple]): Testing results of the dataset.
indices (list, optional): The indices of samples corresponding to
the results. It's unavailable on ConcatDataset.
Defaults to None.
logger (logging.Logger | str, optional): Logger used for printing
related information during evaluation. Defaults to None.
Returns:
dict[str: float]: AP results of the total dataset or each separate
dataset if `self.separate_eval=True`.
"""
if
indices
is
not
None
:
raise
NotImplementedError
(
'Use indices to evaluate speific samples in a ConcatDataset '
'is not supported by now.'
)
assert
len
(
results
)
==
len
(
self
),
\
(
'Dataset and results have different sizes: '
f
'
{
len
(
self
)
}
v.s.
{
len
(
results
)
}
'
)
# Check whether all the datasets support evaluation
for
dataset
in
self
.
datasets
:
assert
hasattr
(
dataset
,
'evaluate'
),
\
f
"
{
type
(
dataset
)
}
haven't implemented the evaluate function."
if
self
.
separate_eval
:
total_eval_results
=
dict
()
for
dataset_idx
,
dataset
in
enumerate
(
self
.
datasets
):
start_idx
=
0
if
dataset_idx
==
0
else
\
self
.
cumulative_sizes
[
dataset_idx
-
1
]
end_idx
=
self
.
cumulative_sizes
[
dataset_idx
]
results_per_dataset
=
results
[
start_idx
:
end_idx
]
print_log
(
f
'Evaluateing dataset-
{
dataset_idx
}
with '
f
'
{
len
(
results_per_dataset
)
}
images now'
,
logger
=
logger
)
eval_results_per_dataset
=
dataset
.
evaluate
(
results_per_dataset
,
*
args
,
logger
=
logger
,
**
kwargs
)
for
k
,
v
in
eval_results_per_dataset
.
items
():
total_eval_results
.
update
({
f
'
{
dataset_idx
}
_
{
k
}
'
:
v
})
return
total_eval_results
else
:
original_data_infos
=
self
.
datasets
[
0
].
data_infos
self
.
datasets
[
0
].
data_infos
=
sum
(
[
dataset
.
data_infos
for
dataset
in
self
.
datasets
],
[])
eval_results
=
self
.
datasets
[
0
].
evaluate
(
results
,
logger
=
logger
,
**
kwargs
)
self
.
datasets
[
0
].
data_infos
=
original_data_infos
return
eval_results
@
DATASETS
.
register_module
()
class
RepeatDataset
(
object
):
"""A wrapper of repeated dataset.
The length of repeated dataset will be `times` larger than the original
dataset. This is useful when the data loading time is long but the dataset
is small. Using RepeatDataset can reduce the data loading time between
epochs.
Args:
dataset (:obj:`BaseDataset`): The dataset to be repeated.
times (int): Repeat times.
"""
def
__init__
(
self
,
dataset
,
times
):
self
.
dataset
=
dataset
self
.
times
=
times
self
.
CLASSES
=
dataset
.
CLASSES
self
.
_ori_len
=
len
(
self
.
dataset
)
def
__getitem__
(
self
,
idx
):
return
self
.
dataset
[
idx
%
self
.
_ori_len
]
def
get_cat_ids
(
self
,
idx
):
return
self
.
dataset
.
get_cat_ids
(
idx
%
self
.
_ori_len
)
def
__len__
(
self
):
return
self
.
times
*
self
.
_ori_len
def
evaluate
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
(
'evaluate results on a repeated dataset is weird. '
'Please inference and evaluate on the original dataset.'
)
def
__repr__
(
self
):
"""Print the number of instance number."""
dataset_type
=
'Test'
if
self
.
test_mode
else
'Train'
result
=
(
f
'
\n
{
self
.
__class__
.
__name__
}
(
{
self
.
dataset
.
__class__
.
__name__
}
) '
f
'
{
dataset_type
}
dataset with total number of samples
{
len
(
self
)
}
.'
)
return
result
# Modified from https://github.com/facebookresearch/detectron2/blob/41d475b75a230221e21d9cac5d69655e3415e3a4/detectron2/data/samplers/distributed_sampler.py#L57 # noqa
@
DATASETS
.
register_module
()
class
ClassBalancedDataset
(
object
):
r
"""A wrapper of repeated dataset with repeat factor.
Suitable for training on class imbalanced datasets like LVIS. Following the
sampling strategy in `this paper`_, in each epoch, an image may appear
multiple times based on its "repeat factor".
.. _this paper: https://arxiv.org/pdf/1908.03195.pdf
The repeat factor for an image is a function of the frequency the rarest
category labeled in that image. The "frequency of category c" in [0, 1]
is defined by the fraction of images in the training set (without repeats)
in which category c appears.
The dataset needs to implement :func:`self.get_cat_ids` to support
ClassBalancedDataset.
The repeat factor is computed as followed.
1. For each category c, compute the fraction :math:`f(c)` of images that
contain it.
2. For each category c, compute the category-level repeat factor
.. math::
r(c) = \max(1, \sqrt{\frac{t}{f(c)}})
3. For each image I and its labels :math:`L(I)`, compute the image-level
repeat factor
.. math::
r(I) = \max_{c \in L(I)} r(c)
Args:
dataset (:obj:`BaseDataset`): The dataset to be repeated.
oversample_thr (float): frequency threshold below which data is
repeated. For categories with ``f_c`` >= ``oversample_thr``, there
is no oversampling. For categories with ``f_c`` <
``oversample_thr``, the degree of oversampling following the
square-root inverse frequency heuristic above.
"""
def
__init__
(
self
,
dataset
,
oversample_thr
):
self
.
dataset
=
dataset
self
.
oversample_thr
=
oversample_thr
self
.
CLASSES
=
dataset
.
CLASSES
repeat_factors
=
self
.
_get_repeat_factors
(
dataset
,
oversample_thr
)
repeat_indices
=
[]
for
dataset_index
,
repeat_factor
in
enumerate
(
repeat_factors
):
repeat_indices
.
extend
([
dataset_index
]
*
math
.
ceil
(
repeat_factor
))
self
.
repeat_indices
=
repeat_indices
flags
=
[]
if
hasattr
(
self
.
dataset
,
'flag'
):
for
flag
,
repeat_factor
in
zip
(
self
.
dataset
.
flag
,
repeat_factors
):
flags
.
extend
([
flag
]
*
int
(
math
.
ceil
(
repeat_factor
)))
assert
len
(
flags
)
==
len
(
repeat_indices
)
self
.
flag
=
np
.
asarray
(
flags
,
dtype
=
np
.
uint8
)
def
_get_repeat_factors
(
self
,
dataset
,
repeat_thr
):
# 1. For each category c, compute the fraction # of images
# that contain it: f(c)
category_freq
=
defaultdict
(
int
)
num_images
=
len
(
dataset
)
for
idx
in
range
(
num_images
):
cat_ids
=
set
(
self
.
dataset
.
get_cat_ids
(
idx
))
for
cat_id
in
cat_ids
:
category_freq
[
cat_id
]
+=
1
for
k
,
v
in
category_freq
.
items
():
assert
v
>
0
,
f
'caterogy
{
k
}
does not contain any images'
category_freq
[
k
]
=
v
/
num_images
# 2. For each category c, compute the category-level repeat factor:
# r(c) = max(1, sqrt(t/f(c)))
category_repeat
=
{
cat_id
:
max
(
1.0
,
math
.
sqrt
(
repeat_thr
/
cat_freq
))
for
cat_id
,
cat_freq
in
category_freq
.
items
()
}
# 3. For each image I and its labels L(I), compute the image-level
# repeat factor:
# r(I) = max_{c in L(I)} r(c)
repeat_factors
=
[]
for
idx
in
range
(
num_images
):
cat_ids
=
set
(
self
.
dataset
.
get_cat_ids
(
idx
))
repeat_factor
=
max
(
{
category_repeat
[
cat_id
]
for
cat_id
in
cat_ids
})
repeat_factors
.
append
(
repeat_factor
)
return
repeat_factors
def
__getitem__
(
self
,
idx
):
ori_index
=
self
.
repeat_indices
[
idx
]
return
self
.
dataset
[
ori_index
]
def
__len__
(
self
):
return
len
(
self
.
repeat_indices
)
def
evaluate
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
(
'evaluate results on a class-balanced dataset is weird. '
'Please inference and evaluate on the original dataset.'
)
def
__repr__
(
self
):
"""Print the number of instance number."""
dataset_type
=
'Test'
if
self
.
test_mode
else
'Train'
result
=
(
f
'
\n
{
self
.
__class__
.
__name__
}
(
{
self
.
dataset
.
__class__
.
__name__
}
) '
f
'
{
dataset_type
}
dataset with total number of samples
{
len
(
self
)
}
.'
)
return
result
@
DATASETS
.
register_module
()
class
KFoldDataset
:
"""A wrapper of dataset for K-Fold cross-validation.
K-Fold cross-validation divides all the samples in groups of samples,
called folds, of almost equal sizes. And we use k-1 of folds to do training
and use the fold left to do validation.
Args:
dataset (:obj:`BaseDataset`): The dataset to be divided.
fold (int): The fold used to do validation. Defaults to 0.
num_splits (int): The number of all folds. Defaults to 5.
test_mode (bool): Use the training dataset or validation dataset.
Defaults to False.
seed (int, optional): The seed to shuffle the dataset before splitting.
If None, not shuffle the dataset. Defaults to None.
"""
def
__init__
(
self
,
dataset
,
fold
=
0
,
num_splits
=
5
,
test_mode
=
False
,
seed
=
None
):
self
.
dataset
=
dataset
self
.
CLASSES
=
dataset
.
CLASSES
self
.
test_mode
=
test_mode
self
.
num_splits
=
num_splits
length
=
len
(
dataset
)
indices
=
list
(
range
(
length
))
if
isinstance
(
seed
,
int
):
rng
=
np
.
random
.
default_rng
(
seed
)
rng
.
shuffle
(
indices
)
test_start
=
length
*
fold
//
num_splits
test_end
=
length
*
(
fold
+
1
)
//
num_splits
if
test_mode
:
self
.
indices
=
indices
[
test_start
:
test_end
]
else
:
self
.
indices
=
indices
[:
test_start
]
+
indices
[
test_end
:]
def
get_cat_ids
(
self
,
idx
):
return
self
.
dataset
.
get_cat_ids
(
self
.
indices
[
idx
])
def
get_gt_labels
(
self
):
dataset_gt_labels
=
self
.
dataset
.
get_gt_labels
()
gt_labels
=
np
.
array
([
dataset_gt_labels
[
idx
]
for
idx
in
self
.
indices
])
return
gt_labels
def
__getitem__
(
self
,
idx
):
return
self
.
dataset
[
self
.
indices
[
idx
]]
def
__len__
(
self
):
return
len
(
self
.
indices
)
def
evaluate
(
self
,
*
args
,
**
kwargs
):
kwargs
[
'indices'
]
=
self
.
indices
return
self
.
dataset
.
evaluate
(
*
args
,
**
kwargs
)
Prev
1
…
30
31
32
33
34
35
36
37
38
…
42
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment