"vllm/vscode:/vscode.git/clone" did not exist on "b634e619bbcfed0abe4e01d0e2d97fb1fdfdbbd5"
Commit 5f308e68 authored by yangshj1's avatar yangshj1
Browse files

add cpp, async tensor transfer between pp ranks

parent 7df6b8ee
......@@ -3,6 +3,8 @@
import os
import copy
import queue
import threading
import gc
import time
import weakref
......@@ -323,6 +325,10 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
device="cpu",
pin_memory=self.pin_memory)
self.seq_lens_np = self.seq_lens_cpu.numpy()
self._recv_queue: queue.Queue[IntermediateTensors] = queue.Queue()
self._recv_thread: Optional[threading.Thread] = None
self._recv_stream: Optional[torch.cuda.Stream] = None
self._recv_event: Optional[torch.cuda.Event] = None
# Layer pairings for cross-layer KV sharing.
# If an Attention layer `layer_name` is in the keys of this dict, it
......@@ -730,8 +736,8 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
self.positions_cpu[:total_num_scheduled_tokens],
non_blocking=True)
self.query_start_loc[:num_reqs + 1].copy_(
self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
current_cpu_slice_clone = self.query_start_loc_cpu[:num_reqs + 1].clone()
self.query_start_loc[:num_reqs + 1].copy_(current_cpu_slice_clone, non_blocking=True)
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
non_blocking=True)
......@@ -740,7 +746,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# Note: pad query_start_loc to be non-decreasing, as kernels
# like FlashAttention requires that
self.query_start_loc[num_reqs + 1:].fill_(
self.query_start_loc_cpu[num_reqs].item())
current_cpu_slice_clone[num_reqs].item())
query_start_loc = self.query_start_loc[:num_reqs + 1]
seq_lens = self.seq_lens[:num_reqs]
......@@ -1337,6 +1343,22 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
finished_recving=finished_recving,
)
def _recv_tensor_dict(self):
with torch.cuda.stream(self._recv_stream):
intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group(),
)
)
self._recv_event.record(self._recv_stream)
self._recv_queue.put(intermediate_tensors)
def _tensor_dict_recv_thread(self):
torch.cuda.set_device(self.device)
self_rank = get_pp_group().rank_in_group
while True:
self._recv_tensor_dict()
@torch.inference_mode()
def execute_model(
self,
......@@ -1426,9 +1448,18 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
if get_pp_group().is_first_rank:
intermediate_tensors = None
else:
def _recv_tensor_dict():
return self._recv_queue.get()
if self._recv_thread is None:
self._recv_stream = torch.cuda.Stream()
self._recv_event = torch.cuda.Event()
self._recv_thread = threading.Thread(target=self._tensor_dict_recv_thread, daemon=True, name="pp_recv_thread")
self._recv_thread.start()
intermediate_tensors = _recv_tensor_dict()
torch.cuda.current_stream().wait_event(self._recv_event)
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
num_input_tokens, intermediate_tensors, True)
# Some attention backends only support CUDA Graphs in pure decode.
# If attention doesn't support CUDA Graphs for this batch, but we
# compiled with full CUDA graphs, we have to skip them entirely.
......@@ -1494,6 +1525,9 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
hidden_states, aux_hidden_states = model_output
else:
hidden_states = model_output
if isinstance(model_output, IntermediateTensors):
residual_clone = model_output.tensors["residual"].clone()
hidden_states.tensors["residual"] = residual_clone
aux_hidden_states = None
# Broadcast PP output for external_launcher (torchrun)
......
......@@ -2,7 +2,9 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A GPU worker class."""
import gc
import queue
import os
import threading
from typing import TYPE_CHECKING, Optional
import torch
......@@ -81,6 +83,10 @@ class Worker(WorkerBase):
torch_profiler_trace_dir, use_gzip=True))
else:
self.profiler = None
self._send_queue: queue.Queue[tuple[IntermediateTensors, SchedulerOutput, torch.cuda.Event]] = queue.Queue(1)
self._send_thread: Optional[threading.Thread] = None
self._send_stream: Optional[torch.cuda.Stream] = None
self._send_event: Optional[torch.cuda.Event] = None
def sleep(self, level: int = 1) -> None:
free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
......@@ -305,16 +311,35 @@ class Worker(WorkerBase):
def get_model(self) -> nn.Module:
return self.model_runner.get_model()
def _send_tensor_dict(self):
intermediate_tensors, scheduler_output, event = self._send_queue.get()
assert event is not None
# 等待event在GPU执行完成
event.synchronize()
def _send_tensor_dict():
# [waterliao mark]
# logger.info(f"[waterliao debug] 222 rank{self.local_rank} _send_tensor_dict intermediate_tensors:{intermediate_tensors}, residual.shape={intermediate_tensors['residual'].shape}")
get_pp_group().send_tensor_dict(
intermediate_tensors.tensors,
all_gather_group=get_tp_group(),
)
_send_tensor_dict()
def _tensor_dict_send_thread(self):
torch.cuda.set_device(self.device)
torch.cuda.set_stream(self._send_stream)
while True:
self._send_tensor_dict()
@torch.inference_mode()
def execute_model(
self,
scheduler_output: "SchedulerOutput",
) -> Optional[ModelRunnerOutput]:
intermediate_tensors = None
if not get_pp_group().is_first_rank:
intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group()))
if envs.VLLM_ZERO_OVERHEAD:
use_stream = zero_overhead_stream(self.device)
with torch.cuda.stream(use_stream):
......@@ -327,8 +352,14 @@ class Worker(WorkerBase):
if parallel_config.distributed_executor_backend != "external_launcher" \
and not get_pp_group().is_last_rank:
assert isinstance(output, IntermediateTensors)
get_pp_group().send_tensor_dict(output.tensors,
all_gather_group=get_tp_group())
if self._send_thread is None:
self._send_stream = torch.cuda.Stream()
self._send_thread = threading.Thread(target=self._tensor_dict_send_thread, daemon=True, name="pp_send_thread")
self._send_thread.start()
self._send_event = torch.cuda.Event()
send_event = self._send_event
send_event.record()
self._send_queue.put((output, scheduler_output, send_event))
return None
assert isinstance(output, ModelRunnerOutput)
return output if self.is_driver_worker else None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment