Commit a04d9e0e authored by Vishnu Banna's avatar Vishnu Banna
Browse files

merged

parents 64f16d61 bcbce005
* @tensorflow/tf-garden-team @tensorflow/tf-model-garden-team
/official/ @rachellj218 @saberkun @jaeyounkim
/official/nlp/ @saberkun @lehougoogle @rachellj218 @jaeyounkim
/official/recommendation/ranking/ @gagika
/official/vision/ @xianzhidu @yeqingli @arashwan @saberkun @rachellj218 @jaeyounkim
/official/vision/beta/projects/assemblenet/ @mryoo
/official/vision/beta/projects/deepmac_maskrcnn/ @vighneshbirodkar
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
......@@ -38,7 +38,10 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
# Special keys in train/validate step returned logs.
loss = "loss"
def __init__(self, params, logging_dir: str = None, name: str = None):
def __init__(self,
params,
logging_dir: Optional[str] = None,
name: Optional[str] = None):
"""Task initialization.
Args:
......@@ -294,11 +297,38 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
return model(inputs, training=False)
def aggregate_logs(self, state, step_logs):
"""Optional aggregation over logs returned from a validation step."""
"""Optional aggregation over logs returned from a validation step.
Given step_logs from a validation step, this function aggregates the logs
after each eval_step() (see eval_reduce() function in
official/core/base_trainer.py). It runs on CPU and can be used to aggregate
metrics during validation, when there are too many metrics that cannot fit
into TPU memory. Note that this may increase latency due to data transfer
between TPU and CPU. Also, the step output from a validation step may be a
tuple with elements from replicas, and a concatenation of the elements is
needed in such case.
Args:
state: The current state of training, for example, it can be a sequence of
metrics.
step_logs: Logs from a validation step. Can be a dictionary.
"""
pass
def reduce_aggregated_logs(self,
aggregated_logs,
global_step: Optional[tf.Tensor] = None):
"""Optional reduce of aggregated logs over validation steps."""
"""Optional reduce of aggregated logs over validation steps.
This function reduces aggregated logs at the end of validation, and can be
used to compute the final metrics. It runs on CPU and in each eval_end() in
base trainer (see eval_end() function in official/core/base_trainer.py).
Args:
aggregated_logs: Aggregated logs over multiple validation steps.
global_step: An optional variable of global step.
Returns:
A dictionary of reduced results.
"""
return {}
......@@ -246,10 +246,11 @@ class Trainer(_AsyncTrainer):
self._train_loss = tf.keras.metrics.Mean("training_loss", dtype=tf.float32)
self._validation_loss = tf.keras.metrics.Mean(
"validation_loss", dtype=tf.float32)
model_metrics = model.metrics if hasattr(model, "metrics") else []
self._train_metrics = self.task.build_metrics(
training=True) + self.model.metrics
training=True) + model_metrics
self._validation_metrics = self.task.build_metrics(
training=False) + self.model.metrics
training=False) + model_metrics
self.init_async()
......
......@@ -28,4 +28,4 @@ def hard_sigmoid(features):
The activation value.
"""
features = tf.convert_to_tensor(features)
return tf.nn.relu6(features + tf.constant(3.)) * 0.16667
return tf.nn.relu6(features + tf.cast(3., features.dtype)) * 0.16667
......@@ -52,7 +52,8 @@ def hard_swish(features):
The activation value.
"""
features = tf.convert_to_tensor(features)
return features * tf.nn.relu6(features + tf.constant(3.)) * (1. / 6.)
fdtype = features.dtype
return features * tf.nn.relu6(features + tf.cast(3., fdtype)) * (1. / 6.)
@tf.keras.utils.register_keras_serializable(package='Text')
......
......@@ -41,6 +41,7 @@ class OptimizerConfig(oneof.OneOfConfig):
rmsprop: rmsprop optimizer.
lars: lars optimizer.
adagrad: adagrad optimizer.
slide: slide optimizer.
"""
type: Optional[str] = None
sgd: opt_cfg.SGDConfig = opt_cfg.SGDConfig()
......@@ -50,6 +51,7 @@ class OptimizerConfig(oneof.OneOfConfig):
rmsprop: opt_cfg.RMSPropConfig = opt_cfg.RMSPropConfig()
lars: opt_cfg.LARSConfig = opt_cfg.LARSConfig()
adagrad: opt_cfg.AdagradConfig = opt_cfg.AdagradConfig()
slide: opt_cfg.SLIDEConfig = opt_cfg.SLIDEConfig()
@dataclasses.dataclass
......
......@@ -226,3 +226,24 @@ class LARSConfig(BaseOptimizerConfig):
classic_momentum: bool = True
exclude_from_weight_decay: Optional[List[str]] = None
exclude_from_layer_adaptation: Optional[List[str]] = None
@dataclasses.dataclass
class SLIDEConfig(BaseOptimizerConfig):
"""Configuration for SLIDE optimizer.
Details coming soon.
"""
name: str = "SLIDE"
beta_1: float = 0.9
beta_2: float = 0.999
epsilon: float = 1e-6
weight_decay_rate: float = 0.0
weight_decay_type: str = "inner"
exclude_from_weight_decay: Optional[List[str]] = None
exclude_from_layer_adaptation: Optional[List[str]] = None
include_in_sparse_layer_adaptation: Optional[List[str]] = None
sparse_layer_learning_rate: float = 0.1
do_gradient_rescaling: bool = True
norm_type: str = "layer"
ratio_clip_norm: float = 1e5
......@@ -14,7 +14,7 @@
"""Exponential moving average optimizer."""
from typing import Text, List
from typing import List, Optional, Text
import tensorflow as tf
......@@ -106,7 +106,7 @@ class ExponentialMovingAverage(tf.keras.optimizers.Optimizer):
def _create_slots(self, var_list):
self._optimizer._create_slots(var_list=var_list) # pylint: disable=protected-access
def apply_gradients(self, grads_and_vars, name: Text = None):
def apply_gradients(self, grads_and_vars, name: Optional[Text] = None):
result = self._optimizer.apply_gradients(grads_and_vars, name)
self.update_average(self.iterations)
return result
......
......@@ -13,12 +13,13 @@
# limitations under the License.
"""Optimizer factory class."""
from typing import Callable, Union
from typing import Callable, Optional, Union
import gin
import tensorflow as tf
import tensorflow_addons.optimizers as tfa_optimizers
from official.modeling.optimization import slide_optimizer
from official.modeling.optimization import ema_optimizer
from official.modeling.optimization import lars_optimizer
from official.modeling.optimization import lr_schedule
......@@ -33,6 +34,7 @@ OPTIMIZERS_CLS = {
'rmsprop': tf.keras.optimizers.RMSprop,
'lars': lars_optimizer.LARS,
'adagrad': tf.keras.optimizers.Adagrad,
'slide': slide_optimizer.SLIDE
}
LR_CLS = {
......@@ -134,8 +136,8 @@ class OptimizerFactory:
def build_optimizer(
self,
lr: Union[tf.keras.optimizers.schedules.LearningRateSchedule, float],
postprocessor: Callable[[tf.keras.optimizers.Optimizer],
tf.keras.optimizers.Optimizer] = None):
postprocessor: Optional[Callable[[tf.keras.optimizers.Optimizer],
tf.keras.optimizers.Optimizer]] = None):
"""Build optimizer.
Builds optimizer from config. It takes learning rate as input, and builds
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SLIDE optimizer.
A new optimizer that will be open sourced soon.
"""
SLIDE = "Unimplemented"
......@@ -284,8 +284,11 @@ class ProgressiveTrainer(trainer_lib.Trainer):
checkpoint_interval=checkpoint_interval,
)
# Make sure we export the last checkpoint.
last_checkpoint = (
self.global_step.numpy() == self._config.trainer.train_steps)
checkpoint_path = self._export_ckpt_manager.save(
checkpoint_number=self.global_step.numpy(),
check_interval=True)
check_interval=not last_checkpoint)
if checkpoint_path:
logging.info('Checkpoints exported: %s.', checkpoint_path)
......@@ -181,20 +181,21 @@ class AxProcessor(DataProcessor):
class ColaProcessor(DataProcessor):
"""Processor for the CoLA data set (GLUE version)."""
def __init__(self, process_text_fn=tokenization.convert_to_unicode):
super(ColaProcessor, self).__init__(process_text_fn)
self.dataset = tfds.load("glue/cola", try_gcs=True)
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
return self._create_examples_tfds("train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
return self._create_examples_tfds("validation")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
return self._create_examples_tfds("test")
def get_labels(self):
"""See base class."""
......@@ -205,22 +206,19 @@ class ColaProcessor(DataProcessor):
"""See base class."""
return "COLA"
def _create_examples(self, lines, set_type):
def _create_examples_tfds(self, set_type):
"""Creates examples for the training/dev/test sets."""
dataset = self.dataset[set_type].as_numpy_iterator()
examples = []
for i, line in enumerate(lines):
# Only the test set has a header.
if set_type == "test" and i == 0:
continue
for i, example in enumerate(dataset):
guid = "%s-%s" % (set_type, i)
if set_type == "test":
text_a = self.process_text_fn(line[1])
label = "0"
else:
text_a = self.process_text_fn(line[3])
label = self.process_text_fn(line[1])
text_a = self.process_text_fn(example["sentence"])
if set_type != "test":
label = str(example["label"])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
InputExample(
guid=guid, text_a=text_a, text_b=None, label=label, weight=None))
return examples
......
......@@ -14,7 +14,7 @@
"""Loads dataset for the sentence prediction (classification) task."""
import functools
from typing import List, Mapping, Optional
from typing import List, Mapping, Optional, Tuple
import dataclasses
import tensorflow as tf
......@@ -40,6 +40,10 @@ class SentencePredictionDataConfig(cfg.DataConfig):
label_type: str = 'int'
# Whether to include the example id number.
include_example_id: bool = False
label_field: str = 'label_ids'
# Maps the key in TfExample to feature name.
# E.g 'label_ids' to 'next_sentence_labels'
label_name: Optional[Tuple[str, str]] = None
@data_loader_factory.register_data_loader_cls(SentencePredictionDataConfig)
......@@ -50,6 +54,11 @@ class SentencePredictionDataLoader(data_loader.DataLoader):
self._params = params
self._seq_length = params.seq_length
self._include_example_id = params.include_example_id
self._label_field = params.label_field
if params.label_name:
self._label_name_mapping = dict([params.label_name])
else:
self._label_name_mapping = dict()
def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example."""
......@@ -58,7 +67,7 @@ class SentencePredictionDataLoader(data_loader.DataLoader):
'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'label_ids': tf.io.FixedLenFeature([], label_type),
self._label_field: tf.io.FixedLenFeature([], label_type),
}
if self._include_example_id:
name_to_features['example_id'] = tf.io.FixedLenFeature([], tf.int64)
......@@ -85,8 +94,12 @@ class SentencePredictionDataLoader(data_loader.DataLoader):
if self._include_example_id:
x['example_id'] = record['example_id']
y = record['label_ids']
return (x, y)
x[self._label_field] = record[self._label_field]
if self._label_field in self._label_name_mapping:
x[self._label_name_mapping[self._label_field]] = record[self._label_field]
return x
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
"""Returns a tf.dataset.Dataset."""
......@@ -204,8 +217,8 @@ class SentencePredictionTextDataLoader(data_loader.DataLoader):
model_inputs = self._text_processor(segments)
if self._include_example_id:
model_inputs['example_id'] = record['example_id']
y = record[self._label_field]
return model_inputs, y
model_inputs[self._label_field] = record[self._label_field]
return model_inputs
def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example."""
......
......@@ -132,14 +132,40 @@ class SentencePredictionDataTest(tf.test.TestCase, parameterized.TestCase):
global_batch_size=batch_size,
label_type=label_type)
dataset = loader.SentencePredictionDataLoader(data_config).load()
features, labels = next(iter(dataset))
self.assertCountEqual(['input_word_ids', 'input_mask', 'input_type_ids'],
features = next(iter(dataset))
self.assertCountEqual(
['input_word_ids', 'input_type_ids', 'input_mask', 'label_ids'],
features.keys())
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
self.assertEqual(labels.shape, (batch_size,))
self.assertEqual(labels.dtype, expected_label_type)
self.assertEqual(features['label_ids'].shape, (batch_size,))
self.assertEqual(features['label_ids'].dtype, expected_label_type)
def test_load_dataset_with_label_mapping(self):
input_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
batch_size = 10
seq_length = 128
_create_fake_preprocessed_dataset(input_path, seq_length, 'int')
data_config = loader.SentencePredictionDataConfig(
input_path=input_path,
seq_length=seq_length,
global_batch_size=batch_size,
label_type='int',
label_name=('label_ids', 'next_sentence_labels'))
dataset = loader.SentencePredictionDataLoader(data_config).load()
features = next(iter(dataset))
self.assertCountEqual([
'input_word_ids', 'input_mask', 'input_type_ids',
'next_sentence_labels', 'label_ids'
], features.keys())
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['label_ids'].shape, (batch_size,))
self.assertEqual(features['label_ids'].dtype, tf.int32)
self.assertEqual(features['next_sentence_labels'].shape, (batch_size,))
self.assertEqual(features['next_sentence_labels'].dtype, tf.int32)
class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
......@@ -170,13 +196,15 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
lower_case=lower_case,
vocab_file=vocab_file_path)
dataset = loader.SentencePredictionTextDataLoader(data_config).load()
features, labels = next(iter(dataset))
self.assertCountEqual(['input_word_ids', 'input_type_ids', 'input_mask'],
features = next(iter(dataset))
label_field = data_config.label_field
self.assertCountEqual(
['input_word_ids', 'input_type_ids', 'input_mask', label_field],
features.keys())
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
self.assertEqual(labels.shape, (batch_size,))
self.assertEqual(features[label_field].shape, (batch_size,))
@parameterized.parameters(True, False)
def test_python_sentencepiece_preprocessing(self, use_tfds):
......@@ -203,13 +231,15 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
vocab_file=sp_model_file_path,
)
dataset = loader.SentencePredictionTextDataLoader(data_config).load()
features, labels = next(iter(dataset))
self.assertCountEqual(['input_word_ids', 'input_type_ids', 'input_mask'],
features = next(iter(dataset))
label_field = data_config.label_field
self.assertCountEqual(
['input_word_ids', 'input_type_ids', 'input_mask', label_field],
features.keys())
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
self.assertEqual(labels.shape, (batch_size,))
self.assertEqual(features[label_field].shape, (batch_size,))
@parameterized.parameters(True, False)
def test_saved_model_preprocessing(self, use_tfds):
......@@ -236,13 +266,15 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
label_type='int' if use_tfds else 'float',
)
dataset = loader.SentencePredictionTextDataLoader(data_config).load()
features, labels = next(iter(dataset))
self.assertCountEqual(['input_word_ids', 'input_type_ids', 'input_mask'],
features = next(iter(dataset))
label_field = data_config.label_field
self.assertCountEqual(
['input_word_ids', 'input_type_ids', 'input_mask', label_field],
features.keys())
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
self.assertEqual(labels.shape, (batch_size,))
self.assertEqual(features[label_field].shape, (batch_size,))
if __name__ == '__main__':
......
......@@ -15,7 +15,7 @@
"""XLNet models."""
# pylint: disable=g-classes-have-attributes
from typing import Any, Mapping, Union
from typing import Any, Mapping, Optional, Union
import tensorflow as tf
......@@ -99,7 +99,7 @@ class XLNetPretrainer(tf.keras.Model):
network: Union[tf.keras.layers.Layer, tf.keras.Model],
mlm_activation=None,
mlm_initializer='glorot_uniform',
name: str = None,
name: Optional[str] = None,
**kwargs):
super().__init__(name=name, **kwargs)
self._config = {
......
......@@ -431,17 +431,17 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
def _continue_search(self, state) -> tf.Tensor:
i = state[decoding_module.StateKeys.CUR_INDEX]
return tf.less(i, self.max_decode_length)
# Have we reached max decoding length?
not_at_end = tf.less(i, self.max_decode_length)
# Have all sampled sequences reached an EOS?
all_has_eos = tf.reduce_all(
state[decoding_module.StateKeys.FINISHED_FLAGS],
axis=None,
name="search_finish_cond")
return tf.logical_and(not_at_end, tf.logical_not(all_has_eos))
def _finished_flags(self, topk_ids, state) -> tf.Tensor:
new_finished_flags = tf.equal(topk_ids, self.eos_id)
new_finished_flags = tf.logical_or(
new_finished_flags, state[decoding_module.StateKeys.FINISHED_FLAGS])
return new_finished_flags
......@@ -22,7 +22,7 @@ modeling library:
* [mobile_bert_encoder.py](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/networks/mobile_bert_encoder.py)
contains `MobileBERTEncoder` implementation.
* [mobile_bert_layers.py](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/mobile_bert_layers.py)
contains `MobileBertEmbedding`, `MobileBertMaskedLM` and `MobileBertMaskedLM`
contains `MobileBertEmbedding`, `MobileBertTransformer` and `MobileBertMaskedLM`
implementation.
## Pre-trained Models
......
# TEAMS (Training ELECTRA Augmented with Multi-word Selection)
**Note:** This project is working in progress and please stay tuned.
TEAMS is a text encoder pre-training method that simultaneously learns a
generator and a discriminator using multi-task learning. We propose a new
pre-training task, multi-word selection, and combine it with previous
pre-training tasks for efficient encoder pre-training. We also develop two
techniques, attention-based task-specific heads and partial layer sharing,
to further improve pre-training effectiveness.
Our academic paper [[1]](#1) which describes TEAMS in detail can be found here:
https://arxiv.org/abs/2106.00139.
## References
<a id="1">[1]</a>
Jiaming Shen, Jialu Liu, Tianqi Liu, Cong Yu and Jiawei Han, "Training ELECTRA
Augmented with Multi-word Selection", Findings of the Association for
Computational Linguistics: ACL 2021.
......@@ -69,6 +69,10 @@ class SentencePredictionTask(base_task.Task):
if params.metric_type not in METRIC_TYPES:
raise ValueError('Invalid metric_type: {}'.format(params.metric_type))
self.metric_type = params.metric_type
if hasattr(params.train_data, 'label_field'):
self.label_field = params.train_data.label_field
else:
self.label_field = 'label_ids'
def build_model(self):
if self.task_config.hub_module_url and self.task_config.init_checkpoint:
......@@ -95,11 +99,12 @@ class SentencePredictionTask(base_task.Task):
use_encoder_pooler=self.task_config.model.use_encoder_pooler)
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
label_ids = labels[self.label_field]
if self.task_config.model.num_classes == 1:
loss = tf.keras.losses.mean_squared_error(labels, model_outputs)
loss = tf.keras.losses.mean_squared_error(label_ids, model_outputs)
else:
loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, tf.cast(model_outputs, tf.float32), from_logits=True)
label_ids, tf.cast(model_outputs, tf.float32), from_logits=True)
if aux_losses:
loss += tf.add_n(aux_losses)
......@@ -120,7 +125,8 @@ class SentencePredictionTask(base_task.Task):
y = tf.zeros((1,), dtype=tf.float32)
else:
y = tf.zeros((1, 1), dtype=tf.int32)
return x, y
x[self.label_field] = y
return x
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
......@@ -142,16 +148,16 @@ class SentencePredictionTask(base_task.Task):
def process_metrics(self, metrics, labels, model_outputs):
for metric in metrics:
metric.update_state(labels, model_outputs)
metric.update_state(labels[self.label_field], model_outputs)
def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
compiled_metrics.update_state(labels, model_outputs)
compiled_metrics.update_state(labels[self.label_field], model_outputs)
def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
if self.metric_type == 'accuracy':
return super(SentencePredictionTask,
self).validation_step(inputs, model, metrics)
features, labels = inputs
features, labels = inputs, inputs
outputs = self.inference_step(features, model)
loss = self.build_losses(
labels=labels, model_outputs=outputs, aux_losses=model.losses)
......@@ -161,12 +167,12 @@ class SentencePredictionTask(base_task.Task):
'sentence_prediction': # Ensure one prediction along batch dimension.
tf.expand_dims(tf.math.argmax(outputs, axis=1), axis=1),
'labels':
labels,
labels[self.label_field],
})
if self.metric_type == 'pearson_spearman_corr':
logs.update({
'sentence_prediction': outputs,
'labels': labels,
'labels': labels[self.label_field],
})
return logs
......@@ -206,10 +212,10 @@ class SentencePredictionTask(base_task.Task):
def initialize(self, model):
"""Load a pretrained checkpoint (if exists) and then train from iter 0."""
ckpt_dir_or_file = self.task_config.init_checkpoint
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
if not ckpt_dir_or_file:
return
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
pretrain2finetune_mapping = {
'encoder': model.checkpoint_items['encoder'],
......@@ -250,7 +256,7 @@ def predict(task: SentencePredictionTask,
def predict_step(inputs):
"""Replicated prediction calculation."""
x, _ = inputs
x = inputs
example_id = x.pop('example_id')
outputs = task.inference_step(x, model)
return dict(example_id=example_id, predictions=outputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment