"vscode:/vscode.git/clone" did not exist on "bda18166ec804e33e88fd0c4b33decc1308926a6"
Commit b92025a9 authored by anivegesana's avatar anivegesana
Browse files

Merge branch 'master' of https://github.com/tensorflow/models into detection_generator_pr_2

parents 1b425791 37536370
...@@ -211,6 +211,44 @@ class PowerDecayWithOffsetLrConfig(base_config.Config): ...@@ -211,6 +211,44 @@ class PowerDecayWithOffsetLrConfig(base_config.Config):
pre_offset_learning_rate: float = 1.0e6 pre_offset_learning_rate: float = 1.0e6
@dataclasses.dataclass
class StepCosineLrConfig(base_config.Config):
"""Configuration for stepwise learning rate decay.
This class is a container for the piecewise cosine learning rate scheduling
configs. It will configure an instance of StepConsineDecayWithOffset keras
learning rate schedule.
```python
boundaries: [100000, 110000]
values: [1.0, 0.5]
lr_decayed_fn = (
lr_schedule.StepConsineDecayWithOffset(
boundaries,
values))
```
from 0 to 100000 step, it will cosine decay from 1.0 to 0.5
from 100000 to 110000 step, it cosine decay from 0.5 to 0.0
Attributes:
name: The name of the learning rate schedule. Defaults to PiecewiseConstant.
boundaries: A list of ints of strictly increasing entries. Defaults to None.
values: A list of floats that specifies the values for the intervals defined
by `boundaries`. It should have one more element than `boundaries`.
The learning rate is computed as follows:
[0, boundaries[0]] -> cosine from values[0] to values[1]
[boundaries[0], boundaries[1]] -> values[1] to values[2]
...
[boundaries[n-1], boundaries[n]] -> values[n] to values[n+1]
[boundaries[n], end] -> values[n+1] to 0.
offset: An int. The offset applied to steps. Defaults to 0.
"""
name: str = 'StepConsineDecayWithOffset'
boundaries: Optional[List[int]] = None
values: Optional[List[float]] = None
offset: int = 0
@dataclasses.dataclass @dataclasses.dataclass
class LinearWarmupConfig(base_config.Config): class LinearWarmupConfig(base_config.Config):
"""Configuration for linear warmup schedule config. """Configuration for linear warmup schedule config.
......
...@@ -70,6 +70,7 @@ class LrConfig(oneof.OneOfConfig): ...@@ -70,6 +70,7 @@ class LrConfig(oneof.OneOfConfig):
power_linear: learning rate config of step^power followed by power_linear: learning rate config of step^power followed by
step^power*linear. step^power*linear.
power_with_offset: power decay with a step offset. power_with_offset: power decay with a step offset.
step_cosine_with_offset: Step cosine with a step offset.
""" """
type: Optional[str] = None type: Optional[str] = None
constant: lr_cfg.ConstantLrConfig = lr_cfg.ConstantLrConfig() constant: lr_cfg.ConstantLrConfig = lr_cfg.ConstantLrConfig()
...@@ -82,6 +83,8 @@ class LrConfig(oneof.OneOfConfig): ...@@ -82,6 +83,8 @@ class LrConfig(oneof.OneOfConfig):
lr_cfg.PowerAndLinearDecayLrConfig()) lr_cfg.PowerAndLinearDecayLrConfig())
power_with_offset: lr_cfg.PowerDecayWithOffsetLrConfig = ( power_with_offset: lr_cfg.PowerDecayWithOffsetLrConfig = (
lr_cfg.PowerDecayWithOffsetLrConfig()) lr_cfg.PowerDecayWithOffsetLrConfig())
step_cosine_with_offset: lr_cfg.StepCosineLrConfig = (
lr_cfg.StepCosineLrConfig())
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
"""Learning rate schedule classes.""" """Learning rate schedule classes."""
import math
from typing import Mapping, Any, Union, Optional from typing import Mapping, Any, Union, Optional
import tensorflow as tf import tensorflow as tf
...@@ -383,3 +384,113 @@ class PowerDecayWithOffset(tf.keras.optimizers.schedules.LearningRateSchedule): ...@@ -383,3 +384,113 @@ class PowerDecayWithOffset(tf.keras.optimizers.schedules.LearningRateSchedule):
"pre_offset_learning_rate": self._pre_offset_lr, "pre_offset_learning_rate": self._pre_offset_lr,
"name": self._name, "name": self._name,
} }
class StepConsineDecayWithOffset(
tf.keras.optimizers.schedules.LearningRateSchedule):
"""Stepwise cosine learning rate decay with offset.
Learning rate is equivalent to one or more consine decay(s) starting and
ending at each interval.
ExampleL
```python
boundaries: [100000, 110000]
values: [1.0, 0.5]
lr_decayed_fn = (
lr_schedule.StepConsineDecayWithOffset(
boundaries,
values))
```
from 0 to 100000 step, it will cosine decay from 1.0 to 0.5
from 100000 to 110000 step, it cosine decay from 0.5 to 0.0
"""
def __init__(self,
boundaries,
values,
offset: int = 0,
name: str = "StepConsineDecayWithOffset"):
"""Initialize configuration of the learning rate schedule.
Args:
boundaries: A list of `Tensor`s or `int`s with strictly
increasing entries, and with all elements having the same type as the
optimizer step.
values: A list of `Tensor`s or `float`s that specifies the
values for the intervals defined by `boundaries`. It should have one
more element than `boundaries`, and all elements should have the same
type.
offset: The offset when computing the power decay.
name: Optional, name of learning rate schedule.
"""
super().__init__()
self.values = values
self.boundaries = boundaries
self.offset = offset
self.name = name
if len(self.values) < 1:
raise ValueError(f"Expect non empty {self.values}")
if len(self.boundaries) != len(self.values):
raise ValueError(
"Boundaries length is equal to learning rate levels length"
f"{len(self.boundaries)} != {len(self.values)}")
self.total_steps = (
[boundaries[i + 1] - boundaries[i] for i in range(len(boundaries) - 1)
] + [0])
def __call__(self, global_step):
with tf.name_scope(self.name or "StepConsineDecayWithOffset"):
global_step = tf.cast(global_step - self.offset, tf.float32)
lr_levels = self.values
lr_steps = self.boundaries
level_total_steps = self.total_steps
num_levels = len(lr_levels)
init_lr = lr_levels[0]
next_init_lr = lr_levels[1] if num_levels > 1 else 0.
init_total_steps = level_total_steps[0]
cosine_learning_rate = ((init_lr - next_init_lr) * (tf.cos(
tf.constant(math.pi) * (global_step) /
(init_total_steps)) + 1.0) / 2.0 + next_init_lr)
learning_rate = cosine_learning_rate
tf.compat.v1.logging.info("DEBUG lr %r next lr %r", learning_rate,
cosine_learning_rate)
tf.compat.v1.logging.info("DEBUG lr %r next lr %r inittotalstep %r",
init_lr, next_init_lr, init_total_steps)
for i in range(1, num_levels):
next_init_lr = lr_levels[i]
next_start_step = lr_steps[i]
next_total_steps = level_total_steps[i]
next_next_init_lr = lr_levels[i + 1] if num_levels > i + 1 else 0.
tf.compat.v1.logging.info(
"DEBUG step %r nilr %r nss %r nts %r nnilr %r", global_step,
next_init_lr, next_start_step, next_total_steps, next_next_init_lr)
next_cosine_learning_rate = ((next_init_lr - next_next_init_lr) *
(tf.cos(
tf.constant(math.pi) *
(global_step - next_start_step) /
(next_total_steps)) + 1.0) / 2.0 +
next_next_init_lr)
learning_rate = tf.where(global_step >= next_start_step,
next_cosine_learning_rate, learning_rate)
tf.compat.v1.logging.info("DEBUG lr %r next lr %r", learning_rate,
next_cosine_learning_rate)
return learning_rate
def get_config(self):
return {
"boundaries": self.boundaries,
"values": self.values,
"offset": self.offset,
"name": self.name
}
...@@ -47,6 +47,7 @@ LR_CLS = { ...@@ -47,6 +47,7 @@ LR_CLS = {
'power': lr_schedule.DirectPowerDecay, 'power': lr_schedule.DirectPowerDecay,
'power_linear': lr_schedule.PowerAndLinearDecay, 'power_linear': lr_schedule.PowerAndLinearDecay,
'power_with_offset': lr_schedule.PowerDecayWithOffset, 'power_with_offset': lr_schedule.PowerDecayWithOffset,
'step_cosine_with_offset': lr_schedule.StepConsineDecayWithOffset,
} }
WARMUP_CLS = { WARMUP_CLS = {
......
...@@ -394,5 +394,38 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -394,5 +394,38 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
for step, value in expected_lr_step_values: for step, value in expected_lr_step_values:
self.assertAlmostEqual(lr(step).numpy(), value) self.assertAlmostEqual(lr(step).numpy(), value)
def test_step_cosine_lr_schedule_with_warmup(self):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {
'momentum': 0.9
}
},
'learning_rate': {
'type': 'step_cosine_with_offset',
'step_cosine_with_offset': {
'values': (0.0001, 0.00005),
'boundaries': (0, 500000),
'offset': 10000,
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 10000,
'warmup_learning_rate': 0.0
}
}
}
expected_lr_step_values = [[0, 0.0], [5000, 1e-4/2.0], [10000, 1e-4],
[20000, 9.994863e-05], [499999, 5e-05]]
opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate()
for step, value in expected_lr_step_values:
self.assertAlmostEqual(lr(step).numpy(), value)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -44,6 +44,8 @@ class SentencePredictionDataConfig(cfg.DataConfig): ...@@ -44,6 +44,8 @@ class SentencePredictionDataConfig(cfg.DataConfig):
# Maps the key in TfExample to feature name. # Maps the key in TfExample to feature name.
# E.g 'label_ids' to 'next_sentence_labels' # E.g 'label_ids' to 'next_sentence_labels'
label_name: Optional[Tuple[str, str]] = None label_name: Optional[Tuple[str, str]] = None
# Either tfrecord, sstable, or recordio.
file_type: str = 'tfrecord'
@data_loader_factory.register_data_loader_cls(SentencePredictionDataConfig) @data_loader_factory.register_data_loader_cls(SentencePredictionDataConfig)
...@@ -111,7 +113,10 @@ class SentencePredictionDataLoader(data_loader.DataLoader): ...@@ -111,7 +113,10 @@ class SentencePredictionDataLoader(data_loader.DataLoader):
def load(self, input_context: Optional[tf.distribute.InputContext] = None): def load(self, input_context: Optional[tf.distribute.InputContext] = None):
"""Returns a tf.dataset.Dataset.""" """Returns a tf.dataset.Dataset."""
reader = input_reader.InputReader( reader = input_reader.InputReader(
params=self._params, decoder_fn=self._decode, parser_fn=self._parse) dataset_fn=dataset_fn.pick_dataset_fn(self._params.file_type),
params=self._params,
decoder_fn=self._decode,
parser_fn=self._parse)
return reader.read(input_context) return reader.read(input_context)
...@@ -168,7 +173,8 @@ class TextProcessor(tf.Module): ...@@ -168,7 +173,8 @@ class TextProcessor(tf.Module):
vocab_file=vocab_file, lower_case=lower_case) vocab_file=vocab_file, lower_case=lower_case)
elif tokenization == 'SentencePiece': elif tokenization == 'SentencePiece':
self._tokenizer = modeling.layers.SentencepieceTokenizer( self._tokenizer = modeling.layers.SentencepieceTokenizer(
model_file_path=vocab_file, lower_case=lower_case, model_file_path=vocab_file,
lower_case=lower_case,
strip_diacritics=True) # Strip diacritics to follow ALBERT model strip_diacritics=True) # Strip diacritics to follow ALBERT model
else: else:
raise ValueError('Unsupported tokenization: %s' % tokenization) raise ValueError('Unsupported tokenization: %s' % tokenization)
......
...@@ -66,14 +66,8 @@ class PositionEmbedding(tf.keras.layers.Layer): ...@@ -66,14 +66,8 @@ class PositionEmbedding(tf.keras.layers.Layer):
def build(self, input_shape): def build(self, input_shape):
dimension_list = input_shape.as_list() dimension_list = input_shape.as_list()
seq_length = dimension_list[self._seq_axis]
width = dimension_list[-1] width = dimension_list[-1]
if self._max_length is not None:
weight_sequence_length = self._max_length weight_sequence_length = self._max_length
else:
weight_sequence_length = seq_length
self._position_embeddings = self.add_weight( self._position_embeddings = self.add_weight(
"embeddings", "embeddings",
......
...@@ -39,23 +39,6 @@ class NoNorm(tf.keras.layers.Layer): ...@@ -39,23 +39,6 @@ class NoNorm(tf.keras.layers.Layer):
return output return output
@tf.keras.utils.register_keras_serializable(package='Text')
class NoNormClipped(NoNorm):
"""Quantization friendly implementation for the NoNorm.
The output of NoNorm layer is clipped to [-6.0, 6.0] to make it quantization
friendly.
"""
def __init__(self, name=None):
super(NoNormClipped, self).__init__(name=name)
def call(self, feature):
output = feature * self.scale + self.bias
clipped_output = tf.clip_by_value(output, -6.0, 6.0)
return clipped_output
def _get_norm_layer(normalization_type='no_norm', name=None): def _get_norm_layer(normalization_type='no_norm', name=None):
"""Get normlization layer. """Get normlization layer.
...@@ -69,8 +52,6 @@ def _get_norm_layer(normalization_type='no_norm', name=None): ...@@ -69,8 +52,6 @@ def _get_norm_layer(normalization_type='no_norm', name=None):
""" """
if normalization_type == 'no_norm': if normalization_type == 'no_norm':
layer = NoNorm(name=name) layer = NoNorm(name=name)
elif normalization_type == 'no_norm_clipped':
layer = NoNormClipped(name=name)
elif normalization_type == 'layer_norm': elif normalization_type == 'layer_norm':
layer = tf.keras.layers.LayerNormalization( layer = tf.keras.layers.LayerNormalization(
name=name, name=name,
......
...@@ -33,22 +33,6 @@ def generate_fake_input(batch_size=1, seq_len=5, vocab_size=10000, seed=0): ...@@ -33,22 +33,6 @@ def generate_fake_input(batch_size=1, seq_len=5, vocab_size=10000, seed=0):
return fake_input return fake_input
class EdgeTPUNoNormTest(tf.test.TestCase):
def test_no_norm(self):
layer = mobile_bert_layers.NoNormClipped()
feature = tf.random.uniform(
[2, 3, 4], minval=-8, maxval=8, dtype=tf.float32)
output = layer(feature)
output_shape = output.shape.as_list()
expected_shape = [2, 3, 4]
self.assertListEqual(output_shape, expected_shape, msg=None)
output_min = tf.reduce_min(output)
output_max = tf.reduce_max(output)
self.assertGreaterEqual(6.0, output_max)
self.assertLessEqual(-6.0, output_min)
class MobileBertEncoderTest(parameterized.TestCase, tf.test.TestCase): class MobileBertEncoderTest(parameterized.TestCase, tf.test.TestCase):
def test_embedding_layer_with_token_type(self): def test_embedding_layer_with_token_type(self):
......
...@@ -544,7 +544,8 @@ class TransformerDecoder(tf.keras.layers.Layer): ...@@ -544,7 +544,8 @@ class TransformerDecoder(tf.keras.layers.Layer):
self_attention_mask=None, self_attention_mask=None,
cross_attention_mask=None, cross_attention_mask=None,
cache=None, cache=None,
decode_loop_step=None): decode_loop_step=None,
return_all_decoder_outputs=False):
"""Return the output of the decoder layer stacks. """Return the output of the decoder layer stacks.
Args: Args:
...@@ -561,6 +562,9 @@ class TransformerDecoder(tf.keras.layers.Layer): ...@@ -561,6 +562,9 @@ class TransformerDecoder(tf.keras.layers.Layer):
...} ...}
decode_loop_step: An integer, the step number of the decoding loop. Used decode_loop_step: An integer, the step number of the decoding loop. Used
only for autoregressive inference on TPU. only for autoregressive inference on TPU.
return_all_decoder_outputs: Return all decoder layer outputs.
Note that the outputs are layer normed.
This is useful when introducing per layer auxiliary loss.
Returns: Returns:
Output of decoder. Output of decoder.
...@@ -568,6 +572,7 @@ class TransformerDecoder(tf.keras.layers.Layer): ...@@ -568,6 +572,7 @@ class TransformerDecoder(tf.keras.layers.Layer):
""" """
output_tensor = target output_tensor = target
decoder_outputs = []
for layer_idx in range(self.num_layers): for layer_idx in range(self.num_layers):
transformer_inputs = [ transformer_inputs = [
output_tensor, memory, cross_attention_mask, self_attention_mask output_tensor, memory, cross_attention_mask, self_attention_mask
...@@ -581,6 +586,12 @@ class TransformerDecoder(tf.keras.layers.Layer): ...@@ -581,6 +586,12 @@ class TransformerDecoder(tf.keras.layers.Layer):
transformer_inputs, transformer_inputs,
cache=cache[cache_layer_idx], cache=cache[cache_layer_idx],
decode_loop_step=decode_loop_step) decode_loop_step=decode_loop_step)
if return_all_decoder_outputs:
decoder_outputs.append(self.output_normalization(output_tensor))
if return_all_decoder_outputs:
return decoder_outputs
else:
return self.output_normalization(output_tensor) return self.output_normalization(output_tensor)
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
# pylint: disable=g-classes-have-attributes # pylint: disable=g-classes-have-attributes
import collections import collections
import tensorflow as tf import tensorflow as tf
from tensorflow.python.util import deprecation
@tf.keras.utils.register_keras_serializable(package='Text') @tf.keras.utils.register_keras_serializable(package='Text')
...@@ -39,6 +40,8 @@ class Classification(tf.keras.Model): ...@@ -39,6 +40,8 @@ class Classification(tf.keras.Model):
`predictions`. `predictions`.
""" """
@deprecation.deprecated(None, 'Classification as a network is deprecated. '
'Please use the layers.ClassificationHead instead.')
def __init__(self, def __init__(self,
input_width, input_width,
num_classes, num_classes,
......
...@@ -13,18 +13,18 @@ ...@@ -13,18 +13,18 @@
# limitations under the License. # limitations under the License.
"""Progressive distillation for MobileBERT student model.""" """Progressive distillation for MobileBERT student model."""
import dataclasses
from typing import List, Optional from typing import List, Optional
from absl import logging from absl import logging
import dataclasses
import orbit import orbit
import tensorflow as tf import tensorflow as tf
from official.core import base_task from official.core import base_task
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.modeling import optimization from official.modeling import optimization
from official.modeling import tf_utils from official.modeling import tf_utils
from official.modeling.fast_training.progressive import policies
from official.modeling.hyperparams import base_config from official.modeling.hyperparams import base_config
from official.modeling.progressive import policies
from official.nlp import keras_nlp from official.nlp import keras_nlp
from official.nlp.configs import bert from official.nlp.configs import bert
from official.nlp.configs import encoders from official.nlp.configs import encoders
......
...@@ -22,7 +22,7 @@ import tensorflow as tf ...@@ -22,7 +22,7 @@ import tensorflow as tf
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.modeling import optimization from official.modeling import optimization
from official.modeling import tf_utils from official.modeling import tf_utils
from official.modeling.progressive import trainer as prog_trainer_lib from official.modeling.fast_training.progressive import trainer as prog_trainer_lib
from official.nlp.configs import bert from official.nlp.configs import bert
from official.nlp.configs import encoders from official.nlp.configs import encoders
from official.nlp.data import pretrain_dataloader from official.nlp.data import pretrain_dataloader
......
...@@ -28,8 +28,8 @@ from official.core import train_utils ...@@ -28,8 +28,8 @@ from official.core import train_utils
from official.modeling import hyperparams from official.modeling import hyperparams
from official.modeling import optimization from official.modeling import optimization
from official.modeling import performance from official.modeling import performance
from official.modeling.progressive import train_lib from official.modeling.fast_training.progressive import train_lib
from official.modeling.progressive import trainer as prog_trainer_lib from official.modeling.fast_training.progressive import trainer as prog_trainer_lib
from official.nlp.data import pretrain_dataloader from official.nlp.data import pretrain_dataloader
from official.nlp.projects.mobilebert import distillation from official.nlp.projects.mobilebert import distillation
......
...@@ -84,7 +84,8 @@ class SentencePrediction(export_base.ExportModule): ...@@ -84,7 +84,8 @@ class SentencePrediction(export_base.ExportModule):
def serve(self, def serve(self,
input_word_ids, input_word_ids,
input_mask=None, input_mask=None,
input_type_ids=None) -> Dict[str, tf.Tensor]: input_type_ids=None,
use_prob=False) -> Dict[str, tf.Tensor]:
if input_type_ids is None: if input_type_ids is None:
# Requires CLS token is the first token of inputs. # Requires CLS token is the first token of inputs.
input_type_ids = tf.zeros_like(input_word_ids) input_type_ids = tf.zeros_like(input_word_ids)
...@@ -97,7 +98,10 @@ class SentencePrediction(export_base.ExportModule): ...@@ -97,7 +98,10 @@ class SentencePrediction(export_base.ExportModule):
input_word_ids=input_word_ids, input_word_ids=input_word_ids,
input_mask=input_mask, input_mask=input_mask,
input_type_ids=input_type_ids) input_type_ids=input_type_ids)
if not use_prob:
return dict(outputs=self.inference_step(inputs)) return dict(outputs=self.inference_step(inputs))
else:
return dict(outputs=tf.nn.softmax(self.inference_step(inputs)))
@tf.function @tf.function
def serve_examples(self, inputs) -> Dict[str, tf.Tensor]: def serve_examples(self, inputs) -> Dict[str, tf.Tensor]:
......
...@@ -13,10 +13,10 @@ ...@@ -13,10 +13,10 @@
# limitations under the License. # limitations under the License.
"""Sentence prediction (classification) task.""" """Sentence prediction (classification) task."""
import dataclasses
from typing import List, Union, Optional from typing import List, Union, Optional
from absl import logging from absl import logging
import dataclasses
import numpy as np import numpy as np
import orbit import orbit
from scipy import stats from scipy import stats
...@@ -140,14 +140,25 @@ class SentencePredictionTask(base_task.Task): ...@@ -140,14 +140,25 @@ class SentencePredictionTask(base_task.Task):
del training del training
if self.task_config.model.num_classes == 1: if self.task_config.model.num_classes == 1:
metrics = [tf.keras.metrics.MeanSquaredError()] metrics = [tf.keras.metrics.MeanSquaredError()]
elif self.task_config.model.num_classes == 2:
metrics = [
tf.keras.metrics.SparseCategoricalAccuracy(name='cls_accuracy'),
tf.keras.metrics.AUC(name='auc', curve='PR'),
]
else: else:
metrics = [ metrics = [
tf.keras.metrics.SparseCategoricalAccuracy(name='cls_accuracy') tf.keras.metrics.SparseCategoricalAccuracy(name='cls_accuracy'),
] ]
return metrics return metrics
def process_metrics(self, metrics, labels, model_outputs): def process_metrics(self, metrics, labels, model_outputs):
for metric in metrics: for metric in metrics:
if metric.name == 'auc':
# Convert the logit to probability and extract the probability of True..
metric.update_state(
labels[self.label_field],
tf.expand_dims(tf.nn.softmax(model_outputs)[:, 1], axis=1))
if metric.name == 'cls_accuracy':
metric.update_state(labels[self.label_field], model_outputs) metric.update_state(labels[self.label_field], model_outputs)
def process_compiled_metrics(self, compiled_metrics, labels, model_outputs): def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
......
...@@ -13,11 +13,11 @@ ...@@ -13,11 +13,11 @@
# limitations under the License. # limitations under the License.
"""Defines the translation task.""" """Defines the translation task."""
import dataclasses
import os import os
from typing import Optional from typing import Optional
from absl import logging from absl import logging
import dataclasses
import sacrebleu import sacrebleu
import tensorflow as tf import tensorflow as tf
import tensorflow_text as tftxt import tensorflow_text as tftxt
......
...@@ -85,7 +85,8 @@ class TranslationTaskTest(tf.test.TestCase): ...@@ -85,7 +85,8 @@ class TranslationTaskTest(tf.test.TestCase):
def test_task(self): def test_task(self):
config = translation.TranslationConfig( config = translation.TranslationConfig(
model=translation.ModelConfig( model=translation.ModelConfig(
encoder=translation.EncDecoder(), decoder=translation.EncDecoder()), encoder=translation.EncDecoder(num_layers=1),
decoder=translation.EncDecoder(num_layers=1)),
train_data=wmt_dataloader.WMTDataConfig( train_data=wmt_dataloader.WMTDataConfig(
input_path=self._record_input_path, input_path=self._record_input_path,
src_lang="en", tgt_lang="reverse_en", src_lang="en", tgt_lang="reverse_en",
...@@ -102,7 +103,8 @@ class TranslationTaskTest(tf.test.TestCase): ...@@ -102,7 +103,8 @@ class TranslationTaskTest(tf.test.TestCase):
def test_no_sentencepiece_path(self): def test_no_sentencepiece_path(self):
config = translation.TranslationConfig( config = translation.TranslationConfig(
model=translation.ModelConfig( model=translation.ModelConfig(
encoder=translation.EncDecoder(), decoder=translation.EncDecoder()), encoder=translation.EncDecoder(num_layers=1),
decoder=translation.EncDecoder(num_layers=1)),
train_data=wmt_dataloader.WMTDataConfig( train_data=wmt_dataloader.WMTDataConfig(
input_path=self._record_input_path, input_path=self._record_input_path,
src_lang="en", tgt_lang="reverse_en", src_lang="en", tgt_lang="reverse_en",
...@@ -122,7 +124,8 @@ class TranslationTaskTest(tf.test.TestCase): ...@@ -122,7 +124,8 @@ class TranslationTaskTest(tf.test.TestCase):
sentencepeice_model_prefix) sentencepeice_model_prefix)
config = translation.TranslationConfig( config = translation.TranslationConfig(
model=translation.ModelConfig( model=translation.ModelConfig(
encoder=translation.EncDecoder(), decoder=translation.EncDecoder()), encoder=translation.EncDecoder(num_layers=1),
decoder=translation.EncDecoder(num_layers=1)),
train_data=wmt_dataloader.WMTDataConfig( train_data=wmt_dataloader.WMTDataConfig(
input_path=self._record_input_path, input_path=self._record_input_path,
src_lang="en", tgt_lang="reverse_en", src_lang="en", tgt_lang="reverse_en",
...@@ -137,7 +140,8 @@ class TranslationTaskTest(tf.test.TestCase): ...@@ -137,7 +140,8 @@ class TranslationTaskTest(tf.test.TestCase):
def test_evaluation(self): def test_evaluation(self):
config = translation.TranslationConfig( config = translation.TranslationConfig(
model=translation.ModelConfig( model=translation.ModelConfig(
encoder=translation.EncDecoder(), decoder=translation.EncDecoder(), encoder=translation.EncDecoder(num_layers=1),
decoder=translation.EncDecoder(num_layers=1),
padded_decode=False, padded_decode=False,
decode_max_length=64), decode_max_length=64),
validation_data=wmt_dataloader.WMTDataConfig( validation_data=wmt_dataloader.WMTDataConfig(
......
...@@ -27,9 +27,15 @@ from official.core import task_factory ...@@ -27,9 +27,15 @@ from official.core import task_factory
from official.core import train_lib from official.core import train_lib
from official.core import train_utils from official.core import train_utils
from official.modeling import performance from official.modeling import performance
from official.nlp import continuous_finetune_lib
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
flags.DEFINE_integer(
'pretrain_steps',
default=None,
help='The number of total training steps for the pretraining job.')
def main(_): def main(_):
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params) gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
...@@ -40,12 +46,18 @@ def main(_): ...@@ -40,12 +46,18 @@ def main(_):
# may race against the train job for writing the same file. # may race against the train job for writing the same file.
train_utils.serialize_config(params, model_dir) train_utils.serialize_config(params, model_dir)
if FLAGS.mode == 'continuous_train_and_eval':
continuous_finetune_lib.run_continuous_finetune(
FLAGS.mode, params, model_dir, pretrain_steps=FLAGS.pretrain_steps)
else:
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16' # Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of # can have significant impact on model speeds by utilizing float16 in case
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when # of GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only
# dtype is float16 # when dtype is float16
if params.runtime.mixed_precision_dtype: if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype) performance.set_mixed_precision_policy(
params.runtime.mixed_precision_dtype)
distribution_strategy = distribute_utils.get_distribution_strategy( distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy, distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg, all_reduce_alg=params.runtime.all_reduce_alg,
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TFM continuous finetuning+eval training driver."""
from absl import app
from absl import flags
import gin
# pylint: disable=unused-import
from official.common import registry_imports
# pylint: enable=unused-import
from official.common import flags as tfm_flags
from official.core import train_utils
from official.nlp import continuous_finetune_lib
FLAGS = flags.FLAGS
flags.DEFINE_integer(
'pretrain_steps',
default=None,
help='The number of total training steps for the pretraining job.')
def main(_):
# TODO(b/177863554): consolidate to nlp/train.py
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
params = train_utils.parse_configuration(FLAGS)
model_dir = FLAGS.model_dir
train_utils.serialize_config(params, model_dir)
continuous_finetune_lib.run_continuous_finetune(
FLAGS.mode, params, model_dir, pretrain_steps=FLAGS.pretrain_steps)
train_utils.save_gin_config(FLAGS.mode, model_dir)
if __name__ == '__main__':
tfm_flags.define_flags()
app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment