Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b4ddd267
Unverified
Commit
b4ddd267
authored
Apr 18, 2022
by
Joao Gante
Committed by
GitHub
Apr 18, 2022
Browse files
TF generate refactor - XLA sample (#16713)
parent
02de7a8e
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
187 additions
and
88 deletions
+187
-88
src/transformers/generation_tf_utils.py
src/transformers/generation_tf_utils.py
+139
-69
tests/gpt2/test_modeling_tf_gpt2.py
tests/gpt2/test_modeling_tf_gpt2.py
+17
-17
tests/t5/test_modeling_tf_t5.py
tests/t5/test_modeling_tf_t5.py
+31
-2
No files found.
src/transformers/generation_tf_utils.py
View file @
b4ddd267
...
@@ -346,6 +346,8 @@ class TFGenerationMixin:
...
@@ -346,6 +346,8 @@ class TFGenerationMixin:
A class containing all of the functions supporting generation, to be used as a mixin in [`TFPreTrainedModel`].
A class containing all of the functions supporting generation, to be used as a mixin in [`TFPreTrainedModel`].
"""
"""
seed_generator
=
tf
.
random
.
Generator
.
from_non_deterministic_state
()
def
prepare_inputs_for_generation
(
self
,
inputs
,
**
kwargs
):
def
prepare_inputs_for_generation
(
self
,
inputs
,
**
kwargs
):
"""
"""
Implement in subclasses of [`TFPreTrainedModel`] for custom behavior to prepare inputs in the generate method.
Implement in subclasses of [`TFPreTrainedModel`] for custom behavior to prepare inputs in the generate method.
...
@@ -585,6 +587,7 @@ class TFGenerationMixin:
...
@@ -585,6 +587,7 @@ class TFGenerationMixin:
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
decoder_start_token_id
=
decoder_start_token_id
,
decoder_start_token_id
=
decoder_start_token_id
,
use_cache
=
use_cache
,
use_cache
=
use_cache
,
seed
=
model_kwargs
.
pop
(
"seed"
,
None
),
output_scores
=
output_scores
,
output_scores
=
output_scores
,
output_attentions
=
output_attentions
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
output_hidden_states
=
output_hidden_states
,
...
@@ -1288,6 +1291,7 @@ class TFGenerationMixin:
...
@@ -1288,6 +1291,7 @@ class TFGenerationMixin:
attention_mask
=
None
,
attention_mask
=
None
,
decoder_start_token_id
=
None
,
decoder_start_token_id
=
None
,
use_cache
=
None
,
use_cache
=
None
,
seed
=
None
,
output_scores
=
None
,
output_scores
=
None
,
output_attentions
=
None
,
output_attentions
=
None
,
output_hidden_states
=
None
,
output_hidden_states
=
None
,
...
@@ -1365,6 +1369,9 @@ class TFGenerationMixin:
...
@@ -1365,6 +1369,9 @@ class TFGenerationMixin:
use_cache (`bool`, *optional*, defaults to `True`):
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
speed up decoding.
speed up decoding.
seed (`List[int]`, *optional*):
Random seed to control sampling, containing two integers, used when `do_sample` is `True`. See the
`seed` argument from stateless functions in `tf.random`.
output_attentions (`bool`, *optional*, defaults to `False`):
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more details.
returned tensors for more details.
...
@@ -1590,6 +1597,7 @@ class TFGenerationMixin:
...
@@ -1590,6 +1597,7 @@ class TFGenerationMixin:
max_length
=
max_length
,
max_length
=
max_length
,
pad_token_id
=
pad_token_id
,
pad_token_id
=
pad_token_id
,
eos_token_id
=
eos_token_id
,
eos_token_id
=
eos_token_id
,
seed
=
seed
,
output_scores
=
output_scores
,
output_scores
=
output_scores
,
return_dict_in_generate
=
return_dict_in_generate
,
return_dict_in_generate
=
return_dict_in_generate
,
**
model_kwargs
,
**
model_kwargs
,
...
@@ -1723,7 +1731,7 @@ class TFGenerationMixin:
...
@@ -1723,7 +1731,7 @@ class TFGenerationMixin:
**
model_kwargs
,
**
model_kwargs
,
)
->
Tuple
[
tf
.
Tensor
,
Dict
[
str
,
Any
]]:
)
->
Tuple
[
tf
.
Tensor
,
Dict
[
str
,
Any
]]:
expanded_return_idx
=
tf
.
reshape
(
expanded_return_idx
=
tf
.
reshape
(
tf
.
tile
(
tf
.
reshape
(
tf
.
range
(
input_ids
.
shape
[
0
]),
(
-
1
,
1
)),
(
1
,
expand_size
)),
(
-
1
)
tf
.
tile
(
tf
.
reshape
(
tf
.
range
(
input_ids
.
shape
[
0
]),
(
-
1
,
1
)),
(
1
,
expand_size
)),
(
-
1
,
)
)
)
input_ids
=
tf
.
gather
(
input_ids
,
expanded_return_idx
,
axis
=
0
)
input_ids
=
tf
.
gather
(
input_ids
,
expanded_return_idx
,
axis
=
0
)
...
@@ -2123,6 +2131,7 @@ class TFGenerationMixin:
...
@@ -2123,6 +2131,7 @@ class TFGenerationMixin:
max_length
:
Optional
[
int
]
=
None
,
max_length
:
Optional
[
int
]
=
None
,
pad_token_id
:
Optional
[
int
]
=
None
,
pad_token_id
:
Optional
[
int
]
=
None
,
eos_token_id
:
Optional
[
int
]
=
None
,
eos_token_id
:
Optional
[
int
]
=
None
,
seed
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
output_scores
:
Optional
[
bool
]
=
None
,
output_scores
:
Optional
[
bool
]
=
None
,
...
@@ -2149,6 +2158,9 @@ class TFGenerationMixin:
...
@@ -2149,6 +2158,9 @@ class TFGenerationMixin:
The id of the *padding* token.
The id of the *padding* token.
eos_token_id (`int`, *optional*):
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
The id of the *end-of-sequence* token.
seed (`List[int]`, *optional*):
Random seed to control sampling, containing two integers, used when `do_sample` is `True`. See the
`seed` argument from stateless functions in `tf.random`.
output_attentions (`bool`, *optional*, defaults to `False`):
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more details.
returned tensors for more details.
...
@@ -2210,7 +2222,7 @@ class TFGenerationMixin:
...
@@ -2210,7 +2222,7 @@ class TFGenerationMixin:
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
```"""
```"""
# init values
#
1.
init
greedy_search
values
logits_processor
=
logits_processor
if
logits_processor
is
not
None
else
TFLogitsProcessorList
()
logits_processor
=
logits_processor
if
logits_processor
is
not
None
else
TFLogitsProcessorList
()
logits_warper
=
logits_warper
if
logits_warper
is
not
None
else
TFLogitsProcessorList
()
logits_warper
=
logits_warper
if
logits_warper
is
not
None
else
TFLogitsProcessorList
()
...
@@ -2224,97 +2236,155 @@ class TFGenerationMixin:
...
@@ -2224,97 +2236,155 @@ class TFGenerationMixin:
return_dict_in_generate
=
(
return_dict_in_generate
=
(
return_dict_in_generate
if
return_dict_in_generate
is
not
None
else
self
.
config
.
return_dict_in_generate
return_dict_in_generate
if
return_dict_in_generate
is
not
None
else
self
.
config
.
return_dict_in_generate
)
)
use_xla
=
not
tf
.
executing_eagerly
()
# init attention
/
hidden
states
/
scores tuples
#
2.
init
`
attention
s`, `
hidden
_
states
`, and `
scores
`
tuples
scores
=
()
if
(
return_dict_in_generate
and
output_scores
)
else
None
scores
=
[]
if
(
return_dict_in_generate
and
output_scores
)
else
None
decoder_attentions
=
()
if
(
return_dict_in_generate
and
output_attentions
)
else
None
decoder_attentions
=
[]
if
(
return_dict_in_generate
and
output_attentions
)
else
None
cross_attentions
=
()
if
(
return_dict_in_generate
and
output_attentions
)
else
None
cross_attentions
=
[]
if
(
return_dict_in_generate
and
output_attentions
)
else
None
decoder_hidden_states
=
()
if
(
return_dict_in_generate
and
output_hidden_states
)
else
None
decoder_hidden_states
=
[]
if
(
return_dict_in_generate
and
output_hidden_states
)
else
None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
# 3. init tensors to use for "xla-compileable" generate function
if
return_dict_in_generate
and
self
.
config
.
is_encoder_decoder
:
# define bsz, seq_length
encoder_attentions
=
model_kwargs
[
"encoder_outputs"
].
get
(
"attentions"
)
if
output_attentions
else
None
batch_size
,
cur_len
=
input_ids
.
shape
encoder_hidden_states
=
(
model_kwargs
[
"encoder_outputs"
].
get
(
"hidden_states"
)
if
output_hidden_states
else
None
# initialize `generated`, `finished_sequences`
generated
=
tf
.
TensorArray
(
element_shape
=
(
batch_size
,),
dtype
=
tf
.
int32
,
dynamic_size
=
False
,
size
=
max_length
,
clear_after_read
=
False
,
)
)
finished_sequences
=
tf
.
zeros
((
batch_size
,),
dtype
=
tf
.
bool
)
#
keep track of which sequences are already finish
ed
#
write prompt to generat
ed
unfinished_sequences
=
tf
.
ones_like
(
input_ids
[:,
0
])
for
i
in
range
(
cur_len
):
cur_len
=
input_ids
.
shape
[
-
1
]
generated
=
generated
.
write
(
i
,
input_ids
[:,
i
])
while
cur_len
<
max_length
:
# 4. define "xla-compile-able" stop-condition and auto-regressive function
# prepare model inputs
def
sample_cond_fn
(
generated
,
finished_sequences
,
next_tokens
,
cur_len
,
model_kwargs
):
model_inputs
=
sel
f
.
p
re
pare_inputs_for_generation
(
input_ids
,
**
model_kwarg
s
)
return
~
t
f
.
re
duce_all
(
finished_sequence
s
)
# forward pass to get next token
def
sample_body_fn
(
generated
,
finished_sequences
,
next_tokens
,
cur_len
,
model_kwargs
):
# TODO(pvp, Joao) - `use_xla` can be removed here as soon as `position_ids` are corrected for the non-xla case in gpt2's `prepare_inputs_for_generation`.
model_inputs
=
self
.
prepare_inputs_for_generation
(
next_tokens
,
use_xla
=
use_xla
,
**
model_kwargs
)
# forward pass to get next token logits
outputs
=
self
(
outputs
=
self
(
**
model_inputs
,
**
model_inputs
,
return_dict
=
True
,
return_dict
=
True
,
output_attentions
=
output_attentions
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
output_hidden_states
=
output_hidden_states
,
)
)
next_token_logits
=
outputs
.
logits
[:,
-
1
]
next_token_logits
=
outputs
.
logits
[:,
-
1
,
:]
# pre-process distribution
next_token_scores
=
logits_processor
(
input_ids
,
next_token_logits
,
cur_len
=
cur_len
)
next_token_scores
=
logits_warper
(
input_ids
,
next_token_scores
)
# Store scores, attentions and hidden_states when required
# Store scores, attentions and hidden_states when required
if
return_dict_in_generate
:
if
not
use_xla
and
return_dict_in_generate
:
if
output_scores
:
if
output_scores
:
scores
+=
(
next_token_
scores
,
)
scores
.
append
(
next_token_
logits
)
if
output_attentions
:
if
output_attentions
and
self
.
config
.
is_encoder_decoder
:
decoder_attentions
+=
(
decoder_attentions
.
append
(
outputs
.
decoder_attentions
)
(
outputs
.
decoder
_attentions
,)
if
self
.
config
.
is_encoder_decoder
else
(
outputs
.
attentions
,)
elif
output
_attentions
and
not
self
.
config
.
is_encoder_decoder
:
)
decoder_attentions
.
append
(
outputs
.
attentions
)
if
self
.
config
.
is_encoder_decoder
:
if
self
.
config
.
is_encoder_decoder
:
cross_attentions
+=
(
outputs
.
cross_attentions
,
)
cross_attentions
.
append
(
outputs
.
cross_attentions
)
if
output_hidden_states
:
if
output_hidden_states
and
self
.
config
.
is_encoder_decoder
:
decoder_hidden_states
+=
(
decoder_hidden_states
.
append
(
outputs
.
decoder_hidden_states
)
(
outputs
.
decoder_hidden_states
,)
elif
output_hidden_states
and
self
.
config
.
is_encoder_decoder
:
if
self
.
config
.
is_encoder_decoder
decoder_hidden_states
.
append
(
outputs
.
hidden_states
)
else
(
outputs
.
hidden_states
,)
)
# pre-process distribution
# TODO(pvp, joao, matt) - all the logits processors/wrappers need to be adapted
# to be XLA compatible
input_ids
=
None
if
not
use_xla
:
input_ids
=
tf
.
reshape
(
generated
.
concat
(),
(
-
1
,
batch_size
))
input_ids
=
tf
.
transpose
(
input_ids
[:
cur_len
])
next_tokens_scores
=
logits_processor
(
input_ids
,
next_token_logits
,
cur_len
=
cur_len
)
next_tokens_scores
=
logits_warper
(
input_ids
,
next_tokens_scores
)
# sample
# sample
if
seed
is
not
None
:
sample_seed
=
seed
else
:
sample_seed
=
tf
.
cast
(
self
.
seed_generator
.
make_seeds
(
count
=
1
)[:,
0
],
dtype
=
tf
.
int32
)
next_tokens
=
tf
.
squeeze
(
next_tokens
=
tf
.
squeeze
(
tf
.
random
.
categorical
(
logits
=
next_token_scores
,
num_samples
=
1
,
dtype
=
tf
.
int32
),
axis
=
1
tf
.
random
.
stateless_categorical
(
logits
=
next_tokens_scores
,
num_samples
=
1
,
seed
=
sample_seed
,
dtype
=
tf
.
int32
),
axis
=
1
,
)
)
# finished sentences should have their next token be a padding token
if
eos_token_id
is
not
None
:
if
eos_token_id
is
not
None
:
if
pad_token_id
is
None
:
if
pad_token_id
is
None
:
raise
ValueError
(
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
)
raise
ValueError
(
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
)
next_tokens
=
next_tokens
*
unfinished_sequences
+
pad_token_id
*
(
1
-
unfinished_sequences
)
unfinished_seq
=
1
-
tf
.
cast
(
finished_sequences
,
tf
.
int32
)
next_tokens
=
next_tokens
*
unfinished_seq
+
pad_token_id
*
(
1
-
unfinished_seq
)
finished_sequences
=
finished_sequences
|
(
next_tokens
==
eos_token_id
)
# update `generated` and `cur_len`
generated
=
generated
.
write
(
cur_len
,
next_tokens
)
next_tokens
=
next_tokens
[:,
None
]
cur_len
+=
1
# update generated ids, model inputs, and length for next step
# update model_kwargs
input_ids
=
tf
.
concat
([
input_ids
,
next_tokens
[:,
None
]],
axis
=-
1
)
if
use_xla
:
model_kwargs
=
self
.
_update_model_kwargs_for_xla_generation
(
outputs
,
model_kwargs
,
cur_len
,
max_length
)
else
:
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
)
)
cur_len
=
cur_len
+
1
# if we don't cache past key values we need the whole input
if
model_kwargs
.
get
(
"past"
,
None
)
is
None
:
# let's throw out `past` since we don't want `None` tensors
model_kwargs
.
pop
(
"past"
,
None
)
# if eos_token was found in one sentence, set sentence to finished
next_tokens
=
tf
.
reshape
(
generated
.
concat
(),
(
-
1
,
batch_size
))
if
eos_token_id
is
not
None
:
next_tokens
=
tf
.
transpose
(
next_tokens
[:
cur_len
])
eos_in_sents
=
next_tokens
==
eos_token_id
# if sentence is unfinished and the token to add is eos
return
generated
,
finished_sequences
,
next_tokens
,
cur_len
,
model_kwargs
is_sents_unfinished_and_token_to_add_is_eos
=
tf
.
math
.
multiply
(
unfinished_sequences
,
tf
.
cast
(
eos_in_sents
,
tf
.
int32
)
# 5. run generation
# 1st generation step has to be run before to initialize `past`
generated
,
finished_sequences
,
next_tokens
,
cur_len
,
model_kwargs
=
sample_body_fn
(
generated
,
finished_sequences
,
input_ids
,
cur_len
,
model_kwargs
)
)
# unfinished_sequences is set to zero if eos in sentence
# 2-to-n generation steps can then be run in autoregressive fashion
unfinished_sequences
-=
is_sents_unfinished_and_token_to_add_is_eos
# only in case 1st generation step does NOT yield EOS token though
if
sample_cond_fn
(
generated
,
finished_sequences
,
next_tokens
,
cur_len
,
model_kwargs
):
maximum_iterations
=
max_length
-
cur_len
-
1
generated
,
_
,
_
,
cur_len
,
_
=
tf
.
while_loop
(
sample_cond_fn
,
sample_body_fn
,
(
generated
,
finished_sequences
,
next_tokens
,
cur_len
,
model_kwargs
),
maximum_iterations
=
maximum_iterations
,
)
# stop when each sentence is finished, or if we exceed the maximum length
# 6. prepare outputs
if
tf
.
math
.
reduce_max
(
unfinished_sequences
)
==
0
:
output_ids
=
tf
.
transpose
(
tf
.
reshape
(
generated
.
concat
(),
(
-
1
,
batch_size
)))
break
if
not
use_xla
:
# cut for backward compatibility
output_ids
=
output_ids
[:,
:
cur_len
]
if
return_dict_in_generate
:
if
return_dict_in_generate
:
if
self
.
config
.
is_encoder_decoder
:
if
self
.
config
.
is_encoder_decoder
:
# if model is an encoder-decoder, retrieve encoder attention weights
# and hidden states
encoder_attentions
=
model_kwargs
[
"encoder_outputs"
].
get
(
"attentions"
)
if
output_attentions
else
None
encoder_hidden_states
=
(
model_kwargs
[
"encoder_outputs"
].
get
(
"hidden_states"
)
if
output_hidden_states
else
None
)
scores
=
tuple
(
scores
)
if
scores
is
not
None
else
None
decoder_attentions
=
tuple
(
decoder_attentions
)
if
decoder_attentions
is
not
None
else
None
cross_attentions
=
tuple
(
cross_attentions
)
if
cross_attentions
is
not
None
else
None
decoder_hidden_states
=
tuple
(
decoder_hidden_states
)
if
decoder_hidden_states
is
not
None
else
None
return
TFSampleEncoderDecoderOutput
(
return
TFSampleEncoderDecoderOutput
(
sequences
=
in
put_ids
,
sequences
=
out
put_ids
,
scores
=
scores
,
scores
=
scores
,
encoder_attentions
=
encoder_attentions
,
encoder_attentions
=
encoder_attentions
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
...
@@ -2324,13 +2394,13 @@ class TFGenerationMixin:
...
@@ -2324,13 +2394,13 @@ class TFGenerationMixin:
)
)
else
:
else
:
return
TFSampleDecoderOnlyOutput
(
return
TFSampleDecoderOnlyOutput
(
sequences
=
in
put_ids
,
sequences
=
out
put_ids
,
scores
=
scores
,
scores
=
scores
,
attentions
=
decoder_attentions
,
attentions
=
decoder_attentions
,
hidden_states
=
decoder_hidden_states
,
hidden_states
=
decoder_hidden_states
,
)
)
else
:
else
:
return
in
put_ids
return
out
put_ids
def
beam_search
(
def
beam_search
(
self
,
self
,
...
@@ -2575,8 +2645,8 @@ class TFGenerationMixin:
...
@@ -2575,8 +2645,8 @@ class TFGenerationMixin:
sequences
,
sequences
,
scores
,
scores
,
is_sent_finished
,
is_sent_finished
,
model_kwargs
,
input_ids_length
,
input_ids_length
,
model_kwargs
,
):
):
"""
"""
Beam Search termination condition function -- halts the generation loop if any of these conditions becomes
Beam Search termination condition function -- halts the generation loop if any of these conditions becomes
...
@@ -2604,8 +2674,8 @@ class TFGenerationMixin:
...
@@ -2604,8 +2674,8 @@ class TFGenerationMixin:
sequences
,
sequences
,
scores
,
scores
,
is_sent_finished
,
is_sent_finished
,
model_kwargs
,
input_ids_length
,
input_ids_length
,
model_kwargs
,
intermediary_running_sequences
=
None
,
intermediary_running_sequences
=
None
,
):
):
"""
"""
...
@@ -2781,8 +2851,8 @@ class TFGenerationMixin:
...
@@ -2781,8 +2851,8 @@ class TFGenerationMixin:
next_sequences
,
next_sequences
,
next_scores
,
next_scores
,
next_is_sent_finished
,
next_is_sent_finished
,
next_model_kwargs
,
next_input_ids_length
,
next_input_ids_length
,
next_model_kwargs
,
)
)
# 5. run generation
# 5. run generation
...
@@ -2799,8 +2869,8 @@ class TFGenerationMixin:
...
@@ -2799,8 +2869,8 @@ class TFGenerationMixin:
sequences
,
sequences
,
scores
,
scores
,
is_sent_finished
,
is_sent_finished
,
model_kwargs
,
input_ids_length
,
input_ids_length
,
model_kwargs
,
)
=
beam_search_body_fn
(
)
=
beam_search_body_fn
(
cur_len
,
cur_len
,
running_sequences
,
running_sequences
,
...
@@ -2808,8 +2878,8 @@ class TFGenerationMixin:
...
@@ -2808,8 +2878,8 @@ class TFGenerationMixin:
sequences
,
sequences
,
scores
,
scores
,
is_sent_finished
,
is_sent_finished
,
model_kwargs
,
input_ids_length
,
input_ids_length
,
model_kwargs
,
)
)
# 2-to-n generation steps can then be run in autoregressive fashion (only in case 1st generation step does
# 2-to-n generation steps can then be run in autoregressive fashion (only in case 1st generation step does
...
@@ -2821,8 +2891,8 @@ class TFGenerationMixin:
...
@@ -2821,8 +2891,8 @@ class TFGenerationMixin:
sequences
,
sequences
,
scores
,
scores
,
is_sent_finished
,
is_sent_finished
,
model_kwargs
,
input_ids_length
,
input_ids_length
,
model_kwargs
,
):
):
maximum_iterations
=
max_length
-
cur_len
maximum_iterations
=
max_length
-
cur_len
cur_len
,
running_sequences
,
running_scores
,
sequences
,
scores
,
is_sent_finished
,
_
,
_
=
tf
.
while_loop
(
cur_len
,
running_sequences
,
running_scores
,
sequences
,
scores
,
is_sent_finished
,
_
,
_
=
tf
.
while_loop
(
...
@@ -2835,8 +2905,8 @@ class TFGenerationMixin:
...
@@ -2835,8 +2905,8 @@ class TFGenerationMixin:
sequences
,
sequences
,
scores
,
scores
,
is_sent_finished
,
is_sent_finished
,
model_kwargs
,
input_ids_length
,
input_ids_length
,
model_kwargs
,
),
),
maximum_iterations
=
maximum_iterations
,
maximum_iterations
=
maximum_iterations
,
)
)
...
...
tests/gpt2/test_modeling_tf_gpt2.py
View file @
b4ddd267
...
@@ -447,19 +447,6 @@ class TFGPT2ModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC
...
@@ -447,19 +447,6 @@ class TFGPT2ModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC
@
require_tf
@
require_tf
class
TFGPT2ModelLanguageGenerationTest
(
unittest
.
TestCase
):
class
TFGPT2ModelLanguageGenerationTest
(
unittest
.
TestCase
):
@
slow
def
test_lm_generate_distilgpt2
(
self
):
model
=
TFGPT2LMHeadModel
.
from_pretrained
(
"distilgpt2"
)
input_ids
=
tf
.
convert_to_tensor
([[
464
,
1893
]],
dtype
=
tf
.
int32
)
# The president
# The president of the United States, and the president of the United Kingdom, have been in the White
# fmt: off
expected_output_ids
=
[
464
,
1893
,
286
,
262
,
1578
,
1829
,
11
,
290
,
262
,
1893
,
286
,
262
,
1578
,
7526
,
11
,
423
,
587
,
287
,
262
,
2635
]
# fmt: on
output_ids
=
model
.
generate
(
input_ids
,
do_sample
=
False
)
self
.
assertListEqual
(
output_ids
[
0
].
numpy
().
tolist
(),
expected_output_ids
)
@
slow
@
slow
def
test_lm_generate_greedy_distilgpt2_batch_special
(
self
):
def
test_lm_generate_greedy_distilgpt2_batch_special
(
self
):
model
=
TFGPT2LMHeadModel
.
from_pretrained
(
"distilgpt2"
)
model
=
TFGPT2LMHeadModel
.
from_pretrained
(
"distilgpt2"
)
...
@@ -506,18 +493,18 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
...
@@ -506,18 +493,18 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
"temperature"
:
1.5
,
"temperature"
:
1.5
,
"top_k"
:
500
,
"top_k"
:
500
,
"top_p"
:
0.9
,
"top_p"
:
0.9
,
"seed"
:
[
42
,
0
],
# seed set -> deterministic sampling sequence -> deterministic generation
}
}
# forces the generation to happen on CPU, to avoid GPU-related quirks
# forces the generation to happen on CPU, to avoid GPU-related quirks
with
tf
.
device
(
":/CPU:0"
):
with
tf
.
device
(
":/CPU:0"
):
tf
.
random
.
set_seed
(
42
)
# deterministic sampling sequence -> deterministic generation
output_ids
=
model
.
generate
(
input_ids
,
**
generation_kwargs
)
output_ids
=
model
.
generate
(
input_ids
,
**
generation_kwargs
)
output_strings
=
tokenizer
.
batch_decode
(
output_ids
,
skip_special_tokens
=
True
)
output_strings
=
tokenizer
.
batch_decode
(
output_ids
,
skip_special_tokens
=
True
)
expected_output_string
=
[
expected_output_string
=
[
"Today is a beautiful day and
this makes finding holiday travel easier for you to do other project
\n
Oh
"
,
"Today is a beautiful day and
we will make you feel very hot/terrific in all
"
,
"Yesterday was an
enjoyable but especially great note though it certainly upset many Democrats who say
"
,
"Yesterday was an
other solid success as news coverage became standard American domestic television hit.
"
,
]
]
self
.
assertListEqual
(
output_strings
,
expected_output_string
)
self
.
assertListEqual
(
output_strings
,
expected_output_string
)
...
@@ -561,7 +548,7 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
...
@@ -561,7 +548,7 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
self
.
assertListEqual
(
output_ids
[
0
].
numpy
().
tolist
(),
expected_output_ids
)
self
.
assertListEqual
(
output_ids
[
0
].
numpy
().
tolist
(),
expected_output_ids
)
@
slow
@
slow
def
test_lm_generate_gpt2_xla
(
self
):
def
test_lm_generate_gpt2_xla
_greedy
(
self
):
"""This test gives the exact same results as the non-xla test above"""
"""This test gives the exact same results as the non-xla test above"""
model
=
TFGPT2LMHeadModel
.
from_pretrained
(
"gpt2"
)
model
=
TFGPT2LMHeadModel
.
from_pretrained
(
"gpt2"
)
input_ids
=
tf
.
convert_to_tensor
([[
464
,
3290
]],
dtype
=
tf
.
int32
)
# The dog
input_ids
=
tf
.
convert_to_tensor
([[
464
,
3290
]],
dtype
=
tf
.
int32
)
# The dog
...
@@ -574,3 +561,16 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
...
@@ -574,3 +561,16 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
output_ids
=
xla_generate
(
input_ids
,
do_sample
=
False
)
output_ids
=
xla_generate
(
input_ids
,
do_sample
=
False
)
self
.
assertListEqual
(
output_ids
[
0
].
numpy
().
tolist
(),
expected_output_ids
)
self
.
assertListEqual
(
output_ids
[
0
].
numpy
().
tolist
(),
expected_output_ids
)
@
slow
def
test_lm_generate_gpt2_xla_sample
(
self
):
model
=
TFGPT2LMHeadModel
.
from_pretrained
(
"gpt2"
)
input_ids
=
tf
.
convert_to_tensor
([[
464
,
3290
]],
dtype
=
tf
.
int32
)
# The dog
# fmt: off
expected_output_ids
=
[
464
,
3290
,
550
,
284
,
307
,
4376
,
287
,
281
,
4044
,
1363
,
329
,
734
,
812
,
878
,
852
,
4376
,
757
,
329
,
2267
,
0
]
# fmt: on
xla_generate
=
tf
.
function
(
model
.
generate
,
jit_compile
=
True
)
output_ids
=
xla_generate
(
input_ids
,
do_sample
=
True
,
seed
=
[
42
,
0
])
self
.
assertListEqual
(
output_ids
[
0
].
numpy
().
tolist
(),
expected_output_ids
)
tests/t5/test_modeling_tf_t5.py
View file @
b4ddd267
...
@@ -524,6 +524,35 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
...
@@ -524,6 +524,35 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
self
.
assertListEqual
(
expected_output_string
,
output_strings
)
self
.
assertListEqual
(
expected_output_string
,
output_strings
)
@
slow
def
test_sample_xla_generate_simple
(
self
):
model
=
TFT5ForConditionalGeneration
.
from_pretrained
(
"t5-small"
)
tokenizer
=
T5Tokenizer
.
from_pretrained
(
"t5-small"
)
sentence
=
"Translate English to German: Today is a beautiful day."
input_ids
=
tokenizer
(
sentence
,
return_tensors
=
"tf"
,
padding
=
True
).
input_ids
# XLA reorder ops, which causes operations like FP matmul to have slightly different results, causing
# divergences in generate -- especially with sampling.
expected_output_string
=
[
"Heute ist ein schöner Tag."
]
expected_output_string_xla
=
[
"Heute ist ein schöne Tage."
]
# However, notice that the first tokens are the same, for the same seed
assert
expected_output_string
[
0
][:
15
]
==
expected_output_string_xla
[
0
][:
15
]
# forces the generation to happen on CPU, to avoid GPU-related quirks
with
tf
.
device
(
":/CPU:0"
):
# seed set -> deterministic sampling sequence -> deterministic generation
output_ids
=
model
.
generate
(
input_ids
,
do_sample
=
True
,
seed
=
[
42
,
0
])
output_strings
=
tokenizer
.
batch_decode
(
output_ids
,
skip_special_tokens
=
True
)
self
.
assertListEqual
(
expected_output_string
,
output_strings
)
# forces the generation to happen on CPU, to avoid GPU-related quirks
with
tf
.
device
(
":/CPU:0"
):
xla_generate
=
tf
.
function
(
model
.
generate
,
jit_compile
=
True
)
# seed set -> deterministic sampling sequence -> deterministic generation
output_ids_xla
=
xla_generate
(
input_ids
,
do_sample
=
True
,
seed
=
[
42
,
0
])
output_strings_xla
=
tokenizer
.
batch_decode
(
output_ids_xla
,
skip_special_tokens
=
True
)
self
.
assertListEqual
(
expected_output_string_xla
,
output_strings_xla
)
@
slow
@
slow
def
test_sample_generate
(
self
):
def
test_sample_generate
(
self
):
model
=
TFT5ForConditionalGeneration
.
from_pretrained
(
"t5-small"
)
model
=
TFT5ForConditionalGeneration
.
from_pretrained
(
"t5-small"
)
...
@@ -540,16 +569,16 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
...
@@ -540,16 +569,16 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
"temperature"
:
0.8
,
"temperature"
:
0.8
,
"top_k"
:
500
,
"top_k"
:
500
,
"top_p"
:
0.9
,
"top_p"
:
0.9
,
"seed"
:
[
20
,
0
],
# seed set -> deterministic sampling sequence -> deterministic generation
}
}
# forces the generation to happen on CPU, to avoid GPU-related quirks
# forces the generation to happen on CPU, to avoid GPU-related quirks
with
tf
.
device
(
":/CPU:0"
):
with
tf
.
device
(
":/CPU:0"
):
tf
.
random
.
set_seed
(
42
)
# deterministic sampling sequence -> deterministic generation
output_ids
=
model
.
generate
(
input_ids
,
**
generation_kwargs
)
output_ids
=
model
.
generate
(
input_ids
,
**
generation_kwargs
)
output_strings
=
tokenizer
.
batch_decode
(
output_ids
,
skip_special_tokens
=
True
)
output_strings
=
tokenizer
.
batch_decode
(
output_ids
,
skip_special_tokens
=
True
)
expected_output_string
=
[
"
i love her
I really love my
heart
"
,
"die Transformatoren sind wirklich erstaunlich"
]
expected_output_string
=
[
"
-
I really love my
way of this.
"
,
"die Transformatoren sind wirklich erstaunlich"
]
self
.
assertListEqual
(
expected_output_string
,
output_strings
)
self
.
assertListEqual
(
expected_output_string
,
output_strings
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment