Unverified Commit e835a500 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Reorg moe code (#2563)

parent 23e5e50f
......@@ -19,6 +19,7 @@
from typing import Any, Dict, Iterable, Optional, Tuple
import torch
import torch.nn.functional as F
from torch import nn
from transformers import PretrainedConfig
from vllm import _custom_ops as ops
......@@ -31,8 +32,6 @@ from vllm.distributed import (
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.ep_moe.layer import EPMoE
from sglang.srt.layers.fused_moe_triton import FusedMoE
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
ColumnParallelLinear,
......@@ -41,6 +40,8 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import (
......@@ -90,6 +91,24 @@ class DeepseekV2MLP(nn.Module):
return x
class MoEGate(nn.Module):
def __init__(self, config):
super().__init__()
self.weight = nn.Parameter(
torch.empty((config.n_routed_experts, config.hidden_size))
)
if config.topk_method == "noaux_tc":
self.e_score_correction_bias = nn.Parameter(
torch.empty((config.n_routed_experts))
)
else:
self.e_score_correction_bias = None
def forward(self, hidden_states):
logits = F.linear(hidden_states, self.weight, None)
return logits
class DeepseekV2MoE(nn.Module):
def __init__(
......@@ -114,6 +133,8 @@ class DeepseekV2MoE(nn.Module):
"Only silu is supported for now."
)
self.gate = MoEGate(config=config)
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
self.experts = MoEImpl(
num_experts=config.n_routed_experts,
......@@ -125,11 +146,9 @@ class DeepseekV2MoE(nn.Module):
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias,
)
self.gate = ReplicatedLinear(
config.hidden_size, config.n_routed_experts, bias=False, quant_config=None
)
if config.n_shared_experts is not None:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = DeepseekV2MLP(
......@@ -146,7 +165,7 @@ class DeepseekV2MoE(nn.Module):
if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
router_logits = self.gate(hidden_states)
final_hidden_states = (
self.experts(hidden_states=hidden_states, router_logits=router_logits)
* self.routed_scaling_factor
......@@ -439,7 +458,10 @@ class DeepseekV2AttentionMLA(nn.Module):
quant_config=quant_config,
)
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
rope_scaling["rope_type"] = "deepseek_yarn"
if rope_scaling:
rope_scaling["rope_type"] = "deepseek_yarn"
self.rotary_emb = get_rope(
qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
......@@ -454,6 +476,8 @@ class DeepseekV2AttentionMLA(nn.Module):
scaling_factor = rope_scaling["factor"]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale
else:
self.rotary_emb.forward = self.rotary_emb.forward_native
self.attn_mqa = RadixAttention(
self.num_local_heads,
......
......@@ -26,7 +26,6 @@ from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.layers.activation import GeluAndMul
from sglang.srt.layers.fused_moe_triton import FusedMoE
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
MergedColumnParallelLinear,
......@@ -35,6 +34,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import (
......
......@@ -27,8 +27,6 @@ from vllm.distributed import (
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.layers.ep_moe.layer import EPMoE
from sglang.srt.layers.fused_moe_triton import FusedMoE
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
QKVParallelLinear,
......@@ -36,6 +34,8 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import (
......
......@@ -36,9 +36,9 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.fused_moe_triton import FusedMoE
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import (
......
......@@ -29,7 +29,6 @@ from vllm.distributed import (
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.fused_moe_triton import FusedMoE
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
MergedColumnParallelLinear,
......@@ -38,6 +37,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import (
......
......@@ -33,8 +33,8 @@ from vllm.model_executor.layers.linear import (
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.layers.fused_moe_triton import fused_moe
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import (
......
......@@ -4,7 +4,7 @@ import torch
from vllm.model_executor.layers.fused_moe import fused_moe as fused_moe_vllm
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
class TestFusedMOE(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment