Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
62832bb2
Unverified
Commit
62832bb2
authored
Nov 18, 2024
by
Ke Bao
Committed by
GitHub
Nov 17, 2024
Browse files
Support cuda graph for DP attention (#2061)
parent
11f881d1
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
88 additions
and
26 deletions
+88
-26
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+10
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+23
-9
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+0
-3
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+1
-4
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+44
-6
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+2
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+3
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+3
-2
scripts/playground/reference_hf.py
scripts/playground/reference_hf.py
+2
-2
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
62832bb2
...
...
@@ -455,6 +455,7 @@ class ScheduleBatch:
# For DP attention
global_num_tokens
:
Optional
[
List
[
int
]]
=
None
can_run_dp_cuda_graph
:
bool
=
False
# For processing logprobs
return_logprob
:
bool
=
False
...
...
@@ -891,6 +892,13 @@ class ScheduleBatch:
self
.
seq_lens
=
torch
.
empty
(
0
,
dtype
=
torch
.
int32
).
to
(
self
.
device
,
non_blocking
=
True
)
self
.
out_cache_loc
=
torch
.
empty
(
0
,
dtype
=
torch
.
int32
).
to
(
self
.
device
,
non_blocking
=
True
)
self
.
req_pool_indices
=
torch
.
empty
(
0
,
dtype
=
torch
.
int32
).
to
(
self
.
device
,
non_blocking
=
True
)
self
.
seq_lens_sum
=
0
self
.
extend_num_tokens
=
0
def
prepare_for_decode
(
self
,
enable_overlap
:
bool
=
False
):
...
...
@@ -1032,6 +1040,7 @@ class ScheduleBatch:
return_logprob
=
self
.
return_logprob
,
top_logprobs_nums
=
self
.
top_logprobs_nums
,
global_num_tokens
=
self
.
global_num_tokens
,
can_run_dp_cuda_graph
=
self
.
can_run_dp_cuda_graph
,
extend_num_tokens
=
self
.
extend_num_tokens
,
extend_seq_lens
=
extend_seq_lens
,
extend_prefix_lens
=
extend_prefix_lens
,
...
...
@@ -1093,6 +1102,7 @@ class ModelWorkerBatch:
# For DP attention
global_num_tokens
:
Optional
[
List
[
int
]]
can_run_dp_cuda_graph
:
bool
# For extend
extend_num_tokens
:
Optional
[
int
]
...
...
python/sglang/srt/managers/scheduler.py
View file @
62832bb2
...
...
@@ -337,7 +337,7 @@ class Scheduler:
kill_parent_process
()
@
torch
.
inference_mode
()
@
torch
.
no_grad
()
def
event_loop_normal
(
self
):
"""A normal blocking scheduler loop."""
self
.
last_batch
=
None
...
...
@@ -375,7 +375,7 @@ class Scheduler:
self
.
last_batch
=
batch
@
torch
.
inference_mode
()
@
torch
.
no_grad
()
def
event_loop_overlap
(
self
):
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
result_queue
=
deque
()
...
...
@@ -411,16 +411,12 @@ class Scheduler:
else
:
num_tokens
=
local_batch
.
extend_num_tokens
local_num_tokens
=
torch
.
tensor
(
num_tokens
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
global_num_tokens
=
torch
.
empty
(
self
.
tp_size
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
local_num_tokens
=
torch
.
tensor
([
num_tokens
],
dtype
=
torch
.
int64
)
global_num_tokens
=
torch
.
empty
(
self
.
tp_size
,
dtype
=
torch
.
int64
)
torch
.
distributed
.
all_gather_into_tensor
(
global_num_tokens
,
local_num_tokens
,
group
=
self
.
tp_
worker
.
get_tp_device
_group
()
,
group
=
self
.
tp_
cpu
_group
,
)
if
local_batch
is
None
and
global_num_tokens
.
max
().
item
()
>
0
:
...
...
@@ -429,6 +425,24 @@ class Scheduler:
if
local_batch
is
not
None
:
local_batch
.
global_num_tokens
=
global_num_tokens
.
tolist
()
# Check forward mode for cuda graph
if
not
self
.
server_args
.
disable_cuda_graph
:
forward_mode_state
=
torch
.
tensor
(
(
1
if
local_batch
.
forward_mode
.
is_decode
()
or
local_batch
.
forward_mode
.
is_idle
()
else
0
),
dtype
=
torch
.
int32
,
)
torch
.
distributed
.
all_reduce
(
forward_mode_state
,
op
=
torch
.
distributed
.
ReduceOp
.
MIN
,
group
=
self
.
tp_cpu_group
,
)
local_batch
.
can_run_dp_cuda_graph
=
forward_mode_state
.
item
()
==
1
return
local_batch
def
get_idle_batch
(
self
):
...
...
python/sglang/srt/managers/tp_worker.py
View file @
62832bb2
...
...
@@ -128,9 +128,6 @@ class TpModelWorker:
def
get_tp_cpu_group
(
self
):
return
self
.
model_runner
.
tp_group
.
cpu_group
def
get_tp_device_group
(
self
):
return
self
.
model_runner
.
tp_group
.
device_group
def
get_memory_pool
(
self
):
return
(
self
.
model_runner
.
req_to_token_pool
,
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
62832bb2
...
...
@@ -83,9 +83,6 @@ class TpModelWorkerClient:
def
get_tp_cpu_group
(
self
):
return
self
.
worker
.
get_tp_cpu_group
()
def
get_tp_device_group
(
self
):
return
self
.
worker
.
get_tp_device_group
()
def
get_memory_pool
(
self
):
return
(
self
.
worker
.
model_runner
.
req_to_token_pool
,
...
...
@@ -96,7 +93,7 @@ class TpModelWorkerClient:
with
torch
.
cuda
.
stream
(
self
.
forward_stream
):
self
.
forward_thread_func_
()
@
torch
.
inference_mode
()
@
torch
.
no_grad
()
def
forward_thread_func_
(
self
):
while
True
:
model_worker_batch
,
future_token_ids_ct
=
self
.
input_queue
.
get
()
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
62832bb2
...
...
@@ -111,6 +111,8 @@ class CudaGraphRunner:
self
.
use_torch_compile
=
model_runner
.
server_args
.
enable_torch_compile
self
.
disable_padding
=
model_runner
.
server_args
.
disable_cuda_graph_padding
self
.
is_encoder_decoder
=
self
.
model_runner
.
model_config
.
is_encoder_decoder
self
.
enable_dp_attention
=
self
.
model_runner
.
server_args
.
enable_dp_attention
self
.
tp_size
=
self
.
model_runner
.
tp_size
# Batch sizes to capture
if
model_runner
.
server_args
.
disable_cuda_graph_padding
:
...
...
@@ -165,6 +167,16 @@ class CudaGraphRunner:
else
:
self
.
encoder_lens
=
None
if
self
.
enable_dp_attention
:
self
.
global_num_tokens
=
[
0
]
*
self
.
tp_size
self
.
gathered_buffer
=
torch
.
zeros
(
(
self
.
max_bs
*
self
.
tp_size
,
self
.
model_runner
.
model_config
.
hidden_size
,
),
dtype
=
self
.
model_runner
.
dtype
,
)
# Capture
try
:
with
self
.
model_capture_mode
():
...
...
@@ -190,11 +202,21 @@ class CudaGraphRunner:
self
.
model_runner
.
model
.
capture_mode
=
False
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
is_bs_supported
=
(
forward_batch
.
batch_size
in
self
.
graphs
if
self
.
disable_padding
else
forward_batch
.
batch_size
<=
self
.
max_bs
)
if
self
.
enable_dp_attention
:
min_num_tokens
,
max_num_tokens
=
min
(
forward_batch
.
global_num_tokens
),
max
(
forward_batch
.
global_num_tokens
)
is_bs_supported
=
forward_batch
.
can_run_dp_cuda_graph
and
(
(
min_num_tokens
==
max_num_tokens
and
max_num_tokens
in
self
.
graphs
)
if
self
.
disable_padding
else
max_num_tokens
<=
self
.
max_bs
)
else
:
is_bs_supported
=
(
forward_batch
.
batch_size
in
self
.
graphs
if
self
.
disable_padding
else
forward_batch
.
batch_size
<=
self
.
max_bs
)
# NOTE: cuda graph cannot handle mixed batch (encoder_len = 0)
# If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph
...
...
@@ -239,6 +261,13 @@ class CudaGraphRunner:
seq_lens_sum
=
seq_lens
.
sum
().
item
()
mrope_positions
=
self
.
mrope_positions
[:,
:
bs
]
if
self
.
enable_dp_attention
:
self
.
global_num_tokens
[:]
=
[
bs
]
*
self
.
tp_size
gathered_buffer
=
self
.
gathered_buffer
[:
bs
*
self
.
tp_size
]
else
:
self
.
global_num_tokens
=
None
gathered_buffer
=
None
# Attention backend
self
.
model_runner
.
attn_backend
.
init_forward_metadata_capture_cuda_graph
(
bs
,
...
...
@@ -265,6 +294,8 @@ class CudaGraphRunner:
top_logprobs_nums
=
[
0
]
*
bs
,
positions
=
clamp_position
(
seq_lens
),
mrope_positions
=
mrope_positions
,
global_num_tokens
=
self
.
global_num_tokens
,
gathered_buffer
=
gathered_buffer
,
)
logits_output
=
forward
(
input_ids
,
forward_batch
.
positions
,
forward_batch
)
return
logits_output
.
next_token_logits
...
...
@@ -295,7 +326,12 @@ class CudaGraphRunner:
raw_bs
=
forward_batch
.
batch_size
# Pad
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
raw_bs
)
if
self
.
enable_dp_attention
:
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
max
(
forward_batch
.
global_num_tokens
)
)
else
:
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
raw_bs
)
bs
=
self
.
capture_bs
[
index
]
if
bs
!=
raw_bs
:
self
.
seq_lens
.
fill_
(
1
)
...
...
@@ -310,6 +346,8 @@ class CudaGraphRunner:
self
.
encoder_lens
[:
raw_bs
].
copy_
(
forward_batch
.
encoder_lens
)
if
forward_batch
.
mrope_positions
is
not
None
:
self
.
mrope_positions
[:,
:
raw_bs
].
copy_
(
forward_batch
.
mrope_positions
)
if
self
.
enable_dp_attention
:
self
.
global_num_tokens
[:]
=
[
bs
]
*
self
.
tp_size
# Attention backend
self
.
model_runner
.
attn_backend
.
init_forward_metadata_replay_cuda_graph
(
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
62832bb2
...
...
@@ -138,6 +138,7 @@ class ForwardBatch:
# For DP attention
global_num_tokens
:
Optional
[
List
[
int
]]
=
None
gathered_buffer
:
Optional
[
torch
.
Tensor
]
=
None
can_run_dp_cuda_graph
:
bool
=
False
def
compute_mrope_positions
(
self
,
model_runner
:
ModelRunner
,
batch
:
ModelWorkerBatch
...
...
@@ -221,6 +222,7 @@ class ForwardBatch:
return_logprob
=
batch
.
return_logprob
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
global_num_tokens
=
batch
.
global_num_tokens
,
can_run_dp_cuda_graph
=
batch
.
can_run_dp_cuda_graph
,
lora_paths
=
batch
.
lora_paths
,
sampling_info
=
batch
.
sampling_info
,
)
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
62832bb2
...
...
@@ -592,6 +592,9 @@ class ModelRunner:
)
def
forward_idle
(
self
,
forward_batch
:
ForwardBatch
):
if
self
.
cuda_graph_runner
and
self
.
cuda_graph_runner
.
can_run
(
forward_batch
):
return
self
.
cuda_graph_runner
.
replay
(
forward_batch
)
return
self
.
model
.
forward
(
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
)
...
...
python/sglang/srt/server_args.py
View file @
62832bb2
...
...
@@ -191,11 +191,12 @@ class ServerArgs:
if
self
.
enable_dp_attention
:
self
.
dp_size
=
self
.
tp_size
self
.
chunked_prefill_size
=
self
.
chunked_prefill_size
//
2
self
.
disable_cuda_graph
=
True
self
.
cuda_graph_max_bs
=
min
(
self
.
cuda_graph_max_bs
,
96
)
self
.
enable_overlap_schedule
=
False
logger
.
warning
(
f
"DP attention is enabled. The chunked prefill size is adjusted to
{
self
.
chunked_prefill_size
}
to avoid MoE workload issue. "
"The CUDA graph is disabled. Data parallel size is adjust to be the same as tensor parallel size."
f
"The CUDA graph max batch size is adjusted to
{
self
.
cuda_graph_max_bs
}
. "
"Data parallel size is adjusted to be the same as tensor parallel size."
)
if
self
.
enable_overlap_schedule
:
...
...
scripts/playground/reference_hf.py
View file @
62832bb2
...
...
@@ -31,7 +31,7 @@ from transformers import AutoModelForCausalLM
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
@
torch
.
inference_mode
()
@
torch
.
no_grad
()
def
normal_text
(
args
):
t
=
get_tokenizer
(
args
.
model_path
,
trust_remote_code
=
True
)
m
=
AutoModelForCausalLM
.
from_pretrained
(
...
...
@@ -69,7 +69,7 @@ def normal_text(args):
print
(
output_str
)
@
torch
.
inference_mode
()
@
torch
.
no_grad
()
def
synthetic_tokens
(
args
):
m
=
AutoModelForCausalLM
.
from_pretrained
(
args
.
model_path
,
torch_dtype
=
torch
.
float16
,
low_cpu_mem_usage
=
True
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment