Unverified Commit 66c079ae authored by wang.yuqi's avatar wang.yuqi Committed by GitHub
Browse files

[Frontend][4/n] Improve pooling entrypoints | pooling. (#39153)


Signed-off-by: default avatarwang.yuqi <yuqi.wang@daocloud.io>
parent b6c9be50
...@@ -87,7 +87,6 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat: ...@@ -87,7 +87,6 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
serving_render = OpenAIServingRender( serving_render = OpenAIServingRender(
model_config=engine.model_config, model_config=engine.model_config,
renderer=engine.renderer, renderer=engine.renderer,
io_processor=engine.io_processor,
model_registry=models.registry, model_registry=models.registry,
request_logger=None, request_logger=None,
chat_template=None, chat_template=None,
...@@ -123,7 +122,6 @@ async def test_chat_error_non_stream(): ...@@ -123,7 +122,6 @@ async def test_chat_error_non_stream():
mock_engine.errored = False mock_engine.errored = False
mock_engine.model_config = MockModelConfig() mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock() mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config) mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_chat = _build_serving_chat(mock_engine) serving_chat = _build_serving_chat(mock_engine)
...@@ -173,7 +171,6 @@ async def test_chat_error_stream(): ...@@ -173,7 +171,6 @@ async def test_chat_error_stream():
mock_engine.errored = False mock_engine.errored = False
mock_engine.model_config = MockModelConfig() mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock() mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config) mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_chat = _build_serving_chat(mock_engine) serving_chat = _build_serving_chat(mock_engine)
......
...@@ -567,7 +567,6 @@ def _build_serving_render( ...@@ -567,7 +567,6 @@ def _build_serving_render(
return OpenAIServingRender( return OpenAIServingRender(
model_config=engine.model_config, model_config=engine.model_config,
renderer=engine.renderer, renderer=engine.renderer,
io_processor=engine.io_processor,
model_registry=model_registry, model_registry=model_registry,
request_logger=None, request_logger=None,
chat_template=CHAT_TEMPLATE, chat_template=CHAT_TEMPLATE,
...@@ -599,7 +598,6 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat: ...@@ -599,7 +598,6 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
class MockEngine: class MockEngine:
model_config: MockModelConfig = field(default_factory=MockModelConfig) model_config: MockModelConfig = field(default_factory=MockModelConfig)
input_processor: MagicMock = field(default_factory=MagicMock) input_processor: MagicMock = field(default_factory=MagicMock)
io_processor: MagicMock = field(default_factory=MagicMock)
renderer: MagicMock = field(default_factory=MagicMock) renderer: MagicMock = field(default_factory=MagicMock)
...@@ -632,7 +630,6 @@ async def test_serving_chat_returns_correct_model_name(): ...@@ -632,7 +630,6 @@ async def test_serving_chat_returns_correct_model_name():
mock_engine.errored = False mock_engine.errored = False
mock_engine.model_config = MockModelConfig() mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock() mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config) mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_chat = _build_serving_chat(mock_engine) serving_chat = _build_serving_chat(mock_engine)
...@@ -662,7 +659,6 @@ async def test_serving_chat_should_set_correct_max_tokens(): ...@@ -662,7 +659,6 @@ async def test_serving_chat_should_set_correct_max_tokens():
mock_engine.errored = False mock_engine.errored = False
mock_engine.model_config = MockModelConfig() mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock() mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config) mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_chat = _build_serving_chat(mock_engine) serving_chat = _build_serving_chat(mock_engine)
...@@ -693,7 +689,6 @@ async def test_serving_chat_should_set_correct_max_tokens(): ...@@ -693,7 +689,6 @@ async def test_serving_chat_should_set_correct_max_tokens():
mock_engine.errored = False mock_engine.errored = False
mock_engine.model_config = mock_model_config mock_engine.model_config = mock_model_config
mock_engine.input_processor = MagicMock() mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config) mock_engine.renderer = _build_renderer(mock_engine.model_config)
# Initialize the serving chat # Initialize the serving chat
...@@ -737,7 +732,6 @@ async def test_serving_chat_should_set_correct_max_tokens(): ...@@ -737,7 +732,6 @@ async def test_serving_chat_should_set_correct_max_tokens():
mock_engine.errored = False mock_engine.errored = False
mock_engine.model_config = mock_model_config mock_engine.model_config = mock_model_config
mock_engine.input_processor = MagicMock() mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config) mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_chat = _build_serving_chat(mock_engine) serving_chat = _build_serving_chat(mock_engine)
...@@ -779,7 +773,6 @@ async def test_serving_chat_should_set_correct_max_tokens(): ...@@ -779,7 +773,6 @@ async def test_serving_chat_should_set_correct_max_tokens():
mock_engine.errored = False mock_engine.errored = False
mock_engine.model_config = mock_model_config mock_engine.model_config = mock_model_config
mock_engine.input_processor = MagicMock() mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config) mock_engine.renderer = _build_renderer(mock_engine.model_config)
# Initialize the serving chat # Initialize the serving chat
...@@ -823,7 +816,6 @@ async def test_serving_chat_mistral_token_ids_prompt_is_validated(): ...@@ -823,7 +816,6 @@ async def test_serving_chat_mistral_token_ids_prompt_is_validated():
mock_engine.errored = False mock_engine.errored = False
mock_engine.model_config = MockModelConfig(skip_tokenizer_init=True) mock_engine.model_config = MockModelConfig(skip_tokenizer_init=True)
mock_engine.input_processor = MagicMock() mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_tokenizer = MagicMock(spec=MistralTokenizer) mock_tokenizer = MagicMock(spec=MistralTokenizer)
mock_renderer = MistralRenderer( mock_renderer = MistralRenderer(
...@@ -863,7 +855,6 @@ async def test_serving_chat_mistral_token_ids_prompt_too_long_is_rejected(): ...@@ -863,7 +855,6 @@ async def test_serving_chat_mistral_token_ids_prompt_too_long_is_rejected():
mock_engine.errored = False mock_engine.errored = False
mock_engine.model_config = MockModelConfig(skip_tokenizer_init=True) mock_engine.model_config = MockModelConfig(skip_tokenizer_init=True)
mock_engine.input_processor = MagicMock() mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_tokenizer = MagicMock(spec=MistralTokenizer) mock_tokenizer = MagicMock(spec=MistralTokenizer)
mock_renderer = MistralRenderer( mock_renderer = MistralRenderer(
...@@ -906,7 +897,6 @@ async def test_serving_chat_could_load_correct_generation_config(): ...@@ -906,7 +897,6 @@ async def test_serving_chat_could_load_correct_generation_config():
mock_engine.errored = False mock_engine.errored = False
mock_engine.model_config = mock_model_config mock_engine.model_config = mock_model_config
mock_engine.input_processor = MagicMock() mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config) mock_engine.renderer = _build_renderer(mock_engine.model_config)
# Initialize the serving chat # Initialize the serving chat
...@@ -952,7 +942,6 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type): ...@@ -952,7 +942,6 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
mock_engine.errored = False mock_engine.errored = False
mock_engine.model_config = mock_model_config mock_engine.model_config = mock_model_config
mock_engine.input_processor = MagicMock() mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config) mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_chat = _build_serving_chat(mock_engine) serving_chat = _build_serving_chat(mock_engine)
...@@ -1003,7 +992,6 @@ async def test_serving_chat_data_parallel_rank_extraction(): ...@@ -1003,7 +992,6 @@ async def test_serving_chat_data_parallel_rank_extraction():
mock_engine.errored = False mock_engine.errored = False
mock_engine.model_config = MockModelConfig() mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock() mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config) mock_engine.renderer = _build_renderer(mock_engine.model_config)
# Mock the generate method to return an async generator # Mock the generate method to return an async generator
...@@ -1095,7 +1083,6 @@ class TestServingChatWithHarmony: ...@@ -1095,7 +1083,6 @@ class TestServingChatWithHarmony:
mock_engine.errored = False mock_engine.errored = False
mock_engine.model_config = MockModelConfig() mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock() mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config) mock_engine.renderer = _build_renderer(mock_engine.model_config)
return mock_engine return mock_engine
...@@ -1732,7 +1719,6 @@ async def test_tool_choice_validation_without_parser(): ...@@ -1732,7 +1719,6 @@ async def test_tool_choice_validation_without_parser():
mock_engine.errored = False mock_engine.errored = False
mock_engine.model_config = MockModelConfig() mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock() mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config) mock_engine.renderer = _build_renderer(mock_engine.model_config)
models = OpenAIServingModels( models = OpenAIServingModels(
...@@ -1802,7 +1788,6 @@ async def test_streaming_n_gt1_independent_tool_parsers(): ...@@ -1802,7 +1788,6 @@ async def test_streaming_n_gt1_independent_tool_parsers():
mock_engine.errored = False mock_engine.errored = False
mock_engine.model_config = MockModelConfig() mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock() mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config) mock_engine.renderer = _build_renderer(mock_engine.model_config)
models = OpenAIServingModels( models = OpenAIServingModels(
......
...@@ -79,7 +79,6 @@ def _build_serving_completion(engine: AsyncLLM) -> OpenAIServingCompletion: ...@@ -79,7 +79,6 @@ def _build_serving_completion(engine: AsyncLLM) -> OpenAIServingCompletion:
serving_render = OpenAIServingRender( serving_render = OpenAIServingRender(
model_config=engine.model_config, model_config=engine.model_config,
renderer=engine.renderer, renderer=engine.renderer,
io_processor=engine.io_processor,
model_registry=models.registry, model_registry=models.registry,
request_logger=None, request_logger=None,
chat_template=None, chat_template=None,
...@@ -107,7 +106,6 @@ async def test_completion_error_non_stream(): ...@@ -107,7 +106,6 @@ async def test_completion_error_non_stream():
mock_engine.errored = False mock_engine.errored = False
mock_engine.model_config = MockModelConfig() mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock() mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config) mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_completion = _build_serving_completion(mock_engine) serving_completion = _build_serving_completion(mock_engine)
...@@ -157,7 +155,6 @@ async def test_completion_error_stream(): ...@@ -157,7 +155,6 @@ async def test_completion_error_stream():
mock_engine.errored = False mock_engine.errored = False
mock_engine.model_config = MockModelConfig() mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock() mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config) mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_completion = _build_serving_completion(mock_engine) serving_completion = _build_serving_completion(mock_engine)
......
...@@ -137,7 +137,6 @@ def mock_serving_setup(): ...@@ -137,7 +137,6 @@ def mock_serving_setup():
mock_engine.model_config = MockModelConfig() mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock() mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config) mock_engine.renderer = _build_renderer(mock_engine.model_config)
models = OpenAIServingModels( models = OpenAIServingModels(
...@@ -148,7 +147,6 @@ def mock_serving_setup(): ...@@ -148,7 +147,6 @@ def mock_serving_setup():
serving_render = OpenAIServingRender( serving_render = OpenAIServingRender(
model_config=mock_engine.model_config, model_config=mock_engine.model_config,
renderer=mock_engine.renderer, renderer=mock_engine.renderer,
io_processor=mock_engine.io_processor,
model_registry=models.registry, model_registry=models.registry,
request_logger=None, request_logger=None,
chat_template=None, chat_template=None,
......
...@@ -77,7 +77,6 @@ def _create_mock_engine(): ...@@ -77,7 +77,6 @@ def _create_mock_engine():
mock_engine.errored = False mock_engine.errored = False
mock_engine.model_config = MockModelConfig() mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock() mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
# renderer is accessed by OpenAIServing.__init__ and serving.py # renderer is accessed by OpenAIServing.__init__ and serving.py
mock_renderer = MagicMock() mock_renderer = MagicMock()
......
...@@ -218,7 +218,6 @@ class TestInitializeToolSessions: ...@@ -218,7 +218,6 @@ class TestInitializeToolSessions:
engine_client.model_config = model_config engine_client.model_config = model_config
engine_client.input_processor = MagicMock() engine_client.input_processor = MagicMock()
engine_client.io_processor = MagicMock()
engine_client.renderer = MagicMock() engine_client.renderer = MagicMock()
models = MagicMock() models = MagicMock()
...@@ -307,7 +306,6 @@ class TestValidateGeneratorInput: ...@@ -307,7 +306,6 @@ class TestValidateGeneratorInput:
engine_client.model_config = model_config engine_client.model_config = model_config
engine_client.input_processor = MagicMock() engine_client.input_processor = MagicMock()
engine_client.io_processor = MagicMock()
engine_client.renderer = MagicMock() engine_client.renderer = MagicMock()
models = MagicMock() models = MagicMock()
...@@ -369,7 +367,6 @@ async def test_reasoning_tokens_counted_for_text_reasoning_model(monkeypatch): ...@@ -369,7 +367,6 @@ async def test_reasoning_tokens_counted_for_text_reasoning_model(monkeypatch):
model_config.get_diff_sampling_param.return_value = {} model_config.get_diff_sampling_param.return_value = {}
engine_client.model_config = model_config engine_client.model_config = model_config
engine_client.input_processor = MagicMock() engine_client.input_processor = MagicMock()
engine_client.io_processor = MagicMock()
engine_client.renderer = MagicMock() engine_client.renderer = MagicMock()
tokenizer = FakeTokenizer() tokenizer = FakeTokenizer()
...@@ -672,7 +669,6 @@ def _make_serving_instance_with_reasoning(): ...@@ -672,7 +669,6 @@ def _make_serving_instance_with_reasoning():
model_config.get_diff_sampling_param.return_value = {} model_config.get_diff_sampling_param.return_value = {}
engine_client.model_config = model_config engine_client.model_config = model_config
engine_client.input_processor = MagicMock() engine_client.input_processor = MagicMock()
engine_client.io_processor = MagicMock()
engine_client.renderer = MagicMock() engine_client.renderer = MagicMock()
models = MagicMock() models = MagicMock()
......
...@@ -110,8 +110,11 @@ def test_score_api(llm: LLM): ...@@ -110,8 +110,11 @@ def test_score_api(llm: LLM):
llm.score("ping", "pong", use_tqdm=False) llm.score("ping", "pong", use_tqdm=False)
@pytest.mark.parametrize("task", ["embed", "token_embed"]) @pytest.mark.parametrize("task", ["embed", "token_embed", "plugin"])
def test_unsupported_tasks(llm: LLM, task: PoolingTask): def test_unsupported_tasks(llm: LLM, task: PoolingTask):
err_msg = "Embedding API is not supported by this model.+" if task == "plugin":
err_msg = "No IOProcessor plugin installed."
else:
err_msg = "Embedding API is not supported by this model.+"
with pytest.raises(ValueError, match=err_msg): with pytest.raises(ValueError, match=err_msg):
llm.encode(prompt, pooling_task=task, use_tqdm=False) llm.encode(prompt, pooling_task=task, use_tqdm=False)
...@@ -469,4 +469,8 @@ async def test_pooling_not_supported( ...@@ -469,4 +469,8 @@ async def test_pooling_not_supported(
}, },
) )
assert response.json()["error"]["type"] == "BadRequestError" assert response.json()["error"]["type"] == "BadRequestError"
assert response.json()["error"]["message"].startswith(f"Unsupported task: {task!r}") if task == "plugin":
err_msg = "No IOProcessor plugin installed."
else:
err_msg = f"Unsupported task: {task!r}"
assert response.json()["error"]["message"].startswith(err_msg)
...@@ -107,8 +107,11 @@ def test_pooling_params(llm: LLM): ...@@ -107,8 +107,11 @@ def test_pooling_params(llm: LLM):
) )
@pytest.mark.parametrize("task", ["token_classify", "classify"]) @pytest.mark.parametrize("task", ["token_classify", "classify", "plugin"])
def test_unsupported_tasks(llm: LLM, task: PoolingTask): def test_unsupported_tasks(llm: LLM, task: PoolingTask):
err_msg = "Classification API is not supported by this model.+" if task == "plugin":
err_msg = "No IOProcessor plugin installed."
else:
err_msg = "Classification API is not supported by this model.+"
with pytest.raises(ValueError, match=err_msg): with pytest.raises(ValueError, match=err_msg):
llm.encode(prompt, pooling_task=task, use_tqdm=False) llm.encode(prompt, pooling_task=task, use_tqdm=False)
...@@ -767,4 +767,8 @@ async def test_pooling_not_supported( ...@@ -767,4 +767,8 @@ async def test_pooling_not_supported(
}, },
) )
assert response.json()["error"]["type"] == "BadRequestError" assert response.json()["error"]["type"] == "BadRequestError"
assert response.json()["error"]["message"].startswith(f"Unsupported task: {task!r}") if task == "plugin":
err_msg = "No IOProcessor plugin installed."
else:
err_msg = f"Unsupported task: {task!r}"
assert response.json()["error"]["message"].startswith(err_msg)
...@@ -411,4 +411,8 @@ async def test_pooling_not_supported(server: RemoteOpenAIServer, task: str): ...@@ -411,4 +411,8 @@ async def test_pooling_not_supported(server: RemoteOpenAIServer, task: str):
}, },
) )
assert response.json()["error"]["type"] == "BadRequestError" assert response.json()["error"]["type"] == "BadRequestError"
assert response.json()["error"]["message"].startswith(f"Unsupported task: {task!r}") if task == "plugin":
err_msg = "No IOProcessor plugin installed."
else:
err_msg = f"Unsupported task: {task!r}"
assert response.json()["error"]["message"].startswith(err_msg)
...@@ -484,4 +484,8 @@ async def test_pooling_not_supported(server: RemoteOpenAIServer, task: str): ...@@ -484,4 +484,8 @@ async def test_pooling_not_supported(server: RemoteOpenAIServer, task: str):
}, },
) )
assert response.json()["error"]["type"] == "BadRequestError" assert response.json()["error"]["type"] == "BadRequestError"
assert response.json()["error"]["message"].startswith(f"Unsupported task: {task!r}") if task == "plugin":
err_msg = "No IOProcessor plugin installed."
else:
err_msg = f"Unsupported task: {task!r}"
assert response.json()["error"]["message"].startswith(err_msg)
...@@ -65,14 +65,17 @@ def test_score_api(llm: LLM): ...@@ -65,14 +65,17 @@ def test_score_api(llm: LLM):
llm.score("ping", "pong", use_tqdm=False) llm.score("ping", "pong", use_tqdm=False)
@pytest.mark.parametrize("task", ["classify", "embed", "token_embed"]) @pytest.mark.parametrize("task", ["classify", "embed", "token_embed", "plugin"])
def test_unsupported_tasks(llm: LLM, task: PoolingTask, caplog_vllm): def test_unsupported_tasks(llm: LLM, task: PoolingTask, caplog_vllm):
if task == "classify": if task == "classify":
with caplog_vllm.at_level(level=logging.WARNING, logger="vllm"): with caplog_vllm.at_level(level=logging.WARNING, logger="vllm"):
llm.encode(prompt, pooling_task=task, use_tqdm=False) llm.encode(prompt, pooling_task=task, use_tqdm=False)
assert "deprecated" in caplog_vllm.text assert "deprecated" in caplog_vllm.text
else: else:
err_msg = "Embedding API is not supported by this model.+" if task == "plugin":
err_msg = "No IOProcessor plugin installed."
else:
err_msg = "Embedding API is not supported by this model.+"
with pytest.raises(ValueError, match=err_msg): with pytest.raises(ValueError, match=err_msg):
llm.encode(prompt, pooling_task=task, use_tqdm=False) llm.encode(prompt, pooling_task=task, use_tqdm=False)
...@@ -50,7 +50,7 @@ async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: st ...@@ -50,7 +50,7 @@ async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: st
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("task", ["classify", "embed", "token_embed", "plugin"]) @pytest.mark.parametrize("task", ["embed", "token_embed", "plugin"])
async def test_pooling_not_supported( async def test_pooling_not_supported(
server: RemoteOpenAIServer, model_name: str, task: str server: RemoteOpenAIServer, model_name: str, task: str
): ):
...@@ -64,7 +64,8 @@ async def test_pooling_not_supported( ...@@ -64,7 +64,8 @@ async def test_pooling_not_supported(
}, },
) )
if task != "classify": if task == "plugin":
assert response.json()["error"]["type"] == "BadRequestError" err_msg = "No IOProcessor plugin installed."
else:
err_msg = f"Unsupported task: {task!r}" err_msg = f"Unsupported task: {task!r}"
assert response.json()["error"]["message"].startswith(err_msg) assert response.json()["error"]["message"].startswith(err_msg)
...@@ -62,14 +62,17 @@ def test_token_ids_prompts(llm: LLM): ...@@ -62,14 +62,17 @@ def test_token_ids_prompts(llm: LLM):
assert outputs[0].outputs.data.shape == (11, 384) assert outputs[0].outputs.data.shape == (11, 384)
@pytest.mark.parametrize("task", ["embed", "classify", "token_classify"]) @pytest.mark.parametrize("task", ["embed", "classify", "token_classify", "plugin"])
def test_unsupported_tasks(llm: LLM, task: PoolingTask, caplog_vllm): def test_unsupported_tasks(llm: LLM, task: PoolingTask, caplog_vllm):
if task == "embed": if task == "embed":
with caplog_vllm.at_level(level=logging.WARNING, logger="vllm"): with caplog_vllm.at_level(level=logging.WARNING, logger="vllm"):
llm.encode(prompt, pooling_task=task, use_tqdm=False) llm.encode(prompt, pooling_task=task, use_tqdm=False)
assert "deprecated" in caplog_vllm.text assert "deprecated" in caplog_vllm.text
else: else:
err_msg = "Classification API is not supported by this model.+" if task == "plugin":
err_msg = "No IOProcessor plugin installed."
else:
err_msg = "Classification API is not supported by this model.+"
with pytest.raises(ValueError, match=err_msg): with pytest.raises(ValueError, match=err_msg):
llm.encode(prompt, pooling_task=task, use_tqdm=False) llm.encode(prompt, pooling_task=task, use_tqdm=False)
...@@ -73,7 +73,7 @@ async def test_pooling_token_embed(server: RemoteOpenAIServer, model_name: str): ...@@ -73,7 +73,7 @@ async def test_pooling_token_embed(server: RemoteOpenAIServer, model_name: str):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("task", ["embed", "classify", "token_classify", "plugin"]) @pytest.mark.parametrize("task", ["classify", "token_classify", "plugin"])
async def test_pooling_not_supported( async def test_pooling_not_supported(
server: RemoteOpenAIServer, model_name: str, task: str server: RemoteOpenAIServer, model_name: str, task: str
): ):
...@@ -87,7 +87,8 @@ async def test_pooling_not_supported( ...@@ -87,7 +87,8 @@ async def test_pooling_not_supported(
}, },
) )
if task != "embed": if task == "plugin":
assert response.json()["error"]["type"] == "BadRequestError" err_msg = "No IOProcessor plugin installed."
else:
err_msg = f"Unsupported task: {task!r}" err_msg = f"Unsupported task: {task!r}"
assert response.json()["error"]["message"].startswith(err_msg) assert response.json()["error"]["message"].startswith(err_msg)
...@@ -86,7 +86,6 @@ def _build_serving_tokens(engine: AsyncLLM, **kwargs) -> ServingTokens: ...@@ -86,7 +86,6 @@ def _build_serving_tokens(engine: AsyncLLM, **kwargs) -> ServingTokens:
serving_render = OpenAIServingRender( serving_render = OpenAIServingRender(
model_config=engine.model_config, model_config=engine.model_config,
renderer=engine.renderer, renderer=engine.renderer,
io_processor=engine.io_processor,
model_registry=models.registry, model_registry=models.registry,
request_logger=None, request_logger=None,
chat_template=None, chat_template=None,
...@@ -148,7 +147,6 @@ def _mock_engine() -> MagicMock: ...@@ -148,7 +147,6 @@ def _mock_engine() -> MagicMock:
engine.errored = False engine.errored = False
engine.model_config = MockModelConfig() engine.model_config = MockModelConfig()
engine.input_processor = MagicMock() engine.input_processor = MagicMock()
engine.io_processor = MagicMock()
engine.renderer = _build_renderer(engine.model_config) engine.renderer = _build_renderer(engine.model_config)
return engine return engine
......
...@@ -34,7 +34,6 @@ async def _async_serving_models_init() -> OpenAIServingModels: ...@@ -34,7 +34,6 @@ async def _async_serving_models_init() -> OpenAIServingModels:
mock_model_config.max_model_len = 2048 mock_model_config.max_model_len = 2048
mock_engine_client.model_config = mock_model_config mock_engine_client.model_config = mock_model_config
mock_engine_client.input_processor = MagicMock() mock_engine_client.input_processor = MagicMock()
mock_engine_client.io_processor = MagicMock()
mock_engine_client.renderer = MagicMock() mock_engine_client.renderer = MagicMock()
serving_models = OpenAIServingModels( serving_models = OpenAIServingModels(
......
...@@ -514,7 +514,6 @@ async def test_header_dp_rank_argument(): ...@@ -514,7 +514,6 @@ async def test_header_dp_rank_argument():
serving_render = OpenAIServingRender( serving_render = OpenAIServingRender(
model_config=engine.model_config, model_config=engine.model_config,
renderer=engine.renderer, renderer=engine.renderer,
io_processor=engine.io_processor,
model_registry=models.registry, model_registry=models.registry,
request_logger=None, request_logger=None,
chat_template=None, chat_template=None,
......
...@@ -14,7 +14,6 @@ from vllm.distributed.weight_transfer.base import ( ...@@ -14,7 +14,6 @@ from vllm.distributed.weight_transfer.base import (
from vllm.inputs import EngineInput, PromptType from vllm.inputs import EngineInput, PromptType
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import IOProcessor
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.renderers import BaseRenderer from vllm.renderers import BaseRenderer
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
...@@ -44,7 +43,6 @@ class EngineClient(ABC): ...@@ -44,7 +43,6 @@ class EngineClient(ABC):
vllm_config: VllmConfig vllm_config: VllmConfig
model_config: ModelConfig model_config: ModelConfig
renderer: BaseRenderer renderer: BaseRenderer
io_processor: IOProcessor | None
input_processor: InputProcessor input_processor: InputProcessor
@property @property
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment