"vscode:/vscode.git/clone" did not exist on "63c763685f1dc94f7efe4742b00b226be99505d0"
Commit 3ce2f61b authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

Merge branch 'master' of https://github.com/tensorflow/models into context_tf2

parents bb16d5ca 8e9296ff
......@@ -10,11 +10,13 @@ can take full advantage of TensorFlow for their research and product development
| [official](official) | • A collection of example implementations for SOTA models using the latest TensorFlow 2's high-level APIs<br />• Officially maintained, supported, and kept up to date with the latest TensorFlow 2 APIs by TensorFlow<br />• Reasonably optimized for fast performance while still being easy to read |
| [research](research) | • A collection of research model implementations in TensorFlow 1 or 2 by researchers<br />• Maintained and supported by researchers |
| [community](community) | • A curated list of the GitHub repositories with machine learning models and implementations powered by TensorFlow 2 |
| [orbit](orbit) | • A flexible and lightweight library that users can easily use or fork when writing customized training loop code in TensorFlow 2.x. It seamlessly integrates with `tf.distribute` and supports running on different device types (CPU, GPU, and TPU). |
## [Announcements](https://github.com/tensorflow/models/wiki/Announcements)
| Date | News |
|------|------|
| July 10, 2020 | TensorFlow 2 meets the [Object Detection API](https://github.com/tensorflow/models/tree/master/research/object_detection) ([Blog](https://blog.tensorflow.org/2020/07/tensorflow-2-meets-object-detection-api.html)) |
| June 30, 2020 | [SpineNet: Learning Scale-Permuted Backbone for Recognition and Localization](https://github.com/tensorflow/models/tree/master/official/vision/detection#train-a-spinenet-49-based-mask-r-cnn) released ([Tweet](https://twitter.com/GoogleAI/status/1278016712978264064)) |
| June 17, 2020 | [Context R-CNN: Long Term Temporal Context for Per-Camera Object Detection](https://github.com/tensorflow/models/tree/master/research/object_detection#june-17th-2020) released ([Tweet](https://twitter.com/GoogleAI/status/1276571419422253057)) |
| May 21, 2020 | [Unifying Deep Local and Global Features for Image Search (DELG)](https://github.com/tensorflow/models/tree/master/research/delf#delg) code released |
......@@ -23,12 +25,6 @@ can take full advantage of TensorFlow for their research and product development
| May 1, 2020 | [DELF: DEep Local Features](https://github.com/tensorflow/models/tree/master/research/delf) updated to support TensorFlow 2.1 |
| March 31, 2020 | [Introducing the Model Garden for TensorFlow 2](https://blog.tensorflow.org/2020/03/introducing-model-garden-for-tensorflow-2.html) ([Tweet](https://twitter.com/TensorFlow/status/1245029834633297921)) |
## [Milestones](https://github.com/tensorflow/models/milestones)
| Date | Milestone |
|------|-----------|
| July 8, 2020 | [![GitHub milestone](https://img.shields.io/github/milestones/progress/tensorflow/models/1)](https://github.com/tensorflow/models/milestone/1) |
## Contributions
[![help wanted:paper implementation](https://img.shields.io/github/issues/tensorflow/models/help%20wanted%3Apaper%20implementation)](https://github.com/tensorflow/models/labels/help%20wanted%3Apaper%20implementation)
......
......@@ -17,12 +17,9 @@ with the same or improved speed and performance with each new TensorFlow build.
The team is actively developing new models.
In the near future, we will add:
* State-of-the-art language understanding models:
More members in Transformer family
* State-of-the-art image classification models:
EfficientNet, MnasNet, and variants
* State-of-the-art objection detection and instance segmentation models:
RetinaNet, Mask R-CNN, SpineNet, and variants
* State-of-the-art language understanding models.
* State-of-the-art image classification models.
* State-of-the-art objection detection and instance segmentation models.
## Table of Contents
......
......@@ -93,8 +93,11 @@ class Unet3DAccuracyBenchmark(keras_benchmark.KerasBenchmark):
"""Runs and reports the benchmark given the provided configuration."""
params = unet_training_lib.extract_params(FLAGS)
strategy = unet_training_lib.create_distribution_strategy(params)
if params.use_bfloat16:
policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16')
input_dtype = params.dtype
if input_dtype == 'float16' or input_dtype == 'bfloat16':
policy = tf.keras.mixed_precision.experimental.Policy(
'mixed_bfloat16' if input_dtype == 'bfloat16' else 'mixed_float16')
tf.keras.mixed_precision.experimental.set_policy(policy)
stats = {}
......
This diff is collapsed.
This diff is collapsed.
......@@ -59,7 +59,7 @@ class Task(tf.Module):
def initialize(self, model: tf.keras.Model):
"""A callback function used as CheckpointManager's init_fn.
This function will be called when no checkpoint found for the model.
This function will be called when no checkpoint is found for the model.
If there is a checkpoint, the checkpoint will be loaded and this function
will not be called. You can use this callback function to load a pretrained
checkpoint, saved under a directory other than the model_dir.
......@@ -71,7 +71,7 @@ class Task(tf.Module):
@abc.abstractmethod
def build_model(self) -> tf.keras.Model:
"""Creates the model architecture.
"""Creates model architecture.
Returns:
A model instance.
......@@ -135,7 +135,7 @@ class Task(tf.Module):
Args:
labels: optional label tensors.
model_outputs: a nested structure of output tensors.
aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model.
aux_losses: auxiliary loss tensors, i.e. `losses` in keras.Model.
Returns:
The total loss tensor.
......@@ -232,7 +232,7 @@ class Task(tf.Module):
return logs
def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
"""Validatation step.
"""Validation step.
With distribution strategies, this method runs on devices.
......
......@@ -171,6 +171,9 @@ class InputReader:
as_supervised=self._tfds_as_supervised,
decoders=decoders,
read_config=read_config)
if self._is_training:
dataset = dataset.repeat()
return dataset
@property
......
......@@ -126,10 +126,10 @@ class Config(params_dict.ParamsDict):
subconfig_type = Config
if k in cls.__annotations__:
# Directly Config subtype.
type_annotation = cls.__annotations__[k]
type_annotation = cls.__annotations__[k] # pytype: disable=invalid-annotation
if (isinstance(type_annotation, type) and
issubclass(type_annotation, Config)):
subconfig_type = cls.__annotations__[k]
subconfig_type = cls.__annotations__[k] # pytype: disable=invalid-annotation
else:
# Check if the field is a sequence of subtypes.
field_type = getattr(type_annotation, '__origin__', type(None))
......
......@@ -14,6 +14,7 @@
# limitations under the License.
# ==============================================================================
"""Common configuration settings."""
from typing import Optional, Union
import dataclasses
......@@ -123,8 +124,8 @@ class RuntimeConfig(base_config.Config):
task_index: int = -1
all_reduce_alg: Optional[str] = None
num_packs: int = 1
loss_scale: Optional[Union[str, float]] = None
mixed_precision_dtype: Optional[str] = None
loss_scale: Optional[Union[str, float]] = None
run_eagerly: bool = False
batchnorm_spatial_persistent: bool = False
......@@ -172,23 +173,27 @@ class TrainerConfig(base_config.Config):
eval_tf_function: whether or not to use tf_function for eval.
steps_per_loop: number of steps per loop.
summary_interval: number of steps between each summary.
checkpoint_intervals: number of steps between checkpoints.
checkpoint_interval: number of steps between checkpoints.
max_to_keep: max checkpoints to keep.
continuous_eval_timeout: maximum number of seconds to wait between
checkpoints, if set to None, continuous eval will wait indefinetely.
checkpoints, if set to None, continuous eval will wait indefinitely.
train_steps: number of train steps.
validation_steps: number of eval steps. If `None`, the entire eval dataset
is used.
validation_interval: number of training steps to run between evaluations.
"""
optimizer_config: OptimizationConfig = OptimizationConfig()
train_steps: int = 0
validation_steps: Optional[int] = None
validation_interval: int = 100
train_tf_while_loop: bool = True
train_tf_function: bool = True
eval_tf_function: bool = True
steps_per_loop: int = 1000
summary_interval: int = 1000
checkpoint_interval: int = 1000
max_to_keep: int = 5
continuous_eval_timeout: Optional[int] = None
train_tf_while_loop: bool = True
train_tf_function: bool = True
eval_tf_function: bool = True
train_steps: int = 0
validation_steps: Optional[int] = None
validation_interval: int = 1000
@dataclasses.dataclass
......
......@@ -20,6 +20,20 @@ import dataclasses
from official.modeling.hyperparams import base_config
@dataclasses.dataclass
class ConstantLrConfig(base_config.Config):
"""Configuration for constant learning rate.
This class is a containers for the constant learning rate decay configs.
Attributes:
name: The name of the learning rate schedule. Defaults to Constant.
learning_rate: A float. The learning rate. Defaults to 0.1.
"""
name: str = 'Constant'
learning_rate: float = 0.1
@dataclasses.dataclass
class StepwiseLrConfig(base_config.Config):
"""Configuration for stepwise learning rate decay.
......
......@@ -55,12 +55,14 @@ class LrConfig(oneof.OneOfConfig):
Attributes:
type: 'str', type of lr schedule to be used, on the of fields below.
constant: constant learning rate config.
stepwise: stepwise learning rate config.
exponential: exponential learning rate config.
polynomial: polynomial learning rate config.
cosine: cosine learning rate config.
"""
type: Optional[str] = None
constant: lr_cfg.ConstantLrConfig = lr_cfg.ConstantLrConfig()
stepwise: lr_cfg.StepwiseLrConfig = lr_cfg.StepwiseLrConfig()
exponential: lr_cfg.ExponentialLrConfig = lr_cfg.ExponentialLrConfig()
polynomial: lr_cfg.PolynomialLrConfig = lr_cfg.PolynomialLrConfig()
......
......@@ -28,13 +28,11 @@ class SGDConfig(base_config.Config):
Attributes:
name: name of the optimizer.
learning_rate: learning_rate for SGD optimizer.
decay: decay rate for SGD optimizer.
nesterov: nesterov for SGD optimizer.
momentum: momentum for SGD optimizer.
"""
name: str = "SGD"
learning_rate: float = 0.01
decay: float = 0.0
nesterov: bool = False
momentum: float = 0.0
......@@ -49,14 +47,12 @@ class RMSPropConfig(base_config.Config):
Attributes:
name: name of the optimizer.
learning_rate: learning_rate for RMSprop optimizer.
rho: discounting factor for RMSprop optimizer.
momentum: momentum for RMSprop optimizer.
epsilon: epsilon value for RMSprop optimizer, help with numerical stability.
centered: Whether to normalize gradients or not.
"""
name: str = "RMSprop"
learning_rate: float = 0.001
rho: float = 0.9
momentum: float = 0.0
epsilon: float = 1e-7
......@@ -72,7 +68,6 @@ class AdamConfig(base_config.Config):
Attributes:
name: name of the optimizer.
learning_rate: learning_rate for Adam optimizer.
beta_1: decay rate for 1st order moments.
beta_2: decay rate for 2st order moments.
epsilon: epsilon value used for numerical stability in Adam optimizer.
......@@ -80,7 +75,6 @@ class AdamConfig(base_config.Config):
the paper "On the Convergence of Adam and beyond".
"""
name: str = "Adam"
learning_rate: float = 0.001
beta_1: float = 0.9
beta_2: float = 0.999
epsilon: float = 1e-07
......@@ -93,7 +87,6 @@ class AdamWeightDecayConfig(base_config.Config):
Attributes:
name: name of the optimizer.
learning_rate: learning_rate for the optimizer.
beta_1: decay rate for 1st order moments.
beta_2: decay rate for 2st order moments.
epsilon: epsilon value used for numerical stability in the optimizer.
......@@ -106,7 +99,6 @@ class AdamWeightDecayConfig(base_config.Config):
include in weight decay.
"""
name: str = "AdamWeightDecay"
learning_rate: float = 0.001
beta_1: float = 0.9
beta_2: float = 0.999
epsilon: float = 1e-07
......@@ -125,7 +117,6 @@ class LAMBConfig(base_config.Config):
Attributes:
name: name of the optimizer.
learning_rate: learning_rate for Adam optimizer.
beta_1: decay rate for 1st order moments.
beta_2: decay rate for 2st order moments.
epsilon: epsilon value used for numerical stability in LAMB optimizer.
......@@ -139,7 +130,6 @@ class LAMBConfig(base_config.Config):
be excluded.
"""
name: str = "LAMB"
learning_rate: float = 0.001
beta_1: float = 0.9
beta_2: float = 0.999
epsilon: float = 1e-6
......
......@@ -60,7 +60,7 @@ class OptimizerFactory(object):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9}
'sgd': {'momentum': 0.9}
},
'learning_rate': {
'type': 'stepwise',
......@@ -88,12 +88,15 @@ class OptimizerFactory(object):
self._optimizer_config = config.optimizer.get()
self._optimizer_type = config.optimizer.type
if self._optimizer_config is None:
if self._optimizer_type is None:
raise ValueError('Optimizer type must be specified')
self._lr_config = config.learning_rate.get()
self._lr_type = config.learning_rate.type
if self._lr_type is None:
raise ValueError('Learning rate type must be specified')
self._warmup_config = config.warmup.get()
self._warmup_type = config.warmup.type
......@@ -101,18 +104,15 @@ class OptimizerFactory(object):
"""Build learning rate.
Builds learning rate from config. Learning rate schedule is built according
to the learning rate config. If there is no learning rate config, optimizer
learning rate is returned.
to the learning rate config. If learning rate type is consant,
lr_config.learning_rate is returned.
Returns:
tf.keras.optimizers.schedules.LearningRateSchedule instance. If no
learning rate schedule defined, optimizer_config.learning_rate is
returned.
tf.keras.optimizers.schedules.LearningRateSchedule instance. If
learning rate type is consant, lr_config.learning_rate is returned.
"""
# TODO(arashwan): Explore if we want to only allow explicit const lr sched.
if not self._lr_config:
lr = self._optimizer_config.learning_rate
if self._lr_type == 'constant':
lr = self._lr_config.learning_rate
else:
lr = LR_CLS[self._lr_type](**self._lr_config.as_dict())
......
......@@ -35,10 +35,17 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = {
'optimizer': {
'type': optimizer_type
},
'learning_rate': {
'type': 'constant',
'constant': {
'learning_rate': 0.1
}
}
}
optimizer_cls = optimizer_factory.OPTIMIZERS_CLS[optimizer_type]
expected_optimizer_config = optimizer_cls().get_config()
expected_optimizer_config['learning_rate'] = 0.1
opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config)
......@@ -48,11 +55,32 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
self.assertIsInstance(optimizer, optimizer_cls)
self.assertEqual(expected_optimizer_config, optimizer.get_config())
def test_missing_types(self):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'momentum': 0.9}
}
}
with self.assertRaises(ValueError):
optimizer_factory.OptimizerFactory(
optimization_config.OptimizationConfig(params))
params = {
'learning_rate': {
'type': 'stepwise',
'stepwise': {'boundaries': [10000, 20000],
'values': [0.1, 0.01, 0.001]}
}
}
with self.assertRaises(ValueError):
optimizer_factory.OptimizerFactory(
optimization_config.OptimizationConfig(params))
def test_stepwise_lr_schedule(self):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9}
'sgd': {'momentum': 0.9}
},
'learning_rate': {
'type': 'stepwise',
......@@ -79,7 +107,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9}
'sgd': {'momentum': 0.9}
},
'learning_rate': {
'type': 'stepwise',
......@@ -112,7 +140,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9}
'sgd': {'momentum': 0.9}
},
'learning_rate': {
'type': 'exponential',
......@@ -142,7 +170,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9}
'sgd': {'momentum': 0.9}
},
'learning_rate': {
'type': 'polynomial',
......@@ -166,7 +194,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9}
'sgd': {'momentum': 0.9}
},
'learning_rate': {
'type': 'cosine',
......@@ -192,7 +220,13 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9}
'sgd': {'momentum': 0.9}
},
'learning_rate': {
'type': 'constant',
'constant': {
'learning_rate': 0.1
}
},
'warmup': {
'type': 'linear',
......@@ -216,7 +250,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9}
'sgd': {'momentum': 0.9}
},
'learning_rate': {
'type': 'stepwise',
......
......@@ -88,7 +88,6 @@ def is_special_none_tensor(tensor):
return tensor.shape.ndims == 0 and tensor.dtype == tf.int32
# TODO(hongkuny): consider moving custom string-map lookup to keras api.
def get_activation(identifier):
"""Maps a identifier to a Python function, e.g., "relu" => `tf.nn.relu`.
......
......@@ -14,23 +14,61 @@
# ==============================================================================
"""ALBERT classification finetuning runner in tf2.x."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import os
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
from official.nlp.albert import configs as albert_configs
from official.nlp.bert import bert_models
from official.nlp.bert import run_classifier as run_classifier_bert
from official.utils.misc import distribution_utils
FLAGS = flags.FLAGS
def predict(strategy, albert_config, input_meta_data, predict_input_fn):
"""Function outputs both the ground truth predictions as .tsv files."""
with strategy.scope():
classifier_model = bert_models.classifier_model(
albert_config, input_meta_data['num_labels'])[0]
checkpoint = tf.train.Checkpoint(model=classifier_model)
latest_checkpoint_file = (
FLAGS.predict_checkpoint_path or
tf.train.latest_checkpoint(FLAGS.model_dir))
assert latest_checkpoint_file
logging.info('Checkpoint file %s found and restoring from '
'checkpoint', latest_checkpoint_file)
checkpoint.restore(
latest_checkpoint_file).assert_existing_objects_matched()
preds, ground_truth = run_classifier_bert.get_predictions_and_labels(
strategy, classifier_model, predict_input_fn, return_probs=True)
output_predict_file = os.path.join(FLAGS.model_dir, 'test_results.tsv')
with tf.io.gfile.GFile(output_predict_file, 'w') as writer:
logging.info('***** Predict results *****')
for probabilities in preds:
output_line = '\t'.join(
str(class_probability)
for class_probability in probabilities) + '\n'
writer.write(output_line)
ground_truth_labels_file = os.path.join(FLAGS.model_dir,
'output_labels.tsv')
with tf.io.gfile.GFile(ground_truth_labels_file, 'w') as writer:
logging.info('***** Ground truth results *****')
for label in ground_truth:
output_line = '\t'.join(str(label)) + '\n'
writer.write(output_line)
return
def main(_):
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8'))
......@@ -56,9 +94,14 @@ def main(_):
albert_config = albert_configs.AlbertConfig.from_json_file(
FLAGS.bert_config_file)
if FLAGS.mode == 'train_and_eval':
run_classifier_bert.run_bert(strategy, input_meta_data, albert_config,
train_input_fn, eval_input_fn)
elif FLAGS.mode == 'predict':
predict(strategy, albert_config, input_meta_data, eval_input_fn)
else:
raise ValueError('Unsupported mode is specified: %s' % FLAGS.mode)
return
if __name__ == '__main__':
flags.mark_flag_as_required('bert_config_file')
......
......@@ -79,7 +79,7 @@ def export_bert_tfhub(bert_config: configs.BertConfig,
do_lower_case, vocab_file)
core_model, encoder = create_bert_model(bert_config)
checkpoint = tf.train.Checkpoint(model=encoder)
checkpoint.restore(model_checkpoint_path).assert_consumed()
checkpoint.restore(model_checkpoint_path).assert_existing_objects_matched()
core_model.vocab_file = tf.saved_model.Asset(vocab_file)
core_model.do_lower_case = tf.Variable(do_lower_case, trainable=False)
core_model.save(hub_destination, include_optimizer=False, save_format="tf")
......
......@@ -99,7 +99,9 @@ def write_txt_summary(training_summary, summary_dir):
@deprecation.deprecated(
None, 'This function is deprecated. Please use Keras compile/fit instead.')
None, 'This function is deprecated and we do not expect adding new '
'functionalities. Please do not have your code depending '
'on this library.')
def run_customized_training_loop(
# pylint: disable=invalid-name
_sentinel=None,
......@@ -557,7 +559,6 @@ def run_customized_training_loop(
for metric in model.metrics:
training_summary[metric.name] = _float_metric_value(metric)
if eval_metrics:
# TODO(hongkuny): Cleans up summary reporting in text.
training_summary['last_train_metrics'] = _float_metric_value(
train_metrics[0])
training_summary['eval_metrics'] = _float_metric_value(eval_metrics[0])
......
......@@ -24,7 +24,6 @@ import tensorflow as tf
from official.modeling import tf_utils
from official.modeling.hyperparams import base_config
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.configs import encoders
from official.nlp.modeling import layers
from official.nlp.modeling.models import bert_pretrainer
......@@ -43,7 +42,6 @@ class ClsHeadConfig(base_config.Config):
@dataclasses.dataclass
class BertPretrainerConfig(base_config.Config):
"""BERT encoder configuration."""
num_masked_tokens: int = 76
encoder: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig())
cls_heads: List[ClsHeadConfig] = dataclasses.field(default_factory=list)
......@@ -56,103 +54,18 @@ def instantiate_classification_heads_from_cfgs(
] if cls_head_configs else []
def instantiate_bertpretrainer_from_cfg(
def instantiate_pretrainer_from_cfg(
config: BertPretrainerConfig,
encoder_network: Optional[tf.keras.Model] = None
) -> bert_pretrainer.BertPretrainerV2:
) -> bert_pretrainer.BertPretrainerV2:
"""Instantiates a BertPretrainer from the config."""
encoder_cfg = config.encoder
if encoder_network is None:
encoder_network = encoders.instantiate_encoder_from_cfg(encoder_cfg)
return bert_pretrainer.BertPretrainerV2(
config.num_masked_tokens,
mlm_activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
mlm_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
encoder_network=encoder_network,
classification_heads=instantiate_classification_heads_from_cfgs(
config.cls_heads))
@dataclasses.dataclass
class BertPretrainDataConfig(cfg.DataConfig):
"""Data config for BERT pretraining task (tasks/masked_lm)."""
input_path: str = ""
global_batch_size: int = 512
is_training: bool = True
seq_length: int = 512
max_predictions_per_seq: int = 76
use_next_sentence_label: bool = True
use_position_id: bool = False
@dataclasses.dataclass
class BertPretrainEvalDataConfig(BertPretrainDataConfig):
"""Data config for the eval set in BERT pretraining task (tasks/masked_lm)."""
input_path: str = ""
global_batch_size: int = 512
is_training: bool = False
@dataclasses.dataclass
class SentencePredictionDataConfig(cfg.DataConfig):
"""Data config for sentence prediction task (tasks/sentence_prediction)."""
input_path: str = ""
global_batch_size: int = 32
is_training: bool = True
seq_length: int = 128
@dataclasses.dataclass
class SentencePredictionDevDataConfig(cfg.DataConfig):
"""Dev Data config for sentence prediction (tasks/sentence_prediction)."""
input_path: str = ""
global_batch_size: int = 32
is_training: bool = False
seq_length: int = 128
drop_remainder: bool = False
@dataclasses.dataclass
class QADataConfig(cfg.DataConfig):
"""Data config for question answering task (tasks/question_answering)."""
input_path: str = ""
global_batch_size: int = 48
is_training: bool = True
seq_length: int = 384
@dataclasses.dataclass
class QADevDataConfig(cfg.DataConfig):
"""Dev Data config for queston answering (tasks/question_answering)."""
input_path: str = ""
input_preprocessed_data_path: str = ""
version_2_with_negative: bool = False
doc_stride: int = 128
global_batch_size: int = 48
is_training: bool = False
seq_length: int = 384
query_length: int = 64
drop_remainder: bool = False
vocab_file: str = ""
tokenization: str = "WordPiece" # WordPiece or SentencePiece
do_lower_case: bool = True
@dataclasses.dataclass
class TaggingDataConfig(cfg.DataConfig):
"""Data config for tagging (tasks/tagging)."""
input_path: str = ""
global_batch_size: int = 48
is_training: bool = True
seq_length: int = 384
@dataclasses.dataclass
class TaggingDevDataConfig(cfg.DataConfig):
"""Dev Data config for tagging (tasks/tagging)."""
input_path: str = ""
global_batch_size: int = 48
is_training: bool = False
seq_length: int = 384
drop_remainder: bool = False
......@@ -26,7 +26,7 @@ class BertModelsTest(tf.test.TestCase):
def test_network_invocation(self):
config = bert.BertPretrainerConfig(
encoder=encoders.TransformerEncoderConfig(vocab_size=10, num_layers=1))
_ = bert.instantiate_bertpretrainer_from_cfg(config)
_ = bert.instantiate_pretrainer_from_cfg(config)
# Invokes with classification heads.
config = bert.BertPretrainerConfig(
......@@ -35,7 +35,7 @@ class BertModelsTest(tf.test.TestCase):
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence")
])
_ = bert.instantiate_bertpretrainer_from_cfg(config)
_ = bert.instantiate_pretrainer_from_cfg(config)
with self.assertRaises(ValueError):
config = bert.BertPretrainerConfig(
......@@ -47,7 +47,7 @@ class BertModelsTest(tf.test.TestCase):
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence")
])
_ = bert.instantiate_bertpretrainer_from_cfg(config)
_ = bert.instantiate_pretrainer_from_cfg(config)
def test_checkpoint_items(self):
config = bert.BertPretrainerConfig(
......@@ -56,9 +56,10 @@ class BertModelsTest(tf.test.TestCase):
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence")
])
encoder = bert.instantiate_bertpretrainer_from_cfg(config)
self.assertSameElements(encoder.checkpoint_items.keys(),
["encoder", "next_sentence.pooler_dense"])
encoder = bert.instantiate_pretrainer_from_cfg(config)
self.assertSameElements(
encoder.checkpoint_items.keys(),
["encoder", "masked_lm", "next_sentence.pooler_dense"])
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment