Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
890aa93d
Unverified
Commit
890aa93d
authored
May 28, 2024
by
Isotr0py
Committed by
GitHub
May 27, 2024
Browse files
[Model] Add support for falcon-11B (#5069)
parent
fbdb7b3e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
40 additions
and
15 deletions
+40
-15
vllm/model_executor/models/falcon.py
vllm/model_executor/models/falcon.py
+40
-15
No files found.
vllm/model_executor/models/falcon.py
View file @
890aa93d
...
@@ -41,7 +41,7 @@ from vllm.model_executor.layers.quantization.base_config import (
...
@@ -41,7 +41,7 @@ from vllm.model_executor.layers.quantization.base_config import (
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
...
@@ -246,18 +246,26 @@ class FalconDecoderLayer(nn.Module):
...
@@ -246,18 +246,26 @@ class FalconDecoderLayer(nn.Module):
self
.
mlp
=
FalconMLP
(
config
,
quant_config
)
self
.
mlp
=
FalconMLP
(
config
,
quant_config
)
self
.
config
=
config
self
.
config
=
config
if
config
.
n
ew_decoder_architecture
:
if
(
config
.
n
um_ln_in_parallel_attn
is
None
# The layer norm before self-attention
and
config
.
new_decoder_architecture
):
self
.
ln_attn
=
LayerNorm
(
hidden_size
,
config
.
num_ln_in_parallel_attn
=
2
eps
=
config
.
layer_norm_epsilon
)
# The layer norm before the MLP
if
not
config
.
parallel_attn
:
self
.
ln_mlp
=
L
ayer
N
orm
(
hidden_size
,
eps
=
config
.
l
ayer
_n
orm
_epsilon
)
self
.
post_attention_l
ayer
n
orm
=
L
ayer
N
orm
(
else
:
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
input_layernorm
=
LayerNorm
(
hidden_size
,
self
.
input_layernorm
=
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
eps
=
config
.
layer_norm_epsilon
)
if
not
config
.
parallel_attn
:
else
:
self
.
post_attention_layernorm
=
LayerNorm
(
if
config
.
num_ln_in_parallel_attn
==
2
:
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
# The layer norm before self-attention
self
.
ln_attn
=
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
# The layer norm before the MLP
self
.
ln_mlp
=
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
else
:
self
.
input_layernorm
=
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
reduce_row_parallel_results
=
not
(
config
.
new_decoder_architecture
self
.
reduce_row_parallel_results
=
not
(
config
.
new_decoder_architecture
or
config
.
parallel_attn
)
or
config
.
parallel_attn
)
...
@@ -271,7 +279,7 @@ class FalconDecoderLayer(nn.Module):
...
@@ -271,7 +279,7 @@ class FalconDecoderLayer(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
residual
=
hidden_states
residual
=
hidden_states
if
self
.
config
.
n
ew_decoder_architecture
:
if
self
.
config
.
n
um_ln_in_parallel_attn
==
2
:
attention_layernorm_out
=
self
.
ln_attn
(
hidden_states
)
attention_layernorm_out
=
self
.
ln_attn
(
hidden_states
)
mlp_layernorm_out
=
self
.
ln_mlp
(
hidden_states
)
mlp_layernorm_out
=
self
.
ln_mlp
(
hidden_states
)
else
:
else
:
...
@@ -294,6 +302,10 @@ class FalconDecoderLayer(nn.Module):
...
@@ -294,6 +302,10 @@ class FalconDecoderLayer(nn.Module):
residual
+=
attention_output
residual
+=
attention_output
mlp_layernorm_out
=
self
.
post_attention_layernorm
(
residual
)
mlp_layernorm_out
=
self
.
post_attention_layernorm
(
residual
)
if
(
self
.
config
.
new_decoder_architecture
and
self
.
config
.
parallel_attn
and
self
.
config
.
num_ln_in_parallel_attn
==
1
):
mlp_layernorm_out
=
attention_layernorm_out
# MLP.
# MLP.
mlp_output
,
mlp_bias
=
self
.
mlp
(
mlp_layernorm_out
)
mlp_output
,
mlp_bias
=
self
.
mlp
(
mlp_layernorm_out
)
if
self
.
reduce_row_parallel_results
and
mlp_bias
is
not
None
:
if
self
.
reduce_row_parallel_results
and
mlp_bias
is
not
None
:
...
@@ -375,7 +387,20 @@ class FalconForCausalLM(nn.Module):
...
@@ -375,7 +387,20 @@ class FalconForCausalLM(nn.Module):
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
transformer
=
FalconModel
(
config
,
cache_config
,
quant_config
)
self
.
transformer
=
FalconModel
(
config
,
cache_config
,
quant_config
)
self
.
lm_head_weight
=
self
.
transformer
.
word_embeddings
.
weight
# only Falcon-11B doesn't share lm_head weight with word embeddings
# and previous Falcon model doesn't have tie_word_embeddings config
# so we set tie_word_embeddings to True by default
self
.
tie_word_embeddings
=
(
config
.
tie_word_embeddings
if
config
.
tie_word_embeddings
is
not
None
else
True
)
if
self
.
tie_word_embeddings
:
self
.
lm_head_weight
=
self
.
transformer
.
word_embeddings
.
weight
else
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
)
self
.
lm_head_weight
=
self
.
lm_head
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
...
@@ -419,8 +444,8 @@ class FalconForCausalLM(nn.Module):
...
@@ -419,8 +444,8 @@ class FalconForCausalLM(nn.Module):
num_query_heads_per_kv_head
=
total_num_heads
//
total_num_kv_heads
num_query_heads_per_kv_head
=
total_num_heads
//
total_num_kv_heads
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
if
name
==
"lm_head.weight"
:
if
name
==
"lm_head.weight"
and
self
.
tie_word_embeddings
:
# Falcon uses tied embeddings.
# Falcon uses tied embeddings
except Falcon-11b
.
continue
continue
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment