Unverified Commit 583a90e0 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Separate sequence and token pooling types (#32026)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 52d42829
......@@ -3,11 +3,11 @@
from typing import Literal, get_args
GenerationTask = Literal["generate", "transcription"]
GENERATION_TASKS = get_args(GenerationTask)
GENERATION_TASKS: tuple[GenerationTask, ...] = get_args(GenerationTask)
PoolingTask = Literal[
"embed", "classify", "score", "token_embed", "token_classify", "plugin"
]
POOLING_TASKS = get_args(PoolingTask)
POOLING_TASKS: tuple[PoolingTask, ...] = get_args(PoolingTask)
SupportedTask = Literal[GenerationTask, PoolingTask]
......@@ -10,9 +10,7 @@ from pathlib import Path
from typing import Any, Literal, TypeAlias
import huggingface_hub
from huggingface_hub import (
get_safetensors_metadata,
)
from huggingface_hub import get_safetensors_metadata
from packaging.version import Version
from transformers import GenerationConfig, PretrainedConfig
from transformers.models.auto.image_processing_auto import get_image_processor_config
......@@ -742,7 +740,10 @@ def get_config(
@cache
def get_pooling_config(model: str, revision: str | None = "main") -> dict | None:
def get_pooling_config(
model: str,
revision: str | None = "main",
) -> dict[str, Any] | None:
"""
This function gets the pooling and normalize
config from the model - only applies to
......@@ -793,38 +794,40 @@ def get_pooling_config(model: str, revision: str | None = "main") -> dict | None
)
if pooling:
pooling_file_name = "{}/config.json".format(pooling["path"])
pooling_dict = get_hf_file_to_dict(pooling_file_name, model, revision)
pooling_type_name = next(
(item for item, val in pooling_dict.items() if val is True), None
)
from vllm.config.pooler import SEQ_POOLING_TYPES, TOK_POOLING_TYPES
if pooling_type_name is not None:
pooling_type_name = get_pooling_config_name(pooling_type_name)
pooling_file_name = "{}/config.json".format(pooling["path"])
pooling_dict = get_hf_file_to_dict(pooling_file_name, model, revision) or {}
logger.info("Found pooling configuration.")
return {"pooling_type": pooling_type_name, "normalize": normalize}
config: dict[str, Any] = {"normalize": normalize}
for key, val in pooling_dict.items():
if val is True:
pooling_type = parse_pooling_type(key)
if pooling_type in SEQ_POOLING_TYPES:
config["seq_pooling_type"] = pooling_type
elif pooling_type in TOK_POOLING_TYPES:
config["tok_pooling_type"] = pooling_type
else:
logger.debug("Skipping unrelated field: %r=%r", key, val)
return config
return None
def get_pooling_config_name(pooling_name: str) -> str | None:
def parse_pooling_type(pooling_name: str):
if "pooling_mode_" in pooling_name:
pooling_name = pooling_name.replace("pooling_mode_", "")
if "_" in pooling_name:
pooling_name = pooling_name.split("_")[0]
pooling_name = pooling_name.split("_", 1)[0]
if "lasttoken" in pooling_name:
pooling_name = "last"
supported_pooling_types = ["LAST", "ALL", "CLS", "STEP", "MEAN"]
pooling_type_name = pooling_name.upper()
if pooling_type_name in supported_pooling_types:
return pooling_type_name
raise NotImplementedError(f"Pooling type {pooling_type_name} not supported")
return pooling_name.upper()
@cache
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment