Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
62832bb2
Unverified
Commit
62832bb2
authored
Nov 18, 2024
by
Ke Bao
Committed by
GitHub
Nov 17, 2024
Browse files
Support cuda graph for DP attention (#2061)
parent
11f881d1
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
88 additions
and
26 deletions
+88
-26
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+10
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+23
-9
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+0
-3
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+1
-4
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+44
-6
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+2
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+3
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+3
-2
scripts/playground/reference_hf.py
scripts/playground/reference_hf.py
+2
-2
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
62832bb2
...
@@ -455,6 +455,7 @@ class ScheduleBatch:
...
@@ -455,6 +455,7 @@ class ScheduleBatch:
# For DP attention
# For DP attention
global_num_tokens
:
Optional
[
List
[
int
]]
=
None
global_num_tokens
:
Optional
[
List
[
int
]]
=
None
can_run_dp_cuda_graph
:
bool
=
False
# For processing logprobs
# For processing logprobs
return_logprob
:
bool
=
False
return_logprob
:
bool
=
False
...
@@ -891,6 +892,13 @@ class ScheduleBatch:
...
@@ -891,6 +892,13 @@ class ScheduleBatch:
self
.
seq_lens
=
torch
.
empty
(
0
,
dtype
=
torch
.
int32
).
to
(
self
.
seq_lens
=
torch
.
empty
(
0
,
dtype
=
torch
.
int32
).
to
(
self
.
device
,
non_blocking
=
True
self
.
device
,
non_blocking
=
True
)
)
self
.
out_cache_loc
=
torch
.
empty
(
0
,
dtype
=
torch
.
int32
).
to
(
self
.
device
,
non_blocking
=
True
)
self
.
req_pool_indices
=
torch
.
empty
(
0
,
dtype
=
torch
.
int32
).
to
(
self
.
device
,
non_blocking
=
True
)
self
.
seq_lens_sum
=
0
self
.
extend_num_tokens
=
0
self
.
extend_num_tokens
=
0
def
prepare_for_decode
(
self
,
enable_overlap
:
bool
=
False
):
def
prepare_for_decode
(
self
,
enable_overlap
:
bool
=
False
):
...
@@ -1032,6 +1040,7 @@ class ScheduleBatch:
...
@@ -1032,6 +1040,7 @@ class ScheduleBatch:
return_logprob
=
self
.
return_logprob
,
return_logprob
=
self
.
return_logprob
,
top_logprobs_nums
=
self
.
top_logprobs_nums
,
top_logprobs_nums
=
self
.
top_logprobs_nums
,
global_num_tokens
=
self
.
global_num_tokens
,
global_num_tokens
=
self
.
global_num_tokens
,
can_run_dp_cuda_graph
=
self
.
can_run_dp_cuda_graph
,
extend_num_tokens
=
self
.
extend_num_tokens
,
extend_num_tokens
=
self
.
extend_num_tokens
,
extend_seq_lens
=
extend_seq_lens
,
extend_seq_lens
=
extend_seq_lens
,
extend_prefix_lens
=
extend_prefix_lens
,
extend_prefix_lens
=
extend_prefix_lens
,
...
@@ -1093,6 +1102,7 @@ class ModelWorkerBatch:
...
@@ -1093,6 +1102,7 @@ class ModelWorkerBatch:
# For DP attention
# For DP attention
global_num_tokens
:
Optional
[
List
[
int
]]
global_num_tokens
:
Optional
[
List
[
int
]]
can_run_dp_cuda_graph
:
bool
# For extend
# For extend
extend_num_tokens
:
Optional
[
int
]
extend_num_tokens
:
Optional
[
int
]
...
...
python/sglang/srt/managers/scheduler.py
View file @
62832bb2
...
@@ -337,7 +337,7 @@ class Scheduler:
...
@@ -337,7 +337,7 @@ class Scheduler:
kill_parent_process
()
kill_parent_process
()
@
torch
.
inference_mode
()
@
torch
.
no_grad
()
def
event_loop_normal
(
self
):
def
event_loop_normal
(
self
):
"""A normal blocking scheduler loop."""
"""A normal blocking scheduler loop."""
self
.
last_batch
=
None
self
.
last_batch
=
None
...
@@ -375,7 +375,7 @@ class Scheduler:
...
@@ -375,7 +375,7 @@ class Scheduler:
self
.
last_batch
=
batch
self
.
last_batch
=
batch
@
torch
.
inference_mode
()
@
torch
.
no_grad
()
def
event_loop_overlap
(
self
):
def
event_loop_overlap
(
self
):
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
result_queue
=
deque
()
result_queue
=
deque
()
...
@@ -411,16 +411,12 @@ class Scheduler:
...
@@ -411,16 +411,12 @@ class Scheduler:
else
:
else
:
num_tokens
=
local_batch
.
extend_num_tokens
num_tokens
=
local_batch
.
extend_num_tokens
local_num_tokens
=
torch
.
tensor
(
local_num_tokens
=
torch
.
tensor
([
num_tokens
],
dtype
=
torch
.
int64
)
num_tokens
,
dtype
=
torch
.
int64
,
device
=
self
.
device
global_num_tokens
=
torch
.
empty
(
self
.
tp_size
,
dtype
=
torch
.
int64
)
)
global_num_tokens
=
torch
.
empty
(
self
.
tp_size
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
torch
.
distributed
.
all_gather_into_tensor
(
torch
.
distributed
.
all_gather_into_tensor
(
global_num_tokens
,
global_num_tokens
,
local_num_tokens
,
local_num_tokens
,
group
=
self
.
tp_
worker
.
get_tp_device
_group
()
,
group
=
self
.
tp_
cpu
_group
,
)
)
if
local_batch
is
None
and
global_num_tokens
.
max
().
item
()
>
0
:
if
local_batch
is
None
and
global_num_tokens
.
max
().
item
()
>
0
:
...
@@ -429,6 +425,24 @@ class Scheduler:
...
@@ -429,6 +425,24 @@ class Scheduler:
if
local_batch
is
not
None
:
if
local_batch
is
not
None
:
local_batch
.
global_num_tokens
=
global_num_tokens
.
tolist
()
local_batch
.
global_num_tokens
=
global_num_tokens
.
tolist
()
# Check forward mode for cuda graph
if
not
self
.
server_args
.
disable_cuda_graph
:
forward_mode_state
=
torch
.
tensor
(
(
1
if
local_batch
.
forward_mode
.
is_decode
()
or
local_batch
.
forward_mode
.
is_idle
()
else
0
),
dtype
=
torch
.
int32
,
)
torch
.
distributed
.
all_reduce
(
forward_mode_state
,
op
=
torch
.
distributed
.
ReduceOp
.
MIN
,
group
=
self
.
tp_cpu_group
,
)
local_batch
.
can_run_dp_cuda_graph
=
forward_mode_state
.
item
()
==
1
return
local_batch
return
local_batch
def
get_idle_batch
(
self
):
def
get_idle_batch
(
self
):
...
...
python/sglang/srt/managers/tp_worker.py
View file @
62832bb2
...
@@ -128,9 +128,6 @@ class TpModelWorker:
...
@@ -128,9 +128,6 @@ class TpModelWorker:
def
get_tp_cpu_group
(
self
):
def
get_tp_cpu_group
(
self
):
return
self
.
model_runner
.
tp_group
.
cpu_group
return
self
.
model_runner
.
tp_group
.
cpu_group
def
get_tp_device_group
(
self
):
return
self
.
model_runner
.
tp_group
.
device_group
def
get_memory_pool
(
self
):
def
get_memory_pool
(
self
):
return
(
return
(
self
.
model_runner
.
req_to_token_pool
,
self
.
model_runner
.
req_to_token_pool
,
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
62832bb2
...
@@ -83,9 +83,6 @@ class TpModelWorkerClient:
...
@@ -83,9 +83,6 @@ class TpModelWorkerClient:
def
get_tp_cpu_group
(
self
):
def
get_tp_cpu_group
(
self
):
return
self
.
worker
.
get_tp_cpu_group
()
return
self
.
worker
.
get_tp_cpu_group
()
def
get_tp_device_group
(
self
):
return
self
.
worker
.
get_tp_device_group
()
def
get_memory_pool
(
self
):
def
get_memory_pool
(
self
):
return
(
return
(
self
.
worker
.
model_runner
.
req_to_token_pool
,
self
.
worker
.
model_runner
.
req_to_token_pool
,
...
@@ -96,7 +93,7 @@ class TpModelWorkerClient:
...
@@ -96,7 +93,7 @@ class TpModelWorkerClient:
with
torch
.
cuda
.
stream
(
self
.
forward_stream
):
with
torch
.
cuda
.
stream
(
self
.
forward_stream
):
self
.
forward_thread_func_
()
self
.
forward_thread_func_
()
@
torch
.
inference_mode
()
@
torch
.
no_grad
()
def
forward_thread_func_
(
self
):
def
forward_thread_func_
(
self
):
while
True
:
while
True
:
model_worker_batch
,
future_token_ids_ct
=
self
.
input_queue
.
get
()
model_worker_batch
,
future_token_ids_ct
=
self
.
input_queue
.
get
()
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
62832bb2
...
@@ -111,6 +111,8 @@ class CudaGraphRunner:
...
@@ -111,6 +111,8 @@ class CudaGraphRunner:
self
.
use_torch_compile
=
model_runner
.
server_args
.
enable_torch_compile
self
.
use_torch_compile
=
model_runner
.
server_args
.
enable_torch_compile
self
.
disable_padding
=
model_runner
.
server_args
.
disable_cuda_graph_padding
self
.
disable_padding
=
model_runner
.
server_args
.
disable_cuda_graph_padding
self
.
is_encoder_decoder
=
self
.
model_runner
.
model_config
.
is_encoder_decoder
self
.
is_encoder_decoder
=
self
.
model_runner
.
model_config
.
is_encoder_decoder
self
.
enable_dp_attention
=
self
.
model_runner
.
server_args
.
enable_dp_attention
self
.
tp_size
=
self
.
model_runner
.
tp_size
# Batch sizes to capture
# Batch sizes to capture
if
model_runner
.
server_args
.
disable_cuda_graph_padding
:
if
model_runner
.
server_args
.
disable_cuda_graph_padding
:
...
@@ -165,6 +167,16 @@ class CudaGraphRunner:
...
@@ -165,6 +167,16 @@ class CudaGraphRunner:
else
:
else
:
self
.
encoder_lens
=
None
self
.
encoder_lens
=
None
if
self
.
enable_dp_attention
:
self
.
global_num_tokens
=
[
0
]
*
self
.
tp_size
self
.
gathered_buffer
=
torch
.
zeros
(
(
self
.
max_bs
*
self
.
tp_size
,
self
.
model_runner
.
model_config
.
hidden_size
,
),
dtype
=
self
.
model_runner
.
dtype
,
)
# Capture
# Capture
try
:
try
:
with
self
.
model_capture_mode
():
with
self
.
model_capture_mode
():
...
@@ -190,6 +202,16 @@ class CudaGraphRunner:
...
@@ -190,6 +202,16 @@ class CudaGraphRunner:
self
.
model_runner
.
model
.
capture_mode
=
False
self
.
model_runner
.
model
.
capture_mode
=
False
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
if
self
.
enable_dp_attention
:
min_num_tokens
,
max_num_tokens
=
min
(
forward_batch
.
global_num_tokens
),
max
(
forward_batch
.
global_num_tokens
)
is_bs_supported
=
forward_batch
.
can_run_dp_cuda_graph
and
(
(
min_num_tokens
==
max_num_tokens
and
max_num_tokens
in
self
.
graphs
)
if
self
.
disable_padding
else
max_num_tokens
<=
self
.
max_bs
)
else
:
is_bs_supported
=
(
is_bs_supported
=
(
forward_batch
.
batch_size
in
self
.
graphs
forward_batch
.
batch_size
in
self
.
graphs
if
self
.
disable_padding
if
self
.
disable_padding
...
@@ -239,6 +261,13 @@ class CudaGraphRunner:
...
@@ -239,6 +261,13 @@ class CudaGraphRunner:
seq_lens_sum
=
seq_lens
.
sum
().
item
()
seq_lens_sum
=
seq_lens
.
sum
().
item
()
mrope_positions
=
self
.
mrope_positions
[:,
:
bs
]
mrope_positions
=
self
.
mrope_positions
[:,
:
bs
]
if
self
.
enable_dp_attention
:
self
.
global_num_tokens
[:]
=
[
bs
]
*
self
.
tp_size
gathered_buffer
=
self
.
gathered_buffer
[:
bs
*
self
.
tp_size
]
else
:
self
.
global_num_tokens
=
None
gathered_buffer
=
None
# Attention backend
# Attention backend
self
.
model_runner
.
attn_backend
.
init_forward_metadata_capture_cuda_graph
(
self
.
model_runner
.
attn_backend
.
init_forward_metadata_capture_cuda_graph
(
bs
,
bs
,
...
@@ -265,6 +294,8 @@ class CudaGraphRunner:
...
@@ -265,6 +294,8 @@ class CudaGraphRunner:
top_logprobs_nums
=
[
0
]
*
bs
,
top_logprobs_nums
=
[
0
]
*
bs
,
positions
=
clamp_position
(
seq_lens
),
positions
=
clamp_position
(
seq_lens
),
mrope_positions
=
mrope_positions
,
mrope_positions
=
mrope_positions
,
global_num_tokens
=
self
.
global_num_tokens
,
gathered_buffer
=
gathered_buffer
,
)
)
logits_output
=
forward
(
input_ids
,
forward_batch
.
positions
,
forward_batch
)
logits_output
=
forward
(
input_ids
,
forward_batch
.
positions
,
forward_batch
)
return
logits_output
.
next_token_logits
return
logits_output
.
next_token_logits
...
@@ -295,6 +326,11 @@ class CudaGraphRunner:
...
@@ -295,6 +326,11 @@ class CudaGraphRunner:
raw_bs
=
forward_batch
.
batch_size
raw_bs
=
forward_batch
.
batch_size
# Pad
# Pad
if
self
.
enable_dp_attention
:
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
max
(
forward_batch
.
global_num_tokens
)
)
else
:
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
raw_bs
)
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
raw_bs
)
bs
=
self
.
capture_bs
[
index
]
bs
=
self
.
capture_bs
[
index
]
if
bs
!=
raw_bs
:
if
bs
!=
raw_bs
:
...
@@ -310,6 +346,8 @@ class CudaGraphRunner:
...
@@ -310,6 +346,8 @@ class CudaGraphRunner:
self
.
encoder_lens
[:
raw_bs
].
copy_
(
forward_batch
.
encoder_lens
)
self
.
encoder_lens
[:
raw_bs
].
copy_
(
forward_batch
.
encoder_lens
)
if
forward_batch
.
mrope_positions
is
not
None
:
if
forward_batch
.
mrope_positions
is
not
None
:
self
.
mrope_positions
[:,
:
raw_bs
].
copy_
(
forward_batch
.
mrope_positions
)
self
.
mrope_positions
[:,
:
raw_bs
].
copy_
(
forward_batch
.
mrope_positions
)
if
self
.
enable_dp_attention
:
self
.
global_num_tokens
[:]
=
[
bs
]
*
self
.
tp_size
# Attention backend
# Attention backend
self
.
model_runner
.
attn_backend
.
init_forward_metadata_replay_cuda_graph
(
self
.
model_runner
.
attn_backend
.
init_forward_metadata_replay_cuda_graph
(
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
62832bb2
...
@@ -138,6 +138,7 @@ class ForwardBatch:
...
@@ -138,6 +138,7 @@ class ForwardBatch:
# For DP attention
# For DP attention
global_num_tokens
:
Optional
[
List
[
int
]]
=
None
global_num_tokens
:
Optional
[
List
[
int
]]
=
None
gathered_buffer
:
Optional
[
torch
.
Tensor
]
=
None
gathered_buffer
:
Optional
[
torch
.
Tensor
]
=
None
can_run_dp_cuda_graph
:
bool
=
False
def
compute_mrope_positions
(
def
compute_mrope_positions
(
self
,
model_runner
:
ModelRunner
,
batch
:
ModelWorkerBatch
self
,
model_runner
:
ModelRunner
,
batch
:
ModelWorkerBatch
...
@@ -221,6 +222,7 @@ class ForwardBatch:
...
@@ -221,6 +222,7 @@ class ForwardBatch:
return_logprob
=
batch
.
return_logprob
,
return_logprob
=
batch
.
return_logprob
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
global_num_tokens
=
batch
.
global_num_tokens
,
global_num_tokens
=
batch
.
global_num_tokens
,
can_run_dp_cuda_graph
=
batch
.
can_run_dp_cuda_graph
,
lora_paths
=
batch
.
lora_paths
,
lora_paths
=
batch
.
lora_paths
,
sampling_info
=
batch
.
sampling_info
,
sampling_info
=
batch
.
sampling_info
,
)
)
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
62832bb2
...
@@ -592,6 +592,9 @@ class ModelRunner:
...
@@ -592,6 +592,9 @@ class ModelRunner:
)
)
def
forward_idle
(
self
,
forward_batch
:
ForwardBatch
):
def
forward_idle
(
self
,
forward_batch
:
ForwardBatch
):
if
self
.
cuda_graph_runner
and
self
.
cuda_graph_runner
.
can_run
(
forward_batch
):
return
self
.
cuda_graph_runner
.
replay
(
forward_batch
)
return
self
.
model
.
forward
(
return
self
.
model
.
forward
(
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
)
)
...
...
python/sglang/srt/server_args.py
View file @
62832bb2
...
@@ -191,11 +191,12 @@ class ServerArgs:
...
@@ -191,11 +191,12 @@ class ServerArgs:
if
self
.
enable_dp_attention
:
if
self
.
enable_dp_attention
:
self
.
dp_size
=
self
.
tp_size
self
.
dp_size
=
self
.
tp_size
self
.
chunked_prefill_size
=
self
.
chunked_prefill_size
//
2
self
.
chunked_prefill_size
=
self
.
chunked_prefill_size
//
2
self
.
disable_cuda_graph
=
True
self
.
cuda_graph_max_bs
=
min
(
self
.
cuda_graph_max_bs
,
96
)
self
.
enable_overlap_schedule
=
False
self
.
enable_overlap_schedule
=
False
logger
.
warning
(
logger
.
warning
(
f
"DP attention is enabled. The chunked prefill size is adjusted to
{
self
.
chunked_prefill_size
}
to avoid MoE workload issue. "
f
"DP attention is enabled. The chunked prefill size is adjusted to
{
self
.
chunked_prefill_size
}
to avoid MoE workload issue. "
"The CUDA graph is disabled. Data parallel size is adjust to be the same as tensor parallel size."
f
"The CUDA graph max batch size is adjusted to
{
self
.
cuda_graph_max_bs
}
. "
"Data parallel size is adjusted to be the same as tensor parallel size."
)
)
if
self
.
enable_overlap_schedule
:
if
self
.
enable_overlap_schedule
:
...
...
scripts/playground/reference_hf.py
View file @
62832bb2
...
@@ -31,7 +31,7 @@ from transformers import AutoModelForCausalLM
...
@@ -31,7 +31,7 @@ from transformers import AutoModelForCausalLM
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
@
torch
.
inference_mode
()
@
torch
.
no_grad
()
def
normal_text
(
args
):
def
normal_text
(
args
):
t
=
get_tokenizer
(
args
.
model_path
,
trust_remote_code
=
True
)
t
=
get_tokenizer
(
args
.
model_path
,
trust_remote_code
=
True
)
m
=
AutoModelForCausalLM
.
from_pretrained
(
m
=
AutoModelForCausalLM
.
from_pretrained
(
...
@@ -69,7 +69,7 @@ def normal_text(args):
...
@@ -69,7 +69,7 @@ def normal_text(args):
print
(
output_str
)
print
(
output_str
)
@
torch
.
inference_mode
()
@
torch
.
no_grad
()
def
synthetic_tokens
(
args
):
def
synthetic_tokens
(
args
):
m
=
AutoModelForCausalLM
.
from_pretrained
(
m
=
AutoModelForCausalLM
.
from_pretrained
(
args
.
model_path
,
torch_dtype
=
torch
.
float16
,
low_cpu_mem_usage
=
True
args
.
model_path
,
torch_dtype
=
torch
.
float16
,
low_cpu_mem_usage
=
True
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment