"vscode:/vscode.git/clone" did not exist on "b51255f369cf45456e3062e32ecbfebd03a9f169"
Unverified Commit 4d51588e authored by Yifan Qiao's avatar Yifan Qiao Committed by GitHub
Browse files

[Feat] DeepSeek V4 Rebased (#40860)


Signed-off-by: default avatarYifan Qiao <yifanqiao@inferact.ai>
Signed-off-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
Signed-off-by: default avatarqizixi <zixi@inferact.ai>
Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
Signed-off-by: default avatarYongye Zhu <zyy1102000@gmail.com>
Co-authored-by: default avatarYongye Zhu <zyy1102000@gmail.com>
Co-authored-by: default avatarYongye Zhu <yongye@inferact.ai>
Co-authored-by: default avatarSimon Mo <simon@inferact.ai>
Co-authored-by: default avatarBugen Zhao <i@bugenzhao.com>
Co-authored-by: default avatarGiancarlo Delfin <gdelfin@inferact.ai>
Co-authored-by: default avatarJee Jee Li <pandaleefree@gmail.com>
Co-authored-by: default avatarNick Hill <nickhill123@gmail.com>
Co-authored-by: default avatarRoger Wang <hey@rogerw.io>
Co-authored-by: default avatarRoy Wang <yasong.wang@inferact.ai>
Co-authored-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
Co-authored-by: default avataryoukaichao <youkaichao@gmail.com>
Co-authored-by: default avatarZhewen Li <jerven.vllm@gmail.com>
Co-authored-by: default avatarZijing Liu <liuzijing2014@gmail.com>
Co-authored-by: default avatarkhluu <khluu000@gmail.com>
Co-authored-by: default avatarqizixi <zixi@inferact.ai>
Co-authored-by: Zh...
parent 32e45636
......@@ -310,7 +310,9 @@ set(VLLM_EXT_SRC
"csrc/torch_bindings.cpp")
if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_EXT_SRC "csrc/minimax_reduce_rms_kernel.cu")
list(APPEND VLLM_EXT_SRC
"csrc/minimax_reduce_rms_kernel.cu"
"csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu")
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
......@@ -1051,7 +1053,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_MOE_EXT_SRC
"csrc/moe/moe_wna16.cu"
"csrc/moe/grouped_topk_kernels.cu"
"csrc/moe/router_gemm.cu")
"csrc/moe/router_gemm.cu"
"csrc/moe/topk_softplus_sqrt_kernels.cu")
endif()
if(VLLM_GPU_LANG STREQUAL "CUDA")
......
......@@ -20,7 +20,7 @@ else()
FetchContent_Declare(
deepgemm
GIT_REPOSITORY https://github.com/deepseek-ai/DeepGEMM.git
GIT_TAG 477618cd51baffca09c4b0b87e97c03fe827ef03
GIT_TAG 891d57b4db1071624b5c8fa0d1e51cb317fa709f
GIT_SUBMODULES "third-party/cutlass" "third-party/fmt"
GIT_PROGRESS TRUE
CONFIGURE_COMMAND ""
......@@ -120,6 +120,11 @@ if(DEEPGEMM_ARCHS)
COMPONENT _deep_gemm_C
FILES_MATCHING PATTERN "*.py")
install(DIRECTORY "${deepgemm_SOURCE_DIR}/deep_gemm/mega/"
DESTINATION vllm/third_party/deep_gemm/mega
COMPONENT _deep_gemm_C
FILES_MATCHING PATTERN "*.py")
# Generate envs.py (normally generated by DeepGEMM's setup.py build step)
file(WRITE "${CMAKE_CURRENT_BINARY_DIR}/deep_gemm_envs.py"
"# Pre-installed environment variables\npersistent_envs = dict()\n")
......
......@@ -19,7 +19,7 @@ else()
FetchContent_Declare(
flashmla
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA
GIT_TAG 692917b1cda61b93ac9ee2d846ec54e75afe87b1
GIT_TAG a6ec2ba7bd0a7dff98b3f4d3e6b52b159c48d78b
GIT_PROGRESS TRUE
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
......
......@@ -178,7 +178,12 @@ void rotary_embedding_gptj_impl(
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
std::optional<torch::Tensor> key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox) {
torch::Tensor& cos_sin_cache, bool is_neox,
int64_t rope_dim_offset, bool inverse) {
TORCH_CHECK(rope_dim_offset == 0,
"rope_dim_offset != 0 is not supported on CPU");
TORCH_CHECK(!inverse, "inverse rotary embedding is not supported on CPU");
int num_tokens = positions.numel();
int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(-1) / head_size;
......
......@@ -263,7 +263,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def(
"rotary_embedding(Tensor positions, Tensor! query,"
" Tensor!? key, int head_size,"
" Tensor cos_sin_cache, bool is_neox) -> ()");
" Tensor cos_sin_cache, bool is_neox, int "
"rope_dim_offset=0, bool inverse=False) -> ()");
ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding);
// Quantization
......
This diff is collapsed.
......@@ -77,7 +77,8 @@ __global__ void rms_norm_kernel(
#pragma unroll
for (int j = 0; j < VEC_SIZE; j++) {
float x = static_cast<float>(src1.val[j]);
dst.val[j] = ((scalar_t)(x * s_variance)) * src2.val[j];
float w = static_cast<float>(src2.val[j]);
dst.val[j] = static_cast<scalar_t>(x * s_variance * w);
}
v_out[i] = dst;
}
......@@ -134,10 +135,17 @@ fused_add_rms_norm_kernel(
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx;
int64_t strided_id = blockIdx.x * vec_input_stride + idx;
_f16Vec<scalar_t, width> temp = residual_v[id];
temp *= s_variance;
temp *= weight_v[idx];
input_v[strided_id] = temp;
_f16Vec<scalar_t, width> res = residual_v[id];
_f16Vec<scalar_t, width> w = weight_v[idx];
_f16Vec<scalar_t, width> out;
using Converter = _typeConvert<scalar_t>;
#pragma unroll
for (int j = 0; j < width; ++j) {
float x = Converter::convert(res.data[j]);
float wf = Converter::convert(w.data[j]);
out.data[j] = Converter::convert(x * s_variance * wf);
}
input_v[strided_id] = out;
}
}
......@@ -174,8 +182,8 @@ fused_add_rms_norm_kernel(
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)residual[blockIdx.x * hidden_size + idx];
input[blockIdx.x * input_stride + idx] =
((scalar_t)(x * s_variance)) * weight[idx];
float w = (float)weight[idx];
input[blockIdx.x * input_stride + idx] = (scalar_t)(x * s_variance * w);
}
}
......
......@@ -65,9 +65,16 @@ __global__ void rms_norm_static_fp8_quant_kernel(
#pragma unroll
for (int j = 0; j < VEC_SIZE; j++) {
float x = static_cast<float>(src1.val[j]);
float const out_norm = ((scalar_t)(x * s_variance)) * src2.val[j];
float w = static_cast<float>(src2.val[j]);
// Round normalized result through scalar_t to match the precision of the
// unfused composite (rms_norm writes scalar_t, then
// static_scaled_fp8_quant re-loads it as float before FP8 conversion).
// Without this round, the fused path is strictly more accurate and
// disagrees with the composite at exact E4M3 quantization tie boundaries.
scalar_t out_norm = static_cast<scalar_t>(x * s_variance * w);
out[blockIdx.x * hidden_size + idx * VEC_SIZE + j] =
scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
scaled_fp8_conversion<true, fp8_type>(static_cast<float>(out_norm),
scale_inv);
}
}
}
......@@ -127,13 +134,21 @@ fused_add_rms_norm_static_fp8_quant_kernel(
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx;
_f16Vec<scalar_t, width> temp = residual_v[id];
temp *= s_variance;
temp *= weight_v[idx];
_f16Vec<scalar_t, width> res = residual_v[id];
_f16Vec<scalar_t, width> w = weight_v[idx];
using Converter = _typeConvert<scalar_t>;
using HipT = typename Converter::hip_type;
#pragma unroll
for (int i = 0; i < width; ++i) {
out[id * width + i] =
scaled_fp8_conversion<true, fp8_type>(float(temp.data[i]), scale_inv);
float x = Converter::convert(res.data[i]);
float wf = Converter::convert(w.data[i]);
// See note in rms_norm_static_fp8_quant_kernel: round through scalar_t
// to match the unfused composite path at FP8 boundaries. We use the
// backend's hip_type for the intermediate since c10::Half/BFloat16 has
// ambiguous conversions on CUDA and no implicit conversion on ROCm.
HipT out_norm_h = Converter::convert(x * s_variance * wf);
out[id * width + i] = scaled_fp8_conversion<true, fp8_type>(
Converter::convert(out_norm_h), scale_inv);
}
}
}
......@@ -176,9 +191,12 @@ fused_add_rms_norm_static_fp8_quant_kernel(
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)residual[blockIdx.x * hidden_size + idx];
float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx];
out[blockIdx.x * hidden_size + idx] =
scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
float w = (float)weight[idx];
// See note in rms_norm_static_fp8_quant_kernel: round through scalar_t
// to match the unfused composite path at FP8 boundaries.
scalar_t out_norm = static_cast<scalar_t>(x * s_variance * w);
out[blockIdx.x * hidden_size + idx] = scaled_fp8_conversion<true, fp8_type>(
static_cast<float>(out_norm), scale_inv);
}
}
......
......@@ -12,6 +12,15 @@ void topk_sigmoid(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
torch::Tensor& gating_output, bool renormalize,
std::optional<torch::Tensor> bias);
void topk_softplus_sqrt(torch::Tensor& topk_weights,
torch::Tensor& topk_indices,
torch::Tensor& token_expert_indices,
torch::Tensor& gating_output, bool renormalize,
double routed_scaling_factor,
const c10::optional<torch::Tensor>& correction_bias,
const c10::optional<torch::Tensor>& input_ids,
const c10::optional<torch::Tensor>& tid2eid);
void moe_sum(torch::Tensor& input, torch::Tensor& output);
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
......
This diff is collapsed.
......@@ -16,6 +16,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"bias) -> ()");
m.impl("topk_sigmoid", torch::kCUDA, &topk_sigmoid);
#ifndef USE_ROCM
m.def(
"topk_softplus_sqrt(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"token_expert_indices, Tensor gating_output, bool renormalize, float "
"routed_scaling_factor, Tensor? "
"bias, Tensor? input_ids, Tensor? tid2eid) -> ()");
m.impl("topk_softplus_sqrt", torch::kCUDA, &topk_softplus_sqrt);
#endif
// Calculate the result of moe by summing up the partial results
// from all selected experts.
m.def("moe_sum(Tensor input, Tensor! output) -> ()");
......
......@@ -100,6 +100,11 @@ void fused_qk_norm_rope(torch::Tensor& qkv, int64_t num_heads_q,
bool is_neox, torch::Tensor& position_ids,
int64_t forced_token_heads_per_warp);
void fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
torch::Tensor& q, torch::Tensor const& kv, torch::Tensor& k_cache,
torch::Tensor const& slot_mapping, torch::Tensor const& position_ids,
torch::Tensor const& cos_sin_cache, double eps, int64_t cache_block_size);
void apply_repetition_penalties_(torch::Tensor& logits,
const torch::Tensor& prompt_mask,
const torch::Tensor& output_mask,
......@@ -153,7 +158,8 @@ void silu_and_mul_per_block_quant(torch::Tensor& out,
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
std::optional<torch::Tensor> key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox);
torch::Tensor& cos_sin_cache, bool is_neox,
int64_t rope_dim_offset, bool inverse);
void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
......
......@@ -18,7 +18,6 @@ namespace persistent {
// Constants
// ============================================================================
constexpr int TopK = 2048;
constexpr int kThreadsPerBlock = 1024;
constexpr int RADIX = 256;
......@@ -128,11 +127,12 @@ struct RadixRowState {
struct PersistentTopKParams {
const float* __restrict__ input; // [num_rows, stride]
int32_t* __restrict__ output; // [num_rows, TopK]
int32_t* __restrict__ output; // [num_rows, top_k]
int32_t* __restrict__ lengths; // [num_rows]
RadixRowState* row_states; // large path: per-group state
uint32_t num_rows;
uint32_t stride;
uint32_t top_k; // actual k value for output stride
uint32_t chunk_size; // large path: elements per CTA
uint32_t ctas_per_group; // 1=medium, >1=large
uint32_t max_seq_len; // max seq_len across all rows (for early CTA exit)
......@@ -154,6 +154,7 @@ __device__ __forceinline__ uint32_t decode_bin(float x) {
return key >> 5;
}
template <int TopK>
__device__ __noinline__ void histogram_2048_topk(
const float* __restrict__ logits, int32_t* __restrict__ output_indices,
int32_t seq_len) {
......@@ -418,6 +419,7 @@ __device__ __noinline__ void histogram_2048_topk(
// by: DarkSharpness
// which at the same time is an optimized topk kernel copied from tilelang
// kernel
template <int TopK>
__device__ __noinline__ void histogram_256_topk(
const float* __restrict__ logits, int* __restrict__ output_indices,
int logits_offset, int seq_len) {
......@@ -649,7 +651,7 @@ __device__ __forceinline__ void wait_ge(int* ptr, int target_val,
// Adapted from https://github.com/flashinfer-ai/flashinfer/pull/2215
// ============================================================================
template <uint32_t VEC_SIZE>
template <int TopK, uint32_t VEC_SIZE>
__device__ void radix_topk(const float* __restrict__ row_input,
int32_t* __restrict__ row_output, uint32_t seq_len,
uint32_t my_chunk_start, uint32_t chunk_size,
......@@ -857,7 +859,7 @@ __device__ void radix_topk(const float* __restrict__ row_input,
// see filtered_topk.cuh)
// ============================================================================
template <uint32_t VEC_SIZE = 1>
template <int TopK = 2048, uint32_t VEC_SIZE = 1>
__global__ void __launch_bounds__(kThreadsPerBlock, 2)
persistent_topk_kernel(PersistentTopKParams params) {
const uint32_t tx = threadIdx.x;
......@@ -915,7 +917,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock, 2)
if (row_idx >= params.num_rows) break;
const uint32_t seq_len = params.lengths[row_idx];
int32_t* row_output = params.output + row_idx * TopK;
int32_t* row_output = params.output + row_idx * params.top_k;
const float* row_input = params.input + row_idx * params.stride;
if (seq_len <= RADIX_THRESHOLD) {
......@@ -927,19 +929,19 @@ __global__ void __launch_bounds__(kThreadsPerBlock, 2)
row_output[i] = (i < seq_len) ? static_cast<int32_t>(i) : -1;
}
} else if (seq_len <= static_cast<uint32_t>(HIST2048_THRESHOLD)) {
histogram_2048_topk(row_input, row_output, seq_len);
histogram_2048_topk<TopK>(row_input, row_output, seq_len);
} else {
histogram_256_topk(row_input, row_output, 0, seq_len);
histogram_256_topk<TopK>(row_input, row_output, 0, seq_len);
}
}
continue;
}
const uint32_t my_chunk_start = cta_in_group * chunk_size;
radix_topk<VEC_SIZE>(row_input, row_output, seq_len, my_chunk_start,
chunk_size, local_histogram, suffix_sum,
shared_scalars, shared_ordered, state, cta_in_group,
ctas_per_group, barrier_phase, iter, tx);
radix_topk<TopK, VEC_SIZE>(
row_input, row_output, seq_len, my_chunk_start, chunk_size,
local_histogram, suffix_sum, shared_scalars, shared_ordered, state,
cta_in_group, ctas_per_group, barrier_phase, iter, tx);
}
}
......@@ -1011,7 +1013,6 @@ struct FilteredTopKTraits<float> {
}
};
constexpr uint32_t FILTERED_TOPK_MAX_K = 2048;
constexpr uint32_t FILTERED_TOPK_BLOCK_THREADS = 1024;
constexpr uint32_t FILTERED_TOPK_SMEM_INPUT_SIZE =
16 * 1024; // 16K indices per buffer
......@@ -1025,7 +1026,7 @@ constexpr size_t FILTERED_TOPK_SMEM_DYNAMIC =
* \tparam IdType Index type (int32_t)
* \tparam VEC_SIZE Vector size for input loads (1, 2, 4, or 8)
*/
template <typename DType, typename IdType, int VEC_SIZE>
template <typename DType, typename IdType, int VEC_SIZE, uint32_t MAX_K = 2048>
__global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS)
FilteredTopKUnifiedKernel(const DType* __restrict__ input,
IdType* __restrict__ output,
......@@ -1059,7 +1060,7 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS)
alignas(128) __shared__ int s_counter;
alignas(128) __shared__ int s_threshold_bin_id;
alignas(128) __shared__ int s_num_input[2];
alignas(128) __shared__ int s_indices[FILTERED_TOPK_MAX_K];
alignas(128) __shared__ int s_indices[MAX_K];
auto& s_histogram = s_histogram_buf[0];
......@@ -1280,7 +1281,7 @@ constexpr int ComputeFilteredTopKVecSize(uint32_t max_len) {
return static_cast<int>(g);
}
template <typename DType, typename IdType>
template <typename DType, typename IdType, uint32_t MAX_K = 2048>
cudaError_t FilteredTopKRaggedTransform(DType* input, IdType* output_indices,
IdType* lengths, uint32_t num_rows,
uint32_t top_k_val, uint32_t max_len,
......@@ -1297,7 +1298,7 @@ cudaError_t FilteredTopKRaggedTransform(DType* input, IdType* output_indices,
#define DISPATCH_VEC_SIZE(VS) \
if (vec_size == VS) { \
auto kernel = FilteredTopKUnifiedKernel<DType, IdType, VS>; \
auto kernel = FilteredTopKUnifiedKernel<DType, IdType, VS, MAX_K>; \
FLASHINFER_CUDA_CALL(cudaFuncSetAttribute( \
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); \
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, grid, block, args, \
......
......@@ -9,28 +9,29 @@ namespace vllm {
template <typename scalar_t, bool IS_NEOX>
inline __device__ void apply_token_rotary_embedding(
scalar_t* __restrict__ arr, const scalar_t* __restrict__ cos_ptr,
const scalar_t* __restrict__ sin_ptr, int rot_offset, int embed_dim) {
scalar_t* __restrict__ arr, const float* __restrict__ cos_ptr,
const float* __restrict__ sin_ptr, int rot_offset, int embed_dim,
const bool inverse) {
int x_index, y_index;
scalar_t cos, sin;
float cos_f, sin_f;
if (IS_NEOX) {
// GPT-NeoX style rotary embedding.
x_index = rot_offset;
y_index = embed_dim + rot_offset;
cos = VLLM_LDG(cos_ptr + x_index);
sin = VLLM_LDG(sin_ptr + x_index);
cos_f = VLLM_LDG(cos_ptr + x_index);
sin_f = VLLM_LDG(sin_ptr + x_index);
} else {
// GPT-J style rotary embedding.
x_index = 2 * rot_offset;
y_index = 2 * rot_offset + 1;
cos = VLLM_LDG(cos_ptr + x_index / 2);
sin = VLLM_LDG(sin_ptr + x_index / 2);
cos_f = VLLM_LDG(cos_ptr + x_index / 2);
sin_f = VLLM_LDG(sin_ptr + x_index / 2);
}
const scalar_t x = arr[x_index];
const scalar_t y = arr[y_index];
arr[x_index] = x * cos - y * sin;
arr[y_index] = y * cos + x * sin;
if (inverse) {
sin_f = -sin_f;
}
const float x_f = static_cast<float>(arr[x_index]);
const float y_f = static_cast<float>(arr[y_index]);
arr[x_index] = static_cast<scalar_t>(x_f * cos_f - y_f * sin_f);
arr[y_index] = static_cast<scalar_t>(y_f * cos_f + x_f * sin_f);
}
template <typename scalar_t, bool IS_NEOX>
......@@ -42,22 +43,23 @@ inline __device__ void apply_rotary_embedding(
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
const scalar_t* cache_ptr, const int head_size, const int num_heads,
const float* cache_ptr, const int head_size, const int num_heads,
const int num_kv_heads, const int rot_dim, const int token_idx,
const int64_t query_stride, const int64_t key_stride,
const int64_t head_stride) {
const int64_t head_stride, const int64_t rope_dim_offset,
const bool inverse) {
const int embed_dim = rot_dim / 2;
const scalar_t* cos_ptr = cache_ptr;
const scalar_t* sin_ptr = cache_ptr + embed_dim;
const float* cos_ptr = cache_ptr;
const float* sin_ptr = cache_ptr + embed_dim;
const int nq = num_heads * embed_dim;
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int64_t token_head =
token_idx * query_stride + head_idx * head_stride;
token_idx * query_stride + head_idx * head_stride + rope_dim_offset;
const int rot_offset = i % embed_dim;
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim, inverse);
}
if (key != nullptr) {
......@@ -65,10 +67,10 @@ inline __device__ void apply_rotary_embedding(
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int64_t token_head =
token_idx * key_stride + head_idx * head_stride;
token_idx * key_stride + head_idx * head_stride + rope_dim_offset;
const int rot_offset = i % embed_dim;
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim, inverse);
}
}
}
......@@ -84,19 +86,18 @@ __global__ void rotary_embedding_kernel(
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
// 2]
const float* __restrict__ cos_sin_cache, // [max_position, rot_dim] fp32
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int64_t head_stride, const int num_heads, const int num_kv_heads,
const int head_size) {
// Each thread block is responsible for one token.
const int head_size, const int64_t rope_dim_offset, const bool inverse) {
const int token_idx = blockIdx.x;
int64_t pos = positions[token_idx];
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
const float* cache_ptr = cos_sin_cache + pos * rot_dim;
apply_rotary_embedding<scalar_t, IS_NEOX>(
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
token_idx, query_stride, key_stride, head_stride);
token_idx, query_stride, key_stride, head_stride, rope_dim_offset,
inverse);
}
} // namespace vllm
......@@ -115,7 +116,7 @@ void rotary_embedding(
// [num_tokens, num_heads, head_size]
int64_t head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox) {
bool is_neox, int64_t rope_dim_offset, bool inverse) {
// num_tokens = batch_size * seq_len
int64_t num_tokens = positions.numel();
int positions_ndim = positions.dim();
......@@ -154,6 +155,8 @@ void rotary_embedding(
int seq_dim_idx = positions_ndim - 1;
int64_t query_stride = query.stride(seq_dim_idx);
int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0;
TORCH_CHECK((rot_dim + rope_dim_offset) <= head_size);
// Determine head stride: for [*, heads, head_size] use stride of last dim;
// for flat [*, heads*head_size], heads blocks are contiguous of size
// head_size
......@@ -165,20 +168,23 @@ void rotary_embedding(
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
auto cache_f32 = cos_sin_cache.to(torch::kFloat32);
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
if (is_neox) {
vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride, key_stride,
head_stride, num_heads, num_kv_heads, head_size);
cache_f32.data_ptr<float>(), rot_dim, query_stride, key_stride,
head_stride, num_heads, num_kv_heads, head_size, rope_dim_offset,
inverse);
} else {
vllm::rotary_embedding_kernel<scalar_t, false>
<<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride,
key_stride, head_stride, num_heads, num_kv_heads, head_size);
cache_f32.data_ptr<float>(), rot_dim, query_stride, key_stride,
head_stride, num_heads, num_kv_heads, head_size, rope_dim_offset,
inverse);
}
});
}
......@@ -258,7 +258,13 @@ __device__ bool processHistogramStep(
auto processBins = [&](float logit, int idx) {
if (isPartialMatch<patternShift>(logit, logitPattern)) {
uint32_t binIdx = extractBinIdx<step>(logit);
if (binIdx < thresholdBinIdx) {
// Only write elements with binIdx < thresholdBinIdx when:
// 1. This is step 0 and the threshold bin is small enough (no step 1)
// 2. This is step >= 1 (where pattern matching filters correctly)
// This prevents duplicates when step 0 and step 1 both run.
bool shouldWriteDirectly =
(step == 0 && smemFinalBinSize[0] <= kNumFinalItems) || (step >= 1);
if (binIdx < thresholdBinIdx && shouldWriteDirectly) {
// The element is part of the top-k selection
int dstIdx = atomicAdd(&smemFoundTopKValues[0], 1);
......
......@@ -10,33 +10,17 @@
#include "persistent_topk.cuh"
#endif
void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
torch::Tensor& output, torch::Tensor& workspace, int64_t k,
int64_t max_seq_len) {
namespace {
#ifndef USE_ROCM
TORCH_CHECK(logits.is_cuda(), "logits must be CUDA tensor");
TORCH_CHECK(lengths.is_cuda(), "lengths must be CUDA tensor");
TORCH_CHECK(output.is_cuda(), "output must be CUDA tensor");
TORCH_CHECK(logits.dtype() == torch::kFloat32, "Only float32 supported");
TORCH_CHECK(lengths.dtype() == torch::kInt32, "lengths must be int32");
TORCH_CHECK(output.dtype() == torch::kInt32, "output must be int32");
TORCH_CHECK(logits.dim() == 2, "logits must be 2D");
TORCH_CHECK(lengths.dim() == 1 || lengths.dim() == 2,
"lengths must be 1D or 2D");
TORCH_CHECK(lengths.is_contiguous(), "lengths must be contiguous");
TORCH_CHECK(output.dim() == 2, "output must be 2D");
template <int TopK>
void launch_persistent_topk(const torch::Tensor& logits,
const torch::Tensor& lengths, torch::Tensor& output,
torch::Tensor& workspace, int64_t max_seq_len) {
namespace P = vllm::persistent;
const int64_t num_rows = logits.size(0);
const int64_t stride = logits.size(1);
TORCH_CHECK(lengths.numel() == num_rows, "lengths size mismatch");
TORCH_CHECK(output.size(0) == num_rows && output.size(1) == k,
"output size mismatch");
namespace P = vllm::persistent;
TORCH_CHECK(k == P::TopK, "k must be 2048");
TORCH_CHECK(k <= stride, "k out of range");
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
static int num_sms = 0;
......@@ -50,18 +34,17 @@ void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
}
if (num_rows > 32 && max_smem_per_block >= 128 * 1024) {
cudaError_t status = vllm::FilteredTopKRaggedTransform<float, int32_t>(
cudaError_t status =
vllm::FilteredTopKRaggedTransform<float, int32_t, TopK>(
logits.data_ptr<float>(), output.data_ptr<int32_t>(),
lengths.data_ptr<int32_t>(), static_cast<uint32_t>(num_rows),
static_cast<uint32_t>(k), static_cast<uint32_t>(stride), stream);
static_cast<uint32_t>(TopK), static_cast<uint32_t>(stride), stream);
TORCH_CHECK(status == cudaSuccess,
"FilteredTopK failed: ", cudaGetErrorString(status));
} else {
TORCH_CHECK(workspace.is_cuda(), "workspace must be CUDA tensor");
TORCH_CHECK(workspace.dtype() == torch::kUInt8, "workspace must be uint8");
// Smem cap: smaller smem → more CTAs/group → more per-row parallelism for
// large path. Empirically tuned.
int effective_max_smem;
if (num_rows <= 4) {
effective_max_smem =
......@@ -101,7 +84,7 @@ void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
int occupancy = 1;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&occupancy, P::persistent_topk_kernel<4>, P::kThreadsPerBlock,
&occupancy, P::persistent_topk_kernel<TopK, 4>, P::kThreadsPerBlock,
smem_size);
if (occupancy < 1) occupancy = 1;
......@@ -121,15 +104,16 @@ void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
params.lengths = lengths.data_ptr<int32_t>();
params.num_rows = static_cast<uint32_t>(num_rows);
params.stride = static_cast<uint32_t>(stride);
params.top_k = static_cast<uint32_t>(TopK);
params.chunk_size = chunk_size;
params.row_states =
reinterpret_cast<P::RadixRowState*>(workspace.data_ptr<uint8_t>());
params.ctas_per_group = ctas_per_group;
params.max_seq_len = static_cast<uint32_t>(max_seq_len);
#define LAUNCH_PERSISTENT(VS) \
#define LAUNCH_PERSISTENT(TOPK_VAL, VS) \
do { \
auto kernel = &P::persistent_topk_kernel<VS>; \
auto kernel = &P::persistent_topk_kernel<TOPK_VAL, VS>; \
cudaError_t err = cudaFuncSetAttribute( \
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); \
TORCH_CHECK(err == cudaSuccess, \
......@@ -138,11 +122,11 @@ void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
} while (0)
if (vec_size == 4) {
LAUNCH_PERSISTENT(4);
LAUNCH_PERSISTENT(TopK, 4);
} else if (vec_size == 2) {
LAUNCH_PERSISTENT(2);
LAUNCH_PERSISTENT(TopK, 2);
} else {
LAUNCH_PERSISTENT(1);
LAUNCH_PERSISTENT(TopK, 1);
}
#undef LAUNCH_PERSISTENT
}
......@@ -150,6 +134,46 @@ void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
cudaError_t err = cudaGetLastError();
TORCH_CHECK(err == cudaSuccess,
"persistent_topk failed: ", cudaGetErrorString(err));
}
#endif
} // anonymous namespace
void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
torch::Tensor& output, torch::Tensor& workspace, int64_t k,
int64_t max_seq_len) {
#ifndef USE_ROCM
TORCH_CHECK(logits.is_cuda(), "logits must be CUDA tensor");
TORCH_CHECK(lengths.is_cuda(), "lengths must be CUDA tensor");
TORCH_CHECK(output.is_cuda(), "output must be CUDA tensor");
TORCH_CHECK(logits.dtype() == torch::kFloat32, "Only float32 supported");
TORCH_CHECK(lengths.dtype() == torch::kInt32, "lengths must be int32");
TORCH_CHECK(output.dtype() == torch::kInt32, "output must be int32");
TORCH_CHECK(logits.dim() == 2, "logits must be 2D");
TORCH_CHECK(lengths.dim() == 1 || lengths.dim() == 2,
"lengths must be 1D or 2D");
TORCH_CHECK(lengths.is_contiguous(), "lengths must be contiguous");
TORCH_CHECK(output.dim() == 2, "output must be 2D");
const int64_t num_rows = logits.size(0);
const int64_t stride = logits.size(1);
TORCH_CHECK(lengths.numel() == num_rows, "lengths size mismatch");
TORCH_CHECK(output.size(0) == num_rows && output.size(1) == k,
"output size mismatch");
TORCH_CHECK(k == 512 || k == 1024 || k == 2048,
"persistent_topk supports k=512, k=1024, or k=2048, got k=", k);
if (k == 512) {
launch_persistent_topk<512>(logits, lengths, output, workspace,
max_seq_len);
} else if (k == 1024) {
launch_persistent_topk<1024>(logits, lengths, output, workspace,
max_seq_len);
} else {
launch_persistent_topk<2048>(logits, lengths, output, workspace,
max_seq_len);
}
#else
TORCH_CHECK(false, "persistent_topk is not supported on ROCm");
#endif
......
......@@ -177,6 +177,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"int forced_token_heads_per_warp=-1) -> ()");
ops.impl("fused_qk_norm_rope", torch::kCUDA, &fused_qk_norm_rope);
#ifndef USE_ROCM
// Horizontally-fused DeepseekV4-MLA: per-head RMSNorm + GPT-J RoPE for Q, and
// GPT-J RoPE + UE8M0 FP8 quant + paged cache insert for KV, all in one
// kernel launch.
ops.def(
"fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert("
"Tensor! q, Tensor kv, Tensor! k_cache, "
"Tensor slot_mapping, Tensor position_ids, Tensor cos_sin_cache, "
"float eps, int cache_block_size) -> ()");
ops.impl("fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert", torch::kCUDA,
&fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert);
#endif
// Apply repetition penalties to logits in-place
ops.def(
"apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, "
......@@ -240,7 +253,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def(
"rotary_embedding(Tensor positions, Tensor! query,"
" Tensor!? key, int head_size,"
" Tensor cos_sin_cache, bool is_neox) -> ()");
" Tensor cos_sin_cache, bool is_neox, int "
"rope_dim_offset=0, bool inverse=False) -> ()");
ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);
// Quantization ops
......
......@@ -213,7 +213,7 @@ configuration.
| `FLASHINFER_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | 10.x |
| `FLASHINFER_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 10.x |
| `FLASHMLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 64 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x-10.x |
| `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_ds_mla` | 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x |
| `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_ds_mla` | 64 | 512, 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x |
| `FLASH_ATTN_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x |
| `ROCM_AITER_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
| `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16` | 1 | Any | ❌ | ✅ | ❌ | ❌ | Decoder | N/A |
......
......@@ -384,6 +384,7 @@ th {
| `DeepseekForCausalLM` | DeepSeek | `deepseek-ai/deepseek-llm-67b-base`, `deepseek-ai/deepseek-llm-7b-chat`, etc. | ✅︎ | ✅︎ |
| `DeepseekV2ForCausalLM` | DeepSeek-V2 | `deepseek-ai/DeepSeek-V2`, `deepseek-ai/DeepSeek-V2-Chat`, etc. | ✅︎ | ✅︎ |
| `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3`, `deepseek-ai/DeepSeek-R1`, `deepseek-ai/DeepSeek-V3.1`, etc. | ✅︎ | ✅︎ |
| `DeepseekV4ForCausalLM` | DeepSeek-V4 | `deepseek-ai/DeepSeek-V4-Flash`, `deepseek-ai/DeepSeek-V4-Pro`, etc. | | |
| `Dots1ForCausalLM` | dots.llm1 | `rednote-hilab/dots.llm1.base`, `rednote-hilab/dots.llm1.inst`, etc. | | ✅︎ |
| `DotsOCRForCausalLM` | dots_ocr | `rednote-hilab/dots.ocr` | ✅︎ | ✅︎ |
| `Ernie4_5ForCausalLM` | Ernie4.5 | `baidu/ERNIE-4.5-0.3B-PT`, etc. | ✅︎ | ✅︎ |
......
......@@ -11,6 +11,8 @@ torchvision==0.26.0 # Required for phi3v processor. See https://github.com/pytor
# FlashInfer should be updated together with the Dockerfile
flashinfer-python==0.6.8.post1
flashinfer-cubin==0.6.8.post1
apache-tvm-ffi==0.1.9
tilelang==0.1.9
# Cap nvidia-cudnn-frontend (transitive dep of flashinfer) due to
# breaking changes in 1.19.0
nvidia-cudnn-frontend>=1.13.0,<1.19.0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment