Unverified Commit d18c6b33 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Support incremental streaming of logprob/token_ids between scheduler and detokenizer (#6225)


Co-authored-by: default avatarSangBin Cho <rkooo567@gmail.com>
parent f1c89600
...@@ -41,7 +41,7 @@ class BaseGrammarObject: ...@@ -41,7 +41,7 @@ class BaseGrammarObject:
raise NotImplementedError() raise NotImplementedError()
def is_terminated(self): def is_terminated(self):
raise NotImplementedError() return False
def allocate_vocab_mask( def allocate_vocab_mask(
self, vocab_size: int, batch_size: int, device self, vocab_size: int, batch_size: int, device
......
...@@ -28,6 +28,7 @@ from sglang.srt.hf_transformers_utils import get_tokenizer ...@@ -28,6 +28,7 @@ from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
BatchEmbeddingOut, BatchEmbeddingOut,
BatchMultimodalDecodeReq, BatchMultimodalDecodeReq,
BatchMultimodalOut,
BatchStrOut, BatchStrOut,
BatchTokenIDOut, BatchTokenIDOut,
) )
...@@ -60,6 +61,8 @@ class DecodeStatus: ...@@ -60,6 +61,8 @@ class DecodeStatus:
decode_ids: List[int] decode_ids: List[int]
surr_offset: int surr_offset: int
read_offset: int read_offset: int
# Offset that's sent to tokenizer for incremental update.
sent_offset: int = 0
class DetokenizerManager: class DetokenizerManager:
...@@ -151,7 +154,7 @@ class DetokenizerManager: ...@@ -151,7 +154,7 @@ class DetokenizerManager:
self.decode_status[rid] = s self.decode_status[rid] = s
else: else:
s = self.decode_status[rid] s = self.decode_status[rid]
s.decode_ids = recv_obj.decode_ids[i] s.decode_ids.extend(recv_obj.decode_ids[i])
read_ids.append( read_ids.append(
self.trim_matched_stop( self.trim_matched_stop(
...@@ -199,13 +202,15 @@ class DetokenizerManager: ...@@ -199,13 +202,15 @@ class DetokenizerManager:
else: else:
new_text = find_printable_text(new_text) new_text = find_printable_text(new_text)
output_strs.append( output_str = self.trim_matched_stop(
self.trim_matched_stop( s.decoded_text + new_text,
s.decoded_text + new_text, recv_obj.finished_reasons[i],
recv_obj.finished_reasons[i], recv_obj.no_stop_trim[i],
recv_obj.no_stop_trim[i],
)
) )
# Incrementally send text.
incremental_output = output_str[s.sent_offset :]
s.sent_offset = len(output_str)
output_strs.append(incremental_output)
return BatchStrOut( return BatchStrOut(
rids=recv_obj.rids, rids=recv_obj.rids,
...@@ -232,7 +237,15 @@ class DetokenizerManager: ...@@ -232,7 +237,15 @@ class DetokenizerManager:
) )
def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq): def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq):
raise NotImplementedError() outputs = self.tokenizer.detokenize(recv_obj)
return BatchMultimodalOut(
rids=recv_obj.rids,
finished_reasons=recv_obj.finished_reasons,
outputs=outputs,
prompt_tokens=recv_obj.prompt_tokens,
completion_tokens=recv_obj.completion_tokens,
cached_tokens=recv_obj.cached_tokens,
)
class LimitedCapacityDict(OrderedDict): class LimitedCapacityDict(OrderedDict):
......
...@@ -52,6 +52,7 @@ from sglang.srt.disaggregation.decode import ScheduleBatchDisaggregationDecodeMi ...@@ -52,6 +52,7 @@ from sglang.srt.disaggregation.decode import ScheduleBatchDisaggregationDecodeMi
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
from sglang.srt.metrics.collector import TimeStats
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
...@@ -436,6 +437,7 @@ class Req: ...@@ -436,6 +437,7 @@ class Req:
self.sampling_params = sampling_params self.sampling_params = sampling_params
self.custom_logit_processor = custom_logit_processor self.custom_logit_processor = custom_logit_processor
self.return_hidden_states = return_hidden_states self.return_hidden_states = return_hidden_states
self.lora_path = lora_path
# Memory pool info # Memory pool info
self.req_pool_idx: Optional[int] = None self.req_pool_idx: Optional[int] = None
...@@ -487,6 +489,13 @@ class Req: ...@@ -487,6 +489,13 @@ class Req:
# For retraction # For retraction
self.is_retracted = False self.is_retracted = False
# Incremental streamining
self.send_token_offset: int = 0
self.send_decode_id_offset: int = 0
# TODO (Byron): send_output_token_logprobs_offset and send_decode_id_offset can be different in disaggregation mode
# because the decode server does not have the first output token logprobs
self.send_output_token_logprobs_offset: int = 0
# Logprobs (arguments) # Logprobs (arguments)
self.return_logprob = return_logprob self.return_logprob = return_logprob
# Start index to compute logprob from. # Start index to compute logprob from.
...@@ -496,11 +505,9 @@ class Req: ...@@ -496,11 +505,9 @@ class Req:
self.temp_scaled_logprobs = False self.temp_scaled_logprobs = False
self.top_p_normalized_logprobs = False self.top_p_normalized_logprobs = False
# Latency Breakdown
self.queue_time_start = None
self.queue_time_end = None
# Logprobs (return values) # Logprobs (return values)
# True means the input logprob has been already sent to detokenizer.
self.input_logprob_sent: bool = False
self.input_token_logprobs_val: Optional[List[float]] = None self.input_token_logprobs_val: Optional[List[float]] = None
self.input_token_logprobs_idx: Optional[List[int]] = None self.input_token_logprobs_idx: Optional[List[int]] = None
self.input_top_logprobs_val: Optional[List[float]] = None self.input_top_logprobs_val: Optional[List[float]] = None
...@@ -515,8 +522,10 @@ class Req: ...@@ -515,8 +522,10 @@ class Req:
self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None
if return_logprob: if return_logprob:
# shape: (bs, 1)
self.output_token_logprobs_val = [] self.output_token_logprobs_val = []
self.output_token_logprobs_idx = [] self.output_token_logprobs_idx = []
# shape: (bs, k)
self.output_top_logprobs_val = [] self.output_top_logprobs_val = []
self.output_top_logprobs_idx = [] self.output_top_logprobs_idx = []
self.output_token_ids_logprobs_val = [] self.output_token_ids_logprobs_val = []
...@@ -543,7 +552,12 @@ class Req: ...@@ -543,7 +552,12 @@ class Req:
# The number of verification forward passes in the speculative decoding. # The number of verification forward passes in the speculative decoding.
# This is used to compute the average acceptance length per request. # This is used to compute the average acceptance length per request.
self.spec_verify_ct = 0 self.spec_verify_ct = 0
self.lora_path = lora_path
# For metrics
self.time_stats: TimeStats = TimeStats()
self.has_log_time_stats: bool = False
self.queue_time_start = None
self.queue_time_end = None
# For disaggregation # For disaggregation
self.bootstrap_host: str = bootstrap_host self.bootstrap_host: str = bootstrap_host
...@@ -562,8 +576,8 @@ class Req: ...@@ -562,8 +576,8 @@ class Req:
# This is because kv is not ready in `process_prefill_chunk`. # This is because kv is not ready in `process_prefill_chunk`.
# We use `tmp_end_idx` to store the end index of the kv cache to send. # We use `tmp_end_idx` to store the end index of the kv cache to send.
self.tmp_end_idx: int = -1 self.tmp_end_idx: int = -1
self.metadata_buffer_index: int = -1 self.metadata_buffer_index: int = -1
# The first output_id transferred from prefill instance. # The first output_id transferred from prefill instance.
self.transferred_output_id: Optional[int] = None self.transferred_output_id: Optional[int] = None
...@@ -656,6 +670,11 @@ class Req: ...@@ -656,6 +670,11 @@ class Req:
) )
return return
if self.grammar is not None:
if self.grammar.is_terminated():
self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.output_ids[-1])
return
last_token_id = self.output_ids[-1] last_token_id = self.output_ids[-1]
if not self.sampling_params.ignore_eos: if not self.sampling_params.ignore_eos:
...@@ -713,6 +732,18 @@ class Req: ...@@ -713,6 +732,18 @@ class Req:
token_to_kv_pool_allocator.load_cpu_copy(self.kv_cache_cpu, token_indices) token_to_kv_pool_allocator.load_cpu_copy(self.kv_cache_cpu, token_indices)
del self.kv_cache_cpu del self.kv_cache_cpu
def log_time_stats(self):
# If overlap schedule, we schedule one decode batch ahead so this gets called twice.
if self.has_log_time_stats is True:
return
if self.bootstrap_room is not None:
prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
else:
prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
logger.info(f"{prefix}: {self.time_stats}")
self.has_log_time_stats = True
def __repr__(self): def __repr__(self):
return ( return (
f"Req(rid={self.rid}, " f"Req(rid={self.rid}, "
......
...@@ -530,10 +530,6 @@ class Scheduler( ...@@ -530,10 +530,6 @@ class Scheduler(
) )
def init_metrics(self): def init_metrics(self):
# The largest prefill length of a single request
self._largest_prefill_len: int = 0
# The largest context length (prefill + generation) of a single request
self._largest_prefill_decode_len: int = 0
self.last_gen_throughput: float = 0.0 self.last_gen_throughput: float = 0.0
self.last_input_throughput: float = 0.0 self.last_input_throughput: float = 0.0
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time] self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
...@@ -1122,9 +1118,6 @@ class Scheduler( ...@@ -1122,9 +1118,6 @@ class Scheduler(
self.token_to_kv_pool_allocator.available_size() self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size() + self.tree_cache.evictable_size()
) )
self._largest_prefill_len = max(
self._largest_prefill_len, adder.log_input_tokens
)
num_new_seq = len(can_run_list) num_new_seq = len(can_run_list)
f = ( f = (
...@@ -1601,14 +1594,9 @@ class Scheduler( ...@@ -1601,14 +1594,9 @@ class Scheduler(
elif batch.forward_mode.is_idle(): elif batch.forward_mode.is_idle():
if self.enable_overlap: if self.enable_overlap:
self.tp_worker.resolve_last_batch_result(launch_done) self.tp_worker.resolve_last_batch_result(launch_done)
if batch.next_batch_sampling_info: self.set_next_batch_sampling_info_done(batch)
batch.next_batch_sampling_info.update_regex_vocab_mask()
self.current_stream.synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()
elif batch.forward_mode.is_dummy_first(): elif batch.forward_mode.is_dummy_first():
batch.next_batch_sampling_info.update_regex_vocab_mask() self.set_next_batch_sampling_info_done(batch)
self.current_stream.synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()
if self.return_health_check_ct: if self.return_health_check_ct:
# Return some signal for the health check. # Return some signal for the health check.
...@@ -1776,6 +1764,13 @@ class Scheduler( ...@@ -1776,6 +1764,13 @@ class Scheduler(
self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs]) self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
self.grammar_queue = self.grammar_queue[num_ready_reqs:] self.grammar_queue = self.grammar_queue[num_ready_reqs:]
def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
if batch.next_batch_sampling_info:
if batch.next_batch_sampling_info.grammars is not None:
batch.next_batch_sampling_info.update_regex_vocab_mask()
self.current_stream.synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()
def watchdog_thread(self): def watchdog_thread(self):
"""A watch dog thread that will try to kill the server itself if one forward batch takes too long.""" """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
self.watchdog_last_forward_ct = 0 self.watchdog_last_forward_ct = 0
......
from __future__ import annotations from __future__ import annotations
import logging
import threading import threading
import time
from typing import TYPE_CHECKING, List, Optional, Tuple, Union from typing import TYPE_CHECKING, List, Optional, Tuple, Union
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import BatchEmbeddingOut, BatchTokenIDOut from sglang.srt.managers.io_struct import BatchEmbeddingOut, BatchTokenIDOut
from sglang.srt.managers.schedule_batch import BaseFinishReason, Req, ScheduleBatch from sglang.srt.managers.schedule_batch import BaseFinishReason, Req, ScheduleBatch
...@@ -15,6 +18,8 @@ if TYPE_CHECKING: ...@@ -15,6 +18,8 @@ if TYPE_CHECKING:
Scheduler, Scheduler,
) )
logger = logging.getLogger(__name__)
DEFAULT_FORCE_STREAM_INTERVAL = 50 DEFAULT_FORCE_STREAM_INTERVAL = 50
...@@ -83,6 +88,7 @@ class SchedulerOutputProcessorMixin: ...@@ -83,6 +88,7 @@ class SchedulerOutputProcessorMixin:
if req.finished(): if req.finished():
self.tree_cache.cache_finished_req(req) self.tree_cache.cache_finished_req(req)
req.time_stats.completion_time = time.time()
elif not batch.decoding_reqs or req not in batch.decoding_reqs: elif not batch.decoding_reqs or req not in batch.decoding_reqs:
# This updates radix so others can match # This updates radix so others can match
self.tree_cache.cache_unfinished_req(req) self.tree_cache.cache_unfinished_req(req)
...@@ -149,10 +155,7 @@ class SchedulerOutputProcessorMixin: ...@@ -149,10 +155,7 @@ class SchedulerOutputProcessorMixin:
) )
logprob_pt += num_input_logprobs logprob_pt += num_input_logprobs
if batch.next_batch_sampling_info: self.set_next_batch_sampling_info_done(batch)
batch.next_batch_sampling_info.update_regex_vocab_mask()
self.current_stream.synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()
else: # embedding or reward model else: # embedding or reward model
embeddings, bid = result.embeddings, result.bid embeddings, bid = result.embeddings, result.bid
...@@ -233,6 +236,7 @@ class SchedulerOutputProcessorMixin: ...@@ -233,6 +236,7 @@ class SchedulerOutputProcessorMixin:
req.check_finished() req.check_finished()
if req.finished(): if req.finished():
self.tree_cache.cache_finished_req(req) self.tree_cache.cache_finished_req(req)
req.time_stats.completion_time = time.time()
if req.return_logprob and batch.spec_algorithm.is_none(): if req.return_logprob and batch.spec_algorithm.is_none():
# speculative worker handles logprob in speculative decoding # speculative worker handles logprob in speculative decoding
...@@ -262,13 +266,8 @@ class SchedulerOutputProcessorMixin: ...@@ -262,13 +266,8 @@ class SchedulerOutputProcessorMixin:
req.grammar.accept_token(next_token_id) req.grammar.accept_token(next_token_id)
req.grammar.finished = req.finished() req.grammar.finished = req.finished()
if batch.next_batch_sampling_info: self.set_next_batch_sampling_info_done(batch)
batch.next_batch_sampling_info.update_regex_vocab_mask()
self.current_stream.synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()
self.stream_output(batch.reqs, batch.return_logprob) self.stream_output(batch.reqs, batch.return_logprob)
self.token_to_kv_pool_allocator.free_group_end() self.token_to_kv_pool_allocator.free_group_end()
self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30) self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
...@@ -530,16 +529,27 @@ class SchedulerOutputProcessorMixin: ...@@ -530,16 +529,27 @@ class SchedulerOutputProcessorMixin:
) )
if should_output: if should_output:
send_token_offset = req.send_token_offset
send_output_token_logprobs_offset = (
req.send_output_token_logprobs_offset
)
rids.append(req.rid) rids.append(req.rid)
finished_reasons.append( finished_reasons.append(
req.finished_reason.to_json() if req.finished_reason else None req.finished_reason.to_json() if req.finished_reason else None
) )
decoded_texts.append(req.decoded_text) decoded_texts.append(req.decoded_text)
decode_ids, read_offset = req.init_incremental_detokenize() decode_ids, read_offset = req.init_incremental_detokenize()
decode_ids_list.append(decode_ids)
if self.model_config.is_multimodal_gen:
decode_ids_list.append(decode_ids)
else:
decode_ids_list.append(decode_ids[req.send_decode_id_offset :])
req.send_decode_id_offset = len(decode_ids)
read_offsets.append(read_offset) read_offsets.append(read_offset)
if self.skip_tokenizer_init: if self.skip_tokenizer_init:
output_ids.append(req.output_ids) output_ids.append(req.output_ids[send_token_offset:])
req.send_token_offset = len(req.output_ids)
skip_special_tokens.append(req.sampling_params.skip_special_tokens) skip_special_tokens.append(req.sampling_params.skip_special_tokens)
spaces_between_special_tokens.append( spaces_between_special_tokens.append(
req.sampling_params.spaces_between_special_tokens req.sampling_params.spaces_between_special_tokens
...@@ -553,36 +563,90 @@ class SchedulerOutputProcessorMixin: ...@@ -553,36 +563,90 @@ class SchedulerOutputProcessorMixin:
spec_verify_ct.append(req.spec_verify_ct) spec_verify_ct.append(req.spec_verify_ct)
if return_logprob: if return_logprob:
input_token_logprobs_val.append(req.input_token_logprobs_val) if (
input_token_logprobs_idx.append(req.input_token_logprobs_idx) req.return_logprob
output_token_logprobs_val.append(req.output_token_logprobs_val) and not req.input_logprob_sent
output_token_logprobs_idx.append(req.output_token_logprobs_idx) # Decode server does not send input logprobs
input_top_logprobs_val.append(req.input_top_logprobs_val) and self.disaggregation_mode != DisaggregationMode.DECODE
input_top_logprobs_idx.append(req.input_top_logprobs_idx) ):
output_top_logprobs_val.append(req.output_top_logprobs_val) input_token_logprobs_val.append(req.input_token_logprobs_val)
output_top_logprobs_idx.append(req.output_top_logprobs_idx) input_token_logprobs_idx.append(req.input_token_logprobs_idx)
input_token_ids_logprobs_val.append( input_top_logprobs_val.append(req.input_top_logprobs_val)
req.input_token_ids_logprobs_val input_top_logprobs_idx.append(req.input_top_logprobs_idx)
) input_token_ids_logprobs_val.append(
input_token_ids_logprobs_idx.append( req.input_token_ids_logprobs_val
req.input_token_ids_logprobs_idx )
) input_token_ids_logprobs_idx.append(
output_token_ids_logprobs_val.append( req.input_token_ids_logprobs_idx
req.output_token_ids_logprobs_val )
) req.input_logprob_sent = True
output_token_ids_logprobs_idx.append( else:
req.output_token_ids_logprobs_idx input_token_logprobs_val.append([])
) input_token_logprobs_idx.append([])
input_top_logprobs_val.append([])
input_top_logprobs_idx.append([])
input_token_ids_logprobs_val.append([])
input_token_ids_logprobs_idx.append([])
if req.return_logprob:
output_token_logprobs_val.append(
req.output_token_logprobs_val[
send_output_token_logprobs_offset:
]
)
output_token_logprobs_idx.append(
req.output_token_logprobs_idx[
send_output_token_logprobs_offset:
]
)
output_top_logprobs_val.append(
req.output_top_logprobs_val[
send_output_token_logprobs_offset:
]
)
output_top_logprobs_idx.append(
req.output_top_logprobs_idx[
send_output_token_logprobs_offset:
]
)
output_token_ids_logprobs_val.append(
req.output_token_ids_logprobs_val[
send_output_token_logprobs_offset:
]
)
output_token_ids_logprobs_idx.append(
req.output_token_ids_logprobs_idx[
send_output_token_logprobs_offset:
]
)
req.send_output_token_logprobs_offset = len(
req.output_token_logprobs_val
)
else:
output_token_logprobs_val.append([])
output_token_logprobs_idx.append([])
output_top_logprobs_val.append([])
output_top_logprobs_idx.append([])
output_token_ids_logprobs_val.append([])
output_token_ids_logprobs_idx.append([])
if req.return_hidden_states: if req.return_hidden_states:
if output_hidden_states is None: if output_hidden_states is None:
output_hidden_states = [] output_hidden_states = []
output_hidden_states.append(req.hidden_states) output_hidden_states.append(req.hidden_states)
if (
req.finished()
and self.tp_rank == 0
and self.server_args.enable_request_time_stats_logging
):
req.log_time_stats()
# Send to detokenizer # Send to detokenizer
if rids: if rids:
if self.model_config.is_multimodal_gen: if self.model_config.is_multimodal_gen:
return return
self.send_to_detokenizer.send_pyobj( self.send_to_detokenizer.send_pyobj(
BatchTokenIDOut( BatchTokenIDOut(
rids, rids,
......
...@@ -125,10 +125,10 @@ logger = logging.getLogger(__name__) ...@@ -125,10 +125,10 @@ logger = logging.getLogger(__name__)
class ReqState: class ReqState:
"""Store the state a request.""" """Store the state a request."""
out_list: List out_list: List[Dict[Any, Any]]
finished: bool finished: bool
event: asyncio.Event event: asyncio.Event
obj: Any obj: Union[GenerateReqInput, EmbeddingReqInput]
# For metrics # For metrics
created_time: float created_time: float
...@@ -139,6 +139,21 @@ class ReqState: ...@@ -139,6 +139,21 @@ class ReqState:
# For streaming output # For streaming output
last_output_offset: int = 0 last_output_offset: int = 0
# For incremental state update.
text: str = ""
output_ids: List[int] = dataclasses.field(default_factory=list)
input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
output_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
output_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
input_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
input_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
output_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
output_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
input_token_ids_logprobs_val: List = dataclasses.field(default_factory=list)
input_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
output_token_ids_logprobs_val: List = dataclasses.field(default_factory=list)
output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
class TokenizerManager: class TokenizerManager:
...@@ -1065,9 +1080,11 @@ class TokenizerManager: ...@@ -1065,9 +1080,11 @@ class TokenizerManager:
if getattr(state.obj, "return_logprob", False): if getattr(state.obj, "return_logprob", False):
self.convert_logprob_style( self.convert_logprob_style(
meta_info, meta_info,
state,
state.obj.top_logprobs_num, state.obj.top_logprobs_num,
state.obj.token_ids_logprob, state.obj.token_ids_logprob,
state.obj.return_text_in_logprobs, state.obj.return_text_in_logprobs
and not self.server_args.skip_tokenizer_init,
recv_obj, recv_obj,
i, i,
) )
...@@ -1084,18 +1101,19 @@ class TokenizerManager: ...@@ -1084,18 +1101,19 @@ class TokenizerManager:
meta_info["hidden_states"] = recv_obj.output_hidden_states[i] meta_info["hidden_states"] = recv_obj.output_hidden_states[i]
if isinstance(recv_obj, BatchStrOut): if isinstance(recv_obj, BatchStrOut):
state.text += recv_obj.output_strs[i]
out_dict = { out_dict = {
"text": recv_obj.output_strs[i], "text": state.text,
"meta_info": meta_info, "meta_info": meta_info,
} }
elif isinstance(recv_obj, BatchTokenIDOut): elif isinstance(recv_obj, BatchTokenIDOut):
if self.server_args.stream_output and state.obj.stream: if self.server_args.stream_output and state.obj.stream:
output_token_ids = recv_obj.output_ids[i][ state.output_ids.extend(recv_obj.output_ids[i])
state.last_output_offset : output_token_ids = state.output_ids[state.last_output_offset :]
] state.last_output_offset = len(state.output_ids)
state.last_output_offset = len(recv_obj.output_ids[i])
else: else:
output_token_ids = recv_obj.output_ids[i] state.output_ids.extend(recv_obj.output_ids[i])
output_token_ids = state.output_ids
out_dict = { out_dict = {
"output_ids": output_token_ids, "output_ids": output_token_ids,
...@@ -1130,45 +1148,85 @@ class TokenizerManager: ...@@ -1130,45 +1148,85 @@ class TokenizerManager:
def convert_logprob_style( def convert_logprob_style(
self, self,
meta_info: dict, meta_info: dict,
state: ReqState,
top_logprobs_num: int, top_logprobs_num: int,
token_ids_logprob: List[int], token_ids_logprob: List[int],
return_text_in_logprobs: bool, return_text_in_logprobs: bool,
recv_obj: BatchStrOut, recv_obj: BatchStrOut,
recv_obj_index: int, recv_obj_index: int,
): ):
if len(recv_obj.input_token_logprobs_val) > 0:
state.input_token_logprobs_val.extend(
recv_obj.input_token_logprobs_val[recv_obj_index]
)
state.input_token_logprobs_idx.extend(
recv_obj.input_token_logprobs_idx[recv_obj_index]
)
state.output_token_logprobs_val.extend(
recv_obj.output_token_logprobs_val[recv_obj_index]
)
state.output_token_logprobs_idx.extend(
recv_obj.output_token_logprobs_idx[recv_obj_index]
)
meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens( meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
recv_obj.input_token_logprobs_val[recv_obj_index], state.input_token_logprobs_val,
recv_obj.input_token_logprobs_idx[recv_obj_index], state.input_token_logprobs_idx,
return_text_in_logprobs, return_text_in_logprobs,
) )
meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens( meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
recv_obj.output_token_logprobs_val[recv_obj_index], state.output_token_logprobs_val,
recv_obj.output_token_logprobs_idx[recv_obj_index], state.output_token_logprobs_idx,
return_text_in_logprobs, return_text_in_logprobs,
) )
if top_logprobs_num > 0: if top_logprobs_num > 0:
if len(recv_obj.input_top_logprobs_val) > 0:
state.input_top_logprobs_val.extend(
recv_obj.input_top_logprobs_val[recv_obj_index]
)
state.input_top_logprobs_idx.extend(
recv_obj.input_top_logprobs_idx[recv_obj_index]
)
state.output_top_logprobs_val.extend(
recv_obj.output_top_logprobs_val[recv_obj_index]
)
state.output_top_logprobs_idx.extend(
recv_obj.output_top_logprobs_idx[recv_obj_index]
)
meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens( meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
recv_obj.input_top_logprobs_val[recv_obj_index], state.input_top_logprobs_val,
recv_obj.input_top_logprobs_idx[recv_obj_index], state.input_top_logprobs_idx,
return_text_in_logprobs, return_text_in_logprobs,
) )
meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens( meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
recv_obj.output_top_logprobs_val[recv_obj_index], state.output_top_logprobs_val,
recv_obj.output_top_logprobs_idx[recv_obj_index], state.output_top_logprobs_idx,
return_text_in_logprobs, return_text_in_logprobs,
) )
if token_ids_logprob is not None: if token_ids_logprob is not None:
if len(recv_obj.input_token_ids_logprobs_val) > 0:
state.input_token_ids_logprobs_val.extend(
recv_obj.input_token_ids_logprobs_val[recv_obj_index]
)
state.input_token_ids_logprobs_idx.extend(
recv_obj.input_token_ids_logprobs_idx[recv_obj_index]
)
state.output_token_ids_logprobs_val.extend(
recv_obj.output_token_ids_logprobs_val[recv_obj_index]
)
state.output_token_ids_logprobs_idx.extend(
recv_obj.output_token_ids_logprobs_idx[recv_obj_index]
)
meta_info["input_token_ids_logprobs"] = self.detokenize_top_logprobs_tokens( meta_info["input_token_ids_logprobs"] = self.detokenize_top_logprobs_tokens(
recv_obj.input_token_ids_logprobs_val[recv_obj_index], state.input_token_ids_logprobs_val,
recv_obj.input_token_ids_logprobs_idx[recv_obj_index], state.input_token_ids_logprobs_idx,
return_text_in_logprobs, return_text_in_logprobs,
) )
meta_info["output_token_ids_logprobs"] = ( meta_info["output_token_ids_logprobs"] = (
self.detokenize_top_logprobs_tokens( self.detokenize_top_logprobs_tokens(
recv_obj.output_token_ids_logprobs_val[recv_obj_index], state.output_token_ids_logprobs_val,
recv_obj.output_token_ids_logprobs_idx[recv_obj_index], state.output_token_ids_logprobs_idx,
return_text_in_logprobs, return_text_in_logprobs,
) )
) )
......
...@@ -127,10 +127,12 @@ class TpModelWorkerClient: ...@@ -127,10 +127,12 @@ class TpModelWorkerClient:
batch_lists = [None] * 2 batch_lists = [None] * 2
while True: while True:
model_worker_batch, future_token_ids_ct = self.input_queue.get() model_worker_batch, future_token_ids_ct, sync_event = self.input_queue.get()
if not model_worker_batch: if not model_worker_batch:
break break
sync_event.wait()
# Keep a reference of model_worker_batch by storing it into a list. # Keep a reference of model_worker_batch by storing it into a list.
# Otherwise, the tensor members of model_worker_batch will be released # Otherwise, the tensor members of model_worker_batch will be released
# by pytorch and cause CUDA illegal memory access errors. # by pytorch and cause CUDA illegal memory access errors.
...@@ -214,10 +216,11 @@ class TpModelWorkerClient: ...@@ -214,10 +216,11 @@ class TpModelWorkerClient:
) )
# A cuda stream sync here to avoid the cuda illegal memory access error. # A cuda stream sync here to avoid the cuda illegal memory access error.
self.scheduler_stream.synchronize() sync_event = torch.get_device_module(self.device).Event()
sync_event.record(self.scheduler_stream)
# Push a new batch to the queue # Push a new batch to the queue
self.input_queue.put((model_worker_batch, self.future_token_ids_ct)) self.input_queue.put((model_worker_batch, self.future_token_ids_ct, sync_event))
# Allocate output future objects # Allocate output future objects
bs = len(model_worker_batch.seq_lens) bs = len(model_worker_batch.seq_lens)
......
...@@ -307,5 +307,5 @@ class SamplingBatchInfo: ...@@ -307,5 +307,5 @@ class SamplingBatchInfo:
other_val = getattr(other, item, None) other_val = getattr(other, item, None)
setattr(self, item, torch.cat([self_val, other_val])) setattr(self, item, torch.cat([self_val, other_val]))
self.is_all_greedy |= other.is_all_greedy self.is_all_greedy &= other.is_all_greedy
self.need_min_p_sampling |= other.need_min_p_sampling self.need_min_p_sampling |= other.need_min_p_sampling
...@@ -98,6 +98,7 @@ class ServerArgs: ...@@ -98,6 +98,7 @@ class ServerArgs:
show_time_cost: bool = False show_time_cost: bool = False
enable_metrics: bool = False enable_metrics: bool = False
decode_log_interval: int = 40 decode_log_interval: int = 40
enable_request_time_stats_logging: bool = False
# API related # API related
api_key: Optional[str] = None api_key: Optional[str] = None
...@@ -785,6 +786,12 @@ class ServerArgs: ...@@ -785,6 +786,12 @@ class ServerArgs:
default=ServerArgs.decode_log_interval, default=ServerArgs.decode_log_interval,
help="The log interval of decode batch.", help="The log interval of decode batch.",
) )
parser.add_argument(
"--enable-request-time-stats-logging",
action="store_true",
default=ServerArgs.enable_request_time_stats_logging,
help="Enable per request time stats logging",
)
# API related # API related
parser.add_argument( parser.add_argument(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment