Commit 0c07559f authored by yan.yan's avatar yan.yan
Browse files

working on performance problem

parent 21bb00ae
......@@ -752,7 +752,7 @@ class SimpleConv:
use_f32_as_accum = weight.dim(0) * kv > 128 * 27
else:
use_f32_as_accum = fp32_accum
use_f32_as_accum = False
# use_f32_as_accum = False
for algo in avail_algos:
static_key = (layout_i.layout_type.value,
layout_w.layout_type.value,
......
......@@ -99,9 +99,9 @@ class AllocKeys:
SPCONV_DEBUG_WEIGHT = False
SPCONV_CPP_INDICE_PAIRS = True
SPCONV_CPP_INDICE_PAIRS_IGEMM = True
SPCONV_CPP_INDICE_PAIRS = False
SPCONV_CPP_INDICE_PAIRS_IGEMM = False
SPCONV_CPP_GEMM = True
SPCONV_CPP_GEMM = False
SPCONV_FX_TRACE_MODE = os.getenv("SPCONV_FX_TRACE_MODE", "0") == "1"
\ No newline at end of file
......@@ -16,10 +16,9 @@ from cumm.gemm.main import gen_shuffle_params_v2 as gen_shuffle_params, GemmAlgo
from cumm.gemm import kernel
from typing import List
from cumm.gemm.algospec.core import TensorOp
from cumm.conv.main import gen_gemm_params as gen_conv_params, ConvFwdAndBwdInput, ConvBwdWeight, ConvFwd, ConvIterAlgo, GemmAlgo
from cumm.conv.main import gen_gemm_params as gen_conv_params, ConvFwdAndBwdInput, ConvBwdWeight, ConvIterAlgo, GemmAlgo
from cumm.conv.bases import (NCHW, NHWC, ConvIterAlgo, ConvLayout,
ConvLayoutType, ConvMode, ConvOpType)
from spconv.algocore import get_gemm_algo_desp_from_param
from spconv.constants import NDIM_DONT_CARE
......@@ -41,17 +40,17 @@ class AlgoHint(Enum):
SHUFFLE_SIMT_PARAMS: List[GemmAlgoParams] = [
*gen_shuffle_params((64, 128, 32), (32, 64, 32), ["s8,s8,s32,s32,s32"], "",
2, kernel.GemmAlgo.SimtDP4A, None, is_nvrtc=True),
2, kernel.GemmAlgo.SimtDP4A, None),
*gen_shuffle_params((128, 64, 32), (64, 32, 32), ["s8,s8,s32,s32,s32"], "",
2, kernel.GemmAlgo.SimtDP4A, None, is_nvrtc=True),
2, kernel.GemmAlgo.SimtDP4A, None),
*gen_shuffle_params((128, 128, 32), (32, 64, 32), ["s8,s8,s32,s32,s32"],
"", 2, kernel.GemmAlgo.SimtDP4A, None, is_nvrtc=True),
"", 2, kernel.GemmAlgo.SimtDP4A, None),
*gen_shuffle_params(
(128, 128, 32),
(64, 32, 32), ["s8,s8,s8,s32,s32", "s8,s8,s32,s32,s32"], "", 2,
kernel.GemmAlgo.SimtDP4A, None, is_nvrtc=True),
kernel.GemmAlgo.SimtDP4A, None),
*gen_shuffle_params((64, 64, 32), (32, 32, 32), ["s8,s8,s32,s32,s32"], "",
2, kernel.GemmAlgo.SimtDP4A, None, is_nvrtc=True),
2, kernel.GemmAlgo.SimtDP4A, None),
*gen_shuffle_params((64, 256, 8), (32, 64, 8), ["f32,f32,f32,f32,f32"],
"f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None),
# *gen_shuffle_params(
......@@ -84,9 +83,6 @@ SHUFFLE_SIMT_PARAMS: List[GemmAlgoParams] = [
"f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None),
*gen_shuffle_params((32, 32, 32), (32, 32, 8), ["f32,f32,f32,f32,f32"],
"f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None),
*gen_shuffle_params((16, 32, 8), (16, 16, 8), ["f32,f32,f32,f32,f32"],
"f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None),
# fall back kernels if mat is misaligned for half
# TODO use access-per-vector kernel instead of simt kernel for fallback
*gen_shuffle_params((128, 128, 8), (32, 64, 8), ["f16,f16,f16,f32,f32"],
......@@ -169,11 +165,11 @@ SHUFFLE_TURING_PARAMS: List[GemmAlgoParams] = [
(32, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2,
kernel.GemmAlgo.Turing, TensorOp((16, 8, 8))),
*gen_shuffle_params((64, 64, 32), (32, 32, 32), ["s8,s8,s32,s32,s32"], "",
2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16)), is_nvrtc=True),
2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))),
*gen_shuffle_params(
(128, 128, 32),
(32, 64, 32), ["s8,s8,s32,s32,s32"], "", 2, kernel.GemmAlgo.Turing,
TensorOp((8, 8, 16)), is_nvrtc=True),
TensorOp((8, 8, 16))),
# *gen_shuffle_params(
# (128, 128, 32),
# (64, 32, 32), ["s8,s8,s8,s32,s32", "s8,s8,s32,s32,s32"], "", 2,
......@@ -181,15 +177,15 @@ SHUFFLE_TURING_PARAMS: List[GemmAlgoParams] = [
*gen_shuffle_params(
(128, 256, 32),
(64, 64, 32), ["s8,s8,s32,s32,s32"], "", 2, kernel.GemmAlgo.Turing,
TensorOp((8, 8, 16)), is_nvrtc=True),
TensorOp((8, 8, 16))),
*gen_shuffle_params(
(256, 128, 32),
(64, 64, 32), ["s8,s8,s32,s32,s32"], "", 2, kernel.GemmAlgo.Turing,
TensorOp((8, 8, 16)), is_nvrtc=True),
TensorOp((8, 8, 16))),
*gen_shuffle_params((128, 64, 32), (64, 32, 32), ["s8,s8,s32,s32,s32"], "",
2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16)), is_nvrtc=True),
2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))),
*gen_shuffle_params((64, 128, 32), (32, 64, 32), ["s8,s8,s32,s32,s32"], "",
2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16)), is_nvrtc=True),
2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))),
]
# SHUFFLE_TURING_PARAMS = []
......@@ -403,8 +399,6 @@ IMPLGEMM_SIMT_PARAMS = [
increment_k_first=True,
access_per_vector=1),
]
IMPLGEMM_VOLTA_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32),
NDIM_DONT_CARE,
......@@ -668,181 +662,6 @@ IMPLGEMM_TURING_PARAMS = [
# NHWC, NHWC, NHWC, GemmAlgo.Turing, TensorOp((16, 8, 8)), mask_sparse=True, increment_k_first=True, access_per_vector=1),
# gen_conv_params(ConvFwdAndBwdInput, )
# all int8 kernels use nvrtc.
*gen_conv_params(ConvFwd, (32, 32, 32), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["s8,s8,s8,s32,s32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((8, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=False),
*gen_conv_params(ConvFwd, (32, 64, 32), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["s8,s8,s8,s32,s32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((8, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=False),
*gen_conv_params(ConvFwd, (32, 32, 64), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["s8,s8,s8,s32,s32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((8, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=False),
*gen_conv_params(ConvFwd, (32, 64, 64), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["s8,s8,s8,s32,s32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((8, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=False),
*gen_conv_params(ConvFwd, (64, 128, 32), (32, 64, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["s8,s8,s8,s32,s32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((8, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=False),
*gen_conv_params(ConvFwd, (64, 64, 32), (32, 64, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["s8,s8,s8,s32,s32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((8, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=False),
*gen_conv_params(ConvFwd, (64, 64, 32), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["s8,s8,s8,s32,s32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((8, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=False),
*gen_conv_params(ConvFwd, (64, 32, 32), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["s8,s8,s8,s32,s32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((8, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=False),
*gen_conv_params(ConvFwd, (128, 128, 64), (64, 64, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["s8,s8,s8,s32,s32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((8, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=False),
*gen_conv_params(ConvFwd, (64, 128, 64), (32, 64, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["s8,s8,s8,s32,s32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((8, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=False),
# *gen_conv_params(ConvFwd, (32, 32, 32), (32, 32, 32),
# NDIM_DONT_CARE,
# ConvIterAlgo.Optimized,
# 2, ["s8,s8,s8,s32,s32"],
# NHWC,
# NHWC,
# NHWC,
# GemmAlgo.Turing,
# TensorOp((8, 8, 16)),
# mask_sparse=True,
# increment_k_first=True,
# access_per_vector=0,
# is_nvrtc=True),
# *gen_conv_params(ConvFwd, (32, 64, 32), (32, 32, 32),
# NDIM_DONT_CARE,
# ConvIterAlgo.Optimized,
# 2, ["s8,s8,s8,s32,s32"],
# NHWC,
# NHWC,
# NHWC,
# GemmAlgo.Turing,
# TensorOp((8, 8, 16)),
# mask_sparse=True,
# increment_k_first=True,
# access_per_vector=0,
# is_nvrtc=True),
# *gen_conv_params(ConvFwd, (32, 32, 64), (32, 32, 32),
# NDIM_DONT_CARE,
# ConvIterAlgo.Optimized,
# 2, ["s8,s8,s8,s32,s32"],
# NHWC,
# NHWC,
# NHWC,
# GemmAlgo.Turing,
# TensorOp((8, 8, 16)),
# mask_sparse=True,
# increment_k_first=True,
# access_per_vector=0,
# is_nvrtc=True),
]
ALL_NATIVE_PARAMS = SHUFFLE_SIMT_PARAMS + SHUFFLE_TURING_PARAMS + SHUFFLE_VOLTA_PARAMS
......
......@@ -11,7 +11,7 @@ class ConvGemmOps:
"""
...
@staticmethod
def indice_conv(allocator, ext_mm, gemm_tuner, all_w_is_krsc: bool, filter_hwio: bool, features: Tensor, filters: Tensor, indice_pairs: Tensor, indice_pair_num: Tensor, num_activate_out: int, inverse: bool = False, subm: bool = False, algo: int = 0, stream_int: int = 0) -> None:
def indice_conv(allocator, ext_mm, gemm_tuner, all_w_is_krsc: bool, filter_hwio: bool, features: Tensor, filters: Tensor, indice_pairs: Tensor, indice_pair_num: Tensor, arch: Tuple[int, int], num_activate_out: int, inverse: bool = False, subm: bool = False, algo: int = 0, stream_int: int = 0) -> None:
"""
1. this function need to take a out features
that from subm first mm.
......@@ -26,6 +26,7 @@ class ConvGemmOps:
filters:
indice_pairs:
indice_pair_num:
arch:
num_activate_out:
inverse:
subm:
......@@ -34,7 +35,7 @@ class ConvGemmOps:
"""
...
@staticmethod
def indice_conv_backward(allocator, ext_mm, gemm_tuner, all_w_is_krsc: bool, filter_hwio: bool, features: Tensor, filters: Tensor, out_bp: Tensor, indice_pairs: Tensor, indice_pair_num: Tensor, inverse: bool = False, subm: bool = False, algo: int = 0, stream_int: int = 0) -> None:
def indice_conv_backward(allocator, ext_mm, gemm_tuner, all_w_is_krsc: bool, filter_hwio: bool, features: Tensor, filters: Tensor, out_bp: Tensor, indice_pairs: Tensor, indice_pair_num: Tensor, arch: Tuple[int, int], inverse: bool = False, subm: bool = False, algo: int = 0, stream_int: int = 0) -> None:
"""
Args:
allocator:
......@@ -47,6 +48,7 @@ class ConvGemmOps:
out_bp:
indice_pairs:
indice_pair_num:
arch:
inverse:
subm:
algo:
......@@ -54,7 +56,7 @@ class ConvGemmOps:
"""
...
@staticmethod
def implicit_gemm(allocator, conv_tuner, features: Tensor, filters: Tensor, pair_fwd: Tensor, pair_mask_fwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], num_activate_out: int, masks: Tensor, is_train: bool = False, is_subm: bool = False, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False) -> int:
def implicit_gemm(allocator, conv_tuner, features: Tensor, filters: Tensor, pair_fwd: Tensor, pair_mask_fwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], num_activate_out: int, masks: Tensor, arch: Tuple[int, int], is_train: bool = False, is_subm: bool = False, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False) -> int:
"""
Args:
allocator:
......@@ -66,6 +68,7 @@ class ConvGemmOps:
mask_argsort_fwd_splits:
num_activate_out:
masks:
arch:
is_train:
is_subm:
stream_int:
......@@ -75,7 +78,7 @@ class ConvGemmOps:
"""
...
@staticmethod
def implicit_gemm_backward(allocator, conv_tuner, features: Tensor, filters: Tensor, out_bp: Tensor, pair_fwd: Tensor, pair_bwd: Tensor, pair_mask_fwd_splits: List[Tensor], pair_mask_bwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], mask_argsort_bwd_splits: List[Tensor], mask_output_fwd: Tensor, masks: Tensor, mask_width: int, is_subm: bool, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False) -> None:
def implicit_gemm_backward(allocator, conv_tuner, features: Tensor, filters: Tensor, out_bp: Tensor, pair_fwd: Tensor, pair_bwd: Tensor, pair_mask_fwd_splits: List[Tensor], pair_mask_bwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], mask_argsort_bwd_splits: List[Tensor], mask_output_fwd: Tensor, masks: Tensor, arch: Tuple[int, int], mask_width: int, is_subm: bool, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False) -> None:
"""
Args:
allocator:
......@@ -91,6 +94,7 @@ class ConvGemmOps:
mask_argsort_bwd_splits:
mask_output_fwd:
masks:
arch:
mask_width:
is_subm:
stream_int:
......
......@@ -1377,6 +1377,8 @@ class ConvGemmOps(pccm.ParameterizedClass):
code.arg("features, filters, indice_pairs", "tv::Tensor")
code.arg("indice_pair_num", "tv::Tensor")
code.arg("arch", "std::tuple<int, int>")
code.arg("num_activate_out", "int")
code.arg("inverse", "bool", "false")
code.arg("subm", "bool", "false")
......@@ -1489,7 +1491,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
}}
}}
TV_ASSERT_RT_ERR(nhot_profile > 0, "this shouldn't happen");
auto arch = get_compute_capability();
// auto arch = get_compute_capability();
auto a_shape = a.shape();
auto c_shape = c.shape();
int sac_shuffle_type = static_cast<int>(tv::gemm::ShuffleStrideType::kShuffleAC);
......@@ -1584,6 +1586,8 @@ class ConvGemmOps(pccm.ParameterizedClass):
code.arg("all_w_is_krsc, filter_hwio", "bool")
code.arg("features, filters, out_bp, indice_pairs", "tv::Tensor")
code.arg("indice_pair_num", "tv::Tensor")
code.arg("arch", "std::tuple<int, int>")
code.arg("inverse", "bool", "false")
code.arg("subm", "bool", "false")
code.arg("algo", "int", f"{ConvAlgo.Native.value}")
......@@ -1594,6 +1598,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
std::vector<int64_t> filter_shape_per_kv;
auto prev_filter_shape_vec = filters.shape_vector();
bool is_KC_not_CK;
if (!all_w_is_krsc){{
kv_dim = 0;
is_KC_not_CK = !filter_hwio;
......@@ -1700,7 +1705,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
}}
}}
TV_ASSERT_RT_ERR(nhot_profile > 0, "this shouldn't happen");
auto arch = get_compute_capability();
// auto arch = get_compute_capability();
int sac_shuffle_type = static_cast<int>(tv::gemm::ShuffleStrideType::kShuffleAC);
int sab_shuffle_type = static_cast<int>(tv::gemm::ShuffleStrideType::kShuffleAB);
......@@ -1899,6 +1904,8 @@ class ConvGemmOps(pccm.ParameterizedClass):
"std::vector<tv::Tensor>")
code.arg("num_activate_out", "int")
code.arg("masks", "tv::Tensor")
code.arg("arch", "std::tuple<int, int>")
code.arg("is_train, is_subm", "bool", "false")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int")
code.arg("timer", "tv::CUDAKernelTimer", "tv::CUDAKernelTimer(false)",
......@@ -1926,7 +1933,11 @@ class ConvGemmOps(pccm.ParameterizedClass):
out_features = allocator.zeros({pccm.literal(AllocKeys.OutFeatures)},
{{num_activate_out, out_channel}}, features.dtype(), features.device(), stream_int);
}}
auto arch = get_compute_capability();
// auto start_ev = tv::CUDAEvent();
// start_ev.record(stream_int);
// auto arch = get_compute_capability();
constexpr auto kForwardInt = static_cast<int>(tv::gemm::ConvOpType::kForward);
constexpr auto kChannelLastInt = static_cast<int>(tv::gemm::ConvLayoutType::kChannelLast);
auto tuned_res_exist = conv_tuner.get_tuned_algo(
......@@ -1959,6 +1970,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
fp32_accum);
tune_res = std::get<0>(tune_res_time);
}}
int mask_width = tune_res.algo_desp.tile_shape[0];
tv::Tensor mask_output_fwd;
std::vector<tv::Tensor> mask_output_fwd_splits;
......@@ -1974,6 +1986,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
mask_output_fwd_splits.push_back(tv::Tensor());
}}
}}
for (int j = 0; j < num_split; ++j){{
float beta = j == 0 ? 0 : 1;
conv_tuner.run_with_tuned_result(
......@@ -1995,6 +2008,11 @@ class ConvGemmOps(pccm.ParameterizedClass):
false, // verbose
timer);
}}
// auto end_ev = tv::CUDAEvent();
// end_ev.record(stream_int);
// tv::ssprint(tune_res.algo_desp.__repr__(), "WTF", exists,
// features.shape(), filters.shape(), out_features.shape(), tv::CUDAEvent::sync_and_duration(start_ev, end_ev));
return mask_width;
""")
return code.ret("int")
......@@ -2013,6 +2031,8 @@ class ConvGemmOps(pccm.ParameterizedClass):
code.arg("mask_output_fwd", "tv::Tensor")
code.arg("masks", "tv::Tensor")
code.arg("arch", "std::tuple<int, int>")
code.arg("mask_width", "int")
code.arg("is_subm", "bool")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int")
......@@ -2056,7 +2076,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
constexpr auto kChannelLastInt = static_cast<int>(tv::gemm::ConvLayoutType::kChannelLast);
auto arch = get_compute_capability();
// auto arch = get_compute_capability();
auto dgrad_tuned_res_exist = conv_tuner.get_tuned_algo(
kBackwardInputInt,
......
......@@ -419,10 +419,16 @@ def get_indice_pairs_implicit_gemm(
is_mask_split = algo == ConvAlgo.MaskSplitImplicitGemm
mask_split_count = 2 if is_mask_split else 1
if subm:
pair = torch.full((2, kv, indices.shape[0]),
-1,
dtype=indices.dtype,
device=indices.device)
if is_train:
pair = torch.full((2, kv, indices.shape[0]),
-1,
dtype=indices.dtype,
device=indices.device)
else:
pair = torch.full((1, kv, indices.shape[0]),
-1,
dtype=indices.dtype,
device=indices.device)
else:
# for regular conv, pair-in not equal to pair-out
pair = torch.full((kv, indices.shape[0]),
......@@ -476,6 +482,7 @@ def get_indice_pairs_implicit_gemm(
ksize=ksize,
dilation=dilation,
indice_pair_mask=pair_mask_tv,
backward=is_train,
stream_int=stream)
# torch.cuda.synchronize()
# print("SUBM0", time.time() - t)
......@@ -505,10 +512,12 @@ def get_indice_pairs_implicit_gemm(
CONV.stream_synchronize(stream)
print("SUBM", time.time() - t)
return (out_inds, indice_num_per_loc, pair[0], pair[1],
pair_mask_in_splits, [], mask_argsort_in_splits, [], masks)
if is_train:
return (out_inds, indice_num_per_loc, pair[0], pair[1],
pair_mask_in_splits, [], mask_argsort_in_splits, [], masks)
else:
return (out_inds, indice_num_per_loc, pair[0], torch.Tensor(),
pair_mask_in_splits, [], mask_argsort_in_splits, [], masks)
else:
if DEBUG:
......@@ -753,11 +762,15 @@ def indice_conv(features: torch.Tensor,
indice_pair_num_tv = torch_tensor_to_tv(indice_pair_num)
filters_tv = torch_tensor_to_tv(filters)
stream = 0
arch = (0, 0)
if features.is_cuda:
# plain get_arch by cuda api is VERY SLOW.
arch = get_arch()
stream = get_current_stream()
ConvGemmOps.indice_conv(alloc, ext_mm, GEMM_CPP, ALL_WEIGHT_IS_KRSC,
FILTER_HWIO, features_tv, filters_tv,
indice_pairs_tv, indice_pair_num_tv,
arch,
num_activate_out, inverse, subm, algo.value,
stream)
out_features = alloc.allocated[AllocKeys.OutFeatures]
......@@ -996,12 +1009,16 @@ def indice_conv_backward(features: torch.Tensor,
indice_pair_num_tv = torch_tensor_to_tv(indice_pair_num)
filters_tv = torch_tensor_to_tv(filters)
stream = 0
arch = (0, 0)
if features.is_cuda:
stream = get_current_stream()
arch = get_arch()
ConvGemmOps.indice_conv_backward(alloc, ext_mm, GEMM_CPP,
ALL_WEIGHT_IS_KRSC, FILTER_HWIO,
features_tv, filters_tv, out_bp_tv,
indice_pairs_tv, indice_pair_num_tv,
arch,
inverse, subm, algo.value, stream)
din = alloc.allocated[AllocKeys.DIn]
df = alloc.allocated[AllocKeys.DFilters]
......@@ -1347,10 +1364,12 @@ def implicit_gemm(features: torch.Tensor,
auto_fp32_accum = fp32_accum is None
if fp32_accum is None:
fp32_accum = False
arch = get_arch()
mask_width = ConvGemmOps.implicit_gemm(
alloc, CONV_CPP, features_tv, filters_tv, pair_fwd_tv,
pair_mask_fwd_splits_tv, mask_argsort_fwd_splits_tv,
num_activate_out, mask_tv, is_train, is_subm, stream, timer_cpp,
num_activate_out, mask_tv, arch, is_train, is_subm, stream, timer_cpp,
auto_fp32_accum, fp32_accum)
out_features = alloc.allocated[AllocKeys.OutFeatures]
mask_output_fwd = alloc.allocated.get(AllocKeys.MaskOutputFwd, None)
......@@ -1441,7 +1460,7 @@ def implicit_gemm(features: torch.Tensor,
# CONV.stream_synchronize(stream)
# t = time.time()
# print(tune_res.algo_desp)
print(tune_res.algo_desp, "REF", features_tv.shape, filters.shape)
# with tv.measure_and_print("f16 time"):
with timer.record("implicit_gemm", stream):
for j in range(num_split):
......@@ -1613,11 +1632,13 @@ def implicit_gemm_backward(features: torch.Tensor,
auto_fp32_accum = fp32_accum is None
if fp32_accum is None:
fp32_accum = False
arch = get_arch()
ConvGemmOps.implicit_gemm_backward(
alloc, CONV_CPP, features_tv, filters_tv, out_bp_tv, pair_fwd_tv,
pair_bwd_tv, pair_mask_fwd_splits_tv, pair_mask_bwd_splits_tv,
mask_argsort_fwd_splits_tv, mask_argsort_bwd_splits_tv,
mask_output_fwd_tv, mask_tv, mask_width, is_subm, stream,
mask_output_fwd_tv, mask_tv, arch, mask_width, is_subm, stream,
timer_cpp, auto_fp32_accum, fp32_accum)
din = alloc.allocated[AllocKeys.DIn]
dfilters = alloc.allocated[AllocKeys.DFilters]
......
......@@ -113,7 +113,7 @@ class Net(nn.Module):
# nn.BatchNorm1d(32),
# nn.ReLU(),
# spconv.SparseConv3d(64, 64, 2, 2, bias=False, indice_key="m0"),
spconv.SparseMaxPool3d(2, 2, algo=pool_algo, record_voxel_count=True),
spconv.SparseMaxPool3d(2, 2, algo=pool_algo),
spconv.SubMConv3d(64,
96,
3,
......@@ -312,7 +312,7 @@ def sort_bench():
for i in range(10):
a_tv_1 = a_tv.clone()
SpconvOps.sort_1d_by_key(a_tv_1[0], mask_argsort_tv[0])
import json
def main():
import pickle
......@@ -332,7 +332,8 @@ def main():
voxels_th = torch.from_numpy(voxels).to(device).to(dtype)
coors_th = torch.from_numpy(coors).to(device).int()
voxels_th.requires_grad = True
algo = spconv.ConvAlgo.Native
algo = spconv.ConvAlgo.MaskImplicitGemm
print("ALGO")
# 3080 Laptop
# MaskImpGemm: 11.2ms
# MaskSplitImpGemm: 12.2ms
......@@ -355,7 +356,7 @@ def main():
# MaskImpGemm: 51.0ms
# MaskSplitImpGemm: 41.1ms
# algo = None
net = Net(spatial_shape, algo).to(device).eval().to(dtype).train()
net = Net(spatial_shape, algo).to(device).eval().to(dtype)# .train()
# net.load_state_dict(net.state_dict())
spconv.assign_name_for_sparse_modules(net)
print(coors_th.shape)
......@@ -368,13 +369,13 @@ def main():
print(out.spatial_shape, out.features.mean(), out.features.max(),
out.features.min())
times = []
show_metrics = False
with torch.no_grad():
for i in range(20):
print("------------")
torch.cuda.synchronize()
t = time.time()
out_nograd = net(voxels_th, coors_th, 1, False)
timer = out_nograd._timer
out_nograd = net(voxels_th, coors_th, 1, show_metrics)
# res = timer.collect_by_name("forward", timer.get_all_pair_time())
# res2 = timer.collect_by_name("forward0", timer.get_all_pair_time())
......@@ -385,6 +386,12 @@ def main():
torch.cuda.synchronize()
# sort_bench()
times.append(time.time() - t)
if show_metrics:
timer = out_nograd._timer
items = list(timer.get_all_pair_time().items())
items.sort(key=lambda x: x[0])
print(json.dumps(dict(items), indent=2))
# state = net.state_dict()
# state.pop("net.2.max_num_voxels_during_training")
# net.load_state_dict(state)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment