Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
f7fd59b8
Commit
f7fd59b8
authored
Dec 10, 2019
by
Chen Chen
Committed by
A. Unique TensorFlower
Dec 10, 2019
Browse files
Internal change
PiperOrigin-RevId: 284792715
parent
558bab5d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
20 additions
and
4 deletions
+20
-4
official/nlp/bert/common_flags.py
official/nlp/bert/common_flags.py
+4
-0
official/nlp/bert/run_classifier.py
official/nlp/bert/run_classifier.py
+5
-1
official/nlp/bert_models.py
official/nlp/bert_models.py
+11
-3
No files found.
official/nlp/bert/common_flags.py
View file @
f7fd59b8
...
@@ -65,6 +65,10 @@ def define_common_bert_flags():
...
@@ -65,6 +65,10 @@ def define_common_bert_flags():
flags
.
DEFINE_string
(
flags
.
DEFINE_string
(
'hub_module_url'
,
None
,
'TF-Hub path/url to Bert module. '
'hub_module_url'
,
None
,
'TF-Hub path/url to Bert module. '
'If specified, init_checkpoint flag should not be used.'
)
'If specified, init_checkpoint flag should not be used.'
)
flags
.
DEFINE_enum
(
'model_type'
,
'bert'
,
[
'bert'
,
'albert'
],
'Specifies the type of the model. '
'If "bert", will use canonical BERT; if "albert", will use ALBERT model.'
)
# Adds flags for mixed precision training.
# Adds flags for mixed precision training.
flags_core
.
define_performance
(
flags_core
.
define_performance
(
...
...
official/nlp/bert/run_classifier.py
View file @
f7fd59b8
...
@@ -287,7 +287,11 @@ def run_bert(strategy,
...
@@ -287,7 +287,11 @@ def run_bert(strategy,
train_input_fn
=
None
,
train_input_fn
=
None
,
eval_input_fn
=
None
):
eval_input_fn
=
None
):
"""Run BERT training."""
"""Run BERT training."""
bert_config
=
modeling
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
if
FLAGS
.
model_type
==
'bert'
:
bert_config
=
modeling
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
else
:
assert
FLAGS
.
model_type
==
'albert'
bert_config
=
modeling
.
AlbertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
if
FLAGS
.
mode
==
'export_only'
:
if
FLAGS
.
mode
==
'export_only'
:
# As Keras ModelCheckpoint callback used with Keras compile/fit() API
# As Keras ModelCheckpoint callback used with Keras compile/fit() API
# internally uses model.save_weights() to save checkpoints, we must
# internally uses model.save_weights() to save checkpoints, we must
...
...
official/nlp/bert_models.py
View file @
f7fd59b8
...
@@ -22,6 +22,7 @@ import tensorflow as tf
...
@@ -22,6 +22,7 @@ import tensorflow as tf
import
tensorflow_hub
as
hub
import
tensorflow_hub
as
hub
from
official.modeling
import
tf_utils
from
official.modeling
import
tf_utils
from
official.nlp
import
bert_modeling
from
official.nlp.modeling
import
losses
from
official.nlp.modeling
import
losses
from
official.nlp.modeling
import
networks
from
official.nlp.modeling
import
networks
from
official.nlp.modeling.networks
import
bert_classifier
from
official.nlp.modeling.networks
import
bert_classifier
...
@@ -139,14 +140,14 @@ def _get_transformer_encoder(bert_config,
...
@@ -139,14 +140,14 @@ def _get_transformer_encoder(bert_config,
"""Gets a 'TransformerEncoder' object.
"""Gets a 'TransformerEncoder' object.
Args:
Args:
bert_config: A 'modeling.BertConfig' object.
bert_config: A 'modeling.BertConfig'
or 'modeling.AlbertConfig'
object.
sequence_length: Maximum sequence length of the training data.
sequence_length: Maximum sequence length of the training data.
float_dtype: tf.dtype, tf.float32 or tf.float16.
float_dtype: tf.dtype, tf.float32 or tf.float16.
Returns:
Returns:
A networks.TransformerEncoder object.
A networks.TransformerEncoder object.
"""
"""
return
networks
.
TransformerEncoder
(
kwargs
=
dict
(
vocab_size
=
bert_config
.
vocab_size
,
vocab_size
=
bert_config
.
vocab_size
,
hidden_size
=
bert_config
.
hidden_size
,
hidden_size
=
bert_config
.
hidden_size
,
num_layers
=
bert_config
.
num_hidden_layers
,
num_layers
=
bert_config
.
num_hidden_layers
,
...
@@ -161,6 +162,12 @@ def _get_transformer_encoder(bert_config,
...
@@ -161,6 +162,12 @@ def _get_transformer_encoder(bert_config,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
bert_config
.
initializer_range
),
stddev
=
bert_config
.
initializer_range
),
float_dtype
=
float_dtype
.
name
)
float_dtype
=
float_dtype
.
name
)
if
isinstance
(
bert_config
,
bert_modeling
.
AlbertConfig
):
kwargs
[
'embedding_width'
]
=
bert_config
.
embedding_size
return
networks
.
AlbertTransformerEncoder
(
**
kwargs
)
else
:
assert
isinstance
(
bert_config
,
bert_modeling
.
BertConfig
)
return
networks
.
TransformerEncoder
(
**
kwargs
)
def
pretrain_model
(
bert_config
,
def
pretrain_model
(
bert_config
,
...
@@ -332,7 +339,8 @@ def classifier_model(bert_config,
...
@@ -332,7 +339,8 @@ def classifier_model(bert_config,
maximum sequence length `max_seq_length`.
maximum sequence length `max_seq_length`.
Args:
Args:
bert_config: BertConfig, the config defines the core BERT model.
bert_config: BertConfig or AlbertConfig, the config defines the core
BERT or ALBERT model.
float_type: dtype, tf.float32 or tf.bfloat16.
float_type: dtype, tf.float32 or tf.bfloat16.
num_labels: integer, the number of classes.
num_labels: integer, the number of classes.
max_seq_length: integer, the maximum input sequence length.
max_seq_length: integer, the maximum input sequence length.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment