Unverified Commit ab1a6a43 authored by mikaylagawarecki's avatar mikaylagawarecki Committed by GitHub
Browse files

[3/n] Migrate cutlass/scaled_mm_entry.cu torch stable ABI (#37221)


Signed-off-by: default avatarMikayla Gawarecki <mikaylagawarecki@gmail.com>
parent b5e60825
...@@ -340,7 +340,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -340,7 +340,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_EXT_SRC list(APPEND VLLM_EXT_SRC
"csrc/quantization/awq/gemm_kernels.cu" "csrc/quantization/awq/gemm_kernels.cu"
"csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu"
"csrc/quantization/fp4/nvfp4_quant_entry.cu" "csrc/quantization/fp4/nvfp4_quant_entry.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu" "csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
"csrc/cutlass_extensions/common.cpp") "csrc/cutlass_extensions/common.cpp")
...@@ -490,132 +489,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -490,132 +489,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
" in CUDA target architectures") " in CUDA target architectures")
endif() endif()
set(SCALED_MM_3X_ARCHS)
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
# CUDA 12.0 or later
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS)
set(SRCS
"csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm90.cu"
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8.cu"
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8.cu"
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu"
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM90=1")
# Let scaled_mm_c2x know it doesn't need to build these arches
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
message(STATUS "Building scaled_mm_c3x_sm90 for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS)
message(STATUS "Not building scaled_mm_c3x_sm90 as CUDA Compiler version is "
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
"later if you intend on running FP8 quantized models on "
"Hopper.")
else()
message(STATUS "Not building scaled_mm_c3x_sm90 as no compatible archs found "
"in CUDA target architectures")
endif()
endif()
# The cutlass_scaled_mm kernels for Blackwell SM12x (c3x, i.e. CUTLASS 3.x) require
# CUDA 12.8 or later
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
set(SRCS
"csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm120.cu"
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu"
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8.cu"
)
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM120=1")
# Let scaled_mm_c2x know it doesn't need to build these arches
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
message(STATUS "Building scaled_mm_c3x_sm120 for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
message(STATUS "Not building scaled_mm_c3x_sm120 as CUDA Compiler version is "
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
"later if you intend on running FP8 quantized models on "
"Blackwell.")
else()
message(STATUS "Not building scaled_mm_c3x_120 as no compatible archs found "
"in CUDA target architectures")
endif()
endif()
# The cutlass_scaled_mm kernels for Blackwell SM100 (c3x, i.e. CUTLASS 3.x)
# require CUDA 12.8 or later
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
set(SRCS
"csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm100.cu"
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu"
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8.cu"
)
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM100=1")
# Let scaled_mm_c2x know it doesn't need to build these arches
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
message(STATUS "Building scaled_mm_c3x_sm100 for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
message(STATUS "Not building scaled_mm_c3x_sm100 as CUDA Compiler version is "
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
"later if you intend on running FP8 quantized models on "
"Blackwell.")
else()
message(STATUS "Not building scaled_mm_c3x_100 as no compatible archs found "
"in CUDA target architectures")
endif()
endif()
#
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
# kernels for the remaining archs that are not already built for 3x.
# (Build 8.9 for FP8)
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
"7.5;8.0;8.7;8.9+PTX" "${CUDA_ARCHS}")
# subtract out the archs that are already built for 3x
list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})
if (SCALED_MM_2X_ARCHS)
set(SRCS "csrc/quantization/w8a8/cutlass/scaled_mm_c2x.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_2X_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C2X=1")
message(STATUS "Building scaled_mm_c2x for archs: ${SCALED_MM_2X_ARCHS}")
else()
if (SCALED_MM_3X_ARCHS)
message(STATUS "Not building scaled_mm_c2x as all archs are already built"
" for and covered by scaled_mm_c3x")
else()
message(STATUS "Not building scaled_mm_c2x as no compatible archs found "
"in CUDA target architectures")
endif()
endif()
# The nvfp4_scaled_mm_sm120 kernels for Blackwell SM12x require # The nvfp4_scaled_mm_sm120 kernels for Blackwell SM12x require
# CUDA 12.8 or later # CUDA 12.8 or later
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
...@@ -693,55 +566,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -693,55 +566,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set(MLA_ARCHS) set(MLA_ARCHS)
endif() endif()
# CUTLASS MoE kernels
# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and ONLY works
# on Hopper). get_cutlass_(batched_)moe_mm_data should only be compiled
# if it's possible to compile MoE kernels that use its output.
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
set(SRCS "csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm90.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM90=1")
message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is "
"not >= 12.3, we recommend upgrading to CUDA 12.3 or later "
"if you intend on running FP8 quantized MoE models on Hopper.")
else()
message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
"in CUDA target architectures.")
endif()
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
set(SRCS "csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is "
"not >= 12.8, we recommend upgrading to CUDA 12.8 or later "
"if you intend on running FP8 quantized MoE models on Blackwell.")
else()
message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
"in CUDA target architectures.")
endif()
endif()
# Expert-specialization MXFP8 blockscaled grouped kernels (SM100+). # Expert-specialization MXFP8 blockscaled grouped kernels (SM100+).
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(ES_MXFP8_GROUPED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}") cuda_archs_loose_intersection(ES_MXFP8_GROUPED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
...@@ -787,36 +611,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -787,36 +611,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"in CUDA target architectures.") "in CUDA target architectures.")
endif() endif()
# moe_data.cu is used by all CUTLASS MoE kernels.
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
set(SRCS "csrc/quantization/w8a8/cutlass/moe/moe_data.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${CUTLASS_MOE_DATA_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
message(STATUS "Building moe_data for archs: ${CUTLASS_MOE_DATA_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
message(STATUS "Not building moe_data as CUDA Compiler version is "
"not >= 12.3, we recommend upgrading to CUDA 12.3 or later "
"if you intend on running FP8 quantized MoE models on Hopper or Blackwell.")
else()
message(STATUS "Not building moe_data as no compatible archs found "
"in CUDA target architectures.")
endif()
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
# #
# Machete kernels # Machete kernels
...@@ -964,7 +758,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -964,7 +758,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# _C_stable_libtorch extension (ops registered via STABLE_TORCH_LIBRARY) # _C_stable_libtorch extension (ops registered via STABLE_TORCH_LIBRARY)
# #
set(VLLM_STABLE_EXT_SRC set(VLLM_STABLE_EXT_SRC
"csrc/libtorch_stable/torch_bindings.cpp") "csrc/libtorch_stable/torch_bindings.cpp"
"csrc/cutlass_extensions/common.cpp"
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu")
if(VLLM_GPU_LANG STREQUAL "CUDA") if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_STABLE_EXT_SRC list(APPEND VLLM_STABLE_EXT_SRC
...@@ -979,6 +775,209 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -979,6 +775,209 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
CUDA_ARCHS "${CUDA_ARCHS}") CUDA_ARCHS "${CUDA_ARCHS}")
endif() endif()
#
# CUTLASS scaled_mm kernels (moved from _C to _C_stable_libtorch)
#
set(SCALED_MM_3X_ARCHS)
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
# CUDA 12.0 or later
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS)
set(SRCS
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c3x_sm90.cu"
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8.cu"
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8.cu"
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu"
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM90=1")
# Let scaled_mm_c2x know it doesn't need to build these arches
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
message(STATUS "Building scaled_mm_c3x_sm90 for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS)
message(STATUS "Not building scaled_mm_c3x_sm90 as CUDA Compiler version is "
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
"later if you intend on running FP8 quantized models on "
"Hopper.")
else()
message(STATUS "Not building scaled_mm_c3x_sm90 as no compatible archs found "
"in CUDA target architectures")
endif()
endif()
# The cutlass_scaled_mm kernels for Blackwell SM12x (c3x, i.e. CUTLASS 3.x) require
# CUDA 12.8 or later
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
set(SRCS
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c3x_sm120.cu"
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu"
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8.cu"
)
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM120=1")
# Let scaled_mm_c2x know it doesn't need to build these arches
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
message(STATUS "Building scaled_mm_c3x_sm120 for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
message(STATUS "Not building scaled_mm_c3x_sm120 as CUDA Compiler version is "
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
"later if you intend on running FP8 quantized models on "
"Blackwell.")
else()
message(STATUS "Not building scaled_mm_c3x_120 as no compatible archs found "
"in CUDA target architectures")
endif()
endif()
# The cutlass_scaled_mm kernels for Blackwell SM100 (c3x, i.e. CUTLASS 3.x)
# require CUDA 12.8 or later
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
set(SRCS
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c3x_sm100.cu"
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu"
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8.cu"
)
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM100=1")
# Let scaled_mm_c2x know it doesn't need to build these arches
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
message(STATUS "Building scaled_mm_c3x_sm100 for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
message(STATUS "Not building scaled_mm_c3x_sm100 as CUDA Compiler version is "
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
"later if you intend on running FP8 quantized models on "
"Blackwell.")
else()
message(STATUS "Not building scaled_mm_c3x_100 as no compatible archs found "
"in CUDA target architectures")
endif()
endif()
#
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
# kernels for the remaining archs that are not already built for 3x.
# (Build 8.9 for FP8)
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
"7.5;8.0;8.7;8.9+PTX" "${CUDA_ARCHS}")
# subtract out the archs that are already built for 3x
list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})
if (SCALED_MM_2X_ARCHS)
set(SRCS "csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c2x.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_2X_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C2X=1")
message(STATUS "Building scaled_mm_c2x for archs: ${SCALED_MM_2X_ARCHS}")
else()
if (SCALED_MM_3X_ARCHS)
message(STATUS "Not building scaled_mm_c2x as all archs are already built"
" for and covered by scaled_mm_c3x")
else()
message(STATUS "Not building scaled_mm_c2x as no compatible archs found "
"in CUDA target architectures")
endif()
endif()
#
# CUTLASS MoE kernels (moved from _C to _C_stable_libtorch)
#
# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and ONLY works
# on Hopper). get_cutlass_(batched_)moe_mm_data should only be compiled
# if it's possible to compile MoE kernels that use its output.
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
set(SRCS "csrc/libtorch_stable/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm90.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM90=1")
message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is "
"not >= 12.3, we recommend upgrading to CUDA 12.3 or later "
"if you intend on running FP8 quantized MoE models on Hopper.")
else()
message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
"in CUDA target architectures.")
endif()
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
set(SRCS "csrc/libtorch_stable/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is "
"not >= 12.8, we recommend upgrading to CUDA 12.8 or later "
"if you intend on running FP8 quantized MoE models on Blackwell.")
else()
message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
"in CUDA target architectures.")
endif()
endif()
# moe_data.cu is used by all CUTLASS MoE kernels.
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
set(SRCS "csrc/libtorch_stable/quantization/w8a8/cutlass/moe/moe_data.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${CUTLASS_MOE_DATA_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
message(STATUS "Building moe_data for archs: ${CUTLASS_MOE_DATA_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
message(STATUS "Not building moe_data as CUDA Compiler version is "
"not >= 12.3, we recommend upgrading to CUDA 12.3 or later "
"if you intend on running FP8 quantized MoE models on Hopper or Blackwell.")
else()
message(STATUS "Not building moe_data as no compatible archs found "
"in CUDA target architectures.")
endif()
endif()
message(STATUS "Enabling C_stable extension.") message(STATUS "Enabling C_stable extension.")
define_extension_target( define_extension_target(
_C_stable_libtorch _C_stable_libtorch
...@@ -987,6 +986,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -987,6 +986,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
SOURCES ${VLLM_STABLE_EXT_SRC} SOURCES ${VLLM_STABLE_EXT_SRC}
COMPILE_FLAGS ${VLLM_GPU_FLAGS} COMPILE_FLAGS ${VLLM_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES} ARCHITECTURES ${VLLM_GPU_ARCHES}
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
USE_SABI 3 USE_SABI 3
WITH_SOABI) WITH_SOABI)
...@@ -1000,6 +1000,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -1000,6 +1000,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# Needed to use cuda APIs from C-shim # Needed to use cuda APIs from C-shim
target_compile_definitions(_C_stable_libtorch PRIVATE target_compile_definitions(_C_stable_libtorch PRIVATE
USE_CUDA) USE_CUDA)
# Needed by CUTLASS kernels
target_compile_definitions(_C_stable_libtorch PRIVATE
CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
endif() endif()
# #
......
...@@ -6,13 +6,15 @@ ...@@ -6,13 +6,15 @@
#include <cstdio> #include <cstdio>
#include <cstdlib> #include <cstdlib>
#include <torch/headeronly/util/shim_utils.h>
/** /**
* Helper function for checking CUTLASS errors * Helper function for checking CUTLASS errors
*/ */
#define CUTLASS_CHECK(status) \ #define CUTLASS_CHECK(status) \
{ \ { \
cutlass::Status error = status; \ cutlass::Status error = status; \
TORCH_CHECK(error == cutlass::Status::kSuccess, \ STD_TORCH_CHECK(error == cutlass::Status::kSuccess, \
cutlassGetStatusString(error)); \ cutlassGetStatusString(error)); \
} }
......
...@@ -3,6 +3,14 @@ ...@@ -3,6 +3,14 @@
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp" #include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp" #include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp"
// This header is shared by both _C (unstable ABI) and _C_stable_libtorch
// (stable ABI) targets. When compiled under the stable ABI target,
// TORCH_TARGET_VERSION is defined and Tensor is unavailable, so we
// use torch::stable::Tensor instead.
#ifdef TORCH_TARGET_VERSION
#include <torch/csrc/stable/tensor.h>
#endif
/* /*
This file defines custom epilogues for fusing channel scales, token scales, This file defines custom epilogues for fusing channel scales, token scales,
bias, and activation zero-points onto a GEMM operation using the bias, and activation zero-points onto a GEMM operation using the
...@@ -15,6 +23,12 @@ ...@@ -15,6 +23,12 @@
namespace vllm::c3x { namespace vllm::c3x {
#ifdef TORCH_TARGET_VERSION
using TensorType = torch::stable::Tensor;
#else
using TensorType = torch::Tensor;
#endif
using namespace cute; using namespace cute;
template <typename T> template <typename T>
...@@ -84,7 +98,7 @@ struct ScaledEpilogueBase { ...@@ -84,7 +98,7 @@ struct ScaledEpilogueBase {
// from a tensor. It can handle both row and column, as well as row/column or // from a tensor. It can handle both row and column, as well as row/column or
// scalar cases. // scalar cases.
template <typename Descriptor, typename T> template <typename Descriptor, typename T>
static auto args_from_tensor(torch::Tensor const& tensor) { static auto args_from_tensor(TensorType const& tensor) {
using Arguments = typename Descriptor::Arguments; using Arguments = typename Descriptor::Arguments;
auto* data_ptr = static_cast<T*>(tensor.data_ptr()); auto* data_ptr = static_cast<T*>(tensor.data_ptr());
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> || if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
...@@ -100,7 +114,7 @@ struct ScaledEpilogueBase { ...@@ -100,7 +114,7 @@ struct ScaledEpilogueBase {
// This overload handles the case where there might not be a tensor, in which // This overload handles the case where there might not be a tensor, in which
// case a nullptr is passed and a constant (0) is used. // case a nullptr is passed and a constant (0) is used.
template <typename Descriptor, typename T> template <typename Descriptor, typename T>
static auto args_from_tensor(std::optional<torch::Tensor> const& tensor) { static auto args_from_tensor(std::optional<TensorType> const& tensor) {
using Arguments = typename Descriptor::Arguments; using Arguments = typename Descriptor::Arguments;
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr; auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> || static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> ||
...@@ -158,8 +172,8 @@ struct ScaledEpilogue ...@@ -158,8 +172,8 @@ struct ScaledEpilogue
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>; cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
using ArgumentType = typename EVTCompute::Arguments; using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales, static ArgumentType prepare_args(TensorType const& a_scales,
torch::Tensor const& b_scales) { TensorType const& b_scales) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales); auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales); auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
...@@ -203,9 +217,9 @@ struct ScaledEpilogueBias ...@@ -203,9 +217,9 @@ struct ScaledEpilogueBias
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>; cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments; using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales, static ArgumentType prepare_args(TensorType const& a_scales,
torch::Tensor const& b_scales, TensorType const& b_scales,
torch::Tensor const& bias) { TensorType const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales); auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales); auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias); auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
...@@ -246,9 +260,9 @@ struct ScaledEpilogueColumnBias ...@@ -246,9 +260,9 @@ struct ScaledEpilogueColumnBias
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>; cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments; using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales, static ArgumentType prepare_args(TensorType const& a_scales,
torch::Tensor const& b_scales, TensorType const& b_scales,
torch::Tensor const& bias) { TensorType const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales); auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales); auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias); auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
...@@ -304,10 +318,10 @@ struct ScaledEpilogueBiasAzp ...@@ -304,10 +318,10 @@ struct ScaledEpilogueBiasAzp
EVTComputeScaleB, Bias>; EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments; using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales, static ArgumentType prepare_args(TensorType const& a_scales,
torch::Tensor const& b_scales, TensorType const& b_scales,
torch::Tensor const& azp_adj, TensorType const& azp_adj,
std::optional<torch::Tensor> const& bias) { std::optional<TensorType> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales); auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales); auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias); auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
...@@ -380,11 +394,11 @@ struct ScaledEpilogueBiasAzpToken ...@@ -380,11 +394,11 @@ struct ScaledEpilogueBiasAzpToken
EVTComputeScaleB, Bias>; EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments; using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales, static ArgumentType prepare_args(TensorType const& a_scales,
torch::Tensor const& b_scales, TensorType const& b_scales,
torch::Tensor const& azp_adj, TensorType const& azp_adj,
torch::Tensor const& azp, TensorType const& azp,
std::optional<torch::Tensor> const& bias) { std::optional<TensorType> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales); auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales); auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias); auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
......
#pragma once #pragma once
#include <torch/csrc/stable/tensor.h>
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp" #include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp"
/* /*
...@@ -52,7 +54,7 @@ struct ScaledEpilogueBase { ...@@ -52,7 +54,7 @@ struct ScaledEpilogueBase {
// from a tensor. It can handle both row and column, as well as row/column or // from a tensor. It can handle both row and column, as well as row/column or
// scalar cases. // scalar cases.
template <typename Descriptor, typename T> template <typename Descriptor, typename T>
static auto args_from_tensor(torch::Tensor const& tensor) { static auto args_from_tensor(torch::stable::Tensor const& tensor) {
using Arguments = typename Descriptor::Arguments; using Arguments = typename Descriptor::Arguments;
auto* data_ptr = static_cast<T*>(tensor.data_ptr()); auto* data_ptr = static_cast<T*>(tensor.data_ptr());
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> || if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
...@@ -68,7 +70,8 @@ struct ScaledEpilogueBase { ...@@ -68,7 +70,8 @@ struct ScaledEpilogueBase {
// This overload handles the case where there might not be a tensor, in which // This overload handles the case where there might not be a tensor, in which
// case a nullptr is passed and a constant (0) is used. // case a nullptr is passed and a constant (0) is used.
template <typename Descriptor, typename T> template <typename Descriptor, typename T>
static auto args_from_tensor(std::optional<torch::Tensor> const& tensor) { static auto args_from_tensor(
std::optional<torch::stable::Tensor> const& tensor) {
static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>); static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
using Arguments = typename Descriptor::Arguments; using Arguments = typename Descriptor::Arguments;
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr; auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
...@@ -117,8 +120,8 @@ struct ScaledEpilogue ...@@ -117,8 +120,8 @@ struct ScaledEpilogue
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>; cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
using ArgumentType = typename EVTCompute::Arguments; using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales, static ArgumentType prepare_args(torch::stable::Tensor const& a_scales,
torch::Tensor const& b_scales) { torch::stable::Tensor const& b_scales) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales); auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales); auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
...@@ -160,9 +163,9 @@ struct ScaledEpilogueBias ...@@ -160,9 +163,9 @@ struct ScaledEpilogueBias
using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
EVTCompute0, Bias>; EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments; using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales, static ArgumentType prepare_args(torch::stable::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::stable::Tensor const& b_scales,
torch::Tensor const& bias) { torch::stable::Tensor const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales); auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales); auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias); auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
...@@ -220,10 +223,11 @@ struct ScaledEpilogueBiasAzp ...@@ -220,10 +223,11 @@ struct ScaledEpilogueBiasAzp
using ArgumentType = typename EVTCompute::Arguments; using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales, static ArgumentType prepare_args(
torch::Tensor const& b_scales, torch::stable::Tensor const& a_scales,
torch::Tensor const& azp_adj, torch::stable::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias) { torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales); auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales); auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias); auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
...@@ -298,11 +302,11 @@ struct ScaledEpilogueBiasAzpToken ...@@ -298,11 +302,11 @@ struct ScaledEpilogueBiasAzpToken
using ArgumentType = typename EVTCompute::Arguments; using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales, static ArgumentType prepare_args(
torch::Tensor const& b_scales, torch::stable::Tensor const& a_scales,
torch::Tensor const& azp_adj, torch::stable::Tensor const& b_scales,
torch::Tensor const& azp, torch::stable::Tensor const& azp_adj, torch::stable::Tensor const& azp,
std::optional<torch::Tensor> const& bias) { std::optional<torch::stable::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales); auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales); auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias); auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
......
...@@ -27,4 +27,61 @@ void per_token_group_quant_int8(const torch::stable::Tensor& input, ...@@ -27,4 +27,61 @@ void per_token_group_quant_int8(const torch::stable::Tensor& input,
torch::stable::Tensor& output_s, torch::stable::Tensor& output_s,
int64_t group_size, double eps, double int8_min, int64_t group_size, double eps, double int8_min,
double int8_max); double int8_max);
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);
bool cutlass_group_gemm_supported(int64_t cuda_device_capability);
void cutlass_scaled_mm(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_moe_mm(torch::stable::Tensor& out_tensors,
torch::stable::Tensor const& a_tensors,
torch::stable::Tensor const& b_tensors,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& expert_offsets,
torch::stable::Tensor const& problem_sizes,
torch::stable::Tensor const& a_strides,
torch::stable::Tensor const& b_strides,
torch::stable::Tensor const& c_strides, bool per_act_token,
bool per_out_ch);
void cutlass_scaled_mm_azp(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias);
void get_cutlass_moe_mm_data(
const torch::stable::Tensor& topk_ids,
torch::stable::Tensor& expert_offsets,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2,
torch::stable::Tensor& input_permutation,
torch::stable::Tensor& output_permutation, const int64_t num_experts,
const int64_t n, const int64_t k,
const std::optional<torch::stable::Tensor>& blockscale_offsets,
const bool is_gated);
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
const torch::stable::Tensor& expert_first_token_offset,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2, const int64_t n, const int64_t k,
const bool swap_ab);
void get_cutlass_batched_moe_mm_data(
torch::stable::Tensor& expert_offsets,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2,
const torch::stable::Tensor& expert_num_tokens,
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
const int64_t k);
#endif #endif
...@@ -2,9 +2,10 @@ ...@@ -2,9 +2,10 @@
// clang-format will break include orders // clang-format will break include orders
// clang-format off // clang-format off
#include <torch/all.h> #include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/ops.h>
#include <ATen/cuda/CUDAContext.h> #include "libtorch_stable/torch_utils.h"
#include "cutlass/cutlass.h" #include "cutlass/cutlass.h"
...@@ -25,14 +26,14 @@ ...@@ -25,14 +26,14 @@
namespace vllm::c3x { namespace vllm::c3x {
static inline cute::Shape<int, int, int, int> get_problem_shape( static inline cute::Shape<int, int, int, int> get_problem_shape(
torch::Tensor const& a, torch::Tensor const& b) { torch::stable::Tensor const& a, torch::stable::Tensor const& b) {
int32_t m = a.size(0), n = b.size(1), k = a.size(1); int32_t m = a.size(0), n = b.size(1), k = a.size(1);
return {m, n, k, 1}; return {m, n, k, 1};
} }
template <typename GemmKernel> template <typename GemmKernel>
void cutlass_gemm_caller( void cutlass_gemm_caller(
torch::Device device, cute::Shape<int, int, int, int> prob_shape, torch::stable::Device device, cute::Shape<int, int, int, int> prob_shape,
typename GemmKernel::MainloopArguments mainloop_args, typename GemmKernel::MainloopArguments mainloop_args,
typename GemmKernel::EpilogueArguments epilogue_args, typename GemmKernel::EpilogueArguments epilogue_args,
typename GemmKernel::TileSchedulerArguments scheduler = {}) { typename GemmKernel::TileSchedulerArguments scheduler = {}) {
...@@ -50,19 +51,20 @@ void cutlass_gemm_caller( ...@@ -50,19 +51,20 @@ void cutlass_gemm_caller(
CUTLASS_CHECK(gemm_op.can_implement(args)); CUTLASS_CHECK(gemm_op.can_implement(args));
size_t workspace_size = gemm_op.get_workspace_size(args); size_t workspace_size = gemm_op.get_workspace_size(args);
auto const workspace_options = auto workspace =
torch::TensorOptions().dtype(torch::kUInt8).device(device); torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
auto workspace = torch::empty(workspace_size, workspace_options); std::nullopt, device);
auto stream = at::cuda::getCurrentCUDAStream(device.index()); auto stream = get_current_cuda_stream(device.index());
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
CUTLASS_CHECK(status); CUTLASS_CHECK(status);
} }
template <typename Gemm, typename... EpilogueArgs> template <typename Gemm, typename... EpilogueArgs>
void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, void cutlass_gemm_caller(torch::stable::Tensor& out,
torch::Tensor const& b, torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... epilogue_params) { EpilogueArgs&&... epilogue_params) {
using ElementAB = typename Gemm::ElementAB; using ElementAB = typename Gemm::ElementAB;
using ElementC = typename Gemm::ElementC; using ElementC = typename Gemm::ElementC;
......
...@@ -4,13 +4,12 @@ ...@@ -4,13 +4,12 @@
namespace vllm { namespace vllm {
void cutlass_scaled_mm_azp_sm90_int8(torch::Tensor& out, torch::Tensor const& a, void cutlass_scaled_mm_azp_sm90_int8(
torch::Tensor const& b, torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::Tensor const& a_scales, torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
torch::Tensor const& azp_adj, std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::Tensor> const& azp, std::optional<torch::stable::Tensor> const& bias) {
std::optional<torch::Tensor> const& bias) {
if (azp) { if (azp) {
return cutlass_scaled_mm_sm90_int8_epilogue< return cutlass_scaled_mm_sm90_int8_epilogue<
c3x::ScaledEpilogueBiasAzpToken>(out, a, b, a_scales, b_scales, azp_adj, c3x::ScaledEpilogueBiasAzpToken>(out, a, b, a_scales, b_scales, azp_adj,
......
...@@ -4,17 +4,16 @@ ...@@ -4,17 +4,16 @@
namespace vllm { namespace vllm {
void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out, void cutlass_scaled_mm_blockwise_sm100_fp8(
torch::Tensor const& a, torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::Tensor const& b, torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::Tensor const& a_scales, torch::stable::Tensor const& b_scales) {
torch::Tensor const& b_scales) { if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
if (out.dtype() == torch::kBFloat16) {
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::bfloat16_t>( cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::bfloat16_t>(
out, a, b, a_scales, b_scales); out, a, b, a_scales, b_scales);
} else { } else {
TORCH_CHECK(out.dtype() == torch::kFloat16); STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::half_t>( cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::half_t>(
out, a, b, a_scales, b_scales); out, a, b, a_scales, b_scales);
} }
......
#pragma once #pragma once
#include <torch/headeronly/util/shim_utils.h>
#include "cuda_utils.h" #include "cuda_utils.h"
#include "cutlass/cutlass.h" #include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h" #include "cutlass/numeric_types.h"
...@@ -130,10 +132,10 @@ struct cutlass_3x_gemm_fp8_blockwise { ...@@ -130,10 +132,10 @@ struct cutlass_3x_gemm_fp8_blockwise {
}; };
template <typename Gemm> template <typename Gemm>
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::Tensor const& b, torch::stable::Tensor const& b,
torch::Tensor const& a_scales, torch::stable::Tensor const& a_scales,
torch::Tensor const& b_scales) { torch::stable::Tensor const& b_scales) {
static constexpr bool swap_ab = Gemm::swap_ab; static constexpr bool swap_ab = Gemm::swap_ab;
using GemmKernel = typename Gemm::GemmKernel; using GemmKernel = typename Gemm::GemmKernel;
using StrideA = typename Gemm::GemmKernel::StrideA; using StrideA = typename Gemm::GemmKernel::StrideA;
...@@ -200,11 +202,11 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, ...@@ -200,11 +202,11 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
} }
template <typename OutType> template <typename OutType>
void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out, void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::stable::Tensor& out,
torch::Tensor const& a, torch::stable::Tensor const& a,
torch::Tensor const& b, torch::stable::Tensor const& b,
torch::Tensor const& a_scales, torch::stable::Tensor const& a_scales,
torch::Tensor const& b_scales) { torch::stable::Tensor const& b_scales) {
int32_t m = a.size(0), n = b.size(1), k = a.size(1), sms; int32_t m = a.size(0), n = b.size(1), k = a.size(1), sms;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device()); cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device());
......
...@@ -4,17 +4,16 @@ ...@@ -4,17 +4,16 @@
namespace vllm { namespace vllm {
void cutlass_scaled_mm_blockwise_sm120_fp8(torch::Tensor& out, void cutlass_scaled_mm_blockwise_sm120_fp8(
torch::Tensor const& a, torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::Tensor const& b, torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::Tensor const& a_scales, torch::stable::Tensor const& b_scales) {
torch::Tensor const& b_scales) { if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
if (out.dtype() == torch::kBFloat16) {
cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::bfloat16_t>( cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::bfloat16_t>(
out, a, b, a_scales, b_scales); out, a, b, a_scales, b_scales);
} else { } else {
TORCH_CHECK(out.dtype() == torch::kFloat16); STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::half_t>( cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::half_t>(
out, a, b, a_scales, b_scales); out, a, b, a_scales, b_scales);
} }
......
#pragma once #pragma once
#include <torch/headeronly/util/shim_utils.h>
#include "cuda_utils.h" #include "cuda_utils.h"
#include "cutlass/cutlass.h" #include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h" #include "cutlass/numeric_types.h"
...@@ -138,10 +140,10 @@ struct sm120_blockwise_fp8_config_M64 { ...@@ -138,10 +140,10 @@ struct sm120_blockwise_fp8_config_M64 {
}; };
template <typename Gemm> template <typename Gemm>
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::Tensor const& b, torch::stable::Tensor const& b,
torch::Tensor const& a_scales, torch::stable::Tensor const& a_scales,
torch::Tensor const& b_scales) { torch::stable::Tensor const& b_scales) {
using GemmKernel = typename Gemm::GemmKernel; using GemmKernel = typename Gemm::GemmKernel;
using StrideA = typename Gemm::GemmKernel::StrideA; using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB; using StrideB = typename Gemm::GemmKernel::StrideB;
...@@ -196,11 +198,11 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, ...@@ -196,11 +198,11 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
} }
template <typename OutType> template <typename OutType>
void cutlass_gemm_blockwise_sm120_fp8_dispatch(torch::Tensor& out, void cutlass_gemm_blockwise_sm120_fp8_dispatch(torch::stable::Tensor& out,
torch::Tensor const& a, torch::stable::Tensor const& a,
torch::Tensor const& b, torch::stable::Tensor const& b,
torch::Tensor const& a_scales, torch::stable::Tensor const& a_scales,
torch::Tensor const& b_scales) { torch::stable::Tensor const& b_scales) {
int M = a.size(0); int M = a.size(0);
if (M <= 256) { if (M <= 256) {
using Gemm = typename sm120_blockwise_fp8_config_M64<OutType>::Gemm; using Gemm = typename sm120_blockwise_fp8_config_M64<OutType>::Gemm;
......
...@@ -5,17 +5,16 @@ ...@@ -5,17 +5,16 @@
namespace vllm { namespace vllm {
void cutlass_scaled_mm_blockwise_sm90_fp8(torch::Tensor& out, void cutlass_scaled_mm_blockwise_sm90_fp8(
torch::Tensor const& a, torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::Tensor const& b, torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::Tensor const& a_scales, torch::stable::Tensor const& b_scales) {
torch::Tensor const& b_scales) { if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
if (out.dtype() == torch::kBFloat16) {
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::bfloat16_t>( cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::bfloat16_t>(
out, a, b, a_scales, b_scales); out, a, b, a_scales, b_scales);
} else { } else {
TORCH_CHECK(out.dtype() == torch::kFloat16); STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::half_t>( cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::half_t>(
out, a, b, a_scales, b_scales); out, a, b, a_scales, b_scales);
} }
......
#pragma once #pragma once
#include <torch/headeronly/util/shim_utils.h>
#include "cutlass/cutlass.h" #include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h" #include "cutlass/numeric_types.h"
...@@ -101,10 +103,10 @@ struct cutlass_3x_gemm_fp8_blockwise { ...@@ -101,10 +103,10 @@ struct cutlass_3x_gemm_fp8_blockwise {
}; };
template <typename Gemm> template <typename Gemm>
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::Tensor const& b, torch::stable::Tensor const& b,
torch::Tensor const& a_scales, torch::stable::Tensor const& a_scales,
torch::Tensor const& b_scales) { torch::stable::Tensor const& b_scales) {
using GemmKernel = typename Gemm::GemmKernel; using GemmKernel = typename Gemm::GemmKernel;
using StrideA = typename Gemm::GemmKernel::StrideA; using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB; using StrideB = typename Gemm::GemmKernel::StrideB;
...@@ -120,7 +122,7 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, ...@@ -120,7 +122,7 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
int32_t m = a.size(0), n = b.size(1), k = a.size(1); int32_t m = a.size(0), n = b.size(1), k = a.size(1);
TORCH_CHECK(m % 4 == 0, "m must be divisible by 4"); STD_TORCH_CHECK(m % 4 == 0, "m must be divisible by 4");
StrideA a_stride; StrideA a_stride;
StrideB b_stride; StrideB b_stride;
...@@ -161,11 +163,11 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, ...@@ -161,11 +163,11 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
} }
template <typename OutType> template <typename OutType>
void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out, void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::stable::Tensor& out,
torch::Tensor const& a, torch::stable::Tensor const& a,
torch::Tensor const& b, torch::stable::Tensor const& b,
torch::Tensor const& a_scales, torch::stable::Tensor const& a_scales,
torch::Tensor const& b_scales) { torch::stable::Tensor const& b_scales) {
// TODO: better heuristics // TODO: better heuristics
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise< cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
OutType, 1, 128, 128, Shape<_128, _128, _128>, OutType, 1, 128, 128, Shape<_128, _128, _128>,
......
#include <torch/all.h> #include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/core/ScalarType.h>
#include "cuda_utils.h" #include "cuda_utils.h"
#include "cutlass_extensions/common.hpp" #include "cutlass_extensions/common.hpp"
template <typename Fp8Func, typename Int8Func, typename BlockwiseFunc> template <typename Fp8Func, typename Int8Func, typename BlockwiseFunc>
void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a, void dispatch_scaled_mm(torch::stable::Tensor& c,
torch::Tensor const& b, torch::Tensor const& a_scales, torch::stable::Tensor const& a,
torch::Tensor const& b_scales, torch::stable::Tensor const& b,
std::optional<torch::Tensor> const& bias, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias,
Fp8Func fp8_func, Int8Func int8_func, Fp8Func fp8_func, Int8Func int8_func,
BlockwiseFunc blockwise_func) { BlockwiseFunc blockwise_func) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32); STD_TORCH_CHECK(a_scales.scalar_type() ==
TORCH_CHECK(b_scales.dtype() == torch::kFloat32); torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
int M = a.size(0), N = b.size(1), K = a.size(1); int M = a.size(0), N = b.size(1), K = a.size(1);
if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) && if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
(b_scales.numel() == 1 || b_scales.numel() == b.size(1))) { (b_scales.numel() == 1 || b_scales.numel() == b.size(1))) {
// Standard per-tensor/per-token/per-channel scaling // Standard per-tensor/per-token/per-channel scaling
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (a.dtype() == torch::kFloat8_e4m3fn) { if (a.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn) {
fp8_func(c, a, b, a_scales, b_scales, bias); fp8_func(c, a, b, a_scales, b_scales, bias);
} else { } else {
TORCH_CHECK(a.dtype() == torch::kInt8); STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
if constexpr (!std::is_same_v<Int8Func, std::nullptr_t>) { if constexpr (!std::is_same_v<Int8Func, std::nullptr_t>) {
int8_func(c, a, b, a_scales, b_scales, bias); int8_func(c, a, b, a_scales, b_scales, bias);
} else { } else {
int32_t version_num = get_sm_version_num(); int32_t version_num = get_sm_version_num();
TORCH_CHECK( STD_TORCH_CHECK(
false, "Int8 not supported on SM", version_num, false, "Int8 not supported on SM", version_num,
". Use FP8 quantization instead, or run on older arch (SM < 100)."); ". Use FP8 quantization instead, or run on older arch (SM < 100).");
} }
} }
} else { } else {
TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor."); STD_TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor.");
TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor."); STD_TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor.");
int32_t version_num = get_sm_version_num(); int32_t version_num = get_sm_version_num();
if (version_num >= 90) { if (version_num >= 90) {
TORCH_CHECK( STD_TORCH_CHECK(
a.size(0) == a_scales.size(0) && a.size(0) == a_scales.size(0) &&
cuda_utils::ceil_div(a.size(1), int64_t(128)) == a_scales.size(1), cuda_utils::ceil_div(a.size(1), int64_t(128)) == a_scales.size(1),
"a_scale_group_shape must be [1, 128]."); "a_scale_group_shape must be [1, 128].");
TORCH_CHECK( STD_TORCH_CHECK(
cuda_utils::ceil_div(b.size(0), int64_t(128)) == b_scales.size(0) && cuda_utils::ceil_div(b.size(0), int64_t(128)) == b_scales.size(0) &&
cuda_utils::ceil_div(b.size(1), int64_t(128)) == b_scales.size(1), cuda_utils::ceil_div(b.size(1), int64_t(128)) == b_scales.size(1),
"b_scale_group_shape must be [128, 128]."); "b_scale_group_shape must be [128, 128].");
} }
TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm"); STD_TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
blockwise_func(c, a, b, a_scales, b_scales); blockwise_func(c, a, b, a_scales, b_scales);
} }
} }
#pragma once
#include <torch/csrc/stable/tensor.h>
namespace vllm {
void cutlass_scaled_mm_sm90_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_sm90_int8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_azp_sm90_int8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_blockwise_sm90_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales);
void cutlass_scaled_mm_sm100_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_sm120_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_blockwise_sm100_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales);
void cutlass_scaled_mm_blockwise_sm120_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales);
} // namespace vllm
...@@ -3,15 +3,16 @@ ...@@ -3,15 +3,16 @@
namespace vllm { namespace vllm {
void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a, void cutlass_scaled_mm_sm100_fp8(
torch::Tensor const& b, torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::Tensor const& a_scales, torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::stable::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias) { std::optional<torch::stable::Tensor> const& bias) {
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (bias) { if (bias) {
TORCH_CHECK(bias->dtype() == out.dtype(), STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
"currently bias dtype must match output dtype ", out.dtype()); "currently bias dtype must match output dtype ",
out.scalar_type());
return cutlass_scaled_mm_sm100_fp8_epilogue<true>(out, a, b, a_scales, return cutlass_scaled_mm_sm100_fp8_epilogue<true>(out, a, b, a_scales,
b_scales, *bias); b_scales, *bias);
} else { } else {
......
#pragma once #pragma once
#include <torch/headeronly/util/shim_utils.h>
#include "scaled_mm.cuh" #include "scaled_mm.cuh"
#include "cutlass_gemm_caller.cuh" #include "cutlass_gemm_caller.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
...@@ -192,8 +194,9 @@ struct sm100_fp8_config_M16_swap_ab { ...@@ -192,8 +194,9 @@ struct sm100_fp8_config_M16_swap_ab {
}; };
template <typename Gemm, typename... EpilogueArgs> template <typename Gemm, typename... EpilogueArgs>
void cutlass_gemm_caller_sm100_fp8(torch::Tensor& out, torch::Tensor const& a, void cutlass_gemm_caller_sm100_fp8(torch::stable::Tensor& out,
torch::Tensor const& b, torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... epilogue_params) { EpilogueArgs&&... epilogue_params) {
static constexpr bool swap_ab = Gemm::swap_ab; static constexpr bool swap_ab = Gemm::swap_ab;
using ElementAB = typename Gemm::ElementAB; using ElementAB = typename Gemm::ElementAB;
...@@ -237,15 +240,15 @@ void cutlass_gemm_caller_sm100_fp8(torch::Tensor& out, torch::Tensor const& a, ...@@ -237,15 +240,15 @@ void cutlass_gemm_caller_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
template <typename InType, typename OutType, bool EnableBias, template <typename InType, typename OutType, bool EnableBias,
typename... EpilogueArgs> typename... EpilogueArgs>
inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out, inline void cutlass_gemm_sm100_fp8_dispatch(
torch::Tensor const& a, torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::Tensor const& b, torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::Tensor const& a_scales, torch::stable::Tensor const& b_scales, EpilogueArgs&&... args) {
torch::Tensor const& b_scales,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); STD_TORCH_CHECK(a.scalar_type() ==
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); torch::headeronly::ScalarType::Float8_e4m3fn);
STD_TORCH_CHECK(b.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
using Cutlass3xGemmDefault = using Cutlass3xGemmDefault =
typename sm100_fp8_config_default<InType, OutType, typename sm100_fp8_config_default<InType, OutType,
...@@ -292,22 +295,24 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out, ...@@ -292,22 +295,24 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
} }
template <bool EnableBias, typename... EpilogueArgs> template <bool EnableBias, typename... EpilogueArgs>
void cutlass_scaled_mm_sm100_fp8_epilogue(torch::Tensor& out, void cutlass_scaled_mm_sm100_fp8_epilogue(torch::stable::Tensor& out,
torch::Tensor const& a, torch::stable::Tensor const& a,
torch::Tensor const& b, torch::stable::Tensor const& b,
torch::Tensor const& a_scales, torch::stable::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::stable::Tensor const& b_scales,
EpilogueArgs&&... epilogue_args) { EpilogueArgs&&... epilogue_args) {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); STD_TORCH_CHECK(a.scalar_type() ==
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); torch::headeronly::ScalarType::Float8_e4m3fn);
STD_TORCH_CHECK(b.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
if (out.dtype() == torch::kBFloat16) { if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
return cutlass_gemm_sm100_fp8_dispatch<cutlass::float_e4m3_t, return cutlass_gemm_sm100_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::bfloat16_t, EnableBias>( cutlass::bfloat16_t, EnableBias>(
out, a, b, a_scales, b_scales, out, a, b, a_scales, b_scales,
std::forward<EpilogueArgs>(epilogue_args)...); std::forward<EpilogueArgs>(epilogue_args)...);
} else { } else {
TORCH_CHECK(out.dtype() == torch::kFloat16); STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
return cutlass_gemm_sm100_fp8_dispatch<cutlass::float_e4m3_t, return cutlass_gemm_sm100_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t, EnableBias>( cutlass::half_t, EnableBias>(
out, a, b, a_scales, b_scales, out, a, b, a_scales, b_scales,
......
...@@ -4,15 +4,16 @@ ...@@ -4,15 +4,16 @@
namespace vllm { namespace vllm {
void cutlass_scaled_mm_sm120_fp8(torch::Tensor& out, torch::Tensor const& a, void cutlass_scaled_mm_sm120_fp8(
torch::Tensor const& b, torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::Tensor const& a_scales, torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::stable::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias) { std::optional<torch::stable::Tensor> const& bias) {
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (bias) { if (bias) {
TORCH_CHECK(bias->dtype() == out.dtype(), STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
"currently bias dtype must match output dtype ", out.dtype()); "currently bias dtype must match output dtype ",
out.scalar_type());
return cutlass_scaled_mm_sm120_fp8_epilogue<c3x::ScaledEpilogueBias>( return cutlass_scaled_mm_sm120_fp8_epilogue<c3x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias); out, a, b, a_scales, b_scales, *bias);
} else { } else {
......
#pragma once #pragma once
#include <torch/headeronly/util/shim_utils.h>
#include "scaled_mm.cuh" #include "scaled_mm.cuh"
#include "cutlass_gemm_caller.cuh" #include "cutlass_gemm_caller.cuh"
...@@ -138,13 +140,15 @@ struct sm120_fp8_config_M16 { ...@@ -138,13 +140,15 @@ struct sm120_fp8_config_M16 {
template <typename InType, typename OutType, template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue, template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs> typename... EpilogueArgs>
inline void cutlass_gemm_sm120_fp8_dispatch(torch::Tensor& out, inline void cutlass_gemm_sm120_fp8_dispatch(torch::stable::Tensor& out,
torch::Tensor const& a, torch::stable::Tensor const& a,
torch::Tensor const& b, torch::stable::Tensor const& b,
EpilogueArgs&&... args) { EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); STD_TORCH_CHECK(a.scalar_type() ==
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); torch::headeronly::ScalarType::Float8_e4m3fn);
STD_TORCH_CHECK(b.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
int M = a.size(0); int M = a.size(0);
...@@ -177,19 +181,21 @@ inline void cutlass_gemm_sm120_fp8_dispatch(torch::Tensor& out, ...@@ -177,19 +181,21 @@ inline void cutlass_gemm_sm120_fp8_dispatch(torch::Tensor& out,
template <template <typename, typename, typename> typename Epilogue, template <template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs> typename... EpilogueArgs>
void cutlass_scaled_mm_sm120_fp8_epilogue(torch::Tensor& out, void cutlass_scaled_mm_sm120_fp8_epilogue(torch::stable::Tensor& out,
torch::Tensor const& a, torch::stable::Tensor const& a,
torch::Tensor const& b, torch::stable::Tensor const& b,
EpilogueArgs&&... epilogue_args) { EpilogueArgs&&... epilogue_args) {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); STD_TORCH_CHECK(a.scalar_type() ==
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); torch::headeronly::ScalarType::Float8_e4m3fn);
STD_TORCH_CHECK(b.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
if (out.dtype() == torch::kBFloat16) { if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
return cutlass_gemm_sm120_fp8_dispatch<cutlass::float_e4m3_t, return cutlass_gemm_sm120_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::bfloat16_t, Epilogue>( cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else { } else {
TORCH_CHECK(out.dtype() == torch::kFloat16); STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
return cutlass_gemm_sm120_fp8_dispatch<cutlass::float_e4m3_t, return cutlass_gemm_sm120_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t, Epilogue>( cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment