Commit 813bf014 authored by one's avatar one
Browse files

spconv: filter DTK MaskImplicitGemm descriptors

Limit the DTK SIMT MaskImplicitGemm forward descriptor set for the
MV2DFusion VVM/SubM fp32 shape family on sm_93.

For kv=27, input channels=128, and output channels=64, keep only the
SIMT descriptor validated by dense-oracle replay. Apply the same
selection rule in both the Python tuner path and generated C++ tuner
path to keep build-time and runtime behavior aligned.
parent 3610ebfa
......@@ -91,6 +91,40 @@ class BestConvAlgoByProfile:
self.arch = arch
def _is_dtk_certified_maskimplicit_fwd(desp: ConvAlgoDesp,
inp: tv.Tensor,
weight: tv.Tensor,
out: tv.Tensor,
arch: Tuple[int, int],
op_type: ConvOpType,
kv: int) -> bool:
"""Certified DTK/BW150 fp32 SubM MaskImplicitGemm subset.
This intentionally covers only the MV2DFusion VVM/SubM shape family that
has been validated against a dense oracle. Other shapes still use the
broader DTK SIMT set until they are certified separately.
"""
if os.getenv("SPCONV_DTK_KERNEL_FILTER", "").lower() != "dtk_simt":
return True
if arch != (9, 3):
return True
if op_type != ConvOpType.kForward:
return True
if not (inp.dtype == tv.float32 and weight.dtype == tv.float32
and out.dtype == tv.float32):
return True
if not (kv == 27 and out.dim(1) == 64 and inp.dim(1) == 128):
return True
return (desp.algo == GemmAlgo.Simt.value
and tuple(desp.tile_shape) == (32, 256, 8)
and tuple(desp.warp_tile_shape) == (32, 64, 8)
and tuple(desp.tensorop) == (-1, -1, -1)
and desp.increment_k_first
and desp.mask_sparse
and not desp.dynamic_mask
and not desp.split_k_serial)
def _get_nvrtc_params(mod: CummNVRTCModule, ker: Union[GemmKernel, ConvKernel],
kernel_name: str):
nvrtc_mode = SPCONV_NVRTC_MODE
......@@ -778,6 +812,9 @@ class SimpleConv:
else:
if desp.dynamic_mask:
continue
if not _is_dtk_certified_maskimplicit_fwd(
desp, inp, weight, out, arch, op_type, kv):
continue
finally_algos.append(desp)
return finally_algos
......
......@@ -1122,6 +1122,36 @@ class ConvTunerSimple(pccm.ParameterizedClass):
continue;
}}
}}
bool dtk_certified_maskimplicit_shape =
{pccm.boolean(os.getenv("SPCONV_DTK_KERNEL_FILTER", "").lower() == "dtk_simt")} &&
arch == std::make_tuple(9, 3) &&
op_type_cpp == tv::gemm::ConvOpType::kForward &&
inp.dtype() == tv::float32 &&
weight.dtype() == tv::float32 &&
out.dtype() == tv::float32 &&
kv == 27 &&
out.dim(1) == 64 &&
inp.dim(1) == 128;
if (dtk_certified_maskimplicit_shape){{
bool dtk_certified_desp =
desp.algo == {pccm.literal(GemmAlgo.Simt.value)} &&
desp.tile_shape[0] == 32 &&
desp.tile_shape[1] == 256 &&
desp.tile_shape[2] == 8 &&
desp.warp_tile_shape[0] == 32 &&
desp.warp_tile_shape[1] == 64 &&
desp.warp_tile_shape[2] == 8 &&
desp.tensorop[0] == -1 &&
desp.tensorop[1] == -1 &&
desp.tensorop[2] == -1 &&
desp.increment_k_first &&
desp.mask_sparse &&
!desp.dynamic_mask &&
!desp.split_k_serial();
if (!dtk_certified_desp){{
continue;
}}
}}
finally_algos.push_back(desp2);
}}
}}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment