Unverified Commit e835a500 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Reorg moe code (#2563)

parent 23e5e50f
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
from typing import Any, Dict, Iterable, Optional, Tuple from typing import Any, Dict, Iterable, Optional, Tuple
import torch import torch
import torch.nn.functional as F
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
...@@ -31,8 +32,6 @@ from vllm.distributed import ( ...@@ -31,8 +32,6 @@ from vllm.distributed import (
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.ep_moe.layer import EPMoE
from sglang.srt.layers.fused_moe_triton import FusedMoE
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
...@@ -41,6 +40,8 @@ from sglang.srt.layers.linear import ( ...@@ -41,6 +40,8 @@ from sglang.srt.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
...@@ -90,6 +91,24 @@ class DeepseekV2MLP(nn.Module): ...@@ -90,6 +91,24 @@ class DeepseekV2MLP(nn.Module):
return x return x
class MoEGate(nn.Module):
def __init__(self, config):
super().__init__()
self.weight = nn.Parameter(
torch.empty((config.n_routed_experts, config.hidden_size))
)
if config.topk_method == "noaux_tc":
self.e_score_correction_bias = nn.Parameter(
torch.empty((config.n_routed_experts))
)
else:
self.e_score_correction_bias = None
def forward(self, hidden_states):
logits = F.linear(hidden_states, self.weight, None)
return logits
class DeepseekV2MoE(nn.Module): class DeepseekV2MoE(nn.Module):
def __init__( def __init__(
...@@ -114,6 +133,8 @@ class DeepseekV2MoE(nn.Module): ...@@ -114,6 +133,8 @@ class DeepseekV2MoE(nn.Module):
"Only silu is supported for now." "Only silu is supported for now."
) )
self.gate = MoEGate(config=config)
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
self.experts = MoEImpl( self.experts = MoEImpl(
num_experts=config.n_routed_experts, num_experts=config.n_routed_experts,
...@@ -125,11 +146,9 @@ class DeepseekV2MoE(nn.Module): ...@@ -125,11 +146,9 @@ class DeepseekV2MoE(nn.Module):
use_grouped_topk=True, use_grouped_topk=True,
num_expert_group=config.n_group, num_expert_group=config.n_group,
topk_group=config.topk_group, topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias,
) )
self.gate = ReplicatedLinear(
config.hidden_size, config.n_routed_experts, bias=False, quant_config=None
)
if config.n_shared_experts is not None: if config.n_shared_experts is not None:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = DeepseekV2MLP( self.shared_experts = DeepseekV2MLP(
...@@ -146,7 +165,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -146,7 +165,7 @@ class DeepseekV2MoE(nn.Module):
if self.n_shared_experts is not None: if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits = self.gate(hidden_states)
final_hidden_states = ( final_hidden_states = (
self.experts(hidden_states=hidden_states, router_logits=router_logits) self.experts(hidden_states=hidden_states, router_logits=router_logits)
* self.routed_scaling_factor * self.routed_scaling_factor
...@@ -439,7 +458,10 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -439,7 +458,10 @@ class DeepseekV2AttentionMLA(nn.Module):
quant_config=quant_config, quant_config=quant_config,
) )
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
rope_scaling["rope_type"] = "deepseek_yarn"
if rope_scaling:
rope_scaling["rope_type"] = "deepseek_yarn"
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
qk_rope_head_dim, qk_rope_head_dim,
rotary_dim=qk_rope_head_dim, rotary_dim=qk_rope_head_dim,
...@@ -454,6 +476,8 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -454,6 +476,8 @@ class DeepseekV2AttentionMLA(nn.Module):
scaling_factor = rope_scaling["factor"] scaling_factor = rope_scaling["factor"]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale self.scaling = self.scaling * mscale * mscale
else:
self.rotary_emb.forward = self.rotary_emb.forward_native
self.attn_mqa = RadixAttention( self.attn_mqa = RadixAttention(
self.num_local_heads, self.num_local_heads,
......
...@@ -26,7 +26,6 @@ from vllm.distributed import get_tensor_model_parallel_world_size ...@@ -26,7 +26,6 @@ from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.layers.activation import GeluAndMul from sglang.srt.layers.activation import GeluAndMul
from sglang.srt.layers.fused_moe_triton import FusedMoE
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -35,6 +34,7 @@ from sglang.srt.layers.linear import ( ...@@ -35,6 +34,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
......
...@@ -27,8 +27,6 @@ from vllm.distributed import ( ...@@ -27,8 +27,6 @@ from vllm.distributed import (
) )
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.layers.ep_moe.layer import EPMoE
from sglang.srt.layers.fused_moe_triton import FusedMoE
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
QKVParallelLinear, QKVParallelLinear,
...@@ -36,6 +34,8 @@ from sglang.srt.layers.linear import ( ...@@ -36,6 +34,8 @@ from sglang.srt.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
......
...@@ -36,9 +36,9 @@ from vllm.model_executor.layers.linear import ( ...@@ -36,9 +36,9 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.fused_moe_triton import FusedMoE
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
......
...@@ -29,7 +29,6 @@ from vllm.distributed import ( ...@@ -29,7 +29,6 @@ from vllm.distributed import (
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.fused_moe_triton import FusedMoE
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -38,6 +37,7 @@ from sglang.srt.layers.linear import ( ...@@ -38,6 +37,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
......
...@@ -33,8 +33,8 @@ from vllm.model_executor.layers.linear import ( ...@@ -33,8 +33,8 @@ from vllm.model_executor.layers.linear import (
) )
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.layers.fused_moe_triton import fused_moe
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
......
...@@ -4,7 +4,7 @@ import torch ...@@ -4,7 +4,7 @@ import torch
from vllm.model_executor.layers.fused_moe import fused_moe as fused_moe_vllm from vllm.model_executor.layers.fused_moe import fused_moe as fused_moe_vllm
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
class TestFusedMOE(unittest.TestCase): class TestFusedMOE(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment