Unverified Commit ff7ec82c authored by SangBin Cho's avatar SangBin Cho Committed by GitHub
Browse files

[Core] Optimize SPMD architecture with delta + serialization optimization (#7109)

parent 200a2ffa
import warnings
from dataclasses import dataclass, field
from typing import Optional
import msgspec
from vllm.adapter_commons.request import AdapterRequest
@dataclass
class LoRARequest(AdapterRequest):
class LoRARequest(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg]
"""
Request for a LoRA adapter.
......@@ -18,16 +21,17 @@ class LoRARequest(AdapterRequest):
lora_int_id must be globally unique for a given adapter.
This is currently not enforced in vLLM.
"""
__metaclass__ = AdapterRequest
lora_name: str
lora_int_id: int
lora_path: str = ""
lora_local_path: Optional[str] = field(default=None, repr=False)
lora_local_path: Optional[str] = msgspec.field(default=None)
long_lora_max_len: Optional[int] = None
__hash__ = AdapterRequest.__hash__
def __post_init__(self):
if 'lora_local_path' in self.__dict__:
if 'lora_local_path' in self.__struct_fields__:
warnings.warn(
"The 'lora_local_path' attribute is deprecated "
"and will be removed in a future version. "
......
"""Minimal implementation of BlipVisionModel intended to be only used
within a vision language model."""
from array import array
from typing import Optional, Union
import torch
......@@ -16,7 +17,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal.image import (cached_get_tokenizer,
repeat_and_pad_image_tokens)
from vllm.sequence import SequenceData
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
......@@ -53,8 +54,10 @@ def dummy_seq_data_for_blip(
else:
image_feature_size = image_feature_size_override
token_ids = [image_token_id] * image_feature_size
token_ids += [0] * (seq_len - image_feature_size)
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[image_token_id]) * image_feature_size
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - image_feature_size)
return SequenceData(token_ids)
......
from array import array
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
......@@ -17,7 +18,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.opt import OPTModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SamplerOutput, SequenceData)
from .blip import (BlipVisionModel, dummy_image_for_blip,
get_max_blip_image_tokens)
......@@ -427,8 +429,10 @@ def dummy_seq_data_for_blip2(
else:
image_feature_size = image_feature_size_override
token_ids = [image_token_id] * image_feature_size * num_images
token_ids += [0] * (seq_len - image_feature_size * num_images)
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[image_token_id]) * image_feature_size * num_images
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - image_feature_size * num_images)
return SequenceData(token_ids)
......
from array import array
from functools import cached_property
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
Tuple, TypedDict)
......@@ -31,7 +32,8 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import (cached_get_tokenizer,
repeat_and_pad_image_tokens)
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SamplerOutput, SequenceData)
from vllm.utils import print_warning_once
from .interfaces import SupportsMultiModal
......@@ -70,8 +72,10 @@ def dummy_seq_data_for_chameleon(
else:
image_feature_size = image_feature_size_override
token_ids = [image_token_id] * image_feature_size * num_images
token_ids += [0] * (seq_len - image_feature_size * num_images)
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[image_token_id]) * image_feature_size * num_images
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - image_feature_size * num_images)
return SequenceData(token_ids)
......
"""Minimal implementation of CLIPVisionModel intended to be only used
within a vision language model."""
from array import array
from typing import Iterable, Optional, Tuple
import torch
......@@ -17,7 +18,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.image import (cached_get_tokenizer,
repeat_and_pad_image_tokens)
from vllm.sequence import SequenceData
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
......@@ -53,8 +54,10 @@ def dummy_seq_data_for_clip(
else:
image_feature_size = image_feature_size_override
token_ids = [image_token_id] * image_feature_size * num_images
token_ids += [0] * (seq_len - image_feature_size * num_images)
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[image_token_id]) * image_feature_size * num_images
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - image_feature_size * num_images)
return SequenceData(token_ids)
......
......@@ -16,6 +16,7 @@
# limitations under the License.
""" PyTorch Fuyu model."""
import math
from array import array
from typing import Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict
import torch
......@@ -37,7 +38,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.image import (cached_get_image_processor,
cached_get_tokenizer)
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SamplerOutput, SequenceData)
from .interfaces import SupportsMultiModal
from .utils import merge_multimodal_embeddings
......@@ -97,9 +99,12 @@ def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int, num_images: int):
ncol, nrow = get_max_fuyu_image_feature_size()
image_feature_size = get_max_fuyu_image_tokens(ctx)
image_token_ids = ([_IMAGE_TOKEN_ID] * ncol + [_NEWLINE_TOKEN_ID]) * nrow
token_ids = image_token_ids * num_images
token_ids += [0] * (seq_len - image_feature_size * num_images)
image_token_ids = (
array(VLLM_TOKEN_ID_ARRAY_TYPE, [_IMAGE_TOKEN_ID]) * ncol +
array(VLLM_TOKEN_ID_ARRAY_TYPE, [_NEWLINE_TOKEN_ID])) * nrow
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, image_token_ids) * num_images
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - image_feature_size * num_images)
return SequenceData(token_ids)
......
......@@ -23,6 +23,7 @@
"""Inference-only MiniCPM-V model compatible with HuggingFace weights."""
import math
import re
from array import array
from functools import partial
from typing import (Any, Callable, Iterable, List, Mapping, Optional, Tuple,
TypedDict, Union)
......@@ -55,7 +56,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import (cached_get_image_processor,
cached_get_tokenizer)
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SamplerOutput, SequenceData)
from .idefics2_vision_model import Idefics2VisionTransformer
......@@ -408,7 +410,7 @@ def get_max_minicpmv_image_tokens(ctx: InputContext):
def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int):
token_ids = [0] * seq_len
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * seq_len
return SequenceData(token_ids)
......
......@@ -2,6 +2,7 @@
within a vision language model."""
import math
from array import array
from typing import Iterable, Optional, Tuple
import torch
......@@ -25,7 +26,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.image import (cached_get_tokenizer,
repeat_and_pad_image_tokens)
from vllm.sequence import SequenceData
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
......@@ -62,8 +63,10 @@ def dummy_seq_data_for_siglip(
else:
image_feature_size = image_feature_size_override
token_ids = [image_token_id] * image_feature_size * num_images
token_ids += [0] * (seq_len - image_feature_size * num_images)
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[image_token_id]) * image_feature_size
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - image_feature_size)
return SequenceData(token_ids)
......
......@@ -6,7 +6,8 @@ from typing import Dict, List, Optional, Tuple
import torch
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SequenceData, SequenceGroupMetadata
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData,
SequenceGroupMetadata)
from vllm.triton_utils.sample import get_num_triton_sampler_splits
from vllm.utils import (PyObjectCache, async_tensor_h2d,
is_pin_memory_available, make_tensor_with_pad,
......@@ -505,9 +506,11 @@ class SamplingTensors:
and sampling_params.prompt_logprobs is not None):
prefill_len = len(seq_group.prompt_logprob_indices)
prompt_tokens.extend(
array('l') for _ in range(prefill_len))
array(VLLM_TOKEN_ID_ARRAY_TYPE)
for _ in range(prefill_len))
output_tokens.extend(
array('l') for _ in range(prefill_len))
array(VLLM_TOKEN_ID_ARRAY_TYPE)
for _ in range(prefill_len))
if seq_group.do_sample:
for seq_id in seq_ids:
seq_data = seq_group.seq_data[seq_id]
......
from typing import Any, Optional
import msgspec
class PoolingParams:
class PoolingParams(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg]
"""Pooling parameters for pooling.
Attributes:
additional_data: Any additional data needed for pooling.
"""
def __init__(self, additional_data: Optional[Any] = None):
self.additional_data = additional_data
additional_data: Optional[Any] = None
def clone(self) -> "PoolingParams":
"""Returns a deep copy of the PoolingParams instance."""
......
from dataclasses import dataclass
import msgspec
from vllm.adapter_commons.request import AdapterRequest
@dataclass
class PromptAdapterRequest(AdapterRequest):
class PromptAdapterRequest(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True, # type: ignore[call-arg]
frozen=True): # type: ignore[call-arg]
"""
Request for a Prompt adapter.
"""
__metaclass__ = AdapterRequest
prompt_adapter_name: str
prompt_adapter_id: int
......
......@@ -2,10 +2,10 @@
import copy
from enum import IntEnum
from functools import cached_property
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Set, Union
import msgspec
import torch
from pydantic import Field
from typing_extensions import Annotated
from vllm.logger import init_logger
......@@ -33,7 +33,11 @@ first argument, and returns a modified tensor of logits
to sample from."""
class SamplingParams:
class SamplingParams(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
# required for @cached_property.
dict=True): # type: ignore[call-arg]
"""Sampling parameters for text generation.
Overall, we follow the sampling parameters from the OpenAI text completion
......@@ -112,87 +116,73 @@ class SamplingParams:
(i.e., no truncation).
"""
def __init__(
self,
n: int = 1,
best_of: Optional[int] = None,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
repetition_penalty: float = 1.0,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
min_p: float = 0.0,
seed: Optional[int] = None,
use_beam_search: bool = False,
length_penalty: float = 1.0,
early_stopping: Union[bool, str] = False,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
include_stop_str_in_output: bool = False,
ignore_eos: bool = False,
max_tokens: Optional[int] = 16,
min_tokens: int = 0,
logprobs: Optional[int] = None,
prompt_logprobs: Optional[int] = None,
detokenize: bool = True,
skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True,
logits_processors: Optional[List[LogitsProcessor]] = None,
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
) -> None:
self.n = n
self.best_of = best_of if best_of is not None else n
self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty
self.repetition_penalty = repetition_penalty
if 0 < temperature < _MAX_TEMP:
n: int = 1
best_of: Optional[int] = None
presence_penalty: float = 0.0
frequency_penalty: float = 0.0
repetition_penalty: float = 1.0
temperature: float = 1.0
top_p: float = 1.0
top_k: int = -1
min_p: float = 0.0
seed: Optional[int] = None
use_beam_search: bool = False
length_penalty: float = 1.0
early_stopping: Union[bool, str] = False
stop: Optional[Union[str, List[str]]] = None
stop_token_ids: Optional[List[int]] = None
ignore_eos: bool = False
max_tokens: Optional[int] = 16
min_tokens: int = 0
logprobs: Optional[int] = None
prompt_logprobs: Optional[int] = None
# NOTE: This parameter is only exposed at the engine level for now.
# It is not exposed in the OpenAI API server, as the OpenAI API does
# not support returning only a list of token IDs.
detokenize: bool = True
skip_special_tokens: bool = True
spaces_between_special_tokens: bool = True
# Optional[List[LogitsProcessor]] type. We use Any here because
# Optional[List[LogitsProcessor]] type is not supported by msgspec.
logits_processors: Optional[Any] = None
include_stop_str_in_output: bool = False
truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None
# The below fields are not supposed to be used as an input.
# They are set in post_init.
output_text_buffer_length: int = 0
_all_stop_token_ids: Set[int] = msgspec.field(default_factory=set)
def __post_init__(self) -> None:
self.best_of = self.best_of or self.n
if 0 < self.temperature < _MAX_TEMP:
logger.warning(
"temperature %s is less than %s, which may cause numerical "
"errors nan or inf in tensors. We have maxed it out to %s.",
temperature, _MAX_TEMP, _MAX_TEMP)
temperature = max(temperature, _MAX_TEMP)
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.min_p = min_p
if seed == -1:
self.temperature, _MAX_TEMP, _MAX_TEMP)
self.temperature = max(self.temperature, _MAX_TEMP)
if self.seed == -1:
self.seed = None
else:
self.seed = seed
self.use_beam_search = use_beam_search
self.length_penalty = length_penalty
self.early_stopping = early_stopping
if stop is None:
self.seed = self.seed
if self.stop is None:
self.stop = []
elif isinstance(stop, str):
self.stop = [stop]
elif isinstance(self.stop, str):
self.stop = [self.stop]
else:
self.stop = list(stop)
if stop_token_ids is None:
self.stop = list(self.stop)
if self.stop_token_ids is None:
self.stop_token_ids = []
else:
self.stop_token_ids = list(stop_token_ids)
self.ignore_eos = ignore_eos
self.max_tokens = max_tokens
self.min_tokens = min_tokens
self.logprobs = 1 if logprobs is True else logprobs
self.prompt_logprobs = 1 if prompt_logprobs is True else prompt_logprobs
# NOTE: This parameter is only exposed at the engine level for now.
# It is not exposed in the OpenAI API server, as the OpenAI API does
# not support returning only a list of token IDs.
self.detokenize = detokenize
self.skip_special_tokens = skip_special_tokens
self.spaces_between_special_tokens = spaces_between_special_tokens
self.logits_processors = logits_processors
self.include_stop_str_in_output = include_stop_str_in_output
self.truncate_prompt_tokens = truncate_prompt_tokens
self.stop_token_ids = list(self.stop_token_ids)
self.logprobs = 1 if self.logprobs is True else self.logprobs
self.prompt_logprobs = (1 if self.prompt_logprobs is True else
self.prompt_logprobs)
# Number of characters to hold back for stop string evaluation
# until sequence is finished.
if self.stop and not include_stop_str_in_output:
if self.stop and not self.include_stop_str_in_output:
self.output_text_buffer_length = max(len(s) for s in self.stop) - 1
else:
self.output_text_buffer_length = 0
self._verify_args()
if self.use_beam_search:
......@@ -206,11 +196,12 @@ class SamplingParams:
self.min_p = 0.0
self._verify_greedy_sampling()
# eos_token_id is added to this by the engine
self.all_stop_token_ids = set(self.stop_token_ids)
self._all_stop_token_ids = set(self.stop_token_ids)
def _verify_args(self) -> None:
if self.n < 1:
raise ValueError(f"n must be at least 1, got {self.n}.")
assert isinstance(self.best_of, int)
if self.best_of < self.n:
raise ValueError(f"best_of must be greater than or equal to n, "
f"got n={self.n} and best_of={self.best_of}.")
......@@ -257,6 +248,7 @@ class SamplingParams:
and self.truncate_prompt_tokens < 1):
raise ValueError(f"truncate_prompt_tokens must be >= 1, "
f"got {self.truncate_prompt_tokens}")
assert isinstance(self.stop, list)
if any(not stop_str for stop_str in self.stop):
raise ValueError("stop cannot contain an empty string.")
if self.stop and not self.detokenize:
......@@ -290,6 +282,7 @@ class SamplingParams:
"default value of 1.0 when not using beam search.")
def _verify_greedy_sampling(self) -> None:
assert isinstance(self.best_of, int)
if self.best_of > 1:
raise ValueError("best_of must be 1 when using greedy sampling."
f"Got {self.best_of}.")
......@@ -303,7 +296,7 @@ class SamplingParams:
if model_eos_token_id is not None:
# Add the eos token id into the sampling_params to support
# min_tokens processing.
self.all_stop_token_ids.add(model_eos_token_id)
self._all_stop_token_ids.add(model_eos_token_id)
# Update eos_token_id for generation
if (eos_ids := generation_config.get("eos_token_id")) is not None:
......@@ -315,7 +308,7 @@ class SamplingParams:
# purposes.
eos_ids.discard(model_eos_token_id)
if eos_ids:
self.all_stop_token_ids.update(eos_ids)
self._all_stop_token_ids.update(eos_ids)
if not self.ignore_eos:
eos_ids.update(self.stop_token_ids)
self.stop_token_ids = list(eos_ids)
......@@ -330,6 +323,10 @@ class SamplingParams:
return SamplingType.RANDOM_SEED
return SamplingType.RANDOM
@property
def all_stop_token_ids(self) -> Set[int]:
return self._all_stop_token_ids
def clone(self) -> "SamplingParams":
"""Deep copy excluding LogitsProcessor objects.
......
......@@ -4,10 +4,11 @@ import enum
from abc import ABC, abstractmethod
from array import array
from collections import defaultdict
from dataclasses import dataclass, field
from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple,
Union, cast)
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set,
Tuple, Union, cast)
import msgspec
import numpy
import torch
......@@ -16,13 +17,18 @@ from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
if TYPE_CHECKING:
from vllm.inputs import LLMInputs
from vllm.multimodal import MultiModalDataDict
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
from vllm.multimodal.base import MultiModalDataDict
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
# We use dataclass for now because it is used for
# openai server output, and msgspec is not serializable.
# TODO(sang): Fix it.
@dataclass
class Logprob:
"""Infos for supporting OpenAI compatible logprobs and token ranks.
......@@ -112,7 +118,23 @@ class RequestMetrics:
model_execute_time: Optional[float] = None
class SequenceData:
class SequenceDataDelta(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True): # type: ignore[call-arg]
"""Delta SequenceData to send to workers per step."""
# A new token to be appended to existing SequenceData.
new_output_token_ids: List[int]
# Overwriting existing `cumulative_logprob`
new_cumulative_logprob: float
# Overwriting existing `num_computed_tokens`.
new_num_computed_tokens: int
# Overwriting existing `stage`.
new_stage: SequenceStage
class SequenceData(msgspec.Struct,
omit_defaults=True): # type: ignore[call-arg]
"""Data associated with a sequence.
Args:
......@@ -125,40 +147,57 @@ class SequenceData:
output_token_ids: The token IDs of the output.
cumulative_logprob: The cumulative log probability of the output.
"""
def __init__(
self,
prompt_token_ids: List[int],
output_token_ids: Optional[List[int]] = None,
) -> None:
self._prompt_token_ids = array('l', prompt_token_ids)
self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(prompt_token_ids)
self._output_token_ids = array(
'l', output_token_ids if output_token_ids is not None else [])
self.cumulative_logprob = 0.0
# The number of tokens that are computed (that run against the model).
self._num_computed_tokens = 0
self._stage: SequenceStage = SequenceStage.PREFILL
# NOTE: we cannot use Union[List, array] because msgspec cannot support
# union of 2 list types.
_prompt_token_ids: array
_output_token_ids: array = msgspec.field(
default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, []))
### The below fields should not be passed as an argument ###
_cumulative_logprob: float = 0.0
_prompt_token_ids_tuple: Tuple[int,
...] = msgspec.field(default_factory=tuple)
# The number of tokens that are computed (that run against the model).
_num_computed_tokens: int = 0
_stage: SequenceStage = SequenceStage.PREFILL
_cached_all_token_ids: List[int] = msgspec.field(default_factory=list)
# It is used to get delta input. It is reset when `get_delta_and_reset`
# is called.
_new_appended_tokens: List[int] = msgspec.field(default_factory=list)
def __post_init__(self) -> None:
assert self._prompt_token_ids.typecode == "l"
assert self._output_token_ids.typecode == "l"
self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(
self._prompt_token_ids)
self._update_cached_all_tokens()
def _update_cached_all_tokens(self):
assert isinstance(self._prompt_token_ids, array)
assert isinstance(self._output_token_ids, array)
self._cached_all_token_ids: List[int] = list(self._prompt_token_ids +
self._output_token_ids)
@property
def cumulative_logprob(self) -> float:
return self._cumulative_logprob
@property
def prompt_token_ids(self) -> Tuple[int, ...]:
return self._prompt_token_ids_tuple
@prompt_token_ids.setter
def prompt_token_ids(self, new_prompt_token_ids) -> None:
self._prompt_token_ids = array('l', new_prompt_token_ids)
self._prompt_token_ids_tuple = tuple(new_prompt_token_ids)
self._update_cached_all_tokens()
raise NotImplementedError
@property
def prompt_token_ids_array(self) -> array:
"""Return the prompt token ids in array type.
Note that the array is in "I" type, and it is not compatible
with torch.long (2 bytes vs 4 bytes). So beware of the usage.
"""
return self._prompt_token_ids
@property
......@@ -166,18 +205,26 @@ class SequenceData:
return tuple(self._output_token_ids)
@output_token_ids.setter
def output_token_ids(self, new_output_token_ids) -> None:
self._output_token_ids = array('l', new_output_token_ids)
def output_token_ids(self, new_output_token_ids: List[int]) -> None:
self._output_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
new_output_token_ids)
self._update_cached_all_tokens()
@property
def output_token_ids_array(self) -> array:
"""Return the prompt token ids in array type.
Note that the array is in "I" type, and it is not compatible
with torch.long (2 bytes vs 4 bytes). So beware of the usage.
"""
assert isinstance(self._output_token_ids, array)
return self._output_token_ids
def append_token_id(self, token_id: int, logprob: float) -> None:
self._output_token_ids.append(token_id)
self._new_appended_tokens.append(token_id)
self._cached_all_token_ids.append(token_id)
self.cumulative_logprob += logprob
self._cumulative_logprob += logprob
def get_len(self) -> int:
return len(self._output_token_ids) + len(self._prompt_token_ids)
......@@ -222,6 +269,7 @@ class SequenceData:
"""
self._num_computed_tokens = 0
self._stage = SequenceStage.PREFILL
self._new_appended_tokens = []
def get_num_uncomputed_tokens(self) -> int:
"""Return the number of prefill tokens that are not computed."""
......@@ -241,6 +289,21 @@ class SequenceData:
def get_output_token_ids(self) -> Tuple[int, ...]:
return self.output_token_ids
def get_delta_and_reset(self) -> SequenceDataDelta:
delta = SequenceDataDelta(self._new_appended_tokens,
self._cumulative_logprob,
self.get_num_computed_tokens(), self.stage)
# Reset delta state.
self._new_appended_tokens = []
return delta
def apply_delta(self, delta: SequenceDataDelta):
self._num_computed_tokens = delta.new_num_computed_tokens
self._cumulative_logprob = delta.new_cumulative_logprob
self._stage = delta.new_stage
self._output_token_ids.extend(delta.new_output_token_ids)
self._cached_all_token_ids.extend(delta.new_output_token_ids)
@property
def stage(self) -> SequenceStage:
return self._stage
......@@ -248,8 +311,9 @@ class SequenceData:
def __repr__(self) -> str:
return (f"SequenceData("
f"prompt_token_ids={self._prompt_token_ids}, "
f"output_token_ids={self._output_token_ids}, "
f"cumulative_logprob={self.cumulative_logprob})")
f"output_token_ids={self.output_token_ids}, "
f"cumulative_logprob={self.cumulative_logprob}, "
f"get_num_computed_tokens={self.get_num_computed_tokens()}")
class Sequence:
......@@ -325,7 +389,8 @@ class Sequence:
f"invalid input {inputs}; did you forget the "
"encoder input prompt fields?")
self.data = SequenceData(self.prompt_token_ids)
self.data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, self.prompt_token_ids))
self.output_logprobs: SampleLogprobs = []
self.output_text = ""
......@@ -490,8 +555,8 @@ class Sequence:
f"num_blocks={self.n_blocks}, ")
@dataclass
class SequenceGroupState:
class SequenceGroupState(msgspec.Struct,
omit_defaults=True): # type: ignore[call-arg]
"""Mutable state tied to a specific sequence group"""
# for multi-step decoding
......@@ -647,14 +712,19 @@ class SequenceGroup:
if self.sampling_params and self.sampling_params.use_beam_search:
# For beam search, maximally there will always be `best_of` beam
# candidates running in the future.
return self.sampling_params.best_of
best_of = self.sampling_params.best_of
assert isinstance(best_of, int)
return best_of
else:
if (self.sampling_params
and self.sampling_params.best_of > self.num_seqs()):
# At prompt stage, the sequence group is not yet filled up
# and only have one sequence running. However, in the
# generation stage, we will have `best_of` sequences running.
return self.sampling_params.best_of
if self.sampling_params:
best_of = self.sampling_params.best_of
assert isinstance(best_of, int)
if best_of > self.num_seqs():
# At prompt stage, the sequence group is not yet filled up
# and only have one sequence running. However, in the
# generation stage, we will have `best_of` sequences
# running.
return best_of
# At sampling stages, return the number of actual sequences
# that are not finished yet.
return self.num_unfinished_seqs()
......@@ -757,7 +827,32 @@ class SequenceGroup:
f"num_seqs={len(self.seqs)})")
class SequenceGroupMetadata:
class SequenceGroupMetadataDelta(
msgspec.Struct,
tag=True, # type: ignore[call-arg]
array_like=True, # type: ignore[call-arg]
omit_defaults=True): # type: ignore[call-arg]
"""Delta of SequenceGroupMetadata.
After sending the first SequenceGroupMetadata, vLLM scheduler
only sends delta to reduce the data payload size.
"""
seq_data_delta: Dict[int, SequenceDataDelta]
request_id: str
block_tables: Dict[int, List[int]]
is_prompt: bool
do_sample: bool = True
token_chunk_size: Optional[int] = None
computed_block_nums: Optional[List[int]] = None
state: Optional[SequenceGroupState] = msgspec.field(
default_factory=lambda: SequenceGroupState())
class SequenceGroupMetadata(
msgspec.Struct,
tag=True, # type: ignore[call-arg]
array_like=True, # type: ignore[call-arg]
omit_defaults=True): # type: ignore[call-arg]
"""Metadata for a sequence group. Used to create `AttentionMetadata`.
Args:
......@@ -789,52 +884,39 @@ class SequenceGroupMetadata:
prompt_adapter_request: Prompt Adapter request.
"""
def __init__(
self,
request_id: str,
is_prompt: bool,
seq_data: Dict[int, SequenceData],
sampling_params: SamplingParams,
block_tables: Dict[int, List[int]],
do_sample: bool = True,
pooling_params: Optional[PoolingParams] = None,
token_chunk_size: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
computed_block_nums: Optional[List[int]] = None,
state: Optional[SequenceGroupState] = None,
multi_modal_data: Optional["MultiModalDataDict"] = None,
encoder_seq_data: Optional[SequenceData] = None,
cross_block_table: Optional[List[int]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
self.request_id = request_id
self.is_prompt = is_prompt
self.seq_data = seq_data
self.sampling_params = sampling_params
self.block_tables = block_tables
self.pooling_params = pooling_params
self.lora_request = lora_request
self.prompt_adapter_request = prompt_adapter_request
self.computed_block_nums = computed_block_nums
self.multi_modal_data = multi_modal_data
self.state = SequenceGroupState() if state is None else state
self.encoder_seq_data = encoder_seq_data
self.cross_block_table = cross_block_table
self._token_chunk_size = token_chunk_size
self.do_sample = do_sample
# The number of speculative tokens adopted in this request.
# None means specuative decoding is not used.
# Zero means speculative decoding is disabled for some reasons.
# TODO: We should maintain this states out of the sequence group.
self.num_speculative_tokens = None
if seq_data is not None and self._token_chunk_size is None:
if is_prompt:
self._token_chunk_size = next(iter(
seq_data.values())).get_len()
request_id: str
is_prompt: bool
seq_data: Dict[int, SequenceData]
sampling_params: SamplingParams
block_tables: Dict[int, List[int]]
do_sample: bool = True
pooling_params: Optional[PoolingParams] = None
lora_request: Optional[LoRARequest] = None
computed_block_nums: Optional[List[int]] = None
state: Optional[SequenceGroupState] = msgspec.field(
default_factory=lambda: SequenceGroupState())
# "MultiModalDataDict" types. We have to use Any due to msgspec
# doesn't allow to have union of 2 different dicts.
multi_modal_data: Optional[Any] = None
encoder_seq_data: Optional[SequenceData] = None
cross_block_table: Optional[List[int]] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None
token_chunk_size: Optional[int] = None
### Stateful fields that are lazily defined. ###
# The number of speculative tokens adopted in this request.
# None means specuative decoding is not used.
# Zero means speculative decoding is disabled for some reasons.
# TODO: We should maintain this states out of the sequence group.
num_speculative_tokens: Optional[int] = None
def __post_init__(self):
if self.seq_data is not None and self.token_chunk_size is None:
if self.is_prompt:
self.token_chunk_size = next(iter(
self.seq_data.values())).get_len()
else:
self._token_chunk_size = 1
self.token_chunk_size = 1
@property
def lora_int_id(self) -> int:
......@@ -850,18 +932,26 @@ class SequenceGroupMetadata:
return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \
if self.prompt_adapter_request else 0
@property
def token_chunk_size(self) -> int:
"""Return the number of tokens to be processed (chunk size)."""
assert self._token_chunk_size is not None
return self._token_chunk_size
def apply_delta(self,
sequence_group_metadata_delta: SequenceGroupMetadataDelta):
for id, delta in sequence_group_metadata_delta.seq_data_delta.items():
self.seq_data[id].apply_delta(delta)
assert self.request_id == sequence_group_metadata_delta.request_id
self.block_tables = sequence_group_metadata_delta.block_tables
self.token_chunk_size = sequence_group_metadata_delta.token_chunk_size
self.do_sample = sequence_group_metadata_delta.do_sample
self.is_prompt = sequence_group_metadata_delta.is_prompt
def finish_step(self) -> None:
assert self.state is not None
assert self.state.current_step < self.state.num_steps
self.state.current_step += 1
class SequenceOutput:
class SequenceOutput(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg]
"""The model output associated with a sequence.
Args:
......@@ -871,16 +961,9 @@ class SequenceOutput:
logprobs: The logprobs of the output token.
(Token id -> logP(x_i+1 | x_0, ..., x_i))
"""
def __init__(
self,
parent_seq_id: int,
output_token: int,
logprobs: Dict[int, Logprob],
) -> None:
self.parent_seq_id = parent_seq_id
self.output_token = output_token
self.logprobs = logprobs
parent_seq_id: int
output_token: int
logprobs: Dict[int, Logprob]
def __repr__(self) -> str:
return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
......@@ -908,17 +991,15 @@ class SequenceGroupOutput(ABC):
pass
class CompletionSequenceGroupOutput(SequenceGroupOutput):
class CompletionSequenceGroupOutput(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg]
__metaclass__ = SequenceGroupOutput
"""The model output associated with a completion sequence group."""
def __init__(
self,
samples: List[SequenceOutput],
prompt_logprobs: Optional[PromptLogprobs],
) -> None:
self.samples = samples
# Prompt logprob for each prompt query token.
self.prompt_logprobs = prompt_logprobs
samples: List[SequenceOutput]
# Prompt logprob for each prompt query token.
prompt_logprobs: Optional[PromptLogprobs]
def __repr__(self) -> str:
return (f"CompletionSequenceGroupOutput(samples={self.samples}, "
......@@ -931,14 +1012,14 @@ class CompletionSequenceGroupOutput(SequenceGroupOutput):
and self.prompt_logprobs == other.prompt_logprobs)
class EmbeddingSequenceGroupOutput(SequenceGroupOutput):
class EmbeddingSequenceGroupOutput(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True, # type: ignore[call-arg]
):
"""The model output associated with an embedding sequence group."""
def __init__(
self,
embeddings: List[float],
) -> None:
self.embeddings = embeddings
__metaclass__ = SequenceGroupOutput
embeddings: List[int]
def __repr__(self) -> str:
return (f"EmbeddingSequenceGroupOutput("
......@@ -950,8 +1031,10 @@ class EmbeddingSequenceGroupOutput(SequenceGroupOutput):
return self.embeddings == other.embeddings
@dataclass
class IntermediateTensors:
class IntermediateTensors(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg]
"""For all pipeline stages except the last, we need to return the hidden
states and residuals to be sent to the next stage. This data structure
contains the hidden states and residuals for a request.
......@@ -978,8 +1061,10 @@ class IntermediateTensors:
return f"IntermediateTensors(tensors={self.tensors})"
@dataclass
class SamplerOutput:
class SamplerOutput(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg]
"""For each sequence group, we generate a list of SequenceOutput object,
each of which contains one possible candidate for the next token.
......@@ -1000,7 +1085,7 @@ class SamplerOutput:
sampled_token_ids_numpy: Optional[numpy.ndarray] = None
# Spec decode metrics populated by workers.
spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
# Optional last hidden states from the model.
hidden_states: Optional[torch.Tensor] = None
......@@ -1039,12 +1124,14 @@ class SamplerOutput:
f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
@dataclass
class PoolerOutput:
class PoolerOutput(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg]
"""The output from a pooling operation in the embedding model."""
outputs: List[EmbeddingSequenceGroupOutput]
spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
def __getitem__(self, idx: int):
return self.outputs[idx]
......@@ -1083,7 +1170,8 @@ def get_all_seq_ids_and_request_ids(
return seq_ids, request_id_seq_ids_mapping
class HiddenStates:
class HiddenStates(msgspec.Struct, array_like=True,
omit_defaults=True): # type: ignore[call-arg]
"""Hidden states corresponding to in-progress sequences.
Used in speculative decoding to pass hidden states from
the target model to the proposer model in the subsequent step.
......@@ -1091,42 +1179,53 @@ class HiddenStates:
seq_ids are the sequence ids of each entry of the batch
dimension of the hidden_states tensor"""
def __init__(self, seq_group_metadata_list: List[SequenceGroupMetadata],
hidden_states: torch.Tensor):
assert len(seq_group_metadata_list) == len(hidden_states)
self.seq_ids: List[int] = get_all_seq_ids(seq_group_metadata_list)
self.hidden_states: torch.Tensor = hidden_states
seq_group_metadata_list: List[SequenceGroupMetadata]
hidden_states: torch.Tensor
_seq_ids: List[int] = msgspec.field(default_factory=list)
def __post_init__(self):
self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list)
assert len(self.seq_group_metadata_list) == len(self.hidden_states)
@property
def seq_ids(self) -> List[int]:
return self._seq_ids
def update(self, seq_group_metadata_list: List[SequenceGroupMetadata],
hidden_states: torch.Tensor) -> None:
"""Update hidden states from target model invocation."""
assert len(seq_group_metadata_list) == len(hidden_states)
self.seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
self.hidden_states = torch.cat([self.hidden_states, hidden_states])
def prune(self,
seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
"""Prune to provided list of sequence ids."""
seq_ids = get_all_seq_ids(seq_group_metadata_list)
if seq_ids != self.seq_ids:
if seq_ids != self._seq_ids:
# Batch contents changed - prune removed sequences.
index = [self.seq_ids.index(seq_id) for seq_id in seq_ids]
index = [self._seq_ids.index(seq_id) for seq_id in seq_ids]
self.hidden_states = self.hidden_states[index]
self.seq_ids = seq_ids
self._seq_ids = seq_ids
@dataclass
class ExecuteModelRequest:
class ExecuteModelRequest(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True): # type: ignore[call-arg]
"""The model execution request, containing CPU metadata only. The LLM
engine should create an instance of this class for each request batch."""
# The sequence group metadata list.
seq_group_metadata_list: List[SequenceGroupMetadata]
seq_group_metadata_list: List[Union[SequenceGroupMetadata,
SequenceGroupMetadataDelta]]
# Blocks to swap in. List of CPU -> GPU block number.
blocks_to_swap_in: List[Tuple[int, int]] = field(default_factory=list)
blocks_to_swap_in: List[Tuple[int,
int]] = msgspec.field(default_factory=list)
# Blocks to swap out. List of GPU -> CPU block number.
blocks_to_swap_out: List[Tuple[int, int]] = field(default_factory=list)
blocks_to_swap_out: List[Tuple[int,
int]] = msgspec.field(default_factory=list)
# Blocks to copy. Source to dest block.
blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list)
blocks_to_copy: List[Tuple[int, int]] = msgspec.field(default_factory=list)
# Virtual engine ID for pipeline parallel.
virtual_engine: int = 0
# The number of slots for lookahead decoding.
......@@ -1138,7 +1237,7 @@ class ExecuteModelRequest:
# The number of forward steps to run.
num_steps: int = 1
# Finished request ids since last step.
finished_requests_ids: List[str] = field(default_factory=list)
finished_requests_ids: List[str] = msgspec.field(default_factory=list)
# The last sampled token ids for multi step decoding.
last_sampled_token_ids: Optional[torch.Tensor] = None
......@@ -1148,6 +1247,7 @@ class ExecuteModelRequest:
# steps
assert len(self.seq_group_metadata_list) > 0
first_seq_group = self.seq_group_metadata_list[0]
assert first_seq_group.state is not None
return first_seq_group.state.current_step == 0
@property
......@@ -1156,6 +1256,7 @@ class ExecuteModelRequest:
# steps
assert len(self.seq_group_metadata_list) > 0
first_seq_group = self.seq_group_metadata_list[0]
assert first_seq_group.state is not None
num_steps = first_seq_group.state.num_steps
current_step = first_seq_group.state.current_step
return num_steps - current_step == 1
......@@ -1165,10 +1266,13 @@ class ExecuteModelRequest:
# TODO(will) make this be able to handle batches with variable number of
# steps
assert len(self.seq_group_metadata_list) > 0
return self.seq_group_metadata_list[0].state.current_step
state = self.seq_group_metadata_list[0].state
assert state is not None
return state.current_step
def clone(
self, seq_group_metadata_list: List[SequenceGroupMetadata]
self, seq_group_metadata_list: List[Union[SequenceGroupMetadata,
SequenceGroupMetadataDelta]]
) -> "ExecuteModelRequest":
"""Clone the request with a new sequence group metadata list."""
return ExecuteModelRequest(
......
from array import array
from itertools import chain, count
from typing import Iterator, List, Tuple
import torch
from vllm import SamplingParams
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
SequenceGroupMetadata, get_all_seq_ids)
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, ExecuteModelRequest,
SamplerOutput, SequenceData, SequenceGroupMetadata,
get_all_seq_ids)
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.util import (nvtx_range, sampler_output_to_torch,
......@@ -293,14 +295,15 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
input sequence.
"""
seq_data = seq_group_metadata.seq_data[seq_id]
prompt_token_ids = seq_data.get_prompt_token_ids()
prompt_token_ids = seq_data.prompt_token_ids_array
new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids]
new_seq_data_dict = {
target_seq_id:
SequenceData(
prompt_token_ids=prompt_token_ids,
output_token_ids=new_output_token_ids,
prompt_token_ids,
_output_token_ids=array(VLLM_TOKEN_ID_ARRAY_TYPE,
new_output_token_ids),
),
}
# This is a hack. Technically, spec decoding should compute
......
import time
from dataclasses import dataclass
from typing import Callable, Optional
import msgspec
import torch
from vllm.model_executor.layers.spec_decode_base_sampler import (
......@@ -9,8 +9,10 @@ from vllm.model_executor.layers.spec_decode_base_sampler import (
from vllm.utils import is_pin_memory_available
@dataclass
class SpecDecodeWorkerMetrics:
class SpecDecodeWorkerMetrics(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg]
"""Dataclass holding metrics emitted from the spec decode worker.
"""
......
"""A GPU worker class."""
import gc
import os
from typing import List, Optional, Set, Tuple, Type
from typing import Dict, List, Optional, Set, Tuple, Type, Union
import torch
import torch.distributed
......@@ -18,7 +18,9 @@ from vllm.model_executor import set_random_seed
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.platforms import current_platform
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
SamplerOutput, SequenceGroupMetadata,
SequenceGroupMetadataDelta)
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.embedding_model_runner import EmbeddingModelRunner
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
......@@ -109,6 +111,7 @@ class Worker(LocalOrDistributedWorkerBase):
self.cache_engine: List[CacheEngine]
# Initialize gpu_cache as embedding models don't initialize kv_caches
self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {}
def _is_encoder_decoder_model(self):
return self.model_config.is_encoder_decoder_model
......@@ -303,6 +306,63 @@ class Worker(LocalOrDistributedWorkerBase):
and worker_input.blocks_to_copy.numel() > 0):
self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy)
def _get_cached_seq_group_metadata(
self,
seq_group_metadata_list: List[Union[SequenceGroupMetadata,
SequenceGroupMetadataDelta]],
finished_request_ids: List[str]) -> List[SequenceGroupMetadata]:
"""Return a list of cached Sequence Group Metadata after updating its
state.
It is used because scheduler only sends delta to workers to reduce
the data payload size. The function also cleans up cache based on
a given `finished_request_ids`.
"""
new_seq_group_metadata_list = []
for metadata_or_delta in seq_group_metadata_list:
request_id = metadata_or_delta.request_id
if request_id not in self._seq_group_metadata_cache:
# The first prefill.
assert isinstance(metadata_or_delta, SequenceGroupMetadata)
self._seq_group_metadata_cache[request_id] = metadata_or_delta
else:
# The first prefill is already cached.
if isinstance(metadata_or_delta, SequenceGroupMetadataDelta):
self._seq_group_metadata_cache[request_id].apply_delta(
metadata_or_delta)
else:
# If metadata snapshot is sent again, it is
# preempted. Reset the cache because we need to start
# from scratch.
assert isinstance(metadata_or_delta, SequenceGroupMetadata)
self._seq_group_metadata_cache[
request_id] = metadata_or_delta
new_seq_group_metadata_list.append(
self._seq_group_metadata_cache[request_id])
# Clean up finished ids
for finished_id in finished_request_ids:
del self._seq_group_metadata_cache[finished_id]
return new_seq_group_metadata_list
def _execute_model_spmd(
self,
execute_model_req: ExecuteModelRequest,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Optional[List[SamplerOutput]]:
if execute_model_req is not None:
new_seq_group_metadata_list = self._get_cached_seq_group_metadata(
execute_model_req.seq_group_metadata_list,
execute_model_req.finished_requests_ids)
execute_model_req.seq_group_metadata_list = (
new_seq_group_metadata_list)
output = super()._execute_model_spmd(execute_model_req,
intermediate_tensors)
return output
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment