Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ef640440
Unverified
Commit
ef640440
authored
Mar 07, 2025
by
afeldman-nm
Committed by
GitHub
Mar 08, 2025
Browse files
[V1] Prompt logprobs + APC compatibility; prompt logprobs reqs cannot fill APC (#13949)
parent
66e16a03
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
292 additions
and
162 deletions
+292
-162
tests/v1/core/test_prefix_caching.py
tests/v1/core/test_prefix_caching.py
+110
-2
tests/v1/core/test_scheduler.py
tests/v1/core/test_scheduler.py
+49
-12
tests/v1/engine/test_async_llm.py
tests/v1/engine/test_async_llm.py
+0
-36
tests/v1/engine/test_llm_engine.py
tests/v1/engine/test_llm_engine.py
+0
-15
tests/v1/engine/utils.py
tests/v1/engine/utils.py
+0
-3
tests/v1/sample/test_logprobs.py
tests/v1/sample/test_logprobs.py
+85
-60
tests/v1/sample/utils.py
tests/v1/sample/utils.py
+24
-9
vllm/v1/core/kv_cache_manager.py
vllm/v1/core/kv_cache_manager.py
+24
-19
vllm/v1/engine/processor.py
vllm/v1/engine/processor.py
+0
-6
No files found.
tests/v1/core/test_prefix_caching.py
View file @
ef640440
# SPDX-License-Identifier: Apache-2.0
"""Compare the with and without prefix caching."""
from
typing
import
Optional
import
pytest
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
...
...
@@ -15,7 +17,8 @@ from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock,
def
make_request
(
request_id
,
prompt_token_ids
,
mm_positions
=
None
,
mm_hashes
=
None
):
mm_hashes
=
None
,
prompt_logprobs
:
Optional
[
int
]
=
None
):
if
mm_positions
is
None
:
multi_modal_inputs
=
None
else
:
...
...
@@ -28,7 +31,8 @@ def make_request(request_id,
multi_modal_inputs
=
multi_modal_inputs
,
multi_modal_hashes
=
mm_hashes
,
multi_modal_placeholders
=
mm_positions
,
sampling_params
=
SamplingParams
(
max_tokens
=
17
),
sampling_params
=
SamplingParams
(
max_tokens
=
17
,
prompt_logprobs
=
prompt_logprobs
),
eos_token_id
=
100
,
arrival_time
=
0
,
lora_request
=
None
,
...
...
@@ -144,6 +148,110 @@ def test_prefill():
assert
manager
.
block_pool
.
free_block_queue
.
free_list_tail
is
None
def
test_prefill_plp
():
'''Test prefill with APC and some prompt logprobs (plp) requests.
1. Schedule plp request and validate APC block allocation
2. Schedule non-plp request and validate blocks
3. Schedule plp request; no hit should occur; validate blocks
'''
manager
=
KVCacheManager
(
block_size
=
16
,
num_gpu_blocks
=
10
,
max_model_len
=
8192
,
sliding_window
=
None
,
enable_caching
=
True
,
num_preallocate_tokens
=
16
,
)
# Complete 3 blocks (48 tokens)
common_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
16
)]
# Request #0 is a prompt logprobs request
# Fully cache miss
# Incomplete 1 block (7 tokens)
unique_token_ids
=
[
3
]
*
7
all_token_ids
=
common_token_ids
+
unique_token_ids
req0
=
make_request
(
"0"
,
all_token_ids
,
prompt_logprobs
=
5
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
assert
len
(
manager
.
req_to_block_hashes
[
req0
.
request_id
])
==
3
assert
not
computed_blocks
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
computed_blocks
)
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
0
,
1
,
2
,
3
,
4
]
req0_block_hashes
=
[
b
.
block_hash
for
b
in
blocks
]
# Check full block metadata
parent_block_hash
=
None
for
block_id
in
(
0
,
1
,
2
):
block_tokens
=
tuple
(
all_token_ids
[
block_id
*
16
:(
block_id
+
1
)
*
16
])
block_hash
=
hash_block_tokens
(
parent_block_hash
,
block_tokens
)
assert
manager
.
block_pool
.
blocks
[
block_id
].
block_hash
==
block_hash
assert
manager
.
block_pool
.
blocks
[
block_id
].
ref_cnt
==
1
parent_block_hash
=
block_hash
.
hash_value
# Check partial/preallocated block metadata
for
block_id
in
(
3
,
4
):
assert
manager
.
block_pool
.
blocks
[
block_id
].
block_hash
is
None
assert
manager
.
block_pool
.
blocks
[
block_id
].
ref_cnt
==
1
# Request #1 is a non-prompt-logprobs request:
# Cache hit in the common prefix when the original block is still in use.
# Incomplete 1 block (5 tokens)
unique_token_ids
=
[
3
]
*
5
req1
=
make_request
(
"1"
,
common_token_ids
+
unique_token_ids
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
len
(
manager
.
req_to_block_hashes
[
req1
.
request_id
])
==
3
assert
[
b
.
block_id
for
b
in
computed_blocks
]
==
[
0
,
1
,
2
]
assert
num_computed_tokens
==
3
*
16
num_new_tokens
=
53
-
3
*
16
blocks
=
manager
.
allocate_slots
(
req1
,
num_new_tokens
,
computed_blocks
)
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
5
,
6
]
for
block
in
computed_blocks
:
assert
block
.
ref_cnt
==
2
# At this point, we should have 3 free blocks left.
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
3
manager
.
free
(
req0
)
manager
.
free
(
req1
)
# All blocks should be available.
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
10
# The order should be
# [unallocated (7, 8, 9)]
# [unique_req0 (4, 3)]
# [unique_req1 (6, 5)]
# [common (2, 1, 0)]
assert
[
b
.
block_id
for
b
in
manager
.
block_pool
.
free_block_queue
.
get_all_free_blocks
()
]
==
[
7
,
8
,
9
,
4
,
3
,
6
,
5
,
2
,
1
,
0
]
# Request #2 is a prompt-logprobs request:
# NO cache hit in the common prefix; duplicates request #0 cached blocks
unique_token_ids
=
[
3
]
*
6
req2
=
make_request
(
"2"
,
common_token_ids
+
unique_token_ids
,
prompt_logprobs
=
5
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
len
(
manager
.
req_to_block_hashes
[
req2
.
request_id
])
==
3
assert
not
computed_blocks
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req2
,
55
,
computed_blocks
)
block_ids
=
[
b
.
block_id
for
b
in
blocks
]
# Duplicate cached blocks have different ids but same hashes vs request #0
assert
[
b
.
block_hash
for
b
in
blocks
]
==
req0_block_hashes
assert
block_ids
!=
[
0
,
1
,
2
,
3
,
4
]
# Request #2 block hashes are valid since request #0 hashes are.
# Check block reference counts.
for
block_id
in
block_ids
:
assert
manager
.
block_pool
.
blocks
[
block_id
].
ref_cnt
==
1
manager
.
free
(
req2
)
def
test_decode
():
manager
=
KVCacheManager
(
block_size
=
16
,
...
...
tests/v1/core/test_scheduler.py
View file @
ef640440
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Optional
import
pytest
from
vllm.config
import
CacheConfig
,
ModelConfig
,
SchedulerConfig
,
VllmConfig
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.sampling_params
import
SamplingParams
...
...
@@ -16,7 +18,21 @@ def create_scheduler(
model
:
str
=
"facebook/opt-125m"
,
max_num_seqs
:
int
=
16
,
max_num_batched_tokens
:
int
=
8192
,
enable_prefix_caching
:
Optional
[
bool
]
=
None
,
)
->
Scheduler
:
'''Create scheduler under test.
Args:
model: model under test
max_num_seqs: max sequences to schedule
max_num_batch_tokens: max num tokens to batch
enable_prefix_caching: optionally force APC config
(True/False) or use default
(None)
Returns:
:class:`Scheduler` instance
'''
scheduler_config
=
SchedulerConfig
(
max_num_seqs
=
max_num_seqs
,
max_num_batched_tokens
=
max_num_batched_tokens
,
...
...
@@ -31,11 +47,16 @@ def create_scheduler(
dtype
=
"float16"
,
seed
=
42
,
)
# Cache config, optionally force APC
kwargs_cache
=
({}
if
enable_prefix_caching
is
None
else
{
'enable_prefix_caching'
:
enable_prefix_caching
})
cache_config
=
CacheConfig
(
block_size
=
16
,
gpu_memory_utilization
=
0.9
,
swap_space
=
0
,
cache_dtype
=
"auto"
,
**
kwargs_cache
,
)
vllm_config
=
VllmConfig
(
scheduler_config
=
scheduler_config
,
...
...
@@ -54,16 +75,16 @@ def create_scheduler(
)
def
create_requests
(
num_requests
:
int
,
num_tokens
:
int
=
10
,
mm_positions
:
Optional
[
list
[
PlaceholderRange
]]
=
None
,
max_tokens
:
int
=
16
,
stop_token_ids
:
Optional
[
list
[
int
]]
=
None
,
):
def
create_requests
(
num_requests
:
int
,
num_tokens
:
int
=
10
,
mm_positions
:
Optional
[
list
[
PlaceholderRange
]]
=
None
,
max_tokens
:
int
=
16
,
stop_token_ids
:
Optional
[
list
[
int
]]
=
None
,
prompt_logprobs
:
Optional
[
int
]
=
None
):
sampling_params
=
SamplingParams
(
ignore_eos
=
False
,
max_tokens
=
max_tokens
,
stop_token_ids
=
stop_token_ids
)
stop_token_ids
=
stop_token_ids
,
prompt_logprobs
=
prompt_logprobs
)
requests
=
[]
for
i
in
range
(
num_requests
):
if
mm_positions
is
not
None
:
...
...
@@ -122,9 +143,18 @@ def test_get_num_unfinished_requests():
assert
scheduler
.
get_num_unfinished_requests
()
==
len
(
requests
)
-
i
-
1
def
test_schedule
():
scheduler
=
create_scheduler
()
requests
=
create_requests
(
num_requests
=
10
)
@
pytest
.
mark
.
parametrize
(
"enable_prefix_caching, prompt_logprobs"
,
[
(
None
,
None
),
(
True
,
5
),
])
def
test_schedule
(
enable_prefix_caching
:
Optional
[
bool
],
prompt_logprobs
:
Optional
[
int
]):
'''Test scheduling.
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
'''
scheduler
=
create_scheduler
(
enable_prefix_caching
=
enable_prefix_caching
)
requests
=
create_requests
(
num_requests
=
10
,
prompt_logprobs
=
prompt_logprobs
)
for
request
in
requests
:
scheduler
.
add_request
(
request
)
...
...
@@ -427,14 +457,21 @@ def test_stop_via_update_from_output():
assert
list
(
requests
[
0
].
output_token_ids
)
==
[
EOS_TOKEN_ID
,
10
,
11
]
def
test_schedule_concurrent_batches
():
@
pytest
.
mark
.
parametrize
(
"enable_prefix_caching, prompt_logprobs"
,
[
(
None
,
None
),
(
True
,
5
),
])
def
test_schedule_concurrent_batches
(
enable_prefix_caching
:
Optional
[
bool
],
prompt_logprobs
:
Optional
[
int
]):
scheduler
=
create_scheduler
(
max_num_batched_tokens
=
1024
,
max_num_seqs
=
2
,
enable_prefix_caching
=
enable_prefix_caching
,
)
requests
=
create_requests
(
num_requests
=
2
,
num_tokens
=
512
,
prompt_logprobs
=
prompt_logprobs
,
)
# Schedule the first request.
...
...
tests/v1/engine/test_async_llm.py
View file @
ef640440
...
...
@@ -6,7 +6,6 @@ from typing import Optional
import
pytest
from
tests.v1.engine.utils
import
PLP_APC_UNSUPPORTED_MSG
from
vllm
import
SamplingParams
from
vllm.assets.image
import
ImageAsset
from
vllm.engine.arg_utils
import
AsyncEngineArgs
...
...
@@ -72,41 +71,6 @@ async def generate(engine: AsyncLLM,
return
count
,
request_id
@
pytest
.
mark
.
parametrize
(
"output_kind"
,
[
RequestOutputKind
.
DELTA
,
RequestOutputKind
.
FINAL_ONLY
])
@
pytest
.
mark
.
asyncio
async
def
test_async_llm_refuses_prompt_logprobs_with_apc
(
monkeypatch
,
output_kind
:
RequestOutputKind
):
"""Test passes if AsyncLLM raises an exception when it is configured
for automatic prefix caching and it receives a request with
prompt_logprobs enabled, which is incompatible."""
# TODO(rickyx): Remove monkeypatch VLLM_USE_V1 setting once we have a
# better way to test V1 so that in the future when we switch, we don't
# have to change all the tests.
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
# Create AsyncLLM engine with APC
apc_engine_args
=
AsyncEngineArgs
(
model
=
"facebook/opt-125m"
,
enable_prefix_caching
=
True
,
gpu_memory_utilization
=
0.8
,
disable_log_requests
=
True
)
engine
=
AsyncLLM
.
from_engine_args
(
apc_engine_args
)
try
:
with
pytest
.
raises
(
ValueError
)
as
excinfo
:
# Issue a request with prompt logprobs enabled, which should fail
await
asyncio
.
create_task
(
generate
(
engine
,
"request-0"
,
TEXT_PROMPT
,
output_kind
,
10
,
prompt_logprobs
=
5
))
# Validate exception string is correct
assert
str
(
excinfo
.
value
)
==
PLP_APC_UNSUPPORTED_MSG
finally
:
# Shut down engine
engine
.
shutdown
()
@
pytest
.
mark
.
parametrize
(
"output_kind"
,
[
RequestOutputKind
.
DELTA
,
RequestOutputKind
.
FINAL_ONLY
])
@
pytest
.
mark
.
parametrize
(
"engine_args_and_prompt"
,
...
...
tests/v1/engine/test_llm_engine.py
View file @
ef640440
...
...
@@ -5,7 +5,6 @@ from typing import Optional
import
pytest
from
tests.v1.engine.utils
import
PLP_APC_UNSUPPORTED_MSG
from
vllm
import
LLM
,
SamplingParams
MODEL
=
"facebook/opt-125m"
...
...
@@ -98,17 +97,3 @@ def test_parallel_sampling(vllm_model, example_prompts) -> None:
raise
AssertionError
(
f
"
{
len
(
completion_counts
)
}
unique completions; expected"
f
"
{
n
}
. Repeats:
{
repeats
}
"
)
def
test_llm_engine_refuses_prompt_logprobs_with_apc
(
vllm_model_apc
):
"""Test passes if LLMEngine raises an exception when it is configured
for automatic prefix caching and it receives a request with
prompt_logprobs enabled, which is incompatible."""
model
:
LLM
=
vllm_model_apc
.
model
with
pytest
.
raises
(
ValueError
)
as
excinfo
:
model
.
generate
(
"Hello, my name is"
,
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
,
prompt_logprobs
=
5
))
# Validate exception string is correct
assert
str
(
excinfo
.
value
)
==
PLP_APC_UNSUPPORTED_MSG
tests/v1/engine/utils.py
View file @
ef640440
...
...
@@ -30,9 +30,6 @@ FULL_STRINGS = [
STOP_STRINGS
=
[
"I love working on"
,
"company by far"
,
"brother in"
]
PROMPT_LEN
=
5
PLP_APC_UNSUPPORTED_MSG
=
(
"Prefix caching with prompt logprobs not yet "
"supported on VLLM V1."
)
random
.
seed
(
42
)
...
...
tests/v1/sample/test_logprobs.py
View file @
ef640440
# SPDX-License-Identifier: Apache-2.0
import
itertools
from
collections.abc
import
Generator
import
pytest
import
torch
from
tests.kernels.utils
import
override_backend_env_variable
from
tests.v1.sample.utils
import
(
BatchLogprobsComposition
,
BatchLogprobsSpecType
,
assert_incr_detok_str_matches_non_incr_detok_str
,
compute_correct_cumulative_logprob
,
get_test_batch
)
from
vllm
import
SamplingParams
from
...conftest
import
VllmRunner
from
...conftest
import
HfRunner
,
VllmRunner
MODEL
=
"meta-llama/Llama-3.2-1B-Instruct"
DTYPE
=
"half"
NONE
=
BatchLogprobsComposition
.
NONE
SAMPLE
=
BatchLogprobsComposition
.
SAMPLE
PROMPT
=
BatchLogprobsComposition
.
PROMPT
SAMPLE_PROMPT
=
BatchLogprobsComposition
.
SAMPLE_PROMPT
@
pytest
.
fixture
(
scope
=
"module"
)
def
vllm_model
(
vllm_runner
):
@
pytest
.
fixture
(
scope
=
"module"
,
# Parameterize APC
params
=
[
False
,
True
])
def
vllm_model
(
vllm_runner
,
request
)
->
Generator
[
VllmRunner
,
None
,
None
]:
with
vllm_runner
(
MODEL
,
dtype
=
DTYPE
,
...
...
@@ -31,22 +41,22 @@ def vllm_model(vllm_runner):
enforce_eager
=
True
,
#TODO: enable this once we support it for
# prompt logprobs.
enable_prefix_caching
=
False
,
enable_prefix_caching
=
request
.
param
,
gpu_memory_utilization
=
0.5
,
)
as
vllm_model
:
yield
vllm_model
@
pytest
.
fixture
(
scope
=
"module"
)
def
hf_model
(
hf_runner
):
def
hf_model
(
hf_runner
)
->
Generator
[
HfRunner
,
None
,
None
]
:
with
hf_runner
(
MODEL
,
dtype
=
DTYPE
)
as
hf_model
:
yield
hf_model
def
_repeat_logprob_config
(
test_prompts
,
logprob_prompt_logprob_list
:
list
[
tuple
]
,
)
->
list
[
tuple
]
:
logprob_prompt_logprob_list
:
BatchLogprobsSpecType
,
)
->
BatchLogprobsSpecType
:
"""Ensure each test prompt has a logprob config.
A logprob config specifies the optional (i.e.
...
...
@@ -91,42 +101,17 @@ def _repeat_logprob_config(
return
logprob_prompt_logprob_list
def
_test_case_get_logprobs_and_prompt_logprobs
(
hf_model
,
vllm_model
,
batch_logprobs_composition
:
str
,
def
_run_and_validate
(
vllm_model
:
VllmRunner
,
test_prompts
:
list
[
str
],
vllm_sampling_params
:
SamplingParams
,
hf_logprobs
:
list
[
list
[
torch
.
Tensor
]],
hf_outputs
:
list
[
tuple
[
list
[
int
],
str
]],
logprob_prompt_logprob_list
:
BatchLogprobsSpecType
,
temperature
:
float
,
example_prompts
,
max_tokens
:
int
,
do_apc
:
bool
,
)
->
None
:
test_prompts
=
example_prompts
max_tokens
=
5
hf_outputs
=
hf_model
.
generate_greedy
(
test_prompts
,
max_tokens
=
max_tokens
,
)
hf_logprobs
=
hf_model
.
generate_greedy_logprobs
(
test_prompts
,
max_tokens
=
max_tokens
,
)
# Batch has mixed sample params
# (different logprobs/prompt logprobs combos)
logprob_prompt_logprob_list
=
get_test_batch
(
batch_logprobs_composition
)
# Ensure that each test prompt has a logprob config for testing
logprob_prompt_logprob_list
=
_repeat_logprob_config
(
test_prompts
,
logprob_prompt_logprob_list
)
# Generate SamplingParams
vllm_sampling_params
=
[
SamplingParams
(
max_tokens
=
max_tokens
,
logprobs
=
num_lp
,
prompt_logprobs
=
num_plp
,
temperature
=
temperature
,
seed
=
1984
)
for
num_lp
,
num_plp
in
logprob_prompt_logprob_list
]
vllm_results
=
vllm_model
.
model
.
generate
(
test_prompts
,
sampling_params
=
vllm_sampling_params
)
...
...
@@ -267,14 +252,13 @@ def _test_case_get_logprobs_and_prompt_logprobs(
assert
vllm_result
.
prompt_logprobs
is
None
#@pytest.mark.skip_global_cleanup
@
pytest
.
mark
.
parametrize
(
"batch_logprobs_composition"
,
[
"
NONE
"
,
"
SAMPLE
"
,
"
PROMPT
"
,
"
SAMPLE_PROMPT
"
])
[
NONE
,
SAMPLE
,
PROMPT
,
SAMPLE_PROMPT
])
@
pytest
.
mark
.
parametrize
(
"temperature"
,
[
0.0
,
2.0
])
def
test_get_logprobs_and_prompt_logprobs
(
hf_model
,
vllm_model
,
batch_logprobs_composition
:
str
,
batch_logprobs_composition
:
BatchLogprobsComposition
,
temperature
:
float
,
example_prompts
,
)
->
None
:
...
...
@@ -292,25 +276,70 @@ def test_get_logprobs_and_prompt_logprobs(
batch_logprobs_composition controls the logprobs configurations for
requests in the batch under test.
APC tests run two test iterations so that cache hits occur.
To save time, only test one APC-enabled scenario
(sample & prompt logprobs enabled, temperature>0.0).
Args:
hf_model
vllm_model
hf_model
: HuggingFace reference model fixture
vllm_model
: vLLM model fixture
batch_logprobs_composition: logprobs configuration for test batch
example_prompts
monkeypatch
temperature: "temperature" sampling parameter
example_prompts: example prompt fixture
"""
_test_case_get_logprobs_and_prompt_logprobs
(
hf_model
=
hf_model
,
vllm_model
=
vllm_model
,
batch_logprobs_composition
=
batch_logprobs_composition
,
temperature
=
temperature
,
example_prompts
=
example_prompts
)
do_apc
=
vllm_model
.
model
.
llm_engine
.
cache_config
.
enable_prefix_caching
if
do_apc
and
(
temperature
<
2.0
or
batch_logprobs_composition
!=
SAMPLE_PROMPT
):
# Skip some test-cases to save time.
pytest
.
skip
()
test_prompts
=
example_prompts
max_tokens
=
5
hf_outputs
=
hf_model
.
generate_greedy
(
test_prompts
,
max_tokens
=
max_tokens
,
)
hf_logprobs
=
hf_model
.
generate_greedy_logprobs
(
test_prompts
,
max_tokens
=
max_tokens
,
)
# Batch has mixed sample params
# (different logprobs/prompt logprobs combos)
logprob_prompt_logprob_list
=
get_test_batch
(
batch_logprobs_composition
)
# Ensure that each test prompt has a logprob config for testing
logprob_prompt_logprob_list
=
_repeat_logprob_config
(
test_prompts
,
logprob_prompt_logprob_list
)
# Generate SamplingParams
vllm_sampling_params
=
[
SamplingParams
(
max_tokens
=
max_tokens
,
logprobs
=
num_lp
,
prompt_logprobs
=
num_plp
,
temperature
=
temperature
,
seed
=
1984
)
for
num_lp
,
num_plp
in
logprob_prompt_logprob_list
]
for
_
in
range
(
2
if
do_apc
else
1
):
_run_and_validate
(
vllm_model
=
vllm_model
,
test_prompts
=
test_prompts
,
vllm_sampling_params
=
vllm_sampling_params
,
hf_logprobs
=
hf_logprobs
,
hf_outputs
=
hf_outputs
,
logprob_prompt_logprob_list
=
logprob_prompt_logprob_list
,
temperature
=
temperature
,
max_tokens
=
max_tokens
,
do_apc
=
do_apc
)
def
test_max_logprobs
(
monkeypatch
):
"""vLLM v1 engine should fail a request with `logprobs > max_logprobs`
Should also fail for `prompt_logprobs > max_logprobs`
APC should not matter as this test checks basic request validation.
Args:
monkeypatch
...
...
@@ -330,14 +359,12 @@ def test_max_logprobs(monkeypatch):
runner
.
generate
([
"Hello world"
],
sampling_params
=
bad_sampling_params
)
def
test_none_logprobs
(
vllm_model
,
example_prompts
,
monkeypatch
):
def
test_none_logprobs
(
vllm_model
,
example_prompts
):
"""Engine should return `logprobs` and `prompt_logprobs` as `None`
Args:
vllm_model: vLLM model fixture
example_prompts: list of example prompts (test fixture)
monkeypatch: supports editing env vars and rolling back changes
after the test
"""
max_tokens
=
5
...
...
@@ -356,14 +383,12 @@ def test_none_logprobs(vllm_model, example_prompts, monkeypatch):
assert
results_logprobs_none
[
i
].
prompt_logprobs
is
None
def
test_zero_logprobs
(
vllm_model
,
example_prompts
,
monkeypatch
):
def
test_zero_logprobs
(
vllm_model
,
example_prompts
):
"""Engine should return sampled token and prompt token logprobs
Args:
vllm_model: vLLM model fixture
example_prompts: list of example prompts (test fixture)
monkeypatch: supports editing env vars and rolling back changes
after the test
"""
max_tokens
=
5
...
...
tests/v1/sample/utils.py
View file @
ef640440
# SPDX-License-Identifier: Apache-2.0
import
re
from
enum
import
Enum
from
typing
import
Optional
from
vllm
import
CompletionOutput
def
get_test_batch
(
batch_logprobs_composition
:
str
)
->
list
[
tuple
]:
class
BatchLogprobsComposition
(
Enum
):
"""Types of logprobs configs to include in test batch"""
NONE
=
0
SAMPLE
=
1
PROMPT
=
2
SAMPLE_PROMPT
=
3
BatchLogprobsSpecType
=
list
[
tuple
[
Optional
[
int
],
Optional
[
int
]]]
def
get_test_batch
(
batch_logprobs_composition
:
BatchLogprobsComposition
)
->
BatchLogprobsSpecType
:
"""Generate logprobs configs for a batch of requests
A given request's logprobs configuration is (1) num_sample_logprobs and (2)
num_prompt_logprobs. The batch logprobs configuration is the list of request
logprobs configs.
batch_logprobs_composition ==
"
NONE
"
yields a batch with no sample or prompt
batch_logprobs_composition == NONE yields a batch with no sample or prompt
logprobs
batch_logprobs_composition ==
"
SAMPLE
"
yields a batch with some requests
batch_logprobs_composition == SAMPLE yields a batch with some requests
configured for sample logprobs only, and others configured for no logprobs
batch_logprobs_composition ==
"
PROMPT
"
yields a batch with some requests
batch_logprobs_composition == PROMPT yields a batch with some requests
configured for prompt logprobs only, and others configured for no logprobs
batch_logprobs_composition ==
"
SAMPLE_PROMPT
"
yields a batch with some
batch_logprobs_composition == SAMPLE_PROMPT yields a batch with some
requests configured for sample logprobs and prompt logprobs, some configured
for only sample logprobs or only prompt logprobs, and some configured for
no logprobs
...
...
@@ -34,10 +49,10 @@ def get_test_batch(batch_logprobs_composition: str) -> list[tuple]:
list of (Optional[num_sample_logprobs], Optional[num_prompt_logprobs])
tuples
"""
if
batch_logprobs_composition
==
"
NONE
"
:
if
batch_logprobs_composition
==
BatchLogprobsComposition
.
NONE
:
# No requests with sample or prompt logprobs
return
[(
None
,
None
)]
elif
batch_logprobs_composition
==
"
SAMPLE
"
:
elif
batch_logprobs_composition
==
BatchLogprobsComposition
.
SAMPLE
:
# Requests requiring sample logprobs or no logprobs
return
[
(
None
,
None
),
...
...
@@ -45,7 +60,7 @@ def get_test_batch(batch_logprobs_composition: str) -> list[tuple]:
(
5
,
None
),
(
3
,
None
),
]
elif
batch_logprobs_composition
==
"
PROMPT
"
:
elif
batch_logprobs_composition
==
BatchLogprobsComposition
.
PROMPT
:
# Requests requiring prompt logprobs or no logprobs
return
[
(
None
,
None
),
...
...
@@ -53,7 +68,7 @@ def get_test_batch(batch_logprobs_composition: str) -> list[tuple]:
(
None
,
6
),
(
None
,
5
),
]
elif
batch_logprobs_composition
==
"
SAMPLE_PROMPT
"
:
elif
batch_logprobs_composition
==
BatchLogprobsComposition
.
SAMPLE_PROMPT
:
# Requests requiring either no logprobs, just
# sample logprobs, just prompt logprobs, or
# both sample and prompt logprobs
...
...
vllm/v1/core/kv_cache_manager.py
View file @
ef640440
...
...
@@ -105,8 +105,6 @@ class KVCacheManager:
# Prefix caching is disabled.
return
[],
0
computed_blocks
=
[]
# The block hashes for the request may already be computed
# if the scheduler has tried to schedule the request before.
block_hashes
=
self
.
req_to_block_hashes
[
request
.
request_id
]
...
...
@@ -114,24 +112,31 @@ class KVCacheManager:
block_hashes
=
hash_request_tokens
(
self
.
block_size
,
request
)
self
.
req_to_block_hashes
[
request
.
request_id
]
=
block_hashes
for
block_hash
in
block_hashes
:
# block_hashes is a chain of block hashes. If a block hash is not
# in the cached_block_hash_to_id, the following block hashes are
# not computed yet for sure.
if
cached_block
:
=
self
.
block_pool
.
get_cached_block
(
block_hash
):
computed_blocks
.
append
(
cached_block
)
else
:
break
self
.
prefix_cache_stats
.
requests
+=
1
self
.
prefix_cache_stats
.
queries
+=
len
(
block_hashes
)
self
.
prefix_cache_stats
.
hits
+=
len
(
computed_blocks
)
# NOTE(woosuk): Since incomplete blocks are not eligible for
# sharing, `num_computed_tokens` is always a multiple of
# `block_size`.
num_computed_tokens
=
len
(
computed_blocks
)
*
self
.
block_size
return
computed_blocks
,
num_computed_tokens
if
request
.
sampling_params
.
prompt_logprobs
is
None
:
# Check for cache hits
computed_blocks
=
[]
for
block_hash
in
block_hashes
:
# block_hashes is a chain of block hashes. If a block hash
# is not in the cached_block_hash_to_id, the following
# block hashes are not computed yet for sure.
if
cached_block
:
=
self
.
block_pool
.
get_cached_block
(
block_hash
):
computed_blocks
.
append
(
cached_block
)
else
:
break
self
.
prefix_cache_stats
.
queries
+=
len
(
block_hashes
)
self
.
prefix_cache_stats
.
hits
+=
len
(
computed_blocks
)
# NOTE(woosuk): Since incomplete blocks are not eligible for
# sharing, `num_computed_tokens` is always a multiple of
# `block_size`.
num_computed_tokens
=
len
(
computed_blocks
)
*
self
.
block_size
return
computed_blocks
,
num_computed_tokens
else
:
# Skip cache hits for prompt logprobs
return
[],
0
def
allocate_slots
(
self
,
...
...
vllm/v1/engine/processor.py
View file @
ef640440
...
...
@@ -72,12 +72,6 @@ class Processor:
f
"Requested prompt logprobs of
{
params
.
prompt_logprobs
}
, "
f
"which is greater than max allowed:
{
max_logprobs
}
"
)
# TODO(andy): enable this in follow up by recomputing.
if
(
params
.
prompt_logprobs
is
not
None
and
self
.
cache_config
.
enable_prefix_caching
):
raise
ValueError
(
"Prefix caching with prompt logprobs not yet "
"supported on VLLM V1."
)
def
_validate_sampling_params
(
self
,
params
:
SamplingParams
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment