Commit 4612aad6 authored by Your Name's avatar Your Name
Browse files

[P/D][Feat]支持dp并行

parent cd42bf87
...@@ -54,6 +54,7 @@ class KVConnectorFactory: ...@@ -54,6 +54,7 @@ class KVConnectorFactory:
cls, cls,
config: "VllmConfig", config: "VllmConfig",
role: KVConnectorRole, role: KVConnectorRole,
dp_rank: int = -1,
) -> KVConnectorBase_V1: ) -> KVConnectorBase_V1:
if not envs.VLLM_USE_V1: if not envs.VLLM_USE_V1:
raise ValueError("Attempting to initialize a V1 Connector, " raise ValueError("Attempting to initialize a V1 Connector, "
...@@ -81,7 +82,7 @@ class KVConnectorFactory: ...@@ -81,7 +82,7 @@ class KVConnectorFactory:
# - Co-locate with worker process # - Co-locate with worker process
# - Should only be used inside the forward context & attention layer # - Should only be used inside the forward context & attention layer
# We build separately to enforce strict separation # We build separately to enforce strict separation
return connector_cls(config, role) return connector_cls(config, role, dp_rank)
# Register various connectors here. # Register various connectors here.
......
...@@ -6,19 +6,23 @@ from typing import TYPE_CHECKING, Any, Optional ...@@ -6,19 +6,23 @@ from typing import TYPE_CHECKING, Any, Optional
import regex as re import regex as re
import torch import torch
import os import os
from vllm import envs from vllm import envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine import ( from vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine import (
P2pNcclEngine) P2pNcclEngine, RemoteAddr)
from vllm.distributed.parallel_state import get_world_group from vllm.distributed.parallel_state import get_world_group, get_dp_group, get_pp_group, get_tp_group
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import MLACommonMetadata from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.distributed.parallel_state import get_pp_group, get_tp_group, get_dp_group from vllm.distributed.parallel_state import get_pp_group, get_tp_group
import zmq
import msgpack
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
...@@ -78,7 +82,7 @@ class P2pNcclConnectorMetadata(KVConnectorMetadata): ...@@ -78,7 +82,7 @@ class P2pNcclConnectorMetadata(KVConnectorMetadata):
class P2pNcclConnector(KVConnectorBase_V1): class P2pNcclConnector(KVConnectorBase_V1):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole, dp_rank : int = -1):
super().__init__(vllm_config=vllm_config, role=role) super().__init__(vllm_config=vllm_config, role=role)
self._block_size = vllm_config.cache_config.block_size self._block_size = vllm_config.cache_config.block_size
self._requests_need_load: dict[str, Any] = {} self._requests_need_load: dict[str, Any] = {}
...@@ -102,12 +106,17 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -102,12 +106,17 @@ class P2pNcclConnector(KVConnectorBase_V1):
if role == KVConnectorRole.WORKER else 0 if role == KVConnectorRole.WORKER else 0
self._tp_size = get_tp_group().world_size \ self._tp_size = get_tp_group().world_size \
if role == KVConnectorRole.WORKER else 0 if role == KVConnectorRole.WORKER else 0
self.p2p_nccl_engine = P2pNcclEngine( self.p2p_nccl_engine = P2pNcclEngine(
local_rank=self._local_rank, local_rank=self._local_rank,
port_offset=self._rank,
config=self.config, config=self.config,
model_config=vllm_config.model_config, hostname="",
port_offset=self._rank,
dp_rank=self._dp_rank,
pp_rank=self._pp_rank,
tp_rank=self._tp_rank,
dp_size=self._dp_size,
pp_size=self._pp_size,
tp_size=self._tp_size
) if role == KVConnectorRole.WORKER else None ) if role == KVConnectorRole.WORKER else None
self.parallel_config = vllm_config.parallel_config self.parallel_config = vllm_config.parallel_config
...@@ -117,19 +126,9 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -117,19 +126,9 @@ class P2pNcclConnector(KVConnectorBase_V1):
self.pp_size = self.parallel_config.pipeline_parallel_size self.pp_size = self.parallel_config.pipeline_parallel_size
self.tp_size = self.parallel_config.tensor_parallel_size self.tp_size = self.parallel_config.tensor_parallel_size
self.num_card = self.pp_size * self.tp_size self.num_card = self.pp_size * self.tp_size
self.multiple_machines = 1 if self.num_card > 8 else 0
self.remote_tp_size = self.config.get_from_extra_config( if self.is_producer and self.multiple_machines == 1:
"remote_tp_size", self.tp_size)
self.remote_pp_size = self.config.get_from_extra_config(
"remote_pp_size", self.pp_size)
self.enable_asymmetric_p2p = self.config.get_from_extra_config(
"enable_asymmetric_p2p", False)
self.remote_num_card = self.remote_tp_size * self.remote_pp_size
self.multiple_machines_d = 1 if self.remote_num_card > 8 else 0
self.multiple_machines_p = 1 if self.num_card > 8 else 0
if self.is_producer and self.multiple_machines_p == 1:
self.ip_map = {} self.ip_map = {}
self.duplicate_keys = [] self.duplicate_keys = []
config_file = os.getenv('IP_CONFIG_FILE') config_file = os.getenv('IP_CONFIG_FILE')
...@@ -152,10 +151,38 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -152,10 +151,38 @@ class P2pNcclConnector(KVConnectorBase_V1):
print(f"Error: Exception occurred while reading configuration file - {e}") print(f"Error: Exception occurred while reading configuration file - {e}")
if role == KVConnectorRole.SCHEDULER :
self.dp_rank = dp_rank
proxy_ip = self.config.get_from_extra_config("proxy_ip", "")
proxy_port = self.config.get_from_extra_config("proxy_port", "")
if proxy_ip == "" or proxy_port == "":
self.proxy_address = ""
else:
self.proxy_address = proxy_ip + ":" + proxy_port
self.http_address = (
f"{self.config.kv_connector_extra_config['instance_ip']}:"
f"{self.config.kv_connector_extra_config['http_port']}")
self.context = zmq.Context()
req_sock = self.context.socket(zmq.DEALER)
req_sock.setsockopt_string(zmq.IDENTITY, f"{self.http_address}_rank{self.dp_rank}")
req_sock.connect(f"tcp://{self.proxy_address}")
self.req_sock = req_sock
def get_ip_value(self, key): def get_ip_value(self, key):
return self.ip_map.get(key) return self.ip_map.get(key)
def register_req(self, request_id: str) :
data = {
"type": "Req",
"instance_type": "P" if self.config.is_kv_producer else "D",
"http_address": self.http_address,
"request_id": request_id,
"dp_rank" : self.dp_rank
}
self.req_sock.send(msgpack.dumps(data))
# ============================== # ==============================
# Worker-side methods # Worker-side methods
# ============================== # ==============================
...@@ -304,7 +331,13 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -304,7 +331,13 @@ class P2pNcclConnector(KVConnectorBase_V1):
2, num_pages * page_size, -1) 2, num_pages * page_size, -1)
inject_start_index = 0 inject_start_index = 0
for num in range(self.p2p_nccl_engine.tensor_split_num): req_layer = f"{request.request_id}#{layer_name}"
with self.p2p_nccl_engine.recv_store_cv:
while req_layer not in self.p2p_nccl_engine.recv_split_nums:
self.p2p_nccl_engine.recv_store_cv.wait()
split_num = self.p2p_nccl_engine.recv_split_nums.get(req_layer)
for num in range(split_num):
kv_cache = self.p2p_nccl_engine.recv_tensor( kv_cache = self.p2p_nccl_engine.recv_tensor(
request.request_id + "#" + layer_name + "#" + str(num)) request.request_id + "#" + layer_name + "#" + str(num))
...@@ -332,6 +365,7 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -332,6 +365,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
# inject_kv_into_layer(kv_cache_layer, kv_cache, # inject_kv_into_layer(kv_cache_layer, kv_cache,
# request.slot_mapping, request.request_id) # request.slot_mapping, request.request_id)
tensor_id = request.request_id + "#" + layer_name + "#" + str(num) tensor_id = request.request_id + "#" + layer_name + "#" + str(num)
if tensor_id in self.p2p_nccl_engine.recv_store: if tensor_id in self.p2p_nccl_engine.recv_store:
tensor = self.p2p_nccl_engine.recv_store.pop(tensor_id, None) tensor = self.p2p_nccl_engine.recv_store.pop(tensor_id, None)
self.p2p_nccl_engine.send_request_id_to_tensor_ids.pop( self.p2p_nccl_engine.send_request_id_to_tensor_ids.pop(
...@@ -375,8 +409,6 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -375,8 +409,6 @@ class P2pNcclConnector(KVConnectorBase_V1):
assert self.p2p_nccl_engine is not None assert self.p2p_nccl_engine is not None
is_mla = isinstance(attn_metadata, MLACommonMetadata)
def extract_kv_from_layer( def extract_kv_from_layer(
layer: torch.Tensor, layer: torch.Tensor,
slot_mapping: torch.Tensor, slot_mapping: torch.Tensor,
...@@ -400,8 +432,6 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -400,8 +432,6 @@ class P2pNcclConnector(KVConnectorBase_V1):
if envs.VLLM_ENABLE_TBO or envs.VLLM_P2P_ASYNC: if envs.VLLM_ENABLE_TBO or envs.VLLM_P2P_ASYNC:
for request in connector_metadata.requests: for request in connector_metadata.requests:
request_id = request.request_id request_id = request.request_id
ip, port = self.parse_request_id(request_id, True)
remote_address = ip + ":" + str(port + self._rank)
slot_mapping = request.slot_mapping slot_mapping = request.slot_mapping
if request.slot_mapping_device is None: if request.slot_mapping_device is None:
request.slot_mapping_device = \ request.slot_mapping_device = \
...@@ -409,91 +439,46 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -409,91 +439,46 @@ class P2pNcclConnector(KVConnectorBase_V1):
slot_mapping = request.slot_mapping_device slot_mapping = request.slot_mapping_device
tbo_evt = torch.cuda.Event(enable_timing=False) tbo_evt = torch.cuda.Event(enable_timing=False)
tbo_evt.record() tbo_evt.record()
pp_rank = (self.parallel_config.rank //
self.parallel_config.tensor_parallel_size) % \ pending = False
self.parallel_config.pipeline_parallel_size with self.p2p_nccl_engine.req_status_cv:
if (self.pp_size == 1): if request_id not in self.p2p_nccl_engine.req_status:
self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name, pending = True
(kv_layer, slot_mapping), remote_address, tbo_evt) if pending:
elif (self.pp_size == 2): self.p2p_nccl_engine.pending_tensor(request_id, layer_name,
if (pp_rank == 0): (kv_layer, slot_mapping), tbo_evt)
self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name, logger.info("[%d] pending for request: %s layer: %s", self._rank, request_id, layer_name)
(kv_layer, slot_mapping), remote_address, tbo_evt) else :
self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name, req_data = self.p2p_nccl_engine.req_status[request_id]
(kv_layer, slot_mapping), ip + ":" + str(port + self._rank + 4), tbo_evt) assert(req_data.dst_num == len(req_data.zmq_address_and_comm_rank))
else: for i in range(req_data.dst_num):
self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name, remote_addr = RemoteAddr(req_data.pd_pair_id, *(req_data.zmq_address_and_comm_rank[i]))
(kv_layer, slot_mapping), remote_address, tbo_evt)
self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name,
(kv_layer, slot_mapping), ip + ":" + str(port + self._rank - 4), tbo_evt)
elif (self.pp_size == 8):
for i in range(8):
self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name, self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name,
(kv_layer, slot_mapping), ip + ":" + str(port + i), tbo_evt) (kv_layer, slot_mapping), remote_addr, tbo_evt)
else:
print("Error: only suppprt pp1 pp2 pp8!!!!!!") # self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name,
# (kv_layer, slot_mapping), remote_address, tbo_evt)
else: else:
for request in connector_metadata.requests: for request in connector_metadata.requests:
request_id = request.request_id request_id = request.request_id
ip, port = self.parse_request_id(request_id, True)
remote_address = ip + ":" + str(port + self._rank)
kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping) kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping)
pp_rank = (self.parallel_config.rank // self.parallel_config.tensor_parallel_size pending = False
) % self.parallel_config.pipeline_parallel_size with self.p2p_nccl_engine.req_status_cv:
if (self.multiple_machines_p and self.multiple_machines_d): if request_id not in self.p2p_nccl_engine.req_status:
ip_second = self.get_ip_value(ip) pending = True
if (self.pp_size == 1): if pending:
if self._rank < 8: self.p2p_nccl_engine.pending_tensor(request_id, layer_name,
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, kv_cache)
kv_cache, remote_address) logger.info("[%d] pending for request: %s layer: %s", self._rank, request_id, layer_name)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, else :
kv_cache, str(ip_second) + ":" + str(port + self._rank + 8)) req_data = self.p2p_nccl_engine.req_status[request_id]
elif (self.pp_size == 2): assert(req_data.dst_num == len(req_data.zmq_address_and_comm_rank))
if (pp_rank == 0): for i in range(req_data.dst_num):
remote_addr = RemoteAddr(req_data.pd_pair_id, *(req_data.zmq_address_and_comm_rank[i]))
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address) kv_cache, remote_addr)
else:
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, str(ip_second) + ":" + str(port + self._rank))
else:
logger.error("Error: multiple machines only suppprt pp1tp16 and pp2tp8!!!!!!")
elif (self.multiple_machines_p and not self.multiple_machines_d):
if (self.pp_size == 2):
remote_address = ip + ":" + str(port + self._tp_rank)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address)
else:
logger.error("Error: P multiple machines D machine only suppprt P:pp2tp8 D:tp8 !!!!!!")
elif (not self.multiple_machines_p and not self.multiple_machines_d):
self.p2p_nccl_engine.send_tensor_new(request_id, layer_name, kv_cache,
is_mla)
# if (self.pp_size == 1):
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# elif (self.pp_size == 2):
# if (pp_rank == 0):
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, ip + ":" + str(port + self._rank + 4))
# else:
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, ip + ":" + str(port + self._rank - 4))
# elif (self.pp_size == 8):
# for i in range(8):
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, ip + ":" + str(port + i))
# elif (self.enable_asymmetric_p2p):
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# else:
# logger.error("Error: P/D single machine only suppprt multiple tp:: (P: pp2tp4 D:tp8 P:pp8tp1 D:tp8) !!!!!!")
else:
logger.error("Error: not support!!!!!!")
def wait_for_save(self): def wait_for_save(self):
pass pass
# if self.is_producer: # if self.is_producer:
...@@ -612,9 +597,7 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -612,9 +597,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
num_scheduled_tokens = ( num_scheduled_tokens = (
scheduler_output.num_scheduled_tokens)[req_id] scheduler_output.num_scheduled_tokens)[req_id]
num_tokens = (num_scheduled_tokens + num_computed_tokens) num_tokens = (num_scheduled_tokens + num_computed_tokens)
# assert req_id in self.chunked_prefill assert req_id in self.chunked_prefill
if req_id not in self.chunked_prefill:
continue
block_ids = new_block_ids[0] block_ids = new_block_ids[0]
if not resumed_from_preemption: if not resumed_from_preemption:
block_ids = (self.chunked_prefill[req_id][0] + block_ids) block_ids = (self.chunked_prefill[req_id][0] + block_ids)
......
...@@ -86,7 +86,8 @@ class Scheduler(SchedulerInterface): ...@@ -86,7 +86,8 @@ class Scheduler(SchedulerInterface):
"Multiple KV cache groups are not currently supported " "Multiple KV cache groups are not currently supported "
"with KV connectors") "with KV connectors")
self.connector = KVConnectorFactory.create_connector_v1( self.connector = KVConnectorFactory.create_connector_v1(
config=self.vllm_config, role=KVConnectorRole.SCHEDULER) config=self.vllm_config, role=KVConnectorRole.SCHEDULER,
dp_rank=self.parallel_config.data_parallel_rank)
self.kv_event_publisher = EventPublisherFactory.create( self.kv_event_publisher = EventPublisherFactory.create(
self.kv_events_config, self.kv_events_config,
...@@ -371,8 +372,10 @@ class Scheduler(SchedulerInterface): ...@@ -371,8 +372,10 @@ class Scheduler(SchedulerInterface):
break break
request = self.waiting.peek_request() request = self.waiting.peek_request()
if request.is_finished():
if self.connector and not self.connector.is_producer and request.request_id not in self.finished_recving_kv_req_ids :
self.waiting.pop_request() self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue continue
# KVTransfer: skip request if still waiting for remote kvs. # KVTransfer: skip request if still waiting for remote kvs.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
...@@ -457,6 +460,7 @@ class Scheduler(SchedulerInterface): ...@@ -457,6 +460,7 @@ class Scheduler(SchedulerInterface):
# pooling requests to be chunked # pooling requests to be chunked
if not self.scheduler_config.chunked_prefill_enabled and \ if not self.scheduler_config.chunked_prefill_enabled and \
num_new_tokens > token_budget: num_new_tokens > token_budget:
break
self.waiting.pop_request() self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request) skipped_waiting_requests.prepend_request(request)
continue continue
...@@ -668,6 +672,11 @@ class Scheduler(SchedulerInterface): ...@@ -668,6 +672,11 @@ class Scheduler(SchedulerInterface):
break break
request = self.waiting.peek_request() request = self.waiting.peek_request()
if self.connector and not self.connector.is_producer and request.request_id not in self.finished_recving_kv_req_ids :
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# KVTransfer: skip request if still waiting for remote kvs. # KVTransfer: skip request if still waiting for remote kvs.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
is_ready = self._update_waiting_for_remote_kv(request) is_ready = self._update_waiting_for_remote_kv(request)
...@@ -751,6 +760,7 @@ class Scheduler(SchedulerInterface): ...@@ -751,6 +760,7 @@ class Scheduler(SchedulerInterface):
# pooling requests to be chunked # pooling requests to be chunked
if not self.scheduler_config.chunked_prefill_enabled and \ if not self.scheduler_config.chunked_prefill_enabled and \
num_new_tokens > token_budget: num_new_tokens > token_budget:
break
self.waiting.pop_request() self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request) skipped_waiting_requests.prepend_request(request)
continue continue
...@@ -1311,7 +1321,7 @@ class Scheduler(SchedulerInterface): ...@@ -1311,7 +1321,7 @@ class Scheduler(SchedulerInterface):
request.num_nans_in_logits = num_nans_in_logits[req_id] request.num_nans_in_logits = num_nans_in_logits[req_id]
# Add newly generated spec token ids to the request. # Add newly generated spec token ids to the request.
if spec_token_ids is not None: if spec_token_ids is not None and (self.connector is None or not self.connector.is_producer):
if self.structured_output_manager.should_advance(request): if self.structured_output_manager.should_advance(request):
metadata = request.structured_output_request metadata = request.structured_output_request
# Needs to happen after new_token_ids are accepted. # Needs to happen after new_token_ids are accepted.
......
...@@ -763,6 +763,9 @@ class EngineCoreProc(EngineCore): ...@@ -763,6 +763,9 @@ class EngineCoreProc(EngineCore):
# Push to input queue for core busy loop. # Push to input queue for core busy loop.
self.input_queue.put_nowait((request_type, request)) self.input_queue.put_nowait((request_type, request))
if isinstance(request, EngineCoreRequest) and self.scheduler.connector is not None:
if request_type == EngineCoreRequestType.ADD:
self.scheduler.connector.register_req(request.request_id)
def process_output_sockets(self, output_paths: list[str], def process_output_sockets(self, output_paths: list[str],
coord_output_path: Optional[str], coord_output_path: Optional[str],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment