Unverified Commit f54f8512 authored by wang.yuqi's avatar wang.yuqi Committed by GitHub
Browse files

[Model][2/N] Improve all pooling task | Support multi-vector retrieval (#25370)


Signed-off-by: default avatarwang.yuqi <noooop@126.com>
parent d4d1a602
...@@ -26,6 +26,12 @@ python examples/offline_inference/pooling/embed_jina_embeddings_v3.py ...@@ -26,6 +26,12 @@ python examples/offline_inference/pooling/embed_jina_embeddings_v3.py
python examples/offline_inference/pooling/embed_matryoshka_fy.py python examples/offline_inference/pooling/embed_matryoshka_fy.py
``` ```
## Multi vector retrieval usage
```bash
python examples/offline_inference/pooling/multi_vector_retrieval.py
```
## Named Entity Recognition (NER) usage ## Named Entity Recognition (NER) usage
```bash ```bash
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from argparse import Namespace
from vllm import LLM, EngineArgs
from vllm.utils import FlexibleArgumentParser
def parse_args():
parser = FlexibleArgumentParser()
parser = EngineArgs.add_cli_args(parser)
# Set example specific arguments
parser.set_defaults(
model="BAAI/bge-m3",
runner="pooling",
enforce_eager=True,
)
return parser.parse_args()
def main(args: Namespace):
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create an LLM.
# You should pass runner="pooling" for embedding models
llm = LLM(**vars(args))
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
outputs = llm.embed(prompts)
# Print the outputs.
print("\nGenerated Outputs:\n" + "-" * 60)
for prompt, output in zip(prompts, outputs):
embeds = output.outputs.embedding
print(len(embeds))
# Generate embedding for each token. The output is a list of PoolingRequestOutput.
outputs = llm.encode(prompts, pooling_task="token_embed")
# Print the outputs.
print("\nGenerated Outputs:\n" + "-" * 60)
for prompt, output in zip(prompts, outputs):
multi_vector = output.outputs.data
print(multi_vector.shape)
if __name__ == "__main__":
args = parse_args()
main(args)
...@@ -40,7 +40,7 @@ def main(): ...@@ -40,7 +40,7 @@ def main():
model_impl="terratorch", model_impl="terratorch",
) )
pooling_params = PoolingParams(task="encode", softmax=False) pooling_params = PoolingParams(task="token_classify", activation=False)
pooler_output = llm.encode( pooler_output = llm.encode(
img_prompt, img_prompt,
pooling_params=pooling_params, pooling_params=pooling_params,
......
...@@ -18,6 +18,12 @@ python examples/online_serving/pooling/embedding_embed_dtype_client.py ...@@ -18,6 +18,12 @@ python examples/online_serving/pooling/embedding_embed_dtype_client.py
python examples/online_serving/pooling/jinaai_rerank_client.py python examples/online_serving/pooling/jinaai_rerank_client.py
``` ```
## Multi vector retrieval usage
```bash
python examples/online_serving/pooling/multi_vector_retrieval_client.py
```
## Named Entity Recognition (NER) usage ## Named Entity Recognition (NER) usage
```bash ```bash
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Example online usage of Pooling API for multi vector retrieval.
Run `vllm serve <model> --runner pooling`
to start up the server in vLLM. e.g.
vllm serve BAAI/bge-m3
"""
import argparse
import requests
import torch
def post_http_request(prompt: dict, api_url: str) -> requests.Response:
headers = {"User-Agent": "Test Client"}
response = requests.post(api_url, headers=headers, json=prompt)
return response
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--model", type=str, default="BAAI/bge-m3")
return parser.parse_args()
def main(args):
api_url = f"http://{args.host}:{args.port}/pooling"
model_name = args.model
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
prompt = {"model": model_name, "input": prompts}
pooling_response = post_http_request(prompt=prompt, api_url=api_url)
for output in pooling_response.json()["data"]:
multi_vector = torch.tensor(output["data"])
print(multi_vector.shape)
if __name__ == "__main__":
args = parse_args()
main(args)
...@@ -1011,8 +1011,12 @@ class VllmRunner: ...@@ -1011,8 +1011,12 @@ class VllmRunner:
req_outputs = self.llm.embed(inputs, *args, **kwargs) req_outputs = self.llm.embed(inputs, *args, **kwargs)
return [req_output.outputs.embedding for req_output in req_outputs] return [req_output.outputs.embedding for req_output in req_outputs]
def encode(self, prompts: list[str]) -> list[list[float]]: def token_embed(self, prompts: list[str]) -> list[list[float]]:
req_outputs = self.llm.encode(prompts) req_outputs = self.llm.encode(prompts, pooling_task="token_embed")
return [req_output.outputs.data for req_output in req_outputs]
def token_classify(self, prompts: list[str]) -> list[list[float]]:
req_outputs = self.llm.encode(prompts, pooling_task="token_classify")
return [req_output.outputs.data for req_output in req_outputs] return [req_output.outputs.data for req_output in req_outputs]
def reward(self, prompts: list[str]) -> list[list[float]]: def reward(self, prompts: list[str]) -> list[list[float]]:
......
...@@ -63,7 +63,7 @@ def test_encode_api(llm: LLM): ...@@ -63,7 +63,7 @@ def test_encode_api(llm: LLM):
# chunked prefill does not support all pooling # chunked prefill does not support all pooling
err_msg = "pooling_task must be one of.+" err_msg = "pooling_task must be one of.+"
with pytest.raises(ValueError, match=err_msg): with pytest.raises(ValueError, match=err_msg):
llm.encode(prompts, use_tqdm=False) llm.encode(prompts, pooling_task="token_classify", use_tqdm=False)
def test_score_api(llm: LLM): def test_score_api(llm: LLM):
......
...@@ -35,6 +35,13 @@ def llm(): ...@@ -35,6 +35,13 @@ def llm():
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
@pytest.mark.skip_global_cleanup
def test_encode_api(llm: LLM):
outputs = llm.encode(prompts, pooling_task="token_embed", use_tqdm=False)
multi_vector = outputs[0].outputs.data
assert multi_vector.shape == (11, 384)
def test_pooling_params(llm: LLM): def test_pooling_params(llm: LLM):
def get_outputs(normalize): def get_outputs(normalize):
outputs = llm.embed( outputs = llm.embed(
......
...@@ -57,20 +57,24 @@ def test_multiple_pooling_params(llm: LLM): ...@@ -57,20 +57,24 @@ def test_multiple_pooling_params(llm: LLM):
] ]
# Multiple PoolingParams should be matched with each prompt # Multiple PoolingParams should be matched with each prompt
outputs = llm.encode(PROMPTS, pooling_params=pooling_params) outputs = llm.encode(PROMPTS, pooling_params=pooling_params, pooling_task="embed")
assert len(PROMPTS) == len(outputs) assert len(PROMPTS) == len(outputs)
# Exception raised, if the size of params does not match the size of prompts # Exception raised, if the size of params does not match the size of prompts
with pytest.raises(ValueError): with pytest.raises(ValueError):
outputs = llm.encode(PROMPTS, pooling_params=pooling_params[:3]) outputs = llm.encode(
PROMPTS, pooling_params=pooling_params[:3], pooling_task="embed"
)
# Single PoolingParams should be applied to every prompt # Single PoolingParams should be applied to every prompt
single_pooling_params = PoolingParams() single_pooling_params = PoolingParams()
outputs = llm.encode(PROMPTS, pooling_params=single_pooling_params) outputs = llm.encode(
PROMPTS, pooling_params=single_pooling_params, pooling_task="embed"
)
assert len(PROMPTS) == len(outputs) assert len(PROMPTS) == len(outputs)
# pooling_params is None, default params should be applied # pooling_params is None, default params should be applied
outputs = llm.encode(PROMPTS, pooling_params=None) outputs = llm.encode(PROMPTS, pooling_params=None, pooling_task="embed")
assert len(PROMPTS) == len(outputs) assert len(PROMPTS) == len(outputs)
......
...@@ -36,22 +36,23 @@ def llm(): ...@@ -36,22 +36,23 @@ def llm():
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
@pytest.mark.skip_global_cleanup
def test_pooling_params(llm: LLM): def test_pooling_params(llm: LLM):
def get_outputs(softmax): def get_outputs(activation):
outputs = llm.reward( outputs = llm.reward(
prompts, pooling_params=PoolingParams(softmax=softmax), use_tqdm=False prompts, pooling_params=PoolingParams(activation=activation), use_tqdm=False
) )
return torch.cat([x.outputs.data for x in outputs]) return torch.cat([x.outputs.data for x in outputs])
default = get_outputs(softmax=None) default = get_outputs(activation=None)
w_softmax = get_outputs(softmax=True) w_activation = get_outputs(activation=True)
wo_softmax = get_outputs(softmax=False) wo_activation = get_outputs(activation=False)
assert torch.allclose(default, w_softmax, atol=1e-2), "Default should use softmax." assert torch.allclose(default, w_activation, atol=1e-2), (
assert not torch.allclose(w_softmax, wo_softmax, atol=1e-2), ( "Default should use activation."
"wo_softmax should not use softmax."
) )
assert torch.allclose(softmax(wo_softmax), w_softmax, atol=1e-2), ( assert not torch.allclose(w_activation, wo_activation, atol=1e-2), (
"w_softmax should be close to softmax(wo_softmax)." "wo_activation should not use activation."
)
assert torch.allclose(softmax(wo_activation), w_activation, atol=1e-2), (
"w_activation should be close to activation(wo_activation)."
) )
...@@ -17,6 +17,7 @@ from tests.utils import RemoteOpenAIServer ...@@ -17,6 +17,7 @@ from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
EMBED_DTYPE_TO_TORCH_DTYPE, EMBED_DTYPE_TO_TORCH_DTYPE,
EmbeddingResponse, EmbeddingResponse,
PoolingResponse,
) )
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
...@@ -509,3 +510,20 @@ async def test_normalize(server: RemoteOpenAIServer, model_name: str): ...@@ -509,3 +510,20 @@ async def test_normalize(server: RemoteOpenAIServer, model_name: str):
assert torch.allclose(w_normal, F.normalize(wo_normal, p=2, dim=-1), atol=1e-2), ( assert torch.allclose(w_normal, F.normalize(wo_normal, p=2, dim=-1), atol=1e-2), (
"w_normal should be close to normal(wo_normal)." "w_normal should be close to normal(wo_normal)."
) )
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_pooling(server: RemoteOpenAIServer, model_name: str):
input_text = ["The chef prepared a delicious meal."]
response = requests.post(
server.url_for("pooling"),
json={"model": model_name, "input": input_text, "encoding_format": "float"},
)
poolings = PoolingResponse.model_validate(response.json())
assert len(poolings.data) == 1
assert len(poolings.data[0].data) == 11
assert len(poolings.data[0].data[0]) == 384
...@@ -7,7 +7,7 @@ import torch ...@@ -7,7 +7,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from tests.utils import RemoteOpenAIServer from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.openai.protocol import RerankResponse from vllm.entrypoints.openai.protocol import PoolingResponse, RerankResponse
MODEL_NAME = "BAAI/bge-reranker-base" MODEL_NAME = "BAAI/bge-reranker-base"
DTYPE = "bfloat16" DTYPE = "bfloat16"
...@@ -159,3 +159,20 @@ async def test_activation(server: RemoteOpenAIServer, model_name: str): ...@@ -159,3 +159,20 @@ async def test_activation(server: RemoteOpenAIServer, model_name: str):
assert torch.allclose(F.sigmoid(wo_activation), w_activation, atol=1e-2), ( assert torch.allclose(F.sigmoid(wo_activation), w_activation, atol=1e-2), (
"w_activation should be close to activation(wo_activation)." "w_activation should be close to activation(wo_activation)."
) )
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_pooling(server: RemoteOpenAIServer, model_name: str):
input_text = ["The chef prepared a delicious meal."]
response = requests.post(
server.url_for("pooling"),
json={"model": model_name, "input": input_text, "encoding_format": "float"},
)
poolings = PoolingResponse.model_validate(response.json())
assert len(poolings.data) == 1
assert len(poolings.data[0].data) == 11
assert len(poolings.data[0].data[0]) == 1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from transformers import AutoModel
from tests.models.utils import check_embeddings_close
@pytest.mark.parametrize(
"model",
["BAAI/bge-m3"],
)
@pytest.mark.parametrize("dtype", ["half"])
@torch.inference_mode
def test_embed_models(hf_runner, vllm_runner, example_prompts, model: str, dtype: str):
with vllm_runner(
model,
runner="pooling",
max_model_len=None,
) as vllm_model:
vllm_outputs = vllm_model.token_embed(example_prompts)
with hf_runner(
model,
auto_cls=AutoModel,
) as hf_model:
tokenizer = hf_model.tokenizer
hf_outputs = []
for prompt in example_prompts:
inputs = tokenizer([prompt], return_tensors="pt")
inputs = hf_model.wrap_device(inputs)
output = hf_model.model(**inputs)
embedding = output.last_hidden_state[0].float()
# normal
hf_outputs.append(embedding.cpu())
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
check_embeddings_close(
embeddings_0_lst=hf_output,
embeddings_1_lst=vllm_output,
name_0="hf",
name_1="vllm",
tol=1e-2,
)
...@@ -93,7 +93,7 @@ def test_embed_models_using_normalize( ...@@ -93,7 +93,7 @@ def test_embed_models_using_normalize(
], ],
) )
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
def test_reward_models_using_softmax( def test_reward_models_using_activation(
hf_runner, hf_runner,
vllm_runner, vllm_runner,
example_prompts, example_prompts,
...@@ -104,22 +104,64 @@ def test_reward_models_using_softmax( ...@@ -104,22 +104,64 @@ def test_reward_models_using_softmax(
model, model,
max_model_len=1024, max_model_len=1024,
dtype=dtype, dtype=dtype,
pooler_config=PoolerConfig(softmax=False), pooler_config=PoolerConfig(activation=False),
) as vllm_model: ) as vllm_model:
wo_softmax = vllm_model.encode(example_prompts) wo_activation = vllm_model.reward(example_prompts)
with vllm_runner( with vllm_runner(
model, max_model_len=1024, dtype=dtype, pooler_config=PoolerConfig(softmax=True) model,
max_model_len=1024,
dtype=dtype,
pooler_config=PoolerConfig(activation=True),
) as vllm_model: ) as vllm_model:
w_softmax = vllm_model.encode(example_prompts) w_activation = vllm_model.reward(example_prompts)
for wo, w in zip(wo_softmax, w_softmax): for wo, w in zip(wo_activation, w_activation):
wo = torch.tensor(wo) wo = torch.tensor(wo)
w = torch.tensor(w) w = torch.tensor(w)
assert not torch.allclose(wo, w, atol=1e-2), ( assert not torch.allclose(wo, w, atol=1e-2), (
"pooler_config softmax is not working" "pooler_config activation is not working"
) )
assert torch.allclose(softmax(wo), w, atol=1e-2), ( assert torch.allclose(softmax(wo), w, atol=1e-2), (
"w_softmax should be close to softmax(wo_softmax)." "w_activation should be close to activation(wo_activation)."
)
@pytest.mark.parametrize(
"model",
[
"intfloat/multilingual-e5-small",
],
)
@pytest.mark.parametrize("dtype", ["half"])
def test_multi_vector_retrieval_models_using_normalize(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
) -> None:
with vllm_runner(
model,
max_model_len=512,
dtype=dtype,
pooler_config=PoolerConfig(normalize=False),
) as vllm_model:
wo_normalize = vllm_model.token_embed(example_prompts)
with vllm_runner(
model,
max_model_len=512,
dtype=dtype,
pooler_config=PoolerConfig(normalize=True),
) as vllm_model:
w_normalize = vllm_model.token_embed(example_prompts)
for wo, w in zip(wo_normalize, w_normalize):
assert not torch.allclose(wo, w, atol=1e-2), (
"pooler_config normalize is not working"
)
assert torch.allclose(F.normalize(wo, p=2, dim=-1), w, atol=1e-2), (
"w_normal should be close to normal(wo_normal)."
) )
...@@ -19,7 +19,7 @@ def test_bert_models( ...@@ -19,7 +19,7 @@ def test_bert_models(
dtype: str, dtype: str,
) -> None: ) -> None:
with vllm_runner(model, max_model_len=None, dtype=dtype) as vllm_model: with vllm_runner(model, max_model_len=None, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts) vllm_outputs = vllm_model.token_classify(example_prompts)
with hf_runner( with hf_runner(
model, dtype=dtype, auto_cls=AutoModelForTokenClassification model, dtype=dtype, auto_cls=AutoModelForTokenClassification
...@@ -50,7 +50,7 @@ def test_modernbert_models( ...@@ -50,7 +50,7 @@ def test_modernbert_models(
dtype: str, dtype: str,
) -> None: ) -> None:
with vllm_runner(model, max_model_len=None, dtype=dtype) as vllm_model: with vllm_runner(model, max_model_len=None, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts) vllm_outputs = vllm_model.token_classify(example_prompts)
with hf_runner( with hf_runner(
model, dtype=dtype, auto_cls=AutoModelForTokenClassification model, dtype=dtype, auto_cls=AutoModelForTokenClassification
......
...@@ -39,7 +39,7 @@ def _run_test( ...@@ -39,7 +39,7 @@ def _run_test(
max_num_seqs=32, max_num_seqs=32,
default_torch_num_threads=1, default_torch_num_threads=1,
) as vllm_model: ) as vllm_model:
vllm_model.encode(prompt) vllm_model.llm.encode(prompt, pooling_task="token_classify")
MODELS = ["mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"] MODELS = ["mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"]
......
...@@ -30,7 +30,7 @@ class MyGemma2Embedding(nn.Module): ...@@ -30,7 +30,7 @@ class MyGemma2Embedding(nn.Module):
self.pooler = DispatchPooler( self.pooler = DispatchPooler(
{ {
"encode": Pooler.for_encode(pooler_config), "token_embed": Pooler.for_token_embed(pooler_config),
"embed": Pooler.for_embed(pooler_config), "embed": Pooler.for_embed(pooler_config),
} }
) )
......
...@@ -93,7 +93,7 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str): ...@@ -93,7 +93,7 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
out_data_format="b64_json", out_data_format="b64_json",
) )
pooling_params = PoolingParams(task="encode", softmax=False) pooling_params = PoolingParams(activation=False)
with vllm_runner( with vllm_runner(
model_name, model_name,
...@@ -108,8 +108,7 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str): ...@@ -108,8 +108,7 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
io_processor_plugin="prithvi_to_tiff", io_processor_plugin="prithvi_to_tiff",
) as llm_runner: ) as llm_runner:
pooler_output = llm_runner.get_llm().encode( pooler_output = llm_runner.get_llm().encode(
img_prompt, img_prompt, pooling_params=pooling_params, pooling_task="token_classify"
pooling_params=pooling_params,
) )
output = pooler_output[0].outputs output = pooler_output[0].outputs
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import pytest import pytest
from tests.models.utils import EmbedModelInfo from tests.models.utils import EmbedModelInfo
from vllm import PoolingParams from vllm import PoolingParams
from vllm.config import ModelConfig from vllm.config import ModelConfig, PoolerConfig
EMBEDDING_MODELS = [ EMBEDDING_MODELS = [
EmbedModelInfo("intfloat/multilingual-e5-small", is_matryoshka=False), EmbedModelInfo("intfloat/multilingual-e5-small", is_matryoshka=False),
...@@ -15,6 +17,15 @@ EMBEDDING_MODELS = [ ...@@ -15,6 +17,15 @@ EMBEDDING_MODELS = [
), ),
] ]
classify_parameters = ["activation"]
embed_parameters = ["dimensions", "normalize"]
step_pooling_parameters = ["step_tag_id", "returned_token_ids"]
@dataclass()
class MockModelConfig:
pooler_config: PoolerConfig
def test_task(): def test_task():
pooling_params = PoolingParams() pooling_params = PoolingParams()
...@@ -24,25 +35,27 @@ def test_task(): ...@@ -24,25 +35,27 @@ def test_task():
pooling_params.verify(task="score") pooling_params.verify(task="score")
with pytest.raises(ValueError): with pytest.raises(ValueError):
pooling_params.verify(task="encode") pooling_params.verify(task="classify")
def test_embed(): def test_embed():
task = "embed" task = "embed"
model_config = MockModelConfig(pooler_config=PoolerConfig(pooling_type="CLS"))
pooling_params = PoolingParams(normalize=None) pooling_params = PoolingParams(normalize=None)
pooling_params.verify(task=task) pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(normalize=True) pooling_params = PoolingParams(normalize=True)
pooling_params.verify(task=task) pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(normalize=False) pooling_params = PoolingParams(normalize=False)
pooling_params.verify(task=task) pooling_params.verify(task=task, model_config=model_config)
invalid_parameters = ["activation", "softmax"] invalid_parameters = classify_parameters + step_pooling_parameters
for p in invalid_parameters: for p in invalid_parameters:
with pytest.raises(ValueError): with pytest.raises(ValueError):
pooling_params = PoolingParams(**{p: True}) pooling_params = PoolingParams(**{p: True})
pooling_params.verify(task=task) pooling_params.verify(task=task, model_config=model_config)
@pytest.mark.parametrize("model_info", EMBEDDING_MODELS) @pytest.mark.parametrize("model_info", EMBEDDING_MODELS)
...@@ -73,35 +86,71 @@ def test_embed_dimensions(model_info: EmbedModelInfo): ...@@ -73,35 +86,71 @@ def test_embed_dimensions(model_info: EmbedModelInfo):
@pytest.mark.parametrize("task", ["score", "classify"]) @pytest.mark.parametrize("task", ["score", "classify"])
def test_classify(task): def test_classify(task):
model_config = MockModelConfig(pooler_config=PoolerConfig(pooling_type="CLS"))
pooling_params = PoolingParams(activation=None) pooling_params = PoolingParams(activation=None)
pooling_params.verify(task=task) pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(activation=True) pooling_params = PoolingParams(activation=True)
pooling_params.verify(task=task) pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(activation=False) pooling_params = PoolingParams(activation=False)
pooling_params.verify(task=task) pooling_params.verify(task=task, model_config=model_config)
invalid_parameters = embed_parameters + step_pooling_parameters
for p in invalid_parameters:
with pytest.raises(ValueError):
pooling_params = PoolingParams(**{p: True})
pooling_params.verify(task=task, model_config=model_config)
@pytest.mark.parametrize("pooling_type", ["ALL", "STEP"])
def test_token_embed(pooling_type: str):
task = "token_embed"
model_config = MockModelConfig(
pooler_config=PoolerConfig(pooling_type=pooling_type)
)
pooling_params = PoolingParams(normalize=None)
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(normalize=True)
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(normalize=False)
pooling_params.verify(task=task, model_config=model_config)
invalid_parameters = classify_parameters
if pooling_type != "STEP":
invalid_parameters = classify_parameters + step_pooling_parameters
invalid_parameters = ["dimensions", "normalize", "softmax"]
for p in invalid_parameters: for p in invalid_parameters:
with pytest.raises(ValueError): with pytest.raises(ValueError):
pooling_params = PoolingParams(**{p: True}) pooling_params = PoolingParams(**{p: True})
pooling_params.verify(task=task) pooling_params.verify(task=task, model_config=model_config)
def test_encode(): @pytest.mark.parametrize("pooling_type", ["ALL", "STEP"])
task = "encode" def test_token_classify(pooling_type: str):
pooling_params = PoolingParams(softmax=None) task = "token_classify"
pooling_params.verify(task=task) model_config = MockModelConfig(
pooler_config=PoolerConfig(pooling_type=pooling_type)
)
pooling_params = PoolingParams(softmax=True) pooling_params = PoolingParams(activation=None)
pooling_params.verify(task=task) pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(activation=True)
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(activation=False)
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(softmax=False) invalid_parameters = embed_parameters
pooling_params.verify(task=task) if pooling_type != "STEP":
invalid_parameters = embed_parameters + step_pooling_parameters
invalid_parameters = ["dimensions", "normalize", "activation"]
for p in invalid_parameters: for p in invalid_parameters:
with pytest.raises(ValueError): with pytest.raises(ValueError):
pooling_params = PoolingParams(**{p: True}) pooling_params = PoolingParams(**{p: True})
pooling_params.verify(task=task) pooling_params.verify(task=task, model_config=model_config)
...@@ -951,7 +951,7 @@ class LLM: ...@@ -951,7 +951,7 @@ class LLM:
truncate_prompt_tokens: int | None = None, truncate_prompt_tokens: int | None = None,
use_tqdm: bool | Callable[..., tqdm] = True, use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: list[LoRARequest] | LoRARequest | None = None, lora_request: list[LoRARequest] | LoRARequest | None = None,
pooling_task: PoolingTask = "encode", pooling_task: PoolingTask | None = None,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
) -> list[PoolingRequestOutput]: ) -> list[PoolingRequestOutput]:
"""Apply pooling to the hidden states corresponding to the input """Apply pooling to the hidden states corresponding to the input
...@@ -986,25 +986,24 @@ class LLM: ...@@ -986,25 +986,24 @@ class LLM:
instead pass them via the `inputs` parameter. instead pass them via the `inputs` parameter.
""" """
if self.supported_tasks == ["encode"] and pooling_task is None: error_str = (
pooling_task = "encode" "pooling_task required for `LLM.encode`\n"
"Please use one of the more specific methods or set the "
"pooling_task when using `LLM.encode`:\n"
" - For embeddings, use `LLM.embed(...)` "
'or `pooling_task="embed"`.\n'
" - For classification logits, use `LLM.classify(...)` "
'or `pooling_task="classify"`.\n'
" - For similarity scores, use `LLM.score(...)`.\n"
" - For rewards, use `LLM.reward(...)` "
'or `pooling_task="token_classify"`\n'
" - For token classification, "
'use `pooling_task="token_classify"`\n'
' - For multi-vector retrieval, use `pooling_task="token_embed"`'
)
if pooling_task is None: if pooling_task is None:
pooling_task = "embed" if "embed" in self.supported_tasks else "encode" raise ValueError(error_str)
logger.warning_once(
"`LLM.encode` is currently using `pooling_task = %s`.\n"
"Please use one of the more specific methods or set the "
"task directly when using `LLM.encode`:\n"
" - For embeddings, use `LLM.embed(...)` "
'or `pooling_task="embed"`.\n'
" - For classification logits, use `LLM.classify(...)` "
'or `pooling_task="classify"`.\n'
" - For rewards, use `LLM.reward(...)` "
'or `pooling_task="reward"`\n'
" - For similarity scores, use `LLM.score(...)`.",
pooling_task,
)
model_config = self.model_config model_config = self.model_config
runner_type = model_config.runner_type runner_type = model_config.runner_type
...@@ -1206,7 +1205,7 @@ class LLM: ...@@ -1206,7 +1205,7 @@ class LLM:
lora_request=lora_request, lora_request=lora_request,
pooling_params=pooling_params, pooling_params=pooling_params,
truncate_prompt_tokens=truncate_prompt_tokens, truncate_prompt_tokens=truncate_prompt_tokens,
pooling_task="encode", pooling_task="token_classify",
) )
def _embedding_score( def _embedding_score(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment