"router/src/infer/v2/queue.rs" did not exist on "1a2d68250aa7dfbe1fa52b22eec07edfb7b895fb"
Unverified Commit 04cf8929 authored by Yezhen Cong's avatar Yezhen Cong Committed by GitHub
Browse files

Added axis-aligned IoU loss to VoteNet (#194)

* support axis-aligned iou loss for votenet

* added doc for iou loss

* fixed problems in format

* updated docstring

* rename and format fix

* rename and format fix

* rename and format fix

* rename and format fix

* modified config

* abstracted axis_aligned_iou3d

* abstracted a bbox corner decode func

* fix docstring
parent 23768cba
...@@ -38,3 +38,17 @@ Then you can use the converted checkpoints following [getting_started.md](../../ ...@@ -38,3 +38,17 @@ Then you can use the converted checkpoints following [getting_started.md](../../
## Indeterminism ## Indeterminism
Since test data preparation randomly downsamples the points, and the test script uses fixed random seeds while the random seeds of validation in training are not fixed, the test results may be slightly different from the results reported above. Since test data preparation randomly downsamples the points, and the test script uses fixed random seeds while the random seeds of validation in training are not fixed, the test results may be slightly different from the results reported above.
## IoU loss
Adding IoU loss (simply = 1-IoU) boosts VoteNet's performance. To use IoU loss, add this loss term to the config file:
```python
iou_loss=dict(type='AxisAlignedIoULoss', reduction='sum', loss_weight=10.0 / 3.0)
```
| Backbone | Lr schd | Mem (GB) | Inf time (fps) | AP@0.25 |AP@0.5| Download |
| :---------: | :-----: | :------: | :------------: | :----: |:----: | :------: |
| [PointNet++](./votenet_iouloss_8x8_scannet-3d-18class.py) | 3x |4.1||63.81|44.21|/|
For now, we only support calculating IoU loss for axis-aligned bounding boxes since the CUDA op of general 3D IoU calculation does not implement the backward method. Therefore, IoU loss can only be used for ScanNet dataset for now.
_base_ = ['./votenet_8x8_scannet-3d-18class.py']
# model settings, add iou loss
model = dict(
bbox_head=dict(
iou_loss=dict(
type='AxisAlignedIoULoss', reduction='sum', loss_weight=10.0 /
3.0)))
from .assigners import AssignResult, BaseAssigner, MaxIoUAssigner from .assigners import AssignResult, BaseAssigner, MaxIoUAssigner
from .coders import DeltaXYZWLHRBBoxCoder from .coders import DeltaXYZWLHRBBoxCoder
# from .bbox_target import bbox_target # from .bbox_target import bbox_target
from .iou_calculators import (BboxOverlaps3D, BboxOverlapsNearest3D, from .iou_calculators import (AxisAlignedBboxOverlaps3D, BboxOverlaps3D,
bbox_overlaps_3d, bbox_overlaps_nearest_3d) BboxOverlapsNearest3D,
axis_aligned_bbox_overlaps_3d, bbox_overlaps_3d,
bbox_overlaps_nearest_3d)
from .samplers import (BaseSampler, CombinedSampler, from .samplers import (BaseSampler, CombinedSampler,
InstanceBalancedPosSampler, IoUBalancedNegSampler, InstanceBalancedPosSampler, IoUBalancedNegSampler,
PseudoSampler, RandomSampler, SamplingResult) PseudoSampler, RandomSampler, SamplingResult)
...@@ -17,7 +19,8 @@ __all__ = [ ...@@ -17,7 +19,8 @@ __all__ = [
'PseudoSampler', 'RandomSampler', 'InstanceBalancedPosSampler', 'PseudoSampler', 'RandomSampler', 'InstanceBalancedPosSampler',
'IoUBalancedNegSampler', 'CombinedSampler', 'SamplingResult', 'IoUBalancedNegSampler', 'CombinedSampler', 'SamplingResult',
'DeltaXYZWLHRBBoxCoder', 'BboxOverlapsNearest3D', 'BboxOverlaps3D', 'DeltaXYZWLHRBBoxCoder', 'BboxOverlapsNearest3D', 'BboxOverlaps3D',
'bbox_overlaps_nearest_3d', 'bbox_overlaps_3d', 'Box3DMode', 'bbox_overlaps_nearest_3d', 'bbox_overlaps_3d',
'AxisAlignedBboxOverlaps3D', 'axis_aligned_bbox_overlaps_3d', 'Box3DMode',
'LiDARInstance3DBoxes', 'CameraInstance3DBoxes', 'bbox3d2roi', 'LiDARInstance3DBoxes', 'CameraInstance3DBoxes', 'bbox3d2roi',
'bbox3d2result', 'DepthInstance3DBoxes', 'BaseInstance3DBoxes', 'bbox3d2result', 'DepthInstance3DBoxes', 'BaseInstance3DBoxes',
'bbox3d_mapping_back', 'xywhr2xyxyr', 'limit_period', 'points_cam2img', 'bbox3d_mapping_back', 'xywhr2xyxyr', 'limit_period', 'points_cam2img',
......
...@@ -98,6 +98,44 @@ class PartialBinBasedBBoxCoder(BaseBBoxCoder): ...@@ -98,6 +98,44 @@ class PartialBinBasedBBoxCoder(BaseBBoxCoder):
bbox3d = torch.cat([center, bbox_size, dir_angle], dim=-1) bbox3d = torch.cat([center, bbox_size, dir_angle], dim=-1)
return bbox3d return bbox3d
def decode_corners(self, center, size_res, size_class):
"""Decode center, size residuals and class to corners. Only useful for
axis-aligned bounding boxes, so angle isn't considered.
Args:
center (torch.Tensor): Shape [B, N, 3]
size_res (torch.Tensor): Shape [B, N, 3] or [B, N, C, 3]
size_class (torch.Tensor): Shape: [B, N] or [B, N, 1]
or [B, N, C, 3]
Returns:
torch.Tensor: Corners with shape [B, N, 6]
"""
if len(size_class.shape) == 2 or size_class.shape[-1] == 1:
batch_size, proposal_num = size_class.shape[:2]
one_hot_size_class = size_res.new_zeros(
(batch_size, proposal_num, self.num_sizes))
if len(size_class.shape) == 2:
size_class = size_class.unsqueeze(-1)
one_hot_size_class.scatter_(2, size_class, 1)
one_hot_size_class_expand = one_hot_size_class.unsqueeze(
-1).repeat(1, 1, 1, 3).contiguous()
else:
one_hot_size_class_expand = size_class
if len(size_res.shape) == 4:
size_res = torch.sum(size_res * one_hot_size_class_expand, 2)
mean_sizes = size_res.new_tensor(self.mean_sizes)
mean_sizes = torch.sum(mean_sizes * one_hot_size_class_expand, 2)
size_full = (size_res + 1) * mean_sizes
size_full = torch.clamp(size_full, 0)
half_size_full = size_full / 2
corner1 = center - half_size_full
corner2 = center + half_size_full
corners = torch.cat([corner1, corner2], dim=-1)
return corners
def split_pred(self, cls_preds, reg_preds, base_xyz): def split_pred(self, cls_preds, reg_preds, base_xyz):
"""Split predicted features to specific parts. """Split predicted features to specific parts.
......
from .iou3d_calculator import (BboxOverlaps3D, BboxOverlapsNearest3D, from .iou3d_calculator import (AxisAlignedBboxOverlaps3D, BboxOverlaps3D,
bbox_overlaps_3d, bbox_overlaps_nearest_3d) BboxOverlapsNearest3D,
axis_aligned_bbox_overlaps_3d, bbox_overlaps_3d,
bbox_overlaps_nearest_3d)
__all__ = [ __all__ = [
'BboxOverlapsNearest3D', 'BboxOverlaps3D', 'bbox_overlaps_nearest_3d', 'BboxOverlapsNearest3D', 'BboxOverlaps3D', 'bbox_overlaps_nearest_3d',
'bbox_overlaps_3d' 'bbox_overlaps_3d', 'AxisAlignedBboxOverlaps3D',
'axis_aligned_bbox_overlaps_3d'
] ]
import torch
from mmdet.core.bbox import bbox_overlaps from mmdet.core.bbox import bbox_overlaps
from mmdet.core.bbox.iou_calculators.builder import IOU_CALCULATORS from mmdet.core.bbox.iou_calculators.builder import IOU_CALCULATORS
from ..structures import get_box_type from ..structures import get_box_type
...@@ -163,3 +165,156 @@ def bbox_overlaps_3d(bboxes1, bboxes2, mode='iou', coordinate='camera'): ...@@ -163,3 +165,156 @@ def bbox_overlaps_3d(bboxes1, bboxes2, mode='iou', coordinate='camera'):
bboxes2 = box_type(bboxes2, box_dim=bboxes2.shape[-1]) bboxes2 = box_type(bboxes2, box_dim=bboxes2.shape[-1])
return bboxes1.overlaps(bboxes1, bboxes2, mode=mode) return bboxes1.overlaps(bboxes1, bboxes2, mode=mode)
@IOU_CALCULATORS.register_module()
class AxisAlignedBboxOverlaps3D(object):
"""Axis-aligned 3D Overlaps (IoU) Calculator."""
def __call__(self, bboxes1, bboxes2, mode='iou', is_aligned=False):
"""Calculate IoU between 2D bboxes.
Args:
bboxes1 (Tensor): shape (B, m, 6) in <x1, y1, z1, x2, y2, z2>
format or empty.
bboxes2 (Tensor): shape (B, n, 6) in <x1, y1, z1, x2, y2, z2>
format or empty.
B indicates the batch dim, in shape (B1, B2, ..., Bn).
If ``is_aligned `` is ``True``, then m and n must be equal.
mode (str): "iou" (intersection over union) or "giou" (generalized
intersection over union).
is_aligned (bool, optional): If True, then m and n must be equal.
Default False.
Returns:
Tensor: shape (m, n) if ``is_aligned `` is False else shape (m,)
"""
assert bboxes1.size(-1) == bboxes2.size(-1) == 6
return axis_aligned_bbox_overlaps_3d(bboxes1, bboxes2, mode,
is_aligned)
def __repr__(self):
"""str: a string describing the module"""
repr_str = self.__class__.__name__ + '()'
return repr_str
def axis_aligned_bbox_overlaps_3d(bboxes1,
bboxes2,
mode='iou',
is_aligned=False,
eps=1e-6):
"""Calculate overlap between two set of axis aligned 3D bboxes. If
``is_aligned `` is ``False``, then calculate the overlaps between each bbox
of bboxes1 and bboxes2, otherwise the overlaps between each aligned pair of
bboxes1 and bboxes2.
Args:
bboxes1 (Tensor): shape (B, m, 6) in <x1, y1, z1, x2, y2, z2>
format or empty.
bboxes2 (Tensor): shape (B, n, 6) in <x1, y1, z1, x2, y2, z2>
format or empty.
B indicates the batch dim, in shape (B1, B2, ..., Bn).
If ``is_aligned `` is ``True``, then m and n must be equal.
mode (str): "iou" (intersection over union) or "giou" (generalized
intersection over union).
is_aligned (bool, optional): If True, then m and n must be equal.
Default False.
eps (float, optional): A value added to the denominator for numerical
stability. Default 1e-6.
Returns:
Tensor: shape (m, n) if ``is_aligned `` is False else shape (m,)
Example:
>>> bboxes1 = torch.FloatTensor([
>>> [0, 0, 0, 10, 10, 10],
>>> [10, 10, 10, 20, 20, 20],
>>> [32, 32, 32, 38, 40, 42],
>>> ])
>>> bboxes2 = torch.FloatTensor([
>>> [0, 0, 0, 10, 20, 20],
>>> [0, 10, 10, 10, 19, 20],
>>> [10, 10, 10, 20, 20, 20],
>>> ])
>>> overlaps = axis_aligned_bbox_overlaps_3d(bboxes1, bboxes2)
>>> assert overlaps.shape == (3, 3)
>>> overlaps = bbox_overlaps(bboxes1, bboxes2, is_aligned=True)
>>> assert overlaps.shape == (3, )
Example:
>>> empty = torch.empty(0, 6)
>>> nonempty = torch.FloatTensor([[0, 0, 0, 10, 9, 10]])
>>> assert tuple(bbox_overlaps(empty, nonempty).shape) == (0, 1)
>>> assert tuple(bbox_overlaps(nonempty, empty).shape) == (1, 0)
>>> assert tuple(bbox_overlaps(empty, empty).shape) == (0, 0)
"""
assert mode in ['iou', 'giou'], f'Unsupported mode {mode}'
# Either the boxes are empty or the length of boxes's last dimenstion is 6
assert (bboxes1.size(-1) == 6 or bboxes1.size(0) == 0)
assert (bboxes2.size(-1) == 6 or bboxes2.size(0) == 0)
# Batch dim must be the same
# Batch dim: (B1, B2, ... Bn)
assert bboxes1.shape[:-2] == bboxes2.shape[:-2]
batch_shape = bboxes1.shape[:-2]
rows = bboxes1.size(-2)
cols = bboxes2.size(-2)
if is_aligned:
assert rows == cols
if rows * cols == 0:
if is_aligned:
return bboxes1.new(batch_shape + (rows, ))
else:
return bboxes1.new(batch_shape + (rows, cols))
area1 = (bboxes1[..., 3] -
bboxes1[..., 0]) * (bboxes1[..., 4] - bboxes1[..., 1]) * (
bboxes1[..., 5] - bboxes1[..., 2])
area2 = (bboxes2[..., 3] -
bboxes2[..., 0]) * (bboxes2[..., 4] - bboxes2[..., 1]) * (
bboxes2[..., 5] - bboxes2[..., 2])
if is_aligned:
lt = torch.max(bboxes1[..., :3], bboxes2[..., :3]) # [B, rows, 3]
rb = torch.min(bboxes1[..., 3:], bboxes2[..., 3:]) # [B, rows, 3]
wh = (rb - lt).clamp(min=0) # [B, rows, 2]
overlap = wh[..., 0] * wh[..., 1] * wh[..., 2]
if mode in ['iou', 'giou']:
union = area1 + area2 - overlap
else:
union = area1
if mode == 'giou':
enclosed_lt = torch.min(bboxes1[..., :3], bboxes2[..., :3])
enclosed_rb = torch.max(bboxes1[..., 3:], bboxes2[..., 3:])
else:
lt = torch.max(bboxes1[..., :, None, :3],
bboxes2[..., None, :, :3]) # [B, rows, cols, 3]
rb = torch.min(bboxes1[..., :, None, 3:],
bboxes2[..., None, :, 3:]) # [B, rows, cols, 3]
wh = (rb - lt).clamp(min=0) # [B, rows, cols, 3]
overlap = wh[..., 0] * wh[..., 1] * wh[..., 2]
if mode in ['iou', 'giou']:
union = area1[..., None] + area2[..., None, :] - overlap
if mode == 'giou':
enclosed_lt = torch.min(bboxes1[..., :, None, :3],
bboxes2[..., None, :, :3])
enclosed_rb = torch.max(bboxes1[..., :, None, 3:],
bboxes2[..., None, :, 3:])
eps = union.new_tensor([eps])
union = torch.max(union, eps)
ious = overlap / union
if mode in ['iou']:
return ious
# calculate gious
enclose_wh = (enclosed_rb - enclosed_lt).clamp(min=0)
enclose_area = enclose_wh[..., 0] * enclose_wh[..., 1] * enclose_wh[..., 2]
enclose_area = torch.max(enclose_area, eps)
gious = ious - (enclose_area - union) / enclose_area
return gious
...@@ -55,7 +55,8 @@ class VoteHead(nn.Module): ...@@ -55,7 +55,8 @@ class VoteHead(nn.Module):
dir_res_loss=None, dir_res_loss=None,
size_class_loss=None, size_class_loss=None,
size_res_loss=None, size_res_loss=None,
semantic_loss=None): semantic_loss=None,
iou_loss=None):
super(VoteHead, self).__init__() super(VoteHead, self).__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.train_cfg = train_cfg self.train_cfg = train_cfg
...@@ -72,6 +73,10 @@ class VoteHead(nn.Module): ...@@ -72,6 +73,10 @@ class VoteHead(nn.Module):
self.size_class_loss = build_loss(size_class_loss) self.size_class_loss = build_loss(size_class_loss)
if semantic_loss is not None: if semantic_loss is not None:
self.semantic_loss = build_loss(semantic_loss) self.semantic_loss = build_loss(semantic_loss)
if iou_loss is not None:
self.iou_loss = build_loss(iou_loss)
else:
self.iou_loss = None
self.bbox_coder = build_bbox_coder(bbox_coder) self.bbox_coder = build_bbox_coder(bbox_coder)
self.num_sizes = self.bbox_coder.num_sizes self.num_sizes = self.bbox_coder.num_sizes
...@@ -241,9 +246,10 @@ class VoteHead(nn.Module): ...@@ -241,9 +246,10 @@ class VoteHead(nn.Module):
pts_semantic_mask, pts_instance_mask, pts_semantic_mask, pts_instance_mask,
bbox_preds) bbox_preds)
(vote_targets, vote_target_masks, size_class_targets, size_res_targets, (vote_targets, vote_target_masks, size_class_targets, size_res_targets,
dir_class_targets, dir_res_targets, center_targets, mask_targets, dir_class_targets, dir_res_targets, center_targets,
valid_gt_masks, objectness_targets, objectness_weights, assigned_center_targets, mask_targets, valid_gt_masks,
box_loss_weights, valid_gt_weights) = targets objectness_targets, objectness_weights, box_loss_weights,
valid_gt_weights) = targets
# calculate vote loss # calculate vote loss
vote_loss = self.vote_module.get_loss(bbox_preds['seed_points'], vote_loss = self.vote_module.get_loss(bbox_preds['seed_points'],
...@@ -318,6 +324,17 @@ class VoteHead(nn.Module): ...@@ -318,6 +324,17 @@ class VoteHead(nn.Module):
size_class_loss=size_class_loss, size_class_loss=size_class_loss,
size_res_loss=size_res_loss) size_res_loss=size_res_loss)
if self.iou_loss:
corners_pred = self.bbox_coder.decode_corners(
bbox_preds['center'], size_residual_norm,
one_hot_size_targets_expand)
corners_target = self.bbox_coder.decode_corners(
assigned_center_targets, size_res_targets,
one_hot_size_targets_expand)
iou_loss = self.iou_loss(
corners_pred, corners_target, weight=box_loss_weights)
losses['iou_loss'] = iou_loss
if ret_target: if ret_target:
losses['targets'] = targets losses['targets'] = targets
...@@ -373,10 +390,12 @@ class VoteHead(nn.Module): ...@@ -373,10 +390,12 @@ class VoteHead(nn.Module):
] ]
(vote_targets, vote_target_masks, size_class_targets, size_res_targets, (vote_targets, vote_target_masks, size_class_targets, size_res_targets,
dir_class_targets, dir_res_targets, center_targets, mask_targets, dir_class_targets, dir_res_targets, center_targets,
objectness_targets, objectness_masks) = multi_apply( assigned_center_targets, mask_targets, objectness_targets,
self.get_targets_single, points, gt_bboxes_3d, gt_labels_3d, objectness_masks) = multi_apply(self.get_targets_single, points,
pts_semantic_mask, pts_instance_mask, aggregated_points) gt_bboxes_3d, gt_labels_3d,
pts_semantic_mask, pts_instance_mask,
aggregated_points)
# pad targets as original code of votenet. # pad targets as original code of votenet.
for index in range(len(gt_labels_3d)): for index in range(len(gt_labels_3d)):
...@@ -390,6 +409,7 @@ class VoteHead(nn.Module): ...@@ -390,6 +409,7 @@ class VoteHead(nn.Module):
center_targets = torch.stack(center_targets) center_targets = torch.stack(center_targets)
valid_gt_masks = torch.stack(valid_gt_masks) valid_gt_masks = torch.stack(valid_gt_masks)
assigned_center_targets = torch.stack(assigned_center_targets)
objectness_targets = torch.stack(objectness_targets) objectness_targets = torch.stack(objectness_targets)
objectness_weights = torch.stack(objectness_masks) objectness_weights = torch.stack(objectness_masks)
objectness_weights /= (torch.sum(objectness_weights) + 1e-6) objectness_weights /= (torch.sum(objectness_weights) + 1e-6)
...@@ -405,9 +425,9 @@ class VoteHead(nn.Module): ...@@ -405,9 +425,9 @@ class VoteHead(nn.Module):
return (vote_targets, vote_target_masks, size_class_targets, return (vote_targets, vote_target_masks, size_class_targets,
size_res_targets, dir_class_targets, dir_res_targets, size_res_targets, dir_class_targets, dir_res_targets,
center_targets, mask_targets, valid_gt_masks, center_targets, assigned_center_targets, mask_targets,
objectness_targets, objectness_weights, box_loss_weights, valid_gt_masks, objectness_targets, objectness_weights,
valid_gt_weights) box_loss_weights, valid_gt_weights)
def get_targets_single(self, def get_targets_single(self,
points, points,
...@@ -526,10 +546,11 @@ class VoteHead(nn.Module): ...@@ -526,10 +546,11 @@ class VoteHead(nn.Module):
size_res_targets /= pos_mean_sizes size_res_targets /= pos_mean_sizes
mask_targets = gt_labels_3d[assignment] mask_targets = gt_labels_3d[assignment]
assigned_center_targets = center_targets[assignment]
return (vote_targets, vote_target_masks, size_class_targets, return (vote_targets, vote_target_masks, size_class_targets,
size_res_targets, size_res_targets, dir_class_targets,
dir_class_targets, dir_res_targets, center_targets, dir_res_targets, center_targets, assigned_center_targets,
mask_targets.long(), objectness_targets, objectness_masks) mask_targets.long(), objectness_targets, objectness_masks)
def get_bboxes(self, def get_bboxes(self,
......
from mmdet.models.losses import FocalLoss, SmoothL1Loss, binary_cross_entropy from mmdet.models.losses import FocalLoss, SmoothL1Loss, binary_cross_entropy
from .axis_aligned_iou_loss import AxisAlignedIoULoss, axis_aligned_iou_loss
from .chamfer_distance import ChamferDistance, chamfer_distance from .chamfer_distance import ChamferDistance, chamfer_distance
__all__ = [ __all__ = [
'FocalLoss', 'SmoothL1Loss', 'binary_cross_entropy', 'ChamferDistance', 'FocalLoss', 'SmoothL1Loss', 'binary_cross_entropy', 'ChamferDistance',
'chamfer_distance' 'chamfer_distance', 'axis_aligned_iou_loss', 'AxisAlignedIoULoss'
] ]
import torch
from torch import nn as nn
from mmdet.models.builder import LOSSES
from mmdet.models.losses.utils import weighted_loss
from ...core.bbox import AxisAlignedBboxOverlaps3D
@weighted_loss
def axis_aligned_iou_loss(pred, target):
"""Calculate the IoU loss (1-IoU) of two set of axis aligned bounding
boxes. Note that predictions and targets are one-to-one corresponded.
Args:
pred (torch.Tensor): Bbox predictions with shape [..., 3].
target (torch.Tensor): Bbox targets (gt) with shape [..., 3].
Returns:
torch.Tensor: IoU loss between predictions and targets.
"""
axis_aligned_iou = AxisAlignedBboxOverlaps3D()(
pred, target, is_aligned=True)
iou_loss = 1 - axis_aligned_iou
return iou_loss
@LOSSES.register_module()
class AxisAlignedIoULoss(nn.Module):
"""Calculate the IoU loss (1-IoU) of axis aligned bounding boxes.
Args:
reduction (str): Method to reduce losses.
The valid reduction method are none, sum or mean.
loss_weight (float, optional): Weight of loss. Defaults to 1.0.
"""
def __init__(self, reduction='mean', loss_weight=1.0):
super(AxisAlignedIoULoss, self).__init__()
assert reduction in ['none', 'sum', 'mean']
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None,
**kwargs):
"""Forward function of loss calculation.
Args:
pred (torch.Tensor): Bbox predictions with shape [..., 3].
target (torch.Tensor): Bbox targets (gt) with shape [..., 3].
weight (torch.Tensor|float, optional): Weight of loss. \
Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): Method to reduce losses.
The valid reduction method are 'none', 'sum' or 'mean'.
Defaults to None.
Returns:
torch.Tensor: IoU loss between predictions and targets.
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if (weight is not None) and (not torch.any(weight > 0)) and (
reduction != 'none'):
return (pred * weight).sum()
return axis_aligned_iou_loss(
pred,
target,
weight=weight,
avg_factor=avg_factor,
reduction=reduction) * self.loss_weight
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment