Commit 45517312 authored by wanghl6's avatar wanghl6
Browse files

Merge branch 'v0.9.2-dev' into '0.9.2-dev-tx-kernel_fuse'

# Conflicts:
#   vllm/envs.py
#   vllm/v1/attention/backends/mla/common.py
parents d7bee8b6 a34bff19
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
import os import os
import copy import copy
import queue
import threading
import gc import gc
import time import time
import weakref import weakref
...@@ -323,6 +325,10 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -323,6 +325,10 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
device="cpu", device="cpu",
pin_memory=self.pin_memory) pin_memory=self.pin_memory)
self.seq_lens_np = self.seq_lens_cpu.numpy() self.seq_lens_np = self.seq_lens_cpu.numpy()
self._recv_queue: queue.Queue[IntermediateTensors] = queue.Queue()
self._recv_thread: Optional[threading.Thread] = None
self._recv_stream: Optional[torch.cuda.Stream] = None
self._recv_event: Optional[torch.cuda.Event] = None
# Layer pairings for cross-layer KV sharing. # Layer pairings for cross-layer KV sharing.
# If an Attention layer `layer_name` is in the keys of this dict, it # If an Attention layer `layer_name` is in the keys of this dict, it
...@@ -730,8 +736,8 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -730,8 +736,8 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
self.positions_cpu[:total_num_scheduled_tokens], self.positions_cpu[:total_num_scheduled_tokens],
non_blocking=True) non_blocking=True)
self.query_start_loc[:num_reqs + 1].copy_( current_cpu_slice_clone = self.query_start_loc_cpu[:num_reqs + 1].clone()
self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True) self.query_start_loc[:num_reqs + 1].copy_(current_cpu_slice_clone, non_blocking=True)
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
non_blocking=True) non_blocking=True)
...@@ -740,7 +746,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -740,7 +746,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# Note: pad query_start_loc to be non-decreasing, as kernels # Note: pad query_start_loc to be non-decreasing, as kernels
# like FlashAttention requires that # like FlashAttention requires that
self.query_start_loc[num_reqs + 1:].fill_( self.query_start_loc[num_reqs + 1:].fill_(
self.query_start_loc_cpu[num_reqs].item()) current_cpu_slice_clone[num_reqs].item())
query_start_loc = self.query_start_loc[:num_reqs + 1] query_start_loc = self.query_start_loc[:num_reqs + 1]
seq_lens = self.seq_lens[:num_reqs] seq_lens = self.seq_lens[:num_reqs]
...@@ -1337,6 +1343,22 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -1337,6 +1343,22 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
finished_recving=finished_recving, finished_recving=finished_recving,
) )
def _recv_tensor_dict(self):
with torch.cuda.stream(self._recv_stream):
intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group(),
)
)
self._recv_event.record(self._recv_stream)
self._recv_queue.put(intermediate_tensors)
def _tensor_dict_recv_thread(self):
torch.cuda.set_device(self.device)
self_rank = get_pp_group().rank_in_group
while True:
self._recv_tensor_dict()
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
...@@ -1426,9 +1448,18 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -1426,9 +1448,18 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
intermediate_tensors = None intermediate_tensors = None
else: else:
def _recv_tensor_dict():
return self._recv_queue.get()
if self._recv_thread is None:
self._recv_stream = torch.cuda.Stream()
self._recv_event = torch.cuda.Event()
self._recv_thread = threading.Thread(target=self._tensor_dict_recv_thread, daemon=True, name="pp_recv_thread")
self._recv_thread.start()
intermediate_tensors = _recv_tensor_dict()
torch.cuda.current_stream().wait_event(self._recv_event)
intermediate_tensors = self.sync_and_slice_intermediate_tensors( intermediate_tensors = self.sync_and_slice_intermediate_tensors(
num_input_tokens, intermediate_tensors, True) num_input_tokens, intermediate_tensors, True)
# Some attention backends only support CUDA Graphs in pure decode. # Some attention backends only support CUDA Graphs in pure decode.
# If attention doesn't support CUDA Graphs for this batch, but we # If attention doesn't support CUDA Graphs for this batch, but we
# compiled with full CUDA graphs, we have to skip them entirely. # compiled with full CUDA graphs, we have to skip them entirely.
...@@ -1494,6 +1525,9 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -1494,6 +1525,9 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
hidden_states, aux_hidden_states = model_output hidden_states, aux_hidden_states = model_output
else: else:
hidden_states = model_output hidden_states = model_output
if isinstance(model_output, IntermediateTensors):
residual_clone = model_output.tensors["residual"].clone()
hidden_states.tensors["residual"] = residual_clone
aux_hidden_states = None aux_hidden_states = None
# Broadcast PP output for external_launcher (torchrun) # Broadcast PP output for external_launcher (torchrun)
...@@ -3290,6 +3324,16 @@ class GPUModelRunnerMTP(GPUModelRunnerBase): ...@@ -3290,6 +3324,16 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
intermediate_tensors = None intermediate_tensors = None
else: else:
def _recv_tensor_dict():
return self._recv_queue.get()
if self._recv_thread is None:
self._recv_stream = torch.cuda.Stream()
self._recv_event = torch.cuda.Event()
self._recv_thread = threading.Thread(target=self._tensor_dict_recv_thread, daemon=True, name="pp_recv_thread")
self._recv_thread.start()
intermediate_tensors = _recv_tensor_dict()
torch.cuda.current_stream().wait_event(self._recv_event)
intermediate_tensors = self.sync_and_slice_intermediate_tensors( intermediate_tensors = self.sync_and_slice_intermediate_tensors(
num_input_tokens, intermediate_tensors, True) num_input_tokens, intermediate_tensors, True)
...@@ -3358,6 +3402,9 @@ class GPUModelRunnerMTP(GPUModelRunnerBase): ...@@ -3358,6 +3402,9 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
hidden_states, aux_hidden_states = model_output hidden_states, aux_hidden_states = model_output
else: else:
hidden_states = model_output hidden_states = model_output
if isinstance(model_output, IntermediateTensors):
residual_clone = model_output.tensors["residual"].clone()
hidden_states.tensors["residual"] = residual_clone
aux_hidden_states = None aux_hidden_states = None
# Broadcast PP output for external_launcher (torchrun) # Broadcast PP output for external_launcher (torchrun)
......
...@@ -2,7 +2,9 @@ ...@@ -2,7 +2,9 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A GPU worker class.""" """A GPU worker class."""
import gc import gc
import queue
import os import os
import threading
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
import torch import torch
...@@ -81,6 +83,10 @@ class Worker(WorkerBase): ...@@ -81,6 +83,10 @@ class Worker(WorkerBase):
torch_profiler_trace_dir, use_gzip=True)) torch_profiler_trace_dir, use_gzip=True))
else: else:
self.profiler = None self.profiler = None
self._send_queue: queue.Queue[tuple[IntermediateTensors, SchedulerOutput, torch.cuda.Event]] = queue.Queue(1)
self._send_thread: Optional[threading.Thread] = None
self._send_stream: Optional[torch.cuda.Stream] = None
self._send_event: Optional[torch.cuda.Event] = None
def sleep(self, level: int = 1) -> None: def sleep(self, level: int = 1) -> None:
free_bytes_before_sleep = torch.cuda.mem_get_info()[0] free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
...@@ -305,16 +311,33 @@ class Worker(WorkerBase): ...@@ -305,16 +311,33 @@ class Worker(WorkerBase):
def get_model(self) -> nn.Module: def get_model(self) -> nn.Module:
return self.model_runner.get_model() return self.model_runner.get_model()
def _send_tensor_dict(self):
intermediate_tensors, scheduler_output, event = self._send_queue.get()
assert event is not None
# 等待event在GPU执行完成
event.synchronize()
def _send_tensor_dict():
get_pp_group().send_tensor_dict(
intermediate_tensors.tensors,
all_gather_group=get_tp_group(),
)
_send_tensor_dict()
def _tensor_dict_send_thread(self):
torch.cuda.set_device(self.device)
torch.cuda.set_stream(self._send_stream)
while True:
self._send_tensor_dict()
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
) -> Optional[ModelRunnerOutput]: ) -> Optional[ModelRunnerOutput]:
intermediate_tensors = None intermediate_tensors = None
if not get_pp_group().is_first_rank:
intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group()))
if envs.VLLM_ZERO_OVERHEAD: if envs.VLLM_ZERO_OVERHEAD:
use_stream = zero_overhead_stream(self.device) use_stream = zero_overhead_stream(self.device)
with torch.cuda.stream(use_stream): with torch.cuda.stream(use_stream):
...@@ -327,8 +350,14 @@ class Worker(WorkerBase): ...@@ -327,8 +350,14 @@ class Worker(WorkerBase):
if parallel_config.distributed_executor_backend != "external_launcher" \ if parallel_config.distributed_executor_backend != "external_launcher" \
and not get_pp_group().is_last_rank: and not get_pp_group().is_last_rank:
assert isinstance(output, IntermediateTensors) assert isinstance(output, IntermediateTensors)
get_pp_group().send_tensor_dict(output.tensors, if self._send_thread is None:
all_gather_group=get_tp_group()) self._send_stream = torch.cuda.Stream()
self._send_thread = threading.Thread(target=self._tensor_dict_send_thread, daemon=True, name="pp_send_thread")
self._send_thread.start()
self._send_event = torch.cuda.Event()
send_event = self._send_event
send_event.record()
self._send_queue.put((output, scheduler_output, send_event))
return None return None
assert isinstance(output, ModelRunnerOutput) assert isinstance(output, ModelRunnerOutput)
return output if self.is_driver_worker else None return output if self.is_driver_worker else None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment