Unverified Commit b2fac671 authored by Nicolò Lucchesi's avatar Nicolò Lucchesi Committed by GitHub
Browse files

[P/D] Heterogeneous TP (#18833)


Signed-off-by: default avatarnicklucche <nlucches@redhat.com>
parent 23027e2d
...@@ -8,7 +8,9 @@ MODELS=( ...@@ -8,7 +8,9 @@ MODELS=(
# Number of prefill and decode instances to create # Number of prefill and decode instances to create
NUM_PREFILL_INSTANCES=${NUM_PREFILL_INSTANCES:-1} # Default to 1 NUM_PREFILL_INSTANCES=${NUM_PREFILL_INSTANCES:-1} # Default to 1
NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-2} # Default to 2 NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-1} # Default to 1
PREFILLER_TP_SIZE=${PREFILLER_TP_SIZE:-1}
DECODER_TP_SIZE=${DECODER_TP_SIZE:-1}
# Find the git repository root directory # Find the git repository root directory
GIT_ROOT=$(git rev-parse --show-toplevel) GIT_ROOT=$(git rev-parse --show-toplevel)
...@@ -74,9 +76,10 @@ run_tests_for_model() { ...@@ -74,9 +76,10 @@ run_tests_for_model() {
for i in $(seq 0 $((NUM_PREFILL_INSTANCES-1))); do for i in $(seq 0 $((NUM_PREFILL_INSTANCES-1))); do
# Calculate GPU ID - we'll distribute across available GPUs # Calculate GPU ID - we'll distribute across available GPUs
GPU_ID=$((i % $(get_num_gpus))) GPU_ID=$((i % $(get_num_gpus)))
# Calculate port number (base port + instance number) # Calculate port number (base port + instance number)
PORT=$((8100 + i)) PORT=$((8100 + i))
# Calculate side channel port # Calculate side channel port. Avoid clash with with TP workers.
SIDE_CHANNEL_PORT=$((5559 + i)) SIDE_CHANNEL_PORT=$((5559 + i))
echo "Starting prefill instance $i on GPU $GPU_ID, port $PORT" echo "Starting prefill instance $i on GPU $GPU_ID, port $PORT"
...@@ -87,6 +90,7 @@ run_tests_for_model() { ...@@ -87,6 +90,7 @@ run_tests_for_model() {
--enforce-eager \ --enforce-eager \
--disable-log-requests \ --disable-log-requests \
--gpu-memory-utilization 0.2 \ --gpu-memory-utilization 0.2 \
--tensor-parallel-size $PREFILLER_TP_SIZE \
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'" --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'"
if [ -n "$model_args" ]; then if [ -n "$model_args" ]; then
...@@ -109,7 +113,7 @@ run_tests_for_model() { ...@@ -109,7 +113,7 @@ run_tests_for_model() {
# Calculate port number (base port + instance number) # Calculate port number (base port + instance number)
PORT=$((8200 + i)) PORT=$((8200 + i))
# Calculate side channel port # Calculate side channel port
SIDE_CHANNEL_PORT=$((5659 + i)) SIDE_CHANNEL_PORT=$((5659 + i * $DECODER_TP_SIZE))
echo "Starting decode instance $i on GPU $GPU_ID, port $PORT" echo "Starting decode instance $i on GPU $GPU_ID, port $PORT"
...@@ -119,6 +123,7 @@ run_tests_for_model() { ...@@ -119,6 +123,7 @@ run_tests_for_model() {
--enforce-eager \ --enforce-eager \
--disable-log-requests \ --disable-log-requests \
--gpu-memory-utilization 0.2 \ --gpu-memory-utilization 0.2 \
--tensor-parallel-size $DECODER_TP_SIZE \
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'" --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'"
if [ -n "$model_args" ]; then if [ -n "$model_args" ]; then
......
...@@ -14,6 +14,7 @@ RTOL = 0.03 ...@@ -14,6 +14,7 @@ RTOL = 0.03
# Model-specific expected values # Model-specific expected values
EXPECTED_VALUES = { EXPECTED_VALUES = {
"Qwen/Qwen3-0.6B": 0.41, "Qwen/Qwen3-0.6B": 0.41,
"deepseek-ai/deepseek-vl2-small": 0.59
} }
SIMPLE_PROMPT = "The best part about working on vLLM is that I got to meet so many people across various different organizations like UCB, Google, and Meta which means", # noqa: E501 SIMPLE_PROMPT = "The best part about working on vLLM is that I got to meet so many people across various different organizations like UCB, Google, and Meta which means", # noqa: E501
......
...@@ -3,11 +3,12 @@ ...@@ -3,11 +3,12 @@
""" """
KV cache helper for store. KV cache helper for store.
""" """
import torch import torch
import vllm.envs as envs import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.config import VllmConfig from vllm.config import VllmConfig, get_current_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -90,3 +91,18 @@ class model_aware_kv_ops_helper: ...@@ -90,3 +91,18 @@ class model_aware_kv_ops_helper:
layer.self_attn.attn._k_scale, layer.self_attn.attn._k_scale,
layer.self_attn.attn._v_scale, layer.self_attn.attn._v_scale,
) )
def get_kv_connector_cache_layout():
vllm_config = get_current_vllm_config()
kv_config = vllm_config.kv_transfer_config
if vllm_config.model_config is None:
logger.warning("Unable to detect current VLLM config. " \
"Defaulting to NHD kv cache layout.")
else:
use_mla = vllm_config.model_config.use_mla
if not use_mla and kv_config.kv_connector == "NixlConnector":
logger.info("NixlConnector detected. Setting KV cache " \
"layout to HND for better xfer performance.")
return "HND"
return "NHD"
...@@ -16,6 +16,8 @@ from vllm.attention.ops.merge_attn_states import merge_attn_states ...@@ -16,6 +16,8 @@ from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8, from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
get_flash_attn_version) get_flash_attn_version)
from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.distributed.kv_transfer.kv_connector.utils import (
get_kv_connector_cache_layout)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import cdiv from vllm.utils import cdiv
...@@ -70,6 +72,20 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -70,6 +72,20 @@ class FlashAttentionBackend(AttentionBackend):
raise ValueError("Block size must be a multiple of 16.") raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size) return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def get_kv_cache_stride_order() -> tuple[int, ...]:
# NOTE When running disaggregated PD with NIXL, HND layout is used for
# faster transfer. `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout = get_kv_connector_cache_layout()
if cache_layout == "NHD":
stride_order = (0, 1, 2, 3, 4)
elif cache_layout == "HND":
stride_order = (0, 1, 3, 2, 4)
else:
raise ValueError("Unknown cache layout format %s.", cache_layout)
return stride_order
@dataclass @dataclass
class FlashAttentionMetadata: class FlashAttentionMetadata:
......
...@@ -597,6 +597,7 @@ class WorkerWrapperBase: ...@@ -597,6 +597,7 @@ class WorkerWrapperBase:
def initialize_from_config(self, kv_cache_configs: List[Any]) -> None: def initialize_from_config(self, kv_cache_configs: List[Any]) -> None:
kv_cache_config = kv_cache_configs[self.rpc_rank] kv_cache_config = kv_cache_configs[self.rpc_rank]
with set_current_vllm_config(self.vllm_config):
self.worker.initialize_from_config(kv_cache_config) # type: ignore self.worker.initialize_from_config(kv_cache_config) # type: ignore
def init_device(self): def init_device(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment