"vscode:/vscode.git/clone" did not exist on "1cf814fe8e683bc2bff9a9ba3c25458783e71d75"
Unverified Commit 3dde8619 authored by pansicheng's avatar pansicheng Committed by GitHub
Browse files

Conditionally import HiCacheHF3FS (#8598)


Co-authored-by: default avatarZhiqiang Xie <xiezhq@stanford.edu>
parent b7170cc8
......@@ -25,12 +25,6 @@ if TYPE_CHECKING:
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
from sglang.srt.mem_cache.mooncake_store.mooncake_store import (
MooncakeStore,
get_hash_str_mooncake,
)
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import HiCacheHF3FS
logger = logging.getLogger(__name__)
......@@ -251,16 +245,7 @@ class HiCacheController:
self.enable_storage = False
# todo: move backend initialization to storage backend module
if storage_backend is not None:
# create a new communication group for synchronizing storage operations across TP workers
self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
if self.tp_world_size > 1:
group_ranks = torch.distributed.get_process_group_ranks(tp_group)
self.prefetch_tp_group = torch.distributed.new_group(
group_ranks, backend="gloo"
)
self.backup_tp_group = torch.distributed.new_group(
group_ranks, backend="gloo"
)
from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
if storage_backend == "file":
self.storage_backend = HiCacheFile()
......@@ -271,11 +256,19 @@ class HiCacheController:
self.storage_backend = HiCacheNixl()
self.get_hash_str = get_hash_str
elif storage_backend == "mooncake":
from sglang.srt.mem_cache.mooncake_store.mooncake_store import (
MooncakeStore,
get_hash_str_mooncake,
)
self.storage_backend = MooncakeStore()
self.get_hash_str = get_hash_str_mooncake
self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
elif storage_backend == "hf3fs":
from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
HiCacheHF3FS,
)
rank = get_tensor_model_parallel_rank()
bytes_per_page = (
......@@ -293,6 +286,16 @@ class HiCacheController:
self.enable_storage = True
# todo: threshold policy for prefetching
self.prefetch_threshold = max(prefetch_threshold, self.page_size)
# create a new communication group for synchronizing storage operations across TP workers
self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
if self.tp_world_size > 1:
group_ranks = torch.distributed.get_process_group_ranks(tp_group)
self.prefetch_tp_group = torch.distributed.new_group(
group_ranks, backend="gloo"
)
self.backup_tp_group = torch.distributed.new_group(
group_ranks, backend="gloo"
)
self.load_cache_event = load_cache_event
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment