Unverified Commit 8210ec60 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Improve error handling & abort disconnected requests (#449)

parent 5be9eb8a
......@@ -34,7 +34,7 @@ class RuntimeEndpoint(BaseBackend):
api_key=self.api_key,
verify=self.verify,
)
assert res.status_code == 200
self._assert_success(res)
self.model_info = res.json()
self.chat_template = get_chat_template_by_model_path(
......@@ -50,7 +50,7 @@ class RuntimeEndpoint(BaseBackend):
auth_token=self.auth_token,
verify=self.verify,
)
return res.status_code == 200
self._assert_success(res)
def get_server_args(self):
res = http_request(
......@@ -58,6 +58,7 @@ class RuntimeEndpoint(BaseBackend):
auth_token=self.auth_token,
verify=self.verify,
)
self._assert_success(res)
return res.json()
def get_chat_template(self):
......@@ -71,7 +72,7 @@ class RuntimeEndpoint(BaseBackend):
api_key=self.api_key,
verify=self.verify,
)
assert res.status_code == 200
self._assert_success(res)
def commit_lazy_operations(self, s: StreamExecutor):
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
......@@ -83,7 +84,7 @@ class RuntimeEndpoint(BaseBackend):
api_key=self.api_key,
verify=self.verify,
)
assert res.status_code == 200
self._assert_success(res)
def fill_image(self, s: StreamExecutor):
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
......@@ -95,7 +96,7 @@ class RuntimeEndpoint(BaseBackend):
api_key=self.api_key,
verify=self.verify,
)
assert res.status_code == 200
self._assert_success(res)
def generate(
self,
......@@ -133,6 +134,8 @@ class RuntimeEndpoint(BaseBackend):
api_key=self.api_key,
verify=self.verify,
)
self._assert_success(res)
obj = res.json()
comp = obj["text"]
return comp, obj["meta_info"]
......@@ -167,7 +170,7 @@ class RuntimeEndpoint(BaseBackend):
data["stream"] = True
self._add_images(s, data)
response = http_request(
res = http_request(
self.base_url + "/generate",
json=data,
stream=True,
......@@ -175,10 +178,11 @@ class RuntimeEndpoint(BaseBackend):
api_key=self.api_key,
verify=self.verify,
)
self._assert_success(res)
pos = 0
incomplete_text = ""
for chunk in response.iter_lines(decode_unicode=False):
for chunk in res.iter_lines(decode_unicode=False):
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]":
......@@ -211,7 +215,7 @@ class RuntimeEndpoint(BaseBackend):
api_key=self.api_key,
verify=self.verify,
)
assert res.status_code == 200
self._assert_success(res)
prompt_len = res.json()["meta_info"]["prompt_tokens"]
# Compute logprob
......@@ -229,7 +233,7 @@ class RuntimeEndpoint(BaseBackend):
api_key=self.api_key,
verify=self.verify,
)
assert res.status_code == 200
self._assert_success(res)
obj = res.json()
normalized_prompt_logprobs = [
r["meta_info"]["normalized_prompt_logprob"] for r in obj
......@@ -253,9 +257,13 @@ class RuntimeEndpoint(BaseBackend):
api_key=self.api_key,
verify=self.verify,
)
assert res.status_code == 200
self._assert_success(res)
def _add_images(self, s: StreamExecutor, data):
if s.images_:
assert len(s.images_) == 1, "Only support one image."
data["image_data"] = s.images_[0][1]
def _assert_success(self, res):
if res.status_code != 200:
raise RuntimeError(res.json())
\ No newline at end of file
......@@ -191,7 +191,7 @@ class StreamExecutor:
self.variable_event = {} # Dict[name: str -> event: threading.Event]
self.meta_info = {} # Dict[name: str -> info: str]
self.is_finished = False
self.error = None
self.error_ = None
# For completion
self.text_ = "" # The full text
......@@ -300,6 +300,10 @@ class StreamExecutor:
self.sync()
return self.messages_
def error(self):
self.sync()
return self.error_
def end(self):
if self.use_thread:
if self.worker.is_alive():
......@@ -338,7 +342,7 @@ class StreamExecutor:
if self.stream_var_event:
for name in self.stream_var_event:
self.stream_var_event[name].set()
self.error = error
self.error_ = error
if self.stream_text_event:
self.stream_text_event.set()
......@@ -713,7 +717,7 @@ class ProgramState:
return self.stream_executor.sync()
def error(self):
return self.stream_executor.error
return self.stream_executor.error()
def text_iter(self, var_name: Optional[str] = None):
if self.stream_executor.stream:
......
......@@ -31,12 +31,9 @@ class GenerateReqInput:
def post_init(self):
if self.text is None:
assert (
self.input_ids is not None
), "Either text or input_ids should be provided"
else:
assert self.input_ids is None, "Either text or input_ids should be provided"
if ((self.text is None and self.input_ids is None) or
(self.text is not None and self.input_ids is not None)):
raise ValueError("Either text or input_ids should be provided.")
if self.text is not None:
is_single = isinstance(self.text, str)
......@@ -71,7 +68,8 @@ class GenerateReqInput:
if self.rid is None:
self.rid = [uuid.uuid4().hex for _ in range(num)]
else:
assert isinstance(self.rid, list)
if not isinstance(self.rid, list):
raise ValueError("The rid should be a list.")
if self.return_logprob is None:
self.return_logprob = [False] * num
......@@ -129,6 +127,11 @@ class FlushCacheReq:
pass
@dataclass
class AbortReq:
rid: str
@dataclass
class DetokenizeReqInput:
input_ids: List[int]
......@@ -20,6 +20,7 @@ from sglang.srt.constrained.fsm_cache import FSMCache
from sglang.srt.constrained.jump_forward import JumpForwardCache
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import (
AbortReq,
BatchTokenIDOut,
FlushCacheReq,
TokenizedGenerateReqInput,
......@@ -110,6 +111,8 @@ class ModelRpcServer:
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
)
set_random_seed(server_args.random_seed)
# Print info
logger.info(
f"Rank {self.tp_rank}: "
f"max_total_num_token={self.max_total_num_token}, "
......@@ -160,24 +163,6 @@ class ModelRpcServer:
self.min_new_token_ratio = min(0.2 * server_args.schedule_conservativeness, 1.0)
self.new_token_ratio_step = (0.0001, 0.05) # (down, up)
def flush_cache(self):
if len(self.forward_queue) == 0 and (
self.running_batch is None or len(self.running_batch.reqs) == 0
):
self.tree_cache.reset()
self.tree_cache_metrics = {"total": 0, "hit": 0}
self.regex_fsm_cache.reset()
self.req_to_token_pool.clear()
self.token_to_kv_pool.clear()
torch.cuda.empty_cache()
logger.info("Cache flushed successfully!")
else:
warnings.warn(
f"Cache not flushed because there are pending requests. "
f"#queue-req: {len(self.forward_queue)}, "
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
)
def exposed_step(self, recv_reqs):
if self.tp_size != 1:
recv_reqs = obtain(recv_reqs)
......@@ -189,6 +174,8 @@ class ModelRpcServer:
self.handle_generate_request(recv_req)
elif isinstance(recv_req, FlushCacheReq):
self.flush_cache()
elif isinstance(recv_req, AbortReq):
self.abort_request(recv_req)
else:
raise ValueError(f"Invalid request: {recv_req}")
......@@ -207,9 +194,8 @@ class ModelRpcServer:
new_batch = self.get_new_fill_batch()
if new_batch is not None:
# Run new fill batch
# Run a new fill batch
self.forward_fill_batch(new_batch)
self.cache_filled_batch(new_batch)
if not new_batch.is_empty():
......@@ -225,14 +211,8 @@ class ModelRpcServer:
self.num_generated_tokens += len(self.running_batch.reqs)
self.forward_decode_batch(self.running_batch)
if self.running_batch.is_empty():
self.running_batch = None
break
if self.out_pyobjs and self.running_batch.reqs[0].stream:
break
if self.running_batch is not None and self.tp_rank == 0:
# Print stats
if self.tp_rank == 0:
if self.decode_forward_ct % 40 == 0:
num_used = self.max_total_num_token - (
self.token_to_kv_pool.available_size()
......@@ -250,8 +230,15 @@ class ModelRpcServer:
f"gen throughput (token/s): {throuhgput:.2f}, "
f"#queue-req: {len(self.forward_queue)}"
)
if self.running_batch.is_empty():
self.running_batch = None
break
if self.out_pyobjs and self.running_batch.reqs[0].stream:
break
else:
# check the available size
# Check the available size
available_size = (
self.token_to_kv_pool.available_size()
+ self.tree_cache.evictable_size()
......@@ -295,7 +282,7 @@ class ModelRpcServer:
req.sampling_params.regex
)
# Truncate long prompts
# Truncate prompts that are too long
req.input_ids = req.input_ids[: self.model_config.context_len - 1]
req.sampling_params.max_new_tokens = min(
req.sampling_params.max_new_tokens,
......@@ -311,6 +298,7 @@ class ModelRpcServer:
):
return None
# Compute matched prefix length
for req in self.forward_queue:
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
if req.return_logprob:
......@@ -383,6 +371,7 @@ class ModelRpcServer:
if len(can_run_list) == 0:
return None
# Print stats
if self.tp_rank == 0:
running_req = (
0 if self.running_batch is None else len(self.running_batch.reqs)
......@@ -410,6 +399,7 @@ class ModelRpcServer:
# f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
# )
# Return the new batch
new_batch = Batch.init_new(
can_run_list,
self.req_to_token_pool,
......@@ -487,7 +477,7 @@ class ModelRpcServer:
self.handle_finished_requests(batch)
def cache_filled_batch(self, batch: Batch):
req_pool_indices_cpu = batch.req_pool_indices.cpu().tolist()
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
for i, req in enumerate(batch.reqs):
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
token_ids=tuple(req.input_ids + req.output_ids)[:-1],
......@@ -671,6 +661,34 @@ class ModelRpcServer:
else:
batch.reqs = []
def flush_cache(self):
if len(self.forward_queue) == 0 and (
self.running_batch is None or len(self.running_batch.reqs) == 0
):
self.tree_cache.reset()
self.tree_cache_metrics = {"total": 0, "hit": 0}
self.regex_fsm_cache.reset()
self.req_to_token_pool.clear()
self.token_to_kv_pool.clear()
torch.cuda.empty_cache()
logger.info("Cache flushed successfully!")
else:
warnings.warn(
f"Cache not flushed because there are pending requests. "
f"#queue-req: {len(self.forward_queue)}, "
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
)
def abort_request(self, recv_req):
to_del = None
for i, req in enumerate(self.forward_queue):
if req.rid == recv_req.rid:
to_del = i
break
if to_del is not None:
del self.forward_queue[to_del]
class ModelRpcService(rpyc.Service):
exposed_ModelRpcServer = ModelRpcServer
......
......@@ -19,6 +19,7 @@ from sglang.srt.hf_transformers_utils import (
get_tokenizer,
)
from sglang.srt.managers.io_struct import (
AbortReq,
BatchStrOut,
FlushCacheReq,
GenerateReqInput,
......@@ -42,52 +43,6 @@ class ReqState:
event: asyncio.Event
global global_processor
def init_global_processor(server_args: ServerArgs):
global global_processor
transformers.logging.set_verbosity_error()
global_processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
def get_pixel_values(
image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None
):
try:
processor = processor or global_processor
image, image_size = load_image(image_data)
if image_size != None:
image_hash = hash(image_data)
pixel_values = processor.image_processor(image)["pixel_values"]
for _ in range(len(pixel_values)):
pixel_values[_] = pixel_values[_].astype(np.float16)
pixel_values = np.stack(pixel_values, axis=0)
return pixel_values, image_hash, image_size
else:
image_hash = hash(image_data)
if image_aspect_ratio == "pad":
image = expand2square(
image,
tuple(int(x * 255) for x in processor.image_processor.image_mean),
)
pixel_values = processor.image_processor(image)["pixel_values"][0]
elif image_aspect_ratio == "anyres":
pixel_values = process_anyres_image(
image, processor.image_processor, image_grid_pinpoints
)
else:
pixel_values = processor.image_processor(image)["pixel_values"][0]
pixel_values = pixel_values.astype(np.float16)
return pixel_values, image_hash, image.size
except Exception:
print("Exception in TokenizerManager:\n" + get_exception_traceback())
class TokenizerManager:
def __init__(
self,
......@@ -154,10 +109,11 @@ class TokenizerManager:
image_data, aspect_ratio, grid_pinpoints, self.processor
)
async def generate_request(self, obj: GenerateReqInput):
async def generate_request(self, obj: GenerateReqInput, request=None):
if self.to_create_loop:
await self.create_handle_loop()
self.create_handle_loop()
obj.post_init()
is_single = obj.is_single
if is_single:
rid = obj.rid
......@@ -170,7 +126,7 @@ class TokenizerManager:
if len(input_ids) >= self.context_len:
raise ValueError(
f"The input ({len(input_ids)} tokens) is longer than the "
f"model's context length ({self.context_len} tokens)"
f"model's context length ({self.context_len} tokens)."
)
sampling_params = SamplingParams(**obj.sampling_params)
......@@ -208,7 +164,14 @@ class TokenizerManager:
self.rid_to_state[rid] = state
while True:
await event.wait()
try:
await asyncio.wait_for(event.wait(), timeout=5)
except asyncio.TimeoutError:
if request is not None and await request.is_disconnected():
self.abort_request(rid)
raise ValueError(f"Abort request {rid}")
continue
out = self.convert_logprob_style(
state.out_list[-1],
obj.return_logprob,
......@@ -226,7 +189,8 @@ class TokenizerManager:
break
event.clear()
else:
assert obj.stream is False
if obj.stream:
raise ValueError("Do not support stream for batch mode.")
if obj.input_ids is None:
bs = len(obj.text)
......@@ -276,7 +240,18 @@ class TokenizerManager:
for i in range(bs):
rid = obj.rid[i]
state = self.rid_to_state[rid]
await state.event.wait()
while True:
try:
await asyncio.wait_for(state.event.wait(), timeout=5)
break
except asyncio.TimeoutError:
if request is not None and await request.is_disconnected():
for rid in obj.rid:
self.abort_request(rid)
raise ValueError(f"Abort request {rid}")
continue
output_list.append(
self.convert_logprob_style(
state.out_list[-1],
......@@ -290,11 +265,16 @@ class TokenizerManager:
yield output_list
async def flush_cache(self):
flush_cache_req = FlushCacheReq()
self.send_to_router.send_pyobj(flush_cache_req)
def flush_cache(self):
req = FlushCacheReq()
self.send_to_router.send_pyobj(req)
async def create_handle_loop(self):
def abort_request(self, rid):
del self.rid_to_state[rid]
req = AbortReq(rid)
self.send_to_router.send_pyobj(req)
def create_handle_loop(self):
self.to_create_loop = False
loop = asyncio.get_event_loop()
loop.create_task(self.handle_loop())
......@@ -305,17 +285,20 @@ class TokenizerManager:
if isinstance(recv_obj, BatchStrOut):
for i, rid in enumerate(recv_obj.rids):
state = self.rid_to_state.get(rid, None)
if state is None:
continue
recv_obj.meta_info[i]["id"] = rid
out_dict = {
"text": recv_obj.output_str[i],
"meta_info": recv_obj.meta_info[i],
}
state = self.rid_to_state[rid]
state.out_list.append(out_dict)
state.finished = recv_obj.finished[i]
state.event.set()
else:
raise ValueError(f"Invalid object: {recv_obj}")
raise ValueError(f"Invalid object: {recv_obj}.")
def convert_logprob_style(
self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs
......@@ -356,3 +339,50 @@ class TokenizerManager:
if t:
top_logprobs[i] = self.detokenize_logprob_tokens(t, decode_to_text)
return top_logprobs
global global_processor
def init_global_processor(server_args: ServerArgs):
global global_processor
transformers.logging.set_verbosity_error()
global_processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
def get_pixel_values(
image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None
):
try:
processor = processor or global_processor
image, image_size = load_image(image_data)
if image_size != None:
image_hash = hash(image_data)
pixel_values = processor.image_processor(image)["pixel_values"]
for _ in range(len(pixel_values)):
pixel_values[_] = pixel_values[_].astype(np.float16)
pixel_values = np.stack(pixel_values, axis=0)
return pixel_values, image_hash, image_size
else:
image_hash = hash(image_data)
if image_aspect_ratio == "pad":
image = expand2square(
image,
tuple(int(x * 255) for x in processor.image_processor.image_mean),
)
pixel_values = processor.image_processor(image)["pixel_values"][0]
elif image_aspect_ratio == "anyres":
pixel_values = process_anyres_image(
image, processor.image_processor, image_grid_pinpoints
)
else:
pixel_values = processor.image_processor(image)["pixel_values"][0]
pixel_values = pixel_values.astype(np.float16)
return pixel_values, image_hash, image.size
except Exception:
print("Exception in TokenizerManager:\n" + get_exception_traceback())
\ No newline at end of file
......@@ -335,7 +335,7 @@ def to_openai_style_logprobs(
ret_logprobs.tokens.append(token_text)
ret_logprobs.token_logprobs.append(logprob)
# Not Supported yet
# Not supported yet
ret_logprobs.text_offset.append(-1)
def append_top_logprobs(top_logprobs):
......
......@@ -10,6 +10,7 @@ import sys
import threading
import time
from typing import List, Optional, Union
from http import HTTPStatus
# Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
......@@ -73,7 +74,7 @@ async def get_server_args():
@app.get("/flush_cache")
async def flush_cache():
await tokenizer_manager.flush_cache()
tokenizer_manager.flush_cache()
return Response(
content="Cache flushed.\nPlease check backend logs for more details. "
"(When there are running or waiting requests, the operation will not be performed.)\n",
......@@ -81,24 +82,25 @@ async def flush_cache():
)
async def generate_request(obj: GenerateReqInput):
obj.post_init()
async def generate_request(obj: GenerateReqInput, request: Request):
if obj.stream:
async def stream_results():
async for out in tokenizer_manager.generate_request(obj):
try:
async for out in tokenizer_manager.generate_request(obj, request):
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
except ValueError as e:
out = {"error": {"message": str(e)}}
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(stream_results(), media_type="text/event-stream")
try:
ret = await tokenizer_manager.generate_request(obj).__anext__()
return ret
except ValueError as e:
print(f"Error: {e}")
return JSONResponse({"error": str(e)}, status_code=400)
else:
try:
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret
except ValueError as e:
return JSONResponse({"error": {"message": str(e)}},
status_code=HTTPStatus.BAD_REQUEST)
app.post("/generate")(generate_request)
app.put("/generate")(generate_request)
......@@ -186,6 +188,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
if server_args.api_key and server_args.api_key != "":
app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)
# Send a warmup request
def _wait_and_warmup():
headers = {}
url = server_args.url()
......@@ -228,6 +231,8 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
t = threading.Thread(target=_wait_and_warmup)
t.start()
# Listen for requests
try:
uvicorn.run(
app,
......
......@@ -9,7 +9,7 @@ import requests
from sglang.backend.openai import OpenAI
from sglang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.global_config import global_config
from sglang.srt.utils import get_exception_traceback
from sglang.utils import get_exception_traceback
def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None):
......
......@@ -93,8 +93,12 @@ def http_request(
data = None
else:
data = bytes(dumps(json), encoding="utf-8")
resp = urllib.request.urlopen(req, data=data, cafile=verify)
return HttpResponse(resp)
try:
resp = urllib.request.urlopen(req, data=data, cafile=verify)
return HttpResponse(resp)
except urllib.error.HTTPError as e:
return HttpResponse(e)
def encode_image_base64(image_path):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment