Unverified Commit 60662849 authored by Juncheng Gu's avatar Juncheng Gu Committed by GitHub
Browse files

[P/D] Support CPU Transfer in NixlConnector (#18293)


Signed-off-by: default avatarJuncheng Gu <juncgu@gmail.com>
Signed-off-by: default avatarRichard Liu <ricliu@google.com>
Co-authored-by: default avatarRichard Liu <39319471+richardsliu@users.noreply.github.com>
Co-authored-by: default avatarRichard Liu <ricliu@google.com>
parent 1e9ea8e6
...@@ -10,6 +10,7 @@ jinja2>=3.1.6 ...@@ -10,6 +10,7 @@ jinja2>=3.1.6
ray[default] ray[default]
ray[data] ray[data]
setuptools==78.1.0 setuptools==78.1.0
nixl==0.3.0
# Install torch_xla # Install torch_xla
--pre --pre
......
#!/bin/bash
set -xe
# Hosts / ports
PREFILL_HOST=${PREFILL_HOST:-"localhost"}
PREFILL_PORT=${PREFILL_PORT:-8100}
PREFILL_NIXL_SIDE_PORT=${PREFILL_NIXL_SIDE_PORT:-5577}
DECODE_HOST=${DECODE_HOST:-"localhost"}
DECODE_PORT=${DECODE_PORT:-8200}
PROXY_HOST=${PROXY_HOST:-"localhost"}
PROXY_PORT=${PROXY_PORT:-8192}
BASELINE_HOST=${BASELINE_HOST:-"localhost"}
BASELINE_PORT=${BASELINE_PORT:-9290}
# Model to run.
MODEL_NAME=${MODEL_NAME:-"meta-llama/Llama-3.2-3B-Instruct"}
MAX_MODEL_LEN=${MAX_MODEL_LEN:-1024}
BLOCK_SIZE=${BLOCK_SIZE:-32}
# execution env
GIT_ROOT=$(git rev-parse --show-toplevel)
EXP_ROOT="${GIT_ROOT}/tests/v1/kv_connector/nixl_integration"
CONDA_PATH=${CONDA_PATH:-"/home/${USER}/anaconda3"}
CONDA_ENV_NAME=${CONDA_ENV_NAME:-"nixl"}
OUTPUT_FILE=${OUTPUT_FILE:-"${EXP_ROOT}/.tpu_accuracy_test_outputs.txt"}
# Trap the SIGINT signal (triggered by Ctrl+C)
trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT
# Waits for vLLM server to start.
wait_for_server() {
local host=$1
local port=$2
timeout 1200 bash -c "
until curl -s ${host}:${port}/v1/completions > /dev/null; do
sleep 1
done" && return 0 || return 1
}
# Cleanup function
cleanup() {
echo "Caught Ctrl+C, cleaning up..."
# Cleanup commands
pgrep python | xargs kill -9 || true
# pkill -f python || true
echo "Cleanup complete. Exiting."
}
launch_baseline() {
BASELINE_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
VLLM_LOGGING_LEVEL=DEBUG \
VLLM_USE_V1=1 \
PJRT_DEVICE=TPU \
VLLM_WORKER_MULTIPROC_METHOD=spawn \
VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \
--host ${BASELINE_HOST} \
--port ${BASELINE_PORT} \
--max-model-len ${MAX_MODEL_LEN}\
--seed 42 \
--block-size ${BLOCK_SIZE} \
--gpu-memory-utilization 0.5 \
--disable-log-requests \
--enforce-eager"
echo ${BASELINE_BASE_CMD}
ssh -tt ${BASELINE_HOST} "${BASELINE_BASE_CMD}" &
}
launch_pd() {
PREFILL_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
UCX_TLS=tcp \
VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S=200 \
VLLM_LOGGING_LEVEL=DEBUG \
VLLM_USE_V1=1 \
VLLM_NIXL_SIDE_CHANNEL_HOST=${PREFILL_HOST} \
VLLM_NIXL_SIDE_CHANNEL_PORT=${PREFILL_NIXL_SIDE_PORT} \
PJRT_DEVICE=TPU \
VLLM_WORKER_MULTIPROC_METHOD=spawn \
VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \
--host ${PREFILL_HOST} \
--port ${PREFILL_PORT} \
--max-model-len ${MAX_MODEL_LEN}\
--seed 42 \
--block-size ${BLOCK_SIZE} \
--enforce-eager \
--gpu-memory-utilization 0.5 \
--disable-log-requests \
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'"
DECODE_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
UCX_TLS=tcp \
VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S=200 \
VLLM_LOGGING_LEVEL=DEBUG \
VLLM_USE_V1=1 \
PJRT_DEVICE=TPU \
VLLM_WORKER_MULTIPROC_METHOD=spawn \
VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \
--host ${DECODE_HOST} \
--port ${DECODE_PORT} \
--max-model-len ${MAX_MODEL_LEN}\
--seed 42 \
--block-size ${BLOCK_SIZE} \
--enforce-eager \
--gpu-memory-utilization 0.5 \
--disable-log-requests \
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'"
echo ${PREFILL_BASE_CMD}
echo ${DECODE_BASE_CMD}
sleep 2
# execute on hosts
ssh -tt ${PREFILL_HOST} "${PREFILL_BASE_CMD}" &
ssh -tt ${DECODE_HOST} "${DECODE_BASE_CMD}" &
sleep 1
wait_for_server ${PREFILL_HOST} ${PREFILL_PORT}
sleep 1
wait_for_server ${DECODE_HOST} ${DECODE_PORT}
sleep 1
}
launch_pd_proxy(){
PROXY_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
python3 ${EXP_ROOT}/toy_proxy_server.py \
--prefiller-host ${PREFILL_HOST} --prefiller-port ${PREFILL_PORT} \
--decoder-host ${DECODE_HOST} --decoder-port ${DECODE_PORT} \
--host=${PROXY_HOST} --port ${PROXY_PORT}"
echo ${PROXY_BASE_CMD}
ssh -tt ${PROXY_HOST} "${PROXY_BASE_CMD}" &
}
run_tests(){
local service_url=$1
local mode=$2
python3 ${EXP_ROOT}/test_disagg_accuracy.py --service_url=${service_url} --model_name=${MODEL_NAME} --mode=${mode} --file_name=${OUTPUT_FILE}
}
# run non-disagg. baseline & save outputs
launch_baseline
sleep 2
wait_for_server ${BASELINE_HOST} ${BASELINE_PORT}
run_tests "http://${BASELINE_HOST}:${BASELINE_PORT}" "baseline"
cleanup
sleep 10
# run disagg. & do exact-match with the outputs from baseline
launch_pd
launch_pd_proxy
sleep 10
run_tests "http://${PROXY_HOST}:${PROXY_PORT}" "disagg"
echo "-----P/D success----"
rm ${OUTPUT_FILE}
cleanup
exit 0
\ No newline at end of file
#!/bin/bash
set -xe
# Hosts / ports
PREFILL_HOST=${PREFILL_HOST:-"localhost"}
PREFILL_PORT=${PREFILL_PORT:-8100}
PREFILL_NIXL_SIDE_PORT=${PREFILL_NIXL_SIDE_PORT:-5577}
DECODE_HOST=${DECODE_HOST:-"localhost"}
DECODE_PORT=${DECODE_PORT:-8200}
PROXY_HOST=${PROXY_HOST:-"localhost"}
PROXY_PORT=${PROXY_PORT:-8192}
BASELINE_HOST=${BASELINE_HOST:-"localhost"}
BASELINE_PORT=${BASELINE_PORT:-9290}
# Model to run.
MODEL_NAME=${MODEL_NAME:-"meta-llama/Llama-3.2-3B-Instruct"}
MAX_MODEL_LEN=${MAX_MODEL_LEN:-1024}
BLOCK_SIZE=${BLOCK_SIZE:-32}
# execution env
GIT_ROOT=$(git rev-parse --show-toplevel)
EXP_ROOT="${GIT_ROOT}/tests/v1/kv_connector/nixl_integration"
CONDA_PATH=${CONDA_PATH:-"/home/${USER}/anaconda3"}
CONDA_ENV_NAME=${CONDA_ENV_NAME:-"nixl"}
OUTPUT_FILE=${OUTPUT_FILE:-"${EXP_ROOT}/.tpu_accuracy_test_outputs.txt"}
# Trap the SIGINT signal (triggered by Ctrl+C)
trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT
# Waits for vLLM server to start.
wait_for_server() {
local host=$1
local port=$2
timeout 1200 bash -c "
until curl -s ${host}:${port}/v1/completions > /dev/null; do
sleep 1
done" && return 0 || return 1
}
# Cleanup function
cleanup() {
echo "Caught Ctrl+C, cleaning up..."
# Cleanup commands
pgrep python | xargs kill -9 || true
# pkill -f python || true
echo "Cleanup complete. Exiting."
}
launch_pd() {
PREFILL_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
UCX_TLS=tcp \
VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S=200 \
VLLM_LOGGING_LEVEL=DEBUG \
VLLM_USE_V1=1 \
VLLM_NIXL_SIDE_CHANNEL_HOST=${PREFILL_HOST} \
VLLM_NIXL_SIDE_CHANNEL_PORT=${PREFILL_NIXL_SIDE_PORT} \
PJRT_DEVICE=TPU \
VLLM_WORKER_MULTIPROC_METHOD=spawn \
VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \
--host ${PREFILL_HOST} \
--port ${PREFILL_PORT} \
--max-model-len ${MAX_MODEL_LEN}\
--seed 42 \
--block-size ${BLOCK_SIZE} \
--enforce-eager \
--gpu-memory-utilization 0.5 \
--disable-log-requests \
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'"
DECODE_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
UCX_TLS=tcp \
VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S=200 \
VLLM_LOGGING_LEVEL=DEBUG \
VLLM_USE_V1=1 \
PJRT_DEVICE=TPU \
VLLM_WORKER_MULTIPROC_METHOD=spawn \
VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \
--host ${DECODE_HOST} \
--port ${DECODE_PORT} \
--max-model-len ${MAX_MODEL_LEN}\
--seed 42 \
--block-size ${BLOCK_SIZE} \
--enforce-eager \
--gpu-memory-utilization 0.5 \
--disable-log-requests \
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'"
echo ${PREFILL_BASE_CMD}
echo ${DECODE_BASE_CMD}
sleep 2
# execute on hosts
ssh -tt ${PREFILL_HOST} "${PREFILL_BASE_CMD}" &
ssh -tt ${DECODE_HOST} "${DECODE_BASE_CMD}" &
sleep 1
wait_for_server ${PREFILL_HOST} ${PREFILL_PORT}
sleep 1
wait_for_server ${DECODE_HOST} ${DECODE_PORT}
sleep 1
}
launch_pd_proxy(){
PROXY_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
python3 ${EXP_ROOT}/toy_proxy_server.py \
--prefiller-host ${PREFILL_HOST} --prefiller-port ${PREFILL_PORT} \
--decoder-host ${DECODE_HOST} --decoder-port ${DECODE_PORT} \
--host=${PROXY_HOST} --port ${PROXY_PORT}"
echo ${PROXY_BASE_CMD}
ssh -tt ${PROXY_HOST} "${PROXY_BASE_CMD}" &
}
# run disagg. & do exact-match with the outputs from baseline
launch_pd
launch_pd_proxy
sleep 10
PREFILL_HOST=${PREFILL_HOST} \
PREFILL_PORT=${PREFILL_PORT} \
DECODE_HOST=${DECODE_HOST} \
DECODE_PORT=${DECODE_PORT} \
PROXY_HOST=${PROXY_HOST} \
PROXY_PORT=${PROXY_PORT} python -m pytest -s -v ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/test_edge_cases.py
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import json
import os
import time
import openai
import requests
MAX_OUTPUT_LEN = 30
SAMPLE_PROMPTS = (
"Red Hat is the best company in the world to work for because it works on "
"open source software, which means that all the contributions are "
"delivered to the community. As a result, when working on projects like "
"vLLM we are able to meet many amazing people from various organizations "
"like AMD, Google, NVIDIA, ",
"We hold these truths to be self-evident, that all men are created equal, "
"that they are endowed by their Creator with certain unalienable Rights, "
"that among these are Life, Liberty and the pursuit of Happiness.--That "
"to secure these rights, Governments are instituted among Men, deriving "
"their just powers from the consent of the governed, ",
)
def check_vllm_server(url: str, timeout=5, retries=3) -> bool:
"""
Checks if the vLLM server is ready by sending a GET request to the
/health endpoint.
Args:
url (str): The base URL of the vLLM server.
timeout (int): Timeout in seconds for the request.
retries (int): Number of retries if the server is not ready.
Returns:
bool: True if the server is ready, False otherwise.
"""
for attempt in range(retries):
try:
response = requests.get(url, timeout=timeout)
if response.status_code == 200:
return True
else:
print(f"Attempt {attempt + 1}: Server returned status code "
"{response.status_code}")
except requests.exceptions.RequestException as e:
print(f"Attempt {attempt + 1}: Error connecting to server: {e}")
time.sleep(1) # Wait before retrying
return False
def run_simple_prompt(base_url: str, model_name: str,
input_prompt: str) -> str:
client = openai.OpenAI(api_key="EMPTY", base_url=base_url)
completion = client.completions.create(model=model_name,
prompt=input_prompt,
max_tokens=MAX_OUTPUT_LEN,
temperature=0.0,
seed=42)
# print("-" * 50)
# print(f"Completion results for {model_name}:")
# print(completion)
# print("-" * 50)
return completion.choices[0].text
def main():
"""
This script demonstrates how to accept two optional string arguments
("service_url" and "file_name") from the command line, each with a
default value of an empty string, using the argparse module.
"""
parser = argparse.ArgumentParser(description="vLLM client script")
parser.add_argument(
"--service_url", # Name of the first argument
type=str,
required=True,
help="The vLLM service URL.")
parser.add_argument(
"--model_name", # Name of the first argument
type=str,
required=True,
help="model_name",
)
parser.add_argument(
"--mode", # Name of the second argument
type=str,
default="baseline",
help="mode: baseline==non-disagg, or disagg",
)
parser.add_argument(
"--file_name", # Name of the second argument
type=str,
default=".vllm_output.txt",
help="the file that saves the output tokens ",
)
args = parser.parse_args()
for arg in vars(args):
print(f"{arg}: {getattr(args, arg)}")
if args.mode == "baseline":
# non-disagg
health_check_url = f"{args.service_url}/health"
else:
# disagg proxy
health_check_url = f"{args.service_url}/healthcheck"
if not os.path.exists(args.file_name):
raise ValueError(
f"In disagg mode, the output file {args.file_name} from "
"non-disagg. baseline does not exist.")
service_url = f"{args.service_url}/v1"
if not check_vllm_server(health_check_url):
raise RuntimeError(
f"vllm server: {args.service_url} is not ready yet!")
output_strs = dict()
for prompt in SAMPLE_PROMPTS:
output_str = run_simple_prompt(base_url=service_url,
model_name=args.model_name,
input_prompt=prompt)
print(f"Prompt: {prompt}, output: {output_str}")
output_strs[prompt] = output_str
if args.mode == "baseline":
# baseline: save outputs
try:
with open(args.file_name, 'w') as json_file:
json.dump(output_strs, json_file, indent=4)
except OSError as e:
print(f"Error writing to file: {e}")
raise
else:
# disagg. verify outputs
baseline_outputs = None
try:
with open(args.file_name) as json_file:
baseline_outputs = json.load(json_file)
except OSError as e:
print(f"Error writing to file: {e}")
raise
assert isinstance(baseline_outputs, dict)
assert len(baseline_outputs) == len(output_strs)
for prompt, output in baseline_outputs.items():
assert prompt in output_strs, f"{prompt} not included"
assert output == output_strs[prompt], (
f"baseline_output: {output} != PD output: {output_strs[prompt]}"
)
if __name__ == "__main__":
main()
...@@ -4,8 +4,11 @@ import os ...@@ -4,8 +4,11 @@ import os
import openai import openai
PREFILL_HOST = os.getenv("PREFILL_HOST", "localhost")
PREFILL_PORT = os.getenv("PREFILL_PORT", None) PREFILL_PORT = os.getenv("PREFILL_PORT", None)
DECODE_HOST = os.getenv("DECODE_HOST", "localhost")
DECODE_PORT = os.getenv("DECODE_PORT", None) DECODE_PORT = os.getenv("DECODE_PORT", None)
PROXY_HOST = os.getenv("PROXY_HOST", "localhost")
PROXY_PORT = os.getenv("PROXY_PORT", None) PROXY_PORT = os.getenv("PROXY_PORT", None)
if PREFILL_PORT is None or DECODE_PORT is None or PROXY_PORT is None: if PREFILL_PORT is None or DECODE_PORT is None or PROXY_PORT is None:
...@@ -21,15 +24,15 @@ def test_edge_cases(): ...@@ -21,15 +24,15 @@ def test_edge_cases():
# Set the OpenAI API key and base URL # Set the OpenAI API key and base URL
decode_client = openai.OpenAI( decode_client = openai.OpenAI(
api_key="MY_KEY", api_key="MY_KEY",
base_url=f"http://localhost:{DECODE_PORT}/v1", base_url=f"http://{DECODE_HOST}:{DECODE_PORT}/v1",
) )
prefill_client = openai.OpenAI( prefill_client = openai.OpenAI(
api_key="MY_KEY", api_key="MY_KEY",
base_url=f"http://localhost:{PREFILL_PORT}/v1", base_url=f"http://{PREFILL_HOST}:{PREFILL_PORT}/v1",
) )
proxy_client = openai.OpenAI( proxy_client = openai.OpenAI(
api_key="MY_KEY", api_key="MY_KEY",
base_url=f"http://localhost:{PROXY_PORT}/v1", base_url=f"http://{PROXY_HOST}:{PROXY_PORT}/v1",
) )
# Get the list of models # Get the list of models
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import argparse import argparse
import itertools import itertools
import logging
import os import os
import uuid import uuid
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
...@@ -11,9 +12,8 @@ import httpx ...@@ -11,9 +12,8 @@ import httpx
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from vllm.logger import init_logger logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logger = init_logger(__name__)
@asynccontextmanager @asynccontextmanager
......
...@@ -32,7 +32,7 @@ The class provides the following primitives: ...@@ -32,7 +32,7 @@ The class provides the following primitives:
import enum import enum
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Callable, Literal, Optional
import torch import torch
...@@ -46,6 +46,12 @@ if TYPE_CHECKING: ...@@ -46,6 +46,12 @@ if TYPE_CHECKING:
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request from vllm.v1.request import Request
# s_tensor_list, d_tensor_list, s_indices, d_indices, direction
CopyBlocksOp = Callable[[
dict[str, torch.Tensor], dict[
str, torch.Tensor], list[int], list[int], Literal["h2d", "d2h"]
], None]
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -127,6 +133,13 @@ class KVConnectorBase_V1(ABC): ...@@ -127,6 +133,13 @@ class KVConnectorBase_V1(ABC):
""" """
return return
def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
"""
Set the xPU-specific ops for copying KV between host and device.
Needed when host buffer is used for kv transfer (e.g., in NixlConnector)
"""
return
@abstractmethod @abstractmethod
def start_load_kv(self, forward_context: "ForwardContext", def start_load_kv(self, forward_context: "ForwardContext",
**kwargs) -> None: **kwargs) -> None:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import gc import gc
import time import time
from contextlib import contextmanager from contextlib import contextmanager
...@@ -23,12 +22,10 @@ from vllm.config import (CompilationLevel, VllmConfig, ...@@ -23,12 +22,10 @@ from vllm.config import (CompilationLevel, VllmConfig,
from vllm.distributed.eplb.eplb_state import EplbState from vllm.distributed.eplb.eplb_state import EplbState
from vllm.distributed.kv_transfer import (get_kv_transfer_group, from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group) has_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
get_pp_group, get_tp_group, graph_capture, is_global_first_rank, get_pp_group, get_tp_group, graph_capture, is_global_first_rank,
prepare_communication_buffer_for_model) prepare_communication_buffer_for_model)
from vllm.forward_context import (DPMetadata, get_forward_context, from vllm.forward_context import DPMetadata, set_forward_context
set_forward_context)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
...@@ -66,6 +63,8 @@ from vllm.v1.spec_decode.medusa import MedusaProposer ...@@ -66,6 +63,8 @@ from vllm.v1.spec_decode.medusa import MedusaProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.kv_connector_model_runner_mixin import (
KVConnectorModelRunnerMixin)
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from ..sample.logits_processor import LogitsProcessorManager from ..sample.logits_processor import LogitsProcessorManager
...@@ -88,7 +87,7 @@ else: ...@@ -88,7 +87,7 @@ else:
logger = init_logger(__name__) logger = init_logger(__name__)
class GPUModelRunner(LoRAModelRunnerMixin): class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def __init__( def __init__(
self, self,
...@@ -1357,7 +1356,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1357,7 +1356,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Return empty ModelRunnerOutput if there's no work to do. # Return empty ModelRunnerOutput if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT return EMPTY_MODEL_RUNNER_OUTPUT
return self.kv_connector_no_forward(scheduler_output) return self.kv_connector_no_forward(scheduler_output,
self.vllm_config)
# Prepare the decoder inputs. # Prepare the decoder inputs.
(attn_metadata, attention_cuda_graphs, logits_indices, (attn_metadata, attention_cuda_graphs, logits_indices,
...@@ -1745,52 +1745,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1745,52 +1745,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
spec_token_ids = draft_token_ids.tolist() spec_token_ids = draft_token_ids.tolist()
return spec_token_ids return spec_token_ids
@staticmethod
def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"):
# Update KVConnector with the KVConnector metadata forward().
if has_kv_transfer_group():
kv_connector = get_kv_transfer_group()
assert isinstance(kv_connector, KVConnectorBase_V1)
assert scheduler_output.kv_connector_metadata is not None
kv_connector.bind_connector_metadata(
scheduler_output.kv_connector_metadata)
# Background KV cache transfers happen here.
# These transfers are designed to be async and the requests
# involved may be disjoint from the running requests.
# Do this here to save a collective_rpc.
kv_connector.start_load_kv(get_forward_context())
@staticmethod
def maybe_wait_for_kv_save() -> None:
if has_kv_transfer_group():
get_kv_transfer_group().wait_for_save()
@staticmethod
def get_finished_kv_transfers(
scheduler_output: "SchedulerOutput",
) -> tuple[Optional[set[str]], Optional[set[str]]]:
if has_kv_transfer_group():
return get_kv_transfer_group().get_finished(
scheduler_output.finished_req_ids)
return None, None
def kv_connector_no_forward(
self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
# KV send/recv even if no work to do.
with set_forward_context(None, self.vllm_config):
self.maybe_setup_kv_connector(scheduler_output)
finished_sending, finished_recving = (
self.get_finished_kv_transfers(scheduler_output))
if not finished_sending and not finished_recving:
return EMPTY_MODEL_RUNNER_OUTPUT
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.finished_sending = finished_sending
output.finished_recving = finished_recving
return output
def propose_ngram_draft_token_ids( def propose_ngram_draft_token_ids(
self, self,
sampled_token_ids: list[list[int]], sampled_token_ids: list[list[int]],
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Define KV connector functionality mixin for model runners.
"""
import copy
from typing import TYPE_CHECKING, Optional
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.logger import init_logger
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
logger = init_logger(__name__)
# Defined as a kv connector functionality mixin for ModelRunner (GPU, TPU)
class KVConnectorModelRunnerMixin:
@staticmethod
def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"):
# Update KVConnector with the KVConnector metadata forward().
if has_kv_transfer_group():
kv_connector = get_kv_transfer_group()
assert isinstance(kv_connector, KVConnectorBase_V1)
assert scheduler_output.kv_connector_metadata is not None
kv_connector.bind_connector_metadata(
scheduler_output.kv_connector_metadata)
# Background KV cache transfers happen here.
# These transfers are designed to be async and the requests
# involved may be disjoint from the running requests.
# Do this here to save a collective_rpc.
kv_connector.start_load_kv(get_forward_context())
@staticmethod
def maybe_wait_for_kv_save() -> None:
if has_kv_transfer_group():
get_kv_transfer_group().wait_for_save()
@staticmethod
def get_finished_kv_transfers(
scheduler_output: "SchedulerOutput",
) -> tuple[Optional[set[str]], Optional[set[str]]]:
if has_kv_transfer_group():
return get_kv_transfer_group().get_finished(
scheduler_output.finished_req_ids)
return None, None
def kv_connector_no_forward(self, scheduler_output: "SchedulerOutput",
vllm_config: VllmConfig) -> ModelRunnerOutput:
# KV send/recv even if no work to do.
with set_forward_context(None, vllm_config):
self.maybe_setup_kv_connector(scheduler_output)
finished_sending, finished_recving = (
self.get_finished_kv_transfers(scheduler_output))
if not finished_sending and not finished_recving:
return EMPTY_MODEL_RUNNER_OUTPUT
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.finished_sending = finished_sending
output.finished_recving = finished_recving
return output
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import bisect import bisect
import gc import gc
import time import time
from typing import TYPE_CHECKING, Any, Optional, cast from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
from unittest.mock import patch from unittest.mock import patch
import numpy as np import numpy as np
...@@ -20,6 +20,8 @@ from vllm.attention.layer import Attention ...@@ -20,6 +20,8 @@ from vllm.attention.layer import Attention
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import (ParallelConfig, VllmConfig, from vllm.config import (ParallelConfig, VllmConfig,
get_layers_from_vllm_config, update_config) get_layers_from_vllm_config, update_config)
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group)
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA from vllm.lora.layers import BaseLayerWithLoRA
...@@ -46,6 +48,8 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsLists, ...@@ -46,6 +48,8 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsLists,
LogprobsTensors, ModelRunnerOutput) LogprobsTensors, ModelRunnerOutput)
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
from vllm.v1.worker.kv_connector_model_runner_mixin import (
KVConnectorModelRunnerMixin)
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch
...@@ -97,7 +101,7 @@ MIN_NUM_SEQS = 8 ...@@ -97,7 +101,7 @@ MIN_NUM_SEQS = 8
# The dummy_run should be comprehensive, ensuring all potential input shapes and # The dummy_run should be comprehensive, ensuring all potential input shapes and
# branch predictions are included as subgraph inputs to facilitate # branch predictions are included as subgraph inputs to facilitate
# pre-compilation. # pre-compilation.
class TPUModelRunner(LoRAModelRunnerMixin): class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def __init__( def __init__(
self, self,
...@@ -971,8 +975,12 @@ class TPUModelRunner(LoRAModelRunnerMixin): ...@@ -971,8 +975,12 @@ class TPUModelRunner(LoRAModelRunnerMixin):
# Update cached state # Update cached state
self._update_states(scheduler_output) self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens: if not scheduler_output.total_num_scheduled_tokens:
# Return empty ModelRunnerOutput if there's no work to do. if not has_kv_transfer_group():
return EMPTY_MODEL_RUNNER_OUTPUT # Return empty ModelRunnerOutput if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
return self.kv_connector_no_forward(scheduler_output,
self.vllm_config)
if self.is_multimodal_model: if self.is_multimodal_model:
# Run the multimodal encoder if any. # Run the multimodal encoder if any.
...@@ -986,6 +994,12 @@ class TPUModelRunner(LoRAModelRunnerMixin): ...@@ -986,6 +994,12 @@ class TPUModelRunner(LoRAModelRunnerMixin):
start_index = 0 start_index = 0
combined_selected_tokens: list[torch.Tensor] = [] combined_selected_tokens: list[torch.Tensor] = []
combined_logprobs: list[LogprobsLists] = [] combined_logprobs: list[LogprobsLists] = []
# NOTE: setup current batch's metadata for kv connector.
# Currently, only verified with NixlConnector
with set_forward_context(None, self.vllm_config):
self.maybe_setup_kv_connector(scheduler_output)
while start_index < self.input_batch.num_reqs: while start_index < self.input_batch.num_reqs:
attn_metadata, logits_indices, padded_num_reqs, num_reqs,\ attn_metadata, logits_indices, padded_num_reqs, num_reqs,\
end_index = self._prepare_inputs(scheduler_output, start_index) end_index = self._prepare_inputs(scheduler_output, start_index)
...@@ -1032,6 +1046,14 @@ class TPUModelRunner(LoRAModelRunnerMixin): ...@@ -1032,6 +1046,14 @@ class TPUModelRunner(LoRAModelRunnerMixin):
start_index = end_index start_index = end_index
# NOTE: current kv load and save get h2d/d2h copies involved.
# Those copies are blocking. Once they become async., kv_save
# should be called right after each single forward pass,
# instead of the forwards of the entire input batch.
self.maybe_wait_for_kv_save()
finished_sending, finished_recving = (
self.get_finished_kv_transfers(scheduler_output))
selected_token_ids = torch.cat(combined_selected_tokens, dim=0) selected_token_ids = torch.cat(combined_selected_tokens, dim=0)
if tpu_sampling_metadata.logprobs: if tpu_sampling_metadata.logprobs:
...@@ -1126,6 +1148,8 @@ class TPUModelRunner(LoRAModelRunnerMixin): ...@@ -1126,6 +1148,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
logprobs=logprobs_lists, logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict, prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[], pooler_output=[],
finished_sending=finished_sending,
finished_recving=finished_recving,
) )
# Check there are no new graphs compiled - all the graphs should be # Check there are no new graphs compiled - all the graphs should be
...@@ -1637,6 +1661,10 @@ class TPUModelRunner(LoRAModelRunnerMixin): ...@@ -1637,6 +1661,10 @@ class TPUModelRunner(LoRAModelRunnerMixin):
for cache in self.kv_caches: for cache in self.kv_caches:
xs.mark_sharding(cache, self.mesh, (None, 'x', None, None)) xs.mark_sharding(cache, self.mesh, (None, 'x', None, None))
if has_kv_transfer_group():
get_kv_transfer_group().register_kv_caches(kv_caches)
get_kv_transfer_group().set_host_xfer_buffer_ops(copy_kv_blocks)
def reset_dynamo_cache(self): def reset_dynamo_cache(self):
if self.is_multimodal_model: if self.is_multimodal_model:
compiled_model = self.model.get_language_model().model compiled_model = self.model.get_language_model().model
...@@ -1851,6 +1879,75 @@ def _get_padded_token_len(paddings: list[int], x: int) -> int: ...@@ -1851,6 +1879,75 @@ def _get_padded_token_len(paddings: list[int], x: int) -> int:
return paddings[index] return paddings[index]
def _make_src_and_dst_indices(
src_block_ids: list[int],
dst_block_ids: list[int],
src_device: Union[torch.device, str],
dst_device: Union[torch.device, str],
) -> tuple[torch.Tensor, torch.Tensor]:
src_indices = torch.tensor(src_block_ids,
device=src_device,
dtype=torch.int64)
dst_indices = torch.tensor(dst_block_ids,
device=dst_device,
dtype=torch.int64)
return src_indices, dst_indices
@torch.compile(backend="openxla")
def _insert_blocks_to_tpu(
cpu_cache: torch.Tensor,
tpu_cache: torch.Tensor,
cpu_block_indices: torch.Tensor,
tpu_block_indices: torch.Tensor,
) -> None:
torch.ops.xla.dynamo_set_buffer_donor_(tpu_cache, True)
tpu_cache[tpu_block_indices] = cpu_cache[cpu_block_indices].to(
tpu_cache.device)
@torch.compile(backend="openxla")
def _swap_out_tpu_blocks(
tpu_cache: torch.Tensor,
cpu_cache: torch.Tensor,
tpu_block_indices: torch.Tensor,
cpu_block_indices: torch.Tensor,
) -> None:
""" tpu blocks to cpu blocks"""
torch.ops.xla.dynamo_set_buffer_donor_(tpu_cache, True)
cpu_cache[cpu_block_indices] = tpu_cache[tpu_block_indices].cpu()
def copy_kv_blocks(
src_kv_caches: dict[str, torch.Tensor],
dst_kv_caches: dict[str, torch.Tensor],
src_block_ids: list[int],
dst_block_ids: list[int],
direction: Literal["h2d", "d2h"],
) -> None:
"""Copy kv blocks between different buffers."""
if not src_kv_caches or not dst_kv_caches or \
not src_block_ids or not dst_block_ids or \
len(src_block_ids) != len(dst_block_ids):
return
src_device = next(iter(src_kv_caches.values())).device
dst_device = next(iter(dst_kv_caches.values())).device
src_indices, dst_indices = _make_src_and_dst_indices(
src_block_ids=src_block_ids,
dst_block_ids=dst_block_ids,
src_device=src_device,
dst_device=dst_device)
_copy_fn = _insert_blocks_to_tpu if direction == "h2d" else \
_swap_out_tpu_blocks
for layer_name in src_kv_caches:
src_tensor = src_kv_caches[layer_name]
dst_tensor = dst_kv_caches[layer_name]
_copy_fn(src_tensor, dst_tensor, src_indices, dst_indices)
def _get_padded_num_kv_cache_update_slices( def _get_padded_num_kv_cache_update_slices(
num_tokens: int, max_num_reqs: int, page_size: int, num_tokens: int, max_num_reqs: int, page_size: int,
num_slices_per_kv_cache_update_block: int) -> int: num_slices_per_kv_cache_update_block: int) -> int:
......
...@@ -12,9 +12,11 @@ import torch_xla.debug.profiler as xp ...@@ -12,9 +12,11 @@ import torch_xla.debug.profiler as xp
import torch_xla.runtime as xr import torch_xla.runtime as xr
import vllm.envs as envs import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig from vllm.config import VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized, from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment) init_distributed_environment)
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
has_kv_transfer_group)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
...@@ -118,7 +120,7 @@ class TPUWorker: ...@@ -118,7 +120,7 @@ class TPUWorker:
# Initialize the distributed environment. # Initialize the distributed environment.
self._init_tpu_worker_distributed_environment( self._init_tpu_worker_distributed_environment(
self.parallel_config, self.rank, self.distributed_init_method, self.vllm_config, self.rank, self.distributed_init_method,
self.local_rank) self.local_rank)
# Device initialization should happen after initializing # Device initialization should happen after initializing
...@@ -242,7 +244,9 @@ class TPUWorker: ...@@ -242,7 +244,9 @@ class TPUWorker:
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
) -> Optional[ModelRunnerOutput]: ) -> Optional[ModelRunnerOutput]:
output = self.model_runner.execute_model(scheduler_output) output = self.model_runner.execute_model(scheduler_output)
return output if self.is_driver_worker else None # every worker's output is needed when kv_transfer_group is setup
return output if self.is_driver_worker or has_kv_transfer_group(
) else None
def profile(self, is_start: bool = True): def profile(self, is_start: bool = True):
if self.rank < 1: if self.rank < 1:
...@@ -294,7 +298,7 @@ class TPUWorker: ...@@ -294,7 +298,7 @@ class TPUWorker:
def _init_tpu_worker_distributed_environment( def _init_tpu_worker_distributed_environment(
self, self,
parallel_config: ParallelConfig, vllm_config: VllmConfig,
rank: int, rank: int,
distributed_init_method: Optional[str] = None, distributed_init_method: Optional[str] = None,
local_rank: int = -1, local_rank: int = -1,
...@@ -306,6 +310,7 @@ class TPUWorker: ...@@ -306,6 +310,7 @@ class TPUWorker:
# the input objects on CPU. The all-reduce and all-gather ops on TPU # the input objects on CPU. The all-reduce and all-gather ops on TPU
# are invoked by `xm.all_reduce` and `xm.all_gather` which use their # are invoked by `xm.all_reduce` and `xm.all_gather` which use their
# own context. # own context.
parallel_config = vllm_config.parallel_config
init_distributed_environment( init_distributed_environment(
world_size=parallel_config.world_size, world_size=parallel_config.world_size,
rank=rank, rank=rank,
...@@ -317,6 +322,8 @@ class TPUWorker: ...@@ -317,6 +322,8 @@ class TPUWorker:
parallel_config.tensor_parallel_size, parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size) parallel_config.pipeline_parallel_size)
ensure_kv_transfer_initialized(vllm_config)
try: try:
from tpu_commons.worker import TPUWorker as TPUCommonsWorker from tpu_commons.worker import TPUWorker as TPUCommonsWorker
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment