Unverified Commit b044400d authored by YanbingJiang's avatar YanbingJiang Committed by GitHub
Browse files

Support non-contiguous query input for extend/decode attention (#7462)

parent 40e5cb7a
...@@ -874,6 +874,8 @@ void decode_attention_kernel_impl( ...@@ -874,6 +874,8 @@ void decode_attention_kernel_impl(
int64_t head_size, int64_t head_size,
int64_t head_size_v, int64_t head_size_v,
int64_t num_kv_splits, int64_t num_kv_splits,
int64_t q_strideM,
int64_t q_strideH,
int64_t k_strideN, int64_t k_strideN,
int64_t k_strideH, int64_t k_strideH,
int64_t v_strideN, int64_t v_strideN,
...@@ -886,8 +888,6 @@ void decode_attention_kernel_impl( ...@@ -886,8 +888,6 @@ void decode_attention_kernel_impl(
using Vec = at::vec::Vectorized<float>; using Vec = at::vec::Vectorized<float>;
// strides // strides
const int64_t q_strideM = num_heads * head_size;
const int64_t q_strideH = head_size;
const int64_t l_stride1 = num_kv_splits * (head_size_v + 1); const int64_t l_stride1 = num_kv_splits * (head_size_v + 1);
const int64_t l_stride2 = head_size_v + 1; const int64_t l_stride2 = head_size_v + 1;
...@@ -1017,6 +1017,8 @@ void decode_attention_mla_kernel_impl( ...@@ -1017,6 +1017,8 @@ void decode_attention_mla_kernel_impl(
int64_t head_size, int64_t head_size,
int64_t head_size_v, int64_t head_size_v,
int64_t num_kv_splits, int64_t num_kv_splits,
int64_t q_strideM,
int64_t q_strideH,
int64_t k_strideN, int64_t k_strideN,
int64_t k_strideH, int64_t k_strideH,
int64_t v_strideN, int64_t v_strideN,
...@@ -1033,8 +1035,6 @@ void decode_attention_mla_kernel_impl( ...@@ -1033,8 +1035,6 @@ void decode_attention_mla_kernel_impl(
const int64_t BLOCK_H = batches == 1 ? 6 : (batches > 16 ? 22 : 11); const int64_t BLOCK_H = batches == 1 ? 6 : (batches > 16 ? 22 : 11);
// strides // strides
const int64_t q_strideM = num_heads * head_size;
const int64_t q_strideH = head_size;
const int64_t l_stride0 = num_heads * num_kv_splits * (head_size_v + 1); const int64_t l_stride0 = num_heads * num_kv_splits * (head_size_v + 1);
const int64_t l_stride1 = num_kv_splits * (head_size_v + 1); const int64_t l_stride1 = num_kv_splits * (head_size_v + 1);
const int64_t l_stride2 = head_size_v + 1; const int64_t l_stride2 = head_size_v + 1;
...@@ -1209,6 +1209,8 @@ void decode_attention_grouped_kernel_impl( ...@@ -1209,6 +1209,8 @@ void decode_attention_grouped_kernel_impl(
int64_t head_size, int64_t head_size,
int64_t head_size_v, int64_t head_size_v,
int64_t num_kv_splits, int64_t num_kv_splits,
int64_t q_strideM,
int64_t q_strideH,
int64_t k_strideN, int64_t k_strideN,
int64_t k_strideH, int64_t k_strideH,
int64_t v_strideN, int64_t v_strideN,
...@@ -1227,8 +1229,6 @@ void decode_attention_grouped_kernel_impl( ...@@ -1227,8 +1229,6 @@ void decode_attention_grouped_kernel_impl(
const int64_t BLOCK_H = std::min(4 * batches, kBLOCK_H); const int64_t BLOCK_H = std::min(4 * batches, kBLOCK_H);
// strides // strides
const int64_t q_strideM = num_heads * head_size;
const int64_t q_strideH = head_size;
const int64_t l_stride0 = num_heads * num_kv_splits * (head_size_v + 1); const int64_t l_stride0 = num_heads * num_kv_splits * (head_size_v + 1);
const int64_t l_stride1 = num_kv_splits * (head_size_v + 1); const int64_t l_stride1 = num_kv_splits * (head_size_v + 1);
const int64_t l_stride2 = head_size_v + 1; const int64_t l_stride2 = head_size_v + 1;
...@@ -1391,7 +1391,7 @@ void decode_attention_cpu( ...@@ -1391,7 +1391,7 @@ void decode_attention_cpu(
std::vector<c10::IValue>( std::vector<c10::IValue>(
{query, output, k_buffer, v_buffer, attn_logits, req_to_token, req_pool_indices, seq_lens})); {query, output, k_buffer, v_buffer, attn_logits, req_to_token, req_pool_indices, seq_lens}));
CHECK_INPUT(query); CHECK_LAST_DIM_CONTIGUOUS_INPUT(query);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_buffer); CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_buffer);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(v_buffer); CHECK_LAST_DIM_CONTIGUOUS_INPUT(v_buffer);
// for MLA, key and value shares the same storage and value could be non-contiguous // for MLA, key and value shares the same storage and value could be non-contiguous
...@@ -1422,6 +1422,10 @@ void decode_attention_cpu( ...@@ -1422,6 +1422,10 @@ void decode_attention_cpu(
CHECK_EQ(attn_logits.size(3), head_size_v + 1); CHECK_EQ(attn_logits.size(3), head_size_v + 1);
CHECK_EQ(attn_logits.scalar_type(), at::kFloat); CHECK_EQ(attn_logits.scalar_type(), at::kFloat);
// strides for query
int64_t q_strideM = query.stride(0);
int64_t q_strideH = query.stride(1);
// strides for k_buffer and v_buffer // strides for k_buffer and v_buffer
int64_t k_strideN = k_buffer.stride(0); int64_t k_strideN = k_buffer.stride(0);
int64_t k_strideH = k_buffer.stride(1); int64_t k_strideH = k_buffer.stride(1);
...@@ -1497,6 +1501,8 @@ void decode_attention_cpu( ...@@ -1497,6 +1501,8 @@ void decode_attention_cpu(
head_size, head_size,
head_size_v, head_size_v,
num_kv_splits, num_kv_splits,
q_strideM,
q_strideH,
k_strideN, k_strideN,
k_strideH, k_strideH,
v_strideN, v_strideN,
...@@ -1523,6 +1529,8 @@ void decode_attention_cpu( ...@@ -1523,6 +1529,8 @@ void decode_attention_cpu(
head_size, head_size,
head_size_v, head_size_v,
num_kv_splits, num_kv_splits,
q_strideM,
q_strideH,
k_strideN, k_strideN,
k_strideH, k_strideH,
v_strideN, v_strideN,
...@@ -1550,6 +1558,8 @@ void decode_attention_cpu( ...@@ -1550,6 +1558,8 @@ void decode_attention_cpu(
head_size, head_size,
head_size_v, head_size_v,
num_kv_splits, num_kv_splits,
q_strideM,
q_strideH,
k_strideN, k_strideN,
k_strideH, k_strideH,
v_strideN, v_strideN,
......
...@@ -240,6 +240,8 @@ void extend_attention_kernel_impl( ...@@ -240,6 +240,8 @@ void extend_attention_kernel_impl(
int num_heads_kv, int num_heads_kv,
int head_size, int head_size,
int head_size_v, int head_size_v,
int q_strideM,
int q_strideH,
int ke_strideN, int ke_strideN,
int ke_strideH, int ke_strideH,
int ve_strideN, int ve_strideN,
...@@ -259,8 +261,6 @@ void extend_attention_kernel_impl( ...@@ -259,8 +261,6 @@ void extend_attention_kernel_impl(
using Vec = at::vec::Vectorized<float>; using Vec = at::vec::Vectorized<float>;
// strides // strides
const int q_strideM = num_heads * head_size;
const int q_strideH = head_size;
const int o_strideM = num_heads * head_size_v; const int o_strideM = num_heads * head_size_v;
const int o_strideH = head_size_v; const int o_strideH = head_size_v;
...@@ -606,7 +606,7 @@ void extend_attention_cpu( ...@@ -606,7 +606,7 @@ void extend_attention_cpu(
extend_seq_lens, extend_seq_lens,
extend_start_loc})); extend_start_loc}));
CHECK_INPUT(q_extend); CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_extend);
CHECK_INPUT(o_extend); CHECK_INPUT(o_extend);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_extend); CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_extend);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(v_extend); CHECK_LAST_DIM_CONTIGUOUS_INPUT(v_extend);
...@@ -623,7 +623,9 @@ void extend_attention_cpu( ...@@ -623,7 +623,9 @@ void extend_attention_cpu(
int head_size = q_extend.size(2); int head_size = q_extend.size(2);
int head_size_v = v_extend.size(2); int head_size_v = v_extend.size(2);
// strides for k_extend and v_extend // strides for q_extend, k_extend and v_extend
int q_strideM = q_extend.stride(0);
int q_strideH = q_extend.stride(1);
int ke_strideN = k_extend.stride(0); int ke_strideN = k_extend.stride(0);
int ke_strideH = k_extend.stride(1); int ke_strideH = k_extend.stride(1);
int ve_strideN = v_extend.stride(0); int ve_strideN = v_extend.stride(0);
...@@ -698,6 +700,8 @@ void extend_attention_cpu( ...@@ -698,6 +700,8 @@ void extend_attention_cpu(
num_heads_kv, num_heads_kv,
head_size, head_size,
head_size_v, head_size_v,
q_strideM,
q_strideH,
ke_strideN, ke_strideN,
ke_strideH, ke_strideH,
ve_strideN, ve_strideN,
......
...@@ -102,9 +102,10 @@ class TestDecodeAttention(CustomTestCase): ...@@ -102,9 +102,10 @@ class TestDecodeAttention(CustomTestCase):
device=device, device=device,
) )
# k_buffer, v_buffer, key and value supports non-contiguous tensors # k_buffer, v_buffer, query, key and value supports non-contiguous tensors
k_buffer = k_buffer.transpose(0, 1).contiguous().transpose(0, 1) k_buffer = k_buffer.transpose(0, 1).contiguous().transpose(0, 1)
v_buffer = v_buffer.transpose(0, 1).contiguous().transpose(0, 1) v_buffer = v_buffer.transpose(0, 1).contiguous().transpose(0, 1)
q = q.transpose(0, 1).contiguous().transpose(0, 1)
key = key.transpose(0, 1).contiguous().transpose(0, 1) key = key.transpose(0, 1).contiguous().transpose(0, 1)
value = value.transpose(0, 1).contiguous().transpose(0, 1) value = value.transpose(0, 1).contiguous().transpose(0, 1)
torch.ops.sgl_kernel.decode_attention_cpu( torch.ops.sgl_kernel.decode_attention_cpu(
......
...@@ -123,7 +123,8 @@ class TestExtendAttention(CustomTestCase): ...@@ -123,7 +123,8 @@ class TestExtendAttention(CustomTestCase):
(b_seq_len_extend[i], H_Q, D), dtype=dtype (b_seq_len_extend[i], H_Q, D), dtype=dtype
) )
# k_extend, v_extend, k_buffer and v_buffer supports non-contiguous tensors # q_extend, k_extend, v_extend, k_buffer and v_buffer supports non-contiguous tensors
q_extend = q_extend.transpose(0, 1).contiguous().transpose(0, 1)
k_extend = k_extend.transpose(0, 1).contiguous().transpose(0, 1) k_extend = k_extend.transpose(0, 1).contiguous().transpose(0, 1)
v_extend = v_extend.transpose(0, 1).contiguous().transpose(0, 1) v_extend = v_extend.transpose(0, 1).contiguous().transpose(0, 1)
k_buffer = k_buffer.transpose(0, 1).contiguous().transpose(0, 1) k_buffer = k_buffer.transpose(0, 1).contiguous().transpose(0, 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment