Commit bab09b63 authored by yan.yan's avatar yan.yan
Browse files

Merge branch 'develop'

parents 7af751dc 66529500
...@@ -131,7 +131,8 @@ class SparseConvTensor(metaclass=SpConvTensorMeta): ...@@ -131,7 +131,8 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
indice_dict: Optional[dict] = None, indice_dict: Optional[dict] = None,
benchmark: bool = False, benchmark: bool = False,
permanent_thrust_allocator: bool = False, permanent_thrust_allocator: bool = False,
enable_timer: bool = False): enable_timer: bool = False,
force_algo: Optional[ConvAlgo] = None):
""" """
Args: Args:
features: [num_points, num_features] feature tensor features: [num_points, num_features] feature tensor
...@@ -142,6 +143,8 @@ class SparseConvTensor(metaclass=SpConvTensorMeta): ...@@ -142,6 +143,8 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
is very large. is very large.
benchmark: whether to enable benchmark. if enabled, all sparse operators will be record to benchmark: whether to enable benchmark. if enabled, all sparse operators will be record to
SparseConvTensor. SparseConvTensor.
enable_timer: if exists, all spconv internal ops run time will be record in _timer.
force_algo: force conv/pool layers use this algo, should only used for debug.
""" """
ndim = indices.shape[1] - 1 ndim = indices.shape[1] - 1
assert features.ndim == 2 assert features.ndim == 2
...@@ -166,6 +169,7 @@ class SparseConvTensor(metaclass=SpConvTensorMeta): ...@@ -166,6 +169,7 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
if permanent_thrust_allocator: if permanent_thrust_allocator:
self.thrust_allocator = ThrustSortAllocator(features.device) self.thrust_allocator = ThrustSortAllocator(features.device)
self._timer = CUDAKernelTimer(enable_timer) self._timer = CUDAKernelTimer(enable_timer)
self.force_algo = force_algo
def replace_feature(self, feature: torch.Tensor): def replace_feature(self, feature: torch.Tensor):
"""we need to replace x.features = F.relu(x.features) with x = x.replace_feature(F.relu(x.features)) """we need to replace x.features = F.relu(x.features) with x = x.replace_feature(F.relu(x.features))
...@@ -179,6 +183,8 @@ class SparseConvTensor(metaclass=SpConvTensorMeta): ...@@ -179,6 +183,8 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
new_spt.benchmark_record = self.benchmark_record new_spt.benchmark_record = self.benchmark_record
new_spt.thrust_allocator = self.thrust_allocator new_spt.thrust_allocator = self.thrust_allocator
new_spt._timer = self._timer new_spt._timer = self._timer
new_spt.force_algo = self.force_algo
return new_spt return new_spt
@property @property
...@@ -244,6 +250,7 @@ class SparseConvTensor(metaclass=SpConvTensorMeta): ...@@ -244,6 +250,7 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
tensor.benchmark_record = self.benchmark_record tensor.benchmark_record = self.benchmark_record
tensor.thrust_allocator = self.thrust_allocator tensor.thrust_allocator = self.thrust_allocator
tensor._timer = self._timer tensor._timer = self._timer
tensor.force_algo = self.force_algo
return tensor return tensor
def expand_nd(ndim: int, val: Union[int, List[int], Tuple[int, ...], np.ndarray]) -> List[int]: def expand_nd(ndim: int, val: Union[int, List[int], Tuple[int, ...], np.ndarray]) -> List[int]:
......
...@@ -36,8 +36,9 @@ _ALL_INTS = {tv.int32, tv.int16, tv.int8, tv.int64, tv.uint64, tv.uint8, tv.uint ...@@ -36,8 +36,9 @@ _ALL_INTS = {tv.int32, tv.int16, tv.int8, tv.int64, tv.uint64, tv.uint8, tv.uint
def torch_tensor_to_tv(ten: torch.Tensor, def torch_tensor_to_tv(ten: torch.Tensor,
dtype: Optional[int] = None, dtype: Optional[int] = None,
shape: Optional[List[int]] = None): shape: Optional[List[int]] = None,
assert ten.is_contiguous(), "must be contiguous tensor" stride: Optional[List[int]] = None):
# assert ten.is_contiguous(), "must be contiguous tensor"
ptr = ten.data_ptr() ptr = ten.data_ptr()
device = ten.device device = ten.device
if device.type == "cpu": if device.type == "cpu":
...@@ -46,12 +47,20 @@ def torch_tensor_to_tv(ten: torch.Tensor, ...@@ -46,12 +47,20 @@ def torch_tensor_to_tv(ten: torch.Tensor,
tv_device = 0 tv_device = 0
else: else:
raise NotImplementedError raise NotImplementedError
if shape is None:
shape = list(ten.shape)
if dtype is None: if dtype is None:
dtype = _TORCH_DTYPE_TO_TV[ten.dtype] dtype = _TORCH_DTYPE_TO_TV[ten.dtype]
stride = ten.stride() if stride is None:
return tv.from_blob_strided(ptr, shape, list(stride), dtype, tv_device) stride = list(ten.stride())
if shape is None:
shape = list(ten.shape)
else:
if not ten.is_contiguous():
msg = "if you provide custom shape for non-contig tensor, stride must not None"
assert stride is not None, msg
else:
# custom shape, if tensor is contiguous, we use from_blob and calc strides
return tv.from_blob(ptr, shape, dtype, tv_device)
return tv.from_blob_strided(ptr, shape, stride, dtype, tv_device)
def torch_tensors_to_tv(*tens: torch.Tensor): def torch_tensors_to_tv(*tens: torch.Tensor):
return (torch_tensor_to_tv(t) for t in tens) return (torch_tensor_to_tv(t) for t in tens)
......
...@@ -19,6 +19,7 @@ import torch ...@@ -19,6 +19,7 @@ import torch
from torch import nn from torch import nn
from torch.autograd import Function from torch.autograd import Function
from typing import Optional, TypeVar from typing import Optional, TypeVar
from spconv.pytorch.core import SparseConvTensor
from spconv.tools import CUDAKernelTimer from spconv.tools import CUDAKernelTimer
from spconv.pytorch import ops, SparseConvTensor from spconv.pytorch import ops, SparseConvTensor
from spconv.pytorch.constants import PYTORCH_VERSION from spconv.pytorch.constants import PYTORCH_VERSION
......
...@@ -80,7 +80,7 @@ class HashTable: ...@@ -80,7 +80,7 @@ class HashTable:
def query(self, keys: torch.Tensor, values: Optional[torch.Tensor] = None): def query(self, keys: torch.Tensor, values: Optional[torch.Tensor] = None):
"""query value by keys, if values is not None, create a new one. """query value by keys, if values is not None, create a new one.
return values and a uint8 tensor that whether query success. return values and a uint8 tensor that whether query fail.
""" """
keys_tv = torch_tensor_to_tv(keys) keys_tv = torch_tensor_to_tv(keys)
if values is None: if values is None:
...@@ -96,17 +96,17 @@ class HashTable: ...@@ -96,17 +96,17 @@ class HashTable:
def insert_exist_keys(self, keys: torch.Tensor, values: torch.Tensor): def insert_exist_keys(self, keys: torch.Tensor, values: torch.Tensor):
"""insert kv that k exists in table. return a uint8 tensor that """insert kv that k exists in table. return a uint8 tensor that
whether insert success. whether insert fail.
""" """
keys_tv = torch_tensor_to_tv(keys) keys_tv = torch_tensor_to_tv(keys)
values_tv = torch_tensor_to_tv(values) values_tv = torch_tensor_to_tv(values)
stream = 0 stream = 0
if not self.is_cpu: if not self.is_cpu:
stream = get_current_stream() stream = get_current_stream()
is_success = torch.empty([keys.shape[0]], dtype=torch.uint8, device=keys.device) is_empty = torch.empty([keys.shape[0]], dtype=torch.uint8, device=keys.device)
is_success_tv = torch_tensor_to_tv(is_success) is_empty_tv = torch_tensor_to_tv(is_empty)
self._table.insert_exist_keys(keys_tv, values_tv, is_success_tv, stream) self._table.insert_exist_keys(keys_tv, values_tv, is_empty_tv, stream)
return is_success > 0 return is_empty
def assign_arange_(self): def assign_arange_(self):
"""iterate table, assign values with "arange" value. """iterate table, assign values with "arange" value.
......
...@@ -137,6 +137,7 @@ class SparseSequential(SparseModule): ...@@ -137,6 +137,7 @@ class SparseSequential(SparseModule):
input = module(input) input = module(input)
else: else:
if isinstance(input, spconv.SparseConvTensor): if isinstance(input, spconv.SparseConvTensor):
print(input.features.shape)
if input.indices.shape[0] != 0: if input.indices.shape[0] != 0:
input = input.replace_feature(module(input.features)) input = input.replace_feature(module(input.features))
else: else:
......
...@@ -39,7 +39,7 @@ else: ...@@ -39,7 +39,7 @@ else:
GEMM = None GEMM = None
CONV = None CONV = None
import time import time
from spconv.constants import FILTER_HWIO from spconv.constants import FILTER_HWIO, ALL_WEIGHT_IS_KRSC
from cumm.gemm import codeops from cumm.gemm import codeops
from spconv.tools import CUDAKernelTimer from spconv.tools import CUDAKernelTimer
...@@ -630,21 +630,40 @@ def indice_conv(features: torch.Tensor, ...@@ -630,21 +630,40 @@ def indice_conv(features: torch.Tensor,
if features.dtype == torch.int8 or features.dtype == torch.qint8: if features.dtype == torch.int8 or features.dtype == torch.qint8:
raise NotImplementedError("work in progress") raise NotImplementedError("work in progress")
if FILTER_HWIO:
out_channel = filters.shape[-1] if not ALL_WEIGHT_IS_KRSC:
kv_dim = 0
is_KC_not_CK = not FILTER_HWIO
if FILTER_HWIO:
out_channel = filters.shape[-1]
filter_shape_per_kv = [filters.shape[-2], out_channel]
else:
out_channel = filters.shape[-2]
filter_shape_per_kv = [out_channel, filters.shape[-1]]
filters = filters.reshape(-1, *filters.shape[-2:])
kv = filters.shape[0]
else: else:
out_channel = filters.shape[-2] kv_dim = 1
filters = filters.reshape(-1, *filters.shape[-2:]) out_channel = filters.shape[0]
kv = filters.shape[0] filters = filters.reshape(out_channel, -1, filters.shape[-1])
is_KC_not_CK = True
kv = filters.shape[1]
filter_shape_per_kv = [out_channel, filters.shape[-1]]
kv_center = kv // 2 kv_center = kv // 2
if subm: if subm:
# out_features = torch.zeros((num_activate_out, out_channel), # out_features = torch.zeros((num_activate_out, out_channel),
# dtype=features.dtype, # dtype=features.dtype,
# device=features.device) # device=features.device)
if FILTER_HWIO: if not ALL_WEIGHT_IS_KRSC:
out_features = torch.mm(features, filters[kv_center]) if not is_KC_not_CK:
out_features = torch.mm(features, filters[kv_center])
else:
out_features = torch.mm(features, filters[kv_center].T)
else: else:
out_features = torch.mm(features, filters[kv_center].T) out_features = torch.mm(features, filters[:, kv_center].T)
else: else:
out_features = torch.zeros((num_activate_out, out_channel), out_features = torch.zeros((num_activate_out, out_channel),
dtype=features.dtype, dtype=features.dtype,
...@@ -664,7 +683,6 @@ def indice_conv(features: torch.Tensor, ...@@ -664,7 +683,6 @@ def indice_conv(features: torch.Tensor,
pair_in = indice_pairs_tv[int(inverse)] pair_in = indice_pairs_tv[int(inverse)]
pair_out = indice_pairs_tv[int(not inverse)] pair_out = indice_pairs_tv[int(not inverse)]
filters_tv = torch_tensor_to_tv(filters) filters_tv = torch_tensor_to_tv(filters)
if not features.is_cuda: if not features.is_cuda:
# perform gather-mm-scatter_add for cpu data # perform gather-mm-scatter_add for cpu data
assert not filters.is_cuda assert not filters.is_cuda
...@@ -686,7 +704,8 @@ def indice_conv(features: torch.Tensor, ...@@ -686,7 +704,8 @@ def indice_conv(features: torch.Tensor,
inp_indices = pair_in[i].slice_first_axis(0, nhot) inp_indices = pair_in[i].slice_first_axis(0, nhot)
out_indices = pair_out[i].slice_first_axis(0, nhot) out_indices = pair_out[i].slice_first_axis(0, nhot)
SpconvOps.gather_cpu(inp_buffer_tv, a, inp_indices) SpconvOps.gather_cpu(inp_buffer_tv, a, inp_indices)
filters_cur = filters[i] if FILTER_HWIO else filters[i].T filters_i = filters.select(kv_dim, i)
filters_cur = filters_i if not is_KC_not_CK else filters_i.T
torch.mm(inp_buffer[:nhot], filters_cur, out=out_buffer[:nhot]) torch.mm(inp_buffer[:nhot], filters_cur, out=out_buffer[:nhot])
SpconvOps.scatter_add_cpu(c, out_buffer_tv, out_indices) SpconvOps.scatter_add_cpu(c, out_buffer_tv, out_indices)
...@@ -713,10 +732,10 @@ def indice_conv(features: torch.Tensor, ...@@ -713,10 +732,10 @@ def indice_conv(features: torch.Tensor,
filters_tv.dtype, filters_tv.dtype,
c.dtype, c.dtype,
a.shape, a.shape,
filters.shape[-2:], filter_shape_per_kv,
c.shape, c.shape,
False, False,
False if FILTER_HWIO else True, is_KC_not_CK,
False, False,
arch=arch, arch=arch,
shuffle_type=ShuffleStrideType.ShuffleAC, shuffle_type=ShuffleStrideType.ShuffleAC,
...@@ -732,13 +751,14 @@ def indice_conv(features: torch.Tensor, ...@@ -732,13 +751,14 @@ def indice_conv(features: torch.Tensor,
inp_indices = torch_tensor_to_tv(inp_indices_th) inp_indices = torch_tensor_to_tv(inp_indices_th)
out_indices = torch_tensor_to_tv(out_indices_th) out_indices = torch_tensor_to_tv(out_indices_th)
filter_tv = torch_tensor_to_tv(filters)[profile_idx] filter_tv = torch_tensor_to_tv(filters)[profile_idx]
filter_tv = torch_tensor_to_tv(filters).select(kv_dim, profile_idx)
tuned_res, min_time = GEMM.tune_and_cache( tuned_res, min_time = GEMM.tune_and_cache(
a, a,
filter_tv, filter_tv,
c, c,
False, False,
False if FILTER_HWIO else True, is_KC_not_CK,
False, False,
arch=arch, arch=arch,
shuffle_type=ShuffleStrideType.ShuffleAC, shuffle_type=ShuffleStrideType.ShuffleAC,
...@@ -760,7 +780,7 @@ def indice_conv(features: torch.Tensor, ...@@ -760,7 +780,7 @@ def indice_conv(features: torch.Tensor,
continue continue
inp_indices = pair_in[i].slice_first_axis(0, nhot) inp_indices = pair_in[i].slice_first_axis(0, nhot)
out_indices = pair_out[i].slice_first_axis(0, nhot) out_indices = pair_out[i].slice_first_axis(0, nhot)
b = filters_tv[i] b = filters_tv.select(kv_dim, i)
# inp @ filter.T, NC @ KC # inp @ filter.T, NC @ KC
beta = 1.0 if inited else 0.0 beta = 1.0 if inited else 0.0
algo_desp = GEMM.run_with_tuned_result( algo_desp = GEMM.run_with_tuned_result(
...@@ -769,7 +789,7 @@ def indice_conv(features: torch.Tensor, ...@@ -769,7 +789,7 @@ def indice_conv(features: torch.Tensor,
b, b,
c, c,
False, False,
False if FILTER_HWIO else True, is_KC_not_CK,
False, False,
arch=arch, arch=arch,
stream=stream, stream=stream,
...@@ -807,11 +827,27 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -807,11 +827,27 @@ def indice_conv_backward(features: torch.Tensor,
timer: CUDAKernelTimer = CUDAKernelTimer(False)): timer: CUDAKernelTimer = CUDAKernelTimer(False)):
# print(out_bp.mean(), out_bp.max(), out_bp.min()) # print(out_bp.mean(), out_bp.max(), out_bp.min())
num_activate_out = out_bp.shape[0]
out_channel = out_bp.shape[-1]
filters_shape = filters.shape filters_shape = filters.shape
filters = filters.reshape(-1, *filters.shape[-2:]) if not ALL_WEIGHT_IS_KRSC:
kv = filters.shape[0] kv_dim = 0
is_KC_not_CK = not FILTER_HWIO
if FILTER_HWIO:
out_channel = filters.shape[-1]
filter_shape_per_kv = [filters.shape[-2], out_channel]
else:
out_channel = filters.shape[-2]
filter_shape_per_kv = [out_channel, filters.shape[-1]]
filters = filters.reshape(-1, *filters.shape[-2:])
kv = filters.shape[0]
else:
kv_dim = 1
out_channel = filters.shape[0]
filters = filters.reshape(out_channel, -1, filters.shape[-1])
is_KC_not_CK = True
kv = filters.shape[1]
filter_shape_per_kv = [out_channel, filters.shape[-1]]
kv_center = kv // 2 kv_center = kv // 2
# TODO handle this in nn.Module to make sure features in backward is contiguous # TODO handle this in nn.Module to make sure features in backward is contiguous
if not features.is_contiguous(): if not features.is_contiguous():
...@@ -824,20 +860,24 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -824,20 +860,24 @@ def indice_conv_backward(features: torch.Tensor,
if subm: if subm:
dfilters = torch.zeros_like(filters) dfilters = torch.zeros_like(filters)
if FILTER_HWIO: if not ALL_WEIGHT_IS_KRSC:
torch.mm(features.T, out_bp, out=dfilters[kv_center]) if not is_KC_not_CK:
# TODO can we use torch mm for f16 backward weight? torch.mm(features.T, out_bp, out=dfilters[kv_center])
din = torch.mm(out_bp, filters[kv_center].T) din = torch.mm(out_bp, filters[kv_center].T)
else:
torch.mm(out_bp.T, features, out=dfilters[kv_center])
din = torch.mm(out_bp, filters[kv_center])
else: else:
torch.mm(out_bp.T, features, out=dfilters[kv_center]) # KN @ NC
# TODO can we use torch mm for f16 backward weight? torch.mm(out_bp.T, features, out=dfilters[:, kv_center])
din = torch.mm(out_bp, filters[kv_center]) # NK @ KC
din = torch.mm(out_bp, filters[:, kv_center])
else: else:
dfilters = torch.zeros_like(filters) dfilters = torch.zeros_like(filters)
din = torch.zeros_like(features) din = torch.zeros_like(features)
if kv == 1 and subm: if kv == 1 and subm:
return (din, dfilters.reshape(filters_shape)) return (din, dfilters.reshape(filters_shape))
inited: bool = subm inited: bool = subm
indice_pairs_tv = torch_tensor_to_tv(indice_pairs) indice_pairs_tv = torch_tensor_to_tv(indice_pairs)
# torch slice (a_th[x]) is very slow, so we need to use tv.Tensor earlier. # torch slice (a_th[x]) is very slow, so we need to use tv.Tensor earlier.
...@@ -881,12 +921,18 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -881,12 +921,18 @@ def indice_conv_backward(features: torch.Tensor,
out_indices = pair_out[i].slice_first_axis(0, nhot) out_indices = pair_out[i].slice_first_axis(0, nhot)
SpconvOps.gather_cpu(inp_buffer_tv, features_tv, inp_indices) SpconvOps.gather_cpu(inp_buffer_tv, features_tv, inp_indices)
SpconvOps.gather_cpu(out_buffer_tv, out_bp_tv, out_indices) SpconvOps.gather_cpu(out_buffer_tv, out_bp_tv, out_indices)
filters_T_cur = filters[i].T if FILTER_HWIO else filters[i] filters_i = filters.select(kv_dim, i)
dfilters_cur = dfilters[i] if FILTER_HWIO else dfilters[i].T dfilters_i = dfilters.select(kv_dim, i)
torch.mm(inp_buffer[:nhot].T, out_buffer[:nhot], out=dfilters_cur)
torch.mm(out_buffer[:nhot], filters_T_cur, out=inp_buffer[:nhot])
filters_KC = filters_i if is_KC_not_CK else filters_i.T
if is_KC_not_CK:
# KN @ NC
torch.mm(out_buffer[:nhot].T, inp_buffer[:nhot], out=dfilters_i)
else:
# CN @ NK
torch.mm(inp_buffer[:nhot].T, out_buffer[:nhot], out=dfilters_i)
# NK @ KC
torch.mm(out_buffer[:nhot], filters_KC, out=inp_buffer[:nhot])
SpconvOps.scatter_add_cpu(din_tv, inp_buffer_tv, inp_indices) SpconvOps.scatter_add_cpu(din_tv, inp_buffer_tv, inp_indices)
return (din, dfilters.reshape(filters_shape)) return (din, dfilters.reshape(filters_shape))
...@@ -910,10 +956,10 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -910,10 +956,10 @@ def indice_conv_backward(features: torch.Tensor,
filters_tv.dtype, filters_tv.dtype,
din_tv.dtype, din_tv.dtype,
out_bp_tv.shape, out_bp_tv.shape,
filters.shape[-2:], filter_shape_per_kv,
din_tv.shape, din_tv.shape,
False, False,
True if FILTER_HWIO else False, not is_KC_not_CK,
False, False,
arch=arch, arch=arch,
shuffle_type=ShuffleStrideType.ShuffleAC, shuffle_type=ShuffleStrideType.ShuffleAC,
...@@ -923,13 +969,13 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -923,13 +969,13 @@ def indice_conv_backward(features: torch.Tensor,
if tuned_res_dgrad is None: if tuned_res_dgrad is None:
inp_indices = pair_in[profile_idx].slice_first_axis(0, nhot_profile) inp_indices = pair_in[profile_idx].slice_first_axis(0, nhot_profile)
out_indices = pair_out[profile_idx].slice_first_axis(0, nhot_profile) out_indices = pair_out[profile_idx].slice_first_axis(0, nhot_profile)
filter_tv = filters_tv[profile_idx] filter_tv = filters_tv.select(kv_dim, profile_idx)
tuned_res_dgrad, min_time = GEMM.tune_and_cache( tuned_res_dgrad, min_time = GEMM.tune_and_cache(
out_bp_tv, out_bp_tv,
filter_tv, filter_tv,
din_tv, din_tv,
False, False,
True if FILTER_HWIO else False, not is_KC_not_CK,
False, False,
arch=arch, arch=arch,
shuffle_type=ShuffleStrideType.ShuffleAC, shuffle_type=ShuffleStrideType.ShuffleAC,
...@@ -939,7 +985,7 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -939,7 +985,7 @@ def indice_conv_backward(features: torch.Tensor,
beta=0.0, beta=0.0,
hint=AlgoHint.BackwardInput.value, hint=AlgoHint.BackwardInput.value,
stream=stream) stream=stream)
if not FILTER_HWIO: if is_KC_not_CK:
a_wgrad = out_bp_tv a_wgrad = out_bp_tv
b_wgrad = features_tv b_wgrad = features_tv
else: else:
...@@ -951,7 +997,7 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -951,7 +997,7 @@ def indice_conv_backward(features: torch.Tensor,
filters_tv.dtype, filters_tv.dtype,
a_wgrad.shape, a_wgrad.shape,
b_wgrad.shape, b_wgrad.shape,
filters.shape[-2:], filter_shape_per_kv,
True, True,
False, False,
False, False,
...@@ -964,8 +1010,8 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -964,8 +1010,8 @@ def indice_conv_backward(features: torch.Tensor,
if tuned_res_wgrad is None: if tuned_res_wgrad is None:
inp_indices = pair_in[profile_idx].slice_first_axis(0, nhot_profile) inp_indices = pair_in[profile_idx].slice_first_axis(0, nhot_profile)
out_indices = pair_out[profile_idx].slice_first_axis(0, nhot_profile) out_indices = pair_out[profile_idx].slice_first_axis(0, nhot_profile)
dfilter_tv = dfilters_tv[profile_idx] dfilter_tv = dfilters_tv.select(kv_dim, profile_idx)
if not FILTER_HWIO: if is_KC_not_CK:
a_inds_wgrad = out_indices a_inds_wgrad = out_indices
b_inds_wgrad = inp_indices b_inds_wgrad = inp_indices
else: else:
...@@ -988,7 +1034,7 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -988,7 +1034,7 @@ def indice_conv_backward(features: torch.Tensor,
stream=stream) stream=stream)
# print(tuned_res_wgrad.algo_desp, tuned_res_wgrad.splitk, min_time) # print(tuned_res_wgrad.algo_desp, tuned_res_wgrad.splitk, min_time)
# get workspace size for wgrad # get workspace size for wgrad
if not FILTER_HWIO: if is_KC_not_CK:
a_shape = [maxnhot, out_bp_tv.dim(1)] a_shape = [maxnhot, out_bp_tv.dim(1)]
b_shape = [maxnhot, features_tv.dim(1)] b_shape = [maxnhot, features_tv.dim(1)]
else: else:
...@@ -1030,13 +1076,13 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -1030,13 +1076,13 @@ def indice_conv_backward(features: torch.Tensor,
inp_indices = pair_in[i].slice_first_axis(0, nhot) inp_indices = pair_in[i].slice_first_axis(0, nhot)
out_indices = pair_out[i].slice_first_axis(0, nhot) out_indices = pair_out[i].slice_first_axis(0, nhot)
# out.T @ inp, NK @ NC # out.T @ inp, NK @ NC
# print(features_tv.shape, out_bp_tv.shape) filter_i_tv = filters_tv.select(kv_dim, i)
GEMM.run_with_tuned_result(tuned_res_dgrad, GEMM.run_with_tuned_result(tuned_res_dgrad,
out_bp_tv, out_bp_tv,
filters_tv[i], filter_i_tv,
din_tv, din_tv,
False, False,
True if FILTER_HWIO else False, not is_KC_not_CK,
False, False,
arch=arch, arch=arch,
stream=stream, stream=stream,
...@@ -1047,7 +1093,7 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -1047,7 +1093,7 @@ def indice_conv_backward(features: torch.Tensor,
alpha=1.0, alpha=1.0,
beta=beta) beta=beta)
if not FILTER_HWIO: if is_KC_not_CK:
a = out_bp_tv a = out_bp_tv
b = features_tv b = features_tv
a_inds = out_indices a_inds = out_indices
...@@ -1060,7 +1106,7 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -1060,7 +1106,7 @@ def indice_conv_backward(features: torch.Tensor,
GEMM.run_with_tuned_result(tuned_res_wgrad, GEMM.run_with_tuned_result(tuned_res_wgrad,
a, a,
b, b,
dfilters_tv[i], dfilters_tv.select(kv_dim, i),
True, True,
False, False,
False, False,
...@@ -1365,6 +1411,9 @@ def implicit_gemm_backward(features: torch.Tensor, ...@@ -1365,6 +1411,9 @@ def implicit_gemm_backward(features: torch.Tensor,
mask_width=-1, mask_width=-1,
beta=beta, beta=beta,
stream=stream) stream=stream)
# for backward weight, beta = 0 because each split
# handle different kernel locations.
# TODO remove D iterator in backward weight kernel
CONV.run_with_tuned_result( CONV.run_with_tuned_result(
wgrad_tune_res, wgrad_tune_res,
ConvOpType.kBackwardWeight, ConvOpType.kBackwardWeight,
...@@ -1378,7 +1427,7 @@ def implicit_gemm_backward(features: torch.Tensor, ...@@ -1378,7 +1427,7 @@ def implicit_gemm_backward(features: torch.Tensor,
reverse_mask=False, reverse_mask=False,
mask_filter=masks[j].item(), mask_filter=masks[j].item(),
mask_width=mask_width, mask_width=mask_width,
beta=beta, beta=0,
workspace=workspace_tv, workspace=workspace_tv,
stream=stream) stream=stream)
......
...@@ -24,7 +24,7 @@ from spconv.core import ConvAlgo ...@@ -24,7 +24,7 @@ from spconv.core import ConvAlgo
import spconv.pytorch as spconv import spconv.pytorch as spconv
from spconv.utils import Point2VoxelCPU3d from spconv.utils import Point2VoxelCPU3d
# torch.backends.cudnn.enabled = False
def waymo_data(batch_size=1): def waymo_data(batch_size=1):
gen = Point2VoxelCPU3d([0.1, 0.1, 0.1], [-80, -80, -2, 80, 80, 6], 3, gen = Point2VoxelCPU3d([0.1, 0.1, 0.1], [-80, -80, -2, 80, 80, 6], 3,
150000, 1) 150000, 1)
...@@ -168,8 +168,8 @@ class Net(nn.Module): ...@@ -168,8 +168,8 @@ class Net(nn.Module):
# nn.ReLU(), # nn.ReLU(),
# spconv.SparseInverseConv3d(256, 128, 2, indice_key="m5", bias=False, algo=algo), # spconv.SparseInverseConv3d(256, 128, 2, indice_key="m5", bias=False, algo=algo),
# # nn.BatchNorm1d(128), # # # nn.BatchNorm1d(128),
# # nn.ReLU(), # # # nn.ReLU(),
# spconv.SparseInverseConv3d(128, 64, 2, indice_key="m4", bias=False, algo=algo), # spconv.SparseInverseConv3d(128, 64, 2, indice_key="m4", bias=False, algo=algo),
) )
...@@ -312,7 +312,8 @@ def main(): ...@@ -312,7 +312,8 @@ def main():
# MaskImpGemm: 51.0ms # MaskImpGemm: 51.0ms
# MaskSplitImpGemm: 41.1ms # MaskSplitImpGemm: 41.1ms
# algo = None # algo = None
net = Net(spatial_shape, algo).to(device).eval().to(dtype).train() net = Net(spatial_shape, algo).to(device).eval().to(dtype)# .train()
# net.load_state_dict(net.state_dict())
spconv.assign_name_for_sparse_modules(net) spconv.assign_name_for_sparse_modules(net)
print(coors_th.shape) print(coors_th.shape)
out = net(voxels_th, coors_th, 1) out = net(voxels_th, coors_th, 1)
...@@ -329,12 +330,12 @@ def main(): ...@@ -329,12 +330,12 @@ def main():
print("------------") print("------------")
torch.cuda.synchronize() torch.cuda.synchronize()
t = time.time() t = time.time()
out_nograd = net(voxels_th, coors_th, 1, True) out_nograd = net(voxels_th, coors_th, 1, False)
timer = out_nograd._timer timer = out_nograd._timer
res = timer.collect_by_name("forward", timer.get_all_pair_time()) # res = timer.collect_by_name("forward", timer.get_all_pair_time())
res2 = timer.collect_by_name("forward0", timer.get_all_pair_time()) # res2 = timer.collect_by_name("forward0", timer.get_all_pair_time())
print(sum(res.values()) + sum(res2.values())) # print(sum(res.values()) + sum(res2.values()))
# print(timer.get_all_pair_time()) # print(timer.get_all_pair_time())
# print(sum(timer.get_all_pair_time().values())) # print(sum(timer.get_all_pair_time().values()))
...@@ -342,7 +343,7 @@ def main(): ...@@ -342,7 +343,7 @@ def main():
# sort_bench() # sort_bench()
times.append(time.time() - t) times.append(time.time() - t)
print("spconv time", np.mean(times[10:])) print("spconv time", np.mean(times[10:]))
# times = [] times = []
# for i in range(10): # for i in range(10):
# out = net(voxels_th, coors_th, 1) # out = net(voxels_th, coors_th, 1)
......
This diff is collapsed.
This diff is collapsed.
# Copyright 2021 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from spconv.core_cc.csrc.sparse.all import SpconvOps
# Copyright 2021 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Compare results between different algos:
CPU: simple gather-mm-scatter
Native: Fused gather-mm-scatter
ImplicitGemm: implicit gemm
"""
import time
from pathlib import Path
import numpy as np
import torch
from torch import nn
from cumm import tensorview as tv
from spconv.core import ConvAlgo
import spconv.pytorch as spconv
import pickle
from spconv.test_utils import generate_sparse_data, params_grid
class Net(nn.Module):
def __init__(self, shape, algo):
super().__init__()
pool_algo = algo
# pool_algo = ConvAlgo.Native
self.net = spconv.SparseSequential(
spconv.SubMConv3d(3, 32, 3, bias=False, indice_key="c0",
algo=algo),
spconv.SubMConv3d(32,
32,
3,
bias=False,
indice_key="c0",
algo=algo),
# # nn.BatchNorm1d(32),
# # nn.ReLU(),
spconv.SubMConv3d(32, 64, 3, bias=False, indice_key="c0",
algo=algo),
spconv.SubMConv3d(64,
64,
3,
bias=False,
indice_key="c0",
algo=algo),
# nn.BatchNorm1d(32),
# # nn.ReLU(),
spconv.SparseConv3d(64, 64, 3, 2, 1, bias=False, indice_key="m0", algo=algo),
# # spconv.SparseMaxPool3d(2, 2, algo=pool_algo),
spconv.SubMConv3d(64,
96,
3,
bias=False,
indice_key="c1",
algo=algo),
spconv.SubMConv3d(96,
96,
3,
bias=False,
indice_key="c1",
algo=algo),
# nn.BatchNorm1d(64),
# nn.ReLU(),
spconv.SparseConv3d(96, 96, 2, 2, bias=False, indice_key="m1", algo=algo),
# spconv.SparseMaxPool3d(2, 2, algo=pool_algo),
spconv.SubMConv3d(96,
128,
3,
bias=False,
indice_key="c2",
algo=algo),
spconv.SubMConv3d(128,
128,
3,
bias=False,
indice_key="c2",
algo=algo),
# nn.BatchNorm1d(128),
# nn.ReLU(),
# spconv.SparseConv3d(128, 128, 2, 2, bias=False, indice_key="m2"),
spconv.SparseMaxPool3d(2, 2, algo=pool_algo),
spconv.SubMConv3d(128,
160,
3,
bias=False,
indice_key="c3",
algo=algo),
spconv.SubMConv3d(160,
160,
3,
bias=False,
indice_key="c3",
algo=algo),
# nn.BatchNorm1d(128),
# nn.ReLU(),
# spconv.SparseConv3d(160, 160, 2, 2, bias=False, indice_key="m3"),
spconv.SparseMaxPool3d(2, 2, algo=pool_algo, indice_key="m3"),
spconv.SubMConv3d(160,
192,
3,
bias=False,
indice_key="c4",
algo=algo),
spconv.SubMConv3d(192,
192,
3,
bias=False,
indice_key="c4",
algo=algo),
# nn.BatchNorm1d(128),
# nn.ReLU(),
spconv.SparseMaxPool3d(2, 2, indice_key="m4", algo=pool_algo),
# spconv.SparseConv3d(192, 192, 2, 2, bias=False, indice_key="m4"),
spconv.SubMConv3d(192,
224,
3,
bias=False,
indice_key="c5",
algo=algo),
spconv.SubMConv3d(224,
224,
3,
bias=False,
indice_key="c5",
algo=algo),
# nn.BatchNorm1d(256),
# nn.ReLU(),
spconv.SparseInverseConv3d(224, 128, 2, indice_key="m4", bias=False, algo=algo),
# # nn.BatchNorm1d(128),
# nn.ReLU(),
spconv.SparseInverseConv3d(128, 64, 2, indice_key="m3", bias=False, algo=algo),
)
max_batch_size = 1
# grid (dense map) is used for indice generation. use pre-allocated grid can run faster.
# self.grid = None
self.shape = shape
def forward(self, features, coors, batch_size):
x = spconv.SparseConvTensor(features,
coors,
self.shape,
batch_size)
return self.net(x)
class NetLight(nn.Module):
def __init__(self, shape, algo):
super().__init__()
pool_algo = algo
# pool_algo = ConvAlgo.Native
self.net = spconv.SparseSequential(
spconv.SubMConv3d(3, 32, 3, bias=False, indice_key="c0",
algo=algo),
spconv.SubMConv3d(32,
32,
3,
bias=False,
indice_key="c0",
algo=algo),
# # nn.BatchNorm1d(32),
# # nn.ReLU(),
spconv.SubMConv3d(32, 64, 3, bias=False, indice_key="c0",
algo=algo),
spconv.SubMConv3d(64,
64,
3,
bias=False,
indice_key="c0",
algo=algo),
# nn.BatchNorm1d(32),
# # nn.ReLU(),
spconv.SparseConv3d(64, 64, 3, 2, 1, bias=False, indice_key="m0", algo=algo),
# # spconv.SparseMaxPool3d(2, 2, algo=pool_algo),
spconv.SubMConv3d(64,
96,
3,
bias=False,
indice_key="c1",
algo=algo),
spconv.SubMConv3d(96,
96,
3,
bias=False,
indice_key="c1",
algo=algo),
# nn.BatchNorm1d(64),
# nn.ReLU(),
spconv.SparseConv3d(96, 96, 2, 2, bias=False, indice_key="m1", algo=algo),
# spconv.SparseMaxPool3d(2, 2, algo=pool_algo),
spconv.SparseInverseConv3d(96, 64, 2, indice_key="m1", bias=False, algo=algo),
# # nn.BatchNorm1d(128),
# nn.ReLU(),
spconv.SparseInverseConv3d(64, 32, 3, indice_key="m0", bias=False, algo=algo),
)
max_batch_size = 1
# grid (dense map) is used for indice generation. use pre-allocated grid can run faster.
# self.grid = None
self.shape = shape
def forward(self, features, coors, batch_size):
x = spconv.SparseConvTensor(features,
coors,
self.shape,
batch_size)
return self.net(x)
def _test_multi_impl(dtype: torch.dtype):
# TODO remove or release this when tf32 op is ready
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
np.random.seed(50051)
if dtype != torch.float16:
with open(Path(__file__).parent / "data" / "test_spconv.pkl", "rb") as f:
(voxels, coors, spatial_shape) = pickle.load(f)
else:
# CPU fp16 is very slow, so we use a small data here.
spatial_shape = [19, 18, 17]
sparse_dict = generate_sparse_data(spatial_shape, [1500] * 1, 3)
voxels = np.ascontiguousarray(sparse_dict["features"]).astype(
np.float32)
coors = np.ascontiguousarray(
sparse_dict["indices"][:, [3, 0, 1, 2]]).astype(np.int32)
device = torch.device("cuda:0")
device_cpu = torch.device("cpu:0")
voxels_th = torch.from_numpy(voxels).to(device_cpu).to(dtype)
coors_th = torch.from_numpy(coors).to(device_cpu).int()
voxels_th_cuda = torch.from_numpy(voxels).to(device).to(dtype)
coors_th_cuda = torch.from_numpy(coors).to(device).int()
net_cls = Net
if dtype == torch.float16:
# CPU fp16 is very slow, so we use a small network here.
net_cls = NetLight
# cpu
torch.manual_seed(50051)
net_native_cpu = net_cls(spatial_shape, ConvAlgo.Native).to(device_cpu).to(dtype)
# gpu_native
torch.manual_seed(50051)
net_native_gpu = net_cls(spatial_shape, ConvAlgo.Native).to(device).to(dtype)
torch.manual_seed(50051)
net_imp_gpu = net_cls(spatial_shape, ConvAlgo.MaskImplicitGemm).to(device).to(dtype)
torch.manual_seed(50051)
net_simp_gpu = net_cls(spatial_shape, ConvAlgo.MaskSplitImplicitGemm).to(device).to(dtype)
spconv.assign_name_for_sparse_modules(net_native_cpu)
spconv.assign_name_for_sparse_modules(net_native_gpu)
spconv.assign_name_for_sparse_modules(net_imp_gpu)
spconv.assign_name_for_sparse_modules(net_simp_gpu)
with torch.no_grad():
out: torch.Tensor = net_native_cpu(voxels_th, coors_th, 1).dense()
dout = np.random.uniform(-0.2, 0.2, out.shape).astype(np.float32)
dout_t = torch.from_numpy(dout).to(device_cpu).to(dtype)
dout_t_cu = torch.from_numpy(dout).to(device).to(dtype)
out_cpu = net_native_cpu(voxels_th, coors_th, 1).dense()
out_cpu.backward(dout_t)
out = net_native_gpu(voxels_th_cuda, coors_th_cuda, 1).dense()
out.backward(dout_t_cu)
out_imp = net_imp_gpu(voxels_th_cuda, coors_th_cuda, 1).dense()
out_imp.backward(dout_t_cu)
out_simp = net_simp_gpu(voxels_th_cuda, coors_th_cuda, 1).dense()
out_simp.backward(dout_t_cu)
with torch.no_grad():
dense_cpu = out_cpu.cuda()
dense_native = out
dense_imp = out_imp
dense_simp = out_simp
error_native = torch.linalg.norm(dense_cpu - dense_native).cpu().item()
error_imp = torch.linalg.norm(dense_cpu - dense_imp).cpu().item()
error_simp = torch.linalg.norm(dense_cpu - dense_simp).cpu().item()
print("error_native", error_native)
print("error_imp", error_imp)
print("error_simp", error_simp)
if dtype == torch.float32:
assert error_native < 0.01
assert error_imp < 0.01
assert error_simp < 0.01
else:
assert error_native < 10
assert error_imp < 10
assert error_simp < 10
cpu_params = dict(net_native_cpu.named_parameters())
native_params = dict(net_native_gpu.named_parameters())
imp_params = dict(net_imp_gpu.named_parameters())
simp_params = dict(net_simp_gpu.named_parameters())
for k, cpu_w in cpu_params.items():
native_w = native_params[k]
imp_w = imp_params[k]
simp_w = simp_params[k]
cpu_w_grad = cpu_w.grad.detach().cuda()
native_w_grad = native_w.grad.detach()
imp_w_grad = imp_w.grad.detach()
simp_w_grad = simp_w.grad.detach()
error_native = torch.linalg.norm(native_w_grad - cpu_w_grad).cpu().item()
error_imp = torch.linalg.norm(native_w_grad - imp_w_grad).cpu().item()
error_simp = torch.linalg.norm(native_w_grad - simp_w_grad).cpu().item()
print(k, error_native, error_imp, error_simp)
assert error_imp < 1
assert error_simp < 1
def test_multi_impl():
_test_multi_impl(torch.float32)
_test_multi_impl(torch.float16)
if __name__ == "__main__":
test_multi_impl()
# Copyright 2021 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# developers must run this file before push or pull request.
# this script contains three parts:
# 1. unit tests for all gemm/conv kernels
# 2. comparison test: compare network fwd/bwd results between CPU, Native, ImplicitGemm
# 3. f32/f16 train/eval test based on mnist and some small datasets
echo "-------------UNIT TEST START--------------"
pytest ./test
echo "-------------UNIT TEST END--------------"
python ./example/mnist_sparse.py --fp16
\ No newline at end of file
...@@ -28,12 +28,12 @@ if (($CUDA_VERSION_FULL -eq "10.2") -or ($CUDA_VERSION_FULL -eq "11.0") -or ($CU ...@@ -28,12 +28,12 @@ if (($CUDA_VERSION_FULL -eq "10.2") -or ($CUDA_VERSION_FULL -eq "11.0") -or ($CU
) )
} elseif ($CUDA_VERSION_FULL -eq "11.3"){ } elseif ($CUDA_VERSION_FULL -eq "11.3"){
$CUDA_PACKAGES_IN = @( $CUDA_PACKAGES_IN = @(
"cuda_nvcc"; "nvcc";
"visual_studio_integration"; "visual_studio_integration";
"cuda_nvrtc"; "nvrtc_dev";
"cuda_cudart"; "cudart";
"cuda_thrust"; "thrust";
"libcurand"; "curand_dev";
) )
} else { } else {
# after cuda 11.4 # after cuda 11.4
......
2.1.21 2.2.0
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment