Commit a34bff19 authored by wangmin6's avatar wangmin6
Browse files

Merge branch 'v0.9.2-dev-tx-cpp' into 'v0.9.2-dev'

V0.9.2 dev tx cpp

See merge request dcutoolkit/deeplearing/vllm!474
parents d761561a badaff2d
...@@ -219,6 +219,7 @@ if TYPE_CHECKING: ...@@ -219,6 +219,7 @@ if TYPE_CHECKING:
VLLM_ENABLE_SHARED_EXPERTS_FUSION: bool = False VLLM_ENABLE_SHARED_EXPERTS_FUSION: bool = False
VLLM_USE_MOE_W16A16_TRITON: bool = False VLLM_USE_MOE_W16A16_TRITON: bool = False
VLLM_USE_FUSED_DTBMM: bool = False VLLM_USE_FUSED_DTBMM: bool = False
VLLM_FUSE_CAT_AND_CAST_FP8: bool = False
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -1404,6 +1405,9 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1404,6 +1405,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FUSED_DTBMM": "VLLM_USE_FUSED_DTBMM":
lambda: (os.environ.get("VLLM_USE_FUSED_DTBMM", "False").lower() in lambda: (os.environ.get("VLLM_USE_FUSED_DTBMM", "False").lower() in
("true", "1")), ("true", "1")),
"VLLM_FUSE_CAT_AND_CAST_FP8":
lambda: (os.environ.get("VLLM_FUSE_CAT_AND_CAST_FP8", "False").lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -1036,33 +1036,44 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1036,33 +1036,44 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \ kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) use_flash_fp8_arch = ( \
torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" \
if envs.VLLM_USE_OPT_CAT: and envs.VLLM_USE_FLASH_ATTN_FP8
if k_nope.shape[0] > 1024: )
from vllm.v1.attention.backends.mla.test_concat import lightop_concat_prefill_helper use_fused_fp8_op = use_flash_fp8_arch and envs.VLLM_FUSE_CAT_AND_CAST_FP8
k = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)),
dim=2) k_pe_expanded = k_pe.expand(k_pe.shape[0], self.num_heads, k_pe.shape[-1])
else: if use_fused_fp8_op:
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), from lightop import op
dim=-1) q, k, v = op.ds_fused_qkv_cast_fp8(
q,
kv_nope,
k_pe_expanded,
self.qk_nope_head_dim,
self.v_head_dim
)
else: else:
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), k_nope, v = kv_nope\
dim=-1) .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_ATTN_FP8: if envs.VLLM_USE_OPT_CAT:
q_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device) if k_nope.shape[0] > 1024:
k_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device) from vllm.v1.attention.backends.mla.test_concat import lightop_concat_prefill_helper
v_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device) k = lightop_concat_prefill_helper(k_nope, k_pe_expanded, dim=2)
descale_shape = (attn_metadata.prefill.query_start_loc.numel() - 1, q.shape[1]) else:
q_descale = q_descale.expand(descale_shape) k = torch.cat((k_nope, k_pe_expanded), dim=-1)
k_descale = k_descale.expand(descale_shape) else:
v_descale = v_descale.expand(descale_shape) k = torch.cat((k_nope, k_pe_expanded), dim=-1)
q = q.to(torch.float8_e4m3fn)
k = k.to(torch.float8_e4m3fn) if use_flash_fp8_arch:
v = v.to(torch.float8_e4m3fn) q_descale = None
k_descale = None
v_descale = None
if not use_fused_fp8_op:
q = q.to(torch.float8_e4m3fn)
k = k.to(torch.float8_e4m3fn)
v = v.to(torch.float8_e4m3fn)
attn_output, attn_softmax_lse = \ attn_output, attn_softmax_lse = \
self._flash_attn_varlen_diff_headdims( self._flash_attn_varlen_diff_headdims(
q=q, q=q,
...@@ -1134,32 +1145,41 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1134,32 +1145,41 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\ use_flash_fp8_arch = ( \
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" \
and envs.VLLM_USE_FLASH_ATTN_FP8
if envs.VLLM_USE_OPT_CAT: )
if k_nope.shape[0] > 1024: use_fused_fp8_op = use_flash_fp8_arch and envs.VLLM_FUSE_CAT_AND_CAST_FP8
from vllm.v1.attention.backends.mla.test_concat import lightop_concat_prefill_helper
k = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)), if use_fused_fp8_op:
dim=2) from lightop import op
else: k_pe_expanded = k_pe.expand(k_pe.shape[0], self.num_heads, k_pe.shape[-1])
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), q, k, v = op.ds_fused_qkv_cast_fp8(
dim=-1) q,
kv_nope,
k_pe_expanded,
self.qk_nope_head_dim,
self.v_head_dim
)
else: else:
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
if envs.VLLM_USE_OPT_CAT:
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_ATTN_FP8: if k_nope.shape[0] > 1024:
q_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device) from vllm.v1.attention.backends.mla.test_concat import lightop_concat_prefill_helper
k_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device) k = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)), dim=2)
v_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device) else:
descale_shape = (attn_metadata.prefill.query_start_loc.numel() - 1, q.shape[1]) k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
q_descale = q_descale.expand(descale_shape) else:
k_descale = k_descale.expand(descale_shape) k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
v_descale = v_descale.expand(descale_shape)
if use_flash_fp8_arch:
q = q.to(torch.float8_e4m3fn) q_descale = None
k = k.to(torch.float8_e4m3fn) k_descale = None
v = v.to(torch.float8_e4m3fn) v_descale = None
if not use_fused_fp8_op:
q = q.to(torch.float8_e4m3fn)
k = k.to(torch.float8_e4m3fn)
v = v.to(torch.float8_e4m3fn)
output = self._flash_attn_varlen_diff_headdims( output = self._flash_attn_varlen_diff_headdims(
q=q, q=q,
k=k, k=k,
...@@ -1270,7 +1290,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1270,7 +1290,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
has_decode = attn_metadata.num_decodes > 0 has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0 has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens num_decode_tokens = attn_metadata.num_decode_tokens
prefill_k_pe = k_pe[num_decode_tokens:] prefill_k_pe = k_pe[num_decode_tokens:]
if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
decode_q = q[:num_decode_tokens] decode_q = q[:num_decode_tokens]
...@@ -1356,7 +1375,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1356,7 +1375,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
False, False,
1e-6, 1e-6,
) )
if has_prefill: if has_prefill:
if envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: if envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
prefill_k_c_normed = key_normed[:num_actual_toks, ...] prefill_k_c_normed = key_normed[:num_actual_toks, ...]
......
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
import os import os
import copy import copy
import queue
import threading
import gc import gc
import time import time
import weakref import weakref
...@@ -323,6 +325,10 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -323,6 +325,10 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
device="cpu", device="cpu",
pin_memory=self.pin_memory) pin_memory=self.pin_memory)
self.seq_lens_np = self.seq_lens_cpu.numpy() self.seq_lens_np = self.seq_lens_cpu.numpy()
self._recv_queue: queue.Queue[IntermediateTensors] = queue.Queue()
self._recv_thread: Optional[threading.Thread] = None
self._recv_stream: Optional[torch.cuda.Stream] = None
self._recv_event: Optional[torch.cuda.Event] = None
# Layer pairings for cross-layer KV sharing. # Layer pairings for cross-layer KV sharing.
# If an Attention layer `layer_name` is in the keys of this dict, it # If an Attention layer `layer_name` is in the keys of this dict, it
...@@ -730,8 +736,8 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -730,8 +736,8 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
self.positions_cpu[:total_num_scheduled_tokens], self.positions_cpu[:total_num_scheduled_tokens],
non_blocking=True) non_blocking=True)
self.query_start_loc[:num_reqs + 1].copy_( current_cpu_slice_clone = self.query_start_loc_cpu[:num_reqs + 1].clone()
self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True) self.query_start_loc[:num_reqs + 1].copy_(current_cpu_slice_clone, non_blocking=True)
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
non_blocking=True) non_blocking=True)
...@@ -740,7 +746,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -740,7 +746,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# Note: pad query_start_loc to be non-decreasing, as kernels # Note: pad query_start_loc to be non-decreasing, as kernels
# like FlashAttention requires that # like FlashAttention requires that
self.query_start_loc[num_reqs + 1:].fill_( self.query_start_loc[num_reqs + 1:].fill_(
self.query_start_loc_cpu[num_reqs].item()) current_cpu_slice_clone[num_reqs].item())
query_start_loc = self.query_start_loc[:num_reqs + 1] query_start_loc = self.query_start_loc[:num_reqs + 1]
seq_lens = self.seq_lens[:num_reqs] seq_lens = self.seq_lens[:num_reqs]
...@@ -1337,6 +1343,22 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -1337,6 +1343,22 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
finished_recving=finished_recving, finished_recving=finished_recving,
) )
def _recv_tensor_dict(self):
with torch.cuda.stream(self._recv_stream):
intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group(),
)
)
self._recv_event.record(self._recv_stream)
self._recv_queue.put(intermediate_tensors)
def _tensor_dict_recv_thread(self):
torch.cuda.set_device(self.device)
self_rank = get_pp_group().rank_in_group
while True:
self._recv_tensor_dict()
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
...@@ -1426,9 +1448,18 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -1426,9 +1448,18 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
intermediate_tensors = None intermediate_tensors = None
else: else:
def _recv_tensor_dict():
return self._recv_queue.get()
if self._recv_thread is None:
self._recv_stream = torch.cuda.Stream()
self._recv_event = torch.cuda.Event()
self._recv_thread = threading.Thread(target=self._tensor_dict_recv_thread, daemon=True, name="pp_recv_thread")
self._recv_thread.start()
intermediate_tensors = _recv_tensor_dict()
torch.cuda.current_stream().wait_event(self._recv_event)
intermediate_tensors = self.sync_and_slice_intermediate_tensors( intermediate_tensors = self.sync_and_slice_intermediate_tensors(
num_input_tokens, intermediate_tensors, True) num_input_tokens, intermediate_tensors, True)
# Some attention backends only support CUDA Graphs in pure decode. # Some attention backends only support CUDA Graphs in pure decode.
# If attention doesn't support CUDA Graphs for this batch, but we # If attention doesn't support CUDA Graphs for this batch, but we
# compiled with full CUDA graphs, we have to skip them entirely. # compiled with full CUDA graphs, we have to skip them entirely.
...@@ -1494,6 +1525,9 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -1494,6 +1525,9 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
hidden_states, aux_hidden_states = model_output hidden_states, aux_hidden_states = model_output
else: else:
hidden_states = model_output hidden_states = model_output
if isinstance(model_output, IntermediateTensors):
residual_clone = model_output.tensors["residual"].clone()
hidden_states.tensors["residual"] = residual_clone
aux_hidden_states = None aux_hidden_states = None
# Broadcast PP output for external_launcher (torchrun) # Broadcast PP output for external_launcher (torchrun)
...@@ -3290,6 +3324,16 @@ class GPUModelRunnerMTP(GPUModelRunnerBase): ...@@ -3290,6 +3324,16 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
intermediate_tensors = None intermediate_tensors = None
else: else:
def _recv_tensor_dict():
return self._recv_queue.get()
if self._recv_thread is None:
self._recv_stream = torch.cuda.Stream()
self._recv_event = torch.cuda.Event()
self._recv_thread = threading.Thread(target=self._tensor_dict_recv_thread, daemon=True, name="pp_recv_thread")
self._recv_thread.start()
intermediate_tensors = _recv_tensor_dict()
torch.cuda.current_stream().wait_event(self._recv_event)
intermediate_tensors = self.sync_and_slice_intermediate_tensors( intermediate_tensors = self.sync_and_slice_intermediate_tensors(
num_input_tokens, intermediate_tensors, True) num_input_tokens, intermediate_tensors, True)
...@@ -3358,6 +3402,9 @@ class GPUModelRunnerMTP(GPUModelRunnerBase): ...@@ -3358,6 +3402,9 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
hidden_states, aux_hidden_states = model_output hidden_states, aux_hidden_states = model_output
else: else:
hidden_states = model_output hidden_states = model_output
if isinstance(model_output, IntermediateTensors):
residual_clone = model_output.tensors["residual"].clone()
hidden_states.tensors["residual"] = residual_clone
aux_hidden_states = None aux_hidden_states = None
# Broadcast PP output for external_launcher (torchrun) # Broadcast PP output for external_launcher (torchrun)
......
...@@ -2,7 +2,9 @@ ...@@ -2,7 +2,9 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A GPU worker class.""" """A GPU worker class."""
import gc import gc
import queue
import os import os
import threading
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
import torch import torch
...@@ -81,6 +83,10 @@ class Worker(WorkerBase): ...@@ -81,6 +83,10 @@ class Worker(WorkerBase):
torch_profiler_trace_dir, use_gzip=True)) torch_profiler_trace_dir, use_gzip=True))
else: else:
self.profiler = None self.profiler = None
self._send_queue: queue.Queue[tuple[IntermediateTensors, SchedulerOutput, torch.cuda.Event]] = queue.Queue(1)
self._send_thread: Optional[threading.Thread] = None
self._send_stream: Optional[torch.cuda.Stream] = None
self._send_event: Optional[torch.cuda.Event] = None
def sleep(self, level: int = 1) -> None: def sleep(self, level: int = 1) -> None:
free_bytes_before_sleep = torch.cuda.mem_get_info()[0] free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
...@@ -305,16 +311,33 @@ class Worker(WorkerBase): ...@@ -305,16 +311,33 @@ class Worker(WorkerBase):
def get_model(self) -> nn.Module: def get_model(self) -> nn.Module:
return self.model_runner.get_model() return self.model_runner.get_model()
def _send_tensor_dict(self):
intermediate_tensors, scheduler_output, event = self._send_queue.get()
assert event is not None
# 等待event在GPU执行完成
event.synchronize()
def _send_tensor_dict():
get_pp_group().send_tensor_dict(
intermediate_tensors.tensors,
all_gather_group=get_tp_group(),
)
_send_tensor_dict()
def _tensor_dict_send_thread(self):
torch.cuda.set_device(self.device)
torch.cuda.set_stream(self._send_stream)
while True:
self._send_tensor_dict()
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
) -> Optional[ModelRunnerOutput]: ) -> Optional[ModelRunnerOutput]:
intermediate_tensors = None intermediate_tensors = None
if not get_pp_group().is_first_rank:
intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group()))
if envs.VLLM_ZERO_OVERHEAD: if envs.VLLM_ZERO_OVERHEAD:
use_stream = zero_overhead_stream(self.device) use_stream = zero_overhead_stream(self.device)
with torch.cuda.stream(use_stream): with torch.cuda.stream(use_stream):
...@@ -327,8 +350,14 @@ class Worker(WorkerBase): ...@@ -327,8 +350,14 @@ class Worker(WorkerBase):
if parallel_config.distributed_executor_backend != "external_launcher" \ if parallel_config.distributed_executor_backend != "external_launcher" \
and not get_pp_group().is_last_rank: and not get_pp_group().is_last_rank:
assert isinstance(output, IntermediateTensors) assert isinstance(output, IntermediateTensors)
get_pp_group().send_tensor_dict(output.tensors, if self._send_thread is None:
all_gather_group=get_tp_group()) self._send_stream = torch.cuda.Stream()
self._send_thread = threading.Thread(target=self._tensor_dict_send_thread, daemon=True, name="pp_send_thread")
self._send_thread.start()
self._send_event = torch.cuda.Event()
send_event = self._send_event
send_event.record()
self._send_queue.put((output, scheduler_output, send_event))
return None return None
assert isinstance(output, ModelRunnerOutput) assert isinstance(output, ModelRunnerOutput)
return output if self.is_driver_worker else None return output if self.is_driver_worker else None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment