Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
one
spconv
Commits
73a5ce7d
"platforms/common/src/CommonCalcNonbondedForce.cpp" did not exist on "b49b82efb5a253a7c891ca084b3370e181de2ea3"
Commit
73a5ce7d
authored
Aug 25, 2022
by
yan.yan
Browse files
add direct table
parent
0c07559f
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
1113 additions
and
375 deletions
+1113
-375
spconv/constants.py
spconv/constants.py
+9
-3
spconv/core_cc/csrc/sparse/all/__init__.pyi
spconv/core_cc/csrc/sparse/all/__init__.pyi
+97
-3
spconv/csrc/sparse/all.py
spconv/csrc/sparse/all.py
+417
-155
spconv/csrc/sparse/indices.py
spconv/csrc/sparse/indices.py
+371
-36
spconv/pytorch/cppcore.py
spconv/pytorch/cppcore.py
+17
-34
spconv/pytorch/ops.py
spconv/pytorch/ops.py
+186
-134
test/benchmark.py
test/benchmark.py
+14
-8
test/test_all_algo.py
test/test_all_algo.py
+2
-2
No files found.
spconv/constants.py
View file @
73a5ce7d
...
@@ -95,13 +95,19 @@ class AllocKeys:
...
@@ -95,13 +95,19 @@ class AllocKeys:
HashV
=
"HashV"
HashV
=
"HashV"
ThrustTemp
=
"ThrustTemp"
ThrustTemp
=
"ThrustTemp"
TightUniqueCount
=
"TightUniqueCount"
SPCONV_DEBUG_WEIGHT
=
False
SPCONV_DEBUG_WEIGHT
=
False
SPCONV_CPP_INDICE_PAIRS
=
False
SPCONV_CPP_INDICE_PAIRS
=
False
SPCONV_CPP_INDICE_PAIRS_IGEMM
=
False
SPCONV_CPP_GEMM
=
False
# currently use cpp pair gen is slightly slower than python, I don't know why.
SPCONV_CPP_INDICE_PAIRS_IGEMM
=
os
.
getenv
(
"SPCONV_CPP_INDICE_PAIRS_IGEMM"
,
"0"
)
==
"1"
SPCONV_CPP_GEMM
=
True
SPCONV_FX_TRACE_MODE
=
os
.
getenv
(
"SPCONV_FX_TRACE_MODE"
,
"0"
)
==
"1"
SPCONV_FX_TRACE_MODE
=
os
.
getenv
(
"SPCONV_FX_TRACE_MODE"
,
"0"
)
==
"1"
SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE
=
1.1
\ No newline at end of file
spconv/core_cc/csrc/sparse/all/__init__.pyi
View file @
73a5ce7d
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
from pccm.stubs import EnumValue, EnumClassValue
from cumm.tensorview import Tensor
from cumm.tensorview import Tensor
from cumm.tensorview import CUDAKernelTimer
class ThrustCustomAllocatorV2:
class ThrustCustomAllocatorV2:
alloc_func: Callable[int, int]
alloc_func: Callable[int, int]
class SpconvOps:
class SpconvOps:
...
@@ -92,6 +93,55 @@ class SpconvOps:
...
@@ -92,6 +93,55 @@ class SpconvOps:
"""
"""
...
...
@staticmethod
@staticmethod
def generate_conv_inds_mask_stage1_direct_table(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_num_per_loc: Tensor, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> None:
"""
Args:
indices:
hashdata_k:
hashdata_v:
indice_pairs_bwd:
indice_pairs_uniq:
indice_num_per_loc:
batch_size:
output_dims:
input_dims:
ksize:
stride:
padding:
dilation:
transposed:
stream_int:
"""
...
@staticmethod
def unique_hash(hashdata_k: Tensor, hashdata_v: Tensor, uniq_cnt: Tensor, out_indices_offset: Tensor, num_out_bound: int, stream_int: int = 0) -> int:
"""
Args:
hashdata_k:
hashdata_v:
uniq_cnt:
out_indices_offset:
num_out_bound:
stream_int:
"""
...
@staticmethod
def assign_output_direct_hash(out_indices_offset: Tensor, out_indices: Tensor, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], stream_int: int = 0) -> None:
"""
Args:
out_indices_offset:
out_indices:
batch_size:
output_dims:
input_dims:
ksize:
stride:
padding:
dilation:
stream_int:
"""
...
@staticmethod
def generate_conv_inds_mask_stage2(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> int:
def generate_conv_inds_mask_stage2(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> int:
"""
"""
Args:
Args:
...
@@ -118,6 +168,32 @@ class SpconvOps:
...
@@ -118,6 +168,32 @@ class SpconvOps:
"""
"""
...
...
@staticmethod
@staticmethod
def generate_conv_inds_stage2_mask_direct_table(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> int:
"""
Args:
indices:
hashdata_k:
hashdata_v:
indice_pairs_fwd:
indice_pairs_bwd:
indice_pairs_uniq:
indice_pairs_uniq_before_sort:
out_inds:
mask_fwd:
mask_bwd:
num_out_act:
batch_size:
output_dims:
input_dims:
ksize:
stride:
padding:
dilation:
transposed:
stream_int:
"""
...
@staticmethod
def generate_subm_conv_inds(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs: Tensor, out_inds: Tensor, indice_num_per_loc: Tensor, batch_size: int, input_dims: List[int], ksize: List[int], dilation: List[int], indice_pair_mask: Tensor = Tensor(), backward: bool = False, stream_int: int = 0) -> int:
def generate_subm_conv_inds(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs: Tensor, out_inds: Tensor, indice_num_per_loc: Tensor, batch_size: int, input_dims: List[int], ksize: List[int], dilation: List[int], indice_pair_mask: Tensor = Tensor(), backward: bool = False, stream_int: int = 0) -> int:
"""
"""
Args:
Args:
...
@@ -427,30 +503,45 @@ class SpconvOps:
...
@@ -427,30 +503,45 @@ class SpconvOps:
@staticmethod
@staticmethod
def get_int32_max() -> int: ...
def get_int32_max() -> int: ...
@staticmethod
@staticmethod
def get_indice_gen_workspace_size(kv: int, num_act_in: int, num_act_out_bound: int, subm: bool, use_int64_hash_k: bool) -> int:
def get_handcrafted_max_act_out(num_act_in: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int]) -> int:
"""
Args:
num_act_in:
ksize:
stride:
padding:
dilation:
"""
...
@staticmethod
def get_indice_gen_workspace_size(kv: int, num_act_in: int, num_act_out_bound: int, max_act_out_in_theory: int, subm: bool, use_int64_hash_k: bool, direct_table: bool) -> int:
"""
"""
Args:
Args:
kv:
kv:
num_act_in:
num_act_in:
num_act_out_bound:
num_act_out_bound:
max_act_out_in_theory:
subm:
subm:
use_int64_hash_k:
use_int64_hash_k:
direct_table:
"""
"""
...
...
@staticmethod
@staticmethod
def get_indice_gen_tensors_from_workspace(workspace, kv: int, num_act_in: int, num_act_out_bound: int, subm: bool, use_int64_hash_k: bool) -> Dict[str, Tensor]:
def get_indice_gen_tensors_from_workspace(workspace, kv: int, num_act_in: int, num_act_out_bound: int,
max_act_out_in_theory: int,
subm: bool, use_int64_hash_k:
bool, direct_table:
bool) -> Dict[str, Tensor]:
"""
"""
Args:
Args:
workspace:
workspace:
kv:
kv:
num_act_in:
num_act_in:
num_act_out_bound:
num_act_out_bound:
max_act_out_in_theory:
subm:
subm:
use_int64_hash_k:
use_int64_hash_k:
direct_table:
"""
"""
...
...
@staticmethod
@staticmethod
def get_indice_pairs_implicit_gemm(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, is_train: bool, stream_int: int = 0, num_out_act_bound: int = -1) -> Tuple[Tensor, int]:
def get_indice_pairs_implicit_gemm(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, is_train: bool, stream_int: int = 0, num_out_act_bound: int = -1
, timer: CUDAKernelTimer = CUDAKernelTimer(False), direct_table: bool = False, preallocated: Dict[str, Tensor] = {}
) -> Tuple[Tensor, int]:
"""
"""
Args:
Args:
allocator:
allocator:
...
@@ -468,6 +559,9 @@ class SpconvOps:
...
@@ -468,6 +559,9 @@ class SpconvOps:
is_train:
is_train:
stream_int:
stream_int:
num_out_act_bound:
num_out_act_bound:
timer:
direct_table:
preallocated:
"""
"""
...
...
@staticmethod
@staticmethod
...
...
spconv/csrc/sparse/all.py
View file @
73a5ce7d
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# limitations under the License.
# limitations under the License.
from
typing
import
List
from
typing
import
List
from
cumm.common
import
TensorView
,
TensorViewCPU
,
TensorViewKernel
,
ThrustLib
,
GemmBasicHost
from
cumm.common
import
TensorView
,
TensorViewCPU
,
TensorViewKernel
,
ThrustLib
,
GemmBasicHost
,
CppTimer
import
cumm
import
cumm
from
cumm.conv.bases
import
ConvOpType
,
NHWC
from
cumm.conv.bases
import
ConvOpType
,
NHWC
from
cumm.conv.params
import
ConvProblem
from
cumm.conv.params
import
ConvProblem
...
@@ -27,7 +27,7 @@ from .indices import SparseConvIndicesKernel, CudaCommonKernel, SparseConvIndice
...
@@ -27,7 +27,7 @@ from .indices import SparseConvIndicesKernel, CudaCommonKernel, SparseConvIndice
from
.maxpool
import
IndiceMaxPool
,
IndiceMaxPoolCPU
from
.maxpool
import
IndiceMaxPool
,
IndiceMaxPoolCPU
from
.gather
import
GatherCPU
from
.gather
import
GatherCPU
from
.alloc
import
ExternalAllocator
,
ThrustAllocator
from
.alloc
import
ExternalAllocator
,
ThrustAllocator
from
spconv.constants
import
AllocKeys
from
spconv.constants
import
SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE
,
AllocKeys
import
re
import
re
class
CustomThrustLib
(
pccm
.
Class
):
class
CustomThrustLib
(
pccm
.
Class
):
...
@@ -78,6 +78,11 @@ def to_snake_case(name):
...
@@ -78,6 +78,11 @@ def to_snake_case(name):
name
=
re
.
sub
(
'([a-z0-9])([A-Z])'
,
r
'\1_\2'
,
name
)
name
=
re
.
sub
(
'([a-z0-9])([A-Z])'
,
r
'\1_\2'
,
name
)
return
name
.
lower
()
return
name
.
lower
()
class
HashCoreHost
(
pccm
.
Class
):
def
__init__
(
self
):
super
().
__init__
()
self
.
add_include
(
"tensorview/hash/hash_core.h"
)
class
SpconvOps
(
pccm
.
Class
):
class
SpconvOps
(
pccm
.
Class
):
def
__init__
(
self
):
def
__init__
(
self
):
super
().
__init__
()
super
().
__init__
()
...
@@ -104,7 +109,10 @@ class SpconvOps(pccm.Class):
...
@@ -104,7 +109,10 @@ class SpconvOps(pccm.Class):
self
.
generate_conv_inds_stage1_5
,
self
.
generate_conv_inds_stage1_5
,
self
.
generate_conv_inds_stage2
,
self
.
sort_1d_by_key
,
self
.
generate_conv_inds_stage2
,
self
.
sort_1d_by_key
,
self
.
generate_conv_inds_mask_stage1
,
self
.
generate_conv_inds_mask_stage1
,
self
.
generate_conv_inds_mask_stage2
self
.
generate_conv_inds_mask_stage2
,
self
.
unique_hash
,
self
.
assign_output_direct_hash
,
self
.
generate_conv_inds_mask_stage1_direct_table
,
self
.
generate_conv_inds_stage2_mask_direct_table
]
]
self
.
add_impl_only_param_class
(
cuda_funcs
,
f
"ops
{
ndim
}
d"
,
self
.
add_impl_only_param_class
(
cuda_funcs
,
f
"ops
{
ndim
}
d"
,
indices
,
indices
,
...
@@ -306,6 +314,110 @@ class SpconvOps(pccm.Class):
...
@@ -306,6 +314,110 @@ class SpconvOps(pccm.Class):
return
code
# .ret("int")
return
code
# .ret("int")
@
pccm
.
pybind
.
mark
@
pccm
.
cuda
.
static_function
def
generate_conv_inds_mask_stage1_direct_table
(
self
):
code
=
pccm
.
FunctionCode
()
if
CUMM_CPU_ONLY_BUILD
:
return
code
.
make_invalid
()
code
.
arg
(
"indices, hashdata_k, hashdata_v"
,
"tv::Tensor"
)
code
.
arg
(
"indice_pairs_bwd, indice_pairs_uniq, indice_num_per_loc"
,
"tv::Tensor"
)
code
.
arg
(
"batch_size"
,
"int"
)
code
.
arg
(
"output_dims, input_dims"
,
f
"std::vector<int>"
)
code
.
arg
(
"ksize, stride, padding, dilation"
,
f
"std::vector<int>"
)
code
.
arg
(
"transposed"
,
f
"bool"
,
"false"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
,
pyanno
=
"int"
)
code
.
raw
(
f
"""
int ndim = indices.dim(1) - 1;
TV_ASSERT_RT_ERR(output_dims.size() == ndim && input_dims.size() == ndim &&
ksize.size() == ndim && stride.size() == ndim && dilation.size() == ndim &&
padding.size() == ndim, "your params size not equal to ndim", ndim);
"""
)
for
ndim
in
self
.
ndims
:
code
.
raw
(
f
"""
if (ndim ==
{
ndim
}
){{
tv::array<int,
{
ndim
}
> output_dims_, input_dims_;
tv::array<int,
{
ndim
}
> ksize_, stride_, padding_, dilation_;
for (int i = 0; i <
{
ndim
}
; ++i){{
output_dims_[i] = output_dims[i];
input_dims_[i] = input_dims[i];
ksize_[i] = ksize[i];
stride_[i] = stride[i];
padding_[i] = padding[i];
dilation_[i] = dilation[i];
}}
return SpconvIndices
{
ndim
}
D::generate_conv_inds_mask_stage1_direct_table(indices,
hashdata_k, hashdata_v, indice_pairs_bwd, indice_pairs_uniq,
indice_num_per_loc, batch_size, output_dims_, input_dims_,
ksize_, stride_, padding_, dilation_, transposed, stream_int);
}}
"""
)
code
.
raw
(
f
"""TV_THROW_RT_ERR("unknown ndim", ndim);"""
)
return
code
# .ret("int")
@
pccm
.
pybind
.
mark
@
pccm
.
cuda
.
static_function
def
unique_hash
(
self
):
code
=
pccm
.
FunctionCode
()
if
CUMM_CPU_ONLY_BUILD
:
return
code
.
make_invalid
()
code
.
arg
(
"hashdata_k, hashdata_v, uniq_cnt, out_indices_offset"
,
"tv::Tensor"
)
code
.
arg
(
"num_out_bound"
,
"int"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
,
pyanno
=
"int"
)
code
.
raw
(
f
"""
return SpconvIndices3D::unique_hash(hashdata_k, hashdata_v,
uniq_cnt, out_indices_offset, num_out_bound, stream_int);
"""
)
return
code
.
ret
(
"int"
)
@
pccm
.
pybind
.
mark
@
pccm
.
cuda
.
static_function
def
assign_output_direct_hash
(
self
):
code
=
pccm
.
FunctionCode
()
if
CUMM_CPU_ONLY_BUILD
:
return
code
.
make_invalid
()
code
.
arg
(
"out_indices_offset, out_indices"
,
"tv::Tensor"
)
code
.
arg
(
"batch_size"
,
"int"
)
code
.
arg
(
"output_dims, input_dims"
,
f
"std::vector<int>"
)
code
.
arg
(
"ksize, stride, padding, dilation"
,
f
"std::vector<int>"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
,
pyanno
=
"int"
)
code
.
raw
(
f
"""
int ndim = out_indices.dim(1) - 1;
TV_ASSERT_RT_ERR(output_dims.size() == ndim && input_dims.size() == ndim &&
ksize.size() == ndim && stride.size() == ndim && dilation.size() == ndim &&
padding.size() == ndim, "your params size not equal to ndim", ndim);
"""
)
for
ndim
in
self
.
ndims
:
code
.
raw
(
f
"""
if (ndim ==
{
ndim
}
){{
tv::array<int,
{
ndim
}
> output_dims_, input_dims_;
tv::array<int,
{
ndim
}
> ksize_, stride_, padding_, dilation_;
for (int i = 0; i <
{
ndim
}
; ++i){{
output_dims_[i] = output_dims[i];
input_dims_[i] = input_dims[i];
ksize_[i] = ksize[i];
stride_[i] = stride[i];
padding_[i] = padding[i];
dilation_[i] = dilation[i];
}}
return SpconvIndices
{
ndim
}
D::assign_output_direct_hash(
out_indices_offset, out_indices, batch_size, output_dims_, input_dims_,
ksize_, stride_, padding_, dilation_, stream_int);
}}
"""
)
code
.
raw
(
f
"""TV_THROW_RT_ERR("unknown ndim", ndim);"""
)
return
code
@
pccm
.
pybind
.
mark
@
pccm
.
pybind
.
mark
@
pccm
.
cuda
.
static_function
@
pccm
.
cuda
.
static_function
def
generate_conv_inds_mask_stage2
(
self
):
def
generate_conv_inds_mask_stage2
(
self
):
...
@@ -356,6 +468,55 @@ class SpconvOps(pccm.Class):
...
@@ -356,6 +468,55 @@ class SpconvOps(pccm.Class):
return
code
.
ret
(
"int"
)
return
code
.
ret
(
"int"
)
@
pccm
.
pybind
.
mark
@
pccm
.
cuda
.
static_function
def
generate_conv_inds_stage2_mask_direct_table
(
self
):
code
=
pccm
.
FunctionCode
()
if
CUMM_CPU_ONLY_BUILD
:
return
code
.
make_invalid
()
code
.
arg
(
"indices, hashdata_k, hashdata_v"
,
"tv::Tensor"
)
code
.
arg
(
"indice_pairs_fwd, indice_pairs_bwd, indice_pairs_uniq, indice_pairs_uniq_before_sort, out_inds"
,
"tv::Tensor"
)
code
.
arg
(
"mask_fwd, mask_bwd"
,
"tv::Tensor"
)
code
.
arg
(
"num_out_act"
,
"int"
)
code
.
arg
(
"batch_size"
,
"int"
)
code
.
arg
(
"output_dims, input_dims"
,
f
"std::vector<int>"
)
code
.
arg
(
"ksize, stride, padding, dilation"
,
f
"std::vector<int>"
)
code
.
arg
(
"transposed"
,
f
"bool"
,
"false"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
,
pyanno
=
"int"
)
code
.
raw
(
f
"""
int ndim = indices.dim(1) - 1;
TV_ASSERT_RT_ERR(output_dims.size() == ndim && input_dims.size() == ndim &&
ksize.size() == ndim && stride.size() == ndim && dilation.size() == ndim &&
padding.size() == ndim, "your params size not equal to ndim", ndim);
"""
)
for
ndim
in
self
.
ndims
:
code
.
raw
(
f
"""
if (ndim ==
{
ndim
}
){{
tv::array<int,
{
ndim
}
> output_dims_, input_dims_;
tv::array<int,
{
ndim
}
> ksize_, stride_, padding_, dilation_;
for (int i = 0; i <
{
ndim
}
; ++i){{
output_dims_[i] = output_dims[i];
input_dims_[i] = input_dims[i];
ksize_[i] = ksize[i];
stride_[i] = stride[i];
padding_[i] = padding[i];
dilation_[i] = dilation[i];
}}
return SpconvIndices
{
ndim
}
D::generate_conv_inds_stage2_mask_direct_table(
indices, hashdata_k, hashdata_v,
indice_pairs_fwd, indice_pairs_bwd,
indice_pairs_uniq, indice_pairs_uniq_before_sort,
out_inds, mask_fwd, mask_bwd,
num_out_act, batch_size, output_dims_, input_dims_,
ksize_, stride_, padding_, dilation_, transposed, stream_int);
}}
"""
)
code
.
raw
(
f
"""TV_THROW_RT_ERR("unknown ndim", ndim);"""
)
return
code
.
ret
(
"int"
)
@
pccm
.
pybind
.
mark
@
pccm
.
pybind
.
mark
@
pccm
.
cuda
.
static_function
@
pccm
.
cuda
.
static_function
def
generate_subm_conv_inds
(
self
):
def
generate_subm_conv_inds
(
self
):
...
@@ -718,53 +879,6 @@ class SpconvOps(pccm.Class):
...
@@ -718,53 +879,6 @@ class SpconvOps(pccm.Class):
"""
)
"""
)
return
code
return
code
@
pccm
.
pybind
.
mark
@
pccm
.
cuda
.
static_function
def
sort_1d_by_key
(
self
):
code
=
pccm
.
FunctionCode
()
if
CUMM_CPU_ONLY_BUILD
:
return
code
.
make_invalid
()
code
.
arg
(
"data"
,
"tv::Tensor"
)
code
.
arg
(
"indices"
,
"tv::Tensor"
,
"tv::Tensor()"
,
pyanno
=
"cumm.tensorview.Tensor = Tensor()"
)
code
.
arg
(
"stream"
,
"std::uintptr_t"
,
"0"
,
pyanno
=
"int"
)
code
.
code_after_include
=
f
"""
template <typename T> struct SmallOrEqualTo {{
TV_HOST_DEVICE_INLINE T operator()(const T &x, const T &y) const {{
return x < y;
}}
}};
template <typename T> __global__ void mask_input(T* inp, T mask, int size){{
for (int i : tv::KernelLoopX<int>(size)){{
inp[i] &= mask;
}}
}}
"""
code
.
add_dependency
(
CustomThrustLib
,
TensorViewKernel
)
code
.
add_param_class
(
"cudakers"
,
self
.
cuda_common_kernel
)
code
.
raw
(
f
"""
cudaStream_t stream_cu = reinterpret_cast<cudaStream_t>(stream);
if (indices.empty()){{
indices = tv::empty({{data.dim(0)}}, tv::int32, 0);
}}
tv::cuda::Launch launcher(data.dim(0), stream_cu);
launcher(cudakers::arange_kernel<int32_t>, indices.data_ptr<int32_t>(), indices.dim(0));
auto timer = tv::CUDATimer();
tv::dispatch<int32_t, uint32_t, int64_t, uint64_t>(data.dtype(), [&](auto I){{
using T = TV_DECLTYPE(I);
thrust::device_ptr<T> ptr_tr(data.data_ptr<T>());
thrust::device_ptr<int32_t> ptr_k(indices.data_ptr<int32_t>());
auto thrust_ctx = thrust::cuda::par.on(stream_cu);
thrust::stable_sort_by_key(thrust_ctx, ptr_tr, ptr_tr + data.dim(0), ptr_k, SmallOrEqualTo<uint32_t>());
}});
// tv::ssprint("SORT BY KEY TIME", data.dim(0), timer.report() / 1000.0);
return indices;
"""
)
return
code
.
ret
(
"tv::Tensor"
)
def
sort_1d_by_key_allocator_template
(
self
,
use_allocator
:
bool
):
def
sort_1d_by_key_allocator_template
(
self
,
use_allocator
:
bool
):
code
=
pccm
.
FunctionCode
()
code
=
pccm
.
FunctionCode
()
if
CUMM_CPU_ONLY_BUILD
:
if
CUMM_CPU_ONLY_BUILD
:
...
@@ -1379,6 +1493,29 @@ class SpconvOps(pccm.Class):
...
@@ -1379,6 +1493,29 @@ class SpconvOps(pccm.Class):
"""
)
"""
)
return
code
.
ret
(
"int"
)
return
code
.
ret
(
"int"
)
@
pccm
.
pybind
.
mark
@
pccm
.
static_function
def
get_handcrafted_max_act_out
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"num_act_in"
,
"size_t"
)
code
.
arg
(
"ksize, stride, padding, dilation"
,
"std::vector<int>"
)
code
.
raw
(
f
"""
int res = num_act_in;
for (int i = 0; i < ksize.size(); ++i){{
if (ksize[i] <= stride[i]){{
res *= 1;
}}
else if (ksize[i] > stride[i]){{
res *= tv::div_up(ksize[i], stride[i]);
}}
else{{
res *= ksize[i];
}}
}}
return res;
"""
)
return
code
.
ret
(
"int"
)
@
pccm
.
pybind
.
mark
@
pccm
.
pybind
.
mark
@
pccm
.
static_function
@
pccm
.
static_function
def
get_indice_gen_workspace_size
(
self
):
def
get_indice_gen_workspace_size
(
self
):
...
@@ -1386,15 +1523,20 @@ class SpconvOps(pccm.Class):
...
@@ -1386,15 +1523,20 @@ class SpconvOps(pccm.Class):
code
.
arg
(
"kv"
,
"size_t"
)
code
.
arg
(
"kv"
,
"size_t"
)
code
.
arg
(
"num_act_in"
,
"size_t"
)
code
.
arg
(
"num_act_in"
,
"size_t"
)
code
.
arg
(
"num_act_out_bound"
,
"size_t"
)
code
.
arg
(
"num_act_out_bound"
,
"size_t"
)
code
.
arg
(
"subm, use_int64_hash_k"
,
"bool"
)
code
.
arg
(
"max_act_out_in_theory"
,
"size_t"
)
code
.
arg
(
"subm, use_int64_hash_k, direct_table"
,
"bool"
)
code
.
raw
(
f
"""
code
.
raw
(
f
"""
int hash_size = 2 * num_act_out_bound;
if (direct_table){{
hash_size = int(
{
SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE
}
* max_act_out_in_theory);
}}
if (subm){{
if (subm){{
return
2 * num_act_out_bound
* (use_int64_hash_k ? 3 : 2) * sizeof(int);
return
hash_size
* (use_int64_hash_k ? 3 : 2) *
sizeof(int) + 1 *
sizeof(int);
}}else{{
}}else{{
size_t pair_single_size = kv * num_act_in; // 40000
size_t pair_single_size = kv * num_act_in; // 40000
size_t ind_uniq_and_bkp_size = (pair_single_size + 1) * 2 * (use_int64_hash_k ? sizeof(int64_t) : sizeof(int32_t));
size_t ind_uniq_and_bkp_size = (pair_single_size + 1) * 2 * (use_int64_hash_k ? sizeof(int64_t) : sizeof(int32_t));
size_t hash_size =
2 * num_act_out_bound
* (use_int64_hash_k ? 3 : 2) * sizeof(int);
size_t hash_size =
hash_size
* (use_int64_hash_k ? 3 : 2) * sizeof(int);
return ind_uniq_and_bkp_size + hash_size;
return ind_uniq_and_bkp_size + hash_size
+ 1 * sizeof(int)
;
}}
}}
"""
)
"""
)
return
code
.
ret
(
"std::size_t"
)
return
code
.
ret
(
"std::size_t"
)
...
@@ -1407,20 +1549,26 @@ class SpconvOps(pccm.Class):
...
@@ -1407,20 +1549,26 @@ class SpconvOps(pccm.Class):
code
.
arg
(
"kv"
,
"size_t"
)
code
.
arg
(
"kv"
,
"size_t"
)
code
.
arg
(
"num_act_in"
,
"size_t"
)
code
.
arg
(
"num_act_in"
,
"size_t"
)
code
.
arg
(
"num_act_out_bound"
,
"size_t"
)
code
.
arg
(
"num_act_out_bound"
,
"size_t"
)
code
.
arg
(
"subm, use_int64_hash_k"
,
"bool"
)
code
.
arg
(
"max_act_out_in_theory"
,
"size_t"
)
code
.
arg
(
"subm, use_int64_hash_k, direct_table"
,
"bool"
)
code
.
raw
(
f
"""
code
.
raw
(
f
"""
std::unordered_map<std::string, tv::Tensor> res;
std::unordered_map<std::string, tv::Tensor> res;
auto ws_prev = workspace;
auto ws_prev = workspace;
auto expected_size = get_indice_gen_workspace_size(kv, num_act_in, num_act_out_bound, subm, use_int64_hash_k);
auto expected_size = get_indice_gen_workspace_size(kv, num_act_in, num_act_out_bound,
max_act_out_in_theory, subm, use_int64_hash_k, direct_table);
int hash_size = 2 * num_act_out_bound;
if (direct_table){{
hash_size = int(
{
SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE
}
* max_act_out_in_theory);
}}
if (use_int64_hash_k){{
if (use_int64_hash_k){{
auto ten = tv::from_blob(workspace, {{int64_t(
num_act_out_bound
) * 2}}, tv::int64, 0);
auto ten = tv::from_blob(workspace, {{int64_t(
hash_size
) * 2}}, tv::int64, 0);
res.insert({{
{
pccm
.
literal
(
AllocKeys
.
HashKOrKV
)
}
, ten}});
res.insert({{
{
pccm
.
literal
(
AllocKeys
.
HashKOrKV
)
}
, ten}});
workspace += ten.nbytes();
workspace += ten.nbytes();
auto ten2 = tv::from_blob(workspace, {{int64_t(
num_act_out_bound
) * 2}}, tv::int32, 0);
auto ten2 = tv::from_blob(workspace, {{int64_t(
hash_size
) * 2}}, tv::int32, 0);
res.insert({{
{
pccm
.
literal
(
AllocKeys
.
HashV
)
}
, ten2}});
res.insert({{
{
pccm
.
literal
(
AllocKeys
.
HashV
)
}
, ten2}});
workspace += ten2.nbytes();
workspace += ten2.nbytes();
}}else{{
}}else{{
auto ten = tv::from_blob(workspace, {{2, int64_t(
num_act_out_bound
) * 2}}, tv::int32, 0);
auto ten = tv::from_blob(workspace, {{2, int64_t(
hash_size
) * 2}}, tv::int32, 0);
res.insert({{
{
pccm
.
literal
(
AllocKeys
.
HashKOrKV
)
}
, ten}});
res.insert({{
{
pccm
.
literal
(
AllocKeys
.
HashKOrKV
)
}
, ten}});
workspace += ten.nbytes();
workspace += ten.nbytes();
}}
}}
...
@@ -1433,6 +1581,10 @@ class SpconvOps(pccm.Class):
...
@@ -1433,6 +1581,10 @@ class SpconvOps(pccm.Class):
res.insert({{
{
pccm
.
literal
(
AllocKeys
.
IndicePairsUniqBackup
)
}
, ten2}});
res.insert({{
{
pccm
.
literal
(
AllocKeys
.
IndicePairsUniqBackup
)
}
, ten2}});
workspace += ten2.nbytes();
workspace += ten2.nbytes();
}}
}}
auto uniq_cnt = tv::from_blob(workspace, {{1}}, tv::int32, 0);
res.insert({{
{
pccm
.
literal
(
AllocKeys
.
TightUniqueCount
)
}
, uniq_cnt}});
workspace += uniq_cnt.nbytes();
TV_ASSERT_RT_ERR(workspace - ws_prev == expected_size, "this shouldn't happen");
TV_ASSERT_RT_ERR(workspace - ws_prev == expected_size, "this shouldn't happen");
return res;
return res;
"""
)
"""
)
...
@@ -1442,6 +1594,7 @@ class SpconvOps(pccm.Class):
...
@@ -1442,6 +1594,7 @@ class SpconvOps(pccm.Class):
@
pccm
.
static_function
@
pccm
.
static_function
def
get_indice_pairs_implicit_gemm
(
self
):
def
get_indice_pairs_implicit_gemm
(
self
):
code
=
pccm
.
code
()
code
=
pccm
.
code
()
code
.
add_dependency
(
HashCoreHost
)
code
.
arg
(
"allocator"
,
"ExternalAllocator&"
)
code
.
arg
(
"allocator"
,
"ExternalAllocator&"
)
code
.
arg
(
"indices"
,
"tv::Tensor"
)
code
.
arg
(
"indices"
,
"tv::Tensor"
)
code
.
arg
(
"batch_size"
,
"int"
)
code
.
arg
(
"batch_size"
,
"int"
)
...
@@ -1452,12 +1605,18 @@ class SpconvOps(pccm.Class):
...
@@ -1452,12 +1605,18 @@ class SpconvOps(pccm.Class):
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
,
pyanno
=
"int"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
,
pyanno
=
"int"
)
code
.
arg
(
"num_out_act_bound"
,
f
"int"
,
"-1"
)
code
.
arg
(
"num_out_act_bound"
,
f
"int"
,
"-1"
)
code
.
arg
(
"timer"
,
"tv::CUDAKernelTimer"
,
"tv::CUDAKernelTimer(false)"
,
"cumm.tensorview.CUDAKernelTimer = CUDAKernelTimer(False)"
)
code
.
arg
(
"direct_table"
,
f
"bool"
,
"false"
)
code
.
arg
(
"preallocated"
,
f
"std::unordered_map<std::string, tv::Tensor>"
,
"std::unordered_map<std::string, tv::Tensor>{}"
,
"Dict[str, cumm.tensorview.Tensor] = {}"
)
if
CUMM_CPU_ONLY_BUILD
:
if
CUMM_CPU_ONLY_BUILD
:
code
.
raw
(
f
"""
code
.
raw
(
f
"""
throw std::runtime_error("this function can only be used with CUDA.")
throw std::runtime_error("this function can only be used with CUDA.")
"""
)
"""
)
return
code
.
ret
(
"tv::Tensor"
)
return
code
.
ret
(
"
std::tuple<
tv::Tensor
, int>
"
)
code
.
raw
(
f
"""
code
.
raw
(
f
"""
auto tvctx = tv::Context();
auto tvctx = tv::Context();
tvctx.set_cuda_stream(reinterpret_cast<cudaStream_t>(stream_int));
tvctx.set_cuda_stream(reinterpret_cast<cudaStream_t>(stream_int));
...
@@ -1479,20 +1638,24 @@ class SpconvOps(pccm.Class):
...
@@ -1479,20 +1638,24 @@ class SpconvOps(pccm.Class):
TV_THROW_RT_ERR("your out spatial shape", out_shape, "ratch zero!, input shape:", input_dims);
TV_THROW_RT_ERR("your out spatial shape", out_shape, "ratch zero!, input shape:", input_dims);
}}
}}
}}
}}
std::vector<int64_t> output_dims_i64(out_shape.begin(), out_shape.end());
std::vector<int64_t> output_dims_i64(out_shape.begin(), out_shape.end());
int64_t out_spatial_volume = std::accumulate(output_dims_i64.begin(),
int64_t out_spatial_volume = std::accumulate(output_dims_i64.begin(),
output_dims_i64.end(), int64_t(1), std::multiplies<int64_t>());
output_dims_i64.end(), int64_t(1), std::multiplies<int64_t>());
bool use_int64_hash_k = out_spatial_volume >= int64_t(std::numeric_limits<int>::max());
bool use_int64_hash_k = out_spatial_volume >= int64_t(std::numeric_limits<int>::max());
tv::DType indice_uniq_dtype = use_int64_hash_k ? tv::int64 : tv::int32;
tv::DType indice_uniq_dtype = use_int64_hash_k ? tv::int64 : tv::int32;
TV_ASSERT_RT_ERR(conv_algo == tv::gemm::SparseConvAlgo::kMaskImplicitGemm ||
TV_ASSERT_RT_ERR(conv_algo == tv::gemm::SparseConvAlgo::kMaskImplicitGemm ||
conv_algo == tv::gemm::SparseConvAlgo::kMaskSplitImplicitGemm, "only support implicit gemm");
conv_algo == tv::gemm::SparseConvAlgo::kMaskSplitImplicitGemm, "only support implicit gemm");
bool is_mask_split = conv_algo == tv::gemm::SparseConvAlgo::kMaskSplitImplicitGemm;
bool is_mask_split = conv_algo == tv::gemm::SparseConvAlgo::kMaskSplitImplicitGemm;
int mask_split_count = is_mask_split ? 2 : 1;
int mask_split_count = is_mask_split ? 2 : 1;
tv::Tensor pair;
int64_t num_act_in = indices.dim(0);
int64_t num_act_in = indices.dim(0);
"""
)
code
.
raw
(
f
"""
tv::Tensor pair;
if (subm){{
if (subm){{
if (preallocated.find(
{
pccm
.
literal
(
AllocKeys
.
PairFwd
)
}
) != preallocated.end()){{
pair = preallocated.at(
{
pccm
.
literal
(
AllocKeys
.
PairFwd
)
}
);
}}
else{{
if (is_train){{
if (is_train){{
// query pair for fwd and bwd
// query pair for fwd and bwd
pair = allocator.full_int(
{
pccm
.
literal
(
AllocKeys
.
PairFwd
)
}
,
pair = allocator.full_int(
{
pccm
.
literal
(
AllocKeys
.
PairFwd
)
}
,
...
@@ -1502,6 +1665,7 @@ class SpconvOps(pccm.Class):
...
@@ -1502,6 +1665,7 @@ class SpconvOps(pccm.Class):
pair = allocator.full_int(
{
pccm
.
literal
(
AllocKeys
.
PairFwd
)
}
,
pair = allocator.full_int(
{
pccm
.
literal
(
AllocKeys
.
PairFwd
)
}
,
{{1, kv, num_act_in}}, -1, indices.dtype(), indices.device(), stream_int);
{{1, kv, num_act_in}}, -1, indices.dtype(), indices.device(), stream_int);
}}
}}
}}
}}else{{
}}else{{
if (is_train){{
if (is_train){{
// query pair bwd
// query pair bwd
...
@@ -1512,9 +1676,17 @@ class SpconvOps(pccm.Class):
...
@@ -1512,9 +1676,17 @@ class SpconvOps(pccm.Class):
pair = tv::Tensor();
pair = tv::Tensor();
}}
}}
}}
}}
"""
)
auto indice_num_per_loc = allocator.zeros(
{
pccm
.
literal
(
AllocKeys
.
IndiceNumPerLoc
)
}
,
code
.
raw
(
f
"""
tv::Tensor indice_num_per_loc;
if (preallocated.find(
{
pccm
.
literal
(
AllocKeys
.
IndiceNumPerLoc
)
}
) != preallocated.end()){{
indice_num_per_loc = preallocated.at(
{
pccm
.
literal
(
AllocKeys
.
IndiceNumPerLoc
)
}
);
}}
else{{
indice_num_per_loc = allocator.zeros(
{
pccm
.
literal
(
AllocKeys
.
IndiceNumPerLoc
)
}
,
{{kv}}, indices.dtype(), indices.device(), stream_int);
{{kv}}, indices.dtype(), indices.device(), stream_int);
}}
tv::Tensor mask_tensor = tv::zeros({{mask_split_count}}, tv::uint32, -1);
tv::Tensor mask_tensor = tv::zeros({{mask_split_count}}, tv::uint32, -1);
auto mask_tensor_ptr = mask_tensor.data_ptr<uint32_t>();
auto mask_tensor_ptr = mask_tensor.data_ptr<uint32_t>();
...
@@ -1533,29 +1705,45 @@ class SpconvOps(pccm.Class):
...
@@ -1533,29 +1705,45 @@ class SpconvOps(pccm.Class):
tv::Tensor out_inds;
tv::Tensor out_inds;
ThrustAllocator thrustalloc(allocator);
ThrustAllocator thrustalloc(allocator);
int num_act_out = 0;
int num_act_out = 0;
if (subm){{
"""
)
with
code
.
if_
(
"subm"
):
code
.
raw
(
f
"""
ExternalAllocator::guard_t hash_k_guard, hash_v_gurad, hash_kv_gurad;
ExternalAllocator::guard_t hash_k_guard, hash_v_gurad, hash_kv_gurad;
out_inds = indices;
out_inds = indices;
num_act_out = indices.dim(0);
num_act_out = indices.dim(0);
int num_points = out_inds.dim(0);
int hash_size = out_inds.dim(0) * 2;
"""
)
code
.
raw
(
f
"""
tv::Tensor hash_k, hash_v;
tv::Tensor hash_k, hash_v;
if (use_int64_hash_k){{
if (use_int64_hash_k){{
hash_k_guard = allocator.empty_guard({{
num_points * 2
}},
hash_k_guard = allocator.empty_guard({{
hash_size
}},
tv::int64, 0,
{
pccm
.
literal
(
AllocKeys
.
HashKOrKV
)
}
);
tv::int64, 0,
{
pccm
.
literal
(
AllocKeys
.
HashKOrKV
)
}
);
hash_v_gurad = allocator.empty_guard({{
num_points * 2
}},
hash_v_gurad = allocator.empty_guard({{
hash_size
}},
tv::int32, 0,
{
pccm
.
literal
(
AllocKeys
.
HashV
)
}
);
tv::int32, 0,
{
pccm
.
literal
(
AllocKeys
.
HashV
)
}
);
hash_k = hash_k_guard->tensor;
hash_k = hash_k_guard->tensor;
hash_v = hash_v_gurad->tensor;
hash_v = hash_v_gurad->tensor;
}}else{{
}}else{{
hash_kv_gurad = allocator.empty_guard({{2, num_points * 2}},
if (preallocated.find(
{
pccm
.
literal
(
AllocKeys
.
HashKOrKV
)
}
) != preallocated.end()){{
auto hash_kv = preallocated.at(
{
pccm
.
literal
(
AllocKeys
.
HashKOrKV
)
}
);
hash_k = hash_kv[0];
hash_v = hash_kv[1];
}}else{{
hash_kv_gurad = allocator.empty_guard({{2, hash_size}},
tv::int32, 0,
{
pccm
.
literal
(
AllocKeys
.
HashKOrKV
)
}
);
tv::int32, 0,
{
pccm
.
literal
(
AllocKeys
.
HashKOrKV
)
}
);
hash_k = hash_kv_gurad->tensor[0];
hash_k = hash_kv_gurad->tensor[0];
hash_v = hash_kv_gurad->tensor[1];
hash_v = hash_kv_gurad->tensor[1];
}}
}}
}}
auto pair_mask = allocator.empty(
{
pccm
.
literal
(
AllocKeys
.
PairMask
)
}
,
"""
)
code
.
raw
(
f
"""
tv::Tensor pair_mask;
if (preallocated.find(
{
pccm
.
literal
(
AllocKeys
.
PairMask
)
}
) != preallocated.end()){{
pair_mask = preallocated.at(
{
pccm
.
literal
(
AllocKeys
.
PairMask
)
}
);
}}else{{
pair_mask = allocator.empty(
{
pccm
.
literal
(
AllocKeys
.
PairMask
)
}
,
{{mask_split_count, num_act_in}}, tv::uint32, 0, stream_int);
{{mask_split_count, num_act_in}}, tv::uint32, 0, stream_int);
}}
generate_subm_conv_inds(indices, hash_k, hash_v, pair, out_inds, indice_num_per_loc,
generate_subm_conv_inds(indices, hash_k, hash_v, pair, out_inds, indice_num_per_loc,
batch_size, input_dims, ksize, dilation, pair_mask, is_train, stream_int);
batch_size, input_dims, ksize, dilation, pair_mask, is_train, stream_int);
auto mask_argsort = allocator.empty(
{
pccm
.
literal
(
AllocKeys
.
MaskArgSort
)
}
,
auto mask_argsort = allocator.empty(
{
pccm
.
literal
(
AllocKeys
.
MaskArgSort
)
}
,
...
@@ -1563,64 +1751,135 @@ class SpconvOps(pccm.Class):
...
@@ -1563,64 +1751,135 @@ class SpconvOps(pccm.Class):
for (int j = 0; j < mask_split_count; ++j){{
for (int j = 0; j < mask_split_count; ++j){{
sort_1d_by_key_allocator_v2(pair_mask[j], thrustalloc, mask_argsort[j], stream_int);
sort_1d_by_key_allocator_v2(pair_mask[j], thrustalloc, mask_argsort[j], stream_int);
}}
}}
"""
)
}}else{{
with
code
.
else_
():
code
.
raw
(
f
"""
// auto start = tv::CPUEvent().record(stream_int);
auto pair_bwd = pair;
auto pair_bwd = pair;
auto pair_size = kv * num_act_in;
auto pair_size = kv * num_act_in;
ExternalAllocator::guard_t hash_k_guard, hash_v_gurad, hash_kv_gurad;
auto indice_pairs_uniq_guard = allocator.empty_guard({{int64_t(pair_size + 1)}},
ExternalAllocator::guard_t indice_pairs_uniq_guard, indice_pairs_uniq_bkp_guard;
tv::Tensor hash_k, hash_v, indice_pairs_uniq;
int max_num_act = get_handcrafted_max_act_out(num_act_in, ksize, stride, padding, dilation);
if (transposed){{
max_num_act = pair_size;
}}
int hash_size = int(max_num_act *
{
SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE
}
);
if (direct_table){{
if (use_int64_hash_k){{
// temp memory don't need to be fixed, static alloc will check
// that tensor is large enough.
hash_k_guard = allocator.empty_guard({{hash_size}},
tv::int64, 0,
{
pccm
.
literal
(
AllocKeys
.
HashKOrKV
)
}
);
hash_v_gurad = allocator.empty_guard({{hash_size}},
tv::int32, 0,
{
pccm
.
literal
(
AllocKeys
.
HashV
)
}
);
hash_k = hash_k_guard->tensor;
hash_v = hash_v_gurad->tensor;
}}else{{
hash_kv_gurad = allocator.empty_guard({{2, hash_size}},
tv::int32, 0,
{
pccm
.
literal
(
AllocKeys
.
HashKOrKV
)
}
);
hash_k = hash_kv_gurad->tensor[0];
hash_v = hash_kv_gurad->tensor[1];
}}
}}
indice_pairs_uniq_guard = allocator.empty_guard({{2, int64_t(pair_size + 1)}},
indice_uniq_dtype, 0,
{
pccm
.
literal
(
AllocKeys
.
IndicePairsUniq
)
}
);
indice_uniq_dtype, 0,
{
pccm
.
literal
(
AllocKeys
.
IndicePairsUniq
)
}
);
auto indice_pairs_uniq_bkp_guard = allocator.empty_guard({{int64_t(pair_size + 1)}},
indice_uniq_dtype, 0,
{
pccm
.
literal
(
AllocKeys
.
IndicePairsUniqBackup
)
}
);
auto indice_pairs_uniq = indice_pairs_uniq_guard->tensor;
indice_pairs_uniq = indice_pairs_uniq_guard->tensor[0];
auto indice_pairs_uniq_bkp = indice_pairs_uniq_guard->tensor[1];
// indice_pairs_uniq_bkp_guard = allocator.empty_guard({{int64_t(pair_size + 1)}},
// indice_uniq_dtype, 0,
{
pccm
.
literal
(
AllocKeys
.
IndicePairsUniqBackup
)
}
);
{{
tv::CUDAKernelTimerGuard timer_guard("gen_conv_inds_stage1",
timer, reinterpret_cast<cudaStream_t>(stream_int));
if (direct_table){{
generate_conv_inds_mask_stage1_direct_table(indices,
hash_k, hash_v, pair_bwd, indice_pairs_uniq_bkp,
indice_num_per_loc, batch_size, out_shape, input_dims, ksize,
stride, padding, dilation, transposed, stream_int);
}}else{{
generate_conv_inds_mask_stage1(indices, pair_bwd, indice_pairs_uniq,
generate_conv_inds_mask_stage1(indices, pair_bwd, indice_pairs_uniq,
indice_num_per_loc, batch_size, out_shape, input_dims, ksize,
indice_num_per_loc, batch_size, out_shape, input_dims, ksize,
stride, padding, dilation, transposed, stream_int);
stride, padding, dilation, transposed, stream_int);
indice_pairs_uniq_bkp_guard->tensor.copy_(indice_pairs_uniq, tvctx);
indice_pairs_uniq_bkp.copy_(indice_pairs_uniq, tvctx);
// TODO pytorch unique may be faster?
}}
}}
// TODO pytorch unique run faster.
{{
tv::CUDAKernelTimerGuard timer_guard(std::string("unique_") + std::to_string(indice_pairs_uniq.dim(0)),
timer, reinterpret_cast<cudaStream_t>(stream_int));
if (direct_table){{
auto uniqcnt = allocator.zeros_guard({{1}}, tv::int32, 0,
{
pccm
.
literal
(
AllocKeys
.
TightUniqueCount
)
}
, stream_int);
num_act_out = unique_hash(hash_k, hash_v, uniqcnt->tensor,
indice_pairs_uniq, num_out_act_bound, stream_int);
}}else{{
num_act_out = apply_thrust_unique_to_indice_pairs_uniq(indice_pairs_uniq, thrustalloc, stream_int);
num_act_out = apply_thrust_unique_to_indice_pairs_uniq(indice_pairs_uniq, thrustalloc, stream_int);
}}
}}
// tv::ssprint("HASH SIZE", hash_size, num_act_out);
if (num_out_act_bound > 0 && num_act_out > num_out_act_bound){{
if (num_out_act_bound > 0 && num_act_out > num_out_act_bound){{
num_act_out = num_out_act_bound;
num_act_out = num_out_act_bound;
}}
}}
indice_pairs_uniq = indice_pairs_uniq.slice_first_axis(0, num_act_out);
indice_pairs_uniq = indice_pairs_uniq.slice_first_axis(0, num_act_out);
// for fixed size allocator, all memory alloc size must be fixed.
// for fixed size allocator, all memory alloc size must be fixed.
tv::Tensor pair_fwd, pair_mask_fwd, pair_mask_bwd;
{{
tv::CUDAKernelTimerGuard timer_guard("alloc_stage2",
timer, reinterpret_cast<cudaStream_t>(stream_int));
out_inds = allocator.empty(
{
pccm
.
literal
(
AllocKeys
.
OutIndices
)
}
,
out_inds = allocator.empty(
{
pccm
.
literal
(
AllocKeys
.
OutIndices
)
}
,
{{num_act_out, indices.dim(1)}}, indices.dtype(), 0, stream_int);
{{num_act_out, indices.dim(1)}}, indices.dtype(), 0, stream_int);
auto
pair_fwd = allocator.full_int(
{
pccm
.
literal
(
AllocKeys
.
PairFwd
)
}
,
pair_fwd = allocator.full_int(
{
pccm
.
literal
(
AllocKeys
.
PairFwd
)
}
,
{{kv, num_act_out}}, -1, indices.dtype(), indices.device(), stream_int);
{{kv, num_act_out}}, -1, indices.dtype(), indices.device(), stream_int);
auto
pair_mask_fwd = allocator.zeros(
{
pccm
.
literal
(
AllocKeys
.
PairMask
)
}
,
pair_mask_fwd = allocator.zeros(
{
pccm
.
literal
(
AllocKeys
.
PairMask
)
}
,
{{mask_split_count, num_act_out}}, tv::uint32, 0, stream_int);
{{mask_split_count, num_act_out}}, tv::uint32, 0, stream_int);
auto
pair_mask_bwd = tv::Tensor();
pair_mask_bwd = tv::Tensor();
if (is_train){{
if (is_train){{
pair_mask_bwd = allocator.zeros(
{
pccm
.
literal
(
AllocKeys
.
PairMaskBwd
)
}
,
pair_mask_bwd = allocator.zeros(
{
pccm
.
literal
(
AllocKeys
.
PairMaskBwd
)
}
,
{{mask_split_count, indices.dim(0)}}, tv::uint32, 0, stream_int);
{{mask_split_count, indices.dim(0)}}, tv::uint32, 0, stream_int);
}}
}}
}}
ExternalAllocator::guard_t hash_k_guard, hash_v_gurad, hash_kv_gurad;
if (!direct_table){{
tv::Tensor hash_k, hash_v
;
int hash_size = int(num_act_out * 2)
;
if (use_int64_hash_k){{
if (use_int64_hash_k){{
// temp memory don't need to be fixed, static alloc will check
// temp memory don't need to be fixed, static alloc will check
// that tensor is large enough.
// that tensor is large enough.
hash_k_guard = allocator.empty_guard({{
num_act_out * 2
}},
hash_k_guard = allocator.empty_guard({{
hash_size
}},
tv::int64, 0,
{
pccm
.
literal
(
AllocKeys
.
HashKOrKV
)
}
);
tv::int64, 0,
{
pccm
.
literal
(
AllocKeys
.
HashKOrKV
)
}
);
hash_v_gurad = allocator.empty_guard({{
num_act_out * 2
}},
hash_v_gurad = allocator.empty_guard({{
hash_size
}},
tv::int32, 0,
{
pccm
.
literal
(
AllocKeys
.
HashV
)
}
);
tv::int32, 0,
{
pccm
.
literal
(
AllocKeys
.
HashV
)
}
);
hash_k = hash_k_guard->tensor;
hash_k = hash_k_guard->tensor;
hash_v = hash_v_gurad->tensor;
hash_v = hash_v_gurad->tensor;
}}else{{
}}else{{
hash_kv_gurad = allocator.empty_guard({{2,
num_act_out * 2
}},
hash_kv_gurad = allocator.empty_guard({{2,
hash_size
}},
tv::int32, 0,
{
pccm
.
literal
(
AllocKeys
.
HashKOrKV
)
}
);
tv::int32, 0,
{
pccm
.
literal
(
AllocKeys
.
HashKOrKV
)
}
);
hash_k = hash_kv_gurad->tensor[0];
hash_k = hash_kv_gurad->tensor[0];
hash_v = hash_kv_gurad->tensor[1];
hash_v = hash_kv_gurad->tensor[1];
}}
}}
}}
{{
tv::CUDAKernelTimerGuard timer_guard(std::string("gen_conv_inds_stage2_") + std::to_string(num_act_out),
timer, reinterpret_cast<cudaStream_t>(stream_int));
if (direct_table){{
assign_output_direct_hash(indice_pairs_uniq, out_inds,
batch_size, out_shape,
input_dims, ksize, stride, padding, dilation, stream_int);
generate_conv_inds_stage2_mask_direct_table(indices, hash_k, hash_v, pair_fwd, pair_bwd,
indice_pairs_uniq, indice_pairs_uniq_bkp,
out_inds, pair_mask_fwd, pair_mask_bwd, num_act_out,
batch_size, out_shape, input_dims, ksize, stride, padding, dilation,
transposed, stream_int);
}}else{{
generate_conv_inds_mask_stage2(indices, hash_k, hash_v, pair_fwd, pair_bwd,
generate_conv_inds_mask_stage2(indices, hash_k, hash_v, pair_fwd, pair_bwd,
indice_pairs_uniq, indice_pairs_uniq_bkp
_guard->tensor
,
indice_pairs_uniq, indice_pairs_uniq_bkp,
out_inds, pair_mask_fwd, pair_mask_bwd, num_act_out,
out_inds, pair_mask_fwd, pair_mask_bwd, num_act_out,
batch_size, out_shape, input_dims, ksize, stride, padding, dilation,
batch_size, out_shape, input_dims, ksize, stride, padding, dilation,
transposed, stream_int);
transposed, stream_int);
}}
}}
"""
)
code
.
raw
(
f
"""
auto mask_argsort_fwd = allocator.empty(
{
pccm
.
literal
(
AllocKeys
.
MaskArgSort
)
}
,
auto mask_argsort_fwd = allocator.empty(
{
pccm
.
literal
(
AllocKeys
.
MaskArgSort
)
}
,
{{mask_split_count, num_act_out}}, tv::int32, 0, stream_int);
{{mask_split_count, num_act_out}}, tv::int32, 0, stream_int);
tv::Tensor mask_argsort_bwd = tv::Tensor();
tv::Tensor mask_argsort_bwd = tv::Tensor();
...
@@ -1628,7 +1887,9 @@ class SpconvOps(pccm.Class):
...
@@ -1628,7 +1887,9 @@ class SpconvOps(pccm.Class):
mask_argsort_bwd = allocator.zeros(
{
pccm
.
literal
(
AllocKeys
.
MaskArgSortBwd
)
}
,
mask_argsort_bwd = allocator.zeros(
{
pccm
.
literal
(
AllocKeys
.
MaskArgSortBwd
)
}
,
{{mask_split_count, num_act_in}}, tv::int32, 0, stream_int);
{{mask_split_count, num_act_in}}, tv::int32, 0, stream_int);
}}
}}
{{
tv::CUDAKernelTimerGuard timer_guard("gen_conv_inds_sort",
timer, reinterpret_cast<cudaStream_t>(stream_int));
if (is_mask_split){{
if (is_mask_split){{
for (int j = 0; j < mask_split_count; ++j){{
for (int j = 0; j < mask_split_count; ++j){{
auto mask_tensor_sub = mask_tensor.slice_first_axis(j, j + 1);
auto mask_tensor_sub = mask_tensor.slice_first_axis(j, j + 1);
...
@@ -1653,8 +1914,9 @@ class SpconvOps(pccm.Class):
...
@@ -1653,8 +1914,9 @@ class SpconvOps(pccm.Class):
mask_argsort_bwd[0], stream_int);
mask_argsort_bwd[0], stream_int);
}}
}}
}}
}}
}}
}}
"""
)
code
.
raw
(
f
"""
return std::make_tuple(mask_tensor, num_act_out);
return std::make_tuple(mask_tensor, num_act_out);
"""
)
"""
)
return
code
.
ret
(
"std::tuple<tv::Tensor, int>"
)
return
code
.
ret
(
"std::tuple<tv::Tensor, int>"
)
...
...
spconv/csrc/sparse/indices.py
View file @
73a5ce7d
...
@@ -73,7 +73,9 @@ class CudaCommonKernel(pccm.ParameterizedClass):
...
@@ -73,7 +73,9 @@ class CudaCommonKernel(pccm.ParameterizedClass):
"""
)
"""
)
return
code
return
code
class
ConvOutLocIter
(
pccm
.
ParameterizedClass
):
class
ConvOutLocIter
(
pccm
.
ParameterizedClass
):
def
__init__
(
self
,
problem
:
ConvProblem
):
def
__init__
(
self
,
problem
:
ConvProblem
):
super
().
__init__
()
super
().
__init__
()
self
.
add_dependency
(
TensorView
)
self
.
add_dependency
(
TensorView
)
...
@@ -264,6 +266,7 @@ class ConvOutLocIter(pccm.ParameterizedClass):
...
@@ -264,6 +266,7 @@ class ConvOutLocIter(pccm.ParameterizedClass):
class
SparseConvIndicesKernel
(
pccm
.
ParameterizedClass
):
class
SparseConvIndicesKernel
(
pccm
.
ParameterizedClass
):
def
__init__
(
self
,
problem
:
ConvProblem
,
dtype_indices
:
dtypes
.
DType
):
def
__init__
(
self
,
problem
:
ConvProblem
,
dtype_indices
:
dtypes
.
DType
):
super
().
__init__
()
super
().
__init__
()
self
.
add_dependency
(
TensorView
,
TensorViewKernel
,
TensorViewHashKernel
)
self
.
add_dependency
(
TensorView
,
TensorViewKernel
,
TensorViewHashKernel
)
...
@@ -278,7 +281,6 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -278,7 +281,6 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
assert
dtype_indices
==
dtypes
.
int32
or
dtype_indices
==
dtypes
.
int64
assert
dtype_indices
==
dtypes
.
int32
or
dtype_indices
==
dtypes
.
int64
@
pccm
.
cuda
.
cuda_global_function
@
pccm
.
cuda
.
cuda_global_function
def
calc_conv_indices_stage1
(
self
):
def
calc_conv_indices_stage1
(
self
):
code
=
pccm
.
FunctionCode
()
code
=
pccm
.
FunctionCode
()
...
@@ -331,7 +333,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -331,7 +333,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
code
.
arg
(
"table"
,
f
"TTable"
)
# [N, ndim + 1]
code
.
arg
(
"table"
,
f
"TTable"
)
# [N, ndim + 1]
code
.
arg
(
"indices_out"
,
f
"int*"
)
# [N, ndim + 1]
code
.
arg
(
"indices_out"
,
f
"int*"
)
# [N, ndim + 1]
code
.
arg
(
"indice_pairs_for_uniq"
,
code
.
arg
(
"indice_pairs_for_uniq"
,
f
"const typename TTable::key_type*"
)
# [2, kernelProd, MaxSize]
f
"const typename TTable::key_type*"
)
# [2, kernelProd, MaxSize]
code
.
arg
(
"layout_npq"
,
code
.
arg
(
"layout_npq"
,
...
@@ -349,12 +352,86 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -349,12 +352,86 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
return
code
return
code
@
pccm
.
cuda
.
cuda_global_function
@
pccm
.
cuda
.
cuda_global_function
def
calc_conv_indices_stage2
(
self
):
def
arange_hash_table_and_assign_out
(
self
):
code
=
pccm
.
FunctionCode
()
code
.
targ
(
"TTable"
)
code
.
arg
(
"table"
,
f
"TTable"
)
# [N, ndim + 1]
code
.
arg
(
"indices_out"
,
f
"int*"
)
# [N, ndim + 1]
code
.
arg
(
"count"
,
f
"int*"
)
# [N, ndim + 1]
code
.
arg
(
"limit"
,
f
"int"
)
# [N, ndim + 1]
code
.
arg
(
"layout_npq"
,
f
"spinds::LayoutNPQ"
)
# [2, kernelProd, MaxSize]
code
.
raw
(
f
"""
auto key_ptr = table.key_ptr();
auto value_ptr = table.value_ptr();
for (auto i : tv::KernelLoopX<int>(table.size())) {{
auto output_coord_offset = key_ptr[i];
if (output_coord_offset != TTable::empty_key) {{
auto output_index = tv::cuda::atomicAggInc(count);
if (output_index < limit){{
value_ptr[i] = output_index;
layout_npq.inverse(output_coord_offset, indices_out +
{
self
.
ndim
+
1
}
* output_index);
}}else{{
value_ptr[i] = -1;
}}
}}
}}
"""
)
return
code
@
pccm
.
cuda
.
cuda_global_function
def
arange_hash_table
(
self
):
code
=
pccm
.
FunctionCode
()
code
=
pccm
.
FunctionCode
()
code
.
targ
(
"TTable"
)
code
.
targ
(
"TTable"
)
code
.
arg
(
"table"
,
f
"TTable"
)
# [N, ndim + 1]
code
.
arg
(
"out_indices_offset"
,
f
"typename TTable::key_type *"
)
# [N, ndim + 1]
code
.
arg
(
"count"
,
f
"int*"
)
# [N, ndim + 1]
code
.
arg
(
"limit"
,
f
"int"
)
# [N, ndim + 1]
code
.
raw
(
f
"""
auto key_ptr = table.key_ptr();
auto value_ptr = table.value_ptr();
for (auto i : tv::KernelLoopX<int>(table.size())) {{
auto output_coord_offset = key_ptr[i];
if (output_coord_offset != TTable::empty_key) {{
auto output_index = tv::cuda::atomicAggInc(count);
value_ptr[i] = output_index < limit ? output_index : -1;
out_indices_offset[output_index] = output_coord_offset;
}}
}}
"""
)
return
code
@
pccm
.
cuda
.
cuda_global_function
def
assign_out_indices
(
self
):
code
=
pccm
.
FunctionCode
()
code
.
targ
(
"T"
)
code
.
arg
(
"indices_out"
,
f
"int*"
)
# [N, ndim + 1]
code
.
arg
(
"out_indices_offset"
,
f
"const T*"
)
# [N, ndim + 1]
code
.
arg
(
"layout_npq"
,
f
"spinds::LayoutNPQ"
)
# [2, kernelProd, MaxSize]
code
.
arg
(
"size"
,
f
"int"
)
# [N, ndim + 1]
code
.
raw
(
f
"""
for (auto i : tv::KernelLoopX<int>(size)) {{
layout_npq.inverse(out_indices_offset[i], indices_out +
{
self
.
ndim
+
1
}
* i);
}}
"""
)
return
code
@
pccm
.
cuda
.
cuda_global_function
def
calc_conv_indices_stage2
(
self
):
code
=
pccm
.
FunctionCode
()
code
.
targ
(
"TTable"
)
code
.
arg
(
"table"
,
f
"TTable"
)
# [N, ndim + 1]
code
.
arg
(
"table"
,
f
"TTable"
)
# [N, ndim + 1]
code
.
arg
(
"indice_pairs_uniq_before_sort"
,
f
"const typename TTable::key_type*"
)
# [kernelProd, MaxSize]
code
.
arg
(
"indice_pairs_uniq_before_sort"
,
f
"const typename TTable::key_type*"
)
# [kernelProd, MaxSize]
code
.
arg
(
"indice_pairs_out_part"
,
f
"int*"
)
# [kernelProd, MaxSize]
code
.
arg
(
"indice_pairs_out_part"
,
f
"int*"
)
# [kernelProd, MaxSize]
code
.
arg
(
"num_indices_in"
,
"int"
)
code
.
arg
(
"num_indices_in"
,
"int"
)
code
.
arg
(
"indices_pair_size"
,
"int"
)
code
.
arg
(
"indices_pair_size"
,
"int"
)
...
@@ -362,7 +439,6 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -362,7 +439,6 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
int filter_offset = blockIdx.y;
int filter_offset = blockIdx.y;
auto indice_pairs_out_part_filter = indice_pairs_out_part + filter_offset * indices_pair_size;
auto indice_pairs_out_part_filter = indice_pairs_out_part + filter_offset * indices_pair_size;
auto indice_pairs_uniq_before_sort_filter = indice_pairs_uniq_before_sort + filter_offset * indices_pair_size;
auto indice_pairs_uniq_before_sort_filter = indice_pairs_uniq_before_sort + filter_offset * indices_pair_size;
for (int i : tv::KernelLoopX<int>(num_indices_in)) {{
for (int i : tv::KernelLoopX<int>(num_indices_in)) {{
{
self
.
dtype_indices
}
output_coord_offset = indice_pairs_uniq_before_sort_filter[i];
{
self
.
dtype_indices
}
output_coord_offset = indice_pairs_uniq_before_sort_filter[i];
if (output_coord_offset != std::numeric_limits<typename TTable::key_type>::max()){{
if (output_coord_offset != std::numeric_limits<typename TTable::key_type>::max()){{
...
@@ -386,8 +462,10 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -386,8 +462,10 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
code
.
targ
(
"TTable"
)
code
.
targ
(
"TTable"
)
code
.
arg
(
"table"
,
f
"TTable"
)
# [N, ndim + 1]
code
.
arg
(
"table"
,
f
"TTable"
)
# [N, ndim + 1]
code
.
arg
(
"indice_pairs_uniq_before_sort"
,
f
"const typename TTable::key_type*"
)
# [kernelProd, MaxSize]
code
.
arg
(
"indice_pairs_uniq_before_sort"
,
code
.
arg
(
"indice_pairs_in_part_temp"
,
f
"const int*"
)
# [kernelProd, MaxSize]
f
"const typename TTable::key_type*"
)
# [kernelProd, MaxSize]
code
.
arg
(
"indice_pairs_in_part_temp"
,
f
"const int*"
)
# [kernelProd, MaxSize]
code
.
arg
(
"indice_pairs_in_part"
,
f
"int*"
)
# [kernelProd, MaxSize]
code
.
arg
(
"indice_pairs_in_part"
,
f
"int*"
)
# [kernelProd, MaxSize]
code
.
arg
(
"indice_pairs_out_part"
,
f
"int*"
)
# [kernelProd, MaxSize]
code
.
arg
(
"indice_pairs_out_part"
,
f
"int*"
)
# [kernelProd, MaxSize]
code
.
arg
(
"indice_num_per_loc"
,
f
"int*"
)
# [kernelProd]
code
.
arg
(
"indice_num_per_loc"
,
f
"int*"
)
# [kernelProd]
...
@@ -448,13 +526,63 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -448,13 +526,63 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
valid = loc_iter.query_npq(indices_in + input_index *
{
self
.
ndim
+
1
}
, npq_offset);
valid = loc_iter.query_npq(indices_in + input_index *
{
self
.
ndim
+
1
}
, npq_offset);
}}
}}
if (valid){{
if (valid){{
int old_num = tv::cuda::atomicAggInc(indice_num_per_loc + filter_offset);
// int old_num = tv::cuda::atomicAggInc(indice_num_per_loc + filter_offset);
TIndiceUniq output_coord_offset = loc_iter.layout_npq(npq_offset);
// if (old_num < indices_pair_size){{
// indice_pairs[filter_offset_mul_indices_pair_size + old_num] = i;
// indice_pairs_bwd[filter_offset_mul_indices_pair_size + input_index] = output_coord_offset;
// indice_pairs_for_uniq[filter_offset_mul_indices_pair_size + old_num] = output_coord_offset;
indice_pairs_for_uniq[filter_offset_mul_indices_pair_size + input_index] = output_coord_offset;
// }}
}}
}}
"""
)
return
code
@
pccm
.
cuda
.
cuda_global_function
def
calc_conv_indices_stage1_mask_direct_table
(
self
):
code
=
pccm
.
FunctionCode
()
code
.
targ
(
"TIndiceUniq"
)
code
.
targ
(
"TTable"
)
code
.
arg
(
"table"
,
f
"TTable"
)
# [N, ndim + 1]
code
.
arg
(
"loc_iter"
,
f
"ConvLocIter"
)
# [N, ndim + 1]
code
.
arg
(
"indices_in"
,
f
"const int*"
)
# [N, ndim + 1]
code
.
arg
(
"indice_pairs_bwd"
,
f
"
{
self
.
dtype_indices
}
*"
)
# [kernelProd, MaxSize]
code
.
arg
(
"indice_pairs_for_uniq"
,
f
"TIndiceUniq*"
)
# [kernelProd * MaxSize + 1]
code
.
arg
(
"indice_num_per_loc"
,
f
"int*"
)
# [kernelProd]
code
.
arg
(
"num_indices_in"
,
"int"
)
code
.
arg
(
"RS"
,
"int"
)
code
.
arg
(
"transposed"
,
"bool"
)
code
.
raw
(
f
"""
int filter_offset = blockIdx.y;
loc_iter.set_filter_offset(filter_offset);
// int indices_pair_size_mul_RS = num_indices_in * RS;
int filter_offset_mul_indices_pair_size = filter_offset * num_indices_in;
for (int input_index : tv::KernelLoopX<int>(num_indices_in)) {{
tv::array<int,
{
self
.
ndim
+
1
}
> npq_offset;
bool valid;
if (transposed){{
valid = loc_iter.query_nhw_out(indices_in + input_index *
{
self
.
ndim
+
1
}
, npq_offset);
}}else{{
valid = loc_iter.query_npq(indices_in + input_index *
{
self
.
ndim
+
1
}
, npq_offset);
}}
if (valid){{
// int old_num = tv::cuda::atomicAggInc(indice_num_per_loc + filter_offset);
TIndiceUniq output_coord_offset = loc_iter.layout_npq(npq_offset);
TIndiceUniq output_coord_offset = loc_iter.layout_npq(npq_offset);
// if (old_num < indices_pair_size){{
// if (old_num < indices_pair_size){{
// indice_pairs[filter_offset_mul_indices_pair_size + old_num] = i;
// indice_pairs[filter_offset_mul_indices_pair_size + old_num] = i;
// indice_pairs_bwd[filter_offset_mul_indices_pair_size + input_index] = output_coord_offset;
// indice_pairs_bwd[filter_offset_mul_indices_pair_size + input_index] = output_coord_offset;
// indice_pairs_for_uniq[filter_offset_mul_indices_pair_size + old_num] = output_coord_offset;
// indice_pairs_for_uniq[filter_offset_mul_indices_pair_size + old_num] = output_coord_offset;
table.insert_key_only(output_coord_offset);
indice_pairs_for_uniq[filter_offset_mul_indices_pair_size + input_index] = output_coord_offset;
indice_pairs_for_uniq[filter_offset_mul_indices_pair_size + input_index] = output_coord_offset;
// }}
// }}
}}
}}
...
@@ -466,12 +594,15 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -466,12 +594,15 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
def
calc_conv_indices_stage2_mask
(
self
):
def
calc_conv_indices_stage2_mask
(
self
):
code
=
pccm
.
FunctionCode
()
code
=
pccm
.
FunctionCode
()
code
.
targ
(
"TTable"
)
code
.
targ
(
"TTable"
)
code
.
nontype_targ
(
"CheckValueValid"
,
"bool"
)
code
.
arg
(
"table"
,
f
"TTable"
)
# [N, ndim + 1]
code
.
arg
(
"table"
,
f
"TTable"
)
# [N, ndim + 1]
code
.
arg
(
"indice_pairs_fwd"
,
code
.
arg
(
"indice_pairs_fwd"
,
f
"int*"
)
# [kernelProd, MaxSize], inp -> out
f
"int*"
)
# [kernelProd, MaxSize], inp -> out
code
.
arg
(
"indice_pairs_bwd"
,
code
.
arg
(
"indice_pairs_bwd"
,
f
"int*"
)
# [kernelProd, MaxSize], out -> inp
f
"int*"
)
# [kernelProd, MaxSize], out -> inp
code
.
arg
(
"indice_pairs_uniq_before_sort"
,
f
"const typename TTable::key_type*"
)
# [kernelProd, MaxSize]
code
.
arg
(
"indice_pairs_uniq_before_sort"
,
f
"const typename TTable::key_type*"
)
# [kernelProd, MaxSize]
code
.
arg
(
"mask_fwd"
,
f
"uint32_t*"
)
# [kernelProd]
code
.
arg
(
"mask_fwd"
,
f
"uint32_t*"
)
# [kernelProd]
code
.
arg
(
"mask_bwd"
,
f
"uint32_t*"
)
# [kernelProd]
code
.
arg
(
"mask_bwd"
,
f
"uint32_t*"
)
# [kernelProd]
...
@@ -495,6 +626,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -495,6 +626,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
auto table_offset = table.lookup_offset(output_coord_offset);
auto table_offset = table.lookup_offset(output_coord_offset);
if (table_offset != -1){{
if (table_offset != -1){{
auto output_index = table.value_ptr()[table_offset];
auto output_index = table.value_ptr()[table_offset];
bool valid = CheckValueValid ? output_index >= 0 : true;
if (valid){{
atomicOr(mask_fwd + output_index, filter_mask_fwd);
atomicOr(mask_fwd + output_index, filter_mask_fwd);
// atomicOr(mask_bwd + input_index, filter_mask_bwd);
// atomicOr(mask_bwd + input_index, filter_mask_bwd);
indice_pairs_fwd_filter[output_index] = input_index;
indice_pairs_fwd_filter[output_index] = input_index;
...
@@ -504,6 +637,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -504,6 +637,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
}}
}}
}}
}}
}}
}}
}}
"""
)
"""
)
return
code
return
code
...
@@ -533,13 +667,15 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -533,13 +667,15 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
def
calc_conv_indices_stage2_inference_mask
(
self
):
def
calc_conv_indices_stage2_inference_mask
(
self
):
code
=
pccm
.
FunctionCode
()
code
=
pccm
.
FunctionCode
()
code
.
targ
(
"TTable"
)
code
.
targ
(
"TTable"
)
code
.
nontype_targ
(
"CheckValueValid"
,
"bool"
)
code
.
arg
(
"table"
,
f
"TTable"
)
# [N, ndim + 1]
code
.
arg
(
"table"
,
f
"TTable"
)
# [N, ndim + 1]
code
.
arg
(
"indice_pairs_fwd"
,
code
.
arg
(
"indice_pairs_fwd"
,
f
"int*"
)
# [kernelProd, MaxSize], inp -> out
f
"int*"
)
# [kernelProd, MaxSize], inp -> out
code
.
arg
(
"indice_pairs_bwd"
,
code
.
arg
(
"indice_pairs_bwd"
,
f
"int*"
)
# [kernelProd, MaxSize], out -> inp
f
"int*"
)
# [kernelProd, MaxSize], out -> inp
code
.
arg
(
"indice_pairs_uniq_before_sort"
,
f
"const typename TTable::key_type*"
)
# [kernelProd, MaxSize]
code
.
arg
(
"indice_pairs_uniq_before_sort"
,
f
"const typename TTable::key_type*"
)
# [kernelProd, MaxSize]
code
.
arg
(
"mask_fwd"
,
f
"uint32_t*"
)
# [kernelProd]
code
.
arg
(
"mask_fwd"
,
f
"uint32_t*"
)
# [kernelProd]
code
.
arg
(
"num_indices_in"
,
"int"
)
code
.
arg
(
"num_indices_in"
,
"int"
)
...
@@ -559,11 +695,14 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -559,11 +695,14 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
auto table_offset = table.lookup_offset(output_coord_offset);
auto table_offset = table.lookup_offset(output_coord_offset);
if (table_offset != -1){{
if (table_offset != -1){{
auto output_index = table.value_ptr()[table_offset];
auto output_index = table.value_ptr()[table_offset];
bool valid = CheckValueValid ? output_index >= 0 : true;
if (valid){{
atomicOr(mask_fwd + output_index, filter_mask_fwd);
atomicOr(mask_fwd + output_index, filter_mask_fwd);
indice_pairs_fwd_filter[output_index] = input_index;
indice_pairs_fwd_filter[output_index] = input_index;
}}
}}
}}
}}
}}
}}
}}
"""
)
"""
)
return
code
return
code
...
@@ -854,7 +993,9 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -854,7 +993,9 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
def
generate_conv_inds_stage2
(
self
):
def
generate_conv_inds_stage2
(
self
):
code
=
pccm
.
FunctionCode
()
code
=
pccm
.
FunctionCode
()
code
.
arg
(
"indices, hashdata_k, hashdata_v"
,
"tv::Tensor"
)
code
.
arg
(
"indices, hashdata_k, hashdata_v"
,
"tv::Tensor"
)
code
.
arg
(
"indice_pairs, indice_pairs_uniq, indice_pairs_uniq_before_sort, out_inds"
,
"tv::Tensor"
)
code
.
arg
(
"indice_pairs, indice_pairs_uniq, indice_pairs_uniq_before_sort, out_inds"
,
"tv::Tensor"
)
code
.
arg
(
"indice_num_per_loc"
,
"tv::Tensor"
)
code
.
arg
(
"indice_num_per_loc"
,
"tv::Tensor"
)
code
.
arg
(
"num_out_act"
,
"int"
)
code
.
arg
(
"num_out_act"
,
"int"
)
...
@@ -938,8 +1079,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -938,8 +1079,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
def
generate_conv_inds_mask_stage1
(
self
):
def
generate_conv_inds_mask_stage1
(
self
):
code
=
pccm
.
FunctionCode
()
code
=
pccm
.
FunctionCode
()
code
.
arg
(
"indices"
,
"tv::Tensor"
)
code
.
arg
(
"indices"
,
"tv::Tensor"
)
code
.
arg
(
"indice_pairs_bwd, indice_pairs_uniq,
indice_num_per_loc"
,
code
.
arg
(
"indice_pairs_bwd, indice_pairs_uniq
"
,
"tv::Tensor"
)
"tv::Tensor"
)
code
.
arg
(
"indice_num_per_loc"
,
"tv::Tensor"
)
code
.
arg
(
"batch_size"
,
"int"
)
code
.
arg
(
"batch_size"
,
"int"
)
code
.
arg
(
"output_dims, input_dims"
,
f
"tv::array<int,
{
self
.
ndim
}
>"
)
code
.
arg
(
"output_dims, input_dims"
,
f
"tv::array<int,
{
self
.
ndim
}
>"
)
code
.
arg
(
"ksize, stride, padding, dilation"
,
code
.
arg
(
"ksize, stride, padding, dilation"
,
...
@@ -982,8 +1123,67 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -982,8 +1123,67 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
"""
)
"""
)
return
code
# .ret("int")
return
code
# .ret("int")
@
pccm
.
cuda
.
static_function
@
pccm
.
cuda
.
static_function
def
generate_conv_inds_stage2_mask
(
self
):
def
generate_conv_inds_mask_stage1_direct_table
(
self
):
code
=
pccm
.
FunctionCode
()
code
.
arg
(
"indices, hashdata_k, hashdata_v"
,
"tv::Tensor"
)
code
.
arg
(
"indice_pairs_bwd, indice_pairs_uniq"
,
"tv::Tensor"
)
code
.
arg
(
"indice_num_per_loc"
,
"tv::Tensor"
)
code
.
arg
(
"batch_size"
,
"int"
)
code
.
arg
(
"output_dims, input_dims"
,
f
"tv::array<int,
{
self
.
ndim
}
>"
)
code
.
arg
(
"ksize, stride, padding, dilation"
,
f
"tv::array<int,
{
self
.
ndim
}
>"
)
code
.
arg
(
"transposed"
,
f
"bool"
,
"false"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
)
code
.
raw
(
f
"""
// TODO stream
// TODO handle num input == 0
int kv = ksize.op<tv::arrayops::prod>();
int num_act_in = indices.dim(0);
// indice_pairs_bwd: [kv, num_act_in] or empty
// indice_pairs_uniq: [kv * num_act_in + 1]
if (!indice_pairs_bwd.empty()){{
tv::check_shape(indice_pairs_bwd, {{kv, num_act_in}});
}}
tv::check_shape(indice_num_per_loc, {{kv}});
int64_t uniq_size = kv * num_act_in + 1;
TV_ASSERT_RT_ERR(indice_pairs_uniq.dim(0) == uniq_size, "error");
tv::cuda::Launch launcher_num_act_in(indices.dim(0), reinterpret_cast<cudaStream_t>(stream_int));
// tv::cuda::Launch launcher_num_act_in_2(indices.dim(0));
launcher_num_act_in.blocks.y = kv;
ConvProblem problem(batch_size, 1, 1, input_dims, output_dims, ksize, padding, stride, dilation);
ConvLocIter loc_iter(problem);
tv::cuda::Launch launcher_clean_uniq(uniq_size, reinterpret_cast<cudaStream_t>(stream_int));
tv::dispatch<int32_t, int64_t>(indice_pairs_uniq.dtype(), [&](auto I){{
using V =
{
self
.
dtype_indices
}
;
using K = TV_DECLTYPE(I);
using table_t =
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
tv::hash::default_empty_key_v<K>, false>;
table_t table = table_t(hashdata_k.data_ptr<K>(), hashdata_v.data_ptr<V>(), hashdata_k.dim(0));
tv::hash::clear_map_split(table, reinterpret_cast<cudaStream_t>(stream_int));
using T = TV_DECLTYPE(I);
TV_ASSERT_RT_ERR(input_dims.op<tv::arrayops::prod>() < std::numeric_limits<T>::max(),
"kernel volume must smaller than max value of T");
launcher_clean_uniq(clean_indices_uniq<T>, indice_pairs_uniq.data_ptr<T>(), uniq_size);
launcher_num_act_in(calc_conv_indices_stage1_mask_direct_table<T, table_t>, table,
loc_iter, indices.data_ptr<const int>(),
indice_pairs_bwd.data_ptr<
{
self
.
dtype_indices
}
>(),
indice_pairs_uniq.data_ptr<T>(), indice_num_per_loc.data_ptr<int>(),
indices.dim(0),
kv, transposed);
}});
"""
)
return
code
def
generate_conv_inds_stage2_mask_template
(
self
,
is_direct_table
:
bool
):
"""here indice_pairs_uniq may be bounded, some
"""here indice_pairs_uniq may be bounded, some
points may be dropped.
points may be dropped.
"""
"""
...
@@ -1013,8 +1213,12 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -1013,8 +1213,12 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
ctx.set_cuda_stream(custream);
ctx.set_cuda_stream(custream);
int num_act_in = indices.dim(0);
int num_act_in = indices.dim(0);
int num_act_out = num_out_act;
int num_act_out = num_out_act;
"""
)
if
not
is_direct_table
:
code
.
raw
(
f
"""
TV_ASSERT_RT_ERR(hashdata_k.dtype() == indice_pairs_uniq.dtype(), "error");
TV_ASSERT_RT_ERR(hashdata_k.dtype() == indice_pairs_uniq.dtype(), "error");
"""
)
code
.
raw
(
f
"""
TV_ASSERT_RT_ERR(hashdata_v.dtype() == tv::int32, "error");
TV_ASSERT_RT_ERR(hashdata_v.dtype() == tv::int32, "error");
// out_inds: [num_out_act,
{
self
.
ndim
+
1
}
]
// out_inds: [num_out_act,
{
self
.
ndim
+
1
}
]
// auto timer = tv::CudaContextTimer<>();
// auto timer = tv::CudaContextTimer<>();
...
@@ -1030,11 +1234,17 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -1030,11 +1234,17 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
ConvProblem problem(batch_size, 1, 1, input_dims, output_dims, ksize, padding, stride, dilation);
ConvProblem problem(batch_size, 1, 1, input_dims, output_dims, ksize, padding, stride, dilation);
ConvLocIter loc_iter(problem);
ConvLocIter loc_iter(problem);
tv::cuda::Launch lanucher_build_hash(num_out_act, custream);
// TODO handle invalid num_out_act
// TODO handle invalid num_out_act
"""
)
if
not
is_direct_table
:
code
.
raw
(
f
"""
indice_pairs_uniq = indice_pairs_uniq.slice_first_axis(0, num_out_act);
indice_pairs_uniq = indice_pairs_uniq.slice_first_axis(0, num_out_act);
tv::cuda::Launch lanucher_build_hash(num_out_act, custream);
"""
)
tv::dispatch<int32_t, int64_t>(hashdata_k.dtype(), [&](auto I){{
with
code
.
block
(
""
,
start
=
"tv::dispatch<int32_t, int64_t>(hashdata_k.dtype(), [&](auto I){"
,
end
=
"});"
):
code
.
raw
(
f
"""
using V =
{
self
.
dtype_indices
}
;
using V =
{
self
.
dtype_indices
}
;
using K = TV_DECLTYPE(I);
using K = TV_DECLTYPE(I);
using table_t =
using table_t =
...
@@ -1042,13 +1252,18 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -1042,13 +1252,18 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
tv::hash::default_empty_key_v<K>, false>;
tv::hash::default_empty_key_v<K>, false>;
TV_ASSERT_RT_ERR(hashdata_k.dim(0) >= num_out_act, "hash size not enough");
TV_ASSERT_RT_ERR(hashdata_k.dim(0) >= num_out_act, "hash size not enough");
table_t hash = table_t(hashdata_k.data_ptr<K>(), hashdata_v.data_ptr<V>(), hashdata_k.dim(0));
table_t hash = table_t(hashdata_k.data_ptr<K>(), hashdata_v.data_ptr<V>(), hashdata_k.dim(0));
"""
)
if
not
is_direct_table
:
# direct table built in stage 1.
code
.
raw
(
f
"""
tv::hash::clear_map_split(hash, custream);
tv::hash::clear_map_split(hash, custream);
lanucher_build_hash(build_conv_hash_table<table_t>, hash,
lanucher_build_hash(build_conv_hash_table<table_t>, hash,
out_inds.data_ptr<int>(), indice_pairs_uniq.data_ptr<const K>(),
out_inds.data_ptr<int>(), indice_pairs_uniq.data_ptr<const K>(),
loc_iter.layout_npq, num_out_act);
loc_iter.layout_npq, num_out_act);
"""
)
code
.
raw
(
f
"""
if (!mask_bwd.empty()){{
if (!mask_bwd.empty()){{
launcher_num_act_in(calc_conv_indices_stage2_mask<table_t>, hash,
launcher_num_act_in(calc_conv_indices_stage2_mask<table_t
,
{
pccm
.
literal
(
is_direct_table
)
}
>, hash,
indice_pairs_fwd.data_ptr<int>(), indice_pairs_bwd.data_ptr<int>(),
indice_pairs_fwd.data_ptr<int>(), indice_pairs_bwd.data_ptr<int>(),
indice_pairs_uniq_before_sort.data_ptr<K>(),
indice_pairs_uniq_before_sort.data_ptr<K>(),
mask_fwd.data_ptr<uint32_t>(), mask_bwd.data_ptr<uint32_t>(),
mask_fwd.data_ptr<uint32_t>(), mask_bwd.data_ptr<uint32_t>(),
...
@@ -1064,7 +1279,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -1064,7 +1279,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
mask_bwd[1].copy_(mask_bwd[0], ctx);
mask_bwd[1].copy_(mask_bwd[0], ctx);
}}
}}
}}else{{
}}else{{
launcher_num_act_in(calc_conv_indices_stage2_inference_mask<table_t>, hash,
launcher_num_act_in(calc_conv_indices_stage2_inference_mask<table_t
,
{
pccm
.
literal
(
is_direct_table
)
}
>, hash,
indice_pairs_fwd.data_ptr<int>(), indice_pairs_bwd.data_ptr<int>(),
indice_pairs_fwd.data_ptr<int>(), indice_pairs_bwd.data_ptr<int>(),
indice_pairs_uniq_before_sort.data_ptr<K>(),
indice_pairs_uniq_before_sort.data_ptr<K>(),
mask_fwd.data_ptr<uint32_t>(),
mask_fwd.data_ptr<uint32_t>(),
...
@@ -1073,11 +1288,130 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -1073,11 +1288,130 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
mask_fwd[1].copy_(mask_fwd[0], ctx);
mask_fwd[1].copy_(mask_fwd[0], ctx);
}}
}}
}}
}}
}});
"""
)
code
.
raw
(
f
"""
return num_out_act;
return num_out_act;
"""
)
"""
)
return
code
.
ret
(
"int"
)
return
code
.
ret
(
"int"
)
@
pccm
.
cuda
.
static_function
def
generate_conv_inds_stage2_mask
(
self
):
"""here indice_pairs_uniq may be bounded, some
points may be dropped.
"""
return
self
.
generate_conv_inds_stage2_mask_template
(
False
)
@
pccm
.
cuda
.
static_function
def
generate_conv_inds_stage2_mask_direct_table
(
self
):
"""here indice_pairs_uniq may be bounded, some
points may be dropped.
"""
return
self
.
generate_conv_inds_stage2_mask_template
(
True
)
@
pccm
.
cuda
.
static_function
def
unique_and_assign_output_direct_hash
(
self
):
"""unique by hash
"""
code
=
pccm
.
FunctionCode
()
code
.
arg
(
"hashdata_k, hashdata_v, uniq_cnt"
,
"tv::Tensor"
)
code
.
arg
(
"out_inds"
,
"tv::Tensor"
)
code
.
arg
(
"num_out_bound"
,
"int"
)
code
.
arg
(
"batch_size"
,
"int"
)
code
.
arg
(
"output_dims, input_dims"
,
f
"tv::array<int,
{
self
.
ndim
}
>"
)
code
.
arg
(
"ksize, stride, padding, dilation"
,
f
"tv::array<int,
{
self
.
ndim
}
>"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
)
code
.
raw
(
f
"""
auto custream = reinterpret_cast<cudaStream_t>(stream_int);
tv::cuda::Launch lanucher_build_hash(hashdata_k.size(), custream);
ConvProblem problem(batch_size, 1, 1, input_dims, output_dims, ksize, padding, stride, dilation);
ConvLocIter loc_iter(problem);
auto tvctx = tv::Context();
tvctx.set_cuda_stream(reinterpret_cast<cudaStream_t>(stream_int));
if (num_out_bound <= 0){{
num_out_bound = hashdata_k.size();
}}
tv::dispatch<int32_t, int64_t>(hashdata_k.dtype(), [&](auto I){{
using V =
{
self
.
dtype_indices
}
;
using K = TV_DECLTYPE(I);
using table_t =
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
tv::hash::default_empty_key_v<K>, false>;
table_t table = table_t(hashdata_k.data_ptr<K>(), hashdata_v.data_ptr<V>(), hashdata_k.dim(0));
lanucher_build_hash(arange_hash_table_and_assign_out<table_t>, table,
out_inds.data_ptr<int>(), uniq_cnt.data_ptr<int>(), num_out_bound,
loc_iter.layout_npq);
}});
auto uniq_cnt_cpu = uniq_cnt.cpu(tvctx);
return std::min(uniq_cnt_cpu.data_ptr<int>()[0], num_out_bound);
"""
)
return
code
.
ret
(
"int"
)
@
pccm
.
cuda
.
static_function
def
unique_hash
(
self
):
"""unique by hash
"""
code
=
pccm
.
FunctionCode
()
code
.
arg
(
"hashdata_k, hashdata_v, uniq_cnt, out_indices_offset"
,
"tv::Tensor"
)
code
.
arg
(
"num_out_bound"
,
"int"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
)
code
.
raw
(
f
"""
auto custream = reinterpret_cast<cudaStream_t>(stream_int);
tv::cuda::Launch lanucher_build_hash(hashdata_k.size(), custream);
auto tvctx = tv::Context();
tvctx.set_cuda_stream(reinterpret_cast<cudaStream_t>(stream_int));
if (num_out_bound <= 0){{
num_out_bound = out_indices_offset.dim(0);
}}
tv::dispatch<int32_t, int64_t>(hashdata_k.dtype(), [&](auto I){{
using V =
{
self
.
dtype_indices
}
;
using K = TV_DECLTYPE(I);
using table_t =
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
tv::hash::default_empty_key_v<K>, false>;
table_t table = table_t(hashdata_k.data_ptr<K>(), hashdata_v.data_ptr<V>(), hashdata_k.dim(0));
lanucher_build_hash(arange_hash_table<table_t>, table,
out_indices_offset.data_ptr<K>(),
uniq_cnt.data_ptr<int>(), num_out_bound);
}});
auto uniq_cnt_cpu = uniq_cnt.cpu(tvctx);
return std::min(uniq_cnt_cpu.data_ptr<int>()[0], num_out_bound);
"""
)
return
code
.
ret
(
"int"
)
@
pccm
.
cuda
.
static_function
def
assign_output_direct_hash
(
self
):
"""unique by hash
"""
code
=
pccm
.
FunctionCode
()
code
.
arg
(
"out_indices_offset"
,
"tv::Tensor"
)
code
.
arg
(
"out_inds"
,
"tv::Tensor"
)
code
.
arg
(
"batch_size"
,
"int"
)
code
.
arg
(
"output_dims, input_dims"
,
f
"tv::array<int,
{
self
.
ndim
}
>"
)
code
.
arg
(
"ksize, stride, padding, dilation"
,
f
"tv::array<int,
{
self
.
ndim
}
>"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
)
code
.
raw
(
f
"""
auto custream = reinterpret_cast<cudaStream_t>(stream_int);
tv::cuda::Launch lanucher_build_hash(out_inds.dim(0), custream);
TV_ASSERT_RT_ERR(out_indices_offset.dim(0) >= out_inds.dim(0), "error");
ConvProblem problem(batch_size, 1, 1, input_dims, output_dims, ksize, padding, stride, dilation);
ConvLocIter loc_iter(problem);
auto tvctx = tv::Context();
tvctx.set_cuda_stream(reinterpret_cast<cudaStream_t>(stream_int));
tv::dispatch<int32_t, int64_t>(out_indices_offset.dtype(), [&](auto I){{
using K = TV_DECLTYPE(I);
lanucher_build_hash(assign_out_indices<K>, out_inds.data_ptr<int>(),
out_indices_offset.data_ptr<const K>(),
loc_iter.layout_npq, out_inds.dim(0));
}});
"""
)
return
code
@
pccm
.
cuda
.
static_function
@
pccm
.
cuda
.
static_function
def
generate_subm_conv_inds
(
self
):
def
generate_subm_conv_inds
(
self
):
code
=
pccm
.
FunctionCode
()
code
=
pccm
.
FunctionCode
()
...
@@ -1175,6 +1509,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -1175,6 +1509,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
class
SparseConvIndicesCPU
(
pccm
.
ParameterizedClass
):
class
SparseConvIndicesCPU
(
pccm
.
ParameterizedClass
):
def
__init__
(
self
,
problem
:
ConvProblem
,
dtype_indices
:
dtypes
.
DType
):
def
__init__
(
self
,
problem
:
ConvProblem
,
dtype_indices
:
dtypes
.
DType
):
super
().
__init__
()
super
().
__init__
()
self
.
add_dependency
(
TensorView
)
self
.
add_dependency
(
TensorView
)
...
...
spconv/pytorch/cppcore.py
View file @
73a5ce7d
...
@@ -33,13 +33,21 @@ _TORCH_DTYPE_TO_TV = {
...
@@ -33,13 +33,21 @@ _TORCH_DTYPE_TO_TV = {
torch
.
int16
:
tv
.
int16
,
torch
.
int16
:
tv
.
int16
,
torch
.
uint8
:
tv
.
uint8
,
torch
.
uint8
:
tv
.
uint8
,
}
}
_TV_DTYPE_TO_TORCH
=
{
v
:
k
for
k
,
v
in
_TORCH_DTYPE_TO_TV
.
items
()}
_TORCH_UINT_WORKAROUNDS
=
{
_TORCH_UINT_WORKAROUNDS
=
{
tv
.
uint32
:
tv
.
int32
,
tv
.
uint32
:
tv
.
int32
,
tv
.
uint16
:
tv
.
int16
,
tv
.
uint16
:
tv
.
int16
,
tv
.
uint64
:
tv
.
int64
tv
.
uint64
:
tv
.
int64
}
}
_TV_DTYPE_TO_TORCH
=
{
v
:
k
for
k
,
v
in
_TORCH_DTYPE_TO_TV
.
items
()}
_TV_DTYPE_TO_TORCH
.
update
({
tv
.
uint32
:
torch
.
int32
,
tv
.
uint16
:
torch
.
int16
,
tv
.
uint64
:
torch
.
int64
})
_ALL_INTS
=
{
_ALL_INTS
=
{
tv
.
int32
,
tv
.
int16
,
tv
.
int8
,
tv
.
int64
,
tv
.
uint64
,
tv
.
uint8
,
tv
.
uint32
,
tv
.
int32
,
tv
.
int16
,
tv
.
int8
,
tv
.
int64
,
tv
.
uint64
,
tv
.
uint8
,
tv
.
uint32
,
tv
.
uint16
tv
.
uint16
...
@@ -106,91 +114,66 @@ class TorchAllocator(ExternalAllocator):
...
@@ -106,91 +114,66 @@ class TorchAllocator(ExternalAllocator):
device
:
int
,
stream
:
int
=
0
,
is_temp_memory
:
bool
=
False
)
->
tv
.
Tensor
:
device
:
int
,
stream
:
int
=
0
,
is_temp_memory
:
bool
=
False
)
->
tv
.
Tensor
:
# TODO free memory by name if its already free by pointer.
# TODO free memory by name if its already free by pointer.
# provide a name if you want to access it after c++ function exit.
# provide a name if you want to access it after c++ function exit.
torch_uint_workaround
=
dtype
in
_TORCH_UINT_WORKAROUNDS
dtype_bkp
=
dtype
dtype_bkp
=
dtype
if
dtype
in
_TORCH_UINT_WORKAROUNDS
:
# assert name == "", "must be temp memory for uint dtypes"
dtype
=
_TORCH_UINT_WORKAROUNDS
[
dtype
]
th_dtype
=
_TV_DTYPE_TO_TORCH
[
dtype
]
th_dtype
=
_TV_DTYPE_TO_TORCH
[
dtype
]
if
device
==
-
1
:
if
device
==
-
1
:
dev
=
self
.
cpudevice
dev
=
self
.
cpudevice
else
:
else
:
dev
=
self
.
gpudevice
dev
=
self
.
gpudevice
ten
=
torch
.
zeros
(
shape
,
dtype
=
th_dtype
,
device
=
dev
)
ten
=
torch
.
zeros
(
shape
,
dtype
=
th_dtype
,
device
=
dev
)
ten_tv
=
torch_tensor_to_tv
(
ten
)
ten_tv
=
torch_tensor_to_tv
(
ten
,
dtype_bkp
)
self
.
allocated
[
ten
.
data_pt
r
()]
=
ten
self
.
allocated
[
ten
_tv
.
byte_pointe
r
()]
=
ten
if
name
and
not
is_temp_memory
:
if
name
and
not
is_temp_memory
:
self
.
allocated
[
name
]
=
ten
self
.
allocated
[
name
]
=
ten
if
torch_uint_workaround
:
return
ten_tv
.
type_view
(
dtype_bkp
)
return
ten_tv
return
ten_tv
def
empty
(
self
,
name
:
str
,
shape
:
List
[
int
],
dtype
:
int
,
def
empty
(
self
,
name
:
str
,
shape
:
List
[
int
],
dtype
:
int
,
device
:
int
,
stream
:
int
=
0
,
is_temp_memory
:
bool
=
False
)
->
tv
.
Tensor
:
device
:
int
,
stream
:
int
=
0
,
is_temp_memory
:
bool
=
False
)
->
tv
.
Tensor
:
torch_uint_workaround
=
dtype
in
_TORCH_UINT_WORKAROUNDS
dtype_bkp
=
dtype
dtype_bkp
=
dtype
if
dtype
in
_TORCH_UINT_WORKAROUNDS
:
# assert name == "", "must be temp memory for uint dtypes"
dtype
=
_TORCH_UINT_WORKAROUNDS
[
dtype
]
th_dtype
=
_TV_DTYPE_TO_TORCH
[
dtype
]
th_dtype
=
_TV_DTYPE_TO_TORCH
[
dtype
]
if
device
==
-
1
:
if
device
==
-
1
:
dev
=
self
.
cpudevice
dev
=
self
.
cpudevice
else
:
else
:
dev
=
self
.
gpudevice
dev
=
self
.
gpudevice
ten
=
torch
.
empty
(
shape
,
dtype
=
th_dtype
,
device
=
dev
)
ten
=
torch
.
empty
(
shape
,
dtype
=
th_dtype
,
device
=
dev
)
ten_tv
=
torch_tensor_to_tv
(
ten
)
ten_tv
=
torch_tensor_to_tv
(
ten
,
dtype_bkp
)
self
.
allocated
[
ten
.
data_pt
r
()]
=
ten
self
.
allocated
[
ten
_tv
.
byte_pointe
r
()]
=
ten
if
name
and
not
is_temp_memory
:
if
name
and
not
is_temp_memory
:
self
.
allocated
[
name
]
=
ten
self
.
allocated
[
name
]
=
ten
if
torch_uint_workaround
:
return
ten_tv
.
type_view
(
dtype_bkp
)
return
ten_tv
return
ten_tv
def
full_int
(
self
,
name
:
str
,
shape
:
List
[
int
],
value
:
int
,
dtype
:
int
,
def
full_int
(
self
,
name
:
str
,
shape
:
List
[
int
],
value
:
int
,
dtype
:
int
,
device
:
int
,
stream
:
int
=
0
,
is_temp_memory
:
bool
=
False
)
->
tv
.
Tensor
:
device
:
int
,
stream
:
int
=
0
,
is_temp_memory
:
bool
=
False
)
->
tv
.
Tensor
:
if
dtype
in
_TORCH_UINT_WORKAROUNDS
and
value
<
0
:
if
dtype
in
_TORCH_UINT_WORKAROUNDS
and
value
<
0
:
raise
NotImplementedError
(
"you can't use full for unsigned dtypes"
)
raise
NotImplementedError
(
"you can't use full for unsigned dtypes"
)
torch_uint_workaround
=
dtype
in
_TORCH_UINT_WORKAROUNDS
dtype_bkp
=
dtype
dtype_bkp
=
dtype
if
dtype
in
_TORCH_UINT_WORKAROUNDS
:
assert
name
==
""
,
"must be temp memory for uint dtypes"
dtype
=
_TORCH_UINT_WORKAROUNDS
[
dtype
]
th_dtype
=
_TV_DTYPE_TO_TORCH
[
dtype
]
th_dtype
=
_TV_DTYPE_TO_TORCH
[
dtype
]
if
device
==
-
1
:
if
device
==
-
1
:
dev
=
self
.
cpudevice
dev
=
self
.
cpudevice
else
:
else
:
dev
=
self
.
gpudevice
dev
=
self
.
gpudevice
ten
=
torch
.
full
(
shape
,
value
,
dtype
=
th_dtype
,
device
=
dev
)
ten
=
torch
.
full
(
shape
,
value
,
dtype
=
th_dtype
,
device
=
dev
)
ten_tv
=
torch_tensor_to_tv
(
ten
)
ten_tv
=
torch_tensor_to_tv
(
ten
,
dtype_bkp
)
self
.
allocated
[
ten
.
data_pt
r
()]
=
ten
self
.
allocated
[
ten
_tv
.
byte_pointe
r
()]
=
ten
if
name
and
not
is_temp_memory
:
if
name
and
not
is_temp_memory
:
self
.
allocated
[
name
]
=
ten
self
.
allocated
[
name
]
=
ten
if
torch_uint_workaround
:
return
ten_tv
.
type_view
(
dtype_bkp
)
return
ten_tv
return
ten_tv
def
full_float
(
self
,
name
:
str
,
shape
:
List
[
int
],
value
:
float
,
dtype
:
int
,
def
full_float
(
self
,
name
:
str
,
shape
:
List
[
int
],
value
:
float
,
dtype
:
int
,
device
:
int
,
stream
:
int
=
0
,
is_temp_memory
:
bool
=
False
)
->
tv
.
Tensor
:
device
:
int
,
stream
:
int
=
0
,
is_temp_memory
:
bool
=
False
)
->
tv
.
Tensor
:
if
dtype
in
_TORCH_UINT_WORKAROUNDS
and
value
<
0
:
if
dtype
in
_TORCH_UINT_WORKAROUNDS
and
value
<
0
:
raise
NotImplementedError
(
"you can't use full for unsigned dtypes"
)
raise
NotImplementedError
(
"you can't use full for unsigned dtypes"
)
torch_uint_workaround
=
dtype
in
_TORCH_UINT_WORKAROUNDS
dtype_bkp
=
dtype
dtype_bkp
=
dtype
if
dtype
in
_TORCH_UINT_WORKAROUNDS
:
assert
name
==
""
,
"must be temp memory for uint dtypes"
dtype
=
_TORCH_UINT_WORKAROUNDS
[
dtype
]
th_dtype
=
_TV_DTYPE_TO_TORCH
[
dtype
]
th_dtype
=
_TV_DTYPE_TO_TORCH
[
dtype
]
if
device
==
-
1
:
if
device
==
-
1
:
dev
=
self
.
cpudevice
dev
=
self
.
cpudevice
else
:
else
:
dev
=
self
.
gpudevice
dev
=
self
.
gpudevice
ten
=
torch
.
full
(
shape
,
value
,
dtype
=
th_dtype
,
device
=
dev
)
ten
=
torch
.
full
(
shape
,
value
,
dtype
=
th_dtype
,
device
=
dev
)
ten_tv
=
torch_tensor_to_tv
(
ten
)
ten_tv
=
torch_tensor_to_tv
(
ten
,
dtype_bkp
)
self
.
allocated
[
ten
.
data_pt
r
()]
=
ten
self
.
allocated
[
ten
_tv
.
byte_pointe
r
()]
=
ten
if
name
and
not
is_temp_memory
:
if
name
and
not
is_temp_memory
:
self
.
allocated
[
name
]
=
ten
self
.
allocated
[
name
]
=
ten
if
torch_uint_workaround
:
return
ten_tv
.
type_view
(
dtype_bkp
)
return
ten_tv
return
ten_tv
def
get_tensor_by_name
(
self
,
name
:
str
):
def
get_tensor_by_name
(
self
,
name
:
str
):
...
...
spconv/pytorch/ops.py
View file @
73a5ce7d
...
@@ -26,7 +26,7 @@ from spconv.pytorch.core import ThrustSortAllocator
...
@@ -26,7 +26,7 @@ from spconv.pytorch.core import ThrustSortAllocator
from
spconv.pytorch.cppcore
import
TorchAllocator
,
torch_tensor_to_tv
,
get_current_stream
,
get_arch
,
TorchSpconvMatmul
from
spconv.pytorch.cppcore
import
TorchAllocator
,
torch_tensor_to_tv
,
get_current_stream
,
get_arch
,
TorchSpconvMatmul
from
spconv.core_cc.csrc.sparse.all
import
SpconvOps
from
spconv.core_cc.csrc.sparse.all
import
SpconvOps
from
spconv.core_cc.csrc.sparse.alloc
import
ExternalAllocator
from
spconv.core_cc.csrc.sparse.alloc
import
ExternalAllocator
from
spconv.constants
import
SPCONV_CPP_INDICE_PAIRS
,
SPCONV_CPP_INDICE_PAIRS_IGEMM
,
SPCONV_CPP_GEMM
from
spconv.constants
import
SPCONV_CPP_INDICE_PAIRS
,
SPCONV_CPP_INDICE_PAIRS_IGEMM
,
SPCONV_CPP_GEMM
,
SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE
import
spconv.core_cc
as
_ext
import
spconv.core_cc
as
_ext
from
spconv.core_cc.csrc.sparse.convops.spops
import
ConvGemmOps
from
spconv.core_cc.csrc.sparse.convops.spops
import
ConvGemmOps
from
spconv.utils
import
nullcontext
from
spconv.utils
import
nullcontext
...
@@ -46,7 +46,7 @@ from cumm.gemm import codeops
...
@@ -46,7 +46,7 @@ from cumm.gemm import codeops
from
spconv.tools
import
CUDAKernelTimer
from
spconv.tools
import
CUDAKernelTimer
DEBUG
=
False
DEBUG
=
False
DEBUG_INT64_HASH_K
=
Tru
e
DEBUG_INT64_HASH_K
=
Fals
e
INT32_MAX
=
SpconvOps
.
get_int32_max
()
INT32_MAX
=
SpconvOps
.
get_int32_max
()
...
@@ -77,12 +77,17 @@ def get_deconv_output_size(input_size, kernel_size, stride, padding, dilation,
...
@@ -77,12 +77,17 @@ def get_deconv_output_size(input_size, kernel_size, stride, padding, dilation,
class
_HashData
:
class
_HashData
:
def
__init__
(
self
,
num
:
int
,
use_i64
:
bool
,
device
:
torch
.
device
)
->
None
:
def
__init__
(
self
,
num
:
int
,
use_i64
:
bool
,
device
:
torch
.
device
,
rate
:
float
=
2.0
)
->
None
:
if
use_i64
:
if
use_i64
:
self
.
hashdata_k
=
torch
.
empty
((
num
*
2
,
),
self
.
hashdata_k
=
torch
.
empty
((
int
(
num
*
rate
)
,
),
dtype
=
torch
.
int64
,
dtype
=
torch
.
int64
,
device
=
device
)
device
=
device
)
self
.
hashdata_v
=
torch
.
empty
((
num
*
2
,
),
self
.
hashdata_v
=
torch
.
empty
((
int
(
num
*
rate
)
,
),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
device
)
device
=
device
)
self
.
hashdata_k_tv
=
torch_tensor_to_tv
(
self
.
hashdata_k
)
self
.
hashdata_k_tv
=
torch_tensor_to_tv
(
self
.
hashdata_k
)
...
@@ -91,7 +96,7 @@ class _HashData:
...
@@ -91,7 +96,7 @@ class _HashData:
else
:
else
:
self
.
hashdata
=
torch
.
empty
((
self
.
hashdata
=
torch
.
empty
((
2
,
2
,
num
*
2
,
int
(
num
*
rate
)
,
),
),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
device
)
device
=
device
)
...
@@ -309,7 +314,8 @@ def get_indice_pairs_implicit_gemm(
...
@@ -309,7 +314,8 @@ def get_indice_pairs_implicit_gemm(
is_train
:
bool
=
True
,
is_train
:
bool
=
True
,
alloc
:
Optional
[
ThrustSortAllocator
]
=
None
,
alloc
:
Optional
[
ThrustSortAllocator
]
=
None
,
timer
:
CUDAKernelTimer
=
CUDAKernelTimer
(
False
),
timer
:
CUDAKernelTimer
=
CUDAKernelTimer
(
False
),
num_out_act_bound
:
int
=
-
1
):
num_out_act_bound
:
int
=
-
1
,
direct_table
:
bool
=
True
):
"""
"""
Why return tuple? because pytorch seems don't support custom object in autograd.
Why return tuple? because pytorch seems don't support custom object in autograd.
return: (
return: (
...
@@ -323,14 +329,33 @@ def get_indice_pairs_implicit_gemm(
...
@@ -323,14 +329,33 @@ def get_indice_pairs_implicit_gemm(
mask_argsort_bwd_splits, # torch.Tensor() if subm or inference mode
mask_argsort_bwd_splits, # torch.Tensor() if subm or inference mode
masks,
masks,
)
)
direct_table: a hash-based regular conv pair gen algo to avoid unique operation.
runs faster than pytorch unique with num_voxel < 1000k.
"""
"""
stream
=
get_current_stream
()
stream
=
get_current_stream
()
if
SPCONV_CPP_INDICE_PAIRS_IGEMM
:
if
SPCONV_CPP_INDICE_PAIRS_IGEMM
:
thalloc
=
TorchAllocator
(
indices
.
device
)
thalloc
=
TorchAllocator
(
indices
.
device
)
timer_cpp
=
tv
.
CUDAKernelTimer
(
False
)
if
timer
.
_timer
is
not
None
:
timer_cpp
=
timer
.
_timer
mask_tensor
,
num_act_out
=
SpconvOps
.
get_indice_pairs_implicit_gemm
(
mask_tensor
,
num_act_out
=
SpconvOps
.
get_indice_pairs_implicit_gemm
(
thalloc
,
torch_tensor_to_tv
(
indices
),
batch_size
,
spatial_shape
,
thalloc
,
algo
.
value
,
ksize
,
stride
,
padding
,
dilation
,
out_padding
,
subm
,
torch_tensor_to_tv
(
indices
),
transpose
,
is_train
,
stream
,
num_out_act_bound
)
batch_size
,
spatial_shape
,
algo
.
value
,
ksize
,
stride
,
padding
,
dilation
,
out_padding
,
subm
,
transpose
,
is_train
,
stream
,
num_out_act_bound
,
timer
=
timer_cpp
,
direct_table
=
direct_table
)
mask_split_count
=
mask_tensor
.
dim
(
0
)
mask_split_count
=
mask_tensor
.
dim
(
0
)
masks
=
[
mask_tensor
[
i
:
i
+
1
].
numpy
()
for
i
in
range
(
mask_split_count
)]
masks
=
[
mask_tensor
[
i
:
i
+
1
].
numpy
()
for
i
in
range
(
mask_split_count
)]
if
subm
:
if
subm
:
...
@@ -342,7 +367,6 @@ def get_indice_pairs_implicit_gemm(
...
@@ -342,7 +367,6 @@ def get_indice_pairs_implicit_gemm(
# for subm, if training, pair shape is [2, kv, ...]
# for subm, if training, pair shape is [2, kv, ...]
# if not training, pair is [1, kv, ...]
# if not training, pair is [1, kv, ...]
pair
=
thalloc
.
allocated
[
AllocKeys
.
PairFwd
]
pair
=
thalloc
.
allocated
[
AllocKeys
.
PairFwd
]
pair_mask
=
thalloc
.
allocated
[
AllocKeys
.
PairMask
]
pair_mask
=
thalloc
.
allocated
[
AllocKeys
.
PairMask
]
mask_argsort
=
thalloc
.
allocated
[
AllocKeys
.
MaskArgSort
]
mask_argsort
=
thalloc
.
allocated
[
AllocKeys
.
MaskArgSort
]
pair_mask_in_splits
=
[
pair_mask_in_splits
=
[
...
@@ -367,7 +391,6 @@ def get_indice_pairs_implicit_gemm(
...
@@ -367,7 +391,6 @@ def get_indice_pairs_implicit_gemm(
if
is_train
:
if
is_train
:
pair_mask_bwd
=
thalloc
.
allocated
[
AllocKeys
.
PairMaskBwd
]
pair_mask_bwd
=
thalloc
.
allocated
[
AllocKeys
.
PairMaskBwd
]
mask_argsort_bwd
=
thalloc
.
allocated
[
AllocKeys
.
MaskArgSortBwd
]
mask_argsort_bwd
=
thalloc
.
allocated
[
AllocKeys
.
MaskArgSortBwd
]
mask_argsort_fwd
=
thalloc
.
allocated
[
AllocKeys
.
MaskArgSort
]
mask_argsort_fwd
=
thalloc
.
allocated
[
AllocKeys
.
MaskArgSort
]
if
not
is_train
:
if
not
is_train
:
pair_mask_bwd_splits
:
List
[
torch
.
Tensor
]
=
[]
pair_mask_bwd_splits
:
List
[
torch
.
Tensor
]
=
[]
...
@@ -388,11 +411,6 @@ def get_indice_pairs_implicit_gemm(
...
@@ -388,11 +411,6 @@ def get_indice_pairs_implicit_gemm(
return
(
out_inds
,
indice_num_per_loc
,
pair_fwd
,
pair_bwd
,
return
(
out_inds
,
indice_num_per_loc
,
pair_fwd
,
pair_bwd
,
pair_mask_fwd_splits
,
pair_mask_bwd_splits
,
pair_mask_fwd_splits
,
pair_mask_bwd_splits
,
mask_argsort_fwd_splits
,
mask_argsort_bwd_splits
,
masks
)
mask_argsort_fwd_splits
,
mask_argsort_bwd_splits
,
masks
)
t
=
0
if
DEBUG
:
CONV
.
stream_synchronize
(
stream
)
t
=
time
.
time
()
assert
indices
.
is_cuda
,
"implicit gemm only support cuda"
assert
indices
.
is_cuda
,
"implicit gemm only support cuda"
ndim
=
indices
.
shape
[
1
]
-
1
ndim
=
indices
.
shape
[
1
]
-
1
kv
:
int
=
functools
.
reduce
(
lambda
x
,
y
:
x
*
y
,
ksize
,
1
)
kv
:
int
=
functools
.
reduce
(
lambda
x
,
y
:
x
*
y
,
ksize
,
1
)
...
@@ -452,8 +470,6 @@ def get_indice_pairs_implicit_gemm(
...
@@ -452,8 +470,6 @@ def get_indice_pairs_implicit_gemm(
masks
=
[
first
.
astype
(
np
.
uint32
),
second
.
astype
(
np
.
uint32
)]
masks
=
[
first
.
astype
(
np
.
uint32
),
second
.
astype
(
np
.
uint32
)]
else
:
else
:
masks
=
[
np
.
array
([
0xffffffff
],
dtype
=
np
.
uint32
)]
masks
=
[
np
.
array
([
0xffffffff
],
dtype
=
np
.
uint32
)]
# torch.cuda.synchronize()
# print("SUBM0", time.time() - t)
if
subm
:
if
subm
:
out_inds
=
indices
out_inds
=
indices
...
@@ -508,10 +524,6 @@ def get_indice_pairs_implicit_gemm(
...
@@ -508,10 +524,6 @@ def get_indice_pairs_implicit_gemm(
mask_argsort_in_splits
=
[
mask_argsort_in_splits
=
[
mask_argsort
[
i
]
for
i
in
range
(
mask_split_count
)
mask_argsort
[
i
]
for
i
in
range
(
mask_split_count
)
]
]
if
DEBUG
:
CONV
.
stream_synchronize
(
stream
)
print
(
"SUBM"
,
time
.
time
()
-
t
)
if
is_train
:
if
is_train
:
return
(
out_inds
,
indice_num_per_loc
,
pair
[
0
],
pair
[
1
],
return
(
out_inds
,
indice_num_per_loc
,
pair
[
0
],
pair
[
1
],
pair_mask_in_splits
,
[],
mask_argsort_in_splits
,
[],
masks
)
pair_mask_in_splits
,
[],
mask_argsort_in_splits
,
[],
masks
)
...
@@ -519,11 +531,10 @@ def get_indice_pairs_implicit_gemm(
...
@@ -519,11 +531,10 @@ def get_indice_pairs_implicit_gemm(
return
(
out_inds
,
indice_num_per_loc
,
pair
[
0
],
torch
.
Tensor
(),
return
(
out_inds
,
indice_num_per_loc
,
pair
[
0
],
torch
.
Tensor
(),
pair_mask_in_splits
,
[],
mask_argsort_in_splits
,
[],
masks
)
pair_mask_in_splits
,
[],
mask_argsort_in_splits
,
[],
masks
)
else
:
else
:
if
DEBUG
:
max_num_act
=
SpconvOps
.
get_handcrafted_max_act_out
(
indices
.
shape
[
0
],
ksize
,
stride
,
padding
,
dilation
)
CONV
.
stream_synchronize
(
stream
)
if
transpose
:
print
(
"REGU_PREPARE"
,
time
.
time
()
-
t
)
max_num_act
=
kv
*
indices
.
shape
[
0
]
t
=
time
.
time
()
pair_bwd
=
pair
pair_bwd
=
pair
pair_bwd_tv
=
pair_tv
pair_bwd_tv
=
pair_tv
...
@@ -531,8 +542,38 @@ def get_indice_pairs_implicit_gemm(
...
@@ -531,8 +542,38 @@ def get_indice_pairs_implicit_gemm(
dtype
=
indice_dtype
,
dtype
=
indice_dtype
,
device
=
indices
.
device
)
device
=
indices
.
device
)
indice_pairs_uniq_tv
=
torch_tensor_to_tv
(
indice_pairs_uniq
)
indice_pairs_uniq_tv
=
torch_tensor_to_tv
(
indice_pairs_uniq
)
hashdata
=
_HashData
(
0
,
use_int64_hash_k
,
indices
.
device
)
indice_pairs_uniq_bkp_tv
=
tv
.
Tensor
()
if
direct_table
:
# print("HASH SIZE", max_num_act * 2)
hashdata
=
_HashData
(
max_num_act
,
use_int64_hash_k
,
indices
.
device
,
SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE
)
indice_pairs_uniq_bkp
=
torch
.
empty
((
pair
.
numel
()
+
1
,
),
dtype
=
indice_dtype
,
device
=
indices
.
device
)
indice_pairs_uniq_bkp_tv
=
torch_tensor_to_tv
(
indice_pairs_uniq_bkp
)
with
timer
.
record
(
"gen_conv_inds_stage1"
,
stream
):
with
timer
.
record
(
"gen_conv_inds_stage1"
,
stream
):
SpconvOps
.
generate_conv_inds_mask_stage1
(
inds_tv
,
SpconvOps
.
generate_conv_inds_mask_stage1_direct_table
(
inds_tv
,
hashdata
.
hashdata_k_tv
,
hashdata
.
hashdata_v_tv
,
pair_bwd_tv
,
indice_pairs_uniq_bkp_tv
,
indice_num_per_loc_tv
,
batch_size
=
batch_size
,
output_dims
=
out_shape
,
input_dims
=
spatial_shape
,
ksize
=
ksize
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
transposed
=
transpose
,
stream_int
=
stream
)
else
:
with
timer
.
record
(
"gen_conv_inds_stage1"
,
stream
):
SpconvOps
.
generate_conv_inds_mask_stage1
(
inds_tv
,
pair_bwd_tv
,
pair_bwd_tv
,
indice_pairs_uniq_tv
,
indice_pairs_uniq_tv
,
indice_num_per_loc_tv
,
indice_num_per_loc_tv
,
...
@@ -545,23 +586,31 @@ def get_indice_pairs_implicit_gemm(
...
@@ -545,23 +586,31 @@ def get_indice_pairs_implicit_gemm(
dilation
=
dilation
,
dilation
=
dilation
,
transposed
=
transpose
,
transposed
=
transpose
,
stream_int
=
stream
)
stream_int
=
stream
)
if
DEBUG
:
uniq_out_indices_offset_tv
=
tv
.
Tensor
()
with
timer
.
record
(
f
"unique_
{
indice_pairs_uniq
.
shape
[
0
]
}
"
,
stream
):
CONV
.
stream_synchronize
(
stream
)
print
(
"REGU_S1"
,
time
.
time
()
-
t
)
t
=
time
.
time
()
if
direct_table
:
uniq_cnt
=
torch
.
zeros
([
1
],
dtype
=
torch
.
int32
,
device
=
indices
.
device
)
uniq_cnt_tv
=
torch_tensor_to_tv
(
uniq_cnt
)
num_act_out
=
SpconvOps
.
unique_hash
(
hashdata
.
hashdata_k_tv
,
hashdata
.
hashdata_v_tv
,
uniq_cnt_tv
,
indice_pairs_uniq_tv
,
num_out_act_bound
,
stream
)
uniq_out_indices_offset_tv
=
indice_pairs_uniq_tv
raw_out_indices_offset_tv
=
indice_pairs_uniq_bkp_tv
else
:
uniq_res
=
indice_pairs_uniq
.
unique
()
uniq_res
=
indice_pairs_uniq
.
unique
()
num_act_out
=
uniq_res
.
shape
[
0
]
-
1
num_act_out
=
uniq_res
.
shape
[
0
]
-
1
uniq_out_indices_offset_tv
=
torch_tensor_to_tv
(
uniq_res
)
raw_out_indices_offset_tv
=
indice_pairs_uniq_tv
if
num_out_act_bound
>
0
and
num_act_out
>
num_out_act_bound
:
if
num_out_act_bound
>
0
and
num_act_out
>
num_out_act_bound
:
num_act_out
=
num_out_act_bound
num_act_out
=
num_out_act_bound
if
DEBUG
:
with
timer
.
record
(
f
"alloc_stage2"
,
stream
):
CONV
.
stream_synchronize
(
stream
)
print
(
"REGU_UNIQ"
,
time
.
time
()
-
t
)
t
=
time
.
time
()
uniq_res_tv
=
torch_tensor_to_tv
(
uniq_res
)
out_inds
=
torch
.
empty
((
num_act_out
,
indices
.
shape
[
1
]),
out_inds
=
torch
.
empty
((
num_act_out
,
indices
.
shape
[
1
]),
dtype
=
indices
.
dtype
,
dtype
=
indices
.
dtype
,
device
=
indices
.
device
)
device
=
indices
.
device
)
...
@@ -574,15 +623,18 @@ def get_indice_pairs_implicit_gemm(
...
@@ -574,15 +623,18 @@ def get_indice_pairs_implicit_gemm(
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
indices
.
device
)
device
=
indices
.
device
)
pair_fwd_tv
=
torch_tensor_to_tv
(
pair_fwd
)
pair_fwd_tv
=
torch_tensor_to_tv
(
pair_fwd
)
pair_mask_fwd_tv
=
torch_tensor_to_tv
(
pair_mask_fwd
,
dtype
=
tv
.
uint32
)
pair_mask_fwd_tv
=
torch_tensor_to_tv
(
pair_mask_fwd
,
dtype
=
tv
.
uint32
)
pair_mask_bwd
=
torch
.
Tensor
()
pair_mask_bwd
=
torch
.
Tensor
()
pair_mask_bwd_tv
=
tv
.
Tensor
()
pair_mask_bwd_tv
=
tv
.
Tensor
()
if
is_train
:
if
is_train
:
pair_mask_bwd
=
torch
.
zeros
((
mask_split_count
,
indices
.
shape
[
0
]),
pair_mask_bwd
=
torch
.
zeros
(
(
mask_split_count
,
indices
.
shape
[
0
]),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
indices
.
device
)
device
=
indices
.
device
)
pair_mask_bwd_tv
=
torch_tensor_to_tv
(
pair_mask_bwd
,
pair_mask_bwd_tv
=
torch_tensor_to_tv
(
pair_mask_bwd
,
dtype
=
tv
.
uint32
)
dtype
=
tv
.
uint32
)
if
not
direct_table
:
hashdata
=
_HashData
(
out_inds
.
shape
[
0
],
use_int64_hash_k
,
hashdata
=
_HashData
(
out_inds
.
shape
[
0
],
use_int64_hash_k
,
indices
.
device
)
indices
.
device
)
...
@@ -591,19 +643,28 @@ def get_indice_pairs_implicit_gemm(
...
@@ -591,19 +643,28 @@ def get_indice_pairs_implicit_gemm(
# device=indices.device)
# device=indices.device)
out_inds_tv
=
torch_tensor_to_tv
(
out_inds
)
out_inds_tv
=
torch_tensor_to_tv
(
out_inds
)
# hashdata_tv = torch_tensor_to_tv(hashdata, dtype=tv.custom64)
# hashdata_tv = torch_tensor_to_tv(hashdata, dtype=tv.custom64)
if
DEBUG
:
with
timer
.
record
(
f
"gen_conv_inds_stage2_
{
num_act_out
}
"
,
stream
):
stage2_fn
=
SpconvOps
.
generate_conv_inds_mask_stage2
if
direct_table
:
SpconvOps
.
assign_output_direct_hash
(
indice_pairs_uniq_tv
,
out_inds_tv
,
batch_size
=
batch_size
,
output_dims
=
out_shape
,
input_dims
=
spatial_shape
,
ksize
=
ksize
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
stream_int
=
stream
)
stage2_fn
=
SpconvOps
.
generate_conv_inds_stage2_mask_direct_table
CONV
.
stream_synchronize
(
stream
)
stage2_fn
(
inds_tv
,
print
(
"REGU_S2_PREPARE"
,
time
.
time
()
-
t
)
t
=
time
.
time
()
with
timer
.
record
(
"gen_conv_inds_stage2"
,
stream
):
SpconvOps
.
generate_conv_inds_mask_stage2
(
inds_tv
,
hashdata
.
hashdata_k_tv
,
hashdata
.
hashdata_k_tv
,
hashdata
.
hashdata_v_tv
,
hashdata
.
hashdata_v_tv
,
pair_fwd_tv
,
pair_fwd_tv
,
pair_bwd_tv
,
pair_bwd_tv
,
uniq_
res
_tv
,
uniq_
out_indices_offset
_tv
,
indice_pairs_uniq
_tv
,
raw_out_indices_offset
_tv
,
out_inds_tv
,
out_inds_tv
,
pair_mask_fwd_tv
,
pair_mask_fwd_tv
,
pair_mask_bwd_tv
,
pair_mask_bwd_tv
,
...
@@ -617,12 +678,6 @@ def get_indice_pairs_implicit_gemm(
...
@@ -617,12 +678,6 @@ def get_indice_pairs_implicit_gemm(
dilation
=
dilation
,
dilation
=
dilation
,
transposed
=
transpose
,
transposed
=
transpose
,
stream_int
=
stream
)
stream_int
=
stream
)
if
DEBUG
:
CONV
.
stream_synchronize
(
stream
)
print
(
"REGU_S2"
,
time
.
time
()
-
t
)
t
=
time
.
time
()
mask_argsort_fwd
=
torch
.
empty
((
mask_split_count
,
out_inds
.
shape
[
0
]),
mask_argsort_fwd
=
torch
.
empty
((
mask_split_count
,
out_inds
.
shape
[
0
]),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
indices
.
device
)
device
=
indices
.
device
)
...
@@ -693,10 +748,6 @@ def get_indice_pairs_implicit_gemm(
...
@@ -693,10 +748,6 @@ def get_indice_pairs_implicit_gemm(
SpconvOps
.
sort_1d_by_key_allocator
(
SpconvOps
.
sort_1d_by_key_allocator
(
pair_mask_bwd_tv
[
0
],
alloc
.
alloc
,
pair_mask_bwd_tv
[
0
],
alloc
.
alloc
,
mask_argsort_bwd_tv
[
0
],
stream
)
mask_argsort_bwd_tv
[
0
],
stream
)
if
DEBUG
:
CONV
.
stream_synchronize
(
stream
)
print
(
"REGU_S2_FINISH"
,
time
.
time
()
-
t
)
t
=
time
.
time
()
# CONV.stream_synchronize(stream)
# CONV.stream_synchronize(stream)
if
not
is_train
:
if
not
is_train
:
...
@@ -716,9 +767,6 @@ def get_indice_pairs_implicit_gemm(
...
@@ -716,9 +767,6 @@ def get_indice_pairs_implicit_gemm(
mask_argsort_fwd_splits
=
[
mask_argsort_fwd_splits
=
[
mask_argsort_fwd
[
i
]
for
i
in
range
(
mask_split_count
)
mask_argsort_fwd
[
i
]
for
i
in
range
(
mask_split_count
)
]
]
if
DEBUG
:
CONV
.
stream_synchronize
(
stream
)
print
(
"REGU"
,
time
.
time
()
-
t
)
return
(
out_inds
,
indice_num_per_loc
,
pair_fwd
,
pair_bwd
,
return
(
out_inds
,
indice_num_per_loc
,
pair_fwd
,
pair_bwd
,
pair_mask_fwd_splits
,
pair_mask_bwd_splits
,
pair_mask_fwd_splits
,
pair_mask_bwd_splits
,
...
@@ -769,8 +817,7 @@ def indice_conv(features: torch.Tensor,
...
@@ -769,8 +817,7 @@ def indice_conv(features: torch.Tensor,
stream
=
get_current_stream
()
stream
=
get_current_stream
()
ConvGemmOps
.
indice_conv
(
alloc
,
ext_mm
,
GEMM_CPP
,
ALL_WEIGHT_IS_KRSC
,
ConvGemmOps
.
indice_conv
(
alloc
,
ext_mm
,
GEMM_CPP
,
ALL_WEIGHT_IS_KRSC
,
FILTER_HWIO
,
features_tv
,
filters_tv
,
FILTER_HWIO
,
features_tv
,
filters_tv
,
indice_pairs_tv
,
indice_pair_num_tv
,
indice_pairs_tv
,
indice_pair_num_tv
,
arch
,
arch
,
num_activate_out
,
inverse
,
subm
,
algo
.
value
,
num_activate_out
,
inverse
,
subm
,
algo
.
value
,
stream
)
stream
)
out_features
=
alloc
.
allocated
[
AllocKeys
.
OutFeatures
]
out_features
=
alloc
.
allocated
[
AllocKeys
.
OutFeatures
]
...
@@ -1018,8 +1065,8 @@ def indice_conv_backward(features: torch.Tensor,
...
@@ -1018,8 +1065,8 @@ def indice_conv_backward(features: torch.Tensor,
ALL_WEIGHT_IS_KRSC
,
FILTER_HWIO
,
ALL_WEIGHT_IS_KRSC
,
FILTER_HWIO
,
features_tv
,
filters_tv
,
out_bp_tv
,
features_tv
,
filters_tv
,
out_bp_tv
,
indice_pairs_tv
,
indice_pair_num_tv
,
indice_pairs_tv
,
indice_pair_num_tv
,
arch
,
arch
,
inverse
,
subm
,
algo
.
value
,
inverse
,
subm
,
algo
.
value
,
stream
)
stream
)
din
=
alloc
.
allocated
[
AllocKeys
.
DIn
]
din
=
alloc
.
allocated
[
AllocKeys
.
DIn
]
df
=
alloc
.
allocated
[
AllocKeys
.
DFilters
]
df
=
alloc
.
allocated
[
AllocKeys
.
DFilters
]
return
din
,
df
return
din
,
df
...
@@ -1369,8 +1416,8 @@ def implicit_gemm(features: torch.Tensor,
...
@@ -1369,8 +1416,8 @@ def implicit_gemm(features: torch.Tensor,
mask_width
=
ConvGemmOps
.
implicit_gemm
(
mask_width
=
ConvGemmOps
.
implicit_gemm
(
alloc
,
CONV_CPP
,
features_tv
,
filters_tv
,
pair_fwd_tv
,
alloc
,
CONV_CPP
,
features_tv
,
filters_tv
,
pair_fwd_tv
,
pair_mask_fwd_splits_tv
,
mask_argsort_fwd_splits_tv
,
pair_mask_fwd_splits_tv
,
mask_argsort_fwd_splits_tv
,
num_activate_out
,
mask_tv
,
arch
,
is_train
,
is_subm
,
stream
,
timer_cpp
,
num_activate_out
,
mask_tv
,
arch
,
is_train
,
is_subm
,
stream
,
auto_fp32_accum
,
fp32_accum
)
timer_cpp
,
auto_fp32_accum
,
fp32_accum
)
out_features
=
alloc
.
allocated
[
AllocKeys
.
OutFeatures
]
out_features
=
alloc
.
allocated
[
AllocKeys
.
OutFeatures
]
mask_output_fwd
=
alloc
.
allocated
.
get
(
AllocKeys
.
MaskOutputFwd
,
None
)
mask_output_fwd
=
alloc
.
allocated
.
get
(
AllocKeys
.
MaskOutputFwd
,
None
)
if
is_train
:
if
is_train
:
...
@@ -1460,7 +1507,7 @@ def implicit_gemm(features: torch.Tensor,
...
@@ -1460,7 +1507,7 @@ def implicit_gemm(features: torch.Tensor,
# CONV.stream_synchronize(stream)
# CONV.stream_synchronize(stream)
# t = time.time()
# t = time.time()
print
(
tune_res
.
algo_desp
,
"REF"
,
features_tv
.
shape
,
filters
.
shape
)
#
print(tune_res.algo_desp, "REF", features_tv.shape, filters.shape)
# with tv.measure_and_print("f16 time"):
# with tv.measure_and_print("f16 time"):
with
timer
.
record
(
"implicit_gemm"
,
stream
):
with
timer
.
record
(
"implicit_gemm"
,
stream
):
for
j
in
range
(
num_split
):
for
j
in
range
(
num_split
):
...
@@ -1921,8 +1968,10 @@ def indice_maxpool_implicit_gemm_backward(features, out_features, out_bp,
...
@@ -1921,8 +1968,10 @@ def indice_maxpool_implicit_gemm_backward(features, out_features, out_bp,
indice_pairs_tv
,
stream
)
indice_pairs_tv
,
stream
)
return
din
return
din
def
indice_avgpool_implicit_gemm
(
features
:
torch
.
Tensor
,
def
indice_avgpool_implicit_gemm
(
features
:
torch
.
Tensor
,
indice_pairs
:
torch
.
Tensor
,
num_activate_out
,
calc_count
:
bool
):
indice_pairs
:
torch
.
Tensor
,
num_activate_out
,
calc_count
:
bool
):
# torch.cuda.synchronize()
# torch.cuda.synchronize()
# t = time.time()
# t = time.time()
stream
=
get_current_stream
()
stream
=
get_current_stream
()
...
@@ -1943,12 +1992,13 @@ def indice_avgpool_implicit_gemm(features: torch.Tensor,
...
@@ -1943,12 +1992,13 @@ def indice_avgpool_implicit_gemm(features: torch.Tensor,
count_out
=
torch
.
Tensor
()
count_out
=
torch
.
Tensor
()
count_out_tv
=
tv
.
Tensor
()
count_out_tv
=
tv
.
Tensor
()
if
calc_count
:
if
calc_count
:
count_out
=
torch
.
zeros
((
num_activate_out
,),
count_out
=
torch
.
zeros
((
num_activate_out
,
),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
features
.
device
)
device
=
features
.
device
)
count_out_tv
=
torch_tensor_to_tv
(
count_out
)
count_out_tv
=
torch_tensor_to_tv
(
count_out
)
SpconvOps
.
avgpool_implicit_gemm_forward
(
out_features_tv
,
features_tv
,
SpconvOps
.
avgpool_implicit_gemm_forward
(
out_features_tv
,
features_tv
,
indice_pairs_tv
,
count_out_tv
,
stream
)
indice_pairs_tv
,
count_out_tv
,
stream
)
# CONV.stream_synchronize(stream)
# CONV.stream_synchronize(stream)
# print("M", time.time() - t)
# print("M", time.time() - t)
...
@@ -1956,12 +2006,13 @@ def indice_avgpool_implicit_gemm(features: torch.Tensor,
...
@@ -1956,12 +2006,13 @@ def indice_avgpool_implicit_gemm(features: torch.Tensor,
return
out_features
,
count_out
return
out_features
,
count_out
def
indice_avgpool_implicit_gemm_backward
(
out_bp
,
def
indice_avgpool_implicit_gemm_backward
(
out_bp
,
indice_pairs
,
count_out
):
indice_pairs
,
count_out
):
# torch.cuda.synchronize()
# torch.cuda.synchronize()
# t = time.time()
# t = time.time()
out_channel
=
out_bp
.
shape
[
-
1
]
out_channel
=
out_bp
.
shape
[
-
1
]
din
=
torch
.
zeros
((
indice_pairs
.
shape
[
1
],
out_bp
.
shape
[
1
]),
dtype
=
out_bp
.
dtype
,
device
=
out_bp
.
device
)
din
=
torch
.
zeros
((
indice_pairs
.
shape
[
1
],
out_bp
.
shape
[
1
]),
dtype
=
out_bp
.
dtype
,
device
=
out_bp
.
device
)
assert
out_bp
.
is_cuda
assert
out_bp
.
is_cuda
if
not
out_bp
.
is_contiguous
():
if
not
out_bp
.
is_contiguous
():
out_bp
=
out_bp
.
contiguous
()
out_bp
=
out_bp
.
contiguous
()
...
@@ -1972,7 +2023,8 @@ def indice_avgpool_implicit_gemm_backward(out_bp,
...
@@ -1972,7 +2023,8 @@ def indice_avgpool_implicit_gemm_backward(out_bp,
din_tv
=
torch_tensor_to_tv
(
din
)
din_tv
=
torch_tensor_to_tv
(
din
)
indice_pairs_tv
=
torch_tensor_to_tv
(
indice_pairs
)
indice_pairs_tv
=
torch_tensor_to_tv
(
indice_pairs
)
SpconvOps
.
avgpool_implicit_gemm_backward
(
out_bp_tv
,
din_tv
,
SpconvOps
.
avgpool_implicit_gemm_backward
(
out_bp_tv
,
din_tv
,
indice_pairs_tv
,
count_out_tv
,
stream
)
indice_pairs_tv
,
count_out_tv
,
stream
)
return
din
return
din
...
...
test/benchmark.py
View file @
73a5ce7d
...
@@ -323,6 +323,8 @@ def main():
...
@@ -323,6 +323,8 @@ def main():
# pickle.dump((voxels, coors, spatial_shape), f)
# pickle.dump((voxels, coors, spatial_shape), f)
with
open
(
Path
(
__file__
).
parent
/
"data"
/
"test_spconv.pkl"
,
"rb"
)
as
f
:
with
open
(
Path
(
__file__
).
parent
/
"data"
/
"test_spconv.pkl"
,
"rb"
)
as
f
:
(
voxels
,
coors
,
spatial_shape
)
=
pickle
.
load
(
f
)
(
voxels
,
coors
,
spatial_shape
)
=
pickle
.
load
(
f
)
# voxels, coors, spatial_shape = waymo_data_large()
print
(
spatial_shape
)
print
(
spatial_shape
)
print
(
voxels
.
shape
)
print
(
voxels
.
shape
)
# voxels = voxels[:100]
# voxels = voxels[:100]
...
@@ -366,15 +368,14 @@ def main():
...
@@ -366,15 +368,14 @@ def main():
dout
=
np
.
random
.
uniform
(
-
0.2
,
0.2
,
out
.
features
.
shape
).
astype
(
np
.
float32
)
dout
=
np
.
random
.
uniform
(
-
0.2
,
0.2
,
out
.
features
.
shape
).
astype
(
np
.
float32
)
dout_t
=
torch
.
from_numpy
(
dout
).
to
(
device
).
to
(
dtype
)
dout_t
=
torch
.
from_numpy
(
dout
).
to
(
device
).
to
(
dtype
)
print
(
out
.
spatial_shape
,
out
.
features
.
mean
(),
out
.
features
.
max
(),
print
(
out
.
spatial_shape
,
out
.
features
.
sum
(
1
).
mean
(),
out
.
features
.
max
(),
out
.
features
.
min
())
out
.
features
.
min
())
times
=
[]
times
=
[]
show_metrics
=
False
show_metrics
=
False
with
torch
.
no_grad
():
with
torch
.
no_grad
():
for
i
in
range
(
20
):
for
i
in
range
(
100
):
print
(
"------------"
)
# print("------------")
torch
.
cuda
.
synchronize
()
with
tv
.
measure_duration
()
as
measure
:
t
=
time
.
time
()
out_nograd
=
net
(
voxels_th
,
coors_th
,
1
,
show_metrics
)
out_nograd
=
net
(
voxels_th
,
coors_th
,
1
,
show_metrics
)
# res = timer.collect_by_name("forward", timer.get_all_pair_time())
# res = timer.collect_by_name("forward", timer.get_all_pair_time())
# res2 = timer.collect_by_name("forward0", timer.get_all_pair_time())
# res2 = timer.collect_by_name("forward0", timer.get_all_pair_time())
...
@@ -383,14 +384,19 @@ def main():
...
@@ -383,14 +384,19 @@ def main():
# print(timer.get_all_pair_time())
# print(timer.get_all_pair_time())
# print(sum(timer.get_all_pair_time().values()))
# print(sum(timer.get_all_pair_time().values()))
torch
.
cuda
.
synchronize
()
# sort_bench()
# sort_bench()
times
.
append
(
time
.
time
()
-
t
)
times
.
append
(
measure
.
duration
)
if
show_metrics
:
if
show_metrics
:
timer
=
out_nograd
.
_timer
timer
=
out_nograd
.
_timer
items
=
list
(
timer
.
get_all_pair_time
().
items
())
items
=
list
(
timer
.
get_all_pair_time
().
items
())
items
.
sort
(
key
=
lambda
x
:
x
[
0
])
items
.
sort
(
key
=
lambda
x
:
x
[
0
])
print
(
"SUM TIME:"
,
sum
([
x
[
1
]
for
x
in
items
]))
print
(
json
.
dumps
(
dict
(
items
),
indent
=
2
))
print
(
json
.
dumps
(
dict
(
items
),
indent
=
2
))
inds_sum
=
0
for
k
,
v
in
items
:
if
"gen_pairs"
in
k
:
inds_sum
+=
v
print
(
"SUM GEN INDS:"
,
inds_sum
)
# state = net.state_dict()
# state = net.state_dict()
# state.pop("net.2.max_num_voxels_during_training")
# state.pop("net.2.max_num_voxels_during_training")
...
...
test/test_all_algo.py
View file @
73a5ce7d
...
@@ -231,8 +231,8 @@ def _test_impgemm_conv_cuda(subm: bool):
...
@@ -231,8 +231,8 @@ def _test_impgemm_conv_cuda(subm: bool):
# out_channels = [32, 48, 64]
# out_channels = [32, 48, 64]
in_channels
=
[
32
,
47
]
in_channels
=
[
32
,
47
]
out_channels
=
[
32
,
48
,
62
]
out_channels
=
[
32
,
48
,
62
]
in_channels
=
[
32
]
#
in_channels = [32]
out_channels
=
[
32
]
#
out_channels = [32]
multiple_base
=
16
multiple_base
=
16
if
subm
:
if
subm
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment