Commit 2703e2e9 authored by zhuwenwen's avatar zhuwenwen
Browse files

remove remove VLLM_USE_OPT_MOE_SUM

parent 1cb851b0
......@@ -912,7 +912,6 @@ class ModelConfig:
# imports during override detection (e.g., MXFP4 imports Triton)
"mxfp4",
"cpu_awq",
"slimquant_marlin",
"slimquant_w4a8_marlin",
"slimquant_compressed_tensors_marlin",
]
......
......@@ -277,7 +277,6 @@ if TYPE_CHECKING:
VLLM_USE_GLOBAL_CACHE13: bool = False
VLLM_USE_LIGHTOP: bool = False
VLLM_USE_OPT_CAT: bool = False
VLLM_USE_OPT_MOE_SUM: bool = False
VLLM_USE_LIGHTOP_MOE_SUM: bool = False
VLLM_USE_LIGHTOP_MOE_ALIGN: bool = False
VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False
......@@ -1773,11 +1772,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# vLLM will use opt cat for deepseek-v3
"VLLM_USE_OPT_CAT":
lambda: (os.environ.get("VLLM_USE_OPT_CAT", "True").lower() in
("true", "1")),
# vLLM will use triton moe_sum
"VLLM_USE_OPT_MOE_SUM":
lambda: (os.environ.get("VLLM_USE_OPT_MOE_SUM", "False").lower() in
("true", "1")),
("true", "1")),
# vLLM will use lightop moe_sum
"VLLM_USE_LIGHTOP_MOE_SUM":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_SUM", "False").lower() in
......
......@@ -80,134 +80,6 @@ def is_power_of_two(n):
return n > 0 and math.log2(n).is_integer()
@torch.compile
def moe_sum_reduce_torch_compile(x, out, routed_scaling_factor):
torch.sum(x, dim=1, out=out)
out.mul_(routed_scaling_factor)
@triton.jit
def _moe_sum_reduce_kernel(
input_ptr,
input_stride_0,
input_stride_1,
input_stride_2,
output_ptr,
output_stride_0,
output_stride_1,
token_num: int,
topk_num: int,
hidden_dim: int,
routed_scaling_factor: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DIM: tl.constexpr,
NUM_STAGE: tl.constexpr,
):
input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64)
input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64)
output_stride_0 = tl.cast(output_stride_0, dtype=tl.int64)
token_block_id = tl.program_id(0)
dim_block_id = tl.program_id(1)
token_start = token_block_id * BLOCK_M
token_end = min((token_block_id + 1) * BLOCK_M, token_num)
dim_start = dim_block_id * BLOCK_DIM
dim_end = min((dim_block_id + 1) * BLOCK_DIM, hidden_dim)
offs_dim = dim_start + tl.arange(0, BLOCK_DIM)
for token_index in range(token_start, token_end):
accumulator = tl.zeros((BLOCK_DIM,), dtype=tl.float32)
input_t_ptr = input_ptr + token_index * input_stride_0 + offs_dim
for i in tl.range(0, topk_num, num_stages=NUM_STAGE):
tmp = tl.load(
input_t_ptr + i * input_stride_1, mask=offs_dim < dim_end, other=0.0
)
accumulator += tmp
accumulator = accumulator * routed_scaling_factor
store_t_ptr = output_ptr + token_index * output_stride_0 + offs_dim
tl.store(
store_t_ptr,
accumulator.to(input_ptr.dtype.element_ty),
mask=offs_dim < dim_end,
)
def moe_sum_reduce_triton(
input: torch.Tensor, output: torch.Tensor, routed_scaling_factor: float
):
assert input.is_contiguous()
assert output.is_contiguous()
token_num, topk_num, hidden_dim = input.shape
assert output.shape[0] == token_num and output.shape[1] == hidden_dim
if token_num <= 32:
BLOCK_M = 1
BLOCK_DIM = 512
NUM_STAGE = 2
num_warps = 4
elif token_num <= 128:
BLOCK_M = 1
BLOCK_DIM = 1024
NUM_STAGE = 0
num_warps = 2
elif token_num <= 4096:
BLOCK_M = 1
BLOCK_DIM = 2048
NUM_STAGE = 0
num_warps = 2
else:
BLOCK_M = 1
BLOCK_DIM = 2048
NUM_STAGE = 2
num_warps = 8
grid = (
triton.cdiv(token_num, BLOCK_M),
triton.cdiv(hidden_dim, BLOCK_DIM),
)
_moe_sum_reduce_kernel[grid](
input,
*input.stride(),
output,
*output.stride(),
token_num=token_num,
topk_num=topk_num,
hidden_dim=hidden_dim,
routed_scaling_factor=routed_scaling_factor,
BLOCK_M=BLOCK_M,
BLOCK_DIM=BLOCK_DIM,
NUM_STAGE=NUM_STAGE,
num_warps=num_warps,
)
return
def moe_reduce_dispatch(
intermediate_cache3: torch.Tensor,
out_hidden_states: torch.Tensor,
begin_chunk_idx: int,
end_chunk_idx: int,
):
inter_cache_view = intermediate_cache3.view(*intermediate_cache3.shape)
n = intermediate_cache3.shape[0]
# 根据 n 大小选择不同的 reduce 实现
if 1 <= n <= 4:
moe_sum_reduce_torch_compile(inter_cache_view, out_hidden_states[begin_chunk_idx:end_chunk_idx], 1.0)
elif 4 < n <= 1024:
moe_sum_reduce_triton(inter_cache_view, out_hidden_states[begin_chunk_idx:end_chunk_idx], 1.0)
elif 1024 < n <= 32768:
ops.moe_sum_opt1(inter_cache_view, out_hidden_states[begin_chunk_idx:end_chunk_idx])
else:
ops.moe_sum(inter_cache_view, out_hidden_states[begin_chunk_idx:end_chunk_idx])
def get_moe_cache(top_k_num,N,K,device,dtype):
global moe_cache_singleton
......@@ -434,6 +306,7 @@ def fused_moe_kernel_gptq_awq(
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
# SPLIT_K: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
top_k: tl.constexpr,
compute_type: tl.constexpr,
......@@ -525,7 +398,6 @@ def fused_moe_kernel_gptq_awq(
+ (offs_k[:, None] // 2) * stride_bk
+ offs_bn[None, :] * stride_bn
)
b_shifter = (offs_k[:, None] % 2) * 4
elif use_int8_w8a16:
b_ptrs = (
......@@ -671,6 +543,7 @@ def fused_moe_kernel(
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
# SPLIT_K: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
top_k: tl.constexpr,
compute_type: tl.constexpr,
......@@ -713,14 +586,6 @@ def fused_moe_kernel(
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
pid = tl.program_id(axis=0)
# num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
# num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
# num_pid_in_group = GROUP_SIZE_M * num_pid_n
# group_id = pid // num_pid_in_group
# first_pid_m = group_id * GROUP_SIZE_M
# group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
# pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
# pid_n = (pid % num_pid_in_group) // group_size_m
if GROUP_SIZE_M ==1:
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
pid_m = pid // num_pid_n
......@@ -757,7 +622,7 @@ def fused_moe_kernel(
token_mask = offs_token < num_valid_tokens
off_experts = tl.load(expert_ids_ptr + pid_m)
off_experts = tl.load(expert_ids_ptr + pid_m) # .to(tl.int64)
if off_experts == -1:
# -----------------------------------------------------------
# Write back zeros to the output when the expert is not
......@@ -1235,7 +1100,6 @@ def dispatch_fused_moe_kernel(
block_shape is not None and block_shape[1] > 0
):
assert B_bias is None
# if os.environ.get('moe_wna16_use_cuda') == '1':
use_moe_wna16_cuda = should_moe_wna16_use_cuda(
num_valid_tokens=num_tokens,
......@@ -1243,7 +1107,6 @@ def dispatch_fused_moe_kernel(
num_experts=B.size(0),
bit=4 if use_int4_w4a16 else 8,
)
if use_moe_wna16_cuda:
invoke_fused_moe_wna16_cuda_kernel(
A,
......@@ -1303,7 +1166,6 @@ def dispatch_fused_moe_kernel(
B_bias,
)
@triton.jit
def compute_identity_kernel(
top_k: int,
......@@ -2394,13 +2256,6 @@ def fused_experts_impl(
num_local_tokens=None,
factor=1.0
)
elif envs.VLLM_USE_OPT_MOE_SUM:
moe_reduce_dispatch(
intermediate_cache3.view(*intermediate_cache3.size()),
out_hidden_states[begin_chunk_idx:end_chunk_idx],
begin_chunk_idx,
end_chunk_idx
)
else:
ops.moe_sum(
intermediate_cache3.view(*intermediate_cache3.size()),
......@@ -2466,7 +2321,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
return not moe_parallel_config.use_fi_all2allv_kernels
return True
def supports_chunking(self) -> bool:
return True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment