Unverified Commit 7a16db9b authored by hlu1's avatar hlu1 Committed by GitHub
Browse files

Make sm100 fp8 kernels available on sm103 (#9789)


Signed-off-by: default avatarHao Lu <14827759+hlu1@users.noreply.github.com>
parent 09a1df22
...@@ -260,7 +260,11 @@ torch::Tensor fp8_blockwise_scaled_mm( ...@@ -260,7 +260,11 @@ torch::Tensor fp8_blockwise_scaled_mm(
#if defined(CUTLASS_ARCH_MMA_SM100A_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) #if defined(CUTLASS_ARCH_MMA_SM100A_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
#if defined CUDA_VERSION && CUDA_VERSION >= 12080 #if defined CUDA_VERSION && CUDA_VERSION >= 12080
if (sm_version == 100) { if (sm_version == 100
#if CUDA_VERSION >= 12090
|| sm_version == 103
#endif
) {
if (out_dtype == torch::kBFloat16) { if (out_dtype == torch::kBFloat16) {
sm100_fp8_blockwise_dispatch_shape<cutlass::bfloat16_t>( sm100_fp8_blockwise_dispatch_shape<cutlass::bfloat16_t>(
out_padded, mat_a_padded, mat_b, scales_a_padded, scales_b); out_padded, mat_a_padded, mat_b, scales_a_padded, scales_b);
......
...@@ -1212,7 +1212,11 @@ torch::Tensor fp8_scaled_mm( ...@@ -1212,7 +1212,11 @@ torch::Tensor fp8_scaled_mm(
auto sm_version = getSMVersion(); auto sm_version = getSMVersion();
#if defined CUDA_VERSION && CUDA_VERSION >= 12080 #if defined CUDA_VERSION && CUDA_VERSION >= 12080
if (sm_version >= 100) { if (sm_version == 100
#if CUDA_VERSION >= 12090
|| sm_version == 103
#endif
) {
if (out_dtype == torch::kBFloat16) { if (out_dtype == torch::kBFloat16) {
sm100_fp8_dispatch_shape<cutlass::bfloat16_t>(out, mat_a, mat_b, scales_a, scales_b, bias); sm100_fp8_dispatch_shape<cutlass::bfloat16_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
} else { } else {
......
...@@ -708,7 +708,11 @@ void fp8_blockwise_scaled_grouped_mm( ...@@ -708,7 +708,11 @@ void fp8_blockwise_scaled_grouped_mm(
#if defined(CUTLASS_ARCH_MMA_SM100A_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) #if defined(CUTLASS_ARCH_MMA_SM100A_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
#if defined CUDA_VERSION && CUDA_VERSION >= 12080 #if defined CUDA_VERSION && CUDA_VERSION >= 12080
if (sm_version == 100) { if (sm_version == 100
#if CUDA_VERSION >= 12090
|| sm_version == 103
#endif
) {
if (output.scalar_type() == torch::kBFloat16) { if (output.scalar_type() == torch::kBFloat16) {
sm100_fp8_blockwise_group_mm_dispatch_shape<cutlass::bfloat16_t>( sm100_fp8_blockwise_group_mm_dispatch_shape<cutlass::bfloat16_t>(
output, output,
...@@ -802,5 +806,5 @@ void fp8_blockwise_scaled_grouped_mm( ...@@ -802,5 +806,5 @@ void fp8_blockwise_scaled_grouped_mm(
} }
#endif #endif
TORCH_CHECK_NOT_IMPLEMENTED( TORCH_CHECK_NOT_IMPLEMENTED(
can_implement, "No implemented fp8_blockwise_scaled_mm for current compute capability: ", sm_version); can_implement, "No implemented fp8_blockwise_scaled_grouped_mm for current compute capability: ", sm_version);
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment