Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1bf43ae3
Unverified
Commit
1bf43ae3
authored
Nov 02, 2025
by
Biswa Panda
Committed by
GitHub
Nov 03, 2025
Browse files
[BugFix][LoRA] use adapter_id instead of id field of lora_request (#27728)
Signed-off-by:
Biswa Panda
<
biswa.panda@gmail.com
>
parent
0ce743f4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
64 additions
and
3 deletions
+64
-3
tests/v1/core/test_prefix_caching.py
tests/v1/core/test_prefix_caching.py
+61
-2
vllm/v1/core/block_pool.py
vllm/v1/core/block_pool.py
+3
-1
No files found.
tests/v1/core/test_prefix_caching.py
View file @
1bf43ae3
...
...
@@ -9,7 +9,8 @@ import pytest
import
torch
import
vllm.v1.core.kv_cache_utils
as
kv_cache_utils
from
vllm.distributed.kv_events
import
AllBlocksCleared
,
BlockRemoved
from
vllm.distributed.kv_events
import
AllBlocksCleared
,
BlockRemoved
,
BlockStored
from
vllm.lora.request
import
LoRARequest
from
vllm.multimodal.inputs
import
(
MultiModalFeatureSpec
,
MultiModalKwargsItem
,
...
...
@@ -59,6 +60,7 @@ def make_request(
mm_hashes
:
list
[
str
]
|
None
=
None
,
prompt_logprobs
:
int
|
None
=
None
,
cache_salt
:
str
|
None
=
None
,
lora_request
:
LoRARequest
|
None
=
None
,
):
mm_features
=
[]
if
mm_positions
is
not
None
:
...
...
@@ -79,7 +81,7 @@ def make_request(
sampling_params
=
SamplingParams
(
max_tokens
=
17
,
prompt_logprobs
=
prompt_logprobs
),
pooling_params
=
None
,
eos_token_id
=
100
,
lora_request
=
None
,
lora_request
=
lora_request
,
cache_salt
=
cache_salt
,
block_hasher
=
get_request_block_hasher
(
block_size
,
hash_fn
),
)
...
...
@@ -1337,6 +1339,63 @@ def test_kv_cache_events(blocks_to_cache: int):
assert
len
(
manager
.
block_pool
.
cached_block_hash_to_block
)
==
0
@
pytest
.
mark
.
parametrize
(
"blocks_to_cache"
,
[
2
,
3
,
10
])
def
test_kv_cache_events_with_lora
(
blocks_to_cache
:
int
):
"""Test BlockStored events contain correct lora_id when using LoRA requests."""
block_size
=
16
num_blocks
=
blocks_to_cache
+
1
# Create KVCacheManager with events enabled
manager
=
KVCacheManager
(
make_kv_cache_config
(
block_size
,
num_blocks
),
max_model_len
=
8192
,
enable_caching
=
True
,
enable_kv_cache_events
=
True
,
)
# Test with LoRA request
lora_request
=
LoRARequest
(
lora_name
=
"test_lora"
,
lora_int_id
=
42
,
lora_path
=
"/test/path"
)
num_tokens
=
block_size
*
blocks_to_cache
req_with_lora
=
make_request
(
"lora_req"
,
list
(
range
(
num_tokens
)),
block_size
,
sha256
,
lora_request
=
lora_request
,
)
# Allocate slots and get events
_
=
manager
.
allocate_slots
(
req_with_lora
,
num_tokens
)
events
=
manager
.
take_events
()
# Verify BlockStored event contains correct lora_id
block_stored_event
=
events
[
-
1
]
assert
isinstance
(
block_stored_event
,
BlockStored
)
assert
block_stored_event
.
lora_id
==
42
# Should match lora_request.adapter_id
assert
len
(
block_stored_event
.
block_hashes
)
==
blocks_to_cache
assert
block_stored_event
.
block_size
==
block_size
# Clean up
manager
.
free
(
req_with_lora
)
# Test without LoRA request (should have lora_id=None)
req_without_lora
=
make_request
(
"no_lora_req"
,
list
(
range
(
num_tokens
)),
block_size
,
sha256
)
_
=
manager
.
allocate_slots
(
req_without_lora
,
num_tokens
)
events
=
manager
.
take_events
()
block_stored_event
=
events
[
-
1
]
assert
isinstance
(
block_stored_event
,
BlockStored
)
assert
block_stored_event
.
lora_id
is
None
# Should be None when no LoRA request
assert
len
(
block_stored_event
.
block_hashes
)
==
blocks_to_cache
assert
block_stored_event
.
block_size
==
block_size
def
test_eagle_enabled_removes_last_block
():
"""Verify Eagle does NOT remove blocks when request
length is divisible by block size."""
...
...
vllm/v1/core/block_pool.py
View file @
1bf43ae3
...
...
@@ -259,7 +259,9 @@ class BlockPool:
num_cached_blocks
*
block_size
:
num_full_blocks
*
block_size
],
block_size
=
block_size
,
lora_id
=
request
.
lora_request
.
id
if
request
.
lora_request
else
None
,
lora_id
=
request
.
lora_request
.
adapter_id
if
request
.
lora_request
else
None
,
medium
=
MEDIUM_GPU
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment