Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2a754e57
Unverified
Commit
2a754e57
authored
Jul 03, 2024
by
Ying Sheng
Committed by
GitHub
Jul 03, 2024
Browse files
2x performance improvement for large prefill & Fix workspace conflicts (#579)
parent
96c503eb
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
88 additions
and
25 deletions
+88
-25
docs/test_process.md
docs/test_process.md
+10
-0
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+4
-4
python/sglang/global_config.py
python/sglang/global_config.py
+3
-0
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+20
-2
python/sglang/srt/managers/controller/model_runner.py
python/sglang/srt/managers/controller/model_runner.py
+50
-18
python/sglang/srt/server.py
python/sglang/srt/server.py
+1
-1
No files found.
docs/test_process.md
View file @
2a754e57
## SRT Unit Tests
### Latency Alignment
Make sure your changes do not slow down the following benchmarks
```
# single gpu
python -m sglang.bench_latency --model-path meta-llama/Llama-2-7b-chat-hf --mem-fraction-static 0.8 --batch 32 --input-len 512 --output-len 256
python -m sglang.bench_latency --model-path meta-llama/Llama-2-7b-chat-hf --mem-fraction-static 0.8 --batch 1 --input-len 512 --output-len 256
# multiple gpu
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-70B --tp 8 --mem-fraction-static 0.6 --batch 32 --input-len 8192 --output-len 1
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-70B --tp 8 --mem-fraction-static 0.6 --batch 1 --input-len 8100 --output-len 32
# moe model
python -m sglang.bench_latency --model-path databricks/dbrx-base --tp 8 --mem-fraction-static 0.6 --batch 4 --input-len 1024 --output-len 32
```
### High-level API
...
...
python/sglang/bench_latency.py
View file @
2a754e57
...
...
@@ -230,7 +230,7 @@ def latency_test(
prefill_latency
=
time
.
time
()
-
tic
tot_latency
+=
prefill_latency
throughput
=
bench_args
.
input_len
*
bench_args
.
batch_size
/
prefill_latency
rank_print
(
f
"Prefill. latency:
{
prefill_latency
:
6.5
f
}
m
s, throughput:
{
throughput
:
9.2
f
}
token/s"
)
rank_print
(
f
"Prefill. latency:
{
prefill_latency
:
6.5
f
}
s, throughput:
{
throughput
:
9.2
f
}
token/s"
)
# Decode
for
i
in
range
(
output_len
):
...
...
@@ -241,13 +241,13 @@ def latency_test(
latency
=
time
.
time
()
-
tic
tot_latency
+=
latency
throughput
=
bench_args
.
batch_size
/
latency
if
i
<
5
:
rank_print
(
f
"Decode. latency:
{
latency
:
6.5
f
}
m
s, throughput:
{
throughput
:
9.2
f
}
token/s"
)
if
i
<
5
:
rank_print
(
f
"Decode. latency:
{
latency
:
6.5
f
}
s, throughput:
{
throughput
:
9.2
f
}
token/s"
)
avg_decode_latency
=
(
tot_latency
-
prefill_latency
)
/
output_len
avg_decode_throughput
=
bench_args
.
batch_size
/
avg_decode_latency
rank_print
(
f
"Decode. avg latency:
{
avg_decode_latency
:
6.5
f
}
m
s, avg throughput:
{
avg_decode_throughput
:
9.2
f
}
token/s"
)
rank_print
(
f
"Decode. avg latency:
{
avg_decode_latency
:
6.5
f
}
s, avg throughput:
{
avg_decode_throughput
:
9.2
f
}
token/s"
)
throughput
=
(
bench_args
.
input_len
+
bench_args
.
output_len
)
*
bench_args
.
batch_size
/
tot_latency
rank_print
(
f
"Total. latency:
{
tot_latency
:
6.3
f
}
m
s, throughput:
{
throughput
:
9.2
f
}
token/s"
)
rank_print
(
f
"Total. latency:
{
tot_latency
:
6.3
f
}
s, throughput:
{
throughput
:
9.2
f
}
token/s"
)
# Warm up
run_once
(
4
)
...
...
python/sglang/global_config.py
View file @
2a754e57
...
...
@@ -35,5 +35,8 @@ class GlobalConfig:
self
.
new_token_ratio_decay
=
0.0001
self
.
new_token_ratio_recovery
=
0.05
# The threshold (number of tokens) to trigger layer-wise cuda sync.
# This can improve the speed for large batch sizes during prefill.
self
.
layer_sync_threshold
=
8192
global_config
=
GlobalConfig
()
python/sglang/srt/layers/radix_attention.py
View file @
2a754e57
...
...
@@ -4,6 +4,7 @@ import numpy as np
import
torch
from
torch
import
nn
from
sglang.global_config
import
global_config
from
sglang.srt.layers.context_flashattention_nopad
import
context_attention_fwd
from
sglang.srt.layers.extend_attention
import
extend_attention_fwd
from
sglang.srt.layers.token_attention
import
token_attention_fwd
...
...
@@ -103,12 +104,29 @@ class RadixAttention(nn.Module):
def
prefill_forward_flashinfer
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
o
=
input_metadata
.
flashinfer_prefill_wrapper
.
forward
(
o
1
,
s1
=
input_metadata
.
flashinfer_prefill_wrapper
_ragged
.
forward
_return_lse
(
q
.
contiguous
().
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
input_metadata
.
token_to_kv_pool
.
kv_data
[
self
.
layer_id
],
k
.
contiguous
().
view
(
-
1
,
self
.
tp_k_head_num
,
self
.
head_dim
),
v
.
contiguous
().
view
(
-
1
,
self
.
tp_v_head_num
,
self
.
head_dim
),
logits_soft_cap
=
self
.
logit_cap
,
)
if
input_metadata
.
no_prefix
:
o
=
o1
else
:
o2
,
s2
=
input_metadata
.
flashinfer_prefill_wrapper_paged
.
forward_return_lse
(
q
.
contiguous
().
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
input_metadata
.
token_to_kv_pool
.
kv_data
[
self
.
layer_id
],
causal
=
False
,
logits_soft_cap
=
self
.
logit_cap
,
)
from
flashinfer.cascade
import
merge_state
o
,
_
=
merge_state
(
o1
,
s1
,
o2
,
s2
)
if
input_metadata
.
total_num_tokens
>=
global_config
.
layer_sync_threshold
:
torch
.
cuda
.
synchronize
()
return
o
.
view
(
-
1
,
self
.
tp_q_head_num
*
self
.
head_dim
)
def
decode_forward_flashinfer
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
...
...
python/sglang/srt/managers/controller/model_runner.py
View file @
2a754e57
...
...
@@ -65,23 +65,33 @@ class InputMetadata:
kv_indptr
:
torch
.
Tensor
=
None
kv_indices
:
torch
.
Tensor
=
None
kv_last_page_len
:
torch
.
Tensor
=
None
flashinfer_prefill_wrapper
:
"BatchPrefillWithPagedKVCacheWrapper"
=
None
flashinfer_prefill_wrapper_ragged
:
"BatchPrefillWithRaggedKVCacheWrapper"
=
None
flashinfer_prefill_wrapper_paged
:
"BatchPrefillWithPagedKVCacheWrapper"
=
None
flashinfer_decode_wrapper
:
"BatchDecodeWithPagedKVCacheWrapper"
=
None
def
init_flashinfer_args
(
self
,
num_qo_heads
,
num_kv_heads
,
head_dim
):
if
(
self
.
forward_mode
==
ForwardMode
.
PREFILL
or
self
.
forward_mode
==
ForwardMode
.
EXTEND
):
paged_kernel_lens
=
self
.
prefix_lens
self
.
no_prefix
=
torch
.
all
(
self
.
prefix_lens
==
0
)
else
:
paged_kernel_lens
=
self
.
seq_lens
self
.
kv_indptr
=
torch
.
zeros
(
(
self
.
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
kv_indptr
[
1
:]
=
torch
.
cumsum
(
self
.
seq
_lens
,
dim
=
0
)
self
.
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel
_lens
,
dim
=
0
)
self
.
kv_last_page_len
=
torch
.
ones
(
(
self
.
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
req_pool_indices_cpu
=
self
.
req_pool_indices
.
cpu
().
numpy
()
seq_lens_cpu
=
self
.
seq
_lens
.
cpu
().
numpy
()
paged_kernel_lens_cpu
=
paged_kernel
_lens
.
cpu
().
numpy
()
self
.
kv_indices
=
torch
.
cat
(
[
self
.
req_to_token_pool
.
req_to_token
[
req_pool_indices_cpu
[
i
],
:
seq
_lens_cpu
[
i
]
req_pool_indices_cpu
[
i
],
:
paged_kernel
_lens_cpu
[
i
]
]
for
i
in
range
(
self
.
batch_size
)
],
...
...
@@ -92,13 +102,24 @@ class InputMetadata:
self
.
forward_mode
==
ForwardMode
.
PREFILL
or
self
.
forward_mode
==
ForwardMode
.
EXTEND
):
# extend part
self
.
qo_indptr
=
torch
.
zeros
(
(
self
.
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
qo_indptr
[
1
:]
=
torch
.
cumsum
(
self
.
extend_seq_lens
,
dim
=
0
)
self
.
flashinfer_prefill_wrapper
.
end_forward
()
self
.
flashinfer_prefill_wrapper
.
begin_forward
(
self
.
flashinfer_prefill_wrapper_ragged
.
end_forward
()
self
.
flashinfer_prefill_wrapper_ragged
.
begin_forward
(
self
.
qo_indptr
,
self
.
qo_indptr
.
clone
(),
num_qo_heads
,
num_kv_heads
,
head_dim
,
)
# cached part
self
.
flashinfer_prefill_wrapper_paged
.
end_forward
()
self
.
flashinfer_prefill_wrapper_paged
.
begin_forward
(
self
.
qo_indptr
,
self
.
kv_indptr
,
self
.
kv_indices
,
...
...
@@ -143,7 +164,8 @@ class InputMetadata:
out_cache_cont_end
=
None
,
top_logprobs_nums
=
None
,
return_logprob
=
False
,
flashinfer_prefill_wrapper
=
None
,
flashinfer_prefill_wrapper_ragged
=
None
,
flashinfer_prefill_wrapper_paged
=
None
,
flashinfer_decode_wrapper
=
None
,
):
batch_size
=
len
(
req_pool_indices
)
...
...
@@ -194,7 +216,8 @@ class InputMetadata:
other_kv_index
=
other_kv_index
,
return_logprob
=
return_logprob
,
top_logprobs_nums
=
top_logprobs_nums
,
flashinfer_prefill_wrapper
=
flashinfer_prefill_wrapper
,
flashinfer_prefill_wrapper_ragged
=
flashinfer_prefill_wrapper_ragged
,
flashinfer_prefill_wrapper_paged
=
flashinfer_prefill_wrapper_paged
,
flashinfer_decode_wrapper
=
flashinfer_decode_wrapper
,
)
...
...
@@ -361,6 +384,7 @@ class ModelRunner:
def
init_flash_infer
(
self
):
if
not
global_server_args_dict
.
get
(
"disable_flashinfer"
,
False
):
from
flashinfer
import
(
BatchPrefillWithRaggedKVCacheWrapper
,
BatchPrefillWithPagedKVCacheWrapper
,
BatchDecodeWithPagedKVCacheWrapper
,
)
...
...
@@ -373,17 +397,21 @@ class ModelRunner:
else
:
use_tensor_cores
=
False
workspace_buffer
=
torch
.
empty
(
128
*
1024
*
1024
,
dtype
=
torch
.
int8
,
device
=
"cuda"
workspace_buffers
=
torch
.
empty
(
3
,
96
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
self
.
flashinfer_prefill_wrapper_ragged
=
BatchPrefillWithRaggedKVCacheWrapper
(
workspace_buffers
[
0
],
"NHD"
)
self
.
flashinfer_prefill_wrapper
=
BatchPrefillWithPagedKVCacheWrapper
(
workspace_buffer
,
"NHD"
self
.
flashinfer_prefill_wrapper
_paged
=
BatchPrefillWithPagedKVCacheWrapper
(
workspace_buffer
s
[
1
]
,
"NHD"
)
self
.
flashinfer_decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
workspace_buffer
,
"NHD"
,
use_tensor_cores
=
use_tensor_cores
workspace_buffer
s
[
2
]
,
"NHD"
,
use_tensor_cores
=
use_tensor_cores
)
else
:
self
.
flashinfer_prefill_wrapper
=
self
.
flashinfer_decode_wrapper
=
None
self
.
flashinfer_prefill_wrapper_ragged
=
self
.
flashinfer_prefill_wrapper_paged
=
None
self
.
flashinfer_decode_wrapper
=
None
@
torch
.
inference_mode
()
def
forward_prefill
(
self
,
batch
:
Batch
):
...
...
@@ -398,7 +426,8 @@ class ModelRunner:
out_cache_loc
=
batch
.
out_cache_loc
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
return_logprob
=
batch
.
return_logprob
,
flashinfer_prefill_wrapper
=
self
.
flashinfer_prefill_wrapper
,
flashinfer_prefill_wrapper_ragged
=
self
.
flashinfer_prefill_wrapper_ragged
,
flashinfer_prefill_wrapper_paged
=
self
.
flashinfer_prefill_wrapper_paged
,
flashinfer_decode_wrapper
=
self
.
flashinfer_decode_wrapper
,
)
return
self
.
model
.
forward
(
...
...
@@ -418,7 +447,8 @@ class ModelRunner:
out_cache_loc
=
batch
.
out_cache_loc
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
return_logprob
=
batch
.
return_logprob
,
flashinfer_prefill_wrapper
=
self
.
flashinfer_prefill_wrapper
,
flashinfer_prefill_wrapper_ragged
=
self
.
flashinfer_prefill_wrapper_ragged
,
flashinfer_prefill_wrapper_paged
=
self
.
flashinfer_prefill_wrapper_paged
,
flashinfer_decode_wrapper
=
self
.
flashinfer_decode_wrapper
,
)
return
self
.
model
.
forward
(
...
...
@@ -440,7 +470,8 @@ class ModelRunner:
out_cache_cont_end
=
batch
.
out_cache_cont_end
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
return_logprob
=
batch
.
return_logprob
,
flashinfer_prefill_wrapper
=
self
.
flashinfer_prefill_wrapper
,
flashinfer_prefill_wrapper_ragged
=
self
.
flashinfer_prefill_wrapper_ragged
,
flashinfer_prefill_wrapper_paged
=
self
.
flashinfer_prefill_wrapper_paged
,
flashinfer_decode_wrapper
=
self
.
flashinfer_decode_wrapper
,
)
return
self
.
model
.
forward
(
...
...
@@ -460,7 +491,8 @@ class ModelRunner:
out_cache_loc
=
batch
.
out_cache_loc
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
return_logprob
=
batch
.
return_logprob
,
flashinfer_prefill_wrapper
=
self
.
flashinfer_prefill_wrapper
,
flashinfer_prefill_wrapper_ragged
=
self
.
flashinfer_prefill_wrapper_ragged
,
flashinfer_prefill_wrapper_paged
=
self
.
flashinfer_prefill_wrapper_paged
,
flashinfer_decode_wrapper
=
self
.
flashinfer_decode_wrapper
,
)
return
self
.
model
.
forward
(
...
...
python/sglang/srt/server.py
View file @
2a754e57
...
...
@@ -152,7 +152,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
if
server_args
.
disable_disk_cache
:
disable_cache
()
if
not
server_args
.
disable_flashinfer
:
assert_pkg_version
(
"flashinfer"
,
"0.0.
7
"
)
assert_pkg_version
(
"flashinfer"
,
"0.0.
8
"
)
if
server_args
.
chat_template
:
# TODO: replace this with huggingface transformers template
load_chat_template_for_openai_api
(
server_args
.
chat_template
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment