Commit 3fb4b5fa authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.18.0' into v0.18.0-ori

parents bcf25339 89138b21
......@@ -16,10 +16,12 @@ namespace vec_op {
#define vec_sr(a, b) ((a) >> (b)) // Vector Shift Right Algebraic
#define vec_sl(a, b) ((a) << (b)) // Vector Shift Left
// FIXME: FP16 is not fully supported in Torch-CPU
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
// NOTE: FP16 (Half) is supported on s390x via custom bit-manipulation
// conversion. PyTorch itself lacks native s390x FP16 support.
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
......@@ -86,6 +88,39 @@ struct BF16Vec8 : public Vec<BF16Vec8> {
}
};
struct FP16Vec8 : public Vec<FP16Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
__vector signed short reg;
explicit FP16Vec8(const void* ptr) : reg(*(__vector signed short*)ptr) {}
explicit FP16Vec8(const FP32Vec8&);
void save(void* ptr) const {
*reinterpret_cast<__vector signed short*>(ptr) = reg;
}
};
struct FP16Vec16 : public Vec<FP16Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
ss16x8x2_t reg;
explicit FP16Vec16(const void* ptr) {
// Load 256 bits (16 FP16 values) in two parts
reg.val[0] = (__vector signed short)vec_xl(0, (signed short*)ptr);
reg.val[1] = (__vector signed short)vec_xl(16, (signed short*)ptr);
}
explicit FP16Vec16(const FP32Vec16&);
void save(void* ptr) const {
// Save 256 bits in two parts
vec_xst(reg.val[0], 0, (signed short*)ptr);
vec_xst(reg.val[1], 16, (signed short*)ptr);
}
};
struct BF16Vec16 : public Vec<BF16Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
......@@ -108,6 +143,92 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
const static __vector signed short zero = vec_splats((signed short)0);
FORCE_INLINE __vector float fp16_to_fp32_bits(__vector unsigned int x) {
const __vector unsigned int mask_sign = {0x8000, 0x8000, 0x8000, 0x8000};
const __vector unsigned int mask_exp = {0x7C00, 0x7C00, 0x7C00, 0x7C00};
const __vector unsigned int mask_mant = {0x03FF, 0x03FF, 0x03FF, 0x03FF};
const __vector unsigned int bias_adj = {112, 112, 112, 112};
const __vector unsigned int exp_max_fp16 = {0x1F, 0x1F, 0x1F,
0x1F}; // FP16 NaN/Inf exponent
const __vector unsigned int exp_max_fp32 = {0xFF, 0xFF, 0xFF,
0xFF}; // FP32 NaN/Inf exponent
__vector unsigned int s = (x & mask_sign) << 16;
__vector unsigned int e = (x & mask_exp) >> 10;
__vector unsigned int m = (x & mask_mant) << 13;
// Check for NaN/Inf: exponent = 0x1F in FP16
__vector __bool int is_nan_inf = vec_cmpeq(e, exp_max_fp16);
// Normal: adjust bias; NaN/Inf: set to 0xFF
__vector unsigned int e_normal = e + bias_adj;
e = vec_sel(e_normal, exp_max_fp32, is_nan_inf);
return (__vector float)(s | (e << 23) | m);
}
FORCE_INLINE __vector unsigned int fp32_to_fp16_bits(__vector float f_in) {
__vector unsigned int in = (__vector unsigned int)f_in;
const __vector unsigned int mask_sign_32 = {0x80000000, 0x80000000,
0x80000000, 0x80000000};
const __vector unsigned int mask_exp_32 = {0x7F800000, 0x7F800000, 0x7F800000,
0x7F800000};
const __vector unsigned int mask_mant_32 = {0x007FFFFF, 0x007FFFFF,
0x007FFFFF, 0x007FFFFF};
// Use SIGNED integers for exponent math to handle underflow check
const __vector signed int bias_adj = {112, 112, 112, 112};
const __vector signed int zero = {0, 0, 0, 0};
const __vector signed int max_exp = {31, 31, 31, 31}; // Max FP16 exp
const __vector unsigned int exp_max_fp32 = {0xFF, 0xFF, 0xFF, 0xFF};
const __vector unsigned int exp_max_fp16 = {0x1F, 0x1F, 0x1F, 0x1F};
__vector unsigned int s = (in & mask_sign_32) >> 16;
__vector unsigned int e_u = (in & mask_exp_32) >> 23;
// Check for NaN/Inf: exponent = 0xFF in FP32
__vector __bool int is_nan_inf = vec_cmpeq(e_u, exp_max_fp32);
__vector signed int e_s = (__vector signed int)e_u;
e_s = vec_sub(e_s, bias_adj);
e_s = vec_max(e_s, zero);
e_s = vec_min(e_s, max_exp);
__vector unsigned int e_normal = (__vector unsigned int)e_s;
__vector unsigned int e_final = vec_sel(e_normal, exp_max_fp16, is_nan_inf);
const __vector unsigned int one_v = {1, 1, 1, 1};
const __vector unsigned int mask_sticky = {0xFFF, 0xFFF, 0xFFF, 0xFFF};
__vector unsigned int round_bit = (in >> 12) & one_v;
__vector unsigned int sticky = in & mask_sticky;
__vector unsigned int m = (in & mask_mant_32) >> 13;
__vector unsigned int lsb = m & one_v; // LSB of mantissa for tie-breaking
// Round up if: round_bit && (sticky || lsb)
__vector __bool int sticky_nonzero =
vec_cmpgt(sticky, (__vector unsigned int){0, 0, 0, 0});
__vector __bool int lsb_set = vec_cmpeq(lsb, one_v);
__vector __bool int round_up =
vec_and(vec_cmpeq(round_bit, one_v), vec_or(sticky_nonzero, lsb_set));
m = vec_sel(m, m + one_v, round_up);
const __vector unsigned int mant_mask = {0x3FF, 0x3FF, 0x3FF, 0x3FF};
const __vector unsigned int max_normal_exp = {0x1E, 0x1E, 0x1E, 0x1E};
__vector __bool int mant_overflows = vec_cmpgt(m, mant_mask);
__vector __bool int would_overflow_to_inf =
vec_and(mant_overflows, vec_cmpeq(e_final, max_normal_exp));
__vector unsigned int e_inc = vec_min(e_final + one_v, exp_max_fp16);
e_final = vec_sel(e_final, e_inc, mant_overflows);
m = vec_and(m, mant_mask);
e_final = vec_sel(e_final, max_normal_exp, would_overflow_to_inf);
m = vec_sel(m, mant_mask, would_overflow_to_inf);
return s | (e_final << 10) | m;
}
struct BF16Vec32 : public Vec<BF16Vec32> {
constexpr static int VEC_ELEM_NUM = 32;
......@@ -180,6 +301,18 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
reg.val[1] = (__vector float)vec_mergel(v.reg, zero);
}
explicit FP32Vec8(const FP16Vec8& v) {
// Cast to UNSIGNED short vector to prevent sign-extension during unpack
__vector unsigned short raw_u = (__vector unsigned short)v.reg;
// Unpack 8x16-bit to two 4x32-bit vectors (Zero extended)
__vector unsigned int raw_hi = (__vector unsigned int)vec_unpackh(raw_u);
__vector unsigned int raw_lo = (__vector unsigned int)vec_unpackl(raw_u);
reg.val[0] = fp16_to_fp32_bits(raw_hi);
reg.val[1] = fp16_to_fp32_bits(raw_lo);
}
float reduce_sum() const {
AliasReg ar;
ar.reg = reg;
......@@ -531,6 +664,22 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
reg.val[3] = (__vector float)vec_mergel(v.reg.val[1], zero);
}
explicit FP32Vec16(const FP16Vec16& v) {
__vector unsigned int raw_hi_0 =
(__vector unsigned int)vec_unpackh(v.reg.val[0]);
__vector unsigned int raw_lo_0 =
(__vector unsigned int)vec_unpackl(v.reg.val[0]);
reg.val[0] = fp16_to_fp32_bits(raw_hi_0);
reg.val[1] = fp16_to_fp32_bits(raw_lo_0);
__vector unsigned int raw_hi_1 =
(__vector unsigned int)vec_unpackh(v.reg.val[1]);
__vector unsigned int raw_lo_1 =
(__vector unsigned int)vec_unpackl(v.reg.val[1]);
reg.val[2] = fp16_to_fp32_bits(raw_hi_1);
reg.val[3] = fp16_to_fp32_bits(raw_lo_1);
}
explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
FP32Vec16 operator*(const FP32Vec16& b) const {
......@@ -628,8 +777,10 @@ struct VecType<c10::BFloat16> {
using vec_type = BF16Vec8;
};
// On s390x, FP16 (Half) is not natively supported, use FP32 vectors instead
using FP16Vec16 = FP32Vec16;
template <>
struct VecType<c10::Half> {
using vec_type = FP16Vec8;
};
template <typename T>
void storeFP32(float v, T* ptr) {
......@@ -650,6 +801,52 @@ inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) {
*ptr = *(v_ptr + 1);
}
template <>
inline void storeFP32<::c10::Half>(float v, ::c10::Half* ptr) {
// Use bit-manipulation for IEEE FP32 to FP16 conversion since vector
// intrinsics for FP32 to FP16 conversion does not use IEEE rounding and can
// produce incorrect results for some inputs. Process each of the 4 vectors
// separately.
uint32_t in;
std::memcpy(&in, &v, sizeof(in));
uint32_t s = (in & 0x80000000) >> 16; // Sign
uint32_t e = (in & 0x7F800000) >> 23; // Exponent
uint32_t round_bit = (in >> 12) & 1;
uint32_t sticky = (in & 0xFFF) != 0; // Any bits in [11..0]
uint32_t m = (in & 0x007FFFFF) >> 13;
uint32_t lsb = m & 1; // LSB of mantissa for tie-breaking
// Check for NaN/Inf before rounding
bool is_nan_inf = (e == 0xFF);
if (round_bit && (sticky || lsb)) {
m++;
// Handle mantissa overflow: if m overflows 10 bits, increment exponent
if (m > 0x3FF) {
m = 0;
e++;
}
}
if (is_nan_inf) {
// NaN/Inf: preserve it
e = 0x1F;
} else {
// Normal: adjust bias (127 - 15), flush subnormals to zero
e = (e >= 112) ? (e - 112) : 0;
// If exponent overflows to Inf range, saturate to max normal FP16 value
if (e > 0x1E) {
e = 0x1E; // Max normal exponent
m = 0x3FF; // Max mantissa
}
}
uint16_t fp16 = (uint16_t)(s | (e << 10) | m);
*reinterpret_cast<uint16_t*>(ptr) = fp16;
}
#ifndef __VEC_CLASS_FP_NAN
#define __VEC_CLASS_FP_NAN (1 << 6)
#endif
......@@ -803,6 +1000,44 @@ inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
reg.val[1] = (__vector signed short)vec_perm(inp2, inp3, omask);
}
inline FP16Vec8::FP16Vec8(const FP32Vec8& v) {
// Use bit-manipulation for IEEE FP32 to FP16 conversion since vector
// intrinsics for FP32 to FP16 conversion does not use IEEE rounding and can
// produce incorrect results for some inputs. Process each of the 4 vectors
// separately.
__vector unsigned int res_hi = fp32_to_fp16_bits(v.reg.val[0]);
__vector unsigned int res_lo = fp32_to_fp16_bits(v.reg.val[1]);
const __vector unsigned char perm_pack = {
2, 3, 6, 7, 10, 11, 14, 15, // Select lower 2 bytes from res_hi
18, 19, 22, 23, 26, 27, 30, 31 // Select lower 2 bytes from res_lo
};
reg = vec_perm((__vector signed short)res_hi, (__vector signed short)res_lo,
perm_pack);
}
inline FP16Vec16::FP16Vec16(const FP32Vec16& v) {
// Use bit-manipulation for IEEE FP32 to FP16 conversion since vector
// intrinsics for FP32 to FP16 conversion does not use IEEE rounding and can
// produce incorrect results for some inputs. Process each of the 4 vectors
// separately.
__vector unsigned int res_0 = fp32_to_fp16_bits(v.reg.val[0]);
__vector unsigned int res_1 = fp32_to_fp16_bits(v.reg.val[1]);
__vector unsigned int res_2 = fp32_to_fp16_bits(v.reg.val[2]);
__vector unsigned int res_3 = fp32_to_fp16_bits(v.reg.val[3]);
const __vector unsigned char perm_pack = {
2, 3, 6, 7, 10, 11, 14, 15, // Lower 2 bytes from first vector
18, 19, 22, 23, 26, 27, 30, 31 // Lower 2 bytes from second vector
};
reg.val[0] = vec_perm((__vector signed short)res_0,
(__vector signed short)res_1, perm_pack);
reg.val[1] = vec_perm((__vector signed short)res_2,
(__vector signed short)res_3, perm_pack);
}
// 1D softmax over `n` elements in `input`, writes result to `output`.
// Uses FP32Vec8 for main body, scalar tail handling.
// Requirement: n > 0
......
......@@ -237,12 +237,17 @@ W8A8MatMulPrimitiveHandler::W8A8MatMulPrimitiveHandler(const Args& args)
};
dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_,
{b_k_stride_, b_n_stride_});
// dummy M size for prepacking weights
// Prepacking weights improves performance and avoid runtime reorders
constexpr dnnl_dim_t kProbeM = 128;
prepack_weight(args.b_ptr, original_b_md,
create_primitive_desc(
MSizeCacheKey{.a_m_size = DNNL_RUNTIME_DIM_VAL,
MSizeCacheKey{.a_m_size = kProbeM,
.use_bias = false,
.bias_type = dnnl::memory::data_type::undef},
true)
/*first_time=*/true)
.weights_desc());
init_runtime_memory_cache(args);
}
......@@ -403,21 +408,19 @@ MatMulPrimitiveHandler::MatMulPrimitiveHandler(const Args& args)
dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_,
{b_k_stride_, b_n_stride_});
// dummy M size for prepacking weights
// Prepacking weights improves performance and avoid runtime reorders
constexpr dnnl_dim_t kProbeM = 128;
prepack_weight(args.b_ptr, original_b_md,
create_primitive_desc(
MSizeCacheKey{
#ifdef VLLM_USE_ACL
// Arm Compute Library (ACL) backend for oneDNN does
// not support runtime
// dimensions, so we set M to a default value
.a_m_size = 128,
.a_m_stride = b_k_size_,
#else
.a_m_size = DNNL_RUNTIME_DIM_VAL,
.a_m_stride = DNNL_RUNTIME_DIM_VAL,
#endif
.use_bias = false,
.bias_type = dnnl::memory::data_type::undef},
MSizeCacheKey{// Use a concrete M so oneDNN's kernel
// selector can choose an optimally blocked
// weight layout.
.a_m_size = kProbeM,
.a_m_stride = b_k_size_,
.use_bias = false,
.bias_type = dnnl::memory::data_type::undef},
true)
.weights_desc());
init_runtime_memory_cache(args);
......
......@@ -19,10 +19,11 @@ ISA_TYPES = {
"VEC": 1,
"VEC16": 2,
"NEON": 3,
"VXE": 4,
}
# ISAs supported for head_dims divisible by 32
ISA_FOR_32 = ["AMX", "NEON", "VEC", "VEC16"]
ISA_FOR_32 = ["AMX", "NEON", "VEC", "VEC16", "VXE"]
# ISAs supported for head_dims divisible by 16 only
ISA_FOR_16 = ["VEC16"]
......@@ -118,6 +119,10 @@ def generate_header_file() -> str:
#include "cpu_attn_neon.hpp"
#endif
#ifdef __s390x__
#include "cpu_attn_vxe.hpp"
#endif
"""
header += generate_helper_function()
......@@ -163,6 +168,25 @@ def generate_header_file() -> str:
} \\
}()
"""
# s390x with VXE
header += """#elif defined(__s390x__)
#define CPU_ATTN_DISPATCH(HEAD_DIM, ISA_TYPE, ...) \\
[&] { \\
int64_t encoded_params = encode_cpu_attn_params(HEAD_DIM, ISA_TYPE); \\
switch (encoded_params) { \\
"""
header += generate_cases_for_isa_group(["VXE", "VEC", "VEC16"])
header += """
default: { \\
TORCH_CHECK(false, "Unsupported CPU attention configuration: head_dim=" + \\
std::to_string(HEAD_DIM) + " isa=" + \\
std::to_string(static_cast<int>(ISA_TYPE))); \\
} \\
} \\
}()
"""
# Fallback: VEC and VEC16 only
......@@ -182,7 +206,7 @@ def generate_header_file() -> str:
} \\
}()
#endif /* CPU_CAPABILITY_AMXBF16 / __aarch64__ */
#endif /* CPU_CAPABILITY_AMXBF16 / __aarch64__ / __s390x__ */
#endif // CPU_ATTN_DISPATCH_GENERATED_H
"""
......
......@@ -18,8 +18,8 @@ struct KernelVecType<float> {
template <>
struct KernelVecType<c10::Half> {
#if defined(__powerpc64__) || defined(__s390x__)
// Power and s390x architecture-specific vector types
#if defined(__powerpc64__)
// Power specific vector types
using qk_load_vec_type = vec_op::FP32Vec16;
using qk_vec_type = vec_op::FP32Vec16;
using v_load_vec_type = vec_op::FP32Vec16;
......@@ -38,16 +38,7 @@ struct KernelVecType<c10::BFloat16> {
using qk_vec_type = vec_op::BF16Vec32;
using v_load_vec_type = vec_op::BF16Vec16;
};
#elif defined(__s390x__)
template <>
struct KernelVecType<c10::BFloat16> {
using qk_load_vec_type = vec_op::BF16Vec16;
using qk_vec_type = vec_op::FP32Vec16;
using v_load_vec_type = vec_op::BF16Vec16;
};
#elif defined(__aarch64__)
#else
template <>
struct KernelVecType<c10::BFloat16> {
using qk_load_vec_type = vec_op::BF16Vec16;
......
......@@ -4,6 +4,10 @@
#include <torch/library.h>
// Note: overwrite the external definition for sharing same name between
// libraries use different ISAs.
#define TORCH_EXTENSION_NAME _C
std::string init_cpu_threads_env(const std::string& cpu_ids);
void release_dnnl_matmul_handler(int64_t handler);
......@@ -118,8 +122,8 @@ void cpu_fused_moe(torch::Tensor& output, const torch::Tensor& input,
const std::optional<torch::Tensor>& w13_bias,
const std::optional<torch::Tensor>& w2_bias,
const torch::Tensor& topk_weights,
const torch::Tensor& topk_id, const std::string& act,
const std::string& isa);
const torch::Tensor& topk_id, const bool skip_weighted,
const std::string& act, const std::string& isa);
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// vLLM custom ops
......@@ -319,22 +323,16 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def(
"cpu_fused_moe(Tensor(a0!) output, Tensor input, Tensor w13, Tensor w2, "
"Tensor? w13_bias, Tensor? w2_bias, Tensor topk_weights, Tensor topk_id, "
"bool skip_weighted, "
"str act, str isa) -> ()");
ops.impl("cpu_fused_moe", torch::kCPU, &cpu_fused_moe);
#endif
}
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {
// CPU utils
utils.def("init_cpu_threads_env(str cpu_ids) -> str", &init_cpu_threads_env);
}
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cpu), cpu_ops) {
cpu_ops.def(
ops.def("init_cpu_threads_env(str cpu_ids) -> str", &init_cpu_threads_env);
ops.def(
"mla_decode_kvcache("
" Tensor! out, Tensor query, Tensor kv_cache,"
" float scale, Tensor block_tables, Tensor seq_lens) -> ()");
cpu_ops.impl("mla_decode_kvcache", torch::kCPU, &mla_decode_kvcache);
ops.impl("mla_decode_kvcache", torch::kCPU, &mla_decode_kvcache);
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
#pragma once
#include <c10/util/BFloat16.h>
#include <c10/util/Half.h>
#include <cassert>
#ifdef USE_ROCM
#include <hip/hip_runtime.h>
#else
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#endif
// Device-side: SM100+ architecture with CUDA 12.9+ toolkit, which
// together enable 256-bit (v8.u32) PTX load/store instructions.
// Use for PTX instruction selection with architecture fallback paths.
#if !defined(USE_ROCM) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000 && \
defined(CUDA_VERSION) && CUDA_VERSION >= 12090
#define VLLM_256B_PTX_ENABLED 1
#else
#define VLLM_256B_PTX_ENABLED 0
#endif
namespace vllm {
// ============================================================
// Types and traits
// ============================================================
// 256-bit (32-byte) aligned vector type: 8 x uint32_t
struct alignas(32) u32x8_t {
uint32_t d[8];
};
// VecTraits — select between 128-bit (int4) and 256-bit
// (u32x8_t) vector types at compile time.
template <bool support_256>
struct VecTraits;
template <>
struct VecTraits<true> {
static constexpr int ARCH_MAX_VEC_SIZE = 32;
using vec_t = u32x8_t;
};
template <>
struct VecTraits<false> {
static constexpr int ARCH_MAX_VEC_SIZE = 16;
using vec_t = int4;
};
// PackedTypeConverter — map between CUDA scalar and packed types
// half <-> half2, __nv_bfloat16 <-> __nv_bfloat162, etc.
template <typename T>
struct PackedTypeConverter {
static_assert(sizeof(T) == 0,
"PackedTypeConverter is not specialized for this type.");
};
template <>
struct PackedTypeConverter<half2> {
using Type = half;
};
template <>
struct PackedTypeConverter<half> {
using Type = half2;
};
template <>
struct PackedTypeConverter<__nv_bfloat162> {
using Type = __nv_bfloat16;
};
template <>
struct PackedTypeConverter<__nv_bfloat16> {
using Type = __nv_bfloat162;
};
template <>
struct PackedTypeConverter<float> {
using Type = float2;
};
template <>
struct PackedTypeConverter<float2> {
using Type = float;
};
template <>
struct PackedTypeConverter<c10::Half> {
using Type = half2;
};
template <>
struct PackedTypeConverter<c10::BFloat16> {
using Type = __nv_bfloat162;
};
// CUDATypeConverter — map PyTorch scalar types to CUDA scalar
// c10::Half -> half, c10::BFloat16 -> __nv_bfloat16
template <typename T>
struct CUDATypeConverter {
using Type = T;
};
template <>
struct CUDATypeConverter<c10::Half> {
using Type = half;
};
template <>
struct CUDATypeConverter<c10::BFloat16> {
using Type = __nv_bfloat16;
};
// PackedVec — typed vector container for packed element access.
// Derives alignment and element count from VecTraits.
// Type is the CUDA scalar type (e.g. half, __nv_bfloat16).
template <class Type, bool use_256b>
struct alignas(VecTraits<use_256b>::ARCH_MAX_VEC_SIZE) PackedVec {
static constexpr int NUM_ELTS =
VecTraits<use_256b>::ARCH_MAX_VEC_SIZE /
sizeof(typename PackedTypeConverter<Type>::Type);
typename PackedTypeConverter<Type>::Type elts[NUM_ELTS];
};
// ============================================================
// Load / store primitives
// ============================================================
// 256-bit load / store — SM100+ only (PTX v8 instructions).
__device__ __forceinline__ void ld256(u32x8_t& val, const u32x8_t* ptr) {
#if VLLM_256B_PTX_ENABLED
asm volatile("ld.global.nc.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%8];\n"
: "=r"(val.d[0]), "=r"(val.d[1]), "=r"(val.d[2]), "=r"(val.d[3]),
"=r"(val.d[4]), "=r"(val.d[5]), "=r"(val.d[6]), "=r"(val.d[7])
: "l"(ptr));
#else
assert(false && "ld256 requires SM100+ with CUDA 12.9+");
#endif
}
__device__ __forceinline__ void st256(u32x8_t& val, u32x8_t* ptr) {
#if VLLM_256B_PTX_ENABLED
asm volatile("st.global.v8.u32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};\n"
:
: "l"(ptr), "r"(val.d[0]), "r"(val.d[1]), "r"(val.d[2]),
"r"(val.d[3]), "r"(val.d[4]), "r"(val.d[5]), "r"(val.d[6]),
"r"(val.d[7])
: "memory");
#else
assert(false && "st256 requires SM100+ with CUDA 12.9+");
#endif
}
// Generic ld256 / st256 for any 32-byte aligned type (e.g. PackedVec).
// Non-template overloads above are preferred for u32x8_t.
template <typename T>
__device__ __forceinline__ void ld256(T& val, const T* ptr) {
static_assert(sizeof(T) == 32, "ld256 requires a 32-byte type");
ld256(reinterpret_cast<u32x8_t&>(val), reinterpret_cast<const u32x8_t*>(ptr));
}
template <typename T>
__device__ __forceinline__ void st256(T& val, T* ptr) {
static_assert(sizeof(T) == 32, "st256 requires a 32-byte type");
st256(reinterpret_cast<u32x8_t&>(val), reinterpret_cast<u32x8_t*>(ptr));
}
// 128-bit load / store via __ldg (read-only cache hint).
template <typename T>
__device__ __forceinline__ void ld128(T& val, const T* ptr) {
static_assert(sizeof(T) == 16, "ld128 requires a 16-byte type");
*reinterpret_cast<int4*>(&val) = __ldg(reinterpret_cast<const int4*>(ptr));
}
template <typename T>
__device__ __forceinline__ void st128(T& val, T* ptr) {
static_assert(sizeof(T) == 16, "st128 requires a 16-byte type");
*reinterpret_cast<int4*>(ptr) = *reinterpret_cast<int4*>(&val);
}
// 256-bit cache-streaming (.cs) load / store — SM100+ only.
__forceinline__ __device__ u32x8_t ld256_cs(const u32x8_t* addr) {
#if VLLM_256B_PTX_ENABLED
u32x8_t val;
asm volatile("ld.global.cs.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%8];"
: "=r"(val.d[0]), "=r"(val.d[1]), "=r"(val.d[2]), "=r"(val.d[3]),
"=r"(val.d[4]), "=r"(val.d[5]), "=r"(val.d[6]), "=r"(val.d[7])
: "l"(addr));
return val;
#else
assert(false && "ld256_cs requires SM100+ with CUDA 12.9+");
return u32x8_t{};
#endif
}
__forceinline__ __device__ void st256_cs(u32x8_t* addr, u32x8_t val) {
#if VLLM_256B_PTX_ENABLED
asm volatile(
"st.global.cs.v8.u32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};" ::"l"(addr),
"r"(val.d[0]), "r"(val.d[1]), "r"(val.d[2]), "r"(val.d[3]), "r"(val.d[4]),
"r"(val.d[5]), "r"(val.d[6]), "r"(val.d[7]));
#else
assert(false && "st256_cs requires SM100+ with CUDA 12.9+");
#endif
}
// 32-bit load / store.
__device__ __forceinline__ int ld32(const int* addr) { return __ldg(addr); }
__device__ __forceinline__ void st32(int* addr, int val) { *addr = val; }
// 32-bit cache-streaming (.cs) load / store.
// Falls back to ld32/st32 on ROCm (no .cs hint).
__forceinline__ __device__ int ld32_cs(const int* addr) {
int val;
#ifndef USE_ROCM
asm volatile("ld.global.cs.b32 %0, [%1];" : "=r"(val) : "l"(addr));
#else
val = ld32(addr);
#endif
return val;
}
__forceinline__ __device__ void st32_cs(int* addr, int val) {
#ifndef USE_ROCM
asm volatile("st.global.cs.b32 [%0], %1;" ::"l"(addr), "r"(val));
#else
st32(addr, val);
#endif
}
// 128-bit cache-streaming (.cs) load / store.
// Falls back to ld128/st128 on ROCm (no .cs hint).
__forceinline__ __device__ int4 ld128_cs(const int4* addr) {
int4 val;
#ifndef USE_ROCM
asm volatile("ld.global.cs.v4.u32 {%0,%1,%2,%3}, [%4];"
: "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w)
: "l"(addr));
#else
ld128(val, addr);
#endif
return val;
}
__forceinline__ __device__ void st128_cs(int4* addr, int4 val) {
#ifndef USE_ROCM
asm volatile("st.global.cs.v4.u32 [%0], {%1,%2,%3,%4};" ::"l"(addr),
"r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w));
#else
st128(val, addr);
#endif
}
// Predicated 256-bit / 128-bit cache-global (.cg) loads.
// Returns zero if pred is false. SM100+ only.
__device__ __forceinline__ void ld256_cg_or_zero(u32x8_t& val, const void* ptr,
bool pred) {
#if VLLM_256B_PTX_ENABLED
asm volatile(
"{\n"
" .reg .pred pr;\n"
" setp.ne.u32 pr, %8, 0;\n"
" mov.u32 %0, 0;\n"
" mov.u32 %1, 0;\n"
" mov.u32 %2, 0;\n"
" mov.u32 %3, 0;\n"
" mov.u32 %4, 0;\n"
" mov.u32 %5, 0;\n"
" mov.u32 %6, 0;\n"
" mov.u32 %7, 0;\n"
" @pr ld.global.cg.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%9];\n"
"}\n"
: "=r"(val.d[0]), "=r"(val.d[1]), "=r"(val.d[2]), "=r"(val.d[3]),
"=r"(val.d[4]), "=r"(val.d[5]), "=r"(val.d[6]), "=r"(val.d[7])
: "r"((int)pred), "l"(ptr));
#else
assert(false && "ld256_cg_or_zero requires SM100+ with CUDA 12.9+");
#endif
}
__device__ __forceinline__ void ld128_cg_or_zero(uint4& val, const void* ptr,
bool pred) {
#ifndef USE_ROCM
uint32_t r0, r1, r2, r3;
asm volatile(
"{\n"
" .reg .pred pr;\n"
" setp.ne.u32 pr, %4, 0;\n"
" mov.u32 %0, 0;\n"
" mov.u32 %1, 0;\n"
" mov.u32 %2, 0;\n"
" mov.u32 %3, 0;\n"
" @pr ld.global.cg.v4.u32 {%0,%1,%2,%3}, [%5];\n"
"}\n"
: "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3)
: "r"((int)pred), "l"(ptr));
val = uint4{r0, r1, r2, r3};
#else
assert(false && "ld128_cg_or_zero is not supported on ROCm");
#endif
}
// ============================================================
// Alignment helpers
// ============================================================
__host__ __device__ __forceinline__ bool is_16byte_aligned(const void* ptr) {
return (reinterpret_cast<uintptr_t>(ptr) & 15) == 0;
}
__host__ __device__ __forceinline__ bool is_32byte_aligned(const void* ptr) {
return (reinterpret_cast<uintptr_t>(ptr) & 31) == 0;
}
// ============================================================
// Packed type conversion and arithmetic
// ============================================================
template <typename packed_t>
__device__ __forceinline__ float2 cast_to_float2(const packed_t& val) {
if constexpr (std::is_same_v<packed_t, __nv_bfloat162>) {
return __bfloat1622float2(val);
} else if constexpr (std::is_same_v<packed_t, __half2>) {
return __half22float2(val);
} else if constexpr (std::is_same_v<packed_t, float2>) {
return float2(val);
}
}
template <typename packed_t>
__device__ __forceinline__ packed_t cast_to_packed(const float2& val) {
if constexpr (std::is_same_v<packed_t, __nv_bfloat162>) {
return __float22bfloat162_rn(val);
} else if constexpr (std::is_same_v<packed_t, __half2>) {
return __float22half2_rn(val);
} else if constexpr (std::is_same_v<packed_t, float2>) {
return float2(val);
}
}
template <typename packed_t>
__device__ __forceinline__ packed_t packed_mul(const packed_t& x,
const packed_t& y) {
if constexpr (std::is_same_v<packed_t, __nv_bfloat162> ||
std::is_same_v<packed_t, __half2>) {
return __hmul2(x, y);
} else if constexpr (std::is_same_v<packed_t, float2>) {
return make_float2(x.x * y.x, x.y * y.y);
}
}
} // namespace vllm
......@@ -2,33 +2,58 @@
#include <torch/cuda.h>
#include <cuda_runtime.h>
// This function assumes that `cpu_tensor` is a CPU tensor allocated with pinned
// memory, and that UVA (Unified Virtual Addressing) is enabled.
// This function assumes that `cpu_tensor` is a CPU tensor,
// and that UVA (Unified Virtual Addressing) is enabled.
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor) {
TORCH_CHECK(cpu_tensor.device().is_cpu(), "Input tensor must be on CPU");
// Get raw host pointer from CPU tensor
void* host_ptr = cpu_tensor.data_ptr();
// handle empty tensor
if (cpu_tensor.numel() == 0) {
return torch::empty(cpu_tensor.sizes(),
cpu_tensor.options().device(torch::kCUDA));
}
if (cpu_tensor.is_pinned()) {
// If CPU tensor is pinned, directly get the device pointer.
void* host_ptr = const_cast<void*>(cpu_tensor.data_ptr());
void* device_ptr = nullptr;
cudaError_t err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0);
TORCH_CHECK(err == cudaSuccess,
"cudaHostGetDevicePointer failed: ", cudaGetErrorString(err));
return torch::from_blob(
device_ptr, cpu_tensor.sizes(), cpu_tensor.strides(),
[base = cpu_tensor](void*) {}, // keep cpu tensor alive
cpu_tensor.options().device(torch::kCUDA));
}
// If CPU tensor is not pinned, allocate a new pinned memory buffer.
torch::Tensor contiguous_cpu = cpu_tensor.contiguous();
size_t nbytes = contiguous_cpu.nbytes();
void* host_ptr = nullptr;
cudaError_t err = cudaHostAlloc(&host_ptr, nbytes, cudaHostAllocMapped);
if (err != cudaSuccess) {
AT_ERROR("cudaHostAlloc failed: ", cudaGetErrorString(err));
}
err = cudaMemcpy(host_ptr, contiguous_cpu.data_ptr(), nbytes,
cudaMemcpyDefault);
if (err != cudaSuccess) {
cudaFreeHost(host_ptr);
AT_ERROR("cudaMemcpy failed: ", cudaGetErrorString(err));
}
// Get a device pointer corresponding to the pinned host memory
void* device_ptr = nullptr;
cudaError_t err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0);
TORCH_CHECK(err == cudaSuccess,
"cudaHostGetDevicePointer failed: ", cudaGetErrorString(err));
// We'll use the same sizes, strides, and dtype as the CPU tensor.
// TODO: check if layout is respected.
auto sizes = cpu_tensor.sizes();
auto strides = cpu_tensor.strides();
auto options = cpu_tensor.options().device(torch::kCUDA);
// use default no-op deleter, since the memory is owned by the original CPU
// tensor
torch::Tensor cuda_tensor =
torch::from_blob(device_ptr, sizes, strides, options);
TORCH_CHECK(cuda_tensor.device().is_cuda(),
"Resulting tensor is not on CUDA device");
return cuda_tensor;
}
err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0);
if (err != cudaSuccess) {
cudaFreeHost(host_ptr);
AT_ERROR("cudaHostGetDevicePointer failed: ", cudaGetErrorString(err));
}
auto deleter = [host_ptr](void*) { cudaFreeHost(host_ptr); };
return torch::from_blob(device_ptr, contiguous_cpu.sizes(),
contiguous_cpu.strides(), deleter,
contiguous_cpu.options().device(torch::kCUDA));
}
\ No newline at end of file
......@@ -109,16 +109,18 @@ void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem,
#ifndef USE_ROCM
int flag = 0;
CUDA_CHECK(cuDeviceGetAttribute(
CUresult rdma_result = cuDeviceGetAttribute(
&flag, CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WITH_CUDA_VMM_SUPPORTED,
device));
if (flag) { // support GPUDirect RDMA if possible
device);
if (rdma_result == CUDA_SUCCESS &&
flag) { // support GPUDirect RDMA if possible
prop.allocFlags.gpuDirectRDMACapable = 1;
}
int fab_flag = 0;
CUDA_CHECK(cuDeviceGetAttribute(
&fab_flag, CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED, device));
if (fab_flag) { // support fabric handle if possible
CUresult fab_result = cuDeviceGetAttribute(
&fab_flag, CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED, device);
if (fab_result == CUDA_SUCCESS &&
fab_flag) { // support fabric handle if possible
prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_FABRIC;
}
#endif
......
/*
* Adapted from
* https://github.com/sgl-project/sglang/blob/main/sgl-kernel/csrc/gemm/dsv3_fused_a_gemm.cu
* which was adapted from
* https://github.com/NVIDIA/TensorRT-LLM/blob/619709fc33bd5dc268f19d6a741fe7ed51c0f8f5/cpp/tensorrt_llm/kernels/dsv3MinLatencyKernels/dsv3FusedAGemm.cu
*
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2021, NAVER Corp. Authored by CLOVA.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <torch/all.h>
#include "core/registration.h"
#include <cstdlib>
#include <mutex>
namespace {
inline int getSMVersion() {
auto* props = at::cuda::getCurrentDeviceProperties();
return props->major * 10 + props->minor;
}
inline bool getEnvEnablePDL() {
static std::once_flag flag;
static bool enablePDL = false;
std::call_once(flag, [&]() {
if (getSMVersion() >= 90) {
char const* env = std::getenv("TRTLLM_ENABLE_PDL");
enablePDL = env && env[0] == '1' && env[1] == '\0';
}
});
return enablePDL;
}
} // namespace
using bf16_t = __nv_bfloat16;
__device__ void hmma_16_8_16_f32acc_bf16ab(float (&d_reg)[4],
const bf16_t (&a_reg)[8],
const bf16_t (&b_reg)[4],
float const (&c_reg)[4]) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
uint32_t a0 = *reinterpret_cast<uint32_t const*>(a_reg + 0);
uint32_t a1 = *reinterpret_cast<uint32_t const*>(a_reg + 2);
uint32_t a2 = *reinterpret_cast<uint32_t const*>(a_reg + 4);
uint32_t a3 = *reinterpret_cast<uint32_t const*>(a_reg + 6);
uint32_t b0 = *reinterpret_cast<uint32_t const*>(b_reg + 0);
uint32_t b1 = *reinterpret_cast<uint32_t const*>(b_reg + 2);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
: "=f"(d_reg[0]), "=f"(d_reg[1]), "=f"(d_reg[2]), "=f"(d_reg[3])
: "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), "f"(d_reg[0]),
"f"(d_reg[1]), "f"(d_reg[2]), "f"(d_reg[3]));
#endif
}
extern "C" {
__device__ uint32_t __nvvm_get_smem_pointer(void*);
}
__device__ void ldgsts_128(void const* gPtr, void* sPtr, uint32_t pred) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
if (pred) {
uint32_t smemPtrAsUint32 = __nvvm_get_smem_pointer(sPtr);
asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(
smemPtrAsUint32),
"l"(gPtr), "n"(16));
}
#endif
}
__device__ void ldsm_x4(void* smem_ptr, uint32_t* reg_ptr) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
asm volatile(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(reg_ptr[0]), "=r"(reg_ptr[1]), "=r"(reg_ptr[2]), "=r"(reg_ptr[3])
: "r"(__nvvm_get_smem_pointer(smem_ptr)));
#endif
}
template <class Type>
__device__ int apply_swizzle_343_on_elem_row_col(int row_idx_, int col_idx_) {
uint32_t row_idx = *reinterpret_cast<uint32_t*>(&row_idx_);
uint32_t col_idx = *reinterpret_cast<uint32_t*>(&col_idx_);
row_idx = row_idx % 8;
row_idx = row_idx * (16 / sizeof(Type));
col_idx = col_idx ^ row_idx;
return *reinterpret_cast<int*>(&col_idx);
}
__device__ void initialize_barrier(
uint64_t* smem_barrier, // 64 bits user-manged barrier in smem
int thread_count =
1) // Thread count expected to arrive/wait on this barrier
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
uint32_t smem_int_ptr = __nvvm_get_smem_pointer(smem_barrier);
asm volatile("mbarrier.init.shared::cta.b64 [%0], %1;\n" ::"r"(smem_int_ptr),
"r"(thread_count));
#endif
}
// Barrier wait
__device__ void wait_barrier(
uint64_t* smem_barrier, // 64 bits user-manged barrier in smem
int phase_bit) // Current phase bit the barrier waiting to flip
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
uint32_t smem_int_ptr = __nvvm_get_smem_pointer(smem_barrier);
asm volatile(
"{\n"
".reg .pred P1;\n"
"LAB_WAIT:\n"
"mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1;\n"
"@P1 bra DONE;\n"
"bra LAB_WAIT;\n"
"DONE:\n"
"}\n" ::"r"(smem_int_ptr),
"r"(phase_bit));
#endif
}
__device__ bool try_wait_barrier(uint64_t* smem_ptr, int phase_bit) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
uint32_t wait_complete;
uint32_t smem_int_ptr = __nvvm_get_smem_pointer(smem_ptr);
asm volatile(
"{\n\t"
".reg .pred P1; \n\t"
"mbarrier.try_wait.parity.shared::cta.b64 P1, [%1], %2; \n\t"
"selp.b32 %0, 1, 0, P1; \n\t"
"}"
: "=r"(wait_complete)
: "r"(smem_int_ptr), "r"(phase_bit));
return static_cast<bool>(wait_complete);
#endif
return false;
}
// Barrier arrive
__device__ void arrive_barrier(
uint64_t* smem_barrier) // 64 bits user-manged barrier in smem
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
uint32_t smem_int_ptr = __nvvm_get_smem_pointer(smem_barrier);
asm volatile(
"{\n"
".reg .b64 state; \n"
"mbarrier.arrive.shared::cta.b64 state, [%0];\n"
"}\n" ::"r"(smem_int_ptr));
#endif
}
__device__ void ldgsts_arrive(uint64_t* smem_barrier) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
uint32_t smem_int_ptr = __nvvm_get_smem_pointer(smem_barrier);
asm volatile("cp.async.mbarrier.arrive.noinc.shared.b64 [%0];"
:
: "r"(smem_int_ptr));
#endif
}
template <int gemm_k, int tile_m, int tile_k, int stage_cnt>
struct GmemLoaderA {
static constexpr int elem_bytes = 2;
static constexpr int vec_bytes = 16;
static constexpr int vec_elems = vec_bytes / elem_bytes;
static constexpr int thread_cnt = 64;
static_assert((tile_m * tile_k) % (vec_elems * thread_cnt) == 0);
static constexpr int a_inst_cnt_per_iter =
(tile_m * tile_k) / (vec_elems * thread_cnt);
static_assert(gemm_k % tile_k == 0);
static constexpr int k_iter_cnt = gemm_k / tile_k;
// Extra params to keep the order of k reduction...
static constexpr int mma_warp_cnt = 4;
static constexpr int per_mma_warp_k = tile_k / mma_warp_cnt;
static constexpr int k_each_chunk = gemm_k / mma_warp_cnt;
private:
__device__ int k_project(int tile_k_idx) {
return (tile_k_idx / per_mma_warp_k * k_each_chunk) +
(tile_k_idx % per_mma_warp_k);
}
public:
__device__ GmemLoaderA(bf16_t const* gmem_a_local_, bf16_t* smem_a_,
uint64_t* smem_barrier_)
: gmem_a(gmem_a_local_),
smem_a(smem_a_),
smem_barrier(smem_barrier_),
local_tid(threadIdx.x % thread_cnt) {}
__device__ void prepare() {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
// swizzle, that's what we want.
#pragma unroll
for (int i = 0; i < a_inst_cnt_per_iter; i++) {
int linear_idx = local_tid * vec_elems + i * thread_cnt * vec_elems;
int m_idx = linear_idx / tile_k;
int k_idx = linear_idx % tile_k;
k_idx = apply_swizzle_343_on_elem_row_col<bf16_t>(m_idx, k_idx);
a_smem_offsets[i] = m_idx * tile_k + k_idx;
}
#endif
}
__device__ void issue_mainloop() {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
#pragma unroll 1
for (int loop_idx = 0; loop_idx < k_iter_cnt; loop_idx++) {
if (need_wait) {
wait_barrier(smem_barrier + 1 + stage_idx * 2, phase_bit);
}
int next_stage_idx = stage_idx + 1;
int next_phase_bit =
next_stage_idx == stage_cnt ? phase_bit ^ 1 : phase_bit;
next_stage_idx = next_stage_idx == stage_cnt ? 0 : next_stage_idx;
if (loop_idx != k_iter_cnt - 1) {
need_wait = !try_wait_barrier(smem_barrier + 1 + next_stage_idx * 2,
next_phase_bit);
}
#pragma unroll
for (int i = 0; i < a_inst_cnt_per_iter; i++) {
int smem_offset = a_smem_offsets[i];
bf16_t* smem_ptr_this_iter =
smem_a + stage_idx * tile_m * tile_k + smem_offset;
int linear_idx = local_tid * vec_elems + i * thread_cnt * vec_elems;
int m_idx = linear_idx / tile_k;
int k_idx = linear_idx % tile_k;
int gmem_offset = m_idx * gemm_k + k_project(k_idx);
bf16_t const* gmem_ptr_this_iter = gmem_a + gmem_offset;
ldgsts_128(gmem_ptr_this_iter, smem_ptr_this_iter, true);
}
ldgsts_arrive(smem_barrier + stage_idx * 2);
stage_idx = next_stage_idx;
phase_bit = next_phase_bit;
gmem_a += per_mma_warp_k;
}
#endif
}
bf16_t const* gmem_a;
bf16_t* smem_a;
uint64_t* smem_barrier;
int local_tid;
int stage_idx = 0;
int phase_bit = 1;
bool need_wait = true;
// per smem_stage, store with swizzle information
int a_smem_offsets[a_inst_cnt_per_iter];
};
template <int gemm_k, int tile_n, int tile_k, int stage_cnt>
struct GmemLoaderB {
static constexpr int elem_bytes = 2;
static constexpr int vec_bytes = 16;
static constexpr int vec_elems = vec_bytes / elem_bytes;
static constexpr int thread_cnt = 64;
static_assert((tile_n * tile_k) % (vec_elems * thread_cnt) == 0);
static constexpr int b_inst_cnt_per_iter =
(tile_n * tile_k) / (vec_elems * thread_cnt);
static_assert(gemm_k % tile_k == 0);
static constexpr int k_iter_cnt = gemm_k / tile_k;
// Extra params to keep the order of k reduction...
static constexpr int mma_warp_cnt = 4;
static constexpr int per_mma_warp_k = tile_k / mma_warp_cnt;
static constexpr int k_each_chunk = gemm_k / mma_warp_cnt;
private:
__device__ int k_project(int tile_k_idx) {
return (tile_k_idx / per_mma_warp_k * k_each_chunk) +
(tile_k_idx % per_mma_warp_k);
}
public:
__device__ GmemLoaderB(bf16_t const* gmem_b_local_, bf16_t* smem_b_,
uint64_t* smem_barrier_, int gemm_n_)
: gmem_b(gmem_b_local_),
smem_b(smem_b_),
smem_barrier(smem_barrier_),
gemm_n(gemm_n_),
local_tid(threadIdx.x % thread_cnt) {}
__device__ void prepare() {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
// swizzle, that's what we want.
#pragma unroll
for (int i = 0; i < b_inst_cnt_per_iter; i++) {
int linear_idx = local_tid * vec_elems + i * thread_cnt * vec_elems;
int n_idx = linear_idx / tile_k;
int k_idx = linear_idx % tile_k;
k_idx = apply_swizzle_343_on_elem_row_col<bf16_t>(n_idx, k_idx);
b_smem_offsets[i] = n_idx * tile_k + k_idx;
preds[i] = n_idx < gemm_n;
}
#endif
}
__device__ void issue_mainloop() {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
asm volatile("griddepcontrol.wait;");
#pragma unroll 1
for (int loop_idx = 0; loop_idx < k_iter_cnt; loop_idx++) {
if (need_wait) {
wait_barrier(smem_barrier + 1 + stage_idx * 2, phase_bit);
}
int next_stage_idx = stage_idx + 1;
int next_phase_bit =
next_stage_idx == stage_cnt ? phase_bit ^ 1 : phase_bit;
next_stage_idx = next_stage_idx == stage_cnt ? 0 : next_stage_idx;
if (loop_idx != k_iter_cnt - 1) {
need_wait = !try_wait_barrier(smem_barrier + 1 + next_stage_idx * 2,
next_phase_bit);
}
#pragma unroll
for (int i = 0; i < b_inst_cnt_per_iter; i++) {
int smem_offset = b_smem_offsets[i];
bf16_t* smem_ptr_this_iter =
smem_b + stage_idx * tile_n * tile_k + smem_offset;
int linear_idx = local_tid * vec_elems + i * thread_cnt * vec_elems;
int n_idx = linear_idx / tile_k;
int k_idx = linear_idx % tile_k;
int gmem_offset = n_idx * gemm_k + k_project(k_idx);
bf16_t const* gmem_ptr_this_iter = gmem_b + gmem_offset;
ldgsts_128(gmem_ptr_this_iter, smem_ptr_this_iter, preds[i]);
}
ldgsts_arrive(smem_barrier + stage_idx * 2);
stage_idx = next_stage_idx;
phase_bit = next_phase_bit;
gmem_b += per_mma_warp_k;
}
#endif
}
bf16_t const* gmem_b;
bf16_t* smem_b;
uint64_t* smem_barrier;
int gemm_n;
int local_tid;
int stage_idx = 0;
int phase_bit = 1;
bool need_wait = true;
// per smem_stage, store with swizzle information
int b_smem_offsets[b_inst_cnt_per_iter];
uint32_t preds[b_inst_cnt_per_iter];
};
template <int gemm_m, int gemm_k, int tile_m, int tile_n, int tile_k,
int stage_cnt>
struct MmaComputer {
static constexpr int elem_bytes = 2;
static constexpr int thread_cnt = 128;
static_assert(gemm_k % tile_k == 0);
static_assert(tile_k % (thread_cnt / 32) == 0);
static constexpr int per_warp_tile_k = tile_k / (thread_cnt / 32);
static constexpr int k_iter_cnt = gemm_k / tile_k;
static constexpr int k_phase_cnt = per_warp_tile_k / 16;
static constexpr int m_iter_cnt = (tile_m + 15) / 16;
static constexpr int n_iter_cnt =
(tile_n + 7) /
8; // Possible to have non-1 n_iter_cnt for ab_swap m16 case.
static_assert(m_iter_cnt == 1);
static_assert(n_iter_cnt == 1 || n_iter_cnt == 2);
__device__ MmaComputer(bf16_t* gmem_c_local_, bf16_t* smem_a_,
bf16_t* smem_b_, uint64_t* smem_barrier_,
int warp_idx_, int gemm_n_)
: gmem_c(gmem_c_local_),
smem_a(smem_a_),
smem_b(smem_b_),
smem_barrier(smem_barrier_),
warp_idx(warp_idx_ - (thread_cnt / 32)),
gemm_n(gemm_n_) {}
private:
__device__ constexpr int internal_b_atom_func(int tid) {
if constexpr (tile_n < 8) {
return (tid % tile_n) + ((tid % 8) / tile_n * 0) + tid / 8 * 8 * tile_n;
} else {
return (tid % 8) + ((tid % 32) / 8 * (tile_n * 8));
}
}
public:
__device__ void prepare() {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
#pragma unroll
for (int i = 0; i < k_phase_cnt; i++) {
int linear_idx = (lane_idx % 16) + (lane_idx / 16) * 128 + i * 256;
int m_idx = linear_idx % tile_m;
int k_idx = linear_idx / tile_m + warp_k_offset_in_tile_k;
k_idx = apply_swizzle_343_on_elem_row_col<bf16_t>(m_idx, k_idx);
a_smem_offsets[0][i] = m_idx * tile_k + k_idx;
}
#pragma unroll
for (int n_iter_idx = 0; n_iter_idx < n_iter_cnt; n_iter_idx++) {
#pragma unroll
for (int i = 0; i < k_phase_cnt; i += 2) { // Special i+=2 for B.
int linear_idx =
internal_b_atom_func(lane_idx) + i * tile_n * 16 + n_iter_idx * 8;
int n_idx = linear_idx % tile_n;
int k_idx = linear_idx / tile_n + warp_k_offset_in_tile_k;
k_idx = apply_swizzle_343_on_elem_row_col<bf16_t>(n_idx, k_idx);
b_smem_offsets[n_iter_idx][i] = n_idx * tile_k + k_idx;
}
}
#endif
}
__device__ void issue_mainloop() {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
#pragma unroll 1
for (int loop_idx = 0; loop_idx < k_iter_cnt; loop_idx++) {
wait_barrier(smem_barrier + 0 + stage_idx * 2, phase_bit);
#pragma unroll
for (int i = 0; i < k_phase_cnt; i++) {
int smem_offset = a_smem_offsets[0][i];
bf16_t* smem_ptr_this_iter =
smem_a + stage_idx * tile_m * tile_k + smem_offset;
ldsm_x4(smem_ptr_this_iter, reinterpret_cast<uint32_t*>(a_reg[0][i]));
}
#pragma unroll
for (int n_iter_idx = 0; n_iter_idx < n_iter_cnt; n_iter_idx++) {
#pragma unroll
for (int i = 0; i < k_phase_cnt; i += 2) {
int smem_offset = b_smem_offsets[n_iter_idx][i];
bf16_t* smem_ptr_this_iter =
smem_b + stage_idx * tile_n * tile_k + smem_offset;
ldsm_x4(smem_ptr_this_iter,
reinterpret_cast<uint32_t*>(b_reg[n_iter_idx][i]));
}
}
#pragma unroll
for (int k_iter_idx = 0; k_iter_idx < k_phase_cnt; k_iter_idx++) {
#pragma unroll
for (int n_iter_idx = 0; n_iter_idx < n_iter_cnt; n_iter_idx++) {
hmma_16_8_16_f32acc_bf16ab(
acc_reg[0][n_iter_idx], a_reg[0][k_iter_idx],
b_reg[n_iter_idx][k_iter_idx], acc_reg[0][n_iter_idx]);
}
}
::arrive_barrier(smem_barrier + 1 + stage_idx * 2);
stage_idx += 1;
phase_bit = stage_idx == stage_cnt ? phase_bit ^ 1 : phase_bit;
stage_idx = stage_idx == stage_cnt ? 0 : stage_idx;
}
#endif
}
__device__ void epi() {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
asm volatile("bar.sync %0, %1;" : : "r"(1), "r"(thread_cnt));
// reorganize the acc_reg
constexpr int thread_m = 2;
constexpr int thread_n = 2 * n_iter_cnt;
constexpr int cta_mma_n = n_iter_cnt * 8;
float acc_reg_reorg[thread_m][thread_n];
for (int i = 0; i < thread_m; i++) {
for (int j = 0; j < thread_n; j++) {
acc_reg_reorg[i][j] = acc_reg[0][j / 2][(j % 2) + (i * 2)];
}
}
// 4 x cosize(smem_c_layout)
float* smem_c = reinterpret_cast<float*>(smem_a);
// coord -> index
auto smem_c_index_func = [&](int m_idx, int n_idx) {
int group_rows = 32 / cta_mma_n;
int group_cnt = 2;
return (m_idx % group_rows * cta_mma_n) +
(m_idx / group_rows * (32 + group_cnt)) + n_idx;
};
constexpr int cosize_smem_c = ((tile_m * cta_mma_n) / 32) * (32 + 2);
// This should be optimized to STS.64 but can not be STS.128 due to the bank
// index.
#pragma unroll
for (int m_idx_thread = 0; m_idx_thread < thread_m; m_idx_thread++) {
#pragma unroll
for (int n_idx_thread = 0; n_idx_thread < thread_n; n_idx_thread++) {
int m_idx = (lane_idx / 4) + m_idx_thread * 8;
int n_idx =
((lane_idx % 4) * 2) + (n_idx_thread % 2) + (n_idx_thread / 2) * 8;
smem_c[cosize_smem_c * warp_idx + smem_c_index_func(m_idx, n_idx)] =
acc_reg_reorg[m_idx_thread][n_idx_thread];
}
}
asm volatile("bar.sync %0, %1;" : : "r"(1), "r"(thread_cnt));
if (warp_idx == 0) {
constexpr int final_acc_reg_cnt = (tile_m * tile_n + 31) / 32;
float acc_final[final_acc_reg_cnt]{};
#pragma unroll
for (int reg_idx = 0; reg_idx < final_acc_reg_cnt; reg_idx++) {
int linear_idx = reg_idx * 32 + lane_idx;
int m_idx = linear_idx % tile_m;
int n_idx = linear_idx / tile_m;
acc_final[reg_idx] +=
smem_c[smem_c_index_func(m_idx, n_idx) + 0 * cosize_smem_c] +
smem_c[smem_c_index_func(m_idx, n_idx) + 1 * cosize_smem_c] +
smem_c[smem_c_index_func(m_idx, n_idx) + 2 * cosize_smem_c] +
smem_c[smem_c_index_func(m_idx, n_idx) + 3 * cosize_smem_c];
}
#pragma unroll
for (int reg_idx = 0; reg_idx < final_acc_reg_cnt; reg_idx++) {
int linear_idx = reg_idx * 32 + lane_idx;
int m_idx = linear_idx % tile_m;
int n_idx = linear_idx / tile_m;
if (m_idx < tile_m && n_idx < gemm_n) {
gmem_c[n_idx * gemm_m + m_idx] = acc_final[reg_idx];
}
}
}
#endif
}
bf16_t* gmem_c;
bf16_t* smem_a;
bf16_t* smem_b;
uint64_t* smem_barrier;
int warp_idx;
int gemm_n;
int stage_idx = 0;
int phase_bit = 0;
int lane_idx = threadIdx.x % 32;
int warp_k_offset_in_tile_k = warp_idx * per_warp_tile_k;
int a_smem_offsets[m_iter_cnt][k_phase_cnt];
int b_smem_offsets[n_iter_cnt][k_phase_cnt];
bf16_t a_reg[m_iter_cnt][k_phase_cnt][8];
bf16_t b_reg[n_iter_cnt][k_phase_cnt][4];
float acc_reg[m_iter_cnt][n_iter_cnt][4]{};
};
// AB swapped, kernel is k-major, k-major, m-major
template <int batch_size, int gemm_m, int gemm_k, int tile_m, int tile_n,
int tile_k, int stage_cnt>
__global__ __launch_bounds__(256, 1) void fused_a_gemm_kernel(
bf16_t* output, bf16_t const* mat_a, bf16_t const* mat_b, int gemm_n) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
constexpr int load_thread_cnt = 128;
constexpr int compute_thread_cnt = 128;
constexpr int thread_cnt = load_thread_cnt + compute_thread_cnt;
(void)thread_cnt;
static_assert(gemm_m % 16 == 0);
static_assert(gemm_k % tile_k == 0);
static_assert(gemm_m % tile_m == 0);
static_assert(
tile_k == 128 || tile_k == 256 || tile_k == 512 ||
tile_k == 1024); // tile_k must be larger than 64 since 4 warp splitK.
static_assert(tile_m == 16);
constexpr int g2s_vec_bytes = 16;
constexpr int a_elem_bytes = 2;
constexpr int b_elem_bytes = 2;
static_assert((tile_m * a_elem_bytes + tile_n * b_elem_bytes) * tile_k *
stage_cnt <=
225 * 1024);
static_assert((tile_m * tile_k * a_elem_bytes) %
(load_thread_cnt * g2s_vec_bytes) ==
0);
static_assert((tile_n * tile_k * b_elem_bytes) %
(load_thread_cnt * g2s_vec_bytes) ==
0);
extern __shared__ char smem[];
uint64_t* smem_barrier = reinterpret_cast<uint64_t*>(
smem); // producer,consumer; producer,consumer; ...
bf16_t* smem_a = reinterpret_cast<bf16_t*>(smem + (stage_cnt * 8 * 2 + 1024) /
1024 * 1024);
bf16_t* smem_b = smem_a + tile_m * tile_k * stage_cnt;
int cta_m_idx = tile_m * blockIdx.x;
int cta_n_idx = tile_n * blockIdx.y;
bf16_t const* gmem_a_local = mat_a + cta_m_idx * gemm_k;
bf16_t const* gmem_b_local = mat_b + cta_n_idx * gemm_k;
bf16_t* gmem_c_local = output + cta_n_idx * gemm_m + cta_m_idx;
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
if (warp_idx == 4) {
for (int i = 0; i < stage_cnt; i++) {
initialize_barrier(smem_barrier + i * 2 + 0,
load_thread_cnt); // producer
initialize_barrier(smem_barrier + i * 2 + 1,
compute_thread_cnt); // consumer
}
}
__syncthreads();
if (warp_idx < 2) {
GmemLoaderA<gemm_k, tile_m, tile_k, stage_cnt> a_loader(
gmem_a_local, smem_a, smem_barrier);
a_loader.prepare();
a_loader.issue_mainloop();
} else if (warp_idx < 4) {
GmemLoaderB<gemm_k, tile_n, tile_k, stage_cnt> b_loader(
gmem_b_local, smem_b, smem_barrier, gemm_n);
b_loader.prepare();
b_loader.issue_mainloop();
} else {
MmaComputer<gemm_m, gemm_k, tile_m, tile_n, tile_k, stage_cnt> mma_computer(
gmem_c_local, smem_a, smem_b, smem_barrier, warp_idx, gemm_n);
mma_computer.prepare();
mma_computer.issue_mainloop();
mma_computer.epi();
}
asm volatile("griddepcontrol.launch_dependents;");
#endif
}
template <typename T, int kHdIn, int kHdOut, int kTileN>
void invokeFusedAGemm(T* output, T const* mat_a, T const* mat_b, int num_tokens,
cudaStream_t const stream) {
constexpr int gemm_m = kHdOut; // 2112
int const gemm_n = num_tokens; // 1-16
constexpr int gemm_k = kHdIn; // 7168
constexpr int batch_size = 1;
std::swap(mat_a, mat_b);
constexpr int tile_m = 16;
constexpr int tile_n = kTileN; // 8 or 16
constexpr int tile_k = std::max(256, 1024 / tile_n); // 256
constexpr int max_stage_cnt =
1024 * 192 / ((tile_m + tile_n) * tile_k * sizeof(bf16_t));
constexpr int k_iter_cnt = gemm_k / tile_k;
constexpr int stage_cnt =
k_iter_cnt > max_stage_cnt ? max_stage_cnt : k_iter_cnt;
int cta_m_cnt = gemm_m / tile_m;
int cta_n_cnt = (gemm_n + tile_n - 1) / tile_n;
constexpr int barrier_bytes = (stage_cnt * 16 + 1023) / 1024 * 1024;
constexpr int smem_bytes =
((tile_m * 2 + tile_n * 2) * tile_k * stage_cnt + barrier_bytes + 1023) /
1024 * 1024;
dim3 grid(cta_m_cnt, cta_n_cnt, 1);
dim3 block_size(256);
cudaLaunchConfig_t config;
config.gridDim = grid;
config.blockDim = block_size;
config.dynamicSmemBytes = smem_bytes;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = getEnvEnablePDL();
config.numAttrs = 1;
config.attrs = attrs;
if (smem_bytes >= (48 * 1024)) {
cudaFuncSetAttribute(fused_a_gemm_kernel<batch_size, gemm_m, gemm_k, tile_m,
tile_n, tile_k, stage_cnt>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_bytes);
}
cudaLaunchKernelEx(&config,
fused_a_gemm_kernel<batch_size, gemm_m, gemm_k, tile_m,
tile_n, tile_k, stage_cnt>,
output, mat_a, mat_b, gemm_n);
}
template void invokeFusedAGemm<__nv_bfloat16, 7168, 2112, 8>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, int num_tokens,
cudaStream_t);
template void invokeFusedAGemm<__nv_bfloat16, 7168, 2112, 16>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, int num_tokens,
cudaStream_t);
void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a,
torch::Tensor const& mat_b) {
TORCH_CHECK(mat_a.dim() == 2 && mat_b.dim() == 2 && output.dim() == 2);
int const num_tokens = mat_a.size(0);
int const hd_in = mat_a.size(1);
int const hd_out = mat_b.size(1);
constexpr int kHdIn = 7168;
constexpr int kHdOut = 2112;
TORCH_CHECK(num_tokens >= 1 && num_tokens <= 16,
"required 1 <= mat_a.shape[0] <= 16")
TORCH_CHECK(hd_in == kHdIn, "required mat_a.shape[1] == 7168")
TORCH_CHECK(hd_out == kHdOut, "required mat_b.shape[1] == 2112")
TORCH_CHECK(output.size(0) == num_tokens,
"required output.shape[0] == mat_a.shape[0]")
TORCH_CHECK(output.size(1) == hd_out,
"required output.shape[1] == mat_b.shape[1]")
TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor");
TORCH_CHECK(output.stride(1) == 1, "output must be a row major tensor");
TORCH_CHECK(mat_b.stride(0) == 1, "mat_b must be a column major tensor");
TORCH_CHECK(mat_a.scalar_type() == torch::kBFloat16 &&
mat_b.scalar_type() == torch::kBFloat16,
"Only BFloat16 input dtype is supported")
TORCH_CHECK(output.scalar_type() == torch::kBFloat16,
"Only BFloat16 output dtype is supported")
TORCH_CHECK(getSMVersion() >= 90, "required CUDA ARCH >= SM_90");
auto stream = at::cuda::getCurrentCUDAStream(mat_a.get_device());
if (num_tokens <= 8) {
invokeFusedAGemm<__nv_bfloat16, kHdIn, kHdOut, 8>(
reinterpret_cast<__nv_bfloat16*>(output.mutable_data_ptr()),
reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()),
reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()), num_tokens,
stream);
} else {
invokeFusedAGemm<__nv_bfloat16, kHdIn, kHdOut, 16>(
reinterpret_cast<__nv_bfloat16*>(output.mutable_data_ptr()),
reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()),
reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()), num_tokens,
stream);
}
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("dsv3_fused_a_gemm", &dsv3_fused_a_gemm);
}
......@@ -15,9 +15,9 @@
////////////////////////////////////////////////////////////////////////////////////////////////////
struct SSMParamsBase {
using index_t = uint32_t;
using index_t = size_t;
int batch, dim, seqlen, dstate, n_groups, n_chunks;
int batch, dim, seqlen, dstate, n_groups;
int dim_ngroups_ratio;
bool is_variable_B;
bool is_variable_C;
......@@ -72,6 +72,8 @@ struct SSMParamsBase {
void *__restrict__ block_idx_first_scheduled_token_ptr; // (batch,) - first block to write
void *__restrict__ block_idx_last_scheduled_token_ptr; // (batch,) - last block to write
void *__restrict__ initial_state_idx_ptr; // (batch,) - index of the initial state to use
void *__restrict__ cu_chunk_seqlen_ptr; // (nchunks+1,) - cumulative chunk token offsets
void *__restrict__ last_chunk_indices_ptr; // (batch,) - index of last chunk per sequence
};
......
......@@ -81,7 +81,6 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
constexpr bool kIsVariableC = Ktraits::kIsVariableC;
constexpr bool kHasZ = Ktraits::kHasZ;
constexpr bool kVarlen = Ktraits::kVarlen;
constexpr int kNThreads = Ktraits::kNThreads;
constexpr int kNItems = Ktraits::kNItems;
constexpr int kNRows = Ktraits::kNRows;
constexpr bool kDirectIO = Ktraits::kDirectIO;
......@@ -161,17 +160,8 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
}
}
// for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) {
// smem_a[state_idx] = A[state_idx * params.A_dstate_stride];
// smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride];
// }
constexpr int kChunkSize = kNThreads * kNItems;
// Use block_size for chunking when APC is enabled, otherwise use 2048 for backwards compatibility
const int iteration_chunk_size = params.cache_enabled ? params.block_size : 2048;
const int n_chunks = (seqlen + iteration_chunk_size - 1) / iteration_chunk_size;
const int block_size = params.cache_enabled ? params.block_size : 2048;
const int* batch_cache_indices = cache_indices != nullptr ?
cache_indices + batch_id * params.cache_indices_stride : nullptr;
......@@ -181,10 +171,44 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
reinterpret_cast<const int*>(params.block_idx_last_scheduled_token_ptr) : nullptr;
const int* initial_state_idx = params.initial_state_idx_ptr != nullptr ?
reinterpret_cast<const int*>(params.initial_state_idx_ptr) : nullptr;
const int* cu_chunk_seqlen = params.cu_chunk_seqlen_ptr != nullptr ?
reinterpret_cast<const int*>(params.cu_chunk_seqlen_ptr) : nullptr;
const int* last_chunk_indices = params.last_chunk_indices_ptr != nullptr ?
reinterpret_cast<const int*>(params.last_chunk_indices_ptr) : nullptr;
const size_t load_cache_slot = params.cache_enabled && batch_cache_indices != nullptr ? batch_cache_indices[initial_state_idx[batch_id]] : cache_index;
const int block_idx_first = (params.cache_enabled && block_idx_first_scheduled != nullptr) ?
block_idx_first_scheduled[batch_id] : 0;
// Determine chunk boundaries from pre-computed metadata (APC mode)
// or fall back to simple block_size chunking.
int first_chunk_idx, n_chunks;
int current_position;
if (cu_chunk_seqlen != nullptr && last_chunk_indices != nullptr) {
const int last_chunk_idx = last_chunk_indices[batch_id];
first_chunk_idx = (batch_id == 0) ? 0 : last_chunk_indices[batch_id - 1] + 1;
n_chunks = last_chunk_idx - first_chunk_idx + 1;
// Derive current_position: if the first chunk is partial (fills remainder
// of a started block), offset into the block accordingly.
const int first_chunk_tokens = cu_chunk_seqlen[first_chunk_idx + 1] - cu_chunk_seqlen[first_chunk_idx];
const int chunk_start_offset = (n_chunks > 1 && first_chunk_tokens < block_size)
? (block_size - first_chunk_tokens) : 0;
current_position = block_idx_first * block_size + chunk_start_offset;
} else {
first_chunk_idx = 0;
n_chunks = (seqlen + block_size - 1) / block_size;
current_position = 0;
}
int tokens_processed = 0;
for (int chunk = 0; chunk < n_chunks; ++chunk) {
const int chunk_tokens = (cu_chunk_seqlen != nullptr)
? cu_chunk_seqlen[first_chunk_idx + chunk + 1] - cu_chunk_seqlen[first_chunk_idx + chunk]
: min(block_size, seqlen - tokens_processed);
if (chunk_tokens <= 0) break;
input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems];
__syncthreads();
......@@ -193,12 +217,12 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
if constexpr (!kDirectIO) {
if (r > 0) { __syncthreads(); }
}
load_input<Ktraits>(u + r * params.u_d_stride, u_vals[r], smem_load, seqlen - chunk * kChunkSize);
load_input<Ktraits>(u + r * params.u_d_stride, u_vals[r], smem_load, chunk_tokens);
if constexpr (!kDirectIO) { __syncthreads(); }
load_input<Ktraits>(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, seqlen - chunk * kChunkSize);
load_input<Ktraits>(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, chunk_tokens);
}
u += kChunkSize;
delta += kChunkSize;
u += chunk_tokens;
delta += chunk_tokens;
float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems];
#pragma unroll
......@@ -232,7 +256,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
weight_t B_vals[kNItems], C_vals[kNItems];
if constexpr (kIsVariableB) {
load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
smem_load_weight, (seqlen - chunk * kChunkSize) * (1));
smem_load_weight, chunk_tokens);
if constexpr (!kIsVariableC) {
#pragma unroll
for (int r = 0; r < kNRows; ++r) {
......@@ -243,7 +267,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
if constexpr (kIsVariableC) {
auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1;
load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
smem_load_weight_C, (seqlen - chunk * kChunkSize) * (1));
smem_load_weight_C, chunk_tokens);
if constexpr (!kIsVariableB) {
#pragma unroll
for (int r = 0; r < kNRows; ++r) {
......@@ -266,10 +290,8 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
for (int i = 0; i < kNItems; ++i) {
thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]),
!kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]);
if (seqlen % (kNItems * kNThreads) != 0) { // So that the last state is correct
if (threadIdx.x * kNItems + i >= seqlen - chunk * kChunkSize) {
thread_data[i] = make_float2(1.f, 0.f);
}
if (threadIdx.x * kNItems + i >= chunk_tokens) {
thread_data[i] = make_float2(1.f, 0.f);
}
}
// Initialize running total
......@@ -301,14 +323,14 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
if (threadIdx.x == 0) {
smem_running_prefix[state_idx + r * MAX_DSTATE] = prefix_op.running_prefix;
// Store state at the end of each chunk when cache is enabled
// Store state at the end of each aligned chunk when cache is enabled
if (params.cache_enabled && batch_cache_indices != nullptr) {
size_t cache_slot;
if (chunk == n_chunks - 1) {
cache_slot = batch_cache_indices[block_idx_last_scheduled[batch_id]];
} else {
cache_slot = batch_cache_indices[block_idx_first_scheduled[batch_id] + chunk];
const int block_idx_completed = (current_position + chunk_tokens - 1) / block_size;
cache_slot = batch_cache_indices[block_idx_completed];
}
size_t state_offset = cache_slot * params.ssm_states_batch_stride +
......@@ -331,38 +353,41 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
}
}
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + sequence_start_index * params.out_batch_stride
+ dim_id * kNRows * params.out_d_stride + chunk * kChunkSize;
+ dim_id * kNRows * params.out_d_stride + tokens_processed;
__syncthreads();
#pragma unroll
for (int r = 0; r < kNRows; ++r) {
if constexpr (!kDirectIO) {
if (r > 0) { __syncthreads(); }
}
store_output<Ktraits>(out + r * params.out_d_stride, out_vals[r], smem_store, seqlen - chunk * kChunkSize);
store_output<Ktraits>(out + r * params.out_d_stride, out_vals[r], smem_store, chunk_tokens);
}
if constexpr (kHasZ) {
input_t *z = reinterpret_cast<input_t *>(params.z_ptr) + sequence_start_index * params.z_batch_stride
+ dim_id * kNRows * params.z_d_stride + chunk * kChunkSize;
+ dim_id * kNRows * params.z_d_stride + tokens_processed;
input_t *out_z = reinterpret_cast<input_t *>(params.out_z_ptr) + sequence_start_index * params.out_z_batch_stride
+ dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize;
+ dim_id * kNRows * params.out_z_d_stride + tokens_processed;
#pragma unroll
for (int r = 0; r < kNRows; ++r) {
input_t z_vals[kNItems];
__syncthreads();
load_input<Ktraits>(z + r * params.z_d_stride, z_vals, smem_load, seqlen - chunk * kChunkSize);
load_input<Ktraits>(z + r * params.z_d_stride, z_vals, smem_load, chunk_tokens);
#pragma unroll
for (int i = 0; i < kNItems; ++i) {
float z_val = z_vals[i];
out_vals[r][i] *= z_val / (1 + expf(-z_val));
}
__syncthreads();
store_output<Ktraits>(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, seqlen - chunk * kChunkSize);
store_output<Ktraits>(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, chunk_tokens);
}
}
Bvar += kChunkSize * 1;
Cvar += kChunkSize * 1;
Bvar += chunk_tokens;
Cvar += chunk_tokens;
tokens_processed += chunk_tokens;
current_position += chunk_tokens;
}
}
......@@ -506,7 +531,9 @@ void set_ssm_params_fwd(SSMParamsBase &params,
int64_t block_size,
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
const std::optional<torch::Tensor> &initial_state_idx) {
const std::optional<torch::Tensor> &initial_state_idx,
const std::optional<torch::Tensor> &cu_chunk_seqlen,
const std::optional<torch::Tensor> &last_chunk_indices) {
// Reset the parameters
memset(&params, 0, sizeof(params));
......@@ -548,6 +575,8 @@ void set_ssm_params_fwd(SSMParamsBase &params,
params.block_idx_first_scheduled_token_ptr = block_idx_first_scheduled_token.has_value() ? block_idx_first_scheduled_token.value().data_ptr() : nullptr;
params.block_idx_last_scheduled_token_ptr = block_idx_last_scheduled_token.has_value() ? block_idx_last_scheduled_token.value().data_ptr() : nullptr;
params.initial_state_idx_ptr = initial_state_idx.has_value() ? initial_state_idx.value().data_ptr() : nullptr;
params.cu_chunk_seqlen_ptr = cu_chunk_seqlen.has_value() ? cu_chunk_seqlen.value().data_ptr() : nullptr;
params.last_chunk_indices_ptr = last_chunk_indices.has_value() ? last_chunk_indices.value().data_ptr() : nullptr;
// All stride are in elements, not bytes.
params.A_d_stride = A.stride(0);
......@@ -633,7 +662,9 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
int64_t block_size,
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
const std::optional<torch::Tensor> &initial_state_idx) {
const std::optional<torch::Tensor> &initial_state_idx,
const std::optional<torch::Tensor> &cu_chunk_seqlen,
const std::optional<torch::Tensor> &last_chunk_indices) {
auto input_type = u.scalar_type();
auto weight_type = A.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
......@@ -778,7 +809,9 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
block_size,
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,
initial_state_idx
initial_state_idx,
cu_chunk_seqlen,
last_chunk_indices
);
......
/*
* Adapted from SGLang's sgl-kernel implementation, which was adapted from
* https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/dsv3MinLatencyKernels/dsv3RouterGemm.cu
* https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/thop/dsv3RouterGemmOp.cpp
*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include "dsv3_router_gemm_utils.h"
// Custom FMA implementation using PTX assembly instructions
__device__ __forceinline__ void fma(float2& d, float2 const& a, float2 const& b,
float2 const& c) {
asm volatile("fma.rn.f32x2 %0, %1, %2, %3;\n"
: "=l"(reinterpret_cast<uint64_t&>(d))
: "l"(reinterpret_cast<uint64_t const&>(a)),
"l"(reinterpret_cast<uint64_t const&>(b)),
"l"(reinterpret_cast<uint64_t const&>(c)));
}
// Convert 8 bfloat16 values from a uint4 to float array - optimized conversion
template <int VPT>
__device__ __forceinline__ void bf16_uint4_to_float8(uint4 const& vec,
float* dst) {
__nv_bfloat16* bf16_ptr =
reinterpret_cast<__nv_bfloat16*>(const_cast<uint4*>(&vec));
#pragma unroll
for (int i = 0; i < VPT; i++) {
dst[i] = __bfloat162float(bf16_ptr[i]);
}
}
template <typename T, int kBlockSize, int VPT, int kNumTokens, int kNumExperts,
int kHiddenDim>
__global__ __launch_bounds__(128, 1) void router_gemm_kernel_bf16_output(
__nv_bfloat16* out, T const* mat_a, T const* mat_b) {
// Each block handles one expert column
int const n_idx = blockIdx.x;
int const tid = threadIdx.x;
constexpr int kWarpSize = 32;
constexpr int kNumWarps = kBlockSize / kWarpSize;
// Constants for this kernel
constexpr int k_elems_per_k_iteration = VPT * kBlockSize;
constexpr int k_iterations =
kHiddenDim / k_elems_per_k_iteration; // Total K iterations
// Initialize accumulators for all M rows
float acc[kNumTokens] = {};
// Shared memory for warp-level reduction
__shared__ float sm_reduction[kNumTokens][kNumWarps]; // kNumWarps
// B matrix is in column-major order, so we can directly load a column for the
// n_idx expert
T const* b_col = mat_b + n_idx * kHiddenDim;
// Pre-compute k_base values for each iteration to help compiler optimize
int k_bases[k_iterations];
#pragma unroll
for (int ki = 0; ki < k_iterations; ki++) {
k_bases[ki] = ki * k_elems_per_k_iteration + tid * VPT;
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
#endif
// Process the GEMM in chunks
for (int ki = 0; ki < k_iterations; ki++) {
int const k_base = k_bases[ki];
// Load B matrix values using vector load (8 bf16 values)
uint4 b_vec = *reinterpret_cast<uint4 const*>(b_col + k_base);
// Convert B values to float
float b_float[VPT];
bf16_uint4_to_float8<VPT>(b_vec, b_float);
// Process each token
#pragma unroll
for (int m_idx = 0; m_idx < kNumTokens; m_idx++) {
// Load both rows of A matrix using vector loads
uint4 a_vec = *reinterpret_cast<uint4 const*>(
mat_a + (m_idx * kHiddenDim) + k_base);
// Convert A values to float
float a_float[VPT];
bf16_uint4_to_float8<VPT>(a_vec, a_float);
// Process elements in this chunk
#pragma unroll
for (int k = 0; k < VPT; k++) {
float a = a_float[k];
float b = b_float[k];
acc[m_idx] += a * b;
}
}
}
// Perform warp-level reduction
int const warpSize = 32;
int const warpId = tid / warpSize;
int const laneId = tid % warpSize;
// Register for warp-level reduction results
float warp_result[kNumTokens];
#pragma unroll
for (int m_idx = 0; m_idx < kNumTokens; m_idx++) {
warp_result[m_idx] = acc[m_idx];
}
// Perform warp-level reduction using optimized butterfly pattern
#pragma unroll
for (int m = 0; m < kNumTokens; m++) {
float sum = warp_result[m];
// Butterfly reduction pattern
sum += __shfl_xor_sync(0xffffffff, sum, 16);
sum += __shfl_xor_sync(0xffffffff, sum, 8);
sum += __shfl_xor_sync(0xffffffff, sum, 4);
sum += __shfl_xor_sync(0xffffffff, sum, 2);
sum += __shfl_xor_sync(0xffffffff, sum, 1);
// Only the first thread in each warp stores to shared memory
if (laneId == 0) {
sm_reduction[m][warpId] = sum;
}
}
__syncthreads();
// Final reduction across warps (only first thread)
if (tid == 0) {
#pragma unroll
for (int m = 0; m < kNumTokens; m++) {
float final_sum = 0.0f;
// Sum across the kNumWarps
#pragma unroll
for (int w = 0; w < kNumWarps; w++) {
final_sum += sm_reduction[m][w];
}
// Write final result
out[m * kNumExperts + n_idx] = __float2bfloat16(final_sum);
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
#endif
}
template <typename T, int kNumTokens, int kNumExperts, int kHiddenDim>
void invokeRouterGemmBf16Output(__nv_bfloat16* output, T const* mat_a,
T const* mat_b, cudaStream_t stream) {
constexpr int VPT = 16 / sizeof(T);
constexpr int kBlockSize = 128;
cudaLaunchConfig_t config;
config.gridDim = kNumExperts;
config.blockDim = kBlockSize;
config.dynamicSmemBytes = 0;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = getEnvEnablePDL();
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(
&config,
router_gemm_kernel_bf16_output<T, kBlockSize, VPT, kNumTokens,
kNumExperts, kHiddenDim>,
output, mat_a, mat_b);
}
// Template instantiations for DEFAULT_NUM_EXPERTS experts
template void invokeRouterGemmBf16Output<__nv_bfloat16, 1, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 2, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 3, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 4, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 5, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 6, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 7, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 8, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 9, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 10, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 11, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 12, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 13, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 14, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 15, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 16, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
// Template instantiations for KIMI_K2_NUM_EXPERTS experts
template void invokeRouterGemmBf16Output<__nv_bfloat16, 1, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 2, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 3, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 4, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 5, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 6, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 7, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 8, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 9, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 10, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 11, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 12, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 13, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 14, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 15, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 16, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
/*
* Adapted from SGLang's sgl-kernel implementation, which was adapted from
* https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/dsv3MinLatencyKernels/dsv3RouterGemm.cu
* https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/thop/dsv3RouterGemmOp.cpp
*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include "core/registration.h"
#include "dsv3_router_gemm_utils.h"
static constexpr int DEFAULT_NUM_EXPERTS = 256;
static constexpr int KIMI_K2_NUM_EXPERTS = 384;
static constexpr int DEFAULT_HIDDEN_DIM = 7168;
template <typename T, int kNumTokens, int kNumExperts, int kHiddenDim>
void invokeRouterGemmFloatOutput(float* output, T const* mat_a, T const* mat_b,
cudaStream_t stream);
template <typename T, int kNumTokens, int kNumExperts, int kHiddenDim>
void invokeRouterGemmBf16Output(__nv_bfloat16* output, T const* mat_a,
T const* mat_b, cudaStream_t stream);
template <int kBegin, int kEnd, int kNumExperts, int kHiddenDim>
struct LoopUnroller {
static void unroll_float_output(int num_tokens, float* output,
__nv_bfloat16 const* input,
__nv_bfloat16 const* weights,
cudaStream_t stream) {
if (num_tokens == kBegin) {
invokeRouterGemmFloatOutput<__nv_bfloat16, kBegin, kNumExperts,
kHiddenDim>(output, input, weights, stream);
} else {
LoopUnroller<kBegin + 1, kEnd, kNumExperts,
kHiddenDim>::unroll_float_output(num_tokens, output, input,
weights, stream);
}
}
static void unroll_bf16_output(int num_tokens, __nv_bfloat16* output,
__nv_bfloat16 const* input,
__nv_bfloat16 const* weights,
cudaStream_t stream) {
if (num_tokens == kBegin) {
invokeRouterGemmBf16Output<__nv_bfloat16, kBegin, kNumExperts,
kHiddenDim>(output, input, weights, stream);
} else {
LoopUnroller<kBegin + 1, kEnd, kNumExperts,
kHiddenDim>::unroll_bf16_output(num_tokens, output, input,
weights, stream);
}
}
};
template <int kEnd, int kNumExperts, int kHiddenDim>
struct LoopUnroller<kEnd, kEnd, kNumExperts, kHiddenDim> {
static void unroll_float_output(int num_tokens, float* output,
__nv_bfloat16 const* input,
__nv_bfloat16 const* weights,
cudaStream_t stream) {
if (num_tokens == kEnd) {
invokeRouterGemmFloatOutput<__nv_bfloat16, kEnd, kNumExperts, kHiddenDim>(
output, input, weights, stream);
} else {
throw std::invalid_argument("Invalid num_tokens, only supports 1 to 16");
}
}
static void unroll_bf16_output(int num_tokens, __nv_bfloat16* output,
__nv_bfloat16 const* input,
__nv_bfloat16 const* weights,
cudaStream_t stream) {
if (num_tokens == kEnd) {
invokeRouterGemmBf16Output<__nv_bfloat16, kEnd, kNumExperts, kHiddenDim>(
output, input, weights, stream);
} else {
throw std::invalid_argument("Invalid num_tokens, only supports 1 to 16");
}
}
};
void dsv3_router_gemm(at::Tensor& output, // [num_tokens, num_experts]
const at::Tensor& mat_a, // [num_tokens, hidden_dim]
const at::Tensor& mat_b // [num_experts, hidden_dim]
) {
TORCH_CHECK(output.dim() == 2 && mat_a.dim() == 2 && mat_b.dim() == 2);
const int num_tokens = mat_a.size(0);
const int num_experts = mat_b.size(0);
const int hidden_dim = mat_a.size(1);
TORCH_CHECK(mat_a.size(1) == mat_b.size(1),
"mat_a and mat_b must have the same hidden_dim");
TORCH_CHECK(hidden_dim == DEFAULT_HIDDEN_DIM,
"Expected hidden_dim=", DEFAULT_HIDDEN_DIM,
", but got hidden_dim=", hidden_dim);
TORCH_CHECK(
num_experts == DEFAULT_NUM_EXPERTS || num_experts == KIMI_K2_NUM_EXPERTS,
"Expected num_experts=", DEFAULT_NUM_EXPERTS,
" or num_experts=", KIMI_K2_NUM_EXPERTS,
", but got num_experts=", num_experts);
TORCH_CHECK(num_tokens >= 1 && num_tokens <= 16,
"currently num_tokens must be less than or equal to 16 for "
"router_gemm");
TORCH_CHECK(mat_a.dtype() == at::kBFloat16, "mat_a must be bf16");
TORCH_CHECK(mat_b.dtype() == at::kBFloat16, "mat_b must be bf16");
TORCH_CHECK(output.dtype() == at::kFloat || output.dtype() == at::kBFloat16,
"output must be float32 or bf16");
auto const sm = getSMVersion();
TORCH_CHECK(sm >= 90 && sm <= 103, "required SM_103 >= CUDA ARCH >= SM_90");
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (output.dtype() == at::kFloat) {
if (num_experts == DEFAULT_NUM_EXPERTS) {
LoopUnroller<1, 16, DEFAULT_NUM_EXPERTS, DEFAULT_HIDDEN_DIM>::
unroll_float_output(
num_tokens, reinterpret_cast<float*>(output.mutable_data_ptr()),
reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()),
reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()), stream);
} else if (num_experts == KIMI_K2_NUM_EXPERTS) {
LoopUnroller<1, 16, KIMI_K2_NUM_EXPERTS, DEFAULT_HIDDEN_DIM>::
unroll_float_output(
num_tokens, reinterpret_cast<float*>(output.mutable_data_ptr()),
reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()),
reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()), stream);
}
} else if (output.dtype() == at::kBFloat16) {
if (num_experts == DEFAULT_NUM_EXPERTS) {
LoopUnroller<1, 16, DEFAULT_NUM_EXPERTS, DEFAULT_HIDDEN_DIM>::
unroll_bf16_output(
num_tokens,
reinterpret_cast<__nv_bfloat16*>(output.mutable_data_ptr()),
reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()),
reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()), stream);
} else if (num_experts == KIMI_K2_NUM_EXPERTS) {
LoopUnroller<1, 16, KIMI_K2_NUM_EXPERTS, DEFAULT_HIDDEN_DIM>::
unroll_bf16_output(
num_tokens,
reinterpret_cast<__nv_bfloat16*>(output.mutable_data_ptr()),
reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()),
reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()), stream);
}
}
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("dsv3_router_gemm", &dsv3_router_gemm);
}
/*
* Adapted from SGLang's sgl-kernel implementation, which was adapted from
* https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/dsv3MinLatencyKernels/dsv3RouterGemm.cu
* https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/thop/dsv3RouterGemmOp.cpp
*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include "dsv3_router_gemm_utils.h"
// Custom FMA implementation using PTX assembly instructions
__device__ __forceinline__ void fma(float2& d, float2 const& a, float2 const& b,
float2 const& c) {
asm volatile("fma.rn.f32x2 %0, %1, %2, %3;\n"
: "=l"(reinterpret_cast<uint64_t&>(d))
: "l"(reinterpret_cast<uint64_t const&>(a)),
"l"(reinterpret_cast<uint64_t const&>(b)),
"l"(reinterpret_cast<uint64_t const&>(c)));
}
// Convert 8 bfloat16 values from a uint4 to float array - optimized conversion
template <int VPT>
__device__ __forceinline__ void bf16_uint4_to_float8(uint4 const& vec,
float* dst) {
__nv_bfloat16* bf16_ptr =
reinterpret_cast<__nv_bfloat16*>(const_cast<uint4*>(&vec));
#pragma unroll
for (int i = 0; i < VPT; i++) {
dst[i] = __bfloat162float(bf16_ptr[i]);
}
}
template <typename T, int kBlockSize, int VPT, int kNumTokens, int kNumExperts,
int kHiddenDim>
__global__ __launch_bounds__(128, 1) void router_gemm_kernel_float_output(
float* out, T const* mat_a, T const* mat_b) {
// Each block handles one expert column
int const n_idx = blockIdx.x;
int const tid = threadIdx.x;
constexpr int kWarpSize = 32;
constexpr int kNumWarps = kBlockSize / kWarpSize;
// Constants for this kernel
constexpr int k_elems_per_k_iteration = VPT * kBlockSize;
constexpr int k_iterations =
kHiddenDim / k_elems_per_k_iteration; // Total K iterations
// Initialize accumulators for all M rows
float acc[kNumTokens] = {};
// Shared memory for warp-level reduction
__shared__ float sm_reduction[kNumTokens][kNumWarps]; // kNumWarps
// B matrix is in column-major order, so we can directly load a column for the
// n_idx expert
T const* b_col = mat_b + n_idx * kHiddenDim;
// Pre-compute k_base values for each iteration to help compiler optimize
int k_bases[k_iterations];
#pragma unroll
for (int ki = 0; ki < k_iterations; ki++) {
k_bases[ki] = ki * k_elems_per_k_iteration + tid * VPT;
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
#endif
// Process the GEMM in chunks
for (int ki = 0; ki < k_iterations; ki++) {
int const k_base = k_bases[ki];
// Load B matrix values using vector load (8 bf16 values)
uint4 b_vec = *reinterpret_cast<uint4 const*>(b_col + k_base);
// Convert B values to float
float b_float[VPT];
bf16_uint4_to_float8<VPT>(b_vec, b_float);
// Process each token
#pragma unroll
for (int m_idx = 0; m_idx < kNumTokens; m_idx++) {
// Load both rows of A matrix using vector loads
uint4 a_vec = *reinterpret_cast<uint4 const*>(
mat_a + (m_idx * kHiddenDim) + k_base);
// Convert A values to float
float a_float[VPT];
bf16_uint4_to_float8<VPT>(a_vec, a_float);
// Process elements in this chunk
#pragma unroll
for (int k = 0; k < VPT; k++) {
float a = a_float[k];
float b = b_float[k];
acc[m_idx] += a * b;
}
}
}
// Perform warp-level reduction
int const warpSize = 32;
int const warpId = tid / warpSize;
int const laneId = tid % warpSize;
// Register for warp-level reduction results
float warp_result[kNumTokens];
#pragma unroll
for (int m_idx = 0; m_idx < kNumTokens; m_idx++) {
warp_result[m_idx] = acc[m_idx];
}
// Perform warp-level reduction using optimized butterfly pattern
#pragma unroll
for (int m = 0; m < kNumTokens; m++) {
float sum = warp_result[m];
// Butterfly reduction pattern
sum += __shfl_xor_sync(0xffffffff, sum, 16);
sum += __shfl_xor_sync(0xffffffff, sum, 8);
sum += __shfl_xor_sync(0xffffffff, sum, 4);
sum += __shfl_xor_sync(0xffffffff, sum, 2);
sum += __shfl_xor_sync(0xffffffff, sum, 1);
// Only the first thread in each warp stores to shared memory
if (laneId == 0) {
sm_reduction[m][warpId] = sum;
}
}
__syncthreads();
// Final reduction across warps (only first thread)
if (tid == 0) {
#pragma unroll
for (int m = 0; m < kNumTokens; m++) {
float final_sum = 0.0f;
// Sum across the kNumWarps
#pragma unroll
for (int w = 0; w < kNumWarps; w++) {
final_sum += sm_reduction[m][w];
}
// Write final result
out[m * kNumExperts + n_idx] = final_sum;
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
#endif
}
template <typename T, int kNumTokens, int kNumExperts, int kHiddenDim>
void invokeRouterGemmFloatOutput(float* output, T const* mat_a, T const* mat_b,
cudaStream_t stream) {
constexpr int VPT = 16 / sizeof(T);
constexpr int kBlockSize = 128;
cudaLaunchConfig_t config;
config.gridDim = kNumExperts;
config.blockDim = kBlockSize;
config.dynamicSmemBytes = 0;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = getEnvEnablePDL();
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(
&config,
router_gemm_kernel_float_output<T, kBlockSize, VPT, kNumTokens,
kNumExperts, kHiddenDim>,
output, mat_a, mat_b);
}
// Template instantiations for DEFAULT_NUM_EXPERTS experts
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 1, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 2, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 3, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 4, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 5, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 6, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 7, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 8, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 9, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 10, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 11, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 12, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 13, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 14, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 15, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 16, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
// Template instantiations for KIMI_K2_NUM_EXPERTS experts
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 1, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 2, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 3, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 4, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 5, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 6, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 7, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 8, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 9, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 10, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 11, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 12, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 13, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 14, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 15, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 16, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
/*
* Adapted from SGLang's sgl-kernel implementation, which was adapted from
* https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/dsv3MinLatencyKernels/dsv3RouterGemm.cu
* https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/thop/dsv3RouterGemmOp.cpp
*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <cstdlib>
#include <mutex>
inline int getSMVersion() {
auto* props = at::cuda::getCurrentDeviceProperties();
return props->major * 10 + props->minor;
}
inline bool getEnvEnablePDL() {
static std::once_flag flag;
static bool enablePDL = false;
std::call_once(flag, [&]() {
if (getSMVersion() >= 90) {
const char* env = std::getenv("TRTLLM_ENABLE_PDL");
enablePDL = env && env[0] == '1' && env[1] == '\0';
}
});
return enablePDL;
}
/*
* Adapted from
* https://github.com/NVIDIA/TensorRT-LLM/blob/v0.21.0/cpp/tensorrt_llm/kernels/noAuxTcKernels.cu
* https://github.com/NVIDIA/TensorRT-LLM/blob/v1.3.0rc2/cpp/tensorrt_llm/kernels/noAuxTcKernels.cu
* Copyright (c) 2025, The vLLM team.
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION &
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
......@@ -17,8 +17,10 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "moeTopKFuncs.cuh"
#include <c10/cuda/CUDAStream.h>
#include <torch/all.h>
#include <cmath>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <cuda/std/limits>
......@@ -30,7 +32,17 @@ namespace vllm {
namespace moe {
constexpr unsigned FULL_WARP_MASK = 0xffffffff;
constexpr int32_t WARP_SIZE = 32;
static constexpr int WARP_SIZE = 32;
static constexpr int NumNemotronExperts = 512;
static constexpr int NumKimiK2Experts = 384;
static constexpr int NumDeepseekExperts = 256;
static constexpr int MaxSupportedExpertCount =
std::max({NumNemotronExperts, NumKimiK2Experts, NumDeepseekExperts});
static constexpr int MaxNumExpertsUnit = 128;
static constexpr int NumTopGroupScores = 2;
static constexpr int DefaultMaxNumTopExperts = 8;
static constexpr int MaxSupportedTopExperts = 22;
static constexpr int MaxNumTopGroups = 4;
namespace warp_topk {
......@@ -657,76 +669,335 @@ __global__ void grouped_topk_fused_kernel(
#endif
}
template <typename T, typename BiasT, typename IdxT>
template <typename T, typename BiasT, typename IdxT, ScoringFunc SF,
int MaxNumExperts, bool UseGroups,
int MaxNumTopExperts = DefaultMaxNumTopExperts>
__global__ void grouped_topk_fused_small_expert_count_kernel(
T* scores, float* topkValues, IdxT* topkIndices, BiasT const* routingBias,
int64_t const numTokens, int64_t const numGroup, int64_t const topkGroup,
int64_t const topk, int64_t const numExperts,
int64_t const numExpertsPerGroup, bool const renormalize,
double const routedScalingFactor) {
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif
// declare shared memory structure
// number of experts is bounded by number of threads
__shared__ float __attribute((aligned(128))) smemScoreSigmoid[MaxNumExperts];
__shared__ float __attribute((aligned(128))) smemScoreBias[MaxNumExperts];
// number of expert groups is bounded by number of warps
int constexpr NumWarps = MaxNumExperts / WARP_SIZE;
__shared__ float __attribute((aligned(128))) smemGroupScores[NumWarps];
// needed for warp reduce
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);
// for the final reduction of weight norm, only some lanes need to participate
int32_t laneIdx = threadIdx.x % WARP_SIZE;
int32_t warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WARP_SIZE, 0);
if constexpr (UseGroups) {
if (warpIdx >= numGroup) {
return;
}
}
// note that for invalid scores, we simply use a negative value:
// they work well even with the compacted format used in topK, and
// sigmoid / bias activated scores cannot be negative
const float invalidScoreFloat = float{-INFINITY};
// load bias already; each warp represents one expert group
auto threadExpert = threadIdx.x;
bool expertSelected = threadExpert < numExperts;
if constexpr (UseGroups) {
threadExpert = warpIdx * numExpertsPerGroup + laneIdx;
expertSelected = laneIdx < numExpertsPerGroup;
}
auto scoreIdx = int64_t{blockIdx.x} * int64_t{numExperts} + threadExpert;
auto biasVal = expertSelected ? static_cast<float>(routingBias[threadExpert])
: invalidScoreFloat;
topkValues += blockIdx.x * topk;
topkIndices += blockIdx.x * topk;
// get our assigned thread score; each warp represents one expert group
float score =
expertSelected ? static_cast<float>(scores[scoreIdx]) : invalidScoreFloat;
auto scoreSigmoid = apply_scoring<SF>(score);
// write the sigmoid score to shared for later use
if (expertSelected) {
smemScoreSigmoid[threadExpert] = scoreSigmoid;
}
// get the score with bias
// note that with invalid values, because sigmoid is < 1 and bias is -1,
// we must get a negative value, which is smaller than any valid value
auto scoreBias = float{scoreSigmoid + float{biasVal}};
if (expertSelected) {
smemScoreBias[threadExpert] = scoreBias;
}
// registers for top group score reduction
float topExpGroupScores[NumTopGroupScores];
[[maybe_unused]] int32_t topExpGroupIdx[NumTopGroupScores];
float topGroups[MaxNumTopGroups]; // bound of numGroup
int32_t topGroupIdx[MaxNumTopGroups];
float expertScoreGroup[MaxNumTopGroups];
int32_t expertIdxGroup[MaxNumTopGroups];
float topScores[MaxNumTopExperts]; // bound of topk
int32_t topExperts[MaxNumTopExperts];
if constexpr (UseGroups) {
reduce_topk::reduceTopK(warp, topExpGroupScores, topExpGroupIdx, scoreBias,
threadExpert,
/* minValue */ invalidScoreFloat);
// get the final group score and write it to shared
if (warp.thread_rank() == 0) {
auto groupScore = topExpGroupScores[0] + topExpGroupScores[1];
smemGroupScores[warpIdx] = groupScore;
}
}
// make group scores available to all warps
__syncthreads();
if constexpr (UseGroups) {
if (warpIdx == 0) {
// a single warp performs the selection of top groups, and goes on to
// select the final experts
float groupScore =
laneIdx < numGroup ? smemGroupScores[laneIdx] : invalidScoreFloat;
reduce_topk::reduceTopK(warp, topGroups, topGroupIdx, groupScore, laneIdx,
/* minValue */ invalidScoreFloat);
// final expert selection: get relevant indexes and scores from shared
#pragma unroll
for (int ii = 0; ii < MaxNumTopGroups; ++ii) { // bound of numGroup
auto groupIdx = topGroupIdx[ii];
expertIdxGroup[ii] = groupIdx * numExpertsPerGroup + laneIdx;
expertScoreGroup[ii] = (ii < topkGroup) && expertSelected
? smemScoreBias[expertIdxGroup[ii]]
: invalidScoreFloat;
}
reduce_topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup,
expertIdxGroup, /* minValue */ invalidScoreFloat,
topk);
}
} else if constexpr (MaxNumExperts > MaxNumExpertsUnit) {
// without groups, and the expert number is larger than MaxNumExpertsUnit,
// we need to use multiple warps to calculate the intermediate topk results
int constexpr NumExpertWarps = (MaxNumExperts - 1) / MaxNumExpertsUnit + 1;
int constexpr NumInterTopK = NumExpertWarps * MaxNumTopExperts;
__shared__ float
__attribute((aligned(128))) smemInterTopScores[NumInterTopK];
__shared__ int32_t
__attribute((aligned(128))) smemInterTopExperts[NumInterTopK];
if (warpIdx < NumExpertWarps) {
int offset = warpIdx * WARP_SIZE * MaxNumTopGroups;
#pragma unroll
for (int ii = 0; ii < MaxNumTopGroups; ++ii) {
auto expertIdx = ii * WARP_SIZE + laneIdx;
expertIdxGroup[ii] = offset + expertIdx;
expertScoreGroup[ii] = offset + expertIdx < numExperts
? smemScoreBias[offset + expertIdx]
: invalidScoreFloat;
}
reduce_topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup,
expertIdxGroup,
/* minValue */ invalidScoreFloat, topk);
if (laneIdx < topk) {
smemInterTopScores[warpIdx * MaxNumTopExperts + laneIdx] =
topScores[laneIdx];
smemInterTopExperts[warpIdx * MaxNumTopExperts + laneIdx] =
topExperts[laneIdx];
} else if (laneIdx >= topk && laneIdx < MaxNumTopExperts) {
smemInterTopScores[warpIdx * MaxNumTopExperts + laneIdx] =
invalidScoreFloat;
smemInterTopExperts[warpIdx * MaxNumTopExperts + laneIdx] =
MaxNumExperts - 1;
}
}
__syncthreads();
if (warpIdx == 0) {
int constexpr NumInterTopKPerThread = (NumInterTopK - 1) / WARP_SIZE + 1;
float intermediateScore[NumInterTopKPerThread];
int32_t intermediateExpert[NumInterTopKPerThread];
for (int i = laneIdx; i < NumInterTopKPerThread * WARP_SIZE;
i += WARP_SIZE) {
int ii = i / WARP_SIZE;
if (i < NumInterTopK) {
intermediateScore[ii] = smemInterTopScores[i];
intermediateExpert[ii] = smemInterTopExperts[i];
} else {
intermediateScore[ii] = invalidScoreFloat;
intermediateExpert[ii] = MaxNumExperts - 1;
}
}
reduce_topk::reduceTopK(warp, topScores, topExperts, intermediateScore,
intermediateExpert,
/* minValue */ invalidScoreFloat, topk);
}
} else {
// without groups, and the expert number is smaller than MaxNumExpertsUnit
// each thread just takes `MaxNumTopGroups` experts
if (warpIdx == 0) {
#pragma unroll
for (int ii = 0; ii < MaxNumTopGroups; ++ii) {
auto expertIdx = ii * WARP_SIZE + laneIdx;
expertIdxGroup[ii] = expertIdx;
expertScoreGroup[ii] = expertIdx < numExperts ? smemScoreBias[expertIdx]
: invalidScoreFloat;
}
reduce_topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup,
expertIdxGroup,
/* minValue */ invalidScoreFloat, topk);
}
}
if (warpIdx == 0) {
// determine our lane's expert index and write to output
int32_t expertIdx =
laneIdx < topk ? topExperts[laneIdx] : MaxNumExperts - 1;
float scoreNorm = laneIdx < topk ? smemScoreSigmoid[expertIdx] : 0.F;
float finalScore = static_cast<float>(scoreNorm * routedScalingFactor);
// norm the value
if (renormalize) {
auto redNorm = cg::reduce(warp, scoreNorm, cg::plus<float>{});
finalScore /= (redNorm + 1e-20);
}
// store the topk scores and experts to output
if (laneIdx < topk) {
topkValues[laneIdx] = finalScore;
topkIndices[laneIdx] = expertIdx;
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
#endif
}
template <typename T, typename BiasT, typename IdxT, ScoringFunc SF>
void invokeNoAuxTc(T* scores, float* topk_values, IdxT* topk_indices,
BiasT const* bias, int64_t const num_tokens,
int64_t const num_experts, int64_t const n_group,
int64_t const topk_group, int64_t const topk,
bool const renormalize, double const routed_scaling_factor,
int const scoring_func, bool enable_pdl = false,
cudaStream_t const stream = 0) {
bool enable_pdl = false, cudaStream_t const stream = 0) {
cudaLaunchConfig_t config;
// One block per token; one warp per group.
config.gridDim = static_cast<uint32_t>(num_tokens);
config.blockDim = static_cast<uint32_t>(n_group) * WARP_SIZE;
// Dynamic shared memory: WarpSelect staging + per-group topk buffers.
int32_t const num_warps = static_cast<int32_t>(n_group);
size_t const val_bytes =
static_cast<size_t>(num_warps) * WARP_SIZE * sizeof(T);
size_t const val_bytes_aligned =
warp_topk::round_up_to_multiple_of<256>(val_bytes);
size_t const idx_bytes =
static_cast<size_t>(num_warps) * WARP_SIZE * sizeof(int32_t);
size_t const internal_bytes = val_bytes_aligned + idx_bytes;
size_t const extra_bytes = 16 + static_cast<size_t>(n_group) * sizeof(T);
config.dynamicSmemBytes = internal_bytes + extra_bytes;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
config.numAttrs = 1;
config.attrs = attrs;
auto const sf = static_cast<ScoringFunc>(scoring_func);
switch (sf) {
case SCORING_NONE: {
auto* kernel_instance =
&grouped_topk_fused_kernel<T, BiasT, IdxT, SCORING_NONE>;
cudaLaunchKernelEx(&config, kernel_instance, scores, topk_values,
topk_indices, bias, num_tokens, num_experts, n_group,
topk_group, topk, renormalize, routed_scaling_factor);
return;
}
case SCORING_SIGMOID: {
auto* kernel_instance =
&grouped_topk_fused_kernel<T, BiasT, IdxT, SCORING_SIGMOID>;
cudaLaunchKernelEx(&config, kernel_instance, scores, topk_values,
topk_indices, bias, num_tokens, num_experts, n_group,
topk_group, topk, renormalize, routed_scaling_factor);
return;
// Check if we can use the optimized
// grouped_topk_fused_small_expert_count_kernel
bool const is_single_group =
(n_group == 1) && (topk_group == 1) &&
(num_experts <= MaxSupportedExpertCount) &&
(topk <= DefaultMaxNumTopExperts || topk == MaxSupportedTopExperts);
int64_t const experts_per_group = num_experts / n_group;
bool const is_multi_group =
(n_group > 1) && (num_experts <= NumDeepseekExperts) &&
(experts_per_group <= WARP_SIZE) &&
(experts_per_group * topk_group <= MaxNumExpertsUnit) &&
(topk <= DefaultMaxNumTopExperts) && (topk_group <= MaxNumTopGroups);
if (is_single_group || is_multi_group) {
auto* kernel_instance =
&grouped_topk_fused_small_expert_count_kernel<T, BiasT, IdxT, SF,
NumDeepseekExperts, true>;
int num_threads = NumDeepseekExperts;
if (is_single_group) {
// Special case for Nemotron, which selects top 22 from 512 experts, and 1
// group only.
if (num_experts == NumNemotronExperts && n_group == 1 &&
topk == MaxSupportedTopExperts) {
kernel_instance = &grouped_topk_fused_small_expert_count_kernel<
T, BiasT, IdxT, SF, NumNemotronExperts, false,
MaxSupportedTopExperts>;
num_threads = NumNemotronExperts;
} else if (num_experts > NumKimiK2Experts &&
num_experts <= MaxSupportedExpertCount) {
kernel_instance = &grouped_topk_fused_small_expert_count_kernel<
T, BiasT, IdxT, SF, MaxSupportedExpertCount, false>;
num_threads = MaxSupportedExpertCount;
} else if (num_experts > MaxNumExpertsUnit &&
num_experts <= NumKimiK2Experts) {
kernel_instance = &grouped_topk_fused_small_expert_count_kernel<
T, BiasT, IdxT, SF, NumKimiK2Experts, false>;
num_threads = NumKimiK2Experts;
} else {
kernel_instance = &grouped_topk_fused_small_expert_count_kernel<
T, BiasT, IdxT, SF, MaxNumExpertsUnit, false>;
num_threads = MaxNumExpertsUnit;
}
}
default:
// should be guarded by higher level checks.
TORCH_CHECK(false, "Unsupported scoring_func in invokeNoAuxTc");
config.gridDim = num_tokens;
config.blockDim = num_threads;
config.dynamicSmemBytes = 0;
cudaLaunchKernelEx(&config, kernel_instance, scores, topk_values,
topk_indices, bias, num_tokens, n_group, topk_group,
topk, num_experts, num_experts / n_group, renormalize,
routed_scaling_factor);
} else {
auto* kernel_instance = &grouped_topk_fused_kernel<T, BiasT, IdxT, SF>;
// One block per token; one warp per group.
config.gridDim = static_cast<uint32_t>(num_tokens);
config.blockDim = static_cast<uint32_t>(n_group) * WARP_SIZE;
// Dynamic shared memory: WarpSelect staging + per-group topk buffers.
int32_t const num_warps = static_cast<int32_t>(n_group);
size_t const val_bytes =
static_cast<size_t>(num_warps) * WARP_SIZE * sizeof(T);
size_t const val_bytes_aligned =
warp_topk::round_up_to_multiple_of<256>(val_bytes);
size_t const idx_bytes =
static_cast<size_t>(num_warps) * WARP_SIZE * sizeof(int32_t);
size_t const internal_bytes = val_bytes_aligned + idx_bytes;
size_t const extra_bytes = 16 + static_cast<size_t>(n_group) * sizeof(T);
config.dynamicSmemBytes = internal_bytes + extra_bytes;
cudaLaunchKernelEx(&config, kernel_instance, scores, topk_values,
topk_indices, bias, num_tokens, num_experts, n_group,
topk_group, topk, renormalize, routed_scaling_factor);
}
}
#define INSTANTIATE_NOAUX_TC(T, BiasT, IdxT) \
template void invokeNoAuxTc<T, BiasT, IdxT>( \
#define INSTANTIATE_NOAUX_TC(T, BiasT, IdxT, SF) \
template void invokeNoAuxTc<T, BiasT, IdxT, SF>( \
T * scores, float* topk_values, IdxT* topk_indices, BiasT const* bias, \
int64_t const num_tokens, int64_t const num_experts, \
int64_t const n_group, int64_t const topk_group, int64_t const topk, \
bool const renormalize, double const routed_scaling_factor, \
int const scoring_func, bool enable_pdl, cudaStream_t const stream);
INSTANTIATE_NOAUX_TC(float, float, int32_t);
INSTANTIATE_NOAUX_TC(float, half, int32_t);
INSTANTIATE_NOAUX_TC(float, __nv_bfloat16, int32_t);
INSTANTIATE_NOAUX_TC(half, float, int32_t);
INSTANTIATE_NOAUX_TC(half, half, int32_t);
INSTANTIATE_NOAUX_TC(half, __nv_bfloat16, int32_t);
INSTANTIATE_NOAUX_TC(__nv_bfloat16, float, int32_t);
INSTANTIATE_NOAUX_TC(__nv_bfloat16, half, int32_t);
INSTANTIATE_NOAUX_TC(__nv_bfloat16, __nv_bfloat16, int32_t);
bool enable_pdl, cudaStream_t const stream);
INSTANTIATE_NOAUX_TC(float, float, int32_t, SCORING_SIGMOID);
INSTANTIATE_NOAUX_TC(float, half, int32_t, SCORING_SIGMOID);
INSTANTIATE_NOAUX_TC(float, __nv_bfloat16, int32_t, SCORING_SIGMOID);
INSTANTIATE_NOAUX_TC(half, float, int32_t, SCORING_SIGMOID);
INSTANTIATE_NOAUX_TC(half, half, int32_t, SCORING_SIGMOID);
INSTANTIATE_NOAUX_TC(half, __nv_bfloat16, int32_t, SCORING_SIGMOID);
INSTANTIATE_NOAUX_TC(__nv_bfloat16, float, int32_t, SCORING_SIGMOID);
INSTANTIATE_NOAUX_TC(__nv_bfloat16, half, int32_t, SCORING_SIGMOID);
INSTANTIATE_NOAUX_TC(__nv_bfloat16, __nv_bfloat16, int32_t, SCORING_SIGMOID);
INSTANTIATE_NOAUX_TC(float, float, int32_t, SCORING_NONE);
INSTANTIATE_NOAUX_TC(float, half, int32_t, SCORING_NONE);
INSTANTIATE_NOAUX_TC(float, __nv_bfloat16, int32_t, SCORING_NONE);
INSTANTIATE_NOAUX_TC(half, float, int32_t, SCORING_NONE);
INSTANTIATE_NOAUX_TC(half, half, int32_t, SCORING_NONE);
INSTANTIATE_NOAUX_TC(half, __nv_bfloat16, int32_t, SCORING_NONE);
INSTANTIATE_NOAUX_TC(__nv_bfloat16, float, int32_t, SCORING_NONE);
INSTANTIATE_NOAUX_TC(__nv_bfloat16, half, int32_t, SCORING_NONE);
INSTANTIATE_NOAUX_TC(__nv_bfloat16, __nv_bfloat16, int32_t, SCORING_NONE);
} // end namespace moe
} // namespace vllm
......@@ -762,46 +1033,53 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
{num_tokens, topk}, torch::dtype(torch::kInt32).device(torch::kCUDA));
auto stream = c10::cuda::getCurrentCUDAStream(scores.get_device());
auto const sf = static_cast<vllm::moe::ScoringFunc>(scoring_func);
#define LAUNCH_KERNEL_SF(T, BiasT, IdxT) \
do { \
switch (sf) { \
case vllm::moe::SCORING_NONE: \
vllm::moe::invokeNoAuxTc<T, BiasT, IdxT, vllm::moe::SCORING_NONE>( \
reinterpret_cast<T*>(scores.mutable_data_ptr()), \
reinterpret_cast<float*>(topk_values.mutable_data_ptr()), \
reinterpret_cast<IdxT*>(topk_indices.mutable_data_ptr()), \
reinterpret_cast<BiasT const*>(bias.data_ptr()), num_tokens, \
num_experts, n_group, topk_group, topk, renormalize, \
routed_scaling_factor, false, stream); \
break; \
case vllm::moe::SCORING_SIGMOID: \
vllm::moe::invokeNoAuxTc<T, BiasT, IdxT, vllm::moe::SCORING_SIGMOID>( \
reinterpret_cast<T*>(scores.mutable_data_ptr()), \
reinterpret_cast<float*>(topk_values.mutable_data_ptr()), \
reinterpret_cast<IdxT*>(topk_indices.mutable_data_ptr()), \
reinterpret_cast<BiasT const*>(bias.data_ptr()), num_tokens, \
num_experts, n_group, topk_group, topk, renormalize, \
routed_scaling_factor, false, stream); \
break; \
default: \
throw std::invalid_argument("Unsupported scoring_func"); \
break; \
} \
} while (0)
#define LAUNCH_KERNEL(T, IdxT) \
do { \
switch (bias_type) { \
case torch::kFloat16: \
vllm::moe::invokeNoAuxTc<T, half, IdxT>( \
reinterpret_cast<T*>(scores.mutable_data_ptr()), \
reinterpret_cast<float*>(topk_values.mutable_data_ptr()), \
reinterpret_cast<IdxT*>(topk_indices.mutable_data_ptr()), \
reinterpret_cast<half const*>(bias.data_ptr()), num_tokens, \
num_experts, n_group, topk_group, topk, renormalize, \
routed_scaling_factor, static_cast<int>(scoring_func), false, \
stream); \
break; \
case torch::kFloat32: \
vllm::moe::invokeNoAuxTc<T, float, IdxT>( \
reinterpret_cast<T*>(scores.mutable_data_ptr()), \
reinterpret_cast<float*>(topk_values.mutable_data_ptr()), \
reinterpret_cast<IdxT*>(topk_indices.mutable_data_ptr()), \
reinterpret_cast<float const*>(bias.data_ptr()), num_tokens, \
num_experts, n_group, topk_group, topk, renormalize, \
routed_scaling_factor, static_cast<int>(scoring_func), false, \
stream); \
break; \
case torch::kBFloat16: \
vllm::moe::invokeNoAuxTc<T, __nv_bfloat16, IdxT>( \
reinterpret_cast<T*>(scores.mutable_data_ptr()), \
reinterpret_cast<float*>(topk_values.mutable_data_ptr()), \
reinterpret_cast<IdxT*>(topk_indices.mutable_data_ptr()), \
reinterpret_cast<__nv_bfloat16 const*>(bias.data_ptr()), \
num_tokens, num_experts, n_group, topk_group, topk, renormalize, \
routed_scaling_factor, static_cast<int>(scoring_func), false, \
stream); \
break; \
default: \
throw std::invalid_argument( \
"Invalid bias dtype, only supports float16, float32, and " \
"bfloat16"); \
break; \
} \
#define LAUNCH_KERNEL(T, IdxT) \
do { \
switch (bias_type) { \
case torch::kFloat16: \
LAUNCH_KERNEL_SF(T, half, IdxT); \
break; \
case torch::kFloat32: \
LAUNCH_KERNEL_SF(T, float, IdxT); \
break; \
case torch::kBFloat16: \
LAUNCH_KERNEL_SF(T, __nv_bfloat16, IdxT); \
break; \
default: \
throw std::invalid_argument( \
"Invalid bias dtype, only supports float16, float32, and " \
"bfloat16"); \
break; \
} \
} while (0)
switch (data_type) {
......@@ -824,5 +1102,6 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
break;
}
#undef LAUNCH_KERNEL
#undef LAUNCH_KERNEL_SF
return {topk_values, topk_indices};
}
/*
* Adapted from
* https://github.com/NVIDIA/TensorRT-LLM/blob/v1.3.0rc2/cpp/tensorrt_llm/kernels/moeTopKFuncs.cuh
* Copyright (c) 2026, The vLLM team.
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION. All rights
* reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cub/cub.cuh>
namespace vllm {
namespace moe {
namespace reduce_topk {
namespace cg = cooperative_groups;
static constexpr int kWARP_SIZE = 32;
template <typename T_>
struct TopKRedType {
using T = T_;
static_assert(
std::is_same_v<T, float> || std::is_same_v<T, half> ||
std::is_same_v<T, __nv_bfloat16> || std::is_same_v<T, int>,
"Top K reduction only implemented for int, float, float16 and bfloat16");
using TypeCmp = std::conditional_t<sizeof(T) == 4, uint64_t, uint32_t>;
using IdxT = std::conditional_t<sizeof(T) == 4, int32_t, int16_t>;
static constexpr int kMoveBits = (sizeof(T) == 4) ? 32 : 16;
static constexpr int kMaxIdx = 65535;
TypeCmp compValIdx;
static __host__ __device__ inline TypeCmp makeCmpVal(T val, int32_t idx = 0) {
auto valueBits = cub::Traits<T>::TwiddleIn(
reinterpret_cast<typename cub::Traits<T>::UnsignedBits&>(val));
TypeCmp compactTmp = valueBits;
compactTmp = (compactTmp << kMoveBits) | (0xFFFF & (kMaxIdx - idx));
// Use 65535 minus idx to give higher priority to elements with smaller
// indices.
return compactTmp;
}
static __host__ __device__ void unpack(T& value, int32_t& index,
TypeCmp cmp) {
// Since “65535-idx” is always smaller than 65536 and positive, we can
// directly use it as the lower 16 bits
index = kMaxIdx - static_cast<int32_t>((cmp & 0xFFFF));
auto compactTmp = cmp >> kMoveBits;
auto valueBits = cub::Traits<T>::TwiddleOut(
reinterpret_cast<typename cub::Traits<T>::UnsignedBits&>(compactTmp));
value = reinterpret_cast<T&>(valueBits);
}
__host__ __device__ TopKRedType() = default;
__host__ __device__ TopKRedType(T val, int32_t idx)
: compValIdx(makeCmpVal(val, idx)) {}
__host__ __device__ operator TypeCmp() const noexcept { return compValIdx; }
__device__ inline TypeCmp reduce(
cg::thread_block_tile<kWARP_SIZE> const& warp) {
return cg::reduce(warp, compValIdx, cg::greater<TypeCmp>{});
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int K_, bool Enable_>
struct TopKIdx {
// by default, empty
};
template <int K_>
struct TopKIdx<K_, true> {
static constexpr int K = K_;
int32_t val[K];
};
////////////////////////////////////////////////////////////////////////////////////////////////////
#define TOPK_SWAP(I, J) \
{ \
auto pairMin = min(topK[I].compValIdx, topK[J].compValIdx); \
auto pairMax = max(topK[I].compValIdx, topK[J].compValIdx); \
topK[I].compValIdx = pairMax; \
topK[J].compValIdx = pairMin; \
}
template <int N, typename RedType>
struct Sort;
template <typename RedType>
struct Sort<1, RedType> {
static __device__ void run(RedType* topK) {}
};
template <typename RedType>
struct Sort<2, RedType> {
static __device__ void run(RedType* topK) { TOPK_SWAP(0, 1); }
};
template <typename RedType>
struct Sort<3, RedType> {
static __device__ void run(RedType* topK) {
TOPK_SWAP(0, 1);
TOPK_SWAP(1, 2);
TOPK_SWAP(0, 1);
}
};
template <typename RedType>
struct Sort<4, RedType> {
static __device__ void run(RedType* topK) {
TOPK_SWAP(0, 2);
TOPK_SWAP(1, 3);
TOPK_SWAP(0, 1);
TOPK_SWAP(2, 3);
TOPK_SWAP(1, 2);
}
};
template <int K, typename Type>
__forceinline__ __device__ void reduceTopK(
cg::thread_block_tile<kWARP_SIZE> const& warp, Type (&out)[K],
int32_t (&outIdx)[K], Type value, int32_t idx, Type const minValue,
int actualK = K) {
static_assert(K > 0, "Top K must have K > 0");
static_assert(K < kWARP_SIZE, "Top K must have K < kWARP_SIZE");
using RedType = TopKRedType<Type>;
RedType topK{value, idx};
typename RedType::TypeCmp packedMax{};
#pragma unroll
for (int kk = 0; kk < actualK; ++kk) {
topK =
kk > 0 && packedMax == topK.compValIdx ? RedType{minValue, idx} : topK;
// get the next largest value
packedMax = topK.reduce(warp);
RedType::unpack(out[kk], outIdx[kk], packedMax);
}
};
template <int K, typename Type, int N, bool IsSorted = false>
__device__ void reduceTopKFunc(cg::thread_block_tile<kWARP_SIZE> const& warp,
Type (&out)[K], int32_t (&outIdx)[K],
Type (&value)[N], int32_t (&idx)[N],
Type minValue, int actualK = K) {
static_assert(K > 0, "Top K must have K > 0");
static_assert(K < kWARP_SIZE, "Top K must have K < kWARP_SIZE");
static_assert(N > 0, "Top K must have N > 0");
static_assert(N < 5,
"Only support candidates number less than or equal to 128");
using RedType = TopKRedType<Type>;
RedType topK[N];
#pragma unroll
for (int nn = 0; nn < N; ++nn) {
topK[nn] = RedType{value[nn], idx[nn]};
}
if constexpr (!IsSorted) {
Sort<N, RedType>::run(topK);
}
typename RedType::TypeCmp packedMax{};
#pragma unroll
for (int kk = 0; kk < actualK; ++kk) {
bool update = kk > 0 && packedMax == topK[0].compValIdx;
#pragma unroll
for (int nn = 0; nn < N; ++nn) {
topK[nn] = update && nn == N - 1 ? RedType{minValue, idx[nn]}
: update ? topK[nn + 1]
: topK[nn];
}
// get the next largest value
packedMax = topK[0].reduce(warp);
RedType::unpack(out[kk], outIdx[kk], packedMax);
}
};
template <int K, typename Type, int N>
__forceinline__ __device__ void reduceTopK(
cg::thread_block_tile<kWARP_SIZE> const& warp, Type (&out)[K],
int32_t (&outIdx)[K], Type (&value)[N], int32_t (&idx)[N],
Type const minValue, int actualK = K) {
static_assert(K > 0, "Top K must have K > 0");
static_assert(K < kWARP_SIZE, "Top K must have K < kWARP_SIZE");
static_assert(N > 0, "Top K must have N > 0");
static_assert(
N <= 16,
"Only support candidates number less than or equal to 16*32=512");
static_assert(N <= 4 || N % 4 == 0,
"Only support candidates number is a multiple of 4*32=128 or "
"less than or equal to 4");
using RedType = TopKRedType<Type>;
if constexpr (N <= 4) {
reduceTopKFunc<K, Type, N>(warp, out, outIdx, value, idx, minValue,
actualK);
} else {
constexpr int numLoops = N / 4;
constexpr int numResults = (numLoops * K - 1) / kWARP_SIZE + 1;
Type topKBufferValue[numResults];
int32_t topKBufferIdx[numResults];
int32_t laneIdx = threadIdx.x % kWARP_SIZE;
for (int ii = 0; ii < numResults; ++ii) {
topKBufferValue[ii] = minValue;
topKBufferIdx[ii] = ii * kWARP_SIZE - 1;
}
for (int loop = 0; loop < numLoops; ++loop) {
int start = loop * 4;
Type topKValue[K];
int32_t topKIdx[K];
Type inValue[4];
int32_t inIdx[4];
for (int i = 0; i < 4; ++i) {
inValue[i] = value[start + i];
inIdx[i] = idx[start + i];
}
reduceTopKFunc<K, Type, 4>(warp, topKValue, topKIdx, inValue, inIdx,
minValue, actualK);
int inOffset = laneIdx % K;
if (laneIdx >= loop * K && laneIdx < (loop + 1) * K) {
topKBufferValue[0] = topKValue[inOffset];
topKBufferIdx[0] = topKIdx[inOffset];
}
if (loop == numLoops - 1 && (laneIdx < (numLoops * K - kWARP_SIZE))) {
topKBufferValue[1] = topKValue[inOffset];
topKBufferIdx[1] = topKIdx[inOffset];
}
}
reduceTopKFunc<K, Type, numResults>(warp, out, outIdx, topKBufferValue,
topKBufferIdx, minValue, actualK);
}
};
#undef TOPK_SWAP
} // namespace reduce_topk
} // namespace moe
} // namespace vllm
......@@ -35,11 +35,11 @@ __global__ void batched_moe_align_block_size_kernel(
int32_t const block_ids_size = sorted_ids_size / block_size;
int32_t const SENTINEL =
num_batches * max_tokens_per_batch; // To denote invalid entries.
// Intialize sorted_ids
// Initialize sorted_ids
for (size_t i = threadIdx.x; i < sorted_ids_size; i += stride) {
sorted_ids[i] = SENTINEL;
}
// Intialize expert_ids with -1
// Initialize expert_ids with -1
for (size_t i = threadIdx.x; i < block_ids_size; i += stride) {
block_ids[i] = -1;
}
......@@ -172,7 +172,7 @@ __device__ void _moe_align_block_size(
}
}
// Fill remaining expert_ids with 0
// Fill remaining expert_ids with -1
const size_t fill_start_idx =
cumsum[cumsum_offset + num_experts] / block_size + threadIdx.x;
for (size_t i = fill_start_idx; i < max_num_m_blocks; i += blockDim.x) {
......@@ -265,7 +265,7 @@ __device__ void _moe_align_block_size_small_batch_expert(
}
}
// Fill remaining expert_ids with 0
// Fill remaining expert_ids with -1
const size_t fill_start_idx = cumsum[num_experts] / block_size + tid;
for (size_t i = fill_start_idx; i < max_num_m_blocks; i += stride) {
expert_ids[expert_ids_offset + i] = inactive_expert_id;
......@@ -332,7 +332,7 @@ __global__ void moe_align_block_size_kernel(
topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad, expert_map,
num_experts, padded_num_experts, experts_per_warp, block_size, numel,
cumsum, max_num_tokens_padded, CEILDIV(max_num_tokens_padded, block_size),
0, 0, topk_num, nullptr, has_expert_map);
0, -1, topk_num, nullptr, has_expert_map);
}
template <typename scalar_t>
......@@ -373,7 +373,7 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
_moe_align_block_size_small_batch_expert<scalar_t, fill_threads>(
topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad, expert_map,
num_experts, block_size, numel, max_num_tokens_padded,
CEILDIV(max_num_tokens_padded, block_size), 0, 0, topk_num, nullptr,
CEILDIV(max_num_tokens_padded, block_size), -1, 0, topk_num, nullptr,
has_expert_map);
}
......
......@@ -55,4 +55,19 @@ bool moe_permute_unpermute_supported();
void shuffle_rows(const torch::Tensor& input_tensor,
const torch::Tensor& dst2src_map,
torch::Tensor& output_tensor);
\ No newline at end of file
torch::Tensor& output_tensor);
#ifndef USE_ROCM
// cuBLAS bf16 x bf16 -> fp32 router GEMM (fallback for non-SM90 / batch > 16)
torch::Tensor router_gemm_bf16_fp32(torch::Tensor const& input,
torch::Tensor const& weight);
// DeepSeek V3 optimized router GEMM kernel for SM90+
// Computes output = mat_a @ mat_b.T where:
// mat_a: [num_tokens, hidden_dim] in bf16
// mat_b: [num_experts, hidden_dim] in bf16
// output: [num_tokens, num_experts] in bf16 or fp32
// Supports num_tokens in [1, 16], num_experts in {256, 384}, hidden_dim = 7168
void dsv3_router_gemm(torch::Tensor& output, const torch::Tensor& mat_a,
const torch::Tensor& mat_b);
#endif
......@@ -73,10 +73,9 @@ void moe_permute(
MOE_DISPATCH(input.scalar_type(), [&] {
expandInputRowsKernelLauncher<scalar_t>(
get_ptr<scalar_t>(input), get_ptr<scalar_t>(permuted_input),
get_ptr<int>(permuted_experts_id), get_ptr<int>(sorted_row_idx),
get_ptr<int>(inv_permuted_idx), get_ptr<int>(permuted_idx),
get_ptr<int64_t>(expert_first_token_offset), n_token, valid_num_ptr,
n_hidden, topk, n_local_expert, stream);
get_ptr<int>(sorted_row_idx), get_ptr<int>(inv_permuted_idx),
get_ptr<int>(permuted_idx), get_ptr<int64_t>(expert_first_token_offset),
n_token, valid_num_ptr, n_hidden, topk, n_local_expert, stream);
});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment