"docs/source/en/api/pipelines/controlnet.mdx" did not exist on "d0f258206d9cdcfb8685447d18b6881aed63143e"
Unverified Commit 62bbd343 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Revert "Extract generation_manager from tokenizer_manager" (#3829)

parent f2388f6b
......@@ -463,5 +463,5 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dic
# Assume all schedulers have the same scheduler_info
scheduler_info = scheduler_infos[0]
tokenizer_manager.configure_max_req_input_len(scheduler_info["max_req_input_len"])
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
return tokenizer_manager, scheduler_info
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import asyncio
import copy
import dataclasses
import logging
import os
import pickle
import time
from datetime import datetime
from http import HTTPStatus
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import fastapi
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.image_processor import (
get_dummy_image_processor,
get_image_processor,
)
from sglang.srt.managers.io_struct import (
AbortReq,
BatchEmbeddingOut,
BatchStrOut,
BatchTokenIDOut,
ConfigureLoggingReq,
EmbeddingReqInput,
GenerateReqInput,
SessionParams,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
)
from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import dataclass_to_string_truncated
logger = logging.getLogger(__name__)
@dataclasses.dataclass
class _MetricReqState:
created_time: float
first_token_time: Optional[float] = None
@dataclasses.dataclass
class _ReqState:
"""Store the state a request."""
out_list: List
finished: bool
event: asyncio.Event
obj: Any
metric: _MetricReqState
# For streaming output
last_output_offset: int = 0
class GenerationManager:
def __init__(
self,
server_args: ServerArgs,
on_request: Callable,
):
self.server_args = server_args
self.on_request = on_request
self.model_config = _create_model_config_from_server_args(server_args)
self.generation_converter = GenerationConverter(server_args=server_args)
self.rid_to_state: Dict[str, _ReqState] = {}
# Metrics
if server_args.enable_metrics:
self._metric_manager = _MetricManager(
server_args=server_args,
)
else:
self._metric_manager = None
self._request_logger = _RequestLogger(server_args)
self._request_dumper = _RequestDumper()
async def generate_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
request: Optional[fastapi.Request] = None,
):
created_time = time.time()
if isinstance(obj, EmbeddingReqInput) and self.model_config.is_generation:
raise ValueError(
"This model does not appear to be an embedding model by default. "
"Please add `--is-embedding` when launching the server or try another model."
)
obj.normalize_batch_and_arguments()
self._request_logger.log_generation(obj)
is_single = obj.is_single
if is_single:
tokenized_obj = await self.generation_converter.tokenize_request(obj)
self._send_one_request(obj, tokenized_obj, created_time)
async for response in self._wait_one_response(obj, request):
yield response
else:
async for response in self._handle_batch_request(
obj, request, created_time
):
yield response
def _send_one_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
created_time: Optional[float] = None,
):
event = asyncio.Event()
state = _ReqState(
[], False, event, obj, metric=_MetricReqState(created_time=created_time)
)
self.rid_to_state[obj.rid] = state
self.on_request(tokenized_obj)
async def _wait_one_response(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
request: Optional[fastapi.Request] = None,
):
"""Wait for the response of one request."""
state = self.rid_to_state[obj.rid]
while True:
try:
await asyncio.wait_for(state.event.wait(), timeout=4)
except asyncio.TimeoutError:
if request is not None and await request.is_disconnected():
self.abort_request(obj.rid)
raise ValueError(f"Abort request {obj.rid}")
continue
out = state.out_list[-1]
state.out_list = []
if state.finished:
self._request_logger.log_response(obj, out)
del self.rid_to_state[obj.rid]
# Check if this was an abort/error created by scheduler
if isinstance(out["meta_info"].get("finish_reason"), dict):
finish_reason = out["meta_info"]["finish_reason"]
if (
finish_reason.get("type") == "abort"
and finish_reason.get("status_code") == HTTPStatus.BAD_REQUEST
):
raise ValueError(finish_reason["message"])
yield out
break
state.event.clear()
if obj.stream:
yield out
else:
if request is not None and await request.is_disconnected():
self.abort_request(obj.rid)
raise ValueError(f"Abort request {obj.rid}")
async def _handle_batch_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
request: Optional[fastapi.Request] = None,
created_time: Optional[float] = None,
):
batch_size = obj.batch_size
generators = []
rids = []
if getattr(obj, "parallel_sample_num", 1) == 1:
# Send all requests
for i in range(batch_size):
tmp_obj = obj[i]
tokenized_obj = await self.generation_converter.tokenize_request(
tmp_obj
)
self._send_one_request(tmp_obj, tokenized_obj, created_time)
generators.append(self._wait_one_response(tmp_obj, request))
rids.append(tmp_obj.rid)
else:
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
if batch_size > 128:
logger.warning(
"Sending a single large batch with parallel sampling (n > 1) has not been well optimized. "
"The performance might be better if you just duplicate the requests n times or use "
"many threads to send them one by one with parallel sampling (n > 1)."
)
# Tokenize all requests
objs = [obj[i] for i in range(batch_size)]
tokenized_objs = await asyncio.gather(
*(self.generation_converter.tokenize_request(obj) for obj in objs)
)
# Cache the common prefix for parallel sampling
for i in range(batch_size):
tmp_obj = copy.copy(objs[i])
tokenized_obj = copy.copy(tokenized_objs[i])
tokenized_obj.rid = tmp_obj.regenerate_rid()
tokenized_obj.sampling_params = copy.copy(tokenized_obj.sampling_params)
tokenized_obj.sampling_params.max_new_tokens = 0
tokenized_obj.stream = False
self._send_one_request(tmp_obj, tokenized_obj, created_time)
await self._wait_one_response(tmp_obj, request).__anext__()
# Expand requests, assign new rids for them, and send them
for i in range(batch_size):
for _ in range(obj.parallel_sample_num):
tmp_obj = copy.copy(objs[i])
tokenized_obj = copy.copy(tokenized_objs[i])
tokenized_obj.rid = tmp_obj.regenerate_rid()
self._send_one_request(tmp_obj, tokenized_obj, created_time)
generators.append(self._wait_one_response(tmp_obj, request))
rids.append(tmp_obj.rid)
# Wait for all requests
is_stream = hasattr(obj, "stream") and obj.stream
if not is_stream:
outputs = await asyncio.gather(*(gen.__anext__() for gen in generators))
yield outputs
else:
rid_to_index = {rid: i for i, rid in enumerate(rids)}
task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators}
while task_map:
done, _ = await asyncio.wait(
task_map.keys(), return_when=asyncio.FIRST_COMPLETED
)
for task in done:
gen = task_map.pop(task)
try:
result = task.result()
result["index"] = rid_to_index[result["meta_info"]["id"]]
yield result
new_task = asyncio.create_task(gen.__anext__())
task_map[new_task] = gen
except StopAsyncIteration:
pass
def handle_batch_output(
self, recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut]
):
for index, rid in enumerate(recv_obj.rids):
state = self.rid_to_state.get(rid, None)
if state is None:
continue
out_dict = self.generation_converter.postprocess_response(
recv_obj, index, state.obj
)
state.out_list.append(out_dict)
state.finished = recv_obj.finished_reasons[index] is not None
state.event.set()
if self._metric_manager:
self._metric_manager.handle_batch_output_metrics(
recv_obj,
index,
state.metric,
finished=state.finished,
stream=state.obj.stream if hasattr(state.obj, "stream") else None,
)
self._request_dumper.maybe_dump_requests(state=state, out_dict=out_dict)
def abort_request(self, rid: str):
if rid not in self.rid_to_state:
return
del self.rid_to_state[rid]
req = AbortReq(rid)
self.on_request(req)
@property
def tokenizer(self):
return self.generation_converter.tokenizer
def configure_logging(self, obj: ConfigureLoggingReq):
self._request_logger.configure(
log_requests=obj.log_requests, log_requests_level=obj.log_requests_level
)
self._request_dumper.configure(
dump_requests_folder=obj.dump_requests_folder,
dump_requests_threshold=obj.dump_requests_threshold,
)
logging.info(f"Config logging: {obj=}")
class GenerationConverter:
"""Preprocessors and postprocessors for generation"""
def __init__(
self,
server_args: ServerArgs,
):
self.server_args = server_args
self.model_config = _create_model_config_from_server_args(server_args)
# Create image processor placeholder
self.image_processor = get_dummy_image_processor()
# Set after scheduler is initialized
self.max_req_input_len = None
# Create tokenizer
if server_args.skip_tokenizer_init:
self.tokenizer = self.processor = None
else:
if self.model_config.is_multimodal:
self.processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
)
self.tokenizer = self.processor.tokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# We want to parallelize the image pre-processing so we create an executor for it
self.image_processor = get_image_processor(
self.model_config.hf_config, server_args, self.processor
)
else:
self.tokenizer = get_tokenizer(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
)
async def tokenize_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
):
"""Tokenize one request."""
# Tokenize
input_embeds = None
input_text = obj.text
if obj.input_embeds is not None:
if not self.server_args.disable_radix_cache:
raise ValueError(
"input_embeds is provided while disable_radix_cache is False. "
"Please add `--disable-radix-cache` when you launch the server "
"if you want to use input_embeds as inputs."
)
input_embeds = obj.input_embeds
input_ids = obj.input_ids
elif obj.input_ids is not None:
input_ids = obj.input_ids
else:
if self.tokenizer is None:
raise ValueError(
"The engine initialized with skip_tokenizer_init=True cannot "
"accept text prompts. Please provide input_ids or re-initialize "
"the engine with skip_tokenizer_init=False."
)
input_ids = self.tokenizer.encode(input_text)
if self.model_config.is_generation:
# TODO: also support getting embeddings for multimodal models
image_inputs: Dict = await self.image_processor.process_images_async(
obj.image_data, input_text or input_ids, obj, self.max_req_input_len
)
if image_inputs and "input_ids" in image_inputs:
input_ids = image_inputs["input_ids"]
return_logprob = obj.return_logprob
logprob_start_len = obj.logprob_start_len
top_logprobs_num = obj.top_logprobs_num
session_params = (
SessionParams(**obj.session_params) if obj.session_params else None
)
input_token_num = len(input_ids) if input_ids is not None else 0
if input_token_num >= self.model_config.context_len:
raise ValueError(
f"The input ({input_token_num} tokens) is longer than the "
f"model's context length ({self.model_config.context_len} tokens)."
)
if (
obj.sampling_params.get("max_new_tokens") is not None
and obj.sampling_params.get("max_new_tokens") + input_token_num
>= self.model_config.context_len
):
raise ValueError(
f"Requested token count exceeds the model's maximum context length "
f"of {self.model_config.context_len} tokens. You requested a total of "
f"{obj.sampling_params.get('max_new_tokens') + input_token_num} "
f"tokens: {input_token_num} tokens from the input messages and "
f"{obj.sampling_params.get('max_new_tokens')} tokens for the "
f"completion. Please reduce the number of tokens in the input "
f"messages or the completion to fit within the limit."
)
# Parse sampling parameters
sampling_params = SamplingParams(**obj.sampling_params)
sampling_params.normalize(self.tokenizer)
sampling_params.verify()
# Build return object
if isinstance(obj, GenerateReqInput):
return TokenizedGenerateReqInput(
obj.rid,
input_text,
input_ids,
image_inputs,
sampling_params,
return_logprob,
logprob_start_len,
top_logprobs_num,
obj.stream,
lora_path=obj.lora_path,
input_embeds=input_embeds,
session_params=session_params,
custom_logit_processor=obj.custom_logit_processor,
)
elif isinstance(obj, EmbeddingReqInput):
return TokenizedEmbeddingReqInput(
obj.rid,
input_text,
input_ids,
sampling_params,
)
else:
raise NotImplementedError
def tokenize_requests(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
) -> List[Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]]:
objs = [obj[i] for i in range(obj.batch_size)]
loop = asyncio.get_event_loop()
return loop.run_until_complete(
asyncio.gather(*(self.tokenize_request(obj) for obj in objs))
)
def postprocess_response(
self,
recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut],
index: int,
req_obj: Union[GenerateReqInput, EmbeddingReqInput],
) -> Dict[str, Any]:
meta_info = self._compute_meta_info(index, recv_obj, req_obj)
if isinstance(recv_obj, BatchStrOut):
return {
"text": recv_obj.output_strs[index],
"meta_info": meta_info,
}
elif isinstance(recv_obj, BatchTokenIDOut):
return {
"token_ids": recv_obj.output_ids[index],
"meta_info": meta_info,
}
elif isinstance(recv_obj, BatchEmbeddingOut):
return {
"embedding": recv_obj.embeddings[index],
"meta_info": meta_info,
}
else:
raise NotImplementedError
def _compute_meta_info(self, index, recv_obj, req_obj):
meta_info = {
"id": recv_obj.rids[index],
"finish_reason": recv_obj.finished_reasons[index],
"prompt_tokens": recv_obj.prompt_tokens[index],
}
if getattr(req_obj, "return_logprob", False):
self._convert_logprob_style(
meta_info,
req_obj.top_logprobs_num,
req_obj.return_text_in_logprobs,
recv_obj,
index,
)
if self.server_args.speculative_algorithm:
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[index]
if not isinstance(recv_obj, BatchEmbeddingOut):
meta_info.update(
{
"completion_tokens": recv_obj.completion_tokens[index],
"cached_tokens": recv_obj.cached_tokens[index],
}
)
if (
hasattr(recv_obj, "output_hidden_states")
and len(recv_obj.output_hidden_states[index]) > 0
):
meta_info["hidden_states"] = recv_obj.output_hidden_states[index]
return meta_info
def _convert_logprob_style(
self,
meta_info: dict,
top_logprobs_num: int,
return_text_in_logprobs: bool,
recv_obj: BatchStrOut,
recv_obj_index: int,
):
meta_info["input_token_logprobs"] = self._detokenize_logprob_tokens(
recv_obj.input_token_logprobs_val[recv_obj_index],
recv_obj.input_token_logprobs_idx[recv_obj_index],
return_text_in_logprobs,
)
meta_info["output_token_logprobs"] = self._detokenize_logprob_tokens(
recv_obj.output_token_logprobs_val[recv_obj_index],
recv_obj.output_token_logprobs_idx[recv_obj_index],
return_text_in_logprobs,
)
if top_logprobs_num > 0:
meta_info["input_top_logprobs"] = self._detokenize_top_logprobs_tokens(
recv_obj.input_top_logprobs_val[recv_obj_index],
recv_obj.input_top_logprobs_idx[recv_obj_index],
return_text_in_logprobs,
)
meta_info["output_top_logprobs"] = self._detokenize_top_logprobs_tokens(
recv_obj.output_top_logprobs_val[recv_obj_index],
recv_obj.output_top_logprobs_idx[recv_obj_index],
return_text_in_logprobs,
)
def _detokenize_logprob_tokens(
self,
token_logprobs_val: List[float],
token_logprobs_idx: List[int],
decode_to_text: bool,
):
if not decode_to_text:
return [
(logprob, token_id, None)
for logprob, token_id in zip(token_logprobs_val, token_logprobs_idx)
]
else:
assert self.tokenizer is not None
token_texts = self.tokenizer.batch_decode(token_logprobs_idx)
return list(zip(token_logprobs_val, token_logprobs_idx, token_texts))
def _detokenize_top_logprobs_tokens(
self,
token_logprobs_val: List[float],
token_logprobs_idx: List[int],
decode_to_text: bool,
):
# TODO: The current implementation only batches the detokenization for top-k tokens per single position.
# We should batch all top-k tokens in all positions.
ret = []
for i in range(len(token_logprobs_val)):
if token_logprobs_val[i]:
ret.append(
self._detokenize_logprob_tokens(
token_logprobs_val[i], token_logprobs_idx[i], decode_to_text
)
)
else:
ret.append(None)
return ret
def _create_model_config_from_server_args(server_args: ServerArgs):
return ModelConfig(
server_args.model_path,
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
context_length=server_args.context_length,
model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding,
dtype=server_args.dtype,
quantization=server_args.quantization,
)
class _MetricManager:
def __init__(self, server_args: ServerArgs):
self.metrics_collector = TokenizerMetricsCollector(
labels={
"model_name": server_args.served_model_name,
# TODO: Add lora name/path in the future,
},
)
def handle_batch_output_metrics(
self,
recv_obj,
i: int,
state: _MetricReqState,
finished: bool,
stream: Optional[bool],
):
completion_tokens = (
recv_obj.completion_tokens[i]
if getattr(recv_obj, "completion_tokens", None)
else 0
)
if state.first_token_time is None:
state.first_token_time = time.time()
self.metrics_collector.observe_time_to_first_token(
state.first_token_time - state.created_time
)
else:
if completion_tokens >= 2:
# Compute time_per_output_token for the streaming case
self.metrics_collector.observe_time_per_output_token(
(time.time() - state.first_token_time) / (completion_tokens - 1)
)
if finished:
self.metrics_collector.observe_one_finished_request(
recv_obj.prompt_tokens[i], completion_tokens
)
self.metrics_collector.observe_e2e_request_latency(
time.time() - state.created_time
)
# Compute time_per_output_token for the non-streaming case
if stream is not None and not stream and completion_tokens >= 1:
self.metrics_collector.observe_time_per_output_token(
(time.time() - state.created_time) / completion_tokens
)
class _RequestLogger:
def __init__(self, server_args: ServerArgs):
self.log_requests = server_args.log_requests
self.log_requests_level = 0
def configure(self, log_requests, log_requests_level):
if log_requests is not None:
self.log_requests = log_requests
if log_requests_level is not None:
self.log_requests_level = log_requests_level
def log_generation(self, obj):
if self.log_requests:
max_length = 2048 if self.log_requests_level == 0 else 1 << 30
logger.info(
f"Receive: obj={dataclass_to_string_truncated(obj, max_length)}"
)
def log_response(self, obj, out):
if self.log_requests:
max_length = 2048 if self.log_requests_level == 0 else 1 << 30
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length)}, out={dataclass_to_string_truncated(out, max_length)}"
logger.info(msg)
class _RequestDumper:
def __init__(self):
self.dump_requests_folder = "" # By default do not dump
self.dump_requests_threshold = 1000
self.dump_request_list: List[Tuple] = []
def configure(self, dump_requests_folder, dump_requests_threshold):
if dump_requests_folder is not None:
self.dump_requests_folder = dump_requests_folder
if dump_requests_threshold is not None:
self.dump_requests_threshold = dump_requests_threshold
def maybe_dump_requests(self, state: _ReqState, out_dict: dict):
if self.dump_requests_folder and state.finished and state.obj.log_metrics:
self._dump_requests(state, out_dict)
def _dump_requests(self, state: _ReqState, out_dict: dict):
self.dump_request_list.append(
(state.obj, out_dict, state.metric.created_time, time.time())
)
if len(self.dump_request_list) >= self.dump_requests_threshold:
filename = os.path.join(
self.dump_requests_folder,
datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".pkl",
)
logger.info(f"Dump {len(self.dump_request_list)} requests to {filename}")
to_dump = self.dump_request_list
self.dump_request_list = []
def background_task():
os.makedirs(self.dump_requests_folder, exist_ok=True)
with open(filename, "wb") as f:
pickle.dump(to_dump, f)
# Schedule the task to run in the background without awaiting it
asyncio.create_task(asyncio.to_thread(background_task))
......@@ -14,13 +14,19 @@
"""TokenizerManager is a process that tokenizes the text."""
import asyncio
import copy
import dataclasses
import logging
import os
import pickle
import signal
import sys
import threading
import time
import uuid
from typing import Awaitable, Generic, List, Optional, Tuple, TypeVar, Union
from datetime import datetime
from http import HTTPStatus
from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union
import fastapi
import uvloop
......@@ -29,8 +35,14 @@ import zmq.asyncio
from fastapi import BackgroundTasks
from sglang.srt.aio_rwlock import RWLock
from sglang.srt.managers.generation_manager import GenerationManager
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.image_processor import (
get_dummy_image_processor,
get_image_processor,
)
from sglang.srt.managers.io_struct import (
AbortReq,
BatchEmbeddingOut,
BatchStrOut,
BatchTokenIDOut,
......@@ -50,6 +62,9 @@ from sglang.srt.managers.io_struct import (
ReleaseMemoryOccupationReqOutput,
ResumeMemoryOccupationReqInput,
ResumeMemoryOccupationReqOutput,
SessionParams,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
UpdateWeightFromDiskReqInput,
UpdateWeightFromDiskReqOutput,
UpdateWeightsFromDistributedReqInput,
......@@ -57,8 +72,14 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromTensorReqInput,
UpdateWeightsFromTensorReqOutput,
)
from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import get_zmq_socket, kill_process_tree
from sglang.srt.utils import (
dataclass_to_string_truncated,
get_zmq_socket,
kill_process_tree,
)
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
......@@ -66,6 +87,23 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
logger = logging.getLogger(__name__)
@dataclasses.dataclass
class ReqState:
"""Store the state a request."""
out_list: List
finished: bool
event: asyncio.Event
obj: Any
# For metrics
created_time: float
first_token_time: Optional[float] = None
# For streaming output
last_output_offset: int = 0
class TokenizerManager:
"""TokenizerManager is a process that tokenizes the text."""
......@@ -78,6 +116,8 @@ class TokenizerManager:
self.server_args = server_args
self.enable_metrics = server_args.enable_metrics
self.log_requests = server_args.log_requests
self.log_requests_level = 0
# Init inter-process communication
context = zmq.asyncio.Context(2)
......@@ -91,9 +131,56 @@ class TokenizerManager:
# Read model args
self.model_path = server_args.model_path
self.served_model_name = server_args.served_model_name
self.model_config = ModelConfig(
server_args.model_path,
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
context_length=server_args.context_length,
model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding,
dtype=server_args.dtype,
quantization=server_args.quantization,
)
self.is_generation = self.model_config.is_generation
self.context_len = self.model_config.context_len
self.image_token_id = self.model_config.image_token_id
# Create image processor placeholder
self.image_processor = get_dummy_image_processor()
# Create tokenizer
if server_args.skip_tokenizer_init:
self.tokenizer = self.processor = None
else:
if self.model_config.is_multimodal:
self.processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
)
self.tokenizer = self.processor.tokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# We want to parallelize the image pre-processing so we create an executor for it
self.image_processor = get_image_processor(
self.model_config.hf_config, server_args, self.processor
)
else:
self.tokenizer = get_tokenizer(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
)
# Store states
self.no_create_loop = False
self.rid_to_state: Dict[str, ReqState] = {}
self.dump_requests_folder = "" # By default do not dump
self.dump_requests_threshold = 1000
self.dump_request_list: List[Tuple] = []
# The event to notify the weight sync is finished.
self.model_update_lock = RWLock()
......@@ -105,11 +192,6 @@ class TokenizerManager:
# For session info
self.session_futures = {} # session_id -> asyncio event
self._generation_manager = GenerationManager(
server_args=server_args,
on_request=self.send_to_scheduler.send_pyobj,
)
# Others
self.gracefully_exit = False
self.init_weights_update_group_communicator = _Communicator(
......@@ -130,12 +212,23 @@ class TokenizerManager:
self.resume_memory_occupation_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
# Set after scheduler is initialized
self.max_req_input_len = None
# Metrics
if self.enable_metrics:
self.metrics_collector = TokenizerMetricsCollector(
labels={
"model_name": self.server_args.served_model_name,
# TODO: Add lora name/path in the future,
},
)
self._result_dispatcher = TypeBasedDispatcher(
[
(
(BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut),
self._generation_manager.handle_batch_output,
self._handle_batch_output,
),
(OpenSessionReqOutput, self._handle_open_session_req_output),
(
......@@ -174,17 +267,280 @@ class TokenizerManager:
obj: Union[GenerateReqInput, EmbeddingReqInput],
request: Optional[fastapi.Request] = None,
):
created_time = time.time()
self.auto_create_handle_loop()
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
raise ValueError(
"This model does not appear to be an embedding model by default. "
"Please add `--is-embedding` when launching the server or try another model."
)
obj.normalize_batch_and_arguments()
if self.log_requests:
max_length = 2048 if self.log_requests_level == 0 else 1 << 30
logger.info(
f"Receive: obj={dataclass_to_string_truncated(obj, max_length)}"
)
async with self.model_update_lock.reader_lock:
async for value in self._generation_manager.generate_request(obj, request):
yield value
is_single = obj.is_single
if is_single:
tokenized_obj = await self._tokenize_one_request(obj)
self._send_one_request(obj, tokenized_obj, created_time)
async for response in self._wait_one_response(obj, request):
yield response
else:
async for response in self._handle_batch_request(
obj, request, created_time
):
yield response
async def _tokenize_one_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
):
"""Tokenize one request."""
# Tokenize
input_embeds = None
input_text = obj.text
if obj.input_embeds is not None:
if not self.server_args.disable_radix_cache:
raise ValueError(
"input_embeds is provided while disable_radix_cache is False. "
"Please add `--disable-radix-cache` when you launch the server "
"if you want to use input_embeds as inputs."
)
input_embeds = obj.input_embeds
input_ids = obj.input_ids
elif obj.input_ids is not None:
input_ids = obj.input_ids
else:
if self.tokenizer is None:
raise ValueError(
"The engine initialized with skip_tokenizer_init=True cannot "
"accept text prompts. Please provide input_ids or re-initialize "
"the engine with skip_tokenizer_init=False."
)
input_ids = self.tokenizer.encode(input_text)
if self.is_generation:
# TODO: also support getting embeddings for multimodal models
image_inputs: Dict = await self.image_processor.process_images_async(
obj.image_data, input_text or input_ids, obj, self.max_req_input_len
)
if image_inputs and "input_ids" in image_inputs:
input_ids = image_inputs["input_ids"]
return_logprob = obj.return_logprob
logprob_start_len = obj.logprob_start_len
top_logprobs_num = obj.top_logprobs_num
session_params = (
SessionParams(**obj.session_params) if obj.session_params else None
)
input_token_num = len(input_ids) if input_ids is not None else 0
if input_token_num >= self.context_len:
raise ValueError(
f"The input ({input_token_num} tokens) is longer than the "
f"model's context length ({self.context_len} tokens)."
)
if (
obj.sampling_params.get("max_new_tokens") is not None
and obj.sampling_params.get("max_new_tokens") + input_token_num
>= self.context_len
):
raise ValueError(
f"Requested token count exceeds the model's maximum context length "
f"of {self.context_len} tokens. You requested a total of "
f"{obj.sampling_params.get('max_new_tokens') + input_token_num} "
f"tokens: {input_token_num} tokens from the input messages and "
f"{obj.sampling_params.get('max_new_tokens')} tokens for the "
f"completion. Please reduce the number of tokens in the input "
f"messages or the completion to fit within the limit."
)
# Parse sampling parameters
sampling_params = SamplingParams(**obj.sampling_params)
sampling_params.normalize(self.tokenizer)
sampling_params.verify()
# Build return object
if isinstance(obj, GenerateReqInput):
tokenized_obj = TokenizedGenerateReqInput(
obj.rid,
input_text,
input_ids,
image_inputs,
sampling_params,
return_logprob,
logprob_start_len,
top_logprobs_num,
obj.stream,
lora_path=obj.lora_path,
input_embeds=input_embeds,
session_params=session_params,
custom_logit_processor=obj.custom_logit_processor,
)
elif isinstance(obj, EmbeddingReqInput):
tokenized_obj = TokenizedEmbeddingReqInput(
obj.rid,
input_text,
input_ids,
sampling_params,
)
return tokenized_obj
def _send_one_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
created_time: Optional[float] = None,
):
event = asyncio.Event()
state = ReqState([], False, event, obj, created_time=created_time)
self.rid_to_state[obj.rid] = state
self.send_to_scheduler.send_pyobj(tokenized_obj)
async def _wait_one_response(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
request: Optional[fastapi.Request] = None,
):
"""Wait for the response of one request."""
state = self.rid_to_state[obj.rid]
while True:
try:
await asyncio.wait_for(state.event.wait(), timeout=4)
except asyncio.TimeoutError:
if request is not None and await request.is_disconnected():
self.abort_request(obj.rid)
raise ValueError(f"Abort request {obj.rid}")
continue
out = state.out_list[-1]
state.out_list = []
if state.finished:
if self.log_requests:
max_length = 2048 if self.log_requests_level == 0 else 1 << 30
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length)}, out={dataclass_to_string_truncated(out, max_length)}"
logger.info(msg)
del self.rid_to_state[obj.rid]
# Check if this was an abort/error created by scheduler
if isinstance(out["meta_info"].get("finish_reason"), dict):
finish_reason = out["meta_info"]["finish_reason"]
if (
finish_reason.get("type") == "abort"
and finish_reason.get("status_code") == HTTPStatus.BAD_REQUEST
):
raise ValueError(finish_reason["message"])
yield out
break
state.event.clear()
if obj.stream:
yield out
else:
if request is not None and await request.is_disconnected():
self.abort_request(obj.rid)
raise ValueError(f"Abort request {obj.rid}")
async def _handle_batch_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
request: Optional[fastapi.Request] = None,
created_time: Optional[float] = None,
):
batch_size = obj.batch_size
generators = []
rids = []
if getattr(obj, "parallel_sample_num", 1) == 1:
# Send all requests
for i in range(batch_size):
tmp_obj = obj[i]
tokenized_obj = await self._tokenize_one_request(tmp_obj)
self._send_one_request(tmp_obj, tokenized_obj, created_time)
generators.append(self._wait_one_response(tmp_obj, request))
rids.append(tmp_obj.rid)
else:
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
if batch_size > 128:
logger.warning(
"Sending a single large batch with parallel sampling (n > 1) has not been well optimized. "
"The performance might be better if you just duplicate the requests n times or use "
"many threads to send them one by one with parallel sampling (n > 1)."
)
# Tokenize all requests
objs = [obj[i] for i in range(batch_size)]
tokenized_objs = await asyncio.gather(
*(self._tokenize_one_request(obj) for obj in objs)
)
# Cache the common prefix for parallel sampling
for i in range(batch_size):
tmp_obj = copy.copy(objs[i])
tokenized_obj = copy.copy(tokenized_objs[i])
tokenized_obj.rid = tmp_obj.regenerate_rid()
tokenized_obj.sampling_params = copy.copy(tokenized_obj.sampling_params)
tokenized_obj.sampling_params.max_new_tokens = 0
tokenized_obj.stream = False
self._send_one_request(tmp_obj, tokenized_obj, created_time)
await self._wait_one_response(tmp_obj, request).__anext__()
# Expand requests, assign new rids for them, and send them
for i in range(batch_size):
for _ in range(obj.parallel_sample_num):
tmp_obj = copy.copy(objs[i])
tokenized_obj = copy.copy(tokenized_objs[i])
tokenized_obj.rid = tmp_obj.regenerate_rid()
self._send_one_request(tmp_obj, tokenized_obj, created_time)
generators.append(self._wait_one_response(tmp_obj, request))
rids.append(tmp_obj.rid)
# Wait for all requests
is_stream = hasattr(obj, "stream") and obj.stream
if not is_stream:
outputs = await asyncio.gather(*(gen.__anext__() for gen in generators))
yield outputs
else:
rid_to_index = {rid: i for i, rid in enumerate(rids)}
task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators}
while task_map:
done, _ = await asyncio.wait(
task_map.keys(), return_when=asyncio.FIRST_COMPLETED
)
for task in done:
gen = task_map.pop(task)
try:
result = task.result()
result["index"] = rid_to_index[result["meta_info"]["id"]]
yield result
new_task = asyncio.create_task(gen.__anext__())
task_map[new_task] = gen
except StopAsyncIteration:
pass
def flush_cache(self):
req = FlushCacheReq()
self.send_to_scheduler.send_pyobj(req)
def abort_request(self, rid: str):
self._generation_manager.abort_request(rid)
if rid not in self.rid_to_state:
return
del self.rid_to_state[rid]
req = AbortReq(rid)
self.send_to_scheduler.send_pyobj(req)
def start_profile(self):
req = ProfileReq.START_PROFILE
......@@ -332,7 +688,15 @@ class TokenizerManager:
await self.send_to_scheduler.send_pyobj(obj)
def configure_logging(self, obj: ConfigureLoggingReq):
self._generation_manager.configure_logging(obj)
if obj.log_requests is not None:
self.log_requests = obj.log_requests
if obj.log_requests_level is not None:
self.log_requests_level = obj.log_requests_level
if obj.dump_requests_folder is not None:
self.dump_requests_folder = obj.dump_requests_folder
if obj.dump_requests_threshold is not None:
self.dump_requests_threshold = obj.dump_requests_threshold
logging.info(f"Config logging: {obj=}")
def create_abort_task(self, obj: GenerateReqInput):
# Abort the request if the client is disconnected.
......@@ -379,7 +743,7 @@ class TokenizerManager:
# Drain requests
while True:
remain_num_req = len(self._generation_manager.rid_to_state)
remain_num_req = len(self.rid_to_state)
logger.info(
f"Gracefully exiting... remaining number of requests {remain_num_req}"
)
......@@ -398,6 +762,198 @@ class TokenizerManager:
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
self._result_dispatcher(recv_obj)
def _handle_batch_output(
self, recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut]
):
for i, rid in enumerate(recv_obj.rids):
state = self.rid_to_state.get(rid, None)
if state is None:
continue
meta_info = {
"id": rid,
"finish_reason": recv_obj.finished_reasons[i],
"prompt_tokens": recv_obj.prompt_tokens[i],
}
if getattr(state.obj, "return_logprob", False):
self.convert_logprob_style(
meta_info,
state.obj.top_logprobs_num,
state.obj.return_text_in_logprobs,
recv_obj,
i,
)
if self.server_args.speculative_algorithm:
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
if not isinstance(recv_obj, BatchEmbeddingOut):
meta_info.update(
{
"completion_tokens": recv_obj.completion_tokens[i],
"cached_tokens": recv_obj.cached_tokens[i],
}
)
if (
hasattr(recv_obj, "output_hidden_states")
and len(recv_obj.output_hidden_states[i]) > 0
):
meta_info["hidden_states"] = recv_obj.output_hidden_states[i]
if isinstance(recv_obj, BatchStrOut):
out_dict = {
"text": recv_obj.output_strs[i],
"meta_info": meta_info,
}
elif isinstance(recv_obj, BatchTokenIDOut):
out_dict = {
"token_ids": recv_obj.output_ids[i],
"meta_info": meta_info,
}
else:
assert isinstance(recv_obj, BatchEmbeddingOut)
out_dict = {
"embedding": recv_obj.embeddings[i],
"meta_info": meta_info,
}
state.out_list.append(out_dict)
state.finished = recv_obj.finished_reasons[i] is not None
state.event.set()
if self.enable_metrics and state.obj.log_metrics:
self.collect_metrics(state, recv_obj, i)
if self.dump_requests_folder and state.finished and state.obj.log_metrics:
self.dump_requests(state, out_dict)
def convert_logprob_style(
self,
meta_info: dict,
top_logprobs_num: int,
return_text_in_logprobs: bool,
recv_obj: BatchStrOut,
recv_obj_index: int,
):
meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
recv_obj.input_token_logprobs_val[recv_obj_index],
recv_obj.input_token_logprobs_idx[recv_obj_index],
return_text_in_logprobs,
)
meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
recv_obj.output_token_logprobs_val[recv_obj_index],
recv_obj.output_token_logprobs_idx[recv_obj_index],
return_text_in_logprobs,
)
if top_logprobs_num > 0:
meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
recv_obj.input_top_logprobs_val[recv_obj_index],
recv_obj.input_top_logprobs_idx[recv_obj_index],
return_text_in_logprobs,
)
meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
recv_obj.output_top_logprobs_val[recv_obj_index],
recv_obj.output_top_logprobs_idx[recv_obj_index],
return_text_in_logprobs,
)
def detokenize_logprob_tokens(
self,
token_logprobs_val: List[float],
token_logprobs_idx: List[int],
decode_to_text: bool,
):
if not decode_to_text:
return [
(logprob, token_id, None)
for logprob, token_id in zip(token_logprobs_val, token_logprobs_idx)
]
else:
assert self.tokenizer is not None
token_texts = self.tokenizer.batch_decode(token_logprobs_idx)
return list(zip(token_logprobs_val, token_logprobs_idx, token_texts))
def detokenize_top_logprobs_tokens(
self,
token_logprobs_val: List[float],
token_logprobs_idx: List[int],
decode_to_text: bool,
):
# TODO: The current implementation only batches the detokenization for top-k tokens per single position.
# We should batch all top-k tokens in all positions.
ret = []
for i in range(len(token_logprobs_val)):
if token_logprobs_val[i]:
ret.append(
self.detokenize_logprob_tokens(
token_logprobs_val[i], token_logprobs_idx[i], decode_to_text
)
)
else:
ret.append(None)
return ret
def collect_metrics(self, state: ReqState, recv_obj: BatchStrOut, i: int):
completion_tokens = (
recv_obj.completion_tokens[i]
if getattr(recv_obj, "completion_tokens", None)
else 0
)
if state.first_token_time is None:
state.first_token_time = time.time()
self.metrics_collector.observe_time_to_first_token(
state.first_token_time - state.created_time
)
else:
if completion_tokens >= 2:
# Compute time_per_output_token for the streaming case
self.metrics_collector.observe_time_per_output_token(
(time.time() - state.first_token_time) / (completion_tokens - 1)
)
if state.finished:
self.metrics_collector.observe_one_finished_request(
recv_obj.prompt_tokens[i], completion_tokens
)
self.metrics_collector.observe_e2e_request_latency(
time.time() - state.created_time
)
# Compute time_per_output_token for the non-streaming case
if (
hasattr(state.obj, "stream")
and not state.obj.stream
and completion_tokens >= 1
):
self.metrics_collector.observe_time_per_output_token(
(time.time() - state.created_time) / completion_tokens
)
def dump_requests(self, state: ReqState, out_dict: dict):
self.dump_request_list.append(
(state.obj, out_dict, state.created_time, time.time())
)
if len(self.dump_request_list) >= self.dump_requests_threshold:
filename = os.path.join(
self.dump_requests_folder,
datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".pkl",
)
logger.info(f"Dump {len(self.dump_request_list)} requests to {filename}")
to_dump = self.dump_request_list
self.dump_request_list = []
def background_task():
os.makedirs(self.dump_requests_folder, exist_ok=True)
with open(filename, "wb") as f:
pickle.dump(to_dump, f)
# Schedule the task to run in the background without awaiting it
asyncio.create_task(asyncio.to_thread(background_task))
def _handle_open_session_req_output(self, recv_obj):
self.session_futures[recv_obj.session_id].set_result(
recv_obj.session_id if recv_obj.success else None
......@@ -412,23 +968,6 @@ class TokenizerManager:
if len(self.model_update_tmp) == self.server_args.dp_size:
self.model_update_result.set_result(self.model_update_tmp)
@property
def is_generation(self):
return self._generation_manager.model_config.is_generation
@property
def tokenizer(self):
return self._generation_manager.tokenizer
@property
def image_token_id(self):
return self._generation_manager.model_config.image_token_id
def configure_max_req_input_len(self, max_req_input_len):
self._generation_manager.generation_converter.max_req_input_len = (
max_req_input_len
)
async def print_exception_wrapper(func):
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment