Unverified Commit c0dfc894 authored by Hendrik Holtmann's avatar Hendrik Holtmann Committed by GitHub
Browse files

SM120 / NVFP4: add device guard and runtime SM dispatch to cutlass_scaled_fp4_mm (#29711)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
Co-authored-by: default avatarmgoin <mgoin64@gmail.com>
parent 44822d7f
......@@ -15,6 +15,8 @@
*/
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include "cutlass_extensions/common.hpp"
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A,
......@@ -32,19 +34,30 @@ void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& alpha);
#endif
void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& B, torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha) {
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
return cutlass_scaled_fp4_mm_sm100a(D, A, B, A_sf, B_sf, alpha);
#elif defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
return cutlass_scaled_fp4_mm_sm120a(D, A, B, A_sf, B_sf, alpha);
void cutlass_scaled_fp4_mm(torch::Tensor& D, const torch::Tensor& A,
const torch::Tensor& B, const torch::Tensor& A_sf,
const torch::Tensor& B_sf,
const torch::Tensor& alpha) {
// Make sure we’re on A’s device.
const c10::cuda::OptionalCUDAGuard device_guard(device_of(A));
const int32_t sm = get_sm_version_num();
#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100
if (sm >= 100 && sm < 120) {
cutlass_scaled_fp4_mm_sm100a(D, A, B, A_sf, B_sf, alpha);
return;
}
#endif
TORCH_CHECK_NOT_IMPLEMENTED(false,
"No compiled nvfp4 mm kernel, vLLM should "
"be compiled using CUDA 12.8 and target "
"compute capability 100 or above.");
#if defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120
if (sm >= 120 && sm < 130) {
cutlass_scaled_fp4_mm_sm120a(D, A, B, A_sf, B_sf, alpha);
return;
}
#endif
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 mm kernel for SM ", sm,
". Recompile with CUDA >= 12.8 and CC >= 100.");
}
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment