Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
61970b08
Unverified
Commit
61970b08
authored
Apr 09, 2025
by
fzyzcjy
Committed by
GitHub
Apr 08, 2025
Browse files
Let `bench_one_batch` support `enable_dp_attention` (#4058)
parent
76c48a09
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
49 additions
and
7 deletions
+49
-7
python/sglang/bench_one_batch.py
python/sglang/bench_one_batch.py
+20
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+29
-7
No files found.
python/sglang/bench_one_batch.py
View file @
61970b08
...
@@ -60,6 +60,7 @@ from sglang.srt.configs.model_config import ModelConfig
...
@@ -60,6 +60,7 @@ from sglang.srt.configs.model_config import ModelConfig
from
sglang.srt.entrypoints.engine
import
_set_envs_and_config
from
sglang.srt.entrypoints.engine
import
_set_envs_and_config
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.managers.schedule_batch
import
Req
,
ScheduleBatch
from
sglang.srt.managers.schedule_batch
import
Req
,
ScheduleBatch
from
sglang.srt.managers.scheduler
import
Scheduler
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.sampling.sampling_params
import
SamplingParams
...
@@ -184,6 +185,7 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer):
...
@@ -184,6 +185,7 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer):
req
.
prefix_indices
=
[]
req
.
prefix_indices
=
[]
req
.
fill_ids
=
req
.
origin_input_ids
req
.
fill_ids
=
req
.
origin_input_ids
req
.
extend_input_len
=
len
(
req
.
fill_ids
)
-
len
(
req
.
prefix_indices
)
req
.
extend_input_len
=
len
(
req
.
fill_ids
)
-
len
(
req
.
prefix_indices
)
req
.
logprob_start_len
=
len
(
req
.
origin_input_ids
)
-
1
reqs
.
append
(
req
)
reqs
.
append
(
req
)
return
input_ids
,
reqs
return
input_ids
,
reqs
...
@@ -199,6 +201,7 @@ def prepare_extend_inputs_for_correctness_test(
...
@@ -199,6 +201,7 @@ def prepare_extend_inputs_for_correctness_test(
i
,
:
bench_args
.
cut_len
i
,
:
bench_args
.
cut_len
]
]
req
.
extend_input_len
=
len
(
req
.
fill_ids
)
-
len
(
req
.
prefix_indices
)
req
.
extend_input_len
=
len
(
req
.
fill_ids
)
-
len
(
req
.
prefix_indices
)
req
.
logprob_start_len
=
len
(
req
.
origin_input_ids
)
-
1
return
reqs
return
reqs
...
@@ -220,6 +223,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
...
@@ -220,6 +223,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
req
.
prefix_indices
=
[]
req
.
prefix_indices
=
[]
req
.
fill_ids
=
req
.
origin_input_ids
req
.
fill_ids
=
req
.
origin_input_ids
req
.
extend_input_len
=
len
(
req
.
fill_ids
)
-
len
(
req
.
prefix_indices
)
req
.
extend_input_len
=
len
(
req
.
fill_ids
)
-
len
(
req
.
prefix_indices
)
req
.
logprob_start_len
=
len
(
req
.
origin_input_ids
)
-
1
reqs
.
append
(
req
)
reqs
.
append
(
req
)
return
reqs
return
reqs
...
@@ -238,6 +242,7 @@ def extend(reqs, model_runner):
...
@@ -238,6 +242,7 @@ def extend(reqs, model_runner):
enable_custom_logit_processor
=
False
,
enable_custom_logit_processor
=
False
,
)
)
batch
.
prepare_for_extend
()
batch
.
prepare_for_extend
()
_maybe_prepare_dp_attn_batch
(
batch
,
model_runner
)
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
model_runner
)
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
model_runner
)
logits_output
=
model_runner
.
forward
(
forward_batch
)
logits_output
=
model_runner
.
forward
(
forward_batch
)
...
@@ -249,6 +254,7 @@ def extend(reqs, model_runner):
...
@@ -249,6 +254,7 @@ def extend(reqs, model_runner):
def
decode
(
input_token_ids
,
batch
,
model_runner
):
def
decode
(
input_token_ids
,
batch
,
model_runner
):
batch
.
output_ids
=
input_token_ids
batch
.
output_ids
=
input_token_ids
batch
.
prepare_for_decode
()
batch
.
prepare_for_decode
()
_maybe_prepare_dp_attn_batch
(
batch
,
model_runner
)
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
model_runner
)
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
model_runner
)
logits_output
=
model_runner
.
forward
(
forward_batch
)
logits_output
=
model_runner
.
forward
(
forward_batch
)
...
@@ -256,6 +262,20 @@ def decode(input_token_ids, batch, model_runner):
...
@@ -256,6 +262,20 @@ def decode(input_token_ids, batch, model_runner):
return
next_token_ids
,
logits_output
.
next_token_logits
return
next_token_ids
,
logits_output
.
next_token_logits
def
_maybe_prepare_dp_attn_batch
(
batch
:
ScheduleBatch
,
model_runner
):
if
model_runner
.
server_args
.
enable_dp_attention
:
Scheduler
.
prepare_dp_attn_batch_raw
(
batch
,
dp_size
=
model_runner
.
server_args
.
dp_size
,
attn_tp_size
=
1
,
tp_cpu_group
=
model_runner
.
tp_group
.
cpu_group
,
get_idle_batch
=
None
,
disable_cuda_graph
=
model_runner
.
server_args
.
disable_cuda_graph
,
spec_algorithm
=
SpeculativeAlgorithm
.
NONE
,
speculative_num_draft_tokens
=
None
,
)
def
correctness_test
(
def
correctness_test
(
server_args
,
server_args
,
port_args
,
port_args
,
...
...
python/sglang/srt/managers/scheduler.py
View file @
61970b08
...
@@ -1466,14 +1466,36 @@ class Scheduler(
...
@@ -1466,14 +1466,36 @@ class Scheduler(
self
.
send_to_tokenizer
.
send_pyobj
(
HealthCheckOutput
())
self
.
send_to_tokenizer
.
send_pyobj
(
HealthCheckOutput
())
def
prepare_dp_attn_batch
(
self
,
local_batch
:
ScheduleBatch
):
def
prepare_dp_attn_batch
(
self
,
local_batch
:
ScheduleBatch
):
return
self
.
prepare_dp_attn_batch_raw
(
local_batch
,
dp_size
=
self
.
server_args
.
dp_size
,
attn_tp_size
=
self
.
attn_tp_size
,
tp_cpu_group
=
self
.
tp_cpu_group
,
get_idle_batch
=
self
.
get_idle_batch
,
disable_cuda_graph
=
self
.
server_args
.
disable_cuda_graph
,
spec_algorithm
=
self
.
spec_algorithm
,
speculative_num_draft_tokens
=
self
.
server_args
.
speculative_num_draft_tokens
,
)
@
staticmethod
def
prepare_dp_attn_batch_raw
(
local_batch
:
ScheduleBatch
,
dp_size
,
attn_tp_size
:
int
,
tp_cpu_group
,
get_idle_batch
,
disable_cuda_graph
:
bool
,
spec_algorithm
,
speculative_num_draft_tokens
,
):
# Check if other DP workers have running batches
# Check if other DP workers have running batches
if
local_batch
is
None
:
if
local_batch
is
None
:
num_tokens
=
0
num_tokens
=
0
global_num_tokens_for_logprob
=
0
global_num_tokens_for_logprob
=
0
elif
local_batch
.
forward_mode
.
is_decode
():
elif
local_batch
.
forward_mode
.
is_decode
():
num_tokens
=
local_batch
.
batch_size
()
num_tokens
=
local_batch
.
batch_size
()
if
not
self
.
spec_algorithm
.
is_none
()
and
self
.
spec_algorithm
.
is_eagle
():
if
not
spec_algorithm
.
is_none
()
and
spec_algorithm
.
is_eagle
():
num_tokens
=
num_tokens
*
self
.
server_args
.
speculative_num_draft_tokens
num_tokens
=
num_tokens
*
speculative_num_draft_tokens
global_num_tokens_for_logprob
=
num_tokens
global_num_tokens_for_logprob
=
num_tokens
else
:
else
:
num_tokens
=
local_batch
.
extend_num_tokens
num_tokens
=
local_batch
.
extend_num_tokens
...
@@ -1492,7 +1514,7 @@ class Scheduler(
...
@@ -1492,7 +1514,7 @@ class Scheduler(
else
:
else
:
can_cuda_graph
=
0
can_cuda_graph
=
0
if
not
self
.
spec_algorithm
.
is_none
():
if
not
spec_algorithm
.
is_none
():
# TODO(sang): Support cuda graph when idle batch is there.
# TODO(sang): Support cuda graph when idle batch is there.
if
local_batch
is
None
or
local_batch
.
forward_mode
.
is_idle
():
if
local_batch
is
None
or
local_batch
.
forward_mode
.
is_idle
():
can_cuda_graph
=
0
can_cuda_graph
=
0
...
@@ -1510,13 +1532,13 @@ class Scheduler(
...
@@ -1510,13 +1532,13 @@ class Scheduler(
dtype
=
torch
.
int64
,
dtype
=
torch
.
int64
,
)
)
global_info
=
torch
.
empty
(
global_info
=
torch
.
empty
(
(
self
.
server_args
.
dp_size
,
self
.
attn_tp_size
,
4
),
(
dp_size
,
attn_tp_size
,
4
),
dtype
=
torch
.
int64
,
dtype
=
torch
.
int64
,
)
)
torch
.
distributed
.
all_gather_into_tensor
(
torch
.
distributed
.
all_gather_into_tensor
(
global_info
.
flatten
(),
global_info
.
flatten
(),
local_info
,
local_info
,
group
=
self
.
tp_cpu_group
,
group
=
tp_cpu_group
,
)
)
global_num_tokens
=
global_info
[:,
0
,
0
].
tolist
()
global_num_tokens
=
global_info
[:,
0
,
0
].
tolist
()
can_cuda_graph
=
min
(
global_info
[:,
0
,
1
].
tolist
())
can_cuda_graph
=
min
(
global_info
[:,
0
,
1
].
tolist
())
...
@@ -1524,14 +1546,14 @@ class Scheduler(
...
@@ -1524,14 +1546,14 @@ class Scheduler(
is_extend_in_batch
=
global_info
[:,
0
,
3
].
tolist
()
is_extend_in_batch
=
global_info
[:,
0
,
3
].
tolist
()
if
local_batch
is
None
and
max
(
global_num_tokens
)
>
0
:
if
local_batch
is
None
and
max
(
global_num_tokens
)
>
0
:
local_batch
=
self
.
get_idle_batch
()
local_batch
=
get_idle_batch
()
if
local_batch
is
not
None
:
if
local_batch
is
not
None
:
local_batch
.
global_num_tokens
=
global_num_tokens
local_batch
.
global_num_tokens
=
global_num_tokens
local_batch
.
global_num_tokens_for_logprob
=
global_num_tokens_for_logprob
local_batch
.
global_num_tokens_for_logprob
=
global_num_tokens_for_logprob
# Check forward mode for cuda graph
# Check forward mode for cuda graph
if
not
self
.
server_args
.
disable_cuda_graph
:
if
not
disable_cuda_graph
:
local_batch
.
can_run_dp_cuda_graph
=
can_cuda_graph
local_batch
.
can_run_dp_cuda_graph
=
can_cuda_graph
return
local_batch
,
any
(
is_extend_in_batch
)
return
local_batch
,
any
(
is_extend_in_batch
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment