Commit 6ca3d790 authored by zhuwenwen's avatar zhuwenwen
Browse files

支持pd分离p2p_async & 解决oom问题

parent f29b58c3
...@@ -18,6 +18,7 @@ from vllm.forward_context import get_forward_context ...@@ -18,6 +18,7 @@ from vllm.forward_context import get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import MLACommonMetadata from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
...@@ -213,6 +214,7 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -213,6 +214,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
kv_cache_layer = kv_cache[ \ kv_cache_layer = kv_cache[ \
forward_context.virtual_engine] forward_context.virtual_engine]
if self.p2p_nccl_engine.tensor_split_num == P2pNcclEngine.TENSOR_SPLIT_OFF:
kv_cache = self.p2p_nccl_engine.recv_tensor( kv_cache = self.p2p_nccl_engine.recv_tensor(
request.request_id + "#" + layer_name) request.request_id + "#" + layer_name)
...@@ -234,6 +236,61 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -234,6 +236,61 @@ class P2pNcclConnector(KVConnectorBase_V1):
if isinstance(tensor, tuple): if isinstance(tensor, tuple):
addr, _, _ = tensor addr, _, _ = tensor
self.p2p_nccl_engine.pool.free(addr) self.p2p_nccl_engine.pool.free(addr)
else:
dst_kv_cache_layer_shape = kv_cache_layer.shape
if isinstance(attn_metadata, MLACommonMetadata) or all(isinstance(value, MLACommonMetadata) for value in attn_metadata.values()):
num_pages = dst_kv_cache_layer_shape[0]
page_size = dst_kv_cache_layer_shape[1]
assert kv_cache_layer.is_contiguous()
dst_kv_cache_layer = kv_cache_layer.reshape(
num_pages * page_size, -1)
else:
num_pages = dst_kv_cache_layer_shape[1]
page_size = dst_kv_cache_layer_shape[2]
assert kv_cache_layer.is_contiguous()
dst_kv_cache_layer = kv_cache_layer.reshape(
2, num_pages * page_size, -1)
inject_start_index = 0
for num in range(self.p2p_nccl_engine.tensor_split_num):
kv_cache = self.p2p_nccl_engine.recv_tensor(
request.request_id + "#" + layer_name + "#" + str(num))
if kv_cache is None:
logger.warning("🚧src_kv_cache is None, %s",
request.request_id)
continue
if isinstance(attn_metadata, MLACommonMetadata) or all(isinstance(value, MLACommonMetadata) for value in attn_metadata.values()):
num_token = kv_cache.shape[0]
if len(request.slot_mapping) == num_token:
dst_kv_cache_layer[request.slot_mapping, ...] = kv_cache
else:
dst_kv_cache_layer[request.slot_mapping[inject_start_index:inject_start_index + num_token],
...] = kv_cache
else:
num_token = kv_cache.shape[1]
if len(request.slot_mapping) == num_token:
dst_kv_cache_layer[:, request.slot_mapping, ...] = kv_cache
else:
dst_kv_cache_layer[:, request.slot_mapping[inject_start_index:inject_start_index + num_token],
...] = kv_cache
inject_start_index += num_token
# inject_kv_into_layer(kv_cache_layer, kv_cache,
# request.slot_mapping, request.request_id)
tensor_id = request.request_id + "#" + layer_name + "#" + str(num)
if tensor_id in self.p2p_nccl_engine.recv_store:
tensor = self.p2p_nccl_engine.recv_store.pop(tensor_id, None)
self.p2p_nccl_engine.send_request_id_to_tensor_ids.pop(
request.request_id, None)
self.p2p_nccl_engine.recv_request_id_to_tensor_ids.pop(
request.request_id, None)
addr = 0
if isinstance(tensor, tuple):
addr, _, _ = tensor
self.p2p_nccl_engine.pool.free(addr)
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
def wait_for_layer_load(self, layer_name: str) -> None: def wait_for_layer_load(self, layer_name: str) -> None:
...@@ -296,30 +353,29 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -296,30 +353,29 @@ class P2pNcclConnector(KVConnectorBase_V1):
request.slot_mapping_device = \ request.slot_mapping_device = \
request.slot_mapping.pin_memory().to(device=kv_layer.device, non_blocking=True) request.slot_mapping.pin_memory().to(device=kv_layer.device, non_blocking=True)
slot_mapping = request.slot_mapping_device slot_mapping = request.slot_mapping_device
kv_cache = extract_kv_from_layer(kv_layer, slot_mapping)
tbo_evt = torch.cuda.Event(enable_timing=False) tbo_evt = torch.cuda.Event(enable_timing=False)
tbo_evt.record() tbo_evt.record()
pp_rank = (self.parallel_config.rank // pp_rank = (self.parallel_config.rank //
self.parallel_config.tensor_parallel_size) % \ self.parallel_config.tensor_parallel_size) % \
self.parallel_config.pipeline_parallel_size self.parallel_config.pipeline_parallel_size
if (self.pp_size == 1): if (self.pp_size == 1):
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address, tbo_evt) (kv_layer, slot_mapping), remote_address, tbo_evt)
elif (self.pp_size == 2): elif (self.pp_size == 2):
if (pp_rank == 0): if (pp_rank == 0):
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address, tbo_evt) (kv_layer, slot_mapping), remote_address, tbo_evt)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name,
kv_cache, ip + ":" + str(port + self._rank + 4), tbo_evt) (kv_layer, slot_mapping), ip + ":" + str(port + self._rank + 4), tbo_evt)
else: else:
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address, tbo_evt) (kv_layer, slot_mapping), remote_address, tbo_evt)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name,
kv_cache, ip + ":" + str(port + self._rank - 4), tbo_evt) (kv_layer, slot_mapping), ip + ":" + str(port + self._rank - 4), tbo_evt)
elif (self.pp_size == 8): elif (self.pp_size == 8):
for i in range(8): for i in range(8):
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name,
kv_cache, ip + ":" + str(port + i), tbo_evt) (kv_layer, slot_mapping), ip + ":" + str(port + i), tbo_evt)
else: else:
print("Error: only suppprt pp1 pp2 pp8!!!!!!") print("Error: only suppprt pp1 pp2 pp8!!!!!!")
else: else:
......
...@@ -21,6 +21,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ...@@ -21,6 +21,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import
TensorMemoryPool) TensorMemoryPool)
from vllm.utils import current_stream, get_ip from vllm.utils import current_stream, get_ip
from vllm import envs from vllm import envs
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
...@@ -62,6 +63,8 @@ def set_p2p_nccl_context(num_channels: str): ...@@ -62,6 +63,8 @@ def set_p2p_nccl_context(num_channels: str):
class P2pNcclEngine: class P2pNcclEngine:
TENSOR_SPLIT_OFF = 0
def __init__(self, def __init__(self,
local_rank: int, local_rank: int,
config: KVTransferConfig, config: KVTransferConfig,
...@@ -111,9 +114,12 @@ class P2pNcclEngine: ...@@ -111,9 +114,12 @@ class P2pNcclEngine:
self.recv_store_cv = threading.Condition() self.recv_store_cv = threading.Condition()
self.send_stream = torch.cuda.Stream() self.send_stream = torch.cuda.Stream()
# self.send_stream = tbo_all_reduce_stream
self.recv_stream = torch.cuda.Stream() self.recv_stream = torch.cuda.Stream()
self.p2p_async_kv_tokens = envs.VLLM_P2P_BUF_TOKENS
self.p2p_async_buf = None
self.tensor_split_num: int = 0
mem_pool_size_gb = self.config.get_from_extra_config( mem_pool_size_gb = self.config.get_from_extra_config(
"mem_pool_size_gb", DEFAULT_MEM_POOL_SIZE_GB) "mem_pool_size_gb", DEFAULT_MEM_POOL_SIZE_GB)
self.pool = TensorMemoryPool(max_block_size=int(mem_pool_size_gb) * self.pool = TensorMemoryPool(max_block_size=int(mem_pool_size_gb) *
...@@ -208,7 +214,54 @@ class P2pNcclEngine: ...@@ -208,7 +214,54 @@ class P2pNcclEngine:
return self._send_sync(tensor_id, tensor, remote_address) return self._send_sync(tensor_id, tensor, remote_address)
elif self.send_type == "PUT_ASYNC": elif self.send_type == "PUT_ASYNC":
with self.send_queue_cv: with self.send_queue_cv:
self.send_queue.append([tensor_id, remote_address, tensor, tbo_evt]) self.send_queue.append([tensor_id, remote_address, tensor])
self.send_queue_cv.notify()
else: # GET
with self.send_store_cv:
tensor_size = tensor.element_size() * tensor.numel()
while (self.buffer_size + tensor_size
> self.buffer_size_threshold):
oldest_tenser_id = next(iter(self.send_store))
oldest_tenser = self.send_store.pop(oldest_tenser_id)
oldest_tenser_size = oldest_tenser.element_size(
) * oldest_tenser.numel()
self.buffer_size -= oldest_tenser_size
logger.info(
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
" buffer_size:%d, oldest_tenser_size:%d, rank:%d",
remote_address, tensor_id, tensor_size,
self.buffer_size, oldest_tenser_size, self.rank)
self.send_store[tensor_id] = tensor
self.buffer_size += tensor_size
logger.debug(
"🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
"shape:%s, rank:%d, buffer_size:%d(%.2f%%)",
remote_address, tensor_id, tensor_size, tensor.shape,
self.rank, self.buffer_size,
self.buffer_size / self.buffer_size_threshold * 100)
return True
def p2p_async_send_tensor(
self,
tensor_id: str,
tensor: torch.Tensor,
remote_address: typing.Optional[str] = None,
tbo_evt = None,
) -> bool:
if remote_address is None:
with self.recv_store_cv:
self.recv_store[tensor_id] = tensor
self.recv_store_cv.notify()
return True
else:
if self.send_type == "PUT":
return self._send_sync(tensor_id, tensor, remote_address)
elif self.send_type == "PUT_ASYNC":
with self.send_queue_cv:
kv_layer, slot_mapping = tensor # tesor (kv_layer, slot_mapping)
self.send_queue.append([tensor_id, remote_address, kv_layer, slot_mapping, tbo_evt])
self.send_queue_cv.notify() self.send_queue_cv.notify()
else: # GET else: # GET
with self.send_store_cv: with self.send_store_cv:
...@@ -313,6 +366,10 @@ class P2pNcclEngine: ...@@ -313,6 +366,10 @@ class P2pNcclEngine:
self.zmq_address, remote_address.decode(), rank) self.zmq_address, remote_address.decode(), rank)
elif data["cmd"] == "PUT": elif data["cmd"] == "PUT":
tensor_id = data["tensor_id"] tensor_id = data["tensor_id"]
if "tensor_split_num" in data:
self.tensor_split_num = data["tensor_split_num"]
else:
self.tensor_split_num= self.TENSOR_SPLIT_OFF
try: try:
with torch.cuda.stream(self.recv_stream): with torch.cuda.stream(self.recv_stream):
tensor = torch.empty(data["shape"], tensor = torch.empty(data["shape"],
...@@ -392,11 +449,16 @@ class P2pNcclEngine: ...@@ -392,11 +449,16 @@ class P2pNcclEngine:
with self.send_queue_cv: with self.send_queue_cv:
while not self.send_queue: while not self.send_queue:
self.send_queue_cv.wait() self.send_queue_cv.wait()
tensor_id, remote_address, tensor, tbo_evt = self.send_queue.popleft() if envs.VLLM_ENABLE_TBO or envs.VLLM_P2P_ASYNC:
tensor_id, remote_address, kv_layer, slot_mapping, tbo_evt = self.send_queue.popleft()
else:
tensor_id, remote_address, tensor = self.send_queue.popleft()
if not self.send_queue: if not self.send_queue:
self.send_queue_cv.notify() self.send_queue_cv.notify()
if (envs.VLLM_ENABLE_TBO or envs.VLLM_P2P_ASYNC) and tbo_evt is not None: if (envs.VLLM_ENABLE_TBO or envs.VLLM_P2P_ASYNC) and tbo_evt is not None:
self.send_stream.wait_event(tbo_evt) self.send_stream.wait_event(tbo_evt)
self._send_kv_p2p_sync(tensor_id, kv_layer, slot_mapping, remote_address)
else:
self._send_sync(tensor_id, tensor, remote_address) self._send_sync(tensor_id, tensor, remote_address)
def wait_for_sent(self): def wait_for_sent(self):
...@@ -410,6 +472,75 @@ class P2pNcclEngine: ...@@ -410,6 +472,75 @@ class P2pNcclEngine:
"🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue" "🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue"
" to be empty, rank:%d", duration * 1000, self.rank) " to be empty, rank:%d", duration * 1000, self.rank)
def _send_kv_p2p_sync(self, tensor_id: str, kv_layer: torch.Tensor,
slot_mapping: torch.Tensor, remote_address: str) -> bool:
if remote_address not in self.socks:
self._create_connect(remote_address)
sock = self.socks[remote_address]
comm, rank = self.comms[remote_address]
is_mla = (kv_layer.ndim == 3)
hidden_dim = kv_layer.shape[-1]
if self.p2p_async_buf is None:
if is_mla:
self.p2p_async_buf = torch.empty((self.p2p_async_kv_tokens, hidden_dim),
dtype=kv_layer.dtype, device=kv_layer.device)
else:
self.p2p_async_buf = torch.empty((2, self.p2p_async_kv_tokens, hidden_dim),
dtype=kv_layer.dtype, device=kv_layer.device)
pack_num = (slot_mapping.shape[0] - 1) // self.p2p_async_kv_tokens + 1
self.tensor_split_num = pack_num
with torch.cuda.stream(self.send_stream):
for pack_idx in range(pack_num):
start = pack_idx * self.p2p_async_kv_tokens
end = min((pack_idx + 1) * self.p2p_async_kv_tokens, slot_mapping.shape[0])
sub_index = slot_mapping[start:end]
if is_mla:
num_pages, page_size = kv_layer.shape[0], kv_layer.shape[1]
data = kv_layer.reshape(num_pages * page_size, -1)
torch.index_select(data, dim=0, index=sub_index, out=self.p2p_async_buf[:end-start])
tx_shape = (end - start, hidden_dim)
else:
num_pages, page_size = kv_layer.shape[1], kv_layer.shape[2]
data = kv_layer.reshape(2, num_pages * page_size, -1)
torch.index_select(data, dim=1, index=sub_index, out=self.p2p_async_buf[:, :end-start])
tx_shape = (2, end - start, hidden_dim)
if is_mla:
send_tensor = self.p2p_async_buf[:end-start]
else:
send_tensor = self.p2p_async_buf[:, :end-start]
header = {
"cmd": "PUT",
"tensor_id": tensor_id + "#" + str(pack_idx), # 拼 pack_idx
"pack_idx": pack_idx,
"tensor_split_num": pack_num,
"shape": tx_shape,
"dtype": str(kv_layer.dtype).replace("torch.", "")
}
sock.send(msgpack.dumps(header))
response = sock.recv()
if response != b"0":
logger.error(
"🔴Send Tensor Failed | %s 👉 %s | Rank:%s | shape:%s | size:%.4f GB | response:%s",
self.zmq_address, remote_address, rank,
tuple(send_tensor.shape), send_tensor.element_size() * send_tensor.numel() / 1024**3,
response.decode()
)
return False
self._send(comm, send_tensor, rank ^ 1, self.send_stream)
if self.send_type == "PUT_ASYNC":
self._have_sent_tensor_id(tensor_id)
return True
def _send_sync( def _send_sync(
self, self,
tensor_id: str, tensor_id: str,
......
...@@ -174,6 +174,7 @@ if TYPE_CHECKING: ...@@ -174,6 +174,7 @@ if TYPE_CHECKING:
USE_FUSED_SILU_MUL_QUANT: bool = False USE_FUSED_SILU_MUL_QUANT: bool = False
VLLM_USE_MORI_EP: bool = False VLLM_USE_MORI_EP: bool = False
VLLM_P2P_ASYNC: bool = False VLLM_P2P_ASYNC: bool = False
VLLM_P2P_BUF_TOKENS: int = 30000
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -1146,6 +1147,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1146,6 +1147,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
# vllm pd separation will be used async # vllm pd separation will be used async
"VLLM_P2P_ASYNC": "VLLM_P2P_ASYNC":
lambda: bool(int(os.getenv("VLLM_P2P_ASYNC", "0"))), lambda: bool(int(os.getenv("VLLM_P2P_ASYNC", "0"))),
# pd separation p2p async buf tokens
"VLLM_P2P_BUF_TOKENS":
lambda: int(os.getenv("VLLM_P2P_BUF_TOKENS", "30000")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -159,8 +159,6 @@ def prepare_tbo_atten_metadata( ...@@ -159,8 +159,6 @@ def prepare_tbo_atten_metadata(
# The block_table for RIGHT starts from (req_offset-1). # The block_table for RIGHT starts from (req_offset-1).
# Align both offsets to that, and re-build the seq_lens for row-0. # Align both offsets to that, and re-build the seq_lens for row-0.
seq_len_offset = req_offset - 1 seq_len_offset = req_offset - 1
# query_start_offset = req_offset - 1
query_start_offset = req_offset query_start_offset = req_offset
# row-0 is the split request (global row index = req_offset-1): # row-0 is the split request (global row index = req_offset-1):
...@@ -182,7 +180,6 @@ def prepare_tbo_atten_metadata( ...@@ -182,7 +180,6 @@ def prepare_tbo_atten_metadata(
else: else:
# RIGHT without split-in-req: natural positions # RIGHT without split-in-req: natural positions
seq_len_offset = req_offset seq_len_offset = req_offset
# query_start_offset = req_offset
query_start_offset = req_offset + 1 query_start_offset = req_offset + 1
seq_lens_cpu_local = torch.as_tensor(default_seq_lens, device=runner.seq_lens_cpu.device) seq_lens_cpu_local = torch.as_tensor(default_seq_lens, device=runner.seq_lens_cpu.device)
......
...@@ -1302,7 +1302,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1302,7 +1302,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[ModelRunnerOutput, IntermediateTensors]: ) -> Union[ModelRunnerOutput, IntermediateTensors]:
# profile.StartTracer()
self._update_states(scheduler_output) self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens: if not scheduler_output.total_num_scheduled_tokens:
if not has_kv_transfer_group(): if not has_kv_transfer_group():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment