Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
280d62b8
Unverified
Commit
280d62b8
authored
Apr 15, 2025
by
DefTruth
Committed by
GitHub
Apr 15, 2025
Browse files
[Kernel] Remove redundant Exp calculations (#16123)
Signed-off-by:
DefTruth
<
qiustudent_r@163.com
>
parent
1666e664
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
3 deletions
+6
-3
vllm/attention/ops/triton_merge_attn_states.py
vllm/attention/ops/triton_merge_attn_states.py
+6
-3
No files found.
vllm/attention/ops/triton_merge_attn_states.py
View file @
280d62b8
...
@@ -66,7 +66,10 @@ def merge_attn_states_kernel(
...
@@ -66,7 +66,10 @@ def merge_attn_states_kernel(
max_lse
=
tl
.
maximum
(
p_lse
,
s_lse
)
max_lse
=
tl
.
maximum
(
p_lse
,
s_lse
)
p_lse
=
p_lse
-
max_lse
p_lse
=
p_lse
-
max_lse
s_lse
=
s_lse
-
max_lse
s_lse
=
s_lse
-
max_lse
out_se
=
(
tl
.
exp
(
p_lse
)
+
tl
.
exp
(
s_lse
))
# Will reuse precomputed Exp values for scale factor computation.
p_se
=
tl
.
exp
(
p_lse
)
s_se
=
tl
.
exp
(
s_lse
)
out_se
=
(
p_se
+
s_se
)
if
OUTPUT_LSE
:
if
OUTPUT_LSE
:
out_lse
=
tl
.
log
(
out_se
)
+
max_lse
out_lse
=
tl
.
log
(
out_se
)
+
max_lse
...
@@ -84,8 +87,8 @@ def merge_attn_states_kernel(
...
@@ -84,8 +87,8 @@ def merge_attn_states_kernel(
# NOTE(woosuk): Be careful with the numerical stability.
# NOTE(woosuk): Be careful with the numerical stability.
# We should compute the scale first, and then multiply it with the output.
# We should compute the scale first, and then multiply it with the output.
# Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly.
# Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly.
p_scale
=
tl
.
exp
(
p_
l
se
)
/
out_se
p_scale
=
p_se
/
out_se
s_scale
=
tl
.
exp
(
s_
l
se
)
/
out_se
s_scale
=
s_se
/
out_se
out
=
p_out
*
p_scale
+
s_out
*
s_scale
out
=
p_out
*
p_scale
+
s_out
*
s_scale
tl
.
store
(
output
+
token_idx
*
num_heads
*
HEAD_SIZE
+
tl
.
store
(
output
+
token_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
+
head_arange
,
head_idx
*
HEAD_SIZE
+
head_arange
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment