Unverified Commit 4d51588e authored by Yifan Qiao's avatar Yifan Qiao Committed by GitHub
Browse files

[Feat] DeepSeek V4 Rebased (#40860)


Signed-off-by: default avatarYifan Qiao <yifanqiao@inferact.ai>
Signed-off-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
Signed-off-by: default avatarqizixi <zixi@inferact.ai>
Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
Signed-off-by: default avatarYongye Zhu <zyy1102000@gmail.com>
Co-authored-by: default avatarYongye Zhu <zyy1102000@gmail.com>
Co-authored-by: default avatarYongye Zhu <yongye@inferact.ai>
Co-authored-by: default avatarSimon Mo <simon@inferact.ai>
Co-authored-by: default avatarBugen Zhao <i@bugenzhao.com>
Co-authored-by: default avatarGiancarlo Delfin <gdelfin@inferact.ai>
Co-authored-by: default avatarJee Jee Li <pandaleefree@gmail.com>
Co-authored-by: default avatarNick Hill <nickhill123@gmail.com>
Co-authored-by: default avatarRoger Wang <hey@rogerw.io>
Co-authored-by: default avatarRoy Wang <yasong.wang@inferact.ai>
Co-authored-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
Co-authored-by: default avataryoukaichao <youkaichao@gmail.com>
Co-authored-by: default avatarZhewen Li <jerven.vllm@gmail.com>
Co-authored-by: default avatarZijing Liu <liuzijing2014@gmail.com>
Co-authored-by: default avatarkhluu <khluu000@gmail.com>
Co-authored-by: default avatarqizixi <zixi@inferact.ai>
Co-authored-by: Zh...
parent 32e45636
...@@ -310,7 +310,9 @@ set(VLLM_EXT_SRC ...@@ -310,7 +310,9 @@ set(VLLM_EXT_SRC
"csrc/torch_bindings.cpp") "csrc/torch_bindings.cpp")
if(VLLM_GPU_LANG STREQUAL "CUDA") if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_EXT_SRC "csrc/minimax_reduce_rms_kernel.cu") list(APPEND VLLM_EXT_SRC
"csrc/minimax_reduce_rms_kernel.cu"
"csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu")
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
...@@ -1051,7 +1053,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -1051,7 +1053,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_MOE_EXT_SRC list(APPEND VLLM_MOE_EXT_SRC
"csrc/moe/moe_wna16.cu" "csrc/moe/moe_wna16.cu"
"csrc/moe/grouped_topk_kernels.cu" "csrc/moe/grouped_topk_kernels.cu"
"csrc/moe/router_gemm.cu") "csrc/moe/router_gemm.cu"
"csrc/moe/topk_softplus_sqrt_kernels.cu")
endif() endif()
if(VLLM_GPU_LANG STREQUAL "CUDA") if(VLLM_GPU_LANG STREQUAL "CUDA")
......
...@@ -20,7 +20,7 @@ else() ...@@ -20,7 +20,7 @@ else()
FetchContent_Declare( FetchContent_Declare(
deepgemm deepgemm
GIT_REPOSITORY https://github.com/deepseek-ai/DeepGEMM.git GIT_REPOSITORY https://github.com/deepseek-ai/DeepGEMM.git
GIT_TAG 477618cd51baffca09c4b0b87e97c03fe827ef03 GIT_TAG 891d57b4db1071624b5c8fa0d1e51cb317fa709f
GIT_SUBMODULES "third-party/cutlass" "third-party/fmt" GIT_SUBMODULES "third-party/cutlass" "third-party/fmt"
GIT_PROGRESS TRUE GIT_PROGRESS TRUE
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""
...@@ -120,6 +120,11 @@ if(DEEPGEMM_ARCHS) ...@@ -120,6 +120,11 @@ if(DEEPGEMM_ARCHS)
COMPONENT _deep_gemm_C COMPONENT _deep_gemm_C
FILES_MATCHING PATTERN "*.py") FILES_MATCHING PATTERN "*.py")
install(DIRECTORY "${deepgemm_SOURCE_DIR}/deep_gemm/mega/"
DESTINATION vllm/third_party/deep_gemm/mega
COMPONENT _deep_gemm_C
FILES_MATCHING PATTERN "*.py")
# Generate envs.py (normally generated by DeepGEMM's setup.py build step) # Generate envs.py (normally generated by DeepGEMM's setup.py build step)
file(WRITE "${CMAKE_CURRENT_BINARY_DIR}/deep_gemm_envs.py" file(WRITE "${CMAKE_CURRENT_BINARY_DIR}/deep_gemm_envs.py"
"# Pre-installed environment variables\npersistent_envs = dict()\n") "# Pre-installed environment variables\npersistent_envs = dict()\n")
......
...@@ -19,7 +19,7 @@ else() ...@@ -19,7 +19,7 @@ else()
FetchContent_Declare( FetchContent_Declare(
flashmla flashmla
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA GIT_REPOSITORY https://github.com/vllm-project/FlashMLA
GIT_TAG 692917b1cda61b93ac9ee2d846ec54e75afe87b1 GIT_TAG a6ec2ba7bd0a7dff98b3f4d3e6b52b159c48d78b
GIT_PROGRESS TRUE GIT_PROGRESS TRUE
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""
BUILD_COMMAND "" BUILD_COMMAND ""
......
...@@ -178,7 +178,12 @@ void rotary_embedding_gptj_impl( ...@@ -178,7 +178,12 @@ void rotary_embedding_gptj_impl(
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
std::optional<torch::Tensor> key, int64_t head_size, std::optional<torch::Tensor> key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox) { torch::Tensor& cos_sin_cache, bool is_neox,
int64_t rope_dim_offset, bool inverse) {
TORCH_CHECK(rope_dim_offset == 0,
"rope_dim_offset != 0 is not supported on CPU");
TORCH_CHECK(!inverse, "inverse rotary embedding is not supported on CPU");
int num_tokens = positions.numel(); int num_tokens = positions.numel();
int rot_dim = cos_sin_cache.size(1); int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(-1) / head_size; int num_heads = query.size(-1) / head_size;
......
...@@ -263,7 +263,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -263,7 +263,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def( ops.def(
"rotary_embedding(Tensor positions, Tensor! query," "rotary_embedding(Tensor positions, Tensor! query,"
" Tensor!? key, int head_size," " Tensor!? key, int head_size,"
" Tensor cos_sin_cache, bool is_neox) -> ()"); " Tensor cos_sin_cache, bool is_neox, int "
"rope_dim_offset=0, bool inverse=False) -> ()");
ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding); ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding);
// Quantization // Quantization
......
/*
* SPDX-License-Identifier: Apache-2.0
* SPDX-FileCopyrightText: Copyright contributors to the vLLM project
*
* Horizontally-fused DeepseekV4-MLA kernel:
* - Q side: per-head RMSNorm (no weight) + GPT-J RoPE on last ROPE_DIM
* - KV side: GPT-J RoPE on last ROPE_DIM + UE8M0 FP8 quant on NoPE + paged
* cache insert
*
* Structured after `applyMLARopeAndAssignQKVKernelGeneration` in
* TensorRT-LLM's mlaKernels.cu: one kernel, one grid, with head-slot
* dispatch choosing Q vs KV work per warp. The per-warp RMSNorm/RoPE
* skeleton is adapted from vllm-deepseek_v4's existing
* `fusedQKNormRopeKernel` (csrc/fused_qknorm_rope_kernel.cu).
*
* Assumptions (hard-coded for DeepseekV4 attention):
* HEAD_DIM = 512
* ROPE_DIM = 64 (RoPE applied to dims [NOPE_DIM, HEAD_DIM))
* NOPE_DIM = 448
* QUANT_BLOCK = 64 (UE8M0 FP8 quant block)
* FP8_MAX = 448.0f
* is_neox=false (GPT-J interleaved pairs)
* cos_sin_cache layout [max_pos, rope_dim] = cos || sin (cos first, sin
* second along last dim; each half is rope_dim/2 = 32 values)
*
* Cache layout per paged-cache block (block_size tokens):
* [0, bs*576): token data, 448 fp8 + 128 bf16 each
* [bs*576, bs*576 + bs*8): UE8M0 scales, 7 real + 1 pad per token
*/
#include <cmath>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <type_traits>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/cuda.h>
#include "cuda_compat.h"
#include "dispatch_utils.h"
#include "type_convert.cuh"
#ifndef FINAL_MASK
#define FINAL_MASK 0xffffffffu
#endif
namespace vllm {
namespace deepseek_v4_fused_ops {
namespace {
inline int getSMVersion() {
auto* props = at::cuda::getCurrentDeviceProperties();
return props->major * 10 + props->minor;
}
} // namespace
// ────────────────────────────────────────────────────────────────────────────
// Constants
// ────────────────────────────────────────────────────────────────────────────
constexpr int kHeadDim = 512;
constexpr int kRopeDim = 64;
constexpr int kNopeDim = kHeadDim - kRopeDim; // 448
constexpr int kQuantBlock = 64;
constexpr int kNumQuantBlocks = kNopeDim / kQuantBlock; // 7
constexpr int kScaleBytesPerToken = kNumQuantBlocks + 1; // 8 (7 real + 1 pad)
constexpr int kTokenDataBytes = kNopeDim + kRopeDim * 2; // 448 + 128 = 576
constexpr float kFp8Max = 448.0f;
// Per-warp layout: 32 lanes × 16 elems/lane = 512 elems = HEAD_DIM.
constexpr int kNumLanes = 32;
constexpr int kElemsPerLane = kHeadDim / kNumLanes; // 16
// ────────────────────────────────────────────────────────────────────────────
// Small inline helpers
// ────────────────────────────────────────────────────────────────────────────
__device__ __forceinline__ float warp4MaxAbs(float val) {
// Reduce absolute max across 4 consecutive lanes (lane id & 3 group).
float peer = __shfl_xor_sync(FINAL_MASK, val, 1);
val = fmaxf(val, peer);
peer = __shfl_xor_sync(FINAL_MASK, val, 2);
val = fmaxf(val, peer);
return val;
}
template <typename T>
__device__ __forceinline__ float warpSum(float val) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
val += __shfl_xor_sync(FINAL_MASK, val, mask, 32);
}
return val;
}
// ────────────────────────────────────────────────────────────────────────────
// Kernel
// ────────────────────────────────────────────────────────────────────────────
//
// Grid: 1D, gridDim.x = ceil(num_tokens_full * (num_heads_q + 1) /
// warps_per_block) Block: blockDim.x = 256 threads (8 warps per block) Each
// warp handles one (token, head_slot) pair. head_slot < num_heads_q →
// Q branch (RMSNorm + RoPE, in place) head_slot == num_heads_q → KV
// branch (RoPE + UE8M0 quant + insert)
//
// With DP padding, q/kv/position_ids can have more rows than slot_mapping.
// The Q branch covers all `num_tokens_full` rows (downstream attention uses
// them). The KV branch only inserts the first `num_tokens_insert` tokens
// (= slot_mapping length) into the paged cache.
//
template <typename scalar_t_in>
__global__ void fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel(
scalar_t_in* __restrict__ q_inout, // [N, H, 512] bf16, in place
scalar_t_in const* __restrict__ kv_in, // [N, 512] bf16
uint8_t* __restrict__ k_cache, // [num_blocks, block_stride]
int64_t const* __restrict__ slot_mapping, // [num_tokens_insert] i64
int64_t const* __restrict__ position_ids, // [N] i64
float const* __restrict__ cos_sin_cache, // [max_pos, 64] fp32
float const eps,
int const num_tokens_full, // = q.size(0) = kv.size(0)
int const num_tokens_insert, // = slot_mapping.size(0), ≤ num_tokens_full
int const num_heads_q, // H
int const cache_block_size, // tokens per paged-cache block
int const kv_block_stride) { // bytes per paged-cache block
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
// BF16 _typeConvert specialization is unavailable on pre-Ampere. The
// DeepseekV4 kernel only runs with bf16 inputs in practice, so compile a
// no-op stub for sm_70/sm_75 to keep multi-arch builds happy.
if constexpr (std::is_same_v<scalar_t_in, c10::BFloat16>) {
return;
} else {
#endif
using Converter = vllm::_typeConvert<scalar_t_in>;
int const warpsPerBlock = blockDim.x / 32;
int const warpId = threadIdx.x / 32;
int const laneId = threadIdx.x % 32;
int const globalWarpIdx = blockIdx.x * warpsPerBlock + warpId;
int const total_slots_per_token = num_heads_q + 1;
int const tokenIdx = globalWarpIdx / total_slots_per_token;
int const slotIdx = globalWarpIdx % total_slots_per_token;
if (tokenIdx >= num_tokens_full) return;
bool const isKV = (slotIdx == num_heads_q);
// KV branch: skip DP-padded tokens (no slot reserved for them).
if (isKV && tokenIdx >= num_tokens_insert) return;
// PDL: wait for predecessor kernel (upstream q/kv producer) to signal
// before touching any global memory. No-op when PDL is not enabled on
// the launch. The CUDA runtime wrapper emits the griddepcontrol.wait
// PTX with the required memory clobber internally.
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
cudaGridDependencySynchronize();
#endif
// Dim range this lane owns within the 512-wide head.
int const dim_base = laneId * kElemsPerLane; // in [0, 512) step 16
// ── Load 16 bf16 → 16 fp32 registers (one 16-byte + one 16-byte LDG) ────
float elements[kElemsPerLane];
float sumOfSquares = 0.0f;
scalar_t_in const* src_ptr;
if (isKV) {
src_ptr = kv_in + static_cast<int64_t>(tokenIdx) * kHeadDim + dim_base;
} else {
int64_t const q_row_offset =
(static_cast<int64_t>(tokenIdx) * num_heads_q + slotIdx) * kHeadDim +
dim_base;
src_ptr = q_inout + q_row_offset;
}
// Two 16-byte loads per thread (8 bf16 each). Use uint4 as the vector
// type and bitcast to scalar_t_in packed pairs for conversion.
uint4 v0 = *reinterpret_cast<uint4 const*>(src_ptr);
uint4 v1 = *reinterpret_cast<uint4 const*>(src_ptr + 8);
{
typename Converter::packed_hip_type const* p0 =
reinterpret_cast<typename Converter::packed_hip_type const*>(&v0);
typename Converter::packed_hip_type const* p1 =
reinterpret_cast<typename Converter::packed_hip_type const*>(&v1);
// Each packed_hip_type holds 2 bf16 → 4 packed = 8 elems per uint4.
#pragma unroll
for (int i = 0; i < 4; i++) {
float2 f2 = Converter::convert(p0[i]);
elements[2 * i] = f2.x;
elements[2 * i + 1] = f2.y;
}
#pragma unroll
for (int i = 0; i < 4; i++) {
float2 f2 = Converter::convert(p1[i]);
elements[8 + 2 * i] = f2.x;
elements[8 + 2 * i + 1] = f2.y;
}
}
// ── Q branch: RMSNorm with no weight (has_weight=False) ─────────────────
// Variance + rsqrt + multiply all in fp32, no intermediate bf16 round.
// The downstream bf16 round only happens at the final store.
if (!isKV) {
#pragma unroll
for (int i = 0; i < kElemsPerLane; i++) {
sumOfSquares += elements[i] * elements[i];
}
sumOfSquares = warpSum<float>(sumOfSquares);
float const rms_rcp =
rsqrtf(sumOfSquares / static_cast<float>(kHeadDim) + eps);
#pragma unroll
for (int i = 0; i < kElemsPerLane; i++) {
elements[i] = elements[i] * rms_rcp;
}
}
// ── GPT-J RoPE on dims [NOPE_DIM, HEAD_DIM) ─────────────────────────────
// All math in fp32. cos_sin_cache is loaded as fp32 (its native storage).
bool const is_rope_lane = dim_base >= kNopeDim;
if (is_rope_lane) {
int64_t const pos = position_ids[tokenIdx];
constexpr int kHalfRope = kRopeDim / 2; // 32
float const* cos_ptr = cos_sin_cache + pos * kRopeDim;
float const* sin_ptr = cos_ptr + kHalfRope;
int const rope_local_base = dim_base - kNopeDim; // in [0, 64) step 16
#pragma unroll
for (int p = 0; p < kElemsPerLane / 2; p++) {
int const pair_dim = rope_local_base + 2 * p;
int const half_idx = pair_dim / 2;
float const cos_v = VLLM_LDG(cos_ptr + half_idx);
float const sin_v = VLLM_LDG(sin_ptr + half_idx);
float const x_even = elements[2 * p];
float const x_odd = elements[2 * p + 1];
elements[2 * p] = x_even * cos_v - x_odd * sin_v;
elements[2 * p + 1] = x_even * sin_v + x_odd * cos_v;
}
}
// ═══════════════════════════════════════════════════════════════════════
// Q branch: cast to bf16 and store back in place.
// ═══════════════════════════════════════════════════════════════════════
if (!isKV) {
uint4 out0, out1;
typename Converter::packed_hip_type* po0 =
reinterpret_cast<typename Converter::packed_hip_type*>(&out0);
typename Converter::packed_hip_type* po1 =
reinterpret_cast<typename Converter::packed_hip_type*>(&out1);
#pragma unroll
for (int i = 0; i < 4; i++) {
po0[i] = Converter::convert(
make_float2(elements[2 * i], elements[2 * i + 1]));
}
#pragma unroll
for (int i = 0; i < 4; i++) {
po1[i] = Converter::convert(
make_float2(elements[8 + 2 * i], elements[8 + 2 * i + 1]));
}
scalar_t_in* dst =
q_inout +
(static_cast<int64_t>(tokenIdx) * num_heads_q + slotIdx) * kHeadDim +
dim_base;
*reinterpret_cast<uint4*>(dst) = out0;
*reinterpret_cast<uint4*>(dst + 8) = out1;
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
cudaTriggerProgrammaticLaunchCompletion();
#endif
return;
}
// ═══════════════════════════════════════════════════════════════════════
// KV branch.
// ═══════════════════════════════════════════════════════════════════════
int64_t const slot_id = slot_mapping[tokenIdx];
if (slot_id < 0) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
cudaTriggerProgrammaticLaunchCompletion();
#endif
return;
}
int64_t const block_idx = slot_id / cache_block_size;
int64_t const pos_in_block = slot_id % cache_block_size;
uint8_t* block_base =
k_cache + block_idx * static_cast<int64_t>(kv_block_stride);
uint8_t* token_fp8_ptr = block_base + pos_in_block * kTokenDataBytes;
uint8_t* token_bf16_ptr = token_fp8_ptr + kNopeDim;
uint8_t* token_scale_ptr =
block_base + static_cast<int64_t>(cache_block_size) * kTokenDataBytes +
pos_in_block * kScaleBytesPerToken;
// Round K to bf16 first, matching the unfused reference path where K is
// materialized as bf16 before K quantization. absmax, clamp, and FP8
// quant below all run on these bf16-rounded values.
#pragma unroll
for (int i = 0; i < kElemsPerLane; i++) {
elements[i] = Converter::convert(Converter::convert(elements[i]));
}
// Per-quant-block absmax must be computed by ALL 32 lanes (warp-collective
// shuffle requires full participation). RoPE lanes contribute garbage,
// but their values are gated out below via `!is_rope_lane`.
float local_absmax = 0.0f;
#pragma unroll
for (int i = 0; i < kElemsPerLane; i++) {
local_absmax = fmaxf(local_absmax, fabsf(elements[i]));
}
float const absmax = fmaxf(warp4MaxAbs(local_absmax), 1e-4f);
float const exponent = ceilf(log2f(absmax / kFp8Max));
float const inv_scale = exp2f(-exponent);
if (!is_rope_lane) {
// ── NoPE lane: UE8M0 FP8 quant ───────────────────────────────────────
uint8_t out_bytes[kElemsPerLane];
#pragma unroll
for (int i = 0; i < kElemsPerLane; i++) {
float scaled = elements[i] * inv_scale;
scaled = fminf(fmaxf(scaled, -kFp8Max), kFp8Max);
__nv_fp8_storage_t s =
__nv_cvt_float_to_fp8(scaled, __NV_SATFINITE, __NV_E4M3);
out_bytes[i] = static_cast<uint8_t>(s);
}
// One 16-byte STG per lane.
*reinterpret_cast<uint4*>(token_fp8_ptr + dim_base) =
*reinterpret_cast<uint4 const*>(out_bytes);
// Lane (4k) of each 4-lane group writes the scale byte for block k<7.
if ((laneId & 3) == 0) {
int const q_block_idx = laneId >> 2; // 0..6 for NoPE lanes
float encoded = fmaxf(fminf(exponent + 127.0f, 255.0f), 0.0f);
token_scale_ptr[q_block_idx] = static_cast<uint8_t>(encoded);
}
// Lane 0 also writes the padding byte at index 7.
if (laneId == 0) {
token_scale_ptr[kNumQuantBlocks] = 0; // pad
}
} else {
// ── RoPE lane: cast back to bf16 and store to cache bf16 tail ────────
uint4 out0, out1;
typename Converter::packed_hip_type* po0 =
reinterpret_cast<typename Converter::packed_hip_type*>(&out0);
typename Converter::packed_hip_type* po1 =
reinterpret_cast<typename Converter::packed_hip_type*>(&out1);
#pragma unroll
for (int i = 0; i < 4; i++) {
po0[i] = Converter::convert(
make_float2(elements[2 * i], elements[2 * i + 1]));
}
#pragma unroll
for (int i = 0; i < 4; i++) {
po1[i] = Converter::convert(
make_float2(elements[8 + 2 * i], elements[8 + 2 * i + 1]));
}
int const rope_local_base = dim_base - kNopeDim; // in [0, 64)
scalar_t_in* bf16_dst =
reinterpret_cast<scalar_t_in*>(token_bf16_ptr) + rope_local_base;
*reinterpret_cast<uint4*>(bf16_dst) = out0;
*reinterpret_cast<uint4*>(bf16_dst + 8) = out1;
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
cudaTriggerProgrammaticLaunchCompletion();
#endif
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
}
#endif
}
// ────────────────────────────────────────────────────────────────────────────
// Launch wrapper
// ────────────────────────────────────────────────────────────────────────────
template <typename scalar_t_in>
void launchFusedDeepseekV4QNormRopeKVRopeQuantInsert(
scalar_t_in* q_inout, scalar_t_in const* kv_in, uint8_t* k_cache,
int64_t const* slot_mapping, int64_t const* position_ids,
float const* cos_sin_cache, float const eps, int const num_tokens_full,
int const num_tokens_insert, int const num_heads_q,
int const cache_block_size, int const kv_block_stride,
cudaStream_t stream) {
constexpr int kBlockSize = 256;
constexpr int kWarpsPerBlock = kBlockSize / 32;
int64_t const total_warps =
static_cast<int64_t>(num_tokens_full) * (num_heads_q + 1);
int const grid =
static_cast<int>((total_warps + kWarpsPerBlock - 1) / kWarpsPerBlock);
// PDL: enable programmatic stream serialization whenever the hardware
// supports it (SM90+). On pre-Hopper GPUs the attribute is unavailable,
// so leave numAttrs = 0 and launch as a regular kernel.
static int const sm_version = getSMVersion();
// Host-side guard: the device kernel body is compiled as a no-op for
// bf16 on pre-Ampere (sm_70/sm_75) because _typeConvert<BFloat16> is
// unavailable there. Refuse the launch loudly instead of silently
// skipping the work.
TORCH_CHECK(
sm_version >= 80,
"fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert requires sm_80+ "
"(Ampere or newer); got sm_",
sm_version);
cudaLaunchConfig_t config;
config.gridDim = dim3(grid);
config.blockDim = dim3(kBlockSize);
config.dynamicSmemBytes = 0;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = 1;
config.attrs = attrs;
config.numAttrs = (sm_version >= 90) ? 1 : 0;
cudaLaunchKernelEx(
&config, fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel<scalar_t_in>,
q_inout, kv_in, k_cache, slot_mapping, position_ids, cos_sin_cache, eps,
num_tokens_full, num_tokens_insert, num_heads_q, cache_block_size,
kv_block_stride);
}
} // namespace deepseek_v4_fused_ops
} // namespace vllm
// ────────────────────────────────────────────────────────────────────────────
// Torch op wrapper
// ────────────────────────────────────────────────────────────────────────────
void fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
torch::Tensor& q, // [N, H, 512] bf16, in place
torch::Tensor const& kv, // [N, 512] bf16 (read-only)
torch::Tensor& k_cache, // [num_blocks, block_bytes] uint8
torch::Tensor const& slot_mapping, // [N] int64
torch::Tensor const& position_ids, // [N] int64
torch::Tensor const& cos_sin_cache, // [max_pos, rope_dim] bf16
double eps, int64_t cache_block_size) {
TORCH_CHECK(q.is_cuda() && q.is_contiguous(), "q must be contiguous CUDA");
TORCH_CHECK(kv.is_cuda() && kv.is_contiguous(), "kv must be contiguous CUDA");
TORCH_CHECK(k_cache.is_cuda(), "k_cache must be CUDA");
TORCH_CHECK(slot_mapping.is_cuda() && slot_mapping.dtype() == torch::kInt64,
"slot_mapping must be int64 CUDA");
TORCH_CHECK(position_ids.is_cuda() && position_ids.dtype() == torch::kInt64,
"position_ids must be int64 CUDA");
TORCH_CHECK(cos_sin_cache.is_cuda(), "cos_sin_cache must be CUDA");
TORCH_CHECK(q.dim() == 3 && q.size(2) == 512, "q shape [N, H, 512]");
TORCH_CHECK(kv.dim() == 2 && kv.size(1) == 512, "kv shape [N, 512]");
TORCH_CHECK(q.dtype() == kv.dtype(), "q and kv dtype must match");
TORCH_CHECK(k_cache.dtype() == torch::kUInt8, "k_cache must be uint8");
TORCH_CHECK(cos_sin_cache.dim() == 2 && cos_sin_cache.size(1) == 64,
"cos_sin_cache shape [max_pos, 64]");
TORCH_CHECK(cos_sin_cache.dtype() == torch::kFloat32,
"cos_sin_cache must be float32");
// With DP padding, slot_mapping can be shorter than q/kv/positions.
// Q-norm+RoPE runs on all q.size(0) rows (downstream attention uses them);
// KV quant+insert runs only on the first slot_mapping.size(0) rows.
int const num_tokens_full = static_cast<int>(q.size(0));
int const num_tokens_insert = static_cast<int>(slot_mapping.size(0));
TORCH_CHECK(static_cast<int>(kv.size(0)) == num_tokens_full &&
static_cast<int>(position_ids.size(0)) == num_tokens_full,
"q/kv/position_ids row counts must match");
TORCH_CHECK(num_tokens_insert <= num_tokens_full,
"slot_mapping must not exceed q row count");
int const num_heads_q = static_cast<int>(q.size(1));
int const cache_block_size_i = static_cast<int>(cache_block_size);
int const kv_block_stride = static_cast<int>(k_cache.stride(0));
at::cuda::OptionalCUDAGuard device_guard(device_of(q));
auto stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_HALF_TYPES(
q.scalar_type(), "fused_deepseek_v4_qnorm_rope_kv_insert", [&] {
using qkv_scalar_t = scalar_t;
vllm::deepseek_v4_fused_ops::
launchFusedDeepseekV4QNormRopeKVRopeQuantInsert<qkv_scalar_t>(
reinterpret_cast<qkv_scalar_t*>(q.data_ptr()),
reinterpret_cast<qkv_scalar_t const*>(kv.data_ptr()),
reinterpret_cast<uint8_t*>(k_cache.data_ptr()),
reinterpret_cast<int64_t const*>(slot_mapping.data_ptr()),
reinterpret_cast<int64_t const*>(position_ids.data_ptr()),
cos_sin_cache.data_ptr<float>(), static_cast<float>(eps),
num_tokens_full, num_tokens_insert, num_heads_q,
cache_block_size_i, kv_block_stride, stream);
});
}
...@@ -77,7 +77,8 @@ __global__ void rms_norm_kernel( ...@@ -77,7 +77,8 @@ __global__ void rms_norm_kernel(
#pragma unroll #pragma unroll
for (int j = 0; j < VEC_SIZE; j++) { for (int j = 0; j < VEC_SIZE; j++) {
float x = static_cast<float>(src1.val[j]); float x = static_cast<float>(src1.val[j]);
dst.val[j] = ((scalar_t)(x * s_variance)) * src2.val[j]; float w = static_cast<float>(src2.val[j]);
dst.val[j] = static_cast<scalar_t>(x * s_variance * w);
} }
v_out[i] = dst; v_out[i] = dst;
} }
...@@ -134,10 +135,17 @@ fused_add_rms_norm_kernel( ...@@ -134,10 +135,17 @@ fused_add_rms_norm_kernel(
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx; int id = blockIdx.x * vec_hidden_size + idx;
int64_t strided_id = blockIdx.x * vec_input_stride + idx; int64_t strided_id = blockIdx.x * vec_input_stride + idx;
_f16Vec<scalar_t, width> temp = residual_v[id]; _f16Vec<scalar_t, width> res = residual_v[id];
temp *= s_variance; _f16Vec<scalar_t, width> w = weight_v[idx];
temp *= weight_v[idx]; _f16Vec<scalar_t, width> out;
input_v[strided_id] = temp; using Converter = _typeConvert<scalar_t>;
#pragma unroll
for (int j = 0; j < width; ++j) {
float x = Converter::convert(res.data[j]);
float wf = Converter::convert(w.data[j]);
out.data[j] = Converter::convert(x * s_variance * wf);
}
input_v[strided_id] = out;
} }
} }
...@@ -174,8 +182,8 @@ fused_add_rms_norm_kernel( ...@@ -174,8 +182,8 @@ fused_add_rms_norm_kernel(
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)residual[blockIdx.x * hidden_size + idx]; float x = (float)residual[blockIdx.x * hidden_size + idx];
input[blockIdx.x * input_stride + idx] = float w = (float)weight[idx];
((scalar_t)(x * s_variance)) * weight[idx]; input[blockIdx.x * input_stride + idx] = (scalar_t)(x * s_variance * w);
} }
} }
......
...@@ -65,9 +65,16 @@ __global__ void rms_norm_static_fp8_quant_kernel( ...@@ -65,9 +65,16 @@ __global__ void rms_norm_static_fp8_quant_kernel(
#pragma unroll #pragma unroll
for (int j = 0; j < VEC_SIZE; j++) { for (int j = 0; j < VEC_SIZE; j++) {
float x = static_cast<float>(src1.val[j]); float x = static_cast<float>(src1.val[j]);
float const out_norm = ((scalar_t)(x * s_variance)) * src2.val[j]; float w = static_cast<float>(src2.val[j]);
// Round normalized result through scalar_t to match the precision of the
// unfused composite (rms_norm writes scalar_t, then
// static_scaled_fp8_quant re-loads it as float before FP8 conversion).
// Without this round, the fused path is strictly more accurate and
// disagrees with the composite at exact E4M3 quantization tie boundaries.
scalar_t out_norm = static_cast<scalar_t>(x * s_variance * w);
out[blockIdx.x * hidden_size + idx * VEC_SIZE + j] = out[blockIdx.x * hidden_size + idx * VEC_SIZE + j] =
scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv); scaled_fp8_conversion<true, fp8_type>(static_cast<float>(out_norm),
scale_inv);
} }
} }
} }
...@@ -127,13 +134,21 @@ fused_add_rms_norm_static_fp8_quant_kernel( ...@@ -127,13 +134,21 @@ fused_add_rms_norm_static_fp8_quant_kernel(
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx; int id = blockIdx.x * vec_hidden_size + idx;
_f16Vec<scalar_t, width> temp = residual_v[id]; _f16Vec<scalar_t, width> res = residual_v[id];
temp *= s_variance; _f16Vec<scalar_t, width> w = weight_v[idx];
temp *= weight_v[idx]; using Converter = _typeConvert<scalar_t>;
using HipT = typename Converter::hip_type;
#pragma unroll #pragma unroll
for (int i = 0; i < width; ++i) { for (int i = 0; i < width; ++i) {
out[id * width + i] = float x = Converter::convert(res.data[i]);
scaled_fp8_conversion<true, fp8_type>(float(temp.data[i]), scale_inv); float wf = Converter::convert(w.data[i]);
// See note in rms_norm_static_fp8_quant_kernel: round through scalar_t
// to match the unfused composite path at FP8 boundaries. We use the
// backend's hip_type for the intermediate since c10::Half/BFloat16 has
// ambiguous conversions on CUDA and no implicit conversion on ROCm.
HipT out_norm_h = Converter::convert(x * s_variance * wf);
out[id * width + i] = scaled_fp8_conversion<true, fp8_type>(
Converter::convert(out_norm_h), scale_inv);
} }
} }
} }
...@@ -176,9 +191,12 @@ fused_add_rms_norm_static_fp8_quant_kernel( ...@@ -176,9 +191,12 @@ fused_add_rms_norm_static_fp8_quant_kernel(
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)residual[blockIdx.x * hidden_size + idx]; float x = (float)residual[blockIdx.x * hidden_size + idx];
float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx]; float w = (float)weight[idx];
out[blockIdx.x * hidden_size + idx] = // See note in rms_norm_static_fp8_quant_kernel: round through scalar_t
scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv); // to match the unfused composite path at FP8 boundaries.
scalar_t out_norm = static_cast<scalar_t>(x * s_variance * w);
out[blockIdx.x * hidden_size + idx] = scaled_fp8_conversion<true, fp8_type>(
static_cast<float>(out_norm), scale_inv);
} }
} }
......
...@@ -12,6 +12,15 @@ void topk_sigmoid(torch::Tensor& topk_weights, torch::Tensor& topk_indices, ...@@ -12,6 +12,15 @@ void topk_sigmoid(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
torch::Tensor& gating_output, bool renormalize, torch::Tensor& gating_output, bool renormalize,
std::optional<torch::Tensor> bias); std::optional<torch::Tensor> bias);
void topk_softplus_sqrt(torch::Tensor& topk_weights,
torch::Tensor& topk_indices,
torch::Tensor& token_expert_indices,
torch::Tensor& gating_output, bool renormalize,
double routed_scaling_factor,
const c10::optional<torch::Tensor>& correction_bias,
const c10::optional<torch::Tensor>& input_ids,
const c10::optional<torch::Tensor>& tid2eid);
void moe_sum(torch::Tensor& input, torch::Tensor& output); void moe_sum(torch::Tensor& input, torch::Tensor& output);
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
......
/*
* Adapted from
* https://github.com/NVIDIA/TensorRT-LLM/blob/v0.7.1/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu
* Copyright (c) 2024, The vLLM team.
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION &
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <type_traits>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "../cuda_compat.h"
#include "../cub_helpers.h"
#ifndef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#else
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
typedef __hip_bfloat16 __nv_bfloat16;
typedef __hip_bfloat162 __nv_bfloat162;
#endif
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
namespace vllm {
namespace moe {
/// Aligned array type
template <typename T,
/// Number of elements in the array
int N,
/// Alignment requirement in bytes
int Alignment = sizeof(T) * N>
struct alignas(Alignment) AlignedArray {
T data[N];
};
template <typename T>
__device__ __forceinline__ float toFloat(T value) {
if constexpr (std::is_same_v<T, float>) {
return value;
} else if constexpr (std::is_same_v<T, __nv_bfloat16>) {
return __bfloat162float(value);
} else if constexpr (std::is_same_v<T, __half>) {
return __half2float(value);
}
}
#define FINAL_MASK 0xffffffff
template <typename T>
__inline__ __device__ T warpReduceSum(T val) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val += __shfl_xor_sync(FINAL_MASK, val, mask, 32);
return val;
}
// ====================== TopK softplus_sqrt things
// ===============================
/*
A Top-K gating softplus_sqrt written to exploit when the number of experts in
the MoE layers are a small power of 2. This allows us to cleanly share the
rows among the threads in a single warp and eliminate communication between
warps (so no need to use shared mem).
It fuses the sigmoid, max and argmax into a single kernel.
Limitations:
1) This implementation is optimized for when the number of experts is a small
power of 2. Additionally it also supports when number of experts is multiple
of 64 which is still faster than the computing sigmoid and topK separately
(only tested on CUDA yet). 2) This implementation assumes k is small, but will
work for any k.
*/
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG,
int WARP_SIZE_PARAM, bool USE_HASH, typename IndType,
typename InputType = float>
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
void topkGatingSoftplusSqrt(
const InputType* input, const bool* finished, float* output,
const int num_rows, IndType* indices, int* source_rows, const int k,
const int start_expert, const int end_expert, const bool renormalize,
double routed_scaling_factor, const float* correction_bias,
const IndType* input_ids, const IndType* tid2eid) {
static_assert(std::is_same_v<InputType, float> ||
std::is_same_v<InputType, __nv_bfloat16> ||
std::is_same_v<InputType, __half>,
"InputType must be float, __nv_bfloat16, or __half");
// We begin by enforcing compile time assertions and setting up compile time
// constants.
static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG),
"BYTES_PER_LDG must be power of 2");
static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16");
// Number of bytes each thread pulls in per load
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(InputType);
static constexpr int ELTS_PER_ROW = NUM_EXPERTS;
static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT;
static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG;
if constexpr (std::is_same_v<InputType, __nv_bfloat16> ||
std::is_same_v<InputType, __half>) {
static_assert(ELTS_PER_LDG == 1 || ELTS_PER_LDG % 2 == 0,
"ELTS_PER_LDG must be 1 or even for 16-bit conversion");
}
// Restrictions based on previous section.
static_assert(
VPT % ELTS_PER_LDG == 0,
"The elements per thread must be a multiple of the elements per ldg");
static_assert(WARP_SIZE_PARAM % THREADS_PER_ROW == 0,
"The threads per row must cleanly divide the threads per warp");
static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW),
"THREADS_PER_ROW must be power of 2");
static_assert(THREADS_PER_ROW <= WARP_SIZE_PARAM,
"THREADS_PER_ROW can be at most warp size");
// We have NUM_EXPERTS elements per row. We specialize for small #experts
static constexpr int ELTS_PER_WARP = WARP_SIZE_PARAM * VPT;
static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW;
static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP;
// Restrictions for previous section.
static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0,
"The elts per row must cleanly divide the total elt per warp");
// ===================== From this point, we finally start computing run-time
// variables. ========================
// Compute CTA and warp rows. We pack multiple rows into a single warp, and a
// block contains WARPS_PER_CTA warps. This, each block processes a chunk of
// rows. We start by computing the start row for each block.
const int cta_base_row = blockIdx.x * ROWS_PER_CTA;
// Now, using the base row per thread block, we compute the base row per warp.
const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP;
// The threads in a warp are split into sub-groups that will work on a row.
// We compute row offset for each thread sub-group
const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW;
const int thread_row = warp_base_row + thread_row_in_warp;
// Threads with indices out of bounds should early exit here.
if (thread_row >= num_rows) {
return;
}
const bool row_is_active = finished ? !finished[thread_row] : true;
// We finally start setting up the read pointers for each thread. First, each
// thread jumps to the start of the row it will read.
const InputType* thread_row_ptr = input + thread_row * ELTS_PER_ROW;
// Now, we compute the group each thread belong to in order to determine the
// first column to start loads.
const int thread_group_idx = threadIdx.x % THREADS_PER_ROW;
const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG;
const InputType* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
// Finally, we pull in the data from global mem
float row_chunk[VPT];
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
#endif
// NOTE(zhuhaoran): dispatch different input types loading, BF16/FP16 convert
// to float
if constexpr (std::is_same_v<InputType, float>) {
using VecType = AlignedArray<float, ELTS_PER_LDG>;
VecType* row_chunk_vec_ptr = reinterpret_cast<VecType*>(&row_chunk);
const VecType* vec_thread_read_ptr =
reinterpret_cast<const VecType*>(thread_read_ptr);
#pragma unroll
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW];
}
} else if constexpr (std::is_same_v<InputType, __nv_bfloat16>) {
if constexpr (ELTS_PER_LDG >= 2) {
using VecType = AlignedArray<__nv_bfloat16, ELTS_PER_LDG>;
float2* row_chunk_f2 = reinterpret_cast<float2*>(row_chunk);
const VecType* vec_thread_read_ptr =
reinterpret_cast<const VecType*>(thread_read_ptr);
#pragma unroll
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
VecType vec = vec_thread_read_ptr[ii * THREADS_PER_ROW];
int base_idx_f2 = ii * ELTS_PER_LDG / 2;
#pragma unroll
for (int jj = 0; jj < ELTS_PER_LDG / 2; ++jj) {
row_chunk_f2[base_idx_f2 + jj] = __bfloat1622float2(
*reinterpret_cast<const __nv_bfloat162*>(vec.data + jj * 2));
}
}
} else { // ELTS_PER_LDG == 1
#pragma unroll
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
const __nv_bfloat16* scalar_ptr =
thread_read_ptr + ii * THREADS_PER_ROW;
row_chunk[ii] = __bfloat162float(*scalar_ptr);
}
}
} else if constexpr (std::is_same_v<InputType, __half>) {
if constexpr (ELTS_PER_LDG >= 2) {
using VecType = AlignedArray<__half, ELTS_PER_LDG>;
float2* row_chunk_f2 = reinterpret_cast<float2*>(row_chunk);
const VecType* vec_thread_read_ptr =
reinterpret_cast<const VecType*>(thread_read_ptr);
#pragma unroll
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
VecType vec = vec_thread_read_ptr[ii * THREADS_PER_ROW];
int base_idx_f2 = ii * ELTS_PER_LDG / 2;
#pragma unroll
for (int jj = 0; jj < ELTS_PER_LDG / 2; ++jj) {
row_chunk_f2[base_idx_f2 + jj] = __half22float2(
*reinterpret_cast<const __half2*>(vec.data + jj * 2));
}
}
} else { // ELTS_PER_LDG == 1
#pragma unroll
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
const __half* scalar_ptr = thread_read_ptr + ii * THREADS_PER_ROW;
row_chunk[ii] = __half2float(*scalar_ptr);
}
}
}
constexpr float threshold = 20.0f;
constexpr float beta = 1.0f;
// Hash MoE path: indices are predetermined from lookup table
if constexpr (USE_HASH) {
const IndType token_id = input_ids[thread_row];
const IndType* expert_indices_for_token = tid2eid + token_id * k;
#pragma unroll
for (int ii = 0; ii < VPT; ++ii) {
float val = row_chunk[ii];
float val_b = val * beta;
val = (val_b > threshold) ? val : (__logf(1.0f + __expf(val_b))) / beta;
row_chunk[ii] = sqrtf(val);
}
float selected_sum = 0.f;
#pragma unroll
for (int k_idx = 0; k_idx < k; ++k_idx) {
const int expert = expert_indices_for_token[k_idx];
const int idx = k * thread_row + k_idx;
for (int ii = 0; ii < VPT; ++ii) {
const int group_id = ii / ELTS_PER_LDG;
const int local_id = ii % ELTS_PER_LDG;
const int expert_idx = first_elt_read_by_thread +
group_id * THREADS_PER_ROW * ELTS_PER_LDG +
local_id;
if (expert == expert_idx) {
indices[idx] = expert;
selected_sum += row_chunk[ii];
break;
}
}
}
// Compute per-thread scale (using warp reduction when renormalizing).
if (renormalize) {
selected_sum = warpReduceSum(selected_sum);
}
float scale = static_cast<float>(routed_scaling_factor);
if (renormalize) {
const float denom = selected_sum > 0.f ? selected_sum : 1.f;
scale /= denom;
}
#pragma unroll
for (int k_idx = 0; k_idx < k; ++k_idx) {
const int expert = expert_indices_for_token[k_idx];
const int idx = k * thread_row + k_idx;
for (int ii = 0; ii < VPT; ++ii) {
const int group_id = ii / ELTS_PER_LDG;
const int local_id = ii % ELTS_PER_LDG;
const int expert_idx = first_elt_read_by_thread +
group_id * THREADS_PER_ROW * ELTS_PER_LDG +
local_id;
if (expert == expert_idx) {
output[idx] = row_chunk[ii] * scale;
break;
}
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
#endif
return;
}
#pragma unroll
for (int ii = 0; ii < VPT; ++ii) {
float val = row_chunk[ii];
float val_b = val * beta;
// Compute softplus: log(1 + exp(val)) with numerical stability
// When val > threshold, softplus(x) ≈ x to avoid exp overflow
val = (val_b > threshold) ? val : (__logf(1.0f + __expf(val_b))) / beta;
val = sqrtf(val);
if (correction_bias) {
const int group_id = ii / ELTS_PER_LDG;
const int local_id = ii % ELTS_PER_LDG;
const int expert_idx = first_elt_read_by_thread +
group_id * THREADS_PER_ROW * ELTS_PER_LDG +
local_id;
val = val + correction_bias[expert_idx];
}
row_chunk[ii] = val;
}
// Original TopK path: find top-k experts by score
// Now, sigmoid_res contains the sigmoid of the row chunk. Now, I want to find
// the topk elements in each row, along with the max index.
int start_col = first_elt_read_by_thread;
static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW;
float selected_sum = 0.f;
for (int k_idx = 0; k_idx < k; ++k_idx) {
// First, each thread does the local argmax
float max_val = row_chunk[0];
int expert = start_col;
#pragma unroll
for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD;
++ldg, col += COLS_PER_GROUP_LDG) {
#pragma unroll
for (int ii = 0; ii < ELTS_PER_LDG; ++ii) {
float val = row_chunk[ldg * ELTS_PER_LDG + ii];
// No check on the experts here since columns with the smallest index
// are processed first and only updated if > (not >=)
if (val > max_val) {
max_val = val;
expert = col + ii;
}
}
}
// Now, we perform the argmax reduce. We use the butterfly pattern so threads
// reach consensus about the max. This will be useful for K > 1 so that the
// threads can agree on "who" had the max value. That thread can then blank out
// their max with -inf and the warp can run more iterations...
#pragma unroll
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
float other_max =
VLLM_SHFL_XOR_SYNC_WIDTH(max_val, mask, THREADS_PER_ROW);
int other_expert =
VLLM_SHFL_XOR_SYNC_WIDTH(expert, mask, THREADS_PER_ROW);
// We want lower indices to "win" in every thread so we break ties this
// way
if (other_max > max_val ||
(other_max == max_val && other_expert < expert)) {
max_val = other_max;
expert = other_expert;
}
}
// Write the max for this k iteration to global memory.
if (thread_group_idx == 0) {
// Add a guard to ignore experts not included by this node
const bool node_uses_expert =
expert >= start_expert && expert < end_expert;
const bool should_process_row = row_is_active && node_uses_expert;
// The lead thread from each sub-group will write out the final results to
// global memory. (This will be a single) thread per row of the
// input/output matrices.
const int idx = k * thread_row + k_idx;
if (correction_bias != nullptr) {
max_val -= correction_bias[expert];
}
output[idx] = max_val;
indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS;
source_rows[idx] = k_idx * num_rows + thread_row;
if (renormalize) {
selected_sum += max_val;
}
}
// Finally, we clear the value in the thread with the current max if there
// is another iteration to run.
if (k_idx + 1 < k) {
const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG;
const int thread_to_clear_in_group =
(expert / ELTS_PER_LDG) % THREADS_PER_ROW;
// Only the thread in the group which produced the max will reset the
// "winning" value to -inf.
if (thread_group_idx == thread_to_clear_in_group) {
const int offset_for_expert = expert % ELTS_PER_LDG;
// Safe to set to any negative value since row_chunk values must be
// between 0 and 1.
row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] =
-10000.f;
}
}
}
// Apply renormalization and routed scaling factor to final weights.
if (thread_group_idx == 0) {
float scale = static_cast<float>(routed_scaling_factor);
if (renormalize) {
const float denom = selected_sum > 0.f ? selected_sum : 1.f;
scale /= denom;
}
for (int k_idx = 0; k_idx < k; ++k_idx) {
const int idx = k * thread_row + k_idx;
output[idx] = output[idx] * scale;
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
#endif
}
namespace detail {
// Constructs some constants needed to partition the work across threads at
// compile time.
template <int EXPERTS, int BYTES_PER_LDG, int WARP_SIZE_PARAM,
typename InputType>
struct TopkConstants {
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(InputType);
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE_PARAM) == 0 ||
EXPERTS % (ELTS_PER_LDG * WARP_SIZE_PARAM) == 0,
"");
static constexpr int VECs_PER_THREAD =
MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE_PARAM));
static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
static const int ROWS_PER_WARP = WARP_SIZE_PARAM / THREADS_PER_ROW;
};
} // namespace detail
#define DISPATCH_HASH(use_hash, USE_HASH, ...) \
if (use_hash) { \
const bool USE_HASH = true; \
static_assert(USE_HASH == true, "USE_HASH must be compile-time constant"); \
__VA_ARGS__ \
} else { \
const bool USE_HASH = false; \
static_assert(USE_HASH == false, \
"USE_HASH must be compile-time constant"); \
__VA_ARGS__ \
}
template <int EXPERTS, int WARPS_PER_TB, int WARP_SIZE_PARAM,
int MAX_BYTES_PER_LDG, typename IndType, typename InputType>
void topkGatingSoftplusSqrtLauncherHelper(
const InputType* input, const bool* finished, float* output,
IndType* indices, int* source_row, const int num_rows, const int k,
const int start_expert, const int end_expert, const bool renormalize,
double routed_scaling_factor, const float* correction_bias,
const bool use_hash, const IndType* input_ids, const IndType* tid2eid,
cudaStream_t stream) {
static constexpr int BYTES_PER_LDG =
MIN(MAX_BYTES_PER_LDG, sizeof(InputType) * EXPERTS);
using Constants =
detail::TopkConstants<EXPERTS, BYTES_PER_LDG, WARP_SIZE_PARAM, InputType>;
static constexpr int VPT = Constants::VPT;
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
dim3 block_dim(WARP_SIZE_PARAM, WARPS_PER_TB);
DISPATCH_HASH(use_hash, USE_HASH, {
auto* kernel =
&topkGatingSoftplusSqrt<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG,
WARP_SIZE_PARAM, USE_HASH, IndType, InputType>;
#ifndef USE_ROCM
cudaLaunchConfig_t config = {};
config.gridDim = num_blocks;
config.blockDim = block_dim;
config.dynamicSmemBytes = 0;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = 1;
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(&config, kernel, input, finished, output, num_rows,
indices, source_row, k, start_expert, end_expert,
renormalize, routed_scaling_factor, correction_bias,
input_ids, tid2eid);
#else
kernel<<<num_blocks, block_dim, 0, stream>>>(
input, finished, output, num_rows, indices, source_row, k, start_expert,
end_expert, renormalize, routed_scaling_factor, correction_bias,
input_ids, tid2eid);
#endif
})
}
#ifndef USE_ROCM
#define LAUNCH_SOFTPLUS_SQRT(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \
static_assert(WARP_SIZE == 32, \
"Unsupported warp size. Only 32 is supported for CUDA"); \
topkGatingSoftplusSqrtLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, WARP_SIZE, \
MAX_BYTES>( \
gating_output, nullptr, topk_weights, topk_indices, \
token_expert_indices, num_tokens, topk, 0, num_experts, renormalize, \
routed_scaling_factor, correction_bias, use_hash, input_ids, tid2eid, \
stream);
#else
#define LAUNCH_SOFTPLUS_SQRT(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \
if (WARP_SIZE == 64) { \
topkGatingSoftplusSqrtLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 64, \
MAX_BYTES>( \
gating_output, nullptr, topk_weights, topk_indices, \
token_expert_indices, num_tokens, topk, 0, num_experts, renormalize, \
routed_scaling_factor, correction_bias, use_hash, input_ids, \
tid2eid, stream); \
} else if (WARP_SIZE == 32) { \
topkGatingSoftplusSqrtLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 32, \
MAX_BYTES>( \
gating_output, nullptr, topk_weights, topk_indices, \
token_expert_indices, num_tokens, topk, 0, num_experts, renormalize, \
routed_scaling_factor, correction_bias, use_hash, input_ids, \
tid2eid, stream); \
} else { \
assert(false && \
"Unsupported warp size. Only 32 and 64 are supported for ROCm"); \
}
#endif
template <typename IndType, typename InputType>
void topkGatingSoftplusSqrtKernelLauncher(
const InputType* gating_output, float* topk_weights, IndType* topk_indices,
int* token_expert_indices, const int num_tokens, const int num_experts,
const int topk, const bool renormalize, double routed_scaling_factor,
const float* correction_bias, const bool use_hash, const IndType* input_ids,
const IndType* tid2eid, cudaStream_t stream) {
static constexpr int WARPS_PER_TB = 4;
static constexpr int BYTES_PER_LDG_POWER_OF_2 = 16;
#ifndef USE_ROCM
// for bfloat16 dtype, we need 4 bytes loading to make sure num_experts
// elements can be loaded by a warp
static constexpr int BYTES_PER_LDG_MULTIPLE_64 =
(std::is_same_v<InputType, __nv_bfloat16> ||
std::is_same_v<InputType, __half>)
? 4
: 8;
#endif
switch (num_experts) {
case 1:
LAUNCH_SOFTPLUS_SQRT(1, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
break;
case 2:
LAUNCH_SOFTPLUS_SQRT(2, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
break;
case 4:
LAUNCH_SOFTPLUS_SQRT(4, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
break;
case 8:
LAUNCH_SOFTPLUS_SQRT(8, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
break;
case 16:
LAUNCH_SOFTPLUS_SQRT(16, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
break;
case 32:
LAUNCH_SOFTPLUS_SQRT(32, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
break;
case 64:
LAUNCH_SOFTPLUS_SQRT(64, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
break;
case 128:
LAUNCH_SOFTPLUS_SQRT(128, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
break;
case 256:
LAUNCH_SOFTPLUS_SQRT(256, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
break;
case 512:
LAUNCH_SOFTPLUS_SQRT(512, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
break;
// (CUDA only) support multiples of 64 when num_experts is not power of 2.
// ROCm uses WARP_SIZE 64 so 8 bytes loading won't fit for some of
// num_experts, alternatively we can test 4 bytes loading and enable it in
// future.
#ifndef USE_ROCM
case 192:
LAUNCH_SOFTPLUS_SQRT(192, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
break;
case 320:
LAUNCH_SOFTPLUS_SQRT(320, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
break;
case 384:
LAUNCH_SOFTPLUS_SQRT(384, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
break;
case 448:
LAUNCH_SOFTPLUS_SQRT(448, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
break;
case 576:
LAUNCH_SOFTPLUS_SQRT(576, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
break;
#endif
default: {
TORCH_CHECK(false, "Unsupported expert number: ", num_experts);
}
}
}
} // namespace moe
} // namespace vllm
template <typename ComputeType>
void dispatch_topk_softplus_sqrt_launch(
const ComputeType* gating_output, torch::Tensor& topk_weights,
torch::Tensor& topk_indices, torch::Tensor& token_expert_indices,
int num_tokens, int num_experts, int topk, bool renormalize,
double routed_scaling_factor,
const c10::optional<torch::Tensor>& correction_bias,
const c10::optional<torch::Tensor>& input_ids,
const c10::optional<torch::Tensor>& tid2eid, cudaStream_t stream) {
const float* bias_ptr = nullptr;
if (correction_bias.has_value()) {
bias_ptr = correction_bias.value().data_ptr<float>();
}
bool use_hash = false;
if (tid2eid.has_value()) {
TORCH_CHECK(input_ids.has_value(), "input_ids is required for hash MoE");
use_hash = true;
}
if (topk_indices.scalar_type() == at::ScalarType::Int) {
const int* input_ids_ptr = nullptr;
const int* tid2eid_ptr = nullptr;
if (tid2eid.has_value()) {
input_ids_ptr = input_ids.value().data_ptr<int>();
tid2eid_ptr = tid2eid.value().data_ptr<int>();
}
vllm::moe::topkGatingSoftplusSqrtKernelLauncher<int, ComputeType>(
gating_output, topk_weights.data_ptr<float>(),
topk_indices.data_ptr<int>(), token_expert_indices.data_ptr<int>(),
num_tokens, num_experts, topk, renormalize, routed_scaling_factor,
bias_ptr, use_hash, input_ids_ptr, tid2eid_ptr, stream);
} else if (topk_indices.scalar_type() == at::ScalarType::UInt32) {
const uint32_t* input_ids_ptr = nullptr;
const uint32_t* tid2eid_ptr = nullptr;
if (tid2eid.has_value()) {
input_ids_ptr = input_ids.value().data_ptr<uint32_t>();
tid2eid_ptr = tid2eid.value().data_ptr<uint32_t>();
}
vllm::moe::topkGatingSoftplusSqrtKernelLauncher<uint32_t, ComputeType>(
gating_output, topk_weights.data_ptr<float>(),
topk_indices.data_ptr<uint32_t>(), token_expert_indices.data_ptr<int>(),
num_tokens, num_experts, topk, renormalize, routed_scaling_factor,
bias_ptr, use_hash, input_ids_ptr, tid2eid_ptr, stream);
} else {
TORCH_CHECK(topk_indices.scalar_type() == at::ScalarType::Long);
const int64_t* input_ids_ptr = nullptr;
const int64_t* tid2eid_ptr = nullptr;
if (tid2eid.has_value()) {
input_ids_ptr = input_ids.value().data_ptr<int64_t>();
tid2eid_ptr = tid2eid.value().data_ptr<int64_t>();
}
vllm::moe::topkGatingSoftplusSqrtKernelLauncher<int64_t, ComputeType>(
gating_output, topk_weights.data_ptr<float>(),
topk_indices.data_ptr<int64_t>(), token_expert_indices.data_ptr<int>(),
num_tokens, num_experts, topk, renormalize, routed_scaling_factor,
bias_ptr, use_hash, input_ids_ptr, tid2eid_ptr, stream);
}
}
void topk_softplus_sqrt(
torch::Tensor& topk_weights, // [num_tokens, topk]
torch::Tensor& topk_indices, // [num_tokens, topk]
torch::Tensor& token_expert_indices, // [num_tokens, topk]
torch::Tensor& gating_output, // [num_tokens, num_experts]
bool renormalize, double routed_scaling_factor,
const c10::optional<torch::Tensor>& correction_bias,
const c10::optional<torch::Tensor>& input_ids,
const c10::optional<torch::Tensor>& tid2eid) {
const int num_experts = gating_output.size(-1);
const auto num_tokens = gating_output.numel() / num_experts;
const int topk = topk_weights.size(-1);
const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (gating_output.scalar_type() == at::ScalarType::Float) {
dispatch_topk_softplus_sqrt_launch<float>(
gating_output.data_ptr<float>(), topk_weights, topk_indices,
token_expert_indices, num_tokens, num_experts, topk, renormalize,
routed_scaling_factor, correction_bias, input_ids, tid2eid, stream);
} else if (gating_output.scalar_type() == at::ScalarType::Half) {
dispatch_topk_softplus_sqrt_launch<__half>(
reinterpret_cast<const __half*>(gating_output.data_ptr<at::Half>()),
topk_weights, topk_indices, token_expert_indices, num_tokens,
num_experts, topk, renormalize, routed_scaling_factor, correction_bias,
input_ids, tid2eid, stream);
} else if (gating_output.scalar_type() == at::ScalarType::BFloat16) {
dispatch_topk_softplus_sqrt_launch<__nv_bfloat16>(
reinterpret_cast<const __nv_bfloat16*>(
gating_output.data_ptr<at::BFloat16>()),
topk_weights, topk_indices, token_expert_indices, num_tokens,
num_experts, topk, renormalize, routed_scaling_factor, correction_bias,
input_ids, tid2eid, stream);
} else {
TORCH_CHECK(false, "Unsupported gating_output data type: ",
gating_output.scalar_type());
}
}
\ No newline at end of file
...@@ -16,6 +16,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { ...@@ -16,6 +16,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"bias) -> ()"); "bias) -> ()");
m.impl("topk_sigmoid", torch::kCUDA, &topk_sigmoid); m.impl("topk_sigmoid", torch::kCUDA, &topk_sigmoid);
#ifndef USE_ROCM
m.def(
"topk_softplus_sqrt(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"token_expert_indices, Tensor gating_output, bool renormalize, float "
"routed_scaling_factor, Tensor? "
"bias, Tensor? input_ids, Tensor? tid2eid) -> ()");
m.impl("topk_softplus_sqrt", torch::kCUDA, &topk_softplus_sqrt);
#endif
// Calculate the result of moe by summing up the partial results // Calculate the result of moe by summing up the partial results
// from all selected experts. // from all selected experts.
m.def("moe_sum(Tensor input, Tensor! output) -> ()"); m.def("moe_sum(Tensor input, Tensor! output) -> ()");
......
...@@ -100,6 +100,11 @@ void fused_qk_norm_rope(torch::Tensor& qkv, int64_t num_heads_q, ...@@ -100,6 +100,11 @@ void fused_qk_norm_rope(torch::Tensor& qkv, int64_t num_heads_q,
bool is_neox, torch::Tensor& position_ids, bool is_neox, torch::Tensor& position_ids,
int64_t forced_token_heads_per_warp); int64_t forced_token_heads_per_warp);
void fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
torch::Tensor& q, torch::Tensor const& kv, torch::Tensor& k_cache,
torch::Tensor const& slot_mapping, torch::Tensor const& position_ids,
torch::Tensor const& cos_sin_cache, double eps, int64_t cache_block_size);
void apply_repetition_penalties_(torch::Tensor& logits, void apply_repetition_penalties_(torch::Tensor& logits,
const torch::Tensor& prompt_mask, const torch::Tensor& prompt_mask,
const torch::Tensor& output_mask, const torch::Tensor& output_mask,
...@@ -153,7 +158,8 @@ void silu_and_mul_per_block_quant(torch::Tensor& out, ...@@ -153,7 +158,8 @@ void silu_and_mul_per_block_quant(torch::Tensor& out,
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
std::optional<torch::Tensor> key, int64_t head_size, std::optional<torch::Tensor> key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox); torch::Tensor& cos_sin_cache, bool is_neox,
int64_t rope_dim_offset, bool inverse);
void silu_and_mul(torch::Tensor& out, torch::Tensor& input); void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
......
...@@ -18,7 +18,6 @@ namespace persistent { ...@@ -18,7 +18,6 @@ namespace persistent {
// Constants // Constants
// ============================================================================ // ============================================================================
constexpr int TopK = 2048;
constexpr int kThreadsPerBlock = 1024; constexpr int kThreadsPerBlock = 1024;
constexpr int RADIX = 256; constexpr int RADIX = 256;
...@@ -128,11 +127,12 @@ struct RadixRowState { ...@@ -128,11 +127,12 @@ struct RadixRowState {
struct PersistentTopKParams { struct PersistentTopKParams {
const float* __restrict__ input; // [num_rows, stride] const float* __restrict__ input; // [num_rows, stride]
int32_t* __restrict__ output; // [num_rows, TopK] int32_t* __restrict__ output; // [num_rows, top_k]
int32_t* __restrict__ lengths; // [num_rows] int32_t* __restrict__ lengths; // [num_rows]
RadixRowState* row_states; // large path: per-group state RadixRowState* row_states; // large path: per-group state
uint32_t num_rows; uint32_t num_rows;
uint32_t stride; uint32_t stride;
uint32_t top_k; // actual k value for output stride
uint32_t chunk_size; // large path: elements per CTA uint32_t chunk_size; // large path: elements per CTA
uint32_t ctas_per_group; // 1=medium, >1=large uint32_t ctas_per_group; // 1=medium, >1=large
uint32_t max_seq_len; // max seq_len across all rows (for early CTA exit) uint32_t max_seq_len; // max seq_len across all rows (for early CTA exit)
...@@ -154,6 +154,7 @@ __device__ __forceinline__ uint32_t decode_bin(float x) { ...@@ -154,6 +154,7 @@ __device__ __forceinline__ uint32_t decode_bin(float x) {
return key >> 5; return key >> 5;
} }
template <int TopK>
__device__ __noinline__ void histogram_2048_topk( __device__ __noinline__ void histogram_2048_topk(
const float* __restrict__ logits, int32_t* __restrict__ output_indices, const float* __restrict__ logits, int32_t* __restrict__ output_indices,
int32_t seq_len) { int32_t seq_len) {
...@@ -418,6 +419,7 @@ __device__ __noinline__ void histogram_2048_topk( ...@@ -418,6 +419,7 @@ __device__ __noinline__ void histogram_2048_topk(
// by: DarkSharpness // by: DarkSharpness
// which at the same time is an optimized topk kernel copied from tilelang // which at the same time is an optimized topk kernel copied from tilelang
// kernel // kernel
template <int TopK>
__device__ __noinline__ void histogram_256_topk( __device__ __noinline__ void histogram_256_topk(
const float* __restrict__ logits, int* __restrict__ output_indices, const float* __restrict__ logits, int* __restrict__ output_indices,
int logits_offset, int seq_len) { int logits_offset, int seq_len) {
...@@ -649,7 +651,7 @@ __device__ __forceinline__ void wait_ge(int* ptr, int target_val, ...@@ -649,7 +651,7 @@ __device__ __forceinline__ void wait_ge(int* ptr, int target_val,
// Adapted from https://github.com/flashinfer-ai/flashinfer/pull/2215 // Adapted from https://github.com/flashinfer-ai/flashinfer/pull/2215
// ============================================================================ // ============================================================================
template <uint32_t VEC_SIZE> template <int TopK, uint32_t VEC_SIZE>
__device__ void radix_topk(const float* __restrict__ row_input, __device__ void radix_topk(const float* __restrict__ row_input,
int32_t* __restrict__ row_output, uint32_t seq_len, int32_t* __restrict__ row_output, uint32_t seq_len,
uint32_t my_chunk_start, uint32_t chunk_size, uint32_t my_chunk_start, uint32_t chunk_size,
...@@ -857,7 +859,7 @@ __device__ void radix_topk(const float* __restrict__ row_input, ...@@ -857,7 +859,7 @@ __device__ void radix_topk(const float* __restrict__ row_input,
// see filtered_topk.cuh) // see filtered_topk.cuh)
// ============================================================================ // ============================================================================
template <uint32_t VEC_SIZE = 1> template <int TopK = 2048, uint32_t VEC_SIZE = 1>
__global__ void __launch_bounds__(kThreadsPerBlock, 2) __global__ void __launch_bounds__(kThreadsPerBlock, 2)
persistent_topk_kernel(PersistentTopKParams params) { persistent_topk_kernel(PersistentTopKParams params) {
const uint32_t tx = threadIdx.x; const uint32_t tx = threadIdx.x;
...@@ -915,7 +917,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock, 2) ...@@ -915,7 +917,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock, 2)
if (row_idx >= params.num_rows) break; if (row_idx >= params.num_rows) break;
const uint32_t seq_len = params.lengths[row_idx]; const uint32_t seq_len = params.lengths[row_idx];
int32_t* row_output = params.output + row_idx * TopK; int32_t* row_output = params.output + row_idx * params.top_k;
const float* row_input = params.input + row_idx * params.stride; const float* row_input = params.input + row_idx * params.stride;
if (seq_len <= RADIX_THRESHOLD) { if (seq_len <= RADIX_THRESHOLD) {
...@@ -927,19 +929,19 @@ __global__ void __launch_bounds__(kThreadsPerBlock, 2) ...@@ -927,19 +929,19 @@ __global__ void __launch_bounds__(kThreadsPerBlock, 2)
row_output[i] = (i < seq_len) ? static_cast<int32_t>(i) : -1; row_output[i] = (i < seq_len) ? static_cast<int32_t>(i) : -1;
} }
} else if (seq_len <= static_cast<uint32_t>(HIST2048_THRESHOLD)) { } else if (seq_len <= static_cast<uint32_t>(HIST2048_THRESHOLD)) {
histogram_2048_topk(row_input, row_output, seq_len); histogram_2048_topk<TopK>(row_input, row_output, seq_len);
} else { } else {
histogram_256_topk(row_input, row_output, 0, seq_len); histogram_256_topk<TopK>(row_input, row_output, 0, seq_len);
} }
} }
continue; continue;
} }
const uint32_t my_chunk_start = cta_in_group * chunk_size; const uint32_t my_chunk_start = cta_in_group * chunk_size;
radix_topk<VEC_SIZE>(row_input, row_output, seq_len, my_chunk_start, radix_topk<TopK, VEC_SIZE>(
chunk_size, local_histogram, suffix_sum, row_input, row_output, seq_len, my_chunk_start, chunk_size,
shared_scalars, shared_ordered, state, cta_in_group, local_histogram, suffix_sum, shared_scalars, shared_ordered, state,
ctas_per_group, barrier_phase, iter, tx); cta_in_group, ctas_per_group, barrier_phase, iter, tx);
} }
} }
...@@ -1011,7 +1013,6 @@ struct FilteredTopKTraits<float> { ...@@ -1011,7 +1013,6 @@ struct FilteredTopKTraits<float> {
} }
}; };
constexpr uint32_t FILTERED_TOPK_MAX_K = 2048;
constexpr uint32_t FILTERED_TOPK_BLOCK_THREADS = 1024; constexpr uint32_t FILTERED_TOPK_BLOCK_THREADS = 1024;
constexpr uint32_t FILTERED_TOPK_SMEM_INPUT_SIZE = constexpr uint32_t FILTERED_TOPK_SMEM_INPUT_SIZE =
16 * 1024; // 16K indices per buffer 16 * 1024; // 16K indices per buffer
...@@ -1025,7 +1026,7 @@ constexpr size_t FILTERED_TOPK_SMEM_DYNAMIC = ...@@ -1025,7 +1026,7 @@ constexpr size_t FILTERED_TOPK_SMEM_DYNAMIC =
* \tparam IdType Index type (int32_t) * \tparam IdType Index type (int32_t)
* \tparam VEC_SIZE Vector size for input loads (1, 2, 4, or 8) * \tparam VEC_SIZE Vector size for input loads (1, 2, 4, or 8)
*/ */
template <typename DType, typename IdType, int VEC_SIZE> template <typename DType, typename IdType, int VEC_SIZE, uint32_t MAX_K = 2048>
__global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS)
FilteredTopKUnifiedKernel(const DType* __restrict__ input, FilteredTopKUnifiedKernel(const DType* __restrict__ input,
IdType* __restrict__ output, IdType* __restrict__ output,
...@@ -1059,7 +1060,7 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) ...@@ -1059,7 +1060,7 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS)
alignas(128) __shared__ int s_counter; alignas(128) __shared__ int s_counter;
alignas(128) __shared__ int s_threshold_bin_id; alignas(128) __shared__ int s_threshold_bin_id;
alignas(128) __shared__ int s_num_input[2]; alignas(128) __shared__ int s_num_input[2];
alignas(128) __shared__ int s_indices[FILTERED_TOPK_MAX_K]; alignas(128) __shared__ int s_indices[MAX_K];
auto& s_histogram = s_histogram_buf[0]; auto& s_histogram = s_histogram_buf[0];
...@@ -1280,7 +1281,7 @@ constexpr int ComputeFilteredTopKVecSize(uint32_t max_len) { ...@@ -1280,7 +1281,7 @@ constexpr int ComputeFilteredTopKVecSize(uint32_t max_len) {
return static_cast<int>(g); return static_cast<int>(g);
} }
template <typename DType, typename IdType> template <typename DType, typename IdType, uint32_t MAX_K = 2048>
cudaError_t FilteredTopKRaggedTransform(DType* input, IdType* output_indices, cudaError_t FilteredTopKRaggedTransform(DType* input, IdType* output_indices,
IdType* lengths, uint32_t num_rows, IdType* lengths, uint32_t num_rows,
uint32_t top_k_val, uint32_t max_len, uint32_t top_k_val, uint32_t max_len,
...@@ -1297,7 +1298,7 @@ cudaError_t FilteredTopKRaggedTransform(DType* input, IdType* output_indices, ...@@ -1297,7 +1298,7 @@ cudaError_t FilteredTopKRaggedTransform(DType* input, IdType* output_indices,
#define DISPATCH_VEC_SIZE(VS) \ #define DISPATCH_VEC_SIZE(VS) \
if (vec_size == VS) { \ if (vec_size == VS) { \
auto kernel = FilteredTopKUnifiedKernel<DType, IdType, VS>; \ auto kernel = FilteredTopKUnifiedKernel<DType, IdType, VS, MAX_K>; \
FLASHINFER_CUDA_CALL(cudaFuncSetAttribute( \ FLASHINFER_CUDA_CALL(cudaFuncSetAttribute( \
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); \ kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); \
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, grid, block, args, \ FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, grid, block, args, \
......
...@@ -9,28 +9,29 @@ namespace vllm { ...@@ -9,28 +9,29 @@ namespace vllm {
template <typename scalar_t, bool IS_NEOX> template <typename scalar_t, bool IS_NEOX>
inline __device__ void apply_token_rotary_embedding( inline __device__ void apply_token_rotary_embedding(
scalar_t* __restrict__ arr, const scalar_t* __restrict__ cos_ptr, scalar_t* __restrict__ arr, const float* __restrict__ cos_ptr,
const scalar_t* __restrict__ sin_ptr, int rot_offset, int embed_dim) { const float* __restrict__ sin_ptr, int rot_offset, int embed_dim,
const bool inverse) {
int x_index, y_index; int x_index, y_index;
scalar_t cos, sin; float cos_f, sin_f;
if (IS_NEOX) { if (IS_NEOX) {
// GPT-NeoX style rotary embedding.
x_index = rot_offset; x_index = rot_offset;
y_index = embed_dim + rot_offset; y_index = embed_dim + rot_offset;
cos = VLLM_LDG(cos_ptr + x_index); cos_f = VLLM_LDG(cos_ptr + x_index);
sin = VLLM_LDG(sin_ptr + x_index); sin_f = VLLM_LDG(sin_ptr + x_index);
} else { } else {
// GPT-J style rotary embedding.
x_index = 2 * rot_offset; x_index = 2 * rot_offset;
y_index = 2 * rot_offset + 1; y_index = 2 * rot_offset + 1;
cos = VLLM_LDG(cos_ptr + x_index / 2); cos_f = VLLM_LDG(cos_ptr + x_index / 2);
sin = VLLM_LDG(sin_ptr + x_index / 2); sin_f = VLLM_LDG(sin_ptr + x_index / 2);
} }
if (inverse) {
const scalar_t x = arr[x_index]; sin_f = -sin_f;
const scalar_t y = arr[y_index]; }
arr[x_index] = x * cos - y * sin; const float x_f = static_cast<float>(arr[x_index]);
arr[y_index] = y * cos + x * sin; const float y_f = static_cast<float>(arr[y_index]);
arr[x_index] = static_cast<scalar_t>(x_f * cos_f - y_f * sin_f);
arr[y_index] = static_cast<scalar_t>(y_f * cos_f + x_f * sin_f);
} }
template <typename scalar_t, bool IS_NEOX> template <typename scalar_t, bool IS_NEOX>
...@@ -42,22 +43,23 @@ inline __device__ void apply_rotary_embedding( ...@@ -42,22 +43,23 @@ inline __device__ void apply_rotary_embedding(
// [batch_size, seq_len, num_kv_heads, // [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads, // head_size] or [num_tokens, num_kv_heads,
// head_size] // head_size]
const scalar_t* cache_ptr, const int head_size, const int num_heads, const float* cache_ptr, const int head_size, const int num_heads,
const int num_kv_heads, const int rot_dim, const int token_idx, const int num_kv_heads, const int rot_dim, const int token_idx,
const int64_t query_stride, const int64_t key_stride, const int64_t query_stride, const int64_t key_stride,
const int64_t head_stride) { const int64_t head_stride, const int64_t rope_dim_offset,
const bool inverse) {
const int embed_dim = rot_dim / 2; const int embed_dim = rot_dim / 2;
const scalar_t* cos_ptr = cache_ptr; const float* cos_ptr = cache_ptr;
const scalar_t* sin_ptr = cache_ptr + embed_dim; const float* sin_ptr = cache_ptr + embed_dim;
const int nq = num_heads * embed_dim; const int nq = num_heads * embed_dim;
for (int i = threadIdx.x; i < nq; i += blockDim.x) { for (int i = threadIdx.x; i < nq; i += blockDim.x) {
const int head_idx = i / embed_dim; const int head_idx = i / embed_dim;
const int64_t token_head = const int64_t token_head =
token_idx * query_stride + head_idx * head_stride; token_idx * query_stride + head_idx * head_stride + rope_dim_offset;
const int rot_offset = i % embed_dim; const int rot_offset = i % embed_dim;
apply_token_rotary_embedding<scalar_t, IS_NEOX>( apply_token_rotary_embedding<scalar_t, IS_NEOX>(
query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim, inverse);
} }
if (key != nullptr) { if (key != nullptr) {
...@@ -65,10 +67,10 @@ inline __device__ void apply_rotary_embedding( ...@@ -65,10 +67,10 @@ inline __device__ void apply_rotary_embedding(
for (int i = threadIdx.x; i < nk; i += blockDim.x) { for (int i = threadIdx.x; i < nk; i += blockDim.x) {
const int head_idx = i / embed_dim; const int head_idx = i / embed_dim;
const int64_t token_head = const int64_t token_head =
token_idx * key_stride + head_idx * head_stride; token_idx * key_stride + head_idx * head_stride + rope_dim_offset;
const int rot_offset = i % embed_dim; const int rot_offset = i % embed_dim;
apply_token_rotary_embedding<scalar_t, IS_NEOX>( apply_token_rotary_embedding<scalar_t, IS_NEOX>(
key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim, inverse);
} }
} }
} }
...@@ -84,19 +86,18 @@ __global__ void rotary_embedding_kernel( ...@@ -84,19 +86,18 @@ __global__ void rotary_embedding_kernel(
// [batch_size, seq_len, num_kv_heads, // [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads, // head_size] or [num_tokens, num_kv_heads,
// head_size] // head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // const float* __restrict__ cos_sin_cache, // [max_position, rot_dim] fp32
// 2]
const int rot_dim, const int64_t query_stride, const int64_t key_stride, const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int64_t head_stride, const int num_heads, const int num_kv_heads, const int64_t head_stride, const int num_heads, const int num_kv_heads,
const int head_size) { const int head_size, const int64_t rope_dim_offset, const bool inverse) {
// Each thread block is responsible for one token.
const int token_idx = blockIdx.x; const int token_idx = blockIdx.x;
int64_t pos = positions[token_idx]; int64_t pos = positions[token_idx];
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; const float* cache_ptr = cos_sin_cache + pos * rot_dim;
apply_rotary_embedding<scalar_t, IS_NEOX>( apply_rotary_embedding<scalar_t, IS_NEOX>(
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
token_idx, query_stride, key_stride, head_stride); token_idx, query_stride, key_stride, head_stride, rope_dim_offset,
inverse);
} }
} // namespace vllm } // namespace vllm
...@@ -115,7 +116,7 @@ void rotary_embedding( ...@@ -115,7 +116,7 @@ void rotary_embedding(
// [num_tokens, num_heads, head_size] // [num_tokens, num_heads, head_size]
int64_t head_size, int64_t head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim] torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox) { bool is_neox, int64_t rope_dim_offset, bool inverse) {
// num_tokens = batch_size * seq_len // num_tokens = batch_size * seq_len
int64_t num_tokens = positions.numel(); int64_t num_tokens = positions.numel();
int positions_ndim = positions.dim(); int positions_ndim = positions.dim();
...@@ -154,6 +155,8 @@ void rotary_embedding( ...@@ -154,6 +155,8 @@ void rotary_embedding(
int seq_dim_idx = positions_ndim - 1; int seq_dim_idx = positions_ndim - 1;
int64_t query_stride = query.stride(seq_dim_idx); int64_t query_stride = query.stride(seq_dim_idx);
int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0; int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0;
TORCH_CHECK((rot_dim + rope_dim_offset) <= head_size);
// Determine head stride: for [*, heads, head_size] use stride of last dim; // Determine head stride: for [*, heads, head_size] use stride of last dim;
// for flat [*, heads*head_size], heads blocks are contiguous of size // for flat [*, heads*head_size], heads blocks are contiguous of size
// head_size // head_size
...@@ -165,20 +168,23 @@ void rotary_embedding( ...@@ -165,20 +168,23 @@ void rotary_embedding(
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512)); dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
auto cache_f32 = cos_sin_cache.to(torch::kFloat32);
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] { VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
if (is_neox) { if (is_neox) {
vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>( vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(), positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.has_value() ? key->data_ptr<scalar_t>() : nullptr, key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride, key_stride, cache_f32.data_ptr<float>(), rot_dim, query_stride, key_stride,
head_stride, num_heads, num_kv_heads, head_size); head_stride, num_heads, num_kv_heads, head_size, rope_dim_offset,
inverse);
} else { } else {
vllm::rotary_embedding_kernel<scalar_t, false> vllm::rotary_embedding_kernel<scalar_t, false>
<<<grid, block, 0, stream>>>( <<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(), positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.has_value() ? key->data_ptr<scalar_t>() : nullptr, key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride, cache_f32.data_ptr<float>(), rot_dim, query_stride, key_stride,
key_stride, head_stride, num_heads, num_kv_heads, head_size); head_stride, num_heads, num_kv_heads, head_size, rope_dim_offset,
inverse);
} }
}); });
} }
...@@ -258,7 +258,13 @@ __device__ bool processHistogramStep( ...@@ -258,7 +258,13 @@ __device__ bool processHistogramStep(
auto processBins = [&](float logit, int idx) { auto processBins = [&](float logit, int idx) {
if (isPartialMatch<patternShift>(logit, logitPattern)) { if (isPartialMatch<patternShift>(logit, logitPattern)) {
uint32_t binIdx = extractBinIdx<step>(logit); uint32_t binIdx = extractBinIdx<step>(logit);
if (binIdx < thresholdBinIdx) { // Only write elements with binIdx < thresholdBinIdx when:
// 1. This is step 0 and the threshold bin is small enough (no step 1)
// 2. This is step >= 1 (where pattern matching filters correctly)
// This prevents duplicates when step 0 and step 1 both run.
bool shouldWriteDirectly =
(step == 0 && smemFinalBinSize[0] <= kNumFinalItems) || (step >= 1);
if (binIdx < thresholdBinIdx && shouldWriteDirectly) {
// The element is part of the top-k selection // The element is part of the top-k selection
int dstIdx = atomicAdd(&smemFoundTopKValues[0], 1); int dstIdx = atomicAdd(&smemFoundTopKValues[0], 1);
......
...@@ -10,33 +10,17 @@ ...@@ -10,33 +10,17 @@
#include "persistent_topk.cuh" #include "persistent_topk.cuh"
#endif #endif
void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths, namespace {
torch::Tensor& output, torch::Tensor& workspace, int64_t k,
int64_t max_seq_len) {
#ifndef USE_ROCM #ifndef USE_ROCM
TORCH_CHECK(logits.is_cuda(), "logits must be CUDA tensor"); template <int TopK>
TORCH_CHECK(lengths.is_cuda(), "lengths must be CUDA tensor"); void launch_persistent_topk(const torch::Tensor& logits,
TORCH_CHECK(output.is_cuda(), "output must be CUDA tensor"); const torch::Tensor& lengths, torch::Tensor& output,
TORCH_CHECK(logits.dtype() == torch::kFloat32, "Only float32 supported"); torch::Tensor& workspace, int64_t max_seq_len) {
TORCH_CHECK(lengths.dtype() == torch::kInt32, "lengths must be int32"); namespace P = vllm::persistent;
TORCH_CHECK(output.dtype() == torch::kInt32, "output must be int32");
TORCH_CHECK(logits.dim() == 2, "logits must be 2D");
TORCH_CHECK(lengths.dim() == 1 || lengths.dim() == 2,
"lengths must be 1D or 2D");
TORCH_CHECK(lengths.is_contiguous(), "lengths must be contiguous");
TORCH_CHECK(output.dim() == 2, "output must be 2D");
const int64_t num_rows = logits.size(0); const int64_t num_rows = logits.size(0);
const int64_t stride = logits.size(1); const int64_t stride = logits.size(1);
TORCH_CHECK(lengths.numel() == num_rows, "lengths size mismatch");
TORCH_CHECK(output.size(0) == num_rows && output.size(1) == k,
"output size mismatch");
namespace P = vllm::persistent;
TORCH_CHECK(k == P::TopK, "k must be 2048");
TORCH_CHECK(k <= stride, "k out of range");
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
static int num_sms = 0; static int num_sms = 0;
...@@ -50,18 +34,17 @@ void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths, ...@@ -50,18 +34,17 @@ void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
} }
if (num_rows > 32 && max_smem_per_block >= 128 * 1024) { if (num_rows > 32 && max_smem_per_block >= 128 * 1024) {
cudaError_t status = vllm::FilteredTopKRaggedTransform<float, int32_t>( cudaError_t status =
logits.data_ptr<float>(), output.data_ptr<int32_t>(), vllm::FilteredTopKRaggedTransform<float, int32_t, TopK>(
lengths.data_ptr<int32_t>(), static_cast<uint32_t>(num_rows), logits.data_ptr<float>(), output.data_ptr<int32_t>(),
static_cast<uint32_t>(k), static_cast<uint32_t>(stride), stream); lengths.data_ptr<int32_t>(), static_cast<uint32_t>(num_rows),
static_cast<uint32_t>(TopK), static_cast<uint32_t>(stride), stream);
TORCH_CHECK(status == cudaSuccess, TORCH_CHECK(status == cudaSuccess,
"FilteredTopK failed: ", cudaGetErrorString(status)); "FilteredTopK failed: ", cudaGetErrorString(status));
} else { } else {
TORCH_CHECK(workspace.is_cuda(), "workspace must be CUDA tensor"); TORCH_CHECK(workspace.is_cuda(), "workspace must be CUDA tensor");
TORCH_CHECK(workspace.dtype() == torch::kUInt8, "workspace must be uint8"); TORCH_CHECK(workspace.dtype() == torch::kUInt8, "workspace must be uint8");
// Smem cap: smaller smem → more CTAs/group → more per-row parallelism for
// large path. Empirically tuned.
int effective_max_smem; int effective_max_smem;
if (num_rows <= 4) { if (num_rows <= 4) {
effective_max_smem = effective_max_smem =
...@@ -101,7 +84,7 @@ void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths, ...@@ -101,7 +84,7 @@ void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
int occupancy = 1; int occupancy = 1;
cudaOccupancyMaxActiveBlocksPerMultiprocessor( cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&occupancy, P::persistent_topk_kernel<4>, P::kThreadsPerBlock, &occupancy, P::persistent_topk_kernel<TopK, 4>, P::kThreadsPerBlock,
smem_size); smem_size);
if (occupancy < 1) occupancy = 1; if (occupancy < 1) occupancy = 1;
...@@ -121,15 +104,16 @@ void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths, ...@@ -121,15 +104,16 @@ void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
params.lengths = lengths.data_ptr<int32_t>(); params.lengths = lengths.data_ptr<int32_t>();
params.num_rows = static_cast<uint32_t>(num_rows); params.num_rows = static_cast<uint32_t>(num_rows);
params.stride = static_cast<uint32_t>(stride); params.stride = static_cast<uint32_t>(stride);
params.top_k = static_cast<uint32_t>(TopK);
params.chunk_size = chunk_size; params.chunk_size = chunk_size;
params.row_states = params.row_states =
reinterpret_cast<P::RadixRowState*>(workspace.data_ptr<uint8_t>()); reinterpret_cast<P::RadixRowState*>(workspace.data_ptr<uint8_t>());
params.ctas_per_group = ctas_per_group; params.ctas_per_group = ctas_per_group;
params.max_seq_len = static_cast<uint32_t>(max_seq_len); params.max_seq_len = static_cast<uint32_t>(max_seq_len);
#define LAUNCH_PERSISTENT(VS) \ #define LAUNCH_PERSISTENT(TOPK_VAL, VS) \
do { \ do { \
auto kernel = &P::persistent_topk_kernel<VS>; \ auto kernel = &P::persistent_topk_kernel<TOPK_VAL, VS>; \
cudaError_t err = cudaFuncSetAttribute( \ cudaError_t err = cudaFuncSetAttribute( \
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); \ kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); \
TORCH_CHECK(err == cudaSuccess, \ TORCH_CHECK(err == cudaSuccess, \
...@@ -138,11 +122,11 @@ void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths, ...@@ -138,11 +122,11 @@ void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
} while (0) } while (0)
if (vec_size == 4) { if (vec_size == 4) {
LAUNCH_PERSISTENT(4); LAUNCH_PERSISTENT(TopK, 4);
} else if (vec_size == 2) { } else if (vec_size == 2) {
LAUNCH_PERSISTENT(2); LAUNCH_PERSISTENT(TopK, 2);
} else { } else {
LAUNCH_PERSISTENT(1); LAUNCH_PERSISTENT(TopK, 1);
} }
#undef LAUNCH_PERSISTENT #undef LAUNCH_PERSISTENT
} }
...@@ -150,6 +134,46 @@ void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths, ...@@ -150,6 +134,46 @@ void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
cudaError_t err = cudaGetLastError(); cudaError_t err = cudaGetLastError();
TORCH_CHECK(err == cudaSuccess, TORCH_CHECK(err == cudaSuccess,
"persistent_topk failed: ", cudaGetErrorString(err)); "persistent_topk failed: ", cudaGetErrorString(err));
}
#endif
} // anonymous namespace
void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
torch::Tensor& output, torch::Tensor& workspace, int64_t k,
int64_t max_seq_len) {
#ifndef USE_ROCM
TORCH_CHECK(logits.is_cuda(), "logits must be CUDA tensor");
TORCH_CHECK(lengths.is_cuda(), "lengths must be CUDA tensor");
TORCH_CHECK(output.is_cuda(), "output must be CUDA tensor");
TORCH_CHECK(logits.dtype() == torch::kFloat32, "Only float32 supported");
TORCH_CHECK(lengths.dtype() == torch::kInt32, "lengths must be int32");
TORCH_CHECK(output.dtype() == torch::kInt32, "output must be int32");
TORCH_CHECK(logits.dim() == 2, "logits must be 2D");
TORCH_CHECK(lengths.dim() == 1 || lengths.dim() == 2,
"lengths must be 1D or 2D");
TORCH_CHECK(lengths.is_contiguous(), "lengths must be contiguous");
TORCH_CHECK(output.dim() == 2, "output must be 2D");
const int64_t num_rows = logits.size(0);
const int64_t stride = logits.size(1);
TORCH_CHECK(lengths.numel() == num_rows, "lengths size mismatch");
TORCH_CHECK(output.size(0) == num_rows && output.size(1) == k,
"output size mismatch");
TORCH_CHECK(k == 512 || k == 1024 || k == 2048,
"persistent_topk supports k=512, k=1024, or k=2048, got k=", k);
if (k == 512) {
launch_persistent_topk<512>(logits, lengths, output, workspace,
max_seq_len);
} else if (k == 1024) {
launch_persistent_topk<1024>(logits, lengths, output, workspace,
max_seq_len);
} else {
launch_persistent_topk<2048>(logits, lengths, output, workspace,
max_seq_len);
}
#else #else
TORCH_CHECK(false, "persistent_topk is not supported on ROCm"); TORCH_CHECK(false, "persistent_topk is not supported on ROCm");
#endif #endif
......
...@@ -177,6 +177,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -177,6 +177,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"int forced_token_heads_per_warp=-1) -> ()"); "int forced_token_heads_per_warp=-1) -> ()");
ops.impl("fused_qk_norm_rope", torch::kCUDA, &fused_qk_norm_rope); ops.impl("fused_qk_norm_rope", torch::kCUDA, &fused_qk_norm_rope);
#ifndef USE_ROCM
// Horizontally-fused DeepseekV4-MLA: per-head RMSNorm + GPT-J RoPE for Q, and
// GPT-J RoPE + UE8M0 FP8 quant + paged cache insert for KV, all in one
// kernel launch.
ops.def(
"fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert("
"Tensor! q, Tensor kv, Tensor! k_cache, "
"Tensor slot_mapping, Tensor position_ids, Tensor cos_sin_cache, "
"float eps, int cache_block_size) -> ()");
ops.impl("fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert", torch::kCUDA,
&fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert);
#endif
// Apply repetition penalties to logits in-place // Apply repetition penalties to logits in-place
ops.def( ops.def(
"apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, " "apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, "
...@@ -240,7 +253,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -240,7 +253,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def( ops.def(
"rotary_embedding(Tensor positions, Tensor! query," "rotary_embedding(Tensor positions, Tensor! query,"
" Tensor!? key, int head_size," " Tensor!? key, int head_size,"
" Tensor cos_sin_cache, bool is_neox) -> ()"); " Tensor cos_sin_cache, bool is_neox, int "
"rope_dim_offset=0, bool inverse=False) -> ()");
ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding); ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);
// Quantization ops // Quantization ops
......
...@@ -213,7 +213,7 @@ configuration. ...@@ -213,7 +213,7 @@ configuration.
| `FLASHINFER_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | 10.x | | `FLASHINFER_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | 10.x |
| `FLASHINFER_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 10.x | | `FLASHINFER_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 10.x |
| `FLASHMLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 64 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x-10.x | | `FLASHMLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 64 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x-10.x |
| `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_ds_mla` | 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x | | `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_ds_mla` | 64 | 512, 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x |
| `FLASH_ATTN_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x | | `FLASH_ATTN_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x |
| `ROCM_AITER_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | | `ROCM_AITER_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
| `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16` | 1 | Any | ❌ | ✅ | ❌ | ❌ | Decoder | N/A | | `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16` | 1 | Any | ❌ | ✅ | ❌ | ❌ | Decoder | N/A |
......
...@@ -384,6 +384,7 @@ th { ...@@ -384,6 +384,7 @@ th {
| `DeepseekForCausalLM` | DeepSeek | `deepseek-ai/deepseek-llm-67b-base`, `deepseek-ai/deepseek-llm-7b-chat`, etc. | ✅︎ | ✅︎ | | `DeepseekForCausalLM` | DeepSeek | `deepseek-ai/deepseek-llm-67b-base`, `deepseek-ai/deepseek-llm-7b-chat`, etc. | ✅︎ | ✅︎ |
| `DeepseekV2ForCausalLM` | DeepSeek-V2 | `deepseek-ai/DeepSeek-V2`, `deepseek-ai/DeepSeek-V2-Chat`, etc. | ✅︎ | ✅︎ | | `DeepseekV2ForCausalLM` | DeepSeek-V2 | `deepseek-ai/DeepSeek-V2`, `deepseek-ai/DeepSeek-V2-Chat`, etc. | ✅︎ | ✅︎ |
| `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3`, `deepseek-ai/DeepSeek-R1`, `deepseek-ai/DeepSeek-V3.1`, etc. | ✅︎ | ✅︎ | | `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3`, `deepseek-ai/DeepSeek-R1`, `deepseek-ai/DeepSeek-V3.1`, etc. | ✅︎ | ✅︎ |
| `DeepseekV4ForCausalLM` | DeepSeek-V4 | `deepseek-ai/DeepSeek-V4-Flash`, `deepseek-ai/DeepSeek-V4-Pro`, etc. | | |
| `Dots1ForCausalLM` | dots.llm1 | `rednote-hilab/dots.llm1.base`, `rednote-hilab/dots.llm1.inst`, etc. | | ✅︎ | | `Dots1ForCausalLM` | dots.llm1 | `rednote-hilab/dots.llm1.base`, `rednote-hilab/dots.llm1.inst`, etc. | | ✅︎ |
| `DotsOCRForCausalLM` | dots_ocr | `rednote-hilab/dots.ocr` | ✅︎ | ✅︎ | | `DotsOCRForCausalLM` | dots_ocr | `rednote-hilab/dots.ocr` | ✅︎ | ✅︎ |
| `Ernie4_5ForCausalLM` | Ernie4.5 | `baidu/ERNIE-4.5-0.3B-PT`, etc. | ✅︎ | ✅︎ | | `Ernie4_5ForCausalLM` | Ernie4.5 | `baidu/ERNIE-4.5-0.3B-PT`, etc. | ✅︎ | ✅︎ |
...@@ -643,10 +644,10 @@ Some models are supported only via the [Transformers modeling backend](#transfor ...@@ -643,10 +644,10 @@ Some models are supported only via the [Transformers modeling backend](#transfor
!!! note !!! note
`Gemma3nForConditionalGeneration` is only supported on V1 due to shared KV caching and it depends on `timm>=1.0.17` to make use of its `Gemma3nForConditionalGeneration` is only supported on V1 due to shared KV caching and it depends on `timm>=1.0.17` to make use of its
MobileNet-v5 vision backbone. MobileNet-v5 vision backbone.
Performance is not yet fully optimized mainly due to: Performance is not yet fully optimized mainly due to:
- Both audio and vision MM encoders use `transformers.AutoModel` implementation. - Both audio and vision MM encoders use `transformers.AutoModel` implementation.
- There's no PLE caching or out-of-memory swapping support, as described in [Google's blog](https://developers.googleblog.com/en/introducing-gemma-3n/). These features might be too model-specific for vLLM, and swapping in particular may be better suited for constrained setups. - There's no PLE caching or out-of-memory swapping support, as described in [Google's blog](https://developers.googleblog.com/en/introducing-gemma-3n/). These features might be too model-specific for vLLM, and swapping in particular may be better suited for constrained setups.
!!! note !!! note
......
...@@ -11,6 +11,8 @@ torchvision==0.26.0 # Required for phi3v processor. See https://github.com/pytor ...@@ -11,6 +11,8 @@ torchvision==0.26.0 # Required for phi3v processor. See https://github.com/pytor
# FlashInfer should be updated together with the Dockerfile # FlashInfer should be updated together with the Dockerfile
flashinfer-python==0.6.8.post1 flashinfer-python==0.6.8.post1
flashinfer-cubin==0.6.8.post1 flashinfer-cubin==0.6.8.post1
apache-tvm-ffi==0.1.9
tilelang==0.1.9
# Cap nvidia-cudnn-frontend (transitive dep of flashinfer) due to # Cap nvidia-cudnn-frontend (transitive dep of flashinfer) due to
# breaking changes in 1.19.0 # breaking changes in 1.19.0
nvidia-cudnn-frontend>=1.13.0,<1.19.0 nvidia-cudnn-frontend>=1.13.0,<1.19.0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment