Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
18db78f2
Unverified
Commit
18db78f2
authored
Jul 19, 2024
by
Daniël de Kok
Committed by
GitHub
Jul 19, 2024
Browse files
Hotfix: various GPT-based model fixes (#2256)
parent
80adb5be
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
21 additions
and
8 deletions
+21
-8
server/text_generation_server/models/__init__.py
server/text_generation_server/models/__init__.py
+5
-0
server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
...tion_server/models/custom_modeling/flash_neox_modeling.py
+11
-4
server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py
...erver/models/custom_modeling/flash_starcoder2_modeling.py
+5
-4
No files found.
server/text_generation_server/models/__init__.py
View file @
18db78f2
...
...
@@ -573,6 +573,10 @@ def get_model(
)
elif
model_type
==
GPT_NEOX
:
if
FLASH_ATTENTION
:
from
text_generation_server.models.custom_modeling.flash_neox_modeling
import
(
GPTNeoXConfig
,
)
return
FlashCausalLM
(
model_id
=
model_id
,
model_class
=
FlashGPTNeoXForCausalLM
,
...
...
@@ -582,6 +586,7 @@ def get_model(
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
lora_adapter_ids
=
lora_adapter_ids
,
config_class
=
GPTNeoXConfig
,
)
elif
sharded
:
return
CausalLM
(
...
...
server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
View file @
18db78f2
...
...
@@ -24,7 +24,7 @@ import torch.distributed
from
torch
import
nn
from
transformers.activations
import
ACT2FN
from
transformers.modeling_utils
import
PreTrainedModel
from
transformers.models.gpt_neox
import
GPTNeoXConfig
from
transformers.models.gpt_neox
import
GPTNeoXConfig
as
TransformersGPTNeoXConfig
from
typing
import
Optional
,
List
,
Tuple
from
text_generation_server.layers.attention
import
(
...
...
@@ -45,6 +45,13 @@ from text_generation_server.layers.layernorm import (
from
text_generation_server.layers.rotary
import
(
PositionRotaryEmbedding
,
)
from
text_generation_server.utils.weights
import
UnquantizedWeight
class
GPTNeoXConfig
(
TransformersGPTNeoXConfig
):
attribute_map
=
{
"num_key_value_heads"
:
"num_attention_heads"
,
}
def
load_row
(
config
,
prefix
:
str
,
weights
,
bias
:
bool
):
...
...
@@ -65,10 +72,10 @@ def load_row(config, prefix: str, weights, bias: bool):
def
load_qkv
(
config
,
prefix
:
str
,
weights
,
num_heads
,
head_size
,
hidden_size
):
weight
=
weights
.
get_multi_weights_col
([
prefix
],
dim
=
0
)
if
isinstance
(
weight
,
torch
.
Tensor
):
if
isinstance
(
weight
,
UnquantizedWeight
):
# Only on non quantized versions
weight
=
(
weight
.
view
(
weight
.
weight
=
(
weight
.
weight
.
view
(
num_heads
,
3
,
head_size
,
...
...
server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py
View file @
18db78f2
...
...
@@ -45,6 +45,7 @@ from text_generation_server.layers.layernorm import (
from
text_generation_server.layers.rotary
import
(
PositionRotaryEmbedding
,
)
from
text_generation_server.utils.weights
import
UnquantizedWeight
class
Starcoder2Config
(
PretrainedConfig
):
...
...
@@ -129,16 +130,16 @@ def _load_gqa(config, prefix: str, weights):
dim
=
0
,
)
if
config
.
quantize
not
in
[
"gptq"
,
"awq"
,
"marlin"
]
:
weight
=
weight
.
to
(
dtype
=
weights
.
dtype
).
to
(
device
=
weights
.
device
)
if
isinstance
(
weight
,
UnquantizedWeight
)
:
weight
.
weight
=
weight
.
weight
.
to
(
dtype
=
weights
.
dtype
).
to
(
device
=
weights
.
device
)
head_size
=
config
.
hidden_size
//
config
.
num_attention_heads
num_heads
=
config
.
num_attention_heads
//
weights
.
process_group
.
size
()
num_key_value_heads
=
config
.
num_key_value_heads
//
weights
.
process_group
.
size
()
assert
list
(
weight
.
shape
)
==
[
assert
list
(
weight
.
weight
.
shape
)
==
[
(
num_heads
+
2
*
num_key_value_heads
)
*
head_size
,
config
.
hidden_size
,
],
f
"
{
list
(
weight
.
shape
)
}
!=
{
[(
num_heads
+
2
*
config
.
num_key_value_heads
)
*
head_size
,
config
.
hidden_size
]
}
"
],
f
"
{
list
(
weight
.
weight
.
shape
)
}
!=
{
[(
num_heads
+
2
*
config
.
num_key_value_heads
)
*
head_size
,
config
.
hidden_size
]
}
"
if
config
.
use_bias
:
w
=
[
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment