Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a5fff3bd
Unverified
Commit
a5fff3bd
authored
Aug 04, 2025
by
Raghav Ravishankar
Committed by
GitHub
Aug 04, 2025
Browse files
Fix Arcee model weight loading: Add custom load_weights (#21725)
Signed-off-by:
alyosha-swamy
<
raghav@arcee.ai
>
parent
1539ced9
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
80 additions
and
6 deletions
+80
-6
tests/models/registry.py
tests/models/registry.py
+1
-2
vllm/model_executor/models/arcee.py
vllm/model_executor/models/arcee.py
+79
-4
No files found.
tests/models/registry.py
View file @
a5fff3bd
...
@@ -139,8 +139,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
...
@@ -139,8 +139,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"AquilaForCausalLM"
:
_HfExamplesInfo
(
"BAAI/AquilaChat2-7B"
,
"AquilaForCausalLM"
:
_HfExamplesInfo
(
"BAAI/AquilaChat2-7B"
,
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"ArceeForCausalLM"
:
_HfExamplesInfo
(
"arcee-ai/AFM-4.5B-Base"
,
"ArceeForCausalLM"
:
_HfExamplesInfo
(
"arcee-ai/AFM-4.5B-Base"
),
is_available_online
=
False
),
"ArcticForCausalLM"
:
_HfExamplesInfo
(
"Snowflake/snowflake-arctic-instruct"
,
"ArcticForCausalLM"
:
_HfExamplesInfo
(
"Snowflake/snowflake-arctic-instruct"
,
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"BaiChuanForCausalLM"
:
_HfExamplesInfo
(
"baichuan-inc/Baichuan-7B"
,
"BaiChuanForCausalLM"
:
_HfExamplesInfo
(
"baichuan-inc/Baichuan-7B"
,
...
...
vllm/model_executor/models/arcee.py
View file @
a5fff3bd
...
@@ -24,10 +24,12 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...
@@ -24,10 +24,12 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
)
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
PPMissingLayer
,
from
.utils
import
(
AutoWeightsLoader
,
PPMissingLayer
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
make_empty_intermediate_tensors_factory
,
make_layers
)
...
@@ -260,6 +262,81 @@ class ArceeModel(nn.Module):
...
@@ -260,6 +262,81 @@ class ArceeModel(nn.Module):
return
hidden_states
,
aux_hidden_states
return
hidden_states
,
aux_hidden_states
return
hidden_states
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
"""Load weights, mapping q/k/v projections to fused qkv_proj."""
stacked_params_mapping
=
[
(
".qkv_proj"
,
".q_proj"
,
"q"
),
(
".qkv_proj"
,
".k_proj"
,
"k"
),
(
".qkv_proj"
,
".v_proj"
,
"v"
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
if
(
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
):
continue
if
(
self
.
quant_config
is
not
None
and
(
scale_name
:
=
self
.
quant_config
.
get_cache_scale
(
name
))):
param
=
params_dict
[
scale_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
loaded_weight
=
(
loaded_weight
if
loaded_weight
.
dim
()
==
0
else
loaded_weight
[
0
])
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
scale_name
)
continue
if
"scale"
in
name
:
remapped_name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
remapped_name
is
None
:
continue
name
=
remapped_name
mapped
=
False
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
mapped
=
True
break
if
is_pp_missing_parameter
(
name
,
self
):
mapped
=
True
break
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
# type: ignore[attr-defined]
weight_loader
(
param
,
loaded_weight
,
shard_id
)
loaded_params
.
add
(
name
)
mapped
=
True
break
if
mapped
:
continue
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
class
ArceeForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
class
ArceeForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
"""Arcee Model for causal language modeling, integrated with vLLM
"""Arcee Model for causal language modeling, integrated with vLLM
...
@@ -304,8 +381,7 @@ class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -304,8 +381,7 @@ class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
else
:
else
:
# Placeholder for lm_head on non-last ranks
# Placeholder for lm_head on non-last ranks
self
.
lm_head
=
PPMissingLayer
()
self
.
lm_head
=
PPMissingLayer
()
# Provide a reference to the model's method for generating empty
# tensors (used in pipeline parallel schedule)
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
self
.
model
.
make_empty_intermediate_tensors
)
...
@@ -316,7 +392,6 @@ class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -316,7 +392,6 @@ class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
# Forward pass through the Arcee model backbone
model_output
=
self
.
model
(
input_ids
=
input_ids
,
model_output
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
positions
=
positions
,
intermediate_tensors
=
intermediate_tensors
,
intermediate_tensors
=
intermediate_tensors
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment