Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e835a500
Unverified
Commit
e835a500
authored
Dec 24, 2024
by
Ke Bao
Committed by
GitHub
Dec 24, 2024
Browse files
Reorg moe code (#2563)
parent
23e5e50f
Changes
87
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
38 additions
and
14 deletions
+38
-14
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+31
-7
python/sglang/srt/models/grok.py
python/sglang/srt/models/grok.py
+1
-1
python/sglang/srt/models/mixtral.py
python/sglang/srt/models/mixtral.py
+2
-2
python/sglang/srt/models/olmoe.py
python/sglang/srt/models/olmoe.py
+1
-1
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+1
-1
python/sglang/srt/models/xverse_moe.py
python/sglang/srt/models/xverse_moe.py
+1
-1
test/srt/test_fused_moe.py
test/srt/test_fused_moe.py
+1
-1
No files found.
python/sglang/srt/models/deepseek_v2.py
View file @
e835a500
...
...
@@ -19,6 +19,7 @@
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Tuple
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm
import
_custom_ops
as
ops
...
...
@@ -31,8 +32,6 @@ from vllm.distributed import (
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.ep_moe.layer
import
EPMoE
from
sglang.srt.layers.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
(
ColumnParallelLinear
,
...
...
@@ -41,6 +40,8 @@ from sglang.srt.layers.linear import (
RowParallelLinear
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe.ep_moe.layer
import
EPMoE
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.vocab_parallel_embedding
import
(
...
...
@@ -90,6 +91,24 @@ class DeepseekV2MLP(nn.Module):
return
x
class
MoEGate
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
empty
((
config
.
n_routed_experts
,
config
.
hidden_size
))
)
if
config
.
topk_method
==
"noaux_tc"
:
self
.
e_score_correction_bias
=
nn
.
Parameter
(
torch
.
empty
((
config
.
n_routed_experts
))
)
else
:
self
.
e_score_correction_bias
=
None
def
forward
(
self
,
hidden_states
):
logits
=
F
.
linear
(
hidden_states
,
self
.
weight
,
None
)
return
logits
class
DeepseekV2MoE
(
nn
.
Module
):
def
__init__
(
...
...
@@ -114,6 +133,8 @@ class DeepseekV2MoE(nn.Module):
"Only silu is supported for now."
)
self
.
gate
=
MoEGate
(
config
=
config
)
MoEImpl
=
EPMoE
if
global_server_args_dict
[
"enable_ep_moe"
]
else
FusedMoE
self
.
experts
=
MoEImpl
(
num_experts
=
config
.
n_routed_experts
,
...
...
@@ -125,11 +146,9 @@ class DeepseekV2MoE(nn.Module):
use_grouped_topk
=
True
,
num_expert_group
=
config
.
n_group
,
topk_group
=
config
.
topk_group
,
correction_bias
=
self
.
gate
.
e_score_correction_bias
,
)
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
config
.
n_routed_experts
,
bias
=
False
,
quant_config
=
None
)
if
config
.
n_shared_experts
is
not
None
:
intermediate_size
=
config
.
moe_intermediate_size
*
config
.
n_shared_experts
self
.
shared_experts
=
DeepseekV2MLP
(
...
...
@@ -146,7 +165,7 @@ class DeepseekV2MoE(nn.Module):
if
self
.
n_shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
router_logits
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
(
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
*
self
.
routed_scaling_factor
...
...
@@ -439,7 +458,10 @@ class DeepseekV2AttentionMLA(nn.Module):
quant_config
=
quant_config
,
)
self
.
kv_a_layernorm
=
RMSNorm
(
self
.
kv_lora_rank
,
eps
=
config
.
rms_norm_eps
)
rope_scaling
[
"rope_type"
]
=
"deepseek_yarn"
if
rope_scaling
:
rope_scaling
[
"rope_type"
]
=
"deepseek_yarn"
self
.
rotary_emb
=
get_rope
(
qk_rope_head_dim
,
rotary_dim
=
qk_rope_head_dim
,
...
...
@@ -454,6 +476,8 @@ class DeepseekV2AttentionMLA(nn.Module):
scaling_factor
=
rope_scaling
[
"factor"
]
mscale
=
yarn_get_mscale
(
scaling_factor
,
float
(
mscale_all_dim
))
self
.
scaling
=
self
.
scaling
*
mscale
*
mscale
else
:
self
.
rotary_emb
.
forward
=
self
.
rotary_emb
.
forward_native
self
.
attn_mqa
=
RadixAttention
(
self
.
num_local_heads
,
...
...
python/sglang/srt/models/grok.py
View file @
e835a500
...
...
@@ -26,7 +26,6 @@ from vllm.distributed import get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.activation
import
GeluAndMul
from
sglang.srt.layers.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
(
MergedColumnParallelLinear
,
...
...
@@ -35,6 +34,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.vocab_parallel_embedding
import
(
...
...
python/sglang/srt/models/mixtral.py
View file @
e835a500
...
...
@@ -27,8 +27,6 @@ from vllm.distributed import (
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.ep_moe.layer
import
EPMoE
from
sglang.srt.layers.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
(
QKVParallelLinear
,
...
...
@@ -36,6 +34,8 @@ from sglang.srt.layers.linear import (
RowParallelLinear
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe.ep_moe.layer
import
EPMoE
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.vocab_parallel_embedding
import
(
...
...
python/sglang/srt/models/olmoe.py
View file @
e835a500
...
...
@@ -36,9 +36,9 @@ from vllm.model_executor.layers.linear import (
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
,
LogitsProcessorOutput
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.vocab_parallel_embedding
import
(
...
...
python/sglang/srt/models/qwen2_moe.py
View file @
e835a500
...
...
@@ -29,7 +29,6 @@ from vllm.distributed import (
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
(
MergedColumnParallelLinear
,
...
...
@@ -38,6 +37,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.vocab_parallel_embedding
import
(
...
...
python/sglang/srt/models/xverse_moe.py
View file @
e835a500
...
...
@@ -33,8 +33,8 @@ from vllm.model_executor.layers.linear import (
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.fused_moe_triton
import
fused_moe
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe.fused_moe_triton
import
fused_moe
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.vocab_parallel_embedding
import
(
...
...
test/srt/test_fused_moe.py
View file @
e835a500
...
...
@@ -4,7 +4,7 @@ import torch
from
vllm.model_executor.layers.fused_moe
import
fused_moe
as
fused_moe_vllm
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.fused_moe_triton.fused_moe
import
fused_moe
from
sglang.srt.layers.
moe.
fused_moe_triton.fused_moe
import
fused_moe
class
TestFusedMOE
(
unittest
.
TestCase
):
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment