Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
66581596
"git@developer.sourcefind.cn:change/sglang.git" did not exist on "d7bc19a46a18ef1e7662bfa2efa34ad24d1580a0"
Unverified
Commit
66581596
authored
Jul 13, 2024
by
Lianmin Zheng
Committed by
GitHub
Jul 13, 2024
Browse files
Enable cuda graph by default (#612)
parent
396a6924
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
317 additions
and
70 deletions
+317
-70
benchmark/latency_throughput/bench_one.py
benchmark/latency_throughput/bench_one.py
+58
-24
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+0
-1
python/sglang/global_config.py
python/sglang/global_config.py
+21
-17
python/sglang/srt/managers/controller/cuda_graph_runner.py
python/sglang/srt/managers/controller/cuda_graph_runner.py
+173
-0
python/sglang/srt/managers/controller/infer_batch.py
python/sglang/srt/managers/controller/infer_batch.py
+13
-6
python/sglang/srt/managers/controller/model_runner.py
python/sglang/srt/managers/controller/model_runner.py
+30
-11
python/sglang/srt/managers/controller/tp_worker.py
python/sglang/srt/managers/controller/tp_worker.py
+8
-6
python/sglang/srt/memory_pool.py
python/sglang/srt/memory_pool.py
+9
-1
python/sglang/srt/server.py
python/sglang/srt/server.py
+1
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+4
-4
No files found.
benchmark/latency_throughput/bench_one.py
View file @
66581596
"""
Usage:
python3 bench_one.py --input-len 2048 --batch-size 1 2 4 8 16 32 64 128 256 512
"""
import
argparse
import
argparse
import
json
import
time
import
time
import
numpy
as
np
import
requests
import
requests
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"http://127.0.0.1"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
"--backend"
,
type
=
str
,
default
=
"srt"
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--max-tokens"
,
type
=
int
,
default
=
256
)
args
=
parser
.
parse_args
()
if
args
.
port
is
None
:
if
args
.
backend
==
"srt"
:
args
.
port
=
30000
elif
args
.
backend
==
"vllm"
:
args
.
port
=
21000
elif
args
.
backend
==
"lightllm"
:
args
.
port
=
22000
elif
args
.
backend
==
"ginfer"
:
args
.
port
=
9988
else
:
raise
ValueError
(
f
"Invalid backend:
{
args
.
backend
}
"
)
def
run_one_batch_size
(
bs
):
url
=
f
"
{
args
.
host
}
:
{
args
.
port
}
"
url
=
f
"
{
args
.
host
}
:
{
args
.
port
}
"
a
=
20
max_new_tokens
=
args
.
max_tokens
max_new_tokens
=
args
.
max_tokens
a
=
20
prompt
=
f
"
{
a
,
}
"
prompt
=
f
"
{
a
,
}
"
tic
=
time
.
time
()
tic
=
time
.
time
()
if
args
.
backend
==
"srt"
:
if
args
.
backend
==
"srt"
:
if
args
.
input_len
:
inputs
=
{
"input_ids"
:
[
[
int
(
x
)
for
x
in
np
.
random
.
randint
(
0
,
high
=
16384
,
size
=
(
args
.
input_len
,))]
for
_
in
range
(
bs
)
]}
else
:
inputs
=
{
"text"
:
[
f
"
{
i
,
}
"
for
i
in
range
(
bs
)
]}
response
=
requests
.
post
(
response
=
requests
.
post
(
url
+
"/generate"
,
url
+
"/generate"
,
json
=
{
json
=
{
"text"
:
[
prompt
]
*
args
.
batch_size
,
"sampling_params"
:
{
"sampling_params"
:
{
"temperature"
:
0
,
"temperature"
:
0
,
"max_new_tokens"
:
max_new_tokens
,
"max_new_tokens"
:
max_new_tokens
,
"ignore_eos"
:
True
,
"ignore_eos"
:
True
,
},
},
**
inputs
,
},
},
)
)
elif
args
.
backend
==
"lightllm"
:
elif
args
.
backend
==
"lightllm"
:
...
@@ -91,5 +89,41 @@ if __name__ == "__main__":
...
@@ -91,5 +89,41 @@ if __name__ == "__main__":
ret
=
response
.
json
()
ret
=
response
.
json
()
print
(
ret
)
print
(
ret
)
speed
=
args
.
batch_size
*
max_new_tokens
/
latency
output_throughput
=
bs
*
max_new_tokens
/
latency
print
(
f
"latency:
{
latency
:.
2
f
}
s, speed:
{
speed
:.
2
f
}
token/s"
)
print
(
f
"latency:
{
latency
:.
2
f
}
s, speed:
{
output_throughput
:.
2
f
}
token/s"
)
with
open
(
"tmp_output.txt"
,
"a"
)
as
fout
:
res
=
{
"input_len"
:
args
.
input_len
,
"output_len"
:
args
.
max_tokens
,
"batch_size"
:
bs
,
"latency"
:
latency
,
"output_throughput"
:
output_throughput
}
fout
.
write
(
json
.
dumps
(
res
)
+
"
\n
"
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"http://127.0.0.1"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
"--backend"
,
type
=
str
,
default
=
"srt"
)
parser
.
add_argument
(
"--input-len"
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
nargs
=
'*'
,
default
=
[
1
])
parser
.
add_argument
(
"--max-tokens"
,
type
=
int
,
default
=
256
)
args
=
parser
.
parse_args
()
if
args
.
port
is
None
:
if
args
.
backend
==
"srt"
:
args
.
port
=
30000
elif
args
.
backend
==
"vllm"
:
args
.
port
=
21000
elif
args
.
backend
==
"lightllm"
:
args
.
port
=
22000
elif
args
.
backend
==
"ginfer"
:
args
.
port
=
9988
else
:
raise
ValueError
(
f
"Invalid backend:
{
args
.
backend
}
"
)
for
bs
in
args
.
batch_size
:
run_one_batch_size
(
bs
)
python/sglang/bench_latency.py
View file @
66581596
...
@@ -30,7 +30,6 @@ import argparse
...
@@ -30,7 +30,6 @@ import argparse
import
dataclasses
import
dataclasses
import
logging
import
logging
import
multiprocessing
import
multiprocessing
import
os
import
time
import
time
...
...
python/sglang/global_config.py
View file @
66581596
...
@@ -8,36 +8,40 @@ class GlobalConfig:
...
@@ -8,36 +8,40 @@ class GlobalConfig:
# 2: output final text after every run
# 2: output final text after every run
self
.
verbosity
=
0
self
.
verbosity
=
0
# Default backend of the language
self
.
default_backend
=
None
self
.
default_backend
=
None
# Output configs
# Runtime constants: Request dependency time due to network delay
self
.
request_dependency_delay
=
0.02
self
.
wait_for_new_request_delay
=
0.0006
# Runtime constants: New generation token ratio estimation
self
.
base_new_token_ratio
=
0.4
self
.
base_min_new_token_ratio
=
0.2
self
.
new_token_ratio_decay
=
0.0001
self
.
new_token_ratio_recovery
=
0.05
# Runtime constants: The threshold (number of tokens) to trigger layer-wise cuda sync.
# This can improve the speed for large batch sizes during prefill.
self
.
layer_sync_threshold
=
8192
# Runtime constants: Flashinfer
self
.
flashinfer_workspace_size
=
192
*
1024
*
1024
# Output tokenization configs
self
.
skip_special_tokens_in_output
=
True
self
.
skip_special_tokens_in_output
=
True
self
.
spaces_between_special_tokens_in_out
=
True
self
.
spaces_between_special_tokens_in_out
=
True
#
O
ptimization configs
#
Interpreter o
ptimization configs
self
.
eager_fill_image
=
False
self
.
eager_fill_image
=
False
self
.
enable_precache_with_tracing
=
True
self
.
enable_precache_with_tracing
=
True
self
.
enable_parallel_encoding
=
True
self
.
enable_parallel_encoding
=
True
self
.
enable_parallel_decoding
=
True
self
.
enable_parallel_decoding
=
True
# Deprecated
# Choices: ["no_adjust", "adjust_cache"]
# Choices: ["no_adjust", "adjust_cache"]
# no_adjust: Do not adjust the position embedding of KV cache.
# no_adjust: Do not adjust the position embedding of KV cache.
# adjust_cache: Adjust the position embedding of KV cache.
# adjust_cache: Adjust the position embedding of KV cache.
self
.
concate_and_append_mode
=
"no_adjust"
self
.
concate_and_append_mode
=
"no_adjust"
# Request dependency time due to network delay
self
.
request_dependency_delay
=
0.02
self
.
wait_for_new_request_delay
=
0.0006
# New generation token ratio estimation
self
.
base_new_token_ratio
=
0.4
self
.
base_min_new_token_ratio
=
0.2
self
.
new_token_ratio_decay
=
0.0001
self
.
new_token_ratio_recovery
=
0.05
# The threshold (number of tokens) to trigger layer-wise cuda sync.
# This can improve the speed for large batch sizes during prefill.
self
.
layer_sync_threshold
=
8192
global_config
=
GlobalConfig
()
global_config
=
GlobalConfig
()
python/sglang/srt/managers/controller/cuda_graph_runner.py
0 → 100644
View file @
66581596
"""Run the model with cuda graph."""
import
bisect
import
torch
from
vllm.distributed.parallel_state
import
graph_capture
from
sglang.global_config
import
global_config
from
sglang.srt.layers.logits_processor
import
LogitProcessorOutput
from
sglang.srt.managers.controller.infer_batch
import
(
Batch
,
ForwardMode
,
InputMetadata
,
init_flashinfer_args
)
class
CudaGraphRunner
:
def
__init__
(
self
,
model_runner
,
max_batch_size_to_capture
):
self
.
model_runner
=
model_runner
self
.
graphs
=
{}
self
.
input_buffers
=
{}
self
.
output_buffers
=
{}
self
.
flashinfer_handlers
=
{}
self
.
graph_memory_pool
=
None
# Common inputs
self
.
max_bs
=
max_batch_size_to_capture
self
.
input_ids
=
torch
.
zeros
((
self
.
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
req_pool_indices
=
torch
.
zeros
((
self
.
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
seq_lens
=
torch
.
ones
((
self
.
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
position_ids_offsets
=
torch
.
zeros
((
self
.
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
out_cache_loc
=
torch
.
zeros
((
self
.
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# FlashInfer inputs
self
.
flashinfer_workspace_buffer
=
self
.
model_runner
.
flashinfer_workspace_buffers
[
0
]
self
.
flashinfer_kv_indptr
=
torch
.
zeros
(
(
self
.
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
flashinfer_kv_indices
=
torch
.
zeros
(
(
self
.
max_bs
*
model_runner
.
model_config
.
context_len
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
flashinfer_kv_last_page_len
=
torch
.
ones
(
(
self
.
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
def
can_run
(
self
,
batch_size
):
return
batch_size
<
self
.
max_bs
def
capture
(
self
,
batch_size_list
):
self
.
batch_size_list
=
batch_size_list
with
graph_capture
()
as
graph_capture_context
:
self
.
stream
=
graph_capture_context
.
stream
for
bs
in
batch_size_list
:
graph
,
input_buffers
,
output_buffers
,
flashinfer_handler
=
self
.
capture_one_batch_size
(
bs
)
self
.
graphs
[
bs
]
=
graph
self
.
input_buffers
[
bs
]
=
input_buffers
self
.
output_buffers
[
bs
]
=
output_buffers
self
.
flashinfer_handlers
[
bs
]
=
flashinfer_handler
def
capture_one_batch_size
(
self
,
bs
):
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
from
flashinfer.decode
import
_grouped_size_compiled_for_decode_kernels
graph
=
torch
.
cuda
.
CUDAGraph
()
stream
=
self
.
stream
# Common inputs
input_ids
=
self
.
input_ids
[:
bs
]
req_pool_indices
=
self
.
req_pool_indices
[:
bs
]
seq_lens
=
self
.
seq_lens
[:
bs
]
position_ids_offsets
=
self
.
position_ids_offsets
[:
bs
]
out_cache_loc
=
self
.
out_cache_loc
[:
bs
]
# FlashInfer inputs
if
not
_grouped_size_compiled_for_decode_kernels
(
self
.
model_runner
.
model_config
.
num_attention_heads
//
self
.
model_runner
.
tp_size
,
self
.
model_runner
.
model_config
.
get_num_kv_heads
(
self
.
model_runner
.
tp_size
),
):
use_tensor_cores
=
True
else
:
use_tensor_cores
=
False
flashinfer_decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
self
.
flashinfer_workspace_buffer
,
"NHD"
,
use_cuda_graph
=
True
,
use_tensor_cores
=
use_tensor_cores
,
paged_kv_indptr_buffer
=
self
.
flashinfer_kv_indptr
[:
bs
+
1
],
paged_kv_indices_buffer
=
self
.
flashinfer_kv_indices
,
paged_kv_last_page_len_buffer
=
self
.
flashinfer_kv_last_page_len
[:
bs
],
)
init_flashinfer_args
(
ForwardMode
.
DECODE
,
self
.
model_runner
,
req_pool_indices
,
seq_lens
,
None
,
flashinfer_decode_wrapper
,
)
# Run and capture
def
run_once
():
input_metadata
=
InputMetadata
.
create
(
self
.
model_runner
,
forward_mode
=
ForwardMode
.
DECODE
,
req_pool_indices
=
req_pool_indices
,
seq_lens
=
seq_lens
,
prefix_lens
=
None
,
position_ids_offsets
=
position_ids_offsets
,
out_cache_loc
=
out_cache_loc
,
out_cache_cont_start
=
None
,
out_cache_cont_end
=
None
,
return_logprob
=
False
,
top_logprobs_nums
=
0
,
skip_flashinfer_init
=
True
,
)
input_metadata
.
flashinfer_decode_wrapper
=
flashinfer_decode_wrapper
return
self
.
model_runner
.
model
.
forward
(
input_ids
,
input_metadata
.
positions
,
input_metadata
)
for
_
in
range
(
2
):
run_once
()
torch
.
cuda
.
synchronize
()
with
torch
.
cuda
.
graph
(
graph
,
pool
=
self
.
graph_memory_pool
,
stream
=
stream
):
out
=
run_once
()
torch
.
cuda
.
synchronize
()
self
.
graph_memory_pool
=
graph
.
pool
()
return
graph
,
None
,
out
,
flashinfer_decode_wrapper
def
replay
(
self
,
batch
:
Batch
):
assert
batch
.
out_cache_loc
is
not
None
assert
not
batch
.
return_logprob
raw_bs
=
len
(
batch
.
reqs
)
# Pad
index
=
bisect
.
bisect_left
(
self
.
batch_size_list
,
raw_bs
)
bs
=
self
.
batch_size_list
[
index
]
if
bs
!=
raw_bs
:
self
.
seq_lens
.
fill_
(
1
)
self
.
out_cache_loc
.
zero_
()
# Common inputs
self
.
input_ids
[:
raw_bs
]
=
batch
.
input_ids
self
.
req_pool_indices
[:
raw_bs
]
=
batch
.
req_pool_indices
self
.
seq_lens
[:
raw_bs
]
=
batch
.
seq_lens
self
.
position_ids_offsets
[:
raw_bs
]
=
batch
.
position_ids_offsets
self
.
out_cache_loc
[:
raw_bs
]
=
batch
.
out_cache_loc
# FlashInfer inputs
init_flashinfer_args
(
ForwardMode
.
DECODE
,
self
.
model_runner
,
self
.
req_pool_indices
[:
bs
],
self
.
seq_lens
[:
bs
],
None
,
self
.
flashinfer_handlers
[
bs
],
)
# Replay
self
.
graphs
[
bs
].
replay
()
output
=
self
.
output_buffers
[
bs
]
# Unpad
if
bs
==
raw_bs
:
return
output
else
:
output
=
LogitProcessorOutput
(
next_token_logits
=
output
.
next_token_logits
[:
raw_bs
],
next_token_logprobs
=
output
.
next_token_logprobs
[:
raw_bs
]
if
output
.
next_token_logprobs
is
not
None
else
None
,
normalized_prompt_logprobs
=
None
,
prefill_token_logprobs
=
None
,
prefill_top_logprobs
=
None
,
decode_top_logprobs
=
output
.
decode_top_logprobs
[:
raw_bs
]
if
output
.
decode_top_logprobs
is
not
None
else
None
,
)
return
output
\ No newline at end of file
python/sglang/srt/managers/controller/infer_batch.py
View file @
66581596
...
@@ -675,7 +675,11 @@ class Batch:
...
@@ -675,7 +675,11 @@ class Batch:
# TODO(lmzheng): apply penalty
# TODO(lmzheng): apply penalty
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
probs_sort
,
probs_idx
=
_top_p_top_k
(
probs
,
self
.
top_ps
,
self
.
top_ks
)
probs_sort
,
probs_idx
=
_top_p_top_k
(
probs
,
self
.
top_ps
,
self
.
top_ks
)
sampled_index
=
torch
.
multinomial
(
probs_sort
,
num_samples
=
1
)
try
:
sampled_index
=
torch
.
multinomial
(
probs_sort
,
num_samples
=
1
)
except
RuntimeError
as
e
:
warnings
.
warn
(
f
"Ignore errors in sampling:
{
e
}
"
)
sampled_index
=
torch
.
ones
(
probs_sort
.
shape
[:
-
1
]
+
(
1
,),
dtype
=
torch
.
int64
,
device
=
probs
.
device
)
batch_next_token_ids
=
torch
.
gather
(
probs_idx
,
dim
=
1
,
index
=
sampled_index
).
view
(
batch_next_token_ids
=
torch
.
gather
(
probs_idx
,
dim
=
1
,
index
=
sampled_index
).
view
(
-
1
-
1
)
)
...
@@ -757,9 +761,11 @@ class InputMetadata:
...
@@ -757,9 +761,11 @@ class InputMetadata:
out_cache_cont_end
=
None
,
out_cache_cont_end
=
None
,
top_logprobs_nums
=
None
,
top_logprobs_nums
=
None
,
return_logprob
=
False
,
return_logprob
=
False
,
skip_flashinfer_init
=
False
,
):
):
if
not
model_runner
.
server_args
.
disable_flashinfer
:
if
not
skip_flashinfer_init
and
not
model_runner
.
server_args
.
disable_flashinfer
:
init_flashinfer_args
(
forward_mode
,
model_runner
,
req_pool_indices
,
seq_lens
,
prefix_lens
)
init_flashinfer_args
(
forward_mode
,
model_runner
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
model_runner
.
flashinfer_decode_wrapper
)
batch_size
=
len
(
req_pool_indices
)
batch_size
=
len
(
req_pool_indices
)
...
@@ -826,7 +832,8 @@ class InputMetadata:
...
@@ -826,7 +832,8 @@ class InputMetadata:
return
ret
return
ret
def
init_flashinfer_args
(
forward_mode
,
model_runner
,
req_pool_indices
,
seq_lens
,
prefix_lens
):
def
init_flashinfer_args
(
forward_mode
,
model_runner
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
flashinfer_decode_wrapper
):
num_qo_heads
=
model_runner
.
model_config
.
num_attention_heads
//
model_runner
.
tp_size
num_qo_heads
=
model_runner
.
model_config
.
num_attention_heads
//
model_runner
.
tp_size
num_kv_heads
=
model_runner
.
model_config
.
get_num_kv_heads
(
model_runner
.
tp_size
)
num_kv_heads
=
model_runner
.
model_config
.
get_num_kv_heads
(
model_runner
.
tp_size
)
head_dim
=
model_runner
.
model_config
.
head_dim
head_dim
=
model_runner
.
model_config
.
head_dim
...
@@ -857,8 +864,8 @@ def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens,
...
@@ -857,8 +864,8 @@ def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens,
)
)
if
forward_mode
==
ForwardMode
.
DECODE
:
if
forward_mode
==
ForwardMode
.
DECODE
:
model_runner
.
flashinfer_decode_wrapper
.
end_forward
()
flashinfer_decode_wrapper
.
end_forward
()
model_runner
.
flashinfer_decode_wrapper
.
begin_forward
(
flashinfer_decode_wrapper
.
begin_forward
(
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
kv_last_page_len
,
kv_last_page_len
,
...
...
python/sglang/srt/managers/controller/model_runner.py
View file @
66581596
...
@@ -15,6 +15,7 @@ from vllm.distributed import init_distributed_environment, initialize_model_para
...
@@ -15,6 +15,7 @@ from vllm.distributed import init_distributed_environment, initialize_model_para
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models
import
ModelRegistry
from
sglang.global_config
import
global_config
from
sglang.srt.managers.controller.infer_batch
import
Batch
,
ForwardMode
,
InputMetadata
,
global_server_args_dict
from
sglang.srt.managers.controller.infer_batch
import
Batch
,
ForwardMode
,
InputMetadata
,
global_server_args_dict
from
sglang.srt.memory_pool
import
ReqToTokenPool
,
TokenToKVPool
from
sglang.srt.memory_pool
import
ReqToTokenPool
,
TokenToKVPool
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
...
@@ -90,6 +91,9 @@ class ModelRunner:
...
@@ -90,6 +91,9 @@ class ModelRunner:
self
.
init_cublas
()
self
.
init_cublas
()
self
.
init_flash_infer
()
self
.
init_flash_infer
()
# Capture cuda graphs
self
.
init_cuda_graphs
()
def
load_model
(
self
):
def
load_model
(
self
):
logger
.
info
(
logger
.
info
(
f
"[gpu_id=
{
self
.
gpu_id
}
] Load weight begin. "
f
"[gpu_id=
{
self
.
gpu_id
}
] Load weight begin. "
...
@@ -203,29 +207,46 @@ class ModelRunner:
...
@@ -203,29 +207,46 @@ class ModelRunner:
else
:
else
:
use_tensor_cores
=
False
use_tensor_cores
=
False
workspace_buffers
=
torch
.
empty
(
self
.
flashinfer_
workspace_buffers
=
torch
.
empty
(
3
,
96
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
2
,
global_config
.
flashinfer_workspace_size
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
)
self
.
flashinfer_prefill_wrapper_ragged
=
BatchPrefillWithRaggedKVCacheWrapper
(
self
.
flashinfer_prefill_wrapper_ragged
=
BatchPrefillWithRaggedKVCacheWrapper
(
workspace_buffers
[
0
],
"NHD"
self
.
flashinfer_
workspace_buffers
[
0
],
"NHD"
)
)
self
.
flashinfer_prefill_wrapper_paged
=
BatchPrefillWithPagedKVCacheWrapper
(
self
.
flashinfer_prefill_wrapper_paged
=
BatchPrefillWithPagedKVCacheWrapper
(
workspace_buffers
[
1
],
"NHD"
self
.
flashinfer_
workspace_buffers
[
1
],
"NHD"
)
)
self
.
flashinfer_decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
self
.
flashinfer_decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
workspace_buffers
[
2
],
"NHD"
,
use_tensor_cores
=
use_tensor_cores
self
.
flashinfer_
workspace_buffers
[
0
],
"NHD"
,
use_tensor_cores
=
use_tensor_cores
)
)
def
init_cuda_graphs
(
self
):
from
sglang.srt.managers.controller.cuda_graph_runner
import
CudaGraphRunner
if
self
.
server_args
.
disable_cuda_graph
or
self
.
server_args
.
disable_flashinfer
:
self
.
cuda_graph_runner
=
None
return
logger
.
info
(
f
"[gpu_id=
{
self
.
gpu_id
}
] Capture cuda graph begin."
)
batch_size_list
=
[
1
,
2
,
4
]
+
[
i
*
8
for
i
in
range
(
1
,
16
)]
self
.
cuda_graph_runner
=
CudaGraphRunner
(
self
,
max_batch_size_to_capture
=
max
(
batch_size_list
))
self
.
cuda_graph_runner
.
capture
(
batch_size_list
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
forward_extend
(
self
,
batch
:
Batch
):
def
forward_decode
(
self
,
batch
:
Batch
):
if
self
.
cuda_graph_runner
and
self
.
cuda_graph_runner
.
can_run
(
len
(
batch
.
reqs
)):
return
self
.
cuda_graph_runner
.
replay
(
batch
)
input_metadata
=
InputMetadata
.
create
(
input_metadata
=
InputMetadata
.
create
(
self
,
self
,
forward_mode
=
ForwardMode
.
EXTEND
,
forward_mode
=
ForwardMode
.
DECODE
,
req_pool_indices
=
batch
.
req_pool_indices
,
req_pool_indices
=
batch
.
req_pool_indices
,
seq_lens
=
batch
.
seq_lens
,
seq_lens
=
batch
.
seq_lens
,
prefix_lens
=
batch
.
prefix_lens
,
prefix_lens
=
batch
.
prefix_lens
,
position_ids_offsets
=
batch
.
position_ids_offsets
,
position_ids_offsets
=
batch
.
position_ids_offsets
,
out_cache_loc
=
batch
.
out_cache_loc
,
out_cache_loc
=
batch
.
out_cache_loc
,
out_cache_cont_start
=
batch
.
out_cache_cont_start
,
out_cache_cont_end
=
batch
.
out_cache_cont_end
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
return_logprob
=
batch
.
return_logprob
,
return_logprob
=
batch
.
return_logprob
,
)
)
...
@@ -234,17 +255,15 @@ class ModelRunner:
...
@@ -234,17 +255,15 @@ class ModelRunner:
)
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
forward_
decode
(
self
,
batch
:
Batch
):
def
forward_
extend
(
self
,
batch
:
Batch
):
input_metadata
=
InputMetadata
.
create
(
input_metadata
=
InputMetadata
.
create
(
self
,
self
,
forward_mode
=
ForwardMode
.
DECODE
,
forward_mode
=
ForwardMode
.
EXTEND
,
req_pool_indices
=
batch
.
req_pool_indices
,
req_pool_indices
=
batch
.
req_pool_indices
,
seq_lens
=
batch
.
seq_lens
,
seq_lens
=
batch
.
seq_lens
,
prefix_lens
=
batch
.
prefix_lens
,
prefix_lens
=
batch
.
prefix_lens
,
position_ids_offsets
=
batch
.
position_ids_offsets
,
position_ids_offsets
=
batch
.
position_ids_offsets
,
out_cache_loc
=
batch
.
out_cache_loc
,
out_cache_loc
=
batch
.
out_cache_loc
,
out_cache_cont_start
=
batch
.
out_cache_cont_start
,
out_cache_cont_end
=
batch
.
out_cache_cont_end
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
return_logprob
=
batch
.
return_logprob
,
return_logprob
=
batch
.
return_logprob
,
)
)
...
...
python/sglang/srt/managers/controller/tp_worker.py
View file @
66581596
...
@@ -98,7 +98,7 @@ class ModelTpServer:
...
@@ -98,7 +98,7 @@ class ModelTpServer:
)
)
self
.
max_total_num_tokens
=
self
.
model_runner
.
max_total_num_tokens
self
.
max_total_num_tokens
=
self
.
model_runner
.
max_total_num_tokens
self
.
max_prefill_tokens
=
(
self
.
max_prefill_tokens
=
(
4096
8192
if
server_args
.
max_prefill_tokens
is
None
if
server_args
.
max_prefill_tokens
is
None
else
server_args
.
max_prefill_tokens
else
server_args
.
max_prefill_tokens
)
)
...
@@ -314,11 +314,9 @@ class ModelTpServer:
...
@@ -314,11 +314,9 @@ class ModelTpServer:
self
.
forward_queue
.
append
(
req
)
self
.
forward_queue
.
append
(
req
)
def
get_new_fill_batch
(
self
)
->
Optional
[
Batch
]:
def
get_new_fill_batch
(
self
)
->
Optional
[
Batch
]:
if
(
running_bs
=
len
(
self
.
running_batch
.
reqs
)
if
self
.
running_batch
is
not
None
else
0
self
.
running_batch
is
not
None
if
running_bs
>
self
.
max_running_requests
:
and
len
(
self
.
running_batch
.
reqs
)
>
self
.
max_running_requests
return
):
return
None
# Compute matched prefix length
# Compute matched prefix length
for
req
in
self
.
forward_queue
:
for
req
in
self
.
forward_queue
:
...
@@ -394,6 +392,10 @@ class ModelTpServer:
...
@@ -394,6 +392,10 @@ class ModelTpServer:
new_batch_input_tokens
+=
req
.
extend_input_len
new_batch_input_tokens
+=
req
.
extend_input_len
else
:
else
:
break
break
if
running_bs
+
len
(
can_run_list
)
>
self
.
max_running_requests
:
break
if
len
(
can_run_list
)
==
0
:
if
len
(
can_run_list
)
==
0
:
return
None
return
None
...
...
python/sglang/srt/memory_pool.py
View file @
66581596
...
@@ -38,7 +38,10 @@ class ReqToTokenPool:
...
@@ -38,7 +38,10 @@ class ReqToTokenPool:
class
TokenToKVPool
:
class
TokenToKVPool
:
def
__init__
(
self
,
size
,
dtype
,
head_num
,
head_dim
,
layer_num
):
def
__init__
(
self
,
size
,
dtype
,
head_num
,
head_dim
,
layer_num
):
self
.
mem_state
=
torch
.
zeros
((
size
,),
dtype
=
torch
.
int16
,
device
=
"cuda"
)
self
.
size
=
size
# mem_state is the reference counter.
# We also add one slot. This slot is used for writing dummy output from padded tokens.
self
.
mem_state
=
torch
.
zeros
((
self
.
size
+
1
,),
dtype
=
torch
.
int16
,
device
=
"cuda"
)
self
.
total_ref_ct
=
0
self
.
total_ref_ct
=
0
# [size, key/value, head_num, head_dim] for each layer
# [size, key/value, head_num, head_dim] for each layer
...
@@ -47,6 +50,8 @@ class TokenToKVPool:
...
@@ -47,6 +50,8 @@ class TokenToKVPool:
for
_
in
range
(
layer_num
)
for
_
in
range
(
layer_num
)
]
]
self
.
clear
()
def
get_key_buffer
(
self
,
layer_id
):
def
get_key_buffer
(
self
,
layer_id
):
return
self
.
kv_data
[
layer_id
][:,
0
]
return
self
.
kv_data
[
layer_id
][:,
0
]
...
@@ -101,3 +106,6 @@ class TokenToKVPool:
...
@@ -101,3 +106,6 @@ class TokenToKVPool:
def
clear
(
self
):
def
clear
(
self
):
self
.
mem_state
.
fill_
(
0
)
self
.
mem_state
.
fill_
(
0
)
self
.
total_ref_ct
=
0
self
.
total_ref_ct
=
0
# We also add one slot. This slot is used for writing dummy output from padded tokens.
self
.
add_refs
(
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
))
\ No newline at end of file
python/sglang/srt/server.py
View file @
66581596
...
@@ -146,6 +146,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
...
@@ -146,6 +146,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
# Set global environments
# Set global environments
os
.
environ
[
"TF_CPP_MIN_LOG_LEVEL"
]
=
"3"
os
.
environ
[
"TF_CPP_MIN_LOG_LEVEL"
]
=
"3"
os
.
environ
[
"NCCL_CUMEM_ENABLE"
]
=
"0"
if
server_args
.
show_time_cost
:
if
server_args
.
show_time_cost
:
enable_show_time_cost
()
enable_show_time_cost
()
if
server_args
.
disable_disk_cache
:
if
server_args
.
disable_disk_cache
:
...
...
python/sglang/srt/server_args.py
View file @
66581596
...
@@ -29,7 +29,7 @@ class ServerArgs:
...
@@ -29,7 +29,7 @@ class ServerArgs:
max_prefill_tokens
:
Optional
[
int
]
=
None
max_prefill_tokens
:
Optional
[
int
]
=
None
max_running_requests
:
Optional
[
int
]
=
None
max_running_requests
:
Optional
[
int
]
=
None
schedule_heuristic
:
str
=
"lpm"
schedule_heuristic
:
str
=
"lpm"
schedule_conservativeness
:
float
=
1.0
schedule_conservativeness
:
float
=
0.8
# Other runtime options
# Other runtime options
tp_size
:
int
=
1
tp_size
:
int
=
1
...
@@ -68,13 +68,13 @@ class ServerArgs:
...
@@ -68,13 +68,13 @@ class ServerArgs:
self
.
tokenizer_path
=
self
.
model_path
self
.
tokenizer_path
=
self
.
model_path
if
self
.
mem_fraction_static
is
None
:
if
self
.
mem_fraction_static
is
None
:
if
self
.
tp_size
>=
8
:
if
self
.
tp_size
>=
8
:
self
.
mem_fraction_static
=
0.8
0
self
.
mem_fraction_static
=
0.
7
8
elif
self
.
tp_size
>=
4
:
elif
self
.
tp_size
>=
4
:
self
.
mem_fraction_static
=
0.8
2
self
.
mem_fraction_static
=
0.8
0
elif
self
.
tp_size
>=
2
:
elif
self
.
tp_size
>=
2
:
self
.
mem_fraction_static
=
0.85
self
.
mem_fraction_static
=
0.85
else
:
else
:
self
.
mem_fraction_static
=
0.
90
self
.
mem_fraction_static
=
0.
88
if
isinstance
(
self
.
additional_ports
,
int
):
if
isinstance
(
self
.
additional_ports
,
int
):
self
.
additional_ports
=
[
self
.
additional_ports
]
self
.
additional_ports
=
[
self
.
additional_ports
]
elif
self
.
additional_ports
is
None
:
elif
self
.
additional_ports
is
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment