Commit 0b229519 authored by 王敏's avatar 王敏
Browse files

[feat]适配sgl moe_fused_gate kernel

parent 1150b65c
...@@ -621,7 +621,8 @@ target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1) ...@@ -621,7 +621,8 @@ target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
set(VLLM_MOE_EXT_SRC set(VLLM_MOE_EXT_SRC
"csrc/moe/torch_bindings.cpp" "csrc/moe/torch_bindings.cpp"
"csrc/moe/moe_align_sum_kernels.cu" "csrc/moe/moe_align_sum_kernels.cu"
"csrc/moe/topk_softmax_kernels.cu") "csrc/moe/topk_softmax_kernels.cu"
"csrc/moe/moe_fused_gate.cu")
if(VLLM_GPU_LANG STREQUAL "CUDA") if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_MOE_EXT_SRC "csrc/moe/moe_wna16.cu") list(APPEND VLLM_MOE_EXT_SRC "csrc/moe/moe_wna16.cu")
......
This diff is collapsed.
...@@ -28,4 +28,13 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, ...@@ -28,4 +28,13 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
torch::Tensor num_tokens_post_pad, int64_t top_k, torch::Tensor num_tokens_post_pad, int64_t top_k,
int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
int64_t BLOCK_SIZE_K, int64_t bit); int64_t BLOCK_SIZE_K, int64_t bit);
#endif #endif
\ No newline at end of file
std::vector<torch::Tensor> moe_fused_gate(
torch::Tensor& input,
torch::Tensor& bias,
int64_t num_expert_group,
int64_t topk_group,
int64_t topk,
int64_t n_share_experts_fusion,
double routed_scaling_factor);
\ No newline at end of file
...@@ -31,6 +31,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { ...@@ -31,6 +31,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
" Tensor! num_tokens_post_pad) -> ()"); " Tensor! num_tokens_post_pad) -> ()");
m.impl("sgl_moe_align_block_size", torch::kCUDA, &sgl_moe_align_block_size); m.impl("sgl_moe_align_block_size", torch::kCUDA, &sgl_moe_align_block_size);
m.def(
"moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk, int "
"n_share_experts_fusion, float routed_scaling_factor) -> "
"(Tensor[])");
m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate);
#ifndef USE_ROCM #ifndef USE_ROCM
m.def( m.def(
"moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, " "moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, "
......
...@@ -1979,3 +1979,31 @@ def flash_mla_with_kvcache( ...@@ -1979,3 +1979,31 @@ def flash_mla_with_kvcache(
# torch.ops._C.cutlass_mla_decode(out, q_nope, q_pe, kv_c_and_k_pe_cache, # torch.ops._C.cutlass_mla_decode(out, q_nope, q_pe, kv_c_and_k_pe_cache,
# seq_lens, page_table, scale) # seq_lens, page_table, scale)
# return out # return out
def moe_fused_gate(
input_tensor,
bias,
num_expert_group,
topk_group,
topk,
n_share_experts_fusion=0,
routed_scaling_factor=0,
):
# This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion
# it split group of expert into num_expert_group, and use top2 expert weight sum in each group
# as the group weight to select exerpt groups and then select topk experts within the selected groups
# the #experts is decided by the input tensor shape and we currently only support power of 2 #experts
# and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limitted for now.
# for non-supported case, we suggestion to use the biased_grouped_topk func in sglang.srt.layers.moe.topk
# n_share_experts_fusion: if > 0, the last expert will be replaced with a round-robin shared expert
# routed_scaling_factor: if > 0, the last expert will be scaled by this factor
return torch.ops._moe_C.moe_fused_gate(
input_tensor,
bias,
num_expert_group,
topk_group,
topk,
n_share_experts_fusion,
routed_scaling_factor,
)
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import functools import functools
import json import json
import os import os
import math
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple
import torch import torch
...@@ -1182,6 +1183,10 @@ def fused_topk( ...@@ -1182,6 +1183,10 @@ def fused_topk(
return topk_weights, topk_ids return topk_weights, topk_ids
def is_power_of_two(n):
return n > 0 and math.log2(n).is_integer()
# This is used by the Deepseek-V2 and Deepseek-V3 model # This is used by the Deepseek-V2 and Deepseek-V3 model
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def grouped_topk( def grouped_topk(
......
...@@ -23,6 +23,7 @@ from vllm.model_executor.utils import set_weight_attrs ...@@ -23,6 +23,7 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum from vllm.platforms.interface import CpuArchEnum
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from vllm import _custom_ops as ops
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
from .fused_moe import fused_experts from .fused_moe import fused_experts
...@@ -222,7 +223,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -222,7 +223,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor if hasattr(self, "routed_scaling_factor") else None)
return fused_experts( return fused_experts(
hidden_states=x, hidden_states=x,
...@@ -436,6 +438,7 @@ class FusedMoE(torch.nn.Module): ...@@ -436,6 +438,7 @@ class FusedMoE(torch.nn.Module):
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
): ):
super().__init__() super().__init__()
...@@ -505,6 +508,7 @@ class FusedMoE(torch.nn.Module): ...@@ -505,6 +508,7 @@ class FusedMoE(torch.nn.Module):
self.scoring_func = scoring_func self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias self.e_score_correction_bias = e_score_correction_bias
self.activation = activation self.activation = activation
self.routed_scaling_factor = routed_scaling_factor
if self.scoring_func != "softmax" and not self.use_grouped_topk: if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for " raise ValueError("Only softmax scoring function is supported for "
...@@ -554,6 +558,7 @@ class FusedMoE(torch.nn.Module): ...@@ -554,6 +558,7 @@ class FusedMoE(torch.nn.Module):
self.quant_method.create_weights(layer=self, **moe_quant_params) self.quant_method.create_weights(layer=self, **moe_quant_params)
setattr(self.quant_method, "routed_scaling_factor", self.routed_scaling_factor)
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce
self.tbo_all_reduce = tbo_all_reduce self.tbo_all_reduce = tbo_all_reduce
...@@ -839,23 +844,39 @@ class FusedMoE(torch.nn.Module): ...@@ -839,23 +844,39 @@ class FusedMoE(torch.nn.Module):
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None): e_score_correction_bias: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None,):
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, grouped_topk) fused_topk, grouped_topk, is_power_of_two)
# DeekSeekv2 uses grouped_top_k # DeekSeekv2 uses grouped_top_k
if use_grouped_topk: if use_grouped_topk:
assert topk_group is not None assert topk_group is not None
assert num_expert_group is not None assert num_expert_group is not None
topk_weights, topk_ids = grouped_topk( if e_score_correction_bias is not None \
hidden_states=hidden_states, and router_logits.shape[1] // num_expert_group <= 32 \
gating_output=router_logits, and is_power_of_two(e_score_correction_bias.shape[0]):
topk=top_k,
renormalize=renormalize, # moe_fused_gate kernel ensure that num_experts/num_expert_group does not exceed MAX_VPT=32 now. And when kernel can handle MAX_VPT > 32, we can remove this assertion.
num_expert_group=num_expert_group, topk_weights, topk_ids = ops.moe_fused_gate(
topk_group=topk_group, router_logits,
scoring_func=scoring_func, e_score_correction_bias,
e_score_correction_bias=e_score_correction_bias) num_expert_group,
topk_group,
top_k,
routed_scaling_factor=routed_scaling_factor,
n_share_experts_fusion=0,
)
else:
topk_weights, topk_ids = grouped_topk(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
elif custom_routing_function is None: elif custom_routing_function is None:
topk_weights, topk_ids = fused_topk(hidden_states=hidden_states, topk_weights, topk_ids = fused_topk(hidden_states=hidden_states,
gating_output=router_logits, gating_output=router_logits,
...@@ -926,7 +947,7 @@ class FusedMoE(torch.nn.Module): ...@@ -926,7 +947,7 @@ class FusedMoE(torch.nn.Module):
e_score_correction_bias=self.e_score_correction_bias, e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation, activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input, apply_router_weight_on_input=self.apply_router_weight_on_input,
use_nn_moe=self.use_nn_moe, use_nn_moe=self.use_nn_moe
) )
if self.dp_size > 1: if self.dp_size > 1:
......
...@@ -142,7 +142,8 @@ class DeepseekV2MoE(nn.Module): ...@@ -142,7 +142,8 @@ class DeepseekV2MoE(nn.Module):
topk_group=config.topk_group, topk_group=config.topk_group,
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
scoring_func=config.scoring_func, scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias,) e_score_correction_bias=self.gate.e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor)
if config.n_shared_experts is not None: if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size * intermediate_size = (config.moe_intermediate_size *
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment