Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f89d18ff
Unverified
Commit
f89d18ff
authored
Nov 10, 2024
by
youkaichao
Committed by
GitHub
Nov 11, 2024
Browse files
[6/N] pass whole config to inner model (#10205)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
f0f2e563
Changes
69
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
213 additions
and
272 deletions
+213
-272
vllm/model_executor/models/arctic.py
vllm/model_executor/models/arctic.py
+10
-13
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+26
-22
vllm/model_executor/models/bart.py
vllm/model_executor/models/bart.py
+11
-12
vllm/model_executor/models/bert.py
vllm/model_executor/models/bert.py
+11
-14
vllm/model_executor/models/blip2.py
vllm/model_executor/models/blip2.py
+3
-7
vllm/model_executor/models/bloom.py
vllm/model_executor/models/bloom.py
+11
-10
vllm/model_executor/models/chameleon.py
vllm/model_executor/models/chameleon.py
+11
-16
vllm/model_executor/models/chatglm.py
vllm/model_executor/models/chatglm.py
+10
-9
vllm/model_executor/models/commandr.py
vllm/model_executor/models/commandr.py
+13
-20
vllm/model_executor/models/dbrx.py
vllm/model_executor/models/dbrx.py
+11
-10
vllm/model_executor/models/decilm.py
vllm/model_executor/models/decilm.py
+1
-5
vllm/model_executor/models/deepseek.py
vllm/model_executor/models/deepseek.py
+11
-15
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+11
-18
vllm/model_executor/models/eagle.py
vllm/model_executor/models/eagle.py
+4
-1
vllm/model_executor/models/exaone.py
vllm/model_executor/models/exaone.py
+13
-21
vllm/model_executor/models/falcon.py
vllm/model_executor/models/falcon.py
+11
-10
vllm/model_executor/models/florence2.py
vllm/model_executor/models/florence2.py
+17
-21
vllm/model_executor/models/gemma.py
vllm/model_executor/models/gemma.py
+8
-16
vllm/model_executor/models/gemma2.py
vllm/model_executor/models/gemma2.py
+9
-19
vllm/model_executor/models/gpt2.py
vllm/model_executor/models/gpt2.py
+11
-13
No files found.
vllm/model_executor/models/arctic.py
View file @
f89d18ff
...
@@ -34,7 +34,8 @@ from vllm.transformers_utils.configs.arctic import ArcticConfig
...
@@ -34,7 +34,8 @@ from vllm.transformers_utils.configs.arctic import ArcticConfig
from
.interfaces
import
SupportsPP
from
.interfaces
import
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -364,14 +365,13 @@ class ArcticDecoderLayer(nn.Module):
...
@@ -364,14 +365,13 @@ class ArcticDecoderLayer(nn.Module):
@
support_torch_compile
@
support_torch_compile
class
ArcticModel
(
nn
.
Module
):
class
ArcticModel
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
self
,
config
:
ArcticConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
padding_idx
=
config
.
pad_token_id
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
embed_tokens
=
VocabParallelEmbedding
(
...
@@ -418,13 +418,10 @@ class ArcticForCausalLM(nn.Module, SupportsPP):
...
@@ -418,13 +418,10 @@ class ArcticForCausalLM(nn.Module, SupportsPP):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
)
->
None
:
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
)
->
None
:
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
config
=
config
self
.
model
=
ArcticModel
(
config
,
self
.
model
=
ArcticModel
(
vllm_config
=
vllm_config
,
cache_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
quant_config
,
prefix
=
prefix
)
self
.
vocab_size
=
config
.
vocab_size
self
.
vocab_size
=
config
.
vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
lm_head
=
ParallelLMHead
(
self
.
vocab_size
,
self
.
vocab_size
,
...
...
vllm/model_executor/models/baichuan.py
View file @
f89d18ff
...
@@ -253,13 +253,18 @@ class BaiChuanDecoderLayer(nn.Module):
...
@@ -253,13 +253,18 @@ class BaiChuanDecoderLayer(nn.Module):
@
support_torch_compile
@
support_torch_compile
class
BaiChuanModel
(
nn
.
Module
):
class
BaiChuanModel
(
nn
.
Module
):
def
__init__
(
self
,
def
__init__
(
config
:
PretrainedConfig
,
self
,
position_embedding
:
str
,
vllm_config
:
VllmConfig
,
cache_con
fi
g
:
Optional
[
CacheConfig
]
=
None
,
pre
fi
x
:
str
=
""
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
position_embedding
:
str
=
"ROPE"
,
prefix
:
str
=
""
)
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
vocab_size
=
config
.
vocab_size
...
@@ -332,21 +337,22 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -332,21 +337,22 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def
__init__
(
def
__init__
(
self
,
self
,
*
,
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
position_embedding
:
str
=
"ROPE"
,
position_embedding
:
str
=
"ROPE"
,
):
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
lora_config
=
lora_config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
model
=
BaiChuanModel
(
config
,
position_embedding
,
cache_config
,
self
.
model
=
BaiChuanModel
(
vllm_config
=
vllm_config
,
quant_config
)
prefix
=
prefix
,
position_embedding
=
position_embedding
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
quant_config
=
quant_config
)
...
@@ -438,16 +444,16 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
...
@@ -438,16 +444,16 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
NOTE: the class name has a lower case 'c'.
NOTE: the class name has a lower case 'c'.
"""
"""
def
__init__
(
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
):
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
if
config
.
hidden_size
==
4096
:
# baichuan2 7b
if
config
.
hidden_size
==
4096
:
# baichuan2 7b
super
().
__init__
(
vllm_config
,
prefix
,
"ROPE"
)
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
position_embedding
=
"ROPE"
)
else
:
# baichuan 13b, baichuan2 13b
else
:
# baichuan 13b, baichuan2 13b
super
().
__init__
(
vllm_config
,
prefix
,
"ALIBI"
)
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
position_embedding
=
"ALIBI"
)
class
BaiChuanForCausalLM
(
BaiChuanBaseForCausalLM
):
class
BaiChuanForCausalLM
(
BaiChuanBaseForCausalLM
):
...
@@ -455,9 +461,7 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
...
@@ -455,9 +461,7 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
NOTE: the class name has an upper case 'C'.
NOTE: the class name has an upper case 'C'.
"""
"""
def
__init__
(
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
self
,
super
().
__init__
(
vllm_config
=
vllm_config
,
vllm_config
:
VllmConfig
,
prefix
=
prefix
,
prefix
:
str
=
""
,
position_embedding
=
"ROPE"
)
):
super
().
__init__
(
vllm_config
,
prefix
,
"ROPE"
)
vllm/model_executor/models/bart.py
View file @
f89d18ff
...
@@ -41,6 +41,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...
@@ -41,6 +41,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.utils
import
maybe_prefix
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
...
@@ -739,13 +741,14 @@ class BartModel(nn.Module):
...
@@ -739,13 +741,14 @@ class BartModel(nn.Module):
"encoder.embed_tokens.weight"
,
"decoder.embed_tokens.weight"
"encoder.embed_tokens.weight"
,
"decoder.embed_tokens.weight"
]
]
def
__init__
(
self
,
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
config
:
BartConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
self
.
padding_idx
=
config
.
pad_token_id
...
@@ -810,20 +813,16 @@ class BartModel(nn.Module):
...
@@ -810,20 +813,16 @@ class BartModel(nn.Module):
class
BartForConditionalGeneration
(
nn
.
Module
):
class
BartForConditionalGeneration
(
nn
.
Module
):
base_model_prefix
=
"model"
base_model_prefix
=
"model"
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
lora_config
=
vllm_config
.
lora_config
# currently all existing BART models have `tie_word_embeddings` enabled
# currently all existing BART models have `tie_word_embeddings` enabled
assert
config
.
tie_word_embeddings
assert
config
.
tie_word_embeddings
self
.
config
=
config
self
.
config
=
config
self
.
model
=
BartModel
(
config
,
self
.
model
=
BartModel
(
vllm_config
=
vllm_config
,
cache_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
quant_config
,
lora_config
=
lora_config
)
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
if
lora_config
:
...
...
vllm/model_executor/models/bert.py
View file @
f89d18ff
...
@@ -21,6 +21,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...
@@ -21,6 +21,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
.utils
import
maybe_prefix
class
BertEmbedding
(
nn
.
Module
):
class
BertEmbedding
(
nn
.
Module
):
...
@@ -309,12 +311,13 @@ class BertOutput(nn.Module):
...
@@ -309,12 +311,13 @@ class BertOutput(nn.Module):
class
BertModel
(
nn
.
Module
):
class
BertModel
(
nn
.
Module
):
def
__init__
(
self
,
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
config
:
BertConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
embeddings
=
BertEmbedding
(
config
)
self
.
embeddings
=
BertEmbedding
(
config
)
self
.
encoder
=
BertEncoder
(
config
,
self
.
encoder
=
BertEncoder
(
config
,
cache_config
,
cache_config
,
...
@@ -382,17 +385,11 @@ class BertEmbeddingModel(nn.Module):
...
@@ -382,17 +385,11 @@ class BertEmbeddingModel(nn.Module):
_pooler: An instance of Pooler used for pooling operations.
_pooler: An instance of Pooler used for pooling operations.
"""
"""
def
__init__
(
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
self
.
model
=
BertModel
(
config
,
cache_config
,
quant_config
)
self
.
model
=
BertModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
_pooler
=
Pooler
.
from_config_with_defaults
(
self
.
_pooler
=
Pooler
.
from_config_with_defaults
(
pooler_config
,
pooler_config
,
pooling_type
=
PoolingType
.
CLS
,
pooling_type
=
PoolingType
.
CLS
,
...
...
vllm/model_executor/models/blip2.py
View file @
f89d18ff
...
@@ -23,7 +23,7 @@ from .blip import (BlipVisionModel, dummy_image_for_blip,
...
@@ -23,7 +23,7 @@ from .blip import (BlipVisionModel, dummy_image_for_blip,
get_max_blip_image_tokens
)
get_max_blip_image_tokens
)
from
.interfaces
import
SupportsMultiModal
,
SupportsPP
from
.interfaces
import
SupportsMultiModal
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
init_vllm_registered_model
,
from
.utils
import
(
AutoWeightsLoader
,
init_vllm_registered_model
,
merge_multimodal_embeddings
)
maybe_prefix
,
merge_multimodal_embeddings
)
# We use this internally as placeholders since there is no image token
# We use this internally as placeholders since there is no image token
# defined on the HuggingFace repo
# defined on the HuggingFace repo
...
@@ -483,11 +483,7 @@ def input_processor_for_blip2(ctx: InputContext, inputs: DecoderOnlyInputs):
...
@@ -483,11 +483,7 @@ def input_processor_for_blip2(ctx: InputContext, inputs: DecoderOnlyInputs):
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_blip2
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_blip2
)
class
Blip2ForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
class
Blip2ForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
def
__init__
(
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
...
@@ -517,7 +513,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -517,7 +513,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
self
.
language_model
=
init_vllm_registered_model
(
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
config
.
text_config
,
vllm_config
=
vllm_config
,
vllm_config
=
vllm_config
,
prefix
=
"language_model"
)
prefix
=
maybe_prefix
(
prefix
,
"language_model"
)
)
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
self
.
language_model
.
make_empty_intermediate_tensors
)
...
...
vllm/model_executor/models/bloom.py
View file @
f89d18ff
...
@@ -42,7 +42,8 @@ from vllm.sequence import IntermediateTensors
...
@@ -42,7 +42,8 @@ from vllm.sequence import IntermediateTensors
from
.interfaces
import
SupportsPP
from
.interfaces
import
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
def
_get_alibi_slopes
(
total_num_heads
:
int
)
->
torch
.
Tensor
:
def
_get_alibi_slopes
(
total_num_heads
:
int
)
->
torch
.
Tensor
:
...
@@ -221,14 +222,13 @@ class BloomBlock(nn.Module):
...
@@ -221,14 +222,13 @@ class BloomBlock(nn.Module):
@
support_torch_compile
@
support_torch_compile
class
BloomModel
(
nn
.
Module
):
class
BloomModel
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
self
,
config
:
BloomConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
embed_dim
=
config
.
hidden_size
self
.
embed_dim
=
config
.
hidden_size
# Embedding + LN Embedding
# Embedding + LN Embedding
...
@@ -288,11 +288,12 @@ class BloomForCausalLM(nn.Module, SupportsPP):
...
@@ -288,11 +288,12 @@ class BloomForCausalLM(nn.Module, SupportsPP):
):
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
transformer
=
BloomModel
(
config
,
cache_config
,
quant_config
)
self
.
transformer
=
BloomModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"transformer"
))
if
self
.
config
.
tie_word_embeddings
:
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
transformer
.
word_embeddings
self
.
lm_head
=
self
.
transformer
.
word_embeddings
else
:
else
:
...
...
vllm/model_executor/models/chameleon.py
View file @
f89d18ff
...
@@ -37,7 +37,8 @@ from vllm.utils import print_warning_once
...
@@ -37,7 +37,8 @@ from vllm.utils import print_warning_once
from
.interfaces
import
SupportsMultiModal
,
SupportsPP
from
.interfaces
import
SupportsMultiModal
,
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
# These configs are not part of the model config but the preprocessor
# These configs are not part of the model config but the preprocessor
# and processor files, so we hardcode them in the model file for now.
# and processor files, so we hardcode them in the model file for now.
...
@@ -831,14 +832,13 @@ class ChameleonImageVocabularyMapping:
...
@@ -831,14 +832,13 @@ class ChameleonImageVocabularyMapping:
class
ChameleonModel
(
nn
.
Module
):
class
ChameleonModel
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
self
,
config
:
ChameleonConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
vocab_size
=
config
.
vocab_size
...
@@ -924,19 +924,14 @@ class ChameleonModel(nn.Module):
...
@@ -924,19 +924,14 @@ class ChameleonModel(nn.Module):
class
ChameleonForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
class
ChameleonForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
SupportsPP
):
def
__init__
(
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
self
.
config
=
config
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
self
.
multimodal_config
=
multimodal_config
self
.
model
=
ChameleonModel
(
config
,
cache_config
,
quant_config
)
self
.
model
=
ChameleonModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_vocab_size
,
self
.
unpadded_vocab_size
,
...
...
vllm/model_executor/models/chatglm.py
View file @
f89d18ff
...
@@ -39,7 +39,8 @@ from vllm.transformers_utils.configs import ChatGLMConfig
...
@@ -39,7 +39,8 @@ from vllm.transformers_utils.configs import ChatGLMConfig
from
.interfaces
import
SupportsLoRA
,
SupportsMultiModal
,
SupportsPP
from
.interfaces
import
SupportsLoRA
,
SupportsMultiModal
,
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -481,14 +482,13 @@ class GLMTransformer(nn.Module):
...
@@ -481,14 +482,13 @@ class GLMTransformer(nn.Module):
class
ChatGLMModel
(
nn
.
Module
):
class
ChatGLMModel
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
self
,
config
:
ChatGLMConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
config
=
config
self
.
embedding
=
VocabParallelEmbedding
(
config
.
padded_vocab_size
,
self
.
embedding
=
VocabParallelEmbedding
(
config
.
padded_vocab_size
,
...
@@ -600,7 +600,6 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
...
@@ -600,7 +600,6 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
):
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
lora_config
=
vllm_config
.
lora_config
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
...
@@ -611,7 +610,9 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
...
@@ -611,7 +610,9 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
max_position_embeddings
=
getattr
(
config
,
"max_sequence_length"
,
self
.
max_position_embeddings
=
getattr
(
config
,
"max_sequence_length"
,
8192
)
8192
)
self
.
transformer
=
ChatGLMModel
(
config
,
cache_config
,
quant_config
)
self
.
transformer
=
ChatGLMModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"transformer"
))
if
self
.
config
.
tie_word_embeddings
:
if
self
.
config
.
tie_word_embeddings
:
self
.
transformer
.
output_layer
.
weight
=
(
self
.
transformer
.
output_layer
.
weight
=
(
self
.
transformer
.
embedding
.
weight
)
self
.
transformer
.
embedding
.
weight
)
...
...
vllm/model_executor/models/commandr.py
View file @
f89d18ff
...
@@ -28,7 +28,7 @@ from transformers import CohereConfig
...
@@ -28,7 +28,7 @@ from transformers import CohereConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
...
@@ -49,7 +49,8 @@ from vllm.sequence import IntermediateTensors
...
@@ -49,7 +49,8 @@ from vllm.sequence import IntermediateTensors
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
@
torch
.
compile
@
torch
.
compile
...
@@ -253,15 +254,14 @@ class CohereDecoderLayer(nn.Module):
...
@@ -253,15 +254,14 @@ class CohereDecoderLayer(nn.Module):
@
support_torch_compile
@
support_torch_compile
class
CohereModel
(
nn
.
Module
):
class
CohereModel
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
self
,
config
:
CohereConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
config
=
config
lora_vocab
=
(
lora_config
.
lora_extra_vocab_size
*
lora_vocab
=
(
lora_config
.
lora_extra_vocab_size
*
(
lora_config
.
max_loras
or
1
))
if
lora_config
else
0
(
lora_config
.
max_loras
or
1
))
if
lora_config
else
0
...
@@ -332,14 +332,9 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -332,14 +332,9 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
embedding_modules
=
{
"embed_tokens"
:
"input_embeddings"
}
embedding_modules
=
{
"embed_tokens"
:
"input_embeddings"
}
embedding_padding_modules
=
[]
embedding_padding_modules
=
[]
def
__init__
(
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
config
=
config
...
@@ -353,10 +348,8 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -353,10 +348,8 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
,
config
.
vocab_size
,
scale
=
config
.
logit_scale
)
scale
=
config
.
logit_scale
)
self
.
model
=
CohereModel
(
config
,
self
.
model
=
CohereModel
(
vllm_config
=
vllm_config
,
cache_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
quant_config
,
lora_config
=
lora_config
)
self
.
sampler
=
get_sampler
()
self
.
sampler
=
get_sampler
()
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
self
.
model
.
make_empty_intermediate_tensors
)
...
...
vllm/model_executor/models/dbrx.py
View file @
f89d18ff
...
@@ -25,7 +25,8 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
...
@@ -25,7 +25,8 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
from
.interfaces
import
SupportsPP
from
.interfaces
import
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
class
DbrxRouter
(
nn
.
Module
):
class
DbrxRouter
(
nn
.
Module
):
...
@@ -294,14 +295,13 @@ class DbrxBlock(nn.Module):
...
@@ -294,14 +295,13 @@ class DbrxBlock(nn.Module):
class
DbrxModel
(
nn
.
Module
):
class
DbrxModel
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
self
,
config
:
DbrxConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
wte
=
VocabParallelEmbedding
(
self
.
wte
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
vocab_size
,
config
.
d_model
,
config
.
d_model
,
...
@@ -357,7 +357,6 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
...
@@ -357,7 +357,6 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
):
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
config
=
config
if
config
.
tie_word_embeddings
:
if
config
.
tie_word_embeddings
:
...
@@ -365,7 +364,9 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
...
@@ -365,7 +364,9 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
"tie_word_embeddings is not supported for Dbrx models."
)
"tie_word_embeddings is not supported for Dbrx models."
)
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
transformer
=
DbrxModel
(
config
,
cache_config
,
quant_config
)
self
.
transformer
=
DbrxModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"transformer"
))
self
.
lm_head
=
ParallelLMHead
(
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
vocab_size
,
config
.
d_model
,
config
.
d_model
,
...
...
vllm/model_executor/models/decilm.py
View file @
f89d18ff
...
@@ -51,11 +51,7 @@ class DeciLMForCausalLM(LlamaForCausalLM):
...
@@ -51,11 +51,7 @@ class DeciLMForCausalLM(LlamaForCausalLM):
instead.
instead.
"""
"""
def
__init__
(
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
)
->
None
:
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
config
.
num_key_value_heads
=
max
(
config
.
num_key_value_heads_per_layer
)
config
.
num_key_value_heads
=
max
(
config
.
num_key_value_heads_per_layer
)
delattr
(
config
,
"num_key_value_heads_per_layer"
)
delattr
(
config
,
"num_key_value_heads_per_layer"
)
...
...
vllm/model_executor/models/deepseek.py
View file @
f89d18ff
...
@@ -50,7 +50,8 @@ from vllm.sequence import IntermediateTensors
...
@@ -50,7 +50,8 @@ from vllm.sequence import IntermediateTensors
from
.interfaces
import
SupportsPP
from
.interfaces
import
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
class
DeepseekMLP
(
nn
.
Module
):
class
DeepseekMLP
(
nn
.
Module
):
...
@@ -326,14 +327,13 @@ class DeepseekModel(nn.Module):
...
@@ -326,14 +327,13 @@ class DeepseekModel(nn.Module):
fall_back_to_pt_during_load
=
False
fall_back_to_pt_during_load
=
False
def
__init__
(
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
self
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
padding_idx
=
config
.
pad_token_id
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
vocab_size
=
config
.
vocab_size
...
@@ -383,18 +383,14 @@ class DeepseekModel(nn.Module):
...
@@ -383,18 +383,14 @@ class DeepseekModel(nn.Module):
class
DeepseekForCausalLM
(
nn
.
Module
,
SupportsPP
):
class
DeepseekForCausalLM
(
nn
.
Module
,
SupportsPP
):
def
__init__
(
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
model
=
DeepseekModel
(
config
,
cache_config
,
quant_config
)
self
.
model
=
DeepseekModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
quant_config
=
quant_config
)
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
f89d18ff
...
@@ -51,7 +51,8 @@ from vllm.sequence import IntermediateTensors
...
@@ -51,7 +51,8 @@ from vllm.sequence import IntermediateTensors
from
.interfaces
import
SupportsPP
from
.interfaces
import
SupportsPP
from
.utils
import
(
PPMissingLayer
,
is_pp_missing_parameter
,
from
.utils
import
(
PPMissingLayer
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
class
DeepseekV2MLP
(
nn
.
Module
):
class
DeepseekV2MLP
(
nn
.
Module
):
...
@@ -408,14 +409,13 @@ class DeepseekV2Model(nn.Module):
...
@@ -408,14 +409,13 @@ class DeepseekV2Model(nn.Module):
fall_back_to_pt_during_load
=
False
fall_back_to_pt_during_load
=
False
def
__init__
(
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
self
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
padding_idx
=
config
.
pad_token_id
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
vocab_size
=
config
.
vocab_size
...
@@ -479,21 +479,14 @@ class DeepseekV2Model(nn.Module):
...
@@ -479,21 +479,14 @@ class DeepseekV2Model(nn.Module):
class
DeepseekV2ForCausalLM
(
nn
.
Module
,
SupportsPP
):
class
DeepseekV2ForCausalLM
(
nn
.
Module
,
SupportsPP
):
def
__init__
(
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
model
=
DeepseekV2Model
(
config
,
self
.
model
=
DeepseekV2Model
(
vllm_config
=
vllm_config
,
cache_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
quant_config
,
prefix
=
"model"
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
quant_config
=
quant_config
)
...
...
vllm/model_executor/models/eagle.py
View file @
f89d18ff
...
@@ -14,6 +14,8 @@ from vllm.model_executor.models import ModelRegistry
...
@@ -14,6 +14,8 @@ from vllm.model_executor.models import ModelRegistry
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.utils
import
maybe_prefix
class
EAGLE
(
nn
.
Module
):
class
EAGLE
(
nn
.
Module
):
"""This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077
"""This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077
...
@@ -42,7 +44,8 @@ class EAGLE(nn.Module):
...
@@ -42,7 +44,8 @@ class EAGLE(nn.Module):
architectures
=
getattr
(
self
.
config
.
model
,
"architectures"
,
[])
architectures
=
getattr
(
self
.
config
.
model
,
"architectures"
,
[])
model_cls
,
_
=
ModelRegistry
.
resolve_model_cls
(
architectures
)
model_cls
,
_
=
ModelRegistry
.
resolve_model_cls
(
architectures
)
self
.
model
=
model_cls
(
vllm_config
,
prefix
)
self
.
model
=
model_cls
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
fc
=
nn
.
Linear
(
config
.
model
.
hidden_size
*
2
,
self
.
fc
=
nn
.
Linear
(
config
.
model
.
hidden_size
*
2
,
config
.
model
.
hidden_size
,
config
.
model
.
hidden_size
,
bias
=
getattr
(
self
.
config
,
"eagle_fc_bias"
,
False
))
bias
=
getattr
(
self
.
config
,
"eagle_fc_bias"
,
False
))
...
...
vllm/model_executor/models/exaone.py
View file @
f89d18ff
...
@@ -29,7 +29,7 @@ from torch import nn
...
@@ -29,7 +29,7 @@ from torch import nn
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
...
@@ -54,7 +54,8 @@ from vllm.transformers_utils.configs.exaone import ExaoneConfig
...
@@ -54,7 +54,8 @@ from vllm.transformers_utils.configs.exaone import ExaoneConfig
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
PPMissingLayer
,
is_pp_missing_parameter
,
from
.utils
import
(
PPMissingLayer
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
class
ExaoneGatedMLP
(
nn
.
Module
):
class
ExaoneGatedMLP
(
nn
.
Module
):
...
@@ -314,15 +315,14 @@ class ExaoneDecoderLayer(nn.Module):
...
@@ -314,15 +315,14 @@ class ExaoneDecoderLayer(nn.Module):
@
support_torch_compile
@
support_torch_compile
class
ExaoneModel
(
nn
.
Module
):
class
ExaoneModel
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
self
,
config
:
ExaoneConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
self
.
padding_idx
=
config
.
pad_token_id
lora_vocab
=
((
lora_config
.
lora_extra_vocab_size
*
lora_vocab
=
((
lora_config
.
lora_extra_vocab_size
*
...
@@ -438,14 +438,9 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -438,14 +438,9 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"c_fc_1"
:
(
"gate_up_proj"
,
1
),
"c_fc_1"
:
(
"gate_up_proj"
,
1
),
}
}
def
__init__
(
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
lora_config
=
vllm_config
.
lora_config
...
@@ -453,11 +448,8 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -453,11 +448,8 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
.
lora_config
=
lora_config
self
.
lora_config
=
lora_config
self
.
transformer
=
ExaoneModel
(
self
.
transformer
=
ExaoneModel
(
config
,
vllm_config
=
vllm_config
,
cache_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
),
quant_config
,
lora_config
=
lora_config
,
prefix
=
"model"
,
)
)
if
get_pp_group
().
is_last_rank
:
if
get_pp_group
().
is_last_rank
:
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
unpadded_vocab_size
=
config
.
vocab_size
...
...
vllm/model_executor/models/falcon.py
View file @
f89d18ff
...
@@ -48,7 +48,8 @@ from vllm.transformers_utils.configs import RWConfig
...
@@ -48,7 +48,8 @@ from vllm.transformers_utils.configs import RWConfig
from
.interfaces
import
SupportsPP
from
.interfaces
import
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
FalconConfig
=
Union
[
HF_FalconConfig
,
RWConfig
]
FalconConfig
=
Union
[
HF_FalconConfig
,
RWConfig
]
...
@@ -332,14 +333,13 @@ class FalconDecoderLayer(nn.Module):
...
@@ -332,14 +333,13 @@ class FalconDecoderLayer(nn.Module):
@
support_torch_compile
@
support_torch_compile
class
FalconModel
(
nn
.
Module
):
class
FalconModel
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
self
,
config
:
FalconConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
config
=
config
self
.
embed_dim
=
config
.
hidden_size
self
.
embed_dim
=
config
.
hidden_size
self
.
num_heads
=
config
.
num_attention_heads
self
.
num_heads
=
config
.
num_attention_heads
...
@@ -408,11 +408,12 @@ class FalconForCausalLM(nn.Module, SupportsPP):
...
@@ -408,11 +408,12 @@ class FalconForCausalLM(nn.Module, SupportsPP):
):
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
transformer
=
FalconModel
(
config
,
cache_config
,
quant_config
)
self
.
transformer
=
FalconModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"transformer"
))
# only Falcon-11B doesn't share lm_head weight with word embeddings
# only Falcon-11B doesn't share lm_head weight with word embeddings
# and previous Falcon model doesn't have tie_word_embeddings config
# and previous Falcon model doesn't have tie_word_embeddings config
# so we set tie_word_embeddings to True by default
# so we set tie_word_embeddings to True by default
...
...
vllm/model_executor/models/florence2.py
View file @
f89d18ff
...
@@ -3,13 +3,10 @@ from typing import Iterable, List, Optional, Tuple
...
@@ -3,13 +3,10 @@ from typing import Iterable, List, Optional, Tuple
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
transformers
import
PretrainedConfig
from
vllm.attention
import
AttentionMetadata
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.bart
import
(
BartDecoder
,
BartEncoder
,
from
vllm.model_executor.models.bart
import
(
BartDecoder
,
BartEncoder
,
...
@@ -23,11 +20,13 @@ from .utils import AutoWeightsLoader
...
@@ -23,11 +20,13 @@ from .utils import AutoWeightsLoader
class
Florence2LanguageModel
(
nn
.
Module
):
class
Florence2LanguageModel
(
nn
.
Module
):
def
__init__
(
self
,
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
self
.
padding_idx
=
config
.
pad_token_id
...
@@ -93,15 +92,14 @@ class Florence2LanguageModel(nn.Module):
...
@@ -93,15 +92,14 @@ class Florence2LanguageModel(nn.Module):
class
Florence2LanguageForConditionalGeneration
(
nn
.
Module
):
class
Florence2LanguageForConditionalGeneration
(
nn
.
Module
):
def
__init__
(
self
,
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
self
.
config
=
config
self
.
config
=
config
self
.
model
=
Florence2LanguageModel
(
config
,
self
.
model
=
Florence2LanguageModel
(
vllm_config
=
vllm_config
,
cache_config
=
cache_config
,
prefix
=
prefix
)
quant_config
=
quant_config
)
embed_scale
=
math
.
sqrt
(
embed_scale
=
math
.
sqrt
(
config
.
d_model
)
if
config
.
scale_embedding
else
1.0
config
.
d_model
)
if
config
.
scale_embedding
else
1.0
...
@@ -189,17 +187,15 @@ class Florence2LanguageForConditionalGeneration(nn.Module):
...
@@ -189,17 +187,15 @@ class Florence2LanguageForConditionalGeneration(nn.Module):
class
Florence2ForConditionalGeneration
(
nn
.
Module
):
class
Florence2ForConditionalGeneration
(
nn
.
Module
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
# TODO(Isotr0py): Add vision backbone
# TODO(Isotr0py): Add vision backbone
self
.
language_model
=
Florence2LanguageForConditionalGeneration
(
self
.
language_model
=
Florence2LanguageForConditionalGeneration
(
config
=
config
.
text_config
,
vllm_
config
=
vllm_config
.
with_hf_config
(
config
.
text_config
)
,
cache_config
=
cache_con
fi
g
,
prefix
=
pre
fi
x
,
quant_config
=
quant_config
)
)
@
property
@
property
def
sampler
(
self
):
def
sampler
(
self
):
...
...
vllm/model_executor/models/gemma.py
View file @
f89d18ff
...
@@ -258,14 +258,13 @@ class GemmaDecoderLayer(nn.Module):
...
@@ -258,14 +258,13 @@ class GemmaDecoderLayer(nn.Module):
@
support_torch_compile
@
support_torch_compile
class
GemmaModel
(
nn
.
Module
):
class
GemmaModel
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
self
,
config
:
GemmaConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
config
=
config
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
embed_tokens
=
VocabParallelEmbedding
(
...
@@ -372,14 +371,9 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -372,14 +371,9 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
embedding_modules
=
{}
embedding_modules
=
{}
embedding_padding_modules
=
[]
embedding_padding_modules
=
[]
def
__init__
(
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
lora_config
=
vllm_config
.
lora_config
...
@@ -389,9 +383,7 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -389,9 +383,7 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
.
lora_config
=
lora_config
self
.
lora_config
=
lora_config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
model
=
GemmaModel
(
config
,
self
.
model
=
GemmaModel
(
vllm_config
=
vllm_config
,
cache_config
,
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
get_sampler
()
self
.
sampler
=
get_sampler
()
...
...
vllm/model_executor/models/gemma2.py
View file @
f89d18ff
...
@@ -43,7 +43,8 @@ from vllm.sequence import IntermediateTensors, PoolerOutput
...
@@ -43,7 +43,8 @@ from vllm.sequence import IntermediateTensors, PoolerOutput
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
is_pp_missing_parameter
,
from
.utils
import
(
AutoWeightsLoader
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -243,11 +244,7 @@ class Gemma2DecoderLayer(nn.Module):
...
@@ -243,11 +244,7 @@ class Gemma2DecoderLayer(nn.Module):
@
support_torch_compile
@
support_torch_compile
class
Gemma2Model
(
nn
.
Module
):
class
Gemma2Model
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
cache_config
=
vllm_config
.
cache_config
...
@@ -399,13 +396,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -399,13 +396,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"up_proj"
:
(
"gate_up_proj"
,
1
),
"up_proj"
:
(
"gate_up_proj"
,
1
),
}
}
def
__init__
(
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
)
->
None
:
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
lora_config
=
vllm_config
.
lora_config
del
lora_config
# Unused.
del
lora_config
# Unused.
...
@@ -414,7 +406,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -414,7 +406,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
# currently all existing Gemma models have `tie_word_embeddings` enabled
# currently all existing Gemma models have `tie_word_embeddings` enabled
assert
config
.
tie_word_embeddings
assert
config
.
tie_word_embeddings
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
model
=
Gemma2Model
(
config
,
cache_config
,
quant_config
)
self
.
model
=
Gemma2Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
logits_processor
=
LogitsProcessor
(
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
,
soft_cap
=
config
.
final_logit_softcapping
)
config
.
vocab_size
,
soft_cap
=
config
.
final_logit_softcapping
)
self
.
sampler
=
get_sampler
()
self
.
sampler
=
get_sampler
()
...
@@ -471,14 +464,11 @@ class Gemma2EmbeddingModel(nn.Module, SupportsPP):
...
@@ -471,14 +464,11 @@ class Gemma2EmbeddingModel(nn.Module, SupportsPP):
_pooler: An instance of Pooler used for pooling operations.
_pooler: An instance of Pooler used for pooling operations.
"""
"""
def
__init__
(
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
model
=
Gemma2Model
(
vllm_config
,
prefix
)
self
.
model
=
Gemma2Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
_pooler
=
Pooler
.
from_config_with_defaults
(
self
.
_pooler
=
Pooler
.
from_config_with_defaults
(
vllm_config
.
model_config
.
pooler_config
,
vllm_config
.
model_config
.
pooler_config
,
pooling_type
=
PoolingType
.
LAST
,
pooling_type
=
PoolingType
.
LAST
,
...
...
vllm/model_executor/models/gpt2.py
View file @
f89d18ff
...
@@ -42,7 +42,8 @@ from vllm.sequence import IntermediateTensors
...
@@ -42,7 +42,8 @@ from vllm.sequence import IntermediateTensors
from
.interfaces
import
SupportsPP
from
.interfaces
import
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
class
GPT2Attention
(
nn
.
Module
):
class
GPT2Attention
(
nn
.
Module
):
...
@@ -184,14 +185,13 @@ class GPT2Block(nn.Module):
...
@@ -184,14 +185,13 @@ class GPT2Block(nn.Module):
@
support_torch_compile
@
support_torch_compile
class
GPT2Model
(
nn
.
Module
):
class
GPT2Model
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
self
,
config
:
GPT2Config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
config
=
config
assert
not
config
.
add_cross_attention
assert
not
config
.
add_cross_attention
assert
not
config
.
scale_attn_by_inverse_layer_idx
assert
not
config
.
scale_attn_by_inverse_layer_idx
...
@@ -247,14 +247,12 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
...
@@ -247,14 +247,12 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
):
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
transformer
=
GPT2Model
(
config
,
self
.
transformer
=
GPT2Model
(
vllm_config
=
vllm_config
,
cache_config
,
prefix
=
maybe_prefix
(
quant_config
,
prefix
,
"transformer"
))
prefix
=
"transformer"
)
if
self
.
config
.
tie_word_embeddings
:
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
transformer
.
wte
self
.
lm_head
=
self
.
transformer
.
wte
else
:
else
:
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment