Unverified Commit bc32444b authored by Vel's avatar Vel Committed by GitHub
Browse files

[Kernel] Add enable_sm120_or_later for SM121 (DGX Spark) CUTLASS support (#33517)


Signed-off-by: default avatarcode4me2 <velvetmoon222999@gmail.com>
parent 18e85452
...@@ -152,3 +152,14 @@ struct enable_sm120_only : Kernel { ...@@ -152,3 +152,14 @@ struct enable_sm120_only : Kernel {
#endif #endif
} }
}; };
// SM12x family includes SM120 (RTX 5090) and SM121 (DGX Spark GB10)
template <typename Kernel>
struct enable_sm120_family : Kernel {
template <typename... Args>
CUTLASS_DEVICE void operator()(Args&&... args) {
#if defined __CUDA_ARCH__ && (__CUDA_ARCH__ >= 1200 && __CUDA_ARCH__ < 1300)
Kernel::operator()(std::forward<Args>(args)...);
#endif
}
};
...@@ -103,7 +103,8 @@ struct cutlass_3x_gemm_fp8_blockwise { ...@@ -103,7 +103,8 @@ struct cutlass_3x_gemm_fp8_blockwise {
MainloopScheduler MainloopScheduler
>::CollectiveOp; >::CollectiveOp;
using KernelType = enable_sm120_only<cutlass::gemm::kernel::GemmUniversal< // SM12x family to support both SM120 (RTX 5090) and SM121 (DGX Spark)
using KernelType = enable_sm120_family<cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue>>; Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue>>;
struct GemmKernel : public KernelType {}; struct GemmKernel : public KernelType {};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment