Unverified Commit fd95e026 authored by afeldman-nm's avatar afeldman-nm Committed by GitHub
Browse files

[Core] Subclass ModelRunner to support cross-attention & encoder sequences...


[Core] Subclass ModelRunner to support cross-attention & encoder sequences (towards eventual encoder/decoder model support) (#4942)
Co-authored-by: default avatarAndrew Feldman <afeld2012@gmail.com>
Co-authored-by: default avatarNick Hill <nickhill@us.ibm.com>
parent 660470e5
...@@ -69,7 +69,7 @@ class EngineArgs: ...@@ -69,7 +69,7 @@ class EngineArgs:
rope_theta: Optional[float] = None rope_theta: Optional[float] = None
tokenizer_revision: Optional[str] = None tokenizer_revision: Optional[str] = None
quantization: Optional[str] = None quantization: Optional[str] = None
enforce_eager: bool = False enforce_eager: Optional[bool] = None
max_context_len_to_capture: Optional[int] = None max_context_len_to_capture: Optional[int] = None
max_seq_len_to_capture: int = 8192 max_seq_len_to_capture: int = 8192
disable_custom_all_reduce: bool = False disable_custom_all_reduce: bool = False
......
This diff is collapsed.
...@@ -121,12 +121,21 @@ class LLM: ...@@ -121,12 +121,21 @@ class LLM:
gpu_memory_utilization: float = 0.9, gpu_memory_utilization: float = 0.9,
swap_space: int = 4, swap_space: int = 4,
cpu_offload_gb: float = 0, cpu_offload_gb: float = 0,
enforce_eager: bool = False, enforce_eager: Optional[bool] = None,
max_context_len_to_capture: Optional[int] = None, max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: int = 8192, max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False, disable_custom_all_reduce: bool = False,
**kwargs, **kwargs,
) -> None: ) -> None:
'''
LLM constructor.
Note: if enforce_eager is unset (enforce_eager is None)
it defaults to False for decoder-only models and True
for encoder/decoder models, since encoder/decoder models
do not currently support CUDAGraph.
'''
if "disable_log_stats" not in kwargs: if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True kwargs["disable_log_stats"] = True
removed_vision_keys = ("image_token_id", "image_feature_size", removed_vision_keys = ("image_token_id", "image_feature_size",
...@@ -297,8 +306,8 @@ class LLM: ...@@ -297,8 +306,8 @@ class LLM:
""" """
if self.llm_engine.model_config.embedding_mode: if self.llm_engine.model_config.embedding_mode:
raise ValueError( raise ValueError(
"LLM.generate() is only supported for generation models " "LLM.generate() is only supported for (conditional) generation "
"(XForCausalLM).") "models (XForCausalLM, XForConditionalGeneration).")
if prompt_token_ids is not None: if prompt_token_ids is not None:
inputs = self._convert_v1_inputs( inputs = self._convert_v1_inputs(
...@@ -631,3 +640,9 @@ class LLM: ...@@ -631,3 +640,9 @@ class LLM:
# This is necessary because some requests may be finished earlier than # This is necessary because some requests may be finished earlier than
# its previous requests. # its previous requests.
return sorted(outputs, key=lambda x: int(x.request_id)) return sorted(outputs, key=lambda x: int(x.request_id))
def _is_encoder_decoder_model(self):
return self.llm_engine.is_encoder_decoder_model()
def _is_embedding_model(self):
return self.llm_engine.is_embedding_model()
from .data import (LLMInputs, ParsedText, ParsedTokens, PromptInputs, from .data import (ExplicitEncoderDecoderPrompt, LLMInputs, ParsedText,
TextPrompt, TokensPrompt, parse_and_batch_prompt) ParsedTokens, PromptInputs, SingletonPromptInputs,
TextPrompt, TokensPrompt, get_prompt_type,
is_valid_encoder_decoder_llm_inputs, parse_and_batch_prompt)
from .registry import InputContext, InputRegistry from .registry import InputContext, InputRegistry
INPUT_REGISTRY = InputRegistry() INPUT_REGISTRY = InputRegistry()
...@@ -12,7 +14,18 @@ See also: ...@@ -12,7 +14,18 @@ See also:
""" """
__all__ = [ __all__ = [
"ParsedText", "ParsedTokens", "parse_and_batch_prompt", "TextPrompt", "ParsedText",
"TokensPrompt", "PromptInputs", "LLMInputs", "INPUT_REGISTRY", "ParsedTokens",
"InputContext", "InputRegistry" "parse_and_batch_prompt",
"TextPrompt",
"TokensPrompt",
"PromptInputs",
"LLMInputs",
"INPUT_REGISTRY",
"InputContext",
"InputRegistry",
"get_prompt_type",
"is_valid_encoder_decoder_llm_inputs",
"ExplicitEncoderDecoderPrompt",
"SingletonPromptInputs",
] ]
...@@ -92,15 +92,114 @@ class TokensPrompt(TypedDict): ...@@ -92,15 +92,114 @@ class TokensPrompt(TypedDict):
""" """
PromptInputs = Union[str, TextPrompt, TokensPrompt] SingletonPromptInputs = Union[str, TextPrompt, TokensPrompt]
""" """
The inputs to the LLM, which can take one of the following forms: Set of possible schemas for a single LLM input:
- A text prompt (:class:`str` or :class:`TextPrompt`) - A text prompt (:class:`str` or :class:`TextPrompt`)
- A tokenized prompt (:class:`TokensPrompt`) - A tokenized prompt (:class:`TokensPrompt`)
Note that "singleton" is as opposed to a data structure
which encapsulates multiple prompts, i.e. of the sort
which may be utilized for encoder/decoder models when
the user desires to express both the encoder & decoder
prompts explicitly, i.e. ExplicitEncoderDecoderPrompt
A prompt of type SingletonPromptInputs may be employed
as (1) input to a decoder-only model, (2) input to
the encoder of an encoder/decoder model, in the scenario
where the decoder-prompt is not specified explicitly, or
(3) as a member of a larger data structure encapsulating
more than one prompt, i.e. ExplicitEncoderDecoderPrompt
""" """
class ExplicitEncoderDecoderPrompt(TypedDict):
"""Represents an encoder/decoder model input prompt,
comprising an explicit encoder prompt and a
decoder prompt.
The encoder and decoder prompts, respectively,
may formatted according to any of the
SingletonPromptInputs schemas, and are not
required to have the same schema.
Only the encoder prompt may have multi-modal data.
Note that an ExplicitEncoderDecoderPrompt may not
be used as an input to a decoder-only model,
and that the `encoder_prompt` and `decoder_prompt`
fields of this data structure may not themselves
must be SingletonPromptInputs instances.
"""
encoder_prompt: SingletonPromptInputs
decoder_prompt: SingletonPromptInputs
PromptInputs = Union[SingletonPromptInputs, ExplicitEncoderDecoderPrompt]
"""
Set of possible schemas for an LLM input, including
both decoder-only and encoder/decoder input types:
- A text prompt (:class:`str` or :class:`TextPrompt`)
- A tokenized prompt (:class:`TokensPrompt`)
- A single data structure containing both an encoder and a decoder prompt
(:class:`ExplicitEncoderDecoderPrompt`)
"""
def _has_required_keys(
d: dict,
required_keys: set,
) -> bool:
return required_keys.issubset(d.keys())
def get_prompt_type(prompt: Optional[PromptInputs]) -> Optional[str]:
"""
Get the type-name of the prompt argument instance, given that
isinstance() cannot apply to TypedDict subclasses directly.
If the prompt is None, return 'None' as the type name.
Arguments:
* prompt: LLM input prompt or None
Returns:
* String representation of prompt type
"""
if prompt is None:
return 'None'
required_keys_dict = {
'TextPrompt': {'prompt'},
'TokensPrompt': {'prompt_token_ids'},
'ExplicitEncoderDecoder': {'encoder_prompt', 'decoder_prompt'},
}
if isinstance(prompt, dict):
for (ptype, required_keys) in required_keys_dict.items():
# Ignore type checking in the conditional below because type
# checker does not understand that is_dict(prompt) narrows
# down the possible types
if _has_required_keys(
prompt, # type: ignore
required_keys):
return ptype
raise ValueError(f"Invalid prompt {prompt}, valid types are "
"required_keys_dict={required_keys_dict}")
if isinstance(prompt, str):
return "str"
raise ValueError(f"Invalid prompt {prompt}")
class LLMInputs(TypedDict): class LLMInputs(TypedDict):
""" """
The inputs in :class:`~vllm.LLMEngine` before they are The inputs in :class:`~vllm.LLMEngine` before they are
...@@ -114,8 +213,29 @@ class LLMInputs(TypedDict): ...@@ -114,8 +213,29 @@ class LLMInputs(TypedDict):
The original prompt text corresponding to the token IDs, if available. The original prompt text corresponding to the token IDs, if available.
""" """
encoder_prompt_token_ids: NotRequired[List[int]]
"""The token IDs of the encoder prompt."""
encoder_prompt: NotRequired[Optional[str]]
"""
The original encoder prompt text corresponding to the token IDs, if
available.
"""
multi_modal_data: NotRequired[Optional["MultiModalDataDict"]] multi_modal_data: NotRequired[Optional["MultiModalDataDict"]]
""" """
Optional multi-modal data to pass to the model, Optional multi-modal data to pass to the model,
if the model supports it. if the model supports it.
""" """
def is_valid_encoder_decoder_llm_inputs(inputs: LLMInputs) -> bool:
"""
Return True if the LLMInputs instance has the correct configuration
for encoder/decoder.
"""
# True if encoder prompt token ids field exists &
# is not None
return ('encoder_prompt_token_ids' in inputs
and inputs['encoder_prompt_token_ids'] is not None)
...@@ -83,7 +83,16 @@ _EMBEDDING_MODELS = { ...@@ -83,7 +83,16 @@ _EMBEDDING_MODELS = {
"MistralModel": ("llama_embedding", "LlamaEmbeddingModel"), "MistralModel": ("llama_embedding", "LlamaEmbeddingModel"),
} }
_MODELS = {**_GENERATION_MODELS, **_EMBEDDING_MODELS} _CONDITIONAL_GENERATION_MODELS = {
"BartModel": ("bart", "BartForConditionalGeneration"),
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
}
_MODELS = {
**_GENERATION_MODELS,
**_EMBEDDING_MODELS,
**_CONDITIONAL_GENERATION_MODELS
}
# Architecture -> type. # Architecture -> type.
# out of tree models # out of tree models
......
This diff is collapsed.
...@@ -70,12 +70,20 @@ class RequestOutput: ...@@ -70,12 +70,20 @@ class RequestOutput:
Args: Args:
request_id: The unique ID of the request. request_id: The unique ID of the request.
prompt: The prompt string of the request. prompt: The prompt string of the request.
For encoder/decoder models, this is the
decoder input prompt.
prompt_token_ids: The token IDs of the prompt. prompt_token_ids: The token IDs of the prompt.
For encoder/decoder models, this is the
decoder input prompt token ids.
prompt_logprobs: The log probabilities to return per prompt token. prompt_logprobs: The log probabilities to return per prompt token.
outputs: The output sequences of the request. outputs: The output sequences of the request.
finished: Whether the whole request is finished. finished: Whether the whole request is finished.
metrics: Metrics associated with the request. metrics: Metrics associated with the request.
lora_request: The LoRA request that was used to generate the output. lora_request: The LoRA request that was used to generate the output.
encoder_prompt: The encoder prompt string of the request;
None if decoder-only
encoder_prompt_token_ids: The token IDs of the encoder prompt;
None if decoder-only
""" """
def __init__( def __init__(
...@@ -88,6 +96,8 @@ class RequestOutput: ...@@ -88,6 +96,8 @@ class RequestOutput:
finished: bool, finished: bool,
metrics: Optional[RequestMetrics] = None, metrics: Optional[RequestMetrics] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
encoder_prompt: Optional[str] = None,
encoder_prompt_token_ids: Optional[List[int]] = None,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.prompt = prompt self.prompt = prompt
...@@ -97,6 +107,8 @@ class RequestOutput: ...@@ -97,6 +107,8 @@ class RequestOutput:
self.finished = finished self.finished = finished
self.metrics = metrics self.metrics = metrics
self.lora_request = lora_request self.lora_request = lora_request
self.encoder_prompt = encoder_prompt
self.encoder_prompt_token_ids = encoder_prompt_token_ids
@classmethod @classmethod
def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
...@@ -137,6 +149,8 @@ class RequestOutput: ...@@ -137,6 +149,8 @@ class RequestOutput:
# Every sequence in the sequence group should have the same prompt. # Every sequence in the sequence group should have the same prompt.
prompt = seq_group.prompt prompt = seq_group.prompt
prompt_token_ids = seq_group.prompt_token_ids prompt_token_ids = seq_group.prompt_token_ids
encoder_prompt = seq_group.encoder_prompt
encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids
prompt_logprobs = seq_group.prompt_logprobs prompt_logprobs = seq_group.prompt_logprobs
finished = seq_group.is_finished() finished = seq_group.is_finished()
finished_time = time.time() if finished else None finished_time = time.time() if finished else None
...@@ -148,12 +162,16 @@ class RequestOutput: ...@@ -148,12 +162,16 @@ class RequestOutput:
outputs, outputs,
finished, finished,
seq_group.metrics, seq_group.metrics,
lora_request=seq_group.lora_request) lora_request=seq_group.lora_request,
encoder_prompt=encoder_prompt,
encoder_prompt_token_ids=encoder_prompt_token_ids)
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"RequestOutput(request_id={self.request_id}, " return (f"RequestOutput(request_id={self.request_id}, "
f"prompt={self.prompt!r}, " f"prompt={self.prompt!r}, "
f"prompt_token_ids={self.prompt_token_ids}, " f"prompt_token_ids={self.prompt_token_ids}, "
f"encoder_prompt={self.encoder_prompt!r}, "
f"encoder_prompt_token_ids={self.encoder_prompt_token_ids}, "
f"prompt_logprobs={self.prompt_logprobs}, " f"prompt_logprobs={self.prompt_logprobs}, "
f"outputs={self.outputs}, " f"outputs={self.outputs}, "
f"finished={self.finished}, " f"finished={self.finished}, "
......
...@@ -7,10 +7,11 @@ from array import array ...@@ -7,10 +7,11 @@ from array import array
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple, from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple,
Union) Union, cast)
import torch import torch
from vllm.inputs import is_valid_encoder_decoder_llm_inputs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
...@@ -244,13 +245,26 @@ class SequenceData: ...@@ -244,13 +245,26 @@ class SequenceData:
class Sequence: class Sequence:
"""Stores the data, status, and block information of a sequence. """Stores the data, status, and block information of a sequence.
The sequence is constructed from the LLMInputs instance passed
in through the `inputs` constructor argument.
For encoder/decoder models, LLMInputs encapsulates both a
decoder and encoder prompt, creating an ambiguity about which
prompt to construct the sequence from. The `from_decoder_prompt`
constructor argument signals whether to construct the Sequence
from the LLMInputs decoder prompt, or encoder prompt.
Args: Args:
seq_id: The ID of the sequence. seq_id: The ID of the sequence.
inputs: The inputs of the sequence. inputs: The inputs of the sequence.
block_size: The block size of the sequence. Should be the same as the block_size: The block size of the sequence. Should be the same as the
block size used by the block manager and cache engine. block size used by the block manager and cache engine.
eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM.
lora_request: LoRA request. lora_request: LoRA request.
prompt_adapter_request: Prompt Adapter request. prompt_adapter_request: Prompt Adapter request.
from_decoder_prompt: Construct Sequence from LLMInputs decoder prompt
(True) or encoder prompt (False.) Must be True
for decoder-only model.
""" """
...@@ -261,7 +275,8 @@ class Sequence: ...@@ -261,7 +275,8 @@ class Sequence:
block_size: int, block_size: int,
eos_token_id: Optional[int] = None, eos_token_id: Optional[int] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None,
from_decoder_prompt: bool = True,
) -> None: ) -> None:
self.seq_id = seq_id self.seq_id = seq_id
self.inputs = inputs self.inputs = inputs
...@@ -269,6 +284,36 @@ class Sequence: ...@@ -269,6 +284,36 @@ class Sequence:
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
self.lora_request = lora_request self.lora_request = lora_request
self.prompt_adapter_request = prompt_adapter_request self.prompt_adapter_request = prompt_adapter_request
self.from_decoder_prompt = from_decoder_prompt
self._prompt: Optional[str] = None
self._prompt_token_ids: Optional[List[int]] = None
# For decoder-only models, a Sequence is constructed
# from an LLMInputs instance (the `inputs` arg.)
#
# For encoder/decoder models the same `inputs`
# instance could be utilized to construct either an
# encoder sequence or a decoder sequence, because
# `LLMInputs` has both decoder- and encoder-oriented
# member variables (i.e. it encapsulates both an encoder
# and a decoder prompt.) The decision of which type of sequence
# to generate is determined by the `from_decoder_prompt` argument.
#
# When constructing a encoder sequence
# (`from_decoder_prompt` False) it matters that
# the `LLMInputs` instance stored in `inputs` is valid
# in the sense that its encoder-related member variables are
# populated; below, an exception is raised if this is
# not the case.
#
# When constructing a decoder sequence (`from_decoder_prompt` True)
# it does not matter whether `inputs` has its encoder-related
# member variables populated.
if not (from_decoder_prompt
or is_valid_encoder_decoder_llm_inputs(inputs)):
raise ValueError("Cannot extract encoder input prompt from "
f"invalid input {inputs}; did you forget the "
"encoder input prompt fields?")
self.data = SequenceData(self.prompt_token_ids) self.data = SequenceData(self.prompt_token_ids)
self.output_logprobs: SampleLogprobs = [] self.output_logprobs: SampleLogprobs = []
...@@ -289,11 +334,35 @@ class Sequence: ...@@ -289,11 +334,35 @@ class Sequence:
@property @property
def prompt(self) -> Optional[str]: def prompt(self) -> Optional[str]:
return self.inputs.get("prompt") if self._prompt is not None:
# Reuse precomputed prompt string
return self._prompt
# Select decoder or encoder input prompt str,
# as appropriate
prompt_key: str = ("prompt"
if self.from_decoder_prompt else "encoder_prompt")
# Cache prompt
self._prompt = cast(Optional[str], self.inputs.get(prompt_key))
return self._prompt
@property @property
def prompt_token_ids(self) -> List[int]: def prompt_token_ids(self) -> List[int]:
return self.inputs["prompt_token_ids"] if self._prompt_token_ids is not None:
# Reuse precomputed prompt token ids
return self._prompt_token_ids
# Select decoder or encoder input prompt
# token ids, as appropriate
prompt_token_ids_key: str = ("prompt_token_ids"
if self.from_decoder_prompt else
"encoder_prompt_token_ids")
# Cache computed prompt token ids
self._prompt_token_ids = cast(List[int],
self.inputs.get(prompt_token_ids_key))
return self._prompt_token_ids
@property @property
def multi_modal_data(self) -> "MultiModalDataDict": def multi_modal_data(self) -> "MultiModalDataDict":
...@@ -472,6 +541,22 @@ class SequenceGroup: ...@@ -472,6 +541,22 @@ class SequenceGroup:
# We use the prompt of an arbitrary sequence. # We use the prompt of an arbitrary sequence.
return self.seqs[0].prompt_token_ids return self.seqs[0].prompt_token_ids
@property
def encoder_prompt(self) -> Optional[str]:
# There are either 0 or 1 encoder sequences
# If one is present, its prompt is distinct
# from the decoder's.
return (self.encoder_seq.prompt
if self.encoder_seq is not None else None)
@property
def encoder_prompt_token_ids(self) -> Optional[List[int]]:
# There are either 0 or 1 encoder sequences
# If one is present, its prompt token ids are
# distinct from the decoder's.
return (self.encoder_seq.prompt_token_ids
if self.encoder_seq is not None else None)
@property @property
def multi_modal_data(self) -> "MultiModalDataDict": def multi_modal_data(self) -> "MultiModalDataDict":
# All sequences in the group should have the same multi-modal data. # All sequences in the group should have the same multi-modal data.
......
...@@ -27,10 +27,93 @@ from typing_extensions import ParamSpec ...@@ -27,10 +27,93 @@ from typing_extensions import ParamSpec
import vllm.envs as envs import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.inputs import (ExplicitEncoderDecoderPrompt, PromptInputs,
SingletonPromptInputs)
from vllm.logger import enable_trace_function_call, init_logger from vllm.logger import enable_trace_function_call, init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
# Exception strings for non-implemented encoder/decoder scenarios
STR_NOT_IMPL_ENC_DEC_SWA = \
"Sliding window attention for encoder/decoder models " + \
"is not currently supported."
STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \
"Prefix caching for encoder/decoder models " + \
"is not currently supported."
STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL = \
"Chunked prefill for encoder/decoder models " + \
"is not currently supported."
STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP = (
"Models with logits_soft_cap "
"require FlashInfer backend, which is "
"currently not supported for encoder/decoder "
"models.")
STR_NOT_IMPL_ENC_DEC_LORA = ("LoRA is currently not currently "
"supported with encoder/decoder "
"models.")
STR_NOT_IMPL_ENC_DEC_PP = ("Pipeline parallelism is not "
"currently supported with "
"encoder/decoder models.")
STR_NOT_IMPL_ENC_DEC_MM = ("Multimodal is not currently "
"supported with encoder/decoder "
"models.")
STR_NOT_IMPL_ENC_DEC_SPEC_DEC = ("Speculative decoding is not "
"currently supported with encoder/"
"decoder models.")
STR_NOT_IMPL_ENC_DEC_CUDAGRAPH = ("CUDAGraph is not "
"currently supported with encoder/"
"decoder models.")
STR_NOT_IMPL_ENC_DEC_BACKEND = ("XFormers is the only backend "
"currently supported with encoder/"
"decoder models.")
STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER = ("Prompt adapters are not "
"currently supported with encoder/"
"decoder models.")
# Efficiently import all enc/dec error strings
# rather than having to import all of the above
STR_NOT_IMPL_ENC_DEC_ERR_STRS = {
"STR_NOT_IMPL_ENC_DEC_SWA": STR_NOT_IMPL_ENC_DEC_SWA,
"STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE": STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE,
"STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL":
STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL,
"STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP": STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP,
"STR_NOT_IMPL_ENC_DEC_LORA": STR_NOT_IMPL_ENC_DEC_LORA,
"STR_NOT_IMPL_ENC_DEC_PP": STR_NOT_IMPL_ENC_DEC_PP,
"STR_NOT_IMPL_ENC_DEC_MM": STR_NOT_IMPL_ENC_DEC_MM,
"STR_NOT_IMPL_ENC_DEC_SPEC_DEC": STR_NOT_IMPL_ENC_DEC_SPEC_DEC,
"STR_NOT_IMPL_ENC_DEC_CUDA_GRAPH": STR_NOT_IMPL_ENC_DEC_CUDAGRAPH,
"STR_NOT_IMPL_ENC_DEC_BACKEND": STR_NOT_IMPL_ENC_DEC_BACKEND,
"STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER": STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER,
}
# Constants related to forcing the attention backend selection
# String name of register which may be set in order to
# force auto-selection of attention backend by Attention
# wrapper
STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
# Possible string values of STR_BACKEND_ENV_VAR
# register, corresponding to possible backends
STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA"
STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH"
STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_INVALID_VAL: str = "INVALID"
STR_DTYPE_TO_TORCH_DTYPE = { STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.half, "half": torch.half,
"bfloat16": torch.bfloat16, "bfloat16": torch.bfloat16,
...@@ -1029,3 +1112,50 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, ...@@ -1029,3 +1112,50 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
"""Utility function to run async task in a lock""" """Utility function to run async task in a lock"""
async with lock: async with lock:
return await task(*args, **kwargs) return await task(*args, **kwargs)
def is_encoder_decoder_model_config(model_config) -> bool:
'''
Extract the HF encoder/decoder model flag from the ModelConfig instance.
Return False if model_config is None.
'''
return model_config is not None and \
getattr(model_config.hf_config,
"is_encoder_decoder",
False)
def is_embedding_model_config(model_config) -> bool:
'''
Extract the embedding model flag from the ModelConfig instance.
Return False if model_config is None.
'''
return model_config is not None and \
model_config.embedding_mode
def build_explicit_enc_dec_prompt(
encoder_prompt: SingletonPromptInputs,
decoder_prompt: SingletonPromptInputs,
) -> ExplicitEncoderDecoderPrompt:
return ExplicitEncoderDecoderPrompt(encoder_prompt=encoder_prompt,
decoder_prompt=decoder_prompt)
def zip_enc_dec_prompt_lists(
enc_prompt_list: List[SingletonPromptInputs],
dec_prompt_list: List[SingletonPromptInputs],
) -> List[ExplicitEncoderDecoderPrompt]:
return [
build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt)
for (encoder_prompt,
decoder_prompt) in zip(enc_prompt_list, dec_prompt_list)
]
def to_enc_dec_tuple_list(
enc_dec_prompts: List[ExplicitEncoderDecoderPrompt],
) -> List[Tuple[PromptInputs, PromptInputs]]:
return [(enc_dec_prompt['encoder_prompt'],
enc_dec_prompt['decoder_prompt'])
for enc_dec_prompt in enc_dec_prompts]
import dataclasses
from typing import Any, Dict, List, Optional, Tuple, Type, cast
import torch
import torch.distributed
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata)
from vllm.attention.selector import (_Backend, get_env_variable_attn_backend,
get_global_forced_attn_backend,
global_force_attn_backend)
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, PoolerOutput, SamplerOutput,
SequenceGroupMetadata)
from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad
from vllm.worker.model_runner import (_PAD_SLOT_ID, GPUModelRunnerBase,
ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata)
from vllm.worker.model_runner_base import (
_add_attn_metadata_broadcastable_dict,
_add_sampling_metadata_broadcastable_dict)
from vllm.worker.utils import assert_enc_dec_mr_supported_scenario
logger = init_logger(__name__)
@dataclasses.dataclass(frozen=True)
class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata):
"""
Used by the EncoderDecoderModelRunner.
"""
encoder_input_tokens: Optional[torch.Tensor] = None
encoder_input_positions: Optional[torch.Tensor] = None
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
"encoder_input_tokens": self.encoder_input_tokens,
"encoder_input_positions": self.encoder_input_positions,
"virtual_engine": self.virtual_engine,
"request_ids_to_seq_ids": self.request_ids_to_seq_ids,
"finished_requests_ids": self.finished_requests_ids,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict,
self.sampling_metadata)
return tensor_dict
@classmethod
def from_broadcasted_tensor_dict(
cls,
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> "EncoderDecoderModelInput":
return cast(
EncoderDecoderModelInput,
super().from_broadcasted_tensor_dict(tensor_dict, attn_backend))
class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
_model_input_cls: Type[EncoderDecoderModelInput] = (
EncoderDecoderModelInput)
_builder_cls: Type[ModelInputForGPUBuilder] = (ModelInputForGPUBuilder)
def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
):
'''
EncoderDecoderModelRunner constructor.
`lora_config`, `multimodal_config`, and prompt_adapter_config are
unused (since these features are not yet supported for encoder/decoder
models) but these arguments are present here for compatibility with
the base-class constructor.
'''
self._maybe_force_supported_attention_backend()
super().__init__(
model_config,
parallel_config,
scheduler_config,
device_config,
cache_config,
load_config,
lora_config=None,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker,
)
# Crash for unsupported encoder/scenarios
assert_enc_dec_mr_supported_scenario(self)
def _maybe_force_supported_attention_backend(self):
'''
Force vLLM to use the XFormers attention backend,
which is currently the only supported option.
'''
def raise_backend_err():
# The user has specified an attention backend override
# which is invalid for encoder/decoder models
raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_BACKEND)
maybe_env_var_forced_backend = get_env_variable_attn_backend()
maybe_global_forced_backend = get_global_forced_attn_backend()
is_forced_by_global = maybe_global_forced_backend is not None
is_forced_by_env_var = maybe_env_var_forced_backend is not None
if not (is_forced_by_global or is_forced_by_env_var):
# The user has not already specified an attention backend
# override
logger.info("EncoderDecoderModelRunner requires "
"XFormers backend; overriding backend "
"auto-selection and forcing XFormers.")
global_force_attn_backend(_Backend.XFORMERS)
elif is_forced_by_global:
# Backend override enforced by global variable takes
# precedence over vLLM backend environment variable.
if maybe_global_forced_backend != _Backend.XFORMERS:
raise_backend_err()
elif is_forced_by_env_var:
# Backend override enforced by vLLM backend
# environment variable
if maybe_env_var_forced_backend != _Backend.XFORMERS:
raise_backend_err()
def _list_to_int32_tensor(
self,
_list: List[int],
) -> torch.Tensor:
return torch.tensor(_list, dtype=torch.int32, device=self.device)
def _list_to_long_tensor(
self,
_list: List[int],
) -> torch.Tensor:
return torch.tensor(_list, dtype=torch.long, device=self.device)
def _empty_int32_tensor(self) -> torch.Tensor:
return self._list_to_int32_tensor([])
def _empty_long_tensor(self) -> torch.Tensor:
return self._list_to_long_tensor([])
@torch.inference_mode()
def execute_model(
self,
model_input: EncoderDecoderModelInput,
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[List[PoolerOutput]]:
if num_steps > 1:
raise ValueError("num_steps > 1 is not supported in "
"EncoderDecoderModelRunner")
model_executable = self.model
seqlen_agnostic_kwargs = {
"finished_requests_ids": model_input.finished_requests_ids,
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
} if self.has_seqlen_agnostic else {}
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
encoder_input_ids=model_input.encoder_input_tokens,
encoder_positions=model_input.encoder_input_positions,
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors,
**seqlen_agnostic_kwargs)
logits = self.model.compute_logits(hidden_or_intermediate_states,
model_input.sampling_metadata)
if not self.is_driver_worker:
return []
# Sample the next token.
output: SamplerOutput = self.model.sample(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
return [output]
def make_model_input_from_broadcasted_tensor_dict(
self, tensor_dict: Dict[str, Any]) -> EncoderDecoderModelInput:
return EncoderDecoderModelInput.from_broadcasted_tensor_dict(
tensor_dict,
attn_backend=self.attn_backend,
)
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> EncoderDecoderModelInput:
"""Prepare the model input based on a given sequence group, including
metadata for the sampling step.
Since chunked prefill is not supported for encoder/decoder models,
`input_tokens` is assumed to be either entirely prefill tokens or
entirely decode tokens.
"""
model_input = self._prepare_model_input_tensors(
seq_group_metadata_list, finished_requests_ids)
(
attn_metadata,
encoder_input_tokens_tensor,
encoder_input_positions_tensor,
) = (self._prepare_encoder_model_input_tensors(seq_group_metadata_list,
model_input))
# Inject attn_metadata encoder/cross-attention fields &
# encoder input tokens/positions into model_input.
# Frozen dataclass fields cannot be modified, so use
# dataclasses.replace to construct a new model input
# instance.
model_input = dataclasses.replace(
model_input,
attn_metadata=attn_metadata,
encoder_input_tokens=encoder_input_tokens_tensor,
encoder_input_positions=encoder_input_positions_tensor,
)
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
model_input.seq_lens,
model_input.query_lens,
self.device,
self.pin_memory)
is_prompt = (seq_group_metadata_list[0].is_prompt
if seq_group_metadata_list else None)
return dataclasses.replace(model_input,
sampling_metadata=sampling_metadata,
is_prompt=is_prompt,
virtual_engine=virtual_engine)
@torch.inference_mode()
def profile_run(self) -> None:
# Enable top-k sampling to reflect the accurate memory usage.
sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
max_num_seqs = self.scheduler_config.max_num_seqs
# Profile memory usage with max_num_sequences sequences and the total
# number of tokens equal to max_num_batched_tokens.
seqs: List[SequenceGroupMetadata] = []
model_config = self.model_config
batch_size = 0
for group_id in range(max_num_seqs):
seq_len = (max_num_batched_tokens // max_num_seqs +
(group_id < max_num_batched_tokens % max_num_seqs))
batch_size += seq_len
seq_data, _ = INPUT_REGISTRY \
.dummy_data_for_profiling(model_config, seq_len)
# Having more tokens is over-conservative but otherwise fine
assert len(seq_data.prompt_token_ids) >= seq_len, (
f"Expected at least {seq_len} dummy tokens for profiling, "
f"but got: {len(seq_data.prompt_token_ids)}")
seq = SequenceGroupMetadata(
request_id=str(group_id),
is_prompt=True,
seq_data={group_id: seq_data},
sampling_params=sampling_params,
block_tables=None,
encoder_seq_data=seq_data,
cross_block_table=None,
)
seqs.append(seq)
# Run the model with the dummy inputs.
num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches = [None] * num_layers
finished_requests_ids = [seq.request_id for seq in seqs]
model_input = self.prepare_model_input(
seqs, finished_requests_ids=finished_requests_ids)
intermediate_tensors = None
self.execute_model(model_input, kv_caches, intermediate_tensors)
torch.cuda.synchronize()
return
def _prepare_encoder_model_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
model_input: EncoderDecoderModelInput,
) -> Tuple[AttentionMetadata, Optional[torch.Tensor],
Optional[torch.Tensor]]:
"""Helper method to prepare the encoder- and cross-attn-related
model inputs based on a given sequence group. These additional inputs
are used to augment an already-computed `EncoderDecoderModelInput`
data structure which already has decoder-related model inputs
populated.
Sets the following attn_metadata fields:
* `num_encoder_tokens`
* `encoder_seq_lens`
* `encoder_seq_lens_tensor`
* `max_encoder_seq_len`
* `cross_slot_mapping`
* `cross_block_tables`
Constructs a new model inputs data structure, based on
(1) the existing fields in the `model_inputs` argument,
and (2) the following additional fields which are
computed (or in the case of `attn_metadata`, updated)
by this function:
* attn_metadata
* encoder_input_tokens
* encoder_input_positions
Arguments:
* seq_group_metadata_list: list of sequence groups for which to
compute inputs
* model_inputs: model inputs data structure with decoder-oriented
fields already computed.
Return:
* Updated model inputs data structure
"""
if len(seq_group_metadata_list) == 0:
return (model_input.attn_metadata, None, None)
# Since we are not supporting chunked prefill either the entire
# batch is prefill or it is decode
is_prompt = seq_group_metadata_list[0].is_prompt
# Build encoder inputs
encoder_seq_lens: List[int] = []
if is_prompt:
# Prefill phase.
cross_block_tables = self._empty_int32_tensor().view(
len(seq_group_metadata_list), -1)
# Extract input tokens/positions, cross-attention slot-mapping,
# & seq len from each sequence group metadata
(
encoder_input_tokens,
encoder_input_positions,
cross_slot_mapping,
) = (
[],
[],
[],
)
for seq_group_metadata in seq_group_metadata_list:
# Build seq lens
seq_len = seq_group_metadata.encoder_seq_data.get_len()
token_ids = seq_group_metadata.encoder_seq_data.get_token_ids()
encoder_seq_lens.append(seq_len)
# Build slot mapping
is_profile_run = (seq_group_metadata.block_tables is None)
if is_profile_run:
# During memory profiling, the block tables are not
# initialized yet. In this case, we just use a dummy
# slot mapping.
# In embeddings, the block tables are {seq_id: None}.
cross_slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
else:
for i in range(0, seq_len):
block_number = seq_group_metadata.cross_block_table[
i // self.block_size]
block_offset = i % self.block_size
slot = block_number * self.block_size + block_offset
cross_slot_mapping.append(slot)
# Build encoder input tokens
encoder_input_tokens.extend(token_ids)
encoder_input_positions.extend(list(range(0, seq_len)))
# Convert tokens/positions & cross-attention
# slot-mapping to encoder input tensors
encoder_input_tokens_tensor = self._list_to_long_tensor(
encoder_input_tokens)
encoder_input_positions_tensor = self._list_to_long_tensor(
encoder_input_positions)
cross_slot_mapping_tensor = self._list_to_long_tensor(
cross_slot_mapping)
else:
# Decode phase.
encoder_input_tokens_tensor = self._empty_long_tensor()
encoder_input_positions_tensor = self._empty_long_tensor()
cross_slot_mapping_tensor = self._empty_long_tensor()
# Extract cross-attention block tables &
# seq len from each sequence group metadata.
# Cross-attention block tables are empty
# during vLLM memory profiling.
cross_block_tables = []
for seq_group_metadata in seq_group_metadata_list:
encoder_seq_lens.append(
seq_group_metadata.encoder_seq_data.get_len())
cross_block_table = seq_group_metadata.cross_block_table
cross_block_tables.append([] if (
cross_block_table is None) else cross_block_table)
# Convert cross-attention block tables to encoder input tensor
cross_block_tables = make_tensor_with_pad(
cross_block_tables,
max_len=max(
len(block_table) for block_table in cross_block_tables),
pad=0,
dtype=torch.int32,
device=self.device,
)
# Compute encoder sequence lengths & encoder
# sequence starting offset tensors
max_encoder_seq_len = max(encoder_seq_lens, default=0)
encoder_seq_lens_tensor = self._list_to_int32_tensor(encoder_seq_lens)
encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] +
1,
dtype=torch.int32,
device=self.device)
torch.cumsum(encoder_seq_lens_tensor,
dim=0,
dtype=encoder_seq_start_loc.dtype,
out=encoder_seq_start_loc[1:])
# Update attention metadata with encoder-oriented attributes
attn_metadata = model_input.attn_metadata
assert attn_metadata is not None
(
attn_metadata.num_encoder_tokens,
attn_metadata.encoder_seq_lens,
attn_metadata.encoder_seq_lens_tensor,
attn_metadata.max_encoder_seq_len,
attn_metadata.cross_slot_mapping,
attn_metadata.cross_block_tables,
) = (
sum(encoder_seq_lens),
encoder_seq_lens,
encoder_seq_lens_tensor,
max_encoder_seq_len,
cross_slot_mapping_tensor,
cross_block_tables,
)
return (attn_metadata, encoder_input_tokens_tensor,
encoder_input_positions_tensor)
'''
Worker-related helper functions.
'''
from vllm.utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS
from vllm.worker.model_runner import GPUModelRunnerBase
def assert_enc_dec_mr_supported_scenario(
enc_dec_mr: GPUModelRunnerBase) -> None:
'''
Asserted that the provided encoder/decoder model runner instance reflects
a supported scenario.
'''
if enc_dec_mr.cache_config.enable_prefix_caching:
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE'])
if enc_dec_mr.sliding_window is not None:
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_SWA'])
if enc_dec_mr.scheduler_config.chunked_prefill_enabled:
raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_ERR_STRS[
'STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL'])
if getattr(enc_dec_mr.model_config.hf_config, 'attn_logit_softcapping',
None) is not None:
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP']
)
if enc_dec_mr.lora_config is not None:
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_LORA'])
if enc_dec_mr.parallel_config.pipeline_parallel_size > 1:
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PP'])
if enc_dec_mr.multimodal_config is not None:
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_MM'])
if enc_dec_mr.scheduler_config.num_lookahead_slots > 0:
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_SPEC_DEC'])
if not enc_dec_mr.model_config.enforce_eager:
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_CUDA_GRAPH'])
if enc_dec_mr.prompt_adapter_config is not None:
raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_ERR_STRS[
'STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER'])
...@@ -19,8 +19,11 @@ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig ...@@ -19,8 +19,11 @@ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest from vllm.sequence import ExecuteModelRequest
from vllm.utils import (is_embedding_model_config,
is_encoder_decoder_model_config)
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
from vllm.worker.embedding_model_runner import EmbeddingModelRunner from vllm.worker.embedding_model_runner import EmbeddingModelRunner
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput
...@@ -85,8 +88,10 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -85,8 +88,10 @@ class Worker(LocalOrDistributedWorkerBase):
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
if model_runner_cls is not None: if model_runner_cls is not None:
ModelRunnerClass = model_runner_cls ModelRunnerClass = model_runner_cls
elif self.model_config.embedding_mode: elif self._is_embedding_model():
ModelRunnerClass = EmbeddingModelRunner ModelRunnerClass = EmbeddingModelRunner
elif self._is_encoder_decoder_model():
ModelRunnerClass = EncoderDecoderModelRunner
self.model_runner: GPUModelRunnerBase = ModelRunnerClass( self.model_runner: GPUModelRunnerBase = ModelRunnerClass(
model_config, model_config,
parallel_config, parallel_config,
...@@ -107,6 +112,12 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -107,6 +112,12 @@ class Worker(LocalOrDistributedWorkerBase):
# Initialize gpu_cache as embedding models don't initialize kv_caches # Initialize gpu_cache as embedding models don't initialize kv_caches
self.gpu_cache: Optional[List[List[torch.Tensor]]] = None self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
def _is_encoder_decoder_model(self):
return is_encoder_decoder_model_config(self.model_config)
def _is_embedding_model(self):
return is_embedding_model_config(self.model_config)
def init_device(self) -> None: def init_device(self) -> None:
if self.device_config.device.type == "cuda": if self.device_config.device.type == "cuda":
# torch.distributed.all_reduce does not free the input tensor until # torch.distributed.all_reduce does not free the input tensor until
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment