Unverified Commit 08ab2a16 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Json Decode && Mutl-Turns (#4)

parent f652494d
import asyncio import asyncio
import logging import logging
from typing import List, Tuple
import uvloop import uvloop
import zmq import zmq
...@@ -8,6 +7,7 @@ import zmq.asyncio ...@@ -8,6 +7,7 @@ import zmq.asyncio
from sglang.srt.managers.router.model_rpc import ModelRpcClient from sglang.srt.managers.router.model_rpc import ModelRpcClient
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import get_exception_traceback from sglang.srt.utils import get_exception_traceback
from sglang.srt.backend_config import GLOBAL_BACKEND_CONFIG
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
...@@ -28,6 +28,9 @@ class RouterManager: ...@@ -28,6 +28,9 @@ class RouterManager:
self.model_client = model_client self.model_client = model_client
self.recv_reqs = [] self.recv_reqs = []
# Init Some Configs
self.extend_dependency_time = GLOBAL_BACKEND_CONFIG.extend_dependency_time
async def loop_for_forward(self): async def loop_for_forward(self):
while True: while True:
next_step_input = list(self.recv_reqs) next_step_input = list(self.recv_reqs)
...@@ -37,7 +40,12 @@ class RouterManager: ...@@ -37,7 +40,12 @@ class RouterManager:
for obj in out_pyobjs: for obj in out_pyobjs:
self.send_to_detokenizer.send_pyobj(obj) self.send_to_detokenizer.send_pyobj(obj)
# await for a while to accept input requests # async sleep for recving the subsequent request, and avoiding cache miss
if len(out_pyobjs) != 0:
has_finished = any([obj.finished for obj in out_pyobjs])
if has_finished:
await asyncio.sleep(self.extend_dependency_time)
await asyncio.sleep(0.001) await asyncio.sleep(0.001)
async def loop_for_recv_requests(self): async def loop_for_recv_requests(self):
......
...@@ -19,7 +19,6 @@ from sglang.srt.managers.router.model_runner import ModelRunner ...@@ -19,7 +19,6 @@ from sglang.srt.managers.router.model_runner import ModelRunner
from sglang.srt.managers.router.radix_cache import RadixCache from sglang.srt.managers.router.radix_cache import RadixCache
from sglang.srt.managers.router.scheduler import Scheduler from sglang.srt.managers.router.scheduler import Scheduler
from sglang.srt.model_config import ModelConfig from sglang.srt.model_config import ModelConfig
from sglang.srt.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
get_exception_traceback, get_exception_traceback,
...@@ -158,6 +157,18 @@ class ModelRpcServer(rpyc.Service): ...@@ -158,6 +157,18 @@ class ModelRpcServer(rpyc.Service):
if self.running_batch.is_empty(): if self.running_batch.is_empty():
self.running_batch = None self.running_batch = None
break break
else:
# check the available size
available_size = (
self.token_to_kv_pool.available_size()
+ self.tree_cache.evictable_size()
)
if available_size != self.max_total_num_token:
logger.warning(
"Warning: "
f"available_size={available_size}, max_total_num_token={self.max_total_num_token}\n"
"KV cache pool leak detected!"
)
if self.running_batch is not None and self.tp_rank == 0: if self.running_batch is not None and self.tp_rank == 0:
if self.decode_forward_ct >= 20: if self.decode_forward_ct >= 20:
...@@ -408,7 +419,9 @@ class ModelRpcServer(rpyc.Service): ...@@ -408,7 +419,9 @@ class ModelRpcServer(rpyc.Service):
token_ids = tuple(req.input_ids + req.output_ids) token_ids = tuple(req.input_ids + req.output_ids)
seq_len = len(token_ids) - 1 seq_len = len(token_ids) - 1
indices = self.req_to_token_pool.req_to_token[req_pool_idx, :seq_len] indices = self.req_to_token_pool.req_to_token[req_pool_idx, :seq_len]
prefix_len = self.tree_cache.insert(token_ids, indices.clone()) prefix_len = self.tree_cache.insert(
token_ids[:seq_len], indices.clone()
)
self.token_to_kv_pool.free(indices[:prefix_len]) self.token_to_kv_pool.free(indices[:prefix_len])
self.req_to_token_pool.free(req_pool_idx) self.req_to_token_pool.free(req_pool_idx)
......
...@@ -18,7 +18,7 @@ class Scheduler: ...@@ -18,7 +18,7 @@ class Scheduler:
self.tree_cache = tree_cache self.tree_cache = tree_cache
def new_token_estimation_ratio(self): def new_token_estimation_ratio(self):
return 0.4 if self.schedule_heuristic != "fcfs" else 0.5 return 0.5 if self.schedule_heuristic != "fcfs" else 0.6
def get_priority_queue(self, forward_queue): def get_priority_queue(self, forward_queue):
if self.schedule_heuristic == "lpm": if self.schedule_heuristic == "lpm":
......
...@@ -7,13 +7,13 @@ _SAMPLING_EPS = 1e-6 ...@@ -7,13 +7,13 @@ _SAMPLING_EPS = 1e-6
class SamplingParams: class SamplingParams:
def __init__( def __init__(
self, self,
max_new_tokens: int = 16,
stop: Optional[Union[str, List[str]]] = None,
temperature: float = 1.0, temperature: float = 1.0,
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = -1, top_k: int = -1,
frequency_penalty: float = 0.0, frequency_penalty: float = 0.0,
presence_penalty: float = 0.0, presence_penalty: float = 0.0,
stop: Optional[Union[str, List[str]]] = None,
max_new_tokens: int = 16,
ignore_eos: bool = False, ignore_eos: bool = False,
skip_special_tokens: bool = True, skip_special_tokens: bool = True,
dtype: Optional[str] = None, dtype: Optional[str] = None,
......
...@@ -24,6 +24,8 @@ class ServerArgs: ...@@ -24,6 +24,8 @@ class ServerArgs:
def __post_init__(self): def __post_init__(self):
if self.tokenizer_path is None: if self.tokenizer_path is None:
self.tokenizer_path = self.model_path self.tokenizer_path = self.model_path
if self.tp_size > 1:
self.mem_fraction_static = 0.8
@staticmethod @staticmethod
def add_cli_args(parser: argparse.ArgumentParser): def add_cli_args(parser: argparse.ArgumentParser):
......
...@@ -38,6 +38,26 @@ def call_generate_vllm(prompt, temperature, max_tokens, stop, url, n=1): ...@@ -38,6 +38,26 @@ def call_generate_vllm(prompt, temperature, max_tokens, stop, url, n=1):
return pred return pred
def call_generate_outlines(
prompt, temperature, max_tokens, url, stop=[], regex=None, n=1
):
data = {
"prompt": prompt,
"temperature": temperature,
"max_tokens": max_tokens,
"stop": stop,
"regex": regex,
"n": n,
}
res = requests.post(url, json=data)
assert res.status_code == 200
if n == 1:
pred = res.json()["text"][0][len(prompt) :]
else:
pred = [x[len(prompt) :] for x in res.json()["text"]]
return pred
def call_generate_srt_raw(prompt, temperature, max_tokens, stop, url): def call_generate_srt_raw(prompt, temperature, max_tokens, stop, url):
data = { data = {
"text": prompt, "text": prompt,
......
...@@ -67,7 +67,7 @@ def dump_state_text(filename, states, mode="w"): ...@@ -67,7 +67,7 @@ def dump_state_text(filename, states, mode="w"):
if isinstance(s, str): if isinstance(s, str):
pass pass
elif isinstance(s, ProgramState): elif isinstance(s, ProgramState):
s = s.text().strip() s = s.text()
else: else:
s = str(s) s = str(s)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment