Commit cbc25585 authored by limm's avatar limm
Browse files

add mmpretrain/ part

parent 1baf0566
Pipeline #2801 canceled with stages
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Sequence, Union
import mmengine
import numpy as np
import torch
from mmengine.evaluator import BaseMetric
from mmengine.utils import is_seq_of
from mmpretrain.registry import METRICS
from mmpretrain.structures import label_to_onehot
from .single_label import to_tensor
@METRICS.register_module()
class RetrievalRecall(BaseMetric):
r"""Recall evaluation metric for image retrieval.
Args:
topk (int | Sequence[int]): If the ground truth label matches one of
the best **k** predictions, the sample will be regard as a positive
prediction. If the parameter is a tuple, all of top-k recall will
be calculated and outputted together. Defaults to 1.
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
prefix (str, optional): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Defaults to None.
Examples:
Use in the code:
>>> import torch
>>> from mmpretrain.evaluation import RetrievalRecall
>>> # -------------------- The Basic Usage --------------------
>>> y_pred = [[0], [1], [2], [3]]
>>> y_true = [[0, 1], [2], [1], [0, 3]]
>>> RetrievalRecall.calculate(
>>> y_pred, y_true, topk=1, pred_indices=True, target_indices=True)
[tensor([50.])]
>>> # Calculate the recall@1 and recall@5 for non-indices input.
>>> y_score = torch.rand((1000, 10))
>>> import torch.nn.functional as F
>>> y_true = F.one_hot(torch.arange(0, 1000) % 10, num_classes=10)
>>> RetrievalRecall.calculate(y_score, y_true, topk=(1, 5))
[tensor(9.3000), tensor(48.4000)]
>>>
>>> # ------------------- Use with Evalutor -------------------
>>> from mmpretrain.structures import DataSample
>>> from mmengine.evaluator import Evaluator
>>> data_samples = [
... DataSample().set_gt_label([0, 1]).set_pred_score(
... torch.rand(10))
... for i in range(1000)
... ]
>>> evaluator = Evaluator(metrics=RetrievalRecall(topk=(1, 5)))
>>> evaluator.process(data_samples)
>>> evaluator.evaluate(1000)
{'retrieval/Recall@1': 20.700000762939453,
'retrieval/Recall@5': 78.5999984741211}
Use in OpenMMLab configs:
.. code:: python
val_evaluator = dict(type='RetrievalRecall', topk=(1, 5))
test_evaluator = val_evaluator
"""
default_prefix: Optional[str] = 'retrieval'
def __init__(self,
topk: Union[int, Sequence[int]],
collect_device: str = 'cpu',
prefix: Optional[str] = None) -> None:
topk = (topk, ) if isinstance(topk, int) else topk
for k in topk:
if k <= 0:
raise ValueError('`topk` must be a ingter larger than 0 '
'or seq of ingter larger than 0.')
self.topk = topk
super().__init__(collect_device=collect_device, prefix=prefix)
def process(self, data_batch: Sequence[dict],
data_samples: Sequence[dict]):
"""Process one batch of data and predictions.
The processed results should be stored in ``self.results``, which will
be used to computed the metrics when all batches have been processed.
Args:
data_batch (Sequence[dict]): A batch of data from the dataloader.
predictions (Sequence[dict]): A batch of outputs from the model.
"""
for data_sample in data_samples:
pred_score = data_sample['pred_score'].clone()
gt_label = data_sample['gt_label']
if 'gt_score' in data_sample:
target = data_sample.get('gt_score').clone()
else:
num_classes = pred_score.size()[-1]
target = label_to_onehot(gt_label, num_classes)
# Because the retrieval output logit vector will be much larger
# compared to the normal classification, to save resources, the
# evaluation results are computed each batch here and then reduce
# all results at the end.
result = RetrievalRecall.calculate(
pred_score.unsqueeze(0), target.unsqueeze(0), topk=self.topk)
self.results.append(result)
def compute_metrics(self, results: List):
"""Compute the metrics from processed results.
Args:
results (list): The processed results of each batch.
Returns:
Dict: The computed metrics. The keys are the names of the metrics,
and the values are corresponding results.
"""
result_metrics = dict()
for i, k in enumerate(self.topk):
recall_at_k = sum([r[i].item() for r in results]) / len(results)
result_metrics[f'Recall@{k}'] = recall_at_k
return result_metrics
@staticmethod
def calculate(pred: Union[np.ndarray, torch.Tensor],
target: Union[np.ndarray, torch.Tensor],
topk: Union[int, Sequence[int]],
pred_indices: (bool) = False,
target_indices: (bool) = False) -> float:
"""Calculate the average recall.
Args:
pred (torch.Tensor | np.ndarray | Sequence): The prediction
results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with
shape ``(N, M)`` or a sequence of index/onehot
format labels.
target (torch.Tensor | np.ndarray | Sequence): The prediction
results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with
shape ``(N, M)`` or a sequence of index/onehot
format labels.
topk (int, Sequence[int]): Predictions with the k-th highest
scores are considered as positive.
pred_indices (bool): Whether the ``pred`` is a sequence of
category index labels. Defaults to False.
target_indices (bool): Whether the ``target`` is a sequence of
category index labels. Defaults to False.
Returns:
List[float]: the average recalls.
"""
topk = (topk, ) if isinstance(topk, int) else topk
for k in topk:
if k <= 0:
raise ValueError('`topk` must be a ingter larger than 0 '
'or seq of ingter larger than 0.')
max_keep = max(topk)
pred = _format_pred(pred, max_keep, pred_indices)
target = _format_target(target, target_indices)
assert len(pred) == len(target), (
f'Length of `pred`({len(pred)}) and `target` ({len(target)}) '
f'must be the same.')
num_samples = len(pred)
results = []
for k in topk:
recalls = torch.zeros(num_samples)
for i, (sample_pred,
sample_target) in enumerate(zip(pred, target)):
sample_pred = np.array(to_tensor(sample_pred).cpu())
sample_target = np.array(to_tensor(sample_target).cpu())
recalls[i] = int(np.in1d(sample_pred[:k], sample_target).max())
results.append(recalls.mean() * 100)
return results
@METRICS.register_module()
class RetrievalAveragePrecision(BaseMetric):
r"""Calculate the average precision for image retrieval.
Args:
topk (int, optional): Predictions with the k-th highest scores are
considered as positive.
mode (str, optional): The mode to calculate AP, choose from
'IR'(information retrieval) and 'integrate'. Defaults to 'IR'.
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
prefix (str, optional): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Defaults to None.
Note:
If the ``mode`` set to 'IR', use the stanford AP calculation of
information retrieval as in wikipedia page[1]; if set to 'integrate',
the method implemented integrates over the precision-recall curve
by averaging two adjacent precision points, then multiplying by the
recall step like mAP in Detection task. This is the convention for
the Revisited Oxford/Paris datasets[2].
References:
[1] `Wikipedia entry for the Average precision <https://en.wikipedia.
org/wiki/Evaluation_measures_(information_retrieval)#Average_precision>`_
[2] `The Oxford Buildings Dataset
<https://www.robots.ox.ac.uk/~vgg/data/oxbuildings/>`_
Examples:
Use in code:
>>> import torch
>>> import numpy as np
>>> from mmcls.evaluation import RetrievalAveragePrecision
>>> # using index format inputs
>>> pred = [ torch.Tensor([idx for idx in range(100)]) ] * 3
>>> target = [[0, 3, 6, 8, 35], [1, 2, 54, 105], [2, 42, 205]]
>>> RetrievalAveragePrecision.calculate(pred, target, 10, True, True)
29.246031746031747
>>> # using tensor format inputs
>>> pred = np.array([np.linspace(0.95, 0.05, 10)] * 2)
>>> target = torch.Tensor([[1, 0, 1, 0, 0, 1, 0, 0, 1, 1]] * 2)
>>> RetrievalAveragePrecision.calculate(pred, target, 10)
62.222222222222214
Use in OpenMMLab config files:
.. code:: python
val_evaluator = dict(type='RetrievalAveragePrecision', topk=100)
test_evaluator = val_evaluator
"""
default_prefix: Optional[str] = 'retrieval'
def __init__(self,
topk: Optional[int] = None,
mode: Optional[str] = 'IR',
collect_device: str = 'cpu',
prefix: Optional[str] = None) -> None:
if topk is None or (isinstance(topk, int) and topk <= 0):
raise ValueError('`topk` must be a ingter larger than 0.')
mode_options = ['IR', 'integrate']
assert mode in mode_options, \
f'Invalid `mode` argument, please specify from {mode_options}.'
self.topk = topk
self.mode = mode
super().__init__(collect_device=collect_device, prefix=prefix)
def process(self, data_batch: Sequence[dict],
data_samples: Sequence[dict]):
"""Process one batch of data and predictions.
The processed results should be stored in ``self.results``, which will
be used to computed the metrics when all batches have been processed.
Args:
data_batch (Sequence[dict]): A batch of data from the dataloader.
predictions (Sequence[dict]): A batch of outputs from the model.
"""
for data_sample in data_samples:
pred_score = data_sample.get('pred_score').clone()
if 'gt_score' in data_sample:
target = data_sample.get('gt_score').clone()
else:
gt_label = data_sample.get('gt_label')
num_classes = pred_score.size()[-1]
target = label_to_onehot(gt_label, num_classes)
# Because the retrieval output logit vector will be much larger
# compared to the normal classification, to save resources, the
# evaluation results are computed each batch here and then reduce
# all results at the end.
result = RetrievalAveragePrecision.calculate(
pred_score.unsqueeze(0),
target.unsqueeze(0),
self.topk,
mode=self.mode)
self.results.append(result)
def compute_metrics(self, results: List):
"""Compute the metrics from processed results.
Args:
results (list): The processed results of each batch.
Returns:
Dict: The computed metrics. The keys are the names of the metrics,
and the values are corresponding results.
"""
result_metrics = dict()
result_metrics[f'mAP@{self.topk}'] = np.mean(self.results).item()
return result_metrics
@staticmethod
def calculate(pred: Union[np.ndarray, torch.Tensor],
target: Union[np.ndarray, torch.Tensor],
topk: Optional[int] = None,
pred_indices: (bool) = False,
target_indices: (bool) = False,
mode: str = 'IR') -> float:
"""Calculate the average precision.
Args:
pred (torch.Tensor | np.ndarray | Sequence): The prediction
results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with
shape ``(N, M)`` or a sequence of index/onehot
format labels.
target (torch.Tensor | np.ndarray | Sequence): The prediction
results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with
shape ``(N, M)`` or a sequence of index/onehot
format labels.
topk (int, optional): Predictions with the k-th highest scores
are considered as positive.
pred_indices (bool): Whether the ``pred`` is a sequence of
category index labels. Defaults to False.
target_indices (bool): Whether the ``target`` is a sequence of
category index labels. Defaults to False.
mode (Optional[str]): The mode to calculate AP, choose from
'IR'(information retrieval) and 'integrate'. Defaults to 'IR'.
Note:
If the ``mode`` set to 'IR', use the stanford AP calculation of
information retrieval as in wikipedia page; if set to 'integrate',
the method implemented integrates over the precision-recall curve
by averaging two adjacent precision points, then multiplying by the
recall step like mAP in Detection task. This is the convention for
the Revisited Oxford/Paris datasets.
Returns:
float: the average precision of the query image.
References:
[1] `Wikipedia entry for Average precision(information_retrieval)
<https://en.wikipedia.org/wiki/Evaluation_measures_
(information_retrieval)#Average_precision>`_
[2] `The Oxford Buildings Dataset <https://www.robots.ox.ac.uk/
~vgg/data/oxbuildings/`_
"""
if topk is None or (isinstance(topk, int) and topk <= 0):
raise ValueError('`topk` must be a ingter larger than 0.')
mode_options = ['IR', 'integrate']
assert mode in mode_options, \
f'Invalid `mode` argument, please specify from {mode_options}.'
pred = _format_pred(pred, topk, pred_indices)
target = _format_target(target, target_indices)
assert len(pred) == len(target), (
f'Length of `pred`({len(pred)}) and `target` ({len(target)}) '
f'must be the same.')
num_samples = len(pred)
aps = np.zeros(num_samples)
for i, (sample_pred, sample_target) in enumerate(zip(pred, target)):
aps[i] = _calculateAp_for_sample(sample_pred, sample_target, mode)
return aps.mean()
def _calculateAp_for_sample(pred, target, mode):
pred = np.array(to_tensor(pred).cpu())
target = np.array(to_tensor(target).cpu())
num_preds = len(pred)
# TODO: use ``torch.isin`` in torch1.10.
positive_ranks = np.arange(num_preds)[np.in1d(pred, target)]
ap = 0
for i, rank in enumerate(positive_ranks):
if mode == 'IR':
precision = (i + 1) / (rank + 1)
ap += precision
elif mode == 'integrate':
# code are modified from https://www.robots.ox.ac.uk/~vgg/data/oxbuildings/compute_ap.cpp # noqa:
old_precision = i / rank if rank > 0 else 1
cur_precision = (i + 1) / (rank + 1)
prediction = (old_precision + cur_precision) / 2
ap += prediction
ap = ap / len(target)
return ap * 100
def _format_pred(label, topk=None, is_indices=False):
"""format various label to List[indices]."""
if is_indices:
assert isinstance(label, Sequence), \
'`pred` must be Sequence of indices when' \
f' `pred_indices` set to True, but get {type(label)}'
for i, sample_pred in enumerate(label):
assert is_seq_of(sample_pred, int) or isinstance(
sample_pred, (np.ndarray, torch.Tensor)), \
'`pred` should be Sequence of indices when `pred_indices`' \
f'set to True. but pred[{i}] is {sample_pred}'
if topk:
label[i] = sample_pred[:min(topk, len(sample_pred))]
return label
if isinstance(label, np.ndarray):
label = torch.from_numpy(label)
elif not isinstance(label, torch.Tensor):
raise TypeError(f'The pred must be type of torch.tensor, '
f'np.ndarray or Sequence but get {type(label)}.')
topk = topk if topk else label.size()[-1]
_, indices = label.topk(topk)
return indices
def _format_target(label, is_indices=False):
"""format various label to List[indices]."""
if is_indices:
assert isinstance(label, Sequence), \
'`target` must be Sequence of indices when' \
f' `target_indices` set to True, but get {type(label)}'
for i, sample_gt in enumerate(label):
assert is_seq_of(sample_gt, int) or isinstance(
sample_gt, (np.ndarray, torch.Tensor)), \
'`target` should be Sequence of indices when ' \
f'`target_indices` set to True. but target[{i}] is {sample_gt}'
return label
if isinstance(label, np.ndarray):
label = torch.from_numpy(label)
elif isinstance(label, Sequence) and not mmengine.is_str(label):
label = torch.tensor(label)
elif not isinstance(label, torch.Tensor):
raise TypeError(f'The pred must be type of torch.tensor, '
f'np.ndarray or Sequence but get {type(label)}.')
indices = [sample_gt.nonzero().squeeze(-1) for sample_gt in label]
return indices
# Copyright (c) OpenMMLab. All rights reserved.
import random
from typing import List, Optional
from mmengine.evaluator import BaseMetric
from mmpretrain.registry import METRICS
def get_pred_idx(prediction: str, choices: List[str],
options: List[str]) -> int: # noqa
"""Get the index (e.g. 2) from the prediction (e.g. 'C')
Args:
prediction (str): The prediction from the model,
from ['A', 'B', 'C', 'D', 'E']
choices (List(str)): The choices for the question,
from ['A', 'B', 'C', 'D', 'E']
options (List(str)): The options for the question,
from ['A', 'B', 'C', 'D', 'E']
Returns:
int: The index of the prediction, from [0, 1, 2, 3, 4]
"""
if prediction in options[:len(choices)]:
return options.index(prediction)
else:
return random.choice(range(len(choices)))
@METRICS.register_module()
class ScienceQAMetric(BaseMetric):
"""Evaluation Metric for ScienceQA.
Args:
options (List(str)): Options for each question. Defaults to
["A", "B", "C", "D", "E"].
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
prefix (str, optional): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Should be modified according to the
`retrieval_type` for unambiguous results. Defaults to TR.
"""
def __init__(self,
options: List[str] = ['A', 'B', 'C', 'D', 'E'],
collect_device: str = 'cpu',
prefix: Optional[str] = None) -> None:
super().__init__(collect_device=collect_device, prefix=prefix)
self.options = options
def process(self, data_batch, data_samples) -> None:
"""Process one batch of data samples.
data_samples should contain the following keys:
1. pred_answer (str): The prediction from the model,
from ['A', 'B', 'C', 'D', 'E']
2. choices (List(str)): The choices for the question,
from ['A', 'B', 'C', 'D', 'E']
3. grade (int): The grade for the question, from grade1 to grade12
4. subject (str): The subject for the question, from
['natural science', 'social science', 'language science']
5. answer (str): The answer for the question, from
['A', 'B', 'C', 'D', 'E']
6. hint (str): The hint for the question
7. has_image (bool): Whether or not the question has image
The processed results should be stored in ``self.results``, which will
be used to computed the metrics when all batches have been processed.
Args:
data_batch: A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from the model.
"""
for data_sample in data_samples:
result = dict()
choices = data_sample.get('choices')
result['prediction'] = get_pred_idx(
data_sample.get('pred_answer'), choices, self.options)
result['grade'] = data_sample.get('grade')
result['subject'] = data_sample.get('subject')
result['answer'] = data_sample.get('gt_answer')
hint = data_sample.get('hint')
has_image = data_sample.get('has_image', False)
result['no_context'] = True if not has_image and len(
hint) == 0 else False # noqa
result['has_text'] = True if len(hint) > 0 else False
result['has_image'] = has_image
# Save the result to `self.results`.
self.results.append(result)
def compute_metrics(self, results: List) -> dict:
"""Compute the metrics from processed results.
Args:
results (dict): The processed results of each batch.
Returns:
Dict: The computed metrics. The keys are the names of the metrics,
and the values are corresponding results.
"""
# NOTICE: don't access `self.results` from the method.
metrics = dict()
all_acc = []
acc_natural = []
acc_social = []
acc_language = []
acc_has_text = []
acc_has_image = []
acc_no_context = []
acc_grade_1_6 = []
acc_grade_7_12 = []
for result in results:
correct = result['prediction'] == result['answer']
all_acc.append(correct)
# different subjects
if result['subject'] == 'natural science':
acc_natural.append(correct)
elif result['subject'] == 'social science':
acc_social.append(correct)
elif result['subject'] == 'language science':
acc_language.append(correct)
# different context
if result['has_text']:
acc_has_text.append(correct)
elif result['has_image']:
acc_has_image.append(correct)
elif result['no_context']:
acc_no_context.append(correct)
# different grade
if result['grade'] in [
'grade1', 'grade2', 'grade3', 'grade4', 'grade5', 'grade6'
]:
acc_grade_1_6.append(correct)
elif result['grade'] in [
'grade7', 'grade8', 'grade9', 'grade10', 'grade11',
'grade12'
]:
acc_grade_7_12.append(correct)
metrics['all_acc'] = sum(all_acc) / len(all_acc)
if len(acc_natural) > 0:
metrics['acc_natural'] = sum(acc_natural) / len(acc_natural)
if len(acc_social) > 0:
metrics['acc_social'] = sum(acc_social) / len(acc_social)
if len(acc_language) > 0:
metrics['acc_language'] = sum(acc_language) / len(acc_language)
if len(acc_has_text) > 0:
metrics['acc_has_text'] = sum(acc_has_text) / len(acc_has_text)
if len(acc_has_image) > 0:
metrics['acc_has_image'] = sum(acc_has_image) / len(acc_has_image)
if len(acc_no_context) > 0:
metrics['acc_no_context'] = sum(acc_no_context) / len(
acc_no_context)
if len(acc_grade_1_6) > 0:
metrics['acc_grade_1_6'] = sum(acc_grade_1_6) / len(acc_grade_1_6)
if len(acc_grade_7_12) > 0:
metrics['acc_grade_7_12'] = sum(acc_grade_7_12) / len(
acc_grade_7_12)
return metrics
# Copyright (c) OpenMMLab. All rights reserved.
import csv
import os
import os.path as osp
from typing import List, Sequence
import numpy as np
import torch
from mmengine.dist.utils import get_rank
from mmengine.evaluator import BaseMetric
from mmpretrain.registry import METRICS
@METRICS.register_module()
class ShapeBiasMetric(BaseMetric):
"""Evaluate the model on ``cue_conflict`` dataset.
This module will evaluate the model on an OOD dataset, cue_conflict, in
order to measure the shape bias of the model. In addition to compuate the
Top-1 accuracy, this module also generate a csv file to record the
detailed prediction results, such that this csv file can be used to
generate the shape bias curve.
Args:
csv_dir (str): The directory to save the csv file.
model_name (str): The name of the csv file. Please note that the
model name should be an unique identifier.
dataset_name (str): The name of the dataset. Default: 'cue_conflict'.
"""
# mapping several classes from ImageNet-1K to the same category
airplane_indices = [404]
bear_indices = [294, 295, 296, 297]
bicycle_indices = [444, 671]
bird_indices = [
8, 10, 11, 12, 13, 14, 15, 16, 18, 19, 20, 22, 23, 24, 80, 81, 82, 83,
87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 98, 99, 100, 127, 128, 129,
130, 131, 132, 133, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144,
145
]
boat_indices = [472, 554, 625, 814, 914]
bottle_indices = [440, 720, 737, 898, 899, 901, 907]
car_indices = [436, 511, 817]
cat_indices = [281, 282, 283, 284, 285, 286]
chair_indices = [423, 559, 765, 857]
clock_indices = [409, 530, 892]
dog_indices = [
152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165,
166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179,
180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 193, 194,
195, 196, 197, 198, 199, 200, 201, 202, 203, 205, 206, 207, 208, 209,
210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223,
224, 225, 226, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238,
239, 240, 241, 243, 244, 245, 246, 247, 248, 249, 250, 252, 253, 254,
255, 256, 257, 259, 261, 262, 263, 265, 266, 267, 268
]
elephant_indices = [385, 386]
keyboard_indices = [508, 878]
knife_indices = [499]
oven_indices = [766]
truck_indices = [555, 569, 656, 675, 717, 734, 864, 867]
def __init__(self,
csv_dir: str,
model_name: str,
dataset_name: str = 'cue_conflict',
**kwargs) -> None:
super().__init__(**kwargs)
self.categories = sorted([
'knife', 'keyboard', 'elephant', 'bicycle', 'airplane', 'clock',
'oven', 'chair', 'bear', 'boat', 'cat', 'bottle', 'truck', 'car',
'bird', 'dog'
])
self.csv_dir = csv_dir
self.model_name = model_name
self.dataset_name = dataset_name
if get_rank() == 0:
self.csv_path = self.create_csv()
def process(self, data_batch, data_samples: Sequence[dict]) -> None:
"""Process one batch of data samples.
The processed results should be stored in ``self.results``, which will
be used to computed the metrics when all batches have been processed.
Args:
data_batch: A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from the model.
"""
for data_sample in data_samples:
result = dict()
if 'pred_score' in data_sample:
result['pred_score'] = data_sample['pred_score'].cpu()
else:
result['pred_label'] = data_sample['pred_label'].cpu()
result['gt_label'] = data_sample['gt_label'].cpu()
result['gt_category'] = data_sample['img_path'].split('/')[-2]
result['img_name'] = data_sample['img_path'].split('/')[-1]
aggregated_category_probabilities = []
# get the prediction for each category of current instance
for category in self.categories:
category_indices = getattr(self, f'{category}_indices')
category_probabilities = torch.gather(
result['pred_score'], 0,
torch.tensor(category_indices)).mean()
aggregated_category_probabilities.append(
category_probabilities)
# sort the probabilities in descending order
pred_indices = torch.stack(aggregated_category_probabilities
).argsort(descending=True).numpy()
result['pred_category'] = np.take(self.categories, pred_indices)
# Save the result to `self.results`.
self.results.append(result)
def create_csv(self) -> str:
"""Create a csv file to store the results."""
session_name = 'session-1'
csv_path = osp.join(
self.csv_dir, self.dataset_name + '_' + self.model_name + '_' +
session_name + '.csv')
if osp.exists(csv_path):
os.remove(csv_path)
directory = osp.dirname(csv_path)
if not osp.exists(directory):
os.makedirs(directory, exist_ok=True)
with open(csv_path, 'w') as f:
writer = csv.writer(f)
writer.writerow([
'subj', 'session', 'trial', 'rt', 'object_response',
'category', 'condition', 'imagename'
])
return csv_path
def dump_results_to_csv(self, results: List[dict]) -> None:
"""Dump the results to a csv file.
Args:
results (List[dict]): A list of results.
"""
for i, result in enumerate(results):
img_name = result['img_name']
category = result['gt_category']
condition = 'NaN'
with open(self.csv_path, 'a') as f:
writer = csv.writer(f)
writer.writerow([
self.model_name, 1, i + 1, 'NaN',
result['pred_category'][0], category, condition, img_name
])
def compute_metrics(self, results: List[dict]) -> dict:
"""Compute the metrics from the results.
Args:
results (List[dict]): A list of results.
Returns:
dict: A dict of metrics.
"""
if get_rank() == 0:
self.dump_results_to_csv(results)
metrics = dict()
metrics['accuracy/top1'] = np.mean([
result['pred_category'][0] == result['gt_category']
for result in results
])
return metrics
# Copyright (c) OpenMMLab. All rights reserved.
from itertools import product
from typing import List, Optional, Sequence, Union
import mmengine
import numpy as np
import torch
import torch.nn.functional as F
from mmengine.evaluator import BaseMetric
from mmpretrain.registry import METRICS
def to_tensor(value):
"""Convert value to torch.Tensor."""
if isinstance(value, np.ndarray):
value = torch.from_numpy(value)
elif isinstance(value, Sequence) and not mmengine.is_str(value):
value = torch.tensor(value)
elif not isinstance(value, torch.Tensor):
raise TypeError(f'{type(value)} is not an available argument.')
return value
def _precision_recall_f1_support(pred_positive, gt_positive, average):
"""calculate base classification task metrics, such as precision, recall,
f1_score, support."""
average_options = ['micro', 'macro', None]
assert average in average_options, 'Invalid `average` argument, ' \
f'please specify from {average_options}.'
# ignore -1 target such as difficult sample that is not wanted
# in evaluation results.
# only for calculate multi-label without affecting single-label behavior
ignored_index = gt_positive == -1
pred_positive[ignored_index] = 0
gt_positive[ignored_index] = 0
class_correct = (pred_positive & gt_positive)
if average == 'micro':
tp_sum = class_correct.sum()
pred_sum = pred_positive.sum()
gt_sum = gt_positive.sum()
else:
tp_sum = class_correct.sum(0)
pred_sum = pred_positive.sum(0)
gt_sum = gt_positive.sum(0)
precision = tp_sum / torch.clamp(pred_sum, min=1).float() * 100
recall = tp_sum / torch.clamp(gt_sum, min=1).float() * 100
f1_score = 2 * precision * recall / torch.clamp(
precision + recall, min=torch.finfo(torch.float32).eps)
if average in ['macro', 'micro']:
precision = precision.mean(0)
recall = recall.mean(0)
f1_score = f1_score.mean(0)
support = gt_sum.sum(0)
else:
support = gt_sum
return precision, recall, f1_score, support
@METRICS.register_module()
class Accuracy(BaseMetric):
r"""Accuracy evaluation metric.
For either binary classification or multi-class classification, the
accuracy is the fraction of correct predictions in all predictions:
.. math::
\text{Accuracy} = \frac{N_{\text{correct}}}{N_{\text{all}}}
Args:
topk (int | Sequence[int]): If the ground truth label matches one of
the best **k** predictions, the sample will be regard as a positive
prediction. If the parameter is a tuple, all of top-k accuracy will
be calculated and outputted together. Defaults to 1.
thrs (Sequence[float | None] | float | None): If a float, predictions
with score lower than the threshold will be regard as the negative
prediction. If None, not apply threshold. If the parameter is a
tuple, accuracy based on all thresholds will be calculated and
outputted together. Defaults to 0.
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
prefix (str, optional): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Defaults to None.
Examples:
>>> import torch
>>> from mmpretrain.evaluation import Accuracy
>>> # -------------------- The Basic Usage --------------------
>>> y_pred = [0, 2, 1, 3]
>>> y_true = [0, 1, 2, 3]
>>> Accuracy.calculate(y_pred, y_true)
tensor([50.])
>>> # Calculate the top1 and top5 accuracy.
>>> y_score = torch.rand((1000, 10))
>>> y_true = torch.zeros((1000, ))
>>> Accuracy.calculate(y_score, y_true, topk=(1, 5))
[[tensor([9.9000])], [tensor([51.5000])]]
>>>
>>> # ------------------- Use with Evalutor -------------------
>>> from mmpretrain.structures import DataSample
>>> from mmengine.evaluator import Evaluator
>>> data_samples = [
... DataSample().set_gt_label(0).set_pred_score(torch.rand(10))
... for i in range(1000)
... ]
>>> evaluator = Evaluator(metrics=Accuracy(topk=(1, 5)))
>>> evaluator.process(data_samples)
>>> evaluator.evaluate(1000)
{
'accuracy/top1': 9.300000190734863,
'accuracy/top5': 51.20000076293945
}
"""
default_prefix: Optional[str] = 'accuracy'
def __init__(self,
topk: Union[int, Sequence[int]] = (1, ),
thrs: Union[float, Sequence[Union[float, None]], None] = 0.,
collect_device: str = 'cpu',
prefix: Optional[str] = None) -> None:
super().__init__(collect_device=collect_device, prefix=prefix)
if isinstance(topk, int):
self.topk = (topk, )
else:
self.topk = tuple(topk)
if isinstance(thrs, float) or thrs is None:
self.thrs = (thrs, )
else:
self.thrs = tuple(thrs)
def process(self, data_batch, data_samples: Sequence[dict]):
"""Process one batch of data samples.
The processed results should be stored in ``self.results``, which will
be used to computed the metrics when all batches have been processed.
Args:
data_batch: A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from the model.
"""
for data_sample in data_samples:
result = dict()
if 'pred_score' in data_sample:
result['pred_score'] = data_sample['pred_score'].cpu()
else:
result['pred_label'] = data_sample['pred_label'].cpu()
result['gt_label'] = data_sample['gt_label'].cpu()
# Save the result to `self.results`.
self.results.append(result)
def compute_metrics(self, results: List):
"""Compute the metrics from processed results.
Args:
results (dict): The processed results of each batch.
Returns:
Dict: The computed metrics. The keys are the names of the metrics,
and the values are corresponding results.
"""
# NOTICE: don't access `self.results` from the method.
metrics = {}
# concat
target = torch.cat([res['gt_label'] for res in results])
if 'pred_score' in results[0]:
pred = torch.stack([res['pred_score'] for res in results])
try:
acc = self.calculate(pred, target, self.topk, self.thrs)
except ValueError as e:
# If the topk is invalid.
raise ValueError(
str(e) + ' Please check the `val_evaluator` and '
'`test_evaluator` fields in your config file.')
multi_thrs = len(self.thrs) > 1
for i, k in enumerate(self.topk):
for j, thr in enumerate(self.thrs):
name = f'top{k}'
if multi_thrs:
name += '_no-thr' if thr is None else f'_thr-{thr:.2f}'
metrics[name] = acc[i][j].item()
else:
# If only label in the `pred_label`.
pred = torch.cat([res['pred_label'] for res in results])
acc = self.calculate(pred, target, self.topk, self.thrs)
metrics['top1'] = acc.item()
return metrics
@staticmethod
def calculate(
pred: Union[torch.Tensor, np.ndarray, Sequence],
target: Union[torch.Tensor, np.ndarray, Sequence],
topk: Sequence[int] = (1, ),
thrs: Sequence[Union[float, None]] = (0., ),
) -> Union[torch.Tensor, List[List[torch.Tensor]]]:
"""Calculate the accuracy.
Args:
pred (torch.Tensor | np.ndarray | Sequence): The prediction
results. It can be labels (N, ), or scores of every
class (N, C).
target (torch.Tensor | np.ndarray | Sequence): The target of
each prediction with shape (N, ).
thrs (Sequence[float | None]): Predictions with scores under
the thresholds are considered negative. It's only used
when ``pred`` is scores. None means no thresholds.
Defaults to (0., ).
thrs (Sequence[float]): Predictions with scores under
the thresholds are considered negative. It's only used
when ``pred`` is scores. Defaults to (0., ).
Returns:
torch.Tensor | List[List[torch.Tensor]]: Accuracy.
- torch.Tensor: If the ``pred`` is a sequence of label instead of
score (number of dimensions is 1). Only return a top-1 accuracy
tensor, and ignore the argument ``topk` and ``thrs``.
- List[List[torch.Tensor]]: If the ``pred`` is a sequence of score
(number of dimensions is 2). Return the accuracy on each ``topk``
and ``thrs``. And the first dim is ``topk``, the second dim is
``thrs``.
"""
pred = to_tensor(pred)
target = to_tensor(target).to(torch.int64)
num = pred.size(0)
assert pred.size(0) == target.size(0), \
f"The size of pred ({pred.size(0)}) doesn't match "\
f'the target ({target.size(0)}).'
if pred.ndim == 1:
# For pred label, ignore topk and acc
pred_label = pred.int()
correct = pred.eq(target).float().sum(0, keepdim=True)
acc = correct.mul_(100. / num)
return acc
else:
# For pred score, calculate on all topk and thresholds.
pred = pred.float()
maxk = max(topk)
if maxk > pred.size(1):
raise ValueError(
f'Top-{maxk} accuracy is unavailable since the number of '
f'categories is {pred.size(1)}.')
pred_score, pred_label = pred.topk(maxk, dim=1)
pred_label = pred_label.t()
correct = pred_label.eq(target.view(1, -1).expand_as(pred_label))
results = []
for k in topk:
results.append([])
for thr in thrs:
# Only prediction values larger than thr are counted
# as correct
_correct = correct
if thr is not None:
_correct = _correct & (pred_score.t() > thr)
correct_k = _correct[:k].reshape(-1).float().sum(
0, keepdim=True)
acc = correct_k.mul_(100. / num)
results[-1].append(acc)
return results
@METRICS.register_module()
class SingleLabelMetric(BaseMetric):
r"""A collection of precision, recall, f1-score and support for
single-label tasks.
The collection of metrics is for single-label multi-class classification.
And all these metrics are based on the confusion matrix of every category:
.. image:: ../../_static/image/confusion-matrix.png
:width: 60%
:align: center
All metrics can be formulated use variables above:
**Precision** is the fraction of correct predictions in all predictions:
.. math::
\text{Precision} = \frac{TP}{TP+FP}
**Recall** is the fraction of correct predictions in all targets:
.. math::
\text{Recall} = \frac{TP}{TP+FN}
**F1-score** is the harmonic mean of the precision and recall:
.. math::
\text{F1-score} = \frac{2\times\text{Recall}\times\text{Precision}}{\text{Recall}+\text{Precision}}
**Support** is the number of samples:
.. math::
\text{Support} = TP + TN + FN + FP
Args:
thrs (Sequence[float | None] | float | None): If a float, predictions
with score lower than the threshold will be regard as the negative
prediction. If None, only the top-1 prediction will be regard as
the positive prediction. If the parameter is a tuple, accuracy
based on all thresholds will be calculated and outputted together.
Defaults to 0.
items (Sequence[str]): The detailed metric items to evaluate, select
from "precision", "recall", "f1-score" and "support".
Defaults to ``('precision', 'recall', 'f1-score')``.
average (str | None): How to calculate the final metrics from the
confusion matrix of every category. It supports three modes:
- `"macro"`: Calculate metrics for each category, and calculate
the mean value over all categories.
- `"micro"`: Average the confusion matrix over all categories and
calculate metrics on the mean confusion matrix.
- `None`: Calculate metrics of every category and output directly.
Defaults to "macro".
num_classes (int, optional): The number of classes. Defaults to None.
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
prefix (str, optional): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Defaults to None.
Examples:
>>> import torch
>>> from mmpretrain.evaluation import SingleLabelMetric
>>> # -------------------- The Basic Usage --------------------
>>> y_pred = [0, 1, 1, 3]
>>> y_true = [0, 2, 1, 3]
>>> # Output precision, recall, f1-score and support.
>>> SingleLabelMetric.calculate(y_pred, y_true, num_classes=4)
(tensor(62.5000), tensor(75.), tensor(66.6667), tensor(4))
>>> # Calculate with different thresholds.
>>> y_score = torch.rand((1000, 10))
>>> y_true = torch.zeros((1000, ))
>>> SingleLabelMetric.calculate(y_score, y_true, thrs=(0., 0.9))
[(tensor(10.), tensor(0.9500), tensor(1.7352), tensor(1000)),
(tensor(10.), tensor(0.5500), tensor(1.0427), tensor(1000))]
>>>
>>> # ------------------- Use with Evalutor -------------------
>>> from mmpretrain.structures import DataSample
>>> from mmengine.evaluator import Evaluator
>>> data_samples = [
... DataSample().set_gt_label(i%5).set_pred_score(torch.rand(5))
... for i in range(1000)
... ]
>>> evaluator = Evaluator(metrics=SingleLabelMetric())
>>> evaluator.process(data_samples)
>>> evaluator.evaluate(1000)
{'single-label/precision': 19.650691986083984,
'single-label/recall': 19.600000381469727,
'single-label/f1-score': 19.619548797607422}
>>> # Evaluate on each class
>>> evaluator = Evaluator(metrics=SingleLabelMetric(average=None))
>>> evaluator.process(data_samples)
>>> evaluator.evaluate(1000)
{
'single-label/precision_classwise': [21.1, 18.7, 17.8, 19.4, 16.1],
'single-label/recall_classwise': [18.5, 18.5, 17.0, 20.0, 18.0],
'single-label/f1-score_classwise': [19.7, 18.6, 17.1, 19.7, 17.0]
}
""" # noqa: E501
default_prefix: Optional[str] = 'single-label'
def __init__(self,
thrs: Union[float, Sequence[Union[float, None]], None] = 0.,
items: Sequence[str] = ('precision', 'recall', 'f1-score'),
average: Optional[str] = 'macro',
num_classes: Optional[int] = None,
collect_device: str = 'cpu',
prefix: Optional[str] = None) -> None:
super().__init__(collect_device=collect_device, prefix=prefix)
if isinstance(thrs, float) or thrs is None:
self.thrs = (thrs, )
else:
self.thrs = tuple(thrs)
for item in items:
assert item in ['precision', 'recall', 'f1-score', 'support'], \
f'The metric {item} is not supported by `SingleLabelMetric`,' \
' please specify from "precision", "recall", "f1-score" and ' \
'"support".'
self.items = tuple(items)
self.average = average
self.num_classes = num_classes
def process(self, data_batch, data_samples: Sequence[dict]):
"""Process one batch of data samples.
The processed results should be stored in ``self.results``, which will
be used to computed the metrics when all batches have been processed.
Args:
data_batch: A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from the model.
"""
for data_sample in data_samples:
result = dict()
if 'pred_score' in data_sample:
result['pred_score'] = data_sample['pred_score'].cpu()
else:
num_classes = self.num_classes or data_sample.get(
'num_classes')
assert num_classes is not None, \
'The `num_classes` must be specified if no `pred_score`.'
result['pred_label'] = data_sample['pred_label'].cpu()
result['num_classes'] = num_classes
result['gt_label'] = data_sample['gt_label'].cpu()
# Save the result to `self.results`.
self.results.append(result)
def compute_metrics(self, results: List):
"""Compute the metrics from processed results.
Args:
results (list): The processed results of each batch.
Returns:
Dict: The computed metrics. The keys are the names of the metrics,
and the values are corresponding results.
"""
# NOTICE: don't access `self.results` from the method. `self.results`
# are a list of results from multiple batch, while the input `results`
# are the collected results.
metrics = {}
def pack_results(precision, recall, f1_score, support):
single_metrics = {}
if 'precision' in self.items:
single_metrics['precision'] = precision
if 'recall' in self.items:
single_metrics['recall'] = recall
if 'f1-score' in self.items:
single_metrics['f1-score'] = f1_score
if 'support' in self.items:
single_metrics['support'] = support
return single_metrics
# concat
target = torch.cat([res['gt_label'] for res in results])
if 'pred_score' in results[0]:
pred = torch.stack([res['pred_score'] for res in results])
metrics_list = self.calculate(
pred, target, thrs=self.thrs, average=self.average)
multi_thrs = len(self.thrs) > 1
for i, thr in enumerate(self.thrs):
if multi_thrs:
suffix = '_no-thr' if thr is None else f'_thr-{thr:.2f}'
else:
suffix = ''
for k, v in pack_results(*metrics_list[i]).items():
metrics[k + suffix] = v
else:
# If only label in the `pred_label`.
pred = torch.cat([res['pred_label'] for res in results])
res = self.calculate(
pred,
target,
average=self.average,
num_classes=results[0]['num_classes'])
metrics = pack_results(*res)
result_metrics = dict()
for k, v in metrics.items():
if self.average is None:
result_metrics[k + '_classwise'] = v.cpu().detach().tolist()
elif self.average == 'micro':
result_metrics[k + f'_{self.average}'] = v.item()
else:
result_metrics[k] = v.item()
return result_metrics
@staticmethod
def calculate(
pred: Union[torch.Tensor, np.ndarray, Sequence],
target: Union[torch.Tensor, np.ndarray, Sequence],
thrs: Sequence[Union[float, None]] = (0., ),
average: Optional[str] = 'macro',
num_classes: Optional[int] = None,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Calculate the precision, recall, f1-score and support.
Args:
pred (torch.Tensor | np.ndarray | Sequence): The prediction
results. It can be labels (N, ), or scores of every
class (N, C).
target (torch.Tensor | np.ndarray | Sequence): The target of
each prediction with shape (N, ).
thrs (Sequence[float | None]): Predictions with scores under
the thresholds are considered negative. It's only used
when ``pred`` is scores. None means no thresholds.
Defaults to (0., ).
average (str | None): How to calculate the final metrics from
the confusion matrix of every category. It supports three
modes:
- `"macro"`: Calculate metrics for each category, and calculate
the mean value over all categories.
- `"micro"`: Average the confusion matrix over all categories
and calculate metrics on the mean confusion matrix.
- `None`: Calculate metrics of every category and output
directly.
Defaults to "macro".
num_classes (Optional, int): The number of classes. If the ``pred``
is label instead of scores, this argument is required.
Defaults to None.
Returns:
Tuple: The tuple contains precision, recall and f1-score.
And the type of each item is:
- torch.Tensor: If the ``pred`` is a sequence of label instead of
score (number of dimensions is 1). Only returns a tensor for
each metric. The shape is (1, ) if ``classwise`` is False, and
(C, ) if ``classwise`` is True.
- List[torch.Tensor]: If the ``pred`` is a sequence of score
(number of dimensions is 2). Return the metrics on each ``thrs``.
The shape of tensor is (1, ) if ``classwise`` is False, and (C, )
if ``classwise`` is True.
"""
average_options = ['micro', 'macro', None]
assert average in average_options, 'Invalid `average` argument, ' \
f'please specify from {average_options}.'
pred = to_tensor(pred)
target = to_tensor(target).to(torch.int64)
assert pred.size(0) == target.size(0), \
f"The size of pred ({pred.size(0)}) doesn't match "\
f'the target ({target.size(0)}).'
if pred.ndim == 1:
assert num_classes is not None, \
'Please specify the `num_classes` if the `pred` is labels ' \
'intead of scores.'
gt_positive = F.one_hot(target.flatten(), num_classes)
pred_positive = F.one_hot(pred.to(torch.int64), num_classes)
return _precision_recall_f1_support(pred_positive, gt_positive,
average)
else:
# For pred score, calculate on all thresholds.
num_classes = pred.size(1)
pred_score, pred_label = torch.topk(pred, k=1)
pred_score = pred_score.flatten()
pred_label = pred_label.flatten()
gt_positive = F.one_hot(target.flatten(), num_classes)
results = []
for thr in thrs:
pred_positive = F.one_hot(pred_label, num_classes)
if thr is not None:
pred_positive[pred_score <= thr] = 0
results.append(
_precision_recall_f1_support(pred_positive, gt_positive,
average))
return results
@METRICS.register_module()
class ConfusionMatrix(BaseMetric):
r"""A metric to calculate confusion matrix for single-label tasks.
Args:
num_classes (int, optional): The number of classes. Defaults to None.
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
prefix (str, optional): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Defaults to None.
Examples:
1. The basic usage.
>>> import torch
>>> from mmpretrain.evaluation import ConfusionMatrix
>>> y_pred = [0, 1, 1, 3]
>>> y_true = [0, 2, 1, 3]
>>> ConfusionMatrix.calculate(y_pred, y_true, num_classes=4)
tensor([[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 0, 0, 1]])
>>> # plot the confusion matrix
>>> import matplotlib.pyplot as plt
>>> y_score = torch.rand((1000, 10))
>>> y_true = torch.randint(10, (1000, ))
>>> matrix = ConfusionMatrix.calculate(y_score, y_true)
>>> ConfusionMatrix().plot(matrix)
>>> plt.show()
2. In the config file
.. code:: python
val_evaluator = dict(type='ConfusionMatrix')
test_evaluator = dict(type='ConfusionMatrix')
""" # noqa: E501
default_prefix = 'confusion_matrix'
def __init__(self,
num_classes: Optional[int] = None,
collect_device: str = 'cpu',
prefix: Optional[str] = None) -> None:
super().__init__(collect_device, prefix)
self.num_classes = num_classes
def process(self, data_batch, data_samples: Sequence[dict]) -> None:
for data_sample in data_samples:
if 'pred_score' in data_sample:
pred_score = data_sample['pred_score']
pred_label = pred_score.argmax(dim=0, keepdim=True)
self.num_classes = pred_score.size(0)
else:
pred_label = data_sample['pred_label']
self.results.append({
'pred_label': pred_label,
'gt_label': data_sample['gt_label'],
})
def compute_metrics(self, results: list) -> dict:
pred_labels = []
gt_labels = []
for result in results:
pred_labels.append(result['pred_label'])
gt_labels.append(result['gt_label'])
confusion_matrix = ConfusionMatrix.calculate(
torch.cat(pred_labels),
torch.cat(gt_labels),
num_classes=self.num_classes)
return {'result': confusion_matrix}
@staticmethod
def calculate(pred, target, num_classes=None) -> dict:
"""Calculate the confusion matrix for single-label task.
Args:
pred (torch.Tensor | np.ndarray | Sequence): The prediction
results. It can be labels (N, ), or scores of every
class (N, C).
target (torch.Tensor | np.ndarray | Sequence): The target of
each prediction with shape (N, ).
num_classes (Optional, int): The number of classes. If the ``pred``
is label instead of scores, this argument is required.
Defaults to None.
Returns:
torch.Tensor: The confusion matrix.
"""
pred = to_tensor(pred)
target_label = to_tensor(target).int()
assert pred.size(0) == target_label.size(0), \
f"The size of pred ({pred.size(0)}) doesn't match "\
f'the target ({target_label.size(0)}).'
assert target_label.ndim == 1
if pred.ndim == 1:
assert num_classes is not None, \
'Please specify the `num_classes` if the `pred` is labels ' \
'intead of scores.'
pred_label = pred
else:
num_classes = num_classes or pred.size(1)
pred_label = torch.argmax(pred, dim=1).flatten()
with torch.no_grad():
indices = num_classes * target_label + pred_label
matrix = torch.bincount(indices, minlength=num_classes**2)
matrix = matrix.reshape(num_classes, num_classes)
return matrix
@staticmethod
def plot(confusion_matrix: torch.Tensor,
include_values: bool = False,
cmap: str = 'viridis',
classes: Optional[List[str]] = None,
colorbar: bool = True,
show: bool = True):
"""Draw a confusion matrix by matplotlib.
Modified from `Scikit-Learn
<https://github.com/scikit-learn/scikit-learn/blob/dc580a8ef/sklearn/metrics/_plot/confusion_matrix.py#L81>`_
Args:
confusion_matrix (torch.Tensor): The confusion matrix to draw.
include_values (bool): Whether to draw the values in the figure.
Defaults to False.
cmap (str): The color map to use. Defaults to use "viridis".
classes (list[str], optional): The names of categories.
Defaults to None, which means to use index number.
colorbar (bool): Whether to show the colorbar. Defaults to True.
show (bool): Whether to show the figure immediately.
Defaults to True.
""" # noqa: E501
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(10, 10))
num_classes = confusion_matrix.size(0)
im_ = ax.imshow(confusion_matrix, interpolation='nearest', cmap=cmap)
text_ = None
cmap_min, cmap_max = im_.cmap(0), im_.cmap(1.0)
if include_values:
text_ = np.empty_like(confusion_matrix, dtype=object)
# print text with appropriate color depending on background
thresh = (confusion_matrix.max() + confusion_matrix.min()) / 2.0
for i, j in product(range(num_classes), range(num_classes)):
color = cmap_max if confusion_matrix[i,
j] < thresh else cmap_min
text_cm = format(confusion_matrix[i, j], '.2g')
text_d = format(confusion_matrix[i, j], 'd')
if len(text_d) < len(text_cm):
text_cm = text_d
text_[i, j] = ax.text(
j, i, text_cm, ha='center', va='center', color=color)
display_labels = classes or np.arange(num_classes)
if colorbar:
fig.colorbar(im_, ax=ax)
ax.set(
xticks=np.arange(num_classes),
yticks=np.arange(num_classes),
xticklabels=display_labels,
yticklabels=display_labels,
ylabel='True label',
xlabel='Predicted label',
)
ax.invert_yaxis()
ax.xaxis.tick_top()
ax.set_ylim((num_classes - 0.5, -0.5))
# Automatically rotate the x labels.
fig.autofmt_xdate(ha='center')
if show:
plt.show()
return fig
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import torch
import torchvision.ops.boxes as boxes
from mmengine.evaluator import BaseMetric
from mmpretrain.registry import METRICS
def aligned_box_iou(boxes1: torch.Tensor, boxes2: torch.Tensor):
area1 = boxes.box_area(boxes1)
area2 = boxes.box_area(boxes2)
lt = torch.max(boxes1[:, :2], boxes2[:, :2]) # (B, 2)
rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) # (B, 2)
wh = boxes._upcast(rb - lt).clamp(min=0) # (B, 2)
inter = wh[:, 0] * wh[:, 1] # (B, )
union = area1 + area2 - inter
iou = inter / union
return iou
@METRICS.register_module()
class VisualGroundingMetric(BaseMetric):
"""Visual Grounding evaluator.
Calculate the box mIOU and box grounding accuracy for visual grounding
model.
Args:
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
prefix (str, optional): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Should be modified according to the
`retrieval_type` for unambiguous results. Defaults to TR.
"""
default_prefix = 'visual-grounding'
def process(self, data_batch, data_samples):
"""Process one batch of data samples.
The processed results should be stored in ``self.results``, which will
be used to computed the metrics when all batches have been processed.
Args:
data_batch: A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from the model.
"""
for preds in data_samples:
pred_box = preds['pred_bboxes'].squeeze()
box_gt = torch.Tensor(preds['gt_bboxes']).squeeze()
result = {
'box': pred_box.to('cpu').squeeze(),
'box_target': box_gt.squeeze(),
}
self.results.append(result)
def compute_metrics(self, results: List):
"""Compute the metrics from processed results.
Args:
results (dict): The processed results of each batch.
Returns:
Dict: The computed metrics. The keys are the names of the metrics,
and the values are corresponding results.
"""
pred_boxes = torch.stack([each['box'] for each in results])
gt_boxes = torch.stack([each['box_target'] for each in results])
iou = aligned_box_iou(pred_boxes, gt_boxes)
accu_num = torch.sum(iou >= 0.5)
miou = torch.mean(iou)
acc = accu_num / len(gt_boxes)
coco_val = {'miou': miou, 'acc': acc}
return coco_val
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence
from mmpretrain.registry import METRICS
from mmpretrain.structures import label_to_onehot
from .multi_label import AveragePrecision, MultiLabelMetric
class VOCMetricMixin:
"""A mixin class for VOC dataset metrics, VOC annotations have extra
`difficult` attribute for each object, therefore, extra option is needed
for calculating VOC metrics.
Args:
difficult_as_postive (Optional[bool]): Whether to map the difficult
labels as positive in one-hot ground truth for evaluation. If it
set to True, map difficult gt labels to positive ones(1), If it
set to False, map difficult gt labels to negative ones(0).
Defaults to None, the difficult labels will be set to '-1'.
"""
def __init__(self,
*arg,
difficult_as_positive: Optional[bool] = None,
**kwarg):
self.difficult_as_positive = difficult_as_positive
super().__init__(*arg, **kwarg)
def process(self, data_batch, data_samples: Sequence[dict]):
"""Process one batch of data samples.
The processed results should be stored in ``self.results``, which will
be used to computed the metrics when all batches have been processed.
Args:
data_batch: A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from the model.
"""
for data_sample in data_samples:
result = dict()
gt_label = data_sample['gt_label']
gt_label_difficult = data_sample['gt_label_difficult']
result['pred_score'] = data_sample['pred_score'].clone()
num_classes = result['pred_score'].size()[-1]
if 'gt_score' in data_sample:
result['gt_score'] = data_sample['gt_score'].clone()
else:
result['gt_score'] = label_to_onehot(gt_label, num_classes)
# VOC annotation labels all the objects in a single image
# therefore, some categories are appeared both in
# difficult objects and non-difficult objects.
# Here we reckon those labels which are only exists in difficult
# objects as difficult labels.
difficult_label = set(gt_label_difficult) - (
set(gt_label_difficult) & set(gt_label.tolist()))
# set difficult label for better eval
if self.difficult_as_positive is None:
result['gt_score'][[*difficult_label]] = -1
elif self.difficult_as_positive:
result['gt_score'][[*difficult_label]] = 1
# Save the result to `self.results`.
self.results.append(result)
@METRICS.register_module()
class VOCMultiLabelMetric(VOCMetricMixin, MultiLabelMetric):
"""A collection of metrics for multi-label multi-class classification task
based on confusion matrix for VOC dataset.
It includes precision, recall, f1-score and support.
Args:
difficult_as_postive (Optional[bool]): Whether to map the difficult
labels as positive in one-hot ground truth for evaluation. If it
set to True, map difficult gt labels to positive ones(1), If it
set to False, map difficult gt labels to negative ones(0).
Defaults to None, the difficult labels will be set to '-1'.
**kwarg: Refers to `MultiLabelMetric` for detailed docstrings.
"""
@METRICS.register_module()
class VOCAveragePrecision(VOCMetricMixin, AveragePrecision):
"""Calculate the average precision with respect of classes for VOC dataset.
Args:
difficult_as_postive (Optional[bool]): Whether to map the difficult
labels as positive in one-hot ground truth for evaluation. If it
set to True, map difficult gt labels to positive ones(1), If it
set to False, map difficult gt labels to negative ones(0).
Defaults to None, the difficult labels will be set to '-1'.
**kwarg: Refers to `AveragePrecision` for detailed docstrings.
"""
# Copyright (c) OpenMMLab. All rights reserved.
# Partly adopted from https://github.com/GT-Vision-Lab/VQA
# Copyright (c) 2014, Aishwarya Agrawal
from typing import List, Optional
import mmengine
from mmengine.evaluator import BaseMetric
from mmengine.logging import MMLogger
from mmpretrain.registry import METRICS
def _process_punctuation(inText):
import re
outText = inText
punct = [
';', r'/', '[', ']', '"', '{', '}', '(', ')', '=', '+', '\\', '_', '-',
'>', '<', '@', '`', ',', '?', '!'
]
commaStrip = re.compile('(\d)(,)(\d)') # noqa: W605
periodStrip = re.compile('(?!<=\d)(\.)(?!\d)') # noqa: W605
for p in punct:
if (p + ' ' in inText or ' ' + p in inText) or (re.search(
commaStrip, inText) is not None):
outText = outText.replace(p, '')
else:
outText = outText.replace(p, ' ')
outText = periodStrip.sub('', outText, re.UNICODE)
return outText
def _process_digit_article(inText):
outText = []
tempText = inText.lower().split()
articles = ['a', 'an', 'the']
manualMap = {
'none': '0',
'zero': '0',
'one': '1',
'two': '2',
'three': '3',
'four': '4',
'five': '5',
'six': '6',
'seven': '7',
'eight': '8',
'nine': '9',
'ten': '10',
}
contractions = {
'aint': "ain't",
'arent': "aren't",
'cant': "can't",
'couldve': "could've",
'couldnt': "couldn't",
"couldn'tve": "couldn't've",
"couldnt've": "couldn't've",
'didnt': "didn't",
'doesnt': "doesn't",
'dont': "don't",
'hadnt': "hadn't",
"hadnt've": "hadn't've",
"hadn'tve": "hadn't've",
'hasnt': "hasn't",
'havent': "haven't",
'hed': "he'd",
"hed've": "he'd've",
"he'dve": "he'd've",
'hes': "he's",
'howd': "how'd",
'howll': "how'll",
'hows': "how's",
"Id've": "I'd've",
"I'dve": "I'd've",
'Im': "I'm",
'Ive': "I've",
'isnt': "isn't",
'itd': "it'd",
"itd've": "it'd've",
"it'dve": "it'd've",
'itll': "it'll",
"let's": "let's",
'maam': "ma'am",
'mightnt': "mightn't",
"mightnt've": "mightn't've",
"mightn'tve": "mightn't've",
'mightve': "might've",
'mustnt': "mustn't",
'mustve': "must've",
'neednt': "needn't",
'notve': "not've",
'oclock': "o'clock",
'oughtnt': "oughtn't",
"ow's'at": "'ow's'at",
"'ows'at": "'ow's'at",
"'ow'sat": "'ow's'at",
'shant': "shan't",
"shed've": "she'd've",
"she'dve": "she'd've",
"she's": "she's",
'shouldve': "should've",
'shouldnt': "shouldn't",
"shouldnt've": "shouldn't've",
"shouldn'tve": "shouldn't've",
"somebody'd": 'somebodyd',
"somebodyd've": "somebody'd've",
"somebody'dve": "somebody'd've",
'somebodyll': "somebody'll",
'somebodys': "somebody's",
'someoned': "someone'd",
"someoned've": "someone'd've",
"someone'dve": "someone'd've",
'someonell': "someone'll",
'someones': "someone's",
'somethingd': "something'd",
"somethingd've": "something'd've",
"something'dve": "something'd've",
'somethingll': "something'll",
'thats': "that's",
'thered': "there'd",
"thered've": "there'd've",
"there'dve": "there'd've",
'therere': "there're",
'theres': "there's",
'theyd': "they'd",
"theyd've": "they'd've",
"they'dve": "they'd've",
'theyll': "they'll",
'theyre': "they're",
'theyve': "they've",
'twas': "'twas",
'wasnt': "wasn't",
"wed've": "we'd've",
"we'dve": "we'd've",
'weve': "we've",
'werent': "weren't",
'whatll': "what'll",
'whatre': "what're",
'whats': "what's",
'whatve': "what've",
'whens': "when's",
'whered': "where'd",
'wheres': "where's",
'whereve': "where've",
'whod': "who'd",
"whod've": "who'd've",
"who'dve": "who'd've",
'wholl': "who'll",
'whos': "who's",
'whove': "who've",
'whyll': "why'll",
'whyre': "why're",
'whys': "why's",
'wont': "won't",
'wouldve': "would've",
'wouldnt': "wouldn't",
"wouldnt've": "wouldn't've",
"wouldn'tve": "wouldn't've",
'yall': "y'all",
"yall'll": "y'all'll",
"y'allll": "y'all'll",
"yall'd've": "y'all'd've",
"y'alld've": "y'all'd've",
"y'all'dve": "y'all'd've",
'youd': "you'd",
"youd've": "you'd've",
"you'dve": "you'd've",
'youll': "you'll",
'youre': "you're",
'youve': "you've",
}
for word in tempText:
word = manualMap.setdefault(word, word)
if word not in articles:
outText.append(word)
for wordId, word in enumerate(outText):
if word in contractions:
outText[wordId] = contractions[word]
outText = ' '.join(outText)
return outText
@METRICS.register_module()
class VQAAcc(BaseMetric):
'''VQA Acc metric.
Args:
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
prefix (str, optional): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Should be modified according to the
`retrieval_type` for unambiguous results. Defaults to TR.
'''
default_prefix = 'VQA'
def __init__(self,
full_score_weight: float = 0.3,
collect_device: str = 'cpu',
prefix: Optional[str] = None):
super().__init__(collect_device=collect_device, prefix=prefix)
self.full_score_weight = full_score_weight
def process(self, data_batch, data_samples):
"""Process one batch of data samples.
The processed results should be stored in ``self.results``, which will
be used to computed the metrics when all batches have been processed.
Args:
data_batch: A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from the model.
"""
for sample in data_samples:
gt_answer = sample.get('gt_answer')
gt_answer_weight = sample.get('gt_answer_weight')
if isinstance(gt_answer, str):
gt_answer = [gt_answer]
if gt_answer_weight is None:
gt_answer_weight = [1. / (len(gt_answer))] * len(gt_answer)
result = {
'pred_answer': sample.get('pred_answer'),
'gt_answer': gt_answer,
'gt_answer_weight': gt_answer_weight,
}
self.results.append(result)
def compute_metrics(self, results: List):
"""Compute the metrics from processed results.
Args:
results (dict): The processed results of each batch.
Returns:
Dict: The computed metrics. The keys are the names of the metrics,
and the values are corresponding results.
"""
acc = []
for result in results:
pred_answer = self._process_answer(result['pred_answer'])
gt_answer = [
self._process_answer(answer) for answer in result['gt_answer']
]
answer_weight = result['gt_answer_weight']
weight_sum = 0
for i, gt in enumerate(gt_answer):
if gt == pred_answer:
weight_sum += answer_weight[i]
vqa_acc = min(1.0, weight_sum / self.full_score_weight)
acc.append(vqa_acc)
accuracy = sum(acc) / len(acc) * 100
metrics = {'acc': accuracy}
return metrics
def _process_answer(self, answer):
answer = answer.replace('\n', ' ')
answer = answer.replace('\t', ' ')
answer = answer.strip()
answer = _process_punctuation(answer)
answer = _process_digit_article(answer)
return answer
@METRICS.register_module()
class ReportVQA(BaseMetric):
"""Dump VQA result to the standard json format for VQA evaluation.
Args:
file_path (str): The file path to save the result file.
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
prefix (str, optional): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Should be modified according to the
`retrieval_type` for unambiguous results. Defaults to TR.
"""
default_prefix = 'VQA'
def __init__(self,
file_path: str,
collect_device: str = 'cpu',
prefix: Optional[str] = None):
super().__init__(collect_device=collect_device, prefix=prefix)
if not file_path.endswith('.json'):
raise ValueError('The output file must be a json file.')
self.file_path = file_path
def process(self, data_batch, data_samples) -> None:
"""transfer tensors in predictions to CPU."""
for sample in data_samples:
question_id = sample['question_id']
pred_answer = sample['pred_answer']
result = {
'question_id': int(question_id),
'answer': pred_answer,
}
self.results.append(result)
def compute_metrics(self, results: List):
"""Dump the result to json file."""
mmengine.dump(results, self.file_path)
logger = MMLogger.get_current_instance()
logger.info(f'Results has been saved to {self.file_path}.')
return {}
# Copyright (c) OpenMMLab. All rights reserved.
from .backbones import * # noqa: F401,F403
from .builder import (BACKBONES, CLASSIFIERS, HEADS, LOSSES, NECKS,
build_backbone, build_classifier, build_head, build_loss,
build_neck)
from .classifiers import * # noqa: F401,F403
from .heads import * # noqa: F401,F403
from .losses import * # noqa: F401,F403
from .multimodal import * # noqa: F401,F403
from .necks import * # noqa: F401,F403
from .peft import * # noqa: F401,F403
from .retrievers import * # noqa: F401,F403
from .selfsup import * # noqa: F401,F403
from .tta import * # noqa: F401,F403
from .utils import * # noqa: F401,F403
__all__ = [
'BACKBONES', 'HEADS', 'NECKS', 'LOSSES', 'CLASSIFIERS', 'build_backbone',
'build_head', 'build_neck', 'build_loss', 'build_classifier'
]
# Copyright (c) OpenMMLab. All rights reserved.
from .alexnet import AlexNet
from .beit import BEiTViT
from .conformer import Conformer
from .convmixer import ConvMixer
from .convnext import ConvNeXt
from .cspnet import CSPDarkNet, CSPNet, CSPResNet, CSPResNeXt
from .davit import DaViT
from .deit import DistilledVisionTransformer
from .deit3 import DeiT3
from .densenet import DenseNet
from .edgenext import EdgeNeXt
from .efficientformer import EfficientFormer
from .efficientnet import EfficientNet
from .efficientnet_v2 import EfficientNetV2
from .hivit import HiViT
from .hornet import HorNet
from .hrnet import HRNet
from .inception_v3 import InceptionV3
from .lenet import LeNet5
from .levit import LeViT
from .mixmim import MixMIMTransformer
from .mlp_mixer import MlpMixer
from .mobilenet_v2 import MobileNetV2
from .mobilenet_v3 import MobileNetV3
from .mobileone import MobileOne
from .mobilevit import MobileViT
from .mvit import MViT
from .poolformer import PoolFormer
from .regnet import RegNet
from .replknet import RepLKNet
from .repmlp import RepMLPNet
from .repvgg import RepVGG
from .res2net import Res2Net
from .resnest import ResNeSt
from .resnet import ResNet, ResNetV1c, ResNetV1d
from .resnet_cifar import ResNet_CIFAR
from .resnext import ResNeXt
from .revvit import RevVisionTransformer
from .riformer import RIFormer
from .seresnet import SEResNet
from .seresnext import SEResNeXt
from .shufflenet_v1 import ShuffleNetV1
from .shufflenet_v2 import ShuffleNetV2
from .sparse_convnext import SparseConvNeXt
from .sparse_resnet import SparseResNet
from .swin_transformer import SwinTransformer
from .swin_transformer_v2 import SwinTransformerV2
from .t2t_vit import T2T_ViT
from .timm_backbone import TIMMBackbone
from .tinyvit import TinyViT
from .tnt import TNT
from .twins import PCPVT, SVT
from .van import VAN
from .vgg import VGG
from .vig import PyramidVig, Vig
from .vision_transformer import VisionTransformer
from .vit_eva02 import ViTEVA02
from .vit_sam import ViTSAM
from .xcit import XCiT
__all__ = [
'LeNet5',
'AlexNet',
'VGG',
'RegNet',
'ResNet',
'ResNeXt',
'ResNetV1d',
'ResNeSt',
'ResNet_CIFAR',
'SEResNet',
'SEResNeXt',
'ShuffleNetV1',
'ShuffleNetV2',
'MobileNetV2',
'MobileNetV3',
'VisionTransformer',
'SwinTransformer',
'TNT',
'TIMMBackbone',
'T2T_ViT',
'Res2Net',
'RepVGG',
'Conformer',
'MlpMixer',
'DistilledVisionTransformer',
'PCPVT',
'SVT',
'EfficientNet',
'EfficientNetV2',
'ConvNeXt',
'HRNet',
'ResNetV1c',
'ConvMixer',
'EdgeNeXt',
'CSPDarkNet',
'CSPResNet',
'CSPResNeXt',
'CSPNet',
'RepLKNet',
'RepMLPNet',
'PoolFormer',
'RIFormer',
'DenseNet',
'VAN',
'InceptionV3',
'MobileOne',
'EfficientFormer',
'SwinTransformerV2',
'MViT',
'DeiT3',
'HorNet',
'MobileViT',
'DaViT',
'BEiTViT',
'RevVisionTransformer',
'MixMIMTransformer',
'TinyViT',
'LeViT',
'Vig',
'PyramidVig',
'XCiT',
'ViTSAM',
'ViTEVA02',
'HiViT',
'SparseResNet',
'SparseConvNeXt',
]
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmpretrain.registry import MODELS
from .base_backbone import BaseBackbone
@MODELS.register_module()
class AlexNet(BaseBackbone):
"""`AlexNet <https://en.wikipedia.org/wiki/AlexNet>`_ backbone.
The input for AlexNet is a 224x224 RGB image.
Args:
num_classes (int): number of classes for classification.
The default value is -1, which uses the backbone as
a feature extractor without the top classifier.
"""
def __init__(self, num_classes=-1):
super(AlexNet, self).__init__()
self.num_classes = num_classes
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(64, 192, kernel_size=5, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(192, 384, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(384, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
)
if self.num_classes > 0:
self.classifier = nn.Sequential(
nn.Dropout(),
nn.Linear(256 * 6 * 6, 4096),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, num_classes),
)
def forward(self, x):
x = self.features(x)
if self.num_classes > 0:
x = x.view(x.size(0), 256 * 6 * 6)
x = self.classifier(x)
return (x, )
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from mmengine.model import BaseModule
class BaseBackbone(BaseModule, metaclass=ABCMeta):
"""Base backbone.
This class defines the basic functions of a backbone. Any backbone that
inherits this class should at least define its own `forward` function.
"""
def __init__(self, init_cfg=None):
super(BaseBackbone, self).__init__(init_cfg)
@abstractmethod
def forward(self, x):
"""Forward computation.
Args:
x (tensor | tuple[tensor]): x could be a Torch.tensor or a tuple of
Torch.tensor, containing input data for forward computation.
"""
pass
def train(self, mode=True):
"""Set module status before forward computation.
Args:
mode (bool): Whether it is train_mode or test_mode
"""
super(BaseBackbone, self).train(mode)
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn.bricks.drop import build_dropout
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
from mmengine.model import BaseModule, ModuleList
from mmengine.model.weight_init import trunc_normal_
from mmpretrain.registry import MODELS
from ..utils import (BEiTAttention, build_norm_layer, resize_pos_embed,
resize_relative_position_bias_table, to_2tuple)
from .base_backbone import BaseBackbone
from .vision_transformer import TransformerEncoderLayer
class RelativePositionBias(BaseModule):
"""Relative Position Bias.
This module is copied from
https://github.com/microsoft/unilm/blob/master/beit/modeling_finetune.py#L209.
Args:
window_size (Sequence[int]): The window size of the relative
position bias.
num_heads (int): The number of head in multi-head attention.
with_cls_token (bool): To indicate the backbone has cls_token or not.
Defaults to True.
"""
def __init__(
self,
window_size: Sequence[int],
num_heads: int,
with_cls_token: bool = True,
) -> None:
super().__init__()
self.window_size = window_size
if with_cls_token:
num_extra_tokens = 3
else:
num_extra_tokens = 0
# cls to token & token to cls & cls to cls
self.num_relative_distance = (2 * window_size[0] - 1) * (
2 * window_size[1] - 1) + num_extra_tokens
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance,
num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each
# token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] -\
coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(
1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
if with_cls_token:
relative_position_index = torch.zeros(
size=(window_size[0] * window_size[1] + 1, ) * 2,
dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(
-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
else:
relative_position_index = torch.zeros(
size=(window_size[0] * window_size[1], ) * 2,
dtype=relative_coords.dtype)
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer('relative_position_index',
relative_position_index)
def forward(self) -> torch.Tensor:
# Wh*Ww,Wh*Ww,nH
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1, -1)
return relative_position_bias.permute(
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
class BEiTTransformerEncoderLayer(TransformerEncoderLayer):
"""Implements one encoder layer in BEiT.
Comparing with conventional ``TransformerEncoderLayer``, this module
adds weights to the shortcut connection. In addition, ``BEiTAttention``
is used to replace the original ``MultiheadAttention`` in
``TransformerEncoderLayer``.
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs.
layer_scale_init_value (float): The initialization value for
the learnable scaling of attention and FFN. 1 means no scaling.
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Defaults to 0.
window_size (tuple[int]): The height and width of the window.
Defaults to None.
use_rel_pos_bias (bool): Whether to use unique relative position bias,
if False, use shared relative position bias defined in backbone.
attn_drop_rate (float): The drop out rate for attention layer.
Defaults to 0.0.
drop_path_rate (float): Stochastic depth rate. Default 0.0.
num_fcs (int): The number of fully-connected layers for FFNs.
Defaults to 2.
bias (bool | str): The option to add leanable bias for q, k, v. If bias
is True, it will add leanable bias. If bias is 'qv_bias', it will
only add leanable bias for q, v. If bias is False, it will not add
bias for q, k, v. Default to 'qv_bias'.
act_cfg (dict): The activation config for FFNs.
Defaults to ``dict(type='GELU')``.
norm_cfg (dict): Config dict for normalization layer.
Defaults to dict(type='LN').
attn_cfg (dict): The configuration for the attention layer.
Defaults to an empty dict.
ffn_cfg (dict): The configuration for the ffn layer.
Defaults to ``dict(add_identity=False)``.
init_cfg (dict or List[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
embed_dims: int,
num_heads: int,
feedforward_channels: int,
layer_scale_init_value: float,
window_size: Tuple[int, int],
use_rel_pos_bias: bool,
drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.,
num_fcs: int = 2,
bias: Union[str, bool] = 'qv_bias',
act_cfg: dict = dict(type='GELU'),
norm_cfg: dict = dict(type='LN'),
attn_cfg: dict = dict(),
ffn_cfg: dict = dict(add_identity=False),
init_cfg: Optional[Union[dict, List[dict]]] = None) -> None:
super().__init__(
embed_dims=embed_dims,
num_heads=num_heads,
feedforward_channels=feedforward_channels,
attn_drop_rate=attn_drop_rate,
drop_path_rate=0.,
drop_rate=0.,
num_fcs=num_fcs,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
init_cfg=init_cfg)
attn_cfg = {
'window_size': window_size,
'use_rel_pos_bias': use_rel_pos_bias,
'qk_scale': None,
'embed_dims': embed_dims,
'num_heads': num_heads,
'attn_drop': attn_drop_rate,
'proj_drop': drop_rate,
'bias': bias,
**attn_cfg,
}
self.attn = BEiTAttention(**attn_cfg)
ffn_cfg = {
'embed_dims': embed_dims,
'feedforward_channels': feedforward_channels,
'num_fcs': num_fcs,
'ffn_drop': drop_rate,
'dropout_layer': dict(type='DropPath', drop_prob=drop_path_rate),
'act_cfg': act_cfg,
**ffn_cfg,
}
self.ffn = FFN(**ffn_cfg)
# NOTE: drop path for stochastic depth, we shall see if
# this is better than dropout here
dropout_layer = dict(type='DropPath', drop_prob=drop_path_rate)
self.drop_path = build_dropout(
dropout_layer) if dropout_layer else nn.Identity()
if layer_scale_init_value > 0:
self.gamma_1 = nn.Parameter(
layer_scale_init_value * torch.ones((embed_dims)),
requires_grad=True)
self.gamma_2 = nn.Parameter(
layer_scale_init_value * torch.ones((embed_dims)),
requires_grad=True)
else:
self.gamma_1, self.gamma_2 = None, None
def forward(self, x: torch.Tensor,
rel_pos_bias: torch.Tensor) -> torch.Tensor:
if self.gamma_1 is None:
x = x + self.drop_path(
self.attn(self.ln1(x), rel_pos_bias=rel_pos_bias))
x = x + self.drop_path(self.ffn(self.ln2(x)))
else:
x = x + self.drop_path(self.gamma_1 * self.attn(
self.ln1(x), rel_pos_bias=rel_pos_bias))
x = x + self.drop_path(self.gamma_2 * self.ffn(self.ln2(x)))
return x
@MODELS.register_module()
class BEiTViT(BaseBackbone):
"""Backbone for BEiT.
A PyTorch implement of : `BEiT: BERT Pre-Training of Image Transformers
<https://arxiv.org/abs/2106.08254>`_
A PyTorch implement of : `BEiT v2: Masked Image Modeling with
Vector-Quantized Visual Tokenizers <https://arxiv.org/abs/2208.06366>`_
Args:
arch (str | dict): BEiT architecture. If use string, choose from
'base', 'large'. If use dict, it should have below keys:
- **embed_dims** (int): The dimensions of embedding.
- **num_layers** (int): The number of transformer encoder layers.
- **num_heads** (int): The number of heads in attention modules.
- **feedforward_channels** (int): The hidden dimensions in
feedforward modules.
Defaults to 'base'.
img_size (int | tuple): The expected input image shape. Because we
support dynamic input shape, just set the argument to the most
common input image shape. Defaults to 224.
patch_size (int | tuple): The patch size in patch embedding.
Defaults to 16.
in_channels (int): The num of input channels. Defaults to 3.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
bias (bool | str): The option to add leanable bias for q, k, v. If bias
is True, it will add leanable bias. If bias is 'qv_bias', it will
only add leanable bias for q, v. If bias is False, it will not add
bias for q, k, v. Default to 'qv_bias'.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
out_type (str): The type of output features. Please choose from
- ``"cls_token"``: The class token tensor with shape (B, C).
- ``"featmap"``: The feature map tensor from the patch tokens
with shape (B, C, H, W).
- ``"avg_featmap"``: The global averaged feature map tensor
with shape (B, C).
- ``"raw"``: The raw feature tensor includes patch tokens and
class tokens with shape (B, L, C).
Defaults to ``"avg_featmap"``.
with_cls_token (bool): Whether concatenating class token into image
tokens as transformer input. Defaults to True.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
use_abs_pos_emb (bool): Use position embedding like vanilla ViT.
Defaults to False.
use_rel_pos_bias (bool): Use relative position embedding in each
transformer encoder layer. Defaults to True.
use_shared_rel_pos_bias (bool): Use shared relative position embedding,
all transformer encoder layers share the same relative position
embedding. Defaults to False.
layer_scale_init_value (float): The initialization value for
the learnable scaling of attention and FFN. Defaults to 0.1.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Defaults to "bicubic".
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
layer_cfgs (Sequence | dict): Configs of each transformer layer in
encoder. Defaults to an empty dict.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
arch_zoo = {
**dict.fromkeys(
['s', 'small'], {
'embed_dims': 768,
'num_layers': 8,
'num_heads': 8,
'feedforward_channels': 768 * 3,
}),
**dict.fromkeys(
['b', 'base'], {
'embed_dims': 768,
'num_layers': 12,
'num_heads': 12,
'feedforward_channels': 3072
}),
**dict.fromkeys(
['l', 'large'], {
'embed_dims': 1024,
'num_layers': 24,
'num_heads': 16,
'feedforward_channels': 4096
}),
**dict.fromkeys(
['eva-g', 'eva-giant'],
{
# The implementation in EVA
# <https://arxiv.org/abs/2211.07636>
'embed_dims': 1408,
'num_layers': 40,
'num_heads': 16,
'feedforward_channels': 6144
}),
**dict.fromkeys(
['deit-t', 'deit-tiny'], {
'embed_dims': 192,
'num_layers': 12,
'num_heads': 3,
'feedforward_channels': 192 * 4
}),
**dict.fromkeys(
['deit-s', 'deit-small'], {
'embed_dims': 384,
'num_layers': 12,
'num_heads': 6,
'feedforward_channels': 384 * 4
}),
**dict.fromkeys(
['deit-b', 'deit-base'], {
'embed_dims': 768,
'num_layers': 12,
'num_heads': 12,
'feedforward_channels': 768 * 4
}),
}
num_extra_tokens = 1 # class token
OUT_TYPES = {'raw', 'cls_token', 'featmap', 'avg_featmap'}
def __init__(self,
arch='base',
img_size=224,
patch_size=16,
in_channels=3,
out_indices=-1,
drop_rate=0,
drop_path_rate=0,
bias='qv_bias',
norm_cfg=dict(type='LN', eps=1e-6),
final_norm=False,
out_type='avg_featmap',
with_cls_token=True,
frozen_stages=-1,
use_abs_pos_emb=False,
use_rel_pos_bias=True,
use_shared_rel_pos_bias=False,
interpolate_mode='bicubic',
layer_scale_init_value=0.1,
patch_cfg=dict(),
layer_cfgs=dict(),
init_cfg=None):
super(BEiTViT, self).__init__(init_cfg)
if isinstance(arch, str):
arch = arch.lower()
assert arch in set(self.arch_zoo), \
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
self.arch_settings = self.arch_zoo[arch]
else:
essential_keys = {
'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels'
}
assert isinstance(arch, dict) and essential_keys <= set(arch), \
f'Custom arch needs a dict with keys {essential_keys}'
self.arch_settings = arch
self.embed_dims = self.arch_settings['embed_dims']
self.num_layers = self.arch_settings['num_layers']
self.img_size = to_2tuple(img_size)
# Set patch embedding
_patch_cfg = dict(
in_channels=in_channels,
input_size=img_size,
embed_dims=self.embed_dims,
conv_type='Conv2d',
kernel_size=patch_size,
stride=patch_size,
)
_patch_cfg.update(patch_cfg)
self.patch_embed = PatchEmbed(**_patch_cfg)
self.patch_resolution = self.patch_embed.init_out_size
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
# Set out type
if out_type not in self.OUT_TYPES:
raise ValueError(f'Unsupported `out_type` {out_type}, please '
f'choose from {self.OUT_TYPES}')
self.out_type = out_type
# Set cls token
self.with_cls_token = with_cls_token
if with_cls_token:
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
self.num_extra_tokens = 1
elif out_type != 'cls_token':
self.cls_token = None
self.num_extra_tokens = 0
else:
raise ValueError(
'with_cls_token must be True when `out_type="cls_token"`.')
# Set position embedding
self.interpolate_mode = interpolate_mode
if use_abs_pos_emb:
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + self.num_extra_tokens,
self.embed_dims))
self._register_load_state_dict_pre_hook(self._prepare_pos_embed)
else:
self.pos_embed = None
self.drop_after_pos = nn.Dropout(p=drop_rate)
assert not (use_rel_pos_bias and use_shared_rel_pos_bias), (
'`use_rel_pos_bias` and `use_shared_rel_pos_bias` cannot be set '
'to True at the same time')
self.use_rel_pos_bias = use_rel_pos_bias
if use_shared_rel_pos_bias:
self.rel_pos_bias = RelativePositionBias(
window_size=self.patch_resolution,
num_heads=self.arch_settings['num_heads'])
else:
self.rel_pos_bias = None
self._register_load_state_dict_pre_hook(
self._prepare_relative_position_bias_table)
if isinstance(out_indices, int):
out_indices = [out_indices]
assert isinstance(out_indices, Sequence), \
f'"out_indices" must by a sequence or int, ' \
f'get {type(out_indices)} instead.'
for i, index in enumerate(out_indices):
if index < 0:
out_indices[i] = self.num_layers + index
assert 0 <= out_indices[i] <= self.num_layers, \
f'Invalid out_indices {index}'
self.out_indices = out_indices
# stochastic depth decay rule
dpr = np.linspace(0, drop_path_rate, self.num_layers)
self.layers = ModuleList()
if isinstance(layer_cfgs, dict):
layer_cfgs = [layer_cfgs] * self.num_layers
for i in range(self.num_layers):
_layer_cfg = dict(
embed_dims=self.embed_dims,
num_heads=self.arch_settings['num_heads'],
feedforward_channels=self.
arch_settings['feedforward_channels'],
layer_scale_init_value=layer_scale_init_value,
window_size=self.patch_resolution,
use_rel_pos_bias=use_rel_pos_bias,
drop_rate=drop_rate,
drop_path_rate=dpr[i],
bias=bias,
norm_cfg=norm_cfg)
_layer_cfg.update(layer_cfgs[i])
self.layers.append(BEiTTransformerEncoderLayer(**_layer_cfg))
self.frozen_stages = frozen_stages
self.final_norm = final_norm
if final_norm:
self.ln1 = build_norm_layer(norm_cfg, self.embed_dims)
if out_type == 'avg_featmap':
self.ln2 = build_norm_layer(norm_cfg, self.embed_dims)
# freeze stages only when self.frozen_stages > 0
if self.frozen_stages > 0:
self._freeze_stages()
@property
def norm1(self):
return self.ln1
@property
def norm2(self):
return self.ln2
def init_weights(self):
super(BEiTViT, self).init_weights()
if not (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=0.02)
def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs):
name = prefix + 'pos_embed'
if name not in state_dict.keys():
return
ckpt_pos_embed_shape = state_dict[name].shape
if (not self.with_cls_token
and ckpt_pos_embed_shape[1] == self.pos_embed.shape[1] + 1):
# Remove cls token from state dict if it's not used.
state_dict[name] = state_dict[name][:, 1:]
ckpt_pos_embed_shape = state_dict[name].shape
if self.pos_embed.shape != ckpt_pos_embed_shape:
from mmengine.logging import MMLogger
logger = MMLogger.get_current_instance()
logger.info(
f'Resize the pos_embed shape from {ckpt_pos_embed_shape} '
f'to {self.pos_embed.shape}.')
ckpt_pos_embed_shape = to_2tuple(
int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens)))
pos_embed_shape = self.patch_embed.init_out_size
state_dict[name] = resize_pos_embed(state_dict[name],
ckpt_pos_embed_shape,
pos_embed_shape,
self.interpolate_mode,
self.num_extra_tokens)
@staticmethod
def resize_pos_embed(*args, **kwargs):
"""Interface for backward-compatibility."""
return resize_pos_embed(*args, **kwargs)
def _freeze_stages(self):
# freeze position embedding
if self.pos_embed is not None:
self.pos_embed.requires_grad = False
# set dropout to eval model
self.drop_after_pos.eval()
# freeze patch embedding
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
# freeze cls_token
if self.with_cls_token:
self.cls_token.requires_grad = False
# freeze layers
for i in range(1, self.frozen_stages + 1):
m = self.layers[i - 1]
m.eval()
for param in m.parameters():
param.requires_grad = False
# freeze the last layer norm
if self.frozen_stages == len(self.layers):
if self.final_norm:
self.ln1.eval()
for param in self.ln1.parameters():
param.requires_grad = False
if self.out_type == 'avg_featmap':
self.ln2.eval()
for param in self.ln2.parameters():
param.requires_grad = False
def forward(self, x):
B = x.shape[0]
x, patch_resolution = self.patch_embed(x)
if self.cls_token is not None:
# stole cls_tokens impl from Phil Wang, thanks
cls_token = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_token, x), dim=1)
if self.pos_embed is not None:
x = x + resize_pos_embed(
self.pos_embed,
self.patch_resolution,
patch_resolution,
mode=self.interpolate_mode,
num_extra_tokens=self.num_extra_tokens)
x = self.drop_after_pos(x)
rel_pos_bias = self.rel_pos_bias() \
if self.rel_pos_bias is not None else None
outs = []
for i, layer in enumerate(self.layers):
x = layer(x, rel_pos_bias)
if i == len(self.layers) - 1 and self.final_norm:
x = self.ln1(x)
if i in self.out_indices:
outs.append(self._format_output(x, patch_resolution))
return tuple(outs)
def _format_output(self, x, hw):
if self.out_type == 'raw':
return x
if self.out_type == 'cls_token':
return x[:, 0]
patch_token = x[:, self.num_extra_tokens:]
if self.out_type == 'featmap':
B = x.size(0)
# (B, N, C) -> (B, H, W, C) -> (B, C, H, W)
return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2)
if self.out_type == 'avg_featmap':
return self.ln2(patch_token.mean(dim=1))
def _prepare_relative_position_bias_table(self, state_dict, prefix, *args,
**kwargs):
from mmengine.logging import MMLogger
logger = MMLogger.get_current_instance()
if self.use_rel_pos_bias and 'rel_pos_bias.relative_position_bias_table' in state_dict: # noqa:E501
logger.info('Expand the shared relative position embedding to '
'each transformer block.')
rel_pos_bias = state_dict[
'rel_pos_bias.relative_position_bias_table']
for i in range(self.num_layers):
state_dict[
f'layers.{i}.attn.relative_position_bias_table'] = \
rel_pos_bias.clone()
state_dict.pop('rel_pos_bias.relative_position_bias_table')
state_dict.pop('rel_pos_bias.relative_position_index')
state_dict_model = self.state_dict()
all_keys = list(state_dict_model.keys())
for key in all_keys:
if 'relative_position_bias_table' in key:
ckpt_key = prefix + key
if ckpt_key not in state_dict:
continue
rel_pos_bias_pretrained = state_dict[ckpt_key]
rel_pos_bias_current = state_dict_model[key]
L1, nH1 = rel_pos_bias_pretrained.size()
L2, nH2 = rel_pos_bias_current.size()
src_size = int((L1 - 3)**0.5)
dst_size = int((L2 - 3)**0.5)
if L1 != L2:
extra_tokens = rel_pos_bias_pretrained[-3:, :]
rel_pos_bias = rel_pos_bias_pretrained[:-3, :]
new_rel_pos_bias = resize_relative_position_bias_table(
src_size, dst_size, rel_pos_bias, nH1)
new_rel_pos_bias = torch.cat(
(new_rel_pos_bias, extra_tokens), dim=0)
logger.info('Resize the relative_position_bias_table from '
f'{state_dict[ckpt_key].shape} to '
f'{new_rel_pos_bias.shape}')
state_dict[ckpt_key] = new_rel_pos_bias
# The index buffer need to be re-generated.
index_buffer = ckpt_key.replace('bias_table', 'index')
if index_buffer in state_dict:
del state_dict[index_buffer]
def get_layer_depth(self, param_name: str, prefix: str = ''):
"""Get the layer-wise depth of a parameter.
Args:
param_name (str): The name of the parameter.
prefix (str): The prefix for the parameter.
Defaults to an empty string.
Returns:
Tuple[int, int]: The layer-wise depth and the num of layers.
Note:
The first depth is the stem module (``layer_depth=0``), and the
last depth is the subsequent module (``layer_depth=num_layers-1``)
"""
num_layers = self.num_layers + 2
if not param_name.startswith(prefix):
# For subsequent module like head
return num_layers - 1, num_layers
param_name = param_name[len(prefix):]
if param_name in ('cls_token', 'pos_embed'):
layer_depth = 0
elif param_name.startswith('patch_embed'):
layer_depth = 0
elif param_name.startswith('layers'):
layer_id = int(param_name.split('.')[1])
layer_depth = layer_id + 1
else:
layer_depth = num_layers - 1
return layer_depth, num_layers
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_activation_layer, build_norm_layer
from mmcv.cnn.bricks.drop import DropPath
from mmcv.cnn.bricks.transformer import AdaptivePadding
from mmengine.model import BaseModule
from mmengine.model.weight_init import trunc_normal_
from mmpretrain.registry import MODELS
from .base_backbone import BaseBackbone
from .vision_transformer import TransformerEncoderLayer
class ConvBlock(BaseModule):
"""Basic convluation block used in Conformer.
This block includes three convluation modules, and supports three new
functions:
1. Returns the output of both the final layers and the second convluation
module.
2. Fuses the input of the second convluation module with an extra input
feature map.
3. Supports to add an extra convluation module to the identity connection.
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
stride (int): The stride of the second convluation module.
Defaults to 1.
groups (int): The groups of the second convluation module.
Defaults to 1.
drop_path_rate (float): The rate of the DropPath layer. Defaults to 0.
with_residual_conv (bool): Whether to add an extra convluation module
to the identity connection. Defaults to False.
norm_cfg (dict): The config of normalization layers.
Defaults to ``dict(type='BN', eps=1e-6)``.
act_cfg (dict): The config of activative functions.
Defaults to ``dict(type='ReLU', inplace=True))``.
init_cfg (dict, optional): The extra config to initialize the module.
Defaults to None.
"""
def __init__(self,
in_channels,
out_channels,
stride=1,
groups=1,
drop_path_rate=0.,
with_residual_conv=False,
norm_cfg=dict(type='BN', eps=1e-6),
act_cfg=dict(type='ReLU', inplace=True),
init_cfg=None):
super(ConvBlock, self).__init__(init_cfg=init_cfg)
expansion = 4
mid_channels = out_channels // expansion
self.conv1 = nn.Conv2d(
in_channels,
mid_channels,
kernel_size=1,
stride=1,
padding=0,
bias=False)
self.bn1 = build_norm_layer(norm_cfg, mid_channels)[1]
self.act1 = build_activation_layer(act_cfg)
self.conv2 = nn.Conv2d(
mid_channels,
mid_channels,
kernel_size=3,
stride=stride,
groups=groups,
padding=1,
bias=False)
self.bn2 = build_norm_layer(norm_cfg, mid_channels)[1]
self.act2 = build_activation_layer(act_cfg)
self.conv3 = nn.Conv2d(
mid_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
bias=False)
self.bn3 = build_norm_layer(norm_cfg, out_channels)[1]
self.act3 = build_activation_layer(act_cfg)
if with_residual_conv:
self.residual_conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=stride,
padding=0,
bias=False)
self.residual_bn = build_norm_layer(norm_cfg, out_channels)[1]
self.with_residual_conv = with_residual_conv
self.drop_path = DropPath(
drop_path_rate) if drop_path_rate > 0. else nn.Identity()
def zero_init_last_bn(self):
nn.init.zeros_(self.bn3.weight)
def forward(self, x, fusion_features=None, out_conv2=True):
identity = x
x = self.conv1(x)
x = self.bn1(x)
x = self.act1(x)
x = self.conv2(x) if fusion_features is None else self.conv2(
x + fusion_features)
x = self.bn2(x)
x2 = self.act2(x)
x = self.conv3(x2)
x = self.bn3(x)
if self.drop_path is not None:
x = self.drop_path(x)
if self.with_residual_conv:
identity = self.residual_conv(identity)
identity = self.residual_bn(identity)
x += identity
x = self.act3(x)
if out_conv2:
return x, x2
else:
return x
class FCUDown(BaseModule):
"""CNN feature maps -> Transformer patch embeddings."""
def __init__(self,
in_channels,
out_channels,
down_stride,
with_cls_token=True,
norm_cfg=dict(type='LN', eps=1e-6),
act_cfg=dict(type='GELU'),
init_cfg=None):
super(FCUDown, self).__init__(init_cfg=init_cfg)
self.down_stride = down_stride
self.with_cls_token = with_cls_token
self.conv_project = nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.sample_pooling = nn.AvgPool2d(
kernel_size=down_stride, stride=down_stride)
self.ln = build_norm_layer(norm_cfg, out_channels)[1]
self.act = build_activation_layer(act_cfg)
def forward(self, x, x_t):
x = self.conv_project(x) # [N, C, H, W]
x = self.sample_pooling(x).flatten(2).transpose(1, 2)
x = self.ln(x)
x = self.act(x)
if self.with_cls_token:
x = torch.cat([x_t[:, 0][:, None, :], x], dim=1)
return x
class FCUUp(BaseModule):
"""Transformer patch embeddings -> CNN feature maps."""
def __init__(self,
in_channels,
out_channels,
up_stride,
with_cls_token=True,
norm_cfg=dict(type='BN', eps=1e-6),
act_cfg=dict(type='ReLU', inplace=True),
init_cfg=None):
super(FCUUp, self).__init__(init_cfg=init_cfg)
self.up_stride = up_stride
self.with_cls_token = with_cls_token
self.conv_project = nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.bn = build_norm_layer(norm_cfg, out_channels)[1]
self.act = build_activation_layer(act_cfg)
def forward(self, x, H, W):
B, _, C = x.shape
# [N, 197, 384] -> [N, 196, 384] -> [N, 384, 196] -> [N, 384, 14, 14]
if self.with_cls_token:
x_r = x[:, 1:].transpose(1, 2).reshape(B, C, H, W)
else:
x_r = x.transpose(1, 2).reshape(B, C, H, W)
x_r = self.act(self.bn(self.conv_project(x_r)))
return F.interpolate(
x_r, size=(H * self.up_stride, W * self.up_stride))
class ConvTransBlock(BaseModule):
"""Basic module for Conformer.
This module is a fusion of CNN block transformer encoder block.
Args:
in_channels (int): The number of input channels in conv blocks.
out_channels (int): The number of output channels in conv blocks.
embed_dims (int): The embedding dimension in transformer blocks.
conv_stride (int): The stride of conv2d layers. Defaults to 1.
groups (int): The groups of conv blocks. Defaults to 1.
with_residual_conv (bool): Whether to add a conv-bn layer to the
identity connect in the conv block. Defaults to False.
down_stride (int): The stride of the downsample pooling layer.
Defaults to 4.
num_heads (int): The number of heads in transformer attention layers.
Defaults to 12.
mlp_ratio (float): The expansion ratio in transformer FFN module.
Defaults to 4.
qkv_bias (bool): Enable bias for qkv if True. Defaults to False.
with_cls_token (bool): Whether use class token or not.
Defaults to True.
drop_rate (float): The dropout rate of the output projection and
FFN in the transformer block. Defaults to 0.
attn_drop_rate (float): The dropout rate after the attention
calculation in the transformer block. Defaults to 0.
drop_path_rate (bloat): The drop path rate in both the conv block
and the transformer block. Defaults to 0.
last_fusion (bool): Whether this block is the last stage. If so,
downsample the fusion feature map.
init_cfg (dict, optional): The extra config to initialize the module.
Defaults to None.
"""
def __init__(self,
in_channels,
out_channels,
embed_dims,
conv_stride=1,
groups=1,
with_residual_conv=False,
down_stride=4,
num_heads=12,
mlp_ratio=4.,
qkv_bias=False,
with_cls_token=True,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
last_fusion=False,
init_cfg=None):
super(ConvTransBlock, self).__init__(init_cfg=init_cfg)
expansion = 4
self.cnn_block = ConvBlock(
in_channels=in_channels,
out_channels=out_channels,
with_residual_conv=with_residual_conv,
stride=conv_stride,
groups=groups)
if last_fusion:
self.fusion_block = ConvBlock(
in_channels=out_channels,
out_channels=out_channels,
stride=2,
with_residual_conv=True,
groups=groups,
drop_path_rate=drop_path_rate)
else:
self.fusion_block = ConvBlock(
in_channels=out_channels,
out_channels=out_channels,
groups=groups,
drop_path_rate=drop_path_rate)
self.squeeze_block = FCUDown(
in_channels=out_channels // expansion,
out_channels=embed_dims,
down_stride=down_stride,
with_cls_token=with_cls_token)
self.expand_block = FCUUp(
in_channels=embed_dims,
out_channels=out_channels // expansion,
up_stride=down_stride,
with_cls_token=with_cls_token)
self.trans_block = TransformerEncoderLayer(
embed_dims=embed_dims,
num_heads=num_heads,
feedforward_channels=int(embed_dims * mlp_ratio),
drop_rate=drop_rate,
drop_path_rate=drop_path_rate,
attn_drop_rate=attn_drop_rate,
qkv_bias=qkv_bias,
norm_cfg=dict(type='LN', eps=1e-6))
self.down_stride = down_stride
self.embed_dim = embed_dims
self.last_fusion = last_fusion
def forward(self, cnn_input, trans_input):
x, x_conv2 = self.cnn_block(cnn_input, out_conv2=True)
_, _, H, W = x_conv2.shape
# Convert the feature map of conv2 to transformer embedding
# and concat with class token.
conv2_embedding = self.squeeze_block(x_conv2, trans_input)
trans_output = self.trans_block(conv2_embedding + trans_input)
# Convert the transformer output embedding to feature map
trans_features = self.expand_block(trans_output, H // self.down_stride,
W // self.down_stride)
x = self.fusion_block(
x, fusion_features=trans_features, out_conv2=False)
return x, trans_output
@MODELS.register_module()
class Conformer(BaseBackbone):
"""Conformer backbone.
A PyTorch implementation of : `Conformer: Local Features Coupling Global
Representations for Visual Recognition <https://arxiv.org/abs/2105.03889>`_
Args:
arch (str | dict): Conformer architecture. Defaults to 'tiny'.
patch_size (int): The patch size. Defaults to 16.
base_channels (int): The base number of channels in CNN network.
Defaults to 64.
mlp_ratio (float): The expansion ratio of FFN network in transformer
block. Defaults to 4.
with_cls_token (bool): Whether use class token or not.
Defaults to True.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
arch_zoo = {
**dict.fromkeys(['t', 'tiny'],
{'embed_dims': 384,
'channel_ratio': 1,
'num_heads': 6,
'depths': 12
}),
**dict.fromkeys(['s', 'small'],
{'embed_dims': 384,
'channel_ratio': 4,
'num_heads': 6,
'depths': 12
}),
**dict.fromkeys(['b', 'base'],
{'embed_dims': 576,
'channel_ratio': 6,
'num_heads': 9,
'depths': 12
}),
} # yapf: disable
_version = 1
def __init__(self,
arch='tiny',
patch_size=16,
base_channels=64,
mlp_ratio=4.,
qkv_bias=True,
with_cls_token=True,
drop_path_rate=0.,
norm_eval=True,
frozen_stages=0,
out_indices=-1,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
if isinstance(arch, str):
arch = arch.lower()
assert arch in set(self.arch_zoo), \
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
self.arch_settings = self.arch_zoo[arch]
else:
essential_keys = {
'embed_dims', 'depths', 'num_heads', 'channel_ratio'
}
assert isinstance(arch, dict) and set(arch) == essential_keys, \
f'Custom arch needs a dict with keys {essential_keys}'
self.arch_settings = arch
self.num_features = self.embed_dims = self.arch_settings['embed_dims']
self.depths = self.arch_settings['depths']
self.num_heads = self.arch_settings['num_heads']
self.channel_ratio = self.arch_settings['channel_ratio']
if isinstance(out_indices, int):
out_indices = [out_indices]
assert isinstance(out_indices, Sequence), \
f'"out_indices" must by a sequence or int, ' \
f'get {type(out_indices)} instead.'
for i, index in enumerate(out_indices):
if index < 0:
out_indices[i] = self.depths + index + 1
assert out_indices[i] >= 0, f'Invalid out_indices {index}'
self.out_indices = out_indices
self.norm_eval = norm_eval
self.frozen_stages = frozen_stages
self.with_cls_token = with_cls_token
if self.with_cls_token:
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
# stochastic depth decay rule
self.trans_dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, self.depths)
]
# Stem stage: get the feature maps by conv block
self.conv1 = nn.Conv2d(
3, 64, kernel_size=7, stride=2, padding=3,
bias=False) # 1 / 2 [112, 112]
self.bn1 = nn.BatchNorm2d(64)
self.act1 = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(
kernel_size=3, stride=2, padding=1) # 1 / 4 [56, 56]
assert patch_size % 16 == 0, 'The patch size of Conformer must ' \
'be divisible by 16.'
trans_down_stride = patch_size // 4
# To solve the issue #680
# Auto pad the feature map to be divisible by trans_down_stride
self.auto_pad = AdaptivePadding(trans_down_stride, trans_down_stride)
# 1 stage
stage1_channels = int(base_channels * self.channel_ratio)
self.conv_1 = ConvBlock(
in_channels=64,
out_channels=stage1_channels,
with_residual_conv=True,
stride=1)
self.trans_patch_conv = nn.Conv2d(
64,
self.embed_dims,
kernel_size=trans_down_stride,
stride=trans_down_stride,
padding=0)
self.trans_1 = TransformerEncoderLayer(
embed_dims=self.embed_dims,
num_heads=self.num_heads,
feedforward_channels=int(self.embed_dims * mlp_ratio),
drop_path_rate=self.trans_dpr[0],
qkv_bias=qkv_bias,
norm_cfg=dict(type='LN', eps=1e-6))
# 2~4 stage
init_stage = 2
fin_stage = self.depths // 3 + 1
for i in range(init_stage, fin_stage):
self.add_module(
f'conv_trans_{i}',
ConvTransBlock(
in_channels=stage1_channels,
out_channels=stage1_channels,
embed_dims=self.embed_dims,
conv_stride=1,
with_residual_conv=False,
down_stride=trans_down_stride,
num_heads=self.num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path_rate=self.trans_dpr[i - 1],
with_cls_token=self.with_cls_token))
stage2_channels = int(base_channels * self.channel_ratio * 2)
# 5~8 stage
init_stage = fin_stage # 5
fin_stage = fin_stage + self.depths // 3 # 9
for i in range(init_stage, fin_stage):
if i == init_stage:
conv_stride = 2
in_channels = stage1_channels
else:
conv_stride = 1
in_channels = stage2_channels
with_residual_conv = True if i == init_stage else False
self.add_module(
f'conv_trans_{i}',
ConvTransBlock(
in_channels=in_channels,
out_channels=stage2_channels,
embed_dims=self.embed_dims,
conv_stride=conv_stride,
with_residual_conv=with_residual_conv,
down_stride=trans_down_stride // 2,
num_heads=self.num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path_rate=self.trans_dpr[i - 1],
with_cls_token=self.with_cls_token))
stage3_channels = int(base_channels * self.channel_ratio * 2 * 2)
# 9~12 stage
init_stage = fin_stage # 9
fin_stage = fin_stage + self.depths // 3 # 13
for i in range(init_stage, fin_stage):
if i == init_stage:
conv_stride = 2
in_channels = stage2_channels
with_residual_conv = True
else:
conv_stride = 1
in_channels = stage3_channels
with_residual_conv = False
last_fusion = (i == self.depths)
self.add_module(
f'conv_trans_{i}',
ConvTransBlock(
in_channels=in_channels,
out_channels=stage3_channels,
embed_dims=self.embed_dims,
conv_stride=conv_stride,
with_residual_conv=with_residual_conv,
down_stride=trans_down_stride // 4,
num_heads=self.num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path_rate=self.trans_dpr[i - 1],
with_cls_token=self.with_cls_token,
last_fusion=last_fusion))
self.fin_stage = fin_stage
self.pooling = nn.AdaptiveAvgPool2d(1)
self.trans_norm = nn.LayerNorm(self.embed_dims)
if self.with_cls_token:
trunc_normal_(self.cls_token, std=.02)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1.)
nn.init.constant_(m.bias, 0.)
if hasattr(m, 'zero_init_last_bn'):
m.zero_init_last_bn()
def init_weights(self):
super(Conformer, self).init_weights()
if (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
# Suppress default init if use pretrained model.
return
self.apply(self._init_weights)
def forward(self, x):
output = []
B = x.shape[0]
if self.with_cls_token:
cls_tokens = self.cls_token.expand(B, -1, -1)
# stem
x_base = self.maxpool(self.act1(self.bn1(self.conv1(x))))
x_base = self.auto_pad(x_base)
# 1 stage [N, 64, 56, 56] -> [N, 128, 56, 56]
x = self.conv_1(x_base, out_conv2=False)
x_t = self.trans_patch_conv(x_base).flatten(2).transpose(1, 2)
if self.with_cls_token:
x_t = torch.cat([cls_tokens, x_t], dim=1)
x_t = self.trans_1(x_t)
# 2 ~ final
for i in range(2, self.fin_stage):
stage = getattr(self, f'conv_trans_{i}')
x, x_t = stage(x, x_t)
if i in self.out_indices:
if self.with_cls_token:
output.append([
self.pooling(x).flatten(1),
self.trans_norm(x_t)[:, 0]
])
else:
# if no class token, use the mean patch token
# as the transformer feature.
output.append([
self.pooling(x).flatten(1),
self.trans_norm(x_t).mean(dim=1)
])
return tuple(output)
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence
import torch
import torch.nn as nn
from mmcv.cnn.bricks import (Conv2dAdaptivePadding, build_activation_layer,
build_norm_layer)
from mmengine.utils import digit_version
from mmpretrain.registry import MODELS
from .base_backbone import BaseBackbone
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
return self.fn(x) + x
@MODELS.register_module()
class ConvMixer(BaseBackbone):
"""ConvMixer. .
A PyTorch implementation of : `Patches Are All You Need?
<https://arxiv.org/pdf/2201.09792.pdf>`_
Modified from the `official repo
<https://github.com/locuslab/convmixer/blob/main/convmixer.py>`_
and `timm
<https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/convmixer.py>`_.
Args:
arch (str | dict): The model's architecture. If string, it should be
one of architecture in ``ConvMixer.arch_settings``. And if dict, it
should include the following two keys:
- embed_dims (int): The dimensions of patch embedding.
- depth (int): Number of repetitions of ConvMixer Layer.
- patch_size (int): The patch size.
- kernel_size (int): The kernel size of depthwise conv layers.
Defaults to '768/32'.
in_channels (int): Number of input image channels. Defaults to 3.
patch_size (int): The size of one patch in the patch embed layer.
Defaults to 7.
norm_cfg (dict): The config dict for norm layers.
Defaults to ``dict(type='BN')``.
act_cfg (dict): The config dict for activation after each convolution.
Defaults to ``dict(type='GELU')``.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
frozen_stages (int): Stages to be frozen (all param fixed).
Defaults to 0, which means not freezing any parameters.
init_cfg (dict, optional): Initialization config dict.
"""
arch_settings = {
'768/32': {
'embed_dims': 768,
'depth': 32,
'patch_size': 7,
'kernel_size': 7
},
'1024/20': {
'embed_dims': 1024,
'depth': 20,
'patch_size': 14,
'kernel_size': 9
},
'1536/20': {
'embed_dims': 1536,
'depth': 20,
'patch_size': 7,
'kernel_size': 9
},
}
def __init__(self,
arch='768/32',
in_channels=3,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='GELU'),
out_indices=-1,
frozen_stages=0,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
if isinstance(arch, str):
assert arch in self.arch_settings, \
f'Unavailable arch, please choose from ' \
f'({set(self.arch_settings)}) or pass a dict.'
arch = self.arch_settings[arch]
elif isinstance(arch, dict):
essential_keys = {
'embed_dims', 'depth', 'patch_size', 'kernel_size'
}
assert isinstance(arch, dict) and essential_keys <= set(arch), \
f'Custom arch needs a dict with keys {essential_keys}'
self.embed_dims = arch['embed_dims']
self.depth = arch['depth']
self.patch_size = arch['patch_size']
self.kernel_size = arch['kernel_size']
self.act = build_activation_layer(act_cfg)
# check out indices and frozen stages
if isinstance(out_indices, int):
out_indices = [out_indices]
assert isinstance(out_indices, Sequence), \
f'"out_indices" must by a sequence or int, ' \
f'get {type(out_indices)} instead.'
for i, index in enumerate(out_indices):
if index < 0:
out_indices[i] = self.depth + index
assert out_indices[i] >= 0, f'Invalid out_indices {index}'
self.out_indices = out_indices
self.frozen_stages = frozen_stages
# Set stem layers
self.stem = nn.Sequential(
nn.Conv2d(
in_channels,
self.embed_dims,
kernel_size=self.patch_size,
stride=self.patch_size), self.act,
build_norm_layer(norm_cfg, self.embed_dims)[1])
# Set conv2d according to torch version
convfunc = nn.Conv2d
if digit_version(torch.__version__) < digit_version('1.9.0'):
convfunc = Conv2dAdaptivePadding
# Repetitions of ConvMixer Layer
self.stages = nn.Sequential(*[
nn.Sequential(
Residual(
nn.Sequential(
convfunc(
self.embed_dims,
self.embed_dims,
self.kernel_size,
groups=self.embed_dims,
padding='same'), self.act,
build_norm_layer(norm_cfg, self.embed_dims)[1])),
nn.Conv2d(self.embed_dims, self.embed_dims, kernel_size=1),
self.act,
build_norm_layer(norm_cfg, self.embed_dims)[1])
for _ in range(self.depth)
])
self._freeze_stages()
def forward(self, x):
x = self.stem(x)
outs = []
for i, stage in enumerate(self.stages):
x = stage(x)
if i in self.out_indices:
outs.append(x)
# x = self.pooling(x).flatten(1)
return tuple(outs)
def train(self, mode=True):
super(ConvMixer, self).train(mode)
self._freeze_stages()
def _freeze_stages(self):
for i in range(self.frozen_stages):
stage = self.stages[i]
stage.eval()
for param in stage.parameters():
param.requires_grad = False
# Copyright (c) OpenMMLab. All rights reserved.
from functools import partial
from itertools import chain
from typing import Sequence
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn.bricks import DropPath
from mmengine.model import BaseModule, ModuleList, Sequential
from mmpretrain.registry import MODELS
from ..utils import GRN, build_norm_layer
from .base_backbone import BaseBackbone
class ConvNeXtBlock(BaseModule):
"""ConvNeXt Block.
Args:
in_channels (int): The number of input channels.
dw_conv_cfg (dict): Config of depthwise convolution.
Defaults to ``dict(kernel_size=7, padding=3)``.
norm_cfg (dict): The config dict for norm layers.
Defaults to ``dict(type='LN2d', eps=1e-6)``.
act_cfg (dict): The config dict for activation between pointwise
convolution. Defaults to ``dict(type='GELU')``.
mlp_ratio (float): The expansion ratio in both pointwise convolution.
Defaults to 4.
linear_pw_conv (bool): Whether to use linear layer to do pointwise
convolution. More details can be found in the note.
Defaults to True.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
layer_scale_init_value (float): Init value for Layer Scale.
Defaults to 1e-6.
Note:
There are two equivalent implementations:
1. DwConv -> LayerNorm -> 1x1 Conv -> GELU -> 1x1 Conv;
all outputs are in (N, C, H, W).
2. DwConv -> LayerNorm -> Permute to (N, H, W, C) -> Linear -> GELU
-> Linear; Permute back
As default, we use the second to align with the official repository.
And it may be slightly faster.
"""
def __init__(self,
in_channels,
dw_conv_cfg=dict(kernel_size=7, padding=3),
norm_cfg=dict(type='LN2d', eps=1e-6),
act_cfg=dict(type='GELU'),
mlp_ratio=4.,
linear_pw_conv=True,
drop_path_rate=0.,
layer_scale_init_value=1e-6,
use_grn=False,
with_cp=False):
super().__init__()
self.with_cp = with_cp
self.depthwise_conv = nn.Conv2d(
in_channels, in_channels, groups=in_channels, **dw_conv_cfg)
self.linear_pw_conv = linear_pw_conv
self.norm = build_norm_layer(norm_cfg, in_channels)
mid_channels = int(mlp_ratio * in_channels)
if self.linear_pw_conv:
# Use linear layer to do pointwise conv.
pw_conv = nn.Linear
else:
pw_conv = partial(nn.Conv2d, kernel_size=1)
self.pointwise_conv1 = pw_conv(in_channels, mid_channels)
self.act = MODELS.build(act_cfg)
self.pointwise_conv2 = pw_conv(mid_channels, in_channels)
if use_grn:
self.grn = GRN(mid_channels)
else:
self.grn = None
self.gamma = nn.Parameter(
layer_scale_init_value * torch.ones((in_channels)),
requires_grad=True) if layer_scale_init_value > 0 else None
self.drop_path = DropPath(
drop_path_rate) if drop_path_rate > 0. else nn.Identity()
def forward(self, x):
def _inner_forward(x):
shortcut = x
x = self.depthwise_conv(x)
if self.linear_pw_conv:
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
x = self.norm(x, data_format='channel_last')
x = self.pointwise_conv1(x)
x = self.act(x)
if self.grn is not None:
x = self.grn(x, data_format='channel_last')
x = self.pointwise_conv2(x)
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
else:
x = self.norm(x, data_format='channel_first')
x = self.pointwise_conv1(x)
x = self.act(x)
if self.grn is not None:
x = self.grn(x, data_format='channel_first')
x = self.pointwise_conv2(x)
if self.gamma is not None:
x = x.mul(self.gamma.view(1, -1, 1, 1))
x = shortcut + self.drop_path(x)
return x
if self.with_cp and x.requires_grad:
x = cp.checkpoint(_inner_forward, x)
else:
x = _inner_forward(x)
return x
@MODELS.register_module()
class ConvNeXt(BaseBackbone):
"""ConvNeXt v1&v2 backbone.
A PyTorch implementation of `A ConvNet for the 2020s
<https://arxiv.org/abs/2201.03545>`_ and
`ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders
<http://arxiv.org/abs/2301.00808>`_
Modified from the `official repo
<https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py>`_
and `timm
<https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/convnext.py>`_.
To use ConvNeXt v2, please set ``use_grn=True`` and ``layer_scale_init_value=0.``.
Args:
arch (str | dict): The model's architecture. If string, it should be
one of architecture in ``ConvNeXt.arch_settings``. And if dict, it
should include the following two keys:
- depths (list[int]): Number of blocks at each stage.
- channels (list[int]): The number of channels at each stage.
Defaults to 'tiny'.
in_channels (int): Number of input image channels. Defaults to 3.
stem_patch_size (int): The size of one patch in the stem layer.
Defaults to 4.
norm_cfg (dict): The config dict for norm layers.
Defaults to ``dict(type='LN2d', eps=1e-6)``.
act_cfg (dict): The config dict for activation between pointwise
convolution. Defaults to ``dict(type='GELU')``.
linear_pw_conv (bool): Whether to use linear layer to do pointwise
convolution. Defaults to True.
use_grn (bool): Whether to add Global Response Normalization in the
blocks. Defaults to False.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
layer_scale_init_value (float): Init value for Layer Scale.
Defaults to 1e-6.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
frozen_stages (int): Stages to be frozen (all param fixed).
Defaults to 0, which means not freezing any parameters.
gap_before_final_norm (bool): Whether to globally average the feature
map before the final norm layer. In the official repo, it's only
used in classification task. Defaults to True.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
init_cfg (dict, optional): Initialization config dict
""" # noqa: E501
arch_settings = {
'atto': {
'depths': [2, 2, 6, 2],
'channels': [40, 80, 160, 320]
},
'femto': {
'depths': [2, 2, 6, 2],
'channels': [48, 96, 192, 384]
},
'pico': {
'depths': [2, 2, 6, 2],
'channels': [64, 128, 256, 512]
},
'nano': {
'depths': [2, 2, 8, 2],
'channels': [80, 160, 320, 640]
},
'tiny': {
'depths': [3, 3, 9, 3],
'channels': [96, 192, 384, 768]
},
'small': {
'depths': [3, 3, 27, 3],
'channels': [96, 192, 384, 768]
},
'base': {
'depths': [3, 3, 27, 3],
'channels': [128, 256, 512, 1024]
},
'large': {
'depths': [3, 3, 27, 3],
'channels': [192, 384, 768, 1536]
},
'xlarge': {
'depths': [3, 3, 27, 3],
'channels': [256, 512, 1024, 2048]
},
'huge': {
'depths': [3, 3, 27, 3],
'channels': [352, 704, 1408, 2816]
}
}
def __init__(self,
arch='tiny',
in_channels=3,
stem_patch_size=4,
norm_cfg=dict(type='LN2d', eps=1e-6),
act_cfg=dict(type='GELU'),
linear_pw_conv=True,
use_grn=False,
drop_path_rate=0.,
layer_scale_init_value=1e-6,
out_indices=-1,
frozen_stages=0,
gap_before_final_norm=True,
with_cp=False,
init_cfg=[
dict(
type='TruncNormal',
layer=['Conv2d', 'Linear'],
std=.02,
bias=0.),
dict(
type='Constant', layer=['LayerNorm'], val=1.,
bias=0.),
]):
super().__init__(init_cfg=init_cfg)
if isinstance(arch, str):
assert arch in self.arch_settings, \
f'Unavailable arch, please choose from ' \
f'({set(self.arch_settings)}) or pass a dict.'
arch = self.arch_settings[arch]
elif isinstance(arch, dict):
assert 'depths' in arch and 'channels' in arch, \
f'The arch dict must have "depths" and "channels", ' \
f'but got {list(arch.keys())}.'
self.depths = arch['depths']
self.channels = arch['channels']
assert (isinstance(self.depths, Sequence)
and isinstance(self.channels, Sequence)
and len(self.depths) == len(self.channels)), \
f'The "depths" ({self.depths}) and "channels" ({self.channels}) ' \
'should be both sequence with the same length.'
self.num_stages = len(self.depths)
if isinstance(out_indices, int):
out_indices = [out_indices]
assert isinstance(out_indices, Sequence), \
f'"out_indices" must by a sequence or int, ' \
f'get {type(out_indices)} instead.'
for i, index in enumerate(out_indices):
if index < 0:
out_indices[i] = 4 + index
assert out_indices[i] >= 0, f'Invalid out_indices {index}'
self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.gap_before_final_norm = gap_before_final_norm
# stochastic depth decay rule
dpr = [
x.item()
for x in torch.linspace(0, drop_path_rate, sum(self.depths))
]
block_idx = 0
# 4 downsample layers between stages, including the stem layer.
self.downsample_layers = ModuleList()
stem = nn.Sequential(
nn.Conv2d(
in_channels,
self.channels[0],
kernel_size=stem_patch_size,
stride=stem_patch_size),
build_norm_layer(norm_cfg, self.channels[0]),
)
self.downsample_layers.append(stem)
# 4 feature resolution stages, each consisting of multiple residual
# blocks
self.stages = nn.ModuleList()
for i in range(self.num_stages):
depth = self.depths[i]
channels = self.channels[i]
if i >= 1:
downsample_layer = nn.Sequential(
build_norm_layer(norm_cfg, self.channels[i - 1]),
nn.Conv2d(
self.channels[i - 1],
channels,
kernel_size=2,
stride=2),
)
self.downsample_layers.append(downsample_layer)
stage = Sequential(*[
ConvNeXtBlock(
in_channels=channels,
drop_path_rate=dpr[block_idx + j],
norm_cfg=norm_cfg,
act_cfg=act_cfg,
linear_pw_conv=linear_pw_conv,
layer_scale_init_value=layer_scale_init_value,
use_grn=use_grn,
with_cp=with_cp) for j in range(depth)
])
block_idx += depth
self.stages.append(stage)
if i in self.out_indices:
norm_layer = build_norm_layer(norm_cfg, channels)
self.add_module(f'norm{i}', norm_layer)
self._freeze_stages()
def forward(self, x):
outs = []
for i, stage in enumerate(self.stages):
x = self.downsample_layers[i](x)
x = stage(x)
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
if self.gap_before_final_norm:
gap = x.mean([-2, -1], keepdim=True)
outs.append(norm_layer(gap).flatten(1))
else:
outs.append(norm_layer(x))
return tuple(outs)
def _freeze_stages(self):
for i in range(self.frozen_stages):
downsample_layer = self.downsample_layers[i]
stage = self.stages[i]
downsample_layer.eval()
stage.eval()
for param in chain(downsample_layer.parameters(),
stage.parameters()):
param.requires_grad = False
def train(self, mode=True):
super(ConvNeXt, self).train(mode)
self._freeze_stages()
def get_layer_depth(self, param_name: str, prefix: str = ''):
"""Get the layer-wise depth of a parameter.
Args:
param_name (str): The name of the parameter.
prefix (str): The prefix for the parameter.
Defaults to an empty string.
Returns:
Tuple[int, int]: The layer-wise depth and the num of layers.
"""
max_layer_id = 12 if self.depths[-2] > 9 else 6
if not param_name.startswith(prefix):
# For subsequent module like head
return max_layer_id + 1, max_layer_id + 2
param_name = param_name[len(prefix):]
if param_name.startswith('downsample_layers'):
stage_id = int(param_name.split('.')[1])
if stage_id == 0:
layer_id = 0
elif stage_id == 1 or stage_id == 2:
layer_id = stage_id + 1
else: # stage_id == 3:
layer_id = max_layer_id
elif param_name.startswith('stages'):
stage_id = int(param_name.split('.')[1])
block_id = int(param_name.split('.')[2])
if stage_id == 0 or stage_id == 1:
layer_id = stage_id + 1
elif stage_id == 2:
layer_id = 3 + block_id // 3
else: # stage_id == 3:
layer_id = max_layer_id
# final norm layer
else:
layer_id = max_layer_id + 1
return layer_id, max_layer_id + 2
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Sequence
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
from mmcv.cnn.bricks import DropPath
from mmengine.model import BaseModule, Sequential
from torch.nn.modules.batchnorm import _BatchNorm
from mmpretrain.registry import MODELS
from ..utils import to_ntuple
from .resnet import Bottleneck as ResNetBottleneck
from .resnext import Bottleneck as ResNeXtBottleneck
eps = 1.0e-5
class DarknetBottleneck(BaseModule):
"""The basic bottleneck block used in Darknet. Each DarknetBottleneck
consists of two ConvModules and the input is added to the final output.
Each ConvModule is composed of Conv, BN, and LeakyReLU. The first convLayer
has filter size of 1x1 and the second one has the filter size of 3x3.
Args:
in_channels (int): The input channels of this Module.
out_channels (int): The output channels of this Module.
expansion (int): The ratio of ``out_channels/mid_channels`` where
``mid_channels`` is the input/output channels of conv2.
Defaults to 4.
add_identity (bool): Whether to add identity to the out.
Defaults to True.
use_depthwise (bool): Whether to use depthwise separable convolution.
Defaults to False.
conv_cfg (dict): Config dict for convolution layer. Defaults to None,
which means using conv2d.
drop_path_rate (float): The ratio of the drop path layer. Default: 0.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='BN', eps=1e-5)``.
act_cfg (dict): Config dict for activation layer.
Defaults to ``dict(type='Swish')``.
"""
def __init__(self,
in_channels,
out_channels,
expansion=2,
add_identity=True,
use_depthwise=False,
conv_cfg=None,
drop_path_rate=0,
norm_cfg=dict(type='BN', eps=1e-5),
act_cfg=dict(type='LeakyReLU', inplace=True),
init_cfg=None):
super().__init__(init_cfg)
hidden_channels = int(out_channels / expansion)
conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule
self.conv1 = ConvModule(
in_channels,
hidden_channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.conv2 = conv(
hidden_channels,
out_channels,
3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.add_identity = \
add_identity and in_channels == out_channels
self.drop_path = DropPath(drop_prob=drop_path_rate
) if drop_path_rate > eps else nn.Identity()
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.conv2(out)
out = self.drop_path(out)
if self.add_identity:
return out + identity
else:
return out
class CSPStage(BaseModule):
"""Cross Stage Partial Stage.
.. code:: text
Downsample Convolution (optional)
|
|
Expand Convolution
|
|
Split to xa, xb
| \
| \
| blocks(xb)
| /
| / transition
| /
Concat xa, blocks(xb)
|
Transition Convolution
Args:
block_fn (nn.module): The basic block function in the Stage.
in_channels (int): The input channels of the CSP layer.
out_channels (int): The output channels of the CSP layer.
has_downsampler (bool): Whether to add a downsampler in the stage.
Default: False.
down_growth (bool): Whether to expand the channels in the
downsampler layer of the stage. Default: False.
expand_ratio (float): The expand ratio to adjust the number of
channels of the expand conv layer. Default: 0.5
bottle_ratio (float): Ratio to adjust the number of channels of the
hidden layer. Default: 0.5
block_dpr (float): The ratio of the drop path layer in the
blocks of the stage. Default: 0.
num_blocks (int): Number of blocks. Default: 1
conv_cfg (dict, optional): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN')
act_cfg (dict): Config dict for activation layer.
Default: dict(type='LeakyReLU', inplace=True)
"""
def __init__(self,
block_fn,
in_channels,
out_channels,
has_downsampler=True,
down_growth=False,
expand_ratio=0.5,
bottle_ratio=2,
num_blocks=1,
block_dpr=0,
block_args={},
conv_cfg=None,
norm_cfg=dict(type='BN', eps=1e-5),
act_cfg=dict(type='LeakyReLU', inplace=True),
init_cfg=None):
super().__init__(init_cfg)
# grow downsample channels to output channels
down_channels = out_channels if down_growth else in_channels
block_dpr = to_ntuple(num_blocks)(block_dpr)
if has_downsampler:
self.downsample_conv = ConvModule(
in_channels=in_channels,
out_channels=down_channels,
kernel_size=3,
stride=2,
padding=1,
groups=32 if block_fn is ResNeXtBottleneck else 1,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
else:
self.downsample_conv = nn.Identity()
exp_channels = int(down_channels * expand_ratio)
self.expand_conv = ConvModule(
in_channels=down_channels,
out_channels=exp_channels,
kernel_size=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg if block_fn is DarknetBottleneck else None)
assert exp_channels % 2 == 0, \
'The channel number before blocks must be divisible by 2.'
block_channels = exp_channels // 2
blocks = []
for i in range(num_blocks):
block_cfg = dict(
in_channels=block_channels,
out_channels=block_channels,
expansion=bottle_ratio,
drop_path_rate=block_dpr[i],
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
**block_args)
blocks.append(block_fn(**block_cfg))
self.blocks = Sequential(*blocks)
self.atfer_blocks_conv = ConvModule(
block_channels,
block_channels,
1,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.final_conv = ConvModule(
2 * block_channels,
out_channels,
1,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
def forward(self, x):
x = self.downsample_conv(x)
x = self.expand_conv(x)
split = x.shape[1] // 2
xa, xb = x[:, :split], x[:, split:]
xb = self.blocks(xb)
xb = self.atfer_blocks_conv(xb).contiguous()
x_final = torch.cat((xa, xb), dim=1)
return self.final_conv(x_final)
class CSPNet(BaseModule):
"""The abstract CSP Network class.
A Pytorch implementation of `CSPNet: A New Backbone that can Enhance
Learning Capability of CNN <https://arxiv.org/abs/1911.11929>`_
This class is an abstract class because the Cross Stage Partial Network
(CSPNet) is a kind of universal network structure, and you
network block to implement networks like CSPResNet, CSPResNeXt and
CSPDarkNet.
Args:
arch (dict): The architecture of the CSPNet.
It should have the following keys:
- block_fn (Callable): A function or class to return a block
module, and it should accept at least ``in_channels``,
``out_channels``, ``expansion``, ``drop_path_rate``, ``norm_cfg``
and ``act_cfg``.
- in_channels (Tuple[int]): The number of input channels of each
stage.
- out_channels (Tuple[int]): The number of output channels of each
stage.
- num_blocks (Tuple[int]): The number of blocks in each stage.
- expansion_ratio (float | Tuple[float]): The expansion ratio in
the expand convolution of each stage. Defaults to 0.5.
- bottle_ratio (float | Tuple[float]): The expansion ratio of
blocks in each stage. Defaults to 2.
- has_downsampler (bool | Tuple[bool]): Whether to add a
downsample convolution in each stage. Defaults to True
- down_growth (bool | Tuple[bool]): Whether to expand the channels
in the downsampler layer of each stage. Defaults to False.
- block_args (dict | Tuple[dict], optional): The extra arguments to
the blocks in each stage. Defaults to None.
stem_fn (Callable): A function or class to return a stem module.
And it should accept ``in_channels``.
in_channels (int): Number of input image channels. Defaults to 3.
out_indices (int | Sequence[int]): Output from which stages.
Defaults to -1, which means the last stage.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
conv_cfg (dict, optional): The config dict for conv layers in blocks.
Defaults to None, which means use Conv2d.
norm_cfg (dict): The config dict for norm layers.
Defaults to ``dict(type='BN', eps=1e-5)``.
act_cfg (dict): The config dict for activation functions.
Defaults to ``dict(type='LeakyReLU', inplace=True)``.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Defaults to False.
init_cfg (dict, optional): The initialization settings.
Defaults to ``dict(type='Kaiming', layer='Conv2d'))``.
Example:
>>> from functools import partial
>>> import torch
>>> import torch.nn as nn
>>> from mmpretrain.models import CSPNet
>>> from mmpretrain.models.backbones.resnet import Bottleneck
>>>
>>> # A simple example to build CSPNet.
>>> arch = dict(
... block_fn=Bottleneck,
... in_channels=[32, 64],
... out_channels=[64, 128],
... num_blocks=[3, 4]
... )
>>> stem_fn = partial(nn.Conv2d, out_channels=32, kernel_size=3)
>>> model = CSPNet(arch=arch, stem_fn=stem_fn, out_indices=(0, 1))
>>> inputs = torch.rand(1, 3, 224, 224)
>>> outs = model(inputs)
>>> for out in outs:
... print(out.shape)
...
(1, 64, 111, 111)
(1, 128, 56, 56)
"""
def __init__(self,
arch,
stem_fn,
in_channels=3,
out_indices=-1,
frozen_stages=-1,
drop_path_rate=0.,
conv_cfg=None,
norm_cfg=dict(type='BN', eps=1e-5),
act_cfg=dict(type='LeakyReLU', inplace=True),
norm_eval=False,
init_cfg=dict(type='Kaiming', layer='Conv2d')):
super().__init__(init_cfg=init_cfg)
self.arch = self.expand_arch(arch)
self.num_stages = len(self.arch['in_channels'])
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.norm_eval = norm_eval
if frozen_stages not in range(-1, self.num_stages):
raise ValueError('frozen_stages must be in range(-1, '
f'{self.num_stages}). But received '
f'{frozen_stages}')
self.frozen_stages = frozen_stages
self.stem = stem_fn(in_channels)
stages = []
depths = self.arch['num_blocks']
dpr = torch.linspace(0, drop_path_rate, sum(depths)).split(depths)
for i in range(self.num_stages):
stage_cfg = {k: v[i] for k, v in self.arch.items()}
csp_stage = CSPStage(
**stage_cfg,
block_dpr=dpr[i].tolist(),
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
init_cfg=init_cfg)
stages.append(csp_stage)
self.stages = Sequential(*stages)
if isinstance(out_indices, int):
out_indices = [out_indices]
assert isinstance(out_indices, Sequence), \
f'"out_indices" must by a sequence or int, ' \
f'get {type(out_indices)} instead.'
out_indices = list(out_indices)
for i, index in enumerate(out_indices):
if index < 0:
out_indices[i] = len(self.stages) + index
assert 0 <= out_indices[i] <= len(self.stages), \
f'Invalid out_indices {index}.'
self.out_indices = out_indices
@staticmethod
def expand_arch(arch):
num_stages = len(arch['in_channels'])
def to_tuple(x, name=''):
if isinstance(x, (list, tuple)):
assert len(x) == num_stages, \
f'The length of {name} ({len(x)}) does not ' \
f'equals to the number of stages ({num_stages})'
return tuple(x)
else:
return (x, ) * num_stages
full_arch = {k: to_tuple(v, k) for k, v in arch.items()}
if 'block_args' not in full_arch:
full_arch['block_args'] = to_tuple({})
return full_arch
def _freeze_stages(self):
if self.frozen_stages >= 0:
self.stem.eval()
for param in self.stem.parameters():
param.requires_grad = False
for i in range(self.frozen_stages + 1):
m = self.stages[i]
m.eval()
for param in m.parameters():
param.requires_grad = False
def train(self, mode=True):
super(CSPNet, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
if isinstance(m, _BatchNorm):
m.eval()
def forward(self, x):
outs = []
x = self.stem(x)
for i, stage in enumerate(self.stages):
x = stage(x)
if i in self.out_indices:
outs.append(x)
return tuple(outs)
@MODELS.register_module()
class CSPDarkNet(CSPNet):
"""CSP-Darknet backbone used in YOLOv4.
Args:
depth (int): Depth of CSP-Darknet. Default: 53.
in_channels (int): Number of input image channels. Default: 3.
out_indices (Sequence[int]): Output from which stages.
Default: (3, ).
frozen_stages (int): Stages to be frozen (stop grad and set eval
mode). -1 means not freezing any parameters. Default: -1.
conv_cfg (dict): Config dict for convolution layer. Default: None.
norm_cfg (dict): Dictionary to construct and config norm layer.
Default: dict(type='BN', requires_grad=True).
act_cfg (dict): Config dict for activation layer.
Default: dict(type='LeakyReLU', negative_slope=0.1).
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
Example:
>>> from mmpretrain.models import CSPDarkNet
>>> import torch
>>> model = CSPDarkNet(depth=53, out_indices=(0, 1, 2, 3, 4))
>>> model.eval()
>>> inputs = torch.rand(1, 3, 416, 416)
>>> level_outputs = model(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
...
(1, 64, 208, 208)
(1, 128, 104, 104)
(1, 256, 52, 52)
(1, 512, 26, 26)
(1, 1024, 13, 13)
"""
arch_settings = {
53:
dict(
block_fn=DarknetBottleneck,
in_channels=(32, 64, 128, 256, 512),
out_channels=(64, 128, 256, 512, 1024),
num_blocks=(1, 2, 8, 8, 4),
expand_ratio=(2, 1, 1, 1, 1),
bottle_ratio=(2, 1, 1, 1, 1),
has_downsampler=True,
down_growth=True,
),
}
def __init__(self,
depth,
in_channels=3,
out_indices=(4, ),
frozen_stages=-1,
conv_cfg=None,
norm_cfg=dict(type='BN', eps=1e-5),
act_cfg=dict(type='LeakyReLU', inplace=True),
norm_eval=False,
init_cfg=dict(
type='Kaiming',
layer='Conv2d',
a=math.sqrt(5),
distribution='uniform',
mode='fan_in',
nonlinearity='leaky_relu')):
assert depth in self.arch_settings, 'depth must be one of ' \
f'{list(self.arch_settings.keys())}, but get {depth}.'
super().__init__(
arch=self.arch_settings[depth],
stem_fn=self._make_stem_layer,
in_channels=in_channels,
out_indices=out_indices,
frozen_stages=frozen_stages,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
norm_eval=norm_eval,
init_cfg=init_cfg)
def _make_stem_layer(self, in_channels):
"""using a stride=1 conv as the stem in CSPDarknet."""
# `stem_channels` equals to the `in_channels` in the first stage.
stem_channels = self.arch['in_channels'][0]
stem = ConvModule(
in_channels=in_channels,
out_channels=stem_channels,
kernel_size=3,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
return stem
@MODELS.register_module()
class CSPResNet(CSPNet):
"""CSP-ResNet backbone.
Args:
depth (int): Depth of CSP-ResNet. Default: 50.
out_indices (Sequence[int]): Output from which stages.
Default: (4, ).
frozen_stages (int): Stages to be frozen (stop grad and set eval
mode). -1 means not freezing any parameters. Default: -1.
conv_cfg (dict): Config dict for convolution layer. Default: None.
norm_cfg (dict): Dictionary to construct and config norm layer.
Default: dict(type='BN', requires_grad=True).
act_cfg (dict): Config dict for activation layer.
Default: dict(type='LeakyReLU', negative_slope=0.1).
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
Example:
>>> from mmpretrain.models import CSPResNet
>>> import torch
>>> model = CSPResNet(depth=50, out_indices=(0, 1, 2, 3))
>>> model.eval()
>>> inputs = torch.rand(1, 3, 416, 416)
>>> level_outputs = model(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
...
(1, 128, 104, 104)
(1, 256, 52, 52)
(1, 512, 26, 26)
(1, 1024, 13, 13)
"""
arch_settings = {
50:
dict(
block_fn=ResNetBottleneck,
in_channels=(64, 128, 256, 512),
out_channels=(128, 256, 512, 1024),
num_blocks=(3, 3, 5, 2),
expand_ratio=4,
bottle_ratio=2,
has_downsampler=(False, True, True, True),
down_growth=False),
}
def __init__(self,
depth,
in_channels=3,
out_indices=(3, ),
frozen_stages=-1,
deep_stem=False,
conv_cfg=None,
norm_cfg=dict(type='BN', eps=1e-5),
act_cfg=dict(type='LeakyReLU', inplace=True),
norm_eval=False,
init_cfg=dict(type='Kaiming', layer='Conv2d')):
assert depth in self.arch_settings, 'depth must be one of ' \
f'{list(self.arch_settings.keys())}, but get {depth}.'
self.deep_stem = deep_stem
super().__init__(
arch=self.arch_settings[depth],
stem_fn=self._make_stem_layer,
in_channels=in_channels,
out_indices=out_indices,
frozen_stages=frozen_stages,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
norm_eval=norm_eval,
init_cfg=init_cfg)
def _make_stem_layer(self, in_channels):
# `stem_channels` equals to the `in_channels` in the first stage.
stem_channels = self.arch['in_channels'][0]
if self.deep_stem:
stem = nn.Sequential(
ConvModule(
in_channels,
stem_channels // 2,
kernel_size=3,
stride=2,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
ConvModule(
stem_channels // 2,
stem_channels // 2,
kernel_size=3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
ConvModule(
stem_channels // 2,
stem_channels,
kernel_size=3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
else:
stem = nn.Sequential(
ConvModule(
in_channels,
stem_channels,
kernel_size=7,
stride=2,
padding=3,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
return stem
@MODELS.register_module()
class CSPResNeXt(CSPResNet):
"""CSP-ResNeXt backbone.
Args:
depth (int): Depth of CSP-ResNeXt. Default: 50.
out_indices (Sequence[int]): Output from which stages.
Default: (4, ).
frozen_stages (int): Stages to be frozen (stop grad and set eval
mode). -1 means not freezing any parameters. Default: -1.
conv_cfg (dict): Config dict for convolution layer. Default: None.
norm_cfg (dict): Dictionary to construct and config norm layer.
Default: dict(type='BN', requires_grad=True).
act_cfg (dict): Config dict for activation layer.
Default: dict(type='LeakyReLU', negative_slope=0.1).
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
Example:
>>> from mmpretrain.models import CSPResNeXt
>>> import torch
>>> model = CSPResNeXt(depth=50, out_indices=(0, 1, 2, 3))
>>> model.eval()
>>> inputs = torch.rand(1, 3, 224, 224)
>>> level_outputs = model(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
...
(1, 256, 56, 56)
(1, 512, 28, 28)
(1, 1024, 14, 14)
(1, 2048, 7, 7)
"""
arch_settings = {
50:
dict(
block_fn=ResNeXtBottleneck,
in_channels=(64, 256, 512, 1024),
out_channels=(256, 512, 1024, 2048),
num_blocks=(3, 3, 5, 2),
expand_ratio=(4, 2, 2, 2),
bottle_ratio=4,
has_downsampler=(False, True, True, True),
down_growth=False,
# the base_channels is changed from 64 to 32 in CSPNet
block_args=dict(base_channels=32),
),
}
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from typing import Sequence, Tuple
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.cnn.bricks import Conv2d
from mmcv.cnn.bricks.transformer import FFN, AdaptivePadding, PatchEmbed
from mmengine.model import BaseModule, ModuleList
from mmengine.utils import to_2tuple
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from mmpretrain.models.backbones.base_backbone import BaseBackbone
from mmpretrain.registry import MODELS
from ..utils import ShiftWindowMSA
class DaViTWindowMSA(BaseModule):
"""Window based multi-head self-attention (W-MSA) module for DaViT.
The differences between DaViTWindowMSA & WindowMSA:
1. Without relative position bias.
Args:
embed_dims (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
Defaults to True.
qk_scale (float, optional): Override default qk scale of
``head_dim ** -0.5`` if set. Defaults to None.
attn_drop (float, optional): Dropout ratio of attention weight.
Defaults to 0.
proj_drop (float, optional): Dropout ratio of output. Defaults to 0.
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
def __init__(self,
embed_dims,
window_size,
num_heads,
qkv_bias=True,
qk_scale=None,
attn_drop=0.,
proj_drop=0.,
init_cfg=None):
super().__init__(init_cfg)
self.embed_dims = embed_dims
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_embed_dims = embed_dims // num_heads
self.scale = qk_scale or head_embed_dims**-0.5
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(embed_dims, embed_dims)
self.proj_drop = nn.Dropout(proj_drop)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x (tensor): input features with shape of (num_windows*B, N, C)
mask (tensor, Optional): mask with shape of (num_windows, Wh*Ww,
Wh*Ww), value should be between (-inf, 0].
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[
2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N,
N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
@staticmethod
def double_step_seq(step1, len1, step2, len2):
seq1 = torch.arange(0, step1 * len1, step1)
seq2 = torch.arange(0, step2 * len2, step2)
return (seq1[:, None] + seq2[None, :]).reshape(1, -1)
class ConvPosEnc(BaseModule):
"""DaViT conv pos encode block.
Args:
embed_dims (int): Number of input channels.
kernel_size (int): The kernel size of the first convolution.
Defaults to 3.
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
def __init__(self, embed_dims, kernel_size=3, init_cfg=None):
super(ConvPosEnc, self).__init__(init_cfg)
self.proj = Conv2d(
embed_dims,
embed_dims,
kernel_size,
stride=1,
padding=kernel_size // 2,
groups=embed_dims)
def forward(self, x, size: Tuple[int, int]):
B, N, C = x.shape
H, W = size
assert N == H * W
feat = x.transpose(1, 2).view(B, C, H, W)
feat = self.proj(feat)
feat = feat.flatten(2).transpose(1, 2)
x = x + feat
return x
class DaViTDownSample(BaseModule):
"""DaViT down sampole block.
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
conv_type (str): The type of convolution
to generate patch embedding. Default: "Conv2d".
kernel_size (int): The kernel size of the first convolution.
Defaults to 2.
stride (int): The stride of the second convluation module.
Defaults to 2.
padding (int | tuple | string ): The padding length of
embedding conv. When it is a string, it means the mode
of adaptive padding, support "same" and "corner" now.
Defaults to "corner".
dilation (int): Dilation of the convolution layers. Defaults to 1.
bias (bool): Bias of embed conv. Default: True.
norm_cfg (dict, optional): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
def __init__(self,
in_channels,
out_channels,
conv_type='Conv2d',
kernel_size=2,
stride=2,
padding='same',
dilation=1,
bias=True,
norm_cfg=None,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.out_channels = out_channels
if stride is None:
stride = kernel_size
kernel_size = to_2tuple(kernel_size)
stride = to_2tuple(stride)
dilation = to_2tuple(dilation)
if isinstance(padding, str):
self.adaptive_padding = AdaptivePadding(
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding)
# disable the padding of conv
padding = 0
else:
self.adaptive_padding = None
padding = to_2tuple(padding)
self.projection = build_conv_layer(
dict(type=conv_type),
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
if norm_cfg is not None:
self.norm = build_norm_layer(norm_cfg, in_channels)[1]
else:
self.norm = None
def forward(self, x, input_size):
if self.adaptive_padding:
x = self.adaptive_padding(x)
H, W = input_size
B, L, C = x.shape
assert L == H * W, 'input feature has wrong size'
x = self.norm(x)
x = x.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous()
x = self.projection(x)
output_size = (x.size(2), x.size(3))
x = x.flatten(2).transpose(1, 2)
return x, output_size
class ChannelAttention(BaseModule):
"""DaViT channel attention.
Args:
embed_dims (int): Number of input channels.
num_heads (int): Number of attention heads.
qkv_bias (bool): enable bias for qkv if True. Defaults to True.
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
def __init__(self, embed_dims, num_heads=8, qkv_bias=False, init_cfg=None):
super().__init__(init_cfg)
self.embed_dims = embed_dims
self.num_heads = num_heads
self.head_dims = embed_dims // num_heads
self.scale = self.head_dims**-0.5
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
self.proj = nn.Linear(embed_dims, embed_dims)
def forward(self, x):
B, N, _ = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
self.head_dims).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
k = k * self.scale
attention = k.transpose(-1, -2) @ v
attention = attention.softmax(dim=-1)
x = (attention @ q.transpose(-1, -2)).transpose(-1, -2)
x = x.transpose(1, 2).reshape(B, N, self.embed_dims)
x = self.proj(x)
return x
class ChannelBlock(BaseModule):
"""DaViT channel attention block.
Args:
embed_dims (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): The height and width of the window. Defaults to 7.
ffn_ratio (float): The expansion ratio of feedforward network hidden
layer channels. Defaults to 4.
qkv_bias (bool): enable bias for qkv if True. Defaults to True.
drop_path (float): The drop path rate after attention and ffn.
Defaults to 0.
ffn_cfgs (dict): The extra config of FFN. Defaults to empty dict.
norm_cfg (dict): The config of norm layers.
Defaults to ``dict(type='LN')``.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
def __init__(self,
embed_dims,
num_heads,
ffn_ratio=4.,
qkv_bias=False,
drop_path=0.,
ffn_cfgs=dict(),
norm_cfg=dict(type='LN'),
with_cp=False,
init_cfg=None):
super().__init__(init_cfg)
self.with_cp = with_cp
self.cpe1 = ConvPosEnc(embed_dims=embed_dims, kernel_size=3)
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
self.attn = ChannelAttention(
embed_dims, num_heads=num_heads, qkv_bias=qkv_bias)
self.cpe2 = ConvPosEnc(embed_dims=embed_dims, kernel_size=3)
_ffn_cfgs = {
'embed_dims': embed_dims,
'feedforward_channels': int(embed_dims * ffn_ratio),
'num_fcs': 2,
'ffn_drop': 0,
'dropout_layer': dict(type='DropPath', drop_prob=drop_path),
'act_cfg': dict(type='GELU'),
**ffn_cfgs
}
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
self.ffn = FFN(**_ffn_cfgs)
def forward(self, x, hw_shape):
def _inner_forward(x):
x = self.cpe1(x, hw_shape)
identity = x
x = self.norm1(x)
x = self.attn(x)
x = x + identity
x = self.cpe2(x, hw_shape)
identity = x
x = self.norm2(x)
x = self.ffn(x, identity=identity)
return x
if self.with_cp and x.requires_grad:
x = cp.checkpoint(_inner_forward, x)
else:
x = _inner_forward(x)
return x
class SpatialBlock(BaseModule):
"""DaViT spatial attention block.
Args:
embed_dims (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): The height and width of the window. Defaults to 7.
ffn_ratio (float): The expansion ratio of feedforward network hidden
layer channels. Defaults to 4.
qkv_bias (bool): enable bias for qkv if True. Defaults to True.
drop_path (float): The drop path rate after attention and ffn.
Defaults to 0.
pad_small_map (bool): If True, pad the small feature map to the window
size, which is common used in detection and segmentation. If False,
avoid shifting window and shrink the window size to the size of
feature map, which is common used in classification.
Defaults to False.
attn_cfgs (dict): The extra config of Shift Window-MSA.
Defaults to empty dict.
ffn_cfgs (dict): The extra config of FFN. Defaults to empty dict.
norm_cfg (dict): The config of norm layers.
Defaults to ``dict(type='LN')``.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
def __init__(self,
embed_dims,
num_heads,
window_size=7,
ffn_ratio=4.,
qkv_bias=True,
drop_path=0.,
pad_small_map=False,
attn_cfgs=dict(),
ffn_cfgs=dict(),
norm_cfg=dict(type='LN'),
with_cp=False,
init_cfg=None):
super(SpatialBlock, self).__init__(init_cfg)
self.with_cp = with_cp
self.cpe1 = ConvPosEnc(embed_dims=embed_dims, kernel_size=3)
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
_attn_cfgs = {
'embed_dims': embed_dims,
'num_heads': num_heads,
'shift_size': 0,
'window_size': window_size,
'dropout_layer': dict(type='DropPath', drop_prob=drop_path),
'qkv_bias': qkv_bias,
'pad_small_map': pad_small_map,
'window_msa': DaViTWindowMSA,
**attn_cfgs
}
self.attn = ShiftWindowMSA(**_attn_cfgs)
self.cpe2 = ConvPosEnc(embed_dims=embed_dims, kernel_size=3)
_ffn_cfgs = {
'embed_dims': embed_dims,
'feedforward_channels': int(embed_dims * ffn_ratio),
'num_fcs': 2,
'ffn_drop': 0,
'dropout_layer': dict(type='DropPath', drop_prob=drop_path),
'act_cfg': dict(type='GELU'),
**ffn_cfgs
}
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
self.ffn = FFN(**_ffn_cfgs)
def forward(self, x, hw_shape):
def _inner_forward(x):
x = self.cpe1(x, hw_shape)
identity = x
x = self.norm1(x)
x = self.attn(x, hw_shape)
x = x + identity
x = self.cpe2(x, hw_shape)
identity = x
x = self.norm2(x)
x = self.ffn(x, identity=identity)
return x
if self.with_cp and x.requires_grad:
x = cp.checkpoint(_inner_forward, x)
else:
x = _inner_forward(x)
return x
class DaViTBlock(BaseModule):
"""DaViT block.
Args:
embed_dims (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): The height and width of the window. Defaults to 7.
ffn_ratio (float): The expansion ratio of feedforward network hidden
layer channels. Defaults to 4.
qkv_bias (bool): enable bias for qkv if True. Defaults to True.
drop_path (float): The drop path rate after attention and ffn.
Defaults to 0.
pad_small_map (bool): If True, pad the small feature map to the window
size, which is common used in detection and segmentation. If False,
avoid shifting window and shrink the window size to the size of
feature map, which is common used in classification.
Defaults to False.
attn_cfgs (dict): The extra config of Shift Window-MSA.
Defaults to empty dict.
ffn_cfgs (dict): The extra config of FFN. Defaults to empty dict.
norm_cfg (dict): The config of norm layers.
Defaults to ``dict(type='LN')``.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
def __init__(self,
embed_dims,
num_heads,
window_size=7,
ffn_ratio=4.,
qkv_bias=True,
drop_path=0.,
pad_small_map=False,
attn_cfgs=dict(),
ffn_cfgs=dict(),
norm_cfg=dict(type='LN'),
with_cp=False,
init_cfg=None):
super(DaViTBlock, self).__init__(init_cfg)
self.spatial_block = SpatialBlock(
embed_dims,
num_heads,
window_size=window_size,
ffn_ratio=ffn_ratio,
qkv_bias=qkv_bias,
drop_path=drop_path,
pad_small_map=pad_small_map,
attn_cfgs=attn_cfgs,
ffn_cfgs=ffn_cfgs,
norm_cfg=norm_cfg,
with_cp=with_cp)
self.channel_block = ChannelBlock(
embed_dims,
num_heads,
ffn_ratio=ffn_ratio,
qkv_bias=qkv_bias,
drop_path=drop_path,
ffn_cfgs=ffn_cfgs,
norm_cfg=norm_cfg,
with_cp=False)
def forward(self, x, hw_shape):
x = self.spatial_block(x, hw_shape)
x = self.channel_block(x, hw_shape)
return x
class DaViTBlockSequence(BaseModule):
"""Module with successive DaViT blocks and downsample layer.
Args:
embed_dims (int): Number of input channels.
depth (int): Number of successive DaViT blocks.
num_heads (int): Number of attention heads.
window_size (int): The height and width of the window. Defaults to 7.
ffn_ratio (float): The expansion ratio of feedforward network hidden
layer channels. Defaults to 4.
qkv_bias (bool): enable bias for qkv if True. Defaults to True.
downsample (bool): Downsample the output of blocks by patch merging.
Defaults to False.
downsample_cfg (dict): The extra config of the patch merging layer.
Defaults to empty dict.
drop_paths (Sequence[float] | float): The drop path rate in each block.
Defaults to 0.
block_cfgs (Sequence[dict] | dict): The extra config of each block.
Defaults to empty dicts.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
pad_small_map (bool): If True, pad the small feature map to the window
size, which is common used in detection and segmentation. If False,
avoid shifting window and shrink the window size to the size of
feature map, which is common used in classification.
Defaults to False.
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
def __init__(self,
embed_dims,
depth,
num_heads,
window_size=7,
ffn_ratio=4.,
qkv_bias=True,
downsample=False,
downsample_cfg=dict(),
drop_paths=0.,
block_cfgs=dict(),
with_cp=False,
pad_small_map=False,
init_cfg=None):
super().__init__(init_cfg)
if not isinstance(drop_paths, Sequence):
drop_paths = [drop_paths] * depth
if not isinstance(block_cfgs, Sequence):
block_cfgs = [deepcopy(block_cfgs) for _ in range(depth)]
self.embed_dims = embed_dims
self.blocks = ModuleList()
for i in range(depth):
_block_cfg = {
'embed_dims': embed_dims,
'num_heads': num_heads,
'window_size': window_size,
'ffn_ratio': ffn_ratio,
'qkv_bias': qkv_bias,
'drop_path': drop_paths[i],
'with_cp': with_cp,
'pad_small_map': pad_small_map,
**block_cfgs[i]
}
block = DaViTBlock(**_block_cfg)
self.blocks.append(block)
if downsample:
_downsample_cfg = {
'in_channels': embed_dims,
'out_channels': 2 * embed_dims,
'norm_cfg': dict(type='LN'),
**downsample_cfg
}
self.downsample = DaViTDownSample(**_downsample_cfg)
else:
self.downsample = None
def forward(self, x, in_shape, do_downsample=True):
for block in self.blocks:
x = block(x, in_shape)
if self.downsample is not None and do_downsample:
x, out_shape = self.downsample(x, in_shape)
else:
out_shape = in_shape
return x, out_shape
@property
def out_channels(self):
if self.downsample:
return self.downsample.out_channels
else:
return self.embed_dims
@MODELS.register_module()
class DaViT(BaseBackbone):
"""DaViT.
A PyTorch implement of : `DaViT: Dual Attention Vision Transformers
<https://arxiv.org/abs/2204.03645v1>`_
Inspiration from
https://github.com/dingmyu/davit
Args:
arch (str | dict): DaViT architecture. If use string, choose from
'tiny', 'small', 'base' and 'large', 'huge', 'giant'. If use dict,
it should have below keys:
- **embed_dims** (int): The dimensions of embedding.
- **depths** (List[int]): The number of blocks in each stage.
- **num_heads** (List[int]): The number of heads in attention
modules of each stage.
Defaults to 't'.
patch_size (int | tuple): The patch size in patch embedding.
Defaults to 4.
in_channels (int): The num of input channels. Defaults to 3.
window_size (int): The height and width of the window. Defaults to 7.
ffn_ratio (float): The expansion ratio of feedforward network hidden
layer channels. Defaults to 4.
qkv_bias (bool): Whether to add bias for qkv in attention modules.
Defaults to True.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.1.
out_after_downsample (bool): Whether to output the feature map of a
stage after the following downsample layer. Defaults to False.
pad_small_map (bool): If True, pad the small feature map to the window
size, which is common used in detection and segmentation. If False,
avoid shifting window and shrink the window size to the size of
feature map, which is common used in classification.
Defaults to False.
norm_cfg (dict): Config dict for normalization layer for all output
features. Defaults to ``dict(type='LN')``
stage_cfgs (Sequence[dict] | dict): Extra config dict for each
stage. Defaults to an empty dict.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Defaults to False.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
arch_zoo = {
**dict.fromkeys(['t', 'tiny'], {
'embed_dims': 96,
'depths': [1, 1, 3, 1],
'num_heads': [3, 6, 12, 24]
}),
**dict.fromkeys(['s', 'small'], {
'embed_dims': 96,
'depths': [1, 1, 9, 1],
'num_heads': [3, 6, 12, 24]
}),
**dict.fromkeys(['b', 'base'], {
'embed_dims': 128,
'depths': [1, 1, 9, 1],
'num_heads': [4, 8, 16, 32]
}),
**dict.fromkeys(
['l', 'large'], {
'embed_dims': 192,
'depths': [1, 1, 9, 1],
'num_heads': [6, 12, 24, 48]
}),
**dict.fromkeys(
['h', 'huge'], {
'embed_dims': 256,
'depths': [1, 1, 9, 1],
'num_heads': [8, 16, 32, 64]
}),
**dict.fromkeys(
['g', 'giant'], {
'embed_dims': 384,
'depths': [1, 1, 12, 3],
'num_heads': [12, 24, 48, 96]
}),
}
def __init__(self,
arch='t',
patch_size=4,
in_channels=3,
window_size=7,
ffn_ratio=4.,
qkv_bias=True,
drop_path_rate=0.1,
out_after_downsample=False,
pad_small_map=False,
norm_cfg=dict(type='LN'),
stage_cfgs=dict(),
frozen_stages=-1,
norm_eval=False,
out_indices=(3, ),
with_cp=False,
init_cfg=None):
super().__init__(init_cfg)
if isinstance(arch, str):
arch = arch.lower()
assert arch in set(self.arch_zoo), \
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
self.arch_settings = self.arch_zoo[arch]
else:
essential_keys = {'embed_dims', 'depths', 'num_heads'}
assert isinstance(arch, dict) and essential_keys <= set(arch), \
f'Custom arch needs a dict with keys {essential_keys}'
self.arch_settings = arch
self.embed_dims = self.arch_settings['embed_dims']
self.depths = self.arch_settings['depths']
self.num_heads = self.arch_settings['num_heads']
self.num_layers = len(self.depths)
self.out_indices = out_indices
self.out_after_downsample = out_after_downsample
self.frozen_stages = frozen_stages
self.norm_eval = norm_eval
# stochastic depth decay rule
total_depth = sum(self.depths)
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
] # stochastic depth decay rule
_patch_cfg = dict(
in_channels=in_channels,
embed_dims=self.embed_dims,
conv_type='Conv2d',
kernel_size=7,
stride=patch_size,
padding='same',
norm_cfg=dict(type='LN'),
)
self.patch_embed = PatchEmbed(**_patch_cfg)
self.stages = ModuleList()
embed_dims = [self.embed_dims]
for i, (depth,
num_heads) in enumerate(zip(self.depths, self.num_heads)):
if isinstance(stage_cfgs, Sequence):
stage_cfg = stage_cfgs[i]
else:
stage_cfg = deepcopy(stage_cfgs)
downsample = True if i < self.num_layers - 1 else False
_stage_cfg = {
'embed_dims': embed_dims[-1],
'depth': depth,
'num_heads': num_heads,
'window_size': window_size,
'ffn_ratio': ffn_ratio,
'qkv_bias': qkv_bias,
'downsample': downsample,
'drop_paths': dpr[:depth],
'with_cp': with_cp,
'pad_small_map': pad_small_map,
**stage_cfg
}
stage = DaViTBlockSequence(**_stage_cfg)
self.stages.append(stage)
dpr = dpr[depth:]
embed_dims.append(stage.out_channels)
self.num_features = embed_dims[:-1]
# add a norm layer for each output
for i in out_indices:
if norm_cfg is not None:
norm_layer = build_norm_layer(norm_cfg,
self.num_features[i])[1]
else:
norm_layer = nn.Identity()
self.add_module(f'norm{i}', norm_layer)
def train(self, mode=True):
super().train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()
def _freeze_stages(self):
if self.frozen_stages >= 0:
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
for i in range(0, self.frozen_stages + 1):
m = self.stages[i]
m.eval()
for param in m.parameters():
param.requires_grad = False
for i in self.out_indices:
if i <= self.frozen_stages:
for param in getattr(self, f'norm{i}').parameters():
param.requires_grad = False
def forward(self, x):
x, hw_shape = self.patch_embed(x)
outs = []
for i, stage in enumerate(self.stages):
x, hw_shape = stage(
x, hw_shape, do_downsample=self.out_after_downsample)
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
out = norm_layer(x)
out = out.view(-1, *hw_shape,
self.num_features[i]).permute(0, 3, 1,
2).contiguous()
outs.append(out)
if stage.downsample is not None and not self.out_after_downsample:
x, hw_shape = stage.downsample(x, hw_shape)
return tuple(outs)
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmengine.model.weight_init import trunc_normal_
from mmpretrain.registry import MODELS
from .vision_transformer import VisionTransformer
@MODELS.register_module()
class DistilledVisionTransformer(VisionTransformer):
"""Distilled Vision Transformer.
A PyTorch implement of : `Training data-efficient image transformers &
distillation through attention <https://arxiv.org/abs/2012.12877>`_
Args:
arch (str | dict): Vision Transformer architecture. If use string,
choose from 'small', 'base', 'large', 'deit-tiny', 'deit-small'
and 'deit-base'. If use dict, it should have below keys:
- **embed_dims** (int): The dimensions of embedding.
- **num_layers** (int): The number of transformer encoder layers.
- **num_heads** (int): The number of heads in attention modules.
- **feedforward_channels** (int): The hidden dimensions in
feedforward modules.
Defaults to 'deit-base'.
img_size (int | tuple): The expected input image shape. Because we
support dynamic input shape, just set the argument to the most
common input image shape. Defaults to 224.
patch_size (int | tuple): The patch size in patch embedding.
Defaults to 16.
in_channels (int): The num of input channels. Defaults to 3.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
qkv_bias (bool): Whether to add bias for qkv in attention modules.
Defaults to True.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
out_type (str): The type of output features. Please choose from
- ``"cls_token"``: A tuple with the class token and the
distillation token. The shapes of both tensor are (B, C).
- ``"featmap"``: The feature map tensor from the patch tokens
with shape (B, C, H, W).
- ``"avg_featmap"``: The global averaged feature map tensor
with shape (B, C).
- ``"raw"``: The raw feature tensor includes patch tokens and
class tokens with shape (B, L, C).
Defaults to ``"cls_token"``.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Defaults to "bicubic".
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
layer_cfgs (Sequence | dict): Configs of each transformer layer in
encoder. Defaults to an empty dict.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
num_extra_tokens = 2 # class token and distillation token
def __init__(self, arch='deit-base', *args, **kwargs):
super(DistilledVisionTransformer, self).__init__(
arch=arch,
with_cls_token=True,
*args,
**kwargs,
)
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
def forward(self, x):
B = x.shape[0]
x, patch_resolution = self.patch_embed(x)
# stole cls_tokens impl from Phil Wang, thanks
cls_tokens = self.cls_token.expand(B, -1, -1)
dist_token = self.dist_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, dist_token, x), dim=1)
x = x + self.resize_pos_embed(
self.pos_embed,
self.patch_resolution,
patch_resolution,
mode=self.interpolate_mode,
num_extra_tokens=self.num_extra_tokens)
x = self.drop_after_pos(x)
outs = []
for i, layer in enumerate(self.layers):
x = layer(x)
if i == len(self.layers) - 1 and self.final_norm:
x = self.ln1(x)
if i in self.out_indices:
outs.append(self._format_output(x, patch_resolution))
return tuple(outs)
def _format_output(self, x, hw):
if self.out_type == 'cls_token':
return x[:, 0], x[:, 1]
return super()._format_output(x, hw)
def init_weights(self):
super(DistilledVisionTransformer, self).init_weights()
if not (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
trunc_normal_(self.dist_token, std=0.02)
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence
import numpy as np
import torch
from mmcv.cnn import Linear, build_activation_layer
from mmcv.cnn.bricks.drop import build_dropout
from mmcv.cnn.bricks.transformer import PatchEmbed
from mmengine.model import BaseModule, ModuleList, Sequential
from mmengine.utils import deprecated_api_warning
from torch import nn
from mmpretrain.registry import MODELS
from ..utils import (LayerScale, MultiheadAttention, build_norm_layer,
resize_pos_embed, to_2tuple)
from .vision_transformer import VisionTransformer
class DeiT3FFN(BaseModule):
"""FFN for DeiT3.
The differences between DeiT3FFN & FFN:
1. Use LayerScale.
Args:
embed_dims (int): The feature dimension. Same as
`MultiheadAttention`. Defaults: 256.
feedforward_channels (int): The hidden dimension of FFNs.
Defaults: 1024.
num_fcs (int, optional): The number of fully-connected layers in
FFNs. Default: 2.
act_cfg (dict, optional): The activation config for FFNs.
Default: dict(type='ReLU')
ffn_drop (float, optional): Probability of an element to be
zeroed in FFN. Default 0.0.
add_identity (bool, optional): Whether to add the
identity connection. Default: `True`.
dropout_layer (obj:`ConfigDict`): The dropout_layer used
when adding the shortcut.
use_layer_scale (bool): Whether to use layer_scale in
DeiT3FFN. Defaults to True.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
@deprecated_api_warning(
{
'dropout': 'ffn_drop',
'add_residual': 'add_identity'
},
cls_name='FFN')
def __init__(self,
embed_dims=256,
feedforward_channels=1024,
num_fcs=2,
act_cfg=dict(type='ReLU', inplace=True),
ffn_drop=0.,
dropout_layer=None,
add_identity=True,
use_layer_scale=True,
init_cfg=None,
**kwargs):
super().__init__(init_cfg)
assert num_fcs >= 2, 'num_fcs should be no less ' \
f'than 2. got {num_fcs}.'
self.embed_dims = embed_dims
self.feedforward_channels = feedforward_channels
self.num_fcs = num_fcs
self.act_cfg = act_cfg
self.activate = build_activation_layer(act_cfg)
layers = []
in_channels = embed_dims
for _ in range(num_fcs - 1):
layers.append(
Sequential(
Linear(in_channels, feedforward_channels), self.activate,
nn.Dropout(ffn_drop)))
in_channels = feedforward_channels
layers.append(Linear(feedforward_channels, embed_dims))
layers.append(nn.Dropout(ffn_drop))
self.layers = Sequential(*layers)
self.dropout_layer = build_dropout(
dropout_layer) if dropout_layer else torch.nn.Identity()
self.add_identity = add_identity
if use_layer_scale:
self.gamma2 = LayerScale(embed_dims)
else:
self.gamma2 = nn.Identity()
@deprecated_api_warning({'residual': 'identity'}, cls_name='FFN')
def forward(self, x, identity=None):
"""Forward function for `FFN`.
The function would add x to the output tensor if residue is None.
"""
out = self.layers(x)
out = self.gamma2(out)
if not self.add_identity:
return self.dropout_layer(out)
if identity is None:
identity = x
return identity + self.dropout_layer(out)
class DeiT3TransformerEncoderLayer(BaseModule):
"""Implements one encoder layer in DeiT3.
The differences between DeiT3TransformerEncoderLayer &
TransformerEncoderLayer:
1. Use LayerScale.
Args:
embed_dims (int): The feature dimension
num_heads (int): Parallel attention heads
feedforward_channels (int): The hidden dimension for FFNs
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Defaults to 0.
attn_drop_rate (float): The drop out rate for attention output weights.
Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
num_fcs (int): The number of fully-connected layers for FFNs.
Defaults to 2.
qkv_bias (bool): enable bias for qkv if True. Defaults to True.
use_layer_scale (bool): Whether to use layer_scale in
DeiT3TransformerEncoderLayer. Defaults to True.
act_cfg (dict): The activation config for FFNs.
Defaults to ``dict(type='GELU')``.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
embed_dims,
num_heads,
feedforward_channels,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
num_fcs=2,
qkv_bias=True,
use_layer_scale=True,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
init_cfg=None):
super(DeiT3TransformerEncoderLayer, self).__init__(init_cfg=init_cfg)
self.embed_dims = embed_dims
self.ln1 = build_norm_layer(norm_cfg, self.embed_dims)
self.attn = MultiheadAttention(
embed_dims=embed_dims,
num_heads=num_heads,
attn_drop=attn_drop_rate,
proj_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
qkv_bias=qkv_bias,
use_layer_scale=use_layer_scale)
self.ln2 = build_norm_layer(norm_cfg, self.embed_dims)
self.ffn = DeiT3FFN(
embed_dims=embed_dims,
feedforward_channels=feedforward_channels,
num_fcs=num_fcs,
ffn_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
act_cfg=act_cfg,
use_layer_scale=use_layer_scale)
def init_weights(self):
super(DeiT3TransformerEncoderLayer, self).init_weights()
for m in self.ffn.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.normal_(m.bias, std=1e-6)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = self.ffn(self.ln1(x), identity=x)
return x
@MODELS.register_module()
class DeiT3(VisionTransformer):
"""DeiT3 backbone.
A PyTorch implement of : `DeiT III: Revenge of the ViT
<https://arxiv.org/pdf/2204.07118.pdf>`_
The differences between DeiT3 & VisionTransformer:
1. Use LayerScale.
2. Concat cls token after adding pos_embed.
Args:
arch (str | dict): DeiT3 architecture. If use string,
choose from 'small', 'base', 'medium', 'large' and 'huge'.
If use dict, it should have below keys:
- **embed_dims** (int): The dimensions of embedding.
- **num_layers** (int): The number of transformer encoder layers.
- **num_heads** (int): The number of heads in attention modules.
- **feedforward_channels** (int): The hidden dimensions in
feedforward modules.
Defaults to 'base'.
img_size (int | tuple): The expected input image shape. Because we
support dynamic input shape, just set the argument to the most
common input image shape. Defaults to 224.
patch_size (int | tuple): The patch size in patch embedding.
Defaults to 16.
in_channels (int): The num of input channels. Defaults to 3.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
qkv_bias (bool): Whether to add bias for qkv in attention modules.
Defaults to True.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
out_type (str): The type of output features. Please choose from
- ``"cls_token"``: The class token tensor with shape (B, C).
- ``"featmap"``: The feature map tensor from the patch tokens
with shape (B, C, H, W).
- ``"avg_featmap"``: The global averaged feature map tensor
with shape (B, C).
- ``"raw"``: The raw feature tensor includes patch tokens and
class tokens with shape (B, L, C).
Defaults to ``"cls_token"``.
with_cls_token (bool): Whether concatenating class token into image
tokens as transformer input. Defaults to True.
use_layer_scale (bool): Whether to use layer_scale in DeiT3.
Defaults to True.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Defaults to "bicubic".
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
layer_cfgs (Sequence | dict): Configs of each transformer layer in
encoder. Defaults to an empty dict.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
arch_zoo = {
**dict.fromkeys(
['s', 'small'], {
'embed_dims': 384,
'num_layers': 12,
'num_heads': 6,
'feedforward_channels': 1536,
}),
**dict.fromkeys(
['m', 'medium'], {
'embed_dims': 512,
'num_layers': 12,
'num_heads': 8,
'feedforward_channels': 2048,
}),
**dict.fromkeys(
['b', 'base'], {
'embed_dims': 768,
'num_layers': 12,
'num_heads': 12,
'feedforward_channels': 3072
}),
**dict.fromkeys(
['l', 'large'], {
'embed_dims': 1024,
'num_layers': 24,
'num_heads': 16,
'feedforward_channels': 4096
}),
**dict.fromkeys(
['h', 'huge'], {
'embed_dims': 1280,
'num_layers': 32,
'num_heads': 16,
'feedforward_channels': 5120
}),
}
num_extra_tokens = 1 # class token
def __init__(self,
arch='base',
img_size=224,
patch_size=16,
in_channels=3,
out_indices=-1,
drop_rate=0.,
drop_path_rate=0.,
qkv_bias=True,
norm_cfg=dict(type='LN', eps=1e-6),
final_norm=True,
out_type='cls_token',
with_cls_token=True,
use_layer_scale=True,
interpolate_mode='bicubic',
patch_cfg=dict(),
layer_cfgs=dict(),
init_cfg=None):
super(VisionTransformer, self).__init__(init_cfg)
if isinstance(arch, str):
arch = arch.lower()
assert arch in set(self.arch_zoo), \
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
self.arch_settings = self.arch_zoo[arch]
else:
essential_keys = {
'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels'
}
assert isinstance(arch, dict) and essential_keys <= set(arch), \
f'Custom arch needs a dict with keys {essential_keys}'
self.arch_settings = arch
self.embed_dims = self.arch_settings['embed_dims']
self.num_layers = self.arch_settings['num_layers']
self.img_size = to_2tuple(img_size)
# Set patch embedding
_patch_cfg = dict(
in_channels=in_channels,
input_size=img_size,
embed_dims=self.embed_dims,
conv_type='Conv2d',
kernel_size=patch_size,
stride=patch_size,
)
_patch_cfg.update(patch_cfg)
self.patch_embed = PatchEmbed(**_patch_cfg)
self.patch_resolution = self.patch_embed.init_out_size
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
# Set out type
if out_type not in self.OUT_TYPES:
raise ValueError(f'Unsupported `out_type` {out_type}, please '
f'choose from {self.OUT_TYPES}')
self.out_type = out_type
# Set cls token
if with_cls_token:
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
elif out_type != 'cls_token':
self.cls_token = None
self.num_extra_tokens = 0
else:
raise ValueError(
'with_cls_token must be True when `out_type="cls_token"`.')
# Set position embedding
self.interpolate_mode = interpolate_mode
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches, self.embed_dims))
self._register_load_state_dict_pre_hook(self._prepare_pos_embed)
self.drop_after_pos = nn.Dropout(p=drop_rate)
if isinstance(out_indices, int):
out_indices = [out_indices]
assert isinstance(out_indices, Sequence), \
f'"out_indices" must by a sequence or int, ' \
f'get {type(out_indices)} instead.'
for i, index in enumerate(out_indices):
if index < 0:
out_indices[i] = self.num_layers + index
assert 0 <= out_indices[i] <= self.num_layers, \
f'Invalid out_indices {index}'
self.out_indices = out_indices
# stochastic depth decay rule
dpr = np.linspace(0, drop_path_rate, self.num_layers)
self.layers = ModuleList()
if isinstance(layer_cfgs, dict):
layer_cfgs = [layer_cfgs] * self.num_layers
for i in range(self.num_layers):
_layer_cfg = dict(
embed_dims=self.embed_dims,
num_heads=self.arch_settings['num_heads'],
feedforward_channels=self.
arch_settings['feedforward_channels'],
drop_rate=drop_rate,
drop_path_rate=dpr[i],
qkv_bias=qkv_bias,
norm_cfg=norm_cfg,
use_layer_scale=use_layer_scale)
_layer_cfg.update(layer_cfgs[i])
self.layers.append(DeiT3TransformerEncoderLayer(**_layer_cfg))
self.final_norm = final_norm
if final_norm:
self.ln1 = build_norm_layer(norm_cfg, self.embed_dims)
def forward(self, x):
B = x.shape[0]
x, patch_resolution = self.patch_embed(x)
x = x + resize_pos_embed(
self.pos_embed,
self.patch_resolution,
patch_resolution,
mode=self.interpolate_mode,
num_extra_tokens=0)
x = self.drop_after_pos(x)
if self.cls_token is not None:
# stole cls_tokens impl from Phil Wang, thanks
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
outs = []
for i, layer in enumerate(self.layers):
x = layer(x)
if i == len(self.layers) - 1 and self.final_norm:
x = self.ln1(x)
if i in self.out_indices:
outs.append(self._format_output(x, patch_resolution))
return tuple(outs)
def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs):
name = prefix + 'pos_embed'
if name not in state_dict.keys():
return
ckpt_pos_embed_shape = state_dict[name].shape
if self.pos_embed.shape != ckpt_pos_embed_shape:
from mmengine.logging import MMLogger
logger = MMLogger.get_current_instance()
logger.info(
f'Resize the pos_embed shape from {ckpt_pos_embed_shape} '
f'to {self.pos_embed.shape}.')
ckpt_pos_embed_shape = to_2tuple(
int(np.sqrt(ckpt_pos_embed_shape[1])))
pos_embed_shape = self.patch_embed.init_out_size
state_dict[name] = resize_pos_embed(
state_dict[name],
ckpt_pos_embed_shape,
pos_embed_shape,
self.interpolate_mode,
num_extra_tokens=0, # The cls token adding is after pos_embed
)
# Copyright (c) OpenMMLab. All rights reserved.
import math
from itertools import chain
from typing import Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from mmcv.cnn.bricks import build_activation_layer, build_norm_layer
from torch.jit.annotations import List
from mmpretrain.registry import MODELS
from .base_backbone import BaseBackbone
class DenseLayer(BaseBackbone):
"""DenseBlock layers."""
def __init__(self,
in_channels,
growth_rate,
bn_size,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
drop_rate=0.,
memory_efficient=False):
super(DenseLayer, self).__init__()
self.norm1 = build_norm_layer(norm_cfg, in_channels)[1]
self.conv1 = nn.Conv2d(
in_channels,
bn_size * growth_rate,
kernel_size=1,
stride=1,
bias=False)
self.act = build_activation_layer(act_cfg)
self.norm2 = build_norm_layer(norm_cfg, bn_size * growth_rate)[1]
self.conv2 = nn.Conv2d(
bn_size * growth_rate,
growth_rate,
kernel_size=3,
stride=1,
padding=1,
bias=False)
self.drop_rate = float(drop_rate)
self.memory_efficient = memory_efficient
def bottleneck_fn(self, xs):
# type: (List[torch.Tensor]) -> torch.Tensor
concated_features = torch.cat(xs, 1)
bottleneck_output = self.conv1(
self.act(self.norm1(concated_features))) # noqa: T484
return bottleneck_output
# todo: rewrite when torchscript supports any
def any_requires_grad(self, x):
# type: (List[torch.Tensor]) -> bool
for tensor in x:
if tensor.requires_grad:
return True
return False
# This decorator indicates to the compiler that a function or method
# should be ignored and replaced with the raising of an exception.
# Here this function is incompatible with torchscript.
@torch.jit.unused # noqa: T484
def call_checkpoint_bottleneck(self, x):
# type: (List[torch.Tensor]) -> torch.Tensor
def closure(*xs):
return self.bottleneck_fn(xs)
# Here use torch.utils.checkpoint to rerun a forward-pass during
# backward in bottleneck to save memories.
return cp.checkpoint(closure, *x)
def forward(self, x): # noqa: F811
# type: (List[torch.Tensor]) -> torch.Tensor
# assert input features is a list of Tensor
assert isinstance(x, list)
if self.memory_efficient and self.any_requires_grad(x):
if torch.jit.is_scripting():
raise Exception('Memory Efficient not supported in JIT')
bottleneck_output = self.call_checkpoint_bottleneck(x)
else:
bottleneck_output = self.bottleneck_fn(x)
new_features = self.conv2(self.act(self.norm2(bottleneck_output)))
if self.drop_rate > 0:
new_features = F.dropout(
new_features, p=self.drop_rate, training=self.training)
return new_features
class DenseBlock(nn.Module):
"""DenseNet Blocks."""
def __init__(self,
num_layers,
in_channels,
bn_size,
growth_rate,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
drop_rate=0.,
memory_efficient=False):
super(DenseBlock, self).__init__()
self.block = nn.ModuleList([
DenseLayer(
in_channels + i * growth_rate,
growth_rate=growth_rate,
bn_size=bn_size,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
drop_rate=drop_rate,
memory_efficient=memory_efficient) for i in range(num_layers)
])
def forward(self, init_features):
features = [init_features]
for layer in self.block:
new_features = layer(features)
features.append(new_features)
return torch.cat(features, 1)
class DenseTransition(nn.Sequential):
"""DenseNet Transition Layers."""
def __init__(self,
in_channels,
out_channels,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU')):
super(DenseTransition, self).__init__()
self.add_module('norm', build_norm_layer(norm_cfg, in_channels)[1])
self.add_module('act', build_activation_layer(act_cfg))
self.add_module(
'conv',
nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=1,
bias=False))
self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
@MODELS.register_module()
class DenseNet(BaseBackbone):
"""DenseNet.
A PyTorch implementation of : `Densely Connected Convolutional Networks
<https://arxiv.org/pdf/1608.06993.pdf>`_
Modified from the `official repo
<https://github.com/liuzhuang13/DenseNet>`_
and `pytorch
<https://github.com/pytorch/vision/blob/main/torchvision/models/densenet.py>`_.
Args:
arch (str | dict): The model's architecture. If string, it should be
one of architecture in ``DenseNet.arch_settings``. And if dict, it
should include the following two keys:
- growth_rate (int): Each layer of DenseBlock produce `k` feature
maps. Here refers `k` as the growth rate of the network.
- depths (list[int]): Number of repeated layers in each DenseBlock.
- init_channels (int): The output channels of stem layers.
Defaults to '121'.
in_channels (int): Number of input image channels. Defaults to 3.
bn_size (int): Refers to channel expansion parameter of 1x1
convolution layer. Defaults to 4.
drop_rate (float): Drop rate of Dropout Layer. Defaults to 0.
compression_factor (float): The reduction rate of transition layers.
Defaults to 0.5.
memory_efficient (bool): If True, uses checkpointing. Much more memory
efficient, but slower. Defaults to False.
See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_.
norm_cfg (dict): The config dict for norm layers.
Defaults to ``dict(type='BN')``.
act_cfg (dict): The config dict for activation after each convolution.
Defaults to ``dict(type='ReLU')``.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
frozen_stages (int): Stages to be frozen (all param fixed).
Defaults to 0, which means not freezing any parameters.
init_cfg (dict, optional): Initialization config dict.
"""
arch_settings = {
'121': {
'growth_rate': 32,
'depths': [6, 12, 24, 16],
'init_channels': 64,
},
'169': {
'growth_rate': 32,
'depths': [6, 12, 32, 32],
'init_channels': 64,
},
'201': {
'growth_rate': 32,
'depths': [6, 12, 48, 32],
'init_channels': 64,
},
'161': {
'growth_rate': 48,
'depths': [6, 12, 36, 24],
'init_channels': 96,
},
}
def __init__(self,
arch='121',
in_channels=3,
bn_size=4,
drop_rate=0,
compression_factor=0.5,
memory_efficient=False,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
out_indices=-1,
frozen_stages=0,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
if isinstance(arch, str):
assert arch in self.arch_settings, \
f'Unavailable arch, please choose from ' \
f'({set(self.arch_settings)}) or pass a dict.'
arch = self.arch_settings[arch]
elif isinstance(arch, dict):
essential_keys = {'growth_rate', 'depths', 'init_channels'}
assert isinstance(arch, dict) and essential_keys <= set(arch), \
f'Custom arch needs a dict with keys {essential_keys}'
self.growth_rate = arch['growth_rate']
self.depths = arch['depths']
self.init_channels = arch['init_channels']
self.act = build_activation_layer(act_cfg)
self.num_stages = len(self.depths)
# check out indices and frozen stages
if isinstance(out_indices, int):
out_indices = [out_indices]
assert isinstance(out_indices, Sequence), \
f'"out_indices" must by a sequence or int, ' \
f'get {type(out_indices)} instead.'
for i, index in enumerate(out_indices):
if index < 0:
out_indices[i] = self.num_stages + index
assert out_indices[i] >= 0, f'Invalid out_indices {index}'
self.out_indices = out_indices
self.frozen_stages = frozen_stages
# Set stem layers
self.stem = nn.Sequential(
nn.Conv2d(
in_channels,
self.init_channels,
kernel_size=7,
stride=2,
padding=3,
bias=False),
build_norm_layer(norm_cfg, self.init_channels)[1], self.act,
nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
# Repetitions of DenseNet Blocks
self.stages = nn.ModuleList()
self.transitions = nn.ModuleList()
channels = self.init_channels
for i in range(self.num_stages):
depth = self.depths[i]
stage = DenseBlock(
num_layers=depth,
in_channels=channels,
bn_size=bn_size,
growth_rate=self.growth_rate,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
drop_rate=drop_rate,
memory_efficient=memory_efficient)
self.stages.append(stage)
channels += depth * self.growth_rate
if i != self.num_stages - 1:
transition = DenseTransition(
in_channels=channels,
out_channels=math.floor(channels * compression_factor),
norm_cfg=norm_cfg,
act_cfg=act_cfg,
)
channels = math.floor(channels * compression_factor)
else:
# Final layers after dense block is just bn with act.
# Unlike the paper, the original repo also put this in
# transition layer, whereas torchvision take this out.
# We reckon this as transition layer here.
transition = nn.Sequential(
build_norm_layer(norm_cfg, channels)[1],
self.act,
)
self.transitions.append(transition)
self._freeze_stages()
def forward(self, x):
x = self.stem(x)
outs = []
for i in range(self.num_stages):
x = self.stages[i](x)
x = self.transitions[i](x)
if i in self.out_indices:
outs.append(x)
return tuple(outs)
def _freeze_stages(self):
for i in range(self.frozen_stages):
downsample_layer = self.transitions[i]
stage = self.stages[i]
downsample_layer.eval()
stage.eval()
for param in chain(downsample_layer.parameters(),
stage.parameters()):
param.requires_grad = False
def train(self, mode=True):
super(DenseNet, self).train(mode)
self._freeze_stages()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment