Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
66581596
Unverified
Commit
66581596
authored
Jul 13, 2024
by
Lianmin Zheng
Committed by
GitHub
Jul 13, 2024
Browse files
Enable cuda graph by default (#612)
parent
396a6924
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
317 additions
and
70 deletions
+317
-70
benchmark/latency_throughput/bench_one.py
benchmark/latency_throughput/bench_one.py
+58
-24
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+0
-1
python/sglang/global_config.py
python/sglang/global_config.py
+21
-17
python/sglang/srt/managers/controller/cuda_graph_runner.py
python/sglang/srt/managers/controller/cuda_graph_runner.py
+173
-0
python/sglang/srt/managers/controller/infer_batch.py
python/sglang/srt/managers/controller/infer_batch.py
+13
-6
python/sglang/srt/managers/controller/model_runner.py
python/sglang/srt/managers/controller/model_runner.py
+30
-11
python/sglang/srt/managers/controller/tp_worker.py
python/sglang/srt/managers/controller/tp_worker.py
+8
-6
python/sglang/srt/memory_pool.py
python/sglang/srt/memory_pool.py
+9
-1
python/sglang/srt/server.py
python/sglang/srt/server.py
+1
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+4
-4
No files found.
benchmark/latency_throughput/bench_one.py
View file @
66581596
"""
Usage:
python3 bench_one.py --input-len 2048 --batch-size 1 2 4 8 16 32 64 128 256 512
"""
import
argparse
import
json
import
time
import
numpy
as
np
import
requests
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"http://127.0.0.1"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
"--backend"
,
type
=
str
,
default
=
"srt"
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--max-tokens"
,
type
=
int
,
default
=
256
)
args
=
parser
.
parse_args
()
if
args
.
port
is
None
:
if
args
.
backend
==
"srt"
:
args
.
port
=
30000
elif
args
.
backend
==
"vllm"
:
args
.
port
=
21000
elif
args
.
backend
==
"lightllm"
:
args
.
port
=
22000
elif
args
.
backend
==
"ginfer"
:
args
.
port
=
9988
else
:
raise
ValueError
(
f
"Invalid backend:
{
args
.
backend
}
"
)
def
run_one_batch_size
(
bs
):
url
=
f
"
{
args
.
host
}
:
{
args
.
port
}
"
a
=
20
max_new_tokens
=
args
.
max_tokens
a
=
20
prompt
=
f
"
{
a
,
}
"
tic
=
time
.
time
()
if
args
.
backend
==
"srt"
:
if
args
.
input_len
:
inputs
=
{
"input_ids"
:
[
[
int
(
x
)
for
x
in
np
.
random
.
randint
(
0
,
high
=
16384
,
size
=
(
args
.
input_len
,))]
for
_
in
range
(
bs
)
]}
else
:
inputs
=
{
"text"
:
[
f
"
{
i
,
}
"
for
i
in
range
(
bs
)
]}
response
=
requests
.
post
(
url
+
"/generate"
,
json
=
{
"text"
:
[
prompt
]
*
args
.
batch_size
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
max_new_tokens
,
"ignore_eos"
:
True
,
},
**
inputs
,
},
)
elif
args
.
backend
==
"lightllm"
:
...
...
@@ -91,5 +89,41 @@ if __name__ == "__main__":
ret
=
response
.
json
()
print
(
ret
)
speed
=
args
.
batch_size
*
max_new_tokens
/
latency
print
(
f
"latency:
{
latency
:.
2
f
}
s, speed:
{
speed
:.
2
f
}
token/s"
)
output_throughput
=
bs
*
max_new_tokens
/
latency
print
(
f
"latency:
{
latency
:.
2
f
}
s, speed:
{
output_throughput
:.
2
f
}
token/s"
)
with
open
(
"tmp_output.txt"
,
"a"
)
as
fout
:
res
=
{
"input_len"
:
args
.
input_len
,
"output_len"
:
args
.
max_tokens
,
"batch_size"
:
bs
,
"latency"
:
latency
,
"output_throughput"
:
output_throughput
}
fout
.
write
(
json
.
dumps
(
res
)
+
"
\n
"
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"http://127.0.0.1"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
"--backend"
,
type
=
str
,
default
=
"srt"
)
parser
.
add_argument
(
"--input-len"
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
nargs
=
'*'
,
default
=
[
1
])
parser
.
add_argument
(
"--max-tokens"
,
type
=
int
,
default
=
256
)
args
=
parser
.
parse_args
()
if
args
.
port
is
None
:
if
args
.
backend
==
"srt"
:
args
.
port
=
30000
elif
args
.
backend
==
"vllm"
:
args
.
port
=
21000
elif
args
.
backend
==
"lightllm"
:
args
.
port
=
22000
elif
args
.
backend
==
"ginfer"
:
args
.
port
=
9988
else
:
raise
ValueError
(
f
"Invalid backend:
{
args
.
backend
}
"
)
for
bs
in
args
.
batch_size
:
run_one_batch_size
(
bs
)
python/sglang/bench_latency.py
View file @
66581596
...
...
@@ -30,7 +30,6 @@ import argparse
import
dataclasses
import
logging
import
multiprocessing
import
os
import
time
...
...
python/sglang/global_config.py
View file @
66581596
...
...
@@ -8,36 +8,40 @@ class GlobalConfig:
# 2: output final text after every run
self
.
verbosity
=
0
# Default backend of the language
self
.
default_backend
=
None
# Output configs
# Runtime constants: Request dependency time due to network delay
self
.
request_dependency_delay
=
0.02
self
.
wait_for_new_request_delay
=
0.0006
# Runtime constants: New generation token ratio estimation
self
.
base_new_token_ratio
=
0.4
self
.
base_min_new_token_ratio
=
0.2
self
.
new_token_ratio_decay
=
0.0001
self
.
new_token_ratio_recovery
=
0.05
# Runtime constants: The threshold (number of tokens) to trigger layer-wise cuda sync.
# This can improve the speed for large batch sizes during prefill.
self
.
layer_sync_threshold
=
8192
# Runtime constants: Flashinfer
self
.
flashinfer_workspace_size
=
192
*
1024
*
1024
# Output tokenization configs
self
.
skip_special_tokens_in_output
=
True
self
.
spaces_between_special_tokens_in_out
=
True
#
O
ptimization configs
#
Interpreter o
ptimization configs
self
.
eager_fill_image
=
False
self
.
enable_precache_with_tracing
=
True
self
.
enable_parallel_encoding
=
True
self
.
enable_parallel_decoding
=
True
# Deprecated
# Choices: ["no_adjust", "adjust_cache"]
# no_adjust: Do not adjust the position embedding of KV cache.
# adjust_cache: Adjust the position embedding of KV cache.
self
.
concate_and_append_mode
=
"no_adjust"
# Request dependency time due to network delay
self
.
request_dependency_delay
=
0.02
self
.
wait_for_new_request_delay
=
0.0006
# New generation token ratio estimation
self
.
base_new_token_ratio
=
0.4
self
.
base_min_new_token_ratio
=
0.2
self
.
new_token_ratio_decay
=
0.0001
self
.
new_token_ratio_recovery
=
0.05
# The threshold (number of tokens) to trigger layer-wise cuda sync.
# This can improve the speed for large batch sizes during prefill.
self
.
layer_sync_threshold
=
8192
global_config
=
GlobalConfig
()
python/sglang/srt/managers/controller/cuda_graph_runner.py
0 → 100644
View file @
66581596
"""Run the model with cuda graph."""
import
bisect
import
torch
from
vllm.distributed.parallel_state
import
graph_capture
from
sglang.global_config
import
global_config
from
sglang.srt.layers.logits_processor
import
LogitProcessorOutput
from
sglang.srt.managers.controller.infer_batch
import
(
Batch
,
ForwardMode
,
InputMetadata
,
init_flashinfer_args
)
class
CudaGraphRunner
:
def
__init__
(
self
,
model_runner
,
max_batch_size_to_capture
):
self
.
model_runner
=
model_runner
self
.
graphs
=
{}
self
.
input_buffers
=
{}
self
.
output_buffers
=
{}
self
.
flashinfer_handlers
=
{}
self
.
graph_memory_pool
=
None
# Common inputs
self
.
max_bs
=
max_batch_size_to_capture
self
.
input_ids
=
torch
.
zeros
((
self
.
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
req_pool_indices
=
torch
.
zeros
((
self
.
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
seq_lens
=
torch
.
ones
((
self
.
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
position_ids_offsets
=
torch
.
zeros
((
self
.
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
out_cache_loc
=
torch
.
zeros
((
self
.
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# FlashInfer inputs
self
.
flashinfer_workspace_buffer
=
self
.
model_runner
.
flashinfer_workspace_buffers
[
0
]
self
.
flashinfer_kv_indptr
=
torch
.
zeros
(
(
self
.
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
flashinfer_kv_indices
=
torch
.
zeros
(
(
self
.
max_bs
*
model_runner
.
model_config
.
context_len
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
flashinfer_kv_last_page_len
=
torch
.
ones
(
(
self
.
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
def
can_run
(
self
,
batch_size
):
return
batch_size
<
self
.
max_bs
def
capture
(
self
,
batch_size_list
):
self
.
batch_size_list
=
batch_size_list
with
graph_capture
()
as
graph_capture_context
:
self
.
stream
=
graph_capture_context
.
stream
for
bs
in
batch_size_list
:
graph
,
input_buffers
,
output_buffers
,
flashinfer_handler
=
self
.
capture_one_batch_size
(
bs
)
self
.
graphs
[
bs
]
=
graph
self
.
input_buffers
[
bs
]
=
input_buffers
self
.
output_buffers
[
bs
]
=
output_buffers
self
.
flashinfer_handlers
[
bs
]
=
flashinfer_handler
def
capture_one_batch_size
(
self
,
bs
):
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
from
flashinfer.decode
import
_grouped_size_compiled_for_decode_kernels
graph
=
torch
.
cuda
.
CUDAGraph
()
stream
=
self
.
stream
# Common inputs
input_ids
=
self
.
input_ids
[:
bs
]
req_pool_indices
=
self
.
req_pool_indices
[:
bs
]
seq_lens
=
self
.
seq_lens
[:
bs
]
position_ids_offsets
=
self
.
position_ids_offsets
[:
bs
]
out_cache_loc
=
self
.
out_cache_loc
[:
bs
]
# FlashInfer inputs
if
not
_grouped_size_compiled_for_decode_kernels
(
self
.
model_runner
.
model_config
.
num_attention_heads
//
self
.
model_runner
.
tp_size
,
self
.
model_runner
.
model_config
.
get_num_kv_heads
(
self
.
model_runner
.
tp_size
),
):
use_tensor_cores
=
True
else
:
use_tensor_cores
=
False
flashinfer_decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
self
.
flashinfer_workspace_buffer
,
"NHD"
,
use_cuda_graph
=
True
,
use_tensor_cores
=
use_tensor_cores
,
paged_kv_indptr_buffer
=
self
.
flashinfer_kv_indptr
[:
bs
+
1
],
paged_kv_indices_buffer
=
self
.
flashinfer_kv_indices
,
paged_kv_last_page_len_buffer
=
self
.
flashinfer_kv_last_page_len
[:
bs
],
)
init_flashinfer_args
(
ForwardMode
.
DECODE
,
self
.
model_runner
,
req_pool_indices
,
seq_lens
,
None
,
flashinfer_decode_wrapper
,
)
# Run and capture
def
run_once
():
input_metadata
=
InputMetadata
.
create
(
self
.
model_runner
,
forward_mode
=
ForwardMode
.
DECODE
,
req_pool_indices
=
req_pool_indices
,
seq_lens
=
seq_lens
,
prefix_lens
=
None
,
position_ids_offsets
=
position_ids_offsets
,
out_cache_loc
=
out_cache_loc
,
out_cache_cont_start
=
None
,
out_cache_cont_end
=
None
,
return_logprob
=
False
,
top_logprobs_nums
=
0
,
skip_flashinfer_init
=
True
,
)
input_metadata
.
flashinfer_decode_wrapper
=
flashinfer_decode_wrapper
return
self
.
model_runner
.
model
.
forward
(
input_ids
,
input_metadata
.
positions
,
input_metadata
)
for
_
in
range
(
2
):
run_once
()
torch
.
cuda
.
synchronize
()
with
torch
.
cuda
.
graph
(
graph
,
pool
=
self
.
graph_memory_pool
,
stream
=
stream
):
out
=
run_once
()
torch
.
cuda
.
synchronize
()
self
.
graph_memory_pool
=
graph
.
pool
()
return
graph
,
None
,
out
,
flashinfer_decode_wrapper
def
replay
(
self
,
batch
:
Batch
):
assert
batch
.
out_cache_loc
is
not
None
assert
not
batch
.
return_logprob
raw_bs
=
len
(
batch
.
reqs
)
# Pad
index
=
bisect
.
bisect_left
(
self
.
batch_size_list
,
raw_bs
)
bs
=
self
.
batch_size_list
[
index
]
if
bs
!=
raw_bs
:
self
.
seq_lens
.
fill_
(
1
)
self
.
out_cache_loc
.
zero_
()
# Common inputs
self
.
input_ids
[:
raw_bs
]
=
batch
.
input_ids
self
.
req_pool_indices
[:
raw_bs
]
=
batch
.
req_pool_indices
self
.
seq_lens
[:
raw_bs
]
=
batch
.
seq_lens
self
.
position_ids_offsets
[:
raw_bs
]
=
batch
.
position_ids_offsets
self
.
out_cache_loc
[:
raw_bs
]
=
batch
.
out_cache_loc
# FlashInfer inputs
init_flashinfer_args
(
ForwardMode
.
DECODE
,
self
.
model_runner
,
self
.
req_pool_indices
[:
bs
],
self
.
seq_lens
[:
bs
],
None
,
self
.
flashinfer_handlers
[
bs
],
)
# Replay
self
.
graphs
[
bs
].
replay
()
output
=
self
.
output_buffers
[
bs
]
# Unpad
if
bs
==
raw_bs
:
return
output
else
:
output
=
LogitProcessorOutput
(
next_token_logits
=
output
.
next_token_logits
[:
raw_bs
],
next_token_logprobs
=
output
.
next_token_logprobs
[:
raw_bs
]
if
output
.
next_token_logprobs
is
not
None
else
None
,
normalized_prompt_logprobs
=
None
,
prefill_token_logprobs
=
None
,
prefill_top_logprobs
=
None
,
decode_top_logprobs
=
output
.
decode_top_logprobs
[:
raw_bs
]
if
output
.
decode_top_logprobs
is
not
None
else
None
,
)
return
output
\ No newline at end of file
python/sglang/srt/managers/controller/infer_batch.py
View file @
66581596
...
...
@@ -675,7 +675,11 @@ class Batch:
# TODO(lmzheng): apply penalty
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
probs_sort
,
probs_idx
=
_top_p_top_k
(
probs
,
self
.
top_ps
,
self
.
top_ks
)
sampled_index
=
torch
.
multinomial
(
probs_sort
,
num_samples
=
1
)
try
:
sampled_index
=
torch
.
multinomial
(
probs_sort
,
num_samples
=
1
)
except
RuntimeError
as
e
:
warnings
.
warn
(
f
"Ignore errors in sampling:
{
e
}
"
)
sampled_index
=
torch
.
ones
(
probs_sort
.
shape
[:
-
1
]
+
(
1
,),
dtype
=
torch
.
int64
,
device
=
probs
.
device
)
batch_next_token_ids
=
torch
.
gather
(
probs_idx
,
dim
=
1
,
index
=
sampled_index
).
view
(
-
1
)
...
...
@@ -757,9 +761,11 @@ class InputMetadata:
out_cache_cont_end
=
None
,
top_logprobs_nums
=
None
,
return_logprob
=
False
,
skip_flashinfer_init
=
False
,
):
if
not
model_runner
.
server_args
.
disable_flashinfer
:
init_flashinfer_args
(
forward_mode
,
model_runner
,
req_pool_indices
,
seq_lens
,
prefix_lens
)
if
not
skip_flashinfer_init
and
not
model_runner
.
server_args
.
disable_flashinfer
:
init_flashinfer_args
(
forward_mode
,
model_runner
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
model_runner
.
flashinfer_decode_wrapper
)
batch_size
=
len
(
req_pool_indices
)
...
...
@@ -826,7 +832,8 @@ class InputMetadata:
return
ret
def
init_flashinfer_args
(
forward_mode
,
model_runner
,
req_pool_indices
,
seq_lens
,
prefix_lens
):
def
init_flashinfer_args
(
forward_mode
,
model_runner
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
flashinfer_decode_wrapper
):
num_qo_heads
=
model_runner
.
model_config
.
num_attention_heads
//
model_runner
.
tp_size
num_kv_heads
=
model_runner
.
model_config
.
get_num_kv_heads
(
model_runner
.
tp_size
)
head_dim
=
model_runner
.
model_config
.
head_dim
...
...
@@ -857,8 +864,8 @@ def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens,
)
if
forward_mode
==
ForwardMode
.
DECODE
:
model_runner
.
flashinfer_decode_wrapper
.
end_forward
()
model_runner
.
flashinfer_decode_wrapper
.
begin_forward
(
flashinfer_decode_wrapper
.
end_forward
()
flashinfer_decode_wrapper
.
begin_forward
(
kv_indptr
,
kv_indices
,
kv_last_page_len
,
...
...
python/sglang/srt/managers/controller/model_runner.py
View file @
66581596
...
...
@@ -15,6 +15,7 @@ from vllm.distributed import init_distributed_environment, initialize_model_para
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.models
import
ModelRegistry
from
sglang.global_config
import
global_config
from
sglang.srt.managers.controller.infer_batch
import
Batch
,
ForwardMode
,
InputMetadata
,
global_server_args_dict
from
sglang.srt.memory_pool
import
ReqToTokenPool
,
TokenToKVPool
from
sglang.srt.server_args
import
ServerArgs
...
...
@@ -90,6 +91,9 @@ class ModelRunner:
self
.
init_cublas
()
self
.
init_flash_infer
()
# Capture cuda graphs
self
.
init_cuda_graphs
()
def
load_model
(
self
):
logger
.
info
(
f
"[gpu_id=
{
self
.
gpu_id
}
] Load weight begin. "
...
...
@@ -203,29 +207,46 @@ class ModelRunner:
else
:
use_tensor_cores
=
False
workspace_buffers
=
torch
.
empty
(
3
,
96
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
self
.
flashinfer_
workspace_buffers
=
torch
.
empty
(
2
,
global_config
.
flashinfer_workspace_size
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
self
.
flashinfer_prefill_wrapper_ragged
=
BatchPrefillWithRaggedKVCacheWrapper
(
workspace_buffers
[
0
],
"NHD"
self
.
flashinfer_
workspace_buffers
[
0
],
"NHD"
)
self
.
flashinfer_prefill_wrapper_paged
=
BatchPrefillWithPagedKVCacheWrapper
(
workspace_buffers
[
1
],
"NHD"
self
.
flashinfer_
workspace_buffers
[
1
],
"NHD"
)
self
.
flashinfer_decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
workspace_buffers
[
2
],
"NHD"
,
use_tensor_cores
=
use_tensor_cores
self
.
flashinfer_
workspace_buffers
[
0
],
"NHD"
,
use_tensor_cores
=
use_tensor_cores
)
def
init_cuda_graphs
(
self
):
from
sglang.srt.managers.controller.cuda_graph_runner
import
CudaGraphRunner
if
self
.
server_args
.
disable_cuda_graph
or
self
.
server_args
.
disable_flashinfer
:
self
.
cuda_graph_runner
=
None
return
logger
.
info
(
f
"[gpu_id=
{
self
.
gpu_id
}
] Capture cuda graph begin."
)
batch_size_list
=
[
1
,
2
,
4
]
+
[
i
*
8
for
i
in
range
(
1
,
16
)]
self
.
cuda_graph_runner
=
CudaGraphRunner
(
self
,
max_batch_size_to_capture
=
max
(
batch_size_list
))
self
.
cuda_graph_runner
.
capture
(
batch_size_list
)
@
torch
.
inference_mode
()
def
forward_extend
(
self
,
batch
:
Batch
):
def
forward_decode
(
self
,
batch
:
Batch
):
if
self
.
cuda_graph_runner
and
self
.
cuda_graph_runner
.
can_run
(
len
(
batch
.
reqs
)):
return
self
.
cuda_graph_runner
.
replay
(
batch
)
input_metadata
=
InputMetadata
.
create
(
self
,
forward_mode
=
ForwardMode
.
EXTEND
,
forward_mode
=
ForwardMode
.
DECODE
,
req_pool_indices
=
batch
.
req_pool_indices
,
seq_lens
=
batch
.
seq_lens
,
prefix_lens
=
batch
.
prefix_lens
,
position_ids_offsets
=
batch
.
position_ids_offsets
,
out_cache_loc
=
batch
.
out_cache_loc
,
out_cache_cont_start
=
batch
.
out_cache_cont_start
,
out_cache_cont_end
=
batch
.
out_cache_cont_end
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
return_logprob
=
batch
.
return_logprob
,
)
...
...
@@ -234,17 +255,15 @@ class ModelRunner:
)
@
torch
.
inference_mode
()
def
forward_
decode
(
self
,
batch
:
Batch
):
def
forward_
extend
(
self
,
batch
:
Batch
):
input_metadata
=
InputMetadata
.
create
(
self
,
forward_mode
=
ForwardMode
.
DECODE
,
forward_mode
=
ForwardMode
.
EXTEND
,
req_pool_indices
=
batch
.
req_pool_indices
,
seq_lens
=
batch
.
seq_lens
,
prefix_lens
=
batch
.
prefix_lens
,
position_ids_offsets
=
batch
.
position_ids_offsets
,
out_cache_loc
=
batch
.
out_cache_loc
,
out_cache_cont_start
=
batch
.
out_cache_cont_start
,
out_cache_cont_end
=
batch
.
out_cache_cont_end
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
return_logprob
=
batch
.
return_logprob
,
)
...
...
python/sglang/srt/managers/controller/tp_worker.py
View file @
66581596
...
...
@@ -98,7 +98,7 @@ class ModelTpServer:
)
self
.
max_total_num_tokens
=
self
.
model_runner
.
max_total_num_tokens
self
.
max_prefill_tokens
=
(
4096
8192
if
server_args
.
max_prefill_tokens
is
None
else
server_args
.
max_prefill_tokens
)
...
...
@@ -314,11 +314,9 @@ class ModelTpServer:
self
.
forward_queue
.
append
(
req
)
def
get_new_fill_batch
(
self
)
->
Optional
[
Batch
]:
if
(
self
.
running_batch
is
not
None
and
len
(
self
.
running_batch
.
reqs
)
>
self
.
max_running_requests
):
return
None
running_bs
=
len
(
self
.
running_batch
.
reqs
)
if
self
.
running_batch
is
not
None
else
0
if
running_bs
>
self
.
max_running_requests
:
return
# Compute matched prefix length
for
req
in
self
.
forward_queue
:
...
...
@@ -394,6 +392,10 @@ class ModelTpServer:
new_batch_input_tokens
+=
req
.
extend_input_len
else
:
break
if
running_bs
+
len
(
can_run_list
)
>
self
.
max_running_requests
:
break
if
len
(
can_run_list
)
==
0
:
return
None
...
...
python/sglang/srt/memory_pool.py
View file @
66581596
...
...
@@ -38,7 +38,10 @@ class ReqToTokenPool:
class
TokenToKVPool
:
def
__init__
(
self
,
size
,
dtype
,
head_num
,
head_dim
,
layer_num
):
self
.
mem_state
=
torch
.
zeros
((
size
,),
dtype
=
torch
.
int16
,
device
=
"cuda"
)
self
.
size
=
size
# mem_state is the reference counter.
# We also add one slot. This slot is used for writing dummy output from padded tokens.
self
.
mem_state
=
torch
.
zeros
((
self
.
size
+
1
,),
dtype
=
torch
.
int16
,
device
=
"cuda"
)
self
.
total_ref_ct
=
0
# [size, key/value, head_num, head_dim] for each layer
...
...
@@ -47,6 +50,8 @@ class TokenToKVPool:
for
_
in
range
(
layer_num
)
]
self
.
clear
()
def
get_key_buffer
(
self
,
layer_id
):
return
self
.
kv_data
[
layer_id
][:,
0
]
...
...
@@ -101,3 +106,6 @@ class TokenToKVPool:
def
clear
(
self
):
self
.
mem_state
.
fill_
(
0
)
self
.
total_ref_ct
=
0
# We also add one slot. This slot is used for writing dummy output from padded tokens.
self
.
add_refs
(
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
))
\ No newline at end of file
python/sglang/srt/server.py
View file @
66581596
...
...
@@ -146,6 +146,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
# Set global environments
os
.
environ
[
"TF_CPP_MIN_LOG_LEVEL"
]
=
"3"
os
.
environ
[
"NCCL_CUMEM_ENABLE"
]
=
"0"
if
server_args
.
show_time_cost
:
enable_show_time_cost
()
if
server_args
.
disable_disk_cache
:
...
...
python/sglang/srt/server_args.py
View file @
66581596
...
...
@@ -29,7 +29,7 @@ class ServerArgs:
max_prefill_tokens
:
Optional
[
int
]
=
None
max_running_requests
:
Optional
[
int
]
=
None
schedule_heuristic
:
str
=
"lpm"
schedule_conservativeness
:
float
=
1.0
schedule_conservativeness
:
float
=
0.8
# Other runtime options
tp_size
:
int
=
1
...
...
@@ -68,13 +68,13 @@ class ServerArgs:
self
.
tokenizer_path
=
self
.
model_path
if
self
.
mem_fraction_static
is
None
:
if
self
.
tp_size
>=
8
:
self
.
mem_fraction_static
=
0.8
0
self
.
mem_fraction_static
=
0.
7
8
elif
self
.
tp_size
>=
4
:
self
.
mem_fraction_static
=
0.8
2
self
.
mem_fraction_static
=
0.8
0
elif
self
.
tp_size
>=
2
:
self
.
mem_fraction_static
=
0.85
else
:
self
.
mem_fraction_static
=
0.
90
self
.
mem_fraction_static
=
0.
88
if
isinstance
(
self
.
additional_ports
,
int
):
self
.
additional_ports
=
[
self
.
additional_ports
]
elif
self
.
additional_ports
is
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment