"ssh:/git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "61acfc45bcf44de6e5a14c82906bbb7a33940443"
Unverified Commit 58170d65 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Hardware][CPU] Add embedding models support for CPU backend (#10193)


Signed-off-by: default avatarIsotr0py <2037008807@qq.com>
parent 9804ac7c
...@@ -25,8 +25,7 @@ function cpu_tests() { ...@@ -25,8 +25,7 @@ function cpu_tests() {
decord einops librosa peft Pillow sentence-transformers soundfile \ decord einops librosa peft Pillow sentence-transformers soundfile \
transformers_stream_generator matplotlib datamodel_code_generator transformers_stream_generator matplotlib datamodel_code_generator
pip install torchvision --index-url https://download.pytorch.org/whl/cpu pip install torchvision --index-url https://download.pytorch.org/whl/cpu
# Embedding models are not supported for CPU yet pytest -v -s tests/models/embedding/language
# pytest -v -s tests/models/embedding/language
pytest -v -s tests/models/encoder_decoder/language pytest -v -s tests/models/encoder_decoder/language
pytest -v -s tests/models/decoder_only/language/test_models.py pytest -v -s tests/models/decoder_only/language/test_models.py
pytest -v -s tests/models/decoder_only/audio_language -m cpu_model pytest -v -s tests/models/decoder_only/audio_language -m cpu_model
......
...@@ -32,8 +32,7 @@ function cpu_tests() { ...@@ -32,8 +32,7 @@ function cpu_tests() {
decord einops librosa peft Pillow sentence-transformers soundfile \ decord einops librosa peft Pillow sentence-transformers soundfile \
transformers_stream_generator matplotlib datamodel_code_generator transformers_stream_generator matplotlib datamodel_code_generator
pip install torchvision --index-url https://download.pytorch.org/whl/cpu pip install torchvision --index-url https://download.pytorch.org/whl/cpu
# Embedding models are not supported for CPU yet pytest -v -s tests/models/embedding/language
# pytest -v -s tests/models/embedding/language
pytest -v -s tests/models/encoder_decoder/language pytest -v -s tests/models/encoder_decoder/language
pytest -v -s tests/models/decoder_only/language/test_models.py pytest -v -s tests/models/decoder_only/language/test_models.py
pytest -v -s tests/models/decoder_only/audio_language -m cpu_model pytest -v -s tests/models/decoder_only/audio_language -m cpu_model
......
...@@ -4,6 +4,8 @@ Run `pytest tests/models/embedding/language/test_embedding.py`. ...@@ -4,6 +4,8 @@ Run `pytest tests/models/embedding/language/test_embedding.py`.
""" """
import pytest import pytest
from vllm.utils import current_platform
from ..utils import check_embeddings_close from ..utils import check_embeddings_close
# Model, Guard # Model, Guard
...@@ -21,15 +23,14 @@ ENCODER_ONLY = [ ...@@ -21,15 +23,14 @@ ENCODER_ONLY = [
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
def test_models( def test_models(
monkeypatch,
hf_runner, hf_runner,
vllm_runner, vllm_runner,
example_prompts, example_prompts,
model, model,
dtype: str, dtype: str,
) -> None: ) -> None:
if model in ENCODER_ONLY: if model not in ENCODER_ONLY and current_platform.is_cpu():
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS") pytest.skip("Skip large embedding models test on CPU.")
# The example_prompts has ending "\n", for example: # The example_prompts has ending "\n", for example:
# "Write a short story about a robot that dreams for the first time.\n" # "Write a short story about a robot that dreams for the first time.\n"
......
...@@ -158,7 +158,8 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -158,7 +158,8 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
* Appropriate sequence lengths tensor for key & value * Appropriate sequence lengths tensor for key & value
''' '''
if attn_type == AttentionType.DECODER: if (attn_type == AttentionType.DECODER
or attn_type == AttentionType.ENCODER_ONLY):
seq_lens_q = self.seq_lens seq_lens_q = self.seq_lens
seq_lens_kv = self.seq_lens seq_lens_kv = self.seq_lens
elif attn_type == AttentionType.ENCODER: elif attn_type == AttentionType.ENCODER:
...@@ -189,7 +190,8 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -189,7 +190,8 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
* Appropriate attention bias value given the attention type * Appropriate attention bias value given the attention type
''' '''
if attn_type == AttentionType.DECODER: if (attn_type == AttentionType.DECODER
or attn_type == AttentionType.ENCODER_ONLY):
return self.attn_bias return self.attn_bias
elif attn_type == AttentionType.ENCODER: elif attn_type == AttentionType.ENCODER:
return self.encoder_attn_bias return self.encoder_attn_bias
...@@ -215,7 +217,8 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -215,7 +217,8 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
encoder/decoder cross-attention encoder/decoder cross-attention
''' '''
if attn_type == AttentionType.DECODER: if (attn_type == AttentionType.DECODER
or attn_type == AttentionType.ENCODER_ONLY):
self.attn_bias = attn_bias self.attn_bias = attn_bias
elif attn_type == AttentionType.ENCODER: elif attn_type == AttentionType.ENCODER:
self.encoder_attn_bias = attn_bias self.encoder_attn_bias = attn_bias
...@@ -252,7 +255,8 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -252,7 +255,8 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
* Appropriate block tables (or None) * Appropriate block tables (or None)
''' '''
if attn_type == AttentionType.DECODER: if (attn_type == AttentionType.DECODER
or attn_type == AttentionType.ENCODER_ONLY):
# Decoder self-attention # Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run # Choose max_seq_len based on whether we are in prompt_run
return (self.seq_lens_tensor, self.max_decode_seq_len, return (self.seq_lens_tensor, self.max_decode_seq_len,
...@@ -420,6 +424,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ...@@ -420,6 +424,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
"Torch SDPA backend doesn't support prefix decoding.") "Torch SDPA backend doesn't support prefix decoding.")
if decode_meta := attn_metadata.decode_metadata: if decode_meta := attn_metadata.decode_metadata:
assert attn_type != AttentionType.ENCODER_ONLY, (
"Encoder-only models should not have decode metadata.")
# Decoding run. # Decoding run.
( (
seq_lens_arg, seq_lens_arg,
......
...@@ -5,7 +5,6 @@ from torch import nn ...@@ -5,7 +5,6 @@ from torch import nn
from transformers import BertConfig from transformers import BertConfig
from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.backends.xformers import XFormersImpl
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
...@@ -218,11 +217,6 @@ class BertSelfAttention(nn.Module): ...@@ -218,11 +217,6 @@ class BertSelfAttention(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.attn") prefix=f"{prefix}.attn")
if not isinstance(self.attn.impl, XFormersImpl):
raise ValueError(
"Encoder-only models currently require XFORMERS attention "
"backend. Set VLLM_ATTENTION_BACKEND=XFORMERS to use BERT.")
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
......
import dataclasses
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import torch
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.multimodal import MultiModalKwargs
from vllm.pooling_params import PoolingParams
from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData,
SequenceGroupMetadata)
from vllm.worker.cpu_model_runner import (CPUModelRunnerBase, ModelInputForCPU,
ModelInputForCPUBuilder)
@dataclasses.dataclass(frozen=True)
class ModelInputForCPUWithPoolingMetadata(ModelInputForCPU):
"""
Used by the CPUEmbeddingModelRunner.
"""
pooling_metadata: Optional["PoolingMetadata"] = None
class CPUEmbeddingModelRunner(
CPUModelRunnerBase[ModelInputForCPUWithPoolingMetadata]):
_model_input_cls: Type[ModelInputForCPUWithPoolingMetadata] = (
ModelInputForCPUWithPoolingMetadata)
_builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder
@torch.inference_mode()
def execute_model(
self,
model_input: ModelInputForCPUWithPoolingMetadata,
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[Union[List[PoolerOutput], IntermediateTensors]]:
if num_steps > 1:
raise ValueError(
"CPU worker does not support multi-step execution.")
num_layers = self.model_config.get_num_layers(self.parallel_config)
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
kv_caches = [
torch.tensor([], dtype=torch.float32, device=self.device)
for _ in range(num_layers)
]
model_executable = self.model
execute_model_kwargs = {
"input_ids":
model_input.input_tokens,
"positions":
model_input.input_positions,
"kv_caches":
kv_caches,
"attn_metadata":
model_input.attn_metadata,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device),
"intermediate_tensors":
intermediate_tensors,
}
hidden_states = model_executable(**execute_model_kwargs)
return [
self.model.pooler(hidden_states=hidden_states,
pooling_metadata=model_input.pooling_metadata)
]
def make_model_input_from_broadcasted_tensor_dict(
self,
tensor_dict: Dict[str,
Any]) -> ModelInputForCPUWithPoolingMetadata:
return ModelInputForCPUWithPoolingMetadata.from_broadcasted_tensor_dict(
tensor_dict,
attn_backend=self.attn_backend,
)
def prepare_model_input(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForCPUWithPoolingMetadata:
assert seq_group_metadata_list is not None
model_input = self._prepare_model_input_tensors(
seq_group_metadata_list, finished_requests_ids)
# Prepare PoolingMetadata.
assert model_input.seq_lens is not None
pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
model_input.seq_lens)
return dataclasses.replace(model_input,
pooling_metadata=pooling_metadata)
def _prepare_pooling(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
prompt_lens: List[int],
) -> PoolingMetadata:
"""Prepare PoolingMetadata for the sequence group metadata list."""
seq_groups: List[Tuple[List[int], PoolingParams]] = []
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_ids = list(seq_group_metadata.seq_data.keys())
pooling_params = seq_group_metadata.pooling_params
seq_groups.append((seq_ids, pooling_params))
seq_data: Dict[int, SequenceData] = {}
for seq_group_metadata in seq_group_metadata_list:
seq_data.update(seq_group_metadata.seq_data)
pooling_metadata = PoolingMetadata(
seq_groups=seq_groups,
seq_data=seq_data,
prompt_lens=prompt_lens,
)
return pooling_metadata
...@@ -8,7 +8,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput ...@@ -8,7 +8,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MultiModalKwargs from vllm.multimodal import MultiModalKwargs
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.utils import make_tensor_with_pad from vllm.utils import make_tensor_with_pad
from vllm.worker.cpu_model_runner import (CPUModelRunner, from vllm.worker.cpu_model_runner import (CPUModelRunnerBase,
ModelInputForCPUBuilder, ModelInputForCPUBuilder,
ModelInputForCPUWithSamplingMetadata) ModelInputForCPUWithSamplingMetadata)
from vllm.worker.model_runner_base import ( from vllm.worker.model_runner_base import (
...@@ -50,7 +50,8 @@ class EncoderDecoderModelInputForCPU(ModelInputForCPUWithSamplingMetadata): ...@@ -50,7 +50,8 @@ class EncoderDecoderModelInputForCPU(ModelInputForCPUWithSamplingMetadata):
super().from_broadcasted_tensor_dict(tensor_dict, attn_backend)) super().from_broadcasted_tensor_dict(tensor_dict, attn_backend))
class CPUEncoderDecoderModelRunner(CPUModelRunner): class CPUEncoderDecoderModelRunner(
CPUModelRunnerBase[EncoderDecoderModelInputForCPU]):
_model_input_cls: Type[EncoderDecoderModelInputForCPU] = ( _model_input_cls: Type[EncoderDecoderModelInputForCPU] = (
EncoderDecoderModelInputForCPU) EncoderDecoderModelInputForCPU)
_builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder _builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder
...@@ -87,10 +88,8 @@ class CPUEncoderDecoderModelRunner(CPUModelRunner): ...@@ -87,10 +88,8 @@ class CPUEncoderDecoderModelRunner(CPUModelRunner):
virtual_engine: int = 0, virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None finished_requests_ids: Optional[List[str]] = None
) -> EncoderDecoderModelInputForCPU: ) -> EncoderDecoderModelInputForCPU:
model_input = super().prepare_model_input(seq_group_metadata_list, model_input = self._prepare_model_input_tensors(
virtual_engine, seq_group_metadata_list, finished_requests_ids)
finished_requests_ids)
model_input = cast(EncoderDecoderModelInputForCPU, model_input)
( (
attn_metadata, attn_metadata,
encoder_input_tokens_tensor, encoder_input_tokens_tensor,
......
...@@ -2,7 +2,8 @@ import dataclasses ...@@ -2,7 +2,8 @@ import dataclasses
import weakref import weakref
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type,
TypeVar, Union)
import torch import torch
from torch import nn from torch import nn
...@@ -31,6 +32,7 @@ if TYPE_CHECKING: ...@@ -31,6 +32,7 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
TModelInputForCPU = TypeVar('TModelInputForCPU', bound="ModelInputForCPU")
_PAD_SLOT_ID = -1 _PAD_SLOT_ID = -1
...@@ -60,10 +62,10 @@ class ModelInputForCPU(ModelRunnerInputBase): ...@@ -60,10 +62,10 @@ class ModelInputForCPU(ModelRunnerInputBase):
@classmethod @classmethod
def from_broadcasted_tensor_dict( def from_broadcasted_tensor_dict(
cls: Type["ModelInputForCPU"], cls: Type[TModelInputForCPU],
tensor_dict: Dict[str, Any], tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None attn_backend: Optional["AttentionBackend"] = None
) -> "ModelInputForCPU": ) -> TModelInputForCPU:
if attn_backend is not None: if attn_backend is not None:
tensor_dict = _init_attn_metadata_from_tensor_dict( tensor_dict = _init_attn_metadata_from_tensor_dict(
attn_backend, tensor_dict) attn_backend, tensor_dict)
...@@ -255,11 +257,14 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): ...@@ -255,11 +257,14 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
slot_mapping.append(_PAD_SLOT_ID) slot_mapping.append(_PAD_SLOT_ID)
continue continue
block_number = block_table[i // # For encoder-only models, the block_table is None,
self.block_size] # type: ignore # and there is no need to initialize the slot_mapping.
block_offset = i % self.block_size # type: ignore if block_table is not None:
slot = block_number * self.block_size + block_offset block_number = block_table[i //
slot_mapping.append(slot) self.block_size] # type: ignore
block_offset = i % self.block_size # type: ignore
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
if any(input_mrope_positions): if any(input_mrope_positions):
input_positions = None # type: ignore input_positions = None # type: ignore
...@@ -402,10 +407,12 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): ...@@ -402,10 +407,12 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
) )
class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]): class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
_model_input_cls: Type[ModelInputForCPUWithSamplingMetadata] = ( """
ModelInputForCPUWithSamplingMetadata) Helper class for shared methods between CPU model runners.
_builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder """
_model_input_cls: Type[TModelInputForCPU]
_builder_cls: Type[ModelInputForCPUBuilder]
def __init__( def __init__(
self, self,
...@@ -448,20 +455,11 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]): ...@@ -448,20 +455,11 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
def load_model(self) -> None: def load_model(self) -> None:
self.model = get_model(vllm_config=self.vllm_config) self.model = get_model(vllm_config=self.vllm_config)
def make_model_input_from_broadcasted_tensor_dict(
self,
tensor_dict: Dict[str, Any],
) -> ModelInputForCPUWithSamplingMetadata:
return ModelInputForCPUWithSamplingMetadata.from_broadcasted_tensor_dict( # noqa: E501
tensor_dict,
attn_backend=self.attn_backend,
)
def _prepare_model_input_tensors( def _prepare_model_input_tensors(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
finished_requests_ids: Optional[List[str]] = None finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForCPUWithSamplingMetadata: ) -> TModelInputForCPU:
"""Helper method to prepare the model input based on a given sequence """Helper method to prepare the model input based on a given sequence
group. Prepares metadata needed for the base model forward pass but not group. Prepares metadata needed for the base model forward pass but not
metadata for possible additional steps, e.g., sampling. metadata for possible additional steps, e.g., sampling.
...@@ -473,6 +471,21 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]): ...@@ -473,6 +471,21 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
return builder.build() # type: ignore return builder.build() # type: ignore
class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
_model_input_cls: Type[ModelInputForCPUWithSamplingMetadata] = (
ModelInputForCPUWithSamplingMetadata)
_builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder
def make_model_input_from_broadcasted_tensor_dict(
self,
tensor_dict: Dict[str, Any],
) -> ModelInputForCPUWithSamplingMetadata:
return ModelInputForCPUWithSamplingMetadata.from_broadcasted_tensor_dict( # noqa: E501
tensor_dict,
attn_backend=self.attn_backend,
)
def prepare_model_input( def prepare_model_input(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
......
...@@ -14,8 +14,9 @@ from vllm.logger import init_logger ...@@ -14,8 +14,9 @@ from vllm.logger import init_logger
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest from vllm.sequence import ExecuteModelRequest
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.worker.cpu_embedding_model_runner import CPUEmbeddingModelRunner
from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner
from vllm.worker.cpu_model_runner import CPUModelRunner from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
LoraNotSupportedWorkerBase, WorkerBase, LoraNotSupportedWorkerBase, WorkerBase,
WorkerInput) WorkerInput)
...@@ -150,21 +151,20 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -150,21 +151,20 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
else: else:
self.local_omp_cpuid = omp_cpuids.split("|")[rank] self.local_omp_cpuid = omp_cpuids.split("|")[rank]
ModelRunnerClass: Type[CPUModelRunner] = CPUModelRunner ModelRunnerClass: Type[CPUModelRunnerBase] = CPUModelRunner
if self.model_config.task == "embedding": if self.model_config.task == "embedding":
raise NotImplementedError( ModelRunnerClass = CPUEmbeddingModelRunner
"Embedding models are not supported for CPU backend")
# ModelRunnerClass = CPUEmbeddingModelRunner
elif self.model_config.is_encoder_decoder: elif self.model_config.is_encoder_decoder:
ModelRunnerClass = CPUEncoderDecoderModelRunner ModelRunnerClass = CPUEncoderDecoderModelRunner
self.model_runner: CPUModelRunner = ModelRunnerClass( self.model_runner: CPUModelRunnerBase = ModelRunnerClass(
vllm_config=vllm_config, vllm_config=vllm_config,
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker) is_driver_worker=is_driver_worker)
# Uninitialized cache engine. Will be initialized by # Uninitialized cache engine. Will be initialized by
# initialize_cache. # initialize_cache.
self.cache_engine: List[CPUCacheEngine] self.cache_engine: List[CPUCacheEngine]
self.cpu_cache: List[List[torch.Tensor]] # Initialize cpu_cache as embedding models don't initialize kv_caches
self.cpu_cache: Optional[List[List[torch.Tensor]]] = None
# Torch profiler. Enabled and configured through env vars: # Torch profiler. Enabled and configured through env vars:
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment