Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
bea2bb9e
Unverified
Commit
bea2bb9e
authored
Aug 20, 2024
by
Lianmin Zheng
Committed by
GitHub
Aug 20, 2024
Browse files
Improve multi-node stability (#1171)
parent
cd10654e
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
94 additions
and
76 deletions
+94
-76
python/sglang/launch_server.py
python/sglang/launch_server.py
+8
-1
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+11
-5
python/sglang/srt/managers/controller_multi.py
python/sglang/srt/managers/controller_multi.py
+0
-2
python/sglang/srt/managers/controller_single.py
python/sglang/srt/managers/controller_single.py
+0
-2
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+5
-15
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+10
-6
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+12
-2
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-0
python/sglang/srt/server.py
python/sglang/srt/server.py
+36
-37
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+6
-0
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+5
-6
No files found.
python/sglang/launch_server.py
View file @
bea2bb9e
"""Launch the inference server."""
import
argparse
import
os
from
sglang.srt.server
import
launch_server
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
kill_child_process
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
...
...
@@ -11,4 +13,9 @@ if __name__ == "__main__":
args
=
parser
.
parse_args
()
server_args
=
ServerArgs
.
from_cli_args
(
args
)
try
:
launch_server
(
server_args
)
except
Exception
as
e
:
raise
e
finally
:
kill_child_process
(
os
.
getpid
(),
including_parent
=
False
)
python/sglang/srt/hf_transformers_utils.py
View file @
bea2bb9e
...
...
@@ -233,6 +233,8 @@ class TiktokenTokenizer:
}
assert
tok_dict
[
"word_split"
]
==
"V1"
default_allowed_special
=
None
kwargs
=
{
"name"
:
name
,
"pat_str"
:
tok_dict
.
get
(
"pat_str"
,
PAT_STR_B
),
...
...
@@ -246,14 +248,18 @@ class TiktokenTokenizer:
for
bytes_list
in
tok_dict
[
"default_allowed_special"
]
]
)
else
:
default_allowed_special
=
None
if
"vocab_size"
in
tok_dict
:
kwargs
[
"explicit_n_vocab"
]
=
tok_dict
[
"vocab_size"
]
PAD
=
"<|pad|>"
EOS
=
"<|eos|>"
SEP
=
"<|separator|>"
DEFAULT_CONTROL_TOKENS
=
{
"pad"
:
PAD
,
"sep"
:
EOS
,
"eos"
:
SEP
}
tokenizer
=
tiktoken
.
Encoding
(
**
kwargs
)
tokenizer
.
_default_allowed_special
=
default_allowed_special
or
set
()
tokenizer
.
_
default_allowed_special
|=
{
"<|separator|>"
}
tokenizer
.
_
control_tokens
=
DEFAULT_CONTROL_TOKENS
def
encode_patched
(
self
,
...
...
@@ -270,14 +276,14 @@ class TiktokenTokenizer:
self
,
text
,
allowed_special
=
allowed_special
,
disallowed_special
=
disallowed_special
,
disallowed_special
=
()
,
)
tokenizer
.
encode
=
functools
.
partial
(
encode_patched
,
tokenizer
)
# Convert to HF interface
self
.
tokenizer
=
tokenizer
self
.
eos_token_id
=
tokenizer
.
_special_tokens
[
"<|eos|>"
]
self
.
eos_token_id
=
tokenizer
.
_special_tokens
[
EOS
]
self
.
vocab_size
=
tokenizer
.
n_vocab
self
.
chat_template
=
Template
(
"{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>
\n\n
' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>
\n\n
' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>
\n\n
' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
...
...
python/sglang/srt/managers/controller_multi.py
View file @
bea2bb9e
...
...
@@ -212,6 +212,4 @@ def start_controller_process(
except
Exception
:
logger
.
error
(
"Exception in ControllerMulti:
\n
"
+
get_exception_traceback
())
finally
:
for
w
in
controller
.
workers
:
os
.
kill
(
w
.
proc
.
pid
,
9
)
kill_parent_process
()
python/sglang/srt/managers/controller_single.py
View file @
bea2bb9e
...
...
@@ -167,6 +167,4 @@ def start_controller_process(
except
Exception
:
logger
.
error
(
"Exception in ControllerSingle:
\n
"
+
get_exception_traceback
())
finally
:
for
t
in
controller
.
tp_procs
:
os
.
kill
(
t
.
pid
,
9
)
kill_parent_process
()
python/sglang/srt/managers/schedule_batch.py
View file @
bea2bb9e
...
...
@@ -16,7 +16,6 @@ limitations under the License.
"""Meta data for requests and batches"""
import
logging
import
warnings
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Union
...
...
@@ -270,7 +269,7 @@ class Req:
if
all_ids
[
prompt_tokens
-
1
]
!=
self
.
origin_input_ids_unpadded
[
-
1
]:
# TODO(lsyin): fix token fusion
logg
ing
.
warning
(
logg
er
.
warning
(
"Token fusion between input and output, try to avoid this by removing the space at the end of the input."
)
return
False
...
...
@@ -753,7 +752,7 @@ class ScheduleBatch:
)
self
.
logit_bias
=
torch
.
concat
([
self
.
logit_bias
,
other
.
logit_bias
])
def
sample
(
self
,
logits
:
torch
.
Tensor
,
is_multi_node_tp
=
False
):
def
sample
(
self
,
logits
:
torch
.
Tensor
):
# TODO(lsyin): move this into a part of layer and run with CUDA Graph
# Post process logits
logits
=
logits
.
contiguous
()
...
...
@@ -791,7 +790,7 @@ class ScheduleBatch:
)
if
not
torch
.
all
(
success
):
logg
ing
.
warning
(
"Sampling failed
, f
allback to top_k=1 strategy"
)
logg
er
.
warning
(
f
"Sampling failed
. F
allback to top_k=1 strategy
.
{
logits
=
}
"
)
probs
=
probs
.
masked_fill
(
torch
.
isnan
(
probs
),
0.0
)
argmax_ids
=
torch
.
argmax
(
probs
,
dim
=-
1
)
batch_next_token_ids
=
torch
.
where
(
...
...
@@ -808,16 +807,6 @@ class ScheduleBatch:
self
.
penalizer_orchestrator
.
cumulate_output_tokens
(
batch_next_token_ids
)
if
is_multi_node_tp
:
# If the tensor parallelism spans across multiple nodes, there is some indeterminism
# that can cause the TP workers to generate different tokens, so we need to
# sync here
torch
.
distributed
.
all_reduce
(
batch_next_token_ids
,
op
=
dist
.
ReduceOp
.
MIN
,
group
=
get_tensor_model_parallel_group
().
device_group
,
)
return
batch_next_token_ids
...
...
@@ -835,7 +824,8 @@ def top_k_top_p_sampling_from_probs_torch(
probs_sort
.
div_
(
probs_sort
.
max
(
dim
=-
1
,
keepdim
=
True
)[
0
])
try
:
sampled_index
=
torch
.
multinomial
(
probs_sort
,
num_samples
=
1
)
except
RuntimeError
:
except
RuntimeError
as
e
:
logger
.
warning
(
f
"Sampling error:
{
e
}
"
)
batch_next_token_ids
=
torch
.
zeros
(
(
probs_sort
.
shape
[
0
],),
dtype
=
torch
.
int32
,
device
=
probs
.
device
)
...
...
python/sglang/srt/managers/tp_worker.py
View file @
bea2bb9e
...
...
@@ -133,6 +133,13 @@ class ModelTpServer:
self
.
model_config
.
context_len
-
1
,
self
.
max_total_num_tokens
-
1
,
)
# Sync random seed
server_args
.
random_seed
=
broadcast_recv_input
(
[
server_args
.
random_seed
],
self
.
tp_rank
,
self
.
model_runner
.
tp_group
.
cpu_group
,
)[
0
]
set_random_seed
(
server_args
.
random_seed
)
# Print info
...
...
@@ -474,9 +481,7 @@ class ModelTpServer:
# Forward and sample the next tokens
if
batch
.
extend_num_tokens
!=
0
:
output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
next_token_ids
=
batch
.
sample
(
output
.
next_token_logits
,
self
.
model_runner
.
is_multi_node_tp
)
next_token_ids
=
batch
.
sample
(
output
.
next_token_logits
)
# Move logprobs to cpu
if
output
.
next_token_logprobs
is
not
None
:
...
...
@@ -636,9 +641,7 @@ class ModelTpServer:
# Forward and sample the next tokens
output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
DECODE
)
next_token_ids
=
batch
.
sample
(
output
.
next_token_logits
,
self
.
model_runner
.
is_multi_node_tp
)
next_token_ids
=
batch
.
sample
(
output
.
next_token_logits
)
# Move logprobs to cpu
if
output
.
next_token_logprobs
is
not
None
:
...
...
@@ -879,6 +882,7 @@ def broadcast_recv_input(
dist
.
broadcast
(
tensor_size
,
src
=
0
,
group
=
dist_group
)
dist
.
broadcast
(
tensor_data
,
src
=
0
,
group
=
dist_group
)
return
data
else
:
tensor_size
=
torch
.
tensor
([
0
],
dtype
=
torch
.
long
)
dist
.
broadcast
(
tensor_size
,
src
=
0
,
group
=
dist_group
)
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
bea2bb9e
...
...
@@ -84,13 +84,20 @@ def set_torch_compile_config():
class
CudaGraphRunner
:
def
__init__
(
self
,
model_runner
,
max_batch_size_to_capture
,
use_torch_compile
):
def
__init__
(
self
,
model_runner
,
max_batch_size_to_capture
:
int
,
use_torch_compile
:
bool
,
disable_padding
:
bool
,
):
self
.
model_runner
=
model_runner
self
.
graphs
=
{}
self
.
input_buffers
=
{}
self
.
output_buffers
=
{}
self
.
flashinfer_handlers
=
{}
self
.
graph_memory_pool
=
None
self
.
disable_padding
=
disable_padding
# Common inputs
self
.
max_bs
=
max_batch_size_to_capture
...
...
@@ -142,6 +149,9 @@ class CudaGraphRunner:
set_torch_compile_config
()
def
can_run
(
self
,
batch_size
):
if
self
.
disable_padding
:
return
batch_size
in
self
.
graphs
else
:
return
batch_size
<=
self
.
max_bs
def
capture
(
self
,
batch_size_list
):
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
bea2bb9e
...
...
@@ -465,6 +465,7 @@ class ModelRunner:
self
,
max_batch_size_to_capture
=
max
(
batch_size_list
),
use_torch_compile
=
self
.
server_args
.
enable_torch_compile
,
disable_padding
=
self
.
server_args
.
disable_cuda_graph_padding
,
)
try
:
self
.
cuda_graph_runner
.
capture
(
batch_size_list
)
...
...
python/sglang/srt/server.py
View file @
bea2bb9e
...
...
@@ -24,7 +24,6 @@ import json
import
logging
import
multiprocessing
as
mp
import
os
import
sys
import
threading
import
time
from
http
import
HTTPStatus
...
...
@@ -301,12 +300,9 @@ def launch_server(
server_args
.
tokenizer_path
=
prepare_tokenizer
(
server_args
.
tokenizer_path
)
# Launch processes for multi-node tensor parallelism
if
server_args
.
nnodes
>
1
:
if
server_args
.
node_rank
!=
0
:
if
server_args
.
nnodes
>
1
and
server_args
.
node_rank
!=
0
:
tp_size_local
=
server_args
.
tp_size
//
server_args
.
nnodes
gpu_ids
=
[
i
for
_
in
range
(
server_args
.
nnodes
)
for
i
in
range
(
tp_size_local
)
]
gpu_ids
=
[
i
for
_
in
range
(
server_args
.
nnodes
)
for
i
in
range
(
tp_size_local
)]
tp_rank_range
=
list
(
range
(
server_args
.
node_rank
*
tp_size_local
,
...
...
@@ -320,8 +316,13 @@ def launch_server(
ports
[
3
],
model_overide_args
,
)
while
True
:
pass
try
:
for
p
in
procs
:
p
.
join
()
finally
:
kill_child_process
(
os
.
getpid
(),
including_parent
=
False
)
return
# Launch processes
tokenizer_manager
=
TokenizerManager
(
server_args
,
port_args
,
model_overide_args
)
...
...
@@ -356,15 +357,11 @@ def launch_server(
if
controller_init_state
!=
"init ok"
or
detoken_init_state
!=
"init ok"
:
proc_controller
.
kill
()
proc_detoken
.
kill
()
print
(
f
"Initialization failed. controller_init_state:
{
controller_init_state
}
"
,
flush
=
True
,
)
print
(
f
"Initialization failed. detoken_init_state:
{
detoken_init_state
}
"
,
flush
=
True
,
raise
RuntimeError
(
"Initialization failed. "
f
"controller_init_state:
{
controller_init_state
}
, "
f
"detoken_init_state:
{
detoken_init_state
}
"
)
sys
.
exit
(
1
)
assert
proc_controller
.
is_alive
()
and
proc_detoken
.
is_alive
()
# Add api key authorization
...
...
@@ -373,12 +370,12 @@ def launch_server(
# Send a warmup request
t
=
threading
.
Thread
(
target
=
_wait_and_warmup
,
args
=
(
server_args
,
pipe_finish_writer
)
target
=
_wait_and_warmup
,
args
=
(
server_args
,
pipe_finish_writer
,
os
.
getpid
()
)
)
t
.
start
()
# Listen for requests
try
:
# Listen for requests
uvicorn
.
run
(
app
,
host
=
server_args
.
host
,
...
...
@@ -426,7 +423,7 @@ def _set_envs_and_config(server_args: ServerArgs):
)
def
_wait_and_warmup
(
server_args
,
pipe_finish_writer
):
def
_wait_and_warmup
(
server_args
,
pipe_finish_writer
,
pid
):
headers
=
{}
url
=
server_args
.
url
()
if
server_args
.
api_key
:
...
...
@@ -449,8 +446,9 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
if
not
success
:
if
pipe_finish_writer
is
not
None
:
pipe_finish_writer
.
send
(
last_traceback
)
print
(
f
"Initialization failed. warmup error:
{
last_traceback
}
"
,
flush
=
True
)
sys
.
exit
(
1
)
logger
.
error
(
f
"Initialization failed. warmup error:
{
last_traceback
}
"
)
kill_child_process
(
pid
,
including_parent
=
False
)
return
# Send a warmup request
request_name
=
"/generate"
if
model_info
[
"is_generation"
]
else
"/encode"
...
...
@@ -475,12 +473,13 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
timeout
=
600
,
)
assert
res
.
status_code
==
200
,
f
"
{
res
}
"
except
Exception
as
e
:
except
Exception
:
last_traceback
=
get_exception_traceback
()
if
pipe_finish_writer
is
not
None
:
pipe_finish_writer
.
send
(
last_traceback
)
print
(
f
"Initialization failed. warmup error:
{
last_traceback
}
"
,
flush
=
True
)
sys
.
exit
(
1
)
logger
.
error
(
f
"Initialization failed. warmup error:
{
last_traceback
}
"
)
kill_child_process
(
pid
,
including_parent
=
False
)
return
logger
.
info
(
"The server is fired up and ready to roll!"
)
if
pipe_finish_writer
is
not
None
:
...
...
python/sglang/srt/server_args.py
View file @
bea2bb9e
...
...
@@ -79,6 +79,7 @@ class ServerArgs:
disable_radix_cache
:
bool
=
False
disable_regex_jump_forward
:
bool
=
False
disable_cuda_graph
:
bool
=
False
disable_cuda_graph_padding
:
bool
=
False
disable_disk_cache
:
bool
=
False
enable_mixed_chunk
:
bool
=
False
enable_torch_compile
:
bool
=
False
...
...
@@ -393,6 +394,11 @@ class ServerArgs:
action
=
"store_true"
,
help
=
"Disable cuda graph."
,
)
parser
.
add_argument
(
"--disable-cuda-graph-padding"
,
action
=
"store_true"
,
help
=
"Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed."
,
)
parser
.
add_argument
(
"--disable-disk-cache"
,
action
=
"store_true"
,
...
...
python/sglang/srt/utils.py
View file @
bea2bb9e
...
...
@@ -369,14 +369,11 @@ def kill_parent_process():
"""Kill the parent process and all children of the parent process."""
current_process
=
psutil
.
Process
()
parent_process
=
current_process
.
parent
()
children
=
parent_process
.
children
(
recursive
=
True
)
for
child
in
children
:
if
child
.
pid
!=
current_process
.
pid
:
os
.
kill
(
child
.
pid
,
9
)
os
.
kill
(
parent_process
.
pid
,
9
)
kill_child_process
(
parent_process
.
pid
,
skip_pid
=
current_process
.
pid
)
def
kill_child_process
(
pid
,
including_parent
=
True
):
def
kill_child_process
(
pid
,
including_parent
=
True
,
skip_pid
=
None
):
"""Kill the process and all its children process."""
try
:
parent
=
psutil
.
Process
(
pid
)
except
psutil
.
NoSuchProcess
:
...
...
@@ -384,6 +381,8 @@ def kill_child_process(pid, including_parent=True):
children
=
parent
.
children
(
recursive
=
True
)
for
child
in
children
:
if
child
.
pid
==
skip_pid
:
continue
try
:
child
.
kill
()
except
psutil
.
NoSuchProcess
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment