Unverified Commit 9f5d201e authored by Shaoshuai Shi's avatar Shaoshuai Shi Committed by GitHub
Browse files

add torch.no_grad context for proposal_layer (#160)

parent 2400fdf2
......@@ -4,6 +4,7 @@ from .anchor_head_template import AnchorHeadTemplate
from ..backbones_2d import BaseBEVBackbone
import torch
class SingleHead(BaseBEVBackbone):
def __init__(self, model_cfg, input_channels, num_class, num_anchors_per_location, code_size, encode_conv_cfg=None):
super().__init__(encode_conv_cfg, input_channels)
......@@ -75,6 +76,7 @@ class SingleHead(BaseBEVBackbone):
return ret_dict
class AnchorHeadMulti(AnchorHeadTemplate):
def __init__(self, model_cfg, input_channels, num_class, class_names, grid_size, point_cloud_range, predict_boxes_when_training=True):
super().__init__(
......@@ -83,7 +85,6 @@ class AnchorHeadMulti(AnchorHeadTemplate):
self.model_cfg = model_cfg
self.make_multihead(input_channels)
def make_multihead(self, input_channels):
rpn_head_cfgs = self.model_cfg.RPN_HEAD_CFGS
rpn_heads = []
......@@ -123,7 +124,8 @@ class AnchorHeadMulti(AnchorHeadTemplate):
gt_boxes=data_dict['gt_boxes']
)
self.forward_ret_dict.update(targets_dict)
else:
if not self.training or self.predict_boxes_when_training:
batch_cls_preds, batch_box_preds = self.generate_predicted_boxes(
batch_size=data_dict['batch_size'],
cls_preds=cls_preds, box_preds=box_preds, dir_cls_preds=dir_cls_preds
......
......@@ -39,6 +39,7 @@ class RoIHeadTemplate(nn.Module):
fc_layers = nn.Sequential(*fc_layers)
return fc_layers
@torch.no_grad()
def proposal_layer(self, batch_dict, nms_config):
"""
Args:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment