Commit 69326cef authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

speed up PFNLayer

parent f70902b8
...@@ -3,6 +3,7 @@ import torch.nn as nn ...@@ -3,6 +3,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .vfe_template import VFETemplate from .vfe_template import VFETemplate
class PFNLayer(nn.Module): class PFNLayer(nn.Module):
def __init__(self, def __init__(self,
in_channels, in_channels,
...@@ -28,12 +29,14 @@ class PFNLayer(nn.Module): ...@@ -28,12 +29,14 @@ class PFNLayer(nn.Module):
if inputs.shape[0] > self.part: if inputs.shape[0] > self.part:
# nn.Linear performs randomly when batch size is too large # nn.Linear performs randomly when batch size is too large
num_parts = inputs.shape[0] // self.part num_parts = inputs.shape[0] // self.part
part_linear_out = [self.linear(inputs[num_part*self.part:(num_part+1)*self.part]) for num_part in range(num_parts+1)] part_linear_out = [self.linear(inputs[num_part*self.part:(num_part+1)*self.part])
for num_part in range(num_parts+1)]
x = torch.cat(part_linear_out, dim=0) x = torch.cat(part_linear_out, dim=0)
else: else:
x = self.linear(inputs) x = self.linear(inputs)
total_points, voxel_points, channels = x.shape torch.backends.cudnn.enabled = False
x = self.norm(x.view(-1, channels)).view(total_points, voxel_points, channels) if self.use_norm else x x = self.norm(x.permute(0, 2, 1)).permute(0, 2, 1) if self.use_norm else x
torch.backends.cudnn.enabled = True
x = F.relu(x) x = F.relu(x)
x_max = torch.max(x, dim=1, keepdim=True)[0] x_max = torch.max(x, dim=1, keepdim=True)[0]
...@@ -44,6 +47,7 @@ class PFNLayer(nn.Module): ...@@ -44,6 +47,7 @@ class PFNLayer(nn.Module):
x_concatenated = torch.cat([x, x_repeat], dim=2) x_concatenated = torch.cat([x, x_repeat], dim=2)
return x_concatenated return x_concatenated
class PillarVFE(VFETemplate): class PillarVFE(VFETemplate):
def __init__(self, model_cfg, num_point_features, voxel_size, point_cloud_range): def __init__(self, model_cfg, num_point_features, voxel_size, point_cloud_range):
super().__init__(model_cfg=model_cfg) super().__init__(model_cfg=model_cfg)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment