Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ad44437b
Unverified
Commit
ad44437b
authored
Nov 20, 2024
by
Isotr0py
Committed by
GitHub
Nov 20, 2024
Browse files
[Bugfix] Fix Mamba model initialization and MLP Speculator weights loading (#10456)
Signed-off-by:
Isotr0py
<
2037008807@qq.com
>
parent
9e05252b
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
4 additions
and
7 deletions
+4
-7
vllm/model_executor/models/mamba.py
vllm/model_executor/models/mamba.py
+2
-6
vllm/model_executor/models/mlp_speculator.py
vllm/model_executor/models/mlp_speculator.py
+2
-1
No files found.
vllm/model_executor/models/mamba.py
View file @
ad44437b
"""PyTorch MAMBA model."""
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
torch
from
torch
import
nn
...
...
@@ -243,10 +243,8 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"A_log"
in
name
:
name
=
name
.
replace
(
"A_log"
,
"A"
)
...
...
@@ -258,5 +256,3 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
vllm/model_executor/models/mlp_speculator.py
View file @
ad44437b
...
...
@@ -193,7 +193,8 @@ class MLPSpeculator(nn.Module):
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
param
=
params_dict
.
get
(
name
.
replace
(
"speculator."
,
""
))
name
=
name
.
replace
(
"speculator."
,
""
)
param
=
params_dict
.
get
(
name
)
if
param
is
not
None
:
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment