Commit 238d6a83 authored by yan.yan's avatar yan.yan
Browse files

add a simple example for c++ inference

parent ce8a91e4
...@@ -2,7 +2,7 @@ name: 'Close stale issues and PRs' ...@@ -2,7 +2,7 @@ name: 'Close stale issues and PRs'
on: on:
schedule: schedule:
- cron: '30 1 * * *' - cron: '30 1 1 * *'
workflow_dispatch: workflow_dispatch:
inputs: inputs:
logLevel: logLevel:
......
...@@ -7,7 +7,7 @@ set(CUMM_DISABLE_CMAKE_INSTALL ON CACHE BOOL "enable X functionality" FORCE) ...@@ -7,7 +7,7 @@ set(CUMM_DISABLE_CMAKE_INSTALL ON CACHE BOOL "enable X functionality" FORCE)
add_subdirectory(cumm) add_subdirectory(cumm)
add_subdirectory(spconv) add_subdirectory(spconv)
add_executable(main main.cc) add_executable(main main.cu)
# SPCONV2_INCLUDE_PATH come from spconv/CMakeLists.txt # SPCONV2_INCLUDE_PATH come from spconv/CMakeLists.txt
target_include_directories(main PRIVATE ${SPCONV2_INCLUDE_PATH}) target_include_directories(main PRIVATE ${SPCONV2_INCLUDE_PATH})
target_link_libraries(main spconv cumm::cumm) target_link_libraries(main spconv cumm::cumm)
#include <spconvlib/cumm/gemm/main/GemmMainUnitTest.h>
#include <spconvlib/spconv/csrc/sparse/all/SpconvOps.h>
#include <spconvlib/spconv/csrc/sparse/alloc/StaticAllocator.h>
#include <spconvlib/spconv/csrc/sparse/convops/spops/ConvGemmOps.h>
#include <spconvlib/spconv/csrc/sparse/inference/InferenceOps.h>
#include <spconvlib/spconv/csrc/sparse/convops/SimpleExternalSpconvMatmul.h>
#include <spconvlib/spconv/csrc/sparse/convops/gemmops/GemmTunerSimple.h>
#include <spconvlib/spconv/csrc/sparse/convops/spops/ConvGemmOps.h>
using StaticAllocator = spconvlib::spconv::csrc::sparse::alloc::StaticAllocator;
using SpconvOps = spconvlib::spconv::csrc::sparse::all::SpconvOps;
using ConvMain = spconvlib::cumm::conv::main::ConvMainUnitTest;
using ConvTunerSimple =
spconvlib::spconv::csrc::sparse::convops::spops::ConvTuner;
using ConvGemmOps =
spconvlib::spconv::csrc::sparse::convops::spops::ConvGemmOps;
using SimpleExternalSpconvMatmul =
spconvlib::spconv::csrc::sparse::convops::SimpleExternalSpconvMatmul;
using InferenceOps =
spconvlib::spconv::csrc::sparse::inference::InferenceOps;
int main(){
tv::ssprint("Hello libspconv!!!");
return 0;
}
\ No newline at end of file
#include <spconvlib/cumm/gemm/main/GemmMainUnitTest.h>
#include <spconvlib/spconv/csrc/sparse/all/SpconvOps.h>
#include <spconvlib/spconv/csrc/sparse/alloc/StaticAllocator.h>
#include <spconvlib/spconv/csrc/sparse/convops/spops/ConvGemmOps.h>
#include <spconvlib/spconv/csrc/sparse/inference/InferenceOps.h>
#include <spconvlib/spconv/csrc/sparse/all/ops3d/Point2Voxel.h>
#include <spconvlib/spconv/csrc/sparse/convops/SimpleExternalSpconvMatmul.h>
#include <spconvlib/spconv/csrc/sparse/convops/gemmops/GemmTunerSimple.h>
#include <spconvlib/spconv/csrc/sparse/convops/spops/ConvGemmOps.h>
#include <tensorview/io/jsonarray.h>
#include <tensorview/parallel/map.h>
using StaticAllocator = spconvlib::spconv::csrc::sparse::alloc::StaticAllocator;
using SpconvOps = spconvlib::spconv::csrc::sparse::all::SpconvOps;
using ConvMain = spconvlib::cumm::conv::main::ConvMainUnitTest;
using ConvTunerSimple =
spconvlib::spconv::csrc::sparse::convops::spops::ConvTuner;
using ConvGemmOps =
spconvlib::spconv::csrc::sparse::convops::spops::ConvGemmOps;
using SimpleExternalSpconvMatmul =
spconvlib::spconv::csrc::sparse::convops::SimpleExternalSpconvMatmul;
using InferenceOps =
spconvlib::spconv::csrc::sparse::inference::InferenceOps;
using Point2VoxelGPU3D =
spconvlib::spconv::csrc::sparse::all::ops3d::Point2Voxel;
int main(int argc, char** argv){
tv::ssprint("Hello libspconv!!!");
TV_ASSERT_RT_ERR(argc == 2, "usage: main /path/to/benchmark-pc.jarr, you can find it in example/libspconv.")
std::string path = argv[1];
Point2VoxelGPU3D p2v{{0.1, 0.1, 0.1}, {-80, -80, -2, 80, 80, 6}, 3, 200000, 1};
auto pc_jarr = tv::io::load_from_file(path);
auto pc = pc_jarr.tensors.at(0).cuda();
// you should use point_to_voxel_hash_static in tensorrt and manage hash data in tensorrt workspace.
auto p2v_res = p2v.point_to_voxel_hash(pc);
tv::Tensor voxels = std::get<0>(p2v_res).cuda().view(-1, 3);
auto indices_without_bs = std::get<1>(p2v_res);
auto indices = tv::zeros({indices_without_bs.dim(0), 4}, tv::int32, 0);
indices.slice(1, 1, 4, 1, false, false).copy_2d_pitched_(indices_without_bs);
auto indices_cpu = indices.cpu();
auto indices_cpu_data_ptr = indices_cpu.data_ptr<int32_t>();
for (int i = 0; i < 5; ++i){
auto cur_indices_cpu_data_ptr = indices_cpu_data_ptr + i * 4;
tv::ssprint(cur_indices_cpu_data_ptr[0], cur_indices_cpu_data_ptr[1], cur_indices_cpu_data_ptr[2], cur_indices_cpu_data_ptr[3]);
}
auto num_per_voxel = std::get<2>(p2v_res);
tv::ssprint("num voxels", voxels.shape());
auto voxels_f16 = tv::zeros(voxels.shape(), tv::float16, 0);
auto voxels_f16_ptr = voxels_f16.data_ptr<__half>();
auto voxels_ptr = voxels.data_ptr<float>();
tv::kernel_1d_map(0, voxels_f16.size(), [=]TV_GPU_LAMBDA(size_t i)mutable{
voxels_f16_ptr[i] = __half(voxels_ptr[i]);
});
// out channels, ksize, in channels
tv::Tensor weights = tv::zeros({64, 3, 3, 3, 3}, tv::float16, 0);
tv::Tensor bias = tv::zeros({64}, tv::float16, 0);
int KV = 27;
int out_inds_num_limit = 100000; // upper bound of number of output indices.
std::vector<int32_t> ksize{3, 3, 3};
std::vector<int32_t> padding{1, 1, 1};
std::vector<int32_t> dilation{1, 1, 1};
std::vector<int32_t> stride{1, 1, 1};
int ndim = 3;
auto p2v_grid_size = p2v.get_grid_size();
std::vector<int32_t> input_dims(p2v_grid_size.begin(),
p2v_grid_size.end());
auto out_dims = SpconvOps::get_conv_output_size(input_dims, ksize, stride, padding, dilation);
tv::ssprint(ksize, input_dims, out_dims);
std::vector<int64_t> output_dims_i64(out_dims.begin(),
out_dims.end());
int64_t out_spatial_volume =
std::accumulate(output_dims_i64.begin(), output_dims_i64.end(),
int64_t(1), std::multiplies<int64_t>());
bool use_int64_hash_k =
out_spatial_volume >= int64_t(std::numeric_limits<int>::max());
int num_act_in = voxels.dim(0);
bool is_subm = true;
bool direct_table = true;
int batch_size = 1;
int transpose = false;
bool use_direct_table = direct_table && !is_subm;
auto conv_algo = tv::gemm::SparseConvAlgo::kMaskImplicitGemm;
auto max_act_out_theory = SpconvOps::get_handcrafted_max_act_out(num_act_in,
ksize, stride, padding, dilation);
int workspace_size = SpconvOps::get_indice_gen_workspace_size(
KV, num_act_in, out_inds_num_limit, max_act_out_theory, is_subm,
use_int64_hash_k, use_direct_table);
// you should return workspace size in tensorrt plugin method.
tv::Tensor workspace = tv::empty({workspace_size}, tv::uint8, 0);
// get tensor map required by pair gen from workspace
auto ws_tensors = SpconvOps::get_indice_gen_tensors_from_workspace(
workspace.raw_data(), KV, num_act_in, is_subm ? num_act_in : out_inds_num_limit,
max_act_out_theory, is_subm, use_int64_hash_k, use_direct_table);
// create output tensors and insert them to static allocator
int pair_size = is_subm ? num_act_in : out_inds_num_limit;
tv::Tensor pair_fwd = tv::empty({KV, pair_size}, tv::int32, 0);
bool is_split_mask =
conv_algo == tv::gemm::SparseConvAlgo::kMaskSplitImplicitGemm;
int mask_count = is_split_mask ? 2 : 1;
tv::Tensor pair_mask_fwd = tv::empty({mask_count, pair_size}, tv::int32, 0);
tv::Tensor mask_argsort_fwd = tv::empty({mask_count, pair_size}, tv::int32, 0);
tv::Tensor out_inds = tv::empty({out_inds_num_limit, ndim + 1}, tv::int32, 0);
tv::Tensor indices_kernel_num = tv::zeros({KV}, tv::int32, 0);
cudaStream_t stream = 0;
ws_tensors.insert({SPCONV_ALLOC_PAIR_FWD, pair_fwd});
ws_tensors.insert({SPCONV_ALLOC_PAIR_MASK, pair_mask_fwd});
ws_tensors.insert({SPCONV_ALLOC_MASK_ARG_SORT, mask_argsort_fwd});
ws_tensors.insert({SPCONV_ALLOC_OUT_INDICES, out_inds});
ws_tensors.insert({SPCONV_ALLOC_INDICE_NUM_PER_LOC, indices_kernel_num});
StaticAllocator alloc(ws_tensors);
auto pair_res = SpconvOps::get_indice_pairs_implicit_gemm(
alloc, indices, batch_size, input_dims, static_cast<int>(conv_algo),
ksize, stride, padding, dilation, {0, 0, 0},
is_subm, transpose, false,
reinterpret_cast<std::uintptr_t>(stream), out_inds_num_limit,
tv::CUDAKernelTimer(false), use_direct_table);
int num_act_out = std::get<1>(pair_res);
tv::Tensor out_features = tv::empty({num_act_out, 64}, tv::float16, 0);
// this function is very slow, don't forget to cache result.
auto arch = ConvGemmOps::get_compute_capability();
int kv = pair_fwd.dim(0);
bool is_mask_split = pair_mask_fwd.dim(0) > 1;
int mask_split_cnt = pair_mask_fwd.dim(0);
tv::Tensor mask_tensor =
tv::zeros({pair_mask_fwd.dim(0)}, tv::uint32, -1);
auto mask_tensor_ptr = mask_tensor.data_ptr<uint32_t>();
if (is_mask_split) {
auto kv_div_2 = kv / 2;
auto remain = kv - kv_div_2;
uint64_t mask_np_1 = 1;
uint64_t first = ((mask_np_1 << remain) - 1);
uint64_t second = ((mask_np_1 << kv_div_2) - 1) << remain;
mask_tensor_ptr[0] = uint32_t(first);
mask_tensor_ptr[1] = uint32_t(second);
} else {
mask_tensor_ptr[0] = 0xffffffff;
}
std::vector<tv::Tensor> pair_mask_splits;
std::vector<tv::Tensor> mask_argsort_splits;
for (int i = 0; i < mask_split_cnt; ++i) {
pair_mask_splits.push_back(
pair_mask_fwd[i]);
mask_argsort_splits.push_back(
mask_argsort_fwd[i]);
}
std::unordered_map<std::string, tv::Tensor> tensor_dict{
{SPCONV_ALLOC_FEATURES, voxels_f16},
{SPCONV_ALLOC_FILTERS, weights},
{SPCONV_ALLOC_OUT_FEATURES, out_features}};
StaticAllocator alloc2(tensor_dict);
ConvTunerSimple tuner(ConvMain::get_all_conv_algo_desp());
auto conv_res = ConvGemmOps::implicit_gemm(
alloc2, tuner, voxels_f16, weights, pair_fwd,
pair_mask_splits, mask_argsort_splits, num_act_out,
mask_tensor, arch, false, is_subm,
reinterpret_cast<std::uintptr_t>(stream), tv::CUDAKernelTimer(false),
false, false, bias, 1.0,
0.0, tv::gemm::Activation::kReLU);
// p2v.point_to_voxel_hash()
return 0;
}
\ No newline at end of file
...@@ -16,4 +16,4 @@ python -m spconv.gencode --include=$SCRIPT_DIR/spconv/include --src=$SCRIPT_DIR/ ...@@ -16,4 +16,4 @@ python -m spconv.gencode --include=$SCRIPT_DIR/spconv/include --src=$SCRIPT_DIR/
mkdir -p $SCRIPT_DIR/build mkdir -p $SCRIPT_DIR/build
cd $SCRIPT_DIR/build cd $SCRIPT_DIR/build
cmake .. cmake ..
cmake --build $SCRIPT_DIR/build --config Release -j 8 cmake --build $SCRIPT_DIR/build --config Release -j 8 --verbose
...@@ -48,11 +48,11 @@ DESCRIPTION = 'spatial sparse convolution' ...@@ -48,11 +48,11 @@ DESCRIPTION = 'spatial sparse convolution'
URL = 'https://github.com/traveller59/spconv' URL = 'https://github.com/traveller59/spconv'
EMAIL = 'yanyan.sub@outlook.com' EMAIL = 'yanyan.sub@outlook.com'
AUTHOR = 'Yan Yan' AUTHOR = 'Yan Yan'
REQUIRES_PYTHON = '>=3.6' REQUIRES_PYTHON = '>=3.7'
VERSION = None VERSION = None
# What packages are required for this module to be executed? # What packages are required for this module to be executed?
REQUIRED = ["pccm>=0.3.5", "pybind11>=2.6.0", "fire", "numpy", *deps] REQUIRED = ["pccm>=0.4.0", "ccimport>=0.4.0", "pybind11>=2.6.0", "fire", "numpy", *deps]
# What packages are optional? # What packages are optional?
EXTRAS = { EXTRAS = {
......
...@@ -55,7 +55,6 @@ from spconv.core_cc.csrc.sparse.convops.convops import ConvTunerSimple as ConvTu ...@@ -55,7 +55,6 @@ from spconv.core_cc.csrc.sparse.convops.convops import ConvTunerSimple as ConvTu
ALL_ALGO_DESPS = GemmMainUnitTest.get_all_algo_desp() ALL_ALGO_DESPS = GemmMainUnitTest.get_all_algo_desp()
ALL_CONV_ALGO_DESPS = ConvMainUnitTest.get_all_conv_algo_desp() ALL_CONV_ALGO_DESPS = ConvMainUnitTest.get_all_conv_algo_desp()
_GEMM_STATIC_KEY = Tuple[bool, bool, bool, int, int, int, int, str]
class SimpleGemmAlgoMeta: class SimpleGemmAlgoMeta:
...@@ -205,6 +204,8 @@ class ConvTunerSimple(ConvTunerSimpleBase): ...@@ -205,6 +204,8 @@ class ConvTunerSimple(ConvTunerSimpleBase):
self._nvrtc_caches[key] = nvrtc_params self._nvrtc_caches[key] = nvrtc_params
return nvrtc_params return nvrtc_params
_GEMM_STATIC_KEY = Tuple[bool, bool, bool, int, int, int, int]
class SimpleGemm: class SimpleGemm:
def __init__(self, prebuilt_desps: List[GemmAlgoDesp]) -> None: def __init__(self, prebuilt_desps: List[GemmAlgoDesp]) -> None:
...@@ -256,7 +257,7 @@ class SimpleGemm: ...@@ -256,7 +257,7 @@ class SimpleGemm:
@staticmethod @staticmethod
def get_static_key(d: GemmAlgoDesp) -> _GEMM_STATIC_KEY: def get_static_key(d: GemmAlgoDesp) -> _GEMM_STATIC_KEY:
return (d.trans_a, d.trans_b, d.trans_c, d.dtype_a, d.dtype_b, return (d.trans_a, d.trans_b, d.trans_c, d.dtype_a, d.dtype_b,
d.dtype_c, d.shuffle_type.value, d.algo) d.dtype_c, d.shuffle_type.value)
def device_synchronize(self): def device_synchronize(self):
return GemmMainUnitTest.device_synchronize() return GemmMainUnitTest.device_synchronize()
...@@ -310,119 +311,34 @@ class SimpleGemm: ...@@ -310,119 +311,34 @@ class SimpleGemm:
avail_algos = get_available_algo_str_from_arch(arch) avail_algos = get_available_algo_str_from_arch(arch)
finally_algos: List[GemmAlgoDesp] = [] finally_algos: List[GemmAlgoDesp] = []
# print(self.static_key_to_desps) # print(self.static_key_to_desps)
for algo in avail_algos: static_key = (trans_a, trans_b, trans_c, a.dtype, b.dtype, c.dtype,
static_key = (trans_a, trans_b, trans_c, a.dtype, b.dtype, c.dtype, shuffle_type.value)
shuffle_type.value, algo) # for algo in avail_algos:
# static_key = (trans_a, trans_b, trans_c, a.dtype, b.dtype, c.dtype,
# shuffle_type.value)
# print(static_key) # print(static_key)
desps = self.static_key_to_desps.get(static_key, None) desps = self.static_key_to_desps.get(static_key, None)
if desps is None or len(desps) == 0: if desps is None or len(desps) == 0:
return finally_algos
# print(desps)
for desp in desps:
if arch < desp.min_arch:
continue continue
# print(desps) # skip volta tensor op since it is very slow in architectures except volta.
for desp in desps: if arch >= (7, 5) and desp.algo == GemmAlgo.Volta.value:
# skip volta tensor op since it is very slow in architectures except volta. continue
if arch >= (7, 5) and desp.algo == GemmAlgo.Volta.value: lda = a.stride[0]
continue ldb = b.stride[0]
lda = a.stride[0] ldc = c.stride[0]
ldb = b.stride[0] if desp.supported_ldx(lda, ldb, ldc):
ldc = c.stride[0] if arch not in COMPILED_CUDA_ARCHS:
if desp.supported_ldx(lda, ldb, ldc): desp = desp.copy()
if arch not in COMPILED_CUDA_ARCHS: desp.is_nvrtc = True
desp = desp.copy() if SPCONV_DEBUG_NVRTC_KERNELS:
desp.is_nvrtc = True desp.is_nvrtc = True
if SPCONV_DEBUG_NVRTC_KERNELS: finally_algos.append(desp)
desp.is_nvrtc = True
finally_algos.append(desp)
return finally_algos return finally_algos
def select(self,
a: tv.Tensor,
b: tv.Tensor,
c: tv.Tensor,
trans_a: bool,
trans_b: bool,
trans_c: bool,
arch: Tuple[int, int],
shuffle_type: ShuffleStrideType = ShuffleStrideType.NoShuffle,
a_inds: tv.Tensor = tv.Tensor(),
b_inds: tv.Tensor = tv.Tensor(),
c_inds: tv.Tensor = tv.Tensor(),
hint: int = AlgoHint.NoHint.value):
m, n, k = GemmMainUnitTest.extract_mnk(a.shape, b.shape, trans_a,
trans_b, trans_c,
shuffle_type.value,
a_inds.shape, b_inds.shape,
c_inds.shape)
if trans_c:
trans_a = not trans_a
trans_b = not trans_b
trans_a, trans_b = trans_b, trans_a
a, b = b, a
trans_c = False
avail_algos = get_available_algo_str_from_arch(arch)
finally_algos: List[GemmAlgoDesp] = []
for algo in avail_algos:
static_key = (trans_a, trans_b, trans_c, a.dtype, b.dtype, c.dtype,
shuffle_type.value, algo)
desps = self.static_key_to_desps.get(static_key, None)
if desps is None or len(desps) == 0:
continue
meta = self.static_key_to_meta[static_key]
# for shuffle stride algos, we need to make channel tile size as large as possible.
# so if ShuffleAC, we need to make k largest.
selected_algo_desps = GemmMainUnitTest.simple_select_tile_shape(
m,
n,
k,
meta.tile_ms,
meta.tile_ns,
meta.tile_ks,
meta.tile_shape_to_algos,
large_k_first=shuffle_type == shuffle_type.ShuffleAC)
if not selected_algo_desps:
candidate = desps
else:
candidate = [desps[i] for i in selected_algo_desps]
# select by hint
if hint == 0:
return candidate[0]
if hint & (AlgoHint.Fowrard.value | AlgoHint.BackwardInput.value):
# m may be huge, n and k are small
# don't need mixed precision
# don't need splitk
finally_algos = []
if a.dtype == tv.float16:
dacc = tv.float16
dcomp = tv.float16
for can in candidate:
if can.dacc == dacc and can.dcomp == dcomp:
finally_algos.append(can)
else:
finally_algos = candidate
elif hint & AlgoHint.BackwardWeight.value:
# k is huge
# don't support i8
# if f16, acc and comp must be f32
finally_algos = []
candidate_filtered: List[GemmAlgoDesp] = list(
filter(lambda x: x.split_k_serial, candidate))
if not candidate_filtered:
candidate_filtered = candidate
if a.dtype == tv.int8:
continue
elif a.dtype == tv.float16:
dacc = tv.float32
dcomp = tv.float32
for can in candidate_filtered:
if can.dacc == dacc and can.dcomp == dcomp:
finally_algos.append(can)
else:
finally_algos = candidate_filtered
else:
return candidate[0]
# print(finally_algos)
if finally_algos:
return finally_algos[0]
return None
def get_tuned_algo( def get_tuned_algo(
self, self,
...@@ -672,7 +588,7 @@ class SimpleGemm: ...@@ -672,7 +588,7 @@ class SimpleGemm:
return algo_desp return algo_desp
_CONV_STATIC_KEY = Tuple[int, int, int, int, int, int, int, int, int, str, int] _CONV_STATIC_KEY = Tuple[int, int, int, int, int, int, int, int, int, int]
class SimpleConv: class SimpleConv:
...@@ -729,7 +645,7 @@ class SimpleConv: ...@@ -729,7 +645,7 @@ class SimpleConv:
def get_static_key(d: ConvAlgoDesp) -> _CONV_STATIC_KEY: def get_static_key(d: ConvAlgoDesp) -> _CONV_STATIC_KEY:
return (d.layout_i.value, d.layout_w.value, d.layout_o.value, return (d.layout_i.value, d.layout_w.value, d.layout_o.value,
d.interleave_i, d.interleave_w, d.interleave_o, d.dtype_input, d.interleave_i, d.interleave_w, d.interleave_o, d.dtype_input,
d.dtype_weight, d.dtype_output, d.algo, d.op_type.value) d.dtype_weight, d.dtype_output, d.op_type.value)
def device_synchronize(self): def device_synchronize(self):
return GemmMainUnitTest.device_synchronize() return GemmMainUnitTest.device_synchronize()
...@@ -762,41 +678,42 @@ class SimpleConv: ...@@ -762,41 +678,42 @@ class SimpleConv:
else: else:
use_f32_as_accum = fp32_accum use_f32_as_accum = fp32_accum
# use_f32_as_accum = False # use_f32_as_accum = False
for algo in avail_algos: static_key = (layout_i.layout_type.value,
static_key = (layout_i.layout_type.value, layout_w.layout_type.value,
layout_w.layout_type.value, layout_o.layout_type.value, layout_i.interleave,
layout_o.layout_type.value, layout_i.interleave, layout_w.interleave, layout_o.interleave, inp.dtype,
layout_w.interleave, layout_o.interleave, inp.dtype, weight.dtype, out.dtype, op_type.value)
weight.dtype, out.dtype, algo, op_type.value) desps = self.static_key_to_desps.get(static_key, None)
desps = self.static_key_to_desps.get(static_key, None) if desps is None or len(desps) == 0:
if desps is None or len(desps) == 0: return finally_algos
for desp in desps:
if arch < desp.min_arch:
continue continue
for desp in desps: # skip volta tensor op since it is very slow in architectures except volta.
# skip volta tensor op since it is very slow in architectures except volta. if arch >= (7, 5) and desp.algo == GemmAlgo.Volta.value:
if arch >= (7, 5) and desp.algo == GemmAlgo.Volta.value: continue
if arch >= (7, 0) and is_fp16:
if desp.algo == GemmAlgo.Simt:
continue continue
if arch >= (7, 0) and is_fp16: if use_f32_as_accum:
if desp.algo == GemmAlgo.Simt: if desp.dacc == tv.float16:
continue continue
if use_f32_as_accum:
if desp.dacc == tv.float16: ldi = inp.dim(-1)
continue ldw = weight.dim(-1)
ldo = out.dim(-1)
ldi = inp.dim(-1) mask_width_valid = True
ldw = weight.dim(-1)
ldo = out.dim(-1) if desp.op_type == ConvOpType.kBackwardWeight.value:
mask_width_valid = True assert mask_width > 0
mask_width_valid = mask_width % desp.tile_shape[2] == 0
if desp.op_type == ConvOpType.kBackwardWeight.value: if desp.supported_ldx_conv(ldi, ldw, ldo) and mask_width_valid:
assert mask_width > 0 if arch not in COMPILED_CUDA_ARCHS:
mask_width_valid = mask_width % desp.tile_shape[2] == 0 desp = desp.copy()
if desp.supported_ldx_conv(ldi, ldw, ldo) and mask_width_valid: desp.is_nvrtc = True
if arch not in COMPILED_CUDA_ARCHS: if SPCONV_DEBUG_NVRTC_KERNELS:
desp = desp.copy() desp.is_nvrtc = True
desp.is_nvrtc = True finally_algos.append(desp)
if SPCONV_DEBUG_NVRTC_KERNELS:
desp.is_nvrtc = True
finally_algos.append(desp)
return finally_algos return finally_algos
def get_tuned_algo(self, def get_tuned_algo(self,
...@@ -1058,5 +975,5 @@ CONV_CPP = ConvTunerSimple([ ...@@ -1058,5 +975,5 @@ CONV_CPP = ConvTunerSimple([
for p in ALL_IMPGEMM_PARAMS]) for p in ALL_IMPGEMM_PARAMS])
if __name__ == "__main__": if __name__ == "__main__":
print(len(ALL_CONV_ALGO_DESPS)) for desp in ALL_CONV_ALGO_DESPS:
print(ALL_CONV_ALGO_DESPS[0]) print(desp, desp.min_arch)
...@@ -4,9 +4,7 @@ from cumm.common import GemmBasicHost, NlohmannJson, TensorView ...@@ -4,9 +4,7 @@ from cumm.common import GemmBasicHost, NlohmannJson, TensorView
from cumm.constants import CUMM_CPU_ONLY_BUILD from cumm.constants import CUMM_CPU_ONLY_BUILD
from cumm.conv.main import ConvMainUnitTest from cumm.conv.main import ConvMainUnitTest
from cumm.gemm.algospec.core import (_GEMM_MIN_ARCH_TO_ALGO, GemmAlgo, from cumm.gemm.algospec.core import (_GEMM_MIN_ARCH_TO_ALGO, GemmAlgo,
ShuffleStrideType, ShuffleStrideType)
get_available_algo_str_from_arch,
get_min_arch_of_algo_str)
from cumm.gemm.main import GemmMainUnitTest from cumm.gemm.main import GemmMainUnitTest
from spconv.constants import NDIM_DONT_CARE, SPCONV_BWD_SPLITK, AllocKeys from spconv.constants import NDIM_DONT_CARE, SPCONV_BWD_SPLITK, AllocKeys
from spconv.core import AlgoHint, ConvAlgo from spconv.core import AlgoHint, ConvAlgo
...@@ -472,7 +470,7 @@ class GemmTunerSimple(pccm.ParameterizedClass): ...@@ -472,7 +470,7 @@ class GemmTunerSimple(pccm.ParameterizedClass):
self.add_typedef( self.add_typedef(
"static_key_t", "std::tuple<bool, bool, bool, int, " "static_key_t", "std::tuple<bool, bool, bool, int, "
"int, int, int, std::string>") "int, int, int>")
self.add_typedef("algo_cache_key_t", "std::tuple<int, " self.add_typedef("algo_cache_key_t", "std::tuple<int, "
"int, int, int, int>") "int, int, int, int>")
...@@ -501,7 +499,7 @@ class GemmTunerSimple(pccm.ParameterizedClass): ...@@ -501,7 +499,7 @@ class GemmTunerSimple(pccm.ParameterizedClass):
for (auto& d : desps){{ for (auto& d : desps){{
static_key_t static_key = std::make_tuple(d.trans_a(), d.trans_b(), d.trans_c(), d.dtype_a, d.dtype_b, static_key_t static_key = std::make_tuple(d.trans_a(), d.trans_b(), d.trans_c(), d.dtype_a, d.dtype_b,
d.dtype_c, int(d.shuffle_type), d.algo); d.dtype_c, int(d.shuffle_type));
auto& vec = static_key_to_desps_[static_key]; auto& vec = static_key_to_desps_[static_key];
vec.push_back(d); vec.push_back(d);
}} }}
...@@ -548,31 +546,32 @@ class GemmTunerSimple(pccm.ParameterizedClass): ...@@ -548,31 +546,32 @@ class GemmTunerSimple(pccm.ParameterizedClass):
std::swap(a, b); std::swap(a, b);
trans_c = false; trans_c = false;
}} }}
auto avail_algos = get_available_algo_str_from_arch(arch); // auto avail_algos = get_available_algo_str_from_arch(arch);
std::vector<tv::gemm::GemmAlgoDesp> finally_algos; std::vector<tv::gemm::GemmAlgoDesp> finally_algos;
auto is_arch_compiled = CompileInfo::arch_is_compiled_gemm(arch); auto is_arch_compiled = CompileInfo::arch_is_compiled_gemm(arch);
for (auto algo : avail_algos){{ static_key_t static_key = std::make_tuple(trans_a, trans_b, trans_c, int(a.dtype()),
static_key_t static_key = std::make_tuple(trans_a, trans_b, trans_c, int(a.dtype()), int(b.dtype()), int(c.dtype()), shuffle_type);
int(b.dtype()), int(c.dtype()), shuffle_type, algo); if (static_key_to_desps_.find(static_key) == static_key_to_desps_.end()){{
if (static_key_to_desps_.find(static_key) == static_key_to_desps_.end()){{ return finally_algos;
}}
auto& desps = static_key_to_desps_.at(static_key);
for (auto& desp : desps){{
if (arch < desp.min_arch){{
continue; continue;
}} }}
auto& desps = static_key_to_desps_.at(static_key); if (arch >= std::make_tuple(7, 5) && desp.algo == {pccm.literal(GemmAlgo.Volta.value)}){{
for (auto& desp : desps){{ continue;
if (arch >= std::make_tuple(7, 5) && desp.algo == {pccm.literal(GemmAlgo.Volta.value)}){{ }}
continue; auto lda = a.stride(0);
}} auto ldb = b.stride(0);
auto lda = a.stride(0); auto ldc = c.stride(0);
auto ldb = b.stride(0); if (desp.supported_ldx(lda, ldb, ldc)){{
auto ldc = c.stride(0); if (!is_arch_compiled){{
if (desp.supported_ldx(lda, ldb, ldc)){{ auto desp2 = desp;
if (!is_arch_compiled){{ desp2.is_nvrtc = true;
auto desp2 = desp; finally_algos.push_back(desp2);
desp2.is_nvrtc = true; }}else{{
finally_algos.push_back(desp2); finally_algos.push_back(desp);
}}else{{
finally_algos.push_back(desp);
}}
}} }}
}} }}
}} }}
...@@ -895,7 +894,7 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -895,7 +894,7 @@ class ConvTunerSimple(pccm.ParameterizedClass):
self.add_typedef("static_key_t", self.add_typedef("static_key_t",
("std::tuple<int, int, int, int, int, " ("std::tuple<int, int, int, int, int, "
"int, int, int, int, std::string, int>")) "int, int, int, int, int>"))
self.add_typedef( self.add_typedef(
"algo_cache_key_t", "std::tuple<int, int, int, int, " "algo_cache_key_t", "std::tuple<int, int, int, int, "
"int, int, int, int>") "int, int, int, int>")
...@@ -927,7 +926,7 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -927,7 +926,7 @@ class ConvTunerSimple(pccm.ParameterizedClass):
static_key_t static_key = std::make_tuple( static_key_t static_key = std::make_tuple(
int(d.layout_i), int(d.layout_w), int(d.layout_o), int(d.layout_i), int(d.layout_w), int(d.layout_o),
d.interleave_i, d.interleave_w, d.interleave_o, d.dtype_input(), d.interleave_i, d.interleave_w, d.interleave_o, d.dtype_input(),
d.dtype_weight(), d.dtype_output(), d.algo, int(d.op_type)); d.dtype_weight(), d.dtype_output(), int(d.op_type));
auto& vec = static_key_to_desps_[static_key]; auto& vec = static_key_to_desps_[static_key];
vec.push_back(d); vec.push_back(d);
}} }}
...@@ -974,7 +973,6 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -974,7 +973,6 @@ class ConvTunerSimple(pccm.ParameterizedClass):
code.raw(f""" code.raw(f"""
tv::gemm::ConvOpType op_type_cpp = static_cast<tv::gemm::ConvOpType>(op_type); tv::gemm::ConvOpType op_type_cpp = static_cast<tv::gemm::ConvOpType>(op_type);
auto avail_algos = get_available_algo_str_from_arch(arch);
bool is_fp16 = (inp.dtype() == tv::float16 && bool is_fp16 = (inp.dtype() == tv::float16 &&
weight.dtype() == tv::float16 && out.dtype() == tv::float16); weight.dtype() == tv::float16 && out.dtype() == tv::float16);
bool use_f32_as_accum = false; bool use_f32_as_accum = false;
...@@ -997,49 +995,50 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -997,49 +995,50 @@ class ConvTunerSimple(pccm.ParameterizedClass):
std::vector<tv::gemm::ConvAlgoDesp> finally_algos; std::vector<tv::gemm::ConvAlgoDesp> finally_algos;
auto is_arch_compiled = CompileInfo::arch_is_compiled_gemm(arch); auto is_arch_compiled = CompileInfo::arch_is_compiled_gemm(arch);
for (auto algo : avail_algos){{ static_key_t static_key = std::make_tuple(
static_key_t static_key = std::make_tuple( layout_i, layout_w, layout_o,
layout_i, layout_w, layout_o, interleave_i, interleave_w, interleave_o, inp.dtype(),
interleave_i, interleave_w, interleave_o, inp.dtype(), weight.dtype(), out.dtype(), op_type);
weight.dtype(), out.dtype(), algo, op_type); if (static_key_to_desps_.find(static_key) == static_key_to_desps_.end()){{
if (static_key_to_desps_.find(static_key) == static_key_to_desps_.end()){{ return finally_algos;
}}
auto& desps = static_key_to_desps_.at(static_key);
for (auto& desp : desps){{
if (arch < desp.min_arch){{
continue; continue;
}} }}
auto& desps = static_key_to_desps_.at(static_key); if (arch >= std::make_tuple(7, 5) && desp.algo == {pccm.literal(GemmAlgo.Volta.value)}){{
for (auto& desp : desps){{ continue;
if (arch >= std::make_tuple(7, 5) && desp.algo == {pccm.literal(GemmAlgo.Volta.value)}){{ }}
if (arch >= std::make_tuple(7, 0) && is_fp16){{
// skip simt fp16 kernels if we have tensor core
if (desp.algo == {pccm.literal(GemmAlgo.Simt.value)}){{
continue; continue;
}} }}
if (arch >= std::make_tuple(7, 0) && is_fp16){{ if (use_f32_as_accum){{
// skip simt fp16 kernels if we have tensor core if (desp.dacc == tv::float16){{
if (desp.algo == {pccm.literal(GemmAlgo.Simt.value)}){{
continue; continue;
}} }}
if (use_f32_as_accum){{
if (desp.dacc == tv::float16){{
continue;
}}
}}
}} }}
}}
int ldi = inp.dim(-1); int ldi = inp.dim(-1);
int ldw = weight.dim(-1); int ldw = weight.dim(-1);
int ldo = out.dim(-1); int ldo = out.dim(-1);
bool mask_width_valid = true; bool mask_width_valid = true;
if (desp.op_type == tv::gemm::ConvOpType::kBackwardWeight){{ if (desp.op_type == tv::gemm::ConvOpType::kBackwardWeight){{
TV_ASSERT_RT_ERR(mask_width > 0, "eroro"); TV_ASSERT_RT_ERR(mask_width > 0, "eroro");
mask_width_valid = mask_width % desp.tile_shape[2] == 0; mask_width_valid = mask_width % desp.tile_shape[2] == 0;
}} }}
if (desp.supported_ldx_conv(ldi, ldw, ldo) && mask_width_valid){{ if (desp.supported_ldx_conv(ldi, ldw, ldo) && mask_width_valid){{
if (!is_arch_compiled){{ if (!is_arch_compiled){{
auto desp2 = desp; auto desp2 = desp;
desp2.is_nvrtc = true; desp2.is_nvrtc = true;
finally_algos.push_back(desp2); finally_algos.push_back(desp2);
}}else{{ }}else{{
finally_algos.push_back(desp); finally_algos.push_back(desp);
}}
}} }}
}} }}
}} }}
......
from spconv.pytorch.cppcore import TorchAllocator from cumm import tensorview as tv
print(1)
from cumm.tensorview import tvio
import numpy as np
from pathlib import Path
from spconv.core_cc.csrc.sparse.all import SpconvOps
import torch
print(2)
if __name__ == "__main__":
alloc = TorchAllocator(torch.device("cuda:0"))
SpconvOps.test_allocator(alloc) def main():
data = np.load(Path(__file__).parent / "data" / "benchmark-pc.npz")
with open(Path(__file__).parent / "data" / "benchmark-pc.jarr", "wb") as f:
f.write(tvio.dumps_jsonarray({
"pc": data
}).tobytes())
if __name__ == "__main__":
main()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment