"git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "b6f114e581484b1d303cf3798bc2f34764306036"
Commit 6b3e52c9 authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

bugfixed: typo in anchor_head_multi

parent 3e92d169
......@@ -4,6 +4,7 @@ from .anchor_head_template import AnchorHeadTemplate
from ..backbones_2d import BaseBEVBackbone
import torch
class SingleHead(BaseBEVBackbone):
def __init__(self, model_cfg, input_channels, num_class, num_anchors_per_location, code_size, encode_conv_cfg=None):
super().__init__(encode_conv_cfg, input_channels)
......@@ -75,6 +76,7 @@ class SingleHead(BaseBEVBackbone):
return ret_dict
class AnchorHeadMulti(AnchorHeadTemplate):
def __init__(self, model_cfg, input_channels, num_class, class_names, grid_size, point_cloud_range, predict_boxes_when_training=True):
super().__init__(
......@@ -157,7 +159,7 @@ class AnchorHeadMulti(AnchorHeadTemplate):
reg_weights = positives.float()
if self.num_class == 1:
# class agnostic
box_cls_labels[positive] = 1
box_cls_labels[positives] = 1
pos_normalizer = positives.sum(1, keepdim=True).float()
reg_weights /= torch.clamp(pos_normalizer, min=1.0)
cls_weights /= torch.clamp(pos_normalizer, min=1.0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment