Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
2565629c
Commit
2565629c
authored
Mar 05, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Mar 05, 2020
Browse files
Internal change
PiperOrigin-RevId: 299169021
parent
d3d7f15f
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
51 additions
and
4 deletions
+51
-4
official/nlp/bert/bert_models.py
official/nlp/bert/bert_models.py
+32
-1
official/nlp/bert/common_flags.py
official/nlp/bert/common_flags.py
+8
-0
official/nlp/bert/run_pretraining.py
official/nlp/bert/run_pretraining.py
+3
-1
official/nlp/modeling/layers/transformer_scaffold.py
official/nlp/modeling/layers/transformer_scaffold.py
+2
-0
official/nlp/modeling/networks/encoder_scaffold.py
official/nlp/modeling/networks/encoder_scaffold.py
+6
-2
No files found.
official/nlp/bert/bert_models.py
View file @
2565629c
...
...
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
gin
import
tensorflow
as
tf
import
tensorflow_hub
as
hub
...
...
@@ -85,16 +86,46 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
return
final_loss
def
get_transformer_encoder
(
bert_config
,
sequence_length
):
@
gin
.
configurable
def
get_transformer_encoder
(
bert_config
,
sequence_length
,
transformer_encoder_cls
=
None
):
"""Gets a 'TransformerEncoder' object.
Args:
bert_config: A 'modeling.BertConfig' or 'modeling.AlbertConfig' object.
sequence_length: Maximum sequence length of the training data.
transformer_encoder_cls: A EncoderScaffold class. If it is None, uses the
default BERT encoder implementation.
Returns:
A networks.TransformerEncoder object.
"""
if
transformer_encoder_cls
is
not
None
:
# TODO(hongkuny): evaluate if it is better to put cfg definition in gin.
embedding_cfg
=
dict
(
vocab_size
=
bert_config
.
vocab_size
,
type_vocab_size
=
bert_config
.
type_vocab_size
,
hidden_size
=
bert_config
.
hidden_size
,
seq_length
=
sequence_length
,
max_seq_length
=
bert_config
.
max_position_embeddings
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
bert_config
.
initializer_range
),
dropout_rate
=
bert_config
.
hidden_dropout_prob
,
)
hidden_cfg
=
dict
(
num_attention_heads
=
bert_config
.
num_attention_heads
,
intermediate_size
=
bert_config
.
intermediate_size
,
intermediate_activation
=
tf_utils
.
get_activation
(
bert_config
.
hidden_act
),
dropout_rate
=
bert_config
.
hidden_dropout_prob
,
attention_dropout_rate
=
bert_config
.
attention_probs_dropout_prob
,
)
kwargs
=
dict
(
embedding_cfg
=
embedding_cfg
,
hidden_cfg
=
hidden_cfg
,
num_hidden_instances
=
bert_config
.
num_hidden_layers
,)
# Relies on gin configuration to define the Transformer encoder arguments.
return
transformer_encoder_cls
(
**
kwargs
)
kwargs
=
dict
(
vocab_size
=
bert_config
.
vocab_size
,
hidden_size
=
bert_config
.
hidden_size
,
...
...
official/nlp/bert/common_flags.py
View file @
2565629c
...
...
@@ -20,6 +20,14 @@ import tensorflow as tf
from
official.utils.flags
import
core
as
flags_core
def
define_gin_flags
():
"""Define common gin configurable flags."""
flags
.
DEFINE_multi_string
(
'gin_file'
,
None
,
'List of paths to the config files.'
)
flags
.
DEFINE_multi_string
(
'gin_param'
,
None
,
'Newline separated list of Gin parameter bindings.'
)
def
define_common_bert_flags
():
"""Define common flags for BERT tasks."""
flags_core
.
define_base
(
...
...
official/nlp/bert/run_pretraining.py
View file @
2565629c
...
...
@@ -20,6 +20,7 @@ from __future__ import print_function
from
absl
import
app
from
absl
import
flags
from
absl
import
logging
import
gin
import
tensorflow
as
tf
from
official.modeling
import
model_training_utils
...
...
@@ -49,6 +50,7 @@ flags.DEFINE_float('warmup_steps', 10000,
'Warmup steps for Adam weight decay optimizer.'
)
common_flags
.
define_common_bert_flags
()
common_flags
.
define_gin_flags
()
FLAGS
=
flags
.
FLAGS
...
...
@@ -158,7 +160,7 @@ def run_bert_pretrain(strategy):
def
main
(
_
):
# Users should always run this script under TF 2.x
assert
tf
.
version
.
VERSION
.
startswith
(
'2.'
)
gin
.
parse_config_files_and_bindings
(
FLAGS
.
gin_file
,
FLAGS
.
gin_param
)
if
not
FLAGS
.
model_dir
:
FLAGS
.
model_dir
=
'/tmp/bert20/'
strategy
=
distribution_utils
.
get_distribution_strategy
(
...
...
official/nlp/modeling/layers/transformer_scaffold.py
View file @
2565629c
...
...
@@ -19,6 +19,7 @@ from __future__ import division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
gin
import
tensorflow
as
tf
from
official.nlp.modeling.layers
import
attention
...
...
@@ -26,6 +27,7 @@ from official.nlp.modeling.layers import dense_einsum
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
@
gin
.
configurable
class
TransformerScaffold
(
tf
.
keras
.
layers
.
Layer
):
"""Transformer scaffold layer.
...
...
official/nlp/modeling/networks/encoder_scaffold.py
View file @
2565629c
# Lint as: python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -20,6 +21,8 @@ from __future__ import division
from
__future__
import
print_function
import
inspect
import
gin
import
tensorflow
as
tf
from
tensorflow.python.keras.engine
import
network
# pylint: disable=g-direct-tensorflow-import
...
...
@@ -27,6 +30,7 @@ from official.nlp.modeling import layers
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
@
gin
.
configurable
class
EncoderScaffold
(
network
.
Network
):
"""Bi-directional Transformer-based encoder network scaffold.
...
...
@@ -96,7 +100,6 @@ class EncoderScaffold(network.Network):
hidden_cls
=
layers
.
Transformer
,
hidden_cfg
=
None
,
**
kwargs
):
print
(
embedding_cfg
)
self
.
_self_setattr_tracking
=
False
self
.
_hidden_cls
=
hidden_cls
self
.
_hidden_cfg
=
hidden_cfg
...
...
@@ -171,7 +174,8 @@ class EncoderScaffold(network.Network):
for
_
in
range
(
num_hidden_instances
):
if
inspect
.
isclass
(
hidden_cls
):
layer
=
self
.
_hidden_cls
(
**
hidden_cfg
)
layer
=
self
.
_hidden_cls
(
**
hidden_cfg
)
if
hidden_cfg
else
self
.
_hidden_cls
()
else
:
layer
=
self
.
_hidden_cls
data
=
layer
([
data
,
attention_mask
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment