Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5cce3076
Unverified
Commit
5cce3076
authored
Jun 23, 2022
by
Joao Gante
Committed by
GitHub
Jun 23, 2022
Browse files
TF: generate without `tf.TensorArray` (#17801)
parent
ab223fc1
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
97 additions
and
200 deletions
+97
-200
src/transformers/generation_tf_utils.py
src/transformers/generation_tf_utils.py
+90
-193
src/transformers/models/gpt2/modeling_tf_gpt2.py
src/transformers/models/gpt2/modeling_tf_gpt2.py
+3
-2
src/transformers/models/xlnet/modeling_tf_xlnet.py
src/transformers/models/xlnet/modeling_tf_xlnet.py
+4
-5
No files found.
src/transformers/generation_tf_utils.py
View file @
5cce3076
...
@@ -16,7 +16,6 @@
...
@@ -16,7 +16,6 @@
import
inspect
import
inspect
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
partial
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -1979,6 +1978,8 @@ class TFGenerationMixin:
...
@@ -1979,6 +1978,8 @@ class TFGenerationMixin:
return_dict_in_generate
if
return_dict_in_generate
is
not
None
else
self
.
config
.
return_dict_in_generate
return_dict_in_generate
if
return_dict_in_generate
is
not
None
else
self
.
config
.
return_dict_in_generate
)
)
use_xla
=
not
tf
.
executing_eagerly
()
use_xla
=
not
tf
.
executing_eagerly
()
# some models, like XLNet, need more than the last token in the presence of past
needs_full_input
=
"use_mems"
in
set
(
inspect
.
signature
(
self
.
prepare_inputs_for_generation
).
parameters
.
keys
())
# 2. init `attentions`, `hidden_states`, and `scores` tuples
# 2. init `attentions`, `hidden_states`, and `scores` tuples
scores
=
[]
if
(
return_dict_in_generate
and
output_scores
)
else
None
scores
=
[]
if
(
return_dict_in_generate
and
output_scores
)
else
None
...
@@ -1989,34 +1990,25 @@ class TFGenerationMixin:
...
@@ -1989,34 +1990,25 @@ class TFGenerationMixin:
# 3. init tensors to use for "xla-compileable" generate function
# 3. init tensors to use for "xla-compileable" generate function
batch_size
,
cur_len
=
input_ids
.
shape
batch_size
,
cur_len
=
input_ids
.
shape
# initialize `generated` (pre-populated with `pad_token_id`), `finished_sequences`
# initialize `generated` (`input_ids` padded with `pad_token_id`), `finished_sequences`
generated
=
tf
.
TensorArray
(
input_ids_padding
=
tf
.
ones
((
batch_size
,
max_length
-
cur_len
),
dtype
=
tf
.
int32
)
*
(
pad_token_id
or
0
)
element_shape
=
(
batch_size
,),
generated
=
tf
.
concat
([
input_ids
,
input_ids_padding
],
axis
=-
1
)
dtype
=
tf
.
int32
,
dynamic_size
=
False
,
size
=
max_length
,
clear_after_read
=
False
,
)
if
pad_token_id
:
# ignores the cases when it is 0 or None
for
i
in
range
(
max_length
):
generated
=
generated
.
write
(
i
,
tf
.
broadcast_to
(
pad_token_id
,
(
batch_size
,)))
# write prompt to generated
for
i
in
range
(
cur_len
):
generated
=
generated
.
write
(
i
,
input_ids
[:,
i
])
finished_sequences
=
tf
.
zeros
((
batch_size
,),
dtype
=
tf
.
bool
)
finished_sequences
=
tf
.
zeros
((
batch_size
,),
dtype
=
tf
.
bool
)
# 4. define "xla-compile-able" stop-condition and auto-regressive function
# 4. define "xla-compile-able" stop-condition and auto-regressive function
# define condition fn
# define condition fn
def
greedy_search_cond_fn
(
generated
,
finished_sequences
,
next_tokens
,
cur_len
,
model_kwargs
):
def
greedy_search_cond_fn
(
generated
,
finished_sequences
,
cur_len
,
model_kwargs
):
"""state termination condition fn."""
"""state termination condition fn."""
return
~
tf
.
reduce_all
(
finished_sequences
)
return
~
tf
.
reduce_all
(
finished_sequences
)
# define condition fn
# define condition fn
def
greedy_search_body_fn
(
generated
,
finished_sequences
,
next_tokens
,
cur_len
,
model_kwargs
):
def
greedy_search_body_fn
(
generated
,
finished_sequences
,
cur_len
,
model_kwargs
):
"""state update fn."""
"""state update fn."""
model_inputs
=
self
.
prepare_inputs_for_generation
(
next_tokens
,
**
model_kwargs
)
if
model_kwargs
.
get
(
"past"
)
is
None
or
needs_full_input
:
input_ids
=
generated
[:,
:
cur_len
]
else
:
input_ids
=
tf
.
expand_dims
(
generated
[:,
cur_len
-
1
],
-
1
)
model_inputs
=
self
.
prepare_inputs_for_generation
(
input_ids
,
**
model_kwargs
)
# forward pass to get next token logits
# forward pass to get next token logits
outputs
=
self
(
outputs
=
self
(
**
model_inputs
,
**
model_inputs
,
...
@@ -2043,8 +2035,7 @@ class TFGenerationMixin:
...
@@ -2043,8 +2035,7 @@ class TFGenerationMixin:
decoder_hidden_states
.
append
(
outputs
.
hidden_states
)
decoder_hidden_states
.
append
(
outputs
.
hidden_states
)
# pre-process distribution
# pre-process distribution
input_ids
=
tf
.
transpose
(
tf
.
reshape
(
generated
.
concat
(),
(
-
1
,
batch_size
)))
next_tokens_scores
=
logits_processor
(
generated
,
next_token_logits
,
cur_len
)
next_tokens_scores
=
logits_processor
(
input_ids
,
next_token_logits
,
cur_len
)
# argmax
# argmax
next_tokens
=
tf
.
argmax
(
next_tokens_scores
,
axis
=-
1
,
output_type
=
tf
.
int32
)
next_tokens
=
tf
.
argmax
(
next_tokens_scores
,
axis
=-
1
,
output_type
=
tf
.
int32
)
...
@@ -2057,8 +2048,8 @@ class TFGenerationMixin:
...
@@ -2057,8 +2048,8 @@ class TFGenerationMixin:
finished_sequences
=
finished_sequences
|
(
next_tokens
==
eos_token_id
)
finished_sequences
=
finished_sequences
|
(
next_tokens
==
eos_token_id
)
# update `generated` and `cur_len`
# update `generated` and `cur_len`
generated
=
generated
.
write
(
cur_len
,
next_tokens
)
update_indices
=
tf
.
stack
([
tf
.
range
(
batch_size
),
tf
.
broadcast_to
(
cur_len
,
[
batch_size
])],
axis
=-
1
)
next_tokens
=
next_tokens
[:,
None
]
generated
=
tf
.
tensor_scatter_nd_update
(
tensor
=
generated
,
indices
=
update_indices
,
updates
=
next_tokens
)
cur_len
+=
1
cur_len
+=
1
# update model_kwargs
# update model_kwargs
...
@@ -2073,34 +2064,29 @@ class TFGenerationMixin:
...
@@ -2073,34 +2064,29 @@ class TFGenerationMixin:
# let's throw out `past` since we don't want `None` tensors
# let's throw out `past` since we don't want `None` tensors
model_kwargs
.
pop
(
"past"
,
None
)
model_kwargs
.
pop
(
"past"
,
None
)
next_tokens
=
tf
.
reshape
(
generated
.
concat
(),
(
-
1
,
batch_size
))
return
generated
,
finished_sequences
,
cur_len
,
model_kwargs
next_tokens
=
tf
.
transpose
(
next_tokens
[:
cur_len
])
return
generated
,
finished_sequences
,
next_tokens
,
cur_len
,
model_kwargs
# 5. run generation
# 5. run generation
# 1st generation step has to be run before to initialize `past`
# 1st generation step has to be run before to initialize `past`
generated
,
finished_sequences
,
next_tokens
,
cur_len
,
model_kwargs
=
greedy_search_body_fn
(
generated
,
finished_sequences
,
cur_len
,
model_kwargs
=
greedy_search_body_fn
(
generated
,
finished_sequences
,
input_ids
,
cur_len
,
model_kwargs
generated
,
finished_sequences
,
cur_len
,
model_kwargs
)
)
# 2-to-n generation steps can then be run in autoregressive fashion
# 2-to-n generation steps can then be run in autoregressive fashion
# only in case 1st generation step does NOT yield EOS token though
# only in case 1st generation step does NOT yield EOS token though
if
greedy_search_cond_fn
(
generated
,
finished_sequences
,
next_tokens
,
cur_len
,
model_kwargs
):
if
greedy_search_cond_fn
(
generated
,
finished_sequences
,
cur_len
,
model_kwargs
):
maximum_iterations
=
max_length
-
cur_len
maximum_iterations
=
max_length
-
cur_len
generated
,
_
,
_
,
cur_len
,
_
=
tf
.
while_loop
(
generated
,
_
,
cur_len
,
_
=
tf
.
while_loop
(
greedy_search_cond_fn
,
greedy_search_cond_fn
,
greedy_search_body_fn
,
greedy_search_body_fn
,
(
generated
,
finished_sequences
,
next_tokens
,
cur_len
,
model_kwargs
),
(
generated
,
finished_sequences
,
cur_len
,
model_kwargs
),
maximum_iterations
=
maximum_iterations
,
maximum_iterations
=
maximum_iterations
,
)
)
# 6. prepare outputs
# 6. prepare outputs
output_ids
=
tf
.
transpose
(
tf
.
reshape
(
generated
.
concat
(),
(
-
1
,
batch_size
)))
if
not
use_xla
:
if
not
use_xla
:
# cut for backward compatibility
# cut for backward compatibility
output_ids
=
output_ids
[:,
:
cur_len
]
generated
=
generated
[:,
:
cur_len
]
if
return_dict_in_generate
:
if
return_dict_in_generate
:
if
self
.
config
.
is_encoder_decoder
:
if
self
.
config
.
is_encoder_decoder
:
...
@@ -2117,7 +2103,7 @@ class TFGenerationMixin:
...
@@ -2117,7 +2103,7 @@ class TFGenerationMixin:
decoder_hidden_states
=
tuple
(
decoder_hidden_states
)
if
decoder_hidden_states
is
not
None
else
None
decoder_hidden_states
=
tuple
(
decoder_hidden_states
)
if
decoder_hidden_states
is
not
None
else
None
return
TFGreedySearchEncoderDecoderOutput
(
return
TFGreedySearchEncoderDecoderOutput
(
sequences
=
output_ids
,
sequences
=
generated
,
scores
=
scores
,
scores
=
scores
,
encoder_attentions
=
encoder_attentions
,
encoder_attentions
=
encoder_attentions
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
...
@@ -2127,13 +2113,13 @@ class TFGenerationMixin:
...
@@ -2127,13 +2113,13 @@ class TFGenerationMixin:
)
)
else
:
else
:
return
TFGreedySearchDecoderOnlyOutput
(
return
TFGreedySearchDecoderOnlyOutput
(
sequences
=
output_ids
,
sequences
=
generated
,
scores
=
scores
,
scores
=
scores
,
attentions
=
decoder_attentions
,
attentions
=
decoder_attentions
,
hidden_states
=
decoder_hidden_states
,
hidden_states
=
decoder_hidden_states
,
)
)
else
:
else
:
return
output_ids
return
generated
def
sample
(
def
sample
(
self
,
self
,
...
@@ -2250,6 +2236,8 @@ class TFGenerationMixin:
...
@@ -2250,6 +2236,8 @@ class TFGenerationMixin:
return_dict_in_generate
if
return_dict_in_generate
is
not
None
else
self
.
config
.
return_dict_in_generate
return_dict_in_generate
if
return_dict_in_generate
is
not
None
else
self
.
config
.
return_dict_in_generate
)
)
use_xla
=
not
tf
.
executing_eagerly
()
use_xla
=
not
tf
.
executing_eagerly
()
# some models, like XLNet, need more than the last token in the presence of past
needs_full_input
=
"use_mems"
in
set
(
inspect
.
signature
(
self
.
prepare_inputs_for_generation
).
parameters
.
keys
())
# 2. init `attentions`, `hidden_states`, and `scores` tuples
# 2. init `attentions`, `hidden_states`, and `scores` tuples
scores
=
[]
if
(
return_dict_in_generate
and
output_scores
)
else
None
scores
=
[]
if
(
return_dict_in_generate
and
output_scores
)
else
None
...
@@ -2261,29 +2249,20 @@ class TFGenerationMixin:
...
@@ -2261,29 +2249,20 @@ class TFGenerationMixin:
batch_size
,
cur_len
=
input_ids
.
shape
batch_size
,
cur_len
=
input_ids
.
shape
# initialize `generated` (pre-populated with `pad_token_id`), `finished_sequences`
# initialize `generated` (pre-populated with `pad_token_id`), `finished_sequences`
generated
=
tf
.
TensorArray
(
input_ids_padding
=
tf
.
ones
((
batch_size
,
max_length
-
cur_len
),
dtype
=
tf
.
int32
)
*
(
pad_token_id
or
0
)
element_shape
=
(
batch_size
,),
generated
=
tf
.
concat
([
input_ids
,
input_ids_padding
],
axis
=-
1
)
dtype
=
tf
.
int32
,
dynamic_size
=
False
,
size
=
max_length
,
clear_after_read
=
False
,
)
if
pad_token_id
:
# ignores the cases when it is 0 or None
for
i
in
range
(
max_length
):
generated
=
generated
.
write
(
i
,
tf
.
broadcast_to
(
pad_token_id
,
(
batch_size
,)))
# write prompt to generated
for
i
in
range
(
cur_len
):
generated
=
generated
.
write
(
i
,
input_ids
[:,
i
])
finished_sequences
=
tf
.
zeros
((
batch_size
,),
dtype
=
tf
.
bool
)
finished_sequences
=
tf
.
zeros
((
batch_size
,),
dtype
=
tf
.
bool
)
# 4. define "xla-compile-able" stop-condition and auto-regressive function
# 4. define "xla-compile-able" stop-condition and auto-regressive function
def
sample_cond_fn
(
generated
,
finished_sequences
,
next_tokens
,
cur_len
,
model_kwargs
):
def
sample_cond_fn
(
generated
,
finished_sequences
,
cur_len
,
model_kwargs
):
return
~
tf
.
reduce_all
(
finished_sequences
)
return
~
tf
.
reduce_all
(
finished_sequences
)
def
sample_body_fn
(
generated
,
finished_sequences
,
next_tokens
,
cur_len
,
model_kwargs
):
def
sample_body_fn
(
generated
,
finished_sequences
,
cur_len
,
model_kwargs
):
model_inputs
=
self
.
prepare_inputs_for_generation
(
next_tokens
,
**
model_kwargs
)
if
model_kwargs
.
get
(
"past"
)
is
None
or
needs_full_input
:
input_ids
=
generated
[:,
:
cur_len
]
else
:
input_ids
=
tf
.
expand_dims
(
generated
[:,
cur_len
-
1
],
-
1
)
model_inputs
=
self
.
prepare_inputs_for_generation
(
input_ids
,
**
model_kwargs
)
# forward pass to get next token logits
# forward pass to get next token logits
outputs
=
self
(
outputs
=
self
(
**
model_inputs
,
**
model_inputs
,
...
@@ -2310,9 +2289,8 @@ class TFGenerationMixin:
...
@@ -2310,9 +2289,8 @@ class TFGenerationMixin:
decoder_hidden_states
.
append
(
outputs
.
hidden_states
)
decoder_hidden_states
.
append
(
outputs
.
hidden_states
)
# pre-process distribution
# pre-process distribution
input_ids
=
tf
.
transpose
(
tf
.
reshape
(
generated
.
concat
(),
(
-
1
,
batch_size
)))
next_tokens_scores
=
logits_processor
(
generated
,
next_token_logits
,
cur_len
)
next_tokens_scores
=
logits_processor
(
input_ids
,
next_token_logits
,
cur_len
)
next_tokens_scores
=
logits_warper
(
generated
,
next_tokens_scores
,
cur_len
)
next_tokens_scores
=
logits_warper
(
input_ids
,
next_tokens_scores
,
cur_len
)
# sample
# sample
if
seed
is
not
None
:
if
seed
is
not
None
:
...
@@ -2334,8 +2312,8 @@ class TFGenerationMixin:
...
@@ -2334,8 +2312,8 @@ class TFGenerationMixin:
finished_sequences
=
finished_sequences
|
(
next_tokens
==
eos_token_id
)
finished_sequences
=
finished_sequences
|
(
next_tokens
==
eos_token_id
)
# update `generated` and `cur_len`
# update `generated` and `cur_len`
generated
=
generated
.
write
(
cur_len
,
next_tokens
)
update_indices
=
tf
.
stack
([
tf
.
range
(
batch_size
),
tf
.
broadcast_to
(
cur_len
,
[
batch_size
])],
axis
=-
1
)
next_tokens
=
next_tokens
[:,
None
]
generated
=
tf
.
tensor_scatter_nd_update
(
tensor
=
generated
,
indices
=
update_indices
,
updates
=
next_tokens
)
cur_len
+=
1
cur_len
+=
1
# update model_kwargs
# update model_kwargs
...
@@ -2350,34 +2328,29 @@ class TFGenerationMixin:
...
@@ -2350,34 +2328,29 @@ class TFGenerationMixin:
# let's throw out `past` since we don't want `None` tensors
# let's throw out `past` since we don't want `None` tensors
model_kwargs
.
pop
(
"past"
,
None
)
model_kwargs
.
pop
(
"past"
,
None
)
next_tokens
=
tf
.
reshape
(
generated
.
concat
(),
(
-
1
,
batch_size
))
return
generated
,
finished_sequences
,
cur_len
,
model_kwargs
next_tokens
=
tf
.
transpose
(
next_tokens
[:
cur_len
])
return
generated
,
finished_sequences
,
next_tokens
,
cur_len
,
model_kwargs
# 5. run generation
# 5. run generation
# 1st generation step has to be run before to initialize `past`
# 1st generation step has to be run before to initialize `past`
generated
,
finished_sequences
,
next_tokens
,
cur_len
,
model_kwargs
=
sample_body_fn
(
generated
,
finished_sequences
,
cur_len
,
model_kwargs
=
sample_body_fn
(
generated
,
finished_sequences
,
input_ids
,
cur_len
,
model_kwargs
generated
,
finished_sequences
,
cur_len
,
model_kwargs
)
)
# 2-to-n generation steps can then be run in autoregressive fashion
# 2-to-n generation steps can then be run in autoregressive fashion
# only in case 1st generation step does NOT yield EOS token though
# only in case 1st generation step does NOT yield EOS token though
if
sample_cond_fn
(
generated
,
finished_sequences
,
next_tokens
,
cur_len
,
model_kwargs
):
if
sample_cond_fn
(
generated
,
finished_sequences
,
cur_len
,
model_kwargs
):
maximum_iterations
=
max_length
-
cur_len
maximum_iterations
=
max_length
-
cur_len
generated
,
_
,
_
,
cur_len
,
_
=
tf
.
while_loop
(
generated
,
_
,
cur_len
,
_
=
tf
.
while_loop
(
sample_cond_fn
,
sample_cond_fn
,
sample_body_fn
,
sample_body_fn
,
(
generated
,
finished_sequences
,
next_tokens
,
cur_len
,
model_kwargs
),
(
generated
,
finished_sequences
,
cur_len
,
model_kwargs
),
maximum_iterations
=
maximum_iterations
,
maximum_iterations
=
maximum_iterations
,
)
)
# 6. prepare outputs
# 6. prepare outputs
output_ids
=
tf
.
transpose
(
tf
.
reshape
(
generated
.
concat
(),
(
-
1
,
batch_size
)))
if
not
use_xla
:
if
not
use_xla
:
# cut for backward compatibility
# cut for backward compatibility
output_ids
=
output_ids
[:,
:
cur_len
]
generated
=
generated
[:,
:
cur_len
]
if
return_dict_in_generate
:
if
return_dict_in_generate
:
if
self
.
config
.
is_encoder_decoder
:
if
self
.
config
.
is_encoder_decoder
:
...
@@ -2394,7 +2367,7 @@ class TFGenerationMixin:
...
@@ -2394,7 +2367,7 @@ class TFGenerationMixin:
decoder_hidden_states
=
tuple
(
decoder_hidden_states
)
if
decoder_hidden_states
is
not
None
else
None
decoder_hidden_states
=
tuple
(
decoder_hidden_states
)
if
decoder_hidden_states
is
not
None
else
None
return
TFSampleEncoderDecoderOutput
(
return
TFSampleEncoderDecoderOutput
(
sequences
=
output_ids
,
sequences
=
generated
,
scores
=
scores
,
scores
=
scores
,
encoder_attentions
=
encoder_attentions
,
encoder_attentions
=
encoder_attentions
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
...
@@ -2404,13 +2377,13 @@ class TFGenerationMixin:
...
@@ -2404,13 +2377,13 @@ class TFGenerationMixin:
)
)
else
:
else
:
return
TFSampleDecoderOnlyOutput
(
return
TFSampleDecoderOnlyOutput
(
sequences
=
output_ids
,
sequences
=
generated
,
scores
=
scores
,
scores
=
scores
,
attentions
=
decoder_attentions
,
attentions
=
decoder_attentions
,
hidden_states
=
decoder_hidden_states
,
hidden_states
=
decoder_hidden_states
,
)
)
else
:
else
:
return
output_ids
return
generated
def
beam_search
(
def
beam_search
(
self
,
self
,
...
@@ -2585,6 +2558,8 @@ class TFGenerationMixin:
...
@@ -2585,6 +2558,8 @@ class TFGenerationMixin:
# GPT2 and other models has a slightly different cache structure, with a different batch axis
# GPT2 and other models has a slightly different cache structure, with a different batch axis
model_name
=
str
(
self
.
decoder
)
if
"EncoderDecoder"
in
str
(
self
)
else
str
(
self
)
model_name
=
str
(
self
.
decoder
)
if
"EncoderDecoder"
in
str
(
self
)
else
str
(
self
)
cache_batch_axis
=
1
if
any
([
model_prefix
in
model_name
for
model_prefix
in
(
"TFGPT2"
,
"TFCTRL"
)])
else
0
cache_batch_axis
=
1
if
any
([
model_prefix
in
model_name
for
model_prefix
in
(
"TFGPT2"
,
"TFCTRL"
)])
else
0
# some models, like XLNet, need more than the last token in the presence of past
needs_full_input
=
"use_mems"
in
set
(
inspect
.
signature
(
self
.
prepare_inputs_for_generation
).
parameters
.
keys
())
# 2. init `attentions`, `hidden_states`, and `scores` tuples
# 2. init `attentions`, `hidden_states`, and `scores` tuples
scores
=
[]
if
(
return_dict_in_generate
and
output_scores
)
else
None
scores
=
[]
if
(
return_dict_in_generate
and
output_scores
)
else
None
...
@@ -2594,41 +2569,13 @@ class TFGenerationMixin:
...
@@ -2594,41 +2569,13 @@ class TFGenerationMixin:
# 3. init tensors to use for "xla-compileable" generate function
# 3. init tensors to use for "xla-compileable" generate function
batch_size
,
num_beams
,
cur_len
=
input_ids
.
shape
batch_size
,
num_beams
,
cur_len
=
input_ids
.
shape
input_ids_length
=
cur_len
# per batch, beam-item holding current token in loop, pre-populated with `pad_token_id`
# per batch, beam-item holding current token in loop, pre-populated with `pad_token_id`
sequences
=
tf
.
TensorArray
(
input_ids_padding
=
tf
.
ones
((
batch_size
,
num_beams
,
max_length
-
cur_len
),
dtype
=
tf
.
int32
)
*
(
element_shape
=
(
batch_size
,
num_beams
),
pad_token_id
or
0
dtype
=
tf
.
int32
,
dynamic_size
=
False
,
size
=
max_length
,
clear_after_read
=
False
,
)
running_sequences
=
tf
.
TensorArray
(
element_shape
=
(
batch_size
,
num_beams
),
dtype
=
tf
.
int32
,
dynamic_size
=
False
,
size
=
max_length
,
clear_after_read
=
False
,
)
intermediary_running_sequences
=
tf
.
TensorArray
(
element_shape
=
(
batch_size
,
num_beams
*
2
),
dtype
=
tf
.
int32
,
dynamic_size
=
False
,
size
=
max_length
,
clear_after_read
=
False
,
)
)
if
pad_token_id
:
# ignores the cases when it is 0 or None
running_sequences
=
tf
.
concat
([
input_ids
,
input_ids_padding
],
axis
=-
1
)
for
i
in
range
(
max_length
):
sequences
=
tf
.
ones
((
batch_size
,
num_beams
,
max_length
),
dtype
=
tf
.
int32
)
*
(
pad_token_id
or
0
)
sequences
=
sequences
.
write
(
i
,
tf
.
broadcast_to
(
pad_token_id
,
(
batch_size
,
num_beams
)))
running_sequences
=
running_sequences
.
write
(
i
,
tf
.
broadcast_to
(
pad_token_id
,
(
batch_size
,
num_beams
)))
intermediary_running_sequences
=
intermediary_running_sequences
.
write
(
i
,
tf
.
broadcast_to
(
pad_token_id
,
(
batch_size
,
num_beams
*
2
))
)
# write prompt to running_sequences
for
i
in
range
(
cur_len
):
running_sequences
=
running_sequences
.
write
(
i
,
input_ids
[:,
:,
i
])
# per batch,beam-item state bit indicating if sentence has finished.
# per batch,beam-item state bit indicating if sentence has finished.
is_sent_finished
=
tf
.
zeros
((
batch_size
,
num_beams
),
dtype
=
tf
.
bool
)
is_sent_finished
=
tf
.
zeros
((
batch_size
,
num_beams
),
dtype
=
tf
.
bool
)
...
@@ -2656,7 +2603,6 @@ class TFGenerationMixin:
...
@@ -2656,7 +2603,6 @@ class TFGenerationMixin:
sequences
,
sequences
,
scores
,
scores
,
is_sent_finished
,
is_sent_finished
,
input_ids_length
,
model_kwargs
,
model_kwargs
,
):
):
"""
"""
...
@@ -2685,27 +2631,18 @@ class TFGenerationMixin:
...
@@ -2685,27 +2631,18 @@ class TFGenerationMixin:
sequences
,
sequences
,
scores
,
scores
,
is_sent_finished
,
is_sent_finished
,
input_ids_length
,
model_kwargs
,
model_kwargs
,
intermediary_running_sequences
=
None
,
):
):
"""
"""
Beam Search iterative update function -- each iteration adds a new token and updates the best sequences
Beam Search iterative update function -- each iteration adds a new token and updates the best sequences
seen so far
seen so far
"""
"""
# TODO (joao): this loop is probably faster with gather/scatters, instead of using `tf.TensorArray`.
# Alternativelly, attempt to rewrite function with permuted axis, when enabling XLA.
# 1. Forward current tokens
# 1. Forward current tokens
if
model_kwargs
.
get
(
"past"
)
is
None
or
needs_full_input
:
# TF places the dynamic dimension (seq_len) in the first axis, we want it in the last
input_ids
=
running_sequences
[:,
:,
:
cur_len
]
running_sequences_seq_last
=
tf
.
transpose
(
running_sequences
.
stack
(),
perm
=
[
1
,
2
,
0
])
else
:
input_token
=
tf
.
slice
(
input_ids
=
tf
.
expand_dims
(
running_sequences
[:,
:,
cur_len
-
1
],
-
1
)
running_sequences_seq_last
,
model_inputs
=
self
.
prepare_inputs_for_generation
(
flatten_beam_dim
(
input_ids
),
**
model_kwargs
)
(
0
,
0
,
cur_len
-
input_ids_length
),
(
batch_size
,
num_beams
,
input_ids_length
),
)
model_inputs
=
self
.
prepare_inputs_for_generation
(
flatten_beam_dim
(
input_token
),
**
model_kwargs
)
model_outputs
=
self
(
model_outputs
=
self
(
**
model_inputs
,
**
model_inputs
,
return_dict
=
True
,
return_dict
=
True
,
...
@@ -2734,9 +2671,7 @@ class TFGenerationMixin:
...
@@ -2734,9 +2671,7 @@ class TFGenerationMixin:
# get log probabilities from logits, process logits with processors (*e.g.* min_length, ...), and
# get log probabilities from logits, process logits with processors (*e.g.* min_length, ...), and
# add new logprobs to existing running logprobs scores.
# add new logprobs to existing running logprobs scores.
log_probs
=
tf
.
nn
.
log_softmax
(
logits
)
log_probs
=
tf
.
nn
.
log_softmax
(
logits
)
log_probs
=
logits_processor
(
log_probs
=
logits_processor
(
flatten_beam_dim
(
running_sequences
),
flatten_beam_dim
(
log_probs
),
cur_len
)
flatten_beam_dim
(
running_sequences_seq_last
),
flatten_beam_dim
(
log_probs
),
cur_len
)
log_probs
=
unflatten_beam_dim
(
log_probs
,
batch_size
,
num_beams
)
log_probs
=
unflatten_beam_dim
(
log_probs
,
batch_size
,
num_beams
)
log_probs
=
log_probs
+
tf
.
expand_dims
(
running_scores
,
axis
=
2
)
log_probs
=
log_probs
+
tf
.
expand_dims
(
running_scores
,
axis
=
2
)
vocab_size
=
log_probs
.
shape
[
2
]
vocab_size
=
log_probs
.
shape
[
2
]
...
@@ -2755,23 +2690,28 @@ class TFGenerationMixin:
...
@@ -2755,23 +2690,28 @@ class TFGenerationMixin:
beams_to_keep
=
2
*
num_beams
beams_to_keep
=
2
*
num_beams
topk_log_probs
,
topk_indices
=
tf
.
math
.
top_k
(
log_probs
,
k
=
beams_to_keep
)
topk_log_probs
,
topk_indices
=
tf
.
math
.
top_k
(
log_probs
,
k
=
beams_to_keep
)
topk_beam_indices
=
topk_indices
//
vocab_size
topk_beam_indices
=
topk_indices
//
vocab_size
topk_running_sequences
_seq_last
=
gather_beams
(
running_sequences
_seq_last
,
topk_beam_indices
)
topk_running_sequences
=
gather_beams
(
running_sequences
,
topk_beam_indices
)
topk_ids
=
topk_indices
%
vocab_size
topk_ids
=
topk_indices
%
vocab_size
# writes the new token
# writes the new token
intermediary_running_sequences
=
intermediary_running_sequences
.
unstack
(
indices_batch
=
tf
.
repeat
(
tf
.
range
(
batch_size
),
[
beams_to_keep
])
tf
.
transpose
(
topk_running_sequences_seq_last
,
perm
=
[
2
,
0
,
1
])
indices_beam
=
tf
.
tile
(
tf
.
range
(
beams_to_keep
),
[
batch_size
])
update_indices
=
tf
.
stack
(
[
indices_batch
,
indices_beam
,
tf
.
broadcast_to
(
cur_len
,
[
batch_size
*
beams_to_keep
])],
axis
=-
1
)
topk_sequences
=
tf
.
tensor_scatter_nd_update
(
tensor
=
topk_running_sequences
,
indices
=
update_indices
,
updates
=
tf
.
reshape
(
topk_ids
,
[
batch_size
*
beams_to_keep
]),
)
)
topk_sequences
=
intermediary_running_sequences
.
write
(
cur_len
,
topk_ids
)
topk_sequences_seq_last
=
tf
.
transpose
(
topk_sequences
.
stack
(),
perm
=
[
1
,
2
,
0
])
# 4. Check which sequences have ended
# 4. Check which sequences have ended
# Update current sequences: Did the top `num_beams` sequences reach an end marker?
# Update current sequences: Did the top `num_beams` sequences reach an end marker?
# To prevent these just finished sequences from being added to the current sequences
# To prevent these just finished sequences from being added to the current sequences
# set of active beam search sequences, set their log probs to a very large negative value.
# set of active beam search sequences, set their log probs to a very large negative value.
eos_in_next_token
=
topk_sequences
_seq_last
[:,
:,
cur_len
]
==
eos_token_id
eos_in_next_token
=
topk_sequences
[:,
:,
cur_len
]
==
eos_token_id
if
eos_token_id
is
None
:
if
eos_token_id
is
None
:
eos_in_next_token
=
tf
.
broadcast_to
(
eos_in_next_token
,
topk_sequences
_seq_last
[:,
:,
cur_len
].
shape
)
eos_in_next_token
=
tf
.
broadcast_to
(
eos_in_next_token
,
topk_sequences
[:,
:,
cur_len
].
shape
)
did_topk_just_finished
=
eos_in_next_token
&
tf
.
broadcast_to
(
did_topk_just_finished
=
eos_in_next_token
&
tf
.
broadcast_to
(
tf
.
concat
((
tf
.
ones
((
num_beams
),
dtype
=
tf
.
bool
),
tf
.
zeros
((
num_beams
),
dtype
=
tf
.
bool
)),
axis
=
0
),
tf
.
concat
((
tf
.
ones
((
num_beams
),
dtype
=
tf
.
bool
),
tf
.
zeros
((
num_beams
),
dtype
=
tf
.
bool
)),
axis
=
0
),
eos_in_next_token
.
shape
,
eos_in_next_token
.
shape
,
...
@@ -2785,8 +2725,8 @@ class TFGenerationMixin:
...
@@ -2785,8 +2725,8 @@ class TFGenerationMixin:
# Determine the top k beam indices (from top 2*k beams) from log probs and gather top k beams
# Determine the top k beam indices (from top 2*k beams) from log probs and gather top k beams
# (from top 2*k beams).
# (from top 2*k beams).
next_topk_indices
=
tf
.
math
.
top_k
(
running_topk_log_probs
,
k
=
num_beams
)[
1
]
next_topk_indices
=
tf
.
math
.
top_k
(
running_topk_log_probs
,
k
=
num_beams
)[
1
]
next_running_sequences
_seq_last
,
next_running_scores
=
gather_beams
(
next_running_sequences
,
next_running_scores
=
gather_beams
(
[
topk_sequences
_seq_last
,
running_topk_log_probs
],
next_topk_indices
[
topk_sequences
,
running_topk_log_probs
],
next_topk_indices
)
)
# 6. Process topk logits
# 6. Process topk logits
...
@@ -2807,18 +2747,18 @@ class TFGenerationMixin:
...
@@ -2807,18 +2747,18 @@ class TFGenerationMixin:
# 7. Get scores, sequences, is sentence finished for next.
# 7. Get scores, sequences, is sentence finished for next.
# Combine sequences, scores, and flags along the beam dimension and compare new finished sequence scores
# Combine sequences, scores, and flags along the beam dimension and compare new finished sequence scores
# to existing finished scores and select the best from the new set of beams
# to existing finished scores and select the best from the new set of beams
sequences_seq_last
=
tf
.
transpose
(
sequences
.
stack
(),
perm
=
[
1
,
2
,
0
])
merged_sequences
=
tf
.
concat
([
sequences
,
topk_sequences
],
axis
=
1
)
merged_sequences
=
tf
.
concat
([
sequences_seq_last
,
topk_sequences_seq_last
],
axis
=
1
)
merged_scores
=
tf
.
concat
([
scores
,
topk_log_probs
],
axis
=
1
)
merged_scores
=
tf
.
concat
([
scores
,
topk_log_probs
],
axis
=
1
)
merged_is_sent_finished
=
tf
.
concat
([
is_sent_finished
,
did_topk_just_finished
],
axis
=
1
)
merged_is_sent_finished
=
tf
.
concat
([
is_sent_finished
,
did_topk_just_finished
],
axis
=
1
)
topk_merged_indices
=
tf
.
math
.
top_k
(
merged_scores
,
k
=
num_beams
)[
1
]
topk_merged_indices
=
tf
.
math
.
top_k
(
merged_scores
,
k
=
num_beams
)[
1
]
next_sequences
_seq_last
,
next_scores
,
next_is_sent_finished
=
gather_beams
(
next_sequences
,
next_scores
,
next_is_sent_finished
=
gather_beams
(
[
merged_sequences
,
merged_scores
,
merged_is_sent_finished
],
topk_merged_indices
[
merged_sequences
,
merged_scores
,
merged_is_sent_finished
],
topk_merged_indices
)
)
# 8. Prepare data for the next iteration
# 8. Prepare data for the next iteration
# Determine the top k beam indices from the original set of all beams. With these, gather the top k
# Determine the top k beam indices from the original set of all beams. With these, gather the top k
# beam-associated caches.
# beam-associated caches.
cur_len
=
cur_len
+
1
if
"past_key_values"
in
model_outputs
:
if
"past_key_values"
in
model_outputs
:
cache
=
tf
.
nest
.
map_structure
(
cache
=
tf
.
nest
.
map_structure
(
lambda
tensor
:
unflatten_beam_dim
(
tensor
,
batch_size
,
num_beams
,
batch_axis
=
cache_batch_axis
),
lambda
tensor
:
unflatten_beam_dim
(
tensor
,
batch_size
,
num_beams
,
batch_axis
=
cache_batch_axis
),
...
@@ -2841,35 +2781,20 @@ class TFGenerationMixin:
...
@@ -2841,35 +2781,20 @@ class TFGenerationMixin:
# if we don't cache past key values we need the whole input
# if we don't cache past key values we need the whole input
if
model_kwargs
.
get
(
"past"
,
None
)
is
None
:
if
model_kwargs
.
get
(
"past"
,
None
)
is
None
:
next_input_ids_length
=
cur_len
+
1
# let's throw out `past` since we don't want `None` tensors
# let's throw out `past` since we don't want `None` tensors
model_kwargs
.
pop
(
"past"
,
None
)
model_kwargs
.
pop
(
"past"
,
None
)
else
:
next_input_ids_length
=
1
# 9. Prepare the `tf.TensorArray` for the next iteration
next_sequences
=
sequences
.
unstack
(
tf
.
transpose
(
next_sequences_seq_last
,
perm
=
[
2
,
0
,
1
]))
next_running_sequences
=
running_sequences
.
unstack
(
tf
.
transpose
(
next_running_sequences_seq_last
,
perm
=
[
2
,
0
,
1
])
)
return
(
return
(
cur_len
+
1
,
cur_len
,
next_running_sequences
,
next_running_sequences
,
next_running_scores
,
next_running_scores
,
next_sequences
,
next_sequences
,
next_scores
,
next_scores
,
next_is_sent_finished
,
next_is_sent_finished
,
next_input_ids_length
,
next_model_kwargs
,
next_model_kwargs
,
)
)
# 5. run generation
# 5. run generation
# Adds the `intermediary_running_sequences` TensorArray into the body, needed as a scratchpad
beam_search_body_fn
=
partial
(
beam_search_body_fn
,
intermediary_running_sequences
=
intermediary_running_sequences
)
# 1st generation step has to be run before to initialize `past` (if active)
# 1st generation step has to be run before to initialize `past` (if active)
(
(
cur_len
,
cur_len
,
...
@@ -2878,66 +2803,38 @@ class TFGenerationMixin:
...
@@ -2878,66 +2803,38 @@ class TFGenerationMixin:
sequences
,
sequences
,
scores
,
scores
,
is_sent_finished
,
is_sent_finished
,
input_ids_length
,
model_kwargs
,
model_kwargs
,
)
=
beam_search_body_fn
(
)
=
beam_search_body_fn
(
cur_len
,
cur_len
,
running_sequences
,
running_scores
,
sequences
,
scores
,
is_sent_finished
,
model_kwargs
running_sequences
,
running_scores
,
sequences
,
scores
,
is_sent_finished
,
input_ids_length
,
model_kwargs
,
)
)
# 2-to-n generation steps can then be run in autoregressive fashion (only in case 1st generation step does
# 2-to-n generation steps can then be run in autoregressive fashion (only in case 1st generation step does
# NOT yield EOS token though)
# NOT yield EOS token though)
if
beam_search_cond_fn
(
if
beam_search_cond_fn
(
cur_len
,
cur_len
,
running_sequences
,
running_scores
,
sequences
,
scores
,
is_sent_finished
,
model_kwargs
running_sequences
,
running_scores
,
sequences
,
scores
,
is_sent_finished
,
input_ids_length
,
model_kwargs
,
):
):
maximum_iterations
=
max_length
-
cur_len
maximum_iterations
=
max_length
-
cur_len
cur_len
,
running_sequences
,
running_scores
,
sequences
,
scores
,
is_sent_finished
,
_
,
_
=
tf
.
while_loop
(
cur_len
,
running_sequences
,
running_scores
,
sequences
,
scores
,
is_sent_finished
,
_
=
tf
.
while_loop
(
beam_search_cond_fn
,
beam_search_cond_fn
,
beam_search_body_fn
,
beam_search_body_fn
,
(
(
cur_len
,
running_sequences
,
running_scores
,
sequences
,
scores
,
is_sent_finished
,
model_kwargs
),
cur_len
,
running_sequences
,
running_scores
,
sequences
,
scores
,
is_sent_finished
,
input_ids_length
,
model_kwargs
,
),
maximum_iterations
=
maximum_iterations
,
maximum_iterations
=
maximum_iterations
,
)
)
# 6. prepare outputs
# 6. prepare outputs
# convert the sequneces to tf.Tensor with shape (batch_size, num_beams, seq_len)
sequences_seq_last
=
tf
.
transpose
(
sequences
.
stack
(),
perm
=
[
1
,
2
,
0
])
running_sequences_seq_last
=
tf
.
transpose
(
running_sequences
.
stack
(),
perm
=
[
1
,
2
,
0
])
# Account for the edge-case where there are no finished sequences for a particular batch item. If so, return
# Account for the edge-case where there are no finished sequences for a particular batch item. If so, return
# running sequences for that batch item.
# running sequences for that batch item.
none_finished
=
tf
.
math
.
reduce_any
(
is_sent_finished
,
axis
=
1
)
none_finished
=
tf
.
math
.
reduce_any
(
is_sent_finished
,
axis
=
1
)
sequences
_seq_last
=
tf
.
where
(
none_finished
[:,
None
,
None
],
sequences
_seq_last
,
running_sequences
_seq_last
)
sequences
=
tf
.
where
(
none_finished
[:,
None
,
None
],
sequences
,
running_sequences
)
scores
=
tf
.
where
(
none_finished
[:,
None
],
scores
,
running_scores
)
scores
=
tf
.
where
(
none_finished
[:,
None
],
scores
,
running_scores
)
# Take best beams for each batch (the score is sorted in ascending order)
# Take best beams for each batch (the score is sorted in ascending order)
sequences
_seq_last
=
flatten_beam_dim
(
sequences
_seq_last
[:,
:
num_return_sequences
,
:])
sequences
=
flatten_beam_dim
(
sequences
[:,
:
num_return_sequences
,
:])
scores
=
flatten_beam_dim
(
scores
[:,
:
num_return_sequences
])
scores
=
flatten_beam_dim
(
scores
[:,
:
num_return_sequences
])
if
not
use_xla
:
if
not
use_xla
:
# Cut for backward compatibility
# Cut for backward compatibility
sequences
_seq_last
=
sequences
_seq_last
[:,
:
cur_len
]
sequences
=
sequences
[:,
:
cur_len
]
if
return_dict_in_generate
:
if
return_dict_in_generate
:
if
self
.
config
.
is_encoder_decoder
:
if
self
.
config
.
is_encoder_decoder
:
...
@@ -2948,7 +2845,7 @@ class TFGenerationMixin:
...
@@ -2948,7 +2845,7 @@ class TFGenerationMixin:
)
)
return
TFBeamSearchEncoderDecoderOutput
(
return
TFBeamSearchEncoderDecoderOutput
(
sequences
=
sequences
_seq_last
,
sequences
=
sequences
,
scores
=
scores
,
scores
=
scores
,
encoder_attentions
=
encoder_attentions
,
encoder_attentions
=
encoder_attentions
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
...
@@ -2958,13 +2855,13 @@ class TFGenerationMixin:
...
@@ -2958,13 +2855,13 @@ class TFGenerationMixin:
)
)
else
:
else
:
return
TFBeamSearchDecoderOnlyOutput
(
return
TFBeamSearchDecoderOnlyOutput
(
sequences
=
sequences
_seq_last
,
sequences
=
sequences
,
scores
=
scores
,
scores
=
scores
,
attentions
=
decoder_attentions
,
attentions
=
decoder_attentions
,
hidden_states
=
decoder_hidden_states
,
hidden_states
=
decoder_hidden_states
,
)
)
else
:
else
:
return
sequences
_seq_last
return
sequences
def
_create_next_token_logits_penalties
(
input_ids
,
logits
,
repetition_penalty
):
def
_create_next_token_logits_penalties
(
input_ids
,
logits
,
repetition_penalty
):
...
...
src/transformers/models/gpt2/modeling_tf_gpt2.py
View file @
5cce3076
...
@@ -874,8 +874,9 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
...
@@ -874,8 +874,9 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
new_past
=
[
None
for
_
in
range
(
len
(
past
))]
new_past
=
[
None
for
_
in
range
(
len
(
past
))]
slice_start_base
=
tf
.
constant
([
0
,
0
,
0
,
1
,
0
])
slice_start_base
=
tf
.
constant
([
0
,
0
,
0
,
1
,
0
])
attention_mask_update_slice
=
tf
.
ones
((
batch_size
,
1
),
dtype
=
attention_mask
.
dtype
)
attention_mask_update_slice
=
tf
.
ones
((
batch_size
,
1
),
dtype
=
attention_mask
.
dtype
)
# correct 5 here
# -1 because current_pos has already been incremented before this function
new_past_index
=
current_pos
-
1
# -1 again because last index = len - 1
new_past_index
=
current_pos
-
2
for
i
in
range
(
len
(
past
)):
for
i
in
range
(
len
(
past
)):
update_slice
=
past
[
i
][:,
:,
:,
-
1
:]
update_slice
=
past
[
i
][:,
:,
:,
-
1
:]
...
...
src/transformers/models/xlnet/modeling_tf_xlnet.py
View file @
5cce3076
...
@@ -1202,7 +1202,6 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
...
@@ -1202,7 +1202,6 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
def
prepare_inputs_for_generation
(
self
,
inputs
,
past
=
None
,
use_mems
=
None
,
**
kwargs
):
def
prepare_inputs_for_generation
(
self
,
inputs
,
past
=
None
,
use_mems
=
None
,
**
kwargs
):
# Add dummy token at the end (no attention on this one)
# Add dummy token at the end (no attention on this one)
effective_batch_size
=
inputs
.
shape
[
0
]
effective_batch_size
=
inputs
.
shape
[
0
]
dummy_token
=
tf
.
zeros
((
effective_batch_size
,
1
),
dtype
=
inputs
.
dtype
)
dummy_token
=
tf
.
zeros
((
effective_batch_size
,
1
),
dtype
=
inputs
.
dtype
)
...
@@ -1212,12 +1211,12 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
...
@@ -1212,12 +1211,12 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
offset
=
2
offset
=
2
if
past
:
if
past
:
inputs
=
tf
.
concat
([
inputs
[:,
-
offset
:],
dummy_token
],
axis
=
1
)
input
_id
s
=
tf
.
concat
([
inputs
[:,
-
offset
:],
dummy_token
],
axis
=
1
)
else
:
else
:
inputs
=
tf
.
concat
([
inputs
,
dummy_token
],
axis
=
1
)
input
_id
s
=
tf
.
concat
([
inputs
,
dummy_token
],
axis
=
1
)
# Build permutation mask so that previous tokens don't see last token
# Build permutation mask so that previous tokens don't see last token
sequence_length
=
inputs
.
shape
[
1
]
sequence_length
=
input
_id
s
.
shape
[
1
]
perm_mask
=
tf
.
zeros
((
effective_batch_size
,
sequence_length
,
sequence_length
-
1
))
perm_mask
=
tf
.
zeros
((
effective_batch_size
,
sequence_length
,
sequence_length
-
1
))
perm_mask_seq_end
=
tf
.
ones
((
effective_batch_size
,
sequence_length
,
1
))
perm_mask_seq_end
=
tf
.
ones
((
effective_batch_size
,
sequence_length
,
1
))
perm_mask
=
tf
.
concat
([
perm_mask
,
perm_mask_seq_end
],
axis
=-
1
)
perm_mask
=
tf
.
concat
([
perm_mask
,
perm_mask_seq_end
],
axis
=-
1
)
...
@@ -1228,7 +1227,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
...
@@ -1228,7 +1227,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
target_mapping
=
tf
.
concat
([
target_mapping
,
target_mapping_seq_end
],
axis
=-
1
)
target_mapping
=
tf
.
concat
([
target_mapping
,
target_mapping_seq_end
],
axis
=-
1
)
inputs
=
{
inputs
=
{
"input_ids"
:
inputs
,
"input_ids"
:
input
_id
s
,
"perm_mask"
:
perm_mask
,
"perm_mask"
:
perm_mask
,
"target_mapping"
:
target_mapping
,
"target_mapping"
:
target_mapping
,
"use_mems"
:
use_mems
,
"use_mems"
:
use_mems
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment