Commit 73a5ce7d authored by yan.yan's avatar yan.yan
Browse files

add direct table

parent 0c07559f
...@@ -95,13 +95,19 @@ class AllocKeys: ...@@ -95,13 +95,19 @@ class AllocKeys:
HashV = "HashV" HashV = "HashV"
ThrustTemp = "ThrustTemp" ThrustTemp = "ThrustTemp"
TightUniqueCount = "TightUniqueCount"
SPCONV_DEBUG_WEIGHT = False SPCONV_DEBUG_WEIGHT = False
SPCONV_CPP_INDICE_PAIRS = False SPCONV_CPP_INDICE_PAIRS = False
SPCONV_CPP_INDICE_PAIRS_IGEMM = False
SPCONV_CPP_GEMM = False # currently use cpp pair gen is slightly slower than python, I don't know why.
SPCONV_CPP_INDICE_PAIRS_IGEMM = os.getenv("SPCONV_CPP_INDICE_PAIRS_IGEMM", "0") == "1"
SPCONV_CPP_GEMM = True
SPCONV_FX_TRACE_MODE = os.getenv("SPCONV_FX_TRACE_MODE", "0") == "1" SPCONV_FX_TRACE_MODE = os.getenv("SPCONV_FX_TRACE_MODE", "0") == "1"
SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE = 1.1
\ No newline at end of file
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue from pccm.stubs import EnumValue, EnumClassValue
from cumm.tensorview import Tensor from cumm.tensorview import Tensor
from cumm.tensorview import CUDAKernelTimer
class ThrustCustomAllocatorV2: class ThrustCustomAllocatorV2:
alloc_func: Callable[int, int] alloc_func: Callable[int, int]
class SpconvOps: class SpconvOps:
...@@ -92,6 +93,55 @@ class SpconvOps: ...@@ -92,6 +93,55 @@ class SpconvOps:
""" """
... ...
@staticmethod @staticmethod
def generate_conv_inds_mask_stage1_direct_table(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_num_per_loc: Tensor, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> None:
"""
Args:
indices:
hashdata_k:
hashdata_v:
indice_pairs_bwd:
indice_pairs_uniq:
indice_num_per_loc:
batch_size:
output_dims:
input_dims:
ksize:
stride:
padding:
dilation:
transposed:
stream_int:
"""
...
@staticmethod
def unique_hash(hashdata_k: Tensor, hashdata_v: Tensor, uniq_cnt: Tensor, out_indices_offset: Tensor, num_out_bound: int, stream_int: int = 0) -> int:
"""
Args:
hashdata_k:
hashdata_v:
uniq_cnt:
out_indices_offset:
num_out_bound:
stream_int:
"""
...
@staticmethod
def assign_output_direct_hash(out_indices_offset: Tensor, out_indices: Tensor, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], stream_int: int = 0) -> None:
"""
Args:
out_indices_offset:
out_indices:
batch_size:
output_dims:
input_dims:
ksize:
stride:
padding:
dilation:
stream_int:
"""
...
@staticmethod
def generate_conv_inds_mask_stage2(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> int: def generate_conv_inds_mask_stage2(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> int:
""" """
Args: Args:
...@@ -118,6 +168,32 @@ class SpconvOps: ...@@ -118,6 +168,32 @@ class SpconvOps:
""" """
... ...
@staticmethod @staticmethod
def generate_conv_inds_stage2_mask_direct_table(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> int:
"""
Args:
indices:
hashdata_k:
hashdata_v:
indice_pairs_fwd:
indice_pairs_bwd:
indice_pairs_uniq:
indice_pairs_uniq_before_sort:
out_inds:
mask_fwd:
mask_bwd:
num_out_act:
batch_size:
output_dims:
input_dims:
ksize:
stride:
padding:
dilation:
transposed:
stream_int:
"""
...
@staticmethod
def generate_subm_conv_inds(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs: Tensor, out_inds: Tensor, indice_num_per_loc: Tensor, batch_size: int, input_dims: List[int], ksize: List[int], dilation: List[int], indice_pair_mask: Tensor = Tensor(), backward: bool = False, stream_int: int = 0) -> int: def generate_subm_conv_inds(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs: Tensor, out_inds: Tensor, indice_num_per_loc: Tensor, batch_size: int, input_dims: List[int], ksize: List[int], dilation: List[int], indice_pair_mask: Tensor = Tensor(), backward: bool = False, stream_int: int = 0) -> int:
""" """
Args: Args:
...@@ -427,30 +503,45 @@ class SpconvOps: ...@@ -427,30 +503,45 @@ class SpconvOps:
@staticmethod @staticmethod
def get_int32_max() -> int: ... def get_int32_max() -> int: ...
@staticmethod @staticmethod
def get_indice_gen_workspace_size(kv: int, num_act_in: int, num_act_out_bound: int, subm: bool, use_int64_hash_k: bool) -> int: def get_handcrafted_max_act_out(num_act_in: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int]) -> int:
"""
Args:
num_act_in:
ksize:
stride:
padding:
dilation:
"""
...
@staticmethod
def get_indice_gen_workspace_size(kv: int, num_act_in: int, num_act_out_bound: int, max_act_out_in_theory: int, subm: bool, use_int64_hash_k: bool, direct_table: bool) -> int:
""" """
Args: Args:
kv: kv:
num_act_in: num_act_in:
num_act_out_bound: num_act_out_bound:
max_act_out_in_theory:
subm: subm:
use_int64_hash_k: use_int64_hash_k:
direct_table:
""" """
... ...
@staticmethod @staticmethod
def get_indice_gen_tensors_from_workspace(workspace, kv: int, num_act_in: int, num_act_out_bound: int, subm: bool, use_int64_hash_k: bool) -> Dict[str, Tensor]: def get_indice_gen_tensors_from_workspace(workspace, kv: int, num_act_in: int, num_act_out_bound: int, max_act_out_in_theory: int, subm: bool, use_int64_hash_k: bool, direct_table: bool) -> Dict[str, Tensor]:
""" """
Args: Args:
workspace: workspace:
kv: kv:
num_act_in: num_act_in:
num_act_out_bound: num_act_out_bound:
max_act_out_in_theory:
subm: subm:
use_int64_hash_k: use_int64_hash_k:
direct_table:
""" """
... ...
@staticmethod @staticmethod
def get_indice_pairs_implicit_gemm(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, is_train: bool, stream_int: int = 0, num_out_act_bound: int = -1) -> Tuple[Tensor, int]: def get_indice_pairs_implicit_gemm(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, is_train: bool, stream_int: int = 0, num_out_act_bound: int = -1, timer: CUDAKernelTimer = CUDAKernelTimer(False), direct_table: bool = False, preallocated: Dict[str, Tensor] = {}) -> Tuple[Tensor, int]:
""" """
Args: Args:
allocator: allocator:
...@@ -468,6 +559,9 @@ class SpconvOps: ...@@ -468,6 +559,9 @@ class SpconvOps:
is_train: is_train:
stream_int: stream_int:
num_out_act_bound: num_out_act_bound:
timer:
direct_table:
preallocated:
""" """
... ...
@staticmethod @staticmethod
......
This diff is collapsed.
This diff is collapsed.
...@@ -33,13 +33,21 @@ _TORCH_DTYPE_TO_TV = { ...@@ -33,13 +33,21 @@ _TORCH_DTYPE_TO_TV = {
torch.int16: tv.int16, torch.int16: tv.int16,
torch.uint8: tv.uint8, torch.uint8: tv.uint8,
} }
_TV_DTYPE_TO_TORCH = {v: k for k, v in _TORCH_DTYPE_TO_TV.items()}
_TORCH_UINT_WORKAROUNDS = { _TORCH_UINT_WORKAROUNDS = {
tv.uint32: tv.int32, tv.uint32: tv.int32,
tv.uint16: tv.int16, tv.uint16: tv.int16,
tv.uint64: tv.int64 tv.uint64: tv.int64
} }
_TV_DTYPE_TO_TORCH = {v: k for k, v in _TORCH_DTYPE_TO_TV.items()}
_TV_DTYPE_TO_TORCH.update({
tv.uint32: torch.int32,
tv.uint16: torch.int16,
tv.uint64: torch.int64
})
_ALL_INTS = { _ALL_INTS = {
tv.int32, tv.int16, tv.int8, tv.int64, tv.uint64, tv.uint8, tv.uint32, tv.int32, tv.int16, tv.int8, tv.int64, tv.uint64, tv.uint8, tv.uint32,
tv.uint16 tv.uint16
...@@ -106,91 +114,66 @@ class TorchAllocator(ExternalAllocator): ...@@ -106,91 +114,66 @@ class TorchAllocator(ExternalAllocator):
device: int, stream: int = 0, is_temp_memory: bool = False) -> tv.Tensor: device: int, stream: int = 0, is_temp_memory: bool = False) -> tv.Tensor:
# TODO free memory by name if its already free by pointer. # TODO free memory by name if its already free by pointer.
# provide a name if you want to access it after c++ function exit. # provide a name if you want to access it after c++ function exit.
torch_uint_workaround = dtype in _TORCH_UINT_WORKAROUNDS
dtype_bkp = dtype dtype_bkp = dtype
if dtype in _TORCH_UINT_WORKAROUNDS:
# assert name == "", "must be temp memory for uint dtypes"
dtype = _TORCH_UINT_WORKAROUNDS[dtype]
th_dtype = _TV_DTYPE_TO_TORCH[dtype] th_dtype = _TV_DTYPE_TO_TORCH[dtype]
if device == -1: if device == -1:
dev = self.cpudevice dev = self.cpudevice
else: else:
dev = self.gpudevice dev = self.gpudevice
ten = torch.zeros(shape, dtype=th_dtype, device=dev) ten = torch.zeros(shape, dtype=th_dtype, device=dev)
ten_tv = torch_tensor_to_tv(ten) ten_tv = torch_tensor_to_tv(ten, dtype_bkp)
self.allocated[ten.data_ptr()] = ten self.allocated[ten_tv.byte_pointer()] = ten
if name and not is_temp_memory: if name and not is_temp_memory:
self.allocated[name] = ten self.allocated[name] = ten
if torch_uint_workaround:
return ten_tv.type_view(dtype_bkp)
return ten_tv return ten_tv
def empty(self, name: str, shape: List[int], dtype: int, def empty(self, name: str, shape: List[int], dtype: int,
device: int, stream: int = 0, is_temp_memory: bool = False) -> tv.Tensor: device: int, stream: int = 0, is_temp_memory: bool = False) -> tv.Tensor:
torch_uint_workaround = dtype in _TORCH_UINT_WORKAROUNDS
dtype_bkp = dtype dtype_bkp = dtype
if dtype in _TORCH_UINT_WORKAROUNDS:
# assert name == "", "must be temp memory for uint dtypes"
dtype = _TORCH_UINT_WORKAROUNDS[dtype]
th_dtype = _TV_DTYPE_TO_TORCH[dtype] th_dtype = _TV_DTYPE_TO_TORCH[dtype]
if device == -1: if device == -1:
dev = self.cpudevice dev = self.cpudevice
else: else:
dev = self.gpudevice dev = self.gpudevice
ten = torch.empty(shape, dtype=th_dtype, device=dev) ten = torch.empty(shape, dtype=th_dtype, device=dev)
ten_tv = torch_tensor_to_tv(ten) ten_tv = torch_tensor_to_tv(ten, dtype_bkp)
self.allocated[ten.data_ptr()] = ten self.allocated[ten_tv.byte_pointer()] = ten
if name and not is_temp_memory: if name and not is_temp_memory:
self.allocated[name] = ten self.allocated[name] = ten
if torch_uint_workaround:
return ten_tv.type_view(dtype_bkp)
return ten_tv return ten_tv
def full_int(self, name: str, shape: List[int], value: int, dtype: int, def full_int(self, name: str, shape: List[int], value: int, dtype: int,
device: int, stream: int = 0, is_temp_memory: bool = False) -> tv.Tensor: device: int, stream: int = 0, is_temp_memory: bool = False) -> tv.Tensor:
if dtype in _TORCH_UINT_WORKAROUNDS and value < 0: if dtype in _TORCH_UINT_WORKAROUNDS and value < 0:
raise NotImplementedError("you can't use full for unsigned dtypes") raise NotImplementedError("you can't use full for unsigned dtypes")
torch_uint_workaround = dtype in _TORCH_UINT_WORKAROUNDS
dtype_bkp = dtype dtype_bkp = dtype
if dtype in _TORCH_UINT_WORKAROUNDS:
assert name == "", "must be temp memory for uint dtypes"
dtype = _TORCH_UINT_WORKAROUNDS[dtype]
th_dtype = _TV_DTYPE_TO_TORCH[dtype] th_dtype = _TV_DTYPE_TO_TORCH[dtype]
if device == -1: if device == -1:
dev = self.cpudevice dev = self.cpudevice
else: else:
dev = self.gpudevice dev = self.gpudevice
ten = torch.full(shape, value, dtype=th_dtype, device=dev) ten = torch.full(shape, value, dtype=th_dtype, device=dev)
ten_tv = torch_tensor_to_tv(ten) ten_tv = torch_tensor_to_tv(ten, dtype_bkp)
self.allocated[ten.data_ptr()] = ten self.allocated[ten_tv.byte_pointer()] = ten
if name and not is_temp_memory: if name and not is_temp_memory:
self.allocated[name] = ten self.allocated[name] = ten
if torch_uint_workaround:
return ten_tv.type_view(dtype_bkp)
return ten_tv return ten_tv
def full_float(self, name: str, shape: List[int], value: float, dtype: int, def full_float(self, name: str, shape: List[int], value: float, dtype: int,
device: int, stream: int = 0, is_temp_memory: bool = False) -> tv.Tensor: device: int, stream: int = 0, is_temp_memory: bool = False) -> tv.Tensor:
if dtype in _TORCH_UINT_WORKAROUNDS and value < 0: if dtype in _TORCH_UINT_WORKAROUNDS and value < 0:
raise NotImplementedError("you can't use full for unsigned dtypes") raise NotImplementedError("you can't use full for unsigned dtypes")
torch_uint_workaround = dtype in _TORCH_UINT_WORKAROUNDS
dtype_bkp = dtype dtype_bkp = dtype
if dtype in _TORCH_UINT_WORKAROUNDS:
assert name == "", "must be temp memory for uint dtypes"
dtype = _TORCH_UINT_WORKAROUNDS[dtype]
th_dtype = _TV_DTYPE_TO_TORCH[dtype] th_dtype = _TV_DTYPE_TO_TORCH[dtype]
if device == -1: if device == -1:
dev = self.cpudevice dev = self.cpudevice
else: else:
dev = self.gpudevice dev = self.gpudevice
ten = torch.full(shape, value, dtype=th_dtype, device=dev) ten = torch.full(shape, value, dtype=th_dtype, device=dev)
ten_tv = torch_tensor_to_tv(ten) ten_tv = torch_tensor_to_tv(ten, dtype_bkp)
self.allocated[ten.data_ptr()] = ten self.allocated[ten_tv.byte_pointer()] = ten
if name and not is_temp_memory: if name and not is_temp_memory:
self.allocated[name] = ten self.allocated[name] = ten
if torch_uint_workaround:
return ten_tv.type_view(dtype_bkp)
return ten_tv return ten_tv
def get_tensor_by_name(self, name: str): def get_tensor_by_name(self, name: str):
......
...@@ -26,7 +26,7 @@ from spconv.pytorch.core import ThrustSortAllocator ...@@ -26,7 +26,7 @@ from spconv.pytorch.core import ThrustSortAllocator
from spconv.pytorch.cppcore import TorchAllocator, torch_tensor_to_tv, get_current_stream, get_arch, TorchSpconvMatmul from spconv.pytorch.cppcore import TorchAllocator, torch_tensor_to_tv, get_current_stream, get_arch, TorchSpconvMatmul
from spconv.core_cc.csrc.sparse.all import SpconvOps from spconv.core_cc.csrc.sparse.all import SpconvOps
from spconv.core_cc.csrc.sparse.alloc import ExternalAllocator from spconv.core_cc.csrc.sparse.alloc import ExternalAllocator
from spconv.constants import SPCONV_CPP_INDICE_PAIRS, SPCONV_CPP_INDICE_PAIRS_IGEMM, SPCONV_CPP_GEMM from spconv.constants import SPCONV_CPP_INDICE_PAIRS, SPCONV_CPP_INDICE_PAIRS_IGEMM, SPCONV_CPP_GEMM, SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE
import spconv.core_cc as _ext import spconv.core_cc as _ext
from spconv.core_cc.csrc.sparse.convops.spops import ConvGemmOps from spconv.core_cc.csrc.sparse.convops.spops import ConvGemmOps
from spconv.utils import nullcontext from spconv.utils import nullcontext
...@@ -46,7 +46,7 @@ from cumm.gemm import codeops ...@@ -46,7 +46,7 @@ from cumm.gemm import codeops
from spconv.tools import CUDAKernelTimer from spconv.tools import CUDAKernelTimer
DEBUG = False DEBUG = False
DEBUG_INT64_HASH_K = True DEBUG_INT64_HASH_K = False
INT32_MAX = SpconvOps.get_int32_max() INT32_MAX = SpconvOps.get_int32_max()
...@@ -77,12 +77,17 @@ def get_deconv_output_size(input_size, kernel_size, stride, padding, dilation, ...@@ -77,12 +77,17 @@ def get_deconv_output_size(input_size, kernel_size, stride, padding, dilation,
class _HashData: class _HashData:
def __init__(self, num: int, use_i64: bool, device: torch.device) -> None:
def __init__(self,
num: int,
use_i64: bool,
device: torch.device,
rate: float = 2.0) -> None:
if use_i64: if use_i64:
self.hashdata_k = torch.empty((num * 2, ), self.hashdata_k = torch.empty((int(num * rate), ),
dtype=torch.int64, dtype=torch.int64,
device=device) device=device)
self.hashdata_v = torch.empty((num * 2, ), self.hashdata_v = torch.empty((int(num * rate), ),
dtype=torch.int32, dtype=torch.int32,
device=device) device=device)
self.hashdata_k_tv = torch_tensor_to_tv(self.hashdata_k) self.hashdata_k_tv = torch_tensor_to_tv(self.hashdata_k)
...@@ -91,7 +96,7 @@ class _HashData: ...@@ -91,7 +96,7 @@ class _HashData:
else: else:
self.hashdata = torch.empty(( self.hashdata = torch.empty((
2, 2,
num * 2, int(num * rate),
), ),
dtype=torch.int32, dtype=torch.int32,
device=device) device=device)
...@@ -309,7 +314,8 @@ def get_indice_pairs_implicit_gemm( ...@@ -309,7 +314,8 @@ def get_indice_pairs_implicit_gemm(
is_train: bool = True, is_train: bool = True,
alloc: Optional[ThrustSortAllocator] = None, alloc: Optional[ThrustSortAllocator] = None,
timer: CUDAKernelTimer = CUDAKernelTimer(False), timer: CUDAKernelTimer = CUDAKernelTimer(False),
num_out_act_bound: int = -1): num_out_act_bound: int = -1,
direct_table: bool = True):
""" """
Why return tuple? because pytorch seems don't support custom object in autograd. Why return tuple? because pytorch seems don't support custom object in autograd.
return: ( return: (
...@@ -323,14 +329,33 @@ def get_indice_pairs_implicit_gemm( ...@@ -323,14 +329,33 @@ def get_indice_pairs_implicit_gemm(
mask_argsort_bwd_splits, # torch.Tensor() if subm or inference mode mask_argsort_bwd_splits, # torch.Tensor() if subm or inference mode
masks, masks,
) )
direct_table: a hash-based regular conv pair gen algo to avoid unique operation.
runs faster than pytorch unique with num_voxel < 1000k.
""" """
stream = get_current_stream() stream = get_current_stream()
if SPCONV_CPP_INDICE_PAIRS_IGEMM: if SPCONV_CPP_INDICE_PAIRS_IGEMM:
thalloc = TorchAllocator(indices.device) thalloc = TorchAllocator(indices.device)
timer_cpp = tv.CUDAKernelTimer(False)
if timer._timer is not None:
timer_cpp = timer._timer
mask_tensor, num_act_out = SpconvOps.get_indice_pairs_implicit_gemm( mask_tensor, num_act_out = SpconvOps.get_indice_pairs_implicit_gemm(
thalloc, torch_tensor_to_tv(indices), batch_size, spatial_shape, thalloc,
algo.value, ksize, stride, padding, dilation, out_padding, subm, torch_tensor_to_tv(indices),
transpose, is_train, stream, num_out_act_bound) batch_size,
spatial_shape,
algo.value,
ksize,
stride,
padding,
dilation,
out_padding,
subm,
transpose,
is_train,
stream,
num_out_act_bound,
timer=timer_cpp,
direct_table=direct_table)
mask_split_count = mask_tensor.dim(0) mask_split_count = mask_tensor.dim(0)
masks = [mask_tensor[i:i + 1].numpy() for i in range(mask_split_count)] masks = [mask_tensor[i:i + 1].numpy() for i in range(mask_split_count)]
if subm: if subm:
...@@ -342,7 +367,6 @@ def get_indice_pairs_implicit_gemm( ...@@ -342,7 +367,6 @@ def get_indice_pairs_implicit_gemm(
# for subm, if training, pair shape is [2, kv, ...] # for subm, if training, pair shape is [2, kv, ...]
# if not training, pair is [1, kv, ...] # if not training, pair is [1, kv, ...]
pair = thalloc.allocated[AllocKeys.PairFwd] pair = thalloc.allocated[AllocKeys.PairFwd]
pair_mask = thalloc.allocated[AllocKeys.PairMask] pair_mask = thalloc.allocated[AllocKeys.PairMask]
mask_argsort = thalloc.allocated[AllocKeys.MaskArgSort] mask_argsort = thalloc.allocated[AllocKeys.MaskArgSort]
pair_mask_in_splits = [ pair_mask_in_splits = [
...@@ -367,7 +391,6 @@ def get_indice_pairs_implicit_gemm( ...@@ -367,7 +391,6 @@ def get_indice_pairs_implicit_gemm(
if is_train: if is_train:
pair_mask_bwd = thalloc.allocated[AllocKeys.PairMaskBwd] pair_mask_bwd = thalloc.allocated[AllocKeys.PairMaskBwd]
mask_argsort_bwd = thalloc.allocated[AllocKeys.MaskArgSortBwd] mask_argsort_bwd = thalloc.allocated[AllocKeys.MaskArgSortBwd]
mask_argsort_fwd = thalloc.allocated[AllocKeys.MaskArgSort] mask_argsort_fwd = thalloc.allocated[AllocKeys.MaskArgSort]
if not is_train: if not is_train:
pair_mask_bwd_splits: List[torch.Tensor] = [] pair_mask_bwd_splits: List[torch.Tensor] = []
...@@ -388,11 +411,6 @@ def get_indice_pairs_implicit_gemm( ...@@ -388,11 +411,6 @@ def get_indice_pairs_implicit_gemm(
return (out_inds, indice_num_per_loc, pair_fwd, pair_bwd, return (out_inds, indice_num_per_loc, pair_fwd, pair_bwd,
pair_mask_fwd_splits, pair_mask_bwd_splits, pair_mask_fwd_splits, pair_mask_bwd_splits,
mask_argsort_fwd_splits, mask_argsort_bwd_splits, masks) mask_argsort_fwd_splits, mask_argsort_bwd_splits, masks)
t = 0
if DEBUG:
CONV.stream_synchronize(stream)
t = time.time()
assert indices.is_cuda, "implicit gemm only support cuda" assert indices.is_cuda, "implicit gemm only support cuda"
ndim = indices.shape[1] - 1 ndim = indices.shape[1] - 1
kv: int = functools.reduce(lambda x, y: x * y, ksize, 1) kv: int = functools.reduce(lambda x, y: x * y, ksize, 1)
...@@ -452,8 +470,6 @@ def get_indice_pairs_implicit_gemm( ...@@ -452,8 +470,6 @@ def get_indice_pairs_implicit_gemm(
masks = [first.astype(np.uint32), second.astype(np.uint32)] masks = [first.astype(np.uint32), second.astype(np.uint32)]
else: else:
masks = [np.array([0xffffffff], dtype=np.uint32)] masks = [np.array([0xffffffff], dtype=np.uint32)]
# torch.cuda.synchronize()
# print("SUBM0", time.time() - t)
if subm: if subm:
out_inds = indices out_inds = indices
...@@ -508,10 +524,6 @@ def get_indice_pairs_implicit_gemm( ...@@ -508,10 +524,6 @@ def get_indice_pairs_implicit_gemm(
mask_argsort_in_splits = [ mask_argsort_in_splits = [
mask_argsort[i] for i in range(mask_split_count) mask_argsort[i] for i in range(mask_split_count)
] ]
if DEBUG:
CONV.stream_synchronize(stream)
print("SUBM", time.time() - t)
if is_train: if is_train:
return (out_inds, indice_num_per_loc, pair[0], pair[1], return (out_inds, indice_num_per_loc, pair[0], pair[1],
pair_mask_in_splits, [], mask_argsort_in_splits, [], masks) pair_mask_in_splits, [], mask_argsort_in_splits, [], masks)
...@@ -519,11 +531,10 @@ def get_indice_pairs_implicit_gemm( ...@@ -519,11 +531,10 @@ def get_indice_pairs_implicit_gemm(
return (out_inds, indice_num_per_loc, pair[0], torch.Tensor(), return (out_inds, indice_num_per_loc, pair[0], torch.Tensor(),
pair_mask_in_splits, [], mask_argsort_in_splits, [], masks) pair_mask_in_splits, [], mask_argsort_in_splits, [], masks)
else: else:
if DEBUG: max_num_act = SpconvOps.get_handcrafted_max_act_out(
indices.shape[0], ksize, stride, padding, dilation)
CONV.stream_synchronize(stream) if transpose:
print("REGU_PREPARE", time.time() - t) max_num_act = kv * indices.shape[0]
t = time.time()
pair_bwd = pair pair_bwd = pair
pair_bwd_tv = pair_tv pair_bwd_tv = pair_tv
...@@ -531,8 +542,38 @@ def get_indice_pairs_implicit_gemm( ...@@ -531,8 +542,38 @@ def get_indice_pairs_implicit_gemm(
dtype=indice_dtype, dtype=indice_dtype,
device=indices.device) device=indices.device)
indice_pairs_uniq_tv = torch_tensor_to_tv(indice_pairs_uniq) indice_pairs_uniq_tv = torch_tensor_to_tv(indice_pairs_uniq)
hashdata = _HashData(0, use_int64_hash_k, indices.device)
indice_pairs_uniq_bkp_tv = tv.Tensor()
if direct_table:
# print("HASH SIZE", max_num_act * 2)
hashdata = _HashData(max_num_act, use_int64_hash_k, indices.device,
SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE)
indice_pairs_uniq_bkp = torch.empty((pair.numel() + 1, ),
dtype=indice_dtype,
device=indices.device)
indice_pairs_uniq_bkp_tv = torch_tensor_to_tv(
indice_pairs_uniq_bkp)
with timer.record("gen_conv_inds_stage1", stream): with timer.record("gen_conv_inds_stage1", stream):
SpconvOps.generate_conv_inds_mask_stage1(inds_tv, SpconvOps.generate_conv_inds_mask_stage1_direct_table(
inds_tv,
hashdata.hashdata_k_tv,
hashdata.hashdata_v_tv,
pair_bwd_tv,
indice_pairs_uniq_bkp_tv,
indice_num_per_loc_tv,
batch_size=batch_size,
output_dims=out_shape,
input_dims=spatial_shape,
ksize=ksize,
stride=stride,
padding=padding,
dilation=dilation,
transposed=transpose,
stream_int=stream)
else:
with timer.record("gen_conv_inds_stage1", stream):
SpconvOps.generate_conv_inds_mask_stage1(
inds_tv,
pair_bwd_tv, pair_bwd_tv,
indice_pairs_uniq_tv, indice_pairs_uniq_tv,
indice_num_per_loc_tv, indice_num_per_loc_tv,
...@@ -545,23 +586,31 @@ def get_indice_pairs_implicit_gemm( ...@@ -545,23 +586,31 @@ def get_indice_pairs_implicit_gemm(
dilation=dilation, dilation=dilation,
transposed=transpose, transposed=transpose,
stream_int=stream) stream_int=stream)
if DEBUG: uniq_out_indices_offset_tv = tv.Tensor()
with timer.record(f"unique_{indice_pairs_uniq.shape[0]}", stream):
CONV.stream_synchronize(stream)
print("REGU_S1", time.time() - t)
t = time.time()
if direct_table:
uniq_cnt = torch.zeros([1],
dtype=torch.int32,
device=indices.device)
uniq_cnt_tv = torch_tensor_to_tv(uniq_cnt)
num_act_out = SpconvOps.unique_hash(hashdata.hashdata_k_tv,
hashdata.hashdata_v_tv,
uniq_cnt_tv,
indice_pairs_uniq_tv,
num_out_act_bound, stream)
uniq_out_indices_offset_tv = indice_pairs_uniq_tv
raw_out_indices_offset_tv = indice_pairs_uniq_bkp_tv
else:
uniq_res = indice_pairs_uniq.unique() uniq_res = indice_pairs_uniq.unique()
num_act_out = uniq_res.shape[0] - 1 num_act_out = uniq_res.shape[0] - 1
uniq_out_indices_offset_tv = torch_tensor_to_tv(uniq_res)
raw_out_indices_offset_tv = indice_pairs_uniq_tv
if num_out_act_bound > 0 and num_act_out > num_out_act_bound: if num_out_act_bound > 0 and num_act_out > num_out_act_bound:
num_act_out = num_out_act_bound num_act_out = num_out_act_bound
if DEBUG: with timer.record(f"alloc_stage2", stream):
CONV.stream_synchronize(stream)
print("REGU_UNIQ", time.time() - t)
t = time.time()
uniq_res_tv = torch_tensor_to_tv(uniq_res)
out_inds = torch.empty((num_act_out, indices.shape[1]), out_inds = torch.empty((num_act_out, indices.shape[1]),
dtype=indices.dtype, dtype=indices.dtype,
device=indices.device) device=indices.device)
...@@ -574,15 +623,18 @@ def get_indice_pairs_implicit_gemm( ...@@ -574,15 +623,18 @@ def get_indice_pairs_implicit_gemm(
dtype=torch.int32, dtype=torch.int32,
device=indices.device) device=indices.device)
pair_fwd_tv = torch_tensor_to_tv(pair_fwd) pair_fwd_tv = torch_tensor_to_tv(pair_fwd)
pair_mask_fwd_tv = torch_tensor_to_tv(pair_mask_fwd, dtype=tv.uint32) pair_mask_fwd_tv = torch_tensor_to_tv(pair_mask_fwd,
dtype=tv.uint32)
pair_mask_bwd = torch.Tensor() pair_mask_bwd = torch.Tensor()
pair_mask_bwd_tv = tv.Tensor() pair_mask_bwd_tv = tv.Tensor()
if is_train: if is_train:
pair_mask_bwd = torch.zeros((mask_split_count, indices.shape[0]), pair_mask_bwd = torch.zeros(
(mask_split_count, indices.shape[0]),
dtype=torch.int32, dtype=torch.int32,
device=indices.device) device=indices.device)
pair_mask_bwd_tv = torch_tensor_to_tv(pair_mask_bwd, pair_mask_bwd_tv = torch_tensor_to_tv(pair_mask_bwd,
dtype=tv.uint32) dtype=tv.uint32)
if not direct_table:
hashdata = _HashData(out_inds.shape[0], use_int64_hash_k, hashdata = _HashData(out_inds.shape[0], use_int64_hash_k,
indices.device) indices.device)
...@@ -591,19 +643,28 @@ def get_indice_pairs_implicit_gemm( ...@@ -591,19 +643,28 @@ def get_indice_pairs_implicit_gemm(
# device=indices.device) # device=indices.device)
out_inds_tv = torch_tensor_to_tv(out_inds) out_inds_tv = torch_tensor_to_tv(out_inds)
# hashdata_tv = torch_tensor_to_tv(hashdata, dtype=tv.custom64) # hashdata_tv = torch_tensor_to_tv(hashdata, dtype=tv.custom64)
if DEBUG: with timer.record(f"gen_conv_inds_stage2_{num_act_out}", stream):
stage2_fn = SpconvOps.generate_conv_inds_mask_stage2
if direct_table:
SpconvOps.assign_output_direct_hash(indice_pairs_uniq_tv,
out_inds_tv,
batch_size=batch_size,
output_dims=out_shape,
input_dims=spatial_shape,
ksize=ksize,
stride=stride,
padding=padding,
dilation=dilation,
stream_int=stream)
stage2_fn = SpconvOps.generate_conv_inds_stage2_mask_direct_table
CONV.stream_synchronize(stream) stage2_fn(inds_tv,
print("REGU_S2_PREPARE", time.time() - t)
t = time.time()
with timer.record("gen_conv_inds_stage2", stream):
SpconvOps.generate_conv_inds_mask_stage2(inds_tv,
hashdata.hashdata_k_tv, hashdata.hashdata_k_tv,
hashdata.hashdata_v_tv, hashdata.hashdata_v_tv,
pair_fwd_tv, pair_fwd_tv,
pair_bwd_tv, pair_bwd_tv,
uniq_res_tv, uniq_out_indices_offset_tv,
indice_pairs_uniq_tv, raw_out_indices_offset_tv,
out_inds_tv, out_inds_tv,
pair_mask_fwd_tv, pair_mask_fwd_tv,
pair_mask_bwd_tv, pair_mask_bwd_tv,
...@@ -617,12 +678,6 @@ def get_indice_pairs_implicit_gemm( ...@@ -617,12 +678,6 @@ def get_indice_pairs_implicit_gemm(
dilation=dilation, dilation=dilation,
transposed=transpose, transposed=transpose,
stream_int=stream) stream_int=stream)
if DEBUG:
CONV.stream_synchronize(stream)
print("REGU_S2", time.time() - t)
t = time.time()
mask_argsort_fwd = torch.empty((mask_split_count, out_inds.shape[0]), mask_argsort_fwd = torch.empty((mask_split_count, out_inds.shape[0]),
dtype=torch.int32, dtype=torch.int32,
device=indices.device) device=indices.device)
...@@ -693,10 +748,6 @@ def get_indice_pairs_implicit_gemm( ...@@ -693,10 +748,6 @@ def get_indice_pairs_implicit_gemm(
SpconvOps.sort_1d_by_key_allocator( SpconvOps.sort_1d_by_key_allocator(
pair_mask_bwd_tv[0], alloc.alloc, pair_mask_bwd_tv[0], alloc.alloc,
mask_argsort_bwd_tv[0], stream) mask_argsort_bwd_tv[0], stream)
if DEBUG:
CONV.stream_synchronize(stream)
print("REGU_S2_FINISH", time.time() - t)
t = time.time()
# CONV.stream_synchronize(stream) # CONV.stream_synchronize(stream)
if not is_train: if not is_train:
...@@ -716,9 +767,6 @@ def get_indice_pairs_implicit_gemm( ...@@ -716,9 +767,6 @@ def get_indice_pairs_implicit_gemm(
mask_argsort_fwd_splits = [ mask_argsort_fwd_splits = [
mask_argsort_fwd[i] for i in range(mask_split_count) mask_argsort_fwd[i] for i in range(mask_split_count)
] ]
if DEBUG:
CONV.stream_synchronize(stream)
print("REGU", time.time() - t)
return (out_inds, indice_num_per_loc, pair_fwd, pair_bwd, return (out_inds, indice_num_per_loc, pair_fwd, pair_bwd,
pair_mask_fwd_splits, pair_mask_bwd_splits, pair_mask_fwd_splits, pair_mask_bwd_splits,
...@@ -769,8 +817,7 @@ def indice_conv(features: torch.Tensor, ...@@ -769,8 +817,7 @@ def indice_conv(features: torch.Tensor,
stream = get_current_stream() stream = get_current_stream()
ConvGemmOps.indice_conv(alloc, ext_mm, GEMM_CPP, ALL_WEIGHT_IS_KRSC, ConvGemmOps.indice_conv(alloc, ext_mm, GEMM_CPP, ALL_WEIGHT_IS_KRSC,
FILTER_HWIO, features_tv, filters_tv, FILTER_HWIO, features_tv, filters_tv,
indice_pairs_tv, indice_pair_num_tv, indice_pairs_tv, indice_pair_num_tv, arch,
arch,
num_activate_out, inverse, subm, algo.value, num_activate_out, inverse, subm, algo.value,
stream) stream)
out_features = alloc.allocated[AllocKeys.OutFeatures] out_features = alloc.allocated[AllocKeys.OutFeatures]
...@@ -1018,8 +1065,8 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -1018,8 +1065,8 @@ def indice_conv_backward(features: torch.Tensor,
ALL_WEIGHT_IS_KRSC, FILTER_HWIO, ALL_WEIGHT_IS_KRSC, FILTER_HWIO,
features_tv, filters_tv, out_bp_tv, features_tv, filters_tv, out_bp_tv,
indice_pairs_tv, indice_pair_num_tv, indice_pairs_tv, indice_pair_num_tv,
arch, arch, inverse, subm, algo.value,
inverse, subm, algo.value, stream) stream)
din = alloc.allocated[AllocKeys.DIn] din = alloc.allocated[AllocKeys.DIn]
df = alloc.allocated[AllocKeys.DFilters] df = alloc.allocated[AllocKeys.DFilters]
return din, df return din, df
...@@ -1369,8 +1416,8 @@ def implicit_gemm(features: torch.Tensor, ...@@ -1369,8 +1416,8 @@ def implicit_gemm(features: torch.Tensor,
mask_width = ConvGemmOps.implicit_gemm( mask_width = ConvGemmOps.implicit_gemm(
alloc, CONV_CPP, features_tv, filters_tv, pair_fwd_tv, alloc, CONV_CPP, features_tv, filters_tv, pair_fwd_tv,
pair_mask_fwd_splits_tv, mask_argsort_fwd_splits_tv, pair_mask_fwd_splits_tv, mask_argsort_fwd_splits_tv,
num_activate_out, mask_tv, arch, is_train, is_subm, stream, timer_cpp, num_activate_out, mask_tv, arch, is_train, is_subm, stream,
auto_fp32_accum, fp32_accum) timer_cpp, auto_fp32_accum, fp32_accum)
out_features = alloc.allocated[AllocKeys.OutFeatures] out_features = alloc.allocated[AllocKeys.OutFeatures]
mask_output_fwd = alloc.allocated.get(AllocKeys.MaskOutputFwd, None) mask_output_fwd = alloc.allocated.get(AllocKeys.MaskOutputFwd, None)
if is_train: if is_train:
...@@ -1460,7 +1507,7 @@ def implicit_gemm(features: torch.Tensor, ...@@ -1460,7 +1507,7 @@ def implicit_gemm(features: torch.Tensor,
# CONV.stream_synchronize(stream) # CONV.stream_synchronize(stream)
# t = time.time() # t = time.time()
print(tune_res.algo_desp, "REF", features_tv.shape, filters.shape) # print(tune_res.algo_desp, "REF", features_tv.shape, filters.shape)
# with tv.measure_and_print("f16 time"): # with tv.measure_and_print("f16 time"):
with timer.record("implicit_gemm", stream): with timer.record("implicit_gemm", stream):
for j in range(num_split): for j in range(num_split):
...@@ -1921,8 +1968,10 @@ def indice_maxpool_implicit_gemm_backward(features, out_features, out_bp, ...@@ -1921,8 +1968,10 @@ def indice_maxpool_implicit_gemm_backward(features, out_features, out_bp,
indice_pairs_tv, stream) indice_pairs_tv, stream)
return din return din
def indice_avgpool_implicit_gemm(features: torch.Tensor, def indice_avgpool_implicit_gemm(features: torch.Tensor,
indice_pairs: torch.Tensor, num_activate_out, calc_count: bool): indice_pairs: torch.Tensor, num_activate_out,
calc_count: bool):
# torch.cuda.synchronize() # torch.cuda.synchronize()
# t = time.time() # t = time.time()
stream = get_current_stream() stream = get_current_stream()
...@@ -1943,12 +1992,13 @@ def indice_avgpool_implicit_gemm(features: torch.Tensor, ...@@ -1943,12 +1992,13 @@ def indice_avgpool_implicit_gemm(features: torch.Tensor,
count_out = torch.Tensor() count_out = torch.Tensor()
count_out_tv = tv.Tensor() count_out_tv = tv.Tensor()
if calc_count: if calc_count:
count_out = torch.zeros((num_activate_out,), count_out = torch.zeros((num_activate_out, ),
dtype=torch.int32, dtype=torch.int32,
device=features.device) device=features.device)
count_out_tv = torch_tensor_to_tv(count_out) count_out_tv = torch_tensor_to_tv(count_out)
SpconvOps.avgpool_implicit_gemm_forward(out_features_tv, features_tv, SpconvOps.avgpool_implicit_gemm_forward(out_features_tv, features_tv,
indice_pairs_tv, count_out_tv, stream) indice_pairs_tv, count_out_tv,
stream)
# CONV.stream_synchronize(stream) # CONV.stream_synchronize(stream)
# print("M", time.time() - t) # print("M", time.time() - t)
...@@ -1956,12 +2006,13 @@ def indice_avgpool_implicit_gemm(features: torch.Tensor, ...@@ -1956,12 +2006,13 @@ def indice_avgpool_implicit_gemm(features: torch.Tensor,
return out_features, count_out return out_features, count_out
def indice_avgpool_implicit_gemm_backward(out_bp, def indice_avgpool_implicit_gemm_backward(out_bp, indice_pairs, count_out):
indice_pairs, count_out):
# torch.cuda.synchronize() # torch.cuda.synchronize()
# t = time.time() # t = time.time()
out_channel = out_bp.shape[-1] out_channel = out_bp.shape[-1]
din = torch.zeros((indice_pairs.shape[1], out_bp.shape[1]), dtype=out_bp.dtype, device=out_bp.device) din = torch.zeros((indice_pairs.shape[1], out_bp.shape[1]),
dtype=out_bp.dtype,
device=out_bp.device)
assert out_bp.is_cuda assert out_bp.is_cuda
if not out_bp.is_contiguous(): if not out_bp.is_contiguous():
out_bp = out_bp.contiguous() out_bp = out_bp.contiguous()
...@@ -1972,7 +2023,8 @@ def indice_avgpool_implicit_gemm_backward(out_bp, ...@@ -1972,7 +2023,8 @@ def indice_avgpool_implicit_gemm_backward(out_bp,
din_tv = torch_tensor_to_tv(din) din_tv = torch_tensor_to_tv(din)
indice_pairs_tv = torch_tensor_to_tv(indice_pairs) indice_pairs_tv = torch_tensor_to_tv(indice_pairs)
SpconvOps.avgpool_implicit_gemm_backward(out_bp_tv, din_tv, SpconvOps.avgpool_implicit_gemm_backward(out_bp_tv, din_tv,
indice_pairs_tv, count_out_tv, stream) indice_pairs_tv, count_out_tv,
stream)
return din return din
......
...@@ -323,6 +323,8 @@ def main(): ...@@ -323,6 +323,8 @@ def main():
# pickle.dump((voxels, coors, spatial_shape), f) # pickle.dump((voxels, coors, spatial_shape), f)
with open(Path(__file__).parent / "data" / "test_spconv.pkl", "rb") as f: with open(Path(__file__).parent / "data" / "test_spconv.pkl", "rb") as f:
(voxels, coors, spatial_shape) = pickle.load(f) (voxels, coors, spatial_shape) = pickle.load(f)
# voxels, coors, spatial_shape = waymo_data_large()
print(spatial_shape) print(spatial_shape)
print(voxels.shape) print(voxels.shape)
# voxels = voxels[:100] # voxels = voxels[:100]
...@@ -366,15 +368,14 @@ def main(): ...@@ -366,15 +368,14 @@ def main():
dout = np.random.uniform(-0.2, 0.2, out.features.shape).astype(np.float32) dout = np.random.uniform(-0.2, 0.2, out.features.shape).astype(np.float32)
dout_t = torch.from_numpy(dout).to(device).to(dtype) dout_t = torch.from_numpy(dout).to(device).to(dtype)
print(out.spatial_shape, out.features.mean(), out.features.max(), print(out.spatial_shape, out.features.sum(1).mean(), out.features.max(),
out.features.min()) out.features.min())
times = [] times = []
show_metrics = False show_metrics = False
with torch.no_grad(): with torch.no_grad():
for i in range(20): for i in range(100):
print("------------") # print("------------")
torch.cuda.synchronize() with tv.measure_duration() as measure:
t = time.time()
out_nograd = net(voxels_th, coors_th, 1, show_metrics) out_nograd = net(voxels_th, coors_th, 1, show_metrics)
# res = timer.collect_by_name("forward", timer.get_all_pair_time()) # res = timer.collect_by_name("forward", timer.get_all_pair_time())
# res2 = timer.collect_by_name("forward0", timer.get_all_pair_time()) # res2 = timer.collect_by_name("forward0", timer.get_all_pair_time())
...@@ -383,14 +384,19 @@ def main(): ...@@ -383,14 +384,19 @@ def main():
# print(timer.get_all_pair_time()) # print(timer.get_all_pair_time())
# print(sum(timer.get_all_pair_time().values())) # print(sum(timer.get_all_pair_time().values()))
torch.cuda.synchronize()
# sort_bench() # sort_bench()
times.append(time.time() - t) times.append(measure.duration)
if show_metrics: if show_metrics:
timer = out_nograd._timer timer = out_nograd._timer
items = list(timer.get_all_pair_time().items()) items = list(timer.get_all_pair_time().items())
items.sort(key=lambda x: x[0]) items.sort(key=lambda x: x[0])
print("SUM TIME:", sum([x[1] for x in items]))
print(json.dumps(dict(items), indent=2)) print(json.dumps(dict(items), indent=2))
inds_sum = 0
for k, v in items:
if "gen_pairs" in k:
inds_sum += v
print("SUM GEN INDS:", inds_sum)
# state = net.state_dict() # state = net.state_dict()
# state.pop("net.2.max_num_voxels_during_training") # state.pop("net.2.max_num_voxels_during_training")
......
...@@ -231,8 +231,8 @@ def _test_impgemm_conv_cuda(subm: bool): ...@@ -231,8 +231,8 @@ def _test_impgemm_conv_cuda(subm: bool):
# out_channels = [32, 48, 64] # out_channels = [32, 48, 64]
in_channels = [32, 47] in_channels = [32, 47]
out_channels = [32, 48, 62] out_channels = [32, 48, 62]
in_channels = [32] # in_channels = [32]
out_channels = [32] # out_channels = [32]
multiple_base = 16 multiple_base = 16
if subm: if subm:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment