Commit d61ab8e1 authored by EvernightAurora's avatar EvernightAurora
Browse files

divide ampere, support multistage for i8 and f16

parent e8bc31ec
...@@ -151,8 +151,8 @@ if disable_jit is not None and disable_jit == "1": ...@@ -151,8 +151,8 @@ if disable_jit is not None and disable_jit == "1":
'build_ext': PCCMBuild, 'build_ext': PCCMBuild,
} }
from cumm.gemm.main import GemmMainUnitTest from cumm.gemm.main import GemmMainUnitTest
from spconv.core import SHUFFLE_SIMT_PARAMS, SHUFFLE_VOLTA_PARAMS, SHUFFLE_TURING_PARAMS from spconv.core import SHUFFLE_SIMT_PARAMS, SHUFFLE_VOLTA_PARAMS, SHUFFLE_TURING_PARAMS, SHUFFLE_AMPERE_PARAMS
from spconv.core import IMPLGEMM_SIMT_PARAMS, IMPLGEMM_VOLTA_PARAMS, IMPLGEMM_TURING_PARAMS from spconv.core import IMPLGEMM_SIMT_PARAMS, IMPLGEMM_VOLTA_PARAMS, IMPLGEMM_TURING_PARAMS, IMPLGEMM_AMPERE_PARAMS
from cumm.conv.main import ConvMainUnitTest from cumm.conv.main import ConvMainUnitTest
from cumm.constants import CUMM_CPU_ONLY_BUILD from cumm.constants import CUMM_CPU_ONLY_BUILD
from spconv.csrc.sparse.all import SpconvOps from spconv.csrc.sparse.all import SpconvOps
...@@ -163,8 +163,8 @@ if disable_jit is not None and disable_jit == "1": ...@@ -163,8 +163,8 @@ if disable_jit is not None and disable_jit == "1":
from spconv.csrc.sparse.convops import GemmTunerSimple, ExternalSpconvMatmul from spconv.csrc.sparse.convops import GemmTunerSimple, ExternalSpconvMatmul
from spconv.csrc.sparse.convops import ConvTunerSimple, ConvGemmOps from spconv.csrc.sparse.convops import ConvTunerSimple, ConvGemmOps
cu = GemmMainUnitTest(SHUFFLE_SIMT_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_TURING_PARAMS) cu = GemmMainUnitTest(SHUFFLE_SIMT_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_TURING_PARAMS + SHUFFLE_AMPERE_PARAMS)
convcu = ConvMainUnitTest(IMPLGEMM_SIMT_PARAMS + IMPLGEMM_VOLTA_PARAMS + IMPLGEMM_TURING_PARAMS) convcu = ConvMainUnitTest(IMPLGEMM_SIMT_PARAMS + IMPLGEMM_VOLTA_PARAMS + IMPLGEMM_TURING_PARAMS + IMPLGEMM_AMPERE_PARAMS)
convcu.namespace = "cumm.conv.main" convcu.namespace = "cumm.conv.main"
cu.namespace = "cumm.gemm.main" cu.namespace = "cumm.gemm.main"
......
...@@ -22,8 +22,8 @@ from .constants import PACKAGE_NAME, PACKAGE_ROOT, DISABLE_JIT ...@@ -22,8 +22,8 @@ from .constants import PACKAGE_NAME, PACKAGE_ROOT, DISABLE_JIT
if project_is_installed(PACKAGE_NAME) and project_is_editable( if project_is_installed(PACKAGE_NAME) and project_is_editable(
PACKAGE_NAME) and not DISABLE_JIT: PACKAGE_NAME) and not DISABLE_JIT:
from spconv.core import SHUFFLE_SIMT_PARAMS, SHUFFLE_VOLTA_PARAMS, SHUFFLE_TURING_PARAMS from spconv.core import SHUFFLE_SIMT_PARAMS, SHUFFLE_VOLTA_PARAMS, SHUFFLE_TURING_PARAMS, SHUFFLE_AMPERE_PARAMS
from spconv.core import IMPLGEMM_SIMT_PARAMS, IMPLGEMM_VOLTA_PARAMS, IMPLGEMM_TURING_PARAMS from spconv.core import IMPLGEMM_SIMT_PARAMS, IMPLGEMM_VOLTA_PARAMS, IMPLGEMM_TURING_PARAMS, IMPLGEMM_AMPERE_PARAMS
from cumm.gemm.main import GemmMainUnitTest from cumm.gemm.main import GemmMainUnitTest
from cumm.conv.main import ConvMainUnitTest from cumm.conv.main import ConvMainUnitTest
...@@ -37,12 +37,12 @@ if project_is_installed(PACKAGE_NAME) and project_is_editable( ...@@ -37,12 +37,12 @@ if project_is_installed(PACKAGE_NAME) and project_is_editable(
from spconv.csrc.sparse.convops import ConvTunerSimple, ConvGemmOps from spconv.csrc.sparse.convops import ConvTunerSimple, ConvGemmOps
from spconv.csrc.sparse.convops import SimpleExternalSpconvMatmul from spconv.csrc.sparse.convops import SimpleExternalSpconvMatmul
all_shuffle = SHUFFLE_SIMT_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_TURING_PARAMS all_shuffle = SHUFFLE_SIMT_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_TURING_PARAMS + SHUFFLE_AMPERE_PARAMS
all_shuffle = list(filter(lambda x: not x.is_nvrtc, all_shuffle)) all_shuffle = list(filter(lambda x: not x.is_nvrtc, all_shuffle))
cu = GemmMainUnitTest(all_shuffle) cu = GemmMainUnitTest(all_shuffle)
cu.namespace = "cumm.gemm.main" cu.namespace = "cumm.gemm.main"
all_imp = (IMPLGEMM_SIMT_PARAMS + IMPLGEMM_VOLTA_PARAMS + all_imp = (IMPLGEMM_SIMT_PARAMS + IMPLGEMM_VOLTA_PARAMS +
IMPLGEMM_TURING_PARAMS) IMPLGEMM_TURING_PARAMS + IMPLGEMM_AMPERE_PARAMS)
all_imp = list(filter(lambda x: not x.is_nvrtc, all_imp)) all_imp = list(filter(lambda x: not x.is_nvrtc, all_imp))
convcu = ConvMainUnitTest(all_imp) convcu = ConvMainUnitTest(all_imp)
convcu.namespace = "cumm.conv.main" convcu.namespace = "cumm.conv.main"
......
...@@ -168,7 +168,7 @@ SHUFFLE_TURING_PARAMS: List[GemmAlgoParams] = [ ...@@ -168,7 +168,7 @@ SHUFFLE_TURING_PARAMS: List[GemmAlgoParams] = [
2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))), 2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))),
*gen_shuffle_params( *gen_shuffle_params(
(128, 128, 32), (128, 128, 32),
(32, 64, 32), ["s8,s8,s32,s32,s32"], "", 2, kernel.GemmAlgo.Turing, (32, 64, 32), ["s8,s8,s8,s32,s32"], "", 2, kernel.GemmAlgo.Turing,
TensorOp((8, 8, 16))), TensorOp((8, 8, 16))),
# *gen_shuffle_params( # *gen_shuffle_params(
# (128, 128, 32), # (128, 128, 32),
...@@ -176,11 +176,11 @@ SHUFFLE_TURING_PARAMS: List[GemmAlgoParams] = [ ...@@ -176,11 +176,11 @@ SHUFFLE_TURING_PARAMS: List[GemmAlgoParams] = [
# kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))), # kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))),
*gen_shuffle_params( *gen_shuffle_params(
(128, 256, 32), (128, 256, 32),
(64, 64, 32), ["s8,s8,s32,s32,s32"], "", 2, kernel.GemmAlgo.Turing, (64, 64, 32), ["s8,s8,s8,s32,s32"], "", 2, kernel.GemmAlgo.Turing,
TensorOp((8, 8, 16))), TensorOp((8, 8, 16))),
*gen_shuffle_params( *gen_shuffle_params(
(256, 128, 32), (256, 128, 32),
(64, 64, 32), ["s8,s8,s32,s32,s32"], "", 2, kernel.GemmAlgo.Turing, (64, 64, 32), ["s8,s8,s8,s32,s32"], "", 2, kernel.GemmAlgo.Turing,
TensorOp((8, 8, 16))), TensorOp((8, 8, 16))),
*gen_shuffle_params((128, 64, 32), (64, 32, 32), ["s8,s8,s32,s32,s32"], "", *gen_shuffle_params((128, 64, 32), (64, 32, 32), ["s8,s8,s32,s32,s32"], "",
2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))), 2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))),
...@@ -188,6 +188,17 @@ SHUFFLE_TURING_PARAMS: List[GemmAlgoParams] = [ ...@@ -188,6 +188,17 @@ SHUFFLE_TURING_PARAMS: List[GemmAlgoParams] = [
2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))), 2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))),
] ]
SHUFFLE_AMPERE_PARAMS = [
*gen_shuffle_params(
(128, 128, 64),
(64, 64, 64), ["s8,s8,s8,s32,s32"], "", 3, kernel.GemmAlgo.Ampere,
TensorOp((8, 8, 16))),
*gen_shuffle_params(
(128, 64, 64),
(64, 32, 64), ["s8,s8,s8,s32,s32"], "", 3, kernel.GemmAlgo.Ampere,
TensorOp((8, 8, 16))),
]
# SHUFFLE_TURING_PARAMS = [] # SHUFFLE_TURING_PARAMS = []
# here we must use f32 for simt f16 accumulators because # here we must use f32 for simt f16 accumulators because
# half intristics is VERY SLOW in GTX 1000 series. # half intristics is VERY SLOW in GTX 1000 series.
...@@ -486,8 +497,415 @@ IMPLGEMM_VOLTA_PARAMS = [ ...@@ -486,8 +497,415 @@ IMPLGEMM_VOLTA_PARAMS = [
access_per_vector=1), access_per_vector=1),
] ]
IMPLGEMM_AMPERE_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (32, 32, 16), (16, 16, 16),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["f32,f32,f32,f32,f32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 8), top_dtypes="tf32,tf32,f32"),
mask_sparse=True,
increment_k_first=True,
access_per_vector=0),
*gen_conv_params(ConvFwdAndBwdInput, (32, 32, 16), (16, 16, 16),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4], ["f32,f32,f32,f32,f32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 8), top_dtypes="tf32,tf32,f32"),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4], ["f32,f32,f32,f32,f32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 8), top_dtypes="tf32,tf32,f32"),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (64, 128, 32), (32, 64, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4], ["f16,f16,f16,f16,f16", "f16,f16,f16,f32,f32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 8)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4], ["f16,f16,f16,f16,f16", "f16,f16,f16,f32,f32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 8)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (128, 64, 32), (64, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4], ["f16,f16,f16,f16,f16", "f16,f16,f16,f32,f32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 8)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvBwdWeight, (64, 64, 32), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
"f32,f32,f32,f32,f32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 8), top_dtypes="tf32,tf32,f32"),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvBwdWeight, (64, 128, 32), (32, 64, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
"f16,f16,f16,f32,f32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 8)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvBwdWeight, (64, 64, 32), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4, 5],
"f16,f16,f16,f32,f32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 8)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (128, 64, 64), (64, 32, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (128, 64, 32), (64, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (64, 128, 64), (32, 64, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (64, 128, 32), (32, 64, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 64), (32, 32, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (128, 128, 64), (64, 64, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (128, 256, 64), (64, 128, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (256, 128, 64), (128, 64, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (128, 128, 128), (64, 64, 128),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3],
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
]
IMPLGEMM_TURING_PARAMS = [ IMPLGEMM_TURING_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 64), (32, 32, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (64, 128, 64), (32, 64, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (64, 128, 32), (32, 64, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (128, 64, 64), (64, 32, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (128, 64, 32), (64, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (128, 256, 64), (64, 128, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (256, 128, 64), (128, 64, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (128, 128, 128), (64, 64, 128),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (128, 128, 64), (64, 64, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (32, 16, 16), (16, 16, 16), *gen_conv_params(ConvFwdAndBwdInput, (32, 16, 16), (16, 16, 16),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
...@@ -743,6 +1161,6 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -743,6 +1161,6 @@ IMPLGEMM_TURING_PARAMS = [
# gen_conv_params(ConvFwdAndBwdInput, ) # gen_conv_params(ConvFwdAndBwdInput, )
] ]
ALL_NATIVE_PARAMS = SHUFFLE_SIMT_PARAMS + SHUFFLE_TURING_PARAMS + SHUFFLE_VOLTA_PARAMS ALL_NATIVE_PARAMS = SHUFFLE_SIMT_PARAMS + SHUFFLE_TURING_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_AMPERE_PARAMS
ALL_IMPGEMM_PARAMS = IMPLGEMM_SIMT_PARAMS + IMPLGEMM_TURING_PARAMS + IMPLGEMM_VOLTA_PARAMS ALL_IMPGEMM_PARAMS = IMPLGEMM_SIMT_PARAMS + IMPLGEMM_TURING_PARAMS + IMPLGEMM_VOLTA_PARAMS + IMPLGEMM_AMPERE_PARAMS
...@@ -5,9 +5,9 @@ from cumm.common import CompileInfo ...@@ -5,9 +5,9 @@ from cumm.common import CompileInfo
from cumm.conv.main import ConvMainUnitTest from cumm.conv.main import ConvMainUnitTest
from cumm.gemm.main import GemmMainUnitTest from cumm.gemm.main import GemmMainUnitTest
from pccm.builder.pybind import gen_cmake from pccm.builder.pybind import gen_cmake
from spconv.core import (IMPLGEMM_SIMT_PARAMS, IMPLGEMM_TURING_PARAMS, from spconv.core import (IMPLGEMM_SIMT_PARAMS, IMPLGEMM_TURING_PARAMS, IMPLGEMM_AMPERE_PARAMS,
IMPLGEMM_VOLTA_PARAMS, SHUFFLE_SIMT_PARAMS, IMPLGEMM_VOLTA_PARAMS, SHUFFLE_SIMT_PARAMS,
SHUFFLE_TURING_PARAMS, SHUFFLE_VOLTA_PARAMS) SHUFFLE_TURING_PARAMS, SHUFFLE_VOLTA_PARAMS, SHUFFLE_AMPERE_PARAMS)
from spconv.csrc.hash.core import HashTable from spconv.csrc.hash.core import HashTable
from spconv.csrc.sparse.all import SpconvOps from spconv.csrc.sparse.all import SpconvOps
from spconv.csrc.sparse.alloc import ExternalAllocator, StaticAllocator from spconv.csrc.sparse.alloc import ExternalAllocator, StaticAllocator
...@@ -24,7 +24,7 @@ def main(include: str, ...@@ -24,7 +24,7 @@ def main(include: str,
libname: str = "spconv", libname: str = "spconv",
prefix: str = "spconvlib", prefix: str = "spconvlib",
inference_only: bool = False): inference_only: bool = False):
all_shuffle = SHUFFLE_SIMT_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_TURING_PARAMS all_shuffle = SHUFFLE_SIMT_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_TURING_PARAMS + SHUFFLE_AMPERE_PARAMS
all_shuffle = list(filter(lambda x: not x.is_nvrtc, all_shuffle)) all_shuffle = list(filter(lambda x: not x.is_nvrtc, all_shuffle))
if inference_only: if inference_only:
all_shuffle = list(filter(lambda x: x.shuffle_stride != ShuffleStrideType.ShuffleAB, all_shuffle)) all_shuffle = list(filter(lambda x: x.shuffle_stride != ShuffleStrideType.ShuffleAB, all_shuffle))
...@@ -32,7 +32,7 @@ def main(include: str, ...@@ -32,7 +32,7 @@ def main(include: str,
cu = GemmMainUnitTest(all_shuffle) cu = GemmMainUnitTest(all_shuffle)
cu.namespace = "cumm.gemm.main" cu.namespace = "cumm.gemm.main"
all_imp = (IMPLGEMM_SIMT_PARAMS + IMPLGEMM_VOLTA_PARAMS + all_imp = (IMPLGEMM_SIMT_PARAMS + IMPLGEMM_VOLTA_PARAMS +
IMPLGEMM_TURING_PARAMS) IMPLGEMM_TURING_PARAMS + IMPLGEMM_AMPERE_PARAMS)
# all_imp = IMPLGEMM_SIMT_PARAMS # all_imp = IMPLGEMM_SIMT_PARAMS
all_imp = list(filter(lambda x: not x.is_nvrtc, all_imp)) all_imp = list(filter(lambda x: not x.is_nvrtc, all_imp))
if inference_only: if inference_only:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment