Commit d61ab8e1 authored by EvernightAurora's avatar EvernightAurora
Browse files

divide ampere, support multistage for i8 and f16

parent e8bc31ec
......@@ -151,8 +151,8 @@ if disable_jit is not None and disable_jit == "1":
'build_ext': PCCMBuild,
}
from cumm.gemm.main import GemmMainUnitTest
from spconv.core import SHUFFLE_SIMT_PARAMS, SHUFFLE_VOLTA_PARAMS, SHUFFLE_TURING_PARAMS
from spconv.core import IMPLGEMM_SIMT_PARAMS, IMPLGEMM_VOLTA_PARAMS, IMPLGEMM_TURING_PARAMS
from spconv.core import SHUFFLE_SIMT_PARAMS, SHUFFLE_VOLTA_PARAMS, SHUFFLE_TURING_PARAMS, SHUFFLE_AMPERE_PARAMS
from spconv.core import IMPLGEMM_SIMT_PARAMS, IMPLGEMM_VOLTA_PARAMS, IMPLGEMM_TURING_PARAMS, IMPLGEMM_AMPERE_PARAMS
from cumm.conv.main import ConvMainUnitTest
from cumm.constants import CUMM_CPU_ONLY_BUILD
from spconv.csrc.sparse.all import SpconvOps
......@@ -163,8 +163,8 @@ if disable_jit is not None and disable_jit == "1":
from spconv.csrc.sparse.convops import GemmTunerSimple, ExternalSpconvMatmul
from spconv.csrc.sparse.convops import ConvTunerSimple, ConvGemmOps
cu = GemmMainUnitTest(SHUFFLE_SIMT_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_TURING_PARAMS)
convcu = ConvMainUnitTest(IMPLGEMM_SIMT_PARAMS + IMPLGEMM_VOLTA_PARAMS + IMPLGEMM_TURING_PARAMS)
cu = GemmMainUnitTest(SHUFFLE_SIMT_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_TURING_PARAMS + SHUFFLE_AMPERE_PARAMS)
convcu = ConvMainUnitTest(IMPLGEMM_SIMT_PARAMS + IMPLGEMM_VOLTA_PARAMS + IMPLGEMM_TURING_PARAMS + IMPLGEMM_AMPERE_PARAMS)
convcu.namespace = "cumm.conv.main"
cu.namespace = "cumm.gemm.main"
......
......@@ -22,8 +22,8 @@ from .constants import PACKAGE_NAME, PACKAGE_ROOT, DISABLE_JIT
if project_is_installed(PACKAGE_NAME) and project_is_editable(
PACKAGE_NAME) and not DISABLE_JIT:
from spconv.core import SHUFFLE_SIMT_PARAMS, SHUFFLE_VOLTA_PARAMS, SHUFFLE_TURING_PARAMS
from spconv.core import IMPLGEMM_SIMT_PARAMS, IMPLGEMM_VOLTA_PARAMS, IMPLGEMM_TURING_PARAMS
from spconv.core import SHUFFLE_SIMT_PARAMS, SHUFFLE_VOLTA_PARAMS, SHUFFLE_TURING_PARAMS, SHUFFLE_AMPERE_PARAMS
from spconv.core import IMPLGEMM_SIMT_PARAMS, IMPLGEMM_VOLTA_PARAMS, IMPLGEMM_TURING_PARAMS, IMPLGEMM_AMPERE_PARAMS
from cumm.gemm.main import GemmMainUnitTest
from cumm.conv.main import ConvMainUnitTest
......@@ -37,12 +37,12 @@ if project_is_installed(PACKAGE_NAME) and project_is_editable(
from spconv.csrc.sparse.convops import ConvTunerSimple, ConvGemmOps
from spconv.csrc.sparse.convops import SimpleExternalSpconvMatmul
all_shuffle = SHUFFLE_SIMT_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_TURING_PARAMS
all_shuffle = SHUFFLE_SIMT_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_TURING_PARAMS + SHUFFLE_AMPERE_PARAMS
all_shuffle = list(filter(lambda x: not x.is_nvrtc, all_shuffle))
cu = GemmMainUnitTest(all_shuffle)
cu.namespace = "cumm.gemm.main"
all_imp = (IMPLGEMM_SIMT_PARAMS + IMPLGEMM_VOLTA_PARAMS +
IMPLGEMM_TURING_PARAMS)
IMPLGEMM_TURING_PARAMS + IMPLGEMM_AMPERE_PARAMS)
all_imp = list(filter(lambda x: not x.is_nvrtc, all_imp))
convcu = ConvMainUnitTest(all_imp)
convcu.namespace = "cumm.conv.main"
......
......@@ -168,7 +168,7 @@ SHUFFLE_TURING_PARAMS: List[GemmAlgoParams] = [
2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))),
*gen_shuffle_params(
(128, 128, 32),
(32, 64, 32), ["s8,s8,s32,s32,s32"], "", 2, kernel.GemmAlgo.Turing,
(32, 64, 32), ["s8,s8,s8,s32,s32"], "", 2, kernel.GemmAlgo.Turing,
TensorOp((8, 8, 16))),
# *gen_shuffle_params(
# (128, 128, 32),
......@@ -176,11 +176,11 @@ SHUFFLE_TURING_PARAMS: List[GemmAlgoParams] = [
# kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))),
*gen_shuffle_params(
(128, 256, 32),
(64, 64, 32), ["s8,s8,s32,s32,s32"], "", 2, kernel.GemmAlgo.Turing,
(64, 64, 32), ["s8,s8,s8,s32,s32"], "", 2, kernel.GemmAlgo.Turing,
TensorOp((8, 8, 16))),
*gen_shuffle_params(
(256, 128, 32),
(64, 64, 32), ["s8,s8,s32,s32,s32"], "", 2, kernel.GemmAlgo.Turing,
(64, 64, 32), ["s8,s8,s8,s32,s32"], "", 2, kernel.GemmAlgo.Turing,
TensorOp((8, 8, 16))),
*gen_shuffle_params((128, 64, 32), (64, 32, 32), ["s8,s8,s32,s32,s32"], "",
2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))),
......@@ -188,6 +188,17 @@ SHUFFLE_TURING_PARAMS: List[GemmAlgoParams] = [
2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))),
]
SHUFFLE_AMPERE_PARAMS = [
*gen_shuffle_params(
(128, 128, 64),
(64, 64, 64), ["s8,s8,s8,s32,s32"], "", 3, kernel.GemmAlgo.Ampere,
TensorOp((8, 8, 16))),
*gen_shuffle_params(
(128, 64, 64),
(64, 32, 64), ["s8,s8,s8,s32,s32"], "", 3, kernel.GemmAlgo.Ampere,
TensorOp((8, 8, 16))),
]
# SHUFFLE_TURING_PARAMS = []
# here we must use f32 for simt f16 accumulators because
# half intristics is VERY SLOW in GTX 1000 series.
......@@ -486,8 +497,415 @@ IMPLGEMM_VOLTA_PARAMS = [
access_per_vector=1),
]
IMPLGEMM_AMPERE_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (32, 32, 16), (16, 16, 16),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["f32,f32,f32,f32,f32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 8), top_dtypes="tf32,tf32,f32"),
mask_sparse=True,
increment_k_first=True,
access_per_vector=0),
*gen_conv_params(ConvFwdAndBwdInput, (32, 32, 16), (16, 16, 16),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4], ["f32,f32,f32,f32,f32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 8), top_dtypes="tf32,tf32,f32"),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4], ["f32,f32,f32,f32,f32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 8), top_dtypes="tf32,tf32,f32"),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (64, 128, 32), (32, 64, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4], ["f16,f16,f16,f16,f16", "f16,f16,f16,f32,f32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 8)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4], ["f16,f16,f16,f16,f16", "f16,f16,f16,f32,f32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 8)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (128, 64, 32), (64, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4], ["f16,f16,f16,f16,f16", "f16,f16,f16,f32,f32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 8)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvBwdWeight, (64, 64, 32), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
"f32,f32,f32,f32,f32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 8), top_dtypes="tf32,tf32,f32"),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvBwdWeight, (64, 128, 32), (32, 64, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
"f16,f16,f16,f32,f32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 8)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvBwdWeight, (64, 64, 32), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4, 5],
"f16,f16,f16,f32,f32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 8)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (128, 64, 64), (64, 32, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (128, 64, 32), (64, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (64, 128, 64), (32, 64, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (64, 128, 32), (32, 64, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 64), (32, 32, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (128, 128, 64), (64, 64, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (128, 256, 64), (64, 128, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (256, 128, 64), (128, 64, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (128, 128, 128), (64, 64, 128),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3],
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
]
IMPLGEMM_TURING_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 64), (32, 32, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (64, 128, 64), (32, 64, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (64, 128, 32), (32, 64, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (128, 64, 64), (64, 32, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (128, 64, 32), (64, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (128, 256, 64), (64, 128, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (256, 128, 64), (128, 64, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (128, 128, 128), (64, 64, 128),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (128, 128, 64), (64, 64, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
"s8,s8,s8,s32,s32",
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (32, 16, 16), (16, 16, 16),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
......@@ -743,6 +1161,6 @@ IMPLGEMM_TURING_PARAMS = [
# gen_conv_params(ConvFwdAndBwdInput, )
]
ALL_NATIVE_PARAMS = SHUFFLE_SIMT_PARAMS + SHUFFLE_TURING_PARAMS + SHUFFLE_VOLTA_PARAMS
ALL_NATIVE_PARAMS = SHUFFLE_SIMT_PARAMS + SHUFFLE_TURING_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_AMPERE_PARAMS
ALL_IMPGEMM_PARAMS = IMPLGEMM_SIMT_PARAMS + IMPLGEMM_TURING_PARAMS + IMPLGEMM_VOLTA_PARAMS
ALL_IMPGEMM_PARAMS = IMPLGEMM_SIMT_PARAMS + IMPLGEMM_TURING_PARAMS + IMPLGEMM_VOLTA_PARAMS + IMPLGEMM_AMPERE_PARAMS
......@@ -5,9 +5,9 @@ from cumm.common import CompileInfo
from cumm.conv.main import ConvMainUnitTest
from cumm.gemm.main import GemmMainUnitTest
from pccm.builder.pybind import gen_cmake
from spconv.core import (IMPLGEMM_SIMT_PARAMS, IMPLGEMM_TURING_PARAMS,
from spconv.core import (IMPLGEMM_SIMT_PARAMS, IMPLGEMM_TURING_PARAMS, IMPLGEMM_AMPERE_PARAMS,
IMPLGEMM_VOLTA_PARAMS, SHUFFLE_SIMT_PARAMS,
SHUFFLE_TURING_PARAMS, SHUFFLE_VOLTA_PARAMS)
SHUFFLE_TURING_PARAMS, SHUFFLE_VOLTA_PARAMS, SHUFFLE_AMPERE_PARAMS)
from spconv.csrc.hash.core import HashTable
from spconv.csrc.sparse.all import SpconvOps
from spconv.csrc.sparse.alloc import ExternalAllocator, StaticAllocator
......@@ -24,7 +24,7 @@ def main(include: str,
libname: str = "spconv",
prefix: str = "spconvlib",
inference_only: bool = False):
all_shuffle = SHUFFLE_SIMT_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_TURING_PARAMS
all_shuffle = SHUFFLE_SIMT_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_TURING_PARAMS + SHUFFLE_AMPERE_PARAMS
all_shuffle = list(filter(lambda x: not x.is_nvrtc, all_shuffle))
if inference_only:
all_shuffle = list(filter(lambda x: x.shuffle_stride != ShuffleStrideType.ShuffleAB, all_shuffle))
......@@ -32,7 +32,7 @@ def main(include: str,
cu = GemmMainUnitTest(all_shuffle)
cu.namespace = "cumm.gemm.main"
all_imp = (IMPLGEMM_SIMT_PARAMS + IMPLGEMM_VOLTA_PARAMS +
IMPLGEMM_TURING_PARAMS)
IMPLGEMM_TURING_PARAMS + IMPLGEMM_AMPERE_PARAMS)
# all_imp = IMPLGEMM_SIMT_PARAMS
all_imp = list(filter(lambda x: not x.is_nvrtc, all_imp))
if inference_only:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment