Unverified Commit bea2bb9e authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Improve multi-node stability (#1171)

parent cd10654e
"""Launch the inference server."""
import argparse
import os
from sglang.srt.server import launch_server
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import kill_child_process
if __name__ == "__main__":
parser = argparse.ArgumentParser()
......@@ -11,4 +13,9 @@ if __name__ == "__main__":
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
launch_server(server_args)
try:
launch_server(server_args)
except Exception as e:
raise e
finally:
kill_child_process(os.getpid(), including_parent=False)
......@@ -233,6 +233,8 @@ class TiktokenTokenizer:
}
assert tok_dict["word_split"] == "V1"
default_allowed_special = None
kwargs = {
"name": name,
"pat_str": tok_dict.get("pat_str", PAT_STR_B),
......@@ -246,14 +248,18 @@ class TiktokenTokenizer:
for bytes_list in tok_dict["default_allowed_special"]
]
)
else:
default_allowed_special = None
if "vocab_size" in tok_dict:
kwargs["explicit_n_vocab"] = tok_dict["vocab_size"]
PAD = "<|pad|>"
EOS = "<|eos|>"
SEP = "<|separator|>"
DEFAULT_CONTROL_TOKENS = {"pad": PAD, "sep": EOS, "eos": SEP}
tokenizer = tiktoken.Encoding(**kwargs)
tokenizer._default_allowed_special = default_allowed_special or set()
tokenizer._default_allowed_special |= {"<|separator|>"}
tokenizer._control_tokens = DEFAULT_CONTROL_TOKENS
def encode_patched(
self,
......@@ -270,14 +276,14 @@ class TiktokenTokenizer:
self,
text,
allowed_special=allowed_special,
disallowed_special=disallowed_special,
disallowed_special=(),
)
tokenizer.encode = functools.partial(encode_patched, tokenizer)
# Convert to HF interface
self.tokenizer = tokenizer
self.eos_token_id = tokenizer._special_tokens["<|eos|>"]
self.eos_token_id = tokenizer._special_tokens[EOS]
self.vocab_size = tokenizer.n_vocab
self.chat_template = Template(
"{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
......
......@@ -212,6 +212,4 @@ def start_controller_process(
except Exception:
logger.error("Exception in ControllerMulti:\n" + get_exception_traceback())
finally:
for w in controller.workers:
os.kill(w.proc.pid, 9)
kill_parent_process()
......@@ -167,6 +167,4 @@ def start_controller_process(
except Exception:
logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
finally:
for t in controller.tp_procs:
os.kill(t.pid, 9)
kill_parent_process()
......@@ -16,7 +16,6 @@ limitations under the License.
"""Meta data for requests and batches"""
import logging
import warnings
from dataclasses import dataclass
from typing import List, Optional, Union
......@@ -270,7 +269,7 @@ class Req:
if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
# TODO(lsyin): fix token fusion
logging.warning(
logger.warning(
"Token fusion between input and output, try to avoid this by removing the space at the end of the input."
)
return False
......@@ -753,7 +752,7 @@ class ScheduleBatch:
)
self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
def sample(self, logits: torch.Tensor, is_multi_node_tp=False):
def sample(self, logits: torch.Tensor):
# TODO(lsyin): move this into a part of layer and run with CUDA Graph
# Post process logits
logits = logits.contiguous()
......@@ -791,7 +790,7 @@ class ScheduleBatch:
)
if not torch.all(success):
logging.warning("Sampling failed, fallback to top_k=1 strategy")
logger.warning(f"Sampling failed. Fallback to top_k=1 strategy. {logits=}")
probs = probs.masked_fill(torch.isnan(probs), 0.0)
argmax_ids = torch.argmax(probs, dim=-1)
batch_next_token_ids = torch.where(
......@@ -808,16 +807,6 @@ class ScheduleBatch:
self.penalizer_orchestrator.cumulate_output_tokens(batch_next_token_ids)
if is_multi_node_tp:
# If the tensor parallelism spans across multiple nodes, there is some indeterminism
# that can cause the TP workers to generate different tokens, so we need to
# sync here
torch.distributed.all_reduce(
batch_next_token_ids,
op=dist.ReduceOp.MIN,
group=get_tensor_model_parallel_group().device_group,
)
return batch_next_token_ids
......@@ -835,7 +824,8 @@ def top_k_top_p_sampling_from_probs_torch(
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
try:
sampled_index = torch.multinomial(probs_sort, num_samples=1)
except RuntimeError:
except RuntimeError as e:
logger.warning(f"Sampling error: {e}")
batch_next_token_ids = torch.zeros(
(probs_sort.shape[0],), dtype=torch.int32, device=probs.device
)
......
......@@ -133,6 +133,13 @@ class ModelTpServer:
self.model_config.context_len - 1,
self.max_total_num_tokens - 1,
)
# Sync random seed
server_args.random_seed = broadcast_recv_input(
[server_args.random_seed],
self.tp_rank,
self.model_runner.tp_group.cpu_group,
)[0]
set_random_seed(server_args.random_seed)
# Print info
......@@ -474,9 +481,7 @@ class ModelTpServer:
# Forward and sample the next tokens
if batch.extend_num_tokens != 0:
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
next_token_ids = batch.sample(
output.next_token_logits, self.model_runner.is_multi_node_tp
)
next_token_ids = batch.sample(output.next_token_logits)
# Move logprobs to cpu
if output.next_token_logprobs is not None:
......@@ -636,9 +641,7 @@ class ModelTpServer:
# Forward and sample the next tokens
output = self.model_runner.forward(batch, ForwardMode.DECODE)
next_token_ids = batch.sample(
output.next_token_logits, self.model_runner.is_multi_node_tp
)
next_token_ids = batch.sample(output.next_token_logits)
# Move logprobs to cpu
if output.next_token_logprobs is not None:
......@@ -879,6 +882,7 @@ def broadcast_recv_input(
dist.broadcast(tensor_size, src=0, group=dist_group)
dist.broadcast(tensor_data, src=0, group=dist_group)
return data
else:
tensor_size = torch.tensor([0], dtype=torch.long)
dist.broadcast(tensor_size, src=0, group=dist_group)
......
......@@ -84,13 +84,20 @@ def set_torch_compile_config():
class CudaGraphRunner:
def __init__(self, model_runner, max_batch_size_to_capture, use_torch_compile):
def __init__(
self,
model_runner,
max_batch_size_to_capture: int,
use_torch_compile: bool,
disable_padding: bool,
):
self.model_runner = model_runner
self.graphs = {}
self.input_buffers = {}
self.output_buffers = {}
self.flashinfer_handlers = {}
self.graph_memory_pool = None
self.disable_padding = disable_padding
# Common inputs
self.max_bs = max_batch_size_to_capture
......@@ -142,7 +149,10 @@ class CudaGraphRunner:
set_torch_compile_config()
def can_run(self, batch_size):
return batch_size <= self.max_bs
if self.disable_padding:
return batch_size in self.graphs
else:
return batch_size <= self.max_bs
def capture(self, batch_size_list):
self.batch_size_list = batch_size_list
......
......@@ -465,6 +465,7 @@ class ModelRunner:
self,
max_batch_size_to_capture=max(batch_size_list),
use_torch_compile=self.server_args.enable_torch_compile,
disable_padding=self.server_args.disable_cuda_graph_padding,
)
try:
self.cuda_graph_runner.capture(batch_size_list)
......
......@@ -24,7 +24,6 @@ import json
import logging
import multiprocessing as mp
import os
import sys
import threading
import time
from http import HTTPStatus
......@@ -301,27 +300,29 @@ def launch_server(
server_args.tokenizer_path = prepare_tokenizer(server_args.tokenizer_path)
# Launch processes for multi-node tensor parallelism
if server_args.nnodes > 1:
if server_args.node_rank != 0:
tp_size_local = server_args.tp_size // server_args.nnodes
gpu_ids = [
i for _ in range(server_args.nnodes) for i in range(tp_size_local)
]
tp_rank_range = list(
range(
server_args.node_rank * tp_size_local,
(server_args.node_rank + 1) * tp_size_local,
)
if server_args.nnodes > 1 and server_args.node_rank != 0:
tp_size_local = server_args.tp_size // server_args.nnodes
gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
tp_rank_range = list(
range(
server_args.node_rank * tp_size_local,
(server_args.node_rank + 1) * tp_size_local,
)
procs = launch_tp_servers(
gpu_ids,
tp_rank_range,
server_args,
ports[3],
model_overide_args,
)
while True:
pass
)
procs = launch_tp_servers(
gpu_ids,
tp_rank_range,
server_args,
ports[3],
model_overide_args,
)
try:
for p in procs:
p.join()
finally:
kill_child_process(os.getpid(), including_parent=False)
return
# Launch processes
tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
......@@ -356,15 +357,11 @@ def launch_server(
if controller_init_state != "init ok" or detoken_init_state != "init ok":
proc_controller.kill()
proc_detoken.kill()
print(
f"Initialization failed. controller_init_state: {controller_init_state}",
flush=True,
raise RuntimeError(
"Initialization failed. "
f"controller_init_state: {controller_init_state}, "
f"detoken_init_state: {detoken_init_state}"
)
print(
f"Initialization failed. detoken_init_state: {detoken_init_state}",
flush=True,
)
sys.exit(1)
assert proc_controller.is_alive() and proc_detoken.is_alive()
# Add api key authorization
......@@ -373,12 +370,12 @@ def launch_server(
# Send a warmup request
t = threading.Thread(
target=_wait_and_warmup, args=(server_args, pipe_finish_writer)
target=_wait_and_warmup, args=(server_args, pipe_finish_writer, os.getpid())
)
t.start()
# Listen for requests
try:
# Listen for requests
uvicorn.run(
app,
host=server_args.host,
......@@ -426,7 +423,7 @@ def _set_envs_and_config(server_args: ServerArgs):
)
def _wait_and_warmup(server_args, pipe_finish_writer):
def _wait_and_warmup(server_args, pipe_finish_writer, pid):
headers = {}
url = server_args.url()
if server_args.api_key:
......@@ -449,8 +446,9 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
if not success:
if pipe_finish_writer is not None:
pipe_finish_writer.send(last_traceback)
print(f"Initialization failed. warmup error: {last_traceback}", flush=True)
sys.exit(1)
logger.error(f"Initialization failed. warmup error: {last_traceback}")
kill_child_process(pid, including_parent=False)
return
# Send a warmup request
request_name = "/generate" if model_info["is_generation"] else "/encode"
......@@ -475,12 +473,13 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
timeout=600,
)
assert res.status_code == 200, f"{res}"
except Exception as e:
except Exception:
last_traceback = get_exception_traceback()
if pipe_finish_writer is not None:
pipe_finish_writer.send(last_traceback)
print(f"Initialization failed. warmup error: {last_traceback}", flush=True)
sys.exit(1)
logger.error(f"Initialization failed. warmup error: {last_traceback}")
kill_child_process(pid, including_parent=False)
return
logger.info("The server is fired up and ready to roll!")
if pipe_finish_writer is not None:
......
......@@ -79,6 +79,7 @@ class ServerArgs:
disable_radix_cache: bool = False
disable_regex_jump_forward: bool = False
disable_cuda_graph: bool = False
disable_cuda_graph_padding: bool = False
disable_disk_cache: bool = False
enable_mixed_chunk: bool = False
enable_torch_compile: bool = False
......@@ -393,6 +394,11 @@ class ServerArgs:
action="store_true",
help="Disable cuda graph.",
)
parser.add_argument(
"--disable-cuda-graph-padding",
action="store_true",
help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.",
)
parser.add_argument(
"--disable-disk-cache",
action="store_true",
......
......@@ -369,14 +369,11 @@ def kill_parent_process():
"""Kill the parent process and all children of the parent process."""
current_process = psutil.Process()
parent_process = current_process.parent()
children = parent_process.children(recursive=True)
for child in children:
if child.pid != current_process.pid:
os.kill(child.pid, 9)
os.kill(parent_process.pid, 9)
kill_child_process(parent_process.pid, skip_pid=current_process.pid)
def kill_child_process(pid, including_parent=True):
def kill_child_process(pid, including_parent=True, skip_pid=None):
"""Kill the process and all its children process."""
try:
parent = psutil.Process(pid)
except psutil.NoSuchProcess:
......@@ -384,6 +381,8 @@ def kill_child_process(pid, including_parent=True):
children = parent.children(recursive=True)
for child in children:
if child.pid == skip_pid:
continue
try:
child.kill()
except psutil.NoSuchProcess:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment