Unverified Commit 4beeb068 authored by Kaicheng Yang's avatar Kaicheng Yang Committed by GitHub
Browse files

fused qknorm+rope kernel optimization for SM9.0 (#37376)


Signed-off-by: default avatarEricccYang <yangyang4991@gmail.com>
Signed-off-by: default avatarKaicheng Yang <53411596+EricccYang@users.noreply.github.com>
Co-authored-by: default avatarClaude Sonnet 4.6 <noreply@anthropic.com>
parent cae98406
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
namespace vllm {
namespace cuda_async {
__device__ __forceinline__ void cp_async_shared_global_16_cg(
void* smem_ptr, const void* glob_ptr) {
#if defined(USE_ROCM)
*reinterpret_cast<int4*>(smem_ptr) = *reinterpret_cast<const int4*>(glob_ptr);
#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile("cp.async.cg.shared.global [%0], [%1], 16;\n"
:
: "r"(smem), "l"(glob_ptr));
#elif defined(__CUDA_ARCH__)
*reinterpret_cast<int4*>(smem_ptr) = *reinterpret_cast<const int4*>(glob_ptr);
#else
(void)smem_ptr;
(void)glob_ptr;
#endif
}
__device__ __forceinline__ void cp_async_shared_global_ca(void* smem_ptr,
const void* glob_ptr,
int size_bytes) {
#if defined(USE_ROCM)
if (size_bytes == 4) {
*reinterpret_cast<uint32_t*>(smem_ptr) =
*reinterpret_cast<const uint32_t*>(glob_ptr);
} else if (size_bytes == 8) {
*reinterpret_cast<uint64_t*>(smem_ptr) =
*reinterpret_cast<const uint64_t*>(glob_ptr);
} else {
*reinterpret_cast<int4*>(smem_ptr) =
*reinterpret_cast<const int4*>(glob_ptr);
}
#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
if (size_bytes == 4) {
asm volatile("cp.async.ca.shared.global [%0], [%1], 4;\n"
:
: "r"(smem), "l"(glob_ptr));
} else if (size_bytes == 8) {
asm volatile("cp.async.ca.shared.global [%0], [%1], 8;\n"
:
: "r"(smem), "l"(glob_ptr));
} else {
asm volatile("cp.async.ca.shared.global [%0], [%1], 16;\n"
:
: "r"(smem), "l"(glob_ptr));
}
#elif defined(__CUDA_ARCH__)
if (size_bytes == 4) {
*reinterpret_cast<uint32_t*>(smem_ptr) =
*reinterpret_cast<const uint32_t*>(glob_ptr);
} else if (size_bytes == 8) {
*reinterpret_cast<uint64_t*>(smem_ptr) =
*reinterpret_cast<const uint64_t*>(glob_ptr);
} else {
*reinterpret_cast<int4*>(smem_ptr) =
*reinterpret_cast<const int4*>(glob_ptr);
}
#else
(void)smem_ptr;
(void)glob_ptr;
(void)size_bytes;
#endif
}
__device__ __forceinline__ void cp_async_commit_group() {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 && !defined(USE_ROCM)
asm volatile("cp.async.commit_group;\n" ::);
#endif
}
template <int n>
__device__ __forceinline__ void cp_async_wait_group() {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 && !defined(USE_ROCM)
asm volatile("cp.async.wait_group %0;\n" : : "n"(n));
#endif
}
} // namespace cuda_async
} // namespace vllm
...@@ -19,8 +19,10 @@ ...@@ -19,8 +19,10 @@
#include <type_traits> #include <type_traits>
#include <torch/cuda.h> #include <torch/cuda.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include "async_util.cuh"
#include "cuda_compat.h" #include "cuda_compat.h"
#include "dispatch_utils.h" #include "dispatch_utils.h"
#include "type_convert.cuh" #include "type_convert.cuh"
...@@ -86,6 +88,9 @@ inline __device__ __host__ T divUp(T m, T n) { ...@@ -86,6 +88,9 @@ inline __device__ __host__ T divUp(T m, T n) {
} // namespace tensorrt_llm::common } // namespace tensorrt_llm::common
namespace tensorrt_llm::kernels { namespace tensorrt_llm::kernels {
using namespace vllm::cuda_async;
// NOTE(zhuhaoran): This kernel is adapted from TensorRT-LLM implementation, // NOTE(zhuhaoran): This kernel is adapted from TensorRT-LLM implementation,
// with added support for passing the cos_sin_cache as an input. // with added support for passing the cos_sin_cache as an input.
// https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu // https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu
...@@ -301,6 +306,237 @@ __global__ void fusedQKNormRopeKernel( ...@@ -301,6 +306,237 @@ __global__ void fusedQKNormRopeKernel(
#endif #endif
} }
// Multi-token-head kernel: one warp processes HEADS_PER_WARP token-heads for
// the same token, sharing cos/sin from shared memory via cp.async.
// When HEADS_PER_WARP > 1 the warp reuses the loaded cos/sin across all heads,
// hiding global-memory latency and improving occupancy for large batches.
template <typename scalar_t_in, typename scalar_t_cache, int head_dim,
bool interleave, int HEADS_PER_WARP>
__global__ void fusedQKNormRopeKernelNTokenHeads(
void* qkv_void, int const num_heads_q, int const num_heads_k,
int const num_heads_v, float const eps, void const* q_weight_void,
void const* k_weight_void, void const* cos_sin_cache_void,
int64_t const* position_ids, int const num_tokens, int const rotary_dim) {
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
if constexpr ((std::is_same_v<scalar_t_in, c10::BFloat16>) ||
std::is_same_v<scalar_t_cache, c10::BFloat16>) {
return;
} else {
#endif
using Converter = vllm::_typeConvert<scalar_t_in>;
static_assert(Converter::exists,
"Input QKV data type is not supported for this CUDA "
"architecture or toolkit version.");
using T_in = typename Converter::hip_type;
using T2_in = typename Converter::packed_hip_type;
using CacheConverter = vllm::_typeConvert<scalar_t_cache>;
static_assert(CacheConverter::exists,
"Cache data type is not supported for this CUDA architecture "
"or toolkit version.");
using T_cache = typename CacheConverter::hip_type;
extern __shared__ char smem_storage[];
// Shared memory layout:
// [0, cos_sin_bytes) : cos/sin for each warp (warpsPerBlock *
// rotary_dim * sizeof(T_cache))
// [cos_sin_bytes, ...) : QKV tiles
// per warp (warpsPerBlock * HEADS_PER_WARP * 32 * elemSizeBytes)
T_cache* const smem = reinterpret_cast<T_cache*>(smem_storage);
T_in* qkv = reinterpret_cast<T_in*>(qkv_void);
T_in const* q_weight = reinterpret_cast<T_in const*>(q_weight_void);
T_in const* k_weight = reinterpret_cast<T_in const*>(k_weight_void);
T_cache const* cos_sin_cache =
reinterpret_cast<T_cache const*>(cos_sin_cache_void);
int const warpsPerBlock = blockDim.x / 32;
int const warpId = threadIdx.x / 32;
int const laneId = threadIdx.x % 32;
int const total_qk_heads = num_heads_q + num_heads_k;
int const num_heads = num_heads_q + num_heads_k + num_heads_v;
int const head_chunks_per_token =
(total_qk_heads + HEADS_PER_WARP - 1) / HEADS_PER_WARP;
int const warp_global = blockIdx.x * warpsPerBlock + warpId;
int const tokenIdx = warp_global / head_chunks_per_token;
int const headChunk = warp_global % head_chunks_per_token;
int const first_head = headChunk * HEADS_PER_WARP;
int const num_heads_this_warp =
(first_head + HEADS_PER_WARP <= total_qk_heads)
? HEADS_PER_WARP
: (total_qk_heads - first_head);
if (tokenIdx >= num_tokens) return;
static_assert(head_dim % (32 * 2) == 0, "head_dim must be divisible by 64");
constexpr int numElemsPerThread = head_dim / 32;
constexpr int elemSizeBytes = numElemsPerThread * sizeof(__nv_bfloat16);
static_assert(elemSizeBytes % 4 == 0,
"elemSizeBytes must be a multiple of 4");
constexpr int vecSize = elemSizeBytes / 4;
using vec_T = typename tensorrt_llm::common::packed_as<uint, vecSize>::type;
int const cos_sin_bytes =
warpsPerBlock * rotary_dim * static_cast<int>(sizeof(T_cache));
int const qkv_tile_bytes = 32 * elemSizeBytes;
char* const this_warp_head_smem =
smem_storage + cos_sin_bytes +
warpId * (HEADS_PER_WARP * qkv_tile_bytes);
// === Group 0: async load all heads' QKV into smem (issued first). ===
for (int k = 0; k < num_heads_this_warp; ++k) {
int const localHeadIdx = first_head + k;
bool const isQ = localHeadIdx < num_heads_q;
int const headIdx = isQ ? localHeadIdx : localHeadIdx - num_heads_q;
int offWarp;
if (isQ) {
offWarp = tokenIdx * num_heads * head_dim + headIdx * head_dim;
} else {
offWarp = tokenIdx * num_heads * head_dim + num_heads_q * head_dim +
headIdx * head_dim;
}
int const offThread = offWarp + laneId * numElemsPerThread;
char* smem_dst =
this_warp_head_smem + k * qkv_tile_bytes + laneId * elemSizeBytes;
cp_async_shared_global_ca(smem_dst,
reinterpret_cast<const char*>(&qkv[offThread]),
elemSizeBytes);
}
cp_async_commit_group(); // commit group 0 (QKV)
// === Group 1: async load cos/sin into smem (issued second). ===
int64_t const pos_id = position_ids[tokenIdx];
T_cache const* const cache_ptr = cos_sin_cache + pos_id * rotary_dim;
int const copy_bytes = rotary_dim * static_cast<int>(sizeof(T_cache));
int const num_copies = (copy_bytes + 15) / 16;
for (int copyId = laneId; copyId < num_copies; copyId += 32) {
char* smem_ptr =
reinterpret_cast<char*>(&smem[warpId * rotary_dim]) + copyId * 16;
const char* glob_ptr =
reinterpret_cast<const char*>(cache_ptr) + copyId * 16;
cp_async_shared_global_16_cg(smem_ptr, glob_ptr);
}
cp_async_commit_group(); // commit group 1 (cos/sin)
// wait<1>: allow at most 1 pending group (group 1) → group 0 (QKV) is done.
cp_async_wait_group<1>();
float elements[numElemsPerThread];
float elements2[numElemsPerThread];
int const rotary_lanes = rotary_dim / numElemsPerThread;
int const embed_dim = rotary_dim / 2;
T_cache const* const cos_smem = &smem[warpId * rotary_dim];
T_cache const* const sin_smem = &smem[warpId * rotary_dim + embed_dim];
// Preload weights into registers once, reused across all heads.
float q_w[numElemsPerThread];
float k_w[numElemsPerThread];
#pragma unroll
for (int i = 0; i < numElemsPerThread; i++) {
int const dim = laneId * numElemsPerThread + i;
q_w[i] = Converter::convert(q_weight[dim]);
k_w[i] = Converter::convert(k_weight[dim]);
}
for (int k = 0; k < num_heads_this_warp; ++k) {
int const localHeadIdx = first_head + k;
bool const isQ = localHeadIdx < num_heads_q;
int const headIdx = isQ ? localHeadIdx : localHeadIdx - num_heads_q;
int offsetWarp;
if (isQ) {
offsetWarp = tokenIdx * num_heads * head_dim + headIdx * head_dim;
} else {
offsetWarp = tokenIdx * num_heads * head_dim + num_heads_q * head_dim +
headIdx * head_dim;
}
int const offsetThread = offsetWarp + laneId * numElemsPerThread;
// === Part 1: QK Norm (read from smem; group 0 already done). ===
float sumOfSquares = 0.0f;
{
char const* smem_src =
this_warp_head_smem + k * qkv_tile_bytes + laneId * elemSizeBytes;
vec_T vec = *reinterpret_cast<vec_T const*>(smem_src);
constexpr int num_packed_elems = elemSizeBytes / sizeof(T2_in);
#pragma unroll
for (int i = 0; i < num_packed_elems; i++) {
T2_in packed_val = *(reinterpret_cast<T2_in*>(&vec) + i);
float2 vals = Converter::convert(packed_val);
sumOfSquares += vals.x * vals.x;
sumOfSquares += vals.y * vals.y;
elements[2 * i] = vals.x;
elements[2 * i + 1] = vals.y;
}
}
sumOfSquares = tensorrt_llm::common::warpReduceSum(sumOfSquares);
float rms_rcp = rsqrtf(sumOfSquares / static_cast<float>(head_dim) + eps);
#pragma unroll
for (int i = 0; i < numElemsPerThread; i++) {
elements[i] *= rms_rcp * (isQ ? q_w[i] : k_w[i]);
}
// On first head: wait for group 1 (cos/sin) before RoPE.
if (k == 0) cp_async_wait_group<0>();
// === Part 2: RoPE using cos/sin from shared memory. ===
if (laneId < rotary_lanes) {
if constexpr (interleave) {
#pragma unroll
for (int i = 0; i < numElemsPerThread / 2; ++i) {
int const idx0 = 2 * i;
int const idx1 = 2 * i + 1;
int const dim_idx = laneId * numElemsPerThread + idx0;
float const val0 = elements[idx0];
float const val1 = elements[idx1];
int const half_dim = dim_idx / 2;
float const cos_val = CacheConverter::convert(cos_smem[half_dim]);
float const sin_val = CacheConverter::convert(sin_smem[half_dim]);
elements[idx0] = val0 * cos_val - val1 * sin_val;
elements[idx1] = val0 * sin_val + val1 * cos_val;
}
} else {
__syncwarp();
int const pairOffset = (rotary_dim / 2) / numElemsPerThread;
#pragma unroll
for (int i = 0; i < numElemsPerThread; i++) {
elements2[i] = __shfl_xor_sync(FINAL_MASK, elements[i], pairOffset);
if (laneId < pairOffset) elements2[i] = -elements2[i];
int dim_idx = laneId * numElemsPerThread + i;
dim_idx = (dim_idx * 2) % rotary_dim;
int const half_dim = dim_idx / 2;
float const cos_val = CacheConverter::convert(cos_smem[half_dim]);
float const sin_val = CacheConverter::convert(sin_smem[half_dim]);
elements[i] = elements[i] * cos_val + elements2[i] * sin_val;
}
__syncwarp();
}
}
// Store.
{
vec_T vec;
constexpr int num_packed_elems = elemSizeBytes / sizeof(T2_in);
#pragma unroll
for (int i = 0; i < num_packed_elems; i++) {
T2_in packed_val = Converter::convert(
make_float2(elements[2 * i], elements[2 * i + 1]));
*(reinterpret_cast<T2_in*>(&vec) + i) = packed_val;
}
*reinterpret_cast<vec_T*>(&qkv[offsetThread]) = vec;
}
}
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
}
#endif
}
// Borrowed from // Borrowed from
// https://github.com/flashinfer-ai/flashinfer/blob/8125d079a43e9a0ba463a4ed1b639cefd084cec9/include/flashinfer/pos_enc.cuh#L568 // https://github.com/flashinfer-ai/flashinfer/blob/8125d079a43e9a0ba463a4ed1b639cefd084cec9/include/flashinfer/pos_enc.cuh#L568
#define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \ #define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \
...@@ -321,15 +557,12 @@ void launchFusedQKNormRope(void* qkv, int const num_tokens, ...@@ -321,15 +557,12 @@ void launchFusedQKNormRope(void* qkv, int const num_tokens,
void const* cos_sin_cache, bool const interleave, void const* cos_sin_cache, bool const interleave,
int64_t const* position_ids, cudaStream_t stream) { int64_t const* position_ids, cudaStream_t stream) {
constexpr int blockSize = 256; constexpr int blockSize = 256;
int const warpsPerBlock = blockSize / 32; int const warpsPerBlock = blockSize / 32;
int const totalQKHeads = num_heads_q + num_heads_k; int const totalQKHeads = num_heads_q + num_heads_k;
int const totalWarps = num_tokens * totalQKHeads; int const totalWarps = num_tokens * totalQKHeads;
int const gridSize = common::divUp(totalWarps, warpsPerBlock); int const gridSize = common::divUp(totalWarps, warpsPerBlock);
dim3 gridDim(gridSize); dim3 gridDim(gridSize);
dim3 blockDim(blockSize); dim3 blockDim(blockSize);
switch (head_dim) { switch (head_dim) {
case 64: case 64:
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
...@@ -360,6 +593,118 @@ void launchFusedQKNormRope(void* qkv, int const num_tokens, ...@@ -360,6 +593,118 @@ void launchFusedQKNormRope(void* qkv, int const num_tokens,
"Unsupported head dimension for fusedQKNormRope: ", head_dim); "Unsupported head dimension for fusedQKNormRope: ", head_dim);
} }
} }
// Launch: one warp processes token_heads_per_warp token-heads (1, 2, 4, or 8).
// When token_heads_per_warp == 1, delegates to the 1-head baseline above.
template <typename scalar_t_in, typename scalar_t_cache>
void launchFusedQKNormRopeNTokenHeads(
void* qkv, int const num_tokens, int const num_heads_q,
int const num_heads_k, int const num_heads_v, int const head_dim,
int const rotary_dim, float const eps, void const* q_weight,
void const* k_weight, void const* cos_sin_cache, bool const interleave,
int64_t const* position_ids, int const token_heads_per_warp,
cudaStream_t stream) {
TORCH_CHECK(token_heads_per_warp == 1 || token_heads_per_warp == 2 ||
token_heads_per_warp == 4 || token_heads_per_warp == 8,
"token_heads_per_warp must be 1, 2, 4, or 8, got ",
token_heads_per_warp);
// token_heads_per_warp == 1: delegate to the 1-head baseline kernel.
if (token_heads_per_warp == 1) {
launchFusedQKNormRope<scalar_t_in, scalar_t_cache>(
qkv, num_tokens, num_heads_q, num_heads_k, num_heads_v, head_dim,
rotary_dim, eps, q_weight, k_weight, cos_sin_cache, interleave,
position_ids, stream);
return;
}
// NTokenHeads kernel uses cp.async to load cos/sin in 16-byte chunks.
// If rotary_dim * sizeof(cache_dtype) is not a multiple of 16, the last
// cp.async would write past the shared memory allocation.
// Fall back to the base kernel instead of failing.
{
size_t const rotary_bytes =
static_cast<size_t>(rotary_dim) *
(std::is_same_v<scalar_t_cache, float> ? sizeof(float) : 2u);
if (rotary_bytes % 16 != 0) {
launchFusedQKNormRope<scalar_t_in, scalar_t_cache>(
qkv, num_tokens, num_heads_q, num_heads_k, num_heads_v, head_dim,
rotary_dim, eps, q_weight, k_weight, cos_sin_cache, interleave,
position_ids, stream);
return;
}
}
constexpr int blockSize = 256;
int const warpsPerBlock = blockSize / 32;
int const totalQKHeads = num_heads_q + num_heads_k;
// Grid: one warp per (token, head_chunk); same token → reuse cos/sin in smem.
int const head_chunks_per_token =
(totalQKHeads + token_heads_per_warp - 1) / token_heads_per_warp;
int const total_warps = num_tokens * head_chunks_per_token;
int const gridSize = common::divUp(total_warps, warpsPerBlock);
dim3 gridDim(gridSize);
dim3 blockDim(blockSize);
// Cache element size: float=4, bfloat16=2 (host-safe; kernel uses same
// layout).
size_t const cache_elem_size =
std::is_same_v<scalar_t_cache, float> ? sizeof(float) : 2u;
// QKV smem: token_heads_per_warp tiles per warp, each tile 32*(head_dim/32*2)
// = 2*head_dim bytes.
size_t const qkv_smem_per_warp = static_cast<size_t>(token_heads_per_warp) *
2u * static_cast<size_t>(head_dim);
size_t const smem_bytes =
warpsPerBlock * static_cast<size_t>(rotary_dim) * cache_elem_size +
warpsPerBlock * qkv_smem_per_warp;
#define LAUNCH_N_TOKEN_HEADS(N) \
do { \
switch (head_dim) { \
case 64: \
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { \
fusedQKNormRopeKernelNTokenHeads<scalar_t_in, scalar_t_cache, 64, \
INTERLEAVE, (N)> \
<<<gridDim, blockDim, smem_bytes, stream>>>( \
qkv, num_heads_q, num_heads_k, num_heads_v, eps, q_weight, \
k_weight, cos_sin_cache, position_ids, num_tokens, \
rotary_dim); \
}); \
break; \
case 128: \
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { \
fusedQKNormRopeKernelNTokenHeads<scalar_t_in, scalar_t_cache, 128, \
INTERLEAVE, (N)> \
<<<gridDim, blockDim, smem_bytes, stream>>>( \
qkv, num_heads_q, num_heads_k, num_heads_v, eps, q_weight, \
k_weight, cos_sin_cache, position_ids, num_tokens, \
rotary_dim); \
}); \
break; \
case 256: \
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { \
fusedQKNormRopeKernelNTokenHeads<scalar_t_in, scalar_t_cache, 256, \
INTERLEAVE, (N)> \
<<<gridDim, blockDim, smem_bytes, stream>>>( \
qkv, num_heads_q, num_heads_k, num_heads_v, eps, q_weight, \
k_weight, cos_sin_cache, position_ids, num_tokens, \
rotary_dim); \
}); \
break; \
default: \
TORCH_CHECK(false, "Unsupported head dimension: ", head_dim); \
} \
} while (0)
if (token_heads_per_warp == 2) {
LAUNCH_N_TOKEN_HEADS(2);
} else if (token_heads_per_warp == 4) {
LAUNCH_N_TOKEN_HEADS(4);
} else if (token_heads_per_warp == 8) {
LAUNCH_N_TOKEN_HEADS(8);
}
#undef LAUNCH_N_TOKEN_HEADS
}
} // namespace tensorrt_llm::kernels } // namespace tensorrt_llm::kernels
void fused_qk_norm_rope( void fused_qk_norm_rope(
...@@ -374,7 +719,8 @@ void fused_qk_norm_rope( ...@@ -374,7 +719,8 @@ void fused_qk_norm_rope(
torch::Tensor& k_weight, // RMSNorm weights for key [head_dim] torch::Tensor& k_weight, // RMSNorm weights for key [head_dim]
torch::Tensor& cos_sin_cache, // Cos/sin cache [max_position, head_dim] torch::Tensor& cos_sin_cache, // Cos/sin cache [max_position, head_dim]
bool is_neox, // Whether RoPE is applied in Neox style bool is_neox, // Whether RoPE is applied in Neox style
torch::Tensor& position_ids // Position IDs for RoPE [num_tokens] torch::Tensor& position_ids, // Position IDs for RoPE [num_tokens]
int64_t forced_token_heads_per_warp // -1 = auto-select, >0 = forced value
) { ) {
// Input validation // Input validation
CHECK_INPUT(qkv); CHECK_INPUT(qkv);
...@@ -414,15 +760,48 @@ void fused_qk_norm_rope( ...@@ -414,15 +760,48 @@ void fused_qk_norm_rope(
qkv.size(1) == total_heads * head_dim, qkv.size(1) == total_heads * head_dim,
"QKV tensor size must match total number of heads and head dimension"); "QKV tensor size must match total number of heads and head dimension");
auto stream = at::cuda::getCurrentCUDAStream(qkv.get_device()); auto device_id = qkv.get_device();
auto stream = at::cuda::getCurrentCUDAStream(device_id);
// Select token_heads_per_warp: forced value if >0, else auto-select.
// Auto thresholds are calibrated on SM 9.0 (H100). On other architectures,
// fall back to token_heads_per_warp=1 (base kernel) until profiled.
int token_heads_per_warp;
if (forced_token_heads_per_warp > 0) { // only support SM80+
token_heads_per_warp = static_cast<int>(forced_token_heads_per_warp);
} else {
token_heads_per_warp = 1;
auto* dev_prop = at::cuda::getDeviceProperties(device_id);
int sm_version = dev_prop->major * 10 + dev_prop->minor;
int64_t total_qk_units = num_tokens * (num_heads_q + num_heads_k);
if (sm_version == 90) {
if (head_dim >= 256) {
if (total_qk_units < 4096LL) {
token_heads_per_warp = 1;
} else if (total_qk_units < 8192LL) {
token_heads_per_warp = 2;
} else {
token_heads_per_warp = 4;
}
} else {
if (total_qk_units < 10240LL) {
token_heads_per_warp = 1;
} else if (total_qk_units < 40960LL) {
token_heads_per_warp = 4;
} else {
token_heads_per_warp = 8;
}
}
}
}
VLLM_DISPATCH_HALF_TYPES(qkv.scalar_type(), "fused_qk_norm_rope_kernel", [&] { VLLM_DISPATCH_HALF_TYPES(qkv.scalar_type(), "fused_qk_norm_rope_kernel", [&] {
using qkv_scalar_t = scalar_t; using qkv_scalar_t = scalar_t;
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(
cos_sin_cache.scalar_type(), "fused_qk_norm_rope_kernel", [&] { cos_sin_cache.scalar_type(), "fused_qk_norm_rope_kernel", [&] {
using cache_scalar_t = scalar_t; using cache_scalar_t = scalar_t;
tensorrt_llm::kernels::launchFusedQKNormRope<qkv_scalar_t, tensorrt_llm::kernels::launchFusedQKNormRopeNTokenHeads<
cache_scalar_t>( qkv_scalar_t, cache_scalar_t>(
qkv.data_ptr(), static_cast<int>(num_tokens), qkv.data_ptr(), static_cast<int>(num_tokens),
static_cast<int>(num_heads_q), static_cast<int>(num_heads_k), static_cast<int>(num_heads_q), static_cast<int>(num_heads_k),
static_cast<int>(num_heads_v), static_cast<int>(head_dim), static_cast<int>(num_heads_v), static_cast<int>(head_dim),
...@@ -430,7 +809,7 @@ void fused_qk_norm_rope( ...@@ -430,7 +809,7 @@ void fused_qk_norm_rope(
q_weight.data_ptr(), k_weight.data_ptr(), q_weight.data_ptr(), k_weight.data_ptr(),
cos_sin_cache.data_ptr(), !is_neox, cos_sin_cache.data_ptr(), !is_neox,
reinterpret_cast<int64_t const*>(position_ids.data_ptr()), reinterpret_cast<int64_t const*>(position_ids.data_ptr()),
stream); token_heads_per_warp, stream);
}); });
}); });
} }
\ No newline at end of file
...@@ -96,7 +96,8 @@ void fused_qk_norm_rope(torch::Tensor& qkv, int64_t num_heads_q, ...@@ -96,7 +96,8 @@ void fused_qk_norm_rope(torch::Tensor& qkv, int64_t num_heads_q,
int64_t num_heads_k, int64_t num_heads_v, int64_t num_heads_k, int64_t num_heads_v,
int64_t head_dim, double eps, torch::Tensor& q_weight, int64_t head_dim, double eps, torch::Tensor& q_weight,
torch::Tensor& k_weight, torch::Tensor& cos_sin_cache, torch::Tensor& k_weight, torch::Tensor& cos_sin_cache,
bool is_neox, torch::Tensor& position_ids); bool is_neox, torch::Tensor& position_ids,
int64_t forced_token_heads_per_warp);
void apply_repetition_penalties_(torch::Tensor& logits, void apply_repetition_penalties_(torch::Tensor& logits,
const torch::Tensor& prompt_mask, const torch::Tensor& prompt_mask,
...@@ -320,4 +321,4 @@ std::tuple<torch::Tensor, torch::Tensor> minimax_allreduce_rms_qk( ...@@ -320,4 +321,4 @@ std::tuple<torch::Tensor, torch::Tensor> minimax_allreduce_rms_qk(
torch::Tensor const& norm_weight_k, torch::Tensor workspace, torch::Tensor const& norm_weight_k, torch::Tensor workspace,
int64_t const q_size, int64_t const kv_size, int64_t const rank, int64_t const q_size, int64_t const kv_size, int64_t const rank,
int64_t const nranks, double const eps); int64_t const nranks, double const eps);
#endif #endif
\ No newline at end of file
...@@ -173,7 +173,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -173,7 +173,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"fused_qk_norm_rope(Tensor! qkv, int num_heads_q, " "fused_qk_norm_rope(Tensor! qkv, int num_heads_q, "
"int num_heads_k, int num_heads_v, int head_dim, float eps, " "int num_heads_k, int num_heads_v, int head_dim, float eps, "
"Tensor q_weight, Tensor k_weight, Tensor cos_sin_cache, " "Tensor q_weight, Tensor k_weight, Tensor cos_sin_cache, "
"bool is_neox, Tensor position_ids) -> ()"); "bool is_neox, Tensor position_ids, "
"int forced_token_heads_per_warp=-1) -> ()");
ops.impl("fused_qk_norm_rope", torch::kCUDA, &fused_qk_norm_rope); ops.impl("fused_qk_norm_rope", torch::kCUDA, &fused_qk_norm_rope);
// Apply repetition penalties to logits in-place // Apply repetition penalties to logits in-place
......
...@@ -435,6 +435,7 @@ def fused_qk_norm_rope( ...@@ -435,6 +435,7 @@ def fused_qk_norm_rope(
cos_sin_cache: torch.Tensor, cos_sin_cache: torch.Tensor,
is_neox: bool, is_neox: bool,
position_ids: torch.Tensor, position_ids: torch.Tensor,
forced_token_heads_per_warp: int = -1,
) -> None: ) -> None:
torch.ops._C.fused_qk_norm_rope( torch.ops._C.fused_qk_norm_rope(
qkv, qkv,
...@@ -448,6 +449,7 @@ def fused_qk_norm_rope( ...@@ -448,6 +449,7 @@ def fused_qk_norm_rope(
cos_sin_cache, cos_sin_cache,
is_neox, is_neox,
position_ids, position_ids,
forced_token_heads_per_warp,
) )
......
...@@ -164,6 +164,7 @@ class QkNormRopePattern: ...@@ -164,6 +164,7 @@ class QkNormRopePattern:
cos_sin_cache=cos_sin_cache, cos_sin_cache=cos_sin_cache,
is_neox=self.is_neox, is_neox=self.is_neox,
position_ids=positions.view(-1), position_ids=positions.view(-1),
forced_token_heads_per_warp=-1,
) )
result_qkv = result[1] result_qkv = result[1]
......
...@@ -168,6 +168,7 @@ class FixFunctionalizationPass(VllmInductorPass): ...@@ -168,6 +168,7 @@ class FixFunctionalizationPass(VllmInductorPass):
"cos_sin_cache", "cos_sin_cache",
"is_neox", "is_neox",
"position_ids", "position_ids",
"forced_token_heads_per_warp",
) )
self.defunctionalize(graph, node, mutated_args=mutated_args, args=args) self.defunctionalize(graph, node, mutated_args=mutated_args, args=args)
elif ( elif (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment