Commit 73a5ce7d authored by yan.yan's avatar yan.yan
Browse files

add direct table

parent 0c07559f
......@@ -95,13 +95,19 @@ class AllocKeys:
HashV = "HashV"
ThrustTemp = "ThrustTemp"
TightUniqueCount = "TightUniqueCount"
SPCONV_DEBUG_WEIGHT = False
SPCONV_CPP_INDICE_PAIRS = False
SPCONV_CPP_INDICE_PAIRS_IGEMM = False
SPCONV_CPP_GEMM = False
# currently use cpp pair gen is slightly slower than python, I don't know why.
SPCONV_CPP_INDICE_PAIRS_IGEMM = os.getenv("SPCONV_CPP_INDICE_PAIRS_IGEMM", "0") == "1"
SPCONV_CPP_GEMM = True
SPCONV_FX_TRACE_MODE = os.getenv("SPCONV_FX_TRACE_MODE", "0") == "1"
SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE = 1.1
\ No newline at end of file
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
from cumm.tensorview import Tensor
from cumm.tensorview import CUDAKernelTimer
class ThrustCustomAllocatorV2:
alloc_func: Callable[int, int]
class SpconvOps:
......@@ -92,6 +93,55 @@ class SpconvOps:
"""
...
@staticmethod
def generate_conv_inds_mask_stage1_direct_table(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_num_per_loc: Tensor, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> None:
"""
Args:
indices:
hashdata_k:
hashdata_v:
indice_pairs_bwd:
indice_pairs_uniq:
indice_num_per_loc:
batch_size:
output_dims:
input_dims:
ksize:
stride:
padding:
dilation:
transposed:
stream_int:
"""
...
@staticmethod
def unique_hash(hashdata_k: Tensor, hashdata_v: Tensor, uniq_cnt: Tensor, out_indices_offset: Tensor, num_out_bound: int, stream_int: int = 0) -> int:
"""
Args:
hashdata_k:
hashdata_v:
uniq_cnt:
out_indices_offset:
num_out_bound:
stream_int:
"""
...
@staticmethod
def assign_output_direct_hash(out_indices_offset: Tensor, out_indices: Tensor, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], stream_int: int = 0) -> None:
"""
Args:
out_indices_offset:
out_indices:
batch_size:
output_dims:
input_dims:
ksize:
stride:
padding:
dilation:
stream_int:
"""
...
@staticmethod
def generate_conv_inds_mask_stage2(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> int:
"""
Args:
......@@ -118,6 +168,32 @@ class SpconvOps:
"""
...
@staticmethod
def generate_conv_inds_stage2_mask_direct_table(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> int:
"""
Args:
indices:
hashdata_k:
hashdata_v:
indice_pairs_fwd:
indice_pairs_bwd:
indice_pairs_uniq:
indice_pairs_uniq_before_sort:
out_inds:
mask_fwd:
mask_bwd:
num_out_act:
batch_size:
output_dims:
input_dims:
ksize:
stride:
padding:
dilation:
transposed:
stream_int:
"""
...
@staticmethod
def generate_subm_conv_inds(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs: Tensor, out_inds: Tensor, indice_num_per_loc: Tensor, batch_size: int, input_dims: List[int], ksize: List[int], dilation: List[int], indice_pair_mask: Tensor = Tensor(), backward: bool = False, stream_int: int = 0) -> int:
"""
Args:
......@@ -427,30 +503,45 @@ class SpconvOps:
@staticmethod
def get_int32_max() -> int: ...
@staticmethod
def get_indice_gen_workspace_size(kv: int, num_act_in: int, num_act_out_bound: int, subm: bool, use_int64_hash_k: bool) -> int:
def get_handcrafted_max_act_out(num_act_in: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int]) -> int:
"""
Args:
num_act_in:
ksize:
stride:
padding:
dilation:
"""
...
@staticmethod
def get_indice_gen_workspace_size(kv: int, num_act_in: int, num_act_out_bound: int, max_act_out_in_theory: int, subm: bool, use_int64_hash_k: bool, direct_table: bool) -> int:
"""
Args:
kv:
num_act_in:
num_act_out_bound:
max_act_out_in_theory:
subm:
use_int64_hash_k:
direct_table:
"""
...
@staticmethod
def get_indice_gen_tensors_from_workspace(workspace, kv: int, num_act_in: int, num_act_out_bound: int, subm: bool, use_int64_hash_k: bool) -> Dict[str, Tensor]:
def get_indice_gen_tensors_from_workspace(workspace, kv: int, num_act_in: int, num_act_out_bound: int, max_act_out_in_theory: int, subm: bool, use_int64_hash_k: bool, direct_table: bool) -> Dict[str, Tensor]:
"""
Args:
workspace:
kv:
num_act_in:
num_act_out_bound:
max_act_out_in_theory:
subm:
use_int64_hash_k:
direct_table:
"""
...
@staticmethod
def get_indice_pairs_implicit_gemm(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, is_train: bool, stream_int: int = 0, num_out_act_bound: int = -1) -> Tuple[Tensor, int]:
def get_indice_pairs_implicit_gemm(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, is_train: bool, stream_int: int = 0, num_out_act_bound: int = -1, timer: CUDAKernelTimer = CUDAKernelTimer(False), direct_table: bool = False, preallocated: Dict[str, Tensor] = {}) -> Tuple[Tensor, int]:
"""
Args:
allocator:
......@@ -468,6 +559,9 @@ class SpconvOps:
is_train:
stream_int:
num_out_act_bound:
timer:
direct_table:
preallocated:
"""
...
@staticmethod
......
This diff is collapsed.
This diff is collapsed.
......@@ -33,13 +33,21 @@ _TORCH_DTYPE_TO_TV = {
torch.int16: tv.int16,
torch.uint8: tv.uint8,
}
_TV_DTYPE_TO_TORCH = {v: k for k, v in _TORCH_DTYPE_TO_TV.items()}
_TORCH_UINT_WORKAROUNDS = {
tv.uint32: tv.int32,
tv.uint16: tv.int16,
tv.uint64: tv.int64
}
_TV_DTYPE_TO_TORCH = {v: k for k, v in _TORCH_DTYPE_TO_TV.items()}
_TV_DTYPE_TO_TORCH.update({
tv.uint32: torch.int32,
tv.uint16: torch.int16,
tv.uint64: torch.int64
})
_ALL_INTS = {
tv.int32, tv.int16, tv.int8, tv.int64, tv.uint64, tv.uint8, tv.uint32,
tv.uint16
......@@ -106,91 +114,66 @@ class TorchAllocator(ExternalAllocator):
device: int, stream: int = 0, is_temp_memory: bool = False) -> tv.Tensor:
# TODO free memory by name if its already free by pointer.
# provide a name if you want to access it after c++ function exit.
torch_uint_workaround = dtype in _TORCH_UINT_WORKAROUNDS
dtype_bkp = dtype
if dtype in _TORCH_UINT_WORKAROUNDS:
# assert name == "", "must be temp memory for uint dtypes"
dtype = _TORCH_UINT_WORKAROUNDS[dtype]
th_dtype = _TV_DTYPE_TO_TORCH[dtype]
if device == -1:
dev = self.cpudevice
else:
dev = self.gpudevice
ten = torch.zeros(shape, dtype=th_dtype, device=dev)
ten_tv = torch_tensor_to_tv(ten)
self.allocated[ten.data_ptr()] = ten
ten_tv = torch_tensor_to_tv(ten, dtype_bkp)
self.allocated[ten_tv.byte_pointer()] = ten
if name and not is_temp_memory:
self.allocated[name] = ten
if torch_uint_workaround:
return ten_tv.type_view(dtype_bkp)
return ten_tv
def empty(self, name: str, shape: List[int], dtype: int,
device: int, stream: int = 0, is_temp_memory: bool = False) -> tv.Tensor:
torch_uint_workaround = dtype in _TORCH_UINT_WORKAROUNDS
dtype_bkp = dtype
if dtype in _TORCH_UINT_WORKAROUNDS:
# assert name == "", "must be temp memory for uint dtypes"
dtype = _TORCH_UINT_WORKAROUNDS[dtype]
th_dtype = _TV_DTYPE_TO_TORCH[dtype]
if device == -1:
dev = self.cpudevice
else:
dev = self.gpudevice
ten = torch.empty(shape, dtype=th_dtype, device=dev)
ten_tv = torch_tensor_to_tv(ten)
self.allocated[ten.data_ptr()] = ten
ten_tv = torch_tensor_to_tv(ten, dtype_bkp)
self.allocated[ten_tv.byte_pointer()] = ten
if name and not is_temp_memory:
self.allocated[name] = ten
if torch_uint_workaround:
return ten_tv.type_view(dtype_bkp)
return ten_tv
def full_int(self, name: str, shape: List[int], value: int, dtype: int,
device: int, stream: int = 0, is_temp_memory: bool = False) -> tv.Tensor:
if dtype in _TORCH_UINT_WORKAROUNDS and value < 0:
raise NotImplementedError("you can't use full for unsigned dtypes")
torch_uint_workaround = dtype in _TORCH_UINT_WORKAROUNDS
dtype_bkp = dtype
if dtype in _TORCH_UINT_WORKAROUNDS:
assert name == "", "must be temp memory for uint dtypes"
dtype = _TORCH_UINT_WORKAROUNDS[dtype]
th_dtype = _TV_DTYPE_TO_TORCH[dtype]
if device == -1:
dev = self.cpudevice
else:
dev = self.gpudevice
ten = torch.full(shape, value, dtype=th_dtype, device=dev)
ten_tv = torch_tensor_to_tv(ten)
self.allocated[ten.data_ptr()] = ten
ten_tv = torch_tensor_to_tv(ten, dtype_bkp)
self.allocated[ten_tv.byte_pointer()] = ten
if name and not is_temp_memory:
self.allocated[name] = ten
if torch_uint_workaround:
return ten_tv.type_view(dtype_bkp)
return ten_tv
def full_float(self, name: str, shape: List[int], value: float, dtype: int,
device: int, stream: int = 0, is_temp_memory: bool = False) -> tv.Tensor:
if dtype in _TORCH_UINT_WORKAROUNDS and value < 0:
raise NotImplementedError("you can't use full for unsigned dtypes")
torch_uint_workaround = dtype in _TORCH_UINT_WORKAROUNDS
dtype_bkp = dtype
if dtype in _TORCH_UINT_WORKAROUNDS:
assert name == "", "must be temp memory for uint dtypes"
dtype = _TORCH_UINT_WORKAROUNDS[dtype]
th_dtype = _TV_DTYPE_TO_TORCH[dtype]
if device == -1:
dev = self.cpudevice
else:
dev = self.gpudevice
ten = torch.full(shape, value, dtype=th_dtype, device=dev)
ten_tv = torch_tensor_to_tv(ten)
self.allocated[ten.data_ptr()] = ten
ten_tv = torch_tensor_to_tv(ten, dtype_bkp)
self.allocated[ten_tv.byte_pointer()] = ten
if name and not is_temp_memory:
self.allocated[name] = ten
if torch_uint_workaround:
return ten_tv.type_view(dtype_bkp)
return ten_tv
def get_tensor_by_name(self, name: str):
......
......@@ -26,7 +26,7 @@ from spconv.pytorch.core import ThrustSortAllocator
from spconv.pytorch.cppcore import TorchAllocator, torch_tensor_to_tv, get_current_stream, get_arch, TorchSpconvMatmul
from spconv.core_cc.csrc.sparse.all import SpconvOps
from spconv.core_cc.csrc.sparse.alloc import ExternalAllocator
from spconv.constants import SPCONV_CPP_INDICE_PAIRS, SPCONV_CPP_INDICE_PAIRS_IGEMM, SPCONV_CPP_GEMM
from spconv.constants import SPCONV_CPP_INDICE_PAIRS, SPCONV_CPP_INDICE_PAIRS_IGEMM, SPCONV_CPP_GEMM, SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE
import spconv.core_cc as _ext
from spconv.core_cc.csrc.sparse.convops.spops import ConvGemmOps
from spconv.utils import nullcontext
......@@ -46,7 +46,7 @@ from cumm.gemm import codeops
from spconv.tools import CUDAKernelTimer
DEBUG = False
DEBUG_INT64_HASH_K = True
DEBUG_INT64_HASH_K = False
INT32_MAX = SpconvOps.get_int32_max()
......@@ -77,12 +77,17 @@ def get_deconv_output_size(input_size, kernel_size, stride, padding, dilation,
class _HashData:
def __init__(self, num: int, use_i64: bool, device: torch.device) -> None:
def __init__(self,
num: int,
use_i64: bool,
device: torch.device,
rate: float = 2.0) -> None:
if use_i64:
self.hashdata_k = torch.empty((num * 2, ),
self.hashdata_k = torch.empty((int(num * rate), ),
dtype=torch.int64,
device=device)
self.hashdata_v = torch.empty((num * 2, ),
self.hashdata_v = torch.empty((int(num * rate), ),
dtype=torch.int32,
device=device)
self.hashdata_k_tv = torch_tensor_to_tv(self.hashdata_k)
......@@ -91,7 +96,7 @@ class _HashData:
else:
self.hashdata = torch.empty((
2,
num * 2,
int(num * rate),
),
dtype=torch.int32,
device=device)
......@@ -309,7 +314,8 @@ def get_indice_pairs_implicit_gemm(
is_train: bool = True,
alloc: Optional[ThrustSortAllocator] = None,
timer: CUDAKernelTimer = CUDAKernelTimer(False),
num_out_act_bound: int = -1):
num_out_act_bound: int = -1,
direct_table: bool = True):
"""
Why return tuple? because pytorch seems don't support custom object in autograd.
return: (
......@@ -323,14 +329,33 @@ def get_indice_pairs_implicit_gemm(
mask_argsort_bwd_splits, # torch.Tensor() if subm or inference mode
masks,
)
direct_table: a hash-based regular conv pair gen algo to avoid unique operation.
runs faster than pytorch unique with num_voxel < 1000k.
"""
stream = get_current_stream()
if SPCONV_CPP_INDICE_PAIRS_IGEMM:
thalloc = TorchAllocator(indices.device)
timer_cpp = tv.CUDAKernelTimer(False)
if timer._timer is not None:
timer_cpp = timer._timer
mask_tensor, num_act_out = SpconvOps.get_indice_pairs_implicit_gemm(
thalloc, torch_tensor_to_tv(indices), batch_size, spatial_shape,
algo.value, ksize, stride, padding, dilation, out_padding, subm,
transpose, is_train, stream, num_out_act_bound)
thalloc,
torch_tensor_to_tv(indices),
batch_size,
spatial_shape,
algo.value,
ksize,
stride,
padding,
dilation,
out_padding,
subm,
transpose,
is_train,
stream,
num_out_act_bound,
timer=timer_cpp,
direct_table=direct_table)
mask_split_count = mask_tensor.dim(0)
masks = [mask_tensor[i:i + 1].numpy() for i in range(mask_split_count)]
if subm:
......@@ -342,7 +367,6 @@ def get_indice_pairs_implicit_gemm(
# for subm, if training, pair shape is [2, kv, ...]
# if not training, pair is [1, kv, ...]
pair = thalloc.allocated[AllocKeys.PairFwd]
pair_mask = thalloc.allocated[AllocKeys.PairMask]
mask_argsort = thalloc.allocated[AllocKeys.MaskArgSort]
pair_mask_in_splits = [
......@@ -367,7 +391,6 @@ def get_indice_pairs_implicit_gemm(
if is_train:
pair_mask_bwd = thalloc.allocated[AllocKeys.PairMaskBwd]
mask_argsort_bwd = thalloc.allocated[AllocKeys.MaskArgSortBwd]
mask_argsort_fwd = thalloc.allocated[AllocKeys.MaskArgSort]
if not is_train:
pair_mask_bwd_splits: List[torch.Tensor] = []
......@@ -388,11 +411,6 @@ def get_indice_pairs_implicit_gemm(
return (out_inds, indice_num_per_loc, pair_fwd, pair_bwd,
pair_mask_fwd_splits, pair_mask_bwd_splits,
mask_argsort_fwd_splits, mask_argsort_bwd_splits, masks)
t = 0
if DEBUG:
CONV.stream_synchronize(stream)
t = time.time()
assert indices.is_cuda, "implicit gemm only support cuda"
ndim = indices.shape[1] - 1
kv: int = functools.reduce(lambda x, y: x * y, ksize, 1)
......@@ -452,8 +470,6 @@ def get_indice_pairs_implicit_gemm(
masks = [first.astype(np.uint32), second.astype(np.uint32)]
else:
masks = [np.array([0xffffffff], dtype=np.uint32)]
# torch.cuda.synchronize()
# print("SUBM0", time.time() - t)
if subm:
out_inds = indices
......@@ -508,10 +524,6 @@ def get_indice_pairs_implicit_gemm(
mask_argsort_in_splits = [
mask_argsort[i] for i in range(mask_split_count)
]
if DEBUG:
CONV.stream_synchronize(stream)
print("SUBM", time.time() - t)
if is_train:
return (out_inds, indice_num_per_loc, pair[0], pair[1],
pair_mask_in_splits, [], mask_argsort_in_splits, [], masks)
......@@ -519,11 +531,10 @@ def get_indice_pairs_implicit_gemm(
return (out_inds, indice_num_per_loc, pair[0], torch.Tensor(),
pair_mask_in_splits, [], mask_argsort_in_splits, [], masks)
else:
if DEBUG:
CONV.stream_synchronize(stream)
print("REGU_PREPARE", time.time() - t)
t = time.time()
max_num_act = SpconvOps.get_handcrafted_max_act_out(
indices.shape[0], ksize, stride, padding, dilation)
if transpose:
max_num_act = kv * indices.shape[0]
pair_bwd = pair
pair_bwd_tv = pair_tv
......@@ -531,8 +542,38 @@ def get_indice_pairs_implicit_gemm(
dtype=indice_dtype,
device=indices.device)
indice_pairs_uniq_tv = torch_tensor_to_tv(indice_pairs_uniq)
hashdata = _HashData(0, use_int64_hash_k, indices.device)
indice_pairs_uniq_bkp_tv = tv.Tensor()
if direct_table:
# print("HASH SIZE", max_num_act * 2)
hashdata = _HashData(max_num_act, use_int64_hash_k, indices.device,
SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE)
indice_pairs_uniq_bkp = torch.empty((pair.numel() + 1, ),
dtype=indice_dtype,
device=indices.device)
indice_pairs_uniq_bkp_tv = torch_tensor_to_tv(
indice_pairs_uniq_bkp)
with timer.record("gen_conv_inds_stage1", stream):
SpconvOps.generate_conv_inds_mask_stage1(inds_tv,
SpconvOps.generate_conv_inds_mask_stage1_direct_table(
inds_tv,
hashdata.hashdata_k_tv,
hashdata.hashdata_v_tv,
pair_bwd_tv,
indice_pairs_uniq_bkp_tv,
indice_num_per_loc_tv,
batch_size=batch_size,
output_dims=out_shape,
input_dims=spatial_shape,
ksize=ksize,
stride=stride,
padding=padding,
dilation=dilation,
transposed=transpose,
stream_int=stream)
else:
with timer.record("gen_conv_inds_stage1", stream):
SpconvOps.generate_conv_inds_mask_stage1(
inds_tv,
pair_bwd_tv,
indice_pairs_uniq_tv,
indice_num_per_loc_tv,
......@@ -545,23 +586,31 @@ def get_indice_pairs_implicit_gemm(
dilation=dilation,
transposed=transpose,
stream_int=stream)
if DEBUG:
CONV.stream_synchronize(stream)
print("REGU_S1", time.time() - t)
t = time.time()
uniq_out_indices_offset_tv = tv.Tensor()
with timer.record(f"unique_{indice_pairs_uniq.shape[0]}", stream):
if direct_table:
uniq_cnt = torch.zeros([1],
dtype=torch.int32,
device=indices.device)
uniq_cnt_tv = torch_tensor_to_tv(uniq_cnt)
num_act_out = SpconvOps.unique_hash(hashdata.hashdata_k_tv,
hashdata.hashdata_v_tv,
uniq_cnt_tv,
indice_pairs_uniq_tv,
num_out_act_bound, stream)
uniq_out_indices_offset_tv = indice_pairs_uniq_tv
raw_out_indices_offset_tv = indice_pairs_uniq_bkp_tv
else:
uniq_res = indice_pairs_uniq.unique()
num_act_out = uniq_res.shape[0] - 1
uniq_out_indices_offset_tv = torch_tensor_to_tv(uniq_res)
raw_out_indices_offset_tv = indice_pairs_uniq_tv
if num_out_act_bound > 0 and num_act_out > num_out_act_bound:
num_act_out = num_out_act_bound
if DEBUG:
CONV.stream_synchronize(stream)
print("REGU_UNIQ", time.time() - t)
t = time.time()
with timer.record(f"alloc_stage2", stream):
uniq_res_tv = torch_tensor_to_tv(uniq_res)
out_inds = torch.empty((num_act_out, indices.shape[1]),
dtype=indices.dtype,
device=indices.device)
......@@ -574,15 +623,18 @@ def get_indice_pairs_implicit_gemm(
dtype=torch.int32,
device=indices.device)
pair_fwd_tv = torch_tensor_to_tv(pair_fwd)
pair_mask_fwd_tv = torch_tensor_to_tv(pair_mask_fwd, dtype=tv.uint32)
pair_mask_fwd_tv = torch_tensor_to_tv(pair_mask_fwd,
dtype=tv.uint32)
pair_mask_bwd = torch.Tensor()
pair_mask_bwd_tv = tv.Tensor()
if is_train:
pair_mask_bwd = torch.zeros((mask_split_count, indices.shape[0]),
pair_mask_bwd = torch.zeros(
(mask_split_count, indices.shape[0]),
dtype=torch.int32,
device=indices.device)
pair_mask_bwd_tv = torch_tensor_to_tv(pair_mask_bwd,
dtype=tv.uint32)
if not direct_table:
hashdata = _HashData(out_inds.shape[0], use_int64_hash_k,
indices.device)
......@@ -591,19 +643,28 @@ def get_indice_pairs_implicit_gemm(
# device=indices.device)
out_inds_tv = torch_tensor_to_tv(out_inds)
# hashdata_tv = torch_tensor_to_tv(hashdata, dtype=tv.custom64)
if DEBUG:
with timer.record(f"gen_conv_inds_stage2_{num_act_out}", stream):
stage2_fn = SpconvOps.generate_conv_inds_mask_stage2
if direct_table:
SpconvOps.assign_output_direct_hash(indice_pairs_uniq_tv,
out_inds_tv,
batch_size=batch_size,
output_dims=out_shape,
input_dims=spatial_shape,
ksize=ksize,
stride=stride,
padding=padding,
dilation=dilation,
stream_int=stream)
stage2_fn = SpconvOps.generate_conv_inds_stage2_mask_direct_table
CONV.stream_synchronize(stream)
print("REGU_S2_PREPARE", time.time() - t)
t = time.time()
with timer.record("gen_conv_inds_stage2", stream):
SpconvOps.generate_conv_inds_mask_stage2(inds_tv,
stage2_fn(inds_tv,
hashdata.hashdata_k_tv,
hashdata.hashdata_v_tv,
pair_fwd_tv,
pair_bwd_tv,
uniq_res_tv,
indice_pairs_uniq_tv,
uniq_out_indices_offset_tv,
raw_out_indices_offset_tv,
out_inds_tv,
pair_mask_fwd_tv,
pair_mask_bwd_tv,
......@@ -617,12 +678,6 @@ def get_indice_pairs_implicit_gemm(
dilation=dilation,
transposed=transpose,
stream_int=stream)
if DEBUG:
CONV.stream_synchronize(stream)
print("REGU_S2", time.time() - t)
t = time.time()
mask_argsort_fwd = torch.empty((mask_split_count, out_inds.shape[0]),
dtype=torch.int32,
device=indices.device)
......@@ -693,10 +748,6 @@ def get_indice_pairs_implicit_gemm(
SpconvOps.sort_1d_by_key_allocator(
pair_mask_bwd_tv[0], alloc.alloc,
mask_argsort_bwd_tv[0], stream)
if DEBUG:
CONV.stream_synchronize(stream)
print("REGU_S2_FINISH", time.time() - t)
t = time.time()
# CONV.stream_synchronize(stream)
if not is_train:
......@@ -716,9 +767,6 @@ def get_indice_pairs_implicit_gemm(
mask_argsort_fwd_splits = [
mask_argsort_fwd[i] for i in range(mask_split_count)
]
if DEBUG:
CONV.stream_synchronize(stream)
print("REGU", time.time() - t)
return (out_inds, indice_num_per_loc, pair_fwd, pair_bwd,
pair_mask_fwd_splits, pair_mask_bwd_splits,
......@@ -769,8 +817,7 @@ def indice_conv(features: torch.Tensor,
stream = get_current_stream()
ConvGemmOps.indice_conv(alloc, ext_mm, GEMM_CPP, ALL_WEIGHT_IS_KRSC,
FILTER_HWIO, features_tv, filters_tv,
indice_pairs_tv, indice_pair_num_tv,
arch,
indice_pairs_tv, indice_pair_num_tv, arch,
num_activate_out, inverse, subm, algo.value,
stream)
out_features = alloc.allocated[AllocKeys.OutFeatures]
......@@ -1018,8 +1065,8 @@ def indice_conv_backward(features: torch.Tensor,
ALL_WEIGHT_IS_KRSC, FILTER_HWIO,
features_tv, filters_tv, out_bp_tv,
indice_pairs_tv, indice_pair_num_tv,
arch,
inverse, subm, algo.value, stream)
arch, inverse, subm, algo.value,
stream)
din = alloc.allocated[AllocKeys.DIn]
df = alloc.allocated[AllocKeys.DFilters]
return din, df
......@@ -1369,8 +1416,8 @@ def implicit_gemm(features: torch.Tensor,
mask_width = ConvGemmOps.implicit_gemm(
alloc, CONV_CPP, features_tv, filters_tv, pair_fwd_tv,
pair_mask_fwd_splits_tv, mask_argsort_fwd_splits_tv,
num_activate_out, mask_tv, arch, is_train, is_subm, stream, timer_cpp,
auto_fp32_accum, fp32_accum)
num_activate_out, mask_tv, arch, is_train, is_subm, stream,
timer_cpp, auto_fp32_accum, fp32_accum)
out_features = alloc.allocated[AllocKeys.OutFeatures]
mask_output_fwd = alloc.allocated.get(AllocKeys.MaskOutputFwd, None)
if is_train:
......@@ -1460,7 +1507,7 @@ def implicit_gemm(features: torch.Tensor,
# CONV.stream_synchronize(stream)
# t = time.time()
print(tune_res.algo_desp, "REF", features_tv.shape, filters.shape)
# print(tune_res.algo_desp, "REF", features_tv.shape, filters.shape)
# with tv.measure_and_print("f16 time"):
with timer.record("implicit_gemm", stream):
for j in range(num_split):
......@@ -1921,8 +1968,10 @@ def indice_maxpool_implicit_gemm_backward(features, out_features, out_bp,
indice_pairs_tv, stream)
return din
def indice_avgpool_implicit_gemm(features: torch.Tensor,
indice_pairs: torch.Tensor, num_activate_out, calc_count: bool):
indice_pairs: torch.Tensor, num_activate_out,
calc_count: bool):
# torch.cuda.synchronize()
# t = time.time()
stream = get_current_stream()
......@@ -1943,12 +1992,13 @@ def indice_avgpool_implicit_gemm(features: torch.Tensor,
count_out = torch.Tensor()
count_out_tv = tv.Tensor()
if calc_count:
count_out = torch.zeros((num_activate_out,),
count_out = torch.zeros((num_activate_out, ),
dtype=torch.int32,
device=features.device)
count_out_tv = torch_tensor_to_tv(count_out)
SpconvOps.avgpool_implicit_gemm_forward(out_features_tv, features_tv,
indice_pairs_tv, count_out_tv, stream)
indice_pairs_tv, count_out_tv,
stream)
# CONV.stream_synchronize(stream)
# print("M", time.time() - t)
......@@ -1956,12 +2006,13 @@ def indice_avgpool_implicit_gemm(features: torch.Tensor,
return out_features, count_out
def indice_avgpool_implicit_gemm_backward(out_bp,
indice_pairs, count_out):
def indice_avgpool_implicit_gemm_backward(out_bp, indice_pairs, count_out):
# torch.cuda.synchronize()
# t = time.time()
out_channel = out_bp.shape[-1]
din = torch.zeros((indice_pairs.shape[1], out_bp.shape[1]), dtype=out_bp.dtype, device=out_bp.device)
din = torch.zeros((indice_pairs.shape[1], out_bp.shape[1]),
dtype=out_bp.dtype,
device=out_bp.device)
assert out_bp.is_cuda
if not out_bp.is_contiguous():
out_bp = out_bp.contiguous()
......@@ -1972,7 +2023,8 @@ def indice_avgpool_implicit_gemm_backward(out_bp,
din_tv = torch_tensor_to_tv(din)
indice_pairs_tv = torch_tensor_to_tv(indice_pairs)
SpconvOps.avgpool_implicit_gemm_backward(out_bp_tv, din_tv,
indice_pairs_tv, count_out_tv, stream)
indice_pairs_tv, count_out_tv,
stream)
return din
......
......@@ -323,6 +323,8 @@ def main():
# pickle.dump((voxels, coors, spatial_shape), f)
with open(Path(__file__).parent / "data" / "test_spconv.pkl", "rb") as f:
(voxels, coors, spatial_shape) = pickle.load(f)
# voxels, coors, spatial_shape = waymo_data_large()
print(spatial_shape)
print(voxels.shape)
# voxels = voxels[:100]
......@@ -366,15 +368,14 @@ def main():
dout = np.random.uniform(-0.2, 0.2, out.features.shape).astype(np.float32)
dout_t = torch.from_numpy(dout).to(device).to(dtype)
print(out.spatial_shape, out.features.mean(), out.features.max(),
print(out.spatial_shape, out.features.sum(1).mean(), out.features.max(),
out.features.min())
times = []
show_metrics = False
with torch.no_grad():
for i in range(20):
print("------------")
torch.cuda.synchronize()
t = time.time()
for i in range(100):
# print("------------")
with tv.measure_duration() as measure:
out_nograd = net(voxels_th, coors_th, 1, show_metrics)
# res = timer.collect_by_name("forward", timer.get_all_pair_time())
# res2 = timer.collect_by_name("forward0", timer.get_all_pair_time())
......@@ -383,14 +384,19 @@ def main():
# print(timer.get_all_pair_time())
# print(sum(timer.get_all_pair_time().values()))
torch.cuda.synchronize()
# sort_bench()
times.append(time.time() - t)
times.append(measure.duration)
if show_metrics:
timer = out_nograd._timer
items = list(timer.get_all_pair_time().items())
items.sort(key=lambda x: x[0])
print("SUM TIME:", sum([x[1] for x in items]))
print(json.dumps(dict(items), indent=2))
inds_sum = 0
for k, v in items:
if "gen_pairs" in k:
inds_sum += v
print("SUM GEN INDS:", inds_sum)
# state = net.state_dict()
# state.pop("net.2.max_num_voxels_during_training")
......
......@@ -231,8 +231,8 @@ def _test_impgemm_conv_cuda(subm: bool):
# out_channels = [32, 48, 64]
in_channels = [32, 47]
out_channels = [32, 48, 62]
in_channels = [32]
out_channels = [32]
# in_channels = [32]
# out_channels = [32]
multiple_base = 16
if subm:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment