You need to sign in or sign up before continuing.
Commit 324235d0 authored by yan.yan's avatar yan.yan
Browse files

v2.0.2: fix a serious bug, and other small change

parent 68cbfcbc
...@@ -18,6 +18,9 @@ ...@@ -18,6 +18,9 @@
[![Build Status](https://github.com/traveller59/spconv/workflows/build/badge.svg)](https://github.com/traveller59/spconv/actions?query=workflow%3Abuild) [![Build Status](https://github.com/traveller59/spconv/workflows/build/badge.svg)](https://github.com/traveller59/spconv/actions?query=workflow%3Abuild)
# !!!!!!!!!!
If you are using spconv < 2.0.2, update after spconv 2.0.2 build success or check [this issue](https://github.com/traveller59/spconv/issues/340#issuecomment-951493008) to fix a serious bug, I'm so sorry for this stupid bug.
## Breaking changes in Spconv 2.x ## Breaking changes in Spconv 2.x
* ```spconv.xxx``` move to ```spconv.pytorch.xxx```, change all ```import spconv``` to ```import spconv.pytorch as spconv``` and ```from spconv.xxx import``` to ```from spconv.pytorch.xxx import```. * ```spconv.xxx``` move to ```spconv.pytorch.xxx```, change all ```import spconv``` to ```import spconv.pytorch as spconv``` and ```from spconv.xxx import``` to ```from spconv.pytorch.xxx import```.
...@@ -42,6 +45,13 @@ ...@@ -42,6 +45,13 @@
* doesn't depend on pytorch binary. * doesn't depend on pytorch binary.
* since spconv 2.x doesn't depend on pytorch binary (never in future), it's impossible to support torch.jit/libtorch inference. * since spconv 2.x doesn't depend on pytorch binary (never in future), it's impossible to support torch.jit/libtorch inference.
Spconv 2.1.0 vs 1.x speed:
| | 1080Ti Spconv 1.x F32 | 1080Ti Spconv 2.0 F32 | 3080M* Spconv 2.1 F16 |
| -------------- |:---------------------:| ---------------------:| ----------:|
| 27x128x128 Fwd | 11ms | 5.4ms | 1.4ms |
\* 3080M (Laptop) ~= 3070 Desktop
## Usage ## Usage
......
...@@ -716,7 +716,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -716,7 +716,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
if (indice_pair_mask.ndim() == 2 && indice_pair_mask.dim(0) == 2){{ if (indice_pair_mask.ndim() == 2 && indice_pair_mask.dim(0) == 2){{
auto mask_0 = indice_pair_mask[0]; auto mask_0 = indice_pair_mask[0];
tv::cuda::Launch lanucher_fill(mask_0.size(), custream); tv::cuda::Launch lanucher_fill(mask_0.size(), custream);
lanucher_fill(cudakers::fill_kernel<int>, mask_0.data_ptr<int>(), (1 << (kv / 2)), mask_0.size()); lanucher_fill(cudakers::fill_kernel<uint32_t>, mask_0.data_ptr<uint32_t>(), (1 << (kv / 2)), mask_0.size());
indice_pair_mask[1].zero_(ctx); indice_pair_mask[1].zero_(ctx);
auto kernel = &calc_subm_conv_indices_split_mask<table_t>; auto kernel = &calc_subm_conv_indices_split_mask<table_t>;
launcher_num_act_in(kernel, loc_iter, hash, launcher_num_act_in(kernel, loc_iter, hash,
...@@ -725,7 +725,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -725,7 +725,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
indices.dim(0), indice_pairs.dim(2), kv); indices.dim(0), indice_pairs.dim(2), kv);
}}else{{ }}else{{
tv::cuda::Launch lanucher_fill(indice_pair_mask.size(), custream); tv::cuda::Launch lanucher_fill(indice_pair_mask.size(), custream);
lanucher_fill(cudakers::fill_kernel<int>, indice_pair_mask.data_ptr<int>(), (1 << (kv / 2)), indice_pair_mask.size()); lanucher_fill(cudakers::fill_kernel<uint32_t>, indice_pair_mask.data_ptr<uint32_t>(), (1 << (kv / 2)), indice_pair_mask.size());
TV_ASSERT_RT_ERR(indice_pair_mask.ndim() == 1, "error"); TV_ASSERT_RT_ERR(indice_pair_mask.ndim() == 1, "error");
launcher_num_act_in(calc_subm_conv_indices_mask<table_t>, loc_iter, hash, launcher_num_act_in(calc_subm_conv_indices_mask<table_t>, loc_iter, hash,
indices.data_ptr<int>(), indice_pairs.data_ptr<int>(), indices.data_ptr<int>(), indice_pairs.data_ptr<int>(),
......
# Copyright 2021 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
try:
remove_plus = torch.__version__.find("+")
remove_dotdev = torch.__version__.find(".dev")
PYTORCH_VERSION = torch.__version__
if remove_plus != -1:
PYTORCH_VERSION = torch.__version__[:remove_plus]
if remove_dotdev != -1:
PYTORCH_VERSION = torch.__version__[:remove_dotdev]
PYTORCH_VERSION = list(map(int, PYTORCH_VERSION.split(".")))
except:
# for unknown errors, just set a version
PYTORCH_VERSION = [1, 8, 0]
\ No newline at end of file
...@@ -191,6 +191,7 @@ class SparseConvolution(SparseModule): ...@@ -191,6 +191,7 @@ class SparseConvolution(SparseModule):
datas = input.find_indice_pair(self.indice_key) datas = input.find_indice_pair(self.indice_key)
if self.inverse: if self.inverse:
assert datas is not None and self.indice_key is not None assert datas is not None and self.indice_key is not None
assert datas.is_subm is False, "inverse conv can only be used with standard conv and pool ops."
outids = datas.indices outids = datas.indices
indice_pairs = datas.indice_pairs indice_pairs = datas.indice_pairs
indice_pair_num = datas.indice_pair_num indice_pair_num = datas.indice_pair_num
...@@ -226,7 +227,7 @@ class SparseConvolution(SparseModule): ...@@ -226,7 +227,7 @@ class SparseConvolution(SparseModule):
self.name]["indice_gen_time"].append(interval) self.name]["indice_gen_time"].append(interval)
indice_data = IndiceData(outids, indices, indice_pairs, indice_data = IndiceData(outids, indices, indice_pairs,
indice_pair_num, spatial_shape) indice_pair_num, spatial_shape, is_subm=self.subm)
input.indice_dict[self.indice_key] = indice_data input.indice_dict[self.indice_key] = indice_data
if input.benchmark: if input.benchmark:
torch.cuda.synchronize() torch.cuda.synchronize()
......
...@@ -12,18 +12,153 @@ ...@@ -12,18 +12,153 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import torch from typing import List, Optional
remove_plus = torch.__version__.find("+")
remove_dotdev = torch.__version__.find(".dev")
PYTORCH_VERSION = torch.__version__ import numpy as np
if remove_plus != -1: import torch
PYTORCH_VERSION = torch.__version__[:remove_plus] from spconv.pytorch.constants import PYTORCH_VERSION
if remove_dotdev != -1:
PYTORCH_VERSION = torch.__version__[:remove_dotdev]
PYTORCH_VERSION = list(map(int, PYTORCH_VERSION.split(".")))
if PYTORCH_VERSION >= [1, 8, 0]: if PYTORCH_VERSION >= [1, 8, 0]:
from .core_fx import * try:
import torch.fx
if PYTORCH_VERSION >= [1, 10, 0]:
from torch.fx import ProxyableClassMeta
else:
from torch.fx.symbolic_trace import ProxyableClassMeta
SpConvTensorMeta = ProxyableClassMeta
except:
class SpConvTensorMeta(type):
pass
else: else:
from .core import * class SpConvTensorMeta(type):
pass
class IndiceData(object):
def __init__(self, out_indices, indices, indice_pairs, indice_pair_num,
out_spatial_shape, is_subm: bool):
self.out_indices = out_indices
self.indices = indices
self.indice_pairs = indice_pairs
self.indice_pair_num = indice_pair_num
self.out_spatial_shape = out_spatial_shape
self.is_subm = is_subm
def scatter_nd(indices, updates, shape):
"""pytorch edition of tensorflow scatter_nd.
this function don't contain except handle code. so use this carefully
when indice repeats, don't support repeat add which is supported
in tensorflow.
"""
ret = torch.zeros(*shape, dtype=updates.dtype, device=updates.device)
ndim = indices.shape[-1]
output_shape = list(indices.shape[:-1]) + shape[indices.shape[-1]:]
flatted_indices = indices.view(-1, ndim)
slices = [flatted_indices[:, i] for i in range(ndim)]
slices += [Ellipsis]
ret[slices] = updates.view(*output_shape)
return ret
# ProxyableClassMeta is used for TensorRT conversion in future.
class SparseConvTensor(metaclass=SpConvTensorMeta):
def __init__(self,
features: torch.Tensor,
indices: torch.Tensor,
spatial_shape: List[int],
batch_size: int,
grid: Optional[torch.Tensor]=None,
voxel_num: Optional[torch.Tensor]=None,
indice_dict: Optional[dict] = None,
benchmark: bool=False):
"""
Args:
features: [num_points, num_features] feature tensor
indices: [num_points, ndim + 1] indice tensor. batch index saved in indices[:, 0]
spatial_shape: spatial shape of your sparse data
batch_size: batch size of your sparse data
grid: pre-allocated grid tensor. should be used when the volume of spatial shape
is very large.
benchmark: whether to enable benchmark. if enabled, all sparse operators will be record to
SparseConvTensor.
"""
self._features = features
self.indices = indices
self.spatial_shape = spatial_shape
self.batch_size = batch_size
if indice_dict is None:
indice_dict = {}
self.indice_dict = indice_dict
if grid is None:
grid = torch.Tensor() # empty tensor
self.grid = grid
self.voxel_num = voxel_num # for tensorrt
self.benchmark = benchmark
self.benchmark_record = {}
def replace_feature(self, feature):
"""we need to replace x.features = F.relu(x) with x = x.replace_feature(F.relu(x.features))
due to limit of torch.fx
"""
new_spt = SparseConvTensor(feature, self.indices, self.spatial_shape, self.batch_size, self.grid, self.voxel_num, self.indice_dict)
new_spt.benchmark = self.benchmark
new_spt.benchmark_record = self.benchmark_record
return new_spt
@property
def features(self):
return self._features
@features.setter
def features(self, val):
msg = ("you can't set feature directly, use 'x = x.replace_feature(your_new_feature)'"
" to generate new SparseConvTensor instead.")
raise ValueError(msg)
@classmethod
def from_dense(cls, x: torch.Tensor):
"""create sparse tensor fron channel last dense tensor by to_sparse
x must be NHWC tensor, channel last
"""
x_sp = x.to_sparse(x.ndim - 1)
spatial_shape = list(x_sp.shape[1:-1])
batch_size = x_sp.shape[0]
indices_th = x_sp.indices().permute(1, 0).contiguous().int()
features_th = x_sp.values()
return cls(features_th, indices_th, spatial_shape, batch_size)
@property
def spatial_size(self):
return np.prod(self.spatial_shape)
def find_indice_pair(self, key) -> Optional[IndiceData]:
if key is None:
return None
if key in self.indice_dict:
return self.indice_dict[key]
return None
def dense(self, channels_first: bool=True):
output_shape = [self.batch_size] + list(
self.spatial_shape) + [self.features.shape[1]]
res = scatter_nd(
self.indices.to(self.features.device).long(), self.features,
output_shape)
if not channels_first:
return res
ndim = len(self.spatial_shape)
trans_params = list(range(0, ndim + 1))
trans_params.insert(1, ndim + 1)
return res.permute(*trans_params).contiguous()
# remove this due to limit of torch.fx
# @property
# def sparity(self):
# return self.indices.shape[0] / np.prod(
# self.spatial_shape) / self.batch_size
def shadow_copy(self) -> "SparseConvTensor":
"""create a new spconv tensor with all member unchanged"""
tensor = SparseConvTensor(self.features, self.indices,
self.spatial_shape, self.batch_size,
self.grid, self.voxel_num, self.indice_dict, self.benchmark)
tensor.benchmark_record = self.benchmark_record
return tensor
# Copyright 2021 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional
import numpy as np
import torch
import torch.fx
from torch.fx.symbolic_trace import ProxyableClassMeta
class IndiceData(object):
def __init__(self, out_indices, indices, indice_pairs, indice_pair_num,
out_spatial_shape):
self.out_indices = out_indices
self.indices = indices
self.indice_pairs = indice_pairs
self.indice_pair_num = indice_pair_num
self.out_spatial_shape = out_spatial_shape
def scatter_nd(indices, updates, shape):
"""pytorch edition of tensorflow scatter_nd.
this function don't contain except handle code. so use this carefully
when indice repeats, don't support repeat add which is supported
in tensorflow.
"""
ret = torch.zeros(*shape, dtype=updates.dtype, device=updates.device)
ndim = indices.shape[-1]
output_shape = list(indices.shape[:-1]) + shape[indices.shape[-1]:]
flatted_indices = indices.view(-1, ndim)
slices = [flatted_indices[:, i] for i in range(ndim)]
slices += [Ellipsis]
ret[slices] = updates.view(*output_shape)
return ret
class SparseConvTensor(metaclass=ProxyableClassMeta):
def __init__(self,
features: torch.Tensor,
indices: torch.Tensor,
spatial_shape: List[int],
batch_size: int,
grid: Optional[torch.Tensor]=None,
voxel_num: Optional[torch.Tensor]=None,
indice_dict: Optional[dict] = None,
benchmark: bool=False):
"""
Args:
features: [num_points, num_features] feature tensor
indices: [num_points, ndim + 1] indice tensor. batch index saved in indices[:, 0]
spatial_shape: spatial shape of your sparse data
batch_size: batch size of your sparse data
grid: pre-allocated grid tensor. should be used when the volume of spatial shape
is very large.
benchmark: whether to enable benchmark. if enabled, all sparse operators will be record to
SparseConvTensor.
"""
self._features = features
self.indices = indices
self.spatial_shape = spatial_shape
self.batch_size = batch_size
if indice_dict is None:
indice_dict = {}
self.indice_dict = indice_dict
if grid is None:
grid = torch.Tensor() # empty tensor
self.grid = grid
self.voxel_num = voxel_num # for tensorrt
self.benchmark = benchmark
self.benchmark_record = {}
def replace_feature(self, feature):
"""we need to replace x.features = F.relu(x) with x = x.replace_feature(F.relu(x.features))
due to limit of torch.fx
"""
new_spt = SparseConvTensor(feature, self.indices, self.spatial_shape, self.batch_size, self.grid, self.voxel_num, self.indice_dict)
new_spt.benchmark = self.benchmark
new_spt.benchmark_record = self.benchmark_record
return new_spt
@property
def features(self):
return self._features
@features.setter
def features(self, val):
msg = ("you can't set feature directly, use 'x = x.replace_feature(your_new_feature)'"
" to generate new SparseConvTensor instead.")
raise ValueError(msg)
@classmethod
def from_dense(cls, x: torch.Tensor):
"""create sparse tensor fron channel last dense tensor by to_sparse
x must be NHWC tensor, channel last
"""
x_sp = x.to_sparse(x.ndim - 1)
spatial_shape = list(x_sp.shape[1:-1])
batch_size = x_sp.shape[0]
indices_th = x_sp.indices().permute(1, 0).contiguous().int()
features_th = x_sp.values()
return cls(features_th, indices_th, spatial_shape, batch_size)
@property
def spatial_size(self):
return np.prod(self.spatial_shape)
def find_indice_pair(self, key) -> Optional[IndiceData]:
if key is None:
return None
if key in self.indice_dict:
return self.indice_dict[key]
return None
def dense(self, channels_first: bool=True):
output_shape = [self.batch_size] + list(
self.spatial_shape) + [self.features.shape[1]]
res = scatter_nd(
self.indices.to(self.features.device).long(), self.features,
output_shape)
if not channels_first:
return res
ndim = len(self.spatial_shape)
trans_params = list(range(0, ndim + 1))
trans_params.insert(1, ndim + 1)
return res.permute(*trans_params).contiguous()
# remove this due to limit of torch.fx
# @property
# def sparity(self):
# return self.indices.shape[0] / np.prod(
# self.spatial_shape) / self.batch_size
def shadow_copy(self) -> "SparseConvTensor":
"""create a new spconv tensor with all member unchanged"""
tensor = SparseConvTensor(self.features, self.indices,
self.spatial_shape, self.batch_size,
self.grid, self.voxel_num, self.indice_dict, self.benchmark)
tensor.benchmark_record = self.benchmark_record
return tensor
# Copyright 2021 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional
import numpy as np
import torch
class IndiceData(object):
def __init__(self, out_indices, indices, indice_pairs, indice_pair_num,
out_spatial_shape):
self.out_indices = out_indices
self.indices = indices
self.indice_pairs = indice_pairs
self.indice_pair_num = indice_pair_num
self.out_spatial_shape = out_spatial_shape
def scatter_nd(indices, updates, shape):
"""pytorch edition of tensorflow scatter_nd.
this function don't contain except handle code. so use this carefully
when indice repeats, don't support repeat add which is supported
in tensorflow.
"""
ret = torch.zeros(*shape, dtype=updates.dtype, device=updates.device)
ndim = indices.shape[-1]
output_shape = list(indices.shape[:-1]) + shape[indices.shape[-1]:]
flatted_indices = indices.view(-1, ndim)
slices = [flatted_indices[:, i] for i in range(ndim)]
slices += [Ellipsis]
ret[slices] = updates.view(*output_shape)
return ret
class SparseConvTensor:
def __init__(self,
features: torch.Tensor,
indices: torch.Tensor,
spatial_shape: List[int],
batch_size: int,
grid: Optional[torch.Tensor]=None,
voxel_num: Optional[torch.Tensor]=None,
indice_dict: Optional[dict] = None,
benchmark: bool=False):
"""
Args:
features: [num_points, num_features] feature tensor
indices: [num_points, ndim + 1] indice tensor. batch index saved in indices[:, 0]
spatial_shape: spatial shape of your sparse data
batch_size: batch size of your sparse data
grid: pre-allocated grid tensor. should be used when the volume of spatial shape
is very large.
benchmark: whether to enable benchmark. if enabled, all sparse operators will be record to
SparseConvTensor.
"""
self._features = features
self.indices = indices
self.spatial_shape = spatial_shape
self.batch_size = batch_size
if indice_dict is None:
indice_dict = {}
self.indice_dict = indice_dict
if grid is None:
grid = torch.Tensor() # empty tensor
self.grid = grid
self.voxel_num = voxel_num # for tensorrt
self.benchmark = benchmark
self.benchmark_record = {}
def replace_feature(self, feature):
"""we need to replace x.features = F.relu(x) with x = x.replace_feature(F.relu(x.features))
due to limit of torch.fx
"""
new_spt = SparseConvTensor(feature, self.indices, self.spatial_shape, self.batch_size, self.grid, self.voxel_num, self.indice_dict)
new_spt.benchmark = self.benchmark
new_spt.benchmark_record = self.benchmark_record
return new_spt
@property
def features(self):
return self._features
@features.setter
def features(self, val):
msg = ("you can't set feature directly, use 'x = x.replace_feature(your_new_feature)'"
" to generate new SparseConvTensor instead.")
raise ValueError(msg)
@classmethod
def from_dense(cls, x: torch.Tensor):
"""create sparse tensor fron channel last dense tensor by to_sparse
x must be NHWC tensor, channel last
"""
x_sp = x.to_sparse(x.ndim - 1)
spatial_shape = list(x_sp.shape[1:-1])
batch_size = x_sp.shape[0]
indices_th = x_sp.indices().permute(1, 0).contiguous().int()
features_th = x_sp.values()
return cls(features_th, indices_th, spatial_shape, batch_size)
@property
def spatial_size(self):
return np.prod(self.spatial_shape)
def find_indice_pair(self, key) -> Optional[IndiceData]:
if key is None:
return None
if key in self.indice_dict:
return self.indice_dict[key]
return None
def dense(self, channels_first: bool=True):
output_shape = [self.batch_size] + list(
self.spatial_shape) + [self.features.shape[1]]
res = scatter_nd(
self.indices.to(self.features.device).long(), self.features,
output_shape)
if not channels_first:
return res
ndim = len(self.spatial_shape)
trans_params = list(range(0, ndim + 1))
trans_params.insert(1, ndim + 1)
return res.permute(*trans_params).contiguous()
# remove this due to limit of torch.fx
# @property
# def sparity(self):
# return self.indices.shape[0] / np.prod(
# self.spatial_shape) / self.batch_size
def shadow_copy(self) -> "SparseConvTensor":
"""create a new spconv tensor with all member unchanged"""
tensor = SparseConvTensor(self.features, self.indices,
self.spatial_shape, self.batch_size,
self.grid, self.voxel_num, self.indice_dict, self.benchmark)
tensor.benchmark_record = self.benchmark_record
return tensor
...@@ -138,7 +138,7 @@ class SparseSequential(SparseModule): ...@@ -138,7 +138,7 @@ class SparseSequential(SparseModule):
else: else:
if isinstance(input, spconv.SparseConvTensor): if isinstance(input, spconv.SparseConvTensor):
if input.indices.shape[0] != 0: if input.indices.shape[0] != 0:
input.replace_feature(module(input.features)) input = input.replace_feature(module(input.features))
else: else:
input = module(input) input = module(input)
return input return input
......
...@@ -26,7 +26,8 @@ from spconv.core_cc.csrc.sparse.all import SpconvOps ...@@ -26,7 +26,8 @@ from spconv.core_cc.csrc.sparse.all import SpconvOps
from spconv.algo import GEMM # , GATHER, SCATTER from spconv.algo import GEMM # , GATHER, SCATTER
import time import time
from spconv.constants import FILTER_HWIO from spconv.constants import FILTER_HWIO
import pickle
from pathlib import Path
def get_conv_output_size(input_size, kernel_size, stride, padding, dilation): def get_conv_output_size(input_size, kernel_size, stride, padding, dilation):
ndim = len(input_size) ndim = len(input_size)
...@@ -78,6 +79,8 @@ def get_indice_pairs(indices: torch.Tensor, ...@@ -78,6 +79,8 @@ def get_indice_pairs(indices: torch.Tensor,
padding, dilation) padding, dilation)
else: else:
out_shape = spatial_shape out_shape = spatial_shape
if any([x == 0 for x in out_shape]):
raise ValueError(f"your out spatial shape {out_shape} reach zero!!! input shape: {spatial_shape}")
assert algo == ConvAlgo.Native, "TODO" assert algo == ConvAlgo.Native, "TODO"
stream = get_current_stream() stream = get_current_stream()
...@@ -205,6 +208,9 @@ def indice_conv(features: torch.Tensor, ...@@ -205,6 +208,9 @@ def indice_conv(features: torch.Tensor,
stream = get_current_stream() stream = get_current_stream()
indice_pair_num_cpu = indice_pair_num.cpu().tolist() indice_pair_num_cpu = indice_pair_num.cpu().tolist()
if subm and all(x == 0 for x in indice_pair_num_cpu):
return out_features
arch = torch.cuda.get_device_capability() arch = torch.cuda.get_device_capability()
inited: bool = subm inited: bool = subm
a = torch_tensor_to_tv(features) a = torch_tensor_to_tv(features)
...@@ -214,7 +220,14 @@ def indice_conv(features: torch.Tensor, ...@@ -214,7 +220,14 @@ def indice_conv(features: torch.Tensor,
profile_idx = kv_center - 1 profile_idx = kv_center - 1
# profile_idx = first_n # profile_idx = first_n
nhot_profile = indice_pair_num_cpu[profile_idx] nhot_profile = indice_pair_num_cpu[profile_idx]
if nhot_profile == 0:
# find a non-zero profile index
profile_idx = 0
for i, nhot in enumerate(indice_pair_num_cpu):
if nhot > nhot_profile:
nhot_profile = nhot
profile_idx = i
assert nhot_profile > 0, "this shouldn't happen"
# print(nhot_profile, indice_pair_num_cpu) # print(nhot_profile, indice_pair_num_cpu)
profile_res = GEMM.get_profiled_algo( profile_res = GEMM.get_profiled_algo(
a.shape, a.shape,
...@@ -294,8 +307,6 @@ def indice_conv(features: torch.Tensor, ...@@ -294,8 +307,6 @@ def indice_conv(features: torch.Tensor,
# # print(stream, valid_count, maxnhot, features.shape[0], features.shape[1], out_channel, time.time() - t, total_times, txt) # # print(stream, valid_count, maxnhot, features.shape[0], features.shape[1], out_channel, time.time() - t, total_times, txt)
# # print(algo_desp, profile_res.external_gather, profile_res.splitk, features.shape[0], features.shape[1], out_channel, time.time() - t) # # print(algo_desp, profile_res.external_gather, profile_res.splitk, features.shape[0], features.shape[1], out_channel, time.time() - t)
# # print(indice_pair_num_cpu)
# print("G", time.time() - t)
return out_features return out_features
...@@ -312,15 +323,14 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -312,15 +323,14 @@ def indice_conv_backward(features: torch.Tensor,
inverse: bool = False, inverse: bool = False,
subm: bool = False, subm: bool = False,
algo: ConvAlgo = ConvAlgo.Native): algo: ConvAlgo = ConvAlgo.Native):
# torch.cuda.synchronize()
# t = time.time()
num_activate_out = out_bp.shape[0] num_activate_out = out_bp.shape[0]
out_channel = out_bp.shape[-1] out_channel = out_bp.shape[-1]
filters_shape = filters.shape filters_shape = filters.shape
filters = filters.reshape(-1, *filters.shape[-2:]) filters = filters.reshape(-1, *filters.shape[-2:])
kv = filters.shape[0] kv = filters.shape[0]
kv_center = kv // 2 kv_center = kv // 2
if not out_bp.is_contiguous():
out_bp = out_bp.contiguous()
assert out_bp.is_contiguous() assert out_bp.is_contiguous()
assert filters.is_contiguous() assert filters.is_contiguous()
assert features.is_contiguous() assert features.is_contiguous()
...@@ -349,6 +359,9 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -349,6 +359,9 @@ def indice_conv_backward(features: torch.Tensor,
stream = get_current_stream() stream = get_current_stream()
indice_pair_num_cpu = indice_pair_num.cpu().tolist() indice_pair_num_cpu = indice_pair_num.cpu().tolist()
if subm and all(x == 0 for x in indice_pair_num_cpu):
return (din, dfilters.reshape(filters_shape))
arch = torch.cuda.get_device_capability() arch = torch.cuda.get_device_capability()
filters_tv = torch_tensor_to_tv(filters) filters_tv = torch_tensor_to_tv(filters)
...@@ -359,10 +372,18 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -359,10 +372,18 @@ def indice_conv_backward(features: torch.Tensor,
din_tv = torch_tensor_to_tv(din) din_tv = torch_tensor_to_tv(din)
profile_idx = kv_center profile_idx = kv_center
if subm: if subm or indice_pair_num_cpu[profile_idx] == 0:
profile_idx = kv_center - 1 profile_idx = kv_center - 1
# profile_idx = first_n # profile_idx = first_n
nhot_profile = indice_pair_num_cpu[profile_idx] nhot_profile = indice_pair_num_cpu[profile_idx]
if nhot_profile == 0:
# find a non-zero profile index
profile_idx = 0
for i, nhot in enumerate(indice_pair_num_cpu):
if nhot > nhot_profile:
nhot_profile = nhot
profile_idx = i
assert nhot_profile > 0, "this shouldn't happen"
# print(nhot_profile, indice_pair_num_cpu) # print(nhot_profile, indice_pair_num_cpu)
profile_res_dgrad = GEMM.get_profiled_algo( profile_res_dgrad = GEMM.get_profiled_algo(
...@@ -549,11 +570,12 @@ def indice_maxpool(features, indice_pairs, indice_pair_num, num_activate_out): ...@@ -549,11 +570,12 @@ def indice_maxpool(features, indice_pairs, indice_pair_num, num_activate_out):
indice_pair_num_cpu = indice_pair_num.cpu().tolist() indice_pair_num_cpu = indice_pair_num.cpu().tolist()
out_features_tv = torch_tensor_to_tv(out_features) out_features_tv = torch_tensor_to_tv(out_features)
features_tv = torch_tensor_to_tv(features) features_tv = torch_tensor_to_tv(features)
indice_pairs_tv = torch_tensor_to_tv(indice_pairs)
for i, nhot in enumerate(indice_pair_num_cpu): for i, nhot in enumerate(indice_pair_num_cpu):
if nhot <= 0: if nhot <= 0:
continue continue
inp_indices = torch_tensor_to_tv(indice_pairs[0][i, :nhot]) inp_indices = indice_pairs_tv[0][i].slice_first_axis(0, nhot)
out_indices = torch_tensor_to_tv(indice_pairs[1][i, :nhot]) out_indices = indice_pairs_tv[1][i].slice_first_axis(0, nhot)
SpconvOps.maxpool_forward(out_features_tv, features_tv, out_indices, SpconvOps.maxpool_forward(out_features_tv, features_tv, out_indices,
inp_indices, stream) inp_indices, stream)
# torch.cuda.synchronize() # torch.cuda.synchronize()
...@@ -568,15 +590,18 @@ def indice_maxpool_backward(features, out_features, out_bp, indice_pairs, ...@@ -568,15 +590,18 @@ def indice_maxpool_backward(features, out_features, out_bp, indice_pairs,
din = torch.zeros_like(features) din = torch.zeros_like(features)
stream = get_current_stream() stream = get_current_stream()
indice_pair_num_cpu = indice_pair_num.cpu().tolist() indice_pair_num_cpu = indice_pair_num.cpu().tolist()
if not out_bp.is_contiguous():
out_bp = out_bp.contiguous()
out_features_tv = torch_tensor_to_tv(out_features) out_features_tv = torch_tensor_to_tv(out_features)
features_tv = torch_tensor_to_tv(features) features_tv = torch_tensor_to_tv(features)
out_bp_tv = torch_tensor_to_tv(out_bp) out_bp_tv = torch_tensor_to_tv(out_bp)
din_tv = torch_tensor_to_tv(din) din_tv = torch_tensor_to_tv(din)
indice_pairs_tv = torch_tensor_to_tv(indice_pairs)
for i, nhot in enumerate(indice_pair_num_cpu): for i, nhot in enumerate(indice_pair_num_cpu):
if nhot <= 0: if nhot <= 0:
continue continue
inp_indices = torch_tensor_to_tv(indice_pairs[0][i, :nhot]) inp_indices = indice_pairs_tv[0][i].slice_first_axis(0, nhot)
out_indices = torch_tensor_to_tv(indice_pairs[1][i, :nhot]) out_indices = indice_pairs_tv[1][i].slice_first_axis(0, nhot)
SpconvOps.maxpool_backward(out_features_tv, features_tv, out_bp_tv, SpconvOps.maxpool_backward(out_features_tv, features_tv, out_bp_tv,
din_tv, out_indices, inp_indices, stream) din_tv, out_indices, inp_indices, stream)
......
...@@ -119,7 +119,7 @@ class SparseMaxPool(SparseModule): ...@@ -119,7 +119,7 @@ class SparseMaxPool(SparseModule):
datas = input.find_indice_pair(self.indice_key) datas = input.find_indice_pair(self.indice_key)
if datas is None: if datas is None:
indice_data = IndiceData(outids, indices, indice_pairs, indice_data = IndiceData(outids, indices, indice_pairs,
indice_pairs_num, spatial_shape) indice_pairs_num, spatial_shape, is_subm=False)
input.indice_dict[self.indice_key] = indice_data input.indice_dict[self.indice_key] = indice_data
else: else:
raise ValueError("indice data exists") raise ValueError("indice data exists")
...@@ -147,12 +147,14 @@ class SparseMaxPool1d(SparseMaxPool): ...@@ -147,12 +147,14 @@ class SparseMaxPool1d(SparseMaxPool):
stride=None, stride=None,
padding=0, padding=0,
dilation=1, dilation=1,
indice_key=None,
name=None): name=None):
super(SparseMaxPool1d, self).__init__(1, super(SparseMaxPool1d, self).__init__(1,
kernel_size, kernel_size,
stride, stride,
padding, padding,
dilation, dilation,
indice_key=indice_key,
name=name) name=name)
class SparseMaxPool2d(SparseMaxPool): class SparseMaxPool2d(SparseMaxPool):
...@@ -161,12 +163,14 @@ class SparseMaxPool2d(SparseMaxPool): ...@@ -161,12 +163,14 @@ class SparseMaxPool2d(SparseMaxPool):
stride=None, stride=None,
padding=0, padding=0,
dilation=1, dilation=1,
indice_key=None,
name=None): name=None):
super(SparseMaxPool2d, self).__init__(2, super(SparseMaxPool2d, self).__init__(2,
kernel_size, kernel_size,
stride, stride,
padding, padding,
dilation, dilation,
indice_key=indice_key,
name=name) name=name)
...@@ -176,12 +180,14 @@ class SparseMaxPool3d(SparseMaxPool): ...@@ -176,12 +180,14 @@ class SparseMaxPool3d(SparseMaxPool):
stride=None, stride=None,
padding=0, padding=0,
dilation=1, dilation=1,
indice_key=None,
name=None): name=None):
super(SparseMaxPool3d, self).__init__(3, super(SparseMaxPool3d, self).__init__(3,
kernel_size, kernel_size,
stride, stride,
padding, padding,
dilation, dilation,
indice_key=indice_key,
name=name) name=name)
class SparseMaxPool4d(SparseMaxPool): class SparseMaxPool4d(SparseMaxPool):
...@@ -190,10 +196,12 @@ class SparseMaxPool4d(SparseMaxPool): ...@@ -190,10 +196,12 @@ class SparseMaxPool4d(SparseMaxPool):
stride=None, stride=None,
padding=0, padding=0,
dilation=1, dilation=1,
indice_key=None,
name=None): name=None):
super(SparseMaxPool4d, self).__init__(4, super(SparseMaxPool4d, self).__init__(4,
kernel_size, kernel_size,
stride, stride,
padding, padding,
dilation, dilation,
indice_key=indice_key,
name=name) name=name)
...@@ -62,8 +62,11 @@ class Net(nn.Module): ...@@ -62,8 +62,11 @@ class Net(nn.Module):
bias=False, bias=False,
indice_key="c0", indice_key="c0",
algo=algo), algo=algo),
# nn.BatchNorm1d(32),
# nn.ReLU(),
spconv.SparseConv3d(64, 64, 2, 2, bias=False, indice_key="m0"),
spconv.SparseMaxPool3d(2, 2), # spconv.SparseMaxPool3d(2, 2),
spconv.SubMConv3d(64, spconv.SubMConv3d(64,
96, 96,
3, 3,
...@@ -78,7 +81,9 @@ class Net(nn.Module): ...@@ -78,7 +81,9 @@ class Net(nn.Module):
algo=algo), algo=algo),
# nn.BatchNorm1d(64), # nn.BatchNorm1d(64),
# nn.ReLU(), # nn.ReLU(),
spconv.SparseMaxPool3d(2, 2), spconv.SparseConv3d(96, 96, 2, 2, bias=False, indice_key="m1"),
# spconv.SparseMaxPool3d(2, 2),
spconv.SubMConv3d(96, spconv.SubMConv3d(96,
128, 128,
3, 3,
...@@ -93,7 +98,9 @@ class Net(nn.Module): ...@@ -93,7 +98,9 @@ class Net(nn.Module):
algo=algo), algo=algo),
# nn.BatchNorm1d(128), # nn.BatchNorm1d(128),
# nn.ReLU(), # nn.ReLU(),
spconv.SparseMaxPool3d(2, 2), spconv.SparseConv3d(128, 128, 2, 2, bias=False, indice_key="m2"),
# spconv.SparseMaxPool3d(2, 2),
spconv.SubMConv3d(128, spconv.SubMConv3d(128,
160, 160,
3, 3,
...@@ -108,7 +115,9 @@ class Net(nn.Module): ...@@ -108,7 +115,9 @@ class Net(nn.Module):
algo=algo), algo=algo),
# nn.BatchNorm1d(128), # nn.BatchNorm1d(128),
# nn.ReLU(), # nn.ReLU(),
spconv.SparseMaxPool3d(2, 2), spconv.SparseConv3d(160, 160, 2, 2, bias=False, indice_key="m3"),
# spconv.SparseMaxPool3d(2, 2),
spconv.SubMConv3d(160, spconv.SubMConv3d(160,
192, 192,
3, 3,
...@@ -123,7 +132,9 @@ class Net(nn.Module): ...@@ -123,7 +132,9 @@ class Net(nn.Module):
algo=algo), algo=algo),
# nn.BatchNorm1d(128), # nn.BatchNorm1d(128),
# nn.ReLU(), # nn.ReLU(),
spconv.SparseMaxPool3d(2, 2), # spconv.SparseMaxPool3d(2, 2, indice_key="m4"),
spconv.SparseConv3d(192, 192, 2, 2, bias=False, indice_key="m4"),
spconv.SubMConv3d(192, spconv.SubMConv3d(192,
224, 224,
3, 3,
...@@ -136,9 +147,10 @@ class Net(nn.Module): ...@@ -136,9 +147,10 @@ class Net(nn.Module):
bias=False, bias=False,
indice_key="c5", indice_key="c5",
algo=algo), algo=algo),
# nn.BatchNorm1d(128), nn.BatchNorm1d(224),
# nn.ReLU(), nn.ReLU(),
spconv.SparseMaxPool3d(2, 2), spconv.SparseConv3d(224, 224, 2, 2, bias=False, indice_key="m5"),
# spconv.SparseMaxPool3d(2, 2, indice_key="m5"),
spconv.SubMConv3d(224, spconv.SubMConv3d(224,
256, 256,
3, 3,
...@@ -151,8 +163,15 @@ class Net(nn.Module): ...@@ -151,8 +163,15 @@ class Net(nn.Module):
bias=False, bias=False,
indice_key="c6", indice_key="c6",
algo=algo), algo=algo),
spconv.SparseInverseConv3d(256, 128, 3, indice_key="c6", bias=False),
spconv.SparseInverseConv3d(128, 64, 3, indice_key="c5", bias=False), nn.BatchNorm1d(256),
nn.ReLU(),
spconv.SparseInverseConv3d(256, 128, 2, indice_key="m5", bias=False),
nn.BatchNorm1d(128),
nn.ReLU(),
spconv.SparseInverseConv3d(128, 64, 2, indice_key="m4", bias=False),
) )
max_batch_size = 1 max_batch_size = 1
...@@ -171,7 +190,7 @@ class Net2(nn.Module): ...@@ -171,7 +190,7 @@ class Net2(nn.Module):
def __init__(self, shape, algo): def __init__(self, shape, algo):
super().__init__() super().__init__()
self.net = spconv.SparseSequential( self.net = spconv.SparseSequential(
spconv.SubMConv3d(3, 256, 3, bias=False, indice_key="c0", spconv.SubMConv3d(3, 128, 3, bias=False, indice_key="c0",
algo=algo), algo=algo),
# spconv.SubMConv3d(32, # spconv.SubMConv3d(32,
# 32, # 32,
...@@ -185,27 +204,27 @@ class Net2(nn.Module): ...@@ -185,27 +204,27 @@ class Net2(nn.Module):
# # algo=algo), # # algo=algo),
# spconv.SubMConv3d(32, 64, 3, bias=False, indice_key="c0", # spconv.SubMConv3d(32, 64, 3, bias=False, indice_key="c0",
# algo=algo), # algo=algo),
spconv.SubMConv3d(256, spconv.SubMConv3d(128,
256, 128,
3, 3,
bias=False, bias=False,
indice_key="c0", indice_key="c0",
algo=algo), algo=algo),
# nn.BatchNorm1d(32), # nn.BatchNorm1d(32),
# nn.ReLU(), # nn.ReLU(),
spconv.SparseMaxPool3d(2, 2), # spconv.SparseMaxPool3d(2, 2),
spconv.SubMConv3d(256, # spconv.SubMConv3d(256,
512, # 512,
3, # 3,
bias=False, # bias=False,
indice_key="c1", # indice_key="c1",
algo=algo), # algo=algo),
spconv.SubMConv3d(512, # spconv.SubMConv3d(512,
512, # 512,
3, # 3,
bias=False, # bias=False,
indice_key="c1", # indice_key="c1",
algo=algo), # algo=algo),
) )
max_batch_size = 1 max_batch_size = 1
# grid (dense map) is used for indice generation. use pre-allocated grid can run faster. # grid (dense map) is used for indice generation. use pre-allocated grid can run faster.
...@@ -249,19 +268,19 @@ def main(): ...@@ -249,19 +268,19 @@ def main():
dout_t = torch.from_numpy(dout).cuda().to(dtype) dout_t = torch.from_numpy(dout).cuda().to(dtype)
print(out.spatial_shape, out.features.mean(), out.features.max(), out.features.min()) print(out.spatial_shape, out.features.mean(), out.features.max(), out.features.min())
times = [] # times = []
with torch.no_grad(): # with torch.no_grad():
for i in range(20): # for i in range(20):
print("------------") # print("------------")
torch.cuda.synchronize() # torch.cuda.synchronize()
t = time.time() # t = time.time()
out_nograd = net(voxels_th, coors_th, 1) # out_nograd = net(voxels_th, coors_th, 1)
torch.cuda.synchronize() # torch.cuda.synchronize()
times.append(time.time() - t) # times.append(time.time() - t)
print("spconv time", np.mean(times[10:])) # print("spconv time", np.mean(times[10:]))
times = [] times = []
for i in range(10): for i in range(1):
out = net(voxels_th, coors_th, 1) out = net(voxels_th, coors_th, 1)
print("------------") print("------------")
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -270,9 +289,9 @@ def main(): ...@@ -270,9 +289,9 @@ def main():
torch.cuda.synchronize() torch.cuda.synchronize()
times.append(time.time() - t) times.append(time.time() - t)
# print((net.grid == -1).float().sum(), net.grid.numel()) # # print((net.grid == -1).float().sum(), net.grid.numel())
# print("spconv time", time.time() - t) # # print("spconv time", time.time() - t)
print("spconv bw time", np.mean(times[5:])) # print("spconv bw time", np.mean(times[5:]))
if __name__ == "__main__": if __name__ == "__main__":
......
2.0.1 2.0.2
\ No newline at end of file \ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment