Unverified Commit 6afa679c authored by Richard Huo's avatar Richard Huo Committed by GitHub
Browse files

chore: KVBM pip wheel (#3826)


Signed-off-by: default avatarAnant Sharma <anants@nvidia.com>
Co-authored-by: default avatarAnant Sharma <anants@nvidia.com>
parent e5c109d8
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# flake8: noqa
from kvbm._core import BlockManager as BlockManager
from kvbm._core import KvbmLeader as KvbmLeader
from kvbm._core import KvbmWorker as KvbmWorker
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Any, List, Optional
class Layer:
"""
A KV cache block layer
"""
...
def __dlpack__(self, stream: Optional[Any] = None, max_version: Optional[Any] = None, dl_device: Optional[Any] = None, copy: Optional[bool] = None) -> Any:
"""
Get a dlpack capsule of the layer
"""
...
def __dlpack_device__(self) -> Any:
"""
Get the dlpack device of the layer
"""
...
class Block:
"""
A KV cache block
"""
...
def __len__(self) -> int:
"""
Get the number of layers in the list
"""
...
def __getitem__(self, index: int) -> Layer:
"""
Get a layer by index
"""
...
def __iter__(self) -> 'Block':
"""
Get an iterator over the layers
"""
...
def __next__(self) -> Block:
"""
Get the next layer in the iterator
"""
...
def to_list(self) -> List[Layer]:
"""
Get a list of layers
"""
...
def __dlpack__(self, stream: Optional[Any] = None, max_version: Optional[Any] = None, dl_device: Optional[Any] = None, copy: Optional[bool] = None) -> Any:
"""
Get a dlpack capsule of the block
Exception raised if the block is not contiguous
"""
...
def __dlpack_device__(self) -> Any:
"""
Get the dlpack device of the block
"""
...
class BlockList:
"""
A list of KV cache blocks
"""
...
def __len__(self) -> int:
"""
Get the number of blocks in the list
"""
...
def __getitem__(self, index: int) -> Block:
"""
Get a block by index
"""
...
def __iter__(self) -> 'BlockList':
"""
Get an iterator over the blocks
"""
...
def __next__(self) -> Block:
"""
Get the next block in the iterator
"""
...
def to_list(self) -> List[Block]:
"""
Get a list of blocks
"""
...
class BlockManager:
"""
A KV cache block manager
"""
def __init__(
self,
worker_id: int,
num_layer: int,
page_size: int,
inner_dim: int,
dtype: Optional[str] = None,
host_num_blocks: Optional[int] = None,
device_num_blocks: Optional[int] = None,
device_id: int = 0
) -> None:
"""
Create a `BlockManager` object
Parameters:
-----------
worker_id: int
The worker ID for this block manager
num_layer: int
Number of layers in the model
page_size: int
Page size for blocks
inner_dim: int
Inner dimension size
dtype: Optional[str]
Data type (e.g., 'fp16', 'bf16', 'fp32'), defaults to 'fp16' if None
host_num_blocks: Optional[int]
Number of host blocks to allocate, None means no host blocks
device_num_blocks: Optional[int]
Number of device blocks to allocate, None means no device blocks
device_id: int
CUDA device ID, defaults to 0
"""
...
def allocate_host_blocks_blocking(self, count: int) -> BlockList:
"""
Allocate a list of host blocks (blocking call)
Parameters:
-----------
count: int
Number of blocks to allocate
Returns:
--------
BlockList
List of allocated blocks
"""
...
async def allocate_host_blocks(self, count: int) -> BlockList:
"""
Allocate a list of host blocks
Parameters:
-----------
count: int
Number of blocks to allocate
Returns:
--------
BlockList
List of allocated blocks
"""
...
def allocate_device_blocks_blocking(self, count: int) -> BlockList:
"""
Allocate a list of device blocks (blocking call)
Parameters:
-----------
count: int
Number of blocks to allocate
Returns:
--------
BlockList
List of allocated blocks
"""
...
async def allocate_device_blocks(self, count: int) -> BlockList:
"""
Allocate a list of device blocks
Parameters:
-----------
count: int
Number of blocks to allocate
Returns:
--------
BlockList
List of allocated blocks
"""
...
class KvbmCacheManager:
"""
A KV cache manager for VLLM
"""
def __init__(self, block_manager: BlockManager) -> None:
...
class KvbmRequest:
"""
A request for KV cache
"""
def __init__(self, request_id: int, tokens: List[int], block_size: int) -> None:
...
...@@ -2,8 +2,13 @@ ...@@ -2,8 +2,13 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import List from typing import List, Optional
from kvbm import KvbmLeader
from kvbm.trtllm_integration.rust import KvbmRequest
from kvbm.trtllm_integration.rust import KvConnectorLeader as RustKvConnectorLeader
from kvbm.trtllm_integration.rust import SchedulerOutput as RustSchedulerOutput
from kvbm.utils import is_dyn_runtime_enabled
from tensorrt_llm._torch.pyexecutor.kv_cache_connector import ( from tensorrt_llm._torch.pyexecutor.kv_cache_connector import (
KvCacheConnectorScheduler, KvCacheConnectorScheduler,
SchedulerOutput, SchedulerOutput,
...@@ -11,19 +16,20 @@ from tensorrt_llm._torch.pyexecutor.kv_cache_connector import ( ...@@ -11,19 +16,20 @@ from tensorrt_llm._torch.pyexecutor.kv_cache_connector import (
from tensorrt_llm.bindings.internal.batch_manager import LlmRequest from tensorrt_llm.bindings.internal.batch_manager import LlmRequest
from tensorrt_llm.llmapi.llm_args import TorchLlmArgs from tensorrt_llm.llmapi.llm_args import TorchLlmArgs
from dynamo.llm import KvbmLeader DistributedRuntime = None
from dynamo.llm.trtllm_integration.rust import KvbmRequest if is_dyn_runtime_enabled():
from dynamo.llm.trtllm_integration.rust import ( from dynamo.runtime import DistributedRuntime
KvConnectorLeader as RustKvConnectorLeader,
)
from dynamo.llm.trtllm_integration.rust import SchedulerOutput as RustSchedulerOutput
from dynamo.runtime import DistributedRuntime
class DynamoKVBMConnectorLeader(KvCacheConnectorScheduler): class DynamoKVBMConnectorLeader(KvCacheConnectorScheduler):
def __init__(self, llm_args: TorchLlmArgs): def __init__(self, llm_args: TorchLlmArgs):
super().__init__(llm_args) super().__init__(llm_args)
self.drt = DistributedRuntime.detached()
drt: Optional[object] = None
if is_dyn_runtime_enabled():
drt = DistributedRuntime.detached()
self.drt = drt
mappings = self._llm_args.parallel_config.to_mapping() mappings = self._llm_args.parallel_config.to_mapping()
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Optional
# Keeping this import is important because it runs the code in nixl’s __init__.py
# to set up the Nixl plugin path.
import nixl # noqa: F401
import torch import torch
from kvbm.trtllm_integration.rust import KvConnectorWorker as RustKvConnectorWorker
from kvbm.utils import is_dyn_runtime_enabled
from tensorrt_llm import logger from tensorrt_llm import logger
from tensorrt_llm._torch.pyexecutor.kv_cache_connector import KvCacheConnectorWorker from tensorrt_llm._torch.pyexecutor.kv_cache_connector import KvCacheConnectorWorker
from tensorrt_llm.llmapi.llm_args import TorchLlmArgs from tensorrt_llm.llmapi.llm_args import TorchLlmArgs
from dynamo.llm.trtllm_integration.rust import ( DistributedRuntime = None
KvConnectorWorker as RustKvConnectorWorker, if is_dyn_runtime_enabled():
) from dynamo.runtime import DistributedRuntime
from dynamo.runtime import DistributedRuntime
class DynamoKVBMConnectorWorker(KvCacheConnectorWorker): class DynamoKVBMConnectorWorker(KvCacheConnectorWorker):
...@@ -31,7 +37,11 @@ class DynamoKVBMConnectorWorker(KvCacheConnectorWorker): ...@@ -31,7 +37,11 @@ class DynamoKVBMConnectorWorker(KvCacheConnectorWorker):
def __init__(self, llm_args: TorchLlmArgs): def __init__(self, llm_args: TorchLlmArgs):
super().__init__(llm_args) super().__init__(llm_args)
self.drt = DistributedRuntime.detached() drt: Optional[object] = None
if is_dyn_runtime_enabled():
drt = DistributedRuntime.detached()
self.drt = drt
mappings = self._llm_args.parallel_config.to_mapping() mappings = self._llm_args.parallel_config.to_mapping()
self.rank = mappings.rank self.rank = mappings.rank
......
...@@ -7,7 +7,7 @@ Loader for the Rust-based TensorRT-LLM integration objects, using objects from _ ...@@ -7,7 +7,7 @@ Loader for the Rust-based TensorRT-LLM integration objects, using objects from _
try: try:
# TODO: use TRTLLM own integration module # TODO: use TRTLLM own integration module
from dynamo._core import _vllm_integration from kvbm._core import _vllm_integration
# Runtime - dynamically loaded classes from Rust extension # Runtime - dynamically loaded classes from Rust extension
KvbmRequest = getattr(_vllm_integration, "KvbmRequest") KvbmRequest = getattr(_vllm_integration, "KvbmRequest")
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import os
def is_dyn_runtime_enabled() -> bool:
"""
Return True if DYN_RUNTIME_ENABLED_KVBM is set to '1' or 'true' (case-insensitive).
DYN_RUNTIME_ENABLED_KVBM indicates if KVBM should use the existing DistributedRuntime
in the current environment.
WRN: Calling DistributedRuntime.detached() can crash the entire process if
dependencies are not satisfied, and it cannot be caught with try/except in Python.
TODO: Make DistributedRuntime.detached() raise a catchable Python exception and
avoid crashing the process.
"""
val = os.environ.get("DYN_RUNTIME_ENABLED_KVBM", "").strip().lower()
return val in {"1", "true"}
...@@ -26,9 +26,9 @@ if TYPE_CHECKING: ...@@ -26,9 +26,9 @@ if TYPE_CHECKING:
from vllm.v1.request import Request from vllm.v1.request import Request
# from dynamo.llm.vllm_integration.kv_cache_utils import KvbmCacheBlocks # from kvbm.vllm_integration.kv_cache_utils import KvbmCacheBlocks
from dynamo.llm.vllm_integration.connector_leader import KvConnectorLeader from kvbm.vllm_integration.connector_leader import KvConnectorLeader
from dynamo.llm.vllm_integration.connector_worker import KvConnectorWorker from kvbm.vllm_integration.connector_worker import KvConnectorWorker
EngineId = str EngineId = str
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from kvbm.vllm_integration.connector.dynamo_connector import DynamoConnector
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorRole from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorRole
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector import ( from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector import (
LMCacheConnectorV1, LMCacheConnectorV1,
...@@ -14,8 +15,6 @@ from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import ( ...@@ -14,8 +15,6 @@ from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import NixlConnector from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import NixlConnector
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from dynamo.llm.vllm_integration.connector.dynamo_connector import DynamoConnector
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
......
...@@ -20,19 +20,23 @@ if TYPE_CHECKING: ...@@ -20,19 +20,23 @@ if TYPE_CHECKING:
from vllm.v1.request import Request from vllm.v1.request import Request
# from dynamo.llm.vllm_integration.kv_cache_utils import KvbmCacheBlocks # from kvbm.vllm_integration.kv_cache_utils import KvbmCacheBlocks
# from dynamo.llm.vllm_integration.rust import BlockManager, KvbmRequest # from kvbm.vllm_integration.rust import BlockManager, KvbmRequest
# from dynamo.llm.vllm_integration.rust import KvConnectorLeader as RustKvConnectorLeader # from kvbm.vllm_integration.rust import KvConnectorLeader as RustKvConnectorLeader
# from dynamo.llm.vllm_integration.rust import ( # from kvbm.vllm_integration.rust import (
# KvConnectorMetadata as RustKvConnectorMetadata, # KvConnectorMetadata as RustKvConnectorMetadata,
# ) # )
# from dynamo.llm.vllm_integration.rust import SchedulerOutput as RustSchedulerOutput # from kvbm.vllm_integration.rust import SchedulerOutput as RustSchedulerOutput
from dynamo.llm import KvbmLeader from kvbm import KvbmLeader
from dynamo.llm.vllm_integration.rust import KvbmRequest from kvbm.utils import is_dyn_runtime_enabled
from dynamo.llm.vllm_integration.rust import KvConnectorLeader as RustKvConnectorLeader from kvbm.vllm_integration.rust import KvbmRequest
from dynamo.llm.vllm_integration.rust import SchedulerOutput as RustSchedulerOutput from kvbm.vllm_integration.rust import KvConnectorLeader as RustKvConnectorLeader
from dynamo.runtime import DistributedRuntime from kvbm.vllm_integration.rust import SchedulerOutput as RustSchedulerOutput
DistributedRuntime = None
if is_dyn_runtime_enabled():
from dynamo.runtime import DistributedRuntime
class DynamoConnectorMetadata(KVConnectorMetadata): class DynamoConnectorMetadata(KVConnectorMetadata):
...@@ -51,12 +55,14 @@ class KvConnectorLeader: ...@@ -51,12 +55,14 @@ class KvConnectorLeader:
""" """
def __init__(self, vllm_config: "VllmConfig", engine_id: str, **kwargs): def __init__(self, vllm_config: "VllmConfig", engine_id: str, **kwargs):
drt = kwargs.get("drt", None) drt: Optional[object] = kwargs.get("drt")
if drt is None:
self.drt = DistributedRuntime.detached() if drt is None and is_dyn_runtime_enabled():
drt = DistributedRuntime.detached()
else: else:
self.drt = drt drt = None
self.drt = drt
self.vllm_config = vllm_config self.vllm_config = vllm_config
world_size = vllm_config.parallel_config.world_size world_size = vllm_config.parallel_config.world_size
......
...@@ -9,7 +9,11 @@ from __future__ import annotations ...@@ -9,7 +9,11 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
# Keeping this import is important because it runs the code in nixl’s __init__.py
# to set up the Nixl plugin path when there is no pre-defined NIXL_PLUGIN_DIR
import nixl # noqa: F401
import torch import torch
from kvbm.utils import is_dyn_runtime_enabled
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
from vllm.model_executor.models.utils import extract_layer_index from vllm.model_executor.models.utils import extract_layer_index
...@@ -21,15 +25,18 @@ if TYPE_CHECKING: ...@@ -21,15 +25,18 @@ if TYPE_CHECKING:
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
# from dynamo.llm.vllm_integration.kv_cache_utils import KvbmCacheBlocks # from kvbm.vllm_integration.kv_cache_utils import KvbmCacheBlocks
# from dynamo.llm.vllm_integration.rust import BlockManager # from kvbm.vllm_integration.rust import BlockManager
# from dynamo.llm.vllm_integration.rust import ( # from kvbm.vllm_integration.rust import (
# KvConnectorMetadata as RustKvConnectorMetadata, # KvConnectorMetadata as RustKvConnectorMetadata,
# KvConnectorWorker as RustKvConnectorWorker, # KvConnectorWorker as RustKvConnectorWorker,
# ) # )
from dynamo.llm.vllm_integration.rust import KvConnectorWorker as RustKvConnectorWorker from kvbm.vllm_integration.rust import KvConnectorWorker as RustKvConnectorWorker
from dynamo.runtime import DistributedRuntime
DistributedRuntime = None
if is_dyn_runtime_enabled():
from dynamo.runtime import DistributedRuntime
class DynamoConnectorMetadata(KVConnectorMetadata): class DynamoConnectorMetadata(KVConnectorMetadata):
...@@ -40,11 +47,14 @@ class DynamoConnectorMetadata(KVConnectorMetadata): ...@@ -40,11 +47,14 @@ class DynamoConnectorMetadata(KVConnectorMetadata):
class KvConnectorWorker: class KvConnectorWorker:
def __init__(self, vllm_config: "VllmConfig", engine_id: str, **kwargs): def __init__(self, vllm_config: "VllmConfig", engine_id: str, **kwargs):
drt = kwargs.get("drt", None) drt: Optional[object] = kwargs.get("drt")
if drt is None:
self.drt = DistributedRuntime.detached() if drt is None and is_dyn_runtime_enabled():
drt = DistributedRuntime.detached()
else: else:
self.drt = drt drt = None
self.drt = drt
self.vllm_config = vllm_config self.vllm_config = vllm_config
self._connector = RustKvConnectorWorker(self.drt, engine_id) self._connector = RustKvConnectorWorker(self.drt, engine_id)
......
...@@ -26,10 +26,10 @@ if TYPE_CHECKING: ...@@ -26,10 +26,10 @@ if TYPE_CHECKING:
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request from vllm.v1.request import Request
from dynamo.llm.vllm_integration.kv_cache_utils import KvbmCacheBlocks from kvbm.vllm_integration.kv_cache_utils import KvbmCacheBlocks
from dynamo.llm.vllm_integration.rust import BlockManager from kvbm.vllm_integration.rust import BlockManager
from dynamo.llm.vllm_integration.rust import KvbmCacheManager as RustKvbmCacheManager from kvbm.vllm_integration.rust import KvbmCacheManager as RustKvbmCacheManager
from dynamo.llm.vllm_integration.rust import KvbmRequest, SlotUpdate from kvbm.vllm_integration.rust import KvbmRequest, SlotUpdate
class KvbmCacheManager(KVConnectorBase_V1): class KvbmCacheManager(KVConnectorBase_V1):
......
...@@ -9,11 +9,10 @@ from __future__ import annotations ...@@ -9,11 +9,10 @@ from __future__ import annotations
from typing import List from typing import List
from kvbm.vllm_integration.rust import BlockState, BlockStates, KvbmBlockList
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.kv_cache_utils import KVCacheBlock from vllm.v1.core.kv_cache_utils import KVCacheBlock
from dynamo.llm.vllm_integration.rust import BlockState, BlockStates, KvbmBlockList
# from vllm.logger import init_logger # from vllm.logger import init_logger
# logger = init_logger(__name__) # logger = init_logger(__name__)
......
...@@ -6,7 +6,7 @@ Loader for the Rust-based vLLM integration objects. ...@@ -6,7 +6,7 @@ Loader for the Rust-based vLLM integration objects.
""" """
try: try:
from dynamo._core import _vllm_integration from kvbm._core import _vllm_integration
# Runtime - dynamically loaded classes from Rust extension # Runtime - dynamically loaded classes from Rust extension
KvbmCacheManager = getattr(_vllm_integration, "KvbmCacheManager") KvbmCacheManager = getattr(_vllm_integration, "KvbmCacheManager")
...@@ -20,7 +20,7 @@ try: ...@@ -20,7 +20,7 @@ try:
KvConnectorLeader = getattr(_vllm_integration, "PyKvConnectorLeader") KvConnectorLeader = getattr(_vllm_integration, "PyKvConnectorLeader")
SchedulerOutput = getattr(_vllm_integration, "SchedulerOutput") SchedulerOutput = getattr(_vllm_integration, "SchedulerOutput")
from dynamo.llm import BlockManager from kvbm import BlockManager
except ImportError: except ImportError:
print("Failed to import Dynamo KVBM. vLLM integration will not be available.") print("Failed to import Dynamo KVBM. vLLM integration will not be available.")
......
...@@ -9,39 +9,11 @@ use dynamo_llm::block_manager::block::{ ...@@ -9,39 +9,11 @@ use dynamo_llm::block_manager::block::{
use dynamo_llm::block_manager::kv_consolidator::KvEventConsolidatorConfig; use dynamo_llm::block_manager::kv_consolidator::KvEventConsolidatorConfig;
use dynamo_llm::block_manager::offload::filter::FrequencyFilter; use dynamo_llm::block_manager::offload::filter::FrequencyFilter;
use dynamo_llm::block_manager::{BasicMetadata, BlockParallelismStrategy}; use dynamo_llm::block_manager::{BasicMetadata, BlockParallelismStrategy};
use dynamo_runtime::DistributedRuntime;
use pyo3::PyResult; use pyo3::PyResult;
use std::time::Duration; use std::time::Duration;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
/// Creates a disk offload filter based on environment configuration.
/// Returns `Ok(None)` if the filter is disabled via `DYN_KVBM_DISABLE_DISK_OFFLOAD_FILTER`,
/// otherwise constructs a `FrequencyFilter` with standard parameters.
fn create_disk_offload_filter(
cancel_token: &CancellationToken,
runtime: &tokio::runtime::Handle,
) -> Result<Option<Arc<FrequencyFilter>>> {
// Check if disk offload filter is disabled via environment variable
let disable_filter = std::env::var("DYN_KVBM_DISABLE_DISK_OFFLOAD_FILTER")
.map(|v| v == "true" || v == "1")
.unwrap_or(false);
if disable_filter {
return Ok(None);
}
// TODO: These values seem plausible for most use cases, but we need to figure out a better way to configure them.
let frequency_filter = FrequencyFilter::new(
2,
Duration::from_secs(600),
1_000_000,
cancel_token.child_token(),
runtime.clone(),
)?;
Ok(Some(Arc::new(frequency_filter)))
}
mod controller; mod controller;
mod distributed; mod distributed;
...@@ -73,11 +45,39 @@ type VllmController = Arc< ...@@ -73,11 +45,39 @@ type VllmController = Arc<
>, >,
>; >;
/// Creates a disk offload filter based on environment configuration.
/// Returns `Ok(None)` if the filter is disabled via `DYN_KVBM_DISABLE_DISK_OFFLOAD_FILTER`,
/// otherwise constructs a `FrequencyFilter` with standard parameters.
fn create_disk_offload_filter(
cancel_token: &CancellationToken,
runtime: &tokio::runtime::Handle,
) -> Result<Option<Arc<FrequencyFilter>>> {
// Check if disk offload filter is disabled via environment variable
let disable_filter = std::env::var("DYN_KVBM_DISABLE_DISK_OFFLOAD_FILTER")
.map(|v| v == "true" || v == "1")
.unwrap_or(false);
if disable_filter {
return Ok(None);
}
// TODO: These values seem plausible for most use cases, but we need to figure out a better way to configure them.
let frequency_filter = FrequencyFilter::new(
2,
Duration::from_secs(600),
1_000_000,
cancel_token.child_token(),
runtime.clone(),
)?;
Ok(Some(Arc::new(frequency_filter)))
}
#[pyclass] #[pyclass]
#[derive(Clone)] #[derive(Clone)]
pub struct BlockManager { pub struct BlockManager {
inner: VllmBlockManager, inner: VllmBlockManager,
drt: DistributedRuntime, _drt: Option<Arc<DistributedRuntime>>,
_controller: Option<VllmController>, _controller: Option<VllmController>,
} }
...@@ -126,19 +126,17 @@ impl BlockManager { ...@@ -126,19 +126,17 @@ impl BlockManager {
if leader.num_host_blocks() > 0 { if leader.num_host_blocks() > 0 {
tracing::info!("Using {} host blocks", leader.num_host_blocks()); tracing::info!("Using {} host blocks", leader.num_host_blocks());
let mut host_layout_config = let mut host_layout_config =
dynamo_llm::block_manager::KvManagerLayoutConfig::builder() dynamo_llm::block_manager::KvManagerLayoutConfig::builder()
.num_blocks(leader.num_host_blocks()) .num_blocks(leader.num_host_blocks())
.logical(Some(BlockParallelismStrategy::LeaderWorkerSharded)); .logical(Some(BlockParallelismStrategy::LeaderWorkerSharded));
if leader.num_disk_blocks() > 0 { if leader.num_disk_blocks() > 0
if let Some(filter) = && let Some(filter) =
create_disk_offload_filter(&cancel_token, &rt.inner().runtime().primary()) create_disk_offload_filter(&cancel_token, &get_current_tokio_handle())
.map_err(to_pyerr)? .map_err(to_pyerr)?
{ {
host_layout_config = host_layout_config.offload_filter(Some(filter)); host_layout_config = host_layout_config.offload_filter(Some(filter));
}
} }
config = config.host_layout(host_layout_config.build().map_err(to_pyerr)?); config = config.host_layout(host_layout_config.build().map_err(to_pyerr)?);
...@@ -181,7 +179,7 @@ impl BlockManager { ...@@ -181,7 +179,7 @@ impl BlockManager {
// ) // )
}; };
let rt = drt.inner().runtime().primary(); let rt = get_current_tokio_handle();
let config = config.build().map_err(to_pyerr)?; let config = config.build().map_err(to_pyerr)?;
Ok(BlockManager { Ok(BlockManager {
...@@ -197,7 +195,7 @@ impl BlockManager { ...@@ -197,7 +195,7 @@ impl BlockManager {
.await .await
}) })
.map_err(to_pyerr)?, .map_err(to_pyerr)?,
drt, _drt: drt,
_controller: None, _controller: None,
}) })
} }
...@@ -213,11 +211,7 @@ impl BlockManager { ...@@ -213,11 +211,7 @@ impl BlockManager {
} }
let block_manager = self.inner.clone(); let block_manager = self.inner.clone();
let controller = self let controller = get_current_tokio_handle()
.drt
.inner()
.runtime()
.primary()
.block_on(controller::Controller::new( .block_on(controller::Controller::new(
block_manager, block_manager,
component.inner.clone(), component.inner.clone(),
...@@ -280,6 +274,7 @@ impl BlockManagerBuilder { ...@@ -280,6 +274,7 @@ impl BlockManagerBuilder {
self.disable_device_pool = yes; self.disable_device_pool = yes;
self self
} }
pub fn kvbm_metrics( pub fn kvbm_metrics(
mut self, mut self,
metrics: dynamo_llm::block_manager::metrics_kvbm::KvbmMetrics, metrics: dynamo_llm::block_manager::metrics_kvbm::KvbmMetrics,
...@@ -339,12 +334,11 @@ impl BlockManagerBuilder { ...@@ -339,12 +334,11 @@ impl BlockManagerBuilder {
.num_blocks(leader_inner.num_host_blocks()) .num_blocks(leader_inner.num_host_blocks())
.logical(Some(BlockParallelismStrategy::LeaderWorkerSharded)); .logical(Some(BlockParallelismStrategy::LeaderWorkerSharded));
if leader_inner.num_disk_blocks() > 0 { if leader_inner.num_disk_blocks() > 0
if let Some(filter) = && let Some(filter) =
create_disk_offload_filter(&cancel_token, &drt.inner().runtime().primary())? create_disk_offload_filter(&cancel_token, &get_current_tokio_handle())?
{ {
host_layout_config = host_layout_config.offload_filter(Some(filter)); host_layout_config = host_layout_config.offload_filter(Some(filter));
}
} }
config = config.host_layout(host_layout_config.build()?); config = config.host_layout(host_layout_config.build()?);
...@@ -382,7 +376,7 @@ impl BlockManagerBuilder { ...@@ -382,7 +376,7 @@ impl BlockManagerBuilder {
Ok(BlockManager { Ok(BlockManager {
inner, inner,
drt, _drt: drt,
_controller: None, _controller: None,
}) })
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment