Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
c3bd5082
Commit
c3bd5082
authored
Nov 12, 2019
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 280123567
parent
377c5285
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
21 additions
and
23 deletions
+21
-23
official/nlp/bert/run_squad.py
official/nlp/bert/run_squad.py
+4
-4
official/nlp/bert_models.py
official/nlp/bert_models.py
+8
-3
official/nlp/modeling/layers/attention.py
official/nlp/modeling/layers/attention.py
+1
-5
official/nlp/modeling/layers/transformer.py
official/nlp/modeling/layers/transformer.py
+3
-5
official/nlp/modeling/networks/transformer_encoder.py
official/nlp/modeling/networks/transformer_encoder.py
+2
-6
official/nlp/modeling/networks/transformer_encoder_test.py
official/nlp/modeling/networks/transformer_encoder_test.py
+3
-0
No files found.
official/nlp/bert/run_squad.py
View file @
c3bd5082
...
@@ -81,7 +81,7 @@ flags.DEFINE_integer(
...
@@ -81,7 +81,7 @@ flags.DEFINE_integer(
'The maximum length of an answer that can be generated. This is needed '
'The maximum length of an answer that can be generated. This is needed '
'because the start and end predictions are not conditioned on one another.'
)
'because the start and end predictions are not conditioned on one another.'
)
flags
.
DEFINE_bool
(
flags
.
DEFINE_bool
(
'use_keras_bert_for_squad'
,
Fals
e
,
'Whether to use keras BERT for squad '
'use_keras_bert_for_squad'
,
Tru
e
,
'Whether to use keras BERT for squad '
'task. Note that when the FLAG "hub_module_url" is specified, '
'task. Note that when the FLAG "hub_module_url" is specified, '
'"use_keras_bert_for_squad" cannot be True.'
)
'"use_keras_bert_for_squad" cannot be True.'
)
...
@@ -200,8 +200,7 @@ def train_squad(strategy,
...
@@ -200,8 +200,7 @@ def train_squad(strategy,
use_float16
=
common_flags
.
use_float16
()
use_float16
=
common_flags
.
use_float16
()
if
use_float16
:
if
use_float16
:
policy
=
tf
.
keras
.
mixed_precision
.
experimental
.
Policy
(
'mixed_float16'
)
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
'mixed_float16'
)
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
policy
)
bert_config
=
modeling
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
bert_config
=
modeling
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
epochs
=
FLAGS
.
num_train_epochs
epochs
=
FLAGS
.
num_train_epochs
...
@@ -223,7 +222,8 @@ def train_squad(strategy,
...
@@ -223,7 +222,8 @@ def train_squad(strategy,
max_seq_length
,
max_seq_length
,
float_type
=
tf
.
float16
if
use_float16
else
tf
.
float32
,
float_type
=
tf
.
float16
if
use_float16
else
tf
.
float32
,
hub_module_url
=
FLAGS
.
hub_module_url
,
hub_module_url
=
FLAGS
.
hub_module_url
,
use_keras_bert
=
FLAGS
.
use_keras_bert_for_squad
)
use_keras_bert
=
False
if
FLAGS
.
hub_module_url
else
FLAGS
.
use_keras_bert_for_squad
)
squad_model
.
optimizer
=
optimization
.
create_optimizer
(
squad_model
.
optimizer
=
optimization
.
create_optimizer
(
FLAGS
.
learning_rate
,
steps_per_epoch
*
epochs
,
warmup_steps
)
FLAGS
.
learning_rate
,
steps_per_epoch
*
epochs
,
warmup_steps
)
if
use_float16
:
if
use_float16
:
...
...
official/nlp/bert_models.py
View file @
c3bd5082
...
@@ -227,12 +227,15 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
...
@@ -227,12 +227,15 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
return
final_loss
return
final_loss
def
_get_transformer_encoder
(
bert_config
,
sequence_length
):
def
_get_transformer_encoder
(
bert_config
,
sequence_length
,
float_dtype
=
tf
.
float32
):
"""Gets a 'TransformerEncoder' object.
"""Gets a 'TransformerEncoder' object.
Args:
Args:
bert_config: A 'modeling.BertConfig' object.
bert_config: A 'modeling.BertConfig' object.
sequence_length: Maximum sequence length of the training data.
sequence_length: Maximum sequence length of the training data.
float_dtype: tf.dtype, tf.float32 or tf.float16.
Returns:
Returns:
A networks.TransformerEncoder object.
A networks.TransformerEncoder object.
...
@@ -250,7 +253,8 @@ def _get_transformer_encoder(bert_config, sequence_length):
...
@@ -250,7 +253,8 @@ def _get_transformer_encoder(bert_config, sequence_length):
max_sequence_length
=
bert_config
.
max_position_embeddings
,
max_sequence_length
=
bert_config
.
max_position_embeddings
,
type_vocab_size
=
bert_config
.
type_vocab_size
,
type_vocab_size
=
bert_config
.
type_vocab_size
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
bert_config
.
initializer_range
))
stddev
=
bert_config
.
initializer_range
),
float_dtype
=
float_dtype
.
name
)
def
pretrain_model
(
bert_config
,
def
pretrain_model
(
bert_config
,
...
@@ -387,7 +391,8 @@ def squad_model(bert_config,
...
@@ -387,7 +391,8 @@ def squad_model(bert_config,
'Cannot use hub_module_url and keras BERT at the same time.'
)
'Cannot use hub_module_url and keras BERT at the same time.'
)
if
use_keras_bert
:
if
use_keras_bert
:
bert_encoder
=
_get_transformer_encoder
(
bert_config
,
max_seq_length
)
bert_encoder
=
_get_transformer_encoder
(
bert_config
,
max_seq_length
,
float_type
)
return
bert_span_labeler
.
BertSpanLabeler
(
return
bert_span_labeler
.
BertSpanLabeler
(
network
=
bert_encoder
),
bert_encoder
network
=
bert_encoder
),
bert_encoder
...
...
official/nlp/modeling/layers/attention.py
View file @
c3bd5082
...
@@ -90,7 +90,6 @@ class Attention(tf.keras.layers.Layer):
...
@@ -90,7 +90,6 @@ class Attention(tf.keras.layers.Layer):
activity_regularizer
=
self
.
_activity_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
dtype
=
self
.
dtype
,
name
=
"query"
)
name
=
"query"
)
self
.
_key_dense
=
dense_einsum
.
DenseEinsum
(
self
.
_key_dense
=
dense_einsum
.
DenseEinsum
(
...
@@ -102,7 +101,6 @@ class Attention(tf.keras.layers.Layer):
...
@@ -102,7 +101,6 @@ class Attention(tf.keras.layers.Layer):
activity_regularizer
=
self
.
_activity_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
dtype
=
self
.
dtype
,
name
=
"key"
)
name
=
"key"
)
self
.
_value_dense
=
dense_einsum
.
DenseEinsum
(
self
.
_value_dense
=
dense_einsum
.
DenseEinsum
(
...
@@ -114,13 +112,11 @@ class Attention(tf.keras.layers.Layer):
...
@@ -114,13 +112,11 @@ class Attention(tf.keras.layers.Layer):
activity_regularizer
=
self
.
_activity_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
dtype
=
self
.
dtype
,
name
=
"value"
)
name
=
"value"
)
self
.
_masked_softmax
=
masked_softmax
.
MaskedSoftmax
(
mask_expansion_axes
=
[
1
])
self
.
_masked_softmax
=
masked_softmax
.
MaskedSoftmax
(
mask_expansion_axes
=
[
1
])
self
.
_dropout
=
tf
.
keras
.
layers
.
Dropout
(
self
.
_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
rate
=
self
.
_dropout_rate
,
dtype
=
self
.
dtype
)
def
compute_output_shape
(
self
,
input_shape
):
def
compute_output_shape
(
self
,
input_shape
):
# TODO(momernick): validate tensor dimensioos
# TODO(momernick): validate tensor dimensioos
...
...
official/nlp/modeling/layers/transformer.py
View file @
c3bd5082
...
@@ -110,7 +110,6 @@ class Transformer(tf.keras.layers.Layer):
...
@@ -110,7 +110,6 @@ class Transformer(tf.keras.layers.Layer):
activity_regularizer
=
self
.
_activity_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
dtype
=
self
.
dtype
,
name
=
"self_attention"
)
name
=
"self_attention"
)
self
.
_attention_output_dense
=
dense_einsum
.
DenseEinsum
(
self
.
_attention_output_dense
=
dense_einsum
.
DenseEinsum
(
output_shape
=
hidden_size
,
output_shape
=
hidden_size
,
...
@@ -122,12 +121,12 @@ class Transformer(tf.keras.layers.Layer):
...
@@ -122,12 +121,12 @@ class Transformer(tf.keras.layers.Layer):
activity_regularizer
=
self
.
_activity_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
dtype
=
self
.
dtype
,
name
=
"self_attention_output"
)
name
=
"self_attention_output"
)
self
.
_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
self
.
_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
self
.
_attention_layer_norm
=
(
self
.
_attention_layer_norm
=
(
tf
.
keras
.
layers
.
LayerNormalization
(
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
"self_attention_layer_norm"
,
axis
=-
1
,
epsilon
=
1e-12
))
name
=
"self_attention_layer_norm"
,
axis
=-
1
,
epsilon
=
1e-12
,
dtype
=
tf
.
float32
))
self
.
_intermediate_dense
=
dense_einsum
.
DenseEinsum
(
self
.
_intermediate_dense
=
dense_einsum
.
DenseEinsum
(
output_shape
=
self
.
_intermediate_size
,
output_shape
=
self
.
_intermediate_size
,
activation
=
self
.
_intermediate_activation
,
activation
=
self
.
_intermediate_activation
,
...
@@ -149,11 +148,10 @@ class Transformer(tf.keras.layers.Layer):
...
@@ -149,11 +148,10 @@ class Transformer(tf.keras.layers.Layer):
activity_regularizer
=
self
.
_activity_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
dtype
=
self
.
dtype
,
name
=
"output"
)
name
=
"output"
)
self
.
_output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
self
.
_output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
self
.
_output_layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
self
.
_output_layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
"output_layer_norm"
,
axis
=-
1
,
epsilon
=
1e-12
)
name
=
"output_layer_norm"
,
axis
=-
1
,
epsilon
=
1e-12
,
dtype
=
tf
.
float32
)
super
(
Transformer
,
self
).
build
(
input_shape
)
super
(
Transformer
,
self
).
build
(
input_shape
)
...
...
official/nlp/modeling/networks/transformer_encoder.py
View file @
c3bd5082
...
@@ -111,7 +111,6 @@ class TransformerEncoder(network.Network):
...
@@ -111,7 +111,6 @@ class TransformerEncoder(network.Network):
vocab_size
=
vocab_size
,
vocab_size
=
vocab_size
,
embedding_width
=
hidden_size
,
embedding_width
=
hidden_size
,
initializer
=
initializer
,
initializer
=
initializer
,
dtype
=
float_dtype
,
name
=
'word_embeddings'
)
name
=
'word_embeddings'
)
word_embeddings
=
self
.
_embedding_layer
(
word_ids
)
word_embeddings
=
self
.
_embedding_layer
(
word_ids
)
...
@@ -119,8 +118,7 @@ class TransformerEncoder(network.Network):
...
@@ -119,8 +118,7 @@ class TransformerEncoder(network.Network):
self
.
_position_embedding_layer
=
layers
.
PositionEmbedding
(
self
.
_position_embedding_layer
=
layers
.
PositionEmbedding
(
initializer
=
initializer
,
initializer
=
initializer
,
use_dynamic_slicing
=
True
,
use_dynamic_slicing
=
True
,
max_sequence_length
=
max_sequence_length
,
max_sequence_length
=
max_sequence_length
)
dtype
=
float_dtype
)
position_embeddings
=
self
.
_position_embedding_layer
(
word_embeddings
)
position_embeddings
=
self
.
_position_embedding_layer
(
word_embeddings
)
type_embeddings
=
(
type_embeddings
=
(
...
@@ -129,7 +127,6 @@ class TransformerEncoder(network.Network):
...
@@ -129,7 +127,6 @@ class TransformerEncoder(network.Network):
embedding_width
=
hidden_size
,
embedding_width
=
hidden_size
,
initializer
=
initializer
,
initializer
=
initializer
,
use_one_hot
=
True
,
use_one_hot
=
True
,
dtype
=
float_dtype
,
name
=
'type_embeddings'
)(
type_ids
))
name
=
'type_embeddings'
)(
type_ids
))
embeddings
=
tf
.
keras
.
layers
.
Add
()(
embeddings
=
tf
.
keras
.
layers
.
Add
()(
...
@@ -139,7 +136,7 @@ class TransformerEncoder(network.Network):
...
@@ -139,7 +136,7 @@ class TransformerEncoder(network.Network):
name
=
'embeddings/layer_norm'
,
name
=
'embeddings/layer_norm'
,
axis
=-
1
,
axis
=-
1
,
epsilon
=
1e-12
,
epsilon
=
1e-12
,
dtype
=
float
_dtype
)(
embeddings
))
dtype
=
tf
.
float
32
)(
embeddings
))
embeddings
=
(
embeddings
=
(
tf
.
keras
.
layers
.
Dropout
(
rate
=
dropout_rate
,
tf
.
keras
.
layers
.
Dropout
(
rate
=
dropout_rate
,
dtype
=
tf
.
float32
)(
embeddings
))
dtype
=
tf
.
float32
)(
embeddings
))
...
@@ -168,7 +165,6 @@ class TransformerEncoder(network.Network):
...
@@ -168,7 +165,6 @@ class TransformerEncoder(network.Network):
units
=
hidden_size
,
units
=
hidden_size
,
activation
=
'tanh'
,
activation
=
'tanh'
,
kernel_initializer
=
initializer
,
kernel_initializer
=
initializer
,
dtype
=
float_dtype
,
name
=
'pooler_transform'
)(
name
=
'pooler_transform'
)(
first_token_tensor
)
first_token_tensor
)
...
...
official/nlp/modeling/networks/transformer_encoder_test.py
View file @
c3bd5082
...
@@ -58,6 +58,7 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
...
@@ -58,6 +58,7 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
def
test_network_creation_with_float16_dtype
(
self
):
def
test_network_creation_with_float16_dtype
(
self
):
hidden_size
=
32
hidden_size
=
32
sequence_length
=
21
sequence_length
=
21
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
"mixed_float16"
)
# Create a small TransformerEncoder for testing.
# Create a small TransformerEncoder for testing.
test_network
=
transformer_encoder
.
TransformerEncoder
(
test_network
=
transformer_encoder
.
TransformerEncoder
(
vocab_size
=
100
,
vocab_size
=
100
,
...
@@ -86,6 +87,7 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
...
@@ -86,6 +87,7 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
sequence_length
=
21
sequence_length
=
21
vocab_size
=
57
vocab_size
=
57
num_types
=
7
num_types
=
7
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
"float32"
)
# Create a small TransformerEncoder for testing.
# Create a small TransformerEncoder for testing.
test_network
=
transformer_encoder
.
TransformerEncoder
(
test_network
=
transformer_encoder
.
TransformerEncoder
(
vocab_size
=
vocab_size
,
vocab_size
=
vocab_size
,
...
@@ -166,4 +168,5 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
...
@@ -166,4 +168,5 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
assert
tf
.
version
.
VERSION
.
startswith
(
'2.'
)
tf
.
test
.
main
()
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment