Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ce75efee
Unverified
Commit
ce75efee
authored
May 28, 2025
by
Lucas Wilkinson
Committed by
GitHub
May 28, 2025
Browse files
[BugFix] FA2 MLA Accuracy Issue (#18807)
Signed-off-by:
LucasWilkinson
<
lwilkinson@neuralmagic.com
>
parent
aa42561e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
16 additions
and
8 deletions
+16
-8
csrc/attention/merge_attn_states.cu
csrc/attention/merge_attn_states.cu
+8
-0
vllm/attention/backends/mla/common.py
vllm/attention/backends/mla/common.py
+4
-4
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+4
-4
No files found.
csrc/attention/merge_attn_states.cu
View file @
ce75efee
...
@@ -143,6 +143,14 @@ void merge_attn_states_launcher(torch::Tensor& output,
...
@@ -143,6 +143,14 @@ void merge_attn_states_launcher(torch::Tensor& output,
const
uint
pack_size
=
16
/
sizeof
(
scalar_t
);
const
uint
pack_size
=
16
/
sizeof
(
scalar_t
);
TORCH_CHECK
(
head_size
%
pack_size
==
0
,
TORCH_CHECK
(
head_size
%
pack_size
==
0
,
"headsize must be multiple of pack_size:"
,
pack_size
);
"headsize must be multiple of pack_size:"
,
pack_size
);
TORCH_CHECK
(
output
.
stride
(
-
2
)
==
head_size
&&
output
.
stride
(
-
1
)
==
1
,
"output heads must be contiguous in memory"
);
TORCH_CHECK
(
prefix_output
.
stride
(
-
2
)
==
head_size
&&
prefix_output
.
stride
(
-
1
)
==
1
,
"prefix_output heads must be contiguous in memory"
);
TORCH_CHECK
(
suffix_output
.
stride
(
-
2
)
==
head_size
&&
suffix_output
.
stride
(
-
1
)
==
1
,
"suffix_output heads must be contiguous in memory"
);
float
*
output_lse_ptr
=
nullptr
;
float
*
output_lse_ptr
=
nullptr
;
if
(
output_lse
.
has_value
())
{
if
(
output_lse
.
has_value
())
{
output_lse_ptr
=
output_lse
.
value
().
data_ptr
<
float
>
();
output_lse_ptr
=
output_lse
.
value
().
data_ptr
<
float
>
();
...
...
vllm/attention/backends/mla/common.py
View file @
ce75efee
...
@@ -1093,10 +1093,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1093,10 +1093,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
if
isinstance
(
attn_out
,
tuple
):
if
isinstance
(
attn_out
,
tuple
):
attn_out
,
*
rest
=
attn_out
attn_out
,
*
rest
=
attn_out
# unpad if necessary
if
self
.
_pad_v
:
attn_out
=
attn_out
[...,
:
v
.
shape
[
-
1
]]
# Remain consistent with old `flash_attn_varlen_func` where there
# Remain consistent with old `flash_attn_varlen_func` where there
# is only one output tensor if `return_softmax_lse` is False.
# is only one output tensor if `return_softmax_lse` is False.
if
return_softmax_lse
:
if
return_softmax_lse
:
...
@@ -1294,6 +1290,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1294,6 +1290,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
suffix_lse
=
suffix_lse
,
suffix_lse
=
suffix_lse
,
)
)
# unpad if necessary
if
self
.
_pad_v
:
output
=
output
[...,
:
v
.
shape
[
-
1
]]
return
output
.
flatten
(
start_dim
=-
2
)
return
output
.
flatten
(
start_dim
=-
2
)
@
abstractmethod
@
abstractmethod
...
...
vllm/v1/attention/backends/mla/common.py
View file @
ce75efee
...
@@ -653,10 +653,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -653,10 +653,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
if
isinstance
(
attn_out
,
tuple
):
if
isinstance
(
attn_out
,
tuple
):
attn_out
,
lse
=
attn_out
[
0
],
attn_out
[
1
]
attn_out
,
lse
=
attn_out
[
0
],
attn_out
[
1
]
# unpad if necessary
if
self
.
_pad_v
:
attn_out
=
attn_out
[...,
:
v
.
shape
[
-
1
]]
# Remain consistent with old `flash_attn_varlen_func` where there
# Remain consistent with old `flash_attn_varlen_func` where there
# is only one output tensor if `return_softmax_lse` is False.
# is only one output tensor if `return_softmax_lse` is False.
if
return_softmax_lse
:
if
return_softmax_lse
:
...
@@ -839,6 +835,10 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -839,6 +835,10 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
suffix_lse
=
suffix_lse
,
suffix_lse
=
suffix_lse
,
)
)
# unpad if necessary
if
self
.
_pad_v
:
output
=
output
[...,
:
v
.
shape
[
-
1
]]
return
output
.
flatten
(
start_dim
=-
2
)
return
output
.
flatten
(
start_dim
=-
2
)
@
abstractmethod
@
abstractmethod
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment