Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
081057de
Commit
081057de
authored
Apr 29, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.5' into v0.8.5-ori
parents
7cf5d5c4
ba41cc90
Changes
554
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1482 additions
and
257 deletions
+1482
-257
tests/v1/core/test_kv_cache_utils.py
tests/v1/core/test_kv_cache_utils.py
+5
-13
tests/v1/core/test_prefix_caching.py
tests/v1/core/test_prefix_caching.py
+123
-98
tests/v1/core/test_scheduler.py
tests/v1/core/test_scheduler.py
+432
-28
tests/v1/e2e/test_cascade_attention.py
tests/v1/e2e/test_cascade_attention.py
+9
-1
tests/v1/e2e/test_spec_decode.py
tests/v1/e2e/test_spec_decode.py
+18
-11
tests/v1/engine/conftest.py
tests/v1/engine/conftest.py
+1
-1
tests/v1/engine/test_async_llm.py
tests/v1/engine/test_async_llm.py
+33
-0
tests/v1/engine/test_engine_core.py
tests/v1/engine/test_engine_core.py
+85
-24
tests/v1/engine/test_engine_core_client.py
tests/v1/engine/test_engine_core_client.py
+0
-1
tests/v1/engine/test_output_processor.py
tests/v1/engine/test_output_processor.py
+97
-26
tests/v1/engine/utils.py
tests/v1/engine/utils.py
+2
-3
tests/v1/entrypoints/conftest.py
tests/v1/entrypoints/conftest.py
+18
-16
tests/v1/entrypoints/llm/test_struct_output_generate.py
tests/v1/entrypoints/llm/test_struct_output_generate.py
+173
-4
tests/v1/shutdown/test_delete.py
tests/v1/shutdown/test_delete.py
+97
-0
tests/v1/shutdown/test_forward_error.py
tests/v1/shutdown/test_forward_error.py
+129
-0
tests/v1/shutdown/test_processor_error.py
tests/v1/shutdown/test_processor_error.py
+69
-0
tests/v1/shutdown/test_startup_error.py
tests/v1/shutdown/test_startup_error.py
+97
-0
tests/v1/shutdown/utils.py
tests/v1/shutdown/utils.py
+5
-0
tests/v1/spec_decode/test_max_len.py
tests/v1/spec_decode/test_max_len.py
+57
-0
tests/v1/spec_decode/test_ngram.py
tests/v1/spec_decode/test_ngram.py
+32
-31
No files found.
Too many changes to show.
To preserve performance only
554 of 554+
files are displayed.
Plain diff
Email patch
tests/v1/core/test_kv_cache_utils.py
View file @
081057de
...
@@ -37,7 +37,6 @@ def make_request(request_id,
...
@@ -37,7 +37,6 @@ def make_request(request_id,
return
Request
(
return
Request
(
request_id
=
request_id
,
request_id
=
request_id
,
prompt
=
None
,
prompt_token_ids
=
prompt_token_ids
,
prompt_token_ids
=
prompt_token_ids
,
multi_modal_inputs
=
multi_modal_inputs
,
multi_modal_inputs
=
multi_modal_inputs
,
multi_modal_hashes
=
mm_hashes
,
multi_modal_hashes
=
mm_hashes
,
...
@@ -311,7 +310,7 @@ def test_metrics():
...
@@ -311,7 +310,7 @@ def test_metrics():
def
stats
(
requests
,
queries
,
hits
):
def
stats
(
requests
,
queries
,
hits
):
return
PrefixCacheStats
(
requests
=
requests
,
queries
=
queries
,
hits
=
hits
)
return
PrefixCacheStats
(
requests
=
requests
,
queries
=
queries
,
hits
=
hits
)
metrics
=
PrefixCachingMetrics
(
interval
=
5
)
metrics
=
PrefixCachingMetrics
(
max_recent_requests
=
5
)
assert
metrics
.
hit_rate
==
0.0
assert
metrics
.
hit_rate
==
0.0
metrics
.
observe
(
stats
(
1
,
20
,
9
))
metrics
.
observe
(
stats
(
1
,
20
,
9
))
...
@@ -496,8 +495,7 @@ def test_allocate_with_lookahead():
...
@@ -496,8 +495,7 @@ def test_allocate_with_lookahead():
# Test case 1: Requires additional lookahead tokens
# Test case 1: Requires additional lookahead tokens
kv_cache_manager
=
KVCacheManager
(
kv_cache_config
=
config
,
kv_cache_manager
=
KVCacheManager
(
kv_cache_config
=
config
,
max_model_len
=
100
,
max_model_len
=
100
)
num_preallocate_tokens
=
0
)
blocks
=
kv_cache_manager
.
allocate_slots
(
blocks
=
kv_cache_manager
.
allocate_slots
(
request
,
request
,
num_tokens
=
3
,
num_tokens
=
3
,
...
@@ -507,25 +505,19 @@ def test_allocate_with_lookahead():
...
@@ -507,25 +505,19 @@ def test_allocate_with_lookahead():
# Test case 2: With precomputed blocks
# Test case 2: With precomputed blocks
kv_cache_manager
=
KVCacheManager
(
kv_cache_config
=
config
,
kv_cache_manager
=
KVCacheManager
(
kv_cache_config
=
config
,
max_model_len
=
100
,
max_model_len
=
100
)
num_preallocate_tokens
=
4
)
# num_preallocate_blocks = 4 // 4 - 2 // 4 = 1
# required_blocks = ceil((3 + 2) /4) = 2
# required_blocks = ceil((3 + 2) /4) = 2
# total_blocks = 1 + 2 = 3
blocks
=
kv_cache_manager
.
allocate_slots
(
blocks
=
kv_cache_manager
.
allocate_slots
(
request
,
request
,
num_tokens
=
3
,
num_tokens
=
3
,
num_lookahead_tokens
=
2
,
num_lookahead_tokens
=
2
,
)
)
assert
len
(
blocks
)
==
3
assert
len
(
blocks
)
==
2
# Test case 3: With precomputed blocks
# Test case 3: With precomputed blocks
# num_preallocate_blocks = 4 // 4 - 4 // 4 = 0
# required_blocks = ceil((3 + 4) / 4) = 2
# required_blocks = ceil((3 + 4) / 4) = 2
# total_blocks = 0 + 2 = 2
kv_cache_manager
=
KVCacheManager
(
kv_cache_config
=
config
,
kv_cache_manager
=
KVCacheManager
(
kv_cache_config
=
config
,
max_model_len
=
100
,
max_model_len
=
100
)
num_preallocate_tokens
=
4
)
blocks
=
kv_cache_manager
.
allocate_slots
(
blocks
=
kv_cache_manager
.
allocate_slots
(
request
,
request
,
num_tokens
=
3
,
num_tokens
=
3
,
...
...
tests/v1/core/test_prefix_caching.py
View file @
081057de
...
@@ -8,7 +8,7 @@ import torch
...
@@ -8,7 +8,7 @@ import torch
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
cdiv
,
sha256
from
vllm.utils
import
sha256
from
vllm.v1.core.block_pool
import
BlockPool
from
vllm.v1.core.block_pool
import
BlockPool
from
vllm.v1.core.kv_cache_manager
import
KVCacheManager
,
Request
from
vllm.v1.core.kv_cache_manager
import
KVCacheManager
,
Request
from
vllm.v1.core.kv_cache_utils
import
(
BlockHashType
,
KVCacheBlock
,
from
vllm.v1.core.kv_cache_utils
import
(
BlockHashType
,
KVCacheBlock
,
...
@@ -29,7 +29,6 @@ def make_request(request_id,
...
@@ -29,7 +29,6 @@ def make_request(request_id,
return
Request
(
return
Request
(
request_id
=
request_id
,
request_id
=
request_id
,
prompt
=
None
,
prompt_token_ids
=
prompt_token_ids
,
prompt_token_ids
=
prompt_token_ids
,
multi_modal_inputs
=
multi_modal_inputs
,
multi_modal_inputs
=
multi_modal_inputs
,
multi_modal_hashes
=
mm_hashes
,
multi_modal_hashes
=
mm_hashes
,
...
@@ -61,7 +60,6 @@ def test_prefill(hash_algo):
...
@@ -61,7 +60,6 @@ def test_prefill(hash_algo):
max_model_len
=
8192
,
max_model_len
=
8192
,
enable_caching
=
True
,
enable_caching
=
True
,
caching_hash_algo
=
hash_algo
,
caching_hash_algo
=
hash_algo
,
num_preallocate_tokens
=
16
,
)
)
# choose the hash function according to the parameter
# choose the hash function according to the parameter
...
@@ -80,7 +78,7 @@ def test_prefill(hash_algo):
...
@@ -80,7 +78,7 @@ def test_prefill(hash_algo):
assert
not
computed_blocks
assert
not
computed_blocks
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
computed_blocks
)
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
computed_blocks
)
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
1
,
2
,
3
,
4
,
5
]
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
1
,
2
,
3
,
4
]
# Check full block metadata
# Check full block metadata
parent_block_hash
=
None
parent_block_hash
=
None
...
@@ -92,8 +90,8 @@ def test_prefill(hash_algo):
...
@@ -92,8 +90,8 @@ def test_prefill(hash_algo):
assert
manager
.
block_pool
.
blocks
[
block_id
].
ref_cnt
==
1
assert
manager
.
block_pool
.
blocks
[
block_id
].
ref_cnt
==
1
parent_block_hash
=
block_hash
.
hash_value
parent_block_hash
=
block_hash
.
hash_value
# Check partial
/preallocated
block metadata
# Check partial block metadata
for
block_id
in
(
4
,
5
):
for
block_id
in
(
4
,
):
assert
manager
.
block_pool
.
blocks
[
block_id
].
block_hash
is
None
assert
manager
.
block_pool
.
blocks
[
block_id
].
block_hash
is
None
assert
manager
.
block_pool
.
blocks
[
block_id
].
ref_cnt
==
1
assert
manager
.
block_pool
.
blocks
[
block_id
].
ref_cnt
==
1
...
@@ -107,12 +105,12 @@ def test_prefill(hash_algo):
...
@@ -107,12 +105,12 @@ def test_prefill(hash_algo):
assert
num_computed_tokens
==
3
*
16
assert
num_computed_tokens
==
3
*
16
num_new_tokens
=
53
-
3
*
16
num_new_tokens
=
53
-
3
*
16
blocks
=
manager
.
allocate_slots
(
req1
,
num_new_tokens
,
computed_blocks
)
blocks
=
manager
.
allocate_slots
(
req1
,
num_new_tokens
,
computed_blocks
)
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
6
,
7
]
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
5
]
for
block
in
computed_blocks
:
for
block
in
computed_blocks
:
assert
block
.
ref_cnt
==
2
assert
block
.
ref_cnt
==
2
# At this point, we should have
3
free blocks left.
# At this point, we should have
5
free blocks left.
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
3
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
5
manager
.
free
(
req0
)
manager
.
free
(
req0
)
manager
.
free
(
req1
)
manager
.
free
(
req1
)
...
@@ -120,14 +118,14 @@ def test_prefill(hash_algo):
...
@@ -120,14 +118,14 @@ def test_prefill(hash_algo):
# All blocks should be available.
# All blocks should be available.
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
10
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
10
# The order should be
# The order should be
# [unallocated (8, 9, 10)]
# [unallocated (
6, 7,
8, 9, 10)]
# [unique_req0 (
5,
4)]
# [unique_req0 (4)]
# [unique_req1 (
7, 6
)]
# [unique_req1 (
5
)]
# [common (3, 2, 1)]
# [common (3, 2, 1)]
assert
[
assert
[
b
.
block_id
b
.
block_id
for
b
in
manager
.
block_pool
.
free_block_queue
.
get_all_free_blocks
()
for
b
in
manager
.
block_pool
.
free_block_queue
.
get_all_free_blocks
()
]
==
[
8
,
9
,
10
,
5
,
4
,
7
,
6
,
3
,
2
,
1
]
]
==
[
6
,
7
,
8
,
9
,
10
,
4
,
5
,
3
,
2
,
1
]
# Cache hit in the common prefix when the original block is already free.
# Cache hit in the common prefix when the original block is already free.
# Incomplete 1 block (6 tokens)
# Incomplete 1 block (6 tokens)
...
@@ -139,29 +137,29 @@ def test_prefill(hash_algo):
...
@@ -139,29 +137,29 @@ def test_prefill(hash_algo):
assert
num_computed_tokens
==
3
*
16
assert
num_computed_tokens
==
3
*
16
num_new_tokens
=
53
-
3
*
16
num_new_tokens
=
53
-
3
*
16
blocks
=
manager
.
allocate_slots
(
req2
,
num_new_tokens
,
computed_blocks
)
blocks
=
manager
.
allocate_slots
(
req2
,
num_new_tokens
,
computed_blocks
)
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
8
,
9
]
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
6
]
# Although we only have
5
free blocks, we have 8 blocks in
# Although we only have
6
free blocks, we have 8 blocks in
# the free block queue due to lazy removal.
# the free block queue due to lazy removal.
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
5
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
6
assert
all
([
assert
all
([
b
.
ref_cnt
==
0
b
.
ref_cnt
==
0
for
b
in
manager
.
block_pool
.
free_block_queue
.
get_all_free_blocks
()
for
b
in
manager
.
block_pool
.
free_block_queue
.
get_all_free_blocks
()
])
])
assert
len
([
assert
len
([
b
for
b
in
manager
.
block_pool
.
free_block_queue
.
get_all_free_blocks
()
b
for
b
in
manager
.
block_pool
.
free_block_queue
.
get_all_free_blocks
()
])
==
5
])
==
6
manager
.
free
(
req2
)
manager
.
free
(
req2
)
# Cache miss and eviction.
# Cache miss and eviction.
req3
=
make_request
(
"3"
,
[
99
]
*
(
16
*
9
))
req3
=
make_request
(
"3"
,
[
99
]
*
(
16
*
10
))
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req3
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req3
)
assert
not
computed_blocks
assert
not
computed_blocks
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req3
,
16
*
9
,
computed_blocks
)
blocks
=
manager
.
allocate_slots
(
req3
,
16
*
10
,
computed_blocks
)
# This block ID order also checks the eviction order.
# This block ID order also checks the eviction order.
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
10
,
5
,
4
,
7
,
6
,
9
,
8
,
3
,
2
,
1
]
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
7
,
8
,
9
,
10
,
4
,
5
,
6
,
3
,
2
,
1
]
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
0
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
0
assert
manager
.
block_pool
.
free_block_queue
.
free_list_head
is
None
assert
manager
.
block_pool
.
free_block_queue
.
free_list_head
is
None
assert
manager
.
block_pool
.
free_block_queue
.
free_list_tail
is
None
assert
manager
.
block_pool
.
free_block_queue
.
free_list_tail
is
None
...
@@ -178,7 +176,6 @@ def test_prefill_plp():
...
@@ -178,7 +176,6 @@ def test_prefill_plp():
make_kv_cache_config
(
16
,
11
),
make_kv_cache_config
(
16
,
11
),
max_model_len
=
8192
,
max_model_len
=
8192
,
enable_caching
=
True
,
enable_caching
=
True
,
num_preallocate_tokens
=
16
,
)
)
# the default hash function is hash
# the default hash function is hash
hash_fn
=
hash
hash_fn
=
hash
...
@@ -197,7 +194,7 @@ def test_prefill_plp():
...
@@ -197,7 +194,7 @@ def test_prefill_plp():
assert
not
computed_blocks
assert
not
computed_blocks
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
computed_blocks
)
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
computed_blocks
)
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
1
,
2
,
3
,
4
,
5
]
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
1
,
2
,
3
,
4
]
req0_block_hashes
=
[
b
.
block_hash
for
b
in
blocks
]
req0_block_hashes
=
[
b
.
block_hash
for
b
in
blocks
]
# Check full block metadata
# Check full block metadata
...
@@ -210,8 +207,8 @@ def test_prefill_plp():
...
@@ -210,8 +207,8 @@ def test_prefill_plp():
assert
manager
.
block_pool
.
blocks
[
block_id
].
ref_cnt
==
1
assert
manager
.
block_pool
.
blocks
[
block_id
].
ref_cnt
==
1
parent_block_hash
=
block_hash
.
hash_value
parent_block_hash
=
block_hash
.
hash_value
# Check partial
/preallocated
block metadata
# Check partial block metadata
for
block_id
in
(
4
,
5
):
for
block_id
in
(
4
,
):
assert
manager
.
block_pool
.
blocks
[
block_id
].
block_hash
is
None
assert
manager
.
block_pool
.
blocks
[
block_id
].
block_hash
is
None
assert
manager
.
block_pool
.
blocks
[
block_id
].
ref_cnt
==
1
assert
manager
.
block_pool
.
blocks
[
block_id
].
ref_cnt
==
1
...
@@ -226,12 +223,12 @@ def test_prefill_plp():
...
@@ -226,12 +223,12 @@ def test_prefill_plp():
assert
num_computed_tokens
==
3
*
16
assert
num_computed_tokens
==
3
*
16
num_new_tokens
=
53
-
3
*
16
num_new_tokens
=
53
-
3
*
16
blocks
=
manager
.
allocate_slots
(
req1
,
num_new_tokens
,
computed_blocks
)
blocks
=
manager
.
allocate_slots
(
req1
,
num_new_tokens
,
computed_blocks
)
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
6
,
7
]
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
5
]
for
block
in
computed_blocks
:
for
block
in
computed_blocks
:
assert
block
.
ref_cnt
==
2
assert
block
.
ref_cnt
==
2
# At this point, we should have
3
free blocks left.
# At this point, we should have
5
free blocks left.
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
3
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
5
manager
.
free
(
req0
)
manager
.
free
(
req0
)
manager
.
free
(
req1
)
manager
.
free
(
req1
)
...
@@ -239,14 +236,14 @@ def test_prefill_plp():
...
@@ -239,14 +236,14 @@ def test_prefill_plp():
# All blocks should be available.
# All blocks should be available.
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
10
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
10
# The order should be
# The order should be
# [unallocated (8, 9, 10)]
# [unallocated (
6, 7,
8, 9, 10)]
# [unique_req0 (
5,
4)]
# [unique_req0 (4)]
# [unique_req1 (
7, 6
)]
# [unique_req1 (
5
)]
# [common (3, 2, 1)]
# [common (3, 2, 1)]
assert
[
assert
[
b
.
block_id
b
.
block_id
for
b
in
manager
.
block_pool
.
free_block_queue
.
get_all_free_blocks
()
for
b
in
manager
.
block_pool
.
free_block_queue
.
get_all_free_blocks
()
]
==
[
8
,
9
,
10
,
5
,
4
,
7
,
6
,
3
,
2
,
1
]
]
==
[
6
,
7
,
8
,
9
,
10
,
4
,
5
,
3
,
2
,
1
]
# Request #2 is a prompt-logprobs request:
# Request #2 is a prompt-logprobs request:
# NO cache hit in the common prefix; duplicates request #0 cached blocks
# NO cache hit in the common prefix; duplicates request #0 cached blocks
...
@@ -262,7 +259,7 @@ def test_prefill_plp():
...
@@ -262,7 +259,7 @@ def test_prefill_plp():
block_ids
=
[
b
.
block_id
for
b
in
blocks
]
block_ids
=
[
b
.
block_id
for
b
in
blocks
]
# Duplicate cached blocks have different ids but same hashes vs request #0
# Duplicate cached blocks have different ids but same hashes vs request #0
assert
[
b
.
block_hash
for
b
in
blocks
]
==
req0_block_hashes
assert
[
b
.
block_hash
for
b
in
blocks
]
==
req0_block_hashes
assert
block_ids
!=
[
1
,
2
,
3
,
4
,
5
]
assert
block_ids
!=
[
1
,
2
,
3
,
4
]
# Request #2 block hashes are valid since request #0 hashes are.
# Request #2 block hashes are valid since request #0 hashes are.
# Check block reference counts.
# Check block reference counts.
...
@@ -277,7 +274,6 @@ def test_decode():
...
@@ -277,7 +274,6 @@ def test_decode():
make_kv_cache_config
(
16
,
11
),
make_kv_cache_config
(
16
,
11
),
max_model_len
=
8192
,
max_model_len
=
8192
,
enable_caching
=
True
,
enable_caching
=
True
,
num_preallocate_tokens
=
16
,
)
)
# Complete 3 blocks (48 tokens)
# Complete 3 blocks (48 tokens)
...
@@ -291,7 +287,7 @@ def test_decode():
...
@@ -291,7 +287,7 @@ def test_decode():
assert
not
computed_blocks
assert
not
computed_blocks
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
computed_blocks
)
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
computed_blocks
)
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
1
,
2
,
3
,
4
,
5
]
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
1
,
2
,
3
,
4
]
# Append slots without allocating a new block.
# Append slots without allocating a new block.
req0
.
num_computed_tokens
=
55
req0
.
num_computed_tokens
=
55
...
@@ -299,28 +295,18 @@ def test_decode():
...
@@ -299,28 +295,18 @@ def test_decode():
req0
.
append_output_token_ids
(
8
)
req0
.
append_output_token_ids
(
8
)
new_blocks
=
manager
.
allocate_slots
(
req0
,
4
)
new_blocks
=
manager
.
allocate_slots
(
req0
,
4
)
assert
new_blocks
is
not
None
and
len
(
new_blocks
)
==
0
assert
new_blocks
is
not
None
and
len
(
new_blocks
)
==
0
assert
manager
.
req_to_blocks
[
req0
.
request_id
][
-
2
].
block_hash
is
None
assert
manager
.
req_to_blocks
[
req0
.
request_id
][
-
1
].
block_hash
is
None
# Append slots without allocating a new block, but start using the
# Append slots with allocating a new block.
# preallocated block.
req0
.
num_computed_tokens
=
59
req0
.
num_computed_tokens
=
59
#
6
tokens to fill the previous block, and 10 tokens to fill
#
9
tokens to fill the previous block, and 10 tokens to fill
# the preallocated block.
# the preallocated block.
for
_
in
range
(
5
+
10
):
for
_
in
range
(
9
+
10
):
req0
.
append_output_token_ids
(
7
)
req0
.
append_output_token_ids
(
7
)
new_blocks
=
manager
.
allocate_slots
(
req0
,
1
5
)
new_blocks
=
manager
.
allocate_slots
(
req0
,
1
9
)
assert
new_blocks
is
not
None
and
len
(
new_blocks
)
==
0
assert
new_blocks
is
not
None
and
len
(
new_blocks
)
==
1
assert
manager
.
req_to_blocks
[
req0
.
request_id
][
-
2
].
block_hash
is
not
None
assert
manager
.
req_to_blocks
[
req0
.
request_id
][
-
2
].
block_hash
is
not
None
assert
manager
.
req_to_blocks
[
req0
.
request_id
][
-
1
].
block_hash
is
None
# Append slots with allocating a new block.
req0
.
num_computed_tokens
=
74
# 6 tokens to fill the previous block, and 10 tokens to fill
# the preallocated block.
for
_
in
range
(
6
+
11
):
req0
.
append_output_token_ids
(
12
)
new_blocks
=
manager
.
allocate_slots
(
req0
,
17
)
# Plus one preallocated block.
assert
new_blocks
is
not
None
and
len
(
new_blocks
)
==
2
def
test_evict
():
def
test_evict
():
...
@@ -328,7 +314,6 @@ def test_evict():
...
@@ -328,7 +314,6 @@ def test_evict():
make_kv_cache_config
(
16
,
11
),
make_kv_cache_config
(
16
,
11
),
max_model_len
=
8192
,
max_model_len
=
8192
,
enable_caching
=
True
,
enable_caching
=
True
,
num_preallocate_tokens
=
16
,
)
)
last_token_id
=
5
*
16
+
7
last_token_id
=
5
*
16
+
7
...
@@ -337,7 +322,7 @@ def test_evict():
...
@@ -337,7 +322,7 @@ def test_evict():
assert
not
computed_blocks
assert
not
computed_blocks
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req0
,
5
*
16
+
7
,
computed_blocks
)
blocks
=
manager
.
allocate_slots
(
req0
,
5
*
16
+
7
,
computed_blocks
)
assert
len
(
blocks
)
==
7
# 5 full + 1 partial
+ 1 preallocated
assert
len
(
blocks
)
==
6
# 5 full + 1 partial
# 3 blocks.
# 3 blocks.
req1
=
make_request
(
"1"
,
list
(
range
(
last_token_id
,
req1
=
make_request
(
"1"
,
list
(
range
(
last_token_id
,
...
@@ -349,7 +334,8 @@ def test_evict():
...
@@ -349,7 +334,8 @@ def test_evict():
assert
len
(
blocks
)
==
3
# 3 full blocks
assert
len
(
blocks
)
==
3
# 3 full blocks
last_token_id
+=
3
*
16
last_token_id
+=
3
*
16
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
0
# 10 - (6 + 3) == 1
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
1
manager
.
free
(
req0
)
manager
.
free
(
req0
)
manager
.
free
(
req1
)
manager
.
free
(
req1
)
...
@@ -357,7 +343,7 @@ def test_evict():
...
@@ -357,7 +343,7 @@ def test_evict():
assert
[
assert
[
b
.
block_id
b
.
block_id
for
b
in
manager
.
block_pool
.
free_block_queue
.
get_all_free_blocks
()
for
b
in
manager
.
block_pool
.
free_block_queue
.
get_all_free_blocks
()
]
==
[
7
,
6
,
5
,
4
,
3
,
2
,
1
,
10
,
9
,
8
]
]
==
[
10
,
6
,
5
,
4
,
3
,
2
,
1
,
9
,
8
,
7
]
# Touch the first 2 blocks.
# Touch the first 2 blocks.
req2
=
make_request
(
"2"
,
list
(
range
(
2
*
16
+
3
)))
req2
=
make_request
(
"2"
,
list
(
range
(
2
*
16
+
3
)))
...
@@ -365,8 +351,8 @@ def test_evict():
...
@@ -365,8 +351,8 @@ def test_evict():
assert
[
b
.
block_id
for
b
in
computed_blocks
]
==
[
1
,
2
]
assert
[
b
.
block_id
for
b
in
computed_blocks
]
==
[
1
,
2
]
assert
num_computed_tokens
==
2
*
16
assert
num_computed_tokens
==
2
*
16
blocks
=
manager
.
allocate_slots
(
req2
,
3
,
computed_blocks
)
blocks
=
manager
.
allocate_slots
(
req2
,
3
,
computed_blocks
)
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
7
,
6
]
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
10
]
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
6
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
7
def
test_hash_block_correct_reuse
():
def
test_hash_block_correct_reuse
():
...
@@ -379,7 +365,6 @@ def test_hash_block_correct_reuse():
...
@@ -379,7 +365,6 @@ def test_hash_block_correct_reuse():
make_kv_cache_config
(
16
,
2
),
make_kv_cache_config
(
16
,
2
),
max_model_len
=
8192
,
max_model_len
=
8192
,
enable_caching
=
True
,
enable_caching
=
True
,
num_preallocate_tokens
=
0
,
)
)
# Allocate 1 block and cache it.
# Allocate 1 block and cache it.
...
@@ -416,7 +401,6 @@ def test_computed_blocks_not_evicted():
...
@@ -416,7 +401,6 @@ def test_computed_blocks_not_evicted():
make_kv_cache_config
(
block_size
,
3
),
make_kv_cache_config
(
block_size
,
3
),
max_model_len
=
8192
,
max_model_len
=
8192
,
enable_caching
=
True
,
enable_caching
=
True
,
num_preallocate_tokens
=
0
,
)
)
# Allocate a block and cache it.
# Allocate a block and cache it.
...
@@ -465,7 +449,6 @@ def test_basic_prefix_caching_disabled():
...
@@ -465,7 +449,6 @@ def test_basic_prefix_caching_disabled():
make_kv_cache_config
(
block_size
,
5
),
make_kv_cache_config
(
block_size
,
5
),
max_model_len
=
8192
,
max_model_len
=
8192
,
enable_caching
=
False
,
enable_caching
=
False
,
num_preallocate_tokens
=
0
,
)
)
req1
=
make_request
(
"1"
,
list
(
range
(
10
)))
# 2 blocks and some more
req1
=
make_request
(
"1"
,
list
(
range
(
10
)))
# 2 blocks and some more
...
@@ -496,40 +479,6 @@ def test_basic_prefix_caching_disabled():
...
@@ -496,40 +479,6 @@ def test_basic_prefix_caching_disabled():
assert
not
blocks
assert
not
blocks
@
pytest
.
mark
.
parametrize
(
"num_preallocate_tokens"
,
list
(
range
(
0
,
8
)))
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
4
])
def
test_preallocate_blocks
(
num_preallocate_tokens
:
int
,
block_size
:
int
):
"""
This tests that the preallocated blocks are correctly added.
"""
manager
=
KVCacheManager
(
make_kv_cache_config
(
block_size
,
11
),
max_model_len
=
8192
,
enable_caching
=
True
,
num_preallocate_tokens
=
num_preallocate_tokens
,
)
num_preallocated_blocks
=
cdiv
(
num_preallocate_tokens
,
block_size
)
req
=
make_request
(
"0"
,
list
(
range
(
block_size
*
30
)))
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req
)
assert
not
computed_blocks
assert
num_computed_tokens
==
0
# Just ask for 1 block.
blocks
=
manager
.
allocate_slots
(
req
,
block_size
,
computed_blocks
)
req
.
num_computed_tokens
=
block_size
assert
len
(
blocks
)
==
1
+
num_preallocated_blocks
# Assume all computed, only when num_preallocate_tokens > 0, we need to
# consume the previously preallocated blocks.
if
num_preallocated_blocks
>
0
:
manager
.
allocate_slots
(
req
,
block_size
*
(
len
(
blocks
)
-
1
))
req
.
num_computed_tokens
=
block_size
*
len
(
blocks
)
# Append 1 block.
blocks
=
manager
.
allocate_slots
(
req
,
block_size
)
assert
len
(
blocks
)
==
1
+
num_preallocated_blocks
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
hash
])
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
hash
])
def
test_cache_blocks
(
hash_fn
):
def
test_cache_blocks
(
hash_fn
):
"""
"""
...
@@ -588,7 +537,6 @@ def test_mm_prefix_caching():
...
@@ -588,7 +537,6 @@ def test_mm_prefix_caching():
make_kv_cache_config
(
16
,
11
),
make_kv_cache_config
(
16
,
11
),
max_model_len
=
8192
,
max_model_len
=
8192
,
enable_caching
=
True
,
enable_caching
=
True
,
num_preallocate_tokens
=
16
,
)
)
# Common prompt tokens (T is text tokens and P is image placeholder tokens)
# Common prompt tokens (T is text tokens and P is image placeholder tokens)
...
@@ -626,7 +574,7 @@ def test_mm_prefix_caching():
...
@@ -626,7 +574,7 @@ def test_mm_prefix_caching():
assert
block_hashes
[
2
].
extra_keys
==
(
"bbb"
,
)
assert
block_hashes
[
2
].
extra_keys
==
(
"bbb"
,
)
blocks
=
manager
.
allocate_slots
(
req0
,
59
,
computed_blocks
)
blocks
=
manager
.
allocate_slots
(
req0
,
59
,
computed_blocks
)
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
1
,
2
,
3
,
4
,
5
]
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
1
,
2
,
3
,
4
]
req0
.
num_computed_tokens
=
59
req0
.
num_computed_tokens
=
59
# Append slots without allocating a new block.
# Append slots without allocating a new block.
...
@@ -667,7 +615,6 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
...
@@ -667,7 +615,6 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
make_kv_cache_config
(
block_size
,
11
),
make_kv_cache_config
(
block_size
,
11
),
max_model_len
=
8192
,
max_model_len
=
8192
,
enable_caching
=
True
,
enable_caching
=
True
,
num_preallocate_tokens
=
0
,
)
)
# Complete 3 blocks (48 tokens)
# Complete 3 blocks (48 tokens)
# | Common-0 | Common-1 | Common-2 | ... |
# | Common-0 | Common-1 | Common-2 | ... |
...
@@ -721,7 +668,6 @@ def test_reset_prefix_cache():
...
@@ -721,7 +668,6 @@ def test_reset_prefix_cache():
make_kv_cache_config
(
16
,
11
),
make_kv_cache_config
(
16
,
11
),
max_model_len
=
8192
,
max_model_len
=
8192
,
enable_caching
=
True
,
enable_caching
=
True
,
num_preallocate_tokens
=
0
,
)
)
full_block_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
16
)]
full_block_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
16
)]
...
@@ -751,3 +697,82 @@ def test_reset_prefix_cache():
...
@@ -751,3 +697,82 @@ def test_reset_prefix_cache():
assert
manager
.
reset_prefix_cache
()
assert
manager
.
reset_prefix_cache
()
assert
not
manager
.
block_pool
.
cached_block_hash_to_block
assert
not
manager
.
block_pool
.
cached_block_hash_to_block
assert
all
([
blk
.
block_hash
is
None
for
blk
in
manager
.
block_pool
.
blocks
])
assert
all
([
blk
.
block_hash
is
None
for
blk
in
manager
.
block_pool
.
blocks
])
def
test_prefix_cache_stats_disabled
():
"""Test that prefix_cache_stats is None when log_stats is False."""
manager
=
KVCacheManager
(
make_kv_cache_config
(
16
,
11
),
max_model_len
=
8192
,
enable_caching
=
True
,
log_stats
=
False
,
# Disable logging stats
)
assert
manager
.
prefix_cache_stats
is
None
# Call all functions that check whether log_stats is disabled.
req
=
make_request
(
"0"
,
list
(
range
(
16
)))
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req
)
assert
not
computed_blocks
assert
num_computed_tokens
==
0
manager
.
allocate_slots
(
req
,
16
,
computed_blocks
)
manager
.
reset_prefix_cache
()
# Ensure prefix_cache_stats remains None
assert
manager
.
prefix_cache_stats
is
None
def
test_eagle_enabled_removes_last_block
():
"""Verify Eagle does NOT remove blocks when request
length is divisible by block size."""
block_size
=
16
manager
=
KVCacheManager
(
make_kv_cache_config
(
block_size
,
num_blocks
=
10
),
max_model_len
=
8192
,
enable_caching
=
True
,
use_eagle
=
True
,
)
# Request with 3 full blocks (48 tokens)
token_ids
=
[
0
]
*
(
3
*
block_size
)
req
=
make_request
(
"divisible_request"
,
token_ids
)
# Prime the cache
computed_blocks
,
_
=
manager
.
get_computed_blocks
(
req
)
manager
.
allocate_slots
(
req
,
len
(
token_ids
),
computed_blocks
)
manager
.
free
(
req
)
# New request with same tokens + Eagle enabled
req_eagle
=
make_request
(
"eagle_divisible"
,
token_ids
)
computed_blocks
,
num_tokens
=
manager
.
get_computed_blocks
(
req_eagle
)
# Should retain 2 blocks:
# 1. Original 3 blocks → pop last hash → 2 matched blocks
# 2. last_block_hash is not None → Eagle pop is not SKIPPED
assert
len
(
computed_blocks
)
==
1
assert
num_tokens
==
1
*
block_size
# 32 tokens
def
test_eagle_with_partial_blocks
():
"""Test Eagle behavior with requests containing partial blocks."""
block_size
=
16
manager
=
KVCacheManager
(
make_kv_cache_config
(
block_size
,
num_blocks
=
10
),
max_model_len
=
8192
,
enable_caching
=
True
,
use_eagle
=
True
,
)
# 2 full blocks + 5 tokens (non-divisible length)
token_ids
=
[
0
]
*
(
2
*
block_size
+
5
)
req
=
make_request
(
"partial_block_test"
,
token_ids
)
# Prime the cache
computed_blocks
,
_
=
manager
.
get_computed_blocks
(
req
)
manager
.
allocate_slots
(
req
,
len
(
token_ids
),
computed_blocks
)
manager
.
free
(
req
)
# New request with Eagle enabled
req_eagle
=
make_request
(
"partial_eagle"
,
token_ids
)
computed_blocks
,
num_tokens
=
manager
.
get_computed_blocks
(
req_eagle
)
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
assert
len
(
computed_blocks
)
==
1
assert
num_tokens
==
1
*
block_size
tests/v1/core/test_scheduler.py
View file @
081057de
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Optional
from
typing
import
Optional
from
unittest.mock
import
Mock
import
pytest
import
pytest
import
torch
import
torch
from
vllm.config
import
CacheConfig
,
ModelConfig
,
SchedulerConfig
,
VllmConfig
from
vllm.config
import
(
CacheConfig
,
KVTransferConfig
,
ModelConfig
,
SchedulerConfig
,
SpeculativeConfig
,
VllmConfig
)
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.output
import
SchedulerOutput
...
@@ -25,6 +27,11 @@ def create_scheduler(
...
@@ -25,6 +27,11 @@ def create_scheduler(
enable_prefix_caching
:
Optional
[
bool
]
=
None
,
enable_prefix_caching
:
Optional
[
bool
]
=
None
,
long_prefill_token_threshold
:
int
=
0
,
long_prefill_token_threshold
:
int
=
0
,
disable_chunked_mm_input
:
bool
=
False
,
disable_chunked_mm_input
:
bool
=
False
,
use_kv_connector
:
bool
=
False
,
num_blocks
:
int
=
10000
,
block_size
:
int
=
16
,
max_model_len
:
Optional
[
int
]
=
None
,
num_speculative_tokens
:
Optional
[
int
]
=
None
,
)
->
Scheduler
:
)
->
Scheduler
:
'''Create scheduler under test.
'''Create scheduler under test.
...
@@ -39,12 +46,15 @@ def create_scheduler(
...
@@ -39,12 +46,15 @@ def create_scheduler(
Returns:
Returns:
:class:`Scheduler` instance
:class:`Scheduler` instance
'''
'''
if
max_model_len
is
None
:
max_model_len
=
max_num_batched_tokens
scheduler_config
=
SchedulerConfig
(
scheduler_config
=
SchedulerConfig
(
max_num_seqs
=
max_num_seqs
,
max_num_seqs
=
max_num_seqs
,
max_num_batched_tokens
=
max_num_batched_tokens
,
max_num_batched_tokens
=
max_num_batched_tokens
,
max_model_len
=
max_
num_batched_tok
en
s
,
max_model_len
=
max_
model_l
en
,
long_prefill_token_threshold
=
long_prefill_token_threshold
,
long_prefill_token_threshold
=
long_prefill_token_threshold
,
disable_chunked_mm_input
=
disable_chunked_mm_input
,
disable_chunked_mm_input
=
disable_chunked_mm_input
,
enable_chunked_prefill
=
True
,
)
)
model_config
=
ModelConfig
(
model_config
=
ModelConfig
(
model
=
model
,
model
=
model
,
...
@@ -60,31 +70,42 @@ def create_scheduler(
...
@@ -60,31 +70,42 @@ def create_scheduler(
'enable_prefix_caching'
:
enable_prefix_caching
'enable_prefix_caching'
:
enable_prefix_caching
})
})
cache_config
=
CacheConfig
(
cache_config
=
CacheConfig
(
block_size
=
16
,
block_size
=
block_size
,
gpu_memory_utilization
=
0.9
,
gpu_memory_utilization
=
0.9
,
swap_space
=
0
,
swap_space
=
0
,
cache_dtype
=
"auto"
,
cache_dtype
=
"auto"
,
**
kwargs_cache
,
**
kwargs_cache
,
)
)
kv_transfer_config
=
KVTransferConfig
(
kv_connector
=
"SharedStorageConnector"
,
kv_role
=
"kv_both"
,
kv_connector_extra_config
=
{
"shared_storage_path"
:
"local_storage"
},
)
if
use_kv_connector
else
None
speculative_config
:
Optional
[
SpeculativeConfig
]
=
None
if
num_speculative_tokens
is
not
None
:
speculative_config
=
SpeculativeConfig
(
model
=
"ngram"
,
num_speculative_tokens
=
num_speculative_tokens
)
vllm_config
=
VllmConfig
(
vllm_config
=
VllmConfig
(
scheduler_config
=
scheduler_config
,
scheduler_config
=
scheduler_config
,
model_config
=
model_config
,
model_config
=
model_config
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
kv_transfer_config
=
kv_transfer_config
,
speculative_config
=
speculative_config
,
)
)
kv_cache_config
=
KVCacheConfig
(
kv_cache_config
=
KVCacheConfig
(
num_blocks
=
10000
,
# A large number of blocks to hold all requests
num_blocks
=
num_blocks
,
# A large number of blocks to hold all requests
tensors
=
{},
tensors
=
{},
kv_cache_groups
=
[
kv_cache_groups
=
[
KVCacheGroupSpec
([
'layer'
],
KVCacheGroupSpec
([
'layer'
],
FullAttentionSpec
(
16
,
1
,
1
,
torch
.
float32
,
False
))
FullAttentionSpec
(
block_size
,
1
,
1
,
torch
.
float32
,
False
))
],
],
)
)
cache_config
.
num_gpu_blocks
=
10000
cache_config
.
num_gpu_blocks
=
num_blocks
return
Scheduler
(
return
Scheduler
(
scheduler_config
,
vllm_config
=
vllm_config
,
model_config
,
cache_config
,
lora_config
=
None
,
kv_cache_config
=
kv_cache_config
,
kv_cache_config
=
kv_cache_config
,
log_stats
=
True
,
log_stats
=
True
,
structured_output_manager
=
StructuredOutputManager
(
vllm_config
),
structured_output_manager
=
StructuredOutputManager
(
vllm_config
),
...
@@ -111,7 +132,6 @@ def create_requests(num_requests: int,
...
@@ -111,7 +132,6 @@ def create_requests(num_requests: int,
mm_inputs
=
None
mm_inputs
=
None
request
=
Request
(
request
=
Request
(
request_id
=
f
"
{
i
}
"
,
request_id
=
f
"
{
i
}
"
,
prompt
=
None
,
prompt_token_ids
=
[
i
]
*
num_tokens
,
prompt_token_ids
=
[
i
]
*
num_tokens
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
multi_modal_inputs
=
mm_inputs
,
multi_modal_inputs
=
mm_inputs
,
...
@@ -286,6 +306,7 @@ def test_no_mm_input_chunking():
...
@@ -286,6 +306,7 @@ def test_no_mm_input_chunking():
model
=
"llava-hf/llava-1.5-7b-hf"
,
model
=
"llava-hf/llava-1.5-7b-hf"
,
max_num_batched_tokens
=
1024
,
max_num_batched_tokens
=
1024
,
disable_chunked_mm_input
=
True
,
disable_chunked_mm_input
=
True
,
max_model_len
=
2048
,
)
)
mm_positions
=
[[
PlaceholderRange
(
offset
=
400
,
length
=
800
)]]
mm_positions
=
[[
PlaceholderRange
(
offset
=
400
,
length
=
800
)]]
requests
=
create_requests
(
num_requests
=
1
,
requests
=
create_requests
(
num_requests
=
1
,
...
@@ -414,7 +435,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
...
@@ -414,7 +435,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
def
test_stop_via_update_from_output
():
def
test_stop_via_update_from_output
():
"""Test stopping behavior through update_from_output"""
"""Test stopping behavior through update_from_output"""
scheduler
=
create_scheduler
()
scheduler
=
create_scheduler
(
num_speculative_tokens
=
1
)
# Test case 1: Stop on EOS token
# Test case 1: Stop on EOS token
requests
=
create_requests
(
num_requests
=
2
,
max_tokens
=
10
)
requests
=
create_requests
(
num_requests
=
2
,
max_tokens
=
10
)
...
@@ -422,7 +443,6 @@ def test_stop_via_update_from_output():
...
@@ -422,7 +443,6 @@ def test_stop_via_update_from_output():
req
.
num_computed_tokens
=
req
.
num_tokens
req
.
num_computed_tokens
=
req
.
num_tokens
scheduler
.
requests
[
req
.
request_id
]
=
req
scheduler
.
requests
[
req
.
request_id
]
=
req
scheduler
.
running
.
append
(
req
)
scheduler
.
running
.
append
(
req
)
scheduler
.
scheduled_req_ids
.
add
(
req
.
request_id
)
scheduler_output
=
SchedulerOutput
(
scheduled_new_reqs
=
[],
scheduler_output
=
SchedulerOutput
(
scheduled_new_reqs
=
[],
scheduled_cached_reqs
=
[],
scheduled_cached_reqs
=
[],
...
@@ -466,7 +486,7 @@ def test_stop_via_update_from_output():
...
@@ -466,7 +486,7 @@ def test_stop_via_update_from_output():
assert
list
(
requests
[
1
].
output_token_ids
)
==
[
10
,
11
]
assert
list
(
requests
[
1
].
output_token_ids
)
==
[
10
,
11
]
# Test case 2: Stop on custom stop token
# Test case 2: Stop on custom stop token
scheduler
=
create_scheduler
()
scheduler
=
create_scheduler
(
num_speculative_tokens
=
2
)
requests
=
create_requests
(
num_requests
=
2
,
requests
=
create_requests
(
num_requests
=
2
,
max_tokens
=
10
,
max_tokens
=
10
,
stop_token_ids
=
[
42
,
43
])
stop_token_ids
=
[
42
,
43
])
...
@@ -474,7 +494,6 @@ def test_stop_via_update_from_output():
...
@@ -474,7 +494,6 @@ def test_stop_via_update_from_output():
req
.
num_computed_tokens
=
req
.
num_tokens
req
.
num_computed_tokens
=
req
.
num_tokens
scheduler
.
requests
[
req
.
request_id
]
=
req
scheduler
.
requests
[
req
.
request_id
]
=
req
scheduler
.
running
.
append
(
req
)
scheduler
.
running
.
append
(
req
)
scheduler
.
scheduled_req_ids
.
add
(
req
.
request_id
)
scheduler_output
=
SchedulerOutput
(
scheduled_new_reqs
=
[],
scheduler_output
=
SchedulerOutput
(
scheduled_new_reqs
=
[],
scheduled_cached_reqs
=
[],
scheduled_cached_reqs
=
[],
...
@@ -518,13 +537,12 @@ def test_stop_via_update_from_output():
...
@@ -518,13 +537,12 @@ def test_stop_via_update_from_output():
assert
list
(
requests
[
1
].
output_token_ids
)
==
[
13
,
14
]
assert
list
(
requests
[
1
].
output_token_ids
)
==
[
13
,
14
]
# Test case 3: Stop on max tokens
# Test case 3: Stop on max tokens
scheduler
=
create_scheduler
()
scheduler
=
create_scheduler
(
num_speculative_tokens
=
2
)
requests
=
create_requests
(
num_requests
=
2
,
max_tokens
=
2
)
requests
=
create_requests
(
num_requests
=
2
,
max_tokens
=
2
)
for
req
in
requests
:
for
req
in
requests
:
req
.
num_computed_tokens
=
req
.
num_tokens
req
.
num_computed_tokens
=
req
.
num_tokens
scheduler
.
requests
[
req
.
request_id
]
=
req
scheduler
.
requests
[
req
.
request_id
]
=
req
scheduler
.
running
.
append
(
req
)
scheduler
.
running
.
append
(
req
)
scheduler
.
scheduled_req_ids
.
add
(
req
.
request_id
)
scheduler_output
=
SchedulerOutput
(
scheduled_new_reqs
=
[],
scheduler_output
=
SchedulerOutput
(
scheduled_new_reqs
=
[],
scheduled_cached_reqs
=
[],
scheduled_cached_reqs
=
[],
...
@@ -568,13 +586,12 @@ def test_stop_via_update_from_output():
...
@@ -568,13 +586,12 @@ def test_stop_via_update_from_output():
assert
list
(
requests
[
1
].
output_token_ids
)
==
[
13
]
assert
list
(
requests
[
1
].
output_token_ids
)
==
[
13
]
# Test case 4: Ignore EOS flag
# Test case 4: Ignore EOS flag
scheduler
=
create_scheduler
()
scheduler
=
create_scheduler
(
num_speculative_tokens
=
2
)
requests
=
create_requests
(
num_requests
=
1
,
max_tokens
=
10
)
requests
=
create_requests
(
num_requests
=
1
,
max_tokens
=
10
)
requests
[
0
].
sampling_params
.
ignore_eos
=
True
requests
[
0
].
sampling_params
.
ignore_eos
=
True
requests
[
0
].
num_computed_tokens
=
requests
[
0
].
num_tokens
requests
[
0
].
num_computed_tokens
=
requests
[
0
].
num_tokens
scheduler
.
requests
[
requests
[
0
].
request_id
]
=
requests
[
0
]
scheduler
.
requests
[
requests
[
0
].
request_id
]
=
requests
[
0
]
scheduler
.
running
.
append
(
requests
[
0
])
scheduler
.
running
.
append
(
requests
[
0
])
scheduler
.
scheduled_req_ids
.
add
(
requests
[
0
].
request_id
)
scheduler_output
=
SchedulerOutput
(
scheduler_output
=
SchedulerOutput
(
scheduled_new_reqs
=
[],
scheduled_new_reqs
=
[],
...
@@ -671,13 +688,14 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
...
@@ -671,13 +688,14 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"spec_tokens,output_tokens,expected"
,
"spec_tokens,output_tokens,expected"
,
[
[
([[
1
,
2
,
3
]],
[[
1
,
2
,
3
,
4
]],
(
3
,
3
)),
# perfect match
([[
1
,
2
,
3
]],
[[
1
,
2
,
3
,
4
]],
(
1
,
3
,
3
,
[
1
,
1
,
1
])),
# perfect match
([[
1
,
2
,
3
]],
[[
1
,
5
]],
(
3
,
1
)),
# early mismatch
([[
1
,
2
,
3
]],
[[
1
,
5
]],
(
1
,
3
,
1
,
[
1
,
0
,
0
])),
# early mismatch
([[
1
,
2
],
[
3
]],
[[
1
,
2
,
5
],
[
3
,
4
]],
(
3
,
3
)),
# multiple sequences
([[
1
,
2
],
[
3
]],
[[
1
,
2
,
5
],
[
3
,
4
]],
([[
1
]],
[[
1
,
2
]],
(
1
,
1
)),
# single token sequence
(
2
,
3
,
3
,
[
2
,
1
])),
# multiple sequences
([[]],
[[
5
]],
(
0
,
0
)),
# empty sequence
([[
1
]],
[[
1
,
2
]],
(
1
,
1
,
1
,
[
1
])),
# single token sequence
([[]],
[[
5
]],
(
0
,
0
,
0
,
[
0
])),
# empty sequence
([[
1
,
2
,
3
],
[
4
,
5
,
6
]],
[[
1
,
2
,
7
],
[
4
,
8
]],
([[
1
,
2
,
3
],
[
4
,
5
,
6
]],
[[
1
,
2
,
7
],
[
4
,
8
]],
(
6
,
3
)),
# multiple mismatches
(
2
,
6
,
3
,
[
2
,
1
,
0
]
)),
# multiple mismatches
])
])
def
test_schedule_spec_decoding_stats
(
spec_tokens
,
output_tokens
,
expected
):
def
test_schedule_spec_decoding_stats
(
spec_tokens
,
output_tokens
,
expected
):
"""Test scheduling behavior with speculative decoding.
"""Test scheduling behavior with speculative decoding.
...
@@ -686,7 +704,8 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
...
@@ -686,7 +704,8 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
1. Speculated tokens get scheduled correctly
1. Speculated tokens get scheduled correctly
2. Spec decoding stats properly count number of draft and accepted tokens
2. Spec decoding stats properly count number of draft and accepted tokens
"""
"""
scheduler
=
create_scheduler
()
num_spec_tokens
=
max
(
1
,
max
(
len
(
t
)
for
t
in
spec_tokens
))
scheduler
=
create_scheduler
(
num_speculative_tokens
=
num_spec_tokens
)
requests
=
create_requests
(
num_requests
=
len
(
spec_tokens
),
num_tokens
=
1
)
requests
=
create_requests
(
num_requests
=
len
(
spec_tokens
),
num_tokens
=
1
)
req_ids
=
[]
req_ids
=
[]
req_to_index
=
{}
req_to_index
=
{}
...
@@ -759,5 +778,390 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
...
@@ -759,5 +778,390 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
else
:
else
:
assert
scheduler_stats
.
spec_decoding_stats
is
not
None
assert
scheduler_stats
.
spec_decoding_stats
is
not
None
stats
=
scheduler_stats
.
spec_decoding_stats
stats
=
scheduler_stats
.
spec_decoding_stats
assert
stats
.
num_draft_tokens
==
expected
[
0
]
assert
stats
.
num_drafts
==
expected
[
0
]
assert
stats
.
num_accepted_tokens
==
expected
[
1
]
assert
stats
.
num_draft_tokens
==
expected
[
1
]
assert
stats
.
num_accepted_tokens
==
expected
[
2
]
assert
stats
.
num_accepted_tokens_per_pos
==
expected
[
3
]
def
_assert_right_scheduler_output
(
output
:
SchedulerOutput
,
num_requests
:
int
,
expected_num_scheduled_tokens
:
int
,
):
"""Check if SchedulerOutput is correct after remote KV cache hit."""
# We should inject the kv_connector_metadata.
assert
len
(
output
.
kv_connector_metadata
.
requests
)
==
num_requests
# Only num_tokens - matched_num_new_tokens should be scheduled.
for
_
,
num_scheduled_tokens
in
output
.
num_scheduled_tokens
.
items
():
assert
num_scheduled_tokens
==
expected_num_scheduled_tokens
def
_assert_right_kv_cache_manager
(
scheduler
:
Scheduler
,
req_ids
:
list
[
str
],
num_tokens
:
int
,
block_size
:
int
,
num_requests
:
int
,
num_total_blocks
:
int
,
):
"""Check whether KVCacheManager is correct after allocate."""
# Make sure the request stats are right.
EXPECTED_TOTAL_BLOCKS
=
num_tokens
//
block_size
for
req_id
in
req_ids
:
blocks
=
scheduler
.
kv_cache_manager
.
req_to_blocks
[
req_id
]
hashes
=
scheduler
.
kv_cache_manager
.
req_to_block_hashes
[
req_id
]
assert
(
scheduler
.
kv_cache_manager
.
num_cached_block
[
req_id
]
==
EXPECTED_TOTAL_BLOCKS
)
assert
len
(
blocks
)
==
EXPECTED_TOTAL_BLOCKS
assert
len
(
hashes
)
==
EXPECTED_TOTAL_BLOCKS
# Make sure we actually touched all the blocks.
BLOCKS_PER_REQ
=
num_tokens
/
block_size
assert
(
scheduler
.
kv_cache_manager
.
block_pool
.
get_num_free_blocks
()
==
num_total_blocks
-
num_requests
*
BLOCKS_PER_REQ
)
def
_step_until_done
(
scheduler
:
Scheduler
,
output
:
SchedulerOutput
,
model_runner_output
:
ModelRunnerOutput
,
):
"""Loop over schedule(), update_from_output() until finished."""
all_finished
=
False
_
=
scheduler
.
update_from_output
(
output
,
model_runner_output
)
while
not
all_finished
:
# Schedule + a few iterations until stopping.
output
=
scheduler
.
schedule
()
assert
len
(
scheduler
.
running
)
for
_
,
num_scheduled_tokens
in
output
.
num_scheduled_tokens
.
items
():
# We should be in the decode phase now.
assert
num_scheduled_tokens
==
1
assert
len
(
output
.
kv_connector_metadata
.
requests
)
==
0
ecos
=
scheduler
.
update_from_output
(
output
,
model_runner_output
)
all_done
=
True
for
eco
in
ecos
.
outputs
:
if
eco
.
finish_reason
is
None
:
all_done
=
False
all_finished
=
all_done
def
test_kv_connector_basic
():
"""
Test whether Scheduler with KVConnector schedules tokens, allocates
memory, and cleans up requests as expected under normal operation.
"""
# Setup Scheduler.
scheduler
=
create_scheduler
(
enable_prefix_caching
=
True
,
use_kv_connector
=
True
,
)
NUM_TOTAL_BLOCKS
=
(
scheduler
.
kv_cache_manager
.
block_pool
.
get_num_free_blocks
())
BLOCK_SIZE
=
scheduler
.
cache_config
.
block_size
# Mock External Cache Hit.
NUM_MATCHED_NEW_TOKENS
=
BLOCK_SIZE
*
2
scheduler
.
connector
.
get_num_new_matched_tokens
=
Mock
(
name
=
"method"
)
scheduler
.
connector
.
get_num_new_matched_tokens
.
return_value
=
(
NUM_MATCHED_NEW_TOKENS
)
######################################################
# FIRST SET OF REQUESTS - External Hit Only
NUM_REQUESTS
=
2
NUM_TOKENS
=
NUM_MATCHED_NEW_TOKENS
*
2
MAX_TOKENS
=
3
requests
=
create_requests
(
num_requests
=
NUM_REQUESTS
,
num_tokens
=
NUM_TOKENS
,
max_tokens
=
MAX_TOKENS
)
req_ids
=
[]
req_to_index
=
{}
for
i
,
request
in
enumerate
(
requests
):
scheduler
.
add_request
(
request
)
req_ids
.
append
(
request
.
request_id
)
req_to_index
[
request
.
request_id
]
=
i
MODEL_RUNNER_OUTPUT
=
ModelRunnerOutput
(
req_ids
=
req_ids
,
req_id_to_index
=
req_to_index
,
sampled_token_ids
=
[[
1000
]]
*
len
(
req_ids
),
spec_token_ids
=
None
,
logprobs
=
None
,
prompt_logprobs_dict
=
{},
)
# Ensure ScheduleOutput is correct.
output
=
scheduler
.
schedule
()
_assert_right_scheduler_output
(
output
=
output
,
num_requests
=
NUM_REQUESTS
,
# Just the incremental tokens should be scheduled.
expected_num_scheduled_tokens
=
NUM_TOKENS
-
NUM_MATCHED_NEW_TOKENS
,
)
# Ensure KVCacheManager is correct.
_assert_right_kv_cache_manager
(
scheduler
,
req_ids
,
NUM_TOKENS
,
BLOCK_SIZE
,
NUM_REQUESTS
,
NUM_TOTAL_BLOCKS
)
# Continue Generation until done.
_step_until_done
(
scheduler
,
output
,
MODEL_RUNNER_OUTPUT
)
_
=
scheduler
.
schedule
()
# Confirm we clean up the memory properly.
assert
scheduler
.
kv_cache_manager
.
block_pool
.
get_num_free_blocks
()
\
==
NUM_TOTAL_BLOCKS
######################################################
# SECOND SET OF REQUESTS - Local And External Hit
NUM_TOKENS_PREFIX
=
NUM_TOKENS
# We will get a local prefix cache hit for the first
# NUM_TOKENS_PREFIX tokens since they are used above.
NUM_TOKENS
=
NUM_TOKENS_PREFIX
*
2
requests
=
create_requests
(
num_requests
=
NUM_REQUESTS
,
num_tokens
=
NUM_TOKENS
,
max_tokens
=
MAX_TOKENS
)
req_ids
=
[]
req_to_index
=
{}
for
i
,
request
in
enumerate
(
requests
):
scheduler
.
add_request
(
request
)
req_ids
.
append
(
request
.
request_id
)
req_to_index
[
request
.
request_id
]
=
i
MODEL_RUNNER_OUTPUT
=
ModelRunnerOutput
(
req_ids
=
req_ids
,
req_id_to_index
=
req_to_index
,
sampled_token_ids
=
[[
1000
]]
*
len
(
req_ids
),
spec_token_ids
=
None
,
logprobs
=
None
,
prompt_logprobs_dict
=
{},
)
# We should get a local cache hit of NUM_TOKENS_PREFIX and
# a remote KV cache hit of NUM_MATCHED_NEW_TOKENS.
output
=
scheduler
.
schedule
()
_assert_right_scheduler_output
(
output
=
output
,
num_requests
=
NUM_REQUESTS
,
# Just the incremental tokens after local + remote cache hit.
expected_num_scheduled_tokens
=
(
NUM_TOKENS
-
NUM_TOKENS_PREFIX
-
NUM_MATCHED_NEW_TOKENS
))
# Ensure KVCacheManager is correct.
_assert_right_kv_cache_manager
(
scheduler
,
req_ids
,
NUM_TOKENS
,
BLOCK_SIZE
,
NUM_REQUESTS
,
NUM_TOTAL_BLOCKS
)
# Continue Generation until done.
_step_until_done
(
scheduler
,
output
,
MODEL_RUNNER_OUTPUT
)
_
=
scheduler
.
schedule
()
# Confirm we clean up the memory properly.
assert
scheduler
.
kv_cache_manager
.
block_pool
.
get_num_free_blocks
()
\
==
NUM_TOTAL_BLOCKS
def
test_kv_connector_unable_to_allocate
():
"""
Test whether scheduler with KVConnector is able to handle
unable to allocate (run out of blocks in allocate_slots().
"""
# Setup Scheduler With Mock External Cache Hit.
BLOCK_SIZE
=
4
NUM_BLOCKS
=
10
scheduler
=
create_scheduler
(
enable_prefix_caching
=
True
,
use_kv_connector
=
True
,
block_size
=
BLOCK_SIZE
,
num_blocks
=
NUM_BLOCKS
,
)
NUM_MATCHED_NEW_TOKENS
=
BLOCK_SIZE
*
2
scheduler
.
connector
.
get_num_new_matched_tokens
=
Mock
(
name
=
"method"
)
scheduler
.
connector
.
get_num_new_matched_tokens
.
return_value
=
(
NUM_MATCHED_NEW_TOKENS
)
# Create two requests. The second request will not be able to
# allocate slots because it will not have enough blocks.
NUM_REQUESTS
=
2
NUM_TOKENS
=
(
NUM_BLOCKS
//
2
+
1
)
*
BLOCK_SIZE
MAX_TOKENS
=
2
requests
=
create_requests
(
num_requests
=
NUM_REQUESTS
,
num_tokens
=
NUM_TOKENS
,
max_tokens
=
MAX_TOKENS
)
req_ids
=
[]
req_to_index
=
{}
for
i
,
request
in
enumerate
(
requests
):
scheduler
.
add_request
(
request
)
req_ids
.
append
(
request
.
request_id
)
req_to_index
[
request
.
request_id
]
=
i
MODEL_RUNNER_OUTPUT
=
ModelRunnerOutput
(
req_ids
=
req_ids
,
req_id_to_index
=
req_to_index
,
sampled_token_ids
=
[[
1000
]]
*
len
(
req_ids
),
spec_token_ids
=
None
,
logprobs
=
None
,
prompt_logprobs_dict
=
{},
)
# Just one request should be running.
output
=
scheduler
.
schedule
()
_assert_right_scheduler_output
(
output
,
num_requests
=
1
,
expected_num_scheduled_tokens
=
NUM_TOKENS
-
NUM_MATCHED_NEW_TOKENS
)
assert
len
(
scheduler
.
running
)
==
1
assert
len
(
scheduler
.
waiting
)
==
1
# All memory should be freed, with one request waiting.
_step_until_done
(
scheduler
,
output
,
MODEL_RUNNER_OUTPUT
)
assert
scheduler
.
kv_cache_manager
.
block_pool
.
get_num_free_blocks
()
\
==
NUM_BLOCKS
-
1
assert
len
(
scheduler
.
running
)
==
0
assert
len
(
scheduler
.
waiting
)
==
1
# Just one request should be running.
output
=
scheduler
.
schedule
()
_assert_right_scheduler_output
(
output
,
num_requests
=
1
,
expected_num_scheduled_tokens
=
NUM_TOKENS
-
NUM_MATCHED_NEW_TOKENS
)
assert
len
(
scheduler
.
running
)
==
1
assert
len
(
scheduler
.
waiting
)
==
0
# All memory should be freed, with no requests waiting / running.
_step_until_done
(
scheduler
,
output
,
MODEL_RUNNER_OUTPUT
)
assert
scheduler
.
kv_cache_manager
.
block_pool
.
get_num_free_blocks
()
\
==
NUM_BLOCKS
-
1
assert
len
(
scheduler
.
running
)
==
0
assert
len
(
scheduler
.
waiting
)
==
0
def
test_kv_connector_handles_preemption
():
"""
Test whether scheduler with KVConnector is able to handle
unable to allocate (run out of blocks in allocate_slots().
"""
# Setup Scheduler With Mock External Cache Hit.
BLOCK_SIZE
=
2
# NOTE: there is 1 null block, so this is 6 blocks.
NUM_BLOCKS
=
7
scheduler
=
create_scheduler
(
enable_prefix_caching
=
True
,
use_kv_connector
=
True
,
block_size
=
BLOCK_SIZE
,
num_blocks
=
NUM_BLOCKS
,
)
NUM_MATCHED_NEW_TOKENS
=
BLOCK_SIZE
scheduler
.
connector
.
get_num_new_matched_tokens
=
Mock
(
name
=
"method"
)
scheduler
.
connector
.
get_num_new_matched_tokens
.
return_value
=
(
NUM_MATCHED_NEW_TOKENS
)
# Create two requests.
# Both can be scheduled at first, but the second request
# will be preempted and re-scheduled.
NUM_REQUESTS
=
2
NUM_TOKENS
=
BLOCK_SIZE
*
2
+
1
MAX_TOKENS
=
BLOCK_SIZE
*
2
requests
=
create_requests
(
num_requests
=
NUM_REQUESTS
,
num_tokens
=
NUM_TOKENS
,
max_tokens
=
MAX_TOKENS
)
req_ids
=
[]
req_to_index
=
{}
for
i
,
request
in
enumerate
(
requests
):
scheduler
.
add_request
(
request
)
req_ids
.
append
(
request
.
request_id
)
req_to_index
[
request
.
request_id
]
=
i
MODEL_RUNNER_OUTPUT
=
ModelRunnerOutput
(
req_ids
=
req_ids
,
req_id_to_index
=
req_to_index
,
sampled_token_ids
=
[[
1000
]]
*
len
(
req_ids
),
spec_token_ids
=
None
,
logprobs
=
None
,
prompt_logprobs_dict
=
{},
)
# All can be scheduled - 1st token.
output
=
scheduler
.
schedule
()
_assert_right_scheduler_output
(
output
,
# 2 remote kv cache hits.
num_requests
=
2
,
expected_num_scheduled_tokens
=
NUM_TOKENS
-
NUM_MATCHED_NEW_TOKENS
)
assert
len
(
scheduler
.
running
)
==
2
_
=
scheduler
.
update_from_output
(
output
,
MODEL_RUNNER_OUTPUT
)
# All can be scheduled - 2nd token.
output
=
scheduler
.
schedule
()
_assert_right_scheduler_output
(
output
,
# no connector_metadata
num_requests
=
0
,
expected_num_scheduled_tokens
=
1
)
assert
len
(
scheduler
.
running
)
==
2
_
=
scheduler
.
update_from_output
(
output
,
MODEL_RUNNER_OUTPUT
)
# This will generate a new block and cause a preemption - 3rd token.
output
=
scheduler
.
schedule
()
_assert_right_scheduler_output
(
output
,
# no connector_metadata
num_requests
=
0
,
expected_num_scheduled_tokens
=
1
)
assert
len
(
scheduler
.
running
)
==
1
assert
len
(
scheduler
.
waiting
)
==
1
_
=
scheduler
.
update_from_output
(
output
,
MODEL_RUNNER_OUTPUT
)
assert
len
(
scheduler
.
running
)
==
1
assert
len
(
scheduler
.
waiting
)
==
1
# Only 1 can be scheduled - 4th (and last token).
output
=
scheduler
.
schedule
()
_assert_right_scheduler_output
(
output
,
# no connector_metadata
num_requests
=
0
,
expected_num_scheduled_tokens
=
1
)
assert
len
(
scheduler
.
waiting
)
==
1
assert
len
(
scheduler
.
running
)
==
1
_
=
scheduler
.
update_from_output
(
output
,
MODEL_RUNNER_OUTPUT
)
assert
len
(
scheduler
.
running
)
==
0
assert
len
(
scheduler
.
waiting
)
==
1
# All memory should be freed since nothing is running.
assert
scheduler
.
kv_cache_manager
.
block_pool
.
get_num_free_blocks
()
\
==
NUM_BLOCKS
-
1
# Restarts the preempted request - generate 3rd token.
# This will have a local and remote cache hit.
output
=
scheduler
.
schedule
()
_assert_right_scheduler_output
(
output
,
# 1 remote kv_cache hit!
num_requests
=
1
,
# Only 1 block was preempted and there is a single
# remote hit. So only single new token is scheduled.
expected_num_scheduled_tokens
=
1
,
)
assert
len
(
scheduler
.
running
)
==
1
assert
len
(
scheduler
.
waiting
)
==
0
_
=
scheduler
.
update_from_output
(
output
,
MODEL_RUNNER_OUTPUT
)
assert
len
(
scheduler
.
running
)
==
1
assert
len
(
scheduler
.
waiting
)
==
0
# Only 1 can be scheduled - 4th (and last token).
output
=
scheduler
.
schedule
()
_assert_right_scheduler_output
(
output
,
# no connector_metadata
num_requests
=
0
,
expected_num_scheduled_tokens
=
1
)
assert
len
(
scheduler
.
running
)
==
1
_
=
scheduler
.
update_from_output
(
output
,
MODEL_RUNNER_OUTPUT
)
assert
len
(
scheduler
.
running
)
==
0
# All memory should be freed since nothing is running.
assert
scheduler
.
kv_cache_manager
.
block_pool
.
get_num_free_blocks
()
\
==
NUM_BLOCKS
-
1
tests/v1/e2e/test_cascade_attention.py
View file @
081057de
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
pytest
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
from
...utils
import
fork_new_process_for_each_test
def
test_cascade_attention
(
example_system_message
,
monkeypatch
):
@
fork_new_process_for_each_test
@
pytest
.
mark
.
parametrize
(
"attn_backend"
,
[
"FLASH_ATTN_VLLM_V1"
,
"FLASHINFER_VLLM_V1"
])
def
test_cascade_attention
(
example_system_message
,
monkeypatch
,
attn_backend
):
prompt
=
"
\n
<User>: Implement fibonacci sequence in Python.
\n
<Claude>:"
prompt
=
"
\n
<User>: Implement fibonacci sequence in Python.
\n
<Claude>:"
with
monkeypatch
.
context
()
as
m
:
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
attn_backend
)
llm
=
LLM
(
model
=
"Qwen/Qwen2-1.5B-Instruct"
)
llm
=
LLM
(
model
=
"Qwen/Qwen2-1.5B-Instruct"
)
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
100
)
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
100
)
...
...
tests/v1/e2e/test_spec_decode.py
View file @
081057de
...
@@ -44,18 +44,20 @@ def test_prompts():
...
@@ -44,18 +44,20 @@ def test_prompts():
@
pytest
.
fixture
@
pytest
.
fixture
def
sampling_config
():
def
sampling_config
():
# Only support greedy for now
return
SamplingParams
(
temperature
=
0
,
max_tokens
=
10
,
ignore_eos
=
False
)
return
SamplingParams
(
temperature
=
0
,
max_tokens
=
10
,
ignore_eos
=
False
)
@
pytest
.
fixture
@
pytest
.
fixture
def
model_name
():
def
model_name
():
return
"meta-llama/
Meta-
Llama-3-8B-Instruct"
return
"meta-llama/Llama-3
.1
-8B-Instruct"
@
pytest
.
fixture
def
eagle_model_name
():
def
eagle_model_name
():
return
"yuhuili/EAGLE-LLaMA3-Instruct-8B"
return
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
def
eagle3_model_name
():
return
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
def
test_ngram_correctness
(
def
test_ngram_correctness
(
...
@@ -102,12 +104,13 @@ def test_ngram_correctness(
...
@@ -102,12 +104,13 @@ def test_ngram_correctness(
del
spec_llm
del
spec_llm
@
pytest
.
mark
.
parametrize
(
"use_eagle3"
,
[
False
,
True
],
ids
=
[
"eagle"
,
"eagle3"
])
def
test_eagle_correctness
(
def
test_eagle_correctness
(
monkeypatch
:
pytest
.
MonkeyPatch
,
monkeypatch
:
pytest
.
MonkeyPatch
,
test_prompts
:
list
[
list
[
dict
[
str
,
Any
]]],
test_prompts
:
list
[
list
[
dict
[
str
,
Any
]]],
sampling_config
:
SamplingParams
,
sampling_config
:
SamplingParams
,
model_name
:
str
,
model_name
:
str
,
eagle
_model_name
:
str
,
use_
eagle
3
:
bool
,
):
):
'''
'''
Compare the outputs of a original LLM and a speculative LLM
Compare the outputs of a original LLM and a speculative LLM
...
@@ -116,18 +119,22 @@ def test_eagle_correctness(
...
@@ -116,18 +119,22 @@ def test_eagle_correctness(
with
monkeypatch
.
context
()
as
m
:
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
ref_llm
=
LLM
(
model
=
model_name
,
max_model_len
=
1024
)
ref_llm
=
LLM
(
model
=
model_name
,
max_model_len
=
2048
)
ref_outputs
=
ref_llm
.
chat
(
test_prompts
,
sampling_config
)
ref_outputs
=
ref_llm
.
chat
(
test_prompts
,
sampling_config
)
del
ref_llm
del
ref_llm
spec_model_name
=
eagle3_model_name
(
)
if
use_eagle3
else
eagle_model_name
()
spec_llm
=
LLM
(
spec_llm
=
LLM
(
model
=
model_name
,
model
=
model_name
,
trust_remote_code
=
True
,
speculative_config
=
{
speculative_config
=
{
"method"
:
"eagle"
,
"method"
:
"eagle3"
if
use_eagle3
else
"eagle"
,
"model"
:
eagle
_model_name
,
"model"
:
spec
_model_name
,
"num_speculative_tokens"
:
3
,
"num_speculative_tokens"
:
3
,
"max_model_len"
:
2048
,
},
},
max_model_len
=
1024
,
max_model_len
=
2048
,
)
)
spec_outputs
=
spec_llm
.
chat
(
test_prompts
,
sampling_config
)
spec_outputs
=
spec_llm
.
chat
(
test_prompts
,
sampling_config
)
matches
=
0
matches
=
0
...
@@ -140,7 +147,7 @@ def test_eagle_correctness(
...
@@ -140,7 +147,7 @@ def test_eagle_correctness(
print
(
f
"ref_output:
{
ref_output
.
outputs
[
0
].
text
}
"
)
print
(
f
"ref_output:
{
ref_output
.
outputs
[
0
].
text
}
"
)
print
(
f
"spec_output:
{
spec_output
.
outputs
[
0
].
text
}
"
)
print
(
f
"spec_output:
{
spec_output
.
outputs
[
0
].
text
}
"
)
# Heuristic: expect at least
70
% of the prompts to match exactly
# Heuristic: expect at least
66
% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
# Upon failure, inspect the outputs to check for inaccuracy.
assert
matches
>
int
(
0.
7
*
len
(
ref_outputs
))
assert
matches
>
int
(
0.
66
*
len
(
ref_outputs
))
del
spec_llm
del
spec_llm
tests/v1/engine/conftest.py
View file @
081057de
...
@@ -47,7 +47,7 @@ def _build_test_vectors_no_logprobs() -> DummyOutputProcessorTestVectors:
...
@@ -47,7 +47,7 @@ def _build_test_vectors_no_logprobs() -> DummyOutputProcessorTestVectors:
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
tokenizer_group
=
init_tokenizer_from_configs
(
tokenizer_group
=
init_tokenizer_from_configs
(
vllm_config
.
model_config
,
vllm_config
.
scheduler_config
,
vllm_config
.
model_config
,
vllm_config
.
scheduler_config
,
vllm_config
.
parallel_config
,
vllm_config
.
lora_config
),
vllm_config
.
lora_config
),
vllm_config
=
vllm_config
,
vllm_config
=
vllm_config
,
full_tokens
=
[
tokenizer
(
text
).
input_ids
for
text
in
FULL_STRINGS
],
full_tokens
=
[
tokenizer
(
text
).
input_ids
for
text
in
FULL_STRINGS
],
prompt_tokens
=
prompt_tokens
,
prompt_tokens
=
prompt_tokens
,
...
...
tests/v1/engine/test_async_llm.py
View file @
081057de
...
@@ -3,16 +3,19 @@
...
@@ -3,16 +3,19 @@
import
asyncio
import
asyncio
from
contextlib
import
ExitStack
from
contextlib
import
ExitStack
from
typing
import
Optional
from
typing
import
Optional
from
unittest.mock
import
MagicMock
import
pytest
import
pytest
from
vllm
import
SamplingParams
from
vllm
import
SamplingParams
from
vllm.assets.image
import
ImageAsset
from
vllm.assets.image
import
ImageAsset
from
vllm.config
import
VllmConfig
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.inputs
import
PromptType
from
vllm.inputs
import
PromptType
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.sampling_params
import
RequestOutputKind
from
vllm.sampling_params
import
RequestOutputKind
from
vllm.v1.engine.async_llm
import
AsyncLLM
from
vllm.v1.engine.async_llm
import
AsyncLLM
from
vllm.v1.metrics.loggers
import
LoggingStatLogger
if
not
current_platform
.
is_cuda
():
if
not
current_platform
.
is_cuda
():
pytest
.
skip
(
reason
=
"V1 currently only supported on CUDA."
,
pytest
.
skip
(
reason
=
"V1 currently only supported on CUDA."
,
...
@@ -216,3 +219,33 @@ async def test_finished_flag(monkeypatch: pytest.MonkeyPatch, n: int,
...
@@ -216,3 +219,33 @@ async def test_finished_flag(monkeypatch: pytest.MonkeyPatch, n: int,
# Assert only the last output has the finished flag set
# Assert only the last output has the finished flag set
assert
all
(
not
out
.
finished
for
out
in
outputs
[:
-
1
])
assert
all
(
not
out
.
finished
for
out
in
outputs
[:
-
1
])
assert
outputs
[
-
1
].
finished
assert
outputs
[
-
1
].
finished
class
MockLoggingStatLogger
(
LoggingStatLogger
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
engine_index
:
int
=
0
):
super
().
__init__
(
vllm_config
,
engine_index
)
self
.
log
=
MagicMock
()
@
pytest
.
mark
.
asyncio
async
def
test_customize_loggers
(
monkeypatch
):
"""Test that we can customize the loggers.
If a customized logger is provided at the init, it should
be used directly.
"""
with
monkeypatch
.
context
()
as
m
,
ExitStack
()
as
after
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
engine
=
AsyncLLM
.
from_engine_args
(
TEXT_ENGINE_ARGS
,
stat_loggers
=
[
MockLoggingStatLogger
],
)
after
.
callback
(
engine
.
shutdown
)
await
engine
.
do_log_stats
()
assert
len
(
engine
.
stat_loggers
)
==
1
assert
len
(
engine
.
stat_loggers
[
0
])
==
1
engine
.
stat_loggers
[
0
][
0
].
log
.
assert_called_once
()
tests/v1/engine/test_engine_core.py
View file @
081057de
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
copy
import
copy
import
threading
import
time
import
time
import
uuid
import
uuid
from
concurrent.futures
import
Future
from
concurrent.futures
import
Future
,
ThreadPoolExecutor
import
pytest
import
pytest
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
...
@@ -32,8 +31,7 @@ PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids
...
@@ -32,8 +31,7 @@ PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids
def
make_request
()
->
EngineCoreRequest
:
def
make_request
()
->
EngineCoreRequest
:
return
EngineCoreRequest
(
return
EngineCoreRequest
(
request_id
=
uuid
.
uuid4
(),
request_id
=
str
(
uuid
.
uuid4
()),
prompt
=
PROMPT
,
prompt_token_ids
=
PROMPT_TOKENS
,
prompt_token_ids
=
PROMPT_TOKENS
,
mm_inputs
=
None
,
mm_inputs
=
None
,
mm_hashes
=
None
,
mm_hashes
=
None
,
...
@@ -244,33 +242,33 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
...
@@ -244,33 +242,33 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
self
,
kv_cache_configs
:
list
[
KVCacheConfig
])
->
None
:
self
,
kv_cache_configs
:
list
[
KVCacheConfig
])
->
None
:
super
().
initialize_from_config
(
kv_cache_configs
)
super
().
initialize_from_config
(
kv_cache_configs
)
#
This executor actually can only run 1 batch at a time
#
Create a thread pool with a single worker
self
.
semaphore
=
t
hread
ing
.
Semaphore
(
1
)
self
.
thread_pool
=
T
hread
PoolExecutor
(
max_workers
=
1
)
def
execute_model
(
def
execute_model
(
self
,
self
,
scheduler_output
,
scheduler_output
,
)
->
Future
[
ModelRunnerOutput
]:
)
->
Future
[
ModelRunnerOutput
]:
"""Make execute_model non-blocking."""
"""Make execute_model non-blocking."""
future
:
Future
[
ModelRunnerOutput
]
=
Future
()
def
_thread_wrapper
(
scheduler_output
,
future
):
def
_execute
():
with
self
.
semaphore
:
output
=
self
.
collective_rpc
(
"execute_model"
,
output
=
self
.
collective_rpc
(
"execute_model"
,
args
=
(
scheduler_output
,
))
args
=
(
scheduler_output
,
))
# Make a copy because output[0] may be reused
# Make a copy because output[0] may be reused
# by the next batch.
# by the next batch.
return
copy
.
deepcopy
(
output
[
0
])
output
=
copy
.
deepcopy
(
output
[
0
])
future
.
set_result
(
output
)
threading
.
Thread
(
target
=
_thread_wrapper
,
# Use the thread pool instead of creating a new thread
args
=
(
scheduler_output
,
future
)).
start
()
return
self
.
thread_pool
.
submit
(
_execute
)
return
future
@
property
@
property
def
max_concurrent_batches
(
self
)
->
int
:
def
max_concurrent_batches
(
self
)
->
int
:
return
2
return
2
def
shutdown
(
self
):
if
hasattr
(
self
,
'thread_pool'
):
self
.
thread_pool
.
shutdown
(
wait
=
False
)
with
monkeypatch
.
context
()
as
m
:
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
...
@@ -299,14 +297,77 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
...
@@ -299,14 +297,77 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
# Schedule Batch 1: (10, req0)
# Schedule Batch 1: (10, req0)
assert
engine_core
.
step_with_batch_queue
()
is
None
assert
engine_core
.
step_with_batch_queue
()
is
None
assert
engine_core
.
batch_queue
.
qsize
()
==
1
assert
engine_core
.
batch_queue
.
qsize
()
==
1
scheduler_output
=
engine_core
.
batch_queue
.
queue
[
-
1
][
1
]
assert
scheduler_output
.
num_scheduled_tokens
[
0
]
==
10
# num_computed_tokens should have been updated immediately.
assert
engine_core
.
scheduler
.
requests
[
req0
.
request_id
].
num_computed_tokens
==
10
# Schedule Batch 2: (2, req0), (8, req1)
assert
engine_core
.
step_with_batch_queue
()
is
None
assert
engine_core
.
step_with_batch_queue
()
is
None
assert
engine_core
.
batch_queue
.
qsize
()
==
2
assert
engine_core
.
batch_queue
.
qsize
()
==
2
scheduler_output
=
engine_core
.
batch_queue
.
queue
[
-
1
][
1
]
assert
scheduler_output
.
num_scheduled_tokens
[
0
]
==
2
assert
scheduler_output
.
num_scheduled_tokens
[
1
]
==
8
# num_computed_tokens should have been updated immediately.
assert
engine_core
.
scheduler
.
requests
[
0
].
num_computed_tokens
==
12
assert
engine_core
.
scheduler
.
requests
[
1
].
num_computed_tokens
==
8
assert
engine_core
.
scheduler
.
get_num_unfinished_requests
()
==
2
assert
engine_core
.
scheduler
.
get_num_unfinished_requests
()
==
2
# Loop through both requests.
# Batch queue is full. Finish Batch 1.
while
engine_core
.
scheduler
.
get_num_unfinished_requests
()
==
2
:
engine_core
.
step_with_batch_queue
()
engine_core
.
step_with_batch_queue
()
# Schedule Batch 3: (4, req1). Note that req0 cannot be scheduled
# because it is in the decoding stage now.
engine_core
.
step_with_batch_queue
()
assert
engine_core
.
batch_queue
.
qsize
()
==
2
scheduler_output
=
engine_core
.
batch_queue
.
queue
[
-
1
][
1
]
assert
scheduler_output
.
num_scheduled_tokens
[
1
]
==
4
# Reaching here when got the result of the first request.
# Batch queue is full. Finish Batch 2. Get first token of req0.
while
engine_core
.
scheduler
.
get_num_unfinished_requests
()
==
1
:
output
=
engine_core
.
step_with_batch_queue
()
engine_core
.
step_with_batch_queue
()
assert
output
is
not
None
assert
len
(
output
.
outputs
)
==
1
assert
engine_core
.
scheduler
.
requests
[
req0
.
request_id
].
num_tokens
==
13
# Schedule Batch 4: (1, req0).
engine_core
.
step_with_batch_queue
()
assert
engine_core
.
batch_queue
.
qsize
()
==
2
scheduler_output
=
engine_core
.
batch_queue
.
queue
[
-
1
][
1
]
assert
scheduler_output
.
num_scheduled_tokens
[
0
]
==
1
# Batch queue is full. Finish Batch 3. Get first token of req1.
output
=
engine_core
.
step_with_batch_queue
()
assert
output
is
not
None
assert
len
(
output
.
outputs
)
==
1
assert
engine_core
.
scheduler
.
requests
[
req1
.
request_id
].
num_tokens
==
13
# Schedule Batch 5: (1, req1).
engine_core
.
step_with_batch_queue
()
assert
engine_core
.
batch_queue
.
qsize
()
==
2
scheduler_output
=
engine_core
.
batch_queue
.
queue
[
-
1
][
1
]
assert
scheduler_output
.
num_scheduled_tokens
[
1
]
==
1
# Loop until req0 is finished.
step
=
0
req_id
=
0
expected_num_tokens
=
[
engine_core
.
scheduler
.
requests
[
0
].
num_tokens
+
1
,
engine_core
.
scheduler
.
requests
[
1
].
num_tokens
+
1
,
]
while
engine_core
.
scheduler
.
get_num_unfinished_requests
()
==
2
:
output
=
engine_core
.
step_with_batch_queue
()
if
step
%
2
==
0
:
# Even steps consumes an output.
assert
output
is
not
None
assert
len
(
output
.
outputs
)
==
1
if
req_id
in
engine_core
.
scheduler
.
requests
:
assert
engine_core
.
scheduler
.
requests
[
req_id
].
num_tokens
==
expected_num_tokens
[
req_id
]
expected_num_tokens
[
req_id
]
+=
1
req_id
=
(
req_id
+
1
)
%
2
else
:
# Odd steps schedules a new batch.
assert
output
is
None
step
+=
1
tests/v1/engine/test_engine_core_client.py
View file @
081057de
...
@@ -35,7 +35,6 @@ PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids
...
@@ -35,7 +35,6 @@ PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids
def
make_request
(
params
:
SamplingParams
)
->
EngineCoreRequest
:
def
make_request
(
params
:
SamplingParams
)
->
EngineCoreRequest
:
return
EngineCoreRequest
(
return
EngineCoreRequest
(
request_id
=
str
(
uuid
.
uuid4
()),
request_id
=
str
(
uuid
.
uuid4
()),
prompt
=
PROMPT
,
prompt_token_ids
=
PROMPT_TOKENS
,
prompt_token_ids
=
PROMPT_TOKENS
,
mm_inputs
=
None
,
mm_inputs
=
None
,
mm_hashes
=
None
,
mm_hashes
=
None
,
...
...
tests/v1/engine/test_output_processor.py
View file @
081057de
...
@@ -50,7 +50,6 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind,
...
@@ -50,7 +50,6 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind,
# Make N requests.
# Make N requests.
requests
=
[
requests
=
[
EngineCoreRequest
(
request_id
=
f
"request-
{
idx
}
"
,
EngineCoreRequest
(
request_id
=
f
"request-
{
idx
}
"
,
prompt
=
prompt
,
prompt_token_ids
=
prompt_tokens
,
prompt_token_ids
=
prompt_tokens
,
arrival_time
=
0
,
arrival_time
=
0
,
mm_inputs
=
None
,
mm_inputs
=
None
,
...
@@ -64,14 +63,13 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind,
...
@@ -64,14 +63,13 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind,
output_kind
=
request_output_kind
,
output_kind
=
request_output_kind
,
stop
=
[],
stop
=
[],
include_stop_str_in_output
=
False
,
include_stop_str_in_output
=
False
,
))
for
idx
,
(
prompt
,
prompt_tokens
)
in
enumerate
(
))
zip
(
dummy_test_vectors
.
prompt_strings
,
for
idx
,
prompt_tokens
in
enumerate
(
dummy_test_vectors
.
prompt_tokens
)
dummy_test_vectors
.
prompt_tokens
))
]
]
# Add requests to the detokenizer.
# Add requests to the detokenizer.
for
request
in
requests
:
for
request
,
prompt
in
zip
(
requests
,
dummy_test_vectors
.
prompt_strings
)
:
output_processor
.
add_request
(
request
)
output_processor
.
add_request
(
request
,
prompt
)
gen_strings
=
{}
gen_strings
=
{}
gen_tokens
=
{}
gen_tokens
=
{}
...
@@ -398,7 +396,6 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
...
@@ -398,7 +396,6 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
]
]
requests
=
[
requests
=
[
EngineCoreRequest
(
request_id
=
request_id_list
[
idx
],
EngineCoreRequest
(
request_id
=
request_id_list
[
idx
],
prompt
=
prompt
,
prompt_token_ids
=
prompt_tokens
,
prompt_token_ids
=
prompt_tokens
,
arrival_time
=
0
,
arrival_time
=
0
,
mm_inputs
=
None
,
mm_inputs
=
None
,
...
@@ -414,14 +411,13 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
...
@@ -414,14 +411,13 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
include_stop_str_in_output
=
False
,
include_stop_str_in_output
=
False
,
logprobs
=
num_sample_logprobs
,
logprobs
=
num_sample_logprobs
,
prompt_logprobs
=
num_prompt_logprobs
,
prompt_logprobs
=
num_prompt_logprobs
,
))
for
idx
,
(
prompt
,
prompt_tokens
)
in
enumerate
(
))
zip
(
dummy_test_vectors
.
prompt_strings
,
for
idx
,
prompt_tokens
in
enumerate
(
dummy_test_vectors
.
prompt_tokens
)
dummy_test_vectors
.
prompt_tokens
))
]
]
# Add requests to the detokenizer.
# Add requests to the detokenizer.
for
request
in
requests
:
for
request
,
prompt
in
zip
(
requests
,
dummy_test_vectors
.
prompt_strings
)
:
output_processor
.
add_request
(
request
)
output_processor
.
add_request
(
request
,
prompt
)
gen_tokens
=
{}
gen_tokens
=
{}
gen_logprobs
=
{}
gen_logprobs
=
{}
...
@@ -562,7 +558,6 @@ def test_stop_token(include_stop_str_in_output: bool,
...
@@ -562,7 +558,6 @@ def test_stop_token(include_stop_str_in_output: bool,
request_id
=
"request-0"
request_id
=
"request-0"
request
=
EngineCoreRequest
(
request
=
EngineCoreRequest
(
request_id
=
request_id
,
request_id
=
request_id
,
prompt
=
prompt_string
,
prompt_token_ids
=
prompt_tokens
,
prompt_token_ids
=
prompt_tokens
,
arrival_time
=
0
,
arrival_time
=
0
,
mm_inputs
=
None
,
mm_inputs
=
None
,
...
@@ -583,7 +578,7 @@ def test_stop_token(include_stop_str_in_output: bool,
...
@@ -583,7 +578,7 @@ def test_stop_token(include_stop_str_in_output: bool,
))
))
# Add request to the detokenizer.
# Add request to the detokenizer.
output_processor
.
add_request
(
request
)
output_processor
.
add_request
(
request
,
prompt_string
)
# Loop over engine core steps; run output processor
# Loop over engine core steps; run output processor
gen_string
=
""
gen_string
=
""
...
@@ -659,7 +654,6 @@ def test_stop_string(include_stop_str_in_output: bool,
...
@@ -659,7 +654,6 @@ def test_stop_string(include_stop_str_in_output: bool,
requests
=
[
requests
=
[
EngineCoreRequest
(
EngineCoreRequest
(
request_id
=
request_id_list
[
idx
],
request_id
=
request_id_list
[
idx
],
prompt
=
prompt
,
prompt_token_ids
=
prompt_tokens
,
prompt_token_ids
=
prompt_tokens
,
arrival_time
=
0
,
arrival_time
=
0
,
mm_inputs
=
None
,
mm_inputs
=
None
,
...
@@ -675,14 +669,13 @@ def test_stop_string(include_stop_str_in_output: bool,
...
@@ -675,14 +669,13 @@ def test_stop_string(include_stop_str_in_output: bool,
include_stop_str_in_output
=
include_stop_str_in_output
,
include_stop_str_in_output
=
include_stop_str_in_output
,
logprobs
=
num_sample_logprobs
,
logprobs
=
num_sample_logprobs
,
prompt_logprobs
=
None
,
prompt_logprobs
=
None
,
))
for
idx
,
(
prompt
,
prompt_tokens
)
in
enumerate
(
))
zip
(
dummy_test_vectors
.
prompt_strings
,
for
idx
,
prompt_tokens
in
enumerate
(
dummy_test_vectors
.
prompt_tokens
)
dummy_test_vectors
.
prompt_tokens
))
]
]
# Add requests to the detokenizer.
# Add requests to the detokenizer.
for
request
in
requests
:
for
request
,
prompt
in
zip
(
requests
,
dummy_test_vectors
.
prompt_strings
)
:
output_processor
.
add_request
(
request
)
output_processor
.
add_request
(
request
,
prompt
)
gen_strings
=
{}
gen_strings
=
{}
gen_tokens
=
{}
gen_tokens
=
{}
...
@@ -774,7 +767,6 @@ def test_iteration_stats(dummy_test_vectors):
...
@@ -774,7 +767,6 @@ def test_iteration_stats(dummy_test_vectors):
requests
=
[
requests
=
[
EngineCoreRequest
(
EngineCoreRequest
(
request_id
=
f
"request-
{
idx
}
"
,
request_id
=
f
"request-
{
idx
}
"
,
prompt
=
prompt
,
prompt_token_ids
=
prompt_tokens
,
prompt_token_ids
=
prompt_tokens
,
arrival_time
=
0
,
arrival_time
=
0
,
mm_inputs
=
None
,
mm_inputs
=
None
,
...
@@ -783,15 +775,13 @@ def test_iteration_stats(dummy_test_vectors):
...
@@ -783,15 +775,13 @@ def test_iteration_stats(dummy_test_vectors):
eos_token_id
=
None
,
eos_token_id
=
None
,
lora_request
=
None
,
lora_request
=
None
,
sampling_params
=
SamplingParams
(),
sampling_params
=
SamplingParams
(),
)
for
idx
,
(
prompt
,
prompt_tokens
)
in
enumerate
(
)
for
idx
,
prompt_tokens
in
enumerate
(
dummy_test_vectors
.
prompt_tokens
)
zip
(
dummy_test_vectors
.
prompt_strings
,
dummy_test_vectors
.
prompt_tokens
))
]
]
# Add all requests except one to the OutputProcessor.
# Add all requests except one to the OutputProcessor.
num_active
=
len
(
dummy_test_vectors
.
generation_tokens
)
-
1
num_active
=
len
(
dummy_test_vectors
.
generation_tokens
)
-
1
for
request
in
requests
[:
num_active
]:
for
request
in
requests
[:
num_active
]:
output_processor
.
add_request
(
request
)
output_processor
.
add_request
(
request
,
None
)
inactive_request
=
requests
[
num_active
]
inactive_request
=
requests
[
num_active
]
# First iteration has 2 prefills.
# First iteration has 2 prefills.
...
@@ -817,7 +807,7 @@ def test_iteration_stats(dummy_test_vectors):
...
@@ -817,7 +807,7 @@ def test_iteration_stats(dummy_test_vectors):
assert
iteration_stats
.
num_generation_tokens
==
num_active
assert
iteration_stats
.
num_generation_tokens
==
num_active
# Add a new request - prefill and 2 decodes in this step.
# Add a new request - prefill and 2 decodes in this step.
output_processor
.
add_request
(
inactive_request
)
output_processor
.
add_request
(
inactive_request
,
None
)
num_active
+=
1
num_active
+=
1
outputs
=
engine_core
.
get_outputs
()[:
num_active
]
outputs
=
engine_core
.
get_outputs
()[:
num_active
]
iteration_stats
=
IterationStats
()
iteration_stats
=
IterationStats
()
...
@@ -921,3 +911,84 @@ async def test_request_output_collector():
...
@@ -921,3 +911,84 @@ async def test_request_output_collector():
# Cumulative logprobs should be the last one.
# Cumulative logprobs should be the last one.
cumulative_logprob_expected
=
1.0
*
num_to_put
cumulative_logprob_expected
=
1.0
*
num_to_put
assert
output
.
outputs
[
0
].
cumulative_logprob
==
cumulative_logprob_expected
assert
output
.
outputs
[
0
].
cumulative_logprob
==
cumulative_logprob_expected
@
pytest
.
mark
.
asyncio
async
def
test_cumulative_output_collector_n
():
"""Test collector correctly handles multiple outputs by index."""
collector
=
RequestOutputCollector
(
RequestOutputKind
.
CUMULATIVE
)
outputs
=
[
RequestOutput
(
request_id
=
"my-request-id"
,
prompt
=
None
,
prompt_token_ids
=
[
1
,
2
,
3
],
prompt_logprobs
=
None
,
outputs
=
[
CompletionOutput
(
index
=
0
,
text
=
"a"
,
token_ids
=
[
0
],
cumulative_logprob
=
None
,
logprobs
=
None
,
finish_reason
=
None
,
),
CompletionOutput
(
index
=
1
,
text
=
"b"
,
token_ids
=
[
1
],
cumulative_logprob
=
None
,
logprobs
=
None
,
finish_reason
=
None
,
),
],
finished
=
False
,
),
RequestOutput
(
request_id
=
"my-request-id"
,
prompt
=
None
,
prompt_token_ids
=
[
1
,
2
,
3
],
prompt_logprobs
=
None
,
outputs
=
[
CompletionOutput
(
index
=
0
,
text
=
"ab"
,
token_ids
=
[
0
,
1
],
cumulative_logprob
=
None
,
logprobs
=
None
,
finish_reason
=
None
,
),
CompletionOutput
(
index
=
2
,
text
=
"c"
,
token_ids
=
[
2
],
cumulative_logprob
=
None
,
logprobs
=
None
,
finish_reason
=
None
,
),
],
finished
=
False
,
),
]
for
output
in
outputs
:
collector
.
put
(
output
)
# Get the output and check that the text and token_ids are correct.
result
=
await
collector
.
get
()
# We are expecting
# [{index: 0, text: "ab"}, {index: 1, text: "b"}, {index: 2, text: "c"}]
assert
len
(
result
.
outputs
)
==
3
# First is the one where index is 0
first
=
[
k
for
k
in
result
.
outputs
if
k
.
index
==
0
]
assert
len
(
first
)
==
1
assert
first
[
0
].
text
==
"ab"
# Second is the one where index is 1
second
=
[
k
for
k
in
result
.
outputs
if
k
.
index
==
1
]
assert
len
(
second
)
==
1
assert
second
[
0
].
text
==
"b"
assert
second
[
0
].
token_ids
==
[
1
]
# Third is the one where index is 2
third
=
[
k
for
k
in
result
.
outputs
if
k
.
index
==
2
]
assert
len
(
third
)
==
1
assert
third
[
0
].
text
==
"c"
tests/v1/engine/utils.py
View file @
081057de
...
@@ -8,8 +8,7 @@ import torch
...
@@ -8,8 +8,7 @@ import torch
from
transformers
import
PreTrainedTokenizer
,
PreTrainedTokenizerFast
from
transformers
import
PreTrainedTokenizer
,
PreTrainedTokenizerFast
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.transformers_utils.tokenizer_group.base_tokenizer_group
import
(
from
vllm.transformers_utils.tokenizer_group
import
TokenizerGroup
BaseTokenizerGroup
)
from
vllm.v1.engine
import
EngineCoreOutput
,
FinishReason
from
vllm.v1.engine
import
EngineCoreOutput
,
FinishReason
from
vllm.v1.outputs
import
LogprobsLists
,
LogprobsTensors
from
vllm.v1.outputs
import
LogprobsLists
,
LogprobsTensors
...
@@ -296,7 +295,7 @@ def generate_dummy_prompt_logprobs_tensors(
...
@@ -296,7 +295,7 @@ def generate_dummy_prompt_logprobs_tensors(
class
DummyOutputProcessorTestVectors
:
class
DummyOutputProcessorTestVectors
:
"""Dummy test vectors for output processor tests"""
"""Dummy test vectors for output processor tests"""
tokenizer
:
GeneralTokenizerType
tokenizer
:
GeneralTokenizerType
tokenizer_group
:
Base
TokenizerGroup
tokenizer_group
:
TokenizerGroup
vllm_config
:
EngineArgs
vllm_config
:
EngineArgs
full_tokens
:
list
[
list
[
int
]]
# Prompt + generated tokens
full_tokens
:
list
[
list
[
int
]]
# Prompt + generated tokens
prompt_tokens
:
list
[
list
[
int
]]
prompt_tokens
:
list
[
list
[
int
]]
...
...
tests/v1/entrypoints/conftest.py
View file @
081057de
...
@@ -47,6 +47,14 @@ def sample_json_schema():
...
@@ -47,6 +47,14 @@ def sample_json_schema():
"type"
:
"string"
,
"type"
:
"string"
,
}
}
},
},
"grade"
:
{
"type"
:
"string"
,
"pattern"
:
"^[A-D]$"
# Regex pattern
},
"email"
:
{
"type"
:
"string"
,
"pattern"
:
"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+
\\
.[a-zA-Z]{2,}$"
},
"work_history"
:
{
"work_history"
:
{
"type"
:
"array"
,
"type"
:
"array"
,
"items"
:
{
"items"
:
{
...
@@ -56,17 +64,20 @@ def sample_json_schema():
...
@@ -56,17 +64,20 @@ def sample_json_schema():
"type"
:
"string"
"type"
:
"string"
},
},
"duration"
:
{
"duration"
:
{
"type"
:
"number"
"type"
:
"number"
,
"minimum"
:
0.0
,
"maximum"
:
100.0
,
# Numeric range
},
},
"position"
:
{
"position"
:
{
"type"
:
"string"
"type"
:
"string"
}
}
},
},
"required"
:
[
"company"
,
"position"
]
"required"
:
[
"company"
,
"duration"
,
"position"
]
}
}
}
}
},
},
"required"
:
[
"name"
,
"age"
,
"skills"
,
"work_history"
]
"required"
:
[
"name"
,
"age"
,
"skills"
,
"grade"
,
"email"
,
"work_history"
]
}
}
...
@@ -78,27 +89,18 @@ def unsupported_json_schema():
...
@@ -78,27 +89,18 @@ def unsupported_json_schema():
"properties"
:
{
"properties"
:
{
"score"
:
{
"score"
:
{
"type"
:
"integer"
,
"type"
:
"integer"
,
"minimum"
:
0
,
"multipleOf"
:
5
# Numeric multiple
"maximum"
:
100
# Numeric range
},
"grade"
:
{
"type"
:
"string"
,
"pattern"
:
"^[A-D]$"
# Regex pattern
},
"email"
:
{
"type"
:
"string"
,
"pattern"
:
"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+
\\
.[a-zA-Z]{2,}$"
},
},
"tags"
:
{
"tags"
:
{
"type"
:
"array"
,
"type"
:
"array"
,
"items"
:
{
"items"
:
{
"type"
:
"string"
,
"type"
:
"string"
,
"
pattern"
:
"
minLength"
:
10
,
"
^[a-z]{1,10}$"
# Combining length and pattern restrictions
"
maxLength"
:
20
}
}
}
}
},
},
"required"
:
[
"score"
,
"grade"
,
"email"
,
"tags"
]
"required"
:
[
"score"
,
"tags"
]
}
}
...
...
tests/v1/entrypoints/llm/test_struct_output_generate.py
View file @
081057de
...
@@ -13,6 +13,7 @@ from pydantic import BaseModel
...
@@ -13,6 +13,7 @@ from pydantic import BaseModel
from
vllm.entrypoints.llm
import
LLM
from
vllm.entrypoints.llm
import
LLM
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.platforms
import
current_platform
from
vllm.sampling_params
import
GuidedDecodingParams
,
SamplingParams
from
vllm.sampling_params
import
GuidedDecodingParams
,
SamplingParams
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE
=
[
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE
=
[
...
@@ -63,10 +64,13 @@ def test_structured_output(
...
@@ -63,10 +64,13 @@ def test_structured_output(
):
):
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
# Don't use eager execution on TPUs because we want to test for no
# recompilation at runtime
enforce_eager
=
bool
(
not
current_platform
.
is_tpu
())
# Use a single LLM instance for several scenarios to
# Use a single LLM instance for several scenarios to
# speed up the test suite.
# speed up the test suite.
llm
=
LLM
(
model
=
model_name
,
llm
=
LLM
(
model
=
model_name
,
enforce_eager
=
True
,
enforce_eager
=
enforce_eager
,
max_model_len
=
1024
,
max_model_len
=
1024
,
guided_decoding_backend
=
guided_decoding_backend
,
guided_decoding_backend
=
guided_decoding_backend
,
tokenizer_mode
=
tokenizer_mode
)
tokenizer_mode
=
tokenizer_mode
)
...
@@ -346,6 +350,7 @@ def test_structured_output(
...
@@ -346,6 +350,7 @@ def test_structured_output(
temperature
=
1.0
,
temperature
=
1.0
,
max_tokens
=
1000
,
max_tokens
=
1000
,
guided_decoding
=
GuidedDecodingParams
(
json
=
json_schema
))
guided_decoding
=
GuidedDecodingParams
(
json
=
json_schema
))
outputs
=
llm
.
generate
(
outputs
=
llm
.
generate
(
prompts
=
"Generate a description of a frog using 50 characters."
,
prompts
=
"Generate a description of a frog using 50 characters."
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
...
@@ -364,6 +369,106 @@ def test_structured_output(
...
@@ -364,6 +369,106 @@ def test_structured_output(
output_json
=
json
.
loads
(
generated_text
)
output_json
=
json
.
loads
(
generated_text
)
jsonschema
.
validate
(
instance
=
output_json
,
schema
=
json_schema
)
jsonschema
.
validate
(
instance
=
output_json
,
schema
=
json_schema
)
#
# Test 11: Generate structured output using structural_tag format
#
structural_tag_config
=
{
"type"
:
"structural_tag"
,
"structures"
:
[{
"begin"
:
"<function=get_weather>"
,
"schema"
:
{
"type"
:
"object"
,
"properties"
:
{
"city"
:
{
"type"
:
"string"
}
}
},
"end"
:
"</function>"
}],
"triggers"
:
[
"<function="
]
}
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
100
,
guided_decoding
=
GuidedDecodingParams
(
structural_tag
=
json
.
dumps
(
structural_tag_config
)))
prompt
=
"""
You have access to the following function to retrieve the weather in a city:
{
"name": "get_weather",
"parameters": {
"city": {
"param_type": "string",
"description": "The city to get the weather for",
"required": True
}
}
}
If a you choose to call a function ONLY reply in the following format:
<{start_tag}={function_name}>{parameters}{end_tag}
where
start_tag => `<function`
parameters => a JSON dict with the function argument name
as key and function argument value as value.
end_tag => `</function>`
Here is an example,
<function=example_function_name>{"example_name": "example_value"}</function>
Reminder:
- Function calls MUST follow the specified format
- Required parameters MUST be specified
- Only call one function at a time
- Put the entire function call reply on one line
- Always add your sources when using search results to answer the user query
You are a helpful assistant.
Given the previous instructions, what is the weather in New York City?
"""
# Change this once other backends support structural_tag
if
guided_decoding_backend
.
startswith
(
"xgrammar"
):
outputs
=
llm
.
generate
(
prompts
=
prompt
,
sampling_params
=
sampling_params
,
use_tqdm
=
True
)
assert
outputs
is
not
None
else
:
outputs
=
[]
for
output
in
outputs
:
assert
output
is
not
None
assert
isinstance
(
output
,
RequestOutput
)
generated_text
=
output
.
outputs
[
0
].
text
assert
generated_text
is
not
None
# Search for function call pattern in the response
function_call_pattern
=
r
'<function=get_weather>(.*?)</function>'
matches
=
re
.
findall
(
function_call_pattern
,
generated_text
)
if
not
matches
:
print
(
f
"Warning: No function calls found in response: "
f
"
{
generated_text
!
r
}
"
)
continue
# Take the first function call if multiple are found
json_str
=
matches
[
0
]
try
:
json_content
=
json
.
loads
(
json_str
)
assert
"city"
in
json_content
assert
isinstance
(
json_content
[
"city"
],
str
)
print
(
f
"Found valid function call:
{
generated_text
!
r
}
"
)
except
(
json
.
JSONDecodeError
,
AssertionError
)
as
e
:
pytest
.
fail
(
"Invalid function call format: "
f
"
{
generated_text
!
r
}
\n
Error:
{
str
(
e
)
}
"
)
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
parametrize
(
"model_name, tokenizer_mode"
,
@
pytest
.
mark
.
parametrize
(
"model_name, tokenizer_mode"
,
...
@@ -386,13 +491,21 @@ def test_structured_output_auto_mode(
...
@@ -386,13 +491,21 @@ def test_structured_output_auto_mode(
max_tokens
=
1000
,
max_tokens
=
1000
,
guided_decoding
=
GuidedDecodingParams
(
json
=
unsupported_json_schema
))
guided_decoding
=
GuidedDecodingParams
(
json
=
unsupported_json_schema
))
prompts
=
(
"Give an example JSON object for a grade "
"that fits this schema: "
f
"
{
unsupported_json_schema
}
"
)
# This would fail with the default of "xgrammar", but in "auto"
# This would fail with the default of "xgrammar", but in "auto"
# we will handle fallback automatically.
# we will handle fallback automatically.
outputs
=
llm
.
generate
(
prompts
=
(
"Give an example JSON object for a grade "
outputs
=
llm
.
generate
(
prompts
=
prompts
,
"that fits this schema: "
f
"
{
unsupported_json_schema
}
"
),
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
use_tqdm
=
True
)
use_tqdm
=
True
)
# Make sure `auto` backend handling doesn't mess up sampling_params
# and that we can reuse it without error.
outputs
.
extend
(
llm
.
generate
(
prompts
=
prompts
,
sampling_params
=
sampling_params
,
use_tqdm
=
True
))
assert
outputs
is
not
None
assert
outputs
is
not
None
for
output
in
outputs
:
for
output
in
outputs
:
assert
output
is
not
None
assert
output
is
not
None
...
@@ -404,3 +517,59 @@ def test_structured_output_auto_mode(
...
@@ -404,3 +517,59 @@ def test_structured_output_auto_mode(
# Parse to verify it is valid JSON
# Parse to verify it is valid JSON
parsed_json
=
json
.
loads
(
generated_text
)
parsed_json
=
json
.
loads
(
generated_text
)
assert
isinstance
(
parsed_json
,
dict
)
assert
isinstance
(
parsed_json
,
dict
)
@
pytest
.
mark
.
skip_global_cleanup
def
test_guidance_no_additional_properties
(
monkeypatch
:
pytest
.
MonkeyPatch
):
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
backend
=
'guidance:no-additional-properties,disable-any-whitespace'
llm
=
LLM
(
model
=
"Qwen/Qwen2.5-1.5B-Instruct"
,
max_model_len
=
1024
,
guided_decoding_backend
=
backend
)
schema
=
{
'type'
:
'object'
,
'properties'
:
{
'a1'
:
{
'type'
:
'string'
},
'a2'
:
{
'type'
:
'string'
},
'a3'
:
{
'type'
:
'string'
}
},
'required'
:
[
'a1'
,
'a2'
,
'a3'
],
}
prompt
=
(
"<|im_start|>system
\n
You are Qwen, created by Alibaba Cloud. You are a "
"helpful assistant.<|im_end|>
\n
<|im_start|>user
\n
Please generate a "
"large JSON object with key-value pairs a1=b1, a2=b2, ..., a20=b20"
"<|im_end|>
\n
<|im_start|>assistant
\n
"
)
def
generate_with_backend
(
backend
):
guided_params
=
GuidedDecodingParams
(
json
=
schema
,
backend
=
backend
)
sampling_params
=
SamplingParams
(
temperature
=
0
,
max_tokens
=
256
,
guided_decoding
=
guided_params
)
outputs
=
llm
.
generate
(
prompts
=
prompt
,
sampling_params
=
sampling_params
)
assert
outputs
is
not
None
generated_text
=
outputs
[
0
].
outputs
[
0
].
text
assert
generated_text
is
not
None
parsed_json
=
json
.
loads
(
generated_text
)
assert
isinstance
(
parsed_json
,
dict
)
jsonschema
.
validate
(
instance
=
parsed_json
,
schema
=
schema
)
return
parsed_json
generated
=
generate_with_backend
(
'guidance:no-additional-properties,disable-any-whitespace'
)
assert
"a1"
in
generated
assert
"a2"
in
generated
assert
"a3"
in
generated
assert
"a4"
not
in
generated
assert
"a5"
not
in
generated
assert
"a6"
not
in
generated
tests/v1/shutdown/test_delete.py
0 → 100644
View file @
081057de
# SPDX-License-Identifier: Apache-2.0
"""Test that we handle a startup Error and shutdown."""
import
pytest
from
tests.utils
import
wait_for_gpu_memory_to_clear
from
tests.v1.shutdown.utils
import
(
SHUTDOWN_TEST_THRESHOLD_BYTES
,
SHUTDOWN_TEST_TIMEOUT_SEC
)
from
vllm
import
LLM
,
SamplingParams
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.sampling_params
import
RequestOutputKind
from
vllm.utils
import
cuda_device_count_stateless
from
vllm.v1.engine.async_llm
import
AsyncLLM
MODELS
=
[
"meta-llama/Llama-3.2-1B"
]
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
timeout
(
SHUTDOWN_TEST_TIMEOUT_SEC
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"tensor_parallel_size"
,
[
2
,
1
])
@
pytest
.
mark
.
parametrize
(
"send_one_request"
,
[
False
,
True
])
async
def
test_async_llm_delete
(
model
:
str
,
tensor_parallel_size
:
int
,
send_one_request
:
bool
)
->
None
:
"""Test that AsyncLLM frees GPU memory upon deletion.
AsyncLLM always uses an MP client.
Args:
model: model under test
tensor_parallel_size: degree of tensor parallelism
send_one_request: send one request to engine before deleting
"""
if
cuda_device_count_stateless
()
<
tensor_parallel_size
:
pytest
.
skip
(
reason
=
"Not enough CUDA devices"
)
engine_args
=
AsyncEngineArgs
(
model
=
model
,
enforce_eager
=
True
,
tensor_parallel_size
=
tensor_parallel_size
)
# Instantiate AsyncLLM; make request to complete any deferred
# initialization; then delete instance
async_llm
=
AsyncLLM
.
from_engine_args
(
engine_args
)
if
send_one_request
:
async
for
_
in
async_llm
.
generate
(
"Hello my name is"
,
request_id
=
"abc"
,
sampling_params
=
SamplingParams
(
max_tokens
=
1
,
output_kind
=
RequestOutputKind
.
DELTA
)):
pass
del
async_llm
# Confirm all the processes are cleaned up.
wait_for_gpu_memory_to_clear
(
devices
=
list
(
range
(
tensor_parallel_size
)),
threshold_bytes
=
SHUTDOWN_TEST_THRESHOLD_BYTES
,
)
@
pytest
.
mark
.
timeout
(
SHUTDOWN_TEST_TIMEOUT_SEC
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"tensor_parallel_size"
,
[
2
,
1
])
@
pytest
.
mark
.
parametrize
(
"enable_multiprocessing"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"send_one_request"
,
[
False
,
True
])
def
test_llm_delete
(
monkeypatch
,
model
:
str
,
tensor_parallel_size
:
int
,
enable_multiprocessing
:
bool
,
send_one_request
:
bool
)
->
None
:
"""Test that LLM frees GPU memory upon deletion.
TODO(andy) - LLM without multiprocessing.
Args:
model: model under test
tensor_parallel_size: degree of tensor parallelism
enable_multiprocessing: enable workers in separate process(es)
send_one_request: send one request to engine before deleting
"""
if
cuda_device_count_stateless
()
<
tensor_parallel_size
:
pytest
.
skip
(
reason
=
"Not enough CUDA devices"
)
with
monkeypatch
.
context
()
as
m
:
MP_VALUE
=
"1"
if
enable_multiprocessing
else
"0"
m
.
setenv
(
"VLLM_ENABLE_V1_MULTIPROCESSING"
,
MP_VALUE
)
# Instantiate LLM; make request to complete any deferred
# initialization; then delete instance
llm
=
LLM
(
model
=
model
,
enforce_eager
=
True
,
tensor_parallel_size
=
tensor_parallel_size
)
if
send_one_request
:
llm
.
generate
(
"Hello my name is"
,
sampling_params
=
SamplingParams
(
max_tokens
=
1
))
del
llm
# Confirm all the processes are cleaned up.
wait_for_gpu_memory_to_clear
(
devices
=
list
(
range
(
tensor_parallel_size
)),
threshold_bytes
=
SHUTDOWN_TEST_THRESHOLD_BYTES
,
)
tests/v1/shutdown/test_forward_error.py
0 → 100644
View file @
081057de
# SPDX-License-Identifier: Apache-2.0
"""Test that we handle an Error in model forward and shutdown."""
import
asyncio
import
pytest
from
tests.utils
import
wait_for_gpu_memory_to_clear
from
tests.v1.shutdown.utils
import
(
SHUTDOWN_TEST_THRESHOLD_BYTES
,
SHUTDOWN_TEST_TIMEOUT_SEC
)
from
vllm
import
LLM
,
AsyncEngineArgs
,
SamplingParams
from
vllm.distributed
import
get_tensor_model_parallel_rank
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
from
vllm.utils
import
cuda_device_count_stateless
from
vllm.v1.engine.async_llm
import
AsyncLLM
from
vllm.v1.engine.exceptions
import
EngineDeadError
MODELS
=
[
"meta-llama/Llama-3.2-1B"
]
def
evil_forward
(
self
,
*
args
,
**
kwargs
):
"""Evil forward method that raise an exception after 10 calls."""
NUMBER_OF_GOOD_PASSES
=
10
if
not
hasattr
(
self
,
"num_calls"
):
self
.
num_calls
=
0
if
(
self
.
num_calls
==
NUMBER_OF_GOOD_PASSES
and
get_tensor_model_parallel_rank
()
==
0
):
raise
Exception
(
"Simulated illegal memory access on Rank 0!"
)
self
.
num_calls
+=
1
return
self
.
model
(
*
args
,
**
kwargs
)
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"tensor_parallel_size"
,
[
2
,
1
])
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
async
def
test_async_llm_model_error
(
monkeypatch
,
tensor_parallel_size
:
int
,
model
:
str
)
->
None
:
"""Test that AsyncLLM propagates a forward pass error and frees memory.
AsyncLLM always uses an MP client.
"""
if
cuda_device_count_stateless
()
<
tensor_parallel_size
:
pytest
.
skip
(
reason
=
"Not enough CUDA devices"
)
# Monkeypatch an error in the model.
monkeypatch
.
setattr
(
LlamaForCausalLM
,
"forward"
,
evil_forward
)
engine_args
=
AsyncEngineArgs
(
model
=
model
,
enforce_eager
=
True
,
tensor_parallel_size
=
tensor_parallel_size
)
async_llm
=
AsyncLLM
.
from_engine_args
(
engine_args
)
async
def
generate
(
request_id
:
str
):
generator
=
async_llm
.
generate
(
"Hello my name is"
,
request_id
=
request_id
,
sampling_params
=
SamplingParams
())
try
:
async
for
_
in
generator
:
pass
except
Exception
as
e
:
return
e
NUM_REQS
=
3
tasks
=
[
generate
(
f
"request-
{
idx
}
"
)
for
idx
in
range
(
NUM_REQS
)]
outputs
=
await
asyncio
.
gather
(
*
tasks
)
# Every request should get an EngineDeadError.
for
output
in
outputs
:
assert
isinstance
(
output
,
EngineDeadError
)
# AsyncLLM should be errored.
assert
async_llm
.
errored
# We should not be able to make another request.
with
pytest
.
raises
(
EngineDeadError
):
async
for
_
in
async_llm
.
generate
(
"Hello my name is"
,
request_id
=
"abc"
,
sampling_params
=
SamplingParams
()):
raise
Exception
(
"We should not get here."
)
# Confirm all the processes are cleaned up.
wait_for_gpu_memory_to_clear
(
devices
=
list
(
range
(
tensor_parallel_size
)),
threshold_bytes
=
2
*
2
**
30
,
timeout_s
=
60
,
)
# NOTE: shutdown is handled by the API Server if an exception
# occurs, so it is expected that we would need to call this.
async_llm
.
shutdown
()
@
pytest
.
mark
.
timeout
(
SHUTDOWN_TEST_TIMEOUT_SEC
)
@
pytest
.
mark
.
parametrize
(
"enable_multiprocessing"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"tensor_parallel_size"
,
[
2
,
1
])
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
def
test_llm_model_error
(
monkeypatch
,
tensor_parallel_size
:
int
,
enable_multiprocessing
:
bool
,
model
:
str
)
->
None
:
"""Test that LLM propagates a forward pass error and frees memory.
TODO(andy) - LLM without multiprocessing; LLM with multiprocessing
and >1 rank
"""
if
cuda_device_count_stateless
()
<
tensor_parallel_size
:
pytest
.
skip
(
reason
=
"Not enough CUDA devices"
)
with
monkeypatch
.
context
()
as
m
:
MP_VALUE
=
"1"
if
enable_multiprocessing
else
"0"
m
.
setenv
(
"VLLM_ENABLE_V1_MULTIPROCESSING"
,
MP_VALUE
)
# Monkeypatch an error in the model.
m
.
setattr
(
LlamaForCausalLM
,
"forward"
,
evil_forward
)
llm
=
LLM
(
model
=
model
,
enforce_eager
=
True
,
tensor_parallel_size
=
tensor_parallel_size
)
with
pytest
.
raises
(
EngineDeadError
if
enable_multiprocessing
else
Exception
):
llm
.
generate
(
"Hello my name is Robert and I"
)
# Confirm all the processes are cleaned up.
wait_for_gpu_memory_to_clear
(
devices
=
list
(
range
(
tensor_parallel_size
)),
threshold_bytes
=
SHUTDOWN_TEST_THRESHOLD_BYTES
,
)
tests/v1/shutdown/test_processor_error.py
0 → 100644
View file @
081057de
# SPDX-License-Identifier: Apache-2.0
"""Test error handling in Processor. Should not impact other reqs."""
import
asyncio
import
pytest
from
tests.v1.shutdown.utils
import
SHUTDOWN_TEST_TIMEOUT_SEC
from
vllm
import
SamplingParams
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.inputs.data
import
TokensPrompt
from
vllm.sampling_params
import
RequestOutputKind
from
vllm.v1.engine.async_llm
import
AsyncLLM
from
vllm.v1.engine.exceptions
import
EngineGenerateError
MODELS
=
[
"meta-llama/Llama-3.2-1B"
]
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
timeout
(
SHUTDOWN_TEST_TIMEOUT_SEC
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
async
def
test_async_llm_processor_error
(
model
:
str
)
->
None
:
"""Test that AsyncLLM propagates a processor error.
Test empty tokens prompt (failure) and non-empty prompt (no failure.)
AsyncLLM always uses an MP client.
"""
engine_args
=
AsyncEngineArgs
(
model
=
model
,
enforce_eager
=
True
)
async_llm
=
AsyncLLM
.
from_engine_args
(
engine_args
)
async
def
generate
(
request_id
:
str
):
# [] is not allowed and will raise a ValueError in Processor.
generator
=
async_llm
.
generate
(
TokensPrompt
([]),
request_id
=
request_id
,
sampling_params
=
SamplingParams
())
try
:
async
for
_
in
generator
:
pass
except
Exception
as
e
:
return
e
NUM_REQS
=
3
tasks
=
[
generate
(
f
"request-
{
idx
}
"
)
for
idx
in
range
(
NUM_REQS
)]
outputs
=
await
asyncio
.
gather
(
*
tasks
)
# Every request should have get an EngineGenerateError.
for
output
in
outputs
:
with
pytest
.
raises
(
EngineGenerateError
):
raise
output
# AsyncLLM should be errored.
assert
not
async_llm
.
errored
# This should be no problem.
EXPECTED_TOKENS
=
5
outputs
=
[]
async
for
out
in
async_llm
.
generate
(
"Hello my name is"
,
request_id
=
"abc"
,
sampling_params
=
SamplingParams
(
max_tokens
=
EXPECTED_TOKENS
,
output_kind
=
RequestOutputKind
.
DELTA
)):
outputs
.
append
(
out
)
generated_tokens
=
[]
for
out
in
outputs
:
generated_tokens
.
extend
(
out
.
outputs
[
0
].
token_ids
)
assert
len
(
generated_tokens
)
==
EXPECTED_TOKENS
async_llm
.
shutdown
()
tests/v1/shutdown/test_startup_error.py
0 → 100644
View file @
081057de
# SPDX-License-Identifier: Apache-2.0
"""Test that we handle a startup Error and shutdown."""
import
pytest
from
tests.utils
import
wait_for_gpu_memory_to_clear
from
tests.v1.shutdown.utils
import
(
SHUTDOWN_TEST_THRESHOLD_BYTES
,
SHUTDOWN_TEST_TIMEOUT_SEC
)
from
vllm
import
LLM
from
vllm.distributed
import
get_tensor_model_parallel_rank
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
from
vllm.utils
import
cuda_device_count_stateless
from
vllm.v1.engine.async_llm
import
AsyncLLM
MODELS
=
[
"meta-llama/Llama-3.2-1B"
]
def
evil_method
(
self
,
*
args
,
**
kwargs
):
"""Evil method that raises an exception."""
if
get_tensor_model_parallel_rank
()
==
0
:
raise
Exception
(
"Simulated Error in startup!"
)
return
self
.
model
(
*
args
,
**
kwargs
,
intermediate_tensors
=
None
)
@
pytest
.
mark
.
timeout
(
SHUTDOWN_TEST_TIMEOUT_SEC
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"tensor_parallel_size"
,
[
2
,
1
])
@
pytest
.
mark
.
parametrize
(
"failing_method"
,
[
"forward"
,
"load_weights"
])
def
test_async_llm_startup_error
(
monkeypatch
,
model
:
str
,
tensor_parallel_size
:
int
,
failing_method
:
str
)
->
None
:
"""Test that AsyncLLM propagates an __init__ error & frees memory.
Test profiling (forward()) and load weights failures.
AsyncLLM always uses an MP client.
"""
if
cuda_device_count_stateless
()
<
tensor_parallel_size
:
pytest
.
skip
(
reason
=
"Not enough CUDA devices"
)
# Monkeypatch an error in the model.
monkeypatch
.
setattr
(
LlamaForCausalLM
,
failing_method
,
evil_method
)
engine_args
=
AsyncEngineArgs
(
model
=
model
,
enforce_eager
=
True
,
tensor_parallel_size
=
tensor_parallel_size
)
# Confirm we get an exception.
with
pytest
.
raises
(
Exception
,
match
=
"initialization failed"
):
_
=
AsyncLLM
.
from_engine_args
(
engine_args
)
# Confirm all the processes are cleaned up.
wait_for_gpu_memory_to_clear
(
devices
=
list
(
range
(
tensor_parallel_size
)),
threshold_bytes
=
SHUTDOWN_TEST_THRESHOLD_BYTES
,
)
@
pytest
.
mark
.
timeout
(
SHUTDOWN_TEST_TIMEOUT_SEC
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"tensor_parallel_size"
,
[
2
,
1
])
@
pytest
.
mark
.
parametrize
(
"enable_multiprocessing"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"failing_method"
,
[
"forward"
,
"load_weights"
])
def
test_llm_startup_error
(
monkeypatch
,
model
:
str
,
tensor_parallel_size
:
int
,
enable_multiprocessing
:
bool
,
failing_method
:
str
)
->
None
:
"""Test that LLM propagates an __init__ error and frees memory.
Test profiling (forward()) and load weights failures.
TODO(andy) - LLM without multiprocessing.
"""
if
model
!=
"meta-llama/Llama-3.2-1B"
:
pytest
.
skip
(
reason
=
"Only test meta-llama/Llama-3.2-1B"
)
if
cuda_device_count_stateless
()
<
tensor_parallel_size
:
pytest
.
skip
(
reason
=
"Not enough CUDA devices"
)
with
monkeypatch
.
context
()
as
m
:
MP_VALUE
=
"1"
if
enable_multiprocessing
else
"0"
m
.
setenv
(
"VLLM_ENABLE_V1_MULTIPROCESSING"
,
MP_VALUE
)
# Monkeypatch an error in the model.
monkeypatch
.
setattr
(
LlamaForCausalLM
,
failing_method
,
evil_method
)
with
pytest
.
raises
(
Exception
,
match
=
"initialization failed"
if
enable_multiprocessing
else
"Simulated Error in startup!"
):
_
=
LLM
(
model
=
model
,
enforce_eager
=
True
,
tensor_parallel_size
=
tensor_parallel_size
)
# Confirm all the processes are cleaned up.
wait_for_gpu_memory_to_clear
(
devices
=
list
(
range
(
tensor_parallel_size
)),
threshold_bytes
=
SHUTDOWN_TEST_THRESHOLD_BYTES
,
)
tests/v1/shutdown/utils.py
0 → 100644
View file @
081057de
# SPDX-License-Identifier: Apache-2.0
"""Shutdown test utils"""
SHUTDOWN_TEST_TIMEOUT_SEC
=
120
SHUTDOWN_TEST_THRESHOLD_BYTES
=
2
*
2
**
30
tests/v1/spec_decode/test_max_len.py
0 → 100644
View file @
081057de
# SPDX-License-Identifier: Apache-2.0
"""Test whether spec decoding handles the max model length properly."""
import
pytest
from
vllm
import
LLM
,
SamplingParams
_PROMPTS
=
[
"1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1"
,
"Repeat the following sentence 10 times: Consistency is key to mastering any skill."
,
# noqa: E501
"Who won the Turing Award in 2018, and for what contribution? Describe in detail."
,
# noqa: E501
]
@
pytest
.
mark
.
parametrize
(
"num_speculative_tokens"
,
[
1
,
3
,
10
])
def
test_ngram_max_len
(
monkeypatch
:
pytest
.
MonkeyPatch
,
num_speculative_tokens
:
int
,
):
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
llm
=
LLM
(
model
=
"facebook/opt-125m"
,
max_model_len
=
100
,
enforce_eager
=
True
,
# For faster initialization.
speculative_config
=
{
"method"
:
"ngram"
,
"prompt_lookup_max"
:
5
,
"prompt_lookup_min"
:
3
,
"num_speculative_tokens"
:
num_speculative_tokens
,
},
)
sampling_params
=
SamplingParams
(
max_tokens
=
100
,
ignore_eos
=
True
)
llm
.
generate
(
_PROMPTS
,
sampling_params
)
@
pytest
.
mark
.
parametrize
(
"num_speculative_tokens"
,
[
1
,
3
,
10
])
def
test_eagle_max_len
(
monkeypatch
:
pytest
.
MonkeyPatch
,
num_speculative_tokens
:
int
,
):
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
llm
=
LLM
(
model
=
"meta-llama/Meta-Llama-3-8B-Instruct"
,
enforce_eager
=
True
,
# For faster initialization.
speculative_config
=
{
"method"
:
"eagle"
,
"model"
:
"yuhuili/EAGLE-LLaMA3-Instruct-8B"
,
"num_speculative_tokens"
:
num_speculative_tokens
,
},
max_model_len
=
100
,
)
sampling_params
=
SamplingParams
(
max_tokens
=
100
,
ignore_eos
=
True
)
llm
.
generate
(
_PROMPTS
,
sampling_params
)
tests/v1/spec_decode/test_ngram.py
View file @
081057de
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
import
numpy
as
np
import
numpy
as
np
from
vllm.config
import
ModelConfig
,
SpeculativeConfig
,
VllmConfig
from
vllm.v1.spec_decode.ngram_proposer
import
(
NgramProposer
,
from
vllm.v1.spec_decode.ngram_proposer
import
(
NgramProposer
,
_find_subarray_kmp
,
_find_subarray_kmp
,
_kmp_lps_array
)
_kmp_lps_array
)
...
@@ -39,50 +40,50 @@ def test_find_subarray_kmp():
...
@@ -39,50 +40,50 @@ def test_find_subarray_kmp():
def
test_ngram_proposer
():
def
test_ngram_proposer
():
proposer
=
NgramProposer
()
def
ngram_proposer
(
min_n
:
int
,
max_n
:
int
,
k
:
int
)
->
NgramProposer
:
# Dummy model config. Just to set max_model_len.
model_config
=
ModelConfig
(
model
=
"facebook/opt-125m"
,
task
=
"generate"
,
max_model_len
=
100
,
tokenizer
=
"facebook/opt-125m"
,
tokenizer_mode
=
"auto"
,
dtype
=
"auto"
,
seed
=
None
,
trust_remote_code
=
False
)
return
NgramProposer
(
vllm_config
=
VllmConfig
(
model_config
=
model_config
,
speculative_config
=
SpeculativeConfig
.
from_dict
({
"prompt_lookup_min"
:
min_n
,
"prompt_lookup_max"
:
max_n
,
"num_speculative_tokens"
:
k
,
"method"
:
"ngram"
,
})))
# No match.
# No match.
result
=
proposer
.
propose
(
result
=
ngram_proposer
(
context_token_ids
=
np
.
array
([
1
,
2
,
3
,
4
,
5
]),
2
,
2
,
2
).
propose
(
context_token_ids
=
np
.
array
([
1
,
2
,
3
,
4
,
5
]))
min_n
=
2
,
max_n
=
2
,
k
=
2
,
)
assert
result
is
None
assert
result
is
None
# No match for 4-gram.
# No match for 4-gram.
result
=
proposer
.
propose
(
result
=
ngram_proposer
(
context_token_ids
=
np
.
array
([
1
,
2
,
3
,
4
,
1
,
2
,
3
]),
4
,
4
,
2
).
propose
(
context_token_ids
=
np
.
array
([
1
,
2
,
3
,
4
,
1
,
2
,
3
]))
min_n
=
4
,
max_n
=
4
,
k
=
2
,
)
assert
result
is
None
assert
result
is
None
# No match for 4-gram but match for 3-gram.
# No match for 4-gram but match for 3-gram.
result
=
proposer
.
propose
(
result
=
ngram_proposer
(
context_token_ids
=
np
.
array
([
1
,
2
,
3
,
4
,
1
,
2
,
3
]),
3
,
4
,
2
).
propose
(
context_token_ids
=
np
.
array
([
1
,
2
,
3
,
4
,
1
,
2
,
3
]))
min_n
=
3
,
max_n
=
4
,
k
=
2
,
)
assert
np
.
array_equal
(
result
,
np
.
array
([
4
,
1
]))
assert
np
.
array_equal
(
result
,
np
.
array
([
4
,
1
]))
# Match for both 4-gram and 3-gram.
# Match for both 4-gram and 3-gram.
# In this case, the proposer should return the 4-gram match.
# In this case, the proposer should return the 4-gram match.
result
=
proposer
.
propose
(
result
=
ngram_proposer
(
3
,
4
,
2
).
propose
(
context_token_ids
=
np
.
array
([
2
,
3
,
4
,
5
,
1
,
2
,
3
,
4
,
1
,
2
,
3
,
4
]),
context_token_ids
=
np
.
array
([
2
,
3
,
4
,
5
,
1
,
2
,
3
,
4
,
1
,
2
,
3
,
4
]))
min_n
=
3
,
max_n
=
4
,
k
=
2
,
)
assert
np
.
array_equal
(
result
,
np
.
array
([
1
,
2
]))
# Not [5, 1]
assert
np
.
array_equal
(
result
,
np
.
array
([
1
,
2
]))
# Not [5, 1]
# Match for 2-gram and 3-gram, but not 4-gram.
# Match for 2-gram and 3-gram, but not 4-gram.
result
=
proposer
.
propose
(
result
=
ngram_proposer
(
context_token_ids
=
np
.
array
([
3
,
4
,
5
,
2
,
3
,
4
,
1
,
2
,
3
,
4
]),
2
,
4
,
min_n
=
2
,
2
).
propose
(
context_token_ids
=
np
.
array
([
3
,
4
,
5
,
2
,
3
,
4
,
1
,
2
,
3
,
4
]))
max_n
=
4
,
k
=
2
,
)
assert
np
.
array_equal
(
result
,
np
.
array
([
1
,
2
]))
# Not [5, 2]
assert
np
.
array_equal
(
result
,
np
.
array
([
1
,
2
]))
# Not [5, 2]
Prev
1
…
13
14
15
16
17
18
19
20
21
…
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment