Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
5c7c9f13
Unverified
Commit
5c7c9f13
authored
Jul 08, 2024
by
Daniël de Kok
Committed by
GitHub
Jul 08, 2024
Browse files
Falcon/DBRX: get correct number of key-value heads (#2205)
parent
153fcf77
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
22 additions
and
6 deletions
+22
-6
server/text_generation_server/models/__init__.py
server/text_generation_server/models/__init__.py
+4
-0
server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py
...tion_server/models/custom_modeling/flash_dbrx_modeling.py
+12
-0
server/text_generation_server/models/custom_modeling/flash_rw_modeling.py
...ration_server/models/custom_modeling/flash_rw_modeling.py
+1
-0
server/text_generation_server/models/flash_causal_lm.py
server/text_generation_server/models/flash_causal_lm.py
+5
-6
No files found.
server/text_generation_server/models/__init__.py
View file @
5c7c9f13
...
@@ -797,6 +797,10 @@ def get_model(
...
@@ -797,6 +797,10 @@ def get_model(
quantize
=
quantize
,
quantize
=
quantize
,
speculator
=
speculator
,
speculator
=
speculator
,
dtype
=
dtype
,
dtype
=
dtype
,
aliases
=
{
"lm_head.weight"
:
[
"transformer.word_embeddings.weight"
],
"transformer.word_embeddings.weight"
:
[
"lm_head.weight"
],
},
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
lora_adapter_ids
=
lora_adapter_ids
,
lora_adapter_ids
=
lora_adapter_ids
,
config_class
=
RWConfig
,
config_class
=
RWConfig
,
...
...
server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py
View file @
5c7c9f13
...
@@ -105,6 +105,12 @@ class DbrxFFNConfig(PretrainedConfig):
...
@@ -105,6 +105,12 @@ class DbrxFFNConfig(PretrainedConfig):
class
DbrxConfig
(
PretrainedConfig
):
class
DbrxConfig
(
PretrainedConfig
):
attribute_map
=
{
"hidden_size"
:
"d_model"
,
"num_attention_heads"
:
"n_heads"
,
"num_hidden_layers"
:
"n_layers"
,
}
def
__init__
(
def
__init__
(
self
,
self
,
d_model
:
int
=
2048
,
d_model
:
int
=
2048
,
...
@@ -157,6 +163,12 @@ class DbrxConfig(PretrainedConfig):
...
@@ -157,6 +163,12 @@ class DbrxConfig(PretrainedConfig):
**
kwargs
,
**
kwargs
,
)
)
@
property
def
num_key_value_heads
(
self
):
# We can't use the attribute map, since this the number of KV
# heads is not top-level.
return
self
.
attn_config
.
kv_n_heads
def
promote_scalar
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
promote_scalar
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
x
.
view
(
1
)
if
len
(
x
.
size
())
==
0
else
x
return
x
.
view
(
1
)
if
len
(
x
.
size
())
==
0
else
x
...
...
server/text_generation_server/models/custom_modeling/flash_rw_modeling.py
View file @
5c7c9f13
...
@@ -42,6 +42,7 @@ class RWConfig(PretrainedConfig):
...
@@ -42,6 +42,7 @@ class RWConfig(PretrainedConfig):
attribute_map
=
{
attribute_map
=
{
"num_hidden_layers"
:
"n_layer"
,
"num_hidden_layers"
:
"n_layer"
,
"num_attention_heads"
:
"n_head"
,
"num_attention_heads"
:
"n_head"
,
"num_key_value_heads"
:
"n_head_kv"
,
}
}
def
__init__
(
def
__init__
(
...
...
server/text_generation_server/models/flash_causal_lm.py
View file @
5c7c9f13
...
@@ -905,13 +905,12 @@ class FlashCausalLM(Model):
...
@@ -905,13 +905,12 @@ class FlashCausalLM(Model):
self
.
num_layers
=
config
.
num_hidden_layers
self
.
num_layers
=
config
.
num_hidden_layers
# Validation is done in the model itself
# Validation is done in the model itself
if
num_kv_heads
is
None
:
if
num_kv_heads
is
None
:
# Order is important here.
num_kv_heads
=
getattr
(
config
,
"num_key_value_heads"
,
None
)
for
attr
in
[
"num_key_value_heads"
,
"num_attention_heads"
,
"n_head"
]:
# GPT-2 workaround
num_kv_heads
=
getattr
(
config
,
attr
,
None
)
if
num_kv_heads
is
not
None
:
break
if
num_kv_heads
is
None
:
if
num_kv_heads
is
None
:
raise
ValueError
(
"Cannot get the number of key/value heads"
)
num_kv_heads
=
getattr
(
config
,
"n_head"
,
None
)
if
num_kv_heads
is
None
:
raise
ValueError
(
"Cannot get the number of key/value heads"
)
self
.
num_kv_heads
=
(
self
.
num_kv_heads
=
(
num_kv_heads
//
self
.
process_group
.
size
()
num_kv_heads
//
self
.
process_group
.
size
()
if
num_kv_heads
>
1
if
num_kv_heads
>
1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment