Commit 6a86ea6d authored by wanghl6's avatar wanghl6
Browse files

[DSA][BUGFIX]解决mqa_logits开PC时大bs导致的oom问题

parent 1edffefe
...@@ -20,10 +20,6 @@ from vllm.v1.attention.ops.rocm_aiter_mla_sparse import indexer_k_bf16_cache_tri ...@@ -20,10 +20,6 @@ from vllm.v1.attention.ops.rocm_aiter_mla_sparse import indexer_k_bf16_cache_tri
from vllm.v1.worker.workspace import current_workspace_manager from vllm.v1.worker.workspace import current_workspace_manager
from lightop import op, gemmopt from lightop import op, gemmopt
from vllm.attention.utils.kv_transfer_utils import (
maybe_transfer_kv_layer,
)
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
elif current_platform.is_xpu(): elif current_platform.is_xpu():
...@@ -31,10 +27,10 @@ elif current_platform.is_xpu(): ...@@ -31,10 +27,10 @@ elif current_platform.is_xpu():
logger = init_logger(__name__) logger = init_logger(__name__)
@maybe_transfer_kv_layer
def sparse_attn_indexer( def sparse_attn_indexer(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
layer_name:str, k_cache_prefix: str,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
q_fp8: torch.Tensor, q_fp8: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
...@@ -60,7 +56,7 @@ def sparse_attn_indexer( ...@@ -60,7 +56,7 @@ def sparse_attn_indexer(
) )
return sparse_attn_indexer_fake( return sparse_attn_indexer_fake(
hidden_states, hidden_states,
layer_name, k_cache_prefix,
kv_cache, kv_cache,
q_fp8, q_fp8,
k, k,
...@@ -73,9 +69,9 @@ def sparse_attn_indexer( ...@@ -73,9 +69,9 @@ def sparse_attn_indexer(
total_seq_lens, total_seq_lens,
topk_indices_buffer, topk_indices_buffer,
) )
attn_metadata = attn_metadata[layer_name] attn_metadata = attn_metadata[k_cache_prefix]
assert isinstance(attn_metadata, DeepseekV32IndexerMetadata) assert isinstance(attn_metadata, DeepseekV32IndexerMetadata)
slot_mapping = attn_metadata.slot_mapping[:attn_metadata.num_kv_actual_tokens] slot_mapping = attn_metadata.slot_mapping
has_decode = attn_metadata.num_decodes > 0 has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0 has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens num_decode_tokens = attn_metadata.num_decode_tokens
...@@ -322,7 +318,7 @@ def sparse_attn_indexer( ...@@ -322,7 +318,7 @@ def sparse_attn_indexer(
def sparse_attn_indexer_fake( def sparse_attn_indexer_fake(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
layer_name: str, k_cache_prefix: str,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
q_fp8: torch.Tensor, q_fp8: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment