Unverified Commit f399182e authored by Chenheli Hua's avatar Chenheli Hua Committed by GitHub
Browse files

Run ruff format on a few files. (#24075)


Signed-off-by: default avatarChenheli Hua <huachenheli@outlook.com>
parent 1c413105
...@@ -46,23 +46,27 @@ MISTRAL_MODEL_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" ...@@ -46,23 +46,27 @@ MISTRAL_MODEL_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def phi3v_model_config(): def phi3v_model_config():
return ModelConfig(PHI3V_MODEL_ID, return ModelConfig(
runner="generate", PHI3V_MODEL_ID,
trust_remote_code=True, runner="generate",
limit_mm_per_prompt={ trust_remote_code=True,
"image": 2, limit_mm_per_prompt={
}) "image": 2,
},
)
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def phi3v_model_config_mm_interleaved(): def phi3v_model_config_mm_interleaved():
return ModelConfig(PHI3V_MODEL_ID, return ModelConfig(
runner="generate", PHI3V_MODEL_ID,
trust_remote_code=True, runner="generate",
interleave_mm_strings=True, trust_remote_code=True,
limit_mm_per_prompt={ interleave_mm_strings=True,
"image": 2, limit_mm_per_prompt={
}) "image": 2,
},
)
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
...@@ -77,14 +81,16 @@ def phi3v_tokenizer(): ...@@ -77,14 +81,16 @@ def phi3v_tokenizer():
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def qwen25omni_model_config_mm_interleaved(): def qwen25omni_model_config_mm_interleaved():
return ModelConfig(QWEN25OMNI_MODEL_ID, return ModelConfig(
runner="generate", QWEN25OMNI_MODEL_ID,
interleave_mm_strings=True, runner="generate",
limit_mm_per_prompt={ interleave_mm_strings=True,
"image": 2, limit_mm_per_prompt={
"audio": 1, "image": 2,
"video": 1, "audio": 1,
}) "video": 1,
},
)
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
...@@ -99,11 +105,13 @@ def qwen25omni_tokenizer(): ...@@ -99,11 +105,13 @@ def qwen25omni_tokenizer():
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def mllama_model_config(): def mllama_model_config():
return ModelConfig(MLLAMA_MODEL_ID, return ModelConfig(
runner="generate", MLLAMA_MODEL_ID,
limit_mm_per_prompt={ runner="generate",
"image": 2, limit_mm_per_prompt={
}) "image": 2,
},
)
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
...@@ -118,11 +126,13 @@ def mllama_tokenizer(): ...@@ -118,11 +126,13 @@ def mllama_tokenizer():
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def mistral_model_config(): def mistral_model_config():
return ModelConfig(MISTRAL_MODEL_ID, return ModelConfig(
runner="generate", MISTRAL_MODEL_ID,
limit_mm_per_prompt={ runner="generate",
"image": 2, limit_mm_per_prompt={
}) "image": 2,
},
)
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
...@@ -137,21 +147,21 @@ def mistral_tokenizer(): ...@@ -137,21 +147,21 @@ def mistral_tokenizer():
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def image_url(): def image_url():
image = ImageAsset('cherry_blossom') image = ImageAsset("cherry_blossom")
base64 = encode_image_base64(image.pil_image) base64 = encode_image_base64(image.pil_image)
return f"data:image/jpeg;base64,{base64}" return f"data:image/jpeg;base64,{base64}"
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def video_url(): def video_url():
video = VideoAsset('baby_reading', 1) video = VideoAsset("baby_reading", 1)
base64 = encode_video_base64(video.np_ndarrays) base64 = encode_video_base64(video.np_ndarrays)
return f"data:video/jpeg;base64,{base64}" return f"data:video/jpeg;base64,{base64}"
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def audio_url(): def audio_url():
audio = AudioAsset('mary_had_lamb') audio = AudioAsset("mary_had_lamb")
base64 = encode_audio_base64(*audio.audio_and_sample_rate) base64 = encode_audio_base64(*audio.audio_and_sample_rate)
return f"data:audio/ogg;base64,{base64}" return f"data:audio/ogg;base64,{base64}"
...@@ -195,15 +205,18 @@ def test_parse_chat_messages_single_image( ...@@ -195,15 +205,18 @@ def test_parse_chat_messages_single_image(
[{ [{
"role": "role":
"user", "user",
"content": [{ "content": [
"type": "image_url", {
"image_url": { "type": "image_url",
"url": image_url "image_url": {
} "url": image_url
}, { }
"type": "text", },
"text": "What's in the image?" {
}] "type": "text",
"text": "What's in the image?"
},
],
}], }],
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer, phi3v_tokenizer,
...@@ -223,58 +236,69 @@ def test_parse_chat_messages_empty_system( ...@@ -223,58 +236,69 @@ def test_parse_chat_messages_empty_system(
): ):
# Test string format # Test string format
conversation, _ = parse_chat_messages( conversation, _ = parse_chat_messages(
[{ [
"role": "system", {
"content": "" "role": "system",
}, { "content": ""
"role": "user", },
"content": [{ {
"type": "text", "role": "user",
"text": "Who are you?" "content": [{
}] "type": "text",
}], "text": "Who are you?"
}],
},
],
mistral_model_config, mistral_model_config,
mistral_tokenizer, mistral_tokenizer,
content_format="string", content_format="string",
) )
assert conversation == [{ assert conversation == [
"role": "system", {
"content": "" "role": "system",
}, { "content": ""
"role": "user", },
"content": "Who are you?" {
}] "role": "user",
"content": "Who are you?"
},
]
# Test openai format # Test openai format
conversation, _ = parse_chat_messages( conversation, _ = parse_chat_messages(
[{ [
{
"role": "system",
"content": ""
},
{
"role": "user",
"content": [{
"type": "text",
"text": "Who are you?"
}],
},
],
mistral_model_config,
mistral_tokenizer,
content_format="openai",
)
assert conversation == [
{
"role": "system", "role": "system",
"content": "" "content": [{
}, { "type": "text",
"text": ""
}]
},
{
"role": "user", "role": "user",
"content": [{ "content": [{
"type": "text", "type": "text",
"text": "Who are you?" "text": "Who are you?"
}] }]
}], },
mistral_model_config, ]
mistral_tokenizer,
content_format="openai",
)
assert conversation == [{
"role": "system",
"content": [{
"type": "text",
"text": ""
}]
}, {
"role":
"user",
"content": [{
"type": "text",
"text": "Who are you?"
}]
}]
@pytest.mark.asyncio @pytest.mark.asyncio
...@@ -287,15 +311,18 @@ async def test_parse_chat_messages_single_image_async( ...@@ -287,15 +311,18 @@ async def test_parse_chat_messages_single_image_async(
[{ [{
"role": "role":
"user", "user",
"content": [{ "content": [
"type": "image_url", {
"image_url": { "type": "image_url",
"url": image_url "image_url": {
} "url": image_url
}, { }
"type": "text", },
"text": "What's in the image?" {
}] "type": "text",
"text": "What's in the image?"
},
],
}], }],
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer, phi3v_tokenizer,
...@@ -318,18 +345,22 @@ def test_parse_chat_messages_multiple_images( ...@@ -318,18 +345,22 @@ def test_parse_chat_messages_multiple_images(
[{ [{
"role": "role":
"user", "user",
"content": [{ "content": [
"type": "image_url", {
"image_url": { "type": "image_url",
"url": image_url "image_url": {
} "url": image_url
}, { }
"type": "image_pil", },
"image_pil": ImageAsset('cherry_blossom').pil_image {
}, { "type": "image_pil",
"type": "text", "image_pil": ImageAsset("cherry_blossom").pil_image,
"text": "What's in these images?" },
}] {
"type": "text",
"text": "What's in these images?"
},
],
}], }],
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer, phi3v_tokenizer,
...@@ -340,7 +371,7 @@ def test_parse_chat_messages_multiple_images( ...@@ -340,7 +371,7 @@ def test_parse_chat_messages_multiple_images(
"role": "role":
"user", "user",
"content": "content":
"<|image_1|>\n<|image_2|>\nWhat's in these images?" "<|image_1|>\n<|image_2|>\nWhat's in these images?",
}] }]
_assert_mm_data_is_image_input(mm_data, 2) _assert_mm_data_is_image_input(mm_data, 2)
...@@ -355,18 +386,22 @@ async def test_parse_chat_messages_multiple_images_async( ...@@ -355,18 +386,22 @@ async def test_parse_chat_messages_multiple_images_async(
[{ [{
"role": "role":
"user", "user",
"content": [{ "content": [
"type": "image_url", {
"image_url": { "type": "image_url",
"url": image_url "image_url": {
} "url": image_url
}, { }
"type": "image_pil", },
"image_pil": ImageAsset('cherry_blossom').pil_image {
}, { "type": "image_pil",
"type": "text", "image_pil": ImageAsset("cherry_blossom").pil_image,
"text": "What's in these images?" },
}] {
"type": "text",
"text": "What's in these images?"
},
],
}], }],
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer, phi3v_tokenizer,
...@@ -377,7 +412,7 @@ async def test_parse_chat_messages_multiple_images_async( ...@@ -377,7 +412,7 @@ async def test_parse_chat_messages_multiple_images_async(
"role": "role":
"user", "user",
"content": "content":
"<|image_1|>\n<|image_2|>\nWhat's in these images?" "<|image_1|>\n<|image_2|>\nWhat's in these images?",
}] }]
_assert_mm_data_is_image_input(await mm_future, 2) _assert_mm_data_is_image_input(await mm_future, 2)
...@@ -391,22 +426,26 @@ def test_parse_chat_messages_placeholder_already_in_prompt( ...@@ -391,22 +426,26 @@ def test_parse_chat_messages_placeholder_already_in_prompt(
[{ [{
"role": "role":
"user", "user",
"content": [{ "content": [
"type": "image_url", {
"image_url": { "type": "image_url",
"url": image_url "image_url": {
} "url": image_url
}, { }
"type": "image_url", },
"image_url": { {
"url": image_url "type": "image_url",
} "image_url": {
}, { "url": image_url
"type": }
"text", },
"text": {
"What's in <|image_1|> and how does it compare to <|image_2|>?" "type":
}] "text",
"text":
"What's in <|image_1|> and how does it compare to <|image_2|>?", # noqa: E501
},
],
}], }],
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer, phi3v_tokenizer,
...@@ -416,7 +455,7 @@ def test_parse_chat_messages_placeholder_already_in_prompt( ...@@ -416,7 +455,7 @@ def test_parse_chat_messages_placeholder_already_in_prompt(
"role": "role":
"user", "user",
"content": "content":
"What's in <|image_1|> and how does it compare to <|image_2|>?" "What's in <|image_1|> and how does it compare to <|image_2|>?",
}] }]
_assert_mm_data_is_image_input(mm_data, 2) _assert_mm_data_is_image_input(mm_data, 2)
...@@ -447,9 +486,9 @@ def test_parse_chat_messages_placeholder_one_already_in_prompt( ...@@ -447,9 +486,9 @@ def test_parse_chat_messages_placeholder_one_already_in_prompt(
"type": "type":
"text", "text",
"text": "text":
"What's in <|image_1|> and how does it compare to the other one?" # noqa: E501 "What's in <|image_1|> and how does it compare to the other one?", # noqa: E501
} },
] ],
}], }],
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer, phi3v_tokenizer,
...@@ -461,7 +500,7 @@ def test_parse_chat_messages_placeholder_one_already_in_prompt( ...@@ -461,7 +500,7 @@ def test_parse_chat_messages_placeholder_one_already_in_prompt(
"user", "user",
"content": "content":
"<|image_2|>\nWhat's in <|image_1|> and how does it compare to the " "<|image_2|>\nWhat's in <|image_1|> and how does it compare to the "
"other one?" "other one?",
}] }]
_assert_mm_data_is_image_input(mm_data, 2) _assert_mm_data_is_image_input(mm_data, 2)
...@@ -472,34 +511,44 @@ def test_parse_chat_messages_multiple_images_across_messages( ...@@ -472,34 +511,44 @@ def test_parse_chat_messages_multiple_images_across_messages(
image_url, image_url,
): ):
conversation, mm_data = parse_chat_messages( conversation, mm_data = parse_chat_messages(
[{ [
"role": {
"user", "role":
"content": [{ "user",
"type": "image_url", "content": [
"image_url": { {
"url": image_url "type": "image_url",
} "image_url": {
}, { "url": image_url
"type": "text", }
"text": "What's in this image?" },
}] {
}, { "type": "text",
"role": "assistant", "text": "What's in this image?"
"content": "Some stuff." },
}, { ],
"role": },
"user", {
"content": [{ "role": "assistant",
"type": "image_url", "content": "Some stuff."
"image_url": { },
"url": image_url {
} "role":
}, { "user",
"type": "text", "content": [
"text": "What about this one?" {
}] "type": "image_url",
}], "image_url": {
"url": image_url
}
},
{
"type": "text",
"text": "What about this one?"
},
],
},
],
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer, phi3v_tokenizer,
content_format="string", content_format="string",
...@@ -527,19 +576,23 @@ def test_parse_chat_messages_context_text_format( ...@@ -527,19 +576,23 @@ def test_parse_chat_messages_context_text_format(
phi3v_tokenizer, phi3v_tokenizer,
): ):
conversation, mm_data = parse_chat_messages( conversation, mm_data = parse_chat_messages(
[{ [
"role": "user", {
"content": [{ "role": "user",
"type": "text", "content": [{
"text": "What's in this text?" "type": "text",
}] "text": "What's in this text?"
}, { }],
"role": "assistant", },
"content": "Some stuff." {
}, { "role": "assistant",
"role": "user", "content": "Some stuff."
"content": "What about this one?" },
}], {
"role": "user",
"content": "What about this one?"
},
],
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer, phi3v_tokenizer,
content_format="openai", content_format="openai",
...@@ -551,21 +604,21 @@ def test_parse_chat_messages_context_text_format( ...@@ -551,21 +604,21 @@ def test_parse_chat_messages_context_text_format(
"content": [{ "content": [{
"type": "text", "type": "text",
"text": "What's in this text?" "text": "What's in this text?"
}] }],
}, },
{ {
"role": "assistant", "role": "assistant",
"content": [{ "content": [{
"type": "text", "type": "text",
"text": "Some stuff." "text": "Some stuff."
}] }],
}, },
{ {
"role": "user", "role": "user",
"content": [{ "content": [{
"type": "text", "type": "text",
"text": "What about this one?" "text": "What about this one?"
}] }],
}, },
] ]
...@@ -578,31 +631,37 @@ def test_parse_chat_messages_rejects_too_many_images_in_one_message( ...@@ -578,31 +631,37 @@ def test_parse_chat_messages_rejects_too_many_images_in_one_message(
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings( warnings.filterwarnings(
"ignore", "ignore",
message="coroutine 'async_get_and_parse_image' was never awaited") message="coroutine 'async_get_and_parse_image' was never awaited",
)
with pytest.raises(ValueError, match="At most"): with pytest.raises(ValueError, match="At most"):
parse_chat_messages( parse_chat_messages(
[{ [{
"role": "role":
"user", "user",
"content": [{ "content": [
"type": "image_url", {
"image_url": { "type": "image_url",
"url": image_url "image_url": {
} "url": image_url
}, { },
"type": "image_url", },
"image_url": { {
"url": image_url "type": "image_url",
} "image_url": {
}, { "url": image_url
"type": "image_url", },
"image_url": { },
"url": image_url {
} "type": "image_url",
}, { "image_url": {
"type": "text", "url": image_url
"text": "What's in these images?" },
}] },
{
"type": "text",
"text": "What's in these images?"
},
],
}], }],
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer, phi3v_tokenizer,
...@@ -618,42 +677,54 @@ def test_parse_chat_messages_rejects_too_many_images_across_messages( ...@@ -618,42 +677,54 @@ def test_parse_chat_messages_rejects_too_many_images_across_messages(
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings( warnings.filterwarnings(
"ignore", "ignore",
message="coroutine 'async_get_and_parse_image' was never awaited") message="coroutine 'async_get_and_parse_image' was never awaited",
)
with pytest.raises(ValueError, match="At most"): with pytest.raises(ValueError, match="At most"):
parse_chat_messages( parse_chat_messages(
[{ [
"role": {
"user", "role":
"content": [{ "user",
"type": "image_url", "content": [
"image_url": { {
"url": image_url "type": "image_url",
} "image_url": {
}, { "url": image_url
"type": "text", },
"text": "What's in this image?" },
}] {
}, { "type": "text",
"role": "assistant", "text": "What's in this image?"
"content": "Some stuff." },
}, { ],
"role": },
"user", {
"content": [{ "role": "assistant",
"type": "image_url", "content": "Some stuff."
"image_url": { },
"url": image_url {
} "role":
}, { "user",
"type": "image_url", "content": [
"image_url": { {
"url": image_url "type": "image_url",
} "image_url": {
}, { "url": image_url
"type": "text", },
"text": "What about these two?" },
}] {
}], "type": "image_url",
"image_url": {
"url": image_url
},
},
{
"type": "text",
"text": "What about these two?"
},
],
},
],
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer, phi3v_tokenizer,
content_format="string", content_format="string",
...@@ -670,12 +741,14 @@ def test_parse_chat_messages_multiple_images_uncommon_input( ...@@ -670,12 +741,14 @@ def test_parse_chat_messages_multiple_images_uncommon_input(
"role": "role":
"user", "user",
"content": [ "content": [
"What's in these images?", { "What's in these images?",
{
"image_url": image_url "image_url": image_url
}, { },
{
"image_url": image_url "image_url": image_url
} },
] ],
}], }],
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer, phi3v_tokenizer,
...@@ -686,7 +759,7 @@ def test_parse_chat_messages_multiple_images_uncommon_input( ...@@ -686,7 +759,7 @@ def test_parse_chat_messages_multiple_images_uncommon_input(
"role": "role":
"user", "user",
"content": "content":
"<|image_1|>\n<|image_2|>\nWhat's in these images?" "<|image_1|>\n<|image_2|>\nWhat's in these images?",
}] }]
_assert_mm_data_is_image_input(mm_data, 2) _assert_mm_data_is_image_input(mm_data, 2)
...@@ -700,26 +773,32 @@ def test_parse_chat_messages_multiple_images_interleave( ...@@ -700,26 +773,32 @@ def test_parse_chat_messages_multiple_images_interleave(
[{ [{
"role": "role":
"user", "user",
"content": [{ "content": [
"type": "text", {
"text": "I need you to compare this image" "type": "text",
}, { "text": "I need you to compare this image",
"type": "image_url", },
"image_url": { {
"url": image_url "type": "image_url",
} "image_url": {
}, { "url": image_url
"type": "text", }
"text": "and this one" },
}, { {
"type": "image_url", "type": "text",
"image_url": { "text": "and this one"
"url": image_url },
} {
}, { "type": "image_url",
"type": "text", "image_url": {
"text": "Do they have differences?" "url": image_url
}] }
},
{
"type": "text",
"text": "Do they have differences?"
},
],
}], }],
phi3v_model_config_mm_interleaved, phi3v_model_config_mm_interleaved,
phi3v_tokenizer, phi3v_tokenizer,
...@@ -731,7 +810,7 @@ def test_parse_chat_messages_multiple_images_interleave( ...@@ -731,7 +810,7 @@ def test_parse_chat_messages_multiple_images_interleave(
"user", "user",
"content": "content":
"I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501 "I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501
"Do they have differences?" "Do they have differences?",
}] }]
_assert_mm_data_is_image_input(mm_data, 2) _assert_mm_data_is_image_input(mm_data, 2)
...@@ -746,26 +825,32 @@ async def test_parse_chat_messages_multiple_images_interleave_async( ...@@ -746,26 +825,32 @@ async def test_parse_chat_messages_multiple_images_interleave_async(
[{ [{
"role": "role":
"user", "user",
"content": [{ "content": [
"type": "text", {
"text": "I need you to compare this image" "type": "text",
}, { "text": "I need you to compare this image",
"type": "image_url", },
"image_url": { {
"url": image_url "type": "image_url",
} "image_url": {
}, { "url": image_url
"type": "text", }
"text": "and this one" },
}, { {
"type": "image_url", "type": "text",
"image_url": { "text": "and this one"
"url": image_url },
} {
}, { "type": "image_url",
"type": "text", "image_url": {
"text": "Do they have differences?" "url": image_url
}] }
},
{
"type": "text",
"text": "Do they have differences?"
},
],
}], }],
phi3v_model_config_mm_interleaved, phi3v_model_config_mm_interleaved,
phi3v_tokenizer, phi3v_tokenizer,
...@@ -777,7 +862,7 @@ async def test_parse_chat_messages_multiple_images_interleave_async( ...@@ -777,7 +862,7 @@ async def test_parse_chat_messages_multiple_images_interleave_async(
"user", "user",
"content": "content":
"I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501 "I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501
"Do they have differences?" "Do they have differences?",
}] }]
_assert_mm_data_is_image_input(await mm_data, 2) _assert_mm_data_is_image_input(await mm_data, 2)
...@@ -788,135 +873,161 @@ def test_parse_chat_messages_multiple_images_multiple_messages_interleave( ...@@ -788,135 +873,161 @@ def test_parse_chat_messages_multiple_images_multiple_messages_interleave(
image_url, image_url,
): ):
conversation, mm_data = parse_chat_messages( conversation, mm_data = parse_chat_messages(
[{ [
"role": {
"user", "role":
"content": [ "user",
{ "content": [
"type": "text", {
"text": "What's on this image?" "type": "text",
}, "text": "What's on this image?"
{ },
"type": "image_url", {
"image_url": { "type": "image_url",
"url": image_url "image_url": {
} "url": image_url
}, }
{ },
"type": "text", {
"text": "Be accurate." "type": "text",
}, "text": "Be accurate."
] },
}, { ],
"role": "assistant", },
"content": "Some stuff." {
}, { "role": "assistant",
"role": "content": "Some stuff."
"user", },
"content": [{ {
"type": "text", "role":
"text": "What's on this image?" "user",
}, { "content": [
"type": "image_url", {
"image_url": { "type": "text",
"url": image_url "text": "What's on this image?"
} },
}] {
}], "type": "image_url",
"image_url": {
"url": image_url
}
},
],
},
],
phi3v_model_config_mm_interleaved, phi3v_model_config_mm_interleaved,
phi3v_tokenizer, phi3v_tokenizer,
content_format="string", content_format="string",
) )
assert conversation == [{ assert conversation == [
"role": {
"user", "role": "user",
"content": "content": "What's on this image?\n<|image_1|>\nBe accurate.",
"What's on this image?\n<|image_1|>\nBe accurate." },
}, { {
"role": "assistant", "role": "assistant",
"content": "Some stuff." "content": "Some stuff."
}, { },
"role": "user", {
"content": "What's on this image?\n<|image_2|>" "role": "user",
}] "content": "What's on this image?\n<|image_2|>"
},
]
_assert_mm_data_is_image_input(mm_data, 2) _assert_mm_data_is_image_input(mm_data, 2)
def test_parse_chat_messages_multiple_modals_multiple_messages_interleave( def test_parse_chat_messages_multiple_modals_multiple_messages_interleave(
qwen25omni_model_config_mm_interleaved, qwen25omni_tokenizer, qwen25omni_model_config_mm_interleaved,
image_url, video_url, audio_url): qwen25omni_tokenizer,
image_url,
video_url,
audio_url,
):
conversation, mm_data = parse_chat_messages( conversation, mm_data = parse_chat_messages(
[{ [
{
"role":
"user",
"content": [
{
"type": "text",
"text": "What's on this image?"
},
{
"type": "image_url",
"image_url": {
"url": image_url
}
},
{
"type": "text",
"text": "Now listen to this audio"
},
{
"type": "audio_url",
"audio_url": {
"url": audio_url
}
},
],
},
{
"role": "assistant",
"content": "Some stuff."
},
{
"role":
"user",
"content": [
{
"type": "text",
"text": "What's on this image?"
},
{
"type": "image_url",
"image_url": {
"url": image_url
}
},
{
"type": "text",
"text": "And what's in the video?"
},
{
"type": "video_url",
"video_url": {
"url": video_url
}
},
],
},
],
qwen25omni_model_config_mm_interleaved,
qwen25omni_tokenizer,
content_format="string",
)
assert conversation == [
{
"role": "role":
"user", "user",
"content": [ "content":
{ "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n"
"type": "text", "Now listen to this audio\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>", # noqa: E501
"text": "What's on this image?" },
}, {
{
"type": "image_url",
"image_url": {
"url": image_url
}
},
{
"type": "text",
"text": "Now listen to this audio"
},
{
"type": "audio_url",
"audio_url": {
"url": audio_url
}
},
]
}, {
"role": "assistant", "role": "assistant",
"content": "Some stuff." "content": "Some stuff."
}, { },
{
"role": "role":
"user", "user",
"content": [{ "content":
"type": "text", "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n"
"text": "What's on this image?" "And what's in the video?\n<|vision_start|><|VIDEO|><|vision_end|>",
}, { },
"type": "image_url", ]
"image_url": {
"url": image_url
}
}, {
"type": "text",
"text": "And what's in the video?"
}, {
"type": "video_url",
"video_url": {
"url": video_url
}
}]
}],
qwen25omni_model_config_mm_interleaved,
qwen25omni_tokenizer,
content_format="string",
)
assert conversation == [{
"role":
"user",
"content":
"What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n"
"Now listen to this audio\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>"
}, {
"role": "assistant",
"content": "Some stuff."
}, {
"role":
"user",
"content":
"What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n"
"And what's in the video?\n<|vision_start|><|VIDEO|><|vision_end|>"
}]
_assert_mm_data_inputs(mm_data, {"image": 2, "video": 1, "audio": 1}) _assert_mm_data_inputs(mm_data, {"image": 2, "video": 1, "audio": 1})
...@@ -929,7 +1040,8 @@ def test_parse_chat_messages_multiple_images_interleave_with_placeholders( ...@@ -929,7 +1040,8 @@ def test_parse_chat_messages_multiple_images_interleave_with_placeholders(
with pytest.raises( with pytest.raises(
ValueError, ValueError,
match=r"Found more '<|image_1|>' placeholders in input prompt " match=r"Found more '<|image_1|>' placeholders in input prompt "
"than actual multimodal data items."): "than actual multimodal data items.",
):
parse_chat_messages( parse_chat_messages(
[{ [{
"role": "role":
...@@ -952,9 +1064,9 @@ def test_parse_chat_messages_multiple_images_interleave_with_placeholders( ...@@ -952,9 +1064,9 @@ def test_parse_chat_messages_multiple_images_interleave_with_placeholders(
"text", "text",
"text": "text":
"I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501 "I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501
"Do they have differences?" "Do they have differences?",
}, },
] ],
}], }],
phi3v_model_config_mm_interleaved, phi3v_model_config_mm_interleaved,
phi3v_tokenizer, phi3v_tokenizer,
...@@ -973,12 +1085,15 @@ def test_mllama_single_image( ...@@ -973,12 +1085,15 @@ def test_mllama_single_image(
[{ [{
"role": "role":
"user", "user",
"content": [{ "content": [
'type': 'text', {
'text': 'The content of this image is:' "type": "text",
}, { "text": "The content of this image is:"
"image_url": image_url },
}] {
"image_url": image_url
},
],
}], }],
mllama_model_config, mllama_model_config,
mllama_tokenizer, mllama_tokenizer,
...@@ -986,14 +1101,17 @@ def test_mllama_single_image( ...@@ -986,14 +1101,17 @@ def test_mllama_single_image(
) )
_assert_mm_data_is_image_input(mm_data, 1) _assert_mm_data_is_image_input(mm_data, 1)
assert conversation == [{ assert conversation == [{
'role': "role":
'user', "user",
'content': [{ "content": [
'type': 'text', {
'text': 'The content of this image is:' "type": "text",
}, { "text": "The content of this image is:"
'type': 'image' },
}] {
"type": "image"
},
],
}] }]
...@@ -1009,20 +1127,20 @@ def test_mllama_interleaved_images( ...@@ -1009,20 +1127,20 @@ def test_mllama_interleaved_images(
"user", "user",
"content": [ "content": [
{ {
'type': 'text', "type": "text",
'text': 'The content of the first image is:' "text": "The content of the first image is:",
}, },
{ {
"image_url": image_url "image_url": image_url
}, },
{ {
'type': 'text', "type": "text",
'text': 'The content of the second image is:' "text": "The content of the second image is:",
}, },
{ {
"image_url": image_url "image_url": image_url
}, },
] ],
}], }],
mllama_model_config, mllama_model_config,
mllama_tokenizer, mllama_tokenizer,
...@@ -1030,19 +1148,24 @@ def test_mllama_interleaved_images( ...@@ -1030,19 +1148,24 @@ def test_mllama_interleaved_images(
) )
_assert_mm_data_is_image_input(mm_data, 2) _assert_mm_data_is_image_input(mm_data, 2)
assert conversation == [{ assert conversation == [{
'role': "role":
'user', "user",
'content': [{ "content": [
'type': 'text', {
'text': 'The content of the first image is:' "type": "text",
}, { "text": "The content of the first image is:"
'type': 'image' },
}, { {
'type': 'text', "type": "image"
'text': 'The content of the second image is:' },
}, { {
'type': 'image' "type": "text",
}] "text": "The content of the second image is:"
},
{
"type": "image"
},
],
}] }]
...@@ -1053,34 +1176,36 @@ def test_multimodal_image_parsing_matches_hf(model, image_url): ...@@ -1053,34 +1176,36 @@ def test_multimodal_image_parsing_matches_hf(model, image_url):
def get_conversation(is_hf: bool): def get_conversation(is_hf: bool):
img_part = {"type": "image_url", "image_url": {"url": image_url}} img_part = {"type": "image_url", "image_url": {"url": image_url}}
if is_hf: if is_hf:
img_part = {'type': 'image'} img_part = {"type": "image"}
return [{ return [{
'role': "role":
'user', "user",
'content': [ "content": [
{ {
'type': 'text', "type": "text",
'text': 'The content of the first image is:' "text": "The content of the first image is:",
}, },
img_part, img_part,
{ {
'type': 'text', "type": "text",
'text': 'The content of the second image is:' "text": "The content of the second image is:",
}, },
img_part, img_part,
{ {
'type': 'text', "type": "text",
'text': 'What animal is in the first image?' "text": "What animal is in the first image?",
}, },
] ],
}] }]
# Build a config for the model # Build a config for the model
model_config = ModelConfig(model, model_config = ModelConfig(
runner="generate", model,
limit_mm_per_prompt={ runner="generate",
"image": 2, limit_mm_per_prompt={
}) "image": 2,
},
)
# Build the tokenizer group and grab the underlying tokenizer # Build the tokenizer group and grab the underlying tokenizer
tokenizer_group = TokenizerGroup( tokenizer_group = TokenizerGroup(
...@@ -1126,7 +1251,8 @@ def test_multimodal_image_parsing_matches_hf(model, image_url): ...@@ -1126,7 +1251,8 @@ def test_multimodal_image_parsing_matches_hf(model, image_url):
[ [
QWEN2VL_MODEL_ID, # tokenizer.chat_template is of type str QWEN2VL_MODEL_ID, # tokenizer.chat_template is of type str
HERMES_MODEL_ID, # tokenizer.chat_template is of type dict HERMES_MODEL_ID, # tokenizer.chat_template is of type dict
]) ],
)
@pytest.mark.parametrize("use_tools", [True, False]) @pytest.mark.parametrize("use_tools", [True, False])
def test_resolve_hf_chat_template(sample_json_schema, model, use_tools): def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
"""checks that chat_template is a dict type for HF models.""" """checks that chat_template is a dict type for HF models."""
...@@ -1152,14 +1278,14 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools): ...@@ -1152,14 +1278,14 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
) )
tokenizer = tokenizer_group.tokenizer tokenizer = tokenizer_group.tokenizer
tools = [{ tools = ([{
"type": "function", "type": "function",
"function": { "function": {
"name": "dummy_function_name", "name": "dummy_function_name",
"description": "This is a dummy function", "description": "This is a dummy function",
"parameters": sample_json_schema "parameters": sample_json_schema,
} },
}] if use_tools else None }] if use_tools else None)
# Test detecting the tokenizer's chat_template # Test detecting the tokenizer's chat_template
chat_template = resolve_hf_chat_template( chat_template = resolve_hf_chat_template(
......
...@@ -103,6 +103,7 @@ class PILImage(BaseModel): ...@@ -103,6 +103,7 @@ class PILImage(BaseModel):
""" """
A PIL.Image.Image object. A PIL.Image.Image object.
""" """
image_pil: Image.Image image_pil: Image.Image
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
...@@ -115,6 +116,7 @@ class CustomChatCompletionContentPILImageParam(TypedDict, total=False): ...@@ -115,6 +116,7 @@ class CustomChatCompletionContentPILImageParam(TypedDict, total=False):
"image_pil": ImageAsset('cherry_blossom').pil_image "image_pil": ImageAsset('cherry_blossom').pil_image
} }
""" """
image_pil: Required[PILImage] image_pil: Required[PILImage]
...@@ -127,6 +129,7 @@ class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False): ...@@ -127,6 +129,7 @@ class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
"image_url": "https://example.com/image.jpg" "image_url": "https://example.com/image.jpg"
} }
""" """
image_url: Required[str] image_url: Required[str]
...@@ -138,6 +141,7 @@ class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False): ...@@ -138,6 +141,7 @@ class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
"audio_url": "https://example.com/audio.mp3" "audio_url": "https://example.com/audio.mp3"
} }
""" """
audio_url: Required[str] audio_url: Required[str]
...@@ -149,6 +153,7 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False): ...@@ -149,6 +153,7 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
"video_url": "https://example.com/video.mp4" "video_url": "https://example.com/video.mp4"
} }
""" """
video_url: Required[str] video_url: Required[str]
...@@ -174,19 +179,24 @@ class CustomThinkCompletionContentParam(TypedDict, total=False): ...@@ -174,19 +179,24 @@ class CustomThinkCompletionContentParam(TypedDict, total=False):
ChatCompletionContentPartParam: TypeAlias = Union[ ChatCompletionContentPartParam: TypeAlias = Union[
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam, OpenAIChatCompletionContentPartParam,
ChatCompletionContentPartAudioParam,
ChatCompletionContentPartInputAudioParam, ChatCompletionContentPartInputAudioParam,
ChatCompletionContentPartVideoParam, ChatCompletionContentPartRefusalParam, ChatCompletionContentPartVideoParam,
ChatCompletionContentPartRefusalParam,
CustomChatCompletionContentPILImageParam, CustomChatCompletionContentPILImageParam,
CustomChatCompletionContentSimpleImageParam, CustomChatCompletionContentSimpleImageParam,
ChatCompletionContentPartImageEmbedsParam, ChatCompletionContentPartImageEmbedsParam,
CustomChatCompletionContentSimpleAudioParam, CustomChatCompletionContentSimpleAudioParam,
CustomChatCompletionContentSimpleVideoParam, str, CustomChatCompletionContentSimpleVideoParam,
CustomThinkCompletionContentParam] str,
CustomThinkCompletionContentParam,
]
class CustomChatCompletionMessageParam(TypedDict, total=False): class CustomChatCompletionMessageParam(TypedDict, total=False):
"""Enables custom roles in the Chat Completion API.""" """Enables custom roles in the Chat Completion API."""
role: Required[str] role: Required[str]
"""The role of the message's author.""" """The role of the message's author."""
...@@ -207,9 +217,11 @@ class CustomChatCompletionMessageParam(TypedDict, total=False): ...@@ -207,9 +217,11 @@ class CustomChatCompletionMessageParam(TypedDict, total=False):
"""The tool calls generated by the model, such as function calls.""" """The tool calls generated by the model, such as function calls."""
ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam, ChatCompletionMessageParam = Union[
CustomChatCompletionMessageParam, OpenAIChatCompletionMessageParam,
OpenAIHarmonyMessage] CustomChatCompletionMessageParam,
OpenAIHarmonyMessage,
]
# TODO: Make fields ReadOnly once mypy supports it # TODO: Make fields ReadOnly once mypy supports it
...@@ -262,13 +274,13 @@ def _is_var_or_elems_access( ...@@ -262,13 +274,13 @@ def _is_var_or_elems_access(
key: Optional[str] = None, key: Optional[str] = None,
) -> bool: ) -> bool:
if isinstance(node, jinja2.nodes.Filter): if isinstance(node, jinja2.nodes.Filter):
return (node.node is not None return node.node is not None and _is_var_or_elems_access(
and _is_var_or_elems_access(node.node, varname, key)) node.node, varname, key)
if isinstance(node, jinja2.nodes.Test): if isinstance(node, jinja2.nodes.Test):
return _is_var_or_elems_access(node.node, varname, key) return _is_var_or_elems_access(node.node, varname, key)
if (isinstance(node, jinja2.nodes.Getitem) if isinstance(node, jinja2.nodes.Getitem) and isinstance(
and isinstance(node.arg, jinja2.nodes.Slice)): node.arg, jinja2.nodes.Slice):
return _is_var_or_elems_access(node.node, varname, key) return _is_var_or_elems_access(node.node, varname, key)
# yapf: disable # yapf: disable
...@@ -373,15 +385,18 @@ def resolve_mistral_chat_template( ...@@ -373,15 +385,18 @@ def resolve_mistral_chat_template(
) -> Optional[str]: ) -> Optional[str]:
if chat_template is not None: if chat_template is not None:
logger.warning_once( logger.warning_once(
"'chat_template' cannot be overridden for mistral tokenizer.") "'chat_template' cannot be overridden for mistral tokenizer."
)
if "add_generation_prompt" in kwargs: if "add_generation_prompt" in kwargs:
logger.warning_once( logger.warning_once(
"'add_generation_prompt' is not supported for mistral tokenizer, " "'add_generation_prompt' is not supported for mistral tokenizer, "
"so it will be ignored.") "so it will be ignored."
)
if "continue_final_message" in kwargs: if "continue_final_message" in kwargs:
logger.warning_once( logger.warning_once(
"'continue_final_message' is not supported for mistral tokenizer, " "'continue_final_message' is not supported for mistral tokenizer, "
"so it will be ignored.") "so it will be ignored."
)
return None return None
...@@ -401,23 +416,35 @@ def resolve_hf_chat_template( ...@@ -401,23 +416,35 @@ def resolve_hf_chat_template(
try: try:
processor = cached_get_processor( processor = cached_get_processor(
tokenizer.name_or_path, tokenizer.name_or_path,
processor_cls=(PreTrainedTokenizer, PreTrainedTokenizerFast, processor_cls=(
ProcessorMixin), PreTrainedTokenizer,
PreTrainedTokenizerFast,
ProcessorMixin,
),
trust_remote_code=model_config.trust_remote_code, trust_remote_code=model_config.trust_remote_code,
) )
if isinstance(processor, ProcessorMixin) and \ if (
hasattr(processor, 'chat_template') and \ isinstance(processor, ProcessorMixin)
processor.chat_template is not None: and hasattr(processor, "chat_template")
and processor.chat_template is not None
):
return processor.chat_template return processor.chat_template
except Exception: except Exception:
logger.debug("Failed to load AutoProcessor chat template for %s", tokenizer.name_or_path, exc_info=True) # noqa: E501 logger.debug(
"Failed to load AutoProcessor chat template for %s",
tokenizer.name_or_path,
exc_info=True,
) # noqa: E501
# 3rd priority: AutoTokenizer chat template # 3rd priority: AutoTokenizer chat template
try: try:
return tokenizer.get_chat_template(chat_template, tools=tools) return tokenizer.get_chat_template(chat_template, tools=tools)
except Exception: except Exception:
logger.debug("Failed to load AutoTokenizer chat template for %s", logger.debug(
tokenizer.name_or_path, exc_info=True) "Failed to load AutoTokenizer chat template for %s",
tokenizer.name_or_path,
exc_info=True,
)
# 4th priority: Predefined fallbacks # 4th priority: Predefined fallbacks
path = get_chat_template_fallback_path( path = get_chat_template_fallback_path(
...@@ -425,12 +452,16 @@ def resolve_hf_chat_template( ...@@ -425,12 +452,16 @@ def resolve_hf_chat_template(
tokenizer_name_or_path=model_config.tokenizer, tokenizer_name_or_path=model_config.tokenizer,
) )
if path is not None: if path is not None:
logger.info("Loading chat template fallback for %s as there isn't one " logger.info(
"defined on HF Hub.", tokenizer.name_or_path) "Loading chat template fallback for %s as there isn't one "
"defined on HF Hub.",
tokenizer.name_or_path,
)
chat_template = load_chat_template(path) chat_template = load_chat_template(path)
else: else:
logger.debug("There is no chat template fallback for %s", logger.debug(
tokenizer.name_or_path) "There is no chat template fallback for %s", tokenizer.name_or_path
)
return chat_template return chat_template
...@@ -452,11 +483,17 @@ def _resolve_chat_template_content_format( ...@@ -452,11 +483,17 @@ def _resolve_chat_template_content_format(
else: else:
hf_chat_template = None hf_chat_template = None
jinja_text = (hf_chat_template if isinstance(hf_chat_template, str) jinja_text = (
else load_chat_template(chat_template, is_literal=True)) hf_chat_template
if isinstance(hf_chat_template, str)
else load_chat_template(chat_template, is_literal=True)
)
detected_format = ("string" if jinja_text is None else detected_format = (
_detect_content_format(jinja_text, default="string")) "string"
if jinja_text is None
else _detect_content_format(jinja_text, default="string")
)
return detected_format return detected_format
...@@ -512,7 +549,6 @@ def resolve_chat_template_content_format( ...@@ -512,7 +549,6 @@ def resolve_chat_template_content_format(
return detected_format return detected_format
ModalityStr = Literal["image", "audio", "video", "image_embeds"] ModalityStr = Literal["image", "audio", "video", "image_embeds"]
_T = TypeVar("_T") _T = TypeVar("_T")
...@@ -539,6 +575,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -539,6 +575,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
@cached_property @cached_property
def model_cls(self) -> type[SupportsMultiModal]: def model_cls(self) -> type[SupportsMultiModal]:
from vllm.model_executor.model_loader import get_model_cls from vllm.model_executor.model_loader import get_model_cls
model_cls = get_model_cls(self.model_config) model_cls = get_model_cls(self.model_config)
return cast(type[SupportsMultiModal], model_cls) return cast(type[SupportsMultiModal], model_cls)
...@@ -574,28 +611,29 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -574,28 +611,29 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
class MultiModalItemTracker(BaseMultiModalItemTracker[object]): class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
def all_mm_data(self) -> Optional[MultiModalDataDict]: def all_mm_data(self) -> Optional[MultiModalDataDict]:
if not self._items_by_modality: if not self._items_by_modality:
return None return None
mm_inputs = {} mm_inputs = {}
items_by_modality = dict(self._items_by_modality) items_by_modality = dict(self._items_by_modality)
if "image" in items_by_modality and "image_embeds" in items_by_modality: if "image" in items_by_modality and "image_embeds" in items_by_modality:
raise ValueError(\ raise ValueError(
"Mixing raw image and embedding inputs is not allowed") "Mixing raw image and embedding inputs is not allowed"
)
if "image_embeds" in items_by_modality: if "image_embeds" in items_by_modality:
image_embeds_lst = items_by_modality["image_embeds"] image_embeds_lst = items_by_modality["image_embeds"]
if len(image_embeds_lst) > 1: if len(image_embeds_lst) > 1:
raise ValueError(\ raise ValueError(
"Only one message can have {'type': 'image_embeds'}") "Only one message can have {'type': 'image_embeds'}"
)
mm_inputs["image"] = image_embeds_lst[0] mm_inputs["image"] = image_embeds_lst[0]
if "image" in items_by_modality: if "image" in items_by_modality:
mm_inputs["image"] = items_by_modality["image"] # A list of images mm_inputs["image"] = items_by_modality["image"] # A list of images
if "audio" in items_by_modality: if "audio" in items_by_modality:
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
if "video" in items_by_modality: if "video" in items_by_modality:
mm_inputs["video"] = items_by_modality["video"] # A list of videos mm_inputs["video"] = items_by_modality["video"] # A list of videos
return mm_inputs return mm_inputs
def create_parser(self) -> "BaseMultiModalContentParser": def create_parser(self) -> "BaseMultiModalContentParser":
...@@ -603,32 +641,33 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]): ...@@ -603,32 +641,33 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]): class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
async def all_mm_data(self) -> Optional[MultiModalDataDict]: async def all_mm_data(self) -> Optional[MultiModalDataDict]:
if not self._items_by_modality: if not self._items_by_modality:
return None return None
mm_inputs = {} mm_inputs = {}
items_by_modality = { items_by_modality = {
modality: await asyncio.gather(*items) modality: await asyncio.gather(*items)
for modality, items in self._items_by_modality.items() for modality, items in self._items_by_modality.items()
} }
if "image" in items_by_modality and "image_embeds" in items_by_modality: if "image" in items_by_modality and "image_embeds" in items_by_modality:
raise ValueError( raise ValueError(
"Mixing raw image and embedding inputs is not allowed") "Mixing raw image and embedding inputs is not allowed"
)
if "image_embeds" in items_by_modality: if "image_embeds" in items_by_modality:
image_embeds_lst = items_by_modality["image_embeds"] image_embeds_lst = items_by_modality["image_embeds"]
if len(image_embeds_lst) > 1: if len(image_embeds_lst) > 1:
raise ValueError( raise ValueError(
"Only one message can have {'type': 'image_embeds'}") "Only one message can have {'type': 'image_embeds'}"
)
mm_inputs["image"] = image_embeds_lst[0] mm_inputs["image"] = image_embeds_lst[0]
if "image" in items_by_modality: if "image" in items_by_modality:
mm_inputs["image"] = items_by_modality["image"] # A list of images mm_inputs["image"] = items_by_modality["image"] # A list of images
if "audio" in items_by_modality: if "audio" in items_by_modality:
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
if "video" in items_by_modality: if "video" in items_by_modality:
mm_inputs["video"] = items_by_modality["video"] # A list of videos mm_inputs["video"] = items_by_modality["video"] # A list of videos
return mm_inputs return mm_inputs
def create_parser(self) -> "BaseMultiModalContentParser": def create_parser(self) -> "BaseMultiModalContentParser":
...@@ -636,7 +675,6 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]): ...@@ -636,7 +675,6 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
class BaseMultiModalContentParser(ABC): class BaseMultiModalContentParser(ABC):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
...@@ -648,8 +686,9 @@ class BaseMultiModalContentParser(ABC): ...@@ -648,8 +686,9 @@ class BaseMultiModalContentParser(ABC):
# } # }
self._placeholder_storage: dict[str, list] = defaultdict(list) self._placeholder_storage: dict[str, list] = defaultdict(list)
def _add_placeholder(self, modality: ModalityStr, def _add_placeholder(
placeholder: Optional[str]): self, modality: ModalityStr, placeholder: Optional[str]
):
mod_placeholder = MODALITY_PLACEHOLDERS_MAP[modality] mod_placeholder = MODALITY_PLACEHOLDERS_MAP[modality]
if placeholder: if placeholder:
self._placeholder_storage[mod_placeholder].append(placeholder) self._placeholder_storage[mod_placeholder].append(placeholder)
...@@ -662,8 +701,9 @@ class BaseMultiModalContentParser(ABC): ...@@ -662,8 +701,9 @@ class BaseMultiModalContentParser(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def parse_image_embeds(self, def parse_image_embeds(
image_embeds: Union[str, dict[str, str]]) -> None: self, image_embeds: Union[str, dict[str, str]]
) -> None:
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
...@@ -684,7 +724,6 @@ class BaseMultiModalContentParser(ABC): ...@@ -684,7 +724,6 @@ class BaseMultiModalContentParser(ABC):
class MultiModalContentParser(BaseMultiModalContentParser): class MultiModalContentParser(BaseMultiModalContentParser):
def __init__(self, tracker: MultiModalItemTracker) -> None: def __init__(self, tracker: MultiModalItemTracker) -> None:
super().__init__() super().__init__()
...@@ -701,8 +740,9 @@ class MultiModalContentParser(BaseMultiModalContentParser): ...@@ -701,8 +740,9 @@ class MultiModalContentParser(BaseMultiModalContentParser):
placeholder = self._tracker.add("image", image) placeholder = self._tracker.add("image", image)
self._add_placeholder("image", placeholder) self._add_placeholder("image", placeholder)
def parse_image_embeds(self, def parse_image_embeds(
image_embeds: Union[str, dict[str, str]]) -> None: self, image_embeds: Union[str, dict[str, str]]
) -> None:
if isinstance(image_embeds, dict): if isinstance(image_embeds, dict):
embeds = { embeds = {
k: self._connector.fetch_image_embedding(v) k: self._connector.fetch_image_embedding(v)
...@@ -741,14 +781,13 @@ class MultiModalContentParser(BaseMultiModalContentParser): ...@@ -741,14 +781,13 @@ class MultiModalContentParser(BaseMultiModalContentParser):
class AsyncMultiModalContentParser(BaseMultiModalContentParser): class AsyncMultiModalContentParser(BaseMultiModalContentParser):
def __init__(self, tracker: AsyncMultiModalItemTracker) -> None: def __init__(self, tracker: AsyncMultiModalItemTracker) -> None:
super().__init__() super().__init__()
self._tracker = tracker self._tracker = tracker
self._connector = MediaConnector( self._connector = MediaConnector(
media_io_kwargs=self._tracker._model_config.media_io_kwargs, media_io_kwargs=self._tracker._model_config.media_io_kwargs,
allowed_local_media_path=tracker.allowed_local_media_path allowed_local_media_path=tracker.allowed_local_media_path,
) )
def parse_image(self, image_url: str) -> None: def parse_image(self, image_url: str) -> None:
...@@ -757,8 +796,9 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser): ...@@ -757,8 +796,9 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
placeholder = self._tracker.add("image", image_coro) placeholder = self._tracker.add("image", image_coro)
self._add_placeholder("image", placeholder) self._add_placeholder("image", placeholder)
def parse_image_embeds(self, def parse_image_embeds(
image_embeds: Union[str, dict[str, str]]) -> None: self, image_embeds: Union[str, dict[str, str]]
) -> None:
future: asyncio.Future[Union[str, dict[str, str]]] = asyncio.Future() future: asyncio.Future[Union[str, dict[str, str]]] = asyncio.Future()
if isinstance(image_embeds, dict): if isinstance(image_embeds, dict):
...@@ -769,8 +809,7 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser): ...@@ -769,8 +809,7 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
future.set_result(embeds) future.set_result(embeds)
if isinstance(image_embeds, str): if isinstance(image_embeds, str):
embedding = self._connector.\ embedding = self._connector.fetch_image_embedding(image_embeds)
fetch_image_embedding(image_embeds)
future.set_result(embedding) future.set_result(embedding)
placeholder = self._tracker.add("image_embeds", future) placeholder = self._tracker.add("image_embeds", future)
...@@ -809,20 +848,23 @@ def validate_chat_template(chat_template: Optional[Union[Path, str]]): ...@@ -809,20 +848,23 @@ def validate_chat_template(chat_template: Optional[Union[Path, str]]):
return return
elif isinstance(chat_template, Path) and not chat_template.exists(): elif isinstance(chat_template, Path) and not chat_template.exists():
raise FileNotFoundError( raise FileNotFoundError("the supplied chat template path doesn't exist")
"the supplied chat template path doesn't exist")
elif isinstance(chat_template, str): elif isinstance(chat_template, str):
JINJA_CHARS = "{}\n" JINJA_CHARS = "{}\n"
if not any(c in chat_template if (
for c in JINJA_CHARS) and not Path(chat_template).exists(): not any(c in chat_template for c in JINJA_CHARS)
and not Path(chat_template).exists()
):
raise ValueError( raise ValueError(
f"The supplied chat template string ({chat_template}) " f"The supplied chat template string ({chat_template}) "
f"appears path-like, but doesn't exist!") f"appears path-like, but doesn't exist!"
)
else: else:
raise TypeError( raise TypeError(
f"{type(chat_template)} is not a valid chat template type") f"{type(chat_template)} is not a valid chat template type"
)
def _load_chat_template( def _load_chat_template(
...@@ -835,8 +877,9 @@ def _load_chat_template( ...@@ -835,8 +877,9 @@ def _load_chat_template(
if is_literal: if is_literal:
if isinstance(chat_template, Path): if isinstance(chat_template, Path):
raise TypeError("chat_template is expected to be read directly " raise TypeError(
"from its value") "chat_template is expected to be read directly from its value"
)
return chat_template return chat_template
...@@ -849,9 +892,11 @@ def _load_chat_template( ...@@ -849,9 +892,11 @@ def _load_chat_template(
JINJA_CHARS = "{}\n" JINJA_CHARS = "{}\n"
if not any(c in chat_template for c in JINJA_CHARS): if not any(c in chat_template for c in JINJA_CHARS):
msg = (f"The supplied chat template ({chat_template}) " msg = (
f"looks like a file path, but it failed to be " f"The supplied chat template ({chat_template}) "
f"opened. Reason: {e}") f"looks like a file path, but it failed to be "
f"opened. Reason: {e}"
)
raise ValueError(msg) from e raise ValueError(msg) from e
# If opening a file fails, set chat template to be args to # If opening a file fails, set chat template to be args to
...@@ -870,8 +915,9 @@ def load_chat_template( ...@@ -870,8 +915,9 @@ def load_chat_template(
return _cached_load_chat_template(chat_template, is_literal=is_literal) return _cached_load_chat_template(chat_template, is_literal=is_literal)
def _get_interleaved_text_prompt(placeholder_storage: dict[str, list], def _get_interleaved_text_prompt(
texts: list[str]) -> str: placeholder_storage: dict[str, list], texts: list[str]
) -> str:
for idx, elem in enumerate(texts): for idx, elem in enumerate(texts):
if elem in placeholder_storage: if elem in placeholder_storage:
texts[idx] = placeholder_storage[elem].pop(0) texts[idx] = placeholder_storage[elem].pop(0)
...@@ -881,10 +927,11 @@ def _get_interleaved_text_prompt(placeholder_storage: dict[str, list], ...@@ -881,10 +927,11 @@ def _get_interleaved_text_prompt(placeholder_storage: dict[str, list],
# TODO: Let user specify how to insert multimodal tokens into prompt # TODO: Let user specify how to insert multimodal tokens into prompt
# (similar to chat template) # (similar to chat template)
def _get_full_multimodal_text_prompt(placeholder_storage: dict[str, list], def _get_full_multimodal_text_prompt(
texts: list[str], placeholder_storage: dict[str, list],
interleave_strings: bool texts: list[str],
) -> str: interleave_strings: bool,
) -> str:
"""Combine multimodal prompts for a multimodal language model.""" """Combine multimodal prompts for a multimodal language model."""
# flatten storage to make it looks like # flatten storage to make it looks like
...@@ -907,7 +954,6 @@ def _get_full_multimodal_text_prompt(placeholder_storage: dict[str, list], ...@@ -907,7 +954,6 @@ def _get_full_multimodal_text_prompt(placeholder_storage: dict[str, list],
# Look through the text prompt to check for missing placeholders # Look through the text prompt to check for missing placeholders
missing_placeholders: list[str] = [] missing_placeholders: list[str] = []
for placeholder in placeholder_counts: for placeholder in placeholder_counts:
# For any existing placeholder in the text prompt, we leave it as is # For any existing placeholder in the text prompt, we leave it as is
placeholder_counts[placeholder] -= text_prompt.count(placeholder) placeholder_counts[placeholder] -= text_prompt.count(placeholder)
...@@ -916,15 +962,18 @@ def _get_full_multimodal_text_prompt(placeholder_storage: dict[str, list], ...@@ -916,15 +962,18 @@ def _get_full_multimodal_text_prompt(placeholder_storage: dict[str, list],
"Placeholder count is negative! " "Placeholder count is negative! "
"Ensure that the 'interleave_strings' flag is disabled " "Ensure that the 'interleave_strings' flag is disabled "
"(current value: %s) " "(current value: %s) "
"when manually placing image placeholders.", interleave_strings "when manually placing image placeholders.",
interleave_strings,
) )
logger.debug("Input prompt: %s", text_prompt) logger.debug("Input prompt: %s", text_prompt)
raise ValueError( raise ValueError(
f"Found more '{placeholder}' placeholders in input prompt than " f"Found more '{placeholder}' placeholders in input prompt than "
"actual multimodal data items.") "actual multimodal data items."
)
missing_placeholders.extend([placeholder] * missing_placeholders.extend(
placeholder_counts[placeholder]) [placeholder] * placeholder_counts[placeholder]
)
# NOTE: Default behaviour: we always add missing placeholders # NOTE: Default behaviour: we always add missing placeholders
# at the front of the prompt, if interleave_strings=False # at the front of the prompt, if interleave_strings=False
...@@ -944,7 +993,8 @@ _AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python ...@@ -944,7 +993,8 @@ _AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python
_VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python _VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python
_ResponsesInputImageParser = TypeAdapter( _ResponsesInputImageParser = TypeAdapter(
ResponseInputImageParam).validate_python ResponseInputImageParam
).validate_python
_ContentPart: TypeAlias = Union[str, dict[str, str], InputAudio, PILImage] _ContentPart: TypeAlias = Union[str, dict[str, str], InputAudio, PILImage]
# Define a mapping from part types to their corresponding parsing functions. # Define a mapping from part types to their corresponding parsing functions.
...@@ -952,32 +1002,35 @@ MM_PARSER_MAP: dict[ ...@@ -952,32 +1002,35 @@ MM_PARSER_MAP: dict[
str, str,
Callable[[ChatCompletionContentPartParam], _ContentPart], Callable[[ChatCompletionContentPartParam], _ContentPart],
] = { ] = {
"text": "text": lambda part: _TextParser(part).get("text", None),
lambda part: _TextParser(part).get("text", None), "thinking": lambda part: _ThinkParser(part).get("thinking", None),
"thinking": "input_text": lambda part: _TextParser(part).get("text", None),
lambda part: _ThinkParser(part).get("thinking", None), "input_image": lambda part: _ResponsesInputImageParser(part).get(
"input_text": "image_url", None
lambda part: _TextParser(part).get("text", None), ),
"input_image": "image_url": lambda part: _ImageParser(part)
lambda part: _ResponsesInputImageParser(part).get("image_url", None), .get("image_url", {})
"image_url": .get("url", None),
lambda part: _ImageParser(part).get("image_url", {}).get("url", None), "image_embeds": lambda part: _ImageEmbedsParser(part).get(
"image_embeds": "image_embeds", None
lambda part: _ImageEmbedsParser(part).get("image_embeds", None), ),
"image_pil": lambda part: _PILImageParser(part).get("image_pil", None), "image_pil": lambda part: _PILImageParser(part).get("image_pil", None),
"audio_url": "audio_url": lambda part: _AudioParser(part)
lambda part: _AudioParser(part).get("audio_url", {}).get("url", None), .get("audio_url", {})
"input_audio": .get("url", None),
lambda part: _InputAudioParser(part).get("input_audio", None), "input_audio": lambda part: _InputAudioParser(part).get(
"refusal": "input_audio", None
lambda part: _RefusalParser(part).get("refusal", None), ),
"video_url": "refusal": lambda part: _RefusalParser(part).get("refusal", None),
lambda part: _VideoParser(part).get("video_url", {}).get("url", None), "video_url": lambda part: _VideoParser(part)
.get("video_url", {})
.get("url", None),
} }
def _parse_chat_message_content_mm_part( def _parse_chat_message_content_mm_part(
part: ChatCompletionContentPartParam) -> tuple[str, _ContentPart]: part: ChatCompletionContentPartParam,
) -> tuple[str, _ContentPart]:
""" """
Parses a given multi-modal content part based on its type. Parses a given multi-modal content part based on its type.
...@@ -993,7 +1046,8 @@ def _parse_chat_message_content_mm_part( ...@@ -993,7 +1046,8 @@ def _parse_chat_message_content_mm_part(
ValueError: If the 'type' field is missing and no direct URL is found. ValueError: If the 'type' field is missing and no direct URL is found.
""" """
assert isinstance( assert isinstance(
part, dict) # This is needed to avoid mypy errors: part.get() from str part, dict
) # This is needed to avoid mypy errors: part.get() from str
part_type = part.get("type", None) part_type = part.get("type", None)
if isinstance(part_type, str) and part_type in MM_PARSER_MAP: if isinstance(part_type, str) and part_type in MM_PARSER_MAP:
...@@ -1002,8 +1056,10 @@ def _parse_chat_message_content_mm_part( ...@@ -1002,8 +1056,10 @@ def _parse_chat_message_content_mm_part(
# Special case for 'image_url.detail' # Special case for 'image_url.detail'
# We only support 'auto', which is the default # We only support 'auto', which is the default
if part_type == "image_url" and part.get("detail", "auto") != "auto": if part_type == "image_url" and part.get("detail", "auto") != "auto":
logger.warning("'image_url.detail' is currently not supported " logger.warning(
"and will be ignored.") "'image_url.detail' is currently not supported "
"and will be ignored."
)
return part_type, content return part_type, content
...@@ -1011,19 +1067,22 @@ def _parse_chat_message_content_mm_part( ...@@ -1011,19 +1067,22 @@ def _parse_chat_message_content_mm_part(
# 'type' is required field by pydantic # 'type' is required field by pydantic
if part_type is None: if part_type is None:
if part.get("image_url") is not None: if part.get("image_url") is not None:
image_params = cast(CustomChatCompletionContentSimpleImageParam, image_params = cast(
part) CustomChatCompletionContentSimpleImageParam, part
)
return "image_url", image_params.get("image_url", "") return "image_url", image_params.get("image_url", "")
if part.get("audio_url") is not None: if part.get("audio_url") is not None:
audio_params = cast(CustomChatCompletionContentSimpleAudioParam, audio_params = cast(
part) CustomChatCompletionContentSimpleAudioParam, part
)
return "audio_url", audio_params.get("audio_url", "") return "audio_url", audio_params.get("audio_url", "")
if part.get("input_audio") is not None: if part.get("input_audio") is not None:
input_audio_params = cast(dict[str, str], part) input_audio_params = cast(dict[str, str], part)
return "input_audio", input_audio_params return "input_audio", input_audio_params
if part.get("video_url") is not None: if part.get("video_url") is not None:
video_params = cast(CustomChatCompletionContentSimpleVideoParam, video_params = cast(
part) CustomChatCompletionContentSimpleVideoParam, part
)
return "video_url", video_params.get("video_url", "") return "video_url", video_params.get("video_url", "")
# Raise an error if no 'type' or direct URL is found. # Raise an error if no 'type' or direct URL is found.
raise ValueError("Missing 'type' field in multimodal part.") raise ValueError("Missing 'type' field in multimodal part.")
...@@ -1033,9 +1092,16 @@ def _parse_chat_message_content_mm_part( ...@@ -1033,9 +1092,16 @@ def _parse_chat_message_content_mm_part(
return part_type, "unknown part_type content" return part_type, "unknown part_type content"
VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url", VALID_MESSAGE_CONTENT_MM_PART_TYPES = (
"image_embeds", "image_pil", "text",
"audio_url", "input_audio", "video_url") "refusal",
"image_url",
"image_embeds",
"image_pil",
"audio_url",
"input_audio",
"video_url",
)
def _parse_chat_message_content_parts( def _parse_chat_message_content_parts(
...@@ -1055,21 +1121,20 @@ def _parse_chat_message_content_parts( ...@@ -1055,21 +1121,20 @@ def _parse_chat_message_content_parts(
part, part,
mm_parser, mm_parser,
wrap_dicts=wrap_dicts, wrap_dicts=wrap_dicts,
interleave_strings=interleave_strings interleave_strings=interleave_strings,
) )
if parse_res: if parse_res:
content.append(parse_res) content.append(parse_res)
if wrap_dicts: if wrap_dicts:
# Parsing wraps images and texts as interleaved dictionaries # Parsing wraps images and texts as interleaved dictionaries
return [ConversationMessage(role=role, return [ConversationMessage(role=role, content=content)] # type: ignore
content=content)] # type: ignore
texts = cast(list[str], content) texts = cast(list[str], content)
mm_placeholder_storage = mm_parser.mm_placeholder_storage() mm_placeholder_storage = mm_parser.mm_placeholder_storage()
if mm_placeholder_storage: if mm_placeholder_storage:
text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_storage, text_prompt = _get_full_multimodal_text_prompt(
texts, mm_placeholder_storage, texts, interleave_strings
interleave_strings) )
else: else:
text_prompt = "\n".join(texts) text_prompt = "\n".join(texts)
...@@ -1099,13 +1164,16 @@ def _parse_chat_message_content_part( ...@@ -1099,13 +1164,16 @@ def _parse_chat_message_content_part(
if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and content is None: if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and content is None:
logger.warning( logger.warning(
"Skipping multimodal part '%s' (type: '%s') " "Skipping multimodal part '%s' (type: '%s') "
"with empty / unparsable content.", part, part_type) "with empty / unparsable content.",
part,
part_type,
)
return None return None
if part_type in ("text", "input_text", "refusal", "thinking"): if part_type in ("text", "input_text", "refusal", "thinking"):
str_content = cast(str, content) str_content = cast(str, content)
if wrap_dicts: if wrap_dicts:
return {'type': 'text', 'text': str_content} return {"type": "text", "text": str_content}
else: else:
return str_content return str_content
...@@ -1137,8 +1205,12 @@ def _parse_chat_message_content_part( ...@@ -1137,8 +1205,12 @@ def _parse_chat_message_content_part(
else: else:
raise NotImplementedError(f"Unknown part type: {part_type}") raise NotImplementedError(f"Unknown part type: {part_type}")
return {'type': modality} if wrap_dicts else ( return (
MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None {"type": modality}
if wrap_dicts
else (
MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None
)
) )
...@@ -1171,14 +1243,16 @@ def _parse_chat_message_content( ...@@ -1171,14 +1243,16 @@ def _parse_chat_message_content(
) )
for result_msg in result: for result_msg in result:
if role == 'assistant': if role == "assistant":
parsed_msg = _AssistantParser(message) parsed_msg = _AssistantParser(message)
# The 'tool_calls' is not None check ensures compatibility. # The 'tool_calls' is not None check ensures compatibility.
# It's needed only if downstream code doesn't strictly # It's needed only if downstream code doesn't strictly
# follow the OpenAI spec. # follow the OpenAI spec.
if ("tool_calls" in parsed_msg if (
and parsed_msg["tool_calls"] is not None): "tool_calls" in parsed_msg
and parsed_msg["tool_calls"] is not None
):
result_msg["tool_calls"] = list(parsed_msg["tool_calls"]) result_msg["tool_calls"] = list(parsed_msg["tool_calls"])
elif role == "tool": elif role == "tool":
parsed_msg = _ToolParser(message) parsed_msg = _ToolParser(message)
...@@ -1198,12 +1272,15 @@ def _postprocess_messages(messages: list[ConversationMessage]) -> None: ...@@ -1198,12 +1272,15 @@ def _postprocess_messages(messages: list[ConversationMessage]) -> None:
# so, for messages that have tool_calls, parse the string (which we get # so, for messages that have tool_calls, parse the string (which we get
# from openAI format) to dict # from openAI format) to dict
for message in messages: for message in messages:
if (message["role"] == "assistant" and "tool_calls" in message if (
and isinstance(message["tool_calls"], list)): message["role"] == "assistant"
and "tool_calls" in message
and isinstance(message["tool_calls"], list)
):
for item in message["tool_calls"]: for item in message["tool_calls"]:
item["function"]["arguments"] = json.loads( item["function"]["arguments"] = json.loads(
item["function"]["arguments"]) item["function"]["arguments"]
)
def parse_chat_messages( def parse_chat_messages(
...@@ -1224,7 +1301,7 @@ def parse_chat_messages( ...@@ -1224,7 +1301,7 @@ def parse_chat_messages(
content_format == "string" content_format == "string"
and model_config.multimodal_config is not None and model_config.multimodal_config is not None
and model_config.multimodal_config.interleave_mm_strings and model_config.multimodal_config.interleave_mm_strings
) ),
) )
conversation.extend(sub_messages) conversation.extend(sub_messages)
...@@ -1252,7 +1329,7 @@ def parse_chat_messages_futures( ...@@ -1252,7 +1329,7 @@ def parse_chat_messages_futures(
content_format == "string" content_format == "string"
and model_config.multimodal_config is not None and model_config.multimodal_config is not None
and model_config.multimodal_config.interleave_mm_strings and model_config.multimodal_config.interleave_mm_strings
) ),
) )
conversation.extend(sub_messages) conversation.extend(sub_messages)
...@@ -1283,10 +1360,10 @@ def apply_hf_chat_template( ...@@ -1283,10 +1360,10 @@ def apply_hf_chat_template(
raise ValueError( raise ValueError(
"As of transformers v4.44, default chat template is no longer " "As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer " "allowed, so you must provide a chat template if the tokenizer "
"does not define one.") "does not define one."
)
try: try:
return tokenizer.apply_chat_template( return tokenizer.apply_chat_template(
conversation=conversation, # type: ignore[arg-type] conversation=conversation, # type: ignore[arg-type]
tools=tools, # type: ignore[arg-type] tools=tools, # type: ignore[arg-type]
...@@ -1298,13 +1375,14 @@ def apply_hf_chat_template( ...@@ -1298,13 +1375,14 @@ def apply_hf_chat_template(
# External library exceptions can sometimes occur despite the framework's # External library exceptions can sometimes occur despite the framework's
# internal exception management capabilities. # internal exception management capabilities.
except Exception as e: except Exception as e:
# Log and report any library-related exceptions for further # Log and report any library-related exceptions for further
# investigation. # investigation.
logger.exception( logger.exception(
"An error occurred in `transformers` while applying chat template") "An error occurred in `transformers` while applying chat template"
)
raise ValueError(str(e)) from e raise ValueError(str(e)) from e
def apply_mistral_chat_template( def apply_mistral_chat_template(
tokenizer: MistralTokenizer, tokenizer: MistralTokenizer,
messages: list[ChatCompletionMessageParam], messages: list[ChatCompletionMessageParam],
...@@ -1337,26 +1415,26 @@ def apply_mistral_chat_template( ...@@ -1337,26 +1415,26 @@ def apply_mistral_chat_template(
# External library exceptions can sometimes occur despite the framework's # External library exceptions can sometimes occur despite the framework's
# internal exception management capabilities. # internal exception management capabilities.
except Exception as e: except Exception as e:
# Log and report any library-related exceptions for further # Log and report any library-related exceptions for further
# investigation. # investigation.
logger.exception( logger.exception(
"An error occurred in `mistral_common` while applying chat " "An error occurred in `mistral_common` while applying chat template"
"template") )
raise ValueError(str(e)) from e raise ValueError(str(e)) from e
def get_history_tool_calls_cnt(conversation: list[ConversationMessage]): def get_history_tool_calls_cnt(conversation: list[ConversationMessage]):
idx = 0 idx = 0
for msg in conversation: for msg in conversation:
if msg['role'] == 'assistant': if msg["role"] == "assistant":
tool_calls = msg.get('tool_calls') tool_calls = msg.get("tool_calls")
idx += len(list(tool_calls)) if tool_calls is not None else 0 # noqa idx += len(list(tool_calls)) if tool_calls is not None else 0 # noqa
return idx return idx
def make_tool_call_id(id_type:str='random', func_name=None, idx=None):
if id_type=='kimi_k2': def make_tool_call_id(id_type: str = "random", func_name=None, idx=None):
return f'functions.{func_name}:{idx}' if id_type == "kimi_k2":
return f"functions.{func_name}:{idx}"
else: else:
# by default return random # by default return random
return f"chatcmpl-tool-{random_uuid()}" return f"chatcmpl-tool-{random_uuid()}"
...@@ -82,16 +82,26 @@ from vllm.utils import (AsyncMicrobatchTokenizer, is_list_of, ...@@ -82,16 +82,26 @@ from vllm.utils import (AsyncMicrobatchTokenizer, is_list_of,
logger = init_logger(__name__) logger = init_logger(__name__)
CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest, CompletionLikeRequest = Union[
EmbeddingCompletionRequest, RerankRequest, CompletionRequest,
ClassificationRequest, ScoreRequest, DetokenizeRequest,
TokenizeCompletionRequest] EmbeddingCompletionRequest,
RerankRequest,
ClassificationRequest,
ScoreRequest,
TokenizeCompletionRequest,
]
ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest, ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
TokenizeChatRequest] TokenizeChatRequest]
SpeechToTextRequest = Union[TranscriptionRequest, TranslationRequest] SpeechToTextRequest = Union[TranscriptionRequest, TranslationRequest]
AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest, SpeechToTextRequest, AnyRequest = Union[
ResponsesRequest, IOProcessorRequest] CompletionLikeRequest,
ChatLikeRequest,
SpeechToTextRequest,
ResponsesRequest,
IOProcessorRequest,
]
AnyResponse = Union[ AnyResponse = Union[
CompletionResponse, CompletionResponse,
...@@ -135,6 +145,7 @@ class RequestProcessingMixin(BaseModel): ...@@ -135,6 +145,7 @@ class RequestProcessingMixin(BaseModel):
Mixin for request processing, Mixin for request processing,
handling prompt preparation and engine input. handling prompt preparation and engine input.
""" """
request_prompts: Optional[Sequence[RequestPrompt]] = [] request_prompts: Optional[Sequence[RequestPrompt]] = []
engine_prompts: Optional[Union[list[EngineTokensPrompt], engine_prompts: Optional[Union[list[EngineTokensPrompt],
list[EngineEmbedsPrompt]]] = [] list[EngineEmbedsPrompt]]] = []
...@@ -147,6 +158,7 @@ class ResponseGenerationMixin(BaseModel): ...@@ -147,6 +158,7 @@ class ResponseGenerationMixin(BaseModel):
Mixin for response generation, Mixin for response generation,
managing result generators and final batch results. managing result generators and final batch results.
""" """
result_generator: Optional[AsyncGenerator[tuple[int, Union[ result_generator: Optional[AsyncGenerator[tuple[int, Union[
RequestOutput, PoolingRequestOutput]], None]] = None RequestOutput, PoolingRequestOutput]], None]] = None
final_res_batch: list[Union[RequestOutput, PoolingRequestOutput]] = Field( final_res_batch: list[Union[RequestOutput, PoolingRequestOutput]] = Field(
...@@ -155,8 +167,12 @@ class ResponseGenerationMixin(BaseModel): ...@@ -155,8 +167,12 @@ class ResponseGenerationMixin(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, BaseModel, class ServeContext(
Generic[RequestT]): RequestProcessingMixin,
ResponseGenerationMixin,
BaseModel,
Generic[RequestT],
):
# Shared across all requests # Shared across all requests
request: RequestT request: RequestT
raw_request: Optional[Request] = None raw_request: Optional[Request] = None
...@@ -298,8 +314,8 @@ class OpenAIServing: ...@@ -298,8 +314,8 @@ class OpenAIServing:
truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens",
None) None)
if truncate_prompt_tokens is not None and \ if (truncate_prompt_tokens is not None
truncate_prompt_tokens > self.max_model_len: and truncate_prompt_tokens > self.max_model_len):
return self.create_error_response( return self.create_error_response(
"truncate_prompt_tokens value is " "truncate_prompt_tokens value is "
"greater than max_model_len." "greater than max_model_len."
...@@ -344,10 +360,12 @@ class OpenAIServing: ...@@ -344,10 +360,12 @@ class OpenAIServing:
return self.create_error_response( return self.create_error_response(
"Request prompts not available") "Request prompts not available")
self._log_inputs(request_id_item, self._log_inputs(
ctx.request_prompts[i], request_id_item,
params=pooling_params, ctx.request_prompts[i],
lora_request=ctx.lora_request) params=pooling_params,
lora_request=ctx.lora_request,
)
# Mypy has an existing bug related to inferring the variance of # Mypy has an existing bug related to inferring the variance of
# TypedDicts with `builtins.enumerate`: # TypedDicts with `builtins.enumerate`:
...@@ -410,10 +428,11 @@ class OpenAIServing: ...@@ -410,10 +428,11 @@ class OpenAIServing:
return self.create_error_response(str(e)) return self.create_error_response(str(e))
def create_error_response( def create_error_response(
self, self,
message: str, message: str,
err_type: str = "BadRequestError", err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse: status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
) -> ErrorResponse:
if self.log_error_stack: if self.log_error_stack:
exc_type, _, _ = sys.exc_info() exc_type, _, _ = sys.exc_info()
if exc_type is not None: if exc_type is not None:
...@@ -424,10 +443,11 @@ class OpenAIServing: ...@@ -424,10 +443,11 @@ class OpenAIServing:
message=message, type=err_type, code=status_code.value)) message=message, type=err_type, code=status_code.value))
def create_streaming_error_response( def create_streaming_error_response(
self, self,
message: str, message: str,
err_type: str = "BadRequestError", err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str: status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
) -> str:
json_str = json.dumps( json_str = json.dumps(
self.create_error_response(message=message, self.create_error_response(message=message,
err_type=err_type, err_type=err_type,
...@@ -438,25 +458,25 @@ class OpenAIServing: ...@@ -438,25 +458,25 @@ class OpenAIServing:
self, self,
request: AnyRequest, request: AnyRequest,
) -> Optional[ErrorResponse]: ) -> Optional[ErrorResponse]:
error_response = None error_response = None
if self._is_model_supported(request.model): if self._is_model_supported(request.model):
return None return None
if request.model in self.models.lora_requests: if request.model in self.models.lora_requests:
return None return None
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING and request.model and ( if (envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING and request.model and
load_result := await self.models.resolve_lora(request.model)): (load_result := await self.models.resolve_lora(request.model))):
if isinstance(load_result, LoRARequest): if isinstance(load_result, LoRARequest):
return None return None
if isinstance(load_result, ErrorResponse) and \ if (isinstance(load_result, ErrorResponse) and
load_result.error.code == HTTPStatus.BAD_REQUEST.value: load_result.error.code == HTTPStatus.BAD_REQUEST.value):
error_response = load_result error_response = load_result
return error_response or self.create_error_response( return error_response or self.create_error_response(
message=f"The model `{request.model}` does not exist.", message=f"The model `{request.model}` does not exist.",
err_type="NotFoundError", err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND) status_code=HTTPStatus.NOT_FOUND,
)
def _get_active_default_mm_loras( def _get_active_default_mm_loras(
self, request: AnyRequest) -> Optional[LoRARequest]: self, request: AnyRequest) -> Optional[LoRARequest]:
...@@ -487,7 +507,6 @@ class OpenAIServing: ...@@ -487,7 +507,6 @@ class OpenAIServing:
request: AnyRequest, request: AnyRequest,
supports_default_mm_loras: bool = False, supports_default_mm_loras: bool = False,
) -> Optional[LoRARequest]: ) -> Optional[LoRARequest]:
if request.model in self.models.lora_requests: if request.model in self.models.lora_requests:
return self.models.lora_requests[request.model] return self.models.lora_requests[request.model]
...@@ -548,13 +567,15 @@ class OpenAIServing: ...@@ -548,13 +567,15 @@ class OpenAIServing:
prompt, prompt,
add_special_tokens=add_special_tokens, add_special_tokens=add_special_tokens,
truncation=True, truncation=True,
max_length=self.max_model_len) max_length=self.max_model_len,
)
else: else:
encoded = await async_tokenizer( encoded = await async_tokenizer(
prompt, prompt,
add_special_tokens=add_special_tokens, add_special_tokens=add_special_tokens,
truncation=True, truncation=True,
max_length=truncate_prompt_tokens) max_length=truncate_prompt_tokens,
)
input_ids = encoded.input_ids input_ids = encoded.input_ids
input_text = prompt input_text = prompt
...@@ -595,16 +616,22 @@ class OpenAIServing: ...@@ -595,16 +616,22 @@ class OpenAIServing:
# Note: EmbeddingRequest, ClassificationRequest, # Note: EmbeddingRequest, ClassificationRequest,
# and ScoreRequest doesn't have max_tokens # and ScoreRequest doesn't have max_tokens
if isinstance(request, if isinstance(
(EmbeddingChatRequest, EmbeddingCompletionRequest, request,
ScoreRequest, RerankRequest, ClassificationRequest)): (
EmbeddingChatRequest,
EmbeddingCompletionRequest,
ScoreRequest,
RerankRequest,
ClassificationRequest,
),
):
# Note: input length can be up to the entire model context length # Note: input length can be up to the entire model context length
# since these requests don't generate tokens. # since these requests don't generate tokens.
if token_num > self.max_model_len: if token_num > self.max_model_len:
operations: dict[type[AnyRequest], str] = { operations: dict[type[AnyRequest], str] = {
ScoreRequest: "score", ScoreRequest: "score",
ClassificationRequest: "classification" ClassificationRequest: "classification",
} }
operation = operations.get(type(request), operation = operations.get(type(request),
"embedding generation") "embedding generation")
...@@ -618,8 +645,11 @@ class OpenAIServing: ...@@ -618,8 +645,11 @@ class OpenAIServing:
# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
# and does not require model context length validation # and does not require model context length validation
if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest, if isinstance(
DetokenizeRequest)): request,
(TokenizeCompletionRequest, TokenizeChatRequest,
DetokenizeRequest),
):
return TextTokensPrompt(prompt=input_text, return TextTokensPrompt(prompt=input_text,
prompt_token_ids=input_ids) prompt_token_ids=input_ids)
...@@ -639,8 +669,8 @@ class OpenAIServing: ...@@ -639,8 +669,8 @@ class OpenAIServing:
f"{token_num} input tokens. Please reduce the length of " f"{token_num} input tokens. Please reduce the length of "
"the input messages.") "the input messages.")
if max_tokens is not None and \ if (max_tokens is not None
token_num + max_tokens > self.max_model_len: and token_num + max_tokens > self.max_model_len):
raise ValueError( raise ValueError(
"'max_tokens' or 'max_completion_tokens' is too large: " "'max_tokens' or 'max_completion_tokens' is too large: "
f"{max_tokens}. This model's maximum context length is " f"{max_tokens}. This model's maximum context length is "
...@@ -745,13 +775,14 @@ class OpenAIServing: ...@@ -745,13 +775,14 @@ class OpenAIServing:
tasks = [] tasks = []
for prompt_input in batch_inputs: for prompt_input in batch_inputs:
if prompt_input["is_tokens"] is False: if prompt_input["is_tokens"] is False:
assert tokenizer is not None, \ assert tokenizer is not None, (
"Tokenizer is required for text prompts" "Tokenizer is required for text prompts")
task = self._normalize_prompt_text_to_input( task = self._normalize_prompt_text_to_input(
request, request,
prompt_input["content"], prompt_input["content"],
tokenizer=tokenizer, tokenizer=tokenizer,
add_special_tokens=add_special_tokens) add_special_tokens=add_special_tokens,
)
else: else:
task = self._normalize_prompt_tokens_to_input( task = self._normalize_prompt_tokens_to_input(
request, prompt_input["content"], tokenizer=tokenizer) request, prompt_input["content"], tokenizer=tokenizer)
...@@ -766,9 +797,14 @@ class OpenAIServing: ...@@ -766,9 +797,14 @@ class OpenAIServing:
@overload @overload
async def _preprocess_completion( async def _preprocess_completion(
self, self,
request: Union[DetokenizeRequest, EmbeddingCompletionRequest, request: Union[
RerankRequest, ClassificationRequest, ScoreRequest, DetokenizeRequest,
TokenizeCompletionRequest], EmbeddingCompletionRequest,
RerankRequest,
ClassificationRequest,
ScoreRequest,
TokenizeCompletionRequest,
],
tokenizer: Optional[AnyTokenizer], tokenizer: Optional[AnyTokenizer],
input_or_inputs: Union[str, list[str], list[int], list[list[int]]], input_or_inputs: Union[str, list[str], list[int], list[list[int]]],
add_special_tokens: bool = ..., add_special_tokens: bool = ...,
...@@ -783,8 +819,10 @@ class OpenAIServing: ...@@ -783,8 +819,10 @@ class OpenAIServing:
input_or_inputs: Optional[Union[str, list[str], list[int], input_or_inputs: Optional[Union[str, list[str], list[int],
list[list[int]]]], list[list[int]]]],
add_special_tokens: bool = ..., add_special_tokens: bool = ...,
) -> tuple[list[Union[TextTokensPrompt, EmbedsPrompt]], list[Union[ ) -> tuple[
EngineTokensPrompt, EngineEmbedsPrompt]]]: list[Union[TextTokensPrompt, EmbedsPrompt]],
list[Union[EngineTokensPrompt, EngineEmbedsPrompt]],
]:
... ...
async def _preprocess_completion( async def _preprocess_completion(
...@@ -794,32 +832,38 @@ class OpenAIServing: ...@@ -794,32 +832,38 @@ class OpenAIServing:
input_or_inputs: Optional[Union[str, list[str], list[int], input_or_inputs: Optional[Union[str, list[str], list[int],
list[list[int]]]], list[list[int]]]],
add_special_tokens: bool = True, add_special_tokens: bool = True,
) -> tuple[Union[list[TextTokensPrompt], list[Union[ ) -> tuple[
TextTokensPrompt, EmbedsPrompt]]], Union[ Union[list[TextTokensPrompt], list[Union[TextTokensPrompt,
list[EngineTokensPrompt], list[Union[EngineTokensPrompt, EmbedsPrompt]]],
EngineEmbedsPrompt]]]]: Union[
if not isinstance(request, list[EngineTokensPrompt],
CompletionRequest) and input_or_inputs is None: list[Union[EngineTokensPrompt, EngineEmbedsPrompt]],
],
]:
if (not isinstance(request, CompletionRequest)
and input_or_inputs is None):
raise ValueError( raise ValueError(
"Prompt embeds with non-completion requests is not" "Prompt embeds with non-completion requests is not"
" currently supported.") " currently supported.")
(request_prompts_text, request_prompts_embeds (
) = await self._tokenize_prompt_input_or_inputs_async( request_prompts_text,
request, request_prompts_embeds,
tokenizer, ) = await self._tokenize_prompt_input_or_inputs_async(
input_or_inputs, request,
add_special_tokens=add_special_tokens, tokenizer,
) input_or_inputs,
add_special_tokens=add_special_tokens,
)
engine_prompts_text = [ engine_prompts_text = [
EngineTokensPrompt( EngineTokensPrompt(
prompt_token_ids=request_prompt_text["prompt_token_ids"]) prompt_token_ids=request_prompt_text["prompt_token_ids"])
for request_prompt_text in request_prompts_text for request_prompt_text in request_prompts_text
] ]
cache_salt = request.cache_salt if ( cache_salt = (request.cache_salt if
hasattr(request, "cache_salt") (hasattr(request, "cache_salt")
and request.cache_salt is not None) else None and request.cache_salt is not None) else None)
if cache_salt: if cache_salt:
for prompt_text in engine_prompts_text: for prompt_text in engine_prompts_text:
prompt_text["cache_salt"] = cache_salt prompt_text["cache_salt"] = cache_salt
...@@ -831,8 +875,8 @@ class OpenAIServing: ...@@ -831,8 +875,8 @@ class OpenAIServing:
# non-completion requests and if we don't add the overload here, # non-completion requests and if we don't add the overload here,
# everywhere this function is used outside of serving_completion will # everywhere this function is used outside of serving_completion will
# need logic asserting that only text prompts are in the request. # need logic asserting that only text prompts are in the request.
if not isinstance(request, if (not isinstance(request, CompletionRequest)
CompletionRequest) and input_or_inputs is not None: and input_or_inputs is not None):
return request_prompts_text, engine_prompts_text return request_prompts_text, engine_prompts_text
engine_prompts_embeds = [ engine_prompts_embeds = [
...@@ -862,8 +906,11 @@ class OpenAIServing: ...@@ -862,8 +906,11 @@ class OpenAIServing:
chat_template_kwargs: Optional[dict[str, Any]] = None, chat_template_kwargs: Optional[dict[str, Any]] = None,
tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None, tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None,
add_special_tokens: bool = False, add_special_tokens: bool = False,
) -> tuple[list[ConversationMessage], Sequence[RequestPrompt], ) -> tuple[
list[EngineTokensPrompt]]: list[ConversationMessage],
Sequence[RequestPrompt],
list[EngineTokensPrompt],
]:
model_config = self.model_config model_config = self.model_config
resolved_content_format = resolve_chat_template_content_format( resolved_content_format = resolve_chat_template_content_format(
...@@ -925,8 +972,8 @@ class OpenAIServing: ...@@ -925,8 +972,8 @@ class OpenAIServing:
if tokenizer is None: if tokenizer is None:
assert isinstance(request_prompt, str), ( assert isinstance(request_prompt, str), (
"Prompt has to be a string", \ "Prompt has to be a string",
"when the tokenizer is not initialised" "when the tokenizer is not initialised",
) )
prompt_inputs = TextTokensPrompt(prompt=request_prompt, prompt_inputs = TextTokensPrompt(prompt=request_prompt,
prompt_token_ids=[1]) prompt_token_ids=[1])
...@@ -943,7 +990,8 @@ class OpenAIServing: ...@@ -943,7 +990,8 @@ class OpenAIServing:
"Prompt has to be either a string or a list of token ids") "Prompt has to be either a string or a list of token ids")
prompt_inputs = TextTokensPrompt( prompt_inputs = TextTokensPrompt(
prompt=tokenizer.decode(request_prompt), prompt=tokenizer.decode(request_prompt),
prompt_token_ids=request_prompt) prompt_token_ids=request_prompt,
)
engine_prompt = EngineTokensPrompt( engine_prompt = EngineTokensPrompt(
prompt_token_ids=prompt_inputs["prompt_token_ids"]) prompt_token_ids=prompt_inputs["prompt_token_ids"])
...@@ -1007,22 +1055,23 @@ class OpenAIServing: ...@@ -1007,22 +1055,23 @@ class OpenAIServing:
prompt_token_ids=prompt_token_ids) prompt_token_ids=prompt_token_ids)
request_prompt = prompt_token_ids request_prompt = prompt_token_ids
# Update the sampling params. # Update the sampling params.
sampling_params.max_tokens = (self.max_model_len - sampling_params.max_tokens = self.max_model_len - len(
len(prompt_token_ids)) prompt_token_ids)
# OPTIMIZATION # OPTIMIZATION
priority = orig_priority - 1 priority = orig_priority - 1
@staticmethod @staticmethod
def _load_prompt_embeds( def _load_prompt_embeds(
prompt_embeds: Optional[Union[bytes, list[bytes]]], prompt_embeds: Optional[Union[bytes, list[bytes]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
) -> list[EmbedsPrompt]: ) -> list[EmbedsPrompt]:
def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt: def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt:
tensor = torch.load(io.BytesIO( tensor = torch.load(
pybase64.b64decode(embed, validate=True)), io.BytesIO(pybase64.b64decode(embed, validate=True)),
weights_only=True, weights_only=True,
map_location=torch.device("cpu")) map_location=torch.device("cpu"),
)
assert isinstance(tensor, torch.Tensor) and tensor.dtype in ( assert isinstance(tensor, torch.Tensor) and tensor.dtype in (
torch.float32, torch.float32,
torch.bfloat16, torch.bfloat16,
...@@ -1061,7 +1110,7 @@ class OpenAIServing: ...@@ -1061,7 +1110,7 @@ class OpenAIServing:
prompt = inputs prompt = inputs
elif isinstance(inputs, list): elif isinstance(inputs, list):
prompt_token_ids = inputs prompt_token_ids = inputs
elif 'prompt_embeds' in inputs: elif "prompt_embeds" in inputs:
prompt_embeds = inputs.get("prompt_embeds") prompt_embeds = inputs.get("prompt_embeds")
else: else:
prompt = inputs["prompt"] prompt = inputs["prompt"]
...@@ -1101,10 +1150,12 @@ class OpenAIServing: ...@@ -1101,10 +1150,12 @@ class OpenAIServing:
return raw_request.headers.get("X-Request-Id", default) return raw_request.headers.get("X-Request-Id", default)
@staticmethod @staticmethod
def _get_decoded_token(logprob: Logprob, def _get_decoded_token(
token_id: int, logprob: Logprob,
tokenizer: AnyTokenizer, token_id: int,
return_as_token_id: bool = False) -> str: tokenizer: AnyTokenizer,
return_as_token_id: bool = False,
) -> str:
if return_as_token_id: if return_as_token_id:
return f"token_id:{token_id}" return f"token_id:{token_id}"
...@@ -1117,9 +1168,11 @@ class OpenAIServing: ...@@ -1117,9 +1168,11 @@ class OpenAIServing:
return True return True
return self.models.is_base_model(model_name) return self.models.is_base_model(model_name)
def _get_model_name(self, def _get_model_name(
model_name: Optional[str] = None, self,
lora_request: Optional[LoRARequest] = None) -> str: model_name: Optional[str] = None,
lora_request: Optional[LoRARequest] = None,
) -> str:
if lora_request: if lora_request:
return lora_request.lora_name return lora_request.lora_name
if not model_name: if not model_name:
...@@ -1129,7 +1182,7 @@ class OpenAIServing: ...@@ -1129,7 +1182,7 @@ class OpenAIServing:
def clamp_prompt_logprobs( def clamp_prompt_logprobs(
prompt_logprobs: Union[PromptLogprobs, prompt_logprobs: Union[PromptLogprobs,
None]) -> Union[PromptLogprobs, None]: None], ) -> Union[PromptLogprobs, None]:
if prompt_logprobs is None: if prompt_logprobs is None:
return prompt_logprobs return prompt_logprobs
...@@ -1137,6 +1190,6 @@ def clamp_prompt_logprobs( ...@@ -1137,6 +1190,6 @@ def clamp_prompt_logprobs(
if logprob_dict is None: if logprob_dict is None:
continue continue
for logprob_values in logprob_dict.values(): for logprob_values in logprob_dict.values():
if logprob_values.logprob == float('-inf'): if logprob_values.logprob == float("-inf"):
logprob_values.logprob = -9999.0 logprob_values.logprob = -9999.0
return prompt_logprobs return prompt_logprobs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment