Commit 520065c6 authored by yan.yan's avatar yan.yan
Browse files

fix a bug in backward

parent 35a9fa40
......@@ -24,6 +24,7 @@
* ```spconv.xxx``` move to ```spconv.pytorch.xxx```, change all ```import spconv``` to ```import spconv.pytorch as spconv``` and ```from spconv.xxx import``` to ```from spconv.pytorch.xxx import```.
* ```use_hash``` in Sparse Convolution is removed, we only use hash table in 2.x.
* ```x.features = F.relu(x)``` now raise error. use ```x = x.replace_feature(F.relu(x.features))``` instead.
* weight layout has been changed to RSKC (native algorithm) or KRSC (implicit gemm), no longer RSCK (spconv 1.x). RS is kernel size, C is input channel, K is output channel.
* all util ops are removed (pillar scatter/nms/...)
* VoxelGenerator has been replaced by Point2VoxelGPU[1-4]d/Point2VoxelCPU[1-4]d.
......@@ -47,7 +48,7 @@
## Install
You need to install python >= 3.6 first to use spconv 2.x.
You need to install python >= 3.7 first to use spconv 2.x.
You need to install CUDA toolkit first before using prebuilt binaries or build from source.
......@@ -55,7 +56,7 @@ You need at least CUDA 10.2 to build and run spconv 2.x. We won't offer any supp
### Prebuilt
We offer python 3.6-3.10 and cuda 10.2/11.1/11.4 prebuilt binaries for linux (manylinux) and windows 10/11.
We offer python 3.7-3.10 and cuda 10.2/11.1/11.4 prebuilt binaries for linux (manylinux) and windows 10/11.
We will offer prebuilts for CUDA versions supported by latest pytorch release. For example, pytorch 1.9 support cuda 10.2 and 11.1, so we support them too.
......
......@@ -28,14 +28,14 @@ class Net(nn.Module):
super(Net, self).__init__()
self.net = spconv.SparseSequential(
nn.BatchNorm1d(1),
spconv.SparseConv2d(1, 32, 3, 1),
spconv.SubMConv2d(1, 32, 3, 1),
nn.ReLU(),
spconv.SparseConv2d(32, 64, 3, 1),
spconv.SubMConv2d(32, 64, 3, 1),
nn.ReLU(),
spconv.SparseMaxPool2d(2, 2),
spconv.ToDense(),
)
self.fc1 = nn.Linear(9216, 128)
self.fc1 = nn.Linear(14 * 14 * 64, 128)
self.fc2 = nn.Linear(128, 10)
self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout2d(0.5)
......
......@@ -12,114 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import numpy as np
import torch
class IndiceData(object):
def __init__(self, out_indices, indices, indice_pairs, indice_pair_num,
out_spatial_shape):
self.out_indices = out_indices
self.indices = indices
self.indice_pairs = indice_pairs
self.indice_pair_num = indice_pair_num
self.out_spatial_shape = out_spatial_shape
def scatter_nd(indices, updates, shape):
"""pytorch edition of tensorflow scatter_nd.
this function don't contain except handle code. so use this carefully
when indice repeats, don't support repeat add which is supported
in tensorflow.
"""
ret = torch.zeros(*shape, dtype=updates.dtype, device=updates.device)
ndim = indices.shape[-1]
output_shape = list(indices.shape[:-1]) + shape[indices.shape[-1]:]
flatted_indices = indices.view(-1, ndim)
slices = [flatted_indices[:, i] for i in range(ndim)]
slices += [Ellipsis]
ret[slices] = updates.view(*output_shape)
return ret
class SparseConvTensor(object):
def __init__(self,
features,
indices,
spatial_shape,
batch_size,
grid=None,
benchmark=False):
"""
Args:
features: [num_points, num_features] feature tensor
indices: [num_points, ndim + 1] indice tensor. batch index saved in indices[:, 0]
spatial_shape: spatial shape of your sparse data
batch_size: batch size of your sparse data
grid: pre-allocated grid tensor. should be used when the volume of spatial shape
is very large.
benchmark: whether to enable benchmark. if enabled, all sparse operators will be record to
SparseConvTensor.
"""
self.features = features
self.indices = indices
self.spatial_shape = spatial_shape
self.batch_size = batch_size
self.indice_dict = {}
if grid is None:
grid = torch.Tensor() # empty tensor
self.grid = grid
self.benchmark = benchmark
self.benchmark_record = {}
@classmethod
def from_dense(cls, x: torch.Tensor):
"""create sparse tensor fron channel last dense tensor by to_sparse
x must be NHWC tensor, channel last
"""
x = x.to_sparse(x.ndim - 1)
spatial_shape = x.shape[1:-1]
batch_size = x.shape[0]
indices_th = x.indices().permute(1, 0).contiguous().int()
features_th = x.values()
return cls(features_th, indices_th, spatial_shape, batch_size)
@property
def spatial_size(self):
return np.prod(self.spatial_shape)
def find_indice_pair(self, key) -> Optional[IndiceData]:
if key is None:
return None
if key in self.indice_dict:
return self.indice_dict[key]
return None
def dense(self, channels_first=True):
output_shape = [self.batch_size] + list(
self.spatial_shape) + [self.features.shape[1]]
res = scatter_nd(
self.indices.to(self.features.device).long(), self.features,
output_shape)
if not channels_first:
return res
ndim = len(self.spatial_shape)
trans_params = list(range(0, ndim + 1))
trans_params.insert(1, ndim + 1)
return res.permute(*trans_params).contiguous()
@property
def sparity(self):
return self.indices.shape[0] / np.prod(
self.spatial_shape) / self.batch_size
def shadow_copy(self) -> "SparseConvTensor":
"""create a new spconv tensor with all member unchanged"""
tensor = SparseConvTensor(self.features, self.indices,
self.spatial_shape, self.batch_size,
self.grid, self.benchmark)
tensor.benchmark_record = self.benchmark_record
tensor.indice_dict = self.indice_dict
return tensor
if torch.__version__ >= "1.8.0":
from .core_fx import *
else:
from .core import *
# Copyright 2021 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import numpy as np
import torch
import torch.fx
from torch.fx.symbolic_trace import ProxyableClassMeta
class IndiceData(object):
def __init__(self, out_indices, indices, indice_pairs, indice_pair_num,
out_spatial_shape):
self.out_indices = out_indices
self.indices = indices
self.indice_pairs = indice_pairs
self.indice_pair_num = indice_pair_num
self.out_spatial_shape = out_spatial_shape
def scatter_nd(indices, updates, shape):
"""pytorch edition of tensorflow scatter_nd.
this function don't contain except handle code. so use this carefully
when indice repeats, don't support repeat add which is supported
in tensorflow.
"""
ret = torch.zeros(*shape, dtype=updates.dtype, device=updates.device)
ndim = indices.shape[-1]
output_shape = list(indices.shape[:-1]) + shape[indices.shape[-1]:]
flatted_indices = indices.view(-1, ndim)
slices = [flatted_indices[:, i] for i in range(ndim)]
slices += [Ellipsis]
ret[slices] = updates.view(*output_shape)
return ret
class SparseConvTensor(metaclass=ProxyableClassMeta):
def __init__(self,
features,
indices,
spatial_shape,
batch_size,
grid=None,
voxel_num=None,
benchmark=False):
"""
Args:
features: [num_points, num_features] feature tensor
indices: [num_points, ndim + 1] indice tensor. batch index saved in indices[:, 0]
spatial_shape: spatial shape of your sparse data
batch_size: batch size of your sparse data
grid: pre-allocated grid tensor. should be used when the volume of spatial shape
is very large.
benchmark: whether to enable benchmark. if enabled, all sparse operators will be record to
SparseConvTensor.
"""
self._features = features
self.indices = indices
self.spatial_shape = spatial_shape
self.batch_size = batch_size
self.indice_dict = {}
if grid is None:
grid = torch.Tensor() # empty tensor
self.grid = grid
self.voxel_num = voxel_num # for tensorrt
self.benchmark = benchmark
self.benchmark_record = {}
def replace_feature(self, feature):
"""we need to replace x.features = F.relu(x) with x = x.replace_feature(F.relu(x.features))
due to limit of torch.fx
"""
new_spt = SparseConvTensor(feature, self.indices, self.spatial_shape, self.batch_size, self.grid, self.voxel_num, self.indice_dict)
new_spt.benchmark = self.benchmark
new_spt.benchmark_record = self.benchmark_record
return new_spt
@property
def features(self):
return self._features
@features.setter
def features(self, val):
msg = ("you can't set feature directly, use 'x = x.replace_feature(your_new_feature)'"
" to generate new SparseConvTensor instead.")
raise ValueError(msg)
@classmethod
def from_dense(cls, x: torch.Tensor):
"""create sparse tensor fron channel last dense tensor by to_sparse
x must be NHWC tensor, channel last
"""
x = x.to_sparse(x.ndim - 1)
spatial_shape = x.shape[1:-1]
batch_size = x.shape[0]
indices_th = x.indices().permute(1, 0).contiguous().int()
features_th = x.values()
return cls(features_th, indices_th, spatial_shape, batch_size)
@property
def spatial_size(self):
return np.prod(self.spatial_shape)
def find_indice_pair(self, key) -> Optional[IndiceData]:
if key is None:
return None
if key in self.indice_dict:
return self.indice_dict[key]
return None
def dense(self, channels_first=True):
output_shape = [self.batch_size] + list(
self.spatial_shape) + [self.features.shape[1]]
res = scatter_nd(
self.indices.to(self.features.device).long(), self.features,
output_shape)
if not channels_first:
return res
ndim = len(self.spatial_shape)
trans_params = list(range(0, ndim + 1))
trans_params.insert(1, ndim + 1)
return res.permute(*trans_params).contiguous()
# remove this due to limit of torch.fx
# @property
# def sparity(self):
# return self.indices.shape[0] / np.prod(
# self.spatial_shape) / self.batch_size
def shadow_copy(self) -> "SparseConvTensor":
"""create a new spconv tensor with all member unchanged"""
tensor = SparseConvTensor(self.features, self.indices,
self.spatial_shape, self.batch_size,
self.grid, self.benchmark)
tensor.benchmark_record = self.benchmark_record
tensor.indice_dict = self.indice_dict
tensor.voxel_num = self.voxel_num
return tensor
# Copyright 2021 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import numpy as np
import torch
class IndiceData(object):
def __init__(self, out_indices, indices, indice_pairs, indice_pair_num,
out_spatial_shape):
self.out_indices = out_indices
self.indices = indices
self.indice_pairs = indice_pairs
self.indice_pair_num = indice_pair_num
self.out_spatial_shape = out_spatial_shape
def scatter_nd(indices, updates, shape):
"""pytorch edition of tensorflow scatter_nd.
this function don't contain except handle code. so use this carefully
when indice repeats, don't support repeat add which is supported
in tensorflow.
"""
ret = torch.zeros(*shape, dtype=updates.dtype, device=updates.device)
ndim = indices.shape[-1]
output_shape = list(indices.shape[:-1]) + shape[indices.shape[-1]:]
flatted_indices = indices.view(-1, ndim)
slices = [flatted_indices[:, i] for i in range(ndim)]
slices += [Ellipsis]
ret[slices] = updates.view(*output_shape)
return ret
class SparseConvTensor(object):
def __init__(self,
features,
indices,
spatial_shape,
batch_size,
grid=None,
voxel_num=None,
benchmark=False):
"""
Args:
features: [num_points, num_features] feature tensor
indices: [num_points, ndim + 1] indice tensor. batch index saved in indices[:, 0]
spatial_shape: spatial shape of your sparse data
batch_size: batch size of your sparse data
grid: pre-allocated grid tensor. should be used when the volume of spatial shape
is very large.
benchmark: whether to enable benchmark. if enabled, all sparse operators will be record to
SparseConvTensor.
"""
self._features = features
self.indices = indices
self.spatial_shape = spatial_shape
self.batch_size = batch_size
self.indice_dict = {}
if grid is None:
grid = torch.Tensor() # empty tensor
self.grid = grid
self.voxel_num = voxel_num
self.benchmark = benchmark
self.benchmark_record = {}
def replace_feature(self, feature):
"""we need to replace x.features = F.relu(x) with x = x.replace_feature(F.relu(x))
due to limit of torch.fx
"""
new_spt = SparseConvTensor(feature, self.indices, self.spatial_shape, self.batch_size, self.grid, self.voxel_num, self.indice_dict)
new_spt.benchmark = self.benchmark
new_spt.benchmark_record = self.benchmark_record
return new_spt
@property
def features(self):
return self._features
@features.setter
def features(self, val):
msg = ("you can't set feature directly, use 'x = x.replace_feature(F.relu(x.feature))'"
" to generate new SparseConvTensor instead.")
raise ValueError(msg)
@classmethod
def from_dense(cls, x: torch.Tensor):
"""create sparse tensor fron channel last dense tensor by to_sparse
x must be NHWC tensor, channel last
"""
x = x.to_sparse(x.ndim - 1)
spatial_shape = x.shape[1:-1]
batch_size = x.shape[0]
indices_th = x.indices().permute(1, 0).contiguous().int()
features_th = x.values()
return cls(features_th, indices_th, spatial_shape, batch_size)
@property
def spatial_size(self):
return np.prod(self.spatial_shape)
def find_indice_pair(self, key) -> Optional[IndiceData]:
if key is None:
return None
if key in self.indice_dict:
return self.indice_dict[key]
return None
def dense(self, channels_first=True):
output_shape = [self.batch_size] + list(
self.spatial_shape) + [self.features.shape[1]]
res = scatter_nd(
self.indices.to(self.features.device).long(), self.features,
output_shape)
if not channels_first:
return res
ndim = len(self.spatial_shape)
trans_params = list(range(0, ndim + 1))
trans_params.insert(1, ndim + 1)
return res.permute(*trans_params).contiguous()
# @property
# def sparity(self):
# return self.indices.shape[0] / np.prod(
# self.spatial_shape) / self.batch_size
def shadow_copy(self) -> "SparseConvTensor":
"""create a new spconv tensor with all member unchanged"""
tensor = SparseConvTensor(self.features, self.indices,
self.spatial_shape, self.batch_size,
self.grid, self.benchmark)
tensor.benchmark_record = self.benchmark_record
tensor.indice_dict = self.indice_dict
return tensor
......@@ -390,8 +390,8 @@ def indice_conv_backward(features: torch.Tensor,
False,
arch=arch,
shuffle_type=ShuffleStrideType.ShuffleAC,
a_inds=inp_indices,
c_inds=out_indices,
a_inds=out_indices,
c_inds=inp_indices,
alpha=1.0,
beta=0.0,
hint=AlgoHint.BackwardInput.value,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment