Commit fbfe20c6 authored by wanghl6's avatar wanghl6
Browse files

[DSA][BUGFIX]: 解决mqa_logits导致的oom问题

parent 3842b316
...@@ -30,6 +30,7 @@ elif current_platform.is_xpu(): ...@@ -30,6 +30,7 @@ elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops as ops from vllm._ipex_ops import ipex_ops as ops
logger = init_logger(__name__) logger = init_logger(__name__)
_GLOBAL_LOGITS_BUFFERS = {}
@maybe_transfer_kv_layer @maybe_transfer_kv_layer
def sparse_attn_indexer( def sparse_attn_indexer(
...@@ -50,7 +51,21 @@ def sparse_attn_indexer( ...@@ -50,7 +51,21 @@ def sparse_attn_indexer(
# careful! this will be None in dummy run # careful! this will be None in dummy run
attn_metadata = get_forward_context().attn_metadata attn_metadata = get_forward_context().attn_metadata
fp8_dtype = current_platform.fp8_dtype() fp8_dtype = current_platform.fp8_dtype()
if q_fp8.dtype == fp8_dtype:
MAX_ELEMENTS = 65536 * 65536
elif q_fp8.dtype in (torch.bfloat16, torch.float16):
MAX_ELEMENTS = 16384 * 32768
else:
MAX_ELEMENTS = 16384 * 32768
device = q_fp8.device
if device not in _GLOBAL_LOGITS_BUFFERS or _GLOBAL_LOGITS_BUFFERS[device].numel() < MAX_ELEMENTS:
_GLOBAL_LOGITS_BUFFERS[device] = torch.empty(
MAX_ELEMENTS,
dtype=torch.float32,
device=device
)
logits_buffer = _GLOBAL_LOGITS_BUFFERS[device]
# assert isinstance(attn_metadata, dict) # assert isinstance(attn_metadata, dict)
if not isinstance(attn_metadata, dict): if not isinstance(attn_metadata, dict):
# Reserve workspace for indexer during profiling run # Reserve workspace for indexer during profiling run
...@@ -116,14 +131,6 @@ def sparse_attn_indexer( ...@@ -116,14 +131,6 @@ def sparse_attn_indexer(
chunk.block_table, chunk.block_table,
chunk.cu_seq_lens, chunk.cu_seq_lens,
) )
logits = fp8_mqa_logits(
q_fp8[chunk.token_start : chunk.token_end],
(k_fp8, k_scale.view(torch.float32).flatten()),
weights[chunk.token_start : chunk.token_end],
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
)
elif get_gcn_arch_name() == "gfx938": elif get_gcn_arch_name() == "gfx938":
k_fp8 = k_fp8_full[: chunk.total_seq_lens] k_fp8 = k_fp8_full[: chunk.total_seq_lens]
k_scale = k_scale_full[: chunk.total_seq_lens] k_scale = k_scale_full[: chunk.total_seq_lens]
...@@ -134,19 +141,6 @@ def sparse_attn_indexer( ...@@ -134,19 +141,6 @@ def sparse_attn_indexer(
chunk.block_table, chunk.block_table,
chunk.cu_seq_lens, chunk.cu_seq_lens,
) )
logits = op.mqa_logits(
q_fp8[chunk.token_start:chunk.token_end],
k_fp8,
weights[chunk.token_start:chunk.token_end],
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
q_fp8[chunk.token_start:chunk.token_end].shape[0],
k_fp8.shape[0],
q_fp8.shape[1],
q_fp8.shape[2],
k_scale.view(torch.float32).flatten(),
True
)
else: else:
k_fp8 = k_fp8_full[: chunk.total_seq_lens] k_fp8 = k_fp8_full[: chunk.total_seq_lens]
k_scale = k_scale_full[: chunk.total_seq_lens] k_scale = k_scale_full[: chunk.total_seq_lens]
...@@ -156,44 +150,117 @@ def sparse_attn_indexer( ...@@ -156,44 +150,117 @@ def sparse_attn_indexer(
chunk.block_table, chunk.block_table,
chunk.cu_seq_lens, chunk.cu_seq_lens,
) )
logits = op.mqa_logits(
q_fp8[chunk.token_start:chunk.token_end], q_all = q_fp8[chunk.token_start:chunk.token_end]
weights_all = weights[chunk.token_start:chunk.token_end]
ks_all = chunk.cu_seqlen_ks
ke_all = chunk.cu_seqlen_ke
num_q = q_all.shape[0]
num_k = k_fp8.shape[0]
is_q_fp16_bf16 = q_all.dtype in (torch.float16, torch.bfloat16)
align_size = 128 if is_q_fp16_bf16 else 1
kv_seq_len_aligned = (num_k + align_size - 1) // align_size * align_size
current_capacity = logits_buffer.numel()
MAX_Q_CHUNK = current_capacity // max(1, kv_seq_len_aligned)
if align_size > 1:
MAX_Q_CHUNK = (MAX_Q_CHUNK // align_size) * align_size
MAX_Q_CHUNK = max(1, MAX_Q_CHUNK)
slices = []
for start_idx in range(0, num_q, MAX_Q_CHUNK):
end_idx = min(start_idx + MAX_Q_CHUNK, num_q)
slices.append((start_idx, end_idx))
for q_start, q_end in slices:
if q_end <= q_start:
continue
q_slice = q_all[q_start:q_end]
weights_slice = weights_all[q_start:q_end]
ks_slice = ks_all[q_start:q_end]
ke_slice = ke_all[q_start:q_end]
q_len = q_end - q_start
q_seq_len_aligned = (q_len + align_size - 1) // align_size * align_size
required_size = q_seq_len_aligned * kv_seq_len_aligned
logits_slice_view = logits_buffer[:required_size].view(q_seq_len_aligned, kv_seq_len_aligned)
if not current_platform.is_rocm():
logits_slice = fp8_mqa_logits(
q_slice,
(k_fp8, k_scale.view(torch.float32).flatten()),
weights_slice,
ks_slice,
ke_slice,
)
elif get_gcn_arch_name() == "gfx938":
op.mqa_logits(
q_slice,
k_fp8,
weights_slice,
ks_slice,
ke_slice,
q_slice.shape[0], # logical lengths
k_fp8.shape[0],
q_slice.shape[1],
q_slice.shape[2],
k_scale.view(torch.float32).flatten(),
True,
logits_slice_view # padded properly out of box for hardware requirements
)
# Extract the exact logical valid window for downstream topk
logits_slice = logits_slice_view[:q_len, :num_k]
else:
op.mqa_logits(
q_slice,
k_fp8, k_fp8,
weights[chunk.token_start:chunk.token_end].to(torch.float32), weights_slice.to(torch.float32),
chunk.cu_seqlen_ks, ks_slice,
chunk.cu_seqlen_ke, ke_slice,
q_fp8[chunk.token_start:chunk.token_end].shape[0], q_slice.shape[0],
k_fp8.shape[0], k_fp8.shape[0],
q_fp8.shape[1], q_slice.shape[1],
q_fp8.shape[2], q_slice.shape[2],
None, None,
True True,
logits_slice_view # padded properly out of box for hardware requirements
) )
num_rows = logits.shape[0] # Extract the exact logical valid window for downstream topk
logits_slice = logits_slice_view[:q_len, :num_k]
topk_indices = topk_indices_buffer[ num_rows_slice = logits_slice.shape[0]
chunk.token_start : chunk.token_end, :topk_tokens
topk_indices_slice = topk_indices_buffer[
chunk.token_start + q_start : chunk.token_start + q_end, :topk_tokens
] ]
if not envs.USE_LIGHTOP_TOPK: if not envs.USE_LIGHTOP_TOPK:
torch.ops._C.top_k_per_row_prefill( torch.ops._C.top_k_per_row_prefill(
logits, logits_slice,
chunk.cu_seqlen_ks, ks_slice,
chunk.cu_seqlen_ke, ke_slice,
topk_indices, topk_indices_slice,
num_rows, num_rows_slice,
logits.stride(0), logits_slice.stride(0), # Automatically fetches kv_seq_len_aligned stride
logits.stride(1), logits_slice.stride(1),
topk_tokens, topk_tokens,
) )
else: else:
op.top_k_per_row_prefill( op.top_k_per_row_prefill(
logits, logits_slice,
chunk.cu_seqlen_ks, ks_slice,
chunk.cu_seqlen_ke, ke_slice,
topk_indices, topk_indices_slice,
num_rows, num_rows_slice,
logits.stride(0), logits_slice.stride(0),
logits.stride(1), logits_slice.stride(1),
topk_tokens, topk_tokens,
) )
...@@ -424,5 +491,3 @@ class SparseAttnIndexer(CustomOp): ...@@ -424,5 +491,3 @@ class SparseAttnIndexer(CustomOp):
self.max_total_seq_len, self.max_total_seq_len,
self.topk_indices_buffer, self.topk_indices_buffer,
) )
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment