Commit 69326cef authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

speed up PFNLayer

parent f70902b8
......@@ -3,6 +3,7 @@ import torch.nn as nn
import torch.nn.functional as F
from .vfe_template import VFETemplate
class PFNLayer(nn.Module):
def __init__(self,
in_channels,
......@@ -28,12 +29,14 @@ class PFNLayer(nn.Module):
if inputs.shape[0] > self.part:
# nn.Linear performs randomly when batch size is too large
num_parts = inputs.shape[0] // self.part
part_linear_out = [self.linear(inputs[num_part*self.part:(num_part+1)*self.part]) for num_part in range(num_parts+1)]
part_linear_out = [self.linear(inputs[num_part*self.part:(num_part+1)*self.part])
for num_part in range(num_parts+1)]
x = torch.cat(part_linear_out, dim=0)
else:
x = self.linear(inputs)
total_points, voxel_points, channels = x.shape
x = self.norm(x.view(-1, channels)).view(total_points, voxel_points, channels) if self.use_norm else x
x = self.linear(inputs)
torch.backends.cudnn.enabled = False
x = self.norm(x.permute(0, 2, 1)).permute(0, 2, 1) if self.use_norm else x
torch.backends.cudnn.enabled = True
x = F.relu(x)
x_max = torch.max(x, dim=1, keepdim=True)[0]
......@@ -44,6 +47,7 @@ class PFNLayer(nn.Module):
x_concatenated = torch.cat([x, x_repeat], dim=2)
return x_concatenated
class PillarVFE(VFETemplate):
def __init__(self, model_cfg, num_point_features, voxel_size, point_cloud_range):
super().__init__(model_cfg=model_cfg)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment