Unverified Commit 9f5d201e authored by Shaoshuai Shi's avatar Shaoshuai Shi Committed by GitHub
Browse files

add torch.no_grad context for proposal_layer (#160)

parent 2400fdf2
...@@ -4,6 +4,7 @@ from .anchor_head_template import AnchorHeadTemplate ...@@ -4,6 +4,7 @@ from .anchor_head_template import AnchorHeadTemplate
from ..backbones_2d import BaseBEVBackbone from ..backbones_2d import BaseBEVBackbone
import torch import torch
class SingleHead(BaseBEVBackbone): class SingleHead(BaseBEVBackbone):
def __init__(self, model_cfg, input_channels, num_class, num_anchors_per_location, code_size, encode_conv_cfg=None): def __init__(self, model_cfg, input_channels, num_class, num_anchors_per_location, code_size, encode_conv_cfg=None):
super().__init__(encode_conv_cfg, input_channels) super().__init__(encode_conv_cfg, input_channels)
...@@ -75,6 +76,7 @@ class SingleHead(BaseBEVBackbone): ...@@ -75,6 +76,7 @@ class SingleHead(BaseBEVBackbone):
return ret_dict return ret_dict
class AnchorHeadMulti(AnchorHeadTemplate): class AnchorHeadMulti(AnchorHeadTemplate):
def __init__(self, model_cfg, input_channels, num_class, class_names, grid_size, point_cloud_range, predict_boxes_when_training=True): def __init__(self, model_cfg, input_channels, num_class, class_names, grid_size, point_cloud_range, predict_boxes_when_training=True):
super().__init__( super().__init__(
...@@ -82,7 +84,6 @@ class AnchorHeadMulti(AnchorHeadTemplate): ...@@ -82,7 +84,6 @@ class AnchorHeadMulti(AnchorHeadTemplate):
) )
self.model_cfg = model_cfg self.model_cfg = model_cfg
self.make_multihead(input_channels) self.make_multihead(input_channels)
def make_multihead(self, input_channels): def make_multihead(self, input_channels):
rpn_head_cfgs = self.model_cfg.RPN_HEAD_CFGS rpn_head_cfgs = self.model_cfg.RPN_HEAD_CFGS
...@@ -123,7 +124,8 @@ class AnchorHeadMulti(AnchorHeadTemplate): ...@@ -123,7 +124,8 @@ class AnchorHeadMulti(AnchorHeadTemplate):
gt_boxes=data_dict['gt_boxes'] gt_boxes=data_dict['gt_boxes']
) )
self.forward_ret_dict.update(targets_dict) self.forward_ret_dict.update(targets_dict)
else:
if not self.training or self.predict_boxes_when_training:
batch_cls_preds, batch_box_preds = self.generate_predicted_boxes( batch_cls_preds, batch_box_preds = self.generate_predicted_boxes(
batch_size=data_dict['batch_size'], batch_size=data_dict['batch_size'],
cls_preds=cls_preds, box_preds=box_preds, dir_cls_preds=dir_cls_preds cls_preds=cls_preds, box_preds=box_preds, dir_cls_preds=dir_cls_preds
......
...@@ -39,6 +39,7 @@ class RoIHeadTemplate(nn.Module): ...@@ -39,6 +39,7 @@ class RoIHeadTemplate(nn.Module):
fc_layers = nn.Sequential(*fc_layers) fc_layers = nn.Sequential(*fc_layers)
return fc_layers return fc_layers
@torch.no_grad()
def proposal_layer(self, batch_dict, nms_config): def proposal_layer(self, batch_dict, nms_config):
""" """
Args: Args:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment