Unverified Commit 631a5159 authored by Zhe Chen's avatar Zhe Chen Committed by GitHub
Browse files

Release code for InternImage-H + Mask2former (#198)

* Add InternImage-H + Mask2Former

* Update README.md

* Update configs and readme

* Update README_CN.md
parent 88dbd1ae
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import numpy as np
import pycocotools.mask as mask_util
import torch
def split_combined_polys(polys, poly_lens, polys_per_mask):
"""Split the combined 1-D polys into masks.
A mask is represented as a list of polys, and a poly is represented as
a 1-D array. In dataset, all masks are concatenated into a single 1-D
tensor. Here we need to split the tensor into original representations.
Args:
polys (list): a list (length = image num) of 1-D tensors
poly_lens (list): a list (length = image num) of poly length
polys_per_mask (list): a list (length = image num) of poly number
of each mask
Returns:
list: a list (length = image num) of list (length = mask num) of \
list (length = poly num) of numpy array.
"""
mask_polys_list = []
for img_id in range(len(polys)):
polys_single = polys[img_id]
polys_lens_single = poly_lens[img_id].tolist()
polys_per_mask_single = polys_per_mask[img_id].tolist()
split_polys = mmcv.slice_list(polys_single, polys_lens_single)
mask_polys = mmcv.slice_list(split_polys, polys_per_mask_single)
mask_polys_list.append(mask_polys)
return mask_polys_list
# TODO: move this function to more proper place
def encode_mask_results(mask_results):
"""Encode bitmap mask to RLE code.
Args:
mask_results (list | tuple[list]): bitmap mask results.
In mask scoring rcnn, mask_results is a tuple of (segm_results,
segm_cls_score).
Returns:
list | tuple: RLE encoded mask.
"""
if isinstance(mask_results, tuple): # mask scoring
cls_segms, cls_mask_scores = mask_results
else:
cls_segms = mask_results
num_classes = len(cls_segms)
encoded_mask_results = [[] for _ in range(num_classes)]
for i in range(len(cls_segms)):
for cls_segm in cls_segms[i]:
encoded_mask_results[i].append(
mask_util.encode(
np.array(
cls_segm[:, :, np.newaxis], order='F',
dtype='uint8'))[0]) # encoded with RLE
if isinstance(mask_results, tuple):
return encoded_mask_results, cls_mask_scores
else:
return encoded_mask_results
def mask2bbox(masks):
"""Obtain tight bounding boxes of binary masks.
Args:
masks (Tensor): Binary mask of shape (n, h, w).
Returns:
Tensor: Bboxe with shape (n, 4) of \
positive region in binary mask.
"""
N = masks.shape[0]
bboxes = masks.new_zeros((N, 4), dtype=torch.float32)
x_any = torch.any(masks, dim=1)
y_any = torch.any(masks, dim=2)
for i in range(N):
x = torch.where(x_any[i, :])[0]
y = torch.where(y_any[i, :])[0]
if len(x) > 0 and len(y) > 0:
bboxes[i, :] = bboxes.new_tensor(
[x[0], y[0], x[-1] + 1, y[-1] + 1])
return bboxes
# Copyright (c) OpenMMLab. All rights reserved.
from .dist_utils import (DistOptimizerHook, all_reduce_dict, allreduce_grads,
reduce_mean)
from .misc import add_prefix, multi_apply
__all__ = [
'add_prefix', 'multi_apply', 'DistOptimizerHook', 'allreduce_grads',
'all_reduce_dict', 'reduce_mean'
]
# Copyright (c) OpenMMLab. All rights reserved.
import functools
import pickle
import warnings
from collections import OrderedDict
import torch
import torch.distributed as dist
from mmcv.runner import OptimizerHook, get_dist_info
from torch._utils import (_flatten_dense_tensors, _take_tensors,
_unflatten_dense_tensors)
def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1):
if bucket_size_mb > 0:
bucket_size_bytes = bucket_size_mb * 1024 * 1024
buckets = _take_tensors(tensors, bucket_size_bytes)
else:
buckets = OrderedDict()
for tensor in tensors:
tp = tensor.type()
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(tensor)
buckets = buckets.values()
for bucket in buckets:
flat_tensors = _flatten_dense_tensors(bucket)
dist.all_reduce(flat_tensors)
flat_tensors.div_(world_size)
for tensor, synced in zip(
bucket, _unflatten_dense_tensors(flat_tensors, bucket)):
tensor.copy_(synced)
def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
"""Allreduce gradients.
Args:
params (list[torch.Parameters]): List of parameters of a model
coalesce (bool, optional): Whether allreduce parameters as a whole.
Defaults to True.
bucket_size_mb (int, optional): Size of bucket, the unit is MB.
Defaults to -1.
"""
grads = [
param.grad.data for param in params
if param.requires_grad and param.grad is not None
]
world_size = dist.get_world_size()
if coalesce:
_allreduce_coalesced(grads, world_size, bucket_size_mb)
else:
for tensor in grads:
dist.all_reduce(tensor.div_(world_size))
class DistOptimizerHook(OptimizerHook):
"""Deprecated optimizer hook for distributed training."""
def __init__(self, *args, **kwargs):
warnings.warn('"DistOptimizerHook" is deprecated, please switch to'
'"mmcv.runner.OptimizerHook".')
super().__init__(*args, **kwargs)
def reduce_mean(tensor):
""""Obtain the mean of tensor on different GPUs."""
if not (dist.is_available() and dist.is_initialized()):
return tensor
tensor = tensor.clone()
dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)
return tensor
def obj2tensor(pyobj, device='cuda'):
"""Serialize picklable python object to tensor."""
storage = torch.ByteStorage.from_buffer(pickle.dumps(pyobj))
return torch.ByteTensor(storage).to(device=device)
def tensor2obj(tensor):
"""Deserialize tensor to picklable python object."""
return pickle.loads(tensor.cpu().numpy().tobytes())
@functools.lru_cache()
def _get_global_gloo_group():
"""Return a process group based on gloo backend, containing all the ranks
The result is cached."""
if dist.get_backend() == 'nccl':
return dist.new_group(backend='gloo')
else:
return dist.group.WORLD
def all_reduce_dict(py_dict, op='sum', group=None, to_float=True):
"""Apply all reduce function for python dict object.
The code is modified from https://github.com/Megvii-
BaseDetection/YOLOX/blob/main/yolox/utils/allreduce_norm.py.
NOTE: make sure that py_dict in different ranks has the same keys and
the values should be in the same shape.
Args:
py_dict (dict): Dict to be applied all reduce op.
op (str): Operator, could be 'sum' or 'mean'. Default: 'sum'
group (:obj:`torch.distributed.group`, optional): Distributed group,
Default: None.
to_float (bool): Whether to convert all values of dict to float.
Default: True.
Returns:
OrderedDict: reduced python dict object.
"""
_, world_size = get_dist_info()
if world_size == 1:
return py_dict
if group is None:
# TODO: May try not to use gloo in the future
group = _get_global_gloo_group()
if dist.get_world_size(group) == 1:
return py_dict
# all reduce logic across different devices.
py_key = list(py_dict.keys())
py_key_tensor = obj2tensor(py_key)
dist.broadcast(py_key_tensor, src=0)
py_key = tensor2obj(py_key_tensor)
tensor_shapes = [py_dict[k].shape for k in py_key]
tensor_numels = [py_dict[k].numel() for k in py_key]
if to_float:
flatten_tensor = torch.cat(
[py_dict[k].flatten().float() for k in py_key])
else:
flatten_tensor = torch.cat([py_dict[k].flatten() for k in py_key])
dist.all_reduce(flatten_tensor, op=dist.ReduceOp.SUM)
if op == 'mean':
flatten_tensor /= world_size
split_tensors = [
x.reshape(shape) for x, shape in zip(
torch.split(flatten_tensor, tensor_numels), tensor_shapes)
]
return OrderedDict({k: v for k, v in zip(py_key, split_tensors)})
# Copyright (c) OpenMMLab. All rights reserved.
def multi_apply(func, *args, **kwargs):
"""Apply function to a list of arguments.
Note:
This function applies the ``func`` to multiple inputs and
map the multiple outputs of the ``func`` into different
list. Each list contains the same type of outputs corresponding
to different inputs.
Args:
func (Function): A function that will be applied to a list of
arguments
Returns:
tuple(list): A tuple containing multiple list, each list contains \
a kind of returned results by the function
"""
pfunc = partial(func, **kwargs) if kwargs else func
map_results = map(pfunc, *args)
return tuple(map(list, zip(*map_results)))
def add_prefix(inputs, prefix):
"""Add prefix for dict.
Args:
inputs (dict): The input dict with str keys.
prefix (str): The prefix to add.
Returns:
dict: The dict with keys updated with ``prefix``.
"""
outputs = dict()
for name, value in inputs.items():
outputs[f'{prefix}.{name}'] = value
return outputs
......@@ -2,3 +2,9 @@
from .mapillary import MapillaryDataset # noqa: F401,F403
from .nyu_depth_v2 import NYUDepthV2Dataset # noqa: F401,F403
from .pipelines import * # noqa: F401,F403
from .dataset_wrappers import ConcatDataset
__all__ = [
'MapillaryDataset', 'NYUDepthV2Dataset', 'ConcatDataset'
]
\ No newline at end of file
# Copyright (c) OpenMMLab. All rights reserved.
import bisect
from itertools import chain
import mmcv
import numpy as np
from mmcv.utils import build_from_cfg, print_log
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
from mmseg.datasets.builder import DATASETS
@DATASETS.register_module(force=True)
class ConcatDataset(_ConcatDataset):
"""A wrapper of concatenated dataset.
Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
support evaluation and formatting results
Args:
datasets (list[:obj:`Dataset`]): A list of datasets.
separate_eval (bool): Whether to evaluate the concatenated
dataset results separately, Defaults to True.
"""
def __init__(self, datasets, separate_eval=True):
super(ConcatDataset, self).__init__(datasets)
self.CLASSES = datasets[0].CLASSES
self.PALETTE = datasets[0].PALETTE
self.separate_eval = separate_eval
assert separate_eval in [True, False], \
f'separate_eval can only be True or False,' \
f'but get {separate_eval}'
def evaluate(self, results, logger=None, **kwargs):
"""Evaluate the results.
Args:
results (list[tuple[torch.Tensor]] | list[str]]): per image
pre_eval results or predict segmentation map for
computing evaluation metric.
logger (logging.Logger | str | None): Logger used for printing
related information during evaluation. Default: None.
Returns:
dict[str: float]: evaluate results of the total dataset
or each separate
dataset if `self.separate_eval=True`.
"""
assert len(results) == self.cumulative_sizes[-1], \
('Dataset and results have different sizes: '
f'{self.cumulative_sizes[-1]} v.s. {len(results)}')
# Check whether all the datasets support evaluation
for dataset in self.datasets:
assert hasattr(dataset, 'evaluate'), \
f'{type(dataset)} does not implement evaluate function'
if self.separate_eval:
dataset_idx = -1
total_eval_results = dict()
for size, dataset in zip(self.cumulative_sizes, self.datasets):
start_idx = 0 if dataset_idx == -1 else \
self.cumulative_sizes[dataset_idx]
end_idx = self.cumulative_sizes[dataset_idx + 1]
results_per_dataset = results[start_idx:end_idx]
print_log(
f'\nEvaluateing {dataset.img_dir} with '
f'{len(results_per_dataset)} images now',
logger=logger)
eval_results_per_dataset = dataset.evaluate(
results_per_dataset, logger=logger, **kwargs)
dataset_idx += 1
for k, v in eval_results_per_dataset.items():
total_eval_results.update({f'{dataset_idx}_{k}': v})
return total_eval_results
if len(set([type(ds) for ds in self.datasets])) != 1:
raise NotImplementedError(
'All the datasets should have same types when '
'self.separate_eval=False')
else:
if mmcv.is_list_of(results, np.ndarray) or mmcv.is_list_of(
results, str):
# merge the generators of gt_seg_maps
gt_seg_maps = chain(
*[dataset.get_gt_seg_maps() for dataset in self.datasets])
else:
# if the results are `pre_eval` results,
# we do not need gt_seg_maps to evaluate
gt_seg_maps = None
eval_results = self.datasets[0].evaluate(
results, gt_seg_maps=gt_seg_maps, logger=logger, **kwargs)
return eval_results
def get_dataset_idx_and_sample_idx(self, indice):
"""Return dataset and sample index when given an indice of
ConcatDataset.
Args:
indice (int): indice of sample in ConcatDataset
Returns:
int: the index of sub dataset the sample belong to
int: the index of sample in its corresponding subset
"""
if indice < 0:
if -indice > len(self):
raise ValueError(
'absolute value of index should not exceed dataset length')
indice = len(self) + indice
dataset_idx = bisect.bisect_right(self.cumulative_sizes, indice)
if dataset_idx == 0:
sample_idx = indice
else:
sample_idx = indice - self.cumulative_sizes[dataset_idx - 1]
return dataset_idx, sample_idx
def format_results(self, results, imgfile_prefix, indices=None, **kwargs):
"""format result for every sample of ConcatDataset."""
if indices is None:
indices = list(range(len(self)))
assert isinstance(results, list), 'results must be a list.'
assert isinstance(indices, list), 'indices must be a list.'
ret_res = []
for i, indice in enumerate(indices):
dataset_idx, sample_idx = self.get_dataset_idx_and_sample_idx(
indice)
res = self.datasets[dataset_idx].format_results(
[results[i]],
imgfile_prefix + f'/{dataset_idx}',
indices=[sample_idx],
**kwargs)
ret_res.append(res)
return sum(ret_res, [])
def pre_eval(self, preds, indices):
"""do pre eval for every sample of ConcatDataset."""
# In order to compat with batch inference
if not isinstance(indices, list):
indices = [indices]
if not isinstance(preds, list):
preds = [preds]
ret_res = []
for i, indice in enumerate(indices):
dataset_idx, sample_idx = self.get_dataset_idx_and_sample_idx(
indice)
res = self.datasets[dataset_idx].pre_eval(preds[i], sample_idx)
ret_res.append(res)
return sum(ret_res, [])
......@@ -4,4 +4,9 @@
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from .backbones import * # noqa: F401,F403
\ No newline at end of file
from .backbones import * # noqa: F401,F403
from .decode_heads import * # noqa: F401,F403
from .losses import * # noqa: F401,F403
from .plugins import * # noqa: F401,F403
from .segmentors import * # noqa: F401,F403
from .utils import * # noqa: F401,F403
\ No newline at end of file
# Copyright (c) OpenMMLab. All rights reserved.
import warnings # noqa: F401,F403
from mmcv.utils import Registry
TRANSFORMER = Registry('Transformer')
MASK_ASSIGNERS = Registry('mask_assigner')
MATCH_COST = Registry('match_cost')
def build_match_cost(cfg):
"""Build Match Cost."""
return MATCH_COST.build(cfg)
def build_assigner(cfg):
"""Build Assigner."""
return MASK_ASSIGNERS.build(cfg)
def build_transformer(cfg):
"""Build Transformer."""
return TRANSFORMER.build(cfg)
# Copyright (c) OpenMMLab. All rights reserved.
from .mask2former_head import Mask2FormerHead
from .maskformer_head import MaskFormerHead
__all__ = [
'MaskFormerHead',
'Mask2FormerHead',
]
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import Conv2d, build_plugin_layer, caffe2_xavier_init
from mmcv.cnn.bricks.transformer import (build_positional_encoding,
build_transformer_layer_sequence)
from mmcv.ops import point_sample
from mmcv.runner import ModuleList, force_fp32
from mmseg.models.builder import HEADS, build_loss
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
from ...core import build_sampler, multi_apply, reduce_mean
from ..builder import build_assigner
from ..utils import get_uncertain_point_coords_with_randomness
@HEADS.register_module()
class Mask2FormerHead(BaseDecodeHead):
"""Implements the Mask2Former head.
See `Masked-attention Mask Transformer for Universal Image
Segmentation <https://arxiv.org/pdf/2112.01527>`_ for details.
Args:
in_channels (list[int]): Number of channels in the input feature map.
feat_channels (int): Number of channels for features.
out_channels (int): Number of channels for output.
num_classes (int): Number of classes.
num_queries (int): Number of query in Transformer decoder.
pixel_decoder (:obj:`mmcv.ConfigDict` | dict): Config for pixel
decoder. Defaults to None.
enforce_decoder_input_project (bool, optional): Whether to add
a layer to change the embed_dim of tranformer encoder in
pixel decoder to the embed_dim of transformer decoder.
Defaults to False.
transformer_decoder (:obj:`mmcv.ConfigDict` | dict): Config for
transformer decoder. Defaults to None.
positional_encoding (:obj:`mmcv.ConfigDict` | dict): Config for
transformer decoder position encoding. Defaults to None.
loss_cls (:obj:`mmcv.ConfigDict` | dict): Config of the classification
loss. Defaults to None.
loss_mask (:obj:`mmcv.ConfigDict` | dict): Config of the mask loss.
Defaults to None.
loss_dice (:obj:`mmcv.ConfigDict` | dict): Config of the dice loss.
Defaults to None.
train_cfg (:obj:`mmcv.ConfigDict` | dict): Training config of
Mask2Former head.
test_cfg (:obj:`mmcv.ConfigDict` | dict): Testing config of
Mask2Former head.
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
in_channels,
feat_channels,
out_channels,
num_classes=80,
num_queries=100,
num_transformer_feat_level=3,
pixel_decoder=None,
enforce_decoder_input_project=False,
transformer_decoder=None,
positional_encoding=None,
loss_cls=None,
loss_mask=None,
loss_dice=None,
train_cfg=None,
test_cfg=None,
init_cfg=None,
**kwargs):
super(Mask2FormerHead, self).__init__(
in_channels=in_channels,
channels=feat_channels,
num_classes=num_classes,
init_cfg=init_cfg,
input_transform='multiple_select',
**kwargs)
self.num_classes = num_classes
self.num_queries = num_queries
self.num_transformer_feat_level = num_transformer_feat_level
self.num_heads = transformer_decoder.transformerlayers. \
attn_cfgs.num_heads
self.num_transformer_decoder_layers = transformer_decoder.num_layers
assert pixel_decoder.encoder.transformerlayers. \
attn_cfgs.num_levels == num_transformer_feat_level
pixel_decoder_ = copy.deepcopy(pixel_decoder)
pixel_decoder_.update(
in_channels=in_channels,
feat_channels=feat_channels,
out_channels=out_channels)
self.pixel_decoder = build_plugin_layer(pixel_decoder_)[1]
self.transformer_decoder = build_transformer_layer_sequence(
transformer_decoder)
self.decoder_embed_dims = self.transformer_decoder.embed_dims
self.decoder_input_projs = ModuleList()
# from low resolution to high resolution
for _ in range(num_transformer_feat_level):
if (self.decoder_embed_dims != feat_channels
or enforce_decoder_input_project):
self.decoder_input_projs.append(
Conv2d(
feat_channels, self.decoder_embed_dims, kernel_size=1))
else:
self.decoder_input_projs.append(nn.Identity())
self.decoder_positional_encoding = build_positional_encoding(
positional_encoding)
self.query_embed = nn.Embedding(self.num_queries, feat_channels)
self.query_feat = nn.Embedding(self.num_queries, feat_channels)
# from low resolution to high resolution
self.level_embed = nn.Embedding(self.num_transformer_feat_level,
feat_channels)
self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1)
self.mask_embed = nn.Sequential(
nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
nn.Linear(feat_channels, out_channels))
self.conv_seg = None # fix a bug here (conv_seg is not used)
self.test_cfg = test_cfg
self.train_cfg = train_cfg
if train_cfg:
self.assigner = build_assigner(self.train_cfg.assigner)
self.sampler = build_sampler(self.train_cfg.sampler, context=self)
self.num_points = self.train_cfg.get('num_points', 12544)
self.oversample_ratio = self.train_cfg.get('oversample_ratio', 3.0)
self.importance_sample_ratio = self.train_cfg.get(
'importance_sample_ratio', 0.75)
self.class_weight = loss_cls.class_weight
self.loss_cls = build_loss(loss_cls)
self.loss_mask = build_loss(loss_mask)
self.loss_dice = build_loss(loss_dice)
def init_weights(self):
for m in self.decoder_input_projs:
if isinstance(m, Conv2d):
caffe2_xavier_init(m, bias=0)
self.pixel_decoder.init_weights()
for p in self.transformer_decoder.parameters():
if p.dim() > 1:
nn.init.xavier_normal_(p)
def get_targets(self, cls_scores_list, mask_preds_list, gt_labels_list,
gt_masks_list, img_metas):
"""Compute classification and mask targets for all images for a decoder
layer.
Args:
cls_scores_list (list[Tensor]): Mask score logits from a single
decoder layer for all images. Each with shape [num_queries,
cls_out_channels].
mask_preds_list (list[Tensor]): Mask logits from a single decoder
layer for all images. Each with shape [num_queries, h, w].
gt_labels_list (list[Tensor]): Ground truth class indices for all
images. Each with shape (n, ), n is the sum of number of stuff
type and number of instance in a image.
gt_masks_list (list[Tensor]): Ground truth mask for each image,
each with shape (n, h, w).
img_metas (list[dict]): List of image meta information.
Returns:
tuple[list[Tensor]]: a tuple containing the following targets.
- labels_list (list[Tensor]): Labels of all images.
Each with shape [num_queries, ].
- label_weights_list (list[Tensor]): Label weights of all
images.Each with shape [num_queries, ].
- mask_targets_list (list[Tensor]): Mask targets of all images.
Each with shape [num_queries, h, w].
- mask_weights_list (list[Tensor]): Mask weights of all images.
Each with shape [num_queries, ].
- num_total_pos (int): Number of positive samples in all
images.
- num_total_neg (int): Number of negative samples in all
images.
"""
(labels_list, label_weights_list, mask_targets_list, mask_weights_list,
pos_inds_list,
neg_inds_list) = multi_apply(self._get_target_single, cls_scores_list,
mask_preds_list, gt_labels_list,
gt_masks_list, img_metas)
num_total_pos = sum((inds.numel() for inds in pos_inds_list))
num_total_neg = sum((inds.numel() for inds in neg_inds_list))
return (labels_list, label_weights_list, mask_targets_list,
mask_weights_list, num_total_pos, num_total_neg)
def _get_target_single(self, cls_score, mask_pred, gt_labels, gt_masks,
img_metas):
"""Compute classification and mask targets for one image.
Args:
cls_score (Tensor): Mask score logits from a single decoder layer
for one image. Shape (num_queries, cls_out_channels).
mask_pred (Tensor): Mask logits for a single decoder layer for one
image. Shape (num_queries, h, w).
gt_labels (Tensor): Ground truth class indices for one image with
shape (num_gts, ).
gt_masks (Tensor): Ground truth mask for each image, each with
shape (num_gts, h, w).
img_metas (dict): Image informtation.
Returns:
tuple[Tensor]: A tuple containing the following for one image.
- labels (Tensor): Labels of each image. \
shape (num_queries, ).
- label_weights (Tensor): Label weights of each image. \
shape (num_queries, ).
- mask_targets (Tensor): Mask targets of each image. \
shape (num_queries, h, w).
- mask_weights (Tensor): Mask weights of each image. \
shape (num_queries, ).
- pos_inds (Tensor): Sampled positive indices for each \
image.
- neg_inds (Tensor): Sampled negative indices for each \
image.
"""
# sample points
num_queries = cls_score.shape[0]
num_gts = gt_labels.shape[0]
point_coords = torch.rand((1, self.num_points, 2),
device=cls_score.device)
# shape (num_queries, num_points)
mask_points_pred = point_sample(
mask_pred.unsqueeze(1), point_coords.repeat(num_queries, 1,
1)).squeeze(1)
# shape (num_gts, num_points)
gt_points_masks = point_sample(
gt_masks.unsqueeze(1).float(), point_coords.repeat(num_gts, 1,
1)).squeeze(1)
# assign and sample
assign_result = self.assigner.assign(cls_score, mask_points_pred,
gt_labels, gt_points_masks,
img_metas)
sampling_result = self.sampler.sample(assign_result, mask_pred,
gt_masks)
pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds
# label target
labels = gt_labels.new_full((self.num_queries, ),
self.num_classes,
dtype=torch.long)
labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
label_weights = gt_labels.new_ones((self.num_queries, ))
# mask target
mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds]
mask_weights = mask_pred.new_zeros((self.num_queries, ))
mask_weights[pos_inds] = 1.0
return (labels, label_weights, mask_targets, mask_weights, pos_inds,
neg_inds)
def loss_single(self, cls_scores, mask_preds, gt_labels_list,
gt_masks_list, img_metas):
"""Loss function for outputs from a single decoder layer.
Args:
cls_scores (Tensor): Mask score logits from a single decoder layer
for all images. Shape (batch_size, num_queries,
cls_out_channels). Note `cls_out_channels` should includes
background.
mask_preds (Tensor): Mask logits for a pixel decoder for all
images. Shape (batch_size, num_queries, h, w).
gt_labels_list (list[Tensor]): Ground truth class indices for each
image, each with shape (num_gts, ).
gt_masks_list (list[Tensor]): Ground truth mask for each image,
each with shape (num_gts, h, w).
img_metas (list[dict]): List of image meta information.
Returns:
tuple[Tensor]: Loss components for outputs from a single \
decoder layer.
"""
num_imgs = cls_scores.size(0)
cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
mask_preds_list = [mask_preds[i] for i in range(num_imgs)]
(labels_list, label_weights_list, mask_targets_list, mask_weights_list,
num_total_pos,
num_total_neg) = self.get_targets(cls_scores_list, mask_preds_list,
gt_labels_list, gt_masks_list,
img_metas)
# shape (batch_size, num_queries)
labels = torch.stack(labels_list, dim=0)
# shape (batch_size, num_queries)
label_weights = torch.stack(label_weights_list, dim=0)
# shape (num_total_gts, h, w)
mask_targets = torch.cat(mask_targets_list, dim=0)
# shape (batch_size, num_queries)
mask_weights = torch.stack(mask_weights_list, dim=0)
# classfication loss
# shape (batch_size * num_queries, )
cls_scores = cls_scores.flatten(0, 1)
labels = labels.flatten(0, 1)
label_weights = label_weights.flatten(0, 1)
class_weight = cls_scores.new_tensor(self.class_weight)
loss_cls = self.loss_cls(
cls_scores,
labels,
label_weights,
avg_factor=class_weight[labels].sum())
num_total_masks = reduce_mean(cls_scores.new_tensor([num_total_pos]))
num_total_masks = max(num_total_masks, 1)
# extract positive ones
# shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w)
mask_preds = mask_preds[mask_weights > 0]
if mask_targets.shape[0] == 0:
# zero match
loss_dice = mask_preds.sum()
loss_mask = mask_preds.sum()
return loss_cls, loss_mask, loss_dice
with torch.no_grad():
points_coords = get_uncertain_point_coords_with_randomness(
mask_preds.unsqueeze(1), None, self.num_points,
self.oversample_ratio, self.importance_sample_ratio)
# shape (num_total_gts, h, w) -> (num_total_gts, num_points)
mask_point_targets = point_sample(
mask_targets.unsqueeze(1).float(), points_coords).squeeze(1)
# shape (num_queries, h, w) -> (num_queries, num_points)
mask_point_preds = point_sample(
mask_preds.unsqueeze(1), points_coords).squeeze(1)
# dice loss
loss_dice = self.loss_dice(
mask_point_preds, mask_point_targets, avg_factor=num_total_masks)
# mask loss
# shape (num_queries, num_points) -> (num_queries * num_points, )
mask_point_preds = mask_point_preds.reshape(-1,1)
# shape (num_total_gts, num_points) -> (num_total_gts * num_points, )
mask_point_targets = mask_point_targets.reshape(-1)
loss_mask = self.loss_mask(
mask_point_preds,
mask_point_targets,
avg_factor=num_total_masks * self.num_points)
return loss_cls, loss_mask, loss_dice
@force_fp32(apply_to=('all_cls_scores', 'all_mask_preds'))
def loss(self, all_cls_scores, all_mask_preds, gt_labels_list,
gt_masks_list, img_metas):
"""Loss function.
Args:
all_cls_scores (Tensor): Classification scores for all decoder
layers with shape [num_decoder, batch_size, num_queries,
cls_out_channels].
all_mask_preds (Tensor): Mask scores for all decoder layers with
shape [num_decoder, batch_size, num_queries, h, w].
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (n, ). n is the sum of number of stuff type
and number of instance in a image.
gt_masks_list (list[Tensor]): Ground truth mask for each image with
shape (n, h, w).
img_metas (list[dict]): List of image meta information.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
num_dec_layers = len(all_cls_scores)
all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)]
all_gt_masks_list = [gt_masks_list for _ in range(num_dec_layers)]
img_metas_list = [img_metas for _ in range(num_dec_layers)]
losses_cls, losses_mask, losses_dice = multi_apply(
self.loss_single, all_cls_scores, all_mask_preds,
all_gt_labels_list, all_gt_masks_list, img_metas_list)
loss_dict = dict()
# loss from the last decoder layer
loss_dict['loss_cls'] = losses_cls[-1]
loss_dict['loss_mask'] = losses_mask[-1]
loss_dict['loss_dice'] = losses_dice[-1]
# loss from other decoder layers
num_dec_layer = 0
for loss_cls_i, loss_mask_i, loss_dice_i in zip(
losses_cls[:-1], losses_mask[:-1], losses_dice[:-1]):
loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i
loss_dict[f'd{num_dec_layer}.loss_mask'] = loss_mask_i
loss_dict[f'd{num_dec_layer}.loss_dice'] = loss_dice_i
num_dec_layer += 1
return loss_dict
def forward_head(self, decoder_out, mask_feature, attn_mask_target_size):
"""Forward for head part which is called after every decoder layer.
Args:
decoder_out (Tensor): in shape (num_queries, batch_size, c).
mask_feature (Tensor): in shape (batch_size, c, h, w).
attn_mask_target_size (tuple[int, int]): target attention
mask size.
Returns:
tuple: A tuple contain three elements.
- cls_pred (Tensor): Classification scores in shape \
(batch_size, num_queries, cls_out_channels). \
Note `cls_out_channels` should includes background.
- mask_pred (Tensor): Mask scores in shape \
(batch_size, num_queries,h, w).
- attn_mask (Tensor): Attention mask in shape \
(batch_size * num_heads, num_queries, h, w).
"""
decoder_out = self.transformer_decoder.post_norm(decoder_out)
decoder_out = decoder_out.transpose(0, 1)
# shape (num_queries, batch_size, c)
cls_pred = self.cls_embed(decoder_out)
# shape (num_queries, batch_size, c)
mask_embed = self.mask_embed(decoder_out)
# shape (num_queries, batch_size, h, w)
mask_pred = torch.einsum('bqc,bchw->bqhw', mask_embed, mask_feature)
attn_mask = F.interpolate(
mask_pred,
attn_mask_target_size,
mode='bilinear',
align_corners=False)
# shape (num_queries, batch_size, h, w) ->
# (batch_size * num_head, num_queries, h, w)
attn_mask = attn_mask.flatten(2).unsqueeze(1).repeat(
(1, self.num_heads, 1, 1)).flatten(0, 1)
attn_mask = attn_mask.sigmoid() < 0.5
attn_mask = attn_mask.detach()
return cls_pred, mask_pred, attn_mask
def forward(self, feats, img_metas):
"""Forward function.
Args:
feats (list[Tensor]): Multi scale Features from the
upstream network, each is a 4D-tensor.
img_metas (list[dict]): List of image information.
Returns:
tuple: A tuple contains two elements.
- cls_pred_list (list[Tensor)]: Classification logits \
for each decoder layer. Each is a 3D-tensor with shape \
(batch_size, num_queries, cls_out_channels). \
Note `cls_out_channels` should includes background.
- mask_pred_list (list[Tensor]): Mask logits for each \
decoder layer. Each with shape (batch_size, num_queries, \
h, w).
"""
batch_size = len(img_metas)
mask_features, multi_scale_memorys = self.pixel_decoder(feats)
# multi_scale_memorys (from low resolution to high resolution)
decoder_inputs = []
decoder_positional_encodings = []
for i in range(self.num_transformer_feat_level):
decoder_input = self.decoder_input_projs[i](multi_scale_memorys[i])
# shape (batch_size, c, h, w) -> (h*w, batch_size, c)
decoder_input = decoder_input.flatten(2).permute(2, 0, 1)
level_embed = self.level_embed.weight[i].view(1, 1, -1)
decoder_input = decoder_input + level_embed
# shape (batch_size, c, h, w) -> (h*w, batch_size, c)
mask = decoder_input.new_zeros(
(batch_size, ) + multi_scale_memorys[i].shape[-2:],
dtype=torch.bool)
decoder_positional_encoding = self.decoder_positional_encoding(
mask)
decoder_positional_encoding = decoder_positional_encoding.flatten(
2).permute(2, 0, 1)
decoder_inputs.append(decoder_input)
decoder_positional_encodings.append(decoder_positional_encoding)
# shape (num_queries, c) -> (num_queries, batch_size, c)
query_feat = self.query_feat.weight.unsqueeze(1).repeat(
(1, batch_size, 1))
query_embed = self.query_embed.weight.unsqueeze(1).repeat(
(1, batch_size, 1))
cls_pred_list = []
mask_pred_list = []
cls_pred, mask_pred, attn_mask = self.forward_head(
query_feat, mask_features, multi_scale_memorys[0].shape[-2:])
cls_pred_list.append(cls_pred)
mask_pred_list.append(mask_pred)
for i in range(self.num_transformer_decoder_layers):
level_idx = i % self.num_transformer_feat_level
# if a mask is all True(all background), then set it all False.
attn_mask[torch.where(
attn_mask.sum(-1) == attn_mask.shape[-1])] = False
# cross_attn + self_attn
layer = self.transformer_decoder.layers[i]
attn_masks = [attn_mask, None]
query_feat = layer(
query=query_feat,
key=decoder_inputs[level_idx],
value=decoder_inputs[level_idx],
query_pos=query_embed,
key_pos=decoder_positional_encodings[level_idx],
attn_masks=attn_masks,
query_key_padding_mask=None,
# here we do not apply masking on padded region
key_padding_mask=None)
cls_pred, mask_pred, attn_mask = self.forward_head(
query_feat, mask_features, multi_scale_memorys[
(i + 1) % self.num_transformer_feat_level].shape[-2:])
cls_pred_list.append(cls_pred)
mask_pred_list.append(mask_pred)
return cls_pred_list, mask_pred_list
def forward_train(self, x, img_metas, gt_semantic_seg, gt_labels,
gt_masks):
"""Forward function for training mode.
Args:
x (list[Tensor]): Multi-level features from the upstream network,
each is a 4D-tensor.
img_metas (list[Dict]): List of image information.
gt_semantic_seg (list[tensor]):Each element is the ground truth
of semantic segmentation with the shape (N, H, W).
train_cfg (dict): The training config, which not been used in
maskformer.
gt_labels (list[Tensor]): Each element is ground truth labels of
each box, shape (num_gts,).
gt_masks (list[BitmapMasks]): Each element is masks of instances
of a image, shape (num_gts, h, w).
Returns:
losses (dict[str, Tensor]): a dictionary of loss components
"""
# forward
all_cls_scores, all_mask_preds = self(x, img_metas)
# loss
losses = self.loss(all_cls_scores, all_mask_preds, gt_labels, gt_masks,
img_metas)
return losses
def forward_test(self, inputs, img_metas, test_cfg):
"""Test segment without test-time aumengtation.
Only the output of last decoder layers was used.
Args:
inputs (list[Tensor]): Multi-level features from the
upstream network, each is a 4D-tensor.
img_metas (list[dict]): List of image information.
test_cfg (dict): Testing config.
Returns:
seg_mask (Tensor): Predicted semantic segmentation logits.
"""
all_cls_scores, all_mask_preds = self(inputs, img_metas)
cls_score, mask_pred = all_cls_scores[-1], all_mask_preds[-1]
ori_h, ori_w, _ = img_metas[0]['ori_shape']
# semantic inference
cls_score = F.softmax(cls_score, dim=-1)[..., :-1]
mask_pred = mask_pred.sigmoid()
seg_mask = torch.einsum('bqc,bqhw->bchw', cls_score, mask_pred)
return seg_mask
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import Conv2d, build_plugin_layer, kaiming_init
from mmcv.cnn.bricks.transformer import (build_positional_encoding,
build_transformer_layer_sequence)
from mmcv.runner import force_fp32
from mmseg.models.builder import HEADS, build_loss
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
from ...core import multi_apply, reduce_mean
from ..builder import build_assigner, build_transformer
@HEADS.register_module()
class MaskFormerHead(BaseDecodeHead):
"""Implements the MaskFormer head.
See `paper: Per-Pixel Classification is Not All You Need
for Semantic Segmentation<https://arxiv.org/pdf/2107.06278>`
for details.
Args:
in_channels (list[int]): Number of channels in the input feature map.
feat_channels (int): Number channels for feature.
out_channels (int): Number channels for output.
num_things_classes (int): Number of things.
num_stuff_classes (int): Number of stuff.
num_queries (int): Number of query in Transformer.
pixel_decoder (obj:`mmcv.ConfigDict`|dict): Config for pixel decoder.
Defaults to None.
enforce_decoder_input_project (bool, optional): Whether to add a layer
to change the embed_dim of tranformer encoder in pixel decoder to
the embed_dim of transformer decoder. Defaults to False.
transformer_decoder (obj:`mmcv.ConfigDict`|dict): Config for
transformer decoder. Defaults to None.
positional_encoding (obj:`mmcv.ConfigDict`|dict): Config for
transformer decoder position encoding. Defaults to None.
loss_cls (obj:`mmcv.ConfigDict`|dict): Config of the classification
loss. Defaults to `CrossEntropyLoss`.
loss_mask (obj:`mmcv.ConfigDict`|dict): Config of the mask loss.
Defaults to `FocalLoss`.
loss_dice (obj:`mmcv.ConfigDict`|dict): Config of the dice loss.
Defaults to `DiceLoss`.
train_cfg (obj:`mmcv.ConfigDict`|dict): Training config of Maskformer
head.
test_cfg (obj:`mmcv.ConfigDict`|dict): Testing config of Maskformer
head.
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
out_channels,
num_queries=100,
pixel_decoder=None,
enforce_decoder_input_project=False,
transformer_decoder=None,
positional_encoding=None,
loss_cls=dict(
type='CrossEntropyLoss',
bg_cls_weight=0.1,
use_sigmoid=False,
loss_weight=1.0,
class_weight=1.0),
loss_mask=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=20.0),
loss_dice=dict(
type='DiceLoss',
use_sigmoid=True,
activate=True,
naive_dice=True,
loss_weight=1.0),
assigner=dict(
type='MaskHungarianAssigner',
cls_cost=dict(type='ClassificationCost', weight=1.),
dice_cost=dict(type='DiceCost', weight=1.0, pred_act=True,
eps=1.0),
mask_cost=dict(type='MaskFocalLossCost', weight=20.0)),
**kwargs):
super(MaskFormerHead, self).__init__(input_transform='multiple_select',
**kwargs)
self.num_queries = num_queries
pixel_decoder.update(
in_channels=self.in_channels,
feat_channels=self.channels,
out_channels=out_channels)
self.pixel_decoder = build_plugin_layer(pixel_decoder)[1]
self.transformer_decoder = build_transformer_layer_sequence(
transformer_decoder)
self.decoder_embed_dims = self.transformer_decoder.embed_dims
pixel_decoder_type = pixel_decoder.get('type')
if pixel_decoder_type == 'PixelDecoder' and (
self.decoder_embed_dims != self.in_channels[-1]
or enforce_decoder_input_project):
self.decoder_input_proj = Conv2d(
self.in_channels[-1], self.decoder_embed_dims, kernel_size=1)
else:
self.decoder_input_proj = nn.Identity()
self.decoder_pe = build_positional_encoding(positional_encoding)
self.query_embed = nn.Embedding(self.num_queries, out_channels)
self.cls_embed = nn.Linear(self.channels, self.num_classes + 1)
self.mask_embed = nn.Sequential(
nn.Linear(self.channels, self.channels), nn.ReLU(inplace=True),
nn.Linear(self.channels, self.channels), nn.ReLU(inplace=True),
nn.Linear(self.channels, out_channels))
self.assigner = build_assigner(assigner)
self.bg_cls_weight = 0
class_weight = loss_cls.get('class_weight', None)
if class_weight is not None and (self.__class__ is MaskFormerHead):
assert isinstance(class_weight, float), 'Expected ' \
'class_weight to have type float. Found ' \
f'{type(class_weight)}.'
# NOTE following the official MaskFormerHead repo, bg_cls_weight
# means relative classification weight of the VOID class.
bg_cls_weight = loss_cls.get('bg_cls_weight', class_weight)
assert isinstance(bg_cls_weight, float), 'Expected ' \
'bg_cls_weight to have type float. Found ' \
f'{type(bg_cls_weight)}.'
class_weight = (self.num_classes + 1) * [class_weight]
# set VOID class as the last indice
class_weight[self.num_classes] = bg_cls_weight
loss_cls.update({'class_weight': class_weight})
if 'bg_cls_weight' in loss_cls:
loss_cls.pop('bg_cls_weight')
self.bg_cls_weight = bg_cls_weight
assert loss_cls['loss_weight'] == assigner['cls_cost']['weight'], \
'The classification weight for loss and matcher should be' \
'exactly the same.'
assert loss_dice['loss_weight'] == assigner['dice_cost']['weight'], \
f'The dice weight for loss and matcher' \
f'should be exactly the same.'
assert loss_mask['loss_weight'] == assigner['mask_cost']['weight'], \
'The focal weight for loss and matcher should be' \
'exactly the same.'
self.loss_cls = build_loss(loss_cls)
self.loss_mask = build_loss(loss_mask)
self.loss_dice = build_loss(loss_dice)
self.init_weights()
def init_weights(self):
kaiming_init(self.decoder_input_proj, a=1)
def get_targets(self, cls_scores_list, mask_preds_list, gt_labels_list,
gt_masks_list, img_metas):
"""Compute classification and mask targets for all images for a decoder
layer.
Args:
cls_scores_list (list[Tensor]): Mask score logits from a single
decoder layer for all images. Each with shape [num_queries,
cls_out_channels].
mask_preds_list (list[Tensor]): Mask logits from a single decoder
layer for all images. Each with shape [num_queries, h, w].
gt_labels_list (list[Tensor]): Ground truth class indices for all
images. Each with shape (n, ), n is the sum of number of stuff
type and number of instance in a image.
gt_masks_list (list[Tensor]): Ground truth mask for each image,
each with shape (n, h, w).
img_metas (list[dict]): List of image meta information.
Returns:
tuple[list[Tensor]]: a tuple containing the following targets.
- labels_list (list[Tensor]): Labels of all images.
Each with shape [num_queries, ].
- label_weights_list (list[Tensor]): Label weights of all
images.Each with shape [num_queries, ].
- mask_targets_list (list[Tensor]): Mask targets of all images.
Each with shape [num_queries, h, w].
- mask_weights_list (list[Tensor]): Mask weights of all images.
Each with shape [num_queries, ].
- num_total_pos (int): Number of positive samples in all
images.
- num_total_neg (int): Number of negative samples in all
images.
"""
(labels_list, label_weights_list, mask_targets_list, mask_weights_list,
pos_inds_list,
neg_inds_list) = multi_apply(self._get_target_single, cls_scores_list,
mask_preds_list, gt_labels_list,
gt_masks_list, img_metas)
num_total_pos = sum((inds.numel() for inds in pos_inds_list))
num_total_neg = sum((inds.numel() for inds in neg_inds_list))
return (labels_list, label_weights_list, mask_targets_list,
mask_weights_list, num_total_pos, num_total_neg)
def _get_target_single(self, cls_score, mask_pred, gt_labels, gt_masks,
img_metas):
"""Compute classification and mask targets for one image.
Args:
cls_score (Tensor): Mask score logits from a single decoder layer
for one image. Shape [num_queries, cls_out_channels].
mask_pred (Tensor): Mask logits for a single decoder layer for one
image. Shape [num_queries, h, w].
gt_labels (Tensor): Ground truth class indices for one image with
shape (n, ). n is the sum of number of stuff type and number
of instance in a image.
gt_masks (Tensor): Ground truth mask for each image, each with
shape (n, h, w).
img_metas (dict): Image informtation.
Returns:
tuple[Tensor]: a tuple containing the following for one image.
- labels (Tensor): Labels of each image.
shape [num_queries, ].
- label_weights (Tensor): Label weights of each image.
shape [num_queries, ].
- mask_targets (Tensor): Mask targets of each image.
shape [num_queries, h, w].
- mask_weights (Tensor): Mask weights of each image.
shape [num_queries, ].
- pos_inds (Tensor): Sampled positive indices for each image.
- neg_inds (Tensor): Sampled negative indices for each image.
"""
target_shape = mask_pred.shape[-2:]
gt_masks_downsampled = F.interpolate(
gt_masks.unsqueeze(1).float(), target_shape,
mode='nearest').squeeze(1).long()
# assign and sample
assign_result = self.assigner.assign(cls_score, mask_pred, gt_labels,
gt_masks_downsampled, img_metas)
# pos_ind: range from 1 to (self.num_classes)
# which represents the positive index
pos_inds = torch.nonzero(assign_result.gt_inds > 0,
as_tuple=False).squeeze(-1).unique()
neg_inds = torch.nonzero(assign_result.gt_inds == 0,
as_tuple=False).squeeze(-1).unique()
pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1
# label target
labels = gt_labels.new_full((self.num_queries, ),
self.num_classes,
dtype=torch.long)
labels[pos_inds] = gt_labels[pos_assigned_gt_inds]
label_weights = gt_labels.new_ones(self.num_queries)
# mask target
mask_targets = gt_masks[pos_assigned_gt_inds, :]
mask_weights = mask_pred.new_zeros((self.num_queries, ))
mask_weights[pos_inds] = 1.0
return (labels, label_weights, mask_targets, mask_weights, pos_inds,
neg_inds)
@force_fp32(apply_to=('all_cls_scores', 'all_mask_preds'))
def loss(self, all_cls_scores, all_mask_preds, gt_labels_list,
gt_masks_list, img_metas):
"""Loss function.
Args:
all_cls_scores (Tensor): Classification scores for all decoder
layers with shape [num_decoder, batch_size, num_queries,
cls_out_channels].
all_mask_preds (Tensor): Mask scores for all decoder layers with
shape [num_decoder, batch_size, num_queries, h, w].
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (n, ). n is the sum of number of stuff type
and number of instance in a image.
gt_masks_list (list[Tensor]): Ground truth mask for each image with
shape (n, h, w).
img_metas (list[dict]): List of image meta information.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
num_dec_layers = len(all_cls_scores)
all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)]
all_gt_masks_list = [gt_masks_list for _ in range(num_dec_layers)]
img_metas_list = [img_metas for _ in range(num_dec_layers)]
losses_cls, losses_mask, losses_dice = multi_apply(
self.loss_single, all_cls_scores, all_mask_preds,
all_gt_labels_list, all_gt_masks_list, img_metas_list)
loss_dict = dict()
# loss from the last decoder layer
loss_dict['loss_cls'] = losses_cls[-1]
loss_dict['loss_mask'] = losses_mask[-1]
loss_dict['loss_dice'] = losses_dice[-1]
# loss from other decoder layers
num_dec_layer = 0
for loss_cls_i, loss_mask_i, loss_dice_i in zip(
losses_cls[:-1], losses_mask[:-1], losses_dice[:-1]):
loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i
loss_dict[f'd{num_dec_layer}.loss_mask'] = loss_mask_i
loss_dict[f'd{num_dec_layer}.loss_dice'] = loss_dice_i
num_dec_layer += 1
return loss_dict
def loss_single(self, cls_scores, mask_preds, gt_labels_list,
gt_masks_list, img_metas):
"""Loss function for outputs from a single decoder layer.
Args:
cls_scores (Tensor): Mask score logits from a single decoder layer
for all images. Shape [batch_size, num_queries,
cls_out_channels].
mask_preds (Tensor): Mask logits for a pixel decoder for all
images. Shape [batch_size, num_queries, h, w].
gt_labels_list (list[Tensor]): Ground truth class indices for each
image, each with shape (n, ). n is the sum of number of stuff
types and number of instances in a image.
gt_masks_list (list[Tensor]): Ground truth mask for each image,
each with shape (n, h, w).
img_metas (list[dict]): List of image meta information.
Returns:
tuple[Tensor]:Loss components for outputs from a single decoder
layer.
"""
num_imgs = cls_scores.size(0)
cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
mask_preds_list = [mask_preds[i] for i in range(num_imgs)]
(labels_list, label_weights_list, mask_targets_list, mask_weights_list,
num_total_pos,
num_total_neg) = self.get_targets(cls_scores_list, mask_preds_list,
gt_labels_list, gt_masks_list,
img_metas)
# shape [batch_size, num_queries]
labels = torch.stack(labels_list, dim=0)
# shape [batch_size, num_queries]
label_weights = torch.stack(label_weights_list, dim=0)
# shape [num_gts, h, w]
mask_targets = torch.cat(mask_targets_list, dim=0)
# shape [batch_size, num_queries]
mask_weights = torch.stack(mask_weights_list, dim=0)
# classfication loss
# shape [batch_size * num_queries, ]
cls_scores = cls_scores.flatten(0, 1)
# shape [batch_size * num_queries, ]
labels = labels.flatten(0, 1)
# shape [batch_size* num_queries, ]
label_weights = label_weights.flatten(0, 1)
class_weight = cls_scores.new_ones(self.num_classes + 1)
class_weight[-1] = self.bg_cls_weight
loss_cls = self.loss_cls(
cls_scores,
labels,
label_weights,
avg_factor=class_weight[labels].sum())
num_total_masks = reduce_mean(cls_scores.new_tensor([num_total_pos]))
num_total_masks = max(num_total_masks, 1)
# extract positive ones
mask_preds = mask_preds[mask_weights > 0]
target_shape = mask_targets.shape[-2:]
if mask_targets.shape[0] == 0:
# zero match
loss_dice = mask_preds.sum()
loss_mask = mask_preds.sum()
return loss_cls, loss_mask, loss_dice
# upsample to shape of target
# shape [num_gts, h, w]
mask_preds = F.interpolate(
mask_preds.unsqueeze(1),
target_shape,
mode='bilinear',
align_corners=False).squeeze(1)
# dice loss
loss_dice = self.loss_dice(
mask_preds, mask_targets, avg_factor=num_total_masks)
# mask loss
# FocalLoss support input of shape [n, num_class]
h, w = mask_preds.shape[-2:]
# shape [num_gts, h, w] -> [num_gts * h * w, 1]
mask_preds = mask_preds.reshape(-1, 1)
# shape [num_gts, h, w] -> [num_gts * h * w]
mask_targets = mask_targets.reshape(-1)
# target is (1 - mask_targets) !!!
print("mask_pred:", mask_preds.shape)
print("mask_targets:", mask_targets.shape)
loss_mask = self.loss_mask(
mask_preds, 1 - mask_targets, avg_factor=num_total_masks * h * w)
return loss_cls, loss_mask, loss_dice
def forward(self, feats, img_metas):
"""Forward function.
Args:
feats (list[Tensor]): Features from the upstream network, each
is a 4D-tensor.
img_metas (list[dict]): List of image information.
Returns:
all_cls_scores (Tensor): Classification scores for each
scale level. Each is a 4D-tensor with shape
[num_decoder, batch_size, num_queries, cls_out_channels].
Note `cls_out_channels` should includes background.
all_mask_preds (Tensor): Mask scores for each decoder
layer. Each with shape [num_decoder, batch_size,
num_queries, h, w].
"""
batch_size = len(img_metas)
input_img_h, input_img_w = img_metas[0]['pad_shape'][:-1]
# input_img_h, input_img_w = img_metas[0]['batch_input_shape']
padding_mask = feats[-1].new_ones(
(batch_size, input_img_h, input_img_w), dtype=torch.float32)
for i in range(batch_size):
img_h, img_w, _ = img_metas[i]['img_shape']
padding_mask[i, :img_h, :img_w] = 0
padding_mask = F.interpolate(
padding_mask.unsqueeze(1),
size=feats[-1].shape[-2:],
mode='nearest').to(torch.bool).squeeze(1)
# when backbone is swin, memory is output of last stage of swin.
# when backbone is r50, memory is output of tranformer encoder.
mask_features, memory = self.pixel_decoder(feats, img_metas)
pos_embed = self.decoder_pe(padding_mask)
memory = self.decoder_input_proj(memory)
# shape [batch_size, c, h, w] -> [h*w, batch_size, c]
memory = memory.flatten(2).permute(2, 0, 1)
pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
# shape [batch_size, h * w]
padding_mask = padding_mask.flatten(1)
# shape = [num_queries, embed_dims]
query_embed = self.query_embed.weight
# shape = [num_queries, batch_size, embed_dims]
query_embed = query_embed.unsqueeze(1).repeat(1, batch_size, 1)
target = torch.zeros_like(query_embed)
# shape [num_decoder, num_queries, batch_size, embed_dims]
out_dec = self.transformer_decoder(
query=target,
key=memory,
value=memory,
key_pos=pos_embed,
query_pos=query_embed,
key_padding_mask=padding_mask)
# shape [num_decoder, batch_size, num_queries, embed_dims]
out_dec = out_dec.transpose(1, 2)
# cls_scores
all_cls_scores = self.cls_embed(out_dec)
# mask_preds
mask_embed = self.mask_embed(out_dec)
all_mask_preds = torch.einsum('lbqc,bchw->lbqhw', mask_embed,
mask_features)
return all_cls_scores, all_mask_preds
def forward_train(self,
x,
img_metas,
gt_semantic_seg,
gt_labels,
gt_masks):
"""Forward function for training mode.
Args:
x (list[Tensor]): Multi-level features from the upstream network,
each is a 4D-tensor.
img_metas (list[Dict]): List of image information.
gt_semantic_seg (list[tensor]):Each element is the ground truth
of semantic segmentation with the shape (N, H, W).
train_cfg (dict): The training config, which not been used in
maskformer.
gt_labels (list[Tensor]): Each element is ground truth labels of
each box, shape (num_gts,).
gt_masks (list[BitmapMasks]): Each element is masks of instances
of a image, shape (num_gts, h, w).
Returns:
losses (dict[str, Tensor]): a dictionary of loss components
"""
# forward
all_cls_scores, all_mask_preds = self(x, img_metas)
# loss
losses = self.loss(all_cls_scores, all_mask_preds, gt_labels, gt_masks,
img_metas)
return losses
def forward_test(self, inputs, img_metas, test_cfg):
"""Test segment without test-time aumengtation.
Only the output of last decoder layers was used.
Args:
inputs (list[Tensor]): Multi-level features from the
upstream network, each is a 4D-tensor.
img_metas (list[dict]): List of image information.
test_cfg (dict): Testing config.
Returns:
seg_mask (Tensor): Predicted semantic segmentation logits.
"""
all_cls_scores, all_mask_preds = self(inputs, img_metas)
cls_score, mask_pred = all_cls_scores[-1], all_mask_preds[-1]
ori_h, ori_w, _ = img_metas[0]['ori_shape']
# semantic inference
cls_score = F.softmax(cls_score, dim=-1)[..., :-1]
mask_pred = mask_pred.sigmoid()
seg_mask = torch.einsum('bqc,bqhw->bchw', cls_score, mask_pred)
return seg_mask
# Copyright (c) OpenMMLab. All rights reserved.
from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
cross_entropy, mask_cross_entropy)
from .dice_loss import DiceLoss
from .focal_loss import FocalLoss
from .match_costs import (ClassificationCost, CrossEntropyLossCost, DiceCost,
MaskFocalLossCost)
__all__ = [
'cross_entropy', 'binary_cross_entropy', 'mask_cross_entropy',
'CrossEntropyLoss', 'DiceLoss', 'FocalLoss', 'ClassificationCost',
'MaskFocalLossCost', 'DiceCost', 'CrossEntropyLossCost'
]
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmseg.models.builder import LOSSES
from mmseg.models.losses.utils import get_class_weight, weight_reduce_loss
def cross_entropy(pred,
label,
weight=None,
class_weight=None,
reduction='mean',
avg_factor=None,
ignore_index=-100,
avg_non_ignore=False):
"""cross_entropy. The wrapper function for :func:`F.cross_entropy`
Args:
pred (torch.Tensor): The prediction with shape (N, 1).
label (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
Default: None.
class_weight (list[float], optional): The weight for each class.
Default: None.
reduction (str, optional): The method used to reduce the loss.
Options are 'none', 'mean' and 'sum'. Default: 'mean'.
avg_factor (int, optional): Average factor that is used to average
the loss. Default: None.
ignore_index (int): Specifies a target value that is ignored and
does not contribute to the input gradients. When
``avg_non_ignore `` is ``True``, and the ``reduction`` is
``''mean''``, the loss is averaged over non-ignored targets.
Defaults: -100.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
`New in version 0.23.0.`
"""
# class_weight is a manual rescaling weight given to each class.
# If given, has to be a Tensor of size C element-wise losses
loss = F.cross_entropy(
pred,
label,
weight=class_weight,
reduction='none',
ignore_index=ignore_index)
# apply weights and do the reduction
# average loss over non-ignored elements
# pytorch's official cross_entropy average loss over non-ignored elements
# refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa
if (avg_factor is None) and avg_non_ignore and reduction == 'mean':
avg_factor = label.numel() - (label == ignore_index).sum().item()
if weight is not None:
weight = weight.float()
loss = weight_reduce_loss(
loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
return loss
def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index):
"""Expand onehot labels to match the size of prediction."""
bin_labels = labels.new_zeros(target_shape)
valid_mask = (labels >= 0) & (labels != ignore_index)
inds = torch.nonzero(valid_mask, as_tuple=True)
if inds[0].numel() > 0:
if labels.dim() == 3:
bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1
else:
bin_labels[inds[0], labels[valid_mask]] = 1
valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float()
if label_weights is None:
bin_label_weights = valid_mask
else:
bin_label_weights = label_weights.unsqueeze(1).expand(target_shape)
bin_label_weights = bin_label_weights * valid_mask
return bin_labels, bin_label_weights, valid_mask
def binary_cross_entropy(pred,
label,
weight=None,
reduction='mean',
avg_factor=None,
class_weight=None,
ignore_index=-100,
avg_non_ignore=False,
**kwargs):
"""Calculate the binary CrossEntropy loss.
Args:
pred (torch.Tensor): The prediction with shape (N, 1).
label (torch.Tensor): The learning label of the prediction.
Note: In bce loss, label < 0 is invalid.
weight (torch.Tensor, optional): Sample-wise loss weight.
reduction (str, optional): The method used to reduce the loss.
Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (list[float], optional): The weight for each class.
ignore_index (int): The label index to be ignored. Default: -100.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
`New in version 0.23.0.`
Returns:
torch.Tensor: The calculated loss
"""
if pred.size(1) == 1:
# For binary class segmentation, the shape of pred is
# [N, 1, H, W] and that of label is [N, H, W].
assert label.max() <= 1, \
'For pred with shape [N, 1, H, W], its label must have at ' \
'most 2 classes'
pred = pred.squeeze()
if pred.dim() != label.dim():
assert (pred.dim() == 2 and label.dim() == 1) or (
pred.dim() == 4 and label.dim() == 3), \
'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \
'H, W], label shape [N, H, W] are supported'
# `weight` returned from `_expand_onehot_labels`
# has been treated for valid (non-ignore) pixels
label, weight, valid_mask = _expand_onehot_labels(
label, weight, pred.shape, ignore_index)
else:
# should mask out the ignored elements
valid_mask = ((label >= 0) & (label != ignore_index)).float()
if weight is not None:
weight = weight * valid_mask
else:
weight = valid_mask
# average loss over non-ignored and valid elements
if reduction == 'mean' and avg_factor is None and avg_non_ignore:
avg_factor = valid_mask.sum().item()
loss = F.binary_cross_entropy_with_logits(
pred, label.float(), pos_weight=class_weight, reduction='none')
# do the reduction for the weighted loss
loss = weight_reduce_loss(
loss, weight, reduction=reduction, avg_factor=avg_factor)
return loss
def mask_cross_entropy(pred,
target,
label,
reduction='mean',
avg_factor=None,
class_weight=None,
ignore_index=None,
**kwargs):
"""Calculate the CrossEntropy loss for masks.
Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the number
of classes.
target (torch.Tensor): The learning label of the prediction.
label (torch.Tensor): ``label`` indicates the class label of the mask'
corresponding object. This will be used to select the mask in the
of the class which the object belongs to when the mask prediction
if not class-agnostic.
reduction (str, optional): The method used to reduce the loss.
Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (list[float], optional): The weight for each class.
ignore_index (None): Placeholder, to be consistent with other loss.
Default: None.
Returns:
torch.Tensor: The calculated loss
"""
assert ignore_index is None, 'BCE loss does not support ignore_index'
# TODO: handle these two reserved arguments
assert reduction == 'mean' and avg_factor is None
num_rois = pred.size()[0]
inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
pred_slice = pred[inds, label].squeeze(1)
return F.binary_cross_entropy_with_logits(
pred_slice, target, weight=class_weight, reduction='mean')[None]
@LOSSES.register_module(force=True)
class CrossEntropyLoss(nn.Module):
"""CrossEntropyLoss.
Args:
use_sigmoid (bool, optional): Whether the prediction uses sigmoid
of softmax. Defaults to False.
use_mask (bool, optional): Whether to use mask cross entropy loss.
Defaults to False.
reduction (str, optional): . Defaults to 'mean'.
Options are "none", "mean" and "sum".
class_weight (list[float] | str, optional): Weight of each class. If in
str format, read them from a file. Defaults to None.
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_ce'.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
`New in version 0.23.0.`
"""
def __init__(self,
use_sigmoid=False,
use_mask=False,
reduction='mean',
class_weight=None,
loss_weight=1.0,
loss_name='loss_ce',
avg_non_ignore=False):
super(CrossEntropyLoss, self).__init__()
assert (use_sigmoid is False) or (use_mask is False)
self.use_sigmoid = use_sigmoid
self.use_mask = use_mask
self.reduction = reduction
self.loss_weight = loss_weight
self.class_weight = get_class_weight(class_weight)
self.avg_non_ignore = avg_non_ignore
if not self.avg_non_ignore and self.reduction == 'mean':
warnings.warn(
'Default ``avg_non_ignore`` is False, if you would like to '
'ignore the certain label and average loss over non-ignore '
'labels, which is the same with PyTorch official '
'cross_entropy, set ``avg_non_ignore=True``.')
if self.use_sigmoid:
self.cls_criterion = binary_cross_entropy
elif self.use_mask:
self.cls_criterion = mask_cross_entropy
else:
self.cls_criterion = cross_entropy
self._loss_name = loss_name
def extra_repr(self):
"""Extra repr."""
s = f'avg_non_ignore={self.avg_non_ignore}'
return s
def forward(self,
cls_score,
label,
weight=None,
avg_factor=None,
reduction_override=None,
ignore_index=-100,
**kwargs):
"""Forward function."""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (reduction_override
if reduction_override else self.reduction)
if self.class_weight is not None:
class_weight = cls_score.new_tensor(self.class_weight)
else:
class_weight = None
# Note: for BCE loss, label < 0 is invalid.
loss_cls = self.loss_weight * self.cls_criterion(
cls_score,
label,
weight,
class_weight=class_weight,
reduction=reduction,
avg_factor=avg_factor,
avg_non_ignore=self.avg_non_ignore,
ignore_index=ignore_index,
**kwargs)
return loss_cls
@property
def loss_name(self):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return self._loss_name
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmseg.models.builder import LOSSES
from mmseg.models.losses.utils import weight_reduce_loss
def dice_loss(pred,
target,
weight=None,
eps=1e-3,
reduction='mean',
avg_factor=None):
"""Calculate dice loss, which is proposed in
`V-Net: Fully Convolutional Neural Networks for Volumetric
Medical Image Segmentation <https://arxiv.org/abs/1606.04797>`_.
Args:
pred (torch.Tensor): The prediction, has a shape (n, *)
target (torch.Tensor): The learning label of the prediction,
shape (n, *), same shape of pred.
weight (torch.Tensor, optional): The weight of loss for each
prediction, has a shape (n,). Defaults to None.
eps (float): Avoid dividing by zero. Default: 1e-3.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'.
Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
"""
input = pred.flatten(1)
target = target.flatten(1).float()
a = torch.sum(input * target, 1)
b = torch.sum(input * input, 1) + eps
c = torch.sum(target * target, 1) + eps
d = (2 * a) / (b + c)
loss = 1 - d
if weight is not None:
assert weight.ndim == loss.ndim
assert len(weight) == len(pred)
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
def naive_dice_loss(pred,
target,
weight=None,
eps=1e-3,
reduction='mean',
avg_factor=None):
"""Calculate naive dice loss, the coefficient in the denominator is the
first power instead of the second power.
Args:
pred (torch.Tensor): The prediction, has a shape (n, *)
target (torch.Tensor): The learning label of the prediction,
shape (n, *), same shape of pred.
weight (torch.Tensor, optional): The weight of loss for each
prediction, has a shape (n,). Defaults to None.
eps (float): Avoid dividing by zero. Default: 1e-3.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'.
Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
"""
input = pred.flatten(1)
target = target.flatten(1).float()
a = torch.sum(input * target, 1)
b = torch.sum(input, 1)
c = torch.sum(target, 1)
d = (2 * a + eps) / (b + c + eps)
loss = 1 - d
if weight is not None:
assert weight.ndim == loss.ndim
assert len(weight) == len(pred)
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
@LOSSES.register_module(force=True)
class DiceLoss(nn.Module):
def __init__(self,
use_sigmoid=True,
activate=True,
reduction='mean',
naive_dice=False,
loss_weight=1.0,
eps=1e-3):
"""Dice Loss, there are two forms of dice loss is supported:
- the one proposed in `V-Net: Fully Convolutional Neural
Networks for Volumetric Medical Image Segmentation
<https://arxiv.org/abs/1606.04797>`_.
- the dice loss in which the power of the number in the
denominator is the first power instead of the second
power.
Args:
use_sigmoid (bool, optional): Whether to the prediction is
used for sigmoid or softmax. Defaults to True.
activate (bool): Whether to activate the predictions inside,
this will disable the inside sigmoid operation.
Defaults to True.
reduction (str, optional): The method used
to reduce the loss. Options are "none",
"mean" and "sum". Defaults to 'mean'.
naive_dice (bool, optional): If false, use the dice
loss defined in the V-Net paper, otherwise, use the
naive dice loss in which the power of the number in the
denominator is the first power instead of the second
power.Defaults to False.
loss_weight (float, optional): Weight of loss. Defaults to 1.0.
eps (float): Avoid dividing by zero. Defaults to 1e-3.
"""
super(DiceLoss, self).__init__()
self.use_sigmoid = use_sigmoid
self.reduction = reduction
self.naive_dice = naive_dice
self.loss_weight = loss_weight
self.eps = eps
self.activate = activate
def forward(self,
pred,
target,
weight=None,
reduction_override=None,
avg_factor=None):
"""Forward function.
Args:
pred (torch.Tensor): The prediction, has a shape (n, *).
target (torch.Tensor): The label of the prediction,
shape (n, *), same shape of pred.
weight (torch.Tensor, optional): The weight of loss for each
prediction, has a shape (n,). Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Options are "none", "mean" and "sum".
Returns:
torch.Tensor: The calculated loss
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (reduction_override
if reduction_override else self.reduction)
if self.activate:
if self.use_sigmoid:
pred = pred.sigmoid()
else:
raise NotImplementedError
if self.naive_dice:
loss = self.loss_weight * naive_dice_loss(
pred,
target,
weight,
eps=self.eps,
reduction=reduction,
avg_factor=avg_factor)
else:
loss = self.loss_weight * dice_loss(
pred,
target,
weight,
eps=self.eps,
reduction=reduction,
avg_factor=avg_factor)
return loss
\ No newline at end of file
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.ops import sigmoid_focal_loss as _sigmoid_focal_loss
from mmseg.models.builder import LOSSES
from mmseg.models.losses.utils import weight_reduce_loss
# This method is only for debugging
def py_sigmoid_focal_loss(pred,
target,
weight=None,
gamma=2.0,
alpha=0.25,
reduction='mean',
avg_factor=None):
"""PyTorch version of `Focal Loss <https://arxiv.org/abs/1708.02002>`_.
Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the
number of classes
target (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 2.0.
alpha (float, optional): A balanced form for Focal Loss.
Defaults to 0.25.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
"""
pred_sigmoid = pred.sigmoid()
target = target.type_as(pred)
pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
focal_weight = (alpha * target + (1 - alpha) *
(1 - target)) * pt.pow(gamma)
loss = F.binary_cross_entropy_with_logits(
pred, target, reduction='none') * focal_weight
if weight is not None:
if weight.shape != loss.shape:
if weight.size(0) == loss.size(0):
# For most cases, weight is of shape (num_priors, ),
# which means it does not have the second axis num_class
weight = weight.view(-1, 1)
else:
# Sometimes, weight per anchor per class is also needed. e.g.
# in FSAF. But it may be flattened of shape
# (num_priors x num_class, ), while loss is still of shape
# (num_priors, num_class).
assert weight.numel() == loss.numel()
weight = weight.view(loss.size(0), -1)
assert weight.ndim == loss.ndim
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
def sigmoid_focal_loss(pred,
target,
weight=None,
gamma=2.0,
alpha=0.25,
reduction='mean',
avg_factor=None):
r"""A warpper of cuda version `Focal Loss
<https://arxiv.org/abs/1708.02002>`_.
Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the number
of classes.
target (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 2.0.
alpha (float, optional): A balanced form for Focal Loss.
Defaults to 0.25.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
"""
# Function.apply does not accept keyword arguments, so the decorator
# "weighted_loss" is not applicable
loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(), gamma,
alpha, None, 'none')
if weight is not None:
if weight.shape != loss.shape:
if weight.size(0) == loss.size(0):
# For most cases, weight is of shape (num_priors, ),
# which means it does not have the second axis num_class
weight = weight.view(-1, 1)
else:
# Sometimes, weight per anchor per class is also needed. e.g.
# in FSAF. But it may be flattened of shape
# (num_priors x num_class, ), while loss is still of shape
# (num_priors, num_class).
assert weight.numel() == loss.numel()
weight = weight.view(loss.size(0), -1)
assert weight.ndim == loss.ndim
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
@LOSSES.register_module(force=True)
class FocalLoss(nn.Module):
def __init__(self,
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
reduction='mean',
loss_weight=1.0):
"""`Focal Loss <https://arxiv.org/abs/1708.02002>`_
Args:
use_sigmoid (bool, optional): Whether to the prediction is
used for sigmoid or softmax. Defaults to True.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 2.0.
alpha (float, optional): A balanced form for Focal Loss.
Defaults to 0.25.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'. Options are "none", "mean" and
"sum".
loss_weight (float, optional): Weight of loss. Defaults to 1.0.
"""
super(FocalLoss, self).__init__()
assert use_sigmoid is True, 'Only sigmoid focal loss supported now.'
self.use_sigmoid = use_sigmoid
self.gamma = gamma
self.alpha = alpha
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
"""Forward function.
Args:
pred (torch.Tensor): The prediction.
target (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Options are "none", "mean" and "sum".
Returns:
torch.Tensor: The calculated loss
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if self.use_sigmoid:
if torch.cuda.is_available() and pred.is_cuda:
calculate_loss_func = sigmoid_focal_loss
else:
num_classes = pred.size(1)
target = F.one_hot(target, num_classes=num_classes + 1)
target = target[:, :num_classes]
calculate_loss_func = py_sigmoid_focal_loss
loss_cls = self.loss_weight * calculate_loss_func(
pred,
target,
weight,
gamma=self.gamma,
alpha=self.alpha,
reduction=reduction,
avg_factor=avg_factor)
else:
raise NotImplementedError
return loss_cls
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..builder import MATCH_COST
@MATCH_COST.register_module()
class FocalLossCost:
"""FocalLossCost.
Args:
weight (int | float, optional): loss_weight
alpha (int | float, optional): focal_loss alpha
gamma (int | float, optional): focal_loss gamma
eps (float, optional): default 1e-12
Examples:
>>> from mmdet.core.bbox.match_costs.match_cost import FocalLossCost
>>> import torch
>>> self = FocalLossCost()
>>> cls_pred = torch.rand(4, 3)
>>> gt_labels = torch.tensor([0, 1, 2])
>>> factor = torch.tensor([10, 8, 10, 8])
>>> self(cls_pred, gt_labels)
tensor([[-0.3236, -0.3364, -0.2699],
[-0.3439, -0.3209, -0.4807],
[-0.4099, -0.3795, -0.2929],
[-0.1950, -0.1207, -0.2626]])
"""
def __init__(self, weight=1., alpha=0.25, gamma=2, eps=1e-12):
self.weight = weight
self.alpha = alpha
self.gamma = gamma
self.eps = eps
def __call__(self, cls_pred, gt_labels):
"""
Args:
cls_pred (Tensor): Predicted classification logits, shape
[num_query, num_class].
gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).
Returns:
torch.Tensor: cls_cost value with weight
"""
cls_pred = cls_pred.sigmoid()
neg_cost = -(1 - cls_pred + self.eps).log() * (
1 - self.alpha) * cls_pred.pow(self.gamma)
pos_cost = -(cls_pred + self.eps).log() * self.alpha * (
1 - cls_pred).pow(self.gamma)
cls_cost = pos_cost[:, gt_labels] - neg_cost[:, gt_labels]
return cls_cost * self.weight
@MATCH_COST.register_module()
class MaskFocalLossCost(FocalLossCost):
"""Cost of mask assignments based on focal losses.
Args:
weight (int | float, optional): loss_weight.
alpha (int | float, optional): focal_loss alpha.
gamma (int | float, optional): focal_loss gamma.
eps (float, optional): default 1e-12.
"""
def __call__(self, cls_pred, gt_labels):
"""
Args:
cls_pred (Tensor): Predicted classfication logits
in shape (N1, H, W), dtype=torch.float32.
gt_labels (Tensor): Ground truth in shape (N2, H, W),
dtype=torch.long.
Returns:
Tensor: classification cost matrix in shape (N1, N2).
"""
cls_pred = cls_pred.reshape((cls_pred.shape[0], -1))
gt_labels = gt_labels.reshape((gt_labels.shape[0], -1)).float()
hw = cls_pred.shape[1]
cls_pred = cls_pred.sigmoid()
neg_cost = -(1 - cls_pred + self.eps).log() * (
1 - self.alpha) * cls_pred.pow(self.gamma)
pos_cost = -(cls_pred + self.eps).log() * self.alpha * (
1 - cls_pred).pow(self.gamma)
cls_cost = torch.einsum('nc,mc->nm', pos_cost, gt_labels) + \
torch.einsum('nc,mc->nm', neg_cost, (1 - gt_labels))
return cls_cost / hw * self.weight
@MATCH_COST.register_module()
class ClassificationCost:
"""ClsSoftmaxCost.Borrow from
mmdet.core.bbox.match_costs.match_cost.ClassificationCost.
Args:
weight (int | float, optional): loss_weight
Examples:
>>> import torch
>>> self = ClassificationCost()
>>> cls_pred = torch.rand(4, 3)
>>> gt_labels = torch.tensor([0, 1, 2])
>>> factor = torch.tensor([10, 8, 10, 8])
>>> self(cls_pred, gt_labels)
tensor([[-0.3430, -0.3525, -0.3045],
[-0.3077, -0.2931, -0.3992],
[-0.3664, -0.3455, -0.2881],
[-0.3343, -0.2701, -0.3956]])
"""
def __init__(self, weight=1.):
self.weight = weight
def __call__(self, cls_pred, gt_labels):
"""
Args:
cls_pred (Tensor): Predicted classification logits, shape
[num_query, num_class].
gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).
Returns:
torch.Tensor: cls_cost value with weight
"""
# Following the official DETR repo, contrary to the loss that
# NLL is used, we approximate it in 1 - cls_score[gt_label].
# The 1 is a constant that doesn't change the matching,
# so it can be omitted.
cls_score = cls_pred.softmax(-1)
cls_cost = -cls_score[:, gt_labels]
return cls_cost * self.weight
@MATCH_COST.register_module()
class DiceCost:
"""Cost of mask assignments based on dice losses.
Args:
weight (int | float, optional): loss_weight. Defaults to 1.
pred_act (bool, optional): Whether to apply sigmoid to mask_pred.
Defaults to False.
eps (float, optional): default 1e-12.
"""
def __init__(self, weight=1., pred_act=False, eps=1e-3):
self.weight = weight
self.pred_act = pred_act
self.eps = eps
def binary_mask_dice_loss(self, mask_preds, gt_masks):
"""
Args:
mask_preds (Tensor): Mask prediction in shape (N1, H, W).
gt_masks (Tensor): Ground truth in shape (N2, H, W)
store 0 or 1, 0 for negative class and 1 for
positive class.
Returns:
Tensor: Dice cost matrix in shape (N1, N2).
"""
mask_preds = mask_preds.reshape((mask_preds.shape[0], -1))
gt_masks = gt_masks.reshape((gt_masks.shape[0], -1)).float()
numerator = 2 * torch.einsum('nc,mc->nm', mask_preds, gt_masks)
denominator = mask_preds.sum(-1)[:, None] + gt_masks.sum(-1)[None, :]
loss = 1 - (numerator + self.eps) / (denominator + self.eps)
return loss
def __call__(self, mask_preds, gt_masks):
"""
Args:
mask_preds (Tensor): Mask prediction logits in shape (N1, H, W).
gt_masks (Tensor): Ground truth in shape (N2, H, W).
Returns:
Tensor: Dice cost matrix in shape (N1, N2).
"""
if self.pred_act:
mask_preds = mask_preds.sigmoid()
dice_cost = self.binary_mask_dice_loss(mask_preds, gt_masks)
return dice_cost * self.weight
@MATCH_COST.register_module()
class CrossEntropyLossCost:
"""CrossEntropyLossCost.
Args:
weight (int | float, optional): loss weight. Defaults to 1.
use_sigmoid (bool, optional): Whether the prediction uses sigmoid
of softmax. Defaults to True.
"""
def __init__(self, weight=1., use_sigmoid=True):
assert use_sigmoid, 'use_sigmoid = False is not supported yet.'
self.weight = weight
self.use_sigmoid = use_sigmoid
def _binary_cross_entropy(self, cls_pred, gt_labels):
"""
Args:
cls_pred (Tensor): The prediction with shape (num_query, 1, *) or
(num_query, *).
gt_labels (Tensor): The learning label of prediction with
shape (num_gt, *).
Returns:
Tensor: Cross entropy cost matrix in shape (num_query, num_gt).
"""
cls_pred = cls_pred.flatten(1).float()
gt_labels = gt_labels.flatten(1).float()
n = cls_pred.shape[1]
pos = F.binary_cross_entropy_with_logits(
cls_pred, torch.ones_like(cls_pred), reduction='none')
neg = F.binary_cross_entropy_with_logits(
cls_pred, torch.zeros_like(cls_pred), reduction='none')
cls_cost = torch.einsum('nc,mc->nm', pos, gt_labels) + \
torch.einsum('nc,mc->nm', neg, 1 - gt_labels)
cls_cost = cls_cost / n
return cls_cost
def __call__(self, cls_pred, gt_labels):
"""
Args:
cls_pred (Tensor): Predicted classification logits.
gt_labels (Tensor): Labels.
Returns:
Tensor: Cross entropy cost matrix with weight in
shape (num_query, num_gt).
"""
if self.use_sigmoid:
cls_cost = self._binary_cross_entropy(cls_pred, gt_labels)
else:
raise NotImplementedError
return cls_cost * self.weight
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..builder import MATCH_COST
@MATCH_COST.register_module()
class FocalLossCost:
"""FocalLossCost.
Args:
weight (int | float, optional): loss_weight
alpha (int | float, optional): focal_loss alpha
gamma (int | float, optional): focal_loss gamma
eps (float, optional): default 1e-12
Examples:
>>> from mmdet.core.bbox.match_costs.match_cost import FocalLossCost
>>> import torch
>>> self = FocalLossCost()
>>> cls_pred = torch.rand(4, 3)
>>> gt_labels = torch.tensor([0, 1, 2])
>>> factor = torch.tensor([10, 8, 10, 8])
>>> self(cls_pred, gt_labels)
tensor([[-0.3236, -0.3364, -0.2699],
[-0.3439, -0.3209, -0.4807],
[-0.4099, -0.3795, -0.2929],
[-0.1950, -0.1207, -0.2626]])
"""
def __init__(self, weight=1., alpha=0.25, gamma=2, eps=1e-12):
self.weight = weight
self.alpha = alpha
self.gamma = gamma
self.eps = eps
def __call__(self, cls_pred, gt_labels):
"""
Args:
cls_pred (Tensor): Predicted classification logits, shape
[num_query, num_class].
gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).
Returns:
torch.Tensor: cls_cost value with weight
"""
cls_pred = cls_pred.sigmoid()
neg_cost = -(1 - cls_pred + self.eps).log() * (
1 - self.alpha) * cls_pred.pow(self.gamma)
pos_cost = -(cls_pred + self.eps).log() * self.alpha * (
1 - cls_pred).pow(self.gamma)
cls_cost = pos_cost[:, gt_labels] - neg_cost[:, gt_labels]
return cls_cost * self.weight
@MATCH_COST.register_module()
class MaskFocalLossCost(FocalLossCost):
"""Cost of mask assignments based on focal losses.
Args:
weight (int | float, optional): loss_weight.
alpha (int | float, optional): focal_loss alpha.
gamma (int | float, optional): focal_loss gamma.
eps (float, optional): default 1e-12.
"""
def __call__(self, cls_pred, gt_labels):
"""
Args:
cls_pred (Tensor): Predicted classfication logits
in shape (N1, H, W), dtype=torch.float32.
gt_labels (Tensor): Ground truth in shape (N2, H, W),
dtype=torch.long.
Returns:
Tensor: classification cost matrix in shape (N1, N2).
"""
cls_pred = cls_pred.reshape((cls_pred.shape[0], -1))
gt_labels = gt_labels.reshape((gt_labels.shape[0], -1)).float()
hw = cls_pred.shape[1]
cls_pred = cls_pred.sigmoid()
neg_cost = -(1 - cls_pred + self.eps).log() * (
1 - self.alpha) * cls_pred.pow(self.gamma)
pos_cost = -(cls_pred + self.eps).log() * self.alpha * (
1 - cls_pred).pow(self.gamma)
cls_cost = torch.einsum('nc,mc->nm', pos_cost, gt_labels) + \
torch.einsum('nc,mc->nm', neg_cost, (1 - gt_labels))
return cls_cost / hw * self.weight
@MATCH_COST.register_module()
class ClassificationCost:
"""ClsSoftmaxCost.Borrow from
mmdet.core.bbox.match_costs.match_cost.ClassificationCost.
Args:
weight (int | float, optional): loss_weight
Examples:
>>> import torch
>>> self = ClassificationCost()
>>> cls_pred = torch.rand(4, 3)
>>> gt_labels = torch.tensor([0, 1, 2])
>>> factor = torch.tensor([10, 8, 10, 8])
>>> self(cls_pred, gt_labels)
tensor([[-0.3430, -0.3525, -0.3045],
[-0.3077, -0.2931, -0.3992],
[-0.3664, -0.3455, -0.2881],
[-0.3343, -0.2701, -0.3956]])
"""
def __init__(self, weight=1.):
self.weight = weight
def __call__(self, cls_pred, gt_labels):
"""
Args:
cls_pred (Tensor): Predicted classification logits, shape
[num_query, num_class].
gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).
Returns:
torch.Tensor: cls_cost value with weight
"""
# Following the official DETR repo, contrary to the loss that
# NLL is used, we approximate it in 1 - cls_score[gt_label].
# The 1 is a constant that doesn't change the matching,
# so it can be omitted.
cls_score = cls_pred.softmax(-1)
cls_cost = -cls_score[:, gt_labels]
return cls_cost * self.weight
@MATCH_COST.register_module()
class DiceCost:
"""Cost of mask assignments based on dice losses.
Args:
weight (int | float, optional): loss_weight. Defaults to 1.
pred_act (bool, optional): Whether to apply sigmoid to mask_pred.
Defaults to False.
eps (float, optional): default 1e-12.
"""
def __init__(self, weight=1., pred_act=False, eps=1e-3):
self.weight = weight
self.pred_act = pred_act
self.eps = eps
def binary_mask_dice_loss(self, mask_preds, gt_masks):
"""
Args:
mask_preds (Tensor): Mask prediction in shape (N1, H, W).
gt_masks (Tensor): Ground truth in shape (N2, H, W)
store 0 or 1, 0 for negative class and 1 for
positive class.
Returns:
Tensor: Dice cost matrix in shape (N1, N2).
"""
mask_preds = mask_preds.reshape((mask_preds.shape[0], -1))
gt_masks = gt_masks.reshape((gt_masks.shape[0], -1)).float()
numerator = 2 * torch.einsum('nc,mc->nm', mask_preds, gt_masks)
denominator = mask_preds.sum(-1)[:, None] + gt_masks.sum(-1)[None, :]
loss = 1 - (numerator + self.eps) / (denominator + self.eps)
return loss
def __call__(self, mask_preds, gt_masks):
"""
Args:
mask_preds (Tensor): Mask prediction logits in shape (N1, H, W).
gt_masks (Tensor): Ground truth in shape (N2, H, W).
Returns:
Tensor: Dice cost matrix in shape (N1, N2).
"""
if self.pred_act:
mask_preds = mask_preds.sigmoid()
dice_cost = self.binary_mask_dice_loss(mask_preds, gt_masks)
return dice_cost * self.weight
# Copyright (c) Shanghai AI Lab. All rights reserved.
from .msdeformattn_pixel_decoder import MSDeformAttnPixelDecoder
from .pixel_decoder import PixelDecoder, TransformerEncoderPixelDecoder
__all__ = [
'PixelDecoder', 'TransformerEncoderPixelDecoder',
'MSDeformAttnPixelDecoder'
]
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import (PLUGIN_LAYERS, Conv2d, ConvModule, caffe2_xavier_init,
normal_init, xavier_init)
from mmcv.cnn.bricks.transformer import (build_positional_encoding,
build_transformer_layer_sequence)
from mmcv.runner import BaseModule, ModuleList
from ...core.anchor import MlvlPointGenerator
from ..utils.transformer import MultiScaleDeformableAttention
@PLUGIN_LAYERS.register_module()
class MSDeformAttnPixelDecoder(BaseModule):
"""Pixel decoder with multi-scale deformable attention.
Args:
in_channels (list[int] | tuple[int]): Number of channels in the
input feature maps.
strides (list[int] | tuple[int]): Output strides of feature from
backbone.
feat_channels (int): Number of channels for feature.
out_channels (int): Number of channels for output.
num_outs (int): Number of output scales.
norm_cfg (:obj:`mmcv.ConfigDict` | dict): Config for normalization.
Defaults to dict(type='GN', num_groups=32).
act_cfg (:obj:`mmcv.ConfigDict` | dict): Config for activation.
Defaults to dict(type='ReLU').
encoder (:obj:`mmcv.ConfigDict` | dict): Config for transformer
encoder. Defaults to `DetrTransformerEncoder`.
positional_encoding (:obj:`mmcv.ConfigDict` | dict): Config for
transformer encoder position encoding. Defaults to
dict(type='SinePositionalEncoding', num_feats=128,
normalize=True).
init_cfg (:obj:`mmcv.ConfigDict` | dict): Initialization config dict.
"""
def __init__(self,
in_channels=[256, 512, 1024, 2048],
strides=[4, 8, 16, 32],
feat_channels=256,
out_channels=256,
num_outs=3,
norm_cfg=dict(type='GN', num_groups=32),
act_cfg=dict(type='ReLU'),
encoder=dict(
type='DetrTransformerEncoder',
num_layers=6,
transformerlayers=dict(
type='BaseTransformerLayer',
attn_cfgs=dict(
type='MultiScaleDeformableAttention',
embed_dims=256,
num_heads=8,
num_levels=3,
num_points=4,
im2col_step=64,
dropout=0.0,
batch_first=False,
norm_cfg=None,
init_cfg=None),
feedforward_channels=1024,
ffn_dropout=0.0,
operation_order=('self_attn', 'norm', 'ffn', 'norm')),
init_cfg=None),
positional_encoding=dict(
type='SinePositionalEncoding',
num_feats=128,
normalize=True),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.strides = strides
self.num_input_levels = len(in_channels)
self.num_encoder_levels = \
encoder.transformerlayers.attn_cfgs.num_levels
assert self.num_encoder_levels >= 1, \
'num_levels in attn_cfgs must be at least one'
input_conv_list = []
# from top to down (low to high resolution)
for i in range(self.num_input_levels - 1,
self.num_input_levels - self.num_encoder_levels - 1,
-1):
input_conv = ConvModule(
in_channels[i],
feat_channels,
kernel_size=1,
norm_cfg=norm_cfg,
act_cfg=None,
bias=True)
input_conv_list.append(input_conv)
self.input_convs = ModuleList(input_conv_list)
self.encoder = build_transformer_layer_sequence(encoder)
self.postional_encoding = build_positional_encoding(
positional_encoding)
# high resolution to low resolution
self.level_encoding = nn.Embedding(self.num_encoder_levels,
feat_channels)
# fpn-like structure
self.lateral_convs = ModuleList()
self.output_convs = ModuleList()
self.use_bias = norm_cfg is None
# from top to down (low to high resolution)
# fpn for the rest features that didn't pass in encoder
for i in range(self.num_input_levels - self.num_encoder_levels - 1, -1,
-1):
lateral_conv = ConvModule(
in_channels[i],
feat_channels,
kernel_size=1,
bias=self.use_bias,
norm_cfg=norm_cfg,
act_cfg=None)
output_conv = ConvModule(
feat_channels,
feat_channels,
kernel_size=3,
stride=1,
padding=1,
bias=self.use_bias,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.lateral_convs.append(lateral_conv)
self.output_convs.append(output_conv)
self.mask_feature = Conv2d(
feat_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.num_outs = num_outs
self.point_generator = MlvlPointGenerator(strides)
def init_weights(self):
"""Initialize weights."""
for i in range(0, self.num_encoder_levels):
xavier_init(
self.input_convs[i].conv,
gain=1,
bias=0,
distribution='uniform')
for i in range(0, self.num_input_levels - self.num_encoder_levels):
caffe2_xavier_init(self.lateral_convs[i].conv, bias=0)
caffe2_xavier_init(self.output_convs[i].conv, bias=0)
caffe2_xavier_init(self.mask_feature, bias=0)
normal_init(self.level_encoding, mean=0, std=1)
for p in self.encoder.parameters():
if p.dim() > 1:
nn.init.xavier_normal_(p)
# init_weights defined in MultiScaleDeformableAttention
for layer in self.encoder.layers:
for attn in layer.attentions:
if isinstance(attn, MultiScaleDeformableAttention):
attn.init_weights()
def forward(self, feats):
"""
Args:
feats (list[Tensor]): Feature maps of each level. Each has
shape of (batch_size, c, h, w).
Returns:
tuple: A tuple containing the following:
- mask_feature (Tensor): shape (batch_size, c, h, w).
- multi_scale_features (list[Tensor]): Multi scale \
features, each in shape (batch_size, c, h, w).
"""
# generate padding mask for each level, for each image
batch_size = feats[0].shape[0]
encoder_input_list = []
padding_mask_list = []
level_positional_encoding_list = []
spatial_shapes = []
reference_points_list = []
for i in range(self.num_encoder_levels):
level_idx = self.num_input_levels - i - 1
feat = feats[level_idx]
feat_projected = self.input_convs[i](feat)
h, w = feat.shape[-2:]
# no padding
padding_mask_resized = feat.new_zeros(
(batch_size, ) + feat.shape[-2:], dtype=torch.bool)
pos_embed = self.postional_encoding(padding_mask_resized)
level_embed = self.level_encoding.weight[i]
level_pos_embed = level_embed.view(1, -1, 1, 1) + pos_embed
# (h_i * w_i, 2)
reference_points = self.point_generator.single_level_grid_priors(
feat.shape[-2:], level_idx, device=feat.device)
# normalize
factor = feat.new_tensor([[w, h]]) * self.strides[level_idx]
reference_points = reference_points / factor
# shape (batch_size, c, h_i, w_i) -> (h_i * w_i, batch_size, c)
feat_projected = feat_projected.flatten(2).permute(2, 0, 1)
level_pos_embed = level_pos_embed.flatten(2).permute(2, 0, 1)
padding_mask_resized = padding_mask_resized.flatten(1)
encoder_input_list.append(feat_projected)
padding_mask_list.append(padding_mask_resized)
level_positional_encoding_list.append(level_pos_embed)
spatial_shapes.append(feat.shape[-2:])
reference_points_list.append(reference_points)
# shape (batch_size, total_num_query),
# total_num_query=sum([., h_i * w_i,.])
padding_masks = torch.cat(padding_mask_list, dim=1)
# shape (total_num_query, batch_size, c)
encoder_inputs = torch.cat(encoder_input_list, dim=0)
level_positional_encodings = torch.cat(
level_positional_encoding_list, dim=0)
device = encoder_inputs.device
# shape (num_encoder_levels, 2), from low
# resolution to high resolution
spatial_shapes = torch.as_tensor(
spatial_shapes, dtype=torch.long, device=device)
# shape (0, h_0*w_0, h_0*w_0+h_1*w_1, ...)
level_start_index = torch.cat((spatial_shapes.new_zeros(
(1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
reference_points = torch.cat(reference_points_list, dim=0)
reference_points = reference_points[None, :, None].repeat(
batch_size, 1, self.num_encoder_levels, 1)
valid_radios = reference_points.new_ones(
(batch_size, self.num_encoder_levels, 2))
# shape (num_total_query, batch_size, c)
memory = self.encoder(
query=encoder_inputs,
key=None,
value=None,
query_pos=level_positional_encodings,
key_pos=None,
attn_masks=None,
key_padding_mask=None,
query_key_padding_mask=padding_masks,
spatial_shapes=spatial_shapes,
reference_points=reference_points,
level_start_index=level_start_index,
valid_radios=valid_radios)
# (num_total_query, batch_size, c) -> (batch_size, c, num_total_query)
memory = memory.permute(1, 2, 0)
# from low resolution to high resolution
num_query_per_level = [e[0] * e[1] for e in spatial_shapes]
outs = torch.split(memory, num_query_per_level, dim=-1)
outs = [
x.reshape(batch_size, -1, spatial_shapes[i][0],
spatial_shapes[i][1]) for i, x in enumerate(outs)
]
for i in range(self.num_input_levels - self.num_encoder_levels - 1, -1,
-1):
x = feats[i]
cur_feat = self.lateral_convs[i](x)
y = cur_feat + F.interpolate(
outs[-1],
size=cur_feat.shape[-2:],
mode='bilinear',
align_corners=False)
y = self.output_convs[i](y)
outs.append(y)
multi_scale_features = outs[:self.num_outs]
mask_feature = self.mask_feature(outs[-1])
return mask_feature, multi_scale_features
\ No newline at end of file
import torch
import torch.nn.functional as F
from mmcv.cnn import PLUGIN_LAYERS, Conv2d, ConvModule, kaiming_init
from mmcv.cnn.bricks.transformer import (build_positional_encoding,
build_transformer_layer_sequence)
from mmcv.runner import BaseModule, ModuleList
@PLUGIN_LAYERS.register_module()
class PixelDecoder(BaseModule):
"""Pixel decoder with a structure like fpn.
Args:
in_channels (list[int] | tuple[int]): Number of channels in the
input feature maps.
feat_channels (int): Number channels for feature.
out_channels (int): Number channels for output.
norm_cfg (obj:`mmcv.ConfigDict`|dict): Config for normalization.
Defaults to dict(type='GN', num_groups=32).
act_cfg (obj:`mmcv.ConfigDict`|dict): Config for activation.
Defaults to dict(type='ReLU').
encoder (obj:`mmcv.ConfigDict`|dict): Config for transorformer
encoder.Defaults to None.
positional_encoding (obj:`mmcv.ConfigDict`|dict): Config for
transformer encoder position encoding. Defaults to
dict(type='SinePositionalEncoding', num_feats=128,
normalize=True).
init_cfg (obj:`mmcv.ConfigDict`|dict): Initialization config dict.
Default: None
"""
def __init__(self,
in_channels,
feat_channels,
out_channels,
norm_cfg=dict(type='GN', num_groups=32),
act_cfg=dict(type='ReLU'),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.in_channels = in_channels
self.num_inputs = len(in_channels)
self.lateral_convs = ModuleList()
self.output_convs = ModuleList()
self.use_bias = norm_cfg is None
for i in range(0, self.num_inputs - 1):
l_conv = ConvModule(
in_channels[i],
feat_channels,
kernel_size=1,
bias=self.use_bias,
norm_cfg=norm_cfg,
act_cfg=None)
o_conv = ConvModule(
feat_channels,
feat_channels,
kernel_size=3,
stride=1,
padding=1,
bias=self.use_bias,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.lateral_convs.append(l_conv)
self.output_convs.append(o_conv)
self.last_feat_conv = ConvModule(
in_channels[-1],
feat_channels,
kernel_size=3,
padding=1,
stride=1,
bias=self.use_bias,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.mask_feature = Conv2d(
feat_channels, out_channels, kernel_size=3, stride=1, padding=1)
def init_weights(self):
"""Initialize weights."""
for i in range(0, self.num_inputs - 2):
kaiming_init(self.lateral_convs[i].conv, a=1)
kaiming_init(self.output_convs[i].conv, a=1)
kaiming_init(self.mask_feature, a=1)
kaiming_init(self.last_feat_conv, a=1)
def forward(self, feats, img_metas):
"""
Args:
feats (list[Tensor]): Feature maps of each level. Each has
shape of [bs, c, h, w].
img_metas (list[dict]): List of image information. Pass in
for creating more accurate padding mask. #! not used here.
Returns:
tuple: a tuple containing the following:
- mask_feature (Tensor): Shape [bs, c, h, w].
- memory (Tensor): Output of last stage of backbone.
Shape [bs, c, h, w].
"""
y = self.last_feat_conv(feats[-1])
for i in range(self.num_inputs - 2, -1, -1):
x = feats[i]
cur_fpn = self.lateral_convs[i](x)
y = cur_fpn + \
F.interpolate(y, size=cur_fpn.shape[-2:], mode='nearest')
y = self.output_convs[i](y)
mask_feature = self.mask_feature(y)
memory = feats[-1]
return mask_feature, memory
@PLUGIN_LAYERS.register_module()
class TransformerEncoderPixelDecoder(PixelDecoder):
"""Pixel decoder with transormer encoder inside.
Args:
in_channels (list[int] | tuple[int]): Number of channels in the
input feature maps.
feat_channels (int): Number channels for feature.
out_channels (int): Number channels for output.
norm_cfg (obj:`mmcv.ConfigDict`|dict): Config for normalization.
Defaults to dict(type='GN', num_groups=32).
act_cfg (obj:`mmcv.ConfigDict`|dict): Config for activation.
Defaults to dict(type='ReLU').
encoder (obj:`mmcv.ConfigDict`|dict): Config for transorformer
encoder.Defaults to None.
positional_encoding (obj:`mmcv.ConfigDict`|dict): Config for
transformer encoder position encoding. Defaults to
dict(type='SinePositionalEncoding', num_feats=128,
normalize=True).
init_cfg (obj:`mmcv.ConfigDict`|dict): Initialization config dict.
Default: None
"""
def __init__(self,
in_channels,
feat_channels,
out_channels,
norm_cfg=dict(type='GN', num_groups=32),
act_cfg=dict(type='ReLU'),
encoder=None,
positional_encoding=dict(
type='SinePositionalEncoding',
num_feats=128,
normalize=True),
init_cfg=None):
super(TransformerEncoderPixelDecoder, self).__init__(
in_channels,
feat_channels,
out_channels,
norm_cfg,
act_cfg,
init_cfg=init_cfg)
self.last_feat_conv = None
self.encoder = build_transformer_layer_sequence(encoder)
self.encoder_embed_dims = self.encoder.embed_dims
assert self.encoder_embed_dims == feat_channels, 'embed_dims({}) of ' \
'tranformer encoder must equal to feat_channels({})'.format(
feat_channels, self.encoder_embed_dims)
self.positional_encoding = build_positional_encoding(
positional_encoding)
self.encoder_in_proj = Conv2d(
in_channels[-1], feat_channels, kernel_size=1)
self.encoder_out_proj = ConvModule(
feat_channels,
feat_channels,
kernel_size=3,
stride=1,
padding=1,
bias=self.use_bias,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
def init_weights(self):
"""Initialize weights."""
for i in range(0, self.num_inputs - 2):
kaiming_init(self.lateral_convs[i].conv, a=1)
kaiming_init(self.output_convs[i].conv, a=1)
kaiming_init(self.mask_feature, a=1)
kaiming_init(self.encoder_in_proj, a=1)
kaiming_init(self.encoder_out_proj.conv, a=1)
def forward(self, feats, img_metas):
"""
Args:
feats (list[Tensor]): Feature maps of each level. Each has
shape of [bs, c, h, w].
img_metas (list[dict]): List of image information. Pass in
for creating more accurate padding mask.
Returns:
tuple: a tuple containing the following:
- mask_feature (Tensor): shape [bs, c, h, w].
- memory (Tensor): shape [bs, c, h, w].
"""
feat_last = feats[-1]
bs, c, h, w = feat_last.shape
input_img_h, input_img_w = img_metas[0]['pad_shape'][:-1]
# input_img_h, input_img_w = img_metas[0]['batch_input_shape']
padding_mask = feat_last.new_ones((bs, input_img_h, input_img_w),
dtype=torch.float32)
for i in range(bs):
img_h, img_w, _ = img_metas[i]['img_shape']
padding_mask[i, :img_h, :img_w] = 0
padding_mask = F.interpolate(
padding_mask.unsqueeze(1),
size=feat_last.shape[-2:],
mode='nearest').to(torch.bool).squeeze(1)
pos_embed = self.positional_encoding(padding_mask)
feat_last = self.encoder_in_proj(feat_last)
# [bs, c, h, w] -> [nq, bs, dim]
feat_last = feat_last.flatten(2).permute(2, 0, 1)
pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
padding_mask = padding_mask.flatten(1) # [bs, h, w] -> [bs, h*w]
memory = self.encoder(
query=feat_last,
key=None,
value=None,
query_pos=pos_embed,
query_key_padding_mask=padding_mask)
# [nq, bs, em] -> [bs, c, h, w]
memory = memory.permute(1, 2, 0).view(bs, self.encoder_embed_dims, h,
w)
y = self.encoder_out_proj(memory)
for i in range(self.num_inputs - 2, -1, -1):
x = feats[i]
cur_fpn = self.lateral_convs[i](x)
y = cur_fpn + \
F.interpolate(y, size=cur_fpn.shape[-2:], mode='nearest')
y = self.output_convs[i](y)
mask_feature = self.mask_feature(y)
return mask_feature, memory
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment