Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b044400d
Unverified
Commit
b044400d
authored
Jul 03, 2025
by
YanbingJiang
Committed by
GitHub
Jul 02, 2025
Browse files
Support non-contiguous query input for extend/decode attention (#7462)
parent
40e5cb7a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
29 additions
and
13 deletions
+29
-13
sgl-kernel/csrc/cpu/decode.cpp
sgl-kernel/csrc/cpu/decode.cpp
+17
-7
sgl-kernel/csrc/cpu/extend.cpp
sgl-kernel/csrc/cpu/extend.cpp
+8
-4
test/srt/cpu/test_decode.py
test/srt/cpu/test_decode.py
+2
-1
test/srt/cpu/test_extend.py
test/srt/cpu/test_extend.py
+2
-1
No files found.
sgl-kernel/csrc/cpu/decode.cpp
View file @
b044400d
...
...
@@ -874,6 +874,8 @@ void decode_attention_kernel_impl(
int64_t
head_size
,
int64_t
head_size_v
,
int64_t
num_kv_splits
,
int64_t
q_strideM
,
int64_t
q_strideH
,
int64_t
k_strideN
,
int64_t
k_strideH
,
int64_t
v_strideN
,
...
...
@@ -886,8 +888,6 @@ void decode_attention_kernel_impl(
using
Vec
=
at
::
vec
::
Vectorized
<
float
>
;
// strides
const
int64_t
q_strideM
=
num_heads
*
head_size
;
const
int64_t
q_strideH
=
head_size
;
const
int64_t
l_stride1
=
num_kv_splits
*
(
head_size_v
+
1
);
const
int64_t
l_stride2
=
head_size_v
+
1
;
...
...
@@ -1017,6 +1017,8 @@ void decode_attention_mla_kernel_impl(
int64_t
head_size
,
int64_t
head_size_v
,
int64_t
num_kv_splits
,
int64_t
q_strideM
,
int64_t
q_strideH
,
int64_t
k_strideN
,
int64_t
k_strideH
,
int64_t
v_strideN
,
...
...
@@ -1033,8 +1035,6 @@ void decode_attention_mla_kernel_impl(
const
int64_t
BLOCK_H
=
batches
==
1
?
6
:
(
batches
>
16
?
22
:
11
);
// strides
const
int64_t
q_strideM
=
num_heads
*
head_size
;
const
int64_t
q_strideH
=
head_size
;
const
int64_t
l_stride0
=
num_heads
*
num_kv_splits
*
(
head_size_v
+
1
);
const
int64_t
l_stride1
=
num_kv_splits
*
(
head_size_v
+
1
);
const
int64_t
l_stride2
=
head_size_v
+
1
;
...
...
@@ -1209,6 +1209,8 @@ void decode_attention_grouped_kernel_impl(
int64_t
head_size
,
int64_t
head_size_v
,
int64_t
num_kv_splits
,
int64_t
q_strideM
,
int64_t
q_strideH
,
int64_t
k_strideN
,
int64_t
k_strideH
,
int64_t
v_strideN
,
...
...
@@ -1227,8 +1229,6 @@ void decode_attention_grouped_kernel_impl(
const
int64_t
BLOCK_H
=
std
::
min
(
4
*
batches
,
kBLOCK_H
);
// strides
const
int64_t
q_strideM
=
num_heads
*
head_size
;
const
int64_t
q_strideH
=
head_size
;
const
int64_t
l_stride0
=
num_heads
*
num_kv_splits
*
(
head_size_v
+
1
);
const
int64_t
l_stride1
=
num_kv_splits
*
(
head_size_v
+
1
);
const
int64_t
l_stride2
=
head_size_v
+
1
;
...
...
@@ -1391,7 +1391,7 @@ void decode_attention_cpu(
std
::
vector
<
c10
::
IValue
>
(
{
query
,
output
,
k_buffer
,
v_buffer
,
attn_logits
,
req_to_token
,
req_pool_indices
,
seq_lens
}));
CHECK_INPUT
(
query
);
CHECK_
LAST_DIM_CONTIGUOUS_
INPUT
(
query
);
CHECK_LAST_DIM_CONTIGUOUS_INPUT
(
k_buffer
);
CHECK_LAST_DIM_CONTIGUOUS_INPUT
(
v_buffer
);
// for MLA, key and value shares the same storage and value could be non-contiguous
...
...
@@ -1422,6 +1422,10 @@ void decode_attention_cpu(
CHECK_EQ
(
attn_logits
.
size
(
3
),
head_size_v
+
1
);
CHECK_EQ
(
attn_logits
.
scalar_type
(),
at
::
kFloat
);
// strides for query
int64_t
q_strideM
=
query
.
stride
(
0
);
int64_t
q_strideH
=
query
.
stride
(
1
);
// strides for k_buffer and v_buffer
int64_t
k_strideN
=
k_buffer
.
stride
(
0
);
int64_t
k_strideH
=
k_buffer
.
stride
(
1
);
...
...
@@ -1497,6 +1501,8 @@ void decode_attention_cpu(
head_size
,
head_size_v
,
num_kv_splits
,
q_strideM
,
q_strideH
,
k_strideN
,
k_strideH
,
v_strideN
,
...
...
@@ -1523,6 +1529,8 @@ void decode_attention_cpu(
head_size
,
head_size_v
,
num_kv_splits
,
q_strideM
,
q_strideH
,
k_strideN
,
k_strideH
,
v_strideN
,
...
...
@@ -1550,6 +1558,8 @@ void decode_attention_cpu(
head_size
,
head_size_v
,
num_kv_splits
,
q_strideM
,
q_strideH
,
k_strideN
,
k_strideH
,
v_strideN
,
...
...
sgl-kernel/csrc/cpu/extend.cpp
View file @
b044400d
...
...
@@ -240,6 +240,8 @@ void extend_attention_kernel_impl(
int
num_heads_kv
,
int
head_size
,
int
head_size_v
,
int
q_strideM
,
int
q_strideH
,
int
ke_strideN
,
int
ke_strideH
,
int
ve_strideN
,
...
...
@@ -259,8 +261,6 @@ void extend_attention_kernel_impl(
using
Vec
=
at
::
vec
::
Vectorized
<
float
>
;
// strides
const
int
q_strideM
=
num_heads
*
head_size
;
const
int
q_strideH
=
head_size
;
const
int
o_strideM
=
num_heads
*
head_size_v
;
const
int
o_strideH
=
head_size_v
;
...
...
@@ -606,7 +606,7 @@ void extend_attention_cpu(
extend_seq_lens
,
extend_start_loc
}));
CHECK_INPUT
(
q_extend
);
CHECK_
LAST_DIM_CONTIGUOUS_
INPUT
(
q_extend
);
CHECK_INPUT
(
o_extend
);
CHECK_LAST_DIM_CONTIGUOUS_INPUT
(
k_extend
);
CHECK_LAST_DIM_CONTIGUOUS_INPUT
(
v_extend
);
...
...
@@ -623,7 +623,9 @@ void extend_attention_cpu(
int
head_size
=
q_extend
.
size
(
2
);
int
head_size_v
=
v_extend
.
size
(
2
);
// strides for k_extend and v_extend
// strides for q_extend, k_extend and v_extend
int
q_strideM
=
q_extend
.
stride
(
0
);
int
q_strideH
=
q_extend
.
stride
(
1
);
int
ke_strideN
=
k_extend
.
stride
(
0
);
int
ke_strideH
=
k_extend
.
stride
(
1
);
int
ve_strideN
=
v_extend
.
stride
(
0
);
...
...
@@ -698,6 +700,8 @@ void extend_attention_cpu(
num_heads_kv
,
head_size
,
head_size_v
,
q_strideM
,
q_strideH
,
ke_strideN
,
ke_strideH
,
ve_strideN
,
...
...
test/srt/cpu/test_decode.py
View file @
b044400d
...
...
@@ -102,9 +102,10 @@ class TestDecodeAttention(CustomTestCase):
device
=
device
,
)
# k_buffer, v_buffer, key and value supports non-contiguous tensors
# k_buffer, v_buffer,
query,
key and value supports non-contiguous tensors
k_buffer
=
k_buffer
.
transpose
(
0
,
1
).
contiguous
().
transpose
(
0
,
1
)
v_buffer
=
v_buffer
.
transpose
(
0
,
1
).
contiguous
().
transpose
(
0
,
1
)
q
=
q
.
transpose
(
0
,
1
).
contiguous
().
transpose
(
0
,
1
)
key
=
key
.
transpose
(
0
,
1
).
contiguous
().
transpose
(
0
,
1
)
value
=
value
.
transpose
(
0
,
1
).
contiguous
().
transpose
(
0
,
1
)
torch
.
ops
.
sgl_kernel
.
decode_attention_cpu
(
...
...
test/srt/cpu/test_extend.py
View file @
b044400d
...
...
@@ -123,7 +123,8 @@ class TestExtendAttention(CustomTestCase):
(
b_seq_len_extend
[
i
],
H_Q
,
D
),
dtype
=
dtype
)
# k_extend, v_extend, k_buffer and v_buffer supports non-contiguous tensors
# q_extend, k_extend, v_extend, k_buffer and v_buffer supports non-contiguous tensors
q_extend
=
q_extend
.
transpose
(
0
,
1
).
contiguous
().
transpose
(
0
,
1
)
k_extend
=
k_extend
.
transpose
(
0
,
1
).
contiguous
().
transpose
(
0
,
1
)
v_extend
=
v_extend
.
transpose
(
0
,
1
).
contiguous
().
transpose
(
0
,
1
)
k_buffer
=
k_buffer
.
transpose
(
0
,
1
).
contiguous
().
transpose
(
0
,
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment