Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
678b2f39
Unverified
Commit
678b2f39
authored
Mar 26, 2023
by
OlivierDehaene
Committed by
GitHub
Mar 26, 2023
Browse files
feat(server): cleanup flash neox loading (#139)
parent
d6a93fe9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
53 additions
and
22 deletions
+53
-22
server/text_generation_server/models/flash_neox.py
server/text_generation_server/models/flash_neox.py
+1
-2
server/text_generation_server/models/flash_neox_modeling.py
server/text_generation_server/models/flash_neox_modeling.py
+52
-20
No files found.
server/text_generation_server/models/flash_neox.py
View file @
678b2f39
...
@@ -450,8 +450,6 @@ class FlashNeoX(Model):
...
@@ -450,8 +450,6 @@ class FlashNeoX(Model):
next_batch_input_ids
=
next_batch_input_ids
[
0
].
view
(
1
)
next_batch_input_ids
=
next_batch_input_ids
[
0
].
view
(
1
)
next_batch_past_key_values
=
next_batch_past_key_values
[
0
]
next_batch_past_key_values
=
next_batch_past_key_values
[
0
]
print
(
next_batch_input_ids
.
shape
)
next_batch
=
FlashNeoXBatch
(
next_batch
=
FlashNeoXBatch
(
batch_id
=
batch
.
batch_id
,
batch_id
=
batch
.
batch_id
,
requests
=
next_batch_requests
,
requests
=
next_batch_requests
,
...
@@ -507,6 +505,7 @@ class FlashNeoXSharded(FlashNeoX):
...
@@ -507,6 +505,7 @@ class FlashNeoXSharded(FlashNeoX):
rank
=
self
.
rank
,
rank
=
self
.
rank
,
world_size
=
self
.
world_size
,
world_size
=
self
.
world_size
,
)
)
model
.
post_load_weights
()
self
.
model
=
model
.
eval
().
to
(
dtype
)
self
.
model
=
model
.
eval
().
to
(
dtype
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
super
(
FlashNeoX
,
self
).
__init__
(
super
(
FlashNeoX
,
self
).
__init__
(
...
...
server/text_generation_server/models/flash_neox_modeling.py
View file @
678b2f39
import
torch
import
torch
import
torch.distributed
import
torch.distributed
from
torch.nn
import
functional
as
F
from
torch
import
nn
from
torch
import
nn
from
transformers.activations
import
ACT2FN
from
transformers.activations
import
ACT2FN
from
transformers.modeling_utils
import
PreTrainedModel
from
transformers.modeling_utils
import
PreTrainedModel
...
@@ -24,13 +26,11 @@ class FastLinear(nn.Linear):
...
@@ -24,13 +26,11 @@ class FastLinear(nn.Linear):
dtype
=
None
,
dtype
=
None
,
)
->
None
:
)
->
None
:
super
(
FastLinear
,
self
).
__init__
(
in_features
,
out_features
,
bias
,
device
,
dtype
)
super
(
FastLinear
,
self
).
__init__
(
in_features
,
out_features
,
bias
,
device
,
dtype
)
self
.
swap_dims
=
True
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
transpose_weight
(
self
):
if
self
.
swap_dims
:
self
.
weight
=
nn
.
Parameter
(
self
.
weight
.
T
)
self
.
weight
=
nn
.
Parameter
(
self
.
weight
.
T
)
self
.
swap_dims
=
False
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
bias
is
not
None
:
if
self
.
bias
is
not
None
:
return
torch
.
addmm
(
self
.
bias
,
input
,
self
.
weight
)
return
torch
.
addmm
(
self
.
bias
,
input
,
self
.
weight
)
return
torch
.
matmul
(
input
,
self
.
weight
)
return
torch
.
matmul
(
input
,
self
.
weight
)
...
@@ -120,6 +120,10 @@ class TensorParallelEmbedding(nn.Embedding):
...
@@ -120,6 +120,10 @@ class TensorParallelEmbedding(nn.Embedding):
self
.
min_id
=
self
.
tp_rank
*
block_size
self
.
min_id
=
self
.
tp_rank
*
block_size
self
.
max_id
=
(
self
.
tp_rank
+
1
)
*
block_size
self
.
max_id
=
(
self
.
tp_rank
+
1
)
*
block_size
# Additional entry that will map to zero
# Used for masking
self
.
null_idx
=
block_size
super
().
__init__
(
super
().
__init__
(
block_size
,
block_size
,
embedding_dim
,
embedding_dim
,
...
@@ -133,15 +137,19 @@ class TensorParallelEmbedding(nn.Embedding):
...
@@ -133,15 +137,19 @@ class TensorParallelEmbedding(nn.Embedding):
dtype
=
dtype
,
dtype
=
dtype
,
)
)
def
add_null_idx
(
self
):
"""Additional 0 entry used for masking"""
self
.
weight
=
nn
.
Parameter
(
F
.
pad
(
self
.
weight
,
(
0
,
0
,
0
,
1
)))
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# `0` if input is in the correct interval, else `1`
# default all out of bounds values to `self.null_idx` that will then be mapped to 0
input_mask
=
torch
.
logical_or
(
self
.
min_id
>
input
,
input
>=
self
.
max_id
)
# translate for [0, self.max_id - self.min_id[
# translate for [0, self.max_id - self.min_id[
input
=
input
-
self
.
min_id
input
=
torch
.
where
(
# default all out of bounds values to `0`
(
self
.
min_id
>
input
)
|
(
input
>=
self
.
max_id
),
input
[
input_mask
]
=
0
self
.
null_idx
,
input
-
self
.
min_id
,
)
out
=
super
().
forward
(
input
)
out
=
super
().
forward
(
input
)
out
[
input_mask
]
=
0.0
torch
.
distributed
.
all_reduce
(
out
,
group
=
self
.
process_group
)
torch
.
distributed
.
all_reduce
(
out
,
group
=
self
.
process_group
)
return
out
return
out
...
@@ -214,11 +222,9 @@ class FlashNeoxAttention(torch.nn.Module):
...
@@ -214,11 +222,9 @@ class FlashNeoxAttention(torch.nn.Module):
hidden_size
,
hidden_size
,
process_group
=
process_group
,
process_group
=
process_group
,
)
)
self
.
swap_dims
=
True
# TODO: remove and swap dims when loading weights
def
shuffle_qkv_dims
(
self
):
def
_swap_dims
(
self
):
"""Swap dims to avoid an additional permute"""
"""Swap dims for the first inference to avoid an additional permute"""
self
.
query_key_value
.
weight
=
torch
.
nn
.
Parameter
(
self
.
query_key_value
.
weight
=
torch
.
nn
.
Parameter
(
self
.
query_key_value
.
weight
.
view
(
self
.
query_key_value
.
weight
.
view
(
self
.
num_heads
,
3
,
self
.
head_size
,
self
.
hidden_size
self
.
num_heads
,
3
,
self
.
head_size
,
self
.
hidden_size
...
@@ -231,7 +237,6 @@ class FlashNeoxAttention(torch.nn.Module):
...
@@ -231,7 +237,6 @@ class FlashNeoxAttention(torch.nn.Module):
.
permute
(
1
,
0
,
2
)
.
permute
(
1
,
0
,
2
)
.
reshape
(
-
1
)
.
reshape
(
-
1
)
)
)
self
.
swap_dims
=
False
def
forward
(
def
forward
(
self
,
self
,
...
@@ -244,9 +249,6 @@ class FlashNeoxAttention(torch.nn.Module):
...
@@ -244,9 +249,6 @@ class FlashNeoxAttention(torch.nn.Module):
layer_past_present_indices
,
layer_past_present_indices
,
cu_seqlens_q
,
cu_seqlens_q
,
):
):
if
self
.
swap_dims
:
self
.
_swap_dims
()
qkv
=
self
.
query_key_value
(
hidden_states
)
qkv
=
self
.
query_key_value
(
hidden_states
)
qkv
=
qkv
.
view
(
-
1
,
3
,
self
.
num_heads
,
self
.
head_size
)
qkv
=
qkv
.
view
(
-
1
,
3
,
self
.
num_heads
,
self
.
head_size
)
qkv_rot
=
self
.
rotary_emb
(
qkv
,
cos
,
sin
)
qkv_rot
=
self
.
rotary_emb
(
qkv
,
cos
,
sin
)
...
@@ -329,7 +331,6 @@ class FlashMLP(nn.Module):
...
@@ -329,7 +331,6 @@ class FlashMLP(nn.Module):
hidden_size
,
hidden_size
,
process_group
=
process_group
,
process_group
=
process_group
,
)
)
self
.
heuristic
=
"auto"
self
.
process_group
=
process_group
self
.
process_group
=
process_group
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
):
...
@@ -531,6 +532,25 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
...
@@ -531,6 +532,25 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
self
.
head_size
=
self
.
layers
[
0
].
attention
.
head_size
self
.
head_size
=
self
.
layers
[
0
].
attention
.
head_size
self
.
num_heads
=
self
.
layers
[
0
].
attention
.
num_heads
self
.
num_heads
=
self
.
layers
[
0
].
attention
.
num_heads
def
post_load_weights
(
self
):
if
isinstance
(
self
.
embed_in
,
TensorParallelEmbedding
):
self
.
embed_in
.
add_null_idx
()
for
layer
in
self
.
layers
:
layer
:
FlashNeoXLayer
layer
.
attention
.
shuffle_qkv_dims
()
layer
.
attention
.
query_key_value
.
transpose_weight
()
layer
.
attention
.
dense
.
transpose_weight
()
layer
.
mlp
.
dense_h_to_4h
.
transpose_weight
()
layer
.
mlp
.
dense_4h_to_h
.
transpose_weight
()
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
model
=
super
(
FlashGPTNeoXModel
,
cls
).
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
model
.
post_load_weights
()
return
model
def
forward
(
def
forward
(
self
,
self
,
input_ids
,
input_ids
,
...
@@ -627,6 +647,18 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
...
@@ -627,6 +647,18 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
)
)
def
post_load_weights
(
self
):
self
.
gpt_neox
.
post_load_weights
()
self
.
embed_out
.
transpose_weight
()
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
model
=
super
(
FlashGPTNeoXForCausalLM
,
cls
).
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
model
.
post_load_weights
()
return
model
def
forward
(
def
forward
(
self
,
self
,
input_ids
,
input_ids
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment