Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
e835a500
"vscode:/vscode.git/clone" did not exist on "ce7350357869bab7a2d8665c37bdf326c9e98b61"
Unverified
Commit
e835a500
authored
Dec 24, 2024
by
Ke Bao
Committed by
GitHub
Dec 24, 2024
Browse files
Reorg moe code (#2563)
parent
23e5e50f
Changes
87
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
38 additions
and
14 deletions
+38
-14
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+31
-7
python/sglang/srt/models/grok.py
python/sglang/srt/models/grok.py
+1
-1
python/sglang/srt/models/mixtral.py
python/sglang/srt/models/mixtral.py
+2
-2
python/sglang/srt/models/olmoe.py
python/sglang/srt/models/olmoe.py
+1
-1
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+1
-1
python/sglang/srt/models/xverse_moe.py
python/sglang/srt/models/xverse_moe.py
+1
-1
test/srt/test_fused_moe.py
test/srt/test_fused_moe.py
+1
-1
No files found.
python/sglang/srt/models/deepseek_v2.py
View file @
e835a500
...
@@ -19,6 +19,7 @@
...
@@ -19,6 +19,7 @@
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Tuple
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Tuple
import
torch
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
...
@@ -31,8 +32,6 @@ from vllm.distributed import (
...
@@ -31,8 +32,6 @@ from vllm.distributed import (
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.ep_moe.layer
import
EPMoE
from
sglang.srt.layers.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
(
from
sglang.srt.layers.linear
import
(
ColumnParallelLinear
,
ColumnParallelLinear
,
...
@@ -41,6 +40,8 @@ from sglang.srt.layers.linear import (
...
@@ -41,6 +40,8 @@ from sglang.srt.layers.linear import (
RowParallelLinear
,
RowParallelLinear
,
)
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe.ep_moe.layer
import
EPMoE
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.vocab_parallel_embedding
import
(
from
sglang.srt.layers.vocab_parallel_embedding
import
(
...
@@ -90,6 +91,24 @@ class DeepseekV2MLP(nn.Module):
...
@@ -90,6 +91,24 @@ class DeepseekV2MLP(nn.Module):
return
x
return
x
class
MoEGate
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
empty
((
config
.
n_routed_experts
,
config
.
hidden_size
))
)
if
config
.
topk_method
==
"noaux_tc"
:
self
.
e_score_correction_bias
=
nn
.
Parameter
(
torch
.
empty
((
config
.
n_routed_experts
))
)
else
:
self
.
e_score_correction_bias
=
None
def
forward
(
self
,
hidden_states
):
logits
=
F
.
linear
(
hidden_states
,
self
.
weight
,
None
)
return
logits
class
DeepseekV2MoE
(
nn
.
Module
):
class
DeepseekV2MoE
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
@@ -114,6 +133,8 @@ class DeepseekV2MoE(nn.Module):
...
@@ -114,6 +133,8 @@ class DeepseekV2MoE(nn.Module):
"Only silu is supported for now."
"Only silu is supported for now."
)
)
self
.
gate
=
MoEGate
(
config
=
config
)
MoEImpl
=
EPMoE
if
global_server_args_dict
[
"enable_ep_moe"
]
else
FusedMoE
MoEImpl
=
EPMoE
if
global_server_args_dict
[
"enable_ep_moe"
]
else
FusedMoE
self
.
experts
=
MoEImpl
(
self
.
experts
=
MoEImpl
(
num_experts
=
config
.
n_routed_experts
,
num_experts
=
config
.
n_routed_experts
,
...
@@ -125,11 +146,9 @@ class DeepseekV2MoE(nn.Module):
...
@@ -125,11 +146,9 @@ class DeepseekV2MoE(nn.Module):
use_grouped_topk
=
True
,
use_grouped_topk
=
True
,
num_expert_group
=
config
.
n_group
,
num_expert_group
=
config
.
n_group
,
topk_group
=
config
.
topk_group
,
topk_group
=
config
.
topk_group
,
correction_bias
=
self
.
gate
.
e_score_correction_bias
,
)
)
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
config
.
n_routed_experts
,
bias
=
False
,
quant_config
=
None
)
if
config
.
n_shared_experts
is
not
None
:
if
config
.
n_shared_experts
is
not
None
:
intermediate_size
=
config
.
moe_intermediate_size
*
config
.
n_shared_experts
intermediate_size
=
config
.
moe_intermediate_size
*
config
.
n_shared_experts
self
.
shared_experts
=
DeepseekV2MLP
(
self
.
shared_experts
=
DeepseekV2MLP
(
...
@@ -146,7 +165,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -146,7 +165,7 @@ class DeepseekV2MoE(nn.Module):
if
self
.
n_shared_experts
is
not
None
:
if
self
.
n_shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
shared_output
=
self
.
shared_experts
(
hidden_states
)
# router_logits: (num_tokens, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
router_logits
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
(
final_hidden_states
=
(
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
*
self
.
routed_scaling_factor
*
self
.
routed_scaling_factor
...
@@ -439,7 +458,10 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -439,7 +458,10 @@ class DeepseekV2AttentionMLA(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
)
)
self
.
kv_a_layernorm
=
RMSNorm
(
self
.
kv_lora_rank
,
eps
=
config
.
rms_norm_eps
)
self
.
kv_a_layernorm
=
RMSNorm
(
self
.
kv_lora_rank
,
eps
=
config
.
rms_norm_eps
)
if
rope_scaling
:
rope_scaling
[
"rope_type"
]
=
"deepseek_yarn"
rope_scaling
[
"rope_type"
]
=
"deepseek_yarn"
self
.
rotary_emb
=
get_rope
(
self
.
rotary_emb
=
get_rope
(
qk_rope_head_dim
,
qk_rope_head_dim
,
rotary_dim
=
qk_rope_head_dim
,
rotary_dim
=
qk_rope_head_dim
,
...
@@ -454,6 +476,8 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -454,6 +476,8 @@ class DeepseekV2AttentionMLA(nn.Module):
scaling_factor
=
rope_scaling
[
"factor"
]
scaling_factor
=
rope_scaling
[
"factor"
]
mscale
=
yarn_get_mscale
(
scaling_factor
,
float
(
mscale_all_dim
))
mscale
=
yarn_get_mscale
(
scaling_factor
,
float
(
mscale_all_dim
))
self
.
scaling
=
self
.
scaling
*
mscale
*
mscale
self
.
scaling
=
self
.
scaling
*
mscale
*
mscale
else
:
self
.
rotary_emb
.
forward
=
self
.
rotary_emb
.
forward_native
self
.
attn_mqa
=
RadixAttention
(
self
.
attn_mqa
=
RadixAttention
(
self
.
num_local_heads
,
self
.
num_local_heads
,
...
...
python/sglang/srt/models/grok.py
View file @
e835a500
...
@@ -26,7 +26,6 @@ from vllm.distributed import get_tensor_model_parallel_world_size
...
@@ -26,7 +26,6 @@ from vllm.distributed import get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.activation
import
GeluAndMul
from
sglang.srt.layers.activation
import
GeluAndMul
from
sglang.srt.layers.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
(
from
sglang.srt.layers.linear
import
(
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
...
@@ -35,6 +34,7 @@ from sglang.srt.layers.linear import (
...
@@ -35,6 +34,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear
,
RowParallelLinear
,
)
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.vocab_parallel_embedding
import
(
from
sglang.srt.layers.vocab_parallel_embedding
import
(
...
...
python/sglang/srt/models/mixtral.py
View file @
e835a500
...
@@ -27,8 +27,6 @@ from vllm.distributed import (
...
@@ -27,8 +27,6 @@ from vllm.distributed import (
)
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.ep_moe.layer
import
EPMoE
from
sglang.srt.layers.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
(
from
sglang.srt.layers.linear
import
(
QKVParallelLinear
,
QKVParallelLinear
,
...
@@ -36,6 +34,8 @@ from sglang.srt.layers.linear import (
...
@@ -36,6 +34,8 @@ from sglang.srt.layers.linear import (
RowParallelLinear
,
RowParallelLinear
,
)
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe.ep_moe.layer
import
EPMoE
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.vocab_parallel_embedding
import
(
from
sglang.srt.layers.vocab_parallel_embedding
import
(
...
...
python/sglang/srt/models/olmoe.py
View file @
e835a500
...
@@ -36,9 +36,9 @@ from vllm.model_executor.layers.linear import (
...
@@ -36,9 +36,9 @@ from vllm.model_executor.layers.linear import (
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
,
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
,
LogitsProcessorOutput
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.vocab_parallel_embedding
import
(
from
sglang.srt.layers.vocab_parallel_embedding
import
(
...
...
python/sglang/srt/models/qwen2_moe.py
View file @
e835a500
...
@@ -29,7 +29,6 @@ from vllm.distributed import (
...
@@ -29,7 +29,6 @@ from vllm.distributed import (
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
(
from
sglang.srt.layers.linear
import
(
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
...
@@ -38,6 +37,7 @@ from sglang.srt.layers.linear import (
...
@@ -38,6 +37,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear
,
RowParallelLinear
,
)
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.vocab_parallel_embedding
import
(
from
sglang.srt.layers.vocab_parallel_embedding
import
(
...
...
python/sglang/srt/models/xverse_moe.py
View file @
e835a500
...
@@ -33,8 +33,8 @@ from vllm.model_executor.layers.linear import (
...
@@ -33,8 +33,8 @@ from vllm.model_executor.layers.linear import (
)
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.fused_moe_triton
import
fused_moe
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe.fused_moe_triton
import
fused_moe
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.vocab_parallel_embedding
import
(
from
sglang.srt.layers.vocab_parallel_embedding
import
(
...
...
test/srt/test_fused_moe.py
View file @
e835a500
...
@@ -4,7 +4,7 @@ import torch
...
@@ -4,7 +4,7 @@ import torch
from
vllm.model_executor.layers.fused_moe
import
fused_moe
as
fused_moe_vllm
from
vllm.model_executor.layers.fused_moe
import
fused_moe
as
fused_moe_vllm
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.fused_moe_triton.fused_moe
import
fused_moe
from
sglang.srt.layers.
moe.
fused_moe_triton.fused_moe
import
fused_moe
class
TestFusedMOE
(
unittest
.
TestCase
):
class
TestFusedMOE
(
unittest
.
TestCase
):
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment