Unverified Commit 562b8857 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Improve error handling (#433)

parent 04c0b214
......@@ -29,7 +29,7 @@ from sglang.lang.ir import (
SglVarScopeBegin,
SglVarScopeEnd,
)
from sglang.utils import encode_image_base64
from sglang.utils import encode_image_base64, get_exception_traceback
def run_internal(state, program, func_args, func_kwargs, sync):
......@@ -195,6 +195,7 @@ class StreamExecutor:
self.variable_event = {} # Dict[name: str -> event: threading.Event]
self.meta_info = {} # Dict[name: str -> info: str]
self.is_finished = False
self.error = None
# For completion
self.text_ = "" # The full text
......@@ -310,17 +311,39 @@ class StreamExecutor:
self.backend.end_program(self)
def _thread_worker_func(self):
error = None
while True:
expr = self.queue.get()
if expr is None:
self.queue.task_done()
break
self._execute(expr)
try:
self._execute(expr)
except Exception as e:
print(f"Error in stream_executor: {get_exception_traceback()}")
error = e
break
self.queue.task_done()
if self.stream_text_event:
self.stream_text_event.set()
# Clean the queue and events
if error is not None:
try:
while True:
self.queue.task_done()
self.queue.get_nowait()
except queue.Empty:
pass
for name in self.variable_event:
self.variable_event[name].set()
if self.stream_var_event:
for name in self.stream_var_event:
self.stream_var_event[name].set()
self.error = error
if self.stream_text_event:
self.stream_text_event.set()
......@@ -679,7 +702,9 @@ class ProgramState:
return self.stream_executor.messages()
def sync(self):
return self.stream_executor.sync()
ret = self.stream_executor.sync()
self.error = self.stream_executor.error
return ret
def text_iter(self, var_name: Optional[str] = None):
if self.stream_executor.stream:
......@@ -769,6 +794,9 @@ class ProgramState:
def __setitem__(self, name, value):
self.set_var(name, value)
def __contains__(self, name):
return name in self.stream_executor.variables
def __del__(self):
self.stream_executor.end()
......
"""
Usage:
python3 -m sglang.srt.flush_cache --url http://localhost:30000
"""
import argparse
import requests
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--url", type=str, default="http://localhost:30000")
args = parser.parse_args()
response = requests.get(args.url + "/flush_cache")
assert response.status_code == 200
\ No newline at end of file
......@@ -135,6 +135,8 @@ class ModelRpcServer:
self.out_pyobjs = []
self.decode_forward_ct = 0
self.stream_interval = server_args.stream_interval
self.num_generated_tokens = 0
self.last_stats_tic = time.time()
# Init the FSM cache for constrained generation
self.regex_fsm_cache = FSMCache(
......@@ -211,6 +213,7 @@ class ModelRpcServer:
if self.running_batch is not None:
# Run a few decode batches continuously for reducing overhead
for _ in range(10):
self.num_generated_tokens += len(self.running_batch.reqs)
self.forward_decode_batch(self.running_batch)
if self.running_batch.is_empty():
......@@ -226,10 +229,14 @@ class ModelRpcServer:
self.token_to_kv_pool.available_size()
+ self.tree_cache.evictable_size()
)
throuhgput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
self.num_generated_tokens = 0
self.last_stats_tic = time.time()
logger.info(
f"#running-req: {len(self.running_batch.reqs)}, "
f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_token:.2f}, "
f"gen throughput (token/s): {throuhgput:.2f}, "
f"#queue-req: {len(self.forward_queue)}"
)
else:
......
......@@ -17,8 +17,8 @@ from vllm.distributed import initialize_model_parallel
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.utils import is_multimodal_model
from sglang.utils import get_available_gpu_memory
from sglang.srt.utils import is_multimodal_model, get_available_gpu_memory
QUANTIZATION_CONFIG_MAPPING = {
"awq": AWQConfig,
......
......@@ -4,9 +4,7 @@ import base64
import os
import random
import socket
import sys
import time
import traceback
from importlib.metadata import PackageNotFoundError, version
from io import BytesIO
from typing import List, Optional
......@@ -20,6 +18,8 @@ from packaging import version as pkg_version
from pydantic import BaseModel
from starlette.middleware.base import BaseHTTPMiddleware
from sglang.utils import get_exception_traceback
show_time_cost = False
time_infos = {}
......@@ -90,6 +90,32 @@ def calculate_time(show=False, min_cost_ms=0.0):
return wrapper
def get_available_gpu_memory(gpu_id, distributed=True):
"""
Get available memory for cuda:gpu_id device.
When distributed is True, the available memory is the minimum available memory of all GPUs.
"""
num_gpus = torch.cuda.device_count()
assert gpu_id < num_gpus
if torch.cuda.current_device() != gpu_id:
print(
f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
"which may cause useless memory allocation for torch CUDA context.",
)
free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)
if distributed:
tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
torch.device("cuda", gpu_id)
)
torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN)
free_gpu_memory = tensor.item()
return free_gpu_memory / (1 << 30)
def set_random_seed(seed: int) -> None:
random.seed(seed)
......@@ -158,12 +184,6 @@ def allocate_init_ports(
return port, additional_ports
def get_exception_traceback():
etype, value, tb = sys.exc_info()
err_str = "".join(traceback.format_exception(etype, value, tb))
return err_str
def get_int_token_logit_bias(tokenizer, vocab_size):
# a bug when model's vocab size > tokenizer.vocab_size
vocab_size = tokenizer.vocab_size
......@@ -314,4 +334,4 @@ IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1
def jsonify_pydantic_model(obj: BaseModel):
if IS_PYDANTIC_1:
return obj.json(ensure_ascii=False)
return obj.model_dump_json()
\ No newline at end of file
return obj.model_dump_json()
......@@ -2,7 +2,9 @@
import base64
import json
import sys
import threading
import traceback
import urllib.request
from io import BytesIO
from json import dumps
......@@ -10,32 +12,10 @@ from json import dumps
import requests
def get_available_gpu_memory(gpu_id, distributed=True):
"""
Get available memory for cuda:gpu_id device.
When distributed is True, the available memory is the minimum available memory of all GPUs.
"""
import torch
num_gpus = torch.cuda.device_count()
assert gpu_id < num_gpus
if torch.cuda.current_device() != gpu_id:
print(
f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
"which may cause useless memory allocation for torch CUDA context.",
)
free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)
if distributed:
tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
torch.device("cuda", gpu_id)
)
torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN)
free_gpu_memory = tensor.item()
return free_gpu_memory / (1 << 30)
def get_exception_traceback():
etype, value, tb = sys.exc_info()
err_str = "".join(traceback.format_exception(etype, value, tb))
return err_str
def is_same_type(values):
......@@ -190,4 +170,4 @@ def run_with_timeout(func, args=(), kwargs=None, timeout=None):
if not ret_value:
raise RuntimeError()
return ret_value[0]
return ret_value[0]
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment