Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5cce3076
Unverified
Commit
5cce3076
authored
Jun 23, 2022
by
Joao Gante
Committed by
GitHub
Jun 23, 2022
Browse files
TF: generate without `tf.TensorArray` (#17801)
parent
ab223fc1
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
97 additions
and
200 deletions
+97
-200
src/transformers/generation_tf_utils.py
src/transformers/generation_tf_utils.py
+90
-193
src/transformers/models/gpt2/modeling_tf_gpt2.py
src/transformers/models/gpt2/modeling_tf_gpt2.py
+3
-2
src/transformers/models/xlnet/modeling_tf_xlnet.py
src/transformers/models/xlnet/modeling_tf_xlnet.py
+4
-5
No files found.
src/transformers/generation_tf_utils.py
View file @
5cce3076
This diff is collapsed.
Click to expand it.
src/transformers/models/gpt2/modeling_tf_gpt2.py
View file @
5cce3076
...
@@ -874,8 +874,9 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
...
@@ -874,8 +874,9 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
new_past
=
[
None
for
_
in
range
(
len
(
past
))]
new_past
=
[
None
for
_
in
range
(
len
(
past
))]
slice_start_base
=
tf
.
constant
([
0
,
0
,
0
,
1
,
0
])
slice_start_base
=
tf
.
constant
([
0
,
0
,
0
,
1
,
0
])
attention_mask_update_slice
=
tf
.
ones
((
batch_size
,
1
),
dtype
=
attention_mask
.
dtype
)
attention_mask_update_slice
=
tf
.
ones
((
batch_size
,
1
),
dtype
=
attention_mask
.
dtype
)
# correct 5 here
# -1 because current_pos has already been incremented before this function
new_past_index
=
current_pos
-
1
# -1 again because last index = len - 1
new_past_index
=
current_pos
-
2
for
i
in
range
(
len
(
past
)):
for
i
in
range
(
len
(
past
)):
update_slice
=
past
[
i
][:,
:,
:,
-
1
:]
update_slice
=
past
[
i
][:,
:,
:,
-
1
:]
...
...
src/transformers/models/xlnet/modeling_tf_xlnet.py
View file @
5cce3076
...
@@ -1202,7 +1202,6 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
...
@@ -1202,7 +1202,6 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
def
prepare_inputs_for_generation
(
self
,
inputs
,
past
=
None
,
use_mems
=
None
,
**
kwargs
):
def
prepare_inputs_for_generation
(
self
,
inputs
,
past
=
None
,
use_mems
=
None
,
**
kwargs
):
# Add dummy token at the end (no attention on this one)
# Add dummy token at the end (no attention on this one)
effective_batch_size
=
inputs
.
shape
[
0
]
effective_batch_size
=
inputs
.
shape
[
0
]
dummy_token
=
tf
.
zeros
((
effective_batch_size
,
1
),
dtype
=
inputs
.
dtype
)
dummy_token
=
tf
.
zeros
((
effective_batch_size
,
1
),
dtype
=
inputs
.
dtype
)
...
@@ -1212,12 +1211,12 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
...
@@ -1212,12 +1211,12 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
offset
=
2
offset
=
2
if
past
:
if
past
:
inputs
=
tf
.
concat
([
inputs
[:,
-
offset
:],
dummy_token
],
axis
=
1
)
input
_id
s
=
tf
.
concat
([
inputs
[:,
-
offset
:],
dummy_token
],
axis
=
1
)
else
:
else
:
inputs
=
tf
.
concat
([
inputs
,
dummy_token
],
axis
=
1
)
input
_id
s
=
tf
.
concat
([
inputs
,
dummy_token
],
axis
=
1
)
# Build permutation mask so that previous tokens don't see last token
# Build permutation mask so that previous tokens don't see last token
sequence_length
=
inputs
.
shape
[
1
]
sequence_length
=
input
_id
s
.
shape
[
1
]
perm_mask
=
tf
.
zeros
((
effective_batch_size
,
sequence_length
,
sequence_length
-
1
))
perm_mask
=
tf
.
zeros
((
effective_batch_size
,
sequence_length
,
sequence_length
-
1
))
perm_mask_seq_end
=
tf
.
ones
((
effective_batch_size
,
sequence_length
,
1
))
perm_mask_seq_end
=
tf
.
ones
((
effective_batch_size
,
sequence_length
,
1
))
perm_mask
=
tf
.
concat
([
perm_mask
,
perm_mask_seq_end
],
axis
=-
1
)
perm_mask
=
tf
.
concat
([
perm_mask
,
perm_mask_seq_end
],
axis
=-
1
)
...
@@ -1228,7 +1227,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
...
@@ -1228,7 +1227,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
target_mapping
=
tf
.
concat
([
target_mapping
,
target_mapping_seq_end
],
axis
=-
1
)
target_mapping
=
tf
.
concat
([
target_mapping
,
target_mapping_seq_end
],
axis
=-
1
)
inputs
=
{
inputs
=
{
"input_ids"
:
inputs
,
"input_ids"
:
input
_id
s
,
"perm_mask"
:
perm_mask
,
"perm_mask"
:
perm_mask
,
"target_mapping"
:
target_mapping
,
"target_mapping"
:
target_mapping
,
"use_mems"
:
use_mems
,
"use_mems"
:
use_mems
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment