Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5cce3076
Unverified
Commit
5cce3076
authored
Jun 23, 2022
by
Joao Gante
Committed by
GitHub
Jun 23, 2022
Browse files
TF: generate without `tf.TensorArray` (#17801)
parent
ab223fc1
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
97 additions
and
200 deletions
+97
-200
src/transformers/generation_tf_utils.py
src/transformers/generation_tf_utils.py
+90
-193
src/transformers/models/gpt2/modeling_tf_gpt2.py
src/transformers/models/gpt2/modeling_tf_gpt2.py
+3
-2
src/transformers/models/xlnet/modeling_tf_xlnet.py
src/transformers/models/xlnet/modeling_tf_xlnet.py
+4
-5
No files found.
src/transformers/generation_tf_utils.py
View file @
5cce3076
...
@@ -16,7 +16,6 @@
...
@@ -16,7 +16,6 @@
import
inspect
import
inspect
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
partial
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -1979,6 +1978,8 @@ class TFGenerationMixin:
...
@@ -1979,6 +1978,8 @@ class TFGenerationMixin:
return_dict_in_generate
if
return_dict_in_generate
is
not
None
else
self
.
config
.
return_dict_in_generate
return_dict_in_generate
if
return_dict_in_generate
is
not
None
else
self
.
config
.
return_dict_in_generate
)
)
use_xla
=
not
tf
.
executing_eagerly
()
use_xla
=
not
tf
.
executing_eagerly
()
# some models, like XLNet, need more than the last token in the presence of past
needs_full_input
=
"use_mems"
in
set
(
inspect
.
signature
(
self
.
prepare_inputs_for_generation
).
parameters
.
keys
())
# 2. init `attentions`, `hidden_states`, and `scores` tuples
# 2. init `attentions`, `hidden_states`, and `scores` tuples
scores
=
[]
if
(
return_dict_in_generate
and
output_scores
)
else
None
scores
=
[]
if
(
return_dict_in_generate
and
output_scores
)
else
None
...
@@ -1989,34 +1990,25 @@ class TFGenerationMixin:
...
@@ -1989,34 +1990,25 @@ class TFGenerationMixin:
# 3. init tensors to use for "xla-compileable" generate function
# 3. init tensors to use for "xla-compileable" generate function
batch_size
,
cur_len
=
input_ids
.
shape
batch_size
,
cur_len
=
input_ids
.
shape
# initialize `generated` (pre-populated with `pad_token_id`), `finished_sequences`
# initialize `generated` (`input_ids` padded with `pad_token_id`), `finished_sequences`
generated
=
tf
.
TensorArray
(
input_ids_padding
=
tf
.
ones
((
batch_size
,
max_length
-
cur_len
),
dtype
=
tf
.
int32
)
*
(
pad_token_id
or
0
)
element_shape
=
(
batch_size
,),
generated
=
tf
.
concat
([
input_ids
,
input_ids_padding
],
axis
=-
1
)
dtype
=
tf
.
int32
,
dynamic_size
=
False
,
size
=
max_length
,
clear_after_read
=
False
,
)
if
pad_token_id
:
# ignores the cases when it is 0 or None
for
i
in
range
(
max_length
):
generated
=
generated
.
write
(
i
,
tf
.
broadcast_to
(
pad_token_id
,
(
batch_size
,)))
# write prompt to generated
for
i
in
range
(
cur_len
):
generated
=
generated
.
write
(
i
,
input_ids
[:,
i
])
finished_sequences
=
tf
.
zeros
((
batch_size
,),
dtype
=
tf
.
bool
)
finished_sequences
=
tf
.
zeros
((
batch_size
,),
dtype
=
tf
.
bool
)
# 4. define "xla-compile-able" stop-condition and auto-regressive function
# 4. define "xla-compile-able" stop-condition and auto-regressive function
# define condition fn
# define condition fn
def
greedy_search_cond_fn
(
generated
,
finished_sequences
,
next_tokens
,
cur_len
,
model_kwargs
):
def
greedy_search_cond_fn
(
generated
,
finished_sequences
,
cur_len
,
model_kwargs
):
"""state termination condition fn."""
"""state termination condition fn."""
return
~
tf
.
reduce_all
(
finished_sequences
)
return
~
tf
.
reduce_all
(
finished_sequences
)
# define condition fn
# define condition fn
def
greedy_search_body_fn
(
generated
,
finished_sequences
,
next_tokens
,
cur_len
,
model_kwargs
):
def
greedy_search_body_fn
(
generated
,
finished_sequences
,
cur_len
,
model_kwargs
):
"""state update fn."""
"""state update fn."""
model_inputs
=
self
.
prepare_inputs_for_generation
(
next_tokens
,
**
model_kwargs
)
if
model_kwargs
.
get
(
"past"
)
is
None
or
needs_full_input
:
input_ids
=
generated
[:,
:
cur_len
]
else
:
input_ids
=
tf
.
expand_dims
(
generated
[:,
cur_len
-
1
],
-
1
)
model_inputs
=
self
.
prepare_inputs_for_generation
(
input_ids
,
**
model_kwargs
)
# forward pass to get next token logits
# forward pass to get next token logits
outputs
=
self
(
outputs
=
self
(
**
model_inputs
,
**
model_inputs
,
...
@@ -2043,8 +2035,7 @@ class TFGenerationMixin:
...
@@ -2043,8 +2035,7 @@ class TFGenerationMixin:
decoder_hidden_states
.
append
(
outputs
.
hidden_states
)
decoder_hidden_states
.
append
(
outputs
.
hidden_states
)
# pre-process distribution
# pre-process distribution
input_ids
=
tf
.
transpose
(
tf
.
reshape
(
generated
.
concat
(),
(
-
1
,
batch_size
)))
next_tokens_scores
=
logits_processor
(
generated
,
next_token_logits
,
cur_len
)
next_tokens_scores
=
logits_processor
(
input_ids
,
next_token_logits
,
cur_len
)
# argmax
# argmax
next_tokens
=
tf
.
argmax
(
next_tokens_scores
,
axis
=-
1
,
output_type
=
tf
.
int32
)
next_tokens
=
tf
.
argmax
(
next_tokens_scores
,
axis
=-
1
,
output_type
=
tf
.
int32
)
...
@@ -2057,8 +2048,8 @@ class TFGenerationMixin:
...
@@ -2057,8 +2048,8 @@ class TFGenerationMixin:
finished_sequences
=
finished_sequences
|
(
next_tokens
==
eos_token_id
)
finished_sequences
=
finished_sequences
|
(
next_tokens
==
eos_token_id
)
# update `generated` and `cur_len`
# update `generated` and `cur_len`
generated
=
generated
.
write
(
cur_len
,
next_tokens
)
update_indices
=
tf
.
stack
([
tf
.
range
(
batch_size
),
tf
.
broadcast_to
(
cur_len
,
[
batch_size
])],
axis
=-
1
)
next_tokens
=
next_tokens
[:,
None
]
generated
=
tf
.
tensor_scatter_nd_update
(
tensor
=
generated
,
indices
=
update_indices
,
updates
=
next_tokens
)
cur_len
+=
1
cur_len
+=
1
# update model_kwargs
# update model_kwargs
...
@@ -2073,34 +2064,29 @@ class TFGenerationMixin:
...
@@ -2073,34 +2064,29 @@ class TFGenerationMixin:
# let's throw out `past` since we don't want `None` tensors
# let's throw out `past` since we don't want `None` tensors
model_kwargs
.
pop
(
"past"
,
None
)
model_kwargs
.
pop
(
"past"
,
None
)
next_tokens
=
tf
.
reshape
(
generated
.
concat
(),
(
-
1
,
batch_size
))
return
generated
,
finished_sequences
,
cur_len
,
model_kwargs
next_tokens
=
tf
.
transpose
(
next_tokens
[:
cur_len
])
return
generated
,
finished_sequences
,
next_tokens
,
cur_len
,
model_kwargs
# 5. run generation
# 5. run generation
# 1st generation step has to be run before to initialize `past`
# 1st generation step has to be run before to initialize `past`
generated
,
finished_sequences
,
next_tokens
,
cur_len
,
model_kwargs
=
greedy_search_body_fn
(
generated
,
finished_sequences
,
cur_len
,
model_kwargs
=
greedy_search_body_fn
(
generated
,
finished_sequences
,
input_ids
,
cur_len
,
model_kwargs
generated
,
finished_sequences
,
cur_len
,
model_kwargs
)
)
# 2-to-n generation steps can then be run in autoregressive fashion
# 2-to-n generation steps can then be run in autoregressive fashion
# only in case 1st generation step does NOT yield EOS token though
# only in case 1st generation step does NOT yield EOS token though
if
greedy_search_cond_fn
(
generated
,
finished_sequences
,
next_tokens
,
cur_len
,
model_kwargs
):
if
greedy_search_cond_fn
(
generated
,
finished_sequences
,
cur_len
,
model_kwargs
):
maximum_iterations
=
max_length
-
cur_len
maximum_iterations
=
max_length
-
cur_len
generated
,
_
,
_
,
cur_len
,
_
=
tf
.
while_loop
(
generated
,
_
,
cur_len
,
_
=
tf
.
while_loop
(
greedy_search_cond_fn
,
greedy_search_cond_fn
,
greedy_search_body_fn
,
greedy_search_body_fn
,
(
generated
,
finished_sequences
,
next_tokens
,
cur_len
,
model_kwargs
),
(
generated
,
finished_sequences
,
cur_len
,
model_kwargs
),
maximum_iterations
=
maximum_iterations
,
maximum_iterations
=
maximum_iterations
,
)
)
# 6. prepare outputs
# 6. prepare outputs
output_ids
=
tf
.
transpose
(
tf
.
reshape
(
generated
.
concat
(),
(
-
1
,
batch_size
)))
if
not
use_xla
:
if
not
use_xla
:
# cut for backward compatibility
# cut for backward compatibility
output_ids
=
output_ids
[:,
:
cur_len
]
generated
=
generated
[:,
:
cur_len
]
if
return_dict_in_generate
:
if
return_dict_in_generate
:
if
self
.
config
.
is_encoder_decoder
:
if
self
.
config
.
is_encoder_decoder
:
...
@@ -2117,7 +2103,7 @@ class TFGenerationMixin:
...
@@ -2117,7 +2103,7 @@ class TFGenerationMixin:
decoder_hidden_states
=
tuple
(
decoder_hidden_states
)
if
decoder_hidden_states
is
not
None
else
None
decoder_hidden_states
=
tuple
(
decoder_hidden_states
)
if
decoder_hidden_states
is
not
None
else
None
return
TFGreedySearchEncoderDecoderOutput
(
return
TFGreedySearchEncoderDecoderOutput
(
sequences
=
output_ids
,
sequences
=
generated
,
scores
=
scores
,
scores
=
scores
,
encoder_attentions
=
encoder_attentions
,
encoder_attentions
=
encoder_attentions
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
...
@@ -2127,13 +2113,13 @@ class TFGenerationMixin:
...
@@ -2127,13 +2113,13 @@ class TFGenerationMixin:
)
)
else
:
else
:
return
TFGreedySearchDecoderOnlyOutput
(
return
TFGreedySearchDecoderOnlyOutput
(
sequences
=
output_ids
,
sequences
=
generated
,
scores
=
scores
,
scores
=
scores
,
attentions
=
decoder_attentions
,
attentions
=
decoder_attentions
,
hidden_states
=
decoder_hidden_states
,
hidden_states
=
decoder_hidden_states
,
)
)
else
:
else
:
return
output_ids
return
generated
def
sample
(
def
sample
(
self
,
self
,
...
@@ -2250,6 +2236,8 @@ class TFGenerationMixin:
...
@@ -2250,6 +2236,8 @@ class TFGenerationMixin:
return_dict_in_generate
if
return_dict_in_generate
is
not
None
else
self
.
config
.
return_dict_in_generate
return_dict_in_generate
if
return_dict_in_generate
is
not
None
else
self
.
config
.
return_dict_in_generate
)
)
use_xla
=
not
tf
.
executing_eagerly
()
use_xla
=
not
tf
.
executing_eagerly
()
# some models, like XLNet, need more than the last token in the presence of past
needs_full_input
=
"use_mems"
in
set
(
inspect
.
signature
(
self
.
prepare_inputs_for_generation
).
parameters
.
keys
())
# 2. init `attentions`, `hidden_states`, and `scores` tuples
# 2. init `attentions`, `hidden_states`, and `scores` tuples
scores
=
[]
if
(
return_dict_in_generate
and
output_scores
)
else
None
scores
=
[]
if
(
return_dict_in_generate
and
output_scores
)
else
None
...
@@ -2261,29 +2249,20 @@ class TFGenerationMixin:
...
@@ -2261,29 +2249,20 @@ class TFGenerationMixin:
batch_size
,
cur_len
=
input_ids
.
shape
batch_size
,
cur_len
=
input_ids
.
shape
# initialize `generated` (pre-populated with `pad_token_id`), `finished_sequences`
# initialize `generated` (pre-populated with `pad_token_id`), `finished_sequences`
generated
=
tf
.
TensorArray
(
input_ids_padding
=
tf
.
ones
((
batch_size
,
max_length
-
cur_len
),
dtype
=
tf
.
int32
)
*
(
pad_token_id
or
0
)
element_shape
=
(
batch_size
,),
generated
=
tf
.
concat
([
input_ids
,
input_ids_padding
],
axis
=-
1
)
dtype
=
tf
.
int32
,
dynamic_size
=
False
,
size
=
max_length
,
clear_after_read
=
False
,
)
if
pad_token_id
:
# ignores the cases when it is 0 or None
for
i
in
range
(
max_length
):
generated
=
generated
.
write
(
i
,
tf
.
broadcast_to
(
pad_token_id
,
(
batch_size
,)))
# write prompt to generated
for
i
in
range
(
cur_len
):
generated
=
generated
.
write
(
i
,
input_ids
[:,
i
])
finished_sequences
=
tf
.
zeros
((
batch_size
,),
dtype
=
tf
.
bool
)
finished_sequences
=
tf
.
zeros
((
batch_size
,),
dtype
=
tf
.
bool
)
# 4. define "xla-compile-able" stop-condition and auto-regressive function
# 4. define "xla-compile-able" stop-condition and auto-regressive function
def
sample_cond_fn
(
generated
,
finished_sequences
,
next_tokens
,
cur_len
,
model_kwargs
):
def
sample_cond_fn
(
generated
,
finished_sequences
,
cur_len
,
model_kwargs
):
return
~
tf
.
reduce_all
(
finished_sequences
)
return
~
tf
.
reduce_all
(
finished_sequences
)
def
sample_body_fn
(
generated
,
finished_sequences
,
next_tokens
,
cur_len
,
model_kwargs
):
def
sample_body_fn
(
generated
,
finished_sequences
,
cur_len
,
model_kwargs
):
model_inputs
=
self
.
prepare_inputs_for_generation
(
next_tokens
,
**
model_kwargs
)
if
model_kwargs
.
get
(
"past"
)
is
None
or
needs_full_input
:
input_ids
=
generated
[:,
:
cur_len
]
else
:
input_ids
=
tf
.
expand_dims
(
generated
[:,
cur_len
-
1
],
-
1
)
model_inputs
=
self
.
prepare_inputs_for_generation
(
input_ids
,
**
model_kwargs
)
# forward pass to get next token logits
# forward pass to get next token logits
outputs
=
self
(
outputs
=
self
(
**
model_inputs
,
**
model_inputs
,
...
@@ -2310,9 +2289,8 @@ class TFGenerationMixin:
...
@@ -2310,9 +2289,8 @@ class TFGenerationMixin:
decoder_hidden_states
.
append
(
outputs
.
hidden_states
)
decoder_hidden_states
.
append
(
outputs
.
hidden_states
)
# pre-process distribution
# pre-process distribution
input_ids
=
tf
.
transpose
(
tf
.
reshape
(
generated
.
concat
(),
(
-
1
,
batch_size
)))
next_tokens_scores
=
logits_processor
(
generated
,
next_token_logits
,
cur_len
)
next_tokens_scores
=
logits_processor
(
input_ids
,
next_token_logits
,
cur_len
)
next_tokens_scores
=
logits_warper
(
generated
,
next_tokens_scores
,
cur_len
)
next_tokens_scores
=
logits_warper
(
input_ids
,
next_tokens_scores
,
cur_len
)
# sample
# sample
if
seed
is
not
None
:
if
seed
is
not
None
:
...
@@ -2334,8 +2312,8 @@ class TFGenerationMixin:
...
@@ -2334,8 +2312,8 @@ class TFGenerationMixin:
finished_sequences
=
finished_sequences
|
(
next_tokens
==
eos_token_id
)
finished_sequences
=
finished_sequences
|
(
next_tokens
==
eos_token_id
)
# update `generated` and `cur_len`
# update `generated` and `cur_len`
generated
=
generated
.
write
(
cur_len
,
next_tokens
)
update_indices
=
tf
.
stack
([
tf
.
range
(
batch_size
),
tf
.
broadcast_to
(
cur_len
,
[
batch_size
])],
axis
=-
1
)
next_tokens
=
next_tokens
[:,
None
]
generated
=
tf
.
tensor_scatter_nd_update
(
tensor
=
generated
,
indices
=
update_indices
,
updates
=
next_tokens
)
cur_len
+=
1
cur_len
+=
1
# update model_kwargs
# update model_kwargs
...
@@ -2350,34 +2328,29 @@ class TFGenerationMixin:
...
@@ -2350,34 +2328,29 @@ class TFGenerationMixin:
# let's throw out `past` since we don't want `None` tensors
# let's throw out `past` since we don't want `None` tensors
model_kwargs
.
pop
(
"past"
,
None
)
model_kwargs
.
pop
(
"past"
,
None
)
next_tokens
=
tf
.
reshape
(
generated
.
concat
(),
(
-
1
,
batch_size
))
return
generated
,
finished_sequences
,
cur_len
,
model_kwargs
next_tokens
=
tf
.
transpose
(
next_tokens
[:
cur_len
])
return
generated
,
finished_sequences
,
next_tokens
,
cur_len
,
model_kwargs
# 5. run generation
# 5. run generation
# 1st generation step has to be run before to initialize `past`
# 1st generation step has to be run before to initialize `past`
generated
,
finished_sequences
,
next_tokens
,
cur_len
,
model_kwargs
=
sample_body_fn
(
generated
,
finished_sequences
,
cur_len
,
model_kwargs
=
sample_body_fn
(
generated
,
finished_sequences
,
input_ids
,
cur_len
,
model_kwargs
generated
,
finished_sequences
,
cur_len
,
model_kwargs
)
)
# 2-to-n generation steps can then be run in autoregressive fashion
# 2-to-n generation steps can then be run in autoregressive fashion
# only in case 1st generation step does NOT yield EOS token though
# only in case 1st generation step does NOT yield EOS token though
if
sample_cond_fn
(
generated
,
finished_sequences
,
next_tokens
,
cur_len
,
model_kwargs
):
if
sample_cond_fn
(
generated
,
finished_sequences
,
cur_len
,
model_kwargs
):
maximum_iterations
=
max_length
-
cur_len
maximum_iterations
=
max_length
-
cur_len
generated
,
_
,
_
,
cur_len
,
_
=
tf
.
while_loop
(
generated
,
_
,
cur_len
,
_
=
tf
.
while_loop
(
sample_cond_fn
,
sample_cond_fn
,
sample_body_fn
,
sample_body_fn
,
(
generated
,
finished_sequences
,
next_tokens
,
cur_len
,
model_kwargs
),
(
generated
,
finished_sequences
,
cur_len
,
model_kwargs
),
maximum_iterations
=
maximum_iterations
,
maximum_iterations
=
maximum_iterations
,
)
)
# 6. prepare outputs
# 6. prepare outputs
output_ids
=
tf
.
transpose
(
tf
.
reshape
(
generated
.
concat
(),
(
-
1
,
batch_size
)))
if
not
use_xla
:
if
not
use_xla
:
# cut for backward compatibility
# cut for backward compatibility
output_ids
=
output_ids
[:,
:
cur_len
]
generated
=
generated
[:,
:
cur_len
]
if
return_dict_in_generate
:
if
return_dict_in_generate
:
if
self
.
config
.
is_encoder_decoder
:
if
self
.
config
.
is_encoder_decoder
:
...
@@ -2394,7 +2367,7 @@ class TFGenerationMixin:
...
@@ -2394,7 +2367,7 @@ class TFGenerationMixin:
decoder_hidden_states
=
tuple
(
decoder_hidden_states
)
if
decoder_hidden_states
is
not
None
else
None
decoder_hidden_states
=
tuple
(
decoder_hidden_states
)
if
decoder_hidden_states
is
not
None
else
None
return
TFSampleEncoderDecoderOutput
(
return
TFSampleEncoderDecoderOutput
(
sequences
=
output_ids
,
sequences
=
generated
,
scores
=
scores
,
scores
=
scores
,
encoder_attentions
=
encoder_attentions
,
encoder_attentions
=
encoder_attentions
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
...
@@ -2404,13 +2377,13 @@ class TFGenerationMixin:
...
@@ -2404,13 +2377,13 @@ class TFGenerationMixin:
)
)
else
:
else
:
return
TFSampleDecoderOnlyOutput
(
return
TFSampleDecoderOnlyOutput
(
sequences
=
output_ids
,
sequences
=
generated
,
scores
=
scores
,
scores
=
scores
,
attentions
=
decoder_attentions
,
attentions
=
decoder_attentions
,
hidden_states
=
decoder_hidden_states
,
hidden_states
=
decoder_hidden_states
,
)
)
else
:
else
:
return
output_ids
return
generated
def
beam_search
(
def
beam_search
(
self
,
self
,
...
@@ -2585,6 +2558,8 @@ class TFGenerationMixin:
...
@@ -2585,6 +2558,8 @@ class TFGenerationMixin:
# GPT2 and other models has a slightly different cache structure, with a different batch axis
# GPT2 and other models has a slightly different cache structure, with a different batch axis
model_name
=
str
(
self
.
decoder
)
if
"EncoderDecoder"
in
str
(
self
)
else
str
(
self
)
model_name
=
str
(
self
.
decoder
)
if
"EncoderDecoder"
in
str
(
self
)
else
str
(
self
)
cache_batch_axis
=
1
if
any
([
model_prefix
in
model_name
for
model_prefix
in
(
"TFGPT2"
,
"TFCTRL"
)])
else
0
cache_batch_axis
=
1
if
any
([
model_prefix
in
model_name
for
model_prefix
in
(
"TFGPT2"
,
"TFCTRL"
)])
else
0
# some models, like XLNet, need more than the last token in the presence of past
needs_full_input
=
"use_mems"
in
set
(
inspect
.
signature
(
self
.
prepare_inputs_for_generation
).
parameters
.
keys
())
# 2. init `attentions`, `hidden_states`, and `scores` tuples
# 2. init `attentions`, `hidden_states`, and `scores` tuples
scores
=
[]
if
(
return_dict_in_generate
and
output_scores
)
else
None
scores
=
[]
if
(
return_dict_in_generate
and
output_scores
)
else
None
...
@@ -2594,41 +2569,13 @@ class TFGenerationMixin:
...
@@ -2594,41 +2569,13 @@ class TFGenerationMixin:
# 3. init tensors to use for "xla-compileable" generate function
# 3. init tensors to use for "xla-compileable" generate function
batch_size
,
num_beams
,
cur_len
=
input_ids
.
shape
batch_size
,
num_beams
,
cur_len
=
input_ids
.
shape
input_ids_length
=
cur_len
# per batch, beam-item holding current token in loop, pre-populated with `pad_token_id`
# per batch, beam-item holding current token in loop, pre-populated with `pad_token_id`
sequences
=
tf
.
TensorArray
(
input_ids_padding
=
tf
.
ones
((
batch_size
,
num_beams
,
max_length
-
cur_len
),
dtype
=
tf
.
int32
)
*
(
element_shape
=
(
batch_size
,
num_beams
),
pad_token_id
or
0
dtype
=
tf
.
int32
,
dynamic_size
=
False
,
size
=
max_length
,
clear_after_read
=
False
,
)
running_sequences
=
tf
.
TensorArray
(
element_shape
=
(
batch_size
,
num_beams
),
dtype
=
tf
.
int32
,
dynamic_size
=
False
,
size
=
max_length
,
clear_after_read
=
False
,
)
intermediary_running_sequences
=
tf
.
TensorArray
(
element_shape
=
(
batch_size
,
num_beams
*
2
),
dtype
=
tf
.
int32
,
dynamic_size
=
False
,
size
=
max_length
,
clear_after_read
=
False
,
)
)
if
pad_token_id
:
# ignores the cases when it is 0 or None
running_sequences
=
tf
.
concat
([
input_ids
,
input_ids_padding
],
axis
=-
1
)
for
i
in
range
(
max_length
):
sequences
=
tf
.
ones
((
batch_size
,
num_beams
,
max_length
),
dtype
=
tf
.
int32
)
*
(
pad_token_id
or
0
)
sequences
=
sequences
.
write
(
i
,
tf
.
broadcast_to
(
pad_token_id
,
(
batch_size
,
num_beams
)))
running_sequences
=
running_sequences
.
write
(
i
,
tf
.
broadcast_to
(
pad_token_id
,
(
batch_size
,
num_beams
)))
intermediary_running_sequences
=
intermediary_running_sequences
.
write
(
i
,
tf
.
broadcast_to
(
pad_token_id
,
(
batch_size
,
num_beams
*
2
))
)
# write prompt to running_sequences
for
i
in
range
(
cur_len
):
running_sequences
=
running_sequences
.
write
(
i
,
input_ids
[:,
:,
i
])
# per batch,beam-item state bit indicating if sentence has finished.
# per batch,beam-item state bit indicating if sentence has finished.
is_sent_finished
=
tf
.
zeros
((
batch_size
,
num_beams
),
dtype
=
tf
.
bool
)
is_sent_finished
=
tf
.
zeros
((
batch_size
,
num_beams
),
dtype
=
tf
.
bool
)
...
@@ -2656,7 +2603,6 @@ class TFGenerationMixin:
...
@@ -2656,7 +2603,6 @@ class TFGenerationMixin:
sequences
,
sequences
,
scores
,
scores
,
is_sent_finished
,
is_sent_finished
,
input_ids_length
,
model_kwargs
,
model_kwargs
,
):
):
"""
"""
...
@@ -2685,27 +2631,18 @@ class TFGenerationMixin:
...
@@ -2685,27 +2631,18 @@ class TFGenerationMixin:
sequences
,
sequences
,
scores
,
scores
,
is_sent_finished
,
is_sent_finished
,
input_ids_length
,
model_kwargs
,
model_kwargs
,
intermediary_running_sequences
=
None
,
):
):
"""
"""
Beam Search iterative update function -- each iteration adds a new token and updates the best sequences
Beam Search iterative update function -- each iteration adds a new token and updates the best sequences
seen so far
seen so far
"""
"""
# TODO (joao): this loop is probably faster with gather/scatters, instead of using `tf.TensorArray`.
# Alternativelly, attempt to rewrite function with permuted axis, when enabling XLA.
# 1. Forward current tokens
# 1. Forward current tokens
if
model_kwargs
.
get
(
"past"
)
is
None
or
needs_full_input
:
# TF places the dynamic dimension (seq_len) in the first axis, we want it in the last
input_ids
=
running_sequences
[:,
:,
:
cur_len
]
running_sequences_seq_last
=
tf
.
transpose
(
running_sequences
.
stack
(),
perm
=
[
1
,
2
,
0
])
else
:
input_token
=
tf
.
slice
(
input_ids
=
tf
.
expand_dims
(
running_sequences
[:,
:,
cur_len
-
1
],
-
1
)
running_sequences_seq_last
,
model_inputs
=
self
.
prepare_inputs_for_generation
(
flatten_beam_dim
(
input_ids
),
**
model_kwargs
)
(
0
,
0
,
cur_len
-
input_ids_length
),
(
batch_size
,
num_beams
,
input_ids_length
),
)
model_inputs
=
self
.
prepare_inputs_for_generation
(
flatten_beam_dim
(
input_token
),
**
model_kwargs
)
model_outputs
=
self
(
model_outputs
=
self
(
**
model_inputs
,
**
model_inputs
,
return_dict
=
True
,
return_dict
=
True
,
...
@@ -2734,9 +2671,7 @@ class TFGenerationMixin:
...
@@ -2734,9 +2671,7 @@ class TFGenerationMixin:
# get log probabilities from logits, process logits with processors (*e.g.* min_length, ...), and
# get log probabilities from logits, process logits with processors (*e.g.* min_length, ...), and
# add new logprobs to existing running logprobs scores.
# add new logprobs to existing running logprobs scores.
log_probs
=
tf
.
nn
.
log_softmax
(
logits
)
log_probs
=
tf
.
nn
.
log_softmax
(
logits
)
log_probs
=
logits_processor
(
log_probs
=
logits_processor
(
flatten_beam_dim
(
running_sequences
),
flatten_beam_dim
(
log_probs
),
cur_len
)
flatten_beam_dim
(
running_sequences_seq_last
),
flatten_beam_dim
(
log_probs
),
cur_len
)
log_probs
=
unflatten_beam_dim
(
log_probs
,
batch_size
,
num_beams
)
log_probs
=
unflatten_beam_dim
(
log_probs
,
batch_size
,
num_beams
)
log_probs
=
log_probs
+
tf
.
expand_dims
(
running_scores
,
axis
=
2
)
log_probs
=
log_probs
+
tf
.
expand_dims
(
running_scores
,
axis
=
2
)
vocab_size
=
log_probs
.
shape
[
2
]
vocab_size
=
log_probs
.
shape
[
2
]
...
@@ -2755,23 +2690,28 @@ class TFGenerationMixin:
...
@@ -2755,23 +2690,28 @@ class TFGenerationMixin:
beams_to_keep
=
2
*
num_beams
beams_to_keep
=
2
*
num_beams
topk_log_probs
,
topk_indices
=
tf
.
math
.
top_k
(
log_probs
,
k
=
beams_to_keep
)
topk_log_probs
,
topk_indices
=
tf
.
math
.
top_k
(
log_probs
,
k
=
beams_to_keep
)
topk_beam_indices
=
topk_indices
//
vocab_size
topk_beam_indices
=
topk_indices
//
vocab_size
topk_running_sequences
_seq_last
=
gather_beams
(
running_sequences
_seq_last
,
topk_beam_indices
)
topk_running_sequences
=
gather_beams
(
running_sequences
,
topk_beam_indices
)
topk_ids
=
topk_indices
%
vocab_size
topk_ids
=
topk_indices
%
vocab_size
# writes the new token
# writes the new token
intermediary_running_sequences
=
intermediary_running_sequences
.
unstack
(
indices_batch
=
tf
.
repeat
(
tf
.
range
(
batch_size
),
[
beams_to_keep
])
tf
.
transpose
(
topk_running_sequences_seq_last
,
perm
=
[
2
,
0
,
1
])
indices_beam
=
tf
.
tile
(
tf
.
range
(
beams_to_keep
),
[
batch_size
])
update_indices
=
tf
.
stack
(
[
indices_batch
,
indices_beam
,
tf
.
broadcast_to
(
cur_len
,
[
batch_size
*
beams_to_keep
])],
axis
=-
1
)
topk_sequences
=
tf
.
tensor_scatter_nd_update
(
tensor
=
topk_running_sequences
,
indices
=
update_indices
,
updates
=
tf
.
reshape
(
topk_ids
,
[
batch_size
*
beams_to_keep
]),
)
)
topk_sequences
=
intermediary_running_sequences
.
write
(
cur_len
,
topk_ids
)
topk_sequences_seq_last
=
tf
.
transpose
(
topk_sequences
.
stack
(),
perm
=
[
1
,
2
,
0
])
# 4. Check which sequences have ended
# 4. Check which sequences have ended
# Update current sequences: Did the top `num_beams` sequences reach an end marker?
# Update current sequences: Did the top `num_beams` sequences reach an end marker?
# To prevent these just finished sequences from being added to the current sequences
# To prevent these just finished sequences from being added to the current sequences
# set of active beam search sequences, set their log probs to a very large negative value.
# set of active beam search sequences, set their log probs to a very large negative value.
eos_in_next_token
=
topk_sequences
_seq_last
[:,
:,
cur_len
]
==
eos_token_id
eos_in_next_token
=
topk_sequences
[:,
:,
cur_len
]
==
eos_token_id
if
eos_token_id
is
None
:
if
eos_token_id
is
None
:
eos_in_next_token
=
tf
.
broadcast_to
(
eos_in_next_token
,
topk_sequences
_seq_last
[:,
:,
cur_len
].
shape
)
eos_in_next_token
=
tf
.
broadcast_to
(
eos_in_next_token
,
topk_sequences
[:,
:,
cur_len
].
shape
)
did_topk_just_finished
=
eos_in_next_token
&
tf
.
broadcast_to
(
did_topk_just_finished
=
eos_in_next_token
&
tf
.
broadcast_to
(
tf
.
concat
((
tf
.
ones
((
num_beams
),
dtype
=
tf
.
bool
),
tf
.
zeros
((
num_beams
),
dtype
=
tf
.
bool
)),
axis
=
0
),
tf
.
concat
((
tf
.
ones
((
num_beams
),
dtype
=
tf
.
bool
),
tf
.
zeros
((
num_beams
),
dtype
=
tf
.
bool
)),
axis
=
0
),
eos_in_next_token
.
shape
,
eos_in_next_token
.
shape
,
...
@@ -2785,8 +2725,8 @@ class TFGenerationMixin:
...
@@ -2785,8 +2725,8 @@ class TFGenerationMixin:
# Determine the top k beam indices (from top 2*k beams) from log probs and gather top k beams
# Determine the top k beam indices (from top 2*k beams) from log probs and gather top k beams
# (from top 2*k beams).
# (from top 2*k beams).
next_topk_indices
=
tf
.
math
.
top_k
(
running_topk_log_probs
,
k
=
num_beams
)[
1
]
next_topk_indices
=
tf
.
math
.
top_k
(
running_topk_log_probs
,
k
=
num_beams
)[
1
]
next_running_sequences
_seq_last
,
next_running_scores
=
gather_beams
(
next_running_sequences
,
next_running_scores
=
gather_beams
(
[
topk_sequences
_seq_last
,
running_topk_log_probs
],
next_topk_indices
[
topk_sequences
,
running_topk_log_probs
],
next_topk_indices
)
)
# 6. Process topk logits
# 6. Process topk logits
...
@@ -2807,18 +2747,18 @@ class TFGenerationMixin:
...
@@ -2807,18 +2747,18 @@ class TFGenerationMixin:
# 7. Get scores, sequences, is sentence finished for next.
# 7. Get scores, sequences, is sentence finished for next.
# Combine sequences, scores, and flags along the beam dimension and compare new finished sequence scores
# Combine sequences, scores, and flags along the beam dimension and compare new finished sequence scores
# to existing finished scores and select the best from the new set of beams
# to existing finished scores and select the best from the new set of beams
sequences_seq_last
=
tf
.
transpose
(
sequences
.
stack
(),
perm
=
[
1
,
2
,
0
])
merged_sequences
=
tf
.
concat
([
sequences
,
topk_sequences
],
axis
=
1
)
merged_sequences
=
tf
.
concat
([
sequences_seq_last
,
topk_sequences_seq_last
],
axis
=
1
)
merged_scores
=
tf
.
concat
([
scores
,
topk_log_probs
],
axis
=
1
)
merged_scores
=
tf
.
concat
([
scores
,
topk_log_probs
],
axis
=
1
)
merged_is_sent_finished
=
tf
.
concat
([
is_sent_finished
,
did_topk_just_finished
],
axis
=
1
)
merged_is_sent_finished
=
tf
.
concat
([
is_sent_finished
,
did_topk_just_finished
],
axis
=
1
)
topk_merged_indices
=
tf
.
math
.
top_k
(
merged_scores
,
k
=
num_beams
)[
1
]
topk_merged_indices
=
tf
.
math
.
top_k
(
merged_scores
,
k
=
num_beams
)[
1
]
next_sequences
_seq_last
,
next_scores
,
next_is_sent_finished
=
gather_beams
(
next_sequences
,
next_scores
,
next_is_sent_finished
=
gather_beams
(
[
merged_sequences
,
merged_scores
,
merged_is_sent_finished
],
topk_merged_indices
[
merged_sequences
,
merged_scores
,
merged_is_sent_finished
],
topk_merged_indices
)
)
# 8. Prepare data for the next iteration
# 8. Prepare data for the next iteration
# Determine the top k beam indices from the original set of all beams. With these, gather the top k
# Determine the top k beam indices from the original set of all beams. With these, gather the top k
# beam-associated caches.
# beam-associated caches.
cur_len
=
cur_len
+
1
if
"past_key_values"
in
model_outputs
:
if
"past_key_values"
in
model_outputs
:
cache
=
tf
.
nest
.
map_structure
(
cache
=
tf
.
nest
.
map_structure
(
lambda
tensor
:
unflatten_beam_dim
(
tensor
,
batch_size
,
num_beams
,
batch_axis
=
cache_batch_axis
),
lambda
tensor
:
unflatten_beam_dim
(
tensor
,
batch_size
,
num_beams
,
batch_axis
=
cache_batch_axis
),
...
@@ -2841,35 +2781,20 @@ class TFGenerationMixin:
...
@@ -2841,35 +2781,20 @@ class TFGenerationMixin:
# if we don't cache past key values we need the whole input
# if we don't cache past key values we need the whole input
if
model_kwargs
.
get
(
"past"
,
None
)
is
None
:
if
model_kwargs
.
get
(
"past"
,
None
)
is
None
:
next_input_ids_length
=
cur_len
+
1
# let's throw out `past` since we don't want `None` tensors
# let's throw out `past` since we don't want `None` tensors
model_kwargs
.
pop
(
"past"
,
None
)
model_kwargs
.
pop
(
"past"
,
None
)
else
:
next_input_ids_length
=
1
# 9. Prepare the `tf.TensorArray` for the next iteration
next_sequences
=
sequences
.
unstack
(
tf
.
transpose
(
next_sequences_seq_last
,
perm
=
[
2
,
0
,
1
]))
next_running_sequences
=
running_sequences
.
unstack
(
tf
.
transpose
(
next_running_sequences_seq_last
,
perm
=
[
2
,
0
,
1
])
)
return
(
return
(
cur_len
+
1
,
cur_len
,
next_running_sequences
,
next_running_sequences
,
next_running_scores
,
next_running_scores
,
next_sequences
,
next_sequences
,
next_scores
,
next_scores
,
next_is_sent_finished
,
next_is_sent_finished
,
next_input_ids_length
,
next_model_kwargs
,
next_model_kwargs
,
)
)
# 5. run generation
# 5. run generation
# Adds the `intermediary_running_sequences` TensorArray into the body, needed as a scratchpad
beam_search_body_fn
=
partial
(
beam_search_body_fn
,
intermediary_running_sequences
=
intermediary_running_sequences
)
# 1st generation step has to be run before to initialize `past` (if active)
# 1st generation step has to be run before to initialize `past` (if active)
(
(
cur_len
,
cur_len
,
...
@@ -2878,66 +2803,38 @@ class TFGenerationMixin:
...
@@ -2878,66 +2803,38 @@ class TFGenerationMixin:
sequences
,
sequences
,
scores
,
scores
,
is_sent_finished
,
is_sent_finished
,
input_ids_length
,
model_kwargs
,
model_kwargs
,
)
=
beam_search_body_fn
(
)
=
beam_search_body_fn
(
cur_len
,
cur_len
,
running_sequences
,
running_scores
,
sequences
,
scores
,
is_sent_finished
,
model_kwargs
running_sequences
,
running_scores
,
sequences
,
scores
,
is_sent_finished
,
input_ids_length
,
model_kwargs
,
)
)
# 2-to-n generation steps can then be run in autoregressive fashion (only in case 1st generation step does
# 2-to-n generation steps can then be run in autoregressive fashion (only in case 1st generation step does
# NOT yield EOS token though)
# NOT yield EOS token though)
if
beam_search_cond_fn
(
if
beam_search_cond_fn
(
cur_len
,
cur_len
,
running_sequences
,
running_scores
,
sequences
,
scores
,
is_sent_finished
,
model_kwargs
running_sequences
,
running_scores
,
sequences
,
scores
,
is_sent_finished
,
input_ids_length
,
model_kwargs
,
):
):
maximum_iterations
=
max_length
-
cur_len
maximum_iterations
=
max_length
-
cur_len
cur_len
,
running_sequences
,
running_scores
,
sequences
,
scores
,
is_sent_finished
,
_
,
_
=
tf
.
while_loop
(
cur_len
,
running_sequences
,
running_scores
,
sequences
,
scores
,
is_sent_finished
,
_
=
tf
.
while_loop
(
beam_search_cond_fn
,
beam_search_cond_fn
,
beam_search_body_fn
,
beam_search_body_fn
,
(
(
cur_len
,
running_sequences
,
running_scores
,
sequences
,
scores
,
is_sent_finished
,
model_kwargs
),
cur_len
,
running_sequences
,
running_scores
,
sequences
,
scores
,
is_sent_finished
,
input_ids_length
,
model_kwargs
,
),
maximum_iterations
=
maximum_iterations
,
maximum_iterations
=
maximum_iterations
,
)
)
# 6. prepare outputs
# 6. prepare outputs
# convert the sequneces to tf.Tensor with shape (batch_size, num_beams, seq_len)
sequences_seq_last
=
tf
.
transpose
(
sequences
.
stack
(),
perm
=
[
1
,
2
,
0
])
running_sequences_seq_last
=
tf
.
transpose
(
running_sequences
.
stack
(),
perm
=
[
1
,
2
,
0
])
# Account for the edge-case where there are no finished sequences for a particular batch item. If so, return
# Account for the edge-case where there are no finished sequences for a particular batch item. If so, return
# running sequences for that batch item.
# running sequences for that batch item.
none_finished
=
tf
.
math
.
reduce_any
(
is_sent_finished
,
axis
=
1
)
none_finished
=
tf
.
math
.
reduce_any
(
is_sent_finished
,
axis
=
1
)
sequences
_seq_last
=
tf
.
where
(
none_finished
[:,
None
,
None
],
sequences
_seq_last
,
running_sequences
_seq_last
)
sequences
=
tf
.
where
(
none_finished
[:,
None
,
None
],
sequences
,
running_sequences
)
scores
=
tf
.
where
(
none_finished
[:,
None
],
scores
,
running_scores
)
scores
=
tf
.
where
(
none_finished
[:,
None
],
scores
,
running_scores
)
# Take best beams for each batch (the score is sorted in ascending order)
# Take best beams for each batch (the score is sorted in ascending order)
sequences
_seq_last
=
flatten_beam_dim
(
sequences
_seq_last
[:,
:
num_return_sequences
,
:])
sequences
=
flatten_beam_dim
(
sequences
[:,
:
num_return_sequences
,
:])
scores
=
flatten_beam_dim
(
scores
[:,
:
num_return_sequences
])
scores
=
flatten_beam_dim
(
scores
[:,
:
num_return_sequences
])
if
not
use_xla
:
if
not
use_xla
:
# Cut for backward compatibility
# Cut for backward compatibility
sequences
_seq_last
=
sequences
_seq_last
[:,
:
cur_len
]
sequences
=
sequences
[:,
:
cur_len
]
if
return_dict_in_generate
:
if
return_dict_in_generate
:
if
self
.
config
.
is_encoder_decoder
:
if
self
.
config
.
is_encoder_decoder
:
...
@@ -2948,7 +2845,7 @@ class TFGenerationMixin:
...
@@ -2948,7 +2845,7 @@ class TFGenerationMixin:
)
)
return
TFBeamSearchEncoderDecoderOutput
(
return
TFBeamSearchEncoderDecoderOutput
(
sequences
=
sequences
_seq_last
,
sequences
=
sequences
,
scores
=
scores
,
scores
=
scores
,
encoder_attentions
=
encoder_attentions
,
encoder_attentions
=
encoder_attentions
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
...
@@ -2958,13 +2855,13 @@ class TFGenerationMixin:
...
@@ -2958,13 +2855,13 @@ class TFGenerationMixin:
)
)
else
:
else
:
return
TFBeamSearchDecoderOnlyOutput
(
return
TFBeamSearchDecoderOnlyOutput
(
sequences
=
sequences
_seq_last
,
sequences
=
sequences
,
scores
=
scores
,
scores
=
scores
,
attentions
=
decoder_attentions
,
attentions
=
decoder_attentions
,
hidden_states
=
decoder_hidden_states
,
hidden_states
=
decoder_hidden_states
,
)
)
else
:
else
:
return
sequences
_seq_last
return
sequences
def
_create_next_token_logits_penalties
(
input_ids
,
logits
,
repetition_penalty
):
def
_create_next_token_logits_penalties
(
input_ids
,
logits
,
repetition_penalty
):
...
...
src/transformers/models/gpt2/modeling_tf_gpt2.py
View file @
5cce3076
...
@@ -874,8 +874,9 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
...
@@ -874,8 +874,9 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
new_past
=
[
None
for
_
in
range
(
len
(
past
))]
new_past
=
[
None
for
_
in
range
(
len
(
past
))]
slice_start_base
=
tf
.
constant
([
0
,
0
,
0
,
1
,
0
])
slice_start_base
=
tf
.
constant
([
0
,
0
,
0
,
1
,
0
])
attention_mask_update_slice
=
tf
.
ones
((
batch_size
,
1
),
dtype
=
attention_mask
.
dtype
)
attention_mask_update_slice
=
tf
.
ones
((
batch_size
,
1
),
dtype
=
attention_mask
.
dtype
)
# correct 5 here
# -1 because current_pos has already been incremented before this function
new_past_index
=
current_pos
-
1
# -1 again because last index = len - 1
new_past_index
=
current_pos
-
2
for
i
in
range
(
len
(
past
)):
for
i
in
range
(
len
(
past
)):
update_slice
=
past
[
i
][:,
:,
:,
-
1
:]
update_slice
=
past
[
i
][:,
:,
:,
-
1
:]
...
...
src/transformers/models/xlnet/modeling_tf_xlnet.py
View file @
5cce3076
...
@@ -1202,7 +1202,6 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
...
@@ -1202,7 +1202,6 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
def
prepare_inputs_for_generation
(
self
,
inputs
,
past
=
None
,
use_mems
=
None
,
**
kwargs
):
def
prepare_inputs_for_generation
(
self
,
inputs
,
past
=
None
,
use_mems
=
None
,
**
kwargs
):
# Add dummy token at the end (no attention on this one)
# Add dummy token at the end (no attention on this one)
effective_batch_size
=
inputs
.
shape
[
0
]
effective_batch_size
=
inputs
.
shape
[
0
]
dummy_token
=
tf
.
zeros
((
effective_batch_size
,
1
),
dtype
=
inputs
.
dtype
)
dummy_token
=
tf
.
zeros
((
effective_batch_size
,
1
),
dtype
=
inputs
.
dtype
)
...
@@ -1212,12 +1211,12 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
...
@@ -1212,12 +1211,12 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
offset
=
2
offset
=
2
if
past
:
if
past
:
inputs
=
tf
.
concat
([
inputs
[:,
-
offset
:],
dummy_token
],
axis
=
1
)
input
_id
s
=
tf
.
concat
([
inputs
[:,
-
offset
:],
dummy_token
],
axis
=
1
)
else
:
else
:
inputs
=
tf
.
concat
([
inputs
,
dummy_token
],
axis
=
1
)
input
_id
s
=
tf
.
concat
([
inputs
,
dummy_token
],
axis
=
1
)
# Build permutation mask so that previous tokens don't see last token
# Build permutation mask so that previous tokens don't see last token
sequence_length
=
inputs
.
shape
[
1
]
sequence_length
=
input
_id
s
.
shape
[
1
]
perm_mask
=
tf
.
zeros
((
effective_batch_size
,
sequence_length
,
sequence_length
-
1
))
perm_mask
=
tf
.
zeros
((
effective_batch_size
,
sequence_length
,
sequence_length
-
1
))
perm_mask_seq_end
=
tf
.
ones
((
effective_batch_size
,
sequence_length
,
1
))
perm_mask_seq_end
=
tf
.
ones
((
effective_batch_size
,
sequence_length
,
1
))
perm_mask
=
tf
.
concat
([
perm_mask
,
perm_mask_seq_end
],
axis
=-
1
)
perm_mask
=
tf
.
concat
([
perm_mask
,
perm_mask_seq_end
],
axis
=-
1
)
...
@@ -1228,7 +1227,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
...
@@ -1228,7 +1227,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
target_mapping
=
tf
.
concat
([
target_mapping
,
target_mapping_seq_end
],
axis
=-
1
)
target_mapping
=
tf
.
concat
([
target_mapping
,
target_mapping_seq_end
],
axis
=-
1
)
inputs
=
{
inputs
=
{
"input_ids"
:
inputs
,
"input_ids"
:
input
_id
s
,
"perm_mask"
:
perm_mask
,
"perm_mask"
:
perm_mask
,
"target_mapping"
:
target_mapping
,
"target_mapping"
:
target_mapping
,
"use_mems"
:
use_mems
,
"use_mems"
:
use_mems
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment