Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
87d87147
Unverified
Commit
87d87147
authored
May 16, 2025
by
learner0810
Committed by
GitHub
May 16, 2025
Browse files
[Model] Use autoweightloader for dbrx (#18251)
Signed-off-by:
learner0810
<
zhongjun.li@daocloud.io
>
parent
a5f8c111
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
53 additions
and
47 deletions
+53
-47
vllm/model_executor/models/dbrx.py
vllm/model_executor/models/dbrx.py
+53
-47
No files found.
vllm/model_executor/models/dbrx.py
View file @
87d87147
...
@@ -26,7 +26,7 @@ from vllm.sequence import IntermediateTensors
...
@@ -26,7 +26,7 @@ from vllm.sequence import IntermediateTensors
from
vllm.transformers_utils.configs.dbrx
import
DbrxConfig
from
vllm.transformers_utils.configs.dbrx
import
DbrxConfig
from
.interfaces
import
SupportsPP
from
.interfaces
import
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
from
.utils
import
(
AutoWeightsLoader
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
maybe_prefix
)
...
@@ -319,6 +319,7 @@ class DbrxModel(nn.Module):
...
@@ -319,6 +319,7 @@ class DbrxModel(nn.Module):
cache_config
=
vllm_config
.
cache_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
self
.
quant_config
=
quant_config
self
.
wte
=
VocabParallelEmbedding
(
self
.
wte
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
vocab_size
,
config
.
d_model
,
config
.
d_model
,
...
@@ -364,6 +365,55 @@ class DbrxModel(nn.Module):
...
@@ -364,6 +365,55 @@ class DbrxModel(nn.Module):
hidden_states
=
self
.
norm_f
(
hidden_states
)
hidden_states
=
self
.
norm_f
(
hidden_states
)
return
hidden_states
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
expert_params_mapping
=
[(
"w13"
if
weight_name
in
[
"w1"
,
"v1"
]
else
"w2"
,
f
"mlp.
{
weight_name
}
"
,
)
for
weight_name
in
[
"w1"
,
"v1"
,
"w2"
]]
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
loaded_params
:
set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
(
self
.
quant_config
is
not
None
and
(
scale_name
:
=
self
.
quant_config
.
get_cache_scale
(
name
))):
# Loading kv cache quantization scales
param
=
params_dict
[
scale_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
loaded_weight
=
(
loaded_weight
if
loaded_weight
.
dim
()
==
0
else
loaded_weight
[
0
])
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
scale_name
)
continue
if
name
.
endswith
((
"w1"
,
"w2"
,
"v1"
)):
name
=
name
+
"_weight"
for
param_name
,
weight_name
in
expert_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
weight_name
,
name
)
break
else
:
if
is_pp_missing_parameter
(
name
,
self
):
continue
# Remapping the name of FP8 kv-scale.
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
class
DbrxForCausalLM
(
nn
.
Module
,
SupportsPP
):
class
DbrxForCausalLM
(
nn
.
Module
,
SupportsPP
):
...
@@ -417,49 +467,5 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
...
@@ -417,49 +467,5 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
torch
.
Tensor
]])
->
set
[
str
]:
expert_params_mapping
=
[(
loader
=
AutoWeightsLoader
(
self
)
"w13"
if
weight_name
in
[
"w1"
,
"v1"
]
else
"w2"
,
return
loader
.
load_weights
(
weights
)
f
"mlp.
{
weight_name
}
"
,
)
for
weight_name
in
[
"w1"
,
"v1"
,
"w2"
]]
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
loaded_params
:
set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
(
self
.
quant_config
is
not
None
and
(
scale_name
:
=
self
.
quant_config
.
get_cache_scale
(
name
))):
# Loading kv cache quantization scales
param
=
params_dict
[
scale_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
loaded_weight
=
(
loaded_weight
if
loaded_weight
.
dim
()
==
0
else
loaded_weight
[
0
])
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
scale_name
)
continue
if
name
.
endswith
((
"w1"
,
"w2"
,
"v1"
)):
name
=
name
+
"_weight"
for
param_name
,
weight_name
in
expert_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
weight_name
,
name
)
break
else
:
if
is_pp_missing_parameter
(
name
,
self
):
continue
# Remapping the name of FP8 kv-scale.
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment