Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c6d549e7
Unverified
Commit
c6d549e7
authored
Mar 23, 2025
by
fzyzcjy
Committed by
GitHub
Mar 22, 2025
Browse files
Multiple tiny code cleanups (#4608)
parent
3c09548d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
3 additions
and
8 deletions
+3
-8
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
+1
-2
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+2
-6
No files found.
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
View file @
c6d549e7
...
@@ -185,7 +185,6 @@ class DeepEPDispatcher:
...
@@ -185,7 +185,6 @@ class DeepEPDispatcher:
previous_event
=
None
,
previous_event
=
None
,
num_max_dispatch_tokens_per_rank
:
int
=
128
,
num_max_dispatch_tokens_per_rank
:
int
=
128
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
self
.
hidden_shape
=
hidden_states
.
shape
topk_idx
=
topk_idx
.
to
(
torch
.
int64
)
topk_idx
=
topk_idx
.
to
(
torch
.
int64
)
# Todo: enable low latency dispatch
# Todo: enable low latency dispatch
if
True
:
# not forward_mode.is_decode():
if
True
:
# not forward_mode.is_decode():
...
@@ -375,7 +374,7 @@ class DeepEPDispatcher:
...
@@ -375,7 +374,7 @@ class DeepEPDispatcher:
hidden_states
,
self
.
topk_idx
,
self
.
topk_weights
,
self
.
handle
hidden_states
,
self
.
topk_idx
,
self
.
topk_weights
,
self
.
handle
)
)
self
.
handle
=
None
self
.
handle
=
None
return
hidden_states
.
view
(
self
.
hidden_shape
)
return
hidden_states
def
combine_normal
(
self
,
x
:
torch
.
Tensor
,
handle
:
Tuple
,
previous_event
=
None
):
def
combine_normal
(
self
,
x
:
torch
.
Tensor
,
handle
:
Tuple
,
previous_event
=
None
):
combined_x
,
_
,
event
=
self
.
buffer_normal
.
combine
(
combined_x
,
_
,
event
=
self
.
buffer_normal
.
combine
(
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
c6d549e7
...
@@ -250,8 +250,6 @@ class DeepseekV2MoE(nn.Module):
...
@@ -250,8 +250,6 @@ class DeepseekV2MoE(nn.Module):
return
self
.
forward_deepep
(
hidden_states
,
forward_mode
)
return
self
.
forward_deepep
(
hidden_states
,
forward_mode
)
def
forward_normal
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward_normal
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
if
self
.
n_shared_experts
is
not
None
:
if
self
.
n_shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
shared_output
=
self
.
shared_experts
(
hidden_states
)
# router_logits: (num_tokens, n_experts)
# router_logits: (num_tokens, n_experts)
...
@@ -264,13 +262,11 @@ class DeepseekV2MoE(nn.Module):
...
@@ -264,13 +262,11 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states
=
final_hidden_states
+
shared_output
final_hidden_states
=
final_hidden_states
+
shared_output
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
)
return
final_hidden_states
def
forward_deepep
(
def
forward_deepep
(
self
,
hidden_states
:
torch
.
Tensor
,
forward_mode
:
ForwardMode
self
,
hidden_states
:
torch
.
Tensor
,
forward_mode
:
ForwardMode
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
shared_output
=
None
shared_output
=
None
topk_idx
=
torch
.
full
(
topk_idx
=
torch
.
full
(
(
0
,
self
.
top_k
),
-
1
,
dtype
=
torch
.
int
,
device
=
hidden_states
.
device
(
0
,
self
.
top_k
),
-
1
,
dtype
=
torch
.
int
,
device
=
hidden_states
.
device
...
@@ -319,7 +315,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -319,7 +315,7 @@ class DeepseekV2MoE(nn.Module):
if
shared_output
is
not
None
:
if
shared_output
is
not
None
:
final_hidden_states
=
final_hidden_states
+
shared_output
final_hidden_states
=
final_hidden_states
+
shared_output
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
)
return
final_hidden_states
def
yarn_get_mscale
(
scale
:
float
=
1
,
mscale
:
float
=
1
)
->
float
:
def
yarn_get_mscale
(
scale
:
float
=
1
,
mscale
:
float
=
1
)
->
float
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment