Unverified Commit 5ae5ed1e authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Core] Consolidate prompt arguments to LLM engines (#4328)


Co-authored-by: default avatarRoger Wang <ywang@roblox.com>
parent 290f4ada
import time
from dataclasses import dataclass
from typing import List, Optional, Union
from vllm.lora.request import LoRARequest
......@@ -6,6 +7,7 @@ from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
SequenceGroup, SequenceStatus)
@dataclass
class CompletionOutput:
"""The output data of one completion output of a request.
......@@ -24,25 +26,14 @@ class CompletionOutput:
lora_request: The LoRA request that was used to generate the output.
"""
def __init__(
self,
index: int,
text: str,
token_ids: List[int],
cumulative_logprob: float,
logprobs: Optional[SampleLogprobs],
finish_reason: Optional[str] = None,
stop_reason: Union[int, str, None] = None,
lora_request: Optional[LoRARequest] = None,
) -> None:
self.index = index
self.text = text
self.token_ids = token_ids
self.cumulative_logprob = cumulative_logprob
self.logprobs = logprobs
self.finish_reason = finish_reason
self.stop_reason = stop_reason
self.lora_request = lora_request
index: int
text: str
token_ids: List[int]
cumulative_logprob: float
logprobs: Optional[SampleLogprobs]
finish_reason: Optional[str] = None
stop_reason: Union[int, str, None] = None
lora_request: Optional[LoRARequest] = None
def finished(self) -> bool:
return self.finish_reason is not None
......@@ -57,6 +48,7 @@ class CompletionOutput:
f"stop_reason={self.stop_reason})")
@dataclass
class EmbeddingOutput:
"""The output data of one completion output of a request.
......@@ -65,15 +57,11 @@ class EmbeddingOutput:
length of vector depends on the model as listed in the embedding guide.
"""
def __init__(
self,
embedding: List[float],
) -> None:
self.embedding = embedding
embedding: List[float]
def __repr__(self) -> str:
return (f"EmbeddingOutput("
f"embedding={len(self.embedding)}")
f"embedding={len(self.embedding)})")
class RequestOutput:
......@@ -93,7 +81,7 @@ class RequestOutput:
def __init__(
self,
request_id: str,
prompt: str,
prompt: Optional[str],
prompt_token_ids: List[int],
prompt_logprobs: Optional[PromptLogprobs],
outputs: List[CompletionOutput],
......@@ -183,7 +171,7 @@ class EmbeddingRequestOutput:
finished (bool): A flag indicating whether the embedding is completed.
"""
def __init__(self, request_id: str, outputs: 'EmbeddingOutput',
def __init__(self, request_id: str, outputs: "EmbeddingOutput",
prompt_token_ids: List[int], finished: bool):
self.request_id = request_id
self.prompt_token_ids = prompt_token_ids
......
......@@ -6,6 +6,7 @@ from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from vllm.block import LogicalTokenBlock
from vllm.inputs import LLMInputs
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
......@@ -210,8 +211,7 @@ class Sequence:
Args:
seq_id: The ID of the sequence.
prompt: The prompt of the sequence.
prompt_token_ids: The token IDs of the prompt.
inputs: The inputs of the sequence.
block_size: The block size of the sequence. Should be the same as the
block size used by the block manager and cache engine.
lora_request: LoRA request.
......@@ -220,25 +220,24 @@ class Sequence:
def __init__(
self,
seq_id: int,
prompt: str,
prompt_token_ids: List[int],
inputs: LLMInputs,
block_size: int,
eos_token_id: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
) -> None:
self.seq_id = seq_id
self.prompt = prompt
self.inputs = inputs
self.block_size = block_size
self.eos_token_id = eos_token_id
self.lora_request = lora_request
self.data: SequenceData = SequenceData(prompt_token_ids)
self.data = SequenceData(self.prompt_token_ids)
self.output_logprobs: SampleLogprobs = []
self.output_text = ""
self.logical_token_blocks: List[LogicalTokenBlock] = []
# Initialize the logical token blocks with the prompt token ids.
self._append_tokens_to_blocks(prompt_token_ids)
self._append_tokens_to_blocks(self.prompt_token_ids)
self.status = SequenceStatus.WAITING
self.stop_reason: Union[int, str, None] = None
......@@ -248,6 +247,18 @@ class Sequence:
# Input + output tokens
self.tokens: Optional[List[str]] = None
@property
def prompt(self) -> Optional[str]:
return self.inputs["prompt"]
@property
def prompt_token_ids(self) -> List[int]:
return self.inputs["prompt_token_ids"]
@property
def multi_modal_data(self) -> Optional["MultiModalData"]:
return self.inputs["multi_modal_data"]
@property
def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0
......@@ -415,7 +426,6 @@ class SequenceGroup:
sampling_params: The sampling parameters used to generate the outputs.
arrival_time: The arrival time of the request.
lora_request: LoRA request.
multi_modal_data: Multi modal data associated with the request.
embeddings: The embeddings vectors of the prompt of the sequence group
for an embedding model.
pooling_params: The pooling parameters used to generate the pooling
......@@ -429,7 +439,6 @@ class SequenceGroup:
arrival_time: float,
sampling_params: Optional[SamplingParams] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
embeddings: Optional[List[float]] = None,
pooling_params: Optional[PoolingParams] = None,
) -> None:
......@@ -444,12 +453,11 @@ class SequenceGroup:
self.lora_request = lora_request
self.prompt_logprobs: Optional[PromptLogprobs] = None
self.state = SequenceGroupState()
self.multi_modal_data = multi_modal_data
self.embeddings = embeddings
self.pooling_params = pooling_params
@property
def prompt(self) -> str:
def prompt(self) -> Optional[str]:
# All sequences in the group should have the same prompt.
# We use the prompt of an arbitrary sequence.
return next(iter(self.seqs_dict.values())).prompt
......@@ -458,7 +466,13 @@ class SequenceGroup:
def prompt_token_ids(self) -> List[int]:
# All sequences in the group should have the same prompt.
# We use the prompt of an arbitrary sequence.
return next(iter(self.seqs_dict.values())).data.prompt_token_ids
return next(iter(self.seqs_dict.values())).prompt_token_ids
@property
def multi_modal_data(self) -> Optional[MultiModalData]:
# All sequences in the group should have the same multi-modal data.
# We use the multi-modal data of an arbitrary sequence.
return next(iter(self.seqs_dict.values())).multi_modal_data
@property
def lora_int_id(self) -> int:
......
......@@ -11,7 +11,7 @@ import threading
import uuid
import warnings
from collections import defaultdict
from functools import lru_cache, partial
from functools import lru_cache, partial, wraps
from platform import uname
from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
Hashable, List, Optional, OrderedDict, Tuple, TypeVar,
......@@ -658,3 +658,44 @@ def enable_trace_function_call_for_thread() -> None:
filename)
os.makedirs(os.path.dirname(log_path), exist_ok=True)
enable_trace_function_call(log_path)
def identity(value: T) -> T:
return value
F = TypeVar('F', bound=Callable[..., Any])
def deprecate_kwargs(
*kws: str,
is_deprecated: Union[bool, Callable[[], bool]] = True,
additional_message: Optional[str] = None) -> Callable[[F], F]:
deprecated_kws = set(kws)
if not callable(is_deprecated):
is_deprecated = partial(identity, is_deprecated)
def wrapper(fn: F) -> F:
@wraps(fn)
def inner(*args, **kwargs):
if is_deprecated():
deprecated_kwargs = kwargs.keys() & deprecated_kws
if deprecated_kwargs:
msg = (
f"The keyword arguments {deprecated_kwargs} are "
"deprecated and will be removed in a future update.")
if additional_message is not None:
msg += f" {additional_message}"
warnings.warn(
DeprecationWarning(msg),
stacklevel=3, # The inner function takes up one level
)
return fn(*args, **kwargs)
return inner # type: ignore
return wrapper
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment