Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e5ce395a
"git@developer.sourcefind.cn:change/sglang.git" did not exist on "022012aae83d2ae4a0f7133c55245d42e8613901"
Unverified
Commit
e5ce395a
authored
Feb 18, 2025
by
Ke Bao
Committed by
GitHub
Feb 18, 2025
Browse files
Fix draft decode max batch size (#3676)
parent
f983213a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
5 additions
and
2 deletions
+5
-2
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+1
-1
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+1
-1
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
...glang/srt/layers/attention/triton_ops/decode_attention.py
+3
-0
No files found.
python/sglang/srt/layers/attention/flashinfer_backend.py
View file @
e5ce395a
...
@@ -1094,7 +1094,7 @@ class FlashInferMultiStepDraftBackend:
...
@@ -1094,7 +1094,7 @@ class FlashInferMultiStepDraftBackend:
self
.
topk
=
topk
self
.
topk
=
topk
self
.
speculative_num_steps
=
speculative_num_steps
self
.
speculative_num_steps
=
speculative_num_steps
self
.
generate_draft_decode_kv_indices
=
generate_draft_decode_kv_indices
self
.
generate_draft_decode_kv_indices
=
generate_draft_decode_kv_indices
max_bs
=
model_runner
.
req_to_token_pool
.
size
max_bs
=
model_runner
.
req_to_token_pool
.
size
*
self
.
topk
self
.
kv_indptr
=
torch
.
zeros
(
self
.
kv_indptr
=
torch
.
zeros
(
(
(
self
.
speculative_num_steps
,
self
.
speculative_num_steps
,
...
...
python/sglang/srt/layers/attention/triton_backend.py
View file @
e5ce395a
...
@@ -474,7 +474,7 @@ class TritonMultiStepDraftBackend:
...
@@ -474,7 +474,7 @@ class TritonMultiStepDraftBackend:
self
.
topk
=
topk
self
.
topk
=
topk
self
.
speculative_num_steps
=
speculative_num_steps
self
.
speculative_num_steps
=
speculative_num_steps
self
.
generate_draft_decode_kv_indices
=
generate_draft_decode_kv_indices
self
.
generate_draft_decode_kv_indices
=
generate_draft_decode_kv_indices
max_bs
=
model_runner
.
req_to_token_pool
.
size
max_bs
=
model_runner
.
req_to_token_pool
.
size
*
self
.
topk
self
.
kv_indptr
=
torch
.
zeros
(
self
.
kv_indptr
=
torch
.
zeros
(
(
(
self
.
speculative_num_steps
,
self
.
speculative_num_steps
,
...
...
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
View file @
e5ce395a
...
@@ -635,6 +635,9 @@ def decode_attention_fwd(
...
@@ -635,6 +635,9 @@ def decode_attention_fwd(
logit_cap
=
0.0
,
logit_cap
=
0.0
,
):
):
assert
num_kv_splits
==
attn_logits
.
shape
[
2
]
assert
num_kv_splits
==
attn_logits
.
shape
[
2
]
assert
q
.
shape
[
0
]
<=
kv_indptr
.
shape
[
0
]
-
1
assert
q
.
shape
[
0
]
<=
attn_logits
.
shape
[
0
]
kv_group_num
=
q
.
shape
[
1
]
//
v_buffer
.
shape
[
1
]
kv_group_num
=
q
.
shape
[
1
]
//
v_buffer
.
shape
[
1
]
if
kv_group_num
==
1
:
if
kv_group_num
==
1
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment