Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
722530fa
Unverified
Commit
722530fa
authored
Nov 20, 2024
by
Lianmin Zheng
Committed by
GitHub
Nov 20, 2024
Browse files
Enable overlap scheduler by default for the triton attention backend (#2105)
parent
56a347f7
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
21 additions
and
24 deletions
+21
-24
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+1
-1
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+3
-12
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+11
-0
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+3
-6
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+2
-5
scripts/killall_sglang.sh
scripts/killall_sglang.sh
+1
-0
No files found.
python/sglang/srt/layers/attention/triton_backend.py
View file @
722530fa
...
@@ -53,7 +53,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -53,7 +53,7 @@ class TritonAttnBackend(AttentionBackend):
start_loc
=
torch
.
zeros_like
(
forward_batch
.
seq_lens
,
dtype
=
torch
.
int32
)
start_loc
=
torch
.
zeros_like
(
forward_batch
.
seq_lens
,
dtype
=
torch
.
int32
)
start_loc
[
1
:]
=
torch
.
cumsum
(
forward_batch
.
seq_lens
[:
-
1
],
dim
=
0
)
start_loc
[
1
:]
=
torch
.
cumsum
(
forward_batch
.
seq_lens
[:
-
1
],
dim
=
0
)
total_num_tokens
=
torch
.
sum
(
forward_batch
.
seq_lens
).
item
()
total_num_tokens
=
forward_batch
.
seq_lens
_sum
attn_logits
=
torch
.
empty
(
attn_logits
=
torch
.
empty
(
(
self
.
num_head
,
total_num_tokens
),
(
self
.
num_head
,
total_num_tokens
),
dtype
=
self
.
reduce_dtype
,
dtype
=
self
.
reduce_dtype
,
...
...
python/sglang/srt/managers/scheduler.py
View file @
722530fa
...
@@ -170,18 +170,9 @@ class Scheduler:
...
@@ -170,18 +170,9 @@ class Scheduler:
if
not
self
.
is_generation
:
if
not
self
.
is_generation
:
self
.
enable_overlap
=
False
self
.
enable_overlap
=
False
logger
.
info
(
"Overlap scheduler is disabled for embedding models."
)
logger
.
info
(
"Overlap scheduler is disabled for embedding models."
)
if
(
server_args
.
attention_backend
==
"triton"
if
self
.
enable_overlap
:
or
server_args
.
enable_double_sparsity
self
.
disable_jump_forward
=
True
or
(
self
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
and
not
self
.
server_args
.
disable_mla
)
):
self
.
enable_overlap
=
False
logger
.
info
(
"Overlap scheduler is disabled if using triton attention backend."
)
# Launch a tensor parallel worker
# Launch a tensor parallel worker
if
self
.
enable_overlap
:
if
self
.
enable_overlap
:
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
722530fa
...
@@ -94,10 +94,21 @@ class TpModelWorkerClient:
...
@@ -94,10 +94,21 @@ class TpModelWorkerClient:
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward_thread_func_
(
self
):
def
forward_thread_func_
(
self
):
batch_pt
=
0
batch_lists
=
[
None
]
*
2
while
True
:
while
True
:
model_worker_batch
,
future_token_ids_ct
=
self
.
input_queue
.
get
()
model_worker_batch
,
future_token_ids_ct
=
self
.
input_queue
.
get
()
if
not
model_worker_batch
:
if
not
model_worker_batch
:
break
break
# Keep a reference of model_worker_batch by storing it into a list.
# Otherwise, the tensor members of model_worker_batch will be released
# by pytorch and cause CUDA illegal memory access errors.
batch_lists
[
batch_pt
%
2
]
=
model_worker_batch
batch_pt
+=
1
# Create event
self
.
launch_done
=
threading
.
Event
()
self
.
launch_done
=
threading
.
Event
()
copy_done
=
torch
.
cuda
.
Event
()
copy_done
=
torch
.
cuda
.
Event
()
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
722530fa
...
@@ -170,7 +170,6 @@ class CudaGraphRunner:
...
@@ -170,7 +170,6 @@ class CudaGraphRunner:
self
.
encoder_lens
=
None
self
.
encoder_lens
=
None
if
self
.
enable_dp_attention
:
if
self
.
enable_dp_attention
:
self
.
global_num_tokens
=
[
0
]
*
self
.
tp_size
self
.
gathered_buffer
=
torch
.
zeros
(
self
.
gathered_buffer
=
torch
.
zeros
(
(
(
self
.
max_bs
*
self
.
tp_size
,
self
.
max_bs
*
self
.
tp_size
,
...
@@ -264,10 +263,10 @@ class CudaGraphRunner:
...
@@ -264,10 +263,10 @@ class CudaGraphRunner:
mrope_positions
=
self
.
mrope_positions
[:,
:
bs
]
mrope_positions
=
self
.
mrope_positions
[:,
:
bs
]
if
self
.
enable_dp_attention
:
if
self
.
enable_dp_attention
:
self
.
global_num_tokens
[:]
=
[
bs
]
*
self
.
tp_size
global_num_tokens
=
[
bs
]
*
self
.
tp_size
gathered_buffer
=
self
.
gathered_buffer
[:
bs
*
self
.
tp_size
]
gathered_buffer
=
self
.
gathered_buffer
[:
bs
*
self
.
tp_size
]
else
:
else
:
self
.
global_num_tokens
=
None
global_num_tokens
=
None
gathered_buffer
=
None
gathered_buffer
=
None
# Attention backend
# Attention backend
...
@@ -296,7 +295,7 @@ class CudaGraphRunner:
...
@@ -296,7 +295,7 @@ class CudaGraphRunner:
top_logprobs_nums
=
[
0
]
*
bs
,
top_logprobs_nums
=
[
0
]
*
bs
,
positions
=
clamp_position
(
seq_lens
),
positions
=
clamp_position
(
seq_lens
),
mrope_positions
=
mrope_positions
,
mrope_positions
=
mrope_positions
,
global_num_tokens
=
self
.
global_num_tokens
,
global_num_tokens
=
global_num_tokens
,
gathered_buffer
=
gathered_buffer
,
gathered_buffer
=
gathered_buffer
,
)
)
logits_output
=
forward
(
input_ids
,
forward_batch
.
positions
,
forward_batch
)
logits_output
=
forward
(
input_ids
,
forward_batch
.
positions
,
forward_batch
)
...
@@ -348,8 +347,6 @@ class CudaGraphRunner:
...
@@ -348,8 +347,6 @@ class CudaGraphRunner:
self
.
encoder_lens
[:
raw_bs
].
copy_
(
forward_batch
.
encoder_lens
)
self
.
encoder_lens
[:
raw_bs
].
copy_
(
forward_batch
.
encoder_lens
)
if
forward_batch
.
mrope_positions
is
not
None
:
if
forward_batch
.
mrope_positions
is
not
None
:
self
.
mrope_positions
[:,
:
raw_bs
].
copy_
(
forward_batch
.
mrope_positions
)
self
.
mrope_positions
[:,
:
raw_bs
].
copy_
(
forward_batch
.
mrope_positions
)
if
self
.
enable_dp_attention
:
self
.
global_num_tokens
[:]
=
[
bs
]
*
self
.
tp_size
# Attention backend
# Attention backend
self
.
model_runner
.
attn_backend
.
init_forward_metadata_replay_cuda_graph
(
self
.
model_runner
.
attn_backend
.
init_forward_metadata_replay_cuda_graph
(
...
...
python/sglang/srt/server_args.py
View file @
722530fa
...
@@ -174,17 +174,17 @@ class ServerArgs:
...
@@ -174,17 +174,17 @@ class ServerArgs:
self
.
cuda_graph_max_bs
=
4
self
.
cuda_graph_max_bs
=
4
logger
.
info
(
"Automatically adjust --chunked-prefill-size for small GPUs."
)
logger
.
info
(
"Automatically adjust --chunked-prefill-size for small GPUs."
)
# Choose kernel backends
if
not
is_flashinfer_available
():
if
not
is_flashinfer_available
():
self
.
attention_backend
=
"triton"
self
.
attention_backend
=
"triton"
self
.
sampling_backend
=
"pytorch"
self
.
sampling_backend
=
"pytorch"
# Default kernel backends
if
self
.
attention_backend
is
None
:
if
self
.
attention_backend
is
None
:
self
.
attention_backend
=
"flashinfer"
self
.
attention_backend
=
"flashinfer"
if
self
.
sampling_backend
is
None
:
if
self
.
sampling_backend
is
None
:
self
.
sampling_backend
=
"flashinfer"
self
.
sampling_backend
=
"flashinfer"
# Others
if
self
.
enable_dp_attention
:
if
self
.
enable_dp_attention
:
self
.
dp_size
=
self
.
tp_size
self
.
dp_size
=
self
.
tp_size
self
.
chunked_prefill_size
=
self
.
chunked_prefill_size
//
2
self
.
chunked_prefill_size
=
self
.
chunked_prefill_size
//
2
...
@@ -205,9 +205,6 @@ class ServerArgs:
...
@@ -205,9 +205,6 @@ class ServerArgs:
)
)
self
.
disable_overlap_schedule
=
True
self
.
disable_overlap_schedule
=
True
if
not
self
.
disable_overlap_schedule
:
self
.
disable_jump_forward
=
True
@
staticmethod
@
staticmethod
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
):
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
):
# Model and port args
# Model and port args
...
...
scripts/killall_sglang.sh
View file @
722530fa
...
@@ -2,3 +2,4 @@
...
@@ -2,3 +2,4 @@
kill
-9
$(
ps aux |
grep
'multiprocessing.spawn'
|
grep
-v
'grep'
|
awk
'{print $2}'
)
kill
-9
$(
ps aux |
grep
'multiprocessing.spawn'
|
grep
-v
'grep'
|
awk
'{print $2}'
)
kill
-9
$(
ps aux |
grep
'sglang.launch_server'
|
grep
-v
'grep'
|
awk
'{print $2}'
)
kill
-9
$(
ps aux |
grep
'sglang.launch_server'
|
grep
-v
'grep'
|
awk
'{print $2}'
)
kill
-9
$(
ps aux |
grep
'sglang.bench'
|
grep
-v
'grep'
|
awk
'{print $2}'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment