Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4851c202
Commit
4851c202
authored
Sep 13, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.6.1' into v0.6.1-dev
parents
9b902f9e
3fd2b0d2
Changes
203
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
947 additions
and
441 deletions
+947
-441
tests/spec_decode/e2e/test_multistep_correctness.py
tests/spec_decode/e2e/test_multistep_correctness.py
+171
-136
tests/spec_decode/e2e/test_ngram_correctness.py
tests/spec_decode/e2e/test_ngram_correctness.py
+58
-36
tests/spec_decode/e2e/test_seed.py
tests/spec_decode/e2e/test_seed.py
+33
-19
tests/test_logger.py
tests/test_logger.py
+3
-3
tests/tool_use/utils.py
tests/tool_use/utils.py
+1
-1
tests/utils.py
tests/utils.py
+12
-7
tests/weight_loading/models-large.txt
tests/weight_loading/models-large.txt
+3
-0
tests/weight_loading/models.txt
tests/weight_loading/models.txt
+1
-2
vllm/_custom_ops.py
vllm/_custom_ops.py
+204
-6
vllm/assets/video.py
vllm/assets/video.py
+85
-0
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+26
-7
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+8
-1
vllm/benchmark_throughput.py
vllm/benchmark_throughput.py
+7
-8
vllm/compilation/wrapper.py
vllm/compilation/wrapper.py
+1
-1
vllm/config.py
vllm/config.py
+51
-49
vllm/core/scheduler.py
vllm/core/scheduler.py
+47
-40
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+63
-40
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+12
-12
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+105
-50
vllm/entrypoints/chat_utils.py
vllm/entrypoints/chat_utils.py
+56
-23
No files found.
tests/spec_decode/e2e/test_multistep_correctness.py
View file @
4851c202
...
...
@@ -41,8 +41,9 @@ from transformers import AutoTokenizer
from
vllm
import
SamplingParams
from
...utils
import
fork_new_process_for_each_test
from
.conftest
import
(
get_output_from_llm_generator
,
run_
greedy_
equality_correctness_test
)
run_equality_correctness_test
)
@
pytest
.
mark
.
parametrize
(
...
...
@@ -73,6 +74,7 @@ from .conftest import (get_output_from_llm_generator,
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
32
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
fork_new_process_for_each_test
def
test_spec_decode_e2e_with_detokenization
(
test_llm_generator
,
batch_size
:
int
):
"""Run generation with speculative decoding on a batch. Verify the engine
...
...
@@ -116,44 +118,6 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
assert
actual_tokens
.
strip
()
==
expected_tokens
.
strip
()
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Use a small model for a fast test.
# Note this is repeated in the test body; to initialize a tokenizer.
"model"
:
"JackFram/llama-68m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
,
# Use AsyncLLM engine
"use_async"
:
True
,
}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
},
])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_spec_decode_e2e_with_async_engine
(
test_llm_generator
,
baseline_llm_generator
,
batch_size
:
int
):
"""Verify spec decode works well with async LLM engine.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
32
,
force_output_len
=
True
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
...
...
@@ -172,10 +136,10 @@ def test_spec_decode_e2e_with_async_engine(test_llm_generator,
# Try two different tiny base models.
# Note that one is equal to the draft model, another isn't.
{
"model"
:
"JackFram/llama-68m"
,
"model
_name
"
:
"JackFram/llama-68m"
,
},
{
"model"
:
"JackFram/llama-160m"
,
"model
_name
"
:
"JackFram/llama-160m"
,
},
])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
...
...
@@ -189,13 +153,15 @@ def test_spec_decode_e2e_with_async_engine(test_llm_generator,
"output_len"
,
[
# Use long output len for the small model test.
1
536
,
1
0
,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
fork_new_process_for_each_test
def
test_spec_decode_e2e_greedy_correctness_tiny_model_bs1
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify greedy equality on a tiny model with batch size of one.
Since this test is cheaper than other e2e correctness tests, we generate
...
...
@@ -204,14 +170,18 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
When the draft model is the same as the target model, we further check
whether all speculative tokens are accepted.
"""
ensure_all_accepted
=
test_llm_generator
.
same_draft_target_model
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
,
ensure_all_accepted
=
ensure_all_accepted
)
ensure_all_accepted
=
per_test_common_llm_kwargs
.
get
(
"model_name"
)
==
test_llm_kwargs
.
get
(
"speculative_model"
)
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
,
ensure_all_accepted
=
ensure_all_accepted
)
@
pytest
.
mark
.
parametrize
(
...
...
@@ -232,10 +202,10 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
# Try two different tiny base models.
# Note that one is equal to the draft model, another isn't.
{
"model"
:
"JackFram/llama-68m"
,
"model
_name
"
:
"JackFram/llama-68m"
,
},
{
"model"
:
"JackFram/llama-160m"
,
"model
_name
"
:
"JackFram/llama-160m"
,
},
])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
...
...
@@ -253,16 +223,22 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
fork_new_process_for_each_test
def
test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify greedy equality on a tiny model and large batch size.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
...
...
@@ -280,10 +256,10 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
# Try two different tiny base models.
# Note that one is equal to the draft model, another isn't.
{
"model"
:
"JackFram/llama-68m"
,
"model
_name
"
:
"JackFram/llama-68m"
,
},
{
"model"
:
"JackFram/llama-160m"
,
"model
_name
"
:
"JackFram/llama-160m"
,
},
])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
...
...
@@ -298,24 +274,31 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
fork_new_process_for_each_test
def
test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
max_output_len
:
int
):
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
max_output_len
:
int
,
seed
:
int
):
"""Verify greedy equality on a tiny model, with a large batch size, and when
sampling respects the EOS token.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
,
force_output_len
=
False
)
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
,
seed
=
seed
,
temperature
=
0.0
,
ignore_eos
=
False
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# A "real" model (not tiny).
"model"
:
"meta-llama/Llama-2-7b-chat-hf"
,
"model
_name
"
:
"meta-llama/Llama-2-7b-chat-hf"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
...
...
@@ -342,24 +325,30 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
256
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
fork_new_process_for_each_test
def
test_spec_decode_e2e_greedy_correctness_real_model_bs1
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify greedy equality on a "real" model and batch size of 1. This is
separate from large BS tests to make identifying the source of bugs easier.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# A "real" model (not tiny).
"model"
:
"meta-llama/Llama-2-7b-chat-hf"
,
"model
_name
"
:
"meta-llama/Llama-2-7b-chat-hf"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
...
...
@@ -386,17 +375,23 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
64
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
fork_new_process_for_each_test
def
test_spec_decode_e2e_greedy_correctness_real_model_large_bs
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify greedy equality with a "real" model on a nontrivial batch size.
This is the closest test to a real production workload.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
...
...
@@ -415,7 +410,7 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[
{
"model"
:
"JackFram/llama-160m"
,
"model
_name
"
:
"JackFram/llama-160m"
,
},
])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
...
...
@@ -433,23 +428,29 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
fork_new_process_for_each_test
def
test_spec_decode_e2e_greedy_correctness_with_preemption
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-160m"
,
"model
_name
"
:
"JackFram/llama-160m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
...
...
@@ -487,22 +488,29 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption(
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_spec_decode_different_block_size
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
@
fork_new_process_for_each_test
def
test_spec_decode_different_block_size
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify greedy equality over different block sizes.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-160m"
,
"model
_name
"
:
"JackFram/llama-160m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
...
...
@@ -534,24 +542,31 @@ def test_spec_decode_different_block_size(baseline_llm_generator,
64
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_skip_speculation
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
@
fork_new_process_for_each_test
def
test_skip_speculation
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify greedy equality when some (or all) sequences skip speculation.
We do this by setting the max model len of the draft model to an
artificially low value, such that when the sequences grow beyond it, they
are skipped in speculative decoding.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-160m"
,
"model
_name
"
:
"JackFram/llama-160m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
...
...
@@ -571,21 +586,28 @@ def test_skip_speculation(baseline_llm_generator, test_llm_generator,
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
10
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_disable_speculation
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
@
fork_new_process_for_each_test
def
test_disable_speculation
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify greedy equality when all sequences disable speculation.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-68m"
,
"model
_name
"
:
"JackFram/llama-68m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
...
...
@@ -613,22 +635,28 @@ def test_disable_speculation(baseline_llm_generator, test_llm_generator,
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_many_k
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
@
fork_new_process_for_each_test
def
test_many_k
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify that speculative decoding produces exact equality to without spec
decode with many different values of k.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-160m"
,
"model
_name
"
:
"JackFram/llama-160m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
...
...
@@ -657,15 +685,22 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_typical_acceptance_sampling
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
@
fork_new_process_for_each_test
def
test_typical_acceptance_sampling
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify that speculative decoding produces exact equality to without spec
decode with TypicalAcceptanceSampler as the draft token acceptance
sampling method.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
)
tests/spec_decode/e2e/test_ngram_correctness.py
View file @
4851c202
...
...
@@ -26,7 +26,7 @@ for the target model outputs.
import
pytest
from
.conftest
import
run_
greedy_
equality_correctness_test
from
.conftest
import
run_equality_correctness_test
@
pytest
.
mark
.
parametrize
(
...
...
@@ -43,7 +43,7 @@ from .conftest import run_greedy_equality_correctness_test
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[
{
"model"
:
"JackFram/llama-68m"
,
"model
_name
"
:
"JackFram/llama-68m"
,
},
])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
...
...
@@ -59,15 +59,21 @@ from .conftest import run_greedy_equality_correctness_test
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
32
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_ngram_e2e_greedy_correctness
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
def
test_ngram_e2e_greedy_correctness
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify greedy equality on a tiny model with different batch size."""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
...
...
@@ -86,7 +92,7 @@ def test_ngram_e2e_greedy_correctness(baseline_llm_generator,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[
{
"model"
:
"JackFram/llama-160m"
,
"model
_name
"
:
"JackFram/llama-160m"
,
},
])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
...
...
@@ -105,24 +111,28 @@ def test_ngram_e2e_greedy_correctness(baseline_llm_generator,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_ngram_e2e_greedy_correctness_with_preemption
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
def
test_ngram_e2e_greedy_correctness_with_preemption
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
temperature
=
0
,
seed
=
seed
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-68m"
,
"model
_name
"
:
"JackFram/llama-68m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
...
...
@@ -159,23 +169,29 @@ def test_ngram_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_ngram_different_k
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
def
test_ngram_different_k
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify that ngram speculative decoding produces exact equality
to without spec decode with many different values of k and
different ngram_prompt_lookup_max.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-68m"
,
"model
_name
"
:
"JackFram/llama-68m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
...
...
@@ -200,14 +216,20 @@ def test_ngram_different_k(baseline_llm_generator, test_llm_generator,
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_ngram_disable_queue
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
def
test_ngram_disable_queue
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify that ngram speculative decoding produces exact equality
to without spec decode with many different values of k and
different ngram_prompt_lookup_max.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
)
tests/spec_decode/e2e/test_seed.py
View file @
4851c202
...
...
@@ -2,11 +2,17 @@ import pytest
from
.conftest
import
run_equality_correctness_test
# main model
MAIN_MODEL
=
"JackFram/llama-68m"
# speculative model
SPEC_MODEL
=
"JackFram/llama-160m"
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-68m"
,
"model
_name
"
:
"JackFram/llama-68m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
...
...
@@ -31,26 +37,34 @@ from .conftest import run_equality_correctness_test
# Use smaller output len for fast test.
20
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
None
])
def
test_seeded_consistency
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
temperature
:
floa
t
,
output_len
:
int
):
def
test_seeded_consistency
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
in
t
,
temperature
:
float
,
output_len
:
int
):
"""Verify outputs are consistent across multiple runs with same seed
"""
run_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
temperature
=
temperature
,
seeded
=
True
,
force_output_len
=
True
)
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
temperature
=
temperature
,
disable_seed
=
False
,
)
# Ensure this same test does fail if we _don't_ include per-request seeds
with
pytest
.
raises
(
AssertionError
):
run_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
temperature
=
temperature
,
seeded
=
False
,
force_output_len
=
True
)
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
temperature
=
temperature
,
disable_seed
=
True
,
)
tests/test_logger.py
View file @
4851c202
...
...
@@ -95,7 +95,7 @@ def test_logger_configuring_can_be_disabled():
config behavior, however mocks are used to ensure no changes in behavior or
configuration occur."""
with
patch
(
"
logging.config
.dictConfig"
)
as
dict_config_mock
:
with
patch
(
"
vllm.logger
.dictConfig"
)
as
dict_config_mock
:
_configure_vllm_root_logger
()
dict_config_mock
.
assert_not_called
()
...
...
@@ -175,9 +175,9 @@ def test_custom_logging_config_is_parsed_and_used_when_provided():
logging_config_file
.
flush
()
with
patch
(
"vllm.logger.VLLM_LOGGING_CONFIG_PATH"
,
logging_config_file
.
name
),
patch
(
"
logging.config
.dictConfig"
)
as
dict_config_mock
:
"
vllm.logger
.dictConfig"
)
as
dict_config_mock
:
_configure_vllm_root_logger
()
assert
dict_config_mock
.
called_with
(
valid_logging_config
)
dict_config_mock
.
assert_
called_with
(
valid_logging_config
)
@
patch
(
"vllm.logger.VLLM_CONFIGURE_LOGGING"
,
0
)
...
...
tests/tool_use/utils.py
View file @
4851c202
...
...
@@ -19,7 +19,7 @@ ARGS: List[str] = ["--enable-auto-tool-choice", "--max-model-len", "8096"]
CONFIGS
:
Dict
[
str
,
ServerConfig
]
=
{
"hermes"
:
{
"model"
:
"NousResearch/Hermes-
2-Pro
-Llama-3-8B"
,
"NousResearch/Hermes-
3
-Llama-3
.1
-8B"
,
"arguments"
:
[
"--tool-call-parser"
,
"hermes"
,
"--chat-template"
,
str
(
VLLM_PATH
/
"examples/tool_chat_template_hermes.jinja"
)
...
...
tests/utils.py
View file @
4851c202
...
...
@@ -20,7 +20,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment
)
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.entrypoints.openai.cli_args
import
make_arg_parser
from
vllm.model_executor.model_loader.loader
import
DefaultM
odel
L
oader
from
vllm.model_executor.model_loader.loader
import
get_m
odel
_l
oader
from
vllm.platforms
import
current_platform
from
vllm.utils
import
FlexibleArgumentParser
,
get_open_port
,
is_hip
...
...
@@ -89,11 +89,11 @@ class RemoteOpenAIServer:
is_local
=
os
.
path
.
isdir
(
model
)
if
not
is_local
:
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
engine
_config
=
engine_args
.
create_
engine
_config
()
dummy_loader
=
DefaultModelLoader
(
engine_config
.
load_config
)
dummy_loader
.
_prepare_weights
(
engine_config
.
model_config
.
model
,
engine_config
.
model_config
.
revision
,
fall_back_to_pt
=
True
)
model
_config
=
engine_args
.
create_
model
_config
()
load_config
=
engine_args
.
create_
load_config
(
)
model_loader
=
get_model_loader
(
load_config
)
model_loader
.
download_model
(
model_config
)
env
=
os
.
environ
.
copy
()
# the current process might initialize cuda,
...
...
@@ -178,7 +178,12 @@ def compare_two_settings(model: str,
env2: The second set of environment variables to pass to the API server.
"""
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model
)
trust_remote_code
=
"--trust-remote-code"
if
trust_remote_code
in
arg1
or
trust_remote_code
in
arg2
:
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model
,
trust_remote_code
=
True
)
else
:
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model
)
prompt
=
"Hello, my name is"
token_ids
=
tokenizer
(
prompt
)[
"input_ids"
]
...
...
tests/weight_loading/models-large.txt
0 → 100644
View file @
4851c202
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main
\ No newline at end of file
tests/weight_loading/models.txt
View file @
4851c202
...
...
@@ -19,8 +19,7 @@ compressed-tensors, nm-testing/tinyllama-oneshot-w8a16-per-channel, main
compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main
compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
compressed-tensors, nm-testing/TinyLlama-1.1B-Chat-v1.0-actorder-group, main
awq, casperhansen/mixtral-instruct-awq, main
awq_marlin, casperhansen/mixtral-instruct-awq, main
fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main
...
...
vllm/_custom_ops.py
View file @
4851c202
...
...
@@ -339,18 +339,28 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# b_g_idx, use_exllama, bit)
# TODO: has to be a better way to do this
try
:
torch
.
ops
.
_C
.
gptq_gemm
# noqa B018
@
torch
.
library
.
register_fake
(
"_C::gptq_gemm"
)
def
_gptq_gemm_fake
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
b_gptq_qzeros
:
torch
.
Tensor
,
b_gptq_scales
:
torch
.
Tensor
,
b_g_idx
:
torch
.
Tensor
,
use_exllama
:
bool
,
bit
:
int
)
->
torch
.
Tensor
:
return
torch
.
empty
((
a
.
size
(
0
),
b_q_weight
.
size
(
1
)),
dtype
=
a
.
dtype
,
device
=
a
.
device
)
except
Exception
:
pass
def
gptq_shuffle
(
q_weight
:
torch
.
Tensor
,
q_perm
:
torch
.
Tensor
,
bit
:
int
)
->
None
:
quant_ops
.
gptq_shuffle
(
q_weight
,
q_perm
,
bit
)
# torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
# squeezellm
def
squeezellm_gemm
(
vec
:
torch
.
Tensor
,
mat
:
torch
.
Tensor
,
mul
:
torch
.
Tensor
,
lookup_table
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
_C
.
squeezellm_gemm
(
vec
,
mat
,
mul
,
lookup_table
)
# marlin
def
marlin_gemm
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
size_m
:
int
,
...
...
@@ -369,6 +379,194 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
size_n
,
size_k
)
# TODO: has to be a better way to do this
try
:
torch
.
ops
.
_C
.
gptq_marlin_24_gemm
# noqa B018
@
torch
.
library
.
register_fake
(
"_C::gptq_marlin_24_gemm"
)
def
_gptq_marlin_24_gemm_fake
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
b_meta
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
b_q_type
:
ScalarType
,
size_m
:
int
,
size_n
:
int
,
size_k
:
int
)
->
torch
.
Tensor
:
return
torch
.
empty
((
size_m
,
size_n
),
device
=
a
.
device
,
dtype
=
a
.
dtype
)
@
torch
.
library
.
register_fake
(
"_C::gptq_marlin_gemm"
)
def
_gptq_marlin_gemm_fake
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
b_zeros
:
torch
.
Tensor
,
g_idx
:
torch
.
Tensor
,
perm
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
b_q_type
:
ScalarType
,
size_m
:
int
,
size_n
:
int
,
size_k
:
int
,
is_k_full
:
bool
,
has_zp
:
bool
=
False
,
use_fp32_reduce
:
bool
=
False
)
->
torch
.
Tensor
:
return
torch
.
empty
((
size_m
,
size_n
),
device
=
a
.
device
,
dtype
=
a
.
dtype
)
@
torch
.
library
.
register_fake
(
"_C::ggml_dequantize"
)
def
_ggml_dequantize_fake
(
W
:
torch
.
Tensor
,
quant_type
:
int
,
m
:
int
,
n
:
int
)
->
torch
.
Tensor
:
return
torch
.
empty
((
m
,
n
),
dtype
=
torch
.
float16
,
device
=
W
.
device
)
@
torch
.
library
.
register_fake
(
"_C::ggml_mul_mat_vec_a8"
)
def
_ggml_mul_mat_vec_a8_fake
(
W
:
torch
.
Tensor
,
X
:
torch
.
Tensor
,
quant_type
:
int
,
row
:
int
,
)
->
torch
.
Tensor
:
return
torch
.
empty
((
1
,
row
),
dtype
=
torch
.
float16
,
device
=
W
.
device
)
@
torch
.
library
.
register_fake
(
"_C::ggml_mul_mat_a8"
)
def
_ggml_mul_mat_a8_fake
(
W
:
torch
.
Tensor
,
X
:
torch
.
Tensor
,
quant_type
:
int
,
row
:
int
,
)
->
torch
.
Tensor
:
batch
=
X
.
size
(
0
)
return
torch
.
empty
((
batch
,
row
),
dtype
=
torch
.
float16
,
device
=
W
.
device
)
@
torch
.
library
.
register_fake
(
"_C::marlin_qqq_gemm"
)
def
_marlin_qqq_gemm_fake
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
s_tok
:
torch
.
Tensor
,
s_ch
:
torch
.
Tensor
,
s_group
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
size_m
:
int
,
size_n
:
int
,
size_k
:
int
)
->
torch
.
Tensor
:
return
torch
.
empty
((
size_m
,
size_n
),
dtype
=
torch
.
float16
,
device
=
a
.
device
)
@
torch
.
library
.
register_fake
(
"_C::marlin_gemm"
)
def
_marlin_gemm_fake
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
size_m
:
int
,
size_n
:
int
,
size_k
:
int
)
->
torch
.
Tensor
:
return
torch
.
empty
((
size_m
,
size_n
),
dtype
=
torch
.
float16
,
device
=
a
.
device
)
@
torch
.
library
.
register_fake
(
"_C::awq_dequantize"
)
def
_awq_dequantize_fake
(
qweight
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
zeros
:
torch
.
Tensor
,
split_k_iters
:
int
,
thx
:
int
,
thy
:
int
)
->
torch
.
Tensor
:
in_c
=
qweight
.
size
(
0
)
qout_c
=
qweight
.
size
(
1
)
out_c
=
qout_c
*
8
return
torch
.
empty
((
in_c
,
out_c
),
dtype
=
scales
.
dtype
,
device
=
scales
.
device
)
@
torch
.
library
.
register_fake
(
"_C::awq_gemm"
)
def
_awq_gemm_fake
(
input
:
torch
.
Tensor
,
qweight
:
torch
.
Tensor
,
qzeros
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
split_k_iters
:
int
)
->
torch
.
Tensor
:
num_in_feats
=
input
.
size
(
0
)
return
torch
.
empty
((
split_k_iters
,
num_in_feats
,
qweight
.
size
(
1
)
*
8
),
dtype
=
input
.
dtype
,
device
=
input
.
device
).
sum
(
0
)
@
torch
.
library
.
register_fake
(
"_C::aqlm_gemm"
)
def
_aqlm_gemm_fake
(
input
:
torch
.
Tensor
,
codes
:
torch
.
Tensor
,
codebooks
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
codebook_partition_sizes
:
List
[
int
],
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
out_features
=
codes
.
size
(
0
)
*
codebooks
.
size
(
2
)
flat_input
=
input
.
reshape
((
-
1
,
input
.
size
(
-
1
)))
flat_output
=
torch
.
empty
((
flat_input
.
size
(
0
),
out_features
),
dtype
=
input
.
dtype
,
device
=
input
.
device
)
output_sizes
=
list
(
input
.
shape
)
output_sizes
.
pop
()
output_sizes
.
append
(
-
1
)
return
flat_output
.
reshape
(
tuple
(
output_sizes
))
@
torch
.
library
.
register_fake
(
"_C::aqlm_dequant"
)
def
_aqlm_dequant_fake
(
codes
:
torch
.
Tensor
,
codebooks
:
torch
.
Tensor
,
codebook_partition_sizes
:
List
[
int
])
->
torch
.
Tensor
:
in_features
=
codes
.
size
(
1
)
*
8
out_features
=
codes
.
size
(
0
)
return
torch
.
empty
((
out_features
,
in_features
),
dtype
=
codebooks
.
dtype
,
device
=
codebooks
.
device
)
@
torch
.
library
.
register_fake
(
"_C::fp8_marlin_gemm"
)
def
_fp8_marlin_gemm_fake
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
num_bits
:
int
,
size_m
:
int
,
size_n
:
int
,
size_k
:
int
)
->
torch
.
Tensor
:
return
torch
.
empty
((
size_m
,
size_n
),
dtype
=
a
.
dtype
,
device
=
a
.
device
)
@
torch
.
library
.
register_fake
(
"_C::machete_gemm"
)
def
machete_gemm_fake
(
a
:
torch
.
Tensor
,
b_q
:
torch
.
Tensor
,
# Should be the tensor returned by machete_prepack_B
b_type
:
ScalarType
,
b_scales
:
Optional
[
torch
.
Tensor
]
=
None
,
b_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
b_group_size
:
Optional
[
int
]
=
None
,
c
:
Optional
[
torch
.
Tensor
]
=
None
,
alpha
:
Optional
[
float
]
=
None
,
beta
:
Optional
[
float
]
=
None
,
schedule
:
Optional
[
str
]
=
None
,
)
->
torch
.
Tensor
:
m
=
a
.
size
(
0
)
n
=
b_q
.
size
(
1
)
return
torch
.
empty
((
m
,
n
),
device
=
a
.
device
,
dtype
=
a
.
dtype
)
@
torch
.
library
.
register_fake
(
"_C::machete_prepack_B"
)
def
machete_prepack_B_fake
(
b_q_weight
:
torch
.
Tensor
,
b_type
:
ScalarType
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
b_q_weight
)
@
torch
.
library
.
register_fake
(
"_C::causal_conv1d_fwd"
)
def
causal_conv1d_fwd_fake
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias_
:
Optional
[
torch
.
Tensor
],
seq_idx_
:
Optional
[
torch
.
Tensor
],
initial_states_
:
Optional
[
torch
.
Tensor
],
final_states_out_
:
Optional
[
torch
.
Tensor
],
silu_activation
:
bool
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
x
)
@
torch
.
library
.
register_fake
(
"_C::causal_conv1d_update"
)
def
causal_conv1d_update_fake
(
x
:
torch
.
Tensor
,
conv_state
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias_
:
Optional
[
torch
.
Tensor
],
silu_activation
:
bool
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
x
)
@
torch
.
library
.
register_fake
(
"_C::selective_scan_fwd"
)
def
selective_scan_fwd_fake
(
u
:
torch
.
Tensor
,
delta
:
torch
.
Tensor
,
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
C
:
torch
.
Tensor
,
D_
:
Optional
[
torch
.
Tensor
],
z_
:
Optional
[
torch
.
Tensor
],
delta_bias_
:
Optional
[
torch
.
Tensor
],
delta_softplus
:
bool
,
index_
:
Optional
[
torch
.
Tensor
],
x
:
Optional
[
torch
.
Tensor
])
->
List
[
torch
.
Tensor
]:
a
=
torch
.
empty_like
(
u
)
if
x
is
not
None
:
b
=
x
else
:
b
=
torch
.
empty
((
u
.
size
(
0
),
u
.
size
(
1
),
A
.
size
(
1
)),
dtype
=
u
.
dtype
,
device
=
u
.
device
)
if
z_
is
not
None
:
c
=
torch
.
empty_like
(
z_
)
return
[
a
,
b
,
c
]
else
:
return
[
a
,
b
]
except
Exception
:
pass
# cutlass
def
cutlass_scaled_mm_supports_fp8
(
cuda_device_capability
:
int
)
->
bool
:
return
torch
.
ops
.
_C
.
cutlass_scaled_mm_supports_fp8
(
cuda_device_capability
)
...
...
vllm/assets/video.py
0 → 100644
View file @
4851c202
from
dataclasses
import
dataclass
from
functools
import
lru_cache
from
typing
import
List
,
Literal
import
numpy
as
np
import
numpy.typing
as
npt
from
huggingface_hub
import
hf_hub_download
from
PIL
import
Image
from
vllm.multimodal.utils
import
(
sample_frames_from_video
,
try_import_video_packages
)
from
.base
import
get_cache_dir
@
lru_cache
def
download_video_asset
(
filename
:
str
)
->
str
:
"""
Download and open an image from huggingface
repo: raushan-testing-hf/videos-test
"""
video_directory
=
get_cache_dir
()
/
"video-eample-data"
video_directory
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
video_path
=
video_directory
/
filename
video_path_str
=
str
(
video_path
)
if
not
video_path
.
exists
():
video_path_str
=
hf_hub_download
(
repo_id
=
"raushan-testing-hf/videos-test"
,
filename
=
filename
,
repo_type
=
"dataset"
,
cache_dir
=
video_directory
,
)
return
video_path_str
def
video_to_ndarrays
(
path
:
str
,
num_frames
:
int
=
-
1
)
->
npt
.
NDArray
:
cv2
=
try_import_video_packages
()
cap
=
cv2
.
VideoCapture
(
path
)
if
not
cap
.
isOpened
():
raise
ValueError
(
f
"Could not open video file
{
path
}
"
)
total_frames
=
int
(
cap
.
get
(
cv2
.
CAP_PROP_FRAME_COUNT
))
frames
=
[]
for
i
in
range
(
total_frames
):
ret
,
frame
=
cap
.
read
()
if
ret
:
frames
.
append
(
frame
)
cap
.
release
()
frames
=
np
.
stack
(
frames
)
frames
=
sample_frames_from_video
(
frames
,
num_frames
)
if
len
(
frames
)
<
num_frames
:
raise
ValueError
(
f
"Could not read enough frames from video file
{
path
}
"
f
" (expected
{
num_frames
}
frames, got
{
len
(
frames
)
}
)"
)
return
frames
def
video_to_pil_images_list
(
path
:
str
,
num_frames
:
int
=
-
1
)
->
List
[
Image
.
Image
]:
cv2
=
try_import_video_packages
()
frames
=
video_to_ndarrays
(
path
,
num_frames
)
return
[
Image
.
fromarray
(
cv2
.
cvtColor
(
frame
,
cv2
.
COLOR_BGR2RGB
))
for
frame
in
frames
]
@
dataclass
(
frozen
=
True
)
class
VideoAsset
:
name
:
Literal
[
"sample_demo_1.mp4"
]
num_frames
:
int
=
-
1
@
property
def
pil_images
(
self
)
->
List
[
Image
.
Image
]:
video_path
=
download_video_asset
(
self
.
name
)
ret
=
video_to_pil_images_list
(
video_path
,
self
.
num_frames
)
return
ret
@
property
def
np_ndarrays
(
self
)
->
List
[
npt
.
NDArray
]:
video_path
=
download_video_asset
(
self
.
name
)
ret
=
video_to_ndarrays
(
video_path
,
self
.
num_frames
)
return
ret
vllm/attention/backends/flash_attn.py
View file @
4851c202
...
...
@@ -16,7 +16,8 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
ModelInputForGPUBuilder
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
ModelInputForGPUWithSamplingMetadata
)
from
vllm_flash_attn
import
flash_attn_varlen_func
as
_flash_attn_varlen_func
from
vllm_flash_attn
import
flash_attn_with_kvcache
as
_flash_attn_with_kvcache
...
...
@@ -302,14 +303,12 @@ class FlashAttentionMetadata(AttentionMetadata):
)
return
self
.
_cached_decode_metadata
def
advance_step
(
self
,
num_seqs
:
int
,
num_queries
:
int
):
def
advance_step
(
self
,
model_input
:
"ModelInputForGPUWithSamplingMetadata"
,
sampled_token_ids
:
Optional
[
torch
.
Tensor
],
block_size
:
int
,
num_seqs
:
int
,
num_queries
:
int
):
"""
Update metadata in-place to advance one decode step.
"""
# GPU in-place update is currently called separately through
# custom_ops.advance_step(). See draft_model_runner. TODO(will): Move
# this logic to the backend.
# When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries
...
...
@@ -347,6 +346,16 @@ class FlashAttentionMetadata(AttentionMetadata):
self
.
seq_lens
[
i
]
+=
1
self
.
max_decode_seq_len
=
max
(
self
.
seq_lens
)
ops
.
advance_step
(
num_seqs
=
num_seqs
,
num_queries
=
num_queries
,
block_size
=
block_size
,
input_tokens
=
model_input
.
input_tokens
,
sampled_token_ids
=
sampled_token_ids
,
input_positions
=
model_input
.
input_positions
,
seq_lens
=
self
.
seq_lens_tensor
,
slot_mapping
=
self
.
slot_mapping
,
block_tables
=
self
.
block_tables
)
class
FlashAttentionMetadataBuilder
(
AttentionMetadataBuilder
[
FlashAttentionMetadata
]):
...
...
@@ -462,9 +471,19 @@ class FlashAttentionMetadataBuilder(
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables
=
self
.
runner
.
graph_block_tables
[:
batch_size
]
max_blocks
=
input_block_tables
.
shape
[
1
]
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
if
block_table
:
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
num_blocks
=
len
(
block_table
)
if
num_blocks
<=
max_blocks
:
input_block_tables
[
i
,
:
num_blocks
]
=
block_table
else
:
# It may be possible to have more blocks allocated due
# to lookahead slots of multi-step, however, they are
# not used anyway, so can be safely ignored.
input_block_tables
[
i
,
:
max_blocks
]
=
block_table
[:
max_blocks
]
block_tables
=
torch
.
from_numpy
(
input_block_tables
).
to
(
device
=
device
,
non_blocking
=
True
)
else
:
...
...
vllm/attention/backends/flashinfer.py
View file @
4851c202
...
...
@@ -224,6 +224,7 @@ class FlashInferState(AttentionState):
query_start_loc
=
query_start_loc_host
,
device
=
self
.
runner
.
device
,
data_type
=
kv_cache_dtype
,
q_data_type
=
self
.
runner
.
model_config
.
dtype
,
use_cuda_graph
=
True
,
decode_wrapper
=
self
.
_graph_decode_wrapper
,
prefill_wrapper
=
None
)
...
...
@@ -292,6 +293,8 @@ class FlashInferMetadata(AttentionMetadata):
page_size
:
Optional
[
int
]
=
None
# The data type of the paged kv cache
data_type
:
torch
.
dtype
=
None
# The data type of the query
q_data_type
:
torch
.
dtype
=
None
device
:
torch
.
device
=
torch
.
device
(
"cuda"
)
is_profile_run
:
bool
=
False
...
...
@@ -353,7 +356,10 @@ class FlashInferMetadata(AttentionMetadata):
self
.
page_size
,
# Disable flashinfer's pos encoding and use vllm's rope.
pos_encoding_mode
=
"NONE"
,
data_type
=
self
.
data_type
)
# kv-cache data type.
data_type
=
self
.
data_type
,
# query data type.
q_data_type
=
self
.
q_data_type
)
def
asdict_zerocopy
(
self
,
skip_fields
:
Optional
[
Set
[
str
]]
=
None
...
...
@@ -617,6 +623,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
query_start_loc
=
query_start_loc
,
device
=
device
,
data_type
=
kv_cache_dtype
,
q_data_type
=
self
.
runner
.
model_config
.
dtype
,
use_cuda_graph
=
use_captured_graph
,
is_profile_run
=
self
.
is_profile_run
)
...
...
vllm/benchmark_throughput.py
View file @
4851c202
...
...
@@ -11,8 +11,9 @@ import uvloop
from
tqdm
import
tqdm
from
transformers
import
(
AutoModelForCausalLM
,
AutoTokenizer
,
PreTrainedTokenizerBase
)
from
vllm.inputs
import
PromptInputs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
,
EngineArgs
from
vllm.engine.arg_utils
import
DEVICE_OPTIONS
,
AsyncEngineArgs
,
EngineArgs
from
vllm.entrypoints.openai.api_server
import
(
build_async_engine_client_from_engine_args
)
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
...
...
@@ -504,13 +505,11 @@ if __name__ == "__main__":
'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
'instead supported for common inference criteria.'
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
"auto"
,
choices
=
[
"auto"
,
"cuda"
,
"cpu"
,
"openvino"
,
"tpu"
,
"xpu"
],
help
=
'device type for vLLM execution, supporting CUDA, OpenVINO and '
'CPU.'
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
"auto"
,
choices
=
DEVICE_OPTIONS
,
help
=
'device type for vLLM execution'
)
parser
.
add_argument
(
"--num-scheduler-steps"
,
type
=
int
,
...
...
vllm/compilation/wrapper.py
View file @
4851c202
...
...
@@ -10,7 +10,7 @@ import torch
import
vllm.envs
as
envs
class
TorchCompileWrapperWithCustomDispa
c
ther
:
class
TorchCompileWrapperWithCustomDispat
c
her
:
"""
A wrapper class for torch.compile, with a custom dispatch logic.
Subclasses should:
...
...
vllm/config.py
View file @
4851c202
...
...
@@ -13,7 +13,7 @@ from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.platforms
import
current_platform
from
vllm.tracing
import
is_otel_available
,
otel_import_error_traceback
from
vllm.transformers_utils.config
import
(
get_config
,
from
vllm.transformers_utils.config
import
(
ConfigFormat
,
get_config
,
get_hf_image_processor_config
,
get_hf_text_config
)
from
vllm.utils
import
(
STR_NOT_IMPL_ENC_DEC_CUDAGRAPH
,
GiB_bytes
,
...
...
@@ -35,18 +35,20 @@ _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS
=
4096
_PP_SUPPORTED_MODELS
=
[
"AquilaModel"
,
"AquilaForCausalLM"
,
"AquilaModel"
,
"DeepseekV2ForCausalLM"
,
"GPT2LMHeadModel"
,
"InternLM2ForCausalLM"
,
"InternLMForCausalLM"
,
"InternVLChatModel"
,
"JAISLMHeadModel"
,
"LlamaForCausalLM"
,
"LLaMAForCausalLM"
,
"MistralForCausalLM"
,
"Phi3ForCausalLM"
,
"GPT2LMHeadModel"
,
"MixtralForCausalLM"
,
"NemotronForCausalLM"
,
"Phi3ForCausalLM"
,
"Qwen2ForCausalLM"
,
"Qwen2MoeForCausalLM"
,
"QWenLMHeadModel"
,
...
...
@@ -119,35 +121,37 @@ class ModelConfig:
override default neuron config that are specific to Neuron devices,
this argument will be used to configure the neuron config that
can not be gathered from the vllm arguments.
config_format: The config format which shall be loaded.
Defaults to 'auto' which defaults to 'hf'.
"""
def
__init__
(
self
,
model
:
str
,
tokenizer
:
str
,
tokenizer_m
ode
:
str
,
trust_remote_code
:
bool
,
dtype
:
Union
[
str
,
torch
.
dtype
]
,
seed
:
int
,
revision
:
Optional
[
str
]
=
None
,
code_revision
:
Optional
[
str
]
=
None
,
rope_scaling
:
Optional
[
dic
t
]
=
None
,
rope_theta
:
Optional
[
float
]
=
None
,
tokenizer_revisio
n
:
Optional
[
str
]
=
None
,
max_model_len
:
Optional
[
int
]
=
None
,
spec_target_max_model_le
n
:
Optional
[
int
]
=
None
,
quantization
:
Optional
[
str
]
=
None
,
quantization_param_path
:
Optional
[
str
]
=
None
,
enforce_eager
:
Optional
[
bool
]
=
None
,
max_context
_len_to_capture
:
Optional
[
int
]
=
None
,
max_seq_len_to_capture
:
Optional
[
int
]
=
None
,
max_logprobs
:
int
=
20
,
disable_sliding_window
:
bool
=
False
,
skip_tokenizer_init
:
bool
=
Fals
e
,
served_model_name
:
Optional
[
Union
[
str
,
List
[
str
]
]]
=
None
,
limit_mm_per_prompt
:
Optional
[
Mapping
[
str
,
int
]]
=
Non
e
,
use_async_output_proc
:
bool
=
Tru
e
,
override_neuron_config
:
Optional
[
Dict
[
str
,
Any
]]
=
N
on
e
)
->
None
:
def
__init__
(
self
,
model
:
str
,
tokenizer
:
str
,
tokenizer
_mode
:
str
,
trust_remote_c
ode
:
bool
,
dtype
:
Union
[
str
,
torch
.
dtype
]
,
seed
:
int
,
revision
:
Optional
[
str
]
=
None
,
code_
revision
:
Optional
[
str
]
=
None
,
rope_scaling
:
Optional
[
dict
]
=
None
,
rope_theta
:
Optional
[
floa
t
]
=
None
,
tokenizer_revision
:
Optional
[
str
]
=
None
,
max_model_le
n
:
Optional
[
int
]
=
None
,
spec_target_
max_model_len
:
Optional
[
int
]
=
None
,
quantizatio
n
:
Optional
[
str
]
=
None
,
quantization
_param_path
:
Optional
[
str
]
=
None
,
enforce_eager
:
Optional
[
bool
]
=
None
,
max_context_len_to_capture
:
Optional
[
int
]
=
None
,
max_seq
_len_to_capture
:
Optional
[
int
]
=
None
,
max_logprobs
:
int
=
20
,
disable_sliding_window
:
bool
=
False
,
skip_tokenizer_init
:
bool
=
False
,
served_model_name
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
Non
e
,
limit_mm_per_prompt
:
Optional
[
Mapping
[
str
,
int
]]
=
None
,
use_async_output_proc
:
bool
=
Tru
e
,
override_neuron_config
:
Optional
[
Dict
[
str
,
Any
]]
=
Non
e
,
config_format
:
ConfigFormat
=
C
on
figFormat
.
AUTO
)
->
None
:
self
.
model
=
model
self
.
tokenizer
=
tokenizer
self
.
tokenizer_mode
=
tokenizer_mode
...
...
@@ -174,7 +178,8 @@ class ModelConfig:
self
.
skip_tokenizer_init
=
skip_tokenizer_init
self
.
hf_config
=
get_config
(
self
.
model
,
trust_remote_code
,
revision
,
code_revision
,
rope_scaling
,
rope_theta
)
code_revision
,
rope_scaling
,
rope_theta
,
config_format
)
self
.
hf_text_config
=
get_hf_text_config
(
self
.
hf_config
)
self
.
hf_image_processor_config
=
get_hf_image_processor_config
(
self
.
model
,
revision
)
...
...
@@ -275,11 +280,11 @@ class ModelConfig:
def
_verify_quantization
(
self
)
->
None
:
supported_quantization
=
[
*
QUANTIZATION_METHODS
]
rocm_supported_quantization
=
[
"awq"
,
"gptq"
,
"squeezellm"
]
# "fp8"
rocm_supported_quantization
=
[
"awq"
,
"gptq"
]
# "fp8"
optimized_quantization_methods
=
[
"fp8"
,
"marlin"
,
"gptq_marlin_24"
,
"gptq_marlin"
,
"awq_marlin"
,
"fbgemm_fp8"
,
"compressed_tensors"
,
"compressed-tensors"
,
"experts_int8"
"fp8"
,
"marlin"
,
"modelopt"
,
"gptq_marlin_24"
,
"gptq_marlin"
,
"awq_marlin"
,
"fbgemm_fp8"
,
"compressed_tensors"
,
"compressed-tensors"
,
"experts_int8"
]
tpu_supported_quantization
=
[
"tpu_int8"
]
neuron_supported_quantization
=
[
"neuron_quant"
]
...
...
@@ -744,6 +749,7 @@ class LoadFormat(str, enum.Enum):
SHARDED_STATE
=
"sharded_state"
GGUF
=
"gguf"
BITSANDBYTES
=
"bitsandbytes"
MISTRAL
=
"mistral"
@
dataclass
...
...
@@ -767,7 +773,7 @@ class LoadConfig:
ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's
checkpoints.
"""
load_format
:
Union
[
str
,
LoadFormat
,
"BaseModelLoader"
]
=
LoadFormat
.
AUTO
...
...
@@ -870,7 +876,8 @@ class ParallelConfig:
from
vllm.executor
import
ray_utils
backend
=
"mp"
ray_found
=
ray_utils
.
ray_is_available
()
if
cuda_device_count_stateless
()
<
self
.
world_size
:
if
(
torch
.
cuda
.
is_available
()
and
cuda_device_count_stateless
()
<
self
.
world_size
):
if
not
ray_found
:
raise
ValueError
(
"Unable to load Ray which is "
"required for multi-node inference, "
...
...
@@ -1535,7 +1542,7 @@ class LoRAConfig:
if
model_config
.
quantization
and
model_config
.
quantization
not
in
[
"awq"
,
"gptq"
]:
# TODO support marlin
and squeezellm
# TODO support marlin
logger
.
warning
(
"%s quantization is not tested with LoRA yet."
,
model_config
.
quantization
)
...
...
@@ -1552,14 +1559,6 @@ class PromptAdapterConfig:
prompt_adapter_dtype
:
Optional
[
torch
.
dtype
]
=
None
def
__post_init__
(
self
):
library_name
=
'peft'
try
:
__import__
(
library_name
)
except
ImportError
as
e
:
raise
ImportError
(
f
"'
{
library_name
}
' is not installed for prompt adapter support."
f
"Please install it using 'pip install
{
library_name
}
'."
)
from
e
if
self
.
max_prompt_adapters
<
1
:
raise
ValueError
(
f
"max_prompt_adapters "
...
...
@@ -1735,8 +1734,11 @@ def _get_and_verify_max_len(
"with rope_scaling. Please raise an issue so we can "
"investigate."
)
assert
"factor"
in
rope_scaling
scaling_factor
=
rope_scaling
[
"factor"
]
if
rope_type
==
"mrope"
:
scaling_factor
=
1
else
:
assert
"factor"
in
rope_scaling
scaling_factor
=
rope_scaling
[
"factor"
]
if
rope_type
==
"yarn"
:
derived_max_model_len
=
rope_scaling
[
"original_max_position_embeddings"
]
...
...
vllm/core/scheduler.py
View file @
4851c202
...
...
@@ -537,13 +537,6 @@ class Scheduler:
preempted
:
List
[
SequenceGroup
]
=
ret
.
preempted
swapped_out
:
List
[
SequenceGroup
]
=
ret
.
swapped_out
# NOTE(woosuk): Preemption happens only when there is no available slot
# to keep all the sequence groups in the RUNNING state.
# Store original running requests for the case of async + preemption
if
self
.
use_async_output_proc
:
orig_running
=
self
.
running
.
copy
()
running_queue
=
self
.
running
assert
len
(
self
.
_async_stopped
)
==
0
while
running_queue
:
...
...
@@ -552,6 +545,7 @@ class Scheduler:
seq_group
,
SequenceStatus
.
RUNNING
,
enable_chunking
,
budget
)
if
num_running_tokens
==
0
:
# No budget => Stop
break
running_queue
.
popleft
()
...
...
@@ -565,18 +559,8 @@ class Scheduler:
self
.
_async_stopped
.
append
(
seq_group
)
continue
# With async postprocessor, when preemption kicks in, we need
# first to drain the async postprocessor, so that all async
# block_table freeing is applied before the preemption freeing
# is applied.
if
self
.
use_async_output_proc
and
not
self
.
_can_append_slots
(
seq_group
):
tmp
=
self
.
running
self
.
running
=
orig_running
assert
self
.
output_proc_callback
is
not
None
self
.
output_proc_callback
()
self
.
running
=
tmp
# NOTE(woosuk): Preemption happens only when there is no available
# slot to keep all the sequence groups in the RUNNING state.
while
not
self
.
_can_append_slots
(
seq_group
):
budget
.
subtract_num_batched_tokens
(
seq_group
.
request_id
,
num_running_tokens
)
...
...
@@ -588,24 +572,43 @@ class Scheduler:
and
seq_group
.
lora_int_id
in
curr_loras
):
curr_loras
.
remove
(
seq_group
.
lora_int_id
)
# Determine victim sequence
cont_loop
=
True
if
running_queue
:
# Preempt the lowest-priority sequence group
s
.
# Preempt the lowest-priority sequence group.
victim_seq_group
=
running_queue
.
pop
()
else
:
# No other sequence group can be preempted.
# Preempt the current sequence group.
# Note: This is also where we stop this loop
# (since there is nothing else to preempt)
victim_seq_group
=
seq_group
cont_loop
=
False
# With async postprocessor, before preempting a sequence
# we need to ensure it has no pending async postprocessor
do_preempt
=
True
if
self
.
use_async_output_proc
:
assert
self
.
output_proc_callback
is
not
None
self
.
output_proc_callback
(
request_id
=
victim_seq_group
.
request_id
)
# It may be that the async pending "victim_seq_group"
# becomes finished, in which case we simply free it.
if
victim_seq_group
.
is_finished
():
self
.
_free_finished_seq_group
(
victim_seq_group
)
do_preempt
=
False
# Do preemption
if
do_preempt
:
preempted_mode
=
self
.
_preempt
(
victim_seq_group
,
blocks_to_swap_out
)
if
preempted_mode
==
PreemptionMode
.
RECOMPUTE
:
preempted
.
append
(
victim_seq_group
)
else
:
swapped_out
.
append
(
victim_seq_group
)
else
:
# No other sequence groups can be preempted.
# Preempt the current sequence group.
preempted_mode
=
self
.
_preempt
(
seq_group
,
blocks_to_swap_out
)
if
preempted_mode
==
PreemptionMode
.
RECOMPUTE
:
preempted
.
append
(
seq_group
)
else
:
swapped_out
.
append
(
seq_group
)
if
not
cont_loop
:
break
else
:
self
.
_append_slots
(
seq_group
,
blocks_to_copy
)
...
...
@@ -1264,22 +1267,26 @@ class Scheduler:
if
seq
.
is_finished
():
self
.
free_seq
(
seq
)
def
_free_finished_seq_group
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
if
seq_group
.
is_finished
():
# Free cross-attention block table, if it exists
self
.
_free_seq_group_cross_attn_blocks
(
seq_group
)
# Add the finished requests to the finished requests list.
# This list will be used to update the Mamba cache in the
# next step.
self
.
_finished_requests_ids
.
append
(
seq_group
.
request_id
)
# Free finished seqs
self
.
_free_finished_seqs
(
seq_group
)
def
free_finished_seq_groups
(
self
)
->
None
:
remaining
:
Deque
[
SequenceGroup
]
=
deque
()
for
seq_group
in
self
.
running
:
if
seq_group
.
is_finished
():
# Free cross-attention block table, if it exists
self
.
_free_seq_group_cross_attn_blocks
(
seq_group
)
# Add the finished requests to the finished requests list.
# This list will be used to update the Mamba cache in the
# next step.
self
.
_finished_requests_ids
.
append
(
seq_group
.
request_id
)
else
:
self
.
_free_finished_seq_group
(
seq_group
)
if
not
seq_group
.
is_finished
():
remaining
.
append
(
seq_group
)
# Free finished seqs
self
.
_free_finished_seqs
(
seq_group
)
self
.
running
=
remaining
# Handle async stopped sequence groups
...
...
vllm/engine/arg_utils.py
View file @
4851c202
...
...
@@ -8,10 +8,10 @@ from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple,
import
torch
import
vllm.envs
as
envs
from
vllm.config
import
(
CacheConfig
,
DecodingConfig
,
Device
Config
,
EngineConfig
,
LoadConfig
,
LoadFormat
,
LoRAConfig
,
ModelConfig
,
ObservabilityConfig
,
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
,
from
vllm.config
import
(
CacheConfig
,
ConfigFormat
,
Decoding
Config
,
DeviceConfig
,
EngineConfig
,
LoadConfig
,
LoadFormat
,
LoRAConfig
,
ModelConfig
,
ObservabilityConfig
,
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
,
SpeculativeConfig
,
TokenizerPoolConfig
)
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.logger
import
init_logger
...
...
@@ -26,6 +26,16 @@ logger = init_logger(__name__)
ALLOWED_DETAILED_TRACE_MODULES
=
[
"model"
,
"worker"
,
"all"
]
DEVICE_OPTIONS
=
[
"auto"
,
"cuda"
,
"neuron"
,
"cpu"
,
"openvino"
,
"tpu"
,
"xpu"
,
]
def
nullable_str
(
val
:
str
):
if
not
val
or
val
==
"None"
:
...
...
@@ -65,6 +75,7 @@ class EngineArgs:
trust_remote_code
:
bool
=
False
download_dir
:
Optional
[
str
]
=
None
load_format
:
str
=
'auto'
config_format
:
str
=
'auto'
dtype
:
str
=
'auto'
kv_cache_dtype
:
str
=
'auto'
quantization_param_path
:
Optional
[
str
]
=
None
...
...
@@ -234,6 +245,13 @@ class EngineArgs:
'section for more information.
\n
'
'* "bitsandbytes" will load the weights using bitsandbytes '
'quantization.
\n
'
)
parser
.
add_argument
(
'--config-format'
,
default
=
EngineArgs
.
config_format
,
choices
=
[
f
.
value
for
f
in
ConfigFormat
],
help
=
'The format of the model config to load.
\n\n
'
'* "auto" will try to load the config in hf format '
'if available else it will try to load in mistral format '
)
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
...
...
@@ -545,10 +563,7 @@ class EngineArgs:
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
EngineArgs
.
device
,
choices
=
[
"auto"
,
"cuda"
,
"neuron"
,
"cpu"
,
"openvino"
,
"tpu"
,
"xpu"
],
choices
=
DEVICE_OPTIONS
,
help
=
'Device type for vLLM execution.'
)
parser
.
add_argument
(
'--num-scheduler-steps'
,
type
=
int
,
...
...
@@ -763,6 +778,43 @@ class EngineArgs:
engine_args
=
cls
(
**
{
attr
:
getattr
(
args
,
attr
)
for
attr
in
attrs
})
return
engine_args
def
create_model_config
(
self
)
->
ModelConfig
:
return
ModelConfig
(
model
=
self
.
model
,
tokenizer
=
self
.
tokenizer
,
tokenizer_mode
=
self
.
tokenizer_mode
,
trust_remote_code
=
self
.
trust_remote_code
,
dtype
=
self
.
dtype
,
seed
=
self
.
seed
,
revision
=
self
.
revision
,
code_revision
=
self
.
code_revision
,
rope_scaling
=
self
.
rope_scaling
,
rope_theta
=
self
.
rope_theta
,
tokenizer_revision
=
self
.
tokenizer_revision
,
max_model_len
=
self
.
max_model_len
,
quantization
=
self
.
quantization
,
quantization_param_path
=
self
.
quantization_param_path
,
enforce_eager
=
self
.
enforce_eager
,
max_context_len_to_capture
=
self
.
max_context_len_to_capture
,
max_seq_len_to_capture
=
self
.
max_seq_len_to_capture
,
max_logprobs
=
self
.
max_logprobs
,
disable_sliding_window
=
self
.
disable_sliding_window
,
skip_tokenizer_init
=
self
.
skip_tokenizer_init
,
served_model_name
=
self
.
served_model_name
,
limit_mm_per_prompt
=
self
.
limit_mm_per_prompt
,
use_async_output_proc
=
not
self
.
disable_async_output_proc
,
override_neuron_config
=
self
.
override_neuron_config
,
config_format
=
self
.
config_format
,
)
def
create_load_config
(
self
)
->
LoadConfig
:
return
LoadConfig
(
load_format
=
self
.
load_format
,
download_dir
=
self
.
download_dir
,
model_loader_extra_config
=
self
.
model_loader_extra_config
,
ignore_patterns
=
self
.
ignore_patterns
,
)
def
create_engine_config
(
self
)
->
EngineConfig
:
# gguf file needs a specific model loader and doesn't use hf_repo
if
check_gguf_file
(
self
.
model
):
...
...
@@ -789,31 +841,8 @@ class EngineArgs:
f
", but got
{
self
.
cpu_offload_gb
}
"
)
device_config
=
DeviceConfig
(
device
=
self
.
device
)
model_config
=
ModelConfig
(
model
=
self
.
model
,
tokenizer
=
self
.
tokenizer
,
tokenizer_mode
=
self
.
tokenizer_mode
,
trust_remote_code
=
self
.
trust_remote_code
,
dtype
=
self
.
dtype
,
seed
=
self
.
seed
,
revision
=
self
.
revision
,
code_revision
=
self
.
code_revision
,
rope_scaling
=
self
.
rope_scaling
,
rope_theta
=
self
.
rope_theta
,
tokenizer_revision
=
self
.
tokenizer_revision
,
max_model_len
=
self
.
max_model_len
,
quantization
=
self
.
quantization
,
quantization_param_path
=
self
.
quantization_param_path
,
enforce_eager
=
self
.
enforce_eager
,
max_context_len_to_capture
=
self
.
max_context_len_to_capture
,
max_seq_len_to_capture
=
self
.
max_seq_len_to_capture
,
max_logprobs
=
self
.
max_logprobs
,
disable_sliding_window
=
self
.
disable_sliding_window
,
skip_tokenizer_init
=
self
.
skip_tokenizer_init
,
served_model_name
=
self
.
served_model_name
,
limit_mm_per_prompt
=
self
.
limit_mm_per_prompt
,
use_async_output_proc
=
not
self
.
disable_async_output_proc
,
override_neuron_config
=
self
.
override_neuron_config
)
model_config
=
self
.
create_model_config
()
cache_config
=
CacheConfig
(
block_size
=
self
.
block_size
if
self
.
device
!=
"neuron"
else
self
.
max_model_len
,
# neuron needs block_size = max_model_len
...
...
@@ -856,7 +885,6 @@ class EngineArgs:
if
(
is_gpu
and
not
use_sliding_window
and
not
use_spec_decode
and
not
self
.
enable_lora
and
not
self
.
enable_prompt_adapter
and
not
self
.
enable_prefix_caching
and
not
has_seqlen_agnostic_layers
):
self
.
enable_chunked_prefill
=
True
logger
.
warning
(
...
...
@@ -956,12 +984,7 @@ class EngineArgs:
self
.
model_loader_extra_config
[
"qlora_adapter_name_or_path"
]
=
self
.
qlora_adapter_name_or_path
load_config
=
LoadConfig
(
load_format
=
self
.
load_format
,
download_dir
=
self
.
download_dir
,
model_loader_extra_config
=
self
.
model_loader_extra_config
,
ignore_patterns
=
self
.
ignore_patterns
,
)
load_config
=
self
.
create_load_config
()
prompt_adapter_config
=
PromptAdapterConfig
(
max_prompt_adapters
=
self
.
max_prompt_adapters
,
...
...
vllm/engine/async_llm_engine.py
View file @
4851c202
...
...
@@ -342,17 +342,17 @@ class _AsyncLLMEngine(LLMEngine):
virtual_engine
]
# Execute the model.
output
=
await
self
.
model_executor
.
execute_model_async
(
output
s
=
await
self
.
model_executor
.
execute_model_async
(
execute_model_req
)
# we need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
if
self
.
scheduler_config
.
is_multi_step
:
self
.
_update_cached_scheduler_output
(
virtual_engine
,
output
)
self
.
_update_cached_scheduler_output
(
virtual_engine
,
output
s
)
else
:
if
len
(
ctx
.
output_queue
)
>
0
:
self
.
_process_model_outputs
(
ctx
=
ctx
)
output
=
[]
output
s
=
[]
# Finish the current step for all the sequence groups.
if
self
.
scheduler_config
.
is_multi_step
:
...
...
@@ -365,25 +365,25 @@ class _AsyncLLMEngine(LLMEngine):
self
.
cached_scheduler_outputs
[
virtual_engine
]
=
SchedulerOutputState
()
is_async
=
allow_async_output_proc
is_last_step
=
True
ctx
.
output_queue
.
append
(
(
output
,
seq_group_metadata_list
,
scheduler_outputs
,
is_asyn
c
,
is_last_step
)
)
ctx
.
append_output
(
outputs
=
outputs
,
seq_group_metadata_list
=
seq_group_metadata_list
,
scheduler_outputs
=
scheduler_outputs
,
is_async
=
allow_async_output_pro
c
,
is_last_step
=
True
)
if
output
and
allow_async_output_proc
:
if
output
s
and
allow_async_output_proc
:
assert
len
(
output
output
s
)
==
1
,
"Async postprocessor expects only a single output set"
self
.
_advance_to_next_step
(
output
[
0
],
seq_group_metadata_list
,
output
s
[
0
],
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
if
not
allow_async_output_proc
:
self
.
_process_model_outputs
(
ctx
=
ctx
)
# Log stats.
self
.
do_log_stats
(
scheduler_outputs
,
output
)
self
.
do_log_stats
(
scheduler_outputs
,
output
s
)
# Tracing
self
.
do_tracing
(
scheduler_outputs
)
...
...
vllm/engine/llm_engine.py
View file @
4851c202
...
...
@@ -2,9 +2,9 @@ import functools
import
time
from
collections
import
deque
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
from
typing
import
(
TYPE_CHECKING
,
Any
,
ClassVar
,
Deque
,
Dict
,
Iterable
,
List
,
Mapping
,
Optional
)
Mapping
,
NamedTuple
,
Optional
)
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Set
,
Tuple
,
Type
,
Union
...
...
@@ -90,17 +90,36 @@ class SchedulerOutputState:
last_output
:
Optional
[
SamplerOutput
]
=
None
@
dataclass
class
OutputData
(
NamedTuple
):
outputs
:
List
[
SamplerOutput
]
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
scheduler_outputs
:
SchedulerOutputs
is_async
:
bool
is_last_step
:
bool
skip
:
List
[
int
]
class
SchedulerContext
:
output_queue
:
Deque
[
Tuple
[
Optional
[
List
[
SamplerOutput
]],
List
[
SequenceGroupMetadata
],
SchedulerOutputs
,
bool
,
bool
]]
=
field
(
default_factory
=
lambda
:
deque
())
request_outputs
:
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]
=
field
(
default_factory
=
lambda
:
[])
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]]
=
None
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
def
__init__
(
self
):
self
.
output_queue
:
Deque
[
OutputData
]
=
deque
()
self
.
request_outputs
:
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]
=
[]
self
.
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]]
=
None
self
.
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
def
append_output
(
self
,
outputs
:
List
[
SamplerOutput
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
scheduler_outputs
:
SchedulerOutputs
,
is_async
:
bool
,
is_last_step
:
bool
):
self
.
output_queue
.
append
(
OutputData
(
outputs
=
outputs
,
seq_group_metadata_list
=
seq_group_metadata_list
,
scheduler_outputs
=
scheduler_outputs
,
is_async
=
is_async
,
is_last_step
=
is_last_step
,
skip
=
[]))
class
LLMEngine
:
...
...
@@ -1254,23 +1273,15 @@ class LLMEngine:
return
def
_process_model_outputs
(
self
,
ctx
:
SchedulerContext
)
->
None
:
"""Apply the model output to the sequences in the scheduled seq groups.
def
_process_model_outputs
(
self
,
ctx
:
SchedulerContext
,
request_id
:
Optional
[
str
]
=
None
)
->
None
:
"""Apply the model output to the sequences in the scheduled seq groups
and return responses.
virtual_engine: The engine id to operate on
ctx: The virtual engine context to work on
request_id: If provided, then only this request is going to be processed
is_async: Indicates whether this postprocessor runs in
parallel with the GPU forward pass and is processing
tokens from the previous step. If this is true, then
no tokens need to be appended since it is already done
externally (before the next schedule() call)
sampler_output: Used with multi-step execution to provide
sampler_output of each step
is_last_output: Used with multi-step execution to indicate
the last step (of each multi-step group)
Returns RequestOutputs that can be returned to the client.
"""
now
=
time
.
time
()
...
...
@@ -1278,9 +1289,14 @@ class LLMEngine:
return
None
# Get pending async postprocessor
(
outputs
,
seq_group_metadata_list
,
scheduler_outputs
,
is_async
,
is_last_step
)
=
ctx
.
output_queue
.
popleft
()
assert
outputs
is
not
None
if
request_id
:
# When we process only one request, no pop is required
# (since later we will process all of the rest)
(
outputs
,
seq_group_metadata_list
,
scheduler_outputs
,
is_async
,
is_last_step
,
skip
)
=
ctx
.
output_queue
[
0
]
else
:
(
outputs
,
seq_group_metadata_list
,
scheduler_outputs
,
is_async
,
is_last_step
,
skip
)
=
ctx
.
output_queue
.
popleft
()
# Sanity check
assert
len
(
seq_group_metadata_list
)
==
len
(
...
...
@@ -1294,9 +1310,30 @@ class LLMEngine:
else
:
outputs_by_sequence_group
=
outputs
# Determine the requests we need to operate on
if
request_id
:
indices
=
[]
for
i
,
seq_group_meta
in
enumerate
(
seq_group_metadata_list
):
if
seq_group_meta
.
request_id
==
request_id
:
assert
i
not
in
skip
# Cannot be called twice
indices
.
append
(
i
)
break
# If the request_id was not found, then it means that
# this is a new request that has no pending async
# postprocessor
if
not
indices
:
return
else
:
indices
=
range
(
len
(
seq_group_metadata_list
))
# type: ignore
finished_before
:
List
[
int
]
=
[]
finished_now
:
List
[
int
]
=
[]
for
i
,
seq_group_meta
in
enumerate
(
seq_group_metadata_list
):
for
i
in
indices
:
if
i
in
skip
:
continue
seq_group_meta
=
seq_group_metadata_list
[
i
]
scheduled_seq_group
=
scheduler_outputs
.
scheduled_seq_groups
[
i
]
seq_group
=
scheduled_seq_group
.
seq_group
...
...
@@ -1351,6 +1388,18 @@ class LLMEngine:
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
ctx
.
request_outputs
.
append
(
request_output
)
# When we process a single request, we skip it for the next time,
# and invoke the request output callback (if there was final output)
if
request_id
:
assert
len
(
indices
)
==
1
skip
.
append
(
indices
[
0
])
if
(
finished_now
and
self
.
process_request_outputs_callback
is
not
None
):
self
.
process_request_outputs_callback
(
ctx
.
request_outputs
)
ctx
.
request_outputs
.
clear
()
return
# Free currently finished requests
if
finished_now
:
for
scheduler
in
self
.
scheduler
:
...
...
@@ -1362,17 +1411,16 @@ class LLMEngine:
if
(
finished_now
and
self
.
process_request_outputs_callback
is
not
None
):
self
.
process_request_outputs_callback
(
ctx
.
request_outputs
)
ctx
.
request_outputs
.
clear
()
return
# Create the outputs
# Note: scheduled_seq_groups and seq_group_metadata_list
# must match with the indices
for
i
,
scheduled_seq_group
in
enumerate
(
scheduler_outputs
.
scheduled_seq_groups
):
if
i
in
finished_before
or
i
in
finished_now
:
for
i
in
indices
:
if
i
in
skip
or
i
in
finished_before
or
i
in
finished_now
:
continue
# Avoids double processing
scheduled_seq_group
=
scheduler_outputs
.
scheduled_seq_groups
[
i
]
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
.
maybe_set_first_token_time
(
now
)
if
(
seq_group
.
is_finished
()
...
...
@@ -1388,6 +1436,7 @@ class LLMEngine:
if
(
ctx
.
request_outputs
and
self
.
process_request_outputs_callback
is
not
None
):
self
.
process_request_outputs_callback
(
ctx
.
request_outputs
)
ctx
.
request_outputs
.
clear
()
# For async case, we need to record the stats here.
# For non-async case, the stats are done in the
...
...
@@ -1556,20 +1605,20 @@ class LLMEngine:
execute_model_req
.
async_callback
=
self
.
async_callbacks
[
virtual_engine
]
output
=
self
.
model_executor
.
execute_model
(
output
s
=
self
.
model_executor
.
execute_model
(
execute_model_req
=
execute_model_req
)
# We need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
if
self
.
scheduler_config
.
is_multi_step
:
self
.
_update_cached_scheduler_output
(
virtual_engine
,
output
)
self
.
_update_cached_scheduler_output
(
virtual_engine
,
output
s
)
else
:
# Nothing scheduled => If there is pending async postprocessor,
# then finish it here.
if
len
(
ctx
.
output_queue
)
>
0
:
self
.
_process_model_outputs
(
ctx
=
ctx
)
# No outputs in this case
output
=
[]
output
s
=
[]
# Finish the current step for all the sequence groups.
if
self
.
scheduler_config
.
is_multi_step
:
...
...
@@ -1582,18 +1631,18 @@ class LLMEngine:
self
.
cached_scheduler_outputs
[
0
]
=
SchedulerOutputState
()
# Add results to the output_queue
is_async
=
allow_async_output_proc
is_last_step
=
True
ctx
.
output_queue
.
append
(
(
output
,
seq_group_metadata_list
,
scheduler_outputs
,
is_asyn
c
,
is_last_step
)
)
if
output
and
allow_async_output_proc
:
assert
len
(
output
)
==
1
,
(
ctx
.
append_output
(
outputs
=
outputs
,
seq_group_metadata_list
=
seq_group_metadata_list
,
scheduler_outputs
=
scheduler_outputs
,
is_async
=
allow_async_output_pro
c
,
is_last_step
=
True
)
if
output
s
and
allow_async_output_proc
:
assert
len
(
output
s
)
==
1
,
(
"Async postprocessor expects only a single output set"
)
self
.
_advance_to_next_step
(
output
[
0
],
seq_group_metadata_list
,
output
s
[
0
],
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
# Check if need to run the usual non-async path
...
...
@@ -1601,7 +1650,7 @@ class LLMEngine:
self
.
_process_model_outputs
(
ctx
=
ctx
)
# Log stats.
self
.
do_log_stats
(
scheduler_outputs
,
output
)
self
.
do_log_stats
(
scheduler_outputs
,
output
s
)
# Tracing
self
.
do_tracing
(
scheduler_outputs
)
...
...
@@ -1922,6 +1971,12 @@ class LLMEngine:
self
.
tokenizer
.
check_health
()
self
.
model_executor
.
check_health
()
def
start_profile
(
self
)
->
None
:
self
.
model_executor
.
start_profile
()
def
stop_profile
(
self
)
->
None
:
self
.
model_executor
.
stop_profile
()
def
is_tracing_enabled
(
self
)
->
bool
:
return
self
.
tracer
is
not
None
...
...
vllm/entrypoints/chat_utils.py
View file @
4851c202
...
...
@@ -23,6 +23,7 @@ from openai.types.chat import (ChatCompletionMessageToolCallParam,
# yapf: enable
# pydantic needs the TypedDict from typing_extensions
from
pydantic
import
ConfigDict
from
transformers
import
PreTrainedTokenizer
,
PreTrainedTokenizerFast
from
typing_extensions
import
Required
,
TypeAlias
,
TypedDict
from
vllm.config
import
ModelConfig
...
...
@@ -31,7 +32,7 @@ from vllm.multimodal import MultiModalDataDict
from
vllm.multimodal.utils
import
(
async_get_and_parse_audio
,
async_get_and_parse_image
,
get_and_parse_audio
,
get_and_parse_image
)
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
logger
=
init_logger
(
__name__
)
...
...
@@ -107,7 +108,7 @@ class ConversationMessage(TypedDict, total=False):
"""The tool calls generated by the model, such as function calls."""
ModalityStr
=
Literal
[
"image"
,
"audio"
]
ModalityStr
=
Literal
[
"image"
,
"audio"
,
"video"
]
_T
=
TypeVar
(
"_T"
)
...
...
@@ -147,20 +148,29 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
return
f
"<|image_
{
current_count
}
|>"
if
model_type
==
"minicpmv"
:
return
"(<image>./</image>)"
if
model_type
in
(
"blip-2"
,
"chatglm"
,
"fuyu"
,
"paligemma"
):
if
model_type
in
(
"blip-2"
,
"chatglm"
,
"fuyu"
,
"paligemma"
,
"pixtral"
):
# These models do not use image tokens in the prompt
return
None
if
model_type
==
"qwen"
:
return
f
"Picture
{
current_count
}
: <img></img>"
if
model_type
.
startswith
(
"llava"
):
return
self
.
_cached_token_str
(
self
.
_tokenizer
,
hf_config
.
image_token_index
)
if
model_type
in
(
"chameleon"
,
"internvl_chat"
):
return
"<image>"
if
model_type
==
"qwen2_vl"
:
return
"<|vision_start|><|image_pad|><|vision_end|>"
raise
TypeError
(
f
"Unknown model type:
{
model_type
}
"
)
elif
modality
==
"audio"
:
if
model_type
==
"ultravox"
:
return
"<|reserved_special_token_0|>"
raise
TypeError
(
f
"Unknown model type:
{
model_type
}
"
)
elif
modality
==
"video"
:
if
model_type
==
"qwen2_vl"
:
return
"<|vision_start|><|video_pad|><|vision_end|>"
raise
TypeError
(
f
"Unknown model type:
{
model_type
}
"
)
else
:
raise
TypeError
(
f
"Unknown modality:
{
modality
}
"
)
...
...
@@ -377,6 +387,9 @@ def _parse_chat_message_content_parts(
audio_url
=
_AudioParser
(
part
)[
"audio_url"
]
mm_parser
.
parse_audio
(
audio_url
[
"url"
])
elif
part_type
==
"refusal"
:
text
=
_RefusalParser
(
part
)[
"refusal"
]
texts
.
append
(
text
)
else
:
raise
NotImplementedError
(
f
"Unknown part type:
{
part_type
}
"
)
...
...
@@ -431,6 +444,21 @@ def _parse_chat_message_content(
return
result
def
_postprocess_messages
(
messages
:
List
[
ConversationMessage
])
->
None
:
# per the Transformers docs & maintainers, tool call arguments in
# assistant-role messages with tool_calls need to be dicts not JSON str -
# this is how tool-use chat templates will expect them moving forwards
# so, for messages that have tool_calls, parse the string (which we get
# from openAI format) to dict
for
message
in
messages
:
if
(
message
[
"role"
]
==
"assistant"
and
"tool_calls"
in
message
and
isinstance
(
message
[
"tool_calls"
],
list
)):
for
item
in
message
[
"tool_calls"
]:
item
[
"function"
][
"arguments"
]
=
json
.
loads
(
item
[
"function"
][
"arguments"
])
def
parse_chat_messages
(
messages
:
List
[
ChatCompletionMessageParam
],
model_config
:
ModelConfig
,
...
...
@@ -444,6 +472,8 @@ def parse_chat_messages(
conversation
.
extend
(
sub_messages
)
_postprocess_messages
(
conversation
)
return
conversation
,
mm_tracker
.
all_mm_data
()
...
...
@@ -460,41 +490,44 @@ def parse_chat_messages_futures(
conversation
.
extend
(
sub_messages
)
_postprocess_messages
(
conversation
)
return
conversation
,
mm_tracker
.
all_mm_data
()
def
apply_chat_template
(
tokenizer
:
Any
Tokenizer
,
def
apply_
hf_
chat_template
(
tokenizer
:
Union
[
PreTrainedTokenizer
,
PreTrained
Tokenizer
Fast
]
,
conversation
:
List
[
ConversationMessage
],
chat_template
:
Optional
[
str
],
*
,
tokenize
:
bool
=
False
,
# Different from HF's default
**
kwargs
:
Any
,
)
->
Union
[
str
,
List
[
int
]]
:
)
->
str
:
if
chat_template
is
None
and
tokenizer
.
chat_template
is
None
:
raise
ValueError
(
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one."
)
# per the Transformers docs & maintainers, tool call arguments in
# assistant-role messages with tool_calls need to be dicts not JSON str -
# this is how tool-use chat templates will expect them moving forwards
# so, for messages that have tool_calls, parse the string (which we get
# from openAI format) to dict
for
message
in
conversation
:
if
(
message
[
"role"
]
==
"assistant"
and
"tool_calls"
in
message
and
isinstance
(
message
[
"tool_calls"
],
list
)):
for
i
in
range
(
len
(
message
[
"tool_calls"
])):
args
:
str
=
message
[
"tool_calls"
][
i
][
"function"
][
"arguments"
]
parsed_args
:
Dict
=
json
.
loads
(
args
)
message
[
"tool_calls"
][
i
][
"function"
][
"arguments"
]
=
parsed_args
prompt
=
tokenizer
.
apply_chat_template
(
conversation
=
conversation
,
return
tokenizer
.
apply_chat_template
(
conversation
=
conversation
,
# type: ignore[arg-type]
chat_template
=
chat_template
,
tokenize
=
tokenize
,
**
kwargs
,
)
return
prompt
def
apply_mistral_chat_template
(
tokenizer
:
MistralTokenizer
,
messages
:
List
[
ChatCompletionMessageParam
],
chat_template
:
Optional
[
str
]
=
None
,
**
kwargs
:
Any
,
)
->
List
[
int
]:
if
chat_template
is
not
None
:
logger
.
warning
(
"'chat_template' cannot be overridden for mistral tokenizer."
)
return
tokenizer
.
apply_chat_template
(
messages
=
messages
,
**
kwargs
,
)
Prev
1
2
3
4
5
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment