Commit fa57cb7f authored by 王敏's avatar 王敏
Browse files

[fix]修复test_attention单测中paged_attention_v1和paged_attention_v2 opcheck找不到attn_masks错误

parent 1d6cfb11
......@@ -206,7 +206,7 @@ def test_paged_attention(
opcheck(torch.ops._C.paged_attention_v1,
(output, query, key_cache, value_cache, num_kv_heads, scale,
block_tables, seq_lens, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0, None, 0),
cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0]))
......@@ -248,7 +248,7 @@ def test_paged_attention(
(output, exp_sums, max_logits, tmp_output, query,
key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0, None, 0),
cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0]))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment