Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f464f10a
Unverified
Commit
f464f10a
authored
Apr 20, 2021
by
Patrick von Platen
Committed by
GitHub
Apr 20, 2021
Browse files
[Generate] Remove outdated code (#11331)
* remove update function * update * refactor more * refactor
parent
bfd83c17
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
60 deletions
+20
-60
src/transformers/generation_utils.py
src/transformers/generation_utils.py
+20
-60
No files found.
src/transformers/generation_utils.py
View file @
f464f10a
...
...
@@ -483,31 +483,6 @@ class GenerationMixin:
model_kwargs
[
"encoder_outputs"
]
=
encoder_outputs
return
input_ids
,
model_kwargs
@
staticmethod
def
_init_sequence_length_for_generation
(
input_ids
:
torch
.
LongTensor
,
max_length
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
int
]:
unfinished_sequences
=
input_ids
.
new
(
input_ids
.
shape
[
0
]).
fill_
(
1
)
sequence_lengths
=
input_ids
.
new
(
input_ids
.
shape
[
0
]).
fill_
(
max_length
)
cur_len
=
input_ids
.
shape
[
-
1
]
return
sequence_lengths
,
unfinished_sequences
,
cur_len
@
staticmethod
def
_update_seq_length_for_generation
(
sequence_lengths
:
torch
.
LongTensor
,
unfinished_sequences
:
torch
.
LongTensor
,
cur_len
:
int
,
is_eos_in_next_token
:
torch
.
BoolTensor
,
)
->
Tuple
[
torch
.
LongTensor
,
torch
.
LongTensor
]:
# check if sentence is not finished yet
is_sent_unfinished
=
unfinished_sequences
.
mul
(
is_eos_in_next_token
.
long
()).
bool
()
# update sentence length
sequence_lengths
=
sequence_lengths
.
masked_fill
(
is_sent_unfinished
,
cur_len
)
unfinished_sequences
=
unfinished_sequences
.
mul
((
~
is_eos_in_next_token
).
long
())
return
sequence_lengths
,
unfinished_sequences
@
staticmethod
def
_update_model_kwargs_for_generation
(
outputs
:
ModelOutput
,
model_kwargs
:
Dict
[
str
,
Any
],
is_encoder_decoder
:
bool
=
False
...
...
@@ -1271,10 +1246,9 @@ class GenerationMixin:
model_kwargs
[
"encoder_outputs"
].
get
(
"hidden_states"
)
if
output_hidden_states
else
None
)
# init sequence length tensors
sequence_lengths
,
unfinished_sequences
,
cur_len
=
self
.
_init_sequence_length_for_generation
(
input_ids
,
max_length
)
# keep track of which sequences are already finished
unfinished_sequences
=
input_ids
.
new
(
input_ids
.
shape
[
0
]).
fill_
(
1
)
cur_len
=
input_ids
.
shape
[
-
1
]
this_peer_finished
=
False
# used by synced_gpus only
while
cur_len
<
max_length
:
...
...
@@ -1330,29 +1304,23 @@ class GenerationMixin:
# argmax
next_tokens
=
torch
.
argmax
(
next_tokens_scores
,
dim
=-
1
)
#
add code that transforms
next
_
token
s to tokens_to_add
#
finished sentences should have their
next
token
be a padding token
if
eos_token_id
is
not
None
:
assert
pad_token_id
is
not
None
,
"If eos_token_id is defined, make sure that pad_token_id is defined."
next_tokens
=
next_tokens
*
unfinished_sequences
+
pad_token_id
*
(
1
-
unfinished_sequences
)
#
add token and increase length by one
#
update generated ids, model inputs, and length for next step
input_ids
=
torch
.
cat
([
input_ids
,
next_tokens
[:,
None
]],
dim
=-
1
)
# update sequence length
if
eos_token_id
is
not
None
:
sequence_lengths
,
unfinished_sequences
=
self
.
_update_seq_length_for_generation
(
sequence_lengths
,
unfinished_sequences
,
cur_len
,
next_tokens
==
eos_token_id
)
# update model kwargs
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
)
# increase cur_len
cur_len
=
cur_len
+
1
# stop when there is a </s> in each sentence, or if we exceed the maximum length
# if eos_token was found in one sentence, set sentence to finished
if
eos_token_id
is
not
None
:
unfinished_sequences
=
unfinished_sequences
.
mul
((
next_tokens
!=
eos_token_id
).
long
())
# stop when each sentence is finished, or if we exceed the maximum length
if
unfinished_sequences
.
max
()
==
0
or
stopping_criteria
(
input_ids
,
scores
):
if
not
synced_gpus
:
break
...
...
@@ -1511,10 +1479,9 @@ class GenerationMixin:
model_kwargs
[
"encoder_outputs"
].
get
(
"hidden_states"
)
if
output_hidden_states
else
None
)
# init sequence length tensors
sequence_lengths
,
unfinished_sequences
,
cur_len
=
self
.
_init_sequence_length_for_generation
(
input_ids
,
max_length
)
# keep track of which sequences are already finished
unfinished_sequences
=
input_ids
.
new
(
input_ids
.
shape
[
0
]).
fill_
(
1
)
cur_len
=
input_ids
.
shape
[
-
1
]
this_peer_finished
=
False
# used by synced_gpus only
# auto-regressive generation
...
...
@@ -1571,32 +1538,25 @@ class GenerationMixin:
# sample
probs
=
F
.
softmax
(
next_token_scores
,
dim
=-
1
)
next_tokens
=
torch
.
multinomial
(
probs
,
num_samples
=
1
).
squeeze
(
1
)
#
add code that transforms
next
_
token
s to tokens_to_add
#
finished sentences should have their
next
token
be a padding token
if
eos_token_id
is
not
None
:
assert
pad_token_id
is
not
None
,
"If eos_token_id is defined, make sure that pad_token_id is defined."
next_tokens
=
next_tokens
*
unfinished_sequences
+
pad_token_id
*
(
1
-
unfinished_sequences
)
#
add token and increase length by one
#
update generated ids, model inputs, and length for next step
input_ids
=
torch
.
cat
([
input_ids
,
next_tokens
[:,
None
]],
dim
=-
1
)
# update sequence length
if
eos_token_id
is
not
None
:
sequence_lengths
,
unfinished_sequences
=
self
.
_update_seq_length_for_generation
(
sequence_lengths
,
unfinished_sequences
,
cur_len
,
next_tokens
==
eos_token_id
)
# update model kwargs
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
)
# increase cur_len
cur_len
=
cur_len
+
1
# stop when there is a </s> in each sentence, or if we exceed the maximum length
# if eos_token was found in one sentence, set sentence to finished
if
eos_token_id
is
not
None
:
unfinished_sequences
=
unfinished_sequences
.
mul
((
next_tokens
!=
eos_token_id
).
long
())
# stop when each sentence is finished, or if we exceed the maximum length
if
unfinished_sequences
.
max
()
==
0
or
stopping_criteria
(
input_ids
,
scores
):
if
not
synced_gpus
:
break
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment