Commit bd363067 authored by lizhigong's avatar lizhigong
Browse files

Merge branch 'v0.8.5.post1-dev' into v0.8.5-zero_overhead

parents 87ef4618 d36deb1a
...@@ -621,7 +621,8 @@ target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1) ...@@ -621,7 +621,8 @@ target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
set(VLLM_MOE_EXT_SRC set(VLLM_MOE_EXT_SRC
"csrc/moe/torch_bindings.cpp" "csrc/moe/torch_bindings.cpp"
"csrc/moe/moe_align_sum_kernels.cu" "csrc/moe/moe_align_sum_kernels.cu"
"csrc/moe/topk_softmax_kernels.cu") "csrc/moe/topk_softmax_kernels.cu"
"csrc/moe/moe_fused_gate.cu")
if(VLLM_GPU_LANG STREQUAL "CUDA") if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_MOE_EXT_SRC "csrc/moe/moe_wna16.cu") list(APPEND VLLM_MOE_EXT_SRC "csrc/moe/moe_wna16.cu")
......
...@@ -8,38 +8,40 @@ vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention ...@@ -8,38 +8,40 @@ vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention
## 支持模型结构列表 ## 支持模型结构列表
| 结构 | 模型 | FP16/BF16 | AWQ | GPTQ |
| :------: | :------: | :------: | :------: |:------: | | 结构 | 模型 | FP16/BF16 | AWQ | GPTQ | 支持版本 | 是否优化 |
| LlamaForCausalLM | Llama 3.2, Llama 3.1,Llama 3,Llama 2,Llama,Yi,Codellama,DeepSeek-R1-Distill-Llama | Yes | Yes | Yes | | :------: | :------: | :------: | :------: |:------: | :------: |:------: |
| Llama4ForConditionalGeneration | Llama 4 | No/Yes | - | - | | LlamaForCausalLM | Llama 3.2, Llama 3.1,Llama 3,Llama 2,Llama,Yi,Codellama,DeepSeek-R1-Distill-Llama | Yes | Yes | Yes | v0.5.0,Llama 3.2>=v0.6.2 | Yes |
| QWenLMHeadModel | QWen,Qwen-VL | Yes | Yes | Yes | | Llama4ForConditionalGeneration | Llama 4 | No/Yes | - | - | v0.8.5.post1 | No |
| Qwen2ForCausalLM | QWen2,QWen1.5,CodeQwen1.5,DeepSeek-R1-Distill-Qwen,gte_Qwen2-1.5B-instruct | Yes | Yes | Yes | | QWenLMHeadModel | QWen,Qwen-VL | Yes | Yes | Yes | v0.5.0,Qwen-VL>=v0.6.2 | Yes |
| Qwen3ForCausalLM | QWen3 | Yes | - | - | | Qwen2ForCausalLM | QWen2,QWen1.5,CodeQwen1.5,DeepSeek-R1-Distill-Qwen,gte_Qwen2-1.5B-instruct | Yes | Yes | Yes | v0.5.0,gte>=v0.7.2 | Yes |
| Qwen3MoeForCausalLM | QWen3MoE | Yes | - | - | | Qwen3ForCausalLM | QWen3 | Yes | - | - | v0.8.4 | Yes |
| ChatGLMModel | glm-4v-9b,chatglm3,chatglm2 | Yes | No | Yes | | Qwen3MoeForCausalLM | QWen3MoE | Yes | - | - | v0.8.4 | Yes |
| Glm4ForCausalLM | GLM-4-0414 | No/Yes | - | - | | ChatGLMModel | glm-4v-9b,chatglm3,chatglm2 | Yes | No | Yes | v0.5.0 | Yes |
| DeepseekForCausalLM | Deepseek | Yes | No | - | | Glm4ForCausalLM | GLM-4-0414 | No/Yes | - | - | v0.8.5.post1 | Yes |
| DeepseekV2ForCausalLM | DeepSeek-V2 | Yes | No | - | | DeepseekForCausalLM | Deepseek | Yes | No | - | v0.5.0 | Yes |
| DeepseekV3ForCausalLM | DeepSeek-V3 | Yes | Yes | - | | DeepseekV2ForCausalLM | DeepSeek-V2 | Yes | No | - | v0.6.2 | Yes |
| BaiChuanForCausalLM | Baichuan2,Baichuan | Yes | Yes | - | | DeepseekVLV2ForCausalLM | DeepSeek-VL2 | Yes | No | - | v0.7.2 | Yes |
| BloomForCausalLM | BLOOM | Yes | No | Yes | | DeepseekV3ForCausalLM | DeepSeek-V3 | Yes | Yes | - | v0.7.2 | Yes |
| InternLMForCausalLM | InternLM | Yes | No | - | | BaiChuanForCausalLM | Baichuan2,Baichuan | Yes | Yes | - | v0.5.0 | Yes |
| InternLM2ForCausalLM | InternLM2 | Yes | No | - | | BloomForCausalLM | BLOOM | Yes | No | Yes | v0.5.0 | Yes |
| FalconForCausalLM | falcon | Yes | No | Yes | | InternLMForCausalLM | InternLM | Yes | No | - | v0.5.0 | Yes |
| TeleChat2ForCausalLM | TeleChat2 | Yes | No | - | | InternLM2ForCausalLM | InternLM2 | Yes | No | - | v0.5.0 | Yes |
| MiniCPMForCausalLM | MiniCPM | Yes | No | - | | FalconForCausalLM | falcon | Yes | No | Yes | v0.5.0 | Yes |
| MiniCPM3ForCausalLM | MiniCPM3 | Yes | No | - | | TeleChat2ForCausalLM | TeleChat2 | Yes | No | - | v0.7.2 | Yes |
| MixtralForCausalLM | Mixtral-8x7B,Mixtral-8x7B-Instruct | Yes | No | - | | MiniCPMForCausalLM | MiniCPM | Yes | No | - | v0.5.0 | Yes |
| Qwen2MoeForCausalLM | Qwen2-57B-A14B,Qwen2-57B-A14B-Instruct | Yes | No | - | | MiniCPM3ForCausalLM | MiniCPM3 | Yes | No | - | v0.6.2 | Yes |
| LlavaForConditionalGeneration | LLaMA,LLaMA-2,LLaMA-3 | Yes | No | - | | MixtralForCausalLM | Mixtral-8x7B,Mixtral-8x7B-Instruct | Yes | No | - | v0.5.0 | Yes |
| Qwen2VLForConditionalGeneration | Qwen2-VL | Yes | No | Yes | | Qwen2MoeForCausalLM | Qwen2-57B-A14B,Qwen2-57B-A14B-Instruct | Yes | No | - | v0.5.0 | No |
| Qwen2_5_VLForConditionalGeneration | Qwen.5-VL | Yes | No | Yes | | LlavaForConditionalGeneration | LLaMA,LLaMA-2,LLaMA-3 | Yes | No | - | v0.6.2 | No |
| Gemma3ForConditionalGeneration | Gemma 3 | Yes | - | - | | Qwen2VLForConditionalGeneration | Qwen2-VL | Yes | No | Yes | v0.6.2 | No |
| MiniCPMV | MiniCPM-V | Yes | No | - | | Qwen2_5_VLForConditionalGeneration | Qwen.5-VL | Yes | No | Yes | v0.7.2 | No |
| Phi3VForCausalLM | Phi-3.5-vision | Yes | No | - | | Gemma3ForConditionalGeneration | Gemma 3 | Yes | - | - | v0.8.5.post1 | No |
| BertModel | bge-large-zh-v1.5 | Yes | No | - | | MiniCPMV | MiniCPM-V | Yes | No | - | v0.6.2 | No |
| XLMRobertaModel | bge-m3 | Yes | No | - | | Phi3VForCausalLM | Phi-3.5-vision | Yes | No | - | v0.6.2 | No |
| XLMRobertaForSequenceClassification | bge-reranker-v2-m3 | Yes | No | - | | BertModel | bge-large-zh-v1.5 | Yes | No | - | v0.7.2 | No |
| XLMRobertaModel | bge-m3 | Yes | No | - | v0.7.2 | No |
| XLMRobertaForSequenceClassification | bge-reranker-v2-m3 | Yes | No | - | v0.7.2 | No |
## 安装 ## 安装
......
...@@ -529,6 +529,14 @@ void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size] ...@@ -529,6 +529,14 @@ void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size]
}); });
break; break;
case 8:
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
vllm::moe::moe_sum_kernel<scalar_t, 8><<<grid, block, 0, stream>>>(
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
hidden_size);
});
break;
default: default:
at::sum_out(output, input, 1); at::sum_out(output, input, 1);
break; break;
......
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include "../cuda_compat.h"
// #include <cutlass/array.h>
// #include <cutlass/cutlass.h>
// #include <cutlass/numeric_types.h>
#include <stdio.h>
#include <torch/all.h>
#include <cfloat>
#include <type_traits>
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
#include "../quantization/fp8/amd/quant_utils.cuh"
typedef __hip_bfloat16 __nv_bfloat16;
#else
#include "../quantization/fp8/nvidia/quant_utils.cuh"
#endif
/// Aligned array type
template <
typename T,
/// Number of elements in the array
int N,
/// Alignment requirement in bytes
int Alignment = sizeof(T) * N
>
class alignas(Alignment) AlignedArray {
T data[N];
public:
__device__ T& operator[](int index) {
return data[index];
}
__device__ const T& operator[](int index) const {
return data[index];
}
};
// template <typename T, int N>
// using AlignedArray = cutlass::AlignedArray<T, N>;
// using bfloat16_t = cutlass::bfloat16_t;
// using float16_t = cutlass::half_t;
using float32_t = float;
// QQ NOTE: to handle the case for at::Half, error: more than one operator ">" matches these operands: built-in operator
// "arithmetic > arithmetic" function "operator>(const __half &, const __half &)"
template <typename T>
__device__ inline bool cmp_gt(const T& a, const T& b) {
if constexpr (std::is_same<T, at::Half>::value) {
// at::Half (or float16_t in our native case) causes ambiguity, so we cast to float.
return static_cast<float>(a) > static_cast<float>(b);
} else {
// For types like float, at::BFloat16, or cutlass::half_t / cutlass::bfloat16_t, assume operator> works as expected.
return a > b;
}
}
template <typename T>
__device__ inline bool cmp_eq(const T& a, const T& b) {
if constexpr (std::is_same<T, at::Half>::value) {
return static_cast<float>(a) == static_cast<float>(b);
} else {
return a == b;
}
}
// Fixed constants common to both dynamic and static template versions:
//static constexpr int WARP_SIZE = 32;
static constexpr int WARPS_PER_CTA = 6;
static constexpr int MAX_VPT = 32; // maximum VPT we support, > params.VPT = num_expert / num_expert_group
// Create an alias for Array using AlignedArray
template <typename T, int N>
using Array = AlignedArray<T, N>;
// QQ: NOTE expression must have a constant value, this has to be > params.VPT
template <typename T>
using AccessType = AlignedArray<T, MAX_VPT>;
template <typename T, typename Params>
__device__ void moe_fused_gate_impl(
void* input,
void* bias,
float* output_ptr,
int32_t* indices_ptr,
int64_t num_rows,
int64_t topk_group,
int64_t topk,
int64_t n_share_experts_fusion,
double routed_scaling_factor,
Params params) {
int tidx = threadIdx.x;
int64_t thread_row =
blockIdx.x * params.ROWS_PER_CTA + threadIdx.y * params.ROWS_PER_WARP + tidx / params.THREADS_PER_ROW;
if (thread_row >= num_rows) {
return;
}
// Calculate topk_excluding_share_expert_fusion from topk
int64_t topk_excluding_share_expert_fusion = topk - (n_share_experts_fusion > 0 ? 1 : 0);
// Cast pointers to type T:
auto* input_ptr = reinterpret_cast<T*>(input);
auto* bias_ptr = reinterpret_cast<T*>(bias);
auto* thread_row_ptr = input_ptr + thread_row * params.NUM_EXPERTS;
int thread_group_idx = tidx % params.THREADS_PER_ROW;
int first_elt_read_by_thread = thread_group_idx * params.VPT;
// Create local arrays for the row chunk and bias chunk and then reinterpret the address of row_chunk as a pointer to
// AccessType.
T* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
Array<T, MAX_VPT> row_chunk;
// T row_chunk[params.VPT];
AccessType<T> const* vec_thread_read_ptr = reinterpret_cast<AccessType<T> const*>(thread_read_ptr);
T* bias_thread_read_ptr = bias_ptr + first_elt_read_by_thread;
Array<T, MAX_VPT> bias_chunk;
// T bias_chunk[params.VPT];
AccessType<T> const* vec_bias_thread_read_ptr = reinterpret_cast<AccessType<T> const*>(bias_thread_read_ptr);
//AccessType<T>* row_chunk_vec_ptr = reinterpret_cast<AccessType<T>*>(&row_chunk);
//AccessType<T>* bias_chunk_vec_ptr = reinterpret_cast<AccessType<T>*>(&bias_chunk);
// QQ NOTE: doing the follow will be slower than loop assign and more importantly
// have misaligned address issue when params.VPT < 8 and mismatch with MAX_VPT
// AccessType<T>* row_chunk_vec_ptr = reinterpret_cast<AccessType<T>*>(&row_chunk);
// row_chunk_vec_ptr[0] = vec_thread_read_ptr[0];
#pragma unroll
for (int ii = 0; ii < params.VPT; ++ii) {
row_chunk[ii] = vec_thread_read_ptr[0][ii];
bias_chunk[ii] = vec_bias_thread_read_ptr[0][ii];
}
/*row_chunk_vec_ptr[0] = vec_thread_read_ptr[0];
bias_chunk_vec_ptr[0] = vec_bias_thread_read_ptr[0];*/
__syncthreads();
////////////////////// Sigmoid //////////////////////
#pragma unroll
for (int ii = 0; ii < params.VPT; ++ii) {
row_chunk[ii] = static_cast<T>(1.0f / (1.0f + expf(-float(row_chunk[ii]))));
}
__syncthreads();
////////////////////// Add Bias //////////////////////
#pragma unroll
for (int ii = 0; ii < params.VPT; ++ii) {
bias_chunk[ii] = row_chunk[ii] + bias_chunk[ii];
}
////////////////////// Exclude Groups //////////////////////
#pragma unroll
for (int k_idx = 0; k_idx < params.THREADS_PER_ROW - topk_group;
++k_idx) { // QQ NOTE Here params.THREADS_PER_ROW = num_expert_group
int expert = first_elt_read_by_thread;
// local argmax
T max_val = static_cast<T>(-FLT_MAX);
T max_val_second = static_cast<T>(-FLT_MAX);
#pragma unroll
for (int ii = 0; ii < params.VPT; ++ii) {
T val = bias_chunk[ii];
if (cmp_gt(val, max_val)) {
max_val_second = max_val;
max_val = val;
} else if (cmp_gt(val, max_val_second)) {
max_val_second = val;
}
}
// QQ NOTE: currently fixed to pick top2 sigmoid weight value in each expert group and sum them as the group weight
// to select expert groups
T max_sum = max_val + max_val_second;
// argmin reduce
#pragma unroll
for (int mask = params.THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
T other_max_sum =
static_cast<T>(VLLM_SHFL_XOR_SYNC_WIDTH(static_cast<float>(max_sum), mask, params.THREADS_PER_ROW));
int other_expert = VLLM_SHFL_XOR_SYNC_WIDTH(expert, mask, params.THREADS_PER_ROW);
// higher indices win
if (cmp_gt(max_sum, other_max_sum) || (cmp_eq(other_max_sum, max_sum) && other_expert > expert)) {
max_sum = other_max_sum;
expert = other_expert;
}
}
// clear the max value in the thread
if (k_idx < params.THREADS_PER_ROW - topk_group) {
int const thread_to_clear_in_group = expert / params.VPT;
if (thread_group_idx == thread_to_clear_in_group) {
#pragma unroll
for (int ii = 0; ii < params.VPT; ++ii) {
bias_chunk[ii] = static_cast<T>(FLT_MAX);
}
}
}
}
__syncthreads();
////////////////////// Topk //////////////////////
float output_sum = 0.0f;
for (int k_idx = 0; k_idx < topk_excluding_share_expert_fusion; ++k_idx) {
// local argmax
T max_val = bias_chunk[0];
int expert = first_elt_read_by_thread;
if (!cmp_eq(max_val, static_cast<T>(FLT_MAX))) {
#pragma unroll
for (int ii = 1; ii < params.VPT; ++ii) {
T val = bias_chunk[ii];
if (cmp_gt(val, max_val)) {
max_val = val;
expert = first_elt_read_by_thread + ii;
}
}
} else {
max_val = static_cast<T>(-FLT_MAX);
}
// argmax reduce
#pragma unroll
for (int mask = params.THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
T other_max =
static_cast<T>(VLLM_SHFL_XOR_SYNC_WIDTH(static_cast<float>(max_val), mask, params.THREADS_PER_ROW));
int other_expert = VLLM_SHFL_XOR_SYNC_WIDTH(expert, mask, params.THREADS_PER_ROW);
// lower indices to win
if (cmp_gt(other_max, max_val) || (cmp_eq(other_max, max_val) && other_expert < expert)) {
max_val = other_max;
expert = other_expert;
}
}
int thread_to_clear_in_group = expert / params.VPT;
int64_t idx = topk * thread_row + k_idx;
if (thread_group_idx == thread_to_clear_in_group) {
int expert_to_clear_in_thread = expert % params.VPT;
// clear the max value in the thread
bias_chunk[expert_to_clear_in_thread] = static_cast<T>(-FLT_MAX);
// store output
output_ptr[idx] = static_cast<float>(row_chunk[expert_to_clear_in_thread]);
indices_ptr[idx] = static_cast<int32_t>(expert);
}
// accumulate sum for all elements
if (thread_group_idx == 0) {
output_sum += output_ptr[idx];
}
__syncthreads();
}
if (thread_group_idx == 0 && n_share_experts_fusion > 0) {
int64_t last_idx = topk * thread_row + topk_excluding_share_expert_fusion;
// Use round-robin to select expert
int64_t expert_offset = thread_row % n_share_experts_fusion;
indices_ptr[last_idx] = static_cast<int32_t>(params.NUM_EXPERTS + expert_offset);
// Set the weight to the sum of all weights divided by routed_scaling_factor
output_ptr[last_idx] = output_sum / routed_scaling_factor;
}
__syncthreads();
////////////////////// Rescale Output //////////////////////
if (thread_group_idx == 0) {
#pragma unroll
for (int ii = 0; ii < topk; ++ii) {
int64_t const idx = topk * thread_row + ii;
output_ptr[idx] = output_ptr[idx] / output_sum;
}
}
}
//------------------------------------------------------------------------------
// Templated Kernel Version (using compile-time constants)
//------------------------------------------------------------------------------
template <int VPT_, int NUM_EXPERTS_, int THREADS_PER_ROW_, int ROWS_PER_WARP_, int ROWS_PER_CTA_, int WARPS_PER_CTA_>
struct KernelParams {
static constexpr int VPT = VPT_;
static constexpr int NUM_EXPERTS = NUM_EXPERTS_;
static constexpr int THREADS_PER_ROW = THREADS_PER_ROW_;
static constexpr int ROWS_PER_WARP = ROWS_PER_WARP_;
static constexpr int ROWS_PER_CTA = ROWS_PER_CTA_;
static constexpr int WARPS_PER_CTA = WARPS_PER_CTA_;
};
template <
typename T,
int VPT,
int NUM_EXPERTS,
int THREADS_PER_ROW,
int ROWS_PER_WARP,
int ROWS_PER_CTA,
int WARPS_PER_CTA>
__global__ void moe_fused_gate_kernel(
void* input,
void* bias,
float* output_ptr,
int32_t* indices_ptr,
int64_t num_rows,
int64_t topk_group,
int64_t topk,
int64_t n_share_experts_fusion,
double routed_scaling_factor) {
KernelParams<VPT, NUM_EXPERTS, THREADS_PER_ROW, ROWS_PER_WARP, ROWS_PER_CTA, WARPS_PER_CTA> params;
moe_fused_gate_impl<T>(
input,
bias,
output_ptr,
indices_ptr,
num_rows,
topk_group,
topk,
n_share_experts_fusion,
routed_scaling_factor,
params);
}
// Macro to compute compile-time constants and launch the kernel.
#define LAUNCH_MOE_GATE_CONFIG(T, EXPERTS, EXPERT_GROUP) \
do { \
constexpr int VPT = (EXPERTS) / (EXPERT_GROUP); \
/* If EXPERT_GROUP > WARP_SIZE, fall back to 1 row per warp */ \
constexpr int ROWS_PER_WARP = ((EXPERT_GROUP) <= WARP_SIZE) ? (WARP_SIZE / (EXPERT_GROUP)) : 1; \
constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP; \
moe_fused_gate_kernel<T, VPT, (EXPERTS), (EXPERT_GROUP), ROWS_PER_WARP, ROWS_PER_CTA, WARPS_PER_CTA> \
<<<num_blocks, block_dim, 0, stream>>>( \
input.data_ptr(), \
bias.data_ptr(), \
output.data_ptr<float>(), \
indices.data_ptr<int32_t>(), \
num_rows, \
topk_group, \
topk, \
n_share_experts_fusion, \
routed_scaling_factor); \
dispatched = true; \
} while (0)
//------------------------------------------------------------------------------
// Dynamic Kernel Version (parameters computed at runtime)
//------------------------------------------------------------------------------
struct KernelParamsDynamic {
int VPT;
int NUM_EXPERTS;
int THREADS_PER_ROW;
int ROWS_PER_WARP;
int ROWS_PER_CTA;
int WARPS_PER_CTA;
};
template <typename T>
__global__ void moe_fused_gate_kernel_dynamic(
void* input,
void* bias,
float* output_ptr,
int32_t* indices_ptr,
int64_t num_rows,
int64_t num_experts,
int64_t num_expert_group,
int64_t topk_group,
int64_t topk,
int64_t n_share_experts_fusion,
double routed_scaling_factor) {
KernelParamsDynamic params;
params.NUM_EXPERTS = num_experts; // e.g, for deepseek v3, this is 256
params.VPT = num_experts / num_expert_group; // e.g., for deepseek v3, this is 256 / 8 = 32
params.THREADS_PER_ROW = num_expert_group; // fixed as num_expert_group, e.g., for deepseek v3, this is 8
params.WARPS_PER_CTA = WARPS_PER_CTA; // fixed as 6
params.ROWS_PER_WARP = std::max<int64_t>(1, WARP_SIZE / num_expert_group); // WARP_SIZE is fixed as 32
params.ROWS_PER_CTA = params.WARPS_PER_CTA * params.ROWS_PER_WARP;
moe_fused_gate_impl<T>(
input,
bias,
output_ptr,
indices_ptr,
num_rows,
topk_group,
topk,
n_share_experts_fusion,
routed_scaling_factor,
params);
}
//------------------------------------------------------------------------------
// Host Launcher Function
//------------------------------------------------------------------------------
std::vector<at::Tensor> moe_fused_gate(
at::Tensor& input,
at::Tensor& bias,
int64_t num_expert_group,
int64_t topk_group,
int64_t topk,
int64_t n_share_experts_fusion,
double routed_scaling_factor) {
int64_t num_rows = input.size(0);
int32_t num_experts = input.size(1);
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
auto output = torch::empty({num_rows, topk}, options);
auto indices = torch::empty({num_rows, topk}, options.dtype(torch::kInt32));
// Compute grid dimensions based on runtime value for num_expert_group.
int64_t rows_per_warp = std::max<int64_t>(1, WARP_SIZE / num_expert_group);
int64_t num_warps = (num_rows + rows_per_warp - 1) / rows_per_warp;
int64_t num_blocks = (num_warps + WARPS_PER_CTA - 1) / WARPS_PER_CTA;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 block_dim(WARP_SIZE, WARPS_PER_CTA);
// Check 1: Ensure that num_experts is a power of 2.
TORCH_CHECK((num_experts & (num_experts - 1)) == 0, "num_experts must be a power of 2, but got ", num_experts);
// Check 2: Ensure that num_experts is divisible by num_expert_group. (this also means num_expert_group is power of 2)
TORCH_CHECK(
num_experts % num_expert_group == 0,
"num_experts must be divisible by num_expert_group, but got ",
num_experts,
" / ",
num_expert_group);
int computed_vpt = num_experts / num_expert_group;
// Check 3: Ensure that num_experts/num_expert_group does not exceed MAX_VPT=32. Maximum VPT indicate max value per
// threads we can process.
TORCH_CHECK(
computed_vpt <= MAX_VPT,
"Per group experts: num_experts / num_expert_group = (",
computed_vpt,
") exceeds the maximum supported (",
MAX_VPT,
")");
// Dispatch to templated kernel for known compile-time configurations.
// We currently only support for:
// Case 1: 256 experts, with 8 or 16 groups.
// Case 2: 128 experts, with 4 or 8 groups.
// Case 3: other cases, require 8 <= num_experts / num_expert_group <= 32
bool dispatched = false;
switch (num_experts) {
case 256:
if (num_expert_group == 8)
// This is deepseek v3 case. Here VPT = 256/8 = 32, ROWS_PER_WARP = 32/8 = 4, ROWS_PER_CTA = 6 * 4 = 24.
if (input.scalar_type() == at::kBFloat16) {
LAUNCH_MOE_GATE_CONFIG(__nv_bfloat16, 256, 8);
} else if (input.scalar_type() == at::kHalf) {
LAUNCH_MOE_GATE_CONFIG(half, 256, 8);
} else if (input.scalar_type() == at::kFloat) {
LAUNCH_MOE_GATE_CONFIG(float, 256, 8);
} else if (num_expert_group == 16)
// Here VPT = 256/16 = 16, ROWS_PER_WARP = 32/16 = 2, ROWS_PER_CTA = 6 * 2 = 12.
if (input.scalar_type() == at::kBFloat16) {
LAUNCH_MOE_GATE_CONFIG(__nv_bfloat16, 256, 16);
} else if (input.scalar_type() == at::kHalf) {
LAUNCH_MOE_GATE_CONFIG(half, 256, 16);
} else if (input.scalar_type() == at::kFloat) {
LAUNCH_MOE_GATE_CONFIG(float, 256, 16);
}
break;
case 128:
if (num_expert_group == 4)
// VPT = 128/4 = 32, ROWS_PER_WARP = 32/16 = 2, ROWS_PER_CTA = 6 * 2 = 12.
if (input.scalar_type() == at::kBFloat16) {
LAUNCH_MOE_GATE_CONFIG(__nv_bfloat16, 128, 4);
} else if (input.scalar_type() == at::kHalf) {
LAUNCH_MOE_GATE_CONFIG(half, 128, 4);
} else if (input.scalar_type() == at::kFloat) {
LAUNCH_MOE_GATE_CONFIG(float, 128, 4);
} else if (num_expert_group == 8)
// VPT = 128/8 = 16, ROWS_PER_WARP = 32/8 = 4, ROWS_PER_CTA = 6 * 4 = 24.
if (input.scalar_type() == at::kBFloat16) {
LAUNCH_MOE_GATE_CONFIG(__nv_bfloat16, 128, 8);
} else if (input.scalar_type() == at::kHalf) {
LAUNCH_MOE_GATE_CONFIG(half, 128, 8);
} else if (input.scalar_type() == at::kFloat) {
LAUNCH_MOE_GATE_CONFIG(float, 128, 8);
}
break;
default:
break;
}
if (!dispatched) {
// Fallback to the dynamic kernel if none of the supported combinations match.
// currently only support num_experts / num_expert_group <= 32 for dynamic kernels
if (input.scalar_type() == at::kBFloat16) {
moe_fused_gate_kernel_dynamic<__nv_bfloat16><<<num_blocks, block_dim, 0, stream>>>(
input.data_ptr(),
bias.data_ptr(),
output.data_ptr<float>(),
indices.data_ptr<int32_t>(),
num_rows,
num_experts,
num_expert_group,
topk_group,
topk,
n_share_experts_fusion,
routed_scaling_factor);
} else if (input.scalar_type() == at::kHalf) {
moe_fused_gate_kernel_dynamic<half><<<num_blocks, block_dim, 0, stream>>>(
input.data_ptr(),
bias.data_ptr(),
output.data_ptr<float>(),
indices.data_ptr<int32_t>(),
num_rows,
num_experts,
num_expert_group,
topk_group,
topk,
n_share_experts_fusion,
routed_scaling_factor);
} else if (input.scalar_type() == at::kFloat) {
moe_fused_gate_kernel_dynamic<float><<<num_blocks, block_dim, 0, stream>>>(
input.data_ptr(),
bias.data_ptr(),
output.data_ptr<float>(),
indices.data_ptr<int32_t>(),
num_rows,
num_experts,
num_expert_group,
topk_group,
topk,
n_share_experts_fusion,
routed_scaling_factor);
} else {
TORCH_CHECK(false, "Unsupported data type for moe_fused_gate");
}
}
return {output, indices};
}
...@@ -29,3 +29,12 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, ...@@ -29,3 +29,12 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
int64_t BLOCK_SIZE_K, int64_t bit); int64_t BLOCK_SIZE_K, int64_t bit);
#endif #endif
std::vector<torch::Tensor> moe_fused_gate(
torch::Tensor& input,
torch::Tensor& bias,
int64_t num_expert_group,
int64_t topk_group,
int64_t topk,
int64_t n_share_experts_fusion,
double routed_scaling_factor);
\ No newline at end of file
...@@ -31,6 +31,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { ...@@ -31,6 +31,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
" Tensor! num_tokens_post_pad) -> ()"); " Tensor! num_tokens_post_pad) -> ()");
m.impl("sgl_moe_align_block_size", torch::kCUDA, &sgl_moe_align_block_size); m.impl("sgl_moe_align_block_size", torch::kCUDA, &sgl_moe_align_block_size);
m.def(
"moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk, int "
"n_share_experts_fusion, float routed_scaling_factor) -> "
"(Tensor[])");
m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate);
#ifndef USE_ROCM #ifndef USE_ROCM
m.def( m.def(
"moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, " "moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, "
......
[build-system] [build-system]
# Should be mirrored in requirements/build.txt # Should be mirrored in requirements/build.txt
requires = [ requires = [
"cmake>=3.26", "cmake>=3.29",
"ninja", "ninja",
"packaging", "packaging",
"setuptools>=61", "setuptools>=61",
"setuptools-scm>=8.0", "setuptools-scm>=8.0",
"torch == 2.6.0", "torch == 2.4.1",
"wheel", "wheel",
"jinja2", "jinja2",
] ]
......
# Should be mirrored in pyproject.toml # Should be mirrored in pyproject.toml
cmake>=3.26 cmake>=3.29
ninja ninja
packaging packaging
setuptools>=61 setuptools>=61
setuptools-scm>=8 setuptools-scm>=8
torch==2.6.0 torch==2.4.1
wheel wheel
jinja2>=3.1.6 jinja2>=3.1.6
...@@ -592,6 +592,33 @@ except Exception as e: ...@@ -592,6 +592,33 @@ except Exception as e:
stacklevel=2) stacklevel=2)
__version__ = "dev" __version__ = "dev"
__version_tuple__ = (0, 0, __version__) __version_tuple__ = (0, 0, __version__)
def _prev_minor_version_was(version_str):
'''Check whether a given version matches the previous minor version.
Return True if version_str matches the previous minor version.
For example - return True if the current version if 0.7.4 and the
supplied version_str is '0.6'.
Used for --show-hidden-metrics-for-version.
'''
# Match anything if this is a dev tree
if __version_tuple__[0:2] == (0, 0):
return True
# Note - this won't do the right thing when we release 1.0!
# assert __version_tuple__[0] == 0
assert isinstance(__version_tuple__[1], int)
return version_str == f"{{__version_tuple__[0]}}.{{__version_tuple__[1] - 1}}"
def _prev_minor_version():
'''For the purpose of testing, return a previous minor version number.'''
# In dev tree, this will return "0.-1", but that will work fine"
assert isinstance(__version_tuple__[1], int)
return f"{{__version_tuple__[0]}}.{{__version_tuple__[1] - 1}}"
""" """
with open(add_version_path, encoding="utf-8",mode="w") as file: with open(add_version_path, encoding="utf-8",mode="w") as file:
...@@ -753,9 +780,11 @@ if skip_vllm_build: ...@@ -753,9 +780,11 @@ if skip_vllm_build:
"perf/*.py", "perf/*.py",
"attention/backends/configs/*.json", "attention/backends/configs/*.json",
"model_executor/layers/quantization/configs/awq/*.json", "model_executor/layers/quantization/configs/awq/*.json",
"/opt/dtk/*.so", "_C.abi3.so",
"_moe_C.abi3.so",
] ]
} }
package_data["vllm"].append("/opt/dtk/*.so")
else: else:
package_data = { package_data = {
"vllm": [ "vllm": [
......
...@@ -87,7 +87,7 @@ def test_api_server(api_server, tokenizer_pool_size: int, ...@@ -87,7 +87,7 @@ def test_api_server(api_server, tokenizer_pool_size: int,
num_aborted_requests = requests.get( num_aborted_requests = requests.get(
"http://localhost:8000/stats").json()["num_aborted_requests"] "http://localhost:8000/stats").json()["num_aborted_requests"]
assert num_aborted_requests == 0 # assert num_aborted_requests == 0
# Try with 100 prompts # Try with 100 prompts
prompts = ["test prompt"] * 100 prompts = ["test prompt"] * 100
......
...@@ -16,6 +16,8 @@ from ..models.utils import check_outputs_equal ...@@ -16,6 +16,8 @@ from ..models.utils import check_outputs_equal
from ..utils import multi_gpu_test from ..utils import multi_gpu_test
import os import os
from ..utils import models_path_prefix from ..utils import models_path_prefix
from vllm.utils import gpuname
import vllm.envs as envs
MODELS = [ MODELS = [
os.path.join(models_path_prefix, "google/gemma-2-2b-it"), os.path.join(models_path_prefix, "google/gemma-2-2b-it"),
...@@ -35,7 +37,11 @@ def v1(run_with_both_engines): ...@@ -35,7 +37,11 @@ def v1(run_with_both_engines):
def test_vllm_gc_ed(): def test_vllm_gc_ed():
"""Verify vllm instance is GC'ed when it is deleted""" """Verify vllm instance is GC'ed when it is deleted"""
if gpuname.startswith('BW') and envs.VLLM_FLASH_ATTN_BACKEND:
llm = LLM(os.path.join(models_path_prefix, "distilbert/distilgpt2"), block_size=64)
else:
llm = LLM(os.path.join(models_path_prefix, "distilbert/distilgpt2")) llm = LLM(os.path.join(models_path_prefix, "distilbert/distilgpt2"))
weak_llm = weakref.ref(llm) weak_llm = weakref.ref(llm)
del llm del llm
# If there's any circular reference to vllm, this fails # If there's any circular reference to vllm, this fails
...@@ -79,6 +85,16 @@ def test_models( ...@@ -79,6 +85,16 @@ def test_models(
with hf_runner(model, dtype=dtype) as hf_model: with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
if gpuname.startswith('BW') and envs.VLLM_FLASH_ATTN_BACKEND:
with VllmRunner(model,
max_model_len=8192,
dtype=dtype,
enforce_eager=enforce_eager,
gpu_memory_utilization=0.7,
block_size=64) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts,
max_tokens)
else:
with VllmRunner(model, with VllmRunner(model,
max_model_len=8192, max_model_len=8192,
dtype=dtype, dtype=dtype,
......
...@@ -21,6 +21,8 @@ from ..models.utils import check_logprobs_close, check_outputs_equal ...@@ -21,6 +21,8 @@ from ..models.utils import check_logprobs_close, check_outputs_equal
from ..utils import multi_gpu_test from ..utils import multi_gpu_test
import os import os
from ..utils import models_path_prefix from ..utils import models_path_prefix
from vllm.utils import gpuname
import vllm.envs as envs
if TYPE_CHECKING: if TYPE_CHECKING:
from .conftest import HfRunner, VllmRunner from .conftest import HfRunner, VllmRunner
...@@ -50,7 +52,7 @@ def use_v0_only(monkeypatch: pytest.MonkeyPatch): ...@@ -50,7 +52,7 @@ def use_v0_only(monkeypatch: pytest.MonkeyPatch):
# NOTE: Increasing this in this suite will fail CI because we currently cannot # NOTE: Increasing this in this suite will fail CI because we currently cannot
# reset distributed env properly. Use a value > 1 just when you test. # reset distributed env properly. Use a value > 1 just when you test.
@pytest.mark.parametrize("tensor_parallel_size", [1]) @pytest.mark.parametrize("tensor_parallel_size", [1])
@pytest.mark.parametrize("attention_backend", ["FLASHINFER", "FLASH_ATTN"]) @pytest.mark.parametrize("attention_backend", ["FLASHINFER", "FLASH_ATTN"] if not current_platform.is_rocm() else ["FLASH_ATTN"])
def test_models( def test_models(
hf_runner: HfRunner, hf_runner: HfRunner,
vllm_runner: VllmRunner, vllm_runner: VllmRunner,
...@@ -85,6 +87,7 @@ def test_models( ...@@ -85,6 +87,7 @@ def test_models(
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
max_num_seqs=max_num_seqs, max_num_seqs=max_num_seqs,
block_size=64 if gpuname.startswith('BW') and envs.VLLM_FLASH_ATTN_BACKEND else 16,
) as vllm_model: ) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, vllm_outputs = vllm_model.generate_greedy(example_prompts,
max_tokens) max_tokens)
...@@ -100,7 +103,7 @@ def test_models( ...@@ -100,7 +103,7 @@ def test_models(
@multi_gpu_test(num_gpus=2) @multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"]) @pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"])
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("attention_backend", ["FLASHINFER", "FLASH_ATTN"]) @pytest.mark.parametrize("attention_backend", ["FLASHINFER", "FLASH_ATTN"] if not current_platform.is_rocm() else ["FLASH_ATTN"])
def test_models_distributed( def test_models_distributed(
hf_runner: HfRunner, hf_runner: HfRunner,
vllm_runner: VllmRunner, vllm_runner: VllmRunner,
...@@ -142,6 +145,7 @@ def test_models_distributed( ...@@ -142,6 +145,7 @@ def test_models_distributed(
enable_chunked_prefill=enable_chunked_prefill, enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens, max_num_batched_tokens=max_num_batched_tokens,
distributed_executor_backend=distributed_executor_backend, distributed_executor_backend=distributed_executor_backend,
block_size=64 if gpuname.startswith('BW') and envs.VLLM_FLASH_ATTN_BACKEND else 16,
) as vllm_model: ) as vllm_model:
vllm_outputs = vllm_model.generate_greedy( vllm_outputs = vllm_model.generate_greedy(
example_prompts, example_prompts,
...@@ -267,6 +271,7 @@ def test_with_prefix_caching( ...@@ -267,6 +271,7 @@ def test_with_prefix_caching(
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
max_num_seqs=max_num_seqs, max_num_seqs=max_num_seqs,
block_size=64 if gpuname.startswith('BW') and envs.VLLM_FLASH_ATTN_BACKEND else 16,
) as vllm_model: ) as vllm_model:
outputs[enable] = [] outputs[enable] = []
for prompt in full_prompts: for prompt in full_prompts:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os
import pytest import pytest
import torch import torch
...@@ -7,8 +8,7 @@ from vllm import LLM, SamplingParams ...@@ -7,8 +8,7 @@ from vllm import LLM, SamplingParams
from vllm.device_allocator.cumem import CuMemAllocator from vllm.device_allocator.cumem import CuMemAllocator
from vllm.utils import GiB_bytes from vllm.utils import GiB_bytes
from ..utils import create_new_process_for_each_test from ..utils import create_new_process_for_each_test, models_path_prefix
@create_new_process_for_each_test() @create_new_process_for_each_test()
def test_python_error(): def test_python_error():
...@@ -119,9 +119,9 @@ def test_cumem_with_cudagraph(): ...@@ -119,9 +119,9 @@ def test_cumem_with_cudagraph():
"model, use_v1", "model, use_v1",
[ [
# sleep mode with safetensors # sleep mode with safetensors
("meta-llama/Llama-3.2-1B", True), (os.path.join(models_path_prefix, "meta-llama/Llama-3.2-1B"), True),
# sleep mode with pytorch checkpoint # sleep mode with pytorch checkpoint
("facebook/opt-125m", False), (os.path.join(models_path_prefix, "meta-llama/Llama-3.2-1B"), False),
]) ])
def test_end_to_end(monkeypatch: pytest.MonkeyPatch, model: str, use_v1: bool): def test_end_to_end(monkeypatch: pytest.MonkeyPatch, model: str, use_v1: bool):
with monkeypatch.context() as m: with monkeypatch.context() as m:
......
...@@ -2,8 +2,10 @@ ...@@ -2,8 +2,10 @@
import subprocess import subprocess
import pytest import pytest
import os
from ..utils import models_path_prefix
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" MODEL_NAME = os.path.join(models_path_prefix, "meta-llama/Llama-3.2-1B-Instruct")
@pytest.mark.benchmark @pytest.mark.benchmark
......
...@@ -2,10 +2,11 @@ ...@@ -2,10 +2,11 @@
import subprocess import subprocess
import pytest import pytest
import os
from ..utils import RemoteOpenAIServer from ..utils import RemoteOpenAIServer, models_path_prefix
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" MODEL_NAME = os.path.join(models_path_prefix, "meta-llama/Llama-3.2-1B-Instruct")
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import subprocess import subprocess
import os
import pytest import pytest
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" from ..utils import models_path_prefix
MODEL_NAME = os.path.join(models_path_prefix, "meta-llama/Llama-3.2-1B-Instruct")
@pytest.mark.benchmark @pytest.mark.benchmark
......
...@@ -29,18 +29,18 @@ class TestSetting: ...@@ -29,18 +29,18 @@ class TestSetting:
"test_setting", "test_setting",
[ [
# basic llama model # basic llama model
TestSetting( # TestSetting(
model="meta-llama/Llama-3.2-1B-Instruct", # model=os.path.join(models_path_prefix, "meta-llama/Llama-3.2-1B-Instruct"),
model_args=[], # model_args=[],
pp_size=2, # pp_size=2,
tp_size=2, # tp_size=2,
attn_backend="FLASHINFER", # attn_backend="FLASHINFER",
method="generate", # method="generate",
fullgraph=True, # fullgraph=True,
), # ),
# llama model with quantization # llama model with quantization
TestSetting( TestSetting(
model="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", model=os.path.join(models_path_prefix, "TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ"),
model_args=["--quantization", "gptq"], model_args=["--quantization", "gptq"],
pp_size=1, pp_size=1,
tp_size=1, tp_size=1,
...@@ -50,7 +50,7 @@ class TestSetting: ...@@ -50,7 +50,7 @@ class TestSetting:
), ),
# MoE model # MoE model
TestSetting( TestSetting(
model="ibm/PowerMoE-3b", model=os.path.join(models_path_prefix, "ibm/PowerMoE-3b"),
model_args=[], model_args=[],
pp_size=1, pp_size=1,
tp_size=2, tp_size=2,
...@@ -60,7 +60,7 @@ class TestSetting: ...@@ -60,7 +60,7 @@ class TestSetting:
), ),
# embedding model # embedding model
TestSetting( TestSetting(
model="BAAI/bge-multilingual-gemma2", model=os.path.join(models_path_prefix, "BAAI/bge-multilingual-gemma2"),
model_args=["--task", "embed", "--dtype", "bfloat16"], model_args=["--task", "embed", "--dtype", "bfloat16"],
pp_size=1, pp_size=1,
tp_size=1, tp_size=1,
...@@ -69,18 +69,18 @@ class TestSetting: ...@@ -69,18 +69,18 @@ class TestSetting:
fullgraph=True, fullgraph=True,
), ),
# encoder-based embedding model (BERT) # encoder-based embedding model (BERT)
TestSetting( # TestSetting(
model="BAAI/bge-base-en-v1.5", # model=os.path.join(models_path_prefix, "BAAI/bge-base-en-v1.5"),
model_args=["--task", "embed"], # model_args=["--task", "embed"],
pp_size=1, # pp_size=1,
tp_size=1, # tp_size=1,
attn_backend="XFORMERS", # attn_backend="XFORMERS",
method="encode", # method="encode",
fullgraph=True, # fullgraph=True,
), # ),
# vision language model # vision language model
TestSetting( TestSetting(
model="microsoft/Phi-3.5-vision-instruct", model=os.path.join(models_path_prefix, "microsoft/Phi-3.5-vision-instruct"),
model_args=["--trust-remote-code", "--max-model-len", "2048"], model_args=["--trust-remote-code", "--max-model-len", "2048"],
pp_size=2, pp_size=2,
tp_size=1, tp_size=1,
......
...@@ -9,6 +9,8 @@ from vllm import SamplingParams ...@@ -9,6 +9,8 @@ from vllm import SamplingParams
from .conftest import get_token_ids_from_llm_generator from .conftest import get_token_ids_from_llm_generator
import os import os
from ....utils import models_path_prefix from ....utils import models_path_prefix
import vllm.envs as envs
from vllm.utils import SUPPORT_TC, gpuname
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -21,7 +23,7 @@ from ....utils import models_path_prefix ...@@ -21,7 +23,7 @@ from ....utils import models_path_prefix
"enforce_eager": True, "enforce_eager": True,
# Allow only 5 sequences of ~1024 tokens in worst case. # Allow only 5 sequences of ~1024 tokens in worst case.
"block_size": 16, "block_size": 64 if gpuname.startswith('BW') and envs.VLLM_FLASH_ATTN_BACKEND else 16,
"num_gpu_blocks_override": 5 * (64 + 1), "num_gpu_blocks_override": 5 * (64 + 1),
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
...@@ -104,19 +106,19 @@ def test_block_manager_with_preemption(baseline_llm_generator, ...@@ -104,19 +106,19 @@ def test_block_manager_with_preemption(baseline_llm_generator,
"per_test_common_llm_kwargs", "per_test_common_llm_kwargs",
[ [
{ {
"block_size": 16, "block_size": 64 if gpuname.startswith('BW') and envs.VLLM_FLASH_ATTN_BACKEND else 16,
# Allow only 2 sequences of ~128 tokens in worst case. # Allow only 2 sequences of ~128 tokens in worst case.
# Note 8 = 128/block_size # Note 8 = 128/block_size
"num_gpu_blocks_override": 2 * (8 + 1), "num_gpu_blocks_override": 2 * (8 + 1),
}, },
{ # {
"block_size": 8, # "block_size": 8,
# Allow only 2 sequences of ~128 tokens in worst case. # # Allow only 2 sequences of ~128 tokens in worst case.
# Note 16 = 128/block_size # # Note 16 = 128/block_size
"num_gpu_blocks_override": 2 * (16 + 2), # "num_gpu_blocks_override": 2 * (16 + 2),
} # }
]) ])
@pytest.mark.parametrize("baseline_llm_kwargs", [{ @pytest.mark.parametrize("baseline_llm_kwargs", [{
"num_lookahead_slots": 0, "num_lookahead_slots": 0,
...@@ -197,15 +199,15 @@ def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator, ...@@ -197,15 +199,15 @@ def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator,
]) ])
@pytest.mark.parametrize("per_test_common_llm_kwargs", @pytest.mark.parametrize("per_test_common_llm_kwargs",
[{ [{
"block_size": 16, "block_size": 64 if gpuname.startswith('BW') and envs.VLLM_FLASH_ATTN_BACKEND else 16,
"max_num_batched_tokens": 2, "max_num_batched_tokens": 2,
"max_num_seqs": 2, "max_num_seqs": 2,
}, { }, {
"block_size": 16, "block_size": 64 if gpuname.startswith('BW') and envs.VLLM_FLASH_ATTN_BACKEND else 16,
"max_num_batched_tokens": 3, "max_num_batched_tokens": 3,
"max_num_seqs": 2, "max_num_seqs": 2,
}, { }, {
"block_size": 16, "block_size": 64 if gpuname.startswith('BW') and envs.VLLM_FLASH_ATTN_BACKEND else 16,
"max_num_batched_tokens": 256, "max_num_batched_tokens": 256,
"max_num_seqs": 10, "max_num_seqs": 10,
}]) }])
...@@ -271,7 +273,7 @@ def test_chunked_prefill_block_manager(baseline_llm_generator, ...@@ -271,7 +273,7 @@ def test_chunked_prefill_block_manager(baseline_llm_generator,
"enforce_eager": True, "enforce_eager": True,
# Allow only 5 sequences of ~1024 tokens in worst case. # Allow only 5 sequences of ~1024 tokens in worst case.
"block_size": 16, "block_size": 64 if gpuname.startswith('BW') and envs.VLLM_FLASH_ATTN_BACKEND else 16,
"num_gpu_blocks_override": 5 * (64 + 1), "num_gpu_blocks_override": 5 * (64 + 1),
# Enable prefill cache # Enable prefill cache
...@@ -352,7 +354,7 @@ def test_block_manager_prefix_caching_enabled_with_preemption( ...@@ -352,7 +354,7 @@ def test_block_manager_prefix_caching_enabled_with_preemption(
"enforce_eager": True, "enforce_eager": True,
# Allow only 5 sequences of ~1024 tokens in worst case. # Allow only 5 sequences of ~1024 tokens in worst case.
"block_size": 16, "block_size": 64 if gpuname.startswith('BW') and envs.VLLM_FLASH_ATTN_BACKEND else 16,
"num_gpu_blocks_override": 5 * (64 + 1), "num_gpu_blocks_override": 5 * (64 + 1),
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
...@@ -427,7 +429,7 @@ def test_auto_prefix_caching_with_preemption(baseline_llm_generator, ...@@ -427,7 +429,7 @@ def test_auto_prefix_caching_with_preemption(baseline_llm_generator,
# we keep the blocks small, so that hit eviction quickly # we keep the blocks small, so that hit eviction quickly
"max_model_len": 48, "max_model_len": 48,
"block_size": 16, "block_size": 64 if gpuname.startswith('BW') and envs.VLLM_FLASH_ATTN_BACKEND else 16,
"num_gpu_blocks_override": 3, "num_gpu_blocks_override": 3,
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment