Commit 84e5aba2 authored by xiabo's avatar xiabo
Browse files

解决pd分离非对称切分通信组过多问题

parent cd42bf87
...@@ -12,7 +12,7 @@ from vllm.config import VllmConfig ...@@ -12,7 +12,7 @@ from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine import ( from vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine import (
P2pNcclEngine) P2pNcclEngine, RemoteAddr)
from vllm.distributed.parallel_state import get_world_group from vllm.distributed.parallel_state import get_world_group
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -108,6 +108,12 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -108,6 +108,12 @@ class P2pNcclConnector(KVConnectorBase_V1):
port_offset=self._rank, port_offset=self._rank,
config=self.config, config=self.config,
model_config=vllm_config.model_config, model_config=vllm_config.model_config,
dp_rank=self._dp_rank,
pp_rank=self._pp_rank,
tp_rank=self._tp_rank,
dp_size=self._dp_size,
pp_size=self._pp_size,
tp_size=self._tp_size
) if role == KVConnectorRole.WORKER else None ) if role == KVConnectorRole.WORKER else None
self.parallel_config = vllm_config.parallel_config self.parallel_config = vllm_config.parallel_config
...@@ -177,13 +183,11 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -177,13 +183,11 @@ class P2pNcclConnector(KVConnectorBase_V1):
# Only consumer/decode loads KV Cache # Only consumer/decode loads KV Cache
if self.is_producer: if self.is_producer:
return return
assert self.p2p_nccl_engine is not None assert self.p2p_nccl_engine is not None
attn_metadata = forward_context.attn_metadata attn_metadata = forward_context.attn_metadata
if attn_metadata is None: if attn_metadata is None:
return return
def inject_kv_into_layer( def inject_kv_into_layer(
dst_kv_cache_layer: torch.Tensor, dst_kv_cache_layer: torch.Tensor,
src_kv_cache: torch.Tensor, src_kv_cache: torch.Tensor,
...@@ -274,7 +278,6 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -274,7 +278,6 @@ class P2pNcclConnector(KVConnectorBase_V1):
logger.warning("🚧src_kv_cache is None, %s", logger.warning("🚧src_kv_cache is None, %s",
request.request_id) request.request_id)
continue continue
inject_kv_into_layer(kv_cache_layer, kv_cache, inject_kv_into_layer(kv_cache_layer, kv_cache,
request.slot_mapping, request.request_id) request.slot_mapping, request.request_id)
tensor_id = request.request_id + "#" + layer_name tensor_id = request.request_id + "#" + layer_name
...@@ -436,7 +439,9 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -436,7 +439,9 @@ class P2pNcclConnector(KVConnectorBase_V1):
for request in connector_metadata.requests: for request in connector_metadata.requests:
request_id = request.request_id request_id = request.request_id
ip, port = self.parse_request_id(request_id, True) ip, port = self.parse_request_id(request_id, True)
p_ip, p_port = self.parse_request_id(request_id, False)
remote_address = ip + ":" + str(port + self._rank) remote_address = ip + ":" + str(port + self._rank)
# pd_pair_id = p_ip + ":" + p_port + "_" + ip + ":" + port
kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping) kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping)
pp_rank = (self.parallel_config.rank // self.parallel_config.tensor_parallel_size pp_rank = (self.parallel_config.rank // self.parallel_config.tensor_parallel_size
...@@ -467,6 +472,7 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -467,6 +472,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
logger.error("Error: P multiple machines D machine only suppprt P:pp2tp8 D:tp8 !!!!!!") logger.error("Error: P multiple machines D machine only suppprt P:pp2tp8 D:tp8 !!!!!!")
elif (not self.multiple_machines_p and not self.multiple_machines_d): elif (not self.multiple_machines_p and not self.multiple_machines_d):
# remote_addr = RemoteAddr(pd_pair_id, remote_address, self._rank + self.num_card)
self.p2p_nccl_engine.send_tensor_new(request_id, layer_name, kv_cache, self.p2p_nccl_engine.send_tensor_new(request_id, layer_name, kv_cache,
is_mla) is_mla)
# if (self.pp_size == 1): # if (self.pp_size == 1):
......
...@@ -72,6 +72,13 @@ def set_p2p_nccl_context(num_channels: str): ...@@ -72,6 +72,13 @@ def set_p2p_nccl_context(num_channels: str):
os.environ.pop(var, None) os.environ.pop(var, None)
@dataclass
class RemoteAddr:
pd_pair_id: str = ""
zmq_address: str = ""
comm_rank: int = 0
class P2pNcclEngine: class P2pNcclEngine:
def __init__(self, def __init__(self,
...@@ -79,11 +86,23 @@ class P2pNcclEngine: ...@@ -79,11 +86,23 @@ class P2pNcclEngine:
port_offset: int, port_offset: int,
config: KVTransferConfig, config: KVTransferConfig,
model_config: ModelConfig, model_config: ModelConfig,
dp_rank: int = 0,
pp_rank: int = 0,
tp_rank: int = 0,
dp_size: int = 0,
pp_size: int = 0,
tp_size: int = 0,
library_path: Optional[str] = None) -> None: library_path: Optional[str] = None) -> None:
self.config = config self.config = config
self.model_config = model_config self.model_config = model_config
self.rank = port_offset self.rank = port_offset
self.local_rank = local_rank self.local_rank = local_rank
self.dp_rank = dp_rank
self.pp_rank = pp_rank
self.tp_rank = tp_rank
self.dp_size = dp_size
self.pp_size = pp_size
self.tp_size = tp_size
self.device = torch.device(f"cuda:{self.local_rank}") self.device = torch.device(f"cuda:{self.local_rank}")
self.nccl = NCCLLibrary(library_path) self.nccl = NCCLLibrary(library_path)
...@@ -184,7 +203,8 @@ class P2pNcclEngine: ...@@ -184,7 +203,8 @@ class P2pNcclEngine:
self._listener_thread.start() self._listener_thread.start()
self._ping_thread = None self._ping_thread = None
if port_offset == 0 and self.proxy_address != "": # if port_offset == 0 and self.proxy_address != "":
if self.proxy_address != "":
self._ping_thread = threading.Thread(target=self._ping, self._ping_thread = threading.Thread(target=self._ping,
daemon=True) daemon=True)
self._ping_thread.start() self._ping_thread.start()
...@@ -196,6 +216,21 @@ class P2pNcclEngine: ...@@ -196,6 +216,21 @@ class P2pNcclEngine:
self.http_address, self.zmq_address, self.proxy_address, self.http_address, self.zmq_address, self.proxy_address,
self.send_type, self.buffer_size_threshold, self.nccl_num_channels) self.send_type, self.buffer_size_threshold, self.nccl_num_channels)
def _create_connect_new(self, remote_address: typing.Optional[str] = None):
assert remote_address is not None
if remote_address not in self.socks:
sock = self.context.socket(zmq.DEALER)
sock.setsockopt(zmq.SNDHWM, 10000)
sock.setsockopt(zmq.RCVHWM, 5000)
sock.setsockopt(zmq.LINGER, 0)
sock.setsockopt(zmq.TCP_KEEPALIVE, 1)
sock.setsockopt_string(zmq.IDENTITY, f"P-{self.zmq_address}")
sock.connect(f"tcp://{remote_address}")
self.socks[remote_address] = sock
return self.socks[remote_address]
def _create_connect(self, remote_address: typing.Optional[str] = None): def _create_connect(self, remote_address: typing.Optional[str] = None):
assert remote_address is not None assert remote_address is not None
if remote_address not in self.socks: if remote_address not in self.socks:
...@@ -228,30 +263,36 @@ class P2pNcclEngine: ...@@ -228,30 +263,36 @@ class P2pNcclEngine:
is_mla: bool) -> list[any]: is_mla: bool) -> list[any]:
tensor_id = self.get_tensor_id(request_id, layer_name) tensor_id = self.get_tensor_id(request_id, layer_name)
remote_ip, remote_port = self.parse_request_id(request_id, True) remote_ip, remote_port = self.parse_request_id(request_id, True)
p_ip, p_port = self.parse_request_id(request_id, False)
pd_pair_id = p_ip + ":" + str(p_port) + "_" + remote_ip + ":" + str(remote_port)
# remote_port = 22001
if not self.enable_asymmetric_p2p: if not self.enable_asymmetric_p2p:
remote_address = remote_ip + ":" + str(remote_port + self.rank) remote_address = remote_ip + ":" + str(remote_port + self.rank)
return [(tensor_id, remote_address, tensor)] remote_addr = RemoteAddr(pd_pair_id, remote_address, self.rank + self.pp_size * self.tp_size)
# logger.info(f"""+++++xiabo tensor_id:{tensor_id} request_id:{request_id} remote_address:{remote_address}""")
return [(tensor_id, remote_addr, tensor)]
if not is_mla: if not is_mla:
logger.error(" P2PNCCL only support mla model symmetric PP/TP!!!!") logger.error(" P2PNCCL only support mla model symmetric PP/TP!!!!")
remote_pp_rank = self.compute_remote_pp_rank(layer_name) remote_pp_rank = self.compute_remote_pp_rank(layer_name)
items: list[Any] = [] items: list[Any] = []
up_down = 1
# remote_tp_rank = self.tp_rank * self.multp # remote_tp_rank = self.tp_rank * self.multp
for d_tp_rank in range(self.remote_tp_size): for d_tp_rank in range(self.remote_tp_size):
for mul_tp in range(self.multp): for mul_tp in range(self.multp):
if self.tp_rank + mul_tp * self.tp_size == d_tp_rank: if self.tp_rank + mul_tp * self.tp_size == d_tp_rank:
remote_port_offset = remote_pp_rank * self.remote_tp_size + d_tp_rank remote_port_offset = remote_pp_rank * self.remote_tp_size + d_tp_rank
remote_address = remote_ip + ":" + str(remote_port + remote_port_offset) remote_address = remote_ip + ":" + str(remote_port + remote_port_offset)
remote_addr = RemoteAddr(pd_pair_id, remote_address, remote_port_offset + self.pp_size * self.tp_size)
logger.debug( logger.debug(
"📥 [PUT] Wait to send: tensor_id:%s, tensor_shape:%s, " "Wait to send::%s, tensor_shape:%s, "
"(pp=%d, tp=%d) -> remote_address=%s(pp=%d, tp=%d)", tensor_id, "(pp=%d, tp=%d) -> remote_address=%s(pp=%d, tp=%d) comm_rank (%d -> %d)", tensor_id,
tensor.shape, self.pp_rank, self.tp_rank, remote_address, tensor.shape, self.pp_rank, self.tp_rank, remote_address,
remote_pp_rank, self.rank * mul_tp + self.rank) remote_pp_rank, self.rank * mul_tp + self.rank, self.rank, remote_port_offset + self.pp_size * self.tp_size)
items.append([tensor_id, remote_address, tensor]) items.append([tensor_id, remote_addr, tensor])
return items return items
def send_tensor_new( def send_tensor_new(
...@@ -481,7 +522,56 @@ class P2pNcclEngine: ...@@ -481,7 +522,56 @@ class P2pNcclEngine:
self.recv_store[tensor_id] = tensor self.recv_store[tensor_id] = tensor
self._have_received_tensor_id(tensor_id) self._have_received_tensor_id(tensor_id)
self.recv_store_cv.notify() self.recv_store_cv.notify()
elif data["cmd"] == "PUT_NEW":
tensor_id = data["tensor_id"]
if "tensor_split_num" in data:
self.tensor_split_num = data["tensor_split_num"]
try:
with torch.cuda.stream(self.recv_stream):
tensor = torch.empty(data["shape"],
dtype=getattr(
torch, data["dtype"]),
device=self.device)
self.router_socket.send_multipart(
[remote_address, b"0"])
# comm, rank = self.comms[remote_address.decode()]
# self._recv(comm, tensor, rank ^ 1, self.recv_stream)
comm, rank = self.comms[data["pd_pair_id"]]
self._recv(comm, tensor, int(data["comm_rank"]), self.recv_stream)
tensor_size = tensor.element_size() * tensor.numel()
if (self.buffer_size + tensor_size
> self.buffer_size_threshold):
# Store Tensor in memory pool
addr = self.pool.store_tensor(tensor)
tensor = (addr, tensor.dtype, tensor.shape)
else:
self.buffer_size += tensor_size
except torch.cuda.OutOfMemoryError:
self.router_socket.send_multipart(
[remote_address, b"1"])
tensor = None
logger.warning(
"🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, "
"data:%s", self.zmq_address,
remote_address.decode(), data)
with self.recv_store_cv:
self.recv_store[tensor_id] = tensor
self._have_received_tensor_id(tensor_id)
self.recv_store_cv.notify()
elif data["cmd"] == "comm_init":
unique_id = self.nccl.unique_id_from_bytes(
bytes(data["unique_id"]))
with torch.cuda.device(self.device):
rank = int(data["rank"])
world_size = int(data["world_size"])
with set_p2p_nccl_context(self.nccl_num_channels):
comm: ncclComm_t = self.nccl.ncclCommInitRank(
world_size, unique_id, rank)
self.comms[data["pd_pair_id"]] = (comm, rank)
logger.info(
"🤝ncclCommInitRank Success, %s👈%s, MyRank:%s",
self.zmq_address, data["pd_pair_id"], rank)
elif data["cmd"] == "GET": elif data["cmd"] == "GET":
tensor_id = data["tensor_id"] tensor_id = data["tensor_id"]
with self.send_store_cv: with self.send_store_cv:
...@@ -538,7 +628,8 @@ class P2pNcclEngine: ...@@ -538,7 +628,8 @@ class P2pNcclEngine:
self.send_stream.wait_event(tbo_evt) self.send_stream.wait_event(tbo_evt)
self._send_kv_p2p_sync(tensor_id, kv_layer, slot_mapping, remote_address) self._send_kv_p2p_sync(tensor_id, kv_layer, slot_mapping, remote_address)
else: else:
self._send_sync(tensor_id, tensor, remote_address) # logger.info(f"""=============xiabo tensor_id:{tensor_id} remote_address:{remote_address}""")
self._send_sync_new(tensor_id, tensor, remote_address)
def wait_for_sent(self): def wait_for_sent(self):
if self.send_type == "PUT_ASYNC": if self.send_type == "PUT_ASYNC":
...@@ -620,6 +711,48 @@ class P2pNcclEngine: ...@@ -620,6 +711,48 @@ class P2pNcclEngine:
return True return True
def _send_sync_new(
self,
tensor_id: str,
tensor: torch.Tensor,
remote_address: typing.Optional[RemoteAddr] = None,
) -> bool:
if remote_address is None:
return False
if remote_address.zmq_address not in self.socks:
# logger.info(f"""=============xiabo remote_address.zmq_address:{remote_address.zmq_address}""")
self._create_connect_new(remote_address.zmq_address)
sock = self.socks[remote_address.zmq_address]
comm, rank = self.comms[remote_address.pd_pair_id]
data = {
"cmd": "PUT_NEW",
"tensor_id": tensor_id,
"shape": tensor.shape,
"dtype": str(tensor.dtype).replace("torch.", ""),
"pd_pair_id": remote_address.pd_pair_id,
"comm_rank": rank
}
# logger.info(f"""_send_sync_new:{data}""")
sock.send(msgpack.dumps(data))
response = sock.recv()
if response != b"0":
logger.error(
"🔴Send Tensor, Peer Out Of Memory/Threshold, %s 👉 %s, "
"MyRank:%s, data:%s, tensor:%s, size:%fGB, response:%s",
self.zmq_address, remote_address.zmq_address, rank, data, tensor.shape,
tensor.element_size() * tensor.numel() / 1024**3,
response.decode())
return False
self._send(comm, tensor.to(self.device), remote_address.comm_rank, self.send_stream)
if self.send_type == "PUT_ASYNC":
self._have_sent_tensor_id(tensor_id)
return True
def _send_sync( def _send_sync(
self, self,
tensor_id: str, tensor_id: str,
...@@ -697,18 +830,46 @@ class P2pNcclEngine: ...@@ -697,18 +830,46 @@ class P2pNcclEngine:
return finished_sending or None, finished_recving or None return finished_sending or None, finished_recving or None
def _ping(self): def _ping(self):
# sock = self.context.socket(zmq.DEALER)
# sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
# logger.debug("ping start, zmq_address:%s", self.zmq_address)
# sock.connect(f"tcp://{self.proxy_address}")
# data = {
# "type": "P" if self.config.is_kv_producer else "D",
# "http_address": self.http_address,
# "zmq_address": self.zmq_address
# }
# while True:
# sock.send(msgpack.dumps(data))
# time.sleep(3)
sock = self.context.socket(zmq.DEALER) sock = self.context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address) sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
logger.debug("ping start, zmq_address:%s", self.zmq_address) logger.debug("ping start, zmq_address:%s", self.zmq_address)
sock.connect(f"tcp://{self.proxy_address}") sock.connect(f"tcp://{self.proxy_address}")
if self.rank == 0:
data = {
"type": "P_init" if self.config.is_kv_producer else "D_init",
"http_address": self.http_address,
"zmq_address": self.zmq_address,
"dp_size" : self.dp_size,
"pp_size" : self.pp_size,
"tp_size" : self.tp_size
}
logger.info(f"""_ping data:{data}""")
sock.send(msgpack.dumps(data))
data = { data = {
"type": "P" if self.config.is_kv_producer else "D", "type": "P" if self.config.is_kv_producer else "D",
"http_address": self.http_address, "http_address": self.http_address,
"dp_rank" : self.dp_rank,
"pp_rank" : self.pp_rank,
"tp_rank" : self.tp_rank,
"zmq_address": self.zmq_address "zmq_address": self.zmq_address
} }
while True: # while True:
sock.send(msgpack.dumps(data)) logger.info(f"""_ping data:{data}""")
time.sleep(3) sock.send(msgpack.dumps(data))
# time.sleep(3)
def _send(self, comm, tensor: torch.Tensor, dst: int, stream=None): def _send(self, comm, tensor: torch.Tensor, dst: int, stream=None):
assert tensor.device == self.device, ( assert tensor.device == self.device, (
...@@ -747,7 +908,7 @@ class P2pNcclEngine: ...@@ -747,7 +908,7 @@ class P2pNcclEngine:
current_layer_idx = extract_layer_index(layer_name) current_layer_idx = extract_layer_index(layer_name)
for d_pp_rank in range(self.remote_pp_size): for d_pp_rank in range(self.remote_pp_size):
start, end = get_pp_indices(self.total_num_hidden_layers, d_pp_rank, self.remote_pp_size) start, end = get_pp_indices(self.total_num_hidden_layers, d_pp_rank, self.remote_pp_size)
logger.info(f"""compute_remote_pp_rank : current_layer_idx:{current_layer_idx} start:{start} end:{end}""") # logger.info(f"""compute_remote_pp_rank : current_layer_idx:{current_layer_idx} start:{start} end:{end}""")
if (current_layer_idx == self.total_num_hidden_layers): if (current_layer_idx == self.total_num_hidden_layers):
return self.remote_pp_size - 1 return self.remote_pp_size - 1
if start <= current_layer_idx < end: if start <= current_layer_idx < end:
......
...@@ -122,7 +122,7 @@ def get_pp_indices(num_hidden_layers: int, pp_rank: int, ...@@ -122,7 +122,7 @@ def get_pp_indices(num_hidden_layers: int, pp_rank: int,
if remaining_layers := num_hidden_layers % pp_size: if remaining_layers := num_hidden_layers % pp_size:
for i in range(2, remaining_layers + 2): for i in range(2, remaining_layers + 2):
partitions[-i] += 1 partitions[-i] += 1
logger.info( logger.debug(
"Hidden layers were unevenly partitioned: [%s]. " "Hidden layers were unevenly partitioned: [%s]. "
"This can be manually overridden using the " "This can be manually overridden using the "
"VLLM_PP_LAYER_PARTITION environment variable", "VLLM_PP_LAYER_PARTITION environment variable",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment