Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
2565629c
Commit
2565629c
authored
Mar 05, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Mar 05, 2020
Browse files
Internal change
PiperOrigin-RevId: 299169021
parent
d3d7f15f
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
51 additions
and
4 deletions
+51
-4
official/nlp/bert/bert_models.py
official/nlp/bert/bert_models.py
+32
-1
official/nlp/bert/common_flags.py
official/nlp/bert/common_flags.py
+8
-0
official/nlp/bert/run_pretraining.py
official/nlp/bert/run_pretraining.py
+3
-1
official/nlp/modeling/layers/transformer_scaffold.py
official/nlp/modeling/layers/transformer_scaffold.py
+2
-0
official/nlp/modeling/networks/encoder_scaffold.py
official/nlp/modeling/networks/encoder_scaffold.py
+6
-2
No files found.
official/nlp/bert/bert_models.py
View file @
2565629c
...
@@ -18,6 +18,7 @@ from __future__ import absolute_import
...
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
gin
import
tensorflow
as
tf
import
tensorflow
as
tf
import
tensorflow_hub
as
hub
import
tensorflow_hub
as
hub
...
@@ -85,16 +86,46 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
...
@@ -85,16 +86,46 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
return
final_loss
return
final_loss
def
get_transformer_encoder
(
bert_config
,
sequence_length
):
@
gin
.
configurable
def
get_transformer_encoder
(
bert_config
,
sequence_length
,
transformer_encoder_cls
=
None
):
"""Gets a 'TransformerEncoder' object.
"""Gets a 'TransformerEncoder' object.
Args:
Args:
bert_config: A 'modeling.BertConfig' or 'modeling.AlbertConfig' object.
bert_config: A 'modeling.BertConfig' or 'modeling.AlbertConfig' object.
sequence_length: Maximum sequence length of the training data.
sequence_length: Maximum sequence length of the training data.
transformer_encoder_cls: A EncoderScaffold class. If it is None, uses the
default BERT encoder implementation.
Returns:
Returns:
A networks.TransformerEncoder object.
A networks.TransformerEncoder object.
"""
"""
if
transformer_encoder_cls
is
not
None
:
# TODO(hongkuny): evaluate if it is better to put cfg definition in gin.
embedding_cfg
=
dict
(
vocab_size
=
bert_config
.
vocab_size
,
type_vocab_size
=
bert_config
.
type_vocab_size
,
hidden_size
=
bert_config
.
hidden_size
,
seq_length
=
sequence_length
,
max_seq_length
=
bert_config
.
max_position_embeddings
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
bert_config
.
initializer_range
),
dropout_rate
=
bert_config
.
hidden_dropout_prob
,
)
hidden_cfg
=
dict
(
num_attention_heads
=
bert_config
.
num_attention_heads
,
intermediate_size
=
bert_config
.
intermediate_size
,
intermediate_activation
=
tf_utils
.
get_activation
(
bert_config
.
hidden_act
),
dropout_rate
=
bert_config
.
hidden_dropout_prob
,
attention_dropout_rate
=
bert_config
.
attention_probs_dropout_prob
,
)
kwargs
=
dict
(
embedding_cfg
=
embedding_cfg
,
hidden_cfg
=
hidden_cfg
,
num_hidden_instances
=
bert_config
.
num_hidden_layers
,)
# Relies on gin configuration to define the Transformer encoder arguments.
return
transformer_encoder_cls
(
**
kwargs
)
kwargs
=
dict
(
kwargs
=
dict
(
vocab_size
=
bert_config
.
vocab_size
,
vocab_size
=
bert_config
.
vocab_size
,
hidden_size
=
bert_config
.
hidden_size
,
hidden_size
=
bert_config
.
hidden_size
,
...
...
official/nlp/bert/common_flags.py
View file @
2565629c
...
@@ -20,6 +20,14 @@ import tensorflow as tf
...
@@ -20,6 +20,14 @@ import tensorflow as tf
from
official.utils.flags
import
core
as
flags_core
from
official.utils.flags
import
core
as
flags_core
def
define_gin_flags
():
"""Define common gin configurable flags."""
flags
.
DEFINE_multi_string
(
'gin_file'
,
None
,
'List of paths to the config files.'
)
flags
.
DEFINE_multi_string
(
'gin_param'
,
None
,
'Newline separated list of Gin parameter bindings.'
)
def
define_common_bert_flags
():
def
define_common_bert_flags
():
"""Define common flags for BERT tasks."""
"""Define common flags for BERT tasks."""
flags_core
.
define_base
(
flags_core
.
define_base
(
...
...
official/nlp/bert/run_pretraining.py
View file @
2565629c
...
@@ -20,6 +20,7 @@ from __future__ import print_function
...
@@ -20,6 +20,7 @@ from __future__ import print_function
from
absl
import
app
from
absl
import
app
from
absl
import
flags
from
absl
import
flags
from
absl
import
logging
from
absl
import
logging
import
gin
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
model_training_utils
from
official.modeling
import
model_training_utils
...
@@ -49,6 +50,7 @@ flags.DEFINE_float('warmup_steps', 10000,
...
@@ -49,6 +50,7 @@ flags.DEFINE_float('warmup_steps', 10000,
'Warmup steps for Adam weight decay optimizer.'
)
'Warmup steps for Adam weight decay optimizer.'
)
common_flags
.
define_common_bert_flags
()
common_flags
.
define_common_bert_flags
()
common_flags
.
define_gin_flags
()
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
...
@@ -158,7 +160,7 @@ def run_bert_pretrain(strategy):
...
@@ -158,7 +160,7 @@ def run_bert_pretrain(strategy):
def
main
(
_
):
def
main
(
_
):
# Users should always run this script under TF 2.x
# Users should always run this script under TF 2.x
assert
tf
.
version
.
VERSION
.
startswith
(
'2.'
)
assert
tf
.
version
.
VERSION
.
startswith
(
'2.'
)
gin
.
parse_config_files_and_bindings
(
FLAGS
.
gin_file
,
FLAGS
.
gin_param
)
if
not
FLAGS
.
model_dir
:
if
not
FLAGS
.
model_dir
:
FLAGS
.
model_dir
=
'/tmp/bert20/'
FLAGS
.
model_dir
=
'/tmp/bert20/'
strategy
=
distribution_utils
.
get_distribution_strategy
(
strategy
=
distribution_utils
.
get_distribution_strategy
(
...
...
official/nlp/modeling/layers/transformer_scaffold.py
View file @
2565629c
...
@@ -19,6 +19,7 @@ from __future__ import division
...
@@ -19,6 +19,7 @@ from __future__ import division
# from __future__ import google_type_annotations
# from __future__ import google_type_annotations
from
__future__
import
print_function
from
__future__
import
print_function
import
gin
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.nlp.modeling.layers
import
attention
from
official.nlp.modeling.layers
import
attention
...
@@ -26,6 +27,7 @@ from official.nlp.modeling.layers import dense_einsum
...
@@ -26,6 +27,7 @@ from official.nlp.modeling.layers import dense_einsum
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
@
gin
.
configurable
class
TransformerScaffold
(
tf
.
keras
.
layers
.
Layer
):
class
TransformerScaffold
(
tf
.
keras
.
layers
.
Layer
):
"""Transformer scaffold layer.
"""Transformer scaffold layer.
...
...
official/nlp/modeling/networks/encoder_scaffold.py
View file @
2565629c
# Lint as: python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
...
@@ -20,6 +21,8 @@ from __future__ import division
...
@@ -20,6 +21,8 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
import
inspect
import
inspect
import
gin
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.python.keras.engine
import
network
# pylint: disable=g-direct-tensorflow-import
from
tensorflow.python.keras.engine
import
network
# pylint: disable=g-direct-tensorflow-import
...
@@ -27,6 +30,7 @@ from official.nlp.modeling import layers
...
@@ -27,6 +30,7 @@ from official.nlp.modeling import layers
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
@
gin
.
configurable
class
EncoderScaffold
(
network
.
Network
):
class
EncoderScaffold
(
network
.
Network
):
"""Bi-directional Transformer-based encoder network scaffold.
"""Bi-directional Transformer-based encoder network scaffold.
...
@@ -96,7 +100,6 @@ class EncoderScaffold(network.Network):
...
@@ -96,7 +100,6 @@ class EncoderScaffold(network.Network):
hidden_cls
=
layers
.
Transformer
,
hidden_cls
=
layers
.
Transformer
,
hidden_cfg
=
None
,
hidden_cfg
=
None
,
**
kwargs
):
**
kwargs
):
print
(
embedding_cfg
)
self
.
_self_setattr_tracking
=
False
self
.
_self_setattr_tracking
=
False
self
.
_hidden_cls
=
hidden_cls
self
.
_hidden_cls
=
hidden_cls
self
.
_hidden_cfg
=
hidden_cfg
self
.
_hidden_cfg
=
hidden_cfg
...
@@ -171,7 +174,8 @@ class EncoderScaffold(network.Network):
...
@@ -171,7 +174,8 @@ class EncoderScaffold(network.Network):
for
_
in
range
(
num_hidden_instances
):
for
_
in
range
(
num_hidden_instances
):
if
inspect
.
isclass
(
hidden_cls
):
if
inspect
.
isclass
(
hidden_cls
):
layer
=
self
.
_hidden_cls
(
**
hidden_cfg
)
layer
=
self
.
_hidden_cls
(
**
hidden_cfg
)
if
hidden_cfg
else
self
.
_hidden_cls
()
else
:
else
:
layer
=
self
.
_hidden_cls
layer
=
self
.
_hidden_cls
data
=
layer
([
data
,
attention_mask
])
data
=
layer
([
data
,
attention_mask
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment