Commit bd8fd27e authored by Cao Yuhang's avatar Cao Yuhang
Browse files

rename pos_mask to pos_inds

parent cb68807f
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmdet.core import (delta2bbox, multiclass_nms, bbox_target, from mmdet.core import (delta2bbox, multiclass_nms, bbox_target,
weighted_cross_entropy, weighted_smoothl1, accuracy) weighted_cross_entropy, weighted_smoothl1, accuracy)
from ..registry import HEADS from ..registry import HEADS
...@@ -94,16 +94,16 @@ class BBoxHead(nn.Module): ...@@ -94,16 +94,16 @@ class BBoxHead(nn.Module):
cls_score, labels, label_weights, reduce=reduce) cls_score, labels, label_weights, reduce=reduce)
losses['acc'] = accuracy(cls_score, labels) losses['acc'] = accuracy(cls_score, labels)
if bbox_pred is not None: if bbox_pred is not None:
pos_mask = labels > 0 pos_inds = labels > 0
if self.reg_class_agnostic: if self.reg_class_agnostic:
pos_bbox_pred = bbox_pred.view(bbox_pred.size(0), 4)[pos_mask] pos_bbox_pred = bbox_pred.view(bbox_pred.size(0), 4)[pos_inds]
else: else:
pos_bbox_pred = bbox_pred.view(bbox_pred.size(0), -1, pos_bbox_pred = bbox_pred.view(bbox_pred.size(0), -1,
4)[pos_mask, labels[pos_mask]] 4)[pos_inds, labels[pos_inds]]
losses['loss_reg'] = weighted_smoothl1( losses['loss_reg'] = weighted_smoothl1(
pos_bbox_pred, pos_bbox_pred,
bbox_targets[pos_mask], bbox_targets[pos_inds],
bbox_weights[pos_mask], bbox_weights[pos_inds],
avg_factor=bbox_targets.size(0)) avg_factor=bbox_targets.size(0))
return losses return losses
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment