Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
5f296bbe
Commit
5f296bbe
authored
Feb 14, 2020
by
Chen Chen
Committed by
A. Unique TensorFlower
Feb 14, 2020
Browse files
Add a run_classifier.py in albert folder.
PiperOrigin-RevId: 295202644
parent
730035d6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
76 additions
and
10 deletions
+76
-10
official/nlp/albert/run_classifier.py
official/nlp/albert/run_classifier.py
+70
-0
official/nlp/bert/run_classifier.py
official/nlp/bert/run_classifier.py
+6
-10
No files found.
official/nlp/albert/run_classifier.py
0 → 100644
View file @
5f296bbe
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ALBERT classification finetuning runner in tf2.x."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
json
from
absl
import
app
from
absl
import
flags
import
tensorflow
as
tf
from
official.nlp.albert
import
configs
as
albert_configs
from
official.nlp.bert
import
run_classifier
as
run_classifier_bert
from
official.utils.misc
import
distribution_utils
FLAGS
=
flags
.
FLAGS
def
main
(
_
):
# Users should always run this script under TF 2.x
assert
tf
.
version
.
VERSION
.
startswith
(
'2.'
)
with
tf
.
io
.
gfile
.
GFile
(
FLAGS
.
input_meta_data_path
,
'rb'
)
as
reader
:
input_meta_data
=
json
.
loads
(
reader
.
read
().
decode
(
'utf-8'
))
if
not
FLAGS
.
model_dir
:
FLAGS
.
model_dir
=
'/tmp/bert20/'
strategy
=
distribution_utils
.
get_distribution_strategy
(
distribution_strategy
=
FLAGS
.
distribution_strategy
,
num_gpus
=
FLAGS
.
num_gpus
,
tpu_address
=
FLAGS
.
tpu
)
max_seq_length
=
input_meta_data
[
'max_seq_length'
]
train_input_fn
=
run_classifier_bert
.
get_dataset_fn
(
FLAGS
.
train_data_path
,
max_seq_length
,
FLAGS
.
train_batch_size
,
is_training
=
True
)
eval_input_fn
=
run_classifier_bert
.
get_dataset_fn
(
FLAGS
.
eval_data_path
,
max_seq_length
,
FLAGS
.
eval_batch_size
,
is_training
=
False
)
albert_config
=
albert_configs
.
AlbertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
run_classifier_bert
.
run_bert
(
strategy
,
input_meta_data
,
albert_config
,
train_input_fn
,
eval_input_fn
)
if
__name__
==
'__main__'
:
flags
.
mark_flag_as_required
(
'bert_config_file'
)
flags
.
mark_flag_as_required
(
'input_meta_data_path'
)
flags
.
mark_flag_as_required
(
'model_dir'
)
app
.
run
(
main
)
official/nlp/bert/run_classifier.py
View file @
5f296bbe
...
...
@@ -28,7 +28,6 @@ import tensorflow as tf
from
official.modeling
import
model_training_utils
from
official.nlp
import
optimization
from
official.nlp.albert
import
configs
as
albert_configs
from
official.nlp.bert
import
bert_models
from
official.nlp.bert
import
common_flags
from
official.nlp.bert
import
configs
as
bert_configs
...
...
@@ -285,22 +284,17 @@ def export_classifier(model_export_path, input_meta_data,
def
run_bert
(
strategy
,
input_meta_data
,
model_config
,
train_input_fn
=
None
,
eval_input_fn
=
None
):
"""Run BERT training."""
if
FLAGS
.
model_type
==
'bert'
:
bert_config
=
bert_configs
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
else
:
assert
FLAGS
.
model_type
==
'albert'
bert_config
=
albert_configs
.
AlbertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
if
FLAGS
.
mode
==
'export_only'
:
# As Keras ModelCheckpoint callback used with Keras compile/fit() API
# internally uses model.save_weights() to save checkpoints, we must
# use model.load_weights() when Keras compile/fit() is used.
export_classifier
(
FLAGS
.
model_export_path
,
input_meta_data
,
FLAGS
.
use_keras_compile_fit
,
bert
_config
,
FLAGS
.
model_dir
)
model
_config
,
FLAGS
.
model_dir
)
return
if
FLAGS
.
mode
!=
'train_and_eval'
:
...
...
@@ -320,7 +314,7 @@ def run_bert(strategy,
trained_model
=
run_bert_classifier
(
strategy
,
bert
_config
,
model
_config
,
input_meta_data
,
FLAGS
.
model_dir
,
epochs
,
...
...
@@ -372,7 +366,9 @@ def main(_):
FLAGS
.
eval_batch_size
,
is_training
=
False
)
run_bert
(
strategy
,
input_meta_data
,
train_input_fn
,
eval_input_fn
)
bert_config
=
bert_configs
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
run_bert
(
strategy
,
input_meta_data
,
bert_config
,
train_input_fn
,
eval_input_fn
)
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment