Unverified Commit bea2bb9e authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Improve multi-node stability (#1171)

parent cd10654e
"""Launch the inference server.""" """Launch the inference server."""
import argparse import argparse
import os
from sglang.srt.server import launch_server from sglang.srt.server import launch_server
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import kill_child_process
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -11,4 +13,9 @@ if __name__ == "__main__": ...@@ -11,4 +13,9 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args) server_args = ServerArgs.from_cli_args(args)
try:
launch_server(server_args) launch_server(server_args)
except Exception as e:
raise e
finally:
kill_child_process(os.getpid(), including_parent=False)
...@@ -233,6 +233,8 @@ class TiktokenTokenizer: ...@@ -233,6 +233,8 @@ class TiktokenTokenizer:
} }
assert tok_dict["word_split"] == "V1" assert tok_dict["word_split"] == "V1"
default_allowed_special = None
kwargs = { kwargs = {
"name": name, "name": name,
"pat_str": tok_dict.get("pat_str", PAT_STR_B), "pat_str": tok_dict.get("pat_str", PAT_STR_B),
...@@ -246,14 +248,18 @@ class TiktokenTokenizer: ...@@ -246,14 +248,18 @@ class TiktokenTokenizer:
for bytes_list in tok_dict["default_allowed_special"] for bytes_list in tok_dict["default_allowed_special"]
] ]
) )
else:
default_allowed_special = None
if "vocab_size" in tok_dict: if "vocab_size" in tok_dict:
kwargs["explicit_n_vocab"] = tok_dict["vocab_size"] kwargs["explicit_n_vocab"] = tok_dict["vocab_size"]
PAD = "<|pad|>"
EOS = "<|eos|>"
SEP = "<|separator|>"
DEFAULT_CONTROL_TOKENS = {"pad": PAD, "sep": EOS, "eos": SEP}
tokenizer = tiktoken.Encoding(**kwargs) tokenizer = tiktoken.Encoding(**kwargs)
tokenizer._default_allowed_special = default_allowed_special or set() tokenizer._default_allowed_special = default_allowed_special or set()
tokenizer._default_allowed_special |= {"<|separator|>"} tokenizer._control_tokens = DEFAULT_CONTROL_TOKENS
def encode_patched( def encode_patched(
self, self,
...@@ -270,14 +276,14 @@ class TiktokenTokenizer: ...@@ -270,14 +276,14 @@ class TiktokenTokenizer:
self, self,
text, text,
allowed_special=allowed_special, allowed_special=allowed_special,
disallowed_special=disallowed_special, disallowed_special=(),
) )
tokenizer.encode = functools.partial(encode_patched, tokenizer) tokenizer.encode = functools.partial(encode_patched, tokenizer)
# Convert to HF interface # Convert to HF interface
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.eos_token_id = tokenizer._special_tokens["<|eos|>"] self.eos_token_id = tokenizer._special_tokens[EOS]
self.vocab_size = tokenizer.n_vocab self.vocab_size = tokenizer.n_vocab
self.chat_template = Template( self.chat_template = Template(
"{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}" "{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
......
...@@ -212,6 +212,4 @@ def start_controller_process( ...@@ -212,6 +212,4 @@ def start_controller_process(
except Exception: except Exception:
logger.error("Exception in ControllerMulti:\n" + get_exception_traceback()) logger.error("Exception in ControllerMulti:\n" + get_exception_traceback())
finally: finally:
for w in controller.workers:
os.kill(w.proc.pid, 9)
kill_parent_process() kill_parent_process()
...@@ -167,6 +167,4 @@ def start_controller_process( ...@@ -167,6 +167,4 @@ def start_controller_process(
except Exception: except Exception:
logger.error("Exception in ControllerSingle:\n" + get_exception_traceback()) logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
finally: finally:
for t in controller.tp_procs:
os.kill(t.pid, 9)
kill_parent_process() kill_parent_process()
...@@ -16,7 +16,6 @@ limitations under the License. ...@@ -16,7 +16,6 @@ limitations under the License.
"""Meta data for requests and batches""" """Meta data for requests and batches"""
import logging import logging
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Union from typing import List, Optional, Union
...@@ -270,7 +269,7 @@ class Req: ...@@ -270,7 +269,7 @@ class Req:
if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]: if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
# TODO(lsyin): fix token fusion # TODO(lsyin): fix token fusion
logging.warning( logger.warning(
"Token fusion between input and output, try to avoid this by removing the space at the end of the input." "Token fusion between input and output, try to avoid this by removing the space at the end of the input."
) )
return False return False
...@@ -753,7 +752,7 @@ class ScheduleBatch: ...@@ -753,7 +752,7 @@ class ScheduleBatch:
) )
self.logit_bias = torch.concat([self.logit_bias, other.logit_bias]) self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
def sample(self, logits: torch.Tensor, is_multi_node_tp=False): def sample(self, logits: torch.Tensor):
# TODO(lsyin): move this into a part of layer and run with CUDA Graph # TODO(lsyin): move this into a part of layer and run with CUDA Graph
# Post process logits # Post process logits
logits = logits.contiguous() logits = logits.contiguous()
...@@ -791,7 +790,7 @@ class ScheduleBatch: ...@@ -791,7 +790,7 @@ class ScheduleBatch:
) )
if not torch.all(success): if not torch.all(success):
logging.warning("Sampling failed, fallback to top_k=1 strategy") logger.warning(f"Sampling failed. Fallback to top_k=1 strategy. {logits=}")
probs = probs.masked_fill(torch.isnan(probs), 0.0) probs = probs.masked_fill(torch.isnan(probs), 0.0)
argmax_ids = torch.argmax(probs, dim=-1) argmax_ids = torch.argmax(probs, dim=-1)
batch_next_token_ids = torch.where( batch_next_token_ids = torch.where(
...@@ -808,16 +807,6 @@ class ScheduleBatch: ...@@ -808,16 +807,6 @@ class ScheduleBatch:
self.penalizer_orchestrator.cumulate_output_tokens(batch_next_token_ids) self.penalizer_orchestrator.cumulate_output_tokens(batch_next_token_ids)
if is_multi_node_tp:
# If the tensor parallelism spans across multiple nodes, there is some indeterminism
# that can cause the TP workers to generate different tokens, so we need to
# sync here
torch.distributed.all_reduce(
batch_next_token_ids,
op=dist.ReduceOp.MIN,
group=get_tensor_model_parallel_group().device_group,
)
return batch_next_token_ids return batch_next_token_ids
...@@ -835,7 +824,8 @@ def top_k_top_p_sampling_from_probs_torch( ...@@ -835,7 +824,8 @@ def top_k_top_p_sampling_from_probs_torch(
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0]) probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
try: try:
sampled_index = torch.multinomial(probs_sort, num_samples=1) sampled_index = torch.multinomial(probs_sort, num_samples=1)
except RuntimeError: except RuntimeError as e:
logger.warning(f"Sampling error: {e}")
batch_next_token_ids = torch.zeros( batch_next_token_ids = torch.zeros(
(probs_sort.shape[0],), dtype=torch.int32, device=probs.device (probs_sort.shape[0],), dtype=torch.int32, device=probs.device
) )
......
...@@ -133,6 +133,13 @@ class ModelTpServer: ...@@ -133,6 +133,13 @@ class ModelTpServer:
self.model_config.context_len - 1, self.model_config.context_len - 1,
self.max_total_num_tokens - 1, self.max_total_num_tokens - 1,
) )
# Sync random seed
server_args.random_seed = broadcast_recv_input(
[server_args.random_seed],
self.tp_rank,
self.model_runner.tp_group.cpu_group,
)[0]
set_random_seed(server_args.random_seed) set_random_seed(server_args.random_seed)
# Print info # Print info
...@@ -474,9 +481,7 @@ class ModelTpServer: ...@@ -474,9 +481,7 @@ class ModelTpServer:
# Forward and sample the next tokens # Forward and sample the next tokens
if batch.extend_num_tokens != 0: if batch.extend_num_tokens != 0:
output = self.model_runner.forward(batch, ForwardMode.EXTEND) output = self.model_runner.forward(batch, ForwardMode.EXTEND)
next_token_ids = batch.sample( next_token_ids = batch.sample(output.next_token_logits)
output.next_token_logits, self.model_runner.is_multi_node_tp
)
# Move logprobs to cpu # Move logprobs to cpu
if output.next_token_logprobs is not None: if output.next_token_logprobs is not None:
...@@ -636,9 +641,7 @@ class ModelTpServer: ...@@ -636,9 +641,7 @@ class ModelTpServer:
# Forward and sample the next tokens # Forward and sample the next tokens
output = self.model_runner.forward(batch, ForwardMode.DECODE) output = self.model_runner.forward(batch, ForwardMode.DECODE)
next_token_ids = batch.sample( next_token_ids = batch.sample(output.next_token_logits)
output.next_token_logits, self.model_runner.is_multi_node_tp
)
# Move logprobs to cpu # Move logprobs to cpu
if output.next_token_logprobs is not None: if output.next_token_logprobs is not None:
...@@ -879,6 +882,7 @@ def broadcast_recv_input( ...@@ -879,6 +882,7 @@ def broadcast_recv_input(
dist.broadcast(tensor_size, src=0, group=dist_group) dist.broadcast(tensor_size, src=0, group=dist_group)
dist.broadcast(tensor_data, src=0, group=dist_group) dist.broadcast(tensor_data, src=0, group=dist_group)
return data
else: else:
tensor_size = torch.tensor([0], dtype=torch.long) tensor_size = torch.tensor([0], dtype=torch.long)
dist.broadcast(tensor_size, src=0, group=dist_group) dist.broadcast(tensor_size, src=0, group=dist_group)
......
...@@ -84,13 +84,20 @@ def set_torch_compile_config(): ...@@ -84,13 +84,20 @@ def set_torch_compile_config():
class CudaGraphRunner: class CudaGraphRunner:
def __init__(self, model_runner, max_batch_size_to_capture, use_torch_compile): def __init__(
self,
model_runner,
max_batch_size_to_capture: int,
use_torch_compile: bool,
disable_padding: bool,
):
self.model_runner = model_runner self.model_runner = model_runner
self.graphs = {} self.graphs = {}
self.input_buffers = {} self.input_buffers = {}
self.output_buffers = {} self.output_buffers = {}
self.flashinfer_handlers = {} self.flashinfer_handlers = {}
self.graph_memory_pool = None self.graph_memory_pool = None
self.disable_padding = disable_padding
# Common inputs # Common inputs
self.max_bs = max_batch_size_to_capture self.max_bs = max_batch_size_to_capture
...@@ -142,6 +149,9 @@ class CudaGraphRunner: ...@@ -142,6 +149,9 @@ class CudaGraphRunner:
set_torch_compile_config() set_torch_compile_config()
def can_run(self, batch_size): def can_run(self, batch_size):
if self.disable_padding:
return batch_size in self.graphs
else:
return batch_size <= self.max_bs return batch_size <= self.max_bs
def capture(self, batch_size_list): def capture(self, batch_size_list):
......
...@@ -465,6 +465,7 @@ class ModelRunner: ...@@ -465,6 +465,7 @@ class ModelRunner:
self, self,
max_batch_size_to_capture=max(batch_size_list), max_batch_size_to_capture=max(batch_size_list),
use_torch_compile=self.server_args.enable_torch_compile, use_torch_compile=self.server_args.enable_torch_compile,
disable_padding=self.server_args.disable_cuda_graph_padding,
) )
try: try:
self.cuda_graph_runner.capture(batch_size_list) self.cuda_graph_runner.capture(batch_size_list)
......
...@@ -24,7 +24,6 @@ import json ...@@ -24,7 +24,6 @@ import json
import logging import logging
import multiprocessing as mp import multiprocessing as mp
import os import os
import sys
import threading import threading
import time import time
from http import HTTPStatus from http import HTTPStatus
...@@ -301,12 +300,9 @@ def launch_server( ...@@ -301,12 +300,9 @@ def launch_server(
server_args.tokenizer_path = prepare_tokenizer(server_args.tokenizer_path) server_args.tokenizer_path = prepare_tokenizer(server_args.tokenizer_path)
# Launch processes for multi-node tensor parallelism # Launch processes for multi-node tensor parallelism
if server_args.nnodes > 1: if server_args.nnodes > 1 and server_args.node_rank != 0:
if server_args.node_rank != 0:
tp_size_local = server_args.tp_size // server_args.nnodes tp_size_local = server_args.tp_size // server_args.nnodes
gpu_ids = [ gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
i for _ in range(server_args.nnodes) for i in range(tp_size_local)
]
tp_rank_range = list( tp_rank_range = list(
range( range(
server_args.node_rank * tp_size_local, server_args.node_rank * tp_size_local,
...@@ -320,8 +316,13 @@ def launch_server( ...@@ -320,8 +316,13 @@ def launch_server(
ports[3], ports[3],
model_overide_args, model_overide_args,
) )
while True:
pass try:
for p in procs:
p.join()
finally:
kill_child_process(os.getpid(), including_parent=False)
return
# Launch processes # Launch processes
tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args) tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
...@@ -356,15 +357,11 @@ def launch_server( ...@@ -356,15 +357,11 @@ def launch_server(
if controller_init_state != "init ok" or detoken_init_state != "init ok": if controller_init_state != "init ok" or detoken_init_state != "init ok":
proc_controller.kill() proc_controller.kill()
proc_detoken.kill() proc_detoken.kill()
print( raise RuntimeError(
f"Initialization failed. controller_init_state: {controller_init_state}", "Initialization failed. "
flush=True, f"controller_init_state: {controller_init_state}, "
) f"detoken_init_state: {detoken_init_state}"
print(
f"Initialization failed. detoken_init_state: {detoken_init_state}",
flush=True,
) )
sys.exit(1)
assert proc_controller.is_alive() and proc_detoken.is_alive() assert proc_controller.is_alive() and proc_detoken.is_alive()
# Add api key authorization # Add api key authorization
...@@ -373,12 +370,12 @@ def launch_server( ...@@ -373,12 +370,12 @@ def launch_server(
# Send a warmup request # Send a warmup request
t = threading.Thread( t = threading.Thread(
target=_wait_and_warmup, args=(server_args, pipe_finish_writer) target=_wait_and_warmup, args=(server_args, pipe_finish_writer, os.getpid())
) )
t.start() t.start()
# Listen for requests
try: try:
# Listen for requests
uvicorn.run( uvicorn.run(
app, app,
host=server_args.host, host=server_args.host,
...@@ -426,7 +423,7 @@ def _set_envs_and_config(server_args: ServerArgs): ...@@ -426,7 +423,7 @@ def _set_envs_and_config(server_args: ServerArgs):
) )
def _wait_and_warmup(server_args, pipe_finish_writer): def _wait_and_warmup(server_args, pipe_finish_writer, pid):
headers = {} headers = {}
url = server_args.url() url = server_args.url()
if server_args.api_key: if server_args.api_key:
...@@ -449,8 +446,9 @@ def _wait_and_warmup(server_args, pipe_finish_writer): ...@@ -449,8 +446,9 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
if not success: if not success:
if pipe_finish_writer is not None: if pipe_finish_writer is not None:
pipe_finish_writer.send(last_traceback) pipe_finish_writer.send(last_traceback)
print(f"Initialization failed. warmup error: {last_traceback}", flush=True) logger.error(f"Initialization failed. warmup error: {last_traceback}")
sys.exit(1) kill_child_process(pid, including_parent=False)
return
# Send a warmup request # Send a warmup request
request_name = "/generate" if model_info["is_generation"] else "/encode" request_name = "/generate" if model_info["is_generation"] else "/encode"
...@@ -475,12 +473,13 @@ def _wait_and_warmup(server_args, pipe_finish_writer): ...@@ -475,12 +473,13 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
timeout=600, timeout=600,
) )
assert res.status_code == 200, f"{res}" assert res.status_code == 200, f"{res}"
except Exception as e: except Exception:
last_traceback = get_exception_traceback() last_traceback = get_exception_traceback()
if pipe_finish_writer is not None: if pipe_finish_writer is not None:
pipe_finish_writer.send(last_traceback) pipe_finish_writer.send(last_traceback)
print(f"Initialization failed. warmup error: {last_traceback}", flush=True) logger.error(f"Initialization failed. warmup error: {last_traceback}")
sys.exit(1) kill_child_process(pid, including_parent=False)
return
logger.info("The server is fired up and ready to roll!") logger.info("The server is fired up and ready to roll!")
if pipe_finish_writer is not None: if pipe_finish_writer is not None:
......
...@@ -79,6 +79,7 @@ class ServerArgs: ...@@ -79,6 +79,7 @@ class ServerArgs:
disable_radix_cache: bool = False disable_radix_cache: bool = False
disable_regex_jump_forward: bool = False disable_regex_jump_forward: bool = False
disable_cuda_graph: bool = False disable_cuda_graph: bool = False
disable_cuda_graph_padding: bool = False
disable_disk_cache: bool = False disable_disk_cache: bool = False
enable_mixed_chunk: bool = False enable_mixed_chunk: bool = False
enable_torch_compile: bool = False enable_torch_compile: bool = False
...@@ -393,6 +394,11 @@ class ServerArgs: ...@@ -393,6 +394,11 @@ class ServerArgs:
action="store_true", action="store_true",
help="Disable cuda graph.", help="Disable cuda graph.",
) )
parser.add_argument(
"--disable-cuda-graph-padding",
action="store_true",
help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.",
)
parser.add_argument( parser.add_argument(
"--disable-disk-cache", "--disable-disk-cache",
action="store_true", action="store_true",
......
...@@ -369,14 +369,11 @@ def kill_parent_process(): ...@@ -369,14 +369,11 @@ def kill_parent_process():
"""Kill the parent process and all children of the parent process.""" """Kill the parent process and all children of the parent process."""
current_process = psutil.Process() current_process = psutil.Process()
parent_process = current_process.parent() parent_process = current_process.parent()
children = parent_process.children(recursive=True) kill_child_process(parent_process.pid, skip_pid=current_process.pid)
for child in children:
if child.pid != current_process.pid:
os.kill(child.pid, 9)
os.kill(parent_process.pid, 9)
def kill_child_process(pid, including_parent=True): def kill_child_process(pid, including_parent=True, skip_pid=None):
"""Kill the process and all its children process."""
try: try:
parent = psutil.Process(pid) parent = psutil.Process(pid)
except psutil.NoSuchProcess: except psutil.NoSuchProcess:
...@@ -384,6 +381,8 @@ def kill_child_process(pid, including_parent=True): ...@@ -384,6 +381,8 @@ def kill_child_process(pid, including_parent=True):
children = parent.children(recursive=True) children = parent.children(recursive=True)
for child in children: for child in children:
if child.pid == skip_pid:
continue
try: try:
child.kill() child.kill()
except psutil.NoSuchProcess: except psutil.NoSuchProcess:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment