Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9b8ee8ce
Commit
9b8ee8ce
authored
Mar 10, 2020
by
Patrick von Platen
Browse files
delete print and make style
parent
ca1330f0
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
5 deletions
+2
-5
src/transformers/modeling_tf_utils.py
src/transformers/modeling_tf_utils.py
+2
-5
No files found.
src/transformers/modeling_tf_utils.py
View file @
9b8ee8ce
...
@@ -926,7 +926,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
...
@@ -926,7 +926,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
if
temperature
!=
1.0
:
if
temperature
!=
1.0
:
next_token_logits
=
next_token_logits
/
temperature
next_token_logits
=
next_token_logits
/
temperature
# calculate log softmax score
# calculate log softmax score
scores
=
tf
.
nn
.
log_softmax
(
next_token_logits
,
axis
=-
1
)
# (batch_size * num_beams, vocab_size)
scores
=
tf
.
nn
.
log_softmax
(
next_token_logits
,
axis
=-
1
)
# (batch_size * num_beams, vocab_size)
# set eos token prob to zero if min_length is not reached
# set eos token prob to zero if min_length is not reached
...
@@ -937,9 +937,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
...
@@ -937,9 +937,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
)
)
eos_token_indices_mask
=
tf
.
broadcast_to
(
is_token_logit_eos_token
,
[
batch_size
,
vocab_size
])
eos_token_indices_mask
=
tf
.
broadcast_to
(
is_token_logit_eos_token
,
[
batch_size
,
vocab_size
])
scores
=
set_tensor_by_indices_to_value
(
scores
=
set_tensor_by_indices_to_value
(
scores
,
eos_token_indices_mask
,
-
float
(
"inf"
))
scores
,
eos_token_indices_mask
,
-
float
(
"inf"
)
)
if
no_repeat_ngram_size
>
0
:
if
no_repeat_ngram_size
>
0
:
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
...
@@ -992,7 +990,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
...
@@ -992,7 +990,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
)
# (batch_size, num_beams * vocab_size)
)
# (batch_size, num_beams * vocab_size)
next_scores
,
next_tokens
=
tf
.
math
.
top_k
(
next_scores
,
k
=
2
*
num_beams
,
sorted
=
True
)
next_scores
,
next_tokens
=
tf
.
math
.
top_k
(
next_scores
,
k
=
2
*
num_beams
,
sorted
=
True
)
print
(
next_tokens
)
assert
shape_list
(
next_scores
)
==
shape_list
(
next_tokens
)
==
[
batch_size
,
2
*
num_beams
]
assert
shape_list
(
next_scores
)
==
shape_list
(
next_tokens
)
==
[
batch_size
,
2
*
num_beams
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment