Commit 0c07559f authored by yan.yan's avatar yan.yan
Browse files

working on performance problem

parent 21bb00ae
...@@ -752,7 +752,7 @@ class SimpleConv: ...@@ -752,7 +752,7 @@ class SimpleConv:
use_f32_as_accum = weight.dim(0) * kv > 128 * 27 use_f32_as_accum = weight.dim(0) * kv > 128 * 27
else: else:
use_f32_as_accum = fp32_accum use_f32_as_accum = fp32_accum
use_f32_as_accum = False # use_f32_as_accum = False
for algo in avail_algos: for algo in avail_algos:
static_key = (layout_i.layout_type.value, static_key = (layout_i.layout_type.value,
layout_w.layout_type.value, layout_w.layout_type.value,
......
...@@ -99,9 +99,9 @@ class AllocKeys: ...@@ -99,9 +99,9 @@ class AllocKeys:
SPCONV_DEBUG_WEIGHT = False SPCONV_DEBUG_WEIGHT = False
SPCONV_CPP_INDICE_PAIRS = True SPCONV_CPP_INDICE_PAIRS = False
SPCONV_CPP_INDICE_PAIRS_IGEMM = True SPCONV_CPP_INDICE_PAIRS_IGEMM = False
SPCONV_CPP_GEMM = True SPCONV_CPP_GEMM = False
SPCONV_FX_TRACE_MODE = os.getenv("SPCONV_FX_TRACE_MODE", "0") == "1" SPCONV_FX_TRACE_MODE = os.getenv("SPCONV_FX_TRACE_MODE", "0") == "1"
\ No newline at end of file
...@@ -16,10 +16,9 @@ from cumm.gemm.main import gen_shuffle_params_v2 as gen_shuffle_params, GemmAlgo ...@@ -16,10 +16,9 @@ from cumm.gemm.main import gen_shuffle_params_v2 as gen_shuffle_params, GemmAlgo
from cumm.gemm import kernel from cumm.gemm import kernel
from typing import List from typing import List
from cumm.gemm.algospec.core import TensorOp from cumm.gemm.algospec.core import TensorOp
from cumm.conv.main import gen_gemm_params as gen_conv_params, ConvFwdAndBwdInput, ConvBwdWeight, ConvFwd, ConvIterAlgo, GemmAlgo from cumm.conv.main import gen_gemm_params as gen_conv_params, ConvFwdAndBwdInput, ConvBwdWeight, ConvIterAlgo, GemmAlgo
from cumm.conv.bases import (NCHW, NHWC, ConvIterAlgo, ConvLayout, from cumm.conv.bases import (NCHW, NHWC, ConvIterAlgo, ConvLayout,
ConvLayoutType, ConvMode, ConvOpType) ConvLayoutType, ConvMode, ConvOpType)
from spconv.algocore import get_gemm_algo_desp_from_param
from spconv.constants import NDIM_DONT_CARE from spconv.constants import NDIM_DONT_CARE
...@@ -41,17 +40,17 @@ class AlgoHint(Enum): ...@@ -41,17 +40,17 @@ class AlgoHint(Enum):
SHUFFLE_SIMT_PARAMS: List[GemmAlgoParams] = [ SHUFFLE_SIMT_PARAMS: List[GemmAlgoParams] = [
*gen_shuffle_params((64, 128, 32), (32, 64, 32), ["s8,s8,s32,s32,s32"], "", *gen_shuffle_params((64, 128, 32), (32, 64, 32), ["s8,s8,s32,s32,s32"], "",
2, kernel.GemmAlgo.SimtDP4A, None, is_nvrtc=True), 2, kernel.GemmAlgo.SimtDP4A, None),
*gen_shuffle_params((128, 64, 32), (64, 32, 32), ["s8,s8,s32,s32,s32"], "", *gen_shuffle_params((128, 64, 32), (64, 32, 32), ["s8,s8,s32,s32,s32"], "",
2, kernel.GemmAlgo.SimtDP4A, None, is_nvrtc=True), 2, kernel.GemmAlgo.SimtDP4A, None),
*gen_shuffle_params((128, 128, 32), (32, 64, 32), ["s8,s8,s32,s32,s32"], *gen_shuffle_params((128, 128, 32), (32, 64, 32), ["s8,s8,s32,s32,s32"],
"", 2, kernel.GemmAlgo.SimtDP4A, None, is_nvrtc=True), "", 2, kernel.GemmAlgo.SimtDP4A, None),
*gen_shuffle_params( *gen_shuffle_params(
(128, 128, 32), (128, 128, 32),
(64, 32, 32), ["s8,s8,s8,s32,s32", "s8,s8,s32,s32,s32"], "", 2, (64, 32, 32), ["s8,s8,s8,s32,s32", "s8,s8,s32,s32,s32"], "", 2,
kernel.GemmAlgo.SimtDP4A, None, is_nvrtc=True), kernel.GemmAlgo.SimtDP4A, None),
*gen_shuffle_params((64, 64, 32), (32, 32, 32), ["s8,s8,s32,s32,s32"], "", *gen_shuffle_params((64, 64, 32), (32, 32, 32), ["s8,s8,s32,s32,s32"], "",
2, kernel.GemmAlgo.SimtDP4A, None, is_nvrtc=True), 2, kernel.GemmAlgo.SimtDP4A, None),
*gen_shuffle_params((64, 256, 8), (32, 64, 8), ["f32,f32,f32,f32,f32"], *gen_shuffle_params((64, 256, 8), (32, 64, 8), ["f32,f32,f32,f32,f32"],
"f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None), "f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None),
# *gen_shuffle_params( # *gen_shuffle_params(
...@@ -84,9 +83,6 @@ SHUFFLE_SIMT_PARAMS: List[GemmAlgoParams] = [ ...@@ -84,9 +83,6 @@ SHUFFLE_SIMT_PARAMS: List[GemmAlgoParams] = [
"f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None), "f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None),
*gen_shuffle_params((32, 32, 32), (32, 32, 8), ["f32,f32,f32,f32,f32"], *gen_shuffle_params((32, 32, 32), (32, 32, 8), ["f32,f32,f32,f32,f32"],
"f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None), "f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None),
*gen_shuffle_params((16, 32, 8), (16, 16, 8), ["f32,f32,f32,f32,f32"],
"f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None),
# fall back kernels if mat is misaligned for half # fall back kernels if mat is misaligned for half
# TODO use access-per-vector kernel instead of simt kernel for fallback # TODO use access-per-vector kernel instead of simt kernel for fallback
*gen_shuffle_params((128, 128, 8), (32, 64, 8), ["f16,f16,f16,f32,f32"], *gen_shuffle_params((128, 128, 8), (32, 64, 8), ["f16,f16,f16,f32,f32"],
...@@ -169,11 +165,11 @@ SHUFFLE_TURING_PARAMS: List[GemmAlgoParams] = [ ...@@ -169,11 +165,11 @@ SHUFFLE_TURING_PARAMS: List[GemmAlgoParams] = [
(32, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2, (32, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2,
kernel.GemmAlgo.Turing, TensorOp((16, 8, 8))), kernel.GemmAlgo.Turing, TensorOp((16, 8, 8))),
*gen_shuffle_params((64, 64, 32), (32, 32, 32), ["s8,s8,s32,s32,s32"], "", *gen_shuffle_params((64, 64, 32), (32, 32, 32), ["s8,s8,s32,s32,s32"], "",
2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16)), is_nvrtc=True), 2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))),
*gen_shuffle_params( *gen_shuffle_params(
(128, 128, 32), (128, 128, 32),
(32, 64, 32), ["s8,s8,s32,s32,s32"], "", 2, kernel.GemmAlgo.Turing, (32, 64, 32), ["s8,s8,s32,s32,s32"], "", 2, kernel.GemmAlgo.Turing,
TensorOp((8, 8, 16)), is_nvrtc=True), TensorOp((8, 8, 16))),
# *gen_shuffle_params( # *gen_shuffle_params(
# (128, 128, 32), # (128, 128, 32),
# (64, 32, 32), ["s8,s8,s8,s32,s32", "s8,s8,s32,s32,s32"], "", 2, # (64, 32, 32), ["s8,s8,s8,s32,s32", "s8,s8,s32,s32,s32"], "", 2,
...@@ -181,15 +177,15 @@ SHUFFLE_TURING_PARAMS: List[GemmAlgoParams] = [ ...@@ -181,15 +177,15 @@ SHUFFLE_TURING_PARAMS: List[GemmAlgoParams] = [
*gen_shuffle_params( *gen_shuffle_params(
(128, 256, 32), (128, 256, 32),
(64, 64, 32), ["s8,s8,s32,s32,s32"], "", 2, kernel.GemmAlgo.Turing, (64, 64, 32), ["s8,s8,s32,s32,s32"], "", 2, kernel.GemmAlgo.Turing,
TensorOp((8, 8, 16)), is_nvrtc=True), TensorOp((8, 8, 16))),
*gen_shuffle_params( *gen_shuffle_params(
(256, 128, 32), (256, 128, 32),
(64, 64, 32), ["s8,s8,s32,s32,s32"], "", 2, kernel.GemmAlgo.Turing, (64, 64, 32), ["s8,s8,s32,s32,s32"], "", 2, kernel.GemmAlgo.Turing,
TensorOp((8, 8, 16)), is_nvrtc=True), TensorOp((8, 8, 16))),
*gen_shuffle_params((128, 64, 32), (64, 32, 32), ["s8,s8,s32,s32,s32"], "", *gen_shuffle_params((128, 64, 32), (64, 32, 32), ["s8,s8,s32,s32,s32"], "",
2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16)), is_nvrtc=True), 2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))),
*gen_shuffle_params((64, 128, 32), (32, 64, 32), ["s8,s8,s32,s32,s32"], "", *gen_shuffle_params((64, 128, 32), (32, 64, 32), ["s8,s8,s32,s32,s32"], "",
2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16)), is_nvrtc=True), 2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))),
] ]
# SHUFFLE_TURING_PARAMS = [] # SHUFFLE_TURING_PARAMS = []
...@@ -403,8 +399,6 @@ IMPLGEMM_SIMT_PARAMS = [ ...@@ -403,8 +399,6 @@ IMPLGEMM_SIMT_PARAMS = [
increment_k_first=True, increment_k_first=True,
access_per_vector=1), access_per_vector=1),
] ]
IMPLGEMM_VOLTA_PARAMS = [ IMPLGEMM_VOLTA_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32), *gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32),
NDIM_DONT_CARE, NDIM_DONT_CARE,
...@@ -668,181 +662,6 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -668,181 +662,6 @@ IMPLGEMM_TURING_PARAMS = [
# NHWC, NHWC, NHWC, GemmAlgo.Turing, TensorOp((16, 8, 8)), mask_sparse=True, increment_k_first=True, access_per_vector=1), # NHWC, NHWC, NHWC, GemmAlgo.Turing, TensorOp((16, 8, 8)), mask_sparse=True, increment_k_first=True, access_per_vector=1),
# gen_conv_params(ConvFwdAndBwdInput, ) # gen_conv_params(ConvFwdAndBwdInput, )
# all int8 kernels use nvrtc.
*gen_conv_params(ConvFwd, (32, 32, 32), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["s8,s8,s8,s32,s32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((8, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=False),
*gen_conv_params(ConvFwd, (32, 64, 32), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["s8,s8,s8,s32,s32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((8, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=False),
*gen_conv_params(ConvFwd, (32, 32, 64), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["s8,s8,s8,s32,s32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((8, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=False),
*gen_conv_params(ConvFwd, (32, 64, 64), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["s8,s8,s8,s32,s32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((8, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=False),
*gen_conv_params(ConvFwd, (64, 128, 32), (32, 64, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["s8,s8,s8,s32,s32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((8, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=False),
*gen_conv_params(ConvFwd, (64, 64, 32), (32, 64, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["s8,s8,s8,s32,s32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((8, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=False),
*gen_conv_params(ConvFwd, (64, 64, 32), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["s8,s8,s8,s32,s32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((8, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=False),
*gen_conv_params(ConvFwd, (64, 32, 32), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["s8,s8,s8,s32,s32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((8, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=False),
*gen_conv_params(ConvFwd, (128, 128, 64), (64, 64, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["s8,s8,s8,s32,s32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((8, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=False),
*gen_conv_params(ConvFwd, (64, 128, 64), (32, 64, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["s8,s8,s8,s32,s32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((8, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=False),
# *gen_conv_params(ConvFwd, (32, 32, 32), (32, 32, 32),
# NDIM_DONT_CARE,
# ConvIterAlgo.Optimized,
# 2, ["s8,s8,s8,s32,s32"],
# NHWC,
# NHWC,
# NHWC,
# GemmAlgo.Turing,
# TensorOp((8, 8, 16)),
# mask_sparse=True,
# increment_k_first=True,
# access_per_vector=0,
# is_nvrtc=True),
# *gen_conv_params(ConvFwd, (32, 64, 32), (32, 32, 32),
# NDIM_DONT_CARE,
# ConvIterAlgo.Optimized,
# 2, ["s8,s8,s8,s32,s32"],
# NHWC,
# NHWC,
# NHWC,
# GemmAlgo.Turing,
# TensorOp((8, 8, 16)),
# mask_sparse=True,
# increment_k_first=True,
# access_per_vector=0,
# is_nvrtc=True),
# *gen_conv_params(ConvFwd, (32, 32, 64), (32, 32, 32),
# NDIM_DONT_CARE,
# ConvIterAlgo.Optimized,
# 2, ["s8,s8,s8,s32,s32"],
# NHWC,
# NHWC,
# NHWC,
# GemmAlgo.Turing,
# TensorOp((8, 8, 16)),
# mask_sparse=True,
# increment_k_first=True,
# access_per_vector=0,
# is_nvrtc=True),
] ]
ALL_NATIVE_PARAMS = SHUFFLE_SIMT_PARAMS + SHUFFLE_TURING_PARAMS + SHUFFLE_VOLTA_PARAMS ALL_NATIVE_PARAMS = SHUFFLE_SIMT_PARAMS + SHUFFLE_TURING_PARAMS + SHUFFLE_VOLTA_PARAMS
......
...@@ -11,7 +11,7 @@ class ConvGemmOps: ...@@ -11,7 +11,7 @@ class ConvGemmOps:
""" """
... ...
@staticmethod @staticmethod
def indice_conv(allocator, ext_mm, gemm_tuner, all_w_is_krsc: bool, filter_hwio: bool, features: Tensor, filters: Tensor, indice_pairs: Tensor, indice_pair_num: Tensor, num_activate_out: int, inverse: bool = False, subm: bool = False, algo: int = 0, stream_int: int = 0) -> None: def indice_conv(allocator, ext_mm, gemm_tuner, all_w_is_krsc: bool, filter_hwio: bool, features: Tensor, filters: Tensor, indice_pairs: Tensor, indice_pair_num: Tensor, arch: Tuple[int, int], num_activate_out: int, inverse: bool = False, subm: bool = False, algo: int = 0, stream_int: int = 0) -> None:
""" """
1. this function need to take a out features 1. this function need to take a out features
that from subm first mm. that from subm first mm.
...@@ -26,6 +26,7 @@ class ConvGemmOps: ...@@ -26,6 +26,7 @@ class ConvGemmOps:
filters: filters:
indice_pairs: indice_pairs:
indice_pair_num: indice_pair_num:
arch:
num_activate_out: num_activate_out:
inverse: inverse:
subm: subm:
...@@ -34,7 +35,7 @@ class ConvGemmOps: ...@@ -34,7 +35,7 @@ class ConvGemmOps:
""" """
... ...
@staticmethod @staticmethod
def indice_conv_backward(allocator, ext_mm, gemm_tuner, all_w_is_krsc: bool, filter_hwio: bool, features: Tensor, filters: Tensor, out_bp: Tensor, indice_pairs: Tensor, indice_pair_num: Tensor, inverse: bool = False, subm: bool = False, algo: int = 0, stream_int: int = 0) -> None: def indice_conv_backward(allocator, ext_mm, gemm_tuner, all_w_is_krsc: bool, filter_hwio: bool, features: Tensor, filters: Tensor, out_bp: Tensor, indice_pairs: Tensor, indice_pair_num: Tensor, arch: Tuple[int, int], inverse: bool = False, subm: bool = False, algo: int = 0, stream_int: int = 0) -> None:
""" """
Args: Args:
allocator: allocator:
...@@ -47,6 +48,7 @@ class ConvGemmOps: ...@@ -47,6 +48,7 @@ class ConvGemmOps:
out_bp: out_bp:
indice_pairs: indice_pairs:
indice_pair_num: indice_pair_num:
arch:
inverse: inverse:
subm: subm:
algo: algo:
...@@ -54,7 +56,7 @@ class ConvGemmOps: ...@@ -54,7 +56,7 @@ class ConvGemmOps:
""" """
... ...
@staticmethod @staticmethod
def implicit_gemm(allocator, conv_tuner, features: Tensor, filters: Tensor, pair_fwd: Tensor, pair_mask_fwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], num_activate_out: int, masks: Tensor, is_train: bool = False, is_subm: bool = False, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False) -> int: def implicit_gemm(allocator, conv_tuner, features: Tensor, filters: Tensor, pair_fwd: Tensor, pair_mask_fwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], num_activate_out: int, masks: Tensor, arch: Tuple[int, int], is_train: bool = False, is_subm: bool = False, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False) -> int:
""" """
Args: Args:
allocator: allocator:
...@@ -66,6 +68,7 @@ class ConvGemmOps: ...@@ -66,6 +68,7 @@ class ConvGemmOps:
mask_argsort_fwd_splits: mask_argsort_fwd_splits:
num_activate_out: num_activate_out:
masks: masks:
arch:
is_train: is_train:
is_subm: is_subm:
stream_int: stream_int:
...@@ -75,7 +78,7 @@ class ConvGemmOps: ...@@ -75,7 +78,7 @@ class ConvGemmOps:
""" """
... ...
@staticmethod @staticmethod
def implicit_gemm_backward(allocator, conv_tuner, features: Tensor, filters: Tensor, out_bp: Tensor, pair_fwd: Tensor, pair_bwd: Tensor, pair_mask_fwd_splits: List[Tensor], pair_mask_bwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], mask_argsort_bwd_splits: List[Tensor], mask_output_fwd: Tensor, masks: Tensor, mask_width: int, is_subm: bool, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False) -> None: def implicit_gemm_backward(allocator, conv_tuner, features: Tensor, filters: Tensor, out_bp: Tensor, pair_fwd: Tensor, pair_bwd: Tensor, pair_mask_fwd_splits: List[Tensor], pair_mask_bwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], mask_argsort_bwd_splits: List[Tensor], mask_output_fwd: Tensor, masks: Tensor, arch: Tuple[int, int], mask_width: int, is_subm: bool, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False) -> None:
""" """
Args: Args:
allocator: allocator:
...@@ -91,6 +94,7 @@ class ConvGemmOps: ...@@ -91,6 +94,7 @@ class ConvGemmOps:
mask_argsort_bwd_splits: mask_argsort_bwd_splits:
mask_output_fwd: mask_output_fwd:
masks: masks:
arch:
mask_width: mask_width:
is_subm: is_subm:
stream_int: stream_int:
......
...@@ -1377,6 +1377,8 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -1377,6 +1377,8 @@ class ConvGemmOps(pccm.ParameterizedClass):
code.arg("features, filters, indice_pairs", "tv::Tensor") code.arg("features, filters, indice_pairs", "tv::Tensor")
code.arg("indice_pair_num", "tv::Tensor") code.arg("indice_pair_num", "tv::Tensor")
code.arg("arch", "std::tuple<int, int>")
code.arg("num_activate_out", "int") code.arg("num_activate_out", "int")
code.arg("inverse", "bool", "false") code.arg("inverse", "bool", "false")
code.arg("subm", "bool", "false") code.arg("subm", "bool", "false")
...@@ -1489,7 +1491,7 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -1489,7 +1491,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
}} }}
}} }}
TV_ASSERT_RT_ERR(nhot_profile > 0, "this shouldn't happen"); TV_ASSERT_RT_ERR(nhot_profile > 0, "this shouldn't happen");
auto arch = get_compute_capability(); // auto arch = get_compute_capability();
auto a_shape = a.shape(); auto a_shape = a.shape();
auto c_shape = c.shape(); auto c_shape = c.shape();
int sac_shuffle_type = static_cast<int>(tv::gemm::ShuffleStrideType::kShuffleAC); int sac_shuffle_type = static_cast<int>(tv::gemm::ShuffleStrideType::kShuffleAC);
...@@ -1584,6 +1586,8 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -1584,6 +1586,8 @@ class ConvGemmOps(pccm.ParameterizedClass):
code.arg("all_w_is_krsc, filter_hwio", "bool") code.arg("all_w_is_krsc, filter_hwio", "bool")
code.arg("features, filters, out_bp, indice_pairs", "tv::Tensor") code.arg("features, filters, out_bp, indice_pairs", "tv::Tensor")
code.arg("indice_pair_num", "tv::Tensor") code.arg("indice_pair_num", "tv::Tensor")
code.arg("arch", "std::tuple<int, int>")
code.arg("inverse", "bool", "false") code.arg("inverse", "bool", "false")
code.arg("subm", "bool", "false") code.arg("subm", "bool", "false")
code.arg("algo", "int", f"{ConvAlgo.Native.value}") code.arg("algo", "int", f"{ConvAlgo.Native.value}")
...@@ -1594,6 +1598,7 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -1594,6 +1598,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
std::vector<int64_t> filter_shape_per_kv; std::vector<int64_t> filter_shape_per_kv;
auto prev_filter_shape_vec = filters.shape_vector(); auto prev_filter_shape_vec = filters.shape_vector();
bool is_KC_not_CK; bool is_KC_not_CK;
if (!all_w_is_krsc){{ if (!all_w_is_krsc){{
kv_dim = 0; kv_dim = 0;
is_KC_not_CK = !filter_hwio; is_KC_not_CK = !filter_hwio;
...@@ -1700,7 +1705,7 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -1700,7 +1705,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
}} }}
}} }}
TV_ASSERT_RT_ERR(nhot_profile > 0, "this shouldn't happen"); TV_ASSERT_RT_ERR(nhot_profile > 0, "this shouldn't happen");
auto arch = get_compute_capability(); // auto arch = get_compute_capability();
int sac_shuffle_type = static_cast<int>(tv::gemm::ShuffleStrideType::kShuffleAC); int sac_shuffle_type = static_cast<int>(tv::gemm::ShuffleStrideType::kShuffleAC);
int sab_shuffle_type = static_cast<int>(tv::gemm::ShuffleStrideType::kShuffleAB); int sab_shuffle_type = static_cast<int>(tv::gemm::ShuffleStrideType::kShuffleAB);
...@@ -1899,6 +1904,8 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -1899,6 +1904,8 @@ class ConvGemmOps(pccm.ParameterizedClass):
"std::vector<tv::Tensor>") "std::vector<tv::Tensor>")
code.arg("num_activate_out", "int") code.arg("num_activate_out", "int")
code.arg("masks", "tv::Tensor") code.arg("masks", "tv::Tensor")
code.arg("arch", "std::tuple<int, int>")
code.arg("is_train, is_subm", "bool", "false") code.arg("is_train, is_subm", "bool", "false")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int") code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int")
code.arg("timer", "tv::CUDAKernelTimer", "tv::CUDAKernelTimer(false)", code.arg("timer", "tv::CUDAKernelTimer", "tv::CUDAKernelTimer(false)",
...@@ -1926,7 +1933,11 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -1926,7 +1933,11 @@ class ConvGemmOps(pccm.ParameterizedClass):
out_features = allocator.zeros({pccm.literal(AllocKeys.OutFeatures)}, out_features = allocator.zeros({pccm.literal(AllocKeys.OutFeatures)},
{{num_activate_out, out_channel}}, features.dtype(), features.device(), stream_int); {{num_activate_out, out_channel}}, features.dtype(), features.device(), stream_int);
}} }}
auto arch = get_compute_capability(); // auto start_ev = tv::CUDAEvent();
// start_ev.record(stream_int);
// auto arch = get_compute_capability();
constexpr auto kForwardInt = static_cast<int>(tv::gemm::ConvOpType::kForward); constexpr auto kForwardInt = static_cast<int>(tv::gemm::ConvOpType::kForward);
constexpr auto kChannelLastInt = static_cast<int>(tv::gemm::ConvLayoutType::kChannelLast); constexpr auto kChannelLastInt = static_cast<int>(tv::gemm::ConvLayoutType::kChannelLast);
auto tuned_res_exist = conv_tuner.get_tuned_algo( auto tuned_res_exist = conv_tuner.get_tuned_algo(
...@@ -1959,6 +1970,7 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -1959,6 +1970,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
fp32_accum); fp32_accum);
tune_res = std::get<0>(tune_res_time); tune_res = std::get<0>(tune_res_time);
}} }}
int mask_width = tune_res.algo_desp.tile_shape[0]; int mask_width = tune_res.algo_desp.tile_shape[0];
tv::Tensor mask_output_fwd; tv::Tensor mask_output_fwd;
std::vector<tv::Tensor> mask_output_fwd_splits; std::vector<tv::Tensor> mask_output_fwd_splits;
...@@ -1974,6 +1986,7 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -1974,6 +1986,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
mask_output_fwd_splits.push_back(tv::Tensor()); mask_output_fwd_splits.push_back(tv::Tensor());
}} }}
}} }}
for (int j = 0; j < num_split; ++j){{ for (int j = 0; j < num_split; ++j){{
float beta = j == 0 ? 0 : 1; float beta = j == 0 ? 0 : 1;
conv_tuner.run_with_tuned_result( conv_tuner.run_with_tuned_result(
...@@ -1995,6 +2008,11 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -1995,6 +2008,11 @@ class ConvGemmOps(pccm.ParameterizedClass):
false, // verbose false, // verbose
timer); timer);
}} }}
// auto end_ev = tv::CUDAEvent();
// end_ev.record(stream_int);
// tv::ssprint(tune_res.algo_desp.__repr__(), "WTF", exists,
// features.shape(), filters.shape(), out_features.shape(), tv::CUDAEvent::sync_and_duration(start_ev, end_ev));
return mask_width; return mask_width;
""") """)
return code.ret("int") return code.ret("int")
...@@ -2013,6 +2031,8 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -2013,6 +2031,8 @@ class ConvGemmOps(pccm.ParameterizedClass):
code.arg("mask_output_fwd", "tv::Tensor") code.arg("mask_output_fwd", "tv::Tensor")
code.arg("masks", "tv::Tensor") code.arg("masks", "tv::Tensor")
code.arg("arch", "std::tuple<int, int>")
code.arg("mask_width", "int") code.arg("mask_width", "int")
code.arg("is_subm", "bool") code.arg("is_subm", "bool")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int") code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int")
...@@ -2056,7 +2076,7 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -2056,7 +2076,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
constexpr auto kChannelLastInt = static_cast<int>(tv::gemm::ConvLayoutType::kChannelLast); constexpr auto kChannelLastInt = static_cast<int>(tv::gemm::ConvLayoutType::kChannelLast);
auto arch = get_compute_capability(); // auto arch = get_compute_capability();
auto dgrad_tuned_res_exist = conv_tuner.get_tuned_algo( auto dgrad_tuned_res_exist = conv_tuner.get_tuned_algo(
kBackwardInputInt, kBackwardInputInt,
......
...@@ -419,10 +419,16 @@ def get_indice_pairs_implicit_gemm( ...@@ -419,10 +419,16 @@ def get_indice_pairs_implicit_gemm(
is_mask_split = algo == ConvAlgo.MaskSplitImplicitGemm is_mask_split = algo == ConvAlgo.MaskSplitImplicitGemm
mask_split_count = 2 if is_mask_split else 1 mask_split_count = 2 if is_mask_split else 1
if subm: if subm:
pair = torch.full((2, kv, indices.shape[0]), if is_train:
-1, pair = torch.full((2, kv, indices.shape[0]),
dtype=indices.dtype, -1,
device=indices.device) dtype=indices.dtype,
device=indices.device)
else:
pair = torch.full((1, kv, indices.shape[0]),
-1,
dtype=indices.dtype,
device=indices.device)
else: else:
# for regular conv, pair-in not equal to pair-out # for regular conv, pair-in not equal to pair-out
pair = torch.full((kv, indices.shape[0]), pair = torch.full((kv, indices.shape[0]),
...@@ -476,6 +482,7 @@ def get_indice_pairs_implicit_gemm( ...@@ -476,6 +482,7 @@ def get_indice_pairs_implicit_gemm(
ksize=ksize, ksize=ksize,
dilation=dilation, dilation=dilation,
indice_pair_mask=pair_mask_tv, indice_pair_mask=pair_mask_tv,
backward=is_train,
stream_int=stream) stream_int=stream)
# torch.cuda.synchronize() # torch.cuda.synchronize()
# print("SUBM0", time.time() - t) # print("SUBM0", time.time() - t)
...@@ -505,10 +512,12 @@ def get_indice_pairs_implicit_gemm( ...@@ -505,10 +512,12 @@ def get_indice_pairs_implicit_gemm(
CONV.stream_synchronize(stream) CONV.stream_synchronize(stream)
print("SUBM", time.time() - t) print("SUBM", time.time() - t)
if is_train:
return (out_inds, indice_num_per_loc, pair[0], pair[1], return (out_inds, indice_num_per_loc, pair[0], pair[1],
pair_mask_in_splits, [], mask_argsort_in_splits, [], masks) pair_mask_in_splits, [], mask_argsort_in_splits, [], masks)
else:
return (out_inds, indice_num_per_loc, pair[0], torch.Tensor(),
pair_mask_in_splits, [], mask_argsort_in_splits, [], masks)
else: else:
if DEBUG: if DEBUG:
...@@ -753,11 +762,15 @@ def indice_conv(features: torch.Tensor, ...@@ -753,11 +762,15 @@ def indice_conv(features: torch.Tensor,
indice_pair_num_tv = torch_tensor_to_tv(indice_pair_num) indice_pair_num_tv = torch_tensor_to_tv(indice_pair_num)
filters_tv = torch_tensor_to_tv(filters) filters_tv = torch_tensor_to_tv(filters)
stream = 0 stream = 0
arch = (0, 0)
if features.is_cuda: if features.is_cuda:
# plain get_arch by cuda api is VERY SLOW.
arch = get_arch()
stream = get_current_stream() stream = get_current_stream()
ConvGemmOps.indice_conv(alloc, ext_mm, GEMM_CPP, ALL_WEIGHT_IS_KRSC, ConvGemmOps.indice_conv(alloc, ext_mm, GEMM_CPP, ALL_WEIGHT_IS_KRSC,
FILTER_HWIO, features_tv, filters_tv, FILTER_HWIO, features_tv, filters_tv,
indice_pairs_tv, indice_pair_num_tv, indice_pairs_tv, indice_pair_num_tv,
arch,
num_activate_out, inverse, subm, algo.value, num_activate_out, inverse, subm, algo.value,
stream) stream)
out_features = alloc.allocated[AllocKeys.OutFeatures] out_features = alloc.allocated[AllocKeys.OutFeatures]
...@@ -996,12 +1009,16 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -996,12 +1009,16 @@ def indice_conv_backward(features: torch.Tensor,
indice_pair_num_tv = torch_tensor_to_tv(indice_pair_num) indice_pair_num_tv = torch_tensor_to_tv(indice_pair_num)
filters_tv = torch_tensor_to_tv(filters) filters_tv = torch_tensor_to_tv(filters)
stream = 0 stream = 0
arch = (0, 0)
if features.is_cuda: if features.is_cuda:
stream = get_current_stream() stream = get_current_stream()
arch = get_arch()
ConvGemmOps.indice_conv_backward(alloc, ext_mm, GEMM_CPP, ConvGemmOps.indice_conv_backward(alloc, ext_mm, GEMM_CPP,
ALL_WEIGHT_IS_KRSC, FILTER_HWIO, ALL_WEIGHT_IS_KRSC, FILTER_HWIO,
features_tv, filters_tv, out_bp_tv, features_tv, filters_tv, out_bp_tv,
indice_pairs_tv, indice_pair_num_tv, indice_pairs_tv, indice_pair_num_tv,
arch,
inverse, subm, algo.value, stream) inverse, subm, algo.value, stream)
din = alloc.allocated[AllocKeys.DIn] din = alloc.allocated[AllocKeys.DIn]
df = alloc.allocated[AllocKeys.DFilters] df = alloc.allocated[AllocKeys.DFilters]
...@@ -1347,10 +1364,12 @@ def implicit_gemm(features: torch.Tensor, ...@@ -1347,10 +1364,12 @@ def implicit_gemm(features: torch.Tensor,
auto_fp32_accum = fp32_accum is None auto_fp32_accum = fp32_accum is None
if fp32_accum is None: if fp32_accum is None:
fp32_accum = False fp32_accum = False
arch = get_arch()
mask_width = ConvGemmOps.implicit_gemm( mask_width = ConvGemmOps.implicit_gemm(
alloc, CONV_CPP, features_tv, filters_tv, pair_fwd_tv, alloc, CONV_CPP, features_tv, filters_tv, pair_fwd_tv,
pair_mask_fwd_splits_tv, mask_argsort_fwd_splits_tv, pair_mask_fwd_splits_tv, mask_argsort_fwd_splits_tv,
num_activate_out, mask_tv, is_train, is_subm, stream, timer_cpp, num_activate_out, mask_tv, arch, is_train, is_subm, stream, timer_cpp,
auto_fp32_accum, fp32_accum) auto_fp32_accum, fp32_accum)
out_features = alloc.allocated[AllocKeys.OutFeatures] out_features = alloc.allocated[AllocKeys.OutFeatures]
mask_output_fwd = alloc.allocated.get(AllocKeys.MaskOutputFwd, None) mask_output_fwd = alloc.allocated.get(AllocKeys.MaskOutputFwd, None)
...@@ -1441,7 +1460,7 @@ def implicit_gemm(features: torch.Tensor, ...@@ -1441,7 +1460,7 @@ def implicit_gemm(features: torch.Tensor,
# CONV.stream_synchronize(stream) # CONV.stream_synchronize(stream)
# t = time.time() # t = time.time()
# print(tune_res.algo_desp) print(tune_res.algo_desp, "REF", features_tv.shape, filters.shape)
# with tv.measure_and_print("f16 time"): # with tv.measure_and_print("f16 time"):
with timer.record("implicit_gemm", stream): with timer.record("implicit_gemm", stream):
for j in range(num_split): for j in range(num_split):
...@@ -1613,11 +1632,13 @@ def implicit_gemm_backward(features: torch.Tensor, ...@@ -1613,11 +1632,13 @@ def implicit_gemm_backward(features: torch.Tensor,
auto_fp32_accum = fp32_accum is None auto_fp32_accum = fp32_accum is None
if fp32_accum is None: if fp32_accum is None:
fp32_accum = False fp32_accum = False
arch = get_arch()
ConvGemmOps.implicit_gemm_backward( ConvGemmOps.implicit_gemm_backward(
alloc, CONV_CPP, features_tv, filters_tv, out_bp_tv, pair_fwd_tv, alloc, CONV_CPP, features_tv, filters_tv, out_bp_tv, pair_fwd_tv,
pair_bwd_tv, pair_mask_fwd_splits_tv, pair_mask_bwd_splits_tv, pair_bwd_tv, pair_mask_fwd_splits_tv, pair_mask_bwd_splits_tv,
mask_argsort_fwd_splits_tv, mask_argsort_bwd_splits_tv, mask_argsort_fwd_splits_tv, mask_argsort_bwd_splits_tv,
mask_output_fwd_tv, mask_tv, mask_width, is_subm, stream, mask_output_fwd_tv, mask_tv, arch, mask_width, is_subm, stream,
timer_cpp, auto_fp32_accum, fp32_accum) timer_cpp, auto_fp32_accum, fp32_accum)
din = alloc.allocated[AllocKeys.DIn] din = alloc.allocated[AllocKeys.DIn]
dfilters = alloc.allocated[AllocKeys.DFilters] dfilters = alloc.allocated[AllocKeys.DFilters]
......
...@@ -113,7 +113,7 @@ class Net(nn.Module): ...@@ -113,7 +113,7 @@ class Net(nn.Module):
# nn.BatchNorm1d(32), # nn.BatchNorm1d(32),
# nn.ReLU(), # nn.ReLU(),
# spconv.SparseConv3d(64, 64, 2, 2, bias=False, indice_key="m0"), # spconv.SparseConv3d(64, 64, 2, 2, bias=False, indice_key="m0"),
spconv.SparseMaxPool3d(2, 2, algo=pool_algo, record_voxel_count=True), spconv.SparseMaxPool3d(2, 2, algo=pool_algo),
spconv.SubMConv3d(64, spconv.SubMConv3d(64,
96, 96,
3, 3,
...@@ -312,7 +312,7 @@ def sort_bench(): ...@@ -312,7 +312,7 @@ def sort_bench():
for i in range(10): for i in range(10):
a_tv_1 = a_tv.clone() a_tv_1 = a_tv.clone()
SpconvOps.sort_1d_by_key(a_tv_1[0], mask_argsort_tv[0]) SpconvOps.sort_1d_by_key(a_tv_1[0], mask_argsort_tv[0])
import json
def main(): def main():
import pickle import pickle
...@@ -332,7 +332,8 @@ def main(): ...@@ -332,7 +332,8 @@ def main():
voxels_th = torch.from_numpy(voxels).to(device).to(dtype) voxels_th = torch.from_numpy(voxels).to(device).to(dtype)
coors_th = torch.from_numpy(coors).to(device).int() coors_th = torch.from_numpy(coors).to(device).int()
voxels_th.requires_grad = True voxels_th.requires_grad = True
algo = spconv.ConvAlgo.Native algo = spconv.ConvAlgo.MaskImplicitGemm
print("ALGO")
# 3080 Laptop # 3080 Laptop
# MaskImpGemm: 11.2ms # MaskImpGemm: 11.2ms
# MaskSplitImpGemm: 12.2ms # MaskSplitImpGemm: 12.2ms
...@@ -355,7 +356,7 @@ def main(): ...@@ -355,7 +356,7 @@ def main():
# MaskImpGemm: 51.0ms # MaskImpGemm: 51.0ms
# MaskSplitImpGemm: 41.1ms # MaskSplitImpGemm: 41.1ms
# algo = None # algo = None
net = Net(spatial_shape, algo).to(device).eval().to(dtype).train() net = Net(spatial_shape, algo).to(device).eval().to(dtype)# .train()
# net.load_state_dict(net.state_dict()) # net.load_state_dict(net.state_dict())
spconv.assign_name_for_sparse_modules(net) spconv.assign_name_for_sparse_modules(net)
print(coors_th.shape) print(coors_th.shape)
...@@ -368,13 +369,13 @@ def main(): ...@@ -368,13 +369,13 @@ def main():
print(out.spatial_shape, out.features.mean(), out.features.max(), print(out.spatial_shape, out.features.mean(), out.features.max(),
out.features.min()) out.features.min())
times = [] times = []
show_metrics = False
with torch.no_grad(): with torch.no_grad():
for i in range(20): for i in range(20):
print("------------") print("------------")
torch.cuda.synchronize() torch.cuda.synchronize()
t = time.time() t = time.time()
out_nograd = net(voxels_th, coors_th, 1, False) out_nograd = net(voxels_th, coors_th, 1, show_metrics)
timer = out_nograd._timer
# res = timer.collect_by_name("forward", timer.get_all_pair_time()) # res = timer.collect_by_name("forward", timer.get_all_pair_time())
# res2 = timer.collect_by_name("forward0", timer.get_all_pair_time()) # res2 = timer.collect_by_name("forward0", timer.get_all_pair_time())
...@@ -385,6 +386,12 @@ def main(): ...@@ -385,6 +386,12 @@ def main():
torch.cuda.synchronize() torch.cuda.synchronize()
# sort_bench() # sort_bench()
times.append(time.time() - t) times.append(time.time() - t)
if show_metrics:
timer = out_nograd._timer
items = list(timer.get_all_pair_time().items())
items.sort(key=lambda x: x[0])
print(json.dumps(dict(items), indent=2))
# state = net.state_dict() # state = net.state_dict()
# state.pop("net.2.max_num_voxels_during_training") # state.pop("net.2.max_num_voxels_during_training")
# net.load_state_dict(state) # net.load_state_dict(state)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment