# Copyright (c) OpenMMLab. All rights reserved. import torch import torch.nn as nn import torch.nn.functional as F from mmaction.core.bbox import bbox_target try: from mmdet.models.builder import HEADS as MMDET_HEADS mmdet_imported = True except (ImportError, ModuleNotFoundError): mmdet_imported = False # Resolve cross-entropy function to support multi-target in Torch < 1.10 # This is a very basic 'hack', with minimal functionality to support the # procedure under prior torch versions from packaging import version as pv if pv.parse(torch.__version__) < pv.parse('1.10'): def cross_entropy_loss(input, target, reduction='None'): input = input.log_softmax(dim=-1) # Compute Log of Softmax loss = -(input * target).sum(dim=-1) # Compute Loss manually if reduction.lower() == 'mean': return loss.mean() elif reduction.lower() == 'sum': return loss.sum() else: return loss else: cross_entropy_loss = F.cross_entropy class BBoxHeadAVA(nn.Module): """Simplest RoI head, with only two fc layers for classification and regression respectively. Args: temporal_pool_type (str): The temporal pool type. Choices are 'avg' or 'max'. Default: 'avg'. spatial_pool_type (str): The spatial pool type. Choices are 'avg' or 'max'. Default: 'max'. in_channels (int): The number of input channels. Default: 2048. focal_alpha (float): The hyper-parameter alpha for Focal Loss. When alpha == 1 and gamma == 0, Focal Loss degenerates to BCELossWithLogits. Default: 1. focal_gamma (float): The hyper-parameter gamma for Focal Loss. When alpha == 1 and gamma == 0, Focal Loss degenerates to BCELossWithLogits. Default: 0. num_classes (int): The number of classes. Default: 81. dropout_ratio (float): A float in [0, 1], indicates the dropout_ratio. Default: 0. dropout_before_pool (bool): Dropout Feature before spatial temporal pooling. Default: True. topk (int or tuple[int]): Parameter for evaluating Top-K accuracy. Default: (3, 5) multilabel (bool): Whether used for a multilabel task. Default: True. """ def __init__( self, temporal_pool_type='avg', spatial_pool_type='max', in_channels=2048, focal_gamma=0., focal_alpha=1., num_classes=81, # First class reserved (BBox as pos/neg) dropout_ratio=0, dropout_before_pool=True, topk=(3, 5), multilabel=True): super(BBoxHeadAVA, self).__init__() assert temporal_pool_type in ['max', 'avg'] assert spatial_pool_type in ['max', 'avg'] self.temporal_pool_type = temporal_pool_type self.spatial_pool_type = spatial_pool_type self.in_channels = in_channels self.num_classes = num_classes self.dropout_ratio = dropout_ratio self.dropout_before_pool = dropout_before_pool self.multilabel = multilabel self.focal_gamma = focal_gamma self.focal_alpha = focal_alpha if topk is None: self.topk = () elif isinstance(topk, int): self.topk = (topk, ) elif isinstance(topk, tuple): assert all([isinstance(k, int) for k in topk]) self.topk = topk else: raise TypeError('topk should be int or tuple[int], ' f'but get {type(topk)}') # Class 0 is ignored when calculating accuracy, # so topk cannot be equal to num_classes. assert all([k < num_classes for k in self.topk]) in_channels = self.in_channels # Pool by default if self.temporal_pool_type == 'avg': self.temporal_pool = nn.AdaptiveAvgPool3d((1, None, None)) else: self.temporal_pool = nn.AdaptiveMaxPool3d((1, None, None)) if self.spatial_pool_type == 'avg': self.spatial_pool = nn.AdaptiveAvgPool3d((None, 1, 1)) else: self.spatial_pool = nn.AdaptiveMaxPool3d((None, 1, 1)) if dropout_ratio > 0: self.dropout = nn.Dropout(dropout_ratio) self.fc_cls = nn.Linear(in_channels, num_classes) self.debug_imgs = None def init_weights(self): nn.init.normal_(self.fc_cls.weight, 0, 0.01) nn.init.constant_(self.fc_cls.bias, 0) def forward(self, x): if self.dropout_before_pool and self.dropout_ratio > 0: x = self.dropout(x) x = self.temporal_pool(x) x = self.spatial_pool(x) if not self.dropout_before_pool and self.dropout_ratio > 0: x = self.dropout(x) x = x.view(x.size(0), -1) cls_score = self.fc_cls(x) # We do not predict bbox, so return None return cls_score, None @staticmethod def get_targets(sampling_results, gt_bboxes, gt_labels, rcnn_train_cfg): pos_proposals = [res.pos_bboxes for res in sampling_results] neg_proposals = [res.neg_bboxes for res in sampling_results] pos_gt_labels = [res.pos_gt_labels for res in sampling_results] cls_reg_targets = bbox_target(pos_proposals, neg_proposals, pos_gt_labels, rcnn_train_cfg) return cls_reg_targets @staticmethod def get_recall_prec(pred_vec, target_vec): """Computes the Recall/Precision for both multi-label and single label scenarios. Note that the computation calculates the micro average. Note, that in both cases, the concept of correct/incorrect is the same. Args: pred_vec (tensor[N x C]): each element is either 0 or 1 target_vec (tensor[N x C]): each element is either 0 or 1 - for single label it is expected that only one element is on (1) although this is not enforced. """ correct = pred_vec & target_vec recall = correct.sum(1) / target_vec.sum(1).float() # Enforce Float prec = correct.sum(1) / (pred_vec.sum(1) + 1e-6) return recall.mean(), prec.mean() @staticmethod def topk_to_matrix(probs, k): """Converts top-k to binary matrix.""" topk_labels = probs.topk(k, 1, True, True)[1] topk_matrix = probs.new_full(probs.size(), 0, dtype=torch.bool) for i in range(probs.shape[0]): topk_matrix[i, topk_labels[i]] = 1 return topk_matrix def topk_accuracy(self, pred, target, thr=0.5): """Computes the Top-K Accuracies for both single and multi-label scenarios.""" # Define Target vector: target_bool = target > 0.5 # Branch on Multilabel for computing output classification if self.multilabel: pred = pred.sigmoid() else: pred = pred.softmax(dim=1) # Compute at threshold (K=1 for single) if self.multilabel: pred_bool = pred > thr else: pred_bool = self.topk_to_matrix(pred, 1) recall_thr, prec_thr = self.get_recall_prec(pred_bool, target_bool) # Compute at various K recalls_k, precs_k = [], [] for k in self.topk: pred_bool = self.topk_to_matrix(pred, k) recall, prec = self.get_recall_prec(pred_bool, target_bool) recalls_k.append(recall) precs_k.append(prec) # Return all return recall_thr, prec_thr, recalls_k, precs_k def loss(self, cls_score, bbox_pred, rois, labels, label_weights, bbox_targets=None, bbox_weights=None, reduce=True): losses = dict() # Only use the cls_score if cls_score is not None: labels = labels[:, 1:] # Get valid labels (ignore first one) pos_inds = torch.sum(labels, dim=-1) > 0 cls_score = cls_score[pos_inds, 1:] labels = labels[pos_inds] # Compute First Recall/Precisions # This has to be done first before normalising the label-space. recall_thr, prec_thr, recall_k, prec_k = self.topk_accuracy( cls_score, labels, thr=0.5) losses['recall@thr=0.5'] = recall_thr losses['prec@thr=0.5'] = prec_thr for i, k in enumerate(self.topk): losses[f'recall@top{k}'] = recall_k[i] losses[f'prec@top{k}'] = prec_k[i] # If Single-label, need to ensure that target labels sum to 1: ie # that they are valid probabilities. if not self.multilabel: labels = labels / labels.sum(dim=1, keepdim=True) # Select Loss function based on single/multi-label # NB. Both losses auto-compute sigmoid/softmax on prediction if self.multilabel: loss_func = F.binary_cross_entropy_with_logits else: loss_func = cross_entropy_loss # Compute loss loss = loss_func(cls_score, labels, reduction='none') pt = torch.exp(-loss) F_loss = self.focal_alpha * (1 - pt)**self.focal_gamma * loss losses['loss_action_cls'] = torch.mean(F_loss) return losses def get_det_bboxes(self, rois, cls_score, img_shape, flip=False, crop_quadruple=None, cfg=None): # might be used by testing w. augmentation if isinstance(cls_score, list): cls_score = sum(cls_score) / float(len(cls_score)) # Handle Multi/Single Label if cls_score is not None: if self.multilabel: scores = cls_score.sigmoid() else: scores = cls_score.softmax(dim=-1) else: scores = None bboxes = rois[:, 1:] assert bboxes.shape[-1] == 4 # First reverse the flip img_h, img_w = img_shape if flip: bboxes_ = bboxes.clone() bboxes_[:, 0] = img_w - 1 - bboxes[:, 2] bboxes_[:, 2] = img_w - 1 - bboxes[:, 0] bboxes = bboxes_ # Then normalize the bbox to [0, 1] bboxes[:, 0::2] /= img_w bboxes[:, 1::2] /= img_h def _bbox_crop_undo(bboxes, crop_quadruple): decropped = bboxes.clone() if crop_quadruple is not None: x1, y1, tw, th = crop_quadruple decropped[:, 0::2] = bboxes[..., 0::2] * tw + x1 decropped[:, 1::2] = bboxes[..., 1::2] * th + y1 return decropped bboxes = _bbox_crop_undo(bboxes, crop_quadruple) return bboxes, scores if mmdet_imported: MMDET_HEADS.register_module()(BBoxHeadAVA)