Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmpretrain
Commits
cbc25585
Commit
cbc25585
authored
Jun 24, 2025
by
limm
Browse files
add mmpretrain/ part
parent
1baf0566
Pipeline
#2801
canceled with stages
Changes
268
Pipelines
1
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
6620 additions
and
0 deletions
+6620
-0
mmpretrain/evaluation/metrics/retrieval.py
mmpretrain/evaluation/metrics/retrieval.py
+445
-0
mmpretrain/evaluation/metrics/scienceqa.py
mmpretrain/evaluation/metrics/scienceqa.py
+170
-0
mmpretrain/evaluation/metrics/shape_bias_label.py
mmpretrain/evaluation/metrics/shape_bias_label.py
+172
-0
mmpretrain/evaluation/metrics/single_label.py
mmpretrain/evaluation/metrics/single_label.py
+776
-0
mmpretrain/evaluation/metrics/visual_grounding_eval.py
mmpretrain/evaluation/metrics/visual_grounding_eval.py
+85
-0
mmpretrain/evaluation/metrics/voc_multi_label.py
mmpretrain/evaluation/metrics/voc_multi_label.py
+98
-0
mmpretrain/evaluation/metrics/vqa.py
mmpretrain/evaluation/metrics/vqa.py
+315
-0
mmpretrain/models/__init__.py
mmpretrain/models/__init__.py
+20
-0
mmpretrain/models/backbones/__init__.py
mmpretrain/models/backbones/__init__.py
+129
-0
mmpretrain/models/backbones/alexnet.py
mmpretrain/models/backbones/alexnet.py
+56
-0
mmpretrain/models/backbones/base_backbone.py
mmpretrain/models/backbones/base_backbone.py
+33
-0
mmpretrain/models/backbones/beit.py
mmpretrain/models/backbones/beit.py
+697
-0
mmpretrain/models/backbones/conformer.py
mmpretrain/models/backbones/conformer.py
+621
-0
mmpretrain/models/backbones/convmixer.py
mmpretrain/models/backbones/convmixer.py
+176
-0
mmpretrain/models/backbones/convnext.py
mmpretrain/models/backbones/convnext.py
+412
-0
mmpretrain/models/backbones/cspnet.py
mmpretrain/models/backbones/cspnet.py
+679
-0
mmpretrain/models/backbones/davit.py
mmpretrain/models/backbones/davit.py
+834
-0
mmpretrain/models/backbones/deit.py
mmpretrain/models/backbones/deit.py
+116
-0
mmpretrain/models/backbones/deit3.py
mmpretrain/models/backbones/deit3.py
+454
-0
mmpretrain/models/backbones/densenet.py
mmpretrain/models/backbones/densenet.py
+332
-0
No files found.
Too many changes to show.
To preserve performance only
268 of 268+
files are displayed.
Plain diff
Email patch
mmpretrain/evaluation/metrics/retrieval.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
List
,
Optional
,
Sequence
,
Union
import
mmengine
import
numpy
as
np
import
torch
from
mmengine.evaluator
import
BaseMetric
from
mmengine.utils
import
is_seq_of
from
mmpretrain.registry
import
METRICS
from
mmpretrain.structures
import
label_to_onehot
from
.single_label
import
to_tensor
@
METRICS
.
register_module
()
class
RetrievalRecall
(
BaseMetric
):
r
"""Recall evaluation metric for image retrieval.
Args:
topk (int | Sequence[int]): If the ground truth label matches one of
the best **k** predictions, the sample will be regard as a positive
prediction. If the parameter is a tuple, all of top-k recall will
be calculated and outputted together. Defaults to 1.
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
prefix (str, optional): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Defaults to None.
Examples:
Use in the code:
>>> import torch
>>> from mmpretrain.evaluation import RetrievalRecall
>>> # -------------------- The Basic Usage --------------------
>>> y_pred = [[0], [1], [2], [3]]
>>> y_true = [[0, 1], [2], [1], [0, 3]]
>>> RetrievalRecall.calculate(
>>> y_pred, y_true, topk=1, pred_indices=True, target_indices=True)
[tensor([50.])]
>>> # Calculate the recall@1 and recall@5 for non-indices input.
>>> y_score = torch.rand((1000, 10))
>>> import torch.nn.functional as F
>>> y_true = F.one_hot(torch.arange(0, 1000) % 10, num_classes=10)
>>> RetrievalRecall.calculate(y_score, y_true, topk=(1, 5))
[tensor(9.3000), tensor(48.4000)]
>>>
>>> # ------------------- Use with Evalutor -------------------
>>> from mmpretrain.structures import DataSample
>>> from mmengine.evaluator import Evaluator
>>> data_samples = [
... DataSample().set_gt_label([0, 1]).set_pred_score(
... torch.rand(10))
... for i in range(1000)
... ]
>>> evaluator = Evaluator(metrics=RetrievalRecall(topk=(1, 5)))
>>> evaluator.process(data_samples)
>>> evaluator.evaluate(1000)
{'retrieval/Recall@1': 20.700000762939453,
'retrieval/Recall@5': 78.5999984741211}
Use in OpenMMLab configs:
.. code:: python
val_evaluator = dict(type='RetrievalRecall', topk=(1, 5))
test_evaluator = val_evaluator
"""
default_prefix
:
Optional
[
str
]
=
'retrieval'
def
__init__
(
self
,
topk
:
Union
[
int
,
Sequence
[
int
]],
collect_device
:
str
=
'cpu'
,
prefix
:
Optional
[
str
]
=
None
)
->
None
:
topk
=
(
topk
,
)
if
isinstance
(
topk
,
int
)
else
topk
for
k
in
topk
:
if
k
<=
0
:
raise
ValueError
(
'`topk` must be a ingter larger than 0 '
'or seq of ingter larger than 0.'
)
self
.
topk
=
topk
super
().
__init__
(
collect_device
=
collect_device
,
prefix
=
prefix
)
def
process
(
self
,
data_batch
:
Sequence
[
dict
],
data_samples
:
Sequence
[
dict
]):
"""Process one batch of data and predictions.
The processed results should be stored in ``self.results``, which will
be used to computed the metrics when all batches have been processed.
Args:
data_batch (Sequence[dict]): A batch of data from the dataloader.
predictions (Sequence[dict]): A batch of outputs from the model.
"""
for
data_sample
in
data_samples
:
pred_score
=
data_sample
[
'pred_score'
].
clone
()
gt_label
=
data_sample
[
'gt_label'
]
if
'gt_score'
in
data_sample
:
target
=
data_sample
.
get
(
'gt_score'
).
clone
()
else
:
num_classes
=
pred_score
.
size
()[
-
1
]
target
=
label_to_onehot
(
gt_label
,
num_classes
)
# Because the retrieval output logit vector will be much larger
# compared to the normal classification, to save resources, the
# evaluation results are computed each batch here and then reduce
# all results at the end.
result
=
RetrievalRecall
.
calculate
(
pred_score
.
unsqueeze
(
0
),
target
.
unsqueeze
(
0
),
topk
=
self
.
topk
)
self
.
results
.
append
(
result
)
def
compute_metrics
(
self
,
results
:
List
):
"""Compute the metrics from processed results.
Args:
results (list): The processed results of each batch.
Returns:
Dict: The computed metrics. The keys are the names of the metrics,
and the values are corresponding results.
"""
result_metrics
=
dict
()
for
i
,
k
in
enumerate
(
self
.
topk
):
recall_at_k
=
sum
([
r
[
i
].
item
()
for
r
in
results
])
/
len
(
results
)
result_metrics
[
f
'Recall@
{
k
}
'
]
=
recall_at_k
return
result_metrics
@
staticmethod
def
calculate
(
pred
:
Union
[
np
.
ndarray
,
torch
.
Tensor
],
target
:
Union
[
np
.
ndarray
,
torch
.
Tensor
],
topk
:
Union
[
int
,
Sequence
[
int
]],
pred_indices
:
(
bool
)
=
False
,
target_indices
:
(
bool
)
=
False
)
->
float
:
"""Calculate the average recall.
Args:
pred (torch.Tensor | np.ndarray | Sequence): The prediction
results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with
shape ``(N, M)`` or a sequence of index/onehot
format labels.
target (torch.Tensor | np.ndarray | Sequence): The prediction
results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with
shape ``(N, M)`` or a sequence of index/onehot
format labels.
topk (int, Sequence[int]): Predictions with the k-th highest
scores are considered as positive.
pred_indices (bool): Whether the ``pred`` is a sequence of
category index labels. Defaults to False.
target_indices (bool): Whether the ``target`` is a sequence of
category index labels. Defaults to False.
Returns:
List[float]: the average recalls.
"""
topk
=
(
topk
,
)
if
isinstance
(
topk
,
int
)
else
topk
for
k
in
topk
:
if
k
<=
0
:
raise
ValueError
(
'`topk` must be a ingter larger than 0 '
'or seq of ingter larger than 0.'
)
max_keep
=
max
(
topk
)
pred
=
_format_pred
(
pred
,
max_keep
,
pred_indices
)
target
=
_format_target
(
target
,
target_indices
)
assert
len
(
pred
)
==
len
(
target
),
(
f
'Length of `pred`(
{
len
(
pred
)
}
) and `target` (
{
len
(
target
)
}
) '
f
'must be the same.'
)
num_samples
=
len
(
pred
)
results
=
[]
for
k
in
topk
:
recalls
=
torch
.
zeros
(
num_samples
)
for
i
,
(
sample_pred
,
sample_target
)
in
enumerate
(
zip
(
pred
,
target
)):
sample_pred
=
np
.
array
(
to_tensor
(
sample_pred
).
cpu
())
sample_target
=
np
.
array
(
to_tensor
(
sample_target
).
cpu
())
recalls
[
i
]
=
int
(
np
.
in1d
(
sample_pred
[:
k
],
sample_target
).
max
())
results
.
append
(
recalls
.
mean
()
*
100
)
return
results
@
METRICS
.
register_module
()
class
RetrievalAveragePrecision
(
BaseMetric
):
r
"""Calculate the average precision for image retrieval.
Args:
topk (int, optional): Predictions with the k-th highest scores are
considered as positive.
mode (str, optional): The mode to calculate AP, choose from
'IR'(information retrieval) and 'integrate'. Defaults to 'IR'.
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
prefix (str, optional): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Defaults to None.
Note:
If the ``mode`` set to 'IR', use the stanford AP calculation of
information retrieval as in wikipedia page[1]; if set to 'integrate',
the method implemented integrates over the precision-recall curve
by averaging two adjacent precision points, then multiplying by the
recall step like mAP in Detection task. This is the convention for
the Revisited Oxford/Paris datasets[2].
References:
[1] `Wikipedia entry for the Average precision <https://en.wikipedia.
org/wiki/Evaluation_measures_(information_retrieval)#Average_precision>`_
[2] `The Oxford Buildings Dataset
<https://www.robots.ox.ac.uk/~vgg/data/oxbuildings/>`_
Examples:
Use in code:
>>> import torch
>>> import numpy as np
>>> from mmcls.evaluation import RetrievalAveragePrecision
>>> # using index format inputs
>>> pred = [ torch.Tensor([idx for idx in range(100)]) ] * 3
>>> target = [[0, 3, 6, 8, 35], [1, 2, 54, 105], [2, 42, 205]]
>>> RetrievalAveragePrecision.calculate(pred, target, 10, True, True)
29.246031746031747
>>> # using tensor format inputs
>>> pred = np.array([np.linspace(0.95, 0.05, 10)] * 2)
>>> target = torch.Tensor([[1, 0, 1, 0, 0, 1, 0, 0, 1, 1]] * 2)
>>> RetrievalAveragePrecision.calculate(pred, target, 10)
62.222222222222214
Use in OpenMMLab config files:
.. code:: python
val_evaluator = dict(type='RetrievalAveragePrecision', topk=100)
test_evaluator = val_evaluator
"""
default_prefix
:
Optional
[
str
]
=
'retrieval'
def
__init__
(
self
,
topk
:
Optional
[
int
]
=
None
,
mode
:
Optional
[
str
]
=
'IR'
,
collect_device
:
str
=
'cpu'
,
prefix
:
Optional
[
str
]
=
None
)
->
None
:
if
topk
is
None
or
(
isinstance
(
topk
,
int
)
and
topk
<=
0
):
raise
ValueError
(
'`topk` must be a ingter larger than 0.'
)
mode_options
=
[
'IR'
,
'integrate'
]
assert
mode
in
mode_options
,
\
f
'Invalid `mode` argument, please specify from
{
mode_options
}
.'
self
.
topk
=
topk
self
.
mode
=
mode
super
().
__init__
(
collect_device
=
collect_device
,
prefix
=
prefix
)
def
process
(
self
,
data_batch
:
Sequence
[
dict
],
data_samples
:
Sequence
[
dict
]):
"""Process one batch of data and predictions.
The processed results should be stored in ``self.results``, which will
be used to computed the metrics when all batches have been processed.
Args:
data_batch (Sequence[dict]): A batch of data from the dataloader.
predictions (Sequence[dict]): A batch of outputs from the model.
"""
for
data_sample
in
data_samples
:
pred_score
=
data_sample
.
get
(
'pred_score'
).
clone
()
if
'gt_score'
in
data_sample
:
target
=
data_sample
.
get
(
'gt_score'
).
clone
()
else
:
gt_label
=
data_sample
.
get
(
'gt_label'
)
num_classes
=
pred_score
.
size
()[
-
1
]
target
=
label_to_onehot
(
gt_label
,
num_classes
)
# Because the retrieval output logit vector will be much larger
# compared to the normal classification, to save resources, the
# evaluation results are computed each batch here and then reduce
# all results at the end.
result
=
RetrievalAveragePrecision
.
calculate
(
pred_score
.
unsqueeze
(
0
),
target
.
unsqueeze
(
0
),
self
.
topk
,
mode
=
self
.
mode
)
self
.
results
.
append
(
result
)
def
compute_metrics
(
self
,
results
:
List
):
"""Compute the metrics from processed results.
Args:
results (list): The processed results of each batch.
Returns:
Dict: The computed metrics. The keys are the names of the metrics,
and the values are corresponding results.
"""
result_metrics
=
dict
()
result_metrics
[
f
'mAP@
{
self
.
topk
}
'
]
=
np
.
mean
(
self
.
results
).
item
()
return
result_metrics
@
staticmethod
def
calculate
(
pred
:
Union
[
np
.
ndarray
,
torch
.
Tensor
],
target
:
Union
[
np
.
ndarray
,
torch
.
Tensor
],
topk
:
Optional
[
int
]
=
None
,
pred_indices
:
(
bool
)
=
False
,
target_indices
:
(
bool
)
=
False
,
mode
:
str
=
'IR'
)
->
float
:
"""Calculate the average precision.
Args:
pred (torch.Tensor | np.ndarray | Sequence): The prediction
results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with
shape ``(N, M)`` or a sequence of index/onehot
format labels.
target (torch.Tensor | np.ndarray | Sequence): The prediction
results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with
shape ``(N, M)`` or a sequence of index/onehot
format labels.
topk (int, optional): Predictions with the k-th highest scores
are considered as positive.
pred_indices (bool): Whether the ``pred`` is a sequence of
category index labels. Defaults to False.
target_indices (bool): Whether the ``target`` is a sequence of
category index labels. Defaults to False.
mode (Optional[str]): The mode to calculate AP, choose from
'IR'(information retrieval) and 'integrate'. Defaults to 'IR'.
Note:
If the ``mode`` set to 'IR', use the stanford AP calculation of
information retrieval as in wikipedia page; if set to 'integrate',
the method implemented integrates over the precision-recall curve
by averaging two adjacent precision points, then multiplying by the
recall step like mAP in Detection task. This is the convention for
the Revisited Oxford/Paris datasets.
Returns:
float: the average precision of the query image.
References:
[1] `Wikipedia entry for Average precision(information_retrieval)
<https://en.wikipedia.org/wiki/Evaluation_measures_
(information_retrieval)#Average_precision>`_
[2] `The Oxford Buildings Dataset <https://www.robots.ox.ac.uk/
~vgg/data/oxbuildings/`_
"""
if
topk
is
None
or
(
isinstance
(
topk
,
int
)
and
topk
<=
0
):
raise
ValueError
(
'`topk` must be a ingter larger than 0.'
)
mode_options
=
[
'IR'
,
'integrate'
]
assert
mode
in
mode_options
,
\
f
'Invalid `mode` argument, please specify from
{
mode_options
}
.'
pred
=
_format_pred
(
pred
,
topk
,
pred_indices
)
target
=
_format_target
(
target
,
target_indices
)
assert
len
(
pred
)
==
len
(
target
),
(
f
'Length of `pred`(
{
len
(
pred
)
}
) and `target` (
{
len
(
target
)
}
) '
f
'must be the same.'
)
num_samples
=
len
(
pred
)
aps
=
np
.
zeros
(
num_samples
)
for
i
,
(
sample_pred
,
sample_target
)
in
enumerate
(
zip
(
pred
,
target
)):
aps
[
i
]
=
_calculateAp_for_sample
(
sample_pred
,
sample_target
,
mode
)
return
aps
.
mean
()
def
_calculateAp_for_sample
(
pred
,
target
,
mode
):
pred
=
np
.
array
(
to_tensor
(
pred
).
cpu
())
target
=
np
.
array
(
to_tensor
(
target
).
cpu
())
num_preds
=
len
(
pred
)
# TODO: use ``torch.isin`` in torch1.10.
positive_ranks
=
np
.
arange
(
num_preds
)[
np
.
in1d
(
pred
,
target
)]
ap
=
0
for
i
,
rank
in
enumerate
(
positive_ranks
):
if
mode
==
'IR'
:
precision
=
(
i
+
1
)
/
(
rank
+
1
)
ap
+=
precision
elif
mode
==
'integrate'
:
# code are modified from https://www.robots.ox.ac.uk/~vgg/data/oxbuildings/compute_ap.cpp # noqa:
old_precision
=
i
/
rank
if
rank
>
0
else
1
cur_precision
=
(
i
+
1
)
/
(
rank
+
1
)
prediction
=
(
old_precision
+
cur_precision
)
/
2
ap
+=
prediction
ap
=
ap
/
len
(
target
)
return
ap
*
100
def
_format_pred
(
label
,
topk
=
None
,
is_indices
=
False
):
"""format various label to List[indices]."""
if
is_indices
:
assert
isinstance
(
label
,
Sequence
),
\
'`pred` must be Sequence of indices when'
\
f
' `pred_indices` set to True, but get
{
type
(
label
)
}
'
for
i
,
sample_pred
in
enumerate
(
label
):
assert
is_seq_of
(
sample_pred
,
int
)
or
isinstance
(
sample_pred
,
(
np
.
ndarray
,
torch
.
Tensor
)),
\
'`pred` should be Sequence of indices when `pred_indices`'
\
f
'set to True. but pred[
{
i
}
] is
{
sample_pred
}
'
if
topk
:
label
[
i
]
=
sample_pred
[:
min
(
topk
,
len
(
sample_pred
))]
return
label
if
isinstance
(
label
,
np
.
ndarray
):
label
=
torch
.
from_numpy
(
label
)
elif
not
isinstance
(
label
,
torch
.
Tensor
):
raise
TypeError
(
f
'The pred must be type of torch.tensor, '
f
'np.ndarray or Sequence but get
{
type
(
label
)
}
.'
)
topk
=
topk
if
topk
else
label
.
size
()[
-
1
]
_
,
indices
=
label
.
topk
(
topk
)
return
indices
def
_format_target
(
label
,
is_indices
=
False
):
"""format various label to List[indices]."""
if
is_indices
:
assert
isinstance
(
label
,
Sequence
),
\
'`target` must be Sequence of indices when'
\
f
' `target_indices` set to True, but get
{
type
(
label
)
}
'
for
i
,
sample_gt
in
enumerate
(
label
):
assert
is_seq_of
(
sample_gt
,
int
)
or
isinstance
(
sample_gt
,
(
np
.
ndarray
,
torch
.
Tensor
)),
\
'`target` should be Sequence of indices when '
\
f
'`target_indices` set to True. but target[
{
i
}
] is
{
sample_gt
}
'
return
label
if
isinstance
(
label
,
np
.
ndarray
):
label
=
torch
.
from_numpy
(
label
)
elif
isinstance
(
label
,
Sequence
)
and
not
mmengine
.
is_str
(
label
):
label
=
torch
.
tensor
(
label
)
elif
not
isinstance
(
label
,
torch
.
Tensor
):
raise
TypeError
(
f
'The pred must be type of torch.tensor, '
f
'np.ndarray or Sequence but get
{
type
(
label
)
}
.'
)
indices
=
[
sample_gt
.
nonzero
().
squeeze
(
-
1
)
for
sample_gt
in
label
]
return
indices
mmpretrain/evaluation/metrics/scienceqa.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
import
random
from
typing
import
List
,
Optional
from
mmengine.evaluator
import
BaseMetric
from
mmpretrain.registry
import
METRICS
def
get_pred_idx
(
prediction
:
str
,
choices
:
List
[
str
],
options
:
List
[
str
])
->
int
:
# noqa
"""Get the index (e.g. 2) from the prediction (e.g. 'C')
Args:
prediction (str): The prediction from the model,
from ['A', 'B', 'C', 'D', 'E']
choices (List(str)): The choices for the question,
from ['A', 'B', 'C', 'D', 'E']
options (List(str)): The options for the question,
from ['A', 'B', 'C', 'D', 'E']
Returns:
int: The index of the prediction, from [0, 1, 2, 3, 4]
"""
if
prediction
in
options
[:
len
(
choices
)]:
return
options
.
index
(
prediction
)
else
:
return
random
.
choice
(
range
(
len
(
choices
)))
@
METRICS
.
register_module
()
class
ScienceQAMetric
(
BaseMetric
):
"""Evaluation Metric for ScienceQA.
Args:
options (List(str)): Options for each question. Defaults to
["A", "B", "C", "D", "E"].
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
prefix (str, optional): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Should be modified according to the
`retrieval_type` for unambiguous results. Defaults to TR.
"""
def
__init__
(
self
,
options
:
List
[
str
]
=
[
'A'
,
'B'
,
'C'
,
'D'
,
'E'
],
collect_device
:
str
=
'cpu'
,
prefix
:
Optional
[
str
]
=
None
)
->
None
:
super
().
__init__
(
collect_device
=
collect_device
,
prefix
=
prefix
)
self
.
options
=
options
def
process
(
self
,
data_batch
,
data_samples
)
->
None
:
"""Process one batch of data samples.
data_samples should contain the following keys:
1. pred_answer (str): The prediction from the model,
from ['A', 'B', 'C', 'D', 'E']
2. choices (List(str)): The choices for the question,
from ['A', 'B', 'C', 'D', 'E']
3. grade (int): The grade for the question, from grade1 to grade12
4. subject (str): The subject for the question, from
['natural science', 'social science', 'language science']
5. answer (str): The answer for the question, from
['A', 'B', 'C', 'D', 'E']
6. hint (str): The hint for the question
7. has_image (bool): Whether or not the question has image
The processed results should be stored in ``self.results``, which will
be used to computed the metrics when all batches have been processed.
Args:
data_batch: A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from the model.
"""
for
data_sample
in
data_samples
:
result
=
dict
()
choices
=
data_sample
.
get
(
'choices'
)
result
[
'prediction'
]
=
get_pred_idx
(
data_sample
.
get
(
'pred_answer'
),
choices
,
self
.
options
)
result
[
'grade'
]
=
data_sample
.
get
(
'grade'
)
result
[
'subject'
]
=
data_sample
.
get
(
'subject'
)
result
[
'answer'
]
=
data_sample
.
get
(
'gt_answer'
)
hint
=
data_sample
.
get
(
'hint'
)
has_image
=
data_sample
.
get
(
'has_image'
,
False
)
result
[
'no_context'
]
=
True
if
not
has_image
and
len
(
hint
)
==
0
else
False
# noqa
result
[
'has_text'
]
=
True
if
len
(
hint
)
>
0
else
False
result
[
'has_image'
]
=
has_image
# Save the result to `self.results`.
self
.
results
.
append
(
result
)
def
compute_metrics
(
self
,
results
:
List
)
->
dict
:
"""Compute the metrics from processed results.
Args:
results (dict): The processed results of each batch.
Returns:
Dict: The computed metrics. The keys are the names of the metrics,
and the values are corresponding results.
"""
# NOTICE: don't access `self.results` from the method.
metrics
=
dict
()
all_acc
=
[]
acc_natural
=
[]
acc_social
=
[]
acc_language
=
[]
acc_has_text
=
[]
acc_has_image
=
[]
acc_no_context
=
[]
acc_grade_1_6
=
[]
acc_grade_7_12
=
[]
for
result
in
results
:
correct
=
result
[
'prediction'
]
==
result
[
'answer'
]
all_acc
.
append
(
correct
)
# different subjects
if
result
[
'subject'
]
==
'natural science'
:
acc_natural
.
append
(
correct
)
elif
result
[
'subject'
]
==
'social science'
:
acc_social
.
append
(
correct
)
elif
result
[
'subject'
]
==
'language science'
:
acc_language
.
append
(
correct
)
# different context
if
result
[
'has_text'
]:
acc_has_text
.
append
(
correct
)
elif
result
[
'has_image'
]:
acc_has_image
.
append
(
correct
)
elif
result
[
'no_context'
]:
acc_no_context
.
append
(
correct
)
# different grade
if
result
[
'grade'
]
in
[
'grade1'
,
'grade2'
,
'grade3'
,
'grade4'
,
'grade5'
,
'grade6'
]:
acc_grade_1_6
.
append
(
correct
)
elif
result
[
'grade'
]
in
[
'grade7'
,
'grade8'
,
'grade9'
,
'grade10'
,
'grade11'
,
'grade12'
]:
acc_grade_7_12
.
append
(
correct
)
metrics
[
'all_acc'
]
=
sum
(
all_acc
)
/
len
(
all_acc
)
if
len
(
acc_natural
)
>
0
:
metrics
[
'acc_natural'
]
=
sum
(
acc_natural
)
/
len
(
acc_natural
)
if
len
(
acc_social
)
>
0
:
metrics
[
'acc_social'
]
=
sum
(
acc_social
)
/
len
(
acc_social
)
if
len
(
acc_language
)
>
0
:
metrics
[
'acc_language'
]
=
sum
(
acc_language
)
/
len
(
acc_language
)
if
len
(
acc_has_text
)
>
0
:
metrics
[
'acc_has_text'
]
=
sum
(
acc_has_text
)
/
len
(
acc_has_text
)
if
len
(
acc_has_image
)
>
0
:
metrics
[
'acc_has_image'
]
=
sum
(
acc_has_image
)
/
len
(
acc_has_image
)
if
len
(
acc_no_context
)
>
0
:
metrics
[
'acc_no_context'
]
=
sum
(
acc_no_context
)
/
len
(
acc_no_context
)
if
len
(
acc_grade_1_6
)
>
0
:
metrics
[
'acc_grade_1_6'
]
=
sum
(
acc_grade_1_6
)
/
len
(
acc_grade_1_6
)
if
len
(
acc_grade_7_12
)
>
0
:
metrics
[
'acc_grade_7_12'
]
=
sum
(
acc_grade_7_12
)
/
len
(
acc_grade_7_12
)
return
metrics
mmpretrain/evaluation/metrics/shape_bias_label.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
import
csv
import
os
import
os.path
as
osp
from
typing
import
List
,
Sequence
import
numpy
as
np
import
torch
from
mmengine.dist.utils
import
get_rank
from
mmengine.evaluator
import
BaseMetric
from
mmpretrain.registry
import
METRICS
@
METRICS
.
register_module
()
class
ShapeBiasMetric
(
BaseMetric
):
"""Evaluate the model on ``cue_conflict`` dataset.
This module will evaluate the model on an OOD dataset, cue_conflict, in
order to measure the shape bias of the model. In addition to compuate the
Top-1 accuracy, this module also generate a csv file to record the
detailed prediction results, such that this csv file can be used to
generate the shape bias curve.
Args:
csv_dir (str): The directory to save the csv file.
model_name (str): The name of the csv file. Please note that the
model name should be an unique identifier.
dataset_name (str): The name of the dataset. Default: 'cue_conflict'.
"""
# mapping several classes from ImageNet-1K to the same category
airplane_indices
=
[
404
]
bear_indices
=
[
294
,
295
,
296
,
297
]
bicycle_indices
=
[
444
,
671
]
bird_indices
=
[
8
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
18
,
19
,
20
,
22
,
23
,
24
,
80
,
81
,
82
,
83
,
87
,
88
,
89
,
90
,
91
,
92
,
93
,
94
,
95
,
96
,
98
,
99
,
100
,
127
,
128
,
129
,
130
,
131
,
132
,
133
,
135
,
136
,
137
,
138
,
139
,
140
,
141
,
142
,
143
,
144
,
145
]
boat_indices
=
[
472
,
554
,
625
,
814
,
914
]
bottle_indices
=
[
440
,
720
,
737
,
898
,
899
,
901
,
907
]
car_indices
=
[
436
,
511
,
817
]
cat_indices
=
[
281
,
282
,
283
,
284
,
285
,
286
]
chair_indices
=
[
423
,
559
,
765
,
857
]
clock_indices
=
[
409
,
530
,
892
]
dog_indices
=
[
152
,
153
,
154
,
155
,
156
,
157
,
158
,
159
,
160
,
161
,
162
,
163
,
164
,
165
,
166
,
167
,
168
,
169
,
170
,
171
,
172
,
173
,
174
,
175
,
176
,
177
,
178
,
179
,
180
,
181
,
182
,
183
,
184
,
185
,
186
,
187
,
188
,
189
,
190
,
191
,
193
,
194
,
195
,
196
,
197
,
198
,
199
,
200
,
201
,
202
,
203
,
205
,
206
,
207
,
208
,
209
,
210
,
211
,
212
,
213
,
214
,
215
,
216
,
217
,
218
,
219
,
220
,
221
,
222
,
223
,
224
,
225
,
226
,
228
,
229
,
230
,
231
,
232
,
233
,
234
,
235
,
236
,
237
,
238
,
239
,
240
,
241
,
243
,
244
,
245
,
246
,
247
,
248
,
249
,
250
,
252
,
253
,
254
,
255
,
256
,
257
,
259
,
261
,
262
,
263
,
265
,
266
,
267
,
268
]
elephant_indices
=
[
385
,
386
]
keyboard_indices
=
[
508
,
878
]
knife_indices
=
[
499
]
oven_indices
=
[
766
]
truck_indices
=
[
555
,
569
,
656
,
675
,
717
,
734
,
864
,
867
]
def
__init__
(
self
,
csv_dir
:
str
,
model_name
:
str
,
dataset_name
:
str
=
'cue_conflict'
,
**
kwargs
)
->
None
:
super
().
__init__
(
**
kwargs
)
self
.
categories
=
sorted
([
'knife'
,
'keyboard'
,
'elephant'
,
'bicycle'
,
'airplane'
,
'clock'
,
'oven'
,
'chair'
,
'bear'
,
'boat'
,
'cat'
,
'bottle'
,
'truck'
,
'car'
,
'bird'
,
'dog'
])
self
.
csv_dir
=
csv_dir
self
.
model_name
=
model_name
self
.
dataset_name
=
dataset_name
if
get_rank
()
==
0
:
self
.
csv_path
=
self
.
create_csv
()
def
process
(
self
,
data_batch
,
data_samples
:
Sequence
[
dict
])
->
None
:
"""Process one batch of data samples.
The processed results should be stored in ``self.results``, which will
be used to computed the metrics when all batches have been processed.
Args:
data_batch: A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from the model.
"""
for
data_sample
in
data_samples
:
result
=
dict
()
if
'pred_score'
in
data_sample
:
result
[
'pred_score'
]
=
data_sample
[
'pred_score'
].
cpu
()
else
:
result
[
'pred_label'
]
=
data_sample
[
'pred_label'
].
cpu
()
result
[
'gt_label'
]
=
data_sample
[
'gt_label'
].
cpu
()
result
[
'gt_category'
]
=
data_sample
[
'img_path'
].
split
(
'/'
)[
-
2
]
result
[
'img_name'
]
=
data_sample
[
'img_path'
].
split
(
'/'
)[
-
1
]
aggregated_category_probabilities
=
[]
# get the prediction for each category of current instance
for
category
in
self
.
categories
:
category_indices
=
getattr
(
self
,
f
'
{
category
}
_indices'
)
category_probabilities
=
torch
.
gather
(
result
[
'pred_score'
],
0
,
torch
.
tensor
(
category_indices
)).
mean
()
aggregated_category_probabilities
.
append
(
category_probabilities
)
# sort the probabilities in descending order
pred_indices
=
torch
.
stack
(
aggregated_category_probabilities
).
argsort
(
descending
=
True
).
numpy
()
result
[
'pred_category'
]
=
np
.
take
(
self
.
categories
,
pred_indices
)
# Save the result to `self.results`.
self
.
results
.
append
(
result
)
def
create_csv
(
self
)
->
str
:
"""Create a csv file to store the results."""
session_name
=
'session-1'
csv_path
=
osp
.
join
(
self
.
csv_dir
,
self
.
dataset_name
+
'_'
+
self
.
model_name
+
'_'
+
session_name
+
'.csv'
)
if
osp
.
exists
(
csv_path
):
os
.
remove
(
csv_path
)
directory
=
osp
.
dirname
(
csv_path
)
if
not
osp
.
exists
(
directory
):
os
.
makedirs
(
directory
,
exist_ok
=
True
)
with
open
(
csv_path
,
'w'
)
as
f
:
writer
=
csv
.
writer
(
f
)
writer
.
writerow
([
'subj'
,
'session'
,
'trial'
,
'rt'
,
'object_response'
,
'category'
,
'condition'
,
'imagename'
])
return
csv_path
def
dump_results_to_csv
(
self
,
results
:
List
[
dict
])
->
None
:
"""Dump the results to a csv file.
Args:
results (List[dict]): A list of results.
"""
for
i
,
result
in
enumerate
(
results
):
img_name
=
result
[
'img_name'
]
category
=
result
[
'gt_category'
]
condition
=
'NaN'
with
open
(
self
.
csv_path
,
'a'
)
as
f
:
writer
=
csv
.
writer
(
f
)
writer
.
writerow
([
self
.
model_name
,
1
,
i
+
1
,
'NaN'
,
result
[
'pred_category'
][
0
],
category
,
condition
,
img_name
])
def
compute_metrics
(
self
,
results
:
List
[
dict
])
->
dict
:
"""Compute the metrics from the results.
Args:
results (List[dict]): A list of results.
Returns:
dict: A dict of metrics.
"""
if
get_rank
()
==
0
:
self
.
dump_results_to_csv
(
results
)
metrics
=
dict
()
metrics
[
'accuracy/top1'
]
=
np
.
mean
([
result
[
'pred_category'
][
0
]
==
result
[
'gt_category'
]
for
result
in
results
])
return
metrics
mmpretrain/evaluation/metrics/single_label.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
from
itertools
import
product
from
typing
import
List
,
Optional
,
Sequence
,
Union
import
mmengine
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
mmengine.evaluator
import
BaseMetric
from
mmpretrain.registry
import
METRICS
def
to_tensor
(
value
):
"""Convert value to torch.Tensor."""
if
isinstance
(
value
,
np
.
ndarray
):
value
=
torch
.
from_numpy
(
value
)
elif
isinstance
(
value
,
Sequence
)
and
not
mmengine
.
is_str
(
value
):
value
=
torch
.
tensor
(
value
)
elif
not
isinstance
(
value
,
torch
.
Tensor
):
raise
TypeError
(
f
'
{
type
(
value
)
}
is not an available argument.'
)
return
value
def
_precision_recall_f1_support
(
pred_positive
,
gt_positive
,
average
):
"""calculate base classification task metrics, such as precision, recall,
f1_score, support."""
average_options
=
[
'micro'
,
'macro'
,
None
]
assert
average
in
average_options
,
'Invalid `average` argument, '
\
f
'please specify from
{
average_options
}
.'
# ignore -1 target such as difficult sample that is not wanted
# in evaluation results.
# only for calculate multi-label without affecting single-label behavior
ignored_index
=
gt_positive
==
-
1
pred_positive
[
ignored_index
]
=
0
gt_positive
[
ignored_index
]
=
0
class_correct
=
(
pred_positive
&
gt_positive
)
if
average
==
'micro'
:
tp_sum
=
class_correct
.
sum
()
pred_sum
=
pred_positive
.
sum
()
gt_sum
=
gt_positive
.
sum
()
else
:
tp_sum
=
class_correct
.
sum
(
0
)
pred_sum
=
pred_positive
.
sum
(
0
)
gt_sum
=
gt_positive
.
sum
(
0
)
precision
=
tp_sum
/
torch
.
clamp
(
pred_sum
,
min
=
1
).
float
()
*
100
recall
=
tp_sum
/
torch
.
clamp
(
gt_sum
,
min
=
1
).
float
()
*
100
f1_score
=
2
*
precision
*
recall
/
torch
.
clamp
(
precision
+
recall
,
min
=
torch
.
finfo
(
torch
.
float32
).
eps
)
if
average
in
[
'macro'
,
'micro'
]:
precision
=
precision
.
mean
(
0
)
recall
=
recall
.
mean
(
0
)
f1_score
=
f1_score
.
mean
(
0
)
support
=
gt_sum
.
sum
(
0
)
else
:
support
=
gt_sum
return
precision
,
recall
,
f1_score
,
support
@
METRICS
.
register_module
()
class
Accuracy
(
BaseMetric
):
r
"""Accuracy evaluation metric.
For either binary classification or multi-class classification, the
accuracy is the fraction of correct predictions in all predictions:
.. math::
\text{Accuracy} = \frac{N_{\text{correct}}}{N_{\text{all}}}
Args:
topk (int | Sequence[int]): If the ground truth label matches one of
the best **k** predictions, the sample will be regard as a positive
prediction. If the parameter is a tuple, all of top-k accuracy will
be calculated and outputted together. Defaults to 1.
thrs (Sequence[float | None] | float | None): If a float, predictions
with score lower than the threshold will be regard as the negative
prediction. If None, not apply threshold. If the parameter is a
tuple, accuracy based on all thresholds will be calculated and
outputted together. Defaults to 0.
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
prefix (str, optional): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Defaults to None.
Examples:
>>> import torch
>>> from mmpretrain.evaluation import Accuracy
>>> # -------------------- The Basic Usage --------------------
>>> y_pred = [0, 2, 1, 3]
>>> y_true = [0, 1, 2, 3]
>>> Accuracy.calculate(y_pred, y_true)
tensor([50.])
>>> # Calculate the top1 and top5 accuracy.
>>> y_score = torch.rand((1000, 10))
>>> y_true = torch.zeros((1000, ))
>>> Accuracy.calculate(y_score, y_true, topk=(1, 5))
[[tensor([9.9000])], [tensor([51.5000])]]
>>>
>>> # ------------------- Use with Evalutor -------------------
>>> from mmpretrain.structures import DataSample
>>> from mmengine.evaluator import Evaluator
>>> data_samples = [
... DataSample().set_gt_label(0).set_pred_score(torch.rand(10))
... for i in range(1000)
... ]
>>> evaluator = Evaluator(metrics=Accuracy(topk=(1, 5)))
>>> evaluator.process(data_samples)
>>> evaluator.evaluate(1000)
{
'accuracy/top1': 9.300000190734863,
'accuracy/top5': 51.20000076293945
}
"""
default_prefix
:
Optional
[
str
]
=
'accuracy'
def
__init__
(
self
,
topk
:
Union
[
int
,
Sequence
[
int
]]
=
(
1
,
),
thrs
:
Union
[
float
,
Sequence
[
Union
[
float
,
None
]],
None
]
=
0.
,
collect_device
:
str
=
'cpu'
,
prefix
:
Optional
[
str
]
=
None
)
->
None
:
super
().
__init__
(
collect_device
=
collect_device
,
prefix
=
prefix
)
if
isinstance
(
topk
,
int
):
self
.
topk
=
(
topk
,
)
else
:
self
.
topk
=
tuple
(
topk
)
if
isinstance
(
thrs
,
float
)
or
thrs
is
None
:
self
.
thrs
=
(
thrs
,
)
else
:
self
.
thrs
=
tuple
(
thrs
)
def
process
(
self
,
data_batch
,
data_samples
:
Sequence
[
dict
]):
"""Process one batch of data samples.
The processed results should be stored in ``self.results``, which will
be used to computed the metrics when all batches have been processed.
Args:
data_batch: A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from the model.
"""
for
data_sample
in
data_samples
:
result
=
dict
()
if
'pred_score'
in
data_sample
:
result
[
'pred_score'
]
=
data_sample
[
'pred_score'
].
cpu
()
else
:
result
[
'pred_label'
]
=
data_sample
[
'pred_label'
].
cpu
()
result
[
'gt_label'
]
=
data_sample
[
'gt_label'
].
cpu
()
# Save the result to `self.results`.
self
.
results
.
append
(
result
)
def
compute_metrics
(
self
,
results
:
List
):
"""Compute the metrics from processed results.
Args:
results (dict): The processed results of each batch.
Returns:
Dict: The computed metrics. The keys are the names of the metrics,
and the values are corresponding results.
"""
# NOTICE: don't access `self.results` from the method.
metrics
=
{}
# concat
target
=
torch
.
cat
([
res
[
'gt_label'
]
for
res
in
results
])
if
'pred_score'
in
results
[
0
]:
pred
=
torch
.
stack
([
res
[
'pred_score'
]
for
res
in
results
])
try
:
acc
=
self
.
calculate
(
pred
,
target
,
self
.
topk
,
self
.
thrs
)
except
ValueError
as
e
:
# If the topk is invalid.
raise
ValueError
(
str
(
e
)
+
' Please check the `val_evaluator` and '
'`test_evaluator` fields in your config file.'
)
multi_thrs
=
len
(
self
.
thrs
)
>
1
for
i
,
k
in
enumerate
(
self
.
topk
):
for
j
,
thr
in
enumerate
(
self
.
thrs
):
name
=
f
'top
{
k
}
'
if
multi_thrs
:
name
+=
'_no-thr'
if
thr
is
None
else
f
'_thr-
{
thr
:.
2
f
}
'
metrics
[
name
]
=
acc
[
i
][
j
].
item
()
else
:
# If only label in the `pred_label`.
pred
=
torch
.
cat
([
res
[
'pred_label'
]
for
res
in
results
])
acc
=
self
.
calculate
(
pred
,
target
,
self
.
topk
,
self
.
thrs
)
metrics
[
'top1'
]
=
acc
.
item
()
return
metrics
@
staticmethod
def
calculate
(
pred
:
Union
[
torch
.
Tensor
,
np
.
ndarray
,
Sequence
],
target
:
Union
[
torch
.
Tensor
,
np
.
ndarray
,
Sequence
],
topk
:
Sequence
[
int
]
=
(
1
,
),
thrs
:
Sequence
[
Union
[
float
,
None
]]
=
(
0.
,
),
)
->
Union
[
torch
.
Tensor
,
List
[
List
[
torch
.
Tensor
]]]:
"""Calculate the accuracy.
Args:
pred (torch.Tensor | np.ndarray | Sequence): The prediction
results. It can be labels (N, ), or scores of every
class (N, C).
target (torch.Tensor | np.ndarray | Sequence): The target of
each prediction with shape (N, ).
thrs (Sequence[float | None]): Predictions with scores under
the thresholds are considered negative. It's only used
when ``pred`` is scores. None means no thresholds.
Defaults to (0., ).
thrs (Sequence[float]): Predictions with scores under
the thresholds are considered negative. It's only used
when ``pred`` is scores. Defaults to (0., ).
Returns:
torch.Tensor | List[List[torch.Tensor]]: Accuracy.
- torch.Tensor: If the ``pred`` is a sequence of label instead of
score (number of dimensions is 1). Only return a top-1 accuracy
tensor, and ignore the argument ``topk` and ``thrs``.
- List[List[torch.Tensor]]: If the ``pred`` is a sequence of score
(number of dimensions is 2). Return the accuracy on each ``topk``
and ``thrs``. And the first dim is ``topk``, the second dim is
``thrs``.
"""
pred
=
to_tensor
(
pred
)
target
=
to_tensor
(
target
).
to
(
torch
.
int64
)
num
=
pred
.
size
(
0
)
assert
pred
.
size
(
0
)
==
target
.
size
(
0
),
\
f
"The size of pred (
{
pred
.
size
(
0
)
}
) doesn't match "
\
f
'the target (
{
target
.
size
(
0
)
}
).'
if
pred
.
ndim
==
1
:
# For pred label, ignore topk and acc
pred_label
=
pred
.
int
()
correct
=
pred
.
eq
(
target
).
float
().
sum
(
0
,
keepdim
=
True
)
acc
=
correct
.
mul_
(
100.
/
num
)
return
acc
else
:
# For pred score, calculate on all topk and thresholds.
pred
=
pred
.
float
()
maxk
=
max
(
topk
)
if
maxk
>
pred
.
size
(
1
):
raise
ValueError
(
f
'Top-
{
maxk
}
accuracy is unavailable since the number of '
f
'categories is
{
pred
.
size
(
1
)
}
.'
)
pred_score
,
pred_label
=
pred
.
topk
(
maxk
,
dim
=
1
)
pred_label
=
pred_label
.
t
()
correct
=
pred_label
.
eq
(
target
.
view
(
1
,
-
1
).
expand_as
(
pred_label
))
results
=
[]
for
k
in
topk
:
results
.
append
([])
for
thr
in
thrs
:
# Only prediction values larger than thr are counted
# as correct
_correct
=
correct
if
thr
is
not
None
:
_correct
=
_correct
&
(
pred_score
.
t
()
>
thr
)
correct_k
=
_correct
[:
k
].
reshape
(
-
1
).
float
().
sum
(
0
,
keepdim
=
True
)
acc
=
correct_k
.
mul_
(
100.
/
num
)
results
[
-
1
].
append
(
acc
)
return
results
@
METRICS
.
register_module
()
class
SingleLabelMetric
(
BaseMetric
):
r
"""A collection of precision, recall, f1-score and support for
single-label tasks.
The collection of metrics is for single-label multi-class classification.
And all these metrics are based on the confusion matrix of every category:
.. image:: ../../_static/image/confusion-matrix.png
:width: 60%
:align: center
All metrics can be formulated use variables above:
**Precision** is the fraction of correct predictions in all predictions:
.. math::
\text{Precision} = \frac{TP}{TP+FP}
**Recall** is the fraction of correct predictions in all targets:
.. math::
\text{Recall} = \frac{TP}{TP+FN}
**F1-score** is the harmonic mean of the precision and recall:
.. math::
\text{F1-score} = \frac{2\times\text{Recall}\times\text{Precision}}{\text{Recall}+\text{Precision}}
**Support** is the number of samples:
.. math::
\text{Support} = TP + TN + FN + FP
Args:
thrs (Sequence[float | None] | float | None): If a float, predictions
with score lower than the threshold will be regard as the negative
prediction. If None, only the top-1 prediction will be regard as
the positive prediction. If the parameter is a tuple, accuracy
based on all thresholds will be calculated and outputted together.
Defaults to 0.
items (Sequence[str]): The detailed metric items to evaluate, select
from "precision", "recall", "f1-score" and "support".
Defaults to ``('precision', 'recall', 'f1-score')``.
average (str | None): How to calculate the final metrics from the
confusion matrix of every category. It supports three modes:
- `"macro"`: Calculate metrics for each category, and calculate
the mean value over all categories.
- `"micro"`: Average the confusion matrix over all categories and
calculate metrics on the mean confusion matrix.
- `None`: Calculate metrics of every category and output directly.
Defaults to "macro".
num_classes (int, optional): The number of classes. Defaults to None.
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
prefix (str, optional): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Defaults to None.
Examples:
>>> import torch
>>> from mmpretrain.evaluation import SingleLabelMetric
>>> # -------------------- The Basic Usage --------------------
>>> y_pred = [0, 1, 1, 3]
>>> y_true = [0, 2, 1, 3]
>>> # Output precision, recall, f1-score and support.
>>> SingleLabelMetric.calculate(y_pred, y_true, num_classes=4)
(tensor(62.5000), tensor(75.), tensor(66.6667), tensor(4))
>>> # Calculate with different thresholds.
>>> y_score = torch.rand((1000, 10))
>>> y_true = torch.zeros((1000, ))
>>> SingleLabelMetric.calculate(y_score, y_true, thrs=(0., 0.9))
[(tensor(10.), tensor(0.9500), tensor(1.7352), tensor(1000)),
(tensor(10.), tensor(0.5500), tensor(1.0427), tensor(1000))]
>>>
>>> # ------------------- Use with Evalutor -------------------
>>> from mmpretrain.structures import DataSample
>>> from mmengine.evaluator import Evaluator
>>> data_samples = [
... DataSample().set_gt_label(i%5).set_pred_score(torch.rand(5))
... for i in range(1000)
... ]
>>> evaluator = Evaluator(metrics=SingleLabelMetric())
>>> evaluator.process(data_samples)
>>> evaluator.evaluate(1000)
{'single-label/precision': 19.650691986083984,
'single-label/recall': 19.600000381469727,
'single-label/f1-score': 19.619548797607422}
>>> # Evaluate on each class
>>> evaluator = Evaluator(metrics=SingleLabelMetric(average=None))
>>> evaluator.process(data_samples)
>>> evaluator.evaluate(1000)
{
'single-label/precision_classwise': [21.1, 18.7, 17.8, 19.4, 16.1],
'single-label/recall_classwise': [18.5, 18.5, 17.0, 20.0, 18.0],
'single-label/f1-score_classwise': [19.7, 18.6, 17.1, 19.7, 17.0]
}
"""
# noqa: E501
default_prefix
:
Optional
[
str
]
=
'single-label'
def
__init__
(
self
,
thrs
:
Union
[
float
,
Sequence
[
Union
[
float
,
None
]],
None
]
=
0.
,
items
:
Sequence
[
str
]
=
(
'precision'
,
'recall'
,
'f1-score'
),
average
:
Optional
[
str
]
=
'macro'
,
num_classes
:
Optional
[
int
]
=
None
,
collect_device
:
str
=
'cpu'
,
prefix
:
Optional
[
str
]
=
None
)
->
None
:
super
().
__init__
(
collect_device
=
collect_device
,
prefix
=
prefix
)
if
isinstance
(
thrs
,
float
)
or
thrs
is
None
:
self
.
thrs
=
(
thrs
,
)
else
:
self
.
thrs
=
tuple
(
thrs
)
for
item
in
items
:
assert
item
in
[
'precision'
,
'recall'
,
'f1-score'
,
'support'
],
\
f
'The metric
{
item
}
is not supported by `SingleLabelMetric`,'
\
' please specify from "precision", "recall", "f1-score" and '
\
'"support".'
self
.
items
=
tuple
(
items
)
self
.
average
=
average
self
.
num_classes
=
num_classes
def
process
(
self
,
data_batch
,
data_samples
:
Sequence
[
dict
]):
"""Process one batch of data samples.
The processed results should be stored in ``self.results``, which will
be used to computed the metrics when all batches have been processed.
Args:
data_batch: A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from the model.
"""
for
data_sample
in
data_samples
:
result
=
dict
()
if
'pred_score'
in
data_sample
:
result
[
'pred_score'
]
=
data_sample
[
'pred_score'
].
cpu
()
else
:
num_classes
=
self
.
num_classes
or
data_sample
.
get
(
'num_classes'
)
assert
num_classes
is
not
None
,
\
'The `num_classes` must be specified if no `pred_score`.'
result
[
'pred_label'
]
=
data_sample
[
'pred_label'
].
cpu
()
result
[
'num_classes'
]
=
num_classes
result
[
'gt_label'
]
=
data_sample
[
'gt_label'
].
cpu
()
# Save the result to `self.results`.
self
.
results
.
append
(
result
)
def
compute_metrics
(
self
,
results
:
List
):
"""Compute the metrics from processed results.
Args:
results (list): The processed results of each batch.
Returns:
Dict: The computed metrics. The keys are the names of the metrics,
and the values are corresponding results.
"""
# NOTICE: don't access `self.results` from the method. `self.results`
# are a list of results from multiple batch, while the input `results`
# are the collected results.
metrics
=
{}
def
pack_results
(
precision
,
recall
,
f1_score
,
support
):
single_metrics
=
{}
if
'precision'
in
self
.
items
:
single_metrics
[
'precision'
]
=
precision
if
'recall'
in
self
.
items
:
single_metrics
[
'recall'
]
=
recall
if
'f1-score'
in
self
.
items
:
single_metrics
[
'f1-score'
]
=
f1_score
if
'support'
in
self
.
items
:
single_metrics
[
'support'
]
=
support
return
single_metrics
# concat
target
=
torch
.
cat
([
res
[
'gt_label'
]
for
res
in
results
])
if
'pred_score'
in
results
[
0
]:
pred
=
torch
.
stack
([
res
[
'pred_score'
]
for
res
in
results
])
metrics_list
=
self
.
calculate
(
pred
,
target
,
thrs
=
self
.
thrs
,
average
=
self
.
average
)
multi_thrs
=
len
(
self
.
thrs
)
>
1
for
i
,
thr
in
enumerate
(
self
.
thrs
):
if
multi_thrs
:
suffix
=
'_no-thr'
if
thr
is
None
else
f
'_thr-
{
thr
:.
2
f
}
'
else
:
suffix
=
''
for
k
,
v
in
pack_results
(
*
metrics_list
[
i
]).
items
():
metrics
[
k
+
suffix
]
=
v
else
:
# If only label in the `pred_label`.
pred
=
torch
.
cat
([
res
[
'pred_label'
]
for
res
in
results
])
res
=
self
.
calculate
(
pred
,
target
,
average
=
self
.
average
,
num_classes
=
results
[
0
][
'num_classes'
])
metrics
=
pack_results
(
*
res
)
result_metrics
=
dict
()
for
k
,
v
in
metrics
.
items
():
if
self
.
average
is
None
:
result_metrics
[
k
+
'_classwise'
]
=
v
.
cpu
().
detach
().
tolist
()
elif
self
.
average
==
'micro'
:
result_metrics
[
k
+
f
'_
{
self
.
average
}
'
]
=
v
.
item
()
else
:
result_metrics
[
k
]
=
v
.
item
()
return
result_metrics
@
staticmethod
def
calculate
(
pred
:
Union
[
torch
.
Tensor
,
np
.
ndarray
,
Sequence
],
target
:
Union
[
torch
.
Tensor
,
np
.
ndarray
,
Sequence
],
thrs
:
Sequence
[
Union
[
float
,
None
]]
=
(
0.
,
),
average
:
Optional
[
str
]
=
'macro'
,
num_classes
:
Optional
[
int
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]:
"""Calculate the precision, recall, f1-score and support.
Args:
pred (torch.Tensor | np.ndarray | Sequence): The prediction
results. It can be labels (N, ), or scores of every
class (N, C).
target (torch.Tensor | np.ndarray | Sequence): The target of
each prediction with shape (N, ).
thrs (Sequence[float | None]): Predictions with scores under
the thresholds are considered negative. It's only used
when ``pred`` is scores. None means no thresholds.
Defaults to (0., ).
average (str | None): How to calculate the final metrics from
the confusion matrix of every category. It supports three
modes:
- `"macro"`: Calculate metrics for each category, and calculate
the mean value over all categories.
- `"micro"`: Average the confusion matrix over all categories
and calculate metrics on the mean confusion matrix.
- `None`: Calculate metrics of every category and output
directly.
Defaults to "macro".
num_classes (Optional, int): The number of classes. If the ``pred``
is label instead of scores, this argument is required.
Defaults to None.
Returns:
Tuple: The tuple contains precision, recall and f1-score.
And the type of each item is:
- torch.Tensor: If the ``pred`` is a sequence of label instead of
score (number of dimensions is 1). Only returns a tensor for
each metric. The shape is (1, ) if ``classwise`` is False, and
(C, ) if ``classwise`` is True.
- List[torch.Tensor]: If the ``pred`` is a sequence of score
(number of dimensions is 2). Return the metrics on each ``thrs``.
The shape of tensor is (1, ) if ``classwise`` is False, and (C, )
if ``classwise`` is True.
"""
average_options
=
[
'micro'
,
'macro'
,
None
]
assert
average
in
average_options
,
'Invalid `average` argument, '
\
f
'please specify from
{
average_options
}
.'
pred
=
to_tensor
(
pred
)
target
=
to_tensor
(
target
).
to
(
torch
.
int64
)
assert
pred
.
size
(
0
)
==
target
.
size
(
0
),
\
f
"The size of pred (
{
pred
.
size
(
0
)
}
) doesn't match "
\
f
'the target (
{
target
.
size
(
0
)
}
).'
if
pred
.
ndim
==
1
:
assert
num_classes
is
not
None
,
\
'Please specify the `num_classes` if the `pred` is labels '
\
'intead of scores.'
gt_positive
=
F
.
one_hot
(
target
.
flatten
(),
num_classes
)
pred_positive
=
F
.
one_hot
(
pred
.
to
(
torch
.
int64
),
num_classes
)
return
_precision_recall_f1_support
(
pred_positive
,
gt_positive
,
average
)
else
:
# For pred score, calculate on all thresholds.
num_classes
=
pred
.
size
(
1
)
pred_score
,
pred_label
=
torch
.
topk
(
pred
,
k
=
1
)
pred_score
=
pred_score
.
flatten
()
pred_label
=
pred_label
.
flatten
()
gt_positive
=
F
.
one_hot
(
target
.
flatten
(),
num_classes
)
results
=
[]
for
thr
in
thrs
:
pred_positive
=
F
.
one_hot
(
pred_label
,
num_classes
)
if
thr
is
not
None
:
pred_positive
[
pred_score
<=
thr
]
=
0
results
.
append
(
_precision_recall_f1_support
(
pred_positive
,
gt_positive
,
average
))
return
results
@
METRICS
.
register_module
()
class
ConfusionMatrix
(
BaseMetric
):
r
"""A metric to calculate confusion matrix for single-label tasks.
Args:
num_classes (int, optional): The number of classes. Defaults to None.
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
prefix (str, optional): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Defaults to None.
Examples:
1. The basic usage.
>>> import torch
>>> from mmpretrain.evaluation import ConfusionMatrix
>>> y_pred = [0, 1, 1, 3]
>>> y_true = [0, 2, 1, 3]
>>> ConfusionMatrix.calculate(y_pred, y_true, num_classes=4)
tensor([[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 0, 0, 1]])
>>> # plot the confusion matrix
>>> import matplotlib.pyplot as plt
>>> y_score = torch.rand((1000, 10))
>>> y_true = torch.randint(10, (1000, ))
>>> matrix = ConfusionMatrix.calculate(y_score, y_true)
>>> ConfusionMatrix().plot(matrix)
>>> plt.show()
2. In the config file
.. code:: python
val_evaluator = dict(type='ConfusionMatrix')
test_evaluator = dict(type='ConfusionMatrix')
"""
# noqa: E501
default_prefix
=
'confusion_matrix'
def
__init__
(
self
,
num_classes
:
Optional
[
int
]
=
None
,
collect_device
:
str
=
'cpu'
,
prefix
:
Optional
[
str
]
=
None
)
->
None
:
super
().
__init__
(
collect_device
,
prefix
)
self
.
num_classes
=
num_classes
def
process
(
self
,
data_batch
,
data_samples
:
Sequence
[
dict
])
->
None
:
for
data_sample
in
data_samples
:
if
'pred_score'
in
data_sample
:
pred_score
=
data_sample
[
'pred_score'
]
pred_label
=
pred_score
.
argmax
(
dim
=
0
,
keepdim
=
True
)
self
.
num_classes
=
pred_score
.
size
(
0
)
else
:
pred_label
=
data_sample
[
'pred_label'
]
self
.
results
.
append
({
'pred_label'
:
pred_label
,
'gt_label'
:
data_sample
[
'gt_label'
],
})
def
compute_metrics
(
self
,
results
:
list
)
->
dict
:
pred_labels
=
[]
gt_labels
=
[]
for
result
in
results
:
pred_labels
.
append
(
result
[
'pred_label'
])
gt_labels
.
append
(
result
[
'gt_label'
])
confusion_matrix
=
ConfusionMatrix
.
calculate
(
torch
.
cat
(
pred_labels
),
torch
.
cat
(
gt_labels
),
num_classes
=
self
.
num_classes
)
return
{
'result'
:
confusion_matrix
}
@
staticmethod
def
calculate
(
pred
,
target
,
num_classes
=
None
)
->
dict
:
"""Calculate the confusion matrix for single-label task.
Args:
pred (torch.Tensor | np.ndarray | Sequence): The prediction
results. It can be labels (N, ), or scores of every
class (N, C).
target (torch.Tensor | np.ndarray | Sequence): The target of
each prediction with shape (N, ).
num_classes (Optional, int): The number of classes. If the ``pred``
is label instead of scores, this argument is required.
Defaults to None.
Returns:
torch.Tensor: The confusion matrix.
"""
pred
=
to_tensor
(
pred
)
target_label
=
to_tensor
(
target
).
int
()
assert
pred
.
size
(
0
)
==
target_label
.
size
(
0
),
\
f
"The size of pred (
{
pred
.
size
(
0
)
}
) doesn't match "
\
f
'the target (
{
target_label
.
size
(
0
)
}
).'
assert
target_label
.
ndim
==
1
if
pred
.
ndim
==
1
:
assert
num_classes
is
not
None
,
\
'Please specify the `num_classes` if the `pred` is labels '
\
'intead of scores.'
pred_label
=
pred
else
:
num_classes
=
num_classes
or
pred
.
size
(
1
)
pred_label
=
torch
.
argmax
(
pred
,
dim
=
1
).
flatten
()
with
torch
.
no_grad
():
indices
=
num_classes
*
target_label
+
pred_label
matrix
=
torch
.
bincount
(
indices
,
minlength
=
num_classes
**
2
)
matrix
=
matrix
.
reshape
(
num_classes
,
num_classes
)
return
matrix
@
staticmethod
def
plot
(
confusion_matrix
:
torch
.
Tensor
,
include_values
:
bool
=
False
,
cmap
:
str
=
'viridis'
,
classes
:
Optional
[
List
[
str
]]
=
None
,
colorbar
:
bool
=
True
,
show
:
bool
=
True
):
"""Draw a confusion matrix by matplotlib.
Modified from `Scikit-Learn
<https://github.com/scikit-learn/scikit-learn/blob/dc580a8ef/sklearn/metrics/_plot/confusion_matrix.py#L81>`_
Args:
confusion_matrix (torch.Tensor): The confusion matrix to draw.
include_values (bool): Whether to draw the values in the figure.
Defaults to False.
cmap (str): The color map to use. Defaults to use "viridis".
classes (list[str], optional): The names of categories.
Defaults to None, which means to use index number.
colorbar (bool): Whether to show the colorbar. Defaults to True.
show (bool): Whether to show the figure immediately.
Defaults to True.
"""
# noqa: E501
import
matplotlib.pyplot
as
plt
fig
,
ax
=
plt
.
subplots
(
figsize
=
(
10
,
10
))
num_classes
=
confusion_matrix
.
size
(
0
)
im_
=
ax
.
imshow
(
confusion_matrix
,
interpolation
=
'nearest'
,
cmap
=
cmap
)
text_
=
None
cmap_min
,
cmap_max
=
im_
.
cmap
(
0
),
im_
.
cmap
(
1.0
)
if
include_values
:
text_
=
np
.
empty_like
(
confusion_matrix
,
dtype
=
object
)
# print text with appropriate color depending on background
thresh
=
(
confusion_matrix
.
max
()
+
confusion_matrix
.
min
())
/
2.0
for
i
,
j
in
product
(
range
(
num_classes
),
range
(
num_classes
)):
color
=
cmap_max
if
confusion_matrix
[
i
,
j
]
<
thresh
else
cmap_min
text_cm
=
format
(
confusion_matrix
[
i
,
j
],
'.2g'
)
text_d
=
format
(
confusion_matrix
[
i
,
j
],
'd'
)
if
len
(
text_d
)
<
len
(
text_cm
):
text_cm
=
text_d
text_
[
i
,
j
]
=
ax
.
text
(
j
,
i
,
text_cm
,
ha
=
'center'
,
va
=
'center'
,
color
=
color
)
display_labels
=
classes
or
np
.
arange
(
num_classes
)
if
colorbar
:
fig
.
colorbar
(
im_
,
ax
=
ax
)
ax
.
set
(
xticks
=
np
.
arange
(
num_classes
),
yticks
=
np
.
arange
(
num_classes
),
xticklabels
=
display_labels
,
yticklabels
=
display_labels
,
ylabel
=
'True label'
,
xlabel
=
'Predicted label'
,
)
ax
.
invert_yaxis
()
ax
.
xaxis
.
tick_top
()
ax
.
set_ylim
((
num_classes
-
0.5
,
-
0.5
))
# Automatically rotate the x labels.
fig
.
autofmt_xdate
(
ha
=
'center'
)
if
show
:
plt
.
show
()
return
fig
mmpretrain/evaluation/metrics/visual_grounding_eval.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
List
import
torch
import
torchvision.ops.boxes
as
boxes
from
mmengine.evaluator
import
BaseMetric
from
mmpretrain.registry
import
METRICS
def
aligned_box_iou
(
boxes1
:
torch
.
Tensor
,
boxes2
:
torch
.
Tensor
):
area1
=
boxes
.
box_area
(
boxes1
)
area2
=
boxes
.
box_area
(
boxes2
)
lt
=
torch
.
max
(
boxes1
[:,
:
2
],
boxes2
[:,
:
2
])
# (B, 2)
rb
=
torch
.
min
(
boxes1
[:,
2
:],
boxes2
[:,
2
:])
# (B, 2)
wh
=
boxes
.
_upcast
(
rb
-
lt
).
clamp
(
min
=
0
)
# (B, 2)
inter
=
wh
[:,
0
]
*
wh
[:,
1
]
# (B, )
union
=
area1
+
area2
-
inter
iou
=
inter
/
union
return
iou
@
METRICS
.
register_module
()
class
VisualGroundingMetric
(
BaseMetric
):
"""Visual Grounding evaluator.
Calculate the box mIOU and box grounding accuracy for visual grounding
model.
Args:
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
prefix (str, optional): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Should be modified according to the
`retrieval_type` for unambiguous results. Defaults to TR.
"""
default_prefix
=
'visual-grounding'
def
process
(
self
,
data_batch
,
data_samples
):
"""Process one batch of data samples.
The processed results should be stored in ``self.results``, which will
be used to computed the metrics when all batches have been processed.
Args:
data_batch: A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from the model.
"""
for
preds
in
data_samples
:
pred_box
=
preds
[
'pred_bboxes'
].
squeeze
()
box_gt
=
torch
.
Tensor
(
preds
[
'gt_bboxes'
]).
squeeze
()
result
=
{
'box'
:
pred_box
.
to
(
'cpu'
).
squeeze
(),
'box_target'
:
box_gt
.
squeeze
(),
}
self
.
results
.
append
(
result
)
def
compute_metrics
(
self
,
results
:
List
):
"""Compute the metrics from processed results.
Args:
results (dict): The processed results of each batch.
Returns:
Dict: The computed metrics. The keys are the names of the metrics,
and the values are corresponding results.
"""
pred_boxes
=
torch
.
stack
([
each
[
'box'
]
for
each
in
results
])
gt_boxes
=
torch
.
stack
([
each
[
'box_target'
]
for
each
in
results
])
iou
=
aligned_box_iou
(
pred_boxes
,
gt_boxes
)
accu_num
=
torch
.
sum
(
iou
>=
0.5
)
miou
=
torch
.
mean
(
iou
)
acc
=
accu_num
/
len
(
gt_boxes
)
coco_val
=
{
'miou'
:
miou
,
'acc'
:
acc
}
return
coco_val
mmpretrain/evaluation/metrics/voc_multi_label.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
Optional
,
Sequence
from
mmpretrain.registry
import
METRICS
from
mmpretrain.structures
import
label_to_onehot
from
.multi_label
import
AveragePrecision
,
MultiLabelMetric
class
VOCMetricMixin
:
"""A mixin class for VOC dataset metrics, VOC annotations have extra
`difficult` attribute for each object, therefore, extra option is needed
for calculating VOC metrics.
Args:
difficult_as_postive (Optional[bool]): Whether to map the difficult
labels as positive in one-hot ground truth for evaluation. If it
set to True, map difficult gt labels to positive ones(1), If it
set to False, map difficult gt labels to negative ones(0).
Defaults to None, the difficult labels will be set to '-1'.
"""
def
__init__
(
self
,
*
arg
,
difficult_as_positive
:
Optional
[
bool
]
=
None
,
**
kwarg
):
self
.
difficult_as_positive
=
difficult_as_positive
super
().
__init__
(
*
arg
,
**
kwarg
)
def
process
(
self
,
data_batch
,
data_samples
:
Sequence
[
dict
]):
"""Process one batch of data samples.
The processed results should be stored in ``self.results``, which will
be used to computed the metrics when all batches have been processed.
Args:
data_batch: A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from the model.
"""
for
data_sample
in
data_samples
:
result
=
dict
()
gt_label
=
data_sample
[
'gt_label'
]
gt_label_difficult
=
data_sample
[
'gt_label_difficult'
]
result
[
'pred_score'
]
=
data_sample
[
'pred_score'
].
clone
()
num_classes
=
result
[
'pred_score'
].
size
()[
-
1
]
if
'gt_score'
in
data_sample
:
result
[
'gt_score'
]
=
data_sample
[
'gt_score'
].
clone
()
else
:
result
[
'gt_score'
]
=
label_to_onehot
(
gt_label
,
num_classes
)
# VOC annotation labels all the objects in a single image
# therefore, some categories are appeared both in
# difficult objects and non-difficult objects.
# Here we reckon those labels which are only exists in difficult
# objects as difficult labels.
difficult_label
=
set
(
gt_label_difficult
)
-
(
set
(
gt_label_difficult
)
&
set
(
gt_label
.
tolist
()))
# set difficult label for better eval
if
self
.
difficult_as_positive
is
None
:
result
[
'gt_score'
][[
*
difficult_label
]]
=
-
1
elif
self
.
difficult_as_positive
:
result
[
'gt_score'
][[
*
difficult_label
]]
=
1
# Save the result to `self.results`.
self
.
results
.
append
(
result
)
@
METRICS
.
register_module
()
class
VOCMultiLabelMetric
(
VOCMetricMixin
,
MultiLabelMetric
):
"""A collection of metrics for multi-label multi-class classification task
based on confusion matrix for VOC dataset.
It includes precision, recall, f1-score and support.
Args:
difficult_as_postive (Optional[bool]): Whether to map the difficult
labels as positive in one-hot ground truth for evaluation. If it
set to True, map difficult gt labels to positive ones(1), If it
set to False, map difficult gt labels to negative ones(0).
Defaults to None, the difficult labels will be set to '-1'.
**kwarg: Refers to `MultiLabelMetric` for detailed docstrings.
"""
@
METRICS
.
register_module
()
class
VOCAveragePrecision
(
VOCMetricMixin
,
AveragePrecision
):
"""Calculate the average precision with respect of classes for VOC dataset.
Args:
difficult_as_postive (Optional[bool]): Whether to map the difficult
labels as positive in one-hot ground truth for evaluation. If it
set to True, map difficult gt labels to positive ones(1), If it
set to False, map difficult gt labels to negative ones(0).
Defaults to None, the difficult labels will be set to '-1'.
**kwarg: Refers to `AveragePrecision` for detailed docstrings.
"""
mmpretrain/evaluation/metrics/vqa.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
# Partly adopted from https://github.com/GT-Vision-Lab/VQA
# Copyright (c) 2014, Aishwarya Agrawal
from
typing
import
List
,
Optional
import
mmengine
from
mmengine.evaluator
import
BaseMetric
from
mmengine.logging
import
MMLogger
from
mmpretrain.registry
import
METRICS
def
_process_punctuation
(
inText
):
import
re
outText
=
inText
punct
=
[
';'
,
r
'/'
,
'['
,
']'
,
'"'
,
'{'
,
'}'
,
'('
,
')'
,
'='
,
'+'
,
'
\\
'
,
'_'
,
'-'
,
'>'
,
'<'
,
'@'
,
'`'
,
','
,
'?'
,
'!'
]
commaStrip
=
re
.
compile
(
'(\d)(,)(\d)'
)
# noqa: W605
periodStrip
=
re
.
compile
(
'(?!<=\d)(\.)(?!\d)'
)
# noqa: W605
for
p
in
punct
:
if
(
p
+
' '
in
inText
or
' '
+
p
in
inText
)
or
(
re
.
search
(
commaStrip
,
inText
)
is
not
None
):
outText
=
outText
.
replace
(
p
,
''
)
else
:
outText
=
outText
.
replace
(
p
,
' '
)
outText
=
periodStrip
.
sub
(
''
,
outText
,
re
.
UNICODE
)
return
outText
def
_process_digit_article
(
inText
):
outText
=
[]
tempText
=
inText
.
lower
().
split
()
articles
=
[
'a'
,
'an'
,
'the'
]
manualMap
=
{
'none'
:
'0'
,
'zero'
:
'0'
,
'one'
:
'1'
,
'two'
:
'2'
,
'three'
:
'3'
,
'four'
:
'4'
,
'five'
:
'5'
,
'six'
:
'6'
,
'seven'
:
'7'
,
'eight'
:
'8'
,
'nine'
:
'9'
,
'ten'
:
'10'
,
}
contractions
=
{
'aint'
:
"ain't"
,
'arent'
:
"aren't"
,
'cant'
:
"can't"
,
'couldve'
:
"could've"
,
'couldnt'
:
"couldn't"
,
"couldn'tve"
:
"couldn't've"
,
"couldnt've"
:
"couldn't've"
,
'didnt'
:
"didn't"
,
'doesnt'
:
"doesn't"
,
'dont'
:
"don't"
,
'hadnt'
:
"hadn't"
,
"hadnt've"
:
"hadn't've"
,
"hadn'tve"
:
"hadn't've"
,
'hasnt'
:
"hasn't"
,
'havent'
:
"haven't"
,
'hed'
:
"he'd"
,
"hed've"
:
"he'd've"
,
"he'dve"
:
"he'd've"
,
'hes'
:
"he's"
,
'howd'
:
"how'd"
,
'howll'
:
"how'll"
,
'hows'
:
"how's"
,
"Id've"
:
"I'd've"
,
"I'dve"
:
"I'd've"
,
'Im'
:
"I'm"
,
'Ive'
:
"I've"
,
'isnt'
:
"isn't"
,
'itd'
:
"it'd"
,
"itd've"
:
"it'd've"
,
"it'dve"
:
"it'd've"
,
'itll'
:
"it'll"
,
"let's"
:
"let's"
,
'maam'
:
"ma'am"
,
'mightnt'
:
"mightn't"
,
"mightnt've"
:
"mightn't've"
,
"mightn'tve"
:
"mightn't've"
,
'mightve'
:
"might've"
,
'mustnt'
:
"mustn't"
,
'mustve'
:
"must've"
,
'neednt'
:
"needn't"
,
'notve'
:
"not've"
,
'oclock'
:
"o'clock"
,
'oughtnt'
:
"oughtn't"
,
"ow's'at"
:
"'ow's'at"
,
"'ows'at"
:
"'ow's'at"
,
"'ow'sat"
:
"'ow's'at"
,
'shant'
:
"shan't"
,
"shed've"
:
"she'd've"
,
"she'dve"
:
"she'd've"
,
"she's"
:
"she's"
,
'shouldve'
:
"should've"
,
'shouldnt'
:
"shouldn't"
,
"shouldnt've"
:
"shouldn't've"
,
"shouldn'tve"
:
"shouldn't've"
,
"somebody'd"
:
'somebodyd'
,
"somebodyd've"
:
"somebody'd've"
,
"somebody'dve"
:
"somebody'd've"
,
'somebodyll'
:
"somebody'll"
,
'somebodys'
:
"somebody's"
,
'someoned'
:
"someone'd"
,
"someoned've"
:
"someone'd've"
,
"someone'dve"
:
"someone'd've"
,
'someonell'
:
"someone'll"
,
'someones'
:
"someone's"
,
'somethingd'
:
"something'd"
,
"somethingd've"
:
"something'd've"
,
"something'dve"
:
"something'd've"
,
'somethingll'
:
"something'll"
,
'thats'
:
"that's"
,
'thered'
:
"there'd"
,
"thered've"
:
"there'd've"
,
"there'dve"
:
"there'd've"
,
'therere'
:
"there're"
,
'theres'
:
"there's"
,
'theyd'
:
"they'd"
,
"theyd've"
:
"they'd've"
,
"they'dve"
:
"they'd've"
,
'theyll'
:
"they'll"
,
'theyre'
:
"they're"
,
'theyve'
:
"they've"
,
'twas'
:
"'twas"
,
'wasnt'
:
"wasn't"
,
"wed've"
:
"we'd've"
,
"we'dve"
:
"we'd've"
,
'weve'
:
"we've"
,
'werent'
:
"weren't"
,
'whatll'
:
"what'll"
,
'whatre'
:
"what're"
,
'whats'
:
"what's"
,
'whatve'
:
"what've"
,
'whens'
:
"when's"
,
'whered'
:
"where'd"
,
'wheres'
:
"where's"
,
'whereve'
:
"where've"
,
'whod'
:
"who'd"
,
"whod've"
:
"who'd've"
,
"who'dve"
:
"who'd've"
,
'wholl'
:
"who'll"
,
'whos'
:
"who's"
,
'whove'
:
"who've"
,
'whyll'
:
"why'll"
,
'whyre'
:
"why're"
,
'whys'
:
"why's"
,
'wont'
:
"won't"
,
'wouldve'
:
"would've"
,
'wouldnt'
:
"wouldn't"
,
"wouldnt've"
:
"wouldn't've"
,
"wouldn'tve"
:
"wouldn't've"
,
'yall'
:
"y'all"
,
"yall'll"
:
"y'all'll"
,
"y'allll"
:
"y'all'll"
,
"yall'd've"
:
"y'all'd've"
,
"y'alld've"
:
"y'all'd've"
,
"y'all'dve"
:
"y'all'd've"
,
'youd'
:
"you'd"
,
"youd've"
:
"you'd've"
,
"you'dve"
:
"you'd've"
,
'youll'
:
"you'll"
,
'youre'
:
"you're"
,
'youve'
:
"you've"
,
}
for
word
in
tempText
:
word
=
manualMap
.
setdefault
(
word
,
word
)
if
word
not
in
articles
:
outText
.
append
(
word
)
for
wordId
,
word
in
enumerate
(
outText
):
if
word
in
contractions
:
outText
[
wordId
]
=
contractions
[
word
]
outText
=
' '
.
join
(
outText
)
return
outText
@
METRICS
.
register_module
()
class
VQAAcc
(
BaseMetric
):
'''VQA Acc metric.
Args:
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
prefix (str, optional): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Should be modified according to the
`retrieval_type` for unambiguous results. Defaults to TR.
'''
default_prefix
=
'VQA'
def
__init__
(
self
,
full_score_weight
:
float
=
0.3
,
collect_device
:
str
=
'cpu'
,
prefix
:
Optional
[
str
]
=
None
):
super
().
__init__
(
collect_device
=
collect_device
,
prefix
=
prefix
)
self
.
full_score_weight
=
full_score_weight
def
process
(
self
,
data_batch
,
data_samples
):
"""Process one batch of data samples.
The processed results should be stored in ``self.results``, which will
be used to computed the metrics when all batches have been processed.
Args:
data_batch: A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from the model.
"""
for
sample
in
data_samples
:
gt_answer
=
sample
.
get
(
'gt_answer'
)
gt_answer_weight
=
sample
.
get
(
'gt_answer_weight'
)
if
isinstance
(
gt_answer
,
str
):
gt_answer
=
[
gt_answer
]
if
gt_answer_weight
is
None
:
gt_answer_weight
=
[
1.
/
(
len
(
gt_answer
))]
*
len
(
gt_answer
)
result
=
{
'pred_answer'
:
sample
.
get
(
'pred_answer'
),
'gt_answer'
:
gt_answer
,
'gt_answer_weight'
:
gt_answer_weight
,
}
self
.
results
.
append
(
result
)
def
compute_metrics
(
self
,
results
:
List
):
"""Compute the metrics from processed results.
Args:
results (dict): The processed results of each batch.
Returns:
Dict: The computed metrics. The keys are the names of the metrics,
and the values are corresponding results.
"""
acc
=
[]
for
result
in
results
:
pred_answer
=
self
.
_process_answer
(
result
[
'pred_answer'
])
gt_answer
=
[
self
.
_process_answer
(
answer
)
for
answer
in
result
[
'gt_answer'
]
]
answer_weight
=
result
[
'gt_answer_weight'
]
weight_sum
=
0
for
i
,
gt
in
enumerate
(
gt_answer
):
if
gt
==
pred_answer
:
weight_sum
+=
answer_weight
[
i
]
vqa_acc
=
min
(
1.0
,
weight_sum
/
self
.
full_score_weight
)
acc
.
append
(
vqa_acc
)
accuracy
=
sum
(
acc
)
/
len
(
acc
)
*
100
metrics
=
{
'acc'
:
accuracy
}
return
metrics
def
_process_answer
(
self
,
answer
):
answer
=
answer
.
replace
(
'
\n
'
,
' '
)
answer
=
answer
.
replace
(
'
\t
'
,
' '
)
answer
=
answer
.
strip
()
answer
=
_process_punctuation
(
answer
)
answer
=
_process_digit_article
(
answer
)
return
answer
@
METRICS
.
register_module
()
class
ReportVQA
(
BaseMetric
):
"""Dump VQA result to the standard json format for VQA evaluation.
Args:
file_path (str): The file path to save the result file.
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
prefix (str, optional): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Should be modified according to the
`retrieval_type` for unambiguous results. Defaults to TR.
"""
default_prefix
=
'VQA'
def
__init__
(
self
,
file_path
:
str
,
collect_device
:
str
=
'cpu'
,
prefix
:
Optional
[
str
]
=
None
):
super
().
__init__
(
collect_device
=
collect_device
,
prefix
=
prefix
)
if
not
file_path
.
endswith
(
'.json'
):
raise
ValueError
(
'The output file must be a json file.'
)
self
.
file_path
=
file_path
def
process
(
self
,
data_batch
,
data_samples
)
->
None
:
"""transfer tensors in predictions to CPU."""
for
sample
in
data_samples
:
question_id
=
sample
[
'question_id'
]
pred_answer
=
sample
[
'pred_answer'
]
result
=
{
'question_id'
:
int
(
question_id
),
'answer'
:
pred_answer
,
}
self
.
results
.
append
(
result
)
def
compute_metrics
(
self
,
results
:
List
):
"""Dump the result to json file."""
mmengine
.
dump
(
results
,
self
.
file_path
)
logger
=
MMLogger
.
get_current_instance
()
logger
.
info
(
f
'Results has been saved to
{
self
.
file_path
}
.'
)
return
{}
mmpretrain/models/__init__.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
from
.backbones
import
*
# noqa: F401,F403
from
.builder
import
(
BACKBONES
,
CLASSIFIERS
,
HEADS
,
LOSSES
,
NECKS
,
build_backbone
,
build_classifier
,
build_head
,
build_loss
,
build_neck
)
from
.classifiers
import
*
# noqa: F401,F403
from
.heads
import
*
# noqa: F401,F403
from
.losses
import
*
# noqa: F401,F403
from
.multimodal
import
*
# noqa: F401,F403
from
.necks
import
*
# noqa: F401,F403
from
.peft
import
*
# noqa: F401,F403
from
.retrievers
import
*
# noqa: F401,F403
from
.selfsup
import
*
# noqa: F401,F403
from
.tta
import
*
# noqa: F401,F403
from
.utils
import
*
# noqa: F401,F403
__all__
=
[
'BACKBONES'
,
'HEADS'
,
'NECKS'
,
'LOSSES'
,
'CLASSIFIERS'
,
'build_backbone'
,
'build_head'
,
'build_neck'
,
'build_loss'
,
'build_classifier'
]
mmpretrain/models/backbones/__init__.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
from
.alexnet
import
AlexNet
from
.beit
import
BEiTViT
from
.conformer
import
Conformer
from
.convmixer
import
ConvMixer
from
.convnext
import
ConvNeXt
from
.cspnet
import
CSPDarkNet
,
CSPNet
,
CSPResNet
,
CSPResNeXt
from
.davit
import
DaViT
from
.deit
import
DistilledVisionTransformer
from
.deit3
import
DeiT3
from
.densenet
import
DenseNet
from
.edgenext
import
EdgeNeXt
from
.efficientformer
import
EfficientFormer
from
.efficientnet
import
EfficientNet
from
.efficientnet_v2
import
EfficientNetV2
from
.hivit
import
HiViT
from
.hornet
import
HorNet
from
.hrnet
import
HRNet
from
.inception_v3
import
InceptionV3
from
.lenet
import
LeNet5
from
.levit
import
LeViT
from
.mixmim
import
MixMIMTransformer
from
.mlp_mixer
import
MlpMixer
from
.mobilenet_v2
import
MobileNetV2
from
.mobilenet_v3
import
MobileNetV3
from
.mobileone
import
MobileOne
from
.mobilevit
import
MobileViT
from
.mvit
import
MViT
from
.poolformer
import
PoolFormer
from
.regnet
import
RegNet
from
.replknet
import
RepLKNet
from
.repmlp
import
RepMLPNet
from
.repvgg
import
RepVGG
from
.res2net
import
Res2Net
from
.resnest
import
ResNeSt
from
.resnet
import
ResNet
,
ResNetV1c
,
ResNetV1d
from
.resnet_cifar
import
ResNet_CIFAR
from
.resnext
import
ResNeXt
from
.revvit
import
RevVisionTransformer
from
.riformer
import
RIFormer
from
.seresnet
import
SEResNet
from
.seresnext
import
SEResNeXt
from
.shufflenet_v1
import
ShuffleNetV1
from
.shufflenet_v2
import
ShuffleNetV2
from
.sparse_convnext
import
SparseConvNeXt
from
.sparse_resnet
import
SparseResNet
from
.swin_transformer
import
SwinTransformer
from
.swin_transformer_v2
import
SwinTransformerV2
from
.t2t_vit
import
T2T_ViT
from
.timm_backbone
import
TIMMBackbone
from
.tinyvit
import
TinyViT
from
.tnt
import
TNT
from
.twins
import
PCPVT
,
SVT
from
.van
import
VAN
from
.vgg
import
VGG
from
.vig
import
PyramidVig
,
Vig
from
.vision_transformer
import
VisionTransformer
from
.vit_eva02
import
ViTEVA02
from
.vit_sam
import
ViTSAM
from
.xcit
import
XCiT
__all__
=
[
'LeNet5'
,
'AlexNet'
,
'VGG'
,
'RegNet'
,
'ResNet'
,
'ResNeXt'
,
'ResNetV1d'
,
'ResNeSt'
,
'ResNet_CIFAR'
,
'SEResNet'
,
'SEResNeXt'
,
'ShuffleNetV1'
,
'ShuffleNetV2'
,
'MobileNetV2'
,
'MobileNetV3'
,
'VisionTransformer'
,
'SwinTransformer'
,
'TNT'
,
'TIMMBackbone'
,
'T2T_ViT'
,
'Res2Net'
,
'RepVGG'
,
'Conformer'
,
'MlpMixer'
,
'DistilledVisionTransformer'
,
'PCPVT'
,
'SVT'
,
'EfficientNet'
,
'EfficientNetV2'
,
'ConvNeXt'
,
'HRNet'
,
'ResNetV1c'
,
'ConvMixer'
,
'EdgeNeXt'
,
'CSPDarkNet'
,
'CSPResNet'
,
'CSPResNeXt'
,
'CSPNet'
,
'RepLKNet'
,
'RepMLPNet'
,
'PoolFormer'
,
'RIFormer'
,
'DenseNet'
,
'VAN'
,
'InceptionV3'
,
'MobileOne'
,
'EfficientFormer'
,
'SwinTransformerV2'
,
'MViT'
,
'DeiT3'
,
'HorNet'
,
'MobileViT'
,
'DaViT'
,
'BEiTViT'
,
'RevVisionTransformer'
,
'MixMIMTransformer'
,
'TinyViT'
,
'LeViT'
,
'Vig'
,
'PyramidVig'
,
'XCiT'
,
'ViTSAM'
,
'ViTEVA02'
,
'HiViT'
,
'SparseResNet'
,
'SparseConvNeXt'
,
]
mmpretrain/models/backbones/alexnet.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
import
torch.nn
as
nn
from
mmpretrain.registry
import
MODELS
from
.base_backbone
import
BaseBackbone
@
MODELS
.
register_module
()
class
AlexNet
(
BaseBackbone
):
"""`AlexNet <https://en.wikipedia.org/wiki/AlexNet>`_ backbone.
The input for AlexNet is a 224x224 RGB image.
Args:
num_classes (int): number of classes for classification.
The default value is -1, which uses the backbone as
a feature extractor without the top classifier.
"""
def
__init__
(
self
,
num_classes
=-
1
):
super
(
AlexNet
,
self
).
__init__
()
self
.
num_classes
=
num_classes
self
.
features
=
nn
.
Sequential
(
nn
.
Conv2d
(
3
,
64
,
kernel_size
=
11
,
stride
=
4
,
padding
=
2
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
),
nn
.
Conv2d
(
64
,
192
,
kernel_size
=
5
,
padding
=
2
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
),
nn
.
Conv2d
(
192
,
384
,
kernel_size
=
3
,
padding
=
1
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Conv2d
(
384
,
256
,
kernel_size
=
3
,
padding
=
1
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Conv2d
(
256
,
256
,
kernel_size
=
3
,
padding
=
1
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
),
)
if
self
.
num_classes
>
0
:
self
.
classifier
=
nn
.
Sequential
(
nn
.
Dropout
(),
nn
.
Linear
(
256
*
6
*
6
,
4096
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Dropout
(),
nn
.
Linear
(
4096
,
4096
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Linear
(
4096
,
num_classes
),
)
def
forward
(
self
,
x
):
x
=
self
.
features
(
x
)
if
self
.
num_classes
>
0
:
x
=
x
.
view
(
x
.
size
(
0
),
256
*
6
*
6
)
x
=
self
.
classifier
(
x
)
return
(
x
,
)
mmpretrain/models/backbones/base_backbone.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
from
abc
import
ABCMeta
,
abstractmethod
from
mmengine.model
import
BaseModule
class
BaseBackbone
(
BaseModule
,
metaclass
=
ABCMeta
):
"""Base backbone.
This class defines the basic functions of a backbone. Any backbone that
inherits this class should at least define its own `forward` function.
"""
def
__init__
(
self
,
init_cfg
=
None
):
super
(
BaseBackbone
,
self
).
__init__
(
init_cfg
)
@
abstractmethod
def
forward
(
self
,
x
):
"""Forward computation.
Args:
x (tensor | tuple[tensor]): x could be a Torch.tensor or a tuple of
Torch.tensor, containing input data for forward computation.
"""
pass
def
train
(
self
,
mode
=
True
):
"""Set module status before forward computation.
Args:
mode (bool): Whether it is train_mode or test_mode
"""
super
(
BaseBackbone
,
self
).
train
(
mode
)
mmpretrain/models/backbones/beit.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
List
,
Optional
,
Sequence
,
Tuple
,
Union
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
mmcv.cnn.bricks.drop
import
build_dropout
from
mmcv.cnn.bricks.transformer
import
FFN
,
PatchEmbed
from
mmengine.model
import
BaseModule
,
ModuleList
from
mmengine.model.weight_init
import
trunc_normal_
from
mmpretrain.registry
import
MODELS
from
..utils
import
(
BEiTAttention
,
build_norm_layer
,
resize_pos_embed
,
resize_relative_position_bias_table
,
to_2tuple
)
from
.base_backbone
import
BaseBackbone
from
.vision_transformer
import
TransformerEncoderLayer
class
RelativePositionBias
(
BaseModule
):
"""Relative Position Bias.
This module is copied from
https://github.com/microsoft/unilm/blob/master/beit/modeling_finetune.py#L209.
Args:
window_size (Sequence[int]): The window size of the relative
position bias.
num_heads (int): The number of head in multi-head attention.
with_cls_token (bool): To indicate the backbone has cls_token or not.
Defaults to True.
"""
def
__init__
(
self
,
window_size
:
Sequence
[
int
],
num_heads
:
int
,
with_cls_token
:
bool
=
True
,
)
->
None
:
super
().
__init__
()
self
.
window_size
=
window_size
if
with_cls_token
:
num_extra_tokens
=
3
else
:
num_extra_tokens
=
0
# cls to token & token to cls & cls to cls
self
.
num_relative_distance
=
(
2
*
window_size
[
0
]
-
1
)
*
(
2
*
window_size
[
1
]
-
1
)
+
num_extra_tokens
self
.
relative_position_bias_table
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
num_relative_distance
,
num_heads
))
# 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each
# token inside the window
coords_h
=
torch
.
arange
(
window_size
[
0
])
coords_w
=
torch
.
arange
(
window_size
[
1
])
coords
=
torch
.
stack
(
torch
.
meshgrid
([
coords_h
,
coords_w
]))
# 2, Wh, Ww
coords_flatten
=
torch
.
flatten
(
coords
,
1
)
# 2, Wh*Ww
relative_coords
=
coords_flatten
[:,
:,
None
]
-
\
coords_flatten
[:,
None
,
:]
# 2, Wh*Ww, Wh*Ww
relative_coords
=
relative_coords
.
permute
(
1
,
2
,
0
).
contiguous
()
# Wh*Ww, Wh*Ww, 2
relative_coords
[:,
:,
0
]
+=
window_size
[
0
]
-
1
# shift to start from 0
relative_coords
[:,
:,
1
]
+=
window_size
[
1
]
-
1
relative_coords
[:,
:,
0
]
*=
2
*
window_size
[
1
]
-
1
if
with_cls_token
:
relative_position_index
=
torch
.
zeros
(
size
=
(
window_size
[
0
]
*
window_size
[
1
]
+
1
,
)
*
2
,
dtype
=
relative_coords
.
dtype
)
relative_position_index
[
1
:,
1
:]
=
relative_coords
.
sum
(
-
1
)
# Wh*Ww, Wh*Ww
relative_position_index
[
0
,
0
:]
=
self
.
num_relative_distance
-
3
relative_position_index
[
0
:,
0
]
=
self
.
num_relative_distance
-
2
relative_position_index
[
0
,
0
]
=
self
.
num_relative_distance
-
1
else
:
relative_position_index
=
torch
.
zeros
(
size
=
(
window_size
[
0
]
*
window_size
[
1
],
)
*
2
,
dtype
=
relative_coords
.
dtype
)
relative_position_index
=
relative_coords
.
sum
(
-
1
)
# Wh*Ww, Wh*Ww
self
.
register_buffer
(
'relative_position_index'
,
relative_position_index
)
def
forward
(
self
)
->
torch
.
Tensor
:
# Wh*Ww,Wh*Ww,nH
relative_position_bias
=
self
.
relative_position_bias_table
[
self
.
relative_position_index
.
view
(
-
1
)].
view
(
self
.
window_size
[
0
]
*
self
.
window_size
[
1
]
+
1
,
self
.
window_size
[
0
]
*
self
.
window_size
[
1
]
+
1
,
-
1
)
return
relative_position_bias
.
permute
(
2
,
0
,
1
).
contiguous
()
# nH, Wh*Ww, Wh*Ww
class
BEiTTransformerEncoderLayer
(
TransformerEncoderLayer
):
"""Implements one encoder layer in BEiT.
Comparing with conventional ``TransformerEncoderLayer``, this module
adds weights to the shortcut connection. In addition, ``BEiTAttention``
is used to replace the original ``MultiheadAttention`` in
``TransformerEncoderLayer``.
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs.
layer_scale_init_value (float): The initialization value for
the learnable scaling of attention and FFN. 1 means no scaling.
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Defaults to 0.
window_size (tuple[int]): The height and width of the window.
Defaults to None.
use_rel_pos_bias (bool): Whether to use unique relative position bias,
if False, use shared relative position bias defined in backbone.
attn_drop_rate (float): The drop out rate for attention layer.
Defaults to 0.0.
drop_path_rate (float): Stochastic depth rate. Default 0.0.
num_fcs (int): The number of fully-connected layers for FFNs.
Defaults to 2.
bias (bool | str): The option to add leanable bias for q, k, v. If bias
is True, it will add leanable bias. If bias is 'qv_bias', it will
only add leanable bias for q, v. If bias is False, it will not add
bias for q, k, v. Default to 'qv_bias'.
act_cfg (dict): The activation config for FFNs.
Defaults to ``dict(type='GELU')``.
norm_cfg (dict): Config dict for normalization layer.
Defaults to dict(type='LN').
attn_cfg (dict): The configuration for the attention layer.
Defaults to an empty dict.
ffn_cfg (dict): The configuration for the ffn layer.
Defaults to ``dict(add_identity=False)``.
init_cfg (dict or List[dict], optional): Initialization config dict.
Defaults to None.
"""
def
__init__
(
self
,
embed_dims
:
int
,
num_heads
:
int
,
feedforward_channels
:
int
,
layer_scale_init_value
:
float
,
window_size
:
Tuple
[
int
,
int
],
use_rel_pos_bias
:
bool
,
drop_rate
:
float
=
0.
,
attn_drop_rate
:
float
=
0.
,
drop_path_rate
:
float
=
0.
,
num_fcs
:
int
=
2
,
bias
:
Union
[
str
,
bool
]
=
'qv_bias'
,
act_cfg
:
dict
=
dict
(
type
=
'GELU'
),
norm_cfg
:
dict
=
dict
(
type
=
'LN'
),
attn_cfg
:
dict
=
dict
(),
ffn_cfg
:
dict
=
dict
(
add_identity
=
False
),
init_cfg
:
Optional
[
Union
[
dict
,
List
[
dict
]]]
=
None
)
->
None
:
super
().
__init__
(
embed_dims
=
embed_dims
,
num_heads
=
num_heads
,
feedforward_channels
=
feedforward_channels
,
attn_drop_rate
=
attn_drop_rate
,
drop_path_rate
=
0.
,
drop_rate
=
0.
,
num_fcs
=
num_fcs
,
act_cfg
=
act_cfg
,
norm_cfg
=
norm_cfg
,
init_cfg
=
init_cfg
)
attn_cfg
=
{
'window_size'
:
window_size
,
'use_rel_pos_bias'
:
use_rel_pos_bias
,
'qk_scale'
:
None
,
'embed_dims'
:
embed_dims
,
'num_heads'
:
num_heads
,
'attn_drop'
:
attn_drop_rate
,
'proj_drop'
:
drop_rate
,
'bias'
:
bias
,
**
attn_cfg
,
}
self
.
attn
=
BEiTAttention
(
**
attn_cfg
)
ffn_cfg
=
{
'embed_dims'
:
embed_dims
,
'feedforward_channels'
:
feedforward_channels
,
'num_fcs'
:
num_fcs
,
'ffn_drop'
:
drop_rate
,
'dropout_layer'
:
dict
(
type
=
'DropPath'
,
drop_prob
=
drop_path_rate
),
'act_cfg'
:
act_cfg
,
**
ffn_cfg
,
}
self
.
ffn
=
FFN
(
**
ffn_cfg
)
# NOTE: drop path for stochastic depth, we shall see if
# this is better than dropout here
dropout_layer
=
dict
(
type
=
'DropPath'
,
drop_prob
=
drop_path_rate
)
self
.
drop_path
=
build_dropout
(
dropout_layer
)
if
dropout_layer
else
nn
.
Identity
()
if
layer_scale_init_value
>
0
:
self
.
gamma_1
=
nn
.
Parameter
(
layer_scale_init_value
*
torch
.
ones
((
embed_dims
)),
requires_grad
=
True
)
self
.
gamma_2
=
nn
.
Parameter
(
layer_scale_init_value
*
torch
.
ones
((
embed_dims
)),
requires_grad
=
True
)
else
:
self
.
gamma_1
,
self
.
gamma_2
=
None
,
None
def
forward
(
self
,
x
:
torch
.
Tensor
,
rel_pos_bias
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
gamma_1
is
None
:
x
=
x
+
self
.
drop_path
(
self
.
attn
(
self
.
ln1
(
x
),
rel_pos_bias
=
rel_pos_bias
))
x
=
x
+
self
.
drop_path
(
self
.
ffn
(
self
.
ln2
(
x
)))
else
:
x
=
x
+
self
.
drop_path
(
self
.
gamma_1
*
self
.
attn
(
self
.
ln1
(
x
),
rel_pos_bias
=
rel_pos_bias
))
x
=
x
+
self
.
drop_path
(
self
.
gamma_2
*
self
.
ffn
(
self
.
ln2
(
x
)))
return
x
@
MODELS
.
register_module
()
class
BEiTViT
(
BaseBackbone
):
"""Backbone for BEiT.
A PyTorch implement of : `BEiT: BERT Pre-Training of Image Transformers
<https://arxiv.org/abs/2106.08254>`_
A PyTorch implement of : `BEiT v2: Masked Image Modeling with
Vector-Quantized Visual Tokenizers <https://arxiv.org/abs/2208.06366>`_
Args:
arch (str | dict): BEiT architecture. If use string, choose from
'base', 'large'. If use dict, it should have below keys:
- **embed_dims** (int): The dimensions of embedding.
- **num_layers** (int): The number of transformer encoder layers.
- **num_heads** (int): The number of heads in attention modules.
- **feedforward_channels** (int): The hidden dimensions in
feedforward modules.
Defaults to 'base'.
img_size (int | tuple): The expected input image shape. Because we
support dynamic input shape, just set the argument to the most
common input image shape. Defaults to 224.
patch_size (int | tuple): The patch size in patch embedding.
Defaults to 16.
in_channels (int): The num of input channels. Defaults to 3.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
bias (bool | str): The option to add leanable bias for q, k, v. If bias
is True, it will add leanable bias. If bias is 'qv_bias', it will
only add leanable bias for q, v. If bias is False, it will not add
bias for q, k, v. Default to 'qv_bias'.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
out_type (str): The type of output features. Please choose from
- ``"cls_token"``: The class token tensor with shape (B, C).
- ``"featmap"``: The feature map tensor from the patch tokens
with shape (B, C, H, W).
- ``"avg_featmap"``: The global averaged feature map tensor
with shape (B, C).
- ``"raw"``: The raw feature tensor includes patch tokens and
class tokens with shape (B, L, C).
Defaults to ``"avg_featmap"``.
with_cls_token (bool): Whether concatenating class token into image
tokens as transformer input. Defaults to True.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
use_abs_pos_emb (bool): Use position embedding like vanilla ViT.
Defaults to False.
use_rel_pos_bias (bool): Use relative position embedding in each
transformer encoder layer. Defaults to True.
use_shared_rel_pos_bias (bool): Use shared relative position embedding,
all transformer encoder layers share the same relative position
embedding. Defaults to False.
layer_scale_init_value (float): The initialization value for
the learnable scaling of attention and FFN. Defaults to 0.1.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Defaults to "bicubic".
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
layer_cfgs (Sequence | dict): Configs of each transformer layer in
encoder. Defaults to an empty dict.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
arch_zoo
=
{
**
dict
.
fromkeys
(
[
's'
,
'small'
],
{
'embed_dims'
:
768
,
'num_layers'
:
8
,
'num_heads'
:
8
,
'feedforward_channels'
:
768
*
3
,
}),
**
dict
.
fromkeys
(
[
'b'
,
'base'
],
{
'embed_dims'
:
768
,
'num_layers'
:
12
,
'num_heads'
:
12
,
'feedforward_channels'
:
3072
}),
**
dict
.
fromkeys
(
[
'l'
,
'large'
],
{
'embed_dims'
:
1024
,
'num_layers'
:
24
,
'num_heads'
:
16
,
'feedforward_channels'
:
4096
}),
**
dict
.
fromkeys
(
[
'eva-g'
,
'eva-giant'
],
{
# The implementation in EVA
# <https://arxiv.org/abs/2211.07636>
'embed_dims'
:
1408
,
'num_layers'
:
40
,
'num_heads'
:
16
,
'feedforward_channels'
:
6144
}),
**
dict
.
fromkeys
(
[
'deit-t'
,
'deit-tiny'
],
{
'embed_dims'
:
192
,
'num_layers'
:
12
,
'num_heads'
:
3
,
'feedforward_channels'
:
192
*
4
}),
**
dict
.
fromkeys
(
[
'deit-s'
,
'deit-small'
],
{
'embed_dims'
:
384
,
'num_layers'
:
12
,
'num_heads'
:
6
,
'feedforward_channels'
:
384
*
4
}),
**
dict
.
fromkeys
(
[
'deit-b'
,
'deit-base'
],
{
'embed_dims'
:
768
,
'num_layers'
:
12
,
'num_heads'
:
12
,
'feedforward_channels'
:
768
*
4
}),
}
num_extra_tokens
=
1
# class token
OUT_TYPES
=
{
'raw'
,
'cls_token'
,
'featmap'
,
'avg_featmap'
}
def
__init__
(
self
,
arch
=
'base'
,
img_size
=
224
,
patch_size
=
16
,
in_channels
=
3
,
out_indices
=-
1
,
drop_rate
=
0
,
drop_path_rate
=
0
,
bias
=
'qv_bias'
,
norm_cfg
=
dict
(
type
=
'LN'
,
eps
=
1e-6
),
final_norm
=
False
,
out_type
=
'avg_featmap'
,
with_cls_token
=
True
,
frozen_stages
=-
1
,
use_abs_pos_emb
=
False
,
use_rel_pos_bias
=
True
,
use_shared_rel_pos_bias
=
False
,
interpolate_mode
=
'bicubic'
,
layer_scale_init_value
=
0.1
,
patch_cfg
=
dict
(),
layer_cfgs
=
dict
(),
init_cfg
=
None
):
super
(
BEiTViT
,
self
).
__init__
(
init_cfg
)
if
isinstance
(
arch
,
str
):
arch
=
arch
.
lower
()
assert
arch
in
set
(
self
.
arch_zoo
),
\
f
'Arch
{
arch
}
is not in default archs
{
set
(
self
.
arch_zoo
)
}
'
self
.
arch_settings
=
self
.
arch_zoo
[
arch
]
else
:
essential_keys
=
{
'embed_dims'
,
'num_layers'
,
'num_heads'
,
'feedforward_channels'
}
assert
isinstance
(
arch
,
dict
)
and
essential_keys
<=
set
(
arch
),
\
f
'Custom arch needs a dict with keys
{
essential_keys
}
'
self
.
arch_settings
=
arch
self
.
embed_dims
=
self
.
arch_settings
[
'embed_dims'
]
self
.
num_layers
=
self
.
arch_settings
[
'num_layers'
]
self
.
img_size
=
to_2tuple
(
img_size
)
# Set patch embedding
_patch_cfg
=
dict
(
in_channels
=
in_channels
,
input_size
=
img_size
,
embed_dims
=
self
.
embed_dims
,
conv_type
=
'Conv2d'
,
kernel_size
=
patch_size
,
stride
=
patch_size
,
)
_patch_cfg
.
update
(
patch_cfg
)
self
.
patch_embed
=
PatchEmbed
(
**
_patch_cfg
)
self
.
patch_resolution
=
self
.
patch_embed
.
init_out_size
num_patches
=
self
.
patch_resolution
[
0
]
*
self
.
patch_resolution
[
1
]
# Set out type
if
out_type
not
in
self
.
OUT_TYPES
:
raise
ValueError
(
f
'Unsupported `out_type`
{
out_type
}
, please '
f
'choose from
{
self
.
OUT_TYPES
}
'
)
self
.
out_type
=
out_type
# Set cls token
self
.
with_cls_token
=
with_cls_token
if
with_cls_token
:
self
.
cls_token
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
self
.
embed_dims
))
self
.
num_extra_tokens
=
1
elif
out_type
!=
'cls_token'
:
self
.
cls_token
=
None
self
.
num_extra_tokens
=
0
else
:
raise
ValueError
(
'with_cls_token must be True when `out_type="cls_token"`.'
)
# Set position embedding
self
.
interpolate_mode
=
interpolate_mode
if
use_abs_pos_emb
:
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
num_patches
+
self
.
num_extra_tokens
,
self
.
embed_dims
))
self
.
_register_load_state_dict_pre_hook
(
self
.
_prepare_pos_embed
)
else
:
self
.
pos_embed
=
None
self
.
drop_after_pos
=
nn
.
Dropout
(
p
=
drop_rate
)
assert
not
(
use_rel_pos_bias
and
use_shared_rel_pos_bias
),
(
'`use_rel_pos_bias` and `use_shared_rel_pos_bias` cannot be set '
'to True at the same time'
)
self
.
use_rel_pos_bias
=
use_rel_pos_bias
if
use_shared_rel_pos_bias
:
self
.
rel_pos_bias
=
RelativePositionBias
(
window_size
=
self
.
patch_resolution
,
num_heads
=
self
.
arch_settings
[
'num_heads'
])
else
:
self
.
rel_pos_bias
=
None
self
.
_register_load_state_dict_pre_hook
(
self
.
_prepare_relative_position_bias_table
)
if
isinstance
(
out_indices
,
int
):
out_indices
=
[
out_indices
]
assert
isinstance
(
out_indices
,
Sequence
),
\
f
'"out_indices" must by a sequence or int, '
\
f
'get
{
type
(
out_indices
)
}
instead.'
for
i
,
index
in
enumerate
(
out_indices
):
if
index
<
0
:
out_indices
[
i
]
=
self
.
num_layers
+
index
assert
0
<=
out_indices
[
i
]
<=
self
.
num_layers
,
\
f
'Invalid out_indices
{
index
}
'
self
.
out_indices
=
out_indices
# stochastic depth decay rule
dpr
=
np
.
linspace
(
0
,
drop_path_rate
,
self
.
num_layers
)
self
.
layers
=
ModuleList
()
if
isinstance
(
layer_cfgs
,
dict
):
layer_cfgs
=
[
layer_cfgs
]
*
self
.
num_layers
for
i
in
range
(
self
.
num_layers
):
_layer_cfg
=
dict
(
embed_dims
=
self
.
embed_dims
,
num_heads
=
self
.
arch_settings
[
'num_heads'
],
feedforward_channels
=
self
.
arch_settings
[
'feedforward_channels'
],
layer_scale_init_value
=
layer_scale_init_value
,
window_size
=
self
.
patch_resolution
,
use_rel_pos_bias
=
use_rel_pos_bias
,
drop_rate
=
drop_rate
,
drop_path_rate
=
dpr
[
i
],
bias
=
bias
,
norm_cfg
=
norm_cfg
)
_layer_cfg
.
update
(
layer_cfgs
[
i
])
self
.
layers
.
append
(
BEiTTransformerEncoderLayer
(
**
_layer_cfg
))
self
.
frozen_stages
=
frozen_stages
self
.
final_norm
=
final_norm
if
final_norm
:
self
.
ln1
=
build_norm_layer
(
norm_cfg
,
self
.
embed_dims
)
if
out_type
==
'avg_featmap'
:
self
.
ln2
=
build_norm_layer
(
norm_cfg
,
self
.
embed_dims
)
# freeze stages only when self.frozen_stages > 0
if
self
.
frozen_stages
>
0
:
self
.
_freeze_stages
()
@
property
def
norm1
(
self
):
return
self
.
ln1
@
property
def
norm2
(
self
):
return
self
.
ln2
def
init_weights
(
self
):
super
(
BEiTViT
,
self
).
init_weights
()
if
not
(
isinstance
(
self
.
init_cfg
,
dict
)
and
self
.
init_cfg
[
'type'
]
==
'Pretrained'
):
if
self
.
pos_embed
is
not
None
:
trunc_normal_
(
self
.
pos_embed
,
std
=
0.02
)
def
_prepare_pos_embed
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
name
=
prefix
+
'pos_embed'
if
name
not
in
state_dict
.
keys
():
return
ckpt_pos_embed_shape
=
state_dict
[
name
].
shape
if
(
not
self
.
with_cls_token
and
ckpt_pos_embed_shape
[
1
]
==
self
.
pos_embed
.
shape
[
1
]
+
1
):
# Remove cls token from state dict if it's not used.
state_dict
[
name
]
=
state_dict
[
name
][:,
1
:]
ckpt_pos_embed_shape
=
state_dict
[
name
].
shape
if
self
.
pos_embed
.
shape
!=
ckpt_pos_embed_shape
:
from
mmengine.logging
import
MMLogger
logger
=
MMLogger
.
get_current_instance
()
logger
.
info
(
f
'Resize the pos_embed shape from
{
ckpt_pos_embed_shape
}
'
f
'to
{
self
.
pos_embed
.
shape
}
.'
)
ckpt_pos_embed_shape
=
to_2tuple
(
int
(
np
.
sqrt
(
ckpt_pos_embed_shape
[
1
]
-
self
.
num_extra_tokens
)))
pos_embed_shape
=
self
.
patch_embed
.
init_out_size
state_dict
[
name
]
=
resize_pos_embed
(
state_dict
[
name
],
ckpt_pos_embed_shape
,
pos_embed_shape
,
self
.
interpolate_mode
,
self
.
num_extra_tokens
)
@
staticmethod
def
resize_pos_embed
(
*
args
,
**
kwargs
):
"""Interface for backward-compatibility."""
return
resize_pos_embed
(
*
args
,
**
kwargs
)
def
_freeze_stages
(
self
):
# freeze position embedding
if
self
.
pos_embed
is
not
None
:
self
.
pos_embed
.
requires_grad
=
False
# set dropout to eval model
self
.
drop_after_pos
.
eval
()
# freeze patch embedding
self
.
patch_embed
.
eval
()
for
param
in
self
.
patch_embed
.
parameters
():
param
.
requires_grad
=
False
# freeze cls_token
if
self
.
with_cls_token
:
self
.
cls_token
.
requires_grad
=
False
# freeze layers
for
i
in
range
(
1
,
self
.
frozen_stages
+
1
):
m
=
self
.
layers
[
i
-
1
]
m
.
eval
()
for
param
in
m
.
parameters
():
param
.
requires_grad
=
False
# freeze the last layer norm
if
self
.
frozen_stages
==
len
(
self
.
layers
):
if
self
.
final_norm
:
self
.
ln1
.
eval
()
for
param
in
self
.
ln1
.
parameters
():
param
.
requires_grad
=
False
if
self
.
out_type
==
'avg_featmap'
:
self
.
ln2
.
eval
()
for
param
in
self
.
ln2
.
parameters
():
param
.
requires_grad
=
False
def
forward
(
self
,
x
):
B
=
x
.
shape
[
0
]
x
,
patch_resolution
=
self
.
patch_embed
(
x
)
if
self
.
cls_token
is
not
None
:
# stole cls_tokens impl from Phil Wang, thanks
cls_token
=
self
.
cls_token
.
expand
(
B
,
-
1
,
-
1
)
x
=
torch
.
cat
((
cls_token
,
x
),
dim
=
1
)
if
self
.
pos_embed
is
not
None
:
x
=
x
+
resize_pos_embed
(
self
.
pos_embed
,
self
.
patch_resolution
,
patch_resolution
,
mode
=
self
.
interpolate_mode
,
num_extra_tokens
=
self
.
num_extra_tokens
)
x
=
self
.
drop_after_pos
(
x
)
rel_pos_bias
=
self
.
rel_pos_bias
()
\
if
self
.
rel_pos_bias
is
not
None
else
None
outs
=
[]
for
i
,
layer
in
enumerate
(
self
.
layers
):
x
=
layer
(
x
,
rel_pos_bias
)
if
i
==
len
(
self
.
layers
)
-
1
and
self
.
final_norm
:
x
=
self
.
ln1
(
x
)
if
i
in
self
.
out_indices
:
outs
.
append
(
self
.
_format_output
(
x
,
patch_resolution
))
return
tuple
(
outs
)
def
_format_output
(
self
,
x
,
hw
):
if
self
.
out_type
==
'raw'
:
return
x
if
self
.
out_type
==
'cls_token'
:
return
x
[:,
0
]
patch_token
=
x
[:,
self
.
num_extra_tokens
:]
if
self
.
out_type
==
'featmap'
:
B
=
x
.
size
(
0
)
# (B, N, C) -> (B, H, W, C) -> (B, C, H, W)
return
patch_token
.
reshape
(
B
,
*
hw
,
-
1
).
permute
(
0
,
3
,
1
,
2
)
if
self
.
out_type
==
'avg_featmap'
:
return
self
.
ln2
(
patch_token
.
mean
(
dim
=
1
))
def
_prepare_relative_position_bias_table
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
from
mmengine.logging
import
MMLogger
logger
=
MMLogger
.
get_current_instance
()
if
self
.
use_rel_pos_bias
and
'rel_pos_bias.relative_position_bias_table'
in
state_dict
:
# noqa:E501
logger
.
info
(
'Expand the shared relative position embedding to '
'each transformer block.'
)
rel_pos_bias
=
state_dict
[
'rel_pos_bias.relative_position_bias_table'
]
for
i
in
range
(
self
.
num_layers
):
state_dict
[
f
'layers.
{
i
}
.attn.relative_position_bias_table'
]
=
\
rel_pos_bias
.
clone
()
state_dict
.
pop
(
'rel_pos_bias.relative_position_bias_table'
)
state_dict
.
pop
(
'rel_pos_bias.relative_position_index'
)
state_dict_model
=
self
.
state_dict
()
all_keys
=
list
(
state_dict_model
.
keys
())
for
key
in
all_keys
:
if
'relative_position_bias_table'
in
key
:
ckpt_key
=
prefix
+
key
if
ckpt_key
not
in
state_dict
:
continue
rel_pos_bias_pretrained
=
state_dict
[
ckpt_key
]
rel_pos_bias_current
=
state_dict_model
[
key
]
L1
,
nH1
=
rel_pos_bias_pretrained
.
size
()
L2
,
nH2
=
rel_pos_bias_current
.
size
()
src_size
=
int
((
L1
-
3
)
**
0.5
)
dst_size
=
int
((
L2
-
3
)
**
0.5
)
if
L1
!=
L2
:
extra_tokens
=
rel_pos_bias_pretrained
[
-
3
:,
:]
rel_pos_bias
=
rel_pos_bias_pretrained
[:
-
3
,
:]
new_rel_pos_bias
=
resize_relative_position_bias_table
(
src_size
,
dst_size
,
rel_pos_bias
,
nH1
)
new_rel_pos_bias
=
torch
.
cat
(
(
new_rel_pos_bias
,
extra_tokens
),
dim
=
0
)
logger
.
info
(
'Resize the relative_position_bias_table from '
f
'
{
state_dict
[
ckpt_key
].
shape
}
to '
f
'
{
new_rel_pos_bias
.
shape
}
'
)
state_dict
[
ckpt_key
]
=
new_rel_pos_bias
# The index buffer need to be re-generated.
index_buffer
=
ckpt_key
.
replace
(
'bias_table'
,
'index'
)
if
index_buffer
in
state_dict
:
del
state_dict
[
index_buffer
]
def
get_layer_depth
(
self
,
param_name
:
str
,
prefix
:
str
=
''
):
"""Get the layer-wise depth of a parameter.
Args:
param_name (str): The name of the parameter.
prefix (str): The prefix for the parameter.
Defaults to an empty string.
Returns:
Tuple[int, int]: The layer-wise depth and the num of layers.
Note:
The first depth is the stem module (``layer_depth=0``), and the
last depth is the subsequent module (``layer_depth=num_layers-1``)
"""
num_layers
=
self
.
num_layers
+
2
if
not
param_name
.
startswith
(
prefix
):
# For subsequent module like head
return
num_layers
-
1
,
num_layers
param_name
=
param_name
[
len
(
prefix
):]
if
param_name
in
(
'cls_token'
,
'pos_embed'
):
layer_depth
=
0
elif
param_name
.
startswith
(
'patch_embed'
):
layer_depth
=
0
elif
param_name
.
startswith
(
'layers'
):
layer_id
=
int
(
param_name
.
split
(
'.'
)[
1
])
layer_depth
=
layer_id
+
1
else
:
layer_depth
=
num_layers
-
1
return
layer_depth
,
num_layers
mmpretrain/models/backbones/conformer.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
Sequence
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmcv.cnn
import
build_activation_layer
,
build_norm_layer
from
mmcv.cnn.bricks.drop
import
DropPath
from
mmcv.cnn.bricks.transformer
import
AdaptivePadding
from
mmengine.model
import
BaseModule
from
mmengine.model.weight_init
import
trunc_normal_
from
mmpretrain.registry
import
MODELS
from
.base_backbone
import
BaseBackbone
from
.vision_transformer
import
TransformerEncoderLayer
class
ConvBlock
(
BaseModule
):
"""Basic convluation block used in Conformer.
This block includes three convluation modules, and supports three new
functions:
1. Returns the output of both the final layers and the second convluation
module.
2. Fuses the input of the second convluation module with an extra input
feature map.
3. Supports to add an extra convluation module to the identity connection.
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
stride (int): The stride of the second convluation module.
Defaults to 1.
groups (int): The groups of the second convluation module.
Defaults to 1.
drop_path_rate (float): The rate of the DropPath layer. Defaults to 0.
with_residual_conv (bool): Whether to add an extra convluation module
to the identity connection. Defaults to False.
norm_cfg (dict): The config of normalization layers.
Defaults to ``dict(type='BN', eps=1e-6)``.
act_cfg (dict): The config of activative functions.
Defaults to ``dict(type='ReLU', inplace=True))``.
init_cfg (dict, optional): The extra config to initialize the module.
Defaults to None.
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
stride
=
1
,
groups
=
1
,
drop_path_rate
=
0.
,
with_residual_conv
=
False
,
norm_cfg
=
dict
(
type
=
'BN'
,
eps
=
1e-6
),
act_cfg
=
dict
(
type
=
'ReLU'
,
inplace
=
True
),
init_cfg
=
None
):
super
(
ConvBlock
,
self
).
__init__
(
init_cfg
=
init_cfg
)
expansion
=
4
mid_channels
=
out_channels
//
expansion
self
.
conv1
=
nn
.
Conv2d
(
in_channels
,
mid_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias
=
False
)
self
.
bn1
=
build_norm_layer
(
norm_cfg
,
mid_channels
)[
1
]
self
.
act1
=
build_activation_layer
(
act_cfg
)
self
.
conv2
=
nn
.
Conv2d
(
mid_channels
,
mid_channels
,
kernel_size
=
3
,
stride
=
stride
,
groups
=
groups
,
padding
=
1
,
bias
=
False
)
self
.
bn2
=
build_norm_layer
(
norm_cfg
,
mid_channels
)[
1
]
self
.
act2
=
build_activation_layer
(
act_cfg
)
self
.
conv3
=
nn
.
Conv2d
(
mid_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias
=
False
)
self
.
bn3
=
build_norm_layer
(
norm_cfg
,
out_channels
)[
1
]
self
.
act3
=
build_activation_layer
(
act_cfg
)
if
with_residual_conv
:
self
.
residual_conv
=
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
stride
,
padding
=
0
,
bias
=
False
)
self
.
residual_bn
=
build_norm_layer
(
norm_cfg
,
out_channels
)[
1
]
self
.
with_residual_conv
=
with_residual_conv
self
.
drop_path
=
DropPath
(
drop_path_rate
)
if
drop_path_rate
>
0.
else
nn
.
Identity
()
def
zero_init_last_bn
(
self
):
nn
.
init
.
zeros_
(
self
.
bn3
.
weight
)
def
forward
(
self
,
x
,
fusion_features
=
None
,
out_conv2
=
True
):
identity
=
x
x
=
self
.
conv1
(
x
)
x
=
self
.
bn1
(
x
)
x
=
self
.
act1
(
x
)
x
=
self
.
conv2
(
x
)
if
fusion_features
is
None
else
self
.
conv2
(
x
+
fusion_features
)
x
=
self
.
bn2
(
x
)
x2
=
self
.
act2
(
x
)
x
=
self
.
conv3
(
x2
)
x
=
self
.
bn3
(
x
)
if
self
.
drop_path
is
not
None
:
x
=
self
.
drop_path
(
x
)
if
self
.
with_residual_conv
:
identity
=
self
.
residual_conv
(
identity
)
identity
=
self
.
residual_bn
(
identity
)
x
+=
identity
x
=
self
.
act3
(
x
)
if
out_conv2
:
return
x
,
x2
else
:
return
x
class
FCUDown
(
BaseModule
):
"""CNN feature maps -> Transformer patch embeddings."""
def
__init__
(
self
,
in_channels
,
out_channels
,
down_stride
,
with_cls_token
=
True
,
norm_cfg
=
dict
(
type
=
'LN'
,
eps
=
1e-6
),
act_cfg
=
dict
(
type
=
'GELU'
),
init_cfg
=
None
):
super
(
FCUDown
,
self
).
__init__
(
init_cfg
=
init_cfg
)
self
.
down_stride
=
down_stride
self
.
with_cls_token
=
with_cls_token
self
.
conv_project
=
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
sample_pooling
=
nn
.
AvgPool2d
(
kernel_size
=
down_stride
,
stride
=
down_stride
)
self
.
ln
=
build_norm_layer
(
norm_cfg
,
out_channels
)[
1
]
self
.
act
=
build_activation_layer
(
act_cfg
)
def
forward
(
self
,
x
,
x_t
):
x
=
self
.
conv_project
(
x
)
# [N, C, H, W]
x
=
self
.
sample_pooling
(
x
).
flatten
(
2
).
transpose
(
1
,
2
)
x
=
self
.
ln
(
x
)
x
=
self
.
act
(
x
)
if
self
.
with_cls_token
:
x
=
torch
.
cat
([
x_t
[:,
0
][:,
None
,
:],
x
],
dim
=
1
)
return
x
class
FCUUp
(
BaseModule
):
"""Transformer patch embeddings -> CNN feature maps."""
def
__init__
(
self
,
in_channels
,
out_channels
,
up_stride
,
with_cls_token
=
True
,
norm_cfg
=
dict
(
type
=
'BN'
,
eps
=
1e-6
),
act_cfg
=
dict
(
type
=
'ReLU'
,
inplace
=
True
),
init_cfg
=
None
):
super
(
FCUUp
,
self
).
__init__
(
init_cfg
=
init_cfg
)
self
.
up_stride
=
up_stride
self
.
with_cls_token
=
with_cls_token
self
.
conv_project
=
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
bn
=
build_norm_layer
(
norm_cfg
,
out_channels
)[
1
]
self
.
act
=
build_activation_layer
(
act_cfg
)
def
forward
(
self
,
x
,
H
,
W
):
B
,
_
,
C
=
x
.
shape
# [N, 197, 384] -> [N, 196, 384] -> [N, 384, 196] -> [N, 384, 14, 14]
if
self
.
with_cls_token
:
x_r
=
x
[:,
1
:].
transpose
(
1
,
2
).
reshape
(
B
,
C
,
H
,
W
)
else
:
x_r
=
x
.
transpose
(
1
,
2
).
reshape
(
B
,
C
,
H
,
W
)
x_r
=
self
.
act
(
self
.
bn
(
self
.
conv_project
(
x_r
)))
return
F
.
interpolate
(
x_r
,
size
=
(
H
*
self
.
up_stride
,
W
*
self
.
up_stride
))
class
ConvTransBlock
(
BaseModule
):
"""Basic module for Conformer.
This module is a fusion of CNN block transformer encoder block.
Args:
in_channels (int): The number of input channels in conv blocks.
out_channels (int): The number of output channels in conv blocks.
embed_dims (int): The embedding dimension in transformer blocks.
conv_stride (int): The stride of conv2d layers. Defaults to 1.
groups (int): The groups of conv blocks. Defaults to 1.
with_residual_conv (bool): Whether to add a conv-bn layer to the
identity connect in the conv block. Defaults to False.
down_stride (int): The stride of the downsample pooling layer.
Defaults to 4.
num_heads (int): The number of heads in transformer attention layers.
Defaults to 12.
mlp_ratio (float): The expansion ratio in transformer FFN module.
Defaults to 4.
qkv_bias (bool): Enable bias for qkv if True. Defaults to False.
with_cls_token (bool): Whether use class token or not.
Defaults to True.
drop_rate (float): The dropout rate of the output projection and
FFN in the transformer block. Defaults to 0.
attn_drop_rate (float): The dropout rate after the attention
calculation in the transformer block. Defaults to 0.
drop_path_rate (bloat): The drop path rate in both the conv block
and the transformer block. Defaults to 0.
last_fusion (bool): Whether this block is the last stage. If so,
downsample the fusion feature map.
init_cfg (dict, optional): The extra config to initialize the module.
Defaults to None.
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
embed_dims
,
conv_stride
=
1
,
groups
=
1
,
with_residual_conv
=
False
,
down_stride
=
4
,
num_heads
=
12
,
mlp_ratio
=
4.
,
qkv_bias
=
False
,
with_cls_token
=
True
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.
,
last_fusion
=
False
,
init_cfg
=
None
):
super
(
ConvTransBlock
,
self
).
__init__
(
init_cfg
=
init_cfg
)
expansion
=
4
self
.
cnn_block
=
ConvBlock
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
with_residual_conv
=
with_residual_conv
,
stride
=
conv_stride
,
groups
=
groups
)
if
last_fusion
:
self
.
fusion_block
=
ConvBlock
(
in_channels
=
out_channels
,
out_channels
=
out_channels
,
stride
=
2
,
with_residual_conv
=
True
,
groups
=
groups
,
drop_path_rate
=
drop_path_rate
)
else
:
self
.
fusion_block
=
ConvBlock
(
in_channels
=
out_channels
,
out_channels
=
out_channels
,
groups
=
groups
,
drop_path_rate
=
drop_path_rate
)
self
.
squeeze_block
=
FCUDown
(
in_channels
=
out_channels
//
expansion
,
out_channels
=
embed_dims
,
down_stride
=
down_stride
,
with_cls_token
=
with_cls_token
)
self
.
expand_block
=
FCUUp
(
in_channels
=
embed_dims
,
out_channels
=
out_channels
//
expansion
,
up_stride
=
down_stride
,
with_cls_token
=
with_cls_token
)
self
.
trans_block
=
TransformerEncoderLayer
(
embed_dims
=
embed_dims
,
num_heads
=
num_heads
,
feedforward_channels
=
int
(
embed_dims
*
mlp_ratio
),
drop_rate
=
drop_rate
,
drop_path_rate
=
drop_path_rate
,
attn_drop_rate
=
attn_drop_rate
,
qkv_bias
=
qkv_bias
,
norm_cfg
=
dict
(
type
=
'LN'
,
eps
=
1e-6
))
self
.
down_stride
=
down_stride
self
.
embed_dim
=
embed_dims
self
.
last_fusion
=
last_fusion
def
forward
(
self
,
cnn_input
,
trans_input
):
x
,
x_conv2
=
self
.
cnn_block
(
cnn_input
,
out_conv2
=
True
)
_
,
_
,
H
,
W
=
x_conv2
.
shape
# Convert the feature map of conv2 to transformer embedding
# and concat with class token.
conv2_embedding
=
self
.
squeeze_block
(
x_conv2
,
trans_input
)
trans_output
=
self
.
trans_block
(
conv2_embedding
+
trans_input
)
# Convert the transformer output embedding to feature map
trans_features
=
self
.
expand_block
(
trans_output
,
H
//
self
.
down_stride
,
W
//
self
.
down_stride
)
x
=
self
.
fusion_block
(
x
,
fusion_features
=
trans_features
,
out_conv2
=
False
)
return
x
,
trans_output
@
MODELS
.
register_module
()
class
Conformer
(
BaseBackbone
):
"""Conformer backbone.
A PyTorch implementation of : `Conformer: Local Features Coupling Global
Representations for Visual Recognition <https://arxiv.org/abs/2105.03889>`_
Args:
arch (str | dict): Conformer architecture. Defaults to 'tiny'.
patch_size (int): The patch size. Defaults to 16.
base_channels (int): The base number of channels in CNN network.
Defaults to 64.
mlp_ratio (float): The expansion ratio of FFN network in transformer
block. Defaults to 4.
with_cls_token (bool): Whether use class token or not.
Defaults to True.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
arch_zoo
=
{
**
dict
.
fromkeys
([
't'
,
'tiny'
],
{
'embed_dims'
:
384
,
'channel_ratio'
:
1
,
'num_heads'
:
6
,
'depths'
:
12
}),
**
dict
.
fromkeys
([
's'
,
'small'
],
{
'embed_dims'
:
384
,
'channel_ratio'
:
4
,
'num_heads'
:
6
,
'depths'
:
12
}),
**
dict
.
fromkeys
([
'b'
,
'base'
],
{
'embed_dims'
:
576
,
'channel_ratio'
:
6
,
'num_heads'
:
9
,
'depths'
:
12
}),
}
# yapf: disable
_version
=
1
def
__init__
(
self
,
arch
=
'tiny'
,
patch_size
=
16
,
base_channels
=
64
,
mlp_ratio
=
4.
,
qkv_bias
=
True
,
with_cls_token
=
True
,
drop_path_rate
=
0.
,
norm_eval
=
True
,
frozen_stages
=
0
,
out_indices
=-
1
,
init_cfg
=
None
):
super
().
__init__
(
init_cfg
=
init_cfg
)
if
isinstance
(
arch
,
str
):
arch
=
arch
.
lower
()
assert
arch
in
set
(
self
.
arch_zoo
),
\
f
'Arch
{
arch
}
is not in default archs
{
set
(
self
.
arch_zoo
)
}
'
self
.
arch_settings
=
self
.
arch_zoo
[
arch
]
else
:
essential_keys
=
{
'embed_dims'
,
'depths'
,
'num_heads'
,
'channel_ratio'
}
assert
isinstance
(
arch
,
dict
)
and
set
(
arch
)
==
essential_keys
,
\
f
'Custom arch needs a dict with keys
{
essential_keys
}
'
self
.
arch_settings
=
arch
self
.
num_features
=
self
.
embed_dims
=
self
.
arch_settings
[
'embed_dims'
]
self
.
depths
=
self
.
arch_settings
[
'depths'
]
self
.
num_heads
=
self
.
arch_settings
[
'num_heads'
]
self
.
channel_ratio
=
self
.
arch_settings
[
'channel_ratio'
]
if
isinstance
(
out_indices
,
int
):
out_indices
=
[
out_indices
]
assert
isinstance
(
out_indices
,
Sequence
),
\
f
'"out_indices" must by a sequence or int, '
\
f
'get
{
type
(
out_indices
)
}
instead.'
for
i
,
index
in
enumerate
(
out_indices
):
if
index
<
0
:
out_indices
[
i
]
=
self
.
depths
+
index
+
1
assert
out_indices
[
i
]
>=
0
,
f
'Invalid out_indices
{
index
}
'
self
.
out_indices
=
out_indices
self
.
norm_eval
=
norm_eval
self
.
frozen_stages
=
frozen_stages
self
.
with_cls_token
=
with_cls_token
if
self
.
with_cls_token
:
self
.
cls_token
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
self
.
embed_dims
))
# stochastic depth decay rule
self
.
trans_dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
self
.
depths
)
]
# Stem stage: get the feature maps by conv block
self
.
conv1
=
nn
.
Conv2d
(
3
,
64
,
kernel_size
=
7
,
stride
=
2
,
padding
=
3
,
bias
=
False
)
# 1 / 2 [112, 112]
self
.
bn1
=
nn
.
BatchNorm2d
(
64
)
self
.
act1
=
nn
.
ReLU
(
inplace
=
True
)
self
.
maxpool
=
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
# 1 / 4 [56, 56]
assert
patch_size
%
16
==
0
,
'The patch size of Conformer must '
\
'be divisible by 16.'
trans_down_stride
=
patch_size
//
4
# To solve the issue #680
# Auto pad the feature map to be divisible by trans_down_stride
self
.
auto_pad
=
AdaptivePadding
(
trans_down_stride
,
trans_down_stride
)
# 1 stage
stage1_channels
=
int
(
base_channels
*
self
.
channel_ratio
)
self
.
conv_1
=
ConvBlock
(
in_channels
=
64
,
out_channels
=
stage1_channels
,
with_residual_conv
=
True
,
stride
=
1
)
self
.
trans_patch_conv
=
nn
.
Conv2d
(
64
,
self
.
embed_dims
,
kernel_size
=
trans_down_stride
,
stride
=
trans_down_stride
,
padding
=
0
)
self
.
trans_1
=
TransformerEncoderLayer
(
embed_dims
=
self
.
embed_dims
,
num_heads
=
self
.
num_heads
,
feedforward_channels
=
int
(
self
.
embed_dims
*
mlp_ratio
),
drop_path_rate
=
self
.
trans_dpr
[
0
],
qkv_bias
=
qkv_bias
,
norm_cfg
=
dict
(
type
=
'LN'
,
eps
=
1e-6
))
# 2~4 stage
init_stage
=
2
fin_stage
=
self
.
depths
//
3
+
1
for
i
in
range
(
init_stage
,
fin_stage
):
self
.
add_module
(
f
'conv_trans_
{
i
}
'
,
ConvTransBlock
(
in_channels
=
stage1_channels
,
out_channels
=
stage1_channels
,
embed_dims
=
self
.
embed_dims
,
conv_stride
=
1
,
with_residual_conv
=
False
,
down_stride
=
trans_down_stride
,
num_heads
=
self
.
num_heads
,
mlp_ratio
=
mlp_ratio
,
qkv_bias
=
qkv_bias
,
drop_path_rate
=
self
.
trans_dpr
[
i
-
1
],
with_cls_token
=
self
.
with_cls_token
))
stage2_channels
=
int
(
base_channels
*
self
.
channel_ratio
*
2
)
# 5~8 stage
init_stage
=
fin_stage
# 5
fin_stage
=
fin_stage
+
self
.
depths
//
3
# 9
for
i
in
range
(
init_stage
,
fin_stage
):
if
i
==
init_stage
:
conv_stride
=
2
in_channels
=
stage1_channels
else
:
conv_stride
=
1
in_channels
=
stage2_channels
with_residual_conv
=
True
if
i
==
init_stage
else
False
self
.
add_module
(
f
'conv_trans_
{
i
}
'
,
ConvTransBlock
(
in_channels
=
in_channels
,
out_channels
=
stage2_channels
,
embed_dims
=
self
.
embed_dims
,
conv_stride
=
conv_stride
,
with_residual_conv
=
with_residual_conv
,
down_stride
=
trans_down_stride
//
2
,
num_heads
=
self
.
num_heads
,
mlp_ratio
=
mlp_ratio
,
qkv_bias
=
qkv_bias
,
drop_path_rate
=
self
.
trans_dpr
[
i
-
1
],
with_cls_token
=
self
.
with_cls_token
))
stage3_channels
=
int
(
base_channels
*
self
.
channel_ratio
*
2
*
2
)
# 9~12 stage
init_stage
=
fin_stage
# 9
fin_stage
=
fin_stage
+
self
.
depths
//
3
# 13
for
i
in
range
(
init_stage
,
fin_stage
):
if
i
==
init_stage
:
conv_stride
=
2
in_channels
=
stage2_channels
with_residual_conv
=
True
else
:
conv_stride
=
1
in_channels
=
stage3_channels
with_residual_conv
=
False
last_fusion
=
(
i
==
self
.
depths
)
self
.
add_module
(
f
'conv_trans_
{
i
}
'
,
ConvTransBlock
(
in_channels
=
in_channels
,
out_channels
=
stage3_channels
,
embed_dims
=
self
.
embed_dims
,
conv_stride
=
conv_stride
,
with_residual_conv
=
with_residual_conv
,
down_stride
=
trans_down_stride
//
4
,
num_heads
=
self
.
num_heads
,
mlp_ratio
=
mlp_ratio
,
qkv_bias
=
qkv_bias
,
drop_path_rate
=
self
.
trans_dpr
[
i
-
1
],
with_cls_token
=
self
.
with_cls_token
,
last_fusion
=
last_fusion
))
self
.
fin_stage
=
fin_stage
self
.
pooling
=
nn
.
AdaptiveAvgPool2d
(
1
)
self
.
trans_norm
=
nn
.
LayerNorm
(
self
.
embed_dims
)
if
self
.
with_cls_token
:
trunc_normal_
(
self
.
cls_token
,
std
=
.
02
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
elif
isinstance
(
m
,
nn
.
Conv2d
):
nn
.
init
.
kaiming_normal_
(
m
.
weight
,
mode
=
'fan_out'
,
nonlinearity
=
'relu'
)
elif
isinstance
(
m
,
nn
.
BatchNorm2d
):
nn
.
init
.
constant_
(
m
.
weight
,
1.
)
nn
.
init
.
constant_
(
m
.
bias
,
0.
)
if
hasattr
(
m
,
'zero_init_last_bn'
):
m
.
zero_init_last_bn
()
def
init_weights
(
self
):
super
(
Conformer
,
self
).
init_weights
()
if
(
isinstance
(
self
.
init_cfg
,
dict
)
and
self
.
init_cfg
[
'type'
]
==
'Pretrained'
):
# Suppress default init if use pretrained model.
return
self
.
apply
(
self
.
_init_weights
)
def
forward
(
self
,
x
):
output
=
[]
B
=
x
.
shape
[
0
]
if
self
.
with_cls_token
:
cls_tokens
=
self
.
cls_token
.
expand
(
B
,
-
1
,
-
1
)
# stem
x_base
=
self
.
maxpool
(
self
.
act1
(
self
.
bn1
(
self
.
conv1
(
x
))))
x_base
=
self
.
auto_pad
(
x_base
)
# 1 stage [N, 64, 56, 56] -> [N, 128, 56, 56]
x
=
self
.
conv_1
(
x_base
,
out_conv2
=
False
)
x_t
=
self
.
trans_patch_conv
(
x_base
).
flatten
(
2
).
transpose
(
1
,
2
)
if
self
.
with_cls_token
:
x_t
=
torch
.
cat
([
cls_tokens
,
x_t
],
dim
=
1
)
x_t
=
self
.
trans_1
(
x_t
)
# 2 ~ final
for
i
in
range
(
2
,
self
.
fin_stage
):
stage
=
getattr
(
self
,
f
'conv_trans_
{
i
}
'
)
x
,
x_t
=
stage
(
x
,
x_t
)
if
i
in
self
.
out_indices
:
if
self
.
with_cls_token
:
output
.
append
([
self
.
pooling
(
x
).
flatten
(
1
),
self
.
trans_norm
(
x_t
)[:,
0
]
])
else
:
# if no class token, use the mean patch token
# as the transformer feature.
output
.
append
([
self
.
pooling
(
x
).
flatten
(
1
),
self
.
trans_norm
(
x_t
).
mean
(
dim
=
1
)
])
return
tuple
(
output
)
mmpretrain/models/backbones/convmixer.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
Sequence
import
torch
import
torch.nn
as
nn
from
mmcv.cnn.bricks
import
(
Conv2dAdaptivePadding
,
build_activation_layer
,
build_norm_layer
)
from
mmengine.utils
import
digit_version
from
mmpretrain.registry
import
MODELS
from
.base_backbone
import
BaseBackbone
class
Residual
(
nn
.
Module
):
def
__init__
(
self
,
fn
):
super
().
__init__
()
self
.
fn
=
fn
def
forward
(
self
,
x
):
return
self
.
fn
(
x
)
+
x
@
MODELS
.
register_module
()
class
ConvMixer
(
BaseBackbone
):
"""ConvMixer. .
A PyTorch implementation of : `Patches Are All You Need?
<https://arxiv.org/pdf/2201.09792.pdf>`_
Modified from the `official repo
<https://github.com/locuslab/convmixer/blob/main/convmixer.py>`_
and `timm
<https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/convmixer.py>`_.
Args:
arch (str | dict): The model's architecture. If string, it should be
one of architecture in ``ConvMixer.arch_settings``. And if dict, it
should include the following two keys:
- embed_dims (int): The dimensions of patch embedding.
- depth (int): Number of repetitions of ConvMixer Layer.
- patch_size (int): The patch size.
- kernel_size (int): The kernel size of depthwise conv layers.
Defaults to '768/32'.
in_channels (int): Number of input image channels. Defaults to 3.
patch_size (int): The size of one patch in the patch embed layer.
Defaults to 7.
norm_cfg (dict): The config dict for norm layers.
Defaults to ``dict(type='BN')``.
act_cfg (dict): The config dict for activation after each convolution.
Defaults to ``dict(type='GELU')``.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
frozen_stages (int): Stages to be frozen (all param fixed).
Defaults to 0, which means not freezing any parameters.
init_cfg (dict, optional): Initialization config dict.
"""
arch_settings
=
{
'768/32'
:
{
'embed_dims'
:
768
,
'depth'
:
32
,
'patch_size'
:
7
,
'kernel_size'
:
7
},
'1024/20'
:
{
'embed_dims'
:
1024
,
'depth'
:
20
,
'patch_size'
:
14
,
'kernel_size'
:
9
},
'1536/20'
:
{
'embed_dims'
:
1536
,
'depth'
:
20
,
'patch_size'
:
7
,
'kernel_size'
:
9
},
}
def
__init__
(
self
,
arch
=
'768/32'
,
in_channels
=
3
,
norm_cfg
=
dict
(
type
=
'BN'
),
act_cfg
=
dict
(
type
=
'GELU'
),
out_indices
=-
1
,
frozen_stages
=
0
,
init_cfg
=
None
):
super
().
__init__
(
init_cfg
=
init_cfg
)
if
isinstance
(
arch
,
str
):
assert
arch
in
self
.
arch_settings
,
\
f
'Unavailable arch, please choose from '
\
f
'(
{
set
(
self
.
arch_settings
)
}
) or pass a dict.'
arch
=
self
.
arch_settings
[
arch
]
elif
isinstance
(
arch
,
dict
):
essential_keys
=
{
'embed_dims'
,
'depth'
,
'patch_size'
,
'kernel_size'
}
assert
isinstance
(
arch
,
dict
)
and
essential_keys
<=
set
(
arch
),
\
f
'Custom arch needs a dict with keys
{
essential_keys
}
'
self
.
embed_dims
=
arch
[
'embed_dims'
]
self
.
depth
=
arch
[
'depth'
]
self
.
patch_size
=
arch
[
'patch_size'
]
self
.
kernel_size
=
arch
[
'kernel_size'
]
self
.
act
=
build_activation_layer
(
act_cfg
)
# check out indices and frozen stages
if
isinstance
(
out_indices
,
int
):
out_indices
=
[
out_indices
]
assert
isinstance
(
out_indices
,
Sequence
),
\
f
'"out_indices" must by a sequence or int, '
\
f
'get
{
type
(
out_indices
)
}
instead.'
for
i
,
index
in
enumerate
(
out_indices
):
if
index
<
0
:
out_indices
[
i
]
=
self
.
depth
+
index
assert
out_indices
[
i
]
>=
0
,
f
'Invalid out_indices
{
index
}
'
self
.
out_indices
=
out_indices
self
.
frozen_stages
=
frozen_stages
# Set stem layers
self
.
stem
=
nn
.
Sequential
(
nn
.
Conv2d
(
in_channels
,
self
.
embed_dims
,
kernel_size
=
self
.
patch_size
,
stride
=
self
.
patch_size
),
self
.
act
,
build_norm_layer
(
norm_cfg
,
self
.
embed_dims
)[
1
])
# Set conv2d according to torch version
convfunc
=
nn
.
Conv2d
if
digit_version
(
torch
.
__version__
)
<
digit_version
(
'1.9.0'
):
convfunc
=
Conv2dAdaptivePadding
# Repetitions of ConvMixer Layer
self
.
stages
=
nn
.
Sequential
(
*
[
nn
.
Sequential
(
Residual
(
nn
.
Sequential
(
convfunc
(
self
.
embed_dims
,
self
.
embed_dims
,
self
.
kernel_size
,
groups
=
self
.
embed_dims
,
padding
=
'same'
),
self
.
act
,
build_norm_layer
(
norm_cfg
,
self
.
embed_dims
)[
1
])),
nn
.
Conv2d
(
self
.
embed_dims
,
self
.
embed_dims
,
kernel_size
=
1
),
self
.
act
,
build_norm_layer
(
norm_cfg
,
self
.
embed_dims
)[
1
])
for
_
in
range
(
self
.
depth
)
])
self
.
_freeze_stages
()
def
forward
(
self
,
x
):
x
=
self
.
stem
(
x
)
outs
=
[]
for
i
,
stage
in
enumerate
(
self
.
stages
):
x
=
stage
(
x
)
if
i
in
self
.
out_indices
:
outs
.
append
(
x
)
# x = self.pooling(x).flatten(1)
return
tuple
(
outs
)
def
train
(
self
,
mode
=
True
):
super
(
ConvMixer
,
self
).
train
(
mode
)
self
.
_freeze_stages
()
def
_freeze_stages
(
self
):
for
i
in
range
(
self
.
frozen_stages
):
stage
=
self
.
stages
[
i
]
stage
.
eval
()
for
param
in
stage
.
parameters
():
param
.
requires_grad
=
False
mmpretrain/models/backbones/convnext.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
from
functools
import
partial
from
itertools
import
chain
from
typing
import
Sequence
import
torch
import
torch.nn
as
nn
import
torch.utils.checkpoint
as
cp
from
mmcv.cnn.bricks
import
DropPath
from
mmengine.model
import
BaseModule
,
ModuleList
,
Sequential
from
mmpretrain.registry
import
MODELS
from
..utils
import
GRN
,
build_norm_layer
from
.base_backbone
import
BaseBackbone
class
ConvNeXtBlock
(
BaseModule
):
"""ConvNeXt Block.
Args:
in_channels (int): The number of input channels.
dw_conv_cfg (dict): Config of depthwise convolution.
Defaults to ``dict(kernel_size=7, padding=3)``.
norm_cfg (dict): The config dict for norm layers.
Defaults to ``dict(type='LN2d', eps=1e-6)``.
act_cfg (dict): The config dict for activation between pointwise
convolution. Defaults to ``dict(type='GELU')``.
mlp_ratio (float): The expansion ratio in both pointwise convolution.
Defaults to 4.
linear_pw_conv (bool): Whether to use linear layer to do pointwise
convolution. More details can be found in the note.
Defaults to True.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
layer_scale_init_value (float): Init value for Layer Scale.
Defaults to 1e-6.
Note:
There are two equivalent implementations:
1. DwConv -> LayerNorm -> 1x1 Conv -> GELU -> 1x1 Conv;
all outputs are in (N, C, H, W).
2. DwConv -> LayerNorm -> Permute to (N, H, W, C) -> Linear -> GELU
-> Linear; Permute back
As default, we use the second to align with the official repository.
And it may be slightly faster.
"""
def
__init__
(
self
,
in_channels
,
dw_conv_cfg
=
dict
(
kernel_size
=
7
,
padding
=
3
),
norm_cfg
=
dict
(
type
=
'LN2d'
,
eps
=
1e-6
),
act_cfg
=
dict
(
type
=
'GELU'
),
mlp_ratio
=
4.
,
linear_pw_conv
=
True
,
drop_path_rate
=
0.
,
layer_scale_init_value
=
1e-6
,
use_grn
=
False
,
with_cp
=
False
):
super
().
__init__
()
self
.
with_cp
=
with_cp
self
.
depthwise_conv
=
nn
.
Conv2d
(
in_channels
,
in_channels
,
groups
=
in_channels
,
**
dw_conv_cfg
)
self
.
linear_pw_conv
=
linear_pw_conv
self
.
norm
=
build_norm_layer
(
norm_cfg
,
in_channels
)
mid_channels
=
int
(
mlp_ratio
*
in_channels
)
if
self
.
linear_pw_conv
:
# Use linear layer to do pointwise conv.
pw_conv
=
nn
.
Linear
else
:
pw_conv
=
partial
(
nn
.
Conv2d
,
kernel_size
=
1
)
self
.
pointwise_conv1
=
pw_conv
(
in_channels
,
mid_channels
)
self
.
act
=
MODELS
.
build
(
act_cfg
)
self
.
pointwise_conv2
=
pw_conv
(
mid_channels
,
in_channels
)
if
use_grn
:
self
.
grn
=
GRN
(
mid_channels
)
else
:
self
.
grn
=
None
self
.
gamma
=
nn
.
Parameter
(
layer_scale_init_value
*
torch
.
ones
((
in_channels
)),
requires_grad
=
True
)
if
layer_scale_init_value
>
0
else
None
self
.
drop_path
=
DropPath
(
drop_path_rate
)
if
drop_path_rate
>
0.
else
nn
.
Identity
()
def
forward
(
self
,
x
):
def
_inner_forward
(
x
):
shortcut
=
x
x
=
self
.
depthwise_conv
(
x
)
if
self
.
linear_pw_conv
:
x
=
x
.
permute
(
0
,
2
,
3
,
1
)
# (N, C, H, W) -> (N, H, W, C)
x
=
self
.
norm
(
x
,
data_format
=
'channel_last'
)
x
=
self
.
pointwise_conv1
(
x
)
x
=
self
.
act
(
x
)
if
self
.
grn
is
not
None
:
x
=
self
.
grn
(
x
,
data_format
=
'channel_last'
)
x
=
self
.
pointwise_conv2
(
x
)
x
=
x
.
permute
(
0
,
3
,
1
,
2
)
# (N, H, W, C) -> (N, C, H, W)
else
:
x
=
self
.
norm
(
x
,
data_format
=
'channel_first'
)
x
=
self
.
pointwise_conv1
(
x
)
x
=
self
.
act
(
x
)
if
self
.
grn
is
not
None
:
x
=
self
.
grn
(
x
,
data_format
=
'channel_first'
)
x
=
self
.
pointwise_conv2
(
x
)
if
self
.
gamma
is
not
None
:
x
=
x
.
mul
(
self
.
gamma
.
view
(
1
,
-
1
,
1
,
1
))
x
=
shortcut
+
self
.
drop_path
(
x
)
return
x
if
self
.
with_cp
and
x
.
requires_grad
:
x
=
cp
.
checkpoint
(
_inner_forward
,
x
)
else
:
x
=
_inner_forward
(
x
)
return
x
@
MODELS
.
register_module
()
class
ConvNeXt
(
BaseBackbone
):
"""ConvNeXt v1&v2 backbone.
A PyTorch implementation of `A ConvNet for the 2020s
<https://arxiv.org/abs/2201.03545>`_ and
`ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders
<http://arxiv.org/abs/2301.00808>`_
Modified from the `official repo
<https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py>`_
and `timm
<https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/convnext.py>`_.
To use ConvNeXt v2, please set ``use_grn=True`` and ``layer_scale_init_value=0.``.
Args:
arch (str | dict): The model's architecture. If string, it should be
one of architecture in ``ConvNeXt.arch_settings``. And if dict, it
should include the following two keys:
- depths (list[int]): Number of blocks at each stage.
- channels (list[int]): The number of channels at each stage.
Defaults to 'tiny'.
in_channels (int): Number of input image channels. Defaults to 3.
stem_patch_size (int): The size of one patch in the stem layer.
Defaults to 4.
norm_cfg (dict): The config dict for norm layers.
Defaults to ``dict(type='LN2d', eps=1e-6)``.
act_cfg (dict): The config dict for activation between pointwise
convolution. Defaults to ``dict(type='GELU')``.
linear_pw_conv (bool): Whether to use linear layer to do pointwise
convolution. Defaults to True.
use_grn (bool): Whether to add Global Response Normalization in the
blocks. Defaults to False.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
layer_scale_init_value (float): Init value for Layer Scale.
Defaults to 1e-6.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
frozen_stages (int): Stages to be frozen (all param fixed).
Defaults to 0, which means not freezing any parameters.
gap_before_final_norm (bool): Whether to globally average the feature
map before the final norm layer. In the official repo, it's only
used in classification task. Defaults to True.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
init_cfg (dict, optional): Initialization config dict
"""
# noqa: E501
arch_settings
=
{
'atto'
:
{
'depths'
:
[
2
,
2
,
6
,
2
],
'channels'
:
[
40
,
80
,
160
,
320
]
},
'femto'
:
{
'depths'
:
[
2
,
2
,
6
,
2
],
'channels'
:
[
48
,
96
,
192
,
384
]
},
'pico'
:
{
'depths'
:
[
2
,
2
,
6
,
2
],
'channels'
:
[
64
,
128
,
256
,
512
]
},
'nano'
:
{
'depths'
:
[
2
,
2
,
8
,
2
],
'channels'
:
[
80
,
160
,
320
,
640
]
},
'tiny'
:
{
'depths'
:
[
3
,
3
,
9
,
3
],
'channels'
:
[
96
,
192
,
384
,
768
]
},
'small'
:
{
'depths'
:
[
3
,
3
,
27
,
3
],
'channels'
:
[
96
,
192
,
384
,
768
]
},
'base'
:
{
'depths'
:
[
3
,
3
,
27
,
3
],
'channels'
:
[
128
,
256
,
512
,
1024
]
},
'large'
:
{
'depths'
:
[
3
,
3
,
27
,
3
],
'channels'
:
[
192
,
384
,
768
,
1536
]
},
'xlarge'
:
{
'depths'
:
[
3
,
3
,
27
,
3
],
'channels'
:
[
256
,
512
,
1024
,
2048
]
},
'huge'
:
{
'depths'
:
[
3
,
3
,
27
,
3
],
'channels'
:
[
352
,
704
,
1408
,
2816
]
}
}
def
__init__
(
self
,
arch
=
'tiny'
,
in_channels
=
3
,
stem_patch_size
=
4
,
norm_cfg
=
dict
(
type
=
'LN2d'
,
eps
=
1e-6
),
act_cfg
=
dict
(
type
=
'GELU'
),
linear_pw_conv
=
True
,
use_grn
=
False
,
drop_path_rate
=
0.
,
layer_scale_init_value
=
1e-6
,
out_indices
=-
1
,
frozen_stages
=
0
,
gap_before_final_norm
=
True
,
with_cp
=
False
,
init_cfg
=
[
dict
(
type
=
'TruncNormal'
,
layer
=
[
'Conv2d'
,
'Linear'
],
std
=
.
02
,
bias
=
0.
),
dict
(
type
=
'Constant'
,
layer
=
[
'LayerNorm'
],
val
=
1.
,
bias
=
0.
),
]):
super
().
__init__
(
init_cfg
=
init_cfg
)
if
isinstance
(
arch
,
str
):
assert
arch
in
self
.
arch_settings
,
\
f
'Unavailable arch, please choose from '
\
f
'(
{
set
(
self
.
arch_settings
)
}
) or pass a dict.'
arch
=
self
.
arch_settings
[
arch
]
elif
isinstance
(
arch
,
dict
):
assert
'depths'
in
arch
and
'channels'
in
arch
,
\
f
'The arch dict must have "depths" and "channels", '
\
f
'but got
{
list
(
arch
.
keys
())
}
.'
self
.
depths
=
arch
[
'depths'
]
self
.
channels
=
arch
[
'channels'
]
assert
(
isinstance
(
self
.
depths
,
Sequence
)
and
isinstance
(
self
.
channels
,
Sequence
)
and
len
(
self
.
depths
)
==
len
(
self
.
channels
)),
\
f
'The "depths" (
{
self
.
depths
}
) and "channels" (
{
self
.
channels
}
) '
\
'should be both sequence with the same length.'
self
.
num_stages
=
len
(
self
.
depths
)
if
isinstance
(
out_indices
,
int
):
out_indices
=
[
out_indices
]
assert
isinstance
(
out_indices
,
Sequence
),
\
f
'"out_indices" must by a sequence or int, '
\
f
'get
{
type
(
out_indices
)
}
instead.'
for
i
,
index
in
enumerate
(
out_indices
):
if
index
<
0
:
out_indices
[
i
]
=
4
+
index
assert
out_indices
[
i
]
>=
0
,
f
'Invalid out_indices
{
index
}
'
self
.
out_indices
=
out_indices
self
.
frozen_stages
=
frozen_stages
self
.
gap_before_final_norm
=
gap_before_final_norm
# stochastic depth decay rule
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
sum
(
self
.
depths
))
]
block_idx
=
0
# 4 downsample layers between stages, including the stem layer.
self
.
downsample_layers
=
ModuleList
()
stem
=
nn
.
Sequential
(
nn
.
Conv2d
(
in_channels
,
self
.
channels
[
0
],
kernel_size
=
stem_patch_size
,
stride
=
stem_patch_size
),
build_norm_layer
(
norm_cfg
,
self
.
channels
[
0
]),
)
self
.
downsample_layers
.
append
(
stem
)
# 4 feature resolution stages, each consisting of multiple residual
# blocks
self
.
stages
=
nn
.
ModuleList
()
for
i
in
range
(
self
.
num_stages
):
depth
=
self
.
depths
[
i
]
channels
=
self
.
channels
[
i
]
if
i
>=
1
:
downsample_layer
=
nn
.
Sequential
(
build_norm_layer
(
norm_cfg
,
self
.
channels
[
i
-
1
]),
nn
.
Conv2d
(
self
.
channels
[
i
-
1
],
channels
,
kernel_size
=
2
,
stride
=
2
),
)
self
.
downsample_layers
.
append
(
downsample_layer
)
stage
=
Sequential
(
*
[
ConvNeXtBlock
(
in_channels
=
channels
,
drop_path_rate
=
dpr
[
block_idx
+
j
],
norm_cfg
=
norm_cfg
,
act_cfg
=
act_cfg
,
linear_pw_conv
=
linear_pw_conv
,
layer_scale_init_value
=
layer_scale_init_value
,
use_grn
=
use_grn
,
with_cp
=
with_cp
)
for
j
in
range
(
depth
)
])
block_idx
+=
depth
self
.
stages
.
append
(
stage
)
if
i
in
self
.
out_indices
:
norm_layer
=
build_norm_layer
(
norm_cfg
,
channels
)
self
.
add_module
(
f
'norm
{
i
}
'
,
norm_layer
)
self
.
_freeze_stages
()
def
forward
(
self
,
x
):
outs
=
[]
for
i
,
stage
in
enumerate
(
self
.
stages
):
x
=
self
.
downsample_layers
[
i
](
x
)
x
=
stage
(
x
)
if
i
in
self
.
out_indices
:
norm_layer
=
getattr
(
self
,
f
'norm
{
i
}
'
)
if
self
.
gap_before_final_norm
:
gap
=
x
.
mean
([
-
2
,
-
1
],
keepdim
=
True
)
outs
.
append
(
norm_layer
(
gap
).
flatten
(
1
))
else
:
outs
.
append
(
norm_layer
(
x
))
return
tuple
(
outs
)
def
_freeze_stages
(
self
):
for
i
in
range
(
self
.
frozen_stages
):
downsample_layer
=
self
.
downsample_layers
[
i
]
stage
=
self
.
stages
[
i
]
downsample_layer
.
eval
()
stage
.
eval
()
for
param
in
chain
(
downsample_layer
.
parameters
(),
stage
.
parameters
()):
param
.
requires_grad
=
False
def
train
(
self
,
mode
=
True
):
super
(
ConvNeXt
,
self
).
train
(
mode
)
self
.
_freeze_stages
()
def
get_layer_depth
(
self
,
param_name
:
str
,
prefix
:
str
=
''
):
"""Get the layer-wise depth of a parameter.
Args:
param_name (str): The name of the parameter.
prefix (str): The prefix for the parameter.
Defaults to an empty string.
Returns:
Tuple[int, int]: The layer-wise depth and the num of layers.
"""
max_layer_id
=
12
if
self
.
depths
[
-
2
]
>
9
else
6
if
not
param_name
.
startswith
(
prefix
):
# For subsequent module like head
return
max_layer_id
+
1
,
max_layer_id
+
2
param_name
=
param_name
[
len
(
prefix
):]
if
param_name
.
startswith
(
'downsample_layers'
):
stage_id
=
int
(
param_name
.
split
(
'.'
)[
1
])
if
stage_id
==
0
:
layer_id
=
0
elif
stage_id
==
1
or
stage_id
==
2
:
layer_id
=
stage_id
+
1
else
:
# stage_id == 3:
layer_id
=
max_layer_id
elif
param_name
.
startswith
(
'stages'
):
stage_id
=
int
(
param_name
.
split
(
'.'
)[
1
])
block_id
=
int
(
param_name
.
split
(
'.'
)[
2
])
if
stage_id
==
0
or
stage_id
==
1
:
layer_id
=
stage_id
+
1
elif
stage_id
==
2
:
layer_id
=
3
+
block_id
//
3
else
:
# stage_id == 3:
layer_id
=
max_layer_id
# final norm layer
else
:
layer_id
=
max_layer_id
+
1
return
layer_id
,
max_layer_id
+
2
mmpretrain/models/backbones/cspnet.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
import
math
from
typing
import
Sequence
import
torch
import
torch.nn
as
nn
from
mmcv.cnn
import
ConvModule
,
DepthwiseSeparableConvModule
from
mmcv.cnn.bricks
import
DropPath
from
mmengine.model
import
BaseModule
,
Sequential
from
torch.nn.modules.batchnorm
import
_BatchNorm
from
mmpretrain.registry
import
MODELS
from
..utils
import
to_ntuple
from
.resnet
import
Bottleneck
as
ResNetBottleneck
from
.resnext
import
Bottleneck
as
ResNeXtBottleneck
eps
=
1.0e-5
class
DarknetBottleneck
(
BaseModule
):
"""The basic bottleneck block used in Darknet. Each DarknetBottleneck
consists of two ConvModules and the input is added to the final output.
Each ConvModule is composed of Conv, BN, and LeakyReLU. The first convLayer
has filter size of 1x1 and the second one has the filter size of 3x3.
Args:
in_channels (int): The input channels of this Module.
out_channels (int): The output channels of this Module.
expansion (int): The ratio of ``out_channels/mid_channels`` where
``mid_channels`` is the input/output channels of conv2.
Defaults to 4.
add_identity (bool): Whether to add identity to the out.
Defaults to True.
use_depthwise (bool): Whether to use depthwise separable convolution.
Defaults to False.
conv_cfg (dict): Config dict for convolution layer. Defaults to None,
which means using conv2d.
drop_path_rate (float): The ratio of the drop path layer. Default: 0.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='BN', eps=1e-5)``.
act_cfg (dict): Config dict for activation layer.
Defaults to ``dict(type='Swish')``.
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
expansion
=
2
,
add_identity
=
True
,
use_depthwise
=
False
,
conv_cfg
=
None
,
drop_path_rate
=
0
,
norm_cfg
=
dict
(
type
=
'BN'
,
eps
=
1e-5
),
act_cfg
=
dict
(
type
=
'LeakyReLU'
,
inplace
=
True
),
init_cfg
=
None
):
super
().
__init__
(
init_cfg
)
hidden_channels
=
int
(
out_channels
/
expansion
)
conv
=
DepthwiseSeparableConvModule
if
use_depthwise
else
ConvModule
self
.
conv1
=
ConvModule
(
in_channels
,
hidden_channels
,
1
,
conv_cfg
=
conv_cfg
,
norm_cfg
=
norm_cfg
,
act_cfg
=
act_cfg
)
self
.
conv2
=
conv
(
hidden_channels
,
out_channels
,
3
,
stride
=
1
,
padding
=
1
,
conv_cfg
=
conv_cfg
,
norm_cfg
=
norm_cfg
,
act_cfg
=
act_cfg
)
self
.
add_identity
=
\
add_identity
and
in_channels
==
out_channels
self
.
drop_path
=
DropPath
(
drop_prob
=
drop_path_rate
)
if
drop_path_rate
>
eps
else
nn
.
Identity
()
def
forward
(
self
,
x
):
identity
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
conv2
(
out
)
out
=
self
.
drop_path
(
out
)
if
self
.
add_identity
:
return
out
+
identity
else
:
return
out
class
CSPStage
(
BaseModule
):
"""Cross Stage Partial Stage.
.. code:: text
Downsample Convolution (optional)
|
|
Expand Convolution
|
|
Split to xa, xb
|
\
|
\
| blocks(xb)
| /
| / transition
| /
Concat xa, blocks(xb)
|
Transition Convolution
Args:
block_fn (nn.module): The basic block function in the Stage.
in_channels (int): The input channels of the CSP layer.
out_channels (int): The output channels of the CSP layer.
has_downsampler (bool): Whether to add a downsampler in the stage.
Default: False.
down_growth (bool): Whether to expand the channels in the
downsampler layer of the stage. Default: False.
expand_ratio (float): The expand ratio to adjust the number of
channels of the expand conv layer. Default: 0.5
bottle_ratio (float): Ratio to adjust the number of channels of the
hidden layer. Default: 0.5
block_dpr (float): The ratio of the drop path layer in the
blocks of the stage. Default: 0.
num_blocks (int): Number of blocks. Default: 1
conv_cfg (dict, optional): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN')
act_cfg (dict): Config dict for activation layer.
Default: dict(type='LeakyReLU', inplace=True)
"""
def
__init__
(
self
,
block_fn
,
in_channels
,
out_channels
,
has_downsampler
=
True
,
down_growth
=
False
,
expand_ratio
=
0.5
,
bottle_ratio
=
2
,
num_blocks
=
1
,
block_dpr
=
0
,
block_args
=
{},
conv_cfg
=
None
,
norm_cfg
=
dict
(
type
=
'BN'
,
eps
=
1e-5
),
act_cfg
=
dict
(
type
=
'LeakyReLU'
,
inplace
=
True
),
init_cfg
=
None
):
super
().
__init__
(
init_cfg
)
# grow downsample channels to output channels
down_channels
=
out_channels
if
down_growth
else
in_channels
block_dpr
=
to_ntuple
(
num_blocks
)(
block_dpr
)
if
has_downsampler
:
self
.
downsample_conv
=
ConvModule
(
in_channels
=
in_channels
,
out_channels
=
down_channels
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
groups
=
32
if
block_fn
is
ResNeXtBottleneck
else
1
,
norm_cfg
=
norm_cfg
,
act_cfg
=
act_cfg
)
else
:
self
.
downsample_conv
=
nn
.
Identity
()
exp_channels
=
int
(
down_channels
*
expand_ratio
)
self
.
expand_conv
=
ConvModule
(
in_channels
=
down_channels
,
out_channels
=
exp_channels
,
kernel_size
=
1
,
norm_cfg
=
norm_cfg
,
act_cfg
=
act_cfg
if
block_fn
is
DarknetBottleneck
else
None
)
assert
exp_channels
%
2
==
0
,
\
'The channel number before blocks must be divisible by 2.'
block_channels
=
exp_channels
//
2
blocks
=
[]
for
i
in
range
(
num_blocks
):
block_cfg
=
dict
(
in_channels
=
block_channels
,
out_channels
=
block_channels
,
expansion
=
bottle_ratio
,
drop_path_rate
=
block_dpr
[
i
],
conv_cfg
=
conv_cfg
,
norm_cfg
=
norm_cfg
,
act_cfg
=
act_cfg
,
**
block_args
)
blocks
.
append
(
block_fn
(
**
block_cfg
))
self
.
blocks
=
Sequential
(
*
blocks
)
self
.
atfer_blocks_conv
=
ConvModule
(
block_channels
,
block_channels
,
1
,
norm_cfg
=
norm_cfg
,
act_cfg
=
act_cfg
)
self
.
final_conv
=
ConvModule
(
2
*
block_channels
,
out_channels
,
1
,
norm_cfg
=
norm_cfg
,
act_cfg
=
act_cfg
)
def
forward
(
self
,
x
):
x
=
self
.
downsample_conv
(
x
)
x
=
self
.
expand_conv
(
x
)
split
=
x
.
shape
[
1
]
//
2
xa
,
xb
=
x
[:,
:
split
],
x
[:,
split
:]
xb
=
self
.
blocks
(
xb
)
xb
=
self
.
atfer_blocks_conv
(
xb
).
contiguous
()
x_final
=
torch
.
cat
((
xa
,
xb
),
dim
=
1
)
return
self
.
final_conv
(
x_final
)
class
CSPNet
(
BaseModule
):
"""The abstract CSP Network class.
A Pytorch implementation of `CSPNet: A New Backbone that can Enhance
Learning Capability of CNN <https://arxiv.org/abs/1911.11929>`_
This class is an abstract class because the Cross Stage Partial Network
(CSPNet) is a kind of universal network structure, and you
network block to implement networks like CSPResNet, CSPResNeXt and
CSPDarkNet.
Args:
arch (dict): The architecture of the CSPNet.
It should have the following keys:
- block_fn (Callable): A function or class to return a block
module, and it should accept at least ``in_channels``,
``out_channels``, ``expansion``, ``drop_path_rate``, ``norm_cfg``
and ``act_cfg``.
- in_channels (Tuple[int]): The number of input channels of each
stage.
- out_channels (Tuple[int]): The number of output channels of each
stage.
- num_blocks (Tuple[int]): The number of blocks in each stage.
- expansion_ratio (float | Tuple[float]): The expansion ratio in
the expand convolution of each stage. Defaults to 0.5.
- bottle_ratio (float | Tuple[float]): The expansion ratio of
blocks in each stage. Defaults to 2.
- has_downsampler (bool | Tuple[bool]): Whether to add a
downsample convolution in each stage. Defaults to True
- down_growth (bool | Tuple[bool]): Whether to expand the channels
in the downsampler layer of each stage. Defaults to False.
- block_args (dict | Tuple[dict], optional): The extra arguments to
the blocks in each stage. Defaults to None.
stem_fn (Callable): A function or class to return a stem module.
And it should accept ``in_channels``.
in_channels (int): Number of input image channels. Defaults to 3.
out_indices (int | Sequence[int]): Output from which stages.
Defaults to -1, which means the last stage.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
conv_cfg (dict, optional): The config dict for conv layers in blocks.
Defaults to None, which means use Conv2d.
norm_cfg (dict): The config dict for norm layers.
Defaults to ``dict(type='BN', eps=1e-5)``.
act_cfg (dict): The config dict for activation functions.
Defaults to ``dict(type='LeakyReLU', inplace=True)``.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Defaults to False.
init_cfg (dict, optional): The initialization settings.
Defaults to ``dict(type='Kaiming', layer='Conv2d'))``.
Example:
>>> from functools import partial
>>> import torch
>>> import torch.nn as nn
>>> from mmpretrain.models import CSPNet
>>> from mmpretrain.models.backbones.resnet import Bottleneck
>>>
>>> # A simple example to build CSPNet.
>>> arch = dict(
... block_fn=Bottleneck,
... in_channels=[32, 64],
... out_channels=[64, 128],
... num_blocks=[3, 4]
... )
>>> stem_fn = partial(nn.Conv2d, out_channels=32, kernel_size=3)
>>> model = CSPNet(arch=arch, stem_fn=stem_fn, out_indices=(0, 1))
>>> inputs = torch.rand(1, 3, 224, 224)
>>> outs = model(inputs)
>>> for out in outs:
... print(out.shape)
...
(1, 64, 111, 111)
(1, 128, 56, 56)
"""
def
__init__
(
self
,
arch
,
stem_fn
,
in_channels
=
3
,
out_indices
=-
1
,
frozen_stages
=-
1
,
drop_path_rate
=
0.
,
conv_cfg
=
None
,
norm_cfg
=
dict
(
type
=
'BN'
,
eps
=
1e-5
),
act_cfg
=
dict
(
type
=
'LeakyReLU'
,
inplace
=
True
),
norm_eval
=
False
,
init_cfg
=
dict
(
type
=
'Kaiming'
,
layer
=
'Conv2d'
)):
super
().
__init__
(
init_cfg
=
init_cfg
)
self
.
arch
=
self
.
expand_arch
(
arch
)
self
.
num_stages
=
len
(
self
.
arch
[
'in_channels'
])
self
.
conv_cfg
=
conv_cfg
self
.
norm_cfg
=
norm_cfg
self
.
act_cfg
=
act_cfg
self
.
norm_eval
=
norm_eval
if
frozen_stages
not
in
range
(
-
1
,
self
.
num_stages
):
raise
ValueError
(
'frozen_stages must be in range(-1, '
f
'
{
self
.
num_stages
}
). But received '
f
'
{
frozen_stages
}
'
)
self
.
frozen_stages
=
frozen_stages
self
.
stem
=
stem_fn
(
in_channels
)
stages
=
[]
depths
=
self
.
arch
[
'num_blocks'
]
dpr
=
torch
.
linspace
(
0
,
drop_path_rate
,
sum
(
depths
)).
split
(
depths
)
for
i
in
range
(
self
.
num_stages
):
stage_cfg
=
{
k
:
v
[
i
]
for
k
,
v
in
self
.
arch
.
items
()}
csp_stage
=
CSPStage
(
**
stage_cfg
,
block_dpr
=
dpr
[
i
].
tolist
(),
conv_cfg
=
conv_cfg
,
norm_cfg
=
norm_cfg
,
act_cfg
=
act_cfg
,
init_cfg
=
init_cfg
)
stages
.
append
(
csp_stage
)
self
.
stages
=
Sequential
(
*
stages
)
if
isinstance
(
out_indices
,
int
):
out_indices
=
[
out_indices
]
assert
isinstance
(
out_indices
,
Sequence
),
\
f
'"out_indices" must by a sequence or int, '
\
f
'get
{
type
(
out_indices
)
}
instead.'
out_indices
=
list
(
out_indices
)
for
i
,
index
in
enumerate
(
out_indices
):
if
index
<
0
:
out_indices
[
i
]
=
len
(
self
.
stages
)
+
index
assert
0
<=
out_indices
[
i
]
<=
len
(
self
.
stages
),
\
f
'Invalid out_indices
{
index
}
.'
self
.
out_indices
=
out_indices
@
staticmethod
def
expand_arch
(
arch
):
num_stages
=
len
(
arch
[
'in_channels'
])
def
to_tuple
(
x
,
name
=
''
):
if
isinstance
(
x
,
(
list
,
tuple
)):
assert
len
(
x
)
==
num_stages
,
\
f
'The length of
{
name
}
(
{
len
(
x
)
}
) does not '
\
f
'equals to the number of stages (
{
num_stages
}
)'
return
tuple
(
x
)
else
:
return
(
x
,
)
*
num_stages
full_arch
=
{
k
:
to_tuple
(
v
,
k
)
for
k
,
v
in
arch
.
items
()}
if
'block_args'
not
in
full_arch
:
full_arch
[
'block_args'
]
=
to_tuple
({})
return
full_arch
def
_freeze_stages
(
self
):
if
self
.
frozen_stages
>=
0
:
self
.
stem
.
eval
()
for
param
in
self
.
stem
.
parameters
():
param
.
requires_grad
=
False
for
i
in
range
(
self
.
frozen_stages
+
1
):
m
=
self
.
stages
[
i
]
m
.
eval
()
for
param
in
m
.
parameters
():
param
.
requires_grad
=
False
def
train
(
self
,
mode
=
True
):
super
(
CSPNet
,
self
).
train
(
mode
)
self
.
_freeze_stages
()
if
mode
and
self
.
norm_eval
:
for
m
in
self
.
modules
():
if
isinstance
(
m
,
_BatchNorm
):
m
.
eval
()
def
forward
(
self
,
x
):
outs
=
[]
x
=
self
.
stem
(
x
)
for
i
,
stage
in
enumerate
(
self
.
stages
):
x
=
stage
(
x
)
if
i
in
self
.
out_indices
:
outs
.
append
(
x
)
return
tuple
(
outs
)
@
MODELS
.
register_module
()
class
CSPDarkNet
(
CSPNet
):
"""CSP-Darknet backbone used in YOLOv4.
Args:
depth (int): Depth of CSP-Darknet. Default: 53.
in_channels (int): Number of input image channels. Default: 3.
out_indices (Sequence[int]): Output from which stages.
Default: (3, ).
frozen_stages (int): Stages to be frozen (stop grad and set eval
mode). -1 means not freezing any parameters. Default: -1.
conv_cfg (dict): Config dict for convolution layer. Default: None.
norm_cfg (dict): Dictionary to construct and config norm layer.
Default: dict(type='BN', requires_grad=True).
act_cfg (dict): Config dict for activation layer.
Default: dict(type='LeakyReLU', negative_slope=0.1).
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
Example:
>>> from mmpretrain.models import CSPDarkNet
>>> import torch
>>> model = CSPDarkNet(depth=53, out_indices=(0, 1, 2, 3, 4))
>>> model.eval()
>>> inputs = torch.rand(1, 3, 416, 416)
>>> level_outputs = model(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
...
(1, 64, 208, 208)
(1, 128, 104, 104)
(1, 256, 52, 52)
(1, 512, 26, 26)
(1, 1024, 13, 13)
"""
arch_settings
=
{
53
:
dict
(
block_fn
=
DarknetBottleneck
,
in_channels
=
(
32
,
64
,
128
,
256
,
512
),
out_channels
=
(
64
,
128
,
256
,
512
,
1024
),
num_blocks
=
(
1
,
2
,
8
,
8
,
4
),
expand_ratio
=
(
2
,
1
,
1
,
1
,
1
),
bottle_ratio
=
(
2
,
1
,
1
,
1
,
1
),
has_downsampler
=
True
,
down_growth
=
True
,
),
}
def
__init__
(
self
,
depth
,
in_channels
=
3
,
out_indices
=
(
4
,
),
frozen_stages
=-
1
,
conv_cfg
=
None
,
norm_cfg
=
dict
(
type
=
'BN'
,
eps
=
1e-5
),
act_cfg
=
dict
(
type
=
'LeakyReLU'
,
inplace
=
True
),
norm_eval
=
False
,
init_cfg
=
dict
(
type
=
'Kaiming'
,
layer
=
'Conv2d'
,
a
=
math
.
sqrt
(
5
),
distribution
=
'uniform'
,
mode
=
'fan_in'
,
nonlinearity
=
'leaky_relu'
)):
assert
depth
in
self
.
arch_settings
,
'depth must be one of '
\
f
'
{
list
(
self
.
arch_settings
.
keys
())
}
, but get
{
depth
}
.'
super
().
__init__
(
arch
=
self
.
arch_settings
[
depth
],
stem_fn
=
self
.
_make_stem_layer
,
in_channels
=
in_channels
,
out_indices
=
out_indices
,
frozen_stages
=
frozen_stages
,
conv_cfg
=
conv_cfg
,
norm_cfg
=
norm_cfg
,
act_cfg
=
act_cfg
,
norm_eval
=
norm_eval
,
init_cfg
=
init_cfg
)
def
_make_stem_layer
(
self
,
in_channels
):
"""using a stride=1 conv as the stem in CSPDarknet."""
# `stem_channels` equals to the `in_channels` in the first stage.
stem_channels
=
self
.
arch
[
'in_channels'
][
0
]
stem
=
ConvModule
(
in_channels
=
in_channels
,
out_channels
=
stem_channels
,
kernel_size
=
3
,
padding
=
1
,
norm_cfg
=
self
.
norm_cfg
,
act_cfg
=
self
.
act_cfg
)
return
stem
@
MODELS
.
register_module
()
class
CSPResNet
(
CSPNet
):
"""CSP-ResNet backbone.
Args:
depth (int): Depth of CSP-ResNet. Default: 50.
out_indices (Sequence[int]): Output from which stages.
Default: (4, ).
frozen_stages (int): Stages to be frozen (stop grad and set eval
mode). -1 means not freezing any parameters. Default: -1.
conv_cfg (dict): Config dict for convolution layer. Default: None.
norm_cfg (dict): Dictionary to construct and config norm layer.
Default: dict(type='BN', requires_grad=True).
act_cfg (dict): Config dict for activation layer.
Default: dict(type='LeakyReLU', negative_slope=0.1).
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
Example:
>>> from mmpretrain.models import CSPResNet
>>> import torch
>>> model = CSPResNet(depth=50, out_indices=(0, 1, 2, 3))
>>> model.eval()
>>> inputs = torch.rand(1, 3, 416, 416)
>>> level_outputs = model(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
...
(1, 128, 104, 104)
(1, 256, 52, 52)
(1, 512, 26, 26)
(1, 1024, 13, 13)
"""
arch_settings
=
{
50
:
dict
(
block_fn
=
ResNetBottleneck
,
in_channels
=
(
64
,
128
,
256
,
512
),
out_channels
=
(
128
,
256
,
512
,
1024
),
num_blocks
=
(
3
,
3
,
5
,
2
),
expand_ratio
=
4
,
bottle_ratio
=
2
,
has_downsampler
=
(
False
,
True
,
True
,
True
),
down_growth
=
False
),
}
def
__init__
(
self
,
depth
,
in_channels
=
3
,
out_indices
=
(
3
,
),
frozen_stages
=-
1
,
deep_stem
=
False
,
conv_cfg
=
None
,
norm_cfg
=
dict
(
type
=
'BN'
,
eps
=
1e-5
),
act_cfg
=
dict
(
type
=
'LeakyReLU'
,
inplace
=
True
),
norm_eval
=
False
,
init_cfg
=
dict
(
type
=
'Kaiming'
,
layer
=
'Conv2d'
)):
assert
depth
in
self
.
arch_settings
,
'depth must be one of '
\
f
'
{
list
(
self
.
arch_settings
.
keys
())
}
, but get
{
depth
}
.'
self
.
deep_stem
=
deep_stem
super
().
__init__
(
arch
=
self
.
arch_settings
[
depth
],
stem_fn
=
self
.
_make_stem_layer
,
in_channels
=
in_channels
,
out_indices
=
out_indices
,
frozen_stages
=
frozen_stages
,
conv_cfg
=
conv_cfg
,
norm_cfg
=
norm_cfg
,
act_cfg
=
act_cfg
,
norm_eval
=
norm_eval
,
init_cfg
=
init_cfg
)
def
_make_stem_layer
(
self
,
in_channels
):
# `stem_channels` equals to the `in_channels` in the first stage.
stem_channels
=
self
.
arch
[
'in_channels'
][
0
]
if
self
.
deep_stem
:
stem
=
nn
.
Sequential
(
ConvModule
(
in_channels
,
stem_channels
//
2
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
conv_cfg
=
self
.
conv_cfg
,
norm_cfg
=
self
.
norm_cfg
,
act_cfg
=
self
.
act_cfg
),
ConvModule
(
stem_channels
//
2
,
stem_channels
//
2
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
conv_cfg
=
self
.
conv_cfg
,
norm_cfg
=
self
.
norm_cfg
,
act_cfg
=
self
.
act_cfg
),
ConvModule
(
stem_channels
//
2
,
stem_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
conv_cfg
=
self
.
conv_cfg
,
norm_cfg
=
self
.
norm_cfg
,
act_cfg
=
self
.
act_cfg
))
else
:
stem
=
nn
.
Sequential
(
ConvModule
(
in_channels
,
stem_channels
,
kernel_size
=
7
,
stride
=
2
,
padding
=
3
,
conv_cfg
=
self
.
conv_cfg
,
norm_cfg
=
self
.
norm_cfg
,
act_cfg
=
self
.
act_cfg
),
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
))
return
stem
@
MODELS
.
register_module
()
class
CSPResNeXt
(
CSPResNet
):
"""CSP-ResNeXt backbone.
Args:
depth (int): Depth of CSP-ResNeXt. Default: 50.
out_indices (Sequence[int]): Output from which stages.
Default: (4, ).
frozen_stages (int): Stages to be frozen (stop grad and set eval
mode). -1 means not freezing any parameters. Default: -1.
conv_cfg (dict): Config dict for convolution layer. Default: None.
norm_cfg (dict): Dictionary to construct and config norm layer.
Default: dict(type='BN', requires_grad=True).
act_cfg (dict): Config dict for activation layer.
Default: dict(type='LeakyReLU', negative_slope=0.1).
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
Example:
>>> from mmpretrain.models import CSPResNeXt
>>> import torch
>>> model = CSPResNeXt(depth=50, out_indices=(0, 1, 2, 3))
>>> model.eval()
>>> inputs = torch.rand(1, 3, 224, 224)
>>> level_outputs = model(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
...
(1, 256, 56, 56)
(1, 512, 28, 28)
(1, 1024, 14, 14)
(1, 2048, 7, 7)
"""
arch_settings
=
{
50
:
dict
(
block_fn
=
ResNeXtBottleneck
,
in_channels
=
(
64
,
256
,
512
,
1024
),
out_channels
=
(
256
,
512
,
1024
,
2048
),
num_blocks
=
(
3
,
3
,
5
,
2
),
expand_ratio
=
(
4
,
2
,
2
,
2
),
bottle_ratio
=
4
,
has_downsampler
=
(
False
,
True
,
True
,
True
),
down_growth
=
False
,
# the base_channels is changed from 64 to 32 in CSPNet
block_args
=
dict
(
base_channels
=
32
),
),
}
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
mmpretrain/models/backbones/davit.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
from
copy
import
deepcopy
from
typing
import
Sequence
,
Tuple
import
torch
import
torch.nn
as
nn
import
torch.utils.checkpoint
as
cp
from
mmcv.cnn
import
build_conv_layer
,
build_norm_layer
from
mmcv.cnn.bricks
import
Conv2d
from
mmcv.cnn.bricks.transformer
import
FFN
,
AdaptivePadding
,
PatchEmbed
from
mmengine.model
import
BaseModule
,
ModuleList
from
mmengine.utils
import
to_2tuple
from
mmengine.utils.dl_utils.parrots_wrapper
import
_BatchNorm
from
mmpretrain.models.backbones.base_backbone
import
BaseBackbone
from
mmpretrain.registry
import
MODELS
from
..utils
import
ShiftWindowMSA
class
DaViTWindowMSA
(
BaseModule
):
"""Window based multi-head self-attention (W-MSA) module for DaViT.
The differences between DaViTWindowMSA & WindowMSA:
1. Without relative position bias.
Args:
embed_dims (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
Defaults to True.
qk_scale (float, optional): Override default qk scale of
``head_dim ** -0.5`` if set. Defaults to None.
attn_drop (float, optional): Dropout ratio of attention weight.
Defaults to 0.
proj_drop (float, optional): Dropout ratio of output. Defaults to 0.
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
def
__init__
(
self
,
embed_dims
,
window_size
,
num_heads
,
qkv_bias
=
True
,
qk_scale
=
None
,
attn_drop
=
0.
,
proj_drop
=
0.
,
init_cfg
=
None
):
super
().
__init__
(
init_cfg
)
self
.
embed_dims
=
embed_dims
self
.
window_size
=
window_size
# Wh, Ww
self
.
num_heads
=
num_heads
head_embed_dims
=
embed_dims
//
num_heads
self
.
scale
=
qk_scale
or
head_embed_dims
**-
0.5
self
.
qkv
=
nn
.
Linear
(
embed_dims
,
embed_dims
*
3
,
bias
=
qkv_bias
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
embed_dims
,
embed_dims
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
self
.
softmax
=
nn
.
Softmax
(
dim
=-
1
)
def
forward
(
self
,
x
,
mask
=
None
):
"""
Args:
x (tensor): input features with shape of (num_windows*B, N, C)
mask (tensor, Optional): mask with shape of (num_windows, Wh*Ww,
Wh*Ww), value should be between (-inf, 0].
"""
B_
,
N
,
C
=
x
.
shape
qkv
=
self
.
qkv
(
x
).
reshape
(
B_
,
N
,
3
,
self
.
num_heads
,
C
//
self
.
num_heads
).
permute
(
2
,
0
,
3
,
1
,
4
)
q
,
k
,
v
=
qkv
[
0
],
qkv
[
1
],
qkv
[
2
]
# make torchscript happy (cannot use tensor as tuple)
q
=
q
*
self
.
scale
attn
=
(
q
@
k
.
transpose
(
-
2
,
-
1
))
if
mask
is
not
None
:
nW
=
mask
.
shape
[
0
]
attn
=
attn
.
view
(
B_
//
nW
,
nW
,
self
.
num_heads
,
N
,
N
)
+
mask
.
unsqueeze
(
1
).
unsqueeze
(
0
)
attn
=
attn
.
view
(
-
1
,
self
.
num_heads
,
N
,
N
)
attn
=
self
.
softmax
(
attn
)
else
:
attn
=
self
.
softmax
(
attn
)
attn
=
self
.
attn_drop
(
attn
)
x
=
(
attn
@
v
).
transpose
(
1
,
2
).
reshape
(
B_
,
N
,
C
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
@
staticmethod
def
double_step_seq
(
step1
,
len1
,
step2
,
len2
):
seq1
=
torch
.
arange
(
0
,
step1
*
len1
,
step1
)
seq2
=
torch
.
arange
(
0
,
step2
*
len2
,
step2
)
return
(
seq1
[:,
None
]
+
seq2
[
None
,
:]).
reshape
(
1
,
-
1
)
class
ConvPosEnc
(
BaseModule
):
"""DaViT conv pos encode block.
Args:
embed_dims (int): Number of input channels.
kernel_size (int): The kernel size of the first convolution.
Defaults to 3.
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
def
__init__
(
self
,
embed_dims
,
kernel_size
=
3
,
init_cfg
=
None
):
super
(
ConvPosEnc
,
self
).
__init__
(
init_cfg
)
self
.
proj
=
Conv2d
(
embed_dims
,
embed_dims
,
kernel_size
,
stride
=
1
,
padding
=
kernel_size
//
2
,
groups
=
embed_dims
)
def
forward
(
self
,
x
,
size
:
Tuple
[
int
,
int
]):
B
,
N
,
C
=
x
.
shape
H
,
W
=
size
assert
N
==
H
*
W
feat
=
x
.
transpose
(
1
,
2
).
view
(
B
,
C
,
H
,
W
)
feat
=
self
.
proj
(
feat
)
feat
=
feat
.
flatten
(
2
).
transpose
(
1
,
2
)
x
=
x
+
feat
return
x
class
DaViTDownSample
(
BaseModule
):
"""DaViT down sampole block.
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
conv_type (str): The type of convolution
to generate patch embedding. Default: "Conv2d".
kernel_size (int): The kernel size of the first convolution.
Defaults to 2.
stride (int): The stride of the second convluation module.
Defaults to 2.
padding (int | tuple | string ): The padding length of
embedding conv. When it is a string, it means the mode
of adaptive padding, support "same" and "corner" now.
Defaults to "corner".
dilation (int): Dilation of the convolution layers. Defaults to 1.
bias (bool): Bias of embed conv. Default: True.
norm_cfg (dict, optional): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
conv_type
=
'Conv2d'
,
kernel_size
=
2
,
stride
=
2
,
padding
=
'same'
,
dilation
=
1
,
bias
=
True
,
norm_cfg
=
None
,
init_cfg
=
None
):
super
().
__init__
(
init_cfg
=
init_cfg
)
self
.
out_channels
=
out_channels
if
stride
is
None
:
stride
=
kernel_size
kernel_size
=
to_2tuple
(
kernel_size
)
stride
=
to_2tuple
(
stride
)
dilation
=
to_2tuple
(
dilation
)
if
isinstance
(
padding
,
str
):
self
.
adaptive_padding
=
AdaptivePadding
(
kernel_size
=
kernel_size
,
stride
=
stride
,
dilation
=
dilation
,
padding
=
padding
)
# disable the padding of conv
padding
=
0
else
:
self
.
adaptive_padding
=
None
padding
=
to_2tuple
(
padding
)
self
.
projection
=
build_conv_layer
(
dict
(
type
=
conv_type
),
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
bias
=
bias
)
if
norm_cfg
is
not
None
:
self
.
norm
=
build_norm_layer
(
norm_cfg
,
in_channels
)[
1
]
else
:
self
.
norm
=
None
def
forward
(
self
,
x
,
input_size
):
if
self
.
adaptive_padding
:
x
=
self
.
adaptive_padding
(
x
)
H
,
W
=
input_size
B
,
L
,
C
=
x
.
shape
assert
L
==
H
*
W
,
'input feature has wrong size'
x
=
self
.
norm
(
x
)
x
=
x
.
reshape
(
B
,
H
,
W
,
C
).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
x
=
self
.
projection
(
x
)
output_size
=
(
x
.
size
(
2
),
x
.
size
(
3
))
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
)
return
x
,
output_size
class
ChannelAttention
(
BaseModule
):
"""DaViT channel attention.
Args:
embed_dims (int): Number of input channels.
num_heads (int): Number of attention heads.
qkv_bias (bool): enable bias for qkv if True. Defaults to True.
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
def
__init__
(
self
,
embed_dims
,
num_heads
=
8
,
qkv_bias
=
False
,
init_cfg
=
None
):
super
().
__init__
(
init_cfg
)
self
.
embed_dims
=
embed_dims
self
.
num_heads
=
num_heads
self
.
head_dims
=
embed_dims
//
num_heads
self
.
scale
=
self
.
head_dims
**-
0.5
self
.
qkv
=
nn
.
Linear
(
embed_dims
,
embed_dims
*
3
,
bias
=
qkv_bias
)
self
.
proj
=
nn
.
Linear
(
embed_dims
,
embed_dims
)
def
forward
(
self
,
x
):
B
,
N
,
_
=
x
.
shape
qkv
=
self
.
qkv
(
x
).
reshape
(
B
,
N
,
3
,
self
.
num_heads
,
self
.
head_dims
).
permute
(
2
,
0
,
3
,
1
,
4
)
q
,
k
,
v
=
qkv
[
0
],
qkv
[
1
],
qkv
[
2
]
k
=
k
*
self
.
scale
attention
=
k
.
transpose
(
-
1
,
-
2
)
@
v
attention
=
attention
.
softmax
(
dim
=-
1
)
x
=
(
attention
@
q
.
transpose
(
-
1
,
-
2
)).
transpose
(
-
1
,
-
2
)
x
=
x
.
transpose
(
1
,
2
).
reshape
(
B
,
N
,
self
.
embed_dims
)
x
=
self
.
proj
(
x
)
return
x
class
ChannelBlock
(
BaseModule
):
"""DaViT channel attention block.
Args:
embed_dims (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): The height and width of the window. Defaults to 7.
ffn_ratio (float): The expansion ratio of feedforward network hidden
layer channels. Defaults to 4.
qkv_bias (bool): enable bias for qkv if True. Defaults to True.
drop_path (float): The drop path rate after attention and ffn.
Defaults to 0.
ffn_cfgs (dict): The extra config of FFN. Defaults to empty dict.
norm_cfg (dict): The config of norm layers.
Defaults to ``dict(type='LN')``.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
def
__init__
(
self
,
embed_dims
,
num_heads
,
ffn_ratio
=
4.
,
qkv_bias
=
False
,
drop_path
=
0.
,
ffn_cfgs
=
dict
(),
norm_cfg
=
dict
(
type
=
'LN'
),
with_cp
=
False
,
init_cfg
=
None
):
super
().
__init__
(
init_cfg
)
self
.
with_cp
=
with_cp
self
.
cpe1
=
ConvPosEnc
(
embed_dims
=
embed_dims
,
kernel_size
=
3
)
self
.
norm1
=
build_norm_layer
(
norm_cfg
,
embed_dims
)[
1
]
self
.
attn
=
ChannelAttention
(
embed_dims
,
num_heads
=
num_heads
,
qkv_bias
=
qkv_bias
)
self
.
cpe2
=
ConvPosEnc
(
embed_dims
=
embed_dims
,
kernel_size
=
3
)
_ffn_cfgs
=
{
'embed_dims'
:
embed_dims
,
'feedforward_channels'
:
int
(
embed_dims
*
ffn_ratio
),
'num_fcs'
:
2
,
'ffn_drop'
:
0
,
'dropout_layer'
:
dict
(
type
=
'DropPath'
,
drop_prob
=
drop_path
),
'act_cfg'
:
dict
(
type
=
'GELU'
),
**
ffn_cfgs
}
self
.
norm2
=
build_norm_layer
(
norm_cfg
,
embed_dims
)[
1
]
self
.
ffn
=
FFN
(
**
_ffn_cfgs
)
def
forward
(
self
,
x
,
hw_shape
):
def
_inner_forward
(
x
):
x
=
self
.
cpe1
(
x
,
hw_shape
)
identity
=
x
x
=
self
.
norm1
(
x
)
x
=
self
.
attn
(
x
)
x
=
x
+
identity
x
=
self
.
cpe2
(
x
,
hw_shape
)
identity
=
x
x
=
self
.
norm2
(
x
)
x
=
self
.
ffn
(
x
,
identity
=
identity
)
return
x
if
self
.
with_cp
and
x
.
requires_grad
:
x
=
cp
.
checkpoint
(
_inner_forward
,
x
)
else
:
x
=
_inner_forward
(
x
)
return
x
class
SpatialBlock
(
BaseModule
):
"""DaViT spatial attention block.
Args:
embed_dims (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): The height and width of the window. Defaults to 7.
ffn_ratio (float): The expansion ratio of feedforward network hidden
layer channels. Defaults to 4.
qkv_bias (bool): enable bias for qkv if True. Defaults to True.
drop_path (float): The drop path rate after attention and ffn.
Defaults to 0.
pad_small_map (bool): If True, pad the small feature map to the window
size, which is common used in detection and segmentation. If False,
avoid shifting window and shrink the window size to the size of
feature map, which is common used in classification.
Defaults to False.
attn_cfgs (dict): The extra config of Shift Window-MSA.
Defaults to empty dict.
ffn_cfgs (dict): The extra config of FFN. Defaults to empty dict.
norm_cfg (dict): The config of norm layers.
Defaults to ``dict(type='LN')``.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
def
__init__
(
self
,
embed_dims
,
num_heads
,
window_size
=
7
,
ffn_ratio
=
4.
,
qkv_bias
=
True
,
drop_path
=
0.
,
pad_small_map
=
False
,
attn_cfgs
=
dict
(),
ffn_cfgs
=
dict
(),
norm_cfg
=
dict
(
type
=
'LN'
),
with_cp
=
False
,
init_cfg
=
None
):
super
(
SpatialBlock
,
self
).
__init__
(
init_cfg
)
self
.
with_cp
=
with_cp
self
.
cpe1
=
ConvPosEnc
(
embed_dims
=
embed_dims
,
kernel_size
=
3
)
self
.
norm1
=
build_norm_layer
(
norm_cfg
,
embed_dims
)[
1
]
_attn_cfgs
=
{
'embed_dims'
:
embed_dims
,
'num_heads'
:
num_heads
,
'shift_size'
:
0
,
'window_size'
:
window_size
,
'dropout_layer'
:
dict
(
type
=
'DropPath'
,
drop_prob
=
drop_path
),
'qkv_bias'
:
qkv_bias
,
'pad_small_map'
:
pad_small_map
,
'window_msa'
:
DaViTWindowMSA
,
**
attn_cfgs
}
self
.
attn
=
ShiftWindowMSA
(
**
_attn_cfgs
)
self
.
cpe2
=
ConvPosEnc
(
embed_dims
=
embed_dims
,
kernel_size
=
3
)
_ffn_cfgs
=
{
'embed_dims'
:
embed_dims
,
'feedforward_channels'
:
int
(
embed_dims
*
ffn_ratio
),
'num_fcs'
:
2
,
'ffn_drop'
:
0
,
'dropout_layer'
:
dict
(
type
=
'DropPath'
,
drop_prob
=
drop_path
),
'act_cfg'
:
dict
(
type
=
'GELU'
),
**
ffn_cfgs
}
self
.
norm2
=
build_norm_layer
(
norm_cfg
,
embed_dims
)[
1
]
self
.
ffn
=
FFN
(
**
_ffn_cfgs
)
def
forward
(
self
,
x
,
hw_shape
):
def
_inner_forward
(
x
):
x
=
self
.
cpe1
(
x
,
hw_shape
)
identity
=
x
x
=
self
.
norm1
(
x
)
x
=
self
.
attn
(
x
,
hw_shape
)
x
=
x
+
identity
x
=
self
.
cpe2
(
x
,
hw_shape
)
identity
=
x
x
=
self
.
norm2
(
x
)
x
=
self
.
ffn
(
x
,
identity
=
identity
)
return
x
if
self
.
with_cp
and
x
.
requires_grad
:
x
=
cp
.
checkpoint
(
_inner_forward
,
x
)
else
:
x
=
_inner_forward
(
x
)
return
x
class
DaViTBlock
(
BaseModule
):
"""DaViT block.
Args:
embed_dims (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): The height and width of the window. Defaults to 7.
ffn_ratio (float): The expansion ratio of feedforward network hidden
layer channels. Defaults to 4.
qkv_bias (bool): enable bias for qkv if True. Defaults to True.
drop_path (float): The drop path rate after attention and ffn.
Defaults to 0.
pad_small_map (bool): If True, pad the small feature map to the window
size, which is common used in detection and segmentation. If False,
avoid shifting window and shrink the window size to the size of
feature map, which is common used in classification.
Defaults to False.
attn_cfgs (dict): The extra config of Shift Window-MSA.
Defaults to empty dict.
ffn_cfgs (dict): The extra config of FFN. Defaults to empty dict.
norm_cfg (dict): The config of norm layers.
Defaults to ``dict(type='LN')``.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
def
__init__
(
self
,
embed_dims
,
num_heads
,
window_size
=
7
,
ffn_ratio
=
4.
,
qkv_bias
=
True
,
drop_path
=
0.
,
pad_small_map
=
False
,
attn_cfgs
=
dict
(),
ffn_cfgs
=
dict
(),
norm_cfg
=
dict
(
type
=
'LN'
),
with_cp
=
False
,
init_cfg
=
None
):
super
(
DaViTBlock
,
self
).
__init__
(
init_cfg
)
self
.
spatial_block
=
SpatialBlock
(
embed_dims
,
num_heads
,
window_size
=
window_size
,
ffn_ratio
=
ffn_ratio
,
qkv_bias
=
qkv_bias
,
drop_path
=
drop_path
,
pad_small_map
=
pad_small_map
,
attn_cfgs
=
attn_cfgs
,
ffn_cfgs
=
ffn_cfgs
,
norm_cfg
=
norm_cfg
,
with_cp
=
with_cp
)
self
.
channel_block
=
ChannelBlock
(
embed_dims
,
num_heads
,
ffn_ratio
=
ffn_ratio
,
qkv_bias
=
qkv_bias
,
drop_path
=
drop_path
,
ffn_cfgs
=
ffn_cfgs
,
norm_cfg
=
norm_cfg
,
with_cp
=
False
)
def
forward
(
self
,
x
,
hw_shape
):
x
=
self
.
spatial_block
(
x
,
hw_shape
)
x
=
self
.
channel_block
(
x
,
hw_shape
)
return
x
class
DaViTBlockSequence
(
BaseModule
):
"""Module with successive DaViT blocks and downsample layer.
Args:
embed_dims (int): Number of input channels.
depth (int): Number of successive DaViT blocks.
num_heads (int): Number of attention heads.
window_size (int): The height and width of the window. Defaults to 7.
ffn_ratio (float): The expansion ratio of feedforward network hidden
layer channels. Defaults to 4.
qkv_bias (bool): enable bias for qkv if True. Defaults to True.
downsample (bool): Downsample the output of blocks by patch merging.
Defaults to False.
downsample_cfg (dict): The extra config of the patch merging layer.
Defaults to empty dict.
drop_paths (Sequence[float] | float): The drop path rate in each block.
Defaults to 0.
block_cfgs (Sequence[dict] | dict): The extra config of each block.
Defaults to empty dicts.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
pad_small_map (bool): If True, pad the small feature map to the window
size, which is common used in detection and segmentation. If False,
avoid shifting window and shrink the window size to the size of
feature map, which is common used in classification.
Defaults to False.
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
def
__init__
(
self
,
embed_dims
,
depth
,
num_heads
,
window_size
=
7
,
ffn_ratio
=
4.
,
qkv_bias
=
True
,
downsample
=
False
,
downsample_cfg
=
dict
(),
drop_paths
=
0.
,
block_cfgs
=
dict
(),
with_cp
=
False
,
pad_small_map
=
False
,
init_cfg
=
None
):
super
().
__init__
(
init_cfg
)
if
not
isinstance
(
drop_paths
,
Sequence
):
drop_paths
=
[
drop_paths
]
*
depth
if
not
isinstance
(
block_cfgs
,
Sequence
):
block_cfgs
=
[
deepcopy
(
block_cfgs
)
for
_
in
range
(
depth
)]
self
.
embed_dims
=
embed_dims
self
.
blocks
=
ModuleList
()
for
i
in
range
(
depth
):
_block_cfg
=
{
'embed_dims'
:
embed_dims
,
'num_heads'
:
num_heads
,
'window_size'
:
window_size
,
'ffn_ratio'
:
ffn_ratio
,
'qkv_bias'
:
qkv_bias
,
'drop_path'
:
drop_paths
[
i
],
'with_cp'
:
with_cp
,
'pad_small_map'
:
pad_small_map
,
**
block_cfgs
[
i
]
}
block
=
DaViTBlock
(
**
_block_cfg
)
self
.
blocks
.
append
(
block
)
if
downsample
:
_downsample_cfg
=
{
'in_channels'
:
embed_dims
,
'out_channels'
:
2
*
embed_dims
,
'norm_cfg'
:
dict
(
type
=
'LN'
),
**
downsample_cfg
}
self
.
downsample
=
DaViTDownSample
(
**
_downsample_cfg
)
else
:
self
.
downsample
=
None
def
forward
(
self
,
x
,
in_shape
,
do_downsample
=
True
):
for
block
in
self
.
blocks
:
x
=
block
(
x
,
in_shape
)
if
self
.
downsample
is
not
None
and
do_downsample
:
x
,
out_shape
=
self
.
downsample
(
x
,
in_shape
)
else
:
out_shape
=
in_shape
return
x
,
out_shape
@
property
def
out_channels
(
self
):
if
self
.
downsample
:
return
self
.
downsample
.
out_channels
else
:
return
self
.
embed_dims
@
MODELS
.
register_module
()
class
DaViT
(
BaseBackbone
):
"""DaViT.
A PyTorch implement of : `DaViT: Dual Attention Vision Transformers
<https://arxiv.org/abs/2204.03645v1>`_
Inspiration from
https://github.com/dingmyu/davit
Args:
arch (str | dict): DaViT architecture. If use string, choose from
'tiny', 'small', 'base' and 'large', 'huge', 'giant'. If use dict,
it should have below keys:
- **embed_dims** (int): The dimensions of embedding.
- **depths** (List[int]): The number of blocks in each stage.
- **num_heads** (List[int]): The number of heads in attention
modules of each stage.
Defaults to 't'.
patch_size (int | tuple): The patch size in patch embedding.
Defaults to 4.
in_channels (int): The num of input channels. Defaults to 3.
window_size (int): The height and width of the window. Defaults to 7.
ffn_ratio (float): The expansion ratio of feedforward network hidden
layer channels. Defaults to 4.
qkv_bias (bool): Whether to add bias for qkv in attention modules.
Defaults to True.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.1.
out_after_downsample (bool): Whether to output the feature map of a
stage after the following downsample layer. Defaults to False.
pad_small_map (bool): If True, pad the small feature map to the window
size, which is common used in detection and segmentation. If False,
avoid shifting window and shrink the window size to the size of
feature map, which is common used in classification.
Defaults to False.
norm_cfg (dict): Config dict for normalization layer for all output
features. Defaults to ``dict(type='LN')``
stage_cfgs (Sequence[dict] | dict): Extra config dict for each
stage. Defaults to an empty dict.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Defaults to False.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
arch_zoo
=
{
**
dict
.
fromkeys
([
't'
,
'tiny'
],
{
'embed_dims'
:
96
,
'depths'
:
[
1
,
1
,
3
,
1
],
'num_heads'
:
[
3
,
6
,
12
,
24
]
}),
**
dict
.
fromkeys
([
's'
,
'small'
],
{
'embed_dims'
:
96
,
'depths'
:
[
1
,
1
,
9
,
1
],
'num_heads'
:
[
3
,
6
,
12
,
24
]
}),
**
dict
.
fromkeys
([
'b'
,
'base'
],
{
'embed_dims'
:
128
,
'depths'
:
[
1
,
1
,
9
,
1
],
'num_heads'
:
[
4
,
8
,
16
,
32
]
}),
**
dict
.
fromkeys
(
[
'l'
,
'large'
],
{
'embed_dims'
:
192
,
'depths'
:
[
1
,
1
,
9
,
1
],
'num_heads'
:
[
6
,
12
,
24
,
48
]
}),
**
dict
.
fromkeys
(
[
'h'
,
'huge'
],
{
'embed_dims'
:
256
,
'depths'
:
[
1
,
1
,
9
,
1
],
'num_heads'
:
[
8
,
16
,
32
,
64
]
}),
**
dict
.
fromkeys
(
[
'g'
,
'giant'
],
{
'embed_dims'
:
384
,
'depths'
:
[
1
,
1
,
12
,
3
],
'num_heads'
:
[
12
,
24
,
48
,
96
]
}),
}
def
__init__
(
self
,
arch
=
't'
,
patch_size
=
4
,
in_channels
=
3
,
window_size
=
7
,
ffn_ratio
=
4.
,
qkv_bias
=
True
,
drop_path_rate
=
0.1
,
out_after_downsample
=
False
,
pad_small_map
=
False
,
norm_cfg
=
dict
(
type
=
'LN'
),
stage_cfgs
=
dict
(),
frozen_stages
=-
1
,
norm_eval
=
False
,
out_indices
=
(
3
,
),
with_cp
=
False
,
init_cfg
=
None
):
super
().
__init__
(
init_cfg
)
if
isinstance
(
arch
,
str
):
arch
=
arch
.
lower
()
assert
arch
in
set
(
self
.
arch_zoo
),
\
f
'Arch
{
arch
}
is not in default archs
{
set
(
self
.
arch_zoo
)
}
'
self
.
arch_settings
=
self
.
arch_zoo
[
arch
]
else
:
essential_keys
=
{
'embed_dims'
,
'depths'
,
'num_heads'
}
assert
isinstance
(
arch
,
dict
)
and
essential_keys
<=
set
(
arch
),
\
f
'Custom arch needs a dict with keys
{
essential_keys
}
'
self
.
arch_settings
=
arch
self
.
embed_dims
=
self
.
arch_settings
[
'embed_dims'
]
self
.
depths
=
self
.
arch_settings
[
'depths'
]
self
.
num_heads
=
self
.
arch_settings
[
'num_heads'
]
self
.
num_layers
=
len
(
self
.
depths
)
self
.
out_indices
=
out_indices
self
.
out_after_downsample
=
out_after_downsample
self
.
frozen_stages
=
frozen_stages
self
.
norm_eval
=
norm_eval
# stochastic depth decay rule
total_depth
=
sum
(
self
.
depths
)
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
total_depth
)
]
# stochastic depth decay rule
_patch_cfg
=
dict
(
in_channels
=
in_channels
,
embed_dims
=
self
.
embed_dims
,
conv_type
=
'Conv2d'
,
kernel_size
=
7
,
stride
=
patch_size
,
padding
=
'same'
,
norm_cfg
=
dict
(
type
=
'LN'
),
)
self
.
patch_embed
=
PatchEmbed
(
**
_patch_cfg
)
self
.
stages
=
ModuleList
()
embed_dims
=
[
self
.
embed_dims
]
for
i
,
(
depth
,
num_heads
)
in
enumerate
(
zip
(
self
.
depths
,
self
.
num_heads
)):
if
isinstance
(
stage_cfgs
,
Sequence
):
stage_cfg
=
stage_cfgs
[
i
]
else
:
stage_cfg
=
deepcopy
(
stage_cfgs
)
downsample
=
True
if
i
<
self
.
num_layers
-
1
else
False
_stage_cfg
=
{
'embed_dims'
:
embed_dims
[
-
1
],
'depth'
:
depth
,
'num_heads'
:
num_heads
,
'window_size'
:
window_size
,
'ffn_ratio'
:
ffn_ratio
,
'qkv_bias'
:
qkv_bias
,
'downsample'
:
downsample
,
'drop_paths'
:
dpr
[:
depth
],
'with_cp'
:
with_cp
,
'pad_small_map'
:
pad_small_map
,
**
stage_cfg
}
stage
=
DaViTBlockSequence
(
**
_stage_cfg
)
self
.
stages
.
append
(
stage
)
dpr
=
dpr
[
depth
:]
embed_dims
.
append
(
stage
.
out_channels
)
self
.
num_features
=
embed_dims
[:
-
1
]
# add a norm layer for each output
for
i
in
out_indices
:
if
norm_cfg
is
not
None
:
norm_layer
=
build_norm_layer
(
norm_cfg
,
self
.
num_features
[
i
])[
1
]
else
:
norm_layer
=
nn
.
Identity
()
self
.
add_module
(
f
'norm
{
i
}
'
,
norm_layer
)
def
train
(
self
,
mode
=
True
):
super
().
train
(
mode
)
self
.
_freeze_stages
()
if
mode
and
self
.
norm_eval
:
for
m
in
self
.
modules
():
# trick: eval have effect on BatchNorm only
if
isinstance
(
m
,
_BatchNorm
):
m
.
eval
()
def
_freeze_stages
(
self
):
if
self
.
frozen_stages
>=
0
:
self
.
patch_embed
.
eval
()
for
param
in
self
.
patch_embed
.
parameters
():
param
.
requires_grad
=
False
for
i
in
range
(
0
,
self
.
frozen_stages
+
1
):
m
=
self
.
stages
[
i
]
m
.
eval
()
for
param
in
m
.
parameters
():
param
.
requires_grad
=
False
for
i
in
self
.
out_indices
:
if
i
<=
self
.
frozen_stages
:
for
param
in
getattr
(
self
,
f
'norm
{
i
}
'
).
parameters
():
param
.
requires_grad
=
False
def
forward
(
self
,
x
):
x
,
hw_shape
=
self
.
patch_embed
(
x
)
outs
=
[]
for
i
,
stage
in
enumerate
(
self
.
stages
):
x
,
hw_shape
=
stage
(
x
,
hw_shape
,
do_downsample
=
self
.
out_after_downsample
)
if
i
in
self
.
out_indices
:
norm_layer
=
getattr
(
self
,
f
'norm
{
i
}
'
)
out
=
norm_layer
(
x
)
out
=
out
.
view
(
-
1
,
*
hw_shape
,
self
.
num_features
[
i
]).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
outs
.
append
(
out
)
if
stage
.
downsample
is
not
None
and
not
self
.
out_after_downsample
:
x
,
hw_shape
=
stage
.
downsample
(
x
,
hw_shape
)
return
tuple
(
outs
)
mmpretrain/models/backbones/deit.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
import
torch.nn
as
nn
from
mmengine.model.weight_init
import
trunc_normal_
from
mmpretrain.registry
import
MODELS
from
.vision_transformer
import
VisionTransformer
@
MODELS
.
register_module
()
class
DistilledVisionTransformer
(
VisionTransformer
):
"""Distilled Vision Transformer.
A PyTorch implement of : `Training data-efficient image transformers &
distillation through attention <https://arxiv.org/abs/2012.12877>`_
Args:
arch (str | dict): Vision Transformer architecture. If use string,
choose from 'small', 'base', 'large', 'deit-tiny', 'deit-small'
and 'deit-base'. If use dict, it should have below keys:
- **embed_dims** (int): The dimensions of embedding.
- **num_layers** (int): The number of transformer encoder layers.
- **num_heads** (int): The number of heads in attention modules.
- **feedforward_channels** (int): The hidden dimensions in
feedforward modules.
Defaults to 'deit-base'.
img_size (int | tuple): The expected input image shape. Because we
support dynamic input shape, just set the argument to the most
common input image shape. Defaults to 224.
patch_size (int | tuple): The patch size in patch embedding.
Defaults to 16.
in_channels (int): The num of input channels. Defaults to 3.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
qkv_bias (bool): Whether to add bias for qkv in attention modules.
Defaults to True.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
out_type (str): The type of output features. Please choose from
- ``"cls_token"``: A tuple with the class token and the
distillation token. The shapes of both tensor are (B, C).
- ``"featmap"``: The feature map tensor from the patch tokens
with shape (B, C, H, W).
- ``"avg_featmap"``: The global averaged feature map tensor
with shape (B, C).
- ``"raw"``: The raw feature tensor includes patch tokens and
class tokens with shape (B, L, C).
Defaults to ``"cls_token"``.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Defaults to "bicubic".
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
layer_cfgs (Sequence | dict): Configs of each transformer layer in
encoder. Defaults to an empty dict.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
num_extra_tokens
=
2
# class token and distillation token
def
__init__
(
self
,
arch
=
'deit-base'
,
*
args
,
**
kwargs
):
super
(
DistilledVisionTransformer
,
self
).
__init__
(
arch
=
arch
,
with_cls_token
=
True
,
*
args
,
**
kwargs
,
)
self
.
dist_token
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
self
.
embed_dims
))
def
forward
(
self
,
x
):
B
=
x
.
shape
[
0
]
x
,
patch_resolution
=
self
.
patch_embed
(
x
)
# stole cls_tokens impl from Phil Wang, thanks
cls_tokens
=
self
.
cls_token
.
expand
(
B
,
-
1
,
-
1
)
dist_token
=
self
.
dist_token
.
expand
(
B
,
-
1
,
-
1
)
x
=
torch
.
cat
((
cls_tokens
,
dist_token
,
x
),
dim
=
1
)
x
=
x
+
self
.
resize_pos_embed
(
self
.
pos_embed
,
self
.
patch_resolution
,
patch_resolution
,
mode
=
self
.
interpolate_mode
,
num_extra_tokens
=
self
.
num_extra_tokens
)
x
=
self
.
drop_after_pos
(
x
)
outs
=
[]
for
i
,
layer
in
enumerate
(
self
.
layers
):
x
=
layer
(
x
)
if
i
==
len
(
self
.
layers
)
-
1
and
self
.
final_norm
:
x
=
self
.
ln1
(
x
)
if
i
in
self
.
out_indices
:
outs
.
append
(
self
.
_format_output
(
x
,
patch_resolution
))
return
tuple
(
outs
)
def
_format_output
(
self
,
x
,
hw
):
if
self
.
out_type
==
'cls_token'
:
return
x
[:,
0
],
x
[:,
1
]
return
super
().
_format_output
(
x
,
hw
)
def
init_weights
(
self
):
super
(
DistilledVisionTransformer
,
self
).
init_weights
()
if
not
(
isinstance
(
self
.
init_cfg
,
dict
)
and
self
.
init_cfg
[
'type'
]
==
'Pretrained'
):
trunc_normal_
(
self
.
dist_token
,
std
=
0.02
)
mmpretrain/models/backbones/deit3.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
Sequence
import
numpy
as
np
import
torch
from
mmcv.cnn
import
Linear
,
build_activation_layer
from
mmcv.cnn.bricks.drop
import
build_dropout
from
mmcv.cnn.bricks.transformer
import
PatchEmbed
from
mmengine.model
import
BaseModule
,
ModuleList
,
Sequential
from
mmengine.utils
import
deprecated_api_warning
from
torch
import
nn
from
mmpretrain.registry
import
MODELS
from
..utils
import
(
LayerScale
,
MultiheadAttention
,
build_norm_layer
,
resize_pos_embed
,
to_2tuple
)
from
.vision_transformer
import
VisionTransformer
class
DeiT3FFN
(
BaseModule
):
"""FFN for DeiT3.
The differences between DeiT3FFN & FFN:
1. Use LayerScale.
Args:
embed_dims (int): The feature dimension. Same as
`MultiheadAttention`. Defaults: 256.
feedforward_channels (int): The hidden dimension of FFNs.
Defaults: 1024.
num_fcs (int, optional): The number of fully-connected layers in
FFNs. Default: 2.
act_cfg (dict, optional): The activation config for FFNs.
Default: dict(type='ReLU')
ffn_drop (float, optional): Probability of an element to be
zeroed in FFN. Default 0.0.
add_identity (bool, optional): Whether to add the
identity connection. Default: `True`.
dropout_layer (obj:`ConfigDict`): The dropout_layer used
when adding the shortcut.
use_layer_scale (bool): Whether to use layer_scale in
DeiT3FFN. Defaults to True.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
@
deprecated_api_warning
(
{
'dropout'
:
'ffn_drop'
,
'add_residual'
:
'add_identity'
},
cls_name
=
'FFN'
)
def
__init__
(
self
,
embed_dims
=
256
,
feedforward_channels
=
1024
,
num_fcs
=
2
,
act_cfg
=
dict
(
type
=
'ReLU'
,
inplace
=
True
),
ffn_drop
=
0.
,
dropout_layer
=
None
,
add_identity
=
True
,
use_layer_scale
=
True
,
init_cfg
=
None
,
**
kwargs
):
super
().
__init__
(
init_cfg
)
assert
num_fcs
>=
2
,
'num_fcs should be no less '
\
f
'than 2. got
{
num_fcs
}
.'
self
.
embed_dims
=
embed_dims
self
.
feedforward_channels
=
feedforward_channels
self
.
num_fcs
=
num_fcs
self
.
act_cfg
=
act_cfg
self
.
activate
=
build_activation_layer
(
act_cfg
)
layers
=
[]
in_channels
=
embed_dims
for
_
in
range
(
num_fcs
-
1
):
layers
.
append
(
Sequential
(
Linear
(
in_channels
,
feedforward_channels
),
self
.
activate
,
nn
.
Dropout
(
ffn_drop
)))
in_channels
=
feedforward_channels
layers
.
append
(
Linear
(
feedforward_channels
,
embed_dims
))
layers
.
append
(
nn
.
Dropout
(
ffn_drop
))
self
.
layers
=
Sequential
(
*
layers
)
self
.
dropout_layer
=
build_dropout
(
dropout_layer
)
if
dropout_layer
else
torch
.
nn
.
Identity
()
self
.
add_identity
=
add_identity
if
use_layer_scale
:
self
.
gamma2
=
LayerScale
(
embed_dims
)
else
:
self
.
gamma2
=
nn
.
Identity
()
@
deprecated_api_warning
({
'residual'
:
'identity'
},
cls_name
=
'FFN'
)
def
forward
(
self
,
x
,
identity
=
None
):
"""Forward function for `FFN`.
The function would add x to the output tensor if residue is None.
"""
out
=
self
.
layers
(
x
)
out
=
self
.
gamma2
(
out
)
if
not
self
.
add_identity
:
return
self
.
dropout_layer
(
out
)
if
identity
is
None
:
identity
=
x
return
identity
+
self
.
dropout_layer
(
out
)
class
DeiT3TransformerEncoderLayer
(
BaseModule
):
"""Implements one encoder layer in DeiT3.
The differences between DeiT3TransformerEncoderLayer &
TransformerEncoderLayer:
1. Use LayerScale.
Args:
embed_dims (int): The feature dimension
num_heads (int): Parallel attention heads
feedforward_channels (int): The hidden dimension for FFNs
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Defaults to 0.
attn_drop_rate (float): The drop out rate for attention output weights.
Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
num_fcs (int): The number of fully-connected layers for FFNs.
Defaults to 2.
qkv_bias (bool): enable bias for qkv if True. Defaults to True.
use_layer_scale (bool): Whether to use layer_scale in
DeiT3TransformerEncoderLayer. Defaults to True.
act_cfg (dict): The activation config for FFNs.
Defaults to ``dict(type='GELU')``.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
def
__init__
(
self
,
embed_dims
,
num_heads
,
feedforward_channels
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.
,
num_fcs
=
2
,
qkv_bias
=
True
,
use_layer_scale
=
True
,
act_cfg
=
dict
(
type
=
'GELU'
),
norm_cfg
=
dict
(
type
=
'LN'
),
init_cfg
=
None
):
super
(
DeiT3TransformerEncoderLayer
,
self
).
__init__
(
init_cfg
=
init_cfg
)
self
.
embed_dims
=
embed_dims
self
.
ln1
=
build_norm_layer
(
norm_cfg
,
self
.
embed_dims
)
self
.
attn
=
MultiheadAttention
(
embed_dims
=
embed_dims
,
num_heads
=
num_heads
,
attn_drop
=
attn_drop_rate
,
proj_drop
=
drop_rate
,
dropout_layer
=
dict
(
type
=
'DropPath'
,
drop_prob
=
drop_path_rate
),
qkv_bias
=
qkv_bias
,
use_layer_scale
=
use_layer_scale
)
self
.
ln2
=
build_norm_layer
(
norm_cfg
,
self
.
embed_dims
)
self
.
ffn
=
DeiT3FFN
(
embed_dims
=
embed_dims
,
feedforward_channels
=
feedforward_channels
,
num_fcs
=
num_fcs
,
ffn_drop
=
drop_rate
,
dropout_layer
=
dict
(
type
=
'DropPath'
,
drop_prob
=
drop_path_rate
),
act_cfg
=
act_cfg
,
use_layer_scale
=
use_layer_scale
)
def
init_weights
(
self
):
super
(
DeiT3TransformerEncoderLayer
,
self
).
init_weights
()
for
m
in
self
.
ffn
.
modules
():
if
isinstance
(
m
,
nn
.
Linear
):
nn
.
init
.
xavier_uniform_
(
m
.
weight
)
nn
.
init
.
normal_
(
m
.
bias
,
std
=
1e-6
)
def
forward
(
self
,
x
):
x
=
x
+
self
.
attn
(
self
.
ln1
(
x
))
x
=
self
.
ffn
(
self
.
ln1
(
x
),
identity
=
x
)
return
x
@
MODELS
.
register_module
()
class
DeiT3
(
VisionTransformer
):
"""DeiT3 backbone.
A PyTorch implement of : `DeiT III: Revenge of the ViT
<https://arxiv.org/pdf/2204.07118.pdf>`_
The differences between DeiT3 & VisionTransformer:
1. Use LayerScale.
2. Concat cls token after adding pos_embed.
Args:
arch (str | dict): DeiT3 architecture. If use string,
choose from 'small', 'base', 'medium', 'large' and 'huge'.
If use dict, it should have below keys:
- **embed_dims** (int): The dimensions of embedding.
- **num_layers** (int): The number of transformer encoder layers.
- **num_heads** (int): The number of heads in attention modules.
- **feedforward_channels** (int): The hidden dimensions in
feedforward modules.
Defaults to 'base'.
img_size (int | tuple): The expected input image shape. Because we
support dynamic input shape, just set the argument to the most
common input image shape. Defaults to 224.
patch_size (int | tuple): The patch size in patch embedding.
Defaults to 16.
in_channels (int): The num of input channels. Defaults to 3.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
qkv_bias (bool): Whether to add bias for qkv in attention modules.
Defaults to True.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
out_type (str): The type of output features. Please choose from
- ``"cls_token"``: The class token tensor with shape (B, C).
- ``"featmap"``: The feature map tensor from the patch tokens
with shape (B, C, H, W).
- ``"avg_featmap"``: The global averaged feature map tensor
with shape (B, C).
- ``"raw"``: The raw feature tensor includes patch tokens and
class tokens with shape (B, L, C).
Defaults to ``"cls_token"``.
with_cls_token (bool): Whether concatenating class token into image
tokens as transformer input. Defaults to True.
use_layer_scale (bool): Whether to use layer_scale in DeiT3.
Defaults to True.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Defaults to "bicubic".
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
layer_cfgs (Sequence | dict): Configs of each transformer layer in
encoder. Defaults to an empty dict.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
arch_zoo
=
{
**
dict
.
fromkeys
(
[
's'
,
'small'
],
{
'embed_dims'
:
384
,
'num_layers'
:
12
,
'num_heads'
:
6
,
'feedforward_channels'
:
1536
,
}),
**
dict
.
fromkeys
(
[
'm'
,
'medium'
],
{
'embed_dims'
:
512
,
'num_layers'
:
12
,
'num_heads'
:
8
,
'feedforward_channels'
:
2048
,
}),
**
dict
.
fromkeys
(
[
'b'
,
'base'
],
{
'embed_dims'
:
768
,
'num_layers'
:
12
,
'num_heads'
:
12
,
'feedforward_channels'
:
3072
}),
**
dict
.
fromkeys
(
[
'l'
,
'large'
],
{
'embed_dims'
:
1024
,
'num_layers'
:
24
,
'num_heads'
:
16
,
'feedforward_channels'
:
4096
}),
**
dict
.
fromkeys
(
[
'h'
,
'huge'
],
{
'embed_dims'
:
1280
,
'num_layers'
:
32
,
'num_heads'
:
16
,
'feedforward_channels'
:
5120
}),
}
num_extra_tokens
=
1
# class token
def
__init__
(
self
,
arch
=
'base'
,
img_size
=
224
,
patch_size
=
16
,
in_channels
=
3
,
out_indices
=-
1
,
drop_rate
=
0.
,
drop_path_rate
=
0.
,
qkv_bias
=
True
,
norm_cfg
=
dict
(
type
=
'LN'
,
eps
=
1e-6
),
final_norm
=
True
,
out_type
=
'cls_token'
,
with_cls_token
=
True
,
use_layer_scale
=
True
,
interpolate_mode
=
'bicubic'
,
patch_cfg
=
dict
(),
layer_cfgs
=
dict
(),
init_cfg
=
None
):
super
(
VisionTransformer
,
self
).
__init__
(
init_cfg
)
if
isinstance
(
arch
,
str
):
arch
=
arch
.
lower
()
assert
arch
in
set
(
self
.
arch_zoo
),
\
f
'Arch
{
arch
}
is not in default archs
{
set
(
self
.
arch_zoo
)
}
'
self
.
arch_settings
=
self
.
arch_zoo
[
arch
]
else
:
essential_keys
=
{
'embed_dims'
,
'num_layers'
,
'num_heads'
,
'feedforward_channels'
}
assert
isinstance
(
arch
,
dict
)
and
essential_keys
<=
set
(
arch
),
\
f
'Custom arch needs a dict with keys
{
essential_keys
}
'
self
.
arch_settings
=
arch
self
.
embed_dims
=
self
.
arch_settings
[
'embed_dims'
]
self
.
num_layers
=
self
.
arch_settings
[
'num_layers'
]
self
.
img_size
=
to_2tuple
(
img_size
)
# Set patch embedding
_patch_cfg
=
dict
(
in_channels
=
in_channels
,
input_size
=
img_size
,
embed_dims
=
self
.
embed_dims
,
conv_type
=
'Conv2d'
,
kernel_size
=
patch_size
,
stride
=
patch_size
,
)
_patch_cfg
.
update
(
patch_cfg
)
self
.
patch_embed
=
PatchEmbed
(
**
_patch_cfg
)
self
.
patch_resolution
=
self
.
patch_embed
.
init_out_size
num_patches
=
self
.
patch_resolution
[
0
]
*
self
.
patch_resolution
[
1
]
# Set out type
if
out_type
not
in
self
.
OUT_TYPES
:
raise
ValueError
(
f
'Unsupported `out_type`
{
out_type
}
, please '
f
'choose from
{
self
.
OUT_TYPES
}
'
)
self
.
out_type
=
out_type
# Set cls token
if
with_cls_token
:
self
.
cls_token
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
self
.
embed_dims
))
elif
out_type
!=
'cls_token'
:
self
.
cls_token
=
None
self
.
num_extra_tokens
=
0
else
:
raise
ValueError
(
'with_cls_token must be True when `out_type="cls_token"`.'
)
# Set position embedding
self
.
interpolate_mode
=
interpolate_mode
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
num_patches
,
self
.
embed_dims
))
self
.
_register_load_state_dict_pre_hook
(
self
.
_prepare_pos_embed
)
self
.
drop_after_pos
=
nn
.
Dropout
(
p
=
drop_rate
)
if
isinstance
(
out_indices
,
int
):
out_indices
=
[
out_indices
]
assert
isinstance
(
out_indices
,
Sequence
),
\
f
'"out_indices" must by a sequence or int, '
\
f
'get
{
type
(
out_indices
)
}
instead.'
for
i
,
index
in
enumerate
(
out_indices
):
if
index
<
0
:
out_indices
[
i
]
=
self
.
num_layers
+
index
assert
0
<=
out_indices
[
i
]
<=
self
.
num_layers
,
\
f
'Invalid out_indices
{
index
}
'
self
.
out_indices
=
out_indices
# stochastic depth decay rule
dpr
=
np
.
linspace
(
0
,
drop_path_rate
,
self
.
num_layers
)
self
.
layers
=
ModuleList
()
if
isinstance
(
layer_cfgs
,
dict
):
layer_cfgs
=
[
layer_cfgs
]
*
self
.
num_layers
for
i
in
range
(
self
.
num_layers
):
_layer_cfg
=
dict
(
embed_dims
=
self
.
embed_dims
,
num_heads
=
self
.
arch_settings
[
'num_heads'
],
feedforward_channels
=
self
.
arch_settings
[
'feedforward_channels'
],
drop_rate
=
drop_rate
,
drop_path_rate
=
dpr
[
i
],
qkv_bias
=
qkv_bias
,
norm_cfg
=
norm_cfg
,
use_layer_scale
=
use_layer_scale
)
_layer_cfg
.
update
(
layer_cfgs
[
i
])
self
.
layers
.
append
(
DeiT3TransformerEncoderLayer
(
**
_layer_cfg
))
self
.
final_norm
=
final_norm
if
final_norm
:
self
.
ln1
=
build_norm_layer
(
norm_cfg
,
self
.
embed_dims
)
def
forward
(
self
,
x
):
B
=
x
.
shape
[
0
]
x
,
patch_resolution
=
self
.
patch_embed
(
x
)
x
=
x
+
resize_pos_embed
(
self
.
pos_embed
,
self
.
patch_resolution
,
patch_resolution
,
mode
=
self
.
interpolate_mode
,
num_extra_tokens
=
0
)
x
=
self
.
drop_after_pos
(
x
)
if
self
.
cls_token
is
not
None
:
# stole cls_tokens impl from Phil Wang, thanks
cls_tokens
=
self
.
cls_token
.
expand
(
B
,
-
1
,
-
1
)
x
=
torch
.
cat
((
cls_tokens
,
x
),
dim
=
1
)
outs
=
[]
for
i
,
layer
in
enumerate
(
self
.
layers
):
x
=
layer
(
x
)
if
i
==
len
(
self
.
layers
)
-
1
and
self
.
final_norm
:
x
=
self
.
ln1
(
x
)
if
i
in
self
.
out_indices
:
outs
.
append
(
self
.
_format_output
(
x
,
patch_resolution
))
return
tuple
(
outs
)
def
_prepare_pos_embed
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
name
=
prefix
+
'pos_embed'
if
name
not
in
state_dict
.
keys
():
return
ckpt_pos_embed_shape
=
state_dict
[
name
].
shape
if
self
.
pos_embed
.
shape
!=
ckpt_pos_embed_shape
:
from
mmengine.logging
import
MMLogger
logger
=
MMLogger
.
get_current_instance
()
logger
.
info
(
f
'Resize the pos_embed shape from
{
ckpt_pos_embed_shape
}
'
f
'to
{
self
.
pos_embed
.
shape
}
.'
)
ckpt_pos_embed_shape
=
to_2tuple
(
int
(
np
.
sqrt
(
ckpt_pos_embed_shape
[
1
])))
pos_embed_shape
=
self
.
patch_embed
.
init_out_size
state_dict
[
name
]
=
resize_pos_embed
(
state_dict
[
name
],
ckpt_pos_embed_shape
,
pos_embed_shape
,
self
.
interpolate_mode
,
num_extra_tokens
=
0
,
# The cls token adding is after pos_embed
)
mmpretrain/models/backbones/densenet.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
import
math
from
itertools
import
chain
from
typing
import
Sequence
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.utils.checkpoint
as
cp
from
mmcv.cnn.bricks
import
build_activation_layer
,
build_norm_layer
from
torch.jit.annotations
import
List
from
mmpretrain.registry
import
MODELS
from
.base_backbone
import
BaseBackbone
class
DenseLayer
(
BaseBackbone
):
"""DenseBlock layers."""
def
__init__
(
self
,
in_channels
,
growth_rate
,
bn_size
,
norm_cfg
=
dict
(
type
=
'BN'
),
act_cfg
=
dict
(
type
=
'ReLU'
),
drop_rate
=
0.
,
memory_efficient
=
False
):
super
(
DenseLayer
,
self
).
__init__
()
self
.
norm1
=
build_norm_layer
(
norm_cfg
,
in_channels
)[
1
]
self
.
conv1
=
nn
.
Conv2d
(
in_channels
,
bn_size
*
growth_rate
,
kernel_size
=
1
,
stride
=
1
,
bias
=
False
)
self
.
act
=
build_activation_layer
(
act_cfg
)
self
.
norm2
=
build_norm_layer
(
norm_cfg
,
bn_size
*
growth_rate
)[
1
]
self
.
conv2
=
nn
.
Conv2d
(
bn_size
*
growth_rate
,
growth_rate
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
)
self
.
drop_rate
=
float
(
drop_rate
)
self
.
memory_efficient
=
memory_efficient
def
bottleneck_fn
(
self
,
xs
):
# type: (List[torch.Tensor]) -> torch.Tensor
concated_features
=
torch
.
cat
(
xs
,
1
)
bottleneck_output
=
self
.
conv1
(
self
.
act
(
self
.
norm1
(
concated_features
)))
# noqa: T484
return
bottleneck_output
# todo: rewrite when torchscript supports any
def
any_requires_grad
(
self
,
x
):
# type: (List[torch.Tensor]) -> bool
for
tensor
in
x
:
if
tensor
.
requires_grad
:
return
True
return
False
# This decorator indicates to the compiler that a function or method
# should be ignored and replaced with the raising of an exception.
# Here this function is incompatible with torchscript.
@
torch
.
jit
.
unused
# noqa: T484
def
call_checkpoint_bottleneck
(
self
,
x
):
# type: (List[torch.Tensor]) -> torch.Tensor
def
closure
(
*
xs
):
return
self
.
bottleneck_fn
(
xs
)
# Here use torch.utils.checkpoint to rerun a forward-pass during
# backward in bottleneck to save memories.
return
cp
.
checkpoint
(
closure
,
*
x
)
def
forward
(
self
,
x
):
# noqa: F811
# type: (List[torch.Tensor]) -> torch.Tensor
# assert input features is a list of Tensor
assert
isinstance
(
x
,
list
)
if
self
.
memory_efficient
and
self
.
any_requires_grad
(
x
):
if
torch
.
jit
.
is_scripting
():
raise
Exception
(
'Memory Efficient not supported in JIT'
)
bottleneck_output
=
self
.
call_checkpoint_bottleneck
(
x
)
else
:
bottleneck_output
=
self
.
bottleneck_fn
(
x
)
new_features
=
self
.
conv2
(
self
.
act
(
self
.
norm2
(
bottleneck_output
)))
if
self
.
drop_rate
>
0
:
new_features
=
F
.
dropout
(
new_features
,
p
=
self
.
drop_rate
,
training
=
self
.
training
)
return
new_features
class
DenseBlock
(
nn
.
Module
):
"""DenseNet Blocks."""
def
__init__
(
self
,
num_layers
,
in_channels
,
bn_size
,
growth_rate
,
norm_cfg
=
dict
(
type
=
'BN'
),
act_cfg
=
dict
(
type
=
'ReLU'
),
drop_rate
=
0.
,
memory_efficient
=
False
):
super
(
DenseBlock
,
self
).
__init__
()
self
.
block
=
nn
.
ModuleList
([
DenseLayer
(
in_channels
+
i
*
growth_rate
,
growth_rate
=
growth_rate
,
bn_size
=
bn_size
,
norm_cfg
=
norm_cfg
,
act_cfg
=
act_cfg
,
drop_rate
=
drop_rate
,
memory_efficient
=
memory_efficient
)
for
i
in
range
(
num_layers
)
])
def
forward
(
self
,
init_features
):
features
=
[
init_features
]
for
layer
in
self
.
block
:
new_features
=
layer
(
features
)
features
.
append
(
new_features
)
return
torch
.
cat
(
features
,
1
)
class
DenseTransition
(
nn
.
Sequential
):
"""DenseNet Transition Layers."""
def
__init__
(
self
,
in_channels
,
out_channels
,
norm_cfg
=
dict
(
type
=
'BN'
),
act_cfg
=
dict
(
type
=
'ReLU'
)):
super
(
DenseTransition
,
self
).
__init__
()
self
.
add_module
(
'norm'
,
build_norm_layer
(
norm_cfg
,
in_channels
)[
1
])
self
.
add_module
(
'act'
,
build_activation_layer
(
act_cfg
))
self
.
add_module
(
'conv'
,
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
bias
=
False
))
self
.
add_module
(
'pool'
,
nn
.
AvgPool2d
(
kernel_size
=
2
,
stride
=
2
))
@
MODELS
.
register_module
()
class
DenseNet
(
BaseBackbone
):
"""DenseNet.
A PyTorch implementation of : `Densely Connected Convolutional Networks
<https://arxiv.org/pdf/1608.06993.pdf>`_
Modified from the `official repo
<https://github.com/liuzhuang13/DenseNet>`_
and `pytorch
<https://github.com/pytorch/vision/blob/main/torchvision/models/densenet.py>`_.
Args:
arch (str | dict): The model's architecture. If string, it should be
one of architecture in ``DenseNet.arch_settings``. And if dict, it
should include the following two keys:
- growth_rate (int): Each layer of DenseBlock produce `k` feature
maps. Here refers `k` as the growth rate of the network.
- depths (list[int]): Number of repeated layers in each DenseBlock.
- init_channels (int): The output channels of stem layers.
Defaults to '121'.
in_channels (int): Number of input image channels. Defaults to 3.
bn_size (int): Refers to channel expansion parameter of 1x1
convolution layer. Defaults to 4.
drop_rate (float): Drop rate of Dropout Layer. Defaults to 0.
compression_factor (float): The reduction rate of transition layers.
Defaults to 0.5.
memory_efficient (bool): If True, uses checkpointing. Much more memory
efficient, but slower. Defaults to False.
See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_.
norm_cfg (dict): The config dict for norm layers.
Defaults to ``dict(type='BN')``.
act_cfg (dict): The config dict for activation after each convolution.
Defaults to ``dict(type='ReLU')``.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
frozen_stages (int): Stages to be frozen (all param fixed).
Defaults to 0, which means not freezing any parameters.
init_cfg (dict, optional): Initialization config dict.
"""
arch_settings
=
{
'121'
:
{
'growth_rate'
:
32
,
'depths'
:
[
6
,
12
,
24
,
16
],
'init_channels'
:
64
,
},
'169'
:
{
'growth_rate'
:
32
,
'depths'
:
[
6
,
12
,
32
,
32
],
'init_channels'
:
64
,
},
'201'
:
{
'growth_rate'
:
32
,
'depths'
:
[
6
,
12
,
48
,
32
],
'init_channels'
:
64
,
},
'161'
:
{
'growth_rate'
:
48
,
'depths'
:
[
6
,
12
,
36
,
24
],
'init_channels'
:
96
,
},
}
def
__init__
(
self
,
arch
=
'121'
,
in_channels
=
3
,
bn_size
=
4
,
drop_rate
=
0
,
compression_factor
=
0.5
,
memory_efficient
=
False
,
norm_cfg
=
dict
(
type
=
'BN'
),
act_cfg
=
dict
(
type
=
'ReLU'
),
out_indices
=-
1
,
frozen_stages
=
0
,
init_cfg
=
None
):
super
().
__init__
(
init_cfg
=
init_cfg
)
if
isinstance
(
arch
,
str
):
assert
arch
in
self
.
arch_settings
,
\
f
'Unavailable arch, please choose from '
\
f
'(
{
set
(
self
.
arch_settings
)
}
) or pass a dict.'
arch
=
self
.
arch_settings
[
arch
]
elif
isinstance
(
arch
,
dict
):
essential_keys
=
{
'growth_rate'
,
'depths'
,
'init_channels'
}
assert
isinstance
(
arch
,
dict
)
and
essential_keys
<=
set
(
arch
),
\
f
'Custom arch needs a dict with keys
{
essential_keys
}
'
self
.
growth_rate
=
arch
[
'growth_rate'
]
self
.
depths
=
arch
[
'depths'
]
self
.
init_channels
=
arch
[
'init_channels'
]
self
.
act
=
build_activation_layer
(
act_cfg
)
self
.
num_stages
=
len
(
self
.
depths
)
# check out indices and frozen stages
if
isinstance
(
out_indices
,
int
):
out_indices
=
[
out_indices
]
assert
isinstance
(
out_indices
,
Sequence
),
\
f
'"out_indices" must by a sequence or int, '
\
f
'get
{
type
(
out_indices
)
}
instead.'
for
i
,
index
in
enumerate
(
out_indices
):
if
index
<
0
:
out_indices
[
i
]
=
self
.
num_stages
+
index
assert
out_indices
[
i
]
>=
0
,
f
'Invalid out_indices
{
index
}
'
self
.
out_indices
=
out_indices
self
.
frozen_stages
=
frozen_stages
# Set stem layers
self
.
stem
=
nn
.
Sequential
(
nn
.
Conv2d
(
in_channels
,
self
.
init_channels
,
kernel_size
=
7
,
stride
=
2
,
padding
=
3
,
bias
=
False
),
build_norm_layer
(
norm_cfg
,
self
.
init_channels
)[
1
],
self
.
act
,
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
))
# Repetitions of DenseNet Blocks
self
.
stages
=
nn
.
ModuleList
()
self
.
transitions
=
nn
.
ModuleList
()
channels
=
self
.
init_channels
for
i
in
range
(
self
.
num_stages
):
depth
=
self
.
depths
[
i
]
stage
=
DenseBlock
(
num_layers
=
depth
,
in_channels
=
channels
,
bn_size
=
bn_size
,
growth_rate
=
self
.
growth_rate
,
norm_cfg
=
norm_cfg
,
act_cfg
=
act_cfg
,
drop_rate
=
drop_rate
,
memory_efficient
=
memory_efficient
)
self
.
stages
.
append
(
stage
)
channels
+=
depth
*
self
.
growth_rate
if
i
!=
self
.
num_stages
-
1
:
transition
=
DenseTransition
(
in_channels
=
channels
,
out_channels
=
math
.
floor
(
channels
*
compression_factor
),
norm_cfg
=
norm_cfg
,
act_cfg
=
act_cfg
,
)
channels
=
math
.
floor
(
channels
*
compression_factor
)
else
:
# Final layers after dense block is just bn with act.
# Unlike the paper, the original repo also put this in
# transition layer, whereas torchvision take this out.
# We reckon this as transition layer here.
transition
=
nn
.
Sequential
(
build_norm_layer
(
norm_cfg
,
channels
)[
1
],
self
.
act
,
)
self
.
transitions
.
append
(
transition
)
self
.
_freeze_stages
()
def
forward
(
self
,
x
):
x
=
self
.
stem
(
x
)
outs
=
[]
for
i
in
range
(
self
.
num_stages
):
x
=
self
.
stages
[
i
](
x
)
x
=
self
.
transitions
[
i
](
x
)
if
i
in
self
.
out_indices
:
outs
.
append
(
x
)
return
tuple
(
outs
)
def
_freeze_stages
(
self
):
for
i
in
range
(
self
.
frozen_stages
):
downsample_layer
=
self
.
transitions
[
i
]
stage
=
self
.
stages
[
i
]
downsample_layer
.
eval
()
stage
.
eval
()
for
param
in
chain
(
downsample_layer
.
parameters
(),
stage
.
parameters
()):
param
.
requires_grad
=
False
def
train
(
self
,
mode
=
True
):
super
(
DenseNet
,
self
).
train
(
mode
)
self
.
_freeze_stages
()
Prev
1
…
7
8
9
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment