Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a62aaf1d
Unverified
Commit
a62aaf1d
authored
Apr 26, 2024
by
Cody Yu
Committed by
GitHub
Apr 26, 2024
Browse files
[Misc][Refactor] Generalize linear_method to be quant_method (#4373)
parent
603ad848
Changes
45
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
344 additions
and
329 deletions
+344
-329
vllm/model_executor/models/deepseek.py
vllm/model_executor/models/deepseek.py
+22
-23
vllm/model_executor/models/falcon.py
vllm/model_executor/models/falcon.py
+17
-16
vllm/model_executor/models/gemma.py
vllm/model_executor/models/gemma.py
+17
-16
vllm/model_executor/models/gpt2.py
vllm/model_executor/models/gpt2.py
+17
-16
vllm/model_executor/models/gpt_bigcode.py
vllm/model_executor/models/gpt_bigcode.py
+17
-16
vllm/model_executor/models/gpt_j.py
vllm/model_executor/models/gpt_j.py
+17
-16
vllm/model_executor/models/gpt_neox.py
vllm/model_executor/models/gpt_neox.py
+17
-16
vllm/model_executor/models/internlm2.py
vllm/model_executor/models/internlm2.py
+17
-16
vllm/model_executor/models/jais.py
vllm/model_executor/models/jais.py
+17
-16
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+16
-16
vllm/model_executor/models/llava.py
vllm/model_executor/models/llava.py
+5
-4
vllm/model_executor/models/minicpm.py
vllm/model_executor/models/minicpm.py
+18
-17
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+22
-22
vllm/model_executor/models/mixtral_quant.py
vllm/model_executor/models/mixtral_quant.py
+21
-20
vllm/model_executor/models/mpt.py
vllm/model_executor/models/mpt.py
+17
-16
vllm/model_executor/models/olmo.py
vllm/model_executor/models/olmo.py
+16
-16
vllm/model_executor/models/opt.py
vllm/model_executor/models/opt.py
+19
-18
vllm/model_executor/models/orion.py
vllm/model_executor/models/orion.py
+17
-16
vllm/model_executor/models/phi.py
vllm/model_executor/models/phi.py
+18
-17
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+17
-16
No files found.
vllm/model_executor/models/deepseek.py
View file @
a62aaf1d
...
...
@@ -34,12 +34,13 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
...
@@ -56,18 +57,18 @@ class DeepseekMLP(nn.Module):
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
reduce_results
:
bool
=
True
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
reduce_results
=
reduce_results
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
...
...
@@ -86,7 +87,7 @@ class DeepseekMoE(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
...
...
@@ -103,7 +104,7 @@ class DeepseekMoE(nn.Module):
DeepseekMLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
hidden_act
=
config
.
hidden_act
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
reduce_results
=
False
)
for
idx
in
range
(
self
.
n_routed_experts
)
])
...
...
@@ -112,7 +113,7 @@ class DeepseekMoE(nn.Module):
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
self
.
n_routed_experts
,
bias
=
False
,
linear_method
=
None
)
quant_config
=
None
)
if
config
.
n_shared_experts
is
not
None
:
intermediate_size
=
(
config
.
moe_intermediate_size
*
...
...
@@ -121,7 +122,7 @@ class DeepseekMoE(nn.Module):
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
reduce_results
=
False
,
)
...
...
@@ -177,7 +178,7 @@ class DeepseekAttention(nn.Module):
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
8192
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
...
...
@@ -208,14 +209,14 @@ class DeepseekAttention(nn.Module):
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
rotary_emb
=
get_rope
(
...
...
@@ -251,7 +252,7 @@ class DeepseekDecoderLayer(nn.Module):
self
,
config
:
PretrainedConfig
,
layer_idx
:
int
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
...
...
@@ -266,18 +267,18 @@ class DeepseekDecoderLayer(nn.Module):
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
if
(
config
.
n_routed_experts
is
not
None
and
layer_idx
>=
config
.
first_k_dense_replace
and
layer_idx
%
config
.
moe_layer_freq
==
0
):
self
.
mlp
=
DeepseekMoE
(
config
=
config
,
linear_method
=
linear_method
)
self
.
mlp
=
DeepseekMoE
(
config
=
config
,
quant_config
=
quant_config
)
else
:
self
.
mlp
=
DeepseekMLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
...
@@ -320,7 +321,7 @@ class DeepseekModel(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
padding_idx
=
config
.
pad_token_id
...
...
@@ -331,9 +332,7 @@ class DeepseekModel(nn.Module):
config
.
hidden_size
,
)
self
.
layers
=
nn
.
ModuleList
([
DeepseekDecoderLayer
(
config
,
layer_idx
,
linear_method
=
linear_method
)
DeepseekDecoderLayer
(
config
,
layer_idx
,
quant_config
=
quant_config
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
...
@@ -361,12 +360,12 @@ class DeepseekForCausalLM(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
model
=
DeepseekModel
(
config
,
linear_method
)
self
.
quant_config
=
quant_config
self
.
model
=
DeepseekModel
(
config
,
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/falcon.py
View file @
a62aaf1d
...
...
@@ -32,10 +32,11 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
...
@@ -76,7 +77,7 @@ class FalconAttention(nn.Module):
def
__init__
(
self
,
config
:
FalconConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
...
...
@@ -115,7 +116,7 @@ class FalconAttention(nn.Module):
self
.
total_num_kv_heads
,
bias
=
config
.
bias
,
skip_bias_add
=
True
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
...
...
@@ -129,7 +130,7 @@ class FalconAttention(nn.Module):
self
.
hidden_size
,
bias
=
config
.
bias
,
skip_bias_add
=
True
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
reduce_results
=
self
.
reduce_row_parallel_results
)
self
.
use_rotary
=
config
.
rotary
...
...
@@ -192,7 +193,7 @@ class FalconMLP(nn.Module):
def
__init__
(
self
,
config
:
FalconConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
hidden_size
=
config
.
hidden_size
...
...
@@ -201,8 +202,8 @@ class FalconMLP(nn.Module):
4
*
hidden_size
,
bias
=
config
.
bias
,
skip_bias_add
=
True
,
linear_method
=
linear_method
)
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
quant_config
=
quant_config
)
quant_config
=
getattr
(
quant_config
,
"quant_config"
,
None
)
self
.
act
=
get_act_fn
(
"gelu"
,
quant_config
,
4
*
hidden_size
)
self
.
reduce_row_parallel_results
=
not
(
config
.
new_decoder_architecture
or
config
.
parallel_attn
)
...
...
@@ -212,7 +213,7 @@ class FalconMLP(nn.Module):
bias
=
config
.
bias
,
skip_bias_add
=
True
,
reduce_results
=
self
.
reduce_row_parallel_results
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# NOTE(zhuohan): Following huggingface, we do not fuse bias add here.
...
...
@@ -229,13 +230,13 @@ class FalconDecoderLayer(nn.Module):
def
__init__
(
self
,
config
:
FalconConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
hidden_size
=
config
.
hidden_size
self
.
num_heads
=
config
.
num_attention_heads
self
.
self_attention
=
FalconAttention
(
config
,
linear_method
)
self
.
mlp
=
FalconMLP
(
config
,
linear_method
)
self
.
self_attention
=
FalconAttention
(
config
,
quant_config
)
self
.
mlp
=
FalconMLP
(
config
,
quant_config
)
self
.
config
=
config
if
config
.
new_decoder_architecture
:
...
...
@@ -311,7 +312,7 @@ class FalconModel(nn.Module):
def
__init__
(
self
,
config
:
FalconConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
...
...
@@ -327,7 +328,7 @@ class FalconModel(nn.Module):
# Transformer blocks
self
.
h
=
nn
.
ModuleList
([
FalconDecoderLayer
(
config
,
linear_method
)
FalconDecoderLayer
(
config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
...
...
@@ -359,12 +360,12 @@ class FalconForCausalLM(nn.Module):
def
__init__
(
self
,
config
:
FalconConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
transformer
=
FalconModel
(
config
,
linear_method
)
self
.
quant_config
=
quant_config
self
.
transformer
=
FalconModel
(
config
,
quant_config
)
self
.
lm_head_weight
=
self
.
transformer
.
word_embeddings
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/gemma.py
View file @
a62aaf1d
...
...
@@ -27,11 +27,12 @@ from vllm.distributed import get_tensor_model_parallel_world_size
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
GeluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
...
@@ -77,17 +78,17 @@ class GemmaMLP(nn.Module):
intermediate_size
:
int
,
hidden_act
:
Optional
[
str
]
=
None
,
hidden_activation
:
Optional
[
str
]
=
None
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
self
.
act_fn
=
_get_gemma_act_fn
(
hidden_act
,
hidden_activation
)
def
forward
(
self
,
x
):
...
...
@@ -106,7 +107,7 @@ class GemmaAttention(nn.Module):
head_dim
:
int
,
max_position_embeddings
:
int
=
8192
,
rope_theta
:
float
=
10000
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
)
->
None
:
quant_config
:
Optional
[
QuantizationConfig
]
=
None
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
...
...
@@ -135,13 +136,13 @@ class GemmaAttention(nn.Module):
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
rotary_emb
=
get_rope
(
...
...
@@ -176,7 +177,7 @@ class GemmaDecoderLayer(nn.Module):
def
__init__
(
self
,
config
:
GemmaConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
...
...
@@ -187,14 +188,14 @@ class GemmaDecoderLayer(nn.Module):
head_dim
=
config
.
head_dim
,
max_position_embeddings
=
config
.
max_position_embeddings
,
rope_theta
=
config
.
rope_theta
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
mlp
=
GemmaMLP
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
hidden_activation
=
getattr
(
config
,
"hidden_activation"
,
None
),
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
...
@@ -235,7 +236,7 @@ class GemmaModel(nn.Module):
def
__init__
(
self
,
config
:
GemmaConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
...
...
@@ -245,7 +246,7 @@ class GemmaModel(nn.Module):
config
.
hidden_size
,
)
self
.
layers
=
nn
.
ModuleList
([
GemmaDecoderLayer
(
config
,
linear_method
)
GemmaDecoderLayer
(
config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
...
@@ -308,14 +309,14 @@ class GemmaForCausalLM(nn.Module):
def
__init__
(
self
,
config
:
GemmaConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
del
lora_config
# Unused.
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
model
=
GemmaModel
(
config
,
linear_method
)
self
.
quant_config
=
quant_config
self
.
model
=
GemmaModel
(
config
,
quant_config
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/gpt2.py
View file @
a62aaf1d
...
...
@@ -27,10 +27,11 @@ from vllm.attention import Attention, AttentionMetadata
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
...
...
@@ -44,7 +45,7 @@ class GPT2Attention(nn.Module):
def
__init__
(
self
,
config
:
GPT2Config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
...
...
@@ -61,13 +62,13 @@ class GPT2Attention(nn.Module):
self
.
head_dim
,
total_num_heads
,
bias
=
True
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
c_proj
=
RowParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
True
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
scale
=
self
.
scale
)
...
...
@@ -90,7 +91,7 @@ class GPT2MLP(nn.Module):
self
,
intermediate_size
:
int
,
config
:
GPT2Config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
hidden_size
=
config
.
hidden_size
...
...
@@ -98,15 +99,15 @@ class GPT2MLP(nn.Module):
hidden_size
,
intermediate_size
,
bias
=
True
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
c_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
True
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
quant_config
=
getattr
(
quant_config
,
"quant_config"
,
None
)
self
.
act
=
get_act_fn
(
config
.
activation_function
,
quant_config
,
intermediate_size
)
...
...
@@ -122,7 +123,7 @@ class GPT2Block(nn.Module):
def
__init__
(
self
,
config
:
GPT2Config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
hidden_size
=
config
.
hidden_size
...
...
@@ -130,9 +131,9 @@ class GPT2Block(nn.Module):
hidden_size
)
self
.
ln_1
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
attn
=
GPT2Attention
(
config
,
linear_method
)
self
.
attn
=
GPT2Attention
(
config
,
quant_config
)
self
.
ln_2
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
mlp
=
GPT2MLP
(
inner_dim
,
config
,
linear_method
)
self
.
mlp
=
GPT2MLP
(
inner_dim
,
config
,
quant_config
)
def
forward
(
self
,
...
...
@@ -163,7 +164,7 @@ class GPT2Model(nn.Module):
def
__init__
(
self
,
config
:
GPT2Config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
...
...
@@ -174,7 +175,7 @@ class GPT2Model(nn.Module):
self
.
wte
=
VocabParallelEmbedding
(
config
.
vocab_size
,
self
.
embed_dim
)
self
.
wpe
=
nn
.
Embedding
(
config
.
max_position_embeddings
,
self
.
embed_dim
)
self
.
h
=
nn
.
ModuleList
([
GPT2Block
(
config
,
linear_method
)
GPT2Block
(
config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
ln_f
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_epsilon
)
...
...
@@ -203,12 +204,12 @@ class GPT2LMHeadModel(nn.Module):
def
__init__
(
self
,
config
:
GPT2Config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
transformer
=
GPT2Model
(
config
,
linear_method
)
self
.
quant_config
=
quant_config
self
.
transformer
=
GPT2Model
(
config
,
quant_config
)
self
.
lm_head_weight
=
self
.
transformer
.
wte
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/gpt_bigcode.py
View file @
a62aaf1d
...
...
@@ -28,10 +28,11 @@ from vllm.attention import Attention, AttentionMetadata
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
...
...
@@ -45,7 +46,7 @@ class GPTBigCodeAttention(nn.Module):
def
__init__
(
self
,
config
:
GPTBigCodeConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
...
...
@@ -72,14 +73,14 @@ class GPTBigCodeAttention(nn.Module):
total_num_heads
,
total_num_kv_heads
,
bias
=
True
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
c_proj
=
RowParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
True
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
...
...
@@ -111,7 +112,7 @@ class GPTBigMLP(nn.Module):
self
,
intermediate_size
:
int
,
config
:
GPTBigCodeConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
hidden_size
=
config
.
hidden_size
...
...
@@ -119,15 +120,15 @@ class GPTBigMLP(nn.Module):
hidden_size
,
intermediate_size
,
bias
=
True
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
c_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
True
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
quant_config
=
getattr
(
quant_config
,
"quant_config"
,
None
)
self
.
act
=
get_act_fn
(
config
.
activation_function
,
quant_config
,
intermediate_size
)
...
...
@@ -143,7 +144,7 @@ class GPTBigCodeBlock(nn.Module):
def
__init__
(
self
,
config
:
GPTBigCodeConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
hidden_size
=
config
.
hidden_size
...
...
@@ -151,9 +152,9 @@ class GPTBigCodeBlock(nn.Module):
hidden_size
)
self
.
ln_1
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
attn
=
GPTBigCodeAttention
(
config
,
linear_method
)
self
.
attn
=
GPTBigCodeAttention
(
config
,
quant_config
)
self
.
ln_2
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
mlp
=
GPTBigMLP
(
inner_dim
,
config
,
linear_method
)
self
.
mlp
=
GPTBigMLP
(
inner_dim
,
config
,
quant_config
)
def
forward
(
self
,
...
...
@@ -184,7 +185,7 @@ class GPTBigCodeModel(nn.Module):
def
__init__
(
self
,
config
:
GPTBigCodeConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
...
...
@@ -195,7 +196,7 @@ class GPTBigCodeModel(nn.Module):
self
.
wte
=
VocabParallelEmbedding
(
config
.
vocab_size
,
self
.
embed_dim
)
self
.
wpe
=
nn
.
Embedding
(
config
.
max_position_embeddings
,
self
.
embed_dim
)
self
.
h
=
nn
.
ModuleList
([
GPTBigCodeBlock
(
config
,
linear_method
)
GPTBigCodeBlock
(
config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
ln_f
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_epsilon
)
...
...
@@ -224,12 +225,12 @@ class GPTBigCodeForCausalLM(nn.Module):
def
__init__
(
self
,
config
:
GPTBigCodeConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
transformer
=
GPTBigCodeModel
(
config
,
linear_method
)
self
.
quant_config
=
quant_config
self
.
transformer
=
GPTBigCodeModel
(
config
,
quant_config
)
self
.
lm_head_weight
=
self
.
transformer
.
wte
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/gpt_j.py
View file @
a62aaf1d
...
...
@@ -26,10 +26,11 @@ from vllm.attention import Attention, AttentionMetadata
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
...
@@ -44,7 +45,7 @@ class GPTJAttention(nn.Module):
def
__init__
(
self
,
config
:
GPTJConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
total_num_heads
=
config
.
num_attention_heads
...
...
@@ -56,13 +57,13 @@ class GPTJAttention(nn.Module):
self
.
head_size
,
self
.
total_num_heads
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
out_proj
=
RowParallelLinear
(
config
.
hidden_size
,
config
.
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
tp_world_size
=
get_tensor_model_parallel_world_size
()
...
...
@@ -105,21 +106,21 @@ class GPTJMLP(nn.Module):
self
,
intermediate_size
:
int
,
config
:
GPTJConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
hidden_size
=
config
.
n_embd
self
.
fc_in
=
ColumnParallelLinear
(
hidden_size
,
intermediate_size
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
fc_out
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
quant_config
=
getattr
(
quant_config
,
"quant_config"
,
None
)
self
.
act
=
get_act_fn
(
config
.
activation_function
,
quant_config
,
intermediate_size
)
...
...
@@ -135,14 +136,14 @@ class GPTJBlock(nn.Module):
def
__init__
(
self
,
config
:
GPTJConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
inner_dim
=
(
4
*
config
.
n_embd
if
config
.
n_inner
is
None
else
config
.
n_inner
)
self
.
ln_1
=
nn
.
LayerNorm
(
config
.
n_embd
,
eps
=
config
.
layer_norm_epsilon
)
self
.
attn
=
GPTJAttention
(
config
,
linear_method
)
self
.
mlp
=
GPTJMLP
(
inner_dim
,
config
,
linear_method
)
self
.
attn
=
GPTJAttention
(
config
,
quant_config
)
self
.
mlp
=
GPTJMLP
(
inner_dim
,
config
,
quant_config
)
def
forward
(
self
,
...
...
@@ -169,7 +170,7 @@ class GPTJModel(nn.Module):
def
__init__
(
self
,
config
:
GPTJConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
...
...
@@ -179,7 +180,7 @@ class GPTJModel(nn.Module):
self
.
embed_dim
,
)
self
.
h
=
nn
.
ModuleList
(
[
GPTJBlock
(
config
,
linear_method
)
for
_
in
range
(
config
.
n_layer
)])
[
GPTJBlock
(
config
,
quant_config
)
for
_
in
range
(
config
.
n_layer
)])
self
.
ln_f
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_epsilon
)
def
forward
(
...
...
@@ -207,13 +208,13 @@ class GPTJForCausalLM(nn.Module):
def
__init__
(
self
,
config
:
GPTJConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
quant_config
=
quant_config
assert
not
config
.
tie_word_embeddings
self
.
transformer
=
GPTJModel
(
config
,
linear_method
)
self
.
transformer
=
GPTJModel
(
config
,
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
n_embd
,
...
...
vllm/model_executor/models/gpt_neox.py
View file @
a62aaf1d
...
...
@@ -26,10 +26,11 @@ from vllm.attention import Attention, AttentionMetadata
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
...
@@ -44,7 +45,7 @@ class GPTNeoXAttention(nn.Module):
def
__init__
(
self
,
config
:
GPTNeoXConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
total_num_heads
=
config
.
num_attention_heads
...
...
@@ -63,13 +64,13 @@ class GPTNeoXAttention(nn.Module):
self
.
head_size
,
self
.
total_num_heads
,
bias
=
self
.
bias
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
dense
=
RowParallelLinear
(
config
.
hidden_size
,
config
.
hidden_size
,
bias
=
self
.
bias
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
scaling
=
self
.
head_size
**-
0.5
rotary_dim
=
int
(
self
.
head_size
*
config
.
rotary_pct
)
...
...
@@ -105,20 +106,20 @@ class GPTNeoXMLP(nn.Module):
def
__init__
(
self
,
config
:
GPTNeoXConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
dense_h_to_4h
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
intermediate_size
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
dense_4h_to_h
=
RowParallelLinear
(
config
.
intermediate_size
,
config
.
hidden_size
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
quant_config
=
getattr
(
quant_config
,
"quant_config"
,
None
)
self
.
act
=
get_act_fn
(
config
.
hidden_act
,
quant_config
,
config
.
intermediate_size
)
...
...
@@ -134,7 +135,7 @@ class GPTNeoXLayer(nn.Module):
def
__init__
(
self
,
config
:
GPTNeoXConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
use_parallel_residual
=
config
.
use_parallel_residual
...
...
@@ -142,8 +143,8 @@ class GPTNeoXLayer(nn.Module):
eps
=
config
.
layer_norm_eps
)
self
.
post_attention_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
attention
=
GPTNeoXAttention
(
config
,
linear_method
)
self
.
mlp
=
GPTNeoXMLP
(
config
,
linear_method
)
self
.
attention
=
GPTNeoXAttention
(
config
,
quant_config
)
self
.
mlp
=
GPTNeoXMLP
(
config
,
quant_config
)
def
forward
(
self
,
...
...
@@ -182,7 +183,7 @@ class GPTNeoXModel(nn.Module):
def
__init__
(
self
,
config
:
GPTNeoXConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
...
...
@@ -192,7 +193,7 @@ class GPTNeoXModel(nn.Module):
config
.
hidden_size
,
)
self
.
layers
=
nn
.
ModuleList
([
GPTNeoXLayer
(
config
,
linear_method
)
GPTNeoXLayer
(
config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
final_layer_norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
...
...
@@ -223,12 +224,12 @@ class GPTNeoXForCausalLM(nn.Module):
def
__init__
(
self
,
config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
gpt_neox
=
GPTNeoXModel
(
config
,
linear_method
)
self
.
quant_config
=
quant_config
self
.
gpt_neox
=
GPTNeoXModel
(
config
,
quant_config
)
self
.
embed_out
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
...
...
vllm/model_executor/models/internlm2.py
View file @
a62aaf1d
...
...
@@ -9,11 +9,12 @@ from vllm.attention import Attention, AttentionMetadata
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
...
@@ -30,17 +31,17 @@ class InternLM2MLP(nn.Module):
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
self
.
w2
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
...
...
@@ -63,7 +64,7 @@ class InternLM2Attention(nn.Module):
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
8192
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
...
...
@@ -94,13 +95,13 @@ class InternLM2Attention(nn.Module):
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
wo
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
rotary_emb
=
get_rope
(
...
...
@@ -135,7 +136,7 @@ class InternLMDecoderLayer(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
...
...
@@ -150,13 +151,13 @@ class InternLMDecoderLayer(nn.Module):
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
feed_forward
=
InternLM2MLP
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
attention_norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
...
@@ -195,7 +196,7 @@ class InternLM2Model(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
...
...
@@ -206,7 +207,7 @@ class InternLM2Model(nn.Module):
config
.
hidden_size
,
)
self
.
layers
=
nn
.
ModuleList
([
InternLMDecoderLayer
(
config
,
linear_method
)
InternLMDecoderLayer
(
config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
...
@@ -238,12 +239,12 @@ class InternLM2ForCausalLM(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
model
=
InternLM2Model
(
config
,
linear_method
)
self
.
quant_config
=
quant_config
self
.
model
=
InternLM2Model
(
config
,
quant_config
)
self
.
output
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/jais.py
View file @
a62aaf1d
...
...
@@ -29,10 +29,11 @@ from vllm.attention import Attention, AttentionMetadata
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
...
...
@@ -68,7 +69,7 @@ class JAISAttention(nn.Module):
def
__init__
(
self
,
config
:
JAISConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
...
...
@@ -88,13 +89,13 @@ class JAISAttention(nn.Module):
self
.
head_dim
,
total_num_heads
,
bias
=
True
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
c_proj
=
RowParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
True
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
tp_rank
=
get_tensor_model_parallel_rank
()
...
...
@@ -128,7 +129,7 @@ class JAISMLP(nn.Module):
self
,
intermediate_size
:
int
,
config
:
JAISConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
hidden_size
=
config
.
hidden_size
...
...
@@ -137,19 +138,19 @@ class JAISMLP(nn.Module):
hidden_size
,
intermediate_size
,
bias
=
True
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
c_fc2
=
(
ColumnParallelLinear
(
hidden_size
,
intermediate_size
,
bias
=
True
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
if
self
.
swiglu
else
None
)
self
.
c_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
True
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
act
=
SwiGLUActivation
()
...
...
@@ -169,7 +170,7 @@ class JAISBlock(nn.Module):
def
__init__
(
self
,
config
:
JAISConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
hidden_size
=
config
.
hidden_size
...
...
@@ -177,9 +178,9 @@ class JAISBlock(nn.Module):
hidden_size
)
self
.
ln_1
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
attn
=
JAISAttention
(
config
,
linear_method
)
self
.
attn
=
JAISAttention
(
config
,
quant_config
)
self
.
ln_2
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
mlp
=
JAISMLP
(
inner_dim
,
config
,
linear_method
)
self
.
mlp
=
JAISMLP
(
inner_dim
,
config
,
quant_config
)
def
forward
(
self
,
...
...
@@ -210,7 +211,7 @@ class JAISModel(nn.Module):
def
__init__
(
self
,
config
:
JAISConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
...
...
@@ -227,7 +228,7 @@ class JAISModel(nn.Module):
else
:
self
.
embeddings_scale
=
config
.
mup_embeddings_scale
self
.
h
=
nn
.
ModuleList
([
JAISBlock
(
config
,
linear_method
)
JAISBlock
(
config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
ln_f
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_epsilon
)
...
...
@@ -261,12 +262,12 @@ class JAISLMHeadModel(nn.Module):
def
__init__
(
self
,
config
:
JAISConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
transformer
=
JAISModel
(
config
,
linear_method
)
self
.
quant_config
=
quant_config
self
.
transformer
=
JAISModel
(
config
,
quant_config
)
self
.
lm_head_weight
=
self
.
transformer
.
wte
.
weight
if
hasattr
(
config
,
"width_scale"
):
self
.
output_logits_scale
=
config
.
width_scale
...
...
vllm/model_executor/models/llama.py
View file @
a62aaf1d
...
...
@@ -33,11 +33,12 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
...
@@ -56,17 +57,17 @@ class LlamaMLP(nn.Module):
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QKVParallelLinear
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
...
...
@@ -89,7 +90,7 @@ class LlamaAttention(nn.Module):
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
8192
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
bias
:
bool
=
False
,
sliding_window
:
Optional
[
int
]
=
None
,
)
->
None
:
...
...
@@ -131,13 +132,13 @@ class LlamaAttention(nn.Module):
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
bias
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
bias
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
rotary_emb
=
get_rope
(
...
...
@@ -174,7 +175,7 @@ class LlamaDecoderLayer(nn.Module):
def
__init__
(
self
,
config
:
LlamaConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
...
...
@@ -199,7 +200,7 @@ class LlamaDecoderLayer(nn.Module):
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
bias
=
attention_bias
,
sliding_window
=
sliding_window
,
)
...
...
@@ -207,7 +208,7 @@ class LlamaDecoderLayer(nn.Module):
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
...
@@ -248,7 +249,7 @@ class LlamaModel(nn.Module):
def
__init__
(
self
,
config
:
LlamaConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -264,7 +265,7 @@ class LlamaModel(nn.Module):
org_num_embeddings
=
config
.
vocab_size
,
)
self
.
layers
=
nn
.
ModuleList
([
LlamaDecoderLayer
(
config
,
linear_method
)
LlamaDecoderLayer
(
config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
...
@@ -329,13 +330,12 @@ class LlamaForCausalLM(nn.Module):
def
__init__
(
self
,
config
:
LlamaConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
model
=
LlamaModel
(
config
,
linear_method
,
lora_config
=
lora_config
)
self
.
model
=
LlamaModel
(
config
,
quant_config
,
lora_config
=
lora_config
)
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
...
...
vllm/model_executor/models/llava.py
View file @
a62aaf1d
...
...
@@ -9,8 +9,9 @@ from transformers import CLIPVisionModel, LlavaConfig
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VisionLanguageConfig
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
LinearMethodBase
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
...
...
@@ -61,7 +62,7 @@ class LlavaForConditionalGeneration(nn.Module):
def
__init__
(
self
,
config
:
"LlavaConfig"
,
vision_language_config
:
VisionLanguageConfig
,
linear_method
:
Optional
[
"LinearMethodBase
"
]
=
None
)
->
None
:
quant_config
:
Optional
[
"QuantizationConfig
"
]
=
None
)
->
None
:
super
().
__init__
()
self
.
config
=
config
...
...
@@ -83,8 +84,8 @@ class LlavaForConditionalGeneration(nn.Module):
text_hidden_size
=
config
.
text_config
.
hidden_size
,
projector_hidden_act
=
config
.
projector_hidden_act
)
self
.
linear_method
=
linear_method
self
.
language_model
=
LlamaModel
(
config
.
text_config
,
linear_method
)
self
.
quant_config
=
quant_config
self
.
language_model
=
LlamaModel
(
config
.
text_config
,
quant_config
)
self
.
unpadded_vocab_size
=
config
.
text_config
.
vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_vocab_size
,
...
...
vllm/model_executor/models/minicpm.py
View file @
a62aaf1d
...
...
@@ -35,12 +35,13 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
...
@@ -84,7 +85,7 @@ class MiniCPMMoE(nn.Module):
self
.
num_total_experts
,
bias
=
False
,
params_dtype
=
self
.
params_dtype
,
linear_method
=
None
)
quant_config
=
None
)
self
.
ws
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_total_experts
,
...
...
@@ -147,17 +148,17 @@ class MiniCPMMLP(nn.Module):
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
...
...
@@ -180,7 +181,7 @@ class MiniCPMAttention(nn.Module):
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
8192
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
...
...
@@ -211,13 +212,13 @@ class MiniCPMAttention(nn.Module):
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
rotary_emb
=
get_rope
(
...
...
@@ -258,7 +259,7 @@ class MiniCPMDecoderLayer(nn.Module):
def
__init__
(
self
,
config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
...
...
@@ -274,7 +275,7 @@ class MiniCPMDecoderLayer(nn.Module):
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
num_experts
=
getattr
(
self
.
config
,
"num_experts"
,
0
)
if
self
.
num_experts
==
0
:
...
...
@@ -282,7 +283,7 @@ class MiniCPMDecoderLayer(nn.Module):
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
else
:
self
.
mlp
=
MiniCPMMoE
(
num_experts
=
config
.
num_experts
,
...
...
@@ -329,7 +330,7 @@ class MiniCPMModel(nn.Module):
def
__init__
(
self
,
config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -345,7 +346,7 @@ class MiniCPMModel(nn.Module):
org_num_embeddings
=
config
.
vocab_size
,
)
self
.
layers
=
nn
.
ModuleList
([
MiniCPMDecoderLayer
(
config
,
linear_method
)
MiniCPMDecoderLayer
(
config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
...
@@ -412,15 +413,15 @@ class MiniCPMForCausalLM(nn.Module):
def
__init__
(
self
,
config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
num_experts
=
getattr
(
self
.
config
,
"num_experts"
,
0
)
self
.
linear_method
=
linear_method
self
.
quant_config
=
quant_config
self
.
model
=
MiniCPMModel
(
config
,
linear_method
,
quant_config
,
lora_config
=
lora_config
)
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
...
...
vllm/model_executor/models/mixtral.py
View file @
a62aaf1d
...
...
@@ -27,6 +27,7 @@ import torch
from
torch
import
nn
from
transformers
import
MixtralConfig
from
vllm
import
_custom_ops
as
ops
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
LoRAConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
...
...
@@ -34,13 +35,13 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
QKVParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.fp8
import
(
Fp8LinearMethod
,
per_tensor_quantize
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.fp8
import
Fp8Config
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
...
@@ -69,7 +70,7 @@ class MixtralMoE(nn.Module):
intermediate_size
:
int
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
tp_size
=
tp_size
or
get_tensor_model_parallel_world_size
()
...
...
@@ -79,7 +80,7 @@ class MixtralMoE(nn.Module):
self
.
intermediate_size
=
intermediate_size
//
self
.
tp_size
# FIXME(pcmoritz): Make this more general to support different
# quantization schemes
self
.
use_fp8
=
isinstance
(
linear_method
,
Fp8LinearMethod
)
self
.
use_fp8
=
isinstance
(
quant_config
,
Fp8Config
)
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
...
...
@@ -89,7 +90,7 @@ class MixtralMoE(nn.Module):
self
.
num_total_experts
,
bias
=
False
,
params_dtype
=
self
.
params_dtype
,
linear_method
=
None
)
quant_config
=
None
)
self
.
ws
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_total_experts
,
...
...
@@ -140,10 +141,10 @@ class MixtralMoE(nn.Module):
ws
=
torch
.
empty_like
(
self
.
ws
.
data
,
dtype
=
torch
.
float8_e4m3fn
)
w2s
=
torch
.
empty_like
(
self
.
w2s
.
data
,
dtype
=
torch
.
float8_e4m3fn
)
for
expert
in
range
(
self
.
num_total_experts
):
ws
[
expert
,
:,
:],
self
.
ws_scale
[
expert
]
=
per_tensor
_quant
ize
(
ws
[
expert
,
:,
:],
self
.
ws_scale
[
expert
]
=
ops
.
scaled_fp8
_quant
(
self
.
ws
.
data
[
expert
,
:,
:])
w2s
[
expert
,
:,
:],
self
.
w2s_scale
[
expert
]
=
per_tensor
_quant
ize
(
self
.
w2s
.
data
[
expert
,
:,
:])
expert
]
=
ops
.
scaled_fp8
_quant
(
self
.
w2s
.
data
[
expert
,
:,
:])
self
.
ws
=
nn
.
Parameter
(
ws
,
requires_grad
=
False
)
self
.
w2s
=
nn
.
Parameter
(
w2s
,
requires_grad
=
False
)
...
...
@@ -178,7 +179,7 @@ class MixtralAttention(nn.Module):
num_kv_heads
:
int
,
max_position
:
int
=
4096
*
32
,
rope_theta
:
float
=
10000
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
...
...
@@ -203,12 +204,12 @@ class MixtralAttention(nn.Module):
self
.
rope_theta
=
rope_theta
self
.
sliding_window
=
sliding_window
if
isinstance
(
linear_method
,
Fp8LinearMethod
):
if
isinstance
(
quant_config
,
Fp8Config
):
print_warning_once
(
"For Mixtral FP8 quantization, we currently do not quantize "
"the attention layers until their FP8 performance is improved."
)
linear_method
=
None
quant_config
=
None
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
...
...
@@ -216,13 +217,13 @@ class MixtralAttention(nn.Module):
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
...
...
@@ -259,7 +260,7 @@ class MixtralDecoderLayer(nn.Module):
def
__init__
(
self
,
config
:
MixtralConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
...
...
@@ -272,13 +273,13 @@ class MixtralDecoderLayer(nn.Module):
num_kv_heads
=
config
.
num_key_value_heads
,
rope_theta
=
rope_theta
,
sliding_window
=
config
.
sliding_window
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
self
.
block_sparse_moe
=
MixtralMoE
(
num_experts
=
config
.
num_local_experts
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
...
...
@@ -318,7 +319,7 @@ class MixtralModel(nn.Module):
def
__init__
(
self
,
config
:
MixtralConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -334,7 +335,7 @@ class MixtralModel(nn.Module):
org_num_embeddings
=
config
.
vocab_size
,
)
self
.
layers
=
nn
.
ModuleList
([
MixtralDecoderLayer
(
config
,
linear_method
=
linear_method
)
MixtralDecoderLayer
(
config
,
quant_config
=
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
...
@@ -384,14 +385,13 @@ class MixtralForCausalLM(nn.Module):
def
__init__
(
self
,
config
:
MixtralConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
model
=
MixtralModel
(
config
,
linear_method
,
quant_config
,
lora_config
=
lora_config
)
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
...
...
vllm/model_executor/models/mixtral_quant.py
View file @
a62aaf1d
...
...
@@ -34,11 +34,12 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
QKVParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
...
@@ -55,7 +56,7 @@ class MixtralMLP(nn.Module):
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
num_experts
=
num_experts
...
...
@@ -65,15 +66,15 @@ class MixtralMLP(nn.Module):
self
.
w1
=
ReplicatedLinear
(
self
.
hidden_dim
,
self
.
ffn_dim
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
self
.
w2
=
ReplicatedLinear
(
self
.
ffn_dim
,
self
.
hidden_dim
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
self
.
w3
=
ReplicatedLinear
(
self
.
hidden_dim
,
self
.
ffn_dim
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
# TODO: Use vllm's SiluAndMul
self
.
act_fn
=
nn
.
SiLU
()
...
...
@@ -92,7 +93,7 @@ class MixtralMoE(nn.Module):
def
__init__
(
self
,
config
:
MixtralConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
...
...
@@ -115,14 +116,14 @@ class MixtralMoE(nn.Module):
MixtralMLP
(
self
.
num_total_experts
,
config
.
hidden_size
,
config
.
intermediate_size
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
if
idx
in
self
.
expert_indicies
else
None
for
idx
in
range
(
self
.
num_total_experts
)
])
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
self
.
num_total_experts
,
bias
=
False
,
linear_method
=
None
)
quant_config
=
None
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
...
...
@@ -162,7 +163,7 @@ class MixtralAttention(nn.Module):
num_kv_heads
:
int
,
max_position
:
int
=
4096
*
32
,
rope_theta
:
float
=
10000
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
...
...
@@ -193,13 +194,13 @@ class MixtralAttention(nn.Module):
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
...
...
@@ -236,7 +237,7 @@ class MixtralDecoderLayer(nn.Module):
def
__init__
(
self
,
config
:
MixtralConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
...
...
@@ -249,9 +250,9 @@ class MixtralDecoderLayer(nn.Module):
num_kv_heads
=
config
.
num_key_value_heads
,
rope_theta
=
rope_theta
,
sliding_window
=
config
.
sliding_window
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
self
.
block_sparse_moe
=
MixtralMoE
(
config
=
config
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
...
...
@@ -291,7 +292,7 @@ class MixtralModel(nn.Module):
def
__init__
(
self
,
config
:
MixtralConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
padding_idx
=
config
.
pad_token_id
...
...
@@ -302,7 +303,7 @@ class MixtralModel(nn.Module):
config
.
hidden_size
,
)
self
.
layers
=
nn
.
ModuleList
([
MixtralDecoderLayer
(
config
,
linear_method
=
linear_method
)
MixtralDecoderLayer
(
config
,
quant_config
=
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
...
@@ -331,12 +332,12 @@ class MixtralForCausalLM(nn.Module):
def
__init__
(
self
,
config
:
MixtralConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
model
=
MixtralModel
(
config
,
linear_method
)
self
.
quant_config
=
quant_config
self
.
model
=
MixtralModel
(
config
,
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/mpt.py
View file @
a62aaf1d
...
...
@@ -11,10 +11,11 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
...
...
@@ -42,7 +43,7 @@ class MPTAttention(nn.Module):
def
__init__
(
self
,
config
:
MPTConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
d_model
=
config
.
d_model
...
...
@@ -65,7 +66,7 @@ class MPTAttention(nn.Module):
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
not
config
.
no_bias
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
if
self
.
qk_ln
:
self
.
q_ln
=
nn
.
LayerNorm
(
self
.
d_model
)
...
...
@@ -74,7 +75,7 @@ class MPTAttention(nn.Module):
self
.
d_model
,
self
.
d_model
,
bias
=
not
config
.
no_bias
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
tp_world_size
=
get_tensor_model_parallel_world_size
()
...
...
@@ -133,7 +134,7 @@ class MPTMLP(nn.Module):
def
__init__
(
self
,
config
:
MPTConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
hidden_size
=
config
.
d_model
...
...
@@ -143,15 +144,15 @@ class MPTMLP(nn.Module):
hidden_size
,
intermediate_size
,
bias
=
not
config
.
no_bias
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
quant_config
=
getattr
(
quant_config
,
"quant_config"
,
None
)
self
.
act
=
get_act_fn
(
"gelu"
,
quant_config
,
intermediate_size
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
not
config
.
no_bias
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -166,14 +167,14 @@ class MPTBlock(nn.Module):
def
__init__
(
self
,
config
:
MPTConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
hidden_size
=
config
.
d_model
self
.
norm_1
=
nn
.
LayerNorm
(
hidden_size
)
self
.
attn
=
MPTAttention
(
config
,
linear_method
)
self
.
attn
=
MPTAttention
(
config
,
quant_config
)
self
.
norm_2
=
nn
.
LayerNorm
(
hidden_size
)
self
.
ffn
=
MPTMLP
(
config
,
linear_method
)
self
.
ffn
=
MPTMLP
(
config
,
quant_config
)
def
forward
(
self
,
...
...
@@ -201,7 +202,7 @@ class MPTModel(nn.Module):
def
__init__
(
self
,
config
:
MPTConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
assert
config
.
embedding_fraction
==
1.0
...
...
@@ -212,7 +213,7 @@ class MPTModel(nn.Module):
config
.
d_model
,
)
self
.
blocks
=
nn
.
ModuleList
(
[
MPTBlock
(
config
,
linear_method
)
for
_
in
range
(
config
.
n_layers
)])
[
MPTBlock
(
config
,
quant_config
)
for
_
in
range
(
config
.
n_layers
)])
self
.
norm_f
=
nn
.
LayerNorm
(
config
.
d_model
)
if
config
.
no_bias
:
for
module
in
self
.
modules
():
...
...
@@ -246,14 +247,14 @@ class MPTForCausalLM(nn.Module):
def
__init__
(
self
,
config
:
MPTConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
assert
config
.
tie_word_embeddings
self
.
linear_method
=
linear_method
self
.
quant_config
=
quant_config
self
.
transformer
=
MPTModel
(
config
,
linear_method
)
self
.
transformer
=
MPTModel
(
config
,
quant_config
)
self
.
lm_head_weight
=
self
.
transformer
.
wte
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/olmo.py
View file @
a62aaf1d
...
...
@@ -30,11 +30,12 @@ from transformers import OlmoConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
...
@@ -54,7 +55,7 @@ class OlmoAttention(nn.Module):
def
__init__
(
self
,
config
:
OlmoConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
...
...
@@ -79,7 +80,7 @@ class OlmoAttention(nn.Module):
self
.
head_dim
,
self
.
total_num_heads
,
bias
=
config
.
attention_bias
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
# Rotary embeddings.
...
...
@@ -99,7 +100,7 @@ class OlmoAttention(nn.Module):
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
config
.
attention_bias
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
def
forward
(
...
...
@@ -129,7 +130,7 @@ class OlmoMLP(nn.Module):
def
__init__
(
self
,
config
:
OlmoConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
...
...
@@ -141,7 +142,7 @@ class OlmoMLP(nn.Module):
self
.
hidden_size
,
[
self
.
intermediate_size
]
*
2
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
# Activation function.
...
...
@@ -152,7 +153,7 @@ class OlmoMLP(nn.Module):
self
.
intermediate_size
,
self
.
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
def
forward
(
...
...
@@ -174,13 +175,13 @@ class OlmoDecoderLayer(nn.Module):
def
__init__
(
self
,
config
:
OlmoConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
# Attention block.
self
.
self_attn
=
OlmoAttention
(
config
,
linear_method
)
self
.
self_attn
=
OlmoAttention
(
config
,
quant_config
)
# MLP block.
self
.
mlp
=
OlmoMLP
(
config
,
linear_method
)
self
.
mlp
=
OlmoMLP
(
config
,
quant_config
)
# LayerNorm
self
.
input_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
...
...
@@ -216,14 +217,14 @@ class OlmoModel(nn.Module):
def
__init__
(
self
,
config
:
OlmoConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
self
.
config
=
config
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
layers
=
nn
.
ModuleList
([
OlmoDecoderLayer
(
config
,
linear_method
)
OlmoDecoderLayer
(
config
,
quant_config
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)
])
self
.
norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
...
...
@@ -270,11 +271,10 @@ class OlmoForCausalLM(nn.Module):
def
__init__
(
self
,
config
:
OlmoConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
model
=
OlmoModel
(
config
,
linear_method
)
self
.
model
=
OlmoModel
(
config
,
quant_config
)
if
config
.
tie_word_embeddings
:
self
.
lm_head_weight
=
self
.
model
.
embed_tokens
.
weight
else
:
...
...
vllm/model_executor/models/opt.py
View file @
a62aaf1d
...
...
@@ -27,11 +27,12 @@ from vllm.attention import Attention, AttentionMetadata
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
...
...
@@ -60,7 +61,7 @@ class OPTAttention(nn.Module):
embed_dim
:
int
,
num_heads
:
int
,
bias
:
bool
=
True
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
embed_dim
=
embed_dim
...
...
@@ -77,13 +78,13 @@ class OPTAttention(nn.Module):
self
.
head_dim
,
total_num_heads
,
bias
=
bias
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
out_proj
=
RowParallelLinear
(
embed_dim
,
embed_dim
,
bias
=
bias
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
...
...
@@ -107,7 +108,7 @@ class OPTDecoderLayer(nn.Module):
def
__init__
(
self
,
config
:
OPTConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
...
...
@@ -116,7 +117,7 @@ class OPTDecoderLayer(nn.Module):
embed_dim
=
self
.
embed_dim
,
num_heads
=
config
.
num_attention_heads
,
bias
=
config
.
enable_bias
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
do_layer_norm_before
=
config
.
do_layer_norm_before
...
...
@@ -127,16 +128,16 @@ class OPTDecoderLayer(nn.Module):
self
.
embed_dim
,
config
.
ffn_dim
,
bias
=
config
.
enable_bias
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
quant_config
=
getattr
(
quant_config
,
"quant_config"
,
None
)
self
.
activation_fn
=
get_act_fn
(
config
.
activation_function
,
quant_config
,
config
.
ffn_dim
)
self
.
fc2
=
RowParallelLinear
(
config
.
ffn_dim
,
self
.
embed_dim
,
bias
=
config
.
enable_bias
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
final_layer_norm
=
nn
.
LayerNorm
(
self
.
embed_dim
,
...
...
@@ -181,7 +182,7 @@ class OPTDecoder(nn.Module):
def
__init__
(
self
,
config
:
OPTConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
...
...
@@ -202,7 +203,7 @@ class OPTDecoder(nn.Module):
self
.
project_out
=
ReplicatedLinear
(
config
.
hidden_size
,
config
.
word_embed_proj_dim
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
else
:
self
.
project_out
=
None
...
...
@@ -210,7 +211,7 @@ class OPTDecoder(nn.Module):
self
.
project_in
=
ReplicatedLinear
(
config
.
word_embed_proj_dim
,
config
.
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
else
:
self
.
project_in
=
None
...
...
@@ -226,7 +227,7 @@ class OPTDecoder(nn.Module):
self
.
final_layer_norm
=
None
self
.
layers
=
nn
.
ModuleList
([
OPTDecoderLayer
(
config
,
linear_method
)
OPTDecoderLayer
(
config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
...
...
@@ -259,10 +260,10 @@ class OPTModel(nn.Module):
def
__init__
(
self
,
config
:
OPTConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
decoder
=
OPTDecoder
(
config
,
linear_method
)
self
.
decoder
=
OPTDecoder
(
config
,
quant_config
)
def
forward
(
self
,
...
...
@@ -279,12 +280,12 @@ class OPTForCausalLM(nn.Module):
def
__init__
(
self
,
config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
model
=
OPTModel
(
config
,
linear_method
)
self
.
quant_config
=
quant_config
self
.
model
=
OPTModel
(
config
,
quant_config
)
self
.
lm_head_weight
=
self
.
model
.
decoder
.
embed_tokens
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/orion.py
View file @
a62aaf1d
...
...
@@ -13,11 +13,12 @@ from transformers import PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
...
@@ -34,17 +35,17 @@ class OrionMLP(nn.Module):
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
...
...
@@ -67,7 +68,7 @@ class OrionAttention(nn.Module):
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
8192
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
...
...
@@ -98,13 +99,13 @@ class OrionAttention(nn.Module):
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
rotary_emb
=
get_rope
(
...
...
@@ -139,7 +140,7 @@ class OrionDecoderLayer(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
...
...
@@ -154,13 +155,13 @@ class OrionDecoderLayer(nn.Module):
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
mlp
=
OrionMLP
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
input_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
...
...
@@ -201,7 +202,7 @@ class OrionModel(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
...
...
@@ -212,7 +213,7 @@ class OrionModel(nn.Module):
config
.
hidden_size
,
)
self
.
layers
=
nn
.
ModuleList
([
OrionDecoderLayer
(
config
,
linear_method
)
OrionDecoderLayer
(
config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
...
@@ -244,12 +245,12 @@ class OrionForCausalLM(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
model
=
OrionModel
(
config
,
linear_method
)
self
.
quant_config
=
quant_config
self
.
model
=
OrionModel
(
config
,
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/phi.py
View file @
a62aaf1d
...
...
@@ -45,10 +45,11 @@ from vllm.attention import Attention, AttentionMetadata
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
...
@@ -62,7 +63,7 @@ class PhiAttention(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
self
.
total_num_heads
=
config
.
num_attention_heads
self
.
hidden_size
=
config
.
hidden_size
...
...
@@ -80,12 +81,12 @@ class PhiAttention(nn.Module):
self
.
head_size
,
self
.
total_num_heads
,
bias
=
True
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
dense
=
RowParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
scaling
=
self
.
head_size
**-
0.5
...
...
@@ -125,7 +126,7 @@ class PhiMLP(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
n_inner
=
getattr
(
config
,
"n_inner"
,
None
)
...
...
@@ -134,14 +135,14 @@ class PhiMLP(nn.Module):
self
.
fc1
=
ColumnParallelLinear
(
config
.
hidden_size
,
n_inner
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
fc2
=
RowParallelLinear
(
n_inner
,
config
.
hidden_size
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
quant_config
=
getattr
(
quant_config
,
"quant_config"
,
None
)
self
.
act
=
get_act_fn
(
config
.
hidden_act
,
quant_config
,
n_inner
)
def
forward
(
self
,
hidden_states
):
...
...
@@ -155,12 +156,12 @@ class PhiLayer(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
self
.
input_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
self_attn
=
PhiAttention
(
config
,
linear_method
)
self
.
mlp
=
PhiMLP
(
config
,
linear_method
)
self
.
self_attn
=
PhiAttention
(
config
,
quant_config
)
self
.
mlp
=
PhiMLP
(
config
,
quant_config
)
def
forward
(
self
,
...
...
@@ -186,14 +187,14 @@ class PhiModel(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
quant_config
=
quant_config
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
layers
=
nn
.
ModuleList
([
PhiLayer
(
config
,
linear_method
)
PhiLayer
(
config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
final_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
...
...
@@ -225,12 +226,12 @@ class PhiForCausalLM(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
quant_config
=
quant_config
self
.
model
=
PhiModel
(
config
,
linear_method
)
self
.
model
=
PhiModel
(
config
,
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
...
...
vllm/model_executor/models/qwen.py
View file @
a62aaf1d
...
...
@@ -14,11 +14,12 @@ from vllm.attention import Attention, AttentionMetadata
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
...
@@ -35,17 +36,17 @@ class QWenMLP(nn.Module):
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
=
"silu"
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
self
.
c_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
...
...
@@ -67,7 +68,7 @@ class QWenAttention(nn.Module):
max_position_embeddings
:
int
,
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
...
...
@@ -83,13 +84,13 @@ class QWenAttention(nn.Module):
self
.
head_dim
,
self
.
total_num_heads
,
bias
=
True
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
c_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
scaling
=
self
.
head_dim
**-
0.5
...
...
@@ -122,7 +123,7 @@ class QWenBlock(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
ln_1
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
...
...
@@ -134,13 +135,13 @@ class QWenBlock(nn.Module):
config
.
max_position_embeddings
,
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
self
.
ln_2
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
mlp
=
QWenMLP
(
config
.
hidden_size
,
config
.
intermediate_size
//
2
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
def
forward
(
self
,
...
...
@@ -174,7 +175,7 @@ class QWenModel(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
...
...
@@ -185,7 +186,7 @@ class QWenModel(nn.Module):
config
.
hidden_size
,
)
self
.
h
=
nn
.
ModuleList
([
QWenBlock
(
config
,
linear_method
)
QWenBlock
(
config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
ln_f
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
...
...
@@ -217,12 +218,12 @@ class QWenLMHeadModel(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
transformer
=
QWenModel
(
config
,
linear_method
)
self
.
quant_config
=
quant_config
self
.
transformer
=
QWenModel
(
config
,
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment