Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
4a7dd408
Unverified
Commit
4a7dd408
authored
Apr 24, 2023
by
Nick Hill
Committed by
GitHub
Apr 24, 2023
Browse files
feat(server): reduce memory requirement (#214)
parent
6ded76a4
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
287 additions
and
174 deletions
+287
-174
server/tests/models/test_bloom.py
server/tests/models/test_bloom.py
+17
-6
server/tests/models/test_causal_lm.py
server/tests/models/test_causal_lm.py
+17
-6
server/tests/models/test_seq2seq_lm.py
server/tests/models/test_seq2seq_lm.py
+23
-10
server/text_generation_server/models/causal_lm.py
server/text_generation_server/models/causal_lm.py
+125
-76
server/text_generation_server/models/seq2seq_lm.py
server/text_generation_server/models/seq2seq_lm.py
+105
-76
No files found.
server/tests/models/test_bloom.py
View file @
4a7dd408
...
...
@@ -175,12 +175,14 @@ def test_causal_lm_generate_token_completion_multi(
generations
[
1
].
generated_text
.
generated_tokens
==
default_multi_requests_bloom_batch
.
stopping_criterias
[
1
].
max_new_tokens
)
# Copy stopping_criterias before filtering
stopping_criterias
=
default_multi_requests_bloom_batch
.
stopping_criterias
.
copy
()
next_batch
=
next_batch
.
filter
([
next_batch
.
requests
[
0
]])
for
_
in
range
(
default_multi_requests_bloom_batch
.
stopping_criterias
[
0
].
max_new_tokens
-
default_multi_requests_bloom_batch
.
stopping_criterias
[
1
].
max_new_tokens
stopping_criterias
[
0
].
max_new_tokens
-
stopping_criterias
[
1
].
max_new_tokens
-
1
):
generations
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
...
...
@@ -212,6 +214,15 @@ def test_batch_concatenate(
next_batch_1
=
default_multi_requests_bloom_batch
_
,
next_batch_1
=
default_bloom
.
generate_token
(
next_batch_1
)
# Clone past_key_values before concatenating to compare after,
# because they are removed from the concatenated batches
next_batch_0_past_key_values
=
[
(
k
.
clone
(),
v
.
clone
())
for
(
k
,
v
)
in
next_batch_0
.
past_key_values
]
next_batch_1_past_key_values
=
[
(
k
.
clone
(),
v
.
clone
())
for
(
k
,
v
)
in
next_batch_1
.
past_key_values
]
next_batch
=
BloomCausalLMBatch
.
concatenate
([
next_batch_0
,
next_batch_1
])
assert
torch
.
equal
(
next_batch
.
all_input_ids
[
0
],
next_batch_0
.
all_input_ids
[
0
])
...
...
@@ -246,15 +257,15 @@ def test_batch_concatenate(
assert
all
([
p
[
1
].
shape
==
(
3
,
16
,
2
,
64
)
for
p
in
next_batch
.
past_key_values
])
for
i
,
past
in
enumerate
(
next_batch
.
past_key_values
):
assert
torch
.
equal
(
next_batch_0
.
past_key_values
[
i
][
0
][:,
:,
-
2
:],
past
[
0
][
0
])
assert
torch
.
equal
(
next_batch_0
_
past_key_values
[
i
][
0
][:,
:,
-
2
:],
past
[
0
][
0
])
assert
torch
.
equal
(
next_batch_1
.
past_key_values
[
i
][
0
][:,
:,
-
1
:],
next_batch_1
_
past_key_values
[
i
][
0
][:,
:,
-
1
:],
past
[
0
][
1
:,
:,
:,
-
1
].
reshape
(
-
1
,
64
,
1
),
)
assert
torch
.
equal
(
next_batch_0
.
past_key_values
[
i
][
1
][:,
-
2
:,
:],
past
[
1
][
0
])
assert
torch
.
equal
(
next_batch_0
_
past_key_values
[
i
][
1
][:,
-
2
:,
:],
past
[
1
][
0
])
assert
torch
.
equal
(
next_batch_1
.
past_key_values
[
i
][
1
][:,
-
1
:,
:],
next_batch_1
_
past_key_values
[
i
][
1
][:,
-
1
:,
:],
past
[
1
][
1
:,
:,
-
1
,
:].
reshape
(
-
1
,
1
,
64
),
)
...
...
server/tests/models/test_causal_lm.py
View file @
4a7dd408
...
...
@@ -173,12 +173,14 @@ def test_causal_lm_generate_token_completion_multi(
generations
[
1
].
generated_text
.
generated_tokens
==
default_multi_requests_causal_lm_batch
.
stopping_criterias
[
1
].
max_new_tokens
)
# Copy stopping_criterias before filtering
stopping_criterias
=
default_multi_requests_causal_lm_batch
.
stopping_criterias
.
copy
()
next_batch
=
next_batch
.
filter
([
next_batch
.
requests
[
0
]])
for
_
in
range
(
default_multi_requests_causal_lm_batch
.
stopping_criterias
[
0
].
max_new_tokens
-
default_multi_requests_causal_lm_batch
.
stopping_criterias
[
1
].
max_new_tokens
stopping_criterias
[
0
].
max_new_tokens
-
stopping_criterias
[
1
].
max_new_tokens
-
1
):
generations
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
...
...
@@ -209,6 +211,15 @@ def test_batch_concatenate(
next_batch_1
=
default_multi_requests_causal_lm_batch
_
,
next_batch_1
=
default_causal_lm
.
generate_token
(
next_batch_1
)
# Clone past_key_values before concatenating to compare after,
# because they are removed from the concatenated batches
next_batch_0_past_key_values
=
[
(
k
.
clone
(),
v
.
clone
())
for
(
k
,
v
)
in
next_batch_0
.
past_key_values
]
next_batch_1_past_key_values
=
[
(
k
.
clone
(),
v
.
clone
())
for
(
k
,
v
)
in
next_batch_1
.
past_key_values
]
next_batch
=
CausalLMBatch
.
concatenate
([
next_batch_0
,
next_batch_1
])
assert
torch
.
equal
(
next_batch
.
all_input_ids
[
0
],
next_batch_0
.
all_input_ids
[
0
])
...
...
@@ -244,14 +255,14 @@ def test_batch_concatenate(
assert
all
([
p
[
1
].
shape
==
(
3
,
12
,
2
,
64
)
for
p
in
next_batch
.
past_key_values
])
for
i
,
past
in
enumerate
(
next_batch
.
past_key_values
):
assert
torch
.
equal
(
next_batch_0
.
past_key_values
[
i
][
0
][
0
,
:,
-
2
:],
past
[
0
][
0
])
assert
torch
.
equal
(
next_batch_0
_
past_key_values
[
i
][
0
][
0
,
:,
-
2
:],
past
[
0
][
0
])
assert
torch
.
equal
(
next_batch_1
.
past_key_values
[
i
][
0
][:,
:,
-
1
:],
past
[
0
][
1
:,
:,
-
1
:,
:]
next_batch_1
_
past_key_values
[
i
][
0
][:,
:,
-
1
:],
past
[
0
][
1
:,
:,
-
1
:,
:]
)
assert
torch
.
equal
(
next_batch_0
.
past_key_values
[
i
][
1
][
0
,
:,
-
2
:],
past
[
1
][
0
])
assert
torch
.
equal
(
next_batch_0
_
past_key_values
[
i
][
1
][
0
,
:,
-
2
:],
past
[
1
][
0
])
assert
torch
.
equal
(
next_batch_1
.
past_key_values
[
i
][
1
][:,
:,
-
1
:],
past
[
1
][
1
:,
:,
-
1
:,
:]
next_batch_1
_
past_key_values
[
i
][
1
][:,
:,
-
1
:],
past
[
1
][
1
:,
:,
-
1
:,
:]
)
for
_
in
range
(
...
...
server/tests/models/test_seq2seq_lm.py
View file @
4a7dd408
...
...
@@ -219,6 +219,19 @@ def test_batch_concatenate(
next_batch_1
=
default_multi_requests_seq2seq_lm_batch
_
,
next_batch_1
=
default_seq2seq_lm
.
generate_token
(
next_batch_1
)
# Copy hidden state because it is removed from the concatenated branches
next_batch_0_encoder_last_hidden_state
=
next_batch_0
.
encoder_last_hidden_state
next_batch_1_encoder_last_hidden_state
=
next_batch_1
.
encoder_last_hidden_state
# Clone past_key_values before concatenating to compare after,
# because they are removed from the concatenated batches
next_batch_0_past_key_values
=
[
[
t
.
clone
()
for
t
in
layer
]
for
layer
in
next_batch_0
.
past_key_values
]
next_batch_1_past_key_values
=
[
[
t
.
clone
()
for
t
in
layer
]
for
layer
in
next_batch_1
.
past_key_values
]
next_batch
=
Seq2SeqLMBatch
.
concatenate
([
next_batch_0
,
next_batch_1
])
assert
next_batch
.
batch_id
==
0
...
...
@@ -239,11 +252,11 @@ def test_batch_concatenate(
assert
torch
.
equal
(
next_batch
.
encoder_last_hidden_state
[
0
],
next_batch_0
.
encoder_last_hidden_state
[
0
,
-
2
:],
next_batch_0
_
encoder_last_hidden_state
[
0
,
-
2
:],
)
assert
torch
.
equal
(
next_batch
.
encoder_last_hidden_state
[
1
:],
next_batch_1
.
encoder_last_hidden_state
[:,
-
2
:],
next_batch_1
_
encoder_last_hidden_state
[:,
-
2
:],
)
assert
next_batch
.
input_lengths
==
[
2
,
2
,
2
]
...
...
@@ -275,24 +288,24 @@ def test_batch_concatenate(
)
for
i
,
past
in
enumerate
(
next_batch
.
past_key_values
):
assert
torch
.
equal
(
next_batch_0
.
past_key_values
[
i
][
0
][
0
,
:,
-
2
:,
:],
past
[
0
][
0
])
assert
torch
.
equal
(
next_batch_0
_
past_key_values
[
i
][
0
][
0
,
:,
-
2
:,
:],
past
[
0
][
0
])
assert
torch
.
equal
(
next_batch_1
.
past_key_values
[
i
][
0
][:,
:,
-
1
:,
:],
past
[
0
][
1
:,
:,
-
1
:,
:]
next_batch_1
_
past_key_values
[
i
][
0
][:,
:,
-
1
:,
:],
past
[
0
][
1
:,
:,
-
1
:,
:]
)
assert
torch
.
equal
(
next_batch_0
.
past_key_values
[
i
][
1
][
0
,
:,
-
2
:,
:],
past
[
1
][
0
])
assert
torch
.
equal
(
next_batch_0
_
past_key_values
[
i
][
1
][
0
,
:,
-
2
:,
:],
past
[
1
][
0
])
assert
torch
.
equal
(
next_batch_1
.
past_key_values
[
i
][
1
][:,
:,
-
1
:,
:],
past
[
1
][
1
:,
:,
-
1
:,
:]
next_batch_1
_
past_key_values
[
i
][
1
][:,
:,
-
1
:,
:],
past
[
1
][
1
:,
:,
-
1
:,
:]
)
assert
torch
.
equal
(
next_batch_0
.
past_key_values
[
i
][
2
][
0
,
:,
-
2
:,
:],
past
[
2
][
0
])
assert
torch
.
equal
(
next_batch_0
_
past_key_values
[
i
][
2
][
0
,
:,
-
2
:,
:],
past
[
2
][
0
])
assert
torch
.
equal
(
next_batch_1
.
past_key_values
[
i
][
2
][:,
:,
-
2
:,
:],
past
[
2
][
1
:]
next_batch_1
_
past_key_values
[
i
][
2
][:,
:,
-
2
:,
:],
past
[
2
][
1
:]
)
assert
torch
.
equal
(
next_batch_0
.
past_key_values
[
i
][
3
][
0
,
:,
-
2
:,
:],
past
[
3
][
0
])
assert
torch
.
equal
(
next_batch_0
_
past_key_values
[
i
][
3
][
0
,
:,
-
2
:,
:],
past
[
3
][
0
])
assert
torch
.
equal
(
next_batch_1
.
past_key_values
[
i
][
3
][:,
:,
-
2
:,
:],
past
[
3
][
1
:]
next_batch_1
_
past_key_values
[
i
][
3
][:,
:,
-
2
:,
:],
past
[
3
][
1
:]
)
for
_
in
range
(
3
):
...
...
server/text_generation_server/models/causal_lm.py
View file @
4a7dd408
...
...
@@ -150,6 +150,8 @@ class CausalLMBatch(Batch):
next_token_choosers
=
[]
stopping_criterias
=
[]
new_padding_right_offset
=
0
for
i
,
r
in
enumerate
(
requests
):
idx
=
self
.
requests_idx_mapping
[
r
.
id
]
requests_idx_mapping
[
r
.
id
]
=
i
...
...
@@ -164,36 +166,57 @@ class CausalLMBatch(Batch):
max_input_length
=
max
(
max_input_length
,
request_input_length
)
next_token_choosers
.
append
(
self
.
next_token_choosers
[
idx
])
stopping_criterias
.
append
(
self
.
stopping_criterias
[
idx
])
stopping_criteria
=
self
.
stopping_criterias
[
idx
]
stopping_criterias
.
append
(
stopping_criteria
)
new_padding_right_offset
=
max
(
new_padding_right_offset
,
stopping_criteria
.
max_new_tokens
-
stopping_criteria
.
current_tokens
)
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
input_ids
=
self
.
input_ids
[
keep_indices
]
attention_mask
=
self
.
attention_mask
[
keep_indices
]
position_ids
=
self
.
position_ids
[
keep_indices
]
# Force past to be of dim [self_size, num_heads, ...] for easy indexing
past_key_values
=
[
[
t
.
view
(
len
(
self
),
-
1
,
*
t
.
shape
[
-
2
:])[
keep_indices
]
for
t
in
layer
]
for
layer
in
self
.
past_key_values
self
.
attention_mask
=
self
.
attention_mask
[
keep_indices
,
-
(
self
.
padding_right_offset
+
max_input_length
):
(
self
.
attention_mask
.
shape
[
1
]
-
self
.
padding_right_offset
)
+
new_padding_right_offset
,
]
return
CausalLMBatch
(
batch_id
=
self
.
batch_id
,
requests
=
requests
,
requests_idx_mapping
=
requests_idx_mapping
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
past_key_values
=
past_key_values
,
all_input_ids
=
all_input_ids
,
input_lengths
=
input_lengths
,
offsets
=
offsets
,
token_offsets
=
token_offsets
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
max_input_length
=
max_input_length
,
padding_right_offset
=
self
.
padding_right_offset
,
keys_head_dim_last
=
self
.
keys_head_dim_last
,
)
# Ensure that past_key_values tensors can be updated in-place
if
type
(
self
.
past_key_values
[
0
])
==
tuple
:
self
.
past_key_values
=
[
list
(
layer
)
for
layer
in
self
.
past_key_values
]
# Update tensors in-place to allow incremental garbage collection
past_kv_length
=
max_input_length
-
1
for
layer
in
self
.
past_key_values
:
past_keys
,
past_values
=
layer
if
len
(
past_keys
.
shape
)
==
3
:
# Force past to be of dim [self_size, num_heads, ...] for easy indexing
past_keys
=
past_keys
.
view
(
len
(
self
),
-
1
,
*
past_keys
.
shape
[
-
2
:])
past_values
=
past_values
.
view
(
len
(
self
),
-
1
,
*
past_values
.
shape
[
-
2
:])
if
self
.
keys_head_dim_last
:
layer
[
0
]
=
past_keys
[
keep_indices
,
:,
-
past_kv_length
:,
:]
else
:
layer
[
0
]
=
past_keys
[
keep_indices
,
:,
:,
-
past_kv_length
:]
del
past_keys
layer
[
1
]
=
past_values
[
keep_indices
,
:,
-
past_kv_length
:,
:]
del
past_values
self
.
requests
=
requests
self
.
requests_idx_mapping
=
requests_idx_mapping
self
.
input_ids
=
input_ids
self
.
position_ids
=
position_ids
self
.
all_input_ids
=
all_input_ids
self
.
input_lengths
=
input_lengths
self
.
offsets
=
offsets
self
.
token_offsets
=
token_offsets
self
.
next_token_choosers
=
next_token_choosers
self
.
stopping_criterias
=
stopping_criterias
self
.
max_input_length
=
max_input_length
self
.
padding_right_offset
=
new_padding_right_offset
return
self
@
classmethod
@
tracer
.
start_as_current_span
(
"concatenate"
)
...
...
@@ -285,16 +308,23 @@ class CausalLMBatch(Batch):
position_ids
=
batch
.
position_ids
.
new_empty
((
total_batch_size
,
1
))
position_ids
[
start_index
:
end_index
]
=
batch
.
position_ids
for
j
,
past
in
enumerate
(
batch
.
past_key_values
):
past_keys
,
past_values
=
past
# Shenanigans to get dimensions because BLOOM outputs a past with a different shape
# BLOOM Keys: [batch_size * num_heads, head_dim, seq_length]
# BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
past_keys
=
past_keys
.
view
(
len
(
batch
),
-
1
,
*
past_keys
.
shape
[
-
2
:])
past_values
=
past_values
.
view
(
len
(
batch
),
-
1
,
*
past_values
.
shape
[
-
2
:])
# And ensure that we can update tensors in-place
if
type
(
batch
.
past_key_values
[
0
])
==
tuple
:
batch
.
past_key_values
=
[
[
t
.
view
(
len
(
batch
),
-
1
,
*
t
.
shape
[
-
2
:])
for
t
in
layer
]
for
layer
in
batch
.
past_key_values
]
elif
batch
.
past_key_values
[
0
][
0
].
shape
==
3
:
for
layer
in
batch
.
past_key_values
:
for
k
,
t
in
enumerate
(
layer
):
layer
[
k
]
=
t
.
view
(
len
(
batch
),
-
1
,
*
t
.
shape
[
-
2
:])
_
,
num_heads
,
padded_sequence_length
,
head_dim
=
past_values
.
shape
start_index
=
end_index
first_past_kvs
=
batches
[
0
].
past_key_values
_
,
num_heads
,
padded_sequence_length
,
head_dim
=
first_past_kvs
[
0
][
1
].
shape
padded_past_values_shape
=
(
total_batch_size
,
...
...
@@ -303,7 +333,7 @@ class CausalLMBatch(Batch):
head_dim
,
)
if
batch
.
keys_head_dim_last
:
if
batch
es
[
0
]
.
keys_head_dim_last
:
padded_past_keys_shape
=
padded_past_values_shape
else
:
# seq_length is last for BLOOM
...
...
@@ -314,33 +344,52 @@ class CausalLMBatch(Batch):
max_input_length
-
1
,
)
# This will run only once per layer
if
j
==
len
(
past_key_values
):
padded_past_keys
=
past_keys
.
new_zeros
(
padded_past_keys_shape
)
padded_past_values
=
past_values
.
new_zeros
(
padded_past_values_shape
)
past_key_values
.
append
((
padded_past_keys
,
padded_past_values
))
# Iterate over attention layers
# Concatenate past key values layer by layer to allow incremental garbage collection
for
j
in
range
(
len
(
first_past_kvs
)):
padded_past_keys
=
first_past_kvs
[
j
][
0
].
new_zeros
(
padded_past_keys_shape
)
start_index
=
0
for
batch
in
batches
:
past_keys
=
batch
.
past_key_values
[
j
][
0
]
# Clear reference to the original tensor
batch
.
past_key_values
[
j
][
0
]
=
None
# We slice the past keys and values to remove the padding from previous batches
# Slicing end index for this batch
end_index
=
start_index
+
len
(
batch
)
# We slice the keys to remove the padding from previous batches
past_seq_len
=
batch
.
max_input_length
-
1
if
batch
.
keys_head_dim_last
:
past_key_values
[
j
][
0
][
start_index
:
end_index
,
:,
-
(
batch
.
max_input_length
-
1
)
:,
:,
]
=
past_keys
[:,
:,
-
(
batch
.
max_input_length
-
1
)
:,
:]
padded_past_keys
[
start_index
:
end_index
,
:,
-
past_seq_len
:,
:
]
=
past_keys
[:,
:,
-
past_seq_len
:,
:]
else
:
past_key_values
[
j
][
0
][
start_index
:
end_index
,
:,
:,
-
(
batch
.
max_input_length
-
1
)
:,
]
=
past_keys
[:,
:,
:,
-
(
batch
.
max_input_length
-
1
)
:]
# BLOOM case
padded_past_keys
[
start_index
:
end_index
,
:,
:,
-
past_seq_len
:
]
=
past_keys
[:,
:,
:,
-
past_seq_len
:]
del
past_keys
start_index
=
end_index
padded_past_values
=
first_past_kvs
[
j
][
1
].
new_zeros
(
padded_past_values_shape
)
start_index
=
0
for
batch
in
batches
:
past_values
=
batch
.
past_key_values
[
j
][
1
]
# Clear reference to the original tensor
batch
.
past_key_values
[
j
][
1
]
=
None
# Slicing end index for this batch
end_index
=
start_index
+
len
(
batch
)
# We slice the past values to remove the padding from previous batches
past_seq_len
=
batch
.
max_input_length
-
1
padded_past_values
[
start_index
:
end_index
,
:,
-
past_seq_len
:,
:
]
=
past_values
[:,
:,
-
past_seq_len
:,
:]
del
past_values
past_key_values
[
j
][
1
][
start_index
:
end_index
,
:,
-
(
batch
.
max_input_length
-
1
)
:,
:
]
=
past_values
[:,
:,
-
(
batch
.
max_input_length
-
1
)
:,
:]
start_index
=
end_index
start_index
+=
len
(
batch
)
past_key_values
.
append
([
padded_past_keys
,
padded_past_values
]
)
return
cls
(
batch_id
=
batches
[
0
].
batch_id
,
...
...
server/text_generation_server/models/seq2seq_lm.py
View file @
4a7dd408
...
...
@@ -25,7 +25,7 @@ class Seq2SeqLMBatch(Batch):
requests_idx_mapping
:
Dict
[
int
,
int
]
# Encoder values
input_ids
:
torch
.
Tensor
input_ids
:
Optional
[
torch
.
Tensor
]
attention_mask
:
torch
.
Tensor
# Decoder values
...
...
@@ -164,6 +164,7 @@ class Seq2SeqLMBatch(Batch):
max_input_length
=
0
max_decoder_input_length
=
0
padding_right_offset
=
0
for
i
,
r
in
enumerate
(
requests
):
idx
=
self
.
requests_idx_mapping
[
r
.
id
]
...
...
@@ -184,45 +185,53 @@ class Seq2SeqLMBatch(Batch):
max_decoder_input_length
=
max
(
max_decoder_input_length
,
request_decoder_input_length
)
padding_right_offset
=
max
(
padding_right_offset
,
self
.
stopping_criterias
[
idx
].
max_new_tokens
-
self
.
stopping_criterias
[
idx
].
current_tokens
)
next_token_choosers
.
append
(
self
.
next_token_choosers
[
idx
])
stopping_criterias
.
append
(
self
.
stopping_criterias
[
idx
])
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
decoder_input_ids
=
self
.
decoder_input_ids
[
keep_indices
]
attention_mask
=
self
.
attention_mask
[
keep_indices
]
self
.
decoder_input_ids
=
self
.
decoder_input_ids
[
keep_indices
]
self
.
attention_mask
=
self
.
attention_mask
[
keep_indices
,
-
max_input_length
:
]
if
self
.
decoder_attention_mask
is
not
None
:
decoder_attention_mask
=
self
.
decoder_attention_mask
[
keep_indices
]
else
:
decoder_attention_mask
=
None
self
.
decoder_attention_mask
=
self
.
decoder_attention_mask
[
keep_indices
,
-
(
self
.
padding_right_offset
+
max_decoder_input_length
):
(
self
.
decoder_attention_mask
.
shape
[
1
]
-
self
.
padding_right_offset
)
+
padding_right_offset
,
]
encoder_last_hidden_state
=
self
.
encoder_last_hidden_state
[
keep_indices
]
self
.
encoder_last_hidden_state
=
self
.
encoder_last_hidden_state
[
keep_indices
,
-
max_input_length
:]
# Ensure that past_key_values tensors can be updated in-place
if
type
(
self
.
past_key_values
[
0
])
==
tuple
:
self
.
past_key_values
=
[[
t
for
t
in
layer
]
for
layer
in
self
.
past_key_values
]
decoder_past_seq_len
=
max_decoder_input_length
-
1
for
layer
in
self
.
past_key_values
:
layer
[
0
]
=
layer
[
0
][
keep_indices
,
:,
-
decoder_past_seq_len
:]
layer
[
1
]
=
layer
[
1
][
keep_indices
,
:,
-
decoder_past_seq_len
:]
layer
[
2
]
=
layer
[
2
][
keep_indices
,
:,
-
max_input_length
:]
layer
[
3
]
=
layer
[
3
][
keep_indices
,
:,
-
max_input_length
:]
self
.
requests
=
requests
self
.
requests_idx_mapping
=
requests_idx_mapping
self
.
input_ids
=
None
self
.
all_decoder_input_ids
=
all_decoder_input_ids
self
.
input_lengths
=
input_lengths
self
.
decoder_input_lengths
=
decoder_input_lengths
self
.
offsets
=
offsets
self
.
token_offsets
=
token_offsets
self
.
next_token_choosers
=
next_token_choosers
self
.
stopping_criterias
=
stopping_criterias
self
.
max_input_length
=
max_input_length
self
.
max_decoder_input_length
=
max_decoder_input_length
self
.
padding_right_offset
=
padding_right_offset
past_key_values
=
[
[
t
[
keep_indices
]
for
t
in
layer
]
for
layer
in
self
.
past_key_values
]
return
self
return
Seq2SeqLMBatch
(
batch_id
=
self
.
batch_id
,
requests
=
requests
,
requests_idx_mapping
=
requests_idx_mapping
,
input_ids
=
None
,
attention_mask
=
attention_mask
,
decoder_input_ids
=
decoder_input_ids
,
all_decoder_input_ids
=
all_decoder_input_ids
,
decoder_attention_mask
=
decoder_attention_mask
,
encoder_last_hidden_state
=
encoder_last_hidden_state
,
past_key_values
=
past_key_values
,
input_lengths
=
input_lengths
,
decoder_input_lengths
=
decoder_input_lengths
,
offsets
=
offsets
,
token_offsets
=
token_offsets
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
max_input_length
=
max_input_length
,
max_decoder_input_length
=
max_decoder_input_length
,
padding_right_offset
=
self
.
padding_right_offset
,
)
@
classmethod
@
tracer
.
start_as_current_span
(
"concatenate"
)
...
...
@@ -350,58 +359,78 @@ class Seq2SeqLMBatch(Batch):
encoder_last_hidden_state
[
start_index
:
end_index
,
-
batch
.
max_input_length
:,
:
]
=
batch
.
encoder_last_hidden_state
[:,
-
batch
.
max_input_length
:,
:]
batch
.
encoder_last_hidden_state
=
None
#
Iterate over attention layers
for
j
,
past
in
enumerat
e
(
batch
.
past_key_values
)
:
_
,
num_heads
,
_
,
head_dim
=
past
[
0
].
shape
#
Ensure that we can update tensors in-place
if
typ
e
(
batch
.
past_key_values
[
0
])
==
tuple
:
batch
.
past_key_values
=
[[
t
for
t
in
layer
]
for
layer
in
batch
.
past_key_values
]
# This will run only once per layer
if
j
==
len
(
past_key_values
):
past_key_values
.
append
([])
start_index
=
end_index
# Decoder past
for
k
,
t
in
enumerate
(
past
[:
2
]):
padded_t_shape
=
(
# Determine shapes for new past kv tensors
first_past_kvs
=
batches
[
0
].
past_key_values
_
,
num_heads
,
_
,
head_dim
=
first_past_kvs
[
0
][
0
].
shape
padded_dec_t_shape
=
(
total_batch_size
,
num_heads
,
(
max_decoder_input_length
-
1
),
head_dim
,
)
# Initialize tensors
# This will run only once per layer and per past tensor
if
k
==
len
(
past_key_values
[
j
]):
past_key_values
[
j
].
append
(
t
.
new_zeros
(
padded_t_shape
))
# We slice the past keys and values to remove the padding from previous batches
past_key_values
[
j
][
k
][
start_index
:
end_index
,
:,
-
(
batch
.
max_decoder_input_length
-
1
)
:,
:,
]
=
t
[:,
:,
-
(
batch
.
max_decoder_input_length
-
1
)
:,
:]
# encoder past
for
k
,
t
in
enumerate
(
past
[
2
:]):
padded_t_shape
=
(
padded_enc_t_shape
=
(
total_batch_size
,
num_heads
,
max_input_length
,
head_dim
,
)
idx
=
k
+
2
# Iterate over attention layers
for
j
in
range
(
len
(
first_past_kvs
)):
past_key_values
.
append
([])
# Decoder past
for
k
in
range
(
0
,
2
):
# Initialize tensors
# This will run only once per layer and per past tensor
if
idx
==
len
(
past_key_values
[
j
]):
past_key_values
[
j
].
append
(
t
.
new_zeros
(
padded_t_shape
))
padded_past_values
=
first_past_kvs
[
j
][
k
].
new_zeros
(
padded_dec_t_shape
)
past_key_values
[
j
].
append
(
padded_past_values
)
past_key_values
[
j
][
idx
][
start_index
:
end_index
,
:,
-
batch
.
max_input_length
:,
:
]
=
t
[:,
:,
-
batch
.
max_input_length
:,
:]
start_index
=
0
for
batch
in
batches
:
t
=
batch
.
past_key_values
[
j
][
k
]
# Clear reference to the original tensor
batch
.
past_key_values
[
j
][
k
]
=
None
# Slicing end index for this batch
end_index
=
start_index
+
len
(
batch
)
# We slice the past keys and values to remove the padding from previous batches
past_seq_len
=
batch
.
max_decoder_input_length
-
1
padded_past_values
[
start_index
:
end_index
,
:,
-
past_seq_len
:,
:
]
=
t
[:,
:,
-
past_seq_len
:,
:]
del
t
start_index
=
end_index
# Encoder past
for
k
in
range
(
2
,
4
):
# Initialize tensors
padded_past_values
=
first_past_kvs
[
j
][
k
].
new_zeros
(
padded_enc_t_shape
)
past_key_values
[
j
].
append
(
padded_past_values
)
start_index
=
0
for
batch
in
batches
:
t
=
batch
.
past_key_values
[
j
][
k
]
# Clear reference to the original tensor
batch
.
past_key_values
[
j
][
k
]
=
None
# Slicing end index for this batch
end_index
=
start_index
+
len
(
batch
)
# We slice the past keys and values to remove the padding from previous batches
padded_past_values
[
start_index
:
end_index
,
:,
-
batch
.
max_input_length
:,
:
]
=
t
[:,
:,
-
batch
.
max_input_length
:,
:]
del
t
start_index
+
=
l
en
(
batch
)
start_index
=
en
d_index
return
cls
(
batch_id
=
batches
[
0
].
batch_id
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment