Commit 7d5faa43 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-pd-all' into 'v0.9.2-dev'

mla模型P、D单实例单机的任意切分方式(满足D的tp>=P的tp)使用

See merge request dcutoolkit/deeplearing/vllm!315
parents bac269d7 4f51931d
...@@ -18,7 +18,7 @@ from vllm.forward_context import get_forward_context ...@@ -18,7 +18,7 @@ from vllm.forward_context import get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import MLACommonMetadata from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.distributed.parallel_state import get_pp_group, get_tp_group, get_dp_group
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
...@@ -90,14 +90,24 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -90,14 +90,24 @@ class P2pNcclConnector(KVConnectorBase_V1):
if role == KVConnectorRole.WORKER else 0 if role == KVConnectorRole.WORKER else 0
self._local_rank = get_world_group().local_rank \ self._local_rank = get_world_group().local_rank \
if role == KVConnectorRole.WORKER else 0 if role == KVConnectorRole.WORKER else 0
self._dp_rank = get_dp_group().rank_in_group \
if role == KVConnectorRole.WORKER else 0
self._pp_rank = get_pp_group().rank_in_group \
if role == KVConnectorRole.WORKER else 0
self._tp_rank = get_tp_group().rank_in_group \ self._tp_rank = get_tp_group().rank_in_group \
if role == KVConnectorRole.WORKER else 0 if role == KVConnectorRole.WORKER else 0
self._dp_size = get_dp_group().world_size \
if role == KVConnectorRole.WORKER else 0
self._pp_size = get_pp_group().world_size \
if role == KVConnectorRole.WORKER else 0
self._tp_size = get_tp_group().world_size \
if role == KVConnectorRole.WORKER else 0
self.p2p_nccl_engine = P2pNcclEngine( self.p2p_nccl_engine = P2pNcclEngine(
local_rank=self._local_rank, local_rank=self._local_rank,
config=self.config,
hostname="",
port_offset=self._rank, port_offset=self._rank,
config=self.config,
model_config=vllm_config.model_config,
) if role == KVConnectorRole.WORKER else None ) if role == KVConnectorRole.WORKER else None
self.parallel_config = vllm_config.parallel_config self.parallel_config = vllm_config.parallel_config
...@@ -365,6 +375,8 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -365,6 +375,8 @@ class P2pNcclConnector(KVConnectorBase_V1):
assert self.p2p_nccl_engine is not None assert self.p2p_nccl_engine is not None
is_mla = isinstance(attn_metadata, MLACommonMetadata)
def extract_kv_from_layer( def extract_kv_from_layer(
layer: torch.Tensor, layer: torch.Tensor,
slot_mapping: torch.Tensor, slot_mapping: torch.Tensor,
...@@ -455,29 +467,31 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -455,29 +467,31 @@ class P2pNcclConnector(KVConnectorBase_V1):
logger.error("Error: P multiple machines D machine only suppprt P:pp2tp8 D:tp8 !!!!!!") logger.error("Error: P multiple machines D machine only suppprt P:pp2tp8 D:tp8 !!!!!!")
elif (not self.multiple_machines_p and not self.multiple_machines_d): elif (not self.multiple_machines_p and not self.multiple_machines_d):
if (self.pp_size == 1): self.p2p_nccl_engine.send_tensor_new(request_id, layer_name, kv_cache,
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, is_mla)
kv_cache, remote_address) # if (self.pp_size == 1):
elif (self.pp_size == 2): # self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
if (pp_rank == 0): # kv_cache, remote_address)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, # elif (self.pp_size == 2):
kv_cache, remote_address) # if (pp_rank == 0):
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, # self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, ip + ":" + str(port + self._rank + 4)) # kv_cache, remote_address)
else: # self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, # kv_cache, ip + ":" + str(port + self._rank + 4))
kv_cache, remote_address) # else:
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, # self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, ip + ":" + str(port + self._rank - 4)) # kv_cache, remote_address)
elif (self.pp_size == 8): # self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
for i in range(8): # kv_cache, ip + ":" + str(port + self._rank - 4))
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, # elif (self.pp_size == 8):
kv_cache, ip + ":" + str(port + i)) # for i in range(8):
elif (self.enable_asymmetric_p2p): # self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, # kv_cache, ip + ":" + str(port + i))
kv_cache, remote_address) # elif (self.enable_asymmetric_p2p):
else: # self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
logger.error("Error: P/D single machine only suppprt multiple tp:: (P: pp2tp4 D:tp8 P:pp8tp1 D:tp8) !!!!!!") # kv_cache, remote_address)
# else:
# logger.error("Error: P/D single machine only suppprt multiple tp:: (P: pp2tp4 D:tp8 P:pp8tp1 D:tp8) !!!!!!")
else: else:
logger.error("Error: not support!!!!!!") logger.error("Error: not support!!!!!!")
def wait_for_save(self): def wait_for_save(self):
......
...@@ -13,6 +13,7 @@ from typing import TYPE_CHECKING, Any, Optional ...@@ -13,6 +13,7 @@ from typing import TYPE_CHECKING, Any, Optional
import msgpack import msgpack
import torch import torch
import zmq import zmq
import regex
from vllm.config import KVTransferConfig from vllm.config import KVTransferConfig
from vllm.distributed.device_communicators.pynccl_wrapper import ( from vllm.distributed.device_communicators.pynccl_wrapper import (
...@@ -23,6 +24,11 @@ from vllm.utils import current_stream, get_ip ...@@ -23,6 +24,11 @@ from vllm.utils import current_stream, get_ip
from vllm import envs from vllm import envs
from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from dataclasses import dataclass
from vllm.model_executor.models.utils import extract_layer_index
from vllm.distributed.utils import get_pp_indices
from vllm.config import ModelConfig
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
...@@ -30,6 +36,11 @@ logger = logging.getLogger(__name__) ...@@ -30,6 +36,11 @@ logger = logging.getLogger(__name__)
DEFAULT_MEM_POOL_SIZE_GB = 32 DEFAULT_MEM_POOL_SIZE_GB = 32
# @dataclass
# class SendQueueItem:
# tensor_id: str
# remote_address: str
# tensor: torch.Tensor
@contextmanager @contextmanager
def set_p2p_nccl_context(num_channels: str): def set_p2p_nccl_context(num_channels: str):
...@@ -65,22 +76,39 @@ class P2pNcclEngine: ...@@ -65,22 +76,39 @@ class P2pNcclEngine:
def __init__(self, def __init__(self,
local_rank: int, local_rank: int,
port_offset: int,
config: KVTransferConfig, config: KVTransferConfig,
hostname: str = "", model_config: ModelConfig,
port_offset: int = 0,
library_path: Optional[str] = None) -> None: library_path: Optional[str] = None) -> None:
self.config = config self.config = config
self.model_config = model_config
self.rank = port_offset self.rank = port_offset
self.local_rank = local_rank self.local_rank = local_rank
self.device = torch.device(f"cuda:{self.local_rank}") self.device = torch.device(f"cuda:{self.local_rank}")
self.nccl = NCCLLibrary(library_path) self.nccl = NCCLLibrary(library_path)
if not hostname: self.total_num_hidden_layers = getattr(self.model_config.hf_text_config,
hostname = get_ip() "num_hidden_layers", 0)
self.pp_rank = get_pp_group().rank_in_group
self.tp_rank = get_tp_group().rank_in_group
self.pp_size = get_pp_group().world_size
self.tp_size = get_tp_group().world_size
if config.is_kv_producer:
self.remote_tp_size = self.config.get_from_extra_config(
"remote_tp_size", 1)
self.remote_pp_size = self.config.get_from_extra_config(
"remote_pp_size", 1)
self.enable_asymmetric_p2p = self.config.get_from_extra_config(
"enable_asymmetric_p2p", False)
if self.remote_tp_size % self.tp_size != 0:
logger.error(" the Prefill TP size must be less than or equal to the Decode TP size!!!!")
self.multp = int(self.remote_tp_size / self.tp_size)
port = int(self.config.kv_port) + port_offset port = int(self.config.kv_port) + port_offset
if port == 0: if port == 0:
raise ValueError("Port cannot be 0") raise ValueError("Port cannot be 0")
self._hostname = hostname self._hostname = get_ip()
self._port = port self._port = port
# Each card corresponds to a ZMQ address. # Each card corresponds to a ZMQ address.
...@@ -195,6 +223,61 @@ class P2pNcclEngine: ...@@ -195,6 +223,61 @@ class P2pNcclEngine:
return self.socks[remote_address], self.comms[remote_address] return self.socks[remote_address], self.comms[remote_address]
def get_send_queue_items(self, request_id: str, layer_name: str,
tensor: torch.Tensor,
is_mla: bool) -> list[any]:
tensor_id = self.get_tensor_id(request_id, layer_name)
remote_ip, remote_port = self.parse_request_id(request_id, True)
if not self.enable_asymmetric_p2p:
remote_address = remote_ip + ":" + str(remote_port + self.rank)
return [(tensor_id, remote_address, tensor)]
if not is_mla:
logger.error(" P2PNCCL only support mla model symmetric PP/TP!!!!")
remote_pp_rank = self.compute_remote_pp_rank(layer_name)
items: list[Any] = []
up_down = 1
# remote_tp_rank = self.tp_rank * self.multp
for d_tp_rank in range(self.remote_tp_size):
for mul_tp in range(self.multp):
if self.tp_rank + mul_tp * self.tp_size == d_tp_rank:
remote_port_offset = remote_pp_rank * self.remote_tp_size + d_tp_rank
remote_address = remote_ip + ":" + str(remote_port + remote_port_offset)
logger.debug(
"📥 [PUT] Wait to send: tensor_id:%s, tensor_shape:%s, "
"(pp=%d, tp=%d) -> remote_address=%s(pp=%d, tp=%d)", tensor_id,
tensor.shape, self.pp_rank, self.tp_rank, remote_address,
remote_pp_rank, self.rank * mul_tp + self.rank)
items.append([tensor_id, remote_address, tensor])
return items
def send_tensor_new(
self,
request_id: str,
layer_name: str,
tensor: torch.Tensor,
is_mla: bool = False,
) -> bool:
tensor_id = self.get_tensor_id(request_id, layer_name)
if self.send_type == "PUT":
return all(
self.send_sync(item) for item in self.get_send_queue_items(
request_id, layer_name, tensor, is_mla))
if self.send_type == "PUT_ASYNC":
with self.send_queue_cv:
for item in self.get_send_queue_items(request_id, layer_name,
tensor, is_mla):
self.send_queue.append(item)
self.send_queue_cv.notify()
return True
if self.send_type == "GET":
logger.error(" P2PNCCL new not support GET model, please set VLLM_P2PNCCL_NEW=0 use defalut model!!!!")
def send_tensor( def send_tensor(
self, self,
tensor_id: str, tensor_id: str,
...@@ -659,3 +742,38 @@ class P2pNcclEngine: ...@@ -659,3 +742,38 @@ class P2pNcclEngine:
self._send_thread.join() self._send_thread.join()
if self._ping_thread is not None: if self._ping_thread is not None:
self._ping_thread.join() self._ping_thread.join()
def compute_remote_pp_rank(self, layer_name: str) -> int:
current_layer_idx = extract_layer_index(layer_name)
for d_pp_rank in range(self.remote_pp_size):
start, end = get_pp_indices(self.total_num_hidden_layers, d_pp_rank, self.remote_pp_size)
logger.info(f"""compute_remote_pp_rank : current_layer_idx:{current_layer_idx} start:{start} end:{end}""")
if (current_layer_idx == self.total_num_hidden_layers):
return self.remote_pp_size - 1
if start <= current_layer_idx < end:
return d_pp_rank
return -1
@staticmethod
def get_tensor_id(request_id: str, layer_name: str) -> str:
return request_id + "#" + layer_name
@staticmethod
def parse_request_id(request_id: str, is_prefill=True) -> tuple[str, int]:
# Regular expression to match the string hostname and integer port
if is_prefill:
pattern = r"___decode_addr_(.*):(\d+)"
else:
pattern = r"___prefill_addr_(.*):(\d+)___"
# Use re.search to find the pattern in the request_id
match = regex.search(pattern, request_id)
if match:
# Extract the ranks
ip = match.group(1)
port = int(match.group(2))
return ip, port
raise ValueError(
f"Request id {request_id} does not contain hostname and port")
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment