"src/diffusers/commands/diffusers_cli.py" did not exist on "dc6324d44bc189a0bf63018145617a736e7a38ff"
Unverified Commit 6f560c76 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Improve the control of streaming and improve the first token latency in streaming (#117)

parent cd687233
...@@ -21,14 +21,17 @@ class FinishReason(Enum): ...@@ -21,14 +21,17 @@ class FinishReason(Enum):
class Req: class Req:
def __init__(self, rid): def __init__(self, rid, input_text, input_ids):
self.rid = rid self.rid = rid
self.input_text = None self.input_text = input_text
self.input_ids = [] self.input_ids = input_ids
self.output_ids = [] self.output_ids = []
# For vision input
self.pixel_values = None self.pixel_values = None
self.image_size = None self.image_size = None
self.image_offset = 0 self.image_offset = 0
self.sampling_params = None self.sampling_params = None
self.return_logprob = False self.return_logprob = False
self.logprob_start_len = 0 self.logprob_start_len = 0
...@@ -46,7 +49,7 @@ class Req: ...@@ -46,7 +49,7 @@ class Req:
self.logprob = None self.logprob = None
self.normalized_logprob = None self.normalized_logprob = None
# for constrained decoding # For constrained decoding
self.regex_fsm = None self.regex_fsm = None
self.regex_fsm_state = 0 self.regex_fsm_state = 0
self.fast_forward_map = None self.fast_forward_map = None
......
...@@ -40,7 +40,7 @@ class RouterManager: ...@@ -40,7 +40,7 @@ class RouterManager:
for obj in out_pyobjs: for obj in out_pyobjs:
self.send_to_detokenizer.send_pyobj(obj) self.send_to_detokenizer.send_pyobj(obj)
# async sleep for recving the subsequent request, and avoiding cache miss # async sleep for receiving the subsequent request and avoiding cache miss
if len(out_pyobjs) != 0: if len(out_pyobjs) != 0:
has_finished = any([obj.finished for obj in out_pyobjs]) has_finished = any([obj.finished for obj in out_pyobjs])
if has_finished: if has_finished:
......
...@@ -17,8 +17,8 @@ from sglang.srt.constrained.fsm_cache import FSMCache ...@@ -17,8 +17,8 @@ from sglang.srt.constrained.fsm_cache import FSMCache
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
BatchTokenIDOut, BatchTokenIDOut,
TokenizedGenerateReqInput,
FlushCacheReq, FlushCacheReq,
TokenizedGenerateReqInput,
) )
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req
from sglang.srt.managers.router.model_runner import ModelRunner from sglang.srt.managers.router.model_runner import ModelRunner
...@@ -194,6 +194,9 @@ class ModelRpcServer(rpyc.Service): ...@@ -194,6 +194,9 @@ class ModelRpcServer(rpyc.Service):
if self.running_batch.is_empty(): if self.running_batch.is_empty():
self.running_batch = None self.running_batch = None
break break
if self.out_pyobjs and self.running_batch.reqs[0].stream:
break
else: else:
# check the available size # check the available size
available_size = ( available_size = (
...@@ -208,8 +211,7 @@ class ModelRpcServer(rpyc.Service): ...@@ -208,8 +211,7 @@ class ModelRpcServer(rpyc.Service):
) )
if self.running_batch is not None and self.tp_rank == 0: if self.running_batch is not None and self.tp_rank == 0:
if self.decode_forward_ct >= 20: if self.decode_forward_ct % 20 == 0:
self.decode_forward_ct = 0
num_used = self.max_total_num_token - ( num_used = self.max_total_num_token - (
self.token_to_kv_pool.available_size() self.token_to_kv_pool.available_size()
+ self.tree_cache.evictable_size() + self.tree_cache.evictable_size()
...@@ -225,11 +227,8 @@ class ModelRpcServer(rpyc.Service): ...@@ -225,11 +227,8 @@ class ModelRpcServer(rpyc.Service):
self, self,
recv_req: TokenizedGenerateReqInput, recv_req: TokenizedGenerateReqInput,
): ):
req = Req(recv_req.rid) req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
req.input_text = recv_req.input_text
req.input_ids = recv_req.input_ids
req.pixel_values = recv_req.pixel_values req.pixel_values = recv_req.pixel_values
req.image_size = recv_req.image_size
if req.pixel_values is not None: if req.pixel_values is not None:
pad_value = [ pad_value = [
(recv_req.image_hash) % self.model_config.vocab_size, (recv_req.image_hash) % self.model_config.vocab_size,
...@@ -240,6 +239,7 @@ class ModelRpcServer(rpyc.Service): ...@@ -240,6 +239,7 @@ class ModelRpcServer(rpyc.Service):
req.input_ids, req.image_offset = self.model_runner.model.pad_input_ids( req.input_ids, req.image_offset = self.model_runner.model.pad_input_ids(
req.input_ids, pad_value, req.pixel_values.shape, req.image_size req.input_ids, pad_value, req.pixel_values.shape, req.image_size
) )
req.image_size = recv_req.image_size
req.sampling_params = recv_req.sampling_params req.sampling_params = recv_req.sampling_params
req.return_logprob = recv_req.return_logprob req.return_logprob = recv_req.return_logprob
req.logprob_start_len = recv_req.logprob_start_len req.logprob_start_len = recv_req.logprob_start_len
...@@ -327,9 +327,11 @@ class ModelRpcServer(rpyc.Service): ...@@ -327,9 +327,11 @@ class ModelRpcServer(rpyc.Service):
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
< available_size < available_size
): ):
# Undo the insertion
delta = self.tree_cache.dec_ref_counter(req.last_node) delta = self.tree_cache.dec_ref_counter(req.last_node)
available_size += delta available_size += delta
else: else:
# Add this request to the running batch
self.token_to_kv_pool.add_refs(req.prefix_indices) self.token_to_kv_pool.add_refs(req.prefix_indices)
can_run_list.append(req) can_run_list.append(req)
new_batch_total_tokens += ( new_batch_total_tokens += (
...@@ -421,7 +423,7 @@ class ModelRpcServer(rpyc.Service): ...@@ -421,7 +423,7 @@ class ModelRpcServer(rpyc.Service):
return return
# Update batch tensors # Update batch tensors
self.decode_forward_ct += 1 self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
batch.prepare_for_decode() batch.prepare_for_decode()
# Forward # Forward
...@@ -454,7 +456,13 @@ class ModelRpcServer(rpyc.Service): ...@@ -454,7 +456,13 @@ class ModelRpcServer(rpyc.Service):
unfinished_indices.append(i) unfinished_indices.append(i)
if req.finished or ( if req.finished or (
req.stream and self.decode_forward_ct % self.stream_interval == 0 (
req.stream
and (
self.decode_forward_ct % self.stream_interval == 0
or len(req.output_ids) == 1
)
)
): ):
output_rids.append(req.rid) output_rids.append(req.rid)
output_tokens.append(req.output_ids) output_tokens.append(req.output_ids)
......
...@@ -7,7 +7,6 @@ from typing import List ...@@ -7,7 +7,6 @@ from typing import List
import numpy as np import numpy as np
import torch import torch
import sglang
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.utils import is_multimodal_model from sglang.srt.utils import is_multimodal_model
...@@ -16,6 +15,8 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig ...@@ -16,6 +15,8 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.model_loader import _set_default_torch_dtype from vllm.model_executor.model_loader import _set_default_torch_dtype
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
import sglang
logger = logging.getLogger("model_runner") logger = logging.getLogger("model_runner")
......
...@@ -18,9 +18,9 @@ from sglang.srt.hf_transformers_utils import ( ...@@ -18,9 +18,9 @@ from sglang.srt.hf_transformers_utils import (
) )
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
BatchStrOut, BatchStrOut,
FlushCacheReq,
GenerateReqInput, GenerateReqInput,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
FlushCacheReq,
) )
from sglang.srt.mm_utils import expand2square, process_anyres_image from sglang.srt.mm_utils import expand2square, process_anyres_image
from sglang.srt.sampling_params import SamplingParams from sglang.srt.sampling_params import SamplingParams
......
...@@ -158,7 +158,7 @@ class LlavaLlamaForCausalLM(nn.Module): ...@@ -158,7 +158,7 @@ class LlavaLlamaForCausalLM(nn.Module):
num_patch_height, num_patch_width, height, width, -1 num_patch_height, num_patch_width, height, width, -1
) )
else: else:
raise NotImplementedError raise NotImplementedError()
if "unpad" in self.mm_patch_merge_type: if "unpad" in self.mm_patch_merge_type:
image_feature = image_feature.permute( image_feature = image_feature.permute(
4, 0, 2, 1, 3 4, 0, 2, 1, 3
......
...@@ -19,7 +19,7 @@ class ServerArgs: ...@@ -19,7 +19,7 @@ class ServerArgs:
schedule_heuristic: str = "lpm" schedule_heuristic: str = "lpm"
schedule_conservativeness: float = 1.0 schedule_conservativeness: float = 1.0
random_seed: int = 42 random_seed: int = 42
stream_interval: int = 2 stream_interval: int = 8
disable_log_stats: bool = False disable_log_stats: bool = False
log_stats_interval: int = 10 log_stats_interval: int = 10
log_level: str = "info" log_level: str = "info"
...@@ -132,7 +132,7 @@ class ServerArgs: ...@@ -132,7 +132,7 @@ class ServerArgs:
"--stream-interval", "--stream-interval",
type=int, type=int,
default=ServerArgs.stream_interval, default=ServerArgs.stream_interval,
help="The interval in terms of token length for streaming", help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
) )
parser.add_argument( parser.add_argument(
"--log-level", "--log-level",
......
...@@ -28,7 +28,7 @@ def test_generate_worker(model_path, tp_rank, tp_size): ...@@ -28,7 +28,7 @@ def test_generate_worker(model_path, tp_rank, tp_size):
reqs = [] reqs = []
for i in range(len(prompts)): for i in range(len(prompts)):
req = Req(i) req = Req(i, None, None)
req.input_ids = tokenizer.encode(prompts[i])[:cut_num] req.input_ids = tokenizer.encode(prompts[i])[:cut_num]
req.sampling_params = sampling_params req.sampling_params = sampling_params
reqs.append(req) reqs.append(req)
......
...@@ -112,6 +112,7 @@ def test_generate_worker( ...@@ -112,6 +112,7 @@ def test_generate_worker(
prefill_params = ( prefill_params = (
torch.tensor(np.array(input_ids)).cuda(), torch.tensor(np.array(input_ids)).cuda(),
np.array(pixel_values), np.array(pixel_values),
[None],
[offset], [offset],
*params, *params,
) )
......
""" """
Usage:
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000 python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
python3 test_httpserver_decode.py
Output: Output:
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo
......
""" """
Usage:
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000 python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
python3 test_httpserver_decode_stream.py
Output: Output:
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo
......
""" """
Usage:
python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000 python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000
python3 test_httpserver_llava.py
Output: Output:
The image features a man standing on the back of a yellow taxi cab, holding The image features a man standing on the back of a yellow taxi cab, holding
...@@ -64,9 +66,12 @@ def test_streaming(args): ...@@ -64,9 +66,12 @@ def test_streaming(args):
) )
prev = 0 prev = 0
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): for chunk in response.iter_lines(decode_unicode=False):
if chunk: chunk = chunk.decode("utf-8")
data = json.loads(chunk.decode()) if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]":
break
data = json.loads(chunk[5:].strip("\n"))
output = data["text"].strip() output = data["text"].strip()
print(output[prev:], end="", flush=True) print(output[prev:], end="", flush=True)
prev = len(output) prev = len(output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment