Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
44ce098c
Unverified
Commit
44ce098c
authored
Feb 24, 2023
by
OlivierDehaene
Committed by
GitHub
Feb 24, 2023
Browse files
feat(server): pre-allocate max attention mask (#75)
parent
78063c05
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
148 additions
and
114 deletions
+148
-114
server/tests/models/test_bloom.py
server/tests/models/test_bloom.py
+13
-12
server/tests/models/test_causal_lm.py
server/tests/models/test_causal_lm.py
+13
-12
server/tests/models/test_seq2seq_lm.py
server/tests/models/test_seq2seq_lm.py
+4
-8
server/text_generation/models/causal_lm.py
server/text_generation/models/causal_lm.py
+50
-21
server/text_generation/models/galactica.py
server/text_generation/models/galactica.py
+0
-2
server/text_generation/models/seq2seq_lm.py
server/text_generation/models/seq2seq_lm.py
+68
-51
server/text_generation/models/t5.py
server/text_generation/models/t5.py
+0
-8
No files found.
server/tests/models/test_bloom.py
View file @
44ce098c
...
...
@@ -65,8 +65,8 @@ def test_batch_from_pb(default_pb_batch, default_bloom_batch):
assert
batch
.
input_ids
[
0
][
-
1
]
==
10264
assert
torch
.
all
(
batch
.
input_ids
[
0
][:
-
1
]
==
3
)
assert
batch
.
attention_mask
[
0
][
-
1
]
==
1
assert
torch
.
all
(
batch
.
attention_mask
[
0
][
:
-
1
]
==
0
)
assert
batch
.
attention_mask
[
0
][
0
]
==
1
assert
torch
.
all
(
batch
.
attention_mask
[
0
][
1
:
]
==
0
)
assert
batch
.
past_key_values
is
None
...
...
@@ -98,16 +98,13 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
assert
not
next_batch
.
keys_head_dim_last
assert
len
(
next_batch
.
all_input_ids
)
==
next_batch
.
size
assert
(
len
(
next_batch
.
all_input_ids
[
0
])
==
len
(
next_batch
.
attention_mask
[
0
])
==
sequence_length
+
1
)
assert
len
(
next_batch
.
all_input_ids
[
0
])
==
sequence_length
+
1
assert
len
(
next_batch
.
attention_mask
[
0
])
==
11
assert
torch
.
all
(
next_batch
.
all_input_ids
[
0
][
-
2
:]
==
10264
)
assert
torch
.
all
(
next_batch
.
all_input_ids
[
0
][:
-
2
]
==
3
)
assert
torch
.
all
(
next_batch
.
attention_mask
[
0
][
-
2
:]
==
1
)
assert
torch
.
all
(
next_batch
.
attention_mask
[
0
][
:
-
2
]
==
0
)
assert
torch
.
all
(
next_batch
.
attention_mask
[
0
][:
2
]
==
1
)
assert
torch
.
all
(
next_batch
.
attention_mask
[
0
][
2
:
]
==
0
)
assert
next_batch
.
input_ids
.
shape
==
(
next_batch
.
size
,
1
)
assert
next_batch
.
input_ids
[
0
,
0
]
==
10264
...
...
@@ -213,9 +210,13 @@ def test_batch_concatenate(
assert
torch
.
equal
(
next_batch
.
all_input_ids
[
1
],
next_batch_1
.
all_input_ids
[
0
])
assert
torch
.
equal
(
next_batch
.
all_input_ids
[
2
],
next_batch_1
.
all_input_ids
[
1
])
assert
torch
.
all
(
next_batch
.
attention_mask
[
0
]
==
1
)
assert
torch
.
all
(
next_batch
.
attention_mask
[
1
:,
-
2
:]
==
1
)
assert
torch
.
all
(
next_batch
.
attention_mask
[
1
:,
:
-
2
]
==
0
)
assert
torch
.
all
(
next_batch
.
attention_mask
[
0
,
:
-
next_batch
.
padding_right_offset
]
==
1
)
assert
torch
.
all
(
next_batch
.
attention_mask
[
1
:,
1
:
-
next_batch
.
padding_right_offset
]
==
1
)
assert
torch
.
all
(
next_batch
.
attention_mask
[
1
:,
3
:]
==
0
)
assert
next_batch
.
batch_id
==
0
assert
torch
.
all
(
next_batch
.
input_ids
==
10264
)
...
...
server/tests/models/test_causal_lm.py
View file @
44ce098c
...
...
@@ -62,8 +62,8 @@ def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
assert
batch
.
input_ids
[
0
][
-
1
]
==
14402
assert
torch
.
all
(
batch
.
input_ids
[
0
][:
-
1
]
==
50256
)
assert
batch
.
attention_mask
[
0
][
-
1
]
==
1
assert
torch
.
all
(
batch
.
attention_mask
[
0
][:
-
1
]
==
0
)
assert
batch
.
attention_mask
[
0
,
0
]
==
1
assert
torch
.
all
(
batch
.
attention_mask
[
0
,
1
:
]
==
0
)
assert
batch
.
past_key_values
is
None
...
...
@@ -94,17 +94,14 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
assert
isinstance
(
next_batch
,
CausalLMBatch
)
assert
len
(
next_batch
.
all_input_ids
)
==
next_batch
.
size
assert
(
len
(
next_batch
.
all_input_ids
[
0
])
==
len
(
next_batch
.
attention_mask
[
0
])
==
sequence_length
+
1
)
assert
len
(
next_batch
.
all_input_ids
[
0
])
==
sequence_length
+
1
assert
len
(
next_batch
.
attention_mask
[
0
])
==
11
assert
next_batch
.
all_input_ids
[
0
][
-
1
]
==
13
assert
next_batch
.
all_input_ids
[
0
][
-
2
]
==
14402
assert
torch
.
all
(
next_batch
.
all_input_ids
[
0
][:
-
2
]
==
50256
)
assert
torch
.
all
(
next_batch
.
attention_mask
[
0
][
-
2
:
]
==
1
)
assert
torch
.
all
(
next_batch
.
attention_mask
[
0
][
:
-
2
]
==
0
)
assert
torch
.
all
(
next_batch
.
attention_mask
[
0
][
0
:
2
]
==
1
)
assert
torch
.
all
(
next_batch
.
attention_mask
[
0
][
2
:
]
==
0
)
assert
next_batch
.
input_ids
.
shape
==
(
next_batch
.
size
,
1
)
assert
next_batch
.
input_ids
[
0
,
0
]
==
13
...
...
@@ -210,9 +207,13 @@ def test_batch_concatenate(
assert
torch
.
equal
(
next_batch
.
all_input_ids
[
1
],
next_batch_1
.
all_input_ids
[
0
])
assert
torch
.
equal
(
next_batch
.
all_input_ids
[
2
],
next_batch_1
.
all_input_ids
[
1
])
assert
torch
.
all
(
next_batch
.
attention_mask
[
0
]
==
1
)
assert
torch
.
all
(
next_batch
.
attention_mask
[
1
:,
-
2
:]
==
1
)
assert
torch
.
all
(
next_batch
.
attention_mask
[
1
:,
:
-
2
]
==
0
)
assert
torch
.
all
(
next_batch
.
attention_mask
[
0
,
:
-
next_batch
.
padding_right_offset
]
==
1
)
assert
torch
.
all
(
next_batch
.
attention_mask
[
1
:,
1
:
-
next_batch
.
padding_right_offset
]
==
1
)
assert
torch
.
all
(
next_batch
.
attention_mask
[
1
:,
3
:]
==
0
)
assert
next_batch
.
batch_id
==
0
assert
next_batch
.
input_ids
[
0
,
0
]
==
12355
...
...
server/tests/models/test_seq2seq_lm.py
View file @
44ce098c
...
...
@@ -106,7 +106,7 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
assert
len
(
generations
)
==
len
(
next_batch
)
assert
isinstance
(
next_batch
,
Seq2SeqLMBatch
)
assert
torch
.
equal
(
next_batch
.
input_ids
,
default_seq2seq_lm_batch
.
input_ids
)
assert
next_batch
.
input_ids
is
None
assert
torch
.
equal
(
next_batch
.
attention_mask
,
default_seq2seq_lm_batch
.
attention_mask
)
...
...
@@ -220,11 +220,6 @@ def test_batch_concatenate(
assert
next_batch
.
batch_id
==
0
assert
torch
.
all
(
next_batch
.
input_ids
[:,
0
]
==
4268
)
assert
torch
.
all
(
next_batch
.
input_ids
[:,
1
]
==
1
)
assert
torch
.
all
(
next_batch
.
attention_mask
==
1
)
assert
torch
.
equal
(
next_batch
.
decoder_input_ids
[
0
],
next_batch_0
.
decoder_input_ids
[
0
]
)
...
...
@@ -233,9 +228,10 @@ def test_batch_concatenate(
next_batch
.
decoder_input_ids
[
1
:,
-
2
:],
next_batch_1
.
decoder_input_ids
)
assert
torch
.
all
(
next_batch
.
decoder_attention_mask
[
0
]
==
1
)
assert
torch
.
all
(
next_batch
.
decoder_attention_mask
[
0
,
:
3
]
==
1
)
assert
torch
.
all
(
next_batch
.
decoder_attention_mask
[
0
,
3
:]
==
0
)
assert
torch
.
all
(
next_batch
.
decoder_attention_mask
[
1
:,
0
]
==
0
)
assert
torch
.
all
(
next_batch
.
decoder_attention_mask
[
1
:,
-
2
:
]
==
1
)
assert
torch
.
all
(
next_batch
.
decoder_attention_mask
[
1
:,
1
:
3
]
==
1
)
assert
torch
.
equal
(
next_batch
.
encoder_last_hidden_state
[
0
],
...
...
server/text_generation/models/causal_lm.py
View file @
44ce098c
...
...
@@ -37,6 +37,7 @@ class CausalLMBatch(Batch):
# Metadata used for padding
size
:
int
max_sequence_length
:
int
padding_right_offset
:
int
# Past metadata
keys_head_dim_last
:
bool
=
True
...
...
@@ -61,22 +62,36 @@ class CausalLMBatch(Batch):
input_lengths
=
[]
# Parse batch
max_sequence_length
=
0
padding_right_offset
=
0
for
r
in
pb
.
requests
:
inputs
.
append
(
r
.
inputs
)
input_lengths
.
append
(
r
.
input_length
)
next_token_choosers
.
append
(
NextTokenChooser
.
from_pb
(
r
.
parameters
,
device
))
stopping_criterias
.
append
(
StoppingCriteria
.
from_pb
(
r
.
stopping_parameters
,
tokenizer
)
stopping_criteria
=
StoppingCriteria
.
from_pb
(
r
.
stopping_parameters
,
tokenizer
)
stopping_criterias
.
append
(
stopping_criteria
)
max_sequence_length
=
max
(
max_sequence_length
,
r
.
input_length
)
padding_right_offset
=
max
(
padding_right_offset
,
stopping_criteria
.
max_new_tokens
)
pad_to_multiple_of
=
8
if
device
.
type
==
"cuda"
else
None
tokenized_inputs
=
tokenizer
(
inputs
,
return_tensors
=
"pt"
,
padding
=
True
,
pad_to_multiple_of
=
pad_to_multiple_of
,
return_token_type_ids
=
False
,
).
to
(
device
)
input_ids
=
tokenized_inputs
[
"input_ids"
]
# Allocate maximum attention_mask
attention_mask
=
input_ids
.
new_zeros
(
(
pb
.
size
,
max_sequence_length
+
padding_right_offset
)
)
# Copy tokenizer attention_mask into fully allocated attention_mask
attention_mask
[:,
:
max_sequence_length
]
=
tokenized_inputs
[
"attention_mask"
]
position_ids
=
tokenized_inputs
[
"attention_mask"
].
long
().
cumsum
(
-
1
)
-
1
position_ids
.
masked_fill_
(
tokenized_inputs
[
"attention_mask"
]
==
0
,
1
)
all_input_ids
=
tokenized_inputs
[
"input_ids"
].
unsqueeze
(
-
1
)
...
...
@@ -84,8 +99,8 @@ class CausalLMBatch(Batch):
return
cls
(
batch_id
=
pb
.
id
,
requests
=
pb
.
requests
,
input_ids
=
tokenized_inputs
[
"
input_ids
"
]
,
attention_mask
=
tokenized_inputs
[
"
attention_mask
"
]
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
past_key_values
=
None
,
all_input_ids
=
all_input_ids
,
...
...
@@ -93,15 +108,21 @@ class CausalLMBatch(Batch):
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
size
=
pb
.
size
,
max_sequence_length
=
max
(
input_lengths
),
max_sequence_length
=
max_sequence_length
,
padding_right_offset
=
padding_right_offset
,
)
@
classmethod
@
tracer
.
start_as_current_span
(
"concatenate"
)
def
concatenate
(
cls
,
batches
:
List
[
"CausalLMBatch"
])
->
"CausalLMBatch"
:
# Used for padding
total_batch_size
=
sum
(
batch
.
size
for
batch
in
batches
)
max_sequence_length
=
max
(
batch
.
max_sequence_length
for
batch
in
batches
)
total_batch_size
=
0
max_sequence_length
=
0
padding_right_offset
=
0
for
batch
in
batches
:
total_batch_size
+=
batch
.
size
max_sequence_length
=
max
(
max_sequence_length
,
batch
.
max_sequence_length
)
padding_right_offset
=
max
(
padding_right_offset
,
batch
.
padding_right_offset
)
# Batch attributes
requests
=
[]
...
...
@@ -144,13 +165,22 @@ class CausalLMBatch(Batch):
# Create padded tensor
if
attention_mask
is
None
:
attention_mask
=
batch
.
attention_mask
.
new_zeros
(
(
total_batch_size
,
max_sequence_length
),
(
total_batch_size
,
max_sequence_length
+
padding_right_offset
),
)
# We need to slice the attention mask to remove padding from previous steps
# and to remove unused allocated space
left_offset
=
max_sequence_length
-
batch
.
max_sequence_length
batch_left_offset
=
(
batch
.
attention_mask
.
shape
[
1
]
-
batch
.
max_sequence_length
-
batch
.
padding_right_offset
)
attention_mask
[
start_index
:
end_index
,
-
batch
.
max_sequence_length
:
]
=
batch
.
attention_mask
[:,
-
batch
.
max_sequence_length
:]
start_index
:
end_index
,
left_offset
:
-
padding_right_offset
,
]
=
batch
.
attention_mask
[
:,
batch_left_offset
:
-
batch
.
padding_right_offset
,
]
# Create empty tensor
# position_ids is always of shape [batch_size, 1]
...
...
@@ -228,6 +258,7 @@ class CausalLMBatch(Batch):
stopping_criterias
=
stopping_criterias
,
size
=
total_batch_size
,
max_sequence_length
=
max_sequence_length
,
padding_right_offset
=
padding_right_offset
,
keys_head_dim_last
=
batches
[
0
].
keys_head_dim_last
,
)
...
...
@@ -294,9 +325,12 @@ class CausalLM(Model):
def
generate_token
(
self
,
batch
:
CausalLMBatch
)
->
Tuple
[
List
[
Generation
],
Optional
[
CausalLMBatch
]]:
# slice the attention mask to the correct shape
attention_mask
=
batch
.
attention_mask
[:,
:
-
batch
.
padding_right_offset
]
logits
,
past
=
self
.
forward
(
batch
.
input_ids
,
batch
.
attention_mask
,
attention_mask
,
batch
.
position_ids
,
batch
.
past_key_values
,
)
...
...
@@ -448,14 +482,8 @@ class CausalLM(Model):
next_batch_next_token_choosers
=
batch
.
next_token_choosers
next_batch_stopping_criterias
=
batch
.
stopping_criterias
# Update attention_mask with padding as we added a new token to input_ids
next_batch_attention_mask
=
torch
.
cat
(
[
next_batch_attention_mask
,
next_batch_attention_mask
.
new_ones
(
next_batch_size
,
1
),
],
dim
=
1
,
)
# Update attention_mask as we added a new token to input_ids
next_batch_attention_mask
[:,
-
batch
.
padding_right_offset
]
=
1
# Update position_ids
next_batch_position_ids
=
next_batch_position_ids
[:,
-
1
:]
+
1
...
...
@@ -473,6 +501,7 @@ class CausalLM(Model):
stopping_criterias
=
next_batch_stopping_criterias
,
size
=
next_batch_size
,
max_sequence_length
=
next_batch_max_sequence_length
,
padding_right_offset
=
batch
.
padding_right_offset
-
1
,
keys_head_dim_last
=
batch
.
keys_head_dim_last
,
)
return
generations
,
next_batch
server/text_generation/models/galactica.py
View file @
44ce098c
...
...
@@ -106,12 +106,10 @@ class GalacticaCausalLMBatch(CausalLMBatch):
)
# Tokenize batch
pad_to_multiple_of
=
8
if
device
.
type
==
"cuda"
else
None
tokenized_inputs
=
tokenizer
(
inputs
,
return_tensors
=
"pt"
,
padding
=
True
,
pad_to_multiple_of
=
pad_to_multiple_of
,
return_token_type_ids
=
False
,
).
to
(
device
)
position_ids
=
tokenized_inputs
[
"attention_mask"
].
long
().
cumsum
(
-
1
)
-
1
...
...
server/text_generation/models/seq2seq_lm.py
View file @
44ce098c
...
...
@@ -42,6 +42,7 @@ class Seq2SeqLMBatch(Batch):
size
:
int
max_input_length
:
int
max_decoder_input_length
:
int
padding_right_offset
:
int
def
to_pb
(
self
)
->
generate_pb2
.
Batch
:
"""Convert a Seq2SeqLMBatch to a text_generation.v1.Batch protobuf"""
...
...
@@ -68,6 +69,8 @@ class Seq2SeqLMBatch(Batch):
decoder_input_lengths
=
[]
# Parse batch
max_input_length
=
0
padding_right_offset
=
0
for
r
in
pb
.
requests
:
inputs
.
append
(
r
.
inputs
)
input_lengths
.
append
(
r
.
input_length
)
...
...
@@ -75,17 +78,20 @@ class Seq2SeqLMBatch(Batch):
decoder_input_ids
.
append
(
tokenizer
.
bos_token_id
)
decoder_input_lengths
.
append
(
1
)
next_token_choosers
.
append
(
NextTokenChooser
.
from_pb
(
r
.
parameters
,
device
))
stopping_criterias
.
append
(
StoppingCriteria
.
from_pb
(
r
.
stopping_parameters
,
tokenizer
)
stopping_criteria
=
StoppingCriteria
.
from_pb
(
r
.
stopping_parameters
,
tokenizer
)
stopping_criterias
.
append
(
stopping_criteria
)
max_input_length
=
max
(
max_input_length
,
r
.
input_length
)
padding_right_offset
=
max
(
padding_right_offset
,
stopping_criteria
.
max_new_tokens
)
# Tokenize batch
pad_to_multiple_of
=
8
if
device
.
type
==
"cuda"
else
None
tokenized_inputs
=
tokenizer
(
inputs
,
return_tensors
=
"pt"
,
padding
=
True
,
pad_to_multiple_of
=
pad_to_multiple_of
,
return_token_type_ids
=
False
,
).
to
(
device
)
# Convert decoder_input_ids to torch tensor of size [batch_size, 1]
...
...
@@ -107,6 +113,7 @@ class Seq2SeqLMBatch(Batch):
size
=
len
(
pb
.
requests
),
max_input_length
=
max
(
input_lengths
),
max_decoder_input_length
=
1
,
padding_right_offset
=
padding_right_offset
,
)
@
classmethod
...
...
@@ -115,11 +122,17 @@ class Seq2SeqLMBatch(Batch):
"""Concatenate multiple batches together by padding internal torch tensors"""
# Used for padding
total_batch_size
=
sum
(
batch
.
size
for
batch
in
batches
)
max_input_length
=
max
(
batch
.
max_input_length
for
batch
in
batches
)
max_decoder_input_length
=
max
(
batch
.
max_decoder_input_length
for
batch
in
batches
)
total_batch_size
=
0
max_input_length
=
0
max_decoder_input_length
=
0
padding_right_offset
=
0
for
batch
in
batches
:
total_batch_size
+=
batch
.
size
max_input_length
=
max
(
max_input_length
,
batch
.
max_input_length
)
max_decoder_input_length
=
max
(
max_decoder_input_length
,
batch
.
max_decoder_input_length
)
padding_right_offset
=
max
(
padding_right_offset
,
batch
.
padding_right_offset
)
# Batch attributes
requests
=
[]
...
...
@@ -129,7 +142,6 @@ class Seq2SeqLMBatch(Batch):
stopping_criterias
=
[]
# Batch tensors
input_ids
=
None
attention_mask
=
None
decoder_input_ids
=
None
decoder_attention_mask
=
None
...
...
@@ -155,16 +167,6 @@ class Seq2SeqLMBatch(Batch):
if
batch
.
encoder_last_hidden_state
is
None
:
raise
ValueError
(
"Batch encoder_last_hidden_state cannot be None"
)
# Create padded tensor
if
input_ids
is
None
:
input_ids
=
batch
.
input_ids
.
new_zeros
(
(
total_batch_size
,
max_input_length
),
)
# Copy to correct indices
input_ids
[
start_index
:
end_index
,
-
batch
.
max_input_length
:
]
=
batch
.
input_ids
[:,
-
batch
.
max_input_length
:]
# Create padded tensor
if
attention_mask
is
None
:
attention_mask
=
batch
.
attention_mask
.
new_zeros
(
...
...
@@ -189,19 +191,29 @@ class Seq2SeqLMBatch(Batch):
if
decoder_attention_mask
is
None
:
# As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here
decoder_attention_mask
=
batch
.
attention_mask
.
new_zeros
(
(
total_batch_size
,
max_decoder_input_length
),
(
total_batch_size
,
max_decoder_input_length
+
padding_right_offset
),
)
# If the decoder mask does not exist yet, all generations started at the same time and we never concatenated
# this batch. All generations are of length `batch.max_decoder_input_length`.
left_offset
=
max_decoder_input_length
-
batch
.
max_decoder_input_length
if
batch
.
decoder_attention_mask
is
None
:
decoder_attention_mask
[
start_index
:
end_index
,
-
batch
.
max_decoder_input_length
:
start_index
:
end_index
,
left_offset
:
-
padding_right_offset
,
]
=
1
# If it exists, we need to index
else
:
batch_left_offset
=
(
batch
.
decoder_attention_mask
.
shape
[
1
]
-
batch
.
max_decoder_input_length
-
batch
.
padding_right_offset
)
decoder_attention_mask
[
start_index
:
end_index
,
-
batch
.
max_decoder_input_length
:
]
=
batch
.
decoder_attention_mask
[:,
-
batch
.
max_decoder_input_length
:]
start_index
:
end_index
,
left_offset
:
-
padding_right_offset
,
]
=
batch
.
decoder_attention_mask
[
:,
batch_left_offset
:
-
batch
.
padding_right_offset
,
]
# Create padded tensor
if
encoder_last_hidden_state
is
None
:
...
...
@@ -273,7 +285,7 @@ class Seq2SeqLMBatch(Batch):
return
cls
(
batch_id
=
batches
[
0
].
batch_id
,
requests
=
requests
,
input_ids
=
input_ids
,
input_ids
=
None
,
attention_mask
=
attention_mask
,
decoder_input_ids
=
decoder_input_ids
,
decoder_attention_mask
=
decoder_attention_mask
,
...
...
@@ -286,6 +298,7 @@ class Seq2SeqLMBatch(Batch):
size
=
total_batch_size
,
max_input_length
=
max_input_length
,
max_decoder_input_length
=
max_decoder_input_length
,
padding_right_offset
=
padding_right_offset
,
)
def
__len__
(
self
):
...
...
@@ -342,14 +355,6 @@ class Seq2SeqLM(Model):
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]],
]:
# Model Forward
if
past_key_values
is
not
None
:
decoder_input_ids
=
decoder_input_ids
[:,
-
1
].
unsqueeze
(
-
1
)
# Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
# internally...
if
encoder_last_hidden_state
is
not
None
:
encoder_last_hidden_state
=
[
encoder_last_hidden_state
]
outputs
=
self
.
model
.
forward
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
...
...
@@ -369,12 +374,34 @@ class Seq2SeqLM(Model):
def
generate_token
(
self
,
batch
:
Seq2SeqLMBatch
)
->
Tuple
[
List
[
Generation
],
Optional
[
Seq2SeqLMBatch
]]:
if
batch
.
decoder_attention_mask
is
not
None
:
# slice to the correct shape
decoder_attention_mask
=
batch
.
decoder_attention_mask
[
:,
:
-
batch
.
padding_right_offset
]
else
:
decoder_attention_mask
=
None
# check if first forward or not
if
batch
.
past_key_values
is
not
None
:
# Only take the last token
decoder_input_ids
=
batch
.
decoder_input_ids
[:,
-
1
].
unsqueeze
(
-
1
)
else
:
decoder_input_ids
=
batch
.
decoder_input_ids
# Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
# internally...
if
batch
.
encoder_last_hidden_state
is
not
None
:
encoder_last_hidden_state
=
[
batch
.
encoder_last_hidden_state
]
else
:
encoder_last_hidden_state
=
batch
.
encoder_last_hidden_state
logits
,
encoder_last_hidden_state
,
past
=
self
.
forward
(
batch
.
input_ids
,
batch
.
attention_mask
,
batch
.
decoder_input_ids
,
batch
.
decoder_attention_mask
,
batch
.
encoder_last_hidden_state
,
decoder_input_ids
,
decoder_attention_mask
,
encoder_last_hidden_state
,
batch
.
past_key_values
,
)
...
...
@@ -402,7 +429,6 @@ class Seq2SeqLM(Model):
logits
,
batch
.
next_token_choosers
,
batch
.
stopping_criterias
,
batch
.
input_ids
,
batch
.
decoder_input_ids
,
)
...
...
@@ -414,7 +440,6 @@ class Seq2SeqLM(Model):
logits
,
next_token_chooser
,
stopping_criteria
,
input_tokens
,
decoder_input_ids
,
)
in
enumerate
(
iterator
):
# Select next token
...
...
@@ -500,10 +525,8 @@ class Seq2SeqLM(Model):
# If we finished at least one generation, we need to evict the indices of the generations that finished
# from the values of the next batch
if
len
(
next_batch_keep_indices
)
!=
len
(
batch
):
# Apply indices to attention mask, past key values and other items that need to be cached
next_batch_input_ids
=
batch
.
input_ids
[
next_batch_keep_indices
]
# Apply indices to decoder_attention mask, past key values and other items that need to be cached
next_batch_attention_mask
=
batch
.
attention_mask
[
next_batch_keep_indices
]
if
batch
.
decoder_attention_mask
is
not
None
:
next_batch_decoder_attention_mask
=
batch
.
decoder_attention_mask
[
next_batch_keep_indices
...
...
@@ -526,7 +549,6 @@ class Seq2SeqLM(Model):
batch
.
stopping_criterias
[
i
]
for
i
in
next_batch_keep_indices
]
else
:
next_batch_input_ids
=
batch
.
input_ids
next_batch_attention_mask
=
batch
.
attention_mask
next_batch_decoder_attention_mask
=
batch
.
decoder_attention_mask
next_batch_encoder_last_hidden_state
=
encoder_last_hidden_state
...
...
@@ -536,20 +558,14 @@ class Seq2SeqLM(Model):
next_batch_next_token_choosers
=
batch
.
next_token_choosers
next_batch_stopping_criterias
=
batch
.
stopping_criterias
# Update decoder_attention_mask
with padding
as we added a new token to input_ids
# Update decoder_attention_mask as we added a new token to input_ids
if
next_batch_decoder_attention_mask
is
not
None
:
next_batch_decoder_attention_mask
=
torch
.
cat
(
[
next_batch_decoder_attention_mask
,
next_batch_decoder_attention_mask
.
new_ones
(
next_batch_size
,
1
),
],
dim
=
1
,
)
next_batch_decoder_attention_mask
[:,
-
batch
.
padding_right_offset
]
=
1
next_batch
=
Seq2SeqLMBatch
(
batch_id
=
batch
.
batch_id
,
requests
=
next_batch_requests
,
input_ids
=
ne
xt_batch_input_ids
,
input_ids
=
No
ne
,
attention_mask
=
next_batch_attention_mask
,
decoder_input_ids
=
next_batch_decoder_input_ids
,
decoder_attention_mask
=
next_batch_decoder_attention_mask
,
...
...
@@ -562,5 +578,6 @@ class Seq2SeqLM(Model):
size
=
next_batch_size
,
max_input_length
=
next_batch_max_input_length
,
max_decoder_input_length
=
next_batch_max_decoder_input_length
,
padding_right_offset
=
batch
.
padding_right_offset
-
1
,
)
return
generations
,
next_batch
server/text_generation/models/t5.py
View file @
44ce098c
...
...
@@ -221,14 +221,6 @@ class T5Sharded(Seq2SeqLM):
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]],
]:
# Model Forward
if
past_key_values
is
not
None
:
decoder_input_ids
=
decoder_input_ids
[:,
-
1
].
unsqueeze
(
-
1
)
# Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
# internally...
if
encoder_last_hidden_state
is
not
None
:
encoder_last_hidden_state
=
[
encoder_last_hidden_state
]
outputs
=
self
.
model
.
forward
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment