Unverified Commit b2fac671 authored by Nicolò Lucchesi's avatar Nicolò Lucchesi Committed by GitHub
Browse files

[P/D] Heterogeneous TP (#18833)


Signed-off-by: default avatarnicklucche <nlucches@redhat.com>
parent 23027e2d
......@@ -8,7 +8,9 @@ MODELS=(
# Number of prefill and decode instances to create
NUM_PREFILL_INSTANCES=${NUM_PREFILL_INSTANCES:-1} # Default to 1
NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-2} # Default to 2
NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-1} # Default to 1
PREFILLER_TP_SIZE=${PREFILLER_TP_SIZE:-1}
DECODER_TP_SIZE=${DECODER_TP_SIZE:-1}
# Find the git repository root directory
GIT_ROOT=$(git rev-parse --show-toplevel)
......@@ -74,9 +76,10 @@ run_tests_for_model() {
for i in $(seq 0 $((NUM_PREFILL_INSTANCES-1))); do
# Calculate GPU ID - we'll distribute across available GPUs
GPU_ID=$((i % $(get_num_gpus)))
# Calculate port number (base port + instance number)
PORT=$((8100 + i))
# Calculate side channel port
# Calculate side channel port. Avoid clash with with TP workers.
SIDE_CHANNEL_PORT=$((5559 + i))
echo "Starting prefill instance $i on GPU $GPU_ID, port $PORT"
......@@ -87,6 +90,7 @@ run_tests_for_model() {
--enforce-eager \
--disable-log-requests \
--gpu-memory-utilization 0.2 \
--tensor-parallel-size $PREFILLER_TP_SIZE \
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'"
if [ -n "$model_args" ]; then
......@@ -109,7 +113,7 @@ run_tests_for_model() {
# Calculate port number (base port + instance number)
PORT=$((8200 + i))
# Calculate side channel port
SIDE_CHANNEL_PORT=$((5659 + i))
SIDE_CHANNEL_PORT=$((5659 + i * $DECODER_TP_SIZE))
echo "Starting decode instance $i on GPU $GPU_ID, port $PORT"
......@@ -119,6 +123,7 @@ run_tests_for_model() {
--enforce-eager \
--disable-log-requests \
--gpu-memory-utilization 0.2 \
--tensor-parallel-size $DECODER_TP_SIZE \
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'"
if [ -n "$model_args" ]; then
......
......@@ -14,6 +14,7 @@ RTOL = 0.03
# Model-specific expected values
EXPECTED_VALUES = {
"Qwen/Qwen3-0.6B": 0.41,
"deepseek-ai/deepseek-vl2-small": 0.59
}
SIMPLE_PROMPT = "The best part about working on vLLM is that I got to meet so many people across various different organizations like UCB, Google, and Meta which means", # noqa: E501
......
......@@ -3,11 +3,12 @@
"""
KV cache helper for store.
"""
import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.config import VllmConfig
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.logger import init_logger
logger = init_logger(__name__)
......@@ -90,3 +91,18 @@ class model_aware_kv_ops_helper:
layer.self_attn.attn._k_scale,
layer.self_attn.attn._v_scale,
)
def get_kv_connector_cache_layout():
vllm_config = get_current_vllm_config()
kv_config = vllm_config.kv_transfer_config
if vllm_config.model_config is None:
logger.warning("Unable to detect current VLLM config. " \
"Defaulting to NHD kv cache layout.")
else:
use_mla = vllm_config.model_config.use_mla
if not use_mla and kv_config.kv_connector == "NixlConnector":
logger.info("NixlConnector detected. Setting KV cache " \
"layout to HND for better xfer performance.")
return "HND"
return "NHD"
......@@ -16,6 +16,8 @@ from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
get_flash_attn_version)
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.distributed.kv_transfer.kv_connector.utils import (
get_kv_connector_cache_layout)
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cdiv
......@@ -70,6 +72,20 @@ class FlashAttentionBackend(AttentionBackend):
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def get_kv_cache_stride_order() -> tuple[int, ...]:
# NOTE When running disaggregated PD with NIXL, HND layout is used for
# faster transfer. `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout = get_kv_connector_cache_layout()
if cache_layout == "NHD":
stride_order = (0, 1, 2, 3, 4)
elif cache_layout == "HND":
stride_order = (0, 1, 3, 2, 4)
else:
raise ValueError("Unknown cache layout format %s.", cache_layout)
return stride_order
@dataclass
class FlashAttentionMetadata:
......
......@@ -597,7 +597,8 @@ class WorkerWrapperBase:
def initialize_from_config(self, kv_cache_configs: List[Any]) -> None:
kv_cache_config = kv_cache_configs[self.rpc_rank]
self.worker.initialize_from_config(kv_cache_config) # type: ignore
with set_current_vllm_config(self.vllm_config):
self.worker.initialize_from_config(kv_cache_config) # type: ignore
def init_device(self):
with set_current_vllm_config(self.vllm_config):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment