"vllm/vscode:/vscode.git/clone" did not exist on "99caa4910651754f3f68de518ca42349c8c424d1"
Unverified Commit 1656ad37 authored by Jinzhen Lin's avatar Jinzhen Lin Committed by GitHub
Browse files

[Kernel][Quantization] add w4a8 support for marlin kernel (#24722)


Signed-off-by: default avatarJinzhen Lin <jinzhen.ljz@antgroup.com>
Signed-off-by: default avatarMichael Goin <mgoin64@gmail.com>
Signed-off-by: default avatarJinzhen Lin <linjinzhen@hotmail.com>
Co-authored-by: default avatarMichael Goin <mgoin64@gmail.com>
Co-authored-by: default avatarMichael Goin <mgoin@redhat.com>
parent fa59fe41
...@@ -354,8 +354,17 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -354,8 +354,17 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# Only build Marlin kernels if we are building for at least some compatible archs. # Only build Marlin kernels if we are building for at least some compatible archs.
# Keep building Marlin for 9.0 as there are some group sizes and shapes that # Keep building Marlin for 9.0 as there are some group sizes and shapes that
# are not supported by Machete yet. # are not supported by Machete yet.
# 9.0 for latest bf16 atomicAdd PTX
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}") # marlin arches for fp16 output
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0+PTX" "${CUDA_ARCHS}")
# marlin arches for bf16 output (we need 9.0 for bf16 atomicAdd PTX)
cuda_archs_loose_intersection(MARLIN_BF16_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}")
# marlin arches for fp8 input
# - sm80 doesn't support fp8 computation
# - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction
# so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0)
cuda_archs_loose_intersection(MARLIN_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}")
if (MARLIN_ARCHS) if (MARLIN_ARCHS)
# #
...@@ -365,16 +374,18 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -365,16 +374,18 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set(MARLIN_GEN_SCRIPT set(MARLIN_GEN_SCRIPT
${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/gptq_marlin/generate_kernels.py) ${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/gptq_marlin/generate_kernels.py)
file(MD5 ${MARLIN_GEN_SCRIPT} MARLIN_GEN_SCRIPT_HASH) file(MD5 ${MARLIN_GEN_SCRIPT} MARLIN_GEN_SCRIPT_HASH)
list(JOIN CUDA_ARCHS "," CUDA_ARCHS_STR)
set(MARLIN_GEN_SCRIPT_HASH_AND_ARCH "${MARLIN_GEN_SCRIPT_HASH}(ARCH:${CUDA_ARCHS_STR})")
message(STATUS "Marlin generation script hash: ${MARLIN_GEN_SCRIPT_HASH}") message(STATUS "Marlin generation script hash: ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH}")
message(STATUS "Last run Marlin generate script hash: $CACHE{MARLIN_GEN_SCRIPT_HASH}") message(STATUS "Last run Marlin generate script hash: $CACHE{MARLIN_GEN_SCRIPT_HASH_AND_ARCH}")
if (NOT DEFINED CACHE{MARLIN_GEN_SCRIPT_HASH} if (NOT DEFINED CACHE{MARLIN_GEN_SCRIPT_HASH_AND_ARCH}
OR NOT $CACHE{MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MARLIN_GEN_SCRIPT_HASH}) OR NOT $CACHE{MARLIN_GEN_SCRIPT_HASH_AND_ARCH} STREQUAL ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH})
execute_process( execute_process(
COMMAND ${CMAKE_COMMAND} -E env COMMAND ${CMAKE_COMMAND} -E env
PYTHONPATH=$PYTHONPATH PYTHONPATH=$PYTHONPATH
${Python_EXECUTABLE} ${MARLIN_GEN_SCRIPT} ${Python_EXECUTABLE} ${MARLIN_GEN_SCRIPT} ${CUDA_ARCHS_STR}
RESULT_VARIABLE marlin_generation_result RESULT_VARIABLE marlin_generation_result
OUTPUT_VARIABLE marlin_generation_result OUTPUT_VARIABLE marlin_generation_result
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log
...@@ -387,15 +398,15 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -387,15 +398,15 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"\nCheck the log for details: " "\nCheck the log for details: "
"${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log") "${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log")
else() else()
set(MARLIN_GEN_SCRIPT_HASH ${MARLIN_GEN_SCRIPT_HASH} set(MARLIN_GEN_SCRIPT_HASH_AND_ARCH ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH}
CACHE STRING "Last run Marlin generate script hash" FORCE) CACHE STRING "Last run Marlin generate script hash and arch" FORCE)
message(STATUS "Marlin generation completed successfully.") message(STATUS "Marlin generation completed successfully.")
endif() endif()
else() else()
message(STATUS "Marlin generation script has not changed, skipping generation.") message(STATUS "Marlin generation script has not changed, skipping generation.")
endif() endif()
file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/kernel_*.cu") file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_float16.cu")
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(
SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}" SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}"
CUDA_ARCHS "${MARLIN_ARCHS}") CUDA_ARCHS "${MARLIN_ARCHS}")
...@@ -403,12 +414,34 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -403,12 +414,34 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set_source_files_properties(${MARLIN_TEMPLATE_KERNEL_SRC} set_source_files_properties(${MARLIN_TEMPLATE_KERNEL_SRC}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
endif() endif()
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC}) list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC})
file(GLOB MARLIN_TEMPLATE_BF16_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_bfloat16.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_TEMPLATE_BF16_KERNEL_SRC}"
CUDA_ARCHS "${MARLIN_BF16_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
set_source_files_properties(${MARLIN_TEMPLATE_BF16_KERNEL_SRC}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
endif()
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_BF16_KERNEL_SRC})
if (MARLIN_FP8_ARCHS)
file(GLOB MARLIN_TEMPLATE_FP8_KERNEL_SRC "csrc/quantization/gptq_marlin/sm89_kernel_*.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_TEMPLATE_FP8_KERNEL_SRC}"
CUDA_ARCHS "${MARLIN_FP8_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
set_source_files_properties(${MARLIN_TEMPLATE_FP8_KERNEL_SRC}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
endif()
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_FP8_KERNEL_SRC})
endif()
set(MARLIN_SRCS set(MARLIN_SRCS
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu" "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
"csrc/quantization/gptq_marlin/gptq_marlin.cu" "csrc/quantization/gptq_marlin/gptq_marlin.cu"
"csrc/quantization/gptq_marlin/marlin_int4_fp8_preprocess.cu"
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu") "csrc/quantization/gptq_marlin/awq_marlin_repack.cu")
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(
...@@ -941,8 +974,15 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -941,8 +974,15 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
CUDA_ARCHS "${CUDA_ARCHS}") CUDA_ARCHS "${CUDA_ARCHS}")
list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}") list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}")
# 9.0 for latest bf16 atomicAdd PTX # moe marlin arches
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}") # note that we always set `use_atomic_add=False` for moe marlin now,
# so we don't need 9.0 for bf16 atomicAdd PTX
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0+PTX" "${CUDA_ARCHS}")
# moe marlin arches for fp8 input
# - sm80 doesn't support fp8 computation
# - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction
# so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0)
cuda_archs_loose_intersection(MARLIN_MOE_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}")
if (MARLIN_MOE_ARCHS) if (MARLIN_MOE_ARCHS)
# #
...@@ -952,16 +992,18 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -952,16 +992,18 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set(MOE_MARLIN_GEN_SCRIPT set(MOE_MARLIN_GEN_SCRIPT
${CMAKE_CURRENT_SOURCE_DIR}/csrc/moe/marlin_moe_wna16/generate_kernels.py) ${CMAKE_CURRENT_SOURCE_DIR}/csrc/moe/marlin_moe_wna16/generate_kernels.py)
file(MD5 ${MOE_MARLIN_GEN_SCRIPT} MOE_MARLIN_GEN_SCRIPT_HASH) file(MD5 ${MOE_MARLIN_GEN_SCRIPT} MOE_MARLIN_GEN_SCRIPT_HASH)
list(JOIN CUDA_ARCHS "," CUDA_ARCHS_STR)
set(MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH "${MOE_MARLIN_GEN_SCRIPT_HASH}(ARCH:${CUDA_ARCHS_STR})")
message(STATUS "Marlin MOE generation script hash: ${MOE_MARLIN_GEN_SCRIPT_HASH}") message(STATUS "Marlin MOE generation script hash with arch: ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}")
message(STATUS "Last run Marlin MOE generate script hash: $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH}") message(STATUS "Last run Marlin MOE generate script hash with arch: $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}")
if (NOT DEFINED CACHE{MOE_MARLIN_GEN_SCRIPT_HASH} if (NOT DEFINED CACHE{MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}
OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH}) OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH})
execute_process( execute_process(
COMMAND ${CMAKE_COMMAND} -E env COMMAND ${CMAKE_COMMAND} -E env
PYTHONPATH=$PYTHONPATH PYTHONPATH=$PYTHONPATH
${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT} ${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT} ${CUDA_ARCHS_STR}
RESULT_VARIABLE moe_marlin_generation_result RESULT_VARIABLE moe_marlin_generation_result
OUTPUT_VARIABLE moe_marlin_generation_output OUTPUT_VARIABLE moe_marlin_generation_output
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log
...@@ -974,7 +1016,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -974,7 +1016,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"\nCheck the log for details: " "\nCheck the log for details: "
"${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log") "${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log")
else() else()
set(MOE_MARLIN_GEN_SCRIPT_HASH ${MOE_MARLIN_GEN_SCRIPT_HASH} set(MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}
CACHE STRING "Last run Marlin MOE generate script hash" FORCE) CACHE STRING "Last run Marlin MOE generate script hash" FORCE)
message(STATUS "Marlin MOE generation completed successfully.") message(STATUS "Marlin MOE generation completed successfully.")
endif() endif()
...@@ -982,16 +1024,28 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -982,16 +1024,28 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
message(STATUS "Marlin MOE generation script has not changed, skipping generation.") message(STATUS "Marlin MOE generation script has not changed, skipping generation.")
endif() endif()
file(GLOB MOE_WNAA16_MARLIN_SRC "csrc/moe/marlin_moe_wna16/*.cu") file(GLOB MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/sm80_kernel_*.cu")
list(APPEND MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/ops.cu")
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(
SRCS "${MOE_WNAA16_MARLIN_SRC}" SRCS "${MARLIN_MOE_SRC}"
CUDA_ARCHS "${MARLIN_MOE_ARCHS}") CUDA_ARCHS "${MARLIN_MOE_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
set_source_files_properties(${MOE_WNAA16_MARLIN_SRC} set_source_files_properties(${MARLIN_MOE_SRC}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
endif() endif()
list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_SRC})
list(APPEND VLLM_MOE_EXT_SRC ${MOE_WNAA16_MARLIN_SRC}) if (MARLIN_MOE_FP8_ARCHS)
file(GLOB MARLIN_MOE_FP8_SRC "csrc/moe/marlin_moe_wna16/sm89_kernel_*.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_MOE_FP8_SRC}"
CUDA_ARCHS "${MARLIN_MOE_FP8_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
set_source_files_properties(${MARLIN_MOE_FP8_SRC}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
endif()
list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_FP8_SRC})
endif()
message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}") message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}")
else() else()
......
...@@ -237,6 +237,7 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable: ...@@ -237,6 +237,7 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
b_q_weight=w_q, b_q_weight=w_q,
b_bias=None, b_bias=None,
b_scales=w_s, b_scales=w_s,
a_scales=None,
global_scale=None, global_scale=None,
b_zeros=w_zp, b_zeros=w_zp,
g_idx=g_idx, g_idx=g_idx,
......
...@@ -263,7 +263,7 @@ def bench_run( ...@@ -263,7 +263,7 @@ def bench_run(
results.append( results.append(
benchmark.Timer( benchmark.Timer(
stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501 stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501
globals=globals, globals=globals,
label=label, label=label,
sub_label=sub_label, sub_label=sub_label,
...@@ -273,7 +273,7 @@ def bench_run( ...@@ -273,7 +273,7 @@ def bench_run(
results.append( results.append(
benchmark.Timer( benchmark.Timer(
stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501 stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501
globals=globals, globals=globals,
label=label, label=label,
sub_label=sub_label, sub_label=sub_label,
......
kernel_*.cu sm*_kernel_*.cu
\ No newline at end of file kernel_selector.h
...@@ -4,134 +4,282 @@ import glob ...@@ -4,134 +4,282 @@ import glob
import itertools import itertools
import os import os
import subprocess import subprocess
import sys
import jinja2 import jinja2
FILE_HEAD = """ ARCHS = []
// auto generated by generate.py SUPPORT_FP8 = False
for arch in sys.argv[1].split(","):
arch = arch[: arch.index(".") + 2].replace(".", "")
arch = int(arch)
# only SM89 and SM120 fully support
# mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32.
# SM90 and SM100 can use this PTX, but it’s simulated
# with FP16 MMA, so it cannot achieve any acceleration.
if arch in [89, 120]:
SUPPORT_FP8 = True
FILE_HEAD_COMMENT = """
// auto generated by generate_kernels.py
// clang-format off // clang-format off
""".lstrip()
FILE_HEAD = (
FILE_HEAD_COMMENT
+ """
#include "kernel.h" #include "kernel.h"
#include "marlin_template.h" #include "marlin_template.h"
namespace MARLIN_NAMESPACE_NAME { namespace MARLIN_NAMESPACE_NAME {
""".strip() """
)
TEMPLATE = ( TEMPLATE = (
"template __global__ void Marlin<" "template __global__ void Marlin<"
"{{scalar_t}}, " "{{a_type_id}}, "
"{{w_type_id}}, " "{{b_type_id}}, "
"{{c_type_id}}, "
"{{s_type_id}}, " "{{s_type_id}}, "
"{{threads}}, " "{{threads}}, "
"{{thread_m_blocks}}, " "{{thread_m_blocks}}, "
"{{thread_n_blocks}}, " "{{thread_n_blocks}}, "
"{{thread_k_blocks}}, " "{{thread_k_blocks}}, "
"{{'true' if m_block_size_8 else 'false'}}, " "{{m_block_size_8}}, "
"{{stages}}, " "{{stages}}, "
"{{group_blocks}}, " "{{group_blocks}}, "
"{{'true' if is_zp_float else 'false'}}>" "{{is_zp_float}}>"
"( MARLIN_KERNEL_PARAMS );" "( MARLIN_KERNEL_PARAMS );"
) )
# int8 with zero point case (vllm::kU8) is also supported,
# we don't add it to reduce wheel size.
SCALAR_TYPES = [
"vllm::kU4",
"vllm::kU4B8",
"vllm::kU8B128",
"vllm::kFE4M3fn",
"vllm::kFE2M1f",
]
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)] THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4] THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
# group_blocks:
# = 0 : act order case QUANT_CONFIGS = [
# = -1 : channelwise quantization # AWQ-INT4
# > 0 : group_size=16*group_blocks {
GROUP_BLOCKS = [0, -1, 1, 2, 4, 8] "b_type": "kU4",
DTYPES = ["fp16", "bf16"] "thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT4
{
"b_type": "kU4B8",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 0, 2, 4, 8],
},
# AWQ-INT8
{
"b_type": "kU8B128",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 0, 2, 4, 8],
},
# FP8
{
"b_type": "kFE4M3fn",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 8],
},
# NVFP4
{
"b_type": "kFE2M1f",
"s_type": "kFE4M3fn",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [1],
},
# MXFP4
{
"a_type": ["kBFloat16"],
"b_type": "kFE2M1f",
"s_type": "kFE8M0fnu",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [2],
},
# AWQ-INT4 with INT8 activation
{
"a_type": ["kS8"],
"b_type": "kU4",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT4 with INT8 activation
{
"a_type": ["kS8"],
"b_type": "kU4B8",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT4 with FP8 activation
{
"a_type": ["kFE4M3fn"],
"b_type": "kU4B8",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# AWQ-INT4 with FP8 activation
{
"a_type": ["kFE4M3fn"],
"b_type": "kU4",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# MXFP4 with FP8 activation
{
"a_type": ["kFE4M3fn"],
"b_type": "kFE2M1f",
"c_type": ["kBFloat16"],
"s_type": "kFE8M0fnu",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [2],
},
]
def remove_old_kernels(): def remove_old_kernels():
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"): for filename in glob.glob(os.path.dirname(__file__) + "/*kernel_*.cu"):
subprocess.call(["rm", "-f", filename])
filename = os.path.dirname(__file__) + "/kernel_selector.h"
subprocess.call(["rm", "-f", filename]) subprocess.call(["rm", "-f", filename])
def generate_new_kernels(): def generate_new_kernels():
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES): result_dict = {}
all_template_str_list = []
for group_blocks, m_blocks, thread_configs in itertools.product( for quant_config in QUANT_CONFIGS:
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"])
): a_types = quant_config.get("a_type", ["kFloat16", "kBFloat16"])
# act order case only support gptq-int4 and gptq-int8 b_type = quant_config["b_type"]
if group_blocks == 0 and scalar_type not in [ all_group_blocks = quant_config["group_blocks"]
"vllm::kU4B8", all_m_blocks = quant_config["thread_m_blocks"]
"vllm::kU8B128", all_thread_configs = quant_config["thread_configs"]
]:
continue for a_type, c_type in itertools.product(a_types, c_types):
if thread_configs[2] == 256: if not SUPPORT_FP8 and a_type == "kFE4M3fn":
# for small batch (m_blocks == 1), we only need (128, 128, 256)
# for large batch (m_blocks > 1), we only need (64, 256, 256)
if m_blocks <= 1 and thread_configs[0] != 128:
continue continue
if m_blocks > 1 and thread_configs[0] != 64: if "16" in a_type and "16" in c_type and a_type != c_type:
continue continue
s_type = quant_config.get("s_type", c_type)
if (a_type, b_type, c_type) not in result_dict:
result_dict[(a_type, b_type, c_type)] = []
# we only support channelwise quantization and group_size == 128 for group_blocks, m_blocks, thread_configs in itertools.product(
# for fp8 all_group_blocks, all_m_blocks, all_thread_configs
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]: ):
continue thread_k, thread_n, threads = thread_configs
# nvfp4 only supports group_size == 16
# mxfp4 only supports group_size == 32 if threads == 256:
if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]: # for small batch (m_blocks == 1),
# we only need (128, 128, 256)
# for large batch (m_blocks > 1),
# we only need (64, 256, 256)
if m_blocks <= 1 and (thread_k, thread_n) != (128, 128):
continue continue
# other quantization methods don't support group_size = 16 if m_blocks > 1 and (thread_k, thread_n) != (64, 256):
if scalar_type != "vllm::kFE2M1f" and group_blocks == 1:
continue continue
k_blocks = thread_configs[0] // 16 config = {
n_blocks = thread_configs[1] // 16 "threads": threads,
threads = thread_configs[2] "s_type": s_type,
"thread_m_blocks": max(m_blocks, 1),
"thread_k_blocks": thread_k // 16,
"thread_n_blocks": thread_n // 16,
"m_block_size_8": "true" if m_blocks == 0.5 else "false",
"stages": "pipe_stages",
"group_blocks": group_blocks,
"is_zp_float": "false",
}
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16" result_dict[(a_type, b_type, c_type)].append(config)
if scalar_type == "vllm::kFE2M1f" and group_blocks == 1: kernel_selector_str = FILE_HEAD_COMMENT
s_type = "vllm::kFE4M3fn"
elif scalar_type == "vllm::kFE2M1f" and group_blocks == 2:
s_type = "vllm::kFE8M0fnu"
if dtype == "fp16":
# we cannot safely dequantize e8m0 to fp16, so skip this
continue
elif dtype == "fp16":
s_type = "vllm::kFloat16"
elif dtype == "bf16":
s_type = "vllm::kBFloat16"
for (a_type, b_type, c_type), config_list in result_dict.items():
all_template_str_list = []
for config in config_list:
s_type = config["s_type"]
template_str = jinja2.Template(TEMPLATE).render( template_str = jinja2.Template(TEMPLATE).render(
scalar_t=c_dtype, a_type_id=f"vllm::{a_type}.id()",
w_type_id=scalar_type + ".id()", b_type_id=f"vllm::{b_type}.id()",
s_type_id=s_type + ".id()", c_type_id=f"vllm::{c_type}.id()",
threads=threads, s_type_id=f"vllm::{s_type}.id()",
thread_m_blocks=max(m_blocks, 1), **config,
thread_n_blocks=n_blocks,
thread_k_blocks=k_blocks,
m_block_size_8=m_blocks == 0.5,
stages="pipe_stages",
group_blocks=group_blocks,
is_zp_float=False,
) )
all_template_str_list.append(template_str) all_template_str_list.append(template_str)
conditions = [
f"a_type == vllm::{a_type}",
f"b_type == vllm::{b_type}",
f"c_type == vllm::{c_type}",
f"s_type == vllm::{s_type}",
f"threads == {config['threads']}",
f"thread_m_blocks == {config['thread_m_blocks']}",
f"thread_n_blocks == {config['thread_n_blocks']}",
f"thread_k_blocks == {config['thread_k_blocks']}",
f"m_block_size_8 == {config['m_block_size_8']}",
f"group_blocks == {config['group_blocks']}",
f"is_zp_float == {config['is_zp_float']}",
]
conditions = " && ".join(conditions)
if kernel_selector_str == FILE_HEAD_COMMENT:
kernel_selector_str += f"if ({conditions})\n kernel = "
else:
kernel_selector_str += f"else if ({conditions})\n kernel = "
kernel_template2 = (
"Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
"{{is_zp_float}}>;"
)
kernel_selector_str += (
jinja2.Template(kernel_template2).render(
a_type_id=f"vllm::{a_type}.id()",
b_type_id=f"vllm::{b_type}.id()",
c_type_id=f"vllm::{c_type}.id()",
s_type_id=f"vllm::{s_type}.id()",
**config,
)
+ "\n"
)
file_content = FILE_HEAD + "\n\n" file_content = FILE_HEAD + "\n\n"
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n" file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu" if a_type == "kFE4M3fn":
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
else:
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
filename = filename.lower()
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f: with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
f.write(file_content) f.write(file_content)
if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT:
kernel_selector_str += (
"else if (a_type == vllm::kFE4M3fn)\n"
" TORCH_CHECK(false, "
'"marlin kernel with fp8 activation is not built.");'
)
with open(os.path.join(os.path.dirname(__file__), "kernel_selector.h"), "w") as f:
f.write(kernel_selector_str)
if __name__ == "__main__": if __name__ == "__main__":
remove_old_kernels() remove_old_kernels()
......
...@@ -11,8 +11,9 @@ ...@@ -11,8 +11,9 @@
const int4 *__restrict__ A, const int4 *__restrict__ B, \ const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ b_bias_ptr, \ const int4 *__restrict__ b_bias_ptr, \
const float *__restrict__ a_scales_ptr, \
const int4 *__restrict__ scales_ptr, \ const int4 *__restrict__ scales_ptr, \
const uint16_t *__restrict__ scale2_ptr, \ const uint16_t *__restrict__ global_scale_ptr, \
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
const int32_t *__restrict__ sorted_token_ids_ptr, \ const int32_t *__restrict__ sorted_token_ids_ptr, \
const int32_t *__restrict__ expert_ids_ptr, \ const int32_t *__restrict__ expert_ids_ptr, \
...@@ -20,12 +21,13 @@ ...@@ -20,12 +21,13 @@
const float *__restrict__ topk_weights_ptr, int top_k, \ const float *__restrict__ topk_weights_ptr, int top_k, \
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \ bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
int prob_n, int prob_k, int *locks, bool has_bias, bool use_atomic_add, \ int prob_n, int prob_k, int *locks, bool has_bias, bool use_atomic_add, \
bool use_fp32_reduce, int max_shared_mem bool use_fp32_reduce
namespace MARLIN_NAMESPACE_NAME { namespace MARLIN_NAMESPACE_NAME {
template <typename scalar_t, // compute dtype, half or nv_float16 template <const vllm::ScalarTypeId a_type_id, // A ScalarType id
const vllm::ScalarTypeId w_type_id, // weight ScalarType id const vllm::ScalarTypeId b_type_id, // B ScalarType id
const vllm::ScalarTypeId s_type_id, // weight scale ScalarType id const vllm::ScalarTypeId c_type_id, // C ScalarType id
const vllm::ScalarTypeId s_type_id, // B_SCALE ScalarType id
const int threads, // number of threads in a threadblock const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the // dimension (batchsize) of the
......
...@@ -38,7 +38,7 @@ namespace MARLIN_NAMESPACE_NAME { ...@@ -38,7 +38,7 @@ namespace MARLIN_NAMESPACE_NAME {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
template <typename scalar_t, // compute dtype, half or nv_float16 template <typename scalar_t, // compute dtype, half or nv_float16
const vllm::ScalarTypeId w_type_id, // weight ScalarType id const vllm::ScalarTypeId b_type_id, // weight MarlinScalarType id
const int threads, // number of threads in a threadblock const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the // dimension (batchsize) of the
...@@ -49,6 +49,7 @@ template <typename scalar_t, // compute dtype, half or nv_float16 ...@@ -49,6 +49,7 @@ template <typename scalar_t, // compute dtype, half or nv_float16
// only works when thread_m_blocks == 1 // only works when thread_m_blocks == 1
const int stages, // number of stages for the async global->shared const int stages, // number of stages for the async global->shared
// fetch pipeline // fetch pipeline
const bool has_act_order, // whether act_order is enabled
const int group_blocks, // number of consecutive 16x16 blocks const int group_blocks, // number of consecutive 16x16 blocks
// with a separate quantization scale // with a separate quantization scale
const bool is_zp_float // is zero point of float16 type? const bool is_zp_float // is zero point of float16 type?
...@@ -76,8 +77,8 @@ __global__ void Marlin( ...@@ -76,8 +77,8 @@ __global__ void Marlin(
int prob_k, // reduction dimension k int prob_k, // reduction dimension k
int* locks, // extra global storage for barrier synchronization int* locks, // extra global storage for barrier synchronization
bool use_atomic_add, // whether to use atomic add to reduce bool use_atomic_add, // whether to use atomic add to reduce
bool use_fp32_reduce, // whether to use fp32 global reduce bool use_fp32_reduce // whether to use fp32 global reduce
int max_shared_mem) {} ) {}
} // namespace MARLIN_NAMESPACE_NAME } // namespace MARLIN_NAMESPACE_NAME
...@@ -85,14 +86,17 @@ __global__ void Marlin( ...@@ -85,14 +86,17 @@ __global__ void Marlin(
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 // m16n8k16 tensor core mma instruction with fp16 inputs and fp32
// output/accumulation. // output/accumulation.
template <typename scalar_t> template <vllm::ScalarTypeId type_id, int k_size = 16>
__device__ inline void mma(const typename ScalarType<scalar_t>::FragA& a_frag, __device__ inline void mma(
const typename ScalarType<scalar_t>::FragB& frag_b, const typename MarlinScalarType<type_id>::FragA& a_frag,
typename ScalarType<scalar_t>::FragC& frag_c) { const typename MarlinScalarType<type_id>::FragB& frag_b,
typename MarlinScalarType<type_id>::FragC& frag_c, int idx = 0) {
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag); const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b); const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
float* c = reinterpret_cast<float*>(&frag_c); using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
if constexpr (k_size == 16) {
if constexpr (std::is_same<scalar_t, half>::value) { if constexpr (std::is_same<scalar_t, half>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile( asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
...@@ -100,28 +104,65 @@ __device__ inline void mma(const typename ScalarType<scalar_t>::FragA& a_frag, ...@@ -100,28 +104,65 @@ __device__ inline void mma(const typename ScalarType<scalar_t>::FragA& a_frag,
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) { } else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile( asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else { } else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "f"(c[0]),
"f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "r"(c[0]),
"r"(c[1]), "r"(c[2]), "r"(c[3]));
}
} else if (k_size == 32) {
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
}
} }
} }
template <typename scalar_t> template <vllm::ScalarTypeId type_id, int k_size = 16>
__device__ inline void mma_trans( __device__ inline void mma_trans(
const typename ScalarType<scalar_t>::FragA& a_frag, const typename MarlinScalarType<type_id>::FragA& a_frag,
const typename ScalarType<scalar_t>::FragB& frag_b, const typename MarlinScalarType<type_id>::FragB& frag_b,
const typename ScalarType<scalar_t>::FragB& frag_b2, const typename MarlinScalarType<type_id>::FragB& frag_b2,
typename ScalarType<scalar_t>::FragC& frag_c) { typename MarlinScalarType<type_id>::FragC& frag_c) {
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag); const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b); const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
const uint32_t* b2 = reinterpret_cast<const uint32_t*>(&frag_b2); const uint32_t* b2 = reinterpret_cast<const uint32_t*>(&frag_b2);
float* c = reinterpret_cast<float*>(&frag_c); float* c = reinterpret_cast<float*>(&frag_c);
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
if constexpr (k_size == 16) {
if constexpr (std::is_same<scalar_t, half>::value) { if constexpr (std::is_same<scalar_t, half>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile( asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
...@@ -129,21 +170,64 @@ __device__ inline void mma_trans( ...@@ -129,21 +170,64 @@ __device__ inline void mma_trans(
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) { } else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile( asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
"f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]),
"r"(c[3]));
}
} else { } else {
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 1200
asm volatile(
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
#else
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
#endif
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
}
} }
} }
// Instruction for loading a full 16x16 matrix fragment of operand A from shared // Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout. // memory, directly in tensor core layout.
template <int count, typename scalar_t> template <int count, vllm::ScalarTypeId type_id>
__device__ inline void ldsm(typename ScalarType<scalar_t>::FragA& frag_a, __device__ inline void ldsm(typename MarlinScalarType<type_id>::FragA& frag_a,
const void* smem_ptr) { const void* smem_ptr) {
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a); uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr)); uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
...@@ -167,47 +251,54 @@ __device__ inline void ldsm(typename ScalarType<scalar_t>::FragA& frag_a, ...@@ -167,47 +251,54 @@ __device__ inline void ldsm(typename ScalarType<scalar_t>::FragA& frag_a,
// Multiply dequantized values by the corresponding quantization scale; used // Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization. // only for grouped quantization.
template <typename scalar_t> template <vllm::ScalarTypeId type_id>
__device__ inline void scale(typename ScalarType<scalar_t>::FragB& frag_b, __device__ inline void scale(typename MarlinScalarType<type_id>::FragB& frag_b,
typename ScalarType<scalar_t>::FragS& frag_s, typename MarlinScalarType<type_id>::FragS& frag_s,
int i) { int i) {
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2; using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
scalar_t2 s = using scalar_t2 = typename MarlinScalarType<type_id>::scalar_t2;
ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_s)[i]); scalar_t2 s = MarlinScalarType<type_id>::num2num2(
reinterpret_cast<scalar_t*>(&frag_s)[i]);
frag_b[0] = __hmul2(frag_b[0], s); frag_b[0] = __hmul2(frag_b[0], s);
frag_b[1] = __hmul2(frag_b[1], s); frag_b[1] = __hmul2(frag_b[1], s);
} }
template <typename scalar_t> template <vllm::ScalarTypeId type_id>
__device__ inline void scale_and_sub( __device__ inline void scale_and_sub(
typename ScalarType<scalar_t>::FragB& frag_b, scalar_t s, scalar_t zp) { typename MarlinScalarType<type_id>::FragB& frag_b,
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2; typename MarlinScalarType<type_id>::scalar_t s,
scalar_t2 s2 = ScalarType<scalar_t>::num2num2(s); typename MarlinScalarType<type_id>::scalar_t zp) {
scalar_t2 zp2 = ScalarType<scalar_t>::num2num2(zp); using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
using scalar_t2 = typename MarlinScalarType<type_id>::scalar_t2;
scalar_t2 s2 = MarlinScalarType<type_id>::num2num2(s);
scalar_t2 zp2 = MarlinScalarType<type_id>::num2num2(zp);
frag_b[0] = __hfma2(frag_b[0], s2, __hneg2(zp2)); frag_b[0] = __hfma2(frag_b[0], s2, __hneg2(zp2));
frag_b[1] = __hfma2(frag_b[1], s2, __hneg2(zp2)); frag_b[1] = __hfma2(frag_b[1], s2, __hneg2(zp2));
} }
template <typename scalar_t> template <vllm::ScalarTypeId type_id>
__device__ inline void sub_zp(typename ScalarType<scalar_t>::FragB& frag_b, __device__ inline void sub_zp(
typename ScalarType<scalar_t>::scalar_t2& frag_zp, typename MarlinScalarType<type_id>::FragB& frag_b,
int i) { typename MarlinScalarType<type_id>::scalar_t2& frag_zp, int i) {
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2; using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
scalar_t2 zp = using scalar_t2 = typename MarlinScalarType<type_id>::scalar_t2;
ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_zp)[i]); scalar_t2 zp = MarlinScalarType<type_id>::num2num2(
reinterpret_cast<scalar_t*>(&frag_zp)[i]);
frag_b[0] = __hsub2(frag_b[0], zp); frag_b[0] = __hsub2(frag_b[0], zp);
frag_b[1] = __hsub2(frag_b[1], zp); frag_b[1] = __hsub2(frag_b[1], zp);
} }
// Same as above, but for act_order (each K is multiplied individually) // Same as above, but for act_order (each K is multiplied individually)
template <typename scalar_t> template <vllm::ScalarTypeId type_id>
__device__ inline void scale4(typename ScalarType<scalar_t>::FragB& frag_b, __device__ inline void scale4(
typename ScalarType<scalar_t>::FragS& frag_s_1, typename MarlinScalarType<type_id>::FragB& frag_b,
typename ScalarType<scalar_t>::FragS& frag_s_2, typename MarlinScalarType<type_id>::FragS& frag_s_1,
typename ScalarType<scalar_t>::FragS& frag_s_3, typename MarlinScalarType<type_id>::FragS& frag_s_2,
typename ScalarType<scalar_t>::FragS& frag_s_4, typename MarlinScalarType<type_id>::FragS& frag_s_3,
int i) { typename MarlinScalarType<type_id>::FragS& frag_s_4, int i) {
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2; using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
using scalar_t2 = typename MarlinScalarType<type_id>::scalar_t2;
scalar_t2 s_val_1_2; scalar_t2 s_val_1_2;
s_val_1_2.x = reinterpret_cast<scalar_t*>(&frag_s_1)[i]; s_val_1_2.x = reinterpret_cast<scalar_t*>(&frag_s_1)[i];
s_val_1_2.y = reinterpret_cast<scalar_t*>(&frag_s_2)[i]; s_val_1_2.y = reinterpret_cast<scalar_t*>(&frag_s_2)[i];
...@@ -221,12 +312,13 @@ __device__ inline void scale4(typename ScalarType<scalar_t>::FragB& frag_b, ...@@ -221,12 +312,13 @@ __device__ inline void scale4(typename ScalarType<scalar_t>::FragB& frag_b,
} }
// Given 2 floats multiply by 2 scales (halves) // Given 2 floats multiply by 2 scales (halves)
template <typename scalar_t> template <vllm::ScalarTypeId type_id>
__device__ inline void scale_float(float* c, __device__ inline void scale_float(
typename ScalarType<scalar_t>::FragS& s) { float* c, typename MarlinScalarType<type_id>::FragS& s) {
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
scalar_t* s_ptr = reinterpret_cast<scalar_t*>(&s); scalar_t* s_ptr = reinterpret_cast<scalar_t*>(&s);
c[0] = __fmul_rn(c[0], ScalarType<scalar_t>::num2float(s_ptr[0])); c[0] = __fmul_rn(c[0], MarlinScalarType<type_id>::num2float(s_ptr[0]));
c[1] = __fmul_rn(c[1], ScalarType<scalar_t>::num2float(s_ptr[1])); c[1] = __fmul_rn(c[1], MarlinScalarType<type_id>::num2float(s_ptr[1]));
} }
// Wait until barrier reaches `count`, then lock for current threadblock. // Wait until barrier reaches `count`, then lock for current threadblock.
...@@ -278,9 +370,10 @@ __device__ inline void wait_negative_and_add(int* lock) { ...@@ -278,9 +370,10 @@ __device__ inline void wait_negative_and_add(int* lock) {
__syncthreads(); __syncthreads();
} }
template <typename scalar_t, // compute dtype, half or nv_float16 template <const vllm::ScalarTypeId a_type_id, // A ScalarType id
const vllm::ScalarTypeId w_type_id, // weight ScalarType id const vllm::ScalarTypeId b_type_id, // B ScalarType id
const vllm::ScalarTypeId s_type_id, // weight scale ScalarType id const vllm::ScalarTypeId c_type_id, // C ScalarType id
const vllm::ScalarTypeId s_type_id, // B_SCALE ScalarType id
const int threads, // number of threads in a threadblock const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the // dimension (batchsize) of the
...@@ -301,13 +394,18 @@ __global__ void Marlin( ...@@ -301,13 +394,18 @@ __global__ void Marlin(
int4* __restrict__ C, // fp16 output buffer of shape mxn int4* __restrict__ C, // fp16 output buffer of shape mxn
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
const int4* __restrict__ b_bias_ptr, const int4* __restrict__ b_bias_ptr,
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape // float scales of input matrix, only used when is_a_8bit == true.
// (k/groupsize)xn // shape (m,)
const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4 const float* __restrict__ a_scales_ptr,
// only) // fp16 quantization scales. shape (k/groupsize, n)
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape const int4* __restrict__ scales_ptr,
// (k/groupsize)x(n/pack_factor) // fp16 global scale (for nvfp4// only)
const int* __restrict__ g_idx, // int32 group indices of shape k const uint16_t* __restrict__ global_scale_ptr,
// 4bit packed zero-points of shape
// (k/groupsize, n/pack_factor)
const int4* __restrict__ zp_ptr,
// int32 group indices of shape k
const int* __restrict__ g_idx,
const int32_t* __restrict__ sorted_token_ids_ptr, // moe sorted_ids const int32_t* __restrict__ sorted_token_ids_ptr, // moe sorted_ids
const int32_t* __restrict__ expert_ids_ptr, // moe expert ids const int32_t* __restrict__ expert_ids_ptr, // moe expert ids
const int32_t* __restrict__ num_tokens_past_padded_ptr, // moe num tokens const int32_t* __restrict__ num_tokens_past_padded_ptr, // moe num tokens
...@@ -322,8 +420,8 @@ __global__ void Marlin( ...@@ -322,8 +420,8 @@ __global__ void Marlin(
int* locks, // extra global storage for barrier synchronization int* locks, // extra global storage for barrier synchronization
bool has_bias, bool has_bias,
bool use_atomic_add, // whether to use atomic add to reduce bool use_atomic_add, // whether to use atomic add to reduce
bool use_fp32_reduce, // whether to use fp32 global reduce bool use_fp32_reduce // whether to use fp32 global reduce
int max_shared_mem) { ) {
// Each threadblock processes one "stripe" of the B matrix with (roughly) the // Each threadblock processes one "stripe" of the B matrix with (roughly) the
// same size, which might involve multiple column "slices" (of width 16 * // same size, which might involve multiple column "slices" (of width 16 *
// `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
...@@ -335,18 +433,37 @@ __global__ void Marlin( ...@@ -335,18 +433,37 @@ __global__ void Marlin(
// ensures good utilization of all SMs for many kinds of shape and GPU // ensures good utilization of all SMs for many kinds of shape and GPU
// configurations, while requiring as few slow global cross-threadblock // configurations, while requiring as few slow global cross-threadblock
// reductions as possible. // reductions as possible.
using Dtype = ScalarType<scalar_t>;
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 890
using FragA = typename ScalarType<scalar_t>::FragA; // FP8 computation is only supported for Ada Lovelace or newer architectures.
using FragB = typename ScalarType<scalar_t>::FragB; if constexpr (a_type_id == vllm::kFE4M3fn.id()) return;
using FragC = typename ScalarType<scalar_t>::FragC; #endif
using FragS = typename ScalarType<scalar_t>::FragS;
using FragZP = typename ScalarType<scalar_t>::FragZP; int num_tokens_past_padded = num_tokens_past_padded_ptr[0];
constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks);
using Adtype = MarlinScalarType<a_type_id>;
using Cdtype = MarlinScalarType<c_type_id>;
using scalar_t = typename MarlinScalarType<a_type_id>::scalar_t;
using scalar_t2 = typename MarlinScalarType<a_type_id>::scalar_t2;
using scalar_32bit_t = typename MarlinScalarType<a_type_id>::scalar_32bit_t;
using c_scalar_t = typename MarlinScalarType<c_type_id>::scalar_t;
using c_scalar_t2 = typename MarlinScalarType<c_type_id>::scalar_t2;
using FragA = typename MarlinScalarType<a_type_id>::FragA;
using FragB = typename MarlinScalarType<a_type_id>::FragB;
using FragC = typename MarlinScalarType<a_type_id>::FragC;
using FragS = typename MarlinScalarType<c_type_id>::FragS;
using FragZP = typename MarlinScalarType<c_type_id>::FragZP;
extern __shared__ int4 sh[]; extern __shared__ int4 sh[];
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); static constexpr auto a_type = vllm::ScalarType::from_id(a_type_id);
static constexpr auto b_type = vllm::ScalarType::from_id(b_type_id);
static constexpr auto c_type = vllm::ScalarType::from_id(c_type_id);
static constexpr auto s_type = vllm::ScalarType::from_id(s_type_id); static constexpr auto s_type = vllm::ScalarType::from_id(s_type_id);
if constexpr (w_type == vllm::kFE2M1f) { if constexpr (b_type == vllm::kFE2M1f) {
static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 || static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 ||
s_type == vllm::kFE8M0fnu && group_blocks == 2); s_type == vllm::kFE8M0fnu && group_blocks == 2);
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) { } else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
...@@ -355,34 +472,37 @@ __global__ void Marlin( ...@@ -355,34 +472,37 @@ __global__ void Marlin(
static_assert(s_type == vllm::kFloat16); static_assert(s_type == vllm::kFloat16);
} }
constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8; constexpr bool is_a_8bit = a_type.size_bits() == 8;
constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 || if constexpr (!is_a_8bit) {
w_type == vllm::kU4B8 || w_type == vllm::kU8B128; static_assert(std::is_same<scalar_t, c_scalar_t>::value);
}
constexpr bool has_zp = b_type == vllm::kU4 || b_type == vllm::kU8;
constexpr bool is_int_type = b_type == vllm::kU4 || b_type == vllm::kU8 ||
b_type == vllm::kS4 || b_type == vllm::kS8 ||
b_type == vllm::kU4B8 || b_type == vllm::kU8B128;
// see comments of dequant.h for more details // see comments of dequant.h for more details
constexpr bool dequant_skip_flop = constexpr bool dequant_skip_flop =
w_type == vllm::kFE4M3fn || is_a_8bit || b_type == vllm::kFE4M3fn ||
w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn || b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn ||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value || has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
has_zp && !is_zp_float && !(w_type == vllm::kU8); has_zp && !is_zp_float && !(b_type == vllm::kU8);
scalar_t2 global_scale; c_scalar_t2 global_scale;
constexpr bool has_act_order = group_blocks == 0; constexpr bool has_act_order = group_blocks == 0;
constexpr int pack_factor = 32 / w_type.size_bits(); constexpr int pack_factor = 32 / b_type.size_bits();
static_assert(thread_m_blocks == 1 || !m_block_size_8); static_assert(thread_m_blocks == 1 || !m_block_size_8);
constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks);
const int group_size = const int group_size =
(!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups; (!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups;
const int scales_expert_stride = const int scales_expert_stride =
prob_n * prob_k / group_size / (w_type == vllm::kFE2M1f ? 16 : 8); prob_n * prob_k / group_size / (b_type == vllm::kFE2M1f ? 16 : 8);
const int zp_expert_stride = const int zp_expert_stride =
is_zp_float ? prob_n * prob_k / group_size / 8 is_zp_float ? prob_n * prob_k / group_size / 8
: prob_n * prob_k / group_size / (pack_factor * 4); : prob_n * prob_k / group_size / (pack_factor * 4);
const int b_bias_expert_stride = prob_n / 8; const int b_bias_expert_stride = prob_n / 8;
// parallel: num valid moe blocks // parallel: num valid moe blocks
int num_tokens_past_padded = num_tokens_past_padded_ptr[0];
int parallel = num_tokens_past_padded / moe_block_size; int parallel = num_tokens_past_padded / moe_block_size;
int num_valid_blocks = parallel; int num_valid_blocks = parallel;
if (is_ep) { if (is_ep) {
...@@ -395,7 +515,23 @@ __global__ void Marlin( ...@@ -395,7 +515,23 @@ __global__ void Marlin(
int k_tiles = prob_k / 16 / thread_k_blocks; int k_tiles = prob_k / 16 / thread_k_blocks;
int n_tiles = prob_n / 16 / thread_n_blocks; int n_tiles = prob_n / 16 / thread_n_blocks;
int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x);
int global_mn_tiles = parallel * n_tiles;
int part2_mn_tiles = global_mn_tiles;
int part1_mn_iters = 0;
bool in_part2 = false;
// we use DP + two-tile SK here
// part1: DP
// part2: two-tile SK
// see https://github.com/vllm-project/vllm/pull/24722 for more details
if (global_mn_tiles > gridDim.x) {
part2_mn_tiles = global_mn_tiles % gridDim.x;
if (part2_mn_tiles * 3 <= gridDim.x) part2_mn_tiles += gridDim.x;
part1_mn_iters = (global_mn_tiles - part2_mn_tiles) / gridDim.x;
}
int iters = div_ceil(k_tiles * part2_mn_tiles, gridDim.x);
if constexpr (!has_act_order && group_blocks != -1) { if constexpr (!has_act_order && group_blocks != -1) {
if (group_blocks >= thread_k_blocks) { if (group_blocks >= thread_k_blocks) {
...@@ -407,14 +543,15 @@ __global__ void Marlin( ...@@ -407,14 +543,15 @@ __global__ void Marlin(
} }
} }
int slice_row = (iters * blockIdx.x) % k_tiles; int slice_row = 0;
int slice_col_par = (iters * blockIdx.x) / k_tiles; int slice_col_par = blockIdx.x;
int slice_col = slice_col_par; int slice_col;
int slice_iters; // number of threadblock tiles in the current slice int slice_iters =
int slice_count = k_tiles; // number of threadblock tiles in the current slice
0; // total number of active threadblocks in the current slice // total number of active threadblocks in the current slice
int slice_idx; // index of threadblock in current slice; numbered bottom to int slice_count = 1;
// top // index of threadblock in current slice; numbered bottom to top
int slice_idx = 0;
int par_id = 0; int par_id = 0;
int block_id = -1; int block_id = -1;
...@@ -422,87 +559,89 @@ __global__ void Marlin( ...@@ -422,87 +559,89 @@ __global__ void Marlin(
int old_expert_id = 0; int old_expert_id = 0;
int64_t B_expert_off = 0; int64_t B_expert_off = 0;
int4* sh_block_sorted_ids_int4 = sh; float* sh_a_s = reinterpret_cast<float*>(sh);
int4* sh_block_sorted_ids_int4 = sh + (is_a_8bit ? (4 * thread_m_blocks) : 0);
int4* sh_rd_block_sorted_ids_int4 = int4* sh_rd_block_sorted_ids_int4 =
sh_block_sorted_ids_int4 + moe_block_size / 4; sh_block_sorted_ids_int4 + moe_block_size / 4;
int4* sh_block_topk_weights_int4 = int4* sh_block_topk_weights_int4 =
sh_rd_block_sorted_ids_int4 + moe_block_size / 4; sh_rd_block_sorted_ids_int4 + moe_block_size / 4;
// sh_block_topk_weights_int4 only need (moe_block_size / 4); // sh_block_topk_weights_int4 only need (moe_block_size / 4);
// but we pad to align to 256 bytes // but we pad to align to 256 bytes
int4* sh_new = int4* sh_new = sh_block_topk_weights_int4 + moe_block_size / 2;
sh_block_topk_weights_int4 + moe_block_size / 2 + moe_block_size;
int32_t* sh_block_sorted_ids = int32_t* sh_block_sorted_ids =
reinterpret_cast<int*>(sh_block_sorted_ids_int4); reinterpret_cast<int*>(sh_block_sorted_ids_int4);
int32_t* sh_rd_block_sorted_ids = int32_t* sh_rd_block_sorted_ids =
reinterpret_cast<int*>(sh_rd_block_sorted_ids_int4); reinterpret_cast<int*>(sh_rd_block_sorted_ids_int4);
scalar_t2* sh_block_topk_weights = c_scalar_t2* sh_block_topk_weights =
reinterpret_cast<scalar_t2*>(sh_block_topk_weights_int4); reinterpret_cast<c_scalar_t2*>(sh_block_topk_weights_int4);
int32_t block_num_valid_tokens = 0; int32_t block_num_valid_tokens = 0;
int32_t locks_off = 0; int32_t locks_off = 0;
// We can easily implement parallel problem execution by just remapping // We can easily implement parallel problem execution by just remapping
// indices and advancing global pointers // indices and advancing global pointers
if (slice_col_par >= n_tiles) { if (part2_mn_tiles >= gridDim.x) {
slice_col = slice_col_par % n_tiles; // when part2_mn_tiles >= sms
par_id = slice_col_par / n_tiles;
}
if (parallel * n_tiles >= gridDim.x) {
// when parallel * n_tiles >= sms
// then there are at most $sms$ conflict tile blocks // then there are at most $sms$ conflict tile blocks
locks_off = blockIdx.x; locks_off = blockIdx.x;
} else { } else {
locks_off = (iters * blockIdx.x) / k_tiles - 1; locks_off = (iters * blockIdx.x) / k_tiles - 1;
} }
int prob_m_top_k = prob_m * top_k;
// read moe block data given block_id // read moe block data given block_id
// block_sorted_ids / block_num_valid_tokens / block_topk_weights // block_sorted_ids / block_num_valid_tokens / block_topk_weights
auto read_moe_block_data = [&](int block_id) { auto read_moe_block_data = [&](int block_id) {
block_num_valid_tokens = moe_block_size; block_num_valid_tokens = moe_block_size;
#pragma unroll
for (int i = 0; i < moe_block_size / 4; i++) { cp_async4_pred(sh_block_sorted_ids_int4 + threadIdx.x,
int4 sorted_token_ids_int4 = reinterpret_cast<const int4*>( reinterpret_cast<const int4*>(sorted_token_ids_ptr) +
sorted_token_ids_ptr)[block_id * moe_block_size / 4 + i]; (block_id * moe_block_size / 4 + threadIdx.x),
int* sorted_token_ids = reinterpret_cast<int*>(&sorted_token_ids_int4); threadIdx.x < moe_block_size / 4);
#pragma unroll
for (int j = 0; j < 4; j++) { cp_async_fence();
if (sorted_token_ids[j] >= prob_m * top_k) { cp_async_wait<0>();
block_num_valid_tokens = i * 4 + j;
break;
}
}
if (block_num_valid_tokens != moe_block_size) break;
}
__syncthreads(); __syncthreads();
int tid4 = threadIdx.x / 4;
if (threadIdx.x % 4 == 0 && threadIdx.x < block_num_valid_tokens) {
sh_block_sorted_ids_int4[tid4] = reinterpret_cast<const int4*>(
sorted_token_ids_ptr)[block_id * moe_block_size / 4 + tid4];
#pragma unroll if (threadIdx.x >= threads - 32) {
for (int i = 0; i < 4; i++) constexpr int size_per_thread = div_ceil(moe_block_size, 32);
sh_rd_block_sorted_ids[tid4 * 4 + i] = int lane_id = threadIdx.x - (threads - 32);
sh_block_sorted_ids[tid4 * 4 + i] / top_k;
if (mul_topk_weights) { int local_count = 0;
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) { for (int i = 0; i < size_per_thread; i++) {
int idx = tid4 * 4 + i; int j = lane_id * size_per_thread + i;
if (idx < block_num_valid_tokens) { if (j < moe_block_size) {
if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { int idx = sh_block_sorted_ids[j];
sh_block_topk_weights[idx] = if (idx < prob_m_top_k) local_count++;
__hmul2(global_scale,
Dtype::num2num2(Dtype::float2num(
topk_weights_ptr[sh_block_sorted_ids[idx]])));
} else {
sh_block_topk_weights[idx] = Dtype::num2num2(
Dtype::float2num(topk_weights_ptr[sh_block_sorted_ids[idx]]));
} }
} }
block_num_valid_tokens = __reduce_add_sync(0xffffffff, local_count);
if (lane_id == 0)
reinterpret_cast<int*>(sh_new)[0] = block_num_valid_tokens;
} }
if (threadIdx.x < moe_block_size) {
int idx = sh_block_sorted_ids[threadIdx.x];
sh_rd_block_sorted_ids[threadIdx.x] = idx / top_k;
if (mul_topk_weights) {
idx = idx < prob_m_top_k ? idx : 0;
c_scalar_t2 topk_weight_val =
Cdtype::num2num2(Cdtype::float2num(topk_weights_ptr[idx]));
if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
topk_weight_val = __hmul2(topk_weight_val, global_scale);
} }
sh_block_topk_weights[threadIdx.x] = topk_weight_val;
} }
}
__syncthreads();
block_num_valid_tokens = reinterpret_cast<int*>(sh_new)[0];
__syncthreads(); __syncthreads();
}; };
...@@ -513,9 +652,8 @@ __global__ void Marlin( ...@@ -513,9 +652,8 @@ __global__ void Marlin(
old_expert_id = expert_id; old_expert_id = expert_id;
if (num_invalid_blocks > 0) { if (num_invalid_blocks > 0) {
int skip_count = block_id == -1 ? par_id : 0; int skip_count = par_id;
block_id++; for (int i = 0; i < num_tokens_past_padded / moe_block_size; i++) {
for (int i = block_id; i < num_tokens_past_padded / moe_block_size; i++) {
expert_id = expert_ids_ptr[i]; expert_id = expert_ids_ptr[i];
if (expert_id != -1) { if (expert_id != -1) {
if (skip_count == 0) { if (skip_count == 0) {
...@@ -530,9 +668,9 @@ __global__ void Marlin( ...@@ -530,9 +668,9 @@ __global__ void Marlin(
expert_id = expert_ids_ptr[block_id]; expert_id = expert_ids_ptr[block_id];
} }
if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
uint16_t val = scale2_ptr[expert_id]; uint16_t val = global_scale_ptr[expert_id];
global_scale = Dtype::num2num2(*reinterpret_cast<scalar_t*>(&val)); global_scale = Cdtype::num2num2(*reinterpret_cast<c_scalar_t*>(&val));
} }
B_expert_off = expert_id * prob_n * prob_k / (pack_factor * 4); B_expert_off = expert_id * prob_n * prob_k / (pack_factor * 4);
...@@ -552,10 +690,11 @@ __global__ void Marlin( ...@@ -552,10 +690,11 @@ __global__ void Marlin(
// Compute all information about the current slice which is required for // Compute all information about the current slice which is required for
// synchronization. // synchronization.
auto init_slice = [&](bool first_init = false) { bool first_init = true;
auto init_part2_slice = [&]() {
slice_iters = slice_iters =
iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; if (slice_iters < 0 || slice_col_par >= part2_mn_tiles) slice_iters = 0;
if (slice_iters == 0) return; if (slice_iters == 0) return;
if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
slice_count = 1; slice_count = 1;
...@@ -573,7 +712,7 @@ __global__ void Marlin( ...@@ -573,7 +712,7 @@ __global__ void Marlin(
if (col_off > 0) slice_idx--; if (col_off > 0) slice_idx--;
} }
} }
if (parallel * n_tiles >= gridDim.x) { if (part2_mn_tiles >= gridDim.x) {
if (slice_count > 1 && slice_idx == slice_count - 1) { if (slice_count > 1 && slice_idx == slice_count - 1) {
locks_off++; locks_off++;
} }
...@@ -607,25 +746,61 @@ __global__ void Marlin( ...@@ -607,25 +746,61 @@ __global__ void Marlin(
par_id++; par_id++;
update_next_moe_block_data(); update_next_moe_block_data();
} }
if (is_a_8bit && (first_init || slice_col == 0)) {
__syncthreads();
cp_async1_ca_pred(&sh_a_s[threadIdx.x],
&a_scales_ptr[sh_rd_block_sorted_ids[threadIdx.x]],
threadIdx.x < block_num_valid_tokens);
}
}; };
auto init_part1_slice = [&]() {
if (part1_mn_iters) {
part1_mn_iters--;
par_id = slice_col_par / n_tiles;
slice_col = slice_col_par % n_tiles;
slice_iters = k_tiles;
update_next_moe_block_data(); update_next_moe_block_data();
init_slice(true); if (is_a_8bit) {
__syncthreads();
cp_async1_ca_pred(&sh_a_s[threadIdx.x],
&a_scales_ptr[sh_rd_block_sorted_ids[threadIdx.x]],
threadIdx.x < block_num_valid_tokens);
}
}
};
auto init_slice = [&]() {
if (!in_part2 && !part1_mn_iters) {
in_part2 = true;
slice_col_par = (iters * blockIdx.x) / k_tiles;
slice_row = (iters * blockIdx.x) % k_tiles;
slice_col = (slice_col_par + global_mn_tiles - part2_mn_tiles) % n_tiles;
par_id = (slice_col_par + global_mn_tiles - part2_mn_tiles) / n_tiles;
update_next_moe_block_data();
}
if (!in_part2) {
init_part1_slice();
} else {
init_part2_slice();
first_init = false;
}
};
init_slice();
// A sizes/strides // A sizes/strides
// stride of the A matrix in global memory // stride of the A matrix in global memory
int a_gl_stride = prob_k / 8; int a_gl_stride = prob_k / (is_a_8bit ? 16 : 8);
// stride of an A matrix tile in shared memory // stride of an A matrix tile in shared memory
constexpr int a_sh_stride = 16 * thread_k_blocks / 8; constexpr int a_sh_stride = 16 * thread_k_blocks / (is_a_8bit ? 16 : 8);
// delta between subsequent A tiles in global memory // delta between subsequent A tiles in global memory
constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / (is_a_8bit ? 16 : 8);
// between subsequent accesses within a tile // between subsequent accesses within a tile
int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o);
// between shared memory writes // between shared memory writes
constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o);
// between shared memory tile reads
constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4));
// within a shared memory tile // within a shared memory tile
constexpr int a_sh_rd_delta_i = a_sh_stride * 16; constexpr int a_sh_rd_delta_i = a_sh_stride * 16;
// overall size of a tile // overall size of a tile
...@@ -634,24 +809,25 @@ __global__ void Marlin( ...@@ -634,24 +809,25 @@ __global__ void Marlin(
constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta);
// B sizes/strides // B sizes/strides
int b_gl_stride = 16 * prob_n / (pack_factor * 4); int b_gl_stride = 16 * prob_n / (pack_factor * (is_a_8bit ? 2 : 4));
constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; constexpr int b_sh_stride =
constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; ((thread_n_blocks * 16) * 16 / pack_factor) / (is_a_8bit ? 2 : 4);
constexpr int b_thread_vecs = b_type.size_bits() == 4 ? 1 : 2;
constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs;
int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks / (is_a_8bit ? 2 : 1);
int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads);
constexpr int b_sh_wr_delta = threads * b_thread_vecs; constexpr int b_sh_wr_delta = threads * b_thread_vecs;
constexpr int b_sh_rd_delta = threads * b_thread_vecs; constexpr int b_sh_stage =
constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; b_sh_stride * thread_k_blocks / (is_a_8bit ? 2 : 1);
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
// Scale sizes/strides without act_order // Scale sizes/strides without act_order
int s_gl_stride = prob_n / 8; int s_gl_stride = prob_n / (b_type == vllm::kFE2M1f ? 16 : 8);
constexpr int s_sh_stride = 16 * thread_n_blocks / 8; constexpr int s_sh_stride =
16 * thread_n_blocks / (b_type == vllm::kFE2M1f ? 16 : 8);
constexpr int s_tb_groups = constexpr int s_tb_groups =
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
? thread_k_blocks / group_blocks / (w_type == vllm::kFE2M1f ? 2 : 1) ? thread_k_blocks / group_blocks
: 1; : 1;
constexpr int s_sh_stage = s_tb_groups * s_sh_stride; constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
int s_gl_rd_delta = s_gl_stride; int s_gl_rd_delta = s_gl_stride;
...@@ -664,7 +840,8 @@ __global__ void Marlin( ...@@ -664,7 +840,8 @@ __global__ void Marlin(
constexpr int act_s_max_num_groups = 32; constexpr int act_s_max_num_groups = 32;
int act_s_col_stride = 1; int act_s_col_stride = 1;
int act_s_col_warp_stride = act_s_col_stride * 8; int act_s_col_warp_stride = act_s_col_stride * 8;
int tb_n_warps = thread_n_blocks / 4;
constexpr int tb_n_warps = thread_n_blocks / (is_a_8bit ? 2 : 4);
int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;
// Zero-points sizes/strides // Zero-points sizes/strides
...@@ -679,7 +856,6 @@ __global__ void Marlin( ...@@ -679,7 +856,6 @@ __global__ void Marlin(
// Global A read index of current thread. // Global A read index of current thread.
int a_gl_rd_row = threadIdx.x / a_gl_rd_delta_o; int a_gl_rd_row = threadIdx.x / a_gl_rd_delta_o;
int a_gl_rd_col = a_gl_rd_delta_o * slice_row + threadIdx.x % a_gl_rd_delta_o; int a_gl_rd_col = a_gl_rd_delta_o * slice_row + threadIdx.x % a_gl_rd_delta_o;
// Shared write index of current thread. // Shared write index of current thread.
int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +
(threadIdx.x % a_gl_rd_delta_o); (threadIdx.x % a_gl_rd_delta_o);
...@@ -687,17 +863,22 @@ __global__ void Marlin( ...@@ -687,17 +863,22 @@ __global__ void Marlin(
int a_sh_rd = int a_sh_rd =
a_sh_stride * ((threadIdx.x % 32) % (16 / (m_block_size_8 ? 2 : 1))) + a_sh_stride * ((threadIdx.x % 32) % (16 / (m_block_size_8 ? 2 : 1))) +
(threadIdx.x % 32) / (16 / (m_block_size_8 ? 2 : 1)); (threadIdx.x % 32) / (16 / (m_block_size_8 ? 2 : 1));
a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); a_sh_rd += 2 * ((threadIdx.x / 32) / tb_n_warps) * b_sh_wr_iters;
int b_gl_rd;
if (threads <= b_sh_stride) {
b_gl_rd = threadIdx.x;
} else {
b_gl_rd =
b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride);
}
int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + b_gl_rd += B_expert_off + b_sh_stride * slice_col;
(threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
b_gl_rd += b_sh_stride * slice_col;
b_gl_rd += b_gl_rd_delta_o * slice_row; b_gl_rd += b_gl_rd_delta_o * slice_row;
auto b_sh_wr = threadIdx.x * b_thread_vecs;
auto b_sh_rd = threadIdx.x * b_thread_vecs; auto b_sh_rd = threadIdx.x * b_thread_vecs;
b_sh_rd += b_sh_rd / b_sh_stride * (b_sh_stride * (b_sh_wr_iters - 1));
// For act_order // For act_order
constexpr int k_iter_size = tb_k / b_sh_wr_iters;
int slice_k_start = tb_k * slice_row; int slice_k_start = tb_k * slice_row;
int slice_k_finish = slice_k_start + tb_k * slice_iters; int slice_k_finish = slice_k_start + tb_k * slice_iters;
int slice_k_start_shared_fetch = slice_k_start; int slice_k_start_shared_fetch = slice_k_start;
...@@ -708,58 +889,54 @@ __global__ void Marlin( ...@@ -708,58 +889,54 @@ __global__ void Marlin(
if constexpr (!has_act_order) { if constexpr (!has_act_order) {
if constexpr (group_blocks == -1) { if constexpr (group_blocks == -1) {
s_gl_rd = s_sh_stride * slice_col + threadIdx.x; s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
} else { } else if constexpr (group_blocks >= thread_k_blocks) {
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
(w_type == vllm::kFE2M1f ? 2 : 1) +
s_sh_stride * slice_col + threadIdx.x; s_sh_stride * slice_col + threadIdx.x;
} else {
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks +
threadIdx.x / s_sh_stride) +
s_sh_stride * slice_col + threadIdx.x % s_sh_stride;
} }
} }
auto s_sh_wr = threadIdx.x; auto s_sh_wr = threadIdx.x;
bool s_sh_wr_pred = threadIdx.x < s_sh_stride; bool s_sh_wr_pred = threadIdx.x < s_sh_stage;
// Zero-points // Zero-points
int zp_gl_rd; int zp_gl_rd;
if constexpr (has_zp) { if constexpr (has_zp) {
if constexpr (group_blocks == -1) { if constexpr (group_blocks == -1) {
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
} else { } else if constexpr (group_blocks >= thread_k_blocks) {
zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
zp_sh_stride * slice_col + threadIdx.x; zp_sh_stride * slice_col + threadIdx.x;
} else {
zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks +
threadIdx.x / zp_sh_stride) +
zp_sh_stride * slice_col + threadIdx.x % zp_sh_stride;
} }
} }
auto zp_sh_wr = threadIdx.x; auto zp_sh_wr = threadIdx.x;
bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; bool zp_sh_wr_pred = zp_sh_stage > 0 && threadIdx.x < zp_sh_stage;
// We use a different scale layout for grouped and column-wise quantization as // We use a different scale layout for grouped and column-wise quantization as
// we scale a `half2` tile in column-major layout in the former and in // we scale a `half2` tile in column-major layout in the former and in
// row-major in the latter case. // row-major in the latter case.
int s_sh_rd; int s_sh_rd;
if constexpr (group_blocks != -1 && w_type == vllm::kFE2M1f) { if constexpr (is_a_8bit) {
auto warp_id = threadIdx.x / 32; s_sh_rd = 4 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 4);
int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps;
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 4;
s_sh_rd = s_sh_rd * 2 + (warp_row / group_blocks) % 2;
} else if constexpr (group_blocks != -1) } else if constexpr (group_blocks != -1)
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + s_sh_rd = 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) / 4;
(threadIdx.x % 32) / 4;
else if constexpr (group_blocks == -1 && else if constexpr (group_blocks == -1 &&
(m_block_size_8 || (has_zp && !dequant_skip_flop))) (m_block_size_8 || (has_zp && !dequant_skip_flop)))
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + s_sh_rd = 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) / 8;
(threadIdx.x % 32) / 8;
else else
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + s_sh_rd = 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) % 4;
(threadIdx.x % 32) % 4;
int bias_sh_rd; int bias_sh_rd;
if constexpr (m_block_size_8) { if constexpr (m_block_size_8) {
bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + bias_sh_rd = 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) / 8;
(threadIdx.x % 32) / 8;
} else { } else {
bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + bias_sh_rd = (is_a_8bit ? 4 : 8) * ((threadIdx.x / 32) % tb_n_warps) +
(threadIdx.x % 32) % 4; (threadIdx.x % 32) % 4;
} }
...@@ -775,12 +952,16 @@ __global__ void Marlin( ...@@ -775,12 +952,16 @@ __global__ void Marlin(
if constexpr (has_zp) { if constexpr (has_zp) {
if constexpr (is_zp_float) { if constexpr (is_zp_float) {
if constexpr (group_blocks != -1) { if constexpr (group_blocks != -1) {
zp_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + zp_sh_rd =
(threadIdx.x % 32) / 4; 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) / 4;
} }
} else if (is_a_8bit) {
zp_sh_rd = num_ints_per_thread * num_col_threads *
((threadIdx.x / 32) % tb_n_warps / 2) +
num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads);
} else { } else {
zp_sh_rd = num_ints_per_thread * num_col_threads * zp_sh_rd = num_ints_per_thread * num_col_threads *
((threadIdx.x / 32) % (thread_n_blocks / 4)) + ((threadIdx.x / 32) % tb_n_warps) +
num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads);
} }
} }
...@@ -807,18 +988,13 @@ __global__ void Marlin( ...@@ -807,18 +988,13 @@ __global__ void Marlin(
for (int i = 0; i < b_sh_wr_iters; i++) { for (int i = 0; i < b_sh_wr_iters; i++) {
#pragma unroll #pragma unroll
for (int j = 0; j < thread_m_blocks; j++) for (int j = 0; j < thread_m_blocks; j++)
a_sh_rd_trans[i][j] = a_sh_rd_trans[i][j] = transform_a(2 * i + a_sh_rd_delta_i * j + a_sh_rd);
transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
} }
// Since B-accesses have non-constant stride they have to be computed at // Since B-accesses have non-constant stride they have to be computed at
// runtime; we break dependencies between subsequent accesses with a tile by // runtime; we break dependencies between subsequent accesses with a tile by
// maintining multiple pointers (we have enough registers), a tiny // maintining multiple pointers (we have enough registers), a tiny
// optimization. // optimization.
const int4* B_ptr[b_sh_wr_iters];
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++)
B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
// Shared memory storage for global fetch pipelines. // Shared memory storage for global fetch pipelines.
constexpr int sh_red_size = (2 * thread_n_blocks + 1) * 16 * thread_m_blocks; constexpr int sh_red_size = (2 * thread_n_blocks + 1) * 16 * thread_m_blocks;
...@@ -847,19 +1023,12 @@ __global__ void Marlin( ...@@ -847,19 +1023,12 @@ __global__ void Marlin(
static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <= static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <=
stages * b_sh_stage); stages * b_sh_stage);
int4* sh_a = sh_s + sh_s_size; int4* sh_a = sh_s + sh_s_size;
constexpr int shm_size_used = moe_block_size +
stages * (g_idx_stage + zp_sh_stage) +
sh_s_size + sh_b_red_bias_size;
// all remaining shared memory is used to cache A (input)
// sh_a_max_row is at least ` stages * 16 * thread_m_blocks `
int sh_a_max_row =
((max_shared_mem - 1024) / 16 - shm_size_used) / (thread_k_blocks * 2);
// Register storage for double buffer of shared memory reads. // Register storage for double buffer of shared memory reads.
FragA frag_a[2][thread_m_blocks]; FragA frag_a[2][thread_m_blocks];
I4 frag_b_quant[2][b_thread_vecs]; I4 frag_b_quant[2][b_thread_vecs];
FragC frag_c[thread_m_blocks][4][2]; FragC frag_c[thread_m_blocks][is_a_8bit ? 2 : 4][2];
FragC frag_c_tmp[thread_m_blocks][is_a_8bit ? 2 : 4][2];
FragS frag_s[2][4]; // No act-order FragS frag_s[2][4]; // No act-order
FragS frag_bias[2][4]; FragS frag_bias[2][4];
FragS act_frag_s[2][4][4]; // For act-order FragS act_frag_s[2][4][4]; // For act-order
...@@ -867,6 +1036,24 @@ __global__ void Marlin( ...@@ -867,6 +1036,24 @@ __global__ void Marlin(
FragZP frag_zp; // Zero-points in fp16 FragZP frag_zp; // Zero-points in fp16
FragZP frag_zpf[2]; // Zero-points in fp16 in HQQ FragZP frag_zpf[2]; // Zero-points in fp16 in HQQ
if constexpr (is_a_8bit && group_blocks != -1) {
#pragma unroll
for (int j = 0; j < 2; j++) {
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
#pragma unroll
for (int g = 0; g < 4; g++) {
frag_c_tmp[i][j][0][g] = 0.0f;
}
#pragma unroll
for (int g = 0; g < 4; g++) {
frag_c_tmp[i][j][1][g] = 0.0f;
}
}
}
}
// Zero accumulators. // Zero accumulators.
auto zero_accums = [&]() { auto zero_accums = [&]() {
#pragma unroll #pragma unroll
...@@ -910,18 +1097,11 @@ __global__ void Marlin( ...@@ -910,18 +1097,11 @@ __global__ void Marlin(
} }
} }
}; };
// Asynchronously fetch the next A, B and s tile from global to the next // Asynchronously fetch the next A, B and s tile from global to the next
// shared memory pipeline location. // shared memory pipeline location.
bool should_load_a = true; auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
int max_num_stage_groups =
((sh_a_max_row - moe_block_size) / moe_block_size + 1) / stages;
max_num_stage_groups = max(max_num_stage_groups, 1);
auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true,
int pipe_a = 0) {
if (pred) { if (pred) {
if (should_load_a) { int4* sh_a_stage = sh_a + moe_block_size * a_sh_stride * pipe;
int4* sh_a_stage = sh_a + moe_block_size * a_sh_stride * pipe_a;
#pragma unroll #pragma unroll
for (int i = 0; i < a_sh_wr_iters; i++) { for (int i = 0; i < a_sh_wr_iters; i++) {
int row = a_gl_rd_delta_i / a_gl_stride * i + a_gl_rd_row; int row = a_gl_rd_delta_i / a_gl_stride * i + a_gl_rd_row;
...@@ -933,20 +1113,20 @@ __global__ void Marlin( ...@@ -933,20 +1113,20 @@ __global__ void Marlin(
cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[true_idx], cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[true_idx],
row < block_num_valid_tokens); row < block_num_valid_tokens);
} }
}
int4* sh_b_stage = sh_b + b_sh_stage * pipe; int4* sh_b_stage = sh_b + b_sh_stage * pipe;
#pragma unroll #pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) { for (int i = 0; i < (b_sh_wr_iters * b_thread_vecs); i++) {
#pragma unroll constexpr int count = div_ceil(b_sh_stride, threads);
for (int j = 0; j < b_thread_vecs; j++) { int b_gl_idx =
cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], b_gl_rd + (i % count) * threads +
B_ptr[i] + j + B_expert_off); b_gl_stride * (i / count) * div_ceil(threads, b_sh_stride);
}
B_ptr[i] += b_gl_rd_delta_o; cp_async4(&sh_b_stage[threads * i + threadIdx.x], &B[b_gl_idx]);
} }
b_gl_rd += b_gl_rd_delta_o;
if constexpr (has_act_order) { if constexpr (has_act_order) {
// Fetch g_idx thread-block portion // Fetch g_idx thread-block portion
int full_pipe = a_off; int full_pipe = a_off;
...@@ -966,44 +1146,24 @@ __global__ void Marlin( ...@@ -966,44 +1146,24 @@ __global__ void Marlin(
if constexpr (group_blocks != -1) { if constexpr (group_blocks != -1) {
int4* sh_s_stage = sh_s + s_sh_stage * pipe; int4* sh_s_stage = sh_s + s_sh_stage * pipe;
if constexpr (group_blocks >= thread_k_blocks) {
// Only fetch scales if this tile starts a new group // Only fetch scales if this tile starts a new group
if (pipe % (group_blocks / thread_k_blocks) == 0) { if (pipe % div_ceil(group_blocks, thread_k_blocks) == 0) {
if (s_sh_wr_pred) { if (s_sh_wr_pred) {
cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
} }
s_gl_rd += s_gl_rd_delta; s_gl_rd += s_gl_rd_delta * s_tb_groups;
}
} else {
for (int i = 0; i < s_tb_groups; i++) {
if (s_sh_wr_pred) {
cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr],
&scales_ptr[s_gl_rd]);
}
s_gl_rd += s_gl_rd_delta;
}
} }
} }
if constexpr (has_zp && group_blocks != -1) { if constexpr (has_zp && group_blocks != -1) {
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
if constexpr (group_blocks >= thread_k_blocks) { // Only fetch zero points if this tile starts a new group
// Only fetch zero-points if this tile starts a new group if (pipe % div_ceil(group_blocks, thread_k_blocks) == 0) {
if (pipe % (group_blocks / thread_k_blocks) == 0) {
if (zp_sh_wr_pred) { if (zp_sh_wr_pred) {
cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]);
} }
zp_gl_rd += zp_gl_rd_delta; zp_gl_rd += zp_gl_rd_delta * zp_tb_groups;
}
} else {
for (int i = 0; i < zp_tb_groups; i++) {
if (zp_sh_wr_pred) {
cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr],
&zp_ptr[zp_gl_rd]);
}
zp_gl_rd += zp_gl_rd_delta;
}
} }
} }
} }
...@@ -1037,18 +1197,18 @@ __global__ void Marlin( ...@@ -1037,18 +1197,18 @@ __global__ void Marlin(
// Load the next sub-tile from the current location in the shared memory pipe // Load the next sub-tile from the current location in the shared memory pipe
// into the current register buffer. // into the current register buffer.
auto fetch_to_registers = [&](int k, int pipe, int pipe_a = 0) { auto fetch_to_registers = [&](int k, int pipe) {
int4* sh_a_stage = sh_a + moe_block_size * a_sh_stride * pipe_a; int4* sh_a_stage = sh_a + moe_block_size * a_sh_stride * pipe;
#pragma unroll #pragma unroll
for (int i = 0; i < thread_m_blocks; i++) for (int i = 0; i < thread_m_blocks; i++)
ldsm<m_block_size_8 ? 2 : 4, scalar_t>( ldsm<m_block_size_8 ? 2 : 4, a_type_id>(
frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
int4* sh_b_stage = sh_b + b_sh_stage * pipe; int4* sh_b_stage = sh_b + b_sh_stage * pipe;
#pragma unroll #pragma unroll
for (int i = 0; i < b_thread_vecs; i++) { for (int i = 0; i < b_thread_vecs; i++) {
frag_b_quant[k % 2][i] = *reinterpret_cast<I4*>( frag_b_quant[k % 2][i] = *reinterpret_cast<I4*>(
&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); &sh_b_stage[b_sh_stride * (k % b_sh_wr_iters) + b_sh_rd + i]);
} }
}; };
...@@ -1072,53 +1232,54 @@ __global__ void Marlin( ...@@ -1072,53 +1232,54 @@ __global__ void Marlin(
auto fetch_scales_to_registers = [&](int k, int full_pipe) { auto fetch_scales_to_registers = [&](int k, int full_pipe) {
int pipe = full_pipe % stages; int pipe = full_pipe % stages;
using IT1 = typename std::conditional_t<is_a_8bit, int2, int4>;
using IT0 = typename std::conditional_t<is_a_8bit, int, int2>;
constexpr int group_blocks2 = div_ceil(group_blocks, is_a_8bit ? 2 : 1);
if constexpr (!has_act_order) { if constexpr (!has_act_order) {
// No act-order case // No act-order case
if constexpr (group_blocks == -1) { if constexpr (group_blocks == -1) {
// load only when starting a new slice // load only when starting a new slice
if (k == 0 && full_pipe == 0) { if (k == 0 && full_pipe == 0 && dequant_skip_flop) {
reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd]; reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd];
reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4]; reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
} }
} else if constexpr (group_blocks != -1) { } else if constexpr (group_blocks != -1) {
if constexpr (group_blocks >= thread_k_blocks) { if constexpr (group_blocks >= thread_k_blocks) {
constexpr int g = group_blocks / thread_k_blocks;
if (pipe % g == 0) {
if (k % b_sh_wr_iters == 0) { if (k % b_sh_wr_iters == 0) {
int4* sh_s_stage = int4* sh_s_stage = sh_s + s_sh_stage * (g * (pipe / g));
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
} else { } else {
reinterpret_cast<int4*>(&frag_s[1])[0] = reinterpret_cast<int4*>(&frag_s[1])[0] =
reinterpret_cast<int4*>(&frag_s[0])[0]; reinterpret_cast<int4*>(&frag_s[0])[0];
} }
} else { }
} else if (group_blocks2 < b_sh_wr_iters || k % b_sh_wr_iters == 0) {
auto warp_id = threadIdx.x / 32; auto warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4; int warp_row = warp_id / tb_n_warps;
int warp_row = warp_id / n_warps; int k_blocks = b_sh_wr_iters * warp_row + k % b_sh_wr_iters;
int cur_group_id = k_blocks / group_blocks2;
int cur_k = warp_row * 16;
cur_k += k_iter_size * (k % b_sh_wr_iters);
int k_blocks = cur_k / 16;
int cur_group_id =
k_blocks / (group_blocks * (w_type == vllm::kFE2M1f ? 2 : 1));
int4* sh_s_stage = sh_s + s_sh_stage * pipe; int4* sh_s_stage = sh_s + s_sh_stage * pipe;
if constexpr (w_type_id != vllm::kFE2M1f.id()) { if constexpr (b_type_id != vllm::kFE2M1f.id()) {
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
} else if constexpr (group_blocks == 1 || thread_k_blocks > 4) { } else {
reinterpret_cast<int2*>(&frag_s[k % 2])[0] = reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
reinterpret_cast<int2*>( reinterpret_cast<int2*>(
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
}
} else if (group_blocks >= b_sh_wr_iters) {
if constexpr (b_type_id != vllm::kFE2M1f.id()) {
reinterpret_cast<int4*>(&frag_s[1])[0] =
reinterpret_cast<int4*>(&frag_s[0])[0];
} else { } else {
reinterpret_cast<int2*>(&frag_s[k % 2])[0] = reinterpret_cast<int2*>(&frag_s[1])[0] =
reinterpret_cast<int2*>( reinterpret_cast<int2*>(&frag_s[0])[0];
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride) +
k % 2];
} }
} }
} }
...@@ -1139,18 +1300,15 @@ __global__ void Marlin( ...@@ -1139,18 +1300,15 @@ __global__ void Marlin(
cur_k = 0; cur_k = 0;
// Progress to current iteration // Progress to current iteration
cur_k += k_iter_size * (k % b_sh_wr_iters); cur_k += k % b_sh_wr_iters;
// Determine "position" inside the thread-block (based on warp and // Determine "position" inside the thread-block (based on warp and
// thread-id) // thread-id)
auto warp_id = threadIdx.x / 32; auto warp_id = threadIdx.x / 32;
int n_warps = int warp_row = warp_id / tb_n_warps;
thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N int warp_col = warp_id % tb_n_warps;
int warp_row = warp_id / n_warps;
int warp_col = warp_id % n_warps;
cur_k += warp_row * 16; cur_k += warp_row * 16 * b_sh_wr_iters;
auto th_id = threadIdx.x % 32; auto th_id = threadIdx.x % 32;
cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix
...@@ -1205,18 +1363,16 @@ __global__ void Marlin( ...@@ -1205,18 +1363,16 @@ __global__ void Marlin(
if constexpr (group_blocks == -1) { if constexpr (group_blocks == -1) {
// load only when starting a new slice // load only when starting a new slice
if (k == 0 && full_pipe == 0) { if (k == 0 && full_pipe == 0 || is_a_8bit) {
#pragma unroll #pragma unroll
for (int i = 0; i < num_ints_per_thread; i++) { for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] = (reinterpret_cast<int*>(sh_zp))[zp_sh_rd + i]; frag_qzp[k % 2][i] = (reinterpret_cast<int*>(sh_zp))[zp_sh_rd + i];
} }
} }
} else if constexpr (group_blocks >= thread_k_blocks) { } else if constexpr (group_blocks >= thread_k_blocks) {
if (k % b_sh_wr_iters == 0) { constexpr int g = group_blocks / thread_k_blocks;
int4* sh_zp_stage = if (pipe % g == 0 && k % b_sh_wr_iters == 0 || is_a_8bit) {
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * int4* sh_zp_stage = sh_zp + zp_sh_stage * (g * (pipe / g));
(pipe / (group_blocks / thread_k_blocks)));
#pragma unroll #pragma unroll
for (int i = 0; i < num_ints_per_thread; i++) { for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] = frag_qzp[k % 2][i] =
...@@ -1225,21 +1381,11 @@ __global__ void Marlin( ...@@ -1225,21 +1381,11 @@ __global__ void Marlin(
} }
} else { } else {
auto warp_id = threadIdx.x / 32; auto warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps;
int cur_k = warp_row * 16;
cur_k += k_iter_size * (k % b_sh_wr_iters);
int k_blocks = cur_k / 16; int warp_row = warp_id / tb_n_warps;
int cur_group_id = 0;
// Suppress bogus and persistent divide-by-zero warning int k_blocks = b_sh_wr_iters * warp_row + k % b_sh_wr_iters;
#pragma nv_diagnostic push int cur_group_id = k_blocks / div_ceil(group_blocks, is_a_8bit ? 2 : 1);
#pragma nv_diag_suppress divide_by_zero
cur_group_id = k_blocks / group_blocks;
#pragma nv_diagnostic pop
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
...@@ -1258,29 +1404,18 @@ __global__ void Marlin( ...@@ -1258,29 +1404,18 @@ __global__ void Marlin(
if constexpr (group_blocks != -1) { if constexpr (group_blocks != -1) {
if constexpr (group_blocks >= thread_k_blocks) { if constexpr (group_blocks >= thread_k_blocks) {
if (k % b_sh_wr_iters == 0) { constexpr int g = group_blocks / thread_k_blocks;
int4* sh_zp_stage = if (pipe % g == 0 && k % b_sh_wr_iters == 0) {
sh_zp + int4* sh_zp_stage = sh_zp + zp_sh_stage * (g * (pipe / g));
zp_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] = reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] =
sh_zp_stage[zp_sh_rd]; sh_zp_stage[zp_sh_rd];
} }
} else { } else if (group_blocks < b_sh_wr_iters || k % b_sh_wr_iters == 0) {
auto warp_id = threadIdx.x / 32; auto warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps;
int cur_k = warp_row * 16; int warp_row = warp_id / tb_n_warps;
cur_k += k_iter_size * (k % b_sh_wr_iters); int k_blocks = b_sh_wr_iters * warp_row + k % b_sh_wr_iters;
int k_blocks = cur_k / 16;
// Suppress bogus and persistent divide-by-zero warning
#pragma nv_diagnostic push
#pragma nv_diag_suppress divide_by_zero
int cur_group_id = k_blocks / group_blocks; int cur_group_id = k_blocks / group_blocks;
#pragma nv_diagnostic pop
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
...@@ -1291,33 +1426,46 @@ __global__ void Marlin( ...@@ -1291,33 +1426,46 @@ __global__ void Marlin(
} }
}; };
auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) { auto dequant_data = [&](int q, scalar_32bit_t* frag_b_ptr, int zp = 0) {
dequant<scalar_t2, w_type_id, dequant_skip_flop>(q, frag_b_ptr); if constexpr (a_type.size_bits() != b_type.size_bits()) {
if constexpr (is_a_8bit && has_zp) {
sub_zp_and_dequant<scalar_32bit_t, b_type_id, dequant_skip_flop>(
q, frag_b_ptr, zp);
} else {
dequant<scalar_32bit_t, b_type_id, dequant_skip_flop>(q, frag_b_ptr);
}
}
}; };
// Execute the actual tensor core matmul of a sub-tile. // Execute the actual tensor core matmul of a sub-tile.
bool is_first_matmul_in_slice = true; bool is_first_matmul_in_slice = true;
auto matmul = [&](int k) { auto matmul = [&](int k, int pipe) {
if (is_a_8bit) return;
int k2 = k % 2; int k2 = k % 2;
constexpr int g =
group_blocks > 0 ? div_ceil(group_blocks, thread_k_blocks) : 1;
const bool is_new_zp = const bool is_new_zp =
((group_blocks != -1) && (group_blocks < thread_k_blocks || k == 0)) || (group_blocks == 0) ||
((group_blocks > 0) && (group_blocks < b_sh_wr_iters || k == 0)) &&
(pipe % g == 0) ||
(group_blocks == -1 && is_first_matmul_in_slice); (group_blocks == -1 && is_first_matmul_in_slice);
if constexpr (has_zp && !is_zp_float) { if constexpr (has_zp && !is_zp_float) {
if (is_new_zp) { if (is_new_zp) {
if constexpr (group_blocks == -1) is_first_matmul_in_slice = false; if constexpr (group_blocks == -1) is_first_matmul_in_slice = false;
int zp_quant_0, zp_quant_1; int zp_quant_0, zp_quant_1;
if constexpr (w_type.size_bits() == 4) { if constexpr (b_type.size_bits() == 4) {
zp_quant_0 = frag_qzp[k2][0]; zp_quant_0 = frag_qzp[k2][0];
zp_quant_1 = zp_quant_0 >> 8; zp_quant_1 = zp_quant_0 >> 8;
} else { } else {
static_assert(w_type.size_bits() == 8); static_assert(b_type.size_bits() == 8);
zp_quant_0 = frag_qzp[k2][0]; zp_quant_0 = frag_qzp[k2][0];
zp_quant_1 = frag_qzp[k2][1]; zp_quant_1 = frag_qzp[k2][1];
} }
dequant_data(zp_quant_0, reinterpret_cast<scalar_t2*>(&frag_zp)); dequant_data(zp_quant_0, reinterpret_cast<scalar_32bit_t*>(&frag_zp));
dequant_data(zp_quant_1, reinterpret_cast<scalar_t2*>(&frag_zp) + 2); dequant_data(zp_quant_1,
reinterpret_cast<scalar_32bit_t*>(&frag_zp) + 2);
} }
} }
if constexpr (!dequant_skip_flop && has_zp && is_zp_float) { if constexpr (!dequant_skip_flop && has_zp && is_zp_float) {
...@@ -1327,14 +1475,14 @@ __global__ void Marlin( ...@@ -1327,14 +1475,14 @@ __global__ void Marlin(
} }
} }
if constexpr (w_type == vllm::kFE2M1f) { if constexpr (b_type == vllm::kFE2M1f) {
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0]; int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1]; int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];
dequant_fp8_scales<scalar_t2, s_type_id>( dequant_fp8_scales<c_scalar_t2, s_type_id>(
s_quant_0, reinterpret_cast<scalar_t2*>(&frag_s[k2])); s_quant_0, reinterpret_cast<c_scalar_t2*>(&frag_s[k2]));
dequant_fp8_scales<scalar_t2, s_type_id>( dequant_fp8_scales<c_scalar_t2, s_type_id>(
s_quant_1, reinterpret_cast<scalar_t2*>(&frag_s[k2]) + 2); s_quant_1, reinterpret_cast<c_scalar_t2*>(&frag_s[k2]) + 2);
} }
// We have the m dimension as the inner loop in order to encourage overlapping // We have the m dimension as the inner loop in order to encourage overlapping
...@@ -1345,61 +1493,168 @@ __global__ void Marlin( ...@@ -1345,61 +1493,168 @@ __global__ void Marlin(
FragB frag_b1; FragB frag_b1;
int b_quant_0, b_quant_1; int b_quant_0, b_quant_1;
if constexpr (w_type_id == vllm::kFE2M1f.id()) { if constexpr (b_type_id == vllm::kFE2M1f.id()) {
b_quant_1 = frag_b_quant[k2][0][j]; b_quant_1 = frag_b_quant[k2][0][j];
b_quant_0 = b_quant_1 << 8; b_quant_0 = b_quant_1 << 8;
} else if constexpr (w_type.size_bits() == 4) { } else if constexpr (b_type.size_bits() == 4) {
b_quant_0 = frag_b_quant[k2][0][j]; b_quant_0 = frag_b_quant[k2][0][j];
b_quant_1 = b_quant_0 >> 8; b_quant_1 = b_quant_0 >> 8;
} else { } else {
static_assert(w_type.size_bits() == 8); static_assert(b_type.size_bits() == 8);
int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k2]); int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k2]);
b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
} }
dequant_data(b_quant_0, reinterpret_cast<scalar_t2*>(&frag_b0)); dequant_data(b_quant_0, reinterpret_cast<scalar_32bit_t*>(&frag_b0));
dequant_data(b_quant_1, reinterpret_cast<scalar_t2*>(&frag_b1)); dequant_data(b_quant_1, reinterpret_cast<scalar_32bit_t*>(&frag_b1));
if constexpr (dequant_skip_flop && has_zp && !is_zp_float) { if constexpr (dequant_skip_flop && has_zp && !is_zp_float && !is_a_8bit) {
sub_zp<scalar_t>(frag_b0, frag_zp[j], 0); sub_zp<a_type_id>(frag_b0, frag_zp[j], 0);
sub_zp<scalar_t>(frag_b1, frag_zp[j], 1); sub_zp<a_type_id>(frag_b1, frag_zp[j], 1);
} }
// Apply scale to frag_b0 // Apply scale to frag_b0
if constexpr (has_act_order) { if constexpr (has_act_order && !is_a_8bit) {
static_assert(group_blocks != -1); static_assert(group_blocks != -1);
scale4<scalar_t>(frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j], scale4<a_type_id>(frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j],
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0); act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0);
scale4<scalar_t>(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], scale4<a_type_id>(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j],
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1); act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1);
} else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float && } else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float &&
group_blocks == -1) { group_blocks == -1 && !is_a_8bit) {
int idx = (threadIdx.x / 4) % 2; int idx = (threadIdx.x / 4) % 2;
scalar_t2 s2 = Dtype::nums2num2( scalar_t2 s2 = Adtype::nums2num2(
reinterpret_cast<scalar_t*>(&frag_s[j / 2][j % 2 * 2 + 0])[idx], reinterpret_cast<scalar_t*>(&frag_s[j / 2][j % 2 * 2 + 0])[idx],
reinterpret_cast<scalar_t*>(&frag_s[j / 2][j % 2 * 2 + 1])[idx]); reinterpret_cast<scalar_t*>(&frag_s[j / 2][j % 2 * 2 + 1])[idx]);
if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2); if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2);
scale_and_sub<scalar_t>(frag_b0, s2.x, frag_zp[j].x); scale_and_sub<a_type_id>(frag_b0, s2.x, frag_zp[j].x);
scale_and_sub<scalar_t>(frag_b1, s2.y, frag_zp[j].y); scale_and_sub<a_type_id>(frag_b1, s2.y, frag_zp[j].y);
} else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1) { } else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1 &&
!is_a_8bit) {
if (is_new_zp) if (is_new_zp)
frag_zp[j] = __hmul2(frag_zp[j], frag_zp[j] = __hmul2(frag_zp[j],
*reinterpret_cast<scalar_t2*>(&frag_s[k2][j])); *reinterpret_cast<scalar_t2*>(&frag_s[k2][j]));
scale_and_sub<scalar_t>(frag_b0, frag_s[k2][j][0].x, frag_zp[j].x); scale_and_sub<a_type_id>(frag_b0, frag_s[k2][j][0].x, frag_zp[j].x);
scale_and_sub<scalar_t>(frag_b1, frag_s[k2][j][0].y, frag_zp[j].y); scale_and_sub<a_type_id>(frag_b1, frag_s[k2][j][0].y, frag_zp[j].y);
} else if constexpr (group_blocks != -1) { } else if constexpr (group_blocks != -1 && !is_a_8bit) {
scale<scalar_t>(frag_b0, frag_s[k2][j], 0); scale<a_type_id>(frag_b0, frag_s[k2][j], 0);
scale<scalar_t>(frag_b1, frag_s[k2][j], 1); scale<a_type_id>(frag_b1, frag_s[k2][j], 1);
} }
#pragma unroll #pragma unroll
for (int i = 0; i < thread_m_blocks; i++) { for (int i = 0; i < thread_m_blocks; i++) {
if constexpr (m_block_size_8) { if constexpr (m_block_size_8) {
mma_trans<scalar_t>(frag_a[k2][i], frag_b0, frag_b1, frag_c[i][j][0]); mma_trans<a_type_id>(frag_a[k2][i], frag_b0, frag_b1,
frag_c[i][j][0]);
} else { } else {
mma<scalar_t>(frag_a[k2][i], frag_b0, frag_c[i][j][0]); mma<a_type_id>(frag_a[k2][i], frag_b0, frag_c[i][j][0]);
mma<scalar_t>(frag_a[k2][i], frag_b1, frag_c[i][j][1]); mma<a_type_id>(frag_a[k2][i], frag_b1, frag_c[i][j][1]);
}
}
}
};
auto matmul_a8 = [&](int k) {
int k2 = k % 2;
#pragma unroll
for (int j = 0; j < 2; j++) {
FragB frag_b[2];
if (is_a_8bit && b_type.size_bits() == 4 && !has_zp) {
dequant_data(frag_b_quant[k2][0][j * 2],
reinterpret_cast<scalar_32bit_t*>(&frag_b));
dequant_data(frag_b_quant[k2][0][j * 2 + 1],
reinterpret_cast<scalar_32bit_t*>(&frag_b) + 2);
} else if (is_a_8bit && b_type.size_bits() == 4 && has_zp) {
int off = (threadIdx.x / 32) % 2 * 2 + j;
int zp = (frag_qzp[k2][0] >> (off * 8)) & 0xF;
dequant_data(frag_b_quant[k2][0][j * 2],
reinterpret_cast<scalar_32bit_t*>(&frag_b), zp);
zp = (frag_qzp[k2][0] >> (off * 8 + 4)) & 0xF;
dequant_data(frag_b_quant[k2][0][j * 2 + 1],
reinterpret_cast<scalar_32bit_t*>(&frag_b) + 2, zp);
} else {
reinterpret_cast<int2*>(&frag_b)[0] =
reinterpret_cast<int2*>(&frag_b_quant[k2][j])[0];
reinterpret_cast<int2*>(&frag_b)[1] =
reinterpret_cast<int2*>(&frag_b_quant[k2][j])[1];
}
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
mma<a_type_id, 32>(frag_a[k2][i], frag_b[0],
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]);
mma<a_type_id, 32>(frag_a[k2][i], frag_b[1],
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]);
}
if constexpr (group_blocks != -1) {
if (group_blocks == 2 || k == 1) {
if constexpr (a_type == vllm::kS8) {
int2 s_vals[2];
s_vals[0] = {
(int)reinterpret_cast<uint16_t*>(&frag_s[k2][j * 2][0])[0],
(int)reinterpret_cast<uint16_t*>(&frag_s[k2][j * 2][0])[1]};
s_vals[1] = {
(int)reinterpret_cast<uint16_t*>(&frag_s[k2][j * 2 + 1][0])[0],
(int)reinterpret_cast<uint16_t*>(&frag_s[k2][j * 2 + 1][0])[1]};
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
#pragma unroll
for (int g = 0; g < 4; g++) {
int scale = reinterpret_cast<int*>(&s_vals[0])[g % 2];
*reinterpret_cast<int32_t*>(&frag_c[i][j][0][g]) +=
*reinterpret_cast<int32_t*>(&frag_c_tmp[i][j][0][g]) *
scale;
frag_c_tmp[i][j][0][g] = 0.0f;
}
#pragma unroll
for (int g = 0; g < 4; g++) {
int scale = reinterpret_cast<int*>(&s_vals[1])[g % 2];
*reinterpret_cast<int32_t*>(&frag_c[i][j][1][g]) +=
*reinterpret_cast<int32_t*>(&frag_c_tmp[i][j][1][g]) *
scale;
frag_c_tmp[i][j][1][g] = 0.0f;
}
}
} else {
float2 s_vals[2];
if constexpr (s_type_id != vllm::kFE8M0fnu.id()) {
static_assert(a_type.size_bits() == 16 ||
s_type.size_bits() == 16);
s_vals[0] = Cdtype::num22float2(frag_s[k2][j * 2][0]);
s_vals[1] = Cdtype::num22float2(frag_s[k2][j * 2 + 1][0]);
} else {
int32_t* s_vals_int = reinterpret_cast<int32_t*>(&s_vals[0]);
int32_t s_vals_e8m0 =
*reinterpret_cast<int32_t*>(&frag_s[k2][j][0]);
s_vals_int[0] = (s_vals_e8m0 & 0xFF) << 23;
s_vals_int[1] = (s_vals_e8m0 & 0xFF00) << 15;
s_vals_int[2] = (s_vals_e8m0 & 0xFF0000) << 7;
s_vals_int[3] = (s_vals_e8m0 & 0xFF000000) >> 1;
}
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
#pragma unroll
for (int g = 0; g < 4; g++) {
float scale = reinterpret_cast<float*>(&s_vals[0])[g % 2];
frag_c[i][j][0][g] += frag_c_tmp[i][j][0][g] * scale;
frag_c_tmp[i][j][0][g] = 0.0f;
}
#pragma unroll
for (int g = 0; g < 4; g++) {
float scale = reinterpret_cast<float*>(&s_vals[1])[g % 2];
frag_c[i][j][1][g] += frag_c_tmp[i][j][1][g] * scale;
frag_c_tmp[i][j][1][g] = 0.0f;
}
}
}
} }
} }
} }
...@@ -1413,7 +1668,8 @@ __global__ void Marlin( ...@@ -1413,7 +1668,8 @@ __global__ void Marlin(
constexpr int red_off = threads / b_sh_stride_threads / 2; constexpr int red_off = threads / b_sh_stride_threads / 2;
if (red_off >= 1) { if (red_off >= 1) {
auto red_idx = threadIdx.x / b_sh_stride_threads; auto red_idx = threadIdx.x / b_sh_stride_threads;
constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; constexpr int red_sh_stride =
b_sh_stride_threads * (is_a_8bit ? 2 : 4) * 2;
constexpr int red_sh_delta = b_sh_stride_threads; constexpr int red_sh_delta = b_sh_stride_threads;
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
(threadIdx.x % b_sh_stride_threads); (threadIdx.x % b_sh_stride_threads);
...@@ -1428,7 +1684,8 @@ __global__ void Marlin( ...@@ -1428,7 +1684,8 @@ __global__ void Marlin(
for (int i = red_off; i > 0; i /= 2) { for (int i = red_off; i > 0; i /= 2) {
if (i <= red_idx && red_idx < 2 * i) { if (i <= red_idx && red_idx < 2 * i) {
#pragma unroll #pragma unroll
for (int j = 0; j < 4 * 2; j += (m_block_size_8 ? 2 : 1)) { for (int j = 0; j < (is_a_8bit ? 2 : 4) * 2;
j += (m_block_size_8 ? 2 : 1)) {
int red_sh_wr = int red_sh_wr =
red_sh_delta * j + (red_sh_rd - red_sh_stride * i); red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
if (i < red_off) { if (i < red_off) {
...@@ -1437,24 +1694,26 @@ __global__ void Marlin( ...@@ -1437,24 +1694,26 @@ __global__ void Marlin(
float* c_wr = reinterpret_cast<float*>(&sh_red[red_sh_wr]); float* c_wr = reinterpret_cast<float*>(&sh_red[red_sh_wr]);
#pragma unroll #pragma unroll
for (int k = 0; k < 4; k++) for (int k = 0; k < 4; k++)
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] += reinterpret_cast<FragC*>(
frag_c)[(is_a_8bit ? 2 : 4) * 2 * m_block + j][k] +=
c_rd[k] + c_wr[k]; c_rd[k] + c_wr[k];
} }
sh_red[red_sh_wr] = sh_red[red_sh_wr] = reinterpret_cast<int4*>(
reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j]; &frag_c)[(is_a_8bit ? 2 : 4) * 2 * m_block + j];
} }
} }
__syncthreads(); __syncthreads();
} }
if (red_idx == 0) { if (red_idx == 0) {
#pragma unroll #pragma unroll
for (int i = 0; i < 4 * 2; i += (m_block_size_8 ? 2 : 1)) { for (int i = 0; i < (is_a_8bit ? 2 : 4) * 2;
i += (m_block_size_8 ? 2 : 1)) {
float* c_rd = float* c_rd =
reinterpret_cast<float*>(&sh_red[red_sh_delta * i + red_sh_rd]); reinterpret_cast<float*>(&sh_red[red_sh_delta * i + red_sh_rd]);
#pragma unroll #pragma unroll
for (int j = 0; j < 4; j++) for (int j = 0; j < 4; j++)
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] += reinterpret_cast<FragC*>(
c_rd[j]; frag_c)[(is_a_8bit ? 2 : 4) * 2 * m_block + i][j] += c_rd[j];
} }
} }
__syncthreads(); __syncthreads();
...@@ -1470,13 +1729,13 @@ __global__ void Marlin( ...@@ -1470,13 +1729,13 @@ __global__ void Marlin(
// We are very careful here to reduce directly in the output buffer to // We are very careful here to reduce directly in the output buffer to
// maximize L2 cache utilization in this step. To do this, we write out // maximize L2 cache utilization in this step. To do this, we write out
// results in FP16 (but still reduce with FP32 compute). // results in FP16 (but still reduce with FP32 compute).
constexpr int active_threads = 32 * thread_n_blocks / 4; constexpr int active_threads = 32 * tb_n_warps;
bool is_th_active = threadIdx.x < active_threads; bool is_th_active = threadIdx.x < active_threads;
if (!is_th_active) { if (!is_th_active) {
return; return;
} }
int c_gl_stride = prob_n / 8; int c_gl_stride = prob_n / 8 * (is_a_8bit ? 2 : 1);
int c_gl_wr_delta_o = 8 * c_gl_stride; int c_gl_wr_delta_o = 8 * c_gl_stride;
int c_gl_wr_delta_i = 4 * (active_threads / 32); int c_gl_wr_delta_i = 4 * (active_threads / 32);
int c_gl_wr; int c_gl_wr;
...@@ -1487,7 +1746,7 @@ __global__ void Marlin( ...@@ -1487,7 +1746,7 @@ __global__ void Marlin(
} else { } else {
c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) +
4 * (threadIdx.x / 32) + threadIdx.x % 4; 4 * (threadIdx.x / 32) + threadIdx.x % 4;
c_gl_wr += (2 * thread_n_blocks) * slice_col; c_gl_wr += (2 * thread_n_blocks) * slice_col * (is_a_8bit ? 2 : 1);
} }
constexpr int c_sh_wr_delta = active_threads; constexpr int c_sh_wr_delta = active_threads;
int c_sh_wr = threadIdx.x; int c_sh_wr = threadIdx.x;
...@@ -1506,37 +1765,51 @@ __global__ void Marlin( ...@@ -1506,37 +1765,51 @@ __global__ void Marlin(
if (c_idx / c_gl_stride < block_num_valid_tokens) { if (c_idx / c_gl_stride < block_num_valid_tokens) {
int64_t sorted_row = sh_block_sorted_ids[c_idx / c_gl_stride]; int64_t sorted_row = sh_block_sorted_ids[c_idx / c_gl_stride];
int64_t true_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; int64_t true_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride;
if constexpr (is_a_8bit) {
int2* sh_red_int2 = reinterpret_cast<int2*>(sh_red);
int2* c_int2 = reinterpret_cast<int2*>(C);
sh_red_int2[c_sh_wr + c_sh_wr_delta * i] = c_int2[true_idx];
} else {
sh_red[c_sh_wr + c_sh_wr_delta * i] = C[true_idx]; sh_red[c_sh_wr + c_sh_wr_delta * i] = C[true_idx];
} }
} }
} }
}
#pragma unroll #pragma unroll
for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) { for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) {
if (!first) { if (!first) {
int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta]; c_scalar_t* c_red_f16;
if constexpr (is_a_8bit) {
int2 tmp =
reinterpret_cast<int2*>(sh_red)[c_sh_wr + i * c_sh_wr_delta];
c_red_f16 = reinterpret_cast<c_scalar_t*>(&tmp);
} else {
int4 tmp = sh_red[c_sh_wr + i * c_sh_wr_delta];
c_red_f16 = reinterpret_cast<c_scalar_t*>(&tmp);
}
#pragma unroll #pragma unroll
for (int j = 0; j < 2 * 4; j++) { for (int j = 0; j < 2 * (is_a_8bit ? 2 : 4); j++) {
int delta = 0; int delta = 0;
if constexpr (m_block_size_8) { if constexpr (m_block_size_8) {
delta = j % 2 == 1 ? -2 : 0; delta = j % 2 == 1 ? -2 : 0;
} }
reinterpret_cast<float*>( reinterpret_cast<float*>(
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta] += &frag_c)[(is_a_8bit ? 2 : 4) * 2 * 4 * (i / 4) + 4 * j + (i % 4) +
Dtype::num2float(reinterpret_cast<scalar_t*>(&c_red)[j]); delta] += Cdtype::num2float(c_red_f16[j]);
} }
} }
if (!last) { if (!last) {
int4 c; c_scalar_t c_f16[is_a_8bit ? 4 : 8];
#pragma unroll #pragma unroll
for (int j = 0; j < 2 * 4; j++) { for (int j = 0; j < 2 * (is_a_8bit ? 2 : 4); j++) {
int delta = 0; int delta = 0;
if constexpr (m_block_size_8) { if constexpr (m_block_size_8) {
delta = j % 2 == 1 ? -2 : 0; delta = j % 2 == 1 ? -2 : 0;
} }
reinterpret_cast<scalar_t*>(&c)[j] = c_f16[j] = Cdtype::float2num(reinterpret_cast<float*>(
Dtype::float2num(reinterpret_cast<float*>( &frag_c)[(is_a_8bit ? 2 : 4) * 2 * 4 * (i / 4) + 4 * j + (i % 4) +
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta]); delta]);
} }
int c_idx; int c_idx;
...@@ -1549,7 +1822,12 @@ __global__ void Marlin( ...@@ -1549,7 +1822,12 @@ __global__ void Marlin(
if (c_idx / c_gl_stride < block_num_valid_tokens) { if (c_idx / c_gl_stride < block_num_valid_tokens) {
int64_t sorted_row = sh_block_sorted_ids[c_idx / c_gl_stride]; int64_t sorted_row = sh_block_sorted_ids[c_idx / c_gl_stride];
int64_t true_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; int64_t true_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride;
C[true_idx] = c; if constexpr (is_a_8bit) {
int2* c_int2 = reinterpret_cast<int2*>(C);
c_int2[true_idx] = *reinterpret_cast<int2*>(c_f16);
} else {
C[true_idx] = *reinterpret_cast<int4*>(c_f16);
}
} }
} }
} }
...@@ -1563,10 +1841,10 @@ __global__ void Marlin( ...@@ -1563,10 +1841,10 @@ __global__ void Marlin(
constexpr int c_size = tb_m * tb_n * sizeof(float) / 16; constexpr int c_size = tb_m * tb_n * sizeof(float) / 16;
constexpr int active_threads = 32 * thread_n_blocks / 4; constexpr int active_threads = 32 * tb_n_warps;
bool is_th_active = threadIdx.x < active_threads; bool is_th_active = threadIdx.x < active_threads;
constexpr int num_floats = thread_m_blocks * 4 * 2 * 4; constexpr int num_floats = thread_m_blocks * (is_a_8bit ? 2 : 4) * 2 * 4;
constexpr int th_size = num_floats * sizeof(float) / 16; constexpr int th_size = num_floats * sizeof(float) / 16;
int c_cur_offset = locks_off * c_size; int c_cur_offset = locks_off * c_size;
...@@ -1634,7 +1912,7 @@ __global__ void Marlin( ...@@ -1634,7 +1912,7 @@ __global__ void Marlin(
} else { } else {
c_sh_wr = c_sh_wr =
(4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;
c_sh_wr += 32 * (threadIdx.x / 32); c_sh_wr += (is_a_8bit ? 16 : 32) * (threadIdx.x / 32);
} }
int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) +
...@@ -1643,49 +1921,49 @@ __global__ void Marlin( ...@@ -1643,49 +1921,49 @@ __global__ void Marlin(
// We first reorder in shared memory to guarantee the most efficient final // We first reorder in shared memory to guarantee the most efficient final
// global write patterns // global write patterns
auto write = [&](int idx, float c0, float c1, FragS& s, FragS& b_bias) { auto write = [&](int idx, float c0, float c1, FragS& s, FragS& b_bias) {
scalar_t2 res = c_scalar_t2 res =
Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); Cdtype::nums2num2(Cdtype::float2num(c0), Cdtype::float2num(c1));
// For per-column quantization we finally apply the scale here (only for // For per-column quantization we finally apply the scale here (only for
// 4-bit) // 4-bit)
if constexpr (!has_act_order && group_blocks == -1 && if constexpr (!has_act_order && group_blocks == -1 && !is_a_8bit &&
w_type.size_bits() == 4 && b_type.size_bits() == 4 &&
(has_zp && dequant_skip_flop || !has_zp)) { (has_zp && dequant_skip_flop || !has_zp)) {
scalar_t2 tmp_scale = s[0]; c_scalar_t2 tmp_scale = s[0];
if constexpr (m_block_size_8) { if constexpr (m_block_size_8) {
tmp_scale = Dtype::num2num2( tmp_scale = Cdtype::num2num2(
reinterpret_cast<scalar_t*>(&s[0])[(threadIdx.x % 8) / 4]); reinterpret_cast<scalar_t*>(&s[0])[(threadIdx.x % 8) / 4]);
} }
res = __hmul2(res, tmp_scale); res = __hmul2(res, tmp_scale);
} }
if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
if (!mul_topk_weights) { if (!mul_topk_weights) {
res = __hmul2(res, global_scale); res = __hmul2(res, global_scale);
} }
} }
if (has_bias && last) { if (has_bias && last) {
scalar_t2 tmp_bias = b_bias[0]; c_scalar_t2 tmp_bias = b_bias[0];
if constexpr (m_block_size_8) { if constexpr (m_block_size_8) {
tmp_bias = Dtype::num2num2( tmp_bias = Cdtype::num2num2(
reinterpret_cast<scalar_t*>(&b_bias[0])[(threadIdx.x % 8) / 4]); reinterpret_cast<scalar_t*>(&b_bias[0])[(threadIdx.x % 8) / 4]);
} }
res = __hadd2(res, tmp_bias); res = __hadd2(res, tmp_bias);
} }
if constexpr (m_block_size_8) { if constexpr (m_block_size_8) {
((scalar_t*)sh_red)[idx] = res.x; ((c_scalar_t*)sh_red)[idx] = res.x;
((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y; ((c_scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y;
} else { } else {
((scalar_t2*)sh_red)[idx] = res; ((c_scalar_t2*)sh_red)[idx] = res;
} }
}; };
if (threadIdx.x / 32 < thread_n_blocks / 4) { if (threadIdx.x / 32 < tb_n_warps) {
#pragma unroll #pragma unroll
for (int i = 0; i < thread_m_blocks; i++) { for (int i = 0; i < thread_m_blocks; i++) {
#pragma unroll #pragma unroll
for (int j = 0; j < 4; j++) { for (int j = 0; j < (is_a_8bit ? 2 : 4); j++) {
if constexpr (m_block_size_8) { if constexpr (m_block_size_8) {
int wr = c_sh_wr + 16 * j; int wr = c_sh_wr + 16 * j;
write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1], write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1],
...@@ -1723,24 +2001,26 @@ __global__ void Marlin( ...@@ -1723,24 +2001,26 @@ __global__ void Marlin(
if (row < block_num_valid_tokens) { if (row < block_num_valid_tokens) {
int64_t sorted_row = sh_block_sorted_ids[row]; int64_t sorted_row = sh_block_sorted_ids[row];
int64_t true_idx = sorted_row * c_gl_stride + c_gl_wr % c_gl_stride; int64_t true_idx = sorted_row * c_gl_stride + c_gl_wr % c_gl_stride;
scalar_t2 topk_weight_score; c_scalar_t2 topk_weight_score;
if (mul_topk_weights) topk_weight_score = sh_block_topk_weights[row]; if (mul_topk_weights) topk_weight_score = sh_block_topk_weights[row];
if (use_atomic_add && slice_count > 1 || mul_topk_weights) { if (use_atomic_add && slice_count > 1 || mul_topk_weights) {
scalar_t2* C_half2 = reinterpret_cast<scalar_t2*>(&C[true_idx]); c_scalar_t2* C_half2 = reinterpret_cast<c_scalar_t2*>(&C[true_idx]);
scalar_t2* sh_red_half2 = c_scalar_t2* sh_red_half2 =
reinterpret_cast<scalar_t2*>(&sh_red[c_sh_rd]); reinterpret_cast<c_scalar_t2*>(&sh_red[c_sh_rd]);
if (mul_topk_weights) {
#pragma unroll #pragma unroll
for (int a = 0; a < 4; a++) { for (int a = 0; a < 4; a++) {
scalar_t2 res = sh_red_half2[a]; sh_red_half2[a] = __hmul2(sh_red_half2[a], topk_weight_score);
if (mul_topk_weights) { }
res = __hmul2(res, topk_weight_score);
} }
if (use_atomic_add && slice_count > 1) { if (use_atomic_add && slice_count > 1) {
atomicAdd(&C_half2[a], res); #pragma unroll
for (int a = 0; a < 4; a++) {
atomicAdd(&C_half2[a], sh_red_half2[a]);
}
} else { } else {
C_half2[a] = res; C[true_idx] = *reinterpret_cast<int4*>(sh_red_half2);
};
} }
} else { } else {
C[true_idx] = sh_red[c_sh_rd]; C[true_idx] = sh_red[c_sh_rd];
...@@ -1774,7 +2054,7 @@ __global__ void Marlin( ...@@ -1774,7 +2054,7 @@ __global__ void Marlin(
} }
} }
} }
fetch_to_shared(i, i, i < slice_iters, i); fetch_to_shared(i, i, i < slice_iters);
} }
zero_accums(); zero_accums();
...@@ -1799,30 +2079,27 @@ __global__ void Marlin( ...@@ -1799,30 +2079,27 @@ __global__ void Marlin(
// have even length meaning that the next iteration will always start at // have even length meaning that the next iteration will always start at
// index 0. // index 0.
for (int stage_group_id = 0; stage_group_id < max_num_stage_groups;
stage_group_id++) {
#pragma unroll #pragma unroll
for (int pipe = 0; pipe < stages;) { for (int pipe = 0; pipe < stages;) {
#pragma unroll #pragma unroll
for (int k = 0; k < b_sh_wr_iters; k++) { for (int k = 0; k < b_sh_wr_iters; k++) {
int idx = fetch_to_registers(k + 1, pipe % stages);
(pipe >= stages && stage_group_id == max_num_stage_groups - 1)
? (pipe - stages)
: (pipe + stage_group_id * stages);
fetch_to_registers(k + 1, pipe % stages, idx);
fetch_scales_to_registers(k + 1, pipe); fetch_scales_to_registers(k + 1, pipe);
fetch_zp_to_registers(k + 1, pipe); fetch_zp_to_registers(k + 1, pipe);
if (k == b_sh_wr_iters - 2) { if (k == b_sh_wr_iters - 2) {
int idx = (pipe >= 1 && stage_group_id == max_num_stage_groups - 1)
? (pipe - 1)
: (pipe + (stage_group_id + 1) * stages - 1);
fetch_to_shared((pipe + stages - 1) % stages, pipe, fetch_to_shared((pipe + stages - 1) % stages, pipe,
slice_iters >= stages, idx); slice_iters >= stages);
pipe++; pipe++;
wait_for_stage(); wait_for_stage();
init_same_group(pipe % stages); init_same_group(pipe % stages);
} }
matmul(k);
if constexpr (!is_a_8bit) {
matmul(k, pipe - (k >= b_sh_wr_iters - 2 ? 1 : 0));
} else {
static_assert(group_blocks != 0 && group_blocks != 1);
matmul_a8(k);
}
} }
slice_iters--; slice_iters--;
if (slice_iters == 0) { if (slice_iters == 0) {
...@@ -1850,22 +2127,52 @@ __global__ void Marlin( ...@@ -1850,22 +2127,52 @@ __global__ void Marlin(
} }
} }
} }
if (slice_iters == 0) {
break;
}
}
// Process results and, if necessary, proceed to the next column slice. // Process results and, if necessary, proceed to the next column slice.
// While this pattern may not be the most readable, other ways of writing // While this pattern may not be the most readable, other ways of writing
// the loop seemed to noticeably worse performance after compilation. // the loop seemed to noticeably worse performance after compilation.
if (slice_iters == 0) { if (slice_iters == 0) {
if constexpr (is_a_8bit) {
float frag_a_s[2 * thread_m_blocks];
for (int i = 0; i < 2 * thread_m_blocks; i++)
frag_a_s[i] = sh_a_s[i * 8 + (threadIdx.x % 32) / 4];
#pragma unroll
for (int j = 0; j < 2; j++) {
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
#pragma unroll
for (int g = 0; g < 4; g++) {
float c_val = frag_c[i][j][0][g];
if constexpr (a_type == vllm::kS8) {
c_val = __int2float_rn(*reinterpret_cast<int32_t*>(&c_val));
}
float s_val = frag_a_s[i * 2 + g / 2];
frag_c[i][j][0][g] = c_val * s_val;
}
#pragma unroll
for (int g = 0; g < 4; g++) {
float c_val = frag_c[i][j][1][g];
if constexpr (a_type == vllm::kS8) {
c_val = __int2float_rn(*reinterpret_cast<int32_t*>(&c_val));
}
float s_val = frag_a_s[i * 2 + g / 2];
frag_c[i][j][1][g] = c_val * s_val;
}
}
}
}
cp_async_wait<0>(); cp_async_wait<0>();
bool last = slice_idx == slice_count - 1; bool last = slice_idx == slice_count - 1;
// For per-column scales, we only fetch them here in the final step before // For per-column scales, we only fetch them here in the final step before
// write-out // write-out
if constexpr (!has_act_order && group_blocks == -1 && if constexpr (!has_act_order && group_blocks == -1 &&
(has_zp && dequant_skip_flop || !has_zp)) { (has_zp && dequant_skip_flop || !has_zp)) {
if (w_type.size_bits() == 8 || (last || use_atomic_add)) { if (b_type.size_bits() == 8 || (last || use_atomic_add) || is_a_8bit) {
if (s_sh_wr_pred) { if (s_sh_wr_pred) {
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
} }
...@@ -1883,20 +2190,27 @@ __global__ void Marlin( ...@@ -1883,20 +2190,27 @@ __global__ void Marlin(
} }
if constexpr (!has_act_order && group_blocks == -1 && if constexpr (!has_act_order && group_blocks == -1 &&
(has_zp && dequant_skip_flop || !has_zp)) { (has_zp && dequant_skip_flop || !has_zp || is_a_8bit)) {
if (w_type.size_bits() == 8 || (last || use_atomic_add)) { if constexpr (is_a_8bit) {
cp_async_wait<0>();
__syncthreads();
if (threadIdx.x / 32 < tb_n_warps) {
reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
}
} else if (b_type.size_bits() == 8 || (last || use_atomic_add)) {
cp_async_wait<0>(); cp_async_wait<0>();
__syncthreads(); __syncthreads();
if (threadIdx.x / 32 < thread_n_blocks / 4) { if (threadIdx.x / 32 < tb_n_warps) {
reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0]; reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4]; reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
if constexpr (m_block_size_8) { if constexpr (m_block_size_8) {
int idx = (threadIdx.x / 4) % 2; int idx = (threadIdx.x / 4) % 2;
scalar_t2* frag_s_half2 = reinterpret_cast<scalar_t2*>(frag_s); c_scalar_t2* frag_s_half2 =
reinterpret_cast<c_scalar_t2*>(frag_s);
#pragma unroll #pragma unroll
for (int i = 0; i < 8; i++) { for (int i = 0; i < 8; i++) {
frag_s_half2[i] = Dtype::num2num2( frag_s_half2[i] = Cdtype::num2num2(
reinterpret_cast<scalar_t*>(&frag_s_half2[i])[idx]); reinterpret_cast<c_scalar_t*>(&frag_s_half2[i])[idx]);
} }
} }
} }
...@@ -1906,26 +2220,48 @@ __global__ void Marlin( ...@@ -1906,26 +2220,48 @@ __global__ void Marlin(
// For 8-bit channelwise, we apply the scale before the global reduction // For 8-bit channelwise, we apply the scale before the global reduction
// that converts the fp32 results to fp16 (so that we avoid possible // that converts the fp32 results to fp16 (so that we avoid possible
// overflow in fp16) // overflow in fp16)
if constexpr (!has_act_order && group_blocks == -1 && if constexpr (!has_act_order && group_blocks == -1 && is_a_8bit) {
w_type.size_bits() == 8 && #pragma unroll
for (int j = 0; j < 2; j++) {
float2 aa[2];
aa[0] = Cdtype::num22float2(frag_s[0][j * 2][0]);
aa[1] = Cdtype::num22float2(frag_s[0][j * 2 + 1][0]);
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
#pragma unroll
for (int g = 0; g < 4; g++) {
float scale = reinterpret_cast<float*>(&aa[0])[g % 2];
frag_c[i][j][0][g] *= scale;
}
#pragma unroll
for (int g = 0; g < 4; g++) {
float scale = reinterpret_cast<float*>(&aa[1])[g % 2];
frag_c[i][j][1][g] *= scale;
}
}
}
} else if (!has_act_order && group_blocks == -1 &&
b_type.size_bits() == 8 &&
(has_zp && dequant_skip_flop || !has_zp)) { (has_zp && dequant_skip_flop || !has_zp)) {
if (threadIdx.x / 32 < thread_n_blocks / 4) { if (threadIdx.x / 32 < tb_n_warps) {
#pragma unroll #pragma unroll
for (int i = 0; i < thread_m_blocks; i++) { for (int i = 0; i < thread_m_blocks; i++) {
#pragma unroll #pragma unroll
for (int j = 0; j < 4; j++) { for (int j = 0; j < 4; j++) {
scale_float<scalar_t>( scale_float<c_type_id>(
reinterpret_cast<float*>(&frag_c[i][j][0][0]), reinterpret_cast<float*>(&frag_c[i][j][0][0]),
frag_s[j / 2][2 * (j % 2) + 0]); frag_s[j / 2][2 * (j % 2) + 0]);
scale_float<scalar_t>( scale_float<c_type_id>(
reinterpret_cast<float*>(&frag_c[i][j][0][2]), reinterpret_cast<float*>(&frag_c[i][j][0][2]),
frag_s[j / 2][2 * (j % 2) + (m_block_size_8 ? 1 : 0)]); frag_s[j / 2][2 * (j % 2) + (m_block_size_8 ? 1 : 0)]);
if constexpr (!m_block_size_8) { if constexpr (!m_block_size_8) {
scale_float<scalar_t>( scale_float<c_type_id>(
reinterpret_cast<float*>(&frag_c[i][j][1][0]), reinterpret_cast<float*>(&frag_c[i][j][1][0]),
frag_s[j / 2][2 * (j % 2) + 1]); frag_s[j / 2][2 * (j % 2) + 1]);
scale_float<scalar_t>( scale_float<c_type_id>(
reinterpret_cast<float*>(&frag_c[i][j][1][2]), reinterpret_cast<float*>(&frag_c[i][j][1][2]),
frag_s[j / 2][2 * (j % 2) + 1]); frag_s[j / 2][2 * (j % 2) + 1]);
} }
...@@ -1949,6 +2285,7 @@ __global__ void Marlin( ...@@ -1949,6 +2285,7 @@ __global__ void Marlin(
cp_async_wait<0>(); cp_async_wait<0>();
__syncthreads(); __syncthreads();
reinterpret_cast<int4*>(&frag_bias)[0] = sh_bias[bias_sh_rd]; reinterpret_cast<int4*>(&frag_bias)[0] = sh_bias[bias_sh_rd];
if constexpr (!is_a_8bit)
reinterpret_cast<int4*>(&frag_bias)[1] = sh_bias[bias_sh_rd + 4]; reinterpret_cast<int4*>(&frag_bias)[1] = sh_bias[bias_sh_rd + 4];
__syncthreads(); __syncthreads();
} }
...@@ -1958,37 +2295,22 @@ __global__ void Marlin( ...@@ -1958,37 +2295,22 @@ __global__ void Marlin(
if (last || use_atomic_add) if (last || use_atomic_add)
// only the last block in a slice actually writes the result // only the last block in a slice actually writes the result
write_result(last); write_result(last);
int old_slice_row = slice_row;
slice_row = 0; slice_row = 0;
if (!in_part2) {
slice_col_par += gridDim.x;
} else {
slice_col_par++; slice_col_par++;
slice_col++; slice_col++;
}
is_first_matmul_in_slice = true; is_first_matmul_in_slice = true;
init_slice(); init_slice();
// Should we load A matrix in next slice?
// `slice_col == 0`: when move to a new moe block
// `old_slice_row > 0`:
// when the last slice is not starting from k_index == 0
// (only happen when it is the first slice of a threadblock)
// `prob_k > thread_k_blocks * 16 * stages * max_num_stage_groups`:
// when the required shared memory size is larger than
// the remaining shared memory
if (slice_col == 0 || old_slice_row ||
prob_k > thread_k_blocks * 16 * stages * max_num_stage_groups) {
should_load_a = true;
} else {
should_load_a = false;
}
if (slice_iters) { if (slice_iters) {
a_gl_rd_col = (threadIdx.x % a_gl_rd_delta_o); a_gl_rd_col =
#pragma unroll a_gl_rd_delta_o * slice_row + threadIdx.x % a_gl_rd_delta_o;
for (int i = 0; i < b_sh_wr_iters; i++) b_gl_rd = B_expert_off + b_gl_stride * (threadIdx.x / b_sh_stride) +
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; (threadIdx.x % b_sh_stride);
if (slice_col == 0) { b_gl_rd += b_sh_stride * slice_col + b_gl_rd_delta_o * slice_row;
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
}
bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x; bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x;
// Update slice k/n for scales loading // Update slice k/n for scales loading
...@@ -1998,8 +2320,26 @@ __global__ void Marlin( ...@@ -1998,8 +2320,26 @@ __global__ void Marlin(
slice_k_start_shared_fetch = slice_k_start; slice_k_start_shared_fetch = slice_k_start;
slice_n_offset = act_s_col_tb_stride * slice_col; slice_n_offset = act_s_col_tb_stride * slice_col;
} else { } else {
if constexpr (group_blocks == -1) {
s_gl_rd = s_sh_stride * slice_col + threadIdx.x; s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
} else if constexpr (group_blocks >= thread_k_blocks) {
s_gl_rd =
s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
s_sh_stride * slice_col + threadIdx.x;
zp_gl_rd =
zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
zp_sh_stride * slice_col + threadIdx.x;
} else {
s_gl_rd =
s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks +
threadIdx.x / s_sh_stride) +
s_sh_stride * slice_col + threadIdx.x % s_sh_stride;
zp_gl_rd =
zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks +
threadIdx.x / zp_sh_stride) +
zp_sh_stride * slice_col + threadIdx.x % zp_sh_stride;
}
} }
start_pipes(); start_pipes();
} }
......
...@@ -37,39 +37,6 @@ __global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){}; ...@@ -37,39 +37,6 @@ __global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){};
using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS); using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
template <int moe_block_size>
__global__ void permute_cols_kernel(
int4 const* __restrict__ a_int4_ptr, int const* __restrict__ perm_int_ptr,
int4* __restrict__ out_int4_ptr,
const int32_t* __restrict__ sorted_token_ids_ptr,
const int32_t* __restrict__ expert_ids_ptr,
const int32_t* __restrict__ num_tokens_past_padded_ptr, int size_m,
int size_k, int top_k) {};
} // namespace marlin
torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
torch::Tensor& b_q_weight,
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none,
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids,
torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights,
int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep,
vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
bool is_zp_float) {
TORCH_CHECK_NOT_IMPLEMENTED(false,
"marlin_gemm(..) requires CUDA_ARCH >= 8.0");
return torch::empty({1, 1});
}
#else
// For a given "a" of size [M,K] performs a permutation of the K columns based // For a given "a" of size [M,K] performs a permutation of the K columns based
// on the given "perm" indices. // on the given "perm" indices.
template <int moe_block_size> template <int moe_block_size>
...@@ -207,7 +174,7 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8, ...@@ -207,7 +174,7 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
int thread_m_blocks, int prob_m, int prob_n, int thread_m_blocks, int prob_m, int prob_n,
int prob_k, int num_bits, int group_size, int prob_k, int num_bits, int group_size,
bool has_act_order, bool is_k_full, int has_zp, bool has_act_order, bool is_k_full, int has_zp,
int is_zp_float) { int is_zp_float, bool is_a_8bit) {
int pack_factor = 32 / num_bits; int pack_factor = 32 / num_bits;
// Get B size // Get B size
...@@ -217,8 +184,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8, ...@@ -217,8 +184,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
// shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights // shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights
// both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32) // both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32)
int sh_block_meta_size = tb_m * 4; int sh_block_meta_size = tb_m * 16;
int sh_a_size = pipe_stages * (tb_m * tb_k) * 2; int sh_a_size = pipe_stages * (tb_m * tb_k) * (is_a_8bit ? 1 : 2);
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4; int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
int sh_red_size = tb_m * (tb_n + 8) * 2; int sh_red_size = tb_m * (tb_n + 8) * 2;
int sh_bias_size = tb_n * 2; int sh_bias_size = tb_n * 2;
...@@ -250,7 +217,7 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8, ...@@ -250,7 +217,7 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
int thread_m_blocks, int prob_m, int prob_n, int prob_k, int thread_m_blocks, int prob_m, int prob_n, int prob_k,
int num_bits, int group_size, bool has_act_order, int num_bits, int group_size, bool has_act_order,
bool is_k_full, int has_zp, int is_zp_float, bool is_k_full, int has_zp, int is_zp_float,
int max_shared_mem) { int max_shared_mem, bool is_a_8bit) {
// Sanity // Sanity
if (th_config.thread_k == -1 || th_config.thread_n == -1 || if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
th_config.num_threads == -1) { th_config.num_threads == -1) {
...@@ -273,188 +240,34 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8, ...@@ -273,188 +240,34 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
} }
// Check that pipeline fits into cache // Check that pipeline fits into cache
int cache_size = get_kernel_cache_size( int cache_size =
th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k, get_kernel_cache_size(th_config, m_block_size_8, thread_m_blocks, prob_m,
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float); prob_n, prob_k, num_bits, group_size, has_act_order,
return cache_size + 512 <= max_shared_mem; is_k_full, has_zp, is_zp_float, is_a_8bit);
return cache_size <= max_shared_mem;
} }
#define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ MarlinFuncPtr get_marlin_kernel(
M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \ const vllm::ScalarType a_type, const vllm::ScalarType b_type,
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ const vllm::ScalarType c_type, const vllm::ScalarType s_type,
thread_n_blocks == THREAD_N_BLOCKS && \ int thread_m_blocks, int thread_n_blocks, int thread_k_blocks,
thread_k_blocks == THREAD_K_BLOCKS && \ bool m_block_size_8, bool has_act_order, bool has_zp, int group_blocks,
m_block_size_8 == M_BLOCK_SIZE_8 && \ int threads, bool is_zp_float) {
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ int num_bits = b_type.size_bits();
is_zp_float == IS_ZP_FLOAT) { \
constexpr auto S_TYPE = \
W_TYPE == vllm::kFE2M1f \
? (GROUP_BLOCKS == 1 ? vllm::kFE4M3fn : vllm::kFE8M0fnu) \
: (std::is_same<scalar_t, half>::value ? vllm::kFloat16 \
: vllm::kBFloat16); \
kernel = Marlin<scalar_t, W_TYPE.id(), S_TYPE.id(), NUM_THREADS, \
THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
M_BLOCK_SIZE_8, pipe_stages, GROUP_BLOCKS, IS_ZP_FLOAT>; \
}
// COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false)
// this is the most common cases
// BIGGROUP: cases for big group size (group_blocks in [-1, 8])
// FZP: cases for float-zero-point (is_zp_float = true)
// ACT: cases for act order case (group_blocks == 0)
// FP4: cases for nvfp4(e2m1) (group_blocks == 1)
#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
\
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
\
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define COMMON_GET_IF(W_TYPE) \
COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \
COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \
COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \
COMMON_GET_IF_M234(W_TYPE, 8, 4, 128)
#define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define BIGGROUP_GET_IF(W_TYPE) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \
BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \
BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128)
#define NVFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
#define NVFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
#define NVFP4_GET_IF(W_TYPE) \
NVFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
NVFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
NVFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
NVFP4_GET_IF_M234(W_TYPE, 8, 4, 128)
#define MXFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)
#define MXFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)
#define MXFP4_GET_IF(W_TYPE) \
MXFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
MXFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
MXFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
MXFP4_GET_IF_M234(W_TYPE, 8, 4, 128)
// We currently have 4-bit models only with group_blocks == 4
#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)
#define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)
#define FZP_GET_IF(W_TYPE) \
FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \
FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \
FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \
FZP_GET_IF_M234(W_TYPE, 8, 4, 128)
// We currently have 4-bit models only with group_blocks == 4
#define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)
#define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)
#define ACT_GET_IF(W_TYPE) \
ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \
ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \
ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \
ACT_GET_IF_M234(W_TYPE, 8, 4, 128)
template <typename scalar_t>
MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
int thread_m_blocks, int thread_n_blocks,
int thread_k_blocks, bool m_block_size_8,
bool has_act_order, bool has_zp,
int group_blocks, int num_threads,
bool is_zp_float) {
int num_bits = q_type.size_bits();
auto kernel = MarlinDefault; auto kernel = MarlinDefault;
if (false) {
}
COMMON_GET_IF(vllm::kU4)
COMMON_GET_IF(vllm::kU4B8)
COMMON_GET_IF(vllm::kU8B128)
NVFP4_GET_IF(vllm::kFE2M1f) #include "kernel_selector.h"
BIGGROUP_GET_IF(vllm::kFE4M3fn)
ACT_GET_IF(vllm::kU4B8)
ACT_GET_IF(vllm::kU8B128)
if (std::is_same<scalar_t, nv_bfloat16>::value) {
if (false) {
}
MXFP4_GET_IF(vllm::kFE2M1f)
}
return kernel; return kernel;
} }
template <typename scalar_t> exec_config_t determine_exec_config(
exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, const vllm::ScalarType& a_type, const vllm::ScalarType& b_type,
int prob_n, int prob_k, int thread_m_blocks, const vllm::ScalarType& c_type, const vllm::ScalarType& s_type, int prob_m,
bool m_block_size_8, int num_bits, int prob_n, int prob_k, int num_experts, int top_k, int thread_m_blocks,
int group_size, bool has_act_order, bool m_block_size_8, int num_bits, int group_size, bool has_act_order,
bool is_k_full, bool has_zp, bool is_k_full, bool has_zp, bool is_zp_float, int max_shared_mem, int sms,
bool is_zp_float, int max_shared_mem) { bool is_a_8bit) {
exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}}; exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}};
thread_config_t* thread_configs = thread_m_blocks > 1 thread_config_t* thread_configs = thread_m_blocks > 1
? large_batch_thread_configs ? large_batch_thread_configs
...@@ -471,73 +284,69 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, ...@@ -471,73 +284,69 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
if (!is_valid_config(th_config, m_block_size_8, thread_m_blocks, prob_m, if (!is_valid_config(th_config, m_block_size_8, thread_m_blocks, prob_m,
prob_n, prob_k, num_bits, group_size, has_act_order, prob_n, prob_k, num_bits, group_size, has_act_order,
is_k_full, has_zp, is_zp_float, max_shared_mem)) { is_k_full, has_zp, is_zp_float, max_shared_mem - 512,
is_a_8bit)) {
continue; continue;
} }
int cache_size = get_kernel_cache_size( int cache_size = get_kernel_cache_size(
th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k, th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float); num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float,
is_a_8bit);
int group_blocks = 0; int group_blocks = 0;
if (!has_act_order) { if (!has_act_order) {
group_blocks = group_size == -1 ? -1 : (group_size / 16); group_blocks = group_size == -1 ? -1 : (group_size / 16);
} }
auto kernel = get_marlin_kernel<scalar_t>( auto kernel =
q_type, thread_m_blocks, th_config.thread_n / 16, get_marlin_kernel(a_type, b_type, c_type, s_type, thread_m_blocks,
th_config.thread_k / 16, m_block_size_8, has_act_order, has_zp, th_config.thread_n / 16, th_config.thread_k / 16,
group_blocks, th_config.num_threads, is_zp_float); m_block_size_8, has_act_order, has_zp, group_blocks,
th_config.num_threads, is_zp_float);
if (kernel == MarlinDefault) continue; if (kernel == MarlinDefault) continue;
if (thread_m_blocks > 1) {
exec_cfg = {1, th_config};
break;
} else {
cudaFuncAttributes attr; cudaFuncAttributes attr;
cudaFuncGetAttributes(&attr, kernel); cudaFuncGetAttributes(&attr, kernel);
int reg_size = max(attr.numRegs, 1) * th_config.num_threads * 4; int reg_size = max(attr.numRegs, 1) * th_config.num_threads * 4;
int allow_count = min(device_max_reg_size / reg_size, int allow_count = min(device_max_reg_size / reg_size,
max_shared_mem / (cache_size + 1024)); max_shared_mem / (cache_size + 1536));
if (thread_m_blocks == 1)
allow_count = max(min(allow_count, 4), 1); allow_count = max(min(allow_count, 4), 1);
else
allow_count = max(min(allow_count, 2), 1);
if (prob_n / th_config.thread_n * prob_m * top_k * 4 < sms * allow_count) {
allow_count =
max(prob_n / th_config.thread_n * prob_m * top_k * 4 / sms, 1);
}
if (allow_count > count) { if (allow_count > count) {
count = allow_count; count = allow_count;
exec_cfg = {count, th_config}; exec_cfg = {count, th_config};
}; };
} }
}
return exec_cfg; return exec_cfg;
} }
template <typename scalar_t>
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
void* s, void* s2, void* zp, void* g_idx, void* perm, void* a_s, void* b_s, void* g_s, void* zp, void* g_idx,
void* a_tmp, void* sorted_token_ids, void* expert_ids, void* perm, void* a_tmp, void* sorted_token_ids,
void* num_tokens_past_padded, void* topk_weights, void* expert_ids, void* num_tokens_past_padded,
int moe_block_size, int top_k, bool mul_topk_weights, bool is_ep, void* topk_weights, int moe_block_size, int num_experts,
int prob_m, int prob_n, int prob_k, void* workspace, int top_k, bool mul_topk_weights, bool is_ep, int prob_m,
vllm::ScalarType const& q_type, bool has_bias, int prob_n, int prob_k, void* workspace,
bool has_act_order, bool is_k_full, bool has_zp, int num_groups, vllm::ScalarType const& a_type, vllm::ScalarType const& b_type,
int group_size, int dev, cudaStream_t stream, int thread_k, vllm::ScalarType const& c_type, vllm::ScalarType const& s_type,
int thread_n, int sms, bool use_atomic_add, bool use_fp32_reduce, bool has_bias, bool has_act_order, bool is_k_full, bool has_zp,
bool is_zp_float) { int num_groups, int group_size, int dev, cudaStream_t stream,
int thread_k, int thread_n, int sms, int blocks_per_sm,
bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) {
int thread_m_blocks = div_ceil(moe_block_size, 16); int thread_m_blocks = div_ceil(moe_block_size, 16);
bool m_block_size_8 = moe_block_size == 8; bool m_block_size_8 = moe_block_size == 8;
bool is_a_8bit = a_type.size_bits() == 8;
if (has_zp) {
TORCH_CHECK(
q_type == vllm::kU4 || q_type == vllm::kU8,
"q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str());
} else {
TORCH_CHECK(
q_type == vllm::kU4B8 || q_type == vllm::kU8B128 ||
q_type == vllm::kFE4M3fn || q_type == vllm::kFE2M1f,
"q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when "
"has_zp = False. Got = ",
q_type.str());
}
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
", ", prob_n, ", ", prob_k, "]"); ", ", prob_n, ", ", prob_k, "]");
...@@ -563,14 +372,15 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, ...@@ -563,14 +372,15 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
} }
} }
int num_bits = q_type.size_bits(); int num_bits = b_type.size_bits();
const int4* A_ptr = (const int4*)A; const int4* A_ptr = (const int4*)A;
const int4* B_ptr = (const int4*)B; const int4* B_ptr = (const int4*)B;
int4* C_ptr = (int4*)C; int4* C_ptr = (int4*)C;
int4* C_tmp_ptr = (int4*)C_tmp; int4* C_tmp_ptr = (int4*)C_tmp;
const int4* bias_ptr = (const int4*)b_bias; const int4* bias_ptr = (const int4*)b_bias;
const int4* s_ptr = (const int4*)s; const float* a_s_ptr = (const float*)a_s;
const uint16_t* s2_ptr = (const uint16_t*)s2; const int4* b_s_ptr = (const int4*)b_s;
const uint16_t* g_s_ptr = (const uint16_t*)g_s;
const int4* zp_ptr = (const int4*)zp; const int4* zp_ptr = (const int4*)zp;
const int* g_idx_ptr = (const int*)g_idx; const int* g_idx_ptr = (const int*)g_idx;
const int* perm_ptr = (const int*)perm; const int* perm_ptr = (const int*)perm;
...@@ -618,22 +428,41 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, ...@@ -618,22 +428,41 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
TORCH_CHECK(max_shared_mem > 0); TORCH_CHECK(max_shared_mem > 0);
int major_capability, minor_capability;
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
dev);
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
dev);
TORCH_CHECK(major_capability * 10 + minor_capability >= 80,
"marlin kernel only support Ampere or newer GPUs.");
if (a_type == vllm::kFE4M3fn) {
TORCH_CHECK(major_capability * 10 + minor_capability >= 89,
"FP8 only support Ada Lovelace or newer GPUs.");
TORCH_CHECK(
major_capability * 10 + minor_capability == 89 ||
major_capability * 10 + minor_capability == 120,
"Marlin W4A8-FP8 only support SM89 or SM120 device (It is slower than "
"Marlin W4A16 on other devices).");
}
// Set thread config // Set thread config
exec_config_t exec_cfg; exec_config_t exec_cfg;
thread_config_t thread_tfg; thread_config_t thread_tfg;
if (thread_k != -1 && thread_n != -1) { if (thread_k != -1 && thread_n != -1) {
thread_tfg = thread_config_t{thread_k, thread_n, default_threads}; thread_tfg = thread_config_t{thread_k, thread_n, thread_k * thread_n / 64};
exec_cfg = exec_config_t{1, thread_tfg}; if (blocks_per_sm == -1) blocks_per_sm = 1;
exec_cfg = exec_config_t{blocks_per_sm, thread_tfg};
TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
" is not divisible by thread_n = ", thread_n); " is not divisible by thread_n = ", thread_n);
TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
" is not divisible by thread_k = ", thread_k); " is not divisible by thread_k = ", thread_k);
} else { } else {
// Auto config // Auto config
exec_cfg = determine_exec_config<scalar_t>( exec_cfg = determine_exec_config(
q_type, prob_m, prob_n, prob_k, thread_m_blocks, m_block_size_8, a_type, b_type, c_type, s_type, prob_m, prob_n, prob_k, num_experts,
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float, top_k, thread_m_blocks, m_block_size_8, num_bits, group_size,
max_shared_mem); has_act_order, is_k_full, has_zp, is_zp_float, max_shared_mem, sms,
is_a_8bit);
thread_tfg = exec_cfg.tb_cfg; thread_tfg = exec_cfg.tb_cfg;
} }
...@@ -647,22 +476,29 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, ...@@ -647,22 +476,29 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
int thread_k_blocks = thread_k / 16; int thread_k_blocks = thread_k / 16;
int thread_n_blocks = thread_n / 16; int thread_n_blocks = thread_n / 16;
TORCH_CHECK( TORCH_CHECK(is_valid_config(thread_tfg, m_block_size_8, thread_m_blocks,
is_valid_config(thread_tfg, m_block_size_8, thread_m_blocks, prob_m, prob_m, prob_n, prob_k, num_bits, group_size,
prob_n, prob_k, num_bits, group_size, has_act_order, has_act_order, is_k_full, has_zp, is_zp_float,
is_k_full, has_zp, is_zp_float, max_shared_mem), max_shared_mem, is_a_8bit),
"Invalid thread config: thread_m_blocks = ", thread_m_blocks, "Invalid thread config: thread_m_blocks = ", thread_m_blocks,
", thread_k = ", thread_tfg.thread_k, ", thread_k = ", thread_tfg.thread_k,
", thread_n = ", thread_tfg.thread_n, ", thread_n = ", thread_tfg.thread_n,
", num_threads = ", thread_tfg.num_threads, " for MKN = [", prob_m, ", ", ", num_threads = ", thread_tfg.num_threads, " for MKN = [",
prob_k, ", ", prob_n, "] and num_bits = ", num_bits, prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
", group_size = ", group_size, ", has_act_order = ", has_act_order, ", group_size = ", group_size,
", is_k_full = ", is_k_full, ", has_zp = ", has_zp, ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full,
", is_zp_float = ", is_zp_float, ", max_shared_mem = ", max_shared_mem); ", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float,
", max_shared_mem = ", max_shared_mem);
int sh_cache_size =
get_kernel_cache_size(thread_tfg, m_block_size_8, thread_m_blocks, prob_m,
prob_n, prob_k, num_bits, group_size, has_act_order,
is_k_full, has_zp, is_zp_float, is_a_8bit);
auto kernel = get_marlin_kernel<scalar_t>( auto kernel = get_marlin_kernel(
q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, m_block_size_8, a_type, b_type, c_type, s_type, thread_m_blocks, thread_n_blocks,
has_act_order, has_zp, group_blocks, num_threads, is_zp_float); thread_k_blocks, m_block_size_8, has_act_order, has_zp, group_blocks,
num_threads, is_zp_float);
if (kernel == MarlinDefault) { if (kernel == MarlinDefault) {
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n, TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
...@@ -679,19 +515,20 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, ...@@ -679,19 +515,20 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
// avoid ">>>" being formatted to "> > >" // avoid ">>>" being formatted to "> > >"
// clang-format off // clang-format off
kernel<<<blocks, num_threads, max_shared_mem, stream>>>( kernel<<<blocks, num_threads, max_shared_mem, stream>>>(
A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr, A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, a_s_ptr, b_s_ptr, g_s_ptr, zp_ptr, g_idx_ptr,
sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr,
topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m, topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m,
prob_n, prob_k, locks, has_bias, use_atomic_add, use_fp32_reduce, max_shared_mem); prob_n, prob_k, locks, has_bias, use_atomic_add, use_fp32_reduce);
// clang-format on // clang-format on
} }
} // namespace MARLIN_NAMESPACE_NAME } // namespace MARLIN_NAMESPACE_NAME
torch::Tensor moe_wna16_marlin_gemm( torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor& a, std::optional<torch::Tensor> const& c_or_none, torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
torch::Tensor& b_q_weight, torch::Tensor& b_q_weight,
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales, std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
std::optional<torch::Tensor> const& a_scales_or_none,
std::optional<torch::Tensor> const& global_scale_or_none, std::optional<torch::Tensor> const& global_scale_or_none,
std::optional<torch::Tensor> const& b_zeros_or_none, std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none, std::optional<torch::Tensor> const& g_idx_or_none,
...@@ -699,11 +536,70 @@ torch::Tensor moe_wna16_marlin_gemm( ...@@ -699,11 +536,70 @@ torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids, torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids,
torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights, torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights,
int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep, int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep,
vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n, vllm::ScalarTypeId const& b_type_id, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
bool is_zp_float) { bool is_zp_float, int64_t thread_k, int64_t thread_n,
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id); int64_t blocks_per_sm) {
int pack_factor = 32 / b_q_type.size_bits(); vllm::ScalarTypeId a_type_id, c_type_id, s_type_id;
auto c_dtype = a.dtype();
if (a.scalar_type() == at::ScalarType::Half) {
a_type_id = vllm::kFloat16.id();
c_type_id = vllm::kFloat16.id();
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
a_type_id = vllm::kBFloat16.id();
c_type_id = vllm::kBFloat16.id();
} else {
c_dtype = b_scales.dtype();
if (b_scales.scalar_type() == at::ScalarType::Half) {
c_type_id = vllm::kFloat16.id();
} else if (b_scales.scalar_type() == at::ScalarType::BFloat16) {
c_type_id = vllm::kBFloat16.id();
} else {
c_type_id = vllm::kBFloat16.id();
TORCH_CHECK(c_or_none.has_value(), "c must be passed for W4A8-FP4");
torch::Tensor c = c_or_none.value();
c_dtype = c.dtype();
if (c.scalar_type() == at::ScalarType::Half) {
c_type_id = vllm::kFloat16.id();
} else if (c.scalar_type() == at::ScalarType::BFloat16) {
c_type_id = vllm::kBFloat16.id();
} else {
TORCH_CHECK(false, "unsupported c dtype");
}
}
if (a.scalar_type() == at::ScalarType::Float8_e4m3fn) {
a_type_id = vllm::kFE4M3fn.id();
} else if (a.scalar_type() == at::ScalarType::Char) {
a_type_id = vllm::kS8.id();
} else {
TORCH_CHECK(false, "unsupported `a` scalar_type");
}
}
s_type_id = c_type_id;
if (b_type_id == vllm::kFE2M1f.id()) {
if (b_scales.scalar_type() == at::ScalarType::Float8_e4m3fn) {
s_type_id = vllm::kFE4M3fn.id();
} else if (b_scales.scalar_type() == at::ScalarType::Float8_e8m0fnu) {
s_type_id = vllm::kFE8M0fnu.id();
} else {
TORCH_CHECK(false,
"When b_type = float4_e2m1f, b_scale scalar type must be",
"float8_e4m3fn (for NVFP4) or float8_e8m0fnu (for MXFP4).");
}
}
vllm::ScalarType a_type = vllm::ScalarType::from_id(a_type_id);
vllm::ScalarType b_type = vllm::ScalarType::from_id(b_type_id);
vllm::ScalarType c_type = vllm::ScalarType::from_id(c_type_id);
vllm::ScalarType s_type = vllm::ScalarType::from_id(s_type_id);
int pack_factor = 32 / b_type.size_bits();
int num_experts = b_q_weight.size(0);
if (moe_block_size != 8) { if (moe_block_size != 8) {
TORCH_CHECK(moe_block_size % 16 == 0, TORCH_CHECK(moe_block_size % 16 == 0,
...@@ -745,19 +641,27 @@ torch::Tensor moe_wna16_marlin_gemm( ...@@ -745,19 +641,27 @@ torch::Tensor moe_wna16_marlin_gemm(
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as torch::Tensor a_scales;
// auto -1) auto options = torch::TensorOptions().dtype(c_dtype).device(a.device());
int thread_k = -1; auto options_fp32 =
// thread_n: `n` size of a thread_tile in `weights` (can usually be left as torch::TensorOptions().dtype(at::kFloat).device(a.device());
// auto -1)
int thread_n = -1; if (a_scales_or_none.has_value()) {
a_scales = a_scales_or_none.value();
TORCH_CHECK(a_type.size_bits() == 8,
"a_scales can only be used for 8bit activation.");
} else {
a_scales = torch::empty({0}, options_fp32);
TORCH_CHECK(a_type.size_bits() != 8,
"the a_scales parameter must be passed for 8bit activation.");
}
// sms: number of SMs to use for the kernel // sms: number of SMs to use for the kernel
int sms = -1; int sms = -1;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device()); cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device());
// Alloc buffers // Alloc buffers
const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
torch::Tensor c; torch::Tensor c;
if (c_or_none.has_value()) { if (c_or_none.has_value()) {
c = c_or_none.value(); c = c_or_none.value();
...@@ -774,8 +678,6 @@ torch::Tensor moe_wna16_marlin_gemm( ...@@ -774,8 +678,6 @@ torch::Tensor moe_wna16_marlin_gemm(
// Alloc C tmp buffer that is going to be used for the global reduce // Alloc C tmp buffer that is going to be used for the global reduce
torch::Tensor c_tmp; torch::Tensor c_tmp;
auto options_fp32 =
torch::TensorOptions().dtype(at::kFloat).device(a.device());
if (use_fp32_reduce && !use_atomic_add) { if (use_fp32_reduce && !use_atomic_add) {
// max num of threadblocks is sms * 4 // max num of threadblocks is sms * 4
long max_c_tmp_size = min( long max_c_tmp_size = min(
...@@ -846,11 +748,11 @@ torch::Tensor moe_wna16_marlin_gemm( ...@@ -846,11 +748,11 @@ torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor global_scale; torch::Tensor global_scale;
if (global_scale_or_none.has_value()) { if (global_scale_or_none.has_value()) {
global_scale = global_scale_or_none.value(); global_scale = global_scale_or_none.value();
TORCH_CHECK(b_q_type == vllm::kFE2M1f && group_size == 16, TORCH_CHECK(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn,
"global_scale can only be used for nvfp4 format."); "global_scale can only be used for nvfp4 format.");
} else { } else {
global_scale = torch::empty({0}, options); global_scale = torch::empty({0}, options);
TORCH_CHECK(!(b_q_type == vllm::kFE2M1f && group_size == 16), TORCH_CHECK(!(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn),
"the global_scale parameter must be passed for nvfp4 format."); "the global_scale parameter must be passed for nvfp4 format.");
} }
...@@ -877,15 +779,15 @@ torch::Tensor moe_wna16_marlin_gemm( ...@@ -877,15 +779,15 @@ torch::Tensor moe_wna16_marlin_gemm(
bool has_zp = b_zeros.size(-1) > 0; bool has_zp = b_zeros.size(-1) > 0;
if (has_zp) { if (has_zp) {
TORCH_CHECK( TORCH_CHECK(
b_q_type == vllm::kU4 || b_q_type == vllm::kU8, b_type == vllm::kU4 || b_type == vllm::kU8,
"b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str()); "b_type must be u4 or u8 when has_zp = True. Got = ", b_type.str());
} else { } else {
TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 || TORCH_CHECK(b_type == vllm::kU4B8 || b_type == vllm::kU8B128 ||
b_q_type == vllm::kFE4M3fn || b_q_type == vllm::kFE2M1f, b_type == vllm::kS4 || b_type == vllm::kS8 ||
"b_q_type must be uint4b8, uint8b128, float8_e4m3fn or " b_type == vllm::kFE4M3fn || b_type == vllm::kFE2M1f,
"float4_e2m1f when " "b_type must be uint4b8, uint8b128, int4, int8, "
"has_zp = False. Got = ", "float8_e4m3fn or float4_e2m1f when has_zp = False. Got = ",
b_q_type.str()); b_type.str());
} }
if (has_zp && is_zp_float) { if (has_zp && is_zp_float) {
...@@ -929,71 +831,33 @@ torch::Tensor moe_wna16_marlin_gemm( ...@@ -929,71 +831,33 @@ torch::Tensor moe_wna16_marlin_gemm(
" is below min_workspace_size = ", min_workspace_size); " is below min_workspace_size = ", min_workspace_size);
int dev = a.get_device(); int dev = a.get_device();
if (a.scalar_type() == at::ScalarType::Half) {
void* scales_ptr; TORCH_CHECK(a_scales.scalar_type() == at::ScalarType::Float,
if (b_q_type == vllm::kFE2M1f) { "scalar type of a_scales must be float");
if (group_size == 16) TORCH_CHECK(global_scale.scalar_type() == c.scalar_type(),
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>(); "scalar type of global_scale must be the same with c");
else if (group_size == 32) if (a_type.size_bits() == 16) {
scales_ptr = b_scales.data_ptr<at::Float8_e8m0fnu>(); TORCH_CHECK(
else a.scalar_type() == c.scalar_type(),
TORCH_CHECK(false, "scalar type of a must be the same with c for 16 bit activation");
"float4_e2m1f only supports group_size == 16 (NVFP4) ", }
"and group_size == 32 (MXFP4)");
} else { MARLIN_NAMESPACE_NAME::marlin_mm(
scales_ptr = b_scales.data_ptr<at::Half>(); a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), c_tmp.data_ptr(),
} b_bias.data_ptr(), a_scales.data_ptr(), b_scales.data_ptr(),
global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(),
MARLIN_NAMESPACE_NAME::marlin_mm<half>( perm.data_ptr(), a_tmp.data_ptr(), sorted_token_ids.data_ptr(),
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(), expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(),
c_tmp.data_ptr<float>(), b_bias.data_ptr<at::Half>(), scales_ptr, topk_weights.data_ptr(), moe_block_size, num_experts, top_k,
global_scale.data_ptr<at::Half>(), b_zeros.data_ptr(), g_idx.data_ptr(), mul_topk_weights, is_ep, size_m, size_n, size_k, workspace.data_ptr(),
perm.data_ptr(), a_tmp.data_ptr<at::Half>(), a_type, b_type, c_type, s_type, has_bias, has_act_order, is_k_full,
sorted_token_ids.data_ptr(), expert_ids.data_ptr(), has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(), thread_k, thread_n, sms, blocks_per_sm, use_atomic_add, use_fp32_reduce,
moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k, is_zp_float);
workspace.data_ptr(), b_q_type, has_bias, has_act_order, is_k_full,
has_zp, num_groups, group_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
use_atomic_add, use_fp32_reduce, is_zp_float);
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
void* scales_ptr;
if (b_q_type == vllm::kFE2M1f) {
if (group_size == 16)
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
else if (group_size == 32)
scales_ptr = b_scales.data_ptr<at::Float8_e8m0fnu>();
else
TORCH_CHECK(false,
"float4_e2m1f only supports group_size == 16 (NVFP4) ",
"and group_size == 32 (MXFP4)");
} else {
scales_ptr = b_scales.data_ptr<at::BFloat16>();
}
MARLIN_NAMESPACE_NAME::marlin_mm<nv_bfloat16>(
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
b_bias.data_ptr<at::BFloat16>(), scales_ptr,
global_scale.data_ptr<at::BFloat16>(), b_zeros.data_ptr(),
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
sorted_token_ids.data_ptr(), expert_ids.data_ptr(),
num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(),
moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k,
workspace.data_ptr(), b_q_type, has_bias, has_act_order, is_k_full,
has_zp, num_groups, group_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
use_atomic_add, use_fp32_reduce, is_zp_float);
} else {
TORCH_CHECK(false,
"moe_wna16_marlin_gemm only supports bfloat16 and float16");
}
return c; return c;
} }
#endif
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("moe_wna16_marlin_gemm", &moe_wna16_marlin_gemm); m.impl("moe_wna16_marlin_gemm", &moe_wna16_marlin_gemm);
} }
...@@ -63,16 +63,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { ...@@ -63,16 +63,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
m.def( m.def(
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none," "moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
"Tensor! b_q_weight, Tensor? b_bias_or_none," "Tensor! b_q_weight, Tensor? b_bias_or_none,"
"Tensor! b_scales, Tensor? global_scale, Tensor? " "Tensor! b_scales, Tensor? a_scales, Tensor? global_scale, Tensor? "
"b_zeros_or_none," "b_zeros_or_none,"
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace," "Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
"Tensor sorted_token_ids," "Tensor sorted_token_ids,"
"Tensor! expert_ids, Tensor! num_tokens_past_padded," "Tensor! expert_ids, Tensor! num_tokens_past_padded,"
"Tensor! topk_weights, int moe_block_size, int top_k, " "Tensor! topk_weights, int moe_block_size, int top_k, "
"bool mul_topk_weights, bool is_ep, int b_q_type_id," "bool mul_topk_weights, bool is_ep, int b_type_id,"
"int size_m, int size_n, int size_k," "int size_m, int size_n, int size_k,"
"bool is_full_k, bool use_atomic_add," "bool is_full_k, bool use_atomic_add,"
"bool use_fp32_reduce, bool is_zp_float) -> Tensor"); "bool use_fp32_reduce, bool is_zp_float,"
"int thread_k, int thread_n, int blocks_per_sm) -> Tensor");
m.def( m.def(
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
......
...@@ -437,10 +437,10 @@ struct ComputeTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK { ...@@ -437,10 +437,10 @@ struct ComputeTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK {
for (int n_idx = 0; n_idx < WARP_NITER; ++n_idx) { for (int n_idx = 0; n_idx < WARP_NITER; ++n_idx) {
#pragma unroll #pragma unroll
for (int k_idx = 0; k_idx < 2; ++k_idx) { for (int k_idx = 0; k_idx < 2; ++k_idx) {
FType low16 = FType low16 = MarlinScalarType2<FType>::float2num(
ScalarType<FType>::float2num(C_frag[m_idx][n_idx][k_idx * 2]); C_frag[m_idx][n_idx][k_idx * 2]);
FType high16 = FType high16 = MarlinScalarType2<FType>::float2num(
ScalarType<FType>::float2num(C_frag[m_idx][n_idx][k_idx * 2 + 1]); C_frag[m_idx][n_idx][k_idx * 2 + 1]);
uint32_t tmp = (reinterpret_cast<uint32_t&>(low16) & 0xffff) | uint32_t tmp = (reinterpret_cast<uint32_t&>(low16) & 0xffff) |
(reinterpret_cast<uint32_t&>(high16) << 16); (reinterpret_cast<uint32_t&>(high16) << 16);
int sts_offset = int sts_offset =
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include <iostream> #include <iostream>
#include "../gptq_marlin/marlin_dtypes.cuh" #include "../gptq_marlin/marlin_dtypes.cuh"
using marlin::ScalarType; using marlin::MarlinScalarType2;
namespace allspark { namespace allspark {
...@@ -72,10 +72,10 @@ __global__ void f16_gemm_splitk_reduce_kernel(const FType* C_split, FType* C, ...@@ -72,10 +72,10 @@ __global__ void f16_gemm_splitk_reduce_kernel(const FType* C_split, FType* C,
int n_mat = N_MATRIX > 0 ? N_MATRIX : (int)n_matrix; int n_mat = N_MATRIX > 0 ? N_MATRIX : (int)n_matrix;
for (int i = 0; i < n_mat; ++i) { for (int i = 0; i < n_mat; ++i) {
sum += ScalarType<FType>::num2float(C_split[idx + i * matrix_size]); sum += MarlinScalarType2<FType>::num2float(C_split[idx + i * matrix_size]);
} }
C[idx] = ScalarType<FType>::float2num(sum); C[idx] = MarlinScalarType2<FType>::float2num(sum);
} }
template <typename FType> template <typename FType>
......
kernel_*.cu sm*_kernel_*.cu
\ No newline at end of file kernel_selector.h
...@@ -4,14 +4,16 @@ ...@@ -4,14 +4,16 @@
namespace marlin { namespace marlin {
template <int const num_threads, int const num_bits> template <int const num_threads, int const num_bits, bool is_a_8bit>
__global__ void awq_marlin_repack_kernel( __global__ void awq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr, uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr,
int size_k, int size_n) { int size_k, int size_n) {
constexpr int pack_factor = 32 / num_bits; constexpr int pack_factor = 32 / num_bits;
int k_tiles = size_k / tile_k_size; constexpr int target_tile_n_size = tile_n_size / (is_a_8bit ? 2 : 1);
int n_tiles = size_n / tile_n_size; constexpr int target_tile_k_size = tile_k_size * (is_a_8bit ? 2 : 1);
int k_tiles = size_k / target_tile_k_size;
int n_tiles = size_n / target_tile_n_size;
int block_k_tiles = div_ceil(k_tiles, gridDim.x); int block_k_tiles = div_ceil(k_tiles, gridDim.x);
auto start_k_tile = blockIdx.x * block_k_tiles; auto start_k_tile = blockIdx.x * block_k_tiles;
...@@ -33,10 +35,10 @@ __global__ void awq_marlin_repack_kernel( ...@@ -33,10 +35,10 @@ __global__ void awq_marlin_repack_kernel(
extern __shared__ int4 sh[]; extern __shared__ int4 sh[];
constexpr int tile_n_ints = tile_n_size / pack_factor; constexpr int tile_n_ints = target_tile_n_size / pack_factor;
constexpr int stage_n_threads = tile_n_ints / 4; constexpr int stage_n_threads = tile_n_ints / 4;
constexpr int stage_k_threads = tile_k_size; constexpr int stage_k_threads = target_tile_k_size;
constexpr int stage_size = stage_k_threads * stage_n_threads; constexpr int stage_size = stage_k_threads * stage_n_threads;
auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) { auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
...@@ -45,7 +47,7 @@ __global__ void awq_marlin_repack_kernel( ...@@ -45,7 +47,7 @@ __global__ void awq_marlin_repack_kernel(
return; return;
} }
int first_n = n_tile_id * tile_n_size; int first_n = n_tile_id * target_tile_n_size;
int first_n_packed = first_n / pack_factor; int first_n_packed = first_n / pack_factor;
int4* sh_ptr = sh + stage_size * pipe; int4* sh_ptr = sh + stage_size * pipe;
...@@ -54,7 +56,7 @@ __global__ void awq_marlin_repack_kernel( ...@@ -54,7 +56,7 @@ __global__ void awq_marlin_repack_kernel(
auto k_id = threadIdx.x / stage_n_threads; auto k_id = threadIdx.x / stage_n_threads;
auto n_id = threadIdx.x % stage_n_threads; auto n_id = threadIdx.x % stage_n_threads;
int first_k = k_tile_id * tile_k_size; int first_k = k_tile_id * target_tile_k_size;
cp_async4(&sh_ptr[k_id * stage_n_threads + n_id], cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
reinterpret_cast<int4 const*>( reinterpret_cast<int4 const*>(
...@@ -78,11 +80,11 @@ __global__ void awq_marlin_repack_kernel( ...@@ -78,11 +80,11 @@ __global__ void awq_marlin_repack_kernel(
} }
int tc_col = th_id / 4; int tc_col = th_id / 4;
int tc_row = (th_id % 4) * 2; int tc_row = (th_id % 4) * (is_a_8bit ? 4 : 2);
constexpr int tc_offsets[4] = {0, 1, 8, 9}; constexpr int tc_offsets[4] = {0, 1, 8, 9};
int cur_n = warp_id * 16 + tc_col; int cur_n = (warp_id / (is_a_8bit ? 2 : 1)) * 16 + tc_col;
int cur_n_packed = cur_n / pack_factor; int cur_n_packed = cur_n / pack_factor;
int cur_n_pos = cur_n % pack_factor; int cur_n_pos = cur_n % pack_factor;
...@@ -105,23 +107,50 @@ __global__ void awq_marlin_repack_kernel( ...@@ -105,23 +107,50 @@ __global__ void awq_marlin_repack_kernel(
uint32_t vals[8]; uint32_t vals[8];
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
if constexpr (is_a_8bit) {
int cur_elem = tc_row + i;
int packed_src_0 =
sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) * (warp_id % 2) +
sh_stride * cur_elem];
int packed_src_1 =
sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) * (warp_id % 2) +
sh_stride * (cur_elem + 16)];
vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask;
vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask;
} else {
int cur_elem = tc_row + tc_offsets[i]; int cur_elem = tc_row + tc_offsets[i];
int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem]; int packed_src_0 =
sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem];
int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) + int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) +
sh_stride * cur_elem]; sh_stride * cur_elem];
vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask; vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask;
vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask; vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask;
} }
}
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor; constexpr int tile_size =
target_tile_k_size * target_tile_n_size / pack_factor;
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
// Result of: // Result of:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
if constexpr (num_bits == 4) { if constexpr (!is_a_8bit && num_bits == 4) {
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
uint32_t res = 0;
#pragma unroll
for (int i = 0; i < 8; i++) {
res |= vals[pack_idx[i]] << (i * 4);
}
out_ptr[out_offset + th_id * 4 + warp_id] = res;
} else if constexpr (is_a_8bit && num_bits == 4) {
int pack_idx[8] = {0, 4, 1, 5, 2, 6, 3, 7};
uint32_t res = 0; uint32_t res = 0;
#pragma unroll #pragma unroll
...@@ -138,8 +167,9 @@ __global__ void awq_marlin_repack_kernel( ...@@ -138,8 +167,9 @@ __global__ void awq_marlin_repack_kernel(
uint32_t res2 = 0; uint32_t res2 = 0;
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
res1 |= vals[pack_idx[i]] << (i * 8); const int ii = is_a_8bit ? i : pack_idx[i];
res2 |= vals[4 + pack_idx[i]] << (i * 8); res1 |= vals[ii] << (i * 8);
res2 |= vals[4 + ii] << (i * 8);
} }
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
...@@ -176,18 +206,21 @@ __global__ void awq_marlin_repack_kernel( ...@@ -176,18 +206,21 @@ __global__ void awq_marlin_repack_kernel(
} // namespace marlin } // namespace marlin
#define CALL_IF(NUM_BITS) \ #define CALL_IF(NUM_BITS, IS_A_8BIT) \
else if (num_bits == NUM_BITS) { \ else if (num_bits == NUM_BITS && is_a_8bit == IS_A_8BIT) { \
cudaFuncSetAttribute( \ cudaFuncSetAttribute( \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \ marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
IS_A_8BIT>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS> \ marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
IS_A_8BIT> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \ <<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, out_ptr, size_k, size_n); \ b_q_weight_ptr, out_ptr, size_k, size_n); \
} }
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
int64_t size_n, int64_t num_bits) { int64_t size_n, int64_t num_bits,
bool is_a_8bit) {
// Verify compatibility with marlin tile of 16x64 // Verify compatibility with marlin tile of 16x64
TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k, TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k,
" is not divisible by tile_k_size = ", marlin::tile_k_size); " is not divisible by tile_k_size = ", marlin::tile_k_size);
...@@ -238,10 +271,13 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, ...@@ -238,10 +271,13 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
if (false) { if (false) {
} }
CALL_IF(4) CALL_IF(4, false)
CALL_IF(8) CALL_IF(8, false)
CALL_IF(4, true)
CALL_IF(8, true)
else { else {
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits); TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits,
", is_a_8bit = ", is_a_8bit);
} }
return out; return out;
......
...@@ -470,6 +470,50 @@ __device__ inline void dequant<nv_bfloat162, vllm::kFE2M1f.id(), false>( ...@@ -470,6 +470,50 @@ __device__ inline void dequant<nv_bfloat162, vllm::kFE2M1f.id(), false>(
frag_b[0] = __hmul2(frag_b[0], bias_reg); frag_b[0] = __hmul2(frag_b[0], bias_reg);
} }
template <>
__device__ inline void dequant<__nv_fp8x4_e4m3, vllm::kFE2M1f.id(), true>(
int q, __nv_fp8x4_e4m3* frag_b) {
// Constants for FP4 (E2M1) and FP16 formats
constexpr int FP4_EXPONENT = 2, FP8_EXPONENT = 4;
constexpr int RIGHT_SHIFT = FP8_EXPONENT - FP4_EXPONENT;
constexpr int MASK = 0x70707070;
// Extract and shift FP4 values to FP16 format
int Out1 = (q & 0x80808080) | ((q & MASK) >> RIGHT_SHIFT);
q <<= 4;
int Out2 = (q & 0x80808080) | ((q & MASK) >> RIGHT_SHIFT);
// Note1: reverse indexing is intentional because weights are permuted
// Note2: when dequant to 8bit type, we write to `frag_b[2]` instead of
// `frag_b[1]` to fit the layout of tensorcore
frag_b[1] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out1);
frag_b[0] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out2);
}
template <>
__device__ inline void dequant<int32_t, vllm::kU4B8.id(), true>(
int q, int32_t* frag_b) {
constexpr int repeated_zp = 0x08080808;
constexpr int MASK = 0x80808080;
frag_b[0] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK;
q >>= 4;
frag_b[1] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK;
}
template <>
__device__ inline void dequant<__nv_fp8x4_e4m3, vllm::kU4B8.id(), true>(
int q, __nv_fp8x4_e4m3* frag_b) {
int s = q & 0x08080808;
int Out1 = ((q & 0x07070707) | (s << 4)) + (s >> 3);
q >>= 4;
s = q & 0x08080808;
int Out2 = ((q & 0x07070707) | (s << 4)) + (s >> 3);
frag_b[0] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out1);
frag_b[1] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out2);
}
template <typename scalar_t2, vllm::ScalarTypeId s_type_id> template <typename scalar_t2, vllm::ScalarTypeId s_type_id>
__device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b); __device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b);
...@@ -515,6 +559,49 @@ __device__ inline void dequant_fp8_scales<nv_bfloat162, vllm::kFE8M0fnu.id()>( ...@@ -515,6 +559,49 @@ __device__ inline void dequant_fp8_scales<nv_bfloat162, vllm::kFE8M0fnu.id()>(
// Note: reverse indexing is intentional because weights are permuted // Note: reverse indexing is intentional because weights are permuted
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1); frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2); frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
};
// subtract zero point in quanted format and then dequant
template <typename scalar_t2, vllm::ScalarTypeId w_type_id,
bool skip_flop = false>
__device__ inline void sub_zp_and_dequant(int q, scalar_t2* frag_b, int zp);
template <>
__device__ inline void sub_zp_and_dequant<int32_t, vllm::kU4.id(), true>(
int q, int32_t* frag_b, int zp) {
// INT4 with zp -> INT8
// see https://github.com/vllm-project/vllm/pull/24722
int repeated_zp = 0x01010101 * zp;
int MASK = 0x80808080;
frag_b[0] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK;
q >>= 4;
frag_b[1] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK;
}
template <>
__device__ inline void sub_zp_and_dequant<__nv_fp8x4_e4m3, vllm::kU4.id(),
true>(int q, __nv_fp8x4_e4m3* frag_b,
int zp) {
// INT4 with zp -> FP8
// see https://github.com/vllm-project/vllm/pull/24722
uint32_t u_q = *reinterpret_cast<uint32_t*>(&q);
uint32_t u_zp = *reinterpret_cast<uint32_t*>(&zp);
uint32_t u_zp1 = u_zp + 1;
uint32_t repeated_zp = 0x01010101 * u_zp;
uint32_t q0, s;
q0 = (u_q & 0x0F0F0F0F) | 0x70707070;
s = (q0 + repeated_zp) & 0x80808080;
uint32_t Out1 = (q0 + (s >> 7) * u_zp1) & 0x0F0F0F0F | s;
u_q >>= 4;
q0 = (u_q & 0x0F0F0F0F) | 0x70707070;
s = (q0 + repeated_zp) & 0x80808080;
uint32_t Out2 = (q0 + (s >> 7) * u_zp1) & 0x0F0F0F0F | s;
frag_b[0] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out1);
frag_b[1] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out2);
} }
#endif #endif
......
...@@ -4,141 +4,292 @@ import glob ...@@ -4,141 +4,292 @@ import glob
import itertools import itertools
import os import os
import subprocess import subprocess
import sys
import jinja2 import jinja2
FILE_HEAD = """ ARCHS = []
// auto generated by generate.py SUPPORT_FP8 = False
for arch in sys.argv[1].split(","):
arch = arch[: arch.index(".") + 2].replace(".", "")
arch = int(arch)
# only SM89 and SM120 fully support
# mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32.
# SM90 and SM100 can use this PTX, but it’s simulated
# with FP16 MMA, so it cannot achieve any acceleration.
if arch in [89, 120]:
SUPPORT_FP8 = True
FILE_HEAD_COMMENT = """
// auto generated by generate_kernels.py
// clang-format off // clang-format off
""".lstrip()
FILE_HEAD = (
FILE_HEAD_COMMENT
+ """
#include "kernel.h" #include "kernel.h"
#include "marlin_template.h" #include "marlin_template.h"
namespace MARLIN_NAMESPACE_NAME { namespace MARLIN_NAMESPACE_NAME {
""".strip() """
)
TEMPLATE = ( TEMPLATE = (
"template __global__ void Marlin<" "template __global__ void Marlin<"
"{{scalar_t}}, " "{{a_type_id}}, "
"{{w_type_id}}, " "{{b_type_id}}, "
"{{c_type_id}}, "
"{{s_type_id}}, " "{{s_type_id}}, "
"{{threads}}, " "{{threads}}, "
"{{thread_m_blocks}}, " "{{thread_m_blocks}}, "
"{{thread_n_blocks}}, " "{{thread_n_blocks}}, "
"{{thread_k_blocks}}, " "{{thread_k_blocks}}, "
"{{'true' if m_block_size_8 else 'false'}}, " "{{m_block_size_8}}, "
"{{stages}}, " "{{stages}}, "
"{{group_blocks}}, " "{{group_blocks}}, "
"{{'true' if is_zp_float else 'false'}}>" "{{is_zp_float}}>"
"( MARLIN_KERNEL_PARAMS );" "( MARLIN_KERNEL_PARAMS );"
) )
# int8 with zero point case (vllm::kU8) is also supported,
# we don't add it to reduce wheel size.
SCALAR_TYPES = [
"vllm::kU4",
"vllm::kU4B8",
"vllm::kU8B128",
"vllm::kFE4M3fn",
"vllm::kFE2M1f",
]
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), (128, 64, 128)] THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), (128, 64, 128)]
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4] THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
# group_blocks:
# = 0 : act order case QUANT_CONFIGS = [
# = -1 : channelwise quantization # AWQ-INT4
# > 0 : group_size=16*group_blocks {
GROUP_BLOCKS = [0, 1, -1, 2, 4, 8] "b_type": "kU4",
DTYPES = ["fp16", "bf16"] "thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 2, 4, 8],
},
# HQQ
{
"a_type": ["kFloat16"],
"b_type": "kU4",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [4],
"is_zp_float": True,
},
# GPTQ-INT4
{
"b_type": "kU4B8",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 0, 2, 4, 8],
},
# GPTQ-INT8
{
"b_type": "kU8B128",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 0, 2, 4, 8],
},
# FP8
{
"b_type": "kFE4M3fn",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 8],
},
# NVFP4
{
"b_type": "kFE2M1f",
"s_type": "kFE4M3fn",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [1],
},
# MXFP4
{
"a_type": ["kBFloat16"],
"b_type": "kFE2M1f",
"s_type": "kFE8M0fnu",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [2],
},
# AWQ-INT4 with INT8 activation
{
"a_type": ["kS8"],
"b_type": "kU4",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT4 with INT8 activation
{
"a_type": ["kS8"],
"b_type": "kU4B8",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT4 with FP8 activation
{
"a_type": ["kFE4M3fn"],
"b_type": "kU4B8",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# AWQ-INT4 with FP8 activation
{
"a_type": ["kFE4M3fn"],
"b_type": "kU4",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# MXFP4 with FP8 activation
{
"a_type": ["kFE4M3fn"],
"b_type": "kFE2M1f",
"c_type": ["kBFloat16"],
"s_type": "kFE8M0fnu",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [2],
},
]
def remove_old_kernels(): def remove_old_kernels():
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"): for filename in glob.glob(os.path.dirname(__file__) + "/*kernel_*.cu"):
subprocess.call(["rm", "-f", filename])
filename = os.path.dirname(__file__) + "/kernel_selector.h"
subprocess.call(["rm", "-f", filename]) subprocess.call(["rm", "-f", filename])
def generate_new_kernels(): def generate_new_kernels():
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES): result_dict = {}
all_template_str_list = []
for group_blocks, m_blocks, thread_configs in itertools.product( for quant_config in QUANT_CONFIGS:
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"])
): a_types = quant_config.get("a_type", ["kFloat16", "kBFloat16"])
# act order case only support gptq-int4 and gptq-int8 b_type = quant_config["b_type"]
if group_blocks == 0 and scalar_type not in [ is_zp_float = quant_config.get("is_zp_float", False)
"vllm::kU4B8", all_group_blocks = quant_config["group_blocks"]
"vllm::kU8B128", all_m_blocks = quant_config["thread_m_blocks"]
]: all_thread_configs = quant_config["thread_configs"]
continue
if thread_configs[2] == 256: for a_type, c_type in itertools.product(a_types, c_types):
# for small batch (m_blocks == 1), we only need (128, 128, 256) if not SUPPORT_FP8 and a_type == "kFE4M3fn":
# for large batch (m_blocks > 1), we only need (64, 256, 256)
if m_blocks <= 1 and thread_configs[0] != 128:
continue continue
if m_blocks > 1 and thread_configs[0] != 64: if "16" in a_type and "16" in c_type and a_type != c_type:
continue continue
s_type = quant_config.get("s_type", c_type)
if (a_type, b_type, c_type) not in result_dict:
result_dict[(a_type, b_type, c_type)] = []
# we only support channelwise quantization and group_size == 128 for group_blocks, m_blocks, thread_configs in itertools.product(
# for fp8 all_group_blocks, all_m_blocks, all_thread_configs
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]: ):
continue thread_k, thread_n, threads = thread_configs
# nvfp4 only supports group_size == 16
# mxfp4 only supports group_size == 32 if threads == 256:
if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]: # for small batch (m_blocks == 1),
# we only need (128, 128, 256)
# for large batch (m_blocks > 1),
# we only need (64, 256, 256)
if m_blocks <= 1 and (thread_k, thread_n) != (128, 128):
continue continue
# other quantization methods don't support group_size = 16 if m_blocks > 1 and (thread_k, thread_n) != (64, 256):
if scalar_type != "vllm::kFE2M1f" and group_blocks == 1:
continue continue
k_blocks = thread_configs[0] // 16 config = {
n_blocks = thread_configs[1] // 16 "threads": threads,
threads = thread_configs[2] "s_type": s_type,
"thread_m_blocks": max(m_blocks, 1),
"thread_k_blocks": thread_k // 16,
"thread_n_blocks": thread_n // 16,
"m_block_size_8": "true" if m_blocks == 0.5 else "false",
"stages": "pipe_stages",
"group_blocks": group_blocks,
"is_zp_float": "true" if is_zp_float else "false",
}
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16" result_dict[(a_type, b_type, c_type)].append(config)
is_zp_float_list = [False] kernel_selector_str = FILE_HEAD_COMMENT
if dtype == "fp16" and scalar_type == "vllm::kU4" and group_blocks == 4:
# HQQ (is_zp_float = true) only supports
# 4bit quantization and fp16
is_zp_float_list.append(True)
if scalar_type == "vllm::kFE2M1f" and group_blocks == 1: for (a_type, b_type, c_type), config_list in result_dict.items():
s_type = "vllm::kFE4M3fn" all_template_str_list = []
elif scalar_type == "vllm::kFE2M1f" and group_blocks == 2: for config in config_list:
s_type = "vllm::kFE8M0fnu" s_type = config["s_type"]
if dtype == "fp16":
# we cannot safely dequantize e8m0 to fp16, so skip this
continue
elif dtype == "fp16":
s_type = "vllm::kFloat16"
elif dtype == "bf16":
s_type = "vllm::kBFloat16"
for is_zp_float in is_zp_float_list:
template_str = jinja2.Template(TEMPLATE).render( template_str = jinja2.Template(TEMPLATE).render(
scalar_t=c_dtype, a_type_id=f"vllm::{a_type}.id()",
w_type_id=scalar_type + ".id()", b_type_id=f"vllm::{b_type}.id()",
s_type_id=s_type + ".id()", c_type_id=f"vllm::{c_type}.id()",
threads=threads, s_type_id=f"vllm::{s_type}.id()",
thread_m_blocks=max(m_blocks, 1), **config,
thread_n_blocks=n_blocks,
thread_k_blocks=k_blocks,
m_block_size_8=m_blocks == 0.5,
stages="pipe_stages",
group_blocks=group_blocks,
is_zp_float=is_zp_float,
) )
all_template_str_list.append(template_str) all_template_str_list.append(template_str)
conditions = [
f"a_type == vllm::{a_type}",
f"b_type == vllm::{b_type}",
f"c_type == vllm::{c_type}",
f"s_type == vllm::{s_type}",
f"threads == {config['threads']}",
f"thread_m_blocks == {config['thread_m_blocks']}",
f"thread_n_blocks == {config['thread_n_blocks']}",
f"thread_k_blocks == {config['thread_k_blocks']}",
f"m_block_size_8 == {config['m_block_size_8']}",
f"group_blocks == {config['group_blocks']}",
f"is_zp_float == {config['is_zp_float']}",
]
conditions = " && ".join(conditions)
if kernel_selector_str == FILE_HEAD_COMMENT:
kernel_selector_str += f"if ({conditions})\n kernel = "
else:
kernel_selector_str += f"else if ({conditions})\n kernel = "
kernel_template2 = (
"Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
"{{is_zp_float}}>;"
)
kernel_selector_str += (
jinja2.Template(kernel_template2).render(
a_type_id=f"vllm::{a_type}.id()",
b_type_id=f"vllm::{b_type}.id()",
c_type_id=f"vllm::{c_type}.id()",
s_type_id=f"vllm::{s_type}.id()",
**config,
)
+ "\n"
)
file_content = FILE_HEAD + "\n\n" file_content = FILE_HEAD + "\n\n"
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n" file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu" if a_type == "kFE4M3fn":
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
else:
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
filename = filename.lower()
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f: with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
f.write(file_content) f.write(file_content)
if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT:
kernel_selector_str += (
"else if (a_type == vllm::kFE4M3fn)\n"
" TORCH_CHECK(false, "
'"marlin kernel with fp8 activation is not built.");'
)
with open(os.path.join(os.path.dirname(__file__), "kernel_selector.h"), "w") as f:
f.write(kernel_selector_str)
if __name__ == "__main__": if __name__ == "__main__":
remove_old_kernels() remove_old_kernels()
......
...@@ -53,7 +53,7 @@ torch::Tensor gptq_marlin_gemm( ...@@ -53,7 +53,7 @@ torch::Tensor gptq_marlin_gemm(
std::optional<torch::Tensor> const& b_zeros_or_none, std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none, std::optional<torch::Tensor> const& g_idx_or_none,
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace, std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n, vllm::ScalarTypeId const& b_type_id, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
bool is_zp_float) { bool is_zp_float) {
TORCH_CHECK_NOT_IMPLEMENTED(false, TORCH_CHECK_NOT_IMPLEMENTED(false,
...@@ -243,204 +243,29 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks, ...@@ -243,204 +243,29 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
int cache_size = get_kernel_cache_size( int cache_size = get_kernel_cache_size(
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size,
has_act_order, is_k_full, has_zp, is_zp_float); has_act_order, is_k_full, has_zp, is_zp_float);
return cache_size + 512 <= max_shared_mem; return cache_size <= max_shared_mem;
} }
#define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ MarlinFuncPtr get_marlin_kernel(
M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \ const vllm::ScalarType a_type, const vllm::ScalarType b_type,
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ const vllm::ScalarType c_type, const vllm::ScalarType s_type,
thread_n_blocks == THREAD_N_BLOCKS && \ int thread_m_blocks, int thread_n_blocks, int thread_k_blocks,
thread_k_blocks == THREAD_K_BLOCKS && \ bool m_block_size_8, bool has_act_order, bool has_zp, int group_blocks,
m_block_size_8 == M_BLOCK_SIZE_8 && \ int threads, bool is_zp_float) {
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ int num_bits = b_type.size_bits();
is_zp_float == IS_ZP_FLOAT) { \
constexpr auto S_TYPE = \
W_TYPE == vllm::kFE2M1f \
? (GROUP_BLOCKS == 1 ? vllm::kFE4M3fn : vllm::kFE8M0fnu) \
: (std::is_same<scalar_t, half>::value ? vllm::kFloat16 \
: vllm::kBFloat16); \
kernel = Marlin<scalar_t, W_TYPE.id(), S_TYPE.id(), NUM_THREADS, \
THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
M_BLOCK_SIZE_8, pipe_stages, GROUP_BLOCKS, IS_ZP_FLOAT>; \
}
// COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false)
// this is the most common cases
// BIGGROUP: cases for big group size (group_blocks in [-1, 8])
// FZP: cases for float-zero-point (is_zp_float = true)
// ACT: cases for act order case (group_blocks == 0)
// FP4: cases for nvfp4(e2m1) (group_blocks == 1)
#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
\
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
\
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define COMMON_GET_IF(W_TYPE) \
COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \
COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \
COMMON_GET_IF_M1(W_TYPE, 4, 8, 128) \
COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \
COMMON_GET_IF_M234(W_TYPE, 8, 4, 128) \
COMMON_GET_IF_M234(W_TYPE, 4, 8, 128)
#define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define BIGGROUP_GET_IF(W_TYPE) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \
BIGGROUP_GET_IF_M1(W_TYPE, 4, 8, 128) \
BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \
BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) \
BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128)
#define NVFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
#define NVFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
#define NVFP4_GET_IF(W_TYPE) \
NVFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
NVFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
NVFP4_GET_IF_M1(W_TYPE, 4, 8, 128) \
NVFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
NVFP4_GET_IF_M234(W_TYPE, 8, 4, 128) \
NVFP4_GET_IF_M234(W_TYPE, 4, 8, 128)
#define MXFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)
#define MXFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)
#define MXFP4_GET_IF(W_TYPE) \
MXFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
MXFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
MXFP4_GET_IF_M1(W_TYPE, 4, 8, 128) \
MXFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
MXFP4_GET_IF_M234(W_TYPE, 8, 4, 128) \
MXFP4_GET_IF_M234(W_TYPE, 4, 8, 128)
// We currently have 4-bit models only with group_blocks == 4
#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)
#define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)
#define FZP_GET_IF(W_TYPE) \
FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \
FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \
FZP_GET_IF_M1(W_TYPE, 4, 8, 128) \
FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \
FZP_GET_IF_M234(W_TYPE, 8, 4, 128) \
FZP_GET_IF_M234(W_TYPE, 4, 8, 128)
// We currently have 4-bit models only with group_blocks == 4
#define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)
#define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)
#define ACT_GET_IF(W_TYPE) \
ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \
ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \
ACT_GET_IF_M1(W_TYPE, 4, 8, 128) \
ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \
ACT_GET_IF_M234(W_TYPE, 8, 4, 128) \
ACT_GET_IF_M234(W_TYPE, 4, 8, 128)
template <typename scalar_t>
MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
int thread_m_blocks, int thread_n_blocks,
int thread_k_blocks, bool m_block_size_8,
bool has_act_order, bool has_zp,
int group_blocks, int num_threads,
bool is_zp_float) {
int num_bits = q_type.size_bits();
auto kernel = MarlinDefault; auto kernel = MarlinDefault;
if (false) {
}
COMMON_GET_IF(vllm::kU4)
COMMON_GET_IF(vllm::kU4B8)
COMMON_GET_IF(vllm::kU8B128)
NVFP4_GET_IF(vllm::kFE2M1f)
BIGGROUP_GET_IF(vllm::kFE4M3fn)
ACT_GET_IF(vllm::kU4B8) #include "kernel_selector.h"
ACT_GET_IF(vllm::kU8B128)
if (std::is_same<scalar_t, half>::value) {
if (false) {
}
FZP_GET_IF(vllm::kU4)
}
if (std::is_same<scalar_t, nv_bfloat16>::value) {
if (false) {
}
MXFP4_GET_IF(vllm::kFE2M1f)
}
return kernel; return kernel;
} }
template <typename scalar_t> exec_config_t determine_exec_config(
exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, const vllm::ScalarType& a_type, const vllm::ScalarType& b_type,
int prob_n, int prob_k, int thread_m_blocks, const vllm::ScalarType& c_type, const vllm::ScalarType& s_type, int prob_m,
bool m_block_size_8, int num_bits, int prob_n, int prob_k, int thread_m_blocks, bool m_block_size_8,
int group_size, bool has_act_order, int num_bits, int group_size, bool has_act_order, bool is_k_full,
bool is_k_full, bool has_zp, bool has_zp, bool is_zp_float, int max_shared_mem, int sms) {
bool is_zp_float, int max_shared_mem,
int sms) {
exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}}; exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}};
thread_config_t* thread_configs = thread_m_blocks > 1 thread_config_t* thread_configs = thread_m_blocks > 1
? large_batch_thread_configs ? large_batch_thread_configs
...@@ -455,7 +280,7 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, ...@@ -455,7 +280,7 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
if (!is_valid_config(th_config, thread_m_blocks, prob_m, prob_n, prob_k, if (!is_valid_config(th_config, thread_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full, has_zp, num_bits, group_size, has_act_order, is_k_full, has_zp,
is_zp_float, max_shared_mem)) { is_zp_float, max_shared_mem - 512)) {
continue; continue;
} }
...@@ -468,10 +293,11 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, ...@@ -468,10 +293,11 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
group_blocks = group_size == -1 ? -1 : group_size / 16; group_blocks = group_size == -1 ? -1 : group_size / 16;
} }
auto kernel = get_marlin_kernel<scalar_t>( auto kernel =
q_type, thread_m_blocks, th_config.thread_n / 16, get_marlin_kernel(a_type, b_type, c_type, s_type, thread_m_blocks,
th_config.thread_k / 16, m_block_size_8, has_act_order, has_zp, th_config.thread_n / 16, th_config.thread_k / 16,
group_blocks, th_config.num_threads, is_zp_float); m_block_size_8, has_act_order, has_zp, group_blocks,
th_config.num_threads, is_zp_float);
if (kernel == MarlinDefault) continue; if (kernel == MarlinDefault) continue;
...@@ -485,28 +311,16 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, ...@@ -485,28 +311,16 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
return exec_cfg; return exec_cfg;
} }
template <typename scalar_t>
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
void* s, void* s2, void* zp, void* g_idx, void* perm, void* a_s, void* b_s, void* g_s, void* zp, void* g_idx,
void* a_tmp, int prob_m, int prob_n, int prob_k, int lda, void* perm, void* a_tmp, int prob_m, int prob_n, int prob_k,
void* workspace, vllm::ScalarType const& q_type, bool has_bias, int lda, void* workspace, vllm::ScalarType const& a_type,
vllm::ScalarType const& b_type, vllm::ScalarType const& c_type,
vllm::ScalarType const& s_type, bool has_bias,
bool has_act_order, bool is_k_full, bool has_zp, int num_groups, bool has_act_order, bool is_k_full, bool has_zp, int num_groups,
int group_size, int dev, cudaStream_t stream, int thread_k_init, int group_size, int dev, cudaStream_t stream, int thread_k_init,
int thread_n_init, int sms, bool use_atomic_add, int thread_n_init, int sms, bool use_atomic_add,
bool use_fp32_reduce, bool is_zp_float) { bool use_fp32_reduce, bool is_zp_float) {
if (has_zp) {
TORCH_CHECK(
q_type == vllm::kU4 || q_type == vllm::kU8,
"q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str());
} else {
TORCH_CHECK(
q_type == vllm::kU4B8 || q_type == vllm::kU8B128 ||
q_type == vllm::kFE4M3fn || q_type == vllm::kFE2M1f,
"q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when "
"has_zp = False. Got = ",
q_type.str());
}
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
", ", prob_n, ", ", prob_k, "]"); ", ", prob_n, ", ", prob_k, "]");
...@@ -531,19 +345,21 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, ...@@ -531,19 +345,21 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
} }
} }
int num_bits = q_type.size_bits(); int num_bits = b_type.size_bits();
const int4* A_ptr = (const int4*)A; const int4* A_ptr = (const int4*)A;
const int4* B_ptr = (const int4*)B; const int4* B_ptr = (const int4*)B;
int4* C_ptr = (int4*)C; int4* C_ptr = (int4*)C;
int4* C_tmp_ptr = (int4*)C_tmp; int4* C_tmp_ptr = (int4*)C_tmp;
const int4* bias_ptr = (const int4*)b_bias; const int4* bias_ptr = (const int4*)b_bias;
const int4* s_ptr = (const int4*)s; const float* a_s_ptr = (const float*)a_s;
const uint16_t* s2_ptr = (const uint16_t*)s2; const int4* b_s_ptr = (const int4*)b_s;
const uint16_t* g_s_ptr = (const uint16_t*)g_s;
const int4* zp_ptr = (const int4*)zp; const int4* zp_ptr = (const int4*)zp;
const int* g_idx_ptr = (const int*)g_idx; const int* g_idx_ptr = (const int*)g_idx;
const int* perm_ptr = (const int*)perm; const int* perm_ptr = (const int*)perm;
int4* a_tmp_ptr = (int4*)a_tmp; int4* a_tmp_ptr = (int4*)a_tmp;
int* locks = (int*)workspace; int* locks = (int*)workspace;
if (has_act_order) { if (has_act_order) {
...@@ -568,6 +384,21 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, ...@@ -568,6 +384,21 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
TORCH_CHECK(max_shared_mem > 0); TORCH_CHECK(max_shared_mem > 0);
int major_capability, minor_capability;
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
dev);
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
dev);
TORCH_CHECK(major_capability * 10 + minor_capability >= 80,
"marlin kernel only support Ampere or newer GPUs.");
if (a_type == vllm::kFE4M3fn) {
TORCH_CHECK(
major_capability * 10 + minor_capability == 89 ||
major_capability * 10 + minor_capability == 120,
"Marlin W4A8-FP8 only support SM89 or SM120 device (It is slower than "
"Marlin W4A16 on other devices).");
}
int max_par = 16; int max_par = 16;
if (prob_n <= 4096) max_par = 16 * 8; if (prob_n <= 4096) max_par = 16 * 8;
int max_shared_mem_new = max_shared_mem; int max_shared_mem_new = max_shared_mem;
...@@ -583,7 +414,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, ...@@ -583,7 +414,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
int thread_n = thread_n_init; int thread_n = thread_n_init;
int thread_m_blocks = min(div_ceil(prob_m_split, 16), max_thread_m_blocks); int thread_m_blocks = min(div_ceil(prob_m_split, 16), max_thread_m_blocks);
int m_block_size_8 = prob_m_split <= 8; int m_block_size_8 = prob_m_split <= 8 && a_type.size_bits() == 16;
// Set thread config // Set thread config
exec_config_t exec_cfg; exec_config_t exec_cfg;
...@@ -597,11 +428,25 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, ...@@ -597,11 +428,25 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
" is not divisible by thread_k = ", thread_k); " is not divisible by thread_k = ", thread_k);
} else { } else {
// Auto config // Auto config
exec_cfg = determine_exec_config<scalar_t>( exec_cfg = determine_exec_config(
q_type, prob_m_split, prob_n, prob_k, thread_m_blocks, m_block_size_8, a_type, b_type, c_type, s_type, prob_m_split, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float, thread_m_blocks, m_block_size_8, num_bits, group_size, has_act_order,
max_shared_mem, sms); is_k_full, has_zp, is_zp_float, max_shared_mem, sms);
thread_tfg = exec_cfg.tb_cfg; thread_tfg = exec_cfg.tb_cfg;
if (thread_tfg.thread_n != -1) {
if (prob_n / thread_tfg.thread_n *
div_ceil(prob_m_split, thread_m_blocks * 16) * 4 <=
sms) {
if (is_valid_config({128, 64, 128}, thread_m_blocks, prob_m_split,
prob_n, prob_k, num_bits, group_size,
has_act_order, is_k_full, has_zp, is_zp_float,
max_shared_mem_new)) {
thread_tfg = {128, 64, 128};
exec_cfg = {1, thread_tfg};
}
}
}
if (thread_tfg.thread_k == -1 && max_thread_m_blocks > 1) { if (thread_tfg.thread_k == -1 && max_thread_m_blocks > 1) {
max_thread_m_blocks--; max_thread_m_blocks--;
continue; continue;
...@@ -632,10 +477,10 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, ...@@ -632,10 +477,10 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float, ", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float,
", max_shared_mem_new = ", max_shared_mem_new); ", max_shared_mem_new = ", max_shared_mem_new);
auto kernel = get_marlin_kernel<scalar_t>( auto kernel = get_marlin_kernel(
q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, a_type, b_type, c_type, s_type, thread_m_blocks, thread_n_blocks,
m_block_size_8, has_act_order, has_zp, group_blocks, num_threads, thread_k_blocks, m_block_size_8, has_act_order, has_zp, group_blocks,
is_zp_float); num_threads, is_zp_float);
if (kernel == MarlinDefault) { if (kernel == MarlinDefault) {
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n, TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
...@@ -657,13 +502,15 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, ...@@ -657,13 +502,15 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
// avoid ">>>" being formatted to "> > >" // avoid ">>>" being formatted to "> > >"
// clang-format off // clang-format off
kernel<<<blocks, num_threads, max_shared_mem_new, stream>>>( kernel<<<blocks, num_threads, max_shared_mem_new, stream>>>(
A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, s_ptr, s2_ptr, zp_ptr, A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, a_s_ptr, b_s_ptr, g_s_ptr, zp_ptr,
g_idx_ptr, num_groups, g_idx_ptr, num_groups,
prob_m_split, prob_n, prob_k, lda, locks, has_bias, part_use_atomic_add, prob_m_split, prob_n, prob_k, lda, locks, has_bias, part_use_atomic_add,
use_fp32_reduce, max_shared_mem_new); use_fp32_reduce, max_shared_mem_new);
// clang-format on // clang-format on
A_ptr += prob_m_split * (lda / 8); bool is_a_8bit = a_type.size_bits() == 8;
A_ptr += prob_m_split * (lda / (is_a_8bit ? 16 : 8));
a_s_ptr += prob_m_split;
C_ptr += prob_m_split * (prob_n / 8); C_ptr += prob_m_split * (prob_n / 8);
rest_m -= prob_m_split; rest_m -= prob_m_split;
} }
...@@ -675,15 +522,73 @@ torch::Tensor gptq_marlin_gemm( ...@@ -675,15 +522,73 @@ torch::Tensor gptq_marlin_gemm(
torch::Tensor& a, std::optional<torch::Tensor> c_or_none, torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
torch::Tensor& b_q_weight, torch::Tensor& b_q_weight,
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales, std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
std::optional<torch::Tensor> const& a_scales_or_none,
std::optional<torch::Tensor> const& global_scale_or_none, std::optional<torch::Tensor> const& global_scale_or_none,
std::optional<torch::Tensor> const& b_zeros_or_none, std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none, std::optional<torch::Tensor> const& g_idx_or_none,
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace, std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n, vllm::ScalarTypeId const& b_type_id, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
bool is_zp_float) { bool is_zp_float) {
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id); vllm::ScalarTypeId a_type_id, c_type_id, s_type_id;
int pack_factor = 32 / b_q_type.size_bits();
auto c_dtype = a.dtype();
if (a.scalar_type() == at::ScalarType::Half) {
a_type_id = vllm::kFloat16.id();
c_type_id = vllm::kFloat16.id();
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
a_type_id = vllm::kBFloat16.id();
c_type_id = vllm::kBFloat16.id();
} else {
c_dtype = b_scales.dtype();
if (b_scales.scalar_type() == at::ScalarType::Half) {
c_type_id = vllm::kFloat16.id();
} else if (b_scales.scalar_type() == at::ScalarType::BFloat16) {
c_type_id = vllm::kBFloat16.id();
} else {
c_type_id = vllm::kBFloat16.id();
TORCH_CHECK(c_or_none.has_value(), "c must be passed for W4A8-FP4");
torch::Tensor c = c_or_none.value();
c_dtype = c.dtype();
if (c.scalar_type() == at::ScalarType::Half) {
c_type_id = vllm::kFloat16.id();
} else if (c.scalar_type() == at::ScalarType::BFloat16) {
c_type_id = vllm::kBFloat16.id();
} else {
TORCH_CHECK(false, "unsupported c dtype");
}
}
if (a.scalar_type() == at::ScalarType::Float8_e4m3fn) {
a_type_id = vllm::kFE4M3fn.id();
} else if (a.scalar_type() == at::ScalarType::Char) {
a_type_id = vllm::kS8.id();
} else {
TORCH_CHECK(false, "unsupported `a` scalar_type");
}
}
s_type_id = c_type_id;
if (b_type_id == vllm::kFE2M1f.id()) {
if (b_scales.scalar_type() == at::ScalarType::Float8_e4m3fn) {
s_type_id = vllm::kFE4M3fn.id();
} else if (b_scales.scalar_type() == at::ScalarType::Float8_e8m0fnu) {
s_type_id = vllm::kFE8M0fnu.id();
} else {
TORCH_CHECK(false,
"When b_type = float4_e2m1f, b_scale scalar type must be",
"float8_e4m3fn (for NVFP4) or float8_e8m0fnu (for MXFP4).");
}
}
vllm::ScalarType a_type = vllm::ScalarType::from_id(a_type_id);
vllm::ScalarType b_type = vllm::ScalarType::from_id(b_type_id);
vllm::ScalarType c_type = vllm::ScalarType::from_id(c_type_id);
vllm::ScalarType s_type = vllm::ScalarType::from_id(s_type_id);
int pack_factor = 32 / b_type.size_bits();
// Verify A // Verify A
TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0), TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0),
...@@ -721,6 +626,21 @@ torch::Tensor gptq_marlin_gemm( ...@@ -721,6 +626,21 @@ torch::Tensor gptq_marlin_gemm(
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
torch::Tensor a_scales;
auto options = torch::TensorOptions().dtype(c_dtype).device(a.device());
auto options_fp32 =
torch::TensorOptions().dtype(at::kFloat).device(a.device());
if (a_scales_or_none.has_value()) {
a_scales = a_scales_or_none.value();
TORCH_CHECK(a_type.size_bits() == 8,
"a_scales can only be used for 8bit activation.");
} else {
a_scales = torch::empty({0}, options_fp32);
TORCH_CHECK(a_type.size_bits() != 8,
"the a_scales parameter must be passed for 8bit activation.");
}
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as // thread_k: `k` size of a thread_tile in `weights` (can usually be left as
// auto -1) // auto -1)
int thread_k = -1; int thread_k = -1;
...@@ -733,7 +653,6 @@ torch::Tensor gptq_marlin_gemm( ...@@ -733,7 +653,6 @@ torch::Tensor gptq_marlin_gemm(
// Alloc buffers // Alloc buffers
const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
torch::Tensor c; torch::Tensor c;
if (c_or_none.has_value()) { if (c_or_none.has_value()) {
c = c_or_none.value(); c = c_or_none.value();
...@@ -750,8 +669,6 @@ torch::Tensor gptq_marlin_gemm( ...@@ -750,8 +669,6 @@ torch::Tensor gptq_marlin_gemm(
// Alloc C tmp buffer that is going to be used for the global reduce // Alloc C tmp buffer that is going to be used for the global reduce
torch::Tensor c_tmp; torch::Tensor c_tmp;
auto options_fp32 =
torch::TensorOptions().dtype(at::kFloat).device(a.device());
if (use_fp32_reduce) { if (use_fp32_reduce) {
int max_m_block_size = (size_m + 16 - 1) / 16 * 16; int max_m_block_size = (size_m + 16 - 1) / 16 * 16;
max_m_block_size = min(max_m_block_size, 64); max_m_block_size = min(max_m_block_size, 64);
...@@ -821,11 +738,11 @@ torch::Tensor gptq_marlin_gemm( ...@@ -821,11 +738,11 @@ torch::Tensor gptq_marlin_gemm(
torch::Tensor global_scale; torch::Tensor global_scale;
if (global_scale_or_none.has_value()) { if (global_scale_or_none.has_value()) {
global_scale = global_scale_or_none.value(); global_scale = global_scale_or_none.value();
TORCH_CHECK(b_q_type == vllm::kFE2M1f && group_size == 16, TORCH_CHECK(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn,
"global_scale can only be used for nvfp4 format."); "global_scale can only be used for nvfp4 format.");
} else { } else {
global_scale = torch::empty({0}, options); global_scale = torch::empty({0}, options);
TORCH_CHECK(!(b_q_type == vllm::kFE2M1f && group_size == 16), TORCH_CHECK(!(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn),
"the global_scale parameter must be passed for nvfp4 format."); "the global_scale parameter must be passed for nvfp4 format.");
} }
...@@ -852,15 +769,15 @@ torch::Tensor gptq_marlin_gemm( ...@@ -852,15 +769,15 @@ torch::Tensor gptq_marlin_gemm(
bool has_zp = b_zeros.size(-1) > 0; bool has_zp = b_zeros.size(-1) > 0;
if (has_zp) { if (has_zp) {
TORCH_CHECK( TORCH_CHECK(
b_q_type == vllm::kU4 || b_q_type == vllm::kU8, b_type == vllm::kU4 || b_type == vllm::kU8,
"b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str()); "b_type must be u4 or u8 when has_zp = True. Got = ", b_type.str());
} else { } else {
TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 || TORCH_CHECK(b_type == vllm::kU4B8 || b_type == vllm::kU8B128 ||
b_q_type == vllm::kFE4M3fn || b_q_type == vllm::kFE2M1f, b_type == vllm::kS4 || b_type == vllm::kS8 ||
"b_q_type must be uint4b8, uint8b128, float8_e4m3fn or " b_type == vllm::kFE4M3fn || b_type == vllm::kFE2M1f,
"float4_e2m1f when " "b_type must be uint4b8, uint8b128, int4, int8, "
"has_zp = False. Got = ", "float8_e4m3fn or float4_e2m1f when has_zp = False. Got = ",
b_q_type.str()); b_type.str());
} }
if (has_zp && is_zp_float) { if (has_zp && is_zp_float) {
...@@ -902,58 +819,26 @@ torch::Tensor gptq_marlin_gemm( ...@@ -902,58 +819,26 @@ torch::Tensor gptq_marlin_gemm(
" is below min_workspace_size = ", min_workspace_size); " is below min_workspace_size = ", min_workspace_size);
int dev = a.get_device(); int dev = a.get_device();
if (a.scalar_type() == at::ScalarType::Half) {
void* scales_ptr;
if (b_q_type == vllm::kFE2M1f) {
if (group_size == 16)
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
else if (group_size == 32)
scales_ptr = b_scales.data_ptr<at::Float8_e8m0fnu>();
else
TORCH_CHECK(false,
"float4_e2m1f only supports group_size == 16 (NVFP4) ",
"and group_size == 32 (MXFP4)");
} else {
scales_ptr = b_scales.data_ptr<at::Half>();
}
marlin::marlin_mm<half>( TORCH_CHECK(a_scales.scalar_type() == at::ScalarType::Float,
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(), "scalar type of a_scales must be float");
c_tmp.data_ptr<float>(), b_bias.data_ptr<at::Half>(), scales_ptr, TORCH_CHECK(global_scale.scalar_type() == c.scalar_type(),
global_scale.data_ptr<at::Half>(), b_zeros.data_ptr(), g_idx.data_ptr(), "scalar type of global_scale must be the same with c");
perm.data_ptr(), a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k, if (a_type.size_bits() == 16) {
a.stride(0), workspace.data_ptr(), b_q_type, has_bias, has_act_order, TORCH_CHECK(
is_k_full, has_zp, num_groups, group_size, dev, a.scalar_type() == c.scalar_type(),
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, "scalar type of a must be the same with c for 16 bit activation");
use_atomic_add, use_fp32_reduce, is_zp_float);
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
void* scales_ptr;
if (b_q_type == vllm::kFE2M1f) {
if (group_size == 16)
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
else if (group_size == 32)
scales_ptr = b_scales.data_ptr<at::Float8_e8m0fnu>();
else
TORCH_CHECK(false,
"float4_e2m1f only supports group_size == 16 (NVFP4) ",
"and group_size == 32 (MXFP4)");
} else {
scales_ptr = b_scales.data_ptr<at::BFloat16>();
} }
marlin::marlin_mm<nv_bfloat16>( marlin::marlin_mm(
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(), a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), c_tmp.data_ptr(),
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(), b_bias.data_ptr(), a_scales.data_ptr(), b_scales.data_ptr(),
b_bias.data_ptr<at::BFloat16>(), scales_ptr, global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(),
global_scale.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, a.stride(0),
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), workspace.data_ptr(), a_type, b_type, c_type, s_type, has_bias,
size_m, size_n, size_k, a.stride(0), workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size, dev,
has_bias, has_act_order, is_k_full, has_zp, num_groups, group_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
use_atomic_add, use_fp32_reduce, is_zp_float); use_atomic_add, use_fp32_reduce, is_zp_float);
} else {
TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
}
return c; return c;
} }
......
...@@ -4,15 +4,18 @@ ...@@ -4,15 +4,18 @@
namespace marlin { namespace marlin {
template <int const num_threads, int const num_bits, bool const has_perm> template <int const num_threads, int const num_bits, bool const has_perm,
bool is_a_8bit>
__global__ void gptq_marlin_repack_kernel( __global__ void gptq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t const* __restrict__ b_q_weight_ptr,
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr, uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
int size_k, int size_n) { int size_k, int size_n) {
constexpr int pack_factor = 32 / num_bits; constexpr int pack_factor = 32 / num_bits;
int k_tiles = size_k / tile_k_size; constexpr int target_tile_n_size = tile_n_size / (is_a_8bit ? 2 : 1);
int n_tiles = size_n / tile_n_size; constexpr int target_tile_k_size = tile_k_size * (is_a_8bit ? 2 : 1);
int k_tiles = size_k / target_tile_k_size;
int n_tiles = size_n / target_tile_n_size;
int block_k_tiles = div_ceil(k_tiles, gridDim.x); int block_k_tiles = div_ceil(k_tiles, gridDim.x);
auto start_k_tile = blockIdx.x * block_k_tiles; auto start_k_tile = blockIdx.x * block_k_tiles;
...@@ -34,7 +37,7 @@ __global__ void gptq_marlin_repack_kernel( ...@@ -34,7 +37,7 @@ __global__ void gptq_marlin_repack_kernel(
extern __shared__ int4 sh[]; extern __shared__ int4 sh[];
constexpr int perm_size = tile_k_size / 4; constexpr int perm_size = target_tile_k_size / 4;
int4* sh_perm_ptr = sh; int4* sh_perm_ptr = sh;
int4* sh_pipe_ptr = sh_perm_ptr; int4* sh_pipe_ptr = sh_perm_ptr;
...@@ -42,14 +45,14 @@ __global__ void gptq_marlin_repack_kernel( ...@@ -42,14 +45,14 @@ __global__ void gptq_marlin_repack_kernel(
sh_pipe_ptr += perm_size; sh_pipe_ptr += perm_size;
} }
constexpr int tile_ints = tile_k_size / pack_factor; constexpr int tile_ints = target_tile_k_size / pack_factor;
constexpr int stage_n_threads = tile_n_size / 4; constexpr int stage_n_threads = target_tile_n_size / 4;
constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints; constexpr int stage_k_threads = has_perm ? target_tile_k_size : tile_ints;
constexpr int stage_size = stage_k_threads * stage_n_threads; constexpr int stage_size = stage_k_threads * stage_n_threads;
auto load_perm_to_shared = [&](int k_tile_id) { auto load_perm_to_shared = [&](int k_tile_id) {
int first_k_int4 = (k_tile_id * tile_k_size) / 4; int first_k_int4 = (k_tile_id * target_tile_k_size) / 4;
int4 const* perm_int4_ptr = reinterpret_cast<int4 const*>(perm_ptr); int4 const* perm_int4_ptr = reinterpret_cast<int4 const*>(perm_ptr);
...@@ -65,7 +68,7 @@ __global__ void gptq_marlin_repack_kernel( ...@@ -65,7 +68,7 @@ __global__ void gptq_marlin_repack_kernel(
return; return;
} }
int first_n = n_tile_id * tile_n_size; int first_n = n_tile_id * target_tile_n_size;
int4* sh_ptr = sh_pipe_ptr + stage_size * pipe; int4* sh_ptr = sh_pipe_ptr + stage_size * pipe;
...@@ -91,7 +94,7 @@ __global__ void gptq_marlin_repack_kernel( ...@@ -91,7 +94,7 @@ __global__ void gptq_marlin_repack_kernel(
auto k_id = threadIdx.x / stage_n_threads; auto k_id = threadIdx.x / stage_n_threads;
auto n_id = threadIdx.x % stage_n_threads; auto n_id = threadIdx.x % stage_n_threads;
int first_k = k_tile_id * tile_k_size; int first_k = k_tile_id * target_tile_k_size;
int first_k_packed = first_k / pack_factor; int first_k_packed = first_k / pack_factor;
cp_async4(&sh_ptr[k_id * stage_n_threads + n_id], cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
...@@ -117,13 +120,13 @@ __global__ void gptq_marlin_repack_kernel( ...@@ -117,13 +120,13 @@ __global__ void gptq_marlin_repack_kernel(
} }
int tc_col = th_id / 4; int tc_col = th_id / 4;
int tc_row = (th_id % 4) * 2; int tc_row = (th_id % 4) * (is_a_8bit ? 4 : 2);
constexpr int tc_offsets[4] = {0, 1, 8, 9}; constexpr int tc_offsets[4] = {0, 1, 8, 9};
int cur_n = warp_id * 16 + tc_col; int cur_n = (warp_id / (is_a_8bit ? 2 : 1)) * 16 + tc_col;
constexpr int sh_stride = 64; constexpr int sh_stride = target_tile_n_size;
constexpr uint32_t mask = (1 << num_bits) - 1; constexpr uint32_t mask = (1 << num_bits) - 1;
int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe; int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe;
...@@ -134,6 +137,7 @@ __global__ void gptq_marlin_repack_kernel( ...@@ -134,6 +137,7 @@ __global__ void gptq_marlin_repack_kernel(
uint32_t vals[8]; uint32_t vals[8];
if constexpr (has_perm) { if constexpr (has_perm) {
static_assert(!is_a_8bit);
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
int k_idx = tc_row + tc_offsets[i]; int k_idx = tc_row + tc_offsets[i];
...@@ -156,28 +160,49 @@ __global__ void gptq_marlin_repack_kernel( ...@@ -156,28 +160,49 @@ __global__ void gptq_marlin_repack_kernel(
#pragma unroll #pragma unroll
for (int i = 0; i < tile_ints; i++) { for (int i = 0; i < tile_ints; i++) {
if constexpr (is_a_8bit) {
b1_vals[i] =
sh_stage_int_ptr[cur_n + sh_stride * i + (warp_id % 2) * 8];
} else {
b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i]; b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i];
b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i]; b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i];
} }
}
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
int cur_elem = tc_row + tc_offsets[i]; int cur_elem = tc_row + (is_a_8bit ? i : tc_offsets[i]);
int cur_int = cur_elem / pack_factor; int cur_int = cur_elem / pack_factor;
int cur_pos = cur_elem % pack_factor; int cur_pos = cur_elem % pack_factor;
vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask; vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask;
if constexpr (is_a_8bit)
vals[4 + i] =
(b1_vals[cur_int + tile_ints / 2] >> (cur_pos * num_bits)) & mask;
else
vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask; vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask;
} }
} }
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor; constexpr int tile_size =
target_tile_k_size * target_tile_n_size / pack_factor;
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
// Result of: // Result of:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
if constexpr (num_bits == 4) { if constexpr (!is_a_8bit && num_bits == 4) {
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
uint32_t res = 0;
#pragma unroll
for (int i = 0; i < 8; i++) {
res |= vals[pack_idx[i]] << (i * 4);
}
out_ptr[out_offset + th_id * 4 + warp_id] = res;
} else if constexpr (is_a_8bit && num_bits == 4) {
int pack_idx[8] = {0, 4, 1, 5, 2, 6, 3, 7};
uint32_t res = 0; uint32_t res = 0;
#pragma unroll #pragma unroll
...@@ -194,8 +219,9 @@ __global__ void gptq_marlin_repack_kernel( ...@@ -194,8 +219,9 @@ __global__ void gptq_marlin_repack_kernel(
uint32_t res2 = 0; uint32_t res2 = 0;
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
res1 |= vals[pack_idx[i]] << (i * 8); const int ii = is_a_8bit ? i : pack_idx[i];
res2 |= vals[4 + pack_idx[i]] << (i * 8); res1 |= vals[ii] << (i * 8);
res2 |= vals[4 + ii] << (i * 8);
} }
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
...@@ -236,21 +262,22 @@ __global__ void gptq_marlin_repack_kernel( ...@@ -236,21 +262,22 @@ __global__ void gptq_marlin_repack_kernel(
} // namespace marlin } // namespace marlin
#define CALL_IF(NUM_BITS, HAS_PERM) \ #define CALL_IF(NUM_BITS, HAS_PERM, IS_A_8BIT) \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ else if (num_bits == NUM_BITS && has_perm == HAS_PERM && \
is_a_8bit == IS_A_8BIT) { \
cudaFuncSetAttribute( \ cudaFuncSetAttribute( \
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \ marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
HAS_PERM>, \ HAS_PERM, IS_A_8BIT>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \ marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
HAS_PERM> \ HAS_PERM, IS_A_8BIT> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \ <<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \ b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
} }
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n, int64_t size_k, int64_t size_n,
int64_t num_bits) { int64_t num_bits, bool is_a_8bit) {
// Verify compatibility with marlin tile of 16x64 // Verify compatibility with marlin tile of 16x64
TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k, TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k,
" is not divisible by tile_k_size = ", marlin::tile_k_size); " is not divisible by tile_k_size = ", marlin::tile_k_size);
...@@ -309,13 +336,17 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, ...@@ -309,13 +336,17 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
if (false) { if (false) {
} }
CALL_IF(4, false) CALL_IF(4, false, false)
CALL_IF(4, true) CALL_IF(4, true, false)
CALL_IF(8, false) CALL_IF(8, false, false)
CALL_IF(8, true) CALL_IF(8, true, false)
CALL_IF(4, false, true)
CALL_IF(8, false, true)
else { else {
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits, TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits,
", has_perm = ", has_perm); ", has_perm = ", has_perm, ", is_a_8bit = ", is_a_8bit);
} }
return out; return out;
......
...@@ -11,17 +11,19 @@ ...@@ -11,17 +11,19 @@
const int4 *__restrict__ A, const int4 *__restrict__ B, \ const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ b_bias_ptr, \ const int4 *__restrict__ b_bias_ptr, \
const float *__restrict__ a_scales_ptr, \
const int4 *__restrict__ scales_ptr, \ const int4 *__restrict__ scales_ptr, \
const uint16_t *__restrict__ scale2_ptr, \ const uint16_t *__restrict__ global_scale_ptr, \
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \ int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \
bool has_bias, bool use_atomic_add, bool use_fp32_reduce, \ bool has_bias, bool use_atomic_add, bool use_fp32_reduce, \
int max_shared_mem int max_shared_mem
namespace MARLIN_NAMESPACE_NAME { namespace MARLIN_NAMESPACE_NAME {
template <typename scalar_t, // compute dtype, half or nv_float16 template <const vllm::ScalarTypeId a_type_id, // A ScalarType id
const vllm::ScalarTypeId w_type_id, // weight ScalarType id const vllm::ScalarTypeId b_type_id, // B ScalarType id
const vllm::ScalarTypeId s_type_id, // weight ScalarType id const vllm::ScalarTypeId c_type_id, // C ScalarType id
const vllm::ScalarTypeId s_type_id, // B_SCALE ScalarType id
const int threads, // number of threads in a threadblock const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the // dimension (batchsize) of the
......
...@@ -55,6 +55,45 @@ constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } ...@@ -55,6 +55,45 @@ constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
// No support for async // No support for async
#else #else
__device__ inline void cp_async1_ca_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
const int BYTES = 4;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.ca.shared.global [%1], [%2], %3;\n"
"}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr), "n"(BYTES));
}
__device__ inline void cp_async2_ca_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
const int BYTES = 8;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.ca.shared.global [%1], [%2], %3;\n"
"}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr), "n"(BYTES));
}
__device__ inline void cp_async4_ca_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.ca.shared.global [%1], [%2], %3;\n"
"}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr), "n"(BYTES));
}
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) { bool pred = true) {
const int BYTES = 16; const int BYTES = 16;
......
...@@ -2,8 +2,10 @@ ...@@ -2,8 +2,10 @@
#ifndef _data_types_cuh #ifndef _data_types_cuh
#define _data_types_cuh #define _data_types_cuh
#include "marlin.cuh" #include "marlin.cuh"
#include "core/scalar_type.hpp"
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include <cuda_fp8.h>
#ifndef MARLIN_NAMESPACE_NAME #ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin #define MARLIN_NAMESPACE_NAME marlin
...@@ -11,14 +13,16 @@ ...@@ -11,14 +13,16 @@
namespace MARLIN_NAMESPACE_NAME { namespace MARLIN_NAMESPACE_NAME {
template <typename scalar_t> template <long scalar_type_id>
class ScalarType {}; class MarlinScalarType {};
template <> template <>
class ScalarType<half> { class MarlinScalarType<vllm::kFloat16.id()> {
public: public:
using scalar_t = half; using scalar_t = half;
using scalar_t2 = half2; using scalar_t2 = half2;
using scalar_t4 = half2;
using scalar_32bit_t = half2;
// Matrix fragments for tensor core instructions; their precise layout is // Matrix fragments for tensor core instructions; their precise layout is
// documented here: // documented here:
...@@ -27,6 +31,7 @@ class ScalarType<half> { ...@@ -27,6 +31,7 @@ class ScalarType<half> {
using FragB = Vec<half2, 2>; using FragB = Vec<half2, 2>;
using FragC = Vec<float, 4>; using FragC = Vec<float, 4>;
using FragS = Vec<half2, 1>; using FragS = Vec<half2, 1>;
using FragS0 = Vec<__nv_fp8x2_e4m3, 1>;
using FragZP = Vec<half2, 4>; using FragZP = Vec<half2, 4>;
static __device__ float inline num2float(const half x) { static __device__ float inline num2float(const half x) {
...@@ -44,18 +49,25 @@ class ScalarType<half> { ...@@ -44,18 +49,25 @@ class ScalarType<half> {
static __host__ __device__ half inline float2num(const float x) { static __host__ __device__ half inline float2num(const float x) {
return __float2half(x); return __float2half(x);
} }
static __host__ __device__ float2 inline num22float2(const half2 x) {
return __half22float2(x);
}
}; };
template <> template <>
class ScalarType<nv_bfloat16> { class MarlinScalarType<vllm::kBFloat16.id()> {
public: public:
using scalar_t = nv_bfloat16; using scalar_t = nv_bfloat16;
using scalar_t2 = nv_bfloat162; using scalar_t2 = nv_bfloat162;
using scalar_t4 = nv_bfloat162;
using scalar_32bit_t = nv_bfloat162;
using FragA = Vec<nv_bfloat162, 4>; using FragA = Vec<nv_bfloat162, 4>;
using FragB = Vec<nv_bfloat162, 2>; using FragB = Vec<nv_bfloat162, 2>;
using FragC = Vec<float, 4>; using FragC = Vec<float, 4>;
using FragS = Vec<nv_bfloat162, 1>; using FragS = Vec<nv_bfloat162, 1>;
using FragS0 = Vec<__nv_fp8x2_e4m3, 1>;
using FragZP = Vec<nv_bfloat162, 4>; using FragZP = Vec<nv_bfloat162, 4>;
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
...@@ -75,9 +87,63 @@ class ScalarType<nv_bfloat16> { ...@@ -75,9 +87,63 @@ class ScalarType<nv_bfloat16> {
static __host__ __device__ nv_bfloat16 inline float2num(const float x) { static __host__ __device__ nv_bfloat16 inline float2num(const float x) {
return __float2bfloat16(x); return __float2bfloat16(x);
} }
static __host__ __device__ float2 inline num22float2(const nv_bfloat162 x) {
return __bfloat1622float2(x);
}
#endif #endif
}; };
template <>
class MarlinScalarType<vllm::kFE4M3fn.id()> {
public:
using scalar_t = __nv_fp8_e4m3;
using scalar_t2 = __nv_fp8x2_e4m3;
using scalar_t4 = __nv_fp8x4_e4m3;
using scalar_32bit_t = __nv_fp8x4_e4m3;
using FragA = Vec<__nv_fp8x4_e4m3, 4>;
using FragB = Vec<__nv_fp8x4_e4m3, 2>;
using FragC = Vec<float, 4>;
using FragZP = Vec<__nv_fp8x2_e4m3, 4>;
static __host__ __device__
float2 inline num22float2(const __nv_fp8x2_e4m3 x) {
return (float2)x;
}
};
template <>
class MarlinScalarType<vllm::kS8.id()> {
public:
using scalar_t = int8_t;
using scalar_t2 = int16_t;
using scalar_t4 = int32_t;
using scalar_32bit_t = int32_t;
using FragA = Vec<int32_t, 4>;
using FragB = Vec<int32_t, 2>;
using FragC = Vec<float, 4>;
using FragZP = Vec<int16_t, 4>;
};
template <typename scalar_t>
class MarlinScalarType2 {};
template <>
class MarlinScalarType2<half> : public MarlinScalarType<vllm::kFloat16.id()> {};
template <>
class MarlinScalarType2<nv_bfloat16>
: public MarlinScalarType<vllm::kBFloat16.id()> {};
template <>
class MarlinScalarType2<__nv_fp8_e4m3>
: public MarlinScalarType<vllm::kFE4M3fn.id()> {};
template <>
class MarlinScalarType2<int8_t> : public MarlinScalarType<vllm::kS8.id()> {};
} // namespace MARLIN_NAMESPACE_NAME } // namespace MARLIN_NAMESPACE_NAME
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment