Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
ac8c0f6f
Unverified
Commit
ac8c0f6f
authored
Apr 21, 2023
by
Nick Hill
Committed by
GitHub
Apr 21, 2023
Browse files
feat(server): flash attention past key value optimizations (#213)
parent
274513e6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
39 additions
and
20 deletions
+39
-20
server/text_generation_server/models/flash_causal_lm.py
server/text_generation_server/models/flash_causal_lm.py
+39
-20
No files found.
server/text_generation_server/models/flash_causal_lm.py
View file @
ac8c0f6f
...
...
@@ -38,7 +38,7 @@ class FlashCausalLMBatch(Batch):
# cumulative sequence lengths
cu_seqlens
:
List
[
int
]
max_seqlen
:
int
past_key_values
:
Optional
[
List
[
torch
.
Tensor
]]
past_key_values
:
Optional
[
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]
]
# All tokens
all_input_ids
:
List
[
List
[
int
]]
...
...
@@ -53,6 +53,9 @@ class FlashCausalLMBatch(Batch):
next_token_choosers
:
List
[
NextTokenChooser
]
stopping_criterias
:
List
[
StoppingCriteria
]
# Constant shared tensor, ref here just so that it's accessible in concatentate()
past_pad
:
Optional
[
torch
.
Tensor
]
def
to_pb
(
self
)
->
generate_pb2
.
Batch
:
return
generate_pb2
.
Batch
(
id
=
self
.
batch_id
,
requests
=
self
.
requests
,
size
=
len
(
self
)
...
...
@@ -149,6 +152,8 @@ class FlashCausalLMBatch(Batch):
if
len
(
requests
)
==
len
(
self
):
return
self
single_request
=
len
(
requests
)
==
1
# Cumulative length
cumulative_length
=
0
...
...
@@ -182,7 +187,9 @@ class FlashCausalLMBatch(Batch):
position_ids
.
append
(
self
.
position_ids
[
idx
])
cu_seqlens
.
append
(
cumulative_length
+
request_input_length
)
max_seqlen
=
max
(
max_seqlen
,
request_input_length
)
past_key_values
.
append
(
self
.
past_key_values
[
idx
])
if
not
single_request
:
past_key_values
.
append
(
self
.
past_key_values
[
2
*
idx
])
past_key_values
.
append
(
self
.
past_key_values
[
1
])
all_input_ids
.
append
(
self
.
all_input_ids
[
idx
])
all_input_ids_tensor
.
append
(
self
.
all_input_ids_tensor
[
idx
])
...
...
@@ -196,6 +203,13 @@ class FlashCausalLMBatch(Batch):
cumulative_length
+=
request_input_length
if
single_request
:
# Preallocate tensor for bs = 1 case
past_key_values
=
torch
.
nn
.
functional
.
pad
(
self
.
past_key_values
[
0
],
(
0
,
0
,
0
,
0
,
0
,
0
,
0
,
stopping_criterias
[
0
].
max_new_tokens
-
stopping_criterias
[
0
].
current_tokens
)
)
return
FlashCausalLMBatch
(
batch_id
=
self
.
batch_id
,
requests
=
requests
,
...
...
@@ -256,7 +270,11 @@ class FlashCausalLMBatch(Batch):
# Add cumulative lengths of all previous inputs
cu_seqlens
.
extend
([
l
+
cumulative_length
for
l
in
batch
.
cu_seqlens
[
1
:]])
max_seqlen
=
max
(
max_seqlen
,
batch
.
max_seqlen
)
past_key_values
.
extend
(
batch
.
past_key_values
)
if
len
(
batch
)
!=
1
:
past_key_values
.
extend
(
batch
.
past_key_values
)
else
:
past_key_values
.
append
(
batch
.
past_key_values
[:,
:
batch
.
input_lengths
[
0
]])
past_key_values
.
append
(
batch
.
past_pad
)
all_input_ids
.
extend
(
batch
.
all_input_ids
)
all_input_ids_tensor
.
extend
(
batch
.
all_input_ids_tensor
)
...
...
@@ -303,6 +321,7 @@ class FlashCausalLM(Model):
quantize
:
bool
=
False
,
decode_buffer
:
int
=
3
,
):
self
.
past_pad
=
None
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
bfloat16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float16
...
...
@@ -359,10 +378,8 @@ class FlashCausalLM(Model):
)
->
Tuple
[
List
[
Generation
],
Optional
[
FlashCausalLMBatch
]]:
# Shortcut when batch_size == 1
if
len
(
batch
)
==
1
:
input_ids
=
batch
.
input_ids
[
0
].
view
(
-
1
)
past_key_values
=
(
batch
.
past_key_values
[
0
]
if
batch
.
past_key_values
is
not
None
else
None
)
# No need to slice this down
past_key_values
=
batch
.
past_key_values
else
:
# Concatenate tensors
input_ids
=
torch
.
cat
(
batch
.
input_ids
).
view
(
-
1
)
...
...
@@ -392,7 +409,18 @@ class FlashCausalLM(Model):
# Initialize past_key_values in prefill
if
batch
.
past_key_values
is
None
:
batch
.
past_key_values
=
[
None
]
*
len
(
batch
)
# Initialize past padding tensor
if
self
.
past_pad
is
None
:
self
.
past_pad
=
present
.
new_zeros
(
present
.
shape
[
0
],
1
,
*
present
.
shape
[
2
:])
# Set in batch in case it needs to be used later in concatenate()
batch
.
past_pad
=
self
.
past_pad
if
len
(
batch
)
==
1
:
# Preallocate tensor for bs = 1 case
batch
.
past_key_values
=
torch
.
nn
.
functional
.
pad
(
present
,
(
0
,
0
,
0
,
0
,
0
,
0
,
0
,
batch
.
stopping_criterias
[
0
].
max_new_tokens
)
)
else
:
batch
.
past_key_values
=
[
None
,
self
.
past_pad
]
*
len
(
batch
)
# Cumulative length
cumulative_length
=
0
...
...
@@ -477,21 +505,10 @@ class FlashCausalLM(Model):
generated_text
=
GeneratedText
(
output_text
,
stopping_criteria
.
current_tokens
,
reason
,
seed
)
# CAUTION: generation will be stopped so no need to pad
# This will make the next forward crash if the request does not get filtered
new_input_length
=
input_length
past
=
present
[:,
start_index
:
end_index
]
else
:
stopped
=
False
generated_text
=
None
# Pad present for next iter attention
new_input_length
=
input_length
+
1
past
=
torch
.
nn
.
functional
.
pad
(
present
[:,
start_index
:
end_index
],
(
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
)
)
# Prefill
if
prefill
:
# Remove generated token to only have prefill and add nan for first prompt token
...
...
@@ -522,6 +539,7 @@ class FlashCausalLM(Model):
generations
.
append
(
generation
)
cumulative_length
+=
input_length
new_input_length
=
input_length
+
1
# Update values
batch
.
input_ids
[
i
]
=
next_token_id
...
...
@@ -532,7 +550,8 @@ class FlashCausalLM(Model):
batch
.
all_input_ids
[
i
]
=
all_input_ids
batch
.
all_input_ids_tensor
[
i
]
=
all_input_ids_tensor
batch
.
max_seqlen
=
max
(
batch
.
max_seqlen
,
new_input_length
)
batch
.
past_key_values
[
i
]
=
past
if
len
(
batch
)
!=
1
:
batch
.
past_key_values
[
i
*
2
]
=
present
[:,
start_index
:
end_index
]
# Cumulative sum
batch
.
cu_seqlens
[(
i
+
1
)]
=
batch
.
cu_seqlens
[
i
]
+
new_input_length
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment