Unverified Commit 45badd05 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Core] Set pooling params based on task and model (#21128)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 4adc66f6
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import bisect import bisect
import gc import gc
import time import time
from typing import TYPE_CHECKING, Any, Optional, cast from typing import TYPE_CHECKING, Any, Optional, cast, get_args
from unittest.mock import patch from unittest.mock import patch
import numpy as np import numpy as np
...@@ -25,10 +25,12 @@ from vllm.logger import init_logger ...@@ -25,10 +25,12 @@ from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA from vllm.lora.layers import BaseLayerWithLoRA
from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.tpu import TPUModelLoader from vllm.model_executor.model_loader.tpu import TPUModelLoader
from vllm.model_executor.models.interfaces_base import is_pooling_model
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs, from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
PlaceholderRange) PlaceholderRange)
from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.pooling_params import PoolingTask
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, cdiv, from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, cdiv,
is_pin_memory_available, prev_power_of_2) is_pin_memory_available, prev_power_of_2)
...@@ -483,6 +485,16 @@ class TPUModelRunner(LoRAModelRunnerMixin): ...@@ -483,6 +485,16 @@ class TPUModelRunner(LoRAModelRunnerMixin):
def get_model(self) -> nn.Module: def get_model(self) -> nn.Module:
return self.model return self.model
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
model = self.get_model()
if not is_pooling_model(model):
return []
return [
task for task in get_args(PoolingTask)
if model.pooler.get_pooling_updates(task)
]
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
""" """
Generates the KVCacheSpec by parsing the kv cache format from each Generates the KVCacheSpec by parsing the kv cache format from each
......
...@@ -19,6 +19,7 @@ from vllm.logger import init_logger ...@@ -19,6 +19,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.pooling_params import PoolingTask
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
...@@ -275,6 +276,9 @@ class TPUWorker: ...@@ -275,6 +276,9 @@ class TPUWorker:
def get_model(self) -> nn.Module: def get_model(self) -> nn.Module:
return self.model_runner.get_model() return self.model_runner.get_model()
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
return self.model_runner.get_supported_pooling_tasks()
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
return self.model_runner.get_kv_cache_spec() return self.model_runner.get_kv_cache_spec()
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import dataclasses import dataclasses
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type, from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
TypeVar) TypeVar, get_args)
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -12,6 +12,8 @@ import torch.nn as nn ...@@ -12,6 +12,8 @@ import torch.nn as nn
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.models.interfaces_base import is_pooling_model
from vllm.pooling_params import PoolingTask
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -223,6 +225,16 @@ class ModelRunnerBase(ABC, Generic[T]): ...@@ -223,6 +225,16 @@ class ModelRunnerBase(ABC, Generic[T]):
def get_model(self) -> nn.Module: def get_model(self) -> nn.Module:
raise NotImplementedError raise NotImplementedError
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
model = self.get_model()
if not is_pooling_model(model):
return []
return [
task for task in get_args(PoolingTask)
if model.pooler.get_pooling_updates(task)
]
def execute_model( def execute_model(
self, self,
model_input: T, model_input: T,
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses import dataclasses
from typing import Any, Dict, List, Optional, Tuple, Type, Union from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast
import torch import torch
...@@ -10,6 +10,7 @@ from vllm.config import VllmConfig ...@@ -10,6 +10,7 @@ from vllm.config import VllmConfig
from vllm.distributed import get_pp_group from vllm.distributed import get_pp_group
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.models.interfaces_base import VllmModelForPooling
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.multimodal import MultiModalKwargs from vllm.multimodal import MultiModalKwargs
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
...@@ -195,7 +196,20 @@ class PoolingModelRunner( ...@@ -195,7 +196,20 @@ class PoolingModelRunner(
seq_groups: List[Tuple[List[int], PoolingParams]] = [] seq_groups: List[Tuple[List[int], PoolingParams]] = []
for i, seq_group_metadata in enumerate(seq_group_metadata_list): for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_ids = list(seq_group_metadata.seq_data.keys()) seq_ids = list(seq_group_metadata.seq_data.keys())
pooling_params = seq_group_metadata.pooling_params pooling_params = seq_group_metadata.pooling_params
assert pooling_params is not None
assert pooling_params.task is not None, (
"You did not set `task` in the API")
to_update = (cast(VllmModelForPooling,
self.model).pooler.get_pooling_updates(
pooling_params.task))
assert to_update is not None, (
f"{pooling_params.task=} is not supported by the model")
to_update.apply(pooling_params)
seq_groups.append((seq_ids, pooling_params)) seq_groups.append((seq_ids, pooling_params))
seq_data: Dict[int, SequenceData] = {} seq_data: Dict[int, SequenceData] = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment