Unverified Commit 39512aba authored by Sage's avatar Sage Committed by GitHub
Browse files

[Prefix Cache] Include lora_name in BlockStored event for deterministic...


[Prefix Cache] Include lora_name in BlockStored event for deterministic KV-cache reconstruction (#27577)
Signed-off-by: default avatarSage Ahrac <sagiahrak@gmail.com>
Co-authored-by: default avatarSage <80211083+sagiahrac@users.noreply.github.com>
parent 0f35429a
...@@ -28,8 +28,14 @@ class BlockStored(KVCacheEvent): ...@@ -28,8 +28,14 @@ class BlockStored(KVCacheEvent):
parent_block_hash: ExternalBlockHash | None parent_block_hash: ExternalBlockHash | None
token_ids: list[int] token_ids: list[int]
block_size: int block_size: int
lora_id: int | None lora_id: int | None
"""Deprecated: use `lora_name` for KV block key hash.
Retained for backward compatibility.
"""
medium: str | None medium: str | None
lora_name: str | None
class BlockRemoved(KVCacheEvent): class BlockRemoved(KVCacheEvent):
......
...@@ -733,6 +733,7 @@ def test_kv_cache_events( ...@@ -733,6 +733,7 @@ def test_kv_cache_events(
) )
assert event.parent_block_hash is None, "Parent block hash should be None" assert event.parent_block_hash is None, "Parent block hash should be None"
assert event.lora_id is None, "Lora id should be None" assert event.lora_id is None, "Lora id should be None"
assert event.lora_name is None, "Lora name should be None"
assert len(event.token_ids) == num_blocks * block_size, ( assert len(event.token_ids) == num_blocks * block_size, (
"Token ids should be the same as the custom tokens" "Token ids should be the same as the custom tokens"
) )
......
...@@ -25,6 +25,7 @@ def mock_lmcache_engine_event(): ...@@ -25,6 +25,7 @@ def mock_lmcache_engine_event():
lora_id, lora_id,
block_size, block_size,
medium, medium,
lora_name,
): ):
self.block_hashes = block_hashes self.block_hashes = block_hashes
self.parent_block_hash = parent_block_hash self.parent_block_hash = parent_block_hash
...@@ -32,6 +33,7 @@ def mock_lmcache_engine_event(): ...@@ -32,6 +33,7 @@ def mock_lmcache_engine_event():
self.lora_id = lora_id self.lora_id = lora_id
self.block_size = block_size self.block_size = block_size
self.medium = medium self.medium = medium
self.lora_name = lora_name
return MockEvent( return MockEvent(
block_hashes=["hash1", "hash2"], block_hashes=["hash1", "hash2"],
...@@ -40,6 +42,7 @@ def mock_lmcache_engine_event(): ...@@ -40,6 +42,7 @@ def mock_lmcache_engine_event():
lora_id=None, lora_id=None,
block_size=16, block_size=16,
medium="GPU", medium="GPU",
lora_name=None,
) )
...@@ -109,6 +112,7 @@ class TestGetKVConnectorKVCacheEvents: ...@@ -109,6 +112,7 @@ class TestGetKVConnectorKVCacheEvents:
assert events[0].lora_id is None assert events[0].lora_id is None
assert events[0].block_size == 16 assert events[0].block_size == 16
assert events[0].medium == "GPU" assert events[0].medium == "GPU"
assert events[0].lora_name is None
def test_converts_multiple_events(self, mock_connector): def test_converts_multiple_events(self, mock_connector):
"""Test conversion of multiple events from lmcache engine format.""" """Test conversion of multiple events from lmcache engine format."""
...@@ -121,6 +125,7 @@ class TestGetKVConnectorKVCacheEvents: ...@@ -121,6 +125,7 @@ class TestGetKVConnectorKVCacheEvents:
self.lora_id = None self.lora_id = None
self.block_size = 16 self.block_size = 16
self.medium = "GPU" self.medium = "GPU"
self.lora_name = None
events = [MockEvent(i) for i in range(5)] events = [MockEvent(i) for i in range(5)]
mock_connector._lmcache_engine.get_kv_events.return_value = events mock_connector._lmcache_engine.get_kv_events.return_value = events
...@@ -150,6 +155,7 @@ class TestGetKVConnectorKVCacheEvents: ...@@ -150,6 +155,7 @@ class TestGetKVConnectorKVCacheEvents:
self.lora_id = 42 self.lora_id = 42
self.block_size = 32 self.block_size = 32
self.medium = "DISK" self.medium = "DISK"
self.lora_name = "lora_example"
mock_connector._lmcache_engine.get_kv_events.return_value = [ mock_connector._lmcache_engine.get_kv_events.return_value = [
MockEventWithLora() MockEventWithLora()
...@@ -166,6 +172,7 @@ class TestGetKVConnectorKVCacheEvents: ...@@ -166,6 +172,7 @@ class TestGetKVConnectorKVCacheEvents:
assert event.lora_id == 42 assert event.lora_id == 42
assert event.block_size == 32 assert event.block_size == 32
assert event.medium == "DISK" assert event.medium == "DISK"
assert event.lora_name == "lora_example"
def test_handles_none_parent_block_hash(self, mock_connector): def test_handles_none_parent_block_hash(self, mock_connector):
"""Test handling of events with None parent_block_hash.""" """Test handling of events with None parent_block_hash."""
...@@ -178,6 +185,7 @@ class TestGetKVConnectorKVCacheEvents: ...@@ -178,6 +185,7 @@ class TestGetKVConnectorKVCacheEvents:
self.lora_id = None self.lora_id = None
self.block_size = 16 self.block_size = 16
self.medium = "GPU" self.medium = "GPU"
self.lora_name = None
mock_connector._lmcache_engine.get_kv_events.return_value = [ mock_connector._lmcache_engine.get_kv_events.return_value = [
MockEventNoParent() MockEventNoParent()
...@@ -223,6 +231,7 @@ class TestUpdateConnectorOutput: ...@@ -223,6 +231,7 @@ class TestUpdateConnectorOutput:
block_size=16, block_size=16,
lora_id=None, lora_id=None,
medium="GPU", medium="GPU",
lora_name=None,
) )
kv_events.add_events([event]) kv_events.add_events([event])
...@@ -243,6 +252,7 @@ class TestUpdateConnectorOutput: ...@@ -243,6 +252,7 @@ class TestUpdateConnectorOutput:
block_size=16, block_size=16,
lora_id=None, lora_id=None,
medium="GPU", medium="GPU",
lora_name=None,
) )
existing_events.add_events([event1]) existing_events.add_events([event1])
existing_events.add_events([event1]) # Simulate 2 workers reporting existing_events.add_events([event1]) # Simulate 2 workers reporting
...@@ -258,6 +268,7 @@ class TestUpdateConnectorOutput: ...@@ -258,6 +268,7 @@ class TestUpdateConnectorOutput:
block_size=16, block_size=16,
lora_id=None, lora_id=None,
medium="GPU", medium="GPU",
lora_name=None,
) )
new_events.add_events([event2]) new_events.add_events([event2])
...@@ -288,6 +299,7 @@ class TestUpdateConnectorOutput: ...@@ -288,6 +299,7 @@ class TestUpdateConnectorOutput:
block_size=16, block_size=16,
lora_id=None, lora_id=None,
medium="GPU", medium="GPU",
lora_name=None,
) )
new_events.add_events([event]) new_events.add_events([event])
...@@ -309,6 +321,7 @@ class TestUpdateConnectorOutput: ...@@ -309,6 +321,7 @@ class TestUpdateConnectorOutput:
block_size=16, block_size=16,
lora_id=None, lora_id=None,
medium="GPU", medium="GPU",
lora_name=None,
) )
events1.add_events([event1]) events1.add_events([event1])
output1 = KVConnectorOutput(kv_cache_events=events1) output1 = KVConnectorOutput(kv_cache_events=events1)
...@@ -323,6 +336,7 @@ class TestUpdateConnectorOutput: ...@@ -323,6 +336,7 @@ class TestUpdateConnectorOutput:
block_size=16, block_size=16,
lora_id=None, lora_id=None,
medium="GPU", medium="GPU",
lora_name=None,
) )
events2.add_events([event2]) events2.add_events([event2])
output2 = KVConnectorOutput(kv_cache_events=events2) output2 = KVConnectorOutput(kv_cache_events=events2)
...@@ -337,6 +351,7 @@ class TestUpdateConnectorOutput: ...@@ -337,6 +351,7 @@ class TestUpdateConnectorOutput:
block_size=16, block_size=16,
lora_id=None, lora_id=None,
medium="GPU", medium="GPU",
lora_name=None,
) )
events3.add_events([event3]) events3.add_events([event3])
output3 = KVConnectorOutput(kv_cache_events=events3) output3 = KVConnectorOutput(kv_cache_events=events3)
...@@ -358,6 +373,7 @@ class TestUpdateConnectorOutput: ...@@ -358,6 +373,7 @@ class TestUpdateConnectorOutput:
block_size=16, block_size=16,
lora_id=None, lora_id=None,
medium="GPU", medium="GPU",
lora_name=None,
) )
events1.add_events([event1]) events1.add_events([event1])
output1 = KVConnectorOutput(kv_cache_events=events1) output1 = KVConnectorOutput(kv_cache_events=events1)
...@@ -397,6 +413,7 @@ class TestTakeEvents: ...@@ -397,6 +413,7 @@ class TestTakeEvents:
block_size=16, block_size=16,
lora_id=None, lora_id=None,
medium="GPU", medium="GPU",
lora_name=None,
) )
event2 = BlockStored( event2 = BlockStored(
block_hashes=["hash2"], block_hashes=["hash2"],
...@@ -405,6 +422,7 @@ class TestTakeEvents: ...@@ -405,6 +422,7 @@ class TestTakeEvents:
block_size=16, block_size=16,
lora_id=None, lora_id=None,
medium="GPU", medium="GPU",
lora_name=None,
) )
kv_events.add_events([event1, event2]) kv_events.add_events([event1, event2])
mock_connector._kv_cache_events = kv_events mock_connector._kv_cache_events = kv_events
...@@ -431,6 +449,7 @@ class TestTakeEvents: ...@@ -431,6 +449,7 @@ class TestTakeEvents:
block_size=16, block_size=16,
lora_id=None, lora_id=None,
medium="GPU", medium="GPU",
lora_name=None,
) )
uncommon_event = BlockStored( uncommon_event = BlockStored(
block_hashes=["hash_uncommon"], block_hashes=["hash_uncommon"],
...@@ -439,6 +458,7 @@ class TestTakeEvents: ...@@ -439,6 +458,7 @@ class TestTakeEvents:
block_size=16, block_size=16,
lora_id=None, lora_id=None,
medium="GPU", medium="GPU",
lora_name=None,
) )
# All 3 workers report common_event # All 3 workers report common_event
...@@ -469,6 +489,7 @@ class TestTakeEvents: ...@@ -469,6 +489,7 @@ class TestTakeEvents:
block_size=16, block_size=16,
lora_id=None, lora_id=None,
medium="GPU", medium="GPU",
lora_name=None,
) )
kv_events1.add_events([event1]) kv_events1.add_events([event1])
mock_connector._kv_cache_events = kv_events1 mock_connector._kv_cache_events = kv_events1
...@@ -491,6 +512,7 @@ class TestTakeEvents: ...@@ -491,6 +512,7 @@ class TestTakeEvents:
block_size=16, block_size=16,
lora_id=None, lora_id=None,
medium="GPU", medium="GPU",
lora_name=None,
) )
kv_events2.add_events([event2]) kv_events2.add_events([event2])
mock_connector._kv_cache_events = kv_events2 mock_connector._kv_cache_events = kv_events2
...@@ -510,6 +532,7 @@ class TestTakeEvents: ...@@ -510,6 +532,7 @@ class TestTakeEvents:
block_size=16, block_size=16,
lora_id=None, lora_id=None,
medium="GPU", medium="GPU",
lora_name=None,
) )
event2 = BlockStored( event2 = BlockStored(
block_hashes=["hash2"], block_hashes=["hash2"],
...@@ -518,6 +541,7 @@ class TestTakeEvents: ...@@ -518,6 +541,7 @@ class TestTakeEvents:
block_size=16, block_size=16,
lora_id=None, lora_id=None,
medium="GPU", medium="GPU",
lora_name=None,
) )
# Worker 1 reports event1 # Worker 1 reports event1
...@@ -572,6 +596,7 @@ class TestIntegrationScenarios: ...@@ -572,6 +596,7 @@ class TestIntegrationScenarios:
self.lora_id = None self.lora_id = None
self.block_size = 16 self.block_size = 16
self.medium = "GPU" self.medium = "GPU"
self.lora_name = None
# Worker 1 # Worker 1
mock_connector._lmcache_engine.get_kv_events.return_value = [ mock_connector._lmcache_engine.get_kv_events.return_value = [
...@@ -628,6 +653,7 @@ class TestIntegrationScenarios: ...@@ -628,6 +653,7 @@ class TestIntegrationScenarios:
self.lora_id = None self.lora_id = None
self.block_size = 16 self.block_size = 16
self.medium = "GPU" self.medium = "GPU"
self.lora_name = None
for cycle in range(3): for cycle in range(3):
# Get events # Get events
...@@ -667,6 +693,7 @@ class TestIntegrationScenarios: ...@@ -667,6 +693,7 @@ class TestIntegrationScenarios:
block_size=16, block_size=16,
lora_id=None, lora_id=None,
medium="GPU", medium="GPU",
lora_name=None,
) )
worker1_unique_event = BlockStored( worker1_unique_event = BlockStored(
...@@ -676,6 +703,7 @@ class TestIntegrationScenarios: ...@@ -676,6 +703,7 @@ class TestIntegrationScenarios:
block_size=16, block_size=16,
lora_id=None, lora_id=None,
medium="GPU", medium="GPU",
lora_name=None,
) )
worker2_unique_event = BlockStored( worker2_unique_event = BlockStored(
...@@ -685,6 +713,7 @@ class TestIntegrationScenarios: ...@@ -685,6 +713,7 @@ class TestIntegrationScenarios:
block_size=16, block_size=16,
lora_id=None, lora_id=None,
medium="GPU", medium="GPU",
lora_name=None,
) )
worker3_unique_event = BlockStored( worker3_unique_event = BlockStored(
...@@ -694,6 +723,7 @@ class TestIntegrationScenarios: ...@@ -694,6 +723,7 @@ class TestIntegrationScenarios:
block_size=16, block_size=16,
lora_id=None, lora_id=None,
medium="GPU", medium="GPU",
lora_name=None,
) )
# Create events for each worker # Create events for each worker
......
...@@ -528,6 +528,7 @@ def test_offloading_connector(request_runner): ...@@ -528,6 +528,7 @@ def test_offloading_connector(request_runner):
assert event.token_ids == [] assert event.token_ids == []
assert event.parent_block_hash is None assert event.parent_block_hash is None
assert event.lora_id is None assert event.lora_id is None
assert event.lora_name is None
event = events[1] event = events[1]
assert isinstance(event, BlockRemoved) assert isinstance(event, BlockRemoved)
assert event.block_hashes == to_hashes([4, 5, 6]) assert event.block_hashes == to_hashes([4, 5, 6])
......
...@@ -51,8 +51,14 @@ class BlockStored(KVCacheEvent): ...@@ -51,8 +51,14 @@ class BlockStored(KVCacheEvent):
parent_block_hash: ExternalBlockHash | None parent_block_hash: ExternalBlockHash | None
token_ids: list[int] token_ids: list[int]
block_size: int block_size: int
lora_id: int | None lora_id: int | None
"""Deprecated: use `lora_name` for KV block key hash.
Retained for backward compatibility.
"""
medium: str | None medium: str | None
lora_name: str | None
def __hash__(self) -> int: def __hash__(self) -> int:
return hash( return hash(
......
...@@ -234,6 +234,7 @@ class LMCacheConnectorV1(KVConnectorBase_V1): ...@@ -234,6 +234,7 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
lora_id=e.lora_id, lora_id=e.lora_id,
block_size=e.block_size, block_size=e.block_size,
medium=e.medium, medium=e.medium,
lora_name=e.lora_name,
) )
for e in events for e in events
] ]
......
...@@ -406,6 +406,7 @@ class OffloadingConnectorScheduler: ...@@ -406,6 +406,7 @@ class OffloadingConnectorScheduler:
lora_id=None, lora_id=None,
block_size=event.block_size, block_size=event.block_size,
medium=event.medium, medium=event.medium,
lora_name=None,
) )
......
...@@ -290,6 +290,9 @@ class BlockPool: ...@@ -290,6 +290,9 @@ class BlockPool:
if request.lora_request if request.lora_request
else None, else None,
medium=MEDIUM_GPU, medium=MEDIUM_GPU,
lora_name=request.lora_request.name
if request.lora_request
else None,
) )
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment