Commit 09385e1e authored by acivgin1's avatar acivgin1
Browse files

backward support for spconv import and features

parent c9674890
from functools import partial
import spconv.pytorch as spconv
import torch.nn as nn
from ...spconv_utils import replace_feature, spconv
def post_act_block(in_channels, out_channels, kernel_size, indice_key=None, stride=1, padding=0,
conv_type='subm', norm_fn=None):
......@@ -50,17 +51,17 @@ class SparseBasicBlock(spconv.SparseModule):
identity = x
out = self.conv1(x)
out = out.replace_feature(self.bn1(out.features))
out = out.replace_feature(self.relu(out.features))
out = replace_feature(out, self.bn1(out.features))
out = replace_feature(out, self.relu(out.features))
out = self.conv2(out)
out = out.replace_feature(self.bn2(out.features))
out = replace_feature(out, self.bn2(out.features))
if self.downsample is not None:
identity = self.downsample(x)
out = out.replace_feature(out.features + identity.features)
out = out.replace_feature(self.relu(out.features))
out = replace_feature(out, out.features + identity.features)
out = replace_feature(out, self.relu(out.features))
return out
......
from functools import partial
import spconv.pytorch as spconv
import torch
import torch.nn as nn
from ...spconv_utils import replace_feature, spconv
from ...utils import common_utils
from .spconv_backbone import post_act_block
......@@ -31,17 +31,17 @@ class SparseBasicBlock(spconv.SparseModule):
assert x.features.dim() == 2, 'x.features.dim()=%d' % x.features.dim()
out = self.conv1(x)
out = out.replace_feature(self.bn1(out.features))
out = out.replace_feature(self.relu(out.features))
out = replace_feature(out, self.bn1(out.features))
out = replace_feature(out, self.relu(out.features))
out = self.conv2(out)
out = out.replace_feature(self.bn2(out.features))
out = replace_feature(out, self.bn2(out.features))
if self.downsample is not None:
identity = self.downsample(x)
out = out.replace_feature(out.features + identity)
out = out.replace_feature(self.relu(out.features))
out = replace_feature(out, out.features + identity)
out = replace_feature(out, self.relu(out.features))
return out
......@@ -52,6 +52,7 @@ class UNetV2(nn.Module):
Reference Paper: https://arxiv.org/abs/1907.03670 (Shaoshuai Shi, et. al)
From Points to Parts: 3D Object Detection from Point Cloud with Part-aware and Part-aggregation Network
"""
def __init__(self, model_cfg, input_channels, grid_size, voxel_size, point_cloud_range, **kwargs):
super().__init__()
self.model_cfg = model_cfg
......@@ -134,10 +135,10 @@ class UNetV2(nn.Module):
def UR_block_forward(self, x_lateral, x_bottom, conv_t, conv_m, conv_inv):
x_trans = conv_t(x_lateral)
x = x_trans
x = x.replace_feature(torch.cat((x_bottom.features, x_trans.features), dim=1))
x = replace_feature(x, torch.cat((x_bottom.features, x_trans.features), dim=1))
x_m = conv_m(x)
x = self.channel_reduction(x, x_m.features.shape[1])
x = x.replace_feature(x_m.features + x.features)
x = replace_feature(x, x_m.features + x.features)
x = conv_inv(x)
return x
......@@ -155,7 +156,7 @@ class UNetV2(nn.Module):
n, in_channels = features.shape
assert (in_channels % out_channels == 0) and (in_channels >= out_channels)
x = x.replace_feature(features.view(n, out_channels, -1).sum(dim=2))
x = replace_feature(x, features.view(n, out_channels, -1).sum(dim=2))
return x
def forward(self, batch_dict):
......
import numpy as np
import spconv.pytorch as spconv
import torch
import torch.nn as nn
from ...ops.roiaware_pool3d import roiaware_pool3d_utils
from ...spconv_utils import spconv
from .roi_head_template import RoIHeadTemplate
......
from typing import Set
import spconv.pytorch as spconv
try:
import spconv.pytorch as spconv
except:
import spconv as spconv
import torch.nn as nn
......@@ -19,3 +23,12 @@ def find_all_spconv_keys(model: nn.Module, prefix="") -> Set[str]:
found_keys.update(find_all_spconv_keys(child, prefix=new_prefix))
return found_keys
def replace_feature(out, new_features):
if "replace_feature" in out.__dir__():
# spconv 2.x behaviour
return out.replace_feature(new_features)
else:
out.features = new_features
return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment