Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1ba21f96
Commit
1ba21f96
authored
Mar 10, 2020
by
Patrick von Platen
Browse files
fix bug in tf no_repeat_ngram_size
parent
d997ac78
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
1 deletion
+2
-1
src/transformers/modeling_tf_utils.py
src/transformers/modeling_tf_utils.py
+2
-1
No files found.
src/transformers/modeling_tf_utils.py
View file @
1ba21f96
...
...
@@ -942,7 +942,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
if
no_repeat_ngram_size
>
0
:
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
banned_tokens
=
calc_banned_tokens
(
input_ids
,
batch_size
,
no_repeat_ngram_size
,
cur_len
)
num_batch_hypotheses
=
batch_size
*
num_beams
banned_tokens
=
calc_banned_tokens
(
input_ids
,
num_batch_hypotheses
,
no_repeat_ngram_size
,
cur_len
)
# create banned_tokens boolean mask
banned_tokens_indices_mask
=
[]
for
banned_tokens_slice
in
banned_tokens
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment