Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
ba4f8267
Unverified
Commit
ba4f8267
authored
Dec 19, 2023
by
Woosuk Kwon
Committed by
GitHub
Dec 19, 2023
Browse files
[BugFix] Fix weight loading for Mixtral with TP (#2208)
parent
de60a3fb
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
26 deletions
+5
-26
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+5
-26
No files found.
vllm/model_executor/models/mixtral.py
View file @
ba4f8267
...
@@ -49,7 +49,6 @@ from vllm.model_executor.parallel_utils.parallel_state import (
...
@@ -49,7 +49,6 @@ from vllm.model_executor.parallel_utils.parallel_state import (
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
hf_model_weights_iterator
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
...
@@ -94,30 +93,6 @@ class MixtralMLP(nn.Module):
...
@@ -94,30 +93,6 @@ class MixtralMLP(nn.Module):
return
current_hidden_states
return
current_hidden_states
class
DummyModule
(
nn
.
Module
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
self
.
w1
=
nn
.
Linear
(
0
,
0
,
bias
=
False
)
self
.
w2
=
nn
.
Linear
(
0
,
0
,
bias
=
False
)
self
.
w3
=
nn
.
Linear
(
0
,
0
,
bias
=
False
)
set_weight_attrs
(
self
.
w1
.
weight
,
{
"weight_loader"
:
self
.
dummy_weight_loader
})
set_weight_attrs
(
self
.
w2
.
weight
,
{
"weight_loader"
:
self
.
dummy_weight_loader
})
set_weight_attrs
(
self
.
w3
.
weight
,
{
"weight_loader"
:
self
.
dummy_weight_loader
})
def
forward
(
self
,
*
args
,
**
kwargs
)
->
None
:
raise
NotImplementedError
()
def
dummy_weight_loader
(
self
,
*
args
,
**
kwargs
)
->
None
:
# pylint: disable=unused-argument
# Noop
return
class
MixtralMoE
(
nn
.
Module
):
class
MixtralMoE
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
@@ -147,7 +122,7 @@ class MixtralMoE(nn.Module):
...
@@ -147,7 +122,7 @@ class MixtralMoE(nn.Module):
config
.
hidden_size
,
config
.
hidden_size
,
config
.
intermediate_size
,
config
.
intermediate_size
,
linear_method
=
linear_method
)
linear_method
=
linear_method
)
if
idx
in
self
.
expert_indicies
else
DummyModule
()
if
idx
in
self
.
expert_indicies
else
None
for
idx
in
range
(
self
.
num_total_experts
)
for
idx
in
range
(
self
.
num_total_experts
)
])
])
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
...
@@ -427,6 +402,10 @@ class MixtralForCausalLM(nn.Module):
...
@@ -427,6 +402,10 @@ class MixtralForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
continue
# Skip experts that are not assigned to this worker.
if
(
"block_sparse_moe.experts."
in
name
and
name
not
in
params_dict
):
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment