Unverified Commit 6f560c76 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Improve the control of streaming and improve the first token latency in streaming (#117)

parent cd687233
......@@ -21,14 +21,17 @@ class FinishReason(Enum):
class Req:
def __init__(self, rid):
def __init__(self, rid, input_text, input_ids):
self.rid = rid
self.input_text = None
self.input_ids = []
self.input_text = input_text
self.input_ids = input_ids
self.output_ids = []
# For vision input
self.pixel_values = None
self.image_size = None
self.image_offset = 0
self.sampling_params = None
self.return_logprob = False
self.logprob_start_len = 0
......@@ -46,7 +49,7 @@ class Req:
self.logprob = None
self.normalized_logprob = None
# for constrained decoding
# For constrained decoding
self.regex_fsm = None
self.regex_fsm_state = 0
self.fast_forward_map = None
......
......@@ -40,7 +40,7 @@ class RouterManager:
for obj in out_pyobjs:
self.send_to_detokenizer.send_pyobj(obj)
# async sleep for recving the subsequent request, and avoiding cache miss
# async sleep for receiving the subsequent request and avoiding cache miss
if len(out_pyobjs) != 0:
has_finished = any([obj.finished for obj in out_pyobjs])
if has_finished:
......
......@@ -17,8 +17,8 @@ from sglang.srt.constrained.fsm_cache import FSMCache
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import (
BatchTokenIDOut,
TokenizedGenerateReqInput,
FlushCacheReq,
TokenizedGenerateReqInput,
)
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req
from sglang.srt.managers.router.model_runner import ModelRunner
......@@ -194,6 +194,9 @@ class ModelRpcServer(rpyc.Service):
if self.running_batch.is_empty():
self.running_batch = None
break
if self.out_pyobjs and self.running_batch.reqs[0].stream:
break
else:
# check the available size
available_size = (
......@@ -208,8 +211,7 @@ class ModelRpcServer(rpyc.Service):
)
if self.running_batch is not None and self.tp_rank == 0:
if self.decode_forward_ct >= 20:
self.decode_forward_ct = 0
if self.decode_forward_ct % 20 == 0:
num_used = self.max_total_num_token - (
self.token_to_kv_pool.available_size()
+ self.tree_cache.evictable_size()
......@@ -225,11 +227,8 @@ class ModelRpcServer(rpyc.Service):
self,
recv_req: TokenizedGenerateReqInput,
):
req = Req(recv_req.rid)
req.input_text = recv_req.input_text
req.input_ids = recv_req.input_ids
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
req.pixel_values = recv_req.pixel_values
req.image_size = recv_req.image_size
if req.pixel_values is not None:
pad_value = [
(recv_req.image_hash) % self.model_config.vocab_size,
......@@ -240,6 +239,7 @@ class ModelRpcServer(rpyc.Service):
req.input_ids, req.image_offset = self.model_runner.model.pad_input_ids(
req.input_ids, pad_value, req.pixel_values.shape, req.image_size
)
req.image_size = recv_req.image_size
req.sampling_params = recv_req.sampling_params
req.return_logprob = recv_req.return_logprob
req.logprob_start_len = recv_req.logprob_start_len
......@@ -327,9 +327,11 @@ class ModelRpcServer(rpyc.Service):
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
< available_size
):
# Undo the insertion
delta = self.tree_cache.dec_ref_counter(req.last_node)
available_size += delta
else:
# Add this request to the running batch
self.token_to_kv_pool.add_refs(req.prefix_indices)
can_run_list.append(req)
new_batch_total_tokens += (
......@@ -421,7 +423,7 @@ class ModelRpcServer(rpyc.Service):
return
# Update batch tensors
self.decode_forward_ct += 1
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
batch.prepare_for_decode()
# Forward
......@@ -454,7 +456,13 @@ class ModelRpcServer(rpyc.Service):
unfinished_indices.append(i)
if req.finished or (
req.stream and self.decode_forward_ct % self.stream_interval == 0
(
req.stream
and (
self.decode_forward_ct % self.stream_interval == 0
or len(req.output_ids) == 1
)
)
):
output_rids.append(req.rid)
output_tokens.append(req.output_ids)
......
......@@ -7,7 +7,6 @@ from typing import List
import numpy as np
import torch
import sglang
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.utils import is_multimodal_model
......@@ -16,6 +15,8 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.model_loader import _set_default_torch_dtype
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
import sglang
logger = logging.getLogger("model_runner")
......
......@@ -18,9 +18,9 @@ from sglang.srt.hf_transformers_utils import (
)
from sglang.srt.managers.io_struct import (
BatchStrOut,
FlushCacheReq,
GenerateReqInput,
TokenizedGenerateReqInput,
FlushCacheReq,
)
from sglang.srt.mm_utils import expand2square, process_anyres_image
from sglang.srt.sampling_params import SamplingParams
......
......@@ -158,7 +158,7 @@ class LlavaLlamaForCausalLM(nn.Module):
num_patch_height, num_patch_width, height, width, -1
)
else:
raise NotImplementedError
raise NotImplementedError()
if "unpad" in self.mm_patch_merge_type:
image_feature = image_feature.permute(
4, 0, 2, 1, 3
......
......@@ -19,7 +19,7 @@ class ServerArgs:
schedule_heuristic: str = "lpm"
schedule_conservativeness: float = 1.0
random_seed: int = 42
stream_interval: int = 2
stream_interval: int = 8
disable_log_stats: bool = False
log_stats_interval: int = 10
log_level: str = "info"
......@@ -132,7 +132,7 @@ class ServerArgs:
"--stream-interval",
type=int,
default=ServerArgs.stream_interval,
help="The interval in terms of token length for streaming",
help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
)
parser.add_argument(
"--log-level",
......
......@@ -28,7 +28,7 @@ def test_generate_worker(model_path, tp_rank, tp_size):
reqs = []
for i in range(len(prompts)):
req = Req(i)
req = Req(i, None, None)
req.input_ids = tokenizer.encode(prompts[i])[:cut_num]
req.sampling_params = sampling_params
reqs.append(req)
......
......@@ -112,6 +112,7 @@ def test_generate_worker(
prefill_params = (
torch.tensor(np.array(input_ids)).cuda(),
np.array(pixel_values),
[None],
[offset],
*params,
)
......
"""
Usage:
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
python3 test_httpserver_decode.py
Output:
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo
......
"""
Usage:
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
python3 test_httpserver_decode_stream.py
Output:
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo
......
"""
Usage:
python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000
python3 test_httpserver_llava.py
Output:
The image features a man standing on the back of a yellow taxi cab, holding
......@@ -64,9 +66,12 @@ def test_streaming(args):
)
prev = 0
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
for chunk in response.iter_lines(decode_unicode=False):
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]":
break
data = json.loads(chunk[5:].strip("\n"))
output = data["text"].strip()
print(output[prev:], end="", flush=True)
prev = len(output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment