"vscode:/vscode.git/clone" did not exist on "d70249e2e9b15655020efe08f3ef46fd908758aa"
Unverified Commit 2554b27b authored by Maximilien de Bayser's avatar Maximilien de Bayser Committed by GitHub
Browse files

[V0 Deprecation] Remove pooling model support in V0 (#23434)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: default avatarMax de Bayser <mbayser@br.ibm.com>
Co-authored-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent 934bebf1
......@@ -118,6 +118,8 @@ class PPTestSettings:
multi_node_only: bool = False,
load_format: Optional[str] = None,
):
vllm_major_versions = ["1"] if runner == "pooling" else ["0"]
return PPTestSettings(
parallel_setups=[
ParallelSetup(tp_size=tp_base,
......@@ -126,7 +128,7 @@ class PPTestSettings:
chunked_prefill=False),
],
distributed_backends=["mp"],
vllm_major_versions=["0"],
vllm_major_versions=vllm_major_versions,
runner=runner,
test_options=PPTestOptions(multi_node_only=multi_node_only,
load_format=load_format),
......@@ -213,7 +215,9 @@ TEXT_GENERATION_MODELS = {
EMBEDDING_MODELS = { # type: ignore[var-annotated]
# [Text-only]
"intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(runner="pooling"),
"BAAI/bge-multilingual-gemma2": PPTestSettings.fast(runner="pooling"),
# TODO: re-enable when https://github.com/vllm-project/vllm/issues/23883
# is fixed
#"BAAI/bge-multilingual-gemma2": PPTestSettings.fast(runner="pooling"),
"Qwen/Qwen2.5-Math-RM-72B": PPTestSettings.fast(
load_format="dummy", runner="pooling"
),
......
......@@ -16,14 +16,6 @@ MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach"
prompts = ["The chef prepared a delicious meal."]
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
@pytest.fixture(scope="module")
def llm():
# pytest caches the fixture so we use weakref.proxy to
......
......@@ -27,14 +27,6 @@ TOKEN_IDS = [
]
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
@pytest.fixture(scope="module")
def llm():
# pytest caches the fixture so we use weakref.proxy to
......
......@@ -16,14 +16,6 @@ MODEL_NAME = "internlm/internlm2-1_8b-reward"
prompts = ["The chef prepared a delicious meal."]
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
@pytest.fixture(scope="module")
def llm():
# pytest caches the fixture so we use weakref.proxy to
......
......@@ -14,14 +14,6 @@ from ...models.utils import softmax
MODEL_NAME = "tomaarsen/Qwen3-Reranker-0.6B-seq-cls"
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
@pytest.fixture(scope="module")
def llm():
# pytest caches the fixture so we use weakref.proxy to
......
......@@ -32,15 +32,16 @@ MODEL_CONFIGS = [
"tensor_parallel_size": 1,
"tokenizer_mode": "mistral",
},
{
"model": "sentence-transformers/all-MiniLM-L12-v2",
"enforce_eager": True,
"gpu_memory_utilization": 0.20,
"max_model_len": 64,
"max_num_batched_tokens": 64,
"max_num_seqs": 64,
"tensor_parallel_size": 1,
},
# TODO: re-enable once these tests are run with V1
# {
# "model": "sentence-transformers/all-MiniLM-L12-v2",
# "enforce_eager": True,
# "gpu_memory_utilization": 0.20,
# "max_model_len": 64,
# "max_num_batched_tokens": 64,
# "max_num_seqs": 64,
# "tensor_parallel_size": 1,
# },
]
......
......@@ -24,14 +24,6 @@ DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' +
DTYPE = "bfloat16"
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
@pytest.fixture(scope="module")
def server():
args = [
......
......@@ -14,14 +14,6 @@ MODEL_NAME = "BAAI/bge-reranker-base"
DTYPE = "bfloat16"
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
@pytest.fixture(scope="module")
def server():
args = ["--enforce-eager", "--max-model-len", "100", "--dtype", DTYPE]
......
......@@ -12,15 +12,6 @@ from vllm.entrypoints.openai.protocol import ScoreResponse
from ...utils import RemoteOpenAIServer
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
MODELS = [
{
"name": "BAAI/bge-reranker-v2-m3",
......
......@@ -10,14 +10,6 @@ from vllm.platforms import current_platform
from ...utils import check_embeddings_close, check_transformers_version
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
@pytest.mark.parametrize(
"model",
[
......@@ -32,21 +24,15 @@ def v1(run_with_both_engines):
"intfloat/e5-mistral-7b-instruct",
# CPU v1 doesn't support sliding window
marks=[pytest.mark.core_model]),
# the qwen models interfere with each other (see PR
# https://github.com/vllm-project/vllm/pull/18720).
# To avoid this problem, for now we skip v0 since it will be
# deprecated anyway.
pytest.param("ssmits/Qwen2-7B-Instruct-embed-base",
marks=[pytest.mark.skip_v0, pytest.mark.cpu_model]),
marks=[pytest.mark.cpu_model]),
# [Encoder-only]
pytest.param("BAAI/bge-base-en-v1.5", marks=[pytest.mark.core_model]),
pytest.param("sentence-transformers/all-MiniLM-L12-v2"),
pytest.param("intfloat/multilingual-e5-small"),
pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct",
marks=[pytest.mark.skip_v1]),
pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"),
# [Cross-Encoder]
pytest.param("sentence-transformers/stsb-roberta-base-v2",
marks=[pytest.mark.skip_v1]),
pytest.param("sentence-transformers/stsb-roberta-base-v2"),
],
)
def test_models(
......
......@@ -13,14 +13,6 @@ from ....conftest import HfRunner
from ...utils import check_transformers_version
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
@pytest.fixture
def math_step_prompts():
# ruff: noqa: E501
......
......@@ -23,15 +23,6 @@ TEXTS_2 = [
"The capital of Germany is Berlin.",
]
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
DTYPE = "half"
......
......@@ -323,8 +323,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
_EMBEDDING_EXAMPLE_MODELS = {
# [Text-only]
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5", v0_only=True),
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2", v0_only=True), # noqa: E501
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"), # noqa: E501
"GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"),
"GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0",
trust_remote_code=True),
......@@ -337,9 +337,9 @@ _EMBEDDING_EXAMPLE_MODELS = {
"LlamaModel": _HfExamplesInfo("llama", is_available_online=False),
"MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"),
"ModernBertModel": _HfExamplesInfo("Alibaba-NLP/gte-modernbert-base",
trust_remote_code=True, v0_only=True),
trust_remote_code=True),
"NomicBertModel": _HfExamplesInfo("nomic-ai/nomic-embed-text-v2-moe",
trust_remote_code=True, v0_only=True), # noqa: E501
trust_remote_code=True), # noqa: E501
"Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"),
"Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B",
max_transformers_version="4.53",
......@@ -347,9 +347,9 @@ _EMBEDDING_EXAMPLE_MODELS = {
"Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B",
max_transformers_version="4.53",
transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers"), # noqa: E501
"RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2", v0_only=True), # noqa: E501
"RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1", v0_only=True), # noqa: E501
"XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small", v0_only=True), # noqa: E501
"RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2"), # noqa: E501
"RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"), # noqa: E501
"XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small"), # noqa: E501
# [Multimodal]
"LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"),
"Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full",
......@@ -364,20 +364,19 @@ _SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = {
"GPT2ForSequenceClassification": _HfExamplesInfo("nie3e/sentiment-polish-gpt2-small"), # noqa: E501
# [Cross-encoder]
"BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2", v0_only=True), # noqa: E501
"BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2"), # noqa: E501
"GteNewForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-multilingual-reranker-base", # noqa: E501
trust_remote_code=True,
hf_overrides={
"architectures": ["GteNewForSequenceClassification"]}),# noqa: E501
"ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501
"RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base", v0_only=True), # noqa: E501
"XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3", v0_only=True), # noqa: E501
"ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base"), # noqa: E501
"RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base"), # noqa: E501
"XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3"), # noqa: E501
}
_AUTOMATIC_CONVERTED_MODELS = {
# Use as_seq_cls_model for automatic conversion
"GemmaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-gemma", # noqa: E501
v0_only=True,
hf_overrides={"architectures": ["GemmaForSequenceClassification"], # noqa: E501
"classifier_from_token": ["Yes"], # noqa: E501
"method": "no_post_processing"}), # noqa: E501
......
......@@ -9,10 +9,7 @@ from vllm.attention import AttentionMetadata, AttentionMetadataBuilder
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.utils import CommonAttentionState
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
from vllm.worker.pooling_model_runner import (
ModelInputForGPUWithPoolingMetadata)
class MockAttentionBackend(AttentionBackend):
......@@ -114,54 +111,3 @@ def test_model_runner_input():
assert (received_model_input.sampling_metadata.selected_token_indices ==
sampling_metadata.selected_token_indices)
assert received_model_input.sampling_metadata.seq_groups is None
def test_embedding_model_runner_input():
pooling_metadata = PoolingMetadata(
seq_groups=[[0]],
seq_data={},
prompt_lens=[1],
)
attn_metadata = AttentionMetadata(
num_prefills=1,
num_prefill_tokens=2,
num_decode_tokens=3,
slot_mapping=torch.zeros(1),
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
)
model_input = ModelInputForGPUWithPoolingMetadata(
input_tokens=torch.ones(10),
input_positions=torch.ones(10),
pooling_metadata=pooling_metadata,
attn_metadata=attn_metadata)
assert isinstance(model_input, ModelInputForGPUWithPoolingMetadata)
# Test round trip serialization.
tensor_dict = model_input.as_broadcastable_tensor_dict()
attn_backend = MockAttentionBackend()
received_model_input = (
ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict(
tensor_dict, attn_backend=attn_backend))
# Check that received copy has correct values.
assert isinstance(received_model_input,
ModelInputForGPUWithPoolingMetadata)
assert received_model_input.input_tokens is not None
assert (
received_model_input.input_tokens == model_input.input_tokens).all()
assert received_model_input.input_positions is not None
assert (received_model_input.input_positions == model_input.input_positions
).all()
assert received_model_input.multi_modal_kwargs is None
assert (received_model_input.multi_modal_kwargs ==
model_input.multi_modal_kwargs)
assert received_model_input.lora_requests is None
assert received_model_input.lora_requests == model_input.lora_requests
assert received_model_input.lora_mapping is None
assert received_model_input.lora_mapping == model_input.lora_mapping
for field in dataclasses.fields(AttentionMetadata):
assert getattr(received_model_input.attn_metadata, field.name,
None) == getattr(attn_metadata, field.name, None)
# Pooling metadata is not broadcast.
assert received_model_input.pooling_metadata is None
......@@ -1591,7 +1591,6 @@ class Scheduler:
encoder_seq_data=encoder_seq_data,
cross_block_table=cross_block_table,
state=seq_group.state,
token_type_ids=seq_group.token_type_ids,
# `multi_modal_data` will only be present for the 1st comm
# between engine and worker.
# the subsequent comms can still use delta, but
......
......@@ -1566,8 +1566,7 @@ class EngineArgs:
use_spec_decode = self.speculative_config is not None
if (is_gpu and not use_sliding_window and not use_spec_decode
and not self.enable_lora
and model_config.runner_type != "pooling"):
and not self.enable_lora):
self.enable_chunked_prefill = True
logger.warning(
"Chunked prefill is enabled by default for models "
......@@ -1585,10 +1584,6 @@ class EngineArgs:
"OOM during the initial memory profiling phase, or result "
"in low performance due to small KV cache size. Consider "
"setting --max-model-len to a smaller value.", max_model_len)
elif (self.enable_chunked_prefill
and model_config.runner_type == "pooling"):
msg = "Chunked prefill is not supported for pooling models"
raise ValueError(msg)
# if using prefix caching, we must set a hash algo
if self.enable_prefix_caching:
......
......@@ -72,8 +72,8 @@ STOP_ITERATION = Exception() # Sentinel
class AsyncStream:
"""A stream of RequestOutputs or PoolingRequestOutputs for a request
that can be iterated over asynchronously via an async generator."""
"""A stream of RequestOutputs for a request that can be iterated over
asynchronously via an async generator."""
def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
self.request_id = request_id
......@@ -81,8 +81,7 @@ class AsyncStream:
self._queue: asyncio.Queue = asyncio.Queue()
self._finished = False
def put(self, item: Union[RequestOutput, PoolingRequestOutput,
Exception]) -> None:
def put(self, item: Union[RequestOutput, Exception]) -> None:
if not self._finished:
self._queue.put_nowait(item)
......@@ -99,9 +98,7 @@ class AsyncStream:
def finished(self) -> bool:
return self._finished
async def generator(
self
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
async def generator(self) -> AsyncGenerator[RequestOutput, None]:
try:
while True:
result = await self._queue.get()
......@@ -151,8 +148,7 @@ class RequestTracker:
self.abort_request(rid, exception=exc)
def process_request_output(self,
request_output: Union[RequestOutput,
PoolingRequestOutput],
request_output: RequestOutput,
*,
verbose: bool = False) -> None:
"""Process a request output from the engine."""
......@@ -261,9 +257,7 @@ class _AsyncLLMEngine(LLMEngine):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
async def step_async(
self, virtual_engine: int
) -> List[Union[RequestOutput, PoolingRequestOutput]]:
async def step_async(self, virtual_engine: int) -> List[RequestOutput]:
"""Performs one decoding iteration and returns newly generated results.
The workers are ran asynchronously if possible.
......@@ -405,7 +399,7 @@ class _AsyncLLMEngine(LLMEngine):
self,
request_id: str,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
params: SamplingParams,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
......@@ -779,14 +773,14 @@ class AsyncLLMEngine(EngineClient):
self,
request_id: str,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
params: SamplingParams,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
) -> AsyncGenerator[RequestOutput, None]:
if not self.is_running:
if self.start_engine_loop:
self.start_background_loop()
......@@ -908,7 +902,7 @@ class AsyncLLMEngine(EngineClient):
await self.abort(request_id)
raise
async def encode(
def encode(
self,
prompt: PromptType,
pooling_params: PoolingParams,
......@@ -918,85 +912,8 @@ class AsyncLLMEngine(EngineClient):
priority: int = 0,
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Generate outputs for a request from a pooling model.
Generate outputs for a request. This method is a coroutine. It adds the
request into the waiting queue of the LLMEngine and streams the outputs
from the LLMEngine to the caller.
Args:
prompt: The prompt to the LLM. See
[`PromptType`][vllm.inputs.PromptType] for more details about
the format of each input.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers.
priority: The priority of the request.
Only applicable with priority scheduling.
Yields:
The output `PoolingRequestOutput` objects from the LLMEngine
for the request.
Details:
- If the engine is not running, start the background loop,
which iteratively invokes
[`vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`][]
to process the waiting requests.
- Add the request to the engine's `RequestTracker`.
On the next background loop, this request will be sent to
the underlying engine.
Also, a corresponding `AsyncStream` will be created.
- Wait for the request outputs from `AsyncStream` and yield them.
Example:
```
# Please refer to entrypoints/api_server.py for
# the complete example.
# initialize the engine and the example input
# note that engine_args here is AsyncEngineArgs instance
engine = AsyncLLMEngine.from_engine_args(engine_args)
example_input = {
"input": "What is LLM?",
"request_id": 0,
}
# start the generation
results_generator = engine.encode(
example_input["input"],
PoolingParams(),
example_input["request_id"])
# get the results
final_output = None
async for request_output in results_generator:
if await request.is_disconnected():
# Abort the request if the client disconnects.
await engine.abort(request_id)
# Return or raise an error
...
final_output = request_output
# Process and return the final output
...
```
"""
try:
async for output in await self.add_request(
request_id,
prompt,
pooling_params,
lora_request=lora_request,
trace_headers=trace_headers,
priority=priority,
tokenization_kwargs=tokenization_kwargs,
):
yield LLMEngine.validate_output(output, PoolingRequestOutput)
except asyncio.CancelledError:
await self.abort(request_id)
raise
raise NotImplementedError(
"Pooling models are not supported in vLLM V0")
async def abort(self, request_id: Union[str, Iterable[str]]) -> None:
"""Abort a request.
......@@ -1104,8 +1021,8 @@ class AsyncLLMEngine(EngineClient):
async def is_sleeping(self) -> bool:
return self.engine.is_sleeping()
async def add_lora(self, lora_request: LoRARequest) -> None:
self.engine.add_lora(lora_request)
async def add_lora(self, lora_request: LoRARequest) -> bool:
return self.engine.add_lora(lora_request)
async def collective_rpc(self,
method: str,
......
......@@ -40,12 +40,11 @@ from vllm.multimodal.cache import processor_only_cache_from_config
from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.outputs import (PoolingRequestOutput, RequestOutput,
RequestOutputFactory)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup,
PoolingSequenceGroupOutput, Sequence, SequenceGroup,
SequenceGroupBase, SequenceGroupMetadata,
SequenceGroupOutput, SequenceStatus)
Sequence, SequenceGroup, SequenceGroupBase,
SequenceGroupMetadata, SequenceGroupOutput,
SequenceStatus)
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer)
from vllm.transformers_utils.detokenizer import Detokenizer
......@@ -93,8 +92,7 @@ class SchedulerContext:
def __init__(self) -> None:
self.output_queue: Deque[OutputData] = deque()
self.request_outputs: List[Union[RequestOutput,
PoolingRequestOutput]] = []
self.request_outputs: List[RequestOutput] = []
self.seq_group_metadata_list: Optional[
List[SequenceGroupMetadata]] = None
self.scheduler_outputs: Optional[SchedulerOutputs] = None
......@@ -261,7 +259,6 @@ class LLMEngine:
self.model_executor = executor_class(vllm_config=vllm_config)
if self.model_config.runner_type != "pooling":
self._initialize_kv_caches()
# If usage stat is enabled, collect relevant info.
......@@ -541,7 +538,7 @@ class LLMEngine:
self,
request_id: str,
processed_inputs: ProcessorInputs,
params: Union[SamplingParams, PoolingParams],
params: SamplingParams,
arrival_time: float,
lora_request: Optional[LoRARequest],
trace_headers: Optional[Mapping[str, str]] = None,
......@@ -577,7 +574,7 @@ class LLMEngine:
encoder_seq = (None if encoder_inputs is None else Sequence(
seq_id, encoder_inputs, block_size, eos_token_id, lora_request))
# Create a SequenceGroup based on SamplingParams or PoolingParams
# Create a SequenceGroup based on SamplingParams
if isinstance(params, SamplingParams):
seq_group = self._create_sequence_group_with_sampling(
request_id,
......@@ -588,18 +585,8 @@ class LLMEngine:
trace_headers=trace_headers,
encoder_seq=encoder_seq,
priority=priority)
elif isinstance(params, PoolingParams):
seq_group = self._create_sequence_group_with_pooling(
request_id,
seq,
params,
arrival_time=arrival_time,
lora_request=lora_request,
encoder_seq=encoder_seq,
priority=priority)
else:
raise ValueError(
"Either SamplingParams or PoolingParams must be provided.")
raise ValueError("SamplingParams must be provided.")
# Add the sequence group to the scheduler with least unfinished seqs.
costs = [
......@@ -618,7 +605,7 @@ class LLMEngine:
self,
request_id: str,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
params: SamplingParams,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
......@@ -636,9 +623,8 @@ class LLMEngine:
prompt: The prompt to the LLM. See
[PromptType][vllm.inputs.PromptType]
for more details about the format of each input.
params: Parameters for sampling or pooling.
params: Parameters for sampling.
[SamplingParams][vllm.SamplingParams] for text generation.
[PoolingParams][vllm.PoolingParams] for pooling.
arrival_time: The arrival time of the request. If None, we use
the current monotonic time.
lora_request: The LoRA request to add.
......@@ -760,29 +746,6 @@ class LLMEngine:
return seq_group
def _create_sequence_group_with_pooling(
self,
request_id: str,
seq: Sequence,
pooling_params: PoolingParams,
arrival_time: float,
lora_request: Optional[LoRARequest],
encoder_seq: Optional[Sequence] = None,
priority: int = 0,
) -> SequenceGroup:
"""Creates a SequenceGroup with PoolingParams."""
# Defensive copy of PoolingParams, which are used by the pooler
pooling_params = pooling_params.clone()
# Create the sequence group.
seq_group = SequenceGroup(request_id=request_id,
seqs=[seq],
arrival_time=arrival_time,
lora_request=lora_request,
pooling_params=pooling_params,
encoder_seq=encoder_seq,
priority=priority)
return seq_group
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
"""Aborts a request(s) with the given ID.
......@@ -856,18 +819,6 @@ class LLMEngine:
success = success and scheduler.reset_prefix_cache(device)
return success
@staticmethod
def _process_sequence_group_outputs(
seq_group: SequenceGroup,
outputs: List[PoolingSequenceGroupOutput],
) -> None:
seq_group.pooled_data = outputs[0].data
for seq in seq_group.get_seqs():
seq.status = SequenceStatus.FINISHED_STOPPED
return
def _process_model_outputs(self,
ctx: SchedulerContext,
request_id: Optional[str] = None) -> None:
......@@ -962,13 +913,10 @@ class LLMEngine:
seq_group.metrics.model_execute_time = (
o.model_execute_time)
if self.model_config.runner_type == "pooling":
self._process_sequence_group_outputs(seq_group, output)
else:
self.output_processor.process_prompt_logprob(seq_group, output)
if seq_group_meta.do_sample:
self.output_processor.process_outputs(
seq_group, output, is_async)
self.output_processor.process_outputs(seq_group, output,
is_async)
if seq_group.is_finished():
finished_now.append(i)
......@@ -1090,7 +1038,7 @@ class LLMEngine:
seq.append_token_id(sample.output_token, sample.logprobs,
sample.output_embed)
def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
def step(self) -> List[RequestOutput]:
"""Performs one decoding iteration and returns newly generated results.
<figure markdown="span">
......
......@@ -120,6 +120,7 @@ class RPCLoadAdapterRequest:
@dataclass
class RPCAdapterLoadedResponse:
request_id: str
lora_loaded: bool
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest,
......
......@@ -6,7 +6,7 @@ import copy
import pickle
from contextlib import contextmanager, suppress
from typing import (Any, AsyncGenerator, Dict, Iterable, Iterator, List,
Mapping, Optional, Union, cast)
Mapping, Optional, Union)
import cloudpickle
import psutil
......@@ -477,10 +477,8 @@ class MQLLMEngineClient(EngineClient):
Any priority other than 0 will lead to an error if the
scheduling policy is not "priority".
"""
return cast(
AsyncGenerator[RequestOutput, None],
self._process_request(prompt, sampling_params, request_id,
lora_request, trace_headers, priority))
return self._process_request(prompt, sampling_params, request_id,
lora_request, trace_headers, priority)
def encode(
self,
......@@ -490,45 +488,20 @@ class MQLLMEngineClient(EngineClient):
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Generate outputs for a request from a pooling model.
Generate outputs for a request. This method is a coroutine. It adds the
request into the waiting queue of the LLMEngine and streams the outputs
from the LLMEngine to the caller.
Args:
prompt: The prompt to the LLM. See
[`PromptType`][vllm.inputs.PromptType] for more details about
the format of each input.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers.
Yields:
The output `PoolingRequestOutput` objects from the LLMEngine
for the request.
"""
return cast(
AsyncGenerator[PoolingRequestOutput, None],
self._process_request(prompt,
pooling_params,
request_id,
lora_request,
trace_headers,
priority=priority))
raise NotImplementedError(
"Pooling models are not supported in vLLM V0")
async def _process_request(
self,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
PoolingRequestOutput, None]]:
) -> AsyncGenerator[RequestOutput, None]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
# If already dead, error out.
......@@ -547,7 +520,7 @@ class MQLLMEngineClient(EngineClient):
try:
# 2) Detach logits processors so that they can be pickled
# separately (may require cloudpickle which is slower)
if isinstance(params, SamplingParams) and params.logits_processors:
if params.logits_processors:
# Defensive shallow copy
params = copy.copy(params)
logits_processors = params.logits_processors
......@@ -646,13 +619,14 @@ class MQLLMEngineClient(EngineClient):
raise request_output
return request_output.is_sleeping
async def add_lora(self, lora_request: LoRARequest) -> None:
async def add_lora(self, lora_request: LoRARequest) -> bool:
"""Load a new LoRA adapter into the engine for future requests."""
# Uses the same I/O as generate requests
request = RPCLoadAdapterRequest(lora_request)
# Create output queue for this request.
queue: asyncio.Queue[Union[None, BaseException]] = asyncio.Queue()
queue: asyncio.Queue[Union[
BaseException, RPCAdapterLoadedResponse]] = asyncio.Queue()
self.output_queues[request.request_id] = queue
# Send the request
......@@ -666,3 +640,4 @@ class MQLLMEngineClient(EngineClient):
# Raise on error, otherwise happily return None
if isinstance(request_output, BaseException):
raise request_output
return request_output.lora_loaded
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment