Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
bdd3d0c7
Unverified
Commit
bdd3d0c7
authored
Mar 04, 2020
by
Thomas Wolf
Committed by
GitHub
Mar 04, 2020
Browse files
Merge pull request #3118 from patrickvonplaten/add_beam_search_to_generation_tf_2_0
Add beam search to generation tf 2 0
parents
c440030e
7a89a3e4
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
334 additions
and
24 deletions
+334
-24
src/transformers/modeling_tf_ctrl.py
src/transformers/modeling_tf_ctrl.py
+2
-2
src/transformers/modeling_tf_gpt2.py
src/transformers/modeling_tf_gpt2.py
+2
-2
src/transformers/modeling_tf_utils.py
src/transformers/modeling_tf_utils.py
+310
-18
tests/test_modeling_tf_common.py
tests/test_modeling_tf_common.py
+20
-2
No files found.
src/transformers/modeling_tf_ctrl.py
View file @
bdd3d0c7
...
@@ -104,10 +104,10 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
...
@@ -104,10 +104,10 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
k
=
self
.
split_into_heads
(
k
,
batch_size
)
k
=
self
.
split_into_heads
(
k
,
batch_size
)
v
=
self
.
split_into_heads
(
v
,
batch_size
)
v
=
self
.
split_into_heads
(
v
,
batch_size
)
if
layer_past
is
not
None
:
if
layer_past
is
not
None
:
past_key
,
past_value
=
tf
.
unstack
(
layer_past
,
axis
=
1
)
past_key
,
past_value
=
tf
.
unstack
(
layer_past
,
axis
=
0
)
k
=
tf
.
concat
((
past_key
,
k
),
axis
=-
2
)
k
=
tf
.
concat
((
past_key
,
k
),
axis
=-
2
)
v
=
tf
.
concat
((
past_value
,
v
),
axis
=-
2
)
v
=
tf
.
concat
((
past_value
,
v
),
axis
=-
2
)
present
=
tf
.
stack
((
k
,
v
),
axis
=
1
)
present
=
tf
.
stack
((
k
,
v
),
axis
=
0
)
output
=
scaled_dot_product_attention
(
q
,
k
,
v
,
mask
,
attention_mask
,
head_mask
)
output
=
scaled_dot_product_attention
(
q
,
k
,
v
,
mask
,
attention_mask
,
head_mask
)
scaled_attention
=
tf
.
transpose
(
output
[
0
],
perm
=
[
0
,
2
,
1
,
3
])
scaled_attention
=
tf
.
transpose
(
output
[
0
],
perm
=
[
0
,
2
,
1
,
3
])
...
...
src/transformers/modeling_tf_gpt2.py
View file @
bdd3d0c7
...
@@ -139,10 +139,10 @@ class TFAttention(tf.keras.layers.Layer):
...
@@ -139,10 +139,10 @@ class TFAttention(tf.keras.layers.Layer):
key
=
self
.
split_heads
(
key
)
key
=
self
.
split_heads
(
key
)
value
=
self
.
split_heads
(
value
)
value
=
self
.
split_heads
(
value
)
if
layer_past
is
not
None
:
if
layer_past
is
not
None
:
past_key
,
past_value
=
tf
.
unstack
(
layer_past
,
axis
=
1
)
past_key
,
past_value
=
tf
.
unstack
(
layer_past
,
axis
=
0
)
key
=
tf
.
concat
([
past_key
,
key
],
axis
=-
2
)
key
=
tf
.
concat
([
past_key
,
key
],
axis
=-
2
)
value
=
tf
.
concat
([
past_value
,
value
],
axis
=-
2
)
value
=
tf
.
concat
([
past_value
,
value
],
axis
=-
2
)
present
=
tf
.
stack
([
key
,
value
],
axis
=
1
)
present
=
tf
.
stack
([
key
,
value
],
axis
=
0
)
attn_outputs
=
self
.
_attn
([
query
,
key
,
value
,
attention_mask
,
head_mask
],
training
=
training
)
attn_outputs
=
self
.
_attn
([
query
,
key
,
value
,
attention_mask
,
head_mask
],
training
=
training
)
a
=
attn_outputs
[
0
]
a
=
attn_outputs
[
0
]
...
...
src/transformers/modeling_tf_utils.py
View file @
bdd3d0c7
...
@@ -142,7 +142,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
...
@@ -142,7 +142,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
# # initialize all new embeddings (in particular added tokens)
# # initialize all new embeddings (in particular added tokens)
# self._init_weights(new_embeddings)
# self._init_weights(new_embeddings)
# # Copy
word
embeddings from the previous weights
# # Copy
token
embeddings from the previous weights
# num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
# num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
# new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
# new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
...
@@ -557,6 +557,19 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
...
@@ -557,6 +557,19 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
else
:
else
:
assert
len
(
shape_list
(
input_ids
))
==
2
,
"Input prompt should be of shape (batch_size, sequence length)."
assert
len
(
shape_list
(
input_ids
))
==
2
,
"Input prompt should be of shape (batch_size, sequence length)."
if
do_sample
is
False
:
if
num_beams
==
1
:
# no_beam_search greedy generation conditions
assert
(
num_return_sequences
==
1
),
"Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1"
else
:
# beam_search greedy generation conditions
assert
(
num_beams
>=
num_return_sequences
),
"Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"
if
pad_token_id
is
None
and
eos_token_ids
is
not
None
:
if
pad_token_id
is
None
and
eos_token_ids
is
not
None
:
logger
.
warning
(
logger
.
warning
(
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence"
.
format
(
eos_token_ids
[
0
])
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence"
.
format
(
eos_token_ids
[
0
])
...
@@ -567,7 +580,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
...
@@ -567,7 +580,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
cur_len
=
shape_list
(
input_ids
)[
1
]
cur_len
=
shape_list
(
input_ids
)[
1
]
vocab_size
=
self
.
config
.
vocab_size
vocab_size
=
self
.
config
.
vocab_size
if
num_return_sequences
!=
1
:
if
num_return_sequences
!=
1
and
do_sample
:
# Expand input to num return sequences
# Expand input to num return sequences
input_ids
=
tf
.
broadcast_to
(
tf
.
expand_dims
(
input_ids
,
1
),
(
batch_size
,
num_return_sequences
,
cur_len
))
input_ids
=
tf
.
broadcast_to
(
tf
.
expand_dims
(
input_ids
,
1
),
(
batch_size
,
num_return_sequences
,
cur_len
))
effective_batch_size
=
batch_size
*
num_return_sequences
effective_batch_size
=
batch_size
*
num_return_sequences
...
@@ -588,6 +601,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
...
@@ -588,6 +601,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
pad_token_id
,
pad_token_id
,
eos_token_ids
,
eos_token_ids
,
effective_batch_size
,
effective_batch_size
,
num_return_sequences
,
length_penalty
,
length_penalty
,
num_beams
,
num_beams
,
vocab_size
,
vocab_size
,
...
@@ -627,19 +641,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
...
@@ -627,19 +641,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
All returned sequence are generated independantly.
All returned sequence are generated independantly.
"""
"""
def
_create_next_token_logits_penalties
(
input_ids
,
logits
):
# length of generated sentences / unfinished sentences
# create logit penalties for already seen input_ids
token_penalties
=
np
.
ones
(
shape_list
(
logits
))
prev_input_ids
=
[
np
.
unique
(
input_id
)
for
input_id
in
input_ids
.
numpy
()]
for
i
,
prev_input_id
in
enumerate
(
prev_input_ids
):
logit_penalized
=
logits
[
i
].
numpy
()[
prev_input_id
]
# if previous logit score is < 0 then multiply repetition penalty else divide
logit_penalized
[
logit_penalized
<
0
]
=
repetition_penalty
logit_penalized
[
logit_penalized
>
0
]
=
1
/
repetition_penalty
np
.
put
(
token_penalties
[
i
],
prev_input_id
,
logit_penalized
)
return
tf
.
convert_to_tensor
(
token_penalties
,
dtype
=
tf
.
float32
)
# current position / max lengths / length of generated sentences / unfinished sentences
unfinished_sents
=
tf
.
ones_like
(
input_ids
[:,
0
])
unfinished_sents
=
tf
.
ones_like
(
input_ids
[:,
0
])
sent_lengths
=
tf
.
ones_like
(
input_ids
[:,
0
])
*
max_length
sent_lengths
=
tf
.
ones_like
(
input_ids
[:,
0
])
*
max_length
...
@@ -656,7 +658,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
...
@@ -656,7 +658,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
if
repetition_penalty
!=
1.0
:
if
repetition_penalty
!=
1.0
:
next_token_logits_penalties
=
_create_next_token_logits_penalties
(
input_ids
,
next_token_logits
)
next_token_logits_penalties
=
_create_next_token_logits_penalties
(
input_ids
,
next_token_logits
,
repetition_penalty
)
next_token_logits
=
tf
.
math
.
multiply
(
next_token_logits
,
next_token_logits_penalties
)
next_token_logits
=
tf
.
math
.
multiply
(
next_token_logits
,
next_token_logits_penalties
)
if
do_sample
:
if
do_sample
:
...
@@ -738,11 +742,249 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
...
@@ -738,11 +742,249 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
pad_token_id
,
pad_token_id
,
eos_token_ids
,
eos_token_ids
,
batch_size
,
batch_size
,
num_return_sequences
,
length_penalty
,
length_penalty
,
num_beams
,
num_beams
,
vocab_size
,
vocab_size
,
):
):
pass
""" Generate sequences for each example with beam search.
"""
# Expand input to num beams
input_ids
=
tf
.
broadcast_to
(
tf
.
expand_dims
(
input_ids
,
1
),
(
batch_size
,
num_beams
,
cur_len
))
input_ids
=
tf
.
reshape
(
input_ids
,
(
batch_size
*
num_beams
,
cur_len
))
# (batch_size * num_beams, cur_len)
# generated hypotheses
generated_hyps
=
[
BeamHypotheses
(
num_beams
,
max_length
,
length_penalty
,
early_stopping
=
False
)
for
_
in
range
(
batch_size
)
]
# scores for each sentence in the beam
if
do_sample
is
False
:
beam_scores_begin
=
tf
.
zeros
((
batch_size
,
1
),
dtype
=
tf
.
float32
)
beam_scores_end
=
tf
.
zeros
((
batch_size
,
num_beams
-
1
),
dtype
=
tf
.
float32
)
*
1e-9
beam_scores
=
tf
.
concat
([
beam_scores_begin
,
beam_scores_end
],
-
1
)
else
:
beam_scores
=
tf
.
zeros
((
batch_size
,
num_beams
),
dtype
=
tf
.
float32
)
beam_scores
=
tf
.
reshape
(
beam_scores
,
(
batch_size
*
num_beams
,))
# cache compute states
past
=
None
# done sentences
done
=
[
False
for
_
in
range
(
batch_size
)]
while
cur_len
<
max_length
:
model_inputs
=
self
.
prepare_inputs_for_generation
(
input_ids
,
past
=
past
)
outputs
=
self
(
**
model_inputs
)
# (batch_size * num_beams, cur_len, vocab_size)
next_token_logits
=
outputs
[
0
][:,
-
1
,
:]
# (batch_size * num_beams, vocab_size)
# if model has past, then set the past variable to speed up decoding
if
self
.
_do_output_past
(
outputs
):
past
=
outputs
[
1
]
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
if
repetition_penalty
!=
1.0
:
next_token_logits_penalties
=
_create_next_token_logits_penalties
(
input_ids
,
next_token_logits
,
repetition_penalty
)
next_token_logits
=
tf
.
math
.
multiply
(
next_token_logits
,
next_token_logits_penalties
)
if
do_sample
:
# Temperature (higher temperature => more likely to sample low probability tokens)
if
temperature
!=
1.0
:
next_token_logits
=
next_token_logits
/
temperature
scores
=
tf
.
nn
.
log_softmax
(
next_token_logits
,
axis
=-
1
)
# (batch_size * num_beams, vocab_size)
_scores
=
scores
+
tf
.
broadcast_to
(
beam_scores
[:,
None
],
(
batch_size
*
num_beams
,
vocab_size
)
)
# (batch_size * num_beams, vocab_size)
# Top-p/top-k filtering
_scores
=
tf_top_k_top_p_filtering
(
_scores
,
top_k
=
top_k
,
top_p
=
top_p
,
min_tokens_to_keep
=
2
)
# (batch_size * num_beams, vocab_size)
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
_scores
=
tf
.
reshape
(
_scores
,
(
batch_size
,
num_beams
*
vocab_size
))
next_tokens
=
tf
.
random
.
categorical
(
_scores
,
dtype
=
tf
.
int32
,
num_samples
=
2
*
num_beams
)
# (batch_size, 2 * num_beams)
# Compute next scores
next_scores
=
tf
.
gather
(
_scores
,
next_tokens
,
batch_dims
=
1
)
# (batch_size, 2 * num_beams)
else
:
# do greedy beam search
scores
=
tf
.
nn
.
log_softmax
(
next_token_logits
,
axis
=-
1
)
# (batch_size * num_beams, vocab_size)
assert
shape_list
(
scores
)
==
[
batch_size
*
num_beams
,
vocab_size
]
# Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
next_scores
=
scores
+
tf
.
broadcast_to
(
beam_scores
[:,
None
],
(
batch_size
*
num_beams
,
vocab_size
)
)
# (batch_size * num_beams, vocab_size)
# re-organize to group the beam together (we are keeping top hypothesis accross beams)
next_scores
=
tf
.
reshape
(
next_scores
,
(
batch_size
,
num_beams
*
vocab_size
)
)
# (batch_size, num_beams * vocab_size)
next_scores
,
next_tokens
=
tf
.
math
.
top_k
(
next_scores
,
2
*
num_beams
,
sorted
=
True
)
assert
shape_list
(
next_scores
)
==
shape_list
(
next_tokens
)
==
[
batch_size
,
2
*
num_beams
]
# next batch beam content
# list of (batch_size * num_beams) tuple(next hypothesis score, next token, current position in the batch)
next_batch_beam
=
[]
# for each sentence
for
batch_idx
in
range
(
batch_size
):
# if we are done with this sentence
done
[
batch_idx
]
=
done
[
batch_idx
]
or
generated_hyps
[
batch_idx
].
is_done
(
tf
.
reduce_max
(
next_scores
[
batch_idx
]).
numpy
()
)
if
done
[
batch_idx
]:
assert
(
len
(
generated_hyps
[
batch_idx
])
>=
num_beams
),
"Batch can only be done if at least {} beams have been generated"
.
format
(
num_beams
)
assert
(
eos_token_ids
is
not
None
and
pad_token_id
is
not
None
),
"generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
next_batch_beam
.
extend
([(
0
,
pad_token_id
,
0
)]
*
num_beams
)
# pad the batch
continue
# next sentence beam content
next_sent_beam
=
[]
# next tokens for this sentence
for
idx
,
score
in
zip
(
next_tokens
[
batch_idx
],
next_scores
[
batch_idx
]):
# get beam and token IDs
beam_id
=
idx
//
vocab_size
token_id
=
idx
%
vocab_size
# add to generated hypotheses if end of sentence or last iteration
if
eos_token_ids
is
not
None
and
token_id
.
numpy
()
in
eos_token_ids
:
generated_hyps
[
batch_idx
].
add
(
tf
.
identity
(
input_ids
[
batch_idx
*
num_beams
+
beam_id
,
:
cur_len
]),
score
.
numpy
()
)
else
:
# add next predicted token if it is not eos_token
next_sent_beam
.
append
((
score
,
token_id
,
batch_idx
*
num_beams
+
beam_id
))
# the beam for next step is full
if
len
(
next_sent_beam
)
==
num_beams
:
break
# update next beam content
assert
len
(
next_sent_beam
)
==
num_beams
,
"Beam should always be full"
next_batch_beam
.
extend
(
next_sent_beam
)
assert
len
(
next_batch_beam
)
==
num_beams
*
(
batch_idx
+
1
)
# sanity check / prepare next batch
assert
len
(
next_batch_beam
)
==
batch_size
*
num_beams
beam_scores
=
tf
.
convert_to_tensor
([
x
[
0
]
for
x
in
next_batch_beam
],
dtype
=
tf
.
float32
)
beam_tokens
=
tf
.
convert_to_tensor
([
x
[
1
]
for
x
in
next_batch_beam
],
dtype
=
tf
.
int32
)
beam_idx
=
tf
.
convert_to_tensor
([
x
[
2
]
for
x
in
next_batch_beam
],
dtype
=
tf
.
int32
)
# re-order batch
input_ids
=
tf
.
stack
([
tf
.
identity
(
input_ids
[
x
,
:])
for
x
in
beam_idx
])
input_ids
=
tf
.
concat
([
input_ids
,
tf
.
expand_dims
(
beam_tokens
,
1
)],
axis
=-
1
)
# re-order internal states
if
past
:
past
=
self
.
_reorder_cache
(
past
,
beam_idx
)
# update current length
cur_len
=
cur_len
+
1
# stop when we are done with each sentence
if
all
(
done
):
break
for
batch_idx
in
range
(
batch_size
):
# Add all open beam hypothesis to generated_hyps
if
not
done
[
batch_idx
]:
for
idx
,
score
in
zip
(
next_tokens
[
batch_idx
],
next_scores
[
batch_idx
]):
# get beam and token IDs
beam_id
=
idx
//
vocab_size
token_id
=
idx
%
vocab_size
generated_hyps
[
batch_idx
].
add
(
tf
.
identity
(
input_ids
[
batch_idx
*
num_beams
+
beam_id
,
:
cur_len
]),
score
.
numpy
()
)
# depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
output_batch_size
=
batch_size
if
do_sample
else
batch_size
*
num_return_sequences
output_num_return_sequences_per_batch
=
1
if
do_sample
else
num_return_sequences
# select the best hypotheses
sent_lengths_list
=
[]
best
=
[]
# retrieve best hypotheses
for
i
,
hypotheses
in
enumerate
(
generated_hyps
):
sorted_hyps
=
sorted
(
hypotheses
.
beams
,
key
=
lambda
x
:
x
[
0
])
for
j
in
range
(
output_num_return_sequences_per_batch
):
best_hyp
=
sorted_hyps
.
pop
()[
1
]
sent_lengths_list
.
append
(
len
(
best_hyp
))
best
.
append
(
best_hyp
)
assert
output_batch_size
==
len
(
best
),
"Output batch size {} must match output beam hypotheses {}"
.
format
(
output_batch_size
,
len
(
best
)
)
sent_lengths
=
tf
.
convert_to_tensor
(
sent_lengths_list
,
dtype
=
tf
.
int32
)
# shorter batches are filled with pad_token
if
tf
.
reduce_min
(
sent_lengths
).
numpy
()
!=
tf
.
reduce_max
(
sent_lengths
).
numpy
():
assert
pad_token_id
is
not
None
,
"`Pad_token_id` has to be defined"
sent_max_len
=
min
(
tf
.
reduce_max
(
sent_lengths
).
numpy
()
+
1
,
max_length
)
decoded_list
=
[]
# fill with hypothesis and eos_token_id if necessary
for
i
,
hypo
in
enumerate
(
best
):
padding
=
tf
.
ones
((
sent_max_len
-
shape_list
(
hypo
)[
0
],),
dtype
=
tf
.
int32
)
*
pad_token_id
decoded_hypo
=
tf
.
concat
([
hypo
,
padding
],
axis
=
0
)
if
sent_lengths
[
i
]
<
max_length
:
decoded_hypo
=
tf
.
where
(
tf
.
range
(
max_length
)
==
sent_lengths
[
i
],
eos_token_ids
[
0
]
*
tf
.
ones
((
sent_max_len
,),
dtype
=
tf
.
int32
),
decoded_hypo
,
)
decoded_list
.
append
(
decoded_hypo
)
decoded
=
tf
.
stack
(
decoded_list
)
else
:
# none of the hypotheses have an eos_token
assert
(
len
(
hypo
)
==
max_length
for
hypo
in
best
)
decoded
=
tf
.
stack
(
best
)
return
decoded
@
staticmethod
def
_reorder_cache
(
past
,
beam_idx
):
reordered_past
=
[]
for
layer_past
in
past
:
# get the correct batch idx from layer past batch dim
# batch dim of `past` and `mems` is at 2nd position
reordered_layer_past
=
[
tf
.
identity
(
tf
.
expand_dims
(
layer_past
[:,
i
],
1
))
for
i
in
beam_idx
]
reordered_layer_past
=
tf
.
concat
(
reordered_layer_past
,
axis
=
1
)
# check that shape matches
assert
shape_list
(
reordered_layer_past
)
==
shape_list
(
layer_past
)
reordered_past
.
append
(
reordered_layer_past
)
past
=
tuple
(
reordered_past
)
return
past
def
_create_next_token_logits_penalties
(
input_ids
,
logits
,
repetition_penalty
):
# create logit penalties for already seen input_ids
token_penalties
=
np
.
ones
(
shape_list
(
logits
))
prev_input_ids
=
[
np
.
unique
(
input_id
)
for
input_id
in
input_ids
.
numpy
()]
for
i
,
prev_input_id
in
enumerate
(
prev_input_ids
):
logit_penalized
=
logits
[
i
].
numpy
()[
prev_input_id
]
# if previous logit score is < 0 then multiply repetition penalty else divide
logit_penalized
[
logit_penalized
<
0
]
=
repetition_penalty
logit_penalized
[
logit_penalized
>
0
]
=
1
/
repetition_penalty
np
.
put
(
token_penalties
[
i
],
prev_input_id
,
logit_penalized
)
return
tf
.
convert_to_tensor
(
token_penalties
,
dtype
=
tf
.
float32
)
def
tf_top_k_top_p_filtering
(
logits
,
top_k
=
0
,
top_p
=
1.0
,
filter_value
=-
float
(
"Inf"
),
min_tokens_to_keep
=
1
):
def
tf_top_k_top_p_filtering
(
logits
,
top_k
=
0
,
top_p
=
1.0
,
filter_value
=-
float
(
"Inf"
),
min_tokens_to_keep
=
1
):
...
@@ -811,6 +1053,56 @@ def set_tensor_by_indices_to_value(tensor, indices, value):
...
@@ -811,6 +1053,56 @@ def set_tensor_by_indices_to_value(tensor, indices, value):
return
tf
.
where
(
indices
,
value_tensor
,
tensor
)
return
tf
.
where
(
indices
,
value_tensor
,
tensor
)
class
BeamHypotheses
(
object
):
def
__init__
(
self
,
num_beams
,
max_length
,
length_penalty
,
early_stopping
):
"""
Initialize n-best list of hypotheses.
"""
self
.
max_length
=
max_length
-
1
# ignoring bos_token
self
.
length_penalty
=
length_penalty
self
.
early_stopping
=
early_stopping
self
.
num_beams
=
num_beams
self
.
beams
=
[]
self
.
worst_score
=
1e9
def
__len__
(
self
):
"""
Number of hypotheses in the list.
"""
return
len
(
self
.
beams
)
def
add
(
self
,
hyp
,
sum_logprobs
):
"""
Add a new hypothesis to the list.
"""
score
=
sum_logprobs
/
len
(
hyp
)
**
self
.
length_penalty
if
len
(
self
)
<
self
.
num_beams
or
score
>
self
.
worst_score
:
self
.
beams
.
append
((
score
,
hyp
))
if
len
(
self
)
>
self
.
num_beams
:
sorted_scores
=
sorted
([(
s
,
idx
)
for
idx
,
(
s
,
_
)
in
enumerate
(
self
.
beams
)])
del
self
.
beams
[
sorted_scores
[
0
][
1
]]
self
.
worst_score
=
sorted_scores
[
1
][
0
]
else
:
self
.
worst_score
=
min
(
score
,
self
.
worst_score
)
def
is_done
(
self
,
best_sum_logprobs
,
cur_len
=
None
):
"""
If there are enough hypotheses and that none of the hypotheses being generated
can become better than the worst one in the heap, then we are done with this sentence.
"""
if
len
(
self
)
<
self
.
num_beams
:
return
False
elif
self
.
early_stopping
:
return
True
else
:
if
cur_len
is
None
:
cur_len
=
self
.
max_length
cur_score
=
best_sum_logprobs
/
cur_len
**
self
.
length_penalty
ret
=
self
.
worst_score
>=
cur_score
return
ret
class
TFConv1D
(
tf
.
keras
.
layers
.
Layer
):
class
TFConv1D
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
nf
,
nx
,
initializer_range
=
0.02
,
**
kwargs
):
def
__init__
(
self
,
nf
,
nx
,
initializer_range
=
0.02
,
**
kwargs
):
""" TFConv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2)
""" TFConv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2)
...
@@ -849,7 +1141,7 @@ class TFSharedEmbeddings(tf.keras.layers.Layer):
...
@@ -849,7 +1141,7 @@ class TFSharedEmbeddings(tf.keras.layers.Layer):
self
.
initializer_range
=
hidden_size
**
-
0.5
if
initializer_range
is
None
else
initializer_range
self
.
initializer_range
=
hidden_size
**
-
0.5
if
initializer_range
is
None
else
initializer_range
def
build
(
self
,
input_shape
):
def
build
(
self
,
input_shape
):
"""Build shared
word
embedding layer
"""Build shared
token
embedding layer
Shared weights logic adapted from
Shared weights logic adapted from
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
"""
"""
...
...
tests/test_modeling_tf_common.py
View file @
bdd3d0c7
...
@@ -381,7 +381,6 @@ class TFModelTesterMixin:
...
@@ -381,7 +381,6 @@ class TFModelTesterMixin:
)
# TODO (PVP): ugly workaround to make code work for t5 for the moment - has to changed when t5 is fixed.
)
# TODO (PVP): ugly workaround to make code work for t5 for the moment - has to changed when t5 is fixed.
for
model_class
in
self
.
all_generative_model_classes
:
for
model_class
in
self
.
all_generative_model_classes
:
# TODO (PVP): add beam search tests when beam search is implemented
model
=
model_class
(
config
)
model
=
model_class
(
config
)
if
config
.
bos_token_id
is
None
:
if
config
.
bos_token_id
is
None
:
...
@@ -389,15 +388,34 @@ class TFModelTesterMixin:
...
@@ -389,15 +388,34 @@ class TFModelTesterMixin:
model
.
generate
(
max_length
=
5
)
model
.
generate
(
max_length
=
5
)
# batch_size = 1
# batch_size = 1
self
.
_check_generated_tokens
(
model
.
generate
(
input_ids
))
self
.
_check_generated_tokens
(
model
.
generate
(
input_ids
))
# batch_size = 1, num_beams > 1
self
.
_check_generated_tokens
(
model
.
generate
(
input_ids
,
num_beams
=
3
))
else
:
else
:
# batch_size = 1
# batch_size = 1
self
.
_check_generated_tokens
(
model
.
generate
(
max_length
=
5
))
self
.
_check_generated_tokens
(
model
.
generate
(
max_length
=
5
))
# batch_size = 1, num_beams > 1
# batch_size = 1, num_beams > 1
self
.
_check_generated_tokens
(
model
.
generate
(
max_length
=
5
,
num_beams
=
3
))
with
self
.
assertRaises
(
AssertionError
):
# generating multiple sequences when greedy no beam generation
# is not allowed as it would always generate the same sequences
model
.
generate
(
input_ids
,
do_sample
=
False
,
num_return_sequences
=
2
)
with
self
.
assertRaises
(
AssertionError
):
# generating more sequences than having beams leads is not possible
model
.
generate
(
input_ids
,
do_sample
=
False
,
num_return_sequences
=
3
,
num_beams
=
2
)
# batch_size > 1, sample
# batch_size > 1, sample
self
.
_check_generated_tokens
(
model
.
generate
(
input_ids
,
num_return_sequences
=
3
))
self
.
_check_generated_tokens
(
model
.
generate
(
input_ids
,
num_return_sequences
=
3
))
# batch_size > 1, greedy
# batch_size > 1, greedy
self
.
_check_generated_tokens
(
model
.
generate
(
input_ids
,
do_sample
=
False
,
num_return_sequences
=
3
))
self
.
_check_generated_tokens
(
model
.
generate
(
input_ids
,
do_sample
=
False
))
# batch_size > 1, num_beams > 1, sample
self
.
_check_generated_tokens
(
model
.
generate
(
input_ids
,
num_beams
=
3
,
num_return_sequences
=
3
,))
# batch_size > 1, num_beams > 1, greedy
self
.
_check_generated_tokens
(
model
.
generate
(
input_ids
,
do_sample
=
False
,
num_beams
=
3
,
num_return_sequences
=
3
)
)
def
_check_generated_tokens
(
self
,
output_ids
):
def
_check_generated_tokens
(
self
,
output_ids
):
for
token_id
in
output_ids
[
0
].
numpy
().
tolist
():
for
token_id
in
output_ids
[
0
].
numpy
().
tolist
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment