Unverified Commit 042af0c8 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Model][1/N] Support multiple poolers at model level (#21227)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 378d33c3
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import dataclasses import dataclasses
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type, from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
TypeVar, get_args) TypeVar)
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -230,10 +230,7 @@ class ModelRunnerBase(ABC, Generic[T]): ...@@ -230,10 +230,7 @@ class ModelRunnerBase(ABC, Generic[T]):
if not is_pooling_model(model): if not is_pooling_model(model):
return [] return []
return [ return list(model.pooler.get_supported_tasks())
task for task in get_args(PoolingTask)
if model.pooler.get_pooling_updates(task)
]
def execute_model( def execute_model(
self, self,
......
...@@ -199,15 +199,11 @@ class PoolingModelRunner( ...@@ -199,15 +199,11 @@ class PoolingModelRunner(
pooling_params = seq_group_metadata.pooling_params pooling_params = seq_group_metadata.pooling_params
assert pooling_params is not None assert pooling_params is not None
assert pooling_params.task is not None, ( assert (task := pooling_params.task) is not None, (
"You did not set `task` in the API") "You did not set `task` in the API")
to_update = (cast(VllmModelForPooling, model = cast(VllmModelForPooling, self.model)
self.model).pooler.get_pooling_updates( to_update = model.pooler.get_pooling_updates(task)
pooling_params.task))
assert to_update is not None, (
f"{pooling_params.task=} is not supported by the model")
to_update.apply(pooling_params) to_update.apply(pooling_params)
seq_groups.append((seq_ids, pooling_params)) seq_groups.append((seq_ids, pooling_params))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment