Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c0fb25e9
Unverified
Commit
c0fb25e9
authored
Jul 24, 2025
by
Cheng Wan
Committed by
GitHub
Jul 24, 2025
Browse files
DP Enhancement (#8280)
parent
28d4d472
Changes
20
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
665 additions
and
1116 deletions
+665
-1116
python/sglang/srt/distributed/parallel_state.py
python/sglang/srt/distributed/parallel_state.py
+9
-0
python/sglang/srt/layers/attention/base_attn_backend.py
python/sglang/srt/layers/attention/base_attn_backend.py
+3
-1
python/sglang/srt/layers/communicator.py
python/sglang/srt/layers/communicator.py
+12
-12
python/sglang/srt/layers/dp_attention.py
python/sglang/srt/layers/dp_attention.py
+72
-24
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+34
-24
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+5
-3
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+2
-3
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+61
-25
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+193
-22
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+21
-4
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+1
-2
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+0
-4
python/sglang/srt/models/qwen3_moe.py
python/sglang/srt/models/qwen3_moe.py
+1
-6
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
...n/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
+33
-27
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
...g/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
+37
-36
python/sglang/srt/speculative/eagle_utils.py
python/sglang/srt/speculative/eagle_utils.py
+45
-23
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+59
-44
python/sglang/srt/two_batch_overlap.py
python/sglang/srt/two_batch_overlap.py
+1
-0
test/srt/test_deepep_small.py
test/srt/test_deepep_small.py
+6
-6
test/srt/test_hybrid_dp_ep_tp_mtp.py
test/srt/test_hybrid_dp_ep_tp_mtp.py
+70
-850
No files found.
python/sglang/srt/distributed/parallel_state.py
View file @
c0fb25e9
...
@@ -545,6 +545,15 @@ class GroupCoordinator:
...
@@ -545,6 +545,15 @@ class GroupCoordinator:
else
:
else
:
torch
.
distributed
.
all_reduce
(
input_
,
group
=
self
.
device_group
)
torch
.
distributed
.
all_reduce
(
input_
,
group
=
self
.
device_group
)
def
reduce_scatter_tensor
(
self
,
output
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
)
->
None
:
# TODO(ch-wan): support other backends
torch
.
distributed
.
reduce_scatter_tensor
(
output
,
input
,
group
=
self
.
device_group
)
return
output
def
reduce_scatter
(
def
reduce_scatter
(
self
,
self
,
output
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
...
...
python/sglang/srt/layers/attention/base_attn_backend.py
View file @
c0fb25e9
...
@@ -65,7 +65,9 @@ class AttentionBackend(ABC):
...
@@ -65,7 +65,9 @@ class AttentionBackend(ABC):
**
kwargs
,
**
kwargs
,
):
):
"""Run forward on an attention layer."""
"""Run forward on an attention layer."""
if
forward_batch
.
forward_mode
.
is_decode
():
if
forward_batch
.
forward_mode
.
is_idle
():
return
q
.
new_empty
(
q
.
shape
[
0
],
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
elif
forward_batch
.
forward_mode
.
is_decode
():
return
self
.
forward_decode
(
return
self
.
forward_decode
(
q
,
q
,
k
,
k
,
...
...
python/sglang/srt/layers/communicator.py
View file @
c0fb25e9
...
@@ -24,8 +24,8 @@ from sglang.srt.distributed import (
...
@@ -24,8 +24,8 @@ from sglang.srt.distributed import (
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce
,
)
)
from
sglang.srt.layers.dp_attention
import
(
from
sglang.srt.layers.dp_attention
import
(
attn_tp_all_gather
,
attn_tp_all_gather
_into_tensor
,
attn_tp_reduce_scatter
,
attn_tp_reduce_scatter
_tensor
,
dp_gather_partial
,
dp_gather_partial
,
dp_scatter
,
dp_scatter
,
get_attention_dp_size
,
get_attention_dp_size
,
...
@@ -309,8 +309,8 @@ class CommunicateSimpleFn:
...
@@ -309,8 +309,8 @@ class CommunicateSimpleFn:
forward_batch
.
gathered_buffer
[:
forward_batch
.
input_ids
.
shape
[
0
]],
forward_batch
.
gathered_buffer
[:
forward_batch
.
input_ids
.
shape
[
0
]],
hidden_states
,
hidden_states
,
)
)
attn_tp_all_gather
(
attn_tp_all_gather
_into_tensor
(
list
(
hidden_states
.
tensor_split
(
context
.
attn_tp_size
))
,
hidden_states
,
local_hidden_states
,
local_hidden_states
,
)
)
return
hidden_states
return
hidden_states
...
@@ -400,9 +400,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
...
@@ -400,9 +400,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
].
clone
(),
].
clone
(),
residual
,
residual
,
)
)
attn_tp_all_gather
(
attn_tp_all_gather_into_tensor
(
residual
,
local_residual
)
list
(
residual
.
tensor_split
(
context
.
attn_tp_size
)),
local_residual
)
if
context
.
attn_dp_size
!=
1
:
if
context
.
attn_dp_size
!=
1
:
if
context
.
attn_tp_rank
==
0
:
if
context
.
attn_tp_rank
==
0
:
hidden_states
+=
residual
hidden_states
+=
residual
...
@@ -442,9 +440,11 @@ class CommunicateWithAllReduceAndLayerNormFn:
...
@@ -442,9 +440,11 @@ class CommunicateWithAllReduceAndLayerNormFn:
*
,
*
,
residual_input_mode
,
residual_input_mode
,
):
):
tensor_list
=
list
(
hidden_states
.
tensor_split
(
context
.
attn_tp_size
))
input_hidden_states
=
hidden_states
hidden_states
=
tensor_list
[
context
.
attn_tp_rank
]
hidden_states
=
hidden_states
.
tensor_split
(
context
.
attn_tp_size
)[
attn_tp_reduce_scatter
(
hidden_states
,
tensor_list
)
context
.
attn_tp_rank
]
attn_tp_reduce_scatter_tensor
(
hidden_states
,
input_hidden_states
)
if
residual_input_mode
==
ScatterMode
.
TP_ATTN_FULL
:
if
residual_input_mode
==
ScatterMode
.
TP_ATTN_FULL
:
residual
=
residual
.
tensor_split
(
context
.
attn_tp_size
)[
context
.
attn_tp_rank
]
residual
=
residual
.
tensor_split
(
context
.
attn_tp_size
)[
context
.
attn_tp_rank
]
if
hidden_states
.
shape
[
0
]
!=
0
:
if
hidden_states
.
shape
[
0
]
!=
0
:
...
@@ -547,8 +547,8 @@ class CommunicateSummableTensorPairFn:
...
@@ -547,8 +547,8 @@ class CommunicateSummableTensorPairFn:
forward_batch
.
gathered_buffer
[:
forward_batch
.
input_ids
.
shape
[
0
]],
forward_batch
.
gathered_buffer
[:
forward_batch
.
input_ids
.
shape
[
0
]],
hidden_states
,
hidden_states
,
)
)
attn_tp_all_gather
(
attn_tp_all_gather
_into_tensor
(
list
(
hidden_states
.
tensor_split
(
context
.
attn_tp_size
))
,
hidden_states
,
local_hidden_states
,
local_hidden_states
,
)
)
return
hidden_states
,
residual
return
hidden_states
,
residual
...
...
python/sglang/srt/layers/dp_attention.py
View file @
c0fb25e9
...
@@ -3,7 +3,8 @@ from __future__ import annotations
...
@@ -3,7 +3,8 @@ from __future__ import annotations
import
functools
import
functools
import
logging
import
logging
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
TYPE_CHECKING
,
List
from
enum
import
IntEnum
,
auto
from
typing
import
TYPE_CHECKING
,
List
,
Tuple
import
torch
import
torch
import
triton
import
triton
...
@@ -30,6 +31,34 @@ _LOCAL_ATTN_DP_SIZE = None
...
@@ -30,6 +31,34 @@ _LOCAL_ATTN_DP_SIZE = None
_LOCAL_ATTN_DP_RANK
=
None
_LOCAL_ATTN_DP_RANK
=
None
class
DPPaddingMode
(
IntEnum
):
# Padding tokens to max length and then gather tokens using `all_gather_into_tensor`
MAX_LEN
=
auto
()
# Padding tokens to sum length and then gather tokens using `all_reduce`
SUM_LEN
=
auto
()
def
is_max_len
(
self
):
return
self
==
DPPaddingMode
.
MAX_LEN
def
is_sum_len
(
self
):
return
self
==
DPPaddingMode
.
SUM_LEN
@
classmethod
def
get_dp_padding_mode
(
cls
,
global_num_tokens
:
List
[
int
])
->
DPPaddingMode
:
# we choose the mode that minimizes the communication cost
max_len
=
max
(
global_num_tokens
)
sum_len
=
sum
(
global_num_tokens
)
if
sum_len
*
2
>
max_len
*
get_attention_dp_size
():
return
cls
.
MAX_LEN
else
:
return
cls
.
SUM_LEN
@
classmethod
def
get_default_mode_in_cuda_graph
(
cls
)
->
DPPaddingMode
:
return
cls
.
MAX_LEN
def
compute_dp_attention_world_info
(
enable_dp_attention
,
tp_rank
,
tp_size
,
dp_size
):
def
compute_dp_attention_world_info
(
enable_dp_attention
,
tp_rank
,
tp_size
,
dp_size
):
if
not
enable_dp_attention
:
if
not
enable_dp_attention
:
return
tp_rank
,
tp_size
,
0
return
tp_rank
,
tp_size
,
0
...
@@ -162,7 +191,7 @@ def disable_dp_size():
...
@@ -162,7 +191,7 @@ def disable_dp_size():
_ATTN_DP_SIZE
=
old_dp_size
_ATTN_DP_SIZE
=
old_dp_size
def
get_dp_local_info
(
forward_batch
:
ForwardBatch
):
def
get_dp_local_info
(
forward_batch
:
ForwardBatch
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
# `get_dp_local_info` is only called in global DP gather and scatter. We use global DP rank here.
# `get_dp_local_info` is only called in global DP gather and scatter. We use global DP rank here.
dp_rank
=
get_attention_dp_rank
()
dp_rank
=
get_attention_dp_rank
()
...
@@ -221,7 +250,7 @@ def memcpy_triton(dst, src, dim, offset, sz, offset_src):
...
@@ -221,7 +250,7 @@ def memcpy_triton(dst, src, dim, offset, sz, offset_src):
memcpy_triton_kernel
[
grid
](
dst
,
src
,
offset
,
sz
,
offset_src
,
chunk_size
,
BLOCK_SIZE
)
memcpy_triton_kernel
[
grid
](
dst
,
src
,
offset
,
sz
,
offset_src
,
chunk_size
,
BLOCK_SIZE
)
def
_dp_gather
(
def
_dp_gather
_via_all_reduce
(
global_tokens
:
torch
.
Tensor
,
global_tokens
:
torch
.
Tensor
,
local_tokens
:
torch
.
Tensor
,
local_tokens
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
...
@@ -238,13 +267,6 @@ def _dp_gather(
...
@@ -238,13 +267,6 @@ def _dp_gather(
local_tokens
.
untyped_storage
()
is
not
global_tokens
.
untyped_storage
()
local_tokens
.
untyped_storage
()
is
not
global_tokens
.
untyped_storage
()
),
"aliasing between global_tokens and local_tokens not allowed"
),
"aliasing between global_tokens and local_tokens not allowed"
# NOTE: During draft extend, the gathered_buffer is padded to num_tokens * (speculative_num_steps + 1).
# But the size of local_tokens is total accepted tokens. We need to reduce the local_num_tokens to the
# actual size of the accepted tokens.
if
forward_batch
.
forward_mode
.
is_draft_extend
():
shape_tensor
=
local_num_tokens
.
new_full
((),
local_tokens
.
shape
[
0
])
local_num_tokens
=
torch
.
minimum
(
local_num_tokens
,
shape_tensor
)
memcpy_triton
(
memcpy_triton
(
global_tokens
,
local_tokens
,
0
,
local_start_pos
,
local_num_tokens
,
False
global_tokens
,
local_tokens
,
0
,
local_start_pos
,
local_num_tokens
,
False
)
)
...
@@ -263,6 +285,38 @@ def _dp_gather(
...
@@ -263,6 +285,38 @@ def _dp_gather(
global_tokens
[:]
=
tensor_model_parallel_all_reduce
(
global_tokens
)
global_tokens
[:]
=
tensor_model_parallel_all_reduce
(
global_tokens
)
def
_dp_gather_via_all_gather
(
global_tokens
:
torch
.
Tensor
,
local_tokens
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
is_partial
:
bool
,
):
if
not
is_partial
:
if
get_attention_tp_rank
()
!=
0
:
local_tokens
.
fill_
(
0
)
scattered_local_tokens
=
local_tokens
.
tensor_split
(
get_attention_tp_size
())[
get_attention_tp_rank
()
]
get_attention_tp_group
().
reduce_scatter_tensor
(
scattered_local_tokens
,
local_tokens
)
get_tp_group
().
all_gather_into_tensor
(
global_tokens
,
scattered_local_tokens
)
def
_dp_gather
(
global_tokens
:
torch
.
Tensor
,
local_tokens
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
is_partial
:
bool
,
):
if
forward_batch
.
dp_padding_mode
.
is_max_len
():
_dp_gather_via_all_gather
(
global_tokens
,
local_tokens
,
forward_batch
,
is_partial
)
else
:
_dp_gather_via_all_reduce
(
global_tokens
,
local_tokens
,
forward_batch
,
is_partial
)
def
dp_gather_partial
(
def
dp_gather_partial
(
global_tokens
:
torch
.
Tensor
,
global_tokens
:
torch
.
Tensor
,
local_tokens
:
torch
.
Tensor
,
local_tokens
:
torch
.
Tensor
,
...
@@ -296,24 +350,18 @@ def dp_scatter(
...
@@ -296,24 +350,18 @@ def dp_scatter(
local_tokens
.
untyped_storage
()
is
not
global_tokens
.
untyped_storage
()
local_tokens
.
untyped_storage
()
is
not
global_tokens
.
untyped_storage
()
),
"aliasing between local_tokens and global_tokens not allowed"
),
"aliasing between local_tokens and global_tokens not allowed"
# NOTE: During draft extend, the gathered_buffer is padded to num_tokens * (speculative_num_steps + 1).
# But the size of local_tokens is total accepted tokens. We need to reduce the local_num_tokens to the
# actual size of the accepted tokens.
if
forward_batch
.
forward_mode
.
is_draft_extend
():
shape_tensor
=
local_num_tokens
.
new_full
((),
local_tokens
.
shape
[
0
])
local_num_tokens
=
torch
.
minimum
(
local_num_tokens
,
shape_tensor
)
memcpy_triton
(
memcpy_triton
(
local_tokens
,
global_tokens
,
0
,
local_start_pos
,
local_num_tokens
,
True
local_tokens
,
global_tokens
,
0
,
local_start_pos
,
local_num_tokens
,
True
)
)
def
attn_tp_reduce_scatter
(
def
attn_tp_reduce_scatter_tensor
(
output
:
torch
.
Tensor
,
input
:
torch
.
Tensor
):
output
:
torch
.
Tensor
,
return
get_attention_tp_group
().
reduce_scatter_tensor
(
output
,
input
)
input_list
:
List
[
torch
.
Tensor
],
):
return
get_attention_tp_group
().
reduce_scatter
(
output
,
input_list
)
def
attn_tp_all_gather_into_tensor
(
output
:
torch
.
Tensor
,
input
:
torch
.
Tensor
):
return
get_attention_tp_group
().
all_gather_into_tensor
(
output
,
input
)
def
attn_tp_all_gather
(
output_list
:
List
[
torch
.
Tensor
],
input
_
:
torch
.
Tensor
):
def
attn_tp_all_gather
(
output_list
:
List
[
torch
.
Tensor
],
input
:
torch
.
Tensor
):
return
get_attention_tp_group
().
all_gather
(
input
_
,
output_tensor_list
=
output_list
)
return
get_attention_tp_group
().
all_gather
(
input
,
output_tensor_list
=
output_list
)
python/sglang/srt/layers/logits_processor.py
View file @
c0fb25e9
...
@@ -27,7 +27,9 @@ from sglang.srt.distributed import (
...
@@ -27,7 +27,9 @@ from sglang.srt.distributed import (
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_gather
,
)
)
from
sglang.srt.layers.dp_attention
import
(
from
sglang.srt.layers.dp_attention
import
(
DPPaddingMode
,
attn_tp_all_gather
,
attn_tp_all_gather
,
attn_tp_all_gather_into_tensor
,
dp_gather_replicate
,
dp_gather_replicate
,
dp_scatter
,
dp_scatter
,
get_attention_dp_rank
,
get_attention_dp_rank
,
...
@@ -111,7 +113,8 @@ class LogitsMetadata:
...
@@ -111,7 +113,8 @@ class LogitsMetadata:
# Number of tokens to sample per DP rank
# Number of tokens to sample per DP rank
global_num_tokens_for_logprob_cpu
:
Optional
[
torch
.
Tensor
]
=
None
global_num_tokens_for_logprob_cpu
:
Optional
[
torch
.
Tensor
]
=
None
global_num_tokens_for_logprob_gpu
:
Optional
[
torch
.
Tensor
]
=
None
global_num_tokens_for_logprob_gpu
:
Optional
[
torch
.
Tensor
]
=
None
# The gather mode for DP attention
dp_padding_mode
:
Optional
[
DPPaddingMode
]
=
None
# for padding
# for padding
padded_static_len
:
int
=
-
1
padded_static_len
:
int
=
-
1
...
@@ -163,12 +166,12 @@ class LogitsMetadata:
...
@@ -163,12 +166,12 @@ class LogitsMetadata:
forward_batch_gathered_buffer
=
forward_batch
.
gathered_buffer
,
forward_batch_gathered_buffer
=
forward_batch
.
gathered_buffer
,
global_num_tokens_for_logprob_cpu
=
forward_batch
.
global_num_tokens_for_logprob_cpu
,
global_num_tokens_for_logprob_cpu
=
forward_batch
.
global_num_tokens_for_logprob_cpu
,
global_num_tokens_for_logprob_gpu
=
forward_batch
.
global_num_tokens_for_logprob_gpu
,
global_num_tokens_for_logprob_gpu
=
forward_batch
.
global_num_tokens_for_logprob_gpu
,
dp_padding_mode
=
DPPaddingMode
.
SUM_LEN
,
)
)
def
compute_dp_attention_metadata
(
self
,
hidden_states
:
torch
.
Tensor
):
def
compute_dp_attention_metadata
(
self
):
if
self
.
global_num_tokens_for_logprob_cpu
is
None
:
# TODO(ch-wan): gathered_buffer here is larger than the actual required size in draft extend,
# we are capturing cuda graph
# we may use a smaller buffer in draft extend.
return
cumtokens
=
torch
.
cumsum
(
self
.
global_num_tokens_for_logprob_gpu
,
dim
=
0
)
cumtokens
=
torch
.
cumsum
(
self
.
global_num_tokens_for_logprob_gpu
,
dim
=
0
)
dp_rank
=
get_attention_dp_rank
()
dp_rank
=
get_attention_dp_rank
()
...
@@ -179,18 +182,9 @@ class LogitsMetadata:
...
@@ -179,18 +182,9 @@ class LogitsMetadata:
else
:
else
:
dp_local_start_pos
=
cumtokens
[
dp_rank
-
1
]
dp_local_start_pos
=
cumtokens
[
dp_rank
-
1
]
dp_local_num_tokens
=
self
.
global_num_tokens_for_logprob_gpu
[
dp_rank
]
dp_local_num_tokens
=
self
.
global_num_tokens_for_logprob_gpu
[
dp_rank
]
gathered_buffer
=
torch
.
zeros
(
(
sum
(
self
.
global_num_tokens_for_logprob_cpu
),
hidden_states
.
shape
[
1
],
),
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
.
device
,
)
self
.
dp_local_start_pos
=
dp_local_start_pos
self
.
dp_local_start_pos
=
dp_local_start_pos
self
.
dp_local_num_tokens
=
dp_local_num_tokens
self
.
dp_local_num_tokens
=
dp_local_num_tokens
self
.
gathered_buffer
=
gathered_buffer
class
LogitsProcessor
(
nn
.
Module
):
class
LogitsProcessor
(
nn
.
Module
):
...
@@ -434,7 +428,7 @@ class LogitsProcessor(nn.Module):
...
@@ -434,7 +428,7 @@ class LogitsProcessor(nn.Module):
guarantee the given hidden_states follow this constraint.
guarantee the given hidden_states follow this constraint.
"""
"""
if
self
.
do_tensor_parallel_all_gather_dp_attn
:
if
self
.
do_tensor_parallel_all_gather_dp_attn
:
logits_metadata
.
compute_dp_attention_metadata
(
hidden_states
)
logits_metadata
.
compute_dp_attention_metadata
()
hidden_states
,
local_hidden_states
=
(
hidden_states
,
local_hidden_states
=
(
torch
.
empty_like
(
logits_metadata
.
gathered_buffer
),
torch
.
empty_like
(
logits_metadata
.
gathered_buffer
),
hidden_states
,
hidden_states
,
...
@@ -463,6 +457,21 @@ class LogitsProcessor(nn.Module):
...
@@ -463,6 +457,21 @@ class LogitsProcessor(nn.Module):
if
self
.
do_tensor_parallel_all_gather
:
if
self
.
do_tensor_parallel_all_gather
:
if
self
.
use_attn_tp_group
:
if
self
.
use_attn_tp_group
:
if
self
.
config
.
vocab_size
%
self
.
attn_tp_size
==
0
:
global_logits
=
torch
.
empty
(
(
self
.
attn_tp_size
,
logits
.
shape
[
0
],
self
.
config
.
vocab_size
//
self
.
attn_tp_size
,
),
device
=
logits
.
device
,
dtype
=
logits
.
dtype
,
)
attn_tp_all_gather_into_tensor
(
global_logits
,
logits
)
global_logits
=
global_logits
.
permute
(
1
,
0
,
2
).
reshape
(
logits
.
shape
[
0
],
self
.
config
.
vocab_size
)
else
:
global_logits
=
torch
.
empty
(
global_logits
=
torch
.
empty
(
(
self
.
config
.
vocab_size
,
logits
.
shape
[
0
]),
(
self
.
config
.
vocab_size
,
logits
.
shape
[
0
]),
device
=
logits
.
device
,
device
=
logits
.
device
,
...
@@ -470,7 +479,8 @@ class LogitsProcessor(nn.Module):
...
@@ -470,7 +479,8 @@ class LogitsProcessor(nn.Module):
)
)
global_logits
=
global_logits
.
T
global_logits
=
global_logits
.
T
attn_tp_all_gather
(
attn_tp_all_gather
(
list
(
global_logits
.
tensor_split
(
self
.
attn_tp_size
,
dim
=-
1
)),
logits
list
(
global_logits
.
tensor_split
(
self
.
attn_tp_size
,
dim
=-
1
)),
logits
,
)
)
logits
=
global_logits
logits
=
global_logits
else
:
else
:
...
...
python/sglang/srt/layers/radix_attention.py
View file @
c0fb25e9
...
@@ -12,14 +12,16 @@
...
@@ -12,14 +12,16 @@
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""Radix attention."""
"""Radix attention."""
from
__future__
import
annotations
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
Optional
from
typing
import
TYPE_CHECKING
,
Optional
from
torch
import
nn
from
torch
import
nn
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
if
TYPE_CHECKING
:
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
class
AttentionType
(
Enum
):
class
AttentionType
(
Enum
):
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
c0fb25e9
...
@@ -45,7 +45,6 @@ import triton
...
@@ -45,7 +45,6 @@ import triton
import
triton.language
as
tl
import
triton.language
as
tl
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.constrained.base_grammar_backend
import
BaseGrammarObject
from
sglang.srt.constrained.base_grammar_backend
import
BaseGrammarObject
from
sglang.srt.disaggregation.base
import
BaseKVSender
from
sglang.srt.disaggregation.base
import
BaseKVSender
from
sglang.srt.disaggregation.decode_schedule_batch_mixin
import
(
from
sglang.srt.disaggregation.decode_schedule_batch_mixin
import
(
...
@@ -68,6 +67,7 @@ from sglang.srt.server_args import ServerArgs
...
@@ -68,6 +67,7 @@ from sglang.srt.server_args import ServerArgs
from
sglang.srt.utils
import
flatten_nested_list
,
support_triton
from
sglang.srt.utils
import
flatten_nested_list
,
support_triton
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
...
@@ -1880,7 +1880,7 @@ class ModelWorkerBatch:
...
@@ -1880,7 +1880,7 @@ class ModelWorkerBatch:
sampling_info
:
SamplingBatchInfo
sampling_info
:
SamplingBatchInfo
# The input Embeds
# The input Embeds
input_embeds
:
Optional
[
torch
.
t
ensor
]
=
None
input_embeds
:
Optional
[
torch
.
T
ensor
]
=
None
# For corss-encoder model
# For corss-encoder model
token_type_ids
:
Optional
[
torch
.
Tensor
]
=
None
token_type_ids
:
Optional
[
torch
.
Tensor
]
=
None
...
@@ -1890,7 +1890,6 @@ class ModelWorkerBatch:
...
@@ -1890,7 +1890,6 @@ class ModelWorkerBatch:
spec_info
:
Optional
[
Union
[
EagleVerifyInput
,
EagleDraftInput
]]
=
None
spec_info
:
Optional
[
Union
[
EagleVerifyInput
,
EagleDraftInput
]]
=
None
# If set, the output of the batch contains the hidden states of the run.
# If set, the output of the batch contains the hidden states of the run.
capture_hidden_mode
:
CaptureHiddenMode
=
None
capture_hidden_mode
:
CaptureHiddenMode
=
None
spec_num_draft_tokens
:
Optional
[
int
]
=
None
hicache_consumer_index
:
int
=
0
hicache_consumer_index
:
int
=
0
# Overlap event
# Overlap event
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
c0fb25e9
...
@@ -29,9 +29,9 @@ from torch.profiler import ProfilerActivity, profile
...
@@ -29,9 +29,9 @@ from torch.profiler import ProfilerActivity, profile
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
from
sglang.srt.distributed.parallel_state
import
GroupCoordinator
,
graph_capture
from
sglang.srt.distributed.parallel_state
import
GroupCoordinator
,
graph_capture
from
sglang.srt.layers.dp_attention
import
DPPaddingMode
,
get_attention_tp_size
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.torchao_utils
import
save_gemlite_cache
from
sglang.srt.layers.torchao_utils
import
save_gemlite_cache
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
(
from
sglang.srt.model_executor.forward_batch_info
import
(
CaptureHiddenMode
,
CaptureHiddenMode
,
ForwardBatch
,
ForwardBatch
,
...
@@ -167,8 +167,15 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
...
@@ -167,8 +167,15 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
# is very small. We add more values here to make sure we capture the maximum bs.
# is very small. We add more values here to make sure we capture the maximum bs.
capture_bs
+=
[
model_runner
.
req_to_token_pool
.
size
]
capture_bs
+=
[
model_runner
.
req_to_token_pool
.
size
]
mul_base
=
1
if
server_args
.
enable_two_batch_overlap
:
if
server_args
.
enable_two_batch_overlap
:
capture_bs
=
[
bs
for
bs
in
capture_bs
if
bs
%
2
==
0
]
mul_base
*=
2
if
require_gathered_buffer
(
server_args
):
mul_base
*=
get_attention_tp_size
()
capture_bs
=
[
bs
for
bs
in
capture_bs
if
bs
%
mul_base
==
0
]
if
server_args
.
cuda_graph_max_bs
:
if
server_args
.
cuda_graph_max_bs
:
capture_bs
=
[
bs
for
bs
in
capture_bs
if
bs
<=
server_args
.
cuda_graph_max_bs
]
capture_bs
=
[
bs
for
bs
in
capture_bs
if
bs
<=
server_args
.
cuda_graph_max_bs
]
...
@@ -306,20 +313,37 @@ class CudaGraphRunner:
...
@@ -306,20 +313,37 @@ class CudaGraphRunner:
self
.
encoder_lens
=
None
self
.
encoder_lens
=
None
if
self
.
require_gathered_buffer
:
if
self
.
require_gathered_buffer
:
if
self
.
require_mlp_tp_gather
:
self
.
global_num_tokens_gpu
=
torch
.
zeros
(
(
self
.
dp_size
,),
dtype
=
torch
.
int32
)
self
.
global_num_tokens_for_logprob_gpu
=
torch
.
zeros
(
(
self
.
dp_size
,),
dtype
=
torch
.
int32
)
self
.
gathered_buffer
=
torch
.
zeros
(
self
.
gathered_buffer
=
torch
.
zeros
(
(
(
self
.
max_num_token
,
self
.
max_num_token
*
self
.
dp_size
,
self
.
model_runner
.
model_config
.
hidden_size
,
self
.
model_runner
.
model_config
.
hidden_size
,
),
),
dtype
=
self
.
model_runner
.
dtype
,
dtype
=
self
.
model_runner
.
dtype
,
)
)
if
self
.
require_mlp_tp_gather
:
self
.
global_num_tokens_gpu
=
torch
.
zeros
(
(
self
.
dp_size
,),
dtype
=
torch
.
int32
)
else
:
else
:
assert
self
.
require_attn_tp_gather
assert
self
.
require_attn_tp_gather
self
.
global_num_tokens_gpu
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
int32
)
self
.
global_num_tokens_gpu
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
int32
)
self
.
global_num_tokens_for_logprob_gpu
=
torch
.
zeros
(
(
1
,),
dtype
=
torch
.
int32
)
self
.
gathered_buffer
=
torch
.
zeros
(
(
self
.
max_num_token
,
self
.
model_runner
.
model_config
.
hidden_size
,
),
dtype
=
self
.
model_runner
.
dtype
,
)
else
:
self
.
global_num_tokens_gpu
=
None
self
.
global_num_tokens_for_logprob_gpu
=
None
self
.
gathered_buffer
=
None
self
.
custom_mask
=
torch
.
ones
(
self
.
custom_mask
=
torch
.
ones
(
(
(
...
@@ -342,9 +366,9 @@ class CudaGraphRunner:
...
@@ -342,9 +366,9 @@ class CudaGraphRunner:
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
if
self
.
require_mlp_tp_gather
:
if
self
.
require_mlp_tp_gather
:
cuda_graph_bs
=
(
cuda_graph_bs
=
(
sum
(
forward_batch
.
global_num_tokens_cpu
)
//
self
.
num_tokens_per_bs
max
(
forward_batch
.
global_num_tokens_cpu
)
//
self
.
num_tokens_per_bs
if
self
.
model_runner
.
spec_algorithm
.
is_eagle
()
if
self
.
model_runner
.
spec_algorithm
.
is_eagle
()
else
sum
(
forward_batch
.
global_num_tokens_cpu
)
else
max
(
forward_batch
.
global_num_tokens_cpu
)
)
)
else
:
else
:
cuda_graph_bs
=
forward_batch
.
batch_size
cuda_graph_bs
=
forward_batch
.
batch_size
...
@@ -480,16 +504,19 @@ class CudaGraphRunner:
...
@@ -480,16 +504,19 @@ class CudaGraphRunner:
if
self
.
require_mlp_tp_gather
:
if
self
.
require_mlp_tp_gather
:
self
.
global_num_tokens_gpu
.
copy_
(
self
.
global_num_tokens_gpu
.
copy_
(
torch
.
tensor
(
torch
.
tensor
(
[
[
num_tokens
]
*
self
.
dp_size
,
num_tokens
//
self
.
dp_size
+
(
i
<
(
num_tokens
%
self
.
dp_size
))
for
i
in
range
(
self
.
dp_size
)
],
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
input_ids
.
device
,
device
=
input_ids
.
device
,
)
)
)
)
global_num_tokens
=
self
.
global_num_tokens_gpu
self
.
global_num_tokens_for_logprob_gpu
.
copy_
(
gathered_buffer
=
self
.
gathered_buffer
[:
num_tokens
]
torch
.
tensor
(
[
num_tokens
]
*
self
.
dp_size
,
dtype
=
torch
.
int32
,
device
=
input_ids
.
device
,
)
)
gathered_buffer
=
self
.
gathered_buffer
[:
num_tokens
*
self
.
dp_size
]
elif
self
.
require_attn_tp_gather
:
elif
self
.
require_attn_tp_gather
:
self
.
global_num_tokens_gpu
.
copy_
(
self
.
global_num_tokens_gpu
.
copy_
(
torch
.
tensor
(
torch
.
tensor
(
...
@@ -498,10 +525,15 @@ class CudaGraphRunner:
...
@@ -498,10 +525,15 @@ class CudaGraphRunner:
device
=
input_ids
.
device
,
device
=
input_ids
.
device
,
)
)
)
)
global_num_tokens
=
self
.
global_num_tokens_gpu
self
.
global_num_tokens_for_logprob_gpu
.
copy_
(
torch
.
tensor
(
[
num_tokens
],
dtype
=
torch
.
int32
,
device
=
input_ids
.
device
,
)
)
gathered_buffer
=
self
.
gathered_buffer
[:
num_tokens
]
gathered_buffer
=
self
.
gathered_buffer
[:
num_tokens
]
else
:
else
:
global_num_tokens
=
None
gathered_buffer
=
None
gathered_buffer
=
None
spec_info
=
self
.
get_spec_info
(
num_tokens
)
spec_info
=
self
.
get_spec_info
(
num_tokens
)
...
@@ -531,7 +563,9 @@ class CudaGraphRunner:
...
@@ -531,7 +563,9 @@ class CudaGraphRunner:
encoder_lens
=
encoder_lens
,
encoder_lens
=
encoder_lens
,
return_logprob
=
False
,
return_logprob
=
False
,
positions
=
positions
,
positions
=
positions
,
global_num_tokens_gpu
=
global_num_tokens
,
global_num_tokens_gpu
=
self
.
global_num_tokens_gpu
,
global_num_tokens_for_logprob_gpu
=
self
.
global_num_tokens_for_logprob_gpu
,
dp_padding_mode
=
DPPaddingMode
.
get_default_mode_in_cuda_graph
(),
gathered_buffer
=
gathered_buffer
,
gathered_buffer
=
gathered_buffer
,
mrope_positions
=
mrope_positions
,
mrope_positions
=
mrope_positions
,
spec_algorithm
=
self
.
model_runner
.
spec_algorithm
,
spec_algorithm
=
self
.
model_runner
.
spec_algorithm
,
...
@@ -635,12 +669,13 @@ class CudaGraphRunner:
...
@@ -635,12 +669,13 @@ class CudaGraphRunner:
# Pad
# Pad
if
self
.
require_mlp_tp_gather
:
if
self
.
require_mlp_tp_gather
:
total_batch_size
=
(
max_num_tokens
=
max
(
forward_batch
.
global_num_tokens_cpu
)
sum
(
forward_batch
.
global_num_tokens_cpu
)
/
self
.
num_tokens_per_bs
max_batch_size
=
(
max_num_tokens
/
self
.
num_tokens_per_bs
if
self
.
model_runner
.
spec_algorithm
.
is_eagle
()
if
self
.
model_runner
.
spec_algorithm
.
is_eagle
()
else
sum
(
forward_batch
.
global
_num_tokens
_cpu
)
else
max
_num_tokens
)
)
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
total
_batch_size
)
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
max
_batch_size
)
else
:
else
:
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
raw_bs
)
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
raw_bs
)
bs
=
self
.
capture_bs
[
index
]
bs
=
self
.
capture_bs
[
index
]
...
@@ -670,7 +705,8 @@ class CudaGraphRunner:
...
@@ -670,7 +705,8 @@ class CudaGraphRunner:
if
forward_batch
.
mrope_positions
is
not
None
:
if
forward_batch
.
mrope_positions
is
not
None
:
self
.
mrope_positions
[:,
:
raw_bs
].
copy_
(
forward_batch
.
mrope_positions
)
self
.
mrope_positions
[:,
:
raw_bs
].
copy_
(
forward_batch
.
mrope_positions
)
if
self
.
require_gathered_buffer
:
if
self
.
require_gathered_buffer
:
self
.
global_num_tokens_gpu
.
copy_
(
forward_batch
.
global_num_tokens_gpu
)
self
.
global_num_tokens_gpu
.
fill_
(
bs
*
self
.
num_tokens_per_bs
)
self
.
global_num_tokens_for_logprob_gpu
.
fill_
(
bs
*
self
.
num_tokens_per_bs
)
if
enable_num_token_non_padded
(
self
.
model_runner
.
server_args
):
if
enable_num_token_non_padded
(
self
.
model_runner
.
server_args
):
self
.
num_token_non_padded
.
copy_
(
forward_batch
.
num_token_non_padded
)
self
.
num_token_non_padded
.
copy_
(
forward_batch
.
num_token_non_padded
)
if
self
.
enable_two_batch_overlap
:
if
self
.
enable_two_batch_overlap
:
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
c0fb25e9
...
@@ -38,6 +38,11 @@ import torch
...
@@ -38,6 +38,11 @@ import torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
sglang.srt.layers.dp_attention
import
(
DPPaddingMode
,
get_attention_dp_rank
,
get_attention_tp_size
,
)
from
sglang.srt.layers.rotary_embedding
import
MRotaryEmbedding
from
sglang.srt.layers.rotary_embedding
import
MRotaryEmbedding
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
flatten_nested_list
,
flatten_nested_list
,
...
@@ -48,6 +53,7 @@ from sglang.srt.utils import (
...
@@ -48,6 +53,7 @@ from sglang.srt.utils import (
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
,
MultimodalInputs
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
,
MultimodalInputs
from
sglang.srt.mem_cache.memory_pool
import
KVCache
,
ReqToTokenPool
from
sglang.srt.mem_cache.memory_pool
import
KVCache
,
ReqToTokenPool
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
...
@@ -242,7 +248,7 @@ class ForwardBatch:
...
@@ -242,7 +248,7 @@ class ForwardBatch:
lora_paths
:
Optional
[
List
[
str
]]
=
None
lora_paths
:
Optional
[
List
[
str
]]
=
None
# For input embeddings
# For input embeddings
input_embeds
:
Optional
[
torch
.
t
ensor
]
=
None
input_embeds
:
Optional
[
torch
.
T
ensor
]
=
None
# For cross-encoder model
# For cross-encoder model
token_type_ids
:
Optional
[
torch
.
Tensor
]
=
None
token_type_ids
:
Optional
[
torch
.
Tensor
]
=
None
...
@@ -261,6 +267,8 @@ class ForwardBatch:
...
@@ -261,6 +267,8 @@ class ForwardBatch:
# Has to be None when cuda graph is captured.
# Has to be None when cuda graph is captured.
global_num_tokens_for_logprob_cpu
:
Optional
[
List
[
int
]]
=
None
global_num_tokens_for_logprob_cpu
:
Optional
[
List
[
int
]]
=
None
global_num_tokens_for_logprob_gpu
:
Optional
[
torch
.
Tensor
]
=
None
global_num_tokens_for_logprob_gpu
:
Optional
[
torch
.
Tensor
]
=
None
# The padding mode for DP attention
dp_padding_mode
:
Optional
[
DPPaddingMode
]
=
None
# for extend, local start pos and num tokens is different in logits processor
# for extend, local start pos and num tokens is different in logits processor
# this will be computed in get_dp_local_info
# this will be computed in get_dp_local_info
# this will be recomputed in LogitsMetadata.from_forward_batch
# this will be recomputed in LogitsMetadata.from_forward_batch
...
@@ -286,7 +294,7 @@ class ForwardBatch:
...
@@ -286,7 +294,7 @@ class ForwardBatch:
# For two-batch overlap
# For two-batch overlap
tbo_split_seq_index
:
Optional
[
int
]
=
None
tbo_split_seq_index
:
Optional
[
int
]
=
None
tbo_parent_token_range
:
Optional
[
Tuple
[
int
,
int
]]
=
None
tbo_parent_token_range
:
Optional
[
Tuple
[
int
,
int
]]
=
None
tbo_children
:
Optional
[
List
[
"
ForwardBatch
"
]]
=
None
tbo_children
:
Optional
[
List
[
ForwardBatch
]]
=
None
@
classmethod
@
classmethod
def
init_new
(
def
init_new
(
...
@@ -340,20 +348,38 @@ class ForwardBatch:
...
@@ -340,20 +348,38 @@ class ForwardBatch:
len
(
batch
.
input_ids
),
dtype
=
torch
.
int32
len
(
batch
.
input_ids
),
dtype
=
torch
.
int32
).
to
(
device
,
non_blocking
=
True
)
).
to
(
device
,
non_blocking
=
True
)
# For
DP attention
# For
MLP sync
if
batch
.
global_num_tokens
is
not
None
:
if
batch
.
global_num_tokens
is
not
None
:
from
sglang.srt.speculative.eagle_utils
import
(
spec_num_draft_tokens
=
(
EagleDraftInput
,
batch
.
spec_num_draft_tokens
EagleVerifyInput
,
if
batch
.
spec_num_draft_tokens
is
not
None
else
1
)
)
assert
batch
.
global_num_tokens_for_logprob
is
not
None
# process global_num_tokens and global_num_tokens_for_logprob
if
batch
.
spec_info
is
not
None
:
if
isinstance
(
batch
.
spec_info
,
EagleDraftInput
):
global_num_tokens
=
[
x
*
batch
.
spec_info
.
num_tokens_per_batch
for
x
in
batch
.
global_num_tokens
]
global_num_tokens_for_logprob
=
[
x
*
batch
.
spec_info
.
num_tokens_for_logprob_per_batch
for
x
in
batch
.
global_num_tokens_for_logprob
]
else
:
assert
isinstance
(
batch
.
spec_info
,
EagleVerifyInput
)
global_num_tokens
=
[
global_num_tokens
=
[
x
*
spec_num_draft_tokens
for
x
in
batch
.
global_num_tokens
x
*
batch
.
spec_info
.
draft_token_num
for
x
in
batch
.
global_num_tokens
]
]
global_num_tokens_for_logprob
=
[
global_num_tokens_for_logprob
=
[
x
*
spec_num_draft_tokens
for
x
in
batch
.
global_num_tokens_for_logprob
x
*
batch
.
spec_info
.
draft_token_num
for
x
in
batch
.
global_num_tokens_for_logprob
]
]
else
:
global_num_tokens
=
batch
.
global_num_tokens
global_num_tokens_for_logprob
=
batch
.
global_num_tokens_for_logprob
ret
.
global_num_tokens_cpu
=
global_num_tokens
ret
.
global_num_tokens_cpu
=
global_num_tokens
ret
.
global_num_tokens_gpu
=
torch
.
tensor
(
ret
.
global_num_tokens_gpu
=
torch
.
tensor
(
...
@@ -365,15 +391,8 @@ class ForwardBatch:
...
@@ -365,15 +391,8 @@ class ForwardBatch:
global_num_tokens_for_logprob
,
dtype
=
torch
.
int64
global_num_tokens_for_logprob
,
dtype
=
torch
.
int64
).
to
(
device
,
non_blocking
=
True
)
).
to
(
device
,
non_blocking
=
True
)
sum_len
=
sum
(
global_num_tokens
)
ret
.
gathered_buffer
=
torch
.
zeros
(
(
sum_len
,
model_runner
.
model_config
.
hidden_size
),
dtype
=
model_runner
.
dtype
,
device
=
device
,
)
if
ret
.
forward_mode
.
is_idle
():
if
ret
.
forward_mode
.
is_idle
():
ret
.
positions
=
torch
.
empty
((
0
,),
device
=
device
)
ret
.
positions
=
torch
.
empty
((
0
,),
dtype
=
torch
.
int64
,
device
=
device
)
TboForwardBatchPreparer
.
prepare
(
TboForwardBatchPreparer
.
prepare
(
ret
,
is_draft_worker
=
model_runner
.
is_draft_worker
ret
,
is_draft_worker
=
model_runner
.
is_draft_worker
)
)
...
@@ -573,6 +592,158 @@ class ForwardBatch:
...
@@ -573,6 +592,158 @@ class ForwardBatch:
)
)
self
.
prefix_chunk_kv_indices
.
append
(
chunk_kv_indices
)
self
.
prefix_chunk_kv_indices
.
append
(
chunk_kv_indices
)
def
_pad_tensor_to_size
(
self
,
tensor
:
torch
.
Tensor
,
size
:
int
,
*
,
value
:
int
=
0
):
if
value
==
0
:
return
torch
.
cat
(
[
tensor
,
tensor
.
new_zeros
(
size
-
tensor
.
shape
[
0
],
*
tensor
.
shape
[
1
:])],
dim
=
0
,
)
else
:
return
torch
.
cat
(
[
tensor
,
tensor
.
new_full
((
size
-
tensor
.
shape
[
0
],
*
tensor
.
shape
[
1
:]),
value
),
],
dim
=
0
,
)
def
prepare_mlp_sync_batch
(
self
,
model_runner
:
ModelRunner
):
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
assert
self
.
global_num_tokens_cpu
is
not
None
assert
self
.
global_num_tokens_for_logprob_cpu
is
not
None
global_num_tokens
=
self
.
global_num_tokens_cpu
sync_group_size
=
len
(
global_num_tokens
)
attn_tp_size
=
get_attention_tp_size
()
for
i
in
range
(
sync_group_size
):
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
# there is no reduce-scatter in LM logprob, so we do not need to adjust the padded length for logprob
global_num_tokens
[
i
]
=
(
(
global_num_tokens
[
i
]
-
1
)
//
attn_tp_size
+
1
)
*
attn_tp_size
dp_padding_mode
=
DPPaddingMode
.
get_dp_padding_mode
(
global_num_tokens
)
self
.
dp_padding_mode
=
dp_padding_mode
if
dp_padding_mode
.
is_max_len
():
# when DP gather mode is all gather, we will use all_gather_into_tensor to gather hidden states,
# where transferred tokens should be padded to the same length.
max_num_tokens
=
max
(
global_num_tokens
)
global_num_tokens
=
[
max_num_tokens
]
*
sync_group_size
buffer_len
=
max_num_tokens
*
sync_group_size
else
:
buffer_len
=
sum
(
global_num_tokens
)
self
.
gathered_buffer
=
torch
.
zeros
(
(
buffer_len
,
model_runner
.
model_config
.
hidden_size
),
dtype
=
model_runner
.
dtype
,
device
=
model_runner
.
device
,
)
bs
=
self
.
batch_size
if
len
(
global_num_tokens
)
>
1
:
num_tokens
=
global_num_tokens
[
get_attention_dp_rank
()]
else
:
num_tokens
=
global_num_tokens
[
0
]
# padding
self
.
input_ids
=
self
.
_pad_tensor_to_size
(
self
.
input_ids
,
num_tokens
)
self
.
req_pool_indices
=
self
.
_pad_tensor_to_size
(
self
.
req_pool_indices
,
bs
)
seq_len_fill_value
=
(
model_runner
.
attn_backend
.
get_cuda_graph_seq_len_fill_value
()
)
self
.
seq_lens
=
self
.
_pad_tensor_to_size
(
self
.
seq_lens
,
bs
,
value
=
seq_len_fill_value
)
if
self
.
seq_lens_cpu
is
not
None
:
self
.
seq_lens_cpu
=
self
.
_pad_tensor_to_size
(
self
.
seq_lens_cpu
,
bs
,
value
=
seq_len_fill_value
)
self
.
out_cache_loc
=
self
.
_pad_tensor_to_size
(
self
.
out_cache_loc
,
num_tokens
)
if
self
.
encoder_lens
is
not
None
:
self
.
encoder_lens
=
self
.
_pad_tensor_to_size
(
self
.
encoder_lens
,
bs
)
self
.
positions
=
self
.
_pad_tensor_to_size
(
self
.
positions
,
num_tokens
)
self
.
global_num_tokens_cpu
=
global_num_tokens
self
.
global_num_tokens_gpu
=
self
.
global_num_tokens_gpu
.
new_tensor
(
global_num_tokens
)
if
self
.
mrope_positions
is
not
None
:
self
.
mrope_positions
=
self
.
_pad_tensor_to_size
(
self
.
mrope_positions
,
bs
)
if
self
.
extend_seq_lens
is
not
None
:
self
.
extend_seq_lens
=
self
.
_pad_tensor_to_size
(
self
.
extend_seq_lens
,
bs
)
if
self
.
spec_info
is
not
None
and
isinstance
(
self
.
spec_info
,
EagleDraftInput
):
spec_info
=
self
.
spec_info
self
.
output_cache_loc_backup
=
self
.
out_cache_loc
self
.
hidden_states_backup
=
spec_info
.
hidden_states
if
spec_info
.
topk_p
is
not
None
:
spec_info
.
topk_p
=
self
.
_pad_tensor_to_size
(
spec_info
.
topk_p
,
bs
)
if
spec_info
.
topk_index
is
not
None
:
spec_info
.
topk_index
=
self
.
_pad_tensor_to_size
(
spec_info
.
topk_index
,
bs
)
if
spec_info
.
accept_length
is
not
None
:
spec_info
.
accept_length
=
self
.
_pad_tensor_to_size
(
spec_info
.
accept_length
,
bs
)
spec_info
.
hidden_states
=
self
.
_pad_tensor_to_size
(
spec_info
.
hidden_states
,
num_tokens
)
def
post_forward_mlp_sync_batch
(
self
,
logits_output
:
LogitsProcessorOutput
):
bs
=
self
.
batch_size
if
self
.
spec_info
is
not
None
:
if
self
.
forward_mode
.
is_decode
():
# draft
num_tokens
=
self
.
hidden_states_backup
.
shape
[
0
]
self
.
positions
=
self
.
positions
[:
num_tokens
]
self
.
seq_lens
=
self
.
seq_lens
[:
bs
]
self
.
req_pool_indices
=
self
.
req_pool_indices
[:
bs
]
if
self
.
seq_lens_cpu
is
not
None
:
self
.
seq_lens_cpu
=
self
.
seq_lens_cpu
[:
bs
]
logits_output
.
next_token_logits
=
logits_output
.
next_token_logits
[
:
num_tokens
]
logits_output
.
hidden_states
=
logits_output
.
hidden_states
[:
num_tokens
]
elif
self
.
forward_mode
.
is_target_verify
():
# verify
num_tokens
=
bs
*
self
.
spec_info
.
draft_token_num
logits_output
.
next_token_logits
=
logits_output
.
next_token_logits
[
:
num_tokens
]
logits_output
.
hidden_states
=
logits_output
.
hidden_states
[:
num_tokens
]
elif
self
.
forward_mode
.
is_draft_extend
():
# draft extend
self
.
spec_info
.
accept_length
=
self
.
spec_info
.
accept_length
[:
bs
]
logits_output
.
next_token_logits
=
logits_output
.
next_token_logits
[:
bs
]
logits_output
.
hidden_states
=
logits_output
.
hidden_states
[:
bs
]
elif
self
.
forward_mode
.
is_extend
()
or
self
.
forward_mode
.
is_idle
():
logits_output
.
next_token_logits
=
logits_output
.
next_token_logits
[:
bs
]
logits_output
.
hidden_states
=
logits_output
.
hidden_states
[:
bs
]
if
hasattr
(
self
,
"hidden_states_backup"
):
self
.
spec_info
.
hidden_states
=
self
.
hidden_states_backup
if
hasattr
(
self
,
"output_cache_loc_backup"
):
self
.
out_cache_loc
=
self
.
output_cache_loc_backup
elif
self
.
forward_mode
.
is_decode
()
or
self
.
forward_mode
.
is_idle
():
logits_output
.
next_token_logits
=
logits_output
.
next_token_logits
[:
bs
]
if
logits_output
.
hidden_states
is
not
None
:
logits_output
.
hidden_states
=
logits_output
.
hidden_states
[:
bs
]
elif
self
.
forward_mode
.
is_extend
():
num_tokens
=
self
.
seq_lens_sum
logits_output
.
next_token_logits
=
logits_output
.
next_token_logits
[
:
num_tokens
]
if
logits_output
.
hidden_states
is
not
None
:
logits_output
.
hidden_states
=
logits_output
.
hidden_states
[:
num_tokens
]
# Here we suppose the length of each chunk is equal
# Here we suppose the length of each chunk is equal
# For example, if we have 4 sequences with prefix length [256, 512, 768, 1024], prefix_chunk_len = 256
# For example, if we have 4 sequences with prefix length [256, 512, 768, 1024], prefix_chunk_len = 256
# num_prefix_chunks = cdiv(1024, 256) = 4
# num_prefix_chunks = cdiv(1024, 256) = 4
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
c0fb25e9
...
@@ -1464,8 +1464,12 @@ class ModelRunner:
...
@@ -1464,8 +1464,12 @@ class ModelRunner:
tensor_parallel
(
self
.
model
,
device_mesh
)
tensor_parallel
(
self
.
model
,
device_mesh
)
def
forward_decode
(
def
forward_decode
(
self
,
forward_batch
:
ForwardBatch
,
pp_proxy_tensors
=
None
self
,
forward_batch
:
ForwardBatch
,
skip_attn_backend_init
:
bool
=
False
,
pp_proxy_tensors
=
None
,
)
->
LogitsProcessorOutput
:
)
->
LogitsProcessorOutput
:
if
not
skip_attn_backend_init
:
self
.
attn_backend
.
init_forward_metadata
(
forward_batch
)
self
.
attn_backend
.
init_forward_metadata
(
forward_batch
)
# FIXME: add pp_proxy_tensors arg to all models
# FIXME: add pp_proxy_tensors arg to all models
kwargs
=
{}
kwargs
=
{}
...
@@ -1578,8 +1582,18 @@ class ModelRunner:
...
@@ -1578,8 +1582,18 @@ class ModelRunner:
skip_attn_backend_init
=
skip_attn_backend_init
,
skip_attn_backend_init
=
skip_attn_backend_init
,
pp_proxy_tensors
=
pp_proxy_tensors
,
pp_proxy_tensors
=
pp_proxy_tensors
,
)
)
elif
forward_batch
.
forward_mode
.
is_decode
():
return
ret
,
can_run_cuda_graph
ret
=
self
.
forward_decode
(
forward_batch
,
pp_proxy_tensors
=
pp_proxy_tensors
)
# For MLP sync
if
forward_batch
.
global_num_tokens_cpu
is
not
None
:
forward_batch
.
prepare_mlp_sync_batch
(
self
)
if
forward_batch
.
forward_mode
.
is_decode
():
ret
=
self
.
forward_decode
(
forward_batch
,
skip_attn_backend_init
=
skip_attn_backend_init
,
pp_proxy_tensors
=
pp_proxy_tensors
,
)
elif
forward_batch
.
forward_mode
.
is_extend
():
elif
forward_batch
.
forward_mode
.
is_extend
():
ret
=
self
.
forward_extend
(
ret
=
self
.
forward_extend
(
forward_batch
,
forward_batch
,
...
@@ -1597,6 +1611,9 @@ class ModelRunner:
...
@@ -1597,6 +1611,9 @@ class ModelRunner:
else
:
else
:
raise
ValueError
(
f
"Invalid forward mode:
{
forward_batch
.
forward_mode
}
"
)
raise
ValueError
(
f
"Invalid forward mode:
{
forward_batch
.
forward_mode
}
"
)
if
forward_batch
.
global_num_tokens_cpu
is
not
None
:
forward_batch
.
post_forward_mlp_sync_batch
(
ret
)
return
ret
,
can_run_cuda_graph
return
ret
,
can_run_cuda_graph
def
_preprocess_logits
(
def
_preprocess_logits
(
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
c0fb25e9
...
@@ -550,9 +550,8 @@ class DeepseekV2MoE(nn.Module):
...
@@ -550,9 +550,8 @@ class DeepseekV2MoE(nn.Module):
def
forward_deepep
(
def
forward_deepep
(
self
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
self
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
forward_mode
=
forward_batch
.
forward_mode
shared_output
=
None
shared_output
=
None
if
is_non_idle_and_non_empty
(
forward_mode
,
hidden_states
)
:
if
hidden_states
.
shape
[
0
]
>
0
:
# router_logits: (num_tokens, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
)
router_logits
=
self
.
gate
(
hidden_states
)
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
)
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
)
...
...
python/sglang/srt/models/qwen2_moe.py
View file @
c0fb25e9
...
@@ -43,10 +43,6 @@ from sglang.srt.layers.communicator import (
...
@@ -43,10 +43,6 @@ from sglang.srt.layers.communicator import (
ScatterMode
,
ScatterMode
,
)
)
from
sglang.srt.layers.dp_attention
import
(
from
sglang.srt.layers.dp_attention
import
(
attn_tp_all_gather
,
attn_tp_reduce_scatter
,
dp_gather_partial
,
dp_scatter
,
get_attention_tp_rank
,
get_attention_tp_rank
,
get_attention_tp_size
,
get_attention_tp_size
,
get_local_attention_dp_size
,
get_local_attention_dp_size
,
...
...
python/sglang/srt/models/qwen3_moe.py
View file @
c0fb25e9
...
@@ -38,10 +38,6 @@ from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
...
@@ -38,10 +38,6 @@ from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.communicator
import
LayerCommunicator
,
LayerScatterModes
from
sglang.srt.layers.communicator
import
LayerCommunicator
,
LayerScatterModes
from
sglang.srt.layers.dp_attention
import
(
from
sglang.srt.layers.dp_attention
import
(
attn_tp_all_gather
,
attn_tp_reduce_scatter
,
dp_gather_partial
,
dp_scatter
,
get_attention_tp_rank
,
get_attention_tp_rank
,
get_attention_tp_size
,
get_attention_tp_size
,
get_local_attention_dp_size
,
get_local_attention_dp_size
,
...
@@ -193,8 +189,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
...
@@ -193,8 +189,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
def
forward_deepep
(
def
forward_deepep
(
self
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
self
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
forward_mode
=
forward_batch
.
forward_mode
if
hidden_states
.
shape
[
0
]
>
0
:
if
is_non_idle_and_non_empty
(
forward_mode
,
hidden_states
):
# router_logits: (num_tokens, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
topk_weights
,
topk_idx
,
_
=
self
.
topk
(
topk_weights
,
topk_idx
,
_
=
self
.
topk
(
...
...
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
View file @
c0fb25e9
...
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Callable
...
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Callable
import
torch
import
torch
from
sglang.srt.layers.dp_attention
import
DPPaddingMode
from
sglang.srt.model_executor.cuda_graph_runner
import
(
from
sglang.srt.model_executor.cuda_graph_runner
import
(
CUDA_GRAPH_CAPTURE_FAILED_MSG
,
CUDA_GRAPH_CAPTURE_FAILED_MSG
,
CudaGraphRunner
,
CudaGraphRunner
,
...
@@ -97,13 +98,6 @@ class EAGLEDraftCudaGraphRunner:
...
@@ -97,13 +98,6 @@ class EAGLEDraftCudaGraphRunner:
)
)
if
self
.
require_gathered_buffer
:
if
self
.
require_gathered_buffer
:
self
.
gathered_buffer
=
torch
.
zeros
(
(
self
.
max_num_token
,
self
.
model_runner
.
model_config
.
hidden_size
,
),
dtype
=
self
.
model_runner
.
dtype
,
)
if
self
.
require_mlp_tp_gather
:
if
self
.
require_mlp_tp_gather
:
self
.
global_num_tokens_gpu
=
torch
.
zeros
(
self
.
global_num_tokens_gpu
=
torch
.
zeros
(
(
self
.
dp_size
,),
dtype
=
torch
.
int32
(
self
.
dp_size
,),
dtype
=
torch
.
int32
...
@@ -111,12 +105,30 @@ class EAGLEDraftCudaGraphRunner:
...
@@ -111,12 +105,30 @@ class EAGLEDraftCudaGraphRunner:
self
.
global_num_tokens_for_logprob_gpu
=
torch
.
zeros
(
self
.
global_num_tokens_for_logprob_gpu
=
torch
.
zeros
(
(
self
.
dp_size
,),
dtype
=
torch
.
int32
(
self
.
dp_size
,),
dtype
=
torch
.
int32
)
)
self
.
gathered_buffer
=
torch
.
zeros
(
(
self
.
max_num_token
*
self
.
dp_size
,
self
.
model_runner
.
model_config
.
hidden_size
,
),
dtype
=
self
.
model_runner
.
dtype
,
)
else
:
else
:
assert
self
.
require_attn_tp_gather
assert
self
.
require_attn_tp_gather
self
.
global_num_tokens_gpu
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
int32
)
self
.
global_num_tokens_gpu
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
int32
)
self
.
global_num_tokens_for_logprob_gpu
=
torch
.
zeros
(
self
.
global_num_tokens_for_logprob_gpu
=
torch
.
zeros
(
(
1
,),
dtype
=
torch
.
int32
(
1
,),
dtype
=
torch
.
int32
)
)
self
.
gathered_buffer
=
torch
.
zeros
(
(
self
.
max_num_token
,
self
.
model_runner
.
model_config
.
hidden_size
,
),
dtype
=
self
.
model_runner
.
dtype
,
)
else
:
self
.
global_num_tokens_gpu
=
None
self
.
global_num_tokens_for_logprob_gpu
=
None
self
.
gathered_buffer
=
None
# Capture
# Capture
try
:
try
:
...
@@ -130,9 +142,9 @@ class EAGLEDraftCudaGraphRunner:
...
@@ -130,9 +142,9 @@ class EAGLEDraftCudaGraphRunner:
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
if
self
.
require_mlp_tp_gather
:
if
self
.
require_mlp_tp_gather
:
cuda_graph_bs
=
(
cuda_graph_bs
=
(
sum
(
forward_batch
.
global_num_tokens_cpu
)
//
self
.
num_tokens_per_bs
max
(
forward_batch
.
global_num_tokens_cpu
)
//
self
.
num_tokens_per_bs
if
self
.
model_runner
.
spec_algorithm
.
is_eagle
()
if
self
.
model_runner
.
spec_algorithm
.
is_eagle
()
else
sum
(
forward_batch
.
global_num_tokens_cpu
)
else
max
(
forward_batch
.
global_num_tokens_cpu
)
)
)
else
:
else
:
cuda_graph_bs
=
forward_batch
.
batch_size
cuda_graph_bs
=
forward_batch
.
batch_size
...
@@ -168,26 +180,20 @@ class EAGLEDraftCudaGraphRunner:
...
@@ -168,26 +180,20 @@ class EAGLEDraftCudaGraphRunner:
if
self
.
require_mlp_tp_gather
:
if
self
.
require_mlp_tp_gather
:
self
.
global_num_tokens_gpu
.
copy_
(
self
.
global_num_tokens_gpu
.
copy_
(
torch
.
tensor
(
torch
.
tensor
(
[
[
num_tokens
]
*
self
.
dp_size
,
num_tokens
//
self
.
dp_size
+
(
i
<
(
num_tokens
%
self
.
dp_size
))
for
i
in
range
(
self
.
dp_size
)
],
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
input_ids
.
device
,
device
=
self
.
input_ids
.
device
,
)
)
)
)
self
.
global_num_tokens_for_logprob_gpu
.
copy_
(
self
.
global_num_tokens_for_logprob_gpu
.
copy_
(
torch
.
tensor
(
torch
.
tensor
(
[
[
num_tokens
]
*
self
.
dp_size
,
num_tokens
//
self
.
dp_size
+
(
i
<
(
num_tokens
%
self
.
dp_size
))
for
i
in
range
(
self
.
dp_size
)
],
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
input_ids
.
device
,
device
=
self
.
input_ids
.
device
,
)
)
)
)
global_num_tokens
=
self
.
global_num_tokens_gpu
global_num_tokens
=
self
.
global_num_tokens_gpu
gathered_buffer
=
self
.
gathered_buffer
[:
num_tokens
]
gathered_buffer
=
self
.
gathered_buffer
[:
num_tokens
*
self
.
dp_size
]
global_num_tokens_for_logprob
=
self
.
global_num_tokens_for_logprob_gpu
global_num_tokens_for_logprob
=
self
.
global_num_tokens_for_logprob_gpu
elif
self
.
require_attn_tp_gather
:
elif
self
.
require_attn_tp_gather
:
self
.
global_num_tokens_gpu
.
copy_
(
self
.
global_num_tokens_gpu
.
copy_
(
...
@@ -233,6 +239,7 @@ class EAGLEDraftCudaGraphRunner:
...
@@ -233,6 +239,7 @@ class EAGLEDraftCudaGraphRunner:
return_logprob
=
False
,
return_logprob
=
False
,
positions
=
positions
,
positions
=
positions
,
global_num_tokens_gpu
=
global_num_tokens
,
global_num_tokens_gpu
=
global_num_tokens
,
dp_padding_mode
=
DPPaddingMode
.
get_default_mode_in_cuda_graph
(),
gathered_buffer
=
gathered_buffer
,
gathered_buffer
=
gathered_buffer
,
spec_algorithm
=
self
.
model_runner
.
spec_algorithm
,
spec_algorithm
=
self
.
model_runner
.
spec_algorithm
,
spec_info
=
spec_info
,
spec_info
=
spec_info
,
...
@@ -290,12 +297,13 @@ class EAGLEDraftCudaGraphRunner:
...
@@ -290,12 +297,13 @@ class EAGLEDraftCudaGraphRunner:
# Pad
# Pad
if
self
.
require_mlp_tp_gather
:
if
self
.
require_mlp_tp_gather
:
total_batch_size
=
(
max_num_tokens
=
max
(
forward_batch
.
global_num_tokens_cpu
)
sum
(
forward_batch
.
global_num_tokens_cpu
)
//
self
.
num_tokens_per_bs
max_batch_size
=
(
max_num_tokens
//
self
.
num_tokens_per_bs
if
self
.
model_runner
.
spec_algorithm
.
is_eagle
()
if
self
.
model_runner
.
spec_algorithm
.
is_eagle
()
else
sum
(
forward_batch
.
global
_num_tokens
_cpu
)
else
max
_num_tokens
)
)
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
total
_batch_size
)
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
max
_batch_size
)
else
:
else
:
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
raw_bs
)
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
raw_bs
)
bs
=
self
.
capture_bs
[
index
]
bs
=
self
.
capture_bs
[
index
]
...
@@ -316,12 +324,10 @@ class EAGLEDraftCudaGraphRunner:
...
@@ -316,12 +324,10 @@ class EAGLEDraftCudaGraphRunner:
self
.
topk_index
[:
raw_bs
].
copy_
(
forward_batch
.
spec_info
.
topk_index
)
self
.
topk_index
[:
raw_bs
].
copy_
(
forward_batch
.
spec_info
.
topk_index
)
self
.
hidden_states
[:
raw_bs
].
copy_
(
forward_batch
.
spec_info
.
hidden_states
)
self
.
hidden_states
[:
raw_bs
].
copy_
(
forward_batch
.
spec_info
.
hidden_states
)
# TODO(ch-wan): support num_token_non_padded
if
self
.
require_gathered_buffer
:
if
self
.
require_gathered_buffer
:
self
.
global_num_tokens_gpu
.
copy_
(
forward_batch
.
global_num_tokens_gpu
)
self
.
global_num_tokens_gpu
.
fill_
(
bs
*
self
.
num_tokens_per_bs
)
self
.
global_num_tokens_for_logprob_gpu
.
copy_
(
self
.
global_num_tokens_for_logprob_gpu
.
fill_
(
bs
*
self
.
num_tokens_per_bs
)
forward_batch
.
global_num_tokens_for_logprob_gpu
)
forward_batch
.
gathered_buffer
=
self
.
gathered_buffer
# Attention backend
# Attention backend
if
bs
!=
raw_bs
:
if
bs
!=
raw_bs
:
...
...
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
View file @
c0fb25e9
...
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Callable
...
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Callable
import
torch
import
torch
from
sglang.srt.layers.dp_attention
import
DPPaddingMode
from
sglang.srt.model_executor.cuda_graph_runner
import
(
from
sglang.srt.model_executor.cuda_graph_runner
import
(
CUDA_GRAPH_CAPTURE_FAILED_MSG
,
CUDA_GRAPH_CAPTURE_FAILED_MSG
,
CudaGraphRunner
,
CudaGraphRunner
,
...
@@ -109,13 +110,6 @@ class EAGLEDraftExtendCudaGraphRunner:
...
@@ -109,13 +110,6 @@ class EAGLEDraftExtendCudaGraphRunner:
)
)
if
self
.
require_gathered_buffer
:
if
self
.
require_gathered_buffer
:
self
.
gathered_buffer
=
torch
.
zeros
(
(
self
.
max_num_token
,
self
.
model_runner
.
model_config
.
hidden_size
,
),
dtype
=
self
.
model_runner
.
dtype
,
)
if
self
.
require_mlp_tp_gather
:
if
self
.
require_mlp_tp_gather
:
self
.
global_num_tokens_gpu
=
torch
.
zeros
(
self
.
global_num_tokens_gpu
=
torch
.
zeros
(
(
self
.
dp_size
,),
dtype
=
torch
.
int32
(
self
.
dp_size
,),
dtype
=
torch
.
int32
...
@@ -123,12 +117,31 @@ class EAGLEDraftExtendCudaGraphRunner:
...
@@ -123,12 +117,31 @@ class EAGLEDraftExtendCudaGraphRunner:
self
.
global_num_tokens_for_logprob_gpu
=
torch
.
zeros
(
self
.
global_num_tokens_for_logprob_gpu
=
torch
.
zeros
(
(
self
.
dp_size
,),
dtype
=
torch
.
int32
(
self
.
dp_size
,),
dtype
=
torch
.
int32
)
)
self
.
gathered_buffer
=
torch
.
zeros
(
(
self
.
max_num_token
*
self
.
dp_size
,
self
.
model_runner
.
model_config
.
hidden_size
,
),
dtype
=
self
.
model_runner
.
dtype
,
)
else
:
else
:
assert
self
.
require_attn_tp_gather
assert
self
.
require_attn_tp_gather
self
.
global_num_tokens_gpu
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
int32
)
self
.
global_num_tokens_gpu
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
int32
)
self
.
global_num_tokens_for_logprob_gpu
=
torch
.
zeros
(
self
.
global_num_tokens_for_logprob_gpu
=
torch
.
zeros
(
(
1
,),
dtype
=
torch
.
int32
(
1
,),
dtype
=
torch
.
int32
)
)
self
.
gathered_buffer
=
torch
.
zeros
(
(
self
.
max_num_token
,
self
.
model_runner
.
model_config
.
hidden_size
,
),
dtype
=
self
.
model_runner
.
dtype
,
)
else
:
self
.
global_num_tokens_gpu
=
None
self
.
global_num_tokens_for_logprob_gpu
=
None
self
.
gathered_buffer
=
None
# Capture
# Capture
try
:
try
:
with
model_capture_mode
():
with
model_capture_mode
():
...
@@ -141,9 +154,9 @@ class EAGLEDraftExtendCudaGraphRunner:
...
@@ -141,9 +154,9 @@ class EAGLEDraftExtendCudaGraphRunner:
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
if
self
.
require_mlp_tp_gather
:
if
self
.
require_mlp_tp_gather
:
cuda_graph_bs
=
(
cuda_graph_bs
=
(
sum
(
forward_batch
.
global_num_tokens_cpu
)
//
self
.
num_tokens_per_bs
max
(
forward_batch
.
global_num_tokens_cpu
)
//
self
.
num_tokens_per_bs
if
self
.
model_runner
.
spec_algorithm
.
is_eagle
()
if
self
.
model_runner
.
spec_algorithm
.
is_eagle
()
else
sum
(
forward_batch
.
global_num_tokens_cpu
)
else
max
(
forward_batch
.
global_num_tokens_cpu
)
)
)
else
:
else
:
cuda_graph_bs
=
forward_batch
.
seq_lens
.
numel
()
cuda_graph_bs
=
forward_batch
.
seq_lens
.
numel
()
...
@@ -180,27 +193,19 @@ class EAGLEDraftExtendCudaGraphRunner:
...
@@ -180,27 +193,19 @@ class EAGLEDraftExtendCudaGraphRunner:
if
self
.
require_mlp_tp_gather
:
if
self
.
require_mlp_tp_gather
:
self
.
global_num_tokens_gpu
.
copy_
(
self
.
global_num_tokens_gpu
.
copy_
(
torch
.
tensor
(
torch
.
tensor
(
[
[
num_tokens
]
*
self
.
dp_size
,
num_tokens
//
self
.
dp_size
+
(
i
<
(
num_tokens
%
self
.
dp_size
))
for
i
in
range
(
self
.
dp_size
)
],
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
input_ids
.
device
,
device
=
self
.
input_ids
.
device
,
)
)
)
)
self
.
global_num_tokens_for_logprob_gpu
.
copy_
(
self
.
global_num_tokens_for_logprob_gpu
.
copy_
(
torch
.
tensor
(
torch
.
tensor
(
[
[
bs
]
*
self
.
dp_size
,
num_tokens
//
self
.
dp_size
+
(
i
<
(
num_tokens
%
self
.
dp_size
))
for
i
in
range
(
self
.
dp_size
)
],
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
input_ids
.
device
,
device
=
self
.
input_ids
.
device
,
)
)
)
)
global_num_tokens
=
self
.
global_num_tokens_gpu
gathered_buffer
=
self
.
gathered_buffer
[:
num_tokens
*
self
.
dp_size
]
gathered_buffer
=
self
.
gathered_buffer
[:
num_tokens
]
global_num_tokens_for_logprob
=
self
.
global_num_tokens_for_logprob_gpu
elif
self
.
require_attn_tp_gather
:
elif
self
.
require_attn_tp_gather
:
self
.
global_num_tokens_gpu
.
copy_
(
self
.
global_num_tokens_gpu
.
copy_
(
torch
.
tensor
(
torch
.
tensor
(
...
@@ -211,18 +216,14 @@ class EAGLEDraftExtendCudaGraphRunner:
...
@@ -211,18 +216,14 @@ class EAGLEDraftExtendCudaGraphRunner:
)
)
self
.
global_num_tokens_for_logprob_gpu
.
copy_
(
self
.
global_num_tokens_for_logprob_gpu
.
copy_
(
torch
.
tensor
(
torch
.
tensor
(
[
num_token
s
],
[
b
s
],
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
input_ids
.
device
,
device
=
self
.
input_ids
.
device
,
)
)
)
)
global_num_tokens
=
self
.
global_num_tokens_gpu
gathered_buffer
=
self
.
gathered_buffer
[:
num_tokens
]
gathered_buffer
=
self
.
gathered_buffer
[:
num_tokens
]
global_num_tokens_for_logprob
=
self
.
global_num_tokens_for_logprob_gpu
else
:
else
:
global_num_tokens
=
None
gathered_buffer
=
None
gathered_buffer
=
None
global_num_tokens_for_logprob
=
None
spec_info
=
EagleDraftInput
(
spec_info
=
EagleDraftInput
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
...
@@ -243,8 +244,9 @@ class EAGLEDraftExtendCudaGraphRunner:
...
@@ -243,8 +244,9 @@ class EAGLEDraftExtendCudaGraphRunner:
seq_lens_sum
=
seq_lens
.
sum
().
item
(),
seq_lens_sum
=
seq_lens
.
sum
().
item
(),
return_logprob
=
False
,
return_logprob
=
False
,
positions
=
positions
,
positions
=
positions
,
global_num_tokens_gpu
=
global_num_tokens
,
global_num_tokens_gpu
=
self
.
global_num_tokens_gpu
,
global_num_tokens_for_logprob_gpu
=
global_num_tokens_for_logprob
,
global_num_tokens_for_logprob_gpu
=
self
.
global_num_tokens_for_logprob_gpu
,
dp_padding_mode
=
DPPaddingMode
.
get_default_mode_in_cuda_graph
(),
gathered_buffer
=
gathered_buffer
,
gathered_buffer
=
gathered_buffer
,
spec_algorithm
=
self
.
model_runner
.
spec_algorithm
,
spec_algorithm
=
self
.
model_runner
.
spec_algorithm
,
spec_info
=
spec_info
,
spec_info
=
spec_info
,
...
@@ -306,12 +308,13 @@ class EAGLEDraftExtendCudaGraphRunner:
...
@@ -306,12 +308,13 @@ class EAGLEDraftExtendCudaGraphRunner:
raw_bs
=
forward_batch
.
batch_size
raw_bs
=
forward_batch
.
batch_size
num_tokens
=
forward_batch
.
input_ids
.
shape
[
0
]
num_tokens
=
forward_batch
.
input_ids
.
shape
[
0
]
if
self
.
require_mlp_tp_gather
:
if
self
.
require_mlp_tp_gather
:
total_batch_size
=
(
max_num_tokens
=
max
(
forward_batch
.
global_num_tokens_cpu
)
sum
(
forward_batch
.
global_num_tokens_cpu
)
//
self
.
num_tokens_per_bs
max_batch_size
=
(
max_num_tokens
//
self
.
num_tokens_per_bs
if
self
.
model_runner
.
spec_algorithm
.
is_eagle
()
if
self
.
model_runner
.
spec_algorithm
.
is_eagle
()
else
sum
(
forward_batch
.
global
_num_tokens
_cpu
)
else
max
_num_tokens
)
)
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
total
_batch_size
)
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
max
_batch_size
)
else
:
else
:
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
raw_bs
)
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
raw_bs
)
...
@@ -334,12 +337,10 @@ class EAGLEDraftExtendCudaGraphRunner:
...
@@ -334,12 +337,10 @@ class EAGLEDraftExtendCudaGraphRunner:
self
.
accept_length
[:
raw_bs
].
copy_
(
forward_batch
.
spec_info
.
accept_length
)
self
.
accept_length
[:
raw_bs
].
copy_
(
forward_batch
.
spec_info
.
accept_length
)
self
.
req_pool_indices
[:
raw_bs
].
copy_
(
forward_batch
.
req_pool_indices
)
self
.
req_pool_indices
[:
raw_bs
].
copy_
(
forward_batch
.
req_pool_indices
)
# TODO(ch-wan): support num_token_non_padded
if
self
.
require_gathered_buffer
:
if
self
.
require_gathered_buffer
:
self
.
global_num_tokens_gpu
.
copy_
(
forward_batch
.
global_num_tokens_gpu
)
self
.
global_num_tokens_gpu
.
fill_
(
bs
*
self
.
num_tokens_per_bs
)
self
.
global_num_tokens_for_logprob_gpu
.
copy_
(
self
.
global_num_tokens_for_logprob_gpu
.
fill_
(
bs
)
forward_batch
.
global_num_tokens_for_logprob_gpu
)
forward_batch
.
gathered_buffer
=
self
.
gathered_buffer
if
forward_batch
.
seq_lens_cpu
is
not
None
:
if
forward_batch
.
seq_lens_cpu
is
not
None
:
if
bs
!=
raw_bs
:
if
bs
!=
raw_bs
:
...
...
python/sglang/srt/speculative/eagle_utils.py
View file @
c0fb25e9
...
@@ -71,9 +71,20 @@ class EagleDraftInput:
...
@@ -71,9 +71,20 @@ class EagleDraftInput:
kv_indptr
:
torch
.
Tensor
=
None
kv_indptr
:
torch
.
Tensor
=
None
kv_indices
:
torch
.
Tensor
=
None
kv_indices
:
torch
.
Tensor
=
None
# Shape info for padding
num_tokens_per_batch
:
int
=
-
1
num_tokens_for_logprob_per_batch
:
int
=
-
1
# Inputs for draft extend
# shape: (b,)
seq_lens_for_draft_extend
:
torch
.
Tensor
=
None
req_pool_indices_for_draft_extend
:
torch
.
Tensor
=
None
def
prepare_for_extend
(
self
,
batch
:
ScheduleBatch
):
def
prepare_for_extend
(
self
,
batch
:
ScheduleBatch
):
if
batch
.
forward_mode
.
is_idle
():
if
batch
.
forward_mode
.
is_idle
():
return
return
# Prefill only generate 1 token.
# Prefill only generate 1 token.
assert
len
(
self
.
verified_id
)
==
len
(
batch
.
seq_lens
)
assert
len
(
self
.
verified_id
)
==
len
(
batch
.
seq_lens
)
...
@@ -95,7 +106,7 @@ class EagleDraftInput:
...
@@ -95,7 +106,7 @@ class EagleDraftInput:
capture_hidden_mode
:
CaptureHiddenMode
,
capture_hidden_mode
:
CaptureHiddenMode
,
):
):
return
cls
(
return
cls
(
verified_id
=
None
,
verified_id
=
torch
.
empty
((
0
,),
device
=
device
,
dtype
=
torch
.
int32
)
,
hidden_states
=
torch
.
empty
((
0
,
hidden_size
),
device
=
device
,
dtype
=
dtype
),
hidden_states
=
torch
.
empty
((
0
,
hidden_size
),
device
=
device
,
dtype
=
dtype
),
topk_p
=
torch
.
empty
((
0
,
topk
),
device
=
device
,
dtype
=
torch
.
float32
),
topk_p
=
torch
.
empty
((
0
,
topk
),
device
=
device
,
dtype
=
torch
.
float32
),
topk_index
=
torch
.
empty
((
0
,
topk
),
device
=
device
,
dtype
=
torch
.
int64
),
topk_index
=
torch
.
empty
((
0
,
topk
),
device
=
device
,
dtype
=
torch
.
int64
),
...
@@ -109,7 +120,10 @@ class EagleDraftInput:
...
@@ -109,7 +120,10 @@ class EagleDraftInput:
batch
:
ScheduleBatch
,
batch
:
ScheduleBatch
,
speculative_num_steps
:
int
,
speculative_num_steps
:
int
,
):
):
batch
.
forward_mode
=
ForwardMode
.
DRAFT_EXTEND
if
batch
.
forward_mode
.
is_idle
():
return
batch
.
input_ids
=
self
.
verified_id
batch
.
input_ids
=
self
.
verified_id
batch
.
extend_lens
=
[
x
+
1
for
x
in
batch
.
spec_info
.
accept_length_cpu
]
batch
.
extend_lens
=
[
x
+
1
for
x
in
batch
.
spec_info
.
accept_length_cpu
]
batch
.
extend_num_tokens
=
sum
(
batch
.
extend_lens
)
batch
.
extend_num_tokens
=
sum
(
batch
.
extend_lens
)
...
@@ -316,7 +330,7 @@ class EagleVerifyInput:
...
@@ -316,7 +330,7 @@ class EagleVerifyInput:
def
verify
(
def
verify
(
self
,
self
,
batch
:
ScheduleBatch
,
batch
:
ScheduleBatch
,
logits_output
:
torch
.
Tensor
,
logits_output
:
LogitsProcessorOutput
,
token_to_kv_pool_allocator
:
BaseTokenToKVPoolAllocator
,
token_to_kv_pool_allocator
:
BaseTokenToKVPoolAllocator
,
page_size
:
int
,
page_size
:
int
,
vocab_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
# For grammar
vocab_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
# For grammar
...
@@ -599,13 +613,14 @@ class EagleVerifyInput:
...
@@ -599,13 +613,14 @@ class EagleVerifyInput:
batch
.
out_cache_loc
=
tgt_cache_loc
batch
.
out_cache_loc
=
tgt_cache_loc
batch
.
seq_lens
.
add_
(
accept_length
+
1
)
batch
.
seq_lens
.
add_
(
accept_length
+
1
)
draft_input
=
EagleDraftInput
()
draft_input
=
EagleDraftInput
(
draft_input
.
hidden_states
=
batch
.
spec_info
.
hidden_states
[
accept_index
]
hidden_states
=
batch
.
spec_info
.
hidden_states
[
accept_index
],
draft_input
.
verified_id
=
verified_id
verified_id
=
verified_id
,
draft_input
.
accept_length
=
accept_length
accept_length
=
accept_length
,
draft_input
.
accept_length_cpu
=
accept_length
.
tolist
()
accept_length_cpu
=
accept_length
.
tolist
(),
draft_input
.
seq_lens_for_draft_extend
=
batch
.
seq_lens
seq_lens_for_draft_extend
=
batch
.
seq_lens
,
draft_input
.
req_pool_indices_for_draft_extend
=
batch
.
req_pool_indices
req_pool_indices_for_draft_extend
=
batch
.
req_pool_indices
,
)
return
EagleVerifyOutput
(
return
EagleVerifyOutput
(
draft_input
=
draft_input
,
draft_input
=
draft_input
,
...
@@ -628,7 +643,6 @@ class EagleVerifyInput:
...
@@ -628,7 +643,6 @@ class EagleVerifyInput:
batch
.
seq_lens
.
add_
(
accept_length
+
1
)
batch
.
seq_lens
.
add_
(
accept_length
+
1
)
accept_length_cpu
=
accept_length
.
tolist
()
accept_length_cpu
=
accept_length
.
tolist
()
draft_input
=
EagleDraftInput
()
if
len
(
unfinished_accept_index
)
>
0
:
if
len
(
unfinished_accept_index
)
>
0
:
unfinished_accept_index
=
torch
.
cat
(
unfinished_accept_index
)
unfinished_accept_index
=
torch
.
cat
(
unfinished_accept_index
)
unfinished_index_device
=
torch
.
tensor
(
unfinished_index_device
=
torch
.
tensor
(
...
@@ -659,18 +673,26 @@ class EagleVerifyInput:
...
@@ -659,18 +673,26 @@ class EagleVerifyInput:
next_power_of_2
(
self
.
draft_token_num
),
next_power_of_2
(
self
.
draft_token_num
),
)
)
draft_input
.
hidden_states
=
batch
.
spec_info
.
hidden_states
[
draft_input
=
EagleDraftInput
(
hidden_states
=
batch
.
spec_info
.
hidden_states
[
unfinished_accept_index
unfinished_accept_index
]
],
draft_input
.
verified_id
=
predict
[
unfinished_accept_index
]
verified_id
=
predict
[
unfinished_accept_index
],
draft_input
.
accept_length_cpu
=
draft_input_accept_length_cpu
accept_length_cpu
=
draft_input_accept_length_cpu
,
draft_input
.
accept_length
=
accept_length
[
unfinished_index_device
]
accept_length
=
accept_length
[
unfinished_index_device
],
draft_input
.
seq_lens_for_draft_extend
=
batch
.
seq_lens
[
seq_lens_for_draft_extend
=
batch
.
seq_lens
[
unfinished_index_device
],
unfinished_index_device
req_pool_indices_for_draft_extend
=
batch
.
req_pool_indices
[
]
draft_input
.
req_pool_indices_for_draft_extend
=
batch
.
req_pool_indices
[
unfinished_index_device
unfinished_index_device
]
],
)
else
:
draft_input
=
EagleDraftInput
.
create_idle_input
(
device
=
batch
.
device
,
hidden_size
=
batch
.
model_config
.
hidden_size
,
dtype
=
batch
.
model_config
.
dtype
,
topk
=
self
.
topk
,
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
,
)
return
EagleVerifyOutput
(
return
EagleVerifyOutput
(
draft_input
=
draft_input
,
draft_input
=
draft_input
,
...
...
python/sglang/srt/speculative/eagle_worker.py
View file @
c0fb25e9
...
@@ -297,7 +297,7 @@ class EAGLEWorker(TpModelWorker):
...
@@ -297,7 +297,7 @@ class EAGLEWorker(TpModelWorker):
def
forward_batch_speculative_generation
(
def
forward_batch_speculative_generation
(
self
,
batch
:
ScheduleBatch
self
,
batch
:
ScheduleBatch
)
->
Tuple
[
LogitsProcessorOutput
,
List
[
int
]
,
int
,
int
]:
)
->
Tuple
[
LogitsProcessorOutput
,
torch
.
Tensor
,
int
,
int
,
bool
]:
"""Run speculative decoding forward.
"""Run speculative decoding forward.
NOTE: Many states of batch is modified as you go through. It is not guaranteed that
NOTE: Many states of batch is modified as you go through. It is not guaranteed that
...
@@ -325,11 +325,16 @@ class EAGLEWorker(TpModelWorker):
...
@@ -325,11 +325,16 @@ class EAGLEWorker(TpModelWorker):
self
.
verify
(
batch
,
spec_info
)
self
.
verify
(
batch
,
spec_info
)
)
)
if
self
.
check_forward_draft_extend_after_decode
(
batch
):
with
self
.
draft_tp_context
(
self
.
draft_model_runner
.
tp_group
):
with
self
.
draft_tp_context
(
self
.
draft_model_runner
.
tp_group
):
self
.
forward_draft_extend_after_decode
(
# NOTE: We should use `check_forward_draft_extend_after_decode`
batch
,
# when DP attention is enabled, but it is slow. Skip it for now.
)
if
(
self
.
server_args
.
enable_dp_attention
or
batch
.
spec_info
.
verified_id
.
shape
[
0
]
>
0
):
# decode is not finished
self
.
forward_draft_extend_after_decode
(
batch
)
return
(
return
(
logits_output
,
logits_output
,
verify_output
.
verified_id
,
verify_output
.
verified_id
,
...
@@ -339,10 +344,7 @@ class EAGLEWorker(TpModelWorker):
...
@@ -339,10 +344,7 @@ class EAGLEWorker(TpModelWorker):
)
)
def
check_forward_draft_extend_after_decode
(
self
,
batch
:
ScheduleBatch
):
def
check_forward_draft_extend_after_decode
(
self
,
batch
:
ScheduleBatch
):
local_need_forward
=
(
local_need_forward
=
batch
.
spec_info
.
verified_id
.
shape
[
0
]
>
0
batch
.
spec_info
.
verified_id
is
not
None
and
batch
.
spec_info
.
verified_id
.
shape
[
0
]
>
0
)
if
not
self
.
server_args
.
enable_dp_attention
:
if
not
self
.
server_args
.
enable_dp_attention
:
return
local_need_forward
return
local_need_forward
...
@@ -361,7 +363,7 @@ class EAGLEWorker(TpModelWorker):
...
@@ -361,7 +363,7 @@ class EAGLEWorker(TpModelWorker):
def
forward_target_extend
(
def
forward_target_extend
(
self
,
batch
:
ScheduleBatch
self
,
batch
:
ScheduleBatch
)
->
Tuple
[
LogitsProcessorOutput
,
List
[
int
]
,
int
]:
)
->
Tuple
[
LogitsProcessorOutput
,
torch
.
Tensor
,
int
,
Optional
[
torch
.
Tensor
]
]:
"""Run the target extend.
"""Run the target extend.
Args:
Args:
...
@@ -376,7 +378,6 @@ class EAGLEWorker(TpModelWorker):
...
@@ -376,7 +378,6 @@ class EAGLEWorker(TpModelWorker):
# We need the full hidden states to prefill the KV cache of the draft model.
# We need the full hidden states to prefill the KV cache of the draft model.
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
.
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
model_worker_batch
.
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
model_worker_batch
.
spec_num_draft_tokens
=
1
logits_output
,
next_token_ids
,
_
=
self
.
target_worker
.
forward_batch_generation
(
logits_output
,
next_token_ids
,
_
=
self
.
target_worker
.
forward_batch_generation
(
model_worker_batch
model_worker_batch
)
)
...
@@ -508,13 +509,15 @@ class EAGLEWorker(TpModelWorker):
...
@@ -508,13 +509,15 @@ class EAGLEWorker(TpModelWorker):
self
.
_draft_preprocess_decode
(
batch
)
self
.
_draft_preprocess_decode
(
batch
)
spec_info
=
batch
.
spec_info
spec_info
=
batch
.
spec_info
assert
isinstance
(
spec_info
,
EagleDraftInput
)
spec_info
.
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
spec_info
.
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
spec_info
.
num_tokens_per_batch
=
self
.
topk
spec_info
.
num_tokens_for_logprob_per_batch
=
self
.
topk
batch
.
return_hidden_states
=
False
batch
.
return_hidden_states
=
False
# Get forward batch
# Get forward batch
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
.
spec_num_draft_tokens
=
self
.
topk
assert
model_worker_batch
.
capture_hidden_mode
==
CaptureHiddenMode
.
LAST
assert
model_worker_batch
.
capture_hidden_mode
==
CaptureHiddenMode
.
LAST
forward_batch
=
ForwardBatch
.
init_new
(
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
draft_model_runner
model_worker_batch
,
self
.
draft_model_runner
...
@@ -527,6 +530,7 @@ class EAGLEWorker(TpModelWorker):
...
@@ -527,6 +530,7 @@ class EAGLEWorker(TpModelWorker):
forward_batch
forward_batch
)
)
else
:
else
:
forward_batch
.
can_run_dp_cuda_graph
=
False
if
not
forward_batch
.
forward_mode
.
is_idle
():
if
not
forward_batch
.
forward_mode
.
is_idle
():
# Initialize attention backend
# Initialize attention backend
self
.
draft_attn_backend
.
init_forward_metadata
(
forward_batch
)
self
.
draft_attn_backend
.
init_forward_metadata
(
forward_batch
)
...
@@ -578,6 +582,7 @@ class EAGLEWorker(TpModelWorker):
...
@@ -578,6 +582,7 @@ class EAGLEWorker(TpModelWorker):
def
draft_forward
(
self
,
forward_batch
:
ForwardBatch
):
def
draft_forward
(
self
,
forward_batch
:
ForwardBatch
):
# Parse args
# Parse args
spec_info
=
forward_batch
.
spec_info
spec_info
=
forward_batch
.
spec_info
assert
isinstance
(
spec_info
,
EagleDraftInput
)
out_cache_loc
=
forward_batch
.
out_cache_loc
out_cache_loc
=
forward_batch
.
out_cache_loc
topk_p
,
topk_index
,
hidden_states
=
(
topk_p
,
topk_index
,
hidden_states
=
(
spec_info
.
topk_p
,
spec_info
.
topk_p
,
...
@@ -621,8 +626,8 @@ class EAGLEWorker(TpModelWorker):
...
@@ -621,8 +626,8 @@ class EAGLEWorker(TpModelWorker):
spec_info
.
hidden_states
=
hidden_states
spec_info
.
hidden_states
=
hidden_states
# Run forward
# Run forward
logits_output
=
self
.
draft_model_runner
.
model
.
forward
(
logits_output
,
_
=
self
.
draft_model_runner
.
forward
(
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
forward_batch
,
skip_attn_backend_init
=
True
)
)
self
.
_detect_nan_if_needed
(
logits_output
)
self
.
_detect_nan_if_needed
(
logits_output
)
probs
=
torch
.
softmax
(
logits_output
.
next_token_logits
,
dim
=-
1
)
probs
=
torch
.
softmax
(
logits_output
.
next_token_logits
,
dim
=-
1
)
...
@@ -642,10 +647,10 @@ class EAGLEWorker(TpModelWorker):
...
@@ -642,10 +647,10 @@ class EAGLEWorker(TpModelWorker):
else
ForwardMode
.
IDLE
else
ForwardMode
.
IDLE
)
)
batch
.
spec_info
=
spec_info
batch
.
spec_info
=
spec_info
model_worker_batch
=
batch
.
get_model_worker_batch
(
model_worker_batch
=
batch
.
get_model_worker_batch
(
seq_lens_cpu_cache
=
spec_info
.
seq_lens_cpu
seq_lens_cpu_cache
=
spec_info
.
seq_lens_cpu
)
)
model_worker_batch
.
spec_num_draft_tokens
=
self
.
speculative_num_draft_tokens
assert
model_worker_batch
.
capture_hidden_mode
==
spec_info
.
capture_hidden_mode
assert
model_worker_batch
.
capture_hidden_mode
==
spec_info
.
capture_hidden_mode
if
batch
.
has_grammar
:
if
batch
.
has_grammar
:
...
@@ -782,8 +787,8 @@ class EAGLEWorker(TpModelWorker):
...
@@ -782,8 +787,8 @@ class EAGLEWorker(TpModelWorker):
self
,
self
,
batch
:
ScheduleBatch
,
batch
:
ScheduleBatch
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
next_token_ids
:
List
[
int
]
,
next_token_ids
:
torch
.
Tensor
,
seq_lens_cpu
:
torch
.
Tensor
,
seq_lens_cpu
:
Optional
[
torch
.
Tensor
]
,
):
):
"""Run draft model extend. This API modifies the states of the batch.
"""Run draft model extend. This API modifies the states of the batch.
...
@@ -795,6 +800,8 @@ class EAGLEWorker(TpModelWorker):
...
@@ -795,6 +800,8 @@ class EAGLEWorker(TpModelWorker):
batch
.
spec_info
=
EagleDraftInput
(
batch
.
spec_info
=
EagleDraftInput
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
verified_id
=
next_token_ids
,
verified_id
=
next_token_ids
,
num_tokens_per_batch
=
1
,
num_tokens_for_logprob_per_batch
=
1
,
)
)
batch
.
return_hidden_states
=
False
batch
.
return_hidden_states
=
False
batch
.
spec_info
.
prepare_for_extend
(
batch
)
batch
.
spec_info
.
prepare_for_extend
(
batch
)
...
@@ -802,7 +809,6 @@ class EAGLEWorker(TpModelWorker):
...
@@ -802,7 +809,6 @@ class EAGLEWorker(TpModelWorker):
model_worker_batch
=
batch
.
get_model_worker_batch
(
model_worker_batch
=
batch
.
get_model_worker_batch
(
seq_lens_cpu_cache
=
seq_lens_cpu
seq_lens_cpu_cache
=
seq_lens_cpu
)
)
model_worker_batch
.
spec_num_draft_tokens
=
1
forward_batch
=
ForwardBatch
.
init_new
(
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
draft_model_runner
model_worker_batch
,
self
.
draft_model_runner
)
)
...
@@ -814,20 +820,16 @@ class EAGLEWorker(TpModelWorker):
...
@@ -814,20 +820,16 @@ class EAGLEWorker(TpModelWorker):
self
.
capture_for_decode
(
logits_output
,
forward_batch
.
spec_info
)
self
.
capture_for_decode
(
logits_output
,
forward_batch
.
spec_info
)
def
forward_draft_extend_after_decode
(
self
,
batch
:
ScheduleBatch
):
def
forward_draft_extend_after_decode
(
self
,
batch
:
ScheduleBatch
):
assert
isinstance
(
batch
.
spec_info
,
EagleDraftInput
)
# Backup fields that will be modified in-place
# Backup fields that will be modified in-place
seq_lens_backup
=
batch
.
seq_lens
.
clone
()
seq_lens_backup
=
batch
.
seq_lens
.
clone
()
req_pool_indices_backup
=
batch
.
req_pool_indices
req_pool_indices_backup
=
batch
.
req_pool_indices
accept_length_backup
=
batch
.
spec_info
.
accept_length
accept_length_backup
=
batch
.
spec_info
.
accept_length
return_logprob_backup
=
batch
.
return_logprob
return_logprob_backup
=
batch
.
return_logprob
input_is_idle
=
batch
.
forward_mode
.
is_idle
()
input_is_idle
=
batch
.
forward_mode
.
is_idle
()
if
not
input_is_idle
:
# Prepare metadata
if
not
input_is_idle
and
batch
.
spec_info
.
verified_id
.
numel
()
==
0
:
if
batch
.
spec_info
.
verified_id
is
not
None
:
batch
.
spec_info
.
prepare_extend_after_decode
(
batch
,
self
.
speculative_num_steps
,
)
else
:
batch
=
batch
.
copy
()
batch
=
batch
.
copy
()
batch
.
prepare_for_idle
()
batch
.
prepare_for_idle
()
hidden_size
=
(
hidden_size
=
(
...
@@ -842,9 +844,21 @@ class EAGLEWorker(TpModelWorker):
...
@@ -842,9 +844,21 @@ class EAGLEWorker(TpModelWorker):
topk
=
self
.
topk
,
topk
=
self
.
topk
,
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
,
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
,
)
)
batch
.
spec_info
.
num_tokens_per_batch
=
self
.
speculative_num_steps
+
1
batch
.
spec_info
.
num_tokens_for_logprob_per_batch
=
1
batch
.
spec_info
.
prepare_extend_after_decode
(
batch
,
self
.
speculative_num_steps
,
)
batch
.
forward_mode
=
(
ForwardMode
.
DRAFT_EXTEND
if
not
batch
.
forward_mode
.
is_idle
()
else
ForwardMode
.
IDLE
)
batch
.
return_hidden_states
=
False
batch
.
return_hidden_states
=
False
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
.
spec_num_draft_tokens
=
self
.
speculative_num_steps
+
1
assert
model_worker_batch
.
capture_hidden_mode
==
CaptureHiddenMode
.
LAST
assert
model_worker_batch
.
capture_hidden_mode
==
CaptureHiddenMode
.
LAST
forward_batch
=
ForwardBatch
.
init_new
(
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
draft_model_runner
model_worker_batch
,
self
.
draft_model_runner
...
@@ -869,12 +883,13 @@ class EAGLEWorker(TpModelWorker):
...
@@ -869,12 +883,13 @@ class EAGLEWorker(TpModelWorker):
)
)
forward_batch
.
spec_info
.
hidden_states
=
logits_output
.
hidden_states
forward_batch
.
spec_info
.
hidden_states
=
logits_output
.
hidden_states
else
:
else
:
forward_batch
.
can_run_dp_cuda_graph
=
False
if
not
forward_batch
.
forward_mode
.
is_idle
():
if
not
forward_batch
.
forward_mode
.
is_idle
():
self
.
draft_model_runner
.
attn_backend
.
init_forward_metadata
(
self
.
draft_model_runner
.
attn_backend
.
init_forward_metadata
(
forward_batch
forward_batch
)
)
logits_output
=
self
.
draft_model_runner
.
model
.
forward
(
logits_output
,
_
=
self
.
draft_model_runner
.
forward
(
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
forward_batch
,
skip_attn_backend_init
=
True
)
)
self
.
capture_for_decode
(
logits_output
,
forward_batch
.
spec_info
)
self
.
capture_for_decode
(
logits_output
,
forward_batch
.
spec_info
)
...
...
python/sglang/srt/two_batch_overlap.py
View file @
c0fb25e9
...
@@ -545,6 +545,7 @@ class TboForwardBatchPreparer:
...
@@ -545,6 +545,7 @@ class TboForwardBatchPreparer:
tbo_children
=
None
,
tbo_children
=
None
,
global_num_tokens_gpu
=
None
,
global_num_tokens_gpu
=
None
,
global_num_tokens_cpu
=
None
,
global_num_tokens_cpu
=
None
,
dp_padding_mode
=
None
,
gathered_buffer
=
gathered_buffer
,
gathered_buffer
=
gathered_buffer
,
global_num_tokens_for_logprob_gpu
=
None
,
global_num_tokens_for_logprob_gpu
=
None
,
global_num_tokens_for_logprob_cpu
=
None
,
global_num_tokens_for_logprob_cpu
=
None
,
...
...
test/srt/test_deepep_small.py
View file @
c0fb25e9
...
@@ -35,7 +35,7 @@ class TestPureDP(CustomTestCase):
...
@@ -35,7 +35,7 @@ class TestPureDP(CustomTestCase):
"--cuda-graph-max-bs"
,
"--cuda-graph-max-bs"
,
"128"
,
"128"
,
"--max-running-requests"
,
"--max-running-requests"
,
"12
8
"
,
"
5
12"
,
"--mem-fraction-static"
,
"--mem-fraction-static"
,
"0.5"
,
"0.5"
,
],
],
...
@@ -81,7 +81,7 @@ class TestHybridDPTP(CustomTestCase):
...
@@ -81,7 +81,7 @@ class TestHybridDPTP(CustomTestCase):
"--cuda-graph-max-bs"
,
"--cuda-graph-max-bs"
,
"128"
,
"128"
,
"--max-running-requests"
,
"--max-running-requests"
,
"
128
"
,
"
256
"
,
],
],
)
)
...
@@ -170,7 +170,7 @@ class TestNoGatherdBuffer(CustomTestCase):
...
@@ -170,7 +170,7 @@ class TestNoGatherdBuffer(CustomTestCase):
"--cuda-graph-max-bs"
,
"--cuda-graph-max-bs"
,
"32"
,
"32"
,
"--max-running-requests"
,
"--max-running-requests"
,
"12
8
"
,
"
5
12"
,
],
],
)
)
...
@@ -217,7 +217,7 @@ class TestTBO(CustomTestCase):
...
@@ -217,7 +217,7 @@ class TestTBO(CustomTestCase):
"--cuda-graph-max-bs"
,
"--cuda-graph-max-bs"
,
"128"
,
"128"
,
"--max-running-requests"
,
"--max-running-requests"
,
"12
8
"
,
"
5
12"
,
],
],
)
)
...
@@ -273,7 +273,7 @@ class TestMTP(CustomTestCase):
...
@@ -273,7 +273,7 @@ class TestMTP(CustomTestCase):
"--cuda-graph-max-bs"
,
"--cuda-graph-max-bs"
,
"32"
,
"32"
,
"--max-running-requests"
,
"--max-running-requests"
,
"
32
"
,
"
64
"
,
],
],
)
)
...
@@ -343,7 +343,7 @@ class TestMTPWithTBO(CustomTestCase):
...
@@ -343,7 +343,7 @@ class TestMTPWithTBO(CustomTestCase):
"--cuda-graph-max-bs"
,
"--cuda-graph-max-bs"
,
"32"
,
"32"
,
"--max-running-requests"
,
"--max-running-requests"
,
"
32
"
,
"
128
"
,
],
],
)
)
...
...
test/srt/test_hybrid_dp_ep_tp_mtp.py
View file @
c0fb25e9
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment