Unverified Commit 799397ee authored by Maximilien de Bayser's avatar Maximilien de Bayser Committed by GitHub
Browse files

Support embedding models in V1 (#16188)


Signed-off-by: default avatarMax de Bayser <mbayser@br.ibm.com>
Signed-off-by: default avatarMax de Bayser <maxdebayser@gmail.com>
Signed-off-by: default avatar22quinn <33176974+22quinn@users.noreply.github.com>
Co-authored-by: default avatar22quinn <33176974+22quinn@users.noreply.github.com>
parent 49599150
...@@ -12,7 +12,10 @@ def parse_args(): ...@@ -12,7 +12,10 @@ def parse_args():
parser = EngineArgs.add_cli_args(parser) parser = EngineArgs.add_cli_args(parser)
# Set example specific arguments # Set example specific arguments
parser.set_defaults( parser.set_defaults(
model="intfloat/e5-mistral-7b-instruct", task="embed", enforce_eager=True model="intfloat/e5-mistral-7b-instruct",
task="embed",
enforce_eager=True,
max_model_len=1024,
) )
return parser.parse_args() return parser.parse_args()
......
...@@ -94,6 +94,7 @@ def run_vlm2vec(query: Query) -> ModelRequestData: ...@@ -94,6 +94,7 @@ def run_vlm2vec(query: Query) -> ModelRequestData:
engine_args = EngineArgs( engine_args = EngineArgs(
model="TIGER-Lab/VLM2Vec-Full", model="TIGER-Lab/VLM2Vec-Full",
task="embed", task="embed",
max_model_len=4096,
trust_remote_code=True, trust_remote_code=True,
mm_processor_kwargs={"num_crops": 4}, mm_processor_kwargs={"num_crops": 4},
limit_mm_per_prompt={"image": 1}, limit_mm_per_prompt={"image": 1},
......
...@@ -31,7 +31,7 @@ class TestSetting: ...@@ -31,7 +31,7 @@ class TestSetting:
# basic llama model # basic llama model
TestSetting( TestSetting(
model="meta-llama/Llama-3.2-1B-Instruct", model="meta-llama/Llama-3.2-1B-Instruct",
model_args=[], model_args=["--max-model-len", "2048"],
pp_size=2, pp_size=2,
tp_size=2, tp_size=2,
attn_backend="FLASHINFER", attn_backend="FLASHINFER",
...@@ -41,7 +41,7 @@ class TestSetting: ...@@ -41,7 +41,7 @@ class TestSetting:
# llama model with quantization # llama model with quantization
TestSetting( TestSetting(
model="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", model="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
model_args=["--quantization", "gptq"], model_args=["--quantization", "gptq", "--max-model-len", "2048"],
pp_size=1, pp_size=1,
tp_size=1, tp_size=1,
attn_backend="FLASH_ATTN", attn_backend="FLASH_ATTN",
...@@ -51,7 +51,7 @@ class TestSetting: ...@@ -51,7 +51,7 @@ class TestSetting:
# MoE model # MoE model
TestSetting( TestSetting(
model="ibm/PowerMoE-3b", model="ibm/PowerMoE-3b",
model_args=[], model_args=["--max-model-len", "2048"],
pp_size=1, pp_size=1,
tp_size=2, tp_size=2,
attn_backend="FLASH_ATTN", attn_backend="FLASH_ATTN",
...@@ -61,23 +61,27 @@ class TestSetting: ...@@ -61,23 +61,27 @@ class TestSetting:
# embedding model # embedding model
TestSetting( TestSetting(
model="BAAI/bge-multilingual-gemma2", model="BAAI/bge-multilingual-gemma2",
model_args=["--task", "embed", "--dtype", "bfloat16"], model_args=[
"--task", "embed", "--dtype", "bfloat16", "--max-model-len",
"2048"
],
pp_size=1, pp_size=1,
tp_size=1, tp_size=1,
attn_backend="FLASH_ATTN", attn_backend="FLASH_ATTN",
method="encode", method="encode",
fullgraph=True, fullgraph=True,
), ),
# encoder-based embedding model (BERT) # TODO: bert models are not supported in V1 yet
TestSetting( # # encoder-based embedding model (BERT)
model="BAAI/bge-base-en-v1.5", # TestSetting(
model_args=["--task", "embed"], # model="BAAI/bge-base-en-v1.5",
pp_size=1, # model_args=["--task", "embed"],
tp_size=1, # pp_size=1,
attn_backend="XFORMERS", # tp_size=1,
method="encode", # attn_backend="XFORMERS",
fullgraph=True, # method="encode",
), # fullgraph=True,
# ),
# vision language model # vision language model
TestSetting( TestSetting(
model="microsoft/Phi-3.5-vision-instruct", model="microsoft/Phi-3.5-vision-instruct",
......
...@@ -145,6 +145,7 @@ def run_with_both_engines(request, monkeypatch): ...@@ -145,6 +145,7 @@ def run_with_both_engines(request, monkeypatch):
# Automatically runs tests twice, once with V1 and once without # Automatically runs tests twice, once with V1 and once without
use_v1 = request.param use_v1 = request.param
# Tests decorated with `@skip_v1` are only run without v1 # Tests decorated with `@skip_v1` are only run without v1
skip_v0 = request.node.get_closest_marker("skip_v0")
skip_v1 = request.node.get_closest_marker("skip_v1") skip_v1 = request.node.get_closest_marker("skip_v1")
if use_v1: if use_v1:
...@@ -152,6 +153,8 @@ def run_with_both_engines(request, monkeypatch): ...@@ -152,6 +153,8 @@ def run_with_both_engines(request, monkeypatch):
pytest.skip("Skipping test on vllm V1") pytest.skip("Skipping test on vllm V1")
monkeypatch.setenv('VLLM_USE_V1', '1') monkeypatch.setenv('VLLM_USE_V1', '1')
else: else:
if skip_v0:
pytest.skip("Skipping test on vllm V0")
monkeypatch.setenv('VLLM_USE_V1', '0') monkeypatch.setenv('VLLM_USE_V1', '0')
yield yield
......
...@@ -8,6 +8,8 @@ import pytest ...@@ -8,6 +8,8 @@ import pytest
from vllm import LLM, PoolingParams, PoolingRequestOutput from vllm import LLM, PoolingParams, PoolingRequestOutput
from vllm.distributed import cleanup_dist_env_and_memory from vllm.distributed import cleanup_dist_env_and_memory
from ...models.utils import check_embeddings_close
MODEL_NAME = "intfloat/multilingual-e5-small" MODEL_NAME = "intfloat/multilingual-e5-small"
PROMPTS = [ PROMPTS = [
...@@ -27,6 +29,14 @@ TOKEN_IDS = [ ...@@ -27,6 +29,14 @@ TOKEN_IDS = [
] ]
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def llm(): def llm():
# pytest caches the fixture so we use weakref.proxy to # pytest caches the fixture so we use weakref.proxy to
...@@ -46,9 +56,15 @@ def llm(): ...@@ -46,9 +56,15 @@ def llm():
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
def assert_outputs_equal(o1: list[PoolingRequestOutput], def assert_outputs_match(o1: list[PoolingRequestOutput],
o2: list[PoolingRequestOutput]): o2: list[PoolingRequestOutput]):
assert [o.outputs for o in o1] == [o.outputs for o in o2] check_embeddings_close(
embeddings_0_lst=[o.outputs.data for o in o1],
embeddings_1_lst=[o.outputs.data for o in o2],
name_0="hf",
name_1="vllm",
tol=1e-2,
)
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
...@@ -63,7 +79,7 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, ...@@ -63,7 +79,7 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
v2_output = llm.encode({"prompt_token_ids": prompt_token_ids}, v2_output = llm.encode({"prompt_token_ids": prompt_token_ids},
pooling_params=pooling_params) pooling_params=pooling_params)
assert_outputs_equal(v1_output, v2_output) assert_outputs_match(v1_output, v2_output)
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
...@@ -80,7 +96,7 @@ def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM): ...@@ -80,7 +96,7 @@ def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM):
} for p in TOKEN_IDS], } for p in TOKEN_IDS],
pooling_params=pooling_params, pooling_params=pooling_params,
) )
assert_outputs_equal(v1_output, v2_output) assert_outputs_match(v1_output, v2_output)
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
......
...@@ -21,6 +21,14 @@ DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + ...@@ -21,6 +21,14 @@ DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' +
DTYPE = "bfloat16" DTYPE = "bfloat16"
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server(): def server():
args = [ args = [
......
...@@ -7,6 +7,7 @@ import numpy as np ...@@ -7,6 +7,7 @@ import numpy as np
import pytest import pytest
import requests import requests
from tests.models.utils import check_embeddings_close
from vllm.entrypoints.openai.protocol import PoolingResponse from vllm.entrypoints.openai.protocol import PoolingResponse
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
...@@ -223,8 +224,11 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer, ...@@ -223,8 +224,11 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer,
np.frombuffer(base64.b64decode(data.data), np.frombuffer(base64.b64decode(data.data),
dtype="float32").tolist()) dtype="float32").tolist())
assert responses_float.data[0].data == decoded_responses_base64_data[0] check_embeddings_close(
assert responses_float.data[1].data == decoded_responses_base64_data[1] embeddings_0_lst=[d.data for d in responses_float.data],
embeddings_1_lst=decoded_responses_base64_data,
name_0="float32",
name_1="base64")
# Default response is float32 decoded from base64 by OpenAI Client # Default response is float32 decoded from base64 by OpenAI Client
default_response = requests.post( default_response = requests.post(
...@@ -237,5 +241,8 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer, ...@@ -237,5 +241,8 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer,
default_response.raise_for_status() default_response.raise_for_status()
responses_default = PoolingResponse.model_validate(default_response.json()) responses_default = PoolingResponse.model_validate(default_response.json())
assert responses_float.data[0].data == responses_default.data[0].data check_embeddings_close(
assert responses_float.data[1].data == responses_default.data[1].data embeddings_0_lst=[d.data for d in responses_default.data],
embeddings_1_lst=[d.data for d in responses_default.data],
name_0="float32",
name_1="base64")
...@@ -12,6 +12,14 @@ MODEL_NAME = "BAAI/bge-reranker-base" ...@@ -12,6 +12,14 @@ MODEL_NAME = "BAAI/bge-reranker-base"
DTYPE = "bfloat16" DTYPE = "bfloat16"
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server(): def server():
args = ["--enforce-eager", "--max-model-len", "100", "--dtype", DTYPE] args = ["--enforce-eager", "--max-model-len", "100", "--dtype", DTYPE]
......
...@@ -11,6 +11,15 @@ from vllm.entrypoints.openai.protocol import ScoreResponse ...@@ -11,6 +11,15 @@ from vllm.entrypoints.openai.protocol import ScoreResponse
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
MODELS = [ MODELS = [
{ {
"name": "BAAI/bge-reranker-v2-m3", "name": "BAAI/bge-reranker-v2-m3",
......
...@@ -6,6 +6,14 @@ from transformers import AutoModelForSequenceClassification ...@@ -6,6 +6,14 @@ from transformers import AutoModelForSequenceClassification
from vllm.platforms import current_platform from vllm.platforms import current_platform
# TODO: enable when float32 is supported by V1
# @pytest.fixture(autouse=True)
# def v1(run_with_both_engines):
# # Simple autouse wrapper to run both engines for each test
# # This can be promoted up to conftest.py to run for every
# # test in a package
# pass
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model", "model",
...@@ -29,7 +37,7 @@ def test_models( ...@@ -29,7 +37,7 @@ def test_models(
# switch to use ROCm CK FA backend # switch to use ROCm CK FA backend
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False") monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
with vllm_runner(model, dtype=dtype) as vllm_model: with vllm_runner(model, max_model_len=512, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.classify(example_prompts) vllm_outputs = vllm_model.classify(example_prompts)
with hf_runner(model, with hf_runner(model,
......
...@@ -8,6 +8,14 @@ from vllm.platforms import current_platform ...@@ -8,6 +8,14 @@ from vllm.platforms import current_platform
from ...utils import check_embeddings_close from ...utils import check_embeddings_close
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model", "model",
[ [
...@@ -20,15 +28,27 @@ from ...utils import check_embeddings_close ...@@ -20,15 +28,27 @@ from ...utils import check_embeddings_close
marks=[pytest.mark.core_model]), marks=[pytest.mark.core_model]),
pytest.param("intfloat/e5-mistral-7b-instruct", pytest.param("intfloat/e5-mistral-7b-instruct",
marks=[pytest.mark.core_model, pytest.mark.cpu_model]), marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
pytest.param("ssmits/Qwen2-7B-Instruct-embed-base"), # the qwen models interfere with each other (see PR
# https://github.com/vllm-project/vllm/pull/18720).
# To avoid this problem, for now we skip v0 since it will be
# deprecated anyway.
pytest.param("ssmits/Qwen2-7B-Instruct-embed-base",
marks=[pytest.mark.skip_v0]),
# [Encoder-only] # [Encoder-only]
pytest.param("BAAI/bge-base-en-v1.5", pytest.param("BAAI/bge-base-en-v1.5",
marks=[pytest.mark.core_model, pytest.mark.cpu_model]), marks=[
pytest.param("sentence-transformers/all-MiniLM-L12-v2"), pytest.mark.core_model, pytest.mark.cpu_model,
pytest.param("intfloat/multilingual-e5-small"), pytest.mark.skip_v1
pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"), ]),
pytest.param("sentence-transformers/all-MiniLM-L12-v2",
marks=[pytest.mark.skip_v1]),
pytest.param("intfloat/multilingual-e5-small",
marks=[pytest.mark.skip_v1]),
pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct",
marks=[pytest.mark.skip_v1]),
# [Cross-Encoder] # [Cross-Encoder]
pytest.param("sentence-transformers/stsb-roberta-base-v2"), pytest.param("sentence-transformers/stsb-roberta-base-v2",
marks=[pytest.mark.skip_v1]),
], ],
) )
def test_models( def test_models(
...@@ -62,7 +82,7 @@ def test_models( ...@@ -62,7 +82,7 @@ def test_models(
with vllm_runner(model, with vllm_runner(model,
task="embed", task="embed",
max_model_len=None, max_model_len=512,
**vllm_extra_kwargs) as vllm_model: **vllm_extra_kwargs) as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts) vllm_outputs = vllm_model.encode(example_prompts)
......
...@@ -265,8 +265,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { ...@@ -265,8 +265,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
_EMBEDDING_EXAMPLE_MODELS = { _EMBEDDING_EXAMPLE_MODELS = {
# [Text-only] # [Text-only]
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"), "BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5", v0_only=True),
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"), "Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2", v0_only=True), # noqa: E501
"GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"), "GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"),
"GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0", "GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0",
trust_remote_code=True), trust_remote_code=True),
...@@ -279,16 +279,16 @@ _EMBEDDING_EXAMPLE_MODELS = { ...@@ -279,16 +279,16 @@ _EMBEDDING_EXAMPLE_MODELS = {
"LlamaModel": _HfExamplesInfo("llama", is_available_online=False), "LlamaModel": _HfExamplesInfo("llama", is_available_online=False),
"MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"), "MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"),
"ModernBertModel": _HfExamplesInfo("Alibaba-NLP/gte-modernbert-base", "ModernBertModel": _HfExamplesInfo("Alibaba-NLP/gte-modernbert-base",
trust_remote_code=True), trust_remote_code=True, v0_only=True),
"NomicBertModel": _HfExamplesInfo("nomic-ai/nomic-embed-text-v2-moe", "NomicBertModel": _HfExamplesInfo("nomic-ai/nomic-embed-text-v2-moe",
trust_remote_code=True), trust_remote_code=True, v0_only=True), # noqa: E501
"Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"), "Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"),
"Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"), "Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"),
"Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B"), "Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B"),
"Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501 "Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501
"RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2"), # noqa: E501 "RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2", v0_only=True), # noqa: E501
"RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"), # noqa: E501 "RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1", v0_only=True), # noqa: E501
"XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small"), "XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small", v0_only=True), # noqa: E501
# [Multimodal] # [Multimodal]
"LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"), "LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"),
"Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full", "Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full",
...@@ -300,10 +300,10 @@ _EMBEDDING_EXAMPLE_MODELS = { ...@@ -300,10 +300,10 @@ _EMBEDDING_EXAMPLE_MODELS = {
_CROSS_ENCODER_EXAMPLE_MODELS = { _CROSS_ENCODER_EXAMPLE_MODELS = {
# [Text-only] # [Text-only]
"BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2"), # noqa: E501 "BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2", v0_only=True), # noqa: E501
"RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base"), # noqa: E501 "RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base", v0_only=True), # noqa: E501
"XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3"), # noqa: E501 "XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3", v0_only=True), # noqa: E501
"ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base"), # noqa: E501 "ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501
} }
_MULTIMODAL_EXAMPLE_MODELS = { _MULTIMODAL_EXAMPLE_MODELS = {
......
...@@ -68,6 +68,7 @@ def _run_incremental_decode(tokenizer, ...@@ -68,6 +68,7 @@ def _run_incremental_decode(tokenizer,
None, None,
params, params,
None, None,
None,
0.0, 0.0,
None, None,
cache_salt=None, cache_salt=None,
......
...@@ -43,6 +43,7 @@ def make_request(request_id, ...@@ -43,6 +43,7 @@ def make_request(request_id,
multi_modal_hashes=mm_hashes, multi_modal_hashes=mm_hashes,
multi_modal_placeholders=mm_positions, multi_modal_placeholders=mm_positions,
sampling_params=SamplingParams(max_tokens=17), sampling_params=SamplingParams(max_tokens=17),
pooling_params=None,
eos_token_id=100, eos_token_id=100,
lora_request=None, lora_request=None,
cache_salt=cache_salt, cache_salt=cache_salt,
......
...@@ -39,6 +39,7 @@ def make_request(request_id, ...@@ -39,6 +39,7 @@ def make_request(request_id,
multi_modal_placeholders=mm_positions, multi_modal_placeholders=mm_positions,
sampling_params=SamplingParams(max_tokens=17, sampling_params=SamplingParams(max_tokens=17,
prompt_logprobs=prompt_logprobs), prompt_logprobs=prompt_logprobs),
pooling_params=None,
eos_token_id=100, eos_token_id=100,
lora_request=None, lora_request=None,
cache_salt=cache_salt, cache_salt=cache_salt,
......
...@@ -135,6 +135,7 @@ def create_requests(num_requests: int, ...@@ -135,6 +135,7 @@ def create_requests(num_requests: int,
request_id=f"{i}", request_id=f"{i}",
prompt_token_ids=[i] * num_tokens, prompt_token_ids=[i] * num_tokens,
sampling_params=sampling_params, sampling_params=sampling_params,
pooling_params=None,
multi_modal_inputs=mm_inputs, multi_modal_inputs=mm_inputs,
multi_modal_placeholders=mm_position, multi_modal_placeholders=mm_position,
multi_modal_hashes=None, multi_modal_hashes=None,
...@@ -283,6 +284,7 @@ def test_schedule_partial_requests(): ...@@ -283,6 +284,7 @@ def test_schedule_partial_requests():
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[],
) )
scheduler.update_from_output(output, model_runner_output) scheduler.update_from_output(output, model_runner_output)
...@@ -333,6 +335,7 @@ def test_no_mm_input_chunking(): ...@@ -333,6 +335,7 @@ def test_no_mm_input_chunking():
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[],
) )
scheduler.update_from_output(output, model_runner_output) scheduler.update_from_output(output, model_runner_output)
...@@ -396,6 +399,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): ...@@ -396,6 +399,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[],
) )
scheduler.update_from_output(output, model_runner_output) scheduler.update_from_output(output, model_runner_output)
...@@ -420,6 +424,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): ...@@ -420,6 +424,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[],
) )
scheduler.update_from_output(output1, model_runner_output) scheduler.update_from_output(output1, model_runner_output)
output2 = scheduler.schedule() output2 = scheduler.schedule()
...@@ -473,7 +478,8 @@ def test_stop_via_update_from_output(): ...@@ -473,7 +478,8 @@ def test_stop_via_update_from_output():
11]], # First request hits EOS, second continues 11]], # First request hits EOS, second continues
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}) prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output, model_output) scheduler.update_from_output(scheduler_output, model_output)
...@@ -523,7 +529,8 @@ def test_stop_via_update_from_output(): ...@@ -523,7 +529,8 @@ def test_stop_via_update_from_output():
[13, 14]], # First request hits stop token [13, 14]], # First request hits stop token
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}) prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output, model_output) scheduler.update_from_output(scheduler_output, model_output)
...@@ -572,7 +579,8 @@ def test_stop_via_update_from_output(): ...@@ -572,7 +579,8 @@ def test_stop_via_update_from_output():
[13]], # First request exceeds max_tokens [13]], # First request exceeds max_tokens
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}) prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output, model_output) scheduler.update_from_output(scheduler_output, model_output)
...@@ -614,7 +622,8 @@ def test_stop_via_update_from_output(): ...@@ -614,7 +622,8 @@ def test_stop_via_update_from_output():
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]], sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}) prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output, model_output) scheduler.update_from_output(scheduler_output, model_output)
...@@ -663,6 +672,7 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool], ...@@ -663,6 +672,7 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[],
) )
scheduler.update_from_output(scheduler_output0, model_runner_output) scheduler.update_from_output(scheduler_output0, model_runner_output)
...@@ -680,6 +690,7 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool], ...@@ -680,6 +690,7 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[],
) )
scheduler.update_from_output(scheduler_output1, model_runner_output) scheduler.update_from_output(scheduler_output1, model_runner_output)
...@@ -730,6 +741,7 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): ...@@ -730,6 +741,7 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
spec_token_ids=spec_tokens, spec_token_ids=spec_tokens,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[],
) )
engine_core_outputs = scheduler.update_from_output(output, engine_core_outputs = scheduler.update_from_output(output,
model_runner_output) model_runner_output)
...@@ -769,6 +781,7 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): ...@@ -769,6 +781,7 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[],
) )
engine_core_outputs = scheduler.update_from_output(output, engine_core_outputs = scheduler.update_from_output(output,
model_runner_output) model_runner_output)
...@@ -896,6 +909,7 @@ def test_kv_connector_basic(): ...@@ -896,6 +909,7 @@ def test_kv_connector_basic():
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[],
) )
# Ensure ScheduleOutput is correct. # Ensure ScheduleOutput is correct.
...@@ -941,6 +955,7 @@ def test_kv_connector_basic(): ...@@ -941,6 +955,7 @@ def test_kv_connector_basic():
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[],
) )
# We should get a local cache hit of NUM_TOKENS_PREFIX and # We should get a local cache hit of NUM_TOKENS_PREFIX and
...@@ -1007,6 +1022,7 @@ def test_kv_connector_unable_to_allocate(): ...@@ -1007,6 +1022,7 @@ def test_kv_connector_unable_to_allocate():
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[],
) )
# Just one request should be running. # Just one request should be running.
...@@ -1087,6 +1103,7 @@ def test_kv_connector_handles_preemption(): ...@@ -1087,6 +1103,7 @@ def test_kv_connector_handles_preemption():
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[],
) )
# All can be scheduled - 1st token. # All can be scheduled - 1st token.
...@@ -1181,6 +1198,7 @@ def make_output(scheduler: Scheduler): ...@@ -1181,6 +1198,7 @@ def make_output(scheduler: Scheduler):
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[],
) )
......
...@@ -39,6 +39,7 @@ def make_request() -> EngineCoreRequest: ...@@ -39,6 +39,7 @@ def make_request() -> EngineCoreRequest:
mm_hashes=None, mm_hashes=None,
mm_placeholders=None, mm_placeholders=None,
sampling_params=SamplingParams(), sampling_params=SamplingParams(),
pooling_params=None,
eos_token_id=None, eos_token_id=None,
arrival_time=time.time(), arrival_time=time.time(),
lora_request=None, lora_request=None,
......
...@@ -53,6 +53,7 @@ def make_request( ...@@ -53,6 +53,7 @@ def make_request(
mm_hashes=None, mm_hashes=None,
mm_placeholders=None, mm_placeholders=None,
sampling_params=params, sampling_params=params,
pooling_params=None,
eos_token_id=None, eos_token_id=None,
arrival_time=time.time(), arrival_time=time.time(),
lora_request=None, lora_request=None,
......
...@@ -33,6 +33,7 @@ def test_fast_inc_detok_invalid_utf8_err_case(): ...@@ -33,6 +33,7 @@ def test_fast_inc_detok_invalid_utf8_err_case():
None, None,
params, params,
None, None,
None,
0.0, 0.0,
None, None,
cache_salt=None, cache_salt=None,
......
...@@ -66,7 +66,8 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind, ...@@ -66,7 +66,8 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind,
output_kind=request_output_kind, output_kind=request_output_kind,
stop=[], stop=[],
include_stop_str_in_output=False, include_stop_str_in_output=False,
)) ),
pooling_params=None)
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
] ]
...@@ -416,7 +417,8 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind, ...@@ -416,7 +417,8 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
include_stop_str_in_output=False, include_stop_str_in_output=False,
logprobs=num_sample_logprobs, logprobs=num_sample_logprobs,
prompt_logprobs=num_prompt_logprobs, prompt_logprobs=num_prompt_logprobs,
)) ),
pooling_params=None)
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
] ]
...@@ -582,7 +584,8 @@ def test_stop_token(include_stop_str_in_output: bool, ...@@ -582,7 +584,8 @@ def test_stop_token(include_stop_str_in_output: bool,
logprobs=num_sample_logprobs, logprobs=num_sample_logprobs,
prompt_logprobs=None, prompt_logprobs=None,
ignore_eos=ignore_eos, ignore_eos=ignore_eos,
)) ),
pooling_params=None)
# Add request to the detokenizer. # Add request to the detokenizer.
output_processor.add_request(request, prompt_string) output_processor.add_request(request, prompt_string)
...@@ -678,7 +681,8 @@ def test_stop_string(include_stop_str_in_output: bool, ...@@ -678,7 +681,8 @@ def test_stop_string(include_stop_str_in_output: bool,
include_stop_str_in_output=include_stop_str_in_output, include_stop_str_in_output=include_stop_str_in_output,
logprobs=num_sample_logprobs, logprobs=num_sample_logprobs,
prompt_logprobs=None, prompt_logprobs=None,
)) ),
pooling_params=None)
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
] ]
...@@ -786,6 +790,7 @@ def test_iteration_stats(dummy_test_vectors): ...@@ -786,6 +790,7 @@ def test_iteration_stats(dummy_test_vectors):
cache_salt=None, cache_salt=None,
data_parallel_rank=None, data_parallel_rank=None,
sampling_params=SamplingParams(), sampling_params=SamplingParams(),
pooling_params=None,
) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) ) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment