Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c0557478
Unverified
Commit
c0557478
authored
Nov 23, 2024
by
youkaichao
Committed by
GitHub
Nov 23, 2024
Browse files
[model][utils] add extract_layer_index utility function (#10599)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
eda2b358
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
59 additions
and
51 deletions
+59
-51
vllm/model_executor/models/arctic.py
vllm/model_executor/models/arctic.py
+18
-23
vllm/model_executor/models/deepseek.py
vllm/model_executor/models/deepseek.py
+11
-8
vllm/model_executor/models/gemma2.py
vllm/model_executor/models/gemma2.py
+5
-10
vllm/model_executor/models/olmoe.py
vllm/model_executor/models/olmoe.py
+2
-6
vllm/model_executor/models/qwen2_moe.py
vllm/model_executor/models/qwen2_moe.py
+2
-4
vllm/model_executor/models/utils.py
vllm/model_executor/models/utils.py
+21
-0
No files found.
vllm/model_executor/models/arctic.py
View file @
c0557478
...
@@ -33,7 +33,7 @@ from vllm.sequence import IntermediateTensors
...
@@ -33,7 +33,7 @@ from vllm.sequence import IntermediateTensors
from
vllm.transformers_utils.configs.arctic
import
ArcticConfig
from
vllm.transformers_utils.configs.arctic
import
ArcticConfig
from
.interfaces
import
SupportsPP
from
.interfaces
import
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
from
.utils
import
(
extract_layer_index
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
maybe_prefix
)
...
@@ -44,15 +44,14 @@ class ArcticMLP(nn.Module):
...
@@ -44,15 +44,14 @@ class ArcticMLP(nn.Module):
def
__init__
(
self
,
def
__init__
(
self
,
config
:
ArcticConfig
,
config
:
ArcticConfig
,
layer_id
:
int
,
expert_id
:
int
=
-
1
,
expert_id
:
int
=
-
1
,
is_residual_mlp
:
bool
=
False
,
is_residual_mlp
:
bool
=
False
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
reduce_results
:
bool
=
True
):
reduce_results
:
bool
=
True
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
self
.
expert_id
=
expert_id
self
.
expert_id
=
expert_id
self
.
layer_id
=
layer_id
self
.
ffn_dim
=
config
.
intermediate_size
if
not
is_residual_mlp
\
self
.
ffn_dim
=
config
.
intermediate_size
if
not
is_residual_mlp
\
else
self
.
hidden_size
else
self
.
hidden_size
...
@@ -85,13 +84,14 @@ class ArcticMoE(nn.Module):
...
@@ -85,13 +84,14 @@ class ArcticMoE(nn.Module):
def
__init__
(
self
,
def
__init__
(
self
,
config
:
ArcticConfig
,
config
:
ArcticConfig
,
layer_id
:
int
,
tp_size
:
Optional
[
int
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
reduce_results
:
bool
=
True
):
reduce_results
:
bool
=
True
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
layer_id
=
extract_layer_index
(
prefix
)
self
.
tp_size
=
tp_size
or
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
tp_size
or
get_tensor_model_parallel_world_size
()
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
self
.
num_experts
=
config
.
num_local_experts
self
.
num_experts
=
config
.
num_local_experts
...
@@ -109,15 +109,16 @@ class ArcticMoE(nn.Module):
...
@@ -109,15 +109,16 @@ class ArcticMoE(nn.Module):
if
not
self
.
is_moe_layer
:
if
not
self
.
is_moe_layer
:
self
.
mlp
=
ArcticMLP
(
config
,
self
.
mlp
=
ArcticMLP
(
config
,
layer_id
=
layer_id
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
reduce_results
=
reduce_results
)
reduce_results
=
reduce_results
,
prefix
=
f
"
{
prefix
}
.mlp"
)
else
:
else
:
self
.
gate
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
gate
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
num_experts
,
self
.
num_experts
,
bias
=
False
,
bias
=
False
,
params_dtype
=
self
.
params_dtype
,
params_dtype
=
self
.
params_dtype
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.gate"
)
if
self
.
is_quant
:
if
self
.
is_quant
:
self
.
ws
=
DeepSpeedFPParameter
(
self
.
ws
=
DeepSpeedFPParameter
(
torch
.
Size
((
self
.
num_experts
,
2
*
self
.
intermediate_size
,
torch
.
Size
((
self
.
num_experts
,
2
*
self
.
intermediate_size
,
...
@@ -220,14 +221,12 @@ class ArcticAttention(nn.Module):
...
@@ -220,14 +221,12 @@ class ArcticAttention(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
ArcticConfig
,
config
:
ArcticConfig
,
layer_idx
:
Optional
[
int
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
layer_idx
=
layer_idx
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
...
@@ -298,26 +297,25 @@ class ArcticDecoderLayer(nn.Module):
...
@@ -298,26 +297,25 @@ class ArcticDecoderLayer(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
ArcticConfig
,
config
:
ArcticConfig
,
layer_idx
:
int
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
layer_idx
=
layer_idx
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
layer_idx
=
extract_layer_index
(
prefix
)
is_moe_layer
=
(
layer_idx
+
1
)
%
config
.
moe_layer_frequency
==
0
is_moe_layer
=
(
layer_idx
+
1
)
%
config
.
moe_layer_frequency
==
0
self
.
use_residual
=
config
.
use_residual
and
is_moe_layer
self
.
use_residual
=
config
.
use_residual
and
is_moe_layer
self
.
self_attn
=
ArcticAttention
(
config
,
self
.
self_attn
=
ArcticAttention
(
config
,
layer_idx
,
cache_config
,
cache_config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
)
prefix
=
f
"
{
prefix
}
.self_attn"
)
self
.
block_sparse_moe
=
ArcticMoE
(
self
.
block_sparse_moe
=
ArcticMoE
(
config
,
config
,
layer_id
=
layer_idx
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
reduce_results
=
(
not
self
.
use_residual
))
reduce_results
=
(
not
self
.
use_residual
),
prefix
=
f
"
{
prefix
}
.block_sparse_moe"
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
eps
=
config
.
rms_norm_eps
)
...
@@ -328,9 +326,9 @@ class ArcticDecoderLayer(nn.Module):
...
@@ -328,9 +326,9 @@ class ArcticDecoderLayer(nn.Module):
self
.
residual_layernorm
=
RMSNorm
(
config
.
hidden_size
,
self
.
residual_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
eps
=
config
.
rms_norm_eps
)
self
.
residual_mlp
=
ArcticMLP
(
config
,
self
.
residual_mlp
=
ArcticMLP
(
config
,
layer_id
=
layer_idx
,
is_residual_mlp
=
True
,
is_residual_mlp
=
True
,
reduce_results
=
False
)
reduce_results
=
False
,
prefix
=
f
"
{
prefix
}
.residual_mlp"
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -384,11 +382,8 @@ class ArcticModel(nn.Module):
...
@@ -384,11 +382,8 @@ class ArcticModel(nn.Module):
org_num_embeddings
=
self
.
vocab_size
)
org_num_embeddings
=
self
.
vocab_size
)
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
config
.
num_hidden_layers
,
lambda
prefix
:
ArcticDecoderLayer
(
config
,
lambda
prefix
:
ArcticDecoderLayer
(
int
(
prefix
.
split
(
"."
)[
-
1
]),
config
,
cache_config
,
quant_config
,
prefix
=
prefix
),
cache_config
,
quant_config
,
prefix
=
prefix
),
prefix
=
f
"
{
prefix
}
.layers"
)
prefix
=
f
"
{
prefix
}
.layers"
)
self
.
_attn_implementation
=
config
.
_attn_implementation
self
.
_attn_implementation
=
config
.
_attn_implementation
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
...
vllm/model_executor/models/deepseek.py
View file @
c0557478
...
@@ -49,7 +49,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
...
@@ -49,7 +49,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsPP
from
.interfaces
import
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
from
.utils
import
(
extract_layer_index
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
maybe_prefix
)
...
@@ -63,6 +63,7 @@ class DeepseekMLP(nn.Module):
...
@@ -63,6 +63,7 @@ class DeepseekMLP(nn.Module):
hidden_act
:
str
,
hidden_act
:
str
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
reduce_results
:
bool
=
True
,
reduce_results
:
bool
=
True
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
...
@@ -92,6 +93,7 @@ class DeepseekMoE(nn.Module):
...
@@ -92,6 +93,7 @@ class DeepseekMoE(nn.Module):
self
,
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
...
@@ -260,12 +262,12 @@ class DeepseekDecoderLayer(nn.Module):
...
@@ -260,12 +262,12 @@ class DeepseekDecoderLayer(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
layer_idx
:
int
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
layer_idx
=
extract_layer_index
(
prefix
)
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
...
@@ -285,13 +287,16 @@ class DeepseekDecoderLayer(nn.Module):
...
@@ -285,13 +287,16 @@ class DeepseekDecoderLayer(nn.Module):
if
(
config
.
n_routed_experts
is
not
None
if
(
config
.
n_routed_experts
is
not
None
and
layer_idx
>=
config
.
first_k_dense_replace
and
layer_idx
>=
config
.
first_k_dense_replace
and
layer_idx
%
config
.
moe_layer_freq
==
0
):
and
layer_idx
%
config
.
moe_layer_freq
==
0
):
self
.
mlp
=
DeepseekMoE
(
config
=
config
,
quant_config
=
quant_config
)
self
.
mlp
=
DeepseekMoE
(
config
=
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
)
else
:
else
:
self
.
mlp
=
DeepseekMLP
(
self
.
mlp
=
DeepseekMLP
(
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
)
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
eps
=
config
.
rms_norm_eps
)
...
@@ -347,11 +352,9 @@ class DeepseekModel(nn.Module):
...
@@ -347,11 +352,9 @@ class DeepseekModel(nn.Module):
)
)
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
config
.
num_hidden_layers
,
lambda
prefix
:
DeepseekDecoderLayer
(
config
,
lambda
prefix
:
DeepseekDecoderLayer
(
int
(
prefix
.
split
(
"."
)[
-
1
]),
config
,
cache_config
,
quant_config
=
quant_config
,
prefix
=
prefix
cache_config
,
),
quant_config
=
quant_config
,
prefix
=
prefix
),
prefix
=
f
"
{
prefix
}
.layers"
)
prefix
=
f
"
{
prefix
}
.layers"
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
...
...
vllm/model_executor/models/gemma2.py
View file @
c0557478
...
@@ -42,7 +42,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
...
@@ -42,7 +42,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
is_pp_missing_parameter
,
from
.utils
import
(
AutoWeightsLoader
,
extract_layer_index
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
maybe_prefix
)
...
@@ -85,7 +86,6 @@ class Gemma2MLP(nn.Module):
...
@@ -85,7 +86,6 @@ class Gemma2MLP(nn.Module):
class
Gemma2Attention
(
nn
.
Module
):
class
Gemma2Attention
(
nn
.
Module
):
def
__init__
(
self
,
def
__init__
(
self
,
layer_idx
:
int
,
config
:
Gemma2Config
,
config
:
Gemma2Config
,
hidden_size
:
int
,
hidden_size
:
int
,
num_heads
:
int
,
num_heads
:
int
,
...
@@ -98,7 +98,6 @@ class Gemma2Attention(nn.Module):
...
@@ -98,7 +98,6 @@ class Gemma2Attention(nn.Module):
attn_logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_logits_soft_cap
:
Optional
[
float
]
=
None
,
prefix
:
str
=
""
)
->
None
:
prefix
:
str
=
""
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
layer_idx
=
layer_idx
self
.
config
=
config
self
.
config
=
config
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
...
@@ -145,6 +144,7 @@ class Gemma2Attention(nn.Module):
...
@@ -145,6 +144,7 @@ class Gemma2Attention(nn.Module):
# reference:
# reference:
# https://github.com/huggingface/transformers/blob/54be2d7ae87e873482b984cc956e165ca4dc0ba3/src/transformers/models/gemma2/modeling_gemma2.py#L312 # noqa
# https://github.com/huggingface/transformers/blob/54be2d7ae87e873482b984cc956e165ca4dc0ba3/src/transformers/models/gemma2/modeling_gemma2.py#L312 # noqa
layer_idx
=
extract_layer_index
(
prefix
)
use_sliding_window
=
(
layer_idx
%
2
==
0
and
use_sliding_window
=
(
layer_idx
%
2
==
0
and
config
.
interleaved_sliding_window
is
not
None
)
config
.
interleaved_sliding_window
is
not
None
)
sliding_window
=
config
.
interleaved_sliding_window
if
\
sliding_window
=
config
.
interleaved_sliding_window
if
\
...
@@ -178,7 +178,6 @@ class Gemma2DecoderLayer(nn.Module):
...
@@ -178,7 +178,6 @@ class Gemma2DecoderLayer(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
layer_idx
:
int
,
config
:
Gemma2Config
,
config
:
Gemma2Config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
...
@@ -187,7 +186,6 @@ class Gemma2DecoderLayer(nn.Module):
...
@@ -187,7 +186,6 @@ class Gemma2DecoderLayer(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
self
.
self_attn
=
Gemma2Attention
(
self
.
self_attn
=
Gemma2Attention
(
layer_idx
=
layer_idx
,
config
=
config
,
config
=
config
,
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_heads
=
config
.
num_attention_heads
,
...
@@ -262,11 +260,8 @@ class Gemma2Model(nn.Module):
...
@@ -262,11 +260,8 @@ class Gemma2Model(nn.Module):
)
)
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
config
.
num_hidden_layers
,
lambda
prefix
:
Gemma2DecoderLayer
(
int
(
prefix
.
split
(
"."
)[
-
1
]),
lambda
prefix
:
Gemma2DecoderLayer
(
config
,
config
,
cache_config
,
quant_config
,
prefix
=
prefix
),
cache_config
,
quant_config
,
prefix
=
prefix
),
prefix
=
f
"
{
prefix
}
.layers"
)
prefix
=
f
"
{
prefix
}
.layers"
)
self
.
norm
=
GemmaRMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
norm
=
GemmaRMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
...
vllm/model_executor/models/olmoe.py
View file @
c0557478
...
@@ -181,7 +181,6 @@ class OlmoeDecoderLayer(nn.Module):
...
@@ -181,7 +181,6 @@ class OlmoeDecoderLayer(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
layer_idx
:
int
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
...
@@ -264,11 +263,8 @@ class OlmoeModel(nn.Module):
...
@@ -264,11 +263,8 @@ class OlmoeModel(nn.Module):
)
)
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
config
.
num_hidden_layers
,
lambda
prefix
:
OlmoeDecoderLayer
(
config
,
lambda
prefix
:
OlmoeDecoderLayer
(
int
(
prefix
.
split
(
"."
)[
-
1
]),
config
,
cache_config
,
quant_config
,
prefix
=
prefix
),
cache_config
,
quant_config
,
prefix
=
prefix
),
prefix
=
f
"
{
prefix
}
.layers"
)
prefix
=
f
"
{
prefix
}
.layers"
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
1e-5
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
1e-5
)
...
...
vllm/model_executor/models/qwen2_moe.py
View file @
c0557478
...
@@ -53,7 +53,7 @@ from vllm.sequence import IntermediateTensors
...
@@ -53,7 +53,7 @@ from vllm.sequence import IntermediateTensors
from
vllm.utils
import
print_warning_once
from
vllm.utils
import
print_warning_once
from
.interfaces
import
SupportsPP
from
.interfaces
import
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
from
.utils
import
(
extract_layer_index
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
maybe_prefix
)
...
@@ -244,7 +244,6 @@ class Qwen2MoeDecoderLayer(nn.Module):
...
@@ -244,7 +244,6 @@ class Qwen2MoeDecoderLayer(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
layer_idx
:
int
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
...
@@ -269,6 +268,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
...
@@ -269,6 +268,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
# Note: Qwen/Qwen2-57B-A14B-Instruct does not have
# Note: Qwen/Qwen2-57B-A14B-Instruct does not have
# `mlp_only_layers` in the config.
# `mlp_only_layers` in the config.
layer_idx
=
extract_layer_index
(
prefix
)
mlp_only_layers
=
([]
if
not
hasattr
(
config
,
"mlp_only_layers"
)
else
mlp_only_layers
=
([]
if
not
hasattr
(
config
,
"mlp_only_layers"
)
else
config
.
mlp_only_layers
)
config
.
mlp_only_layers
)
if
(
layer_idx
not
in
mlp_only_layers
)
and
(
if
(
layer_idx
not
in
mlp_only_layers
)
and
(
...
@@ -337,8 +337,6 @@ class Qwen2MoeModel(nn.Module):
...
@@ -337,8 +337,6 @@ class Qwen2MoeModel(nn.Module):
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
config
.
num_hidden_layers
,
lambda
prefix
:
Qwen2MoeDecoderLayer
(
config
=
config
,
lambda
prefix
:
Qwen2MoeDecoderLayer
(
config
=
config
,
layer_idx
=
int
(
prefix
.
split
(
"."
)[
-
1
]),
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
prefix
),
prefix
=
prefix
),
...
...
vllm/model_executor/models/utils.py
View file @
c0557478
...
@@ -629,3 +629,24 @@ def maybe_prefix(prefix: str, name: str) -> str:
...
@@ -629,3 +629,24 @@ def maybe_prefix(prefix: str, name: str) -> str:
The string "prefix.name" if prefix was non-empty, otherwise just "name".
The string "prefix.name" if prefix was non-empty, otherwise just "name".
"""
"""
return
name
if
not
prefix
else
f
"
{
prefix
}
.
{
name
}
"
return
name
if
not
prefix
else
f
"
{
prefix
}
.
{
name
}
"
def
extract_layer_index
(
layer_name
:
str
)
->
int
:
"""
Extract the layer index from the module name.
Examples:
- "encoder.layers.0" -> 0
- "encoder.layers.1.self_attn" -> 1
- "2.self_attn" -> 2
- "model.encoder.layers.0.sub.1" -> ValueError
"""
subnames
=
layer_name
.
split
(
"."
)
int_vals
:
List
[
int
]
=
[]
for
subname
in
subnames
:
try
:
int_vals
.
append
(
int
(
subname
))
except
ValueError
:
continue
assert
len
(
int_vals
)
==
1
,
(
f
"layer name
{
layer_name
}
should"
" only contain one integer"
)
return
int_vals
[
0
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment