Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f464f10a
Unverified
Commit
f464f10a
authored
Apr 20, 2021
by
Patrick von Platen
Committed by
GitHub
Apr 20, 2021
Browse files
[Generate] Remove outdated code (#11331)
* remove update function * update * refactor more * refactor
parent
bfd83c17
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
60 deletions
+20
-60
src/transformers/generation_utils.py
src/transformers/generation_utils.py
+20
-60
No files found.
src/transformers/generation_utils.py
View file @
f464f10a
...
@@ -483,31 +483,6 @@ class GenerationMixin:
...
@@ -483,31 +483,6 @@ class GenerationMixin:
model_kwargs
[
"encoder_outputs"
]
=
encoder_outputs
model_kwargs
[
"encoder_outputs"
]
=
encoder_outputs
return
input_ids
,
model_kwargs
return
input_ids
,
model_kwargs
@
staticmethod
def
_init_sequence_length_for_generation
(
input_ids
:
torch
.
LongTensor
,
max_length
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
int
]:
unfinished_sequences
=
input_ids
.
new
(
input_ids
.
shape
[
0
]).
fill_
(
1
)
sequence_lengths
=
input_ids
.
new
(
input_ids
.
shape
[
0
]).
fill_
(
max_length
)
cur_len
=
input_ids
.
shape
[
-
1
]
return
sequence_lengths
,
unfinished_sequences
,
cur_len
@
staticmethod
def
_update_seq_length_for_generation
(
sequence_lengths
:
torch
.
LongTensor
,
unfinished_sequences
:
torch
.
LongTensor
,
cur_len
:
int
,
is_eos_in_next_token
:
torch
.
BoolTensor
,
)
->
Tuple
[
torch
.
LongTensor
,
torch
.
LongTensor
]:
# check if sentence is not finished yet
is_sent_unfinished
=
unfinished_sequences
.
mul
(
is_eos_in_next_token
.
long
()).
bool
()
# update sentence length
sequence_lengths
=
sequence_lengths
.
masked_fill
(
is_sent_unfinished
,
cur_len
)
unfinished_sequences
=
unfinished_sequences
.
mul
((
~
is_eos_in_next_token
).
long
())
return
sequence_lengths
,
unfinished_sequences
@
staticmethod
@
staticmethod
def
_update_model_kwargs_for_generation
(
def
_update_model_kwargs_for_generation
(
outputs
:
ModelOutput
,
model_kwargs
:
Dict
[
str
,
Any
],
is_encoder_decoder
:
bool
=
False
outputs
:
ModelOutput
,
model_kwargs
:
Dict
[
str
,
Any
],
is_encoder_decoder
:
bool
=
False
...
@@ -1271,10 +1246,9 @@ class GenerationMixin:
...
@@ -1271,10 +1246,9 @@ class GenerationMixin:
model_kwargs
[
"encoder_outputs"
].
get
(
"hidden_states"
)
if
output_hidden_states
else
None
model_kwargs
[
"encoder_outputs"
].
get
(
"hidden_states"
)
if
output_hidden_states
else
None
)
)
# init sequence length tensors
# keep track of which sequences are already finished
sequence_lengths
,
unfinished_sequences
,
cur_len
=
self
.
_init_sequence_length_for_generation
(
unfinished_sequences
=
input_ids
.
new
(
input_ids
.
shape
[
0
]).
fill_
(
1
)
input_ids
,
max_length
cur_len
=
input_ids
.
shape
[
-
1
]
)
this_peer_finished
=
False
# used by synced_gpus only
this_peer_finished
=
False
# used by synced_gpus only
while
cur_len
<
max_length
:
while
cur_len
<
max_length
:
...
@@ -1330,29 +1304,23 @@ class GenerationMixin:
...
@@ -1330,29 +1304,23 @@ class GenerationMixin:
# argmax
# argmax
next_tokens
=
torch
.
argmax
(
next_tokens_scores
,
dim
=-
1
)
next_tokens
=
torch
.
argmax
(
next_tokens_scores
,
dim
=-
1
)
#
add code that transforms
next
_
token
s to tokens_to_add
#
finished sentences should have their
next
token
be a padding token
if
eos_token_id
is
not
None
:
if
eos_token_id
is
not
None
:
assert
pad_token_id
is
not
None
,
"If eos_token_id is defined, make sure that pad_token_id is defined."
assert
pad_token_id
is
not
None
,
"If eos_token_id is defined, make sure that pad_token_id is defined."
next_tokens
=
next_tokens
*
unfinished_sequences
+
pad_token_id
*
(
1
-
unfinished_sequences
)
next_tokens
=
next_tokens
*
unfinished_sequences
+
pad_token_id
*
(
1
-
unfinished_sequences
)
#
add token and increase length by one
#
update generated ids, model inputs, and length for next step
input_ids
=
torch
.
cat
([
input_ids
,
next_tokens
[:,
None
]],
dim
=-
1
)
input_ids
=
torch
.
cat
([
input_ids
,
next_tokens
[:,
None
]],
dim
=-
1
)
# update sequence length
if
eos_token_id
is
not
None
:
sequence_lengths
,
unfinished_sequences
=
self
.
_update_seq_length_for_generation
(
sequence_lengths
,
unfinished_sequences
,
cur_len
,
next_tokens
==
eos_token_id
)
# update model kwargs
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
)
)
# increase cur_len
cur_len
=
cur_len
+
1
cur_len
=
cur_len
+
1
# stop when there is a </s> in each sentence, or if we exceed the maximum length
# if eos_token was found in one sentence, set sentence to finished
if
eos_token_id
is
not
None
:
unfinished_sequences
=
unfinished_sequences
.
mul
((
next_tokens
!=
eos_token_id
).
long
())
# stop when each sentence is finished, or if we exceed the maximum length
if
unfinished_sequences
.
max
()
==
0
or
stopping_criteria
(
input_ids
,
scores
):
if
unfinished_sequences
.
max
()
==
0
or
stopping_criteria
(
input_ids
,
scores
):
if
not
synced_gpus
:
if
not
synced_gpus
:
break
break
...
@@ -1511,10 +1479,9 @@ class GenerationMixin:
...
@@ -1511,10 +1479,9 @@ class GenerationMixin:
model_kwargs
[
"encoder_outputs"
].
get
(
"hidden_states"
)
if
output_hidden_states
else
None
model_kwargs
[
"encoder_outputs"
].
get
(
"hidden_states"
)
if
output_hidden_states
else
None
)
)
# init sequence length tensors
# keep track of which sequences are already finished
sequence_lengths
,
unfinished_sequences
,
cur_len
=
self
.
_init_sequence_length_for_generation
(
unfinished_sequences
=
input_ids
.
new
(
input_ids
.
shape
[
0
]).
fill_
(
1
)
input_ids
,
max_length
cur_len
=
input_ids
.
shape
[
-
1
]
)
this_peer_finished
=
False
# used by synced_gpus only
this_peer_finished
=
False
# used by synced_gpus only
# auto-regressive generation
# auto-regressive generation
...
@@ -1571,32 +1538,25 @@ class GenerationMixin:
...
@@ -1571,32 +1538,25 @@ class GenerationMixin:
# sample
# sample
probs
=
F
.
softmax
(
next_token_scores
,
dim
=-
1
)
probs
=
F
.
softmax
(
next_token_scores
,
dim
=-
1
)
next_tokens
=
torch
.
multinomial
(
probs
,
num_samples
=
1
).
squeeze
(
1
)
next_tokens
=
torch
.
multinomial
(
probs
,
num_samples
=
1
).
squeeze
(
1
)
#
add code that transforms
next
_
token
s to tokens_to_add
#
finished sentences should have their
next
token
be a padding token
if
eos_token_id
is
not
None
:
if
eos_token_id
is
not
None
:
assert
pad_token_id
is
not
None
,
"If eos_token_id is defined, make sure that pad_token_id is defined."
assert
pad_token_id
is
not
None
,
"If eos_token_id is defined, make sure that pad_token_id is defined."
next_tokens
=
next_tokens
*
unfinished_sequences
+
pad_token_id
*
(
1
-
unfinished_sequences
)
next_tokens
=
next_tokens
*
unfinished_sequences
+
pad_token_id
*
(
1
-
unfinished_sequences
)
#
add token and increase length by one
#
update generated ids, model inputs, and length for next step
input_ids
=
torch
.
cat
([
input_ids
,
next_tokens
[:,
None
]],
dim
=-
1
)
input_ids
=
torch
.
cat
([
input_ids
,
next_tokens
[:,
None
]],
dim
=-
1
)
# update sequence length
if
eos_token_id
is
not
None
:
sequence_lengths
,
unfinished_sequences
=
self
.
_update_seq_length_for_generation
(
sequence_lengths
,
unfinished_sequences
,
cur_len
,
next_tokens
==
eos_token_id
)
# update model kwargs
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
)
)
# increase cur_len
cur_len
=
cur_len
+
1
cur_len
=
cur_len
+
1
# stop when there is a </s> in each sentence, or if we exceed the maximum length
# if eos_token was found in one sentence, set sentence to finished
if
eos_token_id
is
not
None
:
unfinished_sequences
=
unfinished_sequences
.
mul
((
next_tokens
!=
eos_token_id
).
long
())
# stop when each sentence is finished, or if we exceed the maximum length
if
unfinished_sequences
.
max
()
==
0
or
stopping_criteria
(
input_ids
,
scores
):
if
unfinished_sequences
.
max
()
==
0
or
stopping_criteria
(
input_ids
,
scores
):
if
not
synced_gpus
:
if
not
synced_gpus
:
break
break
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment