Commit 32e4ca51 authored by qianyj's avatar qianyj
Browse files

Update code to v2.11.0

parents 9485aa1d 71060f67
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -117,6 +117,61 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase):
model_dir=model_dir,
run_post_eval=run_post_eval)
@combinations.generate(
combinations.combine(
distribution_strategy=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
flag_mode=['train', 'eval', 'train_and_eval'],
run_post_eval=[True, False]))
def test_end_to_end_class(self, distribution_strategy, flag_mode,
run_post_eval):
model_dir = self.get_temp_dir()
flags_dict = dict(
experiment='mock',
mode=flag_mode,
model_dir=model_dir,
params_override=json.dumps(self._test_config))
with flagsaver.flagsaver(**flags_dict):
params = train_utils.parse_configuration(flags.FLAGS)
train_utils.serialize_config(params, model_dir)
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
_, logs = train_lib.OrbitExperimentRunner(
distribution_strategy=distribution_strategy,
task=task,
mode=flag_mode,
params=params,
model_dir=model_dir,
run_post_eval=run_post_eval).run()
if 'eval' in flag_mode:
self.assertTrue(
tf.io.gfile.exists(
os.path.join(model_dir,
params.trainer.validation_summary_subdir)))
if run_post_eval:
self.assertNotEmpty(logs)
else:
self.assertEmpty(logs)
self.assertNotEmpty(
tf.io.gfile.glob(os.path.join(model_dir, 'params.yaml')))
if flag_mode == 'eval':
return
self.assertNotEmpty(
tf.io.gfile.glob(os.path.join(model_dir, 'checkpoint')))
# Tests continuous evaluation.
_, logs = train_lib.OrbitExperimentRunner(
distribution_strategy=distribution_strategy,
task=task,
mode='continuous_eval',
params=params,
model_dir=model_dir,
run_post_eval=run_post_eval).run()
@combinations.generate(
combinations.combine(
distribution_strategy=[
......@@ -148,12 +203,12 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase):
task.build_losses = build_losses
with self.assertRaises(RuntimeError):
train_lib.run_experiment(
train_lib.OrbitExperimentRunner(
distribution_strategy=distribution_strategy,
task=task,
mode=flag_mode,
params=params,
model_dir=model_dir)
model_dir=model_dir).run()
@combinations.generate(
combinations.combine(
......@@ -194,12 +249,12 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase):
task.build_losses = build_losses
model, _ = train_lib.run_experiment(
model, _ = train_lib.OrbitExperimentRunner(
distribution_strategy=distribution_strategy,
task=task,
mode=flag_mode,
params=params,
model_dir=model_dir)
model_dir=model_dir).run()
after_weights = model.get_weights()
for left, right in zip(before_weights, after_weights):
self.assertAllEqual(left, right)
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -15,6 +15,7 @@
"""Training utils."""
import copy
import dataclasses
import inspect
import json
import os
import pprint
......@@ -35,6 +36,9 @@ from official.core import exp_factory
from official.modeling import hyperparams
BEST_CHECKPOINT_NAME = 'best_ckpt'
def get_leaf_nested_dict(d: Dict[str, Any], keys: List[str]) -> Dict[str, Any]:
"""Get leaf from a dictionary with arbitrary depth with a list of keys.
......@@ -101,7 +105,6 @@ def maybe_create_best_ckpt_exporter(params: config_definitions.ExperimentConfig,
return best_ckpt_exporter
# TODO(b/180147589): Add tests for this module.
class BestCheckpointExporter:
"""Keeps track of the best result, and saves its checkpoint.
......@@ -138,7 +141,7 @@ class BestCheckpointExporter:
checkpoint,
directory=self._export_dir,
max_to_keep=1,
checkpoint_name='best_ckpt')
checkpoint_name=BEST_CHECKPOINT_NAME)
return self._checkpoint_manager
......@@ -209,6 +212,28 @@ class BestCheckpointExporter:
return tf.train.latest_checkpoint(self._export_dir)
def create_optimizer(task: base_task.Task,
params: config_definitions.ExperimentConfig
) -> tf.keras.optimizers.Optimizer:
"""A create optimizer util to be backward compatability with new args."""
if 'dp_config' in inspect.signature(task.create_optimizer).parameters:
dp_config = None
if hasattr(params.task, 'differential_privacy_config'):
dp_config = params.task.differential_privacy_config
optimizer = task.create_optimizer(
params.trainer.optimizer_config, params.runtime,
dp_config=dp_config)
else:
if hasattr(params.task, 'differential_privacy_config'
) and params.task.differential_privacy_config is not None:
raise ValueError('Differential privacy config is specified but '
'task.create_optimizer api does not accept it.')
optimizer = task.create_optimizer(
params.trainer.optimizer_config,
params.runtime)
return optimizer
@gin.configurable
def create_trainer(params: config_definitions.ExperimentConfig,
task: base_task.Task,
......@@ -219,8 +244,7 @@ def create_trainer(params: config_definitions.ExperimentConfig,
"""Create trainer."""
logging.info('Running default trainer.')
model = task.build_model()
optimizer = task.create_optimizer(params.trainer.optimizer_config,
params.runtime)
optimizer = create_optimizer(task, params)
return trainer_cls(
params,
task,
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -13,6 +13,7 @@
# limitations under the License.
"""Tests for official.core.train_utils."""
import json
import os
import pprint
......@@ -138,5 +139,60 @@ class TrainUtilsTest(tf.test.TestCase):
self.assertEqual(params_from_obj.trainer.validation_steps, 11)
class BestCheckpointExporterTest(tf.test.TestCase):
def test_maybe_export(self):
model_dir = self.create_tempdir().full_path
best_ckpt_path = os.path.join(model_dir, 'best_ckpt-1')
metric_name = 'test_metric|metric_1'
exporter = train_utils.BestCheckpointExporter(
model_dir, metric_name, 'higher')
v = tf.Variable(1.0)
checkpoint = tf.train.Checkpoint(v=v)
ret = exporter.maybe_export_checkpoint(
checkpoint, {'test_metric': {'metric_1': 5.0}}, 100)
with self.subTest(name='Successful first save.'):
self.assertEqual(ret, True)
v_2 = tf.Variable(2.0)
checkpoint_2 = tf.train.Checkpoint(v=v_2)
checkpoint_2.restore(best_ckpt_path)
self.assertEqual(v_2.numpy(), 1.0)
v = tf.Variable(3.0)
checkpoint = tf.train.Checkpoint(v=v)
ret = exporter.maybe_export_checkpoint(
checkpoint, {'test_metric': {'metric_1': 6.0}}, 200)
with self.subTest(name='Successful better metic save.'):
self.assertEqual(ret, True)
v_2 = tf.Variable(2.0)
checkpoint_2 = tf.train.Checkpoint(v=v_2)
checkpoint_2.restore(best_ckpt_path)
self.assertEqual(v_2.numpy(), 3.0)
v = tf.Variable(5.0)
checkpoint = tf.train.Checkpoint(v=v)
ret = exporter.maybe_export_checkpoint(
checkpoint, {'test_metric': {'metric_1': 1.0}}, 300)
with self.subTest(name='Worse metic no save.'):
self.assertEqual(ret, False)
v_2 = tf.Variable(2.0)
checkpoint_2 = tf.train.Checkpoint(v=v_2)
checkpoint_2.restore(best_ckpt_path)
self.assertEqual(v_2.numpy(), 3.0)
def test_export_best_eval_metric(self):
model_dir = self.create_tempdir().full_path
metric_name = 'test_metric|metric_1'
exporter = train_utils.BestCheckpointExporter(model_dir, metric_name,
'higher')
exporter.export_best_eval_metric({'test_metric': {'metric_1': 5.0}}, 100)
with tf.io.gfile.GFile(os.path.join(model_dir, 'info.json'),
'rb') as reader:
metric = json.loads(reader.read())
self.assertAllEqual(
metric,
{'test_metric': {'metric_1': 5.0}, 'best_ckpt_global_step': 100.0})
if __name__ == '__main__':
tf.test.main()
Models in this `legacy` directory are mainly are used for benchmarking the
models.
Please note that the models in this `legacy` directory are not supported like
the models in official/nlp and official/vision.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The ALBERT configurations."""
import six
from official.legacy.bert import configs
class AlbertConfig(configs.BertConfig):
"""Configuration for `ALBERT`."""
def __init__(self, num_hidden_groups=1, inner_group_num=1, **kwargs):
"""Constructs AlbertConfig.
Args:
num_hidden_groups: Number of group for the hidden layers, parameters in
the same group are shared. Note that this value and also the following
'inner_group_num' has to be 1 for now, because all released ALBERT
models set them to 1. We may support arbitary valid values in future.
inner_group_num: Number of inner repetition of attention and ffn.
**kwargs: The remaining arguments are the same as above 'BertConfig'.
"""
super(AlbertConfig, self).__init__(**kwargs)
# TODO(chendouble): 'inner_group_num' and 'num_hidden_groups' are always 1
# in the released ALBERT. Support other values in AlbertEncoder if needed.
if inner_group_num != 1 or num_hidden_groups != 1:
raise ValueError("We only support 'inner_group_num' and "
"'num_hidden_groups' as 1.")
@classmethod
def from_dict(cls, json_object):
"""Constructs a `AlbertConfig` from a Python dictionary of parameters."""
config = AlbertConfig(vocab_size=None)
for (key, value) in six.iteritems(json_object):
config.__dict__[key] = value
return config
# BERT (Bidirectional Encoder Representations from Transformers)
**WARNING**: We are on the way to deprecate most of the code in this directory.
Please see
[this link](../g3doc/tutorials/bert_new.md)
for the new tutorial and use the new code in `nlp/modeling`. This README is
still correct for this legacy implementation.
The academic paper which describes BERT in detail and provides full results on a
number of tasks can be found here: https://arxiv.org/abs/1810.04805.
This repository contains TensorFlow 2.x implementation for BERT.
## Contents
* [Contents](#contents)
* [Pre-trained Models](#pre-trained-models)
* [Restoring from Checkpoints](#restoring-from-checkpoints)
* [Set Up](#set-up)
* [Process Datasets](#process-datasets)
* [Fine-tuning with BERT](#fine-tuning-with-bert)
* [Cloud GPUs and TPUs](#cloud-gpus-and-tpus)
* [Sentence and Sentence-pair Classification Tasks](#sentence-and-sentence-pair-classification-tasks)
* [SQuAD 1.1](#squad-1.1)
## Pre-trained Models
We released both checkpoints and tf.hub modules as the pretrained models for
fine-tuning. They are TF 2.x compatible and are converted from the checkpoints
released in TF 1.x official BERT repository
[google-research/bert](https://github.com/google-research/bert)
in order to keep consistent with BERT paper.
### Access to Pretrained Checkpoints
Pretrained checkpoints can be found in the following links:
**Note: We have switched BERT implementation
to use Keras functional-style networks in [nlp/modeling](../modeling).
The new checkpoints are:**
* **[`BERT-Large, Uncased (Whole Word Masking)`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/wwm_uncased_L-24_H-1024_A-16.tar.gz)**:
24-layer, 1024-hidden, 16-heads, 340M parameters
* **[`BERT-Large, Cased (Whole Word Masking)`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/wwm_cased_L-24_H-1024_A-16.tar.gz)**:
24-layer, 1024-hidden, 16-heads, 340M parameters
* **[`BERT-Base, Uncased`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/uncased_L-12_H-768_A-12.tar.gz)**:
12-layer, 768-hidden, 12-heads, 110M parameters
* **[`BERT-Large, Uncased`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16.tar.gz)**:
24-layer, 1024-hidden, 16-heads, 340M parameters
* **[`BERT-Base, Cased`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/cased_L-12_H-768_A-12.tar.gz)**:
12-layer, 768-hidden, 12-heads , 110M parameters
* **[`BERT-Large, Cased`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/cased_L-24_H-1024_A-16.tar.gz)**:
24-layer, 1024-hidden, 16-heads, 340M parameters
* **[`BERT-Base, Multilingual Cased`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/multi_cased_L-12_H-768_A-12.tar.gz)**:
104 languages, 12-layer, 768-hidden, 12-heads, 110M parameters
We recommend to host checkpoints on Google Cloud storage buckets when you use
Cloud GPU/TPU.
### Restoring from Checkpoints
`tf.train.Checkpoint` is used to manage model checkpoints in TF 2. To restore
weights from provided pre-trained checkpoints, you can use the following code:
```python
init_checkpoint='the pretrained model checkpoint path.'
model=tf.keras.Model() # Bert pre-trained model as feature extractor.
checkpoint = tf.train.Checkpoint(model=model)
checkpoint.restore(init_checkpoint)
```
Checkpoints featuring native serialized Keras models
(i.e. model.load()/load_weights()) will be available soon.
### Access to Pretrained hub modules.
Pretrained tf.hub modules in TF 2.x SavedModel format can be found in the
following links:
* **[`BERT-Large, Uncased (Whole Word Masking)`](https://tfhub.dev/tensorflow/bert_en_wwm_uncased_L-24_H-1024_A-16/)**:
24-layer, 1024-hidden, 16-heads, 340M parameters
* **[`BERT-Large, Cased (Whole Word Masking)`](https://tfhub.dev/tensorflow/bert_en_wwm_cased_L-24_H-1024_A-16/)**:
24-layer, 1024-hidden, 16-heads, 340M parameters
* **[`BERT-Base, Uncased`](https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/)**:
12-layer, 768-hidden, 12-heads, 110M parameters
* **[`BERT-Large, Uncased`](https://tfhub.dev/tensorflow/bert_en_uncased_L-24_H-1024_A-16/)**:
24-layer, 1024-hidden, 16-heads, 340M parameters
* **[`BERT-Base, Cased`](https://tfhub.dev/tensorflow/bert_en_cased_L-12_H-768_A-12/)**:
12-layer, 768-hidden, 12-heads , 110M parameters
* **[`BERT-Large, Cased`](https://tfhub.dev/tensorflow/bert_en_cased_L-24_H-1024_A-16/)**:
24-layer, 1024-hidden, 16-heads, 340M parameters
* **[`BERT-Base, Multilingual Cased`](https://tfhub.dev/tensorflow/bert_multi_cased_L-12_H-768_A-12/)**:
104 languages, 12-layer, 768-hidden, 12-heads, 110M parameters
* **[`BERT-Base, Chinese`](https://tfhub.dev/tensorflow/bert_zh_L-12_H-768_A-12/)**:
Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads,
110M parameters
## Set Up
```shell
export PYTHONPATH="$PYTHONPATH:/path/to/models"
```
Install `tf-nightly` to get latest updates:
```shell
pip install tf-nightly-gpu
```
With TPU, GPU support is not necessary. First, you need to create a `tf-nightly`
TPU with [ctpu tool](https://github.com/tensorflow/tpu/tree/master/tools/ctpu):
```shell
ctpu up -name <instance name> --tf-version=”nightly”
```
Second, you need to install TF 2 `tf-nightly` on your VM:
```shell
pip install tf-nightly
```
## Process Datasets
### Pre-training
There is no change to generate pre-training data. Please use the script
[`../data/create_pretraining_data.py`](../data/create_pretraining_data.py)
which is essentially branched from [BERT research repo](https://github.com/google-research/bert)
to get processed pre-training data and it adapts to TF2 symbols and python3
compatibility.
Running the pre-training script requires an input and output directory, as well as a vocab file. Note that max_seq_length will need to match the sequence length parameter you specify when you run pre-training.
Example shell script to call create_pretraining_data.py
```
export WORKING_DIR='local disk or cloud location'
export BERT_DIR='local disk or cloud location'
python models/official/nlp/data/create_pretraining_data.py \
--input_file=$WORKING_DIR/input/input.txt \
--output_file=$WORKING_DIR/output/tf_examples.tfrecord \
--vocab_file=$BERT_DIR/wwm_uncased_L-24_H-1024_A-16/vocab.txt \
--do_lower_case=True \
--max_seq_length=512 \
--max_predictions_per_seq=76 \
--masked_lm_prob=0.15 \
--random_seed=12345 \
--dupe_factor=5
```
### Fine-tuning
To prepare the fine-tuning data for final model training, use the
[`../data/create_finetuning_data.py`](../data/create_finetuning_data.py) script.
Resulting datasets in `tf_record` format and training meta data should be later
passed to training or evaluation scripts. The task-specific arguments are
described in following sections:
* GLUE
Users can download the
[GLUE data](https://gluebenchmark.com/tasks) by running
[this script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e)
and unpack it to some directory `$GLUE_DIR`.
Also, users can download [Pretrained Checkpoint](#access-to-pretrained-checkpoints) and locate on some directory `$BERT_DIR` instead of using checkpoints on Google Cloud Storage.
```shell
export GLUE_DIR=~/glue
export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
export TASK_NAME=MNLI
export OUTPUT_DIR=gs://some_bucket/datasets
python ../data/create_finetuning_data.py \
--input_data_dir=${GLUE_DIR}/${TASK_NAME}/ \
--vocab_file=${BERT_DIR}/vocab.txt \
--train_data_output_path=${OUTPUT_DIR}/${TASK_NAME}_train.tf_record \
--eval_data_output_path=${OUTPUT_DIR}/${TASK_NAME}_eval.tf_record \
--meta_data_file_path=${OUTPUT_DIR}/${TASK_NAME}_meta_data \
--fine_tuning_task_type=classification --max_seq_length=128 \
--classification_task_name=${TASK_NAME}
```
* SQUAD
The [SQuAD website](https://rajpurkar.github.io/SQuAD-explorer/) contains
detailed information about the SQuAD datasets and evaluation.
The necessary files can be found here:
* [train-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json)
* [dev-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json)
* [evaluate-v1.1.py](https://github.com/allenai/bi-att-flow/blob/master/squad/evaluate-v1.1.py)
* [train-v2.0.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json)
* [dev-v2.0.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json)
* [evaluate-v2.0.py](https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/)
```shell
export SQUAD_DIR=~/squad
export SQUAD_VERSION=v1.1
export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
export OUTPUT_DIR=gs://some_bucket/datasets
python ../data/create_finetuning_data.py \
--squad_data_file=${SQUAD_DIR}/train-${SQUAD_VERSION}.json \
--vocab_file=${BERT_DIR}/vocab.txt \
--train_data_output_path=${OUTPUT_DIR}/squad_${SQUAD_VERSION}_train.tf_record \
--meta_data_file_path=${OUTPUT_DIR}/squad_${SQUAD_VERSION}_meta_data \
--fine_tuning_task_type=squad --max_seq_length=384
```
Note: To create fine-tuning data with SQUAD 2.0, you need to add flag `--version_2_with_negative=True`.
## Fine-tuning with BERT
### Cloud GPUs and TPUs
* Cloud Storage
The unzipped pre-trained model files can also be found in the Google Cloud
Storage folder `gs://cloud-tpu-checkpoints/bert/keras_bert`. For example:
```shell
export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
export MODEL_DIR=gs://some_bucket/my_output_dir
```
Currently, users are able to access to `tf-nightly` TPUs and the following TPU
script should run with `tf-nightly`.
* GPU -> TPU
Just add the following flags to `run_classifier.py` or `run_squad.py`:
```shell
--distribution_strategy=tpu
--tpu=grpc://${TPU_IP_ADDRESS}:8470
```
### Sentence and Sentence-pair Classification Tasks
This example code fine-tunes `BERT-Large` on the Microsoft Research Paraphrase
Corpus (MRPC) corpus, which only contains 3,600 examples and can fine-tune in a
few minutes on most GPUs.
We use the `BERT-Large` (uncased_L-24_H-1024_A-16) as an example throughout the
workflow.
For GPU memory of 16GB or smaller, you may try to use `BERT-Base`
(uncased_L-12_H-768_A-12).
```shell
export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
export MODEL_DIR=gs://some_bucket/my_output_dir
export GLUE_DIR=gs://some_bucket/datasets
export TASK=MRPC
python run_classifier.py \
--mode='train_and_eval' \
--input_meta_data_path=${GLUE_DIR}/${TASK}_meta_data \
--train_data_path=${GLUE_DIR}/${TASK}_train.tf_record \
--eval_data_path=${GLUE_DIR}/${TASK}_eval.tf_record \
--bert_config_file=${BERT_DIR}/bert_config.json \
--init_checkpoint=${BERT_DIR}/bert_model.ckpt \
--train_batch_size=4 \
--eval_batch_size=4 \
--steps_per_loop=1 \
--learning_rate=2e-5 \
--num_train_epochs=3 \
--model_dir=${MODEL_DIR} \
--distribution_strategy=mirrored
```
Alternatively, instead of specifying `init_checkpoint`, you can specify
`hub_module_url` to employ a pretraind BERT hub module, e.g.,
` --hub_module_url=https://tfhub.dev/tensorflow/bert_en_uncased_L-24_H-1024_A-16/1`.
After training a model, to get predictions from the classifier, you can set the
`--mode=predict` and offer the test set tfrecords to `--eval_data_path`.
Output will be created in file called test_results.tsv in the output folder.
Each line will contain output for each sample, columns are the class
probabilities.
```shell
python run_classifier.py \
--mode='predict' \
--input_meta_data_path=${GLUE_DIR}/${TASK}_meta_data \
--eval_data_path=${GLUE_DIR}/${TASK}_eval.tf_record \
--bert_config_file=${BERT_DIR}/bert_config.json \
--eval_batch_size=4 \
--model_dir=${MODEL_DIR} \
--distribution_strategy=mirrored
```
To use TPU, you only need to switch distribution strategy type to `tpu` with TPU
information and use remote storage for model checkpoints.
```shell
export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
export TPU_IP_ADDRESS='???'
export MODEL_DIR=gs://some_bucket/my_output_dir
export GLUE_DIR=gs://some_bucket/datasets
export TASK=MRPC
python run_classifier.py \
--mode='train_and_eval' \
--input_meta_data_path=${GLUE_DIR}/${TASK}_meta_data \
--train_data_path=${GLUE_DIR}/${TASK}_train.tf_record \
--eval_data_path=${GLUE_DIR}/${TASK}_eval.tf_record \
--bert_config_file=${BERT_DIR}/bert_config.json \
--init_checkpoint=${BERT_DIR}/bert_model.ckpt \
--train_batch_size=32 \
--eval_batch_size=32 \
--steps_per_loop=1000 \
--learning_rate=2e-5 \
--num_train_epochs=3 \
--model_dir=${MODEL_DIR} \
--distribution_strategy=tpu \
--tpu=grpc://${TPU_IP_ADDRESS}:8470
```
Note that, we specify `steps_per_loop=1000` for TPU, because running a loop of
training steps inside a `tf.function` can significantly increase TPU utilization
and callbacks will not be called inside the loop.
### SQuAD 1.1
The Stanford Question Answering Dataset (SQuAD) is a popular question answering
benchmark dataset. See more in [SQuAD website](https://rajpurkar.github.io/SQuAD-explorer/).
We use the `BERT-Large` (uncased_L-24_H-1024_A-16) as an example throughout the
workflow.
For GPU memory of 16GB or smaller, you may try to use `BERT-Base`
(uncased_L-12_H-768_A-12).
```shell
export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
export SQUAD_DIR=gs://some_bucket/datasets
export MODEL_DIR=gs://some_bucket/my_output_dir
export SQUAD_VERSION=v1.1
python run_squad.py \
--input_meta_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_meta_data \
--train_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_train.tf_record \
--predict_file=${SQUAD_DIR}/dev-v1.1.json \
--vocab_file=${BERT_DIR}/vocab.txt \
--bert_config_file=${BERT_DIR}/bert_config.json \
--init_checkpoint=${BERT_DIR}/bert_model.ckpt \
--train_batch_size=4 \
--predict_batch_size=4 \
--learning_rate=8e-5 \
--num_train_epochs=2 \
--model_dir=${MODEL_DIR} \
--distribution_strategy=mirrored
```
Similarily, you can replace `init_checkpoint` FLAG with `hub_module_url` to
specify a hub module path.
`run_squad.py` writes the prediction for `--predict_file` by default. If you set
the `--model=predict` and offer the SQuAD test data, the scripts will generate
the prediction json file.
To use TPU, you need switch distribution strategy type to `tpu` with TPU
information.
```shell
export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
export TPU_IP_ADDRESS='???'
export MODEL_DIR=gs://some_bucket/my_output_dir
export SQUAD_DIR=gs://some_bucket/datasets
export SQUAD_VERSION=v1.1
python run_squad.py \
--input_meta_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_meta_data \
--train_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_train.tf_record \
--predict_file=${SQUAD_DIR}/dev-v1.1.json \
--vocab_file=${BERT_DIR}/vocab.txt \
--bert_config_file=${BERT_DIR}/bert_config.json \
--init_checkpoint=${BERT_DIR}/bert_model.ckpt \
--train_batch_size=32 \
--learning_rate=8e-5 \
--num_train_epochs=2 \
--model_dir=${MODEL_DIR} \
--distribution_strategy=tpu \
--tpu=grpc://${TPU_IP_ADDRESS}:8470
```
The dev set predictions will be saved into a file called predictions.json in the
model_dir:
```shell
python $SQUAD_DIR/evaluate-v1.1.py $SQUAD_DIR/dev-v1.1.json ./squad/predictions.json
```
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -17,9 +17,9 @@
import gin
import tensorflow as tf
import tensorflow_hub as hub
from official.legacy.nlp.albert import configs as albert_configs
from official.legacy.albert import configs as albert_configs
from official.legacy.bert import configs
from official.modeling import tf_utils
from official.nlp.bert import configs
from official.nlp.modeling import models
from official.nlp.modeling import networks
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -14,8 +14,8 @@
import tensorflow as tf
from official.nlp.bert import bert_models
from official.nlp.bert import configs as bert_configs
from official.legacy.bert import bert_models
from official.legacy.bert import configs as bert_configs
from official.nlp.modeling import networks
......
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defining common flags used across all BERT models/applications."""
from absl import flags
import tensorflow as tf
from official.utils import hyperparams_flags
from official.utils.flags import core as flags_core
def define_common_bert_flags():
"""Define common flags for BERT tasks."""
flags_core.define_base(
data_dir=False,
model_dir=True,
clean=False,
train_epochs=False,
epochs_between_evals=False,
stop_threshold=False,
batch_size=False,
num_gpu=True,
export_dir=False,
distribution_strategy=True,
run_eagerly=True)
flags_core.define_distribution()
flags.DEFINE_string('bert_config_file', None,
'Bert configuration file to define core bert layers.')
flags.DEFINE_string(
'model_export_path', None,
'Path to the directory, where trainined model will be '
'exported.')
flags.DEFINE_string('tpu', '', 'TPU address to connect to.')
flags.DEFINE_string(
'init_checkpoint', None,
'Initial checkpoint (usually from a pre-trained BERT model).')
flags.DEFINE_integer('num_train_epochs', 3,
'Total number of training epochs to perform.')
flags.DEFINE_integer(
'steps_per_loop', None,
'Number of steps per graph-mode loop. Only training step '
'happens inside the loop. Callbacks will not be called '
'inside. If not set the value will be configured depending on the '
'devices available.')
flags.DEFINE_float('learning_rate', 5e-5,
'The initial learning rate for Adam.')
flags.DEFINE_float('end_lr', 0.0,
'The end learning rate for learning rate decay.')
flags.DEFINE_string('optimizer_type', 'adamw',
'The type of optimizer to use for training (adamw|lamb)')
flags.DEFINE_boolean(
'scale_loss', False,
'Whether to divide the loss by number of replica inside the per-replica '
'loss function.')
flags.DEFINE_boolean(
'use_keras_compile_fit', False,
'If True, uses Keras compile/fit() API for training logic. Otherwise '
'use custom training loop.')
flags.DEFINE_string(
'hub_module_url', None, 'TF-Hub path/url to Bert module. '
'If specified, init_checkpoint flag should not be used.')
flags.DEFINE_bool('hub_module_trainable', True,
'True to make keras layers in the hub module trainable.')
flags.DEFINE_string(
'sub_model_export_name', None,
'If set, `sub_model` checkpoints are exported into '
'FLAGS.model_dir/FLAGS.sub_model_export_name.')
flags.DEFINE_bool('explicit_allreduce', False,
'True to use explicit allreduce instead of the implicit '
'allreduce in optimizer.apply_gradients(). If fp16 mixed '
'precision training is used, this also enables allreduce '
'gradients in fp16.')
flags.DEFINE_integer('allreduce_bytes_per_pack', 0,
'Number of bytes of a gradient pack for allreduce. '
'Should be positive integer, if set to 0, all '
'gradients are in one pack. Breaking gradient into '
'packs could enable overlap between allreduce and '
'backprop computation. This flag only takes effect '
'when explicit_allreduce is set to True.')
flags_core.define_log_steps()
# Adds flags for mixed precision and multi-worker training.
flags_core.define_performance(
num_parallel_calls=False,
inter_op=False,
intra_op=False,
synthetic_data=False,
max_train_steps=False,
dtype=True,
loss_scale=True,
all_reduce_alg=True,
num_packs=False,
tf_gpu_thread_mode=True,
datasets_num_private_threads=True,
enable_xla=True,
fp16_implementation=True,
)
# Adds gin configuration flags.
hyperparams_flags.define_gin_flags()
def dtype():
return flags_core.get_tf_dtype(flags.FLAGS)
def use_float16():
return flags_core.get_tf_dtype(flags.FLAGS) == tf.float16
def get_loss_scale():
return flags_core.get_loss_scale(flags.FLAGS, default_for_fp16='dynamic')
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The main BERT model and related functions."""
import copy
import json
import six
import tensorflow as tf
class BertConfig(object):
"""Configuration for `BertModel`."""
def __init__(self,
vocab_size,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
initializer_range=0.02,
embedding_size=None,
backward_compatible=True):
"""Constructs BertConfig.
Args:
vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
hidden_size: Size of the encoder layers and the pooler layer.
num_hidden_layers: Number of hidden layers in the Transformer encoder.
num_attention_heads: Number of attention heads for each attention layer in
the Transformer encoder.
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
hidden_act: The non-linear activation function (function or string) in the
encoder and pooler.
hidden_dropout_prob: The dropout probability for all fully connected
layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob: The dropout ratio for the attention
probabilities.
max_position_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
`BertModel`.
initializer_range: The stdev of the truncated_normal_initializer for
initializing all weight matrices.
embedding_size: (Optional) width of the factorized word embeddings.
backward_compatible: Boolean, whether the variables shape are compatible
with checkpoints converted from TF 1.x BERT.
"""
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.embedding_size = embedding_size
self.backward_compatible = backward_compatible
@classmethod
def from_dict(cls, json_object):
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
config = BertConfig(vocab_size=None)
for (key, value) in six.iteritems(json_object):
config.__dict__[key] = value
return config
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `BertConfig` from a json file of parameters."""
with tf.io.gfile.GFile(json_file, "r") as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A script to export BERT as a TF-Hub SavedModel.
This script is **DEPRECATED** for exporting BERT encoder models;
see the error message in by main() for details.
"""
from typing import Text
# Import libraries
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
from official.legacy.bert import bert_models
from official.legacy.bert import configs
FLAGS = flags.FLAGS
flags.DEFINE_string("bert_config_file", None,
"Bert configuration file to define core bert layers.")
flags.DEFINE_string("model_checkpoint_path", None,
"File path to TF model checkpoint.")
flags.DEFINE_string("export_path", None, "TF-Hub SavedModel destination path.")
flags.DEFINE_string("vocab_file", None,
"The vocabulary file that the BERT model was trained on.")
flags.DEFINE_bool(
"do_lower_case", None, "Whether to lowercase. If None, "
"do_lower_case will be enabled if 'uncased' appears in the "
"name of --vocab_file")
flags.DEFINE_enum("model_type", "encoder", ["encoder", "squad"],
"What kind of BERT model to export.")
def create_bert_model(bert_config: configs.BertConfig) -> tf.keras.Model:
"""Creates a BERT keras core model from BERT configuration.
Args:
bert_config: A `BertConfig` to create the core model.
Returns:
A keras model.
"""
# Adds input layers just as placeholders.
input_word_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name="input_word_ids")
input_mask = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name="input_mask")
input_type_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name="input_type_ids")
transformer_encoder = bert_models.get_transformer_encoder(
bert_config, sequence_length=None)
sequence_output, pooled_output = transformer_encoder(
[input_word_ids, input_mask, input_type_ids])
# To keep consistent with legacy hub modules, the outputs are
# "pooled_output" and "sequence_output".
return tf.keras.Model(
inputs=[input_word_ids, input_mask, input_type_ids],
outputs=[pooled_output, sequence_output]), transformer_encoder
def export_bert_tfhub(bert_config: configs.BertConfig,
model_checkpoint_path: Text,
hub_destination: Text,
vocab_file: Text,
do_lower_case: bool = None):
"""Restores a tf.keras.Model and saves for TF-Hub."""
# If do_lower_case is not explicit, default to checking whether "uncased" is
# in the vocab file name
if do_lower_case is None:
do_lower_case = "uncased" in vocab_file
logging.info("Using do_lower_case=%s based on name of vocab_file=%s",
do_lower_case, vocab_file)
core_model, encoder = create_bert_model(bert_config)
checkpoint = tf.train.Checkpoint(
model=encoder, # Legacy checkpoints.
encoder=encoder)
checkpoint.restore(model_checkpoint_path).assert_existing_objects_matched()
core_model.vocab_file = tf.saved_model.Asset(vocab_file)
core_model.do_lower_case = tf.Variable(do_lower_case, trainable=False)
core_model.save(hub_destination, include_optimizer=False, save_format="tf")
def export_bert_squad_tfhub(bert_config: configs.BertConfig,
model_checkpoint_path: Text,
hub_destination: Text,
vocab_file: Text,
do_lower_case: bool = None):
"""Restores a tf.keras.Model for BERT with SQuAD and saves for TF-Hub."""
# If do_lower_case is not explicit, default to checking whether "uncased" is
# in the vocab file name
if do_lower_case is None:
do_lower_case = "uncased" in vocab_file
logging.info("Using do_lower_case=%s based on name of vocab_file=%s",
do_lower_case, vocab_file)
span_labeling, _ = bert_models.squad_model(bert_config, max_seq_length=None)
checkpoint = tf.train.Checkpoint(model=span_labeling)
checkpoint.restore(model_checkpoint_path).assert_existing_objects_matched()
span_labeling.vocab_file = tf.saved_model.Asset(vocab_file)
span_labeling.do_lower_case = tf.Variable(do_lower_case, trainable=False)
span_labeling.save(hub_destination, include_optimizer=False, save_format="tf")
def main(_):
bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
if FLAGS.model_type == "encoder":
deprecation_note = (
"nlp/bert/export_tfhub is **DEPRECATED** for exporting BERT encoder "
"models. Please switch to nlp/tools/export_tfhub for exporting BERT "
"(and other) encoders with dict inputs/outputs conforming to "
"https://www.tensorflow.org/hub/common_saved_model_apis/text#transformer-encoders"
)
logging.error(deprecation_note)
print("\n\nNOTICE:", deprecation_note, "\n")
export_bert_tfhub(bert_config, FLAGS.model_checkpoint_path,
FLAGS.export_path, FLAGS.vocab_file, FLAGS.do_lower_case)
elif FLAGS.model_type == "squad":
export_bert_squad_tfhub(bert_config, FLAGS.model_checkpoint_path,
FLAGS.export_path, FLAGS.vocab_file,
FLAGS.do_lower_case)
else:
raise ValueError("Unsupported model_type %s." % FLAGS.model_type)
if __name__ == "__main__":
app.run(main)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests official.nlp.bert.export_tfhub."""
import os
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
from official.legacy.bert import configs
from official.legacy.bert import export_tfhub
class ExportTfhubTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters("model", "encoder")
def test_export_tfhub(self, ckpt_key_name):
# Exports a savedmodel for TF-Hub
hidden_size = 16
bert_config = configs.BertConfig(
vocab_size=100,
hidden_size=hidden_size,
intermediate_size=32,
max_position_embeddings=128,
num_attention_heads=2,
num_hidden_layers=1)
bert_model, encoder = export_tfhub.create_bert_model(bert_config)
model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")
checkpoint = tf.train.Checkpoint(**{ckpt_key_name: encoder})
checkpoint.save(os.path.join(model_checkpoint_dir, "test"))
model_checkpoint_path = tf.train.latest_checkpoint(model_checkpoint_dir)
vocab_file = os.path.join(self.get_temp_dir(), "uncased_vocab.txt")
with tf.io.gfile.GFile(vocab_file, "w") as f:
f.write("dummy content")
hub_destination = os.path.join(self.get_temp_dir(), "hub")
export_tfhub.export_bert_tfhub(bert_config, model_checkpoint_path,
hub_destination, vocab_file)
# Restores a hub KerasLayer.
hub_layer = hub.KerasLayer(hub_destination, trainable=True)
if hasattr(hub_layer, "resolved_object"):
# Checks meta attributes.
self.assertTrue(hub_layer.resolved_object.do_lower_case.numpy())
with tf.io.gfile.GFile(
hub_layer.resolved_object.vocab_file.asset_path.numpy()) as f:
self.assertEqual("dummy content", f.read())
# Checks the hub KerasLayer.
for source_weight, hub_weight in zip(bert_model.trainable_weights,
hub_layer.trainable_weights):
self.assertAllClose(source_weight.numpy(), hub_weight.numpy())
seq_length = 10
dummy_ids = np.zeros((2, seq_length), dtype=np.int32)
hub_outputs = hub_layer([dummy_ids, dummy_ids, dummy_ids])
source_outputs = bert_model([dummy_ids, dummy_ids, dummy_ids])
# The outputs of hub module are "pooled_output" and "sequence_output",
# while the outputs of encoder is in reversed order, i.e.,
# "sequence_output" and "pooled_output".
encoder_outputs = reversed(encoder([dummy_ids, dummy_ids, dummy_ids]))
self.assertEqual(hub_outputs[0].shape, (2, hidden_size))
self.assertEqual(hub_outputs[1].shape, (2, seq_length, hidden_size))
for source_output, hub_output, encoder_output in zip(
source_outputs, hub_outputs, encoder_outputs):
self.assertAllClose(source_output.numpy(), hub_output.numpy())
self.assertAllClose(source_output.numpy(), encoder_output.numpy())
# Test that training=True makes a difference (activates dropout).
def _dropout_mean_stddev(training, num_runs=20):
input_ids = np.array([[14, 12, 42, 95, 99]], np.int32)
inputs = [input_ids, np.ones_like(input_ids), np.zeros_like(input_ids)]
outputs = np.concatenate(
[hub_layer(inputs, training=training)[0] for _ in range(num_runs)])
return np.mean(np.std(outputs, axis=0))
self.assertLess(_dropout_mean_stddev(training=False), 1e-6)
self.assertGreater(_dropout_mean_stddev(training=True), 1e-3)
# Test propagation of seq_length in shape inference.
input_word_ids = tf.keras.layers.Input(shape=(seq_length,), dtype=tf.int32)
input_mask = tf.keras.layers.Input(shape=(seq_length,), dtype=tf.int32)
input_type_ids = tf.keras.layers.Input(shape=(seq_length,), dtype=tf.int32)
pooled_output, sequence_output = hub_layer(
[input_word_ids, input_mask, input_type_ids])
self.assertEqual(pooled_output.shape.as_list(), [None, hidden_size])
self.assertEqual(sequence_output.shape.as_list(),
[None, seq_length, hidden_size])
if __name__ == "__main__":
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -15,10 +15,9 @@
"""Utilities to save models."""
import os
import typing
from absl import logging
import tensorflow as tf
import typing
def export_bert_model(model_export_path: typing.Text,
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment