Commit 0fd8347d authored by unknown's avatar unknown
Browse files

添加mmclassification-0.24.1代码,删除mmclassification-speed-benchmark

parent cc567e9e
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import numpy as np
import onnxruntime as ort
import torch
from mmcls.models.classifiers import BaseClassifier
class ONNXRuntimeClassifier(BaseClassifier):
"""Wrapper for classifier's inference with ONNXRuntime."""
def __init__(self, onnx_file, class_names, device_id):
super(ONNXRuntimeClassifier, self).__init__()
sess = ort.InferenceSession(onnx_file)
providers = ['CPUExecutionProvider']
options = [{}]
is_cuda_available = ort.get_device() == 'GPU'
if is_cuda_available:
providers.insert(0, 'CUDAExecutionProvider')
options.insert(0, {'device_id': device_id})
sess.set_providers(providers, options)
self.sess = sess
self.CLASSES = class_names
self.device_id = device_id
self.io_binding = sess.io_binding()
self.output_names = [_.name for _ in sess.get_outputs()]
self.is_cuda_available = is_cuda_available
def simple_test(self, img, img_metas, **kwargs):
raise NotImplementedError('This method is not implemented.')
def extract_feat(self, imgs):
raise NotImplementedError('This method is not implemented.')
def forward_train(self, imgs, **kwargs):
raise NotImplementedError('This method is not implemented.')
def forward_test(self, imgs, img_metas, **kwargs):
input_data = imgs
# set io binding for inputs/outputs
device_type = 'cuda' if self.is_cuda_available else 'cpu'
if not self.is_cuda_available:
input_data = input_data.cpu()
self.io_binding.bind_input(
name='input',
device_type=device_type,
device_id=self.device_id,
element_type=np.float32,
shape=input_data.shape,
buffer_ptr=input_data.data_ptr())
for name in self.output_names:
self.io_binding.bind_output(name)
# run session to get outputs
self.sess.run_with_iobinding(self.io_binding)
results = self.io_binding.copy_outputs_to_cpu()[0]
return list(results)
class TensorRTClassifier(BaseClassifier):
def __init__(self, trt_file, class_names, device_id):
super(TensorRTClassifier, self).__init__()
from mmcv.tensorrt import TRTWraper, load_tensorrt_plugin
try:
load_tensorrt_plugin()
except (ImportError, ModuleNotFoundError):
warnings.warn('If input model has custom op from mmcv, \
you may have to build mmcv with TensorRT from source.')
model = TRTWraper(
trt_file, input_names=['input'], output_names=['probs'])
self.model = model
self.device_id = device_id
self.CLASSES = class_names
def simple_test(self, img, img_metas, **kwargs):
raise NotImplementedError('This method is not implemented.')
def extract_feat(self, imgs):
raise NotImplementedError('This method is not implemented.')
def forward_train(self, imgs, **kwargs):
raise NotImplementedError('This method is not implemented.')
def forward_test(self, imgs, img_metas, **kwargs):
input_data = imgs
with torch.cuda.device(self.device_id), torch.no_grad():
results = self.model({'input': input_data})['probs']
results = results.detach().cpu().numpy()
return list(results)
# Copyright (c) OpenMMLab. All rights reserved.
from .class_num_check_hook import ClassNumCheckHook
from .lr_updater import CosineAnnealingCooldownLrUpdaterHook
from .precise_bn_hook import PreciseBNHook
from .wandblogger_hook import MMClsWandbHook
__all__ = [
'ClassNumCheckHook', 'PreciseBNHook',
'CosineAnnealingCooldownLrUpdaterHook', 'MMClsWandbHook'
]
# Copyright (c) OpenMMLab. All rights reserved
from mmcv.runner import IterBasedRunner
from mmcv.runner.hooks import HOOKS, Hook
from mmcv.utils import is_seq_of
@HOOKS.register_module()
class ClassNumCheckHook(Hook):
def _check_head(self, runner, dataset):
"""Check whether the `num_classes` in head matches the length of
`CLASSES` in `dataset`.
Args:
runner (obj:`EpochBasedRunner`, `IterBasedRunner`): runner object.
dataset (obj: `BaseDataset`): the dataset to check.
"""
model = runner.model
if dataset.CLASSES is None:
runner.logger.warning(
f'Please set `CLASSES` '
f'in the {dataset.__class__.__name__} and'
f'check if it is consistent with the `num_classes` '
f'of head')
else:
assert is_seq_of(dataset.CLASSES, str), \
(f'`CLASSES` in {dataset.__class__.__name__}'
f'should be a tuple of str.')
for name, module in model.named_modules():
if hasattr(module, 'num_classes'):
assert module.num_classes == len(dataset.CLASSES), \
(f'The `num_classes` ({module.num_classes}) in '
f'{module.__class__.__name__} of '
f'{model.__class__.__name__} does not matches '
f'the length of `CLASSES` '
f'{len(dataset.CLASSES)}) in '
f'{dataset.__class__.__name__}')
def before_train_iter(self, runner):
"""Check whether the training dataset is compatible with head.
Args:
runner (obj: `IterBasedRunner`): Iter based Runner.
"""
if not isinstance(runner, IterBasedRunner):
return
self._check_head(runner, runner.data_loader._dataloader.dataset)
def before_val_iter(self, runner):
"""Check whether the eval dataset is compatible with head.
Args:
runner (obj:`IterBasedRunner`): Iter based Runner.
"""
if not isinstance(runner, IterBasedRunner):
return
self._check_head(runner, runner.data_loader._dataloader.dataset)
def before_train_epoch(self, runner):
"""Check whether the training dataset is compatible with head.
Args:
runner (obj:`EpochBasedRunner`): Epoch based Runner.
"""
self._check_head(runner, runner.data_loader.dataset)
def before_val_epoch(self, runner):
"""Check whether the eval dataset is compatible with head.
Args:
runner (obj:`EpochBasedRunner`): Epoch based Runner.
"""
self._check_head(runner, runner.data_loader.dataset)
# Copyright (c) OpenMMLab. All rights reserved.
from math import cos, pi
from mmcv.runner.hooks import HOOKS, LrUpdaterHook
@HOOKS.register_module()
class CosineAnnealingCooldownLrUpdaterHook(LrUpdaterHook):
"""Cosine annealing learning rate scheduler with cooldown.
Args:
min_lr (float, optional): The minimum learning rate after annealing.
Defaults to None.
min_lr_ratio (float, optional): The minimum learning ratio after
nnealing. Defaults to None.
cool_down_ratio (float): The cooldown ratio. Defaults to 0.1.
cool_down_time (int): The cooldown time. Defaults to 10.
by_epoch (bool): If True, the learning rate changes epoch by epoch. If
False, the learning rate changes iter by iter. Defaults to True.
warmup (string, optional): Type of warmup used. It can be None (use no
warmup), 'constant', 'linear' or 'exp'. Defaults to None.
warmup_iters (int): The number of iterations or epochs that warmup
lasts. Defaults to 0.
warmup_ratio (float): LR used at the beginning of warmup equals to
``warmup_ratio * initial_lr``. Defaults to 0.1.
warmup_by_epoch (bool): If True, the ``warmup_iters``
means the number of epochs that warmup lasts, otherwise means the
number of iteration that warmup lasts. Defaults to False.
Note:
You need to set one and only one of ``min_lr`` and ``min_lr_ratio``.
"""
def __init__(self,
min_lr=None,
min_lr_ratio=None,
cool_down_ratio=0.1,
cool_down_time=10,
**kwargs):
assert (min_lr is None) ^ (min_lr_ratio is None)
self.min_lr = min_lr
self.min_lr_ratio = min_lr_ratio
self.cool_down_time = cool_down_time
self.cool_down_ratio = cool_down_ratio
super(CosineAnnealingCooldownLrUpdaterHook, self).__init__(**kwargs)
def get_lr(self, runner, base_lr):
if self.by_epoch:
progress = runner.epoch
max_progress = runner.max_epochs
else:
progress = runner.iter
max_progress = runner.max_iters
if self.min_lr_ratio is not None:
target_lr = base_lr * self.min_lr_ratio
else:
target_lr = self.min_lr
if progress > max_progress - self.cool_down_time:
return target_lr * self.cool_down_ratio
else:
max_progress = max_progress - self.cool_down_time
return annealing_cos(base_lr, target_lr, progress / max_progress)
def annealing_cos(start, end, factor, weight=1):
"""Calculate annealing cos learning rate.
Cosine anneal from `weight * start + (1 - weight) * end` to `end` as
percentage goes from 0.0 to 1.0.
Args:
start (float): The starting learning rate of the cosine annealing.
end (float): The ending learing rate of the cosine annealing.
factor (float): The coefficient of `pi` when calculating the current
percentage. Range from 0.0 to 1.0.
weight (float, optional): The combination factor of `start` and `end`
when calculating the actual starting learning rate. Default to 1.
"""
cos_out = cos(pi * factor) + 1
return end + 0.5 * weight * (start - end) * cos_out
# Copyright (c) OpenMMLab. All rights reserved.
# Adapted from https://github.com/facebookresearch/pycls/blob/f8cd962737e33ce9e19b3083a33551da95c2d9c0/pycls/core/net.py # noqa: E501
# Original licence: Copyright (c) 2019 Facebook, Inc under the Apache License 2.0 # noqa: E501
import itertools
import logging
from typing import List, Optional
import mmcv
import torch
import torch.nn as nn
from mmcv.runner import EpochBasedRunner, get_dist_info
from mmcv.runner.hooks import HOOKS, Hook
from mmcv.utils import print_log
from torch.functional import Tensor
from torch.nn import GroupNorm
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.instancenorm import _InstanceNorm
from torch.utils.data import DataLoader
def scaled_all_reduce(tensors: List[Tensor], num_gpus: int) -> List[Tensor]:
"""Performs the scaled all_reduce operation on the provided tensors.
The input tensors are modified in-place. Currently supports only the sum
reduction operator. The reduced values are scaled by the inverse size of
the process group.
Args:
tensors (List[torch.Tensor]): The tensors to process.
num_gpus (int): The number of gpus to use
Returns:
List[torch.Tensor]: The processed tensors.
"""
# There is no need for reduction in the single-proc case
if num_gpus == 1:
return tensors
# Queue the reductions
reductions = []
for tensor in tensors:
reduction = torch.distributed.all_reduce(tensor, async_op=True)
reductions.append(reduction)
# Wait for reductions to finish
for reduction in reductions:
reduction.wait()
# Scale the results
for tensor in tensors:
tensor.mul_(1.0 / num_gpus)
return tensors
@torch.no_grad()
def update_bn_stats(model: nn.Module,
loader: DataLoader,
num_samples: int = 8192,
logger: Optional[logging.Logger] = None) -> None:
"""Computes precise BN stats on training data.
Args:
model (nn.module): The model whose bn stats will be recomputed.
loader (DataLoader): PyTorch dataloader._dataloader
num_samples (int): The number of samples to update the bn stats.
Defaults to 8192.
logger (:obj:`logging.Logger` | None): Logger for logging.
Default: None.
"""
# get dist info
rank, world_size = get_dist_info()
# Compute the number of mini-batches to use, if the size of dataloader is
# less than num_iters, use all the samples in dataloader.
num_iter = num_samples // (loader.batch_size * world_size)
num_iter = min(num_iter, len(loader))
# Retrieve the BN layers
bn_layers = [
m for m in model.modules()
if m.training and isinstance(m, (_BatchNorm))
]
if len(bn_layers) == 0:
print_log('No BN found in model', logger=logger, level=logging.WARNING)
return
print_log(
f'{len(bn_layers)} BN found, run {num_iter} iters...', logger=logger)
# Finds all the other norm layers with training=True.
other_norm_layers = [
m for m in model.modules()
if m.training and isinstance(m, (_InstanceNorm, GroupNorm))
]
if len(other_norm_layers) > 0:
print_log(
'IN/GN stats will not be updated in PreciseHook.',
logger=logger,
level=logging.INFO)
# Initialize BN stats storage for computing
# mean(mean(batch)) and mean(var(batch))
running_means = [torch.zeros_like(bn.running_mean) for bn in bn_layers]
running_vars = [torch.zeros_like(bn.running_var) for bn in bn_layers]
# Remember momentum values
momentums = [bn.momentum for bn in bn_layers]
# Set momentum to 1.0 to compute BN stats that reflect the current batch
for bn in bn_layers:
bn.momentum = 1.0
# Average the BN stats for each BN layer over the batches
if rank == 0:
prog_bar = mmcv.ProgressBar(num_iter)
for data in itertools.islice(loader, num_iter):
model.train_step(data)
for i, bn in enumerate(bn_layers):
running_means[i] += bn.running_mean / num_iter
running_vars[i] += bn.running_var / num_iter
if rank == 0:
prog_bar.update()
# Sync BN stats across GPUs (no reduction if 1 GPU used)
running_means = scaled_all_reduce(running_means, world_size)
running_vars = scaled_all_reduce(running_vars, world_size)
# Set BN stats and restore original momentum values
for i, bn in enumerate(bn_layers):
bn.running_mean = running_means[i]
bn.running_var = running_vars[i]
bn.momentum = momentums[i]
@HOOKS.register_module()
class PreciseBNHook(Hook):
"""Precise BN hook.
Recompute and update the batch norm stats to make them more precise. During
training both BN stats and the weight are changing after every iteration,
so the running average can not precisely reflect the actual stats of the
current model.
With this hook, the BN stats are recomputed with fixed weights, to make the
running average more precise. Specifically, it computes the true average of
per-batch mean/variance instead of the running average. See Sec. 3 of the
paper `Rethinking Batch in BatchNorm <https://arxiv.org/abs/2105.07576>`
for details.
This hook will update BN stats, so it should be executed before
``CheckpointHook`` and ``EMAHook``, generally set its priority to
"ABOVE_NORMAL".
Args:
num_samples (int): The number of samples to update the bn stats.
Defaults to 8192.
interval (int): Perform precise bn interval. Defaults to 1.
"""
def __init__(self, num_samples: int = 8192, interval: int = 1) -> None:
assert interval > 0 and num_samples > 0
self.interval = interval
self.num_samples = num_samples
def _perform_precise_bn(self, runner: EpochBasedRunner) -> None:
print_log(
f'Running Precise BN for {self.num_samples} items...',
logger=runner.logger)
update_bn_stats(
runner.model,
runner.data_loader,
self.num_samples,
logger=runner.logger)
print_log('Finish Precise BN, BN stats updated.', logger=runner.logger)
def after_train_epoch(self, runner: EpochBasedRunner) -> None:
"""Calculate prcise BN and broadcast BN stats across GPUs.
Args:
runner (obj:`EpochBasedRunner`): runner object.
"""
assert isinstance(runner, EpochBasedRunner), \
'PreciseBN only supports `EpochBasedRunner` by now'
# if by epoch, do perform precise every `self.interval` epochs;
if self.every_n_epochs(runner, self.interval):
self._perform_precise_bn(runner)
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import numpy as np
from mmcv.runner import HOOKS, BaseRunner
from mmcv.runner.dist_utils import master_only
from mmcv.runner.hooks.checkpoint import CheckpointHook
from mmcv.runner.hooks.evaluation import DistEvalHook, EvalHook
from mmcv.runner.hooks.logger.wandb import WandbLoggerHook
@HOOKS.register_module()
class MMClsWandbHook(WandbLoggerHook):
"""Enhanced Wandb logger hook for classification.
Comparing with the :cls:`mmcv.runner.WandbLoggerHook`, this hook can not
only automatically log all information in ``log_buffer`` but also log
the following extra information.
- **Checkpoints**: If ``log_checkpoint`` is True, the checkpoint saved at
every checkpoint interval will be saved as W&B Artifacts. This depends on
the : class:`mmcv.runner.CheckpointHook` whose priority is higher than
this hook. Please refer to
https://docs.wandb.ai/guides/artifacts/model-versioning to learn more
about model versioning with W&B Artifacts.
- **Checkpoint Metadata**: If ``log_checkpoint_metadata`` is True, every
checkpoint artifact will have a metadata associated with it. The metadata
contains the evaluation metrics computed on validation data with that
checkpoint along with the current epoch/iter. It depends on
:class:`EvalHook` whose priority is higher than this hook.
- **Evaluation**: At every interval, this hook logs the model prediction as
interactive W&B Tables. The number of samples logged is given by
``num_eval_images``. Currently, this hook logs the predicted labels along
with the ground truth at every evaluation interval. This depends on the
:class:`EvalHook` whose priority is higher than this hook. Also note that
the data is just logged once and subsequent evaluation tables uses
reference to the logged data to save memory usage. Please refer to
https://docs.wandb.ai/guides/data-vis to learn more about W&B Tables.
Here is a config example:
.. code:: python
checkpoint_config = dict(interval=10)
# To log checkpoint metadata, the interval of checkpoint saving should
# be divisible by the interval of evaluation.
evaluation = dict(interval=5)
log_config = dict(
...
hooks=[
...
dict(type='MMClsWandbHook',
init_kwargs={
'entity': "YOUR_ENTITY",
'project': "YOUR_PROJECT_NAME"
},
log_checkpoint=True,
log_checkpoint_metadata=True,
num_eval_images=100)
])
Args:
init_kwargs (dict): A dict passed to wandb.init to initialize
a W&B run. Please refer to https://docs.wandb.ai/ref/python/init
for possible key-value pairs.
interval (int): Logging interval (every k iterations). Defaults to 10.
log_checkpoint (bool): Save the checkpoint at every checkpoint interval
as W&B Artifacts. Use this for model versioning where each version
is a checkpoint. Defaults to False.
log_checkpoint_metadata (bool): Log the evaluation metrics computed
on the validation data with the checkpoint, along with current
epoch as a metadata to that checkpoint.
Defaults to True.
num_eval_images (int): The number of validation images to be logged.
If zero, the evaluation won't be logged. Defaults to 100.
"""
def __init__(self,
init_kwargs=None,
interval=10,
log_checkpoint=False,
log_checkpoint_metadata=False,
num_eval_images=100,
**kwargs):
super(MMClsWandbHook, self).__init__(init_kwargs, interval, **kwargs)
self.log_checkpoint = log_checkpoint
self.log_checkpoint_metadata = (
log_checkpoint and log_checkpoint_metadata)
self.num_eval_images = num_eval_images
self.log_evaluation = (num_eval_images > 0)
self.ckpt_hook: CheckpointHook = None
self.eval_hook: EvalHook = None
@master_only
def before_run(self, runner: BaseRunner):
super(MMClsWandbHook, self).before_run(runner)
# Inspect CheckpointHook and EvalHook
for hook in runner.hooks:
if isinstance(hook, CheckpointHook):
self.ckpt_hook = hook
if isinstance(hook, (EvalHook, DistEvalHook)):
self.eval_hook = hook
# Check conditions to log checkpoint
if self.log_checkpoint:
if self.ckpt_hook is None:
self.log_checkpoint = False
self.log_checkpoint_metadata = False
runner.logger.warning(
'To log checkpoint in MMClsWandbHook, `CheckpointHook` is'
'required, please check hooks in the runner.')
else:
self.ckpt_interval = self.ckpt_hook.interval
# Check conditions to log evaluation
if self.log_evaluation or self.log_checkpoint_metadata:
if self.eval_hook is None:
self.log_evaluation = False
self.log_checkpoint_metadata = False
runner.logger.warning(
'To log evaluation or checkpoint metadata in '
'MMClsWandbHook, `EvalHook` or `DistEvalHook` in mmcls '
'is required, please check whether the validation '
'is enabled.')
else:
self.eval_interval = self.eval_hook.interval
self.val_dataset = self.eval_hook.dataloader.dataset
if (self.log_evaluation
and self.num_eval_images > len(self.val_dataset)):
self.num_eval_images = len(self.val_dataset)
runner.logger.warning(
f'The num_eval_images ({self.num_eval_images}) is '
'greater than the total number of validation samples '
f'({len(self.val_dataset)}). The complete validation '
'dataset will be logged.')
# Check conditions to log checkpoint metadata
if self.log_checkpoint_metadata:
assert self.ckpt_interval % self.eval_interval == 0, \
'To log checkpoint metadata in MMClsWandbHook, the interval ' \
f'of checkpoint saving ({self.ckpt_interval}) should be ' \
'divisible by the interval of evaluation ' \
f'({self.eval_interval}).'
# Initialize evaluation table
if self.log_evaluation:
# Initialize data table
self._init_data_table()
# Add ground truth to the data table
self._add_ground_truth()
# Log ground truth data
self._log_data_table()
@master_only
def after_train_epoch(self, runner):
super(MMClsWandbHook, self).after_train_epoch(runner)
if not self.by_epoch:
return
# Save checkpoint and metadata
if (self.log_checkpoint
and self.every_n_epochs(runner, self.ckpt_interval)
or (self.ckpt_hook.save_last and self.is_last_epoch(runner))):
if self.log_checkpoint_metadata and self.eval_hook:
metadata = {
'epoch': runner.epoch + 1,
**self._get_eval_results()
}
else:
metadata = None
aliases = [f'epoch_{runner.epoch+1}', 'latest']
model_path = osp.join(self.ckpt_hook.out_dir,
f'epoch_{runner.epoch+1}.pth')
self._log_ckpt_as_artifact(model_path, aliases, metadata)
# Save prediction table
if self.log_evaluation and self.eval_hook._should_evaluate(runner):
results = self.eval_hook.latest_results
# Initialize evaluation table
self._init_pred_table()
# Add predictions to evaluation table
self._add_predictions(results, runner.epoch + 1)
# Log the evaluation table
self._log_eval_table(runner.epoch + 1)
@master_only
def after_train_iter(self, runner):
if self.get_mode(runner) == 'train':
# An ugly patch. The iter-based eval hook will call the
# `after_train_iter` method of all logger hooks before evaluation.
# Use this trick to skip that call.
# Don't call super method at first, it will clear the log_buffer
return super(MMClsWandbHook, self).after_train_iter(runner)
else:
super(MMClsWandbHook, self).after_train_iter(runner)
if self.by_epoch:
return
# Save checkpoint and metadata
if (self.log_checkpoint
and self.every_n_iters(runner, self.ckpt_interval)
or (self.ckpt_hook.save_last and self.is_last_iter(runner))):
if self.log_checkpoint_metadata and self.eval_hook:
metadata = {
'iter': runner.iter + 1,
**self._get_eval_results()
}
else:
metadata = None
aliases = [f'iter_{runner.iter+1}', 'latest']
model_path = osp.join(self.ckpt_hook.out_dir,
f'iter_{runner.iter+1}.pth')
self._log_ckpt_as_artifact(model_path, aliases, metadata)
# Save prediction table
if self.log_evaluation and self.eval_hook._should_evaluate(runner):
results = self.eval_hook.latest_results
# Initialize evaluation table
self._init_pred_table()
# Log predictions
self._add_predictions(results, runner.iter + 1)
# Log the table
self._log_eval_table(runner.iter + 1)
@master_only
def after_run(self, runner):
self.wandb.finish()
def _log_ckpt_as_artifact(self, model_path, aliases, metadata=None):
"""Log model checkpoint as W&B Artifact.
Args:
model_path (str): Path of the checkpoint to log.
aliases (list): List of the aliases associated with this artifact.
metadata (dict, optional): Metadata associated with this artifact.
"""
model_artifact = self.wandb.Artifact(
f'run_{self.wandb.run.id}_model', type='model', metadata=metadata)
model_artifact.add_file(model_path)
self.wandb.log_artifact(model_artifact, aliases=aliases)
def _get_eval_results(self):
"""Get model evaluation results."""
results = self.eval_hook.latest_results
eval_results = self.val_dataset.evaluate(
results, logger='silent', **self.eval_hook.eval_kwargs)
return eval_results
def _init_data_table(self):
"""Initialize the W&B Tables for validation data."""
columns = ['image_name', 'image', 'ground_truth']
self.data_table = self.wandb.Table(columns=columns)
def _init_pred_table(self):
"""Initialize the W&B Tables for model evaluation."""
columns = ['epoch'] if self.by_epoch else ['iter']
columns += ['image_name', 'image', 'ground_truth', 'prediction'
] + list(self.val_dataset.CLASSES)
self.eval_table = self.wandb.Table(columns=columns)
def _add_ground_truth(self):
# Get image loading pipeline
from mmcls.datasets.pipelines import LoadImageFromFile
img_loader = None
for t in self.val_dataset.pipeline.transforms:
if isinstance(t, LoadImageFromFile):
img_loader = t
CLASSES = self.val_dataset.CLASSES
self.eval_image_indexs = np.arange(len(self.val_dataset))
# Set seed so that same validation set is logged each time.
np.random.seed(42)
np.random.shuffle(self.eval_image_indexs)
self.eval_image_indexs = self.eval_image_indexs[:self.num_eval_images]
for idx in self.eval_image_indexs:
img_info = self.val_dataset.data_infos[idx]
if img_loader is not None:
img_info = img_loader(img_info)
# Get image and convert from BGR to RGB
image = img_info['img'][..., ::-1]
else:
# For CIFAR dataset.
image = img_info['img']
image_name = img_info.get('filename', f'img_{idx}')
gt_label = img_info.get('gt_label').item()
self.data_table.add_data(image_name, self.wandb.Image(image),
CLASSES[gt_label])
def _add_predictions(self, results, idx):
table_idxs = self.data_table_ref.get_index()
assert len(table_idxs) == len(self.eval_image_indexs)
for ndx, eval_image_index in enumerate(self.eval_image_indexs):
result = results[eval_image_index]
self.eval_table.add_data(
idx, self.data_table_ref.data[ndx][0],
self.data_table_ref.data[ndx][1],
self.data_table_ref.data[ndx][2],
self.val_dataset.CLASSES[np.argmax(result)], *tuple(result))
def _log_data_table(self):
"""Log the W&B Tables for validation data as artifact and calls
`use_artifact` on it so that the evaluation table can use the reference
of already uploaded images.
This allows the data to be uploaded just once.
"""
data_artifact = self.wandb.Artifact('val', type='dataset')
data_artifact.add(self.data_table, 'val_data')
self.wandb.run.use_artifact(data_artifact)
data_artifact.wait()
self.data_table_ref = data_artifact.get('val_data')
def _log_eval_table(self, idx):
"""Log the W&B Tables for model evaluation.
The table will be logged multiple times creating new version. Use this
to compare models at different intervals interactively.
"""
pred_artifact = self.wandb.Artifact(
f'run_{self.wandb.run.id}_pred', type='evaluation')
pred_artifact.add(self.eval_table, 'eval_data')
if self.by_epoch:
aliases = ['latest', f'epoch_{idx}']
else:
aliases = ['latest', f'iter_{idx}']
self.wandb.run.log_artifact(pred_artifact, aliases=aliases)
# Copyright (c) OpenMMLab. All rights reserved.
from .lamb import Lamb
__all__ = [
'Lamb',
]
"""PyTorch Lamb optimizer w/ behaviour similar to NVIDIA FusedLamb.
This optimizer code was adapted from the following (starting with latest)
* https://github.com/HabanaAI/Model-References/blob/
2b435114fe8e31f159b1d3063b8280ae37af7423/PyTorch/nlp/bert/pretraining/lamb.py
* https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/
LanguageModeling/Transformer-XL/pytorch/lamb.py
* https://github.com/cybertronai/pytorch-lamb
Use FusedLamb if you can (GPU). The reason for including this variant of Lamb
is to have a version that is
similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or
cannot install/use APEX.
In addition to some cleanup, this Lamb impl has been modified to support
PyTorch XLA and has been tested on TPU.
Original copyrights for above sources are below.
Modifications Copyright 2021 Ross Wightman
"""
# Copyright (c) 2021, Habana Labs Ltd. All rights reserved.
# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# MIT License
#
# Copyright (c) 2019 cybertronai
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import math
import torch
from mmcv.runner import OPTIMIZERS
from torch.optim import Optimizer
@OPTIMIZERS.register_module()
class Lamb(Optimizer):
"""A pure pytorch variant of FuseLAMB (NvLamb variant) optimizer.
This class is copied from `timm`_. The LAMB was proposed in `Large Batch
Optimization for Deep Learning - Training BERT in 76 minutes`_.
.. _timm:
https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/lamb.py
.. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
https://arxiv.org/abs/1904.00962
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its norm. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
grad_averaging (bool, optional): whether apply (1-beta2) to grad when
calculating running averages of gradient. (default: True)
max_grad_norm (float, optional): value used to clip global grad norm
(default: 1.0)
trust_clip (bool): enable LAMBC trust ratio clipping (default: False)
always_adapt (boolean, optional): Apply adaptive learning rate to 0.0
weight decay parameter (default: False)
""" # noqa: E501
def __init__(self,
params,
lr=1e-3,
bias_correction=True,
betas=(0.9, 0.999),
eps=1e-6,
weight_decay=0.01,
grad_averaging=True,
max_grad_norm=1.0,
trust_clip=False,
always_adapt=False):
defaults = dict(
lr=lr,
bias_correction=bias_correction,
betas=betas,
eps=eps,
weight_decay=weight_decay,
grad_averaging=grad_averaging,
max_grad_norm=max_grad_norm,
trust_clip=trust_clip,
always_adapt=always_adapt)
super().__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
device = self.param_groups[0]['params'][0].device
one_tensor = torch.tensor(
1.0, device=device
) # because torch.where doesn't handle scalars correctly
global_grad_norm = torch.zeros(1, device=device)
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad
if grad.is_sparse:
raise RuntimeError(
'Lamb does not support sparse gradients, consider '
'SparseAdam instead.')
global_grad_norm.add_(grad.pow(2).sum())
global_grad_norm = torch.sqrt(global_grad_norm)
# FIXME it'd be nice to remove explicit tensor conversion of scalars
# when torch.where promotes
# scalar types properly https://github.com/pytorch/pytorch/issues/9190
max_grad_norm = torch.tensor(
self.defaults['max_grad_norm'], device=device)
clip_global_grad_norm = torch.where(global_grad_norm > max_grad_norm,
global_grad_norm / max_grad_norm,
one_tensor)
for group in self.param_groups:
bias_correction = 1 if group['bias_correction'] else 0
beta1, beta2 = group['betas']
grad_averaging = 1 if group['grad_averaging'] else 0
beta3 = 1 - beta1 if grad_averaging else 1.0
# assume same step across group now to simplify things
# per parameter step can be easily support by making it tensor, or
# pass list into kernel
if 'step' in group:
group['step'] += 1
else:
group['step'] = 1
if bias_correction:
bias_correction1 = 1 - beta1**group['step']
bias_correction2 = 1 - beta2**group['step']
else:
bias_correction1, bias_correction2 = 1.0, 1.0
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.div_(clip_global_grad_norm)
state = self.state[p]
# State initialization
if len(state) == 0:
# Exponential moving average of gradient valuesa
state['exp_avg'] = torch.zeros_like(p)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=beta3) # m_t
exp_avg_sq.mul_(beta2).addcmul_(
grad, grad, value=1 - beta2) # v_t
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(
group['eps'])
update = (exp_avg / bias_correction1).div_(denom)
weight_decay = group['weight_decay']
if weight_decay != 0:
update.add_(p, alpha=weight_decay)
if weight_decay != 0 or group['always_adapt']:
# Layer-wise LR adaptation. By default, skip adaptation on
# parameters that are
# excluded from weight decay, unless always_adapt == True,
# then always enabled.
w_norm = p.norm(2.0)
g_norm = update.norm(2.0)
# FIXME nested where required since logical and/or not
# working in PT XLA
trust_ratio = torch.where(
w_norm > 0,
torch.where(g_norm > 0, w_norm / g_norm, one_tensor),
one_tensor,
)
if group['trust_clip']:
# LAMBC trust clipping, upper bound fixed at one
trust_ratio = torch.minimum(trust_ratio, one_tensor)
update.mul_(trust_ratio)
p.add_(update, alpha=-group['lr'])
return loss
# Copyright (c) OpenMMLab. All rights reserved.
from .dist_utils import DistOptimizerHook, allreduce_grads, sync_random_seed
from .misc import multi_apply
__all__ = [
'allreduce_grads', 'DistOptimizerHook', 'multi_apply', 'sync_random_seed'
]
# Copyright (c) OpenMMLab. All rights reserved.
from collections import OrderedDict
import numpy as np
import torch
import torch.distributed as dist
from mmcv.runner import OptimizerHook, get_dist_info
from torch._utils import (_flatten_dense_tensors, _take_tensors,
_unflatten_dense_tensors)
def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1):
if bucket_size_mb > 0:
bucket_size_bytes = bucket_size_mb * 1024 * 1024
buckets = _take_tensors(tensors, bucket_size_bytes)
else:
buckets = OrderedDict()
for tensor in tensors:
tp = tensor.type()
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(tensor)
buckets = buckets.values()
for bucket in buckets:
flat_tensors = _flatten_dense_tensors(bucket)
dist.all_reduce(flat_tensors)
flat_tensors.div_(world_size)
for tensor, synced in zip(
bucket, _unflatten_dense_tensors(flat_tensors, bucket)):
tensor.copy_(synced)
def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
grads = [
param.grad.data for param in params
if param.requires_grad and param.grad is not None
]
world_size = dist.get_world_size()
if coalesce:
_allreduce_coalesced(grads, world_size, bucket_size_mb)
else:
for tensor in grads:
dist.all_reduce(tensor.div_(world_size))
class DistOptimizerHook(OptimizerHook):
def __init__(self, grad_clip=None, coalesce=True, bucket_size_mb=-1):
self.grad_clip = grad_clip
self.coalesce = coalesce
self.bucket_size_mb = bucket_size_mb
def after_train_iter(self, runner):
runner.optimizer.zero_grad()
runner.outputs['loss'].backward()
if self.grad_clip is not None:
self.clip_grads(runner.model.parameters())
runner.optimizer.step()
def sync_random_seed(seed=None, device='cuda'):
"""Make sure different ranks share the same seed.
All workers must call this function, otherwise it will deadlock.
This method is generally used in `DistributedSampler`,
because the seed should be identical across all processes
in the distributed group.
In distributed sampling, different ranks should sample non-overlapped
data in the dataset. Therefore, this function is used to make sure that
each rank shuffles the data indices in the same order based
on the same seed. Then different ranks could use different indices
to select non-overlapped data from the same data list.
Args:
seed (int, Optional): The seed. Default to None.
device (str): The device where the seed will be put on.
Default to 'cuda'.
Returns:
int: Seed to be used.
"""
if seed is None:
seed = np.random.randint(2**31)
assert isinstance(seed, int)
rank, world_size = get_dist_info()
if world_size == 1:
return seed
if rank == 0:
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
else:
random_num = torch.tensor(0, dtype=torch.int32, device=device)
dist.broadcast(random_num, src=0)
return random_num.item()
# Copyright (c) OpenMMLab. All rights reserved.
from .image import (BaseFigureContextManager, ImshowInfosContextManager,
color_val_matplotlib, imshow_infos)
__all__ = [
'BaseFigureContextManager', 'ImshowInfosContextManager', 'imshow_infos',
'color_val_matplotlib'
]
# Copyright (c) OpenMMLab. All rights reserved.
import matplotlib.pyplot as plt
import mmcv
import numpy as np
from matplotlib.backend_bases import CloseEvent
# A small value
EPS = 1e-2
def color_val_matplotlib(color):
"""Convert various input in BGR order to normalized RGB matplotlib color
tuples,
Args:
color (:obj:`mmcv.Color`/str/tuple/int/ndarray): Color inputs
Returns:
tuple[float]: A tuple of 3 normalized floats indicating RGB channels.
"""
color = mmcv.color_val(color)
color = [color / 255 for color in color[::-1]]
return tuple(color)
class BaseFigureContextManager:
"""Context Manager to reuse matplotlib figure.
It provides a figure for saving and a figure for showing to support
different settings.
Args:
axis (bool): Whether to show the axis lines.
fig_save_cfg (dict): Keyword parameters of figure for saving.
Defaults to empty dict.
fig_show_cfg (dict): Keyword parameters of figure for showing.
Defaults to empty dict.
"""
def __init__(self, axis=False, fig_save_cfg={}, fig_show_cfg={}) -> None:
self.is_inline = 'inline' in plt.get_backend()
# Because save and show need different figure size
# We set two figure and axes to handle save and show
self.fig_save: plt.Figure = None
self.fig_save_cfg = fig_save_cfg
self.ax_save: plt.Axes = None
self.fig_show: plt.Figure = None
self.fig_show_cfg = fig_show_cfg
self.ax_show: plt.Axes = None
self.axis = axis
def __enter__(self):
if not self.is_inline:
# If use inline backend, we cannot control which figure to show,
# so disable the interactive fig_show, and put the initialization
# of fig_save to `prepare` function.
self._initialize_fig_save()
self._initialize_fig_show()
return self
def _initialize_fig_save(self):
fig = plt.figure(**self.fig_save_cfg)
ax = fig.add_subplot()
# remove white edges by set subplot margin
fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
self.fig_save, self.ax_save = fig, ax
def _initialize_fig_show(self):
# fig_save will be resized to image size, only fig_show needs fig_size.
fig = plt.figure(**self.fig_show_cfg)
ax = fig.add_subplot()
# remove white edges by set subplot margin
fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
self.fig_show, self.ax_show = fig, ax
def __exit__(self, exc_type, exc_value, traceback):
if self.is_inline:
# If use inline backend, whether to close figure depends on if
# users want to show the image.
return
plt.close(self.fig_save)
plt.close(self.fig_show)
def prepare(self):
if self.is_inline:
# if use inline backend, just rebuild the fig_save.
self._initialize_fig_save()
self.ax_save.cla()
self.ax_save.axis(self.axis)
return
# If users force to destroy the window, rebuild fig_show.
if not plt.fignum_exists(self.fig_show.number):
self._initialize_fig_show()
# Clear all axes
self.ax_save.cla()
self.ax_save.axis(self.axis)
self.ax_show.cla()
self.ax_show.axis(self.axis)
def wait_continue(self, timeout=0, continue_key=' ') -> int:
"""Show the image and wait for the user's input.
This implementation refers to
https://github.com/matplotlib/matplotlib/blob/v3.5.x/lib/matplotlib/_blocking_input.py
Args:
timeout (int): If positive, continue after ``timeout`` seconds.
Defaults to 0.
continue_key (str): The key for users to continue. Defaults to
the space key.
Returns:
int: If zero, means time out or the user pressed ``continue_key``,
and if one, means the user closed the show figure.
""" # noqa: E501
if self.is_inline:
# If use inline backend, interactive input and timeout is no use.
return
if self.fig_show.canvas.manager:
# Ensure that the figure is shown
self.fig_show.show()
while True:
# Connect the events to the handler function call.
event = None
def handler(ev):
# Set external event variable
nonlocal event
# Qt backend may fire two events at the same time,
# use a condition to avoid missing close event.
event = ev if not isinstance(event, CloseEvent) else event
self.fig_show.canvas.stop_event_loop()
cids = [
self.fig_show.canvas.mpl_connect(name, handler)
for name in ('key_press_event', 'close_event')
]
try:
self.fig_show.canvas.start_event_loop(timeout)
finally: # Run even on exception like ctrl-c.
# Disconnect the callbacks.
for cid in cids:
self.fig_show.canvas.mpl_disconnect(cid)
if isinstance(event, CloseEvent):
return 1 # Quit for close.
elif event is None or event.key == continue_key:
return 0 # Quit for continue.
class ImshowInfosContextManager(BaseFigureContextManager):
"""Context Manager to reuse matplotlib figure and put infos on images.
Args:
fig_size (tuple[int]): Size of the figure to show image.
Examples:
>>> import mmcv
>>> from mmcls.core import visualization as vis
>>> img1 = mmcv.imread("./1.png")
>>> info1 = {'class': 'cat', 'label': 0}
>>> img2 = mmcv.imread("./2.png")
>>> info2 = {'class': 'dog', 'label': 1}
>>> with vis.ImshowInfosContextManager() as manager:
... # Show img1
... manager.put_img_infos(img1, info1)
... # Show img2 on the same figure and save output image.
... manager.put_img_infos(
... img2, info2, out_file='./2_out.png')
"""
def __init__(self, fig_size=(15, 10)):
super().__init__(
axis=False,
# A proper dpi for image save with default font size.
fig_save_cfg=dict(frameon=False, dpi=36),
fig_show_cfg=dict(frameon=False, figsize=fig_size))
def _put_text(self, ax, text, x, y, text_color, font_size):
ax.text(
x,
y,
f'{text}',
bbox={
'facecolor': 'black',
'alpha': 0.7,
'pad': 0.2,
'edgecolor': 'none',
'boxstyle': 'round'
},
color=text_color,
fontsize=font_size,
family='monospace',
verticalalignment='top',
horizontalalignment='left')
def put_img_infos(self,
img,
infos,
text_color='white',
font_size=26,
row_width=20,
win_name='',
show=True,
wait_time=0,
out_file=None):
"""Show image with extra information.
Args:
img (str | ndarray): The image to be displayed.
infos (dict): Extra infos to display in the image.
text_color (:obj:`mmcv.Color`/str/tuple/int/ndarray): Extra infos
display color. Defaults to 'white'.
font_size (int): Extra infos display font size. Defaults to 26.
row_width (int): width between each row of results on the image.
win_name (str): The image title. Defaults to ''
show (bool): Whether to show the image. Defaults to True.
wait_time (int): How many seconds to display the image.
Defaults to 0.
out_file (Optional[str]): The filename to write the image.
Defaults to None.
Returns:
np.ndarray: The image with extra infomations.
"""
self.prepare()
text_color = color_val_matplotlib(text_color)
img = mmcv.imread(img).astype(np.uint8)
x, y = 3, row_width // 2
img = mmcv.bgr2rgb(img)
width, height = img.shape[1], img.shape[0]
img = np.ascontiguousarray(img)
# add a small EPS to avoid precision lost due to matplotlib's
# truncation (https://github.com/matplotlib/matplotlib/issues/15363)
dpi = self.fig_save.get_dpi()
self.fig_save.set_size_inches((width + EPS) / dpi,
(height + EPS) / dpi)
for k, v in infos.items():
if isinstance(v, float):
v = f'{v:.2f}'
label_text = f'{k}: {v}'
self._put_text(self.ax_save, label_text, x, y, text_color,
font_size)
if show and not self.is_inline:
self._put_text(self.ax_show, label_text, x, y, text_color,
font_size)
y += row_width
self.ax_save.imshow(img)
stream, _ = self.fig_save.canvas.print_to_buffer()
buffer = np.frombuffer(stream, dtype='uint8')
img_rgba = buffer.reshape(height, width, 4)
rgb, _ = np.split(img_rgba, [3], axis=2)
img_save = rgb.astype('uint8')
img_save = mmcv.rgb2bgr(img_save)
if out_file is not None:
mmcv.imwrite(img_save, out_file)
ret = 0
if show and not self.is_inline:
# Reserve some space for the tip.
self.ax_show.set_title(win_name)
self.ax_show.set_ylim(height + 20)
self.ax_show.text(
width // 2,
height + 18,
'Press SPACE to continue.',
ha='center',
fontsize=font_size)
self.ax_show.imshow(img)
# Refresh canvas, necessary for Qt5 backend.
self.fig_show.canvas.draw()
ret = self.wait_continue(timeout=wait_time)
elif (not show) and self.is_inline:
# If use inline backend, we use fig_save to show the image
# So we need to close it if users don't want to show.
plt.close(self.fig_save)
return ret, img_save
def imshow_infos(img,
infos,
text_color='white',
font_size=26,
row_width=20,
win_name='',
show=True,
fig_size=(15, 10),
wait_time=0,
out_file=None):
"""Show image with extra information.
Args:
img (str | ndarray): The image to be displayed.
infos (dict): Extra infos to display in the image.
text_color (:obj:`mmcv.Color`/str/tuple/int/ndarray): Extra infos
display color. Defaults to 'white'.
font_size (int): Extra infos display font size. Defaults to 26.
row_width (int): width between each row of results on the image.
win_name (str): The image title. Defaults to ''
show (bool): Whether to show the image. Defaults to True.
fig_size (tuple): Image show figure size. Defaults to (15, 10).
wait_time (int): How many seconds to display the image. Defaults to 0.
out_file (Optional[str]): The filename to write the image.
Defaults to None.
Returns:
np.ndarray: The image with extra infomations.
"""
with ImshowInfosContextManager(fig_size=fig_size) as manager:
_, img = manager.put_img_infos(
img,
infos,
text_color=text_color,
font_size=font_size,
row_width=row_width,
win_name=win_name,
show=show,
wait_time=wait_time,
out_file=out_file)
return img
# Copyright (c) OpenMMLab. All rights reserved.
from .base_dataset import BaseDataset
from .builder import (DATASETS, PIPELINES, SAMPLERS, build_dataloader,
build_dataset, build_sampler)
from .cifar import CIFAR10, CIFAR100
from .cub import CUB
from .custom import CustomDataset
from .dataset_wrappers import (ClassBalancedDataset, ConcatDataset,
KFoldDataset, RepeatDataset)
from .imagenet import ImageNet
from .imagenet21k import ImageNet21k
from .mnist import MNIST, FashionMNIST
from .multi_label import MultiLabelDataset
from .samplers import DistributedSampler, RepeatAugSampler
from .stanford_cars import StanfordCars
from .voc import VOC
__all__ = [
'BaseDataset', 'ImageNet', 'CIFAR10', 'CIFAR100', 'MNIST', 'FashionMNIST',
'VOC', 'MultiLabelDataset', 'build_dataloader', 'build_dataset',
'DistributedSampler', 'ConcatDataset', 'RepeatDataset',
'ClassBalancedDataset', 'DATASETS', 'PIPELINES', 'ImageNet21k', 'SAMPLERS',
'build_sampler', 'RepeatAugSampler', 'KFoldDataset', 'CUB',
'CustomDataset', 'StanfordCars'
]
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
from abc import ABCMeta, abstractmethod
from os import PathLike
from typing import List
import mmcv
import numpy as np
......@@ -10,6 +14,13 @@ from mmcls.models.losses import accuracy
from .pipelines import Compose
def expanduser(path):
if isinstance(path, (str, PathLike)):
return osp.expanduser(path)
else:
return path
class BaseDataset(Dataset, metaclass=ABCMeta):
"""Base dataset.
......@@ -32,12 +43,11 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
ann_file=None,
test_mode=False):
super(BaseDataset, self).__init__()
self.ann_file = ann_file
self.data_prefix = data_prefix
self.test_mode = test_mode
self.data_prefix = expanduser(data_prefix)
self.pipeline = Compose(pipeline)
self.CLASSES = self.get_classes(classes)
self.ann_file = expanduser(ann_file)
self.test_mode = test_mode
self.data_infos = self.load_annotations()
@abstractmethod
......@@ -58,23 +68,23 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
"""Get all ground-truth labels (categories).
Returns:
list[int]: categories for all images.
np.ndarray: categories for all images.
"""
gt_labels = np.array([data['gt_label'] for data in self.data_infos])
return gt_labels
def get_cat_ids(self, idx):
def get_cat_ids(self, idx: int) -> List[int]:
"""Get category id by index.
Args:
idx (int): Index of data.
Returns:
int: Image category of specified index.
cat_ids (List[int]): Image category of specified index.
"""
return self.data_infos[idx]['gt_label'].astype(np.int)
return [int(self.data_infos[idx]['gt_label'])]
def prepare_data(self, idx):
results = copy.deepcopy(self.data_infos[idx])
......@@ -89,6 +99,7 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
@classmethod
def get_classes(cls, classes=None):
"""Get class names of current dataset.
Args:
classes (Sequence[str] | str | None): If classes is None, use
default CLASSES defined by builtin dataset. If classes is a
......@@ -104,7 +115,7 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
if isinstance(classes, str):
# take it as a file path
class_names = mmcv.list_from_file(classes)
class_names = mmcv.list_from_file(expanduser(classes))
elif isinstance(classes, (tuple, list)):
class_names = classes
else:
......@@ -116,6 +127,7 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
results,
metric='accuracy',
metric_options=None,
indices=None,
logger=None):
"""Evaluate the dataset.
......@@ -126,6 +138,8 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
metric_options (dict, optional): Options for calculating metrics.
Allowed keys are 'topk', 'thrs' and 'average_mode'.
Defaults to None.
indices (list, optional): The indices of samples corresponding to
the results. Defaults to None.
logger (logging.Logger | str, optional): Logger used for printing
related information during evaluation. Defaults to None.
Returns:
......@@ -143,20 +157,25 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
eval_results = {}
results = np.vstack(results)
gt_labels = self.get_gt_labels()
if indices is not None:
gt_labels = gt_labels[indices]
num_imgs = len(results)
assert len(gt_labels) == num_imgs, 'dataset testing results should '\
'be of the same length as gt_labels.'
invalid_metrics = set(metrics) - set(allowed_metrics)
if len(invalid_metrics) != 0:
raise ValueError(f'metirc {invalid_metrics} is not supported.')
raise ValueError(f'metric {invalid_metrics} is not supported.')
topk = metric_options.get('topk', (1, 5))
thrs = metric_options.get('thrs')
average_mode = metric_options.get('average_mode', 'macro')
if 'accuracy' in metrics:
acc = accuracy(results, gt_labels, topk=topk, thrs=thrs)
if thrs is not None:
acc = accuracy(results, gt_labels, topk=topk, thrs=thrs)
else:
acc = accuracy(results, gt_labels, topk=topk)
if isinstance(topk, tuple):
eval_results_ = {
f'accuracy_top-{k}': a
......@@ -182,8 +201,12 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
precision_recall_f1_keys = ['precision', 'recall', 'f1_score']
if len(set(metrics) & set(precision_recall_f1_keys)) != 0:
precision_recall_f1_values = precision_recall_f1(
results, gt_labels, average_mode=average_mode, thrs=thrs)
if thrs is not None:
precision_recall_f1_values = precision_recall_f1(
results, gt_labels, average_mode=average_mode, thrs=thrs)
else:
precision_recall_f1_values = precision_recall_f1(
results, gt_labels, average_mode=average_mode)
for key, values in zip(precision_recall_f1_keys,
precision_recall_f1_values):
if key in metrics:
......
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import platform
import random
from functools import partial
import numpy as np
import torch
from mmcv.parallel import collate
from mmcv.runner import get_dist_info
from mmcv.utils import Registry, build_from_cfg, digit_version
from torch.utils.data import DataLoader
try:
from mmcv.utils import IS_IPU_AVAILABLE
except ImportError:
IS_IPU_AVAILABLE = False
if platform.system() != 'Windows':
# https://github.com/pytorch/pytorch/issues/973
import resource
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
hard_limit = rlimit[1]
soft_limit = min(4096, hard_limit)
resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit))
DATASETS = Registry('dataset')
PIPELINES = Registry('pipeline')
SAMPLERS = Registry('sampler')
def build_dataset(cfg, default_args=None):
from .dataset_wrappers import (ClassBalancedDataset, ConcatDataset,
KFoldDataset, RepeatDataset)
if isinstance(cfg, (list, tuple)):
dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
elif cfg['type'] == 'ConcatDataset':
dataset = ConcatDataset(
[build_dataset(c, default_args) for c in cfg['datasets']],
separate_eval=cfg.get('separate_eval', True))
elif cfg['type'] == 'RepeatDataset':
dataset = RepeatDataset(
build_dataset(cfg['dataset'], default_args), cfg['times'])
elif cfg['type'] == 'ClassBalancedDataset':
dataset = ClassBalancedDataset(
build_dataset(cfg['dataset'], default_args), cfg['oversample_thr'])
elif cfg['type'] == 'KFoldDataset':
cp_cfg = copy.deepcopy(cfg)
if cp_cfg.get('test_mode', None) is None:
cp_cfg['test_mode'] = (default_args or {}).pop('test_mode', False)
cp_cfg['dataset'] = build_dataset(cp_cfg['dataset'], default_args)
cp_cfg.pop('type')
dataset = KFoldDataset(**cp_cfg)
else:
dataset = build_from_cfg(cfg, DATASETS, default_args)
return dataset
def build_dataloader(dataset,
samples_per_gpu,
workers_per_gpu,
num_gpus=1,
dist=True,
shuffle=True,
round_up=True,
seed=None,
pin_memory=True,
persistent_workers=True,
sampler_cfg=None,
**kwargs):
"""Build PyTorch DataLoader.
In distributed training, each GPU/process has a dataloader.
In non-distributed training, there is only one dataloader for all GPUs.
Args:
dataset (Dataset): A PyTorch dataset.
samples_per_gpu (int): Number of training samples on each GPU, i.e.,
batch size of each GPU.
workers_per_gpu (int): How many subprocesses to use for data loading
for each GPU.
num_gpus (int): Number of GPUs. Only used in non-distributed training.
dist (bool): Distributed training/test or not. Default: True.
shuffle (bool): Whether to shuffle the data at every epoch.
Default: True.
round_up (bool): Whether to round up the length of dataset by adding
extra samples to make it evenly divisible. Default: True.
pin_memory (bool): Whether to use pin_memory in DataLoader.
Default: True
persistent_workers (bool): If True, the data loader will not shutdown
the worker processes after a dataset has been consumed once.
This allows to maintain the workers Dataset instances alive.
The argument also has effect in PyTorch>=1.7.0.
Default: True
sampler_cfg (dict): sampler configuration to override the default
sampler
kwargs: any keyword argument to be used to initialize DataLoader
Returns:
DataLoader: A PyTorch dataloader.
"""
rank, world_size = get_dist_info()
# Custom sampler logic
if sampler_cfg:
# shuffle=False when val and test
sampler_cfg.update(shuffle=shuffle)
sampler = build_sampler(
sampler_cfg,
default_args=dict(
dataset=dataset, num_replicas=world_size, rank=rank,
seed=seed))
# Default sampler logic
elif dist:
sampler = build_sampler(
dict(
type='DistributedSampler',
dataset=dataset,
num_replicas=world_size,
rank=rank,
shuffle=shuffle,
round_up=round_up,
seed=seed))
else:
sampler = None
# If sampler exists, turn off dataloader shuffle
if sampler is not None:
shuffle = False
if dist:
batch_size = samples_per_gpu
num_workers = workers_per_gpu
else:
batch_size = num_gpus * samples_per_gpu
num_workers = num_gpus * workers_per_gpu
init_fn = partial(
worker_init_fn, num_workers=num_workers, rank=rank,
seed=seed) if seed is not None else None
if digit_version(torch.__version__) >= digit_version('1.8.0'):
kwargs['persistent_workers'] = persistent_workers
if IS_IPU_AVAILABLE:
from mmcv.device.ipu import IPUDataLoader
data_loader = IPUDataLoader(
dataset,
None,
batch_size=samples_per_gpu,
num_workers=num_workers,
shuffle=shuffle,
worker_init_fn=init_fn,
**kwargs)
else:
data_loader = DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers,
collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
pin_memory=pin_memory,
shuffle=shuffle,
worker_init_fn=init_fn,
**kwargs)
return data_loader
def worker_init_fn(worker_id, num_workers, rank, seed):
# The seed of each worker equals to
# num_worker * rank + worker_id + user_seed
worker_seed = num_workers * rank + worker_id + seed
np.random.seed(worker_seed)
random.seed(worker_seed)
torch.manual_seed(worker_seed)
def build_sampler(cfg, default_args=None):
if cfg is None:
return None
else:
return build_from_cfg(cfg, SAMPLERS, default_args=default_args)
# Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path
import pickle
......@@ -16,8 +17,8 @@ class CIFAR10(BaseDataset):
"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
This implementation is modified from
https://github.com/pytorch/vision/blob/master/torchvision/datasets/cifar.py # noqa: E501
"""
https://github.com/pytorch/vision/blob/master/torchvision/datasets/cifar.py
""" # noqa: E501
base_folder = 'cifar-10-batches-py'
url = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
......@@ -39,6 +40,10 @@ class CIFAR10(BaseDataset):
'key': 'label_names',
'md5': '5ff9c542aee3614f3951f8cda6e48888',
}
CLASSES = [
'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',
'horse', 'ship', 'truck'
]
def load_annotations(self):
......@@ -130,3 +135,21 @@ class CIFAR100(CIFAR10):
'key': 'fine_label_names',
'md5': '7973b15100ade9c7d40fb424638fde48',
}
CLASSES = [
'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee',
'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus',
'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle',
'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab',
'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish',
'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard',
'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man',
'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom',
'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear',
'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine',
'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea',
'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake',
'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper',
'table', 'tank', 'telephone', 'television', 'tiger', 'tractor',
'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale',
'willow_tree', 'wolf', 'woman', 'worm'
]
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from .base_dataset import BaseDataset
from .builder import DATASETS
@DATASETS.register_module()
class CUB(BaseDataset):
"""The CUB-200-2011 Dataset.
Support the `CUB-200-2011 <http://www.vision.caltech.edu/visipedia/CUB-200-2011.html>`_ Dataset.
Comparing with the `CUB-200 <http://www.vision.caltech.edu/visipedia/CUB-200.html>`_ Dataset,
there are much more pictures in `CUB-200-2011`.
Args:
ann_file (str): the annotation file.
images.txt in CUB.
image_class_labels_file (str): the label file.
image_class_labels.txt in CUB.
train_test_split_file (str): the split file.
train_test_split_file.txt in CUB.
""" # noqa: E501
CLASSES = [
'Black_footed_Albatross', 'Laysan_Albatross', 'Sooty_Albatross',
'Groove_billed_Ani', 'Crested_Auklet', 'Least_Auklet',
'Parakeet_Auklet', 'Rhinoceros_Auklet', 'Brewer_Blackbird',
'Red_winged_Blackbird', 'Rusty_Blackbird', 'Yellow_headed_Blackbird',
'Bobolink', 'Indigo_Bunting', 'Lazuli_Bunting', 'Painted_Bunting',
'Cardinal', 'Spotted_Catbird', 'Gray_Catbird', 'Yellow_breasted_Chat',
'Eastern_Towhee', 'Chuck_will_Widow', 'Brandt_Cormorant',
'Red_faced_Cormorant', 'Pelagic_Cormorant', 'Bronzed_Cowbird',
'Shiny_Cowbird', 'Brown_Creeper', 'American_Crow', 'Fish_Crow',
'Black_billed_Cuckoo', 'Mangrove_Cuckoo', 'Yellow_billed_Cuckoo',
'Gray_crowned_Rosy_Finch', 'Purple_Finch', 'Northern_Flicker',
'Acadian_Flycatcher', 'Great_Crested_Flycatcher', 'Least_Flycatcher',
'Olive_sided_Flycatcher', 'Scissor_tailed_Flycatcher',
'Vermilion_Flycatcher', 'Yellow_bellied_Flycatcher', 'Frigatebird',
'Northern_Fulmar', 'Gadwall', 'American_Goldfinch',
'European_Goldfinch', 'Boat_tailed_Grackle', 'Eared_Grebe',
'Horned_Grebe', 'Pied_billed_Grebe', 'Western_Grebe', 'Blue_Grosbeak',
'Evening_Grosbeak', 'Pine_Grosbeak', 'Rose_breasted_Grosbeak',
'Pigeon_Guillemot', 'California_Gull', 'Glaucous_winged_Gull',
'Heermann_Gull', 'Herring_Gull', 'Ivory_Gull', 'Ring_billed_Gull',
'Slaty_backed_Gull', 'Western_Gull', 'Anna_Hummingbird',
'Ruby_throated_Hummingbird', 'Rufous_Hummingbird', 'Green_Violetear',
'Long_tailed_Jaeger', 'Pomarine_Jaeger', 'Blue_Jay', 'Florida_Jay',
'Green_Jay', 'Dark_eyed_Junco', 'Tropical_Kingbird', 'Gray_Kingbird',
'Belted_Kingfisher', 'Green_Kingfisher', 'Pied_Kingfisher',
'Ringed_Kingfisher', 'White_breasted_Kingfisher',
'Red_legged_Kittiwake', 'Horned_Lark', 'Pacific_Loon', 'Mallard',
'Western_Meadowlark', 'Hooded_Merganser', 'Red_breasted_Merganser',
'Mockingbird', 'Nighthawk', 'Clark_Nutcracker',
'White_breasted_Nuthatch', 'Baltimore_Oriole', 'Hooded_Oriole',
'Orchard_Oriole', 'Scott_Oriole', 'Ovenbird', 'Brown_Pelican',
'White_Pelican', 'Western_Wood_Pewee', 'Sayornis', 'American_Pipit',
'Whip_poor_Will', 'Horned_Puffin', 'Common_Raven',
'White_necked_Raven', 'American_Redstart', 'Geococcyx',
'Loggerhead_Shrike', 'Great_Grey_Shrike', 'Baird_Sparrow',
'Black_throated_Sparrow', 'Brewer_Sparrow', 'Chipping_Sparrow',
'Clay_colored_Sparrow', 'House_Sparrow', 'Field_Sparrow',
'Fox_Sparrow', 'Grasshopper_Sparrow', 'Harris_Sparrow',
'Henslow_Sparrow', 'Le_Conte_Sparrow', 'Lincoln_Sparrow',
'Nelson_Sharp_tailed_Sparrow', 'Savannah_Sparrow', 'Seaside_Sparrow',
'Song_Sparrow', 'Tree_Sparrow', 'Vesper_Sparrow',
'White_crowned_Sparrow', 'White_throated_Sparrow',
'Cape_Glossy_Starling', 'Bank_Swallow', 'Barn_Swallow',
'Cliff_Swallow', 'Tree_Swallow', 'Scarlet_Tanager', 'Summer_Tanager',
'Artic_Tern', 'Black_Tern', 'Caspian_Tern', 'Common_Tern',
'Elegant_Tern', 'Forsters_Tern', 'Least_Tern', 'Green_tailed_Towhee',
'Brown_Thrasher', 'Sage_Thrasher', 'Black_capped_Vireo',
'Blue_headed_Vireo', 'Philadelphia_Vireo', 'Red_eyed_Vireo',
'Warbling_Vireo', 'White_eyed_Vireo', 'Yellow_throated_Vireo',
'Bay_breasted_Warbler', 'Black_and_white_Warbler',
'Black_throated_Blue_Warbler', 'Blue_winged_Warbler', 'Canada_Warbler',
'Cape_May_Warbler', 'Cerulean_Warbler', 'Chestnut_sided_Warbler',
'Golden_winged_Warbler', 'Hooded_Warbler', 'Kentucky_Warbler',
'Magnolia_Warbler', 'Mourning_Warbler', 'Myrtle_Warbler',
'Nashville_Warbler', 'Orange_crowned_Warbler', 'Palm_Warbler',
'Pine_Warbler', 'Prairie_Warbler', 'Prothonotary_Warbler',
'Swainson_Warbler', 'Tennessee_Warbler', 'Wilson_Warbler',
'Worm_eating_Warbler', 'Yellow_Warbler', 'Northern_Waterthrush',
'Louisiana_Waterthrush', 'Bohemian_Waxwing', 'Cedar_Waxwing',
'American_Three_toed_Woodpecker', 'Pileated_Woodpecker',
'Red_bellied_Woodpecker', 'Red_cockaded_Woodpecker',
'Red_headed_Woodpecker', 'Downy_Woodpecker', 'Bewick_Wren',
'Cactus_Wren', 'Carolina_Wren', 'House_Wren', 'Marsh_Wren',
'Rock_Wren', 'Winter_Wren', 'Common_Yellowthroat'
]
def __init__(self, *args, ann_file, image_class_labels_file,
train_test_split_file, **kwargs):
self.image_class_labels_file = image_class_labels_file
self.train_test_split_file = train_test_split_file
super(CUB, self).__init__(*args, ann_file=ann_file, **kwargs)
def load_annotations(self):
with open(self.ann_file) as f:
samples = [x.strip().split(' ')[1] for x in f.readlines()]
with open(self.image_class_labels_file) as f:
gt_labels = [
# in the official CUB-200-2011 dataset, labels in
# image_class_labels_file are started from 1, so
# here we need to '- 1' to let them start from 0.
int(x.strip().split(' ')[1]) - 1 for x in f.readlines()
]
with open(self.train_test_split_file) as f:
splits = [int(x.strip().split(' ')[1]) for x in f.readlines()]
assert len(samples) == len(gt_labels) == len(splits),\
f'samples({len(samples)}), gt_labels({len(gt_labels)}) and ' \
f'splits({len(splits)}) should have same length.'
data_infos = []
for filename, gt_label, split in zip(samples, gt_labels, splits):
if split and self.test_mode:
# skip train samples when test_mode=True
continue
elif not split and not self.test_mode:
# skip test samples when test_mode=False
continue
info = {'img_prefix': self.data_prefix}
info['img_info'] = {'filename': filename}
info['gt_label'] = np.array(gt_label, dtype=np.int64)
data_infos.append(info)
return data_infos
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
import mmcv
import numpy as np
from mmcv import FileClient
from .base_dataset import BaseDataset
from .builder import DATASETS
def find_folders(root: str,
file_client: FileClient) -> Tuple[List[str], Dict[str, int]]:
"""Find classes by folders under a root.
Args:
root (string): root directory of folders
Returns:
Tuple[List[str], Dict[str, int]]:
- folders: The name of sub folders under the root.
- folder_to_idx: The map from folder name to class idx.
"""
folders = list(
file_client.list_dir_or_file(
root,
list_dir=True,
list_file=False,
recursive=False,
))
folders.sort()
folder_to_idx = {folders[i]: i for i in range(len(folders))}
return folders, folder_to_idx
def get_samples(root: str, folder_to_idx: Dict[str, int],
is_valid_file: Callable, file_client: FileClient):
"""Make dataset by walking all images under a root.
Args:
root (string): root directory of folders
folder_to_idx (dict): the map from class name to class idx
is_valid_file (Callable): A function that takes path of a file
and check if the file is a valid sample file.
Returns:
Tuple[list, set]:
- samples: a list of tuple where each element is (image, class_idx)
- empty_folders: The folders don't have any valid files.
"""
samples = []
available_classes = set()
for folder_name in sorted(list(folder_to_idx.keys())):
_dir = file_client.join_path(root, folder_name)
files = list(
file_client.list_dir_or_file(
_dir,
list_dir=False,
list_file=True,
recursive=True,
))
for file in sorted(list(files)):
if is_valid_file(file):
path = file_client.join_path(folder_name, file)
item = (path, folder_to_idx[folder_name])
samples.append(item)
available_classes.add(folder_name)
empty_folders = set(folder_to_idx.keys()) - available_classes
return samples, empty_folders
@DATASETS.register_module()
class CustomDataset(BaseDataset):
"""Custom dataset for classification.
The dataset supports two kinds of annotation format.
1. An annotation file is provided, and each line indicates a sample:
The sample files: ::
data_prefix/
├── folder_1
│ ├── xxx.png
│ ├── xxy.png
│ └── ...
└── folder_2
├── 123.png
├── nsdf3.png
└── ...
The annotation file (the first column is the image path and the second
column is the index of category): ::
folder_1/xxx.png 0
folder_1/xxy.png 1
folder_2/123.png 5
folder_2/nsdf3.png 3
...
Please specify the name of categories by the argument ``classes``.
2. The samples are arranged in the specific way: ::
data_prefix/
├── class_x
│ ├── xxx.png
│ ├── xxy.png
│ └── ...
│ └── xxz.png
└── class_y
├── 123.png
├── nsdf3.png
├── ...
└── asd932_.png
If the ``ann_file`` is specified, the dataset will be generated by the
first way, otherwise, try the second way.
Args:
data_prefix (str): The path of data directory.
pipeline (Sequence[dict]): A list of dict, where each element
represents a operation defined in :mod:`mmcls.datasets.pipelines`.
Defaults to an empty tuple.
classes (str | Sequence[str], optional): Specify names of classes.
- If is string, it should be a file path, and the every line of
the file is a name of a class.
- If is a sequence of string, every item is a name of class.
- If is None, use ``cls.CLASSES`` or the names of sub folders
(If use the second way to arrange samples).
Defaults to None.
ann_file (str, optional): The annotation file. If is string, read
samples paths from the ann_file. If is None, find samples in
``data_prefix``. Defaults to None.
extensions (Sequence[str]): A sequence of allowed extensions. Defaults
to ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif').
test_mode (bool): In train mode or test mode. It's only a mark and
won't be used in this class. Defaults to False.
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details.
If None, automatically inference from the specified path.
Defaults to None.
"""
def __init__(self,
data_prefix: str,
pipeline: Sequence = (),
classes: Union[str, Sequence[str], None] = None,
ann_file: Optional[str] = None,
extensions: Sequence[str] = ('.jpg', '.jpeg', '.png', '.ppm',
'.bmp', '.pgm', '.tif'),
test_mode: bool = False,
file_client_args: Optional[dict] = None):
self.extensions = tuple(set([i.lower() for i in extensions]))
self.file_client_args = file_client_args
super().__init__(
data_prefix=data_prefix,
pipeline=pipeline,
classes=classes,
ann_file=ann_file,
test_mode=test_mode)
def _find_samples(self):
"""find samples from ``data_prefix``."""
file_client = FileClient.infer_client(self.file_client_args,
self.data_prefix)
classes, folder_to_idx = find_folders(self.data_prefix, file_client)
samples, empty_classes = get_samples(
self.data_prefix,
folder_to_idx,
is_valid_file=self.is_valid_file,
file_client=file_client,
)
if len(samples) == 0:
raise RuntimeError(
f'Found 0 files in subfolders of: {self.data_prefix}. '
f'Supported extensions are: {",".join(self.extensions)}')
if self.CLASSES is not None:
assert len(self.CLASSES) == len(classes), \
f"The number of subfolders ({len(classes)}) doesn't match " \
f'the number of specified classes ({len(self.CLASSES)}). ' \
'Please check the data folder.'
else:
self.CLASSES = classes
if empty_classes:
warnings.warn(
'Found no valid file in the folder '
f'{", ".join(empty_classes)}. '
f"Supported extensions are: {', '.join(self.extensions)}",
UserWarning)
self.folder_to_idx = folder_to_idx
return samples
def load_annotations(self):
"""Load image paths and gt_labels."""
if self.ann_file is None:
samples = self._find_samples()
elif isinstance(self.ann_file, str):
lines = mmcv.list_from_file(
self.ann_file, file_client_args=self.file_client_args)
samples = [x.strip().rsplit(' ', 1) for x in lines]
else:
raise TypeError('ann_file must be a str or None')
data_infos = []
for filename, gt_label in samples:
info = {'img_prefix': self.data_prefix}
info['img_info'] = {'filename': filename}
info['gt_label'] = np.array(gt_label, dtype=np.int64)
data_infos.append(info)
return data_infos
def is_valid_file(self, filename: str) -> bool:
"""Check if a file is a valid sample."""
return filename.lower().endswith(self.extensions)
# Copyright (c) OpenMMLab. All rights reserved.
import bisect
import math
from collections import defaultdict
import numpy as np
from mmcv.utils import print_log
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
from .builder import DATASETS
@DATASETS.register_module()
class ConcatDataset(_ConcatDataset):
"""A wrapper of concatenated dataset.
Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
add `get_cat_ids` function.
Args:
datasets (list[:obj:`BaseDataset`]): A list of datasets.
separate_eval (bool): Whether to evaluate the results
separately if it is used as validation dataset.
Defaults to True.
"""
def __init__(self, datasets, separate_eval=True):
super(ConcatDataset, self).__init__(datasets)
self.separate_eval = separate_eval
self.CLASSES = datasets[0].CLASSES
if not separate_eval:
if len(set([type(ds) for ds in datasets])) != 1:
raise NotImplementedError(
'To evaluate a concat dataset non-separately, '
'all the datasets should have same types')
def get_cat_ids(self, idx):
if idx < 0:
if -idx > len(self):
raise ValueError(
'absolute value of index should not exceed dataset length')
idx = len(self) + idx
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
return self.datasets[dataset_idx].get_cat_ids(sample_idx)
def evaluate(self, results, *args, indices=None, logger=None, **kwargs):
"""Evaluate the results.
Args:
results (list[list | tuple]): Testing results of the dataset.
indices (list, optional): The indices of samples corresponding to
the results. It's unavailable on ConcatDataset.
Defaults to None.
logger (logging.Logger | str, optional): Logger used for printing
related information during evaluation. Defaults to None.
Returns:
dict[str: float]: AP results of the total dataset or each separate
dataset if `self.separate_eval=True`.
"""
if indices is not None:
raise NotImplementedError(
'Use indices to evaluate speific samples in a ConcatDataset '
'is not supported by now.')
assert len(results) == len(self), \
('Dataset and results have different sizes: '
f'{len(self)} v.s. {len(results)}')
# Check whether all the datasets support evaluation
for dataset in self.datasets:
assert hasattr(dataset, 'evaluate'), \
f"{type(dataset)} haven't implemented the evaluate function."
if self.separate_eval:
total_eval_results = dict()
for dataset_idx, dataset in enumerate(self.datasets):
start_idx = 0 if dataset_idx == 0 else \
self.cumulative_sizes[dataset_idx-1]
end_idx = self.cumulative_sizes[dataset_idx]
results_per_dataset = results[start_idx:end_idx]
print_log(
f'Evaluateing dataset-{dataset_idx} with '
f'{len(results_per_dataset)} images now',
logger=logger)
eval_results_per_dataset = dataset.evaluate(
results_per_dataset, *args, logger=logger, **kwargs)
for k, v in eval_results_per_dataset.items():
total_eval_results.update({f'{dataset_idx}_{k}': v})
return total_eval_results
else:
original_data_infos = self.datasets[0].data_infos
self.datasets[0].data_infos = sum(
[dataset.data_infos for dataset in self.datasets], [])
eval_results = self.datasets[0].evaluate(
results, logger=logger, **kwargs)
self.datasets[0].data_infos = original_data_infos
return eval_results
@DATASETS.register_module()
class RepeatDataset(object):
"""A wrapper of repeated dataset.
The length of repeated dataset will be `times` larger than the original
dataset. This is useful when the data loading time is long but the dataset
is small. Using RepeatDataset can reduce the data loading time between
epochs.
Args:
dataset (:obj:`BaseDataset`): The dataset to be repeated.
times (int): Repeat times.
"""
def __init__(self, dataset, times):
self.dataset = dataset
self.times = times
self.CLASSES = dataset.CLASSES
self._ori_len = len(self.dataset)
def __getitem__(self, idx):
return self.dataset[idx % self._ori_len]
def get_cat_ids(self, idx):
return self.dataset.get_cat_ids(idx % self._ori_len)
def __len__(self):
return self.times * self._ori_len
def evaluate(self, *args, **kwargs):
raise NotImplementedError(
'evaluate results on a repeated dataset is weird. '
'Please inference and evaluate on the original dataset.')
def __repr__(self):
"""Print the number of instance number."""
dataset_type = 'Test' if self.test_mode else 'Train'
result = (
f'\n{self.__class__.__name__} ({self.dataset.__class__.__name__}) '
f'{dataset_type} dataset with total number of samples {len(self)}.'
)
return result
# Modified from https://github.com/facebookresearch/detectron2/blob/41d475b75a230221e21d9cac5d69655e3415e3a4/detectron2/data/samplers/distributed_sampler.py#L57 # noqa
@DATASETS.register_module()
class ClassBalancedDataset(object):
r"""A wrapper of repeated dataset with repeat factor.
Suitable for training on class imbalanced datasets like LVIS. Following the
sampling strategy in `this paper`_, in each epoch, an image may appear
multiple times based on its "repeat factor".
.. _this paper: https://arxiv.org/pdf/1908.03195.pdf
The repeat factor for an image is a function of the frequency the rarest
category labeled in that image. The "frequency of category c" in [0, 1]
is defined by the fraction of images in the training set (without repeats)
in which category c appears.
The dataset needs to implement :func:`self.get_cat_ids` to support
ClassBalancedDataset.
The repeat factor is computed as followed.
1. For each category c, compute the fraction :math:`f(c)` of images that
contain it.
2. For each category c, compute the category-level repeat factor
.. math::
r(c) = \max(1, \sqrt{\frac{t}{f(c)}})
3. For each image I and its labels :math:`L(I)`, compute the image-level
repeat factor
.. math::
r(I) = \max_{c \in L(I)} r(c)
Args:
dataset (:obj:`BaseDataset`): The dataset to be repeated.
oversample_thr (float): frequency threshold below which data is
repeated. For categories with ``f_c`` >= ``oversample_thr``, there
is no oversampling. For categories with ``f_c`` <
``oversample_thr``, the degree of oversampling following the
square-root inverse frequency heuristic above.
"""
def __init__(self, dataset, oversample_thr):
self.dataset = dataset
self.oversample_thr = oversample_thr
self.CLASSES = dataset.CLASSES
repeat_factors = self._get_repeat_factors(dataset, oversample_thr)
repeat_indices = []
for dataset_index, repeat_factor in enumerate(repeat_factors):
repeat_indices.extend([dataset_index] * math.ceil(repeat_factor))
self.repeat_indices = repeat_indices
flags = []
if hasattr(self.dataset, 'flag'):
for flag, repeat_factor in zip(self.dataset.flag, repeat_factors):
flags.extend([flag] * int(math.ceil(repeat_factor)))
assert len(flags) == len(repeat_indices)
self.flag = np.asarray(flags, dtype=np.uint8)
def _get_repeat_factors(self, dataset, repeat_thr):
# 1. For each category c, compute the fraction # of images
# that contain it: f(c)
category_freq = defaultdict(int)
num_images = len(dataset)
for idx in range(num_images):
cat_ids = set(self.dataset.get_cat_ids(idx))
for cat_id in cat_ids:
category_freq[cat_id] += 1
for k, v in category_freq.items():
assert v > 0, f'caterogy {k} does not contain any images'
category_freq[k] = v / num_images
# 2. For each category c, compute the category-level repeat factor:
# r(c) = max(1, sqrt(t/f(c)))
category_repeat = {
cat_id: max(1.0, math.sqrt(repeat_thr / cat_freq))
for cat_id, cat_freq in category_freq.items()
}
# 3. For each image I and its labels L(I), compute the image-level
# repeat factor:
# r(I) = max_{c in L(I)} r(c)
repeat_factors = []
for idx in range(num_images):
cat_ids = set(self.dataset.get_cat_ids(idx))
repeat_factor = max(
{category_repeat[cat_id]
for cat_id in cat_ids})
repeat_factors.append(repeat_factor)
return repeat_factors
def __getitem__(self, idx):
ori_index = self.repeat_indices[idx]
return self.dataset[ori_index]
def __len__(self):
return len(self.repeat_indices)
def evaluate(self, *args, **kwargs):
raise NotImplementedError(
'evaluate results on a class-balanced dataset is weird. '
'Please inference and evaluate on the original dataset.')
def __repr__(self):
"""Print the number of instance number."""
dataset_type = 'Test' if self.test_mode else 'Train'
result = (
f'\n{self.__class__.__name__} ({self.dataset.__class__.__name__}) '
f'{dataset_type} dataset with total number of samples {len(self)}.'
)
return result
@DATASETS.register_module()
class KFoldDataset:
"""A wrapper of dataset for K-Fold cross-validation.
K-Fold cross-validation divides all the samples in groups of samples,
called folds, of almost equal sizes. And we use k-1 of folds to do training
and use the fold left to do validation.
Args:
dataset (:obj:`BaseDataset`): The dataset to be divided.
fold (int): The fold used to do validation. Defaults to 0.
num_splits (int): The number of all folds. Defaults to 5.
test_mode (bool): Use the training dataset or validation dataset.
Defaults to False.
seed (int, optional): The seed to shuffle the dataset before splitting.
If None, not shuffle the dataset. Defaults to None.
"""
def __init__(self,
dataset,
fold=0,
num_splits=5,
test_mode=False,
seed=None):
self.dataset = dataset
self.CLASSES = dataset.CLASSES
self.test_mode = test_mode
self.num_splits = num_splits
length = len(dataset)
indices = list(range(length))
if isinstance(seed, int):
rng = np.random.default_rng(seed)
rng.shuffle(indices)
test_start = length * fold // num_splits
test_end = length * (fold + 1) // num_splits
if test_mode:
self.indices = indices[test_start:test_end]
else:
self.indices = indices[:test_start] + indices[test_end:]
def get_cat_ids(self, idx):
return self.dataset.get_cat_ids(self.indices[idx])
def get_gt_labels(self):
dataset_gt_labels = self.dataset.get_gt_labels()
gt_labels = np.array([dataset_gt_labels[idx] for idx in self.indices])
return gt_labels
def __getitem__(self, idx):
return self.dataset[self.indices[idx]]
def __len__(self):
return len(self.indices)
def evaluate(self, *args, **kwargs):
kwargs['indices'] = self.indices
return self.dataset.evaluate(*args, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment