Commit 73032f48 authored by xuxz's avatar xuxz
Browse files

[PD]回退p2pncclconnector

parent a997359c
...@@ -6,17 +6,19 @@ from typing import TYPE_CHECKING, Any, Optional ...@@ -6,17 +6,19 @@ from typing import TYPE_CHECKING, Any, Optional
import regex as re import regex as re
import torch import torch
import os
from vllm import envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine import ( from vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine import (
P2pNcclEngine) P2pNcclEngine, RemoteAddr)
from vllm.distributed.parallel_state import get_world_group from vllm.distributed.parallel_state import get_world_group
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import MLACommonMetadata from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.distributed.parallel_state import get_pp_group, get_tp_group, get_dp_group
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
...@@ -35,6 +37,7 @@ class ReqMeta: ...@@ -35,6 +37,7 @@ class ReqMeta:
token_ids: torch.Tensor token_ids: torch.Tensor
# Slot mappings, should have the same length as token_ids # Slot mappings, should have the same length as token_ids
slot_mapping: torch.Tensor slot_mapping: torch.Tensor
slot_mapping_device: torch.Tensor = None
@staticmethod @staticmethod
def make_meta(request_id: str, token_ids: list[int], block_ids: list[int], def make_meta(request_id: str, token_ids: list[int], block_ids: list[int],
...@@ -54,7 +57,7 @@ class ReqMeta: ...@@ -54,7 +57,7 @@ class ReqMeta:
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
) )
@dataclass @dataclass
class P2pNcclConnectorMetadata(KVConnectorMetadata): class P2pNcclConnectorMetadata(KVConnectorMetadata):
requests: list[ReqMeta] requests: list[ReqMeta]
...@@ -87,13 +90,77 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -87,13 +90,77 @@ class P2pNcclConnector(KVConnectorBase_V1):
if role == KVConnectorRole.WORKER else 0 if role == KVConnectorRole.WORKER else 0
self._local_rank = get_world_group().local_rank \ self._local_rank = get_world_group().local_rank \
if role == KVConnectorRole.WORKER else 0 if role == KVConnectorRole.WORKER else 0
self._dp_rank = get_dp_group().rank_in_group \
if role == KVConnectorRole.WORKER else 0
self._pp_rank = get_pp_group().rank_in_group \
if role == KVConnectorRole.WORKER else 0
self._tp_rank = get_tp_group().rank_in_group \
if role == KVConnectorRole.WORKER else 0
self._dp_size = get_dp_group().world_size \
if role == KVConnectorRole.WORKER else 0
self._pp_size = get_pp_group().world_size \
if role == KVConnectorRole.WORKER else 0
self._tp_size = get_tp_group().world_size \
if role == KVConnectorRole.WORKER else 0
self.p2p_nccl_engine = P2pNcclEngine( self.p2p_nccl_engine = P2pNcclEngine(
local_rank=self._local_rank, local_rank=self._local_rank,
config=self.config,
hostname="",
port_offset=self._rank, port_offset=self._rank,
config=self.config,
model_config=vllm_config.model_config,
dp_rank=self._dp_rank,
pp_rank=self._pp_rank,
tp_rank=self._tp_rank,
dp_size=self._dp_size,
pp_size=self._pp_size,
tp_size=self._tp_size
) if role == KVConnectorRole.WORKER else None ) if role == KVConnectorRole.WORKER else None
self.parallel_config = vllm_config.parallel_config
self.model_config = vllm_config.model_config
self.total_num_hidden_layers = getattr(self.model_config.hf_text_config,
"num_hidden_layers", 0)
self.pp_size = self.parallel_config.pipeline_parallel_size
self.tp_size = self.parallel_config.tensor_parallel_size
self.num_card = self.pp_size * self.tp_size
self.remote_tp_size = self.config.get_from_extra_config(
"remote_tp_size", self.tp_size)
self.remote_pp_size = self.config.get_from_extra_config(
"remote_pp_size", self.pp_size)
self.enable_asymmetric_p2p = self.config.get_from_extra_config(
"enable_asymmetric_p2p", False)
self.remote_num_card = self.remote_tp_size * self.remote_pp_size
self.multiple_machines_d = 1 if self.remote_num_card > 8 else 0
self.multiple_machines_p = 1 if self.num_card > 8 else 0
if self.is_producer and self.multiple_machines_p == 1:
self.ip_map = {}
self.duplicate_keys = []
config_file = os.getenv('IP_CONFIG_FILE')
if not config_file:
print("Warning: Please set the IPVNet FILE environment variable for cross machine recognition of the second IP address")
return
try:
with open(config_file, 'r', encoding='utf-8') as file:
for line_num, line in enumerate(file, 1):
line = line.strip()
if line and not line.startswith('#'):
ips = line.split()
if len(ips) == 2:
first_ip, second_ip = ips
if first_ip not in self.ip_map:
self.ip_map[first_ip] = second_ip
else:
print(f"warning: num {line_num} Incorrect format : {line}")
except Exception as e:
print(f"Error: Exception occurred while reading configuration file - {e}")
def get_ip_value(self, key):
return self.ip_map.get(key)
# ============================== # ==============================
# Worker-side methods # Worker-side methods
...@@ -116,13 +183,11 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -116,13 +183,11 @@ class P2pNcclConnector(KVConnectorBase_V1):
# Only consumer/decode loads KV Cache # Only consumer/decode loads KV Cache
if self.is_producer: if self.is_producer:
return return
assert self.p2p_nccl_engine is not None assert self.p2p_nccl_engine is not None
attn_metadata = forward_context.attn_metadata attn_metadata = forward_context.attn_metadata
if attn_metadata is None: if attn_metadata is None:
return return
def inject_kv_into_layer( def inject_kv_into_layer(
dst_kv_cache_layer: torch.Tensor, dst_kv_cache_layer: torch.Tensor,
src_kv_cache: torch.Tensor, src_kv_cache: torch.Tensor,
...@@ -143,7 +208,7 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -143,7 +208,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
request_id (str): request id for log request_id (str): request id for log
""" """
dst_kv_cache_layer_shape = dst_kv_cache_layer.shape dst_kv_cache_layer_shape = dst_kv_cache_layer.shape
if isinstance(attn_metadata, MLACommonMetadata): if isinstance(attn_metadata, MLACommonMetadata) or all(isinstance(value, MLACommonMetadata) for value in attn_metadata.values()):
num_pages = dst_kv_cache_layer_shape[0] num_pages = dst_kv_cache_layer_shape[0]
page_size = dst_kv_cache_layer_shape[1] page_size = dst_kv_cache_layer_shape[1]
dst_kv_cache_layer = dst_kv_cache_layer.reshape( dst_kv_cache_layer = dst_kv_cache_layer.reshape(
...@@ -193,20 +258,95 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -193,20 +258,95 @@ class P2pNcclConnector(KVConnectorBase_V1):
# Load the KV for each request each layer # Load the KV for each request each layer
for request in metadata.requests: for request in metadata.requests:
for layer_name in forward_context.no_compile_layers: for layer_name in forward_context.no_compile_layers:
attn_layer = forward_context.no_compile_layers[layer_name] layer = forward_context.no_compile_layers[layer_name]
kv_cache_layer = attn_layer.kv_cache[ \
forward_context.virtual_engine]
kv_cache = self.p2p_nccl_engine.recv_tensor(
request.request_id + "#" + layer_name)
# Only process layers that have kv_cache
# attribute (attention layers) Skip non-attention
# layers like FusedMoE
kv_cache = getattr(layer, 'kv_cache', None)
if kv_cache is None: if kv_cache is None:
logger.warning("🚧src_kv_cache is None, %s",
request.request_id)
continue continue
inject_kv_into_layer(kv_cache_layer, kv_cache, kv_cache_layer = kv_cache[ \
request.slot_mapping, request.request_id) forward_context.virtual_engine]
if not envs.VLLM_P2P_ASYNC:
kv_cache = self.p2p_nccl_engine.recv_tensor(
request.request_id + "#" + layer_name)
if kv_cache is None:
logger.warning("🚧src_kv_cache is None, %s",
request.request_id)
continue
inject_kv_into_layer(kv_cache_layer, kv_cache,
request.slot_mapping, request.request_id)
tensor_id = request.request_id + "#" + layer_name
if tensor_id in self.p2p_nccl_engine.recv_store:
tensor = self.p2p_nccl_engine.recv_store.pop(tensor_id, None)
self.p2p_nccl_engine.send_request_id_to_tensor_ids.pop(
request.request_id, None)
self.p2p_nccl_engine.recv_request_id_to_tensor_ids.pop(
request.request_id, None)
addr = 0
if isinstance(tensor, tuple):
addr, _, _ = tensor
self.p2p_nccl_engine.pool.free(addr)
else:
dst_kv_cache_layer_shape = kv_cache_layer.shape
if isinstance(attn_metadata, MLACommonMetadata) or all(isinstance(value, MLACommonMetadata) for value in attn_metadata.values()):
num_pages = dst_kv_cache_layer_shape[0]
page_size = dst_kv_cache_layer_shape[1]
assert kv_cache_layer.is_contiguous()
dst_kv_cache_layer = kv_cache_layer.reshape(
num_pages * page_size, -1)
else:
num_pages = dst_kv_cache_layer_shape[1]
page_size = dst_kv_cache_layer_shape[2]
assert kv_cache_layer.is_contiguous()
dst_kv_cache_layer = kv_cache_layer.reshape(
2, num_pages * page_size, -1)
inject_start_index = 0
for num in range(self.p2p_nccl_engine.tensor_split_num):
kv_cache = self.p2p_nccl_engine.recv_tensor(
request.request_id + "#" + layer_name + "#" + str(num))
if kv_cache is None:
logger.warning("🚧src_kv_cache is None, %s",
request.request_id)
continue
if isinstance(attn_metadata, MLACommonMetadata) or all(isinstance(value, MLACommonMetadata) for value in attn_metadata.values()):
num_token = kv_cache.shape[0]
if len(request.slot_mapping) == num_token:
dst_kv_cache_layer[request.slot_mapping, ...] = kv_cache
else:
dst_kv_cache_layer[request.slot_mapping[inject_start_index:inject_start_index + num_token],
...] = kv_cache
else:
num_token = kv_cache.shape[1]
if len(request.slot_mapping) == num_token:
dst_kv_cache_layer[:, request.slot_mapping, ...] = kv_cache
else:
dst_kv_cache_layer[:, request.slot_mapping[inject_start_index:inject_start_index + num_token],
...] = kv_cache
inject_start_index += num_token
# inject_kv_into_layer(kv_cache_layer, kv_cache,
# request.slot_mapping, request.request_id)
tensor_id = request.request_id + "#" + layer_name + "#" + str(num)
if tensor_id in self.p2p_nccl_engine.recv_store:
tensor = self.p2p_nccl_engine.recv_store.pop(tensor_id, None)
self.p2p_nccl_engine.send_request_id_to_tensor_ids.pop(
request.request_id, None)
self.p2p_nccl_engine.recv_request_id_to_tensor_ids.pop(
request.request_id, None)
addr = 0
if isinstance(tensor, tuple):
addr, _, _ = tensor
self.p2p_nccl_engine.pool.free(addr)
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
def wait_for_layer_load(self, layer_name: str) -> None: def wait_for_layer_load(self, layer_name: str) -> None:
"""Blocking until the KV for a specific layer is loaded into vLLM's """Blocking until the KV for a specific layer is loaded into vLLM's
...@@ -238,6 +378,8 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -238,6 +378,8 @@ class P2pNcclConnector(KVConnectorBase_V1):
assert self.p2p_nccl_engine is not None assert self.p2p_nccl_engine is not None
is_mla = isinstance(attn_metadata, MLACommonMetadata)
def extract_kv_from_layer( def extract_kv_from_layer(
layer: torch.Tensor, layer: torch.Tensor,
slot_mapping: torch.Tensor, slot_mapping: torch.Tensor,
...@@ -246,7 +388,7 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -246,7 +388,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
Assume the shape of the layer is (2, num_pages, page_size, xxx) Assume the shape of the layer is (2, num_pages, page_size, xxx)
if MLA is not used, and (num_pages, page_size, xxx) otherwise. if MLA is not used, and (num_pages, page_size, xxx) otherwise.
""" """
if isinstance(attn_metadata, MLACommonMetadata): if isinstance(attn_metadata, MLACommonMetadata):
num_pages, page_size = layer.shape[0], layer.shape[1] num_pages, page_size = layer.shape[0], layer.shape[1]
return layer.reshape(num_pages * page_size, -1)[slot_mapping, return layer.reshape(num_pages * page_size, -1)[slot_mapping,
...@@ -257,18 +399,112 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -257,18 +399,112 @@ class P2pNcclConnector(KVConnectorBase_V1):
connector_metadata = self._get_connector_metadata() connector_metadata = self._get_connector_metadata()
assert isinstance(connector_metadata, P2pNcclConnectorMetadata) assert isinstance(connector_metadata, P2pNcclConnectorMetadata)
for request in connector_metadata.requests:
request_id = request.request_id
ip, port = self.parse_request_id(request_id, True)
remote_address = ip + ":" + str(port + self._rank)
kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address)
if envs.VLLM_ENABLE_TBO or envs.VLLM_P2P_ASYNC:
for request in connector_metadata.requests:
request_id = request.request_id
ip, port = self.parse_request_id(request_id, True)
remote_address = ip + ":" + str(port + self._rank)
slot_mapping = request.slot_mapping
if request.slot_mapping_device is None:
request.slot_mapping_device = \
request.slot_mapping.pin_memory().to(device=kv_layer.device, non_blocking=True)
slot_mapping = request.slot_mapping_device
tbo_evt = torch.cuda.Event(enable_timing=False)
tbo_evt.record()
pp_rank = (self.parallel_config.rank //
self.parallel_config.tensor_parallel_size) % \
self.parallel_config.pipeline_parallel_size
if (self.pp_size == 1):
self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name,
(kv_layer, slot_mapping), remote_address, tbo_evt)
elif (self.pp_size == 2):
if (pp_rank == 0):
self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name,
(kv_layer, slot_mapping), remote_address, tbo_evt)
self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name,
(kv_layer, slot_mapping), ip + ":" + str(port + self._rank + 4), tbo_evt)
else:
self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name,
(kv_layer, slot_mapping), remote_address, tbo_evt)
self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name,
(kv_layer, slot_mapping), ip + ":" + str(port + self._rank - 4), tbo_evt)
elif (self.pp_size == 8):
for i in range(8):
self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name,
(kv_layer, slot_mapping), ip + ":" + str(port + i), tbo_evt)
else:
print("Error: only suppprt pp1 pp2 pp8!!!!!!")
else:
for request in connector_metadata.requests:
request_id = request.request_id
ip, port = self.parse_request_id(request_id, True)
p_ip, p_port = self.parse_request_id(request_id, False)
remote_address = ip + ":" + str(port + self._rank)
# pd_pair_id = p_ip + ":" + p_port + "_" + ip + ":" + port
kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping)
pp_rank = (self.parallel_config.rank // self.parallel_config.tensor_parallel_size
) % self.parallel_config.pipeline_parallel_size
if (self.multiple_machines_p and self.multiple_machines_d):
ip_second = self.get_ip_value(ip)
if (self.pp_size == 1):
if self._rank < 8:
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, str(ip_second) + ":" + str(port + self._rank + 8))
elif (self.pp_size == 2):
if (pp_rank == 0):
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address)
else:
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, str(ip_second) + ":" + str(port + self._rank))
else:
logger.error("Error: multiple machines only suppprt pp1tp16 and pp2tp8!!!!!!")
elif (self.multiple_machines_p and not self.multiple_machines_d):
if (self.pp_size == 2):
remote_address = ip + ":" + str(port + self._tp_rank)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address)
else:
logger.error("Error: P multiple machines D machine only suppprt P:pp2tp8 D:tp8 !!!!!!")
elif (not self.multiple_machines_p and not self.multiple_machines_d):
# remote_addr = RemoteAddr(pd_pair_id, remote_address, self._rank + self.num_card)
self.p2p_nccl_engine.send_tensor_new(request_id, layer_name, kv_cache,
is_mla)
# if (self.pp_size == 1):
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# elif (self.pp_size == 2):
# if (pp_rank == 0):
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, ip + ":" + str(port + self._rank + 4))
# else:
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, ip + ":" + str(port + self._rank - 4))
# elif (self.pp_size == 8):
# for i in range(8):
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, ip + ":" + str(port + i))
# elif (self.enable_asymmetric_p2p):
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# else:
# logger.error("Error: P/D single machine only suppprt multiple tp:: (P: pp2tp4 D:tp8 P:pp8tp1 D:tp8) !!!!!!")
else:
logger.error("Error: not support!!!!!!")
def wait_for_save(self): def wait_for_save(self):
if self.is_producer: pass
assert self.p2p_nccl_engine is not None # if self.is_producer:
self.p2p_nccl_engine.wait_for_sent() # assert self.p2p_nccl_engine is not None
# self.p2p_nccl_engine.wait_for_sent()
def get_finished( def get_finished(
self, finished_req_ids: set[str], self, finished_req_ids: set[str],
...@@ -382,7 +618,9 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -382,7 +618,9 @@ class P2pNcclConnector(KVConnectorBase_V1):
num_scheduled_tokens = ( num_scheduled_tokens = (
scheduler_output.num_scheduled_tokens)[req_id] scheduler_output.num_scheduled_tokens)[req_id]
num_tokens = (num_scheduled_tokens + num_computed_tokens) num_tokens = (num_scheduled_tokens + num_computed_tokens)
assert req_id in self.chunked_prefill # assert req_id in self.chunked_prefill
if req_id not in self.chunked_prefill:
continue
block_ids = new_block_ids[0] block_ids = new_block_ids[0]
if not resumed_from_preemption: if not resumed_from_preemption:
block_ids = (self.chunked_prefill[req_id][0] + block_ids) block_ids = (self.chunked_prefill[req_id][0] + block_ids)
...@@ -482,4 +720,4 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -482,4 +720,4 @@ class P2pNcclConnector(KVConnectorBase_V1):
for i, (s1, s2) in enumerate(zip(shape1, shape2)) if i != dim): for i, (s1, s2) in enumerate(zip(shape1, shape2)) if i != dim):
raise NotImplementedError( raise NotImplementedError(
"Currently, only symmetric TP is supported. Asymmetric TP, PP," "Currently, only symmetric TP is supported. Asymmetric TP, PP,"
"and others will be supported in future PRs.") "and others will be supported in future PRs.")
\ No newline at end of file
...@@ -13,6 +13,7 @@ from typing import TYPE_CHECKING, Any, Optional ...@@ -13,6 +13,7 @@ from typing import TYPE_CHECKING, Any, Optional
import msgpack import msgpack
import torch import torch
import zmq import zmq
import regex
from vllm.config import KVTransferConfig from vllm.config import KVTransferConfig
from vllm.distributed.device_communicators.pynccl_wrapper import ( from vllm.distributed.device_communicators.pynccl_wrapper import (
...@@ -20,6 +21,13 @@ from vllm.distributed.device_communicators.pynccl_wrapper import ( ...@@ -20,6 +21,13 @@ from vllm.distributed.device_communicators.pynccl_wrapper import (
from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ( # noqa: E501 from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ( # noqa: E501
TensorMemoryPool) TensorMemoryPool)
from vllm.utils import current_stream, get_ip from vllm.utils import current_stream, get_ip
from vllm import envs
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from dataclasses import dataclass
from vllm.model_executor.models.utils import extract_layer_index
from vllm.distributed.utils import get_pp_indices
from vllm.config import ModelConfig
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
...@@ -28,6 +36,11 @@ logger = logging.getLogger(__name__) ...@@ -28,6 +36,11 @@ logger = logging.getLogger(__name__)
DEFAULT_MEM_POOL_SIZE_GB = 32 DEFAULT_MEM_POOL_SIZE_GB = 32
# @dataclass
# class SendQueueItem:
# tensor_id: str
# remote_address: str
# tensor: torch.Tensor
@contextmanager @contextmanager
def set_p2p_nccl_context(num_channels: str): def set_p2p_nccl_context(num_channels: str):
...@@ -59,17 +72,37 @@ def set_p2p_nccl_context(num_channels: str): ...@@ -59,17 +72,37 @@ def set_p2p_nccl_context(num_channels: str):
os.environ.pop(var, None) os.environ.pop(var, None)
@dataclass
class RemoteAddr:
pd_pair_id: str = ""
zmq_address: str = ""
comm_rank: int = 0
class P2pNcclEngine: class P2pNcclEngine:
def __init__(self, def __init__(self,
local_rank: int, local_rank: int,
port_offset: int,
config: KVTransferConfig, config: KVTransferConfig,
hostname: str = "", model_config: ModelConfig,
port_offset: int = 0, dp_rank: int = 0,
pp_rank: int = 0,
tp_rank: int = 0,
dp_size: int = 0,
pp_size: int = 0,
tp_size: int = 0,
library_path: Optional[str] = None) -> None: library_path: Optional[str] = None) -> None:
self.config = config self.config = config
self.model_config = model_config
self.rank = port_offset self.rank = port_offset
self.local_rank = local_rank self.local_rank = local_rank
self.dp_rank = dp_rank
self.pp_rank = pp_rank
self.tp_rank = tp_rank
self.dp_size = dp_size
self.pp_size = pp_size
self.tp_size = tp_size
self.device = torch.device(f"cuda:{self.local_rank}") self.device = torch.device(f"cuda:{self.local_rank}")
self.nccl = NCCLLibrary(library_path) self.nccl = NCCLLibrary(library_path)
...@@ -95,7 +128,7 @@ class P2pNcclEngine: ...@@ -95,7 +128,7 @@ class P2pNcclEngine:
port = int(self.config.kv_port) + port_offset port = int(self.config.kv_port) + port_offset
if port == 0: if port == 0:
raise ValueError("Port cannot be 0") raise ValueError("Port cannot be 0")
self._hostname = hostname self._hostname = get_ip()
self._port = port self._port = port
# Each card corresponds to a ZMQ address. # Each card corresponds to a ZMQ address.
...@@ -128,6 +161,10 @@ class P2pNcclEngine: ...@@ -128,6 +161,10 @@ class P2pNcclEngine:
self.send_stream = torch.cuda.Stream() self.send_stream = torch.cuda.Stream()
self.recv_stream = torch.cuda.Stream() self.recv_stream = torch.cuda.Stream()
self.p2p_async_kv_tokens = envs.VLLM_P2P_BUF_TOKENS
self.p2p_async_buf = None
self.tensor_split_num: int = 0
mem_pool_size_gb = self.config.get_from_extra_config( mem_pool_size_gb = self.config.get_from_extra_config(
"mem_pool_size_gb", DEFAULT_MEM_POOL_SIZE_GB) "mem_pool_size_gb", DEFAULT_MEM_POOL_SIZE_GB)
...@@ -167,11 +204,16 @@ class P2pNcclEngine: ...@@ -167,11 +204,16 @@ class P2pNcclEngine:
self._listener_thread.start() self._listener_thread.start()
self._ping_thread = None self._ping_thread = None
if port_offset == 0 and self.proxy_address != "": if self.multiple_machines:
self._ping_thread = threading.Thread(target=self._ping, if port_offset == 0 and self.proxy_address != "":
daemon=True) self._ping_thread = threading.Thread(target=self._ping,
self._ping_thread.start() daemon=True)
self._ping_thread.start()
else:
if self.proxy_address != "":
self._ping_thread = threading.Thread(target=self._ping_new,
daemon=True)
self._ping_thread.start()
logger.info( logger.info(
"💯P2pNcclEngine init, rank:%d, local_rank:%d, http_address:%s, " "💯P2pNcclEngine init, rank:%d, local_rank:%d, http_address:%s, "
"zmq_address:%s, proxy_address:%s, send_type:%s, buffer_size_" "zmq_address:%s, proxy_address:%s, send_type:%s, buffer_size_"
...@@ -179,6 +221,21 @@ class P2pNcclEngine: ...@@ -179,6 +221,21 @@ class P2pNcclEngine:
self.http_address, self.zmq_address, self.proxy_address, self.http_address, self.zmq_address, self.proxy_address,
self.send_type, self.buffer_size_threshold, self.nccl_num_channels) self.send_type, self.buffer_size_threshold, self.nccl_num_channels)
def _create_connect_new(self, remote_address: typing.Optional[str] = None):
assert remote_address is not None
if remote_address not in self.socks:
sock = self.context.socket(zmq.DEALER)
sock.setsockopt(zmq.SNDHWM, 10000)
sock.setsockopt(zmq.RCVHWM, 5000)
sock.setsockopt(zmq.LINGER, 0)
sock.setsockopt(zmq.TCP_KEEPALIVE, 1)
sock.setsockopt_string(zmq.IDENTITY, f"P-{self.zmq_address}")
sock.connect(f"tcp://{remote_address}")
self.socks[remote_address] = sock
return self.socks[remote_address]
def _create_connect(self, remote_address: typing.Optional[str] = None): def _create_connect(self, remote_address: typing.Optional[str] = None):
assert remote_address is not None assert remote_address is not None
if remote_address not in self.socks: if remote_address not in self.socks:
...@@ -206,11 +263,73 @@ class P2pNcclEngine: ...@@ -206,11 +263,73 @@ class P2pNcclEngine:
return self.socks[remote_address], self.comms[remote_address] return self.socks[remote_address], self.comms[remote_address]
def get_send_queue_items(self, request_id: str, layer_name: str,
tensor: torch.Tensor,
is_mla: bool) -> list[any]:
tensor_id = self.get_tensor_id(request_id, layer_name)
remote_ip, remote_port = self.parse_request_id(request_id, True)
p_ip, p_port = self.parse_request_id(request_id, False)
pd_pair_id = p_ip + ":" + str(p_port) + "_" + remote_ip + ":" + str(remote_port)
if not self.enable_asymmetric_p2p:
remote_address = remote_ip + ":" + str(remote_port + self.rank)
remote_addr = RemoteAddr(pd_pair_id, remote_address, self.rank + self.pp_size * self.tp_size)
# logger.info(f"""+++++xiabo tensor_id:{tensor_id} request_id:{request_id} remote_address:{remote_address}""")
return [(tensor_id, remote_addr, tensor)]
if not is_mla:
logger.error(" P2PNCCL only support mla model symmetric PP/TP!!!!")
remote_pp_rank = self.compute_remote_pp_rank(layer_name)
items: list[Any] = []
for d_tp_rank in range(self.remote_tp_size):
for mul_tp in range(self.multp):
if self.tp_rank + mul_tp * self.tp_size == d_tp_rank:
remote_port_offset = remote_pp_rank * self.remote_tp_size + d_tp_rank
remote_address = remote_ip + ":" + str(remote_port + remote_port_offset)
remote_addr = RemoteAddr(pd_pair_id, remote_address, remote_port_offset + self.pp_size * self.tp_size)
logger.debug(
"Wait to send::%s, tensor_shape:%s, "
"(pp=%d, tp=%d) -> remote_address=%s(pp=%d, tp=%d) comm_rank (%d -> %d)", tensor_id,
tensor.shape, self.pp_rank, self.tp_rank, remote_address,
remote_pp_rank, self.rank * mul_tp + self.rank, self.rank, remote_port_offset + self.pp_size * self.tp_size)
items.append([tensor_id, remote_addr, tensor])
return items
def send_tensor_new(
self,
request_id: str,
layer_name: str,
tensor: torch.Tensor,
is_mla: bool = False,
) -> bool:
tensor_id = self.get_tensor_id(request_id, layer_name)
if self.send_type == "PUT":
return all(
self._send_sync_new(item) for item in self.get_send_queue_items(
request_id, layer_name, tensor, is_mla))
if self.send_type == "PUT_ASYNC":
with self.send_queue_cv:
for item in self.get_send_queue_items(request_id, layer_name,
tensor, is_mla):
self.send_queue.append(item)
self.send_queue_cv.notify()
return True
if self.send_type == "GET":
logger.error(" P2PNCCL new not support GET model, please set VLLM_P2PNCCL_NEW=0 use defalut model!!!!")
def send_tensor( def send_tensor(
self, self,
tensor_id: str, tensor_id: str,
tensor: torch.Tensor, tensor: torch.Tensor,
remote_address: typing.Optional[str] = None, remote_address: typing.Optional[str] = None,
tbo_evt = None,
) -> bool: ) -> bool:
if remote_address is None: if remote_address is None:
with self.recv_store_cv: with self.recv_store_cv:
...@@ -250,6 +369,53 @@ class P2pNcclEngine: ...@@ -250,6 +369,53 @@ class P2pNcclEngine:
self.buffer_size / self.buffer_size_threshold * 100) self.buffer_size / self.buffer_size_threshold * 100)
return True return True
def p2p_async_send_tensor(
self,
tensor_id: str,
tensor: torch.Tensor,
remote_address: typing.Optional[str] = None,
tbo_evt = None,
) -> bool:
if remote_address is None:
with self.recv_store_cv:
self.recv_store[tensor_id] = tensor
self.recv_store_cv.notify()
return True
else:
if self.send_type == "PUT":
return self._send_sync(tensor_id, tensor, remote_address)
elif self.send_type == "PUT_ASYNC":
with self.send_queue_cv:
kv_layer, slot_mapping = tensor # tesor (kv_layer, slot_mapping)
self.send_queue.append([tensor_id, remote_address, kv_layer, slot_mapping, tbo_evt])
self.send_queue_cv.notify()
else: # GET
with self.send_store_cv:
tensor_size = tensor.element_size() * tensor.numel()
while (self.buffer_size + tensor_size
> self.buffer_size_threshold):
oldest_tenser_id = next(iter(self.send_store))
oldest_tenser = self.send_store.pop(oldest_tenser_id)
oldest_tenser_size = oldest_tenser.element_size(
) * oldest_tenser.numel()
self.buffer_size -= oldest_tenser_size
logger.info(
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
" buffer_size:%d, oldest_tenser_size:%d, rank:%d",
remote_address, tensor_id, tensor_size,
self.buffer_size, oldest_tenser_size, self.rank)
self.send_store[tensor_id] = tensor
self.buffer_size += tensor_size
logger.debug(
"🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
"shape:%s, rank:%d, buffer_size:%d(%.2f%%)",
remote_address, tensor_id, tensor_size, tensor.shape,
self.rank, self.buffer_size,
self.buffer_size / self.buffer_size_threshold * 100)
return True
def recv_tensor( def recv_tensor(
self, self,
...@@ -327,6 +493,8 @@ class P2pNcclEngine: ...@@ -327,6 +493,8 @@ class P2pNcclEngine:
self.zmq_address, remote_address.decode(), rank) self.zmq_address, remote_address.decode(), rank)
elif data["cmd"] == "PUT": elif data["cmd"] == "PUT":
tensor_id = data["tensor_id"] tensor_id = data["tensor_id"]
if "tensor_split_num" in data:
self.tensor_split_num = data["tensor_split_num"]
try: try:
with torch.cuda.stream(self.recv_stream): with torch.cuda.stream(self.recv_stream):
tensor = torch.empty(data["shape"], tensor = torch.empty(data["shape"],
...@@ -343,10 +511,6 @@ class P2pNcclEngine: ...@@ -343,10 +511,6 @@ class P2pNcclEngine:
# Store Tensor in memory pool # Store Tensor in memory pool
addr = self.pool.store_tensor(tensor) addr = self.pool.store_tensor(tensor)
tensor = (addr, tensor.dtype, tensor.shape) tensor = (addr, tensor.dtype, tensor.shape)
logger.warning(
"🔴[PUT]Recv Tensor, Out Of Threshold, "
"%s👈%s, data:%s, addr:%d", self.zmq_address,
remote_address.decode(), data, addr)
else: else:
self.buffer_size += tensor_size self.buffer_size += tensor_size
...@@ -363,7 +527,56 @@ class P2pNcclEngine: ...@@ -363,7 +527,56 @@ class P2pNcclEngine:
self.recv_store[tensor_id] = tensor self.recv_store[tensor_id] = tensor
self._have_received_tensor_id(tensor_id) self._have_received_tensor_id(tensor_id)
self.recv_store_cv.notify() self.recv_store_cv.notify()
elif data["cmd"] == "PUT_NEW":
tensor_id = data["tensor_id"]
if "tensor_split_num" in data:
self.tensor_split_num = data["tensor_split_num"]
try:
with torch.cuda.stream(self.recv_stream):
tensor = torch.empty(data["shape"],
dtype=getattr(
torch, data["dtype"]),
device=self.device)
self.router_socket.send_multipart(
[remote_address, b"0"])
# comm, rank = self.comms[remote_address.decode()]
# self._recv(comm, tensor, rank ^ 1, self.recv_stream)
comm, rank = self.comms[data["pd_pair_id"]]
self._recv(comm, tensor, int(data["comm_rank"]), self.recv_stream)
tensor_size = tensor.element_size() * tensor.numel()
if (self.buffer_size + tensor_size
> self.buffer_size_threshold):
# Store Tensor in memory pool
addr = self.pool.store_tensor(tensor)
tensor = (addr, tensor.dtype, tensor.shape)
else:
self.buffer_size += tensor_size
except torch.cuda.OutOfMemoryError:
self.router_socket.send_multipart(
[remote_address, b"1"])
tensor = None
logger.warning(
"🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, "
"data:%s", self.zmq_address,
remote_address.decode(), data)
with self.recv_store_cv:
self.recv_store[tensor_id] = tensor
self._have_received_tensor_id(tensor_id)
self.recv_store_cv.notify()
elif data["cmd"] == "comm_init":
unique_id = self.nccl.unique_id_from_bytes(
bytes(data["unique_id"]))
with torch.cuda.device(self.device):
rank = int(data["rank"])
world_size = int(data["world_size"])
with set_p2p_nccl_context(self.nccl_num_channels):
comm: ncclComm_t = self.nccl.ncclCommInitRank(
world_size, unique_id, rank)
self.comms[data["pd_pair_id"]] = (comm, rank)
logger.info(
"🤝ncclCommInitRank Success, %s👈%s, MyRank:%s",
self.zmq_address, data["pd_pair_id"], rank)
elif data["cmd"] == "GET": elif data["cmd"] == "GET":
tensor_id = data["tensor_id"] tensor_id = data["tensor_id"]
with self.send_store_cv: with self.send_store_cv:
...@@ -410,10 +623,21 @@ class P2pNcclEngine: ...@@ -410,10 +623,21 @@ class P2pNcclEngine:
with self.send_queue_cv: with self.send_queue_cv:
while not self.send_queue: while not self.send_queue:
self.send_queue_cv.wait() self.send_queue_cv.wait()
tensor_id, remote_address, tensor = self.send_queue.popleft() if envs.VLLM_ENABLE_TBO or envs.VLLM_P2P_ASYNC:
tensor_id, remote_address, kv_layer, slot_mapping, tbo_evt = self.send_queue.popleft()
else:
tensor_id, remote_address, tensor = self.send_queue.popleft()
if not self.send_queue: if not self.send_queue:
self.send_queue_cv.notify() self.send_queue_cv.notify()
self._send_sync(tensor_id, tensor, remote_address) if (envs.VLLM_ENABLE_TBO or envs.VLLM_P2P_ASYNC) and tbo_evt is not None:
self.send_stream.wait_event(tbo_evt)
self._send_kv_p2p_sync(tensor_id, kv_layer, slot_mapping, remote_address)
else:
if self.multiple_machines:
self._send_sync(tensor_id, tensor, remote_address)
else:
# logger.info(f"""=============xiabo tensor_id:{tensor_id} remote_address:{remote_address}""")
self._send_sync_new(tensor_id, tensor, remote_address)
def wait_for_sent(self): def wait_for_sent(self):
if self.send_type == "PUT_ASYNC": if self.send_type == "PUT_ASYNC":
...@@ -518,7 +742,7 @@ class P2pNcclEngine: ...@@ -518,7 +742,7 @@ class P2pNcclEngine:
"pd_pair_id": remote_address.pd_pair_id, "pd_pair_id": remote_address.pd_pair_id,
"comm_rank": rank "comm_rank": rank
} }
# logger.info(f"""_send_sync_new:{data}""") logger.info(f"""_send_sync_new:{data}""")
sock.send(msgpack.dumps(data)) sock.send(msgpack.dumps(data))
response = sock.recv() response = sock.recv()
...@@ -627,6 +851,36 @@ class P2pNcclEngine: ...@@ -627,6 +851,36 @@ class P2pNcclEngine:
sock.send(msgpack.dumps(data)) sock.send(msgpack.dumps(data))
time.sleep(3) time.sleep(3)
def _ping_new(self):
sock = self.context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
logger.debug("ping start, zmq_address:%s", self.zmq_address)
sock.connect(f"tcp://{self.proxy_address}")
if self.rank == 0:
data = {
"type": "P_init" if self.config.is_kv_producer else "D_init",
"http_address": self.http_address,
"zmq_address": self.zmq_address,
"dp_size" : self.dp_size,
"pp_size" : self.pp_size,
"tp_size" : self.tp_size
}
# logger.info(f"""_ping data:{data}""")
sock.send(msgpack.dumps(data))
data = {
"type": "P" if self.config.is_kv_producer else "D",
"http_address": self.http_address,
"dp_rank" : self.dp_rank,
"pp_rank" : self.pp_rank,
"tp_rank" : self.tp_rank,
"zmq_address": self.zmq_address
}
# while True:
# logger.info(f"""_ping data:{data}""")
sock.send(msgpack.dumps(data))
# time.sleep(3)
def _send(self, comm, tensor: torch.Tensor, dst: int, stream=None): def _send(self, comm, tensor: torch.Tensor, dst: int, stream=None):
assert tensor.device == self.device, ( assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, " f"this nccl communicator is created to work on {self.device}, "
...@@ -727,4 +981,4 @@ class P2pNcclEngine: ...@@ -727,4 +981,4 @@ class P2pNcclEngine:
return ip, port return ip, port
raise ValueError( raise ValueError(
f"Request id {request_id} does not contain hostname and port") f"Request id {request_id} does not contain hostname and port")
\ No newline at end of file
...@@ -63,7 +63,7 @@ class TensorMemoryPool: ...@@ -63,7 +63,7 @@ class TensorMemoryPool:
than min_block_size than min_block_size
""" """
def __init__(self, max_block_size: int, min_block_size: int = 512): def __init__(self, max_block_size: int, min_block_size: int = 128):
if max_block_size <= 0 or min_block_size <= 0: if max_block_size <= 0 or min_block_size <= 0:
raise ValueError("Block sizes must be positive") raise ValueError("Block sizes must be positive")
if max_block_size < min_block_size: if max_block_size < min_block_size:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment