Unverified Commit 96b5004b authored by ibifrost's avatar ibifrost Committed by GitHub
Browse files

[KVConnector] Support 3FS KVConnector (#37636)


Signed-off-by: default avatarwuchenxin <wuchenxin.wcx@alibaba-inc.com>
Signed-off-by: default avataribifrost <47308427+ibifrost@users.noreply.github.com>
Co-authored-by: default avatarSimon Mo <simon.mo@hey.com>
parent 98e1a43a
......@@ -1013,6 +1013,7 @@ package_data = {
"model_executor/layers/quantization/utils/configs/*.json",
"entrypoints/serve/instrumentator/static/*.js",
"entrypoints/serve/instrumentator/static/*.css",
"distributed/kv_transfer/kv_connector/v1/hf3fs/utils/*.cpp",
]
}
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for resource management in hf3fs_client.py: constructor failure cleanup
and idempotent close(). Tests use mock to replace real I/O operations
(hf3fs_fuse.io, SharedMemory, os, CUDA).
Requires hf3fs_fuse.io to be installed; skipped otherwise.
"""
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
HF3FS_AVAILABLE = True
try:
from hf3fs_fuse.io import ( # noqa: F401
deregister_fd,
extract_mount_point,
make_ioring,
make_iovec,
register_fd,
)
from vllm.distributed.kv_transfer.kv_connector.v1.hf3fs.hf3fs_client import (
Hf3fsClient,
)
except Exception:
HF3FS_AVAILABLE = False
requires_hf3fs = pytest.mark.skipif(
not HF3FS_AVAILABLE,
reason="hf3fs_fuse.io is not available on this machine",
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
class _FakeShm:
"""Shared-memory stub matching the multiprocessing.shared_memory.SharedMemory
interface used by Hf3fsClient:
Attributes accessed by the constructor:
.buf – memoryview / buffer-protocol object consumed by torch.frombuffer
Methods called during normal lifetime:
.unlink() – called right after the iovec is set up
.close() – called in _release_resources()
"""
def __init__(self, size: int = 1024):
self._data = bytearray(size)
self.buf = memoryview(self._data)
self.closed = False
self.close_call_count = 0
self.unlink_call_count = 0
def close(self):
self.closed = True
self.close_call_count += 1
def unlink(self):
self.unlink_call_count += 1
# ===========================================================================
# TestHf3fsClientResourceManagement
# ===========================================================================
@requires_hf3fs
class TestHf3fsClientResourceManagement:
"""Tests for constructor failure cleanup and idempotent close()."""
_MOD = "vllm.distributed.kv_transfer.kv_connector.v1.hf3fs.hf3fs_client"
# ------------------------------------------------------------------
# Helper: build a minimal Hf3fsClient bypassing all real I/O so that
# we can fully control its internal state.
# ------------------------------------------------------------------
def _make_client(self, tmp_path):
"""Return a fully-mocked Hf3fsClient with controllable internals."""
fake_shm_r = _FakeShm()
fake_shm_w = _FakeShm()
patcher_list: list[Any] = [
patch(f"{self._MOD}.HF3FS_AVAILABLE", True),
patch(f"{self._MOD}.register_fd"),
patch(f"{self._MOD}.deregister_fd"),
patch(f"{self._MOD}.extract_mount_point", return_value="/mnt/hf3fs"),
patch(f"{self._MOD}.make_ioring", return_value=MagicMock()),
patch(f"{self._MOD}.make_iovec", return_value=MagicMock()),
patch(
"multiprocessing.shared_memory.SharedMemory",
side_effect=[fake_shm_r, fake_shm_w],
),
patch("os.open", return_value=99),
patch("os.ftruncate"),
patch("os.close"),
patch("os.fsync"),
patch("torch.cuda.Stream", return_value=MagicMock()),
patch("torch.frombuffer", return_value=MagicMock()),
patch("torch.empty", return_value=MagicMock()),
]
for p in patcher_list:
p.start()
try:
client = Hf3fsClient(
path=str(tmp_path / "test.bin"),
size=1024,
bytes_per_page=256,
entries=4,
)
finally:
for p in patcher_list:
p.stop()
# Manually point internal handles to our controllable fakes so that
# assertions after close() can inspect them directly.
client.shm_r = fake_shm_r
client.shm_w = fake_shm_w
client.file = 99
return client, fake_shm_r, fake_shm_w
# ------------------------------------------------------------------
# close() idempotency
# ------------------------------------------------------------------
def test_close_idempotent_and_handles_cleared(self, tmp_path):
"""Multiple close() calls must not raise; deregister_fd called exactly
once, all handles set to None, shm.close() invoked."""
client, shm_r, shm_w = self._make_client(tmp_path)
with (
patch(f"{self._MOD}.deregister_fd") as mock_dereg,
patch("os.close"),
):
client.close() # first close
client.close() # second close — must be no-op
client.close() # third close — must be no-op
assert client._closed is True
assert mock_dereg.call_count == 1, (
f"deregister_fd called {mock_dereg.call_count} times; expected 1"
)
for attr in ("iov_r", "iov_w", "ior_r", "ior_w", "shm_r", "shm_w", "file"):
assert getattr(client, attr) is None, f"{attr} should be None after close()"
assert shm_r.closed is True
assert shm_w.closed is True
def test_flush_after_close_is_noop(self, tmp_path):
"""flush() after close() must silently do nothing (no fsync call)."""
client, _, _ = self._make_client(tmp_path)
with (
patch(f"{self._MOD}.deregister_fd"),
patch("os.close"),
patch("os.fsync") as mock_fsync,
):
client.close()
client.flush()
mock_fsync.assert_not_called()
# ------------------------------------------------------------------
# Constructor failure leaves no leaked resources
# ------------------------------------------------------------------
def test_constructor_failure_after_file_open_cleans_file(self, tmp_path):
"""If the constructor raises after os.open(), the fd must be closed."""
with (
patch(f"{self._MOD}.HF3FS_AVAILABLE", True),
patch(f"{self._MOD}.register_fd"),
patch(f"{self._MOD}.deregister_fd"),
patch(
f"{self._MOD}.extract_mount_point",
side_effect=RuntimeError("mount point not found"),
),
patch("os.open", return_value=55),
patch("os.ftruncate"),
patch("os.close") as mock_os_close,
patch("torch.cuda.Stream", return_value=MagicMock()),
pytest.raises(RuntimeError, match="mount point not found"),
):
Hf3fsClient(
path=str(tmp_path / "fail.bin"),
size=1024,
bytes_per_page=256,
entries=4,
)
mock_os_close.assert_called_once_with(55)
def test_constructor_failure_after_shm_alloc_closes_shm(self, tmp_path):
"""Constructor raises after SharedMemory creation → both shm objects closed."""
fake_shm_r = _FakeShm()
fake_shm_w = _FakeShm()
with (
patch(f"{self._MOD}.HF3FS_AVAILABLE", True),
patch(f"{self._MOD}.register_fd"),
patch(f"{self._MOD}.deregister_fd"),
patch(f"{self._MOD}.extract_mount_point", return_value="/mnt/hf3fs"),
patch(
"multiprocessing.shared_memory.SharedMemory",
side_effect=[fake_shm_r, fake_shm_w],
),
patch("os.open", return_value=66),
patch("os.ftruncate"),
patch("os.close"),
patch("torch.frombuffer", return_value=MagicMock()),
patch("torch.empty", return_value=MagicMock()),
patch(
f"{self._MOD}.make_ioring",
side_effect=RuntimeError("ioring init failed"),
),
patch(f"{self._MOD}.make_iovec", return_value=MagicMock()),
patch("torch.cuda.Stream", return_value=MagicMock()),
pytest.raises(RuntimeError, match="ioring init failed"),
):
Hf3fsClient(
path=str(tmp_path / "fail2.bin"),
size=1024,
bytes_per_page=256,
entries=4,
)
assert fake_shm_r.closed is True, (
"shm_r was not closed after constructor failure"
)
assert fake_shm_w.closed is True, (
"shm_w was not closed after constructor failure"
)
def test_constructor_failure_does_not_close_unallocated_shm(self, tmp_path):
"""Failure before SharedMemory is created must not raise AttributeError
or TypeError from cleanup."""
with (
patch(f"{self._MOD}.HF3FS_AVAILABLE", True),
patch(f"{self._MOD}.register_fd"),
patch(f"{self._MOD}.deregister_fd"),
patch(
f"{self._MOD}.extract_mount_point",
side_effect=RuntimeError("early failure"),
),
patch("os.open", return_value=77),
patch("os.ftruncate"),
patch("os.close"),
patch("torch.cuda.Stream", return_value=MagicMock()),
pytest.raises(RuntimeError, match="early failure"),
):
Hf3fsClient(
path=str(tmp_path / "early_fail.bin"),
size=1024,
bytes_per_page=256,
entries=4,
)
# ------------------------------------------------------------------
# _release_resources on already-cleared state must be a no-op
# ------------------------------------------------------------------
def test_release_resources_on_empty_state_is_safe(self, tmp_path):
"""_release_resources() on a fully-cleared client must not raise."""
client, _, _ = self._make_client(tmp_path)
with (
patch(f"{self._MOD}.deregister_fd"),
patch("os.close"),
):
client.close() # clears all handles
with (
patch(f"{self._MOD}.deregister_fd") as mock_dereg2,
patch("os.close") as mock_os_close2,
):
client._release_resources() # must not raise
mock_dereg2.assert_not_called()
mock_os_close2.assert_not_called()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for HF3FS KV Connector high-level components:
- TestHf3fsMockClient : file-backed mock client I/O correctness
- TestHF3FSKVConnectorStats: metric collection, aggregation, serialisation
"""
import os
from unittest.mock import MagicMock
import pytest
import torch
from vllm.distributed.kv_transfer.kv_connector.v1.hf3fs.hf3fs_connector import (
HF3FSKVConnectorStats,
)
from vllm.distributed.kv_transfer.kv_connector.v1.hf3fs.utils.hf3fs_mock_client import (
Hf3fsClient as MockHf3fsClient,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
@pytest.fixture
def hf3fs_stats():
"""Fresh HF3FSKVConnectorStats instance."""
return HF3FSKVConnectorStats()
def _make_cuda_event():
"""Return a real CUDA event when available, otherwise a MagicMock."""
if torch.cuda.is_available():
return torch.cuda.Event()
return MagicMock()
# ===========================================================================
# TestHf3fsMockClient
# ===========================================================================
class TestHf3fsMockClient:
"""Tests for hf3fs_mock_client.Hf3fsClient (file-backend mock)."""
def test_init_creates_file(self, tmp_path):
"""Initializing the client should create the backing file."""
path = str(tmp_path / "test_file")
client = MockHf3fsClient(path=path, size=4096, bytes_per_page=512, entries=4)
assert os.path.exists(path), "Backing file should be created on init"
assert os.path.getsize(path) == 4096
client.close()
@pytest.mark.parametrize(
"dtype, bytes_per_page",
[
(torch.float32, 512),
(torch.float16, 256),
(torch.bfloat16, 256),
],
ids=["float32", "float16", "bfloat16"],
)
def test_batch_write_and_read_dtype(self, tmp_path, dtype, bytes_per_page):
"""Write a tensor of the given dtype and verify round-trip correctness."""
path = str(tmp_path / f"rw_{dtype}")
client = MockHf3fsClient(
path=path, size=bytes_per_page * 8, bytes_per_page=bytes_per_page, entries=4
)
elem_size = torch.tensor([], dtype=dtype).element_size()
numel = bytes_per_page // elem_size
tensor_write = torch.arange(numel, dtype=dtype)
event = _make_cuda_event()
results = client.batch_write([0], [tensor_write], event)
assert results == [bytes_per_page], f"Write should succeed, got {results}"
tensor_read = torch.zeros(numel, dtype=dtype)
results = client.batch_read([0], [tensor_read])
assert results == [bytes_per_page], f"Read should succeed, got {results}"
assert torch.equal(tensor_write, tensor_read), (
"Read tensor should match written tensor"
)
client.close()
def test_batch_read_empty_file_returns_error(self, tmp_path):
"""Reading out-of-bounds offset should return -1."""
bytes_per_page = 128
size = bytes_per_page * 4
path = str(tmp_path / "empty_read")
client = MockHf3fsClient(
path=path, size=size, bytes_per_page=bytes_per_page, entries=4
)
numel = bytes_per_page // 4
tensor_read = torch.zeros(numel, dtype=torch.float32)
results = client.batch_read([size], [tensor_read]) # offset == size => OOB
assert results[0] == -1, "Out-of-bounds read should return -1"
client.close()
def test_batch_write_out_of_bounds_returns_error(self, tmp_path):
"""Writing at an offset beyond file size should return -1."""
bytes_per_page = 128
size = bytes_per_page * 4
path = str(tmp_path / "oob_write")
client = MockHf3fsClient(
path=path, size=size, bytes_per_page=bytes_per_page, entries=4
)
numel = bytes_per_page // 4
tensor = torch.ones(numel, dtype=torch.float32)
event = _make_cuda_event()
results = client.batch_write([size], [tensor], event) # OOB offset
assert results[0] == -1, "Out-of-bounds write should return -1"
client.close()
def test_multiple_tensors_rw(self, tmp_path):
"""Write multiple tensors at different offsets, then read all back."""
bytes_per_page = 128
n = 4
path = str(tmp_path / "multi_rw")
client = MockHf3fsClient(
path=path,
size=bytes_per_page * n * 2,
bytes_per_page=bytes_per_page,
entries=8,
)
tensors_write = [
torch.full((bytes_per_page // 4,), float(i), dtype=torch.float32)
for i in range(n)
]
offsets = [i * bytes_per_page for i in range(n)]
event = _make_cuda_event()
results = client.batch_write(offsets, tensors_write, event)
assert all(r == bytes_per_page for r in results)
tensors_read = [
torch.zeros(bytes_per_page // 4, dtype=torch.float32) for _ in range(n)
]
results = client.batch_read(offsets, tensors_read)
assert all(r == bytes_per_page for r in results)
for i, (tw, tr) in enumerate(zip(tensors_write, tensors_read)):
assert torch.allclose(tw, tr), f"Tensor {i} mismatch after round-trip"
client.close()
def test_flush_and_close_no_error(self, tmp_path):
"""flush() and close() should not raise exceptions."""
path = str(tmp_path / "flush_close")
client = MockHf3fsClient(path=path, size=1024, bytes_per_page=128, entries=4)
client.flush()
client.close()
# ===========================================================================
# TestHF3FSKVConnectorStats
# ===========================================================================
class TestHF3FSKVConnectorStats:
"""Tests for HF3FSKVConnectorStats metric collection and aggregation."""
def test_initial_is_empty(self, hf3fs_stats):
"""Fresh stats object should report is_empty() == True."""
assert hf3fs_stats.is_empty() is True
@pytest.mark.parametrize(
"task_type, duration_key",
[
("Saved", "save_duration"),
("Loaded", "load_duration"),
],
ids=["save", "load"],
)
def test_record_success_duration(self, hf3fs_stats, task_type, duration_key):
"""Recording a successful task should update duration list and total count."""
hf3fs_stats.record_success_task_duration(task_type, 0.5)
assert not hf3fs_stats.is_empty()
assert len(hf3fs_stats.data[duration_key]) == 1
assert hf3fs_stats.data[duration_key][0] == pytest.approx(0.5)
assert hf3fs_stats.data["num_transfer_task"] == 1
@pytest.mark.parametrize(
"task_type, failed_key",
[
("Saved", "num_failed_save"),
("Loaded", "num_failed_load"),
],
ids=["save", "load"],
)
def test_record_failed_task(self, hf3fs_stats, task_type, failed_key):
"""Recording a failed task should increment the corresponding counter."""
hf3fs_stats.record_failed_task_count(task_type)
assert hf3fs_stats.data[failed_key] == 1
assert hf3fs_stats.data["num_transfer_task"] == 1
def test_aggregate_two_stats(self):
"""aggregate() should merge save/load duration lists and sum counters."""
stats1 = HF3FSKVConnectorStats()
stats1.record_success_task_duration("Saved", 0.1)
stats1.record_success_task_duration("Loaded", 0.2)
stats2 = HF3FSKVConnectorStats()
stats2.record_success_task_duration("Saved", 0.3)
stats2.record_failed_task_count("Loaded")
stats1.aggregate(stats2)
assert stats1.data["save_duration"] == pytest.approx([0.1, 0.3])
assert stats1.data["load_duration"] == pytest.approx([0.2])
assert stats1.data["num_failed_load"] == 1
assert stats1.data["num_transfer_task"] == 4
def test_reduce_with_data(self):
"""reduce() computes correct averages when data is present."""
stats = HF3FSKVConnectorStats()
stats.record_success_task_duration("Saved", 1.0)
stats.record_success_task_duration("Saved", 3.0)
result = stats.reduce()
assert result["Num save task success"] == pytest.approx(2.0, rel=0.01)
assert result["Num save task failed"] == pytest.approx(0.0, rel=0.01)
assert result["Avg save duration (ms)"] == pytest.approx(2000.0, rel=0.01)
def test_clone_and_reset(self, hf3fs_stats):
"""clone_and_reset() returns a copy with data and resets the original."""
hf3fs_stats.record_success_task_duration("Saved", 0.7)
hf3fs_stats.record_success_task_duration("Loaded", 0.4)
clone = hf3fs_stats.clone_and_reset()
assert clone.data["num_transfer_task"] == 2
assert hf3fs_stats.is_empty()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for HF3FS metadata server data structures and allocation logic:
- RankFileMetadata : page allocation / release primitives
- KeyMetadata : per-key rank-page tracking and completion detection
- GlobalMetadataState : coordinated allocation with cache-hit semantics
"""
import pytest
from vllm.distributed.kv_transfer.kv_connector.v1.hf3fs.hf3fs_metadata_server import (
GlobalMetadataState,
KeyMetadata,
RankFileMetadata,
)
# ===========================================================================
# TestRankFileMetadata
# ===========================================================================
class TestRankFileMetadata:
"""Unit tests for RankFileMetadata page allocation primitives."""
@pytest.mark.parametrize(
"alloc_count, expected_pages",
[(3, 3), (5, 0)],
ids=["alloc_partial", "alloc_exceeds"],
)
def test_allocate_pages(self, alloc_count, expected_pages):
"""allocate_pages returns correct pages or empty list when insufficient."""
rank_meta = RankFileMetadata(rank_id=0, num_pages=3, free_pages=list(range(3)))
pages = rank_meta.allocate_pages(alloc_count)
assert len(pages) == expected_pages
if expected_pages > 0:
rank_meta.release_pages(pages)
assert rank_meta.get_free_page_count() == 3
def test_release_pages_restores_count(self):
"""Releasing allocated pages returns them to the free pool."""
rank_meta = RankFileMetadata(rank_id=0, num_pages=4, free_pages=list(range(4)))
pages = rank_meta.allocate_pages(2)
assert rank_meta.get_free_page_count() == 2
rank_meta.release_pages(pages)
assert rank_meta.get_free_page_count() == 4
def test_release_pages_no_duplicates(self):
"""Releasing the same page twice must not create duplicates."""
rank_meta = RankFileMetadata(rank_id=0, num_pages=3, free_pages=list(range(3)))
rank_meta.allocate_pages(1) # takes page 0
rank_meta.release_pages([0])
rank_meta.release_pages([0]) # second release of the same page
assert rank_meta.get_free_page_count() == 3
# ===========================================================================
# TestKeyMetadata
# ===========================================================================
class TestKeyMetadata:
"""Unit tests for KeyMetadata completion tracking."""
def test_is_complete_false_until_all_ranks(self):
"""is_complete() returns True only when all ranks confirmed."""
key_meta = KeyMetadata(key="k", rank_to_page={}, tp_world_size=2)
assert key_meta.is_complete() is False
key_meta.add_rank_page(0, 5)
assert key_meta.is_complete() is False
key_meta.add_rank_page(1, 10)
assert key_meta.is_complete() is True
def test_get_rank_page_returns_none_for_missing_rank(self):
"""get_rank_page() returns None when the rank has no entry."""
key_meta = KeyMetadata(key="k", rank_to_page={0: 3}, tp_world_size=2)
assert key_meta.get_rank_page(0) == 3
assert key_meta.get_rank_page(1) is None
def test_get_all_pages(self):
"""get_all_pages() returns all (rank, page) pairs."""
key_meta = KeyMetadata(key="k", rank_to_page={0: 1, 1: 2}, tp_world_size=2)
pairs = key_meta.get_all_pages()
assert set(pairs) == {(0, 1), (1, 2)}
# ===========================================================================
# TestGlobalMetadataStateAllocation
# ===========================================================================
class TestGlobalMetadataStateAllocation:
"""Tests for GlobalMetadataState allocation and cache-hit semantics."""
def test_uninitialized_rank_raises_on_allocate(self):
"""allocate_pages_for_keys raises ValueError for unknown rank."""
state = GlobalMetadataState()
with pytest.raises((ValueError, Exception)):
state.allocate_pages_for_keys(99, [("key", "")])
def test_uninitialized_rank_raises_on_get_locations(self):
"""get_key_locations raises ValueError for unknown rank."""
state = GlobalMetadataState()
with pytest.raises((ValueError, Exception)):
state.get_key_locations(99, ["any_key"])
def test_basic_allocation_and_confirm(self):
"""Allocating a page and confirming it marks the key as complete."""
state = GlobalMetadataState()
state.initialize_rank(0, 4)
results = state.allocate_pages_for_keys(0, [("K", "")])
assert results["K"] >= 0
state.confirm_write_for_keys(0, [("K", results["K"])])
assert state.batch_key_exists(["K"]) == [True]
locations = state.get_key_locations(0, ["K"])
assert locations == [results["K"]]
def test_allocate_pages_cache_hit_does_not_leak_pages(self):
"""Cache-hit key must not consume a page from the free pool;
the pre-allocated slot must be returned before reusing the existing page.
"""
state = GlobalMetadataState()
state.initialize_rank(0, 5) # 5 free pages: [0,1,2,3,4]
# Simulate a key that has already been fully written and confirmed.
state.key_metadata["K_cached"] = KeyMetadata(
key="K_cached", rank_to_page={0: 2}, tp_world_size=1
)
free_before = state.rank_metadata[0].get_free_page_count() # 5
results = state.allocate_pages_for_keys(0, [("K_cached", ""), ("K_new", "")])
free_after = state.rank_metadata[0].get_free_page_count()
# Cache-hit key must reuse its existing page.
assert results["K_cached"] == 2, (
f"Cache-hit key should reuse page 2, got {results['K_cached']}"
)
# New key must receive a valid page.
assert results["K_new"] >= 0, (
f"New key should get a valid page, got {results['K_new']}"
)
# Exactly one page consumed from the free pool.
assert free_before - free_after == 1, (
f"Expected 1 page consumed, got delta={free_before - free_after}"
)
def test_allocate_pages_all_cache_hits_frees_all_slots(self):
"""When every key in the batch is a cache hit, no pages are consumed."""
state = GlobalMetadataState()
state.initialize_rank(0, 5)
for key, page in (("K1", 0), ("K2", 1)):
state.key_metadata[key] = KeyMetadata(
key=key, rank_to_page={0: page}, tp_world_size=1
)
free_before = state.rank_metadata[0].get_free_page_count()
results = state.allocate_pages_for_keys(0, [("K1", ""), ("K2", "")])
free_after = state.rank_metadata[0].get_free_page_count()
assert results["K1"] == 0
assert results["K2"] == 1
assert free_after == free_before, (
f"All-cache-hit batch must not consume free pages; "
f"before={free_before}, after={free_after}"
)
def test_allocate_returns_minus_one_when_pool_exhausted(self):
"""If the free pool is exhausted, all new keys receive -1."""
state = GlobalMetadataState()
state.initialize_rank(0, 1) # only 1 free page
results = state.allocate_pages_for_keys(0, [("K1", ""), ("K2", "")])
# allocate_pages uses all-or-nothing: 2 needed but only 1 available → []
assert all(v == -1 for v in results.values()), f"Expected all -1, got {results}"
def test_confirm_write_releases_pages(self):
"""confirm_write_for_keys with pages_to_release returns them to pool."""
state = GlobalMetadataState()
state.initialize_rank(0, 3)
results = state.allocate_pages_for_keys(0, [("K", "")])
page = results["K"]
free_after_alloc = state.rank_metadata[0].get_free_page_count()
state.confirm_write_for_keys(0, [("K", page)], pages_to_release=[page])
free_after_release = state.rank_metadata[0].get_free_page_count()
assert free_after_release == free_after_alloc + 1
......@@ -211,15 +211,18 @@ KVConnectorFactory.register_connector(
"vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector",
"MooncakeConnector",
)
KVConnectorFactory.register_connector(
"FlexKVConnectorV1",
"vllm.distributed.kv_transfer.kv_connector.v1.flexkv_connector",
"FlexKVConnectorV1",
)
KVConnectorFactory.register_connector(
"SimpleCPUOffloadConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.simple_cpu_offload_connector",
"SimpleCPUOffloadConnector",
)
KVConnectorFactory.register_connector(
"HF3FSKVConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.hf3fs.hf3fs_connector",
"HF3FSKVConnector",
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
import multiprocessing
import os
import threading
from functools import wraps
from pathlib import Path
import torch
import torch.utils.cpp_extension
from torch.utils.cpp_extension import load
root = Path(__file__).parent.resolve()
cuda_include_path = os.path.join(torch.utils.cpp_extension.CUDA_HOME, "include")
hf3fs_utils = load(
name="hf3fs_utils",
sources=[f"{root}/utils/hf3fs_utils.cpp"],
extra_include_paths=[cuda_include_path],
)
logger = logging.getLogger(__name__)
HF3FS_AVAILABLE = True
try:
from hf3fs_fuse.io import (
deregister_fd,
extract_mount_point,
make_ioring,
make_iovec,
register_fd,
)
except ImportError:
HF3FS_AVAILABLE = False
def rsynchronized():
def _decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
with self.rlock:
return func(self, *args, **kwargs)
return wrapper
return _decorator
def wsynchronized():
def _decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
with self.wlock:
return func(self, *args, **kwargs)
return wrapper
return _decorator
class Hf3fsClient:
def __init__(self, path: str, size: int, bytes_per_page: int, entries: int):
"""Initialize the HF3FS client with hf3fs_fuse.
Args:
path: Path to the file used for storage
size: Total size of the storage file in bytes
bytes_per_page: Size of each page in bytes
entries: Maximum number of concurrent operations
"""
if not HF3FS_AVAILABLE:
raise ImportError(
"hf3fs_fuse.io is not available. Please install the hf3fs_fuse package."
)
self.path = path
self.size = size
self.bytes_per_page = bytes_per_page
self.entries = entries
self._closed = False
self.file = None
self.shm_r = None
self.shm_w = None
self.ior_r = None
self.ior_w = None
self.iov_r = None
self.iov_w = None
try:
# Create the file if it doesn't exist and set its size
self.file = os.open(self.path, os.O_RDWR | os.O_CREAT)
os.ftruncate(self.file, size)
register_fd(self.file)
self.hf3fs_mount_point = extract_mount_point(path)
self.bs = self.bytes_per_page
self.shm_r = multiprocessing.shared_memory.SharedMemory(
size=self.bs * self.entries, create=True
)
self.shm_w = multiprocessing.shared_memory.SharedMemory(
size=self.bs * self.entries, create=True
)
self.shm_r_tensor = torch.frombuffer(self.shm_r.buf, dtype=torch.uint8)
self.shm_w_tensor = torch.frombuffer(self.shm_w.buf, dtype=torch.uint8)
numel = self.bs * self.entries
self.r_pinned = torch.empty(
numel,
dtype=torch.uint8,
device="cpu",
pin_memory=True,
)
self.w_pinned = torch.empty(
numel,
dtype=torch.uint8,
device="cpu",
pin_memory=True,
)
self.numa = -1
self.ior_r = make_ioring(
self.hf3fs_mount_point,
self.entries,
for_read=True,
timeout=1,
numa=self.numa,
)
self.ior_w = make_ioring(
self.hf3fs_mount_point,
self.entries,
for_read=False,
timeout=1,
numa=self.numa,
)
self.iov_r = make_iovec(self.shm_r, self.hf3fs_mount_point)
self.iov_w = make_iovec(self.shm_w, self.hf3fs_mount_point)
self.shm_r.unlink()
self.shm_w.unlink()
self.rlock = threading.RLock()
self.wlock = threading.RLock()
self.stream = torch.cuda.Stream()
self.stream_ptr_int = self.stream.cuda_stream
except Exception:
self._release_resources()
raise
logger.debug(
"Initialized HF3FS client with file: %s, size: %s bytes", path, size
)
def _release_resources(self) -> None:
"""Release all acquired resources safely"""
# iov must be released before ioring and shm
for attr in ("iov_r", "iov_w", "ior_r", "ior_w"):
obj = getattr(self, attr, None)
if obj is not None:
del obj
setattr(self, attr, None)
for attr in ("shm_r", "shm_w"):
shm = getattr(self, attr, None)
if shm is not None:
try:
shm.close()
except Exception as e:
logger.warning("Failed to close %s: %s", attr, e)
setattr(self, attr, None)
if self.file is not None:
try:
deregister_fd(self.file)
except Exception as e:
logger.warning("deregister_fd failed: %s", e)
try:
os.close(self.file)
except OSError as e:
logger.warning("os.close failed: %s", e)
self.file = None
@rsynchronized()
def batch_read(self, offsets: list[int], tensors: list[torch.Tensor]) -> list[int]:
"""Read data from the file at specified offsets into tensors.
Args:
offsets: List of byte offsets to read from
tensors: List of tensors to read data into
Returns:
List of operation results (0 for success, non-zero for error)
"""
self.check(offsets, tensors)
assert self.ior_r is not None
assert self.iov_r is not None
# prepare
current = 0
for offset, tensor in zip(offsets, tensors):
size = tensor.numel() * tensor.itemsize
self.ior_r.prepare(
self.iov_r[current : current + size], True, self.file, offset
)
current += size
# submit
ionum = len(offsets)
resv = self.ior_r.submit().wait(min_results=ionum)
# results
with torch.cuda.stream(self.stream):
hf3fs_utils.read_shm(
self.shm_r_tensor, self.r_pinned, tensors, self.stream_ptr_int
)
results = [res.result for res in resv]
return results
@wsynchronized()
def batch_write(
self, offsets: list[int], tensors: list[torch.Tensor], event: torch.cuda.Event
) -> list[int]:
"""Write data from tensors to the file at specified offsets.
Args:
offsets: List of byte offsets to write to
tensors: List of tensors containing data to write
Returns:
List of operation results (0 for success, non-zero for error)
"""
self.check(offsets, tensors)
assert self.ior_w is not None
assert self.iov_w is not None
# prepare
with torch.cuda.stream(self.stream):
self.stream.wait_event(event)
hf3fs_utils.write_shm(
tensors, self.shm_w_tensor, self.w_pinned, self.stream_ptr_int
)
current = 0
for offset, tensor in zip(offsets, tensors):
size = tensor.numel() * tensor.itemsize
self.ior_w.prepare(
self.iov_w[current : current + size], False, self.file, offset
)
current += size
# submit
ionum = len(offsets)
resv = self.ior_w.submit().wait(min_results=ionum)
# results
results = [res.result for res in resv]
return results
def check(self, offsets: list[int], tensors: list[torch.Tensor]) -> None:
sizes = [t.numel() * t.itemsize for t in tensors]
if any(
[
len(offsets) > self.entries,
len(offsets) != len(sizes),
any(
offset < 0 or offset + size > self.size
for offset, size in zip(offsets, sizes)
),
any(size > self.bytes_per_page for size in sizes),
]
):
self.close()
raise ValueError("Hf3fsClient.check Failed")
def get_size(self) -> int:
"""Get the total size of the storage file.
Returns:
Size of the file in bytes
"""
return self.size
def close(self) -> None:
"""Close the client and clean up resources."""
if self._closed:
return
self._closed = True
self._release_resources()
def flush(self) -> None:
"""Flush any pending writes to disk."""
if not self._closed and self.file is not None:
os.fsync(self.file)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
HF3FS Metadata Server with key-based organization.
"""
import argparse
import logging
import threading
from abc import ABC, abstractmethod
from dataclasses import dataclass
try:
import orjson
HAS_ORJSON = True
except ImportError:
import json as orjson # type: ignore
HAS_ORJSON = False
import requests
from fastapi import FastAPI, HTTPException, Request, Response
from fastapi.responses import ORJSONResponse
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
@dataclass
class RankFileMetadata:
"""Manages file page allocation for a single rank."""
rank_id: int
num_pages: int
free_pages: list[int]
def allocate_pages(self, num_pages: int) -> list[int]:
"""Allocate specified number of free pages."""
if len(self.free_pages) < num_pages:
return []
allocated = self.free_pages[:num_pages]
self.free_pages = self.free_pages[num_pages:]
return allocated
def release_pages(self, page_indices: list[int]) -> None:
"""Release pages back to free pool."""
for page_idx in page_indices:
if page_idx not in self.free_pages:
self.free_pages.append(page_idx)
def get_free_page_count(self) -> int:
"""Get current number of free pages."""
return len(self.free_pages)
@dataclass
class KeyMetadata:
"""Manages metadata for a single key across multiple ranks."""
key: str
rank_to_page: dict[int, int] # rank -> allocated page index
tp_world_size: int
def add_rank_page(self, rank: int, page_index: int) -> None:
"""Add page allocation for a specific rank."""
self.rank_to_page[rank] = page_index
def get_all_pages(self) -> list[tuple[int, int]]:
"""Get all (rank, page) pairs for this key."""
return [(rank, page) for rank, page in self.rank_to_page.items()]
def get_rank_page(self, rank: int) -> int | None:
"""Get page index for a specific rank."""
return self.rank_to_page.get(rank)
def is_complete(self) -> bool:
"""Check if all ranks in the TP world have allocated pages."""
return len(self.rank_to_page) == self.tp_world_size
class GlobalMetadataState:
"""Manages global metadata state across all ranks and keys."""
def __init__(self):
self.global_lock = threading.RLock()
self.rank_metadata: dict[int, RankFileMetadata] = {}
self.key_metadata: dict[str, KeyMetadata] = {}
def clear(self) -> None:
"""Clear all metadata state."""
with self.global_lock:
self.rank_metadata.clear()
self.key_metadata.clear()
logger.info("Cleared all metadata state")
def initialize_rank(self, rank: int, num_pages: int) -> None:
"""Initialize a new rank with specified number of pages."""
with self.global_lock:
if rank not in self.rank_metadata:
self.rank_metadata[rank] = RankFileMetadata(
rank, num_pages, list(range(num_pages))
)
logger.info("Initialized rank %s with %s pages", rank, num_pages)
def allocate_pages_for_keys(
self, rank: int, keys: list[tuple[str, str]]
) -> dict[str, int]:
"""Allocate one page for each key on the specified rank.
Args:
rank: Rank ID to allocate pages on
keys: List of keys to allocate pages for
Returns:
Dictionary mapping key -> allocated page index
"""
with self.global_lock:
if rank not in self.rank_metadata:
raise ValueError(f"Rank {rank} not initialized")
# Batch allocate pages for all keys
num_pages_needed = len(keys)
allocated_pages = self.rank_metadata[rank].allocate_pages(num_pages_needed)
if len(allocated_pages) < num_pages_needed:
logger.warning(
"Rank %s only allocated %s pages for %s keys",
rank,
len(allocated_pages),
num_pages_needed,
)
allocation_results = {}
for i, (key, prefix_key) in enumerate(keys):
if key in self.key_metadata:
key_meta = self.key_metadata[key]
if key_meta.is_complete() and rank in key_meta.rank_to_page:
# key is already fully written, reuse the existing page
# and release the allocated pages back to the free pool.
if i < len(allocated_pages):
self.rank_metadata[rank].release_pages([allocated_pages[i]])
allocation_results[key] = key_meta.rank_to_page[rank]
continue
if i < len(allocated_pages):
allocation_results[key] = allocated_pages[i]
else:
allocation_results[key] = -1 # No pages available
return allocation_results
def confirm_write_for_keys(
self,
rank: int,
key_confirmations: list[tuple[str, int]],
pages_to_release: list[int] | None = None,
) -> None:
"""Confirm write operations for keys and update metadata.
Args:
rank: Rank ID that confirmed the writes
key_confirmations: List of (key, page_index) tuples
pages_to_release: List of page indices to release back to free pool
"""
with self.global_lock:
# Confirm successful writes
for key, page_index in key_confirmations:
if key not in self.key_metadata:
# Need to determine tp_world_size from rank_metadata
tp_world_size = len(self.rank_metadata)
self.key_metadata[key] = KeyMetadata(key, {}, tp_world_size)
# Add confirmed page to key metadata
self.key_metadata[key].add_rank_page(rank, page_index)
# Release specified pages back to free pool
if pages_to_release:
self.rank_metadata[rank].release_pages(pages_to_release)
logger.debug(
"Released %s pages on rank %s: %s",
len(pages_to_release),
rank,
pages_to_release,
)
def batch_key_exists(self, keys: list[str]) -> list[bool]:
"""Check if keys exist in metadata and all ranks have confirmed writes.
Args:
keys: List of keys to check
Returns:
List of boolean values indicating key existence and completion
"""
with self.global_lock:
results = []
for key in keys:
if key not in self.key_metadata:
results.append(False)
else:
# Check if all ranks in the TP world have confirmed writes
key_meta = self.key_metadata[key]
results.append(key_meta.is_complete())
return results
def get_key_locations(self, rank: int, keys: list[str]) -> list[int | None]:
"""Get page indices for keys on a specific rank.
Args:
rank: Rank ID to query
keys: List of keys to look up
Returns:
List of page indices in the same order as input keys (None if key not found)
"""
with self.global_lock:
if rank not in self.rank_metadata:
raise ValueError(f"Rank {rank} not initialized")
results = []
for key in keys:
if key in self.key_metadata:
key_meta = self.key_metadata[key]
if key_meta.is_complete():
page_index = key_meta.get_rank_page(rank)
else:
page_index = None
results.append(page_index)
else:
results.append(None)
return results
class Hf3fsMetadataServer:
"""HF3FS Metadata Server with improved key-based organization."""
def __init__(self, persistence_path: str | None = None, save_interval: int = 60):
self.state = GlobalMetadataState()
if HAS_ORJSON:
self.app = FastAPI(default_response_class=ORJSONResponse)
else:
self.app = FastAPI()
self._setup_routes()
async def _read_json(self, request: Request) -> dict:
"""Parse request JSON using orjson if available."""
body = await request.body()
return orjson.loads(body)
def _json_response(self, content: dict):
"""Return ORJSONResponse when available to bypass jsonable_encoder."""
if HAS_ORJSON:
return ORJSONResponse(content)
else:
return content
def _setup_routes(self):
"""Setup FastAPI routes for new API design."""
self.app.post("/rank/{rank}/initialize")(self.initialize_rank)
self.app.post("/keys/batch_allocate")(self.batch_allocate_pages_for_keys)
self.app.post("/keys/confirm_write")(self.confirm_write_for_keys)
self.app.post("/keys/batch_exists")(self.batch_key_exists)
self.app.post("/keys/get_locations")(self.get_key_locations)
self.app.post("/clear")(self.clear)
async def initialize_rank(self, rank: int, request: Request):
"""Initialize a rank with specified number of pages."""
data = await self._read_json(request)
role = data.get("role", "worker")
num_pages = data.get("num_pages", 0)
if role == "scheduler":
return self._json_response(
{"message": "Scheduler role does not require initialization"}
)
if role == "worker" and num_pages > 0:
self.state.initialize_rank(rank, num_pages)
return self._json_response(
{"message": f"Rank {rank} initialized with {num_pages} pages"}
)
else:
raise HTTPException(
status_code=400, detail="Invalid initialization parameters"
)
async def batch_allocate_pages_for_keys(self, request: Request):
"""Allocate one page for each key on a specific rank."""
data = await self._read_json(request)
rank = data.get("rank")
keys = data.get("keys", [])
# Validate input format
if rank is None or not isinstance(keys, list):
raise HTTPException(
status_code=400, detail="Invalid request format: need 'rank' and 'keys'"
)
try:
# Perform allocation
results = self.state.allocate_pages_for_keys(rank, keys)
# Convert results to response format
response = {"rank": rank, "results": list(results.items())}
return self._json_response(response)
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Allocation failed: {str(e)}"
) from e
async def confirm_write_for_keys(self, request: Request):
"""Confirm write operations for keys."""
data = await self._read_json(request)
rank = data.get("rank")
confirmations = data.get("confirmations", [])
pages_to_release = data.get("pages_to_release", [])
# Validate input format
if rank is None or not isinstance(confirmations, list):
raise HTTPException(
status_code=400,
detail="Invalid request format: need 'rank' and 'confirmations'",
)
try:
self.state.confirm_write_for_keys(rank, confirmations, pages_to_release)
return Response(status_code=204)
except Exception as e:
logger.error("Confirm write for keys failed: %s", e)
raise HTTPException(
status_code=500, detail=f"Confirmation failed: {str(e)}"
) from e
async def batch_key_exists(self, request: Request):
"""Check if multiple keys exist in metadata."""
data = await self._read_json(request)
keys = data.get("keys", [])
if not isinstance(keys, list):
raise HTTPException(status_code=400, detail="Invalid keys format")
try:
exists_results = self.state.batch_key_exists(keys)
return self._json_response({"exists": exists_results})
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Key existence check failed: {str(e)}"
) from e
async def get_key_locations(self, request: Request):
"""Get page indices for keys on a specific rank."""
data = await self._read_json(request)
rank = data.get("rank")
keys = data.get("keys", [])
# Validate input format
if rank is None or not isinstance(keys, list):
raise HTTPException(
status_code=400, detail="Invalid request format: need 'rank' and 'keys'"
)
try:
# Get key locations
locations = self.state.get_key_locations(rank, keys)
return self._json_response({"locations": locations})
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Failed to get key locations: {str(e)}"
) from e
async def clear(self, request: Request):
"""Clear the metadata server."""
self.state.clear()
return Response(status_code=204)
def run(self, host: str = "0.0.0.0", port: int = 18000):
"""Run the metadata server."""
import uvicorn
logger.info("Starting improved metadata server on http://%s:%s", host, port)
uvicorn.run(self.app, host=host, port=port)
# --- Client implementation ---
class Hf3fsMetadataInterface(ABC):
"""Interface for HF3FS metadata operations."""
@abstractmethod
def initialize(self, rank: int, num_pages: int = 0, role: str = "worker") -> None:
"""Initialize the metadata service with specified number of pages."""
pass
@abstractmethod
def allocate_pages_for_keys(
self, rank: int, keys: list[tuple[str, str]]
) -> list[tuple[str, int]]:
"""Allocate one page for each key on the specified rank."""
pass
@abstractmethod
def confirm_write_for_keys(
self,
rank: int,
key_confirmations: list[tuple[str, int]],
pages_to_release: list[int] | None = None,
) -> None:
"""Confirm write operations for keys and optionally release pages."""
pass
@abstractmethod
def batch_key_exists(self, keys: list[str]) -> list[bool]:
"""Check if keys exist and are complete across all ranks."""
pass
@abstractmethod
def get_key_locations(self, rank: int, keys: list[str]) -> list[int]:
"""Get page indices for keys on a specific rank."""
pass
class Hf3fsGlobalMetadataClient(Hf3fsMetadataInterface):
"""Global HTTP metadata client for HF3FS."""
def __init__(self, base_url: str = "http://localhost:18000", max_retries: int = 3):
self.base_url = base_url.rstrip("/")
self._session = requests.Session()
retry_strategy = Retry(
total=max_retries,
backoff_factor=0.3,
status_forcelist=[500, 502, 503, 504],
allowed_methods=["GET", "POST"],
)
adapter = HTTPAdapter(max_retries=retry_strategy)
self._session.mount("http://", adapter)
def _post(self, endpoint: str, json_data: dict) -> dict:
"""Make POST request to metadata server."""
try:
url = f"{self.base_url}/{endpoint}"
headers = {"Content-Type": "application/json"}
if HAS_ORJSON:
payload = orjson.dumps(json_data)
else:
import json
payload = json.dumps(json_data).encode("utf-8")
response = self._session.post(url, data=payload, headers=headers)
response.raise_for_status()
if response.status_code == 204 or not response.content:
return {}
if HAS_ORJSON:
return orjson.loads(response.content)
else:
return response.json()
except requests.exceptions.RequestException as e:
logger.error("Failed to POST to %s after retries: %s", endpoint, e)
raise RuntimeError(f"Failed to connect to metadata server: {e}") from e
def initialize(self, rank: int, num_pages: int = 0, role: str = "worker") -> None:
"""Initialize a rank with specified number of pages."""
self._post(f"rank/{rank}/initialize", {"num_pages": num_pages, "role": role})
def allocate_pages_for_keys(
self, rank: int, keys: list[tuple[str, str]]
) -> list[tuple[str, int]]:
"""Allocate pages for keys on the specified rank."""
response = self._post("keys/batch_allocate", {"rank": rank, "keys": keys})
# Convert response to expected format
return response.get("results", {})
def confirm_write_for_keys(
self,
rank: int,
key_confirmations: list[tuple[str, int]],
pages_to_release: list[int] | None = None,
) -> None:
"""Confirm write operations for keys and optionally release pages."""
payload = {
"rank": rank,
"confirmations": key_confirmations,
"pages_to_release": pages_to_release or [],
}
self._post("keys/confirm_write", payload)
def batch_key_exists(self, keys: list[str]) -> list[bool]:
"""Check if keys exist and are complete across all ranks."""
response = self._post("keys/batch_exists", {"keys": keys})
return response.get("exists", [])
def get_key_locations(self, rank: int, keys: list[str]) -> list[int]:
"""Get page indices for keys on a specific rank."""
response = self._post("keys/get_locations", {"rank": rank, "keys": keys})
return response.get("locations", [])
def run_metadata_server(
host: str = "0.0.0.0",
port: int = 18000,
):
"""Run the improved HF3FS metadata server."""
server = Hf3fsMetadataServer()
server.run(host=host, port=port)
# --- Main Execution ---
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Improved HF3FS Metadata Server")
parser.add_argument(
"--host", type=str, default="0.0.0.0", help="Host to bind the server to."
)
parser.add_argument(
"--port", type=int, default=18000, help="Port to run the server on."
)
args = parser.parse_args()
run_metadata_server(args.host, args.port)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading
from dataclasses import dataclass, field
from typing import Optional
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
from vllm.v1.request import Request
class AtomicCounter:
"""Thread-safe atomic counter for round-robin operations."""
def __init__(self, n: int):
assert n > 0, "Counter size must be positive"
self._n = n
self._value = 0
self._lock = threading.Lock()
def next(self) -> int:
"""Get next value in round-robin fashion."""
with self._lock:
current = self._value
self._value = (current + 1) % self._n
return current
@dataclass
class LoadBlockInfo:
"""Operation for loading blocks from external storage."""
num_computed_blocks: int
num_blocks_to_load: int
need_fetch_block_ids: list[int]
@dataclass
class SaveBlockInfo:
"""Operation for saving blocks to external storage."""
skip_leading_blocks: int
@dataclass
class RequestSchedulingState:
"""Unified request scheduling state management."""
request_id: str
request: Request | None = None
# Token and block tracking
token_ids: list[int] = field(default_factory=list)
allocated_block_ids: list[int] = field(default_factory=list)
num_saved_blocks: int = 0
# Load operation info
load_op: LoadBlockInfo | None = None
# Scheduling phase
phase: str = "NEW" # NEW -> WAITING_TO_LOAD -> ACTIVE -> FINISHED
def needs_loading(self) -> bool:
"""Check if request needs loading."""
return self.load_op is not None and self.load_op.num_blocks_to_load > 0
def is_ready_to_load(self) -> bool:
"""Check if request is ready for loading."""
return self.phase == "WAITING_TO_LOAD" and self.needs_loading()
def update_tokens_and_blocks(self, new_token_ids: list[int], new_block_ids) -> None:
"""Update with new tokens and blocks."""
if new_token_ids:
self.token_ids.extend(new_token_ids)
if new_block_ids is not None:
normalized_block_ids = self._normalize_block_ids(new_block_ids)
self.allocated_block_ids.extend(normalized_block_ids)
def _normalize_block_ids(self, block_ids) -> list[int]:
"""Normalize block_ids to list format."""
if not block_ids:
return []
if isinstance(block_ids, tuple):
return block_ids[0] if block_ids else []
if isinstance(block_ids, list):
return block_ids
return []
@dataclass
class HF3FSRequestMetadata:
"""Metadata for a single request in HF3FS connector."""
request_id: str
token_ids: list[int]
block_ids: list[int]
load_block_op: LoadBlockInfo | None = None
save_block_op: SaveBlockInfo | None = None
@staticmethod
def from_scheduling_state(
state: "RequestSchedulingState",
block_size: int,
load_op: LoadBlockInfo | None = None,
skip_leading_blocks: int | None = None,
) -> Optional["HF3FSRequestMetadata"]:
"""Create request metadata from scheduling state."""
token_count = len(state.token_ids)
total_blocks = token_count // block_size
skip_blocks = (
state.num_saved_blocks
if skip_leading_blocks is None
else skip_leading_blocks
)
new_blocks_to_save = total_blocks - state.num_saved_blocks
if new_blocks_to_save <= 0 and load_op is None:
return None
state.num_saved_blocks = total_blocks
return HF3FSRequestMetadata(
request_id=state.request_id,
token_ids=state.token_ids,
block_ids=state.allocated_block_ids,
load_block_op=load_op,
save_block_op=SaveBlockInfo(skip_leading_blocks=skip_blocks),
)
class HF3FSConnectorMetadata(KVConnectorMetadata):
"""Container for HF3FS connector metadata."""
def __init__(self):
self.requests: list[HF3FSRequestMetadata] = []
def add_request(self, request_metadata: HF3FSRequestMetadata) -> None:
"""Add request to metadata."""
self.requests.append(request_metadata)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.logger import init_logger
from vllm.triton_utils import tl, triton
@triton.jit
def kv_cache_scatter_kernel(
kv_cache_ptrs_ptr,
source_ptr,
token_indices_ptr,
num_tokens_in_block,
hidden_size,
total_token_in_kvcache,
num_layers,
is_mla,
BLOCK_SIZE: tl.constexpr,
):
layer_idx = tl.program_id(0)
token_pos = tl.program_id(1)
if layer_idx >= num_layers or token_pos >= num_tokens_in_block:
return
token_idx = tl.load(token_indices_ptr + token_pos)
kv_cache_ptr = tl.cast(tl.load(kv_cache_ptrs_ptr + layer_idx), source_ptr.dtype)
if token_idx >= total_token_in_kvcache:
return
if is_mla:
# MLA format: source [num_layers, num_tokens_in_block, hidden_size]
# MLA format: target [total_token_in_kvcache, hidden_size] (per layer)
source_offset = (layer_idx * num_tokens_in_block + token_pos) * hidden_size
target_offset = token_idx * hidden_size
for i in range(0, hidden_size, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
mask = offset < hidden_size
val = tl.load(source_ptr + source_offset + offset, mask=mask)
tl.store(kv_cache_ptr + target_offset + offset, val, mask=mask)
else:
# MHA format: source [num_layers, 2, num_tokens_in_block, hidden_size]
# MHA format: target [2, total_token_in_kvcache, hidden_size]
source_offset_k = (
layer_idx * num_tokens_in_block * 2 + token_pos
) * hidden_size
source_offset_v = (
layer_idx * num_tokens_in_block * 2 + num_tokens_in_block + token_pos
) * hidden_size
target_offset_k = token_idx * hidden_size
target_offset_v = (total_token_in_kvcache + token_idx) * hidden_size
for i in range(0, hidden_size, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
mask = offset < hidden_size
val_k = tl.load(source_ptr + source_offset_k + offset, mask=mask)
val_v = tl.load(source_ptr + source_offset_v + offset, mask=mask)
tl.store(kv_cache_ptr + target_offset_k + offset, val_k, mask=mask)
tl.store(kv_cache_ptr + target_offset_v + offset, val_v, mask=mask)
@triton.jit
def kv_cache_gather_kernel(
kv_cache_ptrs_ptr,
dst_ptr,
token_indices_ptr,
num_tokens_in_block,
hidden_size,
total_token_in_kvcache,
num_layers,
is_mla,
BLOCK_SIZE: tl.constexpr,
):
layer_idx = tl.program_id(0)
token_pos = tl.program_id(1)
if layer_idx >= num_layers or token_pos >= num_tokens_in_block:
return
token_idx = tl.load(token_indices_ptr + token_pos)
kv_cache_ptr = tl.cast(tl.load(kv_cache_ptrs_ptr + layer_idx), dst_ptr.dtype)
if token_idx >= total_token_in_kvcache:
return
if is_mla:
# MLA format: source [total_token_in_kvcache, hidden_size] (per layer)
# MLA format: dst [num_layers, num_tokens_in_block, hidden_size]
kvcache_offset = token_idx * hidden_size
dst_offset = (layer_idx * num_tokens_in_block + token_pos) * hidden_size
for i in range(0, hidden_size, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
mask = offset < hidden_size
val = tl.load(kv_cache_ptr + kvcache_offset + offset, mask=mask)
tl.store(dst_ptr + dst_offset + offset, val, mask=mask)
else:
# MHA format: source [2, total_token_in_kvcache, hidden_size]
# MHA format: dst [num_layers, 2, num_tokens_in_block, hidden_size]
dst_offset_k = (layer_idx * num_tokens_in_block * 2 + token_pos) * hidden_size
dst_offset_v = (
layer_idx * num_tokens_in_block * 2 + num_tokens_in_block + token_pos
) * hidden_size
kvcache_offset_k = token_idx * hidden_size
kvcache_offset_v = (total_token_in_kvcache + token_idx) * hidden_size
for i in range(0, hidden_size, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
mask = offset < hidden_size
val_k = tl.load(kv_cache_ptr + kvcache_offset_k + offset, mask=mask)
val_v = tl.load(kv_cache_ptr + kvcache_offset_v + offset, mask=mask)
tl.store(dst_ptr + dst_offset_k + offset, val_k, mask=mask)
tl.store(dst_ptr + dst_offset_v + offset, val_v, mask=mask)
def scatter_kv_caches(
kv_caches_ptrs: torch.Tensor,
total_token_in_kvcache: int,
src_tensor: torch.Tensor,
token_indices: list[int],
is_mla: bool = False,
) -> None:
"""Scatter KV cache data from source tensor to KV cache storage.
Args:
kv_caches_ptrs: Tensor of KV cache pointers (one per layer)
total_token_in_kvcache: Total number of tokens in KV cache
src_tensor: Source tensor containing data to scatter
- MHA format: [num_layers, 2, num_tokens_in_block, hidden_size]
- MLA format: [num_layers, num_tokens_in_block, hidden_size]
token_indices: List of token positions to update
is_mla: Whether using MLA model format
"""
num_layers = len(kv_caches_ptrs)
num_tokens_in_block = len(token_indices)
if is_mla:
# MLA: src_tensor is [num_layers, num_tokens_in_block, hidden_size]
assert len(src_tensor.shape) == 3, (
f"MLA src_tensor should be 3D, got {src_tensor.shape}"
)
hidden_size = src_tensor.shape[2]
else:
# MHA: src_tensor is [num_layers, 2, num_tokens_in_block, hidden_size]
assert len(src_tensor.shape) == 4, (
f"MHA src_tensor should be 4D, got {src_tensor.shape}"
)
hidden_size = src_tensor.shape[3]
device = src_tensor.device
token_indices_tensor = torch.tensor(
token_indices, dtype=torch.int32, device="cpu"
).to(device, non_blocking=True)
grid = (num_layers, num_tokens_in_block)
BLOCK_SIZE = 128
kv_cache_scatter_kernel[grid](
kv_caches_ptrs,
src_tensor,
token_indices_tensor,
num_tokens_in_block,
hidden_size,
total_token_in_kvcache,
num_layers,
is_mla,
BLOCK_SIZE=BLOCK_SIZE,
)
def gather_kv_caches(
kv_caches_ptrs: torch.Tensor,
total_token_in_kvcache: int,
dst_tensor: torch.Tensor,
token_indices: list[int],
is_mla: bool = False,
) -> None:
"""Gather KV cache data from KV cache storage to destination tensor.
Args:
kv_caches_ptrs: Tensor of KV cache pointers (one per layer)
total_token_in_kvcache: Total number of tokens in KV cache
dst_tensor: Destination tensor to store gathered data
- MHA format: [num_layers, 2, num_tokens_in_block, hidden_size]
- MLA format: [num_layers, num_tokens_in_block, hidden_size]
token_indices: List of token positions to gather
is_mla: Whether using MLA model format
"""
num_layers = kv_caches_ptrs.shape[0]
num_tokens_in_block = len(token_indices)
if is_mla:
# MLA: dst_tensor is [num_layers, num_tokens_in_block, hidden_size]
assert len(dst_tensor.shape) == 3, (
f"MLA dst_tensor should be 3D, got {dst_tensor.shape}"
)
assert dst_tensor.shape[0] == num_layers, (
f"Layer count mismatch: {dst_tensor.shape[0]} vs {num_layers}"
)
assert dst_tensor.shape[1] == num_tokens_in_block, (
f"Token count mismatch: {dst_tensor.shape[1]} vs {num_tokens_in_block}"
)
hidden_size = dst_tensor.shape[2]
else:
# MHA: dst_tensor is [num_layers, 2, num_tokens_in_block, hidden_size]
assert len(dst_tensor.shape) == 4, (
f"MHA dst_tensor should be 4D, got {dst_tensor.shape}"
)
assert dst_tensor.shape[0] == num_layers, (
f"Layer count mismatch: {dst_tensor.shape[0]} vs {num_layers}"
)
assert dst_tensor.shape[1] == 2, (
f"MHA should have 2 (K,V) components, got {dst_tensor.shape[1]}"
)
assert dst_tensor.shape[2] == num_tokens_in_block, (
f"Token count mismatch: {dst_tensor.shape[2]} vs {num_tokens_in_block}"
)
hidden_size = dst_tensor.shape[3]
device = dst_tensor.device
token_indices_tensor = torch.tensor(
token_indices, dtype=torch.int32, device="cpu"
).to(device, non_blocking=True)
grid = (num_layers, num_tokens_in_block)
BLOCK_SIZE = 128
kv_cache_gather_kernel[grid](
kv_caches_ptrs,
dst_tensor,
token_indices_tensor,
num_tokens_in_block,
hidden_size,
total_token_in_kvcache,
num_layers,
is_mla,
BLOCK_SIZE=BLOCK_SIZE,
)
class CopyBufferAllocator:
"""Memory pool for tensor buffers to avoid frequent allocation/deallocation."""
def __init__(
self, device: torch.device, dtype: torch.dtype, shape: list, max_count: int
):
self._shape = shape
self._max_count = max_count
self._device = device
self._free_buffers = [
torch.empty(shape, dtype=dtype, device=device) for _ in range(max_count)
]
self._inuse_count = 0
def alloc_buffer(self, count: int) -> list[torch.Tensor] | None:
"""Allocate buffers from the pool."""
if count == 0:
return []
if self._inuse_count + count <= self._max_count:
self._inuse_count += count
result = self._free_buffers[-count:]
del self._free_buffers[-count:]
return result
return None
def free_buffer(self, buffers: list[torch.Tensor]) -> None:
"""Return buffers to the pool."""
if not buffers:
return
if self._inuse_count >= len(buffers):
self._inuse_count -= len(buffers)
self._free_buffers.extend(buffers)
else:
raise RuntimeError("Attempted to free more buffers than allocated")
logger = init_logger(__name__)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
import os
import torch
logger = logging.getLogger(__name__)
HF3FS_AVAILABLE = True
class Hf3fsClient:
"""Mock HF3FS client using file backend for debugging and testing."""
def __init__(self, path: str, size: int, bytes_per_page: int, entries: int):
self._size = size
self._bytes_per_page = bytes_per_page
self._entries = entries
self._file_path = path
self._ensure_file_exists()
logger.debug("Initialized mock HF3FS client: %s (%d bytes)", path, size)
def _ensure_file_exists(self) -> None:
"""Create file if it doesn't exist."""
if not os.path.exists(self._file_path):
with open(self._file_path, "w+b") as f:
f.truncate(self._size)
def batch_read(self, offsets: list[int], tensors: list[torch.Tensor]) -> list[int]:
"""Read data from file at specified offsets into tensors."""
results = []
try:
with open(self._file_path, "rb") as f:
for offset, tensor in zip(offsets, tensors):
num_bytes = tensor.numel() * tensor.element_size()
if offset < 0 or offset + num_bytes > self._size:
results.append(-1)
continue
f.seek(offset)
buffer_data = f.read(num_bytes)
if len(buffer_data) == num_bytes == self._bytes_per_page:
tensor_data = self._convert_buffer_to_tensor(
buffer_data, tensor.dtype
)
tensor.copy_(
tensor_data.reshape(tensor.shape).to(tensor.device)
)
results.append(self._bytes_per_page)
else:
logger.error(
"Read size mismatch: got %d, expected %d",
len(buffer_data),
num_bytes,
)
results.append(-1)
except Exception as e:
logger.error("Batch read error: %s", e)
results.extend([-1] * (len(offsets) - len(results)))
return results
def _convert_buffer_to_tensor(
self, buffer_data: bytes, dtype: torch.dtype
) -> torch.Tensor:
"""Convert buffer data to tensor with proper dtype handling."""
if dtype == torch.bfloat16:
tensor_data = torch.frombuffer(buffer_data, dtype=torch.uint16)
return tensor_data.view(dtype=torch.bfloat16)
else:
return torch.frombuffer(buffer_data, dtype=dtype)
def batch_write(
self, offsets: list[int], tensors: list[torch.Tensor], event: torch.cuda.Event
) -> list[int]:
"""Write data from tensors to file at specified offsets."""
results = []
try:
torch.cuda.current_stream().wait_event(event)
# Convert tensors to bytes
data_bytes_list = [self._tensor_to_bytes(tensor) for tensor in tensors]
# Write to file
with open(self._file_path, "r+b") as f:
for offset, data_bytes in zip(offsets, data_bytes_list):
if offset < 0 or offset + len(data_bytes) > self._size:
results.append(-1)
continue
f.seek(offset)
bytes_written = f.write(data_bytes)
if bytes_written == len(data_bytes) == self._bytes_per_page:
results.append(self._bytes_per_page)
else:
logger.error(
"Write size mismatch: wrote %d, expected %d",
bytes_written,
self._bytes_per_page,
)
results.append(-1)
except Exception as e:
logger.error("Batch write error: %s", e)
results.extend([-1] * (len(offsets) - len(results)))
return results
def _tensor_to_bytes(self, tensor: torch.Tensor) -> bytes:
"""Convert tensor to bytes with proper dtype handling."""
cpu_tensor = tensor.cpu()
if cpu_tensor.dtype == torch.bfloat16:
return cpu_tensor.view(dtype=torch.uint16).numpy().tobytes()
else:
return cpu_tensor.numpy().tobytes()
def get_size(self) -> int:
"""Get the total size of the storage file."""
return self._size
def close(self) -> None:
"""Close the client (no-op for file backend)."""
pass
def flush(self) -> None:
"""Flush any pending writes (no-op for file backend)."""
pass
#include <cuda_runtime.h>
#include <torch/extension.h>
#include <cstring>
#include <vector>
void read_shm(const torch::Tensor& shm, const torch::Tensor& pin,
std::vector<torch::Tensor> dst, uint64_t stream_ptr) {
py::gil_scoped_release release;
cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_ptr);
// Copy from shared memory to pinned memory
char* shm_ptr = static_cast<char*>(shm.data_ptr());
char* src_ptr = static_cast<char*>(pin.data_ptr());
std::memcpy(src_ptr, shm_ptr, shm.numel() * shm.element_size());
// Copy from pinned memory to GPU tensors
size_t current = 0;
for (size_t i = 0; i < dst.size(); ++i) {
auto& t = dst[i];
size_t t_bytes = t.numel() * t.element_size();
char* dst_ptr = static_cast<char*>(t.data_ptr());
cudaMemcpyAsync(dst_ptr, src_ptr + current, t_bytes, cudaMemcpyHostToDevice,
stream);
current += t_bytes;
}
cudaStreamSynchronize(stream);
}
void write_shm(const std::vector<torch::Tensor> src, torch::Tensor& shm,
const torch::Tensor& pin, uint64_t stream_ptr) {
py::gil_scoped_release release;
cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_ptr);
// Copy from GPU tensors to pinned memory
char* dst_ptr = static_cast<char*>(pin.data_ptr());
size_t current = 0;
for (size_t i = 0; i < src.size(); ++i) {
auto& t = src[i];
size_t t_bytes = t.numel() * t.element_size();
char* src_ptr = static_cast<char*>(t.data_ptr());
cudaMemcpyAsync(dst_ptr + current, src_ptr, t_bytes, cudaMemcpyDeviceToHost,
stream);
current += t_bytes;
}
cudaStreamSynchronize(stream);
// Copy from pinned memory to shared memory
char* shm_ptr = static_cast<char*>(shm.data_ptr());
std::memcpy(shm_ptr, dst_ptr, shm.numel() * shm.element_size());
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("read_shm", &read_shm, "Read tensors from shared memory");
m.def("write_shm", &write_shm, "Write tensors to shared memory");
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment