Unverified Commit 1656ad37 authored by Jinzhen Lin's avatar Jinzhen Lin Committed by GitHub
Browse files

[Kernel][Quantization] add w4a8 support for marlin kernel (#24722)


Signed-off-by: default avatarJinzhen Lin <jinzhen.ljz@antgroup.com>
Signed-off-by: default avatarMichael Goin <mgoin64@gmail.com>
Signed-off-by: default avatarJinzhen Lin <linjinzhen@hotmail.com>
Co-authored-by: default avatarMichael Goin <mgoin64@gmail.com>
Co-authored-by: default avatarMichael Goin <mgoin@redhat.com>
parent fa59fe41
...@@ -354,8 +354,17 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -354,8 +354,17 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# Only build Marlin kernels if we are building for at least some compatible archs. # Only build Marlin kernels if we are building for at least some compatible archs.
# Keep building Marlin for 9.0 as there are some group sizes and shapes that # Keep building Marlin for 9.0 as there are some group sizes and shapes that
# are not supported by Machete yet. # are not supported by Machete yet.
# 9.0 for latest bf16 atomicAdd PTX
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}") # marlin arches for fp16 output
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0+PTX" "${CUDA_ARCHS}")
# marlin arches for bf16 output (we need 9.0 for bf16 atomicAdd PTX)
cuda_archs_loose_intersection(MARLIN_BF16_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}")
# marlin arches for fp8 input
# - sm80 doesn't support fp8 computation
# - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction
# so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0)
cuda_archs_loose_intersection(MARLIN_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}")
if (MARLIN_ARCHS) if (MARLIN_ARCHS)
# #
...@@ -365,16 +374,18 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -365,16 +374,18 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set(MARLIN_GEN_SCRIPT set(MARLIN_GEN_SCRIPT
${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/gptq_marlin/generate_kernels.py) ${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/gptq_marlin/generate_kernels.py)
file(MD5 ${MARLIN_GEN_SCRIPT} MARLIN_GEN_SCRIPT_HASH) file(MD5 ${MARLIN_GEN_SCRIPT} MARLIN_GEN_SCRIPT_HASH)
list(JOIN CUDA_ARCHS "," CUDA_ARCHS_STR)
set(MARLIN_GEN_SCRIPT_HASH_AND_ARCH "${MARLIN_GEN_SCRIPT_HASH}(ARCH:${CUDA_ARCHS_STR})")
message(STATUS "Marlin generation script hash: ${MARLIN_GEN_SCRIPT_HASH}") message(STATUS "Marlin generation script hash: ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH}")
message(STATUS "Last run Marlin generate script hash: $CACHE{MARLIN_GEN_SCRIPT_HASH}") message(STATUS "Last run Marlin generate script hash: $CACHE{MARLIN_GEN_SCRIPT_HASH_AND_ARCH}")
if (NOT DEFINED CACHE{MARLIN_GEN_SCRIPT_HASH} if (NOT DEFINED CACHE{MARLIN_GEN_SCRIPT_HASH_AND_ARCH}
OR NOT $CACHE{MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MARLIN_GEN_SCRIPT_HASH}) OR NOT $CACHE{MARLIN_GEN_SCRIPT_HASH_AND_ARCH} STREQUAL ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH})
execute_process( execute_process(
COMMAND ${CMAKE_COMMAND} -E env COMMAND ${CMAKE_COMMAND} -E env
PYTHONPATH=$PYTHONPATH PYTHONPATH=$PYTHONPATH
${Python_EXECUTABLE} ${MARLIN_GEN_SCRIPT} ${Python_EXECUTABLE} ${MARLIN_GEN_SCRIPT} ${CUDA_ARCHS_STR}
RESULT_VARIABLE marlin_generation_result RESULT_VARIABLE marlin_generation_result
OUTPUT_VARIABLE marlin_generation_result OUTPUT_VARIABLE marlin_generation_result
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log
...@@ -387,15 +398,15 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -387,15 +398,15 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"\nCheck the log for details: " "\nCheck the log for details: "
"${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log") "${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log")
else() else()
set(MARLIN_GEN_SCRIPT_HASH ${MARLIN_GEN_SCRIPT_HASH} set(MARLIN_GEN_SCRIPT_HASH_AND_ARCH ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH}
CACHE STRING "Last run Marlin generate script hash" FORCE) CACHE STRING "Last run Marlin generate script hash and arch" FORCE)
message(STATUS "Marlin generation completed successfully.") message(STATUS "Marlin generation completed successfully.")
endif() endif()
else() else()
message(STATUS "Marlin generation script has not changed, skipping generation.") message(STATUS "Marlin generation script has not changed, skipping generation.")
endif() endif()
file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/kernel_*.cu") file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_float16.cu")
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(
SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}" SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}"
CUDA_ARCHS "${MARLIN_ARCHS}") CUDA_ARCHS "${MARLIN_ARCHS}")
...@@ -403,12 +414,34 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -403,12 +414,34 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set_source_files_properties(${MARLIN_TEMPLATE_KERNEL_SRC} set_source_files_properties(${MARLIN_TEMPLATE_KERNEL_SRC}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
endif() endif()
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC}) list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC})
file(GLOB MARLIN_TEMPLATE_BF16_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_bfloat16.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_TEMPLATE_BF16_KERNEL_SRC}"
CUDA_ARCHS "${MARLIN_BF16_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
set_source_files_properties(${MARLIN_TEMPLATE_BF16_KERNEL_SRC}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
endif()
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_BF16_KERNEL_SRC})
if (MARLIN_FP8_ARCHS)
file(GLOB MARLIN_TEMPLATE_FP8_KERNEL_SRC "csrc/quantization/gptq_marlin/sm89_kernel_*.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_TEMPLATE_FP8_KERNEL_SRC}"
CUDA_ARCHS "${MARLIN_FP8_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
set_source_files_properties(${MARLIN_TEMPLATE_FP8_KERNEL_SRC}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
endif()
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_FP8_KERNEL_SRC})
endif()
set(MARLIN_SRCS set(MARLIN_SRCS
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu" "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
"csrc/quantization/gptq_marlin/gptq_marlin.cu" "csrc/quantization/gptq_marlin/gptq_marlin.cu"
"csrc/quantization/gptq_marlin/marlin_int4_fp8_preprocess.cu"
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu") "csrc/quantization/gptq_marlin/awq_marlin_repack.cu")
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(
...@@ -941,8 +974,15 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -941,8 +974,15 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
CUDA_ARCHS "${CUDA_ARCHS}") CUDA_ARCHS "${CUDA_ARCHS}")
list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}") list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}")
# 9.0 for latest bf16 atomicAdd PTX # moe marlin arches
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}") # note that we always set `use_atomic_add=False` for moe marlin now,
# so we don't need 9.0 for bf16 atomicAdd PTX
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0+PTX" "${CUDA_ARCHS}")
# moe marlin arches for fp8 input
# - sm80 doesn't support fp8 computation
# - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction
# so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0)
cuda_archs_loose_intersection(MARLIN_MOE_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}")
if (MARLIN_MOE_ARCHS) if (MARLIN_MOE_ARCHS)
# #
...@@ -952,16 +992,18 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -952,16 +992,18 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set(MOE_MARLIN_GEN_SCRIPT set(MOE_MARLIN_GEN_SCRIPT
${CMAKE_CURRENT_SOURCE_DIR}/csrc/moe/marlin_moe_wna16/generate_kernels.py) ${CMAKE_CURRENT_SOURCE_DIR}/csrc/moe/marlin_moe_wna16/generate_kernels.py)
file(MD5 ${MOE_MARLIN_GEN_SCRIPT} MOE_MARLIN_GEN_SCRIPT_HASH) file(MD5 ${MOE_MARLIN_GEN_SCRIPT} MOE_MARLIN_GEN_SCRIPT_HASH)
list(JOIN CUDA_ARCHS "," CUDA_ARCHS_STR)
set(MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH "${MOE_MARLIN_GEN_SCRIPT_HASH}(ARCH:${CUDA_ARCHS_STR})")
message(STATUS "Marlin MOE generation script hash: ${MOE_MARLIN_GEN_SCRIPT_HASH}") message(STATUS "Marlin MOE generation script hash with arch: ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}")
message(STATUS "Last run Marlin MOE generate script hash: $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH}") message(STATUS "Last run Marlin MOE generate script hash with arch: $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}")
if (NOT DEFINED CACHE{MOE_MARLIN_GEN_SCRIPT_HASH} if (NOT DEFINED CACHE{MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}
OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH}) OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH})
execute_process( execute_process(
COMMAND ${CMAKE_COMMAND} -E env COMMAND ${CMAKE_COMMAND} -E env
PYTHONPATH=$PYTHONPATH PYTHONPATH=$PYTHONPATH
${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT} ${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT} ${CUDA_ARCHS_STR}
RESULT_VARIABLE moe_marlin_generation_result RESULT_VARIABLE moe_marlin_generation_result
OUTPUT_VARIABLE moe_marlin_generation_output OUTPUT_VARIABLE moe_marlin_generation_output
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log
...@@ -974,7 +1016,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -974,7 +1016,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"\nCheck the log for details: " "\nCheck the log for details: "
"${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log") "${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log")
else() else()
set(MOE_MARLIN_GEN_SCRIPT_HASH ${MOE_MARLIN_GEN_SCRIPT_HASH} set(MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}
CACHE STRING "Last run Marlin MOE generate script hash" FORCE) CACHE STRING "Last run Marlin MOE generate script hash" FORCE)
message(STATUS "Marlin MOE generation completed successfully.") message(STATUS "Marlin MOE generation completed successfully.")
endif() endif()
...@@ -982,16 +1024,28 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -982,16 +1024,28 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
message(STATUS "Marlin MOE generation script has not changed, skipping generation.") message(STATUS "Marlin MOE generation script has not changed, skipping generation.")
endif() endif()
file(GLOB MOE_WNAA16_MARLIN_SRC "csrc/moe/marlin_moe_wna16/*.cu") file(GLOB MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/sm80_kernel_*.cu")
list(APPEND MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/ops.cu")
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(
SRCS "${MOE_WNAA16_MARLIN_SRC}" SRCS "${MARLIN_MOE_SRC}"
CUDA_ARCHS "${MARLIN_MOE_ARCHS}") CUDA_ARCHS "${MARLIN_MOE_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
set_source_files_properties(${MOE_WNAA16_MARLIN_SRC} set_source_files_properties(${MARLIN_MOE_SRC}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
endif() endif()
list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_SRC})
list(APPEND VLLM_MOE_EXT_SRC ${MOE_WNAA16_MARLIN_SRC}) if (MARLIN_MOE_FP8_ARCHS)
file(GLOB MARLIN_MOE_FP8_SRC "csrc/moe/marlin_moe_wna16/sm89_kernel_*.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_MOE_FP8_SRC}"
CUDA_ARCHS "${MARLIN_MOE_FP8_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
set_source_files_properties(${MARLIN_MOE_FP8_SRC}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
endif()
list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_FP8_SRC})
endif()
message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}") message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}")
else() else()
......
...@@ -237,6 +237,7 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable: ...@@ -237,6 +237,7 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
b_q_weight=w_q, b_q_weight=w_q,
b_bias=None, b_bias=None,
b_scales=w_s, b_scales=w_s,
a_scales=None,
global_scale=None, global_scale=None,
b_zeros=w_zp, b_zeros=w_zp,
g_idx=g_idx, g_idx=g_idx,
......
...@@ -263,7 +263,7 @@ def bench_run( ...@@ -263,7 +263,7 @@ def bench_run(
results.append( results.append(
benchmark.Timer( benchmark.Timer(
stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501 stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501
globals=globals, globals=globals,
label=label, label=label,
sub_label=sub_label, sub_label=sub_label,
...@@ -273,7 +273,7 @@ def bench_run( ...@@ -273,7 +273,7 @@ def bench_run(
results.append( results.append(
benchmark.Timer( benchmark.Timer(
stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501 stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501
globals=globals, globals=globals,
label=label, label=label,
sub_label=sub_label, sub_label=sub_label,
......
kernel_*.cu sm*_kernel_*.cu
\ No newline at end of file kernel_selector.h
...@@ -4,134 +4,282 @@ import glob ...@@ -4,134 +4,282 @@ import glob
import itertools import itertools
import os import os
import subprocess import subprocess
import sys
import jinja2 import jinja2
FILE_HEAD = """ ARCHS = []
// auto generated by generate.py SUPPORT_FP8 = False
for arch in sys.argv[1].split(","):
arch = arch[: arch.index(".") + 2].replace(".", "")
arch = int(arch)
# only SM89 and SM120 fully support
# mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32.
# SM90 and SM100 can use this PTX, but it’s simulated
# with FP16 MMA, so it cannot achieve any acceleration.
if arch in [89, 120]:
SUPPORT_FP8 = True
FILE_HEAD_COMMENT = """
// auto generated by generate_kernels.py
// clang-format off // clang-format off
""".lstrip()
FILE_HEAD = (
FILE_HEAD_COMMENT
+ """
#include "kernel.h" #include "kernel.h"
#include "marlin_template.h" #include "marlin_template.h"
namespace MARLIN_NAMESPACE_NAME { namespace MARLIN_NAMESPACE_NAME {
""".strip() """
)
TEMPLATE = ( TEMPLATE = (
"template __global__ void Marlin<" "template __global__ void Marlin<"
"{{scalar_t}}, " "{{a_type_id}}, "
"{{w_type_id}}, " "{{b_type_id}}, "
"{{c_type_id}}, "
"{{s_type_id}}, " "{{s_type_id}}, "
"{{threads}}, " "{{threads}}, "
"{{thread_m_blocks}}, " "{{thread_m_blocks}}, "
"{{thread_n_blocks}}, " "{{thread_n_blocks}}, "
"{{thread_k_blocks}}, " "{{thread_k_blocks}}, "
"{{'true' if m_block_size_8 else 'false'}}, " "{{m_block_size_8}}, "
"{{stages}}, " "{{stages}}, "
"{{group_blocks}}, " "{{group_blocks}}, "
"{{'true' if is_zp_float else 'false'}}>" "{{is_zp_float}}>"
"( MARLIN_KERNEL_PARAMS );" "( MARLIN_KERNEL_PARAMS );"
) )
# int8 with zero point case (vllm::kU8) is also supported,
# we don't add it to reduce wheel size.
SCALAR_TYPES = [
"vllm::kU4",
"vllm::kU4B8",
"vllm::kU8B128",
"vllm::kFE4M3fn",
"vllm::kFE2M1f",
]
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)] THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4] THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
# group_blocks:
# = 0 : act order case QUANT_CONFIGS = [
# = -1 : channelwise quantization # AWQ-INT4
# > 0 : group_size=16*group_blocks {
GROUP_BLOCKS = [0, -1, 1, 2, 4, 8] "b_type": "kU4",
DTYPES = ["fp16", "bf16"] "thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT4
{
"b_type": "kU4B8",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 0, 2, 4, 8],
},
# AWQ-INT8
{
"b_type": "kU8B128",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 0, 2, 4, 8],
},
# FP8
{
"b_type": "kFE4M3fn",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 8],
},
# NVFP4
{
"b_type": "kFE2M1f",
"s_type": "kFE4M3fn",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [1],
},
# MXFP4
{
"a_type": ["kBFloat16"],
"b_type": "kFE2M1f",
"s_type": "kFE8M0fnu",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [2],
},
# AWQ-INT4 with INT8 activation
{
"a_type": ["kS8"],
"b_type": "kU4",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT4 with INT8 activation
{
"a_type": ["kS8"],
"b_type": "kU4B8",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT4 with FP8 activation
{
"a_type": ["kFE4M3fn"],
"b_type": "kU4B8",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# AWQ-INT4 with FP8 activation
{
"a_type": ["kFE4M3fn"],
"b_type": "kU4",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# MXFP4 with FP8 activation
{
"a_type": ["kFE4M3fn"],
"b_type": "kFE2M1f",
"c_type": ["kBFloat16"],
"s_type": "kFE8M0fnu",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [2],
},
]
def remove_old_kernels(): def remove_old_kernels():
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"): for filename in glob.glob(os.path.dirname(__file__) + "/*kernel_*.cu"):
subprocess.call(["rm", "-f", filename])
filename = os.path.dirname(__file__) + "/kernel_selector.h"
subprocess.call(["rm", "-f", filename]) subprocess.call(["rm", "-f", filename])
def generate_new_kernels(): def generate_new_kernels():
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES): result_dict = {}
all_template_str_list = []
for group_blocks, m_blocks, thread_configs in itertools.product( for quant_config in QUANT_CONFIGS:
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"])
): a_types = quant_config.get("a_type", ["kFloat16", "kBFloat16"])
# act order case only support gptq-int4 and gptq-int8 b_type = quant_config["b_type"]
if group_blocks == 0 and scalar_type not in [ all_group_blocks = quant_config["group_blocks"]
"vllm::kU4B8", all_m_blocks = quant_config["thread_m_blocks"]
"vllm::kU8B128", all_thread_configs = quant_config["thread_configs"]
]:
continue for a_type, c_type in itertools.product(a_types, c_types):
if thread_configs[2] == 256: if not SUPPORT_FP8 and a_type == "kFE4M3fn":
# for small batch (m_blocks == 1), we only need (128, 128, 256)
# for large batch (m_blocks > 1), we only need (64, 256, 256)
if m_blocks <= 1 and thread_configs[0] != 128:
continue continue
if m_blocks > 1 and thread_configs[0] != 64: if "16" in a_type and "16" in c_type and a_type != c_type:
continue continue
s_type = quant_config.get("s_type", c_type)
if (a_type, b_type, c_type) not in result_dict:
result_dict[(a_type, b_type, c_type)] = []
# we only support channelwise quantization and group_size == 128 for group_blocks, m_blocks, thread_configs in itertools.product(
# for fp8 all_group_blocks, all_m_blocks, all_thread_configs
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]: ):
continue thread_k, thread_n, threads = thread_configs
# nvfp4 only supports group_size == 16
# mxfp4 only supports group_size == 32 if threads == 256:
if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]: # for small batch (m_blocks == 1),
# we only need (128, 128, 256)
# for large batch (m_blocks > 1),
# we only need (64, 256, 256)
if m_blocks <= 1 and (thread_k, thread_n) != (128, 128):
continue continue
# other quantization methods don't support group_size = 16 if m_blocks > 1 and (thread_k, thread_n) != (64, 256):
if scalar_type != "vllm::kFE2M1f" and group_blocks == 1:
continue continue
k_blocks = thread_configs[0] // 16 config = {
n_blocks = thread_configs[1] // 16 "threads": threads,
threads = thread_configs[2] "s_type": s_type,
"thread_m_blocks": max(m_blocks, 1),
"thread_k_blocks": thread_k // 16,
"thread_n_blocks": thread_n // 16,
"m_block_size_8": "true" if m_blocks == 0.5 else "false",
"stages": "pipe_stages",
"group_blocks": group_blocks,
"is_zp_float": "false",
}
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16" result_dict[(a_type, b_type, c_type)].append(config)
if scalar_type == "vllm::kFE2M1f" and group_blocks == 1: kernel_selector_str = FILE_HEAD_COMMENT
s_type = "vllm::kFE4M3fn"
elif scalar_type == "vllm::kFE2M1f" and group_blocks == 2:
s_type = "vllm::kFE8M0fnu"
if dtype == "fp16":
# we cannot safely dequantize e8m0 to fp16, so skip this
continue
elif dtype == "fp16":
s_type = "vllm::kFloat16"
elif dtype == "bf16":
s_type = "vllm::kBFloat16"
for (a_type, b_type, c_type), config_list in result_dict.items():
all_template_str_list = []
for config in config_list:
s_type = config["s_type"]
template_str = jinja2.Template(TEMPLATE).render( template_str = jinja2.Template(TEMPLATE).render(
scalar_t=c_dtype, a_type_id=f"vllm::{a_type}.id()",
w_type_id=scalar_type + ".id()", b_type_id=f"vllm::{b_type}.id()",
s_type_id=s_type + ".id()", c_type_id=f"vllm::{c_type}.id()",
threads=threads, s_type_id=f"vllm::{s_type}.id()",
thread_m_blocks=max(m_blocks, 1), **config,
thread_n_blocks=n_blocks,
thread_k_blocks=k_blocks,
m_block_size_8=m_blocks == 0.5,
stages="pipe_stages",
group_blocks=group_blocks,
is_zp_float=False,
) )
all_template_str_list.append(template_str) all_template_str_list.append(template_str)
conditions = [
f"a_type == vllm::{a_type}",
f"b_type == vllm::{b_type}",
f"c_type == vllm::{c_type}",
f"s_type == vllm::{s_type}",
f"threads == {config['threads']}",
f"thread_m_blocks == {config['thread_m_blocks']}",
f"thread_n_blocks == {config['thread_n_blocks']}",
f"thread_k_blocks == {config['thread_k_blocks']}",
f"m_block_size_8 == {config['m_block_size_8']}",
f"group_blocks == {config['group_blocks']}",
f"is_zp_float == {config['is_zp_float']}",
]
conditions = " && ".join(conditions)
if kernel_selector_str == FILE_HEAD_COMMENT:
kernel_selector_str += f"if ({conditions})\n kernel = "
else:
kernel_selector_str += f"else if ({conditions})\n kernel = "
kernel_template2 = (
"Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
"{{is_zp_float}}>;"
)
kernel_selector_str += (
jinja2.Template(kernel_template2).render(
a_type_id=f"vllm::{a_type}.id()",
b_type_id=f"vllm::{b_type}.id()",
c_type_id=f"vllm::{c_type}.id()",
s_type_id=f"vllm::{s_type}.id()",
**config,
)
+ "\n"
)
file_content = FILE_HEAD + "\n\n" file_content = FILE_HEAD + "\n\n"
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n" file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu" if a_type == "kFE4M3fn":
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
else:
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
filename = filename.lower()
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f: with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
f.write(file_content) f.write(file_content)
if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT:
kernel_selector_str += (
"else if (a_type == vllm::kFE4M3fn)\n"
" TORCH_CHECK(false, "
'"marlin kernel with fp8 activation is not built.");'
)
with open(os.path.join(os.path.dirname(__file__), "kernel_selector.h"), "w") as f:
f.write(kernel_selector_str)
if __name__ == "__main__": if __name__ == "__main__":
remove_old_kernels() remove_old_kernels()
......
...@@ -11,8 +11,9 @@ ...@@ -11,8 +11,9 @@
const int4 *__restrict__ A, const int4 *__restrict__ B, \ const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ b_bias_ptr, \ const int4 *__restrict__ b_bias_ptr, \
const float *__restrict__ a_scales_ptr, \
const int4 *__restrict__ scales_ptr, \ const int4 *__restrict__ scales_ptr, \
const uint16_t *__restrict__ scale2_ptr, \ const uint16_t *__restrict__ global_scale_ptr, \
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
const int32_t *__restrict__ sorted_token_ids_ptr, \ const int32_t *__restrict__ sorted_token_ids_ptr, \
const int32_t *__restrict__ expert_ids_ptr, \ const int32_t *__restrict__ expert_ids_ptr, \
...@@ -20,12 +21,13 @@ ...@@ -20,12 +21,13 @@
const float *__restrict__ topk_weights_ptr, int top_k, \ const float *__restrict__ topk_weights_ptr, int top_k, \
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \ bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
int prob_n, int prob_k, int *locks, bool has_bias, bool use_atomic_add, \ int prob_n, int prob_k, int *locks, bool has_bias, bool use_atomic_add, \
bool use_fp32_reduce, int max_shared_mem bool use_fp32_reduce
namespace MARLIN_NAMESPACE_NAME { namespace MARLIN_NAMESPACE_NAME {
template <typename scalar_t, // compute dtype, half or nv_float16 template <const vllm::ScalarTypeId a_type_id, // A ScalarType id
const vllm::ScalarTypeId w_type_id, // weight ScalarType id const vllm::ScalarTypeId b_type_id, // B ScalarType id
const vllm::ScalarTypeId s_type_id, // weight scale ScalarType id const vllm::ScalarTypeId c_type_id, // C ScalarType id
const vllm::ScalarTypeId s_type_id, // B_SCALE ScalarType id
const int threads, // number of threads in a threadblock const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the // dimension (batchsize) of the
......
This diff is collapsed.
This diff is collapsed.
...@@ -63,16 +63,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { ...@@ -63,16 +63,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
m.def( m.def(
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none," "moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
"Tensor! b_q_weight, Tensor? b_bias_or_none," "Tensor! b_q_weight, Tensor? b_bias_or_none,"
"Tensor! b_scales, Tensor? global_scale, Tensor? " "Tensor! b_scales, Tensor? a_scales, Tensor? global_scale, Tensor? "
"b_zeros_or_none," "b_zeros_or_none,"
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace," "Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
"Tensor sorted_token_ids," "Tensor sorted_token_ids,"
"Tensor! expert_ids, Tensor! num_tokens_past_padded," "Tensor! expert_ids, Tensor! num_tokens_past_padded,"
"Tensor! topk_weights, int moe_block_size, int top_k, " "Tensor! topk_weights, int moe_block_size, int top_k, "
"bool mul_topk_weights, bool is_ep, int b_q_type_id," "bool mul_topk_weights, bool is_ep, int b_type_id,"
"int size_m, int size_n, int size_k," "int size_m, int size_n, int size_k,"
"bool is_full_k, bool use_atomic_add," "bool is_full_k, bool use_atomic_add,"
"bool use_fp32_reduce, bool is_zp_float) -> Tensor"); "bool use_fp32_reduce, bool is_zp_float,"
"int thread_k, int thread_n, int blocks_per_sm) -> Tensor");
m.def( m.def(
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
......
...@@ -437,10 +437,10 @@ struct ComputeTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK { ...@@ -437,10 +437,10 @@ struct ComputeTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK {
for (int n_idx = 0; n_idx < WARP_NITER; ++n_idx) { for (int n_idx = 0; n_idx < WARP_NITER; ++n_idx) {
#pragma unroll #pragma unroll
for (int k_idx = 0; k_idx < 2; ++k_idx) { for (int k_idx = 0; k_idx < 2; ++k_idx) {
FType low16 = FType low16 = MarlinScalarType2<FType>::float2num(
ScalarType<FType>::float2num(C_frag[m_idx][n_idx][k_idx * 2]); C_frag[m_idx][n_idx][k_idx * 2]);
FType high16 = FType high16 = MarlinScalarType2<FType>::float2num(
ScalarType<FType>::float2num(C_frag[m_idx][n_idx][k_idx * 2 + 1]); C_frag[m_idx][n_idx][k_idx * 2 + 1]);
uint32_t tmp = (reinterpret_cast<uint32_t&>(low16) & 0xffff) | uint32_t tmp = (reinterpret_cast<uint32_t&>(low16) & 0xffff) |
(reinterpret_cast<uint32_t&>(high16) << 16); (reinterpret_cast<uint32_t&>(high16) << 16);
int sts_offset = int sts_offset =
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include <iostream> #include <iostream>
#include "../gptq_marlin/marlin_dtypes.cuh" #include "../gptq_marlin/marlin_dtypes.cuh"
using marlin::ScalarType; using marlin::MarlinScalarType2;
namespace allspark { namespace allspark {
...@@ -72,10 +72,10 @@ __global__ void f16_gemm_splitk_reduce_kernel(const FType* C_split, FType* C, ...@@ -72,10 +72,10 @@ __global__ void f16_gemm_splitk_reduce_kernel(const FType* C_split, FType* C,
int n_mat = N_MATRIX > 0 ? N_MATRIX : (int)n_matrix; int n_mat = N_MATRIX > 0 ? N_MATRIX : (int)n_matrix;
for (int i = 0; i < n_mat; ++i) { for (int i = 0; i < n_mat; ++i) {
sum += ScalarType<FType>::num2float(C_split[idx + i * matrix_size]); sum += MarlinScalarType2<FType>::num2float(C_split[idx + i * matrix_size]);
} }
C[idx] = ScalarType<FType>::float2num(sum); C[idx] = MarlinScalarType2<FType>::float2num(sum);
} }
template <typename FType> template <typename FType>
......
kernel_*.cu sm*_kernel_*.cu
\ No newline at end of file kernel_selector.h
...@@ -4,14 +4,16 @@ ...@@ -4,14 +4,16 @@
namespace marlin { namespace marlin {
template <int const num_threads, int const num_bits> template <int const num_threads, int const num_bits, bool is_a_8bit>
__global__ void awq_marlin_repack_kernel( __global__ void awq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr, uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr,
int size_k, int size_n) { int size_k, int size_n) {
constexpr int pack_factor = 32 / num_bits; constexpr int pack_factor = 32 / num_bits;
int k_tiles = size_k / tile_k_size; constexpr int target_tile_n_size = tile_n_size / (is_a_8bit ? 2 : 1);
int n_tiles = size_n / tile_n_size; constexpr int target_tile_k_size = tile_k_size * (is_a_8bit ? 2 : 1);
int k_tiles = size_k / target_tile_k_size;
int n_tiles = size_n / target_tile_n_size;
int block_k_tiles = div_ceil(k_tiles, gridDim.x); int block_k_tiles = div_ceil(k_tiles, gridDim.x);
auto start_k_tile = blockIdx.x * block_k_tiles; auto start_k_tile = blockIdx.x * block_k_tiles;
...@@ -33,10 +35,10 @@ __global__ void awq_marlin_repack_kernel( ...@@ -33,10 +35,10 @@ __global__ void awq_marlin_repack_kernel(
extern __shared__ int4 sh[]; extern __shared__ int4 sh[];
constexpr int tile_n_ints = tile_n_size / pack_factor; constexpr int tile_n_ints = target_tile_n_size / pack_factor;
constexpr int stage_n_threads = tile_n_ints / 4; constexpr int stage_n_threads = tile_n_ints / 4;
constexpr int stage_k_threads = tile_k_size; constexpr int stage_k_threads = target_tile_k_size;
constexpr int stage_size = stage_k_threads * stage_n_threads; constexpr int stage_size = stage_k_threads * stage_n_threads;
auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) { auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
...@@ -45,7 +47,7 @@ __global__ void awq_marlin_repack_kernel( ...@@ -45,7 +47,7 @@ __global__ void awq_marlin_repack_kernel(
return; return;
} }
int first_n = n_tile_id * tile_n_size; int first_n = n_tile_id * target_tile_n_size;
int first_n_packed = first_n / pack_factor; int first_n_packed = first_n / pack_factor;
int4* sh_ptr = sh + stage_size * pipe; int4* sh_ptr = sh + stage_size * pipe;
...@@ -54,7 +56,7 @@ __global__ void awq_marlin_repack_kernel( ...@@ -54,7 +56,7 @@ __global__ void awq_marlin_repack_kernel(
auto k_id = threadIdx.x / stage_n_threads; auto k_id = threadIdx.x / stage_n_threads;
auto n_id = threadIdx.x % stage_n_threads; auto n_id = threadIdx.x % stage_n_threads;
int first_k = k_tile_id * tile_k_size; int first_k = k_tile_id * target_tile_k_size;
cp_async4(&sh_ptr[k_id * stage_n_threads + n_id], cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
reinterpret_cast<int4 const*>( reinterpret_cast<int4 const*>(
...@@ -78,11 +80,11 @@ __global__ void awq_marlin_repack_kernel( ...@@ -78,11 +80,11 @@ __global__ void awq_marlin_repack_kernel(
} }
int tc_col = th_id / 4; int tc_col = th_id / 4;
int tc_row = (th_id % 4) * 2; int tc_row = (th_id % 4) * (is_a_8bit ? 4 : 2);
constexpr int tc_offsets[4] = {0, 1, 8, 9}; constexpr int tc_offsets[4] = {0, 1, 8, 9};
int cur_n = warp_id * 16 + tc_col; int cur_n = (warp_id / (is_a_8bit ? 2 : 1)) * 16 + tc_col;
int cur_n_packed = cur_n / pack_factor; int cur_n_packed = cur_n / pack_factor;
int cur_n_pos = cur_n % pack_factor; int cur_n_pos = cur_n % pack_factor;
...@@ -105,23 +107,50 @@ __global__ void awq_marlin_repack_kernel( ...@@ -105,23 +107,50 @@ __global__ void awq_marlin_repack_kernel(
uint32_t vals[8]; uint32_t vals[8];
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
if constexpr (is_a_8bit) {
int cur_elem = tc_row + i;
int packed_src_0 =
sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) * (warp_id % 2) +
sh_stride * cur_elem];
int packed_src_1 =
sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) * (warp_id % 2) +
sh_stride * (cur_elem + 16)];
vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask;
vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask;
} else {
int cur_elem = tc_row + tc_offsets[i]; int cur_elem = tc_row + tc_offsets[i];
int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem]; int packed_src_0 =
sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem];
int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) + int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) +
sh_stride * cur_elem]; sh_stride * cur_elem];
vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask; vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask;
vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask; vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask;
} }
}
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor; constexpr int tile_size =
target_tile_k_size * target_tile_n_size / pack_factor;
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
// Result of: // Result of:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
if constexpr (num_bits == 4) { if constexpr (!is_a_8bit && num_bits == 4) {
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
uint32_t res = 0;
#pragma unroll
for (int i = 0; i < 8; i++) {
res |= vals[pack_idx[i]] << (i * 4);
}
out_ptr[out_offset + th_id * 4 + warp_id] = res;
} else if constexpr (is_a_8bit && num_bits == 4) {
int pack_idx[8] = {0, 4, 1, 5, 2, 6, 3, 7};
uint32_t res = 0; uint32_t res = 0;
#pragma unroll #pragma unroll
...@@ -138,8 +167,9 @@ __global__ void awq_marlin_repack_kernel( ...@@ -138,8 +167,9 @@ __global__ void awq_marlin_repack_kernel(
uint32_t res2 = 0; uint32_t res2 = 0;
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
res1 |= vals[pack_idx[i]] << (i * 8); const int ii = is_a_8bit ? i : pack_idx[i];
res2 |= vals[4 + pack_idx[i]] << (i * 8); res1 |= vals[ii] << (i * 8);
res2 |= vals[4 + ii] << (i * 8);
} }
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
...@@ -176,18 +206,21 @@ __global__ void awq_marlin_repack_kernel( ...@@ -176,18 +206,21 @@ __global__ void awq_marlin_repack_kernel(
} // namespace marlin } // namespace marlin
#define CALL_IF(NUM_BITS) \ #define CALL_IF(NUM_BITS, IS_A_8BIT) \
else if (num_bits == NUM_BITS) { \ else if (num_bits == NUM_BITS && is_a_8bit == IS_A_8BIT) { \
cudaFuncSetAttribute( \ cudaFuncSetAttribute( \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \ marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
IS_A_8BIT>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS> \ marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
IS_A_8BIT> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \ <<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, out_ptr, size_k, size_n); \ b_q_weight_ptr, out_ptr, size_k, size_n); \
} }
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
int64_t size_n, int64_t num_bits) { int64_t size_n, int64_t num_bits,
bool is_a_8bit) {
// Verify compatibility with marlin tile of 16x64 // Verify compatibility with marlin tile of 16x64
TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k, TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k,
" is not divisible by tile_k_size = ", marlin::tile_k_size); " is not divisible by tile_k_size = ", marlin::tile_k_size);
...@@ -238,10 +271,13 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, ...@@ -238,10 +271,13 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
if (false) { if (false) {
} }
CALL_IF(4) CALL_IF(4, false)
CALL_IF(8) CALL_IF(8, false)
CALL_IF(4, true)
CALL_IF(8, true)
else { else {
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits); TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits,
", is_a_8bit = ", is_a_8bit);
} }
return out; return out;
......
...@@ -470,6 +470,50 @@ __device__ inline void dequant<nv_bfloat162, vllm::kFE2M1f.id(), false>( ...@@ -470,6 +470,50 @@ __device__ inline void dequant<nv_bfloat162, vllm::kFE2M1f.id(), false>(
frag_b[0] = __hmul2(frag_b[0], bias_reg); frag_b[0] = __hmul2(frag_b[0], bias_reg);
} }
template <>
__device__ inline void dequant<__nv_fp8x4_e4m3, vllm::kFE2M1f.id(), true>(
int q, __nv_fp8x4_e4m3* frag_b) {
// Constants for FP4 (E2M1) and FP16 formats
constexpr int FP4_EXPONENT = 2, FP8_EXPONENT = 4;
constexpr int RIGHT_SHIFT = FP8_EXPONENT - FP4_EXPONENT;
constexpr int MASK = 0x70707070;
// Extract and shift FP4 values to FP16 format
int Out1 = (q & 0x80808080) | ((q & MASK) >> RIGHT_SHIFT);
q <<= 4;
int Out2 = (q & 0x80808080) | ((q & MASK) >> RIGHT_SHIFT);
// Note1: reverse indexing is intentional because weights are permuted
// Note2: when dequant to 8bit type, we write to `frag_b[2]` instead of
// `frag_b[1]` to fit the layout of tensorcore
frag_b[1] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out1);
frag_b[0] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out2);
}
template <>
__device__ inline void dequant<int32_t, vllm::kU4B8.id(), true>(
int q, int32_t* frag_b) {
constexpr int repeated_zp = 0x08080808;
constexpr int MASK = 0x80808080;
frag_b[0] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK;
q >>= 4;
frag_b[1] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK;
}
template <>
__device__ inline void dequant<__nv_fp8x4_e4m3, vllm::kU4B8.id(), true>(
int q, __nv_fp8x4_e4m3* frag_b) {
int s = q & 0x08080808;
int Out1 = ((q & 0x07070707) | (s << 4)) + (s >> 3);
q >>= 4;
s = q & 0x08080808;
int Out2 = ((q & 0x07070707) | (s << 4)) + (s >> 3);
frag_b[0] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out1);
frag_b[1] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out2);
}
template <typename scalar_t2, vllm::ScalarTypeId s_type_id> template <typename scalar_t2, vllm::ScalarTypeId s_type_id>
__device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b); __device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b);
...@@ -515,6 +559,49 @@ __device__ inline void dequant_fp8_scales<nv_bfloat162, vllm::kFE8M0fnu.id()>( ...@@ -515,6 +559,49 @@ __device__ inline void dequant_fp8_scales<nv_bfloat162, vllm::kFE8M0fnu.id()>(
// Note: reverse indexing is intentional because weights are permuted // Note: reverse indexing is intentional because weights are permuted
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1); frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2); frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
};
// subtract zero point in quanted format and then dequant
template <typename scalar_t2, vllm::ScalarTypeId w_type_id,
bool skip_flop = false>
__device__ inline void sub_zp_and_dequant(int q, scalar_t2* frag_b, int zp);
template <>
__device__ inline void sub_zp_and_dequant<int32_t, vllm::kU4.id(), true>(
int q, int32_t* frag_b, int zp) {
// INT4 with zp -> INT8
// see https://github.com/vllm-project/vllm/pull/24722
int repeated_zp = 0x01010101 * zp;
int MASK = 0x80808080;
frag_b[0] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK;
q >>= 4;
frag_b[1] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK;
}
template <>
__device__ inline void sub_zp_and_dequant<__nv_fp8x4_e4m3, vllm::kU4.id(),
true>(int q, __nv_fp8x4_e4m3* frag_b,
int zp) {
// INT4 with zp -> FP8
// see https://github.com/vllm-project/vllm/pull/24722
uint32_t u_q = *reinterpret_cast<uint32_t*>(&q);
uint32_t u_zp = *reinterpret_cast<uint32_t*>(&zp);
uint32_t u_zp1 = u_zp + 1;
uint32_t repeated_zp = 0x01010101 * u_zp;
uint32_t q0, s;
q0 = (u_q & 0x0F0F0F0F) | 0x70707070;
s = (q0 + repeated_zp) & 0x80808080;
uint32_t Out1 = (q0 + (s >> 7) * u_zp1) & 0x0F0F0F0F | s;
u_q >>= 4;
q0 = (u_q & 0x0F0F0F0F) | 0x70707070;
s = (q0 + repeated_zp) & 0x80808080;
uint32_t Out2 = (q0 + (s >> 7) * u_zp1) & 0x0F0F0F0F | s;
frag_b[0] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out1);
frag_b[1] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out2);
} }
#endif #endif
......
...@@ -4,141 +4,292 @@ import glob ...@@ -4,141 +4,292 @@ import glob
import itertools import itertools
import os import os
import subprocess import subprocess
import sys
import jinja2 import jinja2
FILE_HEAD = """ ARCHS = []
// auto generated by generate.py SUPPORT_FP8 = False
for arch in sys.argv[1].split(","):
arch = arch[: arch.index(".") + 2].replace(".", "")
arch = int(arch)
# only SM89 and SM120 fully support
# mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32.
# SM90 and SM100 can use this PTX, but it’s simulated
# with FP16 MMA, so it cannot achieve any acceleration.
if arch in [89, 120]:
SUPPORT_FP8 = True
FILE_HEAD_COMMENT = """
// auto generated by generate_kernels.py
// clang-format off // clang-format off
""".lstrip()
FILE_HEAD = (
FILE_HEAD_COMMENT
+ """
#include "kernel.h" #include "kernel.h"
#include "marlin_template.h" #include "marlin_template.h"
namespace MARLIN_NAMESPACE_NAME { namespace MARLIN_NAMESPACE_NAME {
""".strip() """
)
TEMPLATE = ( TEMPLATE = (
"template __global__ void Marlin<" "template __global__ void Marlin<"
"{{scalar_t}}, " "{{a_type_id}}, "
"{{w_type_id}}, " "{{b_type_id}}, "
"{{c_type_id}}, "
"{{s_type_id}}, " "{{s_type_id}}, "
"{{threads}}, " "{{threads}}, "
"{{thread_m_blocks}}, " "{{thread_m_blocks}}, "
"{{thread_n_blocks}}, " "{{thread_n_blocks}}, "
"{{thread_k_blocks}}, " "{{thread_k_blocks}}, "
"{{'true' if m_block_size_8 else 'false'}}, " "{{m_block_size_8}}, "
"{{stages}}, " "{{stages}}, "
"{{group_blocks}}, " "{{group_blocks}}, "
"{{'true' if is_zp_float else 'false'}}>" "{{is_zp_float}}>"
"( MARLIN_KERNEL_PARAMS );" "( MARLIN_KERNEL_PARAMS );"
) )
# int8 with zero point case (vllm::kU8) is also supported,
# we don't add it to reduce wheel size.
SCALAR_TYPES = [
"vllm::kU4",
"vllm::kU4B8",
"vllm::kU8B128",
"vllm::kFE4M3fn",
"vllm::kFE2M1f",
]
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), (128, 64, 128)] THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), (128, 64, 128)]
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4] THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
# group_blocks:
# = 0 : act order case QUANT_CONFIGS = [
# = -1 : channelwise quantization # AWQ-INT4
# > 0 : group_size=16*group_blocks {
GROUP_BLOCKS = [0, 1, -1, 2, 4, 8] "b_type": "kU4",
DTYPES = ["fp16", "bf16"] "thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 2, 4, 8],
},
# HQQ
{
"a_type": ["kFloat16"],
"b_type": "kU4",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [4],
"is_zp_float": True,
},
# GPTQ-INT4
{
"b_type": "kU4B8",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 0, 2, 4, 8],
},
# GPTQ-INT8
{
"b_type": "kU8B128",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 0, 2, 4, 8],
},
# FP8
{
"b_type": "kFE4M3fn",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 8],
},
# NVFP4
{
"b_type": "kFE2M1f",
"s_type": "kFE4M3fn",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [1],
},
# MXFP4
{
"a_type": ["kBFloat16"],
"b_type": "kFE2M1f",
"s_type": "kFE8M0fnu",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [2],
},
# AWQ-INT4 with INT8 activation
{
"a_type": ["kS8"],
"b_type": "kU4",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT4 with INT8 activation
{
"a_type": ["kS8"],
"b_type": "kU4B8",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT4 with FP8 activation
{
"a_type": ["kFE4M3fn"],
"b_type": "kU4B8",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# AWQ-INT4 with FP8 activation
{
"a_type": ["kFE4M3fn"],
"b_type": "kU4",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# MXFP4 with FP8 activation
{
"a_type": ["kFE4M3fn"],
"b_type": "kFE2M1f",
"c_type": ["kBFloat16"],
"s_type": "kFE8M0fnu",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [2],
},
]
def remove_old_kernels(): def remove_old_kernels():
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"): for filename in glob.glob(os.path.dirname(__file__) + "/*kernel_*.cu"):
subprocess.call(["rm", "-f", filename])
filename = os.path.dirname(__file__) + "/kernel_selector.h"
subprocess.call(["rm", "-f", filename]) subprocess.call(["rm", "-f", filename])
def generate_new_kernels(): def generate_new_kernels():
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES): result_dict = {}
all_template_str_list = []
for group_blocks, m_blocks, thread_configs in itertools.product( for quant_config in QUANT_CONFIGS:
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"])
): a_types = quant_config.get("a_type", ["kFloat16", "kBFloat16"])
# act order case only support gptq-int4 and gptq-int8 b_type = quant_config["b_type"]
if group_blocks == 0 and scalar_type not in [ is_zp_float = quant_config.get("is_zp_float", False)
"vllm::kU4B8", all_group_blocks = quant_config["group_blocks"]
"vllm::kU8B128", all_m_blocks = quant_config["thread_m_blocks"]
]: all_thread_configs = quant_config["thread_configs"]
continue
if thread_configs[2] == 256: for a_type, c_type in itertools.product(a_types, c_types):
# for small batch (m_blocks == 1), we only need (128, 128, 256) if not SUPPORT_FP8 and a_type == "kFE4M3fn":
# for large batch (m_blocks > 1), we only need (64, 256, 256)
if m_blocks <= 1 and thread_configs[0] != 128:
continue continue
if m_blocks > 1 and thread_configs[0] != 64: if "16" in a_type and "16" in c_type and a_type != c_type:
continue continue
s_type = quant_config.get("s_type", c_type)
if (a_type, b_type, c_type) not in result_dict:
result_dict[(a_type, b_type, c_type)] = []
# we only support channelwise quantization and group_size == 128 for group_blocks, m_blocks, thread_configs in itertools.product(
# for fp8 all_group_blocks, all_m_blocks, all_thread_configs
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]: ):
continue thread_k, thread_n, threads = thread_configs
# nvfp4 only supports group_size == 16
# mxfp4 only supports group_size == 32 if threads == 256:
if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]: # for small batch (m_blocks == 1),
# we only need (128, 128, 256)
# for large batch (m_blocks > 1),
# we only need (64, 256, 256)
if m_blocks <= 1 and (thread_k, thread_n) != (128, 128):
continue continue
# other quantization methods don't support group_size = 16 if m_blocks > 1 and (thread_k, thread_n) != (64, 256):
if scalar_type != "vllm::kFE2M1f" and group_blocks == 1:
continue continue
k_blocks = thread_configs[0] // 16 config = {
n_blocks = thread_configs[1] // 16 "threads": threads,
threads = thread_configs[2] "s_type": s_type,
"thread_m_blocks": max(m_blocks, 1),
"thread_k_blocks": thread_k // 16,
"thread_n_blocks": thread_n // 16,
"m_block_size_8": "true" if m_blocks == 0.5 else "false",
"stages": "pipe_stages",
"group_blocks": group_blocks,
"is_zp_float": "true" if is_zp_float else "false",
}
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16" result_dict[(a_type, b_type, c_type)].append(config)
is_zp_float_list = [False] kernel_selector_str = FILE_HEAD_COMMENT
if dtype == "fp16" and scalar_type == "vllm::kU4" and group_blocks == 4:
# HQQ (is_zp_float = true) only supports
# 4bit quantization and fp16
is_zp_float_list.append(True)
if scalar_type == "vllm::kFE2M1f" and group_blocks == 1: for (a_type, b_type, c_type), config_list in result_dict.items():
s_type = "vllm::kFE4M3fn" all_template_str_list = []
elif scalar_type == "vllm::kFE2M1f" and group_blocks == 2: for config in config_list:
s_type = "vllm::kFE8M0fnu" s_type = config["s_type"]
if dtype == "fp16":
# we cannot safely dequantize e8m0 to fp16, so skip this
continue
elif dtype == "fp16":
s_type = "vllm::kFloat16"
elif dtype == "bf16":
s_type = "vllm::kBFloat16"
for is_zp_float in is_zp_float_list:
template_str = jinja2.Template(TEMPLATE).render( template_str = jinja2.Template(TEMPLATE).render(
scalar_t=c_dtype, a_type_id=f"vllm::{a_type}.id()",
w_type_id=scalar_type + ".id()", b_type_id=f"vllm::{b_type}.id()",
s_type_id=s_type + ".id()", c_type_id=f"vllm::{c_type}.id()",
threads=threads, s_type_id=f"vllm::{s_type}.id()",
thread_m_blocks=max(m_blocks, 1), **config,
thread_n_blocks=n_blocks,
thread_k_blocks=k_blocks,
m_block_size_8=m_blocks == 0.5,
stages="pipe_stages",
group_blocks=group_blocks,
is_zp_float=is_zp_float,
) )
all_template_str_list.append(template_str) all_template_str_list.append(template_str)
conditions = [
f"a_type == vllm::{a_type}",
f"b_type == vllm::{b_type}",
f"c_type == vllm::{c_type}",
f"s_type == vllm::{s_type}",
f"threads == {config['threads']}",
f"thread_m_blocks == {config['thread_m_blocks']}",
f"thread_n_blocks == {config['thread_n_blocks']}",
f"thread_k_blocks == {config['thread_k_blocks']}",
f"m_block_size_8 == {config['m_block_size_8']}",
f"group_blocks == {config['group_blocks']}",
f"is_zp_float == {config['is_zp_float']}",
]
conditions = " && ".join(conditions)
if kernel_selector_str == FILE_HEAD_COMMENT:
kernel_selector_str += f"if ({conditions})\n kernel = "
else:
kernel_selector_str += f"else if ({conditions})\n kernel = "
kernel_template2 = (
"Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
"{{is_zp_float}}>;"
)
kernel_selector_str += (
jinja2.Template(kernel_template2).render(
a_type_id=f"vllm::{a_type}.id()",
b_type_id=f"vllm::{b_type}.id()",
c_type_id=f"vllm::{c_type}.id()",
s_type_id=f"vllm::{s_type}.id()",
**config,
)
+ "\n"
)
file_content = FILE_HEAD + "\n\n" file_content = FILE_HEAD + "\n\n"
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n" file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu" if a_type == "kFE4M3fn":
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
else:
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
filename = filename.lower()
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f: with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
f.write(file_content) f.write(file_content)
if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT:
kernel_selector_str += (
"else if (a_type == vllm::kFE4M3fn)\n"
" TORCH_CHECK(false, "
'"marlin kernel with fp8 activation is not built.");'
)
with open(os.path.join(os.path.dirname(__file__), "kernel_selector.h"), "w") as f:
f.write(kernel_selector_str)
if __name__ == "__main__": if __name__ == "__main__":
remove_old_kernels() remove_old_kernels()
......
...@@ -4,15 +4,18 @@ ...@@ -4,15 +4,18 @@
namespace marlin { namespace marlin {
template <int const num_threads, int const num_bits, bool const has_perm> template <int const num_threads, int const num_bits, bool const has_perm,
bool is_a_8bit>
__global__ void gptq_marlin_repack_kernel( __global__ void gptq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t const* __restrict__ b_q_weight_ptr,
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr, uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
int size_k, int size_n) { int size_k, int size_n) {
constexpr int pack_factor = 32 / num_bits; constexpr int pack_factor = 32 / num_bits;
int k_tiles = size_k / tile_k_size; constexpr int target_tile_n_size = tile_n_size / (is_a_8bit ? 2 : 1);
int n_tiles = size_n / tile_n_size; constexpr int target_tile_k_size = tile_k_size * (is_a_8bit ? 2 : 1);
int k_tiles = size_k / target_tile_k_size;
int n_tiles = size_n / target_tile_n_size;
int block_k_tiles = div_ceil(k_tiles, gridDim.x); int block_k_tiles = div_ceil(k_tiles, gridDim.x);
auto start_k_tile = blockIdx.x * block_k_tiles; auto start_k_tile = blockIdx.x * block_k_tiles;
...@@ -34,7 +37,7 @@ __global__ void gptq_marlin_repack_kernel( ...@@ -34,7 +37,7 @@ __global__ void gptq_marlin_repack_kernel(
extern __shared__ int4 sh[]; extern __shared__ int4 sh[];
constexpr int perm_size = tile_k_size / 4; constexpr int perm_size = target_tile_k_size / 4;
int4* sh_perm_ptr = sh; int4* sh_perm_ptr = sh;
int4* sh_pipe_ptr = sh_perm_ptr; int4* sh_pipe_ptr = sh_perm_ptr;
...@@ -42,14 +45,14 @@ __global__ void gptq_marlin_repack_kernel( ...@@ -42,14 +45,14 @@ __global__ void gptq_marlin_repack_kernel(
sh_pipe_ptr += perm_size; sh_pipe_ptr += perm_size;
} }
constexpr int tile_ints = tile_k_size / pack_factor; constexpr int tile_ints = target_tile_k_size / pack_factor;
constexpr int stage_n_threads = tile_n_size / 4; constexpr int stage_n_threads = target_tile_n_size / 4;
constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints; constexpr int stage_k_threads = has_perm ? target_tile_k_size : tile_ints;
constexpr int stage_size = stage_k_threads * stage_n_threads; constexpr int stage_size = stage_k_threads * stage_n_threads;
auto load_perm_to_shared = [&](int k_tile_id) { auto load_perm_to_shared = [&](int k_tile_id) {
int first_k_int4 = (k_tile_id * tile_k_size) / 4; int first_k_int4 = (k_tile_id * target_tile_k_size) / 4;
int4 const* perm_int4_ptr = reinterpret_cast<int4 const*>(perm_ptr); int4 const* perm_int4_ptr = reinterpret_cast<int4 const*>(perm_ptr);
...@@ -65,7 +68,7 @@ __global__ void gptq_marlin_repack_kernel( ...@@ -65,7 +68,7 @@ __global__ void gptq_marlin_repack_kernel(
return; return;
} }
int first_n = n_tile_id * tile_n_size; int first_n = n_tile_id * target_tile_n_size;
int4* sh_ptr = sh_pipe_ptr + stage_size * pipe; int4* sh_ptr = sh_pipe_ptr + stage_size * pipe;
...@@ -91,7 +94,7 @@ __global__ void gptq_marlin_repack_kernel( ...@@ -91,7 +94,7 @@ __global__ void gptq_marlin_repack_kernel(
auto k_id = threadIdx.x / stage_n_threads; auto k_id = threadIdx.x / stage_n_threads;
auto n_id = threadIdx.x % stage_n_threads; auto n_id = threadIdx.x % stage_n_threads;
int first_k = k_tile_id * tile_k_size; int first_k = k_tile_id * target_tile_k_size;
int first_k_packed = first_k / pack_factor; int first_k_packed = first_k / pack_factor;
cp_async4(&sh_ptr[k_id * stage_n_threads + n_id], cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
...@@ -117,13 +120,13 @@ __global__ void gptq_marlin_repack_kernel( ...@@ -117,13 +120,13 @@ __global__ void gptq_marlin_repack_kernel(
} }
int tc_col = th_id / 4; int tc_col = th_id / 4;
int tc_row = (th_id % 4) * 2; int tc_row = (th_id % 4) * (is_a_8bit ? 4 : 2);
constexpr int tc_offsets[4] = {0, 1, 8, 9}; constexpr int tc_offsets[4] = {0, 1, 8, 9};
int cur_n = warp_id * 16 + tc_col; int cur_n = (warp_id / (is_a_8bit ? 2 : 1)) * 16 + tc_col;
constexpr int sh_stride = 64; constexpr int sh_stride = target_tile_n_size;
constexpr uint32_t mask = (1 << num_bits) - 1; constexpr uint32_t mask = (1 << num_bits) - 1;
int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe; int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe;
...@@ -134,6 +137,7 @@ __global__ void gptq_marlin_repack_kernel( ...@@ -134,6 +137,7 @@ __global__ void gptq_marlin_repack_kernel(
uint32_t vals[8]; uint32_t vals[8];
if constexpr (has_perm) { if constexpr (has_perm) {
static_assert(!is_a_8bit);
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
int k_idx = tc_row + tc_offsets[i]; int k_idx = tc_row + tc_offsets[i];
...@@ -156,28 +160,49 @@ __global__ void gptq_marlin_repack_kernel( ...@@ -156,28 +160,49 @@ __global__ void gptq_marlin_repack_kernel(
#pragma unroll #pragma unroll
for (int i = 0; i < tile_ints; i++) { for (int i = 0; i < tile_ints; i++) {
if constexpr (is_a_8bit) {
b1_vals[i] =
sh_stage_int_ptr[cur_n + sh_stride * i + (warp_id % 2) * 8];
} else {
b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i]; b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i];
b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i]; b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i];
} }
}
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
int cur_elem = tc_row + tc_offsets[i]; int cur_elem = tc_row + (is_a_8bit ? i : tc_offsets[i]);
int cur_int = cur_elem / pack_factor; int cur_int = cur_elem / pack_factor;
int cur_pos = cur_elem % pack_factor; int cur_pos = cur_elem % pack_factor;
vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask; vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask;
if constexpr (is_a_8bit)
vals[4 + i] =
(b1_vals[cur_int + tile_ints / 2] >> (cur_pos * num_bits)) & mask;
else
vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask; vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask;
} }
} }
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor; constexpr int tile_size =
target_tile_k_size * target_tile_n_size / pack_factor;
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
// Result of: // Result of:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
if constexpr (num_bits == 4) { if constexpr (!is_a_8bit && num_bits == 4) {
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
uint32_t res = 0;
#pragma unroll
for (int i = 0; i < 8; i++) {
res |= vals[pack_idx[i]] << (i * 4);
}
out_ptr[out_offset + th_id * 4 + warp_id] = res;
} else if constexpr (is_a_8bit && num_bits == 4) {
int pack_idx[8] = {0, 4, 1, 5, 2, 6, 3, 7};
uint32_t res = 0; uint32_t res = 0;
#pragma unroll #pragma unroll
...@@ -194,8 +219,9 @@ __global__ void gptq_marlin_repack_kernel( ...@@ -194,8 +219,9 @@ __global__ void gptq_marlin_repack_kernel(
uint32_t res2 = 0; uint32_t res2 = 0;
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
res1 |= vals[pack_idx[i]] << (i * 8); const int ii = is_a_8bit ? i : pack_idx[i];
res2 |= vals[4 + pack_idx[i]] << (i * 8); res1 |= vals[ii] << (i * 8);
res2 |= vals[4 + ii] << (i * 8);
} }
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
...@@ -236,21 +262,22 @@ __global__ void gptq_marlin_repack_kernel( ...@@ -236,21 +262,22 @@ __global__ void gptq_marlin_repack_kernel(
} // namespace marlin } // namespace marlin
#define CALL_IF(NUM_BITS, HAS_PERM) \ #define CALL_IF(NUM_BITS, HAS_PERM, IS_A_8BIT) \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ else if (num_bits == NUM_BITS && has_perm == HAS_PERM && \
is_a_8bit == IS_A_8BIT) { \
cudaFuncSetAttribute( \ cudaFuncSetAttribute( \
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \ marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
HAS_PERM>, \ HAS_PERM, IS_A_8BIT>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \ marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
HAS_PERM> \ HAS_PERM, IS_A_8BIT> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \ <<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \ b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
} }
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n, int64_t size_k, int64_t size_n,
int64_t num_bits) { int64_t num_bits, bool is_a_8bit) {
// Verify compatibility with marlin tile of 16x64 // Verify compatibility with marlin tile of 16x64
TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k, TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k,
" is not divisible by tile_k_size = ", marlin::tile_k_size); " is not divisible by tile_k_size = ", marlin::tile_k_size);
...@@ -309,13 +336,17 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, ...@@ -309,13 +336,17 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
if (false) { if (false) {
} }
CALL_IF(4, false) CALL_IF(4, false, false)
CALL_IF(4, true) CALL_IF(4, true, false)
CALL_IF(8, false) CALL_IF(8, false, false)
CALL_IF(8, true) CALL_IF(8, true, false)
CALL_IF(4, false, true)
CALL_IF(8, false, true)
else { else {
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits, TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits,
", has_perm = ", has_perm); ", has_perm = ", has_perm, ", is_a_8bit = ", is_a_8bit);
} }
return out; return out;
......
...@@ -11,17 +11,19 @@ ...@@ -11,17 +11,19 @@
const int4 *__restrict__ A, const int4 *__restrict__ B, \ const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ b_bias_ptr, \ const int4 *__restrict__ b_bias_ptr, \
const float *__restrict__ a_scales_ptr, \
const int4 *__restrict__ scales_ptr, \ const int4 *__restrict__ scales_ptr, \
const uint16_t *__restrict__ scale2_ptr, \ const uint16_t *__restrict__ global_scale_ptr, \
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \ int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \
bool has_bias, bool use_atomic_add, bool use_fp32_reduce, \ bool has_bias, bool use_atomic_add, bool use_fp32_reduce, \
int max_shared_mem int max_shared_mem
namespace MARLIN_NAMESPACE_NAME { namespace MARLIN_NAMESPACE_NAME {
template <typename scalar_t, // compute dtype, half or nv_float16 template <const vllm::ScalarTypeId a_type_id, // A ScalarType id
const vllm::ScalarTypeId w_type_id, // weight ScalarType id const vllm::ScalarTypeId b_type_id, // B ScalarType id
const vllm::ScalarTypeId s_type_id, // weight ScalarType id const vllm::ScalarTypeId c_type_id, // C ScalarType id
const vllm::ScalarTypeId s_type_id, // B_SCALE ScalarType id
const int threads, // number of threads in a threadblock const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the // dimension (batchsize) of the
......
...@@ -55,6 +55,45 @@ constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } ...@@ -55,6 +55,45 @@ constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
// No support for async // No support for async
#else #else
__device__ inline void cp_async1_ca_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
const int BYTES = 4;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.ca.shared.global [%1], [%2], %3;\n"
"}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr), "n"(BYTES));
}
__device__ inline void cp_async2_ca_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
const int BYTES = 8;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.ca.shared.global [%1], [%2], %3;\n"
"}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr), "n"(BYTES));
}
__device__ inline void cp_async4_ca_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.ca.shared.global [%1], [%2], %3;\n"
"}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr), "n"(BYTES));
}
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) { bool pred = true) {
const int BYTES = 16; const int BYTES = 16;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment