Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b4ddd267
"docs/vscode:/vscode.git/clone" did not exist on "ae710425d2d8edf4d197bf893b90ed0546054701"
Unverified
Commit
b4ddd267
authored
Apr 18, 2022
by
Joao Gante
Committed by
GitHub
Apr 18, 2022
Browse files
TF generate refactor - XLA sample (#16713)
parent
02de7a8e
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
187 additions
and
88 deletions
+187
-88
src/transformers/generation_tf_utils.py
src/transformers/generation_tf_utils.py
+139
-69
tests/gpt2/test_modeling_tf_gpt2.py
tests/gpt2/test_modeling_tf_gpt2.py
+17
-17
tests/t5/test_modeling_tf_t5.py
tests/t5/test_modeling_tf_t5.py
+31
-2
No files found.
src/transformers/generation_tf_utils.py
View file @
b4ddd267
...
...
@@ -346,6 +346,8 @@ class TFGenerationMixin:
A class containing all of the functions supporting generation, to be used as a mixin in [`TFPreTrainedModel`].
"""
seed_generator
=
tf
.
random
.
Generator
.
from_non_deterministic_state
()
def
prepare_inputs_for_generation
(
self
,
inputs
,
**
kwargs
):
"""
Implement in subclasses of [`TFPreTrainedModel`] for custom behavior to prepare inputs in the generate method.
...
...
@@ -585,6 +587,7 @@ class TFGenerationMixin:
attention_mask
=
attention_mask
,
decoder_start_token_id
=
decoder_start_token_id
,
use_cache
=
use_cache
,
seed
=
model_kwargs
.
pop
(
"seed"
,
None
),
output_scores
=
output_scores
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
...
...
@@ -1288,6 +1291,7 @@ class TFGenerationMixin:
attention_mask
=
None
,
decoder_start_token_id
=
None
,
use_cache
=
None
,
seed
=
None
,
output_scores
=
None
,
output_attentions
=
None
,
output_hidden_states
=
None
,
...
...
@@ -1365,6 +1369,9 @@ class TFGenerationMixin:
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
speed up decoding.
seed (`List[int]`, *optional*):
Random seed to control sampling, containing two integers, used when `do_sample` is `True`. See the
`seed` argument from stateless functions in `tf.random`.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more details.
...
...
@@ -1590,6 +1597,7 @@ class TFGenerationMixin:
max_length
=
max_length
,
pad_token_id
=
pad_token_id
,
eos_token_id
=
eos_token_id
,
seed
=
seed
,
output_scores
=
output_scores
,
return_dict_in_generate
=
return_dict_in_generate
,
**
model_kwargs
,
...
...
@@ -1723,7 +1731,7 @@ class TFGenerationMixin:
**
model_kwargs
,
)
->
Tuple
[
tf
.
Tensor
,
Dict
[
str
,
Any
]]:
expanded_return_idx
=
tf
.
reshape
(
tf
.
tile
(
tf
.
reshape
(
tf
.
range
(
input_ids
.
shape
[
0
]),
(
-
1
,
1
)),
(
1
,
expand_size
)),
(
-
1
)
tf
.
tile
(
tf
.
reshape
(
tf
.
range
(
input_ids
.
shape
[
0
]),
(
-
1
,
1
)),
(
1
,
expand_size
)),
(
-
1
,
)
)
input_ids
=
tf
.
gather
(
input_ids
,
expanded_return_idx
,
axis
=
0
)
...
...
@@ -2123,6 +2131,7 @@ class TFGenerationMixin:
max_length
:
Optional
[
int
]
=
None
,
pad_token_id
:
Optional
[
int
]
=
None
,
eos_token_id
:
Optional
[
int
]
=
None
,
seed
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
output_scores
:
Optional
[
bool
]
=
None
,
...
...
@@ -2149,6 +2158,9 @@ class TFGenerationMixin:
The id of the *padding* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
seed (`List[int]`, *optional*):
Random seed to control sampling, containing two integers, used when `do_sample` is `True`. See the
`seed` argument from stateless functions in `tf.random`.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more details.
...
...
@@ -2210,7 +2222,7 @@ class TFGenerationMixin:
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
```"""
# init values
#
1.
init
greedy_search
values
logits_processor
=
logits_processor
if
logits_processor
is
not
None
else
TFLogitsProcessorList
()
logits_warper
=
logits_warper
if
logits_warper
is
not
None
else
TFLogitsProcessorList
()
...
...
@@ -2224,97 +2236,155 @@ class TFGenerationMixin:
return_dict_in_generate
=
(
return_dict_in_generate
if
return_dict_in_generate
is
not
None
else
self
.
config
.
return_dict_in_generate
)
use_xla
=
not
tf
.
executing_eagerly
()
# init attention
/
hidden
states
/
scores tuples
scores
=
()
if
(
return_dict_in_generate
and
output_scores
)
else
None
decoder_attentions
=
()
if
(
return_dict_in_generate
and
output_attentions
)
else
None
cross_attentions
=
()
if
(
return_dict_in_generate
and
output_attentions
)
else
None
decoder_hidden_states
=
()
if
(
return_dict_in_generate
and
output_hidden_states
)
else
None
#
2.
init
`
attention
s`, `
hidden
_
states
`, and `
scores
`
tuples
scores
=
[]
if
(
return_dict_in_generate
and
output_scores
)
else
None
decoder_attentions
=
[]
if
(
return_dict_in_generate
and
output_attentions
)
else
None
cross_attentions
=
[]
if
(
return_dict_in_generate
and
output_attentions
)
else
None
decoder_hidden_states
=
[]
if
(
return_dict_in_generate
and
output_hidden_states
)
else
None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if
return_dict_in_generate
and
self
.
config
.
is_encoder_decoder
:
encoder_attentions
=
model_kwargs
[
"encoder_outputs"
].
get
(
"attentions"
)
if
output_attentions
else
None
encoder_hidden_states
=
(
model_kwargs
[
"encoder_outputs"
].
get
(
"hidden_states"
)
if
output_hidden_states
else
None
# 3. init tensors to use for "xla-compileable" generate function
# define bsz, seq_length
batch_size
,
cur_len
=
input_ids
.
shape
# initialize `generated`, `finished_sequences`
generated
=
tf
.
TensorArray
(
element_shape
=
(
batch_size
,),
dtype
=
tf
.
int32
,
dynamic_size
=
False
,
size
=
max_length
,
clear_after_read
=
False
,
)
finished_sequences
=
tf
.
zeros
((
batch_size
,),
dtype
=
tf
.
bool
)
#
keep track of which sequences are already finish
ed
unfinished_sequences
=
tf
.
ones_like
(
input_ids
[:,
0
])
cur_len
=
input_ids
.
shape
[
-
1
]
#
write prompt to generat
ed
for
i
in
range
(
cur_len
):
generated
=
generated
.
write
(
i
,
input_ids
[:,
i
])
while
cur_len
<
max_length
:
# prepare model inputs
model_inputs
=
sel
f
.
p
re
pare_inputs_for_generation
(
input_ids
,
**
model_kwarg
s
)
# 4. define "xla-compile-able" stop-condition and auto-regressive function
def
sample_cond_fn
(
generated
,
finished_sequences
,
next_tokens
,
cur_len
,
model_kwargs
):
return
~
t
f
.
re
duce_all
(
finished_sequence
s
)
# forward pass to get next token
def
sample_body_fn
(
generated
,
finished_sequences
,
next_tokens
,
cur_len
,
model_kwargs
):
# TODO(pvp, Joao) - `use_xla` can be removed here as soon as `position_ids` are corrected for the non-xla case in gpt2's `prepare_inputs_for_generation`.
model_inputs
=
self
.
prepare_inputs_for_generation
(
next_tokens
,
use_xla
=
use_xla
,
**
model_kwargs
)
# forward pass to get next token logits
outputs
=
self
(
**
model_inputs
,
return_dict
=
True
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
)
next_token_logits
=
outputs
.
logits
[:,
-
1
,
:]
# pre-process distribution
next_token_scores
=
logits_processor
(
input_ids
,
next_token_logits
,
cur_len
=
cur_len
)
next_token_scores
=
logits_warper
(
input_ids
,
next_token_scores
)
next_token_logits
=
outputs
.
logits
[:,
-
1
]
# Store scores, attentions and hidden_states when required
if
return_dict_in_generate
:
if
not
use_xla
and
return_dict_in_generate
:
if
output_scores
:
scores
+=
(
next_token_
scores
,
)
if
output_attentions
:
decoder_attentions
+=
(
(
outputs
.
decoder
_attentions
,)
if
self
.
config
.
is_encoder_decoder
else
(
outputs
.
attentions
,)
)
scores
.
append
(
next_token_
logits
)
if
output_attentions
and
self
.
config
.
is_encoder_decoder
:
decoder_attentions
.
append
(
outputs
.
decoder_attentions
)
elif
output
_attentions
and
not
self
.
config
.
is_encoder_decoder
:
decoder_attentions
.
append
(
outputs
.
attentions
)
if
self
.
config
.
is_encoder_decoder
:
cross_attentions
+=
(
outputs
.
cross_attentions
,
)
cross_attentions
.
append
(
outputs
.
cross_attentions
)
if
output_hidden_states
:
decoder_hidden_states
+=
(
(
outputs
.
decoder_hidden_states
,)
if
self
.
config
.
is_encoder_decoder
else
(
outputs
.
hidden_states
,)
)
if
output_hidden_states
and
self
.
config
.
is_encoder_decoder
:
decoder_hidden_states
.
append
(
outputs
.
decoder_hidden_states
)
elif
output_hidden_states
and
self
.
config
.
is_encoder_decoder
:
decoder_hidden_states
.
append
(
outputs
.
hidden_states
)
# pre-process distribution
# TODO(pvp, joao, matt) - all the logits processors/wrappers need to be adapted
# to be XLA compatible
input_ids
=
None
if
not
use_xla
:
input_ids
=
tf
.
reshape
(
generated
.
concat
(),
(
-
1
,
batch_size
))
input_ids
=
tf
.
transpose
(
input_ids
[:
cur_len
])
next_tokens_scores
=
logits_processor
(
input_ids
,
next_token_logits
,
cur_len
=
cur_len
)
next_tokens_scores
=
logits_warper
(
input_ids
,
next_tokens_scores
)
# sample
if
seed
is
not
None
:
sample_seed
=
seed
else
:
sample_seed
=
tf
.
cast
(
self
.
seed_generator
.
make_seeds
(
count
=
1
)[:,
0
],
dtype
=
tf
.
int32
)
next_tokens
=
tf
.
squeeze
(
tf
.
random
.
categorical
(
logits
=
next_token_scores
,
num_samples
=
1
,
dtype
=
tf
.
int32
),
axis
=
1
tf
.
random
.
stateless_categorical
(
logits
=
next_tokens_scores
,
num_samples
=
1
,
seed
=
sample_seed
,
dtype
=
tf
.
int32
),
axis
=
1
,
)
# finished sentences should have their next token be a padding token
if
eos_token_id
is
not
None
:
if
pad_token_id
is
None
:
raise
ValueError
(
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
)
next_tokens
=
next_tokens
*
unfinished_sequences
+
pad_token_id
*
(
1
-
unfinished_sequences
)
unfinished_seq
=
1
-
tf
.
cast
(
finished_sequences
,
tf
.
int32
)
next_tokens
=
next_tokens
*
unfinished_seq
+
pad_token_id
*
(
1
-
unfinished_seq
)
finished_sequences
=
finished_sequences
|
(
next_tokens
==
eos_token_id
)
# update `generated` and `cur_len`
generated
=
generated
.
write
(
cur_len
,
next_tokens
)
next_tokens
=
next_tokens
[:,
None
]
cur_len
+=
1
# update generated ids, model inputs, and length for next step
input_ids
=
tf
.
concat
([
input_ids
,
next_tokens
[:,
None
]],
axis
=-
1
)
# update model_kwargs
if
use_xla
:
model_kwargs
=
self
.
_update_model_kwargs_for_xla_generation
(
outputs
,
model_kwargs
,
cur_len
,
max_length
)
else
:
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
)
cur_len
=
cur_len
+
1
# if we don't cache past key values we need the whole input
if
model_kwargs
.
get
(
"past"
,
None
)
is
None
:
# let's throw out `past` since we don't want `None` tensors
model_kwargs
.
pop
(
"past"
,
None
)
# if eos_token was found in one sentence, set sentence to finished
if
eos_token_id
is
not
None
:
eos_in_sents
=
next_tokens
==
eos_token_id
# if sentence is unfinished and the token to add is eos
is_sents_unfinished_and_token_to_add_is_eos
=
tf
.
math
.
multiply
(
unfinished_sequences
,
tf
.
cast
(
eos_in_sents
,
tf
.
int32
)
next_tokens
=
tf
.
reshape
(
generated
.
concat
(),
(
-
1
,
batch_size
))
next_tokens
=
tf
.
transpose
(
next_tokens
[:
cur_len
])
return
generated
,
finished_sequences
,
next_tokens
,
cur_len
,
model_kwargs
# 5. run generation
# 1st generation step has to be run before to initialize `past`
generated
,
finished_sequences
,
next_tokens
,
cur_len
,
model_kwargs
=
sample_body_fn
(
generated
,
finished_sequences
,
input_ids
,
cur_len
,
model_kwargs
)
# unfinished_sequences is set to zero if eos in sentence
unfinished_sequences
-=
is_sents_unfinished_and_token_to_add_is_eos
# 2-to-n generation steps can then be run in autoregressive fashion
# only in case 1st generation step does NOT yield EOS token though
if
sample_cond_fn
(
generated
,
finished_sequences
,
next_tokens
,
cur_len
,
model_kwargs
):
maximum_iterations
=
max_length
-
cur_len
-
1
generated
,
_
,
_
,
cur_len
,
_
=
tf
.
while_loop
(
sample_cond_fn
,
sample_body_fn
,
(
generated
,
finished_sequences
,
next_tokens
,
cur_len
,
model_kwargs
),
maximum_iterations
=
maximum_iterations
,
)
# stop when each sentence is finished, or if we exceed the maximum length
if
tf
.
math
.
reduce_max
(
unfinished_sequences
)
==
0
:
break
# 6. prepare outputs
output_ids
=
tf
.
transpose
(
tf
.
reshape
(
generated
.
concat
(),
(
-
1
,
batch_size
)))
if
not
use_xla
:
# cut for backward compatibility
output_ids
=
output_ids
[:,
:
cur_len
]
if
return_dict_in_generate
:
if
self
.
config
.
is_encoder_decoder
:
# if model is an encoder-decoder, retrieve encoder attention weights
# and hidden states
encoder_attentions
=
model_kwargs
[
"encoder_outputs"
].
get
(
"attentions"
)
if
output_attentions
else
None
encoder_hidden_states
=
(
model_kwargs
[
"encoder_outputs"
].
get
(
"hidden_states"
)
if
output_hidden_states
else
None
)
scores
=
tuple
(
scores
)
if
scores
is
not
None
else
None
decoder_attentions
=
tuple
(
decoder_attentions
)
if
decoder_attentions
is
not
None
else
None
cross_attentions
=
tuple
(
cross_attentions
)
if
cross_attentions
is
not
None
else
None
decoder_hidden_states
=
tuple
(
decoder_hidden_states
)
if
decoder_hidden_states
is
not
None
else
None
return
TFSampleEncoderDecoderOutput
(
sequences
=
in
put_ids
,
sequences
=
out
put_ids
,
scores
=
scores
,
encoder_attentions
=
encoder_attentions
,
encoder_hidden_states
=
encoder_hidden_states
,
...
...
@@ -2324,13 +2394,13 @@ class TFGenerationMixin:
)
else
:
return
TFSampleDecoderOnlyOutput
(
sequences
=
in
put_ids
,
sequences
=
out
put_ids
,
scores
=
scores
,
attentions
=
decoder_attentions
,
hidden_states
=
decoder_hidden_states
,
)
else
:
return
in
put_ids
return
out
put_ids
def
beam_search
(
self
,
...
...
@@ -2575,8 +2645,8 @@ class TFGenerationMixin:
sequences
,
scores
,
is_sent_finished
,
model_kwargs
,
input_ids_length
,
model_kwargs
,
):
"""
Beam Search termination condition function -- halts the generation loop if any of these conditions becomes
...
...
@@ -2604,8 +2674,8 @@ class TFGenerationMixin:
sequences
,
scores
,
is_sent_finished
,
model_kwargs
,
input_ids_length
,
model_kwargs
,
intermediary_running_sequences
=
None
,
):
"""
...
...
@@ -2781,8 +2851,8 @@ class TFGenerationMixin:
next_sequences
,
next_scores
,
next_is_sent_finished
,
next_model_kwargs
,
next_input_ids_length
,
next_model_kwargs
,
)
# 5. run generation
...
...
@@ -2799,8 +2869,8 @@ class TFGenerationMixin:
sequences
,
scores
,
is_sent_finished
,
model_kwargs
,
input_ids_length
,
model_kwargs
,
)
=
beam_search_body_fn
(
cur_len
,
running_sequences
,
...
...
@@ -2808,8 +2878,8 @@ class TFGenerationMixin:
sequences
,
scores
,
is_sent_finished
,
model_kwargs
,
input_ids_length
,
model_kwargs
,
)
# 2-to-n generation steps can then be run in autoregressive fashion (only in case 1st generation step does
...
...
@@ -2821,8 +2891,8 @@ class TFGenerationMixin:
sequences
,
scores
,
is_sent_finished
,
model_kwargs
,
input_ids_length
,
model_kwargs
,
):
maximum_iterations
=
max_length
-
cur_len
cur_len
,
running_sequences
,
running_scores
,
sequences
,
scores
,
is_sent_finished
,
_
,
_
=
tf
.
while_loop
(
...
...
@@ -2835,8 +2905,8 @@ class TFGenerationMixin:
sequences
,
scores
,
is_sent_finished
,
model_kwargs
,
input_ids_length
,
model_kwargs
,
),
maximum_iterations
=
maximum_iterations
,
)
...
...
tests/gpt2/test_modeling_tf_gpt2.py
View file @
b4ddd267
...
...
@@ -447,19 +447,6 @@ class TFGPT2ModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC
@
require_tf
class
TFGPT2ModelLanguageGenerationTest
(
unittest
.
TestCase
):
@
slow
def
test_lm_generate_distilgpt2
(
self
):
model
=
TFGPT2LMHeadModel
.
from_pretrained
(
"distilgpt2"
)
input_ids
=
tf
.
convert_to_tensor
([[
464
,
1893
]],
dtype
=
tf
.
int32
)
# The president
# The president of the United States, and the president of the United Kingdom, have been in the White
# fmt: off
expected_output_ids
=
[
464
,
1893
,
286
,
262
,
1578
,
1829
,
11
,
290
,
262
,
1893
,
286
,
262
,
1578
,
7526
,
11
,
423
,
587
,
287
,
262
,
2635
]
# fmt: on
output_ids
=
model
.
generate
(
input_ids
,
do_sample
=
False
)
self
.
assertListEqual
(
output_ids
[
0
].
numpy
().
tolist
(),
expected_output_ids
)
@
slow
def
test_lm_generate_greedy_distilgpt2_batch_special
(
self
):
model
=
TFGPT2LMHeadModel
.
from_pretrained
(
"distilgpt2"
)
...
...
@@ -506,18 +493,18 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
"temperature"
:
1.5
,
"top_k"
:
500
,
"top_p"
:
0.9
,
"seed"
:
[
42
,
0
],
# seed set -> deterministic sampling sequence -> deterministic generation
}
# forces the generation to happen on CPU, to avoid GPU-related quirks
with
tf
.
device
(
":/CPU:0"
):
tf
.
random
.
set_seed
(
42
)
# deterministic sampling sequence -> deterministic generation
output_ids
=
model
.
generate
(
input_ids
,
**
generation_kwargs
)
output_strings
=
tokenizer
.
batch_decode
(
output_ids
,
skip_special_tokens
=
True
)
expected_output_string
=
[
"Today is a beautiful day and
this makes finding holiday travel easier for you to do other project
\n
Oh
"
,
"Yesterday was an
enjoyable but especially great note though it certainly upset many Democrats who say
"
,
"Today is a beautiful day and
we will make you feel very hot/terrific in all
"
,
"Yesterday was an
other solid success as news coverage became standard American domestic television hit.
"
,
]
self
.
assertListEqual
(
output_strings
,
expected_output_string
)
...
...
@@ -561,7 +548,7 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
self
.
assertListEqual
(
output_ids
[
0
].
numpy
().
tolist
(),
expected_output_ids
)
@
slow
def
test_lm_generate_gpt2_xla
(
self
):
def
test_lm_generate_gpt2_xla
_greedy
(
self
):
"""This test gives the exact same results as the non-xla test above"""
model
=
TFGPT2LMHeadModel
.
from_pretrained
(
"gpt2"
)
input_ids
=
tf
.
convert_to_tensor
([[
464
,
3290
]],
dtype
=
tf
.
int32
)
# The dog
...
...
@@ -574,3 +561,16 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
output_ids
=
xla_generate
(
input_ids
,
do_sample
=
False
)
self
.
assertListEqual
(
output_ids
[
0
].
numpy
().
tolist
(),
expected_output_ids
)
@
slow
def
test_lm_generate_gpt2_xla_sample
(
self
):
model
=
TFGPT2LMHeadModel
.
from_pretrained
(
"gpt2"
)
input_ids
=
tf
.
convert_to_tensor
([[
464
,
3290
]],
dtype
=
tf
.
int32
)
# The dog
# fmt: off
expected_output_ids
=
[
464
,
3290
,
550
,
284
,
307
,
4376
,
287
,
281
,
4044
,
1363
,
329
,
734
,
812
,
878
,
852
,
4376
,
757
,
329
,
2267
,
0
]
# fmt: on
xla_generate
=
tf
.
function
(
model
.
generate
,
jit_compile
=
True
)
output_ids
=
xla_generate
(
input_ids
,
do_sample
=
True
,
seed
=
[
42
,
0
])
self
.
assertListEqual
(
output_ids
[
0
].
numpy
().
tolist
(),
expected_output_ids
)
tests/t5/test_modeling_tf_t5.py
View file @
b4ddd267
...
...
@@ -524,6 +524,35 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
self
.
assertListEqual
(
expected_output_string
,
output_strings
)
@
slow
def
test_sample_xla_generate_simple
(
self
):
model
=
TFT5ForConditionalGeneration
.
from_pretrained
(
"t5-small"
)
tokenizer
=
T5Tokenizer
.
from_pretrained
(
"t5-small"
)
sentence
=
"Translate English to German: Today is a beautiful day."
input_ids
=
tokenizer
(
sentence
,
return_tensors
=
"tf"
,
padding
=
True
).
input_ids
# XLA reorder ops, which causes operations like FP matmul to have slightly different results, causing
# divergences in generate -- especially with sampling.
expected_output_string
=
[
"Heute ist ein schöner Tag."
]
expected_output_string_xla
=
[
"Heute ist ein schöne Tage."
]
# However, notice that the first tokens are the same, for the same seed
assert
expected_output_string
[
0
][:
15
]
==
expected_output_string_xla
[
0
][:
15
]
# forces the generation to happen on CPU, to avoid GPU-related quirks
with
tf
.
device
(
":/CPU:0"
):
# seed set -> deterministic sampling sequence -> deterministic generation
output_ids
=
model
.
generate
(
input_ids
,
do_sample
=
True
,
seed
=
[
42
,
0
])
output_strings
=
tokenizer
.
batch_decode
(
output_ids
,
skip_special_tokens
=
True
)
self
.
assertListEqual
(
expected_output_string
,
output_strings
)
# forces the generation to happen on CPU, to avoid GPU-related quirks
with
tf
.
device
(
":/CPU:0"
):
xla_generate
=
tf
.
function
(
model
.
generate
,
jit_compile
=
True
)
# seed set -> deterministic sampling sequence -> deterministic generation
output_ids_xla
=
xla_generate
(
input_ids
,
do_sample
=
True
,
seed
=
[
42
,
0
])
output_strings_xla
=
tokenizer
.
batch_decode
(
output_ids_xla
,
skip_special_tokens
=
True
)
self
.
assertListEqual
(
expected_output_string_xla
,
output_strings_xla
)
@
slow
def
test_sample_generate
(
self
):
model
=
TFT5ForConditionalGeneration
.
from_pretrained
(
"t5-small"
)
...
...
@@ -540,16 +569,16 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
"temperature"
:
0.8
,
"top_k"
:
500
,
"top_p"
:
0.9
,
"seed"
:
[
20
,
0
],
# seed set -> deterministic sampling sequence -> deterministic generation
}
# forces the generation to happen on CPU, to avoid GPU-related quirks
with
tf
.
device
(
":/CPU:0"
):
tf
.
random
.
set_seed
(
42
)
# deterministic sampling sequence -> deterministic generation
output_ids
=
model
.
generate
(
input_ids
,
**
generation_kwargs
)
output_strings
=
tokenizer
.
batch_decode
(
output_ids
,
skip_special_tokens
=
True
)
expected_output_string
=
[
"
i love her
I really love my
heart
"
,
"die Transformatoren sind wirklich erstaunlich"
]
expected_output_string
=
[
"
-
I really love my
way of this.
"
,
"die Transformatoren sind wirklich erstaunlich"
]
self
.
assertListEqual
(
expected_output_string
,
output_strings
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment