"tests/vscode:/vscode.git/clone" did not exist on "443c7cf4cf891e6957d4b31655e58cabceb5a2a7"
Commit 4851c202 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.1' into v0.6.1-dev

parents 9b902f9e 3fd2b0d2
...@@ -41,8 +41,9 @@ from transformers import AutoTokenizer ...@@ -41,8 +41,9 @@ from transformers import AutoTokenizer
from vllm import SamplingParams from vllm import SamplingParams
from ...utils import fork_new_process_for_each_test
from .conftest import (get_output_from_llm_generator, from .conftest import (get_output_from_llm_generator,
run_greedy_equality_correctness_test) run_equality_correctness_test)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -73,6 +74,7 @@ from .conftest import (get_output_from_llm_generator, ...@@ -73,6 +74,7 @@ from .conftest import (get_output_from_llm_generator,
@pytest.mark.parametrize("test_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("batch_size", [1, 32]) @pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
@fork_new_process_for_each_test
def test_spec_decode_e2e_with_detokenization(test_llm_generator, def test_spec_decode_e2e_with_detokenization(test_llm_generator,
batch_size: int): batch_size: int):
"""Run generation with speculative decoding on a batch. Verify the engine """Run generation with speculative decoding on a batch. Verify the engine
...@@ -116,44 +118,6 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator, ...@@ -116,44 +118,6 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
assert actual_tokens.strip() == expected_tokens.strip() assert actual_tokens.strip() == expected_tokens.strip()
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Use a small model for a fast test.
# Note this is repeated in the test body; to initialize a tokenizer.
"model": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Use AsyncLLM engine
"use_async": True,
}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_e2e_with_async_engine(test_llm_generator,
baseline_llm_generator,
batch_size: int):
"""Verify spec decode works well with async LLM engine.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=32,
force_output_len=True)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"common_llm_kwargs", "common_llm_kwargs",
[{ [{
...@@ -172,10 +136,10 @@ def test_spec_decode_e2e_with_async_engine(test_llm_generator, ...@@ -172,10 +136,10 @@ def test_spec_decode_e2e_with_async_engine(test_llm_generator,
# Try two different tiny base models. # Try two different tiny base models.
# Note that one is equal to the draft model, another isn't. # Note that one is equal to the draft model, another isn't.
{ {
"model": "JackFram/llama-68m", "model_name": "JackFram/llama-68m",
}, },
{ {
"model": "JackFram/llama-160m", "model_name": "JackFram/llama-160m",
}, },
]) ])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
...@@ -189,13 +153,15 @@ def test_spec_decode_e2e_with_async_engine(test_llm_generator, ...@@ -189,13 +153,15 @@ def test_spec_decode_e2e_with_async_engine(test_llm_generator,
"output_len", "output_len",
[ [
# Use long output len for the small model test. # Use long output len for the small model test.
1536, 10,
]) ])
@pytest.mark.parametrize("batch_size", [1]) @pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
@fork_new_process_for_each_test
def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1( def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
baseline_llm_generator, test_llm_generator, batch_size: int, vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
output_len: int): baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality on a tiny model with batch size of one. """Verify greedy equality on a tiny model with batch size of one.
Since this test is cheaper than other e2e correctness tests, we generate Since this test is cheaper than other e2e correctness tests, we generate
...@@ -204,14 +170,18 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1( ...@@ -204,14 +170,18 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
When the draft model is the same as the target model, we further check When the draft model is the same as the target model, we further check
whether all speculative tokens are accepted. whether all speculative tokens are accepted.
""" """
ensure_all_accepted = test_llm_generator.same_draft_target_model ensure_all_accepted = per_test_common_llm_kwargs.get(
run_greedy_equality_correctness_test( "model_name") == test_llm_kwargs.get("speculative_model")
baseline_llm_generator, run_equality_correctness_test(vllm_runner,
test_llm_generator, common_llm_kwargs,
batch_size, per_test_common_llm_kwargs,
max_output_len=output_len, baseline_llm_kwargs,
force_output_len=True, test_llm_kwargs,
ensure_all_accepted=ensure_all_accepted) batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0,
ensure_all_accepted=ensure_all_accepted)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -232,10 +202,10 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1( ...@@ -232,10 +202,10 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
# Try two different tiny base models. # Try two different tiny base models.
# Note that one is equal to the draft model, another isn't. # Note that one is equal to the draft model, another isn't.
{ {
"model": "JackFram/llama-68m", "model_name": "JackFram/llama-68m",
}, },
{ {
"model": "JackFram/llama-160m", "model_name": "JackFram/llama-160m",
}, },
]) ])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
...@@ -253,16 +223,22 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1( ...@@ -253,16 +223,22 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
]) ])
@pytest.mark.parametrize("batch_size", [64]) @pytest.mark.parametrize("batch_size", [64])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
@fork_new_process_for_each_test
def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs( def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
baseline_llm_generator, test_llm_generator, batch_size: int, vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
output_len: int): baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality on a tiny model and large batch size. """Verify greedy equality on a tiny model and large batch size.
""" """
run_greedy_equality_correctness_test(baseline_llm_generator, run_equality_correctness_test(vllm_runner,
test_llm_generator, common_llm_kwargs,
batch_size, per_test_common_llm_kwargs,
max_output_len=output_len, baseline_llm_kwargs,
force_output_len=True) test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -280,10 +256,10 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs( ...@@ -280,10 +256,10 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
# Try two different tiny base models. # Try two different tiny base models.
# Note that one is equal to the draft model, another isn't. # Note that one is equal to the draft model, another isn't.
{ {
"model": "JackFram/llama-68m", "model_name": "JackFram/llama-68m",
}, },
{ {
"model": "JackFram/llama-160m", "model_name": "JackFram/llama-160m",
}, },
]) ])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
...@@ -298,24 +274,31 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs( ...@@ -298,24 +274,31 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
]) ])
@pytest.mark.parametrize("batch_size", [32]) @pytest.mark.parametrize("batch_size", [32])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
@fork_new_process_for_each_test
def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len( def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
baseline_llm_generator, test_llm_generator, batch_size: int, vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
max_output_len: int): baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
max_output_len: int, seed: int):
"""Verify greedy equality on a tiny model, with a large batch size, and when """Verify greedy equality on a tiny model, with a large batch size, and when
sampling respects the EOS token. sampling respects the EOS token.
""" """
run_greedy_equality_correctness_test(baseline_llm_generator, run_equality_correctness_test(vllm_runner,
test_llm_generator, common_llm_kwargs,
batch_size, per_test_common_llm_kwargs,
max_output_len, baseline_llm_kwargs,
force_output_len=False) test_llm_kwargs,
batch_size,
max_output_len,
seed=seed,
temperature=0.0,
ignore_eos=False)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"common_llm_kwargs", "common_llm_kwargs",
[{ [{
# A "real" model (not tiny). # A "real" model (not tiny).
"model": "meta-llama/Llama-2-7b-chat-hf", "model_name": "meta-llama/Llama-2-7b-chat-hf",
# Skip cuda graph recording for fast test. # Skip cuda graph recording for fast test.
"enforce_eager": True, "enforce_eager": True,
...@@ -342,24 +325,30 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len( ...@@ -342,24 +325,30 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
256, 256,
]) ])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
@fork_new_process_for_each_test
def test_spec_decode_e2e_greedy_correctness_real_model_bs1( def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
baseline_llm_generator, test_llm_generator, batch_size: int, vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
output_len: int): baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality on a "real" model and batch size of 1. This is """Verify greedy equality on a "real" model and batch size of 1. This is
separate from large BS tests to make identifying the source of bugs easier. separate from large BS tests to make identifying the source of bugs easier.
""" """
run_greedy_equality_correctness_test(baseline_llm_generator, run_equality_correctness_test(vllm_runner,
test_llm_generator, common_llm_kwargs,
batch_size, per_test_common_llm_kwargs,
max_output_len=output_len, baseline_llm_kwargs,
force_output_len=True) test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"common_llm_kwargs", "common_llm_kwargs",
[{ [{
# A "real" model (not tiny). # A "real" model (not tiny).
"model": "meta-llama/Llama-2-7b-chat-hf", "model_name": "meta-llama/Llama-2-7b-chat-hf",
# Skip cuda graph recording for fast test. # Skip cuda graph recording for fast test.
"enforce_eager": True, "enforce_eager": True,
...@@ -386,17 +375,23 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1( ...@@ -386,17 +375,23 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
64, 64,
]) ])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
@fork_new_process_for_each_test
def test_spec_decode_e2e_greedy_correctness_real_model_large_bs( def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
baseline_llm_generator, test_llm_generator, batch_size: int, vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
output_len: int): baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality with a "real" model on a nontrivial batch size. """Verify greedy equality with a "real" model on a nontrivial batch size.
This is the closest test to a real production workload. This is the closest test to a real production workload.
""" """
run_greedy_equality_correctness_test(baseline_llm_generator, run_equality_correctness_test(vllm_runner,
test_llm_generator, common_llm_kwargs,
batch_size, per_test_common_llm_kwargs,
max_output_len=output_len, baseline_llm_kwargs,
force_output_len=True) test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -415,7 +410,7 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs( ...@@ -415,7 +410,7 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [ @pytest.mark.parametrize("per_test_common_llm_kwargs", [
{ {
"model": "JackFram/llama-160m", "model_name": "JackFram/llama-160m",
}, },
]) ])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
...@@ -433,23 +428,29 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs( ...@@ -433,23 +428,29 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
]) ])
@pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
@fork_new_process_for_each_test
def test_spec_decode_e2e_greedy_correctness_with_preemption( def test_spec_decode_e2e_greedy_correctness_with_preemption(
baseline_llm_generator, test_llm_generator, batch_size: int, vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
output_len: int): baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality, even when some sequences are preempted mid- """Verify greedy equality, even when some sequences are preempted mid-
generation. generation.
""" """
run_greedy_equality_correctness_test(baseline_llm_generator, run_equality_correctness_test(vllm_runner,
test_llm_generator, common_llm_kwargs,
batch_size, per_test_common_llm_kwargs,
max_output_len=output_len, baseline_llm_kwargs,
force_output_len=True) test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"common_llm_kwargs", "common_llm_kwargs",
[{ [{
"model": "JackFram/llama-160m", "model_name": "JackFram/llama-160m",
# Skip cuda graph recording for fast test. # Skip cuda graph recording for fast test.
"enforce_eager": True, "enforce_eager": True,
...@@ -487,22 +488,29 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption( ...@@ -487,22 +488,29 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption(
32, 32,
]) ])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
def test_spec_decode_different_block_size(baseline_llm_generator, @fork_new_process_for_each_test
test_llm_generator, batch_size: int, def test_spec_decode_different_block_size(vllm_runner, common_llm_kwargs,
output_len: int): per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality over different block sizes. """Verify greedy equality over different block sizes.
""" """
run_greedy_equality_correctness_test(baseline_llm_generator, run_equality_correctness_test(vllm_runner,
test_llm_generator, common_llm_kwargs,
batch_size, per_test_common_llm_kwargs,
max_output_len=output_len, baseline_llm_kwargs,
force_output_len=True) test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"common_llm_kwargs", "common_llm_kwargs",
[{ [{
"model": "JackFram/llama-160m", "model_name": "JackFram/llama-160m",
# Skip cuda graph recording for fast test. # Skip cuda graph recording for fast test.
"enforce_eager": True, "enforce_eager": True,
...@@ -534,24 +542,31 @@ def test_spec_decode_different_block_size(baseline_llm_generator, ...@@ -534,24 +542,31 @@ def test_spec_decode_different_block_size(baseline_llm_generator,
64, 64,
]) ])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
def test_skip_speculation(baseline_llm_generator, test_llm_generator, @fork_new_process_for_each_test
batch_size: int, output_len: int): def test_skip_speculation(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality when some (or all) sequences skip speculation. """Verify greedy equality when some (or all) sequences skip speculation.
We do this by setting the max model len of the draft model to an We do this by setting the max model len of the draft model to an
artificially low value, such that when the sequences grow beyond it, they artificially low value, such that when the sequences grow beyond it, they
are skipped in speculative decoding. are skipped in speculative decoding.
""" """
run_greedy_equality_correctness_test(baseline_llm_generator, run_equality_correctness_test(vllm_runner,
test_llm_generator, common_llm_kwargs,
batch_size, per_test_common_llm_kwargs,
max_output_len=output_len, baseline_llm_kwargs,
force_output_len=True) test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"common_llm_kwargs", "common_llm_kwargs",
[{ [{
"model": "JackFram/llama-160m", "model_name": "JackFram/llama-160m",
# Skip cuda graph recording for fast test. # Skip cuda graph recording for fast test.
"enforce_eager": True, "enforce_eager": True,
...@@ -571,21 +586,28 @@ def test_skip_speculation(baseline_llm_generator, test_llm_generator, ...@@ -571,21 +586,28 @@ def test_skip_speculation(baseline_llm_generator, test_llm_generator,
@pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("output_len", [10]) @pytest.mark.parametrize("output_len", [10])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
def test_disable_speculation(baseline_llm_generator, test_llm_generator, @fork_new_process_for_each_test
batch_size: int, output_len: int): def test_disable_speculation(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality when all sequences disable speculation. """Verify greedy equality when all sequences disable speculation.
""" """
run_greedy_equality_correctness_test(baseline_llm_generator, run_equality_correctness_test(vllm_runner,
test_llm_generator, common_llm_kwargs,
batch_size, per_test_common_llm_kwargs,
max_output_len=output_len, baseline_llm_kwargs,
force_output_len=True) test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"common_llm_kwargs", "common_llm_kwargs",
[{ [{
"model": "JackFram/llama-68m", "model_name": "JackFram/llama-68m",
# Skip cuda graph recording for fast test. # Skip cuda graph recording for fast test.
"enforce_eager": True, "enforce_eager": True,
...@@ -613,22 +635,28 @@ def test_disable_speculation(baseline_llm_generator, test_llm_generator, ...@@ -613,22 +635,28 @@ def test_disable_speculation(baseline_llm_generator, test_llm_generator,
32, 32,
]) ])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int, @fork_new_process_for_each_test
output_len: int): def test_many_k(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
output_len: int, seed: int):
"""Verify that speculative decoding produces exact equality to without spec """Verify that speculative decoding produces exact equality to without spec
decode with many different values of k. decode with many different values of k.
""" """
run_greedy_equality_correctness_test(baseline_llm_generator, run_equality_correctness_test(vllm_runner,
test_llm_generator, common_llm_kwargs,
batch_size, per_test_common_llm_kwargs,
max_output_len=output_len, baseline_llm_kwargs,
force_output_len=True) test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"common_llm_kwargs", "common_llm_kwargs",
[{ [{
"model": "JackFram/llama-160m", "model_name": "JackFram/llama-160m",
# Skip cuda graph recording for fast test. # Skip cuda graph recording for fast test.
"enforce_eager": True, "enforce_eager": True,
...@@ -657,15 +685,22 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int, ...@@ -657,15 +685,22 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
32, 32,
]) ])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
def test_typical_acceptance_sampling(baseline_llm_generator, @fork_new_process_for_each_test
test_llm_generator, batch_size: int, def test_typical_acceptance_sampling(vllm_runner, common_llm_kwargs,
output_len: int): per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
seed: int):
"""Verify that speculative decoding produces exact equality to without spec """Verify that speculative decoding produces exact equality to without spec
decode with TypicalAcceptanceSampler as the draft token acceptance decode with TypicalAcceptanceSampler as the draft token acceptance
sampling method. sampling method.
""" """
run_greedy_equality_correctness_test(baseline_llm_generator, run_equality_correctness_test(vllm_runner,
test_llm_generator, common_llm_kwargs,
batch_size, per_test_common_llm_kwargs,
max_output_len=output_len, baseline_llm_kwargs,
force_output_len=True) test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
...@@ -26,7 +26,7 @@ for the target model outputs. ...@@ -26,7 +26,7 @@ for the target model outputs.
import pytest import pytest
from .conftest import run_greedy_equality_correctness_test from .conftest import run_equality_correctness_test
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -43,7 +43,7 @@ from .conftest import run_greedy_equality_correctness_test ...@@ -43,7 +43,7 @@ from .conftest import run_greedy_equality_correctness_test
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [ @pytest.mark.parametrize("per_test_common_llm_kwargs", [
{ {
"model": "JackFram/llama-68m", "model_name": "JackFram/llama-68m",
}, },
]) ])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
...@@ -59,15 +59,21 @@ from .conftest import run_greedy_equality_correctness_test ...@@ -59,15 +59,21 @@ from .conftest import run_greedy_equality_correctness_test
]) ])
@pytest.mark.parametrize("batch_size", [1, 32]) @pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
def test_ngram_e2e_greedy_correctness(baseline_llm_generator, def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
test_llm_generator, batch_size: int, per_test_common_llm_kwargs,
output_len: int): baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality on a tiny model with different batch size.""" """Verify greedy equality on a tiny model with different batch size."""
run_greedy_equality_correctness_test(baseline_llm_generator, run_equality_correctness_test(vllm_runner,
test_llm_generator, common_llm_kwargs,
batch_size, per_test_common_llm_kwargs,
max_output_len=output_len, baseline_llm_kwargs,
force_output_len=True) test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -86,7 +92,7 @@ def test_ngram_e2e_greedy_correctness(baseline_llm_generator, ...@@ -86,7 +92,7 @@ def test_ngram_e2e_greedy_correctness(baseline_llm_generator,
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [ @pytest.mark.parametrize("per_test_common_llm_kwargs", [
{ {
"model": "JackFram/llama-160m", "model_name": "JackFram/llama-160m",
}, },
]) ])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
...@@ -105,24 +111,28 @@ def test_ngram_e2e_greedy_correctness(baseline_llm_generator, ...@@ -105,24 +111,28 @@ def test_ngram_e2e_greedy_correctness(baseline_llm_generator,
]) ])
@pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
def test_ngram_e2e_greedy_correctness_with_preemption(baseline_llm_generator, def test_ngram_e2e_greedy_correctness_with_preemption(
test_llm_generator, vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
batch_size: int, baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
output_len: int): seed: int):
"""Verify greedy equality, even when some sequences are preempted mid- """Verify greedy equality, even when some sequences are preempted mid-
generation. generation.
""" """
run_greedy_equality_correctness_test(baseline_llm_generator, run_equality_correctness_test(vllm_runner,
test_llm_generator, common_llm_kwargs,
batch_size, per_test_common_llm_kwargs,
max_output_len=output_len, baseline_llm_kwargs,
force_output_len=True) test_llm_kwargs,
batch_size,
max_output_len=output_len,
temperature=0,
seed=seed)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"common_llm_kwargs", "common_llm_kwargs",
[{ [{
"model": "JackFram/llama-68m", "model_name": "JackFram/llama-68m",
# Skip cuda graph recording for fast test. # Skip cuda graph recording for fast test.
"enforce_eager": True, "enforce_eager": True,
...@@ -159,23 +169,29 @@ def test_ngram_e2e_greedy_correctness_with_preemption(baseline_llm_generator, ...@@ -159,23 +169,29 @@ def test_ngram_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
32, 32,
]) ])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
def test_ngram_different_k(baseline_llm_generator, test_llm_generator, def test_ngram_different_k(vllm_runner, common_llm_kwargs,
batch_size: int, output_len: int): per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify that ngram speculative decoding produces exact equality """Verify that ngram speculative decoding produces exact equality
to without spec decode with many different values of k and to without spec decode with many different values of k and
different ngram_prompt_lookup_max. different ngram_prompt_lookup_max.
""" """
run_greedy_equality_correctness_test(baseline_llm_generator, run_equality_correctness_test(vllm_runner,
test_llm_generator, common_llm_kwargs,
batch_size, per_test_common_llm_kwargs,
max_output_len=output_len, baseline_llm_kwargs,
force_output_len=True) test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"common_llm_kwargs", "common_llm_kwargs",
[{ [{
"model": "JackFram/llama-68m", "model_name": "JackFram/llama-68m",
# Skip cuda graph recording for fast test. # Skip cuda graph recording for fast test.
"enforce_eager": True, "enforce_eager": True,
...@@ -200,14 +216,20 @@ def test_ngram_different_k(baseline_llm_generator, test_llm_generator, ...@@ -200,14 +216,20 @@ def test_ngram_different_k(baseline_llm_generator, test_llm_generator,
32, 32,
]) ])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
def test_ngram_disable_queue(baseline_llm_generator, test_llm_generator, def test_ngram_disable_queue(vllm_runner, common_llm_kwargs,
batch_size: int, output_len: int): per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify that ngram speculative decoding produces exact equality """Verify that ngram speculative decoding produces exact equality
to without spec decode with many different values of k and to without spec decode with many different values of k and
different ngram_prompt_lookup_max. different ngram_prompt_lookup_max.
""" """
run_greedy_equality_correctness_test(baseline_llm_generator, run_equality_correctness_test(vllm_runner,
test_llm_generator, common_llm_kwargs,
batch_size, per_test_common_llm_kwargs,
max_output_len=output_len, baseline_llm_kwargs,
force_output_len=True) test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
...@@ -2,11 +2,17 @@ import pytest ...@@ -2,11 +2,17 @@ import pytest
from .conftest import run_equality_correctness_test from .conftest import run_equality_correctness_test
# main model
MAIN_MODEL = "JackFram/llama-68m"
# speculative model
SPEC_MODEL = "JackFram/llama-160m"
@pytest.mark.parametrize( @pytest.mark.parametrize(
"common_llm_kwargs", "common_llm_kwargs",
[{ [{
"model": "JackFram/llama-68m", "model_name": "JackFram/llama-68m",
# Skip cuda graph recording for fast test. # Skip cuda graph recording for fast test.
"enforce_eager": True, "enforce_eager": True,
...@@ -31,26 +37,34 @@ from .conftest import run_equality_correctness_test ...@@ -31,26 +37,34 @@ from .conftest import run_equality_correctness_test
# Use smaller output len for fast test. # Use smaller output len for fast test.
20, 20,
]) ])
@pytest.mark.parametrize("seed", [None]) def test_seeded_consistency(vllm_runner, common_llm_kwargs,
def test_seeded_consistency(baseline_llm_generator, test_llm_generator, per_test_common_llm_kwargs, baseline_llm_kwargs,
batch_size: int, temperature: float, test_llm_kwargs, batch_size: int,
output_len: int): temperature: float, output_len: int):
"""Verify outputs are consistent across multiple runs with same seed """Verify outputs are consistent across multiple runs with same seed
""" """
run_equality_correctness_test(baseline_llm_generator, run_equality_correctness_test(
test_llm_generator, vllm_runner,
batch_size, common_llm_kwargs,
max_output_len=output_len, per_test_common_llm_kwargs,
temperature=temperature, baseline_llm_kwargs,
seeded=True, test_llm_kwargs,
force_output_len=True) batch_size,
max_output_len=output_len,
temperature=temperature,
disable_seed=False,
)
# Ensure this same test does fail if we _don't_ include per-request seeds # Ensure this same test does fail if we _don't_ include per-request seeds
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
run_equality_correctness_test(baseline_llm_generator, run_equality_correctness_test(
test_llm_generator, vllm_runner,
batch_size, common_llm_kwargs,
max_output_len=output_len, per_test_common_llm_kwargs,
temperature=temperature, baseline_llm_kwargs,
seeded=False, test_llm_kwargs,
force_output_len=True) batch_size,
max_output_len=output_len,
temperature=temperature,
disable_seed=True,
)
...@@ -95,7 +95,7 @@ def test_logger_configuring_can_be_disabled(): ...@@ -95,7 +95,7 @@ def test_logger_configuring_can_be_disabled():
config behavior, however mocks are used to ensure no changes in behavior or config behavior, however mocks are used to ensure no changes in behavior or
configuration occur.""" configuration occur."""
with patch("logging.config.dictConfig") as dict_config_mock: with patch("vllm.logger.dictConfig") as dict_config_mock:
_configure_vllm_root_logger() _configure_vllm_root_logger()
dict_config_mock.assert_not_called() dict_config_mock.assert_not_called()
...@@ -175,9 +175,9 @@ def test_custom_logging_config_is_parsed_and_used_when_provided(): ...@@ -175,9 +175,9 @@ def test_custom_logging_config_is_parsed_and_used_when_provided():
logging_config_file.flush() logging_config_file.flush()
with patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", with patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH",
logging_config_file.name), patch( logging_config_file.name), patch(
"logging.config.dictConfig") as dict_config_mock: "vllm.logger.dictConfig") as dict_config_mock:
_configure_vllm_root_logger() _configure_vllm_root_logger()
assert dict_config_mock.called_with(valid_logging_config) dict_config_mock.assert_called_with(valid_logging_config)
@patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 0) @patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 0)
......
...@@ -19,7 +19,7 @@ ARGS: List[str] = ["--enable-auto-tool-choice", "--max-model-len", "8096"] ...@@ -19,7 +19,7 @@ ARGS: List[str] = ["--enable-auto-tool-choice", "--max-model-len", "8096"]
CONFIGS: Dict[str, ServerConfig] = { CONFIGS: Dict[str, ServerConfig] = {
"hermes": { "hermes": {
"model": "model":
"NousResearch/Hermes-2-Pro-Llama-3-8B", "NousResearch/Hermes-3-Llama-3.1-8B",
"arguments": [ "arguments": [
"--tool-call-parser", "hermes", "--chat-template", "--tool-call-parser", "hermes", "--chat-template",
str(VLLM_PATH / "examples/tool_chat_template_hermes.jinja") str(VLLM_PATH / "examples/tool_chat_template_hermes.jinja")
......
...@@ -20,7 +20,7 @@ from vllm.distributed import (ensure_model_parallel_initialized, ...@@ -20,7 +20,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment) init_distributed_environment)
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.model_executor.model_loader.loader import get_model_loader
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser, get_open_port, is_hip from vllm.utils import FlexibleArgumentParser, get_open_port, is_hip
...@@ -89,11 +89,11 @@ class RemoteOpenAIServer: ...@@ -89,11 +89,11 @@ class RemoteOpenAIServer:
is_local = os.path.isdir(model) is_local = os.path.isdir(model)
if not is_local: if not is_local:
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
engine_config = engine_args.create_engine_config() model_config = engine_args.create_model_config()
dummy_loader = DefaultModelLoader(engine_config.load_config) load_config = engine_args.create_load_config()
dummy_loader._prepare_weights(engine_config.model_config.model,
engine_config.model_config.revision, model_loader = get_model_loader(load_config)
fall_back_to_pt=True) model_loader.download_model(model_config)
env = os.environ.copy() env = os.environ.copy()
# the current process might initialize cuda, # the current process might initialize cuda,
...@@ -178,7 +178,12 @@ def compare_two_settings(model: str, ...@@ -178,7 +178,12 @@ def compare_two_settings(model: str,
env2: The second set of environment variables to pass to the API server. env2: The second set of environment variables to pass to the API server.
""" """
tokenizer = AutoTokenizer.from_pretrained(model) trust_remote_code = "--trust-remote-code"
if trust_remote_code in arg1 or trust_remote_code in arg2:
tokenizer = AutoTokenizer.from_pretrained(model,
trust_remote_code=True)
else:
tokenizer = AutoTokenizer.from_pretrained(model)
prompt = "Hello, my name is" prompt = "Hello, my name is"
token_ids = tokenizer(prompt)["input_ids"] token_ids = tokenizer(prompt)["input_ids"]
......
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main
\ No newline at end of file
...@@ -19,8 +19,7 @@ compressed-tensors, nm-testing/tinyllama-oneshot-w8a16-per-channel, main ...@@ -19,8 +19,7 @@ compressed-tensors, nm-testing/tinyllama-oneshot-w8a16-per-channel, main
compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main
compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main compressed-tensors, nm-testing/TinyLlama-1.1B-Chat-v1.0-actorder-group, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
awq, casperhansen/mixtral-instruct-awq, main awq, casperhansen/mixtral-instruct-awq, main
awq_marlin, casperhansen/mixtral-instruct-awq, main awq_marlin, casperhansen/mixtral-instruct-awq, main
fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main
......
...@@ -339,18 +339,28 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, ...@@ -339,18 +339,28 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# b_g_idx, use_exllama, bit) # b_g_idx, use_exllama, bit)
# TODO: has to be a better way to do this
try:
torch.ops._C.gptq_gemm # noqa B018
@torch.library.register_fake("_C::gptq_gemm")
def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor,
b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor,
use_exllama: bool, bit: int) -> torch.Tensor:
return torch.empty((a.size(0), b_q_weight.size(1)),
dtype=a.dtype,
device=a.device)
except Exception:
pass
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
bit: int) -> None: bit: int) -> None:
quant_ops.gptq_shuffle(q_weight, q_perm, bit) quant_ops.gptq_shuffle(q_weight, q_perm, bit)
# torch.ops._C.gptq_shuffle(q_weight, q_perm, bit) # torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
# squeezellm
def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor,
lookup_table: torch.Tensor) -> None:
torch.ops._C.squeezellm_gemm(vec, mat, mul, lookup_table)
# marlin # marlin
def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int, b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int,
...@@ -369,6 +379,194 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, ...@@ -369,6 +379,194 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
size_n, size_k) size_n, size_k)
# TODO: has to be a better way to do this
try:
torch.ops._C.gptq_marlin_24_gemm # noqa B018
@torch.library.register_fake("_C::gptq_marlin_24_gemm")
def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_meta: torch.Tensor, b_scales: torch.Tensor,
workspace: torch.Tensor,
b_q_type: ScalarType, size_m: int,
size_n: int, size_k: int) -> torch.Tensor:
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
@torch.library.register_fake("_C::gptq_marlin_gemm")
def _gptq_marlin_gemm_fake(a: torch.Tensor,
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
b_zeros: torch.Tensor,
g_idx: torch.Tensor,
perm: torch.Tensor,
workspace: torch.Tensor,
b_q_type: ScalarType,
size_m: int,
size_n: int,
size_k: int,
is_k_full: bool,
has_zp: bool = False,
use_fp32_reduce: bool = False) -> torch.Tensor:
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
@torch.library.register_fake("_C::ggml_dequantize")
def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, m: int,
n: int) -> torch.Tensor:
return torch.empty((m, n), dtype=torch.float16, device=W.device)
@torch.library.register_fake("_C::ggml_mul_mat_vec_a8")
def _ggml_mul_mat_vec_a8_fake(
W: torch.Tensor,
X: torch.Tensor,
quant_type: int,
row: int,
) -> torch.Tensor:
return torch.empty((1, row), dtype=torch.float16, device=W.device)
@torch.library.register_fake("_C::ggml_mul_mat_a8")
def _ggml_mul_mat_a8_fake(
W: torch.Tensor,
X: torch.Tensor,
quant_type: int,
row: int,
) -> torch.Tensor:
batch = X.size(0)
return torch.empty((batch, row), dtype=torch.float16, device=W.device)
@torch.library.register_fake("_C::marlin_qqq_gemm")
def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
s_tok: torch.Tensor, s_ch: torch.Tensor,
s_group: torch.Tensor, workspace: torch.Tensor,
size_m: int, size_n: int,
size_k: int) -> torch.Tensor:
return torch.empty((size_m, size_n),
dtype=torch.float16,
device=a.device)
@torch.library.register_fake("_C::marlin_gemm")
def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, workspace: torch.Tensor,
size_m: int, size_n: int,
size_k: int) -> torch.Tensor:
return torch.empty((size_m, size_n),
dtype=torch.float16,
device=a.device)
@torch.library.register_fake("_C::awq_dequantize")
def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor,
zeros: torch.Tensor, split_k_iters: int, thx: int,
thy: int) -> torch.Tensor:
in_c = qweight.size(0)
qout_c = qweight.size(1)
out_c = qout_c * 8
return torch.empty((in_c, out_c),
dtype=scales.dtype,
device=scales.device)
@torch.library.register_fake("_C::awq_gemm")
def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor,
qzeros: torch.Tensor, scales: torch.Tensor,
split_k_iters: int) -> torch.Tensor:
num_in_feats = input.size(0)
return torch.empty((split_k_iters, num_in_feats, qweight.size(1) * 8),
dtype=input.dtype,
device=input.device).sum(0)
@torch.library.register_fake("_C::aqlm_gemm")
def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor,
codebooks: torch.Tensor, scales: torch.Tensor,
codebook_partition_sizes: List[int],
bias: Optional[torch.Tensor]) -> torch.Tensor:
out_features = codes.size(0) * codebooks.size(2)
flat_input = input.reshape((-1, input.size(-1)))
flat_output = torch.empty((flat_input.size(0), out_features),
dtype=input.dtype,
device=input.device)
output_sizes = list(input.shape)
output_sizes.pop()
output_sizes.append(-1)
return flat_output.reshape(tuple(output_sizes))
@torch.library.register_fake("_C::aqlm_dequant")
def _aqlm_dequant_fake(
codes: torch.Tensor, codebooks: torch.Tensor,
codebook_partition_sizes: List[int]) -> torch.Tensor:
in_features = codes.size(1) * 8
out_features = codes.size(0)
return torch.empty((out_features, in_features),
dtype=codebooks.dtype,
device=codebooks.device)
@torch.library.register_fake("_C::fp8_marlin_gemm")
def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, workspace: torch.Tensor,
num_bits: int, size_m: int, size_n: int,
size_k: int) -> torch.Tensor:
return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)
@torch.library.register_fake("_C::machete_gemm")
def machete_gemm_fake(
a: torch.Tensor,
b_q: torch.
Tensor, # Should be the tensor returned by machete_prepack_B
b_type: ScalarType,
b_scales: Optional[torch.Tensor] = None,
b_zeros: Optional[torch.Tensor] = None,
b_group_size: Optional[int] = None,
c: Optional[torch.Tensor] = None,
alpha: Optional[float] = None,
beta: Optional[float] = None,
schedule: Optional[str] = None,
) -> torch.Tensor:
m = a.size(0)
n = b_q.size(1)
return torch.empty((m, n), device=a.device, dtype=a.dtype)
@torch.library.register_fake("_C::machete_prepack_B")
def machete_prepack_B_fake(b_q_weight: torch.Tensor,
b_type: ScalarType) -> torch.Tensor:
return torch.empty_like(b_q_weight)
@torch.library.register_fake("_C::causal_conv1d_fwd")
def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor,
bias_: Optional[torch.Tensor],
seq_idx_: Optional[torch.Tensor],
initial_states_: Optional[torch.Tensor],
final_states_out_: Optional[torch.Tensor],
silu_activation: bool) -> torch.Tensor:
return torch.empty_like(x)
@torch.library.register_fake("_C::causal_conv1d_update")
def causal_conv1d_update_fake(x: torch.Tensor, conv_state: torch.Tensor,
weight: torch.Tensor,
bias_: Optional[torch.Tensor],
silu_activation: bool) -> torch.Tensor:
return torch.empty_like(x)
@torch.library.register_fake("_C::selective_scan_fwd")
def selective_scan_fwd_fake(
u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, C: torch.Tensor, D_: Optional[torch.Tensor],
z_: Optional[torch.Tensor], delta_bias_: Optional[torch.Tensor],
delta_softplus: bool, index_: Optional[torch.Tensor],
x: Optional[torch.Tensor]) -> List[torch.Tensor]:
a = torch.empty_like(u)
if x is not None:
b = x
else:
b = torch.empty((u.size(0), u.size(1), A.size(1)),
dtype=u.dtype,
device=u.device)
if z_ is not None:
c = torch.empty_like(z_)
return [a, b, c]
else:
return [a, b]
except Exception:
pass
# cutlass # cutlass
def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability) return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
......
from dataclasses import dataclass
from functools import lru_cache
from typing import List, Literal
import numpy as np
import numpy.typing as npt
from huggingface_hub import hf_hub_download
from PIL import Image
from vllm.multimodal.utils import (sample_frames_from_video,
try_import_video_packages)
from .base import get_cache_dir
@lru_cache
def download_video_asset(filename: str) -> str:
"""
Download and open an image from huggingface
repo: raushan-testing-hf/videos-test
"""
video_directory = get_cache_dir() / "video-eample-data"
video_directory.mkdir(parents=True, exist_ok=True)
video_path = video_directory / filename
video_path_str = str(video_path)
if not video_path.exists():
video_path_str = hf_hub_download(
repo_id="raushan-testing-hf/videos-test",
filename=filename,
repo_type="dataset",
cache_dir=video_directory,
)
return video_path_str
def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray:
cv2 = try_import_video_packages()
cap = cv2.VideoCapture(path)
if not cap.isOpened():
raise ValueError(f"Could not open video file {path}")
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
frames = []
for i in range(total_frames):
ret, frame = cap.read()
if ret:
frames.append(frame)
cap.release()
frames = np.stack(frames)
frames = sample_frames_from_video(frames, num_frames)
if len(frames) < num_frames:
raise ValueError(f"Could not read enough frames from video file {path}"
f" (expected {num_frames} frames, got {len(frames)})")
return frames
def video_to_pil_images_list(path: str,
num_frames: int = -1) -> List[Image.Image]:
cv2 = try_import_video_packages()
frames = video_to_ndarrays(path, num_frames)
return [
Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
for frame in frames
]
@dataclass(frozen=True)
class VideoAsset:
name: Literal["sample_demo_1.mp4"]
num_frames: int = -1
@property
def pil_images(self) -> List[Image.Image]:
video_path = download_video_asset(self.name)
ret = video_to_pil_images_list(video_path, self.num_frames)
return ret
@property
def np_ndarrays(self) -> List[npt.NDArray]:
video_path = download_video_asset(self.name)
ret = video_to_ndarrays(video_path, self.num_frames)
return ret
...@@ -16,7 +16,8 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState, ...@@ -16,7 +16,8 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
from vllm.utils import async_tensor_h2d, make_tensor_with_pad from vllm.utils import async_tensor_h2d, make_tensor_with_pad
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUBuilder from vllm.worker.model_runner import (ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata)
from vllm_flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func from vllm_flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func
from vllm_flash_attn import flash_attn_with_kvcache as _flash_attn_with_kvcache from vllm_flash_attn import flash_attn_with_kvcache as _flash_attn_with_kvcache
...@@ -302,14 +303,12 @@ class FlashAttentionMetadata(AttentionMetadata): ...@@ -302,14 +303,12 @@ class FlashAttentionMetadata(AttentionMetadata):
) )
return self._cached_decode_metadata return self._cached_decode_metadata
def advance_step(self, num_seqs: int, num_queries: int): def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata",
sampled_token_ids: Optional[torch.Tensor],
block_size: int, num_seqs: int, num_queries: int):
""" """
Update metadata in-place to advance one decode step. Update metadata in-place to advance one decode step.
""" """
# GPU in-place update is currently called separately through
# custom_ops.advance_step(). See draft_model_runner. TODO(will): Move
# this logic to the backend.
# When using cudagraph, the num_seqs is padded to the next captured # When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in # batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries # the batch. For --enforce-eager mode, num_seqs == num_queries
...@@ -347,6 +346,16 @@ class FlashAttentionMetadata(AttentionMetadata): ...@@ -347,6 +346,16 @@ class FlashAttentionMetadata(AttentionMetadata):
self.seq_lens[i] += 1 self.seq_lens[i] += 1
self.max_decode_seq_len = max(self.seq_lens) self.max_decode_seq_len = max(self.seq_lens)
ops.advance_step(num_seqs=num_seqs,
num_queries=num_queries,
block_size=block_size,
input_tokens=model_input.input_tokens,
sampled_token_ids=sampled_token_ids,
input_positions=model_input.input_positions,
seq_lens=self.seq_lens_tensor,
slot_mapping=self.slot_mapping,
block_tables=self.block_tables)
class FlashAttentionMetadataBuilder( class FlashAttentionMetadataBuilder(
AttentionMetadataBuilder[FlashAttentionMetadata]): AttentionMetadataBuilder[FlashAttentionMetadata]):
...@@ -462,9 +471,19 @@ class FlashAttentionMetadataBuilder( ...@@ -462,9 +471,19 @@ class FlashAttentionMetadataBuilder(
# The shape of graph_block_tables is # The shape of graph_block_tables is
# [max batch size, max context len // block size]. # [max batch size, max context len // block size].
input_block_tables = self.runner.graph_block_tables[:batch_size] input_block_tables = self.runner.graph_block_tables[:batch_size]
max_blocks = input_block_tables.shape[1]
for i, block_table in enumerate(self.block_tables): for i, block_table in enumerate(self.block_tables):
if block_table: if block_table:
input_block_tables[i, :len(block_table)] = block_table num_blocks = len(block_table)
if num_blocks <= max_blocks:
input_block_tables[i, :num_blocks] = block_table
else:
# It may be possible to have more blocks allocated due
# to lookahead slots of multi-step, however, they are
# not used anyway, so can be safely ignored.
input_block_tables[
i, :max_blocks] = block_table[:max_blocks]
block_tables = torch.from_numpy(input_block_tables).to( block_tables = torch.from_numpy(input_block_tables).to(
device=device, non_blocking=True) device=device, non_blocking=True)
else: else:
......
...@@ -224,6 +224,7 @@ class FlashInferState(AttentionState): ...@@ -224,6 +224,7 @@ class FlashInferState(AttentionState):
query_start_loc=query_start_loc_host, query_start_loc=query_start_loc_host,
device=self.runner.device, device=self.runner.device,
data_type=kv_cache_dtype, data_type=kv_cache_dtype,
q_data_type=self.runner.model_config.dtype,
use_cuda_graph=True, use_cuda_graph=True,
decode_wrapper=self._graph_decode_wrapper, decode_wrapper=self._graph_decode_wrapper,
prefill_wrapper=None) prefill_wrapper=None)
...@@ -292,6 +293,8 @@ class FlashInferMetadata(AttentionMetadata): ...@@ -292,6 +293,8 @@ class FlashInferMetadata(AttentionMetadata):
page_size: Optional[int] = None page_size: Optional[int] = None
# The data type of the paged kv cache # The data type of the paged kv cache
data_type: torch.dtype = None data_type: torch.dtype = None
# The data type of the query
q_data_type: torch.dtype = None
device: torch.device = torch.device("cuda") device: torch.device = torch.device("cuda")
is_profile_run: bool = False is_profile_run: bool = False
...@@ -353,7 +356,10 @@ class FlashInferMetadata(AttentionMetadata): ...@@ -353,7 +356,10 @@ class FlashInferMetadata(AttentionMetadata):
self.page_size, self.page_size,
# Disable flashinfer's pos encoding and use vllm's rope. # Disable flashinfer's pos encoding and use vllm's rope.
pos_encoding_mode="NONE", pos_encoding_mode="NONE",
data_type=self.data_type) # kv-cache data type.
data_type=self.data_type,
# query data type.
q_data_type=self.q_data_type)
def asdict_zerocopy(self, def asdict_zerocopy(self,
skip_fields: Optional[Set[str]] = None skip_fields: Optional[Set[str]] = None
...@@ -617,6 +623,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -617,6 +623,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
device=device, device=device,
data_type=kv_cache_dtype, data_type=kv_cache_dtype,
q_data_type=self.runner.model_config.dtype,
use_cuda_graph=use_captured_graph, use_cuda_graph=use_captured_graph,
is_profile_run=self.is_profile_run) is_profile_run=self.is_profile_run)
......
...@@ -11,8 +11,9 @@ import uvloop ...@@ -11,8 +11,9 @@ import uvloop
from tqdm import tqdm from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer, from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase) PreTrainedTokenizerBase)
from vllm.inputs import PromptInputs from vllm.inputs import PromptInputs
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.engine.arg_utils import DEVICE_OPTIONS, AsyncEngineArgs, EngineArgs
from vllm.entrypoints.openai.api_server import ( from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args) build_async_engine_client_from_engine_args)
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
...@@ -504,13 +505,11 @@ if __name__ == "__main__": ...@@ -504,13 +505,11 @@ if __name__ == "__main__":
'accuracy issues. FP8_E5M2 (without scaling) is only supported on ' 'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is ' 'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
'instead supported for common inference criteria.') 'instead supported for common inference criteria.')
parser.add_argument( parser.add_argument("--device",
"--device", type=str,
type=str, default="auto",
default="auto", choices=DEVICE_OPTIONS,
choices=["auto", "cuda", "cpu", "openvino", "tpu", "xpu"], help='device type for vLLM execution')
help='device type for vLLM execution, supporting CUDA, OpenVINO and '
'CPU.')
parser.add_argument( parser.add_argument(
"--num-scheduler-steps", "--num-scheduler-steps",
type=int, type=int,
......
...@@ -10,7 +10,7 @@ import torch ...@@ -10,7 +10,7 @@ import torch
import vllm.envs as envs import vllm.envs as envs
class TorchCompileWrapperWithCustomDispacther: class TorchCompileWrapperWithCustomDispatcher:
""" """
A wrapper class for torch.compile, with a custom dispatch logic. A wrapper class for torch.compile, with a custom dispatch logic.
Subclasses should: Subclasses should:
......
...@@ -13,7 +13,7 @@ from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS ...@@ -13,7 +13,7 @@ from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.tracing import is_otel_available, otel_import_error_traceback from vllm.tracing import is_otel_available, otel_import_error_traceback
from vllm.transformers_utils.config import (get_config, from vllm.transformers_utils.config import (ConfigFormat, get_config,
get_hf_image_processor_config, get_hf_image_processor_config,
get_hf_text_config) get_hf_text_config)
from vllm.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, GiB_bytes, from vllm.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, GiB_bytes,
...@@ -35,18 +35,20 @@ _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 ...@@ -35,18 +35,20 @@ _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 4096 _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 4096
_PP_SUPPORTED_MODELS = [ _PP_SUPPORTED_MODELS = [
"AquilaModel",
"AquilaForCausalLM", "AquilaForCausalLM",
"AquilaModel",
"DeepseekV2ForCausalLM", "DeepseekV2ForCausalLM",
"GPT2LMHeadModel",
"InternLM2ForCausalLM",
"InternLMForCausalLM", "InternLMForCausalLM",
"InternVLChatModel",
"JAISLMHeadModel", "JAISLMHeadModel",
"LlamaForCausalLM", "LlamaForCausalLM",
"LLaMAForCausalLM", "LLaMAForCausalLM",
"MistralForCausalLM", "MistralForCausalLM",
"Phi3ForCausalLM",
"GPT2LMHeadModel",
"MixtralForCausalLM", "MixtralForCausalLM",
"NemotronForCausalLM", "NemotronForCausalLM",
"Phi3ForCausalLM",
"Qwen2ForCausalLM", "Qwen2ForCausalLM",
"Qwen2MoeForCausalLM", "Qwen2MoeForCausalLM",
"QWenLMHeadModel", "QWenLMHeadModel",
...@@ -119,35 +121,37 @@ class ModelConfig: ...@@ -119,35 +121,37 @@ class ModelConfig:
override default neuron config that are specific to Neuron devices, override default neuron config that are specific to Neuron devices,
this argument will be used to configure the neuron config that this argument will be used to configure the neuron config that
can not be gathered from the vllm arguments. can not be gathered from the vllm arguments.
config_format: The config format which shall be loaded.
Defaults to 'auto' which defaults to 'hf'.
""" """
def __init__( def __init__(self,
self, model: str,
model: str, tokenizer: str,
tokenizer: str, tokenizer_mode: str,
tokenizer_mode: str, trust_remote_code: bool,
trust_remote_code: bool, dtype: Union[str, torch.dtype],
dtype: Union[str, torch.dtype], seed: int,
seed: int, revision: Optional[str] = None,
revision: Optional[str] = None, code_revision: Optional[str] = None,
code_revision: Optional[str] = None, rope_scaling: Optional[dict] = None,
rope_scaling: Optional[dict] = None, rope_theta: Optional[float] = None,
rope_theta: Optional[float] = None, tokenizer_revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None, max_model_len: Optional[int] = None,
max_model_len: Optional[int] = None, spec_target_max_model_len: Optional[int] = None,
spec_target_max_model_len: Optional[int] = None, quantization: Optional[str] = None,
quantization: Optional[str] = None, quantization_param_path: Optional[str] = None,
quantization_param_path: Optional[str] = None, enforce_eager: Optional[bool] = None,
enforce_eager: Optional[bool] = None, max_context_len_to_capture: Optional[int] = None,
max_context_len_to_capture: Optional[int] = None, max_seq_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: Optional[int] = None, max_logprobs: int = 20,
max_logprobs: int = 20, disable_sliding_window: bool = False,
disable_sliding_window: bool = False, skip_tokenizer_init: bool = False,
skip_tokenizer_init: bool = False, served_model_name: Optional[Union[str, List[str]]] = None,
served_model_name: Optional[Union[str, List[str]]] = None, limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
limit_mm_per_prompt: Optional[Mapping[str, int]] = None, use_async_output_proc: bool = True,
use_async_output_proc: bool = True, override_neuron_config: Optional[Dict[str, Any]] = None,
override_neuron_config: Optional[Dict[str, Any]] = None) -> None: config_format: ConfigFormat = ConfigFormat.AUTO) -> None:
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.tokenizer_mode = tokenizer_mode self.tokenizer_mode = tokenizer_mode
...@@ -174,7 +178,8 @@ class ModelConfig: ...@@ -174,7 +178,8 @@ class ModelConfig:
self.skip_tokenizer_init = skip_tokenizer_init self.skip_tokenizer_init = skip_tokenizer_init
self.hf_config = get_config(self.model, trust_remote_code, revision, self.hf_config = get_config(self.model, trust_remote_code, revision,
code_revision, rope_scaling, rope_theta) code_revision, rope_scaling, rope_theta,
config_format)
self.hf_text_config = get_hf_text_config(self.hf_config) self.hf_text_config = get_hf_text_config(self.hf_config)
self.hf_image_processor_config = get_hf_image_processor_config( self.hf_image_processor_config = get_hf_image_processor_config(
self.model, revision) self.model, revision)
...@@ -275,11 +280,11 @@ class ModelConfig: ...@@ -275,11 +280,11 @@ class ModelConfig:
def _verify_quantization(self) -> None: def _verify_quantization(self) -> None:
supported_quantization = [*QUANTIZATION_METHODS] supported_quantization = [*QUANTIZATION_METHODS]
rocm_supported_quantization = ["awq", "gptq", "squeezellm"] # "fp8" rocm_supported_quantization = ["awq", "gptq"] # "fp8"
optimized_quantization_methods = [ optimized_quantization_methods = [
"fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin", "fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
"fbgemm_fp8", "compressed_tensors", "compressed-tensors", "awq_marlin", "fbgemm_fp8", "compressed_tensors",
"experts_int8" "compressed-tensors", "experts_int8"
] ]
tpu_supported_quantization = ["tpu_int8"] tpu_supported_quantization = ["tpu_int8"]
neuron_supported_quantization = ["neuron_quant"] neuron_supported_quantization = ["neuron_quant"]
...@@ -744,6 +749,7 @@ class LoadFormat(str, enum.Enum): ...@@ -744,6 +749,7 @@ class LoadFormat(str, enum.Enum):
SHARDED_STATE = "sharded_state" SHARDED_STATE = "sharded_state"
GGUF = "gguf" GGUF = "gguf"
BITSANDBYTES = "bitsandbytes" BITSANDBYTES = "bitsandbytes"
MISTRAL = "mistral"
@dataclass @dataclass
...@@ -767,7 +773,7 @@ class LoadConfig: ...@@ -767,7 +773,7 @@ class LoadConfig:
ignore_patterns: The list of patterns to ignore when loading the model. ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's Default to "original/**/*" to avoid repeated loading of llama's
checkpoints. checkpoints.
""" """
load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
...@@ -870,7 +876,8 @@ class ParallelConfig: ...@@ -870,7 +876,8 @@ class ParallelConfig:
from vllm.executor import ray_utils from vllm.executor import ray_utils
backend = "mp" backend = "mp"
ray_found = ray_utils.ray_is_available() ray_found = ray_utils.ray_is_available()
if cuda_device_count_stateless() < self.world_size: if (torch.cuda.is_available()
and cuda_device_count_stateless() < self.world_size):
if not ray_found: if not ray_found:
raise ValueError("Unable to load Ray which is " raise ValueError("Unable to load Ray which is "
"required for multi-node inference, " "required for multi-node inference, "
...@@ -1535,7 +1542,7 @@ class LoRAConfig: ...@@ -1535,7 +1542,7 @@ class LoRAConfig:
if model_config.quantization and model_config.quantization not in [ if model_config.quantization and model_config.quantization not in [
"awq", "gptq" "awq", "gptq"
]: ]:
# TODO support marlin and squeezellm # TODO support marlin
logger.warning("%s quantization is not tested with LoRA yet.", logger.warning("%s quantization is not tested with LoRA yet.",
model_config.quantization) model_config.quantization)
...@@ -1552,14 +1559,6 @@ class PromptAdapterConfig: ...@@ -1552,14 +1559,6 @@ class PromptAdapterConfig:
prompt_adapter_dtype: Optional[torch.dtype] = None prompt_adapter_dtype: Optional[torch.dtype] = None
def __post_init__(self): def __post_init__(self):
library_name = 'peft'
try:
__import__(library_name)
except ImportError as e:
raise ImportError(
f"'{library_name}' is not installed for prompt adapter support."
f"Please install it using 'pip install {library_name}'."
) from e
if self.max_prompt_adapters < 1: if self.max_prompt_adapters < 1:
raise ValueError(f"max_prompt_adapters " raise ValueError(f"max_prompt_adapters "
...@@ -1735,8 +1734,11 @@ def _get_and_verify_max_len( ...@@ -1735,8 +1734,11 @@ def _get_and_verify_max_len(
"with rope_scaling. Please raise an issue so we can " "with rope_scaling. Please raise an issue so we can "
"investigate.") "investigate.")
assert "factor" in rope_scaling if rope_type == "mrope":
scaling_factor = rope_scaling["factor"] scaling_factor = 1
else:
assert "factor" in rope_scaling
scaling_factor = rope_scaling["factor"]
if rope_type == "yarn": if rope_type == "yarn":
derived_max_model_len = rope_scaling[ derived_max_model_len = rope_scaling[
"original_max_position_embeddings"] "original_max_position_embeddings"]
......
...@@ -537,13 +537,6 @@ class Scheduler: ...@@ -537,13 +537,6 @@ class Scheduler:
preempted: List[SequenceGroup] = ret.preempted preempted: List[SequenceGroup] = ret.preempted
swapped_out: List[SequenceGroup] = ret.swapped_out swapped_out: List[SequenceGroup] = ret.swapped_out
# NOTE(woosuk): Preemption happens only when there is no available slot
# to keep all the sequence groups in the RUNNING state.
# Store original running requests for the case of async + preemption
if self.use_async_output_proc:
orig_running = self.running.copy()
running_queue = self.running running_queue = self.running
assert len(self._async_stopped) == 0 assert len(self._async_stopped) == 0
while running_queue: while running_queue:
...@@ -552,6 +545,7 @@ class Scheduler: ...@@ -552,6 +545,7 @@ class Scheduler:
seq_group, SequenceStatus.RUNNING, enable_chunking, budget) seq_group, SequenceStatus.RUNNING, enable_chunking, budget)
if num_running_tokens == 0: if num_running_tokens == 0:
# No budget => Stop
break break
running_queue.popleft() running_queue.popleft()
...@@ -565,18 +559,8 @@ class Scheduler: ...@@ -565,18 +559,8 @@ class Scheduler:
self._async_stopped.append(seq_group) self._async_stopped.append(seq_group)
continue continue
# With async postprocessor, when preemption kicks in, we need # NOTE(woosuk): Preemption happens only when there is no available
# first to drain the async postprocessor, so that all async # slot to keep all the sequence groups in the RUNNING state.
# block_table freeing is applied before the preemption freeing
# is applied.
if self.use_async_output_proc and not self._can_append_slots(
seq_group):
tmp = self.running
self.running = orig_running
assert self.output_proc_callback is not None
self.output_proc_callback()
self.running = tmp
while not self._can_append_slots(seq_group): while not self._can_append_slots(seq_group):
budget.subtract_num_batched_tokens(seq_group.request_id, budget.subtract_num_batched_tokens(seq_group.request_id,
num_running_tokens) num_running_tokens)
...@@ -588,24 +572,43 @@ class Scheduler: ...@@ -588,24 +572,43 @@ class Scheduler:
and seq_group.lora_int_id in curr_loras): and seq_group.lora_int_id in curr_loras):
curr_loras.remove(seq_group.lora_int_id) curr_loras.remove(seq_group.lora_int_id)
# Determine victim sequence
cont_loop = True
if running_queue: if running_queue:
# Preempt the lowest-priority sequence groups. # Preempt the lowest-priority sequence group.
victim_seq_group = running_queue.pop() victim_seq_group = running_queue.pop()
else:
# No other sequence group can be preempted.
# Preempt the current sequence group.
# Note: This is also where we stop this loop
# (since there is nothing else to preempt)
victim_seq_group = seq_group
cont_loop = False
# With async postprocessor, before preempting a sequence
# we need to ensure it has no pending async postprocessor
do_preempt = True
if self.use_async_output_proc:
assert self.output_proc_callback is not None
self.output_proc_callback(
request_id=victim_seq_group.request_id)
# It may be that the async pending "victim_seq_group"
# becomes finished, in which case we simply free it.
if victim_seq_group.is_finished():
self._free_finished_seq_group(victim_seq_group)
do_preempt = False
# Do preemption
if do_preempt:
preempted_mode = self._preempt(victim_seq_group, preempted_mode = self._preempt(victim_seq_group,
blocks_to_swap_out) blocks_to_swap_out)
if preempted_mode == PreemptionMode.RECOMPUTE: if preempted_mode == PreemptionMode.RECOMPUTE:
preempted.append(victim_seq_group) preempted.append(victim_seq_group)
else: else:
swapped_out.append(victim_seq_group) swapped_out.append(victim_seq_group)
else:
# No other sequence groups can be preempted. if not cont_loop:
# Preempt the current sequence group.
preempted_mode = self._preempt(seq_group,
blocks_to_swap_out)
if preempted_mode == PreemptionMode.RECOMPUTE:
preempted.append(seq_group)
else:
swapped_out.append(seq_group)
break break
else: else:
self._append_slots(seq_group, blocks_to_copy) self._append_slots(seq_group, blocks_to_copy)
...@@ -1264,22 +1267,26 @@ class Scheduler: ...@@ -1264,22 +1267,26 @@ class Scheduler:
if seq.is_finished(): if seq.is_finished():
self.free_seq(seq) self.free_seq(seq)
def _free_finished_seq_group(self, seq_group: SequenceGroup) -> None:
if seq_group.is_finished():
# Free cross-attention block table, if it exists
self._free_seq_group_cross_attn_blocks(seq_group)
# Add the finished requests to the finished requests list.
# This list will be used to update the Mamba cache in the
# next step.
self._finished_requests_ids.append(seq_group.request_id)
# Free finished seqs
self._free_finished_seqs(seq_group)
def free_finished_seq_groups(self) -> None: def free_finished_seq_groups(self) -> None:
remaining: Deque[SequenceGroup] = deque() remaining: Deque[SequenceGroup] = deque()
for seq_group in self.running: for seq_group in self.running:
if seq_group.is_finished(): self._free_finished_seq_group(seq_group)
# Free cross-attention block table, if it exists if not seq_group.is_finished():
self._free_seq_group_cross_attn_blocks(seq_group)
# Add the finished requests to the finished requests list.
# This list will be used to update the Mamba cache in the
# next step.
self._finished_requests_ids.append(seq_group.request_id)
else:
remaining.append(seq_group) remaining.append(seq_group)
# Free finished seqs
self._free_finished_seqs(seq_group)
self.running = remaining self.running = remaining
# Handle async stopped sequence groups # Handle async stopped sequence groups
......
...@@ -8,10 +8,10 @@ from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, ...@@ -8,10 +8,10 @@ from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple,
import torch import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig,
EngineConfig, LoadConfig, LoadFormat, LoRAConfig, DeviceConfig, EngineConfig, LoadConfig, LoadFormat,
ModelConfig, ObservabilityConfig, ParallelConfig, LoRAConfig, ModelConfig, ObservabilityConfig,
PromptAdapterConfig, SchedulerConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig, TokenizerPoolConfig) SpeculativeConfig, TokenizerPoolConfig)
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -26,6 +26,16 @@ logger = init_logger(__name__) ...@@ -26,6 +26,16 @@ logger = init_logger(__name__)
ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"] ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"]
DEVICE_OPTIONS = [
"auto",
"cuda",
"neuron",
"cpu",
"openvino",
"tpu",
"xpu",
]
def nullable_str(val: str): def nullable_str(val: str):
if not val or val == "None": if not val or val == "None":
...@@ -65,6 +75,7 @@ class EngineArgs: ...@@ -65,6 +75,7 @@ class EngineArgs:
trust_remote_code: bool = False trust_remote_code: bool = False
download_dir: Optional[str] = None download_dir: Optional[str] = None
load_format: str = 'auto' load_format: str = 'auto'
config_format: str = 'auto'
dtype: str = 'auto' dtype: str = 'auto'
kv_cache_dtype: str = 'auto' kv_cache_dtype: str = 'auto'
quantization_param_path: Optional[str] = None quantization_param_path: Optional[str] = None
...@@ -234,6 +245,13 @@ class EngineArgs: ...@@ -234,6 +245,13 @@ class EngineArgs:
'section for more information.\n' 'section for more information.\n'
'* "bitsandbytes" will load the weights using bitsandbytes ' '* "bitsandbytes" will load the weights using bitsandbytes '
'quantization.\n') 'quantization.\n')
parser.add_argument(
'--config-format',
default=EngineArgs.config_format,
choices=[f.value for f in ConfigFormat],
help='The format of the model config to load.\n\n'
'* "auto" will try to load the config in hf format '
'if available else it will try to load in mistral format ')
parser.add_argument( parser.add_argument(
'--dtype', '--dtype',
type=str, type=str,
...@@ -545,10 +563,7 @@ class EngineArgs: ...@@ -545,10 +563,7 @@ class EngineArgs:
parser.add_argument("--device", parser.add_argument("--device",
type=str, type=str,
default=EngineArgs.device, default=EngineArgs.device,
choices=[ choices=DEVICE_OPTIONS,
"auto", "cuda", "neuron", "cpu", "openvino",
"tpu", "xpu"
],
help='Device type for vLLM execution.') help='Device type for vLLM execution.')
parser.add_argument('--num-scheduler-steps', parser.add_argument('--num-scheduler-steps',
type=int, type=int,
...@@ -763,6 +778,43 @@ class EngineArgs: ...@@ -763,6 +778,43 @@ class EngineArgs:
engine_args = cls(**{attr: getattr(args, attr) for attr in attrs}) engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
return engine_args return engine_args
def create_model_config(self) -> ModelConfig:
return ModelConfig(
model=self.model,
tokenizer=self.tokenizer,
tokenizer_mode=self.tokenizer_mode,
trust_remote_code=self.trust_remote_code,
dtype=self.dtype,
seed=self.seed,
revision=self.revision,
code_revision=self.code_revision,
rope_scaling=self.rope_scaling,
rope_theta=self.rope_theta,
tokenizer_revision=self.tokenizer_revision,
max_model_len=self.max_model_len,
quantization=self.quantization,
quantization_param_path=self.quantization_param_path,
enforce_eager=self.enforce_eager,
max_context_len_to_capture=self.max_context_len_to_capture,
max_seq_len_to_capture=self.max_seq_len_to_capture,
max_logprobs=self.max_logprobs,
disable_sliding_window=self.disable_sliding_window,
skip_tokenizer_init=self.skip_tokenizer_init,
served_model_name=self.served_model_name,
limit_mm_per_prompt=self.limit_mm_per_prompt,
use_async_output_proc=not self.disable_async_output_proc,
override_neuron_config=self.override_neuron_config,
config_format=self.config_format,
)
def create_load_config(self) -> LoadConfig:
return LoadConfig(
load_format=self.load_format,
download_dir=self.download_dir,
model_loader_extra_config=self.model_loader_extra_config,
ignore_patterns=self.ignore_patterns,
)
def create_engine_config(self) -> EngineConfig: def create_engine_config(self) -> EngineConfig:
# gguf file needs a specific model loader and doesn't use hf_repo # gguf file needs a specific model loader and doesn't use hf_repo
if check_gguf_file(self.model): if check_gguf_file(self.model):
...@@ -789,31 +841,8 @@ class EngineArgs: ...@@ -789,31 +841,8 @@ class EngineArgs:
f", but got {self.cpu_offload_gb}") f", but got {self.cpu_offload_gb}")
device_config = DeviceConfig(device=self.device) device_config = DeviceConfig(device=self.device)
model_config = ModelConfig( model_config = self.create_model_config()
model=self.model,
tokenizer=self.tokenizer,
tokenizer_mode=self.tokenizer_mode,
trust_remote_code=self.trust_remote_code,
dtype=self.dtype,
seed=self.seed,
revision=self.revision,
code_revision=self.code_revision,
rope_scaling=self.rope_scaling,
rope_theta=self.rope_theta,
tokenizer_revision=self.tokenizer_revision,
max_model_len=self.max_model_len,
quantization=self.quantization,
quantization_param_path=self.quantization_param_path,
enforce_eager=self.enforce_eager,
max_context_len_to_capture=self.max_context_len_to_capture,
max_seq_len_to_capture=self.max_seq_len_to_capture,
max_logprobs=self.max_logprobs,
disable_sliding_window=self.disable_sliding_window,
skip_tokenizer_init=self.skip_tokenizer_init,
served_model_name=self.served_model_name,
limit_mm_per_prompt=self.limit_mm_per_prompt,
use_async_output_proc=not self.disable_async_output_proc,
override_neuron_config=self.override_neuron_config)
cache_config = CacheConfig( cache_config = CacheConfig(
block_size=self.block_size if self.device != "neuron" else block_size=self.block_size if self.device != "neuron" else
self.max_model_len, # neuron needs block_size = max_model_len self.max_model_len, # neuron needs block_size = max_model_len
...@@ -856,7 +885,6 @@ class EngineArgs: ...@@ -856,7 +885,6 @@ class EngineArgs:
if (is_gpu and not use_sliding_window and not use_spec_decode if (is_gpu and not use_sliding_window and not use_spec_decode
and not self.enable_lora and not self.enable_lora
and not self.enable_prompt_adapter and not self.enable_prompt_adapter
and not self.enable_prefix_caching
and not has_seqlen_agnostic_layers): and not has_seqlen_agnostic_layers):
self.enable_chunked_prefill = True self.enable_chunked_prefill = True
logger.warning( logger.warning(
...@@ -956,12 +984,7 @@ class EngineArgs: ...@@ -956,12 +984,7 @@ class EngineArgs:
self.model_loader_extra_config[ self.model_loader_extra_config[
"qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path
load_config = LoadConfig( load_config = self.create_load_config()
load_format=self.load_format,
download_dir=self.download_dir,
model_loader_extra_config=self.model_loader_extra_config,
ignore_patterns=self.ignore_patterns,
)
prompt_adapter_config = PromptAdapterConfig( prompt_adapter_config = PromptAdapterConfig(
max_prompt_adapters=self.max_prompt_adapters, max_prompt_adapters=self.max_prompt_adapters,
......
...@@ -342,17 +342,17 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -342,17 +342,17 @@ class _AsyncLLMEngine(LLMEngine):
virtual_engine] virtual_engine]
# Execute the model. # Execute the model.
output = await self.model_executor.execute_model_async( outputs = await self.model_executor.execute_model_async(
execute_model_req) execute_model_req)
# we need to do this here so that last step's sampled_token_ids can # we need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP. # be passed to the next iteration for PP.
if self.scheduler_config.is_multi_step: if self.scheduler_config.is_multi_step:
self._update_cached_scheduler_output(virtual_engine, output) self._update_cached_scheduler_output(virtual_engine, outputs)
else: else:
if len(ctx.output_queue) > 0: if len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx) self._process_model_outputs(ctx=ctx)
output = [] outputs = []
# Finish the current step for all the sequence groups. # Finish the current step for all the sequence groups.
if self.scheduler_config.is_multi_step: if self.scheduler_config.is_multi_step:
...@@ -365,25 +365,25 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -365,25 +365,25 @@ class _AsyncLLMEngine(LLMEngine):
self.cached_scheduler_outputs[ self.cached_scheduler_outputs[
virtual_engine] = SchedulerOutputState() virtual_engine] = SchedulerOutputState()
is_async = allow_async_output_proc ctx.append_output(outputs=outputs,
is_last_step = True seq_group_metadata_list=seq_group_metadata_list,
ctx.output_queue.append( scheduler_outputs=scheduler_outputs,
(output, seq_group_metadata_list, scheduler_outputs, is_async, is_async=allow_async_output_proc,
is_last_step)) is_last_step=True)
if output and allow_async_output_proc: if outputs and allow_async_output_proc:
assert len( assert len(
output outputs
) == 1, "Async postprocessor expects only a single output set" ) == 1, "Async postprocessor expects only a single output set"
self._advance_to_next_step( self._advance_to_next_step(
output[0], seq_group_metadata_list, outputs[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups) scheduler_outputs.scheduled_seq_groups)
if not allow_async_output_proc: if not allow_async_output_proc:
self._process_model_outputs(ctx=ctx) self._process_model_outputs(ctx=ctx)
# Log stats. # Log stats.
self.do_log_stats(scheduler_outputs, output) self.do_log_stats(scheduler_outputs, outputs)
# Tracing # Tracing
self.do_tracing(scheduler_outputs) self.do_tracing(scheduler_outputs)
......
...@@ -2,9 +2,9 @@ import functools ...@@ -2,9 +2,9 @@ import functools
import time import time
from collections import deque from collections import deque
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass, field from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, ClassVar, Deque, Dict, Iterable, List, from typing import (TYPE_CHECKING, Any, ClassVar, Deque, Dict, Iterable, List,
Mapping, Optional) Mapping, NamedTuple, Optional)
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Set, Tuple, Type, Union from typing import Set, Tuple, Type, Union
...@@ -90,17 +90,36 @@ class SchedulerOutputState: ...@@ -90,17 +90,36 @@ class SchedulerOutputState:
last_output: Optional[SamplerOutput] = None last_output: Optional[SamplerOutput] = None
@dataclass class OutputData(NamedTuple):
outputs: List[SamplerOutput]
seq_group_metadata_list: List[SequenceGroupMetadata]
scheduler_outputs: SchedulerOutputs
is_async: bool
is_last_step: bool
skip: List[int]
class SchedulerContext: class SchedulerContext:
output_queue: Deque[Tuple[Optional[List[SamplerOutput]],
List[SequenceGroupMetadata], SchedulerOutputs, def __init__(self):
bool, self.output_queue: Deque[OutputData] = deque()
bool]] = field(default_factory=lambda: deque()) self.request_outputs: List[Union[RequestOutput,
request_outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
EmbeddingRequestOutput]] = field( self.seq_group_metadata_list: Optional[
default_factory=lambda: []) List[SequenceGroupMetadata]] = None
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None self.scheduler_outputs: Optional[SchedulerOutputs] = None
scheduler_outputs: Optional[SchedulerOutputs] = None
def append_output(self, outputs: List[SamplerOutput],
seq_group_metadata_list: List[SequenceGroupMetadata],
scheduler_outputs: SchedulerOutputs, is_async: bool,
is_last_step: bool):
self.output_queue.append(
OutputData(outputs=outputs,
seq_group_metadata_list=seq_group_metadata_list,
scheduler_outputs=scheduler_outputs,
is_async=is_async,
is_last_step=is_last_step,
skip=[]))
class LLMEngine: class LLMEngine:
...@@ -1254,23 +1273,15 @@ class LLMEngine: ...@@ -1254,23 +1273,15 @@ class LLMEngine:
return return
def _process_model_outputs(self, ctx: SchedulerContext) -> None: def _process_model_outputs(self,
"""Apply the model output to the sequences in the scheduled seq groups. ctx: SchedulerContext,
request_id: Optional[str] = None) -> None:
"""Apply the model output to the sequences in the scheduled seq groups
and return responses.
virtual_engine: The engine id to operate on ctx: The virtual engine context to work on
request_id: If provided, then only this request is going to be processed
is_async: Indicates whether this postprocessor runs in
parallel with the GPU forward pass and is processing
tokens from the previous step. If this is true, then
no tokens need to be appended since it is already done
externally (before the next schedule() call)
sampler_output: Used with multi-step execution to provide
sampler_output of each step
is_last_output: Used with multi-step execution to indicate
the last step (of each multi-step group)
Returns RequestOutputs that can be returned to the client.
""" """
now = time.time() now = time.time()
...@@ -1278,9 +1289,14 @@ class LLMEngine: ...@@ -1278,9 +1289,14 @@ class LLMEngine:
return None return None
# Get pending async postprocessor # Get pending async postprocessor
(outputs, seq_group_metadata_list, scheduler_outputs, is_async, if request_id:
is_last_step) = ctx.output_queue.popleft() # When we process only one request, no pop is required
assert outputs is not None # (since later we will process all of the rest)
(outputs, seq_group_metadata_list, scheduler_outputs, is_async,
is_last_step, skip) = ctx.output_queue[0]
else:
(outputs, seq_group_metadata_list, scheduler_outputs, is_async,
is_last_step, skip) = ctx.output_queue.popleft()
# Sanity check # Sanity check
assert len(seq_group_metadata_list) == len( assert len(seq_group_metadata_list) == len(
...@@ -1294,9 +1310,30 @@ class LLMEngine: ...@@ -1294,9 +1310,30 @@ class LLMEngine:
else: else:
outputs_by_sequence_group = outputs outputs_by_sequence_group = outputs
# Determine the requests we need to operate on
if request_id:
indices = []
for i, seq_group_meta in enumerate(seq_group_metadata_list):
if seq_group_meta.request_id == request_id:
assert i not in skip # Cannot be called twice
indices.append(i)
break
# If the request_id was not found, then it means that
# this is a new request that has no pending async
# postprocessor
if not indices:
return
else:
indices = range(len(seq_group_metadata_list)) # type: ignore
finished_before: List[int] = [] finished_before: List[int] = []
finished_now: List[int] = [] finished_now: List[int] = []
for i, seq_group_meta in enumerate(seq_group_metadata_list): for i in indices:
if i in skip:
continue
seq_group_meta = seq_group_metadata_list[i]
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
seq_group = scheduled_seq_group.seq_group seq_group = scheduled_seq_group.seq_group
...@@ -1351,6 +1388,18 @@ class LLMEngine: ...@@ -1351,6 +1388,18 @@ class LLMEngine:
request_output = RequestOutputFactory.create(seq_group) request_output = RequestOutputFactory.create(seq_group)
ctx.request_outputs.append(request_output) ctx.request_outputs.append(request_output)
# When we process a single request, we skip it for the next time,
# and invoke the request output callback (if there was final output)
if request_id:
assert len(indices) == 1
skip.append(indices[0])
if (finished_now
and self.process_request_outputs_callback is not None):
self.process_request_outputs_callback(ctx.request_outputs)
ctx.request_outputs.clear()
return
# Free currently finished requests # Free currently finished requests
if finished_now: if finished_now:
for scheduler in self.scheduler: for scheduler in self.scheduler:
...@@ -1362,17 +1411,16 @@ class LLMEngine: ...@@ -1362,17 +1411,16 @@ class LLMEngine:
if (finished_now if (finished_now
and self.process_request_outputs_callback is not None): and self.process_request_outputs_callback is not None):
self.process_request_outputs_callback(ctx.request_outputs) self.process_request_outputs_callback(ctx.request_outputs)
ctx.request_outputs.clear()
return return
# Create the outputs # Create the outputs
# Note: scheduled_seq_groups and seq_group_metadata_list for i in indices:
# must match with the indices if i in skip or i in finished_before or i in finished_now:
for i, scheduled_seq_group in enumerate(
scheduler_outputs.scheduled_seq_groups):
if i in finished_before or i in finished_now:
continue # Avoids double processing continue # Avoids double processing
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
seq_group = scheduled_seq_group.seq_group seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now) seq_group.maybe_set_first_token_time(now)
if (seq_group.is_finished() if (seq_group.is_finished()
...@@ -1388,6 +1436,7 @@ class LLMEngine: ...@@ -1388,6 +1436,7 @@ class LLMEngine:
if (ctx.request_outputs if (ctx.request_outputs
and self.process_request_outputs_callback is not None): and self.process_request_outputs_callback is not None):
self.process_request_outputs_callback(ctx.request_outputs) self.process_request_outputs_callback(ctx.request_outputs)
ctx.request_outputs.clear()
# For async case, we need to record the stats here. # For async case, we need to record the stats here.
# For non-async case, the stats are done in the # For non-async case, the stats are done in the
...@@ -1556,20 +1605,20 @@ class LLMEngine: ...@@ -1556,20 +1605,20 @@ class LLMEngine:
execute_model_req.async_callback = self.async_callbacks[ execute_model_req.async_callback = self.async_callbacks[
virtual_engine] virtual_engine]
output = self.model_executor.execute_model( outputs = self.model_executor.execute_model(
execute_model_req=execute_model_req) execute_model_req=execute_model_req)
# We need to do this here so that last step's sampled_token_ids can # We need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP. # be passed to the next iteration for PP.
if self.scheduler_config.is_multi_step: if self.scheduler_config.is_multi_step:
self._update_cached_scheduler_output(virtual_engine, output) self._update_cached_scheduler_output(virtual_engine, outputs)
else: else:
# Nothing scheduled => If there is pending async postprocessor, # Nothing scheduled => If there is pending async postprocessor,
# then finish it here. # then finish it here.
if len(ctx.output_queue) > 0: if len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx) self._process_model_outputs(ctx=ctx)
# No outputs in this case # No outputs in this case
output = [] outputs = []
# Finish the current step for all the sequence groups. # Finish the current step for all the sequence groups.
if self.scheduler_config.is_multi_step: if self.scheduler_config.is_multi_step:
...@@ -1582,18 +1631,18 @@ class LLMEngine: ...@@ -1582,18 +1631,18 @@ class LLMEngine:
self.cached_scheduler_outputs[0] = SchedulerOutputState() self.cached_scheduler_outputs[0] = SchedulerOutputState()
# Add results to the output_queue # Add results to the output_queue
is_async = allow_async_output_proc ctx.append_output(outputs=outputs,
is_last_step = True seq_group_metadata_list=seq_group_metadata_list,
ctx.output_queue.append( scheduler_outputs=scheduler_outputs,
(output, seq_group_metadata_list, scheduler_outputs, is_async, is_async=allow_async_output_proc,
is_last_step)) is_last_step=True)
if output and allow_async_output_proc: if outputs and allow_async_output_proc:
assert len(output) == 1, ( assert len(outputs) == 1, (
"Async postprocessor expects only a single output set") "Async postprocessor expects only a single output set")
self._advance_to_next_step( self._advance_to_next_step(
output[0], seq_group_metadata_list, outputs[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups) scheduler_outputs.scheduled_seq_groups)
# Check if need to run the usual non-async path # Check if need to run the usual non-async path
...@@ -1601,7 +1650,7 @@ class LLMEngine: ...@@ -1601,7 +1650,7 @@ class LLMEngine:
self._process_model_outputs(ctx=ctx) self._process_model_outputs(ctx=ctx)
# Log stats. # Log stats.
self.do_log_stats(scheduler_outputs, output) self.do_log_stats(scheduler_outputs, outputs)
# Tracing # Tracing
self.do_tracing(scheduler_outputs) self.do_tracing(scheduler_outputs)
...@@ -1922,6 +1971,12 @@ class LLMEngine: ...@@ -1922,6 +1971,12 @@ class LLMEngine:
self.tokenizer.check_health() self.tokenizer.check_health()
self.model_executor.check_health() self.model_executor.check_health()
def start_profile(self) -> None:
self.model_executor.start_profile()
def stop_profile(self) -> None:
self.model_executor.stop_profile()
def is_tracing_enabled(self) -> bool: def is_tracing_enabled(self) -> bool:
return self.tracer is not None return self.tracer is not None
......
...@@ -23,6 +23,7 @@ from openai.types.chat import (ChatCompletionMessageToolCallParam, ...@@ -23,6 +23,7 @@ from openai.types.chat import (ChatCompletionMessageToolCallParam,
# yapf: enable # yapf: enable
# pydantic needs the TypedDict from typing_extensions # pydantic needs the TypedDict from typing_extensions
from pydantic import ConfigDict from pydantic import ConfigDict
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from typing_extensions import Required, TypeAlias, TypedDict from typing_extensions import Required, TypeAlias, TypedDict
from vllm.config import ModelConfig from vllm.config import ModelConfig
...@@ -31,7 +32,7 @@ from vllm.multimodal import MultiModalDataDict ...@@ -31,7 +32,7 @@ from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import (async_get_and_parse_audio, from vllm.multimodal.utils import (async_get_and_parse_audio,
async_get_and_parse_image, async_get_and_parse_image,
get_and_parse_audio, get_and_parse_image) get_and_parse_audio, get_and_parse_image)
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -107,7 +108,7 @@ class ConversationMessage(TypedDict, total=False): ...@@ -107,7 +108,7 @@ class ConversationMessage(TypedDict, total=False):
"""The tool calls generated by the model, such as function calls.""" """The tool calls generated by the model, such as function calls."""
ModalityStr = Literal["image", "audio"] ModalityStr = Literal["image", "audio", "video"]
_T = TypeVar("_T") _T = TypeVar("_T")
...@@ -147,20 +148,29 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -147,20 +148,29 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
return f"<|image_{current_count}|>" return f"<|image_{current_count}|>"
if model_type == "minicpmv": if model_type == "minicpmv":
return "(<image>./</image>)" return "(<image>./</image>)"
if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"): if model_type in ("blip-2", "chatglm", "fuyu", "paligemma",
"pixtral"):
# These models do not use image tokens in the prompt # These models do not use image tokens in the prompt
return None return None
if model_type == "qwen":
return f"Picture {current_count}: <img></img>"
if model_type.startswith("llava"): if model_type.startswith("llava"):
return self._cached_token_str(self._tokenizer, return self._cached_token_str(self._tokenizer,
hf_config.image_token_index) hf_config.image_token_index)
if model_type in ("chameleon", "internvl_chat"): if model_type in ("chameleon", "internvl_chat"):
return "<image>" return "<image>"
if model_type == "qwen2_vl":
return "<|vision_start|><|image_pad|><|vision_end|>"
raise TypeError(f"Unknown model type: {model_type}") raise TypeError(f"Unknown model type: {model_type}")
elif modality == "audio": elif modality == "audio":
if model_type == "ultravox": if model_type == "ultravox":
return "<|reserved_special_token_0|>" return "<|reserved_special_token_0|>"
raise TypeError(f"Unknown model type: {model_type}") raise TypeError(f"Unknown model type: {model_type}")
elif modality == "video":
if model_type == "qwen2_vl":
return "<|vision_start|><|video_pad|><|vision_end|>"
raise TypeError(f"Unknown model type: {model_type}")
else: else:
raise TypeError(f"Unknown modality: {modality}") raise TypeError(f"Unknown modality: {modality}")
...@@ -377,6 +387,9 @@ def _parse_chat_message_content_parts( ...@@ -377,6 +387,9 @@ def _parse_chat_message_content_parts(
audio_url = _AudioParser(part)["audio_url"] audio_url = _AudioParser(part)["audio_url"]
mm_parser.parse_audio(audio_url["url"]) mm_parser.parse_audio(audio_url["url"])
elif part_type == "refusal":
text = _RefusalParser(part)["refusal"]
texts.append(text)
else: else:
raise NotImplementedError(f"Unknown part type: {part_type}") raise NotImplementedError(f"Unknown part type: {part_type}")
...@@ -431,6 +444,21 @@ def _parse_chat_message_content( ...@@ -431,6 +444,21 @@ def _parse_chat_message_content(
return result return result
def _postprocess_messages(messages: List[ConversationMessage]) -> None:
# per the Transformers docs & maintainers, tool call arguments in
# assistant-role messages with tool_calls need to be dicts not JSON str -
# this is how tool-use chat templates will expect them moving forwards
# so, for messages that have tool_calls, parse the string (which we get
# from openAI format) to dict
for message in messages:
if (message["role"] == "assistant" and "tool_calls" in message
and isinstance(message["tool_calls"], list)):
for item in message["tool_calls"]:
item["function"]["arguments"] = json.loads(
item["function"]["arguments"])
def parse_chat_messages( def parse_chat_messages(
messages: List[ChatCompletionMessageParam], messages: List[ChatCompletionMessageParam],
model_config: ModelConfig, model_config: ModelConfig,
...@@ -444,6 +472,8 @@ def parse_chat_messages( ...@@ -444,6 +472,8 @@ def parse_chat_messages(
conversation.extend(sub_messages) conversation.extend(sub_messages)
_postprocess_messages(conversation)
return conversation, mm_tracker.all_mm_data() return conversation, mm_tracker.all_mm_data()
...@@ -460,41 +490,44 @@ def parse_chat_messages_futures( ...@@ -460,41 +490,44 @@ def parse_chat_messages_futures(
conversation.extend(sub_messages) conversation.extend(sub_messages)
_postprocess_messages(conversation)
return conversation, mm_tracker.all_mm_data() return conversation, mm_tracker.all_mm_data()
def apply_chat_template( def apply_hf_chat_template(
tokenizer: AnyTokenizer, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
conversation: List[ConversationMessage], conversation: List[ConversationMessage],
chat_template: Optional[str], chat_template: Optional[str],
*, *,
tokenize: bool = False, # Different from HF's default tokenize: bool = False, # Different from HF's default
**kwargs: Any, **kwargs: Any,
) -> Union[str, List[int]]: ) -> str:
if chat_template is None and tokenizer.chat_template is None: if chat_template is None and tokenizer.chat_template is None:
raise ValueError( raise ValueError(
"As of transformers v4.44, default chat template is no longer " "As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer " "allowed, so you must provide a chat template if the tokenizer "
"does not define one.") "does not define one.")
# per the Transformers docs & maintainers, tool call arguments in return tokenizer.apply_chat_template(
# assistant-role messages with tool_calls need to be dicts not JSON str - conversation=conversation, # type: ignore[arg-type]
# this is how tool-use chat templates will expect them moving forwards
# so, for messages that have tool_calls, parse the string (which we get
# from openAI format) to dict
for message in conversation:
if (message["role"] == "assistant" and "tool_calls" in message
and isinstance(message["tool_calls"], list)):
for i in range(len(message["tool_calls"])):
args: str = message["tool_calls"][i]["function"]["arguments"]
parsed_args: Dict = json.loads(args)
message["tool_calls"][i]["function"]["arguments"] = parsed_args
prompt = tokenizer.apply_chat_template(
conversation=conversation,
chat_template=chat_template, chat_template=chat_template,
tokenize=tokenize, tokenize=tokenize,
**kwargs, **kwargs,
) )
return prompt
def apply_mistral_chat_template(
tokenizer: MistralTokenizer,
messages: List[ChatCompletionMessageParam],
chat_template: Optional[str] = None,
**kwargs: Any,
) -> List[int]:
if chat_template is not None:
logger.warning(
"'chat_template' cannot be overridden for mistral tokenizer.")
return tokenizer.apply_chat_template(
messages=messages,
**kwargs,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment