Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
62832bb2
".github/vscode:/vscode.git/clone" did not exist on "3f14b88db5b98552b9dc637a86ea3998cb4b4c16"
Unverified
Commit
62832bb2
authored
Nov 18, 2024
by
Ke Bao
Committed by
GitHub
Nov 17, 2024
Browse files
Support cuda graph for DP attention (#2061)
parent
11f881d1
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
88 additions
and
26 deletions
+88
-26
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+10
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+23
-9
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+0
-3
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+1
-4
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+44
-6
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+2
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+3
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+3
-2
scripts/playground/reference_hf.py
scripts/playground/reference_hf.py
+2
-2
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
62832bb2
...
...
@@ -455,6 +455,7 @@ class ScheduleBatch:
# For DP attention
global_num_tokens
:
Optional
[
List
[
int
]]
=
None
can_run_dp_cuda_graph
:
bool
=
False
# For processing logprobs
return_logprob
:
bool
=
False
...
...
@@ -891,6 +892,13 @@ class ScheduleBatch:
self
.
seq_lens
=
torch
.
empty
(
0
,
dtype
=
torch
.
int32
).
to
(
self
.
device
,
non_blocking
=
True
)
self
.
out_cache_loc
=
torch
.
empty
(
0
,
dtype
=
torch
.
int32
).
to
(
self
.
device
,
non_blocking
=
True
)
self
.
req_pool_indices
=
torch
.
empty
(
0
,
dtype
=
torch
.
int32
).
to
(
self
.
device
,
non_blocking
=
True
)
self
.
seq_lens_sum
=
0
self
.
extend_num_tokens
=
0
def
prepare_for_decode
(
self
,
enable_overlap
:
bool
=
False
):
...
...
@@ -1032,6 +1040,7 @@ class ScheduleBatch:
return_logprob
=
self
.
return_logprob
,
top_logprobs_nums
=
self
.
top_logprobs_nums
,
global_num_tokens
=
self
.
global_num_tokens
,
can_run_dp_cuda_graph
=
self
.
can_run_dp_cuda_graph
,
extend_num_tokens
=
self
.
extend_num_tokens
,
extend_seq_lens
=
extend_seq_lens
,
extend_prefix_lens
=
extend_prefix_lens
,
...
...
@@ -1093,6 +1102,7 @@ class ModelWorkerBatch:
# For DP attention
global_num_tokens
:
Optional
[
List
[
int
]]
can_run_dp_cuda_graph
:
bool
# For extend
extend_num_tokens
:
Optional
[
int
]
...
...
python/sglang/srt/managers/scheduler.py
View file @
62832bb2
...
...
@@ -337,7 +337,7 @@ class Scheduler:
kill_parent_process
()
@
torch
.
inference_mode
()
@
torch
.
no_grad
()
def
event_loop_normal
(
self
):
"""A normal blocking scheduler loop."""
self
.
last_batch
=
None
...
...
@@ -375,7 +375,7 @@ class Scheduler:
self
.
last_batch
=
batch
@
torch
.
inference_mode
()
@
torch
.
no_grad
()
def
event_loop_overlap
(
self
):
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
result_queue
=
deque
()
...
...
@@ -411,16 +411,12 @@ class Scheduler:
else
:
num_tokens
=
local_batch
.
extend_num_tokens
local_num_tokens
=
torch
.
tensor
(
num_tokens
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
global_num_tokens
=
torch
.
empty
(
self
.
tp_size
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
local_num_tokens
=
torch
.
tensor
([
num_tokens
],
dtype
=
torch
.
int64
)
global_num_tokens
=
torch
.
empty
(
self
.
tp_size
,
dtype
=
torch
.
int64
)
torch
.
distributed
.
all_gather_into_tensor
(
global_num_tokens
,
local_num_tokens
,
group
=
self
.
tp_
worker
.
get_tp_device
_group
()
,
group
=
self
.
tp_
cpu
_group
,
)
if
local_batch
is
None
and
global_num_tokens
.
max
().
item
()
>
0
:
...
...
@@ -429,6 +425,24 @@ class Scheduler:
if
local_batch
is
not
None
:
local_batch
.
global_num_tokens
=
global_num_tokens
.
tolist
()
# Check forward mode for cuda graph
if
not
self
.
server_args
.
disable_cuda_graph
:
forward_mode_state
=
torch
.
tensor
(
(
1
if
local_batch
.
forward_mode
.
is_decode
()
or
local_batch
.
forward_mode
.
is_idle
()
else
0
),
dtype
=
torch
.
int32
,
)
torch
.
distributed
.
all_reduce
(
forward_mode_state
,
op
=
torch
.
distributed
.
ReduceOp
.
MIN
,
group
=
self
.
tp_cpu_group
,
)
local_batch
.
can_run_dp_cuda_graph
=
forward_mode_state
.
item
()
==
1
return
local_batch
def
get_idle_batch
(
self
):
...
...
python/sglang/srt/managers/tp_worker.py
View file @
62832bb2
...
...
@@ -128,9 +128,6 @@ class TpModelWorker:
def
get_tp_cpu_group
(
self
):
return
self
.
model_runner
.
tp_group
.
cpu_group
def
get_tp_device_group
(
self
):
return
self
.
model_runner
.
tp_group
.
device_group
def
get_memory_pool
(
self
):
return
(
self
.
model_runner
.
req_to_token_pool
,
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
62832bb2
...
...
@@ -83,9 +83,6 @@ class TpModelWorkerClient:
def
get_tp_cpu_group
(
self
):
return
self
.
worker
.
get_tp_cpu_group
()
def
get_tp_device_group
(
self
):
return
self
.
worker
.
get_tp_device_group
()
def
get_memory_pool
(
self
):
return
(
self
.
worker
.
model_runner
.
req_to_token_pool
,
...
...
@@ -96,7 +93,7 @@ class TpModelWorkerClient:
with
torch
.
cuda
.
stream
(
self
.
forward_stream
):
self
.
forward_thread_func_
()
@
torch
.
inference_mode
()
@
torch
.
no_grad
()
def
forward_thread_func_
(
self
):
while
True
:
model_worker_batch
,
future_token_ids_ct
=
self
.
input_queue
.
get
()
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
62832bb2
...
...
@@ -111,6 +111,8 @@ class CudaGraphRunner:
self
.
use_torch_compile
=
model_runner
.
server_args
.
enable_torch_compile
self
.
disable_padding
=
model_runner
.
server_args
.
disable_cuda_graph_padding
self
.
is_encoder_decoder
=
self
.
model_runner
.
model_config
.
is_encoder_decoder
self
.
enable_dp_attention
=
self
.
model_runner
.
server_args
.
enable_dp_attention
self
.
tp_size
=
self
.
model_runner
.
tp_size
# Batch sizes to capture
if
model_runner
.
server_args
.
disable_cuda_graph_padding
:
...
...
@@ -165,6 +167,16 @@ class CudaGraphRunner:
else
:
self
.
encoder_lens
=
None
if
self
.
enable_dp_attention
:
self
.
global_num_tokens
=
[
0
]
*
self
.
tp_size
self
.
gathered_buffer
=
torch
.
zeros
(
(
self
.
max_bs
*
self
.
tp_size
,
self
.
model_runner
.
model_config
.
hidden_size
,
),
dtype
=
self
.
model_runner
.
dtype
,
)
# Capture
try
:
with
self
.
model_capture_mode
():
...
...
@@ -190,6 +202,16 @@ class CudaGraphRunner:
self
.
model_runner
.
model
.
capture_mode
=
False
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
if
self
.
enable_dp_attention
:
min_num_tokens
,
max_num_tokens
=
min
(
forward_batch
.
global_num_tokens
),
max
(
forward_batch
.
global_num_tokens
)
is_bs_supported
=
forward_batch
.
can_run_dp_cuda_graph
and
(
(
min_num_tokens
==
max_num_tokens
and
max_num_tokens
in
self
.
graphs
)
if
self
.
disable_padding
else
max_num_tokens
<=
self
.
max_bs
)
else
:
is_bs_supported
=
(
forward_batch
.
batch_size
in
self
.
graphs
if
self
.
disable_padding
...
...
@@ -239,6 +261,13 @@ class CudaGraphRunner:
seq_lens_sum
=
seq_lens
.
sum
().
item
()
mrope_positions
=
self
.
mrope_positions
[:,
:
bs
]
if
self
.
enable_dp_attention
:
self
.
global_num_tokens
[:]
=
[
bs
]
*
self
.
tp_size
gathered_buffer
=
self
.
gathered_buffer
[:
bs
*
self
.
tp_size
]
else
:
self
.
global_num_tokens
=
None
gathered_buffer
=
None
# Attention backend
self
.
model_runner
.
attn_backend
.
init_forward_metadata_capture_cuda_graph
(
bs
,
...
...
@@ -265,6 +294,8 @@ class CudaGraphRunner:
top_logprobs_nums
=
[
0
]
*
bs
,
positions
=
clamp_position
(
seq_lens
),
mrope_positions
=
mrope_positions
,
global_num_tokens
=
self
.
global_num_tokens
,
gathered_buffer
=
gathered_buffer
,
)
logits_output
=
forward
(
input_ids
,
forward_batch
.
positions
,
forward_batch
)
return
logits_output
.
next_token_logits
...
...
@@ -295,6 +326,11 @@ class CudaGraphRunner:
raw_bs
=
forward_batch
.
batch_size
# Pad
if
self
.
enable_dp_attention
:
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
max
(
forward_batch
.
global_num_tokens
)
)
else
:
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
raw_bs
)
bs
=
self
.
capture_bs
[
index
]
if
bs
!=
raw_bs
:
...
...
@@ -310,6 +346,8 @@ class CudaGraphRunner:
self
.
encoder_lens
[:
raw_bs
].
copy_
(
forward_batch
.
encoder_lens
)
if
forward_batch
.
mrope_positions
is
not
None
:
self
.
mrope_positions
[:,
:
raw_bs
].
copy_
(
forward_batch
.
mrope_positions
)
if
self
.
enable_dp_attention
:
self
.
global_num_tokens
[:]
=
[
bs
]
*
self
.
tp_size
# Attention backend
self
.
model_runner
.
attn_backend
.
init_forward_metadata_replay_cuda_graph
(
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
62832bb2
...
...
@@ -138,6 +138,7 @@ class ForwardBatch:
# For DP attention
global_num_tokens
:
Optional
[
List
[
int
]]
=
None
gathered_buffer
:
Optional
[
torch
.
Tensor
]
=
None
can_run_dp_cuda_graph
:
bool
=
False
def
compute_mrope_positions
(
self
,
model_runner
:
ModelRunner
,
batch
:
ModelWorkerBatch
...
...
@@ -221,6 +222,7 @@ class ForwardBatch:
return_logprob
=
batch
.
return_logprob
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
global_num_tokens
=
batch
.
global_num_tokens
,
can_run_dp_cuda_graph
=
batch
.
can_run_dp_cuda_graph
,
lora_paths
=
batch
.
lora_paths
,
sampling_info
=
batch
.
sampling_info
,
)
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
62832bb2
...
...
@@ -592,6 +592,9 @@ class ModelRunner:
)
def
forward_idle
(
self
,
forward_batch
:
ForwardBatch
):
if
self
.
cuda_graph_runner
and
self
.
cuda_graph_runner
.
can_run
(
forward_batch
):
return
self
.
cuda_graph_runner
.
replay
(
forward_batch
)
return
self
.
model
.
forward
(
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
)
...
...
python/sglang/srt/server_args.py
View file @
62832bb2
...
...
@@ -191,11 +191,12 @@ class ServerArgs:
if
self
.
enable_dp_attention
:
self
.
dp_size
=
self
.
tp_size
self
.
chunked_prefill_size
=
self
.
chunked_prefill_size
//
2
self
.
disable_cuda_graph
=
True
self
.
cuda_graph_max_bs
=
min
(
self
.
cuda_graph_max_bs
,
96
)
self
.
enable_overlap_schedule
=
False
logger
.
warning
(
f
"DP attention is enabled. The chunked prefill size is adjusted to
{
self
.
chunked_prefill_size
}
to avoid MoE workload issue. "
"The CUDA graph is disabled. Data parallel size is adjust to be the same as tensor parallel size."
f
"The CUDA graph max batch size is adjusted to
{
self
.
cuda_graph_max_bs
}
. "
"Data parallel size is adjusted to be the same as tensor parallel size."
)
if
self
.
enable_overlap_schedule
:
...
...
scripts/playground/reference_hf.py
View file @
62832bb2
...
...
@@ -31,7 +31,7 @@ from transformers import AutoModelForCausalLM
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
@
torch
.
inference_mode
()
@
torch
.
no_grad
()
def
normal_text
(
args
):
t
=
get_tokenizer
(
args
.
model_path
,
trust_remote_code
=
True
)
m
=
AutoModelForCausalLM
.
from_pretrained
(
...
...
@@ -69,7 +69,7 @@ def normal_text(args):
print
(
output_str
)
@
torch
.
inference_mode
()
@
torch
.
no_grad
()
def
synthetic_tokens
(
args
):
m
=
AutoModelForCausalLM
.
from_pretrained
(
args
.
model_path
,
torch_dtype
=
torch
.
float16
,
low_cpu_mem_usage
=
True
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment