Commit 1edffefe authored by wanghl6's avatar wanghl6
Browse files

[BUGFIX]解决mqa_logits大bs导致的oom问题

parent bcb2ba6c
...@@ -75,7 +75,7 @@ def sparse_attn_indexer( ...@@ -75,7 +75,7 @@ def sparse_attn_indexer(
) )
attn_metadata = attn_metadata[layer_name] attn_metadata = attn_metadata[layer_name]
assert isinstance(attn_metadata, DeepseekV32IndexerMetadata) assert isinstance(attn_metadata, DeepseekV32IndexerMetadata)
slot_mapping = attn_metadata.slot_mapping slot_mapping = attn_metadata.slot_mapping[:attn_metadata.num_kv_actual_tokens]
has_decode = attn_metadata.num_decodes > 0 has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0 has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens num_decode_tokens = attn_metadata.num_decode_tokens
...@@ -116,14 +116,6 @@ def sparse_attn_indexer( ...@@ -116,14 +116,6 @@ def sparse_attn_indexer(
chunk.block_table, chunk.block_table,
chunk.cu_seq_lens, chunk.cu_seq_lens,
) )
logits = fp8_mqa_logits(
q_fp8[chunk.token_start : chunk.token_end],
(k_fp8, k_scale.view(torch.float32).flatten()),
weights[chunk.token_start : chunk.token_end],
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
)
elif get_gcn_arch_name() == "gfx938": elif get_gcn_arch_name() == "gfx938":
k_fp8 = k_fp8_full[: chunk.total_seq_lens] k_fp8 = k_fp8_full[: chunk.total_seq_lens]
k_scale = k_scale_full[: chunk.total_seq_lens] k_scale = k_scale_full[: chunk.total_seq_lens]
...@@ -134,19 +126,6 @@ def sparse_attn_indexer( ...@@ -134,19 +126,6 @@ def sparse_attn_indexer(
chunk.block_table, chunk.block_table,
chunk.cu_seq_lens, chunk.cu_seq_lens,
) )
logits = op.mqa_logits(
q_fp8[chunk.token_start:chunk.token_end],
k_fp8,
weights[chunk.token_start:chunk.token_end],
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
q_fp8[chunk.token_start:chunk.token_end].shape[0],
k_fp8.shape[0],
q_fp8.shape[1],
q_fp8.shape[2],
k_scale.view(torch.float32).flatten(),
True
)
else: else:
k_fp8 = k_fp8_full[: chunk.total_seq_lens] k_fp8 = k_fp8_full[: chunk.total_seq_lens]
k_scale = k_scale_full[: chunk.total_seq_lens] k_scale = k_scale_full[: chunk.total_seq_lens]
...@@ -156,46 +135,104 @@ def sparse_attn_indexer( ...@@ -156,46 +135,104 @@ def sparse_attn_indexer(
chunk.block_table, chunk.block_table,
chunk.cu_seq_lens, chunk.cu_seq_lens,
) )
logits = op.mqa_logits(
q_fp8[chunk.token_start:chunk.token_end],
k_fp8,
weights[chunk.token_start:chunk.token_end].to(torch.float32),
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
q_fp8[chunk.token_start:chunk.token_end].shape[0],
k_fp8.shape[0],
q_fp8.shape[1],
q_fp8.shape[2],
None,
True
)
num_rows = logits.shape[0]
topk_indices = topk_indices_buffer[ q_all = q_fp8[chunk.token_start:chunk.token_end]
chunk.token_start : chunk.token_end, :topk_tokens weights_all = weights[chunk.token_start:chunk.token_end]
] ks_all = chunk.cu_seqlen_ks
if not envs.USE_LIGHTOP_TOPK: ke_all = chunk.cu_seqlen_ke
torch.ops._C.top_k_per_row_prefill(
logits, num_q = q_all.shape[0]
chunk.cu_seqlen_ks, num_k = k_fp8.shape[0]
chunk.cu_seqlen_ke,
topk_indices, MAX_ELEMENTS = 1024 * 1024 * 1024 # 4GB
num_rows, if (num_q <= 65536 and num_k <= 65536): # if num_q <= 65536 and num_k <= 65536 and (num_q * num_k <= MAX_ELEMENTS):
logits.stride(0), MAX_Q_CHUNK = max(1, num_q)
logits.stride(1),
topk_tokens,
)
else: else:
op.top_k_per_row_prefill( MAX_Q_CHUNK = max(1024, MAX_ELEMENTS // max(1, num_k))
logits, MAX_Q_CHUNK = min(MAX_Q_CHUNK, max(1, num_q))
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke, #存储q的起始和终止地址
topk_indices, slices = []
num_rows,
logits.stride(0), for start_idx in range(0, num_q, MAX_Q_CHUNK):
logits.stride(1), end_idx = min(start_idx + MAX_Q_CHUNK, num_q)
topk_tokens, slices.append((start_idx, end_idx))
)
for q_start, q_end in slices:
if q_end <= q_start:
continue
q_slice = q_all[q_start:q_end]
weights_slice = weights_all[q_start:q_end]
ks_slice = ks_all[q_start:q_end]
ke_slice = ke_all[q_start:q_end]
if not current_platform.is_rocm():
logits_slice = fp8_mqa_logits(
q_slice,
(k_fp8, k_scale.view(torch.float32).flatten()),
weights_slice,
ks_slice,
ke_slice,
)
elif get_gcn_arch_name() == "gfx938":
logits_slice = op.mqa_logits(
q_slice,
k_fp8,
weights_slice,
ks_slice,
ke_slice,
q_slice.shape[0],
k_fp8.shape[0],
q_slice.shape[1],
q_slice.shape[2],
k_scale.view(torch.float32).flatten(),
True
)
else:
logits_slice = op.mqa_logits(
q_slice,
k_fp8,
weights_slice.to(torch.float32),
ks_slice,
ke_slice,
q_slice.shape[0],
k_fp8.shape[0],
q_slice.shape[1],
q_slice.shape[2],
None,
True
)
num_rows_slice = logits_slice.shape[0]
topk_indices_slice = topk_indices_buffer[
chunk.token_start + q_start : chunk.token_start + q_end, :topk_tokens
]
if not envs.USE_LIGHTOP_TOPK:
torch.ops._C.top_k_per_row_prefill(
logits_slice,
ks_slice,
ke_slice,
topk_indices_slice,
num_rows_slice,
logits_slice.stride(0),
logits_slice.stride(1),
topk_tokens,
)
else:
op.top_k_per_row_prefill(
logits_slice,
ks_slice,
ke_slice,
topk_indices_slice,
num_rows_slice,
logits_slice.stride(0),
logits_slice.stride(1),
topk_tokens,
)
if has_decode: if has_decode:
decode_metadata = attn_metadata.decode decode_metadata = attn_metadata.decode
...@@ -242,7 +279,6 @@ def sparse_attn_indexer( ...@@ -242,7 +279,6 @@ def sparse_attn_indexer(
max_model_len, max_model_len,
) )
num_rows = logits.shape[0] num_rows = logits.shape[0]
topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens] topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens]
...@@ -423,6 +459,4 @@ class SparseAttnIndexer(CustomOp): ...@@ -423,6 +459,4 @@ class SparseAttnIndexer(CustomOp):
self.max_model_len, self.max_model_len,
self.max_total_seq_len, self.max_total_seq_len,
self.topk_indices_buffer, self.topk_indices_buffer,
) )
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment