Unverified Commit 4505849b authored by inkcherry's avatar inkcherry Committed by GitHub
Browse files

[ROCm][PD] add moriio kv connector. (#29304)


Signed-off-by: default avatarinkcherry <mingzhi.liu@amd.com>
parent db07433c
...@@ -11,6 +11,8 @@ ARG FA_BRANCH="0e60e394" ...@@ -11,6 +11,8 @@ ARG FA_BRANCH="0e60e394"
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git" ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
ARG AITER_BRANCH="6af8b687" ARG AITER_BRANCH="6af8b687"
ARG AITER_REPO="https://github.com/ROCm/aiter.git" ARG AITER_REPO="https://github.com/ROCm/aiter.git"
ARG MORI_BRANCH="2d02c6a9"
ARG MORI_REPO="https://github.com/ROCm/mori.git"
#TODO: When patch has been upstreamed, switch to the main repo/branch #TODO: When patch has been upstreamed, switch to the main repo/branch
# ARG RIXL_BRANCH="<TODO>" # ARG RIXL_BRANCH="<TODO>"
...@@ -31,6 +33,7 @@ ENV LD_LIBRARY_PATH=/opt/rocm/lib:/usr/local/lib: ...@@ -31,6 +33,7 @@ ENV LD_LIBRARY_PATH=/opt/rocm/lib:/usr/local/lib:
ARG PYTORCH_ROCM_ARCH=gfx90a;gfx942;gfx950;gfx1100;gfx1101;gfx1200;gfx1201;gfx1150;gfx1151 ARG PYTORCH_ROCM_ARCH=gfx90a;gfx942;gfx950;gfx1100;gfx1101;gfx1200;gfx1201;gfx1150;gfx1151
ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}
ENV AITER_ROCM_ARCH=gfx942;gfx950 ENV AITER_ROCM_ARCH=gfx942;gfx950
ENV MORI_GPU_ARCHS=gfx942;gfx950
# Required for RCCL in ROCm7.1 # Required for RCCL in ROCm7.1
ENV HSA_NO_SCRATCH_RECLAIM=1 ENV HSA_NO_SCRATCH_RECLAIM=1
...@@ -44,7 +47,7 @@ ENV DEBIAN_FRONTEND=noninteractive ...@@ -44,7 +47,7 @@ ENV DEBIAN_FRONTEND=noninteractive
# Install Python and other dependencies # Install Python and other dependencies
RUN apt-get update -y \ RUN apt-get update -y \
&& apt-get install -y software-properties-common git curl sudo vim less libgfortran5 \ && apt-get install -y software-properties-common git curl sudo vim less libgfortran5 libopenmpi-dev libpci-dev \
&& for i in 1 2 3; do \ && for i in 1 2 3; do \
add-apt-repository -y ppa:deadsnakes/ppa && break || \ add-apt-repository -y ppa:deadsnakes/ppa && break || \
{ echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \ { echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \
...@@ -86,6 +89,18 @@ RUN cd /opt/rocm/share/amd_smi \ ...@@ -86,6 +89,18 @@ RUN cd /opt/rocm/share/amd_smi \
&& pip wheel . --wheel-dir=dist && pip wheel . --wheel-dir=dist
RUN mkdir -p /app/install && cp /opt/rocm/share/amd_smi/dist/*.whl /app/install RUN mkdir -p /app/install && cp /opt/rocm/share/amd_smi/dist/*.whl /app/install
FROM base AS build_mori
ARG MORI_BRANCH
ARG MORI_REPO
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
pip install /install/*.whl
RUN git clone ${MORI_REPO}
RUN cd mori \
&& git checkout ${MORI_BRANCH} \
&& git submodule update --init --recursive \
&& python3 setup.py bdist_wheel --dist-dir=dist && ls /app/mori/dist/*.whl
RUN mkdir -p /app/install && cp /app/mori/dist/*.whl /app/install
### ###
### Pytorch build ### Pytorch build
...@@ -253,6 +268,8 @@ RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \ ...@@ -253,6 +268,8 @@ RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
cp /install/*.whl /app/debs cp /install/*.whl /app/debs
RUN --mount=type=bind,from=build_aiter,src=/app/install/,target=/install \ RUN --mount=type=bind,from=build_aiter,src=/app/install/,target=/install \
cp /install/*.whl /app/debs cp /install/*.whl /app/debs
RUN --mount=type=bind,from=build_mori,src=/app/install/,target=/install \
cp /install/*.whl /app/debs
RUN --mount=type=bind,from=build_rixl,src=/app/install/,target=/install \ RUN --mount=type=bind,from=build_rixl,src=/app/install/,target=/install \
cp /install/*.whl /app/debs cp /install/*.whl /app/debs
......
...@@ -100,7 +100,22 @@ Currently, there are no pre-built ROCm wheels. ...@@ -100,7 +100,22 @@ Currently, there are no pre-built ROCm wheels.
- The validated `$AITER_BRANCH_OR_COMMIT` can be found in the [docker/Dockerfile.rocm_base](https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile.rocm_base). - The validated `$AITER_BRANCH_OR_COMMIT` can be found in the [docker/Dockerfile.rocm_base](https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile.rocm_base).
4. Build vLLM. For example, vLLM on ROCM 7.0 can be built with the following steps: 4. If you want to use MORI for EP or PD disaggregation, you can install [MORI](https://github.com/ROCm/mori) using the following steps:
```bash
git clone https://github.com/ROCm/mori.git
cd mori
git checkout $MORI_BRANCH_OR_COMMIT
git submodule sync; git submodule update --init --recursive
MORI_GPU_ARCHS="gfx942;gfx950" python3 install .
```
!!! note
- You will need to config the `$MORI_BRANCH_OR_COMMIT` for your purpose.
- The validated `$MORI_BRANCH_OR_COMMIT` can be found in the [docker/Dockerfile.rocm_base](https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile.rocm_base).
5. Build vLLM. For example, vLLM on ROCM 7.0 can be built with the following steps:
???+ console "Commands" ???+ console "Commands"
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import copy
import logging
import os
import re
import socket
import threading
import uuid
import aiohttp
import msgpack
import zmq
from quart import Quart, make_response, request
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
prefill_instances: list[dict] = []
decode_instances: list[dict] = []
request_nums = 0
app = Quart(__name__)
IP_PORT_PATTERN = re.compile(r"//(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}):(\d+)")
TRANSFER_TYPE = None
def _append_whole_dict_unique(target_list, data_dict):
new_filtered = {k: v for k, v in data_dict.items() if k != "index"}
for existed in target_list:
existed_filtered = {k: v for k, v in existed.items() if k != "index"}
if existed_filtered == new_filtered:
return False
print("!!APPEND!!", data_dict)
target_list.append(data_dict)
transfer_mode = data_dict.get("transfer_mode", "unknown")
global TRANSFER_TYPE
if TRANSFER_TYPE is None:
TRANSFER_TYPE = transfer_mode
logger.info("SET TRANSFER TYPE TO %s", TRANSFER_TYPE)
elif transfer_mode != TRANSFER_TYPE:
raise ValueError(f"mismatched transfer mode {TRANSFER_TYPE} vs {transfer_mode}")
return True
_list_lock = threading.RLock()
def _listen_for_register(hostname, port):
context = zmq.Context()
router_socket = context.socket(zmq.ROUTER)
router_socket.bind(f"tcp://{hostname}:{port}")
poller = zmq.Poller()
poller.register(router_socket, zmq.POLLIN)
global prefill_instances
global decode_instances
while True:
socks = dict(poller.poll())
if router_socket in socks:
remote_addr, msg = router_socket.recv_multipart()
data = msgpack.loads(msg)
if data["type"] == "HELLO":
pass
elif (
data["type"] == "register"
and data["role"] == "P"
and data["request_address"] not in prefill_instances
):
with _list_lock:
_append_whole_dict_unique(prefill_instances, data)
elif (
data["type"] == "register"
and data["role"] == "D"
and data["request_address"] not in decode_instances
):
with _list_lock:
_append_whole_dict_unique(decode_instances, data)
def start_service_discovery(hostname, port):
if not hostname:
hostname = socket.gethostname()
if port == 0:
raise ValueError("Port cannot be 0")
_listener_thread = threading.Thread(
target=_listen_for_register, args=(hostname, port), daemon=True
)
_listener_thread.start()
return _listener_thread
async def send_request_to_prefill(
endpoint, req_data, request_id, d_endpoint, dip, dport, selected_prefill_dp_rank
):
req_data_copy = req_data
req_data_copy["kv_transfer_params"].update(
{
"do_remote_decode": True,
"do_remote_prefill": False,
"remote_handshake_port": d_endpoint["handshake_port"],
"remote_notify_port": d_endpoint["notify_port"],
"remote_engine_id": None,
"remote_block_ids": None,
"remote_host": dip,
"remote_port": dport,
}
)
req_data_copy["stream"] = False
req_data_copy["max_tokens"] = 1
if "max_completion_tokens" in req_data_copy:
req_data_copy["max_completion_tokens"] = 1
if "stream_options" in req_data_copy:
del req_data_copy["stream_options"]
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000)
) as session:
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id,
}
if selected_prefill_dp_rank is not None:
headers["X-data-parallel-rank"] = str(selected_prefill_dp_rank)
async with session.post(
url=endpoint, json=req_data_copy, headers=headers
) as response:
if response.status == 200:
return await response.json()
else:
raise RuntimeError(
"send_request_to_prefill response.status != 200response.status = ",
response.status,
)
async def start_decode_request(endpoint, req_data, request_id):
session = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000)
)
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id,
}
response = await session.post(url=endpoint, json=req_data, headers=headers)
return session, response
async def stream_decode_response(session, response, request_id):
try:
if response.status == 200:
async for chunk_bytes in response.content.iter_chunked(1024):
yield chunk_bytes
else:
raise RuntimeError(
f"decode response.status != 200, status = {response.status}"
)
finally:
await session.close()
async def send_request_to_decode(endpoint, req_data, request_id):
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000)
) as session:
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id,
}
async with session.post(
url=endpoint, json=req_data, headers=headers
) as response:
if response.status == 200:
async for chunk_bytes in response.content.iter_chunked(1024):
yield chunk_bytes
else:
raise RuntimeError(
"send_request_to_decode response.status != 200,response.statuus = ",
response.status,
)
def example_round_robin_dp_loader(request_number, dp_size):
return request_nums % dp_size
@app.route("/v1/completions", methods=["POST"])
@app.route("/v1/chat/completions", methods=["POST"])
async def handle_request():
try:
with _list_lock:
global request_nums
request_nums += 1
def extract_ip_port_fast(url):
match = IP_PORT_PATTERN.search(url)
if not match:
raise ValueError(f"Invalid URL format: {url}")
return match.groups()
req_data = await request.get_json()
request_id = str(uuid.uuid4())
prefill_instance_endpoint = None
decode_instance_endpoint = None
error_msg = (
"Service Unavailable: No prefill or decode instances are registered."
)
if not prefill_instances or not decode_instances:
return await make_response(
(
error_msg,
503,
)
)
pid = request_nums % len(prefill_instances)
did = request_nums % len(decode_instances)
prefill_instance_endpoint = prefill_instances[pid]
decode_instance_endpoint = decode_instances[did]
selected_prefill_dp_rank = None
if prefill_instance_endpoint["dp_size"] > 1:
selected_prefill_dp_rank = example_round_robin_dp_loader(
request_nums // len(prefill_instance_endpoint),
prefill_instance_endpoint["dp_size"],
)
dip, dport = extract_ip_port_fast(decode_instance_endpoint["request_address"])
ip, port = extract_ip_port_fast(prefill_instance_endpoint["request_address"])
req_data_to_prefill = copy.deepcopy(req_data)
req_data_to_prefill["kv_transfer_params"] = {}
req_data["kv_transfer_params"] = {}
req_data_to_prefill["kv_transfer_params"]["remote_dp_size"] = (
decode_instance_endpoint["dp_size"]
)
req_data_to_prefill["kv_transfer_params"]["remote_tp_size"] = (
decode_instance_endpoint["tp_size"]
)
send_prefill_task = asyncio.create_task(
send_request_to_prefill(
prefill_instance_endpoint["request_address"],
req_data_to_prefill,
request_id,
decode_instance_endpoint,
dip,
dport,
selected_prefill_dp_rank,
)
)
ip, port = extract_ip_port_fast(prefill_instance_endpoint["request_address"])
req_data["max_tokens"] -= 1
req_data["kv_transfer_params"] = {
"do_remote_decode": False,
"do_remote_prefill": True,
"remote_handshake_port": prefill_instance_endpoint["handshake_port"],
"remote_notify_port": prefill_instance_endpoint["notify_port"],
"remote_engine_id": None,
"remote_block_ids": None,
"remote_host": ip,
"remote_port": port,
}
if TRANSFER_TYPE == "READ":
# In read mode, prefill and decode are executed serially.
prefill_response = await send_prefill_task
req_data["kv_transfer_params"]["remote_engine_id"] = prefill_response[
"kv_transfer_params"
]["remote_engine_id"]
req_data["kv_transfer_params"]["remote_block_ids"] = prefill_response[
"kv_transfer_params"
]["remote_block_ids"]
req_data["kv_transfer_params"]["remote_dp_size"] = prefill_instance_endpoint[
"dp_size"
]
req_data["kv_transfer_params"]["remote_tp_size"] = prefill_instance_endpoint[
"tp_size"
]
if selected_prefill_dp_rank is not None:
req_data["kv_transfer_params"]["remote_dp_rank"] = selected_prefill_dp_rank
decode_request_task = asyncio.create_task(
start_decode_request(
decode_instance_endpoint["request_address"], req_data, request_id
)
)
session, decode_response = await decode_request_task
stream_generator = stream_decode_response(session, decode_response, request_id)
response = await make_response(stream_generator)
return response
except Exception as e:
logger.exception("An error occurred while handling the request: %s", e)
return await make_response(
(
f"Internal Server Error: {e!s}",
500,
)
)
if __name__ == "__main__":
t = start_service_discovery("0.0.0.0", 36367)
app.debug = True
app.config["BODY_TIMEOUT"] = 360000
app.config["RESPONSE_TIMEOUT"] = 360000
app.run(host="0.0.0.0", port=10001)
t.join()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib.util
import os
from unittest.mock import MagicMock, patch
import msgspec
import pytest
import torch
import zmq
from tests.conftest import _find_free_port
from vllm.config import (
CacheConfig,
DeviceConfig,
KVTransferConfig,
ModelConfig,
SchedulerConfig,
VllmConfig,
)
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import (
MoRIIOAgentMetadata,
MoRIIOConnectorMetadata,
MoRIIOConstants,
zmq_ctx,
)
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector import (
KVConnectorRole,
MoRIIOConnector,
MoRIIOConnectorWorker,
)
from vllm.platforms import current_platform
from vllm.utils.network_utils import (
get_ip,
make_zmq_path,
)
from .utils import create_request, create_scheduler
aiter_available = importlib.util.find_spec("aiter") is not None
mori_available = importlib.util.find_spec("mori") is not None
pytestmark = pytest.mark.skipif(
not (current_platform.is_rocm() and mori_available),
reason="MoRIIOs are only available on ROCm with aiter package installed",
)
@pytest.fixture
def mock_parallel_groups():
"""Mock tensor/data parallel group functions for single-rank tests."""
mock_group = MagicMock()
mock_group.rank = 0
mock_group.local_rank = 0
mock_group.world_size = 1
with (
patch.multiple(
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common",
get_tensor_model_parallel_rank=MagicMock(return_value=0),
get_tensor_model_parallel_world_size=MagicMock(return_value=0),
),
patch.multiple(
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector",
get_tensor_model_parallel_world_size=MagicMock(return_value=0),
get_world_group=MagicMock(return_value=mock_group),
get_tp_group=MagicMock(return_value=mock_group),
),
):
yield mock_group
def _setup_kv_transfer_request(request, remote_host="127.0.0.1", fake_port=4789):
"""Setup KV transfer parameters for a request."""
request.kv_transfer_params.update(
{
"remote_notify_port": fake_port,
"remote_block_ids": None,
"remote_host": remote_host,
"remote_port": fake_port,
"remote_handshake_port": fake_port,
"remote_engine_id": "test_engine",
}
)
return request
class FakeMorIIOWrapper:
# A fake MoRIIOWrapper for testing purposes
def __init__(self, *args, **kwargs):
pass
def set_moriio_engine(self, moriio_engine):
pass
def set_backend_type(self, backend_type):
pass
def get_agent_metadata(self):
pass
def register_remote_engine(self, remote_packed_engine_metadata):
pass
def register_local_tensor(self, tensor: torch.Tensor):
pass
def get_unpack_memory_metadata(self, packed_memory_metadata):
pass
def build_session(self, local_memory_metadata, remote_memory_metadata):
pass
def read_remote_data(
self, transfer_size_byte, local_offset=0, remote_offset=0, session=None
):
pass
def write_remote_data(
self, transfer_size_byte, local_offset=0, remote_offset=0, session=None
):
pass
def write_remote_data_single(
self, transfer_size_byte, local_offset=0, remote_offset=0, sess_idx=0
):
pass
def waiting_for_transfer_complete(self):
pass
def async_wait_reqid(self):
pass
def _handle_message(self, msg: bytes):
pass
def _handle_structured_message(self, data: dict):
pass
def _handle_completion_message(self, msg: str):
pass
def send_notify(self, req_ids, remote_ip, remote_port):
pass
def pop_finished_req_ids(self):
pass
def pop_finished_write_req_ids(self):
pass
def shutdown(self):
pass
class FakeMorIIOConnectorWorker(MoRIIOConnectorWorker):
# Define a fake remote engine id for testing
REMOTE_ENGINE_ID = "remote_engine"
def __init__(
self, *args, hand_shake_latency: float = 1.8, kv_cache_layout="HND", **kwargs
):
super().__init__(*args, **kwargs)
def create_vllm_config(
model: str = "facebook/opt-125m",
max_num_seqs: int = 16,
max_num_batched_tokens: int = 64,
block_size: int = 16,
max_model_len: int = 10000,
enable_chunked_prefill: bool = True,
enable_permute_local_kv: bool = False,
role="kv_consumer",
) -> VllmConfig:
"""Initialize VllmConfig for testing."""
scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
max_model_len=max_model_len,
enable_chunked_prefill=enable_chunked_prefill,
is_encoder_decoder=False,
)
model_config = ModelConfig(
model=model,
trust_remote_code=True,
dtype="bfloat16",
seed=42,
)
# Cache config, optionally force APC
cache_config = CacheConfig(
block_size=block_size,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
enable_prefix_caching=True,
)
kv_transfer_config = KVTransferConfig(
kv_connector="MoRIIOConnector",
kv_role=role,
enable_permute_local_kv=enable_permute_local_kv,
)
return VllmConfig(
scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
kv_transfer_config=kv_transfer_config,
device_config=DeviceConfig("cpu"),
)
@pytest.fixture
def moriio_read_mode():
"""Force the connector into read mode via env for tests."""
os.environ["VLLM_MORIIO_CONNECTOR_READ_MODE"] = "True"
yield
# Cleanup after test
os.environ.pop("VLLM_MORIIO_CONNECTOR_READ_MODE", None)
def test_write_mode_saves_local_block_ids():
"""Write mode records local block ids in MoRIIOConnectorMetadata.reqs_to_save."""
# Setup Scheduler and Request
vllm_config = create_vllm_config(role="kv_producer")
scheduler = create_scheduler(vllm_config)
# 2 Full Blocks and 1 Half Block.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request = create_request(
request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_decode=True,
do_remote_prefill=False,
)
request_id = request.request_id
scheduler.add_request(request)
# Fake Config
request = _setup_kv_transfer_request(request)
# Remote Prefill, triggers MoRIIOConnectorMetadata.
scheduler_output = scheduler.schedule()
kv_connector_metadata = scheduler_output.kv_connector_metadata
assert kv_connector_metadata is not None, "kv_connector_metadata is None"
assert isinstance(kv_connector_metadata, MoRIIOConnectorMetadata)
assert len(kv_connector_metadata.reqs_to_save) == 1, (
"Unexpected number of reqs_to_save"
)
assert len(kv_connector_metadata.reqs_to_recv) == 0, (
"Unexpected number of reqs_to_recv"
)
assert len(kv_connector_metadata.reqs_to_send) == 0, (
"Unexpected number of reqs_to_send"
)
assert request_id in kv_connector_metadata.reqs_to_save, (
"Request ID not in reqs_to_save"
)
req_meta = kv_connector_metadata.reqs_to_save[request_id]
for block_id, block in zip(
req_meta.local_block_ids,
scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[
request_id
],
):
assert block_id == block.block_id, f"{block_id} != {block.block_id}"
def test_write_mode_with_chunked_prefill_saves_local_block_ids():
"""Write mode with chunked prefill still records correct local block ids."""
# Setup Scheduler and Request
MAX_NUM_BATCHED_TOKENS = 64
NUM_TOKENS = MAX_NUM_BATCHED_TOKENS * 2 + MAX_NUM_BATCHED_TOKENS // 2
vllm_config = create_vllm_config(
max_num_batched_tokens=MAX_NUM_BATCHED_TOKENS, role="kv_producer"
)
BLOCK_SIZE = vllm_config.cache_config.block_size
scheduler = create_scheduler(vllm_config)
# 2 Full Blocks and 1 Half Block.
request = create_request(
request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_decode=True,
do_remote_prefill=False,
)
request_id = request.request_id
scheduler.add_request(request)
# Fake Config
request = _setup_kv_transfer_request(request)
# Remote Prefill with chunked prefill, triggers multiple schedules.
expected_counts = [(0, 0, 0), (0, 0, 0), (1, 0, 0)]
kv_connector_metadata = None
for _, (expected_save, expected_recv, expected_send) in enumerate(expected_counts):
scheduler_output = scheduler.schedule()
kv_connector_metadata = scheduler_output.kv_connector_metadata
assert len(kv_connector_metadata.reqs_to_save) == expected_save
assert len(kv_connector_metadata.reqs_to_recv) == expected_recv
assert len(kv_connector_metadata.reqs_to_send) == expected_send
assert kv_connector_metadata is not None, "kv_connector_metadata is None"
assert request_id in kv_connector_metadata.reqs_to_save, (
"Request ID not in reqs_to_save"
)
req_meta = kv_connector_metadata.reqs_to_save[request_id]
for block_id, block in zip(
req_meta.local_block_ids,
scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[
request_id
],
):
assert block_id == block.block_id, f"{block_id} != {block.block_id}"
def test_read_mode_loads_remote_block_ids(moriio_read_mode):
"""Read mode loads remote block ids into local cache mapping."""
# Setup Scheduler and Request
vllm_config = create_vllm_config(role="kv_consumer")
scheduler = create_scheduler(vllm_config)
# 2 Full Blocks and 1 Half Block.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request = create_request(
request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_decode=False,
do_remote_prefill=True,
)
request_id = request.request_id
scheduler.add_request(request)
block_list = scheduler.kv_cache_manager.coordinator.single_type_managers[
0
].req_to_blocks[request_id]
request = _setup_kv_transfer_request(request)
# Set remote block ids to be fetched.
request.kv_transfer_params["remote_block_ids"] = block_list
# Remote Prefill, triggers MorIIOConnectorMetadata.
scheduler_output = scheduler.schedule()
kv_connector_metadata = scheduler_output.kv_connector_metadata
assert kv_connector_metadata is not None, "kv_connector_metadata is None"
assert isinstance(kv_connector_metadata, MoRIIOConnectorMetadata), (
"kv_connector_metadata is not MoRIIOConnectorMetadata"
)
assert len(kv_connector_metadata.reqs_to_save) == 0, (
"Unexpected number of reqs_to_save"
)
assert len(kv_connector_metadata.reqs_to_recv) == 1, (
"Unexpected number of reqs_to_recv"
)
assert len(kv_connector_metadata.reqs_to_send) == 0, (
"Unexpected number of reqs_to_send"
)
assert request_id in kv_connector_metadata.reqs_to_recv, (
"Request ID not in reqs_to_recv"
)
req_meta = kv_connector_metadata.reqs_to_recv[request_id]
for block_id, block in zip(
req_meta.local_block_ids,
scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[
request_id
],
):
assert block_id == block.block_id, f"{block_id} != {block.block_id}"
@pytest.mark.skipif(
not aiter_available, reason="Requires aiter package for ROCm FlashAttention backend"
)
def test_register_kv_caches(mock_parallel_groups):
"""Test that MoRIIOConnector.register_kv_caches correctly registers kv caches."""
ROLE = "kv_consumer"
IP = get_ip()
vllm_config = create_vllm_config(role=ROLE)
DEFAULT_PORT = 6301
TP_RANK = 0
DP_RANK = 0
from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend
backend_cls = AiterFlashAttentionBackend
# Create test kv cache tensors using proper backend shape
kv_cache_shape = backend_cls.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
"layer2": shared_tensor,
}
with (
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.threading.Event"
),
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.threading.Thread"
),
):
# Create connector
vllm_config.kv_transfer_config.kv_connector_extra_config.update(
{
"proxy_ip": "127.0.0.1",
"proxy_ping_port": 12345,
"http_port": 12346,
}
)
connector = MoRIIOConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeMorIIOConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0
)
from mori.io import (
MemoryDesc,
)
# Execute register_kv_caches
connector.register_kv_caches(kv_caches)
# Verify that the MemoryDesc stored in layer_name_to_local_kv_cache_metadata
assert (
shared_tensor.data_ptr()
== MemoryDesc.unpack(
connector.connector_worker.layer_name_to_local_kv_cache_metadata[
"layer0"
][0]
).data
)
assert (
unique_tensor.data_ptr()
== MemoryDesc.unpack(
connector.connector_worker.layer_name_to_local_kv_cache_metadata[
"layer1"
][0]
).data
)
assert (
shared_tensor.data_ptr()
== MemoryDesc.unpack(
connector.connector_worker.layer_name_to_local_kv_cache_metadata[
"layer2"
][0]
).data
)
# Verify engine keys
expected_engine_key = f"{ROLE[3:]}:{IP}:{DEFAULT_PORT}:tp{TP_RANK}:dp{DP_RANK}"
assert (
MemoryDesc.unpack(
connector.connector_worker.layer_name_to_local_kv_cache_metadata[
"layer0"
][0]
).engine_key
== expected_engine_key
)
@pytest.mark.skipif(
not aiter_available, reason="Requires aiter package for ROCm FlashAttention backend"
)
def test_moriio_handshake_returns_metadata(mock_parallel_groups):
"""MoRIIO handshake socket returns valid agent metadata over ZMQ."""
ROLE = "kv_consumer"
vllm_config = create_vllm_config(role=ROLE)
from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend
backend_cls = AiterFlashAttentionBackend
# Create test kv cache tensors using proper backend shape
kv_cache_shape = backend_cls.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
"layer2": shared_tensor,
}
with (
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_engine.MoRIIOWrapper",
FakeMorIIOWrapper,
),
):
handshake_port = _find_free_port()
# Create connector
vllm_config.kv_transfer_config.kv_connector_extra_config.update(
{
"proxy_ip": "127.0.0.1",
"proxy_ping_port": 12345,
"http_port": 12346,
"handshake_port": handshake_port,
}
)
connector = MoRIIOConnector(vllm_config, KVConnectorRole.WORKER)
# Execute register_kv_caches
connector.register_kv_caches(kv_caches)
# Connect to handshake socket and request metadata
path = make_zmq_path("tcp", "127.0.0.1", handshake_port)
with zmq_ctx(zmq.DEALER, path) as sock:
sock.send(MoRIIOConstants.GET_META_MSG)
received_frame = sock.recv_multipart()
if len(received_frame) != 2 or received_frame[0] != b"":
raise ValueError(f"Unexpected frame! {received_frame = }")
metadata_bytes = received_frame[1]
decoder = msgspec.msgpack.Decoder(MoRIIOAgentMetadata)
metadata = decoder.decode(metadata_bytes)
assert isinstance(metadata, MoRIIOAgentMetadata), (
"Decoded metadata is not MoRIIOAgentMetadata"
)
...@@ -179,6 +179,12 @@ KVConnectorFactory.register_connector( ...@@ -179,6 +179,12 @@ KVConnectorFactory.register_connector(
"MultiConnector", "MultiConnector",
) )
KVConnectorFactory.register_connector(
"MoRIIOConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector",
"MoRIIOConnector",
)
KVConnectorFactory.register_connector( KVConnectorFactory.register_connector(
"OffloadingConnector", "OffloadingConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector", "vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector",
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import threading
import time
from collections.abc import Iterator
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
import msgspec
import torch
import zmq
from vllm import envs
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorMetadata,
)
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.logger import init_logger
from vllm.utils.network_utils import (
get_ip,
get_open_port,
make_zmq_socket,
)
if TYPE_CHECKING:
pass
from dataclasses import field
from enum import Enum
logger = init_logger(__name__)
Transfer = tuple[int, float]
EngineId = str
ReqId = str
@dataclass
class WriteTask:
request_id: str
dst_engine_id: str
local_block_ids: list[int]
remote_block_ids_hint: list[int] | None
layer_name: str
event: torch.cuda.Event
remote_notify_port: int
remote_ip: str
enqueue_time: float = field(default_factory=time.perf_counter)
retried: int = 0
@dataclass
class LayerTransferPlan:
"""Plan for transferring a single layer."""
request_id: str
layer_name: str
sess_idx: int
transfer_local_offsets: list[int]
transfer_remote_offsets: list[int]
transfer_sizes: list[int]
use_batch: bool = True
@dataclass
class RemoteAllocInfo:
"""Information about remote block allocation."""
block_ids: list[int]
writes_done: int = 0
decode_dp_rank: int = 0
transfer_offset: tuple[list[int], list[int], list[int]] | None = None
class ROLE(Enum):
PRODUCER = "producer"
CONSUMER = "consumer"
NOTINIT = "notinit"
class MoRIIOAgentMetadata(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
# required for @cached_property.d
dict=True,
):
engine_id: str
agent_metadata: bytes
kv_caches_base_addr: list[int]
num_blocks: int
block_len: int
attn_backend_name: str
class RoleManager:
"""Manages role state across the connector."""
_instance: Optional["RoleManager"] = None
_lock = threading.Lock()
def __init__(self) -> None:
self._role: ROLE = ROLE.NOTINIT
@classmethod
def get_instance(cls) -> "RoleManager":
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
def set_role(self, role: ROLE) -> None:
"""Set the current role."""
with self._lock:
self._role = role
def get_role(self) -> ROLE:
"""Get the current role."""
return self._role
def set_role(role: ROLE):
"""Set the global role."""
RoleManager.get_instance().set_role(role)
def get_role() -> ROLE:
"""Get the global role."""
return RoleManager.get_instance().get_role()
class MoRIIOMode(Enum):
READ = "read"
WRITE = "write"
class MoRIIOError(Exception):
"""Base exception for MoRIIO operations."""
pass
class HandshakeError(MoRIIOError):
"""Exception raised when handshake fails."""
pass
class TransferError(MoRIIOError):
"""Exception raised when transfer fails."""
pass
def get_moriio_mode() -> MoRIIOMode:
read_mode = envs.VLLM_MORIIO_CONNECTOR_READ_MODE
logger.debug("MoRIIO Connector read_mode: %s", read_mode)
if read_mode:
return MoRIIOMode.READ
else:
return MoRIIOMode.WRITE
def get_port_offset(dp_rank: int, tp_rank: int, tp_size: int = 1) -> int:
return (dp_rank) * tp_size + tp_rank
@dataclass
class MoRIIOConfig:
local_ip: str
local_kv_port: int
proxy_ip: str
local_ping_port: int
proxy_ping_port: int
http_port: int
handshake_port: int
notify_port: int
tp_rank: int
dp_rank: int
dp_size: int
tp_size: int
@classmethod
def from_vllm_config(cls, vllm_config: VllmConfig) -> "MoRIIOConfig":
# Port Configuration:
# local_ping_port -> Outgoing heartbeat to proxy
# proxy_ping_port -> Remote proxy's heartbeat ingress port
# http_port -> Instance's HTTP service endpoint
# local_kv_port -> service port for mori engine
# notify_port -> For synchronizing stages between prefill and decode
# handshake_port -> For initial handshake between mori engine
# TODO : merge notify_port and handshake_port to simplify port management
# supports non-contiguous ports
assert vllm_config.kv_transfer_config is not None, (
"kv_transfer_config must be set for MoRIIOConnector"
)
kv_transfer_config = vllm_config.kv_transfer_config
extra_config = kv_transfer_config.kv_connector_extra_config
tp_rank = get_tensor_model_parallel_rank()
dp_rank = vllm_config.parallel_config.data_parallel_rank
base_notify_port = int(extra_config["notify_port"])
dp_size = vllm_config.parallel_config.data_parallel_size
tp_size = get_tensor_model_parallel_world_size()
port_offset = get_port_offset(dp_rank, tp_rank)
return cls(
local_ip=get_ip(),
local_kv_port=get_open_port(),
proxy_ip=extra_config["proxy_ip"],
local_ping_port=get_open_port(),
proxy_ping_port=int(extra_config["proxy_ping_port"]),
http_port=int(extra_config["http_port"]),
handshake_port=int(extra_config["handshake_port"]),
notify_port=base_notify_port + port_offset,
tp_rank=tp_rank,
dp_rank=dp_rank,
dp_size=dp_size,
tp_size=tp_size,
)
class MoRIIOConstants:
"""Constants for MoRIIO connector."""
# ZMQ message types
GET_META_MSG = b"get_meta_msg"
POP_DONE_RECV = b"pop_done_recv"
OVER = b"OVER"
COMPLETION_PREFIX = "cmpl"
PING_INTERVAL = 5
MAX_PING_RETRIES = 100
DEFAULT_HANDSHAKE_PORT = "6301"
DEFAULT_NOTIFY_PORT = "61005"
VLLM_MORI_READ_ABORT_REQUEST_TIMEOUT = 3600
@dataclass
class ReqMeta:
"""Metadata for a single request."""
local_block_ids: list[int]
remote_block_ids: list[int]
remote_host: str
remote_port: int
remote_handshake_port: int
remote_notify_port: int
remote_engine_id: str
tp_size: int
remote_dp_size: int
class MoRIIOConnectorMetadata(KVConnectorMetadata):
def __init__(self):
self.reqs_to_recv: dict[ReqId, ReqMeta] = {}
self.reqs_to_save: dict[ReqId, ReqMeta] = {}
self.reqs_to_send: dict[ReqId, float] = {}
def __repr__(self):
return_str = ""
for req_id, req_meta in self.reqs_to_recv.items():
return_str += (
f"{req_id = },{req_meta.local_block_ids = },"
f"{req_meta.remote_host = },{req_meta.remote_port = }"
f"{req_meta.remote_engine_id = },{req_meta.tp_size = }"
)
return_str = f"MoRIIOConnectorMetadata:reqs_to_recv:{return_str},"
for req_id, expiry in self.reqs_to_send.items():
return_str += f"{req_id = },{expiry = }"
return_str = f"MoRIIOConnectorMetadata:reqs_to_send:{return_str},"
return return_str
def add_new_req(
self,
request_id: ReqId,
local_block_ids: list[int],
kv_transfer_params: dict[str, Any],
write_mode=False,
):
_req = ReqMeta(
local_block_ids=local_block_ids,
remote_block_ids=kv_transfer_params["remote_block_ids"],
remote_engine_id=kv_transfer_params["remote_engine_id"],
remote_host=kv_transfer_params["remote_host"],
remote_port=kv_transfer_params["remote_port"],
remote_handshake_port=kv_transfer_params["remote_handshake_port"],
remote_notify_port=kv_transfer_params["remote_notify_port"],
tp_size=kv_transfer_params.get("tp_size", 1),
remote_dp_size=kv_transfer_params.get("remote_dp_size", 1),
)
if write_mode:
self.reqs_to_save[request_id] = _req
else:
self.reqs_to_recv[request_id] = _req
@contextlib.contextmanager
def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]:
"""Context manager for a ZMQ socket"""
if socket_type not in (zmq.ROUTER, zmq.REQ, zmq.DEALER):
raise ValueError(f"Unexpected socket type: {socket_type}")
ctx: zmq.Context | None = None
try:
ctx = zmq.Context() # type: ignore[attr-defined]
yield make_zmq_socket(
ctx=ctx, path=addr, socket_type=socket_type, bind=socket_type == zmq.ROUTER
)
finally:
if ctx is not None:
ctx.destroy(linger=0)
...@@ -204,6 +204,10 @@ if TYPE_CHECKING: ...@@ -204,6 +204,10 @@ if TYPE_CHECKING:
VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: int | None = None VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: int | None = None
VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 480 VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 480
VLLM_MORIIO_CONNECTOR_READ_MODE: bool = False
VLLM_MORIIO_QP_PER_TRANSFER: int = 1
VLLM_MORIIO_POST_BATCH_SIZE: int = -1
VLLM_MORIIO_NUM_WORKERS: int = 1
VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT: int = 480 VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT: int = 480
VLLM_USE_CUDNN_PREFILL: bool = False VLLM_USE_CUDNN_PREFILL: bool = False
VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL: bool = False VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL: bool = False
...@@ -1383,6 +1387,20 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1383,6 +1387,20 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT": lambda: int( "VLLM_NIXL_ABORT_REQUEST_TIMEOUT": lambda: int(
os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "480") os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "480")
), ),
# Controls the read mode for the Mori-IO connector
"VLLM_MORIIO_CONNECTOR_READ_MODE": lambda: (
os.getenv("VLLM_MORIIO_CONNECTOR_READ_MODE", "False").lower() in ("true", "1")
),
# Controls the QP (Queue Pair) per transfer configuration for the Mori-IO connector
"VLLM_MORIIO_QP_PER_TRANSFER": lambda: int(
os.getenv("VLLM_MORIIO_QP_PER_TRANSFER", "1")
),
# Controls the post-processing batch size for the Mori-IO connector
"VLLM_MORIIO_POST_BATCH_SIZE": lambda: int(
os.getenv("VLLM_MORIIO_POST_BATCH_SIZE", "-1")
),
# Controls the number of workers for Mori operations for the Mori-IO connector
"VLLM_MORIIO_NUM_WORKERS": lambda: int(os.getenv("VLLM_MORIIO_NUM_WORKERS", "1")),
# Timeout (in seconds) for MooncakeConnector in PD disaggregated setup. # Timeout (in seconds) for MooncakeConnector in PD disaggregated setup.
"VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT": lambda: int( "VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT": lambda: int(
os.getenv("VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT", "480") os.getenv("VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT", "480")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment