Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
34df26ec
Unverified
Commit
34df26ec
authored
Feb 19, 2021
by
Julien Plu
Committed by
GitHub
Feb 19, 2021
Browse files
Making TF OpenAI GPT model compliant with AMP and XLA (#10261)
* Fix AMP and XLA * Remove useless var
parent
3e116ed3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
29 additions
and
34 deletions
+29
-34
src/transformers/models/openai/modeling_tf_openai.py
src/transformers/models/openai/modeling_tf_openai.py
+29
-26
tests/test_modeling_tf_openai.py
tests/test_modeling_tf_openai.py
+0
-8
No files found.
src/transformers/models/openai/modeling_tf_openai.py
View file @
34df26ec
...
@@ -81,7 +81,7 @@ class TFAttention(tf.keras.layers.Layer):
...
@@ -81,7 +81,7 @@ class TFAttention(tf.keras.layers.Layer):
pass
pass
@
staticmethod
@
staticmethod
def
causal_attention_mask
(
nd
,
ns
,
dtype
):
def
causal_attention_mask
(
nd
,
ns
):
"""
"""
1's in the lower triangle, counting from the lower right corner. Same as tf.matrix_band_part(tf.ones([nd, ns]),
1's in the lower triangle, counting from the lower right corner. Same as tf.matrix_band_part(tf.ones([nd, ns]),
-1, ns-nd), but doesn't produce garbage on TPUs.
-1, ns-nd), but doesn't produce garbage on TPUs.
...
@@ -89,23 +89,24 @@ class TFAttention(tf.keras.layers.Layer):
...
@@ -89,23 +89,24 @@ class TFAttention(tf.keras.layers.Layer):
i
=
tf
.
range
(
nd
)[:,
None
]
i
=
tf
.
range
(
nd
)[:,
None
]
j
=
tf
.
range
(
ns
)
j
=
tf
.
range
(
ns
)
m
=
i
>=
j
-
ns
+
nd
m
=
i
>=
j
-
ns
+
nd
return
tf
.
cast
(
m
,
dtype
)
return
m
def
_attn
(
self
,
q
,
k
,
v
,
attention_mask
,
head_mask
,
output_attentions
,
training
=
False
):
def
_attn
(
self
,
q
,
k
,
v
,
attention_mask
,
head_mask
,
output_attentions
,
training
=
False
):
# q, k, v have shape [batch, heads, sequence, features]
# q, k, v have shape [batch, heads, sequence, features]
w
=
tf
.
matmul
(
q
,
k
,
transpose_b
=
True
)
w
=
tf
.
matmul
(
q
,
k
,
transpose_b
=
True
)
if
self
.
scale
:
if
self
.
scale
:
dk
=
tf
.
cast
(
shape_list
(
k
)[
-
1
],
tf
.
float32
)
# scale attention_scores
dk
=
tf
.
cast
(
shape_list
(
k
)[
-
1
],
dtype
=
w
.
dtype
)
# scale attention_scores
w
=
w
/
tf
.
math
.
sqrt
(
dk
)
w
=
w
/
tf
.
math
.
sqrt
(
dk
)
# w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
# w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
_
,
_
,
nd
,
ns
=
shape_list
(
w
)
_
,
_
,
nd
,
ns
=
shape_list
(
w
)
b
=
self
.
causal_attention_mask
(
nd
,
ns
,
dtype
=
w
.
dtype
)
b
=
tf
.
cast
(
self
.
causal_attention_mask
(
nd
,
ns
)
,
dtype
=
w
.
dtype
)
b
=
tf
.
reshape
(
b
,
[
1
,
1
,
nd
,
ns
])
b
=
tf
.
reshape
(
b
,
[
1
,
1
,
nd
,
ns
])
w
=
w
*
b
-
1e4
*
(
1
-
b
)
w
=
w
*
b
-
1e4
*
(
1
-
b
)
if
attention_mask
is
not
None
:
if
attention_mask
is
not
None
:
# Apply the attention mask
# Apply the attention mask
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
w
.
dtype
)
w
=
w
+
attention_mask
w
=
w
+
attention_mask
w
=
tf
.
nn
.
softmax
(
w
,
axis
=-
1
)
w
=
tf
.
nn
.
softmax
(
w
,
axis
=-
1
)
...
@@ -201,19 +202,25 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
...
@@ -201,19 +202,25 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
self
.
num_hidden_layers
=
config
.
n_layer
self
.
num_hidden_layers
=
config
.
n_layer
self
.
vocab_size
=
config
.
vocab_size
self
.
vocab_size
=
config
.
vocab_size
self
.
n_embd
=
config
.
n_embd
self
.
n_embd
=
config
.
n_embd
self
.
n_positions
=
config
.
n_positions
self
.
initializer_range
=
config
.
initializer_range
self
.
tokens_embed
=
TFSharedEmbeddings
(
self
.
tokens_embed
=
TFSharedEmbeddings
(
config
.
vocab_size
,
config
.
n_embd
,
initializer_range
=
config
.
initializer_range
,
name
=
"tokens_embed"
config
.
vocab_size
,
config
.
n_embd
,
initializer_range
=
config
.
initializer_range
,
name
=
"tokens_embed"
)
)
self
.
positions_embed
=
tf
.
keras
.
layers
.
Embedding
(
config
.
n_positions
,
config
.
n_embd
,
embeddings_initializer
=
get_initializer
(
config
.
initializer_range
),
name
=
"positions_embed"
,
)
self
.
drop
=
tf
.
keras
.
layers
.
Dropout
(
config
.
embd_pdrop
)
self
.
drop
=
tf
.
keras
.
layers
.
Dropout
(
config
.
embd_pdrop
)
self
.
h
=
[
TFBlock
(
config
.
n_ctx
,
config
,
scale
=
True
,
name
=
"h_._{}"
.
format
(
i
))
for
i
in
range
(
config
.
n_layer
)]
self
.
h
=
[
TFBlock
(
config
.
n_ctx
,
config
,
scale
=
True
,
name
=
"h_._{}"
.
format
(
i
))
for
i
in
range
(
config
.
n_layer
)]
def
build
(
self
,
input_shape
):
with
tf
.
name_scope
(
"positions_embed"
):
self
.
positions_embed
=
self
.
add_weight
(
name
=
"embeddings"
,
shape
=
[
self
.
n_positions
,
self
.
n_embd
],
initializer
=
get_initializer
(
self
.
initializer_range
),
)
super
().
build
(
input_shape
)
def
get_input_embeddings
(
self
):
def
get_input_embeddings
(
self
):
return
self
.
tokens_embed
return
self
.
tokens_embed
...
@@ -268,7 +275,7 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
...
@@ -268,7 +275,7 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
raise
ValueError
(
"You have to specify either input_ids or inputs_embeds"
)
raise
ValueError
(
"You have to specify either input_ids or inputs_embeds"
)
if
inputs
[
"position_ids"
]
is
None
:
if
inputs
[
"position_ids"
]
is
None
:
inputs
[
"position_ids"
]
=
tf
.
expand_dims
(
tf
.
range
(
input_shape
[
-
1
]
,
dtype
=
tf
.
int32
),
axis
=
0
)
inputs
[
"position_ids"
]
=
tf
.
expand_dims
(
tf
.
range
(
input_shape
[
-
1
]),
axis
=
0
)
if
inputs
[
"attention_mask"
]
is
not
None
:
if
inputs
[
"attention_mask"
]
is
not
None
:
# We create a 3D attention mask from a 2D tensor mask.
# We create a 3D attention mask from a 2D tensor mask.
...
@@ -284,8 +291,11 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
...
@@ -284,8 +291,11 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
# Since we are adding it to the raw scores before the softmax, this is
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
# effectively the same as removing these entirely.
inputs
[
"attention_mask"
]
=
tf
.
cast
(
inputs
[
"attention_mask"
],
tf
.
float32
)
one_cst
=
tf
.
constant
(
1.0
)
inputs
[
"attention_mask"
]
=
(
1.0
-
inputs
[
"attention_mask"
])
*
-
10000.0
inputs
[
"attention_mask"
]
=
tf
.
cast
(
inputs
[
"attention_mask"
],
dtype
=
one_cst
.
dtype
)
inputs
[
"attention_mask"
]
=
tf
.
multiply
(
tf
.
subtract
(
one_cst
,
inputs
[
"attention_mask"
]),
tf
.
constant
(
-
10000.0
)
)
else
:
else
:
inputs
[
"attention_mask"
]
=
None
inputs
[
"attention_mask"
]
=
None
...
@@ -304,7 +314,7 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
...
@@ -304,7 +314,7 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
if
inputs
[
"inputs_embeds"
]
is
None
:
if
inputs
[
"inputs_embeds"
]
is
None
:
inputs
[
"inputs_embeds"
]
=
self
.
tokens_embed
(
inputs
[
"input_ids"
],
mode
=
"embedding"
)
inputs
[
"inputs_embeds"
]
=
self
.
tokens_embed
(
inputs
[
"input_ids"
],
mode
=
"embedding"
)
position_embeds
=
self
.
positions_embed
(
inputs
[
"position_ids"
])
position_embeds
=
tf
.
gather
(
self
.
positions_embed
,
inputs
[
"position_ids"
])
if
inputs
[
"token_type_ids"
]
is
not
None
:
if
inputs
[
"token_type_ids"
]
is
not
None
:
inputs
[
"token_type_ids"
]
=
tf
.
reshape
(
inputs
[
"token_type_ids"
]
=
tf
.
reshape
(
inputs
[
"token_type_ids"
],
[
-
1
,
shape_list
(
inputs
[
"token_type_ids"
])[
-
1
]]
inputs
[
"token_type_ids"
],
[
-
1
,
shape_list
(
inputs
[
"token_type_ids"
])[
-
1
]]
...
@@ -903,7 +913,6 @@ class TFOpenAIGPTForSequenceClassification(TFOpenAIGPTPreTrainedModel, TFSequenc
...
@@ -903,7 +913,6 @@ class TFOpenAIGPTForSequenceClassification(TFOpenAIGPTPreTrainedModel, TFSequenc
hidden_states
=
transformer_outputs
[
0
]
hidden_states
=
transformer_outputs
[
0
]
logits
=
self
.
score
(
hidden_states
)
logits
=
self
.
score
(
hidden_states
)
logits_shape
=
shape_list
(
logits
)
in_logits
=
None
in_logits
=
None
if
self
.
config
.
pad_token_id
is
None
:
if
self
.
config
.
pad_token_id
is
None
:
sequence_lengths
=
-
1
sequence_lengths
=
-
1
...
@@ -911,22 +920,16 @@ class TFOpenAIGPTForSequenceClassification(TFOpenAIGPTPreTrainedModel, TFSequenc
...
@@ -911,22 +920,16 @@ class TFOpenAIGPTForSequenceClassification(TFOpenAIGPTPreTrainedModel, TFSequenc
if
inputs
[
"input_ids"
]
is
not
None
:
if
inputs
[
"input_ids"
]
is
not
None
:
sequence_lengths
=
(
sequence_lengths
=
(
tf
.
reduce_sum
(
tf
.
reduce_sum
(
tf
.
cast
(
tf
.
math
.
not_equal
(
inputs
[
"input_ids"
],
self
.
config
.
pad_token_id
),
tf
.
int32
),
tf
.
cast
(
tf
.
math
.
not_equal
(
inputs
[
"input_ids"
],
self
.
config
.
pad_token_id
),
dtype
=
inputs
[
"input_ids"
].
dtype
,
),
-
1
,
-
1
,
keepdims
=
False
,
keepdims
=
False
,
)
)
-
1
-
1
)
)
in_logits
=
tf
.
gather
(
logits
,
sequence_lengths
,
batch_dims
=
1
,
axis
=
1
)
def
get_seq_element
(
sequence_position
,
input_batch
):
return
tf
.
strided_slice
(
input_batch
,
[
sequence_position
,
0
],
[
sequence_position
+
1
,
input_batch
.
shape
[
-
1
]],
[
1
,
1
]
)
result
=
tf
.
map_fn
(
fn
=
lambda
t
:
get_seq_element
(
t
[
0
],
t
[
1
]),
elems
=
[
sequence_lengths
,
logits
],
dtype
=
"float"
)
in_logits
=
tf
.
reshape
(
result
,
[
logits_shape
[
0
],
logits_shape
[
-
1
]])
else
:
else
:
sequence_lengths
=
-
1
sequence_lengths
=
-
1
logger
.
warning
(
logger
.
warning
(
...
...
tests/test_modeling_tf_openai.py
View file @
34df26ec
...
@@ -246,14 +246,6 @@ class TFOpenAIGPTModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -246,14 +246,6 @@ class TFOpenAIGPTModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_openai_gpt_for_sequence_classification
(
*
config_and_inputs
)
self
.
model_tester
.
create_and_check_openai_gpt_for_sequence_classification
(
*
config_and_inputs
)
def
test_mixed_precision
(
self
):
# TODO JP: Make OpenAIGPT float16 compliant
pass
def
test_xla_mode
(
self
):
# TODO JP: Make OpenAIGPT XLA compliant
pass
@
slow
@
slow
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
for
model_name
in
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST
[:
1
]:
for
model_name
in
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST
[:
1
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment