Commit 73a5ce7d authored by yan.yan's avatar yan.yan
Browse files

add direct table

parent 0c07559f
......@@ -95,13 +95,19 @@ class AllocKeys:
HashV = "HashV"
ThrustTemp = "ThrustTemp"
TightUniqueCount = "TightUniqueCount"
SPCONV_DEBUG_WEIGHT = False
SPCONV_CPP_INDICE_PAIRS = False
SPCONV_CPP_INDICE_PAIRS_IGEMM = False
SPCONV_CPP_GEMM = False
# currently use cpp pair gen is slightly slower than python, I don't know why.
SPCONV_CPP_INDICE_PAIRS_IGEMM = os.getenv("SPCONV_CPP_INDICE_PAIRS_IGEMM", "0") == "1"
SPCONV_FX_TRACE_MODE = os.getenv("SPCONV_FX_TRACE_MODE", "0") == "1"
\ No newline at end of file
SPCONV_CPP_GEMM = True
SPCONV_FX_TRACE_MODE = os.getenv("SPCONV_FX_TRACE_MODE", "0") == "1"
SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE = 1.1
\ No newline at end of file
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
from cumm.tensorview import Tensor
from cumm.tensorview import CUDAKernelTimer
class ThrustCustomAllocatorV2:
alloc_func: Callable[int, int]
class SpconvOps:
......@@ -92,6 +93,55 @@ class SpconvOps:
"""
...
@staticmethod
def generate_conv_inds_mask_stage1_direct_table(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_num_per_loc: Tensor, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> None:
"""
Args:
indices:
hashdata_k:
hashdata_v:
indice_pairs_bwd:
indice_pairs_uniq:
indice_num_per_loc:
batch_size:
output_dims:
input_dims:
ksize:
stride:
padding:
dilation:
transposed:
stream_int:
"""
...
@staticmethod
def unique_hash(hashdata_k: Tensor, hashdata_v: Tensor, uniq_cnt: Tensor, out_indices_offset: Tensor, num_out_bound: int, stream_int: int = 0) -> int:
"""
Args:
hashdata_k:
hashdata_v:
uniq_cnt:
out_indices_offset:
num_out_bound:
stream_int:
"""
...
@staticmethod
def assign_output_direct_hash(out_indices_offset: Tensor, out_indices: Tensor, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], stream_int: int = 0) -> None:
"""
Args:
out_indices_offset:
out_indices:
batch_size:
output_dims:
input_dims:
ksize:
stride:
padding:
dilation:
stream_int:
"""
...
@staticmethod
def generate_conv_inds_mask_stage2(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> int:
"""
Args:
......@@ -118,6 +168,32 @@ class SpconvOps:
"""
...
@staticmethod
def generate_conv_inds_stage2_mask_direct_table(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> int:
"""
Args:
indices:
hashdata_k:
hashdata_v:
indice_pairs_fwd:
indice_pairs_bwd:
indice_pairs_uniq:
indice_pairs_uniq_before_sort:
out_inds:
mask_fwd:
mask_bwd:
num_out_act:
batch_size:
output_dims:
input_dims:
ksize:
stride:
padding:
dilation:
transposed:
stream_int:
"""
...
@staticmethod
def generate_subm_conv_inds(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs: Tensor, out_inds: Tensor, indice_num_per_loc: Tensor, batch_size: int, input_dims: List[int], ksize: List[int], dilation: List[int], indice_pair_mask: Tensor = Tensor(), backward: bool = False, stream_int: int = 0) -> int:
"""
Args:
......@@ -427,30 +503,45 @@ class SpconvOps:
@staticmethod
def get_int32_max() -> int: ...
@staticmethod
def get_indice_gen_workspace_size(kv: int, num_act_in: int, num_act_out_bound: int, subm: bool, use_int64_hash_k: bool) -> int:
def get_handcrafted_max_act_out(num_act_in: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int]) -> int:
"""
Args:
num_act_in:
ksize:
stride:
padding:
dilation:
"""
...
@staticmethod
def get_indice_gen_workspace_size(kv: int, num_act_in: int, num_act_out_bound: int, max_act_out_in_theory: int, subm: bool, use_int64_hash_k: bool, direct_table: bool) -> int:
"""
Args:
kv:
num_act_in:
num_act_out_bound:
max_act_out_in_theory:
subm:
use_int64_hash_k:
direct_table:
"""
...
@staticmethod
def get_indice_gen_tensors_from_workspace(workspace, kv: int, num_act_in: int, num_act_out_bound: int, subm: bool, use_int64_hash_k: bool) -> Dict[str, Tensor]:
def get_indice_gen_tensors_from_workspace(workspace, kv: int, num_act_in: int, num_act_out_bound: int, max_act_out_in_theory: int, subm: bool, use_int64_hash_k: bool, direct_table: bool) -> Dict[str, Tensor]:
"""
Args:
workspace:
kv:
num_act_in:
num_act_out_bound:
max_act_out_in_theory:
subm:
use_int64_hash_k:
direct_table:
"""
...
@staticmethod
def get_indice_pairs_implicit_gemm(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, is_train: bool, stream_int: int = 0, num_out_act_bound: int = -1) -> Tuple[Tensor, int]:
def get_indice_pairs_implicit_gemm(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, is_train: bool, stream_int: int = 0, num_out_act_bound: int = -1, timer: CUDAKernelTimer = CUDAKernelTimer(False), direct_table: bool = False, preallocated: Dict[str, Tensor] = {}) -> Tuple[Tensor, int]:
"""
Args:
allocator:
......@@ -468,6 +559,9 @@ class SpconvOps:
is_train:
stream_int:
num_out_act_bound:
timer:
direct_table:
preallocated:
"""
...
@staticmethod
......
This diff is collapsed.
This diff is collapsed.
......@@ -33,13 +33,21 @@ _TORCH_DTYPE_TO_TV = {
torch.int16: tv.int16,
torch.uint8: tv.uint8,
}
_TV_DTYPE_TO_TORCH = {v: k for k, v in _TORCH_DTYPE_TO_TV.items()}
_TORCH_UINT_WORKAROUNDS = {
tv.uint32: tv.int32,
tv.uint16: tv.int16,
tv.uint64: tv.int64
}
_TV_DTYPE_TO_TORCH = {v: k for k, v in _TORCH_DTYPE_TO_TV.items()}
_TV_DTYPE_TO_TORCH.update({
tv.uint32: torch.int32,
tv.uint16: torch.int16,
tv.uint64: torch.int64
})
_ALL_INTS = {
tv.int32, tv.int16, tv.int8, tv.int64, tv.uint64, tv.uint8, tv.uint32,
tv.uint16
......@@ -106,91 +114,66 @@ class TorchAllocator(ExternalAllocator):
device: int, stream: int = 0, is_temp_memory: bool = False) -> tv.Tensor:
# TODO free memory by name if its already free by pointer.
# provide a name if you want to access it after c++ function exit.
torch_uint_workaround = dtype in _TORCH_UINT_WORKAROUNDS
dtype_bkp = dtype
if dtype in _TORCH_UINT_WORKAROUNDS:
# assert name == "", "must be temp memory for uint dtypes"
dtype = _TORCH_UINT_WORKAROUNDS[dtype]
th_dtype = _TV_DTYPE_TO_TORCH[dtype]
if device == -1:
dev = self.cpudevice
else:
dev = self.gpudevice
ten = torch.zeros(shape, dtype=th_dtype, device=dev)
ten_tv = torch_tensor_to_tv(ten)
self.allocated[ten.data_ptr()] = ten
ten_tv = torch_tensor_to_tv(ten, dtype_bkp)
self.allocated[ten_tv.byte_pointer()] = ten
if name and not is_temp_memory:
self.allocated[name] = ten
if torch_uint_workaround:
return ten_tv.type_view(dtype_bkp)
return ten_tv
def empty(self, name: str, shape: List[int], dtype: int,
device: int, stream: int = 0, is_temp_memory: bool = False) -> tv.Tensor:
torch_uint_workaround = dtype in _TORCH_UINT_WORKAROUNDS
dtype_bkp = dtype
if dtype in _TORCH_UINT_WORKAROUNDS:
# assert name == "", "must be temp memory for uint dtypes"
dtype = _TORCH_UINT_WORKAROUNDS[dtype]
th_dtype = _TV_DTYPE_TO_TORCH[dtype]
if device == -1:
dev = self.cpudevice
else:
dev = self.gpudevice
ten = torch.empty(shape, dtype=th_dtype, device=dev)
ten_tv = torch_tensor_to_tv(ten)
self.allocated[ten.data_ptr()] = ten
ten_tv = torch_tensor_to_tv(ten, dtype_bkp)
self.allocated[ten_tv.byte_pointer()] = ten
if name and not is_temp_memory:
self.allocated[name] = ten
if torch_uint_workaround:
return ten_tv.type_view(dtype_bkp)
return ten_tv
def full_int(self, name: str, shape: List[int], value: int, dtype: int,
device: int, stream: int = 0, is_temp_memory: bool = False) -> tv.Tensor:
if dtype in _TORCH_UINT_WORKAROUNDS and value < 0:
raise NotImplementedError("you can't use full for unsigned dtypes")
torch_uint_workaround = dtype in _TORCH_UINT_WORKAROUNDS
dtype_bkp = dtype
if dtype in _TORCH_UINT_WORKAROUNDS:
assert name == "", "must be temp memory for uint dtypes"
dtype = _TORCH_UINT_WORKAROUNDS[dtype]
th_dtype = _TV_DTYPE_TO_TORCH[dtype]
if device == -1:
dev = self.cpudevice
else:
dev = self.gpudevice
ten = torch.full(shape, value, dtype=th_dtype, device=dev)
ten_tv = torch_tensor_to_tv(ten)
self.allocated[ten.data_ptr()] = ten
ten_tv = torch_tensor_to_tv(ten, dtype_bkp)
self.allocated[ten_tv.byte_pointer()] = ten
if name and not is_temp_memory:
self.allocated[name] = ten
if torch_uint_workaround:
return ten_tv.type_view(dtype_bkp)
return ten_tv
def full_float(self, name: str, shape: List[int], value: float, dtype: int,
device: int, stream: int = 0, is_temp_memory: bool = False) -> tv.Tensor:
if dtype in _TORCH_UINT_WORKAROUNDS and value < 0:
raise NotImplementedError("you can't use full for unsigned dtypes")
torch_uint_workaround = dtype in _TORCH_UINT_WORKAROUNDS
dtype_bkp = dtype
if dtype in _TORCH_UINT_WORKAROUNDS:
assert name == "", "must be temp memory for uint dtypes"
dtype = _TORCH_UINT_WORKAROUNDS[dtype]
th_dtype = _TV_DTYPE_TO_TORCH[dtype]
if device == -1:
dev = self.cpudevice
else:
dev = self.gpudevice
ten = torch.full(shape, value, dtype=th_dtype, device=dev)
ten_tv = torch_tensor_to_tv(ten)
self.allocated[ten.data_ptr()] = ten
ten_tv = torch_tensor_to_tv(ten, dtype_bkp)
self.allocated[ten_tv.byte_pointer()] = ten
if name and not is_temp_memory:
self.allocated[name] = ten
if torch_uint_workaround:
return ten_tv.type_view(dtype_bkp)
return ten_tv
def get_tensor_by_name(self, name: str):
......
This diff is collapsed.
......@@ -323,6 +323,8 @@ def main():
# pickle.dump((voxels, coors, spatial_shape), f)
with open(Path(__file__).parent / "data" / "test_spconv.pkl", "rb") as f:
(voxels, coors, spatial_shape) = pickle.load(f)
# voxels, coors, spatial_shape = waymo_data_large()
print(spatial_shape)
print(voxels.shape)
# voxels = voxels[:100]
......@@ -366,16 +368,15 @@ def main():
dout = np.random.uniform(-0.2, 0.2, out.features.shape).astype(np.float32)
dout_t = torch.from_numpy(dout).to(device).to(dtype)
print(out.spatial_shape, out.features.mean(), out.features.max(),
print(out.spatial_shape, out.features.sum(1).mean(), out.features.max(),
out.features.min())
times = []
show_metrics = False
with torch.no_grad():
for i in range(20):
print("------------")
torch.cuda.synchronize()
t = time.time()
out_nograd = net(voxels_th, coors_th, 1, show_metrics)
for i in range(100):
# print("------------")
with tv.measure_duration() as measure:
out_nograd = net(voxels_th, coors_th, 1, show_metrics)
# res = timer.collect_by_name("forward", timer.get_all_pair_time())
# res2 = timer.collect_by_name("forward0", timer.get_all_pair_time())
......@@ -383,14 +384,19 @@ def main():
# print(timer.get_all_pair_time())
# print(sum(timer.get_all_pair_time().values()))
torch.cuda.synchronize()
# sort_bench()
times.append(time.time() - t)
times.append(measure.duration)
if show_metrics:
timer = out_nograd._timer
items = list(timer.get_all_pair_time().items())
items.sort(key=lambda x: x[0])
print("SUM TIME:", sum([x[1] for x in items]))
print(json.dumps(dict(items), indent=2))
inds_sum = 0
for k, v in items:
if "gen_pairs" in k:
inds_sum += v
print("SUM GEN INDS:", inds_sum)
# state = net.state_dict()
# state.pop("net.2.max_num_voxels_during_training")
......
......@@ -231,8 +231,8 @@ def _test_impgemm_conv_cuda(subm: bool):
# out_channels = [32, 48, 64]
in_channels = [32, 47]
out_channels = [32, 48, 62]
in_channels = [32]
out_channels = [32]
# in_channels = [32]
# out_channels = [32]
multiple_base = 16
if subm:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment