Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
afc5b999
Unverified
Commit
afc5b999
authored
Apr 21, 2023
by
OlivierDehaene
Committed by
GitHub
Apr 21, 2023
Browse files
fix(server): cleanup new flash past_key_values logic (#217)
parent
db4cb5e4
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
5 additions
and
12 deletions
+5
-12
server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
...ion_server/models/custom_modeling/flash_llama_modeling.py
+1
-1
server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
...tion_server/models/custom_modeling/flash_neox_modeling.py
+1
-1
server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
...erver/models/custom_modeling/flash_santacoder_modeling.py
+1
-1
server/text_generation_server/models/flash_causal_lm.py
server/text_generation_server/models/flash_causal_lm.py
+2
-9
No files found.
server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
View file @
afc5b999
...
@@ -594,7 +594,7 @@ class FlashLlamaModel(torch.nn.Module):
...
@@ -594,7 +594,7 @@ class FlashLlamaModel(torch.nn.Module):
residual
=
None
residual
=
None
for
i
,
layer
in
enumerate
(
self
.
layers
):
for
i
,
layer
in
enumerate
(
self
.
layers
):
# We added padding that now need to slice
# We added padding that
we
now need to slice
layer_past_key_values
=
(
layer_past_key_values
=
(
past_key_values
[
i
]
past_key_values
[
i
]
if
slice_past_index
is
None
if
slice_past_index
is
None
...
...
server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
View file @
afc5b999
...
@@ -657,7 +657,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
...
@@ -657,7 +657,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
residual
=
None
residual
=
None
for
i
,
layer
in
enumerate
(
self
.
layers
):
for
i
,
layer
in
enumerate
(
self
.
layers
):
# We added padding that now need to slice
# We added padding that
we
now need to slice
layer_past_key_values
=
(
layer_past_key_values
=
(
past_key_values
[
i
]
past_key_values
[
i
]
if
slice_past_index
is
None
if
slice_past_index
is
None
...
...
server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
View file @
afc5b999
...
@@ -520,7 +520,7 @@ class FlashSantacoderModel(nn.Module):
...
@@ -520,7 +520,7 @@ class FlashSantacoderModel(nn.Module):
residual
=
None
residual
=
None
for
i
,
layer
in
enumerate
(
self
.
h
):
for
i
,
layer
in
enumerate
(
self
.
h
):
# We added padding that now need to slice
# We added padding that
we
now need to slice
layer_past_key_values
=
(
layer_past_key_values
=
(
past_key_values
[
i
]
past_key_values
[
i
]
if
slice_past_index
is
None
if
slice_past_index
is
None
...
...
server/text_generation_server/models/flash_causal_lm.py
View file @
afc5b999
...
@@ -404,8 +404,7 @@ class FlashCausalLM(Model):
...
@@ -404,8 +404,7 @@ class FlashCausalLM(Model):
# Shortcut when batch_size == 1
# Shortcut when batch_size == 1
if
len
(
batch
)
==
1
:
if
len
(
batch
)
==
1
:
input_ids
=
batch
.
input_ids
[
0
].
view
(
-
1
)
input_ids
=
batch
.
input_ids
[
0
].
view
(
-
1
)
# Slice to remove extra padding
# No need to slice as flash attention will take care of it with cu_seqlens
# past_key_values = batch.past_key_values[:, :batch.input_lengths[0]] if batch.past_key_values is not None else None
past_key_values
=
batch
.
past_key_values
past_key_values
=
batch
.
past_key_values
else
:
else
:
# Concatenate tensors
# Concatenate tensors
...
@@ -454,13 +453,7 @@ class FlashCausalLM(Model):
...
@@ -454,13 +453,7 @@ class FlashCausalLM(Model):
)
)
# Set in batch in case it needs to be used later in concatenate()
# Set in batch in case it needs to be used later in concatenate()
batch
.
past_pad
=
self
.
past_pad
batch
.
past_pad
=
self
.
past_pad
if
len
(
batch
)
==
1
:
if
len
(
batch
)
!=
1
:
# Preallocate tensor for bs = 1 case
batch
.
past_key_values
=
torch
.
nn
.
functional
.
pad
(
present
,
(
0
,
0
,
0
,
0
,
0
,
0
,
0
,
batch
.
stopping_criterias
[
0
].
max_new_tokens
),
)
else
:
# Add padding after each sequence
# Add padding after each sequence
# This will have the correct shape after the final past_key_values concatenation before the model
# This will have the correct shape after the final past_key_values concatenation before the model
# forward
# forward
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment