Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ca2047bc
Commit
ca2047bc
authored
Mar 09, 2020
by
Patrick von Platen
Browse files
refactor variable naming and improve tf generate in line with torch generate
parent
41b437ea
Changes
2
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
219 additions
and
91 deletions
+219
-91
src/transformers/modeling_tf_utils.py
src/transformers/modeling_tf_utils.py
+170
-49
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+49
-42
No files found.
src/transformers/modeling_tf_utils.py
View file @
ca2047bc
This diff is collapsed.
Click to expand it.
src/transformers/modeling_utils.py
View file @
ca2047bc
...
@@ -722,6 +722,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
...
@@ -722,6 +722,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
max_length
=
max_length
if
max_length
is
not
None
else
self
.
config
.
max_length
max_length
=
max_length
if
max_length
is
not
None
else
self
.
config
.
max_length
min_length
=
min_length
if
min_length
is
not
None
else
self
.
config
.
min_length
min_length
=
min_length
if
min_length
is
not
None
else
self
.
config
.
min_length
do_sample
=
do_sample
if
do_sample
is
not
None
else
self
.
config
.
do_sample
early_stopping
=
early_stopping
if
early_stopping
is
not
None
else
self
.
config
.
early_stopping
early_stopping
=
early_stopping
if
early_stopping
is
not
None
else
self
.
config
.
early_stopping
num_beams
=
num_beams
if
num_beams
is
not
None
else
self
.
config
.
num_beams
num_beams
=
num_beams
if
num_beams
is
not
None
else
self
.
config
.
num_beams
temperature
=
temperature
if
temperature
is
not
None
else
self
.
config
.
temperature
temperature
=
temperature
if
temperature
is
not
None
else
self
.
config
.
temperature
...
@@ -852,7 +853,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
...
@@ -852,7 +853,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
device
=
next
(
self
.
parameters
()).
device
,
device
=
next
(
self
.
parameters
()).
device
,
)
)
cur_len
=
1
cur_len
=
1
self
.
model
.
decoder
.
generation_mode
=
True
# put model in generation mode if it has one
if
hasattr
(
self
.
model
,
"generation_mode"
):
self
.
model
.
decoder
.
generation_mode
=
True
else
:
else
:
encoder_inputs
=
None
encoder_inputs
=
None
cur_len
=
input_ids
.
shape
[
-
1
]
cur_len
=
input_ids
.
shape
[
-
1
]
...
@@ -860,44 +864,44 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
...
@@ -860,44 +864,44 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
if
num_beams
>
1
:
if
num_beams
>
1
:
output
=
self
.
_generate_beam_search
(
output
=
self
.
_generate_beam_search
(
input_ids
,
input_ids
,
cur_len
,
cur_len
=
cur_len
,
max_length
,
max_length
=
max_length
,
min_length
,
min_length
=
min_length
,
do_sample
,
do_sample
=
do_sample
,
early_stopping
,
early_stopping
=
early_stopping
,
temperature
,
temperature
=
temperature
,
top_k
,
top_k
=
top_k
,
top_p
,
top_p
=
top_p
,
repetition_penalty
,
repetition_penalty
=
repetition_penalty
,
no_repeat_ngram_size
,
no_repeat_ngram_size
=
no_repeat_ngram_size
,
bos_token_id
,
bos_token_id
=
bos_token_id
,
pad_token_id
,
pad_token_id
=
pad_token_id
,
eos_token_ids
,
eos_token_ids
=
eos_token_ids
,
effective_batch_size
,
batch_size
=
effective_batch_size
,
num_return_sequences
,
num_return_sequences
=
num_return_sequences
,
length_penalty
,
length_penalty
=
length_penalty
,
num_beams
,
num_beams
=
num_beams
,
vocab_size
,
vocab_size
=
vocab_size
,
encoder_inputs
,
encoder_inputs
=
encoder_inputs
,
attention_mask
,
attention_mask
=
attention_mask
,
)
)
else
:
else
:
output
=
self
.
_generate_no_beam_search
(
output
=
self
.
_generate_no_beam_search
(
input_ids
,
input_ids
,
cur_len
,
cur_len
=
cur_len
,
max_length
,
max_length
=
max_length
,
min_length
,
min_length
=
min_length
,
do_sample
,
do_sample
=
do_sample
,
temperature
,
temperature
=
temperature
,
top_k
,
top_k
=
top_k
,
top_p
,
top_p
=
top_p
,
repetition_penalty
,
repetition_penalty
=
repetition_penalty
,
no_repeat_ngram_size
,
no_repeat_ngram_size
=
no_repeat_ngram_size
,
pad_token_id
,
pad_token_id
=
pad_token_id
,
eos_token_ids
,
eos_token_ids
=
eos_token_ids
,
effective_batch_size
,
batch_size
=
effective_batch_size
,
encoder_inputs
,
encoder_inputs
=
encoder_inputs
,
attention_mask
,
attention_mask
=
attention_mask
,
)
)
return
output
return
output
...
@@ -1157,24 +1161,27 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
...
@@ -1157,24 +1161,27 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
next_sent_beam
=
[]
next_sent_beam
=
[]
# next tokens for this sentence
# next tokens for this sentence
for
i
,
(
idx
,
score
)
in
enumerate
(
zip
(
next_tokens
[
batch_idx
],
next_scores
[
batch_idx
])):
for
beam_token_rank
,
(
beam_token_id
,
beam_token_score
)
in
enumerate
(
zip
(
next_tokens
[
batch_idx
],
next_scores
[
batch_idx
])
):
# get beam and word IDs
# get beam and word IDs
beam_id
=
id
x
//
vocab_size
beam_id
=
beam_token_
id
//
vocab_size
token_id
=
id
x
%
vocab_size
token_id
=
beam_token_
id
%
vocab_size
effective_beam_id
=
batch_idx
*
num_beams
+
beam_id
effective_beam_id
=
batch_idx
*
num_beams
+
beam_id
# add to generated hypotheses if end of sentence
# add to generated hypotheses if end of sentence
if
(
eos_token_ids
is
not
None
)
and
(
token_id
.
item
()
in
eos_token_ids
):
if
(
eos_token_ids
is
not
None
)
and
(
token_id
.
item
()
in
eos_token_ids
):
# when passed to num_beams hypotheses, continue
# if beam_token does not belong to top num_beams tokens, it should not be added
if
i
>=
num_beams
:
is_beam_token_worse_than_top_num_beams
=
beam_token_rank
>=
num_beams
if
is_beam_token_worse_than_top_num_beams
:
continue
continue
generated_hyps
[
batch_idx
].
add
(
generated_hyps
[
batch_idx
].
add
(
input_ids
[
effective_beam_id
].
clone
(),
score
.
item
(),
input_ids
[
effective_beam_id
].
clone
(),
beam_token_
score
.
item
(),
)
)
else
:
else
:
# add next predicted word if it is not eos_token
# add next predicted word if it is not eos_token
next_sent_beam
.
append
((
score
,
token_id
,
effective_beam_id
))
next_sent_beam
.
append
((
beam_token_
score
,
token_id
,
effective_beam_id
))
# the beam for next step is full
# the beam for next step is full
if
len
(
next_sent_beam
)
==
num_beams
:
if
len
(
next_sent_beam
)
==
num_beams
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment