Commit cf80ed4e authored by anivegesana's avatar anivegesana
Browse files

Merge branch 'purdue-yolo' of https://github.com/tensorflow/models into detection_generator_pr_2

parents 394cefcc 461b3587
......@@ -239,9 +239,10 @@ class TrainerConfig(base_config.Config):
@dataclasses.dataclass
class TaskConfig(base_config.Config):
init_checkpoint: str = ""
model: base_config.Config = None
model: Optional[base_config.Config] = None
train_data: DataConfig = DataConfig()
validation_data: DataConfig = DataConfig()
name: Optional[str] = None
@dataclasses.dataclass
......
......@@ -142,14 +142,19 @@ class BestCheckpointExporter:
return self._checkpoint_manager
def maybe_export_checkpoint(self, checkpoint, eval_logs, global_step):
def maybe_export_checkpoint(
self, checkpoint, eval_logs, global_step, write_logs=True) -> bool:
"""Compare eval_logs with past eval_logs and export checkpoint if better."""
logging.info('[BestCheckpointExporter] received eval_logs: %s, at step: %d',
eval_logs, global_step)
if self._best_ckpt_logs is None or self._new_metric_is_better(
self._best_ckpt_logs, eval_logs):
self._best_ckpt_logs = eval_logs
self._export_best_eval_metric(checkpoint, self._best_ckpt_logs,
global_step)
if write_logs:
self.export_best_eval_metric(self._best_ckpt_logs, global_step)
self._get_checkpoint_manager(checkpoint).save()
return True
return False
def _maybe_load_best_eval_metric(self):
if not tf.io.gfile.exists(self.best_ckpt_logs_path):
......@@ -180,7 +185,7 @@ class BestCheckpointExporter:
return True
return False
def _export_best_eval_metric(self, checkpoint, eval_logs, global_step):
def export_best_eval_metric(self, eval_logs, global_step):
"""Export evaluation results of the best checkpoint into a json file."""
eval_logs_ext = copy.copy(eval_logs)
eval_logs_ext['best_ckpt_global_step'] = global_step
......@@ -190,8 +195,6 @@ class BestCheckpointExporter:
with tf.io.gfile.GFile(self.best_ckpt_logs_path, 'w') as writer:
writer.write(json.dumps(eval_logs_ext, indent=4) + '\n')
self._get_checkpoint_manager(checkpoint).save()
@property
def best_ckpt_logs(self):
return self._best_ckpt_logs
......@@ -377,11 +380,15 @@ def remove_ckpts(model_dir):
tf.io.gfile.remove(file_to_remove)
def try_count_params(model: tf.keras.Model):
def try_count_params(
model: Union[tf.Module, tf.keras.Model],
trainable_only: bool = False):
"""Count the number of parameters if model is possible.
Args:
model: Try to count the number of params in this model.
trainable_only: Whether to calculate trainable params only. This flag is
not used when the model has `count_params` attribute.
Returns:
The number of parameters or None.
......@@ -395,7 +402,13 @@ def try_count_params(model: tf.keras.Model):
'because the model was not feed any input, e.g., the max '
'train step already reached before this run.')
return None
return None
else:
total_params = 0
variables = model.trainable_variables if trainable_only else model.variables
for var in variables:
shape = tf.shape(var)
total_params += tf.math.reduce_prod(shape).numpy()
return total_params
def try_count_flops(model: Union[tf.Module, tf.keras.Model],
......
......@@ -23,6 +23,7 @@ from official.modeling import hyperparams
@dataclasses.dataclass
class TaskRoutine(hyperparams.Config):
# TODO(hongkuny): deprecate the task_name once we migrated client code.
task_name: str = ""
task_config: cfg.TaskConfig = None
eval_steps: Optional[int] = None
......@@ -76,4 +77,4 @@ class MultiEvalExperimentConfig(cfg.ExperimentConfig):
Attributes:
eval_tasks: individual evaluation tasks.
"""
eval_tasks: MultiTaskConfig = MultiTaskConfig()
eval_tasks: Tuple[TaskRoutine, ...] = ()
......@@ -16,14 +16,14 @@
The evaluator implements the Orbit `AbstractEvaluator` interface.
"""
from typing import Optional, Union
from typing import Dict, List, Optional, Union
import gin
import orbit
import tensorflow as tf
from official.core import base_task
from official.core import train_utils
from official.modeling.multitask import base_model
from official.modeling.multitask import multitask
@gin.configurable
......@@ -32,37 +32,39 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
def __init__(
self,
task: multitask.MultiTask,
eval_tasks: List[base_task.Task],
model: Union[tf.keras.Model, base_model.MultiTaskBaseModel],
global_step: Optional[tf.Variable] = None,
eval_steps: Optional[Dict[str, int]] = None,
checkpoint_exporter: Optional[train_utils.BestCheckpointExporter] = None):
"""Initialize common trainer for TensorFlow models.
Args:
task: A multitask.MultiTask instance.
eval_tasks: A list of tasks to evaluate.
model: tf.keras.Model instance.
global_step: the global step variable.
eval_steps: a dictionary of steps to run eval keyed by task names.
checkpoint_exporter: an object that has the `maybe_export_checkpoint`
interface.
"""
# Gets the current distribution strategy. If not inside any strategy scope,
# it gets a single-replica no-op strategy.
self._strategy = tf.distribute.get_strategy()
self._task = task
self._tasks = eval_tasks
self._model = model
self._global_step = global_step or orbit.utils.create_global_step()
self._checkpoint_exporter = checkpoint_exporter
self._checkpoint = tf.train.Checkpoint(
global_step=self.global_step,
model=self.model)
global_step=self.global_step, model=self.model)
self._validation_losses = None
self._validation_metrics = None
# Builds per-task datasets.
self.eval_datasets = {}
for name, task in self.task.tasks.items():
self.eval_datasets[name] = orbit.utils.make_distributed_dataset(
self.eval_steps = eval_steps or {}
for task in self.tasks:
self.eval_datasets[task.name] = orbit.utils.make_distributed_dataset(
self.strategy, task.build_inputs, task.task_config.validation_data)
# Builds per-task validation loops.
......@@ -89,8 +91,7 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
return orbit.utils.create_loop_fn(eval_step_fn)
self.task_fns = {
name: get_function(name, task)
for name, task in self.task.tasks.items()
task.name: get_function(task.name, task) for task in self.tasks
}
@property
......@@ -98,8 +99,8 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
return self._strategy
@property
def task(self):
return self._task
def tasks(self):
return self._tasks
@property
def model(self):
......@@ -115,8 +116,8 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
if self._validation_losses is None:
# Builds the per-task metrics and losses.
self._validation_losses = {}
for name in self.task.tasks:
self._validation_losses[name] = tf.keras.metrics.Mean(
for task in self.tasks:
self._validation_losses[task.name] = tf.keras.metrics.Mean(
"validation_loss", dtype=tf.float32)
return self._validation_losses
......@@ -126,8 +127,8 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
if self._validation_metrics is None:
# Builds the per-task metrics and losses.
self._validation_metrics = {}
for name, task in self.task.tasks.items():
self._validation_metrics[name] = task.build_metrics(training=False)
for task in self.tasks:
self._validation_metrics[task.name] = task.build_metrics(training=False)
return self._validation_metrics
@property
......@@ -145,12 +146,12 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
results = {}
eval_iters = tf.nest.map_structure(iter, self.eval_datasets)
for name, task_eval_loop in self.task_fns.items():
for task in self.tasks:
outputs = None
name = task.name
eval_iter = eval_iters[name]
task = self.task.tasks[name]
task_eval_steps = self.task.task_eval_steps(name) or num_steps
outputs = task_eval_loop(
task_eval_steps = self.eval_steps.get(name, None) or num_steps
outputs = self.task_fns[name](
eval_iter,
task_eval_steps,
state=outputs,
......
......@@ -22,7 +22,6 @@ from tensorflow.python.distribute import strategy_combinations
from official.core import base_task
from official.core import config_definitions as cfg
from official.modeling.multitask import evaluator
from official.modeling.multitask import multitask
def all_strategy_combinations():
......@@ -89,9 +88,7 @@ class MockTask(base_task.Task):
np.concatenate([np.expand_dims(v.numpy(), axis=0) for v in value]))
return state
def reduce_aggregated_logs(self,
aggregated_logs,
global_step=None):
def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
for k, v in aggregated_logs.items():
aggregated_logs[k] = np.sum(np.stack(v, axis=0))
return aggregated_logs
......@@ -106,10 +103,9 @@ class EvaluatorTest(tf.test.TestCase, parameterized.TestCase):
MockTask(params=cfg.TaskConfig(), name="bar"),
MockTask(params=cfg.TaskConfig(), name="foo")
]
test_multitask = multitask.MultiTask(tasks=tasks)
model = MockModel()
test_evaluator = evaluator.MultiTaskEvaluator(
task=test_multitask, model=model)
eval_tasks=tasks, model=model)
results = test_evaluator.evaluate(tf.convert_to_tensor(1, dtype=tf.int32))
self.assertContainsSubset(["validation_loss", "acc"], results["bar"].keys())
self.assertContainsSubset(["validation_loss", "acc"], results["foo"].keys())
......@@ -123,10 +119,9 @@ class EvaluatorTest(tf.test.TestCase, parameterized.TestCase):
MockTask(params=cfg.TaskConfig(), name="bar"),
MockTask(params=cfg.TaskConfig(), name="foo")
]
test_multitask = multitask.MultiTask(tasks=tasks)
model = MockModel()
test_evaluator = evaluator.MultiTaskEvaluator(
task=test_multitask, model=model)
eval_tasks=tasks, model=model)
results = test_evaluator.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertEqual(results["bar"]["counter"],
5. * distribution.num_replicas_in_sync)
......
......@@ -34,7 +34,7 @@ class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer):
optimizer: tf.optimizers.Optimizer,
task_sampler: sampler.TaskSampler,
trainer_options=None):
super(MultiTaskInterleavingTrainer, self).__init__(
super().__init__(
multi_task=multi_task,
multi_task_model=multi_task_model,
optimizer=optimizer,
......@@ -90,3 +90,13 @@ class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer):
self._task_train_step_map[name], args=(next(iterator_map[name]),))
self.global_step.assign_add(1)
self.task_step_counter(name).assign_add(1)
def train_loop_end(self):
"""Record loss and metric values per task."""
result = super().train_loop_end()
# Interleaving training does not have a good semantic for `total_loss`. In
# fact, it is always zero. To avoid confusion, we filter the `total_loss`
# from the result logs.
if 'total_loss' in result:
result.pop('total_loss')
return result
......@@ -60,6 +60,7 @@ class InterleavingTrainerTest(tf.test.TestCase, parameterized.TestCase):
results["bar"].keys())
self.assertContainsSubset(["training_loss", "foo_acc"],
results["foo"].keys())
self.assertNotIn("total_loss", results)
@combinations.generate(all_strategy_combinations())
def test_trainer_with_configs(self, distribution):
......
......@@ -59,10 +59,7 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
else:
raise ValueError("The tasks argument has an invalid type: %s" %
type(tasks))
self._task_eval_steps = task_eval_steps or {}
self._task_eval_steps = dict([
(name, self._task_eval_steps.get(name, None)) for name in self.tasks
])
self.task_eval_steps = task_eval_steps or {}
self._task_weights = task_weights or {}
self._task_weights = dict([
(name, self._task_weights.get(name, 1.0)) for name in self.tasks
......@@ -74,9 +71,9 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
task_eval_steps = {}
task_weights = {}
for task_routine in config.task_routines:
task_name = task_routine.task_name
task_name = task_routine.task_name or task_routine.task_config.name
tasks[task_name] = task_factory.get_task(
task_routine.task_config, logging_dir=logging_dir)
task_routine.task_config, logging_dir=logging_dir, name=task_name)
task_eval_steps[task_name] = task_routine.eval_steps
task_weights[task_name] = task_routine.task_weight
return cls(
......@@ -86,9 +83,6 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
def tasks(self):
return self._tasks
def task_eval_steps(self, task_name):
return self._task_eval_steps[task_name]
def task_weight(self, task_name):
return self._task_weights[task_name]
......
......@@ -15,7 +15,7 @@
"""Multitask training driver library."""
# pytype: disable=attribute-error
import os
from typing import Optional
from typing import List, Optional
from absl import logging
import orbit
import tensorflow as tf
......@@ -69,9 +69,11 @@ def run_experiment(*, distribution_strategy: tf.distribute.Strategy,
trainer = TRAINERS[params.trainer.trainer_type](
**kwargs) if is_training else None
if is_eval:
eval_steps = task.task_eval_steps
evaluator = evaluator_lib.MultiTaskEvaluator(
task=task,
eval_tasks=task.tasks.values(),
model=model,
eval_steps=eval_steps,
global_step=trainer.global_step if is_training else None,
checkpoint_exporter=train_utils.maybe_create_best_ckpt_exporter(
params, model_dir))
......@@ -137,7 +139,7 @@ def run_experiment_with_multitask_eval(
*,
distribution_strategy: tf.distribute.Strategy,
train_task: base_task.Task,
eval_tasks: multitask.MultiTask,
eval_tasks: List[base_task.Task],
mode: str,
params: configs.MultiEvalExperimentConfig,
model_dir: str,
......@@ -149,7 +151,7 @@ def run_experiment_with_multitask_eval(
Args:
distribution_strategy: A distribution distribution_strategy.
train_task: A base_task.Task instance.
eval_tasks: A multitask.MultiTask with evaluation tasks.
eval_tasks: A list of evaluation tasks.
mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
or 'continuous_eval'.
params: MultiEvalExperimentConfig instance.
......@@ -173,8 +175,8 @@ def run_experiment_with_multitask_eval(
config=params,
task=train_task,
model=train_task.build_model(),
optimizer=train_task.create_optimizer(
params.trainer.optimizer_config, params.runtime),
optimizer=train_task.create_optimizer(params.trainer.optimizer_config,
params.runtime),
train=True,
evaluate=False)
else:
......@@ -182,10 +184,14 @@ def run_experiment_with_multitask_eval(
model = trainer.model if trainer else train_task.build_model()
if is_eval:
eval_steps = dict([(task_routine.task_config.name,
task_routine.eval_steps)
for task_routine in params.eval_tasks])
evaluator = evaluator_lib.MultiTaskEvaluator(
task=eval_tasks,
eval_tasks=eval_tasks,
model=model,
global_step=trainer.global_step if is_training else None,
eval_steps=eval_steps,
checkpoint_exporter=train_utils.maybe_create_best_ckpt_exporter(
params, model_dir))
else:
......
......@@ -65,8 +65,7 @@ class TrainLibTest(tf.test.TestCase, parameterized.TestCase):
task=configs.MultiTaskConfig(
task_routines=(
configs.TaskRoutine(
task_name='foo',
task_config=test_utils.FooConfig()),
task_name='foo', task_config=test_utils.FooConfig()),
configs.TaskRoutine(
task_name='bar', task_config=test_utils.BarConfig()))))
experiment_config = params_dict.override_params_dict(
......@@ -95,18 +94,20 @@ class TrainLibTest(tf.test.TestCase, parameterized.TestCase):
model_dir = self.get_temp_dir()
experiment_config = configs.MultiEvalExperimentConfig(
task=test_utils.FooConfig(),
eval_tasks=configs.MultiTaskConfig(
task_routines=(
configs.TaskRoutine(
task_name='foo',
task_config=test_utils.FooConfig()),
configs.TaskRoutine(
task_name='bar', task_config=test_utils.BarConfig()))))
eval_tasks=(configs.TaskRoutine(
task_name='foo', task_config=test_utils.FooConfig(), eval_steps=2),
configs.TaskRoutine(
task_name='bar',
task_config=test_utils.BarConfig(),
eval_steps=3)))
experiment_config = params_dict.override_params_dict(
experiment_config, self._test_config, is_strict=False)
with distribution_strategy.scope():
train_task = task_factory.get_task(experiment_config.task)
eval_tasks = multitask.MultiTask.from_config(experiment_config.eval_tasks)
eval_tasks = [
task_factory.get_task(config.task_config, name=config.task_name)
for config in experiment_config.eval_tasks
]
train_lib.run_experiment_with_multitask_eval(
distribution_strategy=distribution_strategy,
train_task=train_task,
......
......@@ -28,7 +28,6 @@ from official.core import train_lib
from official.core import train_utils
from official.modeling import performance
from official.modeling.multitask import configs
from official.modeling.multitask import multitask
from official.modeling.multitask import train_lib as multitask_train_lib
......@@ -167,7 +166,10 @@ def run_continuous_finetune(
with distribution_strategy.scope():
if isinstance(params, configs.MultiEvalExperimentConfig):
task = task_factory.get_task(params_replaced.task)
eval_tasks = multitask.MultiTask.from_config(params_replaced.eval_tasks)
eval_tasks = [
task_factory.get_task(config.task_config, name=config.task_name)
for config in params.eval_tasks
]
(_,
eval_metrics) = multitask_train_lib.run_experiment_with_multitask_eval(
distribution_strategy=distribution_strategy,
......
......@@ -89,8 +89,7 @@ def _get_ngrams_with_counter(segment, max_order):
Args:
segment: text segment from which n-grams will be extracted.
max_order: maximum length in tokens of the n-grams returned by this
methods.
max_order: maximum length in tokens of the n-grams returned by this methods.
Returns:
The Counter containing all n-grams upto max_order in segment
......@@ -104,15 +103,17 @@ def _get_ngrams_with_counter(segment, max_order):
return ngram_counts
def compute_bleu(reference_corpus, translation_corpus, max_order=4,
def compute_bleu(reference_corpus,
translation_corpus,
max_order=4,
use_bp=True):
"""Computes BLEU score of translated segments against one or more references.
Args:
reference_corpus: list of references for each translation. Each
reference should be tokenized into a list of tokens.
translation_corpus: list of translations to score. Each translation
should be tokenized into a list of tokens.
reference_corpus: list of references for each translation. Each reference
should be tokenized into a list of tokens.
translation_corpus: list of translations to score. Each translation should
be tokenized into a list of tokens.
max_order: Maximum n-gram order to use when computing BLEU score.
use_bp: boolean, whether to apply brevity penalty.
......@@ -134,15 +135,14 @@ def compute_bleu(reference_corpus, translation_corpus, max_order=4,
ref_ngram_counts = _get_ngrams_with_counter(references, max_order)
translation_ngram_counts = _get_ngrams_with_counter(translations, max_order)
overlap = dict((ngram,
min(count, translation_ngram_counts[ngram]))
overlap = dict((ngram, min(count, translation_ngram_counts[ngram]))
for ngram, count in ref_ngram_counts.items())
for ngram in overlap:
matches_by_order[len(ngram) - 1] += overlap[ngram]
for ngram in translation_ngram_counts:
possible_matches_by_order[len(ngram) - 1] += translation_ngram_counts[
ngram]
possible_matches_by_order[len(ngram) -
1] += translation_ngram_counts[ngram]
precisions = [0] * max_order
smooth = 1.0
......@@ -151,8 +151,8 @@ def compute_bleu(reference_corpus, translation_corpus, max_order=4,
if possible_matches_by_order[i] > 0:
precisions[i] = float(matches_by_order[i]) / possible_matches_by_order[i]
if matches_by_order[i] > 0:
precisions[i] = float(matches_by_order[i]) / possible_matches_by_order[
i]
precisions[i] = float(
matches_by_order[i]) / possible_matches_by_order[i]
else:
smooth *= 2
precisions[i] = 1.0 / (smooth * possible_matches_by_order[i])
......@@ -165,7 +165,8 @@ def compute_bleu(reference_corpus, translation_corpus, max_order=4,
if use_bp:
ratio = translation_length / reference_length
bp = math.exp(1 - 1. / ratio) if ratio < 1.0 else 1.0
bp = 0. if ratio < 1e-6 else math.exp(1 -
1. / ratio) if ratio < 1.0 else 1.0
bleu = geo_mean * bp
return np.float32(bleu)
......
......@@ -22,3 +22,4 @@ from official.nlp.modeling import layers
from official.nlp.modeling import losses
from official.nlp.modeling import models
from official.nlp.modeling import networks
from official.nlp.modeling import ops
......@@ -39,6 +39,23 @@ class NoNorm(tf.keras.layers.Layer):
return output
@tf.keras.utils.register_keras_serializable(package='Text')
class NoNormClipped(NoNorm):
"""Quantization friendly implementation for the NoNorm.
The output of NoNorm layer is clipped to [-6.0, 6.0] to make it quantization
friendly.
"""
def __init__(self, name=None):
super(NoNormClipped, self).__init__(name=name)
def call(self, feature):
output = feature * self.scale + self.bias
clipped_output = tf.clip_by_value(output, -6.0, 6.0)
return clipped_output
def _get_norm_layer(normalization_type='no_norm', name=None):
"""Get normlization layer.
......@@ -52,6 +69,8 @@ def _get_norm_layer(normalization_type='no_norm', name=None):
"""
if normalization_type == 'no_norm':
layer = NoNorm(name=name)
elif normalization_type == 'no_norm_clipped':
layer = NoNormClipped(name=name)
elif normalization_type == 'layer_norm':
layer = tf.keras.layers.LayerNormalization(
name=name,
......
......@@ -33,6 +33,22 @@ def generate_fake_input(batch_size=1, seq_len=5, vocab_size=10000, seed=0):
return fake_input
class EdgeTPUNoNormTest(tf.test.TestCase):
def test_no_norm(self):
layer = mobile_bert_layers.NoNormClipped()
feature = tf.random.uniform(
[2, 3, 4], minval=-8, maxval=8, dtype=tf.float32)
output = layer(feature)
output_shape = output.shape.as_list()
expected_shape = [2, 3, 4]
self.assertListEqual(output_shape, expected_shape, msg=None)
output_min = tf.reduce_min(output)
output_max = tf.reduce_max(output)
self.assertGreaterEqual(6.0, output_max)
self.assertLessEqual(-6.0, output_min)
class MobileBertEncoderTest(parameterized.TestCase, tf.test.TestCase):
def test_embedding_layer_with_token_type(self):
......
......@@ -106,16 +106,19 @@ class SpectralNormalization(tf.keras.layers.Wrapper):
def call(self, inputs, *, training=None):
training = self.do_power_iteration if training is None else training
u_update_op, v_update_op, w_update_op = self.update_weights(
training=training)
output = self.layer(inputs)
w_restore_op = self.restore_weights()
# Register update ops.
self.add_update(u_update_op)
self.add_update(v_update_op)
self.add_update(w_update_op)
self.add_update(w_restore_op)
if training:
u_update_op, v_update_op, w_update_op = self.update_weights(
training=training)
output = self.layer(inputs)
w_restore_op = self.restore_weights()
# Register update ops.
self.add_update(u_update_op)
self.add_update(v_update_op)
self.add_update(w_update_op)
self.add_update(w_restore_op)
else:
output = self.layer(inputs)
return output
......
......@@ -45,10 +45,11 @@ class BertClassifier(tf.keras.Model):
dropout_rate: The dropout probability of the cls head.
use_encoder_pooler: Whether to use the pooler layer pre-defined inside the
encoder.
head_name: Name of the classification head.
cls_head: (Optional) The layer instance to use for the classifier head.
It should take in the output from network and produce the final logits.
If set, the arguments ('num_classes', 'initializer', 'dropout_rate',
'use_encoder_pooler') will be ignored.
'use_encoder_pooler', 'head_name') will be ignored.
"""
def __init__(self,
......@@ -57,9 +58,11 @@ class BertClassifier(tf.keras.Model):
initializer='glorot_uniform',
dropout_rate=0.1,
use_encoder_pooler=True,
head_name='sentence_prediction',
cls_head=None,
**kwargs):
self.num_classes = num_classes
self.head_name = head_name
self.initializer = initializer
self.use_encoder_pooler = use_encoder_pooler
......@@ -92,7 +95,7 @@ class BertClassifier(tf.keras.Model):
num_classes=num_classes,
initializer=initializer,
dropout_rate=dropout_rate,
name='sentence_prediction')
name=head_name)
predictions = classifier(cls_inputs)
......@@ -137,6 +140,7 @@ class BertClassifier(tf.keras.Model):
return {
'network': self._network,
'num_classes': self.num_classes,
'head_name': self.head_name,
'initializer': self.initializer,
'use_encoder_pooler': self.use_encoder_pooler,
'cls_head': self._cls_head,
......
......@@ -87,10 +87,8 @@ class BertClassifierTest(keras_parameterized.TestCase):
inner_dim=0, num_classes=4)))
def test_serialize_deserialize(self, cls_head):
"""Validate that the BERT trainer can be serialized and deserialized."""
# Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.)
test_network = networks.BertEncoder(
vocab_size=100, num_layers=2, sequence_length=5)
# Build a transformer network to use within the BERT trainer.
test_network = networks.BertEncoder(vocab_size=100, num_layers=2)
# Create a BERT trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.)
......
......@@ -67,10 +67,8 @@ class BertPretrainerTest(keras_parameterized.TestCase):
def test_bert_trainer_tensor_call(self):
"""Validate that the Keras object can be invoked."""
# Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.)
test_network = networks.BertEncoder(
vocab_size=100, num_layers=2, sequence_length=2)
# Build a transformer network to use within the BERT trainer.
test_network = networks.BertEncoder(vocab_size=100, num_layers=2)
# Create a BERT trainer with the created network.
bert_trainer_model = bert_pretrainer.BertPretrainer(
......@@ -213,10 +211,8 @@ class BertPretrainerV2Test(keras_parameterized.TestCase):
def test_v2_serialize_deserialize(self):
"""Validate that the BERT trainer can be serialized and deserialized."""
# Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.)
test_network = networks.BertEncoder(
vocab_size=100, num_layers=2, sequence_length=5)
# Build a transformer network to use within the BERT trainer.
test_network = networks.BertEncoder(vocab_size=100, num_layers=2)
# Create a BERT trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.)
......
......@@ -93,10 +93,8 @@ class BertSpanLabelerTest(keras_parameterized.TestCase):
def test_serialize_deserialize(self):
"""Validate that the BERT trainer can be serialized and deserialized."""
# Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.)
test_network = networks.BertEncoder(
vocab_size=100, num_layers=2, sequence_length=5)
# Build a transformer network to use within the BERT trainer.
test_network = networks.BertEncoder(vocab_size=100, num_layers=2)
# Create a BERT trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment