Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5db0e637
Commit
5db0e637
authored
Sep 24, 2025
by
xuxz
Browse files
lmcache enable with cli args
parents
5086453d
f6fcc8ff
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
91 additions
and
8 deletions
+91
-8
vllm/attention/layer.py
vllm/attention/layer.py
+8
-1
vllm/distributed/kv_transfer/__init__.py
vllm/distributed/kv_transfer/__init__.py
+2
-2
vllm/distributed/kv_transfer/kv_transfer_state.py
vllm/distributed/kv_transfer/kv_transfer_state.py
+20
-2
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+41
-1
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+1
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+16
-0
vllm/zero_overhead/v1/gpu_model_runner.py
vllm/zero_overhead/v1/gpu_model_runner.py
+3
-1
No files found.
vllm/attention/layer.py
View file @
5db0e637
...
@@ -8,11 +8,13 @@ import torch.nn as nn
...
@@ -8,11 +8,13 @@ import torch.nn as nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
vllm.two_batch_overlap.v1.two_batch_overlap_v1
import
tbo_maybe_save_kv_layer_to_connector
from
vllm.two_batch_overlap.v1.two_batch_overlap_v1
import
tbo_maybe_save_kv_layer_to_connector
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
KVConnectorRole
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.attention
import
AttentionType
from
vllm.attention
import
AttentionType
from
vllm.attention.selector
import
backend_name_to_enum
,
get_attn_backend
from
vllm.attention.selector
import
backend_name_to_enum
,
get_attn_backend
from
vllm.config
import
CacheConfig
,
get_current_vllm_config
from
vllm.config
import
CacheConfig
,
get_current_vllm_config
from
vllm.distributed.kv_transfer
import
(
get_kv_transfer_group
,
from
vllm.distributed.kv_transfer
import
(
get_kv_transfer_group
,
get_lmcache_connector
,
has_kv_transfer_group
,
has_kv_transfer_group
,
is_v1_kv_transfer_group
)
is_v1_kv_transfer_group
)
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
...
@@ -23,7 +25,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
...
@@ -23,7 +25,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from
vllm.platforms
import
_Backend
,
current_platform
from
vllm.platforms
import
_Backend
,
current_platform
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
from
vllm.v1.attention.backends.utils
import
validate_kv_sharing_target
from
vllm.v1.attention.backends.utils
import
validate_kv_sharing_target
from
vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector
import
LMCacheConnectorV1
class
Attention
(
nn
.
Module
):
class
Attention
(
nn
.
Module
):
"""Attention layer.
"""Attention layer.
...
@@ -378,6 +380,9 @@ def wait_for_kv_layer_from_connector(layer_name: str):
...
@@ -378,6 +380,9 @@ def wait_for_kv_layer_from_connector(layer_name: str):
assert
isinstance
(
attn_metadata
,
dict
)
assert
isinstance
(
attn_metadata
,
dict
)
connector
.
wait_for_layer_load
(
layer_name
)
connector
.
wait_for_layer_load
(
layer_name
)
get_lmcache_connector
().
wait_for_layer_load
(
layer_name
)
def
maybe_save_kv_layer_to_connector
(
def
maybe_save_kv_layer_to_connector
(
layer_name
:
str
,
layer_name
:
str
,
...
@@ -395,6 +400,8 @@ def maybe_save_kv_layer_to_connector(
...
@@ -395,6 +400,8 @@ def maybe_save_kv_layer_to_connector(
assert
isinstance
(
attn_metadata
,
dict
)
assert
isinstance
(
attn_metadata
,
dict
)
connector
.
save_kv_layer
(
layer_name
,
kv_cache_layer
,
connector
.
save_kv_layer
(
layer_name
,
kv_cache_layer
,
attn_metadata
[
layer_name
])
attn_metadata
[
layer_name
])
get_lmcache_connector
().
save_kv_layer
(
layer_name
,
kv_cache_layer
,
attn_metadata
[
layer_name
])
def
unified_attention
(
def
unified_attention
(
...
...
vllm/distributed/kv_transfer/__init__.py
View file @
5db0e637
...
@@ -3,10 +3,10 @@
...
@@ -3,10 +3,10 @@
from
vllm.distributed.kv_transfer.kv_transfer_state
import
(
from
vllm.distributed.kv_transfer.kv_transfer_state
import
(
KVConnectorBaseType
,
ensure_kv_transfer_initialized
,
get_kv_transfer_group
,
KVConnectorBaseType
,
ensure_kv_transfer_initialized
,
get_kv_transfer_group
,
has_kv_transfer_group
,
is_v1_kv_transfer_group
)
get_lmcache_connector
,
has_kv_transfer_group
,
is_v1_kv_transfer_group
)
__all__
=
[
__all__
=
[
"get_kv_transfer_group"
,
"has_kv_transfer_group"
,
"get_kv_transfer_group"
,
"get_lmcache_connector"
,
"has_kv_transfer_group"
,
"is_v1_kv_transfer_group"
,
"ensure_kv_transfer_initialized"
,
"is_v1_kv_transfer_group"
,
"ensure_kv_transfer_initialized"
,
"KVConnectorBaseType"
"KVConnectorBaseType"
]
]
vllm/distributed/kv_transfer/kv_transfer_state.py
View file @
5db0e637
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
TYPE_CHECKING
,
Optional
from
typing
import
TYPE_CHECKING
,
Optional
import
copy
from
vllm
import
envs
from
vllm
import
envs
from
vllm.config
import
get_current_vllm_config
from
vllm.distributed.kv_transfer.kv_connector.base
import
KVConnectorBaseType
from
vllm.distributed.kv_transfer.kv_connector.base
import
KVConnectorBaseType
from
vllm.distributed.kv_transfer.kv_connector.factory
import
(
from
vllm.distributed.kv_transfer.kv_connector.factory
import
(
KVConnectorFactory
)
KVConnectorFactory
)
from
vllm.distributed.kv_transfer.kv_connector.v1
import
(
KVConnectorBase_V1
,
from
vllm.distributed.kv_transfer.kv_connector.v1
import
(
KVConnectorBase_V1
,
KVConnectorRole
)
KVConnectorRole
)
from
vllm.distributed.parallel_state
import
get_world_group
from
vllm.distributed.parallel_state
import
get_world_group
from
vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector
import
LMCacheConnectorV1
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
KVConnectorRole
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
_KV_CONNECTOR_AGENT
:
Optional
[
KVConnectorBaseType
]
=
None
_KV_CONNECTOR_AGENT
:
Optional
[
KVConnectorBaseType
]
=
None
_KV_LMCACHE_CONNECTOR_AGENT
:
Optional
[
KVConnectorBaseType
]
=
None
def
get_kv_transfer_group
()
->
KVConnectorBaseType
:
def
get_kv_transfer_group
()
->
KVConnectorBaseType
:
assert
_KV_CONNECTOR_AGENT
is
not
None
,
(
assert
_KV_CONNECTOR_AGENT
is
not
None
,
(
"disaggregated KV cache transfer parallel group is not initialized"
)
"disaggregated KV cache transfer parallel group is not initialized"
)
return
_KV_CONNECTOR_AGENT
return
_KV_CONNECTOR_AGENT
def
get_lmcache_connector
()
->
KVConnectorBaseType
:
assert
_KV_LMCACHE_CONNECTOR_AGENT
is
not
None
,
(
"LM cache transfer parallel group is not initialized"
)
return
_KV_LMCACHE_CONNECTOR_AGENT
def
has_kv_transfer_group
()
->
bool
:
def
has_kv_transfer_group
()
->
bool
:
return
_KV_CONNECTOR_AGENT
is
not
None
return
_KV_CONNECTOR_AGENT
is
not
None
...
@@ -54,6 +62,16 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
...
@@ -54,6 +62,16 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
"""
"""
global
_KV_CONNECTOR_AGENT
global
_KV_CONNECTOR_AGENT
global
_KV_LMCACHE_CONNECTOR_AGENT
if
_KV_LMCACHE_CONNECTOR_AGENT
is
None
:
lmcache_config
=
copy
.
deepcopy
(
vllm_config
)
from
vllm.config
import
KVTransferConfig
lmcache_config
.
kv_transfer_config
=
KVTransferConfig
(
kv_connector
=
"LMCacheConnectorV1"
,
kv_role
=
"kv_both"
)
lmcache_config
.
kv_transfer_config
.
engine_id
=
"ed9e943a-e455-4ed6-b88c-09ae6263f0c9"
lmcache_connector
:
LMCacheConnectorV1
=
LMCacheConnectorV1
(
lmcache_config
,
role
=
KVConnectorRole
.
WORKER
)
_KV_LMCACHE_CONNECTOR_AGENT
=
lmcache_connector
if
vllm_config
.
kv_transfer_config
is
None
:
if
vllm_config
.
kv_transfer_config
is
None
:
return
return
...
...
vllm/v1/core/sched/scheduler.py
View file @
5db0e637
...
@@ -5,11 +5,12 @@ from __future__ import annotations
...
@@ -5,11 +5,12 @@ from __future__ import annotations
import
itertools
import
itertools
import
time
import
time
import
copy
from
collections
import
defaultdict
from
collections
import
defaultdict
from
collections.abc
import
Iterable
from
collections.abc
import
Iterable
from
typing
import
Any
,
Optional
,
Union
from
typing
import
Any
,
Optional
,
Union
from
vllm.config
import
VllmConfig
from
vllm.config
import
KVTransferConfig
,
VllmConfig
from
vllm.distributed.kv_events
import
EventPublisherFactory
,
KVEventBatch
from
vllm.distributed.kv_events
import
EventPublisherFactory
,
KVEventBatch
from
vllm.distributed.kv_transfer.kv_connector.factory
import
(
from
vllm.distributed.kv_transfer.kv_connector.factory
import
(
KVConnectorFactory
)
KVConnectorFactory
)
...
@@ -34,6 +35,7 @@ from vllm.v1.outputs import ModelRunnerOutput
...
@@ -34,6 +35,7 @@ from vllm.v1.outputs import ModelRunnerOutput
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.spec_decode.metrics
import
SpecDecodingStats
from
vllm.v1.spec_decode.metrics
import
SpecDecodingStats
from
vllm.v1.structured_output
import
StructuredOutputManager
from
vllm.v1.structured_output
import
StructuredOutputManager
from
vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector
import
LMCacheConnectorV1
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -86,6 +88,12 @@ class Scheduler(SchedulerInterface):
...
@@ -86,6 +88,12 @@ class Scheduler(SchedulerInterface):
self
.
connector
=
KVConnectorFactory
.
create_connector_v1
(
self
.
connector
=
KVConnectorFactory
.
create_connector_v1
(
config
=
self
.
vllm_config
,
role
=
KVConnectorRole
.
SCHEDULER
)
config
=
self
.
vllm_config
,
role
=
KVConnectorRole
.
SCHEDULER
)
lmcache_config
=
copy
.
deepcopy
(
self
.
vllm_config
)
lmcache_config
.
kv_transfer_config
=
KVTransferConfig
(
kv_connector
=
"LMCacheConnectorV1"
,
kv_role
=
"kv_both"
)
lmcache_config
.
kv_transfer_config
.
engine_id
=
"ed9e943a-e455-4ed6-b88c-09ae6263f0c9"
self
.
lmcache_connector
:
LMCacheConnectorV1
=
LMCacheConnectorV1
(
lmcache_config
,
role
=
KVConnectorRole
.
SCHEDULER
)
self
.
kv_event_publisher
=
EventPublisherFactory
.
create
(
self
.
kv_event_publisher
=
EventPublisherFactory
.
create
(
self
.
kv_events_config
,
self
.
kv_events_config
,
self
.
parallel_config
.
data_parallel_rank
,
self
.
parallel_config
.
data_parallel_rank
,
...
@@ -389,6 +397,11 @@ class Scheduler(SchedulerInterface):
...
@@ -389,6 +397,11 @@ class Scheduler(SchedulerInterface):
self
.
connector
.
get_num_new_matched_tokens
(
self
.
connector
.
get_num_new_matched_tokens
(
request
,
num_new_local_computed_tokens
))
request
,
num_new_local_computed_tokens
))
if
self
.
lmcache_connector
is
not
None
:
num_external_computed_tokens
,
load_kv_async
=
(
self
.
lmcache_connector
.
get_num_new_matched_tokens
(
request
,
num_new_local_computed_tokens
))
# Total computed tokens (local + external).
# Total computed tokens (local + external).
num_computed_tokens
=
(
num_new_local_computed_tokens
+
num_computed_tokens
=
(
num_new_local_computed_tokens
+
num_external_computed_tokens
)
num_external_computed_tokens
)
...
@@ -463,6 +476,13 @@ class Scheduler(SchedulerInterface):
...
@@ -463,6 +476,13 @@ class Scheduler(SchedulerInterface):
num_external_computed_tokens
,
num_external_computed_tokens
,
)
)
if
self
.
lmcache_connector
is
not
None
:
self
.
lmcache_connector
.
update_state_after_alloc
(
request
,
new_computed_blocks
+
new_blocks
,
num_external_computed_tokens
,
)
# Request was already popped from self.waiting
# Request was already popped from self.waiting
# unless it was re-added above due to new_blocks being None.
# unless it was re-added above due to new_blocks being None.
request
=
self
.
waiting
.
pop_request
()
request
=
self
.
waiting
.
pop_request
()
...
@@ -578,6 +598,10 @@ class Scheduler(SchedulerInterface):
...
@@ -578,6 +598,10 @@ class Scheduler(SchedulerInterface):
meta
=
self
.
connector
.
build_connector_meta
(
scheduler_output
)
meta
=
self
.
connector
.
build_connector_meta
(
scheduler_output
)
scheduler_output
.
kv_connector_metadata
=
meta
scheduler_output
.
kv_connector_metadata
=
meta
if
self
.
lmcache_connector
is
not
None
:
meta
=
self
.
lmcache_connector
.
build_connector_meta
(
scheduler_output
)
scheduler_output
.
kv_connector_metadata
=
meta
events
=
self
.
kv_cache_manager
.
take_events
()
events
=
self
.
kv_cache_manager
.
take_events
()
if
events
:
if
events
:
batch
=
KVEventBatch
(
ts
=
time
.
time
(),
events
=
events
)
batch
=
KVEventBatch
(
ts
=
time
.
time
(),
events
=
events
)
...
@@ -676,6 +700,11 @@ class Scheduler(SchedulerInterface):
...
@@ -676,6 +700,11 @@ class Scheduler(SchedulerInterface):
self
.
connector
.
get_num_new_matched_tokens
(
self
.
connector
.
get_num_new_matched_tokens
(
request
,
num_new_local_computed_tokens
))
request
,
num_new_local_computed_tokens
))
if
self
.
lmcache_connector
is
not
None
:
num_external_computed_tokens
,
load_kv_async
=
(
self
.
lmcache_connector
.
get_num_new_matched_tokens
(
request
,
num_new_local_computed_tokens
))
# Total computed tokens (local + external).
# Total computed tokens (local + external).
num_computed_tokens
=
(
num_new_local_computed_tokens
+
num_computed_tokens
=
(
num_new_local_computed_tokens
+
num_external_computed_tokens
)
num_external_computed_tokens
)
...
@@ -750,6 +779,13 @@ class Scheduler(SchedulerInterface):
...
@@ -750,6 +779,13 @@ class Scheduler(SchedulerInterface):
num_external_computed_tokens
,
num_external_computed_tokens
,
)
)
if
self
.
lmcache_connector
is
not
None
:
self
.
lmcache_connector
.
update_state_after_alloc
(
request
,
new_computed_blocks
+
new_blocks
,
num_external_computed_tokens
,
)
# Request was already popped from self.waiting
# Request was already popped from self.waiting
# unless it was re-added above due to new_blocks being None.
# unless it was re-added above due to new_blocks being None.
request
=
self
.
waiting
.
pop_request
()
request
=
self
.
waiting
.
pop_request
()
...
@@ -994,6 +1030,10 @@ class Scheduler(SchedulerInterface):
...
@@ -994,6 +1030,10 @@ class Scheduler(SchedulerInterface):
meta
=
self
.
connector
.
build_connector_meta
(
scheduler_output
)
meta
=
self
.
connector
.
build_connector_meta
(
scheduler_output
)
scheduler_output
.
kv_connector_metadata
=
meta
scheduler_output
.
kv_connector_metadata
=
meta
if
self
.
lmcache_connector
is
not
None
:
meta
=
self
.
lmcache_connector
.
build_connector_meta
(
scheduler_output
)
scheduler_output
.
kv_connector_metadata
=
meta
events
=
self
.
kv_cache_manager
.
take_events
()
events
=
self
.
kv_cache_manager
.
take_events
()
if
events
:
if
events
:
batch
=
KVEventBatch
(
ts
=
time
.
time
(),
events
=
events
)
batch
=
KVEventBatch
(
ts
=
time
.
time
(),
events
=
events
)
...
...
vllm/v1/engine/core.py
View file @
5db0e637
...
@@ -864,7 +864,7 @@ class DPEngineCoreProc(EngineCoreProc):
...
@@ -864,7 +864,7 @@ class DPEngineCoreProc(EngineCoreProc):
vllm_config
.
kv_transfer_config
.
engine_id
=
(
vllm_config
.
kv_transfer_config
.
engine_id
=
(
f
"
{
vllm_config
.
kv_transfer_config
.
engine_id
}
_dp
{
local_dp_rank
}
"
f
"
{
vllm_config
.
kv_transfer_config
.
engine_id
}
_dp
{
local_dp_rank
}
"
)
)
logger
.
debug
(
"Setting kv_transfer_config.engine_id to %s"
,
logger
.
info
(
"Setting kv_transfer_config.engine_id to %s"
,
vllm_config
.
kv_transfer_config
.
engine_id
)
vllm_config
.
kv_transfer_config
.
engine_id
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
5db0e637
...
@@ -23,6 +23,7 @@ from vllm.config import (CompilationLevel, VllmConfig,
...
@@ -23,6 +23,7 @@ from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config
)
get_layers_from_vllm_config
)
from
vllm.distributed.eplb.eplb_state
import
EplbState
from
vllm.distributed.eplb.eplb_state
import
EplbState
from
vllm.distributed.kv_transfer
import
(
get_kv_transfer_group
,
from
vllm.distributed.kv_transfer
import
(
get_kv_transfer_group
,
get_lmcache_connector
,
has_kv_transfer_group
)
has_kv_transfer_group
)
from
vllm.distributed.kv_transfer.kv_connector.v1
import
KVConnectorBase_V1
from
vllm.distributed.kv_transfer.kv_connector.v1
import
KVConnectorBase_V1
from
vllm.distributed.parallel_state
import
(
from
vllm.distributed.parallel_state
import
(
...
@@ -73,6 +74,9 @@ from vllm.two_batch_overlap.v1.model_input_split_v1 import tbo_split_and_execute
...
@@ -73,6 +74,9 @@ from vllm.two_batch_overlap.v1.model_input_split_v1 import tbo_split_and_execute
from
..sample.logits_processor
import
LogitsProcessorManager
from
..sample.logits_processor
import
LogitsProcessorManager
from
.utils
import
(
gather_mm_placeholders
,
initialize_kv_cache_for_kv_sharing
,
from
.utils
import
(
gather_mm_placeholders
,
initialize_kv_cache_for_kv_sharing
,
sanity_check_mm_encoder_outputs
,
scatter_mm_placeholders
)
sanity_check_mm_encoder_outputs
,
scatter_mm_placeholders
)
from
vllm.config
import
CacheConfig
,
get_current_vllm_config
from
vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector
import
LMCacheConnectorV1
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
KVConnectorRole
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
import
xgrammar
as
xgr
import
xgrammar
as
xgr
...
@@ -1573,6 +1577,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1573,6 +1577,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
has_kv_transfer_group
():
if
has_kv_transfer_group
():
get_kv_transfer_group
().
clear_connector_metadata
()
get_kv_transfer_group
().
clear_connector_metadata
()
get_lmcache_connector
().
clear_connector_metadata
()
self
.
eplb_step
()
self
.
eplb_step
()
return
ModelRunnerOutput
(
return
ModelRunnerOutput
(
...
@@ -1736,11 +1742,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1736,11 +1742,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Do this here to save a collective_rpc.
# Do this here to save a collective_rpc.
kv_connector
.
start_load_kv
(
get_forward_context
())
kv_connector
.
start_load_kv
(
get_forward_context
())
lmcache_connector
=
get_lmcache_connector
()
lmcache_connector
.
bind_connector_metadata
(
scheduler_output
.
kv_connector_metadata
)
lmcache_connector
.
start_load_kv
(
get_forward_context
())
@
staticmethod
@
staticmethod
def
maybe_wait_for_kv_save
()
->
None
:
def
maybe_wait_for_kv_save
()
->
None
:
if
has_kv_transfer_group
():
if
has_kv_transfer_group
():
get_kv_transfer_group
().
wait_for_save
()
get_kv_transfer_group
().
wait_for_save
()
lmcache_connector
=
get_lmcache_connector
()
lmcache_connector
.
wait_for_save
()
@
staticmethod
@
staticmethod
def
get_finished_kv_transfers
(
def
get_finished_kv_transfers
(
scheduler_output
:
"SchedulerOutput"
,
scheduler_output
:
"SchedulerOutput"
,
...
@@ -2690,6 +2704,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -2690,6 +2704,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
has_kv_transfer_group
():
if
has_kv_transfer_group
():
get_kv_transfer_group
().
register_kv_caches
(
kv_caches
)
get_kv_transfer_group
().
register_kv_caches
(
kv_caches
)
get_lmcache_connector
().
register_kv_caches
(
kv_caches
)
def
get_kv_cache_spec
(
self
)
->
dict
[
str
,
KVCacheSpec
]:
def
get_kv_cache_spec
(
self
)
->
dict
[
str
,
KVCacheSpec
]:
"""
"""
Generates the KVCacheSpec by parsing the kv cache format from each
Generates the KVCacheSpec by parsing the kv cache format from each
...
...
vllm/zero_overhead/v1/gpu_model_runner.py
View file @
5db0e637
...
@@ -3,7 +3,7 @@ from typing import Any, Optional, Union
...
@@ -3,7 +3,7 @@ from typing import Any, Optional, Union
import
torch
import
torch
import
numpy
as
np
import
numpy
as
np
from
vllm
import
envs
from
vllm
import
envs
from
vllm.distributed.kv_transfer.kv_transfer_state
import
get_kv_transfer_group
,
has_kv_transfer_group
from
vllm.distributed.kv_transfer.kv_transfer_state
import
get_kv_transfer_group
,
has_kv_transfer_group
,
get_lmcache_connector
from
vllm.distributed.parallel_state
import
get_pp_group
,
get_tp_group
from
vllm.distributed.parallel_state
import
get_pp_group
,
get_tp_group
from
vllm.forward_context
import
set_forward_context
from
vllm.forward_context
import
set_forward_context
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
...
@@ -727,6 +727,8 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -727,6 +727,8 @@ class V1ZeroModelRunner(GPUModelRunner):
if
has_kv_transfer_group
():
if
has_kv_transfer_group
():
get_kv_transfer_group
().
clear_connector_metadata
()
get_kv_transfer_group
().
clear_connector_metadata
()
get_lmcache_connector
().
clear_connector_metadata
()
self
.
eplb_step
()
self
.
eplb_step
()
model_output
=
ZeroV1ModelRunnerOutput
(
model_output
=
ZeroV1ModelRunnerOutput
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment