Commit adbd3d7b authored by wangmin6's avatar wangmin6
Browse files

Merge branch 'v0.15.1-dev-wm' into 'v0.15.1-dev'

[perf]DSA架构模型支持mtp>1

See merge request dcutoolkit/deeplearing/vllm!521
parents 12b5bcb1 7eb2446c
...@@ -74,6 +74,12 @@ def sparse_attn_indexer( ...@@ -74,6 +74,12 @@ def sparse_attn_indexer(
has_prefill = attn_metadata.num_prefills > 0 has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens num_decode_tokens = attn_metadata.num_decode_tokens
# During speculative decoding, k may be padded to the CUDA graph batch
# size while slot_mapping only covers actual tokens. Truncate k to avoid
# out-of-bounds reads in the kernel.
num_tokens = slot_mapping.shape[0]
k = k[:num_tokens]
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938": if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
ops.indexer_k_quant_and_cache( ops.indexer_k_quant_and_cache(
k, k,
...@@ -135,10 +141,10 @@ def sparse_attn_indexer( ...@@ -135,10 +141,10 @@ def sparse_attn_indexer(
k_scale.view(torch.float32).flatten(), k_scale.view(torch.float32).flatten(),
True True
) )
else: else:
logits = op.mqa_logits( logits = op.mqa_logits(
q_fp8[chunk.token_start:chunk.token_end], q_fp8[chunk.token_start:chunk.token_end],
k, k,
weights[chunk.token_start:chunk.token_end].to(torch.float32), weights[chunk.token_start:chunk.token_end].to(torch.float32),
chunk.cu_seqlen_ks, chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke, chunk.cu_seqlen_ke,
......
...@@ -8,6 +8,7 @@ import torch ...@@ -8,6 +8,7 @@ import torch
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata, is_deep_gemm_supported from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata, is_deep_gemm_supported
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
...@@ -21,8 +22,10 @@ from vllm.v1.attention.backends.utils import ( ...@@ -21,8 +22,10 @@ from vllm.v1.attention.backends.utils import (
split_prefill_chunks, split_prefill_chunks,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.worker.cp_utils import get_total_cp_world_size
from lightop import gemmopt from lightop import gemmopt
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -214,14 +217,44 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): ...@@ -214,14 +217,44 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
else 0 else 0
) )
# Now deepgemm fp8_paged_mqa_logits does not support next_n > 2 # Now deepgemm fp8_paged_mqa_logits does not support next_n > 2
self.reorder_batch_threshold += min(self.num_speculative_tokens, 1) #self.reorder_batch_threshold += min(self.num_speculative_tokens, 1)
self.reorder_batch_threshold += self.num_speculative_tokens
props = torch.cuda.get_device_properties(self.device) props = torch.cuda.get_device_properties(self.device)
sm_count = props.multi_processor_count sm_count = props.multi_processor_count
self.num_sms = sm_count self.num_sms = sm_count
# self.decode_lens_buffer = torch.empty(
# (scheduler_config.max_num_seqs,), dtype=torch.int32, device=self.device
# )
self.decode_lens_buffer = torch.empty( self.decode_lens_buffer = torch.empty(
(scheduler_config.max_num_seqs,), dtype=torch.int32, device=self.device (scheduler_config.max_num_batched_tokens,),
dtype=torch.int32,
device=self.device,
)
# Pre-allocated buffers for flattening (spec decode).
self.arange_buffer = torch.arange(
scheduler_config.max_num_seqs * (1 + self.num_speculative_tokens),
dtype=torch.int32,
device=self.device,
)
self.expanded_seq_lens_buffer = torch.zeros(
(scheduler_config.max_num_batched_tokens,),
dtype=torch.int32,
device=self.device,
)
max_num_blocks_per_req = cdiv(
self.vllm_config.model_config.max_model_len,
self.kv_cache_spec.block_size * get_total_cp_world_size(),
)
self.expanded_block_table_buffer = torch.zeros(
(
scheduler_config.max_num_batched_tokens,
max_num_blocks_per_req,
),
dtype=torch.int32,
device=self.device,
) )
# See: DeepGMM/csrc/apis/attention.hpp # See: DeepGMM/csrc/apis/attention.hpp
...@@ -320,24 +353,81 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): ...@@ -320,24 +353,81 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
common_attn_metadata.query_start_loc_cpu[: num_decodes + 1] common_attn_metadata.query_start_loc_cpu[: num_decodes + 1]
) )
# Use CPU to avoid GPU sync; breaking async scheduling
requires_padding = (decode_lens_cpu.max() > decode_lens_cpu.min()).item()
seq_lens = common_attn_metadata.seq_lens[:num_decodes] seq_lens = common_attn_metadata.seq_lens[:num_decodes]
# if is_deep_gemm_supported(): block_table = common_attn_metadata.block_table_tensor[:num_decodes, ...]
# Padded CUDA graph requests have block_table entries of -1.
# Clamp to 0 to prevent OOB access in the DeepGEMM kernel.
# This is safe because padded requests have seq_lens=0, so the
# kernel produces no meaningful output for those rows.
block_table.clamp_(min=0)
max_decode_len = int(decode_lens_cpu.max().item())
if max_decode_len > 1:
# Flatten multi-token decode requests into single-token
# batch entries, expanding seq_lens and block tables so
# the kernel always sees next_n=1.
# Assume 4 requests with seq_lens [10, 7, 12, 0] (the final req is
# padding) and decode_lens [3, 1, 4, 0] in the below example comments.
# The context lengths are therefore
# [10-3, 7-1, 12-4, 0-0] = [7, 6, 8, 0].
# 3 + 1 + 4 + 0 = 8
actual_expanded = int(decode_lens_cpu.sum().item())
# [7, 6, 8, 0] -> [7, 7, 7, 6, 8, 8, 8, 8]
expanded_base = torch.repeat_interleave(
seq_lens - decode_lens, decode_lens
)
# [0, 3, 4, 8] -> [0, 0, 0, 3, 4, 4, 4, 4]
expanded_starts = torch.repeat_interleave(
common_attn_metadata.query_start_loc[:num_decodes], decode_lens
)
# [0, 1, 2, 0, 0, 1, 2, 3]
positions_within = (
self.arange_buffer[:actual_expanded] - expanded_starts
)
# [8, 9, 10, 7, 9, 10, 11, 12, ...] where ... is unused buffer space
self.expanded_seq_lens_buffer[:actual_expanded] = (
expanded_base + positions_within + 1
)
self.expanded_seq_lens_buffer[actual_expanded:] = 0
seq_lens = self.expanded_seq_lens_buffer[:num_decode_tokens]
# Give each of the flattened entries the same block table row as the
# original request.
self.expanded_block_table_buffer[:actual_expanded] = (
torch.repeat_interleave(block_table, decode_lens, dim=0)
)
if actual_expanded < num_decode_tokens:
self.expanded_block_table_buffer[
actual_expanded:num_decode_tokens, 0
] = 0
block_table = self.expanded_block_table_buffer[:num_decode_tokens]
# All reqs now have decode_len=1
self.decode_lens_buffer[:num_decode_tokens] = 1
decode_lens = self.decode_lens_buffer[:num_decode_tokens]
# DeepGEMM is required for the paged MQA logits on CUDA devices
if current_platform.is_rocm(): if current_platform.is_rocm():
self.scheduler_metadata_buffer= gemmopt.get_paged_mqa_logits_metadata( self.scheduler_metadata_buffer = gemmopt.get_paged_mqa_logits_metadata(
seq_lens, self.kv_cache_spec.block_size, self.num_sms seq_lens, self.kv_cache_spec.block_size, self.num_sms
) )
else: else:
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata( self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
seq_lens, self.kv_cache_spec.block_size, self.num_sms seq_lens, self.kv_cache_spec.block_size, self.num_sms
) )
decode_metadata = DeepSeekV32IndexerDecodeMetadata( decode_metadata = DeepSeekV32IndexerDecodeMetadata(
block_table=common_attn_metadata.block_table_tensor[:num_decodes, ...], block_table=block_table,
seq_lens=common_attn_metadata.seq_lens[:num_decodes], seq_lens=seq_lens,
decode_lens=decode_lens, decode_lens=decode_lens,
requires_padding=requires_padding, requires_padding=False,
schedule_metadata=self.scheduler_metadata_buffer, schedule_metadata=self.scheduler_metadata_buffer,
) )
......
...@@ -562,9 +562,23 @@ class SpecDecodeBaseProposer: ...@@ -562,9 +562,23 @@ class SpecDecodeBaseProposer:
attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore
common_attn_metadata=common_attn_metadata, draft_index=token_index + 1 common_attn_metadata=common_attn_metadata, draft_index=token_index + 1
) )
if self.draft_indexer_metadata_builder:
draft_indexer_metadata = (
self.draft_indexer_metadata_builder.build_for_drafting(
common_attn_metadata=common_attn_metadata,
draft_index=token_index + 1,
)
)
else:
draft_indexer_metadata = None
for layer_name in self.attn_layer_names: for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata per_layer_attn_metadata[layer_name] = attn_metadata
for layer_name in self.indexer_layer_names:
per_layer_attn_metadata[layer_name] = draft_indexer_metadata
# copy inputs to buffer for cudagraph # copy inputs to buffer for cudagraph
self.input_ids[:batch_size] = input_ids self.input_ids[:batch_size] = input_ids
self._set_positions(batch_size, clamped_positions) self._set_positions(batch_size, clamped_positions)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment