Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
32e4ca51
Commit
32e4ca51
authored
Nov 28, 2023
by
qianyj
Browse files
Update code to v2.11.0
parents
9485aa1d
71060f67
Changes
775
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
698 additions
and
35 deletions
+698
-35
official/legacy/bert/model_training_utils_test.py
official/legacy/bert/model_training_utils_test.py
+3
-3
official/legacy/bert/run_classifier.py
official/legacy/bert/run_classifier.py
+515
-0
official/legacy/bert/run_pretraining.py
official/legacy/bert/run_pretraining.py
+6
-6
official/legacy/bert/run_squad.py
official/legacy/bert/run_squad.py
+148
-0
official/legacy/bert/run_squad_helper.py
official/legacy/bert/run_squad_helper.py
+8
-8
official/legacy/bert/serving.py
official/legacy/bert/serving.py
+3
-3
official/legacy/detection/__init__.py
official/legacy/detection/__init__.py
+1
-1
official/legacy/detection/configs/__init__.py
official/legacy/detection/configs/__init__.py
+1
-1
official/legacy/detection/configs/base_config.py
official/legacy/detection/configs/base_config.py
+1
-1
official/legacy/detection/configs/factory.py
official/legacy/detection/configs/factory.py
+1
-1
official/legacy/detection/configs/maskrcnn_config.py
official/legacy/detection/configs/maskrcnn_config.py
+1
-1
official/legacy/detection/configs/olnmask_config.py
official/legacy/detection/configs/olnmask_config.py
+1
-1
official/legacy/detection/configs/retinanet_config.py
official/legacy/detection/configs/retinanet_config.py
+1
-1
official/legacy/detection/configs/shapemask_config.py
official/legacy/detection/configs/shapemask_config.py
+1
-1
official/legacy/detection/dataloader/__init__.py
official/legacy/detection/dataloader/__init__.py
+1
-1
official/legacy/detection/dataloader/anchor.py
official/legacy/detection/dataloader/anchor.py
+2
-2
official/legacy/detection/dataloader/factory.py
official/legacy/detection/dataloader/factory.py
+1
-1
official/legacy/detection/dataloader/input_reader.py
official/legacy/detection/dataloader/input_reader.py
+1
-1
official/legacy/detection/dataloader/maskrcnn_parser.py
official/legacy/detection/dataloader/maskrcnn_parser.py
+1
-1
official/legacy/detection/dataloader/mode_keys.py
official/legacy/detection/dataloader/mode_keys.py
+1
-1
No files found.
Too many changes to show.
To preserve performance only
775 of 775+
files are displayed.
Plain diff
Email patch
official/
nlp
/bert/model_training_utils_test.py
→
official/
legacy
/bert/model_training_utils_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -25,8 +25,8 @@ import tensorflow as tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
strategy_combinations
from
official.
nlp
.bert
import
common_flags
from
official.
nlp
.bert
import
model_training_utils
from
official.
legacy
.bert
import
common_flags
from
official.
legacy
.bert
import
model_training_utils
common_flags
.
define_common_bert_flags
()
...
...
official/legacy/bert/run_classifier.py
0 → 100644
View file @
32e4ca51
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BERT classification or regression finetuning runner in TF 2.x."""
import
functools
import
json
import
math
import
os
# Import libraries
from
absl
import
app
from
absl
import
flags
from
absl
import
logging
import
gin
import
tensorflow
as
tf
from
official.common
import
distribute_utils
from
official.legacy.bert
import
bert_models
from
official.legacy.bert
import
common_flags
from
official.legacy.bert
import
configs
as
bert_configs
from
official.legacy.bert
import
input_pipeline
from
official.legacy.bert
import
model_saving_utils
from
official.modeling
import
performance
from
official.nlp
import
optimization
from
official.utils.misc
import
keras_utils
flags
.
DEFINE_enum
(
'mode'
,
'train_and_eval'
,
[
'train_and_eval'
,
'export_only'
,
'predict'
],
'One of {"train_and_eval", "export_only", "predict"}. `train_and_eval`: '
'trains the model and evaluates in the meantime. '
'`export_only`: will take the latest checkpoint inside '
'model_dir and export a `SavedModel`. `predict`: takes a checkpoint and '
'restores the model to output predictions on the test set.'
)
flags
.
DEFINE_string
(
'train_data_path'
,
None
,
'Path to training data for BERT classifier.'
)
flags
.
DEFINE_string
(
'eval_data_path'
,
None
,
'Path to evaluation data for BERT classifier.'
)
flags
.
DEFINE_string
(
'input_meta_data_path'
,
None
,
'Path to file that contains meta data about input '
'to be used for training and evaluation.'
)
flags
.
DEFINE_integer
(
'train_data_size'
,
None
,
'Number of training samples '
'to use. If None, uses the full train data. '
'(default: None).'
)
flags
.
DEFINE_string
(
'predict_checkpoint_path'
,
None
,
'Path to the checkpoint for predictions.'
)
flags
.
DEFINE_integer
(
'num_eval_per_epoch'
,
1
,
'Number of evaluations per epoch. The purpose of this flag is to provide '
'more granular evaluation scores and checkpoints. For example, if original '
'data has N samples and num_eval_per_epoch is n, then each epoch will be '
'evaluated every N/n samples.'
)
flags
.
DEFINE_integer
(
'train_batch_size'
,
32
,
'Batch size for training.'
)
flags
.
DEFINE_integer
(
'eval_batch_size'
,
32
,
'Batch size for evaluation.'
)
common_flags
.
define_common_bert_flags
()
FLAGS
=
flags
.
FLAGS
LABEL_TYPES_MAP
=
{
'int'
:
tf
.
int64
,
'float'
:
tf
.
float32
}
def
get_loss_fn
(
num_classes
):
"""Gets the classification loss function."""
def
classification_loss_fn
(
labels
,
logits
):
"""Classification loss."""
labels
=
tf
.
reshape
(
labels
,
[
-
1
])
log_probs
=
tf
.
nn
.
log_softmax
(
logits
,
axis
=-
1
)
one_hot_labels
=
tf
.
one_hot
(
tf
.
cast
(
labels
,
dtype
=
tf
.
int32
),
depth
=
num_classes
,
dtype
=
tf
.
float32
)
per_example_loss
=
-
tf
.
reduce_sum
(
tf
.
cast
(
one_hot_labels
,
dtype
=
tf
.
float32
)
*
log_probs
,
axis
=-
1
)
return
tf
.
reduce_mean
(
per_example_loss
)
return
classification_loss_fn
def
get_dataset_fn
(
input_file_pattern
,
max_seq_length
,
global_batch_size
,
is_training
,
label_type
=
tf
.
int64
,
include_sample_weights
=
False
,
num_samples
=
None
):
"""Gets a closure to create a dataset."""
def
_dataset_fn
(
ctx
=
None
):
"""Returns tf.data.Dataset for distributed BERT pretraining."""
batch_size
=
ctx
.
get_per_replica_batch_size
(
global_batch_size
)
if
ctx
else
global_batch_size
dataset
=
input_pipeline
.
create_classifier_dataset
(
tf
.
io
.
gfile
.
glob
(
input_file_pattern
),
max_seq_length
,
batch_size
,
is_training
=
is_training
,
input_pipeline_context
=
ctx
,
label_type
=
label_type
,
include_sample_weights
=
include_sample_weights
,
num_samples
=
num_samples
)
return
dataset
return
_dataset_fn
def
run_bert_classifier
(
strategy
,
bert_config
,
input_meta_data
,
model_dir
,
epochs
,
steps_per_epoch
,
steps_per_loop
,
eval_steps
,
warmup_steps
,
initial_lr
,
init_checkpoint
,
train_input_fn
,
eval_input_fn
,
training_callbacks
=
True
,
custom_callbacks
=
None
,
custom_metrics
=
None
):
"""Run BERT classifier training using low-level API."""
max_seq_length
=
input_meta_data
[
'max_seq_length'
]
num_classes
=
input_meta_data
.
get
(
'num_labels'
,
1
)
is_regression
=
num_classes
==
1
def
_get_classifier_model
():
"""Gets a classifier model."""
classifier_model
,
core_model
=
(
bert_models
.
classifier_model
(
bert_config
,
num_classes
,
max_seq_length
,
hub_module_url
=
FLAGS
.
hub_module_url
,
hub_module_trainable
=
FLAGS
.
hub_module_trainable
))
optimizer
=
optimization
.
create_optimizer
(
initial_lr
,
steps_per_epoch
*
epochs
,
warmup_steps
,
FLAGS
.
end_lr
,
FLAGS
.
optimizer_type
)
classifier_model
.
optimizer
=
performance
.
configure_optimizer
(
optimizer
,
use_float16
=
common_flags
.
use_float16
())
return
classifier_model
,
core_model
# tf.keras.losses objects accept optional sample_weight arguments (eg. coming
# from the dataset) to compute weighted loss, as used for the regression
# tasks. The classification tasks, using the custom get_loss_fn don't accept
# sample weights though.
loss_fn
=
(
tf
.
keras
.
losses
.
MeanSquaredError
()
if
is_regression
else
get_loss_fn
(
num_classes
))
# Defines evaluation metrics function, which will create metrics in the
# correct device and strategy scope.
if
custom_metrics
:
metric_fn
=
custom_metrics
elif
is_regression
:
metric_fn
=
functools
.
partial
(
tf
.
keras
.
metrics
.
MeanSquaredError
,
'mean_squared_error'
,
dtype
=
tf
.
float32
)
else
:
metric_fn
=
functools
.
partial
(
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
,
'accuracy'
,
dtype
=
tf
.
float32
)
# Start training using Keras compile/fit API.
logging
.
info
(
'Training using TF 2.x Keras compile/fit API with '
'distribution strategy.'
)
return
run_keras_compile_fit
(
model_dir
,
strategy
,
_get_classifier_model
,
train_input_fn
,
eval_input_fn
,
loss_fn
,
metric_fn
,
init_checkpoint
,
epochs
,
steps_per_epoch
,
steps_per_loop
,
eval_steps
,
training_callbacks
=
training_callbacks
,
custom_callbacks
=
custom_callbacks
)
def
run_keras_compile_fit
(
model_dir
,
strategy
,
model_fn
,
train_input_fn
,
eval_input_fn
,
loss_fn
,
metric_fn
,
init_checkpoint
,
epochs
,
steps_per_epoch
,
steps_per_loop
,
eval_steps
,
training_callbacks
=
True
,
custom_callbacks
=
None
):
"""Runs BERT classifier model using Keras compile/fit API."""
with
strategy
.
scope
():
training_dataset
=
train_input_fn
()
evaluation_dataset
=
eval_input_fn
()
if
eval_input_fn
else
None
bert_model
,
sub_model
=
model_fn
()
optimizer
=
bert_model
.
optimizer
if
init_checkpoint
:
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
sub_model
,
encoder
=
sub_model
)
checkpoint
.
read
(
init_checkpoint
).
assert_existing_objects_matched
()
if
not
isinstance
(
metric_fn
,
(
list
,
tuple
)):
metric_fn
=
[
metric_fn
]
bert_model
.
compile
(
optimizer
=
optimizer
,
loss
=
loss_fn
,
metrics
=
[
fn
()
for
fn
in
metric_fn
],
steps_per_execution
=
steps_per_loop
)
summary_dir
=
os
.
path
.
join
(
model_dir
,
'summaries'
)
summary_callback
=
tf
.
keras
.
callbacks
.
TensorBoard
(
summary_dir
)
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
bert_model
,
optimizer
=
optimizer
)
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
checkpoint
,
directory
=
model_dir
,
max_to_keep
=
None
,
step_counter
=
optimizer
.
iterations
,
checkpoint_interval
=
0
)
checkpoint_callback
=
keras_utils
.
SimpleCheckpoint
(
checkpoint_manager
)
if
training_callbacks
:
if
custom_callbacks
is
not
None
:
custom_callbacks
+=
[
summary_callback
,
checkpoint_callback
]
else
:
custom_callbacks
=
[
summary_callback
,
checkpoint_callback
]
history
=
bert_model
.
fit
(
x
=
training_dataset
,
validation_data
=
evaluation_dataset
,
steps_per_epoch
=
steps_per_epoch
,
epochs
=
epochs
,
validation_steps
=
eval_steps
,
callbacks
=
custom_callbacks
)
stats
=
{
'total_training_steps'
:
steps_per_epoch
*
epochs
}
if
'loss'
in
history
.
history
:
stats
[
'train_loss'
]
=
history
.
history
[
'loss'
][
-
1
]
if
'val_accuracy'
in
history
.
history
:
stats
[
'eval_metrics'
]
=
history
.
history
[
'val_accuracy'
][
-
1
]
return
bert_model
,
stats
def
get_predictions_and_labels
(
strategy
,
trained_model
,
eval_input_fn
,
is_regression
=
False
,
return_probs
=
False
):
"""Obtains predictions of trained model on evaluation data.
Note that list of labels is returned along with the predictions because the
order changes on distributing dataset over TPU pods.
Args:
strategy: Distribution strategy.
trained_model: Trained model with preloaded weights.
eval_input_fn: Input function for evaluation data.
is_regression: Whether it is a regression task.
return_probs: Whether to return probabilities of classes.
Returns:
predictions: List of predictions.
labels: List of gold labels corresponding to predictions.
"""
@
tf
.
function
def
test_step
(
iterator
):
"""Computes predictions on distributed devices."""
def
_test_step_fn
(
inputs
):
"""Replicated predictions."""
inputs
,
labels
=
inputs
logits
=
trained_model
(
inputs
,
training
=
False
)
if
not
is_regression
:
probabilities
=
tf
.
nn
.
softmax
(
logits
)
return
probabilities
,
labels
else
:
return
logits
,
labels
outputs
,
labels
=
strategy
.
run
(
_test_step_fn
,
args
=
(
next
(
iterator
),))
# outputs: current batch logits as a tuple of shard logits
outputs
=
tf
.
nest
.
map_structure
(
strategy
.
experimental_local_results
,
outputs
)
labels
=
tf
.
nest
.
map_structure
(
strategy
.
experimental_local_results
,
labels
)
return
outputs
,
labels
def
_run_evaluation
(
test_iterator
):
"""Runs evaluation steps."""
preds
,
golds
=
list
(),
list
()
try
:
with
tf
.
experimental
.
async_scope
():
while
True
:
probabilities
,
labels
=
test_step
(
test_iterator
)
for
cur_probs
,
cur_labels
in
zip
(
probabilities
,
labels
):
if
return_probs
:
preds
.
extend
(
cur_probs
.
numpy
().
tolist
())
else
:
preds
.
extend
(
tf
.
math
.
argmax
(
cur_probs
,
axis
=
1
).
numpy
())
golds
.
extend
(
cur_labels
.
numpy
().
tolist
())
except
(
StopIteration
,
tf
.
errors
.
OutOfRangeError
):
tf
.
experimental
.
async_clear_error
()
return
preds
,
golds
test_iter
=
iter
(
strategy
.
distribute_datasets_from_function
(
eval_input_fn
))
predictions
,
labels
=
_run_evaluation
(
test_iter
)
return
predictions
,
labels
def
export_classifier
(
model_export_path
,
input_meta_data
,
bert_config
,
model_dir
):
"""Exports a trained model as a `SavedModel` for inference.
Args:
model_export_path: a string specifying the path to the SavedModel directory.
input_meta_data: dictionary containing meta data about input and model.
bert_config: Bert configuration file to define core bert layers.
model_dir: The directory where the model weights and training/evaluation
summaries are stored.
Raises:
Export path is not specified, got an empty string or None.
"""
if
not
model_export_path
:
raise
ValueError
(
'Export path is not specified: %s'
%
model_export_path
)
if
not
model_dir
:
raise
ValueError
(
'Export path is not specified: %s'
%
model_dir
)
# Export uses float32 for now, even if training uses mixed precision.
tf
.
keras
.
mixed_precision
.
set_global_policy
(
'float32'
)
classifier_model
=
bert_models
.
classifier_model
(
bert_config
,
input_meta_data
.
get
(
'num_labels'
,
1
),
hub_module_url
=
FLAGS
.
hub_module_url
,
hub_module_trainable
=
False
)[
0
]
model_saving_utils
.
export_bert_model
(
model_export_path
,
model
=
classifier_model
,
checkpoint_dir
=
model_dir
)
def
run_bert
(
strategy
,
input_meta_data
,
model_config
,
train_input_fn
=
None
,
eval_input_fn
=
None
,
init_checkpoint
=
None
,
custom_callbacks
=
None
,
custom_metrics
=
None
):
"""Run BERT training."""
# Enables XLA in Session Config. Should not be set for TPU.
keras_utils
.
set_session_config
(
FLAGS
.
enable_xla
)
performance
.
set_mixed_precision_policy
(
common_flags
.
dtype
())
epochs
=
FLAGS
.
num_train_epochs
*
FLAGS
.
num_eval_per_epoch
train_data_size
=
(
input_meta_data
[
'train_data_size'
]
//
FLAGS
.
num_eval_per_epoch
)
if
FLAGS
.
train_data_size
:
train_data_size
=
min
(
train_data_size
,
FLAGS
.
train_data_size
)
logging
.
info
(
'Updated train_data_size: %s'
,
train_data_size
)
steps_per_epoch
=
int
(
train_data_size
/
FLAGS
.
train_batch_size
)
warmup_steps
=
int
(
epochs
*
train_data_size
*
0.1
/
FLAGS
.
train_batch_size
)
eval_steps
=
int
(
math
.
ceil
(
input_meta_data
[
'eval_data_size'
]
/
FLAGS
.
eval_batch_size
))
if
not
strategy
:
raise
ValueError
(
'Distribution strategy has not been specified.'
)
if
not
custom_callbacks
:
custom_callbacks
=
[]
if
FLAGS
.
log_steps
:
custom_callbacks
.
append
(
keras_utils
.
TimeHistory
(
batch_size
=
FLAGS
.
train_batch_size
,
log_steps
=
FLAGS
.
log_steps
,
logdir
=
FLAGS
.
model_dir
))
trained_model
,
_
=
run_bert_classifier
(
strategy
,
model_config
,
input_meta_data
,
FLAGS
.
model_dir
,
epochs
,
steps_per_epoch
,
FLAGS
.
steps_per_loop
,
eval_steps
,
warmup_steps
,
FLAGS
.
learning_rate
,
init_checkpoint
or
FLAGS
.
init_checkpoint
,
train_input_fn
,
eval_input_fn
,
custom_callbacks
=
custom_callbacks
,
custom_metrics
=
custom_metrics
)
if
FLAGS
.
model_export_path
:
model_saving_utils
.
export_bert_model
(
FLAGS
.
model_export_path
,
model
=
trained_model
)
return
trained_model
def
custom_main
(
custom_callbacks
=
None
,
custom_metrics
=
None
):
"""Run classification or regression.
Args:
custom_callbacks: list of tf.keras.Callbacks passed to training loop.
custom_metrics: list of metrics passed to the training loop.
"""
gin
.
parse_config_files_and_bindings
(
FLAGS
.
gin_file
,
FLAGS
.
gin_param
)
with
tf
.
io
.
gfile
.
GFile
(
FLAGS
.
input_meta_data_path
,
'rb'
)
as
reader
:
input_meta_data
=
json
.
loads
(
reader
.
read
().
decode
(
'utf-8'
))
label_type
=
LABEL_TYPES_MAP
[
input_meta_data
.
get
(
'label_type'
,
'int'
)]
include_sample_weights
=
input_meta_data
.
get
(
'has_sample_weights'
,
False
)
if
not
FLAGS
.
model_dir
:
FLAGS
.
model_dir
=
'/tmp/bert20/'
bert_config
=
bert_configs
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
if
FLAGS
.
mode
==
'export_only'
:
export_classifier
(
FLAGS
.
model_export_path
,
input_meta_data
,
bert_config
,
FLAGS
.
model_dir
)
return
strategy
=
distribute_utils
.
get_distribution_strategy
(
distribution_strategy
=
FLAGS
.
distribution_strategy
,
num_gpus
=
FLAGS
.
num_gpus
,
tpu_address
=
FLAGS
.
tpu
)
eval_input_fn
=
get_dataset_fn
(
FLAGS
.
eval_data_path
,
input_meta_data
[
'max_seq_length'
],
FLAGS
.
eval_batch_size
,
is_training
=
False
,
label_type
=
label_type
,
include_sample_weights
=
include_sample_weights
)
if
FLAGS
.
mode
==
'predict'
:
num_labels
=
input_meta_data
.
get
(
'num_labels'
,
1
)
with
strategy
.
scope
():
classifier_model
=
bert_models
.
classifier_model
(
bert_config
,
num_labels
)[
0
]
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
classifier_model
)
latest_checkpoint_file
=
(
FLAGS
.
predict_checkpoint_path
or
tf
.
train
.
latest_checkpoint
(
FLAGS
.
model_dir
))
assert
latest_checkpoint_file
logging
.
info
(
'Checkpoint file %s found and restoring from '
'checkpoint'
,
latest_checkpoint_file
)
checkpoint
.
restore
(
latest_checkpoint_file
).
assert_existing_objects_matched
()
preds
,
_
=
get_predictions_and_labels
(
strategy
,
classifier_model
,
eval_input_fn
,
is_regression
=
(
num_labels
==
1
),
return_probs
=
True
)
output_predict_file
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'test_results.tsv'
)
with
tf
.
io
.
gfile
.
GFile
(
output_predict_file
,
'w'
)
as
writer
:
logging
.
info
(
'***** Predict results *****'
)
for
probabilities
in
preds
:
output_line
=
'
\t
'
.
join
(
str
(
class_probability
)
for
class_probability
in
probabilities
)
+
'
\n
'
writer
.
write
(
output_line
)
return
if
FLAGS
.
mode
!=
'train_and_eval'
:
raise
ValueError
(
'Unsupported mode is specified: %s'
%
FLAGS
.
mode
)
train_input_fn
=
get_dataset_fn
(
FLAGS
.
train_data_path
,
input_meta_data
[
'max_seq_length'
],
FLAGS
.
train_batch_size
,
is_training
=
True
,
label_type
=
label_type
,
include_sample_weights
=
include_sample_weights
,
num_samples
=
FLAGS
.
train_data_size
)
run_bert
(
strategy
,
input_meta_data
,
bert_config
,
train_input_fn
,
eval_input_fn
,
custom_callbacks
=
custom_callbacks
,
custom_metrics
=
custom_metrics
)
def
main
(
_
):
custom_main
(
custom_callbacks
=
None
,
custom_metrics
=
None
)
if
__name__
==
'__main__'
:
flags
.
mark_flag_as_required
(
'bert_config_file'
)
flags
.
mark_flag_as_required
(
'input_meta_data_path'
)
flags
.
mark_flag_as_required
(
'model_dir'
)
app
.
run
(
main
)
official/
nlp
/bert/run_pretraining.py
→
official/
legacy
/bert/run_pretraining.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -21,13 +21,13 @@ from absl import logging
import
gin
import
tensorflow
as
tf
from
official.common
import
distribute_utils
from
official.legacy.bert
import
bert_models
from
official.legacy.bert
import
common_flags
from
official.legacy.bert
import
configs
from
official.legacy.bert
import
input_pipeline
from
official.legacy.bert
import
model_training_utils
from
official.modeling
import
performance
from
official.nlp
import
optimization
from
official.nlp.bert
import
bert_models
from
official.nlp.bert
import
common_flags
from
official.nlp.bert
import
configs
from
official.nlp.bert
import
input_pipeline
from
official.nlp.bert
import
model_training_utils
flags
.
DEFINE_string
(
'input_files'
,
None
,
...
...
official/legacy/bert/run_squad.py
0 → 100644
View file @
32e4ca51
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run BERT on SQuAD 1.1 and SQuAD 2.0 in TF 2.x."""
import
json
import
os
import
time
# Import libraries
from
absl
import
app
from
absl
import
flags
from
absl
import
logging
import
gin
import
tensorflow
as
tf
from
official.common
import
distribute_utils
from
official.legacy.bert
import
configs
as
bert_configs
from
official.legacy.bert
import
run_squad_helper
from
official.nlp.data
import
squad_lib
as
squad_lib_wp
from
official.nlp.tools
import
tokenization
from
official.utils.misc
import
keras_utils
flags
.
DEFINE_string
(
'vocab_file'
,
None
,
'The vocabulary file that the BERT model was trained on.'
)
# More flags can be found in run_squad_helper.
run_squad_helper
.
define_common_squad_flags
()
FLAGS
=
flags
.
FLAGS
def
train_squad
(
strategy
,
input_meta_data
,
custom_callbacks
=
None
,
run_eagerly
=
False
,
init_checkpoint
=
None
,
sub_model_export_name
=
None
):
"""Run bert squad training."""
bert_config
=
bert_configs
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
init_checkpoint
=
init_checkpoint
or
FLAGS
.
init_checkpoint
run_squad_helper
.
train_squad
(
strategy
,
input_meta_data
,
bert_config
,
custom_callbacks
,
run_eagerly
,
init_checkpoint
,
sub_model_export_name
=
sub_model_export_name
)
def
predict_squad
(
strategy
,
input_meta_data
):
"""Makes predictions for the squad dataset."""
bert_config
=
bert_configs
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
tokenizer
=
tokenization
.
FullTokenizer
(
vocab_file
=
FLAGS
.
vocab_file
,
do_lower_case
=
FLAGS
.
do_lower_case
)
run_squad_helper
.
predict_squad
(
strategy
,
input_meta_data
,
tokenizer
,
bert_config
,
squad_lib_wp
)
def
eval_squad
(
strategy
,
input_meta_data
):
"""Evaluate on the squad dataset."""
bert_config
=
bert_configs
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
tokenizer
=
tokenization
.
FullTokenizer
(
vocab_file
=
FLAGS
.
vocab_file
,
do_lower_case
=
FLAGS
.
do_lower_case
)
eval_metrics
=
run_squad_helper
.
eval_squad
(
strategy
,
input_meta_data
,
tokenizer
,
bert_config
,
squad_lib_wp
)
return
eval_metrics
def
export_squad
(
model_export_path
,
input_meta_data
):
"""Exports a trained model as a `SavedModel` for inference.
Args:
model_export_path: a string specifying the path to the SavedModel directory.
input_meta_data: dictionary containing meta data about input and model.
Raises:
Export path is not specified, got an empty string or None.
"""
bert_config
=
bert_configs
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
run_squad_helper
.
export_squad
(
model_export_path
,
input_meta_data
,
bert_config
)
def
main
(
_
):
gin
.
parse_config_files_and_bindings
(
FLAGS
.
gin_file
,
FLAGS
.
gin_param
)
with
tf
.
io
.
gfile
.
GFile
(
FLAGS
.
input_meta_data_path
,
'rb'
)
as
reader
:
input_meta_data
=
json
.
loads
(
reader
.
read
().
decode
(
'utf-8'
))
if
FLAGS
.
mode
==
'export_only'
:
export_squad
(
FLAGS
.
model_export_path
,
input_meta_data
)
return
# Configures cluster spec for multi-worker distribution strategy.
if
FLAGS
.
num_gpus
>
0
:
_
=
distribute_utils
.
configure_cluster
(
FLAGS
.
worker_hosts
,
FLAGS
.
task_index
)
strategy
=
distribute_utils
.
get_distribution_strategy
(
distribution_strategy
=
FLAGS
.
distribution_strategy
,
num_gpus
=
FLAGS
.
num_gpus
,
all_reduce_alg
=
FLAGS
.
all_reduce_alg
,
tpu_address
=
FLAGS
.
tpu
)
if
'train'
in
FLAGS
.
mode
:
if
FLAGS
.
log_steps
:
custom_callbacks
=
[
keras_utils
.
TimeHistory
(
batch_size
=
FLAGS
.
train_batch_size
,
log_steps
=
FLAGS
.
log_steps
,
logdir
=
FLAGS
.
model_dir
,
)]
else
:
custom_callbacks
=
None
train_squad
(
strategy
,
input_meta_data
,
custom_callbacks
=
custom_callbacks
,
run_eagerly
=
FLAGS
.
run_eagerly
,
sub_model_export_name
=
FLAGS
.
sub_model_export_name
,
)
if
'predict'
in
FLAGS
.
mode
:
predict_squad
(
strategy
,
input_meta_data
)
if
'eval'
in
FLAGS
.
mode
:
eval_metrics
=
eval_squad
(
strategy
,
input_meta_data
)
f1_score
=
eval_metrics
[
'final_f1'
]
logging
.
info
(
'SQuAD eval F1-score: %f'
,
f1_score
)
summary_dir
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'summaries'
,
'eval'
)
summary_writer
=
tf
.
summary
.
create_file_writer
(
summary_dir
)
with
summary_writer
.
as_default
():
# TODO(lehou): write to the correct step number.
tf
.
summary
.
scalar
(
'F1-score'
,
f1_score
,
step
=
0
)
summary_writer
.
flush
()
# Also write eval_metrics to json file.
squad_lib_wp
.
write_to_json_files
(
eval_metrics
,
os
.
path
.
join
(
summary_dir
,
'eval_metrics.json'
))
time
.
sleep
(
60
)
if
__name__
==
'__main__'
:
flags
.
mark_flag_as_required
(
'bert_config_file'
)
flags
.
mark_flag_as_required
(
'model_dir'
)
app
.
run
(
main
)
official/
nlp
/bert/run_squad_helper.py
→
official/
legacy
/bert/run_squad_helper.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -21,16 +21,16 @@ import os
from
absl
import
flags
from
absl
import
logging
import
tensorflow
as
tf
from
official.legacy.bert
import
bert_models
from
official.legacy.bert
import
common_flags
from
official.legacy.bert
import
input_pipeline
from
official.legacy.bert
import
model_saving_utils
from
official.legacy.bert
import
model_training_utils
from
official.modeling
import
performance
from
official.nlp
import
optimization
from
official.nlp.bert
import
bert_models
from
official.nlp.bert
import
common_flags
from
official.nlp.bert
import
input_pipeline
from
official.nlp.bert
import
model_saving_utils
from
official.nlp.bert
import
model_training_utils
from
official.nlp.bert
import
squad_evaluate_v1_1
from
official.nlp.bert
import
squad_evaluate_v2_0
from
official.nlp.data
import
squad_lib_sp
from
official.nlp.tools
import
squad_evaluate_v1_1
from
official.nlp.tools
import
squad_evaluate_v2_0
from
official.utils.misc
import
keras_utils
...
...
official/
nlp
/bert/serving.py
→
official/
legacy
/bert/serving.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -18,8 +18,8 @@ from absl import app
from
absl
import
flags
import
tensorflow
as
tf
from
official.
nlp
.bert
import
bert_models
from
official.
nlp
.bert
import
configs
from
official.
legacy
.bert
import
bert_models
from
official.
legacy
.bert
import
configs
flags
.
DEFINE_integer
(
"sequence_length"
,
None
,
"Sequence length to parse the tf.Example. If "
...
...
official/legacy/detection/__init__.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/legacy/detection/configs/__init__.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/legacy/detection/configs/base_config.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/legacy/detection/configs/factory.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/legacy/detection/configs/maskrcnn_config.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/legacy/detection/configs/olnmask_config.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/legacy/detection/configs/retinanet_config.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/legacy/detection/configs/shapemask_config.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/legacy/detection/dataloader/__init__.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/legacy/detection/dataloader/anchor.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -22,7 +22,7 @@ import collections
import
tensorflow
as
tf
from
official.legacy.detection.utils
import
box_utils
from
official.vision.
beta.
ops
import
iou_similarity
from
official.vision.ops
import
iou_similarity
from
official.vision.utils.object_detection
import
argmax_matcher
from
official.vision.utils.object_detection
import
balanced_positive_negative_sampler
from
official.vision.utils.object_detection
import
box_list
...
...
official/legacy/detection/dataloader/factory.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/legacy/detection/dataloader/input_reader.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/legacy/detection/dataloader/maskrcnn_parser.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/legacy/detection/dataloader/mode_keys.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
Prev
1
2
3
4
5
6
7
8
9
…
39
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment