Unverified Commit 33c63e95 authored by Jinzhen Lin's avatar Jinzhen Lin Committed by GitHub
Browse files

[Kernel] [Quantization] Add MXFP4 and bias support for marlin kernel (#22428)


Signed-off-by: default avatarrongfu.leng <rongfu.leng@daocloud.io>
Signed-off-by: default avatarJinzhen Lin <linjinzhen@hotmail.com>
Signed-off-by: default avatarHuzaifa Sidhpurwala <huzaifas@redhat.com>
Signed-off-by: default avatarVarun Sundar Rabindranath <vsundarr@redhat.com>
Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
Signed-off-by: default avatarAnimesh Jain <anijain@umich.edu>
Signed-off-by: default avatarRui Qiao <ruisearch42@gmail.com>
Signed-off-by: default avatarXiongfei Wei <isaacwxf23@gmail.com>
Signed-off-by: default avatarNick Hill <nhill@redhat.com>
Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
Signed-off-by: default avatarkf <kuanfu.liu@embeddedllm.com>
Signed-off-by: default avatarvllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: default avatarNickLucche <nlucches@redhat.com>
Signed-off-by: default avatarDipika Sikka <dipikasikka1@gmail.com>
Signed-off-by: default avatarSage Moore <sage@neuralmagic.com>
Signed-off-by: default avatartjtanaavllm <tunjian.tan@amd.com>
Signed-off-by: default avatarYong Hoon Shin <yhshin@meta.com>
Signed-off-by: default avatarChih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: default avatarRoger Wang <hey@rogerw.me>
Signed-off-by: default avatarVadim Gimpelson <vadim.gimpelson@centml.ai>
Signed-off-by: default avatarIsotr0py <2037008807@qq.com>
Signed-off-by: default avatarzRzRzRzRzRzRzR <2448370773@qq.com>
Signed-off-by: default avatarChih-Chieh Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: default avatarThomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: default avataryan <yan.ma@intel.com>
Signed-off-by: default avatarYan Ma <yan.ma@intel.com>
Signed-off-by: default avatarXiao Liu <xiszishu@gmail.com>
Signed-off-by: default avatarjiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: default avatarYe (Charlotte) Qi <yeq@meta.com>
Signed-off-by: default avatarLopezCastroRoberto <roberto.lopez.castro@udc.es>
Signed-off-by: default avatarAndy Xie <andy.xning@gmail.com>
Signed-off-by: default avatarHaibin Lin <haibin.lin@bytedance.com>
Signed-off-by: default avatarDavid Ben-David <davidb@pliops.com>
Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: default avatarjiang1.li <jiang1.li@intel.com>
Signed-off-by: default avatarSeiji Eicher <seiji@anyscale.com>
Signed-off-by: default avatarzitian.zhao <zitian.zhao@tencentmusic.com>
Signed-off-by: default avatar22quinn <33176974+22quinn@users.noreply.github.com>
Signed-off-by: default avatarAbirdcfly <fp544037857@gmail.com>
Signed-off-by: default avatarGiancarlo Delfin <gdelfin@meta.com>
Signed-off-by: default avatarTyler Michael Smith <tyler@neuralmagic.com>
Signed-off-by: default avatarhuangweixiao <huangweixiao@msh.team>
Signed-off-by: default avataralyosha-swamy <raghav@arcee.ai>
Signed-off-by: default avatarEric Hanley <ericehanley@google.com>
Signed-off-by: default avatarAbatom <abzhonghua@gmail.com>
Signed-off-by: default avatarCLFutureX <775523362@qq.com>
Signed-off-by: default avatarLinkun Chen <github@lkchen.net>
Signed-off-by: default avatartjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: default avatarGregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Signed-off-by: default avatartlipoca9 <tlipoca9@gmail.com>
Signed-off-by: default avatarelvischenv <219235043+elvischenv@users.noreply.github.com>
Signed-off-by: default avatarzitian zhao <zitian.zhao@tencentmusic.com>
Signed-off-by: default avatarmgoin <michael@neuralmagic.com>
Signed-off-by: default avatarwang.yuqi <noooop@126.com>
Signed-off-by: default avatarBenji Beck <benjibeck@meta.com>
Signed-off-by: default avatarSiyuan Liu <lsiyuan@google.com>
Signed-off-by: default avatarBenjamin Chislett <benjamin.chislett@centml.ai>
Signed-off-by: default avatarisotr0py <2037008807@qq.com>
Signed-off-by: default avatarChen Zhang <zhangch99@outlook.com>
Signed-off-by: default avatarsimon-mo <xmo@berkeley.edu>
Signed-off-by: default avatarLucasWilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: default avatarZhang Jason <ning.zhang2@amd.com>
Signed-off-by: default avatarYongye Zhu <zyy1102000@gmail.com>
Signed-off-by: default avatarasafg <asafg@ai21.com>
Signed-off-by: default avatarSiyuan Fu <siyuanf@nvidia.com>
Signed-off-by: default avatarLain <fusiyuan2000@hotmail.com>
Signed-off-by: default avatarMax de Bayser <mbayser@br.ibm.com>
Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: default avatarKunshang Ji <kunshang.ji@intel.com>
Signed-off-by: default avatarTao He <linzhu.ht@alibaba-inc.com>
Signed-off-by: default avatarMichael Goin <mgoin64@gmail.com>
Signed-off-by: default avatarQscQ <qscqesze@gmail.com>
Signed-off-by: default avatarqingjun <qingjun@minimaxi.com>
Signed-off-by: default avatarSyed Muhammad Bin Asif <syedmba7@connect.hku.hk>
Signed-off-by: default avatarLionel Villard <villard@us.ibm.com>
Signed-off-by: default avatarycyaw66 <497410282@qq.com>
Signed-off-by: default avatarDavid Chen <530634352@qq.com>
Signed-off-by: default avatarLinkun <github@lkchen.net>
Signed-off-by: default avatarMoritz Sanft <58110325+msanft@users.noreply.github.com>
Signed-off-by: default avatarMing Yang <minos.future@gmail.com>
Signed-off-by: default avatarAdrian Garcia <adrian.garcia@inceptionai.ai>
Signed-off-by: default avatarshaojunqi <shaojunqi.sjq@alibaba-inc.com>
Signed-off-by: default avatarRicardo Decal <rdecal@anyscale.com>
Signed-off-by: default avatarAndrew Chan <andrewkchan.akc@gmail.com>
Signed-off-by: default avatarFelix Marty <Felix.Marty@amd.com>
Signed-off-by: default avatarAndrew Sansom <andrew@protopia.ai>
Signed-off-by: default avatarZhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: default avatarShu Wang <shuw@nvidia.com>
Signed-off-by: default avatarPo-Han Huang <pohanh@nvidia.com>
Signed-off-by: default avatarShu Wang. <shuw@nvidia.com>
Signed-off-by: default avatarXIn Li <xinli@nvidia.com>
Signed-off-by: default avatarJunhao Li <junhao@ubicloud.com>
Signed-off-by: default avatarchaunceyjiang <chaunceyjiang@gmail.com>
Signed-off-by: default avatariAmir97 <Amir.balwel@embeddedllm.com>
Signed-off-by: default avatariAmir97 <71513472+iAmir97@users.noreply.github.com>
Signed-off-by: <zyy1102000@gmail.com>
Signed-off-by: default avatarGuy Stone <guys@spotify.com>
Signed-off-by: <yyweiss@gmail.com>
Signed-off-by: default avataryyw <yyweiss@gmail.com>
Signed-off-by: default avatarRussell Bryant <rbryant@redhat.com>
Signed-off-by: default avatarPradyun Ramadorai <pradyunr@amazon.com>
Signed-off-by: default avatarPradyun92 <142861237+Pradyun92@users.noreply.github.com>
Signed-off-by: default avatarJinzhen Lin <jinzhen.ljz@antgroup.com>
Co-authored-by: default avatarrongfu.leng <rongfu.leng@daocloud.io>
Co-authored-by: default avatarHuzaifa Sidhpurwala <huzaifas@redhat.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: default avatarRussell Bryant <rbryant@redhat.com>
Co-authored-by: default avatarVarun Sundar Rabindranath <varunsundar08@gmail.com>
Co-authored-by: default avatarVarun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: default avatarJee Jee Li <pandaleefree@gmail.com>
Co-authored-by: default avatarMichael Goin <mgoin64@gmail.com>
Co-authored-by: default avatarAnimesh Jain <jainanimesh2305@yahoo.com>
Co-authored-by: default avatarRui Qiao <161574667+ruisearch42@users.noreply.github.com>
Co-authored-by: default avatarXiongfeiWei <isaacwxf23@gmail.com>
Co-authored-by: default avatarNick Hill <nhill@redhat.com>
Co-authored-by: default avatarWentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: default avatarJartX <sagformas@gmail.com>
Co-authored-by: default avatarfhl2000 <63384265+fhl2000@users.noreply.github.com>
Co-authored-by: default avatarvllmellm <vllm.ellm@embeddedllm.com>
Co-authored-by: default avatarkf <kuanfu.liu@embeddedllm.com>
Co-authored-by: default avatarNicolò Lucchesi <nlucches@redhat.com>
Co-authored-by: default avatarDipika Sikka <dipikasikka1@gmail.com>
Co-authored-by: default avatarSage Moore <sage@neuralmagic.com>
Co-authored-by: default avatartjtanaavllm <tunjian.tan@amd.com>
Co-authored-by: default avatarYong Hoon Shin <48474650+sarckk@users.noreply.github.com>
Co-authored-by: default avatarChih-Chieh Yang <7364402+cyang49@users.noreply.github.com>
Co-authored-by: default avatarRoger Wang <hey@rogerw.me>
Co-authored-by: default avatarVadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com>
Co-authored-by: default avatarYuxuan Zhang <2448370773@qq.com>
Co-authored-by: default avatarIsotr0py <2037008807@qq.com>
Co-authored-by: default avatarCyrus Leung <tlleungac@connect.ust.hk>
Co-authored-by: default avatarThomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: default avatarYan Ma <yan.ma@intel.com>
Co-authored-by: default avatarXiao <xiszishu@gmail.com>
Co-authored-by: default avatarjiahanc <173873397+jiahanc@users.noreply.github.com>
Co-authored-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: default avatarYe (Charlotte) Qi <yeq@meta.com>
Co-authored-by: default avatarRoberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com>
Co-authored-by: default avatarNing Xie <andy.xning@gmail.com>
Co-authored-by: default avatarH <linhaibin.eric@gmail.com>
Co-authored-by: default avatarDavid Ben-David <sdavidbd@gmail.com>
Co-authored-by: default avatarDavid Ben-David <davidb@pliops.com>
Co-authored-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: default avatarLi, Jiang <jiang1.li@intel.com>
Co-authored-by: default avatarTankNee <nee@tanknee.cn>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: default avatarSeiji Eicher <58963096+eicherseiji@users.noreply.github.com>
Co-authored-by: default avatarZiTian.Zhao <zitian.zhao@tencentmusic.com>
Co-authored-by: default avatar22quinn <33176974+22quinn@users.noreply.github.com>
Co-authored-by: default avatarAbirdcfly <fp544037857@gmail.com>
Co-authored-by: default avatarGiancarlo Delfin <32987265+TheEpicDolphin@users.noreply.github.com>
Co-authored-by: default avatarChenxi Yang <cxyang@cs.utexas.edu>
Co-authored-by: default avatarChenxi Yang <cxyang@meta.com>
Co-authored-by: default avatarTyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: default avatarWeixiao Huang <hwx.simle@gmail.com>
Co-authored-by: default avatarRaghav Ravishankar <113712354+alyosha-swamy@users.noreply.github.com>
Co-authored-by: default avatarericehanley <ericehanley@google.com>
Co-authored-by: default avatarZhonghua Deng <abzhonghua@gmail.com>
Co-authored-by: default avatarPo-Han Huang (NVIDIA) <53919306+nvpohanh@users.noreply.github.com>
Co-authored-by: default avatarPiteXChen <44110731+CLFutureX@users.noreply.github.com>
Co-authored-by: default avatarlkchen <github@lkchen.net>
Co-authored-by: default avatarTJian <tunjian.tan@embeddedllm.com>
Co-authored-by: default avatarGregory Shtrasberg <156009573+gshtras@users.noreply.github.com>
Co-authored-by: default avatartlipoca9 <160737620+tlipoca9@users.noreply.github.com>
Co-authored-by: default avatarelvischenv <219235043+elvischenv@users.noreply.github.com>
Co-authored-by: default avatarwang.yuqi <noooop@126.com>
Co-authored-by: default avatarBenji Beck <benjibeck@meta.com>
Co-authored-by: default avataryoukaichao <youkaichao@gmail.com>
Co-authored-by: default avatarSiyuan Liu <lsiyuan@google.com>
Co-authored-by: default avatarBenjamin Chislett <chislett.ben@gmail.com>
Co-authored-by: default avatarLiuXiaoxuanPKU <lilyliupku@gmail.com>
Co-authored-by: default avatarsimon-mo <xmo@berkeley.edu>
Co-authored-by: default avatarChen Zhang <zhangch99@outlook.com>
Co-authored-by: default avatarHongxia Yang <62075498+hongxiayang@users.noreply.github.com>
Co-authored-by: default avatarMinseok Lee <47620120+minseokl@users.noreply.github.com>
Co-authored-by: default avatarYongye Zhu <zyy1102000@gmail.com>
Co-authored-by: default avatarLucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: default avatarZhang Jason <ning.zhang2@amd.com>
Co-authored-by: default avatarAsaf Joseph Gardin <39553475+Josephasafg@users.noreply.github.com>
Co-authored-by: default avatarasafg <asafg@ai21.com>
Co-authored-by: default avatarLain <siyuanf@nvidia.com>
Co-authored-by: default avatartc-mb <157115220+tc-mb@users.noreply.github.com>
Co-authored-by: default avatarimning3 <hbning@pku.edu.cn>
Co-authored-by: default avatarMaximilien de Bayser <mbayser@br.ibm.com>
Co-authored-by: default avatarKunshang Ji <kunshang.ji@intel.com>
Co-authored-by: default avatarTao He <linzhu.ht@alibaba-inc.com>
Co-authored-by: default avatarqscqesze <qingjun@minimaxi.com>
Co-authored-by: default avatarSyed Muhammad Bin Asif <92625830+syedmba@users.noreply.github.com>
Co-authored-by: default avatarLionel Villard <villard@us.ibm.com>
Co-authored-by: default avatarWeiQing Chen <40507679+david6666666@users.noreply.github.com>
Co-authored-by: default avatarycyaw66 <497410282@qq.com>
Co-authored-by: default avatarMoritz Sanft <58110325+msanft@users.noreply.github.com>
Co-authored-by: default avatarMing Yang <minos.future@gmail.com>
Co-authored-by: default avatarAdrián García García <adrigarvk8@gmail.com>
Co-authored-by: default avatarMichael Goin <mgoin@redhat.com>
Co-authored-by: default avatarJaceyShao <65159281+JaceyShao@users.noreply.github.com>
Co-authored-by: default avatarshaojunqi <shaojunqi.sjq@alibaba-inc.com>
Co-authored-by: default avatarRicardo Decal <crypdick@users.noreply.github.com>
Co-authored-by: default avatarAndrew Chan <andrewkchan.akc@gmail.com>
Co-authored-by: default avatarfxmarty-amd <felmarty@amd.com>
Co-authored-by: default avatarAndrew Sansom <andrew@protopia.ai>
Co-authored-by: default avatarZhiyu <zhiyuc@nvidia.com>
Co-authored-by: default avatarShu Wang <shuw@nvidia.com>
Co-authored-by: default avatarXIn Li <xinli@nvidia.com>
Co-authored-by: default avatarJunhao Li <streaver91@gmail.com>
Co-authored-by: default avatarChauncey <chaunceyjiang@gmail.com>
Co-authored-by: default avatariAmir97 <71513472+iAmir97@users.noreply.github.com>
Co-authored-by: default avatariAmir97 <Amir.balwel@embeddedllm.com>
Co-authored-by: default avatarHong Hanh <hanh.usth@gmail.com>
Co-authored-by: default avatarDaniel Serebrenik <74646983+pliops-daniels@users.noreply.github.com>
Co-authored-by: default avataryewentao256 <zhyanwentao@126.com>
Co-authored-by: default avatarGuy Stone <guys@spotify.com>
Co-authored-by: default avataryyweiss <70619747+yyweiss@users.noreply.github.com>
Co-authored-by: default avatarPradyun92 <142861237+Pradyun92@users.noreply.github.com>
Co-authored-by: default avatarPradyun Ramadorai <pradyunr@amazon.com>
Co-authored-by: default avatarNicolò Lucchesi <nicolo.lucchesi@gmail.com>
parent ab9f2cfd
...@@ -351,6 +351,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -351,6 +351,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(
SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}" SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}"
CUDA_ARCHS "${MARLIN_ARCHS}") CUDA_ARCHS "${MARLIN_ARCHS}")
set_source_files_properties(${MARLIN_TEMPLATE_KERNEL_SRC}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC}) list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC})
...@@ -364,7 +366,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -364,7 +366,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(
SRCS "${MARLIN_SRCS}" SRCS "${MARLIN_SRCS}"
CUDA_ARCHS "${MARLIN_ARCHS}") CUDA_ARCHS "${MARLIN_ARCHS}")
set_source_files_properties("csrc/quantization/gptq_marlin/gptq_marlin.cu"
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
list(APPEND VLLM_EXT_SRC "${MARLIN_SRCS}") list(APPEND VLLM_EXT_SRC "${MARLIN_SRCS}")
message(STATUS "Building Marlin kernels for archs: ${MARLIN_ARCHS}") message(STATUS "Building Marlin kernels for archs: ${MARLIN_ARCHS}")
else() else()
message(STATUS "Not building Marlin kernels as no compatible archs found" message(STATUS "Not building Marlin kernels as no compatible archs found"
...@@ -854,6 +859,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -854,6 +859,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(
SRCS "${MOE_WNAA16_MARLIN_SRC}" SRCS "${MOE_WNAA16_MARLIN_SRC}"
CUDA_ARCHS "${MARLIN_MOE_ARCHS}") CUDA_ARCHS "${MARLIN_MOE_ARCHS}")
set_source_files_properties(${MOE_WNAA16_MARLIN_SRC}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
list(APPEND VLLM_MOE_EXT_SRC ${MOE_WNAA16_MARLIN_SRC}) list(APPEND VLLM_MOE_EXT_SRC ${MOE_WNAA16_MARLIN_SRC})
......
...@@ -236,6 +236,7 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable: ...@@ -236,6 +236,7 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
a=bt.a, a=bt.a,
c=None, c=None,
b_q_weight=w_q, b_q_weight=w_q,
b_bias=None,
b_scales=w_s, b_scales=w_s,
global_scale=None, global_scale=None,
b_zeros=w_zp, b_zeros=w_zp,
......
...@@ -321,6 +321,8 @@ static inline constexpr auto kFE3M2f = ...@@ -321,6 +321,8 @@ static inline constexpr auto kFE3M2f =
ScalarType::float_(3, 2, true, ScalarType::NAN_NONE); ScalarType::float_(3, 2, true, ScalarType::NAN_NONE);
static inline constexpr auto kFE4M3fn = static inline constexpr auto kFE4M3fn =
ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN); ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN);
static inline constexpr auto kFE8M0fnu =
ScalarType(8, 0, false, 0, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN);
static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2); static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2);
static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7); static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7);
static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10); static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10);
......
...@@ -20,6 +20,7 @@ namespace MARLIN_NAMESPACE_NAME { ...@@ -20,6 +20,7 @@ namespace MARLIN_NAMESPACE_NAME {
TEMPLATE = ("template __global__ void Marlin<" TEMPLATE = ("template __global__ void Marlin<"
"{{scalar_t}}, " "{{scalar_t}}, "
"{{w_type_id}}, " "{{w_type_id}}, "
"{{s_type_id}}, "
"{{threads}}, " "{{threads}}, "
"{{thread_m_blocks}}, " "{{thread_m_blocks}}, "
"{{thread_n_blocks}}, " "{{thread_n_blocks}}, "
...@@ -77,6 +78,7 @@ def generate_new_kernels(): ...@@ -77,6 +78,7 @@ def generate_new_kernels():
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]: if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
continue continue
# nvfp4 only supports group_size == 16 # nvfp4 only supports group_size == 16
# mxfp4 only supports group_size == 32
if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]: if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]:
continue continue
# other quantization methods don't support group_size = 16 # other quantization methods don't support group_size = 16
...@@ -89,9 +91,22 @@ def generate_new_kernels(): ...@@ -89,9 +91,22 @@ def generate_new_kernels():
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16" c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"
if scalar_type == "vllm::kFE2M1f" and group_blocks == 1:
s_type = "vllm::kFE4M3fn"
elif scalar_type == "vllm::kFE2M1f" and group_blocks == 2:
s_type = "vllm::kFE8M0fnu"
if dtype == "fp16":
# we cannot safely dequantize e8m0 to fp16, so skip this
continue
elif dtype == "fp16":
s_type = "vllm::kFloat16"
elif dtype == "bf16":
s_type = "vllm::kBFloat16"
template_str = jinja2.Template(TEMPLATE).render( template_str = jinja2.Template(TEMPLATE).render(
scalar_t=c_dtype, scalar_t=c_dtype,
w_type_id=scalar_type + ".id()", w_type_id=scalar_type + ".id()",
s_type_id=s_type + ".id()",
threads=threads, threads=threads,
thread_m_blocks=max(m_blocks, 1), thread_m_blocks=max(m_blocks, 1),
thread_n_blocks=n_blocks, thread_n_blocks=n_blocks,
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#define MARLIN_KERNEL_PARAMS \ #define MARLIN_KERNEL_PARAMS \
const int4 *__restrict__ A, const int4 *__restrict__ B, \ const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ b_bias_ptr, \
const int4 *__restrict__ scales_ptr, \ const int4 *__restrict__ scales_ptr, \
const uint16_t *__restrict__ scale2_ptr, \ const uint16_t *__restrict__ scale2_ptr, \
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
...@@ -18,12 +19,13 @@ ...@@ -18,12 +19,13 @@
const int32_t *__restrict__ num_tokens_past_padded_ptr, \ const int32_t *__restrict__ num_tokens_past_padded_ptr, \
const float *__restrict__ topk_weights_ptr, int top_k, \ const float *__restrict__ topk_weights_ptr, int top_k, \
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \ bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
int prob_n, int prob_k, int *locks, bool use_atomic_add, \ int prob_n, int prob_k, int *locks, bool has_bias, bool use_atomic_add, \
bool use_fp32_reduce, int max_shared_mem bool use_fp32_reduce, int max_shared_mem
namespace MARLIN_NAMESPACE_NAME { namespace MARLIN_NAMESPACE_NAME {
template <typename scalar_t, // compute dtype, half or nv_float16 template <typename scalar_t, // compute dtype, half or nv_float16
const vllm::ScalarTypeId w_type_id, // weight ScalarType id const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const vllm::ScalarTypeId s_type_id, // weight scale ScalarType id
const int threads, // number of threads in a threadblock const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the // dimension (batchsize) of the
......
...@@ -280,6 +280,7 @@ __device__ inline void wait_negative_and_add(int* lock) { ...@@ -280,6 +280,7 @@ __device__ inline void wait_negative_and_add(int* lock) {
template <typename scalar_t, // compute dtype, half or nv_float16 template <typename scalar_t, // compute dtype, half or nv_float16
const vllm::ScalarTypeId w_type_id, // weight ScalarType id const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const vllm::ScalarTypeId s_type_id, // weight scale ScalarType id
const int threads, // number of threads in a threadblock const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the // dimension (batchsize) of the
...@@ -299,6 +300,7 @@ __global__ void Marlin( ...@@ -299,6 +300,7 @@ __global__ void Marlin(
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
int4* __restrict__ C, // fp16 output buffer of shape mxn int4* __restrict__ C, // fp16 output buffer of shape mxn
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
const int4* __restrict__ b_bias_ptr,
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
// (k/groupsize)xn // (k/groupsize)xn
const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4 const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4
...@@ -318,6 +320,7 @@ __global__ void Marlin( ...@@ -318,6 +320,7 @@ __global__ void Marlin(
int prob_n, // output dimension n int prob_n, // output dimension n
int prob_k, // reduction dimension k int prob_k, // reduction dimension k
int* locks, // extra global storage for barrier synchronization int* locks, // extra global storage for barrier synchronization
bool has_bias,
bool use_atomic_add, // whether to use atomic add to reduce bool use_atomic_add, // whether to use atomic add to reduce
bool use_fp32_reduce, // whether to use fp32 global reduce bool use_fp32_reduce, // whether to use fp32 global reduce
int max_shared_mem) { int max_shared_mem) {
...@@ -342,12 +345,23 @@ __global__ void Marlin( ...@@ -342,12 +345,23 @@ __global__ void Marlin(
extern __shared__ int4 sh[]; extern __shared__ int4 sh[];
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
static constexpr auto s_type = vllm::ScalarType::from_id(s_type_id);
if constexpr (w_type == vllm::kFE2M1f) {
static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 ||
s_type == vllm::kFE8M0fnu && group_blocks == 2);
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
static_assert(s_type == vllm::kBFloat16);
} else if constexpr (std::is_same<scalar_t, half>::value) {
static_assert(s_type == vllm::kFloat16);
}
constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8; constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8;
constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 || constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 ||
w_type == vllm::kU4B8 || w_type == vllm::kU8B128; w_type == vllm::kU4B8 || w_type == vllm::kU8B128;
// see comments of dequant.h for more details // see comments of dequant.h for more details
constexpr bool dequant_skip_flop = constexpr bool dequant_skip_flop =
!is_int_type || w_type == vllm::kFE4M3fn ||
w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn ||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value || has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
has_zp && !is_zp_float && !(w_type == vllm::kU8); has_zp && !is_zp_float && !(w_type == vllm::kU8);
...@@ -365,6 +379,7 @@ __global__ void Marlin( ...@@ -365,6 +379,7 @@ __global__ void Marlin(
const int zp_expert_stride = const int zp_expert_stride =
is_zp_float ? prob_n * prob_k / group_size / 8 is_zp_float ? prob_n * prob_k / group_size / 8
: prob_n * prob_k / group_size / (pack_factor * 4); : prob_n * prob_k / group_size / (pack_factor * 4);
const int b_bias_expert_stride = prob_n / 8;
// parallel: num valid moe blocks // parallel: num valid moe blocks
int num_tokens_past_padded = num_tokens_past_padded_ptr[0]; int num_tokens_past_padded = num_tokens_past_padded_ptr[0];
...@@ -475,7 +490,7 @@ __global__ void Marlin( ...@@ -475,7 +490,7 @@ __global__ void Marlin(
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
int idx = tid4 * 4 + i; int idx = tid4 * 4 + i;
idx = idx < block_num_valid_tokens ? idx : 0; idx = idx < block_num_valid_tokens ? idx : 0;
if constexpr (w_type == vllm::kFE2M1f) { if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
sh_block_topk_weights[idx] = __hmul2( sh_block_topk_weights[idx] = __hmul2(
global_scale, Dtype::num2num2(Dtype::float2num( global_scale, Dtype::num2num2(Dtype::float2num(
topk_weights_ptr[sh_block_sorted_ids[idx]]))); topk_weights_ptr[sh_block_sorted_ids[idx]])));
...@@ -513,7 +528,7 @@ __global__ void Marlin( ...@@ -513,7 +528,7 @@ __global__ void Marlin(
expert_id = expert_ids_ptr[block_id]; expert_id = expert_ids_ptr[block_id];
} }
if constexpr (w_type == vllm::kFE2M1f) { if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
uint16_t val = scale2_ptr[expert_id]; uint16_t val = scale2_ptr[expert_id];
global_scale = Dtype::num2num2(*reinterpret_cast<scalar_t*>(&val)); global_scale = Dtype::num2num2(*reinterpret_cast<scalar_t*>(&val));
} }
...@@ -526,6 +541,9 @@ __global__ void Marlin( ...@@ -526,6 +541,9 @@ __global__ void Marlin(
if constexpr (has_act_order) { if constexpr (has_act_order) {
g_idx += (expert_id - old_expert_id) * prob_k; g_idx += (expert_id - old_expert_id) * prob_k;
} }
if (has_bias) {
b_bias_ptr += (expert_id - old_expert_id) * b_bias_expert_stride;
}
read_moe_block_data(block_id); read_moe_block_data(block_id);
}; };
...@@ -721,7 +739,7 @@ __global__ void Marlin( ...@@ -721,7 +739,7 @@ __global__ void Marlin(
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 4; (threadIdx.x % 32) / 4;
s_sh_rd = s_sh_rd * 2 + warp_row % 2; s_sh_rd = s_sh_rd * 2 + (warp_row / group_blocks) % 2;
} else if constexpr (group_blocks != -1) } else if constexpr (group_blocks != -1)
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
...@@ -734,6 +752,18 @@ __global__ void Marlin( ...@@ -734,6 +752,18 @@ __global__ void Marlin(
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) % 4; (threadIdx.x % 32) % 4;
int bias_sh_rd;
if constexpr (m_block_size_8) {
bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 8;
} else {
bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) % 4;
}
int bias_sh_wr = threadIdx.x;
int bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x;
// Zero-points have the same read layout as the scales // Zero-points have the same read layout as the scales
// (without column-wise case) // (without column-wise case)
constexpr int num_col_threads = 8; constexpr int num_col_threads = 8;
...@@ -793,7 +823,19 @@ __global__ void Marlin( ...@@ -793,7 +823,19 @@ __global__ void Marlin(
constexpr int sh_b_size = stages * b_sh_stage; constexpr int sh_b_size = stages * b_sh_stage;
int4* sh_b = sh_new; int4* sh_b = sh_new;
int4* sh_red = sh_new; int4* sh_red = sh_new;
int4* sh_g_idx = sh_b + (sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
constexpr int sh_size_b_red_min =
(sh_red_size < sh_b_size ? sh_red_size : sh_b_size);
constexpr int sh_size_b_red_max =
(sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
constexpr int sh_bias_size = (thread_n_blocks * 16 / 8);
constexpr int sh_b_red_bias_size =
sh_size_b_red_max > (sh_size_b_red_min + sh_bias_size)
? sh_size_b_red_max
: (sh_size_b_red_min + sh_bias_size);
int4* sh_bias = sh_new + sh_size_b_red_min;
int4* sh_g_idx = sh_new + sh_b_red_bias_size;
int4* sh_zp = sh_g_idx + (stages * g_idx_stage); int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride) constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride)
: (stages * s_sh_stage); : (stages * s_sh_stage);
...@@ -803,9 +845,9 @@ __global__ void Marlin( ...@@ -803,9 +845,9 @@ __global__ void Marlin(
static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <= static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <=
stages * b_sh_stage); stages * b_sh_stage);
int4* sh_a = sh_s + sh_s_size; int4* sh_a = sh_s + sh_s_size;
constexpr int shm_size_used = constexpr int shm_size_used = moe_block_size +
moe_block_size + stages * (g_idx_stage + zp_sh_stage) + sh_s_size + stages * (g_idx_stage + zp_sh_stage) +
(sh_red_size > sh_b_size ? sh_red_size : sh_b_size); sh_s_size + sh_b_red_bias_size;
// all remaining shared memory is used to cache A (input) // all remaining shared memory is used to cache A (input)
// sh_a_max_row is at least ` stages * 16 * thread_m_blocks ` // sh_a_max_row is at least ` stages * 16 * thread_m_blocks `
...@@ -817,6 +859,7 @@ __global__ void Marlin( ...@@ -817,6 +859,7 @@ __global__ void Marlin(
I4 frag_b_quant[2][b_thread_vecs]; I4 frag_b_quant[2][b_thread_vecs];
FragC frag_c[thread_m_blocks][4][2]; FragC frag_c[thread_m_blocks][4][2];
FragS frag_s[2][4]; // No act-order FragS frag_s[2][4]; // No act-order
FragS frag_bias[2][4];
FragS act_frag_s[2][4][4]; // For act-order FragS act_frag_s[2][4][4]; // For act-order
int frag_qzp[2][num_ints_per_thread]; // Zero-points int frag_qzp[2][num_ints_per_thread]; // Zero-points
FragZP frag_zp; // Zero-points in fp16 FragZP frag_zp; // Zero-points in fp16
...@@ -1065,10 +1108,15 @@ __global__ void Marlin( ...@@ -1065,10 +1108,15 @@ __global__ void Marlin(
if constexpr (w_type_id != vllm::kFE2M1f.id()) { if constexpr (w_type_id != vllm::kFE2M1f.id()) {
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
} else { } else if constexpr (group_blocks == 1 || thread_k_blocks > 4) {
reinterpret_cast<int2*>(&frag_s[k % 2])[0] = reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
reinterpret_cast<int2*>( reinterpret_cast<int2*>(
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
} else {
reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
reinterpret_cast<int2*>(
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride) +
k % 2];
} }
} }
} }
...@@ -1281,9 +1329,9 @@ __global__ void Marlin( ...@@ -1281,9 +1329,9 @@ __global__ void Marlin(
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0]; int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1]; int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];
dequant_fp8_scales<scalar_t2>(s_quant_0, dequant_fp8_scales<scalar_t2, s_type_id>(
reinterpret_cast<scalar_t2*>(&frag_s[k2])); s_quant_0, reinterpret_cast<scalar_t2*>(&frag_s[k2]));
dequant_fp8_scales<scalar_t2>( dequant_fp8_scales<scalar_t2, s_type_id>(
s_quant_1, reinterpret_cast<scalar_t2*>(&frag_s[k2]) + 2); s_quant_1, reinterpret_cast<scalar_t2*>(&frag_s[k2]) + 2);
} }
...@@ -1566,7 +1614,7 @@ __global__ void Marlin( ...@@ -1566,7 +1614,7 @@ __global__ void Marlin(
// Write out the reduce final result in the correct layout. We only actually // Write out the reduce final result in the correct layout. We only actually
// reshuffle matrix fragments in this step, the reduction above is performed // reshuffle matrix fragments in this step, the reduction above is performed
// in fragment layout. // in fragment layout.
auto write_result = [&]() { auto write_result = [&](bool last) {
int c_gl_stride = prob_n / 8; int c_gl_stride = prob_n / 8;
constexpr int c_sh_stride = 2 * thread_n_blocks + 1; constexpr int c_sh_stride = 2 * thread_n_blocks + 1;
int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
...@@ -1592,7 +1640,7 @@ __global__ void Marlin( ...@@ -1592,7 +1640,7 @@ __global__ void Marlin(
// We first reorder in shared memory to guarantee the most efficient final // We first reorder in shared memory to guarantee the most efficient final
// global write patterns // global write patterns
auto write = [&](int idx, float c0, float c1, FragS& s) { auto write = [&](int idx, float c0, float c1, FragS& s, FragS& b_bias) {
scalar_t2 res = scalar_t2 res =
Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1));
...@@ -1601,14 +1649,27 @@ __global__ void Marlin( ...@@ -1601,14 +1649,27 @@ __global__ void Marlin(
if constexpr (!has_act_order && group_blocks == -1 && if constexpr (!has_act_order && group_blocks == -1 &&
w_type.size_bits() == 4 && w_type.size_bits() == 4 &&
(has_zp && dequant_skip_flop || !has_zp)) { (has_zp && dequant_skip_flop || !has_zp)) {
res = __hmul2(res, s[0]); scalar_t2 tmp_scale = s[0];
if constexpr (m_block_size_8) {
tmp_scale = Dtype::num2num2(
reinterpret_cast<scalar_t*>(&s[0])[(threadIdx.x % 8) / 4]);
}
res = __hmul2(res, tmp_scale);
} }
if constexpr (w_type == vllm::kFE2M1f) { if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
if (!mul_topk_weights) { if (!mul_topk_weights) {
res = __hmul2(res, global_scale); res = __hmul2(res, global_scale);
} }
} }
if (has_bias && last) {
scalar_t2 tmp_bias = b_bias[0];
if constexpr (m_block_size_8) {
tmp_bias = Dtype::num2num2(
reinterpret_cast<scalar_t*>(&b_bias[0])[(threadIdx.x % 8) / 4]);
}
res = __hadd2(res, tmp_bias);
}
if constexpr (m_block_size_8) { if constexpr (m_block_size_8) {
((scalar_t*)sh_red)[idx] = res.x; ((scalar_t*)sh_red)[idx] = res.x;
...@@ -1626,19 +1687,25 @@ __global__ void Marlin( ...@@ -1626,19 +1687,25 @@ __global__ void Marlin(
if constexpr (m_block_size_8) { if constexpr (m_block_size_8) {
int wr = c_sh_wr + 16 * j; int wr = c_sh_wr + 16 * j;
write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1], write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1],
frag_s[j / 2][2 * (j % 2) + 0]); frag_s[j / 2][2 * (j % 2) + 0],
frag_bias[j / 2][2 * (j % 2) + 0]);
write(wr + 8, frag_c[i][j][0][2], frag_c[i][j][0][3], write(wr + 8, frag_c[i][j][0][2], frag_c[i][j][0][3],
frag_s[j / 2][2 * (j % 2) + 1]); frag_s[j / 2][2 * (j % 2) + 1],
frag_bias[j / 2][2 * (j % 2) + 1]);
} else { } else {
int wr = c_sh_wr + 8 * j; int wr = c_sh_wr + 8 * j;
write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],
frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0],
frag_bias[j / 2][2 * (j % 2) + 0]);
write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],
frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0],
frag_bias[j / 2][2 * (j % 2) + 0]);
write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],
frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1],
frag_bias[j / 2][2 * (j % 2) + 1]);
write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],
frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1],
frag_bias[j / 2][2 * (j % 2) + 1]);
} }
} }
c_sh_wr += 16 * (4 * c_sh_stride); c_sh_wr += 16 * (4 * c_sh_stride);
...@@ -1805,6 +1872,14 @@ __global__ void Marlin( ...@@ -1805,6 +1872,14 @@ __global__ void Marlin(
} }
thread_block_reduce(); thread_block_reduce();
if (has_bias && last) {
__syncthreads();
cp_async4_pred(&sh_bias[bias_sh_wr], &b_bias_ptr[bias_gl_rd],
threadIdx.x < 16 * thread_n_blocks / 8);
cp_async_fence();
}
if constexpr (!has_act_order && group_blocks == -1 && if constexpr (!has_act_order && group_blocks == -1 &&
(has_zp && dequant_skip_flop || !has_zp)) { (has_zp && dequant_skip_flop || !has_zp)) {
if (w_type.size_bits() == 8 || (last || use_atomic_add)) { if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
...@@ -1867,11 +1942,20 @@ __global__ void Marlin( ...@@ -1867,11 +1942,20 @@ __global__ void Marlin(
} }
barrier_release(&locks[locks_off], last); barrier_release(&locks[locks_off], last);
} }
if (has_bias && last) {
cp_async_wait<0>();
__syncthreads();
reinterpret_cast<int4*>(&frag_bias)[0] = sh_bias[bias_sh_rd];
reinterpret_cast<int4*>(&frag_bias)[1] = sh_bias[bias_sh_rd + 4];
__syncthreads();
}
if (use_atomic_add && slice_count > 1 && slice_idx != 0) if (use_atomic_add && slice_count > 1 && slice_idx != 0)
wait_negative_and_add(&locks[locks_off]); wait_negative_and_add(&locks[locks_off]);
if (last || use_atomic_add) if (last || use_atomic_add)
// only the last block in a slice actually writes the result // only the last block in a slice actually writes the result
write_result(); write_result(last);
int old_slice_row = slice_row; int old_slice_row = slice_row;
slice_row = 0; slice_row = 0;
slice_col_par++; slice_col_par++;
...@@ -1904,6 +1988,7 @@ __global__ void Marlin( ...@@ -1904,6 +1988,7 @@ __global__ void Marlin(
for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
} }
bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x;
// Update slice k/n for scales loading // Update slice k/n for scales loading
if constexpr (has_act_order) { if constexpr (has_act_order) {
slice_k_start = tb_k * slice_row; slice_k_start = tb_k * slice_row;
......
...@@ -51,8 +51,9 @@ __global__ void permute_cols_kernel( ...@@ -51,8 +51,9 @@ __global__ void permute_cols_kernel(
} // namespace marlin } // namespace marlin
torch::Tensor moe_wna16_marlin_gemm( torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor& a, std::optional<torch::Tensor> const& c_or_none, torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
torch::Tensor& b_q_weight, torch::Tensor& b_scales, torch::Tensor& b_q_weight,
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
std::optional<torch::Tensor> const& b_zeros_or_none, std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none, std::optional<torch::Tensor> const& g_idx_or_none,
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace, std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
...@@ -212,7 +213,7 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8, ...@@ -212,7 +213,7 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
// Get B size // Get B size
int tb_k = th_config.thread_k; int tb_k = th_config.thread_k;
int tb_n = th_config.thread_n; int tb_n = th_config.thread_n;
int tb_m = thread_m_blocks * (m_block_size_8 ? 8 : 16); int tb_m = thread_m_blocks * 16;
// shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights // shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights
// both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32) // both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32)
...@@ -220,6 +221,11 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8, ...@@ -220,6 +221,11 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
int sh_a_size = pipe_stages * (tb_m * tb_k) * 2; int sh_a_size = pipe_stages * (tb_m * tb_k) * 2;
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4; int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
int sh_red_size = tb_m * (tb_n + 8) * 2; int sh_red_size = tb_m * (tb_n + 8) * 2;
int sh_bias_size = tb_n * 2;
int tmp_size =
(sh_b_size > sh_red_size ? sh_red_size : sh_b_size) + sh_bias_size;
tmp_size = max(max(sh_b_size, sh_red_size), tmp_size);
int sh_s_size = int sh_s_size =
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
group_size, has_act_order, is_k_full); group_size, has_act_order, is_k_full);
...@@ -234,8 +240,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8, ...@@ -234,8 +240,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
sh_zp_size = sh_s_size / 2; sh_zp_size = sh_s_size / 2;
} }
int total_size = max(sh_b_size, sh_red_size) + sh_a_size + sh_s_size + int total_size = tmp_size + sh_a_size + sh_s_size + sh_zp_size +
sh_zp_size + sh_g_idx_size + sh_block_meta_size; sh_g_idx_size + sh_block_meta_size;
return total_size; return total_size;
} }
...@@ -270,7 +276,7 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8, ...@@ -270,7 +276,7 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
int cache_size = get_kernel_cache_size( int cache_size = get_kernel_cache_size(
th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k, th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float); num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float);
return cache_size <= max_shared_mem; return cache_size + 512 <= max_shared_mem;
} }
#define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ #define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
...@@ -281,9 +287,14 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8, ...@@ -281,9 +287,14 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
m_block_size_8 == M_BLOCK_SIZE_8 && \ m_block_size_8 == M_BLOCK_SIZE_8 && \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
is_zp_float == IS_ZP_FLOAT) { \ is_zp_float == IS_ZP_FLOAT) { \
kernel = Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \ constexpr auto S_TYPE = \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, \ W_TYPE == vllm::kFE2M1f \
pipe_stages, GROUP_BLOCKS, IS_ZP_FLOAT>; \ ? (GROUP_BLOCKS == 1 ? vllm::kFE4M3fn : vllm::kFE8M0fnu) \
: (std::is_same<scalar_t, half>::value ? vllm::kFloat16 \
: vllm::kBFloat16); \
kernel = Marlin<scalar_t, W_TYPE.id(), S_TYPE.id(), NUM_THREADS, \
THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
M_BLOCK_SIZE_8, pipe_stages, GROUP_BLOCKS, IS_ZP_FLOAT>; \
} }
// COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false) // COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false)
...@@ -335,30 +346,44 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8, ...@@ -335,30 +346,44 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
\
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ #define BIGGROUP_GET_IF(W_TYPE) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \
BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \
BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128)
#define NVFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
#define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ #define NVFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
#define FP4_GET_IF(W_TYPE) \ #define NVFP4_GET_IF(W_TYPE) \
FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ NVFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ NVFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ NVFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
FP4_GET_IF_M234(W_TYPE, 8, 4, 128) NVFP4_GET_IF_M234(W_TYPE, 8, 4, 128)
#define BIGGROUP_GET_IF(W_TYPE) \ #define MXFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)
BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \
BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) #define MXFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)
#define MXFP4_GET_IF(W_TYPE) \
MXFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
MXFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
MXFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
MXFP4_GET_IF_M234(W_TYPE, 8, 4, 128)
// We currently have 4-bit models only with group_blocks == 4 // We currently have 4-bit models only with group_blocks == 4
#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ #define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
...@@ -408,12 +433,17 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type, ...@@ -408,12 +433,17 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
COMMON_GET_IF(vllm::kU4B8) COMMON_GET_IF(vllm::kU4B8)
COMMON_GET_IF(vllm::kU8B128) COMMON_GET_IF(vllm::kU8B128)
BIGGROUP_GET_IF(vllm::kFE4M3fn) NVFP4_GET_IF(vllm::kFE2M1f)
FP4_GET_IF(vllm::kFE2M1f) BIGGROUP_GET_IF(vllm::kFE4M3fn)
ACT_GET_IF(vllm::kU4B8) ACT_GET_IF(vllm::kU4B8)
ACT_GET_IF(vllm::kU8B128) ACT_GET_IF(vllm::kU8B128)
if (std::is_same<scalar_t, nv_bfloat16>::value) {
if (false) {
}
MXFP4_GET_IF(vllm::kFE2M1f)
}
return kernel; return kernel;
} }
...@@ -482,16 +512,16 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, ...@@ -482,16 +512,16 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
} }
template <typename scalar_t> template <typename scalar_t>
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
void* s2, void* zp, void* g_idx, void* perm, void* a_tmp, void* s, void* s2, void* zp, void* g_idx, void* perm,
void* sorted_token_ids, void* expert_ids, void* a_tmp, void* sorted_token_ids, void* expert_ids,
void* num_tokens_past_padded, void* topk_weights, void* num_tokens_past_padded, void* topk_weights,
int moe_block_size, int top_k, bool mul_topk_weights, bool is_ep, int moe_block_size, int top_k, bool mul_topk_weights, bool is_ep,
int prob_m, int prob_n, int prob_k, void* workspace, int prob_m, int prob_n, int prob_k, void* workspace,
vllm::ScalarType const& q_type, bool has_act_order, vllm::ScalarType const& q_type, bool has_bias,
bool is_k_full, bool has_zp, int num_groups, int group_size, bool has_act_order, bool is_k_full, bool has_zp, int num_groups,
int dev, cudaStream_t stream, int thread_k, int thread_n, int group_size, int dev, cudaStream_t stream, int thread_k,
int sms, bool use_atomic_add, bool use_fp32_reduce, int thread_n, int sms, bool use_atomic_add, bool use_fp32_reduce,
bool is_zp_float) { bool is_zp_float) {
int thread_m_blocks = div_ceil(moe_block_size, 16); int thread_m_blocks = div_ceil(moe_block_size, 16);
bool m_block_size_8 = moe_block_size == 8; bool m_block_size_8 = moe_block_size == 8;
...@@ -538,6 +568,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, ...@@ -538,6 +568,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
const int4* B_ptr = (const int4*)B; const int4* B_ptr = (const int4*)B;
int4* C_ptr = (int4*)C; int4* C_ptr = (int4*)C;
int4* C_tmp_ptr = (int4*)C_tmp; int4* C_tmp_ptr = (int4*)C_tmp;
const int4* bias_ptr = (const int4*)b_bias;
const int4* s_ptr = (const int4*)s; const int4* s_ptr = (const int4*)s;
const uint16_t* s2_ptr = (const uint16_t*)s2; const uint16_t* s2_ptr = (const uint16_t*)s2;
const int4* zp_ptr = (const int4*)zp; const int4* zp_ptr = (const int4*)zp;
...@@ -648,10 +679,10 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, ...@@ -648,10 +679,10 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
// avoid ">>>" being formatted to "> > >" // avoid ">>>" being formatted to "> > >"
// clang-format off // clang-format off
kernel<<<blocks, num_threads, max_shared_mem, stream>>>( kernel<<<blocks, num_threads, max_shared_mem, stream>>>(
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr, A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr,
sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr,
topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m, topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m,
prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce, max_shared_mem); prob_n, prob_k, locks, has_bias, use_atomic_add, use_fp32_reduce, max_shared_mem);
// clang-format on // clang-format on
} }
...@@ -659,7 +690,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, ...@@ -659,7 +690,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
torch::Tensor moe_wna16_marlin_gemm( torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor& a, std::optional<torch::Tensor> const& c_or_none, torch::Tensor& a, std::optional<torch::Tensor> const& c_or_none,
torch::Tensor& b_q_weight, torch::Tensor& b_scales, torch::Tensor& b_q_weight,
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
std::optional<torch::Tensor> const& global_scale_or_none, std::optional<torch::Tensor> const& global_scale_or_none,
std::optional<torch::Tensor> const& b_zeros_or_none, std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none, std::optional<torch::Tensor> const& g_idx_or_none,
...@@ -766,7 +798,6 @@ torch::Tensor moe_wna16_marlin_gemm( ...@@ -766,7 +798,6 @@ torch::Tensor moe_wna16_marlin_gemm(
num_groups = b_scales.size(1); num_groups = b_scales.size(1);
torch::Tensor g_idx, perm, a_tmp; torch::Tensor g_idx, perm, a_tmp;
;
if (g_idx_or_none.has_value() && perm_or_none.has_value()) { if (g_idx_or_none.has_value() && perm_or_none.has_value()) {
g_idx = g_idx_or_none.value(); g_idx = g_idx_or_none.value();
perm = perm_or_none.value(); perm = perm_or_none.value();
...@@ -815,12 +846,24 @@ torch::Tensor moe_wna16_marlin_gemm( ...@@ -815,12 +846,24 @@ torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor global_scale; torch::Tensor global_scale;
if (global_scale_or_none.has_value()) { if (global_scale_or_none.has_value()) {
global_scale = global_scale_or_none.value(); global_scale = global_scale_or_none.value();
TORCH_CHECK(b_q_type == vllm::kFE2M1f, TORCH_CHECK(b_q_type == vllm::kFE2M1f && group_size == 16,
"global_scale can only be used for float4_e2m1f."); "global_scale can only be used for nvfp4 format.");
} else { } else {
global_scale = torch::empty({0}, options); global_scale = torch::empty({0}, options);
TORCH_CHECK(!(b_q_type == vllm::kFE2M1f), TORCH_CHECK(!(b_q_type == vllm::kFE2M1f && group_size == 16),
"the global_scale parameter must be passed for float4_e2m1f."); "the global_scale parameter must be passed for nvfp4 format.");
}
bool has_bias = b_bias_or_none.has_value();
torch::Tensor b_bias;
if (has_bias) {
b_bias = b_bias_or_none.value();
TORCH_CHECK(b_bias.device().is_cuda(), "b_bias is not on GPU");
TORCH_CHECK(b_bias.is_contiguous(), "b_bias is not contiguous");
TORCH_CHECK(b_bias.size(1) == size_n, "b_bias.size(0) != size_n");
TORCH_CHECK(b_bias.stride(1) == 1, "b_bias.stride(1) != 1");
} else {
b_bias = torch::empty({0}, options);
} }
torch::Tensor b_zeros; torch::Tensor b_zeros;
...@@ -832,7 +875,6 @@ torch::Tensor moe_wna16_marlin_gemm( ...@@ -832,7 +875,6 @@ torch::Tensor moe_wna16_marlin_gemm(
b_zeros = torch::empty({0}, options); b_zeros = torch::empty({0}, options);
} }
bool has_zp = b_zeros.size(-1) > 0; bool has_zp = b_zeros.size(-1) > 0;
if (has_zp) { if (has_zp) {
TORCH_CHECK( TORCH_CHECK(
b_q_type == vllm::kU4 || b_q_type == vllm::kU8, b_q_type == vllm::kU4 || b_q_type == vllm::kU8,
...@@ -890,41 +932,58 @@ torch::Tensor moe_wna16_marlin_gemm( ...@@ -890,41 +932,58 @@ torch::Tensor moe_wna16_marlin_gemm(
if (a.scalar_type() == at::ScalarType::Half) { if (a.scalar_type() == at::ScalarType::Half) {
void* scales_ptr; void* scales_ptr;
if (b_q_type == vllm::kFE2M1f) { if (b_q_type == vllm::kFE2M1f) {
if (group_size == 16)
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>(); scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
else if (group_size == 32)
scales_ptr = b_scales.data_ptr<at::Float8_e8m0fnu>();
else
TORCH_CHECK(false,
"float4_e2m1f only supports group_size == 16 (NVFP4) ",
"and group_size == 32 (MXFP4)");
} else { } else {
scales_ptr = b_scales.data_ptr<at::Half>(); scales_ptr = b_scales.data_ptr<at::Half>();
} }
MARLIN_NAMESPACE_NAME::marlin_mm<half>( MARLIN_NAMESPACE_NAME::marlin_mm<half>(
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(), a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
c_tmp.data_ptr<float>(), scales_ptr, global_scale.data_ptr<at::Half>(), c_tmp.data_ptr<float>(), b_bias.data_ptr<at::Half>(), scales_ptr,
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), global_scale.data_ptr<at::Half>(), b_zeros.data_ptr(), g_idx.data_ptr(),
a_tmp.data_ptr<at::Half>(), sorted_token_ids.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::Half>(),
expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(), sorted_token_ids.data_ptr(), expert_ids.data_ptr(),
topk_weights.data_ptr(), moe_block_size, top_k, mul_topk_weights, is_ep, num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(),
size_m, size_n, size_k, workspace.data_ptr(), b_q_type, has_act_order, moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k,
is_k_full, has_zp, num_groups, group_size, dev, workspace.data_ptr(), b_q_type, has_bias, has_act_order, is_k_full,
has_zp, num_groups, group_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
use_atomic_add, use_fp32_reduce, is_zp_float); use_atomic_add, use_fp32_reduce, is_zp_float);
} else if (a.scalar_type() == at::ScalarType::BFloat16) { } else if (a.scalar_type() == at::ScalarType::BFloat16) {
void* scales_ptr; void* scales_ptr;
if (b_q_type == vllm::kFE2M1f) { if (b_q_type == vllm::kFE2M1f) {
if (group_size == 16)
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>(); scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
else if (group_size == 32)
scales_ptr = b_scales.data_ptr<at::Float8_e8m0fnu>();
else
TORCH_CHECK(false,
"float4_e2m1f only supports group_size == 16 (NVFP4) ",
"and group_size == 32 (MXFP4)");
} else { } else {
scales_ptr = b_scales.data_ptr<at::BFloat16>(); scales_ptr = b_scales.data_ptr<at::BFloat16>();
} }
MARLIN_NAMESPACE_NAME::marlin_mm<nv_bfloat16>( MARLIN_NAMESPACE_NAME::marlin_mm<nv_bfloat16>(
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(), a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(), scales_ptr, c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
b_bias.data_ptr<at::BFloat16>(), scales_ptr,
global_scale.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), global_scale.data_ptr<at::BFloat16>(), b_zeros.data_ptr(),
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
sorted_token_ids.data_ptr(), expert_ids.data_ptr(), sorted_token_ids.data_ptr(), expert_ids.data_ptr(),
num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(), num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(),
moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k, moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k,
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp, workspace.data_ptr(), b_q_type, has_bias, has_act_order, is_k_full,
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), has_zp, num_groups, group_size, dev,
thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float); at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
use_atomic_add, use_fp32_reduce, is_zp_float);
} else { } else {
TORCH_CHECK(false, TORCH_CHECK(false,
"moe_wna16_marlin_gemm only supports bfloat16 and float16"); "moe_wna16_marlin_gemm only supports bfloat16 and float16");
......
...@@ -35,7 +35,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { ...@@ -35,7 +35,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
m.def( m.def(
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none," "moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
"Tensor! b_q_weight, Tensor! b_scales, Tensor? global_scale, Tensor? " "Tensor! b_q_weight, Tensor? b_bias_or_none,"
"Tensor! b_scales, Tensor? global_scale, Tensor? "
"b_zeros_or_none," "b_zeros_or_none,"
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace," "Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
"Tensor sorted_token_ids," "Tensor sorted_token_ids,"
......
...@@ -470,11 +470,12 @@ __device__ inline void dequant<nv_bfloat162, vllm::kFE2M1f.id(), false>( ...@@ -470,11 +470,12 @@ __device__ inline void dequant<nv_bfloat162, vllm::kFE2M1f.id(), false>(
frag_b[0] = __hmul2(frag_b[0], bias_reg); frag_b[0] = __hmul2(frag_b[0], bias_reg);
} }
template <typename scalar_t2> template <typename scalar_t2, vllm::ScalarTypeId s_type_id>
__device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b); __device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b);
template <> template <>
__device__ inline void dequant_fp8_scales<half2>(int q, half2* frag_b) { __device__ inline void dequant_fp8_scales<half2, vllm::kFE4M3fn.id()>(
int q, half2* frag_b) {
int Out1 = (q & 0xFF00FF00) >> 1; int Out1 = (q & 0xFF00FF00) >> 1;
; ;
q <<= 8; q <<= 8;
...@@ -486,8 +487,8 @@ __device__ inline void dequant_fp8_scales<half2>(int q, half2* frag_b) { ...@@ -486,8 +487,8 @@ __device__ inline void dequant_fp8_scales<half2>(int q, half2* frag_b) {
}; };
template <> template <>
__device__ inline void dequant_fp8_scales<nv_bfloat162>(int q, __device__ inline void dequant_fp8_scales<nv_bfloat162, vllm::kFE4M3fn.id()>(
nv_bfloat162* frag_b) { int q, nv_bfloat162* frag_b) {
constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT;
constexpr int MASK = 0x7F007F00; constexpr int MASK = 0x7F007F00;
...@@ -502,6 +503,20 @@ __device__ inline void dequant_fp8_scales<nv_bfloat162>(int q, ...@@ -502,6 +503,20 @@ __device__ inline void dequant_fp8_scales<nv_bfloat162>(int q,
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2); frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
} }
template <>
__device__ inline void dequant_fp8_scales<nv_bfloat162, vllm::kFE8M0fnu.id()>(
int q, nv_bfloat162* frag_b) {
// In this conversion, 2 ** -127 in FP8E8M0 would become 0 in BF16,
// but we assume that such a extreme value would not occur in real models.
int Out1 = (q & 0xFF00FF00) >> 1;
q <<= 7;
int Out2 = q & 0x7F807F80;
// Note: reverse indexing is intentional because weights are permuted
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
}
#endif #endif
} // namespace MARLIN_NAMESPACE_NAME } // namespace MARLIN_NAMESPACE_NAME
...@@ -20,6 +20,7 @@ namespace MARLIN_NAMESPACE_NAME { ...@@ -20,6 +20,7 @@ namespace MARLIN_NAMESPACE_NAME {
TEMPLATE = ("template __global__ void Marlin<" TEMPLATE = ("template __global__ void Marlin<"
"{{scalar_t}}, " "{{scalar_t}}, "
"{{w_type_id}}, " "{{w_type_id}}, "
"{{s_type_id}}, "
"{{threads}}, " "{{threads}}, "
"{{thread_m_blocks}}, " "{{thread_m_blocks}}, "
"{{thread_n_blocks}}, " "{{thread_n_blocks}}, "
...@@ -78,7 +79,8 @@ def generate_new_kernels(): ...@@ -78,7 +79,8 @@ def generate_new_kernels():
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]: if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
continue continue
# nvfp4 only supports group_size == 16 # nvfp4 only supports group_size == 16
if scalar_type == "vllm::kFE2M1f" and group_blocks != 1: # mxfp4 only supports group_size == 32
if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]:
continue continue
# other quantization methods don't support group_size = 16 # other quantization methods don't support group_size = 16
if scalar_type != "vllm::kFE2M1f" and group_blocks == 1: if scalar_type != "vllm::kFE2M1f" and group_blocks == 1:
...@@ -97,10 +99,23 @@ def generate_new_kernels(): ...@@ -97,10 +99,23 @@ def generate_new_kernels():
# 4bit quantization and fp16 # 4bit quantization and fp16
is_zp_float_list.append(True) is_zp_float_list.append(True)
if scalar_type == "vllm::kFE2M1f" and group_blocks == 1:
s_type = "vllm::kFE4M3fn"
elif scalar_type == "vllm::kFE2M1f" and group_blocks == 2:
s_type = "vllm::kFE8M0fnu"
if dtype == "fp16":
# we cannot safely dequantize e8m0 to fp16, so skip this
continue
elif dtype == "fp16":
s_type = "vllm::kFloat16"
elif dtype == "bf16":
s_type = "vllm::kBFloat16"
for is_zp_float in is_zp_float_list: for is_zp_float in is_zp_float_list:
template_str = jinja2.Template(TEMPLATE).render( template_str = jinja2.Template(TEMPLATE).render(
scalar_t=c_dtype, scalar_t=c_dtype,
w_type_id=scalar_type + ".id()", w_type_id=scalar_type + ".id()",
s_type_id=s_type + ".id()",
threads=threads, threads=threads,
thread_m_blocks=max(m_blocks, 1), thread_m_blocks=max(m_blocks, 1),
thread_n_blocks=n_blocks, thread_n_blocks=n_blocks,
......
...@@ -48,7 +48,8 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, ...@@ -48,7 +48,8 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
torch::Tensor gptq_marlin_gemm( torch::Tensor gptq_marlin_gemm(
torch::Tensor& a, std::optional<torch::Tensor> c_or_none, torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
torch::Tensor& b_q_weight, torch::Tensor& b_scales, torch::Tensor& b_q_weight,
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
std::optional<torch::Tensor> const& b_zeros_or_none, std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none, std::optional<torch::Tensor> const& g_idx_or_none,
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace, std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
...@@ -187,7 +188,12 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks, ...@@ -187,7 +188,12 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
int tb_m = thread_m_blocks * 16; int tb_m = thread_m_blocks * 16;
int sh_a_size = pipe_stages * (tb_m * tb_k) * 2; int sh_a_size = pipe_stages * (tb_m * tb_k) * 2;
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4; int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
int sh_red_size = tb_m * (tb_n + 8); int sh_red_size = tb_m * (tb_n + 8) * 2;
int sh_bias_size = tb_n * 2;
int tmp_size =
(sh_b_size > sh_red_size ? sh_red_size : sh_b_size) + sh_bias_size;
tmp_size = max(max(sh_b_size, sh_red_size), tmp_size);
int sh_s_size = int sh_s_size =
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
group_size, has_act_order, is_k_full); group_size, has_act_order, is_k_full);
...@@ -202,8 +208,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks, ...@@ -202,8 +208,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
sh_zp_size = sh_s_size / 2; sh_zp_size = sh_s_size / 2;
} }
int total_size = max(sh_b_size, sh_red_size) + sh_a_size + sh_s_size + int total_size =
sh_zp_size + sh_g_idx_size; tmp_size + sh_a_size + sh_s_size + sh_zp_size + sh_g_idx_size;
return total_size; return total_size;
} }
...@@ -237,7 +243,7 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks, ...@@ -237,7 +243,7 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
int cache_size = get_kernel_cache_size( int cache_size = get_kernel_cache_size(
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size,
has_act_order, is_k_full, has_zp, is_zp_float); has_act_order, is_k_full, has_zp, is_zp_float);
return cache_size <= max_shared_mem; return cache_size + 512 <= max_shared_mem;
} }
#define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ #define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
...@@ -248,9 +254,14 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks, ...@@ -248,9 +254,14 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
m_block_size_8 == M_BLOCK_SIZE_8 && \ m_block_size_8 == M_BLOCK_SIZE_8 && \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
is_zp_float == IS_ZP_FLOAT) { \ is_zp_float == IS_ZP_FLOAT) { \
kernel = Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \ constexpr auto S_TYPE = \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, \ W_TYPE == vllm::kFE2M1f \
pipe_stages, GROUP_BLOCKS, IS_ZP_FLOAT>; \ ? (GROUP_BLOCKS == 1 ? vllm::kFE4M3fn : vllm::kFE8M0fnu) \
: (std::is_same<scalar_t, half>::value ? vllm::kFloat16 \
: vllm::kBFloat16); \
kernel = Marlin<scalar_t, W_TYPE.id(), S_TYPE.id(), NUM_THREADS, \
THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
M_BLOCK_SIZE_8, pipe_stages, GROUP_BLOCKS, IS_ZP_FLOAT>; \
} }
// COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false) // COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false)
...@@ -315,22 +326,39 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks, ...@@ -315,22 +326,39 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) \ BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) \
BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128) BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128)
#define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ #define NVFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
#define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ #define NVFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
#define FP4_GET_IF(W_TYPE) \ #define NVFP4_GET_IF(W_TYPE) \
FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ NVFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ NVFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
FP4_GET_IF_M1(W_TYPE, 4, 8, 128) \ NVFP4_GET_IF_M1(W_TYPE, 4, 8, 128) \
FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ NVFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
FP4_GET_IF_M234(W_TYPE, 8, 4, 128) \ NVFP4_GET_IF_M234(W_TYPE, 8, 4, 128) \
FP4_GET_IF_M234(W_TYPE, 4, 8, 128) NVFP4_GET_IF_M234(W_TYPE, 4, 8, 128)
#define MXFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)
#define MXFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)
#define MXFP4_GET_IF(W_TYPE) \
MXFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
MXFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
MXFP4_GET_IF_M1(W_TYPE, 4, 8, 128) \
MXFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
MXFP4_GET_IF_M234(W_TYPE, 8, 4, 128) \
MXFP4_GET_IF_M234(W_TYPE, 4, 8, 128)
// We currently have 4-bit models only with group_blocks == 4 // We currently have 4-bit models only with group_blocks == 4
#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ #define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
...@@ -384,7 +412,7 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type, ...@@ -384,7 +412,7 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
COMMON_GET_IF(vllm::kU4B8) COMMON_GET_IF(vllm::kU4B8)
COMMON_GET_IF(vllm::kU8B128) COMMON_GET_IF(vllm::kU8B128)
FP4_GET_IF(vllm::kFE2M1f) NVFP4_GET_IF(vllm::kFE2M1f)
BIGGROUP_GET_IF(vllm::kFE4M3fn) BIGGROUP_GET_IF(vllm::kFE4M3fn)
...@@ -396,6 +424,11 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type, ...@@ -396,6 +424,11 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
} }
FZP_GET_IF(vllm::kU4) FZP_GET_IF(vllm::kU4)
} }
if (std::is_same<scalar_t, nv_bfloat16>::value) {
if (false) {
}
MXFP4_GET_IF(vllm::kFE2M1f)
}
return kernel; return kernel;
} }
...@@ -453,12 +486,12 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, ...@@ -453,12 +486,12 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
} }
template <typename scalar_t> template <typename scalar_t>
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
void* s2, void* zp, void* g_idx, void* perm, void* a_tmp, void* s, void* s2, void* zp, void* g_idx, void* perm,
int prob_m, int prob_n, int prob_k, int lda, void* workspace, void* a_tmp, int prob_m, int prob_n, int prob_k, int lda,
vllm::ScalarType const& q_type, bool has_act_order, void* workspace, vllm::ScalarType const& q_type, bool has_bias,
bool is_k_full, bool has_zp, int num_groups, int group_size, bool has_act_order, bool is_k_full, bool has_zp, int num_groups,
int dev, cudaStream_t stream, int thread_k_init, int group_size, int dev, cudaStream_t stream, int thread_k_init,
int thread_n_init, int sms, bool use_atomic_add, int thread_n_init, int sms, bool use_atomic_add,
bool use_fp32_reduce, bool is_zp_float) { bool use_fp32_reduce, bool is_zp_float) {
if (has_zp) { if (has_zp) {
...@@ -503,6 +536,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, ...@@ -503,6 +536,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
const int4* B_ptr = (const int4*)B; const int4* B_ptr = (const int4*)B;
int4* C_ptr = (int4*)C; int4* C_ptr = (int4*)C;
int4* C_tmp_ptr = (int4*)C_tmp; int4* C_tmp_ptr = (int4*)C_tmp;
const int4* bias_ptr = (const int4*)b_bias;
const int4* s_ptr = (const int4*)s; const int4* s_ptr = (const int4*)s;
const uint16_t* s2_ptr = (const uint16_t*)s2; const uint16_t* s2_ptr = (const uint16_t*)s2;
const int4* zp_ptr = (const int4*)zp; const int4* zp_ptr = (const int4*)zp;
...@@ -623,8 +657,9 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, ...@@ -623,8 +657,9 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
// avoid ">>>" being formatted to "> > >" // avoid ">>>" being formatted to "> > >"
// clang-format off // clang-format off
kernel<<<blocks, num_threads, max_shared_mem_new, stream>>>( kernel<<<blocks, num_threads, max_shared_mem_new, stream>>>(
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr, num_groups, A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, s_ptr, s2_ptr, zp_ptr,
prob_m_split, prob_n, prob_k, lda, locks, part_use_atomic_add, g_idx_ptr, num_groups,
prob_m_split, prob_n, prob_k, lda, locks, has_bias, part_use_atomic_add,
use_fp32_reduce, max_shared_mem_new); use_fp32_reduce, max_shared_mem_new);
// clang-format on // clang-format on
...@@ -638,7 +673,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, ...@@ -638,7 +673,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
torch::Tensor gptq_marlin_gemm( torch::Tensor gptq_marlin_gemm(
torch::Tensor& a, std::optional<torch::Tensor> c_or_none, torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
torch::Tensor& b_q_weight, torch::Tensor& b_scales, torch::Tensor& b_q_weight,
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
std::optional<torch::Tensor> const& global_scale_or_none, std::optional<torch::Tensor> const& global_scale_or_none,
std::optional<torch::Tensor> const& b_zeros_or_none, std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none, std::optional<torch::Tensor> const& g_idx_or_none,
...@@ -785,12 +821,24 @@ torch::Tensor gptq_marlin_gemm( ...@@ -785,12 +821,24 @@ torch::Tensor gptq_marlin_gemm(
torch::Tensor global_scale; torch::Tensor global_scale;
if (global_scale_or_none.has_value()) { if (global_scale_or_none.has_value()) {
global_scale = global_scale_or_none.value(); global_scale = global_scale_or_none.value();
TORCH_CHECK(b_q_type == vllm::kFE2M1f, TORCH_CHECK(b_q_type == vllm::kFE2M1f && group_size == 16,
"global_scale can only be used for float4_e2m1f."); "global_scale can only be used for nvfp4 format.");
} else { } else {
global_scale = torch::empty({0}, options); global_scale = torch::empty({0}, options);
TORCH_CHECK(!(b_q_type == vllm::kFE2M1f), TORCH_CHECK(!(b_q_type == vllm::kFE2M1f && group_size == 16),
"the global_scale parameter must be passed for float4_e2m1f."); "the global_scale parameter must be passed for nvfp4 format.");
}
bool has_bias = b_bias_or_none.has_value();
torch::Tensor b_bias;
if (has_bias) {
b_bias = b_bias_or_none.value();
TORCH_CHECK(b_bias.device().is_cuda(), "b_bias is not on GPU");
TORCH_CHECK(b_bias.is_contiguous(), "b_bias is not contiguous");
TORCH_CHECK(b_bias.size(0) == size_n, "b_bias.size(0) != size_n");
TORCH_CHECK(b_bias.stride(0) == 1, "b_bias.stride(0) != 1");
} else {
b_bias = torch::empty({0}, options);
} }
torch::Tensor b_zeros; torch::Tensor b_zeros;
...@@ -857,34 +905,50 @@ torch::Tensor gptq_marlin_gemm( ...@@ -857,34 +905,50 @@ torch::Tensor gptq_marlin_gemm(
if (a.scalar_type() == at::ScalarType::Half) { if (a.scalar_type() == at::ScalarType::Half) {
void* scales_ptr; void* scales_ptr;
if (b_q_type == vllm::kFE2M1f) { if (b_q_type == vllm::kFE2M1f) {
if (group_size == 16)
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>(); scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
else if (group_size == 32)
scales_ptr = b_scales.data_ptr<at::Float8_e8m0fnu>();
else
TORCH_CHECK(false,
"float4_e2m1f only supports group_size == 16 (NVFP4) ",
"and group_size == 32 (MXFP4)");
} else { } else {
scales_ptr = b_scales.data_ptr<at::Half>(); scales_ptr = b_scales.data_ptr<at::Half>();
} }
marlin::marlin_mm<half>( marlin::marlin_mm<half>(
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(), a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
c_tmp.data_ptr<float>(), scales_ptr, global_scale.data_ptr<at::Half>(), c_tmp.data_ptr<float>(), b_bias.data_ptr<at::Half>(), scales_ptr,
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), global_scale.data_ptr<at::Half>(), b_zeros.data_ptr(), g_idx.data_ptr(),
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k, a.stride(0), perm.data_ptr(), a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp, a.stride(0), workspace.data_ptr(), b_q_type, has_bias, has_act_order,
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), is_k_full, has_zp, num_groups, group_size, dev,
thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float); at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
use_atomic_add, use_fp32_reduce, is_zp_float);
} else if (a.scalar_type() == at::ScalarType::BFloat16) { } else if (a.scalar_type() == at::ScalarType::BFloat16) {
void* scales_ptr; void* scales_ptr;
if (b_q_type == vllm::kFE2M1f) { if (b_q_type == vllm::kFE2M1f) {
if (group_size == 16)
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>(); scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
else if (group_size == 32)
scales_ptr = b_scales.data_ptr<at::Float8_e8m0fnu>();
else
TORCH_CHECK(false,
"float4_e2m1f only supports group_size == 16 (NVFP4) ",
"and group_size == 32 (MXFP4)");
} else { } else {
scales_ptr = b_scales.data_ptr<at::BFloat16>(); scales_ptr = b_scales.data_ptr<at::BFloat16>();
} }
marlin::marlin_mm<nv_bfloat16>( marlin::marlin_mm<nv_bfloat16>(
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(), a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(), scales_ptr, c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
b_bias.data_ptr<at::BFloat16>(), scales_ptr,
global_scale.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), global_scale.data_ptr<at::BFloat16>(), b_zeros.data_ptr(),
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
size_m, size_n, size_k, a.stride(0), workspace.data_ptr(), b_q_type, size_m, size_n, size_k, a.stride(0), workspace.data_ptr(), b_q_type,
has_act_order, is_k_full, has_zp, num_groups, group_size, dev, has_bias, has_act_order, is_k_full, has_zp, num_groups, group_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
use_atomic_add, use_fp32_reduce, is_zp_float); use_atomic_add, use_fp32_reduce, is_zp_float);
} else { } else {
......
...@@ -10,15 +10,18 @@ ...@@ -10,15 +10,18 @@
#define MARLIN_KERNEL_PARAMS \ #define MARLIN_KERNEL_PARAMS \
const int4 *__restrict__ A, const int4 *__restrict__ B, \ const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ b_bias_ptr, \
const int4 *__restrict__ scales_ptr, \ const int4 *__restrict__ scales_ptr, \
const uint16_t *__restrict__ scale2_ptr, \ const uint16_t *__restrict__ scale2_ptr, \
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \ int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \
bool use_atomic_add, bool use_fp32_reduce, int max_shared_mem bool has_bias, bool use_atomic_add, bool use_fp32_reduce, \
int max_shared_mem
namespace MARLIN_NAMESPACE_NAME { namespace MARLIN_NAMESPACE_NAME {
template <typename scalar_t, // compute dtype, half or nv_float16 template <typename scalar_t, // compute dtype, half or nv_float16
const vllm::ScalarTypeId w_type_id, // weight ScalarType id const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const vllm::ScalarTypeId s_type_id, // weight ScalarType id
const int threads, // number of threads in a threadblock const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the // dimension (batchsize) of the
......
...@@ -39,6 +39,7 @@ namespace MARLIN_NAMESPACE_NAME { ...@@ -39,6 +39,7 @@ namespace MARLIN_NAMESPACE_NAME {
template <typename scalar_t, // compute dtype, half or nv_float16 template <typename scalar_t, // compute dtype, half or nv_float16
const vllm::ScalarTypeId w_type_id, // weight ScalarType id const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const vllm::ScalarTypeId s_type_id, // weight scale ScalarType id
const int threads, // number of threads in a threadblock const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the // dimension (batchsize) of the
...@@ -271,6 +272,7 @@ __device__ inline void wait_negative_and_add(int* lock) { ...@@ -271,6 +272,7 @@ __device__ inline void wait_negative_and_add(int* lock) {
template <typename scalar_t, // compute dtype, half or nv_float16 template <typename scalar_t, // compute dtype, half or nv_float16
const vllm::ScalarTypeId w_type_id, // weight ScalarType id const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const vllm::ScalarTypeId s_type_id, // weight scale ScalarType id
const int threads, // number of threads in a threadblock const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the // dimension (batchsize) of the
...@@ -290,6 +292,7 @@ __global__ void Marlin( ...@@ -290,6 +292,7 @@ __global__ void Marlin(
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
int4* __restrict__ C, // fp16 output buffer of shape mxn int4* __restrict__ C, // fp16 output buffer of shape mxn
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
const int4* __restrict__ b_bias_ptr,
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
// (k/groupsize)xn // (k/groupsize)xn
const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4 const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4
...@@ -303,6 +306,7 @@ __global__ void Marlin( ...@@ -303,6 +306,7 @@ __global__ void Marlin(
int prob_k, // reduction dimension k int prob_k, // reduction dimension k
int lda, // A.stride(0), equal to prob_k is A is contiguous int lda, // A.stride(0), equal to prob_k is A is contiguous
int* locks, // extra global storage for barrier synchronization int* locks, // extra global storage for barrier synchronization
bool has_bias,
bool use_atomic_add, // whether to use atomic add to reduce bool use_atomic_add, // whether to use atomic add to reduce
bool use_fp32_reduce, // whether to use fp32 global reduce bool use_fp32_reduce, // whether to use fp32 global reduce
int max_shared_mem) { int max_shared_mem) {
...@@ -326,18 +330,29 @@ __global__ void Marlin( ...@@ -326,18 +330,29 @@ __global__ void Marlin(
using FragZP = typename ScalarType<scalar_t>::FragZP; using FragZP = typename ScalarType<scalar_t>::FragZP;
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
static constexpr auto s_type = vllm::ScalarType::from_id(s_type_id);
if constexpr (w_type == vllm::kFE2M1f) {
static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 ||
s_type == vllm::kFE8M0fnu && group_blocks == 2);
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
static_assert(s_type == vllm::kBFloat16);
} else if constexpr (std::is_same<scalar_t, half>::value) {
static_assert(s_type == vllm::kFloat16);
}
constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8; constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8;
constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 || constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 ||
w_type == vllm::kU4B8 || w_type == vllm::kU8B128; w_type == vllm::kU4B8 || w_type == vllm::kU8B128;
// see comments of dequant.h for more details // see comments of dequant.h for more details
constexpr bool dequant_skip_flop = constexpr bool dequant_skip_flop =
!is_int_type || w_type == vllm::kFE4M3fn ||
w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn ||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value || has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
has_zp && !is_zp_float && !(w_type == vllm::kU8); has_zp && !is_zp_float && !(w_type == vllm::kU8);
scalar_t2 global_scale; scalar_t2 global_scale;
if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
if constexpr (w_type == vllm::kFE2M1f) { // NVFP4 format requires global scale
uint16_t val = scale2_ptr[0]; uint16_t val = scale2_ptr[0];
global_scale = Dtype::num2num2(*reinterpret_cast<scalar_t*>(&val)); global_scale = Dtype::num2num2(*reinterpret_cast<scalar_t*>(&val));
} }
...@@ -589,7 +604,7 @@ __global__ void Marlin( ...@@ -589,7 +604,7 @@ __global__ void Marlin(
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 4; (threadIdx.x % 32) / 4;
s_sh_rd = s_sh_rd * 2 + warp_row % 2; s_sh_rd = s_sh_rd * 2 + (warp_row / group_blocks) % 2;
} else if constexpr (group_blocks != -1) } else if constexpr (group_blocks != -1)
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
...@@ -602,6 +617,18 @@ __global__ void Marlin( ...@@ -602,6 +617,18 @@ __global__ void Marlin(
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) % 4; (threadIdx.x % 32) % 4;
int bias_sh_rd;
if constexpr (m_block_size_8) {
bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 8;
} else {
bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) % 4;
}
int bias_sh_wr = threadIdx.x;
int bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x;
// Zero-points have the same read layout as the scales // Zero-points have the same read layout as the scales
// (without column-wise case) // (without column-wise case)
constexpr int num_col_threads = 8; constexpr int num_col_threads = 8;
...@@ -670,7 +697,19 @@ __global__ void Marlin( ...@@ -670,7 +697,19 @@ __global__ void Marlin(
constexpr int sh_b_size = stages * b_sh_stage; constexpr int sh_b_size = stages * b_sh_stage;
int4* sh_b = sh; int4* sh_b = sh;
int4* sh_red = sh; int4* sh_red = sh;
int4* sh_g_idx = sh_b + (sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
constexpr int sh_size_b_red_min =
(sh_red_size < sh_b_size ? sh_red_size : sh_b_size);
constexpr int sh_size_b_red_max =
(sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
constexpr int sh_bias_size = (thread_n_blocks * 16 / 8);
constexpr int sh_b_red_bias_size =
sh_size_b_red_max > (sh_size_b_red_min + sh_bias_size)
? sh_size_b_red_max
: (sh_size_b_red_min + sh_bias_size);
int4* sh_bias = sh + sh_size_b_red_min;
int4* sh_g_idx = sh + sh_b_red_bias_size;
int4* sh_zp = sh_g_idx + (stages * g_idx_stage); int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride) constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride)
: (stages * s_sh_stage); : (stages * s_sh_stage);
...@@ -680,15 +719,13 @@ __global__ void Marlin( ...@@ -680,15 +719,13 @@ __global__ void Marlin(
static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <= static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <=
stages * b_sh_stage); stages * b_sh_stage);
int4* sh_a = sh_s + sh_s_size; int4* sh_a = sh_s + sh_s_size;
// constexpr int shm_size_used =
// stages * (g_idx_stage + zp_sh_stage) + sh_s_size +
// (sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
// Register storage for double buffer of shared memory reads. // Register storage for double buffer of shared memory reads.
FragA frag_a[2][thread_m_blocks]; FragA frag_a[2][thread_m_blocks];
I4 frag_b_quant[2][b_thread_vecs]; I4 frag_b_quant[2][b_thread_vecs];
FragC frag_c[thread_m_blocks][4][2]; FragC frag_c[thread_m_blocks][4][2];
FragS frag_s[2][4]; // No act-order FragS frag_s[2][4]; // No act-order
FragS frag_bias[2][4];
FragS act_frag_s[2][4][4]; // For act-order FragS act_frag_s[2][4][4]; // For act-order
int frag_qzp[2][num_ints_per_thread]; // Zero-points int frag_qzp[2][num_ints_per_thread]; // Zero-points
FragZP frag_zp; // Zero-points in fp16 FragZP frag_zp; // Zero-points in fp16
...@@ -923,10 +960,15 @@ __global__ void Marlin( ...@@ -923,10 +960,15 @@ __global__ void Marlin(
if constexpr (w_type_id != vllm::kFE2M1f.id()) { if constexpr (w_type_id != vllm::kFE2M1f.id()) {
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
} else { } else if constexpr (group_blocks == 1 || thread_k_blocks > 4) {
reinterpret_cast<int2*>(&frag_s[k % 2])[0] = reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
reinterpret_cast<int2*>( reinterpret_cast<int2*>(
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
} else {
reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
reinterpret_cast<int2*>(
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride) +
k % 2];
} }
} }
} }
...@@ -1139,9 +1181,9 @@ __global__ void Marlin( ...@@ -1139,9 +1181,9 @@ __global__ void Marlin(
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0]; int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1]; int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];
dequant_fp8_scales<scalar_t2>(s_quant_0, dequant_fp8_scales<scalar_t2, s_type_id>(
reinterpret_cast<scalar_t2*>(&frag_s[k2])); s_quant_0, reinterpret_cast<scalar_t2*>(&frag_s[k2]));
dequant_fp8_scales<scalar_t2>( dequant_fp8_scales<scalar_t2, s_type_id>(
s_quant_1, reinterpret_cast<scalar_t2*>(&frag_s[k2]) + 2); s_quant_1, reinterpret_cast<scalar_t2*>(&frag_s[k2]) + 2);
} }
...@@ -1411,7 +1453,7 @@ __global__ void Marlin( ...@@ -1411,7 +1453,7 @@ __global__ void Marlin(
// Write out the reduce final result in the correct layout. We only actually // Write out the reduce final result in the correct layout. We only actually
// reshuffle matrix fragments in this step, the reduction above is performed // reshuffle matrix fragments in this step, the reduction above is performed
// in fragment layout. // in fragment layout.
auto write_result = [&]() { auto write_result = [&](bool last) {
int c_gl_stride = prob_n / 8; int c_gl_stride = prob_n / 8;
constexpr int c_sh_stride = 2 * thread_n_blocks + 1; constexpr int c_sh_stride = 2 * thread_n_blocks + 1;
int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
...@@ -1438,7 +1480,7 @@ __global__ void Marlin( ...@@ -1438,7 +1480,7 @@ __global__ void Marlin(
int c_gl_wr_end = c_gl_stride * prob_m; int c_gl_wr_end = c_gl_stride * prob_m;
// We first reorder in shared memory to guarantee the most efficient final // We first reorder in shared memory to guarantee the most efficient final
// global write patterns // global write patterns
auto write = [&](int idx, float c0, float c1, FragS& s) { auto write = [&](int idx, float c0, float c1, FragS& s, FragS& b_bias) {
scalar_t2 res = scalar_t2 res =
Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1));
...@@ -1447,12 +1489,25 @@ __global__ void Marlin( ...@@ -1447,12 +1489,25 @@ __global__ void Marlin(
if constexpr (!has_act_order && group_blocks == -1 && if constexpr (!has_act_order && group_blocks == -1 &&
w_type.size_bits() == 4 && w_type.size_bits() == 4 &&
(has_zp && dequant_skip_flop || !has_zp)) { (has_zp && dequant_skip_flop || !has_zp)) {
res = __hmul2(res, s[0]); scalar_t2 tmp_scale = s[0];
if constexpr (m_block_size_8) {
tmp_scale = Dtype::num2num2(
reinterpret_cast<scalar_t*>(&s[0])[(threadIdx.x % 8) / 4]);
}
res = __hmul2(res, tmp_scale);
} }
if constexpr (w_type == vllm::kFE2M1f) { if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
res = __hmul2(res, global_scale); res = __hmul2(res, global_scale);
} }
if (has_bias && last) {
scalar_t2 tmp_bias = b_bias[0];
if constexpr (m_block_size_8) {
tmp_bias = Dtype::num2num2(
reinterpret_cast<scalar_t*>(&b_bias[0])[(threadIdx.x % 8) / 4]);
}
res = __hadd2(res, tmp_bias);
}
if constexpr (m_block_size_8) { if constexpr (m_block_size_8) {
((scalar_t*)sh_red)[idx] = res.x; ((scalar_t*)sh_red)[idx] = res.x;
...@@ -1470,19 +1525,25 @@ __global__ void Marlin( ...@@ -1470,19 +1525,25 @@ __global__ void Marlin(
if constexpr (m_block_size_8) { if constexpr (m_block_size_8) {
int wr = c_sh_wr + 16 * j; int wr = c_sh_wr + 16 * j;
write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1], write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1],
frag_s[j / 2][2 * (j % 2) + 0]); frag_s[j / 2][2 * (j % 2) + 0],
frag_bias[j / 2][2 * (j % 2) + 0]);
write(wr + 8, frag_c[i][j][0][2], frag_c[i][j][0][3], write(wr + 8, frag_c[i][j][0][2], frag_c[i][j][0][3],
frag_s[j / 2][2 * (j % 2) + 1]); frag_s[j / 2][2 * (j % 2) + 1],
frag_bias[j / 2][2 * (j % 2) + 1]);
} else { } else {
int wr = c_sh_wr + 8 * j; int wr = c_sh_wr + 8 * j;
write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],
frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0],
frag_bias[j / 2][2 * (j % 2) + 0]);
write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],
frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0],
frag_bias[j / 2][2 * (j % 2) + 0]);
write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],
frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1],
frag_bias[j / 2][2 * (j % 2) + 1]);
write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],
frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1],
frag_bias[j / 2][2 * (j % 2) + 1]);
} }
} }
c_sh_wr += 16 * (4 * c_sh_stride); c_sh_wr += 16 * (4 * c_sh_stride);
...@@ -1622,6 +1683,14 @@ __global__ void Marlin( ...@@ -1622,6 +1683,14 @@ __global__ void Marlin(
} }
thread_block_reduce(); thread_block_reduce();
if (has_bias && last) {
__syncthreads();
cp_async4_pred(&sh_bias[bias_sh_wr], &b_bias_ptr[bias_gl_rd],
threadIdx.x < 16 * thread_n_blocks / 8);
cp_async_fence();
}
if constexpr (!has_act_order && group_blocks == -1 && if constexpr (!has_act_order && group_blocks == -1 &&
(has_zp && dequant_skip_flop || !has_zp)) { (has_zp && dequant_skip_flop || !has_zp)) {
if (w_type.size_bits() == 8 || (last || use_atomic_add)) { if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
...@@ -1684,11 +1753,20 @@ __global__ void Marlin( ...@@ -1684,11 +1753,20 @@ __global__ void Marlin(
} }
barrier_release(&locks[locks_off], last); barrier_release(&locks[locks_off], last);
} }
if (has_bias && last) {
cp_async_wait<0>();
__syncthreads();
reinterpret_cast<int4*>(&frag_bias)[0] = sh_bias[bias_sh_rd];
reinterpret_cast<int4*>(&frag_bias)[1] = sh_bias[bias_sh_rd + 4];
__syncthreads();
}
if (use_atomic_add && slice_count > 1 && slice_idx != 0) if (use_atomic_add && slice_count > 1 && slice_idx != 0)
wait_negative_and_add(&locks[locks_off]); wait_negative_and_add(&locks[locks_off]);
if (last || use_atomic_add) if (last || use_atomic_add)
// only the last block in a slice actually writes the result // only the last block in a slice actually writes the result
write_result(); write_result(last);
slice_row = 0; slice_row = 0;
slice_col_par++; slice_col_par++;
slice_col++; slice_col++;
...@@ -1706,6 +1784,7 @@ __global__ void Marlin( ...@@ -1706,6 +1784,7 @@ __global__ void Marlin(
for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
} }
bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x;
// Update slice k/n for scales loading // Update slice k/n for scales loading
if constexpr (has_act_order) { if constexpr (has_act_order) {
slice_k_start = tb_k * slice_row; slice_k_start = tb_k * slice_row;
......
...@@ -326,6 +326,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -326,6 +326,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// gptq_marlin Optimized Quantized GEMM for GPTQ. // gptq_marlin Optimized Quantized GEMM for GPTQ.
ops.def( ops.def(
"gptq_marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, " "gptq_marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, "
"Tensor? b_bias_or_none,"
"Tensor b_scales, Tensor? global_scale, Tensor? b_zeros_or_none, Tensor? " "Tensor b_scales, Tensor? global_scale, Tensor? b_zeros_or_none, Tensor? "
"g_idx_or_none, Tensor? perm_or_none, Tensor workspace, int b_q_type, " "g_idx_or_none, Tensor? perm_or_none, Tensor workspace, int b_q_type, "
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, " "SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
......
...@@ -24,8 +24,10 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( ...@@ -24,8 +24,10 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, modular_triton_fused_moe) fused_topk, modular_triton_fused_moe)
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
fused_moe as iterative_moe) fused_moe as iterative_moe)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_permute_bias)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
rand_marlin_weight_fp4_like) rand_marlin_weight_mxfp4_like, rand_marlin_weight_nvfp4_like)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
marlin_quant_fp8_torch) marlin_quant_fp8_torch)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
...@@ -476,7 +478,10 @@ def marlin_moe_generate_valid_test_cases(): ...@@ -476,7 +478,10 @@ def marlin_moe_generate_valid_test_cases():
if quant_type == scalar_types.float8_e4m3fn and \ if quant_type == scalar_types.float8_e4m3fn and \
group_size not in [-1, 128]: group_size not in [-1, 128]:
return False return False
if quant_type == scalar_types.float4_e2m1f and group_size != 16: if quant_type == scalar_types.float4_e2m1f:
if group_size not in [16, 32]:
return False
if dtype == torch.float16 and group_size == 32:
return False return False
if quant_type != scalar_types.float4_e2m1f and group_size == 16: if quant_type != scalar_types.float4_e2m1f and group_size == 16:
return False return False
...@@ -520,31 +525,6 @@ def test_fused_marlin_moe( ...@@ -520,31 +525,6 @@ def test_fused_marlin_moe(
torch.cuda.manual_seed(0) torch.cuda.manual_seed(0)
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8] has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
if quant_type == scalar_types.float8_e4m3fn:
if group_size not in [-1, 128]:
return
if act_order:
return
# Filter act_order
if act_order:
if quant_type == scalar_types.float8_e4m3fn:
return
if group_size == -1:
return
if group_size in (k, n):
return
if has_zp:
return
else:
if not is_k_full:
return
if quant_type == scalar_types.float4_e2m1f and group_size != 16:
return
if quant_type != scalar_types.float4_e2m1f and group_size == 16:
return
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20
...@@ -569,12 +549,18 @@ def test_fused_marlin_moe( ...@@ -569,12 +549,18 @@ def test_fused_marlin_moe(
for i in range(w1.shape[0]): for i in range(w1.shape[0]):
if quant_type == scalar_types.float4_e2m1f: if quant_type == scalar_types.float4_e2m1f:
if group_size == 16:
w_ref1, qweight1, scales1, global_scale1 = \ w_ref1, qweight1, scales1, global_scale1 = \
rand_marlin_weight_fp4_like(w1[i], group_size) rand_marlin_weight_nvfp4_like(w1[i], group_size)
else:
w_ref1, qweight1, scales1 = \
rand_marlin_weight_mxfp4_like(w1[i], group_size)
global_scale1 = None
w_ref1_l.append(w_ref1.T) w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1) qweight1_l.append(qweight1)
scales1_l.append(scales1) scales1_l.append(scales1)
if global_scale1 is not None:
global_scale1_l.append(global_scale1) global_scale1_l.append(global_scale1)
elif quant_type == scalar_types.float8_e4m3fn: elif quant_type == scalar_types.float8_e4m3fn:
w_ref1, qweight1, scales1 = marlin_quant_fp8_torch( w_ref1, qweight1, scales1 = marlin_quant_fp8_torch(
...@@ -620,12 +606,18 @@ def test_fused_marlin_moe( ...@@ -620,12 +606,18 @@ def test_fused_marlin_moe(
for i in range(w2.shape[0]): for i in range(w2.shape[0]):
if quant_type == scalar_types.float4_e2m1f: if quant_type == scalar_types.float4_e2m1f:
if group_size == 16:
w_ref2, qweight2, scales2, global_scale2 = \ w_ref2, qweight2, scales2, global_scale2 = \
rand_marlin_weight_fp4_like(w2[i], group_size) rand_marlin_weight_nvfp4_like(w2[i], group_size)
else:
w_ref2, qweight2, scales2 = \
rand_marlin_weight_mxfp4_like(w2[i], group_size)
global_scale2 = None
w_ref2_l.append(w_ref2.T) w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2) qweight2_l.append(qweight2)
scales2_l.append(scales2) scales2_l.append(scales2)
if global_scale2 is not None:
global_scale2_l.append(global_scale2) global_scale2_l.append(global_scale2)
elif quant_type == scalar_types.float8_e4m3fn: elif quant_type == scalar_types.float8_e4m3fn:
w_ref2, qweight2, scales2 = marlin_quant_fp8_torch( w_ref2, qweight2, scales2 = marlin_quant_fp8_torch(
...@@ -677,6 +669,8 @@ def test_fused_marlin_moe( ...@@ -677,6 +669,8 @@ def test_fused_marlin_moe(
a, a,
qweight1, qweight1,
qweight2, qweight2,
None,
None,
scales1, scales1,
scales2, scales2,
score, score,
...@@ -698,6 +692,119 @@ def test_fused_marlin_moe( ...@@ -698,6 +692,119 @@ def test_fused_marlin_moe(
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0) torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
@pytest.mark.flaky(reruns=2)
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
@pytest.mark.parametrize("m", [1, 256])
def test_fused_marlin_moe_with_bias(m):
torch.cuda.manual_seed(0)
e, topk = 32, 4
n, k = 2048, 2048
group_size = 128
act_order = False
is_k_full = True
quant_type = scalar_types.uint4b8
dtype = torch.half
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
b_bias1 = torch.randn((e, 2 * n), device="cuda", dtype=dtype) / 10
b_bias2 = torch.randn((e, k), device="cuda", dtype=dtype) / 10
b_bias1_l = []
w_ref1_l = []
qweight1_l = []
scales1_l = []
g_idx1_l = []
sort_indices1_l = []
for i in range(w1.shape[0]):
test_perm = torch.randperm(k)
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = \
marlin_quantize(w1[i].transpose(1, 0), quant_type,
group_size, act_order, test_perm)
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
g_idx1_l.append(g_idx1)
sort_indices1_l.append(sort_indices1)
b_bias1_l.append(marlin_permute_bias(b_bias1[i]))
w_ref1 = stack_and_dev(w_ref1_l)
qweight1 = stack_and_dev(qweight1_l).contiguous()
scales1 = stack_and_dev(scales1_l)
global_scale1 = None
g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
zeros1 = None
sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
marlin_bias1 = stack_and_dev(b_bias1_l) if b_bias1_l else None
b_bias2_l = []
w_ref2_l = []
qweight2_l = []
scales2_l = []
g_idx2_l = []
sort_indices2_l = []
for i in range(w2.shape[0]):
test_perm = torch.randperm(n)
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = \
marlin_quantize(w2[i].transpose(1, 0), quant_type,
group_size, act_order, test_perm)
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
g_idx2_l.append(g_idx2)
sort_indices2_l.append(sort_indices2)
b_bias2_l.append(marlin_permute_bias(b_bias2[i]))
w_ref2 = stack_and_dev(w_ref2_l)
qweight2 = stack_and_dev(qweight2_l).contiguous()
scales2 = stack_and_dev(scales2_l)
global_scale2 = None
g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
zeros2 = None
sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
marlin_bias2 = stack_and_dev(b_bias2_l) if b_bias2_l else None
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
with set_current_vllm_config(vllm_config):
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, b_bias1,
b_bias2)
marlin_output = torch.ops.vllm.fused_marlin_moe(
a,
qweight1,
qweight2,
marlin_bias1,
marlin_bias2,
scales1,
scales2,
score,
topk_weights,
topk_ids,
global_num_experts=e,
expert_map=None,
global_scale1=global_scale1,
global_scale2=global_scale2,
g_idx1=g_idx1,
g_idx2=g_idx2,
sort_indices1=sort_indices1,
sort_indices2=sort_indices2,
w1_zeros=zeros1,
w2_zeros=zeros2,
quant_type_id=quant_type.id,
is_k_full=is_k_full)
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
def test_moe_align_block_size_opcheck(): def test_moe_align_block_size_opcheck():
num_experts = 4 num_experts = 4
block_size = 4 block_size = 4
......
...@@ -19,10 +19,11 @@ from vllm.model_executor.layers.quantization.qqq import ( ...@@ -19,10 +19,11 @@ from vllm.model_executor.layers.quantization.qqq import (
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx, MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx,
marlin_make_workspace_new, marlin_permute_scales, marlin_make_workspace_new, marlin_permute_bias, marlin_permute_scales,
query_marlin_supported_quant_types) query_marlin_supported_quant_types)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
FP4_MARLIN_SUPPORTED_GROUP_SIZES, rand_marlin_weight_fp4_like) FP4_MARLIN_SUPPORTED_GROUP_SIZES, rand_marlin_weight_mxfp4_like,
rand_marlin_weight_nvfp4_like)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
marlin_quant_fp8_torch) marlin_quant_fp8_torch)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
...@@ -39,7 +40,7 @@ from vllm.scalar_type import scalar_types ...@@ -39,7 +40,7 @@ from vllm.scalar_type import scalar_types
ACT_ORDER_OPTS = [False, True] ACT_ORDER_OPTS = [False, True]
K_FULL_OPTS = [False, True] K_FULL_OPTS = [False, True]
USE_ATOMIC_ADD_OPTS = [False, True] USE_ATOMIC_ADD_OPTS = [False, True]
USE_FP32_REDUCE_OPTS = [False, True] USE_FP32_REDUCE_OPTS = [True]
MARLIN_K_CHUNKS = [128] MARLIN_K_CHUNKS = [128]
MARLIN_N_CHUNKS = [64, 256] MARLIN_N_CHUNKS = [64, 256]
...@@ -202,17 +203,10 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, ...@@ -202,17 +203,10 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
@pytest.mark.parametrize("is_k_full", K_FULL_OPTS) @pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
@pytest.mark.parametrize("use_atomic_add", USE_ATOMIC_ADD_OPTS) @pytest.mark.parametrize("use_atomic_add", USE_ATOMIC_ADD_OPTS)
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS) @pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
def test_gptq_marlin_gemm( @pytest.mark.parametrize("dtype", DTYPES)
k_chunk, def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size,
n_chunk, mnk_factors, act_order, is_k_full, use_atomic_add,
quant_type, use_fp32_reduce, dtype):
group_size,
mnk_factors,
act_order,
is_k_full,
use_atomic_add,
use_fp32_reduce,
):
m_factor, n_factor, k_factor = mnk_factors m_factor, n_factor, k_factor = mnk_factors
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8] has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
...@@ -231,14 +225,23 @@ def test_gptq_marlin_gemm( ...@@ -231,14 +225,23 @@ def test_gptq_marlin_gemm(
if size_k % group_size != 0: if size_k % group_size != 0:
return return
a_input = rand_data((size_m, size_k)) a_input = rand_data((size_m, size_k), dtype)
b_weight = rand_data((size_k, size_n)) b_weight = rand_data((size_k, size_n), dtype)
if quant_type == scalar_types.float4_e2m1f: if quant_type == scalar_types.float4_e2m1f:
if group_size != 16 or act_order: if group_size not in [16, 32] or act_order:
return return
w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_fp4_like( if group_size == 32 and dtype == torch.float16:
b_weight.T, group_size) return
if group_size == 16:
w_ref, marlin_q_w, marlin_s, marlin_s2 = \
rand_marlin_weight_nvfp4_like(b_weight.T, group_size)
else:
w_ref, marlin_q_w, marlin_s = \
rand_marlin_weight_mxfp4_like(b_weight.T, group_size)
marlin_s2 = None
g_idx = None g_idx = None
sort_indices = None sort_indices = None
marlin_zp = None marlin_zp = None
...@@ -272,8 +275,8 @@ def test_gptq_marlin_gemm( ...@@ -272,8 +275,8 @@ def test_gptq_marlin_gemm(
workspace = marlin_make_workspace_new(w_ref.device) workspace = marlin_make_workspace_new(w_ref.device)
opcheck(torch.ops._C.gptq_marlin_gemm, opcheck(torch.ops._C.gptq_marlin_gemm,
(a_input, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, g_idx, (a_input, None, marlin_q_w, None, marlin_s, marlin_s2, marlin_zp,
sort_indices, workspace, quant_type.id, a_input.shape[0], g_idx, sort_indices, workspace, quant_type.id, a_input.shape[0],
b_weight.shape[1], a_input.shape[1], is_k_full, use_atomic_add, b_weight.shape[1], a_input.shape[1], is_k_full, use_atomic_add,
use_fp32_reduce, False), use_fp32_reduce, False),
test_utils=DEFAULT_OPCHECK_TEST_UTILS) test_utils=DEFAULT_OPCHECK_TEST_UTILS)
...@@ -282,6 +285,7 @@ def test_gptq_marlin_gemm( ...@@ -282,6 +285,7 @@ def test_gptq_marlin_gemm(
a_input, a_input,
None, None,
marlin_q_w, marlin_q_w,
None,
marlin_s, marlin_s,
marlin_s2, marlin_s2,
marlin_zp, marlin_zp,
...@@ -418,6 +422,7 @@ def test_hqq_marlin_gemm( ...@@ -418,6 +422,7 @@ def test_hqq_marlin_gemm(
a_input, a_input,
None, None,
marlin_w_q, marlin_w_q,
None,
marlin_s, marlin_s,
None, None,
marlin_zp, marlin_zp,
...@@ -531,6 +536,7 @@ def test_marlin_gemm_subset_input(): ...@@ -531,6 +536,7 @@ def test_marlin_gemm_subset_input():
a_input, a_input,
None, None,
marlin_q_w, marlin_q_w,
None,
marlin_s, marlin_s,
None, None,
marlin_zp, marlin_zp,
...@@ -555,6 +561,53 @@ def test_marlin_gemm_subset_input(): ...@@ -555,6 +561,53 @@ def test_marlin_gemm_subset_input():
assert max_diff < 0.04 assert max_diff < 0.04
@pytest.mark.parametrize("size_m", [1, 256])
def test_marlin_gemm_with_bias(size_m):
quant_type = scalar_types.uint4b8
group_size = 128
size_k, size_n = 1024, 2048
a_input = rand_data((size_m, size_k))
b_weight = rand_data((size_k, size_n))
b_bias = rand_data((size_n, )) * 10
marlin_bias = marlin_permute_bias(b_bias)
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
b_weight, quant_type, group_size, False)
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
workspace = marlin_make_workspace_new(a_input.device)
output = ops.gptq_marlin_gemm(
a_input,
None,
marlin_q_w,
marlin_bias,
marlin_s,
None,
marlin_zp,
g_idx,
sort_indices,
workspace,
quant_type,
a_input.shape[0],
b_weight.shape[1],
a_input.shape[1],
is_k_full=True,
use_atomic_add=False,
use_fp32_reduce=True,
is_zp_float=False,
)
output_ref = torch.matmul(a_input, w_ref) + b_bias.view(1, -1)
torch.cuda.synchronize()
max_diff = compute_max_diff(output, output_ref)
assert max_diff < 0.04
def test_marlin_gemm_opcheck(): def test_marlin_gemm_opcheck():
size_m = 2048 size_m = 2048
size_n = 4096 size_n = 4096
......
...@@ -1064,6 +1064,8 @@ def torch_experts( ...@@ -1064,6 +1064,8 @@ def torch_experts(
topk_weight: torch.Tensor, topk_weight: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
global_num_experts: int = -1, global_num_experts: int = -1,
b_bias1: Optional[torch.Tensor] = None,
b_bias2: Optional[torch.Tensor] = None,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None,
...@@ -1108,8 +1110,13 @@ def torch_experts( ...@@ -1108,8 +1110,13 @@ def torch_experts(
if mask.sum(): if mask.sum():
if quant_dtype is None: if quant_dtype is None:
tmp1 = a[mask] @ w1[i].transpose(0, 1) tmp1 = a[mask] @ w1[i].transpose(0, 1)
if b_bias1 is not None:
tmp1 = tmp1 + b_bias1[i].view(1, -1).to(tmp1.dtype)
tmp2 = SiluAndMul()(tmp1) tmp2 = SiluAndMul()(tmp1)
out[mask] = tmp2 @ w2[i].transpose(0, 1) out[mask] = tmp2 @ w2[i].transpose(0, 1)
if b_bias2 is not None:
out[mask] = out[mask] + b_bias2[i].view(1, -1).to(
tmp1.dtype)
elif block_shape is not None: elif block_shape is not None:
# block quantized # block quantized
assert (a_scale is not None and w1_scale is not None assert (a_scale is not None and w1_scale is not None
...@@ -1117,6 +1124,8 @@ def torch_experts( ...@@ -1117,6 +1124,8 @@ def torch_experts(
tmp1 = native_w8a8_block_matmul(a[mask], w1[i], a_scale[mask], tmp1 = native_w8a8_block_matmul(a[mask], w1[i], a_scale[mask],
w1_scale[i], block_shape, w1_scale[i], block_shape,
out.dtype) out.dtype)
if b_bias1 is not None:
tmp1 = tmp1 + b_bias1[i].view(1, -1).to(tmp1.dtype)
tmp2 = SiluAndMul()(tmp1) tmp2 = SiluAndMul()(tmp1)
tmp2, b_scale = moe_kernel_quantize_input( tmp2, b_scale = moe_kernel_quantize_input(
tmp2, a2_scale, quant_dtype, per_act_token_quant, tmp2, a2_scale, quant_dtype, per_act_token_quant,
...@@ -1125,6 +1134,9 @@ def torch_experts( ...@@ -1125,6 +1134,9 @@ def torch_experts(
out[mask] = native_w8a8_block_matmul(tmp2, w2[i], b_scale, out[mask] = native_w8a8_block_matmul(tmp2, w2[i], b_scale,
w2_scale[i], block_shape, w2_scale[i], block_shape,
out.dtype) out.dtype)
if b_bias2 is not None:
out[mask] = out[mask] + b_bias2[i].view(1, -1).to(
tmp1.dtype)
else: else:
assert (a_scale is not None and w1_scale is not None assert (a_scale is not None and w1_scale is not None
and w2_scale is not None) and w2_scale is not None)
...@@ -1133,6 +1145,8 @@ def torch_experts( ...@@ -1133,6 +1145,8 @@ def torch_experts(
tmp1 = a[mask].to(f32) * scales tmp1 = a[mask].to(f32) * scales
w1_dq = (w1[i].to(f32) * w1_scale[i]).transpose(0, 1) w1_dq = (w1[i].to(f32) * w1_scale[i]).transpose(0, 1)
tmp1 = (tmp1 @ w1_dq).to(out.dtype) tmp1 = (tmp1 @ w1_dq).to(out.dtype)
if b_bias1 is not None:
tmp1 = tmp1 + b_bias1[i].view(1, -1).to(out.dtype)
tmp2 = SiluAndMul()(tmp1).to(out.dtype) tmp2 = SiluAndMul()(tmp1).to(out.dtype)
...@@ -1144,6 +1158,9 @@ def torch_experts( ...@@ -1144,6 +1158,9 @@ def torch_experts(
tmp2 = tmp2.to(f32) * b_scale tmp2 = tmp2.to(f32) * b_scale
w2_dq = (w2[i].to(f32) * w2_scale[i]).transpose(0, 1) w2_dq = (w2[i].to(f32) * w2_scale[i]).transpose(0, 1)
out[mask] = (tmp2 @ w2_dq).to(out.dtype) out[mask] = (tmp2 @ w2_dq).to(out.dtype)
if b_bias2 is not None:
out[mask] = out[mask] + b_bias2[i].view(1, -1).to(
out.dtype)
if apply_router_weights_on_input: if apply_router_weights_on_input:
return out return out
...@@ -1157,12 +1174,14 @@ def torch_moe(a: torch.Tensor, ...@@ -1157,12 +1174,14 @@ def torch_moe(a: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
score: torch.Tensor, score: torch.Tensor,
topk: int, topk: int,
b_bias1: Optional[torch.Tensor] = None,
b_bias2: Optional[torch.Tensor] = None,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None) -> torch.Tensor: expert_map: Optional[torch.Tensor] = None) -> torch.Tensor:
score = torch.softmax(score, dim=-1, dtype=torch.float32) score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk) topk_weight, topk_ids = torch.topk(score, topk)
return torch_experts(a, w1, w2, topk_weight, topk_ids, global_num_experts, return torch_experts(a, w1, w2, topk_weight, topk_ids, global_num_experts,
expert_map) b_bias1, b_bias2, expert_map)
def torch_moe_single(a, w, score, topk): def torch_moe_single(a, w, score, topk):
......
...@@ -452,6 +452,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): ...@@ -452,6 +452,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
def _gptq_marlin_gemm_fake(a: torch.Tensor, def _gptq_marlin_gemm_fake(a: torch.Tensor,
c: Optional[torch.Tensor], c: Optional[torch.Tensor],
b_q_weight: torch.Tensor, b_q_weight: torch.Tensor,
b_bias: Optional[torch.Tensor],
b_scales: torch.Tensor, b_scales: torch.Tensor,
global_scale: Optional[torch.Tensor], global_scale: Optional[torch.Tensor],
b_zeros: Optional[torch.Tensor], b_zeros: Optional[torch.Tensor],
...@@ -1048,6 +1049,7 @@ def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, ...@@ -1048,6 +1049,7 @@ def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
def gptq_marlin_gemm(a: torch.Tensor, def gptq_marlin_gemm(a: torch.Tensor,
c: Optional[torch.Tensor], c: Optional[torch.Tensor],
b_q_weight: torch.Tensor, b_q_weight: torch.Tensor,
b_bias: Optional[torch.Tensor],
b_scales: torch.Tensor, b_scales: torch.Tensor,
global_scale: Optional[torch.Tensor], global_scale: Optional[torch.Tensor],
b_zeros: Optional[torch.Tensor], b_zeros: Optional[torch.Tensor],
...@@ -1062,7 +1064,7 @@ def gptq_marlin_gemm(a: torch.Tensor, ...@@ -1062,7 +1064,7 @@ def gptq_marlin_gemm(a: torch.Tensor,
use_atomic_add: bool = False, use_atomic_add: bool = False,
use_fp32_reduce: bool = False, use_fp32_reduce: bool = False,
is_zp_float: bool = False) -> torch.Tensor: is_zp_float: bool = False) -> torch.Tensor:
return torch.ops._C.gptq_marlin_gemm(a, c, b_q_weight, b_scales, return torch.ops._C.gptq_marlin_gemm(a, c, b_q_weight, b_bias, b_scales,
global_scale, b_zeros, g_idx, perm, global_scale, b_zeros, g_idx, perm,
workspace, b_q_type.id, size_m, workspace, b_q_type.id, size_m,
size_n, size_k, is_k_full, size_n, size_k, is_k_full,
...@@ -1540,7 +1542,9 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, ...@@ -1540,7 +1542,9 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor], def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor],
b_qweight: torch.Tensor, b_scales: torch.Tensor, b_qweight: torch.Tensor,
b_bias: Optional[torch.Tensor],
b_scales: torch.Tensor,
global_scale: Optional[torch.Tensor], global_scale: Optional[torch.Tensor],
b_qzeros: Optional[torch.Tensor], b_qzeros: Optional[torch.Tensor],
g_idx: Optional[torch.Tensor], g_idx: Optional[torch.Tensor],
...@@ -1556,11 +1560,11 @@ def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor], ...@@ -1556,11 +1560,11 @@ def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor],
use_fp32_reduce: bool, use_fp32_reduce: bool,
is_zp_float: bool) -> torch.Tensor: is_zp_float: bool) -> torch.Tensor:
return torch.ops._moe_C.moe_wna16_marlin_gemm( return torch.ops._moe_C.moe_wna16_marlin_gemm(
input, output, b_qweight, b_scales, global_scale, b_qzeros, g_idx, input, output, b_qweight, b_bias, b_scales, global_scale, b_qzeros,
perm, workspace, sorted_token_ids, expert_ids, num_tokens_past_padded, g_idx, perm, workspace, sorted_token_ids, expert_ids,
topk_weights, moe_block_size, top_k, mul_topk_weights, is_ep, num_tokens_past_padded, topk_weights, moe_block_size, top_k,
b_q_type.id, size_m, size_n, size_k, is_k_full, use_atomic_add, mul_topk_weights, is_ep, b_q_type.id, size_m, size_n, size_k,
use_fp32_reduce, is_zp_float) is_k_full, use_atomic_add, use_fp32_reduce, is_zp_float)
if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"): if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
......
...@@ -122,6 +122,7 @@ if TYPE_CHECKING: ...@@ -122,6 +122,7 @@ if TYPE_CHECKING:
VLLM_MOE_DP_CHUNK_SIZE: int = 256 VLLM_MOE_DP_CHUNK_SIZE: int = 256
VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
VLLM_MXFP4_USE_MARLIN: Optional[bool] = None
VLLM_V0_USE_OUTLINES_CACHE: bool = False VLLM_V0_USE_OUTLINES_CACHE: bool = False
VLLM_V1_USE_OUTLINES_CACHE: bool = False VLLM_V1_USE_OUTLINES_CACHE: bool = False
VLLM_TPU_BUCKET_PADDING_GAP: int = 0 VLLM_TPU_BUCKET_PADDING_GAP: int = 0
...@@ -182,6 +183,12 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: ...@@ -182,6 +183,12 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
return int(value) return int(value)
def maybe_convert_bool(value: Optional[str]) -> Optional[bool]:
if value is None:
return None
return bool(int(value))
def get_vllm_port() -> Optional[int]: def get_vllm_port() -> Optional[int]:
"""Get the port from VLLM_PORT environment variable. """Get the port from VLLM_PORT environment variable.
...@@ -906,6 +913,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -906,6 +913,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_MARLIN_USE_ATOMIC_ADD": "VLLM_MARLIN_USE_ATOMIC_ADD":
lambda: os.environ.get("VLLM_MARLIN_USE_ATOMIC_ADD", "0") == "1", lambda: os.environ.get("VLLM_MARLIN_USE_ATOMIC_ADD", "0") == "1",
# Whether to use marlin kernel in mxfp4 quantization method
"VLLM_MXFP4_USE_MARLIN":
lambda: maybe_convert_bool(os.environ.get("VLLM_MXFP4_USE_MARLIN", None)),
# Whether to turn on the outlines cache for V0 # Whether to turn on the outlines cache for V0
# This cache is unbounded and on disk, so it's not safe to use in # This cache is unbounded and on disk, so it's not safe to use in
# an environment with potentially malicious users. # an environment with potentially malicious users.
......
...@@ -18,6 +18,8 @@ from vllm.utils import direct_register_custom_op ...@@ -18,6 +18,8 @@ from vllm.utils import direct_register_custom_op
def fused_marlin_moe(hidden_states: torch.Tensor, def fused_marlin_moe(hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
bias1: Optional[torch.Tensor],
bias2: Optional[torch.Tensor],
w1_scale: torch.Tensor, w1_scale: torch.Tensor,
w2_scale: torch.Tensor, w2_scale: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
...@@ -26,6 +28,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor, ...@@ -26,6 +28,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
quant_type_id: int, quant_type_id: int,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
activation: Optional[str] = "silu",
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
global_scale1: Optional[torch.Tensor] = None, global_scale1: Optional[torch.Tensor] = None,
global_scale2: Optional[torch.Tensor] = None, global_scale2: Optional[torch.Tensor] = None,
...@@ -88,6 +91,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor, ...@@ -88,6 +91,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [torch.float16, torch.bfloat16] assert hidden_states.dtype in [torch.float16, torch.bfloat16]
assert num_bits in [4, 8] assert num_bits in [4, 8]
assert topk_weights.dtype == torch.float32
M, K = hidden_states.shape M, K = hidden_states.shape
E = w1.shape[0] E = w1.shape[0]
...@@ -138,6 +142,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor, ...@@ -138,6 +142,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
hidden_states, hidden_states,
intermediate_cache1, intermediate_cache1,
w1, w1,
bias1,
w1_scale, w1_scale,
global_scale1, global_scale1,
w1_zeros, w1_zeros,
...@@ -161,8 +166,28 @@ def fused_marlin_moe(hidden_states: torch.Tensor, ...@@ -161,8 +166,28 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
use_fp32_reduce=True, use_fp32_reduce=True,
is_zp_float=False) is_zp_float=False)
if activation == "silu":
torch.ops._C.silu_and_mul(intermediate_cache2, torch.ops._C.silu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, 2 * N)) intermediate_cache1.view(-1, 2 * N))
elif activation == "swiglu_oai":
# NOTE: in gpt-oss, the gate_proj and up_proj is interleaved
# - interleaved: gate, up = gate_up[..., ::2], gate_up[..., 1::2]
# - origin: gate, up = gate_up[..., :N], gate_up[..., N:]
@torch.compile(dynamic=True)
def swiglu_oai(gate_up):
alpha = 1.702
limit = 7.0
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
gate = gate.clamp(min=None, max=limit)
up = up.clamp(min=-limit, max=limit)
glu = gate * torch.sigmoid(gate * alpha)
return (up + 1) * glu
intermediate_cache2 = swiglu_oai(intermediate_cache1)
else:
raise ValueError(f"Unsupported activation: {activation}. "
"Only silu and swiglu_oai activations are supported.")
if expert_map is not None: if expert_map is not None:
intermediate_cache3.zero_() intermediate_cache3.zero_()
...@@ -171,6 +196,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor, ...@@ -171,6 +196,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
intermediate_cache2, intermediate_cache2,
intermediate_cache3, intermediate_cache3,
w2, w2,
bias2,
w2_scale, w2_scale,
global_scale2, global_scale2,
w2_zeros, w2_zeros,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment