Commit 081057de authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.5' into v0.8.5-ori

parents 7cf5d5c4 ba41cc90
...@@ -37,7 +37,6 @@ def make_request(request_id, ...@@ -37,7 +37,6 @@ def make_request(request_id,
return Request( return Request(
request_id=request_id, request_id=request_id,
prompt=None,
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
multi_modal_inputs=multi_modal_inputs, multi_modal_inputs=multi_modal_inputs,
multi_modal_hashes=mm_hashes, multi_modal_hashes=mm_hashes,
...@@ -311,7 +310,7 @@ def test_metrics(): ...@@ -311,7 +310,7 @@ def test_metrics():
def stats(requests, queries, hits): def stats(requests, queries, hits):
return PrefixCacheStats(requests=requests, queries=queries, hits=hits) return PrefixCacheStats(requests=requests, queries=queries, hits=hits)
metrics = PrefixCachingMetrics(interval=5) metrics = PrefixCachingMetrics(max_recent_requests=5)
assert metrics.hit_rate == 0.0 assert metrics.hit_rate == 0.0
metrics.observe(stats(1, 20, 9)) metrics.observe(stats(1, 20, 9))
...@@ -496,8 +495,7 @@ def test_allocate_with_lookahead(): ...@@ -496,8 +495,7 @@ def test_allocate_with_lookahead():
# Test case 1: Requires additional lookahead tokens # Test case 1: Requires additional lookahead tokens
kv_cache_manager = KVCacheManager(kv_cache_config=config, kv_cache_manager = KVCacheManager(kv_cache_config=config,
max_model_len=100, max_model_len=100)
num_preallocate_tokens=0)
blocks = kv_cache_manager.allocate_slots( blocks = kv_cache_manager.allocate_slots(
request, request,
num_tokens=3, num_tokens=3,
...@@ -507,25 +505,19 @@ def test_allocate_with_lookahead(): ...@@ -507,25 +505,19 @@ def test_allocate_with_lookahead():
# Test case 2: With precomputed blocks # Test case 2: With precomputed blocks
kv_cache_manager = KVCacheManager(kv_cache_config=config, kv_cache_manager = KVCacheManager(kv_cache_config=config,
max_model_len=100, max_model_len=100)
num_preallocate_tokens=4)
# num_preallocate_blocks = 4 // 4 - 2 // 4 = 1
# required_blocks = ceil((3 + 2) /4) = 2 # required_blocks = ceil((3 + 2) /4) = 2
# total_blocks = 1 + 2 = 3
blocks = kv_cache_manager.allocate_slots( blocks = kv_cache_manager.allocate_slots(
request, request,
num_tokens=3, num_tokens=3,
num_lookahead_tokens=2, num_lookahead_tokens=2,
) )
assert len(blocks) == 3 assert len(blocks) == 2
# Test case 3: With precomputed blocks # Test case 3: With precomputed blocks
# num_preallocate_blocks = 4 // 4 - 4 // 4 = 0
# required_blocks = ceil((3 + 4) / 4) = 2 # required_blocks = ceil((3 + 4) / 4) = 2
# total_blocks = 0 + 2 = 2
kv_cache_manager = KVCacheManager(kv_cache_config=config, kv_cache_manager = KVCacheManager(kv_cache_config=config,
max_model_len=100, max_model_len=100)
num_preallocate_tokens=4)
blocks = kv_cache_manager.allocate_slots( blocks = kv_cache_manager.allocate_slots(
request, request,
num_tokens=3, num_tokens=3,
......
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import cdiv, sha256 from vllm.utils import sha256
from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock,
...@@ -29,7 +29,6 @@ def make_request(request_id, ...@@ -29,7 +29,6 @@ def make_request(request_id,
return Request( return Request(
request_id=request_id, request_id=request_id,
prompt=None,
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
multi_modal_inputs=multi_modal_inputs, multi_modal_inputs=multi_modal_inputs,
multi_modal_hashes=mm_hashes, multi_modal_hashes=mm_hashes,
...@@ -61,7 +60,6 @@ def test_prefill(hash_algo): ...@@ -61,7 +60,6 @@ def test_prefill(hash_algo):
max_model_len=8192, max_model_len=8192,
enable_caching=True, enable_caching=True,
caching_hash_algo=hash_algo, caching_hash_algo=hash_algo,
num_preallocate_tokens=16,
) )
# choose the hash function according to the parameter # choose the hash function according to the parameter
...@@ -80,7 +78,7 @@ def test_prefill(hash_algo): ...@@ -80,7 +78,7 @@ def test_prefill(hash_algo):
assert not computed_blocks assert not computed_blocks
assert num_computed_tokens == 0 assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks) blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5] assert [b.block_id for b in blocks] == [1, 2, 3, 4]
# Check full block metadata # Check full block metadata
parent_block_hash = None parent_block_hash = None
...@@ -92,8 +90,8 @@ def test_prefill(hash_algo): ...@@ -92,8 +90,8 @@ def test_prefill(hash_algo):
assert manager.block_pool.blocks[block_id].ref_cnt == 1 assert manager.block_pool.blocks[block_id].ref_cnt == 1
parent_block_hash = block_hash.hash_value parent_block_hash = block_hash.hash_value
# Check partial/preallocated block metadata # Check partial block metadata
for block_id in (4, 5): for block_id in (4, ):
assert manager.block_pool.blocks[block_id].block_hash is None assert manager.block_pool.blocks[block_id].block_hash is None
assert manager.block_pool.blocks[block_id].ref_cnt == 1 assert manager.block_pool.blocks[block_id].ref_cnt == 1
...@@ -107,12 +105,12 @@ def test_prefill(hash_algo): ...@@ -107,12 +105,12 @@ def test_prefill(hash_algo):
assert num_computed_tokens == 3 * 16 assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16 num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks) blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
assert [b.block_id for b in blocks] == [6, 7] assert [b.block_id for b in blocks] == [5]
for block in computed_blocks: for block in computed_blocks:
assert block.ref_cnt == 2 assert block.ref_cnt == 2
# At this point, we should have 3 free blocks left. # At this point, we should have 5 free blocks left.
assert manager.block_pool.free_block_queue.num_free_blocks == 3 assert manager.block_pool.free_block_queue.num_free_blocks == 5
manager.free(req0) manager.free(req0)
manager.free(req1) manager.free(req1)
...@@ -120,14 +118,14 @@ def test_prefill(hash_algo): ...@@ -120,14 +118,14 @@ def test_prefill(hash_algo):
# All blocks should be available. # All blocks should be available.
assert manager.block_pool.free_block_queue.num_free_blocks == 10 assert manager.block_pool.free_block_queue.num_free_blocks == 10
# The order should be # The order should be
# [unallocated (8, 9, 10)] # [unallocated (6, 7, 8, 9, 10)]
# [unique_req0 (5, 4)] # [unique_req0 (4)]
# [unique_req1 (7, 6)] # [unique_req1 (5)]
# [common (3, 2, 1)] # [common (3, 2, 1)]
assert [ assert [
b.block_id b.block_id
for b in manager.block_pool.free_block_queue.get_all_free_blocks() for b in manager.block_pool.free_block_queue.get_all_free_blocks()
] == [8, 9, 10, 5, 4, 7, 6, 3, 2, 1] ] == [6, 7, 8, 9, 10, 4, 5, 3, 2, 1]
# Cache hit in the common prefix when the original block is already free. # Cache hit in the common prefix when the original block is already free.
# Incomplete 1 block (6 tokens) # Incomplete 1 block (6 tokens)
...@@ -139,29 +137,29 @@ def test_prefill(hash_algo): ...@@ -139,29 +137,29 @@ def test_prefill(hash_algo):
assert num_computed_tokens == 3 * 16 assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16 num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks) blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks)
assert [b.block_id for b in blocks] == [8, 9] assert [b.block_id for b in blocks] == [6]
# Although we only have 5 free blocks, we have 8 blocks in # Although we only have 6 free blocks, we have 8 blocks in
# the free block queue due to lazy removal. # the free block queue due to lazy removal.
assert manager.block_pool.free_block_queue.num_free_blocks == 5 assert manager.block_pool.free_block_queue.num_free_blocks == 6
assert all([ assert all([
b.ref_cnt == 0 b.ref_cnt == 0
for b in manager.block_pool.free_block_queue.get_all_free_blocks() for b in manager.block_pool.free_block_queue.get_all_free_blocks()
]) ])
assert len([ assert len([
b for b in manager.block_pool.free_block_queue.get_all_free_blocks() b for b in manager.block_pool.free_block_queue.get_all_free_blocks()
]) == 5 ]) == 6
manager.free(req2) manager.free(req2)
# Cache miss and eviction. # Cache miss and eviction.
req3 = make_request("3", [99] * (16 * 9)) req3 = make_request("3", [99] * (16 * 10))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
assert not computed_blocks assert not computed_blocks
assert num_computed_tokens == 0 assert num_computed_tokens == 0
blocks = manager.allocate_slots(req3, 16 * 9, computed_blocks) blocks = manager.allocate_slots(req3, 16 * 10, computed_blocks)
# This block ID order also checks the eviction order. # This block ID order also checks the eviction order.
assert [b.block_id for b in blocks] == [10, 5, 4, 7, 6, 9, 8, 3, 2, 1] assert [b.block_id for b in blocks] == [7, 8, 9, 10, 4, 5, 6, 3, 2, 1]
assert manager.block_pool.free_block_queue.num_free_blocks == 0 assert manager.block_pool.free_block_queue.num_free_blocks == 0
assert manager.block_pool.free_block_queue.free_list_head is None assert manager.block_pool.free_block_queue.free_list_head is None
assert manager.block_pool.free_block_queue.free_list_tail is None assert manager.block_pool.free_block_queue.free_list_tail is None
...@@ -178,7 +176,6 @@ def test_prefill_plp(): ...@@ -178,7 +176,6 @@ def test_prefill_plp():
make_kv_cache_config(16, 11), make_kv_cache_config(16, 11),
max_model_len=8192, max_model_len=8192,
enable_caching=True, enable_caching=True,
num_preallocate_tokens=16,
) )
# the default hash function is hash # the default hash function is hash
hash_fn = hash hash_fn = hash
...@@ -197,7 +194,7 @@ def test_prefill_plp(): ...@@ -197,7 +194,7 @@ def test_prefill_plp():
assert not computed_blocks assert not computed_blocks
assert num_computed_tokens == 0 assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks) blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5] assert [b.block_id for b in blocks] == [1, 2, 3, 4]
req0_block_hashes = [b.block_hash for b in blocks] req0_block_hashes = [b.block_hash for b in blocks]
# Check full block metadata # Check full block metadata
...@@ -210,8 +207,8 @@ def test_prefill_plp(): ...@@ -210,8 +207,8 @@ def test_prefill_plp():
assert manager.block_pool.blocks[block_id].ref_cnt == 1 assert manager.block_pool.blocks[block_id].ref_cnt == 1
parent_block_hash = block_hash.hash_value parent_block_hash = block_hash.hash_value
# Check partial/preallocated block metadata # Check partial block metadata
for block_id in (4, 5): for block_id in (4, ):
assert manager.block_pool.blocks[block_id].block_hash is None assert manager.block_pool.blocks[block_id].block_hash is None
assert manager.block_pool.blocks[block_id].ref_cnt == 1 assert manager.block_pool.blocks[block_id].ref_cnt == 1
...@@ -226,12 +223,12 @@ def test_prefill_plp(): ...@@ -226,12 +223,12 @@ def test_prefill_plp():
assert num_computed_tokens == 3 * 16 assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16 num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks) blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
assert [b.block_id for b in blocks] == [6, 7] assert [b.block_id for b in blocks] == [5]
for block in computed_blocks: for block in computed_blocks:
assert block.ref_cnt == 2 assert block.ref_cnt == 2
# At this point, we should have 3 free blocks left. # At this point, we should have 5 free blocks left.
assert manager.block_pool.free_block_queue.num_free_blocks == 3 assert manager.block_pool.free_block_queue.num_free_blocks == 5
manager.free(req0) manager.free(req0)
manager.free(req1) manager.free(req1)
...@@ -239,14 +236,14 @@ def test_prefill_plp(): ...@@ -239,14 +236,14 @@ def test_prefill_plp():
# All blocks should be available. # All blocks should be available.
assert manager.block_pool.free_block_queue.num_free_blocks == 10 assert manager.block_pool.free_block_queue.num_free_blocks == 10
# The order should be # The order should be
# [unallocated (8, 9, 10)] # [unallocated (6, 7, 8, 9, 10)]
# [unique_req0 (5, 4)] # [unique_req0 (4)]
# [unique_req1 (7, 6)] # [unique_req1 (5)]
# [common (3, 2, 1)] # [common (3, 2, 1)]
assert [ assert [
b.block_id b.block_id
for b in manager.block_pool.free_block_queue.get_all_free_blocks() for b in manager.block_pool.free_block_queue.get_all_free_blocks()
] == [8, 9, 10, 5, 4, 7, 6, 3, 2, 1] ] == [6, 7, 8, 9, 10, 4, 5, 3, 2, 1]
# Request #2 is a prompt-logprobs request: # Request #2 is a prompt-logprobs request:
# NO cache hit in the common prefix; duplicates request #0 cached blocks # NO cache hit in the common prefix; duplicates request #0 cached blocks
...@@ -262,7 +259,7 @@ def test_prefill_plp(): ...@@ -262,7 +259,7 @@ def test_prefill_plp():
block_ids = [b.block_id for b in blocks] block_ids = [b.block_id for b in blocks]
# Duplicate cached blocks have different ids but same hashes vs request #0 # Duplicate cached blocks have different ids but same hashes vs request #0
assert [b.block_hash for b in blocks] == req0_block_hashes assert [b.block_hash for b in blocks] == req0_block_hashes
assert block_ids != [1, 2, 3, 4, 5] assert block_ids != [1, 2, 3, 4]
# Request #2 block hashes are valid since request #0 hashes are. # Request #2 block hashes are valid since request #0 hashes are.
# Check block reference counts. # Check block reference counts.
...@@ -277,7 +274,6 @@ def test_decode(): ...@@ -277,7 +274,6 @@ def test_decode():
make_kv_cache_config(16, 11), make_kv_cache_config(16, 11),
max_model_len=8192, max_model_len=8192,
enable_caching=True, enable_caching=True,
num_preallocate_tokens=16,
) )
# Complete 3 blocks (48 tokens) # Complete 3 blocks (48 tokens)
...@@ -291,7 +287,7 @@ def test_decode(): ...@@ -291,7 +287,7 @@ def test_decode():
assert not computed_blocks assert not computed_blocks
assert num_computed_tokens == 0 assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks) blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5] assert [b.block_id for b in blocks] == [1, 2, 3, 4]
# Append slots without allocating a new block. # Append slots without allocating a new block.
req0.num_computed_tokens = 55 req0.num_computed_tokens = 55
...@@ -299,28 +295,18 @@ def test_decode(): ...@@ -299,28 +295,18 @@ def test_decode():
req0.append_output_token_ids(8) req0.append_output_token_ids(8)
new_blocks = manager.allocate_slots(req0, 4) new_blocks = manager.allocate_slots(req0, 4)
assert new_blocks is not None and len(new_blocks) == 0 assert new_blocks is not None and len(new_blocks) == 0
assert manager.req_to_blocks[req0.request_id][-2].block_hash is None assert manager.req_to_blocks[req0.request_id][-1].block_hash is None
# Append slots without allocating a new block, but start using the # Append slots with allocating a new block.
# preallocated block.
req0.num_computed_tokens = 59 req0.num_computed_tokens = 59
# 6 tokens to fill the previous block, and 10 tokens to fill # 9 tokens to fill the previous block, and 10 tokens to fill
# the preallocated block. # the preallocated block.
for _ in range(5 + 10): for _ in range(9 + 10):
req0.append_output_token_ids(7) req0.append_output_token_ids(7)
new_blocks = manager.allocate_slots(req0, 15) new_blocks = manager.allocate_slots(req0, 19)
assert new_blocks is not None and len(new_blocks) == 0 assert new_blocks is not None and len(new_blocks) == 1
assert manager.req_to_blocks[req0.request_id][-2].block_hash is not None assert manager.req_to_blocks[req0.request_id][-2].block_hash is not None
assert manager.req_to_blocks[req0.request_id][-1].block_hash is None
# Append slots with allocating a new block.
req0.num_computed_tokens = 74
# 6 tokens to fill the previous block, and 10 tokens to fill
# the preallocated block.
for _ in range(6 + 11):
req0.append_output_token_ids(12)
new_blocks = manager.allocate_slots(req0, 17)
# Plus one preallocated block.
assert new_blocks is not None and len(new_blocks) == 2
def test_evict(): def test_evict():
...@@ -328,7 +314,6 @@ def test_evict(): ...@@ -328,7 +314,6 @@ def test_evict():
make_kv_cache_config(16, 11), make_kv_cache_config(16, 11),
max_model_len=8192, max_model_len=8192,
enable_caching=True, enable_caching=True,
num_preallocate_tokens=16,
) )
last_token_id = 5 * 16 + 7 last_token_id = 5 * 16 + 7
...@@ -337,7 +322,7 @@ def test_evict(): ...@@ -337,7 +322,7 @@ def test_evict():
assert not computed_blocks assert not computed_blocks
assert num_computed_tokens == 0 assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 5 * 16 + 7, computed_blocks) blocks = manager.allocate_slots(req0, 5 * 16 + 7, computed_blocks)
assert len(blocks) == 7 # 5 full + 1 partial + 1 preallocated assert len(blocks) == 6 # 5 full + 1 partial
# 3 blocks. # 3 blocks.
req1 = make_request("1", list(range(last_token_id, req1 = make_request("1", list(range(last_token_id,
...@@ -349,7 +334,8 @@ def test_evict(): ...@@ -349,7 +334,8 @@ def test_evict():
assert len(blocks) == 3 # 3 full blocks assert len(blocks) == 3 # 3 full blocks
last_token_id += 3 * 16 last_token_id += 3 * 16
assert manager.block_pool.free_block_queue.num_free_blocks == 0 # 10 - (6 + 3) == 1
assert manager.block_pool.free_block_queue.num_free_blocks == 1
manager.free(req0) manager.free(req0)
manager.free(req1) manager.free(req1)
...@@ -357,7 +343,7 @@ def test_evict(): ...@@ -357,7 +343,7 @@ def test_evict():
assert [ assert [
b.block_id b.block_id
for b in manager.block_pool.free_block_queue.get_all_free_blocks() for b in manager.block_pool.free_block_queue.get_all_free_blocks()
] == [7, 6, 5, 4, 3, 2, 1, 10, 9, 8] ] == [10, 6, 5, 4, 3, 2, 1, 9, 8, 7]
# Touch the first 2 blocks. # Touch the first 2 blocks.
req2 = make_request("2", list(range(2 * 16 + 3))) req2 = make_request("2", list(range(2 * 16 + 3)))
...@@ -365,8 +351,8 @@ def test_evict(): ...@@ -365,8 +351,8 @@ def test_evict():
assert [b.block_id for b in computed_blocks] == [1, 2] assert [b.block_id for b in computed_blocks] == [1, 2]
assert num_computed_tokens == 2 * 16 assert num_computed_tokens == 2 * 16
blocks = manager.allocate_slots(req2, 3, computed_blocks) blocks = manager.allocate_slots(req2, 3, computed_blocks)
assert [b.block_id for b in blocks] == [7, 6] assert [b.block_id for b in blocks] == [10]
assert manager.block_pool.free_block_queue.num_free_blocks == 6 assert manager.block_pool.free_block_queue.num_free_blocks == 7
def test_hash_block_correct_reuse(): def test_hash_block_correct_reuse():
...@@ -379,7 +365,6 @@ def test_hash_block_correct_reuse(): ...@@ -379,7 +365,6 @@ def test_hash_block_correct_reuse():
make_kv_cache_config(16, 2), make_kv_cache_config(16, 2),
max_model_len=8192, max_model_len=8192,
enable_caching=True, enable_caching=True,
num_preallocate_tokens=0,
) )
# Allocate 1 block and cache it. # Allocate 1 block and cache it.
...@@ -416,7 +401,6 @@ def test_computed_blocks_not_evicted(): ...@@ -416,7 +401,6 @@ def test_computed_blocks_not_evicted():
make_kv_cache_config(block_size, 3), make_kv_cache_config(block_size, 3),
max_model_len=8192, max_model_len=8192,
enable_caching=True, enable_caching=True,
num_preallocate_tokens=0,
) )
# Allocate a block and cache it. # Allocate a block and cache it.
...@@ -465,7 +449,6 @@ def test_basic_prefix_caching_disabled(): ...@@ -465,7 +449,6 @@ def test_basic_prefix_caching_disabled():
make_kv_cache_config(block_size, 5), make_kv_cache_config(block_size, 5),
max_model_len=8192, max_model_len=8192,
enable_caching=False, enable_caching=False,
num_preallocate_tokens=0,
) )
req1 = make_request("1", list(range(10))) # 2 blocks and some more req1 = make_request("1", list(range(10))) # 2 blocks and some more
...@@ -496,40 +479,6 @@ def test_basic_prefix_caching_disabled(): ...@@ -496,40 +479,6 @@ def test_basic_prefix_caching_disabled():
assert not blocks assert not blocks
@pytest.mark.parametrize("num_preallocate_tokens", list(range(0, 8)))
@pytest.mark.parametrize("block_size", [4])
def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int):
"""
This tests that the preallocated blocks are correctly added.
"""
manager = KVCacheManager(
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
num_preallocate_tokens=num_preallocate_tokens,
)
num_preallocated_blocks = cdiv(num_preallocate_tokens, block_size)
req = make_request("0", list(range(block_size * 30)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks
assert num_computed_tokens == 0
# Just ask for 1 block.
blocks = manager.allocate_slots(req, block_size, computed_blocks)
req.num_computed_tokens = block_size
assert len(blocks) == 1 + num_preallocated_blocks
# Assume all computed, only when num_preallocate_tokens > 0, we need to
# consume the previously preallocated blocks.
if num_preallocated_blocks > 0:
manager.allocate_slots(req, block_size * (len(blocks) - 1))
req.num_computed_tokens = block_size * len(blocks)
# Append 1 block.
blocks = manager.allocate_slots(req, block_size)
assert len(blocks) == 1 + num_preallocated_blocks
@pytest.mark.parametrize("hash_fn", [sha256, hash]) @pytest.mark.parametrize("hash_fn", [sha256, hash])
def test_cache_blocks(hash_fn): def test_cache_blocks(hash_fn):
""" """
...@@ -588,7 +537,6 @@ def test_mm_prefix_caching(): ...@@ -588,7 +537,6 @@ def test_mm_prefix_caching():
make_kv_cache_config(16, 11), make_kv_cache_config(16, 11),
max_model_len=8192, max_model_len=8192,
enable_caching=True, enable_caching=True,
num_preallocate_tokens=16,
) )
# Common prompt tokens (T is text tokens and P is image placeholder tokens) # Common prompt tokens (T is text tokens and P is image placeholder tokens)
...@@ -626,7 +574,7 @@ def test_mm_prefix_caching(): ...@@ -626,7 +574,7 @@ def test_mm_prefix_caching():
assert block_hashes[2].extra_keys == ("bbb", ) assert block_hashes[2].extra_keys == ("bbb", )
blocks = manager.allocate_slots(req0, 59, computed_blocks) blocks = manager.allocate_slots(req0, 59, computed_blocks)
assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5] assert [b.block_id for b in blocks] == [1, 2, 3, 4]
req0.num_computed_tokens = 59 req0.num_computed_tokens = 59
# Append slots without allocating a new block. # Append slots without allocating a new block.
...@@ -667,7 +615,6 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): ...@@ -667,7 +615,6 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
make_kv_cache_config(block_size, 11), make_kv_cache_config(block_size, 11),
max_model_len=8192, max_model_len=8192,
enable_caching=True, enable_caching=True,
num_preallocate_tokens=0,
) )
# Complete 3 blocks (48 tokens) # Complete 3 blocks (48 tokens)
# | Common-0 | Common-1 | Common-2 | ... | # | Common-0 | Common-1 | Common-2 | ... |
...@@ -721,7 +668,6 @@ def test_reset_prefix_cache(): ...@@ -721,7 +668,6 @@ def test_reset_prefix_cache():
make_kv_cache_config(16, 11), make_kv_cache_config(16, 11),
max_model_len=8192, max_model_len=8192,
enable_caching=True, enable_caching=True,
num_preallocate_tokens=0,
) )
full_block_token_ids = [i for i in range(3) for _ in range(16)] full_block_token_ids = [i for i in range(3) for _ in range(16)]
...@@ -751,3 +697,82 @@ def test_reset_prefix_cache(): ...@@ -751,3 +697,82 @@ def test_reset_prefix_cache():
assert manager.reset_prefix_cache() assert manager.reset_prefix_cache()
assert not manager.block_pool.cached_block_hash_to_block assert not manager.block_pool.cached_block_hash_to_block
assert all([blk.block_hash is None for blk in manager.block_pool.blocks]) assert all([blk.block_hash is None for blk in manager.block_pool.blocks])
def test_prefix_cache_stats_disabled():
"""Test that prefix_cache_stats is None when log_stats is False."""
manager = KVCacheManager(
make_kv_cache_config(16, 11),
max_model_len=8192,
enable_caching=True,
log_stats=False, # Disable logging stats
)
assert manager.prefix_cache_stats is None
# Call all functions that check whether log_stats is disabled.
req = make_request("0", list(range(16)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks
assert num_computed_tokens == 0
manager.allocate_slots(req, 16, computed_blocks)
manager.reset_prefix_cache()
# Ensure prefix_cache_stats remains None
assert manager.prefix_cache_stats is None
def test_eagle_enabled_removes_last_block():
"""Verify Eagle does NOT remove blocks when request
length is divisible by block size."""
block_size = 16
manager = KVCacheManager(
make_kv_cache_config(block_size, num_blocks=10),
max_model_len=8192,
enable_caching=True,
use_eagle=True,
)
# Request with 3 full blocks (48 tokens)
token_ids = [0] * (3 * block_size)
req = make_request("divisible_request", token_ids)
# Prime the cache
computed_blocks, _ = manager.get_computed_blocks(req)
manager.allocate_slots(req, len(token_ids), computed_blocks)
manager.free(req)
# New request with same tokens + Eagle enabled
req_eagle = make_request("eagle_divisible", token_ids)
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
# Should retain 2 blocks:
# 1. Original 3 blocks → pop last hash → 2 matched blocks
# 2. last_block_hash is not None → Eagle pop is not SKIPPED
assert len(computed_blocks) == 1
assert num_tokens == 1 * block_size # 32 tokens
def test_eagle_with_partial_blocks():
"""Test Eagle behavior with requests containing partial blocks."""
block_size = 16
manager = KVCacheManager(
make_kv_cache_config(block_size, num_blocks=10),
max_model_len=8192,
enable_caching=True,
use_eagle=True,
)
# 2 full blocks + 5 tokens (non-divisible length)
token_ids = [0] * (2 * block_size + 5)
req = make_request("partial_block_test", token_ids)
# Prime the cache
computed_blocks, _ = manager.get_computed_blocks(req)
manager.allocate_slots(req, len(token_ids), computed_blocks)
manager.free(req)
# New request with Eagle enabled
req_eagle = make_request("partial_eagle", token_ids)
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
assert len(computed_blocks) == 1
assert num_tokens == 1 * block_size
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Optional from typing import Optional
from unittest.mock import Mock
import pytest import pytest
import torch import torch
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
SchedulerConfig, SpeculativeConfig, VllmConfig)
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
...@@ -25,6 +27,11 @@ def create_scheduler( ...@@ -25,6 +27,11 @@ def create_scheduler(
enable_prefix_caching: Optional[bool] = None, enable_prefix_caching: Optional[bool] = None,
long_prefill_token_threshold: int = 0, long_prefill_token_threshold: int = 0,
disable_chunked_mm_input: bool = False, disable_chunked_mm_input: bool = False,
use_kv_connector: bool = False,
num_blocks: int = 10000,
block_size: int = 16,
max_model_len: Optional[int] = None,
num_speculative_tokens: Optional[int] = None,
) -> Scheduler: ) -> Scheduler:
'''Create scheduler under test. '''Create scheduler under test.
...@@ -39,12 +46,15 @@ def create_scheduler( ...@@ -39,12 +46,15 @@ def create_scheduler(
Returns: Returns:
:class:`Scheduler` instance :class:`Scheduler` instance
''' '''
if max_model_len is None:
max_model_len = max_num_batched_tokens
scheduler_config = SchedulerConfig( scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs, max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens, max_num_batched_tokens=max_num_batched_tokens,
max_model_len=max_num_batched_tokens, max_model_len=max_model_len,
long_prefill_token_threshold=long_prefill_token_threshold, long_prefill_token_threshold=long_prefill_token_threshold,
disable_chunked_mm_input=disable_chunked_mm_input, disable_chunked_mm_input=disable_chunked_mm_input,
enable_chunked_prefill=True,
) )
model_config = ModelConfig( model_config = ModelConfig(
model=model, model=model,
...@@ -60,31 +70,42 @@ def create_scheduler( ...@@ -60,31 +70,42 @@ def create_scheduler(
'enable_prefix_caching': enable_prefix_caching 'enable_prefix_caching': enable_prefix_caching
}) })
cache_config = CacheConfig( cache_config = CacheConfig(
block_size=16, block_size=block_size,
gpu_memory_utilization=0.9, gpu_memory_utilization=0.9,
swap_space=0, swap_space=0,
cache_dtype="auto", cache_dtype="auto",
**kwargs_cache, **kwargs_cache,
) )
kv_transfer_config = KVTransferConfig(
kv_connector="SharedStorageConnector",
kv_role="kv_both",
kv_connector_extra_config={"shared_storage_path": "local_storage"},
) if use_kv_connector else None
speculative_config: Optional[SpeculativeConfig] = None
if num_speculative_tokens is not None:
speculative_config = SpeculativeConfig(
model="ngram", num_speculative_tokens=num_speculative_tokens)
vllm_config = VllmConfig( vllm_config = VllmConfig(
scheduler_config=scheduler_config, scheduler_config=scheduler_config,
model_config=model_config, model_config=model_config,
cache_config=cache_config, cache_config=cache_config,
kv_transfer_config=kv_transfer_config,
speculative_config=speculative_config,
) )
kv_cache_config = KVCacheConfig( kv_cache_config = KVCacheConfig(
num_blocks=10000, # A large number of blocks to hold all requests num_blocks=num_blocks, # A large number of blocks to hold all requests
tensors={}, tensors={},
kv_cache_groups=[ kv_cache_groups=[
KVCacheGroupSpec(['layer'], KVCacheGroupSpec(['layer'],
FullAttentionSpec(16, 1, 1, torch.float32, False)) FullAttentionSpec(block_size, 1, 1, torch.float32,
False))
], ],
) )
cache_config.num_gpu_blocks = 10000 cache_config.num_gpu_blocks = num_blocks
return Scheduler( return Scheduler(
scheduler_config, vllm_config=vllm_config,
model_config,
cache_config,
lora_config=None,
kv_cache_config=kv_cache_config, kv_cache_config=kv_cache_config,
log_stats=True, log_stats=True,
structured_output_manager=StructuredOutputManager(vllm_config), structured_output_manager=StructuredOutputManager(vllm_config),
...@@ -111,7 +132,6 @@ def create_requests(num_requests: int, ...@@ -111,7 +132,6 @@ def create_requests(num_requests: int,
mm_inputs = None mm_inputs = None
request = Request( request = Request(
request_id=f"{i}", request_id=f"{i}",
prompt=None,
prompt_token_ids=[i] * num_tokens, prompt_token_ids=[i] * num_tokens,
sampling_params=sampling_params, sampling_params=sampling_params,
multi_modal_inputs=mm_inputs, multi_modal_inputs=mm_inputs,
...@@ -286,6 +306,7 @@ def test_no_mm_input_chunking(): ...@@ -286,6 +306,7 @@ def test_no_mm_input_chunking():
model="llava-hf/llava-1.5-7b-hf", model="llava-hf/llava-1.5-7b-hf",
max_num_batched_tokens=1024, max_num_batched_tokens=1024,
disable_chunked_mm_input=True, disable_chunked_mm_input=True,
max_model_len=2048,
) )
mm_positions = [[PlaceholderRange(offset=400, length=800)]] mm_positions = [[PlaceholderRange(offset=400, length=800)]]
requests = create_requests(num_requests=1, requests = create_requests(num_requests=1,
...@@ -414,7 +435,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): ...@@ -414,7 +435,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
def test_stop_via_update_from_output(): def test_stop_via_update_from_output():
"""Test stopping behavior through update_from_output""" """Test stopping behavior through update_from_output"""
scheduler = create_scheduler() scheduler = create_scheduler(num_speculative_tokens=1)
# Test case 1: Stop on EOS token # Test case 1: Stop on EOS token
requests = create_requests(num_requests=2, max_tokens=10) requests = create_requests(num_requests=2, max_tokens=10)
...@@ -422,7 +443,6 @@ def test_stop_via_update_from_output(): ...@@ -422,7 +443,6 @@ def test_stop_via_update_from_output():
req.num_computed_tokens = req.num_tokens req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req scheduler.requests[req.request_id] = req
scheduler.running.append(req) scheduler.running.append(req)
scheduler.scheduled_req_ids.add(req.request_id)
scheduler_output = SchedulerOutput(scheduled_new_reqs=[], scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[], scheduled_cached_reqs=[],
...@@ -466,7 +486,7 @@ def test_stop_via_update_from_output(): ...@@ -466,7 +486,7 @@ def test_stop_via_update_from_output():
assert list(requests[1].output_token_ids) == [10, 11] assert list(requests[1].output_token_ids) == [10, 11]
# Test case 2: Stop on custom stop token # Test case 2: Stop on custom stop token
scheduler = create_scheduler() scheduler = create_scheduler(num_speculative_tokens=2)
requests = create_requests(num_requests=2, requests = create_requests(num_requests=2,
max_tokens=10, max_tokens=10,
stop_token_ids=[42, 43]) stop_token_ids=[42, 43])
...@@ -474,7 +494,6 @@ def test_stop_via_update_from_output(): ...@@ -474,7 +494,6 @@ def test_stop_via_update_from_output():
req.num_computed_tokens = req.num_tokens req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req scheduler.requests[req.request_id] = req
scheduler.running.append(req) scheduler.running.append(req)
scheduler.scheduled_req_ids.add(req.request_id)
scheduler_output = SchedulerOutput(scheduled_new_reqs=[], scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[], scheduled_cached_reqs=[],
...@@ -518,13 +537,12 @@ def test_stop_via_update_from_output(): ...@@ -518,13 +537,12 @@ def test_stop_via_update_from_output():
assert list(requests[1].output_token_ids) == [13, 14] assert list(requests[1].output_token_ids) == [13, 14]
# Test case 3: Stop on max tokens # Test case 3: Stop on max tokens
scheduler = create_scheduler() scheduler = create_scheduler(num_speculative_tokens=2)
requests = create_requests(num_requests=2, max_tokens=2) requests = create_requests(num_requests=2, max_tokens=2)
for req in requests: for req in requests:
req.num_computed_tokens = req.num_tokens req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req scheduler.requests[req.request_id] = req
scheduler.running.append(req) scheduler.running.append(req)
scheduler.scheduled_req_ids.add(req.request_id)
scheduler_output = SchedulerOutput(scheduled_new_reqs=[], scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[], scheduled_cached_reqs=[],
...@@ -568,13 +586,12 @@ def test_stop_via_update_from_output(): ...@@ -568,13 +586,12 @@ def test_stop_via_update_from_output():
assert list(requests[1].output_token_ids) == [13] assert list(requests[1].output_token_ids) == [13]
# Test case 4: Ignore EOS flag # Test case 4: Ignore EOS flag
scheduler = create_scheduler() scheduler = create_scheduler(num_speculative_tokens=2)
requests = create_requests(num_requests=1, max_tokens=10) requests = create_requests(num_requests=1, max_tokens=10)
requests[0].sampling_params.ignore_eos = True requests[0].sampling_params.ignore_eos = True
requests[0].num_computed_tokens = requests[0].num_tokens requests[0].num_computed_tokens = requests[0].num_tokens
scheduler.requests[requests[0].request_id] = requests[0] scheduler.requests[requests[0].request_id] = requests[0]
scheduler.running.append(requests[0]) scheduler.running.append(requests[0])
scheduler.scheduled_req_ids.add(requests[0].request_id)
scheduler_output = SchedulerOutput( scheduler_output = SchedulerOutput(
scheduled_new_reqs=[], scheduled_new_reqs=[],
...@@ -671,13 +688,14 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool], ...@@ -671,13 +688,14 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
@pytest.mark.parametrize( @pytest.mark.parametrize(
"spec_tokens,output_tokens,expected", "spec_tokens,output_tokens,expected",
[ [
([[1, 2, 3]], [[1, 2, 3, 4]], (3, 3)), # perfect match ([[1, 2, 3]], [[1, 2, 3, 4]], (1, 3, 3, [1, 1, 1])), # perfect match
([[1, 2, 3]], [[1, 5]], (3, 1)), # early mismatch ([[1, 2, 3]], [[1, 5]], (1, 3, 1, [1, 0, 0])), # early mismatch
([[1, 2], [3]], [[1, 2, 5], [3, 4]], (3, 3)), # multiple sequences ([[1, 2], [3]], [[1, 2, 5], [3, 4]],
([[1]], [[1, 2]], (1, 1)), # single token sequence (2, 3, 3, [2, 1])), # multiple sequences
([[]], [[5]], (0, 0)), # empty sequence ([[1]], [[1, 2]], (1, 1, 1, [1])), # single token sequence
([[]], [[5]], (0, 0, 0, [0])), # empty sequence
([[1, 2, 3], [4, 5, 6]], [[1, 2, 7], [4, 8]], ([[1, 2, 3], [4, 5, 6]], [[1, 2, 7], [4, 8]],
(6, 3)), # multiple mismatches (2, 6, 3, [2, 1, 0])), # multiple mismatches
]) ])
def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
"""Test scheduling behavior with speculative decoding. """Test scheduling behavior with speculative decoding.
...@@ -686,7 +704,8 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): ...@@ -686,7 +704,8 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
1. Speculated tokens get scheduled correctly 1. Speculated tokens get scheduled correctly
2. Spec decoding stats properly count number of draft and accepted tokens 2. Spec decoding stats properly count number of draft and accepted tokens
""" """
scheduler = create_scheduler() num_spec_tokens = max(1, max(len(t) for t in spec_tokens))
scheduler = create_scheduler(num_speculative_tokens=num_spec_tokens)
requests = create_requests(num_requests=len(spec_tokens), num_tokens=1) requests = create_requests(num_requests=len(spec_tokens), num_tokens=1)
req_ids = [] req_ids = []
req_to_index = {} req_to_index = {}
...@@ -759,5 +778,390 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): ...@@ -759,5 +778,390 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
else: else:
assert scheduler_stats.spec_decoding_stats is not None assert scheduler_stats.spec_decoding_stats is not None
stats = scheduler_stats.spec_decoding_stats stats = scheduler_stats.spec_decoding_stats
assert stats.num_draft_tokens == expected[0] assert stats.num_drafts == expected[0]
assert stats.num_accepted_tokens == expected[1] assert stats.num_draft_tokens == expected[1]
assert stats.num_accepted_tokens == expected[2]
assert stats.num_accepted_tokens_per_pos == expected[3]
def _assert_right_scheduler_output(
output: SchedulerOutput,
num_requests: int,
expected_num_scheduled_tokens: int,
):
"""Check if SchedulerOutput is correct after remote KV cache hit."""
# We should inject the kv_connector_metadata.
assert len(output.kv_connector_metadata.requests) == num_requests
# Only num_tokens - matched_num_new_tokens should be scheduled.
for _, num_scheduled_tokens in output.num_scheduled_tokens.items():
assert num_scheduled_tokens == expected_num_scheduled_tokens
def _assert_right_kv_cache_manager(
scheduler: Scheduler,
req_ids: list[str],
num_tokens: int,
block_size: int,
num_requests: int,
num_total_blocks: int,
):
"""Check whether KVCacheManager is correct after allocate."""
# Make sure the request stats are right.
EXPECTED_TOTAL_BLOCKS = num_tokens // block_size
for req_id in req_ids:
blocks = scheduler.kv_cache_manager.req_to_blocks[req_id]
hashes = scheduler.kv_cache_manager.req_to_block_hashes[req_id]
assert (scheduler.kv_cache_manager.num_cached_block[req_id] ==
EXPECTED_TOTAL_BLOCKS)
assert len(blocks) == EXPECTED_TOTAL_BLOCKS
assert len(hashes) == EXPECTED_TOTAL_BLOCKS
# Make sure we actually touched all the blocks.
BLOCKS_PER_REQ = num_tokens / block_size
assert (scheduler.kv_cache_manager.block_pool.get_num_free_blocks() ==
num_total_blocks - num_requests * BLOCKS_PER_REQ)
def _step_until_done(
scheduler: Scheduler,
output: SchedulerOutput,
model_runner_output: ModelRunnerOutput,
):
"""Loop over schedule(), update_from_output() until finished."""
all_finished = False
_ = scheduler.update_from_output(output, model_runner_output)
while not all_finished:
# Schedule + a few iterations until stopping.
output = scheduler.schedule()
assert len(scheduler.running)
for _, num_scheduled_tokens in output.num_scheduled_tokens.items():
# We should be in the decode phase now.
assert num_scheduled_tokens == 1
assert len(output.kv_connector_metadata.requests) == 0
ecos = scheduler.update_from_output(output, model_runner_output)
all_done = True
for eco in ecos.outputs:
if eco.finish_reason is None:
all_done = False
all_finished = all_done
def test_kv_connector_basic():
"""
Test whether Scheduler with KVConnector schedules tokens, allocates
memory, and cleans up requests as expected under normal operation.
"""
# Setup Scheduler.
scheduler = create_scheduler(
enable_prefix_caching=True,
use_kv_connector=True,
)
NUM_TOTAL_BLOCKS = (
scheduler.kv_cache_manager.block_pool.get_num_free_blocks())
BLOCK_SIZE = scheduler.cache_config.block_size
# Mock External Cache Hit.
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
scheduler.connector.get_num_new_matched_tokens.return_value = (
NUM_MATCHED_NEW_TOKENS)
######################################################
# FIRST SET OF REQUESTS - External Hit Only
NUM_REQUESTS = 2
NUM_TOKENS = NUM_MATCHED_NEW_TOKENS * 2
MAX_TOKENS = 3
requests = create_requests(num_requests=NUM_REQUESTS,
num_tokens=NUM_TOKENS,
max_tokens=MAX_TOKENS)
req_ids = []
req_to_index = {}
for i, request in enumerate(requests):
scheduler.add_request(request)
req_ids.append(request.request_id)
req_to_index[request.request_id] = i
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=[[1000]] * len(req_ids),
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
)
# Ensure ScheduleOutput is correct.
output = scheduler.schedule()
_assert_right_scheduler_output(
output=output,
num_requests=NUM_REQUESTS,
# Just the incremental tokens should be scheduled.
expected_num_scheduled_tokens=NUM_TOKENS - NUM_MATCHED_NEW_TOKENS,
)
# Ensure KVCacheManager is correct.
_assert_right_kv_cache_manager(scheduler, req_ids, NUM_TOKENS, BLOCK_SIZE,
NUM_REQUESTS, NUM_TOTAL_BLOCKS)
# Continue Generation until done.
_step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT)
_ = scheduler.schedule()
# Confirm we clean up the memory properly.
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \
== NUM_TOTAL_BLOCKS
######################################################
# SECOND SET OF REQUESTS - Local And External Hit
NUM_TOKENS_PREFIX = NUM_TOKENS
# We will get a local prefix cache hit for the first
# NUM_TOKENS_PREFIX tokens since they are used above.
NUM_TOKENS = NUM_TOKENS_PREFIX * 2
requests = create_requests(num_requests=NUM_REQUESTS,
num_tokens=NUM_TOKENS,
max_tokens=MAX_TOKENS)
req_ids = []
req_to_index = {}
for i, request in enumerate(requests):
scheduler.add_request(request)
req_ids.append(request.request_id)
req_to_index[request.request_id] = i
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=[[1000]] * len(req_ids),
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
)
# We should get a local cache hit of NUM_TOKENS_PREFIX and
# a remote KV cache hit of NUM_MATCHED_NEW_TOKENS.
output = scheduler.schedule()
_assert_right_scheduler_output(
output=output,
num_requests=NUM_REQUESTS,
# Just the incremental tokens after local + remote cache hit.
expected_num_scheduled_tokens=(NUM_TOKENS - NUM_TOKENS_PREFIX -
NUM_MATCHED_NEW_TOKENS))
# Ensure KVCacheManager is correct.
_assert_right_kv_cache_manager(scheduler, req_ids, NUM_TOKENS, BLOCK_SIZE,
NUM_REQUESTS, NUM_TOTAL_BLOCKS)
# Continue Generation until done.
_step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT)
_ = scheduler.schedule()
# Confirm we clean up the memory properly.
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \
== NUM_TOTAL_BLOCKS
def test_kv_connector_unable_to_allocate():
"""
Test whether scheduler with KVConnector is able to handle
unable to allocate (run out of blocks in allocate_slots().
"""
# Setup Scheduler With Mock External Cache Hit.
BLOCK_SIZE = 4
NUM_BLOCKS = 10
scheduler = create_scheduler(
enable_prefix_caching=True,
use_kv_connector=True,
block_size=BLOCK_SIZE,
num_blocks=NUM_BLOCKS,
)
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
scheduler.connector.get_num_new_matched_tokens.return_value = (
NUM_MATCHED_NEW_TOKENS)
# Create two requests. The second request will not be able to
# allocate slots because it will not have enough blocks.
NUM_REQUESTS = 2
NUM_TOKENS = (NUM_BLOCKS // 2 + 1) * BLOCK_SIZE
MAX_TOKENS = 2
requests = create_requests(num_requests=NUM_REQUESTS,
num_tokens=NUM_TOKENS,
max_tokens=MAX_TOKENS)
req_ids = []
req_to_index = {}
for i, request in enumerate(requests):
scheduler.add_request(request)
req_ids.append(request.request_id)
req_to_index[request.request_id] = i
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=[[1000]] * len(req_ids),
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
)
# Just one request should be running.
output = scheduler.schedule()
_assert_right_scheduler_output(output,
num_requests=1,
expected_num_scheduled_tokens=NUM_TOKENS -
NUM_MATCHED_NEW_TOKENS)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 1
# All memory should be freed, with one request waiting.
_step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT)
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \
== NUM_BLOCKS - 1
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 1
# Just one request should be running.
output = scheduler.schedule()
_assert_right_scheduler_output(output,
num_requests=1,
expected_num_scheduled_tokens=NUM_TOKENS -
NUM_MATCHED_NEW_TOKENS)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 0
# All memory should be freed, with no requests waiting / running.
_step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT)
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \
== NUM_BLOCKS - 1
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 0
def test_kv_connector_handles_preemption():
"""
Test whether scheduler with KVConnector is able to handle
unable to allocate (run out of blocks in allocate_slots().
"""
# Setup Scheduler With Mock External Cache Hit.
BLOCK_SIZE = 2
# NOTE: there is 1 null block, so this is 6 blocks.
NUM_BLOCKS = 7
scheduler = create_scheduler(
enable_prefix_caching=True,
use_kv_connector=True,
block_size=BLOCK_SIZE,
num_blocks=NUM_BLOCKS,
)
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
scheduler.connector.get_num_new_matched_tokens.return_value = (
NUM_MATCHED_NEW_TOKENS)
# Create two requests.
# Both can be scheduled at first, but the second request
# will be preempted and re-scheduled.
NUM_REQUESTS = 2
NUM_TOKENS = BLOCK_SIZE * 2 + 1
MAX_TOKENS = BLOCK_SIZE * 2
requests = create_requests(num_requests=NUM_REQUESTS,
num_tokens=NUM_TOKENS,
max_tokens=MAX_TOKENS)
req_ids = []
req_to_index = {}
for i, request in enumerate(requests):
scheduler.add_request(request)
req_ids.append(request.request_id)
req_to_index[request.request_id] = i
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=[[1000]] * len(req_ids),
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
)
# All can be scheduled - 1st token.
output = scheduler.schedule()
_assert_right_scheduler_output(
output,
# 2 remote kv cache hits.
num_requests=2,
expected_num_scheduled_tokens=NUM_TOKENS - NUM_MATCHED_NEW_TOKENS)
assert len(scheduler.running) == 2
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
# All can be scheduled - 2nd token.
output = scheduler.schedule()
_assert_right_scheduler_output(
output,
# no connector_metadata
num_requests=0,
expected_num_scheduled_tokens=1)
assert len(scheduler.running) == 2
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
# This will generate a new block and cause a preemption - 3rd token.
output = scheduler.schedule()
_assert_right_scheduler_output(
output,
# no connector_metadata
num_requests=0,
expected_num_scheduled_tokens=1)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 1
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 1
# Only 1 can be scheduled - 4th (and last token).
output = scheduler.schedule()
_assert_right_scheduler_output(
output,
# no connector_metadata
num_requests=0,
expected_num_scheduled_tokens=1)
assert len(scheduler.waiting) == 1
assert len(scheduler.running) == 1
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 1
# All memory should be freed since nothing is running.
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \
== NUM_BLOCKS - 1
# Restarts the preempted request - generate 3rd token.
# This will have a local and remote cache hit.
output = scheduler.schedule()
_assert_right_scheduler_output(
output,
# 1 remote kv_cache hit!
num_requests=1,
# Only 1 block was preempted and there is a single
# remote hit. So only single new token is scheduled.
expected_num_scheduled_tokens=1,
)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 0
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 0
# Only 1 can be scheduled - 4th (and last token).
output = scheduler.schedule()
_assert_right_scheduler_output(
output,
# no connector_metadata
num_requests=0,
expected_num_scheduled_tokens=1)
assert len(scheduler.running) == 1
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
assert len(scheduler.running) == 0
# All memory should be freed since nothing is running.
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \
== NUM_BLOCKS - 1
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import pytest
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from ...utils import fork_new_process_for_each_test
def test_cascade_attention(example_system_message, monkeypatch): @fork_new_process_for_each_test
@pytest.mark.parametrize("attn_backend",
["FLASH_ATTN_VLLM_V1", "FLASHINFER_VLLM_V1"])
def test_cascade_attention(example_system_message, monkeypatch, attn_backend):
prompt = "\n<User>: Implement fibonacci sequence in Python.\n<Claude>:" prompt = "\n<User>: Implement fibonacci sequence in Python.\n<Claude>:"
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct") llm = LLM(model="Qwen/Qwen2-1.5B-Instruct")
sampling_params = SamplingParams(temperature=0.0, max_tokens=100) sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
......
...@@ -44,18 +44,20 @@ def test_prompts(): ...@@ -44,18 +44,20 @@ def test_prompts():
@pytest.fixture @pytest.fixture
def sampling_config(): def sampling_config():
# Only support greedy for now
return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False) return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False)
@pytest.fixture @pytest.fixture
def model_name(): def model_name():
return "meta-llama/Meta-Llama-3-8B-Instruct" return "meta-llama/Llama-3.1-8B-Instruct"
@pytest.fixture
def eagle_model_name(): def eagle_model_name():
return "yuhuili/EAGLE-LLaMA3-Instruct-8B" return "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
def eagle3_model_name():
return "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
def test_ngram_correctness( def test_ngram_correctness(
...@@ -102,12 +104,13 @@ def test_ngram_correctness( ...@@ -102,12 +104,13 @@ def test_ngram_correctness(
del spec_llm del spec_llm
@pytest.mark.parametrize("use_eagle3", [False, True], ids=["eagle", "eagle3"])
def test_eagle_correctness( def test_eagle_correctness(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
test_prompts: list[list[dict[str, Any]]], test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams, sampling_config: SamplingParams,
model_name: str, model_name: str,
eagle_model_name: str, use_eagle3: bool,
): ):
''' '''
Compare the outputs of a original LLM and a speculative LLM Compare the outputs of a original LLM and a speculative LLM
...@@ -116,18 +119,22 @@ def test_eagle_correctness( ...@@ -116,18 +119,22 @@ def test_eagle_correctness(
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
ref_llm = LLM(model=model_name, max_model_len=1024) ref_llm = LLM(model=model_name, max_model_len=2048)
ref_outputs = ref_llm.chat(test_prompts, sampling_config) ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm del ref_llm
spec_model_name = eagle3_model_name(
) if use_eagle3 else eagle_model_name()
spec_llm = LLM( spec_llm = LLM(
model=model_name, model=model_name,
trust_remote_code=True,
speculative_config={ speculative_config={
"method": "eagle", "method": "eagle3" if use_eagle3 else "eagle",
"model": eagle_model_name, "model": spec_model_name,
"num_speculative_tokens": 3, "num_speculative_tokens": 3,
"max_model_len": 2048,
}, },
max_model_len=1024, max_model_len=2048,
) )
spec_outputs = spec_llm.chat(test_prompts, sampling_config) spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0 matches = 0
...@@ -140,7 +147,7 @@ def test_eagle_correctness( ...@@ -140,7 +147,7 @@ def test_eagle_correctness(
print(f"ref_output: {ref_output.outputs[0].text}") print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}") print(f"spec_output: {spec_output.outputs[0].text}")
# Heuristic: expect at least 70% of the prompts to match exactly # Heuristic: expect at least 66% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy. # Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.7 * len(ref_outputs)) assert matches > int(0.66 * len(ref_outputs))
del spec_llm del spec_llm
...@@ -47,7 +47,7 @@ def _build_test_vectors_no_logprobs() -> DummyOutputProcessorTestVectors: ...@@ -47,7 +47,7 @@ def _build_test_vectors_no_logprobs() -> DummyOutputProcessorTestVectors:
tokenizer=tokenizer, tokenizer=tokenizer,
tokenizer_group=init_tokenizer_from_configs( tokenizer_group=init_tokenizer_from_configs(
vllm_config.model_config, vllm_config.scheduler_config, vllm_config.model_config, vllm_config.scheduler_config,
vllm_config.parallel_config, vllm_config.lora_config), vllm_config.lora_config),
vllm_config=vllm_config, vllm_config=vllm_config,
full_tokens=[tokenizer(text).input_ids for text in FULL_STRINGS], full_tokens=[tokenizer(text).input_ids for text in FULL_STRINGS],
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
......
...@@ -3,16 +3,19 @@ ...@@ -3,16 +3,19 @@
import asyncio import asyncio
from contextlib import ExitStack from contextlib import ExitStack
from typing import Optional from typing import Optional
from unittest.mock import MagicMock
import pytest import pytest
from vllm import SamplingParams from vllm import SamplingParams
from vllm.assets.image import ImageAsset from vllm.assets.image import ImageAsset
from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs import PromptType from vllm.inputs import PromptType
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sampling_params import RequestOutputKind from vllm.sampling_params import RequestOutputKind
from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.metrics.loggers import LoggingStatLogger
if not current_platform.is_cuda(): if not current_platform.is_cuda():
pytest.skip(reason="V1 currently only supported on CUDA.", pytest.skip(reason="V1 currently only supported on CUDA.",
...@@ -216,3 +219,33 @@ async def test_finished_flag(monkeypatch: pytest.MonkeyPatch, n: int, ...@@ -216,3 +219,33 @@ async def test_finished_flag(monkeypatch: pytest.MonkeyPatch, n: int,
# Assert only the last output has the finished flag set # Assert only the last output has the finished flag set
assert all(not out.finished for out in outputs[:-1]) assert all(not out.finished for out in outputs[:-1])
assert outputs[-1].finished assert outputs[-1].finished
class MockLoggingStatLogger(LoggingStatLogger):
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
super().__init__(vllm_config, engine_index)
self.log = MagicMock()
@pytest.mark.asyncio
async def test_customize_loggers(monkeypatch):
"""Test that we can customize the loggers.
If a customized logger is provided at the init, it should
be used directly.
"""
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")
engine = AsyncLLM.from_engine_args(
TEXT_ENGINE_ARGS,
stat_loggers=[MockLoggingStatLogger],
)
after.callback(engine.shutdown)
await engine.do_log_stats()
assert len(engine.stat_loggers) == 1
assert len(engine.stat_loggers[0]) == 1
engine.stat_loggers[0][0].log.assert_called_once()
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import copy import copy
import threading
import time import time
import uuid import uuid
from concurrent.futures import Future from concurrent.futures import Future, ThreadPoolExecutor
import pytest import pytest
from transformers import AutoTokenizer from transformers import AutoTokenizer
...@@ -32,8 +31,7 @@ PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids ...@@ -32,8 +31,7 @@ PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids
def make_request() -> EngineCoreRequest: def make_request() -> EngineCoreRequest:
return EngineCoreRequest( return EngineCoreRequest(
request_id=uuid.uuid4(), request_id=str(uuid.uuid4()),
prompt=PROMPT,
prompt_token_ids=PROMPT_TOKENS, prompt_token_ids=PROMPT_TOKENS,
mm_inputs=None, mm_inputs=None,
mm_hashes=None, mm_hashes=None,
...@@ -244,33 +242,33 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): ...@@ -244,33 +242,33 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
self, kv_cache_configs: list[KVCacheConfig]) -> None: self, kv_cache_configs: list[KVCacheConfig]) -> None:
super().initialize_from_config(kv_cache_configs) super().initialize_from_config(kv_cache_configs)
# This executor actually can only run 1 batch at a time # Create a thread pool with a single worker
self.semaphore = threading.Semaphore(1) self.thread_pool = ThreadPoolExecutor(max_workers=1)
def execute_model( def execute_model(
self, self,
scheduler_output, scheduler_output,
) -> Future[ModelRunnerOutput]: ) -> Future[ModelRunnerOutput]:
"""Make execute_model non-blocking.""" """Make execute_model non-blocking."""
future: Future[ModelRunnerOutput] = Future()
def _thread_wrapper(scheduler_output, future): def _execute():
with self.semaphore: output = self.collective_rpc("execute_model",
output = self.collective_rpc("execute_model", args=(scheduler_output, ))
args=(scheduler_output, )) # Make a copy because output[0] may be reused
# Make a copy because output[0] may be reused # by the next batch.
# by the next batch. return copy.deepcopy(output[0])
output = copy.deepcopy(output[0])
future.set_result(output)
threading.Thread(target=_thread_wrapper, # Use the thread pool instead of creating a new thread
args=(scheduler_output, future)).start() return self.thread_pool.submit(_execute)
return future
@property @property
def max_concurrent_batches(self) -> int: def max_concurrent_batches(self) -> int:
return 2 return 2
def shutdown(self):
if hasattr(self, 'thread_pool'):
self.thread_pool.shutdown(wait=False)
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
...@@ -299,14 +297,77 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): ...@@ -299,14 +297,77 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
# Schedule Batch 1: (10, req0) # Schedule Batch 1: (10, req0)
assert engine_core.step_with_batch_queue() is None assert engine_core.step_with_batch_queue() is None
assert engine_core.batch_queue.qsize() == 1 assert engine_core.batch_queue.qsize() == 1
scheduler_output = engine_core.batch_queue.queue[-1][1]
assert scheduler_output.num_scheduled_tokens[0] == 10
# num_computed_tokens should have been updated immediately.
assert engine_core.scheduler.requests[
req0.request_id].num_computed_tokens == 10
# Schedule Batch 2: (2, req0), (8, req1)
assert engine_core.step_with_batch_queue() is None assert engine_core.step_with_batch_queue() is None
assert engine_core.batch_queue.qsize() == 2 assert engine_core.batch_queue.qsize() == 2
scheduler_output = engine_core.batch_queue.queue[-1][1]
assert scheduler_output.num_scheduled_tokens[0] == 2
assert scheduler_output.num_scheduled_tokens[1] == 8
# num_computed_tokens should have been updated immediately.
assert engine_core.scheduler.requests[0].num_computed_tokens == 12
assert engine_core.scheduler.requests[1].num_computed_tokens == 8
assert engine_core.scheduler.get_num_unfinished_requests() == 2 assert engine_core.scheduler.get_num_unfinished_requests() == 2
# Loop through both requests. # Batch queue is full. Finish Batch 1.
while engine_core.scheduler.get_num_unfinished_requests() == 2: engine_core.step_with_batch_queue()
engine_core.step_with_batch_queue()
# Schedule Batch 3: (4, req1). Note that req0 cannot be scheduled
# because it is in the decoding stage now.
engine_core.step_with_batch_queue()
assert engine_core.batch_queue.qsize() == 2
scheduler_output = engine_core.batch_queue.queue[-1][1]
assert scheduler_output.num_scheduled_tokens[1] == 4
# Reaching here when got the result of the first request. # Batch queue is full. Finish Batch 2. Get first token of req0.
while engine_core.scheduler.get_num_unfinished_requests() == 1: output = engine_core.step_with_batch_queue()
engine_core.step_with_batch_queue() assert output is not None
assert len(output.outputs) == 1
assert engine_core.scheduler.requests[req0.request_id].num_tokens == 13
# Schedule Batch 4: (1, req0).
engine_core.step_with_batch_queue()
assert engine_core.batch_queue.qsize() == 2
scheduler_output = engine_core.batch_queue.queue[-1][1]
assert scheduler_output.num_scheduled_tokens[0] == 1
# Batch queue is full. Finish Batch 3. Get first token of req1.
output = engine_core.step_with_batch_queue()
assert output is not None
assert len(output.outputs) == 1
assert engine_core.scheduler.requests[req1.request_id].num_tokens == 13
# Schedule Batch 5: (1, req1).
engine_core.step_with_batch_queue()
assert engine_core.batch_queue.qsize() == 2
scheduler_output = engine_core.batch_queue.queue[-1][1]
assert scheduler_output.num_scheduled_tokens[1] == 1
# Loop until req0 is finished.
step = 0
req_id = 0
expected_num_tokens = [
engine_core.scheduler.requests[0].num_tokens + 1,
engine_core.scheduler.requests[1].num_tokens + 1,
]
while engine_core.scheduler.get_num_unfinished_requests() == 2:
output = engine_core.step_with_batch_queue()
if step % 2 == 0:
# Even steps consumes an output.
assert output is not None
assert len(output.outputs) == 1
if req_id in engine_core.scheduler.requests:
assert engine_core.scheduler.requests[
req_id].num_tokens == expected_num_tokens[req_id]
expected_num_tokens[req_id] += 1
req_id = (req_id + 1) % 2
else:
# Odd steps schedules a new batch.
assert output is None
step += 1
...@@ -35,7 +35,6 @@ PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids ...@@ -35,7 +35,6 @@ PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids
def make_request(params: SamplingParams) -> EngineCoreRequest: def make_request(params: SamplingParams) -> EngineCoreRequest:
return EngineCoreRequest( return EngineCoreRequest(
request_id=str(uuid.uuid4()), request_id=str(uuid.uuid4()),
prompt=PROMPT,
prompt_token_ids=PROMPT_TOKENS, prompt_token_ids=PROMPT_TOKENS,
mm_inputs=None, mm_inputs=None,
mm_hashes=None, mm_hashes=None,
......
...@@ -50,7 +50,6 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind, ...@@ -50,7 +50,6 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind,
# Make N requests. # Make N requests.
requests = [ requests = [
EngineCoreRequest(request_id=f"request-{idx}", EngineCoreRequest(request_id=f"request-{idx}",
prompt=prompt,
prompt_token_ids=prompt_tokens, prompt_token_ids=prompt_tokens,
arrival_time=0, arrival_time=0,
mm_inputs=None, mm_inputs=None,
...@@ -64,14 +63,13 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind, ...@@ -64,14 +63,13 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind,
output_kind=request_output_kind, output_kind=request_output_kind,
stop=[], stop=[],
include_stop_str_in_output=False, include_stop_str_in_output=False,
)) for idx, (prompt, prompt_tokens) in enumerate( ))
zip(dummy_test_vectors.prompt_strings, for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
dummy_test_vectors.prompt_tokens))
] ]
# Add requests to the detokenizer. # Add requests to the detokenizer.
for request in requests: for request, prompt in zip(requests, dummy_test_vectors.prompt_strings):
output_processor.add_request(request) output_processor.add_request(request, prompt)
gen_strings = {} gen_strings = {}
gen_tokens = {} gen_tokens = {}
...@@ -398,7 +396,6 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind, ...@@ -398,7 +396,6 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
] ]
requests = [ requests = [
EngineCoreRequest(request_id=request_id_list[idx], EngineCoreRequest(request_id=request_id_list[idx],
prompt=prompt,
prompt_token_ids=prompt_tokens, prompt_token_ids=prompt_tokens,
arrival_time=0, arrival_time=0,
mm_inputs=None, mm_inputs=None,
...@@ -414,14 +411,13 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind, ...@@ -414,14 +411,13 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
include_stop_str_in_output=False, include_stop_str_in_output=False,
logprobs=num_sample_logprobs, logprobs=num_sample_logprobs,
prompt_logprobs=num_prompt_logprobs, prompt_logprobs=num_prompt_logprobs,
)) for idx, (prompt, prompt_tokens) in enumerate( ))
zip(dummy_test_vectors.prompt_strings, for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
dummy_test_vectors.prompt_tokens))
] ]
# Add requests to the detokenizer. # Add requests to the detokenizer.
for request in requests: for request, prompt in zip(requests, dummy_test_vectors.prompt_strings):
output_processor.add_request(request) output_processor.add_request(request, prompt)
gen_tokens = {} gen_tokens = {}
gen_logprobs = {} gen_logprobs = {}
...@@ -562,7 +558,6 @@ def test_stop_token(include_stop_str_in_output: bool, ...@@ -562,7 +558,6 @@ def test_stop_token(include_stop_str_in_output: bool,
request_id = "request-0" request_id = "request-0"
request = EngineCoreRequest( request = EngineCoreRequest(
request_id=request_id, request_id=request_id,
prompt=prompt_string,
prompt_token_ids=prompt_tokens, prompt_token_ids=prompt_tokens,
arrival_time=0, arrival_time=0,
mm_inputs=None, mm_inputs=None,
...@@ -583,7 +578,7 @@ def test_stop_token(include_stop_str_in_output: bool, ...@@ -583,7 +578,7 @@ def test_stop_token(include_stop_str_in_output: bool,
)) ))
# Add request to the detokenizer. # Add request to the detokenizer.
output_processor.add_request(request) output_processor.add_request(request, prompt_string)
# Loop over engine core steps; run output processor # Loop over engine core steps; run output processor
gen_string = "" gen_string = ""
...@@ -659,7 +654,6 @@ def test_stop_string(include_stop_str_in_output: bool, ...@@ -659,7 +654,6 @@ def test_stop_string(include_stop_str_in_output: bool,
requests = [ requests = [
EngineCoreRequest( EngineCoreRequest(
request_id=request_id_list[idx], request_id=request_id_list[idx],
prompt=prompt,
prompt_token_ids=prompt_tokens, prompt_token_ids=prompt_tokens,
arrival_time=0, arrival_time=0,
mm_inputs=None, mm_inputs=None,
...@@ -675,14 +669,13 @@ def test_stop_string(include_stop_str_in_output: bool, ...@@ -675,14 +669,13 @@ def test_stop_string(include_stop_str_in_output: bool,
include_stop_str_in_output=include_stop_str_in_output, include_stop_str_in_output=include_stop_str_in_output,
logprobs=num_sample_logprobs, logprobs=num_sample_logprobs,
prompt_logprobs=None, prompt_logprobs=None,
)) for idx, (prompt, prompt_tokens) in enumerate( ))
zip(dummy_test_vectors.prompt_strings, for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
dummy_test_vectors.prompt_tokens))
] ]
# Add requests to the detokenizer. # Add requests to the detokenizer.
for request in requests: for request, prompt in zip(requests, dummy_test_vectors.prompt_strings):
output_processor.add_request(request) output_processor.add_request(request, prompt)
gen_strings = {} gen_strings = {}
gen_tokens = {} gen_tokens = {}
...@@ -774,7 +767,6 @@ def test_iteration_stats(dummy_test_vectors): ...@@ -774,7 +767,6 @@ def test_iteration_stats(dummy_test_vectors):
requests = [ requests = [
EngineCoreRequest( EngineCoreRequest(
request_id=f"request-{idx}", request_id=f"request-{idx}",
prompt=prompt,
prompt_token_ids=prompt_tokens, prompt_token_ids=prompt_tokens,
arrival_time=0, arrival_time=0,
mm_inputs=None, mm_inputs=None,
...@@ -783,15 +775,13 @@ def test_iteration_stats(dummy_test_vectors): ...@@ -783,15 +775,13 @@ def test_iteration_stats(dummy_test_vectors):
eos_token_id=None, eos_token_id=None,
lora_request=None, lora_request=None,
sampling_params=SamplingParams(), sampling_params=SamplingParams(),
) for idx, (prompt, prompt_tokens) in enumerate( ) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
zip(dummy_test_vectors.prompt_strings,
dummy_test_vectors.prompt_tokens))
] ]
# Add all requests except one to the OutputProcessor. # Add all requests except one to the OutputProcessor.
num_active = len(dummy_test_vectors.generation_tokens) - 1 num_active = len(dummy_test_vectors.generation_tokens) - 1
for request in requests[:num_active]: for request in requests[:num_active]:
output_processor.add_request(request) output_processor.add_request(request, None)
inactive_request = requests[num_active] inactive_request = requests[num_active]
# First iteration has 2 prefills. # First iteration has 2 prefills.
...@@ -817,7 +807,7 @@ def test_iteration_stats(dummy_test_vectors): ...@@ -817,7 +807,7 @@ def test_iteration_stats(dummy_test_vectors):
assert iteration_stats.num_generation_tokens == num_active assert iteration_stats.num_generation_tokens == num_active
# Add a new request - prefill and 2 decodes in this step. # Add a new request - prefill and 2 decodes in this step.
output_processor.add_request(inactive_request) output_processor.add_request(inactive_request, None)
num_active += 1 num_active += 1
outputs = engine_core.get_outputs()[:num_active] outputs = engine_core.get_outputs()[:num_active]
iteration_stats = IterationStats() iteration_stats = IterationStats()
...@@ -921,3 +911,84 @@ async def test_request_output_collector(): ...@@ -921,3 +911,84 @@ async def test_request_output_collector():
# Cumulative logprobs should be the last one. # Cumulative logprobs should be the last one.
cumulative_logprob_expected = 1.0 * num_to_put cumulative_logprob_expected = 1.0 * num_to_put
assert output.outputs[0].cumulative_logprob == cumulative_logprob_expected assert output.outputs[0].cumulative_logprob == cumulative_logprob_expected
@pytest.mark.asyncio
async def test_cumulative_output_collector_n():
"""Test collector correctly handles multiple outputs by index."""
collector = RequestOutputCollector(RequestOutputKind.CUMULATIVE)
outputs = [
RequestOutput(
request_id="my-request-id",
prompt=None,
prompt_token_ids=[1, 2, 3],
prompt_logprobs=None,
outputs=[
CompletionOutput(
index=0,
text="a",
token_ids=[0],
cumulative_logprob=None,
logprobs=None,
finish_reason=None,
),
CompletionOutput(
index=1,
text="b",
token_ids=[1],
cumulative_logprob=None,
logprobs=None,
finish_reason=None,
),
],
finished=False,
),
RequestOutput(
request_id="my-request-id",
prompt=None,
prompt_token_ids=[1, 2, 3],
prompt_logprobs=None,
outputs=[
CompletionOutput(
index=0,
text="ab",
token_ids=[0, 1],
cumulative_logprob=None,
logprobs=None,
finish_reason=None,
),
CompletionOutput(
index=2,
text="c",
token_ids=[2],
cumulative_logprob=None,
logprobs=None,
finish_reason=None,
),
],
finished=False,
),
]
for output in outputs:
collector.put(output)
# Get the output and check that the text and token_ids are correct.
result = await collector.get()
# We are expecting
# [{index: 0, text: "ab"}, {index: 1, text: "b"}, {index: 2, text: "c"}]
assert len(result.outputs) == 3
# First is the one where index is 0
first = [k for k in result.outputs if k.index == 0]
assert len(first) == 1
assert first[0].text == "ab"
# Second is the one where index is 1
second = [k for k in result.outputs if k.index == 1]
assert len(second) == 1
assert second[0].text == "b"
assert second[0].token_ids == [1]
# Third is the one where index is 2
third = [k for k in result.outputs if k.index == 2]
assert len(third) == 1
assert third[0].text == "c"
...@@ -8,8 +8,7 @@ import torch ...@@ -8,8 +8,7 @@ import torch
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( from vllm.transformers_utils.tokenizer_group import TokenizerGroup
BaseTokenizerGroup)
from vllm.v1.engine import EngineCoreOutput, FinishReason from vllm.v1.engine import EngineCoreOutput, FinishReason
from vllm.v1.outputs import LogprobsLists, LogprobsTensors from vllm.v1.outputs import LogprobsLists, LogprobsTensors
...@@ -296,7 +295,7 @@ def generate_dummy_prompt_logprobs_tensors( ...@@ -296,7 +295,7 @@ def generate_dummy_prompt_logprobs_tensors(
class DummyOutputProcessorTestVectors: class DummyOutputProcessorTestVectors:
"""Dummy test vectors for output processor tests""" """Dummy test vectors for output processor tests"""
tokenizer: GeneralTokenizerType tokenizer: GeneralTokenizerType
tokenizer_group: BaseTokenizerGroup tokenizer_group: TokenizerGroup
vllm_config: EngineArgs vllm_config: EngineArgs
full_tokens: list[list[int]] # Prompt + generated tokens full_tokens: list[list[int]] # Prompt + generated tokens
prompt_tokens: list[list[int]] prompt_tokens: list[list[int]]
......
...@@ -47,6 +47,14 @@ def sample_json_schema(): ...@@ -47,6 +47,14 @@ def sample_json_schema():
"type": "string", "type": "string",
} }
}, },
"grade": {
"type": "string",
"pattern": "^[A-D]$" # Regex pattern
},
"email": {
"type": "string",
"pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$"
},
"work_history": { "work_history": {
"type": "array", "type": "array",
"items": { "items": {
...@@ -56,17 +64,20 @@ def sample_json_schema(): ...@@ -56,17 +64,20 @@ def sample_json_schema():
"type": "string" "type": "string"
}, },
"duration": { "duration": {
"type": "number" "type": "number",
"minimum": 0.0,
"maximum": 100.0, # Numeric range
}, },
"position": { "position": {
"type": "string" "type": "string"
} }
}, },
"required": ["company", "position"] "required": ["company", "duration", "position"]
} }
} }
}, },
"required": ["name", "age", "skills", "work_history"] "required":
["name", "age", "skills", "grade", "email", "work_history"]
} }
...@@ -78,27 +89,18 @@ def unsupported_json_schema(): ...@@ -78,27 +89,18 @@ def unsupported_json_schema():
"properties": { "properties": {
"score": { "score": {
"type": "integer", "type": "integer",
"minimum": 0, "multipleOf": 5 # Numeric multiple
"maximum": 100 # Numeric range
},
"grade": {
"type": "string",
"pattern": "^[A-D]$" # Regex pattern
},
"email": {
"type": "string",
"pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$"
}, },
"tags": { "tags": {
"type": "array", "type": "array",
"items": { "items": {
"type": "string", "type": "string",
"pattern": "minLength": 10,
"^[a-z]{1,10}$" # Combining length and pattern restrictions "maxLength": 20
} }
} }
}, },
"required": ["score", "grade", "email", "tags"] "required": ["score", "tags"]
} }
......
...@@ -13,6 +13,7 @@ from pydantic import BaseModel ...@@ -13,6 +13,7 @@ from pydantic import BaseModel
from vllm.entrypoints.llm import LLM from vllm.entrypoints.llm import LLM
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.platforms import current_platform
from vllm.sampling_params import GuidedDecodingParams, SamplingParams from vllm.sampling_params import GuidedDecodingParams, SamplingParams
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [ PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [
...@@ -63,10 +64,13 @@ def test_structured_output( ...@@ -63,10 +64,13 @@ def test_structured_output(
): ):
monkeypatch.setenv("VLLM_USE_V1", "1") monkeypatch.setenv("VLLM_USE_V1", "1")
# Don't use eager execution on TPUs because we want to test for no
# recompilation at runtime
enforce_eager = bool(not current_platform.is_tpu())
# Use a single LLM instance for several scenarios to # Use a single LLM instance for several scenarios to
# speed up the test suite. # speed up the test suite.
llm = LLM(model=model_name, llm = LLM(model=model_name,
enforce_eager=True, enforce_eager=enforce_eager,
max_model_len=1024, max_model_len=1024,
guided_decoding_backend=guided_decoding_backend, guided_decoding_backend=guided_decoding_backend,
tokenizer_mode=tokenizer_mode) tokenizer_mode=tokenizer_mode)
...@@ -346,6 +350,7 @@ def test_structured_output( ...@@ -346,6 +350,7 @@ def test_structured_output(
temperature=1.0, temperature=1.0,
max_tokens=1000, max_tokens=1000,
guided_decoding=GuidedDecodingParams(json=json_schema)) guided_decoding=GuidedDecodingParams(json=json_schema))
outputs = llm.generate( outputs = llm.generate(
prompts="Generate a description of a frog using 50 characters.", prompts="Generate a description of a frog using 50 characters.",
sampling_params=sampling_params, sampling_params=sampling_params,
...@@ -364,6 +369,106 @@ def test_structured_output( ...@@ -364,6 +369,106 @@ def test_structured_output(
output_json = json.loads(generated_text) output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=json_schema) jsonschema.validate(instance=output_json, schema=json_schema)
#
# Test 11: Generate structured output using structural_tag format
#
structural_tag_config = {
"type":
"structural_tag",
"structures": [{
"begin": "<function=get_weather>",
"schema": {
"type": "object",
"properties": {
"city": {
"type": "string"
}
}
},
"end": "</function>"
}],
"triggers": ["<function="]
}
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=100,
guided_decoding=GuidedDecodingParams(
structural_tag=json.dumps(structural_tag_config)))
prompt = """
You have access to the following function to retrieve the weather in a city:
{
"name": "get_weather",
"parameters": {
"city": {
"param_type": "string",
"description": "The city to get the weather for",
"required": True
}
}
}
If a you choose to call a function ONLY reply in the following format:
<{start_tag}={function_name}>{parameters}{end_tag}
where
start_tag => `<function`
parameters => a JSON dict with the function argument name
as key and function argument value as value.
end_tag => `</function>`
Here is an example,
<function=example_function_name>{"example_name": "example_value"}</function>
Reminder:
- Function calls MUST follow the specified format
- Required parameters MUST be specified
- Only call one function at a time
- Put the entire function call reply on one line
- Always add your sources when using search results to answer the user query
You are a helpful assistant.
Given the previous instructions, what is the weather in New York City?
"""
# Change this once other backends support structural_tag
if guided_decoding_backend.startswith("xgrammar"):
outputs = llm.generate(prompts=prompt,
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
else:
outputs = []
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
generated_text = output.outputs[0].text
assert generated_text is not None
# Search for function call pattern in the response
function_call_pattern = r'<function=get_weather>(.*?)</function>'
matches = re.findall(function_call_pattern, generated_text)
if not matches:
print(f"Warning: No function calls found in response: "
f"{generated_text!r}")
continue
# Take the first function call if multiple are found
json_str = matches[0]
try:
json_content = json.loads(json_str)
assert "city" in json_content
assert isinstance(json_content["city"], str)
print(f"Found valid function call: {generated_text!r}")
except (json.JSONDecodeError, AssertionError) as e:
pytest.fail("Invalid function call format: "
f"{generated_text!r}\nError: {str(e)}")
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("model_name, tokenizer_mode", @pytest.mark.parametrize("model_name, tokenizer_mode",
...@@ -386,13 +491,21 @@ def test_structured_output_auto_mode( ...@@ -386,13 +491,21 @@ def test_structured_output_auto_mode(
max_tokens=1000, max_tokens=1000,
guided_decoding=GuidedDecodingParams(json=unsupported_json_schema)) guided_decoding=GuidedDecodingParams(json=unsupported_json_schema))
prompts = ("Give an example JSON object for a grade "
"that fits this schema: "
f"{unsupported_json_schema}")
# This would fail with the default of "xgrammar", but in "auto" # This would fail with the default of "xgrammar", but in "auto"
# we will handle fallback automatically. # we will handle fallback automatically.
outputs = llm.generate(prompts=("Give an example JSON object for a grade " outputs = llm.generate(prompts=prompts,
"that fits this schema: "
f"{unsupported_json_schema}"),
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=True) use_tqdm=True)
# Make sure `auto` backend handling doesn't mess up sampling_params
# and that we can reuse it without error.
outputs.extend(
llm.generate(prompts=prompts,
sampling_params=sampling_params,
use_tqdm=True))
assert outputs is not None assert outputs is not None
for output in outputs: for output in outputs:
assert output is not None assert output is not None
...@@ -404,3 +517,59 @@ def test_structured_output_auto_mode( ...@@ -404,3 +517,59 @@ def test_structured_output_auto_mode(
# Parse to verify it is valid JSON # Parse to verify it is valid JSON
parsed_json = json.loads(generated_text) parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict) assert isinstance(parsed_json, dict)
@pytest.mark.skip_global_cleanup
def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_V1", "1")
backend = 'guidance:no-additional-properties,disable-any-whitespace'
llm = LLM(model="Qwen/Qwen2.5-1.5B-Instruct",
max_model_len=1024,
guided_decoding_backend=backend)
schema = {
'type': 'object',
'properties': {
'a1': {
'type': 'string'
},
'a2': {
'type': 'string'
},
'a3': {
'type': 'string'
}
},
'required': ['a1', 'a2', 'a3'],
}
prompt = (
"<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a "
"helpful assistant.<|im_end|>\n<|im_start|>user\nPlease generate a "
"large JSON object with key-value pairs a1=b1, a2=b2, ..., a20=b20"
"<|im_end|>\n<|im_start|>assistant\n")
def generate_with_backend(backend):
guided_params = GuidedDecodingParams(json=schema, backend=backend)
sampling_params = SamplingParams(temperature=0,
max_tokens=256,
guided_decoding=guided_params)
outputs = llm.generate(prompts=prompt, sampling_params=sampling_params)
assert outputs is not None
generated_text = outputs[0].outputs[0].text
assert generated_text is not None
parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict)
jsonschema.validate(instance=parsed_json, schema=schema)
return parsed_json
generated = generate_with_backend(
'guidance:no-additional-properties,disable-any-whitespace')
assert "a1" in generated
assert "a2" in generated
assert "a3" in generated
assert "a4" not in generated
assert "a5" not in generated
assert "a6" not in generated
# SPDX-License-Identifier: Apache-2.0
"""Test that we handle a startup Error and shutdown."""
import pytest
from tests.utils import wait_for_gpu_memory_to_clear
from tests.v1.shutdown.utils import (SHUTDOWN_TEST_THRESHOLD_BYTES,
SHUTDOWN_TEST_TIMEOUT_SEC)
from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.sampling_params import RequestOutputKind
from vllm.utils import cuda_device_count_stateless
from vllm.v1.engine.async_llm import AsyncLLM
MODELS = ["meta-llama/Llama-3.2-1B"]
@pytest.mark.asyncio
@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("tensor_parallel_size", [2, 1])
@pytest.mark.parametrize("send_one_request", [False, True])
async def test_async_llm_delete(model: str, tensor_parallel_size: int,
send_one_request: bool) -> None:
"""Test that AsyncLLM frees GPU memory upon deletion.
AsyncLLM always uses an MP client.
Args:
model: model under test
tensor_parallel_size: degree of tensor parallelism
send_one_request: send one request to engine before deleting
"""
if cuda_device_count_stateless() < tensor_parallel_size:
pytest.skip(reason="Not enough CUDA devices")
engine_args = AsyncEngineArgs(model=model,
enforce_eager=True,
tensor_parallel_size=tensor_parallel_size)
# Instantiate AsyncLLM; make request to complete any deferred
# initialization; then delete instance
async_llm = AsyncLLM.from_engine_args(engine_args)
if send_one_request:
async for _ in async_llm.generate(
"Hello my name is",
request_id="abc",
sampling_params=SamplingParams(
max_tokens=1, output_kind=RequestOutputKind.DELTA)):
pass
del async_llm
# Confirm all the processes are cleaned up.
wait_for_gpu_memory_to_clear(
devices=list(range(tensor_parallel_size)),
threshold_bytes=SHUTDOWN_TEST_THRESHOLD_BYTES,
)
@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("tensor_parallel_size", [2, 1])
@pytest.mark.parametrize("enable_multiprocessing", [True])
@pytest.mark.parametrize("send_one_request", [False, True])
def test_llm_delete(monkeypatch, model: str, tensor_parallel_size: int,
enable_multiprocessing: bool,
send_one_request: bool) -> None:
"""Test that LLM frees GPU memory upon deletion.
TODO(andy) - LLM without multiprocessing.
Args:
model: model under test
tensor_parallel_size: degree of tensor parallelism
enable_multiprocessing: enable workers in separate process(es)
send_one_request: send one request to engine before deleting
"""
if cuda_device_count_stateless() < tensor_parallel_size:
pytest.skip(reason="Not enough CUDA devices")
with monkeypatch.context() as m:
MP_VALUE = "1" if enable_multiprocessing else "0"
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", MP_VALUE)
# Instantiate LLM; make request to complete any deferred
# initialization; then delete instance
llm = LLM(model=model,
enforce_eager=True,
tensor_parallel_size=tensor_parallel_size)
if send_one_request:
llm.generate("Hello my name is",
sampling_params=SamplingParams(max_tokens=1))
del llm
# Confirm all the processes are cleaned up.
wait_for_gpu_memory_to_clear(
devices=list(range(tensor_parallel_size)),
threshold_bytes=SHUTDOWN_TEST_THRESHOLD_BYTES,
)
# SPDX-License-Identifier: Apache-2.0
"""Test that we handle an Error in model forward and shutdown."""
import asyncio
import pytest
from tests.utils import wait_for_gpu_memory_to_clear
from tests.v1.shutdown.utils import (SHUTDOWN_TEST_THRESHOLD_BYTES,
SHUTDOWN_TEST_TIMEOUT_SEC)
from vllm import LLM, AsyncEngineArgs, SamplingParams
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.utils import cuda_device_count_stateless
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.exceptions import EngineDeadError
MODELS = ["meta-llama/Llama-3.2-1B"]
def evil_forward(self, *args, **kwargs):
"""Evil forward method that raise an exception after 10 calls."""
NUMBER_OF_GOOD_PASSES = 10
if not hasattr(self, "num_calls"):
self.num_calls = 0
if (self.num_calls == NUMBER_OF_GOOD_PASSES
and get_tensor_model_parallel_rank() == 0):
raise Exception("Simulated illegal memory access on Rank 0!")
self.num_calls += 1
return self.model(*args, **kwargs)
@pytest.mark.asyncio
@pytest.mark.parametrize("tensor_parallel_size", [2, 1])
@pytest.mark.parametrize("model", MODELS)
async def test_async_llm_model_error(monkeypatch, tensor_parallel_size: int,
model: str) -> None:
"""Test that AsyncLLM propagates a forward pass error and frees memory.
AsyncLLM always uses an MP client.
"""
if cuda_device_count_stateless() < tensor_parallel_size:
pytest.skip(reason="Not enough CUDA devices")
# Monkeypatch an error in the model.
monkeypatch.setattr(LlamaForCausalLM, "forward", evil_forward)
engine_args = AsyncEngineArgs(model=model,
enforce_eager=True,
tensor_parallel_size=tensor_parallel_size)
async_llm = AsyncLLM.from_engine_args(engine_args)
async def generate(request_id: str):
generator = async_llm.generate("Hello my name is",
request_id=request_id,
sampling_params=SamplingParams())
try:
async for _ in generator:
pass
except Exception as e:
return e
NUM_REQS = 3
tasks = [generate(f"request-{idx}") for idx in range(NUM_REQS)]
outputs = await asyncio.gather(*tasks)
# Every request should get an EngineDeadError.
for output in outputs:
assert isinstance(output, EngineDeadError)
# AsyncLLM should be errored.
assert async_llm.errored
# We should not be able to make another request.
with pytest.raises(EngineDeadError):
async for _ in async_llm.generate("Hello my name is",
request_id="abc",
sampling_params=SamplingParams()):
raise Exception("We should not get here.")
# Confirm all the processes are cleaned up.
wait_for_gpu_memory_to_clear(
devices=list(range(tensor_parallel_size)),
threshold_bytes=2 * 2**30,
timeout_s=60,
)
# NOTE: shutdown is handled by the API Server if an exception
# occurs, so it is expected that we would need to call this.
async_llm.shutdown()
@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC)
@pytest.mark.parametrize("enable_multiprocessing", [True])
@pytest.mark.parametrize("tensor_parallel_size", [2, 1])
@pytest.mark.parametrize("model", MODELS)
def test_llm_model_error(monkeypatch, tensor_parallel_size: int,
enable_multiprocessing: bool, model: str) -> None:
"""Test that LLM propagates a forward pass error and frees memory.
TODO(andy) - LLM without multiprocessing; LLM with multiprocessing
and >1 rank
"""
if cuda_device_count_stateless() < tensor_parallel_size:
pytest.skip(reason="Not enough CUDA devices")
with monkeypatch.context() as m:
MP_VALUE = "1" if enable_multiprocessing else "0"
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", MP_VALUE)
# Monkeypatch an error in the model.
m.setattr(LlamaForCausalLM, "forward", evil_forward)
llm = LLM(model=model,
enforce_eager=True,
tensor_parallel_size=tensor_parallel_size)
with pytest.raises(
EngineDeadError if enable_multiprocessing else Exception):
llm.generate("Hello my name is Robert and I")
# Confirm all the processes are cleaned up.
wait_for_gpu_memory_to_clear(
devices=list(range(tensor_parallel_size)),
threshold_bytes=SHUTDOWN_TEST_THRESHOLD_BYTES,
)
# SPDX-License-Identifier: Apache-2.0
"""Test error handling in Processor. Should not impact other reqs."""
import asyncio
import pytest
from tests.v1.shutdown.utils import SHUTDOWN_TEST_TIMEOUT_SEC
from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs.data import TokensPrompt
from vllm.sampling_params import RequestOutputKind
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.exceptions import EngineGenerateError
MODELS = ["meta-llama/Llama-3.2-1B"]
@pytest.mark.asyncio
@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC)
@pytest.mark.parametrize("model", MODELS)
async def test_async_llm_processor_error(model: str) -> None:
"""Test that AsyncLLM propagates a processor error.
Test empty tokens prompt (failure) and non-empty prompt (no failure.)
AsyncLLM always uses an MP client.
"""
engine_args = AsyncEngineArgs(model=model, enforce_eager=True)
async_llm = AsyncLLM.from_engine_args(engine_args)
async def generate(request_id: str):
# [] is not allowed and will raise a ValueError in Processor.
generator = async_llm.generate(TokensPrompt([]),
request_id=request_id,
sampling_params=SamplingParams())
try:
async for _ in generator:
pass
except Exception as e:
return e
NUM_REQS = 3
tasks = [generate(f"request-{idx}") for idx in range(NUM_REQS)]
outputs = await asyncio.gather(*tasks)
# Every request should have get an EngineGenerateError.
for output in outputs:
with pytest.raises(EngineGenerateError):
raise output
# AsyncLLM should be errored.
assert not async_llm.errored
# This should be no problem.
EXPECTED_TOKENS = 5
outputs = []
async for out in async_llm.generate(
"Hello my name is",
request_id="abc",
sampling_params=SamplingParams(
max_tokens=EXPECTED_TOKENS,
output_kind=RequestOutputKind.DELTA)):
outputs.append(out)
generated_tokens = []
for out in outputs:
generated_tokens.extend(out.outputs[0].token_ids)
assert len(generated_tokens) == EXPECTED_TOKENS
async_llm.shutdown()
# SPDX-License-Identifier: Apache-2.0
"""Test that we handle a startup Error and shutdown."""
import pytest
from tests.utils import wait_for_gpu_memory_to_clear
from tests.v1.shutdown.utils import (SHUTDOWN_TEST_THRESHOLD_BYTES,
SHUTDOWN_TEST_TIMEOUT_SEC)
from vllm import LLM
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.utils import cuda_device_count_stateless
from vllm.v1.engine.async_llm import AsyncLLM
MODELS = ["meta-llama/Llama-3.2-1B"]
def evil_method(self, *args, **kwargs):
"""Evil method that raises an exception."""
if get_tensor_model_parallel_rank() == 0:
raise Exception("Simulated Error in startup!")
return self.model(*args, **kwargs, intermediate_tensors=None)
@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("tensor_parallel_size", [2, 1])
@pytest.mark.parametrize("failing_method", ["forward", "load_weights"])
def test_async_llm_startup_error(monkeypatch, model: str,
tensor_parallel_size: int,
failing_method: str) -> None:
"""Test that AsyncLLM propagates an __init__ error & frees memory.
Test profiling (forward()) and load weights failures.
AsyncLLM always uses an MP client.
"""
if cuda_device_count_stateless() < tensor_parallel_size:
pytest.skip(reason="Not enough CUDA devices")
# Monkeypatch an error in the model.
monkeypatch.setattr(LlamaForCausalLM, failing_method, evil_method)
engine_args = AsyncEngineArgs(model=model,
enforce_eager=True,
tensor_parallel_size=tensor_parallel_size)
# Confirm we get an exception.
with pytest.raises(Exception, match="initialization failed"):
_ = AsyncLLM.from_engine_args(engine_args)
# Confirm all the processes are cleaned up.
wait_for_gpu_memory_to_clear(
devices=list(range(tensor_parallel_size)),
threshold_bytes=SHUTDOWN_TEST_THRESHOLD_BYTES,
)
@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("tensor_parallel_size", [2, 1])
@pytest.mark.parametrize("enable_multiprocessing", [True])
@pytest.mark.parametrize("failing_method", ["forward", "load_weights"])
def test_llm_startup_error(monkeypatch, model: str, tensor_parallel_size: int,
enable_multiprocessing: bool,
failing_method: str) -> None:
"""Test that LLM propagates an __init__ error and frees memory.
Test profiling (forward()) and load weights failures.
TODO(andy) - LLM without multiprocessing.
"""
if model != "meta-llama/Llama-3.2-1B":
pytest.skip(reason="Only test meta-llama/Llama-3.2-1B")
if cuda_device_count_stateless() < tensor_parallel_size:
pytest.skip(reason="Not enough CUDA devices")
with monkeypatch.context() as m:
MP_VALUE = "1" if enable_multiprocessing else "0"
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", MP_VALUE)
# Monkeypatch an error in the model.
monkeypatch.setattr(LlamaForCausalLM, failing_method, evil_method)
with pytest.raises(
Exception,
match="initialization failed"
if enable_multiprocessing else "Simulated Error in startup!"):
_ = LLM(model=model,
enforce_eager=True,
tensor_parallel_size=tensor_parallel_size)
# Confirm all the processes are cleaned up.
wait_for_gpu_memory_to_clear(
devices=list(range(tensor_parallel_size)),
threshold_bytes=SHUTDOWN_TEST_THRESHOLD_BYTES,
)
# SPDX-License-Identifier: Apache-2.0
"""Shutdown test utils"""
SHUTDOWN_TEST_TIMEOUT_SEC = 120
SHUTDOWN_TEST_THRESHOLD_BYTES = 2 * 2**30
# SPDX-License-Identifier: Apache-2.0
"""Test whether spec decoding handles the max model length properly."""
import pytest
from vllm import LLM, SamplingParams
_PROMPTS = [
"1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1",
"Repeat the following sentence 10 times: Consistency is key to mastering any skill.", # noqa: E501
"Who won the Turing Award in 2018, and for what contribution? Describe in detail.", # noqa: E501
]
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10])
def test_ngram_max_len(
monkeypatch: pytest.MonkeyPatch,
num_speculative_tokens: int,
):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
llm = LLM(
model="facebook/opt-125m",
max_model_len=100,
enforce_eager=True, # For faster initialization.
speculative_config={
"method": "ngram",
"prompt_lookup_max": 5,
"prompt_lookup_min": 3,
"num_speculative_tokens": num_speculative_tokens,
},
)
sampling_params = SamplingParams(max_tokens=100, ignore_eos=True)
llm.generate(_PROMPTS, sampling_params)
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10])
def test_eagle_max_len(
monkeypatch: pytest.MonkeyPatch,
num_speculative_tokens: int,
):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
llm = LLM(
model="meta-llama/Meta-Llama-3-8B-Instruct",
enforce_eager=True, # For faster initialization.
speculative_config={
"method": "eagle",
"model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
"num_speculative_tokens": num_speculative_tokens,
},
max_model_len=100,
)
sampling_params = SamplingParams(max_tokens=100, ignore_eos=True)
llm.generate(_PROMPTS, sampling_params)
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import numpy as np import numpy as np
from vllm.config import ModelConfig, SpeculativeConfig, VllmConfig
from vllm.v1.spec_decode.ngram_proposer import (NgramProposer, from vllm.v1.spec_decode.ngram_proposer import (NgramProposer,
_find_subarray_kmp, _find_subarray_kmp,
_kmp_lps_array) _kmp_lps_array)
...@@ -39,50 +40,50 @@ def test_find_subarray_kmp(): ...@@ -39,50 +40,50 @@ def test_find_subarray_kmp():
def test_ngram_proposer(): def test_ngram_proposer():
proposer = NgramProposer()
def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer:
# Dummy model config. Just to set max_model_len.
model_config = ModelConfig(model="facebook/opt-125m",
task="generate",
max_model_len=100,
tokenizer="facebook/opt-125m",
tokenizer_mode="auto",
dtype="auto",
seed=None,
trust_remote_code=False)
return NgramProposer(
vllm_config=VllmConfig(model_config=model_config,
speculative_config=SpeculativeConfig.
from_dict({
"prompt_lookup_min": min_n,
"prompt_lookup_max": max_n,
"num_speculative_tokens": k,
"method": "ngram",
})))
# No match. # No match.
result = proposer.propose( result = ngram_proposer(
context_token_ids=np.array([1, 2, 3, 4, 5]), 2, 2, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 5]))
min_n=2,
max_n=2,
k=2,
)
assert result is None assert result is None
# No match for 4-gram. # No match for 4-gram.
result = proposer.propose( result = ngram_proposer(
context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]), 4, 4, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]))
min_n=4,
max_n=4,
k=2,
)
assert result is None assert result is None
# No match for 4-gram but match for 3-gram. # No match for 4-gram but match for 3-gram.
result = proposer.propose( result = ngram_proposer(
context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]), 3, 4, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]))
min_n=3,
max_n=4,
k=2,
)
assert np.array_equal(result, np.array([4, 1])) assert np.array_equal(result, np.array([4, 1]))
# Match for both 4-gram and 3-gram. # Match for both 4-gram and 3-gram.
# In this case, the proposer should return the 4-gram match. # In this case, the proposer should return the 4-gram match.
result = proposer.propose( result = ngram_proposer(3, 4, 2).propose(
context_token_ids=np.array([2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]), context_token_ids=np.array([2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]))
min_n=3,
max_n=4,
k=2,
)
assert np.array_equal(result, np.array([1, 2])) # Not [5, 1] assert np.array_equal(result, np.array([1, 2])) # Not [5, 1]
# Match for 2-gram and 3-gram, but not 4-gram. # Match for 2-gram and 3-gram, but not 4-gram.
result = proposer.propose( result = ngram_proposer(
context_token_ids=np.array([3, 4, 5, 2, 3, 4, 1, 2, 3, 4]), 2, 4,
min_n=2, 2).propose(context_token_ids=np.array([3, 4, 5, 2, 3, 4, 1, 2, 3, 4]))
max_n=4,
k=2,
)
assert np.array_equal(result, np.array([1, 2])) # Not [5, 2] assert np.array_equal(result, np.array([1, 2])) # Not [5, 2]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment