Unverified Commit 0cceabfc authored by Yiming Shi's avatar Yiming Shi Committed by GitHub
Browse files

Merge branch 'master' into move_to_keraslayers_fasterrcnn_fpn_keras_feature_extractor

parents 17821c0d 39ee0ac9
...@@ -20,6 +20,20 @@ import dataclasses ...@@ -20,6 +20,20 @@ import dataclasses
from official.modeling.hyperparams import base_config from official.modeling.hyperparams import base_config
@dataclasses.dataclass
class ConstantLrConfig(base_config.Config):
"""Configuration for constant learning rate.
This class is a containers for the constant learning rate decay configs.
Attributes:
name: The name of the learning rate schedule. Defaults to Constant.
learning_rate: A float. The learning rate. Defaults to 0.1.
"""
name: str = 'Constant'
learning_rate: float = 0.1
@dataclasses.dataclass @dataclasses.dataclass
class StepwiseLrConfig(base_config.Config): class StepwiseLrConfig(base_config.Config):
"""Configuration for stepwise learning rate decay. """Configuration for stepwise learning rate decay.
......
...@@ -39,12 +39,14 @@ class OptimizerConfig(oneof.OneOfConfig): ...@@ -39,12 +39,14 @@ class OptimizerConfig(oneof.OneOfConfig):
adam: adam optimizer config. adam: adam optimizer config.
adamw: adam with weight decay. adamw: adam with weight decay.
lamb: lamb optimizer. lamb: lamb optimizer.
rmsprop: rmsprop optimizer.
""" """
type: Optional[str] = None type: Optional[str] = None
sgd: opt_cfg.SGDConfig = opt_cfg.SGDConfig() sgd: opt_cfg.SGDConfig = opt_cfg.SGDConfig()
adam: opt_cfg.AdamConfig = opt_cfg.AdamConfig() adam: opt_cfg.AdamConfig = opt_cfg.AdamConfig()
adamw: opt_cfg.AdamWeightDecayConfig = opt_cfg.AdamWeightDecayConfig() adamw: opt_cfg.AdamWeightDecayConfig = opt_cfg.AdamWeightDecayConfig()
lamb: opt_cfg.LAMBConfig = opt_cfg.LAMBConfig() lamb: opt_cfg.LAMBConfig = opt_cfg.LAMBConfig()
rmsprop: opt_cfg.RMSPropConfig = opt_cfg.RMSPropConfig()
@dataclasses.dataclass @dataclasses.dataclass
...@@ -53,12 +55,14 @@ class LrConfig(oneof.OneOfConfig): ...@@ -53,12 +55,14 @@ class LrConfig(oneof.OneOfConfig):
Attributes: Attributes:
type: 'str', type of lr schedule to be used, on the of fields below. type: 'str', type of lr schedule to be used, on the of fields below.
constant: constant learning rate config.
stepwise: stepwise learning rate config. stepwise: stepwise learning rate config.
exponential: exponential learning rate config. exponential: exponential learning rate config.
polynomial: polynomial learning rate config. polynomial: polynomial learning rate config.
cosine: cosine learning rate config. cosine: cosine learning rate config.
""" """
type: Optional[str] = None type: Optional[str] = None
constant: lr_cfg.ConstantLrConfig = lr_cfg.ConstantLrConfig()
stepwise: lr_cfg.StepwiseLrConfig = lr_cfg.StepwiseLrConfig() stepwise: lr_cfg.StepwiseLrConfig = lr_cfg.StepwiseLrConfig()
exponential: lr_cfg.ExponentialLrConfig = lr_cfg.ExponentialLrConfig() exponential: lr_cfg.ExponentialLrConfig = lr_cfg.ExponentialLrConfig()
polynomial: lr_cfg.PolynomialLrConfig = lr_cfg.PolynomialLrConfig() polynomial: lr_cfg.PolynomialLrConfig = lr_cfg.PolynomialLrConfig()
......
...@@ -28,18 +28,37 @@ class SGDConfig(base_config.Config): ...@@ -28,18 +28,37 @@ class SGDConfig(base_config.Config):
Attributes: Attributes:
name: name of the optimizer. name: name of the optimizer.
learning_rate: learning_rate for SGD optimizer.
decay: decay rate for SGD optimizer. decay: decay rate for SGD optimizer.
nesterov: nesterov for SGD optimizer. nesterov: nesterov for SGD optimizer.
momentum: momentum for SGD optimizer. momentum: momentum for SGD optimizer.
""" """
name: str = "SGD" name: str = "SGD"
learning_rate: float = 0.01
decay: float = 0.0 decay: float = 0.0
nesterov: bool = False nesterov: bool = False
momentum: float = 0.0 momentum: float = 0.0
@dataclasses.dataclass
class RMSPropConfig(base_config.Config):
"""Configuration for RMSProp optimizer.
The attributes for this class matches the arguments of
tf.keras.optimizers.RMSprop.
Attributes:
name: name of the optimizer.
rho: discounting factor for RMSprop optimizer.
momentum: momentum for RMSprop optimizer.
epsilon: epsilon value for RMSprop optimizer, help with numerical stability.
centered: Whether to normalize gradients or not.
"""
name: str = "RMSprop"
rho: float = 0.9
momentum: float = 0.0
epsilon: float = 1e-7
centered: bool = False
@dataclasses.dataclass @dataclasses.dataclass
class AdamConfig(base_config.Config): class AdamConfig(base_config.Config):
"""Configuration for Adam optimizer. """Configuration for Adam optimizer.
...@@ -49,7 +68,6 @@ class AdamConfig(base_config.Config): ...@@ -49,7 +68,6 @@ class AdamConfig(base_config.Config):
Attributes: Attributes:
name: name of the optimizer. name: name of the optimizer.
learning_rate: learning_rate for Adam optimizer.
beta_1: decay rate for 1st order moments. beta_1: decay rate for 1st order moments.
beta_2: decay rate for 2st order moments. beta_2: decay rate for 2st order moments.
epsilon: epsilon value used for numerical stability in Adam optimizer. epsilon: epsilon value used for numerical stability in Adam optimizer.
...@@ -57,7 +75,6 @@ class AdamConfig(base_config.Config): ...@@ -57,7 +75,6 @@ class AdamConfig(base_config.Config):
the paper "On the Convergence of Adam and beyond". the paper "On the Convergence of Adam and beyond".
""" """
name: str = "Adam" name: str = "Adam"
learning_rate: float = 0.001
beta_1: float = 0.9 beta_1: float = 0.9
beta_2: float = 0.999 beta_2: float = 0.999
epsilon: float = 1e-07 epsilon: float = 1e-07
...@@ -70,7 +87,6 @@ class AdamWeightDecayConfig(base_config.Config): ...@@ -70,7 +87,6 @@ class AdamWeightDecayConfig(base_config.Config):
Attributes: Attributes:
name: name of the optimizer. name: name of the optimizer.
learning_rate: learning_rate for the optimizer.
beta_1: decay rate for 1st order moments. beta_1: decay rate for 1st order moments.
beta_2: decay rate for 2st order moments. beta_2: decay rate for 2st order moments.
epsilon: epsilon value used for numerical stability in the optimizer. epsilon: epsilon value used for numerical stability in the optimizer.
...@@ -83,7 +99,6 @@ class AdamWeightDecayConfig(base_config.Config): ...@@ -83,7 +99,6 @@ class AdamWeightDecayConfig(base_config.Config):
include in weight decay. include in weight decay.
""" """
name: str = "AdamWeightDecay" name: str = "AdamWeightDecay"
learning_rate: float = 0.001
beta_1: float = 0.9 beta_1: float = 0.9
beta_2: float = 0.999 beta_2: float = 0.999
epsilon: float = 1e-07 epsilon: float = 1e-07
...@@ -102,7 +117,6 @@ class LAMBConfig(base_config.Config): ...@@ -102,7 +117,6 @@ class LAMBConfig(base_config.Config):
Attributes: Attributes:
name: name of the optimizer. name: name of the optimizer.
learning_rate: learning_rate for Adam optimizer.
beta_1: decay rate for 1st order moments. beta_1: decay rate for 1st order moments.
beta_2: decay rate for 2st order moments. beta_2: decay rate for 2st order moments.
epsilon: epsilon value used for numerical stability in LAMB optimizer. epsilon: epsilon value used for numerical stability in LAMB optimizer.
...@@ -116,7 +130,6 @@ class LAMBConfig(base_config.Config): ...@@ -116,7 +130,6 @@ class LAMBConfig(base_config.Config):
be excluded. be excluded.
""" """
name: str = "LAMB" name: str = "LAMB"
learning_rate: float = 0.001
beta_1: float = 0.9 beta_1: float = 0.9
beta_2: float = 0.999 beta_2: float = 0.999
epsilon: float = 1e-6 epsilon: float = 1e-6
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Optimizer factory class.""" """Optimizer factory class."""
from typing import Union from typing import Union
import tensorflow as tf import tensorflow as tf
...@@ -29,7 +28,8 @@ OPTIMIZERS_CLS = { ...@@ -29,7 +28,8 @@ OPTIMIZERS_CLS = {
'sgd': tf.keras.optimizers.SGD, 'sgd': tf.keras.optimizers.SGD,
'adam': tf.keras.optimizers.Adam, 'adam': tf.keras.optimizers.Adam,
'adamw': nlp_optimization.AdamWeightDecay, 'adamw': nlp_optimization.AdamWeightDecay,
'lamb': tfa_optimizers.LAMB 'lamb': tfa_optimizers.LAMB,
'rmsprop': tf.keras.optimizers.RMSprop
} }
LR_CLS = { LR_CLS = {
...@@ -60,7 +60,7 @@ class OptimizerFactory(object): ...@@ -60,7 +60,7 @@ class OptimizerFactory(object):
params = { params = {
'optimizer': { 'optimizer': {
'type': 'sgd', 'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9} 'sgd': {'momentum': 0.9}
}, },
'learning_rate': { 'learning_rate': {
'type': 'stepwise', 'type': 'stepwise',
...@@ -88,12 +88,15 @@ class OptimizerFactory(object): ...@@ -88,12 +88,15 @@ class OptimizerFactory(object):
self._optimizer_config = config.optimizer.get() self._optimizer_config = config.optimizer.get()
self._optimizer_type = config.optimizer.type self._optimizer_type = config.optimizer.type
if self._optimizer_config is None: if self._optimizer_type is None:
raise ValueError('Optimizer type must be specified') raise ValueError('Optimizer type must be specified')
self._lr_config = config.learning_rate.get() self._lr_config = config.learning_rate.get()
self._lr_type = config.learning_rate.type self._lr_type = config.learning_rate.type
if self._lr_type is None:
raise ValueError('Learning rate type must be specified')
self._warmup_config = config.warmup.get() self._warmup_config = config.warmup.get()
self._warmup_type = config.warmup.type self._warmup_type = config.warmup.type
...@@ -101,18 +104,15 @@ class OptimizerFactory(object): ...@@ -101,18 +104,15 @@ class OptimizerFactory(object):
"""Build learning rate. """Build learning rate.
Builds learning rate from config. Learning rate schedule is built according Builds learning rate from config. Learning rate schedule is built according
to the learning rate config. If there is no learning rate config, optimizer to the learning rate config. If learning rate type is consant,
learning rate is returned. lr_config.learning_rate is returned.
Returns: Returns:
tf.keras.optimizers.schedules.LearningRateSchedule instance. If no tf.keras.optimizers.schedules.LearningRateSchedule instance. If
learning rate schedule defined, optimizer_config.learning_rate is learning rate type is consant, lr_config.learning_rate is returned.
returned.
""" """
if self._lr_type == 'constant':
# TODO(arashwan): Explore if we want to only allow explicit const lr sched. lr = self._lr_config.learning_rate
if not self._lr_config:
lr = self._optimizer_config.learning_rate
else: else:
lr = LR_CLS[self._lr_type](**self._lr_config.as_dict()) lr = LR_CLS[self._lr_type](**self._lr_config.as_dict())
......
...@@ -15,91 +15,72 @@ ...@@ -15,91 +15,72 @@
# ============================================================================== # ==============================================================================
"""Tests for optimizer_factory.py.""" """Tests for optimizer_factory.py."""
from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
import tensorflow_addons.optimizers as tfa_optimizers
from official.modeling.optimization import optimizer_factory from official.modeling.optimization import optimizer_factory
from official.modeling.optimization.configs import optimization_config from official.modeling.optimization.configs import optimization_config
from official.nlp import optimization as nlp_optimization
class OptimizerFactoryTest(tf.test.TestCase):
def test_sgd_optimizer(self):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9}
}
}
expected_optimizer_config = {
'name': 'SGD',
'learning_rate': 0.1,
'decay': 0.0,
'momentum': 0.9,
'nesterov': False
}
opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate()
optimizer = opt_factory.build_optimizer(lr)
self.assertIsInstance(optimizer, tf.keras.optimizers.SGD)
self.assertEqual(expected_optimizer_config, optimizer.get_config())
def test_adam_optimizer(self): class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
# Define adam optimizer with default values. @parameterized.parameters(
('sgd'),
('rmsprop'),
('adam'),
('adamw'),
('lamb'))
def test_optimizers(self, optimizer_type):
params = { params = {
'optimizer': { 'optimizer': {
'type': 'adam' 'type': optimizer_type
},
'learning_rate': {
'type': 'constant',
'constant': {
'learning_rate': 0.1
}
} }
} }
expected_optimizer_config = tf.keras.optimizers.Adam().get_config() optimizer_cls = optimizer_factory.OPTIMIZERS_CLS[optimizer_type]
expected_optimizer_config = optimizer_cls().get_config()
expected_optimizer_config['learning_rate'] = 0.1
opt_config = optimization_config.OptimizationConfig(params) opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config) opt_factory = optimizer_factory.OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate() lr = opt_factory.build_learning_rate()
optimizer = opt_factory.build_optimizer(lr) optimizer = opt_factory.build_optimizer(lr)
self.assertIsInstance(optimizer, tf.keras.optimizers.Adam) self.assertIsInstance(optimizer, optimizer_cls)
self.assertEqual(expected_optimizer_config, optimizer.get_config()) self.assertEqual(expected_optimizer_config, optimizer.get_config())
def test_adam_weight_decay_optimizer(self): def test_missing_types(self):
params = { params = {
'optimizer': { 'optimizer': {
'type': 'adamw' 'type': 'sgd',
'sgd': {'momentum': 0.9}
} }
} }
expected_optimizer_config = nlp_optimization.AdamWeightDecay().get_config() with self.assertRaises(ValueError):
opt_config = optimization_config.OptimizationConfig(params) optimizer_factory.OptimizerFactory(
opt_factory = optimizer_factory.OptimizerFactory(opt_config) optimization_config.OptimizationConfig(params))
lr = opt_factory.build_learning_rate()
optimizer = opt_factory.build_optimizer(lr)
self.assertIsInstance(optimizer, nlp_optimization.AdamWeightDecay)
self.assertEqual(expected_optimizer_config, optimizer.get_config())
def test_lamb_optimizer(self):
params = { params = {
'optimizer': { 'learning_rate': {
'type': 'lamb' 'type': 'stepwise',
'stepwise': {'boundaries': [10000, 20000],
'values': [0.1, 0.01, 0.001]}
} }
} }
expected_optimizer_config = tfa_optimizers.LAMB().get_config() with self.assertRaises(ValueError):
opt_config = optimization_config.OptimizationConfig(params) optimizer_factory.OptimizerFactory(
opt_factory = optimizer_factory.OptimizerFactory(opt_config) optimization_config.OptimizationConfig(params))
lr = opt_factory.build_learning_rate()
optimizer = opt_factory.build_optimizer(lr)
self.assertIsInstance(optimizer, tfa_optimizers.LAMB)
self.assertEqual(expected_optimizer_config, optimizer.get_config())
def test_stepwise_lr_schedule(self): def test_stepwise_lr_schedule(self):
params = { params = {
'optimizer': { 'optimizer': {
'type': 'sgd', 'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9} 'sgd': {'momentum': 0.9}
}, },
'learning_rate': { 'learning_rate': {
'type': 'stepwise', 'type': 'stepwise',
...@@ -126,7 +107,7 @@ class OptimizerFactoryTest(tf.test.TestCase): ...@@ -126,7 +107,7 @@ class OptimizerFactoryTest(tf.test.TestCase):
params = { params = {
'optimizer': { 'optimizer': {
'type': 'sgd', 'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9} 'sgd': {'momentum': 0.9}
}, },
'learning_rate': { 'learning_rate': {
'type': 'stepwise', 'type': 'stepwise',
...@@ -159,7 +140,7 @@ class OptimizerFactoryTest(tf.test.TestCase): ...@@ -159,7 +140,7 @@ class OptimizerFactoryTest(tf.test.TestCase):
params = { params = {
'optimizer': { 'optimizer': {
'type': 'sgd', 'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9} 'sgd': {'momentum': 0.9}
}, },
'learning_rate': { 'learning_rate': {
'type': 'exponential', 'type': 'exponential',
...@@ -189,7 +170,7 @@ class OptimizerFactoryTest(tf.test.TestCase): ...@@ -189,7 +170,7 @@ class OptimizerFactoryTest(tf.test.TestCase):
params = { params = {
'optimizer': { 'optimizer': {
'type': 'sgd', 'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9} 'sgd': {'momentum': 0.9}
}, },
'learning_rate': { 'learning_rate': {
'type': 'polynomial', 'type': 'polynomial',
...@@ -213,7 +194,7 @@ class OptimizerFactoryTest(tf.test.TestCase): ...@@ -213,7 +194,7 @@ class OptimizerFactoryTest(tf.test.TestCase):
params = { params = {
'optimizer': { 'optimizer': {
'type': 'sgd', 'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9} 'sgd': {'momentum': 0.9}
}, },
'learning_rate': { 'learning_rate': {
'type': 'cosine', 'type': 'cosine',
...@@ -239,7 +220,13 @@ class OptimizerFactoryTest(tf.test.TestCase): ...@@ -239,7 +220,13 @@ class OptimizerFactoryTest(tf.test.TestCase):
params = { params = {
'optimizer': { 'optimizer': {
'type': 'sgd', 'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9} 'sgd': {'momentum': 0.9}
},
'learning_rate': {
'type': 'constant',
'constant': {
'learning_rate': 0.1
}
}, },
'warmup': { 'warmup': {
'type': 'linear', 'type': 'linear',
...@@ -263,7 +250,7 @@ class OptimizerFactoryTest(tf.test.TestCase): ...@@ -263,7 +250,7 @@ class OptimizerFactoryTest(tf.test.TestCase):
params = { params = {
'optimizer': { 'optimizer': {
'type': 'sgd', 'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9} 'sgd': {'momentum': 0.9}
}, },
'learning_rate': { 'learning_rate': {
'type': 'stepwise', 'type': 'stepwise',
......
...@@ -88,7 +88,6 @@ def is_special_none_tensor(tensor): ...@@ -88,7 +88,6 @@ def is_special_none_tensor(tensor):
return tensor.shape.ndims == 0 and tensor.dtype == tf.int32 return tensor.shape.ndims == 0 and tensor.dtype == tf.int32
# TODO(hongkuny): consider moving custom string-map lookup to keras api.
def get_activation(identifier): def get_activation(identifier):
"""Maps a identifier to a Python function, e.g., "relu" => `tf.nn.relu`. """Maps a identifier to a Python function, e.g., "relu" => `tf.nn.relu`.
...@@ -173,3 +172,18 @@ def assert_rank(tensor, expected_rank, name=None): ...@@ -173,3 +172,18 @@ def assert_rank(tensor, expected_rank, name=None):
"For the tensor `%s`, the actual tensor rank `%d` (shape = %s) is not " "For the tensor `%s`, the actual tensor rank `%d` (shape = %s) is not "
"equal to the expected tensor rank `%s`" % "equal to the expected tensor rank `%s`" %
(name, actual_rank, str(tensor.shape), str(expected_rank))) (name, actual_rank, str(tensor.shape), str(expected_rank)))
def safe_mean(losses):
"""Computes a safe mean of the losses.
Args:
losses: `Tensor` whose elements contain individual loss measurements.
Returns:
A scalar representing the mean of `losses`. If `num_present` is zero,
then zero is returned.
"""
total = tf.reduce_sum(losses)
num_elements = tf.cast(tf.size(losses), dtype=losses.dtype)
return tf.math.divide_no_nan(total, num_elements)
...@@ -63,8 +63,8 @@ def metrics_as_dict(metric): ...@@ -63,8 +63,8 @@ def metrics_as_dict(metric):
"""Puts input metric(s) into a list. """Puts input metric(s) into a list.
Args: Args:
metric: metric(s) to be put into the list. `metric` could be a object, a metric: metric(s) to be put into the list. `metric` could be an object, a
list or a dict of tf.keras.metrics.Metric or has the `required_method`. list, or a dict of tf.keras.metrics.Metric or has the `required_method`.
Returns: Returns:
A dictionary of valid metrics. A dictionary of valid metrics.
...@@ -351,7 +351,8 @@ class DistributedExecutor(object): ...@@ -351,7 +351,8 @@ class DistributedExecutor(object):
train_input_fn: (params: dict) -> tf.data.Dataset training data input train_input_fn: (params: dict) -> tf.data.Dataset training data input
function. function.
eval_input_fn: (Optional) same type as train_input_fn. If not None, will eval_input_fn: (Optional) same type as train_input_fn. If not None, will
trigger evaluting metric on eval data. If None, will not run eval step. trigger evaluating metric on eval data. If None, will not run the eval
step.
model_dir: the folder path for model checkpoints. model_dir: the folder path for model checkpoints.
total_steps: total training steps. total_steps: total training steps.
iterations_per_loop: train steps per loop. After each loop, this job will iterations_per_loop: train steps per loop. After each loop, this job will
...@@ -672,7 +673,7 @@ class DistributedExecutor(object): ...@@ -672,7 +673,7 @@ class DistributedExecutor(object):
raise ValueError('if `eval_metric_fn` is specified, ' raise ValueError('if `eval_metric_fn` is specified, '
'eval_metric_fn must be a callable.') 'eval_metric_fn must be a callable.')
old_phrase = tf.keras.backend.learning_phase() old_phase = tf.keras.backend.learning_phase()
tf.keras.backend.set_learning_phase(0) tf.keras.backend.set_learning_phase(0)
params = self._params params = self._params
strategy = self._strategy strategy = self._strategy
...@@ -698,7 +699,8 @@ class DistributedExecutor(object): ...@@ -698,7 +699,8 @@ class DistributedExecutor(object):
logging.info( logging.info(
'Checkpoint file %s found and restoring from ' 'Checkpoint file %s found and restoring from '
'checkpoint', checkpoint_path) 'checkpoint', checkpoint_path)
checkpoint.restore(checkpoint_path) status = checkpoint.restore(checkpoint_path)
status.expect_partial().assert_existing_objects_matched()
self.global_train_step = model.optimizer.iterations self.global_train_step = model.optimizer.iterations
eval_iterator = self._get_input_iterator(eval_input_fn, strategy) eval_iterator = self._get_input_iterator(eval_input_fn, strategy)
...@@ -709,7 +711,7 @@ class DistributedExecutor(object): ...@@ -709,7 +711,7 @@ class DistributedExecutor(object):
summary_writer(metrics=eval_metric_result, step=current_step) summary_writer(metrics=eval_metric_result, step=current_step)
reset_states(eval_metric) reset_states(eval_metric)
tf.keras.backend.set_learning_phase(old_phrase) tf.keras.backend.set_learning_phase(old_phase)
return eval_metric_result, current_step return eval_metric_result, current_step
def predict(self): def predict(self):
...@@ -759,7 +761,7 @@ class ExecutorBuilder(object): ...@@ -759,7 +761,7 @@ class ExecutorBuilder(object):
Args: Args:
strategy_type: string. One of 'tpu', 'mirrored', 'multi_worker_mirrored'. strategy_type: string. One of 'tpu', 'mirrored', 'multi_worker_mirrored'.
If None. User is responsible to set the strategy before calling If None, the user is responsible to set the strategy before calling
build_executor(...). build_executor(...).
strategy_config: necessary config for constructing the proper Strategy. strategy_config: necessary config for constructing the proper Strategy.
Check strategy_flags_dict() for examples of the structure. Check strategy_flags_dict() for examples of the structure.
......
...@@ -14,23 +14,61 @@ ...@@ -14,23 +14,61 @@
# ============================================================================== # ==============================================================================
"""ALBERT classification finetuning runner in tf2.x.""" """ALBERT classification finetuning runner in tf2.x."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import json import json
import os
from absl import app from absl import app
from absl import flags from absl import flags
from absl import logging
import tensorflow as tf import tensorflow as tf
from official.nlp.albert import configs as albert_configs from official.nlp.albert import configs as albert_configs
from official.nlp.bert import bert_models
from official.nlp.bert import run_classifier as run_classifier_bert from official.nlp.bert import run_classifier as run_classifier_bert
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
def predict(strategy, albert_config, input_meta_data, predict_input_fn):
"""Function outputs both the ground truth predictions as .tsv files."""
with strategy.scope():
classifier_model = bert_models.classifier_model(
albert_config, input_meta_data['num_labels'])[0]
checkpoint = tf.train.Checkpoint(model=classifier_model)
latest_checkpoint_file = (
FLAGS.predict_checkpoint_path or
tf.train.latest_checkpoint(FLAGS.model_dir))
assert latest_checkpoint_file
logging.info('Checkpoint file %s found and restoring from '
'checkpoint', latest_checkpoint_file)
checkpoint.restore(
latest_checkpoint_file).assert_existing_objects_matched()
preds, ground_truth = run_classifier_bert.get_predictions_and_labels(
strategy, classifier_model, predict_input_fn, return_probs=True)
output_predict_file = os.path.join(FLAGS.model_dir, 'test_results.tsv')
with tf.io.gfile.GFile(output_predict_file, 'w') as writer:
logging.info('***** Predict results *****')
for probabilities in preds:
output_line = '\t'.join(
str(class_probability)
for class_probability in probabilities) + '\n'
writer.write(output_line)
ground_truth_labels_file = os.path.join(FLAGS.model_dir,
'output_labels.tsv')
with tf.io.gfile.GFile(ground_truth_labels_file, 'w') as writer:
logging.info('***** Ground truth results *****')
for label in ground_truth:
output_line = '\t'.join(str(label)) + '\n'
writer.write(output_line)
return
def main(_): def main(_):
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader: with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8')) input_meta_data = json.loads(reader.read().decode('utf-8'))
...@@ -56,9 +94,14 @@ def main(_): ...@@ -56,9 +94,14 @@ def main(_):
albert_config = albert_configs.AlbertConfig.from_json_file( albert_config = albert_configs.AlbertConfig.from_json_file(
FLAGS.bert_config_file) FLAGS.bert_config_file)
run_classifier_bert.run_bert(strategy, input_meta_data, albert_config, if FLAGS.mode == 'train_and_eval':
train_input_fn, eval_input_fn) run_classifier_bert.run_bert(strategy, input_meta_data, albert_config,
train_input_fn, eval_input_fn)
elif FLAGS.mode == 'predict':
predict(strategy, albert_config, input_meta_data, eval_input_fn)
else:
raise ValueError('Unsupported mode is specified: %s' % FLAGS.mode)
return
if __name__ == '__main__': if __name__ == '__main__':
flags.mark_flag_as_required('bert_config_file') flags.mark_flag_as_required('bert_config_file')
......
...@@ -86,7 +86,7 @@ def _create_albert_model(cfg): ...@@ -86,7 +86,7 @@ def _create_albert_model(cfg):
activation=activations.gelu, activation=activations.gelu,
dropout_rate=cfg.hidden_dropout_prob, dropout_rate=cfg.hidden_dropout_prob,
attention_dropout_rate=cfg.attention_probs_dropout_prob, attention_dropout_rate=cfg.attention_probs_dropout_prob,
sequence_length=cfg.max_position_embeddings, max_sequence_length=cfg.max_position_embeddings,
type_vocab_size=cfg.type_vocab_size, type_vocab_size=cfg.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal( initializer=tf.keras.initializers.TruncatedNormal(
stddev=cfg.initializer_range)) stddev=cfg.initializer_range))
......
...@@ -25,7 +25,6 @@ import tensorflow_hub as hub ...@@ -25,7 +25,6 @@ import tensorflow_hub as hub
from official.modeling import tf_utils from official.modeling import tf_utils
from official.nlp.albert import configs as albert_configs from official.nlp.albert import configs as albert_configs
from official.nlp.bert import configs from official.nlp.bert import configs
from official.nlp.modeling import losses
from official.nlp.modeling import models from official.nlp.modeling import models
from official.nlp.modeling import networks from official.nlp.modeling import networks
...@@ -67,22 +66,27 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer): ...@@ -67,22 +66,27 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
next_sentence_loss, name='next_sentence_loss', aggregation='mean') next_sentence_loss, name='next_sentence_loss', aggregation='mean')
def call(self, def call(self,
lm_output, lm_output_logits,
sentence_output, sentence_output_logits,
lm_label_ids, lm_label_ids,
lm_label_weights, lm_label_weights,
sentence_labels=None): sentence_labels=None):
"""Implements call() for the layer.""" """Implements call() for the layer."""
lm_label_weights = tf.cast(lm_label_weights, tf.float32) lm_label_weights = tf.cast(lm_label_weights, tf.float32)
lm_output = tf.cast(lm_output, tf.float32) lm_output_logits = tf.cast(lm_output_logits, tf.float32)
mask_label_loss = losses.weighted_sparse_categorical_crossentropy_loss( lm_prediction_losses = tf.keras.losses.sparse_categorical_crossentropy(
labels=lm_label_ids, predictions=lm_output, weights=lm_label_weights) lm_label_ids, lm_output_logits, from_logits=True)
lm_numerator_loss = tf.reduce_sum(lm_prediction_losses * lm_label_weights)
lm_denominator_loss = tf.reduce_sum(lm_label_weights)
mask_label_loss = tf.math.divide_no_nan(lm_numerator_loss,
lm_denominator_loss)
if sentence_labels is not None: if sentence_labels is not None:
sentence_output = tf.cast(sentence_output, tf.float32) sentence_output_logits = tf.cast(sentence_output_logits, tf.float32)
sentence_loss = losses.weighted_sparse_categorical_crossentropy_loss( sentence_loss = tf.keras.losses.sparse_categorical_crossentropy(
labels=sentence_labels, predictions=sentence_output) sentence_labels, sentence_output_logits, from_logits=True)
sentence_loss = tf.reduce_mean(sentence_loss)
loss = mask_label_loss + sentence_loss loss = mask_label_loss + sentence_loss
else: else:
sentence_loss = None sentence_loss = None
...@@ -92,22 +96,22 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer): ...@@ -92,22 +96,22 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
# TODO(hongkuny): Avoids the hack and switches add_loss. # TODO(hongkuny): Avoids the hack and switches add_loss.
final_loss = tf.fill(batch_shape, loss) final_loss = tf.fill(batch_shape, loss)
self._add_metrics(lm_output, lm_label_ids, lm_label_weights, self._add_metrics(lm_output_logits, lm_label_ids, lm_label_weights,
mask_label_loss, sentence_output, sentence_labels, mask_label_loss, sentence_output_logits, sentence_labels,
sentence_loss) sentence_loss)
return final_loss return final_loss
@gin.configurable @gin.configurable
def get_transformer_encoder(bert_config, def get_transformer_encoder(bert_config,
sequence_length, sequence_length=None,
transformer_encoder_cls=None, transformer_encoder_cls=None,
output_range=None): output_range=None):
"""Gets a 'TransformerEncoder' object. """Gets a 'TransformerEncoder' object.
Args: Args:
bert_config: A 'modeling.BertConfig' or 'modeling.AlbertConfig' object. bert_config: A 'modeling.BertConfig' or 'modeling.AlbertConfig' object.
sequence_length: Maximum sequence length of the training data. sequence_length: [Deprecated].
transformer_encoder_cls: A EncoderScaffold class. If it is None, uses the transformer_encoder_cls: A EncoderScaffold class. If it is None, uses the
default BERT encoder implementation. default BERT encoder implementation.
output_range: the sequence output range, [0, output_range). Default setting output_range: the sequence output range, [0, output_range). Default setting
...@@ -116,13 +120,13 @@ def get_transformer_encoder(bert_config, ...@@ -116,13 +120,13 @@ def get_transformer_encoder(bert_config,
Returns: Returns:
A networks.TransformerEncoder object. A networks.TransformerEncoder object.
""" """
del sequence_length
if transformer_encoder_cls is not None: if transformer_encoder_cls is not None:
# TODO(hongkuny): evaluate if it is better to put cfg definition in gin. # TODO(hongkuny): evaluate if it is better to put cfg definition in gin.
embedding_cfg = dict( embedding_cfg = dict(
vocab_size=bert_config.vocab_size, vocab_size=bert_config.vocab_size,
type_vocab_size=bert_config.type_vocab_size, type_vocab_size=bert_config.type_vocab_size,
hidden_size=bert_config.hidden_size, hidden_size=bert_config.hidden_size,
seq_length=sequence_length,
max_seq_length=bert_config.max_position_embeddings, max_seq_length=bert_config.max_position_embeddings,
initializer=tf.keras.initializers.TruncatedNormal( initializer=tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range), stddev=bert_config.initializer_range),
...@@ -157,7 +161,6 @@ def get_transformer_encoder(bert_config, ...@@ -157,7 +161,6 @@ def get_transformer_encoder(bert_config,
activation=tf_utils.get_activation(bert_config.hidden_act), activation=tf_utils.get_activation(bert_config.hidden_act),
dropout_rate=bert_config.hidden_dropout_prob, dropout_rate=bert_config.hidden_dropout_prob,
attention_dropout_rate=bert_config.attention_probs_dropout_prob, attention_dropout_rate=bert_config.attention_probs_dropout_prob,
sequence_length=sequence_length,
max_sequence_length=bert_config.max_position_embeddings, max_sequence_length=bert_config.max_position_embeddings,
type_vocab_size=bert_config.type_vocab_size, type_vocab_size=bert_config.type_vocab_size,
embedding_width=bert_config.embedding_size, embedding_width=bert_config.embedding_size,
...@@ -228,7 +231,7 @@ def pretrain_model(bert_config, ...@@ -228,7 +231,7 @@ def pretrain_model(bert_config,
activation=tf_utils.get_activation(bert_config.hidden_act), activation=tf_utils.get_activation(bert_config.hidden_act),
num_token_predictions=max_predictions_per_seq, num_token_predictions=max_predictions_per_seq,
initializer=initializer, initializer=initializer,
output='predictions') output='logits')
outputs = pretrainer_model( outputs = pretrainer_model(
[input_word_ids, input_mask, input_type_ids, masked_lm_positions]) [input_word_ids, input_mask, input_type_ids, masked_lm_positions])
......
...@@ -56,8 +56,6 @@ class BertModelsTest(tf.test.TestCase): ...@@ -56,8 +56,6 @@ class BertModelsTest(tf.test.TestCase):
# Expect two output from encoder: sequence and classification output. # Expect two output from encoder: sequence and classification output.
self.assertIsInstance(encoder.output, list) self.assertIsInstance(encoder.output, list)
self.assertLen(encoder.output, 2) self.assertLen(encoder.output, 2)
# shape should be [batch size, seq_length, hidden_size]
self.assertEqual(encoder.output[0].shape.as_list(), [None, 5, 16])
# shape should be [batch size, hidden_size] # shape should be [batch size, hidden_size]
self.assertEqual(encoder.output[1].shape.as_list(), [None, 16]) self.assertEqual(encoder.output[1].shape.as_list(), [None, 16])
...@@ -74,16 +72,12 @@ class BertModelsTest(tf.test.TestCase): ...@@ -74,16 +72,12 @@ class BertModelsTest(tf.test.TestCase):
# Expect two output from model: start positions and end positions # Expect two output from model: start positions and end positions
self.assertIsInstance(model.output, list) self.assertIsInstance(model.output, list)
self.assertLen(model.output, 2) self.assertLen(model.output, 2)
# shape should be [batch size, seq_length]
self.assertEqual(model.output[0].shape.as_list(), [None, 5])
# shape should be [batch size, seq_length]
self.assertEqual(model.output[1].shape.as_list(), [None, 5])
# Expect two output from core_model: sequence and classification output. # Expect two output from core_model: sequence and classification output.
self.assertIsInstance(core_model.output, list) self.assertIsInstance(core_model.output, list)
self.assertLen(core_model.output, 2) self.assertLen(core_model.output, 2)
# shape should be [batch size, seq_length, hidden_size] # shape should be [batch size, None, hidden_size]
self.assertEqual(core_model.output[0].shape.as_list(), [None, 5, 16]) self.assertEqual(core_model.output[0].shape.as_list(), [None, None, 16])
# shape should be [batch size, hidden_size] # shape should be [batch size, hidden_size]
self.assertEqual(core_model.output[1].shape.as_list(), [None, 16]) self.assertEqual(core_model.output[1].shape.as_list(), [None, 16])
...@@ -104,8 +98,8 @@ class BertModelsTest(tf.test.TestCase): ...@@ -104,8 +98,8 @@ class BertModelsTest(tf.test.TestCase):
# Expect two output from core_model: sequence and classification output. # Expect two output from core_model: sequence and classification output.
self.assertIsInstance(core_model.output, list) self.assertIsInstance(core_model.output, list)
self.assertLen(core_model.output, 2) self.assertLen(core_model.output, 2)
# shape should be [batch size, 1, hidden_size] # shape should be [batch size, None, hidden_size]
self.assertEqual(core_model.output[0].shape.as_list(), [None, 1, 16]) self.assertEqual(core_model.output[0].shape.as_list(), [None, None, 16])
# shape should be [batch size, hidden_size] # shape should be [batch size, hidden_size]
self.assertEqual(core_model.output[1].shape.as_list(), [None, 16]) self.assertEqual(core_model.output[1].shape.as_list(), [None, 16])
......
...@@ -79,7 +79,7 @@ def export_bert_tfhub(bert_config: configs.BertConfig, ...@@ -79,7 +79,7 @@ def export_bert_tfhub(bert_config: configs.BertConfig,
do_lower_case, vocab_file) do_lower_case, vocab_file)
core_model, encoder = create_bert_model(bert_config) core_model, encoder = create_bert_model(bert_config)
checkpoint = tf.train.Checkpoint(model=encoder) checkpoint = tf.train.Checkpoint(model=encoder)
checkpoint.restore(model_checkpoint_path).assert_consumed() checkpoint.restore(model_checkpoint_path).assert_existing_objects_matched()
core_model.vocab_file = tf.saved_model.Asset(vocab_file) core_model.vocab_file = tf.saved_model.Asset(vocab_file)
core_model.do_lower_case = tf.Variable(do_lower_case, trainable=False) core_model.do_lower_case = tf.Variable(do_lower_case, trainable=False)
core_model.save(hub_destination, include_optimizer=False, save_format="tf") core_model.save(hub_destination, include_optimizer=False, save_format="tf")
......
...@@ -247,3 +247,39 @@ def create_squad_dataset(file_path, ...@@ -247,3 +247,39 @@ def create_squad_dataset(file_path,
dataset = dataset.batch(batch_size, drop_remainder=True) dataset = dataset.batch(batch_size, drop_remainder=True)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
return dataset return dataset
def create_retrieval_dataset(file_path,
seq_length,
batch_size,
input_pipeline_context=None):
"""Creates input dataset from (tf)records files for scoring."""
name_to_features = {
'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
'input_mask': tf.io.FixedLenFeature([seq_length], tf.int64),
'segment_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
'int_iden': tf.io.FixedLenFeature([1], tf.int64),
}
dataset = single_file_dataset(file_path, name_to_features)
# The dataset is always sharded by number of hosts.
# num_input_pipelines is the number of hosts rather than number of cores.
if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1:
dataset = dataset.shard(input_pipeline_context.num_input_pipelines,
input_pipeline_context.input_pipeline_id)
def _select_data_from_record(record):
x = {
'input_word_ids': record['input_ids'],
'input_mask': record['input_mask'],
'input_type_ids': record['segment_ids']
}
y = record['int_iden']
return (x, y)
dataset = dataset.map(
_select_data_from_record,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.batch(batch_size, drop_remainder=False)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
return dataset
...@@ -55,14 +55,10 @@ def export_bert_model(model_export_path: typing.Text, ...@@ -55,14 +55,10 @@ def export_bert_model(model_export_path: typing.Text,
raise ValueError('model must be a tf.keras.Model object.') raise ValueError('model must be a tf.keras.Model object.')
if checkpoint_dir: if checkpoint_dir:
# Keras compile/fit() was used to save checkpoint using
# model.save_weights().
if restore_model_using_load_weights: if restore_model_using_load_weights:
model_weight_path = os.path.join(checkpoint_dir, 'checkpoint') model_weight_path = os.path.join(checkpoint_dir, 'checkpoint')
assert tf.io.gfile.exists(model_weight_path) assert tf.io.gfile.exists(model_weight_path)
model.load_weights(model_weight_path) model.load_weights(model_weight_path)
# tf.train.Checkpoint API was used via custom training loop logic.
else: else:
checkpoint = tf.train.Checkpoint(model=model) checkpoint = tf.train.Checkpoint(model=model)
......
...@@ -99,7 +99,9 @@ def write_txt_summary(training_summary, summary_dir): ...@@ -99,7 +99,9 @@ def write_txt_summary(training_summary, summary_dir):
@deprecation.deprecated( @deprecation.deprecated(
None, 'This function is deprecated. Please use Keras compile/fit instead.') None, 'This function is deprecated and we do not expect adding new '
'functionalities. Please do not have your code depending '
'on this library.')
def run_customized_training_loop( def run_customized_training_loop(
# pylint: disable=invalid-name # pylint: disable=invalid-name
_sentinel=None, _sentinel=None,
...@@ -557,7 +559,6 @@ def run_customized_training_loop( ...@@ -557,7 +559,6 @@ def run_customized_training_loop(
for metric in model.metrics: for metric in model.metrics:
training_summary[metric.name] = _float_metric_value(metric) training_summary[metric.name] = _float_metric_value(metric)
if eval_metrics: if eval_metrics:
# TODO(hongkuny): Cleans up summary reporting in text.
training_summary['last_train_metrics'] = _float_metric_value( training_summary['last_train_metrics'] = _float_metric_value(
train_metrics[0]) train_metrics[0])
training_summary['eval_metrics'] = _float_metric_value(eval_metrics[0]) training_summary['eval_metrics'] = _float_metric_value(eval_metrics[0])
......
...@@ -343,7 +343,10 @@ def export_classifier(model_export_path, input_meta_data, bert_config, ...@@ -343,7 +343,10 @@ def export_classifier(model_export_path, input_meta_data, bert_config,
# Export uses float32 for now, even if training uses mixed precision. # Export uses float32 for now, even if training uses mixed precision.
tf.keras.mixed_precision.experimental.set_policy('float32') tf.keras.mixed_precision.experimental.set_policy('float32')
classifier_model = bert_models.classifier_model( classifier_model = bert_models.classifier_model(
bert_config, input_meta_data.get('num_labels', 1))[0] bert_config,
input_meta_data.get('num_labels', 1),
hub_module_url=FLAGS.hub_module_url,
hub_module_trainable=False)[0]
model_saving_utils.export_bert_model( model_saving_utils.export_bert_model(
model_export_path, model=classifier_model, checkpoint_dir=model_dir) model_export_path, model=classifier_model, checkpoint_dir=model_dir)
......
...@@ -61,7 +61,11 @@ def define_common_squad_flags(): ...@@ -61,7 +61,11 @@ def define_common_squad_flags():
flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.') flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.')
# Predict processing related. # Predict processing related.
flags.DEFINE_string('predict_file', None, flags.DEFINE_string('predict_file', None,
'Prediction data path with train tfrecords.') 'SQuAD prediction json file path. '
'`predict` mode supports multiple files: one can use '
'wildcard to specify multiple files and it can also be '
'multiple file patterns separated by comma. Note that '
'`eval` mode only supports a single predict file.')
flags.DEFINE_bool( flags.DEFINE_bool(
'do_lower_case', True, 'do_lower_case', True,
'Whether to lower case the input text. Should be True for uncased ' 'Whether to lower case the input text. Should be True for uncased '
...@@ -159,22 +163,9 @@ def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size, ...@@ -159,22 +163,9 @@ def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size,
return _dataset_fn return _dataset_fn
def predict_squad_customized(strategy, def get_squad_model_to_predict(strategy, bert_config, checkpoint_path,
input_meta_data, input_meta_data):
bert_config, """Gets a squad model to make predictions."""
checkpoint_path,
predict_tfrecord_path,
num_steps):
"""Make predictions using a Bert-based squad model."""
predict_dataset_fn = get_dataset_fn(
predict_tfrecord_path,
input_meta_data['max_seq_length'],
FLAGS.predict_batch_size,
is_training=False)
predict_iterator = iter(
strategy.experimental_distribute_datasets_from_function(
predict_dataset_fn))
with strategy.scope(): with strategy.scope():
# Prediction always uses float32, even if training uses mixed precision. # Prediction always uses float32, even if training uses mixed precision.
tf.keras.mixed_precision.experimental.set_policy('float32') tf.keras.mixed_precision.experimental.set_policy('float32')
...@@ -188,6 +179,23 @@ def predict_squad_customized(strategy, ...@@ -188,6 +179,23 @@ def predict_squad_customized(strategy,
logging.info('Restoring checkpoints from %s', checkpoint_path) logging.info('Restoring checkpoints from %s', checkpoint_path)
checkpoint = tf.train.Checkpoint(model=squad_model) checkpoint = tf.train.Checkpoint(model=squad_model)
checkpoint.restore(checkpoint_path).expect_partial() checkpoint.restore(checkpoint_path).expect_partial()
return squad_model
def predict_squad_customized(strategy,
input_meta_data,
predict_tfrecord_path,
num_steps,
squad_model):
"""Make predictions using a Bert-based squad model."""
predict_dataset_fn = get_dataset_fn(
predict_tfrecord_path,
input_meta_data['max_seq_length'],
FLAGS.predict_batch_size,
is_training=False)
predict_iterator = iter(
strategy.experimental_distribute_datasets_from_function(
predict_dataset_fn))
@tf.function @tf.function
def predict_step(iterator): def predict_step(iterator):
...@@ -287,8 +295,8 @@ def train_squad(strategy, ...@@ -287,8 +295,8 @@ def train_squad(strategy,
post_allreduce_callbacks=[clip_by_global_norm_callback]) post_allreduce_callbacks=[clip_by_global_norm_callback])
def prediction_output_squad( def prediction_output_squad(strategy, input_meta_data, tokenizer, squad_lib,
strategy, input_meta_data, tokenizer, bert_config, squad_lib, checkpoint): predict_file, squad_model):
"""Makes predictions for a squad dataset.""" """Makes predictions for a squad dataset."""
doc_stride = input_meta_data['doc_stride'] doc_stride = input_meta_data['doc_stride']
max_query_length = input_meta_data['max_query_length'] max_query_length = input_meta_data['max_query_length']
...@@ -296,7 +304,7 @@ def prediction_output_squad( ...@@ -296,7 +304,7 @@ def prediction_output_squad(
version_2_with_negative = input_meta_data.get('version_2_with_negative', version_2_with_negative = input_meta_data.get('version_2_with_negative',
False) False)
eval_examples = squad_lib.read_squad_examples( eval_examples = squad_lib.read_squad_examples(
input_file=FLAGS.predict_file, input_file=predict_file,
is_training=False, is_training=False,
version_2_with_negative=version_2_with_negative) version_2_with_negative=version_2_with_negative)
...@@ -337,8 +345,7 @@ def prediction_output_squad( ...@@ -337,8 +345,7 @@ def prediction_output_squad(
num_steps = int(dataset_size / FLAGS.predict_batch_size) num_steps = int(dataset_size / FLAGS.predict_batch_size)
all_results = predict_squad_customized( all_results = predict_squad_customized(
strategy, input_meta_data, bert_config, strategy, input_meta_data, eval_writer.filename, num_steps, squad_model)
checkpoint, eval_writer.filename, num_steps)
all_predictions, all_nbest_json, scores_diff_json = ( all_predictions, all_nbest_json, scores_diff_json = (
squad_lib.postprocess_output( squad_lib.postprocess_output(
...@@ -356,11 +363,14 @@ def prediction_output_squad( ...@@ -356,11 +363,14 @@ def prediction_output_squad(
def dump_to_files(all_predictions, all_nbest_json, scores_diff_json, def dump_to_files(all_predictions, all_nbest_json, scores_diff_json,
squad_lib, version_2_with_negative): squad_lib, version_2_with_negative, file_prefix=''):
"""Save output to json files.""" """Save output to json files."""
output_prediction_file = os.path.join(FLAGS.model_dir, 'predictions.json') output_prediction_file = os.path.join(FLAGS.model_dir,
output_nbest_file = os.path.join(FLAGS.model_dir, 'nbest_predictions.json') '%spredictions.json' % file_prefix)
output_null_log_odds_file = os.path.join(FLAGS.model_dir, 'null_odds.json') output_nbest_file = os.path.join(FLAGS.model_dir,
'%snbest_predictions.json' % file_prefix)
output_null_log_odds_file = os.path.join(FLAGS.model_dir, file_prefix,
'%snull_odds.json' % file_prefix)
logging.info('Writing predictions to: %s', (output_prediction_file)) logging.info('Writing predictions to: %s', (output_prediction_file))
logging.info('Writing nbest to: %s', (output_nbest_file)) logging.info('Writing nbest to: %s', (output_nbest_file))
...@@ -370,6 +380,22 @@ def dump_to_files(all_predictions, all_nbest_json, scores_diff_json, ...@@ -370,6 +380,22 @@ def dump_to_files(all_predictions, all_nbest_json, scores_diff_json,
squad_lib.write_to_json_files(scores_diff_json, output_null_log_odds_file) squad_lib.write_to_json_files(scores_diff_json, output_null_log_odds_file)
def _get_matched_files(input_path):
"""Returns all files that matches the input_path."""
input_patterns = input_path.strip().split(',')
all_matched_files = []
for input_pattern in input_patterns:
input_pattern = input_pattern.strip()
if not input_pattern:
continue
matched_files = tf.io.gfile.glob(input_pattern)
if not matched_files:
raise ValueError('%s does not match any files.' % input_pattern)
else:
all_matched_files.extend(matched_files)
return sorted(all_matched_files)
def predict_squad(strategy, def predict_squad(strategy,
input_meta_data, input_meta_data,
tokenizer, tokenizer,
...@@ -379,11 +405,24 @@ def predict_squad(strategy, ...@@ -379,11 +405,24 @@ def predict_squad(strategy,
"""Get prediction results and evaluate them to hard drive.""" """Get prediction results and evaluate them to hard drive."""
if init_checkpoint is None: if init_checkpoint is None:
init_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir) init_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
all_predictions, all_nbest_json, scores_diff_json = prediction_output_squad(
strategy, input_meta_data, tokenizer, all_predict_files = _get_matched_files(FLAGS.predict_file)
bert_config, squad_lib, init_checkpoint) squad_model = get_squad_model_to_predict(strategy, bert_config,
dump_to_files(all_predictions, all_nbest_json, scores_diff_json, squad_lib, init_checkpoint, input_meta_data)
input_meta_data.get('version_2_with_negative', False)) for idx, predict_file in enumerate(all_predict_files):
all_predictions, all_nbest_json, scores_diff_json = prediction_output_squad(
strategy, input_meta_data, tokenizer, squad_lib, predict_file,
squad_model)
if len(all_predict_files) == 1:
file_prefix = ''
else:
# if predict_file is /path/xquad.ar.json, the `file_prefix` may be
# "xquad.ar-0-"
file_prefix = '%s-' % os.path.splitext(
os.path.basename(all_predict_files[idx]))[0]
dump_to_files(all_predictions, all_nbest_json, scores_diff_json, squad_lib,
input_meta_data.get('version_2_with_negative', False),
file_prefix)
def eval_squad(strategy, def eval_squad(strategy,
...@@ -395,9 +434,17 @@ def eval_squad(strategy, ...@@ -395,9 +434,17 @@ def eval_squad(strategy,
"""Get prediction results and evaluate them against ground truth.""" """Get prediction results and evaluate them against ground truth."""
if init_checkpoint is None: if init_checkpoint is None:
init_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir) init_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
all_predict_files = _get_matched_files(FLAGS.predict_file)
if len(all_predict_files) != 1:
raise ValueError('`eval_squad` only supports one predict file, '
'but got %s' % all_predict_files)
squad_model = get_squad_model_to_predict(strategy, bert_config,
init_checkpoint, input_meta_data)
all_predictions, all_nbest_json, scores_diff_json = prediction_output_squad( all_predictions, all_nbest_json, scores_diff_json = prediction_output_squad(
strategy, input_meta_data, tokenizer, strategy, input_meta_data, tokenizer, squad_lib, all_predict_files[0],
bert_config, squad_lib, init_checkpoint) squad_model)
dump_to_files(all_predictions, all_nbest_json, scores_diff_json, squad_lib, dump_to_files(all_predictions, all_nbest_json, scores_diff_json, squad_lib,
input_meta_data.get('version_2_with_negative', False)) input_meta_data.get('version_2_with_negative', False))
......
...@@ -61,7 +61,7 @@ def _create_bert_model(cfg): ...@@ -61,7 +61,7 @@ def _create_bert_model(cfg):
activation=activations.gelu, activation=activations.gelu,
dropout_rate=cfg.hidden_dropout_prob, dropout_rate=cfg.hidden_dropout_prob,
attention_dropout_rate=cfg.attention_probs_dropout_prob, attention_dropout_rate=cfg.attention_probs_dropout_prob,
sequence_length=cfg.max_position_embeddings, max_sequence_length=cfg.max_position_embeddings,
type_vocab_size=cfg.type_vocab_size, type_vocab_size=cfg.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal( initializer=tf.keras.initializers.TruncatedNormal(
stddev=cfg.initializer_range), stddev=cfg.initializer_range),
...@@ -73,6 +73,7 @@ def _create_bert_model(cfg): ...@@ -73,6 +73,7 @@ def _create_bert_model(cfg):
def convert_checkpoint(bert_config, output_path, v1_checkpoint): def convert_checkpoint(bert_config, output_path, v1_checkpoint):
"""Converts a V1 checkpoint into an OO V2 checkpoint.""" """Converts a V1 checkpoint into an OO V2 checkpoint."""
output_dir, _ = os.path.split(output_path) output_dir, _ = os.path.split(output_path)
tf.io.gfile.makedirs(output_dir)
# Create a temporary V1 name-converted checkpoint in the output directory. # Create a temporary V1 name-converted checkpoint in the output directory.
temporary_checkpoint_dir = os.path.join(output_dir, "temp_v1") temporary_checkpoint_dir = os.path.join(output_dir, "temp_v1")
......
...@@ -13,7 +13,10 @@ ...@@ -13,7 +13,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""A multi-head BERT encoder network for pretraining.""" """Multi-head BERT encoder network with classification heads.
Includes configurations and instantiation methods.
"""
from typing import List, Optional, Text from typing import List, Optional, Text
import dataclasses import dataclasses
...@@ -21,10 +24,8 @@ import tensorflow as tf ...@@ -21,10 +24,8 @@ import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
from official.modeling.hyperparams import base_config from official.modeling.hyperparams import base_config
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.configs import encoders from official.nlp.configs import encoders
from official.nlp.modeling import layers from official.nlp.modeling import layers
from official.nlp.modeling import networks
from official.nlp.modeling.models import bert_pretrainer from official.nlp.modeling.models import bert_pretrainer
...@@ -41,80 +42,30 @@ class ClsHeadConfig(base_config.Config): ...@@ -41,80 +42,30 @@ class ClsHeadConfig(base_config.Config):
@dataclasses.dataclass @dataclasses.dataclass
class BertPretrainerConfig(base_config.Config): class BertPretrainerConfig(base_config.Config):
"""BERT encoder configuration.""" """BERT encoder configuration."""
num_masked_tokens: int = 76
encoder: encoders.TransformerEncoderConfig = ( encoder: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig()) encoders.TransformerEncoderConfig())
cls_heads: List[ClsHeadConfig] = dataclasses.field(default_factory=list) cls_heads: List[ClsHeadConfig] = dataclasses.field(default_factory=list)
def instantiate_from_cfg( def instantiate_classification_heads_from_cfgs(
cls_head_configs: List[ClsHeadConfig]) -> List[layers.ClassificationHead]:
return [
layers.ClassificationHead(**cfg.as_dict()) for cfg in cls_head_configs
] if cls_head_configs else []
def instantiate_pretrainer_from_cfg(
config: BertPretrainerConfig, config: BertPretrainerConfig,
encoder_network: Optional[tf.keras.Model] = None): encoder_network: Optional[tf.keras.Model] = None
) -> bert_pretrainer.BertPretrainerV2:
"""Instantiates a BertPretrainer from the config.""" """Instantiates a BertPretrainer from the config."""
encoder_cfg = config.encoder encoder_cfg = config.encoder
if encoder_network is None: if encoder_network is None:
encoder_network = networks.TransformerEncoder( encoder_network = encoders.instantiate_encoder_from_cfg(encoder_cfg)
vocab_size=encoder_cfg.vocab_size,
hidden_size=encoder_cfg.hidden_size,
num_layers=encoder_cfg.num_layers,
num_attention_heads=encoder_cfg.num_attention_heads,
intermediate_size=encoder_cfg.intermediate_size,
activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
dropout_rate=encoder_cfg.dropout_rate,
attention_dropout_rate=encoder_cfg.attention_dropout_rate,
max_sequence_length=encoder_cfg.max_position_embeddings,
type_vocab_size=encoder_cfg.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range))
if config.cls_heads:
classification_heads = [
layers.ClassificationHead(**cfg.as_dict()) for cfg in config.cls_heads
]
else:
classification_heads = []
return bert_pretrainer.BertPretrainerV2( return bert_pretrainer.BertPretrainerV2(
config.num_masked_tokens,
mlm_activation=tf_utils.get_activation(encoder_cfg.hidden_activation), mlm_activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
mlm_initializer=tf.keras.initializers.TruncatedNormal( mlm_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range), stddev=encoder_cfg.initializer_range),
encoder_network=encoder_network, encoder_network=encoder_network,
classification_heads=classification_heads) classification_heads=instantiate_classification_heads_from_cfgs(
config.cls_heads))
@dataclasses.dataclass
class BertPretrainDataConfig(cfg.DataConfig):
"""Data config for BERT pretraining task."""
input_path: str = ""
global_batch_size: int = 512
is_training: bool = True
seq_length: int = 512
max_predictions_per_seq: int = 76
use_next_sentence_label: bool = True
use_position_id: bool = False
@dataclasses.dataclass
class BertPretrainEvalDataConfig(BertPretrainDataConfig):
"""Data config for the eval set in BERT pretraining task."""
input_path: str = ""
global_batch_size: int = 512
is_training: bool = False
@dataclasses.dataclass
class BertSentencePredictionDataConfig(cfg.DataConfig):
"""Data of sentence prediction dataset."""
input_path: str = ""
global_batch_size: int = 32
is_training: bool = True
seq_length: int = 128
@dataclasses.dataclass
class BertSentencePredictionDevDataConfig(cfg.DataConfig):
"""Dev data of MNLI sentence prediction dataset."""
input_path: str = ""
global_batch_size: int = 32
is_training: bool = False
seq_length: int = 128
drop_remainder: bool = False
...@@ -26,7 +26,7 @@ class BertModelsTest(tf.test.TestCase): ...@@ -26,7 +26,7 @@ class BertModelsTest(tf.test.TestCase):
def test_network_invocation(self): def test_network_invocation(self):
config = bert.BertPretrainerConfig( config = bert.BertPretrainerConfig(
encoder=encoders.TransformerEncoderConfig(vocab_size=10, num_layers=1)) encoder=encoders.TransformerEncoderConfig(vocab_size=10, num_layers=1))
_ = bert.instantiate_from_cfg(config) _ = bert.instantiate_pretrainer_from_cfg(config)
# Invokes with classification heads. # Invokes with classification heads.
config = bert.BertPretrainerConfig( config = bert.BertPretrainerConfig(
...@@ -35,7 +35,7 @@ class BertModelsTest(tf.test.TestCase): ...@@ -35,7 +35,7 @@ class BertModelsTest(tf.test.TestCase):
bert.ClsHeadConfig( bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence") inner_dim=10, num_classes=2, name="next_sentence")
]) ])
_ = bert.instantiate_from_cfg(config) _ = bert.instantiate_pretrainer_from_cfg(config)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
config = bert.BertPretrainerConfig( config = bert.BertPretrainerConfig(
...@@ -47,7 +47,7 @@ class BertModelsTest(tf.test.TestCase): ...@@ -47,7 +47,7 @@ class BertModelsTest(tf.test.TestCase):
bert.ClsHeadConfig( bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence") inner_dim=10, num_classes=2, name="next_sentence")
]) ])
_ = bert.instantiate_from_cfg(config) _ = bert.instantiate_pretrainer_from_cfg(config)
def test_checkpoint_items(self): def test_checkpoint_items(self):
config = bert.BertPretrainerConfig( config = bert.BertPretrainerConfig(
...@@ -56,9 +56,10 @@ class BertModelsTest(tf.test.TestCase): ...@@ -56,9 +56,10 @@ class BertModelsTest(tf.test.TestCase):
bert.ClsHeadConfig( bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence") inner_dim=10, num_classes=2, name="next_sentence")
]) ])
encoder = bert.instantiate_from_cfg(config) encoder = bert.instantiate_pretrainer_from_cfg(config)
self.assertSameElements(encoder.checkpoint_items.keys(), self.assertSameElements(
["encoder", "next_sentence.pooler_dense"]) encoder.checkpoint_items.keys(),
["encoder", "masked_lm", "next_sentence.pooler_dense"])
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment