Unverified Commit 5ae5ed1e authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Core] Consolidate prompt arguments to LLM engines (#4328)


Co-authored-by: default avatarRoger Wang <ywang@roblox.com>
parent 290f4ada
import time import time
from dataclasses import dataclass
from typing import List, Optional, Union from typing import List, Optional, Union
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
...@@ -6,6 +7,7 @@ from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs, ...@@ -6,6 +7,7 @@ from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
SequenceGroup, SequenceStatus) SequenceGroup, SequenceStatus)
@dataclass
class CompletionOutput: class CompletionOutput:
"""The output data of one completion output of a request. """The output data of one completion output of a request.
...@@ -24,25 +26,14 @@ class CompletionOutput: ...@@ -24,25 +26,14 @@ class CompletionOutput:
lora_request: The LoRA request that was used to generate the output. lora_request: The LoRA request that was used to generate the output.
""" """
def __init__( index: int
self, text: str
index: int, token_ids: List[int]
text: str, cumulative_logprob: float
token_ids: List[int], logprobs: Optional[SampleLogprobs]
cumulative_logprob: float, finish_reason: Optional[str] = None
logprobs: Optional[SampleLogprobs], stop_reason: Union[int, str, None] = None
finish_reason: Optional[str] = None, lora_request: Optional[LoRARequest] = None
stop_reason: Union[int, str, None] = None,
lora_request: Optional[LoRARequest] = None,
) -> None:
self.index = index
self.text = text
self.token_ids = token_ids
self.cumulative_logprob = cumulative_logprob
self.logprobs = logprobs
self.finish_reason = finish_reason
self.stop_reason = stop_reason
self.lora_request = lora_request
def finished(self) -> bool: def finished(self) -> bool:
return self.finish_reason is not None return self.finish_reason is not None
...@@ -57,6 +48,7 @@ class CompletionOutput: ...@@ -57,6 +48,7 @@ class CompletionOutput:
f"stop_reason={self.stop_reason})") f"stop_reason={self.stop_reason})")
@dataclass
class EmbeddingOutput: class EmbeddingOutput:
"""The output data of one completion output of a request. """The output data of one completion output of a request.
...@@ -65,15 +57,11 @@ class EmbeddingOutput: ...@@ -65,15 +57,11 @@ class EmbeddingOutput:
length of vector depends on the model as listed in the embedding guide. length of vector depends on the model as listed in the embedding guide.
""" """
def __init__( embedding: List[float]
self,
embedding: List[float],
) -> None:
self.embedding = embedding
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"EmbeddingOutput(" return (f"EmbeddingOutput("
f"embedding={len(self.embedding)}") f"embedding={len(self.embedding)})")
class RequestOutput: class RequestOutput:
...@@ -93,7 +81,7 @@ class RequestOutput: ...@@ -93,7 +81,7 @@ class RequestOutput:
def __init__( def __init__(
self, self,
request_id: str, request_id: str,
prompt: str, prompt: Optional[str],
prompt_token_ids: List[int], prompt_token_ids: List[int],
prompt_logprobs: Optional[PromptLogprobs], prompt_logprobs: Optional[PromptLogprobs],
outputs: List[CompletionOutput], outputs: List[CompletionOutput],
...@@ -183,7 +171,7 @@ class EmbeddingRequestOutput: ...@@ -183,7 +171,7 @@ class EmbeddingRequestOutput:
finished (bool): A flag indicating whether the embedding is completed. finished (bool): A flag indicating whether the embedding is completed.
""" """
def __init__(self, request_id: str, outputs: 'EmbeddingOutput', def __init__(self, request_id: str, outputs: "EmbeddingOutput",
prompt_token_ids: List[int], finished: bool): prompt_token_ids: List[int], finished: bool):
self.request_id = request_id self.request_id = request_id
self.prompt_token_ids = prompt_token_ids self.prompt_token_ids = prompt_token_ids
......
...@@ -6,6 +6,7 @@ from dataclasses import dataclass, field ...@@ -6,6 +6,7 @@ from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from vllm.block import LogicalTokenBlock from vllm.block import LogicalTokenBlock
from vllm.inputs import LLMInputs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
...@@ -210,8 +211,7 @@ class Sequence: ...@@ -210,8 +211,7 @@ class Sequence:
Args: Args:
seq_id: The ID of the sequence. seq_id: The ID of the sequence.
prompt: The prompt of the sequence. inputs: The inputs of the sequence.
prompt_token_ids: The token IDs of the prompt.
block_size: The block size of the sequence. Should be the same as the block_size: The block size of the sequence. Should be the same as the
block size used by the block manager and cache engine. block size used by the block manager and cache engine.
lora_request: LoRA request. lora_request: LoRA request.
...@@ -220,25 +220,24 @@ class Sequence: ...@@ -220,25 +220,24 @@ class Sequence:
def __init__( def __init__(
self, self,
seq_id: int, seq_id: int,
prompt: str, inputs: LLMInputs,
prompt_token_ids: List[int],
block_size: int, block_size: int,
eos_token_id: Optional[int] = None, eos_token_id: Optional[int] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
) -> None: ) -> None:
self.seq_id = seq_id self.seq_id = seq_id
self.prompt = prompt self.inputs = inputs
self.block_size = block_size self.block_size = block_size
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
self.lora_request = lora_request self.lora_request = lora_request
self.data: SequenceData = SequenceData(prompt_token_ids) self.data = SequenceData(self.prompt_token_ids)
self.output_logprobs: SampleLogprobs = [] self.output_logprobs: SampleLogprobs = []
self.output_text = "" self.output_text = ""
self.logical_token_blocks: List[LogicalTokenBlock] = [] self.logical_token_blocks: List[LogicalTokenBlock] = []
# Initialize the logical token blocks with the prompt token ids. # Initialize the logical token blocks with the prompt token ids.
self._append_tokens_to_blocks(prompt_token_ids) self._append_tokens_to_blocks(self.prompt_token_ids)
self.status = SequenceStatus.WAITING self.status = SequenceStatus.WAITING
self.stop_reason: Union[int, str, None] = None self.stop_reason: Union[int, str, None] = None
...@@ -248,6 +247,18 @@ class Sequence: ...@@ -248,6 +247,18 @@ class Sequence:
# Input + output tokens # Input + output tokens
self.tokens: Optional[List[str]] = None self.tokens: Optional[List[str]] = None
@property
def prompt(self) -> Optional[str]:
return self.inputs["prompt"]
@property
def prompt_token_ids(self) -> List[int]:
return self.inputs["prompt_token_ids"]
@property
def multi_modal_data(self) -> Optional["MultiModalData"]:
return self.inputs["multi_modal_data"]
@property @property
def lora_int_id(self) -> int: def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0 return self.lora_request.lora_int_id if self.lora_request else 0
...@@ -415,7 +426,6 @@ class SequenceGroup: ...@@ -415,7 +426,6 @@ class SequenceGroup:
sampling_params: The sampling parameters used to generate the outputs. sampling_params: The sampling parameters used to generate the outputs.
arrival_time: The arrival time of the request. arrival_time: The arrival time of the request.
lora_request: LoRA request. lora_request: LoRA request.
multi_modal_data: Multi modal data associated with the request.
embeddings: The embeddings vectors of the prompt of the sequence group embeddings: The embeddings vectors of the prompt of the sequence group
for an embedding model. for an embedding model.
pooling_params: The pooling parameters used to generate the pooling pooling_params: The pooling parameters used to generate the pooling
...@@ -429,7 +439,6 @@ class SequenceGroup: ...@@ -429,7 +439,6 @@ class SequenceGroup:
arrival_time: float, arrival_time: float,
sampling_params: Optional[SamplingParams] = None, sampling_params: Optional[SamplingParams] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
embeddings: Optional[List[float]] = None, embeddings: Optional[List[float]] = None,
pooling_params: Optional[PoolingParams] = None, pooling_params: Optional[PoolingParams] = None,
) -> None: ) -> None:
...@@ -444,12 +453,11 @@ class SequenceGroup: ...@@ -444,12 +453,11 @@ class SequenceGroup:
self.lora_request = lora_request self.lora_request = lora_request
self.prompt_logprobs: Optional[PromptLogprobs] = None self.prompt_logprobs: Optional[PromptLogprobs] = None
self.state = SequenceGroupState() self.state = SequenceGroupState()
self.multi_modal_data = multi_modal_data
self.embeddings = embeddings self.embeddings = embeddings
self.pooling_params = pooling_params self.pooling_params = pooling_params
@property @property
def prompt(self) -> str: def prompt(self) -> Optional[str]:
# All sequences in the group should have the same prompt. # All sequences in the group should have the same prompt.
# We use the prompt of an arbitrary sequence. # We use the prompt of an arbitrary sequence.
return next(iter(self.seqs_dict.values())).prompt return next(iter(self.seqs_dict.values())).prompt
...@@ -458,7 +466,13 @@ class SequenceGroup: ...@@ -458,7 +466,13 @@ class SequenceGroup:
def prompt_token_ids(self) -> List[int]: def prompt_token_ids(self) -> List[int]:
# All sequences in the group should have the same prompt. # All sequences in the group should have the same prompt.
# We use the prompt of an arbitrary sequence. # We use the prompt of an arbitrary sequence.
return next(iter(self.seqs_dict.values())).data.prompt_token_ids return next(iter(self.seqs_dict.values())).prompt_token_ids
@property
def multi_modal_data(self) -> Optional[MultiModalData]:
# All sequences in the group should have the same multi-modal data.
# We use the multi-modal data of an arbitrary sequence.
return next(iter(self.seqs_dict.values())).multi_modal_data
@property @property
def lora_int_id(self) -> int: def lora_int_id(self) -> int:
......
...@@ -11,7 +11,7 @@ import threading ...@@ -11,7 +11,7 @@ import threading
import uuid import uuid
import warnings import warnings
from collections import defaultdict from collections import defaultdict
from functools import lru_cache, partial from functools import lru_cache, partial, wraps
from platform import uname from platform import uname
from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic, from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
Hashable, List, Optional, OrderedDict, Tuple, TypeVar, Hashable, List, Optional, OrderedDict, Tuple, TypeVar,
...@@ -658,3 +658,44 @@ def enable_trace_function_call_for_thread() -> None: ...@@ -658,3 +658,44 @@ def enable_trace_function_call_for_thread() -> None:
filename) filename)
os.makedirs(os.path.dirname(log_path), exist_ok=True) os.makedirs(os.path.dirname(log_path), exist_ok=True)
enable_trace_function_call(log_path) enable_trace_function_call(log_path)
def identity(value: T) -> T:
return value
F = TypeVar('F', bound=Callable[..., Any])
def deprecate_kwargs(
*kws: str,
is_deprecated: Union[bool, Callable[[], bool]] = True,
additional_message: Optional[str] = None) -> Callable[[F], F]:
deprecated_kws = set(kws)
if not callable(is_deprecated):
is_deprecated = partial(identity, is_deprecated)
def wrapper(fn: F) -> F:
@wraps(fn)
def inner(*args, **kwargs):
if is_deprecated():
deprecated_kwargs = kwargs.keys() & deprecated_kws
if deprecated_kwargs:
msg = (
f"The keyword arguments {deprecated_kwargs} are "
"deprecated and will be removed in a future update.")
if additional_message is not None:
msg += f" {additional_message}"
warnings.warn(
DeprecationWarning(msg),
stacklevel=3, # The inner function takes up one level
)
return fn(*args, **kwargs)
return inner # type: ignore
return wrapper
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment