Unverified Commit 2554b27b authored by Maximilien de Bayser's avatar Maximilien de Bayser Committed by GitHub
Browse files

[V0 Deprecation] Remove pooling model support in V0 (#23434)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: default avatarMax de Bayser <mbayser@br.ibm.com>
Co-authored-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent 934bebf1
...@@ -118,6 +118,8 @@ class PPTestSettings: ...@@ -118,6 +118,8 @@ class PPTestSettings:
multi_node_only: bool = False, multi_node_only: bool = False,
load_format: Optional[str] = None, load_format: Optional[str] = None,
): ):
vllm_major_versions = ["1"] if runner == "pooling" else ["0"]
return PPTestSettings( return PPTestSettings(
parallel_setups=[ parallel_setups=[
ParallelSetup(tp_size=tp_base, ParallelSetup(tp_size=tp_base,
...@@ -126,7 +128,7 @@ class PPTestSettings: ...@@ -126,7 +128,7 @@ class PPTestSettings:
chunked_prefill=False), chunked_prefill=False),
], ],
distributed_backends=["mp"], distributed_backends=["mp"],
vllm_major_versions=["0"], vllm_major_versions=vllm_major_versions,
runner=runner, runner=runner,
test_options=PPTestOptions(multi_node_only=multi_node_only, test_options=PPTestOptions(multi_node_only=multi_node_only,
load_format=load_format), load_format=load_format),
...@@ -213,7 +215,9 @@ TEXT_GENERATION_MODELS = { ...@@ -213,7 +215,9 @@ TEXT_GENERATION_MODELS = {
EMBEDDING_MODELS = { # type: ignore[var-annotated] EMBEDDING_MODELS = { # type: ignore[var-annotated]
# [Text-only] # [Text-only]
"intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(runner="pooling"), "intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(runner="pooling"),
"BAAI/bge-multilingual-gemma2": PPTestSettings.fast(runner="pooling"), # TODO: re-enable when https://github.com/vllm-project/vllm/issues/23883
# is fixed
#"BAAI/bge-multilingual-gemma2": PPTestSettings.fast(runner="pooling"),
"Qwen/Qwen2.5-Math-RM-72B": PPTestSettings.fast( "Qwen/Qwen2.5-Math-RM-72B": PPTestSettings.fast(
load_format="dummy", runner="pooling" load_format="dummy", runner="pooling"
), ),
......
...@@ -16,14 +16,6 @@ MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach" ...@@ -16,14 +16,6 @@ MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach"
prompts = ["The chef prepared a delicious meal."] prompts = ["The chef prepared a delicious meal."]
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def llm(): def llm():
# pytest caches the fixture so we use weakref.proxy to # pytest caches the fixture so we use weakref.proxy to
......
...@@ -27,14 +27,6 @@ TOKEN_IDS = [ ...@@ -27,14 +27,6 @@ TOKEN_IDS = [
] ]
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def llm(): def llm():
# pytest caches the fixture so we use weakref.proxy to # pytest caches the fixture so we use weakref.proxy to
......
...@@ -16,14 +16,6 @@ MODEL_NAME = "internlm/internlm2-1_8b-reward" ...@@ -16,14 +16,6 @@ MODEL_NAME = "internlm/internlm2-1_8b-reward"
prompts = ["The chef prepared a delicious meal."] prompts = ["The chef prepared a delicious meal."]
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def llm(): def llm():
# pytest caches the fixture so we use weakref.proxy to # pytest caches the fixture so we use weakref.proxy to
......
...@@ -14,14 +14,6 @@ from ...models.utils import softmax ...@@ -14,14 +14,6 @@ from ...models.utils import softmax
MODEL_NAME = "tomaarsen/Qwen3-Reranker-0.6B-seq-cls" MODEL_NAME = "tomaarsen/Qwen3-Reranker-0.6B-seq-cls"
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def llm(): def llm():
# pytest caches the fixture so we use weakref.proxy to # pytest caches the fixture so we use weakref.proxy to
......
...@@ -32,15 +32,16 @@ MODEL_CONFIGS = [ ...@@ -32,15 +32,16 @@ MODEL_CONFIGS = [
"tensor_parallel_size": 1, "tensor_parallel_size": 1,
"tokenizer_mode": "mistral", "tokenizer_mode": "mistral",
}, },
{ # TODO: re-enable once these tests are run with V1
"model": "sentence-transformers/all-MiniLM-L12-v2", # {
"enforce_eager": True, # "model": "sentence-transformers/all-MiniLM-L12-v2",
"gpu_memory_utilization": 0.20, # "enforce_eager": True,
"max_model_len": 64, # "gpu_memory_utilization": 0.20,
"max_num_batched_tokens": 64, # "max_model_len": 64,
"max_num_seqs": 64, # "max_num_batched_tokens": 64,
"tensor_parallel_size": 1, # "max_num_seqs": 64,
}, # "tensor_parallel_size": 1,
# },
] ]
......
...@@ -24,14 +24,6 @@ DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + ...@@ -24,14 +24,6 @@ DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' +
DTYPE = "bfloat16" DTYPE = "bfloat16"
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server(): def server():
args = [ args = [
......
...@@ -14,14 +14,6 @@ MODEL_NAME = "BAAI/bge-reranker-base" ...@@ -14,14 +14,6 @@ MODEL_NAME = "BAAI/bge-reranker-base"
DTYPE = "bfloat16" DTYPE = "bfloat16"
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server(): def server():
args = ["--enforce-eager", "--max-model-len", "100", "--dtype", DTYPE] args = ["--enforce-eager", "--max-model-len", "100", "--dtype", DTYPE]
......
...@@ -12,15 +12,6 @@ from vllm.entrypoints.openai.protocol import ScoreResponse ...@@ -12,15 +12,6 @@ from vllm.entrypoints.openai.protocol import ScoreResponse
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
MODELS = [ MODELS = [
{ {
"name": "BAAI/bge-reranker-v2-m3", "name": "BAAI/bge-reranker-v2-m3",
......
...@@ -10,14 +10,6 @@ from vllm.platforms import current_platform ...@@ -10,14 +10,6 @@ from vllm.platforms import current_platform
from ...utils import check_embeddings_close, check_transformers_version from ...utils import check_embeddings_close, check_transformers_version
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model", "model",
[ [
...@@ -32,21 +24,15 @@ def v1(run_with_both_engines): ...@@ -32,21 +24,15 @@ def v1(run_with_both_engines):
"intfloat/e5-mistral-7b-instruct", "intfloat/e5-mistral-7b-instruct",
# CPU v1 doesn't support sliding window # CPU v1 doesn't support sliding window
marks=[pytest.mark.core_model]), marks=[pytest.mark.core_model]),
# the qwen models interfere with each other (see PR
# https://github.com/vllm-project/vllm/pull/18720).
# To avoid this problem, for now we skip v0 since it will be
# deprecated anyway.
pytest.param("ssmits/Qwen2-7B-Instruct-embed-base", pytest.param("ssmits/Qwen2-7B-Instruct-embed-base",
marks=[pytest.mark.skip_v0, pytest.mark.cpu_model]), marks=[pytest.mark.cpu_model]),
# [Encoder-only] # [Encoder-only]
pytest.param("BAAI/bge-base-en-v1.5", marks=[pytest.mark.core_model]), pytest.param("BAAI/bge-base-en-v1.5", marks=[pytest.mark.core_model]),
pytest.param("sentence-transformers/all-MiniLM-L12-v2"), pytest.param("sentence-transformers/all-MiniLM-L12-v2"),
pytest.param("intfloat/multilingual-e5-small"), pytest.param("intfloat/multilingual-e5-small"),
pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct", pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"),
marks=[pytest.mark.skip_v1]),
# [Cross-Encoder] # [Cross-Encoder]
pytest.param("sentence-transformers/stsb-roberta-base-v2", pytest.param("sentence-transformers/stsb-roberta-base-v2"),
marks=[pytest.mark.skip_v1]),
], ],
) )
def test_models( def test_models(
......
...@@ -13,14 +13,6 @@ from ....conftest import HfRunner ...@@ -13,14 +13,6 @@ from ....conftest import HfRunner
from ...utils import check_transformers_version from ...utils import check_transformers_version
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
@pytest.fixture @pytest.fixture
def math_step_prompts(): def math_step_prompts():
# ruff: noqa: E501 # ruff: noqa: E501
......
...@@ -23,15 +23,6 @@ TEXTS_2 = [ ...@@ -23,15 +23,6 @@ TEXTS_2 = [
"The capital of Germany is Berlin.", "The capital of Germany is Berlin.",
] ]
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
DTYPE = "half" DTYPE = "half"
......
...@@ -323,8 +323,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { ...@@ -323,8 +323,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
_EMBEDDING_EXAMPLE_MODELS = { _EMBEDDING_EXAMPLE_MODELS = {
# [Text-only] # [Text-only]
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5", v0_only=True), "BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2", v0_only=True), # noqa: E501 "Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"), # noqa: E501
"GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"), "GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"),
"GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0", "GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0",
trust_remote_code=True), trust_remote_code=True),
...@@ -337,9 +337,9 @@ _EMBEDDING_EXAMPLE_MODELS = { ...@@ -337,9 +337,9 @@ _EMBEDDING_EXAMPLE_MODELS = {
"LlamaModel": _HfExamplesInfo("llama", is_available_online=False), "LlamaModel": _HfExamplesInfo("llama", is_available_online=False),
"MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"), "MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"),
"ModernBertModel": _HfExamplesInfo("Alibaba-NLP/gte-modernbert-base", "ModernBertModel": _HfExamplesInfo("Alibaba-NLP/gte-modernbert-base",
trust_remote_code=True, v0_only=True), trust_remote_code=True),
"NomicBertModel": _HfExamplesInfo("nomic-ai/nomic-embed-text-v2-moe", "NomicBertModel": _HfExamplesInfo("nomic-ai/nomic-embed-text-v2-moe",
trust_remote_code=True, v0_only=True), # noqa: E501 trust_remote_code=True), # noqa: E501
"Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"), "Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"),
"Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B", "Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B",
max_transformers_version="4.53", max_transformers_version="4.53",
...@@ -347,9 +347,9 @@ _EMBEDDING_EXAMPLE_MODELS = { ...@@ -347,9 +347,9 @@ _EMBEDDING_EXAMPLE_MODELS = {
"Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B", "Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B",
max_transformers_version="4.53", max_transformers_version="4.53",
transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers"), # noqa: E501 transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers"), # noqa: E501
"RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2", v0_only=True), # noqa: E501 "RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2"), # noqa: E501
"RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1", v0_only=True), # noqa: E501 "RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"), # noqa: E501
"XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small", v0_only=True), # noqa: E501 "XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small"), # noqa: E501
# [Multimodal] # [Multimodal]
"LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"), "LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"),
"Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full", "Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full",
...@@ -364,20 +364,19 @@ _SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = { ...@@ -364,20 +364,19 @@ _SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = {
"GPT2ForSequenceClassification": _HfExamplesInfo("nie3e/sentiment-polish-gpt2-small"), # noqa: E501 "GPT2ForSequenceClassification": _HfExamplesInfo("nie3e/sentiment-polish-gpt2-small"), # noqa: E501
# [Cross-encoder] # [Cross-encoder]
"BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2", v0_only=True), # noqa: E501 "BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2"), # noqa: E501
"GteNewForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-multilingual-reranker-base", # noqa: E501 "GteNewForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-multilingual-reranker-base", # noqa: E501
trust_remote_code=True, trust_remote_code=True,
hf_overrides={ hf_overrides={
"architectures": ["GteNewForSequenceClassification"]}),# noqa: E501 "architectures": ["GteNewForSequenceClassification"]}),# noqa: E501
"ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501 "ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base"), # noqa: E501
"RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base", v0_only=True), # noqa: E501 "RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base"), # noqa: E501
"XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3", v0_only=True), # noqa: E501 "XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3"), # noqa: E501
} }
_AUTOMATIC_CONVERTED_MODELS = { _AUTOMATIC_CONVERTED_MODELS = {
# Use as_seq_cls_model for automatic conversion # Use as_seq_cls_model for automatic conversion
"GemmaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-gemma", # noqa: E501 "GemmaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-gemma", # noqa: E501
v0_only=True,
hf_overrides={"architectures": ["GemmaForSequenceClassification"], # noqa: E501 hf_overrides={"architectures": ["GemmaForSequenceClassification"], # noqa: E501
"classifier_from_token": ["Yes"], # noqa: E501 "classifier_from_token": ["Yes"], # noqa: E501
"method": "no_post_processing"}), # noqa: E501 "method": "no_post_processing"}), # noqa: E501
......
...@@ -9,10 +9,7 @@ from vllm.attention import AttentionMetadata, AttentionMetadataBuilder ...@@ -9,10 +9,7 @@ from vllm.attention import AttentionMetadata, AttentionMetadataBuilder
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.backends.utils import CommonAttentionState
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
from vllm.worker.pooling_model_runner import (
ModelInputForGPUWithPoolingMetadata)
class MockAttentionBackend(AttentionBackend): class MockAttentionBackend(AttentionBackend):
...@@ -114,54 +111,3 @@ def test_model_runner_input(): ...@@ -114,54 +111,3 @@ def test_model_runner_input():
assert (received_model_input.sampling_metadata.selected_token_indices == assert (received_model_input.sampling_metadata.selected_token_indices ==
sampling_metadata.selected_token_indices) sampling_metadata.selected_token_indices)
assert received_model_input.sampling_metadata.seq_groups is None assert received_model_input.sampling_metadata.seq_groups is None
def test_embedding_model_runner_input():
pooling_metadata = PoolingMetadata(
seq_groups=[[0]],
seq_data={},
prompt_lens=[1],
)
attn_metadata = AttentionMetadata(
num_prefills=1,
num_prefill_tokens=2,
num_decode_tokens=3,
slot_mapping=torch.zeros(1),
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
)
model_input = ModelInputForGPUWithPoolingMetadata(
input_tokens=torch.ones(10),
input_positions=torch.ones(10),
pooling_metadata=pooling_metadata,
attn_metadata=attn_metadata)
assert isinstance(model_input, ModelInputForGPUWithPoolingMetadata)
# Test round trip serialization.
tensor_dict = model_input.as_broadcastable_tensor_dict()
attn_backend = MockAttentionBackend()
received_model_input = (
ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict(
tensor_dict, attn_backend=attn_backend))
# Check that received copy has correct values.
assert isinstance(received_model_input,
ModelInputForGPUWithPoolingMetadata)
assert received_model_input.input_tokens is not None
assert (
received_model_input.input_tokens == model_input.input_tokens).all()
assert received_model_input.input_positions is not None
assert (received_model_input.input_positions == model_input.input_positions
).all()
assert received_model_input.multi_modal_kwargs is None
assert (received_model_input.multi_modal_kwargs ==
model_input.multi_modal_kwargs)
assert received_model_input.lora_requests is None
assert received_model_input.lora_requests == model_input.lora_requests
assert received_model_input.lora_mapping is None
assert received_model_input.lora_mapping == model_input.lora_mapping
for field in dataclasses.fields(AttentionMetadata):
assert getattr(received_model_input.attn_metadata, field.name,
None) == getattr(attn_metadata, field.name, None)
# Pooling metadata is not broadcast.
assert received_model_input.pooling_metadata is None
...@@ -1591,7 +1591,6 @@ class Scheduler: ...@@ -1591,7 +1591,6 @@ class Scheduler:
encoder_seq_data=encoder_seq_data, encoder_seq_data=encoder_seq_data,
cross_block_table=cross_block_table, cross_block_table=cross_block_table,
state=seq_group.state, state=seq_group.state,
token_type_ids=seq_group.token_type_ids,
# `multi_modal_data` will only be present for the 1st comm # `multi_modal_data` will only be present for the 1st comm
# between engine and worker. # between engine and worker.
# the subsequent comms can still use delta, but # the subsequent comms can still use delta, but
......
...@@ -1566,8 +1566,7 @@ class EngineArgs: ...@@ -1566,8 +1566,7 @@ class EngineArgs:
use_spec_decode = self.speculative_config is not None use_spec_decode = self.speculative_config is not None
if (is_gpu and not use_sliding_window and not use_spec_decode if (is_gpu and not use_sliding_window and not use_spec_decode
and not self.enable_lora and not self.enable_lora):
and model_config.runner_type != "pooling"):
self.enable_chunked_prefill = True self.enable_chunked_prefill = True
logger.warning( logger.warning(
"Chunked prefill is enabled by default for models " "Chunked prefill is enabled by default for models "
...@@ -1585,10 +1584,6 @@ class EngineArgs: ...@@ -1585,10 +1584,6 @@ class EngineArgs:
"OOM during the initial memory profiling phase, or result " "OOM during the initial memory profiling phase, or result "
"in low performance due to small KV cache size. Consider " "in low performance due to small KV cache size. Consider "
"setting --max-model-len to a smaller value.", max_model_len) "setting --max-model-len to a smaller value.", max_model_len)
elif (self.enable_chunked_prefill
and model_config.runner_type == "pooling"):
msg = "Chunked prefill is not supported for pooling models"
raise ValueError(msg)
# if using prefix caching, we must set a hash algo # if using prefix caching, we must set a hash algo
if self.enable_prefix_caching: if self.enable_prefix_caching:
......
...@@ -72,8 +72,8 @@ STOP_ITERATION = Exception() # Sentinel ...@@ -72,8 +72,8 @@ STOP_ITERATION = Exception() # Sentinel
class AsyncStream: class AsyncStream:
"""A stream of RequestOutputs or PoolingRequestOutputs for a request """A stream of RequestOutputs for a request that can be iterated over
that can be iterated over asynchronously via an async generator.""" asynchronously via an async generator."""
def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None: def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
self.request_id = request_id self.request_id = request_id
...@@ -81,8 +81,7 @@ class AsyncStream: ...@@ -81,8 +81,7 @@ class AsyncStream:
self._queue: asyncio.Queue = asyncio.Queue() self._queue: asyncio.Queue = asyncio.Queue()
self._finished = False self._finished = False
def put(self, item: Union[RequestOutput, PoolingRequestOutput, def put(self, item: Union[RequestOutput, Exception]) -> None:
Exception]) -> None:
if not self._finished: if not self._finished:
self._queue.put_nowait(item) self._queue.put_nowait(item)
...@@ -99,9 +98,7 @@ class AsyncStream: ...@@ -99,9 +98,7 @@ class AsyncStream:
def finished(self) -> bool: def finished(self) -> bool:
return self._finished return self._finished
async def generator( async def generator(self) -> AsyncGenerator[RequestOutput, None]:
self
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
try: try:
while True: while True:
result = await self._queue.get() result = await self._queue.get()
...@@ -151,8 +148,7 @@ class RequestTracker: ...@@ -151,8 +148,7 @@ class RequestTracker:
self.abort_request(rid, exception=exc) self.abort_request(rid, exception=exc)
def process_request_output(self, def process_request_output(self,
request_output: Union[RequestOutput, request_output: RequestOutput,
PoolingRequestOutput],
*, *,
verbose: bool = False) -> None: verbose: bool = False) -> None:
"""Process a request output from the engine.""" """Process a request output from the engine."""
...@@ -261,9 +257,7 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -261,9 +257,7 @@ class _AsyncLLMEngine(LLMEngine):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
async def step_async( async def step_async(self, virtual_engine: int) -> List[RequestOutput]:
self, virtual_engine: int
) -> List[Union[RequestOutput, PoolingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results. """Performs one decoding iteration and returns newly generated results.
The workers are ran asynchronously if possible. The workers are ran asynchronously if possible.
...@@ -405,7 +399,7 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -405,7 +399,7 @@ class _AsyncLLMEngine(LLMEngine):
self, self,
request_id: str, request_id: str,
prompt: PromptType, prompt: PromptType,
params: Union[SamplingParams, PoolingParams], params: SamplingParams,
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
...@@ -779,14 +773,14 @@ class AsyncLLMEngine(EngineClient): ...@@ -779,14 +773,14 @@ class AsyncLLMEngine(EngineClient):
self, self,
request_id: str, request_id: str,
prompt: PromptType, prompt: PromptType,
params: Union[SamplingParams, PoolingParams], params: SamplingParams,
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0, priority: int = 0,
data_parallel_rank: Optional[int] = None, data_parallel_rank: Optional[int] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]: ) -> AsyncGenerator[RequestOutput, None]:
if not self.is_running: if not self.is_running:
if self.start_engine_loop: if self.start_engine_loop:
self.start_background_loop() self.start_background_loop()
...@@ -908,7 +902,7 @@ class AsyncLLMEngine(EngineClient): ...@@ -908,7 +902,7 @@ class AsyncLLMEngine(EngineClient):
await self.abort(request_id) await self.abort(request_id)
raise raise
async def encode( def encode(
self, self,
prompt: PromptType, prompt: PromptType,
pooling_params: PoolingParams, pooling_params: PoolingParams,
...@@ -918,85 +912,8 @@ class AsyncLLMEngine(EngineClient): ...@@ -918,85 +912,8 @@ class AsyncLLMEngine(EngineClient):
priority: int = 0, priority: int = 0,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> AsyncGenerator[PoolingRequestOutput, None]: ) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Generate outputs for a request from a pooling model. raise NotImplementedError(
"Pooling models are not supported in vLLM V0")
Generate outputs for a request. This method is a coroutine. It adds the
request into the waiting queue of the LLMEngine and streams the outputs
from the LLMEngine to the caller.
Args:
prompt: The prompt to the LLM. See
[`PromptType`][vllm.inputs.PromptType] for more details about
the format of each input.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers.
priority: The priority of the request.
Only applicable with priority scheduling.
Yields:
The output `PoolingRequestOutput` objects from the LLMEngine
for the request.
Details:
- If the engine is not running, start the background loop,
which iteratively invokes
[`vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`][]
to process the waiting requests.
- Add the request to the engine's `RequestTracker`.
On the next background loop, this request will be sent to
the underlying engine.
Also, a corresponding `AsyncStream` will be created.
- Wait for the request outputs from `AsyncStream` and yield them.
Example:
```
# Please refer to entrypoints/api_server.py for
# the complete example.
# initialize the engine and the example input
# note that engine_args here is AsyncEngineArgs instance
engine = AsyncLLMEngine.from_engine_args(engine_args)
example_input = {
"input": "What is LLM?",
"request_id": 0,
}
# start the generation
results_generator = engine.encode(
example_input["input"],
PoolingParams(),
example_input["request_id"])
# get the results
final_output = None
async for request_output in results_generator:
if await request.is_disconnected():
# Abort the request if the client disconnects.
await engine.abort(request_id)
# Return or raise an error
...
final_output = request_output
# Process and return the final output
...
```
"""
try:
async for output in await self.add_request(
request_id,
prompt,
pooling_params,
lora_request=lora_request,
trace_headers=trace_headers,
priority=priority,
tokenization_kwargs=tokenization_kwargs,
):
yield LLMEngine.validate_output(output, PoolingRequestOutput)
except asyncio.CancelledError:
await self.abort(request_id)
raise
async def abort(self, request_id: Union[str, Iterable[str]]) -> None: async def abort(self, request_id: Union[str, Iterable[str]]) -> None:
"""Abort a request. """Abort a request.
...@@ -1104,8 +1021,8 @@ class AsyncLLMEngine(EngineClient): ...@@ -1104,8 +1021,8 @@ class AsyncLLMEngine(EngineClient):
async def is_sleeping(self) -> bool: async def is_sleeping(self) -> bool:
return self.engine.is_sleeping() return self.engine.is_sleeping()
async def add_lora(self, lora_request: LoRARequest) -> None: async def add_lora(self, lora_request: LoRARequest) -> bool:
self.engine.add_lora(lora_request) return self.engine.add_lora(lora_request)
async def collective_rpc(self, async def collective_rpc(self,
method: str, method: str,
......
...@@ -40,12 +40,11 @@ from vllm.multimodal.cache import processor_only_cache_from_config ...@@ -40,12 +40,11 @@ from vllm.multimodal.cache import processor_only_cache_from_config
from vllm.multimodal.processing import EncDecMultiModalProcessor from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.outputs import (PoolingRequestOutput, RequestOutput, from vllm.outputs import (PoolingRequestOutput, RequestOutput,
RequestOutputFactory) RequestOutputFactory)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup, from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup,
PoolingSequenceGroupOutput, Sequence, SequenceGroup, Sequence, SequenceGroup, SequenceGroupBase,
SequenceGroupBase, SequenceGroupMetadata, SequenceGroupMetadata, SequenceGroupOutput,
SequenceGroupOutput, SequenceStatus) SequenceStatus)
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer) init_tracer)
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import Detokenizer
...@@ -93,8 +92,7 @@ class SchedulerContext: ...@@ -93,8 +92,7 @@ class SchedulerContext:
def __init__(self) -> None: def __init__(self) -> None:
self.output_queue: Deque[OutputData] = deque() self.output_queue: Deque[OutputData] = deque()
self.request_outputs: List[Union[RequestOutput, self.request_outputs: List[RequestOutput] = []
PoolingRequestOutput]] = []
self.seq_group_metadata_list: Optional[ self.seq_group_metadata_list: Optional[
List[SequenceGroupMetadata]] = None List[SequenceGroupMetadata]] = None
self.scheduler_outputs: Optional[SchedulerOutputs] = None self.scheduler_outputs: Optional[SchedulerOutputs] = None
...@@ -261,8 +259,7 @@ class LLMEngine: ...@@ -261,8 +259,7 @@ class LLMEngine:
self.model_executor = executor_class(vllm_config=vllm_config) self.model_executor = executor_class(vllm_config=vllm_config)
if self.model_config.runner_type != "pooling": self._initialize_kv_caches()
self._initialize_kv_caches()
# If usage stat is enabled, collect relevant info. # If usage stat is enabled, collect relevant info.
if is_usage_stats_enabled(): if is_usage_stats_enabled():
...@@ -541,7 +538,7 @@ class LLMEngine: ...@@ -541,7 +538,7 @@ class LLMEngine:
self, self,
request_id: str, request_id: str,
processed_inputs: ProcessorInputs, processed_inputs: ProcessorInputs,
params: Union[SamplingParams, PoolingParams], params: SamplingParams,
arrival_time: float, arrival_time: float,
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
...@@ -577,7 +574,7 @@ class LLMEngine: ...@@ -577,7 +574,7 @@ class LLMEngine:
encoder_seq = (None if encoder_inputs is None else Sequence( encoder_seq = (None if encoder_inputs is None else Sequence(
seq_id, encoder_inputs, block_size, eos_token_id, lora_request)) seq_id, encoder_inputs, block_size, eos_token_id, lora_request))
# Create a SequenceGroup based on SamplingParams or PoolingParams # Create a SequenceGroup based on SamplingParams
if isinstance(params, SamplingParams): if isinstance(params, SamplingParams):
seq_group = self._create_sequence_group_with_sampling( seq_group = self._create_sequence_group_with_sampling(
request_id, request_id,
...@@ -588,18 +585,8 @@ class LLMEngine: ...@@ -588,18 +585,8 @@ class LLMEngine:
trace_headers=trace_headers, trace_headers=trace_headers,
encoder_seq=encoder_seq, encoder_seq=encoder_seq,
priority=priority) priority=priority)
elif isinstance(params, PoolingParams):
seq_group = self._create_sequence_group_with_pooling(
request_id,
seq,
params,
arrival_time=arrival_time,
lora_request=lora_request,
encoder_seq=encoder_seq,
priority=priority)
else: else:
raise ValueError( raise ValueError("SamplingParams must be provided.")
"Either SamplingParams or PoolingParams must be provided.")
# Add the sequence group to the scheduler with least unfinished seqs. # Add the sequence group to the scheduler with least unfinished seqs.
costs = [ costs = [
...@@ -618,7 +605,7 @@ class LLMEngine: ...@@ -618,7 +605,7 @@ class LLMEngine:
self, self,
request_id: str, request_id: str,
prompt: PromptType, prompt: PromptType,
params: Union[SamplingParams, PoolingParams], params: SamplingParams,
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
...@@ -636,9 +623,8 @@ class LLMEngine: ...@@ -636,9 +623,8 @@ class LLMEngine:
prompt: The prompt to the LLM. See prompt: The prompt to the LLM. See
[PromptType][vllm.inputs.PromptType] [PromptType][vllm.inputs.PromptType]
for more details about the format of each input. for more details about the format of each input.
params: Parameters for sampling or pooling. params: Parameters for sampling.
[SamplingParams][vllm.SamplingParams] for text generation. [SamplingParams][vllm.SamplingParams] for text generation.
[PoolingParams][vllm.PoolingParams] for pooling.
arrival_time: The arrival time of the request. If None, we use arrival_time: The arrival time of the request. If None, we use
the current monotonic time. the current monotonic time.
lora_request: The LoRA request to add. lora_request: The LoRA request to add.
...@@ -760,29 +746,6 @@ class LLMEngine: ...@@ -760,29 +746,6 @@ class LLMEngine:
return seq_group return seq_group
def _create_sequence_group_with_pooling(
self,
request_id: str,
seq: Sequence,
pooling_params: PoolingParams,
arrival_time: float,
lora_request: Optional[LoRARequest],
encoder_seq: Optional[Sequence] = None,
priority: int = 0,
) -> SequenceGroup:
"""Creates a SequenceGroup with PoolingParams."""
# Defensive copy of PoolingParams, which are used by the pooler
pooling_params = pooling_params.clone()
# Create the sequence group.
seq_group = SequenceGroup(request_id=request_id,
seqs=[seq],
arrival_time=arrival_time,
lora_request=lora_request,
pooling_params=pooling_params,
encoder_seq=encoder_seq,
priority=priority)
return seq_group
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
"""Aborts a request(s) with the given ID. """Aborts a request(s) with the given ID.
...@@ -856,18 +819,6 @@ class LLMEngine: ...@@ -856,18 +819,6 @@ class LLMEngine:
success = success and scheduler.reset_prefix_cache(device) success = success and scheduler.reset_prefix_cache(device)
return success return success
@staticmethod
def _process_sequence_group_outputs(
seq_group: SequenceGroup,
outputs: List[PoolingSequenceGroupOutput],
) -> None:
seq_group.pooled_data = outputs[0].data
for seq in seq_group.get_seqs():
seq.status = SequenceStatus.FINISHED_STOPPED
return
def _process_model_outputs(self, def _process_model_outputs(self,
ctx: SchedulerContext, ctx: SchedulerContext,
request_id: Optional[str] = None) -> None: request_id: Optional[str] = None) -> None:
...@@ -962,13 +913,10 @@ class LLMEngine: ...@@ -962,13 +913,10 @@ class LLMEngine:
seq_group.metrics.model_execute_time = ( seq_group.metrics.model_execute_time = (
o.model_execute_time) o.model_execute_time)
if self.model_config.runner_type == "pooling": self.output_processor.process_prompt_logprob(seq_group, output)
self._process_sequence_group_outputs(seq_group, output) if seq_group_meta.do_sample:
else: self.output_processor.process_outputs(seq_group, output,
self.output_processor.process_prompt_logprob(seq_group, output) is_async)
if seq_group_meta.do_sample:
self.output_processor.process_outputs(
seq_group, output, is_async)
if seq_group.is_finished(): if seq_group.is_finished():
finished_now.append(i) finished_now.append(i)
...@@ -1090,7 +1038,7 @@ class LLMEngine: ...@@ -1090,7 +1038,7 @@ class LLMEngine:
seq.append_token_id(sample.output_token, sample.logprobs, seq.append_token_id(sample.output_token, sample.logprobs,
sample.output_embed) sample.output_embed)
def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: def step(self) -> List[RequestOutput]:
"""Performs one decoding iteration and returns newly generated results. """Performs one decoding iteration and returns newly generated results.
<figure markdown="span"> <figure markdown="span">
......
...@@ -120,6 +120,7 @@ class RPCLoadAdapterRequest: ...@@ -120,6 +120,7 @@ class RPCLoadAdapterRequest:
@dataclass @dataclass
class RPCAdapterLoadedResponse: class RPCAdapterLoadedResponse:
request_id: str request_id: str
lora_loaded: bool
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest, RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest,
......
...@@ -6,7 +6,7 @@ import copy ...@@ -6,7 +6,7 @@ import copy
import pickle import pickle
from contextlib import contextmanager, suppress from contextlib import contextmanager, suppress
from typing import (Any, AsyncGenerator, Dict, Iterable, Iterator, List, from typing import (Any, AsyncGenerator, Dict, Iterable, Iterator, List,
Mapping, Optional, Union, cast) Mapping, Optional, Union)
import cloudpickle import cloudpickle
import psutil import psutil
...@@ -477,10 +477,8 @@ class MQLLMEngineClient(EngineClient): ...@@ -477,10 +477,8 @@ class MQLLMEngineClient(EngineClient):
Any priority other than 0 will lead to an error if the Any priority other than 0 will lead to an error if the
scheduling policy is not "priority". scheduling policy is not "priority".
""" """
return cast( return self._process_request(prompt, sampling_params, request_id,
AsyncGenerator[RequestOutput, None], lora_request, trace_headers, priority)
self._process_request(prompt, sampling_params, request_id,
lora_request, trace_headers, priority))
def encode( def encode(
self, self,
...@@ -490,45 +488,20 @@ class MQLLMEngineClient(EngineClient): ...@@ -490,45 +488,20 @@ class MQLLMEngineClient(EngineClient):
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0, priority: int = 0,
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> AsyncGenerator[PoolingRequestOutput, None]: ) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Generate outputs for a request from a pooling model. raise NotImplementedError(
"Pooling models are not supported in vLLM V0")
Generate outputs for a request. This method is a coroutine. It adds the
request into the waiting queue of the LLMEngine and streams the outputs
from the LLMEngine to the caller.
Args:
prompt: The prompt to the LLM. See
[`PromptType`][vllm.inputs.PromptType] for more details about
the format of each input.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers.
Yields:
The output `PoolingRequestOutput` objects from the LLMEngine
for the request.
"""
return cast(
AsyncGenerator[PoolingRequestOutput, None],
self._process_request(prompt,
pooling_params,
request_id,
lora_request,
trace_headers,
priority=priority))
async def _process_request( async def _process_request(
self, self,
prompt: PromptType, prompt: PromptType,
params: Union[SamplingParams, PoolingParams], params: SamplingParams,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0, priority: int = 0,
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[ ) -> AsyncGenerator[RequestOutput, None]:
PoolingRequestOutput, None]]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses.""" """Send an RPCGenerateRequest to the RPCServer and stream responses."""
# If already dead, error out. # If already dead, error out.
...@@ -547,7 +520,7 @@ class MQLLMEngineClient(EngineClient): ...@@ -547,7 +520,7 @@ class MQLLMEngineClient(EngineClient):
try: try:
# 2) Detach logits processors so that they can be pickled # 2) Detach logits processors so that they can be pickled
# separately (may require cloudpickle which is slower) # separately (may require cloudpickle which is slower)
if isinstance(params, SamplingParams) and params.logits_processors: if params.logits_processors:
# Defensive shallow copy # Defensive shallow copy
params = copy.copy(params) params = copy.copy(params)
logits_processors = params.logits_processors logits_processors = params.logits_processors
...@@ -646,13 +619,14 @@ class MQLLMEngineClient(EngineClient): ...@@ -646,13 +619,14 @@ class MQLLMEngineClient(EngineClient):
raise request_output raise request_output
return request_output.is_sleeping return request_output.is_sleeping
async def add_lora(self, lora_request: LoRARequest) -> None: async def add_lora(self, lora_request: LoRARequest) -> bool:
"""Load a new LoRA adapter into the engine for future requests.""" """Load a new LoRA adapter into the engine for future requests."""
# Uses the same I/O as generate requests # Uses the same I/O as generate requests
request = RPCLoadAdapterRequest(lora_request) request = RPCLoadAdapterRequest(lora_request)
# Create output queue for this request. # Create output queue for this request.
queue: asyncio.Queue[Union[None, BaseException]] = asyncio.Queue() queue: asyncio.Queue[Union[
BaseException, RPCAdapterLoadedResponse]] = asyncio.Queue()
self.output_queues[request.request_id] = queue self.output_queues[request.request_id] = queue
# Send the request # Send the request
...@@ -666,3 +640,4 @@ class MQLLMEngineClient(EngineClient): ...@@ -666,3 +640,4 @@ class MQLLMEngineClient(EngineClient):
# Raise on error, otherwise happily return None # Raise on error, otherwise happily return None
if isinstance(request_output, BaseException): if isinstance(request_output, BaseException):
raise request_output raise request_output
return request_output.lora_loaded
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment