Unverified Commit 45badd05 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Core] Set pooling params based on task and model (#21128)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 4adc66f6
......@@ -3,7 +3,7 @@
import bisect
import gc
import time
from typing import TYPE_CHECKING, Any, Optional, cast
from typing import TYPE_CHECKING, Any, Optional, cast, get_args
from unittest.mock import patch
import numpy as np
......@@ -25,10 +25,12 @@ from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA
from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.tpu import TPUModelLoader
from vllm.model_executor.models.interfaces_base import is_pooling_model
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
PlaceholderRange)
from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.pooling_params import PoolingTask
from vllm.sequence import IntermediateTensors
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, cdiv,
is_pin_memory_available, prev_power_of_2)
......@@ -483,6 +485,16 @@ class TPUModelRunner(LoRAModelRunnerMixin):
def get_model(self) -> nn.Module:
return self.model
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
model = self.get_model()
if not is_pooling_model(model):
return []
return [
task for task in get_args(PoolingTask)
if model.pooler.get_pooling_updates(task)
]
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
"""
Generates the KVCacheSpec by parsing the kv cache format from each
......
......@@ -19,6 +19,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.pooling_params import PoolingTask
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT
from vllm.v1.core.sched.output import SchedulerOutput
......@@ -275,6 +276,9 @@ class TPUWorker:
def get_model(self) -> nn.Module:
return self.model_runner.get_model()
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
return self.model_runner.get_supported_pooling_tasks()
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
return self.model_runner.get_kv_cache_spec()
......
......@@ -4,7 +4,7 @@
import dataclasses
from abc import ABC, abstractmethod
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
TypeVar)
TypeVar, get_args)
import torch
import torch.nn as nn
......@@ -12,6 +12,8 @@ import torch.nn as nn
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.models.interfaces_base import is_pooling_model
from vllm.pooling_params import PoolingTask
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
if TYPE_CHECKING:
......@@ -223,6 +225,16 @@ class ModelRunnerBase(ABC, Generic[T]):
def get_model(self) -> nn.Module:
raise NotImplementedError
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
model = self.get_model()
if not is_pooling_model(model):
return []
return [
task for task in get_args(PoolingTask)
if model.pooler.get_pooling_updates(task)
]
def execute_model(
self,
model_input: T,
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast
import torch
......@@ -10,6 +10,7 @@ from vllm.config import VllmConfig
from vllm.distributed import get_pp_group
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces_base import VllmModelForPooling
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.multimodal import MultiModalKwargs
from vllm.pooling_params import PoolingParams
......@@ -195,7 +196,20 @@ class PoolingModelRunner(
seq_groups: List[Tuple[List[int], PoolingParams]] = []
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_ids = list(seq_group_metadata.seq_data.keys())
pooling_params = seq_group_metadata.pooling_params
assert pooling_params is not None
assert pooling_params.task is not None, (
"You did not set `task` in the API")
to_update = (cast(VllmModelForPooling,
self.model).pooler.get_pooling_updates(
pooling_params.task))
assert to_update is not None, (
f"{pooling_params.task=} is not supported by the model")
to_update.apply(pooling_params)
seq_groups.append((seq_ids, pooling_params))
seq_data: Dict[int, SequenceData] = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment