Unverified Commit ff7ec82c authored by SangBin Cho's avatar SangBin Cho Committed by GitHub
Browse files

[Core] Optimize SPMD architecture with delta + serialization optimization (#7109)

parent 200a2ffa
import warnings import warnings
from dataclasses import dataclass, field
from typing import Optional from typing import Optional
import msgspec
from vllm.adapter_commons.request import AdapterRequest from vllm.adapter_commons.request import AdapterRequest
@dataclass class LoRARequest(
class LoRARequest(AdapterRequest): msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg]
""" """
Request for a LoRA adapter. Request for a LoRA adapter.
...@@ -18,16 +21,17 @@ class LoRARequest(AdapterRequest): ...@@ -18,16 +21,17 @@ class LoRARequest(AdapterRequest):
lora_int_id must be globally unique for a given adapter. lora_int_id must be globally unique for a given adapter.
This is currently not enforced in vLLM. This is currently not enforced in vLLM.
""" """
__metaclass__ = AdapterRequest
lora_name: str lora_name: str
lora_int_id: int lora_int_id: int
lora_path: str = "" lora_path: str = ""
lora_local_path: Optional[str] = field(default=None, repr=False) lora_local_path: Optional[str] = msgspec.field(default=None)
long_lora_max_len: Optional[int] = None long_lora_max_len: Optional[int] = None
__hash__ = AdapterRequest.__hash__ __hash__ = AdapterRequest.__hash__
def __post_init__(self): def __post_init__(self):
if 'lora_local_path' in self.__dict__: if 'lora_local_path' in self.__struct_fields__:
warnings.warn( warnings.warn(
"The 'lora_local_path' attribute is deprecated " "The 'lora_local_path' attribute is deprecated "
"and will be removed in a future version. " "and will be removed in a future version. "
......
"""Minimal implementation of BlipVisionModel intended to be only used """Minimal implementation of BlipVisionModel intended to be only used
within a vision language model.""" within a vision language model."""
from array import array
from typing import Optional, Union from typing import Optional, Union
import torch import torch
...@@ -16,7 +17,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -16,7 +17,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal.image import (cached_get_tokenizer, from vllm.multimodal.image import (cached_get_tokenizer,
repeat_and_pad_image_tokens) repeat_and_pad_image_tokens)
from vllm.sequence import SequenceData from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int: def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
...@@ -53,8 +54,10 @@ def dummy_seq_data_for_blip( ...@@ -53,8 +54,10 @@ def dummy_seq_data_for_blip(
else: else:
image_feature_size = image_feature_size_override image_feature_size = image_feature_size_override
token_ids = [image_token_id] * image_feature_size token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
token_ids += [0] * (seq_len - image_feature_size) [image_token_id]) * image_feature_size
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - image_feature_size)
return SequenceData(token_ids) return SequenceData(token_ids)
......
from array import array
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union) TypedDict, Union)
...@@ -17,7 +18,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -17,7 +18,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.opt import OPTModel from vllm.model_executor.models.opt import OPTModel
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SamplerOutput, SequenceData)
from .blip import (BlipVisionModel, dummy_image_for_blip, from .blip import (BlipVisionModel, dummy_image_for_blip,
get_max_blip_image_tokens) get_max_blip_image_tokens)
...@@ -427,8 +429,10 @@ def dummy_seq_data_for_blip2( ...@@ -427,8 +429,10 @@ def dummy_seq_data_for_blip2(
else: else:
image_feature_size = image_feature_size_override image_feature_size = image_feature_size_override
token_ids = [image_token_id] * image_feature_size * num_images token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
token_ids += [0] * (seq_len - image_feature_size * num_images) [image_token_id]) * image_feature_size * num_images
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - image_feature_size * num_images)
return SequenceData(token_ids) return SequenceData(token_ids)
......
from array import array
from functools import cached_property from functools import cached_property
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
Tuple, TypedDict) Tuple, TypedDict)
...@@ -31,7 +32,8 @@ from vllm.model_executor.utils import set_weight_attrs ...@@ -31,7 +32,8 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import (cached_get_tokenizer, from vllm.multimodal.image import (cached_get_tokenizer,
repeat_and_pad_image_tokens) repeat_and_pad_image_tokens)
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SamplerOutput, SequenceData)
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
from .interfaces import SupportsMultiModal from .interfaces import SupportsMultiModal
...@@ -70,8 +72,10 @@ def dummy_seq_data_for_chameleon( ...@@ -70,8 +72,10 @@ def dummy_seq_data_for_chameleon(
else: else:
image_feature_size = image_feature_size_override image_feature_size = image_feature_size_override
token_ids = [image_token_id] * image_feature_size * num_images token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
token_ids += [0] * (seq_len - image_feature_size * num_images) [image_token_id]) * image_feature_size * num_images
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - image_feature_size * num_images)
return SequenceData(token_ids) return SequenceData(token_ids)
......
"""Minimal implementation of CLIPVisionModel intended to be only used """Minimal implementation of CLIPVisionModel intended to be only used
within a vision language model.""" within a vision language model."""
from array import array
from typing import Iterable, Optional, Tuple from typing import Iterable, Optional, Tuple
import torch import torch
...@@ -17,7 +18,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig ...@@ -17,7 +18,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.image import (cached_get_tokenizer, from vllm.multimodal.image import (cached_get_tokenizer,
repeat_and_pad_image_tokens) repeat_and_pad_image_tokens)
from vllm.sequence import SequenceData from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int: def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
...@@ -53,8 +54,10 @@ def dummy_seq_data_for_clip( ...@@ -53,8 +54,10 @@ def dummy_seq_data_for_clip(
else: else:
image_feature_size = image_feature_size_override image_feature_size = image_feature_size_override
token_ids = [image_token_id] * image_feature_size * num_images token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
token_ids += [0] * (seq_len - image_feature_size * num_images) [image_token_id]) * image_feature_size * num_images
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - image_feature_size * num_images)
return SequenceData(token_ids) return SequenceData(token_ids)
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
# limitations under the License. # limitations under the License.
""" PyTorch Fuyu model.""" """ PyTorch Fuyu model."""
import math import math
from array import array
from typing import Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict from typing import Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict
import torch import torch
...@@ -37,7 +38,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY ...@@ -37,7 +38,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.image import (cached_get_image_processor, from vllm.multimodal.image import (cached_get_image_processor,
cached_get_tokenizer) cached_get_tokenizer)
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SamplerOutput, SequenceData)
from .interfaces import SupportsMultiModal from .interfaces import SupportsMultiModal
from .utils import merge_multimodal_embeddings from .utils import merge_multimodal_embeddings
...@@ -97,9 +99,12 @@ def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int, num_images: int): ...@@ -97,9 +99,12 @@ def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int, num_images: int):
ncol, nrow = get_max_fuyu_image_feature_size() ncol, nrow = get_max_fuyu_image_feature_size()
image_feature_size = get_max_fuyu_image_tokens(ctx) image_feature_size = get_max_fuyu_image_tokens(ctx)
image_token_ids = ([_IMAGE_TOKEN_ID] * ncol + [_NEWLINE_TOKEN_ID]) * nrow image_token_ids = (
token_ids = image_token_ids * num_images array(VLLM_TOKEN_ID_ARRAY_TYPE, [_IMAGE_TOKEN_ID]) * ncol +
token_ids += [0] * (seq_len - image_feature_size * num_images) array(VLLM_TOKEN_ID_ARRAY_TYPE, [_NEWLINE_TOKEN_ID])) * nrow
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, image_token_ids) * num_images
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - image_feature_size * num_images)
return SequenceData(token_ids) return SequenceData(token_ids)
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
"""Inference-only MiniCPM-V model compatible with HuggingFace weights.""" """Inference-only MiniCPM-V model compatible with HuggingFace weights."""
import math import math
import re import re
from array import array
from functools import partial from functools import partial
from typing import (Any, Callable, Iterable, List, Mapping, Optional, Tuple, from typing import (Any, Callable, Iterable, List, Mapping, Optional, Tuple,
TypedDict, Union) TypedDict, Union)
...@@ -55,7 +56,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -55,7 +56,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import (cached_get_image_processor, from vllm.multimodal.image import (cached_get_image_processor,
cached_get_tokenizer) cached_get_tokenizer)
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SamplerOutput, SequenceData)
from .idefics2_vision_model import Idefics2VisionTransformer from .idefics2_vision_model import Idefics2VisionTransformer
...@@ -408,7 +410,7 @@ def get_max_minicpmv_image_tokens(ctx: InputContext): ...@@ -408,7 +410,7 @@ def get_max_minicpmv_image_tokens(ctx: InputContext):
def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int): def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int):
token_ids = [0] * seq_len token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * seq_len
return SequenceData(token_ids) return SequenceData(token_ids)
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
within a vision language model.""" within a vision language model."""
import math import math
from array import array
from typing import Iterable, Optional, Tuple from typing import Iterable, Optional, Tuple
import torch import torch
...@@ -25,7 +26,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -25,7 +26,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.image import (cached_get_tokenizer, from vllm.multimodal.image import (cached_get_tokenizer,
repeat_and_pad_image_tokens) repeat_and_pad_image_tokens)
from vllm.sequence import SequenceData from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int: def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
...@@ -62,8 +63,10 @@ def dummy_seq_data_for_siglip( ...@@ -62,8 +63,10 @@ def dummy_seq_data_for_siglip(
else: else:
image_feature_size = image_feature_size_override image_feature_size = image_feature_size_override
token_ids = [image_token_id] * image_feature_size * num_images token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
token_ids += [0] * (seq_len - image_feature_size * num_images) [image_token_id]) * image_feature_size
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - image_feature_size)
return SequenceData(token_ids) return SequenceData(token_ids)
......
...@@ -6,7 +6,8 @@ from typing import Dict, List, Optional, Tuple ...@@ -6,7 +6,8 @@ from typing import Dict, List, Optional, Tuple
import torch import torch
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SequenceData, SequenceGroupMetadata from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData,
SequenceGroupMetadata)
from vllm.triton_utils.sample import get_num_triton_sampler_splits from vllm.triton_utils.sample import get_num_triton_sampler_splits
from vllm.utils import (PyObjectCache, async_tensor_h2d, from vllm.utils import (PyObjectCache, async_tensor_h2d,
is_pin_memory_available, make_tensor_with_pad, is_pin_memory_available, make_tensor_with_pad,
...@@ -505,9 +506,11 @@ class SamplingTensors: ...@@ -505,9 +506,11 @@ class SamplingTensors:
and sampling_params.prompt_logprobs is not None): and sampling_params.prompt_logprobs is not None):
prefill_len = len(seq_group.prompt_logprob_indices) prefill_len = len(seq_group.prompt_logprob_indices)
prompt_tokens.extend( prompt_tokens.extend(
array('l') for _ in range(prefill_len)) array(VLLM_TOKEN_ID_ARRAY_TYPE)
for _ in range(prefill_len))
output_tokens.extend( output_tokens.extend(
array('l') for _ in range(prefill_len)) array(VLLM_TOKEN_ID_ARRAY_TYPE)
for _ in range(prefill_len))
if seq_group.do_sample: if seq_group.do_sample:
for seq_id in seq_ids: for seq_id in seq_ids:
seq_data = seq_group.seq_data[seq_id] seq_data = seq_group.seq_data[seq_id]
......
from typing import Any, Optional from typing import Any, Optional
import msgspec
class PoolingParams:
class PoolingParams(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg]
"""Pooling parameters for pooling. """Pooling parameters for pooling.
Attributes: Attributes:
additional_data: Any additional data needed for pooling. additional_data: Any additional data needed for pooling.
""" """
additional_data: Optional[Any] = None
def __init__(self, additional_data: Optional[Any] = None):
self.additional_data = additional_data
def clone(self) -> "PoolingParams": def clone(self) -> "PoolingParams":
"""Returns a deep copy of the PoolingParams instance.""" """Returns a deep copy of the PoolingParams instance."""
......
from dataclasses import dataclass import msgspec
from vllm.adapter_commons.request import AdapterRequest from vllm.adapter_commons.request import AdapterRequest
@dataclass class PromptAdapterRequest(
class PromptAdapterRequest(AdapterRequest): msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True, # type: ignore[call-arg]
frozen=True): # type: ignore[call-arg]
""" """
Request for a Prompt adapter. Request for a Prompt adapter.
""" """
__metaclass__ = AdapterRequest
prompt_adapter_name: str prompt_adapter_name: str
prompt_adapter_id: int prompt_adapter_id: int
......
...@@ -2,10 +2,10 @@ ...@@ -2,10 +2,10 @@
import copy import copy
from enum import IntEnum from enum import IntEnum
from functools import cached_property from functools import cached_property
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Set, Union
import msgspec
import torch import torch
from pydantic import Field
from typing_extensions import Annotated from typing_extensions import Annotated
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -33,7 +33,11 @@ first argument, and returns a modified tensor of logits ...@@ -33,7 +33,11 @@ first argument, and returns a modified tensor of logits
to sample from.""" to sample from."""
class SamplingParams: class SamplingParams(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
# required for @cached_property.
dict=True): # type: ignore[call-arg]
"""Sampling parameters for text generation. """Sampling parameters for text generation.
Overall, we follow the sampling parameters from the OpenAI text completion Overall, we follow the sampling parameters from the OpenAI text completion
...@@ -112,87 +116,73 @@ class SamplingParams: ...@@ -112,87 +116,73 @@ class SamplingParams:
(i.e., no truncation). (i.e., no truncation).
""" """
def __init__( n: int = 1
self, best_of: Optional[int] = None
n: int = 1, presence_penalty: float = 0.0
best_of: Optional[int] = None, frequency_penalty: float = 0.0
presence_penalty: float = 0.0, repetition_penalty: float = 1.0
frequency_penalty: float = 0.0, temperature: float = 1.0
repetition_penalty: float = 1.0, top_p: float = 1.0
temperature: float = 1.0, top_k: int = -1
top_p: float = 1.0, min_p: float = 0.0
top_k: int = -1, seed: Optional[int] = None
min_p: float = 0.0, use_beam_search: bool = False
seed: Optional[int] = None, length_penalty: float = 1.0
use_beam_search: bool = False, early_stopping: Union[bool, str] = False
length_penalty: float = 1.0, stop: Optional[Union[str, List[str]]] = None
early_stopping: Union[bool, str] = False, stop_token_ids: Optional[List[int]] = None
stop: Optional[Union[str, List[str]]] = None, ignore_eos: bool = False
stop_token_ids: Optional[List[int]] = None, max_tokens: Optional[int] = 16
include_stop_str_in_output: bool = False, min_tokens: int = 0
ignore_eos: bool = False, logprobs: Optional[int] = None
max_tokens: Optional[int] = 16, prompt_logprobs: Optional[int] = None
min_tokens: int = 0, # NOTE: This parameter is only exposed at the engine level for now.
logprobs: Optional[int] = None, # It is not exposed in the OpenAI API server, as the OpenAI API does
prompt_logprobs: Optional[int] = None, # not support returning only a list of token IDs.
detokenize: bool = True, detokenize: bool = True
skip_special_tokens: bool = True, skip_special_tokens: bool = True
spaces_between_special_tokens: bool = True, spaces_between_special_tokens: bool = True
logits_processors: Optional[List[LogitsProcessor]] = None, # Optional[List[LogitsProcessor]] type. We use Any here because
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, # Optional[List[LogitsProcessor]] type is not supported by msgspec.
) -> None: logits_processors: Optional[Any] = None
self.n = n include_stop_str_in_output: bool = False
self.best_of = best_of if best_of is not None else n truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None
self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty # The below fields are not supposed to be used as an input.
self.repetition_penalty = repetition_penalty # They are set in post_init.
if 0 < temperature < _MAX_TEMP: output_text_buffer_length: int = 0
_all_stop_token_ids: Set[int] = msgspec.field(default_factory=set)
def __post_init__(self) -> None:
self.best_of = self.best_of or self.n
if 0 < self.temperature < _MAX_TEMP:
logger.warning( logger.warning(
"temperature %s is less than %s, which may cause numerical " "temperature %s is less than %s, which may cause numerical "
"errors nan or inf in tensors. We have maxed it out to %s.", "errors nan or inf in tensors. We have maxed it out to %s.",
temperature, _MAX_TEMP, _MAX_TEMP) self.temperature, _MAX_TEMP, _MAX_TEMP)
temperature = max(temperature, _MAX_TEMP) self.temperature = max(self.temperature, _MAX_TEMP)
self.temperature = temperature if self.seed == -1:
self.top_p = top_p
self.top_k = top_k
self.min_p = min_p
if seed == -1:
self.seed = None self.seed = None
else: else:
self.seed = seed self.seed = self.seed
self.use_beam_search = use_beam_search if self.stop is None:
self.length_penalty = length_penalty
self.early_stopping = early_stopping
if stop is None:
self.stop = [] self.stop = []
elif isinstance(stop, str): elif isinstance(self.stop, str):
self.stop = [stop] self.stop = [self.stop]
else: else:
self.stop = list(stop) self.stop = list(self.stop)
if stop_token_ids is None: if self.stop_token_ids is None:
self.stop_token_ids = [] self.stop_token_ids = []
else: else:
self.stop_token_ids = list(stop_token_ids) self.stop_token_ids = list(self.stop_token_ids)
self.ignore_eos = ignore_eos self.logprobs = 1 if self.logprobs is True else self.logprobs
self.max_tokens = max_tokens self.prompt_logprobs = (1 if self.prompt_logprobs is True else
self.min_tokens = min_tokens self.prompt_logprobs)
self.logprobs = 1 if logprobs is True else logprobs
self.prompt_logprobs = 1 if prompt_logprobs is True else prompt_logprobs
# NOTE: This parameter is only exposed at the engine level for now.
# It is not exposed in the OpenAI API server, as the OpenAI API does
# not support returning only a list of token IDs.
self.detokenize = detokenize
self.skip_special_tokens = skip_special_tokens
self.spaces_between_special_tokens = spaces_between_special_tokens
self.logits_processors = logits_processors
self.include_stop_str_in_output = include_stop_str_in_output
self.truncate_prompt_tokens = truncate_prompt_tokens
# Number of characters to hold back for stop string evaluation # Number of characters to hold back for stop string evaluation
# until sequence is finished. # until sequence is finished.
if self.stop and not include_stop_str_in_output: if self.stop and not self.include_stop_str_in_output:
self.output_text_buffer_length = max(len(s) for s in self.stop) - 1 self.output_text_buffer_length = max(len(s) for s in self.stop) - 1
else:
self.output_text_buffer_length = 0
self._verify_args() self._verify_args()
if self.use_beam_search: if self.use_beam_search:
...@@ -206,11 +196,12 @@ class SamplingParams: ...@@ -206,11 +196,12 @@ class SamplingParams:
self.min_p = 0.0 self.min_p = 0.0
self._verify_greedy_sampling() self._verify_greedy_sampling()
# eos_token_id is added to this by the engine # eos_token_id is added to this by the engine
self.all_stop_token_ids = set(self.stop_token_ids) self._all_stop_token_ids = set(self.stop_token_ids)
def _verify_args(self) -> None: def _verify_args(self) -> None:
if self.n < 1: if self.n < 1:
raise ValueError(f"n must be at least 1, got {self.n}.") raise ValueError(f"n must be at least 1, got {self.n}.")
assert isinstance(self.best_of, int)
if self.best_of < self.n: if self.best_of < self.n:
raise ValueError(f"best_of must be greater than or equal to n, " raise ValueError(f"best_of must be greater than or equal to n, "
f"got n={self.n} and best_of={self.best_of}.") f"got n={self.n} and best_of={self.best_of}.")
...@@ -257,6 +248,7 @@ class SamplingParams: ...@@ -257,6 +248,7 @@ class SamplingParams:
and self.truncate_prompt_tokens < 1): and self.truncate_prompt_tokens < 1):
raise ValueError(f"truncate_prompt_tokens must be >= 1, " raise ValueError(f"truncate_prompt_tokens must be >= 1, "
f"got {self.truncate_prompt_tokens}") f"got {self.truncate_prompt_tokens}")
assert isinstance(self.stop, list)
if any(not stop_str for stop_str in self.stop): if any(not stop_str for stop_str in self.stop):
raise ValueError("stop cannot contain an empty string.") raise ValueError("stop cannot contain an empty string.")
if self.stop and not self.detokenize: if self.stop and not self.detokenize:
...@@ -290,6 +282,7 @@ class SamplingParams: ...@@ -290,6 +282,7 @@ class SamplingParams:
"default value of 1.0 when not using beam search.") "default value of 1.0 when not using beam search.")
def _verify_greedy_sampling(self) -> None: def _verify_greedy_sampling(self) -> None:
assert isinstance(self.best_of, int)
if self.best_of > 1: if self.best_of > 1:
raise ValueError("best_of must be 1 when using greedy sampling." raise ValueError("best_of must be 1 when using greedy sampling."
f"Got {self.best_of}.") f"Got {self.best_of}.")
...@@ -303,7 +296,7 @@ class SamplingParams: ...@@ -303,7 +296,7 @@ class SamplingParams:
if model_eos_token_id is not None: if model_eos_token_id is not None:
# Add the eos token id into the sampling_params to support # Add the eos token id into the sampling_params to support
# min_tokens processing. # min_tokens processing.
self.all_stop_token_ids.add(model_eos_token_id) self._all_stop_token_ids.add(model_eos_token_id)
# Update eos_token_id for generation # Update eos_token_id for generation
if (eos_ids := generation_config.get("eos_token_id")) is not None: if (eos_ids := generation_config.get("eos_token_id")) is not None:
...@@ -315,7 +308,7 @@ class SamplingParams: ...@@ -315,7 +308,7 @@ class SamplingParams:
# purposes. # purposes.
eos_ids.discard(model_eos_token_id) eos_ids.discard(model_eos_token_id)
if eos_ids: if eos_ids:
self.all_stop_token_ids.update(eos_ids) self._all_stop_token_ids.update(eos_ids)
if not self.ignore_eos: if not self.ignore_eos:
eos_ids.update(self.stop_token_ids) eos_ids.update(self.stop_token_ids)
self.stop_token_ids = list(eos_ids) self.stop_token_ids = list(eos_ids)
...@@ -330,6 +323,10 @@ class SamplingParams: ...@@ -330,6 +323,10 @@ class SamplingParams:
return SamplingType.RANDOM_SEED return SamplingType.RANDOM_SEED
return SamplingType.RANDOM return SamplingType.RANDOM
@property
def all_stop_token_ids(self) -> Set[int]:
return self._all_stop_token_ids
def clone(self) -> "SamplingParams": def clone(self) -> "SamplingParams":
"""Deep copy excluding LogitsProcessor objects. """Deep copy excluding LogitsProcessor objects.
......
This diff is collapsed.
from array import array
from itertools import chain, count from itertools import chain, count
from typing import Iterator, List, Tuple from typing import Iterator, List, Tuple
import torch import torch
from vllm import SamplingParams from vllm import SamplingParams
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData, from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, ExecuteModelRequest,
SequenceGroupMetadata, get_all_seq_ids) SamplerOutput, SequenceData, SequenceGroupMetadata,
get_all_seq_ids)
from vllm.spec_decode.interfaces import (SpeculativeProposals, from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores) SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.util import (nvtx_range, sampler_output_to_torch, from vllm.spec_decode.util import (nvtx_range, sampler_output_to_torch,
...@@ -293,14 +295,15 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -293,14 +295,15 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
input sequence. input sequence.
""" """
seq_data = seq_group_metadata.seq_data[seq_id] seq_data = seq_group_metadata.seq_data[seq_id]
prompt_token_ids = seq_data.get_prompt_token_ids() prompt_token_ids = seq_data.prompt_token_ids_array
new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids] new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids]
new_seq_data_dict = { new_seq_data_dict = {
target_seq_id: target_seq_id:
SequenceData( SequenceData(
prompt_token_ids=prompt_token_ids, prompt_token_ids,
output_token_ids=new_output_token_ids, _output_token_ids=array(VLLM_TOKEN_ID_ARRAY_TYPE,
new_output_token_ids),
), ),
} }
# This is a hack. Technically, spec decoding should compute # This is a hack. Technically, spec decoding should compute
......
import time import time
from dataclasses import dataclass
from typing import Callable, Optional from typing import Callable, Optional
import msgspec
import torch import torch
from vllm.model_executor.layers.spec_decode_base_sampler import ( from vllm.model_executor.layers.spec_decode_base_sampler import (
...@@ -9,8 +9,10 @@ from vllm.model_executor.layers.spec_decode_base_sampler import ( ...@@ -9,8 +9,10 @@ from vllm.model_executor.layers.spec_decode_base_sampler import (
from vllm.utils import is_pin_memory_available from vllm.utils import is_pin_memory_available
@dataclass class SpecDecodeWorkerMetrics(
class SpecDecodeWorkerMetrics: msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg]
"""Dataclass holding metrics emitted from the spec decode worker. """Dataclass holding metrics emitted from the spec decode worker.
""" """
......
"""A GPU worker class.""" """A GPU worker class."""
import gc import gc
import os import os
from typing import List, Optional, Set, Tuple, Type from typing import Dict, List, Optional, Set, Tuple, Type, Union
import torch import torch
import torch.distributed import torch.distributed
...@@ -18,7 +18,9 @@ from vllm.model_executor import set_random_seed ...@@ -18,7 +18,9 @@ from vllm.model_executor import set_random_seed
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
SamplerOutput, SequenceGroupMetadata,
SequenceGroupMetadataDelta)
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
from vllm.worker.embedding_model_runner import EmbeddingModelRunner from vllm.worker.embedding_model_runner import EmbeddingModelRunner
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
...@@ -109,6 +111,7 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -109,6 +111,7 @@ class Worker(LocalOrDistributedWorkerBase):
self.cache_engine: List[CacheEngine] self.cache_engine: List[CacheEngine]
# Initialize gpu_cache as embedding models don't initialize kv_caches # Initialize gpu_cache as embedding models don't initialize kv_caches
self.gpu_cache: Optional[List[List[torch.Tensor]]] = None self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {}
def _is_encoder_decoder_model(self): def _is_encoder_decoder_model(self):
return self.model_config.is_encoder_decoder_model return self.model_config.is_encoder_decoder_model
...@@ -303,6 +306,63 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -303,6 +306,63 @@ class Worker(LocalOrDistributedWorkerBase):
and worker_input.blocks_to_copy.numel() > 0): and worker_input.blocks_to_copy.numel() > 0):
self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy) self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy)
def _get_cached_seq_group_metadata(
self,
seq_group_metadata_list: List[Union[SequenceGroupMetadata,
SequenceGroupMetadataDelta]],
finished_request_ids: List[str]) -> List[SequenceGroupMetadata]:
"""Return a list of cached Sequence Group Metadata after updating its
state.
It is used because scheduler only sends delta to workers to reduce
the data payload size. The function also cleans up cache based on
a given `finished_request_ids`.
"""
new_seq_group_metadata_list = []
for metadata_or_delta in seq_group_metadata_list:
request_id = metadata_or_delta.request_id
if request_id not in self._seq_group_metadata_cache:
# The first prefill.
assert isinstance(metadata_or_delta, SequenceGroupMetadata)
self._seq_group_metadata_cache[request_id] = metadata_or_delta
else:
# The first prefill is already cached.
if isinstance(metadata_or_delta, SequenceGroupMetadataDelta):
self._seq_group_metadata_cache[request_id].apply_delta(
metadata_or_delta)
else:
# If metadata snapshot is sent again, it is
# preempted. Reset the cache because we need to start
# from scratch.
assert isinstance(metadata_or_delta, SequenceGroupMetadata)
self._seq_group_metadata_cache[
request_id] = metadata_or_delta
new_seq_group_metadata_list.append(
self._seq_group_metadata_cache[request_id])
# Clean up finished ids
for finished_id in finished_request_ids:
del self._seq_group_metadata_cache[finished_id]
return new_seq_group_metadata_list
def _execute_model_spmd(
self,
execute_model_req: ExecuteModelRequest,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Optional[List[SamplerOutput]]:
if execute_model_req is not None:
new_seq_group_metadata_list = self._get_cached_seq_group_metadata(
execute_model_req.seq_group_metadata_list,
execute_model_req.finished_requests_ids)
execute_model_req.seq_group_metadata_list = (
new_seq_group_metadata_list)
output = super()._execute_model_spmd(execute_model_req,
intermediate_tensors)
return output
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request) return self.model_runner.add_lora(lora_request)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment