Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2f6af1a3
Unverified
Commit
2f6af1a3
authored
Oct 31, 2025
by
Yuhong Guo
Committed by
GitHub
Oct 31, 2025
Browse files
Enable bailing_moe to support TP=16 (#12369)
parent
50b6842b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
2 deletions
+9
-2
python/sglang/srt/models/bailing_moe.py
python/sglang/srt/models/bailing_moe.py
+9
-2
No files found.
python/sglang/srt/models/bailing_moe.py
View file @
2f6af1a3
...
...
@@ -420,14 +420,21 @@ class BailingMoEAttention(nn.Module):
attn_tp_size
=
get_attention_tp_size
()
assert
self
.
total_num_heads
%
attn_tp_size
==
0
assert
self
.
total_kv_heads
%
attn_tp_size
==
0
if
self
.
total_kv_heads
>=
attn_tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_kv_heads
%
attn_tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
attn_tp_size
%
self
.
total_kv_heads
==
0
assert
self
.
total_num_heads
>=
self
.
total_kv_heads
self
.
num_heads
=
self
.
total_num_heads
//
attn_tp_size
self
.
head_dim
=
config
.
head_dim
or
(
self
.
hidden_size
//
self
.
total_num_heads
)
self
.
q_size
=
self
.
head_dim
*
self
.
num_heads
self
.
num_kv_heads
=
self
.
total_kv_heads
//
attn_tp_size
self
.
num_kv_heads
=
max
(
1
,
self
.
total_kv_heads
//
attn_tp_size
)
self
.
kv_size
=
max
(
1
,
self
.
num_kv_heads
*
self
.
head_dim
)
self
.
scale
=
self
.
head_dim
**-
0.5
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment