Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f89d18ff
Unverified
Commit
f89d18ff
authored
Nov 10, 2024
by
youkaichao
Committed by
GitHub
Nov 11, 2024
Browse files
[6/N] pass whole config to inner model (#10205)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
f0f2e563
Changes
69
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
213 additions
and
272 deletions
+213
-272
vllm/model_executor/models/arctic.py
vllm/model_executor/models/arctic.py
+10
-13
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+26
-22
vllm/model_executor/models/bart.py
vllm/model_executor/models/bart.py
+11
-12
vllm/model_executor/models/bert.py
vllm/model_executor/models/bert.py
+11
-14
vllm/model_executor/models/blip2.py
vllm/model_executor/models/blip2.py
+3
-7
vllm/model_executor/models/bloom.py
vllm/model_executor/models/bloom.py
+11
-10
vllm/model_executor/models/chameleon.py
vllm/model_executor/models/chameleon.py
+11
-16
vllm/model_executor/models/chatglm.py
vllm/model_executor/models/chatglm.py
+10
-9
vllm/model_executor/models/commandr.py
vllm/model_executor/models/commandr.py
+13
-20
vllm/model_executor/models/dbrx.py
vllm/model_executor/models/dbrx.py
+11
-10
vllm/model_executor/models/decilm.py
vllm/model_executor/models/decilm.py
+1
-5
vllm/model_executor/models/deepseek.py
vllm/model_executor/models/deepseek.py
+11
-15
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+11
-18
vllm/model_executor/models/eagle.py
vllm/model_executor/models/eagle.py
+4
-1
vllm/model_executor/models/exaone.py
vllm/model_executor/models/exaone.py
+13
-21
vllm/model_executor/models/falcon.py
vllm/model_executor/models/falcon.py
+11
-10
vllm/model_executor/models/florence2.py
vllm/model_executor/models/florence2.py
+17
-21
vllm/model_executor/models/gemma.py
vllm/model_executor/models/gemma.py
+8
-16
vllm/model_executor/models/gemma2.py
vllm/model_executor/models/gemma2.py
+9
-19
vllm/model_executor/models/gpt2.py
vllm/model_executor/models/gpt2.py
+11
-13
No files found.
vllm/model_executor/models/arctic.py
View file @
f89d18ff
...
...
@@ -34,7 +34,8 @@ from vllm.transformers_utils.configs.arctic import ArcticConfig
from
.interfaces
import
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
logger
=
init_logger
(
__name__
)
...
...
@@ -364,14 +365,13 @@ class ArcticDecoderLayer(nn.Module):
@
support_torch_compile
class
ArcticModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
ArcticConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
...
...
@@ -418,13 +418,10 @@ class ArcticForCausalLM(nn.Module, SupportsPP):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
)
->
None
:
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
model
=
ArcticModel
(
config
,
cache_config
,
quant_config
,
prefix
=
prefix
)
self
.
model
=
ArcticModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
vocab_size
=
config
.
vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
vocab_size
,
...
...
vllm/model_executor/models/baichuan.py
View file @
f89d18ff
...
...
@@ -253,13 +253,18 @@ class BaiChuanDecoderLayer(nn.Module):
@
support_torch_compile
class
BaiChuanModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
position_embedding
:
str
,
cache_con
fi
g
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
)
:
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
pre
fi
x
:
str
=
""
,
position_embedding
:
str
=
"ROPE"
,
)
->
None
:
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
...
...
@@ -332,21 +337,22 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
position_embedding
:
str
=
"ROPE"
,
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
quant_config
=
quant_config
self
.
model
=
BaiChuanModel
(
config
,
position_embedding
,
cache_config
,
quant_config
)
self
.
model
=
BaiChuanModel
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
position_embedding
=
position_embedding
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
...
...
@@ -438,16 +444,16 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
NOTE: the class name has a lower case 'c'.
"""
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
config
=
vllm_config
.
model_config
.
hf_config
if
config
.
hidden_size
==
4096
:
# baichuan2 7b
super
().
__init__
(
vllm_config
,
prefix
,
"ROPE"
)
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
position_embedding
=
"ROPE"
)
else
:
# baichuan 13b, baichuan2 13b
super
().
__init__
(
vllm_config
,
prefix
,
"ALIBI"
)
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
position_embedding
=
"ALIBI"
)
class
BaiChuanForCausalLM
(
BaiChuanBaseForCausalLM
):
...
...
@@ -455,9 +461,7 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
NOTE: the class name has an upper case 'C'.
"""
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
):
super
().
__init__
(
vllm_config
,
prefix
,
"ROPE"
)
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
position_embedding
=
"ROPE"
)
vllm/model_executor/models/bart.py
View file @
f89d18ff
...
...
@@ -41,6 +41,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.utils
import
maybe_prefix
logger
=
logging
.
get_logger
(
__name__
)
...
...
@@ -739,13 +741,14 @@ class BartModel(nn.Module):
"encoder.embed_tokens.weight"
,
"decoder.embed_tokens.weight"
]
def
__init__
(
self
,
config
:
BartConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
...
...
@@ -810,20 +813,16 @@ class BartModel(nn.Module):
class
BartForConditionalGeneration
(
nn
.
Module
):
base_model_prefix
=
"model"
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
# currently all existing BART models have `tie_word_embeddings` enabled
assert
config
.
tie_word_embeddings
self
.
config
=
config
self
.
model
=
BartModel
(
config
,
cache_config
,
quant_config
,
lora_config
=
lora_config
)
self
.
model
=
BartModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
...
...
vllm/model_executor/models/bert.py
View file @
f89d18ff
...
...
@@ -21,6 +21,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
.utils
import
maybe_prefix
class
BertEmbedding
(
nn
.
Module
):
...
...
@@ -309,12 +311,13 @@ class BertOutput(nn.Module):
class
BertModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
BertConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
embeddings
=
BertEmbedding
(
config
)
self
.
encoder
=
BertEncoder
(
config
,
cache_config
,
...
...
@@ -382,17 +385,11 @@ class BertEmbeddingModel(nn.Module):
_pooler: An instance of Pooler used for pooling operations.
"""
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
)
->
None
:
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
self
.
model
=
BertModel
(
config
,
cache_config
,
quant_config
)
self
.
model
=
BertModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
_pooler
=
Pooler
.
from_config_with_defaults
(
pooler_config
,
pooling_type
=
PoolingType
.
CLS
,
...
...
vllm/model_executor/models/blip2.py
View file @
f89d18ff
...
...
@@ -23,7 +23,7 @@ from .blip import (BlipVisionModel, dummy_image_for_blip,
get_max_blip_image_tokens
)
from
.interfaces
import
SupportsMultiModal
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
init_vllm_registered_model
,
merge_multimodal_embeddings
)
maybe_prefix
,
merge_multimodal_embeddings
)
# We use this internally as placeholders since there is no image token
# defined on the HuggingFace repo
...
...
@@ -483,11 +483,7 @@ def input_processor_for_blip2(ctx: InputContext, inputs: DecoderOnlyInputs):
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_blip2
)
class
Blip2ForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
)
->
None
:
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
...
...
@@ -517,7 +513,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
vllm_config
=
vllm_config
,
prefix
=
"language_model"
)
prefix
=
maybe_prefix
(
prefix
,
"language_model"
)
)
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
...
...
vllm/model_executor/models/bloom.py
View file @
f89d18ff
...
...
@@ -42,7 +42,8 @@ from vllm.sequence import IntermediateTensors
from
.interfaces
import
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
def
_get_alibi_slopes
(
total_num_heads
:
int
)
->
torch
.
Tensor
:
...
...
@@ -221,14 +222,13 @@ class BloomBlock(nn.Module):
@
support_torch_compile
class
BloomModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
BloomConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
embed_dim
=
config
.
hidden_size
# Embedding + LN Embedding
...
...
@@ -288,11 +288,12 @@ class BloomForCausalLM(nn.Module, SupportsPP):
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
transformer
=
BloomModel
(
config
,
cache_config
,
quant_config
)
self
.
transformer
=
BloomModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"transformer"
))
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
transformer
.
word_embeddings
else
:
...
...
vllm/model_executor/models/chameleon.py
View file @
f89d18ff
...
...
@@ -37,7 +37,8 @@ from vllm.utils import print_warning_once
from
.interfaces
import
SupportsMultiModal
,
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
# These configs are not part of the model config but the preprocessor
# and processor files, so we hardcode them in the model file for now.
...
...
@@ -831,14 +832,13 @@ class ChameleonImageVocabularyMapping:
class
ChameleonModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
ChameleonConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
...
...
@@ -924,19 +924,14 @@ class ChameleonModel(nn.Module):
class
ChameleonForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
)
->
None
:
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
self
.
model
=
ChameleonModel
(
config
,
cache_config
,
quant_config
)
self
.
model
=
ChameleonModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_vocab_size
,
...
...
vllm/model_executor/models/chatglm.py
View file @
f89d18ff
...
...
@@ -39,7 +39,8 @@ from vllm.transformers_utils.configs import ChatGLMConfig
from
.interfaces
import
SupportsLoRA
,
SupportsMultiModal
,
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
logger
=
init_logger
(
__name__
)
...
...
@@ -481,14 +482,13 @@ class GLMTransformer(nn.Module):
class
ChatGLMModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
ChatGLMConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
embedding
=
VocabParallelEmbedding
(
config
.
padded_vocab_size
,
...
...
@@ -600,7 +600,6 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
...
...
@@ -611,7 +610,9 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
self
.
quant_config
=
quant_config
self
.
max_position_embeddings
=
getattr
(
config
,
"max_sequence_length"
,
8192
)
self
.
transformer
=
ChatGLMModel
(
config
,
cache_config
,
quant_config
)
self
.
transformer
=
ChatGLMModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"transformer"
))
if
self
.
config
.
tie_word_embeddings
:
self
.
transformer
.
output_layer
.
weight
=
(
self
.
transformer
.
embedding
.
weight
)
...
...
vllm/model_executor/models/commandr.py
View file @
f89d18ff
...
...
@@ -28,7 +28,7 @@ from transformers import CohereConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
...
...
@@ -49,7 +49,8 @@ from vllm.sequence import IntermediateTensors
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
@
torch
.
compile
...
...
@@ -253,15 +254,14 @@ class CohereDecoderLayer(nn.Module):
@
support_torch_compile
class
CohereModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
CohereConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
prefix
:
str
=
""
,
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
lora_vocab
=
(
lora_config
.
lora_extra_vocab_size
*
(
lora_config
.
max_loras
or
1
))
if
lora_config
else
0
...
...
@@ -332,14 +332,9 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
embedding_modules
=
{
"embed_tokens"
:
"input_embeddings"
}
embedding_padding_modules
=
[]
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
)
->
None
:
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
...
...
@@ -353,10 +348,8 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
,
scale
=
config
.
logit_scale
)
self
.
model
=
CohereModel
(
config
,
cache_config
,
quant_config
,
lora_config
=
lora_config
)
self
.
model
=
CohereModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
sampler
=
get_sampler
()
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
...
...
vllm/model_executor/models/dbrx.py
View file @
f89d18ff
...
...
@@ -25,7 +25,8 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
from
.interfaces
import
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
class
DbrxRouter
(
nn
.
Module
):
...
...
@@ -294,14 +295,13 @@ class DbrxBlock(nn.Module):
class
DbrxModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
DbrxConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
wte
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
d_model
,
...
...
@@ -357,7 +357,6 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
if
config
.
tie_word_embeddings
:
...
...
@@ -365,7 +364,9 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
"tie_word_embeddings is not supported for Dbrx models."
)
self
.
quant_config
=
quant_config
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
transformer
=
DbrxModel
(
config
,
cache_config
,
quant_config
)
self
.
transformer
=
DbrxModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"transformer"
))
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
d_model
,
...
...
vllm/model_executor/models/decilm.py
View file @
f89d18ff
...
...
@@ -51,11 +51,7 @@ class DeciLMForCausalLM(LlamaForCausalLM):
instead.
"""
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
)
->
None
:
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
config
=
vllm_config
.
model_config
.
hf_config
config
.
num_key_value_heads
=
max
(
config
.
num_key_value_heads_per_layer
)
delattr
(
config
,
"num_key_value_heads_per_layer"
)
...
...
vllm/model_executor/models/deepseek.py
View file @
f89d18ff
...
...
@@ -50,7 +50,8 @@ from vllm.sequence import IntermediateTensors
from
.interfaces
import
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
class
DeepseekMLP
(
nn
.
Module
):
...
...
@@ -326,14 +327,13 @@ class DeepseekModel(nn.Module):
fall_back_to_pt_during_load
=
False
def
__init__
(
self
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
...
...
@@ -383,18 +383,14 @@ class DeepseekModel(nn.Module):
class
DeepseekForCausalLM
(
nn
.
Module
,
SupportsPP
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
)
->
None
:
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
DeepseekModel
(
config
,
cache_config
,
quant_config
)
self
.
model
=
DeepseekModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
f89d18ff
...
...
@@ -51,7 +51,8 @@ from vllm.sequence import IntermediateTensors
from
.interfaces
import
SupportsPP
from
.utils
import
(
PPMissingLayer
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
class
DeepseekV2MLP
(
nn
.
Module
):
...
...
@@ -408,14 +409,13 @@ class DeepseekV2Model(nn.Module):
fall_back_to_pt_during_load
=
False
def
__init__
(
self
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
...
...
@@ -479,21 +479,14 @@ class DeepseekV2Model(nn.Module):
class
DeepseekV2ForCausalLM
(
nn
.
Module
,
SupportsPP
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
)
->
None
:
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
DeepseekV2Model
(
config
,
cache_config
,
quant_config
,
prefix
=
"model"
)
self
.
model
=
DeepseekV2Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
...
...
vllm/model_executor/models/eagle.py
View file @
f89d18ff
...
...
@@ -14,6 +14,8 @@ from vllm.model_executor.models import ModelRegistry
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.utils
import
maybe_prefix
class
EAGLE
(
nn
.
Module
):
"""This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077
...
...
@@ -42,7 +44,8 @@ class EAGLE(nn.Module):
architectures
=
getattr
(
self
.
config
.
model
,
"architectures"
,
[])
model_cls
,
_
=
ModelRegistry
.
resolve_model_cls
(
architectures
)
self
.
model
=
model_cls
(
vllm_config
,
prefix
)
self
.
model
=
model_cls
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
fc
=
nn
.
Linear
(
config
.
model
.
hidden_size
*
2
,
config
.
model
.
hidden_size
,
bias
=
getattr
(
self
.
config
,
"eagle_fc_bias"
,
False
))
...
...
vllm/model_executor/models/exaone.py
View file @
f89d18ff
...
...
@@ -29,7 +29,7 @@ from torch import nn
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
...
...
@@ -54,7 +54,8 @@ from vllm.transformers_utils.configs.exaone import ExaoneConfig
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
PPMissingLayer
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
class
ExaoneGatedMLP
(
nn
.
Module
):
...
...
@@ -314,15 +315,14 @@ class ExaoneDecoderLayer(nn.Module):
@
support_torch_compile
class
ExaoneModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
ExaoneConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
lora_vocab
=
((
lora_config
.
lora_extra_vocab_size
*
...
...
@@ -438,14 +438,9 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"c_fc_1"
:
(
"gate_up_proj"
,
1
),
}
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
)
->
None
:
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
...
...
@@ -453,11 +448,8 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
.
lora_config
=
lora_config
self
.
transformer
=
ExaoneModel
(
config
,
cache_config
,
quant_config
,
lora_config
=
lora_config
,
prefix
=
"model"
,
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
),
)
if
get_pp_group
().
is_last_rank
:
self
.
unpadded_vocab_size
=
config
.
vocab_size
...
...
vllm/model_executor/models/falcon.py
View file @
f89d18ff
...
...
@@ -48,7 +48,8 @@ from vllm.transformers_utils.configs import RWConfig
from
.interfaces
import
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
FalconConfig
=
Union
[
HF_FalconConfig
,
RWConfig
]
...
...
@@ -332,14 +333,13 @@ class FalconDecoderLayer(nn.Module):
@
support_torch_compile
class
FalconModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
FalconConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
embed_dim
=
config
.
hidden_size
self
.
num_heads
=
config
.
num_attention_heads
...
...
@@ -408,11 +408,12 @@ class FalconForCausalLM(nn.Module, SupportsPP):
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
transformer
=
FalconModel
(
config
,
cache_config
,
quant_config
)
self
.
transformer
=
FalconModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"transformer"
))
# only Falcon-11B doesn't share lm_head weight with word embeddings
# and previous Falcon model doesn't have tie_word_embeddings config
# so we set tie_word_embeddings to True by default
...
...
vllm/model_executor/models/florence2.py
View file @
f89d18ff
...
...
@@ -3,13 +3,10 @@ from typing import Iterable, List, Optional, Tuple
import
torch
import
torch.nn
as
nn
from
transformers
import
PretrainedConfig
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.bart
import
(
BartDecoder
,
BartEncoder
,
...
...
@@ -23,11 +20,13 @@ from .utils import AutoWeightsLoader
class
Florence2LanguageModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
...
...
@@ -93,15 +92,14 @@ class Florence2LanguageModel(nn.Module):
class
Florence2LanguageForConditionalGeneration
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
self
.
config
=
config
self
.
model
=
Florence2LanguageModel
(
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
self
.
model
=
Florence2LanguageModel
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
embed_scale
=
math
.
sqrt
(
config
.
d_model
)
if
config
.
scale_embedding
else
1.0
...
...
@@ -189,17 +187,15 @@ class Florence2LanguageForConditionalGeneration(nn.Module):
class
Florence2ForConditionalGeneration
(
nn
.
Module
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
# TODO(Isotr0py): Add vision backbone
self
.
language_model
=
Florence2LanguageForConditionalGeneration
(
config
=
config
.
text_config
,
cache_config
=
cache_con
fi
g
,
quant_config
=
quant_config
)
vllm_
config
=
vllm_config
.
with_hf_config
(
config
.
text_config
)
,
prefix
=
pre
fi
x
,
)
@
property
def
sampler
(
self
):
...
...
vllm/model_executor/models/gemma.py
View file @
f89d18ff
...
...
@@ -258,14 +258,13 @@ class GemmaDecoderLayer(nn.Module):
@
support_torch_compile
class
GemmaModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
GemmaConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
embed_tokens
=
VocabParallelEmbedding
(
...
...
@@ -372,14 +371,9 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
embedding_modules
=
{}
embedding_padding_modules
=
[]
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
)
->
None
:
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
...
...
@@ -389,9 +383,7 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
.
lora_config
=
lora_config
self
.
quant_config
=
quant_config
self
.
model
=
GemmaModel
(
config
,
cache_config
,
quant_config
,
self
.
model
=
GemmaModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
get_sampler
()
...
...
vllm/model_executor/models/gemma2.py
View file @
f89d18ff
...
...
@@ -43,7 +43,8 @@ from vllm.sequence import IntermediateTensors, PoolerOutput
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
logger
=
init_logger
(
__name__
)
...
...
@@ -243,11 +244,7 @@ class Gemma2DecoderLayer(nn.Module):
@
support_torch_compile
class
Gemma2Model
(
nn
.
Module
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
)
->
None
:
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
...
...
@@ -399,13 +396,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"up_proj"
:
(
"gate_up_proj"
,
1
),
}
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
)
->
None
:
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
del
lora_config
# Unused.
...
...
@@ -414,7 +406,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
# currently all existing Gemma models have `tie_word_embeddings` enabled
assert
config
.
tie_word_embeddings
self
.
quant_config
=
quant_config
self
.
model
=
Gemma2Model
(
config
,
cache_config
,
quant_config
)
self
.
model
=
Gemma2Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
,
soft_cap
=
config
.
final_logit_softcapping
)
self
.
sampler
=
get_sampler
()
...
...
@@ -471,14 +464,11 @@ class Gemma2EmbeddingModel(nn.Module, SupportsPP):
_pooler: An instance of Pooler used for pooling operations.
"""
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
)
->
None
:
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
model
=
Gemma2Model
(
vllm_config
,
prefix
)
self
.
model
=
Gemma2Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
_pooler
=
Pooler
.
from_config_with_defaults
(
vllm_config
.
model_config
.
pooler_config
,
pooling_type
=
PoolingType
.
LAST
,
...
...
vllm/model_executor/models/gpt2.py
View file @
f89d18ff
...
...
@@ -42,7 +42,8 @@ from vllm.sequence import IntermediateTensors
from
.interfaces
import
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
class
GPT2Attention
(
nn
.
Module
):
...
...
@@ -184,14 +185,13 @@ class GPT2Block(nn.Module):
@
support_torch_compile
class
GPT2Model
(
nn
.
Module
):
def
__init__
(
self
,
config
:
GPT2Config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
assert
not
config
.
add_cross_attention
assert
not
config
.
scale_attn_by_inverse_layer_idx
...
...
@@ -247,14 +247,12 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
transformer
=
GPT2Model
(
config
,
cache_config
,
quant_config
,
prefix
=
"transformer"
)
self
.
transformer
=
GPT2Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"transformer"
))
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
transformer
.
wte
else
:
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment