Unverified Commit 042af0c8 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Model][1/N] Support multiple poolers at model level (#21227)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 378d33c3
......@@ -4,7 +4,7 @@
import dataclasses
from abc import ABC, abstractmethod
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
TypeVar, get_args)
TypeVar)
import torch
import torch.nn as nn
......@@ -230,10 +230,7 @@ class ModelRunnerBase(ABC, Generic[T]):
if not is_pooling_model(model):
return []
return [
task for task in get_args(PoolingTask)
if model.pooler.get_pooling_updates(task)
]
return list(model.pooler.get_supported_tasks())
def execute_model(
self,
......
......@@ -199,15 +199,11 @@ class PoolingModelRunner(
pooling_params = seq_group_metadata.pooling_params
assert pooling_params is not None
assert pooling_params.task is not None, (
assert (task := pooling_params.task) is not None, (
"You did not set `task` in the API")
to_update = (cast(VllmModelForPooling,
self.model).pooler.get_pooling_updates(
pooling_params.task))
assert to_update is not None, (
f"{pooling_params.task=} is not supported by the model")
model = cast(VllmModelForPooling, self.model)
to_update = model.pooler.get_pooling_updates(task)
to_update.apply(pooling_params)
seq_groups.append((seq_ids, pooling_params))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment