Unverified Commit 34e97911 authored by FindDefinition's avatar FindDefinition Committed by GitHub
Browse files

Merge pull request #515 from EvernightAurora/master

Feature/Ampere
parents f8c25027 d61ab8e1
...@@ -151,8 +151,8 @@ if disable_jit is not None and disable_jit == "1": ...@@ -151,8 +151,8 @@ if disable_jit is not None and disable_jit == "1":
'build_ext': PCCMBuild, 'build_ext': PCCMBuild,
} }
from cumm.gemm.main import GemmMainUnitTest from cumm.gemm.main import GemmMainUnitTest
from spconv.core import SHUFFLE_SIMT_PARAMS, SHUFFLE_VOLTA_PARAMS, SHUFFLE_TURING_PARAMS from spconv.core import SHUFFLE_SIMT_PARAMS, SHUFFLE_VOLTA_PARAMS, SHUFFLE_TURING_PARAMS, SHUFFLE_AMPERE_PARAMS
from spconv.core import IMPLGEMM_SIMT_PARAMS, IMPLGEMM_VOLTA_PARAMS, IMPLGEMM_TURING_PARAMS from spconv.core import IMPLGEMM_SIMT_PARAMS, IMPLGEMM_VOLTA_PARAMS, IMPLGEMM_TURING_PARAMS, IMPLGEMM_AMPERE_PARAMS
from cumm.conv.main import ConvMainUnitTest from cumm.conv.main import ConvMainUnitTest
from cumm.constants import CUMM_CPU_ONLY_BUILD from cumm.constants import CUMM_CPU_ONLY_BUILD
from spconv.csrc.sparse.all import SpconvOps from spconv.csrc.sparse.all import SpconvOps
...@@ -164,8 +164,8 @@ if disable_jit is not None and disable_jit == "1": ...@@ -164,8 +164,8 @@ if disable_jit is not None and disable_jit == "1":
from spconv.csrc.sparse.convops import ConvTunerSimple, ConvGemmOps from spconv.csrc.sparse.convops import ConvTunerSimple, ConvGemmOps
from spconv.csrc.sparse.inference import InferenceOps from spconv.csrc.sparse.inference import InferenceOps
cu = GemmMainUnitTest(SHUFFLE_SIMT_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_TURING_PARAMS) cu = GemmMainUnitTest(SHUFFLE_SIMT_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_TURING_PARAMS + SHUFFLE_AMPERE_PARAMS)
convcu = ConvMainUnitTest(IMPLGEMM_SIMT_PARAMS + IMPLGEMM_VOLTA_PARAMS + IMPLGEMM_TURING_PARAMS) convcu = ConvMainUnitTest(IMPLGEMM_SIMT_PARAMS + IMPLGEMM_VOLTA_PARAMS + IMPLGEMM_TURING_PARAMS + IMPLGEMM_AMPERE_PARAMS)
convcu.namespace = "cumm.conv.main" convcu.namespace = "cumm.conv.main"
cu.namespace = "cumm.gemm.main" cu.namespace = "cumm.gemm.main"
......
...@@ -22,8 +22,8 @@ from .constants import PACKAGE_NAME, PACKAGE_ROOT, DISABLE_JIT ...@@ -22,8 +22,8 @@ from .constants import PACKAGE_NAME, PACKAGE_ROOT, DISABLE_JIT
if project_is_installed(PACKAGE_NAME) and project_is_editable( if project_is_installed(PACKAGE_NAME) and project_is_editable(
PACKAGE_NAME) and not DISABLE_JIT: PACKAGE_NAME) and not DISABLE_JIT:
from spconv.core import SHUFFLE_SIMT_PARAMS, SHUFFLE_VOLTA_PARAMS, SHUFFLE_TURING_PARAMS from spconv.core import SHUFFLE_SIMT_PARAMS, SHUFFLE_VOLTA_PARAMS, SHUFFLE_TURING_PARAMS, SHUFFLE_AMPERE_PARAMS
from spconv.core import IMPLGEMM_SIMT_PARAMS, IMPLGEMM_VOLTA_PARAMS, IMPLGEMM_TURING_PARAMS from spconv.core import IMPLGEMM_SIMT_PARAMS, IMPLGEMM_VOLTA_PARAMS, IMPLGEMM_TURING_PARAMS, IMPLGEMM_AMPERE_PARAMS
from cumm.gemm.main import GemmMainUnitTest from cumm.gemm.main import GemmMainUnitTest
from cumm.conv.main import ConvMainUnitTest from cumm.conv.main import ConvMainUnitTest
...@@ -38,12 +38,12 @@ if project_is_installed(PACKAGE_NAME) and project_is_editable( ...@@ -38,12 +38,12 @@ if project_is_installed(PACKAGE_NAME) and project_is_editable(
from spconv.csrc.sparse.convops import SimpleExternalSpconvMatmul from spconv.csrc.sparse.convops import SimpleExternalSpconvMatmul
from spconv.csrc.sparse.inference import InferenceOps from spconv.csrc.sparse.inference import InferenceOps
all_shuffle = SHUFFLE_SIMT_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_TURING_PARAMS all_shuffle = SHUFFLE_SIMT_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_TURING_PARAMS + SHUFFLE_AMPERE_PARAMS
all_shuffle = list(filter(lambda x: not x.is_nvrtc, all_shuffle)) all_shuffle = list(filter(lambda x: not x.is_nvrtc, all_shuffle))
cu = GemmMainUnitTest(all_shuffle) cu = GemmMainUnitTest(all_shuffle)
cu.namespace = "cumm.gemm.main" cu.namespace = "cumm.gemm.main"
all_imp = (IMPLGEMM_SIMT_PARAMS + IMPLGEMM_VOLTA_PARAMS + all_imp = (IMPLGEMM_SIMT_PARAMS + IMPLGEMM_VOLTA_PARAMS +
IMPLGEMM_TURING_PARAMS) IMPLGEMM_TURING_PARAMS + IMPLGEMM_AMPERE_PARAMS)
all_imp = list(filter(lambda x: not x.is_nvrtc, all_imp)) all_imp = list(filter(lambda x: not x.is_nvrtc, all_imp))
convcu = ConvMainUnitTest(all_imp) convcu = ConvMainUnitTest(all_imp)
convcu.namespace = "cumm.conv.main" convcu.namespace = "cumm.conv.main"
......
...@@ -168,7 +168,7 @@ SHUFFLE_TURING_PARAMS: List[GemmAlgoParams] = [ ...@@ -168,7 +168,7 @@ SHUFFLE_TURING_PARAMS: List[GemmAlgoParams] = [
2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))), 2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))),
*gen_shuffle_params( *gen_shuffle_params(
(128, 128, 32), (128, 128, 32),
(32, 64, 32), ["s8,s8,s32,s32,s32"], "", 2, kernel.GemmAlgo.Turing, (32, 64, 32), ["s8,s8,s8,s32,s32"], "", 2, kernel.GemmAlgo.Turing,
TensorOp((8, 8, 16))), TensorOp((8, 8, 16))),
# *gen_shuffle_params( # *gen_shuffle_params(
# (128, 128, 32), # (128, 128, 32),
...@@ -176,11 +176,11 @@ SHUFFLE_TURING_PARAMS: List[GemmAlgoParams] = [ ...@@ -176,11 +176,11 @@ SHUFFLE_TURING_PARAMS: List[GemmAlgoParams] = [
# kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))), # kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))),
*gen_shuffle_params( *gen_shuffle_params(
(128, 256, 32), (128, 256, 32),
(64, 64, 32), ["s8,s8,s32,s32,s32"], "", 2, kernel.GemmAlgo.Turing, (64, 64, 32), ["s8,s8,s8,s32,s32"], "", 2, kernel.GemmAlgo.Turing,
TensorOp((8, 8, 16))), TensorOp((8, 8, 16))),
*gen_shuffle_params( *gen_shuffle_params(
(256, 128, 32), (256, 128, 32),
(64, 64, 32), ["s8,s8,s32,s32,s32"], "", 2, kernel.GemmAlgo.Turing, (64, 64, 32), ["s8,s8,s8,s32,s32"], "", 2, kernel.GemmAlgo.Turing,
TensorOp((8, 8, 16))), TensorOp((8, 8, 16))),
*gen_shuffle_params((128, 64, 32), (64, 32, 32), ["s8,s8,s32,s32,s32"], "", *gen_shuffle_params((128, 64, 32), (64, 32, 32), ["s8,s8,s32,s32,s32"], "",
2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))), 2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))),
...@@ -188,6 +188,17 @@ SHUFFLE_TURING_PARAMS: List[GemmAlgoParams] = [ ...@@ -188,6 +188,17 @@ SHUFFLE_TURING_PARAMS: List[GemmAlgoParams] = [
2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))), 2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))),
] ]
SHUFFLE_AMPERE_PARAMS = [
*gen_shuffle_params(
(128, 128, 64),
(64, 64, 64), ["s8,s8,s8,s32,s32"], "", 3, kernel.GemmAlgo.Ampere,
TensorOp((8, 8, 16))),
*gen_shuffle_params(
(128, 64, 64),
(64, 32, 64), ["s8,s8,s8,s32,s32"], "", 3, kernel.GemmAlgo.Ampere,
TensorOp((8, 8, 16))),
]
# SHUFFLE_TURING_PARAMS = [] # SHUFFLE_TURING_PARAMS = []
# here we must use f32 for simt f16 accumulators because # here we must use f32 for simt f16 accumulators because
# half intristics is VERY SLOW in GTX 1000 series. # half intristics is VERY SLOW in GTX 1000 series.
...@@ -486,7 +497,464 @@ IMPLGEMM_VOLTA_PARAMS = [ ...@@ -486,7 +497,464 @@ IMPLGEMM_VOLTA_PARAMS = [
access_per_vector=1), access_per_vector=1),
] ]
IMPLGEMM_AMPERE_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (32, 32, 16), (16, 16, 16),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["f32,f32,f32,f32,f32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 8), top_dtypes="tf32,tf32,f32"),
mask_sparse=True,
increment_k_first=True,
access_per_vector=0),
*gen_conv_params(ConvFwdAndBwdInput, (32, 32, 16), (16, 16, 16),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4], ["f32,f32,f32,f32,f32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 8), top_dtypes="tf32,tf32,f32"),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4], ["f32,f32,f32,f32,f32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 8), top_dtypes="tf32,tf32,f32"),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (64, 128, 32), (32, 64, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4], ["f16,f16,f16,f16,f16", "f16,f16,f16,f32,f32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 8)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4], ["f16,f16,f16,f16,f16", "f16,f16,f16,f32,f32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 8)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (128, 64, 32), (64, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4], ["f16,f16,f16,f16,f16", "f16,f16,f16,f32,f32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 8)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvBwdWeight, (64, 64, 32), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
"f32,f32,f32,f32,f32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 8), top_dtypes="tf32,tf32,f32"),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvBwdWeight, (64, 128, 32), (32, 64, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
"f16,f16,f16,f32,f32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 8)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvBwdWeight, (64, 64, 32), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4, 5],
"f16,f16,f16,f32,f32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 8)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (128, 64, 64), (64, 32, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (128, 64, 32), (64, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (64, 128, 64), (32, 64, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (64, 128, 32), (32, 64, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 64), (32, 32, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (128, 128, 64), (64, 64, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (128, 256, 64), (64, 128, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (256, 128, 64), (128, 64, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (128, 128, 128), (64, 64, 128),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3],
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
]
IMPLGEMM_TURING_PARAMS = [ IMPLGEMM_TURING_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 64), (32, 32, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (64, 128, 64), (32, 64, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (64, 128, 32), (32, 64, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (128, 64, 64), (64, 32, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (128, 64, 32), (64, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (128, 256, 64), (64, 128, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (256, 128, 64), (128, 64, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (128, 128, 128), (64, 64, 128),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (128, 128, 64), (64, 64, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (32, 16, 16), (16, 16, 16),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["f16,f16,f16,f16,f16", "f16,f16,f16,f32,f32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 8)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=0),
*gen_conv_params(ConvFwdAndBwdInput, (32, 16, 16), (16, 16, 16),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["f16,f16,f16,f16,f16", "f16,f16,f16,f32,f32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 8)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (64, 32, 16), (32, 16, 16),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["f16,f16,f16,f16,f16", "f16,f16,f16,f32,f32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 8)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=0),
*gen_conv_params(ConvFwdAndBwdInput, (64, 32, 16), (32, 16, 16),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["f16,f16,f16,f16,f16", "f16,f16,f16,f32,f32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 8)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (32, 64, 32), (32, 32, 16), *gen_conv_params(ConvFwdAndBwdInput, (32, 64, 32), (32, 32, 16),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
...@@ -658,12 +1126,41 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -658,12 +1126,41 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1), access_per_vector=1),
*gen_conv_params(ConvBwdWeight, (64, 16, 32), (32, 16, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
"f16,f16,f16,f32,f32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 8)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=0),
*gen_conv_params(ConvBwdWeight, (64, 16, 32), (32, 16, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
"f16,f16,f16,f32,f32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 8)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
# *gen_conv_params(ConvBwdWeight, (32, 64, 32), (32, 32, 16), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, "f16,f16,f16,f32,f32", # *gen_conv_params(ConvBwdWeight, (32, 64, 32), (32, 32, 16), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, "f16,f16,f16,f32,f32",
# NHWC, NHWC, NHWC, GemmAlgo.Turing, TensorOp((16, 8, 8)), mask_sparse=True, increment_k_first=True, access_per_vector=1), # NHWC, NHWC, NHWC, GemmAlgo.Turing, TensorOp((16, 8, 8)), mask_sparse=True, increment_k_first=True, access_per_vector=1),
# gen_conv_params(ConvFwdAndBwdInput, ) # gen_conv_params(ConvFwdAndBwdInput, )
] ]
ALL_NATIVE_PARAMS = SHUFFLE_SIMT_PARAMS + SHUFFLE_TURING_PARAMS + SHUFFLE_VOLTA_PARAMS ALL_NATIVE_PARAMS = SHUFFLE_SIMT_PARAMS + SHUFFLE_TURING_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_AMPERE_PARAMS
ALL_IMPGEMM_PARAMS = IMPLGEMM_SIMT_PARAMS + IMPLGEMM_TURING_PARAMS + IMPLGEMM_VOLTA_PARAMS ALL_IMPGEMM_PARAMS = IMPLGEMM_SIMT_PARAMS + IMPLGEMM_TURING_PARAMS + IMPLGEMM_VOLTA_PARAMS + IMPLGEMM_AMPERE_PARAMS
...@@ -5,9 +5,9 @@ from cumm.common import CompileInfo ...@@ -5,9 +5,9 @@ from cumm.common import CompileInfo
from cumm.conv.main import ConvMainUnitTest from cumm.conv.main import ConvMainUnitTest
from cumm.gemm.main import GemmMainUnitTest from cumm.gemm.main import GemmMainUnitTest
from pccm.builder.pybind import gen_cmake from pccm.builder.pybind import gen_cmake
from spconv.core import (IMPLGEMM_SIMT_PARAMS, IMPLGEMM_TURING_PARAMS, from spconv.core import (IMPLGEMM_SIMT_PARAMS, IMPLGEMM_TURING_PARAMS, IMPLGEMM_AMPERE_PARAMS,
IMPLGEMM_VOLTA_PARAMS, SHUFFLE_SIMT_PARAMS, IMPLGEMM_VOLTA_PARAMS, SHUFFLE_SIMT_PARAMS,
SHUFFLE_TURING_PARAMS, SHUFFLE_VOLTA_PARAMS) SHUFFLE_TURING_PARAMS, SHUFFLE_VOLTA_PARAMS, SHUFFLE_AMPERE_PARAMS)
from spconv.csrc.hash.core import HashTable from spconv.csrc.hash.core import HashTable
from spconv.csrc.sparse.all import SpconvOps from spconv.csrc.sparse.all import SpconvOps
from spconv.csrc.sparse.alloc import ExternalAllocator, StaticAllocator from spconv.csrc.sparse.alloc import ExternalAllocator, StaticAllocator
...@@ -25,7 +25,7 @@ def main(include: str, ...@@ -25,7 +25,7 @@ def main(include: str,
libname: str = "spconv", libname: str = "spconv",
prefix: str = "spconvlib", prefix: str = "spconvlib",
inference_only: bool = False): inference_only: bool = False):
all_shuffle = SHUFFLE_SIMT_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_TURING_PARAMS all_shuffle = SHUFFLE_SIMT_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_TURING_PARAMS + SHUFFLE_AMPERE_PARAMS
all_shuffle = list(filter(lambda x: not x.is_nvrtc, all_shuffle)) all_shuffle = list(filter(lambda x: not x.is_nvrtc, all_shuffle))
if inference_only: if inference_only:
all_shuffle = list(filter(lambda x: x.shuffle_stride != ShuffleStrideType.ShuffleAB, all_shuffle)) all_shuffle = list(filter(lambda x: x.shuffle_stride != ShuffleStrideType.ShuffleAB, all_shuffle))
...@@ -33,7 +33,7 @@ def main(include: str, ...@@ -33,7 +33,7 @@ def main(include: str,
cu = GemmMainUnitTest(all_shuffle) cu = GemmMainUnitTest(all_shuffle)
cu.namespace = "cumm.gemm.main" cu.namespace = "cumm.gemm.main"
all_imp = (IMPLGEMM_SIMT_PARAMS + IMPLGEMM_VOLTA_PARAMS + all_imp = (IMPLGEMM_SIMT_PARAMS + IMPLGEMM_VOLTA_PARAMS +
IMPLGEMM_TURING_PARAMS) IMPLGEMM_TURING_PARAMS + IMPLGEMM_AMPERE_PARAMS)
# all_imp = IMPLGEMM_SIMT_PARAMS # all_imp = IMPLGEMM_SIMT_PARAMS
all_imp = list(filter(lambda x: not x.is_nvrtc, all_imp)) all_imp = list(filter(lambda x: not x.is_nvrtc, all_imp))
if inference_only: if inference_only:
......
...@@ -290,6 +290,49 @@ class Net2(nn.Module): ...@@ -290,6 +290,49 @@ class Net2(nn.Module):
return self.net(x) return self.net(x)
class NetSm(nn.Module):
def __init__(self, shape, algo):
super().__init__()
self.net = spconv.SparseSequential(
spconv.SubMConv3d(3,
8,
3,
bias=False,
indice_key="c0",
algo=algo),
spconv.SubMConv3d(8,
16,
3,
bias=False,
indice_key="c0",
algo=algo),
spconv.SubMConv3d(16,
32,
3,
bias=False,
indice_key="c0",
algo=algo),
spconv.SubMConv3d(32,
64,
3,
bias=False,
indice_key="c0",
algo=algo),
)
max_batch_size = 1
# grid (dense map) is used for indice generation. use pre-allocated grid can run faster.
self.grid = torch.full([max_batch_size, *shape], -1,
dtype=torch.int32).cuda()
# self.grid = None
self.shape = shape
def forward(self, features, coors, batch_size, enable_timer: bool = False):
x = spconv.SparseConvTensor(features, coors, self.shape, batch_size,
self.grid, enable_timer=enable_timer)
return self.net(x)
import numpy as np import numpy as np
from cumm import tensorview as tv from cumm import tensorview as tv
from spconv.core_cc.csrc.sparse.all import SpconvOps from spconv.core_cc.csrc.sparse.all import SpconvOps
...@@ -359,7 +402,7 @@ def main(): ...@@ -359,7 +402,7 @@ def main():
# MaskImpGemm: 51.0ms # MaskImpGemm: 51.0ms
# MaskSplitImpGemm: 41.1ms # MaskSplitImpGemm: 41.1ms
# algo = None # algo = None
net = Net(spatial_shape, algo).to(device).eval().to(dtype)# .train() net = NetSm(spatial_shape, algo).to(device).eval().to(dtype)# .train()
# net.load_state_dict(net.state_dict()) # net.load_state_dict(net.state_dict())
spconv.assign_name_for_sparse_modules(net) spconv.assign_name_for_sparse_modules(net)
print(coors_th.shape) print(coors_th.shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment