Unverified Commit 33c63e95 authored by Jinzhen Lin's avatar Jinzhen Lin Committed by GitHub
Browse files

[Kernel] [Quantization] Add MXFP4 and bias support for marlin kernel (#22428)


Signed-off-by: default avatarrongfu.leng <rongfu.leng@daocloud.io>
Signed-off-by: default avatarJinzhen Lin <linjinzhen@hotmail.com>
Signed-off-by: default avatarHuzaifa Sidhpurwala <huzaifas@redhat.com>
Signed-off-by: default avatarVarun Sundar Rabindranath <vsundarr@redhat.com>
Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
Signed-off-by: default avatarAnimesh Jain <anijain@umich.edu>
Signed-off-by: default avatarRui Qiao <ruisearch42@gmail.com>
Signed-off-by: default avatarXiongfei Wei <isaacwxf23@gmail.com>
Signed-off-by: default avatarNick Hill <nhill@redhat.com>
Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
Signed-off-by: default avatarkf <kuanfu.liu@embeddedllm.com>
Signed-off-by: default avatarvllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: default avatarNickLucche <nlucches@redhat.com>
Signed-off-by: default avatarDipika Sikka <dipikasikka1@gmail.com>
Signed-off-by: default avatarSage Moore <sage@neuralmagic.com>
Signed-off-by: default avatartjtanaavllm <tunjian.tan@amd.com>
Signed-off-by: default avatarYong Hoon Shin <yhshin@meta.com>
Signed-off-by: default avatarChih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: default avatarRoger Wang <hey@rogerw.me>
Signed-off-by: default avatarVadim Gimpelson <vadim.gimpelson@centml.ai>
Signed-off-by: default avatarIsotr0py <2037008807@qq.com>
Signed-off-by: default avatarzRzRzRzRzRzRzR <2448370773@qq.com>
Signed-off-by: default avatarChih-Chieh Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: default avatarThomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: default avataryan <yan.ma@intel.com>
Signed-off-by: default avatarYan Ma <yan.ma@intel.com>
Signed-off-by: default avatarXiao Liu <xiszishu@gmail.com>
Signed-off-by: default avatarjiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: default avatarYe (Charlotte) Qi <yeq@meta.com>
Signed-off-by: default avatarLopezCastroRoberto <roberto.lopez.castro@udc.es>
Signed-off-by: default avatarAndy Xie <andy.xning@gmail.com>
Signed-off-by: default avatarHaibin Lin <haibin.lin@bytedance.com>
Signed-off-by: default avatarDavid Ben-David <davidb@pliops.com>
Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: default avatarjiang1.li <jiang1.li@intel.com>
Signed-off-by: default avatarSeiji Eicher <seiji@anyscale.com>
Signed-off-by: default avatarzitian.zhao <zitian.zhao@tencentmusic.com>
Signed-off-by: default avatar22quinn <33176974+22quinn@users.noreply.github.com>
Signed-off-by: default avatarAbirdcfly <fp544037857@gmail.com>
Signed-off-by: default avatarGiancarlo Delfin <gdelfin@meta.com>
Signed-off-by: default avatarTyler Michael Smith <tyler@neuralmagic.com>
Signed-off-by: default avatarhuangweixiao <huangweixiao@msh.team>
Signed-off-by: default avataralyosha-swamy <raghav@arcee.ai>
Signed-off-by: default avatarEric Hanley <ericehanley@google.com>
Signed-off-by: default avatarAbatom <abzhonghua@gmail.com>
Signed-off-by: default avatarCLFutureX <775523362@qq.com>
Signed-off-by: default avatarLinkun Chen <github@lkchen.net>
Signed-off-by: default avatartjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: default avatarGregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Signed-off-by: default avatartlipoca9 <tlipoca9@gmail.com>
Signed-off-by: default avatarelvischenv <219235043+elvischenv@users.noreply.github.com>
Signed-off-by: default avatarzitian zhao <zitian.zhao@tencentmusic.com>
Signed-off-by: default avatarmgoin <michael@neuralmagic.com>
Signed-off-by: default avatarwang.yuqi <noooop@126.com>
Signed-off-by: default avatarBenji Beck <benjibeck@meta.com>
Signed-off-by: default avatarSiyuan Liu <lsiyuan@google.com>
Signed-off-by: default avatarBenjamin Chislett <benjamin.chislett@centml.ai>
Signed-off-by: default avatarisotr0py <2037008807@qq.com>
Signed-off-by: default avatarChen Zhang <zhangch99@outlook.com>
Signed-off-by: default avatarsimon-mo <xmo@berkeley.edu>
Signed-off-by: default avatarLucasWilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: default avatarZhang Jason <ning.zhang2@amd.com>
Signed-off-by: default avatarYongye Zhu <zyy1102000@gmail.com>
Signed-off-by: default avatarasafg <asafg@ai21.com>
Signed-off-by: default avatarSiyuan Fu <siyuanf@nvidia.com>
Signed-off-by: default avatarLain <fusiyuan2000@hotmail.com>
Signed-off-by: default avatarMax de Bayser <mbayser@br.ibm.com>
Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: default avatarKunshang Ji <kunshang.ji@intel.com>
Signed-off-by: default avatarTao He <linzhu.ht@alibaba-inc.com>
Signed-off-by: default avatarMichael Goin <mgoin64@gmail.com>
Signed-off-by: default avatarQscQ <qscqesze@gmail.com>
Signed-off-by: default avatarqingjun <qingjun@minimaxi.com>
Signed-off-by: default avatarSyed Muhammad Bin Asif <syedmba7@connect.hku.hk>
Signed-off-by: default avatarLionel Villard <villard@us.ibm.com>
Signed-off-by: default avatarycyaw66 <497410282@qq.com>
Signed-off-by: default avatarDavid Chen <530634352@qq.com>
Signed-off-by: default avatarLinkun <github@lkchen.net>
Signed-off-by: default avatarMoritz Sanft <58110325+msanft@users.noreply.github.com>
Signed-off-by: default avatarMing Yang <minos.future@gmail.com>
Signed-off-by: default avatarAdrian Garcia <adrian.garcia@inceptionai.ai>
Signed-off-by: default avatarshaojunqi <shaojunqi.sjq@alibaba-inc.com>
Signed-off-by: default avatarRicardo Decal <rdecal@anyscale.com>
Signed-off-by: default avatarAndrew Chan <andrewkchan.akc@gmail.com>
Signed-off-by: default avatarFelix Marty <Felix.Marty@amd.com>
Signed-off-by: default avatarAndrew Sansom <andrew@protopia.ai>
Signed-off-by: default avatarZhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: default avatarShu Wang <shuw@nvidia.com>
Signed-off-by: default avatarPo-Han Huang <pohanh@nvidia.com>
Signed-off-by: default avatarShu Wang. <shuw@nvidia.com>
Signed-off-by: default avatarXIn Li <xinli@nvidia.com>
Signed-off-by: default avatarJunhao Li <junhao@ubicloud.com>
Signed-off-by: default avatarchaunceyjiang <chaunceyjiang@gmail.com>
Signed-off-by: default avatariAmir97 <Amir.balwel@embeddedllm.com>
Signed-off-by: default avatariAmir97 <71513472+iAmir97@users.noreply.github.com>
Signed-off-by: <zyy1102000@gmail.com>
Signed-off-by: default avatarGuy Stone <guys@spotify.com>
Signed-off-by: <yyweiss@gmail.com>
Signed-off-by: default avataryyw <yyweiss@gmail.com>
Signed-off-by: default avatarRussell Bryant <rbryant@redhat.com>
Signed-off-by: default avatarPradyun Ramadorai <pradyunr@amazon.com>
Signed-off-by: default avatarPradyun92 <142861237+Pradyun92@users.noreply.github.com>
Signed-off-by: default avatarJinzhen Lin <jinzhen.ljz@antgroup.com>
Co-authored-by: default avatarrongfu.leng <rongfu.leng@daocloud.io>
Co-authored-by: default avatarHuzaifa Sidhpurwala <huzaifas@redhat.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: default avatarRussell Bryant <rbryant@redhat.com>
Co-authored-by: default avatarVarun Sundar Rabindranath <varunsundar08@gmail.com>
Co-authored-by: default avatarVarun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: default avatarJee Jee Li <pandaleefree@gmail.com>
Co-authored-by: default avatarMichael Goin <mgoin64@gmail.com>
Co-authored-by: default avatarAnimesh Jain <jainanimesh2305@yahoo.com>
Co-authored-by: default avatarRui Qiao <161574667+ruisearch42@users.noreply.github.com>
Co-authored-by: default avatarXiongfeiWei <isaacwxf23@gmail.com>
Co-authored-by: default avatarNick Hill <nhill@redhat.com>
Co-authored-by: default avatarWentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: default avatarJartX <sagformas@gmail.com>
Co-authored-by: default avatarfhl2000 <63384265+fhl2000@users.noreply.github.com>
Co-authored-by: default avatarvllmellm <vllm.ellm@embeddedllm.com>
Co-authored-by: default avatarkf <kuanfu.liu@embeddedllm.com>
Co-authored-by: default avatarNicolò Lucchesi <nlucches@redhat.com>
Co-authored-by: default avatarDipika Sikka <dipikasikka1@gmail.com>
Co-authored-by: default avatarSage Moore <sage@neuralmagic.com>
Co-authored-by: default avatartjtanaavllm <tunjian.tan@amd.com>
Co-authored-by: default avatarYong Hoon Shin <48474650+sarckk@users.noreply.github.com>
Co-authored-by: default avatarChih-Chieh Yang <7364402+cyang49@users.noreply.github.com>
Co-authored-by: default avatarRoger Wang <hey@rogerw.me>
Co-authored-by: default avatarVadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com>
Co-authored-by: default avatarYuxuan Zhang <2448370773@qq.com>
Co-authored-by: default avatarIsotr0py <2037008807@qq.com>
Co-authored-by: default avatarCyrus Leung <tlleungac@connect.ust.hk>
Co-authored-by: default avatarThomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: default avatarYan Ma <yan.ma@intel.com>
Co-authored-by: default avatarXiao <xiszishu@gmail.com>
Co-authored-by: default avatarjiahanc <173873397+jiahanc@users.noreply.github.com>
Co-authored-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: default avatarYe (Charlotte) Qi <yeq@meta.com>
Co-authored-by: default avatarRoberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com>
Co-authored-by: default avatarNing Xie <andy.xning@gmail.com>
Co-authored-by: default avatarH <linhaibin.eric@gmail.com>
Co-authored-by: default avatarDavid Ben-David <sdavidbd@gmail.com>
Co-authored-by: default avatarDavid Ben-David <davidb@pliops.com>
Co-authored-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: default avatarLi, Jiang <jiang1.li@intel.com>
Co-authored-by: default avatarTankNee <nee@tanknee.cn>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: default avatarSeiji Eicher <58963096+eicherseiji@users.noreply.github.com>
Co-authored-by: default avatarZiTian.Zhao <zitian.zhao@tencentmusic.com>
Co-authored-by: default avatar22quinn <33176974+22quinn@users.noreply.github.com>
Co-authored-by: default avatarAbirdcfly <fp544037857@gmail.com>
Co-authored-by: default avatarGiancarlo Delfin <32987265+TheEpicDolphin@users.noreply.github.com>
Co-authored-by: default avatarChenxi Yang <cxyang@cs.utexas.edu>
Co-authored-by: default avatarChenxi Yang <cxyang@meta.com>
Co-authored-by: default avatarTyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: default avatarWeixiao Huang <hwx.simle@gmail.com>
Co-authored-by: default avatarRaghav Ravishankar <113712354+alyosha-swamy@users.noreply.github.com>
Co-authored-by: default avatarericehanley <ericehanley@google.com>
Co-authored-by: default avatarZhonghua Deng <abzhonghua@gmail.com>
Co-authored-by: default avatarPo-Han Huang (NVIDIA) <53919306+nvpohanh@users.noreply.github.com>
Co-authored-by: default avatarPiteXChen <44110731+CLFutureX@users.noreply.github.com>
Co-authored-by: default avatarlkchen <github@lkchen.net>
Co-authored-by: default avatarTJian <tunjian.tan@embeddedllm.com>
Co-authored-by: default avatarGregory Shtrasberg <156009573+gshtras@users.noreply.github.com>
Co-authored-by: default avatartlipoca9 <160737620+tlipoca9@users.noreply.github.com>
Co-authored-by: default avatarelvischenv <219235043+elvischenv@users.noreply.github.com>
Co-authored-by: default avatarwang.yuqi <noooop@126.com>
Co-authored-by: default avatarBenji Beck <benjibeck@meta.com>
Co-authored-by: default avataryoukaichao <youkaichao@gmail.com>
Co-authored-by: default avatarSiyuan Liu <lsiyuan@google.com>
Co-authored-by: default avatarBenjamin Chislett <chislett.ben@gmail.com>
Co-authored-by: default avatarLiuXiaoxuanPKU <lilyliupku@gmail.com>
Co-authored-by: default avatarsimon-mo <xmo@berkeley.edu>
Co-authored-by: default avatarChen Zhang <zhangch99@outlook.com>
Co-authored-by: default avatarHongxia Yang <62075498+hongxiayang@users.noreply.github.com>
Co-authored-by: default avatarMinseok Lee <47620120+minseokl@users.noreply.github.com>
Co-authored-by: default avatarYongye Zhu <zyy1102000@gmail.com>
Co-authored-by: default avatarLucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: default avatarZhang Jason <ning.zhang2@amd.com>
Co-authored-by: default avatarAsaf Joseph Gardin <39553475+Josephasafg@users.noreply.github.com>
Co-authored-by: default avatarasafg <asafg@ai21.com>
Co-authored-by: default avatarLain <siyuanf@nvidia.com>
Co-authored-by: default avatartc-mb <157115220+tc-mb@users.noreply.github.com>
Co-authored-by: default avatarimning3 <hbning@pku.edu.cn>
Co-authored-by: default avatarMaximilien de Bayser <mbayser@br.ibm.com>
Co-authored-by: default avatarKunshang Ji <kunshang.ji@intel.com>
Co-authored-by: default avatarTao He <linzhu.ht@alibaba-inc.com>
Co-authored-by: default avatarqscqesze <qingjun@minimaxi.com>
Co-authored-by: default avatarSyed Muhammad Bin Asif <92625830+syedmba@users.noreply.github.com>
Co-authored-by: default avatarLionel Villard <villard@us.ibm.com>
Co-authored-by: default avatarWeiQing Chen <40507679+david6666666@users.noreply.github.com>
Co-authored-by: default avatarycyaw66 <497410282@qq.com>
Co-authored-by: default avatarMoritz Sanft <58110325+msanft@users.noreply.github.com>
Co-authored-by: default avatarMing Yang <minos.future@gmail.com>
Co-authored-by: default avatarAdrián García García <adrigarvk8@gmail.com>
Co-authored-by: default avatarMichael Goin <mgoin@redhat.com>
Co-authored-by: default avatarJaceyShao <65159281+JaceyShao@users.noreply.github.com>
Co-authored-by: default avatarshaojunqi <shaojunqi.sjq@alibaba-inc.com>
Co-authored-by: default avatarRicardo Decal <crypdick@users.noreply.github.com>
Co-authored-by: default avatarAndrew Chan <andrewkchan.akc@gmail.com>
Co-authored-by: default avatarfxmarty-amd <felmarty@amd.com>
Co-authored-by: default avatarAndrew Sansom <andrew@protopia.ai>
Co-authored-by: default avatarZhiyu <zhiyuc@nvidia.com>
Co-authored-by: default avatarShu Wang <shuw@nvidia.com>
Co-authored-by: default avatarXIn Li <xinli@nvidia.com>
Co-authored-by: default avatarJunhao Li <streaver91@gmail.com>
Co-authored-by: default avatarChauncey <chaunceyjiang@gmail.com>
Co-authored-by: default avatariAmir97 <71513472+iAmir97@users.noreply.github.com>
Co-authored-by: default avatariAmir97 <Amir.balwel@embeddedllm.com>
Co-authored-by: default avatarHong Hanh <hanh.usth@gmail.com>
Co-authored-by: default avatarDaniel Serebrenik <74646983+pliops-daniels@users.noreply.github.com>
Co-authored-by: default avataryewentao256 <zhyanwentao@126.com>
Co-authored-by: default avatarGuy Stone <guys@spotify.com>
Co-authored-by: default avataryyweiss <70619747+yyweiss@users.noreply.github.com>
Co-authored-by: default avatarPradyun92 <142861237+Pradyun92@users.noreply.github.com>
Co-authored-by: default avatarPradyun Ramadorai <pradyunr@amazon.com>
Co-authored-by: default avatarNicolò Lucchesi <nicolo.lucchesi@gmail.com>
parent ab9f2cfd
......@@ -351,6 +351,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}"
CUDA_ARCHS "${MARLIN_ARCHS}")
set_source_files_properties(${MARLIN_TEMPLATE_KERNEL_SRC}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC})
......@@ -364,7 +366,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_SRCS}"
CUDA_ARCHS "${MARLIN_ARCHS}")
set_source_files_properties("csrc/quantization/gptq_marlin/gptq_marlin.cu"
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
list(APPEND VLLM_EXT_SRC "${MARLIN_SRCS}")
message(STATUS "Building Marlin kernels for archs: ${MARLIN_ARCHS}")
else()
message(STATUS "Not building Marlin kernels as no compatible archs found"
......@@ -854,6 +859,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set_gencode_flags_for_srcs(
SRCS "${MOE_WNAA16_MARLIN_SRC}"
CUDA_ARCHS "${MARLIN_MOE_ARCHS}")
set_source_files_properties(${MOE_WNAA16_MARLIN_SRC}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
list(APPEND VLLM_MOE_EXT_SRC ${MOE_WNAA16_MARLIN_SRC})
......
......@@ -236,6 +236,7 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
a=bt.a,
c=None,
b_q_weight=w_q,
b_bias=None,
b_scales=w_s,
global_scale=None,
b_zeros=w_zp,
......
......@@ -321,6 +321,8 @@ static inline constexpr auto kFE3M2f =
ScalarType::float_(3, 2, true, ScalarType::NAN_NONE);
static inline constexpr auto kFE4M3fn =
ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN);
static inline constexpr auto kFE8M0fnu =
ScalarType(8, 0, false, 0, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN);
static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2);
static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7);
static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10);
......
......@@ -20,6 +20,7 @@ namespace MARLIN_NAMESPACE_NAME {
TEMPLATE = ("template __global__ void Marlin<"
"{{scalar_t}}, "
"{{w_type_id}}, "
"{{s_type_id}}, "
"{{threads}}, "
"{{thread_m_blocks}}, "
"{{thread_n_blocks}}, "
......@@ -77,6 +78,7 @@ def generate_new_kernels():
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
continue
# nvfp4 only supports group_size == 16
# mxfp4 only supports group_size == 32
if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]:
continue
# other quantization methods don't support group_size = 16
......@@ -89,9 +91,22 @@ def generate_new_kernels():
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"
if scalar_type == "vllm::kFE2M1f" and group_blocks == 1:
s_type = "vllm::kFE4M3fn"
elif scalar_type == "vllm::kFE2M1f" and group_blocks == 2:
s_type = "vllm::kFE8M0fnu"
if dtype == "fp16":
# we cannot safely dequantize e8m0 to fp16, so skip this
continue
elif dtype == "fp16":
s_type = "vllm::kFloat16"
elif dtype == "bf16":
s_type = "vllm::kBFloat16"
template_str = jinja2.Template(TEMPLATE).render(
scalar_t=c_dtype,
w_type_id=scalar_type + ".id()",
s_type_id=s_type + ".id()",
threads=threads,
thread_m_blocks=max(m_blocks, 1),
thread_n_blocks=n_blocks,
......
......@@ -10,6 +10,7 @@
#define MARLIN_KERNEL_PARAMS \
const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ b_bias_ptr, \
const int4 *__restrict__ scales_ptr, \
const uint16_t *__restrict__ scale2_ptr, \
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
......@@ -18,12 +19,13 @@
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
const float *__restrict__ topk_weights_ptr, int top_k, \
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
int prob_n, int prob_k, int *locks, bool use_atomic_add, \
int prob_n, int prob_k, int *locks, bool has_bias, bool use_atomic_add, \
bool use_fp32_reduce, int max_shared_mem
namespace MARLIN_NAMESPACE_NAME {
template <typename scalar_t, // compute dtype, half or nv_float16
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const vllm::ScalarTypeId s_type_id, // weight scale ScalarType id
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the
......
......@@ -280,6 +280,7 @@ __device__ inline void wait_negative_and_add(int* lock) {
template <typename scalar_t, // compute dtype, half or nv_float16
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const vllm::ScalarTypeId s_type_id, // weight scale ScalarType id
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the
......@@ -299,6 +300,7 @@ __global__ void Marlin(
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
int4* __restrict__ C, // fp16 output buffer of shape mxn
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
const int4* __restrict__ b_bias_ptr,
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
// (k/groupsize)xn
const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4
......@@ -318,6 +320,7 @@ __global__ void Marlin(
int prob_n, // output dimension n
int prob_k, // reduction dimension k
int* locks, // extra global storage for barrier synchronization
bool has_bias,
bool use_atomic_add, // whether to use atomic add to reduce
bool use_fp32_reduce, // whether to use fp32 global reduce
int max_shared_mem) {
......@@ -342,12 +345,23 @@ __global__ void Marlin(
extern __shared__ int4 sh[];
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
static constexpr auto s_type = vllm::ScalarType::from_id(s_type_id);
if constexpr (w_type == vllm::kFE2M1f) {
static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 ||
s_type == vllm::kFE8M0fnu && group_blocks == 2);
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
static_assert(s_type == vllm::kBFloat16);
} else if constexpr (std::is_same<scalar_t, half>::value) {
static_assert(s_type == vllm::kFloat16);
}
constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8;
constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 ||
w_type == vllm::kU4B8 || w_type == vllm::kU8B128;
// see comments of dequant.h for more details
constexpr bool dequant_skip_flop =
!is_int_type ||
w_type == vllm::kFE4M3fn ||
w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn ||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
has_zp && !is_zp_float && !(w_type == vllm::kU8);
......@@ -365,6 +379,7 @@ __global__ void Marlin(
const int zp_expert_stride =
is_zp_float ? prob_n * prob_k / group_size / 8
: prob_n * prob_k / group_size / (pack_factor * 4);
const int b_bias_expert_stride = prob_n / 8;
// parallel: num valid moe blocks
int num_tokens_past_padded = num_tokens_past_padded_ptr[0];
......@@ -475,7 +490,7 @@ __global__ void Marlin(
for (int i = 0; i < 4; i++) {
int idx = tid4 * 4 + i;
idx = idx < block_num_valid_tokens ? idx : 0;
if constexpr (w_type == vllm::kFE2M1f) {
if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
sh_block_topk_weights[idx] = __hmul2(
global_scale, Dtype::num2num2(Dtype::float2num(
topk_weights_ptr[sh_block_sorted_ids[idx]])));
......@@ -513,7 +528,7 @@ __global__ void Marlin(
expert_id = expert_ids_ptr[block_id];
}
if constexpr (w_type == vllm::kFE2M1f) {
if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
uint16_t val = scale2_ptr[expert_id];
global_scale = Dtype::num2num2(*reinterpret_cast<scalar_t*>(&val));
}
......@@ -526,6 +541,9 @@ __global__ void Marlin(
if constexpr (has_act_order) {
g_idx += (expert_id - old_expert_id) * prob_k;
}
if (has_bias) {
b_bias_ptr += (expert_id - old_expert_id) * b_bias_expert_stride;
}
read_moe_block_data(block_id);
};
......@@ -721,7 +739,7 @@ __global__ void Marlin(
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 4;
s_sh_rd = s_sh_rd * 2 + warp_row % 2;
s_sh_rd = s_sh_rd * 2 + (warp_row / group_blocks) % 2;
} else if constexpr (group_blocks != -1)
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
......@@ -734,6 +752,18 @@ __global__ void Marlin(
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) % 4;
int bias_sh_rd;
if constexpr (m_block_size_8) {
bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 8;
} else {
bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) % 4;
}
int bias_sh_wr = threadIdx.x;
int bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x;
// Zero-points have the same read layout as the scales
// (without column-wise case)
constexpr int num_col_threads = 8;
......@@ -793,7 +823,19 @@ __global__ void Marlin(
constexpr int sh_b_size = stages * b_sh_stage;
int4* sh_b = sh_new;
int4* sh_red = sh_new;
int4* sh_g_idx = sh_b + (sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
constexpr int sh_size_b_red_min =
(sh_red_size < sh_b_size ? sh_red_size : sh_b_size);
constexpr int sh_size_b_red_max =
(sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
constexpr int sh_bias_size = (thread_n_blocks * 16 / 8);
constexpr int sh_b_red_bias_size =
sh_size_b_red_max > (sh_size_b_red_min + sh_bias_size)
? sh_size_b_red_max
: (sh_size_b_red_min + sh_bias_size);
int4* sh_bias = sh_new + sh_size_b_red_min;
int4* sh_g_idx = sh_new + sh_b_red_bias_size;
int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride)
: (stages * s_sh_stage);
......@@ -803,9 +845,9 @@ __global__ void Marlin(
static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <=
stages * b_sh_stage);
int4* sh_a = sh_s + sh_s_size;
constexpr int shm_size_used =
moe_block_size + stages * (g_idx_stage + zp_sh_stage) + sh_s_size +
(sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
constexpr int shm_size_used = moe_block_size +
stages * (g_idx_stage + zp_sh_stage) +
sh_s_size + sh_b_red_bias_size;
// all remaining shared memory is used to cache A (input)
// sh_a_max_row is at least ` stages * 16 * thread_m_blocks `
......@@ -817,6 +859,7 @@ __global__ void Marlin(
I4 frag_b_quant[2][b_thread_vecs];
FragC frag_c[thread_m_blocks][4][2];
FragS frag_s[2][4]; // No act-order
FragS frag_bias[2][4];
FragS act_frag_s[2][4][4]; // For act-order
int frag_qzp[2][num_ints_per_thread]; // Zero-points
FragZP frag_zp; // Zero-points in fp16
......@@ -1065,10 +1108,15 @@ __global__ void Marlin(
if constexpr (w_type_id != vllm::kFE2M1f.id()) {
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
} else {
} else if constexpr (group_blocks == 1 || thread_k_blocks > 4) {
reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
reinterpret_cast<int2*>(
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
} else {
reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
reinterpret_cast<int2*>(
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride) +
k % 2];
}
}
}
......@@ -1281,9 +1329,9 @@ __global__ void Marlin(
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];
dequant_fp8_scales<scalar_t2>(s_quant_0,
reinterpret_cast<scalar_t2*>(&frag_s[k2]));
dequant_fp8_scales<scalar_t2>(
dequant_fp8_scales<scalar_t2, s_type_id>(
s_quant_0, reinterpret_cast<scalar_t2*>(&frag_s[k2]));
dequant_fp8_scales<scalar_t2, s_type_id>(
s_quant_1, reinterpret_cast<scalar_t2*>(&frag_s[k2]) + 2);
}
......@@ -1566,7 +1614,7 @@ __global__ void Marlin(
// Write out the reduce final result in the correct layout. We only actually
// reshuffle matrix fragments in this step, the reduction above is performed
// in fragment layout.
auto write_result = [&]() {
auto write_result = [&](bool last) {
int c_gl_stride = prob_n / 8;
constexpr int c_sh_stride = 2 * thread_n_blocks + 1;
int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
......@@ -1592,7 +1640,7 @@ __global__ void Marlin(
// We first reorder in shared memory to guarantee the most efficient final
// global write patterns
auto write = [&](int idx, float c0, float c1, FragS& s) {
auto write = [&](int idx, float c0, float c1, FragS& s, FragS& b_bias) {
scalar_t2 res =
Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1));
......@@ -1601,14 +1649,27 @@ __global__ void Marlin(
if constexpr (!has_act_order && group_blocks == -1 &&
w_type.size_bits() == 4 &&
(has_zp && dequant_skip_flop || !has_zp)) {
res = __hmul2(res, s[0]);
scalar_t2 tmp_scale = s[0];
if constexpr (m_block_size_8) {
tmp_scale = Dtype::num2num2(
reinterpret_cast<scalar_t*>(&s[0])[(threadIdx.x % 8) / 4]);
}
res = __hmul2(res, tmp_scale);
}
if constexpr (w_type == vllm::kFE2M1f) {
if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
if (!mul_topk_weights) {
res = __hmul2(res, global_scale);
}
}
if (has_bias && last) {
scalar_t2 tmp_bias = b_bias[0];
if constexpr (m_block_size_8) {
tmp_bias = Dtype::num2num2(
reinterpret_cast<scalar_t*>(&b_bias[0])[(threadIdx.x % 8) / 4]);
}
res = __hadd2(res, tmp_bias);
}
if constexpr (m_block_size_8) {
((scalar_t*)sh_red)[idx] = res.x;
......@@ -1626,19 +1687,25 @@ __global__ void Marlin(
if constexpr (m_block_size_8) {
int wr = c_sh_wr + 16 * j;
write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1],
frag_s[j / 2][2 * (j % 2) + 0]);
frag_s[j / 2][2 * (j % 2) + 0],
frag_bias[j / 2][2 * (j % 2) + 0]);
write(wr + 8, frag_c[i][j][0][2], frag_c[i][j][0][3],
frag_s[j / 2][2 * (j % 2) + 1]);
frag_s[j / 2][2 * (j % 2) + 1],
frag_bias[j / 2][2 * (j % 2) + 1]);
} else {
int wr = c_sh_wr + 8 * j;
write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],
frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);
frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0],
frag_bias[j / 2][2 * (j % 2) + 0]);
write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],
frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);
frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0],
frag_bias[j / 2][2 * (j % 2) + 0]);
write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],
frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);
frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1],
frag_bias[j / 2][2 * (j % 2) + 1]);
write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],
frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);
frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1],
frag_bias[j / 2][2 * (j % 2) + 1]);
}
}
c_sh_wr += 16 * (4 * c_sh_stride);
......@@ -1805,6 +1872,14 @@ __global__ void Marlin(
}
thread_block_reduce();
if (has_bias && last) {
__syncthreads();
cp_async4_pred(&sh_bias[bias_sh_wr], &b_bias_ptr[bias_gl_rd],
threadIdx.x < 16 * thread_n_blocks / 8);
cp_async_fence();
}
if constexpr (!has_act_order && group_blocks == -1 &&
(has_zp && dequant_skip_flop || !has_zp)) {
if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
......@@ -1867,11 +1942,20 @@ __global__ void Marlin(
}
barrier_release(&locks[locks_off], last);
}
if (has_bias && last) {
cp_async_wait<0>();
__syncthreads();
reinterpret_cast<int4*>(&frag_bias)[0] = sh_bias[bias_sh_rd];
reinterpret_cast<int4*>(&frag_bias)[1] = sh_bias[bias_sh_rd + 4];
__syncthreads();
}
if (use_atomic_add && slice_count > 1 && slice_idx != 0)
wait_negative_and_add(&locks[locks_off]);
if (last || use_atomic_add)
// only the last block in a slice actually writes the result
write_result();
write_result(last);
int old_slice_row = slice_row;
slice_row = 0;
slice_col_par++;
......@@ -1904,6 +1988,7 @@ __global__ void Marlin(
for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
}
bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x;
// Update slice k/n for scales loading
if constexpr (has_act_order) {
slice_k_start = tb_k * slice_row;
......
......@@ -51,8 +51,9 @@ __global__ void permute_cols_kernel(
} // namespace marlin
torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor& a, std::optional<torch::Tensor> const& c_or_none,
torch::Tensor& b_q_weight, torch::Tensor& b_scales,
torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
torch::Tensor& b_q_weight,
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none,
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
......@@ -212,7 +213,7 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
// Get B size
int tb_k = th_config.thread_k;
int tb_n = th_config.thread_n;
int tb_m = thread_m_blocks * (m_block_size_8 ? 8 : 16);
int tb_m = thread_m_blocks * 16;
// shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights
// both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32)
......@@ -220,6 +221,11 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
int sh_a_size = pipe_stages * (tb_m * tb_k) * 2;
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
int sh_red_size = tb_m * (tb_n + 8) * 2;
int sh_bias_size = tb_n * 2;
int tmp_size =
(sh_b_size > sh_red_size ? sh_red_size : sh_b_size) + sh_bias_size;
tmp_size = max(max(sh_b_size, sh_red_size), tmp_size);
int sh_s_size =
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
group_size, has_act_order, is_k_full);
......@@ -234,8 +240,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
sh_zp_size = sh_s_size / 2;
}
int total_size = max(sh_b_size, sh_red_size) + sh_a_size + sh_s_size +
sh_zp_size + sh_g_idx_size + sh_block_meta_size;
int total_size = tmp_size + sh_a_size + sh_s_size + sh_zp_size +
sh_g_idx_size + sh_block_meta_size;
return total_size;
}
......@@ -270,7 +276,7 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
int cache_size = get_kernel_cache_size(
th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float);
return cache_size <= max_shared_mem;
return cache_size + 512 <= max_shared_mem;
}
#define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
......@@ -281,9 +287,14 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
m_block_size_8 == M_BLOCK_SIZE_8 && \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
is_zp_float == IS_ZP_FLOAT) { \
kernel = Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, \
pipe_stages, GROUP_BLOCKS, IS_ZP_FLOAT>; \
constexpr auto S_TYPE = \
W_TYPE == vllm::kFE2M1f \
? (GROUP_BLOCKS == 1 ? vllm::kFE4M3fn : vllm::kFE8M0fnu) \
: (std::is_same<scalar_t, half>::value ? vllm::kFloat16 \
: vllm::kBFloat16); \
kernel = Marlin<scalar_t, W_TYPE.id(), S_TYPE.id(), NUM_THREADS, \
THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
M_BLOCK_SIZE_8, pipe_stages, GROUP_BLOCKS, IS_ZP_FLOAT>; \
}
// COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false)
......@@ -335,30 +346,44 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
\
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
#define BIGGROUP_GET_IF(W_TYPE) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \
BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \
BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128)
#define NVFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
#define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
#define NVFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
#define FP4_GET_IF(W_TYPE) \
FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
FP4_GET_IF_M234(W_TYPE, 8, 4, 128)
#define NVFP4_GET_IF(W_TYPE) \
NVFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
NVFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
NVFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
NVFP4_GET_IF_M234(W_TYPE, 8, 4, 128)
#define BIGGROUP_GET_IF(W_TYPE) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \
BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \
BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128)
#define MXFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)
#define MXFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)
#define MXFP4_GET_IF(W_TYPE) \
MXFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
MXFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
MXFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
MXFP4_GET_IF_M234(W_TYPE, 8, 4, 128)
// We currently have 4-bit models only with group_blocks == 4
#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
......@@ -408,12 +433,17 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
COMMON_GET_IF(vllm::kU4B8)
COMMON_GET_IF(vllm::kU8B128)
BIGGROUP_GET_IF(vllm::kFE4M3fn)
NVFP4_GET_IF(vllm::kFE2M1f)
FP4_GET_IF(vllm::kFE2M1f)
BIGGROUP_GET_IF(vllm::kFE4M3fn)
ACT_GET_IF(vllm::kU4B8)
ACT_GET_IF(vllm::kU8B128)
if (std::is_same<scalar_t, nv_bfloat16>::value) {
if (false) {
}
MXFP4_GET_IF(vllm::kFE2M1f)
}
return kernel;
}
......@@ -482,16 +512,16 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
}
template <typename scalar_t>
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
void* s2, void* zp, void* g_idx, void* perm, void* a_tmp,
void* sorted_token_ids, void* expert_ids,
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
void* s, void* s2, void* zp, void* g_idx, void* perm,
void* a_tmp, void* sorted_token_ids, void* expert_ids,
void* num_tokens_past_padded, void* topk_weights,
int moe_block_size, int top_k, bool mul_topk_weights, bool is_ep,
int prob_m, int prob_n, int prob_k, void* workspace,
vllm::ScalarType const& q_type, bool has_act_order,
bool is_k_full, bool has_zp, int num_groups, int group_size,
int dev, cudaStream_t stream, int thread_k, int thread_n,
int sms, bool use_atomic_add, bool use_fp32_reduce,
vllm::ScalarType const& q_type, bool has_bias,
bool has_act_order, bool is_k_full, bool has_zp, int num_groups,
int group_size, int dev, cudaStream_t stream, int thread_k,
int thread_n, int sms, bool use_atomic_add, bool use_fp32_reduce,
bool is_zp_float) {
int thread_m_blocks = div_ceil(moe_block_size, 16);
bool m_block_size_8 = moe_block_size == 8;
......@@ -538,6 +568,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
const int4* B_ptr = (const int4*)B;
int4* C_ptr = (int4*)C;
int4* C_tmp_ptr = (int4*)C_tmp;
const int4* bias_ptr = (const int4*)b_bias;
const int4* s_ptr = (const int4*)s;
const uint16_t* s2_ptr = (const uint16_t*)s2;
const int4* zp_ptr = (const int4*)zp;
......@@ -648,10 +679,10 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
// avoid ">>>" being formatted to "> > >"
// clang-format off
kernel<<<blocks, num_threads, max_shared_mem, stream>>>(
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr,
A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr,
sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr,
topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m,
prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce, max_shared_mem);
prob_n, prob_k, locks, has_bias, use_atomic_add, use_fp32_reduce, max_shared_mem);
// clang-format on
}
......@@ -659,7 +690,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor& a, std::optional<torch::Tensor> const& c_or_none,
torch::Tensor& b_q_weight, torch::Tensor& b_scales,
torch::Tensor& b_q_weight,
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
std::optional<torch::Tensor> const& global_scale_or_none,
std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none,
......@@ -766,7 +798,6 @@ torch::Tensor moe_wna16_marlin_gemm(
num_groups = b_scales.size(1);
torch::Tensor g_idx, perm, a_tmp;
;
if (g_idx_or_none.has_value() && perm_or_none.has_value()) {
g_idx = g_idx_or_none.value();
perm = perm_or_none.value();
......@@ -815,12 +846,24 @@ torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor global_scale;
if (global_scale_or_none.has_value()) {
global_scale = global_scale_or_none.value();
TORCH_CHECK(b_q_type == vllm::kFE2M1f,
"global_scale can only be used for float4_e2m1f.");
TORCH_CHECK(b_q_type == vllm::kFE2M1f && group_size == 16,
"global_scale can only be used for nvfp4 format.");
} else {
global_scale = torch::empty({0}, options);
TORCH_CHECK(!(b_q_type == vllm::kFE2M1f),
"the global_scale parameter must be passed for float4_e2m1f.");
TORCH_CHECK(!(b_q_type == vllm::kFE2M1f && group_size == 16),
"the global_scale parameter must be passed for nvfp4 format.");
}
bool has_bias = b_bias_or_none.has_value();
torch::Tensor b_bias;
if (has_bias) {
b_bias = b_bias_or_none.value();
TORCH_CHECK(b_bias.device().is_cuda(), "b_bias is not on GPU");
TORCH_CHECK(b_bias.is_contiguous(), "b_bias is not contiguous");
TORCH_CHECK(b_bias.size(1) == size_n, "b_bias.size(0) != size_n");
TORCH_CHECK(b_bias.stride(1) == 1, "b_bias.stride(1) != 1");
} else {
b_bias = torch::empty({0}, options);
}
torch::Tensor b_zeros;
......@@ -832,7 +875,6 @@ torch::Tensor moe_wna16_marlin_gemm(
b_zeros = torch::empty({0}, options);
}
bool has_zp = b_zeros.size(-1) > 0;
if (has_zp) {
TORCH_CHECK(
b_q_type == vllm::kU4 || b_q_type == vllm::kU8,
......@@ -890,41 +932,58 @@ torch::Tensor moe_wna16_marlin_gemm(
if (a.scalar_type() == at::ScalarType::Half) {
void* scales_ptr;
if (b_q_type == vllm::kFE2M1f) {
if (group_size == 16)
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
else if (group_size == 32)
scales_ptr = b_scales.data_ptr<at::Float8_e8m0fnu>();
else
TORCH_CHECK(false,
"float4_e2m1f only supports group_size == 16 (NVFP4) ",
"and group_size == 32 (MXFP4)");
} else {
scales_ptr = b_scales.data_ptr<at::Half>();
}
MARLIN_NAMESPACE_NAME::marlin_mm<half>(
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
c_tmp.data_ptr<float>(), scales_ptr, global_scale.data_ptr<at::Half>(),
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
a_tmp.data_ptr<at::Half>(), sorted_token_ids.data_ptr(),
expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(),
topk_weights.data_ptr(), moe_block_size, top_k, mul_topk_weights, is_ep,
size_m, size_n, size_k, workspace.data_ptr(), b_q_type, has_act_order,
is_k_full, has_zp, num_groups, group_size, dev,
c_tmp.data_ptr<float>(), b_bias.data_ptr<at::Half>(), scales_ptr,
global_scale.data_ptr<at::Half>(), b_zeros.data_ptr(), g_idx.data_ptr(),
perm.data_ptr(), a_tmp.data_ptr<at::Half>(),
sorted_token_ids.data_ptr(), expert_ids.data_ptr(),
num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(),
moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k,
workspace.data_ptr(), b_q_type, has_bias, has_act_order, is_k_full,
has_zp, num_groups, group_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
use_atomic_add, use_fp32_reduce, is_zp_float);
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
void* scales_ptr;
if (b_q_type == vllm::kFE2M1f) {
if (group_size == 16)
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
else if (group_size == 32)
scales_ptr = b_scales.data_ptr<at::Float8_e8m0fnu>();
else
TORCH_CHECK(false,
"float4_e2m1f only supports group_size == 16 (NVFP4) ",
"and group_size == 32 (MXFP4)");
} else {
scales_ptr = b_scales.data_ptr<at::BFloat16>();
}
MARLIN_NAMESPACE_NAME::marlin_mm<nv_bfloat16>(
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(), scales_ptr,
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
b_bias.data_ptr<at::BFloat16>(), scales_ptr,
global_scale.data_ptr<at::BFloat16>(), b_zeros.data_ptr(),
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
sorted_token_ids.data_ptr(), expert_ids.data_ptr(),
num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(),
moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k,
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float);
workspace.data_ptr(), b_q_type, has_bias, has_act_order, is_k_full,
has_zp, num_groups, group_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
use_atomic_add, use_fp32_reduce, is_zp_float);
} else {
TORCH_CHECK(false,
"moe_wna16_marlin_gemm only supports bfloat16 and float16");
......
......@@ -35,7 +35,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
m.def(
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
"Tensor! b_q_weight, Tensor! b_scales, Tensor? global_scale, Tensor? "
"Tensor! b_q_weight, Tensor? b_bias_or_none,"
"Tensor! b_scales, Tensor? global_scale, Tensor? "
"b_zeros_or_none,"
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
"Tensor sorted_token_ids,"
......
......@@ -470,11 +470,12 @@ __device__ inline void dequant<nv_bfloat162, vllm::kFE2M1f.id(), false>(
frag_b[0] = __hmul2(frag_b[0], bias_reg);
}
template <typename scalar_t2>
template <typename scalar_t2, vllm::ScalarTypeId s_type_id>
__device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b);
template <>
__device__ inline void dequant_fp8_scales<half2>(int q, half2* frag_b) {
__device__ inline void dequant_fp8_scales<half2, vllm::kFE4M3fn.id()>(
int q, half2* frag_b) {
int Out1 = (q & 0xFF00FF00) >> 1;
;
q <<= 8;
......@@ -486,8 +487,8 @@ __device__ inline void dequant_fp8_scales<half2>(int q, half2* frag_b) {
};
template <>
__device__ inline void dequant_fp8_scales<nv_bfloat162>(int q,
nv_bfloat162* frag_b) {
__device__ inline void dequant_fp8_scales<nv_bfloat162, vllm::kFE4M3fn.id()>(
int q, nv_bfloat162* frag_b) {
constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT;
constexpr int MASK = 0x7F007F00;
......@@ -502,6 +503,20 @@ __device__ inline void dequant_fp8_scales<nv_bfloat162>(int q,
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
}
template <>
__device__ inline void dequant_fp8_scales<nv_bfloat162, vllm::kFE8M0fnu.id()>(
int q, nv_bfloat162* frag_b) {
// In this conversion, 2 ** -127 in FP8E8M0 would become 0 in BF16,
// but we assume that such a extreme value would not occur in real models.
int Out1 = (q & 0xFF00FF00) >> 1;
q <<= 7;
int Out2 = q & 0x7F807F80;
// Note: reverse indexing is intentional because weights are permuted
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
}
#endif
} // namespace MARLIN_NAMESPACE_NAME
......@@ -20,6 +20,7 @@ namespace MARLIN_NAMESPACE_NAME {
TEMPLATE = ("template __global__ void Marlin<"
"{{scalar_t}}, "
"{{w_type_id}}, "
"{{s_type_id}}, "
"{{threads}}, "
"{{thread_m_blocks}}, "
"{{thread_n_blocks}}, "
......@@ -78,7 +79,8 @@ def generate_new_kernels():
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
continue
# nvfp4 only supports group_size == 16
if scalar_type == "vllm::kFE2M1f" and group_blocks != 1:
# mxfp4 only supports group_size == 32
if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]:
continue
# other quantization methods don't support group_size = 16
if scalar_type != "vllm::kFE2M1f" and group_blocks == 1:
......@@ -97,10 +99,23 @@ def generate_new_kernels():
# 4bit quantization and fp16
is_zp_float_list.append(True)
if scalar_type == "vllm::kFE2M1f" and group_blocks == 1:
s_type = "vllm::kFE4M3fn"
elif scalar_type == "vllm::kFE2M1f" and group_blocks == 2:
s_type = "vllm::kFE8M0fnu"
if dtype == "fp16":
# we cannot safely dequantize e8m0 to fp16, so skip this
continue
elif dtype == "fp16":
s_type = "vllm::kFloat16"
elif dtype == "bf16":
s_type = "vllm::kBFloat16"
for is_zp_float in is_zp_float_list:
template_str = jinja2.Template(TEMPLATE).render(
scalar_t=c_dtype,
w_type_id=scalar_type + ".id()",
s_type_id=s_type + ".id()",
threads=threads,
thread_m_blocks=max(m_blocks, 1),
thread_n_blocks=n_blocks,
......
......@@ -48,7 +48,8 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
torch::Tensor gptq_marlin_gemm(
torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
torch::Tensor& b_q_weight, torch::Tensor& b_scales,
torch::Tensor& b_q_weight,
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none,
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
......@@ -187,7 +188,12 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
int tb_m = thread_m_blocks * 16;
int sh_a_size = pipe_stages * (tb_m * tb_k) * 2;
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
int sh_red_size = tb_m * (tb_n + 8);
int sh_red_size = tb_m * (tb_n + 8) * 2;
int sh_bias_size = tb_n * 2;
int tmp_size =
(sh_b_size > sh_red_size ? sh_red_size : sh_b_size) + sh_bias_size;
tmp_size = max(max(sh_b_size, sh_red_size), tmp_size);
int sh_s_size =
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
group_size, has_act_order, is_k_full);
......@@ -202,8 +208,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
sh_zp_size = sh_s_size / 2;
}
int total_size = max(sh_b_size, sh_red_size) + sh_a_size + sh_s_size +
sh_zp_size + sh_g_idx_size;
int total_size =
tmp_size + sh_a_size + sh_s_size + sh_zp_size + sh_g_idx_size;
return total_size;
}
......@@ -237,7 +243,7 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
int cache_size = get_kernel_cache_size(
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size,
has_act_order, is_k_full, has_zp, is_zp_float);
return cache_size <= max_shared_mem;
return cache_size + 512 <= max_shared_mem;
}
#define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
......@@ -248,9 +254,14 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
m_block_size_8 == M_BLOCK_SIZE_8 && \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
is_zp_float == IS_ZP_FLOAT) { \
kernel = Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, \
pipe_stages, GROUP_BLOCKS, IS_ZP_FLOAT>; \
constexpr auto S_TYPE = \
W_TYPE == vllm::kFE2M1f \
? (GROUP_BLOCKS == 1 ? vllm::kFE4M3fn : vllm::kFE8M0fnu) \
: (std::is_same<scalar_t, half>::value ? vllm::kFloat16 \
: vllm::kBFloat16); \
kernel = Marlin<scalar_t, W_TYPE.id(), S_TYPE.id(), NUM_THREADS, \
THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
M_BLOCK_SIZE_8, pipe_stages, GROUP_BLOCKS, IS_ZP_FLOAT>; \
}
// COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false)
......@@ -315,22 +326,39 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) \
BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128)
#define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
#define NVFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
#define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
#define NVFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
#define FP4_GET_IF(W_TYPE) \
FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
FP4_GET_IF_M1(W_TYPE, 4, 8, 128) \
FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
FP4_GET_IF_M234(W_TYPE, 8, 4, 128) \
FP4_GET_IF_M234(W_TYPE, 4, 8, 128)
#define NVFP4_GET_IF(W_TYPE) \
NVFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
NVFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
NVFP4_GET_IF_M1(W_TYPE, 4, 8, 128) \
NVFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
NVFP4_GET_IF_M234(W_TYPE, 8, 4, 128) \
NVFP4_GET_IF_M234(W_TYPE, 4, 8, 128)
#define MXFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)
#define MXFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)
#define MXFP4_GET_IF(W_TYPE) \
MXFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
MXFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
MXFP4_GET_IF_M1(W_TYPE, 4, 8, 128) \
MXFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
MXFP4_GET_IF_M234(W_TYPE, 8, 4, 128) \
MXFP4_GET_IF_M234(W_TYPE, 4, 8, 128)
// We currently have 4-bit models only with group_blocks == 4
#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
......@@ -384,7 +412,7 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
COMMON_GET_IF(vllm::kU4B8)
COMMON_GET_IF(vllm::kU8B128)
FP4_GET_IF(vllm::kFE2M1f)
NVFP4_GET_IF(vllm::kFE2M1f)
BIGGROUP_GET_IF(vllm::kFE4M3fn)
......@@ -396,6 +424,11 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
}
FZP_GET_IF(vllm::kU4)
}
if (std::is_same<scalar_t, nv_bfloat16>::value) {
if (false) {
}
MXFP4_GET_IF(vllm::kFE2M1f)
}
return kernel;
}
......@@ -453,12 +486,12 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
}
template <typename scalar_t>
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
void* s2, void* zp, void* g_idx, void* perm, void* a_tmp,
int prob_m, int prob_n, int prob_k, int lda, void* workspace,
vllm::ScalarType const& q_type, bool has_act_order,
bool is_k_full, bool has_zp, int num_groups, int group_size,
int dev, cudaStream_t stream, int thread_k_init,
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
void* s, void* s2, void* zp, void* g_idx, void* perm,
void* a_tmp, int prob_m, int prob_n, int prob_k, int lda,
void* workspace, vllm::ScalarType const& q_type, bool has_bias,
bool has_act_order, bool is_k_full, bool has_zp, int num_groups,
int group_size, int dev, cudaStream_t stream, int thread_k_init,
int thread_n_init, int sms, bool use_atomic_add,
bool use_fp32_reduce, bool is_zp_float) {
if (has_zp) {
......@@ -503,6 +536,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
const int4* B_ptr = (const int4*)B;
int4* C_ptr = (int4*)C;
int4* C_tmp_ptr = (int4*)C_tmp;
const int4* bias_ptr = (const int4*)b_bias;
const int4* s_ptr = (const int4*)s;
const uint16_t* s2_ptr = (const uint16_t*)s2;
const int4* zp_ptr = (const int4*)zp;
......@@ -623,8 +657,9 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
// avoid ">>>" being formatted to "> > >"
// clang-format off
kernel<<<blocks, num_threads, max_shared_mem_new, stream>>>(
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr, num_groups,
prob_m_split, prob_n, prob_k, lda, locks, part_use_atomic_add,
A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, s_ptr, s2_ptr, zp_ptr,
g_idx_ptr, num_groups,
prob_m_split, prob_n, prob_k, lda, locks, has_bias, part_use_atomic_add,
use_fp32_reduce, max_shared_mem_new);
// clang-format on
......@@ -638,7 +673,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
torch::Tensor gptq_marlin_gemm(
torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
torch::Tensor& b_q_weight, torch::Tensor& b_scales,
torch::Tensor& b_q_weight,
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
std::optional<torch::Tensor> const& global_scale_or_none,
std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none,
......@@ -785,12 +821,24 @@ torch::Tensor gptq_marlin_gemm(
torch::Tensor global_scale;
if (global_scale_or_none.has_value()) {
global_scale = global_scale_or_none.value();
TORCH_CHECK(b_q_type == vllm::kFE2M1f,
"global_scale can only be used for float4_e2m1f.");
TORCH_CHECK(b_q_type == vllm::kFE2M1f && group_size == 16,
"global_scale can only be used for nvfp4 format.");
} else {
global_scale = torch::empty({0}, options);
TORCH_CHECK(!(b_q_type == vllm::kFE2M1f),
"the global_scale parameter must be passed for float4_e2m1f.");
TORCH_CHECK(!(b_q_type == vllm::kFE2M1f && group_size == 16),
"the global_scale parameter must be passed for nvfp4 format.");
}
bool has_bias = b_bias_or_none.has_value();
torch::Tensor b_bias;
if (has_bias) {
b_bias = b_bias_or_none.value();
TORCH_CHECK(b_bias.device().is_cuda(), "b_bias is not on GPU");
TORCH_CHECK(b_bias.is_contiguous(), "b_bias is not contiguous");
TORCH_CHECK(b_bias.size(0) == size_n, "b_bias.size(0) != size_n");
TORCH_CHECK(b_bias.stride(0) == 1, "b_bias.stride(0) != 1");
} else {
b_bias = torch::empty({0}, options);
}
torch::Tensor b_zeros;
......@@ -857,34 +905,50 @@ torch::Tensor gptq_marlin_gemm(
if (a.scalar_type() == at::ScalarType::Half) {
void* scales_ptr;
if (b_q_type == vllm::kFE2M1f) {
if (group_size == 16)
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
else if (group_size == 32)
scales_ptr = b_scales.data_ptr<at::Float8_e8m0fnu>();
else
TORCH_CHECK(false,
"float4_e2m1f only supports group_size == 16 (NVFP4) ",
"and group_size == 32 (MXFP4)");
} else {
scales_ptr = b_scales.data_ptr<at::Half>();
}
marlin::marlin_mm<half>(
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
c_tmp.data_ptr<float>(), scales_ptr, global_scale.data_ptr<at::Half>(),
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k, a.stride(0),
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float);
c_tmp.data_ptr<float>(), b_bias.data_ptr<at::Half>(), scales_ptr,
global_scale.data_ptr<at::Half>(), b_zeros.data_ptr(), g_idx.data_ptr(),
perm.data_ptr(), a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
a.stride(0), workspace.data_ptr(), b_q_type, has_bias, has_act_order,
is_k_full, has_zp, num_groups, group_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
use_atomic_add, use_fp32_reduce, is_zp_float);
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
void* scales_ptr;
if (b_q_type == vllm::kFE2M1f) {
if (group_size == 16)
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
else if (group_size == 32)
scales_ptr = b_scales.data_ptr<at::Float8_e8m0fnu>();
else
TORCH_CHECK(false,
"float4_e2m1f only supports group_size == 16 (NVFP4) ",
"and group_size == 32 (MXFP4)");
} else {
scales_ptr = b_scales.data_ptr<at::BFloat16>();
}
marlin::marlin_mm<nv_bfloat16>(
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(), scales_ptr,
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
b_bias.data_ptr<at::BFloat16>(), scales_ptr,
global_scale.data_ptr<at::BFloat16>(), b_zeros.data_ptr(),
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
size_m, size_n, size_k, a.stride(0), workspace.data_ptr(), b_q_type,
has_act_order, is_k_full, has_zp, num_groups, group_size, dev,
has_bias, has_act_order, is_k_full, has_zp, num_groups, group_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
use_atomic_add, use_fp32_reduce, is_zp_float);
} else {
......
......@@ -10,15 +10,18 @@
#define MARLIN_KERNEL_PARAMS \
const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ b_bias_ptr, \
const int4 *__restrict__ scales_ptr, \
const uint16_t *__restrict__ scale2_ptr, \
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \
bool use_atomic_add, bool use_fp32_reduce, int max_shared_mem
bool has_bias, bool use_atomic_add, bool use_fp32_reduce, \
int max_shared_mem
namespace MARLIN_NAMESPACE_NAME {
template <typename scalar_t, // compute dtype, half or nv_float16
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const vllm::ScalarTypeId s_type_id, // weight ScalarType id
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the
......
......@@ -39,6 +39,7 @@ namespace MARLIN_NAMESPACE_NAME {
template <typename scalar_t, // compute dtype, half or nv_float16
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const vllm::ScalarTypeId s_type_id, // weight scale ScalarType id
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the
......@@ -271,6 +272,7 @@ __device__ inline void wait_negative_and_add(int* lock) {
template <typename scalar_t, // compute dtype, half or nv_float16
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const vllm::ScalarTypeId s_type_id, // weight scale ScalarType id
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the
......@@ -290,6 +292,7 @@ __global__ void Marlin(
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
int4* __restrict__ C, // fp16 output buffer of shape mxn
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
const int4* __restrict__ b_bias_ptr,
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
// (k/groupsize)xn
const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4
......@@ -303,6 +306,7 @@ __global__ void Marlin(
int prob_k, // reduction dimension k
int lda, // A.stride(0), equal to prob_k is A is contiguous
int* locks, // extra global storage for barrier synchronization
bool has_bias,
bool use_atomic_add, // whether to use atomic add to reduce
bool use_fp32_reduce, // whether to use fp32 global reduce
int max_shared_mem) {
......@@ -326,18 +330,29 @@ __global__ void Marlin(
using FragZP = typename ScalarType<scalar_t>::FragZP;
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
static constexpr auto s_type = vllm::ScalarType::from_id(s_type_id);
if constexpr (w_type == vllm::kFE2M1f) {
static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 ||
s_type == vllm::kFE8M0fnu && group_blocks == 2);
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
static_assert(s_type == vllm::kBFloat16);
} else if constexpr (std::is_same<scalar_t, half>::value) {
static_assert(s_type == vllm::kFloat16);
}
constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8;
constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 ||
w_type == vllm::kU4B8 || w_type == vllm::kU8B128;
// see comments of dequant.h for more details
constexpr bool dequant_skip_flop =
!is_int_type ||
w_type == vllm::kFE4M3fn ||
w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn ||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
has_zp && !is_zp_float && !(w_type == vllm::kU8);
scalar_t2 global_scale;
if constexpr (w_type == vllm::kFE2M1f) {
if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
// NVFP4 format requires global scale
uint16_t val = scale2_ptr[0];
global_scale = Dtype::num2num2(*reinterpret_cast<scalar_t*>(&val));
}
......@@ -589,7 +604,7 @@ __global__ void Marlin(
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 4;
s_sh_rd = s_sh_rd * 2 + warp_row % 2;
s_sh_rd = s_sh_rd * 2 + (warp_row / group_blocks) % 2;
} else if constexpr (group_blocks != -1)
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
......@@ -602,6 +617,18 @@ __global__ void Marlin(
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) % 4;
int bias_sh_rd;
if constexpr (m_block_size_8) {
bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 8;
} else {
bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) % 4;
}
int bias_sh_wr = threadIdx.x;
int bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x;
// Zero-points have the same read layout as the scales
// (without column-wise case)
constexpr int num_col_threads = 8;
......@@ -670,7 +697,19 @@ __global__ void Marlin(
constexpr int sh_b_size = stages * b_sh_stage;
int4* sh_b = sh;
int4* sh_red = sh;
int4* sh_g_idx = sh_b + (sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
constexpr int sh_size_b_red_min =
(sh_red_size < sh_b_size ? sh_red_size : sh_b_size);
constexpr int sh_size_b_red_max =
(sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
constexpr int sh_bias_size = (thread_n_blocks * 16 / 8);
constexpr int sh_b_red_bias_size =
sh_size_b_red_max > (sh_size_b_red_min + sh_bias_size)
? sh_size_b_red_max
: (sh_size_b_red_min + sh_bias_size);
int4* sh_bias = sh + sh_size_b_red_min;
int4* sh_g_idx = sh + sh_b_red_bias_size;
int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride)
: (stages * s_sh_stage);
......@@ -680,15 +719,13 @@ __global__ void Marlin(
static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <=
stages * b_sh_stage);
int4* sh_a = sh_s + sh_s_size;
// constexpr int shm_size_used =
// stages * (g_idx_stage + zp_sh_stage) + sh_s_size +
// (sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
// Register storage for double buffer of shared memory reads.
FragA frag_a[2][thread_m_blocks];
I4 frag_b_quant[2][b_thread_vecs];
FragC frag_c[thread_m_blocks][4][2];
FragS frag_s[2][4]; // No act-order
FragS frag_bias[2][4];
FragS act_frag_s[2][4][4]; // For act-order
int frag_qzp[2][num_ints_per_thread]; // Zero-points
FragZP frag_zp; // Zero-points in fp16
......@@ -923,10 +960,15 @@ __global__ void Marlin(
if constexpr (w_type_id != vllm::kFE2M1f.id()) {
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
} else {
} else if constexpr (group_blocks == 1 || thread_k_blocks > 4) {
reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
reinterpret_cast<int2*>(
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
} else {
reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
reinterpret_cast<int2*>(
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride) +
k % 2];
}
}
}
......@@ -1139,9 +1181,9 @@ __global__ void Marlin(
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];
dequant_fp8_scales<scalar_t2>(s_quant_0,
reinterpret_cast<scalar_t2*>(&frag_s[k2]));
dequant_fp8_scales<scalar_t2>(
dequant_fp8_scales<scalar_t2, s_type_id>(
s_quant_0, reinterpret_cast<scalar_t2*>(&frag_s[k2]));
dequant_fp8_scales<scalar_t2, s_type_id>(
s_quant_1, reinterpret_cast<scalar_t2*>(&frag_s[k2]) + 2);
}
......@@ -1411,7 +1453,7 @@ __global__ void Marlin(
// Write out the reduce final result in the correct layout. We only actually
// reshuffle matrix fragments in this step, the reduction above is performed
// in fragment layout.
auto write_result = [&]() {
auto write_result = [&](bool last) {
int c_gl_stride = prob_n / 8;
constexpr int c_sh_stride = 2 * thread_n_blocks + 1;
int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
......@@ -1438,7 +1480,7 @@ __global__ void Marlin(
int c_gl_wr_end = c_gl_stride * prob_m;
// We first reorder in shared memory to guarantee the most efficient final
// global write patterns
auto write = [&](int idx, float c0, float c1, FragS& s) {
auto write = [&](int idx, float c0, float c1, FragS& s, FragS& b_bias) {
scalar_t2 res =
Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1));
......@@ -1447,12 +1489,25 @@ __global__ void Marlin(
if constexpr (!has_act_order && group_blocks == -1 &&
w_type.size_bits() == 4 &&
(has_zp && dequant_skip_flop || !has_zp)) {
res = __hmul2(res, s[0]);
scalar_t2 tmp_scale = s[0];
if constexpr (m_block_size_8) {
tmp_scale = Dtype::num2num2(
reinterpret_cast<scalar_t*>(&s[0])[(threadIdx.x % 8) / 4]);
}
res = __hmul2(res, tmp_scale);
}
if constexpr (w_type == vllm::kFE2M1f) {
if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
res = __hmul2(res, global_scale);
}
if (has_bias && last) {
scalar_t2 tmp_bias = b_bias[0];
if constexpr (m_block_size_8) {
tmp_bias = Dtype::num2num2(
reinterpret_cast<scalar_t*>(&b_bias[0])[(threadIdx.x % 8) / 4]);
}
res = __hadd2(res, tmp_bias);
}
if constexpr (m_block_size_8) {
((scalar_t*)sh_red)[idx] = res.x;
......@@ -1470,19 +1525,25 @@ __global__ void Marlin(
if constexpr (m_block_size_8) {
int wr = c_sh_wr + 16 * j;
write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1],
frag_s[j / 2][2 * (j % 2) + 0]);
frag_s[j / 2][2 * (j % 2) + 0],
frag_bias[j / 2][2 * (j % 2) + 0]);
write(wr + 8, frag_c[i][j][0][2], frag_c[i][j][0][3],
frag_s[j / 2][2 * (j % 2) + 1]);
frag_s[j / 2][2 * (j % 2) + 1],
frag_bias[j / 2][2 * (j % 2) + 1]);
} else {
int wr = c_sh_wr + 8 * j;
write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],
frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);
frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0],
frag_bias[j / 2][2 * (j % 2) + 0]);
write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],
frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);
frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0],
frag_bias[j / 2][2 * (j % 2) + 0]);
write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],
frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);
frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1],
frag_bias[j / 2][2 * (j % 2) + 1]);
write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],
frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);
frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1],
frag_bias[j / 2][2 * (j % 2) + 1]);
}
}
c_sh_wr += 16 * (4 * c_sh_stride);
......@@ -1622,6 +1683,14 @@ __global__ void Marlin(
}
thread_block_reduce();
if (has_bias && last) {
__syncthreads();
cp_async4_pred(&sh_bias[bias_sh_wr], &b_bias_ptr[bias_gl_rd],
threadIdx.x < 16 * thread_n_blocks / 8);
cp_async_fence();
}
if constexpr (!has_act_order && group_blocks == -1 &&
(has_zp && dequant_skip_flop || !has_zp)) {
if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
......@@ -1684,11 +1753,20 @@ __global__ void Marlin(
}
barrier_release(&locks[locks_off], last);
}
if (has_bias && last) {
cp_async_wait<0>();
__syncthreads();
reinterpret_cast<int4*>(&frag_bias)[0] = sh_bias[bias_sh_rd];
reinterpret_cast<int4*>(&frag_bias)[1] = sh_bias[bias_sh_rd + 4];
__syncthreads();
}
if (use_atomic_add && slice_count > 1 && slice_idx != 0)
wait_negative_and_add(&locks[locks_off]);
if (last || use_atomic_add)
// only the last block in a slice actually writes the result
write_result();
write_result(last);
slice_row = 0;
slice_col_par++;
slice_col++;
......@@ -1706,6 +1784,7 @@ __global__ void Marlin(
for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
}
bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x;
// Update slice k/n for scales loading
if constexpr (has_act_order) {
slice_k_start = tb_k * slice_row;
......
......@@ -326,6 +326,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// gptq_marlin Optimized Quantized GEMM for GPTQ.
ops.def(
"gptq_marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, "
"Tensor? b_bias_or_none,"
"Tensor b_scales, Tensor? global_scale, Tensor? b_zeros_or_none, Tensor? "
"g_idx_or_none, Tensor? perm_or_none, Tensor workspace, int b_q_type, "
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
......
......@@ -24,8 +24,10 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, modular_triton_fused_moe)
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
fused_moe as iterative_moe)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_permute_bias)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
rand_marlin_weight_fp4_like)
rand_marlin_weight_mxfp4_like, rand_marlin_weight_nvfp4_like)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
marlin_quant_fp8_torch)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
......@@ -476,7 +478,10 @@ def marlin_moe_generate_valid_test_cases():
if quant_type == scalar_types.float8_e4m3fn and \
group_size not in [-1, 128]:
return False
if quant_type == scalar_types.float4_e2m1f and group_size != 16:
if quant_type == scalar_types.float4_e2m1f:
if group_size not in [16, 32]:
return False
if dtype == torch.float16 and group_size == 32:
return False
if quant_type != scalar_types.float4_e2m1f and group_size == 16:
return False
......@@ -520,31 +525,6 @@ def test_fused_marlin_moe(
torch.cuda.manual_seed(0)
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
if quant_type == scalar_types.float8_e4m3fn:
if group_size not in [-1, 128]:
return
if act_order:
return
# Filter act_order
if act_order:
if quant_type == scalar_types.float8_e4m3fn:
return
if group_size == -1:
return
if group_size in (k, n):
return
if has_zp:
return
else:
if not is_k_full:
return
if quant_type == scalar_types.float4_e2m1f and group_size != 16:
return
if quant_type != scalar_types.float4_e2m1f and group_size == 16:
return
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20
......@@ -569,12 +549,18 @@ def test_fused_marlin_moe(
for i in range(w1.shape[0]):
if quant_type == scalar_types.float4_e2m1f:
if group_size == 16:
w_ref1, qweight1, scales1, global_scale1 = \
rand_marlin_weight_fp4_like(w1[i], group_size)
rand_marlin_weight_nvfp4_like(w1[i], group_size)
else:
w_ref1, qweight1, scales1 = \
rand_marlin_weight_mxfp4_like(w1[i], group_size)
global_scale1 = None
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
if global_scale1 is not None:
global_scale1_l.append(global_scale1)
elif quant_type == scalar_types.float8_e4m3fn:
w_ref1, qweight1, scales1 = marlin_quant_fp8_torch(
......@@ -620,12 +606,18 @@ def test_fused_marlin_moe(
for i in range(w2.shape[0]):
if quant_type == scalar_types.float4_e2m1f:
if group_size == 16:
w_ref2, qweight2, scales2, global_scale2 = \
rand_marlin_weight_fp4_like(w2[i], group_size)
rand_marlin_weight_nvfp4_like(w2[i], group_size)
else:
w_ref2, qweight2, scales2 = \
rand_marlin_weight_mxfp4_like(w2[i], group_size)
global_scale2 = None
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
if global_scale2 is not None:
global_scale2_l.append(global_scale2)
elif quant_type == scalar_types.float8_e4m3fn:
w_ref2, qweight2, scales2 = marlin_quant_fp8_torch(
......@@ -677,6 +669,8 @@ def test_fused_marlin_moe(
a,
qweight1,
qweight2,
None,
None,
scales1,
scales2,
score,
......@@ -698,6 +692,119 @@ def test_fused_marlin_moe(
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
@pytest.mark.flaky(reruns=2)
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
@pytest.mark.parametrize("m", [1, 256])
def test_fused_marlin_moe_with_bias(m):
torch.cuda.manual_seed(0)
e, topk = 32, 4
n, k = 2048, 2048
group_size = 128
act_order = False
is_k_full = True
quant_type = scalar_types.uint4b8
dtype = torch.half
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
b_bias1 = torch.randn((e, 2 * n), device="cuda", dtype=dtype) / 10
b_bias2 = torch.randn((e, k), device="cuda", dtype=dtype) / 10
b_bias1_l = []
w_ref1_l = []
qweight1_l = []
scales1_l = []
g_idx1_l = []
sort_indices1_l = []
for i in range(w1.shape[0]):
test_perm = torch.randperm(k)
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = \
marlin_quantize(w1[i].transpose(1, 0), quant_type,
group_size, act_order, test_perm)
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
g_idx1_l.append(g_idx1)
sort_indices1_l.append(sort_indices1)
b_bias1_l.append(marlin_permute_bias(b_bias1[i]))
w_ref1 = stack_and_dev(w_ref1_l)
qweight1 = stack_and_dev(qweight1_l).contiguous()
scales1 = stack_and_dev(scales1_l)
global_scale1 = None
g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
zeros1 = None
sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
marlin_bias1 = stack_and_dev(b_bias1_l) if b_bias1_l else None
b_bias2_l = []
w_ref2_l = []
qweight2_l = []
scales2_l = []
g_idx2_l = []
sort_indices2_l = []
for i in range(w2.shape[0]):
test_perm = torch.randperm(n)
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = \
marlin_quantize(w2[i].transpose(1, 0), quant_type,
group_size, act_order, test_perm)
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
g_idx2_l.append(g_idx2)
sort_indices2_l.append(sort_indices2)
b_bias2_l.append(marlin_permute_bias(b_bias2[i]))
w_ref2 = stack_and_dev(w_ref2_l)
qweight2 = stack_and_dev(qweight2_l).contiguous()
scales2 = stack_and_dev(scales2_l)
global_scale2 = None
g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
zeros2 = None
sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
marlin_bias2 = stack_and_dev(b_bias2_l) if b_bias2_l else None
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
with set_current_vllm_config(vllm_config):
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, b_bias1,
b_bias2)
marlin_output = torch.ops.vllm.fused_marlin_moe(
a,
qweight1,
qweight2,
marlin_bias1,
marlin_bias2,
scales1,
scales2,
score,
topk_weights,
topk_ids,
global_num_experts=e,
expert_map=None,
global_scale1=global_scale1,
global_scale2=global_scale2,
g_idx1=g_idx1,
g_idx2=g_idx2,
sort_indices1=sort_indices1,
sort_indices2=sort_indices2,
w1_zeros=zeros1,
w2_zeros=zeros2,
quant_type_id=quant_type.id,
is_k_full=is_k_full)
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
def test_moe_align_block_size_opcheck():
num_experts = 4
block_size = 4
......
......@@ -19,10 +19,11 @@ from vllm.model_executor.layers.quantization.qqq import (
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx,
marlin_make_workspace_new, marlin_permute_scales,
marlin_make_workspace_new, marlin_permute_bias, marlin_permute_scales,
query_marlin_supported_quant_types)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
FP4_MARLIN_SUPPORTED_GROUP_SIZES, rand_marlin_weight_fp4_like)
FP4_MARLIN_SUPPORTED_GROUP_SIZES, rand_marlin_weight_mxfp4_like,
rand_marlin_weight_nvfp4_like)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
marlin_quant_fp8_torch)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
......@@ -39,7 +40,7 @@ from vllm.scalar_type import scalar_types
ACT_ORDER_OPTS = [False, True]
K_FULL_OPTS = [False, True]
USE_ATOMIC_ADD_OPTS = [False, True]
USE_FP32_REDUCE_OPTS = [False, True]
USE_FP32_REDUCE_OPTS = [True]
MARLIN_K_CHUNKS = [128]
MARLIN_N_CHUNKS = [64, 256]
......@@ -202,17 +203,10 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
@pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
@pytest.mark.parametrize("use_atomic_add", USE_ATOMIC_ADD_OPTS)
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
def test_gptq_marlin_gemm(
k_chunk,
n_chunk,
quant_type,
group_size,
mnk_factors,
act_order,
is_k_full,
use_atomic_add,
use_fp32_reduce,
):
@pytest.mark.parametrize("dtype", DTYPES)
def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size,
mnk_factors, act_order, is_k_full, use_atomic_add,
use_fp32_reduce, dtype):
m_factor, n_factor, k_factor = mnk_factors
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
......@@ -231,14 +225,23 @@ def test_gptq_marlin_gemm(
if size_k % group_size != 0:
return
a_input = rand_data((size_m, size_k))
b_weight = rand_data((size_k, size_n))
a_input = rand_data((size_m, size_k), dtype)
b_weight = rand_data((size_k, size_n), dtype)
if quant_type == scalar_types.float4_e2m1f:
if group_size != 16 or act_order:
if group_size not in [16, 32] or act_order:
return
w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_fp4_like(
b_weight.T, group_size)
if group_size == 32 and dtype == torch.float16:
return
if group_size == 16:
w_ref, marlin_q_w, marlin_s, marlin_s2 = \
rand_marlin_weight_nvfp4_like(b_weight.T, group_size)
else:
w_ref, marlin_q_w, marlin_s = \
rand_marlin_weight_mxfp4_like(b_weight.T, group_size)
marlin_s2 = None
g_idx = None
sort_indices = None
marlin_zp = None
......@@ -272,8 +275,8 @@ def test_gptq_marlin_gemm(
workspace = marlin_make_workspace_new(w_ref.device)
opcheck(torch.ops._C.gptq_marlin_gemm,
(a_input, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, g_idx,
sort_indices, workspace, quant_type.id, a_input.shape[0],
(a_input, None, marlin_q_w, None, marlin_s, marlin_s2, marlin_zp,
g_idx, sort_indices, workspace, quant_type.id, a_input.shape[0],
b_weight.shape[1], a_input.shape[1], is_k_full, use_atomic_add,
use_fp32_reduce, False),
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
......@@ -282,6 +285,7 @@ def test_gptq_marlin_gemm(
a_input,
None,
marlin_q_w,
None,
marlin_s,
marlin_s2,
marlin_zp,
......@@ -418,6 +422,7 @@ def test_hqq_marlin_gemm(
a_input,
None,
marlin_w_q,
None,
marlin_s,
None,
marlin_zp,
......@@ -531,6 +536,7 @@ def test_marlin_gemm_subset_input():
a_input,
None,
marlin_q_w,
None,
marlin_s,
None,
marlin_zp,
......@@ -555,6 +561,53 @@ def test_marlin_gemm_subset_input():
assert max_diff < 0.04
@pytest.mark.parametrize("size_m", [1, 256])
def test_marlin_gemm_with_bias(size_m):
quant_type = scalar_types.uint4b8
group_size = 128
size_k, size_n = 1024, 2048
a_input = rand_data((size_m, size_k))
b_weight = rand_data((size_k, size_n))
b_bias = rand_data((size_n, )) * 10
marlin_bias = marlin_permute_bias(b_bias)
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
b_weight, quant_type, group_size, False)
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
workspace = marlin_make_workspace_new(a_input.device)
output = ops.gptq_marlin_gemm(
a_input,
None,
marlin_q_w,
marlin_bias,
marlin_s,
None,
marlin_zp,
g_idx,
sort_indices,
workspace,
quant_type,
a_input.shape[0],
b_weight.shape[1],
a_input.shape[1],
is_k_full=True,
use_atomic_add=False,
use_fp32_reduce=True,
is_zp_float=False,
)
output_ref = torch.matmul(a_input, w_ref) + b_bias.view(1, -1)
torch.cuda.synchronize()
max_diff = compute_max_diff(output, output_ref)
assert max_diff < 0.04
def test_marlin_gemm_opcheck():
size_m = 2048
size_n = 4096
......
......@@ -1064,6 +1064,8 @@ def torch_experts(
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
global_num_experts: int = -1,
b_bias1: Optional[torch.Tensor] = None,
b_bias2: Optional[torch.Tensor] = None,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
......@@ -1108,8 +1110,13 @@ def torch_experts(
if mask.sum():
if quant_dtype is None:
tmp1 = a[mask] @ w1[i].transpose(0, 1)
if b_bias1 is not None:
tmp1 = tmp1 + b_bias1[i].view(1, -1).to(tmp1.dtype)
tmp2 = SiluAndMul()(tmp1)
out[mask] = tmp2 @ w2[i].transpose(0, 1)
if b_bias2 is not None:
out[mask] = out[mask] + b_bias2[i].view(1, -1).to(
tmp1.dtype)
elif block_shape is not None:
# block quantized
assert (a_scale is not None and w1_scale is not None
......@@ -1117,6 +1124,8 @@ def torch_experts(
tmp1 = native_w8a8_block_matmul(a[mask], w1[i], a_scale[mask],
w1_scale[i], block_shape,
out.dtype)
if b_bias1 is not None:
tmp1 = tmp1 + b_bias1[i].view(1, -1).to(tmp1.dtype)
tmp2 = SiluAndMul()(tmp1)
tmp2, b_scale = moe_kernel_quantize_input(
tmp2, a2_scale, quant_dtype, per_act_token_quant,
......@@ -1125,6 +1134,9 @@ def torch_experts(
out[mask] = native_w8a8_block_matmul(tmp2, w2[i], b_scale,
w2_scale[i], block_shape,
out.dtype)
if b_bias2 is not None:
out[mask] = out[mask] + b_bias2[i].view(1, -1).to(
tmp1.dtype)
else:
assert (a_scale is not None and w1_scale is not None
and w2_scale is not None)
......@@ -1133,6 +1145,8 @@ def torch_experts(
tmp1 = a[mask].to(f32) * scales
w1_dq = (w1[i].to(f32) * w1_scale[i]).transpose(0, 1)
tmp1 = (tmp1 @ w1_dq).to(out.dtype)
if b_bias1 is not None:
tmp1 = tmp1 + b_bias1[i].view(1, -1).to(out.dtype)
tmp2 = SiluAndMul()(tmp1).to(out.dtype)
......@@ -1144,6 +1158,9 @@ def torch_experts(
tmp2 = tmp2.to(f32) * b_scale
w2_dq = (w2[i].to(f32) * w2_scale[i]).transpose(0, 1)
out[mask] = (tmp2 @ w2_dq).to(out.dtype)
if b_bias2 is not None:
out[mask] = out[mask] + b_bias2[i].view(1, -1).to(
out.dtype)
if apply_router_weights_on_input:
return out
......@@ -1157,12 +1174,14 @@ def torch_moe(a: torch.Tensor,
w2: torch.Tensor,
score: torch.Tensor,
topk: int,
b_bias1: Optional[torch.Tensor] = None,
b_bias2: Optional[torch.Tensor] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None) -> torch.Tensor:
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
return torch_experts(a, w1, w2, topk_weight, topk_ids, global_num_experts,
expert_map)
b_bias1, b_bias2, expert_map)
def torch_moe_single(a, w, score, topk):
......
......@@ -452,6 +452,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
def _gptq_marlin_gemm_fake(a: torch.Tensor,
c: Optional[torch.Tensor],
b_q_weight: torch.Tensor,
b_bias: Optional[torch.Tensor],
b_scales: torch.Tensor,
global_scale: Optional[torch.Tensor],
b_zeros: Optional[torch.Tensor],
......@@ -1048,6 +1049,7 @@ def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
def gptq_marlin_gemm(a: torch.Tensor,
c: Optional[torch.Tensor],
b_q_weight: torch.Tensor,
b_bias: Optional[torch.Tensor],
b_scales: torch.Tensor,
global_scale: Optional[torch.Tensor],
b_zeros: Optional[torch.Tensor],
......@@ -1062,7 +1064,7 @@ def gptq_marlin_gemm(a: torch.Tensor,
use_atomic_add: bool = False,
use_fp32_reduce: bool = False,
is_zp_float: bool = False) -> torch.Tensor:
return torch.ops._C.gptq_marlin_gemm(a, c, b_q_weight, b_scales,
return torch.ops._C.gptq_marlin_gemm(a, c, b_q_weight, b_bias, b_scales,
global_scale, b_zeros, g_idx, perm,
workspace, b_q_type.id, size_m,
size_n, size_k, is_k_full,
......@@ -1540,7 +1542,9 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor],
b_qweight: torch.Tensor, b_scales: torch.Tensor,
b_qweight: torch.Tensor,
b_bias: Optional[torch.Tensor],
b_scales: torch.Tensor,
global_scale: Optional[torch.Tensor],
b_qzeros: Optional[torch.Tensor],
g_idx: Optional[torch.Tensor],
......@@ -1556,11 +1560,11 @@ def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor],
use_fp32_reduce: bool,
is_zp_float: bool) -> torch.Tensor:
return torch.ops._moe_C.moe_wna16_marlin_gemm(
input, output, b_qweight, b_scales, global_scale, b_qzeros, g_idx,
perm, workspace, sorted_token_ids, expert_ids, num_tokens_past_padded,
topk_weights, moe_block_size, top_k, mul_topk_weights, is_ep,
b_q_type.id, size_m, size_n, size_k, is_k_full, use_atomic_add,
use_fp32_reduce, is_zp_float)
input, output, b_qweight, b_bias, b_scales, global_scale, b_qzeros,
g_idx, perm, workspace, sorted_token_ids, expert_ids,
num_tokens_past_padded, topk_weights, moe_block_size, top_k,
mul_topk_weights, is_ep, b_q_type.id, size_m, size_n, size_k,
is_k_full, use_atomic_add, use_fp32_reduce, is_zp_float)
if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
......
......@@ -122,6 +122,7 @@ if TYPE_CHECKING:
VLLM_MOE_DP_CHUNK_SIZE: int = 256
VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
VLLM_MXFP4_USE_MARLIN: Optional[bool] = None
VLLM_V0_USE_OUTLINES_CACHE: bool = False
VLLM_V1_USE_OUTLINES_CACHE: bool = False
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
......@@ -182,6 +183,12 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
return int(value)
def maybe_convert_bool(value: Optional[str]) -> Optional[bool]:
if value is None:
return None
return bool(int(value))
def get_vllm_port() -> Optional[int]:
"""Get the port from VLLM_PORT environment variable.
......@@ -906,6 +913,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_MARLIN_USE_ATOMIC_ADD":
lambda: os.environ.get("VLLM_MARLIN_USE_ATOMIC_ADD", "0") == "1",
# Whether to use marlin kernel in mxfp4 quantization method
"VLLM_MXFP4_USE_MARLIN":
lambda: maybe_convert_bool(os.environ.get("VLLM_MXFP4_USE_MARLIN", None)),
# Whether to turn on the outlines cache for V0
# This cache is unbounded and on disk, so it's not safe to use in
# an environment with potentially malicious users.
......
......@@ -18,6 +18,8 @@ from vllm.utils import direct_register_custom_op
def fused_marlin_moe(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
bias1: Optional[torch.Tensor],
bias2: Optional[torch.Tensor],
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
gating_output: torch.Tensor,
......@@ -26,6 +28,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
quant_type_id: int,
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
activation: Optional[str] = "silu",
expert_map: Optional[torch.Tensor] = None,
global_scale1: Optional[torch.Tensor] = None,
global_scale2: Optional[torch.Tensor] = None,
......@@ -88,6 +91,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [torch.float16, torch.bfloat16]
assert num_bits in [4, 8]
assert topk_weights.dtype == torch.float32
M, K = hidden_states.shape
E = w1.shape[0]
......@@ -138,6 +142,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
hidden_states,
intermediate_cache1,
w1,
bias1,
w1_scale,
global_scale1,
w1_zeros,
......@@ -161,8 +166,28 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
use_fp32_reduce=True,
is_zp_float=False)
if activation == "silu":
torch.ops._C.silu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, 2 * N))
elif activation == "swiglu_oai":
# NOTE: in gpt-oss, the gate_proj and up_proj is interleaved
# - interleaved: gate, up = gate_up[..., ::2], gate_up[..., 1::2]
# - origin: gate, up = gate_up[..., :N], gate_up[..., N:]
@torch.compile(dynamic=True)
def swiglu_oai(gate_up):
alpha = 1.702
limit = 7.0
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
gate = gate.clamp(min=None, max=limit)
up = up.clamp(min=-limit, max=limit)
glu = gate * torch.sigmoid(gate * alpha)
return (up + 1) * glu
intermediate_cache2 = swiglu_oai(intermediate_cache1)
else:
raise ValueError(f"Unsupported activation: {activation}. "
"Only silu and swiglu_oai activations are supported.")
if expert_map is not None:
intermediate_cache3.zero_()
......@@ -171,6 +196,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
intermediate_cache2,
intermediate_cache3,
w2,
bias2,
w2_scale,
global_scale2,
w2_zeros,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment