Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5afd3276
Unverified
Commit
5afd3276
authored
Oct 16, 2025
by
rongfu.leng
Committed by
GitHub
Oct 16, 2025
Browse files
[Feature] Add process_weights_after_loading to AttentionImpl (#26870)
Signed-off-by:
rongfu.leng
<
rongfu.leng@daocloud.io
>
parent
43721bc6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
9 additions
and
10 deletions
+9
-10
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+3
-0
vllm/attention/layer.py
vllm/attention/layer.py
+1
-10
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+5
-0
No files found.
vllm/attention/backends/abstract.py
View file @
5afd3276
...
...
@@ -207,6 +207,9 @@ class AttentionImpl(ABC, Generic[T]):
"""
return
False
def
process_weights_after_loading
(
self
,
act_dtype
:
torch
.
dtype
):
pass
class
MLAAttentionImpl
(
AttentionImpl
[
T
],
Generic
[
T
]):
@
abstractmethod
...
...
vllm/attention/layer.py
View file @
5afd3276
...
...
@@ -404,16 +404,7 @@ class Attention(nn.Module, AttentionLayerBase):
return
s
def
process_weights_after_loading
(
self
,
act_dtype
:
torch
.
dtype
):
if
hasattr
(
self
.
impl
,
"process_weights_after_loading"
):
self
.
impl
.
process_weights_after_loading
(
act_dtype
)
# FlashInfer requires attention sinks to be float32
if
self
.
backend
==
_Backend
.
FLASHINFER
and
hasattr
(
self
.
impl
,
"sinks"
):
from
vllm.v1.attention.backends.flashinfer
import
FlashInferImpl
assert
isinstance
(
self
.
impl
,
FlashInferImpl
)
if
self
.
impl
.
sinks
is
not
None
and
self
.
impl
.
sinks
.
dtype
!=
torch
.
float32
:
self
.
impl
.
sinks
=
self
.
impl
.
sinks
.
to
(
torch
.
float32
)
self
.
impl
.
process_weights_after_loading
(
act_dtype
)
def
get_attn_backend
(
self
)
->
type
[
AttentionBackend
]:
return
self
.
attn_backend
...
...
vllm/v1/attention/backends/flashinfer.py
View file @
5afd3276
...
...
@@ -833,6 +833,11 @@ class FlashInferImpl(AttentionImpl):
return
self
.
support_trtllm_attn
# FlashInfer requires attention sinks to be float32
def
process_weights_after_loading
(
self
,
act_dtype
:
torch
.
dtype
):
if
self
.
sinks
is
not
None
and
self
.
sinks
.
dtype
!=
torch
.
float32
:
self
.
sinks
=
self
.
sinks
.
to
(
torch
.
float32
)
def
forward
(
self
,
layer
:
torch
.
nn
.
Module
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment