Commit 4612aad6 authored by Your Name's avatar Your Name
Browse files

[P/D][Feat]支持dp并行

parent cd42bf87
......@@ -54,6 +54,7 @@ class KVConnectorFactory:
cls,
config: "VllmConfig",
role: KVConnectorRole,
dp_rank: int = -1,
) -> KVConnectorBase_V1:
if not envs.VLLM_USE_V1:
raise ValueError("Attempting to initialize a V1 Connector, "
......@@ -81,7 +82,7 @@ class KVConnectorFactory:
# - Co-locate with worker process
# - Should only be used inside the forward context & attention layer
# We build separately to enforce strict separation
return connector_cls(config, role)
return connector_cls(config, role, dp_rank)
# Register various connectors here.
......
......@@ -6,19 +6,23 @@ from typing import TYPE_CHECKING, Any, Optional
import regex as re
import torch
import os
from vllm import envs
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine import (
P2pNcclEngine)
from vllm.distributed.parallel_state import get_world_group
P2pNcclEngine, RemoteAddr)
from vllm.distributed.parallel_state import get_world_group, get_dp_group, get_pp_group, get_tp_group
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.distributed.parallel_state import get_pp_group, get_tp_group, get_dp_group
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
import zmq
import msgpack
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
......@@ -78,7 +82,7 @@ class P2pNcclConnectorMetadata(KVConnectorMetadata):
class P2pNcclConnector(KVConnectorBase_V1):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole, dp_rank : int = -1):
super().__init__(vllm_config=vllm_config, role=role)
self._block_size = vllm_config.cache_config.block_size
self._requests_need_load: dict[str, Any] = {}
......@@ -102,12 +106,17 @@ class P2pNcclConnector(KVConnectorBase_V1):
if role == KVConnectorRole.WORKER else 0
self._tp_size = get_tp_group().world_size \
if role == KVConnectorRole.WORKER else 0
self.p2p_nccl_engine = P2pNcclEngine(
local_rank=self._local_rank,
port_offset=self._rank,
config=self.config,
model_config=vllm_config.model_config,
hostname="",
port_offset=self._rank,
dp_rank=self._dp_rank,
pp_rank=self._pp_rank,
tp_rank=self._tp_rank,
dp_size=self._dp_size,
pp_size=self._pp_size,
tp_size=self._tp_size
) if role == KVConnectorRole.WORKER else None
self.parallel_config = vllm_config.parallel_config
......@@ -117,19 +126,9 @@ class P2pNcclConnector(KVConnectorBase_V1):
self.pp_size = self.parallel_config.pipeline_parallel_size
self.tp_size = self.parallel_config.tensor_parallel_size
self.num_card = self.pp_size * self.tp_size
self.multiple_machines = 1 if self.num_card > 8 else 0
self.remote_tp_size = self.config.get_from_extra_config(
"remote_tp_size", self.tp_size)
self.remote_pp_size = self.config.get_from_extra_config(
"remote_pp_size", self.pp_size)
self.enable_asymmetric_p2p = self.config.get_from_extra_config(
"enable_asymmetric_p2p", False)
self.remote_num_card = self.remote_tp_size * self.remote_pp_size
self.multiple_machines_d = 1 if self.remote_num_card > 8 else 0
self.multiple_machines_p = 1 if self.num_card > 8 else 0
if self.is_producer and self.multiple_machines_p == 1:
if self.is_producer and self.multiple_machines == 1:
self.ip_map = {}
self.duplicate_keys = []
config_file = os.getenv('IP_CONFIG_FILE')
......@@ -152,10 +151,38 @@ class P2pNcclConnector(KVConnectorBase_V1):
print(f"Error: Exception occurred while reading configuration file - {e}")
if role == KVConnectorRole.SCHEDULER :
self.dp_rank = dp_rank
proxy_ip = self.config.get_from_extra_config("proxy_ip", "")
proxy_port = self.config.get_from_extra_config("proxy_port", "")
if proxy_ip == "" or proxy_port == "":
self.proxy_address = ""
else:
self.proxy_address = proxy_ip + ":" + proxy_port
self.http_address = (
f"{self.config.kv_connector_extra_config['instance_ip']}:"
f"{self.config.kv_connector_extra_config['http_port']}")
self.context = zmq.Context()
req_sock = self.context.socket(zmq.DEALER)
req_sock.setsockopt_string(zmq.IDENTITY, f"{self.http_address}_rank{self.dp_rank}")
req_sock.connect(f"tcp://{self.proxy_address}")
self.req_sock = req_sock
def get_ip_value(self, key):
return self.ip_map.get(key)
def register_req(self, request_id: str) :
data = {
"type": "Req",
"instance_type": "P" if self.config.is_kv_producer else "D",
"http_address": self.http_address,
"request_id": request_id,
"dp_rank" : self.dp_rank
}
self.req_sock.send(msgpack.dumps(data))
# ==============================
# Worker-side methods
# ==============================
......@@ -304,7 +331,13 @@ class P2pNcclConnector(KVConnectorBase_V1):
2, num_pages * page_size, -1)
inject_start_index = 0
for num in range(self.p2p_nccl_engine.tensor_split_num):
req_layer = f"{request.request_id}#{layer_name}"
with self.p2p_nccl_engine.recv_store_cv:
while req_layer not in self.p2p_nccl_engine.recv_split_nums:
self.p2p_nccl_engine.recv_store_cv.wait()
split_num = self.p2p_nccl_engine.recv_split_nums.get(req_layer)
for num in range(split_num):
kv_cache = self.p2p_nccl_engine.recv_tensor(
request.request_id + "#" + layer_name + "#" + str(num))
......@@ -332,6 +365,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
# inject_kv_into_layer(kv_cache_layer, kv_cache,
# request.slot_mapping, request.request_id)
tensor_id = request.request_id + "#" + layer_name + "#" + str(num)
if tensor_id in self.p2p_nccl_engine.recv_store:
tensor = self.p2p_nccl_engine.recv_store.pop(tensor_id, None)
self.p2p_nccl_engine.send_request_id_to_tensor_ids.pop(
......@@ -375,8 +409,6 @@ class P2pNcclConnector(KVConnectorBase_V1):
assert self.p2p_nccl_engine is not None
is_mla = isinstance(attn_metadata, MLACommonMetadata)
def extract_kv_from_layer(
layer: torch.Tensor,
slot_mapping: torch.Tensor,
......@@ -400,8 +432,6 @@ class P2pNcclConnector(KVConnectorBase_V1):
if envs.VLLM_ENABLE_TBO or envs.VLLM_P2P_ASYNC:
for request in connector_metadata.requests:
request_id = request.request_id
ip, port = self.parse_request_id(request_id, True)
remote_address = ip + ":" + str(port + self._rank)
slot_mapping = request.slot_mapping
if request.slot_mapping_device is None:
request.slot_mapping_device = \
......@@ -409,91 +439,46 @@ class P2pNcclConnector(KVConnectorBase_V1):
slot_mapping = request.slot_mapping_device
tbo_evt = torch.cuda.Event(enable_timing=False)
tbo_evt.record()
pp_rank = (self.parallel_config.rank //
self.parallel_config.tensor_parallel_size) % \
self.parallel_config.pipeline_parallel_size
if (self.pp_size == 1):
self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name,
(kv_layer, slot_mapping), remote_address, tbo_evt)
elif (self.pp_size == 2):
if (pp_rank == 0):
self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name,
(kv_layer, slot_mapping), remote_address, tbo_evt)
self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name,
(kv_layer, slot_mapping), ip + ":" + str(port + self._rank + 4), tbo_evt)
else:
self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name,
(kv_layer, slot_mapping), remote_address, tbo_evt)
self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name,
(kv_layer, slot_mapping), ip + ":" + str(port + self._rank - 4), tbo_evt)
elif (self.pp_size == 8):
for i in range(8):
pending = False
with self.p2p_nccl_engine.req_status_cv:
if request_id not in self.p2p_nccl_engine.req_status:
pending = True
if pending:
self.p2p_nccl_engine.pending_tensor(request_id, layer_name,
(kv_layer, slot_mapping), tbo_evt)
logger.info("[%d] pending for request: %s layer: %s", self._rank, request_id, layer_name)
else :
req_data = self.p2p_nccl_engine.req_status[request_id]
assert(req_data.dst_num == len(req_data.zmq_address_and_comm_rank))
for i in range(req_data.dst_num):
remote_addr = RemoteAddr(req_data.pd_pair_id, *(req_data.zmq_address_and_comm_rank[i]))
self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name,
(kv_layer, slot_mapping), ip + ":" + str(port + i), tbo_evt)
else:
print("Error: only suppprt pp1 pp2 pp8!!!!!!")
(kv_layer, slot_mapping), remote_addr, tbo_evt)
# self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name,
# (kv_layer, slot_mapping), remote_address, tbo_evt)
else:
for request in connector_metadata.requests:
request_id = request.request_id
ip, port = self.parse_request_id(request_id, True)
remote_address = ip + ":" + str(port + self._rank)
kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping)
pp_rank = (self.parallel_config.rank // self.parallel_config.tensor_parallel_size
) % self.parallel_config.pipeline_parallel_size
if (self.multiple_machines_p and self.multiple_machines_d):
ip_second = self.get_ip_value(ip)
if (self.pp_size == 1):
if self._rank < 8:
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, str(ip_second) + ":" + str(port + self._rank + 8))
elif (self.pp_size == 2):
if (pp_rank == 0):
pending = False
with self.p2p_nccl_engine.req_status_cv:
if request_id not in self.p2p_nccl_engine.req_status:
pending = True
if pending:
self.p2p_nccl_engine.pending_tensor(request_id, layer_name,
kv_cache)
logger.info("[%d] pending for request: %s layer: %s", self._rank, request_id, layer_name)
else :
req_data = self.p2p_nccl_engine.req_status[request_id]
assert(req_data.dst_num == len(req_data.zmq_address_and_comm_rank))
for i in range(req_data.dst_num):
remote_addr = RemoteAddr(req_data.pd_pair_id, *(req_data.zmq_address_and_comm_rank[i]))
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address)
else:
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, str(ip_second) + ":" + str(port + self._rank))
else:
logger.error("Error: multiple machines only suppprt pp1tp16 and pp2tp8!!!!!!")
elif (self.multiple_machines_p and not self.multiple_machines_d):
if (self.pp_size == 2):
remote_address = ip + ":" + str(port + self._tp_rank)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address)
else:
logger.error("Error: P multiple machines D machine only suppprt P:pp2tp8 D:tp8 !!!!!!")
elif (not self.multiple_machines_p and not self.multiple_machines_d):
self.p2p_nccl_engine.send_tensor_new(request_id, layer_name, kv_cache,
is_mla)
# if (self.pp_size == 1):
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# elif (self.pp_size == 2):
# if (pp_rank == 0):
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, ip + ":" + str(port + self._rank + 4))
# else:
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, ip + ":" + str(port + self._rank - 4))
# elif (self.pp_size == 8):
# for i in range(8):
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, ip + ":" + str(port + i))
# elif (self.enable_asymmetric_p2p):
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# else:
# logger.error("Error: P/D single machine only suppprt multiple tp:: (P: pp2tp4 D:tp8 P:pp8tp1 D:tp8) !!!!!!")
else:
logger.error("Error: not support!!!!!!")
kv_cache, remote_addr)
def wait_for_save(self):
pass
# if self.is_producer:
......@@ -612,9 +597,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
num_scheduled_tokens = (
scheduler_output.num_scheduled_tokens)[req_id]
num_tokens = (num_scheduled_tokens + num_computed_tokens)
# assert req_id in self.chunked_prefill
if req_id not in self.chunked_prefill:
continue
assert req_id in self.chunked_prefill
block_ids = new_block_ids[0]
if not resumed_from_preemption:
block_ids = (self.chunked_prefill[req_id][0] + block_ids)
......
......@@ -86,7 +86,8 @@ class Scheduler(SchedulerInterface):
"Multiple KV cache groups are not currently supported "
"with KV connectors")
self.connector = KVConnectorFactory.create_connector_v1(
config=self.vllm_config, role=KVConnectorRole.SCHEDULER)
config=self.vllm_config, role=KVConnectorRole.SCHEDULER,
dp_rank=self.parallel_config.data_parallel_rank)
self.kv_event_publisher = EventPublisherFactory.create(
self.kv_events_config,
......@@ -371,8 +372,10 @@ class Scheduler(SchedulerInterface):
break
request = self.waiting.peek_request()
if request.is_finished():
if self.connector and not self.connector.is_producer and request.request_id not in self.finished_recving_kv_req_ids :
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# KVTransfer: skip request if still waiting for remote kvs.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
......@@ -457,6 +460,7 @@ class Scheduler(SchedulerInterface):
# pooling requests to be chunked
if not self.scheduler_config.chunked_prefill_enabled and \
num_new_tokens > token_budget:
break
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
......@@ -668,6 +672,11 @@ class Scheduler(SchedulerInterface):
break
request = self.waiting.peek_request()
if self.connector and not self.connector.is_producer and request.request_id not in self.finished_recving_kv_req_ids :
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# KVTransfer: skip request if still waiting for remote kvs.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
is_ready = self._update_waiting_for_remote_kv(request)
......@@ -751,6 +760,7 @@ class Scheduler(SchedulerInterface):
# pooling requests to be chunked
if not self.scheduler_config.chunked_prefill_enabled and \
num_new_tokens > token_budget:
break
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
......@@ -1311,7 +1321,7 @@ class Scheduler(SchedulerInterface):
request.num_nans_in_logits = num_nans_in_logits[req_id]
# Add newly generated spec token ids to the request.
if spec_token_ids is not None:
if spec_token_ids is not None and (self.connector is None or not self.connector.is_producer):
if self.structured_output_manager.should_advance(request):
metadata = request.structured_output_request
# Needs to happen after new_token_ids are accepted.
......
......@@ -763,6 +763,9 @@ class EngineCoreProc(EngineCore):
# Push to input queue for core busy loop.
self.input_queue.put_nowait((request_type, request))
if isinstance(request, EngineCoreRequest) and self.scheduler.connector is not None:
if request_type == EngineCoreRequestType.ADD:
self.scheduler.connector.register_req(request.request_id)
def process_output_sockets(self, output_paths: list[str],
coord_output_path: Optional[str],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment