Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ca1330f0
"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "1f6f32c24338ad1ff17475b836c7b4505da77714"
Commit
ca1330f0
authored
Mar 10, 2020
by
Patrick von Platen
Browse files
do not mess with the negative sign
parent
10989715
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
18 deletions
+19
-18
src/transformers/modeling_tf_utils.py
src/transformers/modeling_tf_utils.py
+19
-18
No files found.
src/transformers/modeling_tf_utils.py
View file @
ca1330f0
...
@@ -894,7 +894,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
...
@@ -894,7 +894,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
# for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
# for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
if
do_sample
is
False
:
if
do_sample
is
False
:
beam_scores_begin
=
tf
.
zeros
((
batch_size
,
1
),
dtype
=
tf
.
float32
)
beam_scores_begin
=
tf
.
zeros
((
batch_size
,
1
),
dtype
=
tf
.
float32
)
beam_scores_end
=
tf
.
zero
s
((
batch_size
,
num_beams
-
1
),
dtype
=
tf
.
float32
)
*
1e-9
beam_scores_end
=
tf
.
one
s
((
batch_size
,
num_beams
-
1
),
dtype
=
tf
.
float32
)
*
(
-
1e9
)
beam_scores
=
tf
.
concat
([
beam_scores_begin
,
beam_scores_end
],
-
1
)
beam_scores
=
tf
.
concat
([
beam_scores_begin
,
beam_scores_end
],
-
1
)
else
:
else
:
beam_scores
=
tf
.
zeros
((
batch_size
,
num_beams
),
dtype
=
tf
.
float32
)
beam_scores
=
tf
.
zeros
((
batch_size
,
num_beams
),
dtype
=
tf
.
float32
)
...
@@ -926,6 +926,21 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
...
@@ -926,6 +926,21 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
if
temperature
!=
1.0
:
if
temperature
!=
1.0
:
next_token_logits
=
next_token_logits
/
temperature
next_token_logits
=
next_token_logits
/
temperature
# calculate log softmax score
scores
=
tf
.
nn
.
log_softmax
(
next_token_logits
,
axis
=-
1
)
# (batch_size * num_beams, vocab_size)
# set eos token prob to zero if min_length is not reached
if
eos_token_ids
is
not
None
and
cur_len
<
min_length
:
# create eos_token_ids boolean mask
is_token_logit_eos_token
=
tf
.
convert_to_tensor
(
[
True
if
token
in
eos_token_ids
else
False
for
token
in
range
(
vocab_size
)],
dtype
=
tf
.
bool
)
eos_token_indices_mask
=
tf
.
broadcast_to
(
is_token_logit_eos_token
,
[
batch_size
,
vocab_size
])
scores
=
set_tensor_by_indices_to_value
(
scores
,
eos_token_indices_mask
,
-
float
(
"inf"
)
)
if
no_repeat_ngram_size
>
0
:
if
no_repeat_ngram_size
>
0
:
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
...
@@ -937,24 +952,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
...
@@ -937,24 +952,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
[
True
if
token
in
banned_tokens_slice
else
False
for
token
in
range
(
vocab_size
)]
[
True
if
token
in
banned_tokens_slice
else
False
for
token
in
range
(
vocab_size
)]
)
)
next_token_logits
=
set_tensor_by_indices_to_value
(
scores
=
set_tensor_by_indices_to_value
(
next_token_logits
,
tf
.
convert_to_tensor
(
banned_tokens_indices_mask
,
dtype
=
tf
.
bool
),
-
float
(
"inf"
)
scores
,
tf
.
convert_to_tensor
(
banned_tokens_indices_mask
,
dtype
=
tf
.
bool
),
-
float
(
"inf"
)
)
# set eos token prob to zero if min_length is not reached
if
eos_token_ids
is
not
None
and
cur_len
<
min_length
:
# create eos_token_ids boolean mask
is_token_logit_eos_token
=
tf
.
convert_to_tensor
(
[
True
if
token
in
eos_token_ids
else
False
for
token
in
range
(
vocab_size
)],
dtype
=
tf
.
bool
)
eos_token_indices_mask
=
tf
.
broadcast_to
(
is_token_logit_eos_token
,
[
batch_size
,
vocab_size
])
next_token_logits
=
set_tensor_by_indices_to_value
(
next_token_logits
,
eos_token_indices_mask
,
-
float
(
"inf"
)
)
)
# calculate log softmax score
scores
=
tf
.
nn
.
log_softmax
(
next_token_logits
,
axis
=-
1
)
# (batch_size * num_beams, vocab_size)
assert
shape_list
(
scores
)
==
[
batch_size
*
num_beams
,
vocab_size
]
assert
shape_list
(
scores
)
==
[
batch_size
*
num_beams
,
vocab_size
]
if
do_sample
:
if
do_sample
:
...
@@ -991,6 +992,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
...
@@ -991,6 +992,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
)
# (batch_size, num_beams * vocab_size)
)
# (batch_size, num_beams * vocab_size)
next_scores
,
next_tokens
=
tf
.
math
.
top_k
(
next_scores
,
k
=
2
*
num_beams
,
sorted
=
True
)
next_scores
,
next_tokens
=
tf
.
math
.
top_k
(
next_scores
,
k
=
2
*
num_beams
,
sorted
=
True
)
print
(
next_tokens
)
assert
shape_list
(
next_scores
)
==
shape_list
(
next_tokens
)
==
[
batch_size
,
2
*
num_beams
]
assert
shape_list
(
next_scores
)
==
shape_list
(
next_tokens
)
==
[
batch_size
,
2
*
num_beams
]
...
@@ -1064,7 +1066,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
...
@@ -1064,7 +1066,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
# re-order batch
# re-order batch
input_ids
=
tf
.
stack
([
tf
.
identity
(
input_ids
[
x
,
:])
for
x
in
beam_idx
])
input_ids
=
tf
.
stack
([
tf
.
identity
(
input_ids
[
x
,
:])
for
x
in
beam_idx
])
input_ids
=
tf
.
concat
([
input_ids
,
tf
.
expand_dims
(
beam_tokens
,
1
)],
axis
=-
1
)
input_ids
=
tf
.
concat
([
input_ids
,
tf
.
expand_dims
(
beam_tokens
,
1
)],
axis
=-
1
)
# re-order internal states
# re-order internal states
if
past
:
if
past
:
past
=
self
.
_reorder_cache
(
past
,
beam_idx
)
past
=
self
.
_reorder_cache
(
past
,
beam_idx
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment