Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c4c4c999
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "ef0ac063c9b9da3e4da759866736e266dbb44cfe"
Commit
c4c4c999
authored
Mar 04, 2020
by
Patrick von Platen
Browse files
make GPT2 and CTRL shape consistent between torch and TF
parent
2529b2d3
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
30 additions
and
13 deletions
+30
-13
src/transformers/modeling_tf_ctrl.py
src/transformers/modeling_tf_ctrl.py
+2
-2
src/transformers/modeling_tf_gpt2.py
src/transformers/modeling_tf_gpt2.py
+2
-2
src/transformers/modeling_tf_utils.py
src/transformers/modeling_tf_utils.py
+26
-9
No files found.
src/transformers/modeling_tf_ctrl.py
View file @
c4c4c999
...
@@ -104,10 +104,10 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
...
@@ -104,10 +104,10 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
k
=
self
.
split_into_heads
(
k
,
batch_size
)
k
=
self
.
split_into_heads
(
k
,
batch_size
)
v
=
self
.
split_into_heads
(
v
,
batch_size
)
v
=
self
.
split_into_heads
(
v
,
batch_size
)
if
layer_past
is
not
None
:
if
layer_past
is
not
None
:
past_key
,
past_value
=
tf
.
unstack
(
layer_past
,
axis
=
1
)
past_key
,
past_value
=
tf
.
unstack
(
layer_past
,
axis
=
0
)
k
=
tf
.
concat
((
past_key
,
k
),
axis
=-
2
)
k
=
tf
.
concat
((
past_key
,
k
),
axis
=-
2
)
v
=
tf
.
concat
((
past_value
,
v
),
axis
=-
2
)
v
=
tf
.
concat
((
past_value
,
v
),
axis
=-
2
)
present
=
tf
.
stack
((
k
,
v
),
axis
=
1
)
present
=
tf
.
stack
((
k
,
v
),
axis
=
0
)
output
=
scaled_dot_product_attention
(
q
,
k
,
v
,
mask
,
attention_mask
,
head_mask
)
output
=
scaled_dot_product_attention
(
q
,
k
,
v
,
mask
,
attention_mask
,
head_mask
)
scaled_attention
=
tf
.
transpose
(
output
[
0
],
perm
=
[
0
,
2
,
1
,
3
])
scaled_attention
=
tf
.
transpose
(
output
[
0
],
perm
=
[
0
,
2
,
1
,
3
])
...
...
src/transformers/modeling_tf_gpt2.py
View file @
c4c4c999
...
@@ -139,10 +139,10 @@ class TFAttention(tf.keras.layers.Layer):
...
@@ -139,10 +139,10 @@ class TFAttention(tf.keras.layers.Layer):
key
=
self
.
split_heads
(
key
)
key
=
self
.
split_heads
(
key
)
value
=
self
.
split_heads
(
value
)
value
=
self
.
split_heads
(
value
)
if
layer_past
is
not
None
:
if
layer_past
is
not
None
:
past_key
,
past_value
=
tf
.
unstack
(
layer_past
,
axis
=
1
)
past_key
,
past_value
=
tf
.
unstack
(
layer_past
,
axis
=
0
)
key
=
tf
.
concat
([
past_key
,
key
],
axis
=-
2
)
key
=
tf
.
concat
([
past_key
,
key
],
axis
=-
2
)
value
=
tf
.
concat
([
past_value
,
value
],
axis
=-
2
)
value
=
tf
.
concat
([
past_value
,
value
],
axis
=-
2
)
present
=
tf
.
stack
([
key
,
value
],
axis
=
1
)
present
=
tf
.
stack
([
key
,
value
],
axis
=
0
)
attn_outputs
=
self
.
_attn
([
query
,
key
,
value
,
attention_mask
,
head_mask
],
training
=
training
)
attn_outputs
=
self
.
_attn
([
query
,
key
,
value
,
attention_mask
,
head_mask
],
training
=
training
)
a
=
attn_outputs
[
0
]
a
=
attn_outputs
[
0
]
...
...
src/transformers/modeling_tf_utils.py
View file @
c4c4c999
...
@@ -658,7 +658,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
...
@@ -658,7 +658,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
if
repetition_penalty
!=
1.0
:
if
repetition_penalty
!=
1.0
:
next_token_logits_penalties
=
_create_next_token_logits_penalties
(
input_ids
,
next_token_logits
,
repetition_penalty
)
next_token_logits_penalties
=
_create_next_token_logits_penalties
(
input_ids
,
next_token_logits
,
repetition_penalty
)
next_token_logits
=
tf
.
math
.
multiply
(
next_token_logits
,
next_token_logits_penalties
)
next_token_logits
=
tf
.
math
.
multiply
(
next_token_logits
,
next_token_logits_penalties
)
if
do_sample
:
if
do_sample
:
...
@@ -779,7 +781,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
...
@@ -779,7 +781,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
if
repetition_penalty
!=
1.0
:
if
repetition_penalty
!=
1.0
:
next_token_logits_penalties
=
_create_next_token_logits_penalties
(
input_ids
,
next_token_logits
,
repetition_penalty
)
next_token_logits_penalties
=
_create_next_token_logits_penalties
(
input_ids
,
next_token_logits
,
repetition_penalty
)
next_token_logits
=
tf
.
math
.
multiply
(
next_token_logits
,
next_token_logits_penalties
)
next_token_logits
=
tf
.
math
.
multiply
(
next_token_logits
,
next_token_logits_penalties
)
if
do_sample
:
if
do_sample
:
...
@@ -791,11 +795,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
...
@@ -791,11 +795,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
next_token_logits
,
top_k
=
top_k
,
top_p
=
top_p
,
min_tokens_to_keep
=
2
next_token_logits
,
top_k
=
top_k
,
top_p
=
top_p
,
min_tokens_to_keep
=
2
)
# (batch_size * num_beams, vocab_size)
)
# (batch_size * num_beams, vocab_size)
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
next_tokens
=
tf
.
random
.
categorical
(
next_token_logits
,
dtype
=
tf
.
int32
,
num_samples
=
2
)
# (batch_size * num_beams, vocab_size)
next_tokens
=
tf
.
random
.
categorical
(
next_token_logits
,
dtype
=
tf
.
int32
,
num_samples
=
2
)
# (batch_size * num_beams, vocab_size)
# Compute next scores
# Compute next scores
scores
=
tf
.
nn
.
log_softmax
(
next_token_logits
,
axis
=-
1
)
# (batch_size * num_beams, vocab_size)
scores
=
tf
.
nn
.
log_softmax
(
next_token_logits
,
axis
=-
1
)
# (batch_size * num_beams, vocab_size)
_scores
=
tf
.
gather
(
scores
,
next_tokens
,
batch_dims
=
1
)
# (batch_size * num_beams, 2)
_scores
=
tf
.
gather
(
scores
,
next_tokens
,
batch_dims
=
1
)
# (batch_size * num_beams, 2)
next_scores
=
_scores
+
tf
.
broadcast_to
(
beam_scores
[:,
None
],
(
batch_size
*
num_beams
,
2
))
# (batch_size * num_beams, 2)
next_scores
=
_scores
+
tf
.
broadcast_to
(
beam_scores
[:,
None
],
(
batch_size
*
num_beams
,
2
)
)
# (batch_size * num_beams, 2)
# Match shape of greedy beam search
# Match shape of greedy beam search
next_tokens
=
tf
.
reshape
(
next_tokens
,
(
batch_size
,
2
*
num_beams
))
# (batch_size, 2 * num_beams)
next_tokens
=
tf
.
reshape
(
next_tokens
,
(
batch_size
,
2
*
num_beams
))
# (batch_size, 2 * num_beams)
next_scores
=
tf
.
reshape
(
next_scores
,
(
batch_size
,
2
*
num_beams
))
# (batch_size, 2 * num_beams)
next_scores
=
tf
.
reshape
(
next_scores
,
(
batch_size
,
2
*
num_beams
))
# (batch_size, 2 * num_beams)
...
@@ -804,10 +812,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
...
@@ -804,10 +812,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
scores
=
tf
.
nn
.
log_softmax
(
next_token_logits
,
axis
=-
1
)
# (batch_size * num_beams, vocab_size)
scores
=
tf
.
nn
.
log_softmax
(
next_token_logits
,
axis
=-
1
)
# (batch_size * num_beams, vocab_size)
assert
shape_list
(
scores
)
==
[
batch_size
*
num_beams
,
vocab_size
]
assert
shape_list
(
scores
)
==
[
batch_size
*
num_beams
,
vocab_size
]
# Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
# Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
next_scores
=
scores
+
tf
.
broadcast_to
(
beam_scores
[:,
None
],
(
batch_size
*
num_beams
,
vocab_size
))
# (batch_size * num_beams, vocab_size)
next_scores
=
scores
+
tf
.
broadcast_to
(
beam_scores
[:,
None
],
(
batch_size
*
num_beams
,
vocab_size
)
)
# (batch_size * num_beams, vocab_size)
# re-organize to group the beam together (we are keeping top hypothesis accross beams)
# re-organize to group the beam together (we are keeping top hypothesis accross beams)
next_scores
=
tf
.
reshape
(
next_scores
,
(
batch_size
,
num_beams
*
vocab_size
))
# (batch_size, num_beams * vocab_size)
next_scores
=
tf
.
reshape
(
next_scores
,
(
batch_size
,
num_beams
*
vocab_size
)
)
# (batch_size, num_beams * vocab_size)
next_scores
,
next_tokens
=
tf
.
math
.
top_k
(
next_scores
,
2
*
num_beams
,
sorted
=
True
)
next_scores
,
next_tokens
=
tf
.
math
.
top_k
(
next_scores
,
2
*
num_beams
,
sorted
=
True
)
assert
shape_list
(
next_scores
)
==
shape_list
(
next_tokens
)
==
[
batch_size
,
2
*
num_beams
]
assert
shape_list
(
next_scores
)
==
shape_list
(
next_tokens
)
==
[
batch_size
,
2
*
num_beams
]
...
@@ -909,7 +921,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
...
@@ -909,7 +921,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
best_hyp
=
sorted_hyps
.
pop
()[
1
]
best_hyp
=
sorted_hyps
.
pop
()[
1
]
sent_lengths_list
.
append
(
len
(
best_hyp
))
sent_lengths_list
.
append
(
len
(
best_hyp
))
best
.
append
(
best_hyp
)
best
.
append
(
best_hyp
)
assert
output_batch_size
==
len
(
best
),
"Output batch size {} must match output beam hypotheses {}"
.
format
(
output_batch_size
,
len
(
best
))
assert
output_batch_size
==
len
(
best
),
"Output batch size {} must match output beam hypotheses {}"
.
format
(
output_batch_size
,
len
(
best
)
)
sent_lengths
=
tf
.
convert_to_tensor
(
sent_lengths_list
,
dtype
=
tf
.
int32
)
sent_lengths
=
tf
.
convert_to_tensor
(
sent_lengths_list
,
dtype
=
tf
.
int32
)
...
@@ -925,7 +939,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
...
@@ -925,7 +939,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
decoded_hypo
=
tf
.
concat
([
hypo
,
padding
],
axis
=
0
)
decoded_hypo
=
tf
.
concat
([
hypo
,
padding
],
axis
=
0
)
if
sent_lengths
[
i
]
<
max_length
:
if
sent_lengths
[
i
]
<
max_length
:
decoded_hypo
=
tf
.
where
(
tf
.
range
(
max_length
)
==
sent_lengths
[
i
],
eos_token_ids
[
0
]
*
tf
.
ones
((
sent_max_len
,),
dtype
=
tf
.
int32
),
decoded_hypo
)
decoded_hypo
=
tf
.
where
(
tf
.
range
(
max_length
)
==
sent_lengths
[
i
],
eos_token_ids
[
0
]
*
tf
.
ones
((
sent_max_len
,),
dtype
=
tf
.
int32
),
decoded_hypo
,
)
decoded_list
.
append
(
decoded_hypo
)
decoded_list
.
append
(
decoded_hypo
)
decoded
=
tf
.
stack
(
decoded_list
)
decoded
=
tf
.
stack
(
decoded_list
)
else
:
else
:
...
@@ -942,7 +960,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
...
@@ -942,7 +960,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
# get the correct batch idx from layer past batch dim
# get the correct batch idx from layer past batch dim
# batch dim of `past` and `mems` is at 2nd position
# batch dim of `past` and `mems` is at 2nd position
reordered_layer_past
=
[
tf
.
identity
(
tf
.
expand_dims
(
layer_past
[:,
i
],
1
))
for
i
in
beam_idx
]
reordered_layer_past
=
[
tf
.
identity
(
tf
.
expand_dims
(
layer_past
[:,
i
],
1
))
for
i
in
beam_idx
]
# TODO: check whether it is an error that TF past.shape != Torch past.shape
reordered_layer_past
=
tf
.
concat
(
reordered_layer_past
,
axis
=
1
)
reordered_layer_past
=
tf
.
concat
(
reordered_layer_past
,
axis
=
1
)
# check that shape matches
# check that shape matches
assert
shape_list
(
reordered_layer_past
)
==
shape_list
(
layer_past
)
assert
shape_list
(
reordered_layer_past
)
==
shape_list
(
layer_past
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment