Unverified Commit 96b5004b authored by ibifrost's avatar ibifrost Committed by GitHub
Browse files

[KVConnector] Support 3FS KVConnector (#37636)


Signed-off-by: default avatarwuchenxin <wuchenxin.wcx@alibaba-inc.com>
Signed-off-by: default avataribifrost <47308427+ibifrost@users.noreply.github.com>
Co-authored-by: default avatarSimon Mo <simon.mo@hey.com>
parent 98e1a43a
...@@ -1013,6 +1013,7 @@ package_data = { ...@@ -1013,6 +1013,7 @@ package_data = {
"model_executor/layers/quantization/utils/configs/*.json", "model_executor/layers/quantization/utils/configs/*.json",
"entrypoints/serve/instrumentator/static/*.js", "entrypoints/serve/instrumentator/static/*.js",
"entrypoints/serve/instrumentator/static/*.css", "entrypoints/serve/instrumentator/static/*.css",
"distributed/kv_transfer/kv_connector/v1/hf3fs/utils/*.cpp",
] ]
} }
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for resource management in hf3fs_client.py: constructor failure cleanup
and idempotent close(). Tests use mock to replace real I/O operations
(hf3fs_fuse.io, SharedMemory, os, CUDA).
Requires hf3fs_fuse.io to be installed; skipped otherwise.
"""
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
HF3FS_AVAILABLE = True
try:
from hf3fs_fuse.io import ( # noqa: F401
deregister_fd,
extract_mount_point,
make_ioring,
make_iovec,
register_fd,
)
from vllm.distributed.kv_transfer.kv_connector.v1.hf3fs.hf3fs_client import (
Hf3fsClient,
)
except Exception:
HF3FS_AVAILABLE = False
requires_hf3fs = pytest.mark.skipif(
not HF3FS_AVAILABLE,
reason="hf3fs_fuse.io is not available on this machine",
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
class _FakeShm:
"""Shared-memory stub matching the multiprocessing.shared_memory.SharedMemory
interface used by Hf3fsClient:
Attributes accessed by the constructor:
.buf – memoryview / buffer-protocol object consumed by torch.frombuffer
Methods called during normal lifetime:
.unlink() – called right after the iovec is set up
.close() – called in _release_resources()
"""
def __init__(self, size: int = 1024):
self._data = bytearray(size)
self.buf = memoryview(self._data)
self.closed = False
self.close_call_count = 0
self.unlink_call_count = 0
def close(self):
self.closed = True
self.close_call_count += 1
def unlink(self):
self.unlink_call_count += 1
# ===========================================================================
# TestHf3fsClientResourceManagement
# ===========================================================================
@requires_hf3fs
class TestHf3fsClientResourceManagement:
"""Tests for constructor failure cleanup and idempotent close()."""
_MOD = "vllm.distributed.kv_transfer.kv_connector.v1.hf3fs.hf3fs_client"
# ------------------------------------------------------------------
# Helper: build a minimal Hf3fsClient bypassing all real I/O so that
# we can fully control its internal state.
# ------------------------------------------------------------------
def _make_client(self, tmp_path):
"""Return a fully-mocked Hf3fsClient with controllable internals."""
fake_shm_r = _FakeShm()
fake_shm_w = _FakeShm()
patcher_list: list[Any] = [
patch(f"{self._MOD}.HF3FS_AVAILABLE", True),
patch(f"{self._MOD}.register_fd"),
patch(f"{self._MOD}.deregister_fd"),
patch(f"{self._MOD}.extract_mount_point", return_value="/mnt/hf3fs"),
patch(f"{self._MOD}.make_ioring", return_value=MagicMock()),
patch(f"{self._MOD}.make_iovec", return_value=MagicMock()),
patch(
"multiprocessing.shared_memory.SharedMemory",
side_effect=[fake_shm_r, fake_shm_w],
),
patch("os.open", return_value=99),
patch("os.ftruncate"),
patch("os.close"),
patch("os.fsync"),
patch("torch.cuda.Stream", return_value=MagicMock()),
patch("torch.frombuffer", return_value=MagicMock()),
patch("torch.empty", return_value=MagicMock()),
]
for p in patcher_list:
p.start()
try:
client = Hf3fsClient(
path=str(tmp_path / "test.bin"),
size=1024,
bytes_per_page=256,
entries=4,
)
finally:
for p in patcher_list:
p.stop()
# Manually point internal handles to our controllable fakes so that
# assertions after close() can inspect them directly.
client.shm_r = fake_shm_r
client.shm_w = fake_shm_w
client.file = 99
return client, fake_shm_r, fake_shm_w
# ------------------------------------------------------------------
# close() idempotency
# ------------------------------------------------------------------
def test_close_idempotent_and_handles_cleared(self, tmp_path):
"""Multiple close() calls must not raise; deregister_fd called exactly
once, all handles set to None, shm.close() invoked."""
client, shm_r, shm_w = self._make_client(tmp_path)
with (
patch(f"{self._MOD}.deregister_fd") as mock_dereg,
patch("os.close"),
):
client.close() # first close
client.close() # second close — must be no-op
client.close() # third close — must be no-op
assert client._closed is True
assert mock_dereg.call_count == 1, (
f"deregister_fd called {mock_dereg.call_count} times; expected 1"
)
for attr in ("iov_r", "iov_w", "ior_r", "ior_w", "shm_r", "shm_w", "file"):
assert getattr(client, attr) is None, f"{attr} should be None after close()"
assert shm_r.closed is True
assert shm_w.closed is True
def test_flush_after_close_is_noop(self, tmp_path):
"""flush() after close() must silently do nothing (no fsync call)."""
client, _, _ = self._make_client(tmp_path)
with (
patch(f"{self._MOD}.deregister_fd"),
patch("os.close"),
patch("os.fsync") as mock_fsync,
):
client.close()
client.flush()
mock_fsync.assert_not_called()
# ------------------------------------------------------------------
# Constructor failure leaves no leaked resources
# ------------------------------------------------------------------
def test_constructor_failure_after_file_open_cleans_file(self, tmp_path):
"""If the constructor raises after os.open(), the fd must be closed."""
with (
patch(f"{self._MOD}.HF3FS_AVAILABLE", True),
patch(f"{self._MOD}.register_fd"),
patch(f"{self._MOD}.deregister_fd"),
patch(
f"{self._MOD}.extract_mount_point",
side_effect=RuntimeError("mount point not found"),
),
patch("os.open", return_value=55),
patch("os.ftruncate"),
patch("os.close") as mock_os_close,
patch("torch.cuda.Stream", return_value=MagicMock()),
pytest.raises(RuntimeError, match="mount point not found"),
):
Hf3fsClient(
path=str(tmp_path / "fail.bin"),
size=1024,
bytes_per_page=256,
entries=4,
)
mock_os_close.assert_called_once_with(55)
def test_constructor_failure_after_shm_alloc_closes_shm(self, tmp_path):
"""Constructor raises after SharedMemory creation → both shm objects closed."""
fake_shm_r = _FakeShm()
fake_shm_w = _FakeShm()
with (
patch(f"{self._MOD}.HF3FS_AVAILABLE", True),
patch(f"{self._MOD}.register_fd"),
patch(f"{self._MOD}.deregister_fd"),
patch(f"{self._MOD}.extract_mount_point", return_value="/mnt/hf3fs"),
patch(
"multiprocessing.shared_memory.SharedMemory",
side_effect=[fake_shm_r, fake_shm_w],
),
patch("os.open", return_value=66),
patch("os.ftruncate"),
patch("os.close"),
patch("torch.frombuffer", return_value=MagicMock()),
patch("torch.empty", return_value=MagicMock()),
patch(
f"{self._MOD}.make_ioring",
side_effect=RuntimeError("ioring init failed"),
),
patch(f"{self._MOD}.make_iovec", return_value=MagicMock()),
patch("torch.cuda.Stream", return_value=MagicMock()),
pytest.raises(RuntimeError, match="ioring init failed"),
):
Hf3fsClient(
path=str(tmp_path / "fail2.bin"),
size=1024,
bytes_per_page=256,
entries=4,
)
assert fake_shm_r.closed is True, (
"shm_r was not closed after constructor failure"
)
assert fake_shm_w.closed is True, (
"shm_w was not closed after constructor failure"
)
def test_constructor_failure_does_not_close_unallocated_shm(self, tmp_path):
"""Failure before SharedMemory is created must not raise AttributeError
or TypeError from cleanup."""
with (
patch(f"{self._MOD}.HF3FS_AVAILABLE", True),
patch(f"{self._MOD}.register_fd"),
patch(f"{self._MOD}.deregister_fd"),
patch(
f"{self._MOD}.extract_mount_point",
side_effect=RuntimeError("early failure"),
),
patch("os.open", return_value=77),
patch("os.ftruncate"),
patch("os.close"),
patch("torch.cuda.Stream", return_value=MagicMock()),
pytest.raises(RuntimeError, match="early failure"),
):
Hf3fsClient(
path=str(tmp_path / "early_fail.bin"),
size=1024,
bytes_per_page=256,
entries=4,
)
# ------------------------------------------------------------------
# _release_resources on already-cleared state must be a no-op
# ------------------------------------------------------------------
def test_release_resources_on_empty_state_is_safe(self, tmp_path):
"""_release_resources() on a fully-cleared client must not raise."""
client, _, _ = self._make_client(tmp_path)
with (
patch(f"{self._MOD}.deregister_fd"),
patch("os.close"),
):
client.close() # clears all handles
with (
patch(f"{self._MOD}.deregister_fd") as mock_dereg2,
patch("os.close") as mock_os_close2,
):
client._release_resources() # must not raise
mock_dereg2.assert_not_called()
mock_os_close2.assert_not_called()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for HF3FS KV Connector high-level components:
- TestHf3fsMockClient : file-backed mock client I/O correctness
- TestHF3FSKVConnectorStats: metric collection, aggregation, serialisation
"""
import os
from unittest.mock import MagicMock
import pytest
import torch
from vllm.distributed.kv_transfer.kv_connector.v1.hf3fs.hf3fs_connector import (
HF3FSKVConnectorStats,
)
from vllm.distributed.kv_transfer.kv_connector.v1.hf3fs.utils.hf3fs_mock_client import (
Hf3fsClient as MockHf3fsClient,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
@pytest.fixture
def hf3fs_stats():
"""Fresh HF3FSKVConnectorStats instance."""
return HF3FSKVConnectorStats()
def _make_cuda_event():
"""Return a real CUDA event when available, otherwise a MagicMock."""
if torch.cuda.is_available():
return torch.cuda.Event()
return MagicMock()
# ===========================================================================
# TestHf3fsMockClient
# ===========================================================================
class TestHf3fsMockClient:
"""Tests for hf3fs_mock_client.Hf3fsClient (file-backend mock)."""
def test_init_creates_file(self, tmp_path):
"""Initializing the client should create the backing file."""
path = str(tmp_path / "test_file")
client = MockHf3fsClient(path=path, size=4096, bytes_per_page=512, entries=4)
assert os.path.exists(path), "Backing file should be created on init"
assert os.path.getsize(path) == 4096
client.close()
@pytest.mark.parametrize(
"dtype, bytes_per_page",
[
(torch.float32, 512),
(torch.float16, 256),
(torch.bfloat16, 256),
],
ids=["float32", "float16", "bfloat16"],
)
def test_batch_write_and_read_dtype(self, tmp_path, dtype, bytes_per_page):
"""Write a tensor of the given dtype and verify round-trip correctness."""
path = str(tmp_path / f"rw_{dtype}")
client = MockHf3fsClient(
path=path, size=bytes_per_page * 8, bytes_per_page=bytes_per_page, entries=4
)
elem_size = torch.tensor([], dtype=dtype).element_size()
numel = bytes_per_page // elem_size
tensor_write = torch.arange(numel, dtype=dtype)
event = _make_cuda_event()
results = client.batch_write([0], [tensor_write], event)
assert results == [bytes_per_page], f"Write should succeed, got {results}"
tensor_read = torch.zeros(numel, dtype=dtype)
results = client.batch_read([0], [tensor_read])
assert results == [bytes_per_page], f"Read should succeed, got {results}"
assert torch.equal(tensor_write, tensor_read), (
"Read tensor should match written tensor"
)
client.close()
def test_batch_read_empty_file_returns_error(self, tmp_path):
"""Reading out-of-bounds offset should return -1."""
bytes_per_page = 128
size = bytes_per_page * 4
path = str(tmp_path / "empty_read")
client = MockHf3fsClient(
path=path, size=size, bytes_per_page=bytes_per_page, entries=4
)
numel = bytes_per_page // 4
tensor_read = torch.zeros(numel, dtype=torch.float32)
results = client.batch_read([size], [tensor_read]) # offset == size => OOB
assert results[0] == -1, "Out-of-bounds read should return -1"
client.close()
def test_batch_write_out_of_bounds_returns_error(self, tmp_path):
"""Writing at an offset beyond file size should return -1."""
bytes_per_page = 128
size = bytes_per_page * 4
path = str(tmp_path / "oob_write")
client = MockHf3fsClient(
path=path, size=size, bytes_per_page=bytes_per_page, entries=4
)
numel = bytes_per_page // 4
tensor = torch.ones(numel, dtype=torch.float32)
event = _make_cuda_event()
results = client.batch_write([size], [tensor], event) # OOB offset
assert results[0] == -1, "Out-of-bounds write should return -1"
client.close()
def test_multiple_tensors_rw(self, tmp_path):
"""Write multiple tensors at different offsets, then read all back."""
bytes_per_page = 128
n = 4
path = str(tmp_path / "multi_rw")
client = MockHf3fsClient(
path=path,
size=bytes_per_page * n * 2,
bytes_per_page=bytes_per_page,
entries=8,
)
tensors_write = [
torch.full((bytes_per_page // 4,), float(i), dtype=torch.float32)
for i in range(n)
]
offsets = [i * bytes_per_page for i in range(n)]
event = _make_cuda_event()
results = client.batch_write(offsets, tensors_write, event)
assert all(r == bytes_per_page for r in results)
tensors_read = [
torch.zeros(bytes_per_page // 4, dtype=torch.float32) for _ in range(n)
]
results = client.batch_read(offsets, tensors_read)
assert all(r == bytes_per_page for r in results)
for i, (tw, tr) in enumerate(zip(tensors_write, tensors_read)):
assert torch.allclose(tw, tr), f"Tensor {i} mismatch after round-trip"
client.close()
def test_flush_and_close_no_error(self, tmp_path):
"""flush() and close() should not raise exceptions."""
path = str(tmp_path / "flush_close")
client = MockHf3fsClient(path=path, size=1024, bytes_per_page=128, entries=4)
client.flush()
client.close()
# ===========================================================================
# TestHF3FSKVConnectorStats
# ===========================================================================
class TestHF3FSKVConnectorStats:
"""Tests for HF3FSKVConnectorStats metric collection and aggregation."""
def test_initial_is_empty(self, hf3fs_stats):
"""Fresh stats object should report is_empty() == True."""
assert hf3fs_stats.is_empty() is True
@pytest.mark.parametrize(
"task_type, duration_key",
[
("Saved", "save_duration"),
("Loaded", "load_duration"),
],
ids=["save", "load"],
)
def test_record_success_duration(self, hf3fs_stats, task_type, duration_key):
"""Recording a successful task should update duration list and total count."""
hf3fs_stats.record_success_task_duration(task_type, 0.5)
assert not hf3fs_stats.is_empty()
assert len(hf3fs_stats.data[duration_key]) == 1
assert hf3fs_stats.data[duration_key][0] == pytest.approx(0.5)
assert hf3fs_stats.data["num_transfer_task"] == 1
@pytest.mark.parametrize(
"task_type, failed_key",
[
("Saved", "num_failed_save"),
("Loaded", "num_failed_load"),
],
ids=["save", "load"],
)
def test_record_failed_task(self, hf3fs_stats, task_type, failed_key):
"""Recording a failed task should increment the corresponding counter."""
hf3fs_stats.record_failed_task_count(task_type)
assert hf3fs_stats.data[failed_key] == 1
assert hf3fs_stats.data["num_transfer_task"] == 1
def test_aggregate_two_stats(self):
"""aggregate() should merge save/load duration lists and sum counters."""
stats1 = HF3FSKVConnectorStats()
stats1.record_success_task_duration("Saved", 0.1)
stats1.record_success_task_duration("Loaded", 0.2)
stats2 = HF3FSKVConnectorStats()
stats2.record_success_task_duration("Saved", 0.3)
stats2.record_failed_task_count("Loaded")
stats1.aggregate(stats2)
assert stats1.data["save_duration"] == pytest.approx([0.1, 0.3])
assert stats1.data["load_duration"] == pytest.approx([0.2])
assert stats1.data["num_failed_load"] == 1
assert stats1.data["num_transfer_task"] == 4
def test_reduce_with_data(self):
"""reduce() computes correct averages when data is present."""
stats = HF3FSKVConnectorStats()
stats.record_success_task_duration("Saved", 1.0)
stats.record_success_task_duration("Saved", 3.0)
result = stats.reduce()
assert result["Num save task success"] == pytest.approx(2.0, rel=0.01)
assert result["Num save task failed"] == pytest.approx(0.0, rel=0.01)
assert result["Avg save duration (ms)"] == pytest.approx(2000.0, rel=0.01)
def test_clone_and_reset(self, hf3fs_stats):
"""clone_and_reset() returns a copy with data and resets the original."""
hf3fs_stats.record_success_task_duration("Saved", 0.7)
hf3fs_stats.record_success_task_duration("Loaded", 0.4)
clone = hf3fs_stats.clone_and_reset()
assert clone.data["num_transfer_task"] == 2
assert hf3fs_stats.is_empty()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for HF3FS metadata server data structures and allocation logic:
- RankFileMetadata : page allocation / release primitives
- KeyMetadata : per-key rank-page tracking and completion detection
- GlobalMetadataState : coordinated allocation with cache-hit semantics
"""
import pytest
from vllm.distributed.kv_transfer.kv_connector.v1.hf3fs.hf3fs_metadata_server import (
GlobalMetadataState,
KeyMetadata,
RankFileMetadata,
)
# ===========================================================================
# TestRankFileMetadata
# ===========================================================================
class TestRankFileMetadata:
"""Unit tests for RankFileMetadata page allocation primitives."""
@pytest.mark.parametrize(
"alloc_count, expected_pages",
[(3, 3), (5, 0)],
ids=["alloc_partial", "alloc_exceeds"],
)
def test_allocate_pages(self, alloc_count, expected_pages):
"""allocate_pages returns correct pages or empty list when insufficient."""
rank_meta = RankFileMetadata(rank_id=0, num_pages=3, free_pages=list(range(3)))
pages = rank_meta.allocate_pages(alloc_count)
assert len(pages) == expected_pages
if expected_pages > 0:
rank_meta.release_pages(pages)
assert rank_meta.get_free_page_count() == 3
def test_release_pages_restores_count(self):
"""Releasing allocated pages returns them to the free pool."""
rank_meta = RankFileMetadata(rank_id=0, num_pages=4, free_pages=list(range(4)))
pages = rank_meta.allocate_pages(2)
assert rank_meta.get_free_page_count() == 2
rank_meta.release_pages(pages)
assert rank_meta.get_free_page_count() == 4
def test_release_pages_no_duplicates(self):
"""Releasing the same page twice must not create duplicates."""
rank_meta = RankFileMetadata(rank_id=0, num_pages=3, free_pages=list(range(3)))
rank_meta.allocate_pages(1) # takes page 0
rank_meta.release_pages([0])
rank_meta.release_pages([0]) # second release of the same page
assert rank_meta.get_free_page_count() == 3
# ===========================================================================
# TestKeyMetadata
# ===========================================================================
class TestKeyMetadata:
"""Unit tests for KeyMetadata completion tracking."""
def test_is_complete_false_until_all_ranks(self):
"""is_complete() returns True only when all ranks confirmed."""
key_meta = KeyMetadata(key="k", rank_to_page={}, tp_world_size=2)
assert key_meta.is_complete() is False
key_meta.add_rank_page(0, 5)
assert key_meta.is_complete() is False
key_meta.add_rank_page(1, 10)
assert key_meta.is_complete() is True
def test_get_rank_page_returns_none_for_missing_rank(self):
"""get_rank_page() returns None when the rank has no entry."""
key_meta = KeyMetadata(key="k", rank_to_page={0: 3}, tp_world_size=2)
assert key_meta.get_rank_page(0) == 3
assert key_meta.get_rank_page(1) is None
def test_get_all_pages(self):
"""get_all_pages() returns all (rank, page) pairs."""
key_meta = KeyMetadata(key="k", rank_to_page={0: 1, 1: 2}, tp_world_size=2)
pairs = key_meta.get_all_pages()
assert set(pairs) == {(0, 1), (1, 2)}
# ===========================================================================
# TestGlobalMetadataStateAllocation
# ===========================================================================
class TestGlobalMetadataStateAllocation:
"""Tests for GlobalMetadataState allocation and cache-hit semantics."""
def test_uninitialized_rank_raises_on_allocate(self):
"""allocate_pages_for_keys raises ValueError for unknown rank."""
state = GlobalMetadataState()
with pytest.raises((ValueError, Exception)):
state.allocate_pages_for_keys(99, [("key", "")])
def test_uninitialized_rank_raises_on_get_locations(self):
"""get_key_locations raises ValueError for unknown rank."""
state = GlobalMetadataState()
with pytest.raises((ValueError, Exception)):
state.get_key_locations(99, ["any_key"])
def test_basic_allocation_and_confirm(self):
"""Allocating a page and confirming it marks the key as complete."""
state = GlobalMetadataState()
state.initialize_rank(0, 4)
results = state.allocate_pages_for_keys(0, [("K", "")])
assert results["K"] >= 0
state.confirm_write_for_keys(0, [("K", results["K"])])
assert state.batch_key_exists(["K"]) == [True]
locations = state.get_key_locations(0, ["K"])
assert locations == [results["K"]]
def test_allocate_pages_cache_hit_does_not_leak_pages(self):
"""Cache-hit key must not consume a page from the free pool;
the pre-allocated slot must be returned before reusing the existing page.
"""
state = GlobalMetadataState()
state.initialize_rank(0, 5) # 5 free pages: [0,1,2,3,4]
# Simulate a key that has already been fully written and confirmed.
state.key_metadata["K_cached"] = KeyMetadata(
key="K_cached", rank_to_page={0: 2}, tp_world_size=1
)
free_before = state.rank_metadata[0].get_free_page_count() # 5
results = state.allocate_pages_for_keys(0, [("K_cached", ""), ("K_new", "")])
free_after = state.rank_metadata[0].get_free_page_count()
# Cache-hit key must reuse its existing page.
assert results["K_cached"] == 2, (
f"Cache-hit key should reuse page 2, got {results['K_cached']}"
)
# New key must receive a valid page.
assert results["K_new"] >= 0, (
f"New key should get a valid page, got {results['K_new']}"
)
# Exactly one page consumed from the free pool.
assert free_before - free_after == 1, (
f"Expected 1 page consumed, got delta={free_before - free_after}"
)
def test_allocate_pages_all_cache_hits_frees_all_slots(self):
"""When every key in the batch is a cache hit, no pages are consumed."""
state = GlobalMetadataState()
state.initialize_rank(0, 5)
for key, page in (("K1", 0), ("K2", 1)):
state.key_metadata[key] = KeyMetadata(
key=key, rank_to_page={0: page}, tp_world_size=1
)
free_before = state.rank_metadata[0].get_free_page_count()
results = state.allocate_pages_for_keys(0, [("K1", ""), ("K2", "")])
free_after = state.rank_metadata[0].get_free_page_count()
assert results["K1"] == 0
assert results["K2"] == 1
assert free_after == free_before, (
f"All-cache-hit batch must not consume free pages; "
f"before={free_before}, after={free_after}"
)
def test_allocate_returns_minus_one_when_pool_exhausted(self):
"""If the free pool is exhausted, all new keys receive -1."""
state = GlobalMetadataState()
state.initialize_rank(0, 1) # only 1 free page
results = state.allocate_pages_for_keys(0, [("K1", ""), ("K2", "")])
# allocate_pages uses all-or-nothing: 2 needed but only 1 available → []
assert all(v == -1 for v in results.values()), f"Expected all -1, got {results}"
def test_confirm_write_releases_pages(self):
"""confirm_write_for_keys with pages_to_release returns them to pool."""
state = GlobalMetadataState()
state.initialize_rank(0, 3)
results = state.allocate_pages_for_keys(0, [("K", "")])
page = results["K"]
free_after_alloc = state.rank_metadata[0].get_free_page_count()
state.confirm_write_for_keys(0, [("K", page)], pages_to_release=[page])
free_after_release = state.rank_metadata[0].get_free_page_count()
assert free_after_release == free_after_alloc + 1
...@@ -211,15 +211,18 @@ KVConnectorFactory.register_connector( ...@@ -211,15 +211,18 @@ KVConnectorFactory.register_connector(
"vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector", "vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector",
"MooncakeConnector", "MooncakeConnector",
) )
KVConnectorFactory.register_connector( KVConnectorFactory.register_connector(
"FlexKVConnectorV1", "FlexKVConnectorV1",
"vllm.distributed.kv_transfer.kv_connector.v1.flexkv_connector", "vllm.distributed.kv_transfer.kv_connector.v1.flexkv_connector",
"FlexKVConnectorV1", "FlexKVConnectorV1",
) )
KVConnectorFactory.register_connector( KVConnectorFactory.register_connector(
"SimpleCPUOffloadConnector", "SimpleCPUOffloadConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.simple_cpu_offload_connector", "vllm.distributed.kv_transfer.kv_connector.v1.simple_cpu_offload_connector",
"SimpleCPUOffloadConnector", "SimpleCPUOffloadConnector",
) )
KVConnectorFactory.register_connector(
"HF3FSKVConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.hf3fs.hf3fs_connector",
"HF3FSKVConnector",
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
import multiprocessing
import os
import threading
from functools import wraps
from pathlib import Path
import torch
import torch.utils.cpp_extension
from torch.utils.cpp_extension import load
root = Path(__file__).parent.resolve()
cuda_include_path = os.path.join(torch.utils.cpp_extension.CUDA_HOME, "include")
hf3fs_utils = load(
name="hf3fs_utils",
sources=[f"{root}/utils/hf3fs_utils.cpp"],
extra_include_paths=[cuda_include_path],
)
logger = logging.getLogger(__name__)
HF3FS_AVAILABLE = True
try:
from hf3fs_fuse.io import (
deregister_fd,
extract_mount_point,
make_ioring,
make_iovec,
register_fd,
)
except ImportError:
HF3FS_AVAILABLE = False
def rsynchronized():
def _decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
with self.rlock:
return func(self, *args, **kwargs)
return wrapper
return _decorator
def wsynchronized():
def _decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
with self.wlock:
return func(self, *args, **kwargs)
return wrapper
return _decorator
class Hf3fsClient:
def __init__(self, path: str, size: int, bytes_per_page: int, entries: int):
"""Initialize the HF3FS client with hf3fs_fuse.
Args:
path: Path to the file used for storage
size: Total size of the storage file in bytes
bytes_per_page: Size of each page in bytes
entries: Maximum number of concurrent operations
"""
if not HF3FS_AVAILABLE:
raise ImportError(
"hf3fs_fuse.io is not available. Please install the hf3fs_fuse package."
)
self.path = path
self.size = size
self.bytes_per_page = bytes_per_page
self.entries = entries
self._closed = False
self.file = None
self.shm_r = None
self.shm_w = None
self.ior_r = None
self.ior_w = None
self.iov_r = None
self.iov_w = None
try:
# Create the file if it doesn't exist and set its size
self.file = os.open(self.path, os.O_RDWR | os.O_CREAT)
os.ftruncate(self.file, size)
register_fd(self.file)
self.hf3fs_mount_point = extract_mount_point(path)
self.bs = self.bytes_per_page
self.shm_r = multiprocessing.shared_memory.SharedMemory(
size=self.bs * self.entries, create=True
)
self.shm_w = multiprocessing.shared_memory.SharedMemory(
size=self.bs * self.entries, create=True
)
self.shm_r_tensor = torch.frombuffer(self.shm_r.buf, dtype=torch.uint8)
self.shm_w_tensor = torch.frombuffer(self.shm_w.buf, dtype=torch.uint8)
numel = self.bs * self.entries
self.r_pinned = torch.empty(
numel,
dtype=torch.uint8,
device="cpu",
pin_memory=True,
)
self.w_pinned = torch.empty(
numel,
dtype=torch.uint8,
device="cpu",
pin_memory=True,
)
self.numa = -1
self.ior_r = make_ioring(
self.hf3fs_mount_point,
self.entries,
for_read=True,
timeout=1,
numa=self.numa,
)
self.ior_w = make_ioring(
self.hf3fs_mount_point,
self.entries,
for_read=False,
timeout=1,
numa=self.numa,
)
self.iov_r = make_iovec(self.shm_r, self.hf3fs_mount_point)
self.iov_w = make_iovec(self.shm_w, self.hf3fs_mount_point)
self.shm_r.unlink()
self.shm_w.unlink()
self.rlock = threading.RLock()
self.wlock = threading.RLock()
self.stream = torch.cuda.Stream()
self.stream_ptr_int = self.stream.cuda_stream
except Exception:
self._release_resources()
raise
logger.debug(
"Initialized HF3FS client with file: %s, size: %s bytes", path, size
)
def _release_resources(self) -> None:
"""Release all acquired resources safely"""
# iov must be released before ioring and shm
for attr in ("iov_r", "iov_w", "ior_r", "ior_w"):
obj = getattr(self, attr, None)
if obj is not None:
del obj
setattr(self, attr, None)
for attr in ("shm_r", "shm_w"):
shm = getattr(self, attr, None)
if shm is not None:
try:
shm.close()
except Exception as e:
logger.warning("Failed to close %s: %s", attr, e)
setattr(self, attr, None)
if self.file is not None:
try:
deregister_fd(self.file)
except Exception as e:
logger.warning("deregister_fd failed: %s", e)
try:
os.close(self.file)
except OSError as e:
logger.warning("os.close failed: %s", e)
self.file = None
@rsynchronized()
def batch_read(self, offsets: list[int], tensors: list[torch.Tensor]) -> list[int]:
"""Read data from the file at specified offsets into tensors.
Args:
offsets: List of byte offsets to read from
tensors: List of tensors to read data into
Returns:
List of operation results (0 for success, non-zero for error)
"""
self.check(offsets, tensors)
assert self.ior_r is not None
assert self.iov_r is not None
# prepare
current = 0
for offset, tensor in zip(offsets, tensors):
size = tensor.numel() * tensor.itemsize
self.ior_r.prepare(
self.iov_r[current : current + size], True, self.file, offset
)
current += size
# submit
ionum = len(offsets)
resv = self.ior_r.submit().wait(min_results=ionum)
# results
with torch.cuda.stream(self.stream):
hf3fs_utils.read_shm(
self.shm_r_tensor, self.r_pinned, tensors, self.stream_ptr_int
)
results = [res.result for res in resv]
return results
@wsynchronized()
def batch_write(
self, offsets: list[int], tensors: list[torch.Tensor], event: torch.cuda.Event
) -> list[int]:
"""Write data from tensors to the file at specified offsets.
Args:
offsets: List of byte offsets to write to
tensors: List of tensors containing data to write
Returns:
List of operation results (0 for success, non-zero for error)
"""
self.check(offsets, tensors)
assert self.ior_w is not None
assert self.iov_w is not None
# prepare
with torch.cuda.stream(self.stream):
self.stream.wait_event(event)
hf3fs_utils.write_shm(
tensors, self.shm_w_tensor, self.w_pinned, self.stream_ptr_int
)
current = 0
for offset, tensor in zip(offsets, tensors):
size = tensor.numel() * tensor.itemsize
self.ior_w.prepare(
self.iov_w[current : current + size], False, self.file, offset
)
current += size
# submit
ionum = len(offsets)
resv = self.ior_w.submit().wait(min_results=ionum)
# results
results = [res.result for res in resv]
return results
def check(self, offsets: list[int], tensors: list[torch.Tensor]) -> None:
sizes = [t.numel() * t.itemsize for t in tensors]
if any(
[
len(offsets) > self.entries,
len(offsets) != len(sizes),
any(
offset < 0 or offset + size > self.size
for offset, size in zip(offsets, sizes)
),
any(size > self.bytes_per_page for size in sizes),
]
):
self.close()
raise ValueError("Hf3fsClient.check Failed")
def get_size(self) -> int:
"""Get the total size of the storage file.
Returns:
Size of the file in bytes
"""
return self.size
def close(self) -> None:
"""Close the client and clean up resources."""
if self._closed:
return
self._closed = True
self._release_resources()
def flush(self) -> None:
"""Flush any pending writes to disk."""
if not self._closed and self.file is not None:
os.fsync(self.file)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
HF3FS KV Connector Implementation for vLLM.
This module implements a KV connector that uses
the 3FS for storing and retrieving KV cache data.
Key components:
1. HF3FSConnector: Main connector implementation
2.1 AsyncOperationManager: Manages async save/load operations with background threads
2.2 HF3FSConnectorMetadata: Container for connector metadata
3. HF3FSMetadataServer: Mini Metadata server for HF3FS connector
4. HF3FSClient: 3FS Client Implementation
"""
import atexit
import concurrent
import copy
import hashlib
import os
import queue
import signal
import threading
import time
from concurrent.futures import Future
from dataclasses import dataclass
from queue import Empty
from typing import Any, Optional
import numpy as np
import torch
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
KVConnectorMetadata,
KVConnectorRole,
)
from vllm.distributed.kv_transfer.kv_connector.v1.hf3fs.hf3fs_metadata_server import (
Hf3fsGlobalMetadataClient as Hf3fsMetadataClient,
)
from vllm.distributed.kv_transfer.kv_connector.v1.hf3fs.utils import (
gather_scatter_helper,
)
from vllm.distributed.kv_transfer.kv_connector.v1.hf3fs.utils.common import (
AtomicCounter,
HF3FSConnectorMetadata,
HF3FSRequestMetadata,
LoadBlockInfo,
RequestSchedulingState,
)
from vllm.distributed.kv_transfer.kv_connector.v1.hf3fs.utils.gather_scatter_helper import ( # noqa: E501
CopyBufferAllocator,
)
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorPromMetrics,
KVConnectorStats,
PromMetric,
PromMetricT,
)
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
from vllm.forward_context import ForwardContext
from vllm.logger import init_logger
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.utils import create_metric_per_engine
from vllm.v1.request import Request
HF3FS_AVAILABLE = True
Hf3fsClient = None
try:
from hf3fs_fuse.io import deregister_fd # noqa: F401
from vllm.distributed.kv_transfer.kv_connector.v1.hf3fs.hf3fs_client import (
Hf3fsClient as _RealClient,
)
Hf3fsClient = _RealClient
except Exception:
HF3FS_AVAILABLE = False
from vllm.distributed.kv_transfer.kv_connector.v1.hf3fs.utils.hf3fs_mock_client import ( # noqa: E501
Hf3fsClient as _MockClient,
)
Hf3fsClient = _MockClient # type: ignore
# Constants
DEFAULT_MAX_IO_ENTRIES = 8
logger = init_logger(__name__)
# ============================================================================
# Async Operation Management
# ============================================================================
class AsyncOperationManager:
"""
Manages async save/load operations with background threads.
"""
def __init__(self, connector: "HF3FSKVConnector"):
# Store connector reference and extract commonly used attributes
self._connector = connector
self._device = connector._device
self._dtype = connector._dtype
self._shape_per_page = connector._shape_per_page
self._bytes_per_page = connector._bytes_per_page
self._rank = connector._rank
self._numjobs = connector._numjobs
self._max_device_buffer_count = connector._max_device_buffer_count
# Operation tracking
self._save_futures: dict[str, list[Future]] = {}
self._load_futures: dict[str, Future] = {}
self._pending_finished_requests: set[str] = set()
# Initialize resources
self._init_cuda_resources()
self._init_worker_threads()
# Metrics
self.hf3fs_stats = HF3FSKVConnectorStats()
logger.info("AsyncOperationManager initialized for rank %d", self._rank)
def _init_cuda_resources(self) -> None:
"""Initialize CUDA streams, events and buffer allocators."""
# CUDA streams for async operations
self._save_stream = torch.cuda.Stream()
self._load_stream = torch.cuda.Stream()
self._save_event = torch.cuda.Event()
# Buffer allocators for data copying
self._save_buffer_allocator = CopyBufferAllocator(
self._device,
self._dtype,
self._shape_per_page,
self._max_device_buffer_count,
)
self._load_buffer_allocator = CopyBufferAllocator(
self._device,
self._dtype,
self._shape_per_page,
self._max_device_buffer_count,
)
def _init_worker_threads(self) -> None:
"""Initialize worker threads and I/O executor."""
# Thread synchronization
self._stop_event = threading.Event()
self._save_queue: queue.Queue[Any] = queue.Queue()
self._load_queue: queue.Queue[Any] = queue.Queue()
# I/O thread pool
self._io_executor = concurrent.futures.ThreadPoolExecutor(
max_workers=self._numjobs,
thread_name_prefix=f"HF3FS-Rank{self._rank}",
)
# Background worker threads
self._save_thread = threading.Thread(target=self._save_worker, daemon=True)
self._load_thread = threading.Thread(target=self._load_worker, daemon=True)
self._save_thread.start()
self._load_thread.start()
def submit_save_operation(self, request_id: str, block_ids, block_hashes) -> Future:
"""Submit a save operation for async execution."""
future: Future[Any] = Future()
main_stream_event = torch.cuda.Event()
main_stream_event.record()
task = (request_id, block_ids, block_hashes, future, main_stream_event)
self._save_queue.put(task)
if request_id not in self._save_futures:
self._save_futures[request_id] = []
self._save_futures[request_id].append(future)
return future
def submit_load_operation(self, request_id: str, block_ids, block_hashes) -> Future:
"""Submit a load operation for async execution."""
future: Future[Any] = Future()
task = (request_id, block_ids, block_hashes, future)
self._load_queue.put(task)
self._load_futures[request_id] = future
return future
def get_finished_operations(
self, finished_req_ids: set[str]
) -> tuple[set[str], set[str]]:
completed_saves = self._check_completed_saves(finished_req_ids)
completed_loads = self._check_completed_loads()
if completed_saves or completed_loads:
logger.info(
"HF3FS Connector Completed: %d saves, %d loads operations",
len(completed_saves),
len(completed_loads),
)
return completed_saves, completed_loads
def _check_completed_saves(self, finished_req_ids: set[str]) -> set[str]:
"""Check for completed save operations."""
completed = set()
# Check pending finished requests first
for request_id in list(self._pending_finished_requests):
if request_id in self._save_futures and self._all_saves_done(request_id):
completed.add(request_id)
self._save_futures.pop(request_id)
self._pending_finished_requests.remove(request_id)
# Process newly finished requests
for request_id in finished_req_ids:
if request_id in self._save_futures:
if self._all_saves_done(request_id):
completed.add(request_id)
self._save_futures.pop(request_id)
else:
self._pending_finished_requests.add(request_id)
else:
completed.add(request_id)
return completed
def _check_completed_loads(self) -> set[str]:
"""Check for completed load operations."""
completed = set()
for request_id in list(self._load_futures):
if self._load_futures[request_id].done():
completed.add(request_id)
self._load_futures.pop(request_id)
return completed
def _all_saves_done(self, request_id: str) -> bool:
"""Check if all save operations for a request are completed."""
return all(future.done() for future in self._save_futures[request_id])
def _save_worker(self) -> None:
"""Background worker for handling save operations."""
torch.accelerator.set_device_index(self._device)
while not self._stop_event.is_set():
try:
task = self._save_queue.get(block=True, timeout=1)
self._handle_save_task(task)
except Empty:
continue
except Exception as e:
logger.error("Save worker error: %s", e)
def _load_worker(self) -> None:
"""Background worker for handling load operations."""
torch.accelerator.set_device_index(self._device)
while not self._stop_event.is_set():
try:
task = self._load_queue.get(block=True, timeout=1)
self._handle_load_task(task)
except Empty:
continue
except Exception as e:
logger.error("Load worker error: %s", e)
def _handle_save_task(self, task) -> None:
"""Handle individual save task with proper stream synchronization."""
request_id, block_ids, block_hashes, future, main_stream_event = task
start_time = time.perf_counter()
buffers = None
try:
# Step1: Allocate storage pages
key_pairs = [(hash_val, "") for hash_val in block_hashes]
allocation_results = (
self._connector._metadata_client.allocate_pages_for_keys(
self._rank, key_pairs
)
)
if any(result[1] < 0 for result in allocation_results):
return self._fail_task(
"Saved", "Page allocation failed", request_id, future
)
page_indices = [result[1] for result in allocation_results]
offsets = [idx * self._bytes_per_page for idx in page_indices]
# Step2: Allocate buffers and gather KV cache data
buffers = self._save_buffer_allocator.alloc_buffer(len(block_ids))
if buffers is None:
return self._fail_task(
"Saved",
f"Buffer allocation failed for {len(block_ids)} blocks",
request_id,
future,
)
# Synchronize streams and gather data
with torch.cuda.stream(self._save_stream):
self._save_stream.wait_event(main_stream_event) # Wait for main stream
self._connector._gather_or_scatter_kv_caches(
block_ids, buffers, "gather"
)
save_stream_event = torch.cuda.Event()
save_stream_event.record(self._save_stream) # Record gather completion
# Step3: Write data in batches
write_futures = []
for i in range(0, len(offsets), DEFAULT_MAX_IO_ENTRIES):
batch_offsets = offsets[i : i + DEFAULT_MAX_IO_ENTRIES]
batch_buffers = buffers[i : i + DEFAULT_MAX_IO_ENTRIES]
client = self._connector._clients[self._connector._ac.next()]
write_future = self._io_executor.submit(
client.batch_write, batch_offsets, batch_buffers, save_stream_event
)
write_futures.append(write_future)
# Check write results
write_success = all(
result == self._bytes_per_page
for write_future in write_futures
for result in write_future.result()
)
# Step4: Confirm writes to metadata server
if write_success:
written_keys = list(zip(block_hashes, page_indices))
self._connector._metadata_client.confirm_write_for_keys(
self._rank, written_keys, []
)
self._save_buffer_allocator.free_buffer(buffers)
return self._succeed_task(
"Saved", start_time, request_id, len(block_ids), future
)
else:
self._connector._metadata_client.confirm_write_for_keys(
self._rank, [], page_indices
)
self._save_buffer_allocator.free_buffer(buffers)
return self._fail_task(
"Saved", "Write operation failed", request_id, future
)
except Exception as e:
if buffers is not None:
self._save_buffer_allocator.free_buffer(buffers)
return self._fail_task(
"Saved", f"Task execution error: {e}", request_id, future
)
def _handle_load_task(self, task) -> None:
"""Handle individual load task."""
request_id, block_ids, block_hashes, future = task
start_time = time.perf_counter()
buffers = None
try:
# Step1: Get block locations from metadata server
page_indices = self._connector._metadata_client.get_key_locations(
self._rank, block_hashes
)
if any(idx is None for idx in page_indices):
return self._fail_task("Loaded", "Blocks not found", request_id, future)
# Allocate read buffer
buffers = self._load_buffer_allocator.alloc_buffer(len(block_ids))
if buffers is None:
return self._fail_task(
"Loaded",
f"Buffer allocation failed for {len(block_ids)} blocks",
request_id,
future,
)
# Step2: Read data in batches
offsets = [idx * self._bytes_per_page for idx in page_indices]
read_futures = []
for i in range(0, len(offsets), DEFAULT_MAX_IO_ENTRIES):
batch_offsets = offsets[i : i + DEFAULT_MAX_IO_ENTRIES]
batch_buffers = buffers[i : i + DEFAULT_MAX_IO_ENTRIES]
client = self._connector._clients[self._connector._ac.next()]
read_future = self._io_executor.submit(
client.batch_read, batch_offsets, batch_buffers
)
read_futures.append(read_future)
# Check read results
read_success = all(
result == self._bytes_per_page
for read_future in read_futures
for result in read_future.result()
)
if not read_success:
self._load_buffer_allocator.free_buffer(buffers)
return self._fail_task(
"Loaded", "Read operation failed", request_id, future
)
# Step3: Scatter data back to KV cache
with torch.cuda.stream(self._load_stream):
self._connector._gather_or_scatter_kv_caches(
block_ids, buffers, "scatter"
)
self._load_stream.synchronize()
self._load_buffer_allocator.free_buffer(buffers)
return self._succeed_task(
"Loaded", start_time, request_id, len(block_ids), future
)
except Exception as e:
if buffers is not None:
self._load_buffer_allocator.free_buffer(buffers)
return self._fail_task(
"Loaded", f"Task execution error: {e}", request_id, future
)
def _fail_task(
self, operation: str, error_msg: str, request_id: str, future: Future
) -> None:
"""Helper to fail task with error logging."""
logger.error(
"%s for %s request %s",
error_msg,
operation,
request_id,
)
self.hf3fs_stats.record_failed_task_count(operation)
future.set_result(False)
def _succeed_task(
self,
operation: str,
start_time: float,
request_id: str,
block_count: int,
future: Future,
) -> None:
"""Helper to succeed task with logging."""
duration = time.perf_counter() - start_time
logger.info(
"%s %s: %d blocks in %.2fs",
operation,
request_id,
block_count,
duration,
)
self.hf3fs_stats.record_success_task_duration(operation, duration)
future.set_result(True)
def shutdown(self) -> None:
"""Clean shutdown of all background threads and resources."""
self._stop_event.set()
self._save_thread.join()
self._load_thread.join()
self._io_executor.shutdown(wait=True)
logger.info("AsyncOperationManager shutdown completed")
# ============================================================================
# HF3FS Connector
# ============================================================================
class HF3FSKVConnector(KVConnectorBase_V1):
"""HF3FS KV Connector implementation."""
def __init__(
self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: "KVCacheConfig",
):
super().__init__(
vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config
)
# Core configuration
self._vllm_config = vllm_config
self._role = role
self._block_size = vllm_config.cache_config.block_size
self._use_mla = vllm_config.model_config.use_mla
self._model_config = vllm_config.model_config
logger.info("Using MLA: %s", self._use_mla)
# HF3FS configuration
kv_config = vllm_config.kv_transfer_config
assert kv_config is not None
self._storage_path = kv_config.get_from_extra_config(
"hf3fs_storage_path", "/vllm-workspace/mnt/hf3fs"
)
self._metadata_server_url = kv_config.get_from_extra_config(
"hf3fs_metadata_server_url", "http://localhost:18000"
)
self._file_size = kv_config.get_from_extra_config(
"hf3fs_file_size", 1024 * 1024 * 1024
)
self._numjobs = kv_config.get_from_extra_config("hf3fs_client_numjobs", 16)
self._max_device_buffer_count = kv_config.get_from_extra_config(
"hf3fs_max_device_buffer_count", 128
)
self._max_device_buffer_count = max(
self._max_device_buffer_count, self._numjobs * DEFAULT_MAX_IO_ENTRIES
)
if self._role == KVConnectorRole.SCHEDULER:
self._scheduling_states: dict[str, RequestSchedulingState] = {}
self._metadata_client = Hf3fsMetadataClient()
self._metadata_client.initialize(0, role="scheduler")
atexit.register(self.close)
signal.signal(signal.SIGINT, lambda sig, frame: self.close())
signal.signal(signal.SIGTERM, lambda sig, frame: self.close())
signal.signal(signal.SIGQUIT, lambda sig, frame: self.close())
logger.info(
"HF3FSKVConnector initialized: path=%s, role=%s",
self._storage_path,
self._role.name,
)
############################################################
# Worker Side Methods
############################################################
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]) -> None:
self._kv_caches = kv_caches
self._setup_kv_cache_config()
self._setup_storage_clients()
self._async_manager = AsyncOperationManager(self)
def _setup_kv_cache_config(self):
first_cache = next(iter(self._kv_caches.values()))
self._device = first_cache.device
self._dtype = first_cache.dtype
element_size = first_cache.element_size()
if self._use_mla:
assert len(first_cache.shape) == 3, "MLA format should have 3 dimensions"
# MLA format: [num_blocks, block_size, head_size]
num_blocks, block_size, head_size = first_cache.shape
num_heads = 1
else:
# MHA format: [2, num_blocks, block_size, num_heads, head_size]
_, num_blocks, block_size, num_heads, head_size = first_cache.shape
self._local_total_tokens = num_blocks * block_size
self._local_block_size = block_size
if self._use_mla:
layer_block_size = block_size * head_size * element_size
self._bytes_per_page = layer_block_size * len(self._kv_caches)
self._shape_per_page = [
len(self._kv_caches),
block_size,
head_size,
]
else:
layer_block_size = 2 * block_size * num_heads * head_size * element_size
self._bytes_per_page = layer_block_size * len(self._kv_caches)
self._shape_per_page = [
len(self._kv_caches),
2,
block_size,
num_heads * head_size,
]
self._kvcache_ptrs = torch.tensor(
[cache.data_ptr() for cache in self._kv_caches.values()],
dtype=torch.int64,
device=self._device,
)
def _setup_storage_clients(self):
os.makedirs(self._storage_path, exist_ok=True)
self._rank = get_tensor_model_parallel_rank()
file_path = os.path.join(
self._storage_path, f"hf3fs_vllm_data_file_{self._rank}"
)
try:
# Initialize HF3FS clients
self._ac = AtomicCounter(self._numjobs)
assert Hf3fsClient is not None
self._clients = [
Hf3fsClient(
path=file_path,
size=self._file_size,
bytes_per_page=self._bytes_per_page,
entries=DEFAULT_MAX_IO_ENTRIES,
)
for _ in range(self._numjobs)
]
# Initialize metadata client
num_pages = self._file_size // self._bytes_per_page
self._metadata_client = Hf3fsMetadataClient()
self._metadata_client.initialize(self._rank, num_pages, role="worker")
except Exception as e:
logger.error("HF3FS client initialization failed: %s", e)
raise
def save_kv_layer(
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
**kwargs,
) -> None:
"""HF3FSConnector does not do layerwise saving."""
pass
def wait_for_save(self) -> None:
metadata = self._get_connector_metadata()
if not isinstance(metadata, HF3FSConnectorMetadata):
logger.error("Invalid metadata type: %s", type(metadata))
return
for request in metadata.requests:
if request.save_block_op is None:
continue
skip_blocks = request.save_block_op.skip_leading_blocks
block_hashes = self._generate_block_hashes(request.token_ids, skip_blocks)
block_ids = request.block_ids[skip_blocks : skip_blocks + len(block_hashes)]
for i in range(0, len(block_ids), self._max_device_buffer_count):
batch_block_ids = block_ids[i : i + self._max_device_buffer_count]
batch_block_hashes = block_hashes[i : i + self._max_device_buffer_count]
self._async_manager.submit_save_operation(
request.request_id, batch_block_ids, batch_block_hashes
)
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
metadata = self._get_connector_metadata()
if not isinstance(metadata, HF3FSConnectorMetadata):
logger.error("Invalid metadata type for loading")
return
for request in metadata.requests:
if request.load_block_op is None:
continue
load_op = request.load_block_op
block_ids = request.block_ids[: load_op.num_blocks_to_load]
block_hashes = self._generate_block_hashes(
request.token_ids, load_op.num_computed_blocks, len(block_ids)
)
for i in range(0, len(block_ids), self._max_device_buffer_count):
batch_block_ids = block_ids[i : i + self._max_device_buffer_count]
batch_block_hashes = block_hashes[i : i + self._max_device_buffer_count]
self._async_manager.submit_load_operation(
request.request_id, batch_block_ids, batch_block_hashes
)
def wait_for_layer_load(self, layer_name: str) -> None:
pass
def get_finished(
self, finished_req_ids: set[str]
) -> tuple[set[str] | None, set[str] | None]:
return self._async_manager.get_finished_operations(finished_req_ids)
def get_kv_connector_stats(self) -> Optional["KVConnectorStats"]:
"""
Get the KV connector stats collected during the last interval.
"""
# Clear stats for next iteration
if (
hasattr(self, "_async_manager")
and not self._async_manager.hf3fs_stats.is_empty()
):
return self._async_manager.hf3fs_stats.clone_and_reset()
return None
############################################################
# Scheduler Side Methods
############################################################
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, dict[str, Any] | None]:
return True, None
def get_num_new_matched_tokens(
self, request: "Request", num_computed_tokens: int
) -> tuple[int, bool]:
"""Get number of new tokens that can be loaded from external cache."""
try:
state = self._get_or_create_scheduling_state(request.request_id)
state.request = request
assert request.prompt_token_ids is not None
num_tokens_to_check = self._align_to_block_size(
len(request.prompt_token_ids) - 1
)
if num_tokens_to_check <= num_computed_tokens:
state.load_op = LoadBlockInfo(
num_computed_blocks=num_computed_tokens // self._block_size,
num_blocks_to_load=0,
need_fetch_block_ids=[],
)
return 0, False
token_ids_to_check = request.prompt_token_ids[:num_tokens_to_check]
block_hashes = self._generate_block_hashes(token_ids_to_check, 0)
# Check existence
exists_results = self._metadata_client.batch_key_exists(block_hashes)
# Count consecutive matches
matched_blocks = next(
(i for i, exists in enumerate(exists_results) if not exists),
len(exists_results),
)
matched_tokens = matched_blocks * self._block_size
new_hit_tokens = max(0, matched_tokens - num_computed_tokens)
# Store load operation
state.load_op = LoadBlockInfo(
num_computed_blocks=num_computed_tokens // self._block_size,
num_blocks_to_load=new_hit_tokens // self._block_size,
need_fetch_block_ids=[],
)
logger.info(
(
"Token matching for %s: "
"%d matched (%d blocks), "
"%d new hits, "
"prompt len %d"
),
request.request_id,
matched_tokens,
matched_blocks,
new_hit_tokens,
len(request.prompt_token_ids),
)
return new_hit_tokens, new_hit_tokens > 0
except Exception as e:
logger.error(
"Error calculating matches for request %s: %s", request.request_id, e
)
return 0, False
def update_state_after_alloc(
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
) -> None:
"""Update state after block allocation."""
state = self._get_or_create_scheduling_state(request.request_id)
state.request = request
if num_external_tokens <= 0 or not state.needs_loading():
return
# Validate block allocation
assert state.load_op is not None
expected_blocks = state.load_op.num_blocks_to_load
actual_blocks = num_external_tokens // self._block_size
assert actual_blocks == expected_blocks, (
f"Block count mismatch for {request.request_id}: "
f"expected {expected_blocks}, got {actual_blocks}"
)
# Update load operation with allocated block IDs
if actual_blocks > 0:
local_block_ids = blocks.get_unhashed_block_ids()
state.load_op.need_fetch_block_ids.extend(local_block_ids)
state.phase = "WAITING_TO_LOAD"
def build_connector_meta(
self, scheduler_output: SchedulerOutput
) -> KVConnectorMetadata:
"""Build connector metadata for scheduling step."""
metadata = HF3FSConnectorMetadata()
for request_id in scheduler_output.finished_req_ids:
self._scheduling_states.pop(request_id, None)
# Process requests by phase
self._process_waiting_to_load_requests(metadata)
self._process_new_requests(scheduler_output, metadata)
self._process_cached_requests(scheduler_output, metadata)
return metadata
def _process_waiting_to_load_requests(
self, metadata: HF3FSConnectorMetadata
) -> None:
"""Process requests waiting to load."""
for state in list(self._scheduling_states.values()):
if not state.is_ready_to_load():
continue
assert state.load_op is not None
assert (
state.request is not None and state.request.prompt_token_ids is not None
)
# Create load request metadata
num_cached_blocks = (
state.load_op.num_computed_blocks + state.load_op.num_blocks_to_load
)
num_tokens_to_compute = num_cached_blocks * self._block_size
# Initialize token_ids and allocated_block_ids for loading
state.token_ids = state.request.prompt_token_ids[
:num_tokens_to_compute
].copy()
state.allocated_block_ids = state.load_op.need_fetch_block_ids.copy()
request_metadata = HF3FSRequestMetadata.from_scheduling_state(
state, self._block_size, state.load_op, num_cached_blocks
)
if request_metadata:
metadata.add_request(request_metadata)
state.phase = "ACTIVE"
def _process_new_requests(
self, scheduler_output: SchedulerOutput, metadata: HF3FSConnectorMetadata
) -> None:
"""Process new requests."""
for request in scheduler_output.scheduled_new_reqs:
state = self._get_or_create_scheduling_state(request.req_id)
# Calculate tokens to compute
num_tokens_to_compute = (
request.num_computed_tokens
+ scheduler_output.num_scheduled_tokens[request.req_id]
)
self._initialize_state_from_new_request(
state, request, num_tokens_to_compute
)
# Create save metadata (skip cached blocks if any)
num_cached_blocks = None
if state.load_op:
num_cached_blocks = (
state.load_op.num_computed_blocks + state.load_op.num_blocks_to_load
)
request_metadata = HF3FSRequestMetadata.from_scheduling_state(
state, self._block_size, None, num_cached_blocks
)
if request_metadata:
metadata.add_request(request_metadata)
state.phase = "ACTIVE"
def _process_cached_requests(
self, scheduler_output: SchedulerOutput, metadata: HF3FSConnectorMetadata
) -> None:
"""Process cached requests."""
cached_reqs = scheduler_output.scheduled_cached_reqs
for i, request_id in enumerate(cached_reqs.req_ids):
state = self._get_or_create_scheduling_state(request_id)
assert state.request is not None
# Update with new tokens and blocks
num_new_tokens = scheduler_output.num_scheduled_tokens[request_id]
num_current_tokens = len(state.token_ids)
new_token_ids = state.request.all_token_ids[
num_current_tokens : num_current_tokens + num_new_tokens
]
new_block_ids = cached_reqs.new_block_ids[i]
state.update_tokens_and_blocks(new_token_ids, new_block_ids)
# Create save metadata
request_metadata = HF3FSRequestMetadata.from_scheduling_state(
state, self._block_size, None
)
if request_metadata:
metadata.add_request(request_metadata)
@classmethod
def build_kv_connector_stats(
cls, data: dict[str, Any] | None = None
) -> Optional["KVConnectorStats"]:
"""
KVConnectorStats resolution method. This method allows dynamically
registered connectors to return their own KVConnectorStats object,
which can implement custom aggregation logic on the data dict.
"""
return (
HF3FSKVConnectorStats(data=data)
if data is not None
else HF3FSKVConnectorStats()
)
@classmethod
def build_prom_metrics(
cls,
vllm_config: VllmConfig,
metric_types: dict[type[PromMetric], type[PromMetricT]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[object]],
) -> KVConnectorPromMetrics:
return HF3FSPromMetrics(
vllm_config, metric_types, labelnames, per_engine_labelvalues
)
def close(self) -> None:
try:
if hasattr(self, "_async_manager"):
self._async_manager.shutdown()
if hasattr(self, "_clients"):
for client in self._clients:
client.close()
logger.info("HF3FS clients closed")
except Exception as e:
logger.error("Connector shutdown error: %s", e)
############################################################
# Utility Methods
############################################################
def _get_or_create_scheduling_state(
self, request_id: str
) -> RequestSchedulingState:
"""Get existing or create new scheduling state."""
if request_id not in self._scheduling_states:
self._scheduling_states[request_id] = RequestSchedulingState(
request_id=request_id
)
return self._scheduling_states[request_id]
def _initialize_state_from_new_request(
self, state: RequestSchedulingState, request, num_tokens_to_compute: int
) -> None:
"""Initialize state from new request data."""
# Handle different block_ids formats in vLLM 0.9.0+
if isinstance(request.block_ids[0], list):
unfolded_block_ids = request.block_ids[0].copy()
else:
unfolded_block_ids = request.block_ids.copy()
state.token_ids = request.prompt_token_ids[:num_tokens_to_compute].copy()
state.allocated_block_ids = unfolded_block_ids
state.num_saved_blocks = 0
def _generate_block_hashes(
self,
token_ids: list[int],
start_block_id: int,
max_blocks_count: int | None = None,
) -> list[str]:
"""Generate block hashes for token sequence."""
block_hashes = []
previous_hash = ""
for start_idx in range(0, len(token_ids), self._block_size):
if start_idx + self._block_size > len(token_ids):
break
end_idx = start_idx + self._block_size
block_hash = self._compute_prefix_hash(
token_ids[start_idx:end_idx], previous_hash
)
block_index = start_idx // self._block_size
if block_index >= start_block_id:
block_hashes.append(block_hash)
if max_blocks_count and len(block_hashes) >= max_blocks_count:
break
previous_hash = block_hash
return block_hashes
def _gather_or_scatter_kv_caches(
self, block_ids: list[int], block_buffers, operation: str
):
for buffer_tensor, block_id in zip(block_buffers, block_ids):
start_idx = block_id * self._local_block_size
token_indices = list(range(start_idx, start_idx + self._local_block_size))
if operation == "gather":
gather_scatter_helper.gather_kv_caches(
self._kvcache_ptrs,
self._local_total_tokens,
buffer_tensor,
token_indices,
is_mla=self._use_mla,
)
else:
gather_scatter_helper.scatter_kv_caches(
self._kvcache_ptrs,
self._local_total_tokens,
buffer_tensor,
token_indices,
is_mla=self._use_mla,
)
def _compute_prefix_hash(
self, token_ids: list[int], previous_hash: str = ""
) -> str:
"""Compute prefix hash for token block."""
combined_string = f"{previous_hash}_{token_ids}"
return hashlib.md5(combined_string.encode()).hexdigest()
def _align_to_block_size(self, num_tokens: int) -> int:
"""Align token count to block size."""
return (num_tokens // self._block_size) * self._block_size
@dataclass
class HF3FSKVConnectorStats(KVConnectorStats):
"""Container for transfer performance metrics"""
def __post_init__(self):
if not self.data:
# Empty container init, no data is passed in.
self.reset()
def reset(self):
# Must be serializable
self.data: dict[str, Any] = {
"save_duration": [],
"load_duration": [],
"num_failed_save": 0,
"num_failed_load": 0,
"num_transfer_task": 0,
}
def aggregate(self, other: "KVConnectorStats") -> "KVConnectorStats":
if not other.is_empty():
for k, v in other.data.items():
accumulator = self.data[k]
if isinstance(accumulator, list):
accumulator.extend(v)
else: # int
self.data[k] += v
return self
def reduce(self) -> dict[str, int | float]:
# Compute compact representative stats suitable for CLI logging
if self.is_empty():
return {
"Num transfers task": 0,
"Num save task success": 0,
"Num save task failed": 0,
"Num load task success": 0,
"Num load task failed": 0,
"Avg save duration (ms)": 0,
"P90 save duration (ms)": 0,
"Avg load duration (ms)": 0,
"P90 load duration (ms)": 0,
}
num_success_save = len(self.data["save_duration"] or [])
num_success_load = len(self.data["load_duration"] or [])
num_failed_save = self.data["num_failed_save"]
num_failed_load = self.data["num_failed_load"]
if num_success_save == 0:
save_duration = np.zeros(1)
else:
save_duration = np.asarray(self.data["save_duration"])
if num_success_load == 0:
load_duration = np.zeros(1)
else:
load_duration = np.asarray(self.data["load_duration"])
return {
"Num transfers task": self.data["num_transfer_task"],
"Num save task success": num_success_save,
"Num save task failed": num_failed_save,
"Num load task success": num_success_load,
"Num load task failed": num_failed_load,
"Avg save duration (ms)": round(save_duration.mean() * 1e3, 3),
"P90 save duration (ms)": round(np.percentile(save_duration, 90) * 1e3, 3),
"Avg load duration (ms)": round(load_duration.mean() * 1e3, 3),
"P90 load duration (ms)": round(np.percentile(load_duration, 90) * 1e3, 3),
}
def is_empty(self) -> bool:
return self.data["num_transfer_task"] == 0
def record_success_task_duration(self, operation, duration):
if operation == "Saved":
self.data["save_duration"].append(duration)
elif operation == "Loaded":
self.data["load_duration"].append(duration)
self.data["num_transfer_task"] += 1
def record_failed_task_count(self, operation):
if operation == "Saved":
self.data["num_failed_save"] += 1
elif operation == "Loaded":
self.data["num_failed_load"] += 1
self.data["num_transfer_task"] += 1
def clone_and_reset(self):
old = copy.copy(self)
self.reset()
return old
class HF3FSPromMetrics(KVConnectorPromMetrics):
def __init__(
self,
vllm_config: VllmConfig,
metric_types: dict[type[PromMetric], type[PromMetricT]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[object]],
):
super().__init__(vllm_config, metric_types, labelnames, per_engine_labelvalues)
buckets = [
0.001,
0.005,
0.01,
0.025,
0.05,
0.075,
0.1,
0.2,
0.3,
0.5,
0.75,
1.0,
5.0,
]
hf3fs_save_duration = self._histogram_cls(
name="vllm:hf3fs_save_duration_seconds",
documentation="Histogram of save duration for HF3FSKVConnector.",
buckets=buckets,
labelnames=labelnames,
)
self.hf3fs_save_duration = create_metric_per_engine(
hf3fs_save_duration, self.per_engine_labelvalues
)
hf3fs_load_duration = self._histogram_cls(
name="vllm:hf3fs_load_duration_seconds",
documentation="Histogram of load duration for HF3FSKVConnector.",
buckets=buckets,
labelnames=labelnames,
)
self.hf3fs_load_duration = create_metric_per_engine(
hf3fs_load_duration, self.per_engine_labelvalues
)
hf3fs_num_failed_save = self._counter_cls(
name="vllm:hf3fs_num_failed_save",
documentation="Number of failed HF3FS KV save.",
labelnames=labelnames,
)
self.hf3fs_num_failed_save = create_metric_per_engine(
hf3fs_num_failed_save, self.per_engine_labelvalues
)
hf3fs_num_failed_load = self._counter_cls(
name="vllm:hf3fs_num_failed_load",
documentation="Number of failed HF3FS KV load.",
labelnames=labelnames,
)
self.hf3fs_num_failed_load = create_metric_per_engine(
hf3fs_num_failed_load, self.per_engine_labelvalues
)
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
for prom_obj, list_item_key in zip(
[
self.hf3fs_save_duration,
self.hf3fs_load_duration,
],
[
"save_duration",
"load_duration",
],
):
for list_item in transfer_stats_data[list_item_key]:
prom_obj[engine_idx].observe(list_item)
for counter_obj, counter_item_key in zip(
[
self.hf3fs_num_failed_save,
self.hf3fs_num_failed_load,
],
[
"num_failed_save",
"num_failed_load",
],
):
counter_obj[engine_idx].inc(transfer_stats_data[counter_item_key])
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
HF3FS Metadata Server with key-based organization.
"""
import argparse
import logging
import threading
from abc import ABC, abstractmethod
from dataclasses import dataclass
try:
import orjson
HAS_ORJSON = True
except ImportError:
import json as orjson # type: ignore
HAS_ORJSON = False
import requests
from fastapi import FastAPI, HTTPException, Request, Response
from fastapi.responses import ORJSONResponse
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
@dataclass
class RankFileMetadata:
"""Manages file page allocation for a single rank."""
rank_id: int
num_pages: int
free_pages: list[int]
def allocate_pages(self, num_pages: int) -> list[int]:
"""Allocate specified number of free pages."""
if len(self.free_pages) < num_pages:
return []
allocated = self.free_pages[:num_pages]
self.free_pages = self.free_pages[num_pages:]
return allocated
def release_pages(self, page_indices: list[int]) -> None:
"""Release pages back to free pool."""
for page_idx in page_indices:
if page_idx not in self.free_pages:
self.free_pages.append(page_idx)
def get_free_page_count(self) -> int:
"""Get current number of free pages."""
return len(self.free_pages)
@dataclass
class KeyMetadata:
"""Manages metadata for a single key across multiple ranks."""
key: str
rank_to_page: dict[int, int] # rank -> allocated page index
tp_world_size: int
def add_rank_page(self, rank: int, page_index: int) -> None:
"""Add page allocation for a specific rank."""
self.rank_to_page[rank] = page_index
def get_all_pages(self) -> list[tuple[int, int]]:
"""Get all (rank, page) pairs for this key."""
return [(rank, page) for rank, page in self.rank_to_page.items()]
def get_rank_page(self, rank: int) -> int | None:
"""Get page index for a specific rank."""
return self.rank_to_page.get(rank)
def is_complete(self) -> bool:
"""Check if all ranks in the TP world have allocated pages."""
return len(self.rank_to_page) == self.tp_world_size
class GlobalMetadataState:
"""Manages global metadata state across all ranks and keys."""
def __init__(self):
self.global_lock = threading.RLock()
self.rank_metadata: dict[int, RankFileMetadata] = {}
self.key_metadata: dict[str, KeyMetadata] = {}
def clear(self) -> None:
"""Clear all metadata state."""
with self.global_lock:
self.rank_metadata.clear()
self.key_metadata.clear()
logger.info("Cleared all metadata state")
def initialize_rank(self, rank: int, num_pages: int) -> None:
"""Initialize a new rank with specified number of pages."""
with self.global_lock:
if rank not in self.rank_metadata:
self.rank_metadata[rank] = RankFileMetadata(
rank, num_pages, list(range(num_pages))
)
logger.info("Initialized rank %s with %s pages", rank, num_pages)
def allocate_pages_for_keys(
self, rank: int, keys: list[tuple[str, str]]
) -> dict[str, int]:
"""Allocate one page for each key on the specified rank.
Args:
rank: Rank ID to allocate pages on
keys: List of keys to allocate pages for
Returns:
Dictionary mapping key -> allocated page index
"""
with self.global_lock:
if rank not in self.rank_metadata:
raise ValueError(f"Rank {rank} not initialized")
# Batch allocate pages for all keys
num_pages_needed = len(keys)
allocated_pages = self.rank_metadata[rank].allocate_pages(num_pages_needed)
if len(allocated_pages) < num_pages_needed:
logger.warning(
"Rank %s only allocated %s pages for %s keys",
rank,
len(allocated_pages),
num_pages_needed,
)
allocation_results = {}
for i, (key, prefix_key) in enumerate(keys):
if key in self.key_metadata:
key_meta = self.key_metadata[key]
if key_meta.is_complete() and rank in key_meta.rank_to_page:
# key is already fully written, reuse the existing page
# and release the allocated pages back to the free pool.
if i < len(allocated_pages):
self.rank_metadata[rank].release_pages([allocated_pages[i]])
allocation_results[key] = key_meta.rank_to_page[rank]
continue
if i < len(allocated_pages):
allocation_results[key] = allocated_pages[i]
else:
allocation_results[key] = -1 # No pages available
return allocation_results
def confirm_write_for_keys(
self,
rank: int,
key_confirmations: list[tuple[str, int]],
pages_to_release: list[int] | None = None,
) -> None:
"""Confirm write operations for keys and update metadata.
Args:
rank: Rank ID that confirmed the writes
key_confirmations: List of (key, page_index) tuples
pages_to_release: List of page indices to release back to free pool
"""
with self.global_lock:
# Confirm successful writes
for key, page_index in key_confirmations:
if key not in self.key_metadata:
# Need to determine tp_world_size from rank_metadata
tp_world_size = len(self.rank_metadata)
self.key_metadata[key] = KeyMetadata(key, {}, tp_world_size)
# Add confirmed page to key metadata
self.key_metadata[key].add_rank_page(rank, page_index)
# Release specified pages back to free pool
if pages_to_release:
self.rank_metadata[rank].release_pages(pages_to_release)
logger.debug(
"Released %s pages on rank %s: %s",
len(pages_to_release),
rank,
pages_to_release,
)
def batch_key_exists(self, keys: list[str]) -> list[bool]:
"""Check if keys exist in metadata and all ranks have confirmed writes.
Args:
keys: List of keys to check
Returns:
List of boolean values indicating key existence and completion
"""
with self.global_lock:
results = []
for key in keys:
if key not in self.key_metadata:
results.append(False)
else:
# Check if all ranks in the TP world have confirmed writes
key_meta = self.key_metadata[key]
results.append(key_meta.is_complete())
return results
def get_key_locations(self, rank: int, keys: list[str]) -> list[int | None]:
"""Get page indices for keys on a specific rank.
Args:
rank: Rank ID to query
keys: List of keys to look up
Returns:
List of page indices in the same order as input keys (None if key not found)
"""
with self.global_lock:
if rank not in self.rank_metadata:
raise ValueError(f"Rank {rank} not initialized")
results = []
for key in keys:
if key in self.key_metadata:
key_meta = self.key_metadata[key]
if key_meta.is_complete():
page_index = key_meta.get_rank_page(rank)
else:
page_index = None
results.append(page_index)
else:
results.append(None)
return results
class Hf3fsMetadataServer:
"""HF3FS Metadata Server with improved key-based organization."""
def __init__(self, persistence_path: str | None = None, save_interval: int = 60):
self.state = GlobalMetadataState()
if HAS_ORJSON:
self.app = FastAPI(default_response_class=ORJSONResponse)
else:
self.app = FastAPI()
self._setup_routes()
async def _read_json(self, request: Request) -> dict:
"""Parse request JSON using orjson if available."""
body = await request.body()
return orjson.loads(body)
def _json_response(self, content: dict):
"""Return ORJSONResponse when available to bypass jsonable_encoder."""
if HAS_ORJSON:
return ORJSONResponse(content)
else:
return content
def _setup_routes(self):
"""Setup FastAPI routes for new API design."""
self.app.post("/rank/{rank}/initialize")(self.initialize_rank)
self.app.post("/keys/batch_allocate")(self.batch_allocate_pages_for_keys)
self.app.post("/keys/confirm_write")(self.confirm_write_for_keys)
self.app.post("/keys/batch_exists")(self.batch_key_exists)
self.app.post("/keys/get_locations")(self.get_key_locations)
self.app.post("/clear")(self.clear)
async def initialize_rank(self, rank: int, request: Request):
"""Initialize a rank with specified number of pages."""
data = await self._read_json(request)
role = data.get("role", "worker")
num_pages = data.get("num_pages", 0)
if role == "scheduler":
return self._json_response(
{"message": "Scheduler role does not require initialization"}
)
if role == "worker" and num_pages > 0:
self.state.initialize_rank(rank, num_pages)
return self._json_response(
{"message": f"Rank {rank} initialized with {num_pages} pages"}
)
else:
raise HTTPException(
status_code=400, detail="Invalid initialization parameters"
)
async def batch_allocate_pages_for_keys(self, request: Request):
"""Allocate one page for each key on a specific rank."""
data = await self._read_json(request)
rank = data.get("rank")
keys = data.get("keys", [])
# Validate input format
if rank is None or not isinstance(keys, list):
raise HTTPException(
status_code=400, detail="Invalid request format: need 'rank' and 'keys'"
)
try:
# Perform allocation
results = self.state.allocate_pages_for_keys(rank, keys)
# Convert results to response format
response = {"rank": rank, "results": list(results.items())}
return self._json_response(response)
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Allocation failed: {str(e)}"
) from e
async def confirm_write_for_keys(self, request: Request):
"""Confirm write operations for keys."""
data = await self._read_json(request)
rank = data.get("rank")
confirmations = data.get("confirmations", [])
pages_to_release = data.get("pages_to_release", [])
# Validate input format
if rank is None or not isinstance(confirmations, list):
raise HTTPException(
status_code=400,
detail="Invalid request format: need 'rank' and 'confirmations'",
)
try:
self.state.confirm_write_for_keys(rank, confirmations, pages_to_release)
return Response(status_code=204)
except Exception as e:
logger.error("Confirm write for keys failed: %s", e)
raise HTTPException(
status_code=500, detail=f"Confirmation failed: {str(e)}"
) from e
async def batch_key_exists(self, request: Request):
"""Check if multiple keys exist in metadata."""
data = await self._read_json(request)
keys = data.get("keys", [])
if not isinstance(keys, list):
raise HTTPException(status_code=400, detail="Invalid keys format")
try:
exists_results = self.state.batch_key_exists(keys)
return self._json_response({"exists": exists_results})
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Key existence check failed: {str(e)}"
) from e
async def get_key_locations(self, request: Request):
"""Get page indices for keys on a specific rank."""
data = await self._read_json(request)
rank = data.get("rank")
keys = data.get("keys", [])
# Validate input format
if rank is None or not isinstance(keys, list):
raise HTTPException(
status_code=400, detail="Invalid request format: need 'rank' and 'keys'"
)
try:
# Get key locations
locations = self.state.get_key_locations(rank, keys)
return self._json_response({"locations": locations})
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Failed to get key locations: {str(e)}"
) from e
async def clear(self, request: Request):
"""Clear the metadata server."""
self.state.clear()
return Response(status_code=204)
def run(self, host: str = "0.0.0.0", port: int = 18000):
"""Run the metadata server."""
import uvicorn
logger.info("Starting improved metadata server on http://%s:%s", host, port)
uvicorn.run(self.app, host=host, port=port)
# --- Client implementation ---
class Hf3fsMetadataInterface(ABC):
"""Interface for HF3FS metadata operations."""
@abstractmethod
def initialize(self, rank: int, num_pages: int = 0, role: str = "worker") -> None:
"""Initialize the metadata service with specified number of pages."""
pass
@abstractmethod
def allocate_pages_for_keys(
self, rank: int, keys: list[tuple[str, str]]
) -> list[tuple[str, int]]:
"""Allocate one page for each key on the specified rank."""
pass
@abstractmethod
def confirm_write_for_keys(
self,
rank: int,
key_confirmations: list[tuple[str, int]],
pages_to_release: list[int] | None = None,
) -> None:
"""Confirm write operations for keys and optionally release pages."""
pass
@abstractmethod
def batch_key_exists(self, keys: list[str]) -> list[bool]:
"""Check if keys exist and are complete across all ranks."""
pass
@abstractmethod
def get_key_locations(self, rank: int, keys: list[str]) -> list[int]:
"""Get page indices for keys on a specific rank."""
pass
class Hf3fsGlobalMetadataClient(Hf3fsMetadataInterface):
"""Global HTTP metadata client for HF3FS."""
def __init__(self, base_url: str = "http://localhost:18000", max_retries: int = 3):
self.base_url = base_url.rstrip("/")
self._session = requests.Session()
retry_strategy = Retry(
total=max_retries,
backoff_factor=0.3,
status_forcelist=[500, 502, 503, 504],
allowed_methods=["GET", "POST"],
)
adapter = HTTPAdapter(max_retries=retry_strategy)
self._session.mount("http://", adapter)
def _post(self, endpoint: str, json_data: dict) -> dict:
"""Make POST request to metadata server."""
try:
url = f"{self.base_url}/{endpoint}"
headers = {"Content-Type": "application/json"}
if HAS_ORJSON:
payload = orjson.dumps(json_data)
else:
import json
payload = json.dumps(json_data).encode("utf-8")
response = self._session.post(url, data=payload, headers=headers)
response.raise_for_status()
if response.status_code == 204 or not response.content:
return {}
if HAS_ORJSON:
return orjson.loads(response.content)
else:
return response.json()
except requests.exceptions.RequestException as e:
logger.error("Failed to POST to %s after retries: %s", endpoint, e)
raise RuntimeError(f"Failed to connect to metadata server: {e}") from e
def initialize(self, rank: int, num_pages: int = 0, role: str = "worker") -> None:
"""Initialize a rank with specified number of pages."""
self._post(f"rank/{rank}/initialize", {"num_pages": num_pages, "role": role})
def allocate_pages_for_keys(
self, rank: int, keys: list[tuple[str, str]]
) -> list[tuple[str, int]]:
"""Allocate pages for keys on the specified rank."""
response = self._post("keys/batch_allocate", {"rank": rank, "keys": keys})
# Convert response to expected format
return response.get("results", {})
def confirm_write_for_keys(
self,
rank: int,
key_confirmations: list[tuple[str, int]],
pages_to_release: list[int] | None = None,
) -> None:
"""Confirm write operations for keys and optionally release pages."""
payload = {
"rank": rank,
"confirmations": key_confirmations,
"pages_to_release": pages_to_release or [],
}
self._post("keys/confirm_write", payload)
def batch_key_exists(self, keys: list[str]) -> list[bool]:
"""Check if keys exist and are complete across all ranks."""
response = self._post("keys/batch_exists", {"keys": keys})
return response.get("exists", [])
def get_key_locations(self, rank: int, keys: list[str]) -> list[int]:
"""Get page indices for keys on a specific rank."""
response = self._post("keys/get_locations", {"rank": rank, "keys": keys})
return response.get("locations", [])
def run_metadata_server(
host: str = "0.0.0.0",
port: int = 18000,
):
"""Run the improved HF3FS metadata server."""
server = Hf3fsMetadataServer()
server.run(host=host, port=port)
# --- Main Execution ---
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Improved HF3FS Metadata Server")
parser.add_argument(
"--host", type=str, default="0.0.0.0", help="Host to bind the server to."
)
parser.add_argument(
"--port", type=int, default=18000, help="Port to run the server on."
)
args = parser.parse_args()
run_metadata_server(args.host, args.port)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading
from dataclasses import dataclass, field
from typing import Optional
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
from vllm.v1.request import Request
class AtomicCounter:
"""Thread-safe atomic counter for round-robin operations."""
def __init__(self, n: int):
assert n > 0, "Counter size must be positive"
self._n = n
self._value = 0
self._lock = threading.Lock()
def next(self) -> int:
"""Get next value in round-robin fashion."""
with self._lock:
current = self._value
self._value = (current + 1) % self._n
return current
@dataclass
class LoadBlockInfo:
"""Operation for loading blocks from external storage."""
num_computed_blocks: int
num_blocks_to_load: int
need_fetch_block_ids: list[int]
@dataclass
class SaveBlockInfo:
"""Operation for saving blocks to external storage."""
skip_leading_blocks: int
@dataclass
class RequestSchedulingState:
"""Unified request scheduling state management."""
request_id: str
request: Request | None = None
# Token and block tracking
token_ids: list[int] = field(default_factory=list)
allocated_block_ids: list[int] = field(default_factory=list)
num_saved_blocks: int = 0
# Load operation info
load_op: LoadBlockInfo | None = None
# Scheduling phase
phase: str = "NEW" # NEW -> WAITING_TO_LOAD -> ACTIVE -> FINISHED
def needs_loading(self) -> bool:
"""Check if request needs loading."""
return self.load_op is not None and self.load_op.num_blocks_to_load > 0
def is_ready_to_load(self) -> bool:
"""Check if request is ready for loading."""
return self.phase == "WAITING_TO_LOAD" and self.needs_loading()
def update_tokens_and_blocks(self, new_token_ids: list[int], new_block_ids) -> None:
"""Update with new tokens and blocks."""
if new_token_ids:
self.token_ids.extend(new_token_ids)
if new_block_ids is not None:
normalized_block_ids = self._normalize_block_ids(new_block_ids)
self.allocated_block_ids.extend(normalized_block_ids)
def _normalize_block_ids(self, block_ids) -> list[int]:
"""Normalize block_ids to list format."""
if not block_ids:
return []
if isinstance(block_ids, tuple):
return block_ids[0] if block_ids else []
if isinstance(block_ids, list):
return block_ids
return []
@dataclass
class HF3FSRequestMetadata:
"""Metadata for a single request in HF3FS connector."""
request_id: str
token_ids: list[int]
block_ids: list[int]
load_block_op: LoadBlockInfo | None = None
save_block_op: SaveBlockInfo | None = None
@staticmethod
def from_scheduling_state(
state: "RequestSchedulingState",
block_size: int,
load_op: LoadBlockInfo | None = None,
skip_leading_blocks: int | None = None,
) -> Optional["HF3FSRequestMetadata"]:
"""Create request metadata from scheduling state."""
token_count = len(state.token_ids)
total_blocks = token_count // block_size
skip_blocks = (
state.num_saved_blocks
if skip_leading_blocks is None
else skip_leading_blocks
)
new_blocks_to_save = total_blocks - state.num_saved_blocks
if new_blocks_to_save <= 0 and load_op is None:
return None
state.num_saved_blocks = total_blocks
return HF3FSRequestMetadata(
request_id=state.request_id,
token_ids=state.token_ids,
block_ids=state.allocated_block_ids,
load_block_op=load_op,
save_block_op=SaveBlockInfo(skip_leading_blocks=skip_blocks),
)
class HF3FSConnectorMetadata(KVConnectorMetadata):
"""Container for HF3FS connector metadata."""
def __init__(self):
self.requests: list[HF3FSRequestMetadata] = []
def add_request(self, request_metadata: HF3FSRequestMetadata) -> None:
"""Add request to metadata."""
self.requests.append(request_metadata)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.logger import init_logger
from vllm.triton_utils import tl, triton
@triton.jit
def kv_cache_scatter_kernel(
kv_cache_ptrs_ptr,
source_ptr,
token_indices_ptr,
num_tokens_in_block,
hidden_size,
total_token_in_kvcache,
num_layers,
is_mla,
BLOCK_SIZE: tl.constexpr,
):
layer_idx = tl.program_id(0)
token_pos = tl.program_id(1)
if layer_idx >= num_layers or token_pos >= num_tokens_in_block:
return
token_idx = tl.load(token_indices_ptr + token_pos)
kv_cache_ptr = tl.cast(tl.load(kv_cache_ptrs_ptr + layer_idx), source_ptr.dtype)
if token_idx >= total_token_in_kvcache:
return
if is_mla:
# MLA format: source [num_layers, num_tokens_in_block, hidden_size]
# MLA format: target [total_token_in_kvcache, hidden_size] (per layer)
source_offset = (layer_idx * num_tokens_in_block + token_pos) * hidden_size
target_offset = token_idx * hidden_size
for i in range(0, hidden_size, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
mask = offset < hidden_size
val = tl.load(source_ptr + source_offset + offset, mask=mask)
tl.store(kv_cache_ptr + target_offset + offset, val, mask=mask)
else:
# MHA format: source [num_layers, 2, num_tokens_in_block, hidden_size]
# MHA format: target [2, total_token_in_kvcache, hidden_size]
source_offset_k = (
layer_idx * num_tokens_in_block * 2 + token_pos
) * hidden_size
source_offset_v = (
layer_idx * num_tokens_in_block * 2 + num_tokens_in_block + token_pos
) * hidden_size
target_offset_k = token_idx * hidden_size
target_offset_v = (total_token_in_kvcache + token_idx) * hidden_size
for i in range(0, hidden_size, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
mask = offset < hidden_size
val_k = tl.load(source_ptr + source_offset_k + offset, mask=mask)
val_v = tl.load(source_ptr + source_offset_v + offset, mask=mask)
tl.store(kv_cache_ptr + target_offset_k + offset, val_k, mask=mask)
tl.store(kv_cache_ptr + target_offset_v + offset, val_v, mask=mask)
@triton.jit
def kv_cache_gather_kernel(
kv_cache_ptrs_ptr,
dst_ptr,
token_indices_ptr,
num_tokens_in_block,
hidden_size,
total_token_in_kvcache,
num_layers,
is_mla,
BLOCK_SIZE: tl.constexpr,
):
layer_idx = tl.program_id(0)
token_pos = tl.program_id(1)
if layer_idx >= num_layers or token_pos >= num_tokens_in_block:
return
token_idx = tl.load(token_indices_ptr + token_pos)
kv_cache_ptr = tl.cast(tl.load(kv_cache_ptrs_ptr + layer_idx), dst_ptr.dtype)
if token_idx >= total_token_in_kvcache:
return
if is_mla:
# MLA format: source [total_token_in_kvcache, hidden_size] (per layer)
# MLA format: dst [num_layers, num_tokens_in_block, hidden_size]
kvcache_offset = token_idx * hidden_size
dst_offset = (layer_idx * num_tokens_in_block + token_pos) * hidden_size
for i in range(0, hidden_size, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
mask = offset < hidden_size
val = tl.load(kv_cache_ptr + kvcache_offset + offset, mask=mask)
tl.store(dst_ptr + dst_offset + offset, val, mask=mask)
else:
# MHA format: source [2, total_token_in_kvcache, hidden_size]
# MHA format: dst [num_layers, 2, num_tokens_in_block, hidden_size]
dst_offset_k = (layer_idx * num_tokens_in_block * 2 + token_pos) * hidden_size
dst_offset_v = (
layer_idx * num_tokens_in_block * 2 + num_tokens_in_block + token_pos
) * hidden_size
kvcache_offset_k = token_idx * hidden_size
kvcache_offset_v = (total_token_in_kvcache + token_idx) * hidden_size
for i in range(0, hidden_size, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
mask = offset < hidden_size
val_k = tl.load(kv_cache_ptr + kvcache_offset_k + offset, mask=mask)
val_v = tl.load(kv_cache_ptr + kvcache_offset_v + offset, mask=mask)
tl.store(dst_ptr + dst_offset_k + offset, val_k, mask=mask)
tl.store(dst_ptr + dst_offset_v + offset, val_v, mask=mask)
def scatter_kv_caches(
kv_caches_ptrs: torch.Tensor,
total_token_in_kvcache: int,
src_tensor: torch.Tensor,
token_indices: list[int],
is_mla: bool = False,
) -> None:
"""Scatter KV cache data from source tensor to KV cache storage.
Args:
kv_caches_ptrs: Tensor of KV cache pointers (one per layer)
total_token_in_kvcache: Total number of tokens in KV cache
src_tensor: Source tensor containing data to scatter
- MHA format: [num_layers, 2, num_tokens_in_block, hidden_size]
- MLA format: [num_layers, num_tokens_in_block, hidden_size]
token_indices: List of token positions to update
is_mla: Whether using MLA model format
"""
num_layers = len(kv_caches_ptrs)
num_tokens_in_block = len(token_indices)
if is_mla:
# MLA: src_tensor is [num_layers, num_tokens_in_block, hidden_size]
assert len(src_tensor.shape) == 3, (
f"MLA src_tensor should be 3D, got {src_tensor.shape}"
)
hidden_size = src_tensor.shape[2]
else:
# MHA: src_tensor is [num_layers, 2, num_tokens_in_block, hidden_size]
assert len(src_tensor.shape) == 4, (
f"MHA src_tensor should be 4D, got {src_tensor.shape}"
)
hidden_size = src_tensor.shape[3]
device = src_tensor.device
token_indices_tensor = torch.tensor(
token_indices, dtype=torch.int32, device="cpu"
).to(device, non_blocking=True)
grid = (num_layers, num_tokens_in_block)
BLOCK_SIZE = 128
kv_cache_scatter_kernel[grid](
kv_caches_ptrs,
src_tensor,
token_indices_tensor,
num_tokens_in_block,
hidden_size,
total_token_in_kvcache,
num_layers,
is_mla,
BLOCK_SIZE=BLOCK_SIZE,
)
def gather_kv_caches(
kv_caches_ptrs: torch.Tensor,
total_token_in_kvcache: int,
dst_tensor: torch.Tensor,
token_indices: list[int],
is_mla: bool = False,
) -> None:
"""Gather KV cache data from KV cache storage to destination tensor.
Args:
kv_caches_ptrs: Tensor of KV cache pointers (one per layer)
total_token_in_kvcache: Total number of tokens in KV cache
dst_tensor: Destination tensor to store gathered data
- MHA format: [num_layers, 2, num_tokens_in_block, hidden_size]
- MLA format: [num_layers, num_tokens_in_block, hidden_size]
token_indices: List of token positions to gather
is_mla: Whether using MLA model format
"""
num_layers = kv_caches_ptrs.shape[0]
num_tokens_in_block = len(token_indices)
if is_mla:
# MLA: dst_tensor is [num_layers, num_tokens_in_block, hidden_size]
assert len(dst_tensor.shape) == 3, (
f"MLA dst_tensor should be 3D, got {dst_tensor.shape}"
)
assert dst_tensor.shape[0] == num_layers, (
f"Layer count mismatch: {dst_tensor.shape[0]} vs {num_layers}"
)
assert dst_tensor.shape[1] == num_tokens_in_block, (
f"Token count mismatch: {dst_tensor.shape[1]} vs {num_tokens_in_block}"
)
hidden_size = dst_tensor.shape[2]
else:
# MHA: dst_tensor is [num_layers, 2, num_tokens_in_block, hidden_size]
assert len(dst_tensor.shape) == 4, (
f"MHA dst_tensor should be 4D, got {dst_tensor.shape}"
)
assert dst_tensor.shape[0] == num_layers, (
f"Layer count mismatch: {dst_tensor.shape[0]} vs {num_layers}"
)
assert dst_tensor.shape[1] == 2, (
f"MHA should have 2 (K,V) components, got {dst_tensor.shape[1]}"
)
assert dst_tensor.shape[2] == num_tokens_in_block, (
f"Token count mismatch: {dst_tensor.shape[2]} vs {num_tokens_in_block}"
)
hidden_size = dst_tensor.shape[3]
device = dst_tensor.device
token_indices_tensor = torch.tensor(
token_indices, dtype=torch.int32, device="cpu"
).to(device, non_blocking=True)
grid = (num_layers, num_tokens_in_block)
BLOCK_SIZE = 128
kv_cache_gather_kernel[grid](
kv_caches_ptrs,
dst_tensor,
token_indices_tensor,
num_tokens_in_block,
hidden_size,
total_token_in_kvcache,
num_layers,
is_mla,
BLOCK_SIZE=BLOCK_SIZE,
)
class CopyBufferAllocator:
"""Memory pool for tensor buffers to avoid frequent allocation/deallocation."""
def __init__(
self, device: torch.device, dtype: torch.dtype, shape: list, max_count: int
):
self._shape = shape
self._max_count = max_count
self._device = device
self._free_buffers = [
torch.empty(shape, dtype=dtype, device=device) for _ in range(max_count)
]
self._inuse_count = 0
def alloc_buffer(self, count: int) -> list[torch.Tensor] | None:
"""Allocate buffers from the pool."""
if count == 0:
return []
if self._inuse_count + count <= self._max_count:
self._inuse_count += count
result = self._free_buffers[-count:]
del self._free_buffers[-count:]
return result
return None
def free_buffer(self, buffers: list[torch.Tensor]) -> None:
"""Return buffers to the pool."""
if not buffers:
return
if self._inuse_count >= len(buffers):
self._inuse_count -= len(buffers)
self._free_buffers.extend(buffers)
else:
raise RuntimeError("Attempted to free more buffers than allocated")
logger = init_logger(__name__)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
import os
import torch
logger = logging.getLogger(__name__)
HF3FS_AVAILABLE = True
class Hf3fsClient:
"""Mock HF3FS client using file backend for debugging and testing."""
def __init__(self, path: str, size: int, bytes_per_page: int, entries: int):
self._size = size
self._bytes_per_page = bytes_per_page
self._entries = entries
self._file_path = path
self._ensure_file_exists()
logger.debug("Initialized mock HF3FS client: %s (%d bytes)", path, size)
def _ensure_file_exists(self) -> None:
"""Create file if it doesn't exist."""
if not os.path.exists(self._file_path):
with open(self._file_path, "w+b") as f:
f.truncate(self._size)
def batch_read(self, offsets: list[int], tensors: list[torch.Tensor]) -> list[int]:
"""Read data from file at specified offsets into tensors."""
results = []
try:
with open(self._file_path, "rb") as f:
for offset, tensor in zip(offsets, tensors):
num_bytes = tensor.numel() * tensor.element_size()
if offset < 0 or offset + num_bytes > self._size:
results.append(-1)
continue
f.seek(offset)
buffer_data = f.read(num_bytes)
if len(buffer_data) == num_bytes == self._bytes_per_page:
tensor_data = self._convert_buffer_to_tensor(
buffer_data, tensor.dtype
)
tensor.copy_(
tensor_data.reshape(tensor.shape).to(tensor.device)
)
results.append(self._bytes_per_page)
else:
logger.error(
"Read size mismatch: got %d, expected %d",
len(buffer_data),
num_bytes,
)
results.append(-1)
except Exception as e:
logger.error("Batch read error: %s", e)
results.extend([-1] * (len(offsets) - len(results)))
return results
def _convert_buffer_to_tensor(
self, buffer_data: bytes, dtype: torch.dtype
) -> torch.Tensor:
"""Convert buffer data to tensor with proper dtype handling."""
if dtype == torch.bfloat16:
tensor_data = torch.frombuffer(buffer_data, dtype=torch.uint16)
return tensor_data.view(dtype=torch.bfloat16)
else:
return torch.frombuffer(buffer_data, dtype=dtype)
def batch_write(
self, offsets: list[int], tensors: list[torch.Tensor], event: torch.cuda.Event
) -> list[int]:
"""Write data from tensors to file at specified offsets."""
results = []
try:
torch.cuda.current_stream().wait_event(event)
# Convert tensors to bytes
data_bytes_list = [self._tensor_to_bytes(tensor) for tensor in tensors]
# Write to file
with open(self._file_path, "r+b") as f:
for offset, data_bytes in zip(offsets, data_bytes_list):
if offset < 0 or offset + len(data_bytes) > self._size:
results.append(-1)
continue
f.seek(offset)
bytes_written = f.write(data_bytes)
if bytes_written == len(data_bytes) == self._bytes_per_page:
results.append(self._bytes_per_page)
else:
logger.error(
"Write size mismatch: wrote %d, expected %d",
bytes_written,
self._bytes_per_page,
)
results.append(-1)
except Exception as e:
logger.error("Batch write error: %s", e)
results.extend([-1] * (len(offsets) - len(results)))
return results
def _tensor_to_bytes(self, tensor: torch.Tensor) -> bytes:
"""Convert tensor to bytes with proper dtype handling."""
cpu_tensor = tensor.cpu()
if cpu_tensor.dtype == torch.bfloat16:
return cpu_tensor.view(dtype=torch.uint16).numpy().tobytes()
else:
return cpu_tensor.numpy().tobytes()
def get_size(self) -> int:
"""Get the total size of the storage file."""
return self._size
def close(self) -> None:
"""Close the client (no-op for file backend)."""
pass
def flush(self) -> None:
"""Flush any pending writes (no-op for file backend)."""
pass
#include <cuda_runtime.h>
#include <torch/extension.h>
#include <cstring>
#include <vector>
void read_shm(const torch::Tensor& shm, const torch::Tensor& pin,
std::vector<torch::Tensor> dst, uint64_t stream_ptr) {
py::gil_scoped_release release;
cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_ptr);
// Copy from shared memory to pinned memory
char* shm_ptr = static_cast<char*>(shm.data_ptr());
char* src_ptr = static_cast<char*>(pin.data_ptr());
std::memcpy(src_ptr, shm_ptr, shm.numel() * shm.element_size());
// Copy from pinned memory to GPU tensors
size_t current = 0;
for (size_t i = 0; i < dst.size(); ++i) {
auto& t = dst[i];
size_t t_bytes = t.numel() * t.element_size();
char* dst_ptr = static_cast<char*>(t.data_ptr());
cudaMemcpyAsync(dst_ptr, src_ptr + current, t_bytes, cudaMemcpyHostToDevice,
stream);
current += t_bytes;
}
cudaStreamSynchronize(stream);
}
void write_shm(const std::vector<torch::Tensor> src, torch::Tensor& shm,
const torch::Tensor& pin, uint64_t stream_ptr) {
py::gil_scoped_release release;
cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_ptr);
// Copy from GPU tensors to pinned memory
char* dst_ptr = static_cast<char*>(pin.data_ptr());
size_t current = 0;
for (size_t i = 0; i < src.size(); ++i) {
auto& t = src[i];
size_t t_bytes = t.numel() * t.element_size();
char* src_ptr = static_cast<char*>(t.data_ptr());
cudaMemcpyAsync(dst_ptr + current, src_ptr, t_bytes, cudaMemcpyDeviceToHost,
stream);
current += t_bytes;
}
cudaStreamSynchronize(stream);
// Copy from pinned memory to shared memory
char* shm_ptr = static_cast<char*>(shm.data_ptr());
std::memcpy(shm_ptr, dst_ptr, shm.numel() * shm.element_size());
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("read_shm", &read_shm, "Read tensors from shared memory");
m.def("write_shm", &write_shm, "Write tensors to shared memory");
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment