Unverified Commit d06ba4ed authored by Jinzhen Lin's avatar Jinzhen Lin Committed by GitHub
Browse files

[Kernel] moe wna16 marlin kernel (#14447)


Signed-off-by: default avatarJinzhen Lin <linjinzhen@hotmail.com>
Co-authored-by: default avatarMichael Goin <michael@neuralmagic.com>
Co-authored-by: default avatarmgoin <mgoin64@gmail.com>
parent 6b40996a
......@@ -609,21 +609,51 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}")
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
if (MARLIN_MOE_ARCHS)
set(MARLIN_MOE_SRC
"csrc/moe/marlin_kernels/marlin_moe_kernel.h"
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h"
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu"
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h"
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu"
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h"
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu"
"csrc/moe/marlin_moe_ops.cu")
#
# For the Marlin MOE kernels we automatically generate sources for various
# preselected input type pairs and schedules.
# Generate sources:
set(MOE_MARLIN_GEN_SCRIPT
${CMAKE_CURRENT_SOURCE_DIR}/csrc/moe/marlin_moe_wna16/generate_kernels.py)
file(MD5 ${MOE_MARLIN_GEN_SCRIPT} MOE_MARLIN_GEN_SCRIPT_HASH)
message(STATUS "Marlin MOE generation script hash: ${MOE_MARLIN_GEN_SCRIPT_HASH}")
message(STATUS "Last run Marlin MOE generate script hash: $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH}")
if (NOT DEFINED CACHE{MOE_MARLIN_GEN_SCRIPT_HASH}
OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH})
execute_process(
COMMAND ${CMAKE_COMMAND} -E env
PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$PYTHONPATH
${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT}
RESULT_VARIABLE moe_marlin_generation_result
OUTPUT_VARIABLE moe_marlin_generation_output
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log
ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log
)
if (NOT moe_marlin_generation_result EQUAL 0)
message(FATAL_ERROR "Marlin MOE generation failed."
" Result: \"${moe_marlin_generation_result}\""
"\nCheck the log for details: "
"${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log")
else()
set(MOE_MARLIN_GEN_SCRIPT_HASH ${MOE_MARLIN_GEN_SCRIPT_HASH}
CACHE STRING "Last run Marlin MOE generate script hash" FORCE)
message(STATUS "Marlin MOE generation completed successfully.")
endif()
else()
message(STATUS "Marlin MOE generation script has not changed, skipping generation.")
endif()
file(GLOB MOE_WNAA16_MARLIN_SRC "csrc/moe/marlin_moe_wna16/*.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_MOE_SRC}"
SRCS "${MOE_WNAA16_MARLIN_SRC}"
CUDA_ARCHS "${MARLIN_MOE_ARCHS}")
list(APPEND VLLM_MOE_EXT_SRC "${MARLIN_MOE_SRC}")
list(APPEND VLLM_MOE_EXT_SRC ${MOE_WNAA16_MARLIN_SRC})
message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}")
else()
message(STATUS "Not building Marlin MOE kernels as no compatible archs found"
......
# SPDX-License-Identifier: Apache-2.0
import glob
import itertools
import os
import subprocess
import jinja2
FILE_HEAD = """
// auto generated by generate.py
// clang-format off
#include "kernel.h"
#include "marlin_template.h"
namespace MARLIN_NAMESPACE_NAME {
""".strip()
TEMPLATE = ("template __global__ void Marlin<"
"{{scalar_t}}, "
"{{w_type_id}}, "
"{{threads}}, "
"{{thread_m_blocks}}, "
"{{thread_n_blocks}}, "
"{{thread_k_blocks}}, "
"{{'true' if m_block_size_8 else 'false'}}, "
"{{stages}}, "
"{{'true' if has_act_order else 'false'}}, "
"{{'true' if has_zp else 'false'}}, "
"{{group_blocks}}, "
"{{'true' if is_zp_float else 'false'}}>"
"( MARLIN_KERNEL_PARAMS );")
# int8 with zero point case (vllm::kU8) is also supported,
# we don't add it to reduce wheel size.
SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128"]
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
# group_blocks:
# = 0 : act order case
# = -1 : channelwise quantization
# > 0 : group_size=16*group_blocks
GROUP_BLOCKS = [0, -1, 2, 4, 8]
DTYPES = ["fp16", "bf16"]
def remove_old_kernels():
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"):
subprocess.call(["rm", "-f", filename])
def generate_new_kernels():
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
has_zp = "B" not in scalar_type
all_template_str_list = []
for group_blocks, m_blocks, thread_configs in itertools.product(
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS):
has_act_order = group_blocks == 0
if has_zp and has_act_order:
continue
if thread_configs[2] == 256:
if m_blocks <= 1 and thread_configs[0] != 128:
continue
if m_blocks > 1 and thread_configs[0] != 64:
continue
k_blocks = thread_configs[0] // 16
n_blocks = thread_configs[1] // 16
threads = thread_configs[2]
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"
template_str = jinja2.Template(TEMPLATE).render(
scalar_t=c_dtype,
w_type_id=scalar_type + ".id()",
threads=threads,
thread_m_blocks=max(m_blocks, 1),
thread_n_blocks=n_blocks,
thread_k_blocks=k_blocks,
m_block_size_8=m_blocks == 0.5,
stages="pipe_stages",
has_act_order=has_act_order,
has_zp=has_zp,
group_blocks=group_blocks,
is_zp_float=False,
)
all_template_str_list.append(template_str)
file_content = FILE_HEAD + "\n\n"
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu"
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
f.write(file_content)
if __name__ == "__main__":
remove_old_kernels()
generate_new_kernels()
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif
#include "quantization/gptq_marlin/marlin.cuh"
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
#include "core/scalar_type.hpp"
#define MARLIN_KERNEL_PARAMS \
const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, \
const int *__restrict__ g_idx, \
const int32_t *__restrict__ sorted_token_ids_ptr, \
const int32_t *__restrict__ expert_ids_ptr, \
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
const float *__restrict__ topk_weights_ptr, int top_k, \
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
int prob_n, int prob_k, int *locks, bool use_atomic_add, \
bool use_fp32_reduce
namespace MARLIN_NAMESPACE_NAME {
template <typename scalar_t, // compute dtype, half or nv_float16
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the
// threadblock
const int thread_n_blocks, // same for n dimension (output)
const int thread_k_blocks, // same for k dimension (reduction)
const bool m_block_size_8, // whether m_block_size == 8
// only works when thread_m_blocks == 1
const int stages, // number of stages for the async global->shared
// fetch pipeline
const bool has_act_order, // whether act_order is enabled
const bool has_zp, // whether zero-points are enabled
const int group_blocks, // number of consecutive 16x16 blocks
// with a separate quantization scale
const bool is_zp_float // is zero point of float16 type?
>
__global__ void Marlin(MARLIN_KERNEL_PARAMS);
}
/*
* Modified by Neural Magic
* Copyright (C) Marlin.2024 Elias Frantar
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* Adapted from https://github.com/IST-DASLab/marlin
*/
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif
#include "quantization/gptq_marlin/marlin.cuh"
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
#include "core/scalar_type.hpp"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert(std::is_same<scalar_t, half>::value || \
std::is_same<scalar_t, nv_bfloat16>::value, \
"only float16 and bfloat16 is supported");
namespace MARLIN_NAMESPACE_NAME {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
template <typename scalar_t, // compute dtype, half or nv_float16
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the
// threadblock
const int thread_n_blocks, // same for n dimension (output)
const int thread_k_blocks, // same for k dimension (reduction)
const bool m_block_size_8, // whether m_block_size == 8
// only works when thread_m_blocks == 1
const int stages, // number of stages for the async global->shared
// fetch pipeline
const bool has_act_order, // whether act_order is enabled
const bool has_zp, // whether zero-points are enabled
const int group_blocks, // number of consecutive 16x16 blocks
// with a separate quantization scale
const bool is_zp_float // is zero point of float16 type?
>
__global__ void Marlin(
const int4* __restrict__ A, // fp16 input matrix of shape mxk
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
int4* __restrict__ C, // fp16 output buffer of shape mxn
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
// (k/groupsize)xn
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
const int* __restrict__ g_idx, // int32 group indices of shape k
const int32_t* __restrict__ sorted_token_ids_ptr, // moe sorted_ids
const int32_t* __restrict__ expert_ids_ptr, // moe expert ids
const int32_t* __restrict__ num_tokens_past_padded_ptr, // moe num tokens
const float* __restrict__ topk_weights_ptr, // moe top weights
int top_k, // num of experts per token
bool mul_topk_weights, // mul topk weights or not
bool is_ep, // expert parallelism
int num_groups, // number of scale groups per output channel
int prob_m, // batch dimension m
int prob_n, // output dimension n
int prob_k, // reduction dimension k
int* locks, // extra global storage for barrier synchronization
bool use_atomic_add, // whether to use atomic add to reduce
bool use_fp32_reduce // whether to use fp32 global reduce
) {}
} // namespace MARLIN_NAMESPACE_NAME
#else
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
// output/accumulation.
template <typename scalar_t>
__device__ inline void mma(const typename ScalarType<scalar_t>::FragA& a_frag,
const typename ScalarType<scalar_t>::FragB& frag_b,
typename ScalarType<scalar_t>::FragC& frag_c) {
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
float* c = reinterpret_cast<float*>(&frag_c);
if constexpr (std::is_same<scalar_t, half>::value) {
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else {
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
}
}
template <typename scalar_t>
__device__ inline void mma_trans(
const typename ScalarType<scalar_t>::FragA& a_frag,
const typename ScalarType<scalar_t>::FragB& frag_b,
const typename ScalarType<scalar_t>::FragB& frag_b2,
typename ScalarType<scalar_t>::FragC& frag_c) {
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
const uint32_t* b2 = reinterpret_cast<const uint32_t*>(&frag_b2);
float* c = reinterpret_cast<float*>(&frag_c);
if constexpr (std::is_same<scalar_t, half>::value) {
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else {
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
}
}
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout.
template <int count, typename scalar_t>
__device__ inline void ldsm(typename ScalarType<scalar_t>::FragA& frag_a,
const void* smem_ptr) {
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
if constexpr (count == 4) {
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
: "r"(smem));
} else if constexpr (count == 2) {
asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n"
: "=r"(a[0]), "=r"(a[1])
: "r"(smem));
} else if constexpr (count == 1) {
asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n"
: "=r"(a[0])
: "r"(smem));
} else {
static_assert(count == 1 || count == 2 || count == 4, "invalid count");
}
}
// Lookup-table based 3-input logical operation; explicitly used for
// dequantization as the compiler does not seem to automatically recognize it in
// all cases.
template <int lut>
__device__ inline int lop3(int a, int b, int c) {
int res;
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(res)
: "r"(a), "r"(b), "r"(c), "n"(lut));
return res;
}
// Constructs destination register by taking bytes from 2 sources (based on
// mask)
template <int start_byte, int mask>
__device__ inline uint32_t prmt(uint32_t a) {
uint32_t res;
asm volatile("prmt.b32 %0, %1, %2, %3;\n"
: "=r"(res)
: "r"(a), "n"(start_byte), "n"(mask));
return res;
}
template <typename scalar_t, int bit>
__device__ inline typename ScalarType<scalar_t>::FragB dequant(
int q, typename ScalarType<scalar_t>::FragB& frag_b);
//
// Efficiently dequantize 4bit values packed in an int32 value into a full
// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below,
// with some small changes:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
//
template <>
__device__ inline typename ScalarType<half>::FragB dequant<half, 4>(
int q, typename ScalarType<half>::FragB& frag_b) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const int SUB = 0x64086408;
const int MUL = 0x2c002c00;
const int ADD = 0xd480d480;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&SUB));
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&MUL),
*reinterpret_cast<const half2*>(&ADD));
return frag_b;
}
template <>
__device__ inline typename ScalarType<nv_bfloat16>::FragB
dequant<nv_bfloat16, 4>(int q,
typename ScalarType<nv_bfloat16>::FragB& frag_b) {
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t EX = 0x43004300;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
q >>= 4;
int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
static constexpr uint32_t MUL = 0x3F803F80;
static constexpr uint32_t ADD = 0xC308C308;
frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
return frag_b;
}
//
// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or
// bf16 Reference:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
//
template <>
__device__ inline typename ScalarType<half>::FragB dequant<half, 8>(
int q, typename ScalarType<half>::FragB& frag_b) {
static constexpr uint32_t mask_for_elt_01 = 0x5250;
static constexpr uint32_t mask_for_elt_23 = 0x5351;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
return frag_b;
}
template <>
__device__ inline typename ScalarType<nv_bfloat16>::FragB
dequant<nv_bfloat16, 8>(int q,
typename ScalarType<nv_bfloat16>::FragB& frag_b) {
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted =
reinterpret_cast<uint32_t*>(fp32_intermediates);
static constexpr uint32_t fp32_base = 0x4B000000;
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
fp32_intermediates[0] -= 8388736.f;
fp32_intermediates[1] -= 8388736.f;
fp32_intermediates[2] -= 8388736.f;
fp32_intermediates[3] -= 8388736.f;
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&frag_b);
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
fp32_intermediates_casted[1], 0x7632);
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
fp32_intermediates_casted[3], 0x7632);
return frag_b;
}
// Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization.
template <typename scalar_t>
__device__ inline void scale(typename ScalarType<scalar_t>::FragB& frag_b,
typename ScalarType<scalar_t>::FragS& frag_s,
int i) {
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
scalar_t2 s =
ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_s)[i]);
frag_b[0] = __hmul2(frag_b[0], s);
frag_b[1] = __hmul2(frag_b[1], s);
}
template <typename scalar_t>
__device__ inline void scale_and_sub(
typename ScalarType<scalar_t>::FragB& frag_b, scalar_t s, scalar_t zp) {
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
scalar_t2 s2 = ScalarType<scalar_t>::num2num2(s);
scalar_t2 zp2 = ScalarType<scalar_t>::num2num2(zp);
frag_b[0] = __hfma2(frag_b[0], s2, __hneg2(zp2));
frag_b[1] = __hfma2(frag_b[1], s2, __hneg2(zp2));
}
template <typename scalar_t>
__device__ inline void sub_zp(typename ScalarType<scalar_t>::FragB& frag_b,
typename ScalarType<scalar_t>::scalar_t2& frag_zp,
int i) {
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
scalar_t2 zp =
ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_zp)[i]);
frag_b[0] = __hsub2(frag_b[0], zp);
frag_b[1] = __hsub2(frag_b[1], zp);
}
// Same as above, but for act_order (each K is multiplied individually)
template <typename scalar_t>
__device__ inline void scale4(typename ScalarType<scalar_t>::FragB& frag_b,
typename ScalarType<scalar_t>::FragS& frag_s_1,
typename ScalarType<scalar_t>::FragS& frag_s_2,
typename ScalarType<scalar_t>::FragS& frag_s_3,
typename ScalarType<scalar_t>::FragS& frag_s_4,
int i) {
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
scalar_t2 s_val_1_2;
s_val_1_2.x = reinterpret_cast<scalar_t*>(&frag_s_1)[i];
s_val_1_2.y = reinterpret_cast<scalar_t*>(&frag_s_2)[i];
scalar_t2 s_val_3_4;
s_val_3_4.x = reinterpret_cast<scalar_t*>(&frag_s_3)[i];
s_val_3_4.y = reinterpret_cast<scalar_t*>(&frag_s_4)[i];
frag_b[0] = __hmul2(frag_b[0], s_val_1_2);
frag_b[1] = __hmul2(frag_b[1], s_val_3_4);
}
// Given 2 floats multiply by 2 scales (halves)
template <typename scalar_t>
__device__ inline void scale_float(float* c,
typename ScalarType<scalar_t>::FragS& s) {
scalar_t* s_ptr = reinterpret_cast<scalar_t*>(&s);
c[0] = __fmul_rn(c[0], ScalarType<scalar_t>::num2float(s_ptr[0]));
c[1] = __fmul_rn(c[1], ScalarType<scalar_t>::num2float(s_ptr[1]));
}
// Wait until barrier reaches `count`, then lock for current threadblock.
__device__ inline void barrier_acquire(int* lock, int count) {
if (threadIdx.x == 0) {
int state = -1;
do
// Guarantee that subsequent writes by this threadblock will be visible
// globally.
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
: "=r"(state)
: "l"(lock));
while (state != count);
}
__syncthreads();
}
// Release barrier and increment visitation count.
__device__ inline void barrier_release(int* lock, bool reset = false) {
__syncthreads();
if (threadIdx.x == 0) {
if (reset) {
lock[0] = 0;
return;
}
int val = 1;
// Make sure that all writes since acquiring this barrier are visible
// globally, while releasing the barrier.
asm volatile("fence.acq_rel.gpu;\n");
asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
:
: "l"(lock), "r"(val));
}
}
// Wait until value of lock to be negative, and then add 1
__device__ inline void wait_negative_and_add(int* lock) {
if (threadIdx.x == 0) {
int state = 0;
do
// Guarantee that subsequent writes by this threadblock will be visible
// globally.
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
: "=r"(state)
: "l"(lock));
while (state >= 0);
atomicAdd(lock, 1);
}
__syncthreads();
}
template <typename scalar_t, // compute dtype, half or nv_float16
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the
// threadblock
const int thread_n_blocks, // same for n dimension (output)
const int thread_k_blocks, // same for k dimension (reduction)
const bool m_block_size_8, // whether m_block_size == 8
// only works when thread_m_blocks == 1
const int stages, // number of stages for the async global->shared
// fetch pipeline
const bool has_act_order, // whether act_order is enabled
const bool has_zp, // whether zero-points are enabled
const int group_blocks, // number of consecutive 16x16 blocks
// with a separate quantization scale
const bool is_zp_float // is zero point of float16 type?
>
__global__ void Marlin(
const int4* __restrict__ A, // fp16 input matrix of shape mxk
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
int4* __restrict__ C, // fp16 output buffer of shape mxn
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
// (k/groupsize)xn
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
const int* __restrict__ g_idx, // int32 group indices of shape k
const int32_t* __restrict__ sorted_token_ids_ptr, // moe sorted_ids
const int32_t* __restrict__ expert_ids_ptr, // moe expert ids
const int32_t* __restrict__ num_tokens_past_padded_ptr, // moe num tokens
const float* __restrict__ topk_weights_ptr, // moe top weights
int top_k, // num of experts per token
bool mul_topk_weights, // mul topk weights or not
bool is_ep, // expert parallelism
int num_groups, // number of scale groups per output channel
int prob_m, // batch dimension m
int prob_n, // output dimension n
int prob_k, // reduction dimension k
int* locks, // extra global storage for barrier synchronization
bool use_atomic_add, // whether to use atomic add to reduce
bool use_fp32_reduce // whether to use fp32 global reduce
) {
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
// same size, which might involve multiple column "slices" (of width 16 *
// `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
// example:
// 0 1 3
// 0 2 3
// 1 2 4
// While this kind of partitioning makes things somewhat more complicated, it
// ensures good utilization of all SMs for many kinds of shape and GPU
// configurations, while requiring as few slow global cross-threadblock
// reductions as possible.
using Dtype = ScalarType<scalar_t>;
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
using FragA = typename ScalarType<scalar_t>::FragA;
using FragB = typename ScalarType<scalar_t>::FragB;
using FragC = typename ScalarType<scalar_t>::FragC;
using FragS = typename ScalarType<scalar_t>::FragS;
using FragZP = typename ScalarType<scalar_t>::FragZP;
extern __shared__ int4 sh[];
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
constexpr int pack_factor = 32 / w_type.size_bits();
static_assert(thread_m_blocks == 1 || !m_block_size_8);
constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks);
const int group_size =
(!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups;
const int scales_expert_stride = prob_n * prob_k / group_size / 8;
const int zp_expert_stride =
is_zp_float ? prob_n * prob_k / group_size / 8
: prob_n * prob_k / group_size / (pack_factor * 4);
// parallel: num valid moe blocks
int num_tokens_past_padded = num_tokens_past_padded_ptr[0];
int parallel = num_tokens_past_padded / moe_block_size;
int num_valid_blocks = parallel;
if (is_ep) {
for (int i = 0; i < parallel; i++) {
if (expert_ids_ptr[i] == -1) num_valid_blocks--;
}
}
int num_invalid_blocks = parallel - num_valid_blocks;
parallel = num_valid_blocks;
int k_tiles = prob_k / 16 / thread_k_blocks;
int n_tiles = prob_n / 16 / thread_n_blocks;
int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x);
if constexpr (!has_act_order && group_blocks != -1) {
if (group_blocks >= thread_k_blocks) {
// Ensure that the number of tiles in each stripe is a multiple of the
// groupsize; this avoids an annoying special case where a stripe starts
// in the middle of group.
iters = (group_blocks / thread_k_blocks) *
div_ceil(iters, (group_blocks / thread_k_blocks));
}
}
int slice_row = (iters * blockIdx.x) % k_tiles;
int slice_col_par = (iters * blockIdx.x) / k_tiles;
int slice_col = slice_col_par;
int slice_iters; // number of threadblock tiles in the current slice
int slice_count =
0; // total number of active threadblocks in the current slice
int slice_idx; // index of threadblock in current slice; numbered bottom to
// top
int par_id = 0;
int block_id = -1;
int64_t expert_id = 0; // use int64 to avoid computation result overflow
int old_expert_id = 0;
int64_t B_expert_off = 0;
int4* sh_block_sorted_ids_int4 = sh;
int32_t* sh_block_sorted_ids =
reinterpret_cast<int*>(sh_block_sorted_ids_int4);
int4* sh_block_topk_weights_int4 =
sh_block_sorted_ids_int4 + moe_block_size / 4;
scalar_t2* sh_block_topk_weights =
reinterpret_cast<scalar_t2*>(sh_block_topk_weights_int4);
int4* sh_new = sh_block_topk_weights_int4 + moe_block_size / 4;
int32_t block_num_valid_tokens = 0;
int32_t locks_off = 0;
// We can easily implement parallel problem execution by just remapping
// indices and advancing global pointers
if (slice_col_par >= n_tiles) {
slice_col = slice_col_par % n_tiles;
par_id = slice_col_par / n_tiles;
}
if (parallel * n_tiles >= gridDim.x) {
// when parallel * n_tiles >= sms
// then there are at most $sms$ conflict tile blocks
locks_off = blockIdx.x;
} else {
locks_off = (iters * blockIdx.x) / k_tiles - 1;
}
// read moe block data given block_id
// block_sorted_ids / block_num_valid_tokens / block_topk_weights
auto read_moe_block_data = [&](int block_id) {
block_num_valid_tokens = moe_block_size;
#pragma unroll
for (int i = 0; i < moe_block_size / 4; i++) {
int4 sorted_token_ids_int4 = reinterpret_cast<const int4*>(
sorted_token_ids_ptr)[block_id * moe_block_size / 4 + i];
int* sorted_token_ids = reinterpret_cast<int*>(&sorted_token_ids_int4);
#pragma unroll
for (int j = 0; j < 4; j++) {
if (sorted_token_ids[j] >= prob_m * top_k) {
block_num_valid_tokens = i * 4 + j;
break;
}
}
if (block_num_valid_tokens != moe_block_size) break;
}
__syncthreads();
int tid4 = threadIdx.x / 4;
if (threadIdx.x % 4 == 0 && threadIdx.x < block_num_valid_tokens) {
sh_block_sorted_ids_int4[tid4] = reinterpret_cast<const int4*>(
sorted_token_ids_ptr)[block_id * moe_block_size / 4 + tid4];
if (mul_topk_weights) {
#pragma unroll
for (int i = 0; i < 4; i++) {
sh_block_topk_weights[tid4 * 4 + i] =
Dtype::num2num2(Dtype::float2num(
topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]]));
}
}
}
__syncthreads();
};
// when move to next moe block, find the next block_id and expert_id
// and then read moe block data
auto update_next_moe_block_data = [&]() {
if (par_id >= parallel) return;
old_expert_id = expert_id;
if (num_invalid_blocks > 0) {
int skip_count = block_id == -1 ? par_id : 0;
block_id++;
for (int i = block_id; i < num_tokens_past_padded / moe_block_size; i++) {
expert_id = expert_ids_ptr[i];
if (expert_id != -1) {
if (skip_count == 0) {
block_id = i;
break;
};
skip_count--;
};
}
} else {
block_id = par_id;
expert_id = expert_ids_ptr[block_id];
}
B_expert_off = expert_id * prob_n * prob_k / (pack_factor * 4);
scales_ptr += (expert_id - old_expert_id) * scales_expert_stride;
if constexpr (has_zp) {
zp_ptr += (expert_id - old_expert_id) * zp_expert_stride;
}
if constexpr (has_act_order) {
g_idx += (expert_id - old_expert_id) * prob_k;
}
read_moe_block_data(block_id);
};
// Compute all information about the current slice which is required for
// synchronization.
auto init_slice = [&](bool first_init = false) {
slice_iters =
iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
if (slice_iters == 0) return;
if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
slice_count = 1;
slice_idx = 0;
int col_first = iters * div_ceil(k_tiles * slice_col_par, iters);
if (col_first <= k_tiles * (slice_col_par + 1)) {
int col_off = col_first - k_tiles * slice_col_par;
slice_count = div_ceil(k_tiles - col_off, iters);
if (col_off > 0) slice_count++;
int delta_first = iters * blockIdx.x - col_first;
if (delta_first < 0 || (col_off == 0 && delta_first == 0))
slice_idx = slice_count - 1;
else {
slice_idx = slice_count - 1 - delta_first / iters;
if (col_off > 0) slice_idx--;
}
}
if (parallel * n_tiles >= gridDim.x) {
if (slice_count > 1 && slice_idx == slice_count - 1) {
locks_off++;
}
} else {
locks_off++;
}
if (first_init && use_atomic_add && slice_count > 1 && slice_idx == 0) {
constexpr int threads_per_m = 16 * thread_n_blocks / 8;
int m_per_thread =
div_ceil(block_num_valid_tokens, threads / threads_per_m);
for (int i = 0; i < m_per_thread; i++) {
int row = threads / threads_per_m * i + threadIdx.x / threads_per_m;
if (row < block_num_valid_tokens) {
int64_t sorted_row = sh_block_sorted_ids[row];
int col = slice_col * 16 * thread_n_blocks / 8 +
threadIdx.x % threads_per_m;
C[sorted_row * prob_n / 8 + col] = {0, 0, 0, 0};
}
}
// After write zero to output, write a negative value to lock.
// Every SM that processes the same slice would wait for
// the negative value, and then atomicAdd 1 to it.
// After all SMs are processed, the lock value would back to 0 again.
__syncthreads();
if (threadIdx.x == 0) locks[locks_off] = 1 - slice_count;
}
if (slice_col == n_tiles) {
slice_col = 0;
par_id++;
update_next_moe_block_data();
}
};
update_next_moe_block_data();
init_slice(true);
// A sizes/strides
// stride of the A matrix in global memory
int a_gl_stride = prob_k / 8;
// stride of an A matrix tile in shared memory
constexpr int a_sh_stride = 16 * thread_k_blocks / 8;
// delta between subsequent A tiles in global memory
constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8;
// between subsequent accesses within a tile
int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o);
// between shared memory writes
constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o);
// between shared memory tile reads
constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4));
// within a shared memory tile
constexpr int a_sh_rd_delta_i = a_sh_stride * 16;
// overall size of a tile
constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks);
// number of shared write iterations for a tile
constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta);
// B sizes/strides
int b_gl_stride = 16 * prob_n / (pack_factor * 4);
constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4;
constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2;
constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs;
int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads);
constexpr int b_sh_wr_delta = threads * b_thread_vecs;
constexpr int b_sh_rd_delta = threads * b_thread_vecs;
constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
// Scale sizes/strides without act_order
int s_gl_stride = prob_n / 8;
constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
constexpr int s_tb_groups =
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
? thread_k_blocks / group_blocks
: 1;
constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
int s_gl_rd_delta = s_gl_stride;
// Scale size/strides with act_order
constexpr int tb_k = 16 * thread_k_blocks;
constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0;
// constexpr int act_s_row_stride = 1;
// int act_s_col_stride = act_s_row_stride * num_groups;
int act_s_col_stride = 1;
int act_s_col_warp_stride = act_s_col_stride * 8;
int tb_n_warps = thread_n_blocks / 4;
int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;
// Zero-points sizes/strides
int zp_gl_stride = is_zp_float ? prob_n / 8 : (prob_n / pack_factor) / 4;
constexpr int zp_sh_stride = is_zp_float
? 16 * thread_n_blocks / 8
: ((16 * thread_n_blocks) / pack_factor) / 4;
constexpr int zp_tb_groups = s_tb_groups;
constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0;
int zp_gl_rd_delta = zp_gl_stride;
// Global A read index of current thread.
int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
(threadIdx.x % a_gl_rd_delta_o);
a_gl_rd += a_gl_rd_delta_o * slice_row;
// Shared write index of current thread.
int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +
(threadIdx.x % a_gl_rd_delta_o);
// Shared read index.
int a_sh_rd =
a_sh_stride * ((threadIdx.x % 32) % (16 / (m_block_size_8 ? 2 : 1))) +
(threadIdx.x % 32) / (16 / (m_block_size_8 ? 2 : 1));
a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) +
(threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
b_gl_rd += b_sh_stride * slice_col;
b_gl_rd += b_gl_rd_delta_o * slice_row;
int b_sh_wr = threadIdx.x * b_thread_vecs;
int b_sh_rd = threadIdx.x * b_thread_vecs;
// For act_order
constexpr int k_iter_size = tb_k / b_sh_wr_iters;
int slice_k_start = tb_k * slice_row;
int slice_k_finish = slice_k_start + tb_k * slice_iters;
int slice_k_start_shared_fetch = slice_k_start;
int slice_n_offset = act_s_col_tb_stride * slice_col;
// No act_order
int s_gl_rd;
if constexpr (!has_act_order) {
if constexpr (group_blocks == -1) {
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
} else {
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
s_sh_stride * slice_col + threadIdx.x;
}
}
int s_sh_wr = threadIdx.x;
bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
// Zero-points
int zp_gl_rd;
if constexpr (has_zp) {
if constexpr (group_blocks == -1) {
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
} else {
zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
zp_sh_stride * slice_col + threadIdx.x;
}
}
int zp_sh_wr = threadIdx.x;
bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride;
// We use a different scale layout for grouped and column-wise quantization as
// we scale a `half2` tile in column-major layout in the former and in
// row-major in the latter case.
int s_sh_rd;
if constexpr (group_blocks != -1)
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 4;
else if constexpr (group_blocks == -1 && (m_block_size_8 || has_zp))
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 8;
else
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) % 4;
// Zero-points have the same read layout as the scales
// (without column-wise case)
constexpr int num_col_threads = 8;
constexpr int num_row_threads = 4;
constexpr int num_ints_per_thread = 8 / pack_factor;
int zp_sh_rd;
if constexpr (has_zp) {
if constexpr (is_zp_float) {
if constexpr (group_blocks != -1) {
zp_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 4;
}
} else {
zp_sh_rd = num_ints_per_thread * num_col_threads *
((threadIdx.x / 32) % (thread_n_blocks / 4)) +
num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads);
}
}
// To ensure that writing and reading A tiles to/from shared memory, the
// latter in fragment format, is fully bank conflict free, we need to use a
// rather fancy XOR-based layout. The key here is that neither reads nor
// writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
// same shared memory banks. Further, it seems (based on NSight-Compute) that
// each warp must also write a consecutive memory segment?
auto transform_a = [&](int i) {
int row = i / a_gl_rd_delta_o;
return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
};
// Since the computation of this remapping is non-trivial and, due to our main
// loop unrolls, all shared memory accesses are static, we simply precompute
// both transformed reads and writes.
int a_sh_wr_trans[a_sh_wr_iters];
#pragma unroll
for (int i = 0; i < a_sh_wr_iters; i++)
a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) {
#pragma unroll
for (int j = 0; j < thread_m_blocks; j++)
a_sh_rd_trans[i][j] =
transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
}
// Since B-accesses have non-constant stride they have to be computed at
// runtime; we break dependencies between subsequent accesses with a tile by
// maintining multiple pointers (we have enough registers), a tiny
// optimization.
const int4* B_ptr[b_sh_wr_iters];
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++)
B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
// Shared memory storage for global fetch pipelines.
int4* sh_a = sh_new;
int4* sh_b = sh_a + (stages * a_sh_stage);
int4* sh_g_idx = sh_b + (stages * b_sh_stage);
int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
int4* sh_s = sh_zp + (stages * zp_sh_stage);
int4* sh_red = sh_b;
// Register storage for double buffer of shared memory reads.
FragA frag_a[2][thread_m_blocks];
I4 frag_b_quant[2][b_thread_vecs];
FragC frag_c[thread_m_blocks][4][2];
FragS frag_s[2][4]; // No act-order
FragS act_frag_s[2][4][4]; // For act-order
int frag_qzp[2][num_ints_per_thread]; // Zero-points
FragZP frag_zp; // Zero-points in fp16
FragZP frag_zpf[2]; // Zero-points in fp16 in HQQ
// Zero accumulators.
auto zero_accums = [&]() {
#pragma unroll
for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
reinterpret_cast<float*>(frag_c)[i] = 0;
};
int sh_first_group_id = -1;
int sh_num_groups = -1;
constexpr int sh_max_num_groups = 32;
auto fetch_act_order_scales_to_shared = [&](bool is_async, int first_group_id,
int last_group_id) {
sh_first_group_id = first_group_id;
sh_num_groups = last_group_id - first_group_id + 1;
if (sh_num_groups < sh_max_num_groups) {
sh_num_groups = sh_max_num_groups;
}
if (sh_first_group_id + sh_num_groups > num_groups) {
sh_num_groups = num_groups - sh_first_group_id;
}
int row_offset = first_group_id * s_gl_stride;
if (is_async) {
for (int i = 0; i < sh_num_groups; i++) {
if (threadIdx.x < s_sh_stride) {
cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x],
&scales_ptr[row_offset + (i * s_gl_stride) +
slice_n_offset + threadIdx.x]);
}
}
} else {
for (int i = 0; i < sh_num_groups; i++) {
if (threadIdx.x < s_sh_stride) {
sh_s[(i * s_sh_stride) + threadIdx.x] =
scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset +
threadIdx.x];
}
}
}
};
// Asynchronously fetch the next A, B and s tile from global to the next
// shared memory pipeline location.
int a_remaining_load_count_in_slice = stages;
auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
if (pred) {
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
if (prob_k > thread_k_blocks * 16 * stages || slice_col == 0 ||
a_remaining_load_count_in_slice > 0) {
a_remaining_load_count_in_slice--;
#pragma unroll
for (int i = 0; i < a_sh_wr_iters; i++) {
int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off;
int row = a_idx / a_gl_stride;
int64_t sorted_row = 0;
if (!m_block_size_8 || row < 8)
sorted_row = sh_block_sorted_ids[row] / top_k;
int64_t true_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride;
cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[true_idx],
row < block_num_valid_tokens);
}
}
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) {
#pragma unroll
for (int j = 0; j < b_thread_vecs; j++) {
cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j],
B_ptr[i] + j + B_expert_off);
}
B_ptr[i] += b_gl_rd_delta_o;
}
if constexpr (has_act_order) {
// Fetch g_idx thread-block portion
int full_pipe = a_off;
int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe;
if (cur_k < prob_k && cur_k < slice_k_finish) {
int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
int4 const* cur_g_idx_stage_ptr =
reinterpret_cast<int4 const*>(&g_idx[cur_k]);
if (threadIdx.x < g_idx_stage) {
cp_async4_pred(&sh_g_idx_stage[threadIdx.x],
&cur_g_idx_stage_ptr[threadIdx.x]);
}
}
} else {
if constexpr (group_blocks != -1) {
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
if constexpr (group_blocks >= thread_k_blocks) {
// Only fetch scales if this tile starts a new group
if (pipe % (group_blocks / thread_k_blocks) == 0) {
if (s_sh_wr_pred) {
cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
}
s_gl_rd += s_gl_rd_delta;
}
} else {
for (int i = 0; i < s_tb_groups; i++) {
if (s_sh_wr_pred) {
cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr],
&scales_ptr[s_gl_rd]);
}
s_gl_rd += s_gl_rd_delta;
}
}
}
if constexpr (has_zp && group_blocks != -1) {
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
if constexpr (group_blocks >= thread_k_blocks) {
// Only fetch zero-points if this tile starts a new group
if (pipe % (group_blocks / thread_k_blocks) == 0) {
if (zp_sh_wr_pred) {
cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]);
}
zp_gl_rd += zp_gl_rd_delta;
}
} else {
for (int i = 0; i < zp_tb_groups; i++) {
if (zp_sh_wr_pred) {
cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr],
&zp_ptr[zp_gl_rd]);
}
zp_gl_rd += zp_gl_rd_delta;
}
}
}
}
}
// Insert a fence even when we are winding down the pipeline to ensure that
// waiting is also correct at this point.
cp_async_fence();
};
auto fetch_col_zp_to_shared = [&]() {
if (zp_sh_wr_pred) {
cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]);
}
};
auto fetch_col_scale_to_shared = [&]() {
if (s_sh_wr_pred) {
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
}
};
// Wait until the next thread tile has been loaded to shared memory.
auto wait_for_stage = [&]() {
// We only have `stages - 2` active fetches since we are double buffering
// and can only issue the next fetch when it is guaranteed that the previous
// shared memory load is fully complete (as it may otherwise be
// overwritten).
cp_async_wait<stages - 2>();
__syncthreads();
};
// Load the next sub-tile from the current location in the shared memory pipe
// into the current register buffer.
auto fetch_to_registers = [&](int k, int pipe) {
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++)
ldsm<m_block_size_8 ? 2 : 4, scalar_t>(
frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
#pragma unroll
for (int i = 0; i < b_thread_vecs; i++) {
frag_b_quant[k % 2][i] = *reinterpret_cast<I4*>(
&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]);
}
};
bool is_same_group[stages];
int same_group_id[stages];
auto init_same_group = [&](int pipe) {
if constexpr (!has_act_order) {
return;
}
int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
int* sh_g_idx_int_ptr = reinterpret_cast<int*>(sh_g_idx_stage);
int group_id_1 = sh_g_idx_int_ptr[0];
int group_id_2 = sh_g_idx_int_ptr[tb_k - 1];
is_same_group[pipe] = group_id_1 == group_id_2;
same_group_id[pipe] = group_id_1;
};
auto fetch_scales_to_registers = [&](int k, int full_pipe) {
int pipe = full_pipe % stages;
if constexpr (!has_act_order) {
// No act-order case
if constexpr (group_blocks == -1) {
// load only when starting a new slice
if (k == 0 && full_pipe == 0) {
reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd];
reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
}
} else if constexpr (group_blocks != -1) {
if constexpr (group_blocks >= thread_k_blocks) {
int4* sh_s_stage =
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
} else {
int warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps;
int cur_k = warp_row * 16;
cur_k += k_iter_size * (k % b_sh_wr_iters);
int k_blocks = cur_k / 16;
int cur_group_id = k_blocks / group_blocks;
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
}
}
return;
}
// Act-order case
// Determine K of the "current" thread-block
int cur_k = slice_k_start + tb_k * full_pipe;
if (cur_k >= prob_k || cur_k >= slice_k_finish) {
return;
}
// Reset (to current thread-block) since we read g_idx portion from the
// shared memory
cur_k = 0;
// Progress to current iteration
cur_k += k_iter_size * (k % b_sh_wr_iters);
// Determine "position" inside the thread-block (based on warp and
// thread-id)
int warp_id = threadIdx.x / 32;
int n_warps =
thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N
int warp_row = warp_id / n_warps;
int warp_col = warp_id % n_warps;
cur_k += warp_row * 16;
int th_id = threadIdx.x % 32;
cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix
int s_col_shift =
/*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) +
(th_id / 4) * act_s_col_stride;
if (is_same_group[pipe]) {
if (k % 2 == 0) {
*(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0]))) =
sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride +
s_col_shift];
} else {
*(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0]))) =
*(reinterpret_cast<int4*>(&(act_frag_s[(k - 1) % 2][0][0])));
}
for (int i = 1; i < 4; i++) {
*(reinterpret_cast<int4*>(&(act_frag_s[k % 2][i][0]))) =
*(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0])));
}
return;
}
int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
int* sh_g_idx_int_ptr = reinterpret_cast<int*>(sh_g_idx_stage);
constexpr int k_frag_offsets[4] = {0, 1, 8,
9}; // Tensor core offsets per thread
#pragma unroll
for (int i = 0; i < 4; i++) {
int actual_k = cur_k + k_frag_offsets[i];
int group_id = sh_g_idx_int_ptr[actual_k];
int rel_group_id = group_id - sh_first_group_id;
*(reinterpret_cast<int4*>(&(act_frag_s[k % 2][i][0]))) =
sh_s[rel_group_id * s_sh_stride + s_col_shift];
}
};
auto fetch_zp_to_registers = [&](int k, int full_pipe) {
// This code does not handle group_blocks == 0,
// which signifies act_order.
// has_zp implies AWQ, which doesn't have act_order,
static_assert(!has_zp || group_blocks != 0);
if constexpr (has_zp && !is_zp_float) {
int pipe = full_pipe % stages;
if constexpr (group_blocks == -1) {
// load only when starting a new slice
if (k == 0 && full_pipe == 0) {
#pragma unroll
for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] = (reinterpret_cast<int*>(sh_zp))[zp_sh_rd + i];
}
}
} else if constexpr (group_blocks >= thread_k_blocks) {
int4* sh_zp_stage =
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] =
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
}
} else {
int warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps;
int cur_k = warp_row * 16;
cur_k += k_iter_size * (k % b_sh_wr_iters);
int k_blocks = cur_k / 16;
int cur_group_id = 0;
// Suppress bogus and persistent divide-by-zero warning
#pragma nv_diagnostic push
#pragma nv_diag_suppress divide_by_zero
cur_group_id = k_blocks / group_blocks;
#pragma nv_diagnostic pop
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
sh_zp_stage += cur_group_id * zp_sh_stride;
for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] =
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
}
}
}
else if constexpr (has_zp && is_zp_float) {
int pipe = full_pipe % stages;
if constexpr (group_blocks != -1) {
if constexpr (group_blocks >= thread_k_blocks) {
int4* sh_zp_stage =
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd];
} else {
int warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps;
int cur_k = warp_row * 16;
cur_k += k_iter_size * (k % b_sh_wr_iters);
int k_blocks = cur_k / 16;
// Suppress bogus and persistent divide-by-zero warning
#pragma nv_diagnostic push
#pragma nv_diag_suppress divide_by_zero
int cur_group_id = k_blocks / group_blocks;
#pragma nv_diagnostic pop
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] =
sh_zp_stage[zp_sh_rd + cur_group_id * zp_sh_stride];
}
}
}
};
// Execute the actual tensor core matmul of a sub-tile.
bool is_first_matmul_in_slice = true;
auto matmul = [&](int k) {
int k2 = k % 2;
const bool is_new_zp =
((group_blocks != -1) && (group_blocks < thread_k_blocks || k == 0)) ||
(group_blocks == -1 && is_first_matmul_in_slice);
if constexpr (has_zp && !is_zp_float) {
if (is_new_zp) {
if constexpr (group_blocks == -1) is_first_matmul_in_slice = false;
FragB frag_zp_0;
FragB frag_zp_1;
int zp_quant_0, zp_quant_1;
if constexpr (w_type.size_bits() == 4) {
zp_quant_0 = frag_qzp[k2][0];
zp_quant_1 = zp_quant_0 >> 8;
} else {
static_assert(w_type.size_bits() == 8);
zp_quant_0 = frag_qzp[k2][0];
zp_quant_1 = frag_qzp[k2][1];
}
dequant<scalar_t, w_type.size_bits()>(zp_quant_0, frag_zp_0);
dequant<scalar_t, w_type.size_bits()>(zp_quant_1, frag_zp_1);
frag_zp[0] = frag_zp_0[0];
frag_zp[1] = frag_zp_0[1];
frag_zp[2] = frag_zp_1[0];
frag_zp[3] = frag_zp_1[1];
}
}
// We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations.
#pragma unroll
for (int j = 0; j < 4; j++) {
FragB frag_b0;
FragB frag_b1;
int b_quant_0, b_quant_1;
if constexpr (w_type.size_bits() == 4) {
b_quant_0 = frag_b_quant[k2][0][j];
b_quant_1 = b_quant_0 >> 8;
} else {
static_assert(w_type.size_bits() == 8);
int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k2]);
b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
}
dequant<scalar_t, w_type.size_bits()>(b_quant_0, frag_b0);
dequant<scalar_t, w_type.size_bits()>(b_quant_1, frag_b1);
// Apply scale to frag_b0
if constexpr (has_act_order) {
static_assert(group_blocks != -1);
scale4<scalar_t>(frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j],
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0);
scale4<scalar_t>(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j],
act_frag_s[k][2][j], act_frag_s[k2][3][j], 1);
} else if constexpr (has_zp && !is_zp_float && group_blocks == -1) {
int idx = (threadIdx.x / 4) % 2;
scalar_t2 s2 = Dtype::nums2num2(
reinterpret_cast<scalar_t*>(&frag_s[j / 2][j % 2 * 2 + 0])[idx],
reinterpret_cast<scalar_t*>(&frag_s[j / 2][j % 2 * 2 + 1])[idx]);
if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2);
scale_and_sub<scalar_t>(frag_b0, s2.x, frag_zp[j].x);
scale_and_sub<scalar_t>(frag_b1, s2.y, frag_zp[j].y);
} else if constexpr (has_zp && !is_zp_float && group_blocks != -1) {
if (is_new_zp)
frag_zp[j] = __hmul2(frag_zp[j],
*reinterpret_cast<scalar_t2*>(&frag_s[k2][j]));
scale_and_sub<scalar_t>(frag_b0, frag_s[k % 2][j][0].x, frag_zp[j].x);
scale_and_sub<scalar_t>(frag_b1, frag_s[k % 2][j][0].y, frag_zp[j].y);
} else if constexpr (has_zp && is_zp_float && group_blocks != -1) {
if (is_new_zp)
frag_zpf[k2][j] = __hmul2(
frag_zpf[k2][j], *reinterpret_cast<scalar_t2*>(&frag_s[k2][j]));
scale_and_sub<scalar_t>(frag_b0, frag_s[k2][j].x, frag_zpf[k2][j].x);
scale_and_sub<scalar_t>(frag_b1, frag_s[k2][j].y, frag_zpf[k2][j].y);
} else if constexpr (group_blocks != -1) {
scale<scalar_t>(frag_b0, frag_s[k2][j], 0);
scale<scalar_t>(frag_b1, frag_s[k2][j], 1);
}
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
if constexpr (m_block_size_8) {
mma_trans<scalar_t>(frag_a[k2][i], frag_b0, frag_b1, frag_c[i][j][0]);
} else {
mma<scalar_t>(frag_a[k2][i], frag_b0, frag_c[i][j][0]);
mma<scalar_t>(frag_a[k2][i], frag_b1, frag_c[i][j][1]);
}
}
}
};
// Since we slice across the k dimension of a tile in order to increase the
// number of warps while keeping the n dimension of a tile reasonable, we have
// multiple warps that accumulate their partial sums of the same output
// location; which we have to reduce over in the end. We do in shared memory.
auto thread_block_reduce = [&]() {
constexpr int red_off = threads / b_sh_stride_threads / 2;
if (red_off >= 1) {
int red_idx = threadIdx.x / b_sh_stride_threads;
constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
constexpr int red_sh_delta = b_sh_stride_threads;
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
(threadIdx.x % b_sh_stride_threads);
// Parallel logarithmic shared memory reduction. We make sure to avoid any
// unnecessary read or write iterations, e.g., for two warps we write only
// once by warp 1 and read only once by warp 0.
#pragma unroll
for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
#pragma unroll
for (int i = red_off; i > 0; i /= 2) {
if (i <= red_idx && red_idx < 2 * i) {
#pragma unroll
for (int j = 0; j < 4 * 2; j += (m_block_size_8 ? 2 : 1)) {
int red_sh_wr =
red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
if (i < red_off) {
float* c_rd = reinterpret_cast<float*>(
&sh_red[red_sh_delta * j + red_sh_rd]);
float* c_wr = reinterpret_cast<float*>(&sh_red[red_sh_wr]);
#pragma unroll
for (int k = 0; k < 4; k++)
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
c_rd[k] + c_wr[k];
}
sh_red[red_sh_wr] =
reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
}
}
__syncthreads();
}
if (red_idx == 0) {
#pragma unroll
for (int i = 0; i < 4 * 2; i += (m_block_size_8 ? 2 : 1)) {
float* c_rd =
reinterpret_cast<float*>(&sh_red[red_sh_delta * i + red_sh_rd]);
#pragma unroll
for (int j = 0; j < 4; j++)
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
c_rd[j];
}
}
__syncthreads();
}
}
};
// Since multiple threadblocks may process parts of the same column slice, we
// finally have to globally reduce over the results. As the striped
// partitioning minimizes the number of such reductions and our outputs are
// usually rather small, we perform this reduction serially in L2 cache.
auto global_reduce_fp16 = [&](bool first = false, bool last = false) {
// We are very careful here to reduce directly in the output buffer to
// maximize L2 cache utilization in this step. To do this, we write out
// results in FP16 (but still reduce with FP32 compute).
constexpr int active_threads = 32 * thread_n_blocks / 4;
bool is_th_active = threadIdx.x < active_threads;
if (!is_th_active) {
return;
}
int c_gl_stride = prob_n / 8;
int c_gl_wr_delta_o = 8 * c_gl_stride;
int c_gl_wr_delta_i = 4 * (active_threads / 32);
int c_gl_wr;
if constexpr (m_block_size_8) {
c_gl_wr = c_gl_stride * ((threadIdx.x % 4) * 2) + 4 * (threadIdx.x / 32) +
(threadIdx.x % 32) / 8;
c_gl_wr += (2 * thread_n_blocks) * slice_col;
} else {
c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) +
4 * (threadIdx.x / 32) + threadIdx.x % 4;
c_gl_wr += (2 * thread_n_blocks) * slice_col;
}
constexpr int c_sh_wr_delta = active_threads;
int c_sh_wr = threadIdx.x;
if (!first) {
#pragma unroll
for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) {
int c_idx;
if constexpr (m_block_size_8)
c_idx = c_gl_wr + i * c_gl_stride +
(threadIdx.x % 8) / 4 * c_gl_wr_delta_i;
else
c_idx =
c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2);
if (c_idx / c_gl_stride < block_num_valid_tokens) {
int64_t sorted_row = sh_block_sorted_ids[c_idx / c_gl_stride];
int64_t true_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride;
sh_red[c_sh_wr + c_sh_wr_delta * i] = C[true_idx];
}
}
}
#pragma unroll
for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) {
if (!first) {
int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta];
#pragma unroll
for (int j = 0; j < 2 * 4; j++) {
int delta = 0;
if constexpr (m_block_size_8) {
delta = j % 2 == 1 ? -2 : 0;
}
reinterpret_cast<float*>(
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta] +=
Dtype::num2float(reinterpret_cast<scalar_t*>(&c_red)[j]);
}
}
if (!last) {
int4 c;
#pragma unroll
for (int j = 0; j < 2 * 4; j++) {
int delta = 0;
if constexpr (m_block_size_8) {
delta = j % 2 == 1 ? -2 : 0;
}
reinterpret_cast<scalar_t*>(&c)[j] =
Dtype::float2num(reinterpret_cast<float*>(
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta]);
}
int c_idx;
if constexpr (m_block_size_8)
c_idx = c_gl_wr + i * c_gl_stride +
(threadIdx.x % 8) / 4 * c_gl_wr_delta_i;
else
c_idx =
c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2);
if (c_idx / c_gl_stride < block_num_valid_tokens) {
int64_t sorted_row = sh_block_sorted_ids[c_idx / c_gl_stride];
int64_t true_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride;
C[true_idx] = c;
}
}
}
};
// Globally reduce over threadblocks that compute the same column block.
// We use a tmp C buffer to reduce in full fp32 precision.
auto global_reduce_fp32 = [&](bool first = false, bool last = false) {
constexpr int tb_m = thread_m_blocks * 16;
constexpr int tb_n = thread_n_blocks * 16;
constexpr int c_size = tb_m * tb_n * sizeof(float) / 16;
constexpr int active_threads = 32 * thread_n_blocks / 4;
bool is_th_active = threadIdx.x < active_threads;
constexpr int num_floats = thread_m_blocks * 4 * 2 * 4;
constexpr int th_size = num_floats * sizeof(float) / 16;
int c_cur_offset = locks_off * c_size;
if (!is_th_active) {
return;
}
if (!first) {
float* frag_c_ptr = reinterpret_cast<float*>(&frag_c);
#pragma unroll
for (int k = 0; k < th_size; k++) {
if constexpr (m_block_size_8) {
if (k % 2) continue;
} else {
if (k / 8 * 16 + (threadIdx.x % 32) / 4 >= block_num_valid_tokens)
continue;
}
sh_red[threadIdx.x] =
C_tmp[c_cur_offset + active_threads * k + threadIdx.x];
float* sh_c_ptr = reinterpret_cast<float*>(&sh_red[threadIdx.x]);
#pragma unroll
for (int f = 0; f < 4; f++) {
frag_c_ptr[k * 4 + f] += sh_c_ptr[f];
}
}
}
if (!last) {
int4* frag_c_ptr = reinterpret_cast<int4*>(&frag_c);
#pragma unroll
for (int k = 0; k < th_size; k++) {
if constexpr (m_block_size_8) {
if (k % 2) continue;
} else {
if (k / 8 * 16 + (threadIdx.x % 32) / 4 >= block_num_valid_tokens)
continue;
}
C_tmp[c_cur_offset + active_threads * k + threadIdx.x] = frag_c_ptr[k];
}
}
};
// Write out the reduce final result in the correct layout. We only actually
// reshuffle matrix fragments in this step, the reduction above is performed
// in fragment layout.
auto write_result = [&]() {
int c_gl_stride = prob_n / 8;
constexpr int c_sh_stride = 2 * thread_n_blocks + 1;
int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
constexpr int c_sh_rd_delta =
c_sh_stride * (threads / (2 * thread_n_blocks));
int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) +
(threadIdx.x % (2 * thread_n_blocks));
c_gl_wr += (2 * thread_n_blocks) * slice_col;
int c_sh_wr;
if constexpr (m_block_size_8) {
c_sh_wr = (8 * c_sh_stride) * ((threadIdx.x % 32) % 4 * 2) +
(threadIdx.x % 32) / 4;
c_sh_wr += 64 * (threadIdx.x / 32);
} else {
c_sh_wr =
(4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;
c_sh_wr += 32 * (threadIdx.x / 32);
}
int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) +
(threadIdx.x % (2 * thread_n_blocks));
// We first reorder in shared memory to guarantee the most efficient final
// global write patterns
auto write = [&](int idx, float c0, float c1, FragS& s) {
scalar_t2 res =
Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1));
// For per-column quantization we finally apply the scale here (only for
// 4-bit)
if constexpr (!has_act_order && group_blocks == -1 &&
w_type.size_bits() == 4 && !has_zp) {
res = __hmul2(res, s[0]);
}
if constexpr (m_block_size_8) {
((scalar_t*)sh_red)[idx] = res.x;
((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y;
} else {
((scalar_t2*)sh_red)[idx] = res;
}
};
if (threadIdx.x / 32 < thread_n_blocks / 4) {
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
#pragma unroll
for (int j = 0; j < 4; j++) {
if constexpr (m_block_size_8) {
int wr = c_sh_wr + 16 * j;
write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1],
frag_s[j / 2][2 * (j % 2) + 0]);
write(wr + 8, frag_c[i][j][0][2], frag_c[i][j][0][3],
frag_s[j / 2][2 * (j % 2) + 1]);
} else {
int wr = c_sh_wr + 8 * j;
write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],
frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);
write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],
frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);
write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],
frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);
write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],
frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);
}
}
c_sh_wr += 16 * (4 * c_sh_stride);
}
}
__syncthreads();
#pragma unroll
for (int i = 0;
i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
i++) {
int row = c_gl_wr / c_gl_stride;
if (row < block_num_valid_tokens) {
int64_t sorted_row = sh_block_sorted_ids[row];
int64_t true_idx = sorted_row * c_gl_stride + c_gl_wr % c_gl_stride;
scalar_t2 topk_weight_score;
if (mul_topk_weights) topk_weight_score = sh_block_topk_weights[row];
if (use_atomic_add && slice_count > 1 || mul_topk_weights) {
scalar_t2* C_half2 = reinterpret_cast<scalar_t2*>(&C[true_idx]);
scalar_t2* sh_red_half2 =
reinterpret_cast<scalar_t2*>(&sh_red[c_sh_rd]);
#pragma unroll
for (int a = 0; a < 4; a++) {
scalar_t2 res = sh_red_half2[a];
if (mul_topk_weights) {
res = __hmul2(res, topk_weight_score);
}
if (use_atomic_add && slice_count > 1) {
atomicAdd(&C_half2[a], res);
} else {
C_half2[a] = res;
};
}
} else {
C[true_idx] = sh_red[c_sh_rd];
}
c_gl_wr += c_gl_wr_delta;
c_sh_rd += c_sh_rd_delta;
}
}
__syncthreads();
};
// Start global fetch and register load pipelines.
auto start_pipes = [&]() {
#pragma unroll
for (int i = 0; i < stages - 1; i++) {
if (has_act_order && i == 0) {
int last_g_idx = slice_k_start + stages * tb_k * 2;
if (last_g_idx >= prob_k) {
last_g_idx = prob_k - 1;
}
fetch_act_order_scales_to_shared(true, g_idx[slice_k_start],
g_idx[last_g_idx]);
}
if constexpr (has_zp && !is_zp_float && group_blocks == -1) {
if (i == 0) {
fetch_col_zp_to_shared();
fetch_col_scale_to_shared();
}
}
fetch_to_shared(i, i, i < slice_iters);
}
zero_accums();
wait_for_stage();
init_same_group(0);
fetch_to_registers(0, 0);
fetch_scales_to_registers(0, 0);
fetch_zp_to_registers(0, 0);
a_gl_rd += a_gl_rd_delta_o * (stages - 1);
slice_k_start_shared_fetch += tb_k * (stages - 1);
};
if (slice_iters) {
start_pipes();
}
// Main loop.
while (slice_iters) {
// We unroll over both the global fetch and the register load pipeline to
// ensure all shared memory accesses are static. Note that both pipelines
// have even length meaning that the next iteration will always start at
// index 0.
#pragma unroll
for (int pipe = 0; pipe < stages;) {
#pragma unroll
for (int k = 0; k < b_sh_wr_iters; k++) {
fetch_to_registers(k + 1, pipe % stages);
fetch_scales_to_registers(k + 1, pipe);
fetch_zp_to_registers(k + 1, pipe);
if (k == b_sh_wr_iters - 2) {
fetch_to_shared((pipe + stages - 1) % stages, pipe,
slice_iters >= stages);
pipe++;
wait_for_stage();
init_same_group(pipe % stages);
}
matmul(k);
}
slice_iters--;
if (slice_iters == 0) {
break;
}
}
a_remaining_load_count_in_slice = 0;
a_gl_rd += a_gl_rd_delta_o * stages;
slice_k_start += tb_k * stages;
slice_k_start_shared_fetch += tb_k * stages;
if constexpr (has_act_order) {
int first_group_id = g_idx[slice_k_start];
int last_g_idx = slice_k_start + stages * tb_k * 2;
if (last_g_idx >= prob_k) {
last_g_idx = prob_k - 1;
}
int last_group_id = g_idx[last_g_idx];
if (last_group_id >= sh_first_group_id + sh_num_groups) {
fetch_act_order_scales_to_shared(false, first_group_id, last_group_id);
__syncthreads();
}
}
// Process results and, if necessary, proceed to the next column slice.
// While this pattern may not be the most readable, other ways of writing
// the loop seemed to noticeably worse performance after compilation.
if (slice_iters == 0) {
cp_async_wait<0>();
bool last = slice_idx == slice_count - 1;
// For per-column scales, we only fetch them here in the final step before
// write-out
if constexpr (!has_act_order && group_blocks == -1 && !has_zp) {
if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
if (s_sh_wr_pred) {
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
}
cp_async_fence();
}
}
thread_block_reduce();
if constexpr (!has_act_order && group_blocks == -1 && !has_zp) {
if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
cp_async_wait<0>();
__syncthreads();
if (threadIdx.x / 32 < thread_n_blocks / 4) {
reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
if constexpr (m_block_size_8) {
int idx = (threadIdx.x / 4) % 2;
scalar_t2* frag_s_half2 = reinterpret_cast<scalar_t2*>(frag_s);
#pragma unroll
for (int i = 0; i < 8; i++) {
frag_s_half2[i] = Dtype::num2num2(
reinterpret_cast<scalar_t*>(&frag_s_half2[i])[idx]);
}
}
}
}
}
// For 8-bit channelwise, we apply the scale before the global reduction
// that converts the fp32 results to fp16 (so that we avoid possible
// overflow in fp16)
if constexpr (!has_act_order && group_blocks == -1 &&
w_type.size_bits() == 8 && !has_zp) {
if (threadIdx.x / 32 < thread_n_blocks / 4) {
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
#pragma unroll
for (int j = 0; j < 4; j++) {
scale_float<scalar_t>(
reinterpret_cast<float*>(&frag_c[i][j][0][0]),
frag_s[j / 2][2 * (j % 2) + 0]);
scale_float<scalar_t>(
reinterpret_cast<float*>(&frag_c[i][j][0][2]),
frag_s[j / 2][2 * (j % 2) + (m_block_size_8 ? 1 : 0)]);
if constexpr (!m_block_size_8) {
scale_float<scalar_t>(
reinterpret_cast<float*>(&frag_c[i][j][1][0]),
frag_s[j / 2][2 * (j % 2) + 1]);
scale_float<scalar_t>(
reinterpret_cast<float*>(&frag_c[i][j][1][2]),
frag_s[j / 2][2 * (j % 2) + 1]);
}
}
}
}
}
if (slice_count > 1 && !use_atomic_add) {
// only globally reduce if there is more than one block in a slice
barrier_acquire(&locks[locks_off], slice_idx);
if (use_fp32_reduce) {
global_reduce_fp32(slice_idx == 0, last);
} else {
global_reduce_fp16(slice_idx == 0, last);
}
barrier_release(&locks[locks_off], last);
}
if (use_atomic_add && slice_count > 1 && slice_idx != 0)
wait_negative_and_add(&locks[locks_off]);
if (last || use_atomic_add)
// only the last block in a slice actually writes the result
write_result();
if (slice_row) a_remaining_load_count_in_slice = stages;
slice_row = 0;
slice_col_par++;
slice_col++;
is_first_matmul_in_slice = true;
init_slice();
if (slice_iters) {
a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
(threadIdx.x % a_gl_rd_delta_o);
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++)
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
if (slice_col == 0) {
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
}
// Update slice k/n for scales loading
if constexpr (has_act_order) {
slice_k_start = tb_k * slice_row;
slice_k_finish = slice_k_start + tb_k * slice_iters;
slice_k_start_shared_fetch = slice_k_start;
slice_n_offset = act_s_col_tb_stride * slice_col;
} else {
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
}
start_pipes();
}
}
}
}
} // namespace MARLIN_NAMESPACE_NAME
#endif
/*
* Modified by Neural Magic
* Copyright (C) Marlin.2024 Elias Frantar
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* Adapted from https://github.com/IST-DASLab/marlin
*/
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif
#include "kernel.h"
#include "core/registration.h"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert(std::is_same<scalar_t, half>::value || \
std::is_same<scalar_t, nv_bfloat16>::value, \
"only float16 and bfloat16 is supported");
namespace MARLIN_NAMESPACE_NAME {
__global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){};
using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
template <int moe_block_size>
__global__ void permute_cols_kernel(
int4 const* __restrict__ a_int4_ptr, int const* __restrict__ perm_int_ptr,
int4* __restrict__ out_int4_ptr,
const int32_t* __restrict__ sorted_token_ids_ptr,
const int32_t* __restrict__ expert_ids_ptr,
const int32_t* __restrict__ num_tokens_past_padded_ptr, int size_m,
int size_k, int top_k) {};
} // namespace marlin
torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor& a, std::optional<torch::Tensor> const& c_or_none,
torch::Tensor& b_q_weight, torch::Tensor& b_scales,
std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none,
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids,
torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights,
int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep,
vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
bool is_zp_float) {
TORCH_CHECK_NOT_IMPLEMENTED(false,
"marlin_gemm(..) requires CUDA_ARCH >= 8.0");
return torch::empty({1, 1});
}
#else
// For a given "a" of size [M,K] performs a permutation of the K columns based
// on the given "perm" indices.
template <int moe_block_size>
__global__ void permute_cols_kernel(
int4 const* __restrict__ a_int4_ptr, int const* __restrict__ perm_int_ptr,
int4* __restrict__ out_int4_ptr,
const int32_t* __restrict__ sorted_token_ids_ptr,
const int32_t* __restrict__ expert_ids_ptr,
const int32_t* __restrict__ num_tokens_past_padded_ptr, int size_m,
int size_k, int top_k) {
int num_tokens_past_padded = num_tokens_past_padded_ptr[0];
int num_moe_blocks = div_ceil(num_tokens_past_padded, moe_block_size);
int32_t block_sorted_ids[moe_block_size];
int block_num_valid_tokens = 0;
int64_t old_expert_id = 0;
int64_t expert_id = 0;
int row_stride = size_k * sizeof(half) / 16;
auto read_moe_block_data = [&](int block_id) {
block_num_valid_tokens = moe_block_size;
int4* tmp_block_sorted_ids = reinterpret_cast<int4*>(block_sorted_ids);
for (int i = 0; i < moe_block_size / 4; i++) {
tmp_block_sorted_ids[i] =
((int4*)sorted_token_ids_ptr)[block_id * moe_block_size / 4 + i];
}
for (int i = 0; i < moe_block_size; i++) {
if (block_sorted_ids[i] >= size_m * top_k) {
block_num_valid_tokens = i;
break;
};
}
};
auto permute_row = [&](int row) {
int iters = size_k / default_threads;
int rest = size_k % default_threads;
int in_offset = (row / top_k) * row_stride;
int out_offset = row * row_stride;
half const* a_row_half =
reinterpret_cast<half const*>(a_int4_ptr + in_offset);
half* out_half = reinterpret_cast<half*>(out_int4_ptr + out_offset);
int base_k = 0;
for (int i = 0; i < iters; i++) {
int cur_k = base_k + threadIdx.x;
int src_pos = perm_int_ptr[cur_k];
out_half[cur_k] = a_row_half[src_pos];
base_k += default_threads;
}
if (rest) {
if (threadIdx.x < rest) {
int cur_k = base_k + threadIdx.x;
int src_pos = perm_int_ptr[cur_k];
out_half[cur_k] = a_row_half[src_pos];
}
}
};
for (int index = blockIdx.x; index < num_moe_blocks; index += gridDim.x) {
old_expert_id = expert_id;
int tmp_expert_id = expert_ids_ptr[index];
if (tmp_expert_id == -1) continue;
expert_id = tmp_expert_id;
perm_int_ptr += (expert_id - old_expert_id) * size_k;
read_moe_block_data(index);
for (int i = 0; i < block_num_valid_tokens; i++)
permute_row(block_sorted_ids[i]);
}
}
typedef struct {
int thread_k;
int thread_n;
int num_threads;
} thread_config_t;
thread_config_t small_batch_thread_configs[] = {
// Ordered by priority
// thread_k, thread_n, num_threads
{128, 128, 256},
{64, 128, 128}};
thread_config_t large_batch_thread_configs[] = {
// Ordered by priority
// thread_k, thread_n, num_threads
{64, 256, 256},
{64, 128, 128}};
typedef struct {
int blocks_per_sm;
thread_config_t tb_cfg;
} exec_config_t;
int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
int prob_n, int prob_k, int num_bits, int group_size,
bool has_act_order, bool is_k_full) {
bool cache_scales_chunk = has_act_order && !is_k_full;
int tb_n = th_config.thread_n;
int tb_k = th_config.thread_k;
// Get max scale groups per thread-block
int tb_groups;
if (group_size == -1) {
tb_groups = 1;
} else if (group_size == 0) {
tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size
} else {
tb_groups = div_ceil(tb_k, group_size);
}
if (cache_scales_chunk) {
int load_groups =
tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K
load_groups = max(load_groups, 32); // We load at least 32 scale groups
return load_groups * tb_n * 2;
} else {
int tb_scales = tb_groups * tb_n * 2;
return tb_scales * pipe_stages;
}
}
int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
int prob_m, int prob_n, int prob_k, int num_bits,
int group_size, bool has_act_order, bool is_k_full,
int has_zp, int is_zp_float) {
int pack_factor = 32 / num_bits;
// Get B size
int tb_k = th_config.thread_k;
int tb_n = th_config.thread_n;
int tb_m = thread_m_blocks * 16;
// shm size for block_sorted_ids/block_topk_weights
// both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32)
int sh_block_meta_size = tb_m * 4 * 2;
int sh_a_size = pipe_stages * (tb_m * tb_k) * 2;
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
int sh_s_size =
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
group_size, has_act_order, is_k_full);
int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0;
int sh_zp_size = 0;
if (has_zp) {
if (is_zp_float)
sh_zp_size = sh_s_size;
else if (num_bits == 4)
sh_zp_size = sh_s_size / 4;
else if (num_bits == 8)
sh_zp_size = sh_s_size / 2;
}
int total_size = sh_a_size + sh_b_size + sh_s_size + sh_zp_size +
sh_g_idx_size + sh_block_meta_size;
return total_size;
}
bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
int prob_m, int prob_n, int prob_k, int num_bits,
int group_size, bool has_act_order, bool is_k_full,
int has_zp, int is_zp_float, int max_shared_mem) {
// Sanity
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
th_config.num_threads == -1) {
return false;
}
// Verify K/N are divisible by thread K/N
if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {
return false;
}
// Verify min for thread K/N
if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {
return false;
}
// num_threads must be at least 128 (= 4 warps)
if (th_config.num_threads < 128) {
return false;
}
// Check that pipeline fits into cache
int cache_size = get_kernel_cache_size(
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size,
has_act_order, is_k_full, has_zp, is_zp_float);
return cache_size <= max_shared_mem;
}
#define __GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
M_BLOCK_SIZE_8, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
NUM_THREADS, IS_ZP_FLOAT) \
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
m_block_size_8 == M_BLOCK_SIZE_8 && \
has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
is_zp_float == IS_ZP_FLOAT) { \
kernel = Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, \
pipe_stages, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
IS_ZP_FLOAT>; \
}
#define GPTQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, true, false, 0, NUM_THREADS, \
false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 8, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
NUM_THREADS, false)
#define GPTQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
NUM_THREADS, false)
#define AWQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 2, NUM_THREADS, \
false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, \
false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 8, NUM_THREADS, \
false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
NUM_THREADS, false)
#define AWQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
NUM_THREADS, false)
// We currently have 4-bit models only with group_blocks == 4
#define HQQ_GET_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, \
true) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, true) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, true) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, true) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, true)
template <typename scalar_t>
MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
int thread_m_blocks, int thread_n_blocks,
int thread_k_blocks, bool m_block_size_8,
bool has_act_order, bool has_zp,
int group_blocks, int num_threads,
bool is_zp_float) {
int num_bits = q_type.size_bits();
auto kernel = MarlinDefault;
if (false) {
}
GPTQ_GET_IF_M1(vllm::kU4B8, 8, 8, 256)
GPTQ_GET_IF_M1(vllm::kU4B8, 8, 4, 128)
GPTQ_GET_IF_M234(vllm::kU4B8, 16, 4, 256)
GPTQ_GET_IF_M234(vllm::kU4B8, 8, 4, 128)
GPTQ_GET_IF_M1(vllm::kU8B128, 8, 8, 256)
GPTQ_GET_IF_M1(vllm::kU8B128, 8, 4, 128)
GPTQ_GET_IF_M234(vllm::kU8B128, 16, 4, 256)
GPTQ_GET_IF_M234(vllm::kU8B128, 8, 4, 128)
AWQ_GET_IF_M1(vllm::kU4, 8, 8, 256)
AWQ_GET_IF_M1(vllm::kU4, 8, 4, 128)
AWQ_GET_IF_M234(vllm::kU4, 16, 4, 256)
AWQ_GET_IF_M234(vllm::kU4, 8, 4, 128)
return kernel;
}
template <typename scalar_t>
exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
int prob_n, int prob_k, int thread_m_blocks,
bool m_block_size_8, int num_bits,
int group_size, bool has_act_order,
bool is_k_full, bool has_zp,
bool is_zp_float, int max_shared_mem) {
exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}};
thread_config_t* thread_configs = thread_m_blocks > 1
? large_batch_thread_configs
: small_batch_thread_configs;
int thread_configs_size =
thread_m_blocks > 1
? sizeof(large_batch_thread_configs) / sizeof(thread_config_t)
: sizeof(small_batch_thread_configs) / sizeof(thread_config_t);
int count = 0;
constexpr int device_max_reg_size = 255 * 1024;
for (int i = 0; i < thread_configs_size; i++) {
thread_config_t th_config = thread_configs[i];
if (!is_valid_config(th_config, thread_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full, has_zp,
is_zp_float, max_shared_mem)) {
continue;
}
int cache_size = get_kernel_cache_size(
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits,
group_size, has_act_order, is_k_full, has_zp, is_zp_float);
int group_blocks = 0;
if (!has_act_order) {
group_blocks = group_size == -1 ? -1 : group_size / 16;
}
auto kernel = get_marlin_kernel<scalar_t>(
q_type, thread_m_blocks, th_config.thread_n / 16,
th_config.thread_k / 16, m_block_size_8, has_act_order, has_zp,
group_blocks, th_config.num_threads, is_zp_float);
if (kernel == MarlinDefault) continue;
if (thread_m_blocks > 1) {
exec_cfg = {1, th_config};
break;
} else {
cudaFuncAttributes attr;
cudaFuncGetAttributes(&attr, kernel);
int reg_size = max(attr.numRegs, 1) * th_config.num_threads * 4;
int allow_count = min(device_max_reg_size / reg_size,
max_shared_mem / (cache_size + 1024));
allow_count = max(min(allow_count, 4), 1);
if (allow_count > count) {
count = allow_count;
exec_cfg = {count, th_config};
};
}
}
return exec_cfg;
}
template <typename scalar_t>
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
void* zp, void* g_idx, void* perm, void* a_tmp,
void* sorted_token_ids, void* expert_ids,
void* num_tokens_past_padded, void* topk_weights,
int moe_block_size, int top_k, bool mul_topk_weights, bool is_ep,
int prob_m, int prob_n, int prob_k, void* workspace,
vllm::ScalarType const& q_type, bool has_act_order,
bool is_k_full, bool has_zp, int num_groups, int group_size,
int dev, cudaStream_t stream, int thread_k, int thread_n,
int sms, bool use_atomic_add, bool use_fp32_reduce,
bool is_zp_float) {
int thread_m_blocks = div_ceil(moe_block_size, 16);
bool m_block_size_8 = moe_block_size == 8;
if (has_zp) {
TORCH_CHECK(
q_type == vllm::kU4 || q_type == vllm::kU8,
"q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str());
} else {
TORCH_CHECK(
q_type == vllm::kU4B8 || q_type == vllm::kU8B128,
"q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
q_type.str());
}
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
", ", prob_n, ", ", prob_k, "]");
int group_blocks = 0;
if (has_act_order) {
if (is_k_full) {
TORCH_CHECK(group_size != -1);
group_blocks = group_size / 16;
TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
" is not divisible by group_blocks = ", group_blocks);
} else {
TORCH_CHECK(group_size == 0);
group_blocks = 0;
}
} else {
if (group_size == -1) {
group_blocks = -1;
} else {
group_blocks = group_size / 16;
TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
" is not divisible by group_blocks = ", group_blocks);
}
}
int num_bits = q_type.size_bits();
const int4* A_ptr = (const int4*)A;
const int4* B_ptr = (const int4*)B;
int4* C_ptr = (int4*)C;
int4* C_tmp_ptr = (int4*)C_tmp;
const int4* s_ptr = (const int4*)s;
const int4* zp_ptr = (const int4*)zp;
const int* g_idx_ptr = (const int*)g_idx;
const int* perm_ptr = (const int*)perm;
int4* a_tmp_ptr = (int4*)a_tmp;
const int32_t* sorted_token_ids_ptr = (const int32_t*)sorted_token_ids;
const int32_t* expert_ids_ptr = (const int32_t*)expert_ids;
const int32_t* num_tokens_past_padded_ptr =
(const int32_t*)num_tokens_past_padded;
const float* topk_weights_ptr = (const float*)topk_weights;
int* locks = (int*)workspace;
if (has_act_order) {
// Permute A columns
auto kernel = permute_cols_kernel<8>;
if (moe_block_size == 8) {
} else if (moe_block_size == 16)
kernel = permute_cols_kernel<16>;
else if (moe_block_size == 32)
kernel = permute_cols_kernel<32>;
else if (moe_block_size == 48)
kernel = permute_cols_kernel<48>;
else if (moe_block_size == 64)
kernel = permute_cols_kernel<64>;
else
TORCH_CHECK(false, "unsupported moe_block_size ", moe_block_size);
// avoid ">>>" being formatted to "> > >"
// clang-format off
kernel<<<sms, default_threads, 0, stream>>>(
A_ptr, perm_ptr, a_tmp_ptr, sorted_token_ids_ptr, expert_ids_ptr,
num_tokens_past_padded_ptr, prob_m, prob_k, top_k);
// clang-format on
A_ptr = a_tmp_ptr;
prob_m = prob_m * top_k;
top_k = 1;
// If we have a full K, then we can run the non-act-order version of Marlin
// (since the weight rows are reordered by increasing group ids, and by
// having a full K, we have full original groups)
if (is_k_full) has_act_order = false;
}
int max_shared_mem = 0;
cudaDeviceGetAttribute(&max_shared_mem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
TORCH_CHECK(max_shared_mem > 0);
// Set thread config
exec_config_t exec_cfg;
thread_config_t thread_tfg;
if (thread_k != -1 && thread_n != -1) {
thread_tfg = thread_config_t{thread_k, thread_n, default_threads};
exec_cfg = exec_config_t{1, thread_tfg};
TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
" is not divisible by thread_n = ", thread_n);
TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
" is not divisible by thread_k = ", thread_k);
} else {
// Auto config
exec_cfg = determine_exec_config<scalar_t>(
q_type, prob_m, prob_n, prob_k, thread_m_blocks, m_block_size_8,
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float,
max_shared_mem);
thread_tfg = exec_cfg.tb_cfg;
}
int num_threads = thread_tfg.num_threads;
thread_k = thread_tfg.thread_k;
thread_n = thread_tfg.thread_n;
int blocks = sms * exec_cfg.blocks_per_sm;
if (exec_cfg.blocks_per_sm > 1)
max_shared_mem = max_shared_mem / exec_cfg.blocks_per_sm - 1024;
int thread_k_blocks = thread_k / 16;
int thread_n_blocks = thread_n / 16;
TORCH_CHECK(is_valid_config(thread_tfg, thread_m_blocks, prob_m, prob_n,
prob_k, num_bits, group_size, has_act_order,
is_k_full, has_zp, is_zp_float, max_shared_mem),
"Invalid thread config: thread_m_blocks = ", thread_m_blocks,
", thread_k = ", thread_tfg.thread_k,
", thread_n = ", thread_tfg.thread_n,
", num_threads = ", thread_tfg.num_threads, " for MKN = [",
prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
", group_size = ", group_size,
", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full,
", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float,
", max_shared_mem = ", max_shared_mem);
auto kernel = get_marlin_kernel<scalar_t>(
q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, m_block_size_8,
has_act_order, has_zp, group_blocks, num_threads, is_zp_float);
if (kernel == MarlinDefault) {
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
", ", prob_k, "]", ", has_act_order = ", has_act_order,
", num_groups = ", num_groups, ", group_size = ", group_size,
", thread_m_blocks = ", thread_m_blocks,
", thread_n_blocks = ", thread_n_blocks,
", thread_k_blocks = ", thread_k_blocks,
", num_bits = ", num_bits);
}
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
max_shared_mem);
// avoid ">>>" being formatted to "> > >"
// clang-format off
kernel<<<blocks, num_threads, max_shared_mem, stream>>>(
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr,
sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr,
topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m,
prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce);
// clang-format on
}
} // namespace MARLIN_NAMESPACE_NAME
torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor& a, std::optional<torch::Tensor> const& c_or_none,
torch::Tensor& b_q_weight, torch::Tensor& b_scales,
std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none,
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids,
torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights,
int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep,
vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
bool is_zp_float) {
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
int pack_factor = 32 / b_q_type.size_bits();
if (moe_block_size != 8) {
TORCH_CHECK(moe_block_size % 16 == 0,
"unsupported moe_block_size=", moe_block_size);
TORCH_CHECK(moe_block_size >= 16 && moe_block_size <= 64,
"unsupported moe_block_size=", moe_block_size);
}
// Verify A
TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0),
", size_m = ", size_m);
TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1),
", size_k = ", size_k);
// Verify B
TORCH_CHECK(
size_k % MARLIN_NAMESPACE_NAME::tile_size == 0, "size_k = ", size_k,
" is not divisible by tile_size = ", MARLIN_NAMESPACE_NAME::tile_size);
TORCH_CHECK((size_k / MARLIN_NAMESPACE_NAME::tile_size) == b_q_weight.size(1),
"Shape mismatch: b_q_weight.size(1) = ", b_q_weight.size(1),
", size_k = ", size_k,
", tile_size = ", MARLIN_NAMESPACE_NAME::tile_size);
TORCH_CHECK(
b_q_weight.size(2) % MARLIN_NAMESPACE_NAME::tile_size == 0,
"b_q_weight.size(2) = ", b_q_weight.size(2),
" is not divisible by tile_size = ", MARLIN_NAMESPACE_NAME::tile_size);
int actual_size_n =
(b_q_weight.size(2) / MARLIN_NAMESPACE_NAME::tile_size) * pack_factor;
TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n,
", actual_size_n = ", actual_size_n);
// Verify device and strides
TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
TORCH_CHECK(a.is_contiguous(), "A is not contiguous");
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int thread_k = -1;
// thread_n: `n` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int thread_n = -1;
// sms: number of SMs to use for the kernel
int sms = -1;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device());
// Alloc buffers
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
torch::Tensor c;
if (c_or_none.has_value()) {
c = c_or_none.value();
TORCH_CHECK(c.device().is_cuda(), "c is not on GPU");
TORCH_CHECK(c.is_contiguous(), "c is not contiguous");
TORCH_CHECK(c.size(0) == size_m * top_k,
"Shape mismatch: c.size(0) = ", c.size(0),
", size_m * topk = ", size_m * top_k);
TORCH_CHECK(c.size(1) == size_n, "Shape mismatch: c.size(1) = ", c.size(1),
", size_n = ", size_n);
} else {
c = torch::empty({size_m * top_k, size_n}, options);
}
// Alloc C tmp buffer that is going to be used for the global reduce
torch::Tensor c_tmp;
auto options_fp32 =
torch::TensorOptions().dtype(at::kFloat).device(a.device());
if (use_fp32_reduce && !use_atomic_add) {
// max num of threadblocks is sms * 4
long max_c_tmp_size = min(
(long)size_n * sorted_token_ids.size(0),
(long)sms * 4 * moe_block_size * MARLIN_NAMESPACE_NAME::max_thread_n);
if (moe_block_size == 8) max_c_tmp_size *= 2;
c_tmp = torch::empty({max_c_tmp_size}, options_fp32);
} else {
c_tmp = torch::empty({0}, options_fp32);
}
// Detect groupsize and act_order
int num_groups = -1;
int group_size = -1;
int rank = b_scales.sizes().size();
TORCH_CHECK(rank == 3, "b_scales rank = ", rank, " is not 3");
TORCH_CHECK(b_scales.size(2) == size_n, "b_scales dim 2 = ", b_scales.size(2),
" is not size_n = ", size_n);
num_groups = b_scales.size(1);
torch::Tensor g_idx, perm, a_tmp;
;
if (g_idx_or_none.has_value() && perm_or_none.has_value()) {
g_idx = g_idx_or_none.value();
perm = perm_or_none.value();
TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU");
TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous");
TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU");
TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous");
// Verify g_idx and perm
TORCH_CHECK((g_idx.size(-1) == 0 && perm.size(-1) == 0) ||
(g_idx.size(-1) == size_k && perm.size(-1) == size_k),
"Unexpected g_idx.size(-1) = ", g_idx.size(-1),
" and perm.size(-1) = ", perm.size(-1),
", where size_k = ", size_k);
} else {
g_idx = torch::empty({0}, options);
perm = torch::empty({0}, options);
a_tmp = torch::empty({0}, options);
}
bool has_act_order = g_idx.size(-1) > 0 && perm.size(-1) > 0;
if (has_act_order) {
a_tmp = torch::empty({size_m * top_k, size_k}, options);
if (is_k_full) {
TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1");
TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k,
", is not divisible by num_groups = ", num_groups);
group_size = size_k / num_groups;
} else {
group_size = 0;
}
} else {
a_tmp = torch::empty({0}, options);
if (num_groups > 1) {
TORCH_CHECK(
size_k % num_groups == 0, "size_k = ", size_k,
", is not divisible by b_scales.size(1) = ", b_scales.size(1));
group_size = size_k / num_groups;
} else {
group_size = -1;
}
}
torch::Tensor b_zeros;
if (b_zeros_or_none.has_value()) {
b_zeros = b_zeros_or_none.value();
TORCH_CHECK(b_zeros.device().is_cuda(), "b_zeros is not on GPU");
TORCH_CHECK(b_zeros.is_contiguous(), "b_zeros is not contiguous");
} else {
b_zeros = torch::empty({0}, options);
}
bool has_zp = b_zeros.size(-1) > 0;
if (has_zp) {
TORCH_CHECK(
b_q_type == vllm::kU4,
"b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str());
} else {
TORCH_CHECK(
b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128,
"b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
b_q_type.str());
}
if (has_zp && is_zp_float) {
TORCH_CHECK(a.scalar_type() == at::ScalarType::Half,
"Computation type must be float16 (half) when using float zero "
"points.");
}
// Verify b_zeros
if (has_zp) {
int rank = b_zeros.sizes().size();
TORCH_CHECK(rank == 3, "b_zeros rank = ", rank, " is not 3");
if (is_zp_float) {
TORCH_CHECK(b_zeros.size(2) == size_n,
"b_zeros dim 2 = ", b_zeros.size(2),
" is not size_n = ", size_n);
TORCH_CHECK(num_groups == b_zeros.size(1),
"b_zeros dim 1 = ", b_zeros.size(1),
" is not num_groups = ", num_groups);
TORCH_CHECK(num_groups != -1, "num_groups must be != -1");
} else {
TORCH_CHECK(b_zeros.size(1) == num_groups,
"b_zeros dim 1 = ", b_zeros.size(1),
" is not num_groups = ", num_groups);
TORCH_CHECK(b_zeros.size(2) == size_n / pack_factor,
"b_zeros dim 2 = ", b_zeros.size(2),
" is not size_n / pack_factor = ", size_n / pack_factor);
}
}
// Verify workspace size
TORCH_CHECK(size_n % MARLIN_NAMESPACE_NAME::min_thread_n == 0,
"size_n = ", size_n, ", is not divisible by min_thread_n = ",
MARLIN_NAMESPACE_NAME::min_thread_n);
int max_n_tiles = size_n / MARLIN_NAMESPACE_NAME::min_thread_n;
int min_workspace_size = min(
max_n_tiles * (int)(sorted_token_ids.size(0) / moe_block_size), sms * 4);
TORCH_CHECK(workspace.numel() >= min_workspace_size,
"workspace.numel = ", workspace.numel(),
" is below min_workspace_size = ", min_workspace_size);
int dev = a.get_device();
if (a.scalar_type() == at::ScalarType::Half) {
MARLIN_NAMESPACE_NAME::marlin_mm<half>(
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
c_tmp.data_ptr<float>(), b_scales.data_ptr<at::Half>(),
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
a_tmp.data_ptr<at::Half>(), sorted_token_ids.data_ptr(),
expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(),
topk_weights.data_ptr(), moe_block_size, top_k, mul_topk_weights, is_ep,
size_m, size_n, size_k, workspace.data_ptr(), b_q_type, has_act_order,
is_k_full, has_zp, num_groups, group_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
use_atomic_add, use_fp32_reduce, is_zp_float);
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
MARLIN_NAMESPACE_NAME::marlin_mm<nv_bfloat16>(
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
b_scales.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), g_idx.data_ptr(),
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
sorted_token_ids.data_ptr(), expert_ids.data_ptr(),
num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(),
moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k,
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float);
} else {
TORCH_CHECK(false,
"moe_wna16_marlin_gemm only supports bfloat16 and float16");
}
return c;
}
#endif
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("moe_wna16_marlin_gemm", &moe_wna16_marlin_gemm);
}
......@@ -43,14 +43,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
m.impl("moe_wna16_gemm", torch::kCUDA, &moe_wna16_gemm);
m.def(
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
"b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, "
"int b_q_type, SymInt size_m, "
"SymInt size_n, SymInt size_k, bool is_k_full, int num_experts, int "
"topk, "
"int moe_block_size, bool replicate_input, bool apply_weights)"
" -> Tensor");
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
"Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none,"
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
"Tensor sorted_token_ids,"
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
"Tensor! topk_weights, int moe_block_size, int top_k, "
"bool mul_topk_weights, bool is_ep, int b_q_type_id,"
"int size_m, int size_n, int size_k,"
"bool is_full_k, bool use_atomic_add,"
"bool use_fp32_reduce, bool is_zp_float) -> Tensor");
// conditionally compiled so impl registration is in source file
#endif
......
......@@ -9,7 +9,11 @@
#include <cuda_runtime.h>
#include <iostream>
namespace marlin {
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin
#endif
namespace MARLIN_NAMESPACE_NAME {
// Marlin params
......@@ -23,6 +27,7 @@ static constexpr int pipe_stages =
static constexpr int min_thread_n = 64;
static constexpr int min_thread_k = 64;
static constexpr int max_thread_n = 256;
static constexpr int tile_size = 16;
static constexpr int max_par = 16;
......@@ -84,4 +89,4 @@ __device__ inline void cp_async_wait() {
#endif
} // namespace marlin
} // namespace MARLIN_NAMESPACE_NAME
......@@ -5,7 +5,11 @@
#include <cuda_fp16.h>
#include <cuda_bf16.h>
namespace marlin {
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin
#endif
namespace MARLIN_NAMESPACE_NAME {
template <typename scalar_t>
class ScalarType {};
......@@ -54,7 +58,7 @@ class ScalarType<nv_bfloat16> {
using FragS = Vec<nv_bfloat162, 1>;
using FragZP = Vec<nv_bfloat162, 4>;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
static __device__ float inline num2float(const nv_bfloat16 x) {
return __bfloat162float(x);
}
......@@ -74,6 +78,6 @@ class ScalarType<nv_bfloat16> {
#endif
};
} // namespace marlin
} // namespace MARLIN_NAMESPACE_NAME
#endif
......@@ -11,16 +11,14 @@ from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import vllm.model_executor.layers.fused_moe # noqa
from tests.kernels.utils import (compute_max_diff, opcheck, stack_and_dev,
torch_moe, torch_moe_single)
from vllm import _custom_ops as ops
from tests.kernels.utils import (opcheck, stack_and_dev, torch_moe,
torch_moe_single)
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, moe_align_block_size)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
fused_moe as iterative_moe)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
marlin_quantize)
awq_marlin_quantize, marlin_quantize)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
quantize_weights)
from vllm.model_executor.models.mixtral import MixtralMoE
......@@ -287,14 +285,17 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
atol=mixtral_moe_tol[dtype])
@pytest.mark.parametrize("m", [1, 33, 64, 222])
@pytest.mark.parametrize("n", [128, 2048])
@pytest.mark.parametrize("k", [128, 1024])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("m", [1, 33, 123])
@pytest.mark.parametrize("n", [128, 1024])
@pytest.mark.parametrize("k", [256, 2048])
@pytest.mark.parametrize("e", [4, 12])
@pytest.mark.parametrize("topk", [2, 3])
@pytest.mark.parametrize("ep_size", [1, 4])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("group_size", [-1, 32, 128])
@pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("num_bits", [4, 8])
@pytest.mark.parametrize("has_zp", [True, False])
@pytest.mark.parametrize("is_k_full", [True, False])
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_fused_marlin_moe(
......@@ -303,9 +304,12 @@ def test_fused_marlin_moe(
k: int,
e: int,
topk: int,
ep_size: int,
dtype: torch.dtype,
group_size: int,
act_order: bool,
num_bits: int,
has_zp: bool,
is_k_full: bool,
):
current_platform.seed_everything(7)
......@@ -316,75 +320,110 @@ def test_fused_marlin_moe(
return
if group_size in (k, n):
return
if has_zp:
return
else:
if not is_k_full:
return
quant_type = (scalar_types.uint4b8
if num_bits == 4 else scalar_types.uint8b128)
dtype = torch.float16
if has_zp:
# we don't build kernel for int8 with zero
if num_bits == 8:
return
quant_type = scalar_types.uint4 if num_bits == 4 else scalar_types.uint8
else:
quant_type = scalar_types.uint4b8 \
if num_bits == 4 else scalar_types.uint8b128
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
if ep_size > 1:
local_e = e // ep_size
e_ids = torch.randperm(e, device="cuda", dtype=torch.int32)[:local_e]
e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32)
e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
w1 = w1[e_ids]
w2 = w2[e_ids]
else:
e_map = None
w_ref1_l = []
qweight1_l = []
scales1_l = []
zeros1_l = []
g_idx1_l = []
sort_indices1_l = []
for i in range(w1.shape[0]):
test_perm = torch.randperm(k)
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
w1[i].transpose(1, 0), quant_type, group_size, act_order,
test_perm)
w_ref1_l.append(w_ref1)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
g_idx1_l.append(g_idx1)
sort_indices1_l.append(sort_indices1)
if has_zp:
w_ref1, qweight1, scales1, zeros1 = awq_marlin_quantize(
w1[i].transpose(1, 0), quant_type, group_size)
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
zeros1_l.append(zeros1)
else:
test_perm = torch.randperm(k)
quant_res = marlin_quantize(w1[i].transpose(1, 0), quant_type,
group_size, act_order, test_perm)
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = quant_res
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
g_idx1_l.append(g_idx1)
sort_indices1_l.append(sort_indices1)
w_ref1 = stack_and_dev(w_ref1_l)
qweight1 = stack_and_dev(qweight1_l).contiguous()
scales1 = stack_and_dev(scales1_l)
g_idx1 = stack_and_dev(g_idx1_l)
sort_indices1 = stack_and_dev(sort_indices1_l)
g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
zeros1 = stack_and_dev(zeros1_l) if zeros1_l else None
sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
w_ref2_l = []
qweight2_l = []
scales2_l = []
zeros2_l = []
g_idx2_l = []
sort_indices2_l = []
for i in range(w2.shape[0]):
test_perm = torch.randperm(n)
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
w2[i].transpose(1, 0), quant_type, group_size, act_order,
test_perm)
w_ref2_l.append(w_ref2)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
g_idx2_l.append(g_idx2)
sort_indices2_l.append(sort_indices2)
if has_zp:
w_ref2, qweight2, scales2, zeros2 = awq_marlin_quantize(
w2[i].transpose(1, 0), quant_type, group_size)
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
zeros2_l.append(zeros2)
else:
test_perm = torch.randperm(n)
quant_res = marlin_quantize(w2[i].transpose(1, 0), quant_type,
group_size, act_order, test_perm)
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = quant_res
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
g_idx2_l.append(g_idx2)
sort_indices2_l.append(sort_indices2)
w_ref2 = stack_and_dev(w_ref2_l)
qweight2 = stack_and_dev(qweight2_l).contiguous()
scales2 = stack_and_dev(scales2_l)
g_idx2 = stack_and_dev(g_idx2_l)
sort_indices2 = stack_and_dev(sort_indices2_l)
g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
zeros2 = stack_and_dev(zeros2_l) if zeros2_l else None
sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids = fused_topk(a, score, topk, False)
triton_output = fused_moe(
a,
w_ref1.transpose(1, 2).contiguous(),
w_ref2.transpose(1, 2).contiguous(),
score,
topk,
renormalize=False,
)
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map)
marlin_output = torch.ops.vllm.fused_marlin_moe(
a,
qweight1,
......@@ -394,111 +433,91 @@ def test_fused_marlin_moe(
score,
topk_weights,
topk_ids,
global_num_experts=e,
expert_map=e_map,
g_idx1=g_idx1,
g_idx2=g_idx2,
sort_indices1=sort_indices1,
sort_indices2=sort_indices2,
w1_zeros=zeros1,
w2_zeros=zeros2,
num_bits=num_bits,
is_k_full=is_k_full,
)
assert compute_max_diff(marlin_output, triton_output) < 4e-2
if ops.supports_moe_ops:
token_expert_indicies = torch.empty(m,
topk,
dtype=torch.int32,
device=a.device)
opcheck(torch.ops._moe_C.topk_softmax, (
topk_weights,
topk_ids,
token_expert_indicies,
score.float(),
))
block_size_m = 4
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m,
e)
is_k_full=is_k_full)
max_workspace_size = ((m + 255) // 256) * (max(2 * n, k) // 64) * 16
workspace = torch.zeros(max_workspace_size,
dtype=torch.int,
device="cuda",
requires_grad=False)
zp = torch.empty((0, 0),
dtype=dtype,
device="cuda",
requires_grad=False)
opcheck(torch.ops._moe_C.marlin_gemm_moe,
(a, qweight1, sorted_token_ids, topk_weights, topk_ids,
scales1, zp, g_idx1, sort_indices1, workspace, quant_type.id,
m, 2 * n, k, True, e, topk, block_size_m, True, False))
torch.testing.assert_close(marlin_output, torch_output, atol=2e-2, rtol=0)
@pytest.mark.skip("This test is here for the sake of debugging, "
"don't run it in automated tests.")
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 1024, 512])
@pytest.mark.parametrize("e", [8, 64])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("m", [1, 33, 123])
@pytest.mark.parametrize("n", [128, 1024])
@pytest.mark.parametrize("k", [256, 2048])
@pytest.mark.parametrize("e", [4, 12])
@pytest.mark.parametrize("topk", [2, 3])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("group_size", [-1, 32, 128])
@pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("num_bits", [4, 8])
@pytest.mark.parametrize("has_zp", [True, False])
@pytest.mark.parametrize("is_k_full", [True, False])
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_single_marlin_moe_multiply(
m: int,
n: int,
k: int,
e: int,
topk: int,
group_size: int,
act_order: bool,
num_bits: int,
is_k_full: bool,
):
def test_single_marlin_moe_multiply(m: int, n: int, k: int, e: int, topk: int,
dtype: torch.dtype, group_size: int,
act_order: bool, num_bits: int,
has_zp: bool, is_k_full: bool):
# Filter act_order
if act_order:
if group_size == -1:
return
if group_size == k:
if group_size in (k, n):
return
if has_zp:
return
else:
if not is_k_full:
return
quant_type = (scalar_types.uint4b8
if num_bits == 4 else scalar_types.uint8b128)
dtype = torch.float16
if has_zp:
quant_type = scalar_types.uint4 if num_bits == 4 else scalar_types.uint8
else:
quant_type = scalar_types.uint4b8 \
if num_bits == 4 else scalar_types.uint8b128
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
w_ref_l = []
qweights_l = []
qweight_l = []
scales_l = []
zeros_l = []
g_idx_l = []
sort_indices_l = []
for i in range(w.shape[0]):
test_perm = torch.randperm(k)
w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize(
w[i].transpose(1, 0), quant_type, group_size, act_order, test_perm)
w_ref_l.append(w_ref)
qweights_l.append(qweight)
scales_l.append(scales)
g_idx_l.append(g_idx)
sort_indices_l.append(sort_indices)
if has_zp:
w_ref, qweight, scales, zeros = awq_marlin_quantize(
w[i].transpose(1, 0), quant_type, group_size)
w_ref_l.append(w_ref.T)
qweight_l.append(qweight)
scales_l.append(scales)
zeros_l.append(zeros)
else:
test_perm = torch.randperm(k)
w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize(
w[i].transpose(1, 0), quant_type, group_size, act_order,
test_perm)
w_ref_l.append(w_ref.T)
qweight_l.append(qweight)
scales_l.append(scales)
g_idx_l.append(g_idx)
sort_indices_l.append(sort_indices)
w_ref = stack_and_dev(w_ref_l)
qweight = stack_and_dev(qweights_l).contiguous()
qweight = stack_and_dev(qweight_l).contiguous()
scales = stack_and_dev(scales_l)
g_idx = stack_and_dev(g_idx_l)
sort_indices = stack_and_dev(sort_indices_l)
g_idx = stack_and_dev(g_idx_l) if g_idx_l else None
zeros = stack_and_dev(zeros_l) if zeros_l else None
sort_indices = stack_and_dev(sort_indices_l) if sort_indices_l else None
score = torch.randn((m, e), device="cuda", dtype=dtype)
marlin_output = torch.ops.vllm.single_marlin_moe(
......@@ -510,13 +529,14 @@ def test_single_marlin_moe_multiply(
renormalize=False,
g_idx=g_idx,
sort_indices=sort_indices,
w_zeros=zeros,
num_bits=num_bits,
is_k_full=is_k_full,
)
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)
torch_output = torch_moe_single(a, w_ref, score, topk)
assert compute_max_diff(marlin_output, torch_output) < 1e-2
torch.testing.assert_close(marlin_output, torch_output, atol=2e-2, rtol=0)
def test_moe_align_block_size_opcheck():
......
......@@ -1245,6 +1245,29 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
token_expert_indicies, gating_output)
def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor],
b_qweight: torch.Tensor, b_scales: torch.Tensor,
b_qzeros: Optional[torch.Tensor],
g_idx: Optional[torch.Tensor],
perm: Optional[torch.Tensor],
workspace: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_past_padded: torch.Tensor,
topk_weights: torch.Tensor, moe_block_size: int,
top_k: int, mul_topk_weights: bool, is_ep: bool,
b_q_type: ScalarType, size_m: int, size_n: int,
size_k: int, is_k_full: bool, use_atomic_add: bool,
use_fp32_reduce: bool,
is_zp_float: bool) -> torch.Tensor:
return torch.ops._moe_C.moe_wna16_marlin_gemm(
input, output, b_qweight, b_scales, b_qzeros, g_idx, perm, workspace,
sorted_token_ids, expert_ids, num_tokens_past_padded, topk_weights,
moe_block_size, top_k, mul_topk_weights, is_ep, b_q_type.id, size_m,
size_n, size_k, is_k_full, use_atomic_add, use_fp32_reduce,
is_zp_float)
if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
@register_fake("_moe_C::marlin_gemm_moe")
......@@ -1263,6 +1286,29 @@ if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
dtype=a.dtype,
device=a.device)
@register_fake("_moe_C::moe_wna16_marlin_gemm")
def moe_wna16_marlin_gemm_fake(input: torch.Tensor,
output: Optional[torch.Tensor],
b_qweight: torch.Tensor,
b_scales: torch.Tensor,
b_qzeros: Optional[torch.Tensor],
g_idx: Optional[torch.Tensor],
perm: Optional[torch.Tensor],
workspace: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_past_padded: torch.Tensor,
topk_weights: torch.Tensor,
moe_block_size: int, top_k: int,
mul_topk_weights: bool, is_ep: bool,
b_q_type: ScalarType, size_m: int,
size_n: int, size_k: int, is_k_full: bool,
use_atomic_add: bool, use_fp32_reduce: bool,
is_zp_float: bool) -> torch.Tensor:
return torch.empty((size_m * top_k, size_n),
dtype=input.dtype,
device=input.device)
def reshape_and_cache(
key: torch.Tensor,
......
......@@ -5,17 +5,16 @@ from typing import Optional
import torch
import vllm._custom_ops as ops
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, moe_align_block_size, try_get_optimal_moe_config)
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.utils import direct_register_custom_op
def get_scalar_type(num_bits: int, has_zp: bool):
if has_zp:
assert num_bits == 4
return scalar_types.uint4
return scalar_types.uint4 if num_bits == 4 else scalar_types.uint8
else:
return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128
......@@ -27,9 +26,12 @@ def single_marlin_moe(
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
g_idx: Optional[torch.Tensor] = None,
sort_indices: Optional[torch.Tensor] = None,
w_zeros: Optional[torch.Tensor] = None,
workspace: Optional[torch.Tensor] = None,
num_bits: int = 8,
is_k_full: bool = True,
) -> torch.Tensor:
......@@ -62,7 +64,7 @@ def single_marlin_moe(
assert gating_output.shape[1] == w.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w.is_contiguous(), "Expert weights must be contiguous"
assert hidden_states.dtype == torch.float16
assert hidden_states.dtype in [torch.float16, torch.bfloat16]
assert num_bits in [4, 8]
M, K = hidden_states.shape
......@@ -83,39 +85,54 @@ def single_marlin_moe(
block_size_m = config['BLOCK_SIZE_M']
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)
max_workspace_size = (N // 64) * 16
workspace = torch.zeros(max_workspace_size,
dtype=torch.int,
device=hidden_states.device,
requires_grad=False)
has_zero_point = w_zeros is not None
if w_zeros is None:
w_zeros = torch.empty((0, 0),
dtype=hidden_states.dtype,
device=hidden_states.device,
requires_grad=False)
if g_idx is None:
g_idx = torch.empty((0, 0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)
if sort_indices is None:
sort_indices = torch.empty((0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)
scalar_type = get_scalar_type(num_bits, has_zero_point)
if global_num_experts == -1:
global_num_experts = E
sorted_token_ids, expert_ids, num_tokens_post_padded = \
moe_align_block_size(topk_ids, block_size_m, E, expert_map)
if workspace is None:
max_workspace_size = (max(2 * N, K) // 64) * \
(sorted_token_ids.size(0) // block_size_m)
device = hidden_states.device
sms = torch.cuda.get_device_properties(device).multi_processor_count
max_workspace_size = min(max_workspace_size, sms)
workspace = torch.zeros(max_workspace_size,
dtype=torch.int,
device=device,
requires_grad=False)
scalar_type = get_scalar_type(num_bits, w_zeros is not None)
intermediate_cache = torch.empty(
(M * topk_ids.shape[1], N),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache = torch.ops._moe_C.marlin_gemm_moe(
hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales,
w_zeros, g_idx, sort_indices, workspace, scalar_type.id, M, N, K,
is_k_full, E, topk, block_size_m, True, False)
ops.moe_wna16_marlin_gemm(hidden_states,
intermediate_cache,
w,
scales,
w_zeros,
g_idx,
sort_indices,
workspace,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
topk_weights,
moe_block_size=block_size_m,
top_k=topk,
mul_topk_weights=False,
is_ep=expert_map is not None,
b_q_type=scalar_type,
size_m=M,
size_n=N,
size_k=K,
is_k_full=is_k_full,
use_atomic_add=False,
use_fp32_reduce=True,
is_zp_float=False)
intermediate_cache = intermediate_cache.view(-1, topk, N)
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
......@@ -127,9 +144,12 @@ def single_marlin_moe_fake(
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
g_idx: Optional[torch.Tensor] = None,
sort_indices: Optional[torch.Tensor] = None,
w_zeros: Optional[torch.Tensor] = None,
workspace: Optional[torch.Tensor] = None,
num_bits: int = 8,
is_k_full: bool = True,
) -> torch.Tensor:
......@@ -144,24 +164,26 @@ direct_register_custom_op(
)
def fused_marlin_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
gating_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
g_idx1: Optional[torch.Tensor] = None,
g_idx2: Optional[torch.Tensor] = None,
sort_indices1: Optional[torch.Tensor] = None,
sort_indices2: Optional[torch.Tensor] = None,
w1_zeros: Optional[torch.Tensor] = None,
w2_zeros: Optional[torch.Tensor] = None,
num_bits: int = 8,
is_k_full: bool = True,
) -> torch.Tensor:
def fused_marlin_moe(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
gating_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
g_idx1: Optional[torch.Tensor] = None,
g_idx2: Optional[torch.Tensor] = None,
sort_indices1: Optional[torch.Tensor] = None,
sort_indices2: Optional[torch.Tensor] = None,
w1_zeros: Optional[torch.Tensor] = None,
w2_zeros: Optional[torch.Tensor] = None,
workspace: Optional[torch.Tensor] = None,
num_bits: int = 8,
is_k_full: bool = True,
inplace: bool = False) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
......@@ -196,27 +218,12 @@ def fused_marlin_moe(
1] == w1.shape[1] * 16, "Hidden size mismatch w1"
assert hidden_states.shape[1] == w2.shape[2] // (
num_bits // 2), "Hidden size mismatch w2"
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype == torch.float16
assert hidden_states.dtype in [torch.float16, torch.bfloat16]
assert num_bits in [4, 8]
has_no_act_order = (g_idx1 is None and g_idx2 is None
and sort_indices1 is None and sort_indices2 is None)
has_all_act_order = (g_idx1 is not None and g_idx2 is not None
and sort_indices1 is not None
and sort_indices2 is not None)
assert has_no_act_order or has_all_act_order, (
"g_idx and sorted_indices "
"must be all not None or must be all None")
has_no_zp = w1_zeros is None and w2_zeros is None
has_all_zp = w1_zeros is not None and w2_zeros is not None
assert has_no_zp or has_all_zp, ("zero points must be both not None or "
"must be both None")
M, K = hidden_states.shape
E = w1.shape[0]
N = w2.shape[1] * 16
......@@ -234,122 +241,128 @@ def fused_marlin_moe(
block_size_m = config["BLOCK_SIZE_M"]
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)
max_workspace_size = (max(2 * N, K) // 64) * 16
workspace = torch.zeros(max_workspace_size,
dtype=torch.int,
device=current_platform.device_type,
requires_grad=False)
if has_no_zp:
w1_zeros = torch.empty((0, 0),
dtype=hidden_states.dtype,
device=hidden_states.device,
requires_grad=False)
w2_zeros = torch.empty((0, 0),
dtype=hidden_states.dtype,
device=hidden_states.device,
requires_grad=False)
if has_no_act_order:
g_idx1 = torch.empty((0, 0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)
g_idx2 = torch.empty((0, 0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)
sort_indices1 = torch.empty((0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)
sort_indices2 = torch.empty((0, 0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)
scalar_type1 = get_scalar_type(num_bits, has_all_zp)
scalar_type2 = get_scalar_type(num_bits, has_all_zp)
if global_num_experts == -1:
global_num_experts = E
sorted_token_ids, expert_ids, num_tokens_post_padded = \
moe_align_block_size(topk_ids, block_size_m, global_num_experts,
expert_map)
if workspace is None:
max_workspace_size = (max(2 * N, K) // 64) * \
(sorted_token_ids.size(0) // block_size_m)
device = hidden_states.device
sms = torch.cuda.get_device_properties(device).multi_processor_count
max_workspace_size = min(max_workspace_size, sms * 4)
workspace = torch.zeros(max_workspace_size,
dtype=torch.int,
device=device,
requires_grad=False)
scalar_type1 = get_scalar_type(num_bits, w1_zeros is not None)
scalar_type2 = get_scalar_type(num_bits, w2_zeros is not None)
intermediate_cache2 = torch.empty(
(M * topk_ids.shape[1], N),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache13 = torch.empty(
(M * topk_ids.shape[1] * max(2 * N, K), ),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache1 = intermediate_cache13[:M * topk_ids.shape[1] * 2 * N]
intermediate_cache1 = intermediate_cache1.view(-1, 2 * N)
intermediate_cache3 = intermediate_cache13[:M * topk_ids.shape[1] * K]
intermediate_cache3 = intermediate_cache3.view(-1, K)
use_atomic_add = hidden_states.dtype == torch.half or \
torch.cuda.get_device_capability(hidden_states.device)[0] >= 9
intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe(
intermediate_cache1 = ops.moe_wna16_marlin_gemm(
hidden_states,
intermediate_cache1,
w1,
sorted_token_ids,
topk_weights,
topk_ids,
w1_scale,
w1_zeros,
g_idx1,
sort_indices1,
workspace,
scalar_type1.id,
M,
2 * N,
K,
is_k_full,
E,
topk,
block_size_m,
True,
False,
)
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
topk_weights,
moe_block_size=block_size_m,
top_k=topk,
mul_topk_weights=False,
is_ep=expert_map is not None,
b_q_type=scalar_type1,
size_m=M,
size_n=2 * N,
size_k=K,
is_k_full=is_k_full,
use_atomic_add=use_atomic_add,
use_fp32_reduce=True,
is_zp_float=False)
torch.ops._C.silu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, 2 * N))
intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe(
if expert_map is not None:
intermediate_cache3.zero_()
intermediate_cache3 = ops.moe_wna16_marlin_gemm(
intermediate_cache2,
intermediate_cache3,
w2,
sorted_token_ids,
topk_weights,
topk_ids,
w2_scale,
w2_zeros,
g_idx2,
sort_indices2,
workspace,
scalar_type2.id,
M,
K,
N,
is_k_full,
E,
topk,
block_size_m,
False,
True,
)
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
topk_weights,
moe_block_size=block_size_m,
top_k=1,
mul_topk_weights=True,
is_ep=expert_map is not None,
b_q_type=scalar_type2,
size_m=M * topk,
size_n=K,
size_k=N,
is_k_full=is_k_full,
use_atomic_add=use_atomic_add,
use_fp32_reduce=True,
is_zp_float=False).view(-1, topk, K)
output = hidden_states if inplace else torch.empty_like(hidden_states)
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1)
def fused_marlin_moe_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
gating_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
g_idx1: Optional[torch.Tensor] = None,
g_idx2: Optional[torch.Tensor] = None,
sort_indices1: Optional[torch.Tensor] = None,
sort_indices2: Optional[torch.Tensor] = None,
w1_zeros: Optional[torch.Tensor] = None,
w2_zeros: Optional[torch.Tensor] = None,
num_bits: int = 8,
is_k_full: bool = True,
) -> torch.Tensor:
dim=1,
out=output)
def fused_marlin_moe_fake(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
gating_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
g_idx1: Optional[torch.Tensor] = None,
g_idx2: Optional[torch.Tensor] = None,
sort_indices1: Optional[torch.Tensor] = None,
sort_indices2: Optional[torch.Tensor] = None,
w1_zeros: Optional[torch.Tensor] = None,
w2_zeros: Optional[torch.Tensor] = None,
workspace: Optional[torch.Tensor] = None,
num_bits: int = 8,
is_k_full: bool = True,
inplace: bool = False) -> torch.Tensor:
return torch.empty_like(hidden_states)
......
......@@ -773,6 +773,18 @@ def get_default_config(
config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1}
else:
config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1}
elif is_marlin:
for block_size_m in [8, 16, 32, 48, 64]:
if M * topk / E / block_size_m < 0.9:
break
return {"BLOCK_SIZE_M": block_size_m}
elif M <= E:
config = {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
}
else:
config = {
"BLOCK_SIZE_M": 64,
......@@ -780,14 +792,6 @@ def get_default_config(
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
}
# A heuristic: fused marlin works faster with this config for small M
if M <= E or (is_marlin and M <= 32):
config = {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
}
return config
......
......@@ -472,6 +472,7 @@ class FusedMoE(torch.nn.Module):
self.global_num_experts = num_experts
assert intermediate_size % self.tp_size == 0
self.hidden_size = hidden_size
self.intermediate_size_per_partition = intermediate_size // self.tp_size
self.reduce_results = reduce_results
self.renormalize = renormalize
......
......@@ -17,14 +17,13 @@ from vllm.model_executor.layers.quantization.awq import (AWQConfig,
is_layer_skipped_awq)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
check_marlin_supports_layer, marlin_make_empty_g_idx,
marlin_make_workspace, marlin_moe_permute_scales, marlin_permute_scales,
moe_awq_to_marlin_zero_points, verify_marlin_supported,
verify_marlin_supports_shape)
check_marlin_supports_layer, check_moe_marlin_supports_layer,
marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales,
marlin_permute_scales, moe_awq_to_marlin_zero_points,
verify_marlin_supported, verify_marlin_supports_shape)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
PackedvLLMParameter)
......@@ -136,12 +135,15 @@ class AWQMarlinConfig(QuantizationConfig):
self.full_config).get_quant_method(layer, prefix)
return AWQMarlinLinearMethod(self)
elif isinstance(layer, FusedMoE):
if layer.local_num_experts > 32:
# For MoEs with many experts the moe_wna16 kernel is faster
from vllm.model_executor.layers.quantization.moe_wna16 import (
MoeWNA16Config)
if not check_moe_marlin_supports_layer(layer, self.group_size):
logger.warning_one(
f"Layer '{prefix}' is not supported by AWQMoeMarlin. "
"Falling back to Moe WNA16 kernels.")
return MoeWNA16Config.from_config(
self.full_config).get_quant_method(layer, prefix)
else:
return AWQMoEMethod(self)
return AWQMoEMethod(self)
return None
@classmethod
......@@ -391,6 +393,13 @@ class AWQMoEMethod(FusedMoEMethodBase):
layer.register_parameter("w2_qzeros", w2_qzeros)
set_weight_attrs(w2_qzeros, extra_weight_attrs)
device = layer.w13_qweight.device
sms = torch.cuda.get_device_properties(device).multi_processor_count
layer.workspace = torch.zeros((sms * 4, ),
dtype=torch.int,
device=device,
requires_grad=False)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
num_experts = layer.w13_qweight.shape[0]
device = layer.w13_qweight.device
......@@ -473,10 +482,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
activation: str = "silu",
) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported."
if expert_map is not None:
raise NotImplementedError(
"Expert Parallelism is not supported for "
"fused Marlin MoE method.")
if apply_router_weight_on_input:
raise NotImplementedError(
"Apply router weight on input is not supported for"
......@@ -503,7 +509,10 @@ class AWQMoEMethod(FusedMoEMethodBase):
router_logits,
topk_weights,
topk_ids,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_zeros=layer.w13_qzeros,
w2_zeros=layer.w2_qzeros,
workspace=layer.workspace,
num_bits=self.quant_config.weight_bits,
)
......@@ -15,13 +15,13 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
MPLinearLayerConfig, choose_mp_linear_kernel)
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
get_linear_quant_method)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported, marlin_moe_permute_scales,
marlin_repeat_scales_on_all_ranks, verify_marlin_supported)
check_marlin_supported, check_moe_marlin_supports_layer,
marlin_moe_permute_scales, marlin_repeat_scales_on_all_ranks,
verify_marlin_supported)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
......@@ -153,12 +153,15 @@ class GPTQMarlinConfig(QuantizationConfig):
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, FusedMoE):
if layer.local_num_experts > 32:
# For MoEs with many experts the moe_wna16 kernel is faster
from vllm.model_executor.layers.quantization.moe_wna16 import (
MoeWNA16Config)
if not check_moe_marlin_supports_layer(layer, self.group_size):
logger.warning_one(
f"Layer '{prefix}' is not supported by GPTQMoeMarlin. "
"Falling back to Moe WNA16 kernels.")
return MoeWNA16Config.from_config(
self.full_config).get_quant_method(layer, prefix)
else:
return GPTQMarlinMoEMethod(self)
return GPTQMarlinMoEMethod(self)
return get_linear_quant_method(self, layer, prefix,
GPTQMarlinLinearMethod)
......@@ -408,7 +411,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
torch.empty(num_experts,
scales_size13,
2 * intermediate_size_per_partition,
dtype=torch.half),
dtype=params_dtype),
requires_grad=False,
)
layer.register_parameter("w13_scales", w13_scales)
......@@ -418,7 +421,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
torch.empty(num_experts,
scales_size2,
hidden_size,
dtype=torch.half),
dtype=params_dtype),
requires_grad=False,
)
layer.register_parameter("w2_scales", w2_scales)
......@@ -493,6 +496,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
w2_g_idx_sort_indices)
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
device = layer.w13_qweight.device
sms = torch.cuda.get_device_properties(device).multi_processor_count
layer.workspace = torch.zeros((sms * 4, ),
dtype=torch.int,
device=device,
requires_grad=False)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Process act_order
......@@ -601,10 +611,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
"Apply router weight on input is not supported for"
"fused Marlin MoE method.")
# The input must currently be float16
orig_dtype = x.dtype
x = x.half()
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
......@@ -626,9 +632,12 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
router_logits,
topk_weights,
topk_ids,
global_num_experts=global_num_experts,
expert_map=expert_map,
g_idx1=layer.w13_g_idx,
g_idx2=layer.w2_g_idx,
sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices,
num_bits=self.quant_config.quant_type.size_bits,
is_k_full=self.is_k_full).to(orig_dtype)
workspace=layer.workspace,
is_k_full=self.is_k_full)
......@@ -151,6 +151,19 @@ def check_marlin_supports_layer(layer: LinearBase, group_size: int) \
group_size=group_size)[0]
def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) \
-> bool:
hidden_size = layer.hidden_size
intermediate_size_per_partition = layer.intermediate_size_per_partition
# gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size)
# down: (n, k) = (hidden_size, intermediate_size_per_partition)
# moe marlin requires n % 128 == 0 and k % 64 == 0
return hidden_size % 128 == 0 and \
intermediate_size_per_partition % max(64, group_size) == 0 and \
group_size in [-1, 32, 64, 128]
def marlin_make_workspace(output_size_per_partition: int,
device: torch.device) -> torch.Tensor:
max_workspace_size = (output_size_per_partition //
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment