Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ef640440
Unverified
Commit
ef640440
authored
Mar 07, 2025
by
afeldman-nm
Committed by
GitHub
Mar 08, 2025
Browse files
[V1] Prompt logprobs + APC compatibility; prompt logprobs reqs cannot fill APC (#13949)
parent
66e16a03
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
292 additions
and
162 deletions
+292
-162
tests/v1/core/test_prefix_caching.py
tests/v1/core/test_prefix_caching.py
+110
-2
tests/v1/core/test_scheduler.py
tests/v1/core/test_scheduler.py
+49
-12
tests/v1/engine/test_async_llm.py
tests/v1/engine/test_async_llm.py
+0
-36
tests/v1/engine/test_llm_engine.py
tests/v1/engine/test_llm_engine.py
+0
-15
tests/v1/engine/utils.py
tests/v1/engine/utils.py
+0
-3
tests/v1/sample/test_logprobs.py
tests/v1/sample/test_logprobs.py
+85
-60
tests/v1/sample/utils.py
tests/v1/sample/utils.py
+24
-9
vllm/v1/core/kv_cache_manager.py
vllm/v1/core/kv_cache_manager.py
+24
-19
vllm/v1/engine/processor.py
vllm/v1/engine/processor.py
+0
-6
No files found.
tests/v1/core/test_prefix_caching.py
View file @
ef640440
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
"""Compare the with and without prefix caching."""
"""Compare the with and without prefix caching."""
from
typing
import
Optional
import
pytest
import
pytest
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
...
@@ -15,7 +17,8 @@ from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock,
...
@@ -15,7 +17,8 @@ from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock,
def
make_request
(
request_id
,
def
make_request
(
request_id
,
prompt_token_ids
,
prompt_token_ids
,
mm_positions
=
None
,
mm_positions
=
None
,
mm_hashes
=
None
):
mm_hashes
=
None
,
prompt_logprobs
:
Optional
[
int
]
=
None
):
if
mm_positions
is
None
:
if
mm_positions
is
None
:
multi_modal_inputs
=
None
multi_modal_inputs
=
None
else
:
else
:
...
@@ -28,7 +31,8 @@ def make_request(request_id,
...
@@ -28,7 +31,8 @@ def make_request(request_id,
multi_modal_inputs
=
multi_modal_inputs
,
multi_modal_inputs
=
multi_modal_inputs
,
multi_modal_hashes
=
mm_hashes
,
multi_modal_hashes
=
mm_hashes
,
multi_modal_placeholders
=
mm_positions
,
multi_modal_placeholders
=
mm_positions
,
sampling_params
=
SamplingParams
(
max_tokens
=
17
),
sampling_params
=
SamplingParams
(
max_tokens
=
17
,
prompt_logprobs
=
prompt_logprobs
),
eos_token_id
=
100
,
eos_token_id
=
100
,
arrival_time
=
0
,
arrival_time
=
0
,
lora_request
=
None
,
lora_request
=
None
,
...
@@ -144,6 +148,110 @@ def test_prefill():
...
@@ -144,6 +148,110 @@ def test_prefill():
assert
manager
.
block_pool
.
free_block_queue
.
free_list_tail
is
None
assert
manager
.
block_pool
.
free_block_queue
.
free_list_tail
is
None
def
test_prefill_plp
():
'''Test prefill with APC and some prompt logprobs (plp) requests.
1. Schedule plp request and validate APC block allocation
2. Schedule non-plp request and validate blocks
3. Schedule plp request; no hit should occur; validate blocks
'''
manager
=
KVCacheManager
(
block_size
=
16
,
num_gpu_blocks
=
10
,
max_model_len
=
8192
,
sliding_window
=
None
,
enable_caching
=
True
,
num_preallocate_tokens
=
16
,
)
# Complete 3 blocks (48 tokens)
common_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
16
)]
# Request #0 is a prompt logprobs request
# Fully cache miss
# Incomplete 1 block (7 tokens)
unique_token_ids
=
[
3
]
*
7
all_token_ids
=
common_token_ids
+
unique_token_ids
req0
=
make_request
(
"0"
,
all_token_ids
,
prompt_logprobs
=
5
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
assert
len
(
manager
.
req_to_block_hashes
[
req0
.
request_id
])
==
3
assert
not
computed_blocks
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
computed_blocks
)
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
0
,
1
,
2
,
3
,
4
]
req0_block_hashes
=
[
b
.
block_hash
for
b
in
blocks
]
# Check full block metadata
parent_block_hash
=
None
for
block_id
in
(
0
,
1
,
2
):
block_tokens
=
tuple
(
all_token_ids
[
block_id
*
16
:(
block_id
+
1
)
*
16
])
block_hash
=
hash_block_tokens
(
parent_block_hash
,
block_tokens
)
assert
manager
.
block_pool
.
blocks
[
block_id
].
block_hash
==
block_hash
assert
manager
.
block_pool
.
blocks
[
block_id
].
ref_cnt
==
1
parent_block_hash
=
block_hash
.
hash_value
# Check partial/preallocated block metadata
for
block_id
in
(
3
,
4
):
assert
manager
.
block_pool
.
blocks
[
block_id
].
block_hash
is
None
assert
manager
.
block_pool
.
blocks
[
block_id
].
ref_cnt
==
1
# Request #1 is a non-prompt-logprobs request:
# Cache hit in the common prefix when the original block is still in use.
# Incomplete 1 block (5 tokens)
unique_token_ids
=
[
3
]
*
5
req1
=
make_request
(
"1"
,
common_token_ids
+
unique_token_ids
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
len
(
manager
.
req_to_block_hashes
[
req1
.
request_id
])
==
3
assert
[
b
.
block_id
for
b
in
computed_blocks
]
==
[
0
,
1
,
2
]
assert
num_computed_tokens
==
3
*
16
num_new_tokens
=
53
-
3
*
16
blocks
=
manager
.
allocate_slots
(
req1
,
num_new_tokens
,
computed_blocks
)
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
5
,
6
]
for
block
in
computed_blocks
:
assert
block
.
ref_cnt
==
2
# At this point, we should have 3 free blocks left.
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
3
manager
.
free
(
req0
)
manager
.
free
(
req1
)
# All blocks should be available.
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
10
# The order should be
# [unallocated (7, 8, 9)]
# [unique_req0 (4, 3)]
# [unique_req1 (6, 5)]
# [common (2, 1, 0)]
assert
[
b
.
block_id
for
b
in
manager
.
block_pool
.
free_block_queue
.
get_all_free_blocks
()
]
==
[
7
,
8
,
9
,
4
,
3
,
6
,
5
,
2
,
1
,
0
]
# Request #2 is a prompt-logprobs request:
# NO cache hit in the common prefix; duplicates request #0 cached blocks
unique_token_ids
=
[
3
]
*
6
req2
=
make_request
(
"2"
,
common_token_ids
+
unique_token_ids
,
prompt_logprobs
=
5
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
len
(
manager
.
req_to_block_hashes
[
req2
.
request_id
])
==
3
assert
not
computed_blocks
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req2
,
55
,
computed_blocks
)
block_ids
=
[
b
.
block_id
for
b
in
blocks
]
# Duplicate cached blocks have different ids but same hashes vs request #0
assert
[
b
.
block_hash
for
b
in
blocks
]
==
req0_block_hashes
assert
block_ids
!=
[
0
,
1
,
2
,
3
,
4
]
# Request #2 block hashes are valid since request #0 hashes are.
# Check block reference counts.
for
block_id
in
block_ids
:
assert
manager
.
block_pool
.
blocks
[
block_id
].
ref_cnt
==
1
manager
.
free
(
req2
)
def
test_decode
():
def
test_decode
():
manager
=
KVCacheManager
(
manager
=
KVCacheManager
(
block_size
=
16
,
block_size
=
16
,
...
...
tests/v1/core/test_scheduler.py
View file @
ef640440
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Optional
from
typing
import
Optional
import
pytest
from
vllm.config
import
CacheConfig
,
ModelConfig
,
SchedulerConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
ModelConfig
,
SchedulerConfig
,
VllmConfig
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
...
@@ -16,7 +18,21 @@ def create_scheduler(
...
@@ -16,7 +18,21 @@ def create_scheduler(
model
:
str
=
"facebook/opt-125m"
,
model
:
str
=
"facebook/opt-125m"
,
max_num_seqs
:
int
=
16
,
max_num_seqs
:
int
=
16
,
max_num_batched_tokens
:
int
=
8192
,
max_num_batched_tokens
:
int
=
8192
,
enable_prefix_caching
:
Optional
[
bool
]
=
None
,
)
->
Scheduler
:
)
->
Scheduler
:
'''Create scheduler under test.
Args:
model: model under test
max_num_seqs: max sequences to schedule
max_num_batch_tokens: max num tokens to batch
enable_prefix_caching: optionally force APC config
(True/False) or use default
(None)
Returns:
:class:`Scheduler` instance
'''
scheduler_config
=
SchedulerConfig
(
scheduler_config
=
SchedulerConfig
(
max_num_seqs
=
max_num_seqs
,
max_num_seqs
=
max_num_seqs
,
max_num_batched_tokens
=
max_num_batched_tokens
,
max_num_batched_tokens
=
max_num_batched_tokens
,
...
@@ -31,11 +47,16 @@ def create_scheduler(
...
@@ -31,11 +47,16 @@ def create_scheduler(
dtype
=
"float16"
,
dtype
=
"float16"
,
seed
=
42
,
seed
=
42
,
)
)
# Cache config, optionally force APC
kwargs_cache
=
({}
if
enable_prefix_caching
is
None
else
{
'enable_prefix_caching'
:
enable_prefix_caching
})
cache_config
=
CacheConfig
(
cache_config
=
CacheConfig
(
block_size
=
16
,
block_size
=
16
,
gpu_memory_utilization
=
0.9
,
gpu_memory_utilization
=
0.9
,
swap_space
=
0
,
swap_space
=
0
,
cache_dtype
=
"auto"
,
cache_dtype
=
"auto"
,
**
kwargs_cache
,
)
)
vllm_config
=
VllmConfig
(
vllm_config
=
VllmConfig
(
scheduler_config
=
scheduler_config
,
scheduler_config
=
scheduler_config
,
...
@@ -54,16 +75,16 @@ def create_scheduler(
...
@@ -54,16 +75,16 @@ def create_scheduler(
)
)
def
create_requests
(
def
create_requests
(
num_requests
:
int
,
num_requests
:
int
,
num_tokens
:
int
=
10
,
num_tokens
:
int
=
10
,
mm_positions
:
Optional
[
list
[
PlaceholderRange
]]
=
None
,
mm_positions
:
Optional
[
list
[
PlaceholderRange
]]
=
None
,
max_tokens
:
int
=
16
,
max_tokens
:
int
=
16
,
stop_token_ids
:
Optional
[
list
[
int
]]
=
None
,
stop_token_ids
:
Optional
[
list
[
int
]]
=
None
,
):
prompt_logprobs
:
Optional
[
int
]
=
None
):
sampling_params
=
SamplingParams
(
ignore_eos
=
False
,
sampling_params
=
SamplingParams
(
ignore_eos
=
False
,
max_tokens
=
max_tokens
,
max_tokens
=
max_tokens
,
stop_token_ids
=
stop_token_ids
)
stop_token_ids
=
stop_token_ids
,
prompt_logprobs
=
prompt_logprobs
)
requests
=
[]
requests
=
[]
for
i
in
range
(
num_requests
):
for
i
in
range
(
num_requests
):
if
mm_positions
is
not
None
:
if
mm_positions
is
not
None
:
...
@@ -122,9 +143,18 @@ def test_get_num_unfinished_requests():
...
@@ -122,9 +143,18 @@ def test_get_num_unfinished_requests():
assert
scheduler
.
get_num_unfinished_requests
()
==
len
(
requests
)
-
i
-
1
assert
scheduler
.
get_num_unfinished_requests
()
==
len
(
requests
)
-
i
-
1
def
test_schedule
():
@
pytest
.
mark
.
parametrize
(
"enable_prefix_caching, prompt_logprobs"
,
[
scheduler
=
create_scheduler
()
(
None
,
None
),
requests
=
create_requests
(
num_requests
=
10
)
(
True
,
5
),
])
def
test_schedule
(
enable_prefix_caching
:
Optional
[
bool
],
prompt_logprobs
:
Optional
[
int
]):
'''Test scheduling.
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
'''
scheduler
=
create_scheduler
(
enable_prefix_caching
=
enable_prefix_caching
)
requests
=
create_requests
(
num_requests
=
10
,
prompt_logprobs
=
prompt_logprobs
)
for
request
in
requests
:
for
request
in
requests
:
scheduler
.
add_request
(
request
)
scheduler
.
add_request
(
request
)
...
@@ -427,14 +457,21 @@ def test_stop_via_update_from_output():
...
@@ -427,14 +457,21 @@ def test_stop_via_update_from_output():
assert
list
(
requests
[
0
].
output_token_ids
)
==
[
EOS_TOKEN_ID
,
10
,
11
]
assert
list
(
requests
[
0
].
output_token_ids
)
==
[
EOS_TOKEN_ID
,
10
,
11
]
def
test_schedule_concurrent_batches
():
@
pytest
.
mark
.
parametrize
(
"enable_prefix_caching, prompt_logprobs"
,
[
(
None
,
None
),
(
True
,
5
),
])
def
test_schedule_concurrent_batches
(
enable_prefix_caching
:
Optional
[
bool
],
prompt_logprobs
:
Optional
[
int
]):
scheduler
=
create_scheduler
(
scheduler
=
create_scheduler
(
max_num_batched_tokens
=
1024
,
max_num_batched_tokens
=
1024
,
max_num_seqs
=
2
,
max_num_seqs
=
2
,
enable_prefix_caching
=
enable_prefix_caching
,
)
)
requests
=
create_requests
(
requests
=
create_requests
(
num_requests
=
2
,
num_requests
=
2
,
num_tokens
=
512
,
num_tokens
=
512
,
prompt_logprobs
=
prompt_logprobs
,
)
)
# Schedule the first request.
# Schedule the first request.
...
...
tests/v1/engine/test_async_llm.py
View file @
ef640440
...
@@ -6,7 +6,6 @@ from typing import Optional
...
@@ -6,7 +6,6 @@ from typing import Optional
import
pytest
import
pytest
from
tests.v1.engine.utils
import
PLP_APC_UNSUPPORTED_MSG
from
vllm
import
SamplingParams
from
vllm
import
SamplingParams
from
vllm.assets.image
import
ImageAsset
from
vllm.assets.image
import
ImageAsset
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
...
@@ -72,41 +71,6 @@ async def generate(engine: AsyncLLM,
...
@@ -72,41 +71,6 @@ async def generate(engine: AsyncLLM,
return
count
,
request_id
return
count
,
request_id
@
pytest
.
mark
.
parametrize
(
"output_kind"
,
[
RequestOutputKind
.
DELTA
,
RequestOutputKind
.
FINAL_ONLY
])
@
pytest
.
mark
.
asyncio
async
def
test_async_llm_refuses_prompt_logprobs_with_apc
(
monkeypatch
,
output_kind
:
RequestOutputKind
):
"""Test passes if AsyncLLM raises an exception when it is configured
for automatic prefix caching and it receives a request with
prompt_logprobs enabled, which is incompatible."""
# TODO(rickyx): Remove monkeypatch VLLM_USE_V1 setting once we have a
# better way to test V1 so that in the future when we switch, we don't
# have to change all the tests.
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
# Create AsyncLLM engine with APC
apc_engine_args
=
AsyncEngineArgs
(
model
=
"facebook/opt-125m"
,
enable_prefix_caching
=
True
,
gpu_memory_utilization
=
0.8
,
disable_log_requests
=
True
)
engine
=
AsyncLLM
.
from_engine_args
(
apc_engine_args
)
try
:
with
pytest
.
raises
(
ValueError
)
as
excinfo
:
# Issue a request with prompt logprobs enabled, which should fail
await
asyncio
.
create_task
(
generate
(
engine
,
"request-0"
,
TEXT_PROMPT
,
output_kind
,
10
,
prompt_logprobs
=
5
))
# Validate exception string is correct
assert
str
(
excinfo
.
value
)
==
PLP_APC_UNSUPPORTED_MSG
finally
:
# Shut down engine
engine
.
shutdown
()
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"output_kind"
,
[
RequestOutputKind
.
DELTA
,
RequestOutputKind
.
FINAL_ONLY
])
"output_kind"
,
[
RequestOutputKind
.
DELTA
,
RequestOutputKind
.
FINAL_ONLY
])
@
pytest
.
mark
.
parametrize
(
"engine_args_and_prompt"
,
@
pytest
.
mark
.
parametrize
(
"engine_args_and_prompt"
,
...
...
tests/v1/engine/test_llm_engine.py
View file @
ef640440
...
@@ -5,7 +5,6 @@ from typing import Optional
...
@@ -5,7 +5,6 @@ from typing import Optional
import
pytest
import
pytest
from
tests.v1.engine.utils
import
PLP_APC_UNSUPPORTED_MSG
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
MODEL
=
"facebook/opt-125m"
MODEL
=
"facebook/opt-125m"
...
@@ -98,17 +97,3 @@ def test_parallel_sampling(vllm_model, example_prompts) -> None:
...
@@ -98,17 +97,3 @@ def test_parallel_sampling(vllm_model, example_prompts) -> None:
raise
AssertionError
(
raise
AssertionError
(
f
"
{
len
(
completion_counts
)
}
unique completions; expected"
f
"
{
len
(
completion_counts
)
}
unique completions; expected"
f
"
{
n
}
. Repeats:
{
repeats
}
"
)
f
"
{
n
}
. Repeats:
{
repeats
}
"
)
def
test_llm_engine_refuses_prompt_logprobs_with_apc
(
vllm_model_apc
):
"""Test passes if LLMEngine raises an exception when it is configured
for automatic prefix caching and it receives a request with
prompt_logprobs enabled, which is incompatible."""
model
:
LLM
=
vllm_model_apc
.
model
with
pytest
.
raises
(
ValueError
)
as
excinfo
:
model
.
generate
(
"Hello, my name is"
,
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
,
prompt_logprobs
=
5
))
# Validate exception string is correct
assert
str
(
excinfo
.
value
)
==
PLP_APC_UNSUPPORTED_MSG
tests/v1/engine/utils.py
View file @
ef640440
...
@@ -30,9 +30,6 @@ FULL_STRINGS = [
...
@@ -30,9 +30,6 @@ FULL_STRINGS = [
STOP_STRINGS
=
[
"I love working on"
,
"company by far"
,
"brother in"
]
STOP_STRINGS
=
[
"I love working on"
,
"company by far"
,
"brother in"
]
PROMPT_LEN
=
5
PROMPT_LEN
=
5
PLP_APC_UNSUPPORTED_MSG
=
(
"Prefix caching with prompt logprobs not yet "
"supported on VLLM V1."
)
random
.
seed
(
42
)
random
.
seed
(
42
)
...
...
tests/v1/sample/test_logprobs.py
View file @
ef640440
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
itertools
import
itertools
from
collections.abc
import
Generator
import
pytest
import
pytest
import
torch
import
torch
from
tests.kernels.utils
import
override_backend_env_variable
from
tests.kernels.utils
import
override_backend_env_variable
from
tests.v1.sample.utils
import
(
from
tests.v1.sample.utils
import
(
BatchLogprobsComposition
,
BatchLogprobsSpecType
,
assert_incr_detok_str_matches_non_incr_detok_str
,
assert_incr_detok_str_matches_non_incr_detok_str
,
compute_correct_cumulative_logprob
,
get_test_batch
)
compute_correct_cumulative_logprob
,
get_test_batch
)
from
vllm
import
SamplingParams
from
vllm
import
SamplingParams
from
...conftest
import
VllmRunner
from
...conftest
import
HfRunner
,
VllmRunner
MODEL
=
"meta-llama/Llama-3.2-1B-Instruct"
MODEL
=
"meta-llama/Llama-3.2-1B-Instruct"
DTYPE
=
"half"
DTYPE
=
"half"
NONE
=
BatchLogprobsComposition
.
NONE
SAMPLE
=
BatchLogprobsComposition
.
SAMPLE
PROMPT
=
BatchLogprobsComposition
.
PROMPT
SAMPLE_PROMPT
=
BatchLogprobsComposition
.
SAMPLE_PROMPT
@
pytest
.
fixture
(
scope
=
"module"
)
def
vllm_model
(
vllm_runner
):
@
pytest
.
fixture
(
scope
=
"module"
,
# Parameterize APC
params
=
[
False
,
True
])
def
vllm_model
(
vllm_runner
,
request
)
->
Generator
[
VllmRunner
,
None
,
None
]:
with
vllm_runner
(
with
vllm_runner
(
MODEL
,
MODEL
,
dtype
=
DTYPE
,
dtype
=
DTYPE
,
...
@@ -31,22 +41,22 @@ def vllm_model(vllm_runner):
...
@@ -31,22 +41,22 @@ def vllm_model(vllm_runner):
enforce_eager
=
True
,
enforce_eager
=
True
,
#TODO: enable this once we support it for
#TODO: enable this once we support it for
# prompt logprobs.
# prompt logprobs.
enable_prefix_caching
=
False
,
enable_prefix_caching
=
request
.
param
,
gpu_memory_utilization
=
0.5
,
gpu_memory_utilization
=
0.5
,
)
as
vllm_model
:
)
as
vllm_model
:
yield
vllm_model
yield
vllm_model
@
pytest
.
fixture
(
scope
=
"module"
)
@
pytest
.
fixture
(
scope
=
"module"
)
def
hf_model
(
hf_runner
):
def
hf_model
(
hf_runner
)
->
Generator
[
HfRunner
,
None
,
None
]
:
with
hf_runner
(
MODEL
,
dtype
=
DTYPE
)
as
hf_model
:
with
hf_runner
(
MODEL
,
dtype
=
DTYPE
)
as
hf_model
:
yield
hf_model
yield
hf_model
def
_repeat_logprob_config
(
def
_repeat_logprob_config
(
test_prompts
,
test_prompts
,
logprob_prompt_logprob_list
:
list
[
tuple
]
,
logprob_prompt_logprob_list
:
BatchLogprobsSpecType
,
)
->
list
[
tuple
]
:
)
->
BatchLogprobsSpecType
:
"""Ensure each test prompt has a logprob config.
"""Ensure each test prompt has a logprob config.
A logprob config specifies the optional (i.e.
A logprob config specifies the optional (i.e.
...
@@ -91,42 +101,17 @@ def _repeat_logprob_config(
...
@@ -91,42 +101,17 @@ def _repeat_logprob_config(
return
logprob_prompt_logprob_list
return
logprob_prompt_logprob_list
def
_test_case_get_logprobs_and_prompt_logprobs
(
def
_run_and_validate
(
hf_model
,
vllm_model
:
VllmRunner
,
vllm_model
,
test_prompts
:
list
[
str
],
batch_logprobs_composition
:
str
,
vllm_sampling_params
:
SamplingParams
,
hf_logprobs
:
list
[
list
[
torch
.
Tensor
]],
hf_outputs
:
list
[
tuple
[
list
[
int
],
str
]],
logprob_prompt_logprob_list
:
BatchLogprobsSpecType
,
temperature
:
float
,
temperature
:
float
,
example_prompts
,
max_tokens
:
int
,
do_apc
:
bool
,
)
->
None
:
)
->
None
:
test_prompts
=
example_prompts
max_tokens
=
5
hf_outputs
=
hf_model
.
generate_greedy
(
test_prompts
,
max_tokens
=
max_tokens
,
)
hf_logprobs
=
hf_model
.
generate_greedy_logprobs
(
test_prompts
,
max_tokens
=
max_tokens
,
)
# Batch has mixed sample params
# (different logprobs/prompt logprobs combos)
logprob_prompt_logprob_list
=
get_test_batch
(
batch_logprobs_composition
)
# Ensure that each test prompt has a logprob config for testing
logprob_prompt_logprob_list
=
_repeat_logprob_config
(
test_prompts
,
logprob_prompt_logprob_list
)
# Generate SamplingParams
vllm_sampling_params
=
[
SamplingParams
(
max_tokens
=
max_tokens
,
logprobs
=
num_lp
,
prompt_logprobs
=
num_plp
,
temperature
=
temperature
,
seed
=
1984
)
for
num_lp
,
num_plp
in
logprob_prompt_logprob_list
]
vllm_results
=
vllm_model
.
model
.
generate
(
vllm_results
=
vllm_model
.
model
.
generate
(
test_prompts
,
sampling_params
=
vllm_sampling_params
)
test_prompts
,
sampling_params
=
vllm_sampling_params
)
...
@@ -267,14 +252,13 @@ def _test_case_get_logprobs_and_prompt_logprobs(
...
@@ -267,14 +252,13 @@ def _test_case_get_logprobs_and_prompt_logprobs(
assert
vllm_result
.
prompt_logprobs
is
None
assert
vllm_result
.
prompt_logprobs
is
None
#@pytest.mark.skip_global_cleanup
@
pytest
.
mark
.
parametrize
(
"batch_logprobs_composition"
,
@
pytest
.
mark
.
parametrize
(
"batch_logprobs_composition"
,
[
"
NONE
"
,
"
SAMPLE
"
,
"
PROMPT
"
,
"
SAMPLE_PROMPT
"
])
[
NONE
,
SAMPLE
,
PROMPT
,
SAMPLE_PROMPT
])
@
pytest
.
mark
.
parametrize
(
"temperature"
,
[
0.0
,
2.0
])
@
pytest
.
mark
.
parametrize
(
"temperature"
,
[
0.0
,
2.0
])
def
test_get_logprobs_and_prompt_logprobs
(
def
test_get_logprobs_and_prompt_logprobs
(
hf_model
,
hf_model
,
vllm_model
,
vllm_model
,
batch_logprobs_composition
:
str
,
batch_logprobs_composition
:
BatchLogprobsComposition
,
temperature
:
float
,
temperature
:
float
,
example_prompts
,
example_prompts
,
)
->
None
:
)
->
None
:
...
@@ -292,19 +276,62 @@ def test_get_logprobs_and_prompt_logprobs(
...
@@ -292,19 +276,62 @@ def test_get_logprobs_and_prompt_logprobs(
batch_logprobs_composition controls the logprobs configurations for
batch_logprobs_composition controls the logprobs configurations for
requests in the batch under test.
requests in the batch under test.
APC tests run two test iterations so that cache hits occur.
To save time, only test one APC-enabled scenario
(sample & prompt logprobs enabled, temperature>0.0).
Args:
Args:
hf_model
hf_model
: HuggingFace reference model fixture
vllm_model
vllm_model
: vLLM model fixture
batch_logprobs_composition: logprobs configuration for test batch
batch_logprobs_composition: logprobs configuration for test batch
example_prompts
temperature: "temperature" sampling parameter
monkeypatch
example_prompts: example prompt fixture
"""
"""
_test_case_get_logprobs_and_prompt_logprobs
(
do_apc
=
vllm_model
.
model
.
llm_engine
.
cache_config
.
enable_prefix_caching
hf_model
=
hf_model
,
if
do_apc
and
(
temperature
<
2.0
or
batch_logprobs_composition
!=
SAMPLE_PROMPT
):
# Skip some test-cases to save time.
pytest
.
skip
()
test_prompts
=
example_prompts
max_tokens
=
5
hf_outputs
=
hf_model
.
generate_greedy
(
test_prompts
,
max_tokens
=
max_tokens
,
)
hf_logprobs
=
hf_model
.
generate_greedy_logprobs
(
test_prompts
,
max_tokens
=
max_tokens
,
)
# Batch has mixed sample params
# (different logprobs/prompt logprobs combos)
logprob_prompt_logprob_list
=
get_test_batch
(
batch_logprobs_composition
)
# Ensure that each test prompt has a logprob config for testing
logprob_prompt_logprob_list
=
_repeat_logprob_config
(
test_prompts
,
logprob_prompt_logprob_list
)
# Generate SamplingParams
vllm_sampling_params
=
[
SamplingParams
(
max_tokens
=
max_tokens
,
logprobs
=
num_lp
,
prompt_logprobs
=
num_plp
,
temperature
=
temperature
,
seed
=
1984
)
for
num_lp
,
num_plp
in
logprob_prompt_logprob_list
]
for
_
in
range
(
2
if
do_apc
else
1
):
_run_and_validate
(
vllm_model
=
vllm_model
,
vllm_model
=
vllm_model
,
batch_logprobs_composition
=
batch_logprobs_composition
,
test_prompts
=
test_prompts
,
vllm_sampling_params
=
vllm_sampling_params
,
hf_logprobs
=
hf_logprobs
,
hf_outputs
=
hf_outputs
,
logprob_prompt_logprob_list
=
logprob_prompt_logprob_list
,
temperature
=
temperature
,
temperature
=
temperature
,
example_prompts
=
example_prompts
)
max_tokens
=
max_tokens
,
do_apc
=
do_apc
)
def
test_max_logprobs
(
monkeypatch
):
def
test_max_logprobs
(
monkeypatch
):
...
@@ -312,6 +339,8 @@ def test_max_logprobs(monkeypatch):
...
@@ -312,6 +339,8 @@ def test_max_logprobs(monkeypatch):
Should also fail for `prompt_logprobs > max_logprobs`
Should also fail for `prompt_logprobs > max_logprobs`
APC should not matter as this test checks basic request validation.
Args:
Args:
monkeypatch
monkeypatch
"""
"""
...
@@ -330,14 +359,12 @@ def test_max_logprobs(monkeypatch):
...
@@ -330,14 +359,12 @@ def test_max_logprobs(monkeypatch):
runner
.
generate
([
"Hello world"
],
sampling_params
=
bad_sampling_params
)
runner
.
generate
([
"Hello world"
],
sampling_params
=
bad_sampling_params
)
def
test_none_logprobs
(
vllm_model
,
example_prompts
,
monkeypatch
):
def
test_none_logprobs
(
vllm_model
,
example_prompts
):
"""Engine should return `logprobs` and `prompt_logprobs` as `None`
"""Engine should return `logprobs` and `prompt_logprobs` as `None`
Args:
Args:
vllm_model: vLLM model fixture
vllm_model: vLLM model fixture
example_prompts: list of example prompts (test fixture)
example_prompts: list of example prompts (test fixture)
monkeypatch: supports editing env vars and rolling back changes
after the test
"""
"""
max_tokens
=
5
max_tokens
=
5
...
@@ -356,14 +383,12 @@ def test_none_logprobs(vllm_model, example_prompts, monkeypatch):
...
@@ -356,14 +383,12 @@ def test_none_logprobs(vllm_model, example_prompts, monkeypatch):
assert
results_logprobs_none
[
i
].
prompt_logprobs
is
None
assert
results_logprobs_none
[
i
].
prompt_logprobs
is
None
def
test_zero_logprobs
(
vllm_model
,
example_prompts
,
monkeypatch
):
def
test_zero_logprobs
(
vllm_model
,
example_prompts
):
"""Engine should return sampled token and prompt token logprobs
"""Engine should return sampled token and prompt token logprobs
Args:
Args:
vllm_model: vLLM model fixture
vllm_model: vLLM model fixture
example_prompts: list of example prompts (test fixture)
example_prompts: list of example prompts (test fixture)
monkeypatch: supports editing env vars and rolling back changes
after the test
"""
"""
max_tokens
=
5
max_tokens
=
5
...
...
tests/v1/sample/utils.py
View file @
ef640440
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
re
import
re
from
enum
import
Enum
from
typing
import
Optional
from
vllm
import
CompletionOutput
from
vllm
import
CompletionOutput
def
get_test_batch
(
batch_logprobs_composition
:
str
)
->
list
[
tuple
]:
class
BatchLogprobsComposition
(
Enum
):
"""Types of logprobs configs to include in test batch"""
NONE
=
0
SAMPLE
=
1
PROMPT
=
2
SAMPLE_PROMPT
=
3
BatchLogprobsSpecType
=
list
[
tuple
[
Optional
[
int
],
Optional
[
int
]]]
def
get_test_batch
(
batch_logprobs_composition
:
BatchLogprobsComposition
)
->
BatchLogprobsSpecType
:
"""Generate logprobs configs for a batch of requests
"""Generate logprobs configs for a batch of requests
A given request's logprobs configuration is (1) num_sample_logprobs and (2)
A given request's logprobs configuration is (1) num_sample_logprobs and (2)
num_prompt_logprobs. The batch logprobs configuration is the list of request
num_prompt_logprobs. The batch logprobs configuration is the list of request
logprobs configs.
logprobs configs.
batch_logprobs_composition ==
"
NONE
"
yields a batch with no sample or prompt
batch_logprobs_composition == NONE yields a batch with no sample or prompt
logprobs
logprobs
batch_logprobs_composition ==
"
SAMPLE
"
yields a batch with some requests
batch_logprobs_composition == SAMPLE yields a batch with some requests
configured for sample logprobs only, and others configured for no logprobs
configured for sample logprobs only, and others configured for no logprobs
batch_logprobs_composition ==
"
PROMPT
"
yields a batch with some requests
batch_logprobs_composition == PROMPT yields a batch with some requests
configured for prompt logprobs only, and others configured for no logprobs
configured for prompt logprobs only, and others configured for no logprobs
batch_logprobs_composition ==
"
SAMPLE_PROMPT
"
yields a batch with some
batch_logprobs_composition == SAMPLE_PROMPT yields a batch with some
requests configured for sample logprobs and prompt logprobs, some configured
requests configured for sample logprobs and prompt logprobs, some configured
for only sample logprobs or only prompt logprobs, and some configured for
for only sample logprobs or only prompt logprobs, and some configured for
no logprobs
no logprobs
...
@@ -34,10 +49,10 @@ def get_test_batch(batch_logprobs_composition: str) -> list[tuple]:
...
@@ -34,10 +49,10 @@ def get_test_batch(batch_logprobs_composition: str) -> list[tuple]:
list of (Optional[num_sample_logprobs], Optional[num_prompt_logprobs])
list of (Optional[num_sample_logprobs], Optional[num_prompt_logprobs])
tuples
tuples
"""
"""
if
batch_logprobs_composition
==
"
NONE
"
:
if
batch_logprobs_composition
==
BatchLogprobsComposition
.
NONE
:
# No requests with sample or prompt logprobs
# No requests with sample or prompt logprobs
return
[(
None
,
None
)]
return
[(
None
,
None
)]
elif
batch_logprobs_composition
==
"
SAMPLE
"
:
elif
batch_logprobs_composition
==
BatchLogprobsComposition
.
SAMPLE
:
# Requests requiring sample logprobs or no logprobs
# Requests requiring sample logprobs or no logprobs
return
[
return
[
(
None
,
None
),
(
None
,
None
),
...
@@ -45,7 +60,7 @@ def get_test_batch(batch_logprobs_composition: str) -> list[tuple]:
...
@@ -45,7 +60,7 @@ def get_test_batch(batch_logprobs_composition: str) -> list[tuple]:
(
5
,
None
),
(
5
,
None
),
(
3
,
None
),
(
3
,
None
),
]
]
elif
batch_logprobs_composition
==
"
PROMPT
"
:
elif
batch_logprobs_composition
==
BatchLogprobsComposition
.
PROMPT
:
# Requests requiring prompt logprobs or no logprobs
# Requests requiring prompt logprobs or no logprobs
return
[
return
[
(
None
,
None
),
(
None
,
None
),
...
@@ -53,7 +68,7 @@ def get_test_batch(batch_logprobs_composition: str) -> list[tuple]:
...
@@ -53,7 +68,7 @@ def get_test_batch(batch_logprobs_composition: str) -> list[tuple]:
(
None
,
6
),
(
None
,
6
),
(
None
,
5
),
(
None
,
5
),
]
]
elif
batch_logprobs_composition
==
"
SAMPLE_PROMPT
"
:
elif
batch_logprobs_composition
==
BatchLogprobsComposition
.
SAMPLE_PROMPT
:
# Requests requiring either no logprobs, just
# Requests requiring either no logprobs, just
# sample logprobs, just prompt logprobs, or
# sample logprobs, just prompt logprobs, or
# both sample and prompt logprobs
# both sample and prompt logprobs
...
...
vllm/v1/core/kv_cache_manager.py
View file @
ef640440
...
@@ -105,8 +105,6 @@ class KVCacheManager:
...
@@ -105,8 +105,6 @@ class KVCacheManager:
# Prefix caching is disabled.
# Prefix caching is disabled.
return
[],
0
return
[],
0
computed_blocks
=
[]
# The block hashes for the request may already be computed
# The block hashes for the request may already be computed
# if the scheduler has tried to schedule the request before.
# if the scheduler has tried to schedule the request before.
block_hashes
=
self
.
req_to_block_hashes
[
request
.
request_id
]
block_hashes
=
self
.
req_to_block_hashes
[
request
.
request_id
]
...
@@ -114,16 +112,20 @@ class KVCacheManager:
...
@@ -114,16 +112,20 @@ class KVCacheManager:
block_hashes
=
hash_request_tokens
(
self
.
block_size
,
request
)
block_hashes
=
hash_request_tokens
(
self
.
block_size
,
request
)
self
.
req_to_block_hashes
[
request
.
request_id
]
=
block_hashes
self
.
req_to_block_hashes
[
request
.
request_id
]
=
block_hashes
self
.
prefix_cache_stats
.
requests
+=
1
if
request
.
sampling_params
.
prompt_logprobs
is
None
:
# Check for cache hits
computed_blocks
=
[]
for
block_hash
in
block_hashes
:
for
block_hash
in
block_hashes
:
# block_hashes is a chain of block hashes. If a block hash is not
# block_hashes is a chain of block hashes. If a block hash
# in the cached_block_hash_to_id, the following block hashes are
# is not in the cached_block_hash_to_id, the following
# not computed yet for sure.
# block hashes are not computed yet for sure.
if
cached_block
:
=
self
.
block_pool
.
get_cached_block
(
block_hash
):
if
cached_block
:
=
self
.
block_pool
.
get_cached_block
(
block_hash
):
computed_blocks
.
append
(
cached_block
)
computed_blocks
.
append
(
cached_block
)
else
:
else
:
break
break
self
.
prefix_cache_stats
.
requests
+=
1
self
.
prefix_cache_stats
.
queries
+=
len
(
block_hashes
)
self
.
prefix_cache_stats
.
queries
+=
len
(
block_hashes
)
self
.
prefix_cache_stats
.
hits
+=
len
(
computed_blocks
)
self
.
prefix_cache_stats
.
hits
+=
len
(
computed_blocks
)
...
@@ -132,6 +134,9 @@ class KVCacheManager:
...
@@ -132,6 +134,9 @@ class KVCacheManager:
# `block_size`.
# `block_size`.
num_computed_tokens
=
len
(
computed_blocks
)
*
self
.
block_size
num_computed_tokens
=
len
(
computed_blocks
)
*
self
.
block_size
return
computed_blocks
,
num_computed_tokens
return
computed_blocks
,
num_computed_tokens
else
:
# Skip cache hits for prompt logprobs
return
[],
0
def
allocate_slots
(
def
allocate_slots
(
self
,
self
,
...
...
vllm/v1/engine/processor.py
View file @
ef640440
...
@@ -72,12 +72,6 @@ class Processor:
...
@@ -72,12 +72,6 @@ class Processor:
f
"Requested prompt logprobs of
{
params
.
prompt_logprobs
}
, "
f
"Requested prompt logprobs of
{
params
.
prompt_logprobs
}
, "
f
"which is greater than max allowed:
{
max_logprobs
}
"
)
f
"which is greater than max allowed:
{
max_logprobs
}
"
)
# TODO(andy): enable this in follow up by recomputing.
if
(
params
.
prompt_logprobs
is
not
None
and
self
.
cache_config
.
enable_prefix_caching
):
raise
ValueError
(
"Prefix caching with prompt logprobs not yet "
"supported on VLLM V1."
)
def
_validate_sampling_params
(
def
_validate_sampling_params
(
self
,
self
,
params
:
SamplingParams
,
params
:
SamplingParams
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment