Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b87aacb5
Unverified
Commit
b87aacb5
authored
Aug 13, 2025
by
Cheng Wan
Committed by
GitHub
Aug 13, 2025
Browse files
[DP Attention] Refactor: adding some utility functions (#9136)
parent
b3363cc1
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
212 additions
and
151 deletions
+212
-151
python/sglang/srt/layers/communicator.py
python/sglang/srt/layers/communicator.py
+7
-7
python/sglang/srt/layers/dp_attention.py
python/sglang/srt/layers/dp_attention.py
+114
-27
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+12
-18
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+5
-2
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+0
-1
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+8
-21
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+8
-10
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+2
-6
python/sglang/srt/models/deepseek_nextn.py
python/sglang/srt/models/deepseek_nextn.py
+2
-1
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+5
-3
python/sglang/srt/models/glm4_moe.py
python/sglang/srt/models/glm4_moe.py
+2
-2
python/sglang/srt/models/glm4_moe_nextn.py
python/sglang/srt/models/glm4_moe_nextn.py
+2
-1
python/sglang/srt/models/gpt_oss.py
python/sglang/srt/models/gpt_oss.py
+2
-1
python/sglang/srt/models/llama4.py
python/sglang/srt/models/llama4.py
+2
-2
python/sglang/srt/models/qwen2.py
python/sglang/srt/models/qwen2.py
+2
-2
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+2
-1
python/sglang/srt/models/step3_vl.py
python/sglang/srt/models/step3_vl.py
+6
-2
python/sglang/srt/operations.py
python/sglang/srt/operations.py
+17
-2
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
...n/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
+7
-21
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
...g/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
+7
-21
No files found.
python/sglang/srt/layers/communicator.py
View file @
b87aacb5
...
@@ -32,6 +32,8 @@ from sglang.srt.layers.dp_attention import (
...
@@ -32,6 +32,8 @@ from sglang.srt.layers.dp_attention import (
get_attention_dp_size
,
get_attention_dp_size
,
get_attention_tp_rank
,
get_attention_tp_rank
,
get_attention_tp_size
,
get_attention_tp_size
,
get_global_dp_buffer
,
get_local_dp_buffer
,
)
)
from
sglang.srt.layers.utils
import
is_sm100_supported
from
sglang.srt.layers.utils
import
is_sm100_supported
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
...
@@ -319,7 +321,7 @@ class CommunicateSimpleFn:
...
@@ -319,7 +321,7 @@ class CommunicateSimpleFn:
context
:
CommunicateContext
,
context
:
CommunicateContext
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
,
local_hidden_states
=
(
hidden_states
,
local_hidden_states
=
(
forward_batch
.
gathered_buffer
[:
forward_batch
.
input_ids
.
shape
[
0
]]
,
get_local_dp_buffer
()
,
hidden_states
,
hidden_states
,
)
)
attn_tp_all_gather_into_tensor
(
attn_tp_all_gather_into_tensor
(
...
@@ -408,9 +410,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
...
@@ -408,9 +410,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
):
):
if
residual_input_mode
==
ScatterMode
.
SCATTERED
and
context
.
attn_tp_size
>
1
:
if
residual_input_mode
==
ScatterMode
.
SCATTERED
and
context
.
attn_tp_size
>
1
:
residual
,
local_residual
=
(
residual
,
local_residual
=
(
torch
.
empty_like
(
get_local_dp_buffer
(),
forward_batch
.
gathered_buffer
[:
forward_batch
.
input_ids
.
shape
[
0
]]
),
residual
,
residual
,
)
)
attn_tp_all_gather_into_tensor
(
residual
,
local_residual
)
attn_tp_all_gather_into_tensor
(
residual
,
local_residual
)
...
@@ -424,7 +424,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
...
@@ -424,7 +424,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
residual
=
hidden_states
residual
=
hidden_states
hidden_states
=
layernorm
(
hidden_states
)
hidden_states
=
layernorm
(
hidden_states
)
hidden_states
,
local_hidden_states
=
(
hidden_states
,
local_hidden_states
=
(
torch
.
empty_like
(
forward_batch
.
gathered
_buffer
),
get_global_dp
_buffer
(
),
hidden_states
,
hidden_states
,
)
)
dp_gather_partial
(
hidden_states
,
local_hidden_states
,
forward_batch
)
dp_gather_partial
(
hidden_states
,
local_hidden_states
,
forward_batch
)
...
@@ -548,7 +548,7 @@ class CommunicateSummableTensorPairFn:
...
@@ -548,7 +548,7 @@ class CommunicateSummableTensorPairFn:
allow_reduce_scatter
:
bool
=
False
,
allow_reduce_scatter
:
bool
=
False
,
):
):
hidden_states
,
global_hidden_states
=
(
hidden_states
,
global_hidden_states
=
(
forward_batch
.
gathered_buffer
[:
forward_batch
.
input_ids
.
shape
[
0
]]
,
get_local_dp_buffer
()
,
hidden_states
,
hidden_states
,
)
)
if
allow_reduce_scatter
and
forward_batch
.
dp_padding_mode
.
is_max_len
():
if
allow_reduce_scatter
and
forward_batch
.
dp_padding_mode
.
is_max_len
():
...
@@ -569,7 +569,7 @@ class CommunicateSummableTensorPairFn:
...
@@ -569,7 +569,7 @@ class CommunicateSummableTensorPairFn:
hidden_states
+=
residual
hidden_states
+=
residual
residual
=
None
residual
=
None
hidden_states
,
local_hidden_states
=
(
hidden_states
,
local_hidden_states
=
(
forward_batch
.
gathered_buffer
[:
forward_batch
.
input_ids
.
shape
[
0
]]
,
get_local_dp_buffer
()
,
hidden_states
,
hidden_states
,
)
)
attn_tp_all_gather_into_tensor
(
attn_tp_all_gather_into_tensor
(
...
...
python/sglang/srt/layers/dp_attention.py
View file @
b87aacb5
...
@@ -4,7 +4,7 @@ import functools
...
@@ -4,7 +4,7 @@ import functools
import
logging
import
logging
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
enum
import
IntEnum
,
auto
from
enum
import
IntEnum
,
auto
from
typing
import
TYPE_CHECKING
,
List
,
Tuple
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
import
torch
import
torch
import
triton
import
triton
...
@@ -18,21 +18,26 @@ from sglang.srt.distributed import (
...
@@ -18,21 +18,26 @@ from sglang.srt.distributed import (
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce
,
)
)
if
TYPE_CHECKING
:
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.server_args
import
ServerArgs
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
_ATTN_TP_GROUP
=
None
_ATTN_TP_GROUP
:
Optional
[
GroupCoordinator
]
=
None
_ATTN_TP_RANK
=
None
_ATTN_TP_RANK
:
Optional
[
int
]
=
None
_ATTN_TP_SIZE
=
None
_ATTN_TP_SIZE
:
Optional
[
int
]
=
None
_ATTN_DP_RANK
=
None
_ATTN_DP_RANK
:
Optional
[
int
]
=
None
_ATTN_DP_SIZE
=
None
_ATTN_DP_SIZE
:
Optional
[
int
]
=
None
_LOCAL_ATTN_DP_SIZE
=
None
_LOCAL_ATTN_DP_SIZE
:
Optional
[
int
]
=
None
_LOCAL_ATTN_DP_RANK
=
None
_LOCAL_ATTN_DP_RANK
:
Optional
[
int
]
=
None
_ENABLE_DP_ATTENTION_FLAG
:
bool
=
False
class
D
P
PaddingMode
(
IntEnum
):
class
D
p
PaddingMode
(
IntEnum
):
# Padding tokens to max length and then gather tokens using `all_gather_into_tensor`
# Padding tokens to max length and then gather tokens using `all_gather_into_tensor`
MAX_LEN
=
auto
()
MAX_LEN
=
auto
()
...
@@ -40,13 +45,13 @@ class DPPaddingMode(IntEnum):
...
@@ -40,13 +45,13 @@ class DPPaddingMode(IntEnum):
SUM_LEN
=
auto
()
SUM_LEN
=
auto
()
def
is_max_len
(
self
):
def
is_max_len
(
self
):
return
self
==
D
P
PaddingMode
.
MAX_LEN
return
self
==
D
p
PaddingMode
.
MAX_LEN
def
is_sum_len
(
self
):
def
is_sum_len
(
self
):
return
self
==
D
P
PaddingMode
.
SUM_LEN
return
self
==
D
p
PaddingMode
.
SUM_LEN
@
classmethod
@
classmethod
def
get_dp_padding_mode
(
cls
,
global_num_tokens
:
List
[
int
])
->
D
P
PaddingMode
:
def
get_dp_padding_mode
(
cls
,
global_num_tokens
:
List
[
int
])
->
D
p
PaddingMode
:
# we choose the mode that minimizes the communication cost
# we choose the mode that minimizes the communication cost
max_len
=
max
(
global_num_tokens
)
max_len
=
max
(
global_num_tokens
)
sum_len
=
sum
(
global_num_tokens
)
sum_len
=
sum
(
global_num_tokens
)
...
@@ -56,10 +61,76 @@ class DPPaddingMode(IntEnum):
...
@@ -56,10 +61,76 @@ class DPPaddingMode(IntEnum):
return
cls
.
SUM_LEN
return
cls
.
SUM_LEN
@
classmethod
@
classmethod
def
get_default_mode_in_cuda_graph
(
cls
)
->
D
P
PaddingMode
:
def
get_default_mode_in_cuda_graph
(
cls
)
->
D
p
PaddingMode
:
return
cls
.
MAX_LEN
return
cls
.
MAX_LEN
class
_DpGatheredBufferWrapper
:
_hidden_size
:
int
_dtype
:
torch
.
dtype
_device
:
torch
.
device
_global_dp_buffer_len
:
int
_local_dp_buffer_len
:
int
@
classmethod
def
set_metadata
(
cls
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
):
cls
.
_hidden_size
=
hidden_size
cls
.
_dtype
=
dtype
cls
.
_device
=
device
@
classmethod
def
set_dp_buffer_len
(
cls
,
global_dp_buffer_len
:
int
,
local_dp_buffer_len
:
int
):
cls
.
_global_dp_buffer_len
=
global_dp_buffer_len
cls
.
_local_dp_buffer_len
=
local_dp_buffer_len
@
classmethod
def
get_global_dp_buffer
(
cls
)
->
torch
.
Tensor
:
return
torch
.
empty
(
(
cls
.
_global_dp_buffer_len
,
cls
.
_hidden_size
),
dtype
=
cls
.
_dtype
,
device
=
cls
.
_device
,
)
@
classmethod
def
get_local_dp_buffer
(
cls
)
->
torch
.
Tensor
:
return
torch
.
empty
(
(
cls
.
_local_dp_buffer_len
,
cls
.
_hidden_size
),
dtype
=
cls
.
_dtype
,
device
=
cls
.
_device
,
)
@
classmethod
def
get_global_dp_buffer_len
(
cls
)
->
int
:
return
cls
.
_global_dp_buffer_len
@
classmethod
def
get_local_dp_buffer_len
(
cls
)
->
int
:
return
cls
.
_local_dp_buffer_len
def
set_dp_buffer_len
(
global_dp_buffer_len
:
int
,
local_dp_buffer_len
:
int
):
_DpGatheredBufferWrapper
.
set_dp_buffer_len
(
global_dp_buffer_len
,
local_dp_buffer_len
)
def
get_global_dp_buffer
()
->
torch
.
Tensor
:
return
_DpGatheredBufferWrapper
.
get_global_dp_buffer
()
def
get_local_dp_buffer
()
->
torch
.
Tensor
:
return
_DpGatheredBufferWrapper
.
get_local_dp_buffer
()
def
get_global_dp_buffer_len
()
->
int
:
return
_DpGatheredBufferWrapper
.
get_global_dp_buffer_len
()
def
get_local_dp_buffer_len
()
->
int
:
return
_DpGatheredBufferWrapper
.
get_local_dp_buffer_len
()
def
compute_dp_attention_world_info
(
enable_dp_attention
,
tp_rank
,
tp_size
,
dp_size
):
def
compute_dp_attention_world_info
(
enable_dp_attention
,
tp_rank
,
tp_size
,
dp_size
):
if
not
enable_dp_attention
:
if
not
enable_dp_attention
:
return
tp_rank
,
tp_size
,
0
return
tp_rank
,
tp_size
,
0
...
@@ -89,18 +160,24 @@ def compute_dp_attention_local_info(
...
@@ -89,18 +160,24 @@ def compute_dp_attention_local_info(
def
initialize_dp_attention
(
def
initialize_dp_attention
(
enable_dp_attention
:
bool
,
server_args
:
ServerArgs
,
tp_rank
:
int
,
model_config
:
ModelConfig
,
tp_size
:
int
,
dp_size
:
int
,
moe_dense_tp_size
:
int
,
pp_size
:
int
,
):
):
global
_ATTN_TP_GROUP
,
_ATTN_TP_RANK
,
_ATTN_TP_SIZE
,
_ATTN_DP_RANK
,
_ATTN_DP_SIZE
global
_ATTN_TP_GROUP
,
_ATTN_TP_RANK
,
_ATTN_TP_SIZE
,
_ATTN_DP_RANK
,
_ATTN_DP_SIZE
global
_LOCAL_ATTN_DP_SIZE
,
_LOCAL_ATTN_DP_RANK
global
_LOCAL_ATTN_DP_SIZE
,
_LOCAL_ATTN_DP_RANK
,
_ENABLE_DP_ATTENTION_FLAG
from
sglang.srt.layers.sampler
import
SYNC_TOKEN_IDS_ACROSS_TP
from
sglang.srt.layers.sampler
import
SYNC_TOKEN_IDS_ACROSS_TP
enable_dp_attention
=
server_args
.
enable_dp_attention
tp_size
=
server_args
.
tp_size
dp_size
=
server_args
.
dp_size
moe_dense_tp_size
=
server_args
.
moe_dense_tp_size
pp_size
=
server_args
.
pp_size
tp_rank
=
get_tensor_model_parallel_rank
()
_ENABLE_DP_ATTENTION_FLAG
=
enable_dp_attention
_ATTN_TP_RANK
,
_ATTN_TP_SIZE
,
_ATTN_DP_RANK
=
compute_dp_attention_world_info
(
_ATTN_TP_RANK
,
_ATTN_TP_SIZE
,
_ATTN_DP_RANK
=
compute_dp_attention_world_info
(
enable_dp_attention
,
tp_rank
,
tp_size
,
dp_size
enable_dp_attention
,
tp_rank
,
tp_size
,
dp_size
)
)
...
@@ -135,38 +212,48 @@ def initialize_dp_attention(
...
@@ -135,38 +212,48 @@ def initialize_dp_attention(
group_name
=
"attention_tp"
,
group_name
=
"attention_tp"
,
)
)
_DpGatheredBufferWrapper
.
set_metadata
(
hidden_size
=
model_config
.
hidden_size
,
dtype
=
model_config
.
dtype
,
device
=
torch
.
device
(
"cuda"
),
)
def
is_dp_attention_enabled
()
->
bool
:
return
_ENABLE_DP_ATTENTION_FLAG
def
get_attention_tp_group
():
def
get_attention_tp_group
()
->
GroupCoordinator
:
assert
_ATTN_TP_GROUP
is
not
None
,
"dp attention not initialized!"
assert
_ATTN_TP_GROUP
is
not
None
,
"dp attention not initialized!"
return
_ATTN_TP_GROUP
return
_ATTN_TP_GROUP
def
get_attention_tp_rank
():
def
get_attention_tp_rank
()
->
int
:
assert
_ATTN_TP_RANK
is
not
None
,
"dp attention not initialized!"
assert
_ATTN_TP_RANK
is
not
None
,
"dp attention not initialized!"
return
_ATTN_TP_RANK
return
_ATTN_TP_RANK
def
get_attention_tp_size
():
def
get_attention_tp_size
()
->
int
:
assert
_ATTN_TP_SIZE
is
not
None
,
"dp attention not initialized!"
assert
_ATTN_TP_SIZE
is
not
None
,
"dp attention not initialized!"
return
_ATTN_TP_SIZE
return
_ATTN_TP_SIZE
def
get_attention_dp_rank
():
def
get_attention_dp_rank
()
->
int
:
assert
_ATTN_DP_RANK
is
not
None
,
"dp attention not initialized!"
assert
_ATTN_DP_RANK
is
not
None
,
"dp attention not initialized!"
return
_ATTN_DP_RANK
return
_ATTN_DP_RANK
def
get_attention_dp_size
():
def
get_attention_dp_size
()
->
int
:
assert
_ATTN_DP_SIZE
is
not
None
,
"dp attention not initialized!"
assert
_ATTN_DP_SIZE
is
not
None
,
"dp attention not initialized!"
return
_ATTN_DP_SIZE
return
_ATTN_DP_SIZE
def
get_local_attention_dp_rank
():
def
get_local_attention_dp_rank
()
->
int
:
assert
_LOCAL_ATTN_DP_RANK
is
not
None
,
"dp attention not initialized!"
assert
_LOCAL_ATTN_DP_RANK
is
not
None
,
"dp attention not initialized!"
return
_LOCAL_ATTN_DP_RANK
return
_LOCAL_ATTN_DP_RANK
def
get_local_attention_dp_size
():
def
get_local_attention_dp_size
()
->
int
:
assert
_LOCAL_ATTN_DP_SIZE
is
not
None
,
"dp attention not initialized!"
assert
_LOCAL_ATTN_DP_SIZE
is
not
None
,
"dp attention not initialized!"
return
_LOCAL_ATTN_DP_SIZE
return
_LOCAL_ATTN_DP_SIZE
...
...
python/sglang/srt/layers/logits_processor.py
View file @
b87aacb5
...
@@ -27,7 +27,7 @@ from sglang.srt.distributed import (
...
@@ -27,7 +27,7 @@ from sglang.srt.distributed import (
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_gather
,
)
)
from
sglang.srt.layers.dp_attention
import
(
from
sglang.srt.layers.dp_attention
import
(
D
P
PaddingMode
,
D
p
PaddingMode
,
attn_tp_all_gather
,
attn_tp_all_gather
,
attn_tp_all_gather_into_tensor
,
attn_tp_all_gather_into_tensor
,
dp_gather_replicate
,
dp_gather_replicate
,
...
@@ -35,7 +35,9 @@ from sglang.srt.layers.dp_attention import (
...
@@ -35,7 +35,9 @@ from sglang.srt.layers.dp_attention import (
get_attention_dp_rank
,
get_attention_dp_rank
,
get_attention_dp_size
,
get_attention_dp_size
,
get_attention_tp_size
,
get_attention_tp_size
,
get_global_dp_buffer
,
get_local_attention_dp_size
,
get_local_attention_dp_size
,
set_dp_buffer_len
,
)
)
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
...
@@ -108,14 +110,12 @@ class LogitsMetadata:
...
@@ -108,14 +110,12 @@ class LogitsMetadata:
# The start position of local hidden states.
# The start position of local hidden states.
dp_local_start_pos
:
Optional
[
torch
.
Tensor
]
=
None
dp_local_start_pos
:
Optional
[
torch
.
Tensor
]
=
None
dp_local_num_tokens
:
Optional
[
torch
.
Tensor
]
=
None
dp_local_num_tokens
:
Optional
[
torch
.
Tensor
]
=
None
gathered_buffer
:
Optional
[
torch
.
Tensor
]
=
None
global_dp_buffer_len
:
Optional
[
int
]
=
None
# Buffer to gather logits from all ranks.
forward_batch_gathered_buffer
:
Optional
[
torch
.
Tensor
]
=
None
# Number of tokens to sample per DP rank
# Number of tokens to sample per DP rank
global_num_tokens_for_logprob_cpu
:
Optional
[
torch
.
Tensor
]
=
None
global_num_tokens_for_logprob_cpu
:
Optional
[
torch
.
Tensor
]
=
None
global_num_tokens_for_logprob_gpu
:
Optional
[
torch
.
Tensor
]
=
None
global_num_tokens_for_logprob_gpu
:
Optional
[
torch
.
Tensor
]
=
None
# The gather mode for DP attention
# The gather mode for DP attention
dp_padding_mode
:
Optional
[
D
P
PaddingMode
]
=
None
dp_padding_mode
:
Optional
[
D
p
PaddingMode
]
=
None
# for padding
# for padding
padded_static_len
:
int
=
-
1
padded_static_len
:
int
=
-
1
...
@@ -164,11 +164,10 @@ class LogitsMetadata:
...
@@ -164,11 +164,10 @@ class LogitsMetadata:
global_num_tokens_gpu
=
forward_batch
.
global_num_tokens_gpu
,
global_num_tokens_gpu
=
forward_batch
.
global_num_tokens_gpu
,
dp_local_start_pos
=
forward_batch
.
dp_local_start_pos
,
dp_local_start_pos
=
forward_batch
.
dp_local_start_pos
,
dp_local_num_tokens
=
forward_batch
.
dp_local_num_tokens
,
dp_local_num_tokens
=
forward_batch
.
dp_local_num_tokens
,
gathered_buffer
=
forward_batch
.
gathered_buffer
,
global_dp_buffer_len
=
forward_batch
.
global_dp_buffer_len
,
forward_batch_gathered_buffer
=
forward_batch
.
gathered_buffer
,
global_num_tokens_for_logprob_cpu
=
forward_batch
.
global_num_tokens_for_logprob_cpu
,
global_num_tokens_for_logprob_cpu
=
forward_batch
.
global_num_tokens_for_logprob_cpu
,
global_num_tokens_for_logprob_gpu
=
forward_batch
.
global_num_tokens_for_logprob_gpu
,
global_num_tokens_for_logprob_gpu
=
forward_batch
.
global_num_tokens_for_logprob_gpu
,
dp_padding_mode
=
D
P
PaddingMode
.
SUM_LEN
,
dp_padding_mode
=
D
p
PaddingMode
.
SUM_LEN
,
)
)
def
compute_dp_attention_metadata
(
self
):
def
compute_dp_attention_metadata
(
self
):
...
@@ -188,16 +187,11 @@ class LogitsMetadata:
...
@@ -188,16 +187,11 @@ class LogitsMetadata:
if
self
.
global_num_tokens_for_logprob_cpu
is
not
None
:
if
self
.
global_num_tokens_for_logprob_cpu
is
not
None
:
# create a smaller buffer to reduce peak memory usage
# create a smaller buffer to reduce peak memory usage
self
.
gathered_buffer
=
torch
.
empty
(
self
.
global_dp_buffer_len
=
sum
(
self
.
global_num_tokens_for_logprob_cpu
)
(
sum
(
self
.
global_num_tokens_for_logprob_cpu
),
self
.
gathered_buffer
.
shape
[
1
],
),
dtype
=
self
.
gathered_buffer
.
dtype
,
device
=
self
.
gathered_buffer
.
device
,
)
else
:
else
:
self
.
gathered_buffer
=
torch
.
empty_like
(
self
.
gathered_buffer
)
self
.
global_dp_buffer_len
=
self
.
global_dp_buffer_len
set_dp_buffer_len
(
self
.
global_dp_buffer_len
,
self
.
dp_local_num_tokens
)
class
LogitsProcessor
(
nn
.
Module
):
class
LogitsProcessor
(
nn
.
Module
):
...
@@ -443,7 +437,7 @@ class LogitsProcessor(nn.Module):
...
@@ -443,7 +437,7 @@ class LogitsProcessor(nn.Module):
if
self
.
do_tensor_parallel_all_gather_dp_attn
:
if
self
.
do_tensor_parallel_all_gather_dp_attn
:
logits_metadata
.
compute_dp_attention_metadata
()
logits_metadata
.
compute_dp_attention_metadata
()
hidden_states
,
local_hidden_states
=
(
hidden_states
,
local_hidden_states
=
(
logits_metadata
.
gathered
_buffer
,
get_global_dp
_buffer
()
,
hidden_states
,
hidden_states
,
)
)
dp_gather_replicate
(
hidden_states
,
local_hidden_states
,
logits_metadata
)
dp_gather_replicate
(
hidden_states
,
local_hidden_states
,
logits_metadata
)
...
...
python/sglang/srt/layers/sampler.py
View file @
b87aacb5
...
@@ -6,7 +6,10 @@ import torch.distributed as dist
...
@@ -6,7 +6,10 @@ import torch.distributed as dist
from
torch
import
nn
from
torch
import
nn
from
sglang.srt.distributed
import
get_tp_group
from
sglang.srt.distributed
import
get_tp_group
from
sglang.srt.layers.dp_attention
import
get_attention_tp_group
from
sglang.srt.layers.dp_attention
import
(
get_attention_tp_group
,
is_dp_attention_enabled
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
...
@@ -32,7 +35,7 @@ class Sampler(nn.Module):
...
@@ -32,7 +35,7 @@ class Sampler(nn.Module):
self
.
use_nan_detection
=
global_server_args_dict
[
"enable_nan_detection"
]
self
.
use_nan_detection
=
global_server_args_dict
[
"enable_nan_detection"
]
self
.
tp_sync_group
=
get_tp_group
().
device_group
self
.
tp_sync_group
=
get_tp_group
().
device_group
if
global_server_args_dict
[
"enable
_dp_attention
"
]
:
if
is
_dp_attention
_enabled
()
:
self
.
tp_sync_group
=
get_attention_tp_group
().
device_group
self
.
tp_sync_group
=
get_attention_tp_group
().
device_group
def
forward
(
def
forward
(
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
b87aacb5
...
@@ -84,7 +84,6 @@ GLOBAL_SERVER_ARGS_KEYS = [
...
@@ -84,7 +84,6 @@ GLOBAL_SERVER_ARGS_KEYS = [
"device"
,
"device"
,
"disable_chunked_prefix_cache"
,
"disable_chunked_prefix_cache"
,
"disable_radix_cache"
,
"disable_radix_cache"
,
"enable_dp_attention"
,
"enable_two_batch_overlap"
,
"enable_two_batch_overlap"
,
"tbo_token_distribution_threshold"
,
"tbo_token_distribution_threshold"
,
"enable_dp_lm_head"
,
"enable_dp_lm_head"
,
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
b87aacb5
...
@@ -34,9 +34,10 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
...
@@ -34,9 +34,10 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
)
)
from
sglang.srt.distributed.parallel_state
import
GroupCoordinator
,
graph_capture
from
sglang.srt.distributed.parallel_state
import
GroupCoordinator
,
graph_capture
from
sglang.srt.layers.dp_attention
import
(
from
sglang.srt.layers.dp_attention
import
(
D
P
PaddingMode
,
D
p
PaddingMode
,
get_attention_tp_rank
,
get_attention_tp_rank
,
get_attention_tp_size
,
get_attention_tp_size
,
set_dp_buffer_len
,
)
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.torchao_utils
import
save_gemlite_cache
from
sglang.srt.layers.torchao_utils
import
save_gemlite_cache
...
@@ -349,30 +350,15 @@ class CudaGraphRunner:
...
@@ -349,30 +350,15 @@ class CudaGraphRunner:
self
.
global_num_tokens_for_logprob_gpu
=
torch
.
zeros
(
self
.
global_num_tokens_for_logprob_gpu
=
torch
.
zeros
(
(
self
.
dp_size
,),
dtype
=
torch
.
int32
(
self
.
dp_size
,),
dtype
=
torch
.
int32
)
)
self
.
gathered_buffer
=
torch
.
zeros
(
(
self
.
max_num_token
*
self
.
dp_size
,
self
.
model_runner
.
model_config
.
hidden_size
,
),
dtype
=
self
.
model_runner
.
dtype
,
)
else
:
else
:
assert
self
.
require_attn_tp_gather
assert
self
.
require_attn_tp_gather
self
.
global_num_tokens_gpu
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
int32
)
self
.
global_num_tokens_gpu
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
int32
)
self
.
global_num_tokens_for_logprob_gpu
=
torch
.
zeros
(
self
.
global_num_tokens_for_logprob_gpu
=
torch
.
zeros
(
(
1
,),
dtype
=
torch
.
int32
(
1
,),
dtype
=
torch
.
int32
)
)
self
.
gathered_buffer
=
torch
.
zeros
(
(
self
.
max_num_token
,
self
.
model_runner
.
model_config
.
hidden_size
,
),
dtype
=
self
.
model_runner
.
dtype
,
)
else
:
else
:
self
.
global_num_tokens_gpu
=
None
self
.
global_num_tokens_gpu
=
None
self
.
global_num_tokens_for_logprob_gpu
=
None
self
.
global_num_tokens_for_logprob_gpu
=
None
self
.
gathered_buffer
=
None
self
.
custom_mask
=
torch
.
ones
(
self
.
custom_mask
=
torch
.
ones
(
(
(
...
@@ -556,7 +542,7 @@ class CudaGraphRunner:
...
@@ -556,7 +542,7 @@ class CudaGraphRunner:
device
=
input_ids
.
device
,
device
=
input_ids
.
device
,
)
)
)
)
g
athered_buffer
=
self
.
gathered_buffer
[:
num_tokens
*
self
.
dp_size
]
g
lobal_dp_buffer_len
=
num_tokens
*
self
.
dp_size
elif
self
.
require_attn_tp_gather
:
elif
self
.
require_attn_tp_gather
:
self
.
global_num_tokens_gpu
.
copy_
(
self
.
global_num_tokens_gpu
.
copy_
(
torch
.
tensor
(
torch
.
tensor
(
...
@@ -572,9 +558,9 @@ class CudaGraphRunner:
...
@@ -572,9 +558,9 @@ class CudaGraphRunner:
device
=
input_ids
.
device
,
device
=
input_ids
.
device
,
)
)
)
)
g
athered_buffer
=
self
.
gathered_buffer
[:
num_tokens
]
g
lobal_dp_buffer_len
=
num_tokens
else
:
else
:
g
athered
_buffer
=
None
g
lobal_dp
_buffer
_len
=
None
spec_info
=
self
.
get_spec_info
(
num_tokens
)
spec_info
=
self
.
get_spec_info
(
num_tokens
)
if
self
.
capture_hidden_mode
!=
CaptureHiddenMode
.
FULL
:
if
self
.
capture_hidden_mode
!=
CaptureHiddenMode
.
FULL
:
...
@@ -607,8 +593,8 @@ class CudaGraphRunner:
...
@@ -607,8 +593,8 @@ class CudaGraphRunner:
positions
=
positions
,
positions
=
positions
,
global_num_tokens_gpu
=
self
.
global_num_tokens_gpu
,
global_num_tokens_gpu
=
self
.
global_num_tokens_gpu
,
global_num_tokens_for_logprob_gpu
=
self
.
global_num_tokens_for_logprob_gpu
,
global_num_tokens_for_logprob_gpu
=
self
.
global_num_tokens_for_logprob_gpu
,
dp_padding_mode
=
D
P
PaddingMode
.
get_default_mode_in_cuda_graph
(),
dp_padding_mode
=
D
p
PaddingMode
.
get_default_mode_in_cuda_graph
(),
g
athered_buffer
=
gathered
_buffer
,
g
lobal_dp_buffer_len
=
global_dp
_buffer
_len
,
mrope_positions
=
mrope_positions
,
mrope_positions
=
mrope_positions
,
spec_algorithm
=
self
.
model_runner
.
spec_algorithm
,
spec_algorithm
=
self
.
model_runner
.
spec_algorithm
,
spec_info
=
spec_info
,
spec_info
=
spec_info
,
...
@@ -637,6 +623,7 @@ class CudaGraphRunner:
...
@@ -637,6 +623,7 @@ class CudaGraphRunner:
def
run_once
():
def
run_once
():
# Clean intermediate result cache for DP attention
# Clean intermediate result cache for DP attention
forward_batch
.
dp_local_start_pos
=
forward_batch
.
dp_local_num_tokens
=
None
forward_batch
.
dp_local_start_pos
=
forward_batch
.
dp_local_num_tokens
=
None
set_dp_buffer_len
(
global_dp_buffer_len
,
num_tokens
)
kwargs
=
{}
kwargs
=
{}
if
(
if
(
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
b87aacb5
...
@@ -40,9 +40,10 @@ import triton.language as tl
...
@@ -40,9 +40,10 @@ import triton.language as tl
from
sglang.srt.distributed.parallel_state
import
get_moe_expert_parallel_world_size
from
sglang.srt.distributed.parallel_state
import
get_moe_expert_parallel_world_size
from
sglang.srt.layers.dp_attention
import
(
from
sglang.srt.layers.dp_attention
import
(
D
P
PaddingMode
,
D
p
PaddingMode
,
get_attention_dp_rank
,
get_attention_dp_rank
,
get_attention_tp_size
,
get_attention_tp_size
,
set_dp_buffer_len
,
)
)
from
sglang.srt.layers.rotary_embedding
import
MRotaryEmbedding
from
sglang.srt.layers.rotary_embedding
import
MRotaryEmbedding
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
...
@@ -274,13 +275,13 @@ class ForwardBatch:
...
@@ -274,13 +275,13 @@ class ForwardBatch:
global_num_tokens_for_logprob_cpu
:
Optional
[
List
[
int
]]
=
None
global_num_tokens_for_logprob_cpu
:
Optional
[
List
[
int
]]
=
None
global_num_tokens_for_logprob_gpu
:
Optional
[
torch
.
Tensor
]
=
None
global_num_tokens_for_logprob_gpu
:
Optional
[
torch
.
Tensor
]
=
None
# The padding mode for DP attention
# The padding mode for DP attention
dp_padding_mode
:
Optional
[
D
P
PaddingMode
]
=
None
dp_padding_mode
:
Optional
[
D
p
PaddingMode
]
=
None
# for extend, local start pos and num tokens is different in logits processor
# for extend, local start pos and num tokens is different in logits processor
# this will be computed in get_dp_local_info
# this will be computed in get_dp_local_info
# this will be recomputed in LogitsMetadata.from_forward_batch
# this will be recomputed in LogitsMetadata.from_forward_batch
dp_local_start_pos
:
Optional
[
torch
.
Tensor
]
=
None
# cached info at runtime
dp_local_start_pos
:
Optional
[
torch
.
Tensor
]
=
None
# cached info at runtime
dp_local_num_tokens
:
Optional
[
torch
.
Tensor
]
=
None
# cached info at runtime
dp_local_num_tokens
:
Optional
[
torch
.
Tensor
]
=
None
# cached info at runtime
g
athered
_buffer
:
Optional
[
torch
.
Tensor
]
=
None
g
lobal_dp
_buffer
_len
:
Optional
[
int
]
=
None
is_extend_in_batch
:
bool
=
False
is_extend_in_batch
:
bool
=
False
can_run_dp_cuda_graph
:
bool
=
False
can_run_dp_cuda_graph
:
bool
=
False
global_forward_mode
:
Optional
[
ForwardMode
]
=
None
global_forward_mode
:
Optional
[
ForwardMode
]
=
None
...
@@ -628,7 +629,7 @@ class ForwardBatch:
...
@@ -628,7 +629,7 @@ class ForwardBatch:
(
global_num_tokens
[
i
]
-
1
)
//
attn_tp_size
+
1
(
global_num_tokens
[
i
]
-
1
)
//
attn_tp_size
+
1
)
*
attn_tp_size
)
*
attn_tp_size
dp_padding_mode
=
D
P
PaddingMode
.
get_dp_padding_mode
(
global_num_tokens
)
dp_padding_mode
=
D
p
PaddingMode
.
get_dp_padding_mode
(
global_num_tokens
)
self
.
dp_padding_mode
=
dp_padding_mode
self
.
dp_padding_mode
=
dp_padding_mode
if
dp_padding_mode
.
is_max_len
():
if
dp_padding_mode
.
is_max_len
():
...
@@ -642,17 +643,14 @@ class ForwardBatch:
...
@@ -642,17 +643,14 @@ class ForwardBatch:
else
:
else
:
buffer_len
=
sum
(
global_num_tokens
)
buffer_len
=
sum
(
global_num_tokens
)
self
.
gathered_buffer
=
torch
.
zeros
(
(
buffer_len
,
model_runner
.
model_config
.
hidden_size
),
dtype
=
model_runner
.
dtype
,
device
=
model_runner
.
device
,
)
if
len
(
global_num_tokens
)
>
1
:
if
len
(
global_num_tokens
)
>
1
:
num_tokens
=
global_num_tokens
[
get_attention_dp_rank
()]
num_tokens
=
global_num_tokens
[
get_attention_dp_rank
()]
else
:
else
:
num_tokens
=
global_num_tokens
[
0
]
num_tokens
=
global_num_tokens
[
0
]
self
.
global_dp_buffer_len
=
buffer_len
set_dp_buffer_len
(
buffer_len
,
num_tokens
)
bs
=
self
.
batch_size
bs
=
self
.
batch_size
if
self
.
forward_mode
.
is_decode
():
if
self
.
forward_mode
.
is_decode
():
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
b87aacb5
...
@@ -603,12 +603,8 @@ class ModelRunner:
...
@@ -603,12 +603,8 @@ class ModelRunner:
duplicate_tp_group
=
self
.
server_args
.
enable_pdmux
,
duplicate_tp_group
=
self
.
server_args
.
enable_pdmux
,
)
)
initialize_dp_attention
(
initialize_dp_attention
(
enable_dp_attention
=
self
.
server_args
.
enable_dp_attention
,
server_args
=
self
.
server_args
,
tp_rank
=
self
.
tp_rank
,
model_config
=
self
.
model_config
,
tp_size
=
self
.
tp_size
,
dp_size
=
self
.
server_args
.
dp_size
,
moe_dense_tp_size
=
self
.
server_args
.
moe_dense_tp_size
,
pp_size
=
self
.
server_args
.
pp_size
,
)
)
min_per_gpu_memory
=
get_available_gpu_memory
(
min_per_gpu_memory
=
get_available_gpu_memory
(
...
...
python/sglang/srt/models/deepseek_nextn.py
View file @
b87aacb5
...
@@ -22,6 +22,7 @@ from transformers import PretrainedConfig
...
@@ -22,6 +22,7 @@ from transformers import PretrainedConfig
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
from
sglang.srt.eplb.expert_distribution
import
get_global_expert_distribution_recorder
from
sglang.srt.eplb.expert_distribution
import
get_global_expert_distribution_recorder
from
sglang.srt.layers.dp_attention
import
is_dp_attention_enabled
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
...
@@ -56,7 +57,7 @@ class DeepseekModelNextN(nn.Module):
...
@@ -56,7 +57,7 @@ class DeepseekModelNextN(nn.Module):
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
enable_tp
=
not
global_server_args_dict
[
"enable
_dp_attention
"
]
,
enable_tp
=
not
is
_dp_attention
_enabled
()
,
prefix
=
add_prefix
(
"embed_tokens"
,
prefix
),
prefix
=
add_prefix
(
"embed_tokens"
,
prefix
),
)
)
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
b87aacb5
...
@@ -51,6 +51,7 @@ from sglang.srt.layers.dp_attention import (
...
@@ -51,6 +51,7 @@ from sglang.srt.layers.dp_attention import (
get_attention_tp_rank
,
get_attention_tp_rank
,
get_attention_tp_size
,
get_attention_tp_size
,
get_local_attention_dp_size
,
get_local_attention_dp_size
,
is_dp_attention_enabled
,
)
)
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
(
from
sglang.srt.layers.linear
import
(
...
@@ -1797,7 +1798,6 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -1797,7 +1798,6 @@ class DeepseekV2DecoderLayer(nn.Module):
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
self
.
enable_dp_attention
=
global_server_args_dict
[
"enable_dp_attention"
]
self
.
speculative_algorithm
=
global_server_args_dict
[
"speculative_algorithm"
]
self
.
speculative_algorithm
=
global_server_args_dict
[
"speculative_algorithm"
]
self
.
layer_id
=
layer_id
self
.
layer_id
=
layer_id
self
.
is_nextn
=
is_nextn
self
.
is_nextn
=
is_nextn
...
@@ -1917,7 +1917,9 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -1917,7 +1917,9 @@ class DeepseekV2DecoderLayer(nn.Module):
should_allreduce_fusion
=
(
should_allreduce_fusion
=
(
self
.
_should_fuse_mlp_allreduce_with_next_layer
(
forward_batch
)
self
.
_should_fuse_mlp_allreduce_with_next_layer
(
forward_batch
)
and
not
(
self
.
enable_dp_attention
and
self
.
speculative_algorithm
.
is_eagle
())
and
not
(
is_dp_attention_enabled
()
and
self
.
speculative_algorithm
.
is_eagle
()
)
and
not
self
.
is_nextn
and
not
self
.
is_nextn
)
)
...
@@ -2047,7 +2049,7 @@ class DeepseekV2Model(nn.Module):
...
@@ -2047,7 +2049,7 @@ class DeepseekV2Model(nn.Module):
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
enable_tp
=
not
global_server_args_dict
[
"enable
_dp_attention
"
]
,
enable_tp
=
not
is
_dp_attention
_enabled
()
,
)
)
self
.
alt_stream
=
torch
.
cuda
.
Stream
()
if
_is_cuda
else
None
self
.
alt_stream
=
torch
.
cuda
.
Stream
()
if
_is_cuda
else
None
self
.
layers
=
nn
.
ModuleList
(
self
.
layers
=
nn
.
ModuleList
(
...
...
python/sglang/srt/models/glm4_moe.py
View file @
b87aacb5
...
@@ -40,6 +40,7 @@ from sglang.srt.layers.dp_attention import (
...
@@ -40,6 +40,7 @@ from sglang.srt.layers.dp_attention import (
get_attention_tp_rank
,
get_attention_tp_rank
,
get_attention_tp_size
,
get_attention_tp_size
,
get_local_attention_dp_size
,
get_local_attention_dp_size
,
is_dp_attention_enabled
,
)
)
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
(
from
sglang.srt.layers.linear
import
(
...
@@ -634,7 +635,6 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
...
@@ -634,7 +635,6 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
)
)
rms_norm_eps
=
config
.
rms_norm_eps
rms_norm_eps
=
config
.
rms_norm_eps
attention_bias
=
config
.
attention_bias
attention_bias
=
config
.
attention_bias
self
.
enable_dp_attention
=
global_server_args_dict
[
"enable_dp_attention"
]
self
.
layer_id
=
layer_id
self
.
layer_id
=
layer_id
self
.
self_attn
=
Glm4MoeAttention
(
self
.
self_attn
=
Glm4MoeAttention
(
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
...
@@ -744,7 +744,7 @@ class Glm4MoeModel(DeepseekV2Model):
...
@@ -744,7 +744,7 @@ class Glm4MoeModel(DeepseekV2Model):
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
enable_tp
=
not
global_server_args_dict
[
"enable
_dp_attention
"
]
,
enable_tp
=
not
is
_dp_attention
_enabled
()
,
)
)
self
.
alt_stream
=
torch
.
cuda
.
Stream
()
if
_is_cuda
else
None
self
.
alt_stream
=
torch
.
cuda
.
Stream
()
if
_is_cuda
else
None
self
.
layers
=
nn
.
ModuleList
(
self
.
layers
=
nn
.
ModuleList
(
...
...
python/sglang/srt/models/glm4_moe_nextn.py
View file @
b87aacb5
...
@@ -22,6 +22,7 @@ from transformers import PretrainedConfig
...
@@ -22,6 +22,7 @@ from transformers import PretrainedConfig
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
from
sglang.srt.eplb.expert_distribution
import
get_global_expert_distribution_recorder
from
sglang.srt.eplb.expert_distribution
import
get_global_expert_distribution_recorder
from
sglang.srt.layers.dp_attention
import
is_dp_attention_enabled
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
...
@@ -56,7 +57,7 @@ class Glm4MoeModelNextN(nn.Module):
...
@@ -56,7 +57,7 @@ class Glm4MoeModelNextN(nn.Module):
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
enable_tp
=
not
global_server_args_dict
[
"enable
_dp_attention
"
]
,
enable_tp
=
not
is
_dp_attention
_enabled
()
,
prefix
=
add_prefix
(
"embed_tokens"
,
prefix
),
prefix
=
add_prefix
(
"embed_tokens"
,
prefix
),
)
)
...
...
python/sglang/srt/models/gpt_oss.py
View file @
b87aacb5
...
@@ -41,6 +41,7 @@ from sglang.srt.layers.dp_attention import (
...
@@ -41,6 +41,7 @@ from sglang.srt.layers.dp_attention import (
get_attention_tp_rank
,
get_attention_tp_rank
,
get_attention_tp_size
,
get_attention_tp_size
,
get_local_attention_dp_size
,
get_local_attention_dp_size
,
is_dp_attention_enabled
,
)
)
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
(
from
sglang.srt.layers.linear
import
(
...
@@ -565,7 +566,7 @@ class GptOssModel(nn.Module):
...
@@ -565,7 +566,7 @@ class GptOssModel(nn.Module):
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
enable_tp
=
not
global_server_args_dict
[
"enable
_dp_attention
"
]
,
enable_tp
=
not
is
_dp_attention
_enabled
()
,
prefix
=
add_prefix
(
"embed_tokens"
,
prefix
),
prefix
=
add_prefix
(
"embed_tokens"
,
prefix
),
)
)
else
:
else
:
...
...
python/sglang/srt/models/llama4.py
View file @
b87aacb5
...
@@ -32,6 +32,7 @@ from sglang.srt.layers.dp_attention import (
...
@@ -32,6 +32,7 @@ from sglang.srt.layers.dp_attention import (
get_attention_tp_rank
,
get_attention_tp_rank
,
get_attention_tp_size
,
get_attention_tp_size
,
get_local_attention_dp_size
,
get_local_attention_dp_size
,
is_dp_attention_enabled
,
)
)
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
(
from
sglang.srt.layers.linear
import
(
...
@@ -45,7 +46,6 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
...
@@ -45,7 +46,6 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
(
from
sglang.srt.model_executor.forward_batch_info
import
(
ForwardBatch
,
ForwardBatch
,
ForwardMode
,
ForwardMode
,
...
@@ -466,7 +466,7 @@ class Llama4Model(nn.Module):
...
@@ -466,7 +466,7 @@ class Llama4Model(nn.Module):
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"embed_tokens"
,
prefix
),
prefix
=
add_prefix
(
"embed_tokens"
,
prefix
),
enable_tp
=
not
global_server_args_dict
[
"enable
_dp_attention
"
]
,
enable_tp
=
not
is
_dp_attention
_enabled
()
,
)
)
self
.
layers
=
make_layers
(
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
config
.
num_hidden_layers
,
...
...
python/sglang/srt/models/qwen2.py
View file @
b87aacb5
...
@@ -27,6 +27,7 @@ from sglang.srt.distributed import (
...
@@ -27,6 +27,7 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
)
)
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.dp_attention
import
is_dp_attention_enabled
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
(
from
sglang.srt.layers.linear
import
(
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
...
@@ -43,7 +44,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
...
@@ -43,7 +44,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_loader.weight_utils
import
(
from
sglang.srt.model_loader.weight_utils
import
(
default_weight_loader
,
default_weight_loader
,
...
@@ -273,7 +273,7 @@ class Qwen2Model(nn.Module):
...
@@ -273,7 +273,7 @@ class Qwen2Model(nn.Module):
config
.
vocab_size
,
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
enable_tp
=
not
global_server_args_dict
[
"enable
_dp_attention
"
]
,
enable_tp
=
not
is
_dp_attention
_enabled
()
,
prefix
=
add_prefix
(
"embed_tokens"
,
prefix
),
prefix
=
add_prefix
(
"embed_tokens"
,
prefix
),
)
)
else
:
else
:
...
...
python/sglang/srt/models/qwen2_moe.py
View file @
b87aacb5
...
@@ -46,6 +46,7 @@ from sglang.srt.layers.dp_attention import (
...
@@ -46,6 +46,7 @@ from sglang.srt.layers.dp_attention import (
get_attention_tp_rank
,
get_attention_tp_rank
,
get_attention_tp_size
,
get_attention_tp_size
,
get_local_attention_dp_size
,
get_local_attention_dp_size
,
is_dp_attention_enabled
,
)
)
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
(
from
sglang.srt.layers.linear
import
(
...
@@ -420,7 +421,7 @@ class Qwen2MoeModel(nn.Module):
...
@@ -420,7 +421,7 @@ class Qwen2MoeModel(nn.Module):
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
enable_tp
=
not
global_server_args_dict
[
"enable
_dp_attention
"
]
,
enable_tp
=
not
is
_dp_attention
_enabled
()
,
prefix
=
add_prefix
(
"embed_tokens"
,
prefix
),
prefix
=
add_prefix
(
"embed_tokens"
,
prefix
),
)
)
else
:
else
:
...
...
python/sglang/srt/models/step3_vl.py
View file @
b87aacb5
...
@@ -25,7 +25,11 @@ from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
...
@@ -25,7 +25,11 @@ from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.attention.vision
import
VisionAttention
from
sglang.srt.layers.attention.vision
import
VisionAttention
from
sglang.srt.layers.communicator
import
LayerCommunicator
,
LayerScatterModes
from
sglang.srt.layers.communicator
import
LayerCommunicator
,
LayerScatterModes
from
sglang.srt.layers.dp_attention
import
get_attention_tp_rank
,
get_attention_tp_size
from
sglang.srt.layers.dp_attention
import
(
get_attention_tp_rank
,
get_attention_tp_size
,
is_dp_attention_enabled
,
)
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
(
from
sglang.srt.layers.linear
import
(
ColumnParallelLinear
,
ColumnParallelLinear
,
...
@@ -437,7 +441,7 @@ class Step3TextModel(nn.Module):
...
@@ -437,7 +441,7 @@ class Step3TextModel(nn.Module):
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
enable_tp
=
not
global_server_args_dict
[
"enable
_dp_attention
"
]
,
enable_tp
=
not
is
_dp_attention
_enabled
()
,
prefix
=
add_prefix
(
"embed_tokens"
,
prefix
),
prefix
=
add_prefix
(
"embed_tokens"
,
prefix
),
)
)
...
...
python/sglang/srt/operations.py
View file @
b87aacb5
from
__future__
import
annotations
import
os
import
os
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Any
,
Callable
,
Dict
,
Generator
,
List
,
Sequence
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
Generator
,
List
,
Sequence
,
Union
import
torch
import
torch
from
sglang.srt.layers.dp_attention
import
set_dp_buffer_len
if
TYPE_CHECKING
:
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
_ENABLE_PROFILE
=
bool
(
int
(
os
.
environ
.
get
(
"SGLANG_OPERATIONS_ENABLE_PROFILE"
,
"0"
)))
_ENABLE_PROFILE
=
bool
(
int
(
os
.
environ
.
get
(
"SGLANG_OPERATIONS_ENABLE_PROFILE"
,
"0"
)))
if
_ENABLE_PROFILE
:
if
_ENABLE_PROFILE
:
...
@@ -66,18 +73,26 @@ Stage = List[ExecutionOperation]
...
@@ -66,18 +73,26 @@ Stage = List[ExecutionOperation]
class
_StageExecutor
:
class
_StageExecutor
:
def
__init__
(
self
,
debug_name
:
str
,
stages
:
List
[
Stage
],
inputs
):
def
__init__
(
self
,
debug_name
:
str
,
stages
:
List
[
Stage
],
inputs
:
dict
):
self
.
_debug_name
=
debug_name
self
.
_debug_name
=
debug_name
self
.
_stages
=
stages
self
.
_stages
=
stages
self
.
_index
=
0
self
.
_index
=
0
self
.
_stage_state
=
_StateDict
()
self
.
_stage_state
=
_StateDict
()
self
.
_stage_output
=
inputs
self
.
_stage_output
=
inputs
# handling DP attention
forward_batch
:
ForwardBatch
=
inputs
[
"forward_batch"
]
self
.
_global_dp_buffer_len
=
forward_batch
.
global_dp_buffer_len
self
.
_local_dp_buffer_len
=
forward_batch
.
input_ids
.
shape
[
0
]
def
next
(
self
):
def
next
(
self
):
assert
not
self
.
done
assert
not
self
.
done
stage
=
self
.
_stages
[
self
.
_index
]
stage
=
self
.
_stages
[
self
.
_index
]
if
self
.
_global_dp_buffer_len
is
not
None
:
set_dp_buffer_len
(
self
.
_global_dp_buffer_len
,
self
.
_local_dp_buffer_len
)
with
_annotate_region
(
debug_name
=
f
"
{
self
.
_debug_name
}{
self
.
_index
}
"
):
with
_annotate_region
(
debug_name
=
f
"
{
self
.
_debug_name
}{
self
.
_index
}
"
):
for
op
in
stage
:
for
op
in
stage
:
with
_annotate_region
(
debug_name
=
op
.
debug_name
):
with
_annotate_region
(
debug_name
=
op
.
debug_name
):
...
...
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
View file @
b87aacb5
...
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Callable
...
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Callable
import
torch
import
torch
from
sglang.srt.layers.dp_attention
import
D
P
PaddingMode
from
sglang.srt.layers.dp_attention
import
D
p
PaddingMode
,
set_dp_buffer_len
from
sglang.srt.model_executor.cuda_graph_runner
import
(
from
sglang.srt.model_executor.cuda_graph_runner
import
(
CUDA_GRAPH_CAPTURE_FAILED_MSG
,
CUDA_GRAPH_CAPTURE_FAILED_MSG
,
CudaGraphRunner
,
CudaGraphRunner
,
...
@@ -105,30 +105,15 @@ class EAGLEDraftCudaGraphRunner:
...
@@ -105,30 +105,15 @@ class EAGLEDraftCudaGraphRunner:
self
.
global_num_tokens_for_logprob_gpu
=
torch
.
zeros
(
self
.
global_num_tokens_for_logprob_gpu
=
torch
.
zeros
(
(
self
.
dp_size
,),
dtype
=
torch
.
int32
(
self
.
dp_size
,),
dtype
=
torch
.
int32
)
)
self
.
gathered_buffer
=
torch
.
zeros
(
(
self
.
max_num_token
*
self
.
dp_size
,
self
.
model_runner
.
model_config
.
hidden_size
,
),
dtype
=
self
.
model_runner
.
dtype
,
)
else
:
else
:
assert
self
.
require_attn_tp_gather
assert
self
.
require_attn_tp_gather
self
.
global_num_tokens_gpu
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
int32
)
self
.
global_num_tokens_gpu
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
int32
)
self
.
global_num_tokens_for_logprob_gpu
=
torch
.
zeros
(
self
.
global_num_tokens_for_logprob_gpu
=
torch
.
zeros
(
(
1
,),
dtype
=
torch
.
int32
(
1
,),
dtype
=
torch
.
int32
)
)
self
.
gathered_buffer
=
torch
.
zeros
(
(
self
.
max_num_token
,
self
.
model_runner
.
model_config
.
hidden_size
,
),
dtype
=
self
.
model_runner
.
dtype
,
)
else
:
else
:
self
.
global_num_tokens_gpu
=
None
self
.
global_num_tokens_gpu
=
None
self
.
global_num_tokens_for_logprob_gpu
=
None
self
.
global_num_tokens_for_logprob_gpu
=
None
self
.
gathered_buffer
=
None
# Capture
# Capture
try
:
try
:
...
@@ -193,7 +178,7 @@ class EAGLEDraftCudaGraphRunner:
...
@@ -193,7 +178,7 @@ class EAGLEDraftCudaGraphRunner:
)
)
)
)
global_num_tokens
=
self
.
global_num_tokens_gpu
global_num_tokens
=
self
.
global_num_tokens_gpu
g
athered_buffer
=
self
.
gathered_buffer
[:
num_tokens
*
self
.
dp_size
]
g
lobal_dp_buffer_len
=
num_tokens
*
self
.
dp_size
global_num_tokens_for_logprob
=
self
.
global_num_tokens_for_logprob_gpu
global_num_tokens_for_logprob
=
self
.
global_num_tokens_for_logprob_gpu
elif
self
.
require_attn_tp_gather
:
elif
self
.
require_attn_tp_gather
:
self
.
global_num_tokens_gpu
.
copy_
(
self
.
global_num_tokens_gpu
.
copy_
(
...
@@ -211,11 +196,11 @@ class EAGLEDraftCudaGraphRunner:
...
@@ -211,11 +196,11 @@ class EAGLEDraftCudaGraphRunner:
)
)
)
)
global_num_tokens
=
self
.
global_num_tokens_gpu
global_num_tokens
=
self
.
global_num_tokens_gpu
g
athered_buffer
=
self
.
gathered_buffer
[:
num_tokens
]
g
lobal_dp_buffer_len
=
num_tokens
global_num_tokens_for_logprob
=
self
.
global_num_tokens_for_logprob_gpu
global_num_tokens_for_logprob
=
self
.
global_num_tokens_for_logprob_gpu
else
:
else
:
global_num_tokens
=
None
global_num_tokens
=
None
g
athered
_buffer
=
None
g
lobal_dp
_buffer
_len
=
None
global_num_tokens_for_logprob
=
None
global_num_tokens_for_logprob
=
None
spec_info
=
EagleDraftInput
(
spec_info
=
EagleDraftInput
(
...
@@ -239,8 +224,8 @@ class EAGLEDraftCudaGraphRunner:
...
@@ -239,8 +224,8 @@ class EAGLEDraftCudaGraphRunner:
return_logprob
=
False
,
return_logprob
=
False
,
positions
=
positions
,
positions
=
positions
,
global_num_tokens_gpu
=
global_num_tokens
,
global_num_tokens_gpu
=
global_num_tokens
,
dp_padding_mode
=
D
P
PaddingMode
.
get_default_mode_in_cuda_graph
(),
dp_padding_mode
=
D
p
PaddingMode
.
get_default_mode_in_cuda_graph
(),
g
athered_buffer
=
gathered
_buffer
,
g
lobal_dp_buffer_len
=
global_dp
_buffer
_len
,
spec_algorithm
=
self
.
model_runner
.
spec_algorithm
,
spec_algorithm
=
self
.
model_runner
.
spec_algorithm
,
spec_info
=
spec_info
,
spec_info
=
spec_info
,
capture_hidden_mode
=
(
capture_hidden_mode
=
(
...
@@ -258,6 +243,7 @@ class EAGLEDraftCudaGraphRunner:
...
@@ -258,6 +243,7 @@ class EAGLEDraftCudaGraphRunner:
def
run_once
():
def
run_once
():
# Clean intermediate result cache for DP attention
# Clean intermediate result cache for DP attention
forward_batch
.
dp_local_start_pos
=
forward_batch
.
dp_local_num_tokens
=
None
forward_batch
.
dp_local_start_pos
=
forward_batch
.
dp_local_num_tokens
=
None
set_dp_buffer_len
(
global_dp_buffer_len
,
num_tokens
)
# Backup two fields, which will be modified in-place in `draft_forward`.
# Backup two fields, which will be modified in-place in `draft_forward`.
output_cache_loc_backup
=
forward_batch
.
out_cache_loc
output_cache_loc_backup
=
forward_batch
.
out_cache_loc
...
...
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
View file @
b87aacb5
...
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Callable
...
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Callable
import
torch
import
torch
from
sglang.srt.layers.dp_attention
import
D
P
PaddingMode
from
sglang.srt.layers.dp_attention
import
D
p
PaddingMode
,
set_dp_buffer_len
from
sglang.srt.model_executor.cuda_graph_runner
import
(
from
sglang.srt.model_executor.cuda_graph_runner
import
(
CUDA_GRAPH_CAPTURE_FAILED_MSG
,
CUDA_GRAPH_CAPTURE_FAILED_MSG
,
CudaGraphRunner
,
CudaGraphRunner
,
...
@@ -117,30 +117,15 @@ class EAGLEDraftExtendCudaGraphRunner:
...
@@ -117,30 +117,15 @@ class EAGLEDraftExtendCudaGraphRunner:
self
.
global_num_tokens_for_logprob_gpu
=
torch
.
zeros
(
self
.
global_num_tokens_for_logprob_gpu
=
torch
.
zeros
(
(
self
.
dp_size
,),
dtype
=
torch
.
int32
(
self
.
dp_size
,),
dtype
=
torch
.
int32
)
)
self
.
gathered_buffer
=
torch
.
zeros
(
(
self
.
max_num_token
*
self
.
dp_size
,
self
.
model_runner
.
model_config
.
hidden_size
,
),
dtype
=
self
.
model_runner
.
dtype
,
)
else
:
else
:
assert
self
.
require_attn_tp_gather
assert
self
.
require_attn_tp_gather
self
.
global_num_tokens_gpu
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
int32
)
self
.
global_num_tokens_gpu
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
int32
)
self
.
global_num_tokens_for_logprob_gpu
=
torch
.
zeros
(
self
.
global_num_tokens_for_logprob_gpu
=
torch
.
zeros
(
(
1
,),
dtype
=
torch
.
int32
(
1
,),
dtype
=
torch
.
int32
)
)
self
.
gathered_buffer
=
torch
.
zeros
(
(
self
.
max_num_token
,
self
.
model_runner
.
model_config
.
hidden_size
,
),
dtype
=
self
.
model_runner
.
dtype
,
)
else
:
else
:
self
.
global_num_tokens_gpu
=
None
self
.
global_num_tokens_gpu
=
None
self
.
global_num_tokens_for_logprob_gpu
=
None
self
.
global_num_tokens_for_logprob_gpu
=
None
self
.
gathered_buffer
=
None
if
hasattr
(
if
hasattr
(
self
.
model_runner
.
model_config
.
hf_config
,
"draft_vocab_size"
self
.
model_runner
.
model_config
.
hf_config
,
"draft_vocab_size"
...
@@ -222,7 +207,7 @@ class EAGLEDraftExtendCudaGraphRunner:
...
@@ -222,7 +207,7 @@ class EAGLEDraftExtendCudaGraphRunner:
device
=
self
.
input_ids
.
device
,
device
=
self
.
input_ids
.
device
,
)
)
)
)
g
athered_buffer
=
self
.
gathered_buffer
[:
num_tokens
*
self
.
dp_size
]
g
lobal_dp_buffer_len
=
num_tokens
*
self
.
dp_size
elif
self
.
require_attn_tp_gather
:
elif
self
.
require_attn_tp_gather
:
self
.
global_num_tokens_gpu
.
copy_
(
self
.
global_num_tokens_gpu
.
copy_
(
torch
.
tensor
(
torch
.
tensor
(
...
@@ -238,9 +223,9 @@ class EAGLEDraftExtendCudaGraphRunner:
...
@@ -238,9 +223,9 @@ class EAGLEDraftExtendCudaGraphRunner:
device
=
self
.
input_ids
.
device
,
device
=
self
.
input_ids
.
device
,
)
)
)
)
g
athered_buffer
=
self
.
gathered_buffer
[:
num_tokens
]
g
lobal_dp_buffer_len
=
num_tokens
else
:
else
:
g
athered
_buffer
=
None
g
lobal_dp
_buffer
_len
=
None
spec_info
=
EagleDraftInput
(
spec_info
=
EagleDraftInput
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
...
@@ -264,8 +249,8 @@ class EAGLEDraftExtendCudaGraphRunner:
...
@@ -264,8 +249,8 @@ class EAGLEDraftExtendCudaGraphRunner:
positions
=
positions
,
positions
=
positions
,
global_num_tokens_gpu
=
self
.
global_num_tokens_gpu
,
global_num_tokens_gpu
=
self
.
global_num_tokens_gpu
,
global_num_tokens_for_logprob_gpu
=
self
.
global_num_tokens_for_logprob_gpu
,
global_num_tokens_for_logprob_gpu
=
self
.
global_num_tokens_for_logprob_gpu
,
dp_padding_mode
=
D
P
PaddingMode
.
get_default_mode_in_cuda_graph
(),
dp_padding_mode
=
D
p
PaddingMode
.
get_default_mode_in_cuda_graph
(),
g
athered_buffer
=
gathered
_buffer
,
g
lobal_dp_buffer_len
=
global_dp
_buffer
_len
,
spec_algorithm
=
self
.
model_runner
.
spec_algorithm
,
spec_algorithm
=
self
.
model_runner
.
spec_algorithm
,
spec_info
=
spec_info
,
spec_info
=
spec_info
,
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
,
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
,
...
@@ -288,6 +273,7 @@ class EAGLEDraftExtendCudaGraphRunner:
...
@@ -288,6 +273,7 @@ class EAGLEDraftExtendCudaGraphRunner:
def
run_once
():
def
run_once
():
# Clean intermediate result cache for DP attention
# Clean intermediate result cache for DP attention
forward_batch
.
dp_local_start_pos
=
forward_batch
.
dp_local_num_tokens
=
None
forward_batch
.
dp_local_start_pos
=
forward_batch
.
dp_local_num_tokens
=
None
set_dp_buffer_len
(
global_dp_buffer_len
,
num_tokens
)
# Backup two fields, which will be modified in-place in `draft_forward`.
# Backup two fields, which will be modified in-place in `draft_forward`.
output_cache_loc_backup
=
forward_batch
.
out_cache_loc
output_cache_loc_backup
=
forward_batch
.
out_cache_loc
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment