Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0475448e
"src/vscode:/vscode.git/clone" did not exist on "7200daa412b9d0738e655fbac99077f9b899d1f1"
Unverified
Commit
0475448e
authored
Aug 06, 2025
by
Ke Bao
Committed by
GitHub
Aug 06, 2025
Browse files
Optimize triton swa kernel by skipping computation (#8860)
parent
399e7ec8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
398 additions
and
98 deletions
+398
-98
benchmark/kernels/sliding_window_attention_triton/bench_triton_swa_kernel.py
...liding_window_attention_triton/bench_triton_swa_kernel.py
+283
-0
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
...glang/srt/layers/attention/triton_ops/extend_attention.py
+115
-98
No files found.
benchmark/kernels/sliding_window_attention_triton/bench_triton_swa_kernel.py
0 → 100644
View file @
0475448e
import
itertools
import
torch
import
torch.nn.functional
as
F
import
triton.testing
as
tt
from
sglang.srt.layers.attention.triton_ops.extend_attention
import
extend_attention_fwd
def
extend_attention_fwd_torch
(
q
:
torch
.
Tensor
,
# [extend_tokens, H_Q, D]
k
:
torch
.
Tensor
,
# [extend_tokens, H_KV, D]
v
:
torch
.
Tensor
,
# [extend_tokens, H_KV, D]
o
:
torch
.
Tensor
,
# [extend_tokens, H_Q, D]
k_cache
:
torch
.
Tensor
,
# [total_tokens, H_KV, D]
v_cache
:
torch
.
Tensor
,
# [total_tokens, H_KV, D]
qo_indptr
:
torch
.
Tensor
,
# [B+1]
kv_indptr
:
torch
.
Tensor
,
# [B+1]
kv_indices
:
torch
.
Tensor
,
# [prefix_tokens]
sliding_window_size
:
int
,
):
B
=
qo_indptr
.
size
(
0
)
-
1
_
,
H_Q
,
D
=
q
.
shape
_
,
H_KV
,
_
=
k
.
shape
group_size
=
H_Q
//
H_KV
scale
=
1.0
/
D
**
0.5
for
i
in
range
(
B
):
q_start
=
int
(
qo_indptr
[
i
].
item
())
q_end
=
int
(
qo_indptr
[
i
+
1
].
item
())
kv_start
=
int
(
kv_indptr
[
i
].
item
())
kv_end
=
int
(
kv_indptr
[
i
+
1
].
item
())
prefix_indices
=
kv_indices
[
kv_start
:
kv_end
]
k_prefix
=
k_cache
[
prefix_indices
]
# [prefix_len, H_KV, D]
v_prefix
=
v_cache
[
prefix_indices
]
# [prefix_len, H_KV, D]
k_extend
=
k
[
q_start
:
q_end
]
# [extend_len, H_KV, D]
v_extend
=
v
[
q_start
:
q_end
]
# [extend_len, H_KV, D]
q_extend
=
q
[
q_start
:
q_end
]
# [extend_len, H_Q, D]
k_full
=
torch
.
cat
([
k_prefix
,
k_extend
],
dim
=
0
)
# [total_len, H_KV, D]
v_full
=
torch
.
cat
([
v_prefix
,
v_extend
],
dim
=
0
)
# [total_len, H_KV, D]
if
group_size
!=
1
:
k_full_hq
=
k_full
.
repeat_interleave
(
group_size
,
dim
=
1
)
# [total_len, H_Q, D]
v_full_hq
=
v_full
.
repeat_interleave
(
group_size
,
dim
=
1
)
# [total_len, H_Q, D]
else
:
k_full_hq
=
k_full
v_full_hq
=
v_full
prefix_len
=
k_prefix
.
size
(
0
)
extend_len
=
k_extend
.
size
(
0
)
total_len
=
prefix_len
+
extend_len
# causal
pos_keys
=
torch
.
arange
(
total_len
,
device
=
q
.
device
)
t
=
prefix_len
+
torch
.
arange
(
extend_len
,
device
=
q
.
device
)
# [extend_len]
causal_mask
=
pos_keys
.
unsqueeze
(
0
)
<=
t
.
unsqueeze
(
1
)
# sliding window
if
sliding_window_size
is
not
None
and
sliding_window_size
>
0
:
start
=
(
t
-
(
sliding_window_size
)).
clamp_min
(
0
)
# [extend_len]
else
:
start
=
torch
.
zeros_like
(
t
)
window_mask
=
pos_keys
.
unsqueeze
(
0
)
>=
start
.
unsqueeze
(
1
)
final_mask
=
causal_mask
&
window_mask
attn_scores
=
(
torch
.
einsum
(
"qhd,khd->qhk"
,
q_extend
,
k_full_hq
)
*
scale
)
# [extend_len, H_Q, total_len]
attn_scores
=
attn_scores
.
masked_fill
(
~
final_mask
.
unsqueeze
(
1
),
float
(
"-inf"
))
attn_weights
=
F
.
softmax
(
attn_scores
,
dim
=-
1
)
o
[
q_start
:
q_end
]
=
torch
.
einsum
(
"qhk,khd->qhd"
,
attn_weights
,
v_full_hq
)
def
_build_batch
(
B
,
N_CTX
,
H_Q
,
H_KV
,
D
,
WINDOW_SIZE
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
):
b_seq_len_prefix
=
torch
.
randint
(
1
,
max
(
2
,
N_CTX
//
2
),
(
B
,),
dtype
=
torch
.
int32
,
device
=
device
)
b_seq_len_extend
=
torch
.
randint
(
1
,
max
(
2
,
N_CTX
//
2
),
(
B
,),
dtype
=
torch
.
int32
,
device
=
device
)
b_seq_len
=
b_seq_len_prefix
+
b_seq_len_extend
b_start_loc
=
torch
.
zeros
((
B
,),
dtype
=
torch
.
int32
,
device
=
device
)
b_start_loc
[
1
:]
=
torch
.
cumsum
(
b_seq_len
[:
-
1
],
0
)
b_start_loc_extend
=
torch
.
zeros
((
B
,),
dtype
=
torch
.
int32
,
device
=
device
)
b_start_loc_extend
[
1
:]
=
torch
.
cumsum
(
b_seq_len_extend
[:
-
1
],
0
)
kv_indptr
=
torch
.
zeros
((
B
+
1
,),
dtype
=
torch
.
int32
,
device
=
device
)
kv_indptr
[
1
:
B
+
1
]
=
torch
.
cumsum
(
b_seq_len_prefix
[:
B
],
dim
=
0
)
kv_indices
=
torch
.
zeros
(
(
int
(
b_seq_len_prefix
.
sum
().
item
()),),
dtype
=
torch
.
int32
,
device
=
device
)
for
i
in
range
(
B
):
s
=
kv_indptr
[
i
].
item
()
e
=
kv_indptr
[
i
+
1
].
item
()
kv_indices
[
s
:
e
]
=
torch
.
arange
(
b_start_loc
[
i
],
b_start_loc
[
i
]
+
b_seq_len_prefix
[
i
],
dtype
=
torch
.
int32
,
device
=
device
,
)
total_token_num
=
int
(
torch
.
sum
(
b_seq_len
).
item
())
extend_token_num
=
int
(
torch
.
sum
(
b_seq_len_extend
).
item
())
k_buffer
=
torch
.
empty
(
(
total_token_num
,
H_KV
,
D
),
dtype
=
dtype
,
device
=
device
).
normal_
(
mean
=
0.1
,
std
=
0.2
)
v_buffer
=
torch
.
empty
(
(
total_token_num
,
H_KV
,
D
),
dtype
=
dtype
,
device
=
device
).
normal_
(
mean
=
0.1
,
std
=
0.2
)
k_extend
=
torch
.
empty
((
extend_token_num
,
H_KV
,
D
),
dtype
=
dtype
,
device
=
device
)
v_extend
=
torch
.
empty
((
extend_token_num
,
H_KV
,
D
),
dtype
=
dtype
,
device
=
device
)
q_extend
=
torch
.
empty
((
extend_token_num
,
H_Q
,
D
),
dtype
=
dtype
,
device
=
device
)
for
i
in
range
(
B
):
extend_start_in_buffer
=
b_start_loc
[
i
]
+
b_seq_len_prefix
[
i
]
extend_end_in_buffer
=
b_start_loc
[
i
]
+
b_seq_len
[
i
]
extend_start
=
b_start_loc_extend
[
i
]
extend_end
=
b_start_loc_extend
[
i
]
+
b_seq_len_extend
[
i
]
k_extend
[
extend_start
:
extend_end
]
=
k_buffer
[
extend_start_in_buffer
:
extend_end_in_buffer
]
v_extend
[
extend_start
:
extend_end
]
=
v_buffer
[
extend_start_in_buffer
:
extend_end_in_buffer
]
q_extend
[
extend_start
:
extend_end
]
=
torch
.
empty
(
(
int
(
b_seq_len_extend
[
i
].
item
()),
H_Q
,
D
),
dtype
=
dtype
,
device
=
device
).
normal_
(
mean
=
0.1
,
std
=
0.2
)
o_extend_triton
=
torch
.
empty
(
(
extend_token_num
,
H_Q
,
D
),
dtype
=
dtype
,
device
=
device
)
o_extend_torch
=
torch
.
empty
((
extend_token_num
,
H_Q
,
D
),
dtype
=
dtype
,
device
=
device
)
b_seq_len_extend
=
b_seq_len
-
b_seq_len_prefix
max_len_extend
=
int
(
torch
.
max
(
b_seq_len_extend
,
0
)[
0
].
item
())
qo_indptr
=
torch
.
zeros
((
B
+
1
,),
dtype
=
torch
.
int32
,
device
=
device
)
qo_indptr
[
1
:
B
+
1
]
=
torch
.
cumsum
(
b_seq_len_extend
[:
B
],
dim
=
0
)
inputs
=
dict
(
q_extend
=
q_extend
,
k_extend
=
k_extend
,
v_extend
=
v_extend
,
k_buffer
=
k_buffer
,
v_buffer
=
v_buffer
,
o_extend_triton
=
o_extend_triton
,
o_extend_torch
=
o_extend_torch
,
qo_indptr
=
qo_indptr
,
kv_indptr
=
kv_indptr
,
kv_indices
=
kv_indices
,
max_len_extend
=
max_len_extend
,
WINDOW_SIZE
=
WINDOW_SIZE
,
)
meta
=
dict
(
B
=
B
,
N_CTX
=
N_CTX
,
H_Q
=
H_Q
,
H_KV
=
H_KV
,
D
=
D
,
extend_token_num
=
extend_token_num
)
return
inputs
,
meta
def
_run_triton
(
inputs
):
extend_attention_fwd
(
inputs
[
"q_extend"
],
inputs
[
"k_extend"
],
inputs
[
"v_extend"
],
inputs
[
"o_extend_triton"
],
inputs
[
"k_buffer"
],
inputs
[
"v_buffer"
],
inputs
[
"qo_indptr"
],
inputs
[
"kv_indptr"
],
inputs
[
"kv_indices"
],
custom_mask
=
None
,
is_causal
=
True
,
mask_indptr
=
None
,
max_len_extend
=
inputs
[
"max_len_extend"
],
sliding_window_size
=
inputs
[
"WINDOW_SIZE"
],
)
def
_run_torch_ref
(
inputs
):
extend_attention_fwd_torch
(
inputs
[
"q_extend"
],
inputs
[
"k_extend"
],
inputs
[
"v_extend"
],
inputs
[
"o_extend_torch"
],
inputs
[
"k_buffer"
],
inputs
[
"v_buffer"
],
inputs
[
"qo_indptr"
],
inputs
[
"kv_indptr"
],
inputs
[
"kv_indices"
],
inputs
[
"WINDOW_SIZE"
],
)
N_CTXS
=
[
1024
,
2048
,
4096
,
8192
]
WINDOW_SIZES
=
[
-
1
,
127
,
256
,
512
]
CONFIGS
=
list
(
itertools
.
product
(
N_CTXS
,
WINDOW_SIZES
))
PROVIDERS
=
[
"torch"
,
"triton"
]
@
tt
.
perf_report
(
tt
.
Benchmark
(
x_names
=
[
"N_CTX"
,
"WINDOW_SIZE"
],
x_vals
=
CONFIGS
,
line_arg
=
"provider"
,
line_vals
=
PROVIDERS
,
line_names
=
PROVIDERS
,
ylabel
=
"Runtime (ms)"
,
plot_name
=
"extend_attention_triton_vs_torch"
,
args
=
{
"B"
:
32
,
"H_Q"
:
64
,
"H_KV"
:
8
,
"D"
:
128
,
"dtype"
:
"bf16"
,
"device"
:
"cuda"
,
"check_correctness"
:
False
,
"warmup"
:
25
,
"rep"
:
100
,
},
)
)
def
bench
(
N_CTX
,
provider
,
B
,
H_Q
,
H_KV
,
D
,
dtype
,
device
,
WINDOW_SIZE
,
check_correctness
,
warmup
,
rep
,
):
torch
.
manual_seed
(
0
)
torch
.
cuda
.
manual_seed
(
0
)
dtype_map
=
{
"bf16"
:
torch
.
bfloat16
,
"fp16"
:
torch
.
float16
,
"fp32"
:
torch
.
float32
}
dt
=
dtype_map
[
dtype
]
inputs
,
_
=
_build_batch
(
B
,
N_CTX
,
H_Q
,
H_KV
,
D
,
WINDOW_SIZE
,
dtype
=
dt
,
device
=
device
)
if
check_correctness
and
provider
==
"triton"
:
_run_triton
(
inputs
)
_run_torch_ref
(
inputs
)
torch
.
cuda
.
synchronize
()
if
not
torch
.
allclose
(
inputs
[
"o_extend_triton"
],
inputs
[
"o_extend_torch"
],
rtol
=
1e-3
,
atol
=
1e-3
):
raise
AssertionError
(
"Mismatch between triton and torch reference."
)
if
provider
==
"triton"
:
ms
=
tt
.
do_bench
(
lambda
:
_run_triton
(
inputs
),
warmup
=
warmup
,
rep
=
rep
)
elif
provider
==
"torch"
:
ms
=
tt
.
do_bench
(
lambda
:
_run_torch_ref
(
inputs
),
warmup
=
warmup
,
rep
=
rep
)
else
:
raise
ValueError
(
provider
)
return
ms
if
__name__
==
"__main__"
:
bench
.
run
(
print_data
=
True
,
show_plots
=
False
)
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
View file @
0475448e
...
...
@@ -134,38 +134,6 @@ def _fwd_kernel(
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
mask_n
=
(
start_n
+
offs_n
)
<
cur_seq_len_prefix
offs_kv_loc
=
tl
.
load
(
kv_indices
+
cur_seq_kv_start_idx
+
start_n
+
offs_n
,
mask
=
mask_n
,
other
=
0
)
# load k in transposed way
offs_buf_k
=
(
offs_kv_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
offs_d
[:,
None
]
)
k
=
tl
.
load
(
K_Buffer
+
offs_buf_k
,
mask
=
(
mask_n
[
None
,
:])
&
(
mask_d
[:,
None
]),
other
=
0.0
)
qk
=
tl
.
dot
(
q
.
to
(
k
.
dtype
),
k
)
if
BLOCK_DPE
>
0
:
offs_kpe
=
(
offs_kv_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
offs_dpe
[:,
None
]
)
kpe
=
tl
.
load
(
K_Buffer
+
offs_kpe
,
mask
=
mask_n
[
None
,
:],
other
=
0.0
,
)
qk
+=
tl
.
dot
(
qpe
.
to
(
kpe
.
dtype
),
kpe
)
qk
*=
sm_scale
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
final_mask
=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
if
USE_CUSTOM_MASK
and
not
SKIP_PREFIX_CUSTOM_MASK
:
custom_mask
=
tl
.
load
(
...
...
@@ -185,28 +153,72 @@ def _fwd_kernel(
cur_seq_len_prefix
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
]
)
<=
(
start_n
+
offs_n
[
None
,
:]
+
SLIDING_WINDOW_SIZE
)
final_mask
&=
window_mask
qk
=
tl
.
where
(
final_mask
,
qk
,
float
(
"-inf"
))
row_max
=
tl
.
max
(
qk
,
1
)
row_max_fixed
=
tl
.
where
(
row_max
==
float
(
"-inf"
),
-
1e20
,
row_max
)
n_e_max
=
tl
.
maximum
(
row_
max
_
fi
xed
,
e_max
)
SKIP_TILE
=
False
if
(
USE_CUSTOM_MASK
and
not
SKIP_PREFIX_CUSTOM_MASK
)
or
SLIDING_WINDOW_SIZE
>
0
:
SKIP_TILE
=
tl
.
max
(
tl
.
max
(
fi
nal_mask
.
to
(
tl
.
int32
),
axis
=
1
),
axis
=
0
)
==
0
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
p
=
tl
.
exp
(
qk
-
n_e_max
[:,
None
])
deno
=
deno
*
re_scale
+
tl
.
sum
(
p
,
1
)
if
not
SKIP_TILE
:
offs_kv_loc
=
tl
.
load
(
kv_indices
+
cur_seq_kv_start_idx
+
start_n
+
offs_n
,
mask
=
mask_n
,
other
=
0
,
)
offs_buf_v
=
(
offs_kv_loc
[:,
None
]
*
stride_buf_vbs
+
cur_kv_head
*
stride_buf_vh
+
offs_dv
[
None
,
:]
)
v
=
tl
.
load
(
V_Buffer
+
offs_buf_v
,
mask
=
mask_n
[:,
None
]
&
mask_dv
[
None
,
:],
other
=
0.0
)
p
=
p
.
to
(
v
.
dtype
)
acc
=
acc
*
re_scale
[:,
None
]
+
tl
.
dot
(
p
,
v
)
# load k in transposed way
offs_buf_k
=
(
offs_kv_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
offs_d
[:,
None
]
)
k
=
tl
.
load
(
K_Buffer
+
offs_buf_k
,
mask
=
(
mask_n
[
None
,
:])
&
(
mask_d
[:,
None
]),
other
=
0.0
,
)
e_max
=
n_e_max
qk
=
tl
.
dot
(
q
.
to
(
k
.
dtype
),
k
)
if
BLOCK_DPE
>
0
:
offs_kpe
=
(
offs_kv_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
offs_dpe
[:,
None
]
)
kpe
=
tl
.
load
(
K_Buffer
+
offs_kpe
,
mask
=
mask_n
[
None
,
:],
other
=
0.0
,
)
qk
+=
tl
.
dot
(
qpe
.
to
(
kpe
.
dtype
),
kpe
)
qk
*=
sm_scale
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
qk
=
tl
.
where
(
final_mask
,
qk
,
float
(
"-inf"
))
row_max
=
tl
.
max
(
qk
,
1
)
row_max_fixed
=
tl
.
where
(
row_max
==
float
(
"-inf"
),
-
1e20
,
row_max
)
n_e_max
=
tl
.
maximum
(
row_max_fixed
,
e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
p
=
tl
.
exp
(
qk
-
n_e_max
[:,
None
])
deno
=
deno
*
re_scale
+
tl
.
sum
(
p
,
1
)
offs_buf_v
=
(
offs_kv_loc
[:,
None
]
*
stride_buf_vbs
+
cur_kv_head
*
stride_buf_vh
+
offs_dv
[
None
,
:]
)
v
=
tl
.
load
(
V_Buffer
+
offs_buf_v
,
mask
=
mask_n
[:,
None
]
&
mask_dv
[
None
,
:],
other
=
0.0
,
)
p
=
p
.
to
(
v
.
dtype
)
acc
=
acc
*
re_scale
[:,
None
]
+
tl
.
dot
(
p
,
v
)
e_max
=
n_e_max
# stage 2: compute the triangle part
...
...
@@ -219,35 +231,6 @@ def _fwd_kernel(
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
mask_n
=
(
start_n
+
offs_n
)
<
cur_block_m_end
# load k in transposed way
offs_k
=
(
(
cur_seq_extend_start_idx
+
start_n
+
offs_n
[
None
,
:])
*
stride_kbs
+
cur_kv_head
*
stride_kh
+
offs_d
[:,
None
]
)
k
=
tl
.
load
(
K_Extend
+
offs_k
,
mask
=
(
mask_n
[
None
,
:])
&
(
mask_d
[:,
None
]),
other
=
0.0
)
qk
=
tl
.
dot
(
q
,
k
,
out_dtype
=
tl
.
float32
)
if
BLOCK_DPE
>
0
:
offs_kpe
=
(
(
cur_seq_extend_start_idx
+
start_n
+
offs_n
[
None
,
:])
*
stride_kbs
+
cur_kv_head
*
stride_kh
+
offs_dpe
[:,
None
]
)
kpe
=
tl
.
load
(
K_Extend
+
offs_kpe
,
mask
=
mask_n
[
None
,
:],
other
=
0.0
,
)
qk
+=
tl
.
dot
(
qpe
,
kpe
)
qk
*=
sm_scale
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
final_mask
=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
if
USE_CUSTOM_MASK
:
custom_mask
=
tl
.
load
(
...
...
@@ -279,28 +262,62 @@ def _fwd_kernel(
)
final_mask
&=
window_mask
qk
=
tl
.
where
(
final_mask
,
qk
,
float
(
"-inf"
))
row_max
=
tl
.
max
(
qk
,
1
)
row_max_fixed
=
tl
.
where
(
row_max
==
float
(
"-inf"
),
-
1e20
,
row_max
)
n_e_max
=
tl
.
maximum
(
row_max_fixed
,
e_max
)
SKIP_TILE
=
False
if
USE_CUSTOM_MASK
or
SLIDING_WINDOW_SIZE
>
0
:
SKIP_TILE
=
tl
.
max
(
tl
.
max
(
final_mask
.
to
(
tl
.
int32
),
axis
=
1
),
axis
=
0
)
==
0
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
p
=
tl
.
exp
(
qk
-
n_e_max
[:,
None
])
deno
=
deno
*
re_scale
+
tl
.
sum
(
p
,
1
)
if
not
SKIP_TILE
:
# load k in transposed way
offs_k
=
(
(
cur_seq_extend_start_idx
+
start_n
+
offs_n
[
None
,
:])
*
stride_kbs
+
cur_kv_head
*
stride_kh
+
offs_d
[:,
None
]
)
k
=
tl
.
load
(
K_Extend
+
offs_k
,
mask
=
(
mask_n
[
None
,
:])
&
(
mask_d
[:,
None
]),
other
=
0.0
)
offs_v
=
(
(
cur_seq_extend_start_idx
+
start_n
+
offs_n
[:,
None
])
*
stride_vbs
+
cur_kv_head
*
stride_vh
+
offs_dv
[
None
,
:]
)
v
=
tl
.
load
(
V_Extend
+
offs_v
,
mask
=
mask_n
[:,
None
]
&
mask_dv
[
None
,
:],
other
=
0.0
)
p
=
p
.
to
(
v
.
dtype
)
acc
=
acc
*
re_scale
[:,
None
]
+
tl
.
dot
(
p
,
v
)
qk
=
tl
.
dot
(
q
,
k
,
out_dtype
=
tl
.
float32
)
if
BLOCK_DPE
>
0
:
offs_kpe
=
(
(
cur_seq_extend_start_idx
+
start_n
+
offs_n
[
None
,
:])
*
stride_kbs
+
cur_kv_head
*
stride_kh
+
offs_dpe
[:,
None
]
)
kpe
=
tl
.
load
(
K_Extend
+
offs_kpe
,
mask
=
mask_n
[
None
,
:],
other
=
0.0
,
)
qk
+=
tl
.
dot
(
qpe
,
kpe
)
qk
*=
sm_scale
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
qk
=
tl
.
where
(
final_mask
,
qk
,
float
(
"-inf"
))
row_max
=
tl
.
max
(
qk
,
1
)
row_max_fixed
=
tl
.
where
(
row_max
==
float
(
"-inf"
),
-
1e20
,
row_max
)
n_e_max
=
tl
.
maximum
(
row_max_fixed
,
e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
p
=
tl
.
exp
(
qk
-
n_e_max
[:,
None
])
deno
=
deno
*
re_scale
+
tl
.
sum
(
p
,
1
)
offs_v
=
(
(
cur_seq_extend_start_idx
+
start_n
+
offs_n
[:,
None
])
*
stride_vbs
+
cur_kv_head
*
stride_vh
+
offs_dv
[
None
,
:]
)
v
=
tl
.
load
(
V_Extend
+
offs_v
,
mask
=
mask_n
[:,
None
]
&
mask_dv
[
None
,
:],
other
=
0.0
)
p
=
p
.
to
(
v
.
dtype
)
acc
=
acc
*
re_scale
[:,
None
]
+
tl
.
dot
(
p
,
v
)
e_max
=
n_e_max
e_max
=
n_e_max
if
HAS_SINK
:
cur_sink
=
tl
.
load
(
sink_ptr
+
cur_head
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment