Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1591c68f
Commit
1591c68f
authored
May 25, 2024
by
zhuwenwen
Browse files
merge v0.4.2
parents
09bcf00b
c7f2cf2b
Changes
265
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
619 additions
and
523 deletions
+619
-523
vllm/model_executor/models/gemma.py
vllm/model_executor/models/gemma.py
+20
-19
vllm/model_executor/models/gpt2.py
vllm/model_executor/models/gpt2.py
+16
-16
vllm/model_executor/models/gpt_bigcode.py
vllm/model_executor/models/gpt_bigcode.py
+16
-16
vllm/model_executor/models/gpt_j.py
vllm/model_executor/models/gpt_j.py
+16
-16
vllm/model_executor/models/gpt_neox.py
vllm/model_executor/models/gpt_neox.py
+16
-16
vllm/model_executor/models/internlm2.py
vllm/model_executor/models/internlm2.py
+17
-16
vllm/model_executor/models/jais.py
vllm/model_executor/models/jais.py
+17
-16
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+25
-21
vllm/model_executor/models/llava.py
vllm/model_executor/models/llava.py
+5
-4
vllm/model_executor/models/minicpm.py
vllm/model_executor/models/minicpm.py
+18
-17
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+148
-53
vllm/model_executor/models/mixtral_quant.py
vllm/model_executor/models/mixtral_quant.py
+21
-20
vllm/model_executor/models/mpt.py
vllm/model_executor/models/mpt.py
+16
-16
vllm/model_executor/models/olmo.py
vllm/model_executor/models/olmo.py
+160
-171
vllm/model_executor/models/opt.py
vllm/model_executor/models/opt.py
+18
-18
vllm/model_executor/models/orion.py
vllm/model_executor/models/orion.py
+17
-16
vllm/model_executor/models/phi.py
vllm/model_executor/models/phi.py
+17
-17
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+17
-16
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+17
-16
vllm/model_executor/models/qwen2_moe.py
vllm/model_executor/models/qwen2_moe.py
+22
-23
No files found.
vllm/model_executor/models/gemma.py
View file @
1591c68f
...
@@ -27,11 +27,12 @@ from vllm.distributed import get_tensor_model_parallel_world_size
...
@@ -27,11 +27,12 @@ from vllm.distributed import get_tensor_model_parallel_world_size
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
GeluAndMul
from
vllm.model_executor.layers.activation
import
GeluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
@@ -55,10 +56,10 @@ def _get_gemma_act_fn(
...
@@ -55,10 +56,10 @@ def _get_gemma_act_fn(
"in the config JSON file when it was initially released. "
"in the config JSON file when it was initially released. "
"Changing the activation function to approximate GeLU "
"Changing the activation function to approximate GeLU "
"(`gelu_pytorch_tanh`). If you want to use the legacy "
"(`gelu_pytorch_tanh`). If you want to use the legacy "
f
"`
{
hidden_act
}
`, edit the config JSON to set "
"`
%s
`, edit the config JSON to set "
f
"`hidden_activation=
{
hidden_act
}
` instead of `hidden_act`. "
"`hidden_activation=
%s
` instead of `hidden_act`. "
"See https://github.com/huggingface/transformers/pull/29402 "
"See https://github.com/huggingface/transformers/pull/29402 "
"for more details."
)
"for more details."
,
hidden_act
,
hidden_act
)
return
GeluAndMul
(
approximate
=
"tanh"
)
return
GeluAndMul
(
approximate
=
"tanh"
)
elif
hidden_activation
==
"gelu_pytorch_tanh"
:
elif
hidden_activation
==
"gelu_pytorch_tanh"
:
return
GeluAndMul
(
approximate
=
"tanh"
)
return
GeluAndMul
(
approximate
=
"tanh"
)
...
@@ -77,17 +78,17 @@ class GemmaMLP(nn.Module):
...
@@ -77,17 +78,17 @@ class GemmaMLP(nn.Module):
intermediate_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
Optional
[
str
]
=
None
,
hidden_act
:
Optional
[
str
]
=
None
,
hidden_activation
:
Optional
[
str
]
=
None
,
hidden_activation
:
Optional
[
str
]
=
None
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
hidden_size
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
self
.
act_fn
=
_get_gemma_act_fn
(
hidden_act
,
hidden_activation
)
self
.
act_fn
=
_get_gemma_act_fn
(
hidden_act
,
hidden_activation
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
...
@@ -106,7 +107,7 @@ class GemmaAttention(nn.Module):
...
@@ -106,7 +107,7 @@ class GemmaAttention(nn.Module):
head_dim
:
int
,
head_dim
:
int
,
max_position_embeddings
:
int
=
8192
,
max_position_embeddings
:
int
=
8192
,
rope_theta
:
float
=
10000
,
rope_theta
:
float
=
10000
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
)
->
None
:
quant_config
:
Optional
[
QuantizationConfig
]
=
None
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
...
@@ -135,13 +136,13 @@ class GemmaAttention(nn.Module):
...
@@ -135,13 +136,13 @@ class GemmaAttention(nn.Module):
self
.
total_num_heads
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
self
.
total_num_kv_heads
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
o_proj
=
RowParallelLinear
(
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
hidden_size
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
rotary_emb
=
get_rope
(
self
.
rotary_emb
=
get_rope
(
...
@@ -176,7 +177,7 @@ class GemmaDecoderLayer(nn.Module):
...
@@ -176,7 +177,7 @@ class GemmaDecoderLayer(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
GemmaConfig
,
config
:
GemmaConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
...
@@ -187,14 +188,14 @@ class GemmaDecoderLayer(nn.Module):
...
@@ -187,14 +188,14 @@ class GemmaDecoderLayer(nn.Module):
head_dim
=
config
.
head_dim
,
head_dim
=
config
.
head_dim
,
max_position_embeddings
=
config
.
max_position_embeddings
,
max_position_embeddings
=
config
.
max_position_embeddings
,
rope_theta
=
config
.
rope_theta
,
rope_theta
=
config
.
rope_theta
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
mlp
=
GemmaMLP
(
self
.
mlp
=
GemmaMLP
(
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
hidden_act
=
config
.
hidden_act
,
hidden_activation
=
getattr
(
config
,
"hidden_activation"
,
None
),
hidden_activation
=
getattr
(
config
,
"hidden_activation"
,
None
),
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
eps
=
config
.
rms_norm_eps
)
...
@@ -235,7 +236,7 @@ class GemmaModel(nn.Module):
...
@@ -235,7 +236,7 @@ class GemmaModel(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
GemmaConfig
,
config
:
GemmaConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
...
@@ -245,7 +246,7 @@ class GemmaModel(nn.Module):
...
@@ -245,7 +246,7 @@ class GemmaModel(nn.Module):
config
.
hidden_size
,
config
.
hidden_size
,
)
)
self
.
layers
=
nn
.
ModuleList
([
self
.
layers
=
nn
.
ModuleList
([
GemmaDecoderLayer
(
config
,
linear_method
)
GemmaDecoderLayer
(
config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
@@ -308,14 +309,14 @@ class GemmaForCausalLM(nn.Module):
...
@@ -308,14 +309,14 @@ class GemmaForCausalLM(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
GemmaConfig
,
config
:
GemmaConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
)
->
None
:
del
lora_config
# Unused.
del
lora_config
# Unused.
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
quant_config
=
quant_config
self
.
model
=
GemmaModel
(
config
,
linear_method
)
self
.
model
=
GemmaModel
(
config
,
quant_config
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/gpt2.py
View file @
1591c68f
...
@@ -27,10 +27,11 @@ from vllm.attention import Attention, AttentionMetadata
...
@@ -27,10 +27,11 @@ from vllm.attention import Attention, AttentionMetadata
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
VocabParallelEmbedding
)
...
@@ -44,7 +45,7 @@ class GPT2Attention(nn.Module):
...
@@ -44,7 +45,7 @@ class GPT2Attention(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
GPT2Config
,
config
:
GPT2Config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
...
@@ -61,13 +62,13 @@ class GPT2Attention(nn.Module):
...
@@ -61,13 +62,13 @@ class GPT2Attention(nn.Module):
self
.
head_dim
,
self
.
head_dim
,
total_num_heads
,
total_num_heads
,
bias
=
True
,
bias
=
True
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
c_proj
=
RowParallelLinear
(
self
.
c_proj
=
RowParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
True
,
bias
=
True
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
scale
=
self
.
scale
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
scale
=
self
.
scale
)
...
@@ -90,7 +91,7 @@ class GPT2MLP(nn.Module):
...
@@ -90,7 +91,7 @@ class GPT2MLP(nn.Module):
self
,
self
,
intermediate_size
:
int
,
intermediate_size
:
int
,
config
:
GPT2Config
,
config
:
GPT2Config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
hidden_size
=
config
.
hidden_size
hidden_size
=
config
.
hidden_size
...
@@ -98,15 +99,14 @@ class GPT2MLP(nn.Module):
...
@@ -98,15 +99,14 @@ class GPT2MLP(nn.Module):
hidden_size
,
hidden_size
,
intermediate_size
,
intermediate_size
,
bias
=
True
,
bias
=
True
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
c_proj
=
RowParallelLinear
(
self
.
c_proj
=
RowParallelLinear
(
intermediate_size
,
intermediate_size
,
hidden_size
,
hidden_size
,
bias
=
True
,
bias
=
True
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
self
.
act
=
get_act_fn
(
config
.
activation_function
,
quant_config
,
self
.
act
=
get_act_fn
(
config
.
activation_function
,
quant_config
,
intermediate_size
)
intermediate_size
)
...
@@ -122,7 +122,7 @@ class GPT2Block(nn.Module):
...
@@ -122,7 +122,7 @@ class GPT2Block(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
GPT2Config
,
config
:
GPT2Config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
hidden_size
=
config
.
hidden_size
hidden_size
=
config
.
hidden_size
...
@@ -130,9 +130,9 @@ class GPT2Block(nn.Module):
...
@@ -130,9 +130,9 @@ class GPT2Block(nn.Module):
hidden_size
)
hidden_size
)
self
.
ln_1
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_1
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
attn
=
GPT2Attention
(
config
,
linear_method
)
self
.
attn
=
GPT2Attention
(
config
,
quant_config
)
self
.
ln_2
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_2
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
mlp
=
GPT2MLP
(
inner_dim
,
config
,
linear_method
)
self
.
mlp
=
GPT2MLP
(
inner_dim
,
config
,
quant_config
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -163,7 +163,7 @@ class GPT2Model(nn.Module):
...
@@ -163,7 +163,7 @@ class GPT2Model(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
GPT2Config
,
config
:
GPT2Config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
...
@@ -174,7 +174,7 @@ class GPT2Model(nn.Module):
...
@@ -174,7 +174,7 @@ class GPT2Model(nn.Module):
self
.
wte
=
VocabParallelEmbedding
(
config
.
vocab_size
,
self
.
embed_dim
)
self
.
wte
=
VocabParallelEmbedding
(
config
.
vocab_size
,
self
.
embed_dim
)
self
.
wpe
=
nn
.
Embedding
(
config
.
max_position_embeddings
,
self
.
embed_dim
)
self
.
wpe
=
nn
.
Embedding
(
config
.
max_position_embeddings
,
self
.
embed_dim
)
self
.
h
=
nn
.
ModuleList
([
self
.
h
=
nn
.
ModuleList
([
GPT2Block
(
config
,
linear_method
)
GPT2Block
(
config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
])
self
.
ln_f
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_f
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_epsilon
)
...
@@ -203,12 +203,12 @@ class GPT2LMHeadModel(nn.Module):
...
@@ -203,12 +203,12 @@ class GPT2LMHeadModel(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
GPT2Config
,
config
:
GPT2Config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
quant_config
=
quant_config
self
.
transformer
=
GPT2Model
(
config
,
linear_method
)
self
.
transformer
=
GPT2Model
(
config
,
quant_config
)
self
.
lm_head_weight
=
self
.
transformer
.
wte
.
weight
self
.
lm_head_weight
=
self
.
transformer
.
wte
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/gpt_bigcode.py
View file @
1591c68f
...
@@ -28,10 +28,11 @@ from vllm.attention import Attention, AttentionMetadata
...
@@ -28,10 +28,11 @@ from vllm.attention import Attention, AttentionMetadata
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
VocabParallelEmbedding
)
...
@@ -45,7 +46,7 @@ class GPTBigCodeAttention(nn.Module):
...
@@ -45,7 +46,7 @@ class GPTBigCodeAttention(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
GPTBigCodeConfig
,
config
:
GPTBigCodeConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
...
@@ -72,14 +73,14 @@ class GPTBigCodeAttention(nn.Module):
...
@@ -72,14 +73,14 @@ class GPTBigCodeAttention(nn.Module):
total_num_heads
,
total_num_heads
,
total_num_kv_heads
,
total_num_kv_heads
,
bias
=
True
,
bias
=
True
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
c_proj
=
RowParallelLinear
(
self
.
c_proj
=
RowParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
True
,
bias
=
True
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
head_dim
,
...
@@ -111,7 +112,7 @@ class GPTBigMLP(nn.Module):
...
@@ -111,7 +112,7 @@ class GPTBigMLP(nn.Module):
self
,
self
,
intermediate_size
:
int
,
intermediate_size
:
int
,
config
:
GPTBigCodeConfig
,
config
:
GPTBigCodeConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
hidden_size
=
config
.
hidden_size
hidden_size
=
config
.
hidden_size
...
@@ -119,15 +120,14 @@ class GPTBigMLP(nn.Module):
...
@@ -119,15 +120,14 @@ class GPTBigMLP(nn.Module):
hidden_size
,
hidden_size
,
intermediate_size
,
intermediate_size
,
bias
=
True
,
bias
=
True
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
c_proj
=
RowParallelLinear
(
self
.
c_proj
=
RowParallelLinear
(
intermediate_size
,
intermediate_size
,
hidden_size
,
hidden_size
,
bias
=
True
,
bias
=
True
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
self
.
act
=
get_act_fn
(
config
.
activation_function
,
quant_config
,
self
.
act
=
get_act_fn
(
config
.
activation_function
,
quant_config
,
intermediate_size
)
intermediate_size
)
...
@@ -143,7 +143,7 @@ class GPTBigCodeBlock(nn.Module):
...
@@ -143,7 +143,7 @@ class GPTBigCodeBlock(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
GPTBigCodeConfig
,
config
:
GPTBigCodeConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
hidden_size
=
config
.
hidden_size
hidden_size
=
config
.
hidden_size
...
@@ -151,9 +151,9 @@ class GPTBigCodeBlock(nn.Module):
...
@@ -151,9 +151,9 @@ class GPTBigCodeBlock(nn.Module):
hidden_size
)
hidden_size
)
self
.
ln_1
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_1
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
attn
=
GPTBigCodeAttention
(
config
,
linear_method
)
self
.
attn
=
GPTBigCodeAttention
(
config
,
quant_config
)
self
.
ln_2
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_2
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
mlp
=
GPTBigMLP
(
inner_dim
,
config
,
linear_method
)
self
.
mlp
=
GPTBigMLP
(
inner_dim
,
config
,
quant_config
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -184,7 +184,7 @@ class GPTBigCodeModel(nn.Module):
...
@@ -184,7 +184,7 @@ class GPTBigCodeModel(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
GPTBigCodeConfig
,
config
:
GPTBigCodeConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
...
@@ -195,7 +195,7 @@ class GPTBigCodeModel(nn.Module):
...
@@ -195,7 +195,7 @@ class GPTBigCodeModel(nn.Module):
self
.
wte
=
VocabParallelEmbedding
(
config
.
vocab_size
,
self
.
embed_dim
)
self
.
wte
=
VocabParallelEmbedding
(
config
.
vocab_size
,
self
.
embed_dim
)
self
.
wpe
=
nn
.
Embedding
(
config
.
max_position_embeddings
,
self
.
embed_dim
)
self
.
wpe
=
nn
.
Embedding
(
config
.
max_position_embeddings
,
self
.
embed_dim
)
self
.
h
=
nn
.
ModuleList
([
self
.
h
=
nn
.
ModuleList
([
GPTBigCodeBlock
(
config
,
linear_method
)
GPTBigCodeBlock
(
config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
])
self
.
ln_f
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_f
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_epsilon
)
...
@@ -224,12 +224,12 @@ class GPTBigCodeForCausalLM(nn.Module):
...
@@ -224,12 +224,12 @@ class GPTBigCodeForCausalLM(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
GPTBigCodeConfig
,
config
:
GPTBigCodeConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
quant_config
=
quant_config
self
.
transformer
=
GPTBigCodeModel
(
config
,
linear_method
)
self
.
transformer
=
GPTBigCodeModel
(
config
,
quant_config
)
self
.
lm_head_weight
=
self
.
transformer
.
wte
.
weight
self
.
lm_head_weight
=
self
.
transformer
.
wte
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/gpt_j.py
View file @
1591c68f
...
@@ -26,10 +26,11 @@ from vllm.attention import Attention, AttentionMetadata
...
@@ -26,10 +26,11 @@ from vllm.attention import Attention, AttentionMetadata
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
@@ -44,7 +45,7 @@ class GPTJAttention(nn.Module):
...
@@ -44,7 +45,7 @@ class GPTJAttention(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
GPTJConfig
,
config
:
GPTJConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
total_num_heads
=
config
.
num_attention_heads
self
.
total_num_heads
=
config
.
num_attention_heads
...
@@ -56,13 +57,13 @@ class GPTJAttention(nn.Module):
...
@@ -56,13 +57,13 @@ class GPTJAttention(nn.Module):
self
.
head_size
,
self
.
head_size
,
self
.
total_num_heads
,
self
.
total_num_heads
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
out_proj
=
RowParallelLinear
(
self
.
out_proj
=
RowParallelLinear
(
config
.
hidden_size
,
config
.
hidden_size
,
config
.
hidden_size
,
config
.
hidden_size
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
tp_world_size
=
get_tensor_model_parallel_world_size
()
tp_world_size
=
get_tensor_model_parallel_world_size
()
...
@@ -105,21 +106,20 @@ class GPTJMLP(nn.Module):
...
@@ -105,21 +106,20 @@ class GPTJMLP(nn.Module):
self
,
self
,
intermediate_size
:
int
,
intermediate_size
:
int
,
config
:
GPTJConfig
,
config
:
GPTJConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
hidden_size
=
config
.
n_embd
hidden_size
=
config
.
n_embd
self
.
fc_in
=
ColumnParallelLinear
(
self
.
fc_in
=
ColumnParallelLinear
(
hidden_size
,
hidden_size
,
intermediate_size
,
intermediate_size
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
fc_out
=
RowParallelLinear
(
self
.
fc_out
=
RowParallelLinear
(
intermediate_size
,
intermediate_size
,
hidden_size
,
hidden_size
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
self
.
act
=
get_act_fn
(
config
.
activation_function
,
quant_config
,
self
.
act
=
get_act_fn
(
config
.
activation_function
,
quant_config
,
intermediate_size
)
intermediate_size
)
...
@@ -135,14 +135,14 @@ class GPTJBlock(nn.Module):
...
@@ -135,14 +135,14 @@ class GPTJBlock(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
GPTJConfig
,
config
:
GPTJConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
inner_dim
=
(
4
*
config
.
n_embd
inner_dim
=
(
4
*
config
.
n_embd
if
config
.
n_inner
is
None
else
config
.
n_inner
)
if
config
.
n_inner
is
None
else
config
.
n_inner
)
self
.
ln_1
=
nn
.
LayerNorm
(
config
.
n_embd
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_1
=
nn
.
LayerNorm
(
config
.
n_embd
,
eps
=
config
.
layer_norm_epsilon
)
self
.
attn
=
GPTJAttention
(
config
,
linear_method
)
self
.
attn
=
GPTJAttention
(
config
,
quant_config
)
self
.
mlp
=
GPTJMLP
(
inner_dim
,
config
,
linear_method
)
self
.
mlp
=
GPTJMLP
(
inner_dim
,
config
,
quant_config
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -169,7 +169,7 @@ class GPTJModel(nn.Module):
...
@@ -169,7 +169,7 @@ class GPTJModel(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
GPTJConfig
,
config
:
GPTJConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
...
@@ -179,7 +179,7 @@ class GPTJModel(nn.Module):
...
@@ -179,7 +179,7 @@ class GPTJModel(nn.Module):
self
.
embed_dim
,
self
.
embed_dim
,
)
)
self
.
h
=
nn
.
ModuleList
(
self
.
h
=
nn
.
ModuleList
(
[
GPTJBlock
(
config
,
linear_method
)
for
_
in
range
(
config
.
n_layer
)])
[
GPTJBlock
(
config
,
quant_config
)
for
_
in
range
(
config
.
n_layer
)])
self
.
ln_f
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_f
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_epsilon
)
def
forward
(
def
forward
(
...
@@ -207,13 +207,13 @@ class GPTJForCausalLM(nn.Module):
...
@@ -207,13 +207,13 @@ class GPTJForCausalLM(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
GPTJConfig
,
config
:
GPTJConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
quant_config
=
quant_config
assert
not
config
.
tie_word_embeddings
assert
not
config
.
tie_word_embeddings
self
.
transformer
=
GPTJModel
(
config
,
linear_method
)
self
.
transformer
=
GPTJModel
(
config
,
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
vocab_size
,
config
.
n_embd
,
config
.
n_embd
,
...
...
vllm/model_executor/models/gpt_neox.py
View file @
1591c68f
...
@@ -26,10 +26,11 @@ from vllm.attention import Attention, AttentionMetadata
...
@@ -26,10 +26,11 @@ from vllm.attention import Attention, AttentionMetadata
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
@@ -44,7 +45,7 @@ class GPTNeoXAttention(nn.Module):
...
@@ -44,7 +45,7 @@ class GPTNeoXAttention(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
GPTNeoXConfig
,
config
:
GPTNeoXConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
total_num_heads
=
config
.
num_attention_heads
self
.
total_num_heads
=
config
.
num_attention_heads
...
@@ -63,13 +64,13 @@ class GPTNeoXAttention(nn.Module):
...
@@ -63,13 +64,13 @@ class GPTNeoXAttention(nn.Module):
self
.
head_size
,
self
.
head_size
,
self
.
total_num_heads
,
self
.
total_num_heads
,
bias
=
self
.
bias
,
bias
=
self
.
bias
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
dense
=
RowParallelLinear
(
self
.
dense
=
RowParallelLinear
(
config
.
hidden_size
,
config
.
hidden_size
,
config
.
hidden_size
,
config
.
hidden_size
,
bias
=
self
.
bias
,
bias
=
self
.
bias
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
scaling
=
self
.
head_size
**-
0.5
scaling
=
self
.
head_size
**-
0.5
rotary_dim
=
int
(
self
.
head_size
*
config
.
rotary_pct
)
rotary_dim
=
int
(
self
.
head_size
*
config
.
rotary_pct
)
...
@@ -105,20 +106,19 @@ class GPTNeoXMLP(nn.Module):
...
@@ -105,20 +106,19 @@ class GPTNeoXMLP(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
GPTNeoXConfig
,
config
:
GPTNeoXConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
dense_h_to_4h
=
ColumnParallelLinear
(
self
.
dense_h_to_4h
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
hidden_size
,
config
.
intermediate_size
,
config
.
intermediate_size
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
dense_4h_to_h
=
RowParallelLinear
(
self
.
dense_4h_to_h
=
RowParallelLinear
(
config
.
intermediate_size
,
config
.
intermediate_size
,
config
.
hidden_size
,
config
.
hidden_size
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
self
.
act
=
get_act_fn
(
config
.
hidden_act
,
quant_config
,
self
.
act
=
get_act_fn
(
config
.
hidden_act
,
quant_config
,
config
.
intermediate_size
)
config
.
intermediate_size
)
...
@@ -134,7 +134,7 @@ class GPTNeoXLayer(nn.Module):
...
@@ -134,7 +134,7 @@ class GPTNeoXLayer(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
GPTNeoXConfig
,
config
:
GPTNeoXConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
use_parallel_residual
=
config
.
use_parallel_residual
self
.
use_parallel_residual
=
config
.
use_parallel_residual
...
@@ -142,8 +142,8 @@ class GPTNeoXLayer(nn.Module):
...
@@ -142,8 +142,8 @@ class GPTNeoXLayer(nn.Module):
eps
=
config
.
layer_norm_eps
)
eps
=
config
.
layer_norm_eps
)
self
.
post_attention_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
self
.
post_attention_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
eps
=
config
.
layer_norm_eps
)
self
.
attention
=
GPTNeoXAttention
(
config
,
linear_method
)
self
.
attention
=
GPTNeoXAttention
(
config
,
quant_config
)
self
.
mlp
=
GPTNeoXMLP
(
config
,
linear_method
)
self
.
mlp
=
GPTNeoXMLP
(
config
,
quant_config
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -182,7 +182,7 @@ class GPTNeoXModel(nn.Module):
...
@@ -182,7 +182,7 @@ class GPTNeoXModel(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
GPTNeoXConfig
,
config
:
GPTNeoXConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
...
@@ -192,7 +192,7 @@ class GPTNeoXModel(nn.Module):
...
@@ -192,7 +192,7 @@ class GPTNeoXModel(nn.Module):
config
.
hidden_size
,
config
.
hidden_size
,
)
)
self
.
layers
=
nn
.
ModuleList
([
self
.
layers
=
nn
.
ModuleList
([
GPTNeoXLayer
(
config
,
linear_method
)
GPTNeoXLayer
(
config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
])
self
.
final_layer_norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
self
.
final_layer_norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
...
@@ -223,12 +223,12 @@ class GPTNeoXForCausalLM(nn.Module):
...
@@ -223,12 +223,12 @@ class GPTNeoXForCausalLM(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
,
config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
quant_config
=
quant_config
self
.
gpt_neox
=
GPTNeoXModel
(
config
,
linear_method
)
self
.
gpt_neox
=
GPTNeoXModel
(
config
,
quant_config
)
self
.
embed_out
=
ParallelLMHead
(
self
.
embed_out
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
...
...
vllm/model_executor/models/internlm2.py
View file @
1591c68f
...
@@ -9,11 +9,12 @@ from vllm.attention import Attention, AttentionMetadata
...
@@ -9,11 +9,12 @@ from vllm.attention import Attention, AttentionMetadata
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
@@ -30,17 +31,17 @@ class InternLM2MLP(nn.Module):
...
@@ -30,17 +31,17 @@ class InternLM2MLP(nn.Module):
hidden_size
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
hidden_act
:
str
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
self
.
w2
=
RowParallelLinear
(
intermediate_size
,
self
.
w2
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
hidden_size
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
if
hidden_act
!=
"silu"
:
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
"Only silu is supported for now."
)
...
@@ -63,7 +64,7 @@ class InternLM2Attention(nn.Module):
...
@@ -63,7 +64,7 @@ class InternLM2Attention(nn.Module):
rope_theta
:
float
=
10000
,
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
8192
,
max_position_embeddings
:
int
=
8192
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
...
@@ -94,13 +95,13 @@ class InternLM2Attention(nn.Module):
...
@@ -94,13 +95,13 @@ class InternLM2Attention(nn.Module):
self
.
total_num_heads
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
self
.
total_num_kv_heads
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
wo
=
RowParallelLinear
(
self
.
wo
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
hidden_size
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
rotary_emb
=
get_rope
(
self
.
rotary_emb
=
get_rope
(
...
@@ -135,7 +136,7 @@ class InternLMDecoderLayer(nn.Module):
...
@@ -135,7 +136,7 @@ class InternLMDecoderLayer(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
...
@@ -150,13 +151,13 @@ class InternLMDecoderLayer(nn.Module):
...
@@ -150,13 +151,13 @@ class InternLMDecoderLayer(nn.Module):
rope_theta
=
rope_theta
,
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
max_position_embeddings
=
max_position_embeddings
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
feed_forward
=
InternLM2MLP
(
self
.
feed_forward
=
InternLM2MLP
(
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
hidden_act
=
config
.
hidden_act
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
attention_norm
=
RMSNorm
(
config
.
hidden_size
,
self
.
attention_norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
eps
=
config
.
rms_norm_eps
)
...
@@ -195,7 +196,7 @@ class InternLM2Model(nn.Module):
...
@@ -195,7 +196,7 @@ class InternLM2Model(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
...
@@ -206,7 +207,7 @@ class InternLM2Model(nn.Module):
...
@@ -206,7 +207,7 @@ class InternLM2Model(nn.Module):
config
.
hidden_size
,
config
.
hidden_size
,
)
)
self
.
layers
=
nn
.
ModuleList
([
self
.
layers
=
nn
.
ModuleList
([
InternLMDecoderLayer
(
config
,
linear_method
)
InternLMDecoderLayer
(
config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
@@ -238,12 +239,12 @@ class InternLM2ForCausalLM(nn.Module):
...
@@ -238,12 +239,12 @@ class InternLM2ForCausalLM(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
quant_config
=
quant_config
self
.
model
=
InternLM2Model
(
config
,
linear_method
)
self
.
model
=
InternLM2Model
(
config
,
quant_config
)
self
.
output
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
output
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/jais.py
View file @
1591c68f
...
@@ -29,10 +29,11 @@ from vllm.attention import Attention, AttentionMetadata
...
@@ -29,10 +29,11 @@ from vllm.attention import Attention, AttentionMetadata
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
VocabParallelEmbedding
)
...
@@ -68,7 +69,7 @@ class JAISAttention(nn.Module):
...
@@ -68,7 +69,7 @@ class JAISAttention(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
JAISConfig
,
config
:
JAISConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
...
@@ -88,13 +89,13 @@ class JAISAttention(nn.Module):
...
@@ -88,13 +89,13 @@ class JAISAttention(nn.Module):
self
.
head_dim
,
self
.
head_dim
,
total_num_heads
,
total_num_heads
,
bias
=
True
,
bias
=
True
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
c_proj
=
RowParallelLinear
(
self
.
c_proj
=
RowParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
True
,
bias
=
True
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
tp_rank
=
get_tensor_model_parallel_rank
()
tp_rank
=
get_tensor_model_parallel_rank
()
...
@@ -128,7 +129,7 @@ class JAISMLP(nn.Module):
...
@@ -128,7 +129,7 @@ class JAISMLP(nn.Module):
self
,
self
,
intermediate_size
:
int
,
intermediate_size
:
int
,
config
:
JAISConfig
,
config
:
JAISConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
hidden_size
=
config
.
hidden_size
hidden_size
=
config
.
hidden_size
...
@@ -137,19 +138,19 @@ class JAISMLP(nn.Module):
...
@@ -137,19 +138,19 @@ class JAISMLP(nn.Module):
hidden_size
,
hidden_size
,
intermediate_size
,
intermediate_size
,
bias
=
True
,
bias
=
True
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
c_fc2
=
(
ColumnParallelLinear
(
self
.
c_fc2
=
(
ColumnParallelLinear
(
hidden_size
,
hidden_size
,
intermediate_size
,
intermediate_size
,
bias
=
True
,
bias
=
True
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
if
self
.
swiglu
else
None
)
)
if
self
.
swiglu
else
None
)
self
.
c_proj
=
RowParallelLinear
(
self
.
c_proj
=
RowParallelLinear
(
intermediate_size
,
intermediate_size
,
hidden_size
,
hidden_size
,
bias
=
True
,
bias
=
True
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
act
=
SwiGLUActivation
()
self
.
act
=
SwiGLUActivation
()
...
@@ -169,7 +170,7 @@ class JAISBlock(nn.Module):
...
@@ -169,7 +170,7 @@ class JAISBlock(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
JAISConfig
,
config
:
JAISConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
hidden_size
=
config
.
hidden_size
hidden_size
=
config
.
hidden_size
...
@@ -177,9 +178,9 @@ class JAISBlock(nn.Module):
...
@@ -177,9 +178,9 @@ class JAISBlock(nn.Module):
hidden_size
)
hidden_size
)
self
.
ln_1
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_1
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
attn
=
JAISAttention
(
config
,
linear_method
)
self
.
attn
=
JAISAttention
(
config
,
quant_config
)
self
.
ln_2
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_2
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
mlp
=
JAISMLP
(
inner_dim
,
config
,
linear_method
)
self
.
mlp
=
JAISMLP
(
inner_dim
,
config
,
quant_config
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -210,7 +211,7 @@ class JAISModel(nn.Module):
...
@@ -210,7 +211,7 @@ class JAISModel(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
JAISConfig
,
config
:
JAISConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
...
@@ -227,7 +228,7 @@ class JAISModel(nn.Module):
...
@@ -227,7 +228,7 @@ class JAISModel(nn.Module):
else
:
else
:
self
.
embeddings_scale
=
config
.
mup_embeddings_scale
self
.
embeddings_scale
=
config
.
mup_embeddings_scale
self
.
h
=
nn
.
ModuleList
([
self
.
h
=
nn
.
ModuleList
([
JAISBlock
(
config
,
linear_method
)
JAISBlock
(
config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
])
self
.
ln_f
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_f
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_epsilon
)
...
@@ -261,12 +262,12 @@ class JAISLMHeadModel(nn.Module):
...
@@ -261,12 +262,12 @@ class JAISLMHeadModel(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
JAISConfig
,
config
:
JAISConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
quant_config
=
quant_config
self
.
transformer
=
JAISModel
(
config
,
linear_method
)
self
.
transformer
=
JAISModel
(
config
,
quant_config
)
self
.
lm_head_weight
=
self
.
transformer
.
wte
.
weight
self
.
lm_head_weight
=
self
.
transformer
.
wte
.
weight
if
hasattr
(
config
,
"width_scale"
):
if
hasattr
(
config
,
"width_scale"
):
self
.
output_logits_scale
=
config
.
width_scale
self
.
output_logits_scale
=
config
.
width_scale
...
...
vllm/model_executor/models/llama.py
View file @
1591c68f
...
@@ -33,11 +33,12 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
...
@@ -33,11 +33,12 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
@@ -56,17 +57,17 @@ class LlamaMLP(nn.Module):
...
@@ -56,17 +57,17 @@ class LlamaMLP(nn.Module):
hidden_size
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
hidden_act
:
str
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QKVParallelLinear
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
hidden_size
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
if
hidden_act
!=
"silu"
:
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
"Only silu is supported for now."
)
...
@@ -89,7 +90,7 @@ class LlamaAttention(nn.Module):
...
@@ -89,7 +90,7 @@ class LlamaAttention(nn.Module):
rope_theta
:
float
=
10000
,
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
8192
,
max_position_embeddings
:
int
=
8192
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
bias
:
bool
=
False
,
bias
:
bool
=
False
,
sliding_window
:
Optional
[
int
]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
,
)
->
None
:
)
->
None
:
...
@@ -131,13 +132,13 @@ class LlamaAttention(nn.Module):
...
@@ -131,13 +132,13 @@ class LlamaAttention(nn.Module):
self
.
total_num_heads
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
self
.
total_num_kv_heads
,
bias
=
bias
,
bias
=
bias
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
o_proj
=
RowParallelLinear
(
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
hidden_size
,
bias
=
bias
,
bias
=
bias
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
rotary_emb
=
get_rope
(
self
.
rotary_emb
=
get_rope
(
...
@@ -174,12 +175,16 @@ class LlamaDecoderLayer(nn.Module):
...
@@ -174,12 +175,16 @@ class LlamaDecoderLayer(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
LlamaConfig
,
config
:
LlamaConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
if
rope_scaling
is
not
None
and
getattr
(
config
,
"original_max_position_embeddings"
,
None
):
rope_scaling
[
"original_max_position_embeddings"
]
=
(
config
.
original_max_position_embeddings
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
8192
)
sliding_window
=
getattr
(
config
,
"sliding_window"
,
None
)
sliding_window
=
getattr
(
config
,
"sliding_window"
,
None
)
...
@@ -195,7 +200,7 @@ class LlamaDecoderLayer(nn.Module):
...
@@ -195,7 +200,7 @@ class LlamaDecoderLayer(nn.Module):
rope_theta
=
rope_theta
,
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
max_position_embeddings
=
max_position_embeddings
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
bias
=
attention_bias
,
bias
=
attention_bias
,
sliding_window
=
sliding_window
,
sliding_window
=
sliding_window
,
)
)
...
@@ -203,7 +208,7 @@ class LlamaDecoderLayer(nn.Module):
...
@@ -203,7 +208,7 @@ class LlamaDecoderLayer(nn.Module):
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
hidden_act
=
config
.
hidden_act
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
eps
=
config
.
rms_norm_eps
)
...
@@ -244,7 +249,7 @@ class LlamaModel(nn.Module):
...
@@ -244,7 +249,7 @@ class LlamaModel(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
LlamaConfig
,
config
:
LlamaConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -260,7 +265,7 @@ class LlamaModel(nn.Module):
...
@@ -260,7 +265,7 @@ class LlamaModel(nn.Module):
org_num_embeddings
=
config
.
vocab_size
,
org_num_embeddings
=
config
.
vocab_size
,
)
)
self
.
layers
=
nn
.
ModuleList
([
self
.
layers
=
nn
.
ModuleList
([
LlamaDecoderLayer
(
config
,
linear_method
)
LlamaDecoderLayer
(
config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
@@ -325,13 +330,12 @@ class LlamaForCausalLM(nn.Module):
...
@@ -325,13 +330,12 @@ class LlamaForCausalLM(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
LlamaConfig
,
config
:
LlamaConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
model
=
LlamaModel
(
config
,
quant_config
,
lora_config
=
lora_config
)
self
.
model
=
LlamaModel
(
config
,
linear_method
,
lora_config
=
lora_config
)
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
...
@@ -378,11 +382,11 @@ class LlamaForCausalLM(nn.Module):
...
@@ -378,11 +382,11 @@ class LlamaForCausalLM(nn.Module):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"
.
qkv_proj"
,
"
.
q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"
.
qkv_proj"
,
"
.
k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"
.
qkv_proj"
,
"
.
v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"
.
gate_up_proj"
,
"
.
gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
(
"
.
gate_up_proj"
,
"
.
up_proj"
,
1
),
]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
...
...
vllm/model_executor/models/llava.py
View file @
1591c68f
...
@@ -9,8 +9,9 @@ from transformers import CLIPVisionModel, LlavaConfig
...
@@ -9,8 +9,9 @@ from transformers import CLIPVisionModel, LlavaConfig
from
vllm.attention
import
AttentionMetadata
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VisionLanguageConfig
from
vllm.config
import
VisionLanguageConfig
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
LinearMethodBase
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
...
@@ -61,7 +62,7 @@ class LlavaForConditionalGeneration(nn.Module):
...
@@ -61,7 +62,7 @@ class LlavaForConditionalGeneration(nn.Module):
def
__init__
(
self
,
def
__init__
(
self
,
config
:
"LlavaConfig"
,
config
:
"LlavaConfig"
,
vision_language_config
:
VisionLanguageConfig
,
vision_language_config
:
VisionLanguageConfig
,
linear_method
:
Optional
[
"LinearMethodBase
"
]
=
None
)
->
None
:
quant_config
:
Optional
[
"QuantizationConfig
"
]
=
None
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
...
@@ -83,8 +84,8 @@ class LlavaForConditionalGeneration(nn.Module):
...
@@ -83,8 +84,8 @@ class LlavaForConditionalGeneration(nn.Module):
text_hidden_size
=
config
.
text_config
.
hidden_size
,
text_hidden_size
=
config
.
text_config
.
hidden_size
,
projector_hidden_act
=
config
.
projector_hidden_act
)
projector_hidden_act
=
config
.
projector_hidden_act
)
self
.
linear_method
=
linear_method
self
.
quant_config
=
quant_config
self
.
language_model
=
LlamaModel
(
config
.
text_config
,
linear_method
)
self
.
language_model
=
LlamaModel
(
config
.
text_config
,
quant_config
)
self
.
unpadded_vocab_size
=
config
.
text_config
.
vocab_size
self
.
unpadded_vocab_size
=
config
.
text_config
.
vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_vocab_size
,
self
.
unpadded_vocab_size
,
...
...
vllm/model_executor/models/minicpm.py
View file @
1591c68f
...
@@ -35,12 +35,13 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
...
@@ -35,12 +35,13 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
ReplicatedLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
@@ -84,7 +85,7 @@ class MiniCPMMoE(nn.Module):
...
@@ -84,7 +85,7 @@ class MiniCPMMoE(nn.Module):
self
.
num_total_experts
,
self
.
num_total_experts
,
bias
=
False
,
bias
=
False
,
params_dtype
=
self
.
params_dtype
,
params_dtype
=
self
.
params_dtype
,
linear_method
=
None
)
quant_config
=
None
)
self
.
ws
=
nn
.
Parameter
(
self
.
ws
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_total_experts
,
torch
.
empty
(
self
.
num_total_experts
,
...
@@ -147,17 +148,17 @@ class MiniCPMMLP(nn.Module):
...
@@ -147,17 +148,17 @@ class MiniCPMMLP(nn.Module):
hidden_size
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
hidden_act
:
str
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
hidden_size
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
if
hidden_act
!=
"silu"
:
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
"Only silu is supported for now."
)
...
@@ -180,7 +181,7 @@ class MiniCPMAttention(nn.Module):
...
@@ -180,7 +181,7 @@ class MiniCPMAttention(nn.Module):
rope_theta
:
float
=
10000
,
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
8192
,
max_position_embeddings
:
int
=
8192
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
...
@@ -211,13 +212,13 @@ class MiniCPMAttention(nn.Module):
...
@@ -211,13 +212,13 @@ class MiniCPMAttention(nn.Module):
self
.
total_num_heads
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
self
.
total_num_kv_heads
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
o_proj
=
RowParallelLinear
(
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
hidden_size
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
rotary_emb
=
get_rope
(
self
.
rotary_emb
=
get_rope
(
...
@@ -258,7 +259,7 @@ class MiniCPMDecoderLayer(nn.Module):
...
@@ -258,7 +259,7 @@ class MiniCPMDecoderLayer(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
,
config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
...
@@ -274,7 +275,7 @@ class MiniCPMDecoderLayer(nn.Module):
...
@@ -274,7 +275,7 @@ class MiniCPMDecoderLayer(nn.Module):
rope_theta
=
rope_theta
,
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
max_position_embeddings
=
max_position_embeddings
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
num_experts
=
getattr
(
self
.
config
,
"num_experts"
,
0
)
self
.
num_experts
=
getattr
(
self
.
config
,
"num_experts"
,
0
)
if
self
.
num_experts
==
0
:
if
self
.
num_experts
==
0
:
...
@@ -282,7 +283,7 @@ class MiniCPMDecoderLayer(nn.Module):
...
@@ -282,7 +283,7 @@ class MiniCPMDecoderLayer(nn.Module):
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
hidden_act
=
config
.
hidden_act
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
else
:
else
:
self
.
mlp
=
MiniCPMMoE
(
num_experts
=
config
.
num_experts
,
self
.
mlp
=
MiniCPMMoE
(
num_experts
=
config
.
num_experts
,
...
@@ -329,7 +330,7 @@ class MiniCPMModel(nn.Module):
...
@@ -329,7 +330,7 @@ class MiniCPMModel(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
,
config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -345,7 +346,7 @@ class MiniCPMModel(nn.Module):
...
@@ -345,7 +346,7 @@ class MiniCPMModel(nn.Module):
org_num_embeddings
=
config
.
vocab_size
,
org_num_embeddings
=
config
.
vocab_size
,
)
)
self
.
layers
=
nn
.
ModuleList
([
self
.
layers
=
nn
.
ModuleList
([
MiniCPMDecoderLayer
(
config
,
linear_method
)
MiniCPMDecoderLayer
(
config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
@@ -412,15 +413,15 @@ class MiniCPMForCausalLM(nn.Module):
...
@@ -412,15 +413,15 @@ class MiniCPMForCausalLM(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
,
config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
num_experts
=
getattr
(
self
.
config
,
"num_experts"
,
0
)
self
.
num_experts
=
getattr
(
self
.
config
,
"num_experts"
,
0
)
self
.
linear_method
=
linear_method
self
.
quant_config
=
quant_config
self
.
model
=
MiniCPMModel
(
config
,
self
.
model
=
MiniCPMModel
(
config
,
linear_method
,
quant_config
,
lora_config
=
lora_config
)
lora_config
=
lora_config
)
unpadded_vocab_size
=
config
.
vocab_size
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
if
lora_config
:
...
...
vllm/model_executor/models/mixtral.py
View file @
1591c68f
...
@@ -27,6 +27,7 @@ import torch
...
@@ -27,6 +27,7 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
MixtralConfig
from
transformers
import
MixtralConfig
from
vllm
import
_custom_ops
as
ops
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
LoRAConfig
from
vllm.config
import
LoRAConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
...
@@ -34,13 +35,13 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
...
@@ -34,13 +35,13 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
tensor_model_parallel_all_reduce
)
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
ReplicatedLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.fp8
import
(
Fp8LinearMethod
,
from
vllm.model_executor.layers.quantization.base_config
import
(
per_tensor_quantize
)
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.fp8
import
Fp8Config
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
@@ -69,7 +70,7 @@ class MixtralMoE(nn.Module):
...
@@ -69,7 +70,7 @@ class MixtralMoE(nn.Module):
intermediate_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
tp_size
=
tp_size
or
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
tp_size
or
get_tensor_model_parallel_world_size
()
...
@@ -77,50 +78,90 @@ class MixtralMoE(nn.Module):
...
@@ -77,50 +78,90 @@ class MixtralMoE(nn.Module):
self
.
top_k
=
top_k
self
.
top_k
=
top_k
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
//
self
.
tp_size
self
.
intermediate_size
=
intermediate_size
//
self
.
tp_size
self
.
quant_config
=
quant_config
# FIXME(pcmoritz): Make this more general to support different
# FIXME(pcmoritz): Make this more general to support different
# quantization schemes
# quantization schemes
self
.
use_fp8
=
isinstance
(
linear_method
,
Fp8LinearMethod
)
self
.
use_fp8
=
isinstance
(
quant_config
,
Fp8Config
)
if
params_dtype
is
None
:
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
params_dtype
=
torch
.
get_default_dtype
()
self
.
params_dtype
=
params_dtype
self
.
params_dtype
=
params_dtype
# Gate always runs at half / full precision for now.
self
.
gate
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
gate
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
num_total_experts
,
self
.
num_total_experts
,
bias
=
False
,
bias
=
False
,
params_dtype
=
self
.
params_dtype
,
params_dtype
=
self
.
params_dtype
,
linear_method
=
None
)
quant_config
=
None
)
if
self
.
use_fp8
:
params_dtype
=
torch
.
float8_e4m3fn
self
.
w
s
=
nn
.
Parameter
(
self
.
w
13_weight
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_total_experts
,
torch
.
empty
(
self
.
num_total_experts
,
2
*
self
.
intermediate_size
,
2
*
self
.
intermediate_size
,
self
.
hidden_size
,
self
.
hidden_size
,
device
=
"cuda"
,
dtype
=
params_dtype
))
dtype
=
self
.
params_dtype
))
self
.
w2_weight
=
nn
.
Parameter
(
self
.
w2s
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_total_experts
,
torch
.
empty
(
self
.
num_total_experts
,
self
.
hidden_size
,
self
.
hidden_size
,
self
.
intermediate_size
,
self
.
intermediate_size
,
device
=
"cuda"
,
dtype
=
params_dtype
))
dtype
=
self
.
params_dtype
))
set_weight_attrs
(
self
.
w13_weight
,
{
# Scaling factors for FP8 weights
self
.
ws_scale
=
nn
.
Parameter
(
torch
.
ones
(
self
.
num_total_experts
,
device
=
"cuda"
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
if
self
.
use_fp8
else
None
self
.
w2s_scale
=
nn
.
Parameter
(
torch
.
ones
(
self
.
num_total_experts
,
device
=
"cuda"
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
if
self
.
use_fp8
else
None
set_weight_attrs
(
self
.
ws
,
{
"weight_loader"
:
self
.
weight_loader
,
"weight_loader"
:
self
.
weight_loader
,
})
})
set_weight_attrs
(
self
.
w2
s
,
{
set_weight_attrs
(
self
.
w2
_weight
,
{
"weight_loader"
:
self
.
weight_loader
,
"weight_loader"
:
self
.
weight_loader
,
})
})
# Used for fp8.
self
.
w13_scale
=
None
self
.
w2_scale
=
None
self
.
a13_scale
=
None
self
.
a2_scale
=
None
if
self
.
use_fp8
:
# WEIGHT_SCALE (for fp8)
self
.
w13_scale
=
nn
.
Parameter
(
torch
.
ones
(
self
.
num_total_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
self
.
w2_scale
=
nn
.
Parameter
(
torch
.
ones
(
self
.
num_total_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if
quant_config
.
is_checkpoint_fp8_serialized
:
set_weight_attrs
(
self
.
w13_scale
,
{
"weight_loader"
:
self
.
weight_loader
,
})
set_weight_attrs
(
self
.
w2_scale
,
{
"weight_loader"
:
self
.
weight_loader
,
})
# ACT_SCALE (for fp8)
if
quant_config
.
activation_scheme
==
"static"
:
if
not
quant_config
.
is_checkpoint_fp8_serialized
:
raise
ValueError
(
"Found static activation scheme for checkpoint that "
"was not serialized fp8."
)
self
.
a13_scale
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
num_total_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
self
.
a2_scale
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
num_total_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
set_weight_attrs
(
self
.
a13_scale
,
{
"weight_loader"
:
self
.
weight_loader
,
})
set_weight_attrs
(
self
.
a2_scale
,
{
"weight_loader"
:
self
.
weight_loader
,
})
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
expert_id
:
int
):
weight_name
:
str
,
expert_id
:
int
):
tp_rank
=
get_tensor_model_parallel_rank
()
tp_rank
=
get_tensor_model_parallel_rank
()
...
@@ -134,18 +175,49 @@ class MixtralMoE(nn.Module):
...
@@ -134,18 +175,49 @@ class MixtralMoE(nn.Module):
shard_size
:
2
*
shard_size
,
:]
=
loaded_weight
[
shard
,
:]
shard_size
:
2
*
shard_size
,
:]
=
loaded_weight
[
shard
,
:]
if
weight_name
.
endswith
(
"w2.weight"
):
if
weight_name
.
endswith
(
"w2.weight"
):
param_data
[
expert_id
,
:,
:]
=
loaded_weight
[:,
shard
]
param_data
[
expert_id
,
:,
:]
=
loaded_weight
[:,
shard
]
if
"act_scale"
in
weight_name
or
"weight_scale"
in
weight_name
:
param_data
[
expert_id
]
=
loaded_weight
def
process_weights_after_loading
(
self
):
def
process_weights_after_loading
(
self
):
if
self
.
use_fp8
:
# Fp8 is the only case where we need to process after loading.
ws
=
torch
.
empty_like
(
self
.
ws
.
data
,
dtype
=
torch
.
float8_e4m3fn
)
if
not
self
.
use_fp8
:
w2s
=
torch
.
empty_like
(
self
.
w2s
.
data
,
dtype
=
torch
.
float8_e4m3fn
)
return
# If checkpoint is fp16, quantize here.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
w13_weight
=
torch
.
empty_like
(
self
.
w13_weight
.
data
,
dtype
=
torch
.
float8_e4m3fn
)
w2_weight
=
torch
.
empty_like
(
self
.
w2_weight
.
data
,
dtype
=
torch
.
float8_e4m3fn
)
for
expert
in
range
(
self
.
num_total_experts
):
for
expert
in
range
(
self
.
num_total_experts
):
ws
[
expert
,
:,
:],
self
.
ws_scale
[
expert
]
=
per_tensor_quantize
(
w13_weight
[
expert
,
:,
:],
self
.
w13_scale
[
self
.
ws
.
data
[
expert
,
:,
:])
expert
]
=
ops
.
scaled_fp8_quant
(
w2s
[
expert
,
:,
:],
self
.
w2s_scale
[
self
.
w13_weight
.
data
[
expert
,
:,
:])
expert
]
=
per_tensor_quantize
(
self
.
w2s
.
data
[
expert
,
:,
:])
w2_weight
[
expert
,
:,
:],
self
.
w2_scale
[
self
.
ws
=
nn
.
Parameter
(
ws
,
requires_grad
=
False
)
expert
]
=
ops
.
scaled_fp8_quant
(
self
.
w2s
=
nn
.
Parameter
(
w2s
,
requires_grad
=
False
)
self
.
w2_weight
.
data
[
expert
,
:,
:])
self
.
w13_weight
=
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
self
.
w2_weight
=
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
# If checkpoint is fp8 + static, cleanup act_scales.
# Since state_dict has an act_scale per expert but our kernels
# are passed one act_scale shared across all experts.
elif
self
.
quant_config
.
activation_scheme
==
"static"
:
if
self
.
a13_scale
is
None
or
self
.
a2_scale
is
None
:
raise
ValueError
(
"QuantConfig has static quantization, but found "
"activation scales are None."
)
if
(
not
all_close_1d
(
self
.
a13_scale
)
or
not
all_close_1d
(
self
.
a2_scale
)):
print_warning_once
(
"Found act_scales that are not equal for fp8 MoE layer. "
"Using the maximum across experts for each layer. "
)
self
.
a13_scale
=
nn
.
Parameter
(
self
.
a13_scale
.
max
(),
requires_grad
=
False
)
self
.
a2_scale
=
nn
.
Parameter
(
self
.
a2_scale
.
max
(),
requires_grad
=
False
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
,
hidden_size
=
hidden_states
.
shape
num_tokens
,
hidden_size
=
hidden_states
.
shape
...
@@ -153,15 +225,17 @@ class MixtralMoE(nn.Module):
...
@@ -153,15 +225,17 @@ class MixtralMoE(nn.Module):
# router_logits: (num_tokens, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
fused_moe
(
hidden_states
,
final_hidden_states
=
fused_moe
(
hidden_states
,
self
.
w
s
,
self
.
w
13_weight
,
self
.
w2
s
,
self
.
w2
_weight
,
router_logits
,
router_logits
,
self
.
top_k
,
self
.
top_k
,
renormalize
=
True
,
renormalize
=
True
,
inplace
=
True
,
inplace
=
True
,
use_fp8
=
self
.
use_fp8
,
use_fp8
=
self
.
use_fp8
,
w1_scale
=
self
.
ws_scale
,
w1_scale
=
self
.
w13_scale
,
w2_scale
=
self
.
w2s_scale
)
w2_scale
=
self
.
w2_scale
,
a1_scale
=
self
.
a13_scale
,
a2_scale
=
self
.
a2_scale
)
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
=
tensor_model_parallel_all_reduce
(
...
@@ -178,7 +252,7 @@ class MixtralAttention(nn.Module):
...
@@ -178,7 +252,7 @@ class MixtralAttention(nn.Module):
num_kv_heads
:
int
,
num_kv_heads
:
int
,
max_position
:
int
=
4096
*
32
,
max_position
:
int
=
4096
*
32
,
rope_theta
:
float
=
10000
,
rope_theta
:
float
=
10000
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
)
->
None
:
sliding_window
:
Optional
[
int
]
=
None
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
...
@@ -203,12 +277,14 @@ class MixtralAttention(nn.Module):
...
@@ -203,12 +277,14 @@ class MixtralAttention(nn.Module):
self
.
rope_theta
=
rope_theta
self
.
rope_theta
=
rope_theta
self
.
sliding_window
=
sliding_window
self
.
sliding_window
=
sliding_window
if
isinstance
(
linear_method
,
Fp8LinearMethod
):
if
isinstance
(
quant_config
,
Fp8Config
)
and
not
quant_config
.
is_checkpoint_fp8_serialized
:
print_warning_once
(
print_warning_once
(
"For Mixtral FP8 quantization, we currently do not quantize "
"For Mixtral FP8 quantization, we currently do not quantize "
"the attention layers until their FP8 performance is improved."
"the attention layers until their FP8 performance is improved."
)
)
linear_method
=
None
quant_config
=
None
self
.
qkv_proj
=
QKVParallelLinear
(
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
hidden_size
,
...
@@ -216,13 +292,13 @@ class MixtralAttention(nn.Module):
...
@@ -216,13 +292,13 @@ class MixtralAttention(nn.Module):
self
.
total_num_heads
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
self
.
total_num_kv_heads
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
o_proj
=
RowParallelLinear
(
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
hidden_size
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
rotary_emb
=
get_rope
(
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
self
.
head_dim
,
...
@@ -259,7 +335,7 @@ class MixtralDecoderLayer(nn.Module):
...
@@ -259,7 +335,7 @@ class MixtralDecoderLayer(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
MixtralConfig
,
config
:
MixtralConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
...
@@ -272,13 +348,13 @@ class MixtralDecoderLayer(nn.Module):
...
@@ -272,13 +348,13 @@ class MixtralDecoderLayer(nn.Module):
num_kv_heads
=
config
.
num_key_value_heads
,
num_kv_heads
=
config
.
num_key_value_heads
,
rope_theta
=
rope_theta
,
rope_theta
=
rope_theta
,
sliding_window
=
config
.
sliding_window
,
sliding_window
=
config
.
sliding_window
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
self
.
block_sparse_moe
=
MixtralMoE
(
self
.
block_sparse_moe
=
MixtralMoE
(
num_experts
=
config
.
num_local_experts
,
num_experts
=
config
.
num_local_experts
,
top_k
=
config
.
num_experts_per_tok
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
intermediate_size
=
config
.
intermediate_size
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
...
@@ -318,7 +394,7 @@ class MixtralModel(nn.Module):
...
@@ -318,7 +394,7 @@ class MixtralModel(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
MixtralConfig
,
config
:
MixtralConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -334,7 +410,7 @@ class MixtralModel(nn.Module):
...
@@ -334,7 +410,7 @@ class MixtralModel(nn.Module):
org_num_embeddings
=
config
.
vocab_size
,
org_num_embeddings
=
config
.
vocab_size
,
)
)
self
.
layers
=
nn
.
ModuleList
([
self
.
layers
=
nn
.
ModuleList
([
MixtralDecoderLayer
(
config
,
linear_method
=
linear_method
)
MixtralDecoderLayer
(
config
,
quant_config
=
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
@@ -384,14 +460,13 @@ class MixtralForCausalLM(nn.Module):
...
@@ -384,14 +460,13 @@ class MixtralForCausalLM(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
MixtralConfig
,
config
:
MixtralConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
model
=
MixtralModel
(
config
,
self
.
model
=
MixtralModel
(
config
,
linear_method
,
quant_config
,
lora_config
=
lora_config
)
lora_config
=
lora_config
)
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
if
lora_config
:
...
@@ -443,11 +518,26 @@ class MixtralForCausalLM(nn.Module):
...
@@ -443,11 +518,26 @@ class MixtralForCausalLM(nn.Module):
]
]
expert_params_mapping
=
[
expert_params_mapping
=
[
# These are the weight scales for the experts
# (param_name, weight_name, expert_id)
# (param_name, weight_name, expert_id)
(
"ws"
if
weight_name
in
[
"w1"
,
"w3"
]
else
"w2s"
,
(
"w13_scale"
if
weight_name
in
[
"w1"
,
"w3"
]
else
"w2_scale"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight_scale"
,
expert_id
)
for
expert_id
in
range
(
self
.
config
.
num_local_experts
)
for
weight_name
in
[
"w1"
,
"w2"
,
"w3"
]
]
+
[
# These are the weights for the experts
# (param_name, weight_name, expert_id)
(
"w13_weight"
if
weight_name
in
[
"w1"
,
"w3"
]
else
"w2_weight"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight"
,
expert_id
)
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight"
,
expert_id
)
for
expert_id
in
range
(
self
.
config
.
num_local_experts
)
for
expert_id
in
range
(
self
.
config
.
num_local_experts
)
for
weight_name
in
[
"w1"
,
"w2"
,
"w3"
]
for
weight_name
in
[
"w1"
,
"w2"
,
"w3"
]
]
+
[
# These are the activation scales for the experts
# (param_name, weight_name, expert_id)
(
"a13_scale"
if
weight_name
in
[
"w1"
,
"w3"
]
else
"a2_scale"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.act_scale"
,
expert_id
)
for
expert_id
in
range
(
self
.
config
.
num_local_experts
)
for
weight_name
in
[
"w1"
,
"w2"
,
"w3"
]
]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
...
@@ -486,3 +576,8 @@ class MixtralForCausalLM(nn.Module):
...
@@ -486,3 +576,8 @@ class MixtralForCausalLM(nn.Module):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
def
all_close_1d
(
x
:
torch
.
Tensor
)
->
bool
:
assert
len
(
x
.
shape
)
==
1
return
all
(
torch
.
allclose
(
x
[
0
],
x
[
i
])
for
i
in
range
(
x
.
shape
[
0
]))
vllm/model_executor/models/mixtral_quant.py
View file @
1591c68f
...
@@ -34,11 +34,12 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
...
@@ -34,11 +34,12 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
ReplicatedLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
@@ -55,7 +56,7 @@ class MixtralMLP(nn.Module):
...
@@ -55,7 +56,7 @@ class MixtralMLP(nn.Module):
num_experts
:
int
,
num_experts
:
int
,
hidden_size
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
intermediate_size
:
int
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
num_experts
=
num_experts
self
.
num_experts
=
num_experts
...
@@ -65,15 +66,15 @@ class MixtralMLP(nn.Module):
...
@@ -65,15 +66,15 @@ class MixtralMLP(nn.Module):
self
.
w1
=
ReplicatedLinear
(
self
.
hidden_dim
,
self
.
w1
=
ReplicatedLinear
(
self
.
hidden_dim
,
self
.
ffn_dim
,
self
.
ffn_dim
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
self
.
w2
=
ReplicatedLinear
(
self
.
ffn_dim
,
self
.
w2
=
ReplicatedLinear
(
self
.
ffn_dim
,
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
self
.
w3
=
ReplicatedLinear
(
self
.
hidden_dim
,
self
.
w3
=
ReplicatedLinear
(
self
.
hidden_dim
,
self
.
ffn_dim
,
self
.
ffn_dim
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
# TODO: Use vllm's SiluAndMul
# TODO: Use vllm's SiluAndMul
self
.
act_fn
=
nn
.
SiLU
()
self
.
act_fn
=
nn
.
SiLU
()
...
@@ -92,7 +93,7 @@ class MixtralMoE(nn.Module):
...
@@ -92,7 +93,7 @@ class MixtralMoE(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
MixtralConfig
,
config
:
MixtralConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
...
@@ -115,14 +116,14 @@ class MixtralMoE(nn.Module):
...
@@ -115,14 +116,14 @@ class MixtralMoE(nn.Module):
MixtralMLP
(
self
.
num_total_experts
,
MixtralMLP
(
self
.
num_total_experts
,
config
.
hidden_size
,
config
.
hidden_size
,
config
.
intermediate_size
,
config
.
intermediate_size
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
if
idx
in
self
.
expert_indicies
else
None
if
idx
in
self
.
expert_indicies
else
None
for
idx
in
range
(
self
.
num_total_experts
)
for
idx
in
range
(
self
.
num_total_experts
)
])
])
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
self
.
num_total_experts
,
self
.
num_total_experts
,
bias
=
False
,
bias
=
False
,
linear_method
=
None
)
quant_config
=
None
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
num_tokens
,
hidden_dim
=
hidden_states
.
shape
...
@@ -162,7 +163,7 @@ class MixtralAttention(nn.Module):
...
@@ -162,7 +163,7 @@ class MixtralAttention(nn.Module):
num_kv_heads
:
int
,
num_kv_heads
:
int
,
max_position
:
int
=
4096
*
32
,
max_position
:
int
=
4096
*
32
,
rope_theta
:
float
=
10000
,
rope_theta
:
float
=
10000
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
)
->
None
:
sliding_window
:
Optional
[
int
]
=
None
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
...
@@ -193,13 +194,13 @@ class MixtralAttention(nn.Module):
...
@@ -193,13 +194,13 @@ class MixtralAttention(nn.Module):
self
.
total_num_heads
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
self
.
total_num_kv_heads
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
o_proj
=
RowParallelLinear
(
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
hidden_size
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
rotary_emb
=
get_rope
(
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
self
.
head_dim
,
...
@@ -236,7 +237,7 @@ class MixtralDecoderLayer(nn.Module):
...
@@ -236,7 +237,7 @@ class MixtralDecoderLayer(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
MixtralConfig
,
config
:
MixtralConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
...
@@ -249,9 +250,9 @@ class MixtralDecoderLayer(nn.Module):
...
@@ -249,9 +250,9 @@ class MixtralDecoderLayer(nn.Module):
num_kv_heads
=
config
.
num_key_value_heads
,
num_kv_heads
=
config
.
num_key_value_heads
,
rope_theta
=
rope_theta
,
rope_theta
=
rope_theta
,
sliding_window
=
config
.
sliding_window
,
sliding_window
=
config
.
sliding_window
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
self
.
block_sparse_moe
=
MixtralMoE
(
config
=
config
,
self
.
block_sparse_moe
=
MixtralMoE
(
config
=
config
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
...
@@ -291,7 +292,7 @@ class MixtralModel(nn.Module):
...
@@ -291,7 +292,7 @@ class MixtralModel(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
MixtralConfig
,
config
:
MixtralConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
padding_idx
=
config
.
pad_token_id
self
.
padding_idx
=
config
.
pad_token_id
...
@@ -302,7 +303,7 @@ class MixtralModel(nn.Module):
...
@@ -302,7 +303,7 @@ class MixtralModel(nn.Module):
config
.
hidden_size
,
config
.
hidden_size
,
)
)
self
.
layers
=
nn
.
ModuleList
([
self
.
layers
=
nn
.
ModuleList
([
MixtralDecoderLayer
(
config
,
linear_method
=
linear_method
)
MixtralDecoderLayer
(
config
,
quant_config
=
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
@@ -331,12 +332,12 @@ class MixtralForCausalLM(nn.Module):
...
@@ -331,12 +332,12 @@ class MixtralForCausalLM(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
MixtralConfig
,
config
:
MixtralConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
quant_config
=
quant_config
self
.
model
=
MixtralModel
(
config
,
linear_method
)
self
.
model
=
MixtralModel
(
config
,
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/mpt.py
View file @
1591c68f
...
@@ -11,10 +11,11 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
...
@@ -11,10 +11,11 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
VocabParallelEmbedding
)
...
@@ -42,7 +43,7 @@ class MPTAttention(nn.Module):
...
@@ -42,7 +43,7 @@ class MPTAttention(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
MPTConfig
,
config
:
MPTConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
d_model
=
config
.
d_model
self
.
d_model
=
config
.
d_model
...
@@ -65,7 +66,7 @@ class MPTAttention(nn.Module):
...
@@ -65,7 +66,7 @@ class MPTAttention(nn.Module):
self
.
total_num_heads
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
self
.
total_num_kv_heads
,
bias
=
not
config
.
no_bias
,
bias
=
not
config
.
no_bias
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
if
self
.
qk_ln
:
if
self
.
qk_ln
:
self
.
q_ln
=
nn
.
LayerNorm
(
self
.
d_model
)
self
.
q_ln
=
nn
.
LayerNorm
(
self
.
d_model
)
...
@@ -74,7 +75,7 @@ class MPTAttention(nn.Module):
...
@@ -74,7 +75,7 @@ class MPTAttention(nn.Module):
self
.
d_model
,
self
.
d_model
,
self
.
d_model
,
self
.
d_model
,
bias
=
not
config
.
no_bias
,
bias
=
not
config
.
no_bias
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
tp_world_size
=
get_tensor_model_parallel_world_size
()
tp_world_size
=
get_tensor_model_parallel_world_size
()
...
@@ -133,7 +134,7 @@ class MPTMLP(nn.Module):
...
@@ -133,7 +134,7 @@ class MPTMLP(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
MPTConfig
,
config
:
MPTConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
hidden_size
=
config
.
d_model
hidden_size
=
config
.
d_model
...
@@ -143,15 +144,14 @@ class MPTMLP(nn.Module):
...
@@ -143,15 +144,14 @@ class MPTMLP(nn.Module):
hidden_size
,
hidden_size
,
intermediate_size
,
intermediate_size
,
bias
=
not
config
.
no_bias
,
bias
=
not
config
.
no_bias
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
self
.
act
=
get_act_fn
(
"gelu"
,
quant_config
,
intermediate_size
)
self
.
act
=
get_act_fn
(
"gelu"
,
quant_config
,
intermediate_size
)
self
.
down_proj
=
RowParallelLinear
(
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
intermediate_size
,
hidden_size
,
hidden_size
,
bias
=
not
config
.
no_bias
,
bias
=
not
config
.
no_bias
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -166,14 +166,14 @@ class MPTBlock(nn.Module):
...
@@ -166,14 +166,14 @@ class MPTBlock(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
MPTConfig
,
config
:
MPTConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
hidden_size
=
config
.
d_model
hidden_size
=
config
.
d_model
self
.
norm_1
=
nn
.
LayerNorm
(
hidden_size
)
self
.
norm_1
=
nn
.
LayerNorm
(
hidden_size
)
self
.
attn
=
MPTAttention
(
config
,
linear_method
)
self
.
attn
=
MPTAttention
(
config
,
quant_config
)
self
.
norm_2
=
nn
.
LayerNorm
(
hidden_size
)
self
.
norm_2
=
nn
.
LayerNorm
(
hidden_size
)
self
.
ffn
=
MPTMLP
(
config
,
linear_method
)
self
.
ffn
=
MPTMLP
(
config
,
quant_config
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -201,7 +201,7 @@ class MPTModel(nn.Module):
...
@@ -201,7 +201,7 @@ class MPTModel(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
MPTConfig
,
config
:
MPTConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
assert
config
.
embedding_fraction
==
1.0
assert
config
.
embedding_fraction
==
1.0
...
@@ -212,7 +212,7 @@ class MPTModel(nn.Module):
...
@@ -212,7 +212,7 @@ class MPTModel(nn.Module):
config
.
d_model
,
config
.
d_model
,
)
)
self
.
blocks
=
nn
.
ModuleList
(
self
.
blocks
=
nn
.
ModuleList
(
[
MPTBlock
(
config
,
linear_method
)
for
_
in
range
(
config
.
n_layers
)])
[
MPTBlock
(
config
,
quant_config
)
for
_
in
range
(
config
.
n_layers
)])
self
.
norm_f
=
nn
.
LayerNorm
(
config
.
d_model
)
self
.
norm_f
=
nn
.
LayerNorm
(
config
.
d_model
)
if
config
.
no_bias
:
if
config
.
no_bias
:
for
module
in
self
.
modules
():
for
module
in
self
.
modules
():
...
@@ -246,14 +246,14 @@ class MPTForCausalLM(nn.Module):
...
@@ -246,14 +246,14 @@ class MPTForCausalLM(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
MPTConfig
,
config
:
MPTConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
assert
config
.
tie_word_embeddings
assert
config
.
tie_word_embeddings
self
.
linear_method
=
linear_method
self
.
quant_config
=
quant_config
self
.
transformer
=
MPTModel
(
config
,
linear_method
)
self
.
transformer
=
MPTModel
(
config
,
quant_config
)
self
.
lm_head_weight
=
self
.
transformer
.
wte
.
weight
self
.
lm_head_weight
=
self
.
transformer
.
wte
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/olmo.py
View file @
1591c68f
# coding=utf-8
# coding=utf-8
# Adapted from
# Adapted from
# https://github.com/allenai/OLMo/blob/v0.2.4/olmo/model.py and
# https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/olmo/modeling_olmo.py
# https://github.com/allenai/OLMo/blob/v0.2.4/hf_olmo/modeling_olmo.py
# Copyright 2024 The vLLM team.
# Copyright 2023 The vLLM team.
# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
#
#
# BSD 3-Clause License
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
#
# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
# Licensed under the Apache License, Version 2.0 (the "License");
# All rights reserved.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# Redistribution and use in source and binary forms, with or without
# http://www.apache.org/licenses/LICENSE-2.0
# modification, are permitted provided that the following conditions are met:
#
#
# * Redistributions of source code must retain the above copyright notice, this
# Unless required by applicable law or agreed to in writing, software
# list of conditions and the following disclaimer.
# distributed under the License is distributed on an "AS IS" BASIS,
#
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# * Redistributions in binary form must reproduce the above copyright notice,
# See the License for the specific language governing permissions and
# this list of conditions and the following disclaimer in the documentation
# limitations under the License.
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Inference-only OLMo model compatible with HuggingFace weights."""
"""Inference-only OLMo model compatible with HuggingFace weights."""
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
torch
import
torch
# this model must need this dependency
from
hf_olmo
import
OLMoConfig
from
torch
import
nn
from
torch
import
nn
from
transformers
import
OlmoConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
LinearMethodBase
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
...
@@ -70,56 +54,53 @@ class OlmoAttention(nn.Module):
...
@@ -70,56 +54,53 @@ class OlmoAttention(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
O
LM
oConfig
,
config
:
O
lm
oConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
hidden_size
=
config
.
d_model
self
.
hidden_size
=
config
.
hidden_size
assert
config
.
d_model
%
config
.
n_heads
==
0
tensor_model_parallel_world_size
=
(
tensor_model_parallel_world_size
=
(
get_tensor_model_parallel_world_size
())
get_tensor_model_parallel_world_size
())
self
.
total_num_heads
=
self
.
config
.
n_heads
self
.
total_num_heads
=
config
.
num_attention_heads
assert
self
.
hidden_size
%
self
.
total_num_heads
==
0
assert
self
.
total_num_heads
%
tensor_model_parallel_world_size
==
0
assert
self
.
total_num_heads
%
tensor_model_parallel_world_size
==
0
self
.
num_heads
=
(
self
.
total_num_heads
//
self
.
num_heads
=
(
self
.
total_num_heads
//
tensor_model_parallel_world_size
)
tensor_model_parallel_world_size
)
self
.
head_dim
=
self
.
hidden_size
//
self
.
total_num_heads
self
.
head_dim
=
self
.
hidden_size
//
self
.
total_num_heads
self
.
max_position_embeddings
=
config
.
max_position_embeddings
self
.
rope_theta
=
config
.
rope_theta
self
.
clip_qkv
=
config
.
clip_qkv
# Layer norms.
self
.
attn_norm
=
nn
.
LayerNorm
(
config
.
d_model
,
elementwise_affine
=
False
,
bias
=
False
)
# Attention input projection. Projects x -> (q, k, v)
# Attention input projection. Projects x -> (q, k, v)
self
.
att
_proj
=
QKVParallelLinear
(
self
.
qkv
_proj
=
QKVParallelLinear
(
config
.
d_model
,
self
.
hidden_size
,
self
.
head_dim
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_heads
,
bias
=
config
.
include
_bias
,
bias
=
config
.
attention
_bias
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
# Rotary embeddings.
# Rotary embeddings.
if
self
.
config
.
rope
:
self
.
rotary_emb
=
get_rope
(
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
self
.
head_dim
,
max_position_embeddings
=
getattr
(
config
,
rotary_dim
=
self
.
head_dim
,
"max_position_embeddings"
,
8192
)
max_position
=
self
.
max_position_embeddings
,
self
.
rotary_emb
=
get_rope
(
base
=
self
.
rope_theta
,
self
.
head_dim
,
)
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
)
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
head_dim
,
scale
=
self
.
scaling
)
scale
=
self
.
scaling
)
# Attention output projection.
# Attention output projection.
self
.
attn_out
=
RowParallelLinear
(
self
.
o_proj
=
RowParallelLinear
(
config
.
d_model
,
self
.
hidden_size
,
config
.
d_model
,
self
.
hidden_size
,
bias
=
config
.
include
_bias
,
bias
=
config
.
attention
_bias
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
def
forward
(
def
forward
(
...
@@ -129,13 +110,13 @@ class OlmoAttention(nn.Module):
...
@@ -129,13 +110,13 @@ class OlmoAttention(nn.Module):
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
attn_norm
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
att_proj
(
hidden_states
)
if
self
.
clip_qkv
is
not
None
:
qkv
.
clamp_
(
min
=-
self
.
clip_qkv
,
max
=
self
.
clip_qkv
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
if
self
.
config
.
rope
:
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
output
,
_
=
self
.
attn_out
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
@@ -148,57 +129,44 @@ class OlmoMLP(nn.Module):
...
@@ -148,57 +129,44 @@ class OlmoMLP(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
O
LM
oConfig
,
config
:
O
lm
oConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
hidden_size
=
(
config
.
mlp_hidden_size
if
config
.
mlp_hidden_size
self
.
hidden_size
=
config
.
hidden_size
is
not
None
else
config
.
mlp_ratio
*
config
.
d_model
)
self
.
intermediate_size
=
config
.
intermediate_size
# Layer norms.
self
.
ff_norm
=
nn
.
LayerNorm
(
config
.
d_model
,
elementwise_affine
=
False
,
bias
=
False
)
# Feed-forward input projection.
# Feed-forward input projection.
self
.
ff
_proj
=
MergedColumnParallelLinear
(
self
.
gate_up
_proj
=
MergedColumnParallelLinear
(
config
.
d_model
,
self
.
hidden_size
,
[
self
.
hidden_size
//
2
]
*
2
,
[
self
.
intermediate_size
]
*
2
,
bias
=
config
.
include_bias
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
# Activation function.
# Activation function.
self
.
act
=
SiluAndMul
()
self
.
act_fn
=
SiluAndMul
()
self
.
act
.
output_multiplier
=
0.5
assert
(
self
.
act
.
output_multiplier
*
self
.
hidden_size
)
%
1
==
0
# Feed-forward output projection.
# Feed-forward output projection.
self
.
ff_out
=
RowParallelLinear
(
self
.
down_proj
=
RowParallelLinear
(
int
(
self
.
act
.
output_multiplier
*
self
.
hidden
_size
)
,
self
.
intermediate
_size
,
config
.
d_model
,
self
.
hidden_size
,
bias
=
config
.
include_bias
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
def
forward
(
def
forward
(
self
,
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# Add feed-forward projection.
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
# shape: (batch_size, seq_len, d_model)
x
=
self
.
act_fn
(
gate_up
)
og_x
=
x
x
,
_
=
self
.
down_proj
(
x
)
x
=
self
.
ff_norm
(
x
)
x
,
_
=
self
.
ff_proj
(
x
)
x
=
self
.
act
(
x
)
x
,
_
=
self
.
ff_out
(
x
)
x
=
og_x
+
x
return
x
return
x
class
Olmo
Block
(
nn
.
Module
):
class
Olmo
DecoderLayer
(
nn
.
Module
):
"""
"""
This is a typical transformer block where the output is
This is a typical transformer block where the output is
computed as ``MLP(LN(x + Attention(LN(x))))``
computed as ``MLP(LN(x + Attention(LN(x))))``
...
@@ -206,14 +174,22 @@ class OlmoBlock(nn.Module):
...
@@ -206,14 +174,22 @@ class OlmoBlock(nn.Module):
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
config
:
O
LM
oConfig
,
config
:
O
lm
oConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
super
().
__init__
()
# Attention block.
# Attention block.
self
.
attn
=
OlmoAttention
(
config
,
linear_method
)
self
.
self_
attn
=
OlmoAttention
(
config
,
quant_config
)
# MLP block.
# MLP block.
self
.
mlp
=
OlmoMLP
(
config
,
linear_method
)
self
.
mlp
=
OlmoMLP
(
config
,
quant_config
)
# LayerNorm
self
.
input_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
elementwise_affine
=
False
,
bias
=
False
)
self
.
post_attention_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
elementwise_affine
=
False
,
bias
=
False
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -223,52 +199,37 @@ class OlmoBlock(nn.Module):
...
@@ -223,52 +199,37 @@ class OlmoBlock(nn.Module):
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]]:
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]]:
# Attention block.
# Attention block.
og_x
=
hidden_states
residual
=
hidden_states
x
=
self
.
attn
(
positions
,
hidden_states
,
kv_cache
,
attn_metadata
)
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
x
=
x
+
og_x
hidden_states
=
self
.
self_attn
(
positions
,
hidden_states
,
kv_cache
,
attn_metadata
)
hidden_states
=
hidden_states
+
residual
# MLP block.
# MLP block.
hidden_states
=
self
.
mlp
(
x
)
residual
=
hidden_states
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
return
hidden_states
return
hidden_states
class
OlmoModel
(
nn
.
Module
):
class
OlmoModel
(
nn
.
Module
):
def
__init__
(
self
,
def
__init__
(
self
,
config
:
O
LM
oConfig
,
config
:
O
lm
oConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
transformer
=
nn
.
ModuleDict
(
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
dict
(
config
.
hidden_size
)
wte
=
VocabParallelEmbedding
(
self
.
layers
=
nn
.
ModuleList
([
config
.
embedding_size
or
config
.
vocab_size
,
OlmoDecoderLayer
(
config
,
quant_config
)
config
.
d_model
,
for
layer_idx
in
range
(
config
.
num_hidden_layers
)
),
])
ln_f
=
nn
.
LayerNorm
(
config
.
d_model
,
self
.
norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
elementwise_affine
=
False
,
elementwise_affine
=
False
,
bias
=
False
),
bias
=
False
)
))
blocks
=
[
OlmoBlock
(
config
,
linear_method
)
for
i
in
range
(
config
.
n_layers
)
]
if
self
.
config
.
block_group_size
>
1
:
raise
NotImplementedError
(
"Block group size > 1 not supported yet"
)
else
:
self
.
transformer
.
update
({
"blocks"
:
nn
.
ModuleList
(
blocks
)})
if
not
config
.
weight_tying
:
self
.
transformer
.
update
({
"ff_out"
:
ColumnParallelLinear
(
config
.
d_model
,
config
.
embedding_size
or
config
.
vocab_size
,
bias
=
config
.
include_bias
,
linear_method
=
linear_method
,
)
})
def
forward
(
def
forward
(
self
,
self
,
...
@@ -282,39 +243,48 @@ class OlmoModel(nn.Module):
...
@@ -282,39 +243,48 @@ class OlmoModel(nn.Module):
"""
"""
# Get embeddings of input.
# Get embeddings of input.
# shape: (batch_size, seq_len, d_model)
# shape: (batch_size, seq_len, d_model)
x
=
self
.
transformer
.
wte
(
input_ids
)
# type: ignore
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
# embed positions
hidden_states
=
inputs_embeds
# Apply blocks one-by-one.
# Apply blocks one-by-one.
for
block
_idx
,
block
in
enumerate
(
self
.
transformer
.
block
s
):
for
layer
_idx
,
decoder_layer
in
enumerate
(
self
.
layer
s
):
# shape: (batch_size, seq_len, d_model)
# shape: (batch_size, seq_len, d_model)
x
=
block
(
hidden_states
=
decoder_layer
(
positions
,
positions
,
x
,
hidden_states
,
kv_caches
[
block
_idx
],
kv_caches
[
layer
_idx
],
attn_metadata
,
attn_metadata
,
)
)
# Apply final layer norm.
# Apply final layer norm.
# shape: (batch_size, seq_len or 1, d_model)
# shape: (batch_size, seq_len or 1, d_model)
x
=
self
.
transformer
.
ln_f
(
x
)
# type: ignore
hidden_states
=
self
.
norm
(
hidden_states
)
return
x
return
hidden_states
class
O
LM
oForCausalLM
(
nn
.
Module
):
class
O
lm
oForCausalLM
(
nn
.
Module
):
"""
"""
Extremely barebones HF model wrapper.
Extremely barebones HF model wrapper.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
config
:
O
LM
oConfig
,
config
:
O
lm
oConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
model
=
OlmoModel
(
config
,
quant_config
)
self
.
model
=
OlmoModel
(
config
,
linear_method
)
if
config
.
tie_word_embeddings
:
self
.
lm_head_weight
=
(
self
.
model
.
transformer
.
wte
.
weight
self
.
lm_head_weight
=
self
.
model
.
embed_tokens
.
weight
if
config
.
weight_tying
else
else
:
self
.
model
.
transformer
.
ff_out
.
weight
)
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
)
self
.
lm_head_weight
=
self
.
lm_head
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
...
@@ -348,20 +318,39 @@ class OLMoForCausalLM(nn.Module):
...
@@ -348,20 +318,39 @@ class OLMoForCausalLM(nn.Module):
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
# attention
if
"rotary_emb.inv_freq"
in
name
:
if
".att"
in
name
:
continue
name
=
name
.
replace
(
".att"
,
".attn.att"
)
if
(
"rotary_emb.cos_cached"
in
name
# mlp
or
"rotary_emb.sin_cached"
in
name
):
if
".ff_proj"
in
name
:
# Models trained using ColossalAI may include these tensors in
name
=
name
.
replace
(
".ff_proj"
,
".mlp.ff_proj"
)
# the checkpoint. Skip them.
# Reverse the weight for the MergeColumnParallelLinear
continue
loaded_weight
=
torch
.
concat
(
loaded_weight
.
chunk
(
2
)[::
-
1
])
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
".ff_out"
in
name
and
"transformer.ff_out"
not
in
name
:
if
weight_name
not
in
name
:
name
=
name
.
replace
(
".ff_out"
,
".mlp.ff_out"
)
continue
# there is no bias in olmo
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
# Skip loading extra bias for GPTQ models.
weight_loader
=
getattr
(
param
,
"weight_loader"
,
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
default_weight_loader
)
continue
weight_loader
(
param
,
loaded_weight
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
vllm/model_executor/models/opt.py
View file @
1591c68f
...
@@ -27,11 +27,12 @@ from vllm.attention import Attention, AttentionMetadata
...
@@ -27,11 +27,12 @@ from vllm.attention import Attention, AttentionMetadata
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
ReplicatedLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
VocabParallelEmbedding
)
...
@@ -60,7 +61,7 @@ class OPTAttention(nn.Module):
...
@@ -60,7 +61,7 @@ class OPTAttention(nn.Module):
embed_dim
:
int
,
embed_dim
:
int
,
num_heads
:
int
,
num_heads
:
int
,
bias
:
bool
=
True
,
bias
:
bool
=
True
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
embed_dim
=
embed_dim
self
.
embed_dim
=
embed_dim
...
@@ -77,13 +78,13 @@ class OPTAttention(nn.Module):
...
@@ -77,13 +78,13 @@ class OPTAttention(nn.Module):
self
.
head_dim
,
self
.
head_dim
,
total_num_heads
,
total_num_heads
,
bias
=
bias
,
bias
=
bias
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
out_proj
=
RowParallelLinear
(
self
.
out_proj
=
RowParallelLinear
(
embed_dim
,
embed_dim
,
embed_dim
,
embed_dim
,
bias
=
bias
,
bias
=
bias
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
head_dim
,
...
@@ -107,7 +108,7 @@ class OPTDecoderLayer(nn.Module):
...
@@ -107,7 +108,7 @@ class OPTDecoderLayer(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
OPTConfig
,
config
:
OPTConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
...
@@ -116,7 +117,7 @@ class OPTDecoderLayer(nn.Module):
...
@@ -116,7 +117,7 @@ class OPTDecoderLayer(nn.Module):
embed_dim
=
self
.
embed_dim
,
embed_dim
=
self
.
embed_dim
,
num_heads
=
config
.
num_attention_heads
,
num_heads
=
config
.
num_attention_heads
,
bias
=
config
.
enable_bias
,
bias
=
config
.
enable_bias
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
do_layer_norm_before
=
config
.
do_layer_norm_before
self
.
do_layer_norm_before
=
config
.
do_layer_norm_before
...
@@ -127,16 +128,15 @@ class OPTDecoderLayer(nn.Module):
...
@@ -127,16 +128,15 @@ class OPTDecoderLayer(nn.Module):
self
.
embed_dim
,
self
.
embed_dim
,
config
.
ffn_dim
,
config
.
ffn_dim
,
bias
=
config
.
enable_bias
,
bias
=
config
.
enable_bias
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
self
.
activation_fn
=
get_act_fn
(
config
.
activation_function
,
self
.
activation_fn
=
get_act_fn
(
config
.
activation_function
,
quant_config
,
config
.
ffn_dim
)
quant_config
,
config
.
ffn_dim
)
self
.
fc2
=
RowParallelLinear
(
self
.
fc2
=
RowParallelLinear
(
config
.
ffn_dim
,
config
.
ffn_dim
,
self
.
embed_dim
,
self
.
embed_dim
,
bias
=
config
.
enable_bias
,
bias
=
config
.
enable_bias
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
final_layer_norm
=
nn
.
LayerNorm
(
self
.
final_layer_norm
=
nn
.
LayerNorm
(
self
.
embed_dim
,
self
.
embed_dim
,
...
@@ -181,7 +181,7 @@ class OPTDecoder(nn.Module):
...
@@ -181,7 +181,7 @@ class OPTDecoder(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
OPTConfig
,
config
:
OPTConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
...
@@ -202,7 +202,7 @@ class OPTDecoder(nn.Module):
...
@@ -202,7 +202,7 @@ class OPTDecoder(nn.Module):
self
.
project_out
=
ReplicatedLinear
(
config
.
hidden_size
,
self
.
project_out
=
ReplicatedLinear
(
config
.
hidden_size
,
config
.
word_embed_proj_dim
,
config
.
word_embed_proj_dim
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
else
:
else
:
self
.
project_out
=
None
self
.
project_out
=
None
...
@@ -210,7 +210,7 @@ class OPTDecoder(nn.Module):
...
@@ -210,7 +210,7 @@ class OPTDecoder(nn.Module):
self
.
project_in
=
ReplicatedLinear
(
config
.
word_embed_proj_dim
,
self
.
project_in
=
ReplicatedLinear
(
config
.
word_embed_proj_dim
,
config
.
hidden_size
,
config
.
hidden_size
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
else
:
else
:
self
.
project_in
=
None
self
.
project_in
=
None
...
@@ -226,7 +226,7 @@ class OPTDecoder(nn.Module):
...
@@ -226,7 +226,7 @@ class OPTDecoder(nn.Module):
self
.
final_layer_norm
=
None
self
.
final_layer_norm
=
None
self
.
layers
=
nn
.
ModuleList
([
self
.
layers
=
nn
.
ModuleList
([
OPTDecoderLayer
(
config
,
linear_method
)
OPTDecoderLayer
(
config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
])
...
@@ -259,10 +259,10 @@ class OPTModel(nn.Module):
...
@@ -259,10 +259,10 @@ class OPTModel(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
OPTConfig
,
config
:
OPTConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
decoder
=
OPTDecoder
(
config
,
linear_method
)
self
.
decoder
=
OPTDecoder
(
config
,
quant_config
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -279,12 +279,12 @@ class OPTForCausalLM(nn.Module):
...
@@ -279,12 +279,12 @@ class OPTForCausalLM(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
,
config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
quant_config
=
quant_config
self
.
model
=
OPTModel
(
config
,
linear_method
)
self
.
model
=
OPTModel
(
config
,
quant_config
)
self
.
lm_head_weight
=
self
.
model
.
decoder
.
embed_tokens
.
weight
self
.
lm_head_weight
=
self
.
model
.
decoder
.
embed_tokens
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/orion.py
View file @
1591c68f
...
@@ -13,11 +13,12 @@ from transformers import PretrainedConfig
...
@@ -13,11 +13,12 @@ from transformers import PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
@@ -34,17 +35,17 @@ class OrionMLP(nn.Module):
...
@@ -34,17 +35,17 @@ class OrionMLP(nn.Module):
hidden_size
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
hidden_act
:
str
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
hidden_size
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
if
hidden_act
!=
"silu"
:
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
"Only silu is supported for now."
)
...
@@ -67,7 +68,7 @@ class OrionAttention(nn.Module):
...
@@ -67,7 +68,7 @@ class OrionAttention(nn.Module):
rope_theta
:
float
=
10000
,
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
8192
,
max_position_embeddings
:
int
=
8192
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
...
@@ -98,13 +99,13 @@ class OrionAttention(nn.Module):
...
@@ -98,13 +99,13 @@ class OrionAttention(nn.Module):
self
.
total_num_heads
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
self
.
total_num_kv_heads
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
o_proj
=
RowParallelLinear
(
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
hidden_size
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
rotary_emb
=
get_rope
(
self
.
rotary_emb
=
get_rope
(
...
@@ -139,7 +140,7 @@ class OrionDecoderLayer(nn.Module):
...
@@ -139,7 +140,7 @@ class OrionDecoderLayer(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
...
@@ -154,13 +155,13 @@ class OrionDecoderLayer(nn.Module):
...
@@ -154,13 +155,13 @@ class OrionDecoderLayer(nn.Module):
rope_theta
=
rope_theta
,
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
max_position_embeddings
=
max_position_embeddings
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
mlp
=
OrionMLP
(
self
.
mlp
=
OrionMLP
(
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
hidden_act
=
config
.
hidden_act
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
input_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
self
.
input_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
...
@@ -201,7 +202,7 @@ class OrionModel(nn.Module):
...
@@ -201,7 +202,7 @@ class OrionModel(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
...
@@ -212,7 +213,7 @@ class OrionModel(nn.Module):
...
@@ -212,7 +213,7 @@ class OrionModel(nn.Module):
config
.
hidden_size
,
config
.
hidden_size
,
)
)
self
.
layers
=
nn
.
ModuleList
([
self
.
layers
=
nn
.
ModuleList
([
OrionDecoderLayer
(
config
,
linear_method
)
OrionDecoderLayer
(
config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
])
self
.
norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
@@ -244,12 +245,12 @@ class OrionForCausalLM(nn.Module):
...
@@ -244,12 +245,12 @@ class OrionForCausalLM(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
quant_config
=
quant_config
self
.
model
=
OrionModel
(
config
,
linear_method
)
self
.
model
=
OrionModel
(
config
,
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/phi.py
View file @
1591c68f
...
@@ -45,10 +45,11 @@ from vllm.attention import Attention, AttentionMetadata
...
@@ -45,10 +45,11 @@ from vllm.attention import Attention, AttentionMetadata
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
@@ -62,7 +63,7 @@ class PhiAttention(nn.Module):
...
@@ -62,7 +63,7 @@ class PhiAttention(nn.Module):
def
__init__
(
self
,
def
__init__
(
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
total_num_heads
=
config
.
num_attention_heads
self
.
total_num_heads
=
config
.
num_attention_heads
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
...
@@ -80,12 +81,12 @@ class PhiAttention(nn.Module):
...
@@ -80,12 +81,12 @@ class PhiAttention(nn.Module):
self
.
head_size
,
self
.
head_size
,
self
.
total_num_heads
,
self
.
total_num_heads
,
bias
=
True
,
bias
=
True
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
dense
=
RowParallelLinear
(
self
.
dense
=
RowParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
self
.
hidden_size
,
self
.
hidden_size
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
scaling
=
self
.
head_size
**-
0.5
scaling
=
self
.
head_size
**-
0.5
...
@@ -125,7 +126,7 @@ class PhiMLP(nn.Module):
...
@@ -125,7 +126,7 @@ class PhiMLP(nn.Module):
def
__init__
(
self
,
def
__init__
(
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
super
().
__init__
()
n_inner
=
getattr
(
config
,
"n_inner"
,
None
)
n_inner
=
getattr
(
config
,
"n_inner"
,
None
)
...
@@ -134,14 +135,13 @@ class PhiMLP(nn.Module):
...
@@ -134,14 +135,13 @@ class PhiMLP(nn.Module):
self
.
fc1
=
ColumnParallelLinear
(
self
.
fc1
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
hidden_size
,
n_inner
,
n_inner
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
fc2
=
RowParallelLinear
(
self
.
fc2
=
RowParallelLinear
(
n_inner
,
n_inner
,
config
.
hidden_size
,
config
.
hidden_size
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
self
.
act
=
get_act_fn
(
config
.
hidden_act
,
quant_config
,
n_inner
)
self
.
act
=
get_act_fn
(
config
.
hidden_act
,
quant_config
,
n_inner
)
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
):
...
@@ -155,12 +155,12 @@ class PhiLayer(nn.Module):
...
@@ -155,12 +155,12 @@ class PhiLayer(nn.Module):
def
__init__
(
self
,
def
__init__
(
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
input_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
self
.
input_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
eps
=
config
.
layer_norm_eps
)
self
.
self_attn
=
PhiAttention
(
config
,
linear_method
)
self
.
self_attn
=
PhiAttention
(
config
,
quant_config
)
self
.
mlp
=
PhiMLP
(
config
,
linear_method
)
self
.
mlp
=
PhiMLP
(
config
,
quant_config
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -186,14 +186,14 @@ class PhiModel(nn.Module):
...
@@ -186,14 +186,14 @@ class PhiModel(nn.Module):
def
__init__
(
self
,
def
__init__
(
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
quant_config
=
quant_config
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
)
config
.
hidden_size
)
self
.
layers
=
nn
.
ModuleList
([
self
.
layers
=
nn
.
ModuleList
([
PhiLayer
(
config
,
linear_method
)
PhiLayer
(
config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
])
self
.
final_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
self
.
final_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
...
@@ -225,12 +225,12 @@ class PhiForCausalLM(nn.Module):
...
@@ -225,12 +225,12 @@ class PhiForCausalLM(nn.Module):
def
__init__
(
self
,
def
__init__
(
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
quant_config
=
quant_config
self
.
model
=
PhiModel
(
config
,
linear_method
)
self
.
model
=
PhiModel
(
config
,
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
...
...
vllm/model_executor/models/qwen.py
View file @
1591c68f
...
@@ -14,11 +14,12 @@ from vllm.attention import Attention, AttentionMetadata
...
@@ -14,11 +14,12 @@ from vllm.attention import Attention, AttentionMetadata
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
@@ -35,17 +36,17 @@ class QWenMLP(nn.Module):
...
@@ -35,17 +36,17 @@ class QWenMLP(nn.Module):
hidden_size
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
=
"silu"
,
hidden_act
:
str
=
"silu"
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
self
.
c_proj
=
RowParallelLinear
(
intermediate_size
,
self
.
c_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
hidden_size
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
if
hidden_act
!=
"silu"
:
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
"Only silu is supported for now."
)
...
@@ -67,7 +68,7 @@ class QWenAttention(nn.Module):
...
@@ -67,7 +68,7 @@ class QWenAttention(nn.Module):
max_position_embeddings
:
int
,
max_position_embeddings
:
int
,
rope_theta
:
float
=
10000
,
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
...
@@ -83,13 +84,13 @@ class QWenAttention(nn.Module):
...
@@ -83,13 +84,13 @@ class QWenAttention(nn.Module):
self
.
head_dim
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_heads
,
bias
=
True
,
bias
=
True
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
c_proj
=
RowParallelLinear
(
self
.
c_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
hidden_size
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
scaling
=
self
.
head_dim
**-
0.5
...
@@ -122,7 +123,7 @@ class QWenBlock(nn.Module):
...
@@ -122,7 +123,7 @@ class QWenBlock(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
ln_1
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_1
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
...
@@ -134,13 +135,13 @@ class QWenBlock(nn.Module):
...
@@ -134,13 +135,13 @@ class QWenBlock(nn.Module):
config
.
max_position_embeddings
,
config
.
max_position_embeddings
,
rope_theta
=
rope_theta
,
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
rope_scaling
=
rope_scaling
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
self
.
ln_2
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_2
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
mlp
=
QWenMLP
(
config
.
hidden_size
,
self
.
mlp
=
QWenMLP
(
config
.
hidden_size
,
config
.
intermediate_size
//
2
,
config
.
intermediate_size
//
2
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -174,7 +175,7 @@ class QWenModel(nn.Module):
...
@@ -174,7 +175,7 @@ class QWenModel(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
...
@@ -185,7 +186,7 @@ class QWenModel(nn.Module):
...
@@ -185,7 +186,7 @@ class QWenModel(nn.Module):
config
.
hidden_size
,
config
.
hidden_size
,
)
)
self
.
h
=
nn
.
ModuleList
([
self
.
h
=
nn
.
ModuleList
([
QWenBlock
(
config
,
linear_method
)
QWenBlock
(
config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
])
self
.
ln_f
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_f
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
...
@@ -217,12 +218,12 @@ class QWenLMHeadModel(nn.Module):
...
@@ -217,12 +218,12 @@ class QWenLMHeadModel(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
quant_config
=
quant_config
self
.
transformer
=
QWenModel
(
config
,
linear_method
)
self
.
transformer
=
QWenModel
(
config
,
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/qwen2.py
View file @
1591c68f
...
@@ -33,11 +33,12 @@ from vllm.config import LoRAConfig
...
@@ -33,11 +33,12 @@ from vllm.config import LoRAConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
@@ -54,17 +55,17 @@ class Qwen2MLP(nn.Module):
...
@@ -54,17 +55,17 @@ class Qwen2MLP(nn.Module):
hidden_size
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
hidden_act
:
str
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
hidden_size
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
if
hidden_act
!=
"silu"
:
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
"Only silu is supported for now."
)
...
@@ -86,7 +87,7 @@ class Qwen2Attention(nn.Module):
...
@@ -86,7 +87,7 @@ class Qwen2Attention(nn.Module):
max_position
:
int
=
4096
*
32
,
max_position
:
int
=
4096
*
32
,
rope_theta
:
float
=
10000
,
rope_theta
:
float
=
10000
,
use_sliding_window
:
bool
=
False
,
use_sliding_window
:
bool
=
False
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
)
->
None
:
sliding_window
:
Optional
[
int
]
=
None
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
...
@@ -117,13 +118,13 @@ class Qwen2Attention(nn.Module):
...
@@ -117,13 +118,13 @@ class Qwen2Attention(nn.Module):
self
.
total_num_heads
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
self
.
total_num_kv_heads
,
bias
=
True
,
bias
=
True
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
o_proj
=
RowParallelLinear
(
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
hidden_size
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
rotary_emb
=
get_rope
(
self
.
rotary_emb
=
get_rope
(
...
@@ -159,7 +160,7 @@ class Qwen2DecoderLayer(nn.Module):
...
@@ -159,7 +160,7 @@ class Qwen2DecoderLayer(nn.Module):
self
,
self
,
config
:
Qwen2Config
,
config
:
Qwen2Config
,
layer_idx
:
int
,
layer_idx
:
int
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
...
@@ -174,13 +175,13 @@ class Qwen2DecoderLayer(nn.Module):
...
@@ -174,13 +175,13 @@ class Qwen2DecoderLayer(nn.Module):
num_kv_heads
=
config
.
num_key_value_heads
,
num_kv_heads
=
config
.
num_key_value_heads
,
rope_theta
=
rope_theta
,
rope_theta
=
rope_theta
,
use_sliding_window
=
use_sliding_window
,
use_sliding_window
=
use_sliding_window
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
sliding_window
=
config
.
sliding_window
)
sliding_window
=
config
.
sliding_window
)
self
.
mlp
=
Qwen2MLP
(
self
.
mlp
=
Qwen2MLP
(
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
hidden_act
=
config
.
hidden_act
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
eps
=
config
.
rms_norm_eps
)
...
@@ -221,7 +222,7 @@ class Qwen2Model(nn.Module):
...
@@ -221,7 +222,7 @@ class Qwen2Model(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
Qwen2Config
,
config
:
Qwen2Config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
...
@@ -233,7 +234,7 @@ class Qwen2Model(nn.Module):
...
@@ -233,7 +234,7 @@ class Qwen2Model(nn.Module):
config
.
hidden_size
,
config
.
hidden_size
,
)
)
self
.
layers
=
nn
.
ModuleList
([
self
.
layers
=
nn
.
ModuleList
([
Qwen2DecoderLayer
(
config
,
layer_idx
,
linear_method
)
Qwen2DecoderLayer
(
config
,
layer_idx
,
quant_config
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)
])
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
@@ -286,14 +287,14 @@ class Qwen2ForCausalLM(nn.Module):
...
@@ -286,14 +287,14 @@ class Qwen2ForCausalLM(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
Qwen2Config
,
config
:
Qwen2Config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
)
->
None
:
del
lora_config
del
lora_config
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
quant_config
=
quant_config
self
.
model
=
Qwen2Model
(
config
,
linear_method
)
self
.
model
=
Qwen2Model
(
config
,
quant_config
)
if
config
.
tie_word_embeddings
:
if
config
.
tie_word_embeddings
:
self
.
lm_head_weight
=
self
.
model
.
embed_tokens
.
weight
self
.
lm_head_weight
=
self
.
model
.
embed_tokens
.
weight
...
...
vllm/model_executor/models/qwen2_moe.py
View file @
1591c68f
...
@@ -36,12 +36,13 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
...
@@ -36,12 +36,13 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
ReplicatedLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
@@ -58,18 +59,18 @@ class Qwen2MoeMLP(nn.Module):
...
@@ -58,18 +59,18 @@ class Qwen2MoeMLP(nn.Module):
hidden_size
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
hidden_act
:
str
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
reduce_results
:
bool
=
True
,
reduce_results
:
bool
=
True
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
hidden_size
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
reduce_results
=
reduce_results
)
reduce_results
=
reduce_results
)
if
hidden_act
!=
"silu"
:
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
...
@@ -88,7 +89,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
...
@@ -88,7 +89,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
...
@@ -105,7 +106,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
...
@@ -105,7 +106,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
Qwen2MoeMLP
(
hidden_size
=
config
.
hidden_size
,
Qwen2MoeMLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
hidden_act
=
config
.
hidden_act
,
hidden_act
=
config
.
hidden_act
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
reduce_results
=
False
)
reduce_results
=
False
)
for
idx
in
range
(
self
.
n_routed_experts
)
for
idx
in
range
(
self
.
n_routed_experts
)
])
])
...
@@ -114,13 +115,13 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
...
@@ -114,13 +115,13 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
self
.
n_routed_experts
,
self
.
n_routed_experts
,
bias
=
False
,
bias
=
False
,
linear_method
=
None
)
quant_config
=
None
)
if
config
.
shared_expert_intermediate_size
>
0
:
if
config
.
shared_expert_intermediate_size
>
0
:
self
.
shared_expert
=
Qwen2MoeMLP
(
self
.
shared_expert
=
Qwen2MoeMLP
(
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
shared_expert_intermediate_size
,
intermediate_size
=
config
.
shared_expert_intermediate_size
,
hidden_act
=
config
.
hidden_act
,
hidden_act
=
config
.
hidden_act
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
reduce_results
=
False
,
reduce_results
=
False
,
)
)
else
:
else
:
...
@@ -186,7 +187,7 @@ class Qwen2MoeAttention(nn.Module):
...
@@ -186,7 +187,7 @@ class Qwen2MoeAttention(nn.Module):
rope_theta
:
float
=
10000
,
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
8192
,
max_position_embeddings
:
int
=
8192
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
...
@@ -217,14 +218,14 @@ class Qwen2MoeAttention(nn.Module):
...
@@ -217,14 +218,14 @@ class Qwen2MoeAttention(nn.Module):
self
.
total_num_heads
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
self
.
total_num_kv_heads
,
bias
=
True
,
bias
=
True
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
o_proj
=
RowParallelLinear
(
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
hidden_size
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
rotary_emb
=
get_rope
(
self
.
rotary_emb
=
get_rope
(
...
@@ -260,7 +261,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
...
@@ -260,7 +261,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
self
,
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
layer_idx
:
int
,
layer_idx
:
int
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
...
@@ -275,18 +276,18 @@ class Qwen2MoeDecoderLayer(nn.Module):
...
@@ -275,18 +276,18 @@ class Qwen2MoeDecoderLayer(nn.Module):
rope_theta
=
rope_theta
,
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
max_position_embeddings
=
max_position_embeddings
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
if
(
config
.
num_experts
is
not
None
if
(
config
.
num_experts
is
not
None
and
(
layer_idx
+
1
)
%
config
.
decoder_sparse_step
==
0
):
and
(
layer_idx
+
1
)
%
config
.
decoder_sparse_step
==
0
):
self
.
mlp
=
Qwen2MoeSparseMoeBlock
(
config
=
config
,
self
.
mlp
=
Qwen2MoeSparseMoeBlock
(
config
=
config
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
else
:
else
:
self
.
mlp
=
Qwen2MoeMLP
(
self
.
mlp
=
Qwen2MoeMLP
(
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
hidden_act
=
config
.
hidden_act
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
eps
=
config
.
rms_norm_eps
)
...
@@ -327,7 +328,7 @@ class Qwen2MoeModel(nn.Module):
...
@@ -327,7 +328,7 @@ class Qwen2MoeModel(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
padding_idx
=
config
.
pad_token_id
self
.
padding_idx
=
config
.
pad_token_id
...
@@ -338,9 +339,7 @@ class Qwen2MoeModel(nn.Module):
...
@@ -338,9 +339,7 @@ class Qwen2MoeModel(nn.Module):
config
.
hidden_size
,
config
.
hidden_size
,
)
)
self
.
layers
=
nn
.
ModuleList
([
self
.
layers
=
nn
.
ModuleList
([
Qwen2MoeDecoderLayer
(
config
,
Qwen2MoeDecoderLayer
(
config
,
layer_idx
,
quant_config
=
quant_config
)
layer_idx
,
linear_method
=
linear_method
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)
])
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
@@ -370,12 +369,12 @@ class Qwen2MoeForCausalLM(nn.Module):
...
@@ -370,12 +369,12 @@ class Qwen2MoeForCausalLM(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
quant_config
=
quant_config
self
.
model
=
Qwen2MoeModel
(
config
,
linear_method
)
self
.
model
=
Qwen2MoeModel
(
config
,
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
...
...
Prev
1
…
8
9
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment