Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
ab96b9ae
Unverified
Commit
ab96b9ae
authored
Jul 27, 2023
by
OlivierDehaene
Committed by
GitHub
Jul 27, 2023
Browse files
feat(server): support new falcon config (#712)
parent
2efd46ef
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
38 additions
and
26 deletions
+38
-26
server/text_generation_server/models/__init__.py
server/text_generation_server/models/__init__.py
+3
-8
server/text_generation_server/models/custom_modeling/flash_rw_modeling.py
...ration_server/models/custom_modeling/flash_rw_modeling.py
+35
-18
No files found.
server/text_generation_server/models/__init__.py
View file @
ab96b9ae
...
@@ -200,13 +200,10 @@ def get_model(
...
@@ -200,13 +200,10 @@ def get_model(
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
)
)
if
model_type
in
[
"RefinedWeb"
,
"RefinedWebModel"
]:
if
model_type
in
[
"RefinedWeb"
,
"RefinedWebModel"
,
"falcon"
]:
if
sharded
:
if
sharded
:
if
FLASH_ATTENTION
:
if
FLASH_ATTENTION
:
if
config_dict
.
get
(
"alibi"
,
False
)
or
(
if
config_dict
.
get
(
"alibi"
,
False
):
model_type
==
"RefinedWebModel"
and
config_dict
.
get
(
"multi_query"
,
True
)
):
raise
NotImplementedError
(
"sharded is not supported for this model"
)
raise
NotImplementedError
(
"sharded is not supported for this model"
)
return
FlashRWSharded
(
return
FlashRWSharded
(
model_id
,
model_id
,
...
@@ -215,9 +212,7 @@ def get_model(
...
@@ -215,9 +212,7 @@ def get_model(
dtype
=
dtype
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
)
)
raise
NotImplementedError
(
raise
NotImplementedError
(
FLASH_ATT_ERROR_MESSAGE
.
format
(
f
"Sharded Falcon"
))
FLASH_ATT_ERROR_MESSAGE
.
format
(
f
"Sharded RefinedWeb"
)
)
else
:
else
:
if
FLASH_ATTENTION
and
not
config_dict
.
get
(
"alibi"
,
False
):
if
FLASH_ATTENTION
and
not
config_dict
.
get
(
"alibi"
,
False
):
return
FlashRWSharded
(
return
FlashRWSharded
(
...
...
server/text_generation_server/models/custom_modeling/flash_rw_modeling.py
View file @
ab96b9ae
...
@@ -49,8 +49,8 @@ class RWConfig(PretrainedConfig):
...
@@ -49,8 +49,8 @@ class RWConfig(PretrainedConfig):
model_type
=
"RefinedWeb"
,
model_type
=
"RefinedWeb"
,
vocab_size
=
250880
,
vocab_size
=
250880
,
hidden_size
=
64
,
hidden_size
=
64
,
n
_layer
=
2
,
n
um_hidden_layers
=
None
,
n
_head
=
8
,
n
um_attention_heads
=
None
,
layer_norm_epsilon
=
1e-5
,
layer_norm_epsilon
=
1e-5
,
initializer_range
=
0.02
,
initializer_range
=
0.02
,
use_cache
=
True
,
use_cache
=
True
,
...
@@ -58,9 +58,10 @@ class RWConfig(PretrainedConfig):
...
@@ -58,9 +58,10 @@ class RWConfig(PretrainedConfig):
eos_token_id
=
2
,
eos_token_id
=
2
,
hidden_dropout
=
0.0
,
hidden_dropout
=
0.0
,
attention_dropout
=
0.0
,
attention_dropout
=
0.0
,
n_head
_kv
=
None
,
n
um_kv
_head
s
=
None
,
multi_query
=
False
,
multi_query
=
False
,
alibi
=
False
,
alibi
=
False
,
new_decoder_architecture
=
None
,
bias
=
False
,
bias
=
False
,
parallel_attn
=
False
,
parallel_attn
=
False
,
**
kwargs
,
**
kwargs
,
...
@@ -78,8 +79,16 @@ class RWConfig(PretrainedConfig):
...
@@ -78,8 +79,16 @@ class RWConfig(PretrainedConfig):
# Backward compatibility with n_embed kwarg
# Backward compatibility with n_embed kwarg
n_embed
=
kwargs
.
pop
(
"n_embed"
,
None
)
n_embed
=
kwargs
.
pop
(
"n_embed"
,
None
)
self
.
hidden_size
=
hidden_size
if
n_embed
is
None
else
n_embed
self
.
hidden_size
=
hidden_size
if
n_embed
is
None
else
n_embed
self
.
n_layer
=
n_layer
self
.
n_layer
=
(
self
.
n_head
=
n_head
num_hidden_layers
if
num_hidden_layers
is
not
None
else
kwargs
.
pop
(
"n_layer"
,
2
)
)
self
.
n_head
=
(
num_attention_heads
if
num_attention_heads
is
not
None
else
kwargs
.
pop
(
"n_head"
,
8
)
)
self
.
layer_norm_epsilon
=
layer_norm_epsilon
self
.
layer_norm_epsilon
=
layer_norm_epsilon
self
.
initializer_range
=
initializer_range
self
.
initializer_range
=
initializer_range
self
.
use_cache
=
use_cache
self
.
use_cache
=
use_cache
...
@@ -91,10 +100,21 @@ class RWConfig(PretrainedConfig):
...
@@ -91,10 +100,21 @@ class RWConfig(PretrainedConfig):
self
.
bos_token_id
=
bos_token_id
self
.
bos_token_id
=
bos_token_id
self
.
eos_token_id
=
eos_token_id
self
.
eos_token_id
=
eos_token_id
if
n_head
_kv
is
not
None
:
if
n
um_kv
_head
s
is
not
None
:
self
.
n_head_kv
=
n_head
_kv
self
.
n_head_kv
=
n
um_kv
_head
s
else
:
else
:
self
.
n_head_kv
=
1
if
multi_query
else
n_head
old_n_head_kv
=
kwargs
.
pop
(
"n_head_kv"
,
None
)
if
old_n_head_kv
is
not
None
:
self
.
n_head_kv
=
old_n_head_kv
else
:
self
.
n_head_kv
=
1
if
multi_query
else
self
.
n_head
if
new_decoder_architecture
is
not
None
:
self
.
new_decoder_architecture
=
new_decoder_architecture
elif
model_type
==
"RefinedWeb"
:
self
.
new_decoder_architecture
=
True
else
:
self
.
new_decoder_architecture
=
False
super
().
__init__
(
bos_token_id
=
bos_token_id
,
eos_token_id
=
eos_token_id
,
**
kwargs
)
super
().
__init__
(
bos_token_id
=
bos_token_id
,
eos_token_id
=
eos_token_id
,
**
kwargs
)
...
@@ -530,26 +550,23 @@ class FlashRWModel(FlashRWPreTrainedModel):
...
@@ -530,26 +550,23 @@ class FlashRWModel(FlashRWPreTrainedModel):
self
.
word_embeddings
=
TensorParallelEmbedding
(
self
.
word_embeddings
=
TensorParallelEmbedding
(
prefix
=
"transformer.word_embeddings"
,
weights
=
weights
prefix
=
"transformer.word_embeddings"
,
weights
=
weights
)
)
if
config
.
model_type
==
"RefinedWebModel"
:
if
config
.
new_decoder_architecture
:
self
.
h
=
nn
.
ModuleList
(
self
.
h
=
nn
.
ModuleList
(
[
[
FlashRWLayer
(
layer_id
,
config
,
weights
)
FlashRWLa
rgeLa
yer
(
layer_id
,
config
,
weights
)
for
layer_id
in
range
(
config
.
num_hidden_layers
)
for
layer_id
in
range
(
config
.
num_hidden_layers
)
]
]
)
)
self
.
cache_size
=
self
.
h
[
0
].
self_attention
.
num_
heads_kv
self
.
cache_size
=
self
.
h
[
0
].
self_attention
.
num_
groups
el
if
config
.
model_type
==
"RefinedWeb"
:
el
se
:
self
.
h
=
nn
.
ModuleList
(
self
.
h
=
nn
.
ModuleList
(
[
[
FlashRWLa
rgeLa
yer
(
layer_id
,
config
,
weights
)
FlashRWLayer
(
layer_id
,
config
,
weights
)
for
layer_id
in
range
(
config
.
num_hidden_layers
)
for
layer_id
in
range
(
config
.
num_hidden_layers
)
]
]
)
)
self
.
cache_size
=
self
.
h
[
0
].
self_attention
.
num_groups
self
.
cache_size
=
self
.
h
[
0
].
self_attention
.
num_heads_kv
else
:
raise
NotImplementedError
(
f
"model_type
{
config
.
model_type
}
is not supported."
)
self
.
ln_f
=
FastLayerNorm
.
load
(
self
.
ln_f
=
FastLayerNorm
.
load
(
prefix
=
"transformer.ln_f"
,
prefix
=
"transformer.ln_f"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment