"app/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "c4ba1921870f40659d0b12ca00b3bcc6aa6cece9"
Commit 0fe89de9 authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

Opensource translation task.

PiperOrigin-RevId: 351379789
parent 04af8d35
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Input pipeline for the transformer model to read, filter, and batch examples.
1. Batching scheme
Prior to batching, elements in the dataset are grouped by length (max between
'inputs' and 'targets' length). Each group is then batched such that:
group_batch_size * length <= batch_size.
Another way to view batch_size is the maximum number of tokens in each batch.
Once batched, each element in the dataset will have the shape:
{'inputs': [group_batch_size, padded_input_length],
'targets': [group_batch_size, padded_target_length]}
Lengths are padded to the longest 'inputs' or 'targets' sequence in the batch
(padded_input_length and padded_target_length can be different).
This batching scheme decreases the fraction of padding tokens per training
batch, thus improving the training speed significantly.
"""
from typing import Dict, Optional
import dataclasses
import tensorflow as tf
import tensorflow_text as tftxt
from official.core import config_definitions as cfg
from official.core import input_reader
from official.nlp.data import data_loader
from official.nlp.data import data_loader_factory
# Example grouping constants. Defines length boundaries for each group.
# These values are the defaults used in Tensor2Tensor.
_MIN_BOUNDARY = 8
_BOUNDARY_SCALE = 1.1
def _get_example_length(example):
"""Returns the maximum length between the example inputs and targets."""
length = tf.maximum(tf.shape(example[0])[0], tf.shape(example[1])[0])
return length
def _create_min_max_boundaries(max_length,
min_boundary=_MIN_BOUNDARY,
boundary_scale=_BOUNDARY_SCALE):
"""Create min and max boundary lists up to max_length.
For example, when max_length=24, min_boundary=4 and boundary_scale=2, the
returned values will be:
buckets_min = [0, 4, 8, 16, 24]
buckets_max = [4, 8, 16, 24, 25]
Args:
max_length: The maximum length of example in dataset.
min_boundary: Minimum length in boundary.
boundary_scale: Amount to scale consecutive boundaries in the list.
Returns:
min and max boundary lists
"""
# Create bucket boundaries list by scaling the previous boundary or adding 1
# (to ensure increasing boundary sizes).
bucket_boundaries = []
x = min_boundary
while x < max_length:
bucket_boundaries.append(x)
x = max(x + 1, int(x * boundary_scale))
# Create min and max boundary lists from the initial list.
buckets_min = [0] + bucket_boundaries
buckets_max = bucket_boundaries + [max_length + 1]
return buckets_min, buckets_max
def _batch_examples(dataset, batch_size, max_length):
"""Group examples by similar lengths, and return batched dataset.
Each batch of similar-length examples are padded to the same length, and may
have different number of elements in each batch, such that:
group_batch_size * padded_length <= batch_size.
This decreases the number of padding tokens per batch, which improves the
training speed.
Args:
dataset: Dataset of unbatched examples.
batch_size: Max number of tokens per batch of examples.
max_length: Max number of tokens in an example input or target sequence.
Returns:
Dataset of batched examples with similar lengths.
"""
# Get min and max boundary lists for each example. These are used to calculate
# the `bucket_id`, which is the index at which:
# buckets_min[bucket_id] <= len(example) < buckets_max[bucket_id]
# Note that using both min and max lists improves the performance.
buckets_min, buckets_max = _create_min_max_boundaries(max_length)
# Create list of batch sizes for each bucket_id, so that
# bucket_batch_size[bucket_id] * buckets_max[bucket_id] <= batch_size
bucket_batch_sizes = [int(batch_size) // x for x in buckets_max]
# Validates bucket batch sizes.
if any([batch_size <= 0 for batch_size in bucket_batch_sizes]):
raise ValueError(
'The token budget, global batch size, is too small to yeild 0 bucket '
'window: %s' % str(bucket_batch_sizes))
# bucket_id will be a tensor, so convert this list to a tensor as well.
bucket_batch_sizes = tf.constant(bucket_batch_sizes, dtype=tf.int64)
def example_to_bucket_id(example):
"""Return int64 bucket id for this example, calculated based on length."""
example_input = example['inputs']
example_target = example['targets']
seq_length = _get_example_length((example_input, example_target))
conditions_c = tf.logical_and(
tf.less_equal(buckets_min, seq_length), tf.less(seq_length,
buckets_max))
bucket_id = tf.reduce_min(tf.where(conditions_c))
return bucket_id
def window_size_fn(bucket_id):
"""Return number of examples to be grouped when given a bucket id."""
return bucket_batch_sizes[bucket_id]
def batching_fn(bucket_id, grouped_dataset):
"""Batch and add padding to a dataset of elements with similar lengths."""
bucket_batch_size = window_size_fn(bucket_id)
# Batch the dataset and add padding so that all input sequences in the
# examples have the same length, and all target sequences have the same
# lengths as well. Resulting lengths of inputs and targets can differ.
padded_shapes = dict([
(name, [None] * len(spec.shape))
for name, spec in grouped_dataset.element_spec.items()
])
return grouped_dataset.padded_batch(bucket_batch_size, padded_shapes)
return dataset.apply(
tf.data.experimental.group_by_window(
key_func=example_to_bucket_id,
reduce_func=batching_fn,
window_size=None,
window_size_func=window_size_fn))
@dataclasses.dataclass
class WMTDataConfig(cfg.DataConfig):
"""Data config for WMT translation."""
max_seq_length: int = 64
static_batch: bool = False
sentencepiece_model_path: str = ''
src_lang: str = ''
tgt_lang: str = ''
transform_and_batch: bool = True
has_unique_id: bool = False
@data_loader_factory.register_data_loader_cls(WMTDataConfig)
class WMTDataLoader(data_loader.DataLoader):
"""A class to load dataset for WMT translation task."""
def __init__(self, params: WMTDataConfig):
self._params = params
self._max_seq_length = params.max_seq_length
self._static_batch = params.static_batch
self._global_batch_size = params.global_batch_size
if self._params.transform_and_batch:
self._tokenizer = tftxt.SentencepieceTokenizer(
model=tf.io.gfile.GFile(params.sentencepiece_model_path, 'rb').read(),
add_eos=True)
def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example."""
name_to_features = {
self._params.src_lang: tf.io.FixedLenFeature([], tf.string),
self._params.tgt_lang: tf.io.FixedLenFeature([], tf.string),
}
if self._params.has_unique_id:
name_to_features['unique_id'] = tf.io.FixedLenFeature([], tf.int64)
example = tf.io.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for name in example:
t = example[name]
if t.dtype == tf.int64:
t = tf.cast(t, tf.int32)
example[name] = t
return example
def _tokenize(self, inputs) -> Dict[str, tf.Tensor]:
tokenized_inputs = {}
for k, v in inputs.items():
if k == self._params.src_lang:
tokenized_inputs['inputs'] = self._tokenizer.tokenize(v)
elif k == self._params.tgt_lang:
tokenized_inputs['targets'] = self._tokenizer.tokenize(v)
else:
tokenized_inputs[k] = v
print(tokenized_inputs)
return tokenized_inputs
def _filter_max_length(self, inputs):
# return tf.constant(True)
return tf.logical_and(
tf.shape(inputs['inputs'])[0] <= self._max_seq_length,
tf.shape(inputs['targets'])[0] <= self._max_seq_length)
def _maybe_truncate(self, inputs):
truncated_inputs = {}
for k, v in inputs.items():
if k == 'inputs' or k == 'targets':
truncated_inputs[k] = tf.pad(
v[:self._max_seq_length - 1], [[0, 1]],
constant_values=1) if tf.shape(v)[0] > self._max_seq_length else v
else:
truncated_inputs[k] = v
return truncated_inputs
def _tokenize_bucketize_and_batch(
self,
dataset,
input_context: Optional[tf.distribute.InputContext] = None):
dataset = dataset.map(
self._tokenize, num_parallel_calls=tf.data.experimental.AUTOTUNE)
if self._params.is_training:
dataset = dataset.filter(self._filter_max_length)
else:
dataset = dataset.map(
self._maybe_truncate,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
per_replica_batch_size = input_context.get_per_replica_batch_size(
self._global_batch_size) if input_context else self._global_batch_size
if self._static_batch:
padded_shapes = {}
for name, _ in dataset.element_spec.items():
if name == 'unique_id':
padded_shapes[name] = []
else:
padded_shapes[name] = [self._max_seq_length
] if self._static_batch else [None]
batch_size = per_replica_batch_size
if self._params.is_training:
batch_size = int(batch_size // self._max_seq_length)
dataset = dataset.padded_batch(
batch_size,
padded_shapes,
drop_remainder=True)
else:
# Group and batch such that each batch has examples of similar length.
dataset = _batch_examples(dataset, per_replica_batch_size,
self._max_seq_length)
# Prefetch the next element to improve speed of input pipeline.
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
return dataset
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
"""Returns a tf.dataset.Dataset."""
decoder_fn = None
# Only decode for TFRecords.
if self._params.input_path:
decoder_fn = self._decode
def _identity(
dataset, input_context: Optional[tf.distribute.InputContext] = None):
del input_context
return dataset
transform_and_batch_fn = _identity
if self._params.transform_and_batch:
transform_and_batch_fn = self._tokenize_bucketize_and_batch
reader = input_reader.InputReader(
params=self._params,
decoder_fn=decoder_fn,
transform_and_batch_fn=transform_and_batch_fn)
return reader.read(input_context)
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for official.nlp.data.wmt_dataloader."""
import os
from absl.testing import parameterized
import tensorflow as tf
from sentencepiece import SentencePieceTrainer
from official.nlp.data import wmt_dataloader
def _generate_line_file(filepath, lines):
with tf.io.gfile.GFile(filepath, 'w') as f:
for l in lines:
f.write('{}\n'.format(l))
def _generate_record_file(filepath, src_lines, tgt_lines, unique_id=False):
writer = tf.io.TFRecordWriter(filepath)
for i, (src, tgt) in enumerate(zip(src_lines, tgt_lines)):
features = {
'en': tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[src.encode()])),
'reverse_en': tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[tgt.encode()])),
}
if unique_id:
features['unique_id'] = tf.train.Feature(
int64_list=tf.train.Int64List(value=[i])),
example = tf.train.Example(
features=tf.train.Features(
feature=features))
writer.write(example.SerializeToString())
writer.close()
def _train_sentencepiece(input_path, vocab_size, model_path, eos_id=1):
argstr = ' '.join([
f'--input={input_path}', f'--vocab_size={vocab_size}',
'--character_coverage=0.995',
f'--model_prefix={model_path}', '--model_type=bpe',
'--bos_id=-1', '--pad_id=0', f'--eos_id={eos_id}', '--unk_id=2'
])
SentencePieceTrainer.Train(argstr)
class WMTDataLoaderTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(WMTDataLoaderTest, self).setUp()
self._temp_dir = self.get_temp_dir()
src_lines = [
'abc ede fg',
'bbcd ef a g',
'de f a a g'
]
tgt_lines = [
'dd cc a ef g',
'bcd ef a g',
'gef cd ba'
]
self._record_train_input_path = os.path.join(self._temp_dir, 'train.record')
_generate_record_file(self._record_train_input_path, src_lines, tgt_lines)
self._record_test_input_path = os.path.join(self._temp_dir, 'test.record')
_generate_record_file(self._record_test_input_path, src_lines, tgt_lines,
unique_id=True)
self._sentencepeice_input_path = os.path.join(self._temp_dir, 'inputs.txt')
_generate_line_file(self._sentencepeice_input_path, src_lines + tgt_lines)
sentencepeice_model_prefix = os.path.join(self._temp_dir, 'sp')
_train_sentencepiece(self._sentencepeice_input_path, 20,
sentencepeice_model_prefix)
self._sentencepeice_model_path = '{}.model'.format(
sentencepeice_model_prefix)
@parameterized.named_parameters(
('train_static', True, True, 100, (2, 35)),
('train_non_static', True, False, 100, (12, 7)),
('non_train_static', False, True, 3, (3, 35)),
('non_train_non_static', False, False, 50, (2, 7)),)
def test_load_dataset(
self, is_training, static_batch, batch_size, expected_shape):
data_config = wmt_dataloader.WMTDataConfig(
input_path=self._record_train_input_path
if is_training else self._record_test_input_path,
max_seq_length=35,
global_batch_size=batch_size,
is_training=is_training,
static_batch=static_batch,
src_lang='en',
tgt_lang='reverse_en',
sentencepiece_model_path=self._sentencepeice_model_path)
dataset = wmt_dataloader.WMTDataLoader(data_config).load()
examples = next(iter(dataset))
inputs, targets = examples['inputs'], examples['targets']
self.assertEqual(inputs.shape, expected_shape)
self.assertEqual(targets.shape, expected_shape)
def test_load_dataset_raise_invalid_window(self):
batch_tokens_size = 10 # this is too small to form buckets.
data_config = wmt_dataloader.WMTDataConfig(
input_path=self._record_train_input_path,
max_seq_length=100,
global_batch_size=batch_tokens_size,
is_training=True,
static_batch=False,
src_lang='en',
tgt_lang='reverse_en',
sentencepiece_model_path=self._sentencepeice_model_path)
with self.assertRaisesRegex(
ValueError, 'The token budget, global batch size, is too small.*'):
_ = wmt_dataloader.WMTDataLoader(data_config).load()
if __name__ == '__main__':
tf.test.main()
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines the translation task."""
import os
from typing import Optional
from absl import logging
import dataclasses
import sacrebleu
import tensorflow as tf
import tensorflow_text as tftxt
from official.core import base_task
from official.core import config_definitions as cfg
from official.core import task_factory
from official.modeling.hyperparams import base_config
from official.nlp.data import data_loader_factory
from official.nlp.modeling import models
from official.nlp.transformer import compute_bleu
def _pad_tensors_to_same_length(x, y):
"""Pad x and y so that the results have the same length (second dimension)."""
x_length = tf.shape(x)[1]
y_length = tf.shape(y)[1]
max_length = tf.maximum(x_length, y_length)
x = tf.pad(x, [[0, 0], [0, max_length - x_length], [0, 0]])
y = tf.pad(y, [[0, 0], [0, max_length - y_length]])
return x, y
def _padded_cross_entropy_loss(logits, labels, smoothing, vocab_size):
"""Calculate cross entropy loss while ignoring padding.
Args:
logits: Tensor of size [batch_size, length_logits, vocab_size]
labels: Tensor of size [batch_size, length_labels]
smoothing: Label smoothing constant, used to determine the on and off values
vocab_size: int size of the vocabulary
Returns:
Returns the cross entropy loss and weight tensors: float32 tensors with
shape [batch_size, max(length_logits, length_labels)]
"""
logits, labels = _pad_tensors_to_same_length(logits, labels)
# Calculate smoothing cross entropy
confidence = 1.0 - smoothing
low_confidence = (1.0 - confidence) / tf.cast(vocab_size - 1, tf.float32)
soft_targets = tf.one_hot(
tf.cast(labels, tf.int32),
depth=vocab_size,
on_value=confidence,
off_value=low_confidence)
xentropy = tf.nn.softmax_cross_entropy_with_logits(
logits=logits, labels=soft_targets)
# Calculate the best (lowest) possible value of cross entropy, and
# subtract from the cross entropy loss.
normalizing_constant = -(
confidence * tf.math.log(confidence) + tf.cast(vocab_size - 1, tf.float32)
* low_confidence * tf.math.log(low_confidence + 1e-20))
xentropy -= normalizing_constant
weights = tf.cast(tf.not_equal(labels, 0), tf.float32)
return xentropy * weights, weights
@dataclasses.dataclass
class EncDecoder(base_config.Config):
"""Configurations for Encoder/Decoder."""
num_layers: int = 6
num_attention_heads: int = 8
intermediate_size: int = 2048
activation: str = "relu"
dropout_rate: float = 0.1
attention_dropout_rate: float = 0.1
intermediate_dropout: float = 0.1
use_bias: bool = False
norm_first: bool = True
norm_epsilon: float = 1e-6
@dataclasses.dataclass
class ModelConfig(base_config.Config):
"""A base Seq2Seq model configuration."""
encoder: EncDecoder = EncDecoder()
decoder: EncDecoder = EncDecoder()
embedding_width: int = 512
dropout_rate: float = 0.1
# Decoding.
padded_decode: bool = False
decode_max_length: Optional[int] = None
beam_size: int = 4
alpha: float = 0.6
# Training.
label_smoothing: float = 0.1
@dataclasses.dataclass
class TranslationConfig(cfg.TaskConfig):
"""The translation task config."""
model: ModelConfig = ModelConfig()
train_data: cfg.DataConfig = cfg.DataConfig()
validation_data: cfg.DataConfig = cfg.DataConfig()
# Tokenization
sentencepiece_model_path: str = ""
# Evaluation.
print_translations: Optional[bool] = None
def write_test_record(params, model_dir):
"""Writes the test input to a tfrecord."""
# Get raw data from tfds.
params = params.replace(transform_and_batch=False)
dataset = data_loader_factory.get_data_loader(params).load()
references = []
total_samples = 0
output_file = os.path.join(model_dir, "eval.tf_record")
writer = tf.io.TFRecordWriter(output_file)
for d in dataset:
references.append(d[params.tgt_lang].numpy().decode())
example = tf.train.Example(
features=tf.train.Features(
feature={
"unique_id": tf.train.Feature(
int64_list=tf.train.Int64List(value=[total_samples])),
params.src_lang: tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[d[params.src_lang].numpy()])),
params.tgt_lang: tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[d[params.tgt_lang].numpy()])),
}))
writer.write(example.SerializeToString())
total_samples += 1
batch_size = params.global_batch_size
num_dummy_example = batch_size - total_samples % batch_size
for i in range(num_dummy_example):
example = tf.train.Example(
features=tf.train.Features(
feature={
"unique_id": tf.train.Feature(
int64_list=tf.train.Int64List(value=[total_samples + i])),
params.src_lang: tf.train.Feature(
bytes_list=tf.train.BytesList(value=[b""])),
params.tgt_lang: tf.train.Feature(
bytes_list=tf.train.BytesList(value=[b""])),
}))
writer.write(example.SerializeToString())
writer.close()
return references, output_file
@task_factory.register_task_cls(TranslationConfig)
class TranslationTask(base_task.Task):
"""A single-replica view of training procedure.
Tasks provide artifacts for training/evalution procedures, including
loading/iterating over Datasets, initializing the model, calculating the loss
and customized metrics with reduction.
"""
def __init__(self, params: cfg.TaskConfig, logging_dir=None, name=None):
super().__init__(params, logging_dir, name=name)
self._sentencepiece_model_path = params.sentencepiece_model_path
if params.sentencepiece_model_path:
self._sp_tokenizer = tftxt.SentencepieceTokenizer(
model=tf.io.gfile.GFile(params.sentencepiece_model_path, "rb").read(),
add_eos=True)
try:
empty_str_tokenized = self._sp_tokenizer.tokenize("").numpy()
except tf.errors.InternalError:
raise ValueError(
"EOS token not in tokenizer vocab."
"Please make sure the tokenizer generates a single token for an "
"empty string.")
self._eos_id = empty_str_tokenized.item()
self._vocab_size = self._sp_tokenizer.vocab_size().numpy()
else:
raise ValueError("Setencepiece model path not provided.")
if (params.validation_data.input_path or
params.validation_data.tfds_name) and self._logging_dir:
self._references, self._tf_record_input_path = write_test_record(
params.validation_data, self.logging_dir)
def build_model(self) -> tf.keras.Model:
"""Creates model architecture.
Returns:
A model instance.
"""
model_cfg = self.task_config.model
encoder_kwargs = model_cfg.encoder.as_dict()
encoder_layer = models.TransformerEncoder(**encoder_kwargs)
decoder_kwargs = model_cfg.decoder.as_dict()
decoder_layer = models.TransformerDecoder(**decoder_kwargs)
return models.Seq2SeqTransformer(
vocab_size=self._vocab_size,
embedding_width=model_cfg.embedding_width,
dropout_rate=model_cfg.dropout_rate,
padded_decode=model_cfg.padded_decode,
decode_max_length=model_cfg.decode_max_length,
beam_size=model_cfg.beam_size,
alpha=model_cfg.alpha,
encoder_layer=encoder_layer,
decoder_layer=decoder_layer,
eos_id=self._eos_id)
def build_inputs(self,
params: cfg.DataConfig,
input_context: Optional[tf.distribute.InputContext] = None):
"""Returns a dataset."""
if params.is_training:
dataloader_params = params
else:
input_path = self._tf_record_input_path
# Read from padded tf records instead.
dataloader_params = params.replace(
input_path=input_path,
tfds_name="",
tfds_split="",
has_unique_id=True)
dataloader_params = dataloader_params.replace(
sentencepiece_model_path=self._sentencepiece_model_path)
return data_loader_factory.get_data_loader(dataloader_params).load(
input_context)
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
"""Standard interface to compute losses.
Args:
labels: optional label tensors.
model_outputs: a nested structure of output tensors.
aux_losses: auxiliary loss tensors, i.e. `losses` in keras.Model.
Returns:
The total loss tensor.
"""
del aux_losses
smoothing = self.task_config.model.label_smoothing
xentropy, weights = _padded_cross_entropy_loss(model_outputs, labels,
smoothing, self._vocab_size)
return tf.reduce_sum(xentropy) / tf.reduce_sum(weights)
def train_step(self,
inputs,
model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer,
metrics=None):
"""Does forward and backward.
With distribution strategies, this method runs on devices.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
with tf.GradientTape() as tape:
outputs = model(inputs, training=True)
# Computes per-replica loss.
loss = self.build_losses(labels=inputs["targets"], model_outputs=outputs)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync
# For mixed precision, when a LossScaleOptimizer is used, the loss is
# scaled to avoid numeric underflow.
if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
scaled_loss = optimizer.get_scaled_loss(scaled_loss)
tvars = model.trainable_variables
grads = tape.gradient(scaled_loss, tvars)
if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
grads = optimizer.get_unscaled_gradients(grads)
optimizer.apply_gradients(list(zip(grads, tvars)))
logs = {self.loss: loss}
if metrics:
self.process_metrics(metrics, inputs["targets"], outputs)
logs.update({m.name: m.result() for m in metrics})
return logs
def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
unique_ids = inputs.pop("unique_id")
# Validation loss
outputs = model(inputs, training=False)
# Computes per-replica loss to help understand if we are overfitting.
loss = self.build_losses(labels=inputs["targets"], model_outputs=outputs)
inputs.pop("targets")
# Beam search to calculate metrics.
model_outputs = model(inputs, training=False)
outputs = model_outputs
logs = {
self.loss: loss,
"inputs": inputs["inputs"],
"unique_ids": unique_ids,
}
logs.update(outputs)
return logs
def aggregate_logs(self, state=None, step_outputs=None):
"""Aggregates over logs returned from a validation step."""
if state is None:
state = {}
for in_token_ids, out_token_ids, unique_ids in zip(
step_outputs["inputs"],
step_outputs["outputs"],
step_outputs["unique_ids"]):
for in_ids, out_ids, u_id in zip(
in_token_ids.numpy(), out_token_ids.numpy(), unique_ids.numpy()):
state[u_id] = (in_ids, out_ids)
return state
def reduce_aggregated_logs(self, aggregated_logs):
def _decode(ids):
return self._sp_tokenizer.detokenize(ids).numpy().decode()
def _trim_and_decode(ids):
"""Trim EOS and PAD tokens from ids, and decode to return a string."""
try:
index = list(ids).index(self._eos_id)
return _decode(ids[:index])
except ValueError: # No EOS found in sequence
return _decode(ids)
translations = []
for u_id in sorted(aggregated_logs):
if u_id >= len(self._references):
continue
src = _trim_and_decode(aggregated_logs[u_id][0])
translation = _trim_and_decode(aggregated_logs[u_id][1])
translations.append(translation)
if self.task_config.print_translations:
# Deccoding the in_ids to reflect what the model sees.
logging.info("Translating:\n\tInput: %s\n\tOutput: %s\n\tReference: %s",
src, translation, self._references[u_id])
sacrebleu_score = sacrebleu.corpus_bleu(
translations, [self._references]).score
bleu_score = compute_bleu.bleu_on_list(self._references, translations)
return {"sacrebleu_score": sacrebleu_score,
"bleu_score": bleu_score}
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for official.nlp.tasks.translation."""
import functools
import os
import orbit
import tensorflow as tf
from sentencepiece import SentencePieceTrainer
from official.nlp.data import wmt_dataloader
from official.nlp.tasks import translation
def _generate_line_file(filepath, lines):
with tf.io.gfile.GFile(filepath, "w") as f:
for l in lines:
f.write("{}\n".format(l))
def _generate_record_file(filepath, src_lines, tgt_lines):
writer = tf.io.TFRecordWriter(filepath)
for src, tgt in zip(src_lines, tgt_lines):
example = tf.train.Example(
features=tf.train.Features(
feature={
"en": tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[src.encode()])),
"reverse_en": tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[tgt.encode()])),
}))
writer.write(example.SerializeToString())
writer.close()
def _train_sentencepiece(input_path, vocab_size, model_path, eos_id=1):
argstr = " ".join([
f"--input={input_path}", f"--vocab_size={vocab_size}",
"--character_coverage=0.995",
f"--model_prefix={model_path}", "--model_type=bpe",
"--bos_id=-1", "--pad_id=0", f"--eos_id={eos_id}", "--unk_id=2"
])
SentencePieceTrainer.Train(argstr)
class TranslationTaskTest(tf.test.TestCase):
def setUp(self):
super(TranslationTaskTest, self).setUp()
self._temp_dir = self.get_temp_dir()
src_lines = [
"abc ede fg",
"bbcd ef a g",
"de f a a g"
]
tgt_lines = [
"dd cc a ef g",
"bcd ef a g",
"gef cd ba"
]
self._record_input_path = os.path.join(self._temp_dir, "inputs.record")
_generate_record_file(self._record_input_path, src_lines, tgt_lines)
self._sentencepeice_input_path = os.path.join(self._temp_dir, "inputs.txt")
_generate_line_file(self._sentencepeice_input_path, src_lines + tgt_lines)
sentencepeice_model_prefix = os.path.join(self._temp_dir, "sp")
_train_sentencepiece(self._sentencepeice_input_path, 11,
sentencepeice_model_prefix)
self._sentencepeice_model_path = "{}.model".format(
sentencepeice_model_prefix)
def test_task(self):
config = translation.TranslationConfig(
model=translation.ModelConfig(
encoder=translation.EncDecoder(), decoder=translation.EncDecoder()),
train_data=wmt_dataloader.WMTDataConfig(
input_path=self._record_input_path,
src_lang="en", tgt_lang="reverse_en",
is_training=True, static_batch=True, global_batch_size=24,
max_seq_length=12),
sentencepiece_model_path=self._sentencepeice_model_path)
task = translation.TranslationTask(config)
model = task.build_model()
dataset = task.build_inputs(config.train_data)
iterator = iter(dataset)
optimizer = tf.keras.optimizers.SGD(lr=0.1)
task.train_step(next(iterator), model, optimizer)
def test_no_sentencepiece_path(self):
config = translation.TranslationConfig(
model=translation.ModelConfig(
encoder=translation.EncDecoder(), decoder=translation.EncDecoder()),
train_data=wmt_dataloader.WMTDataConfig(
input_path=self._record_input_path,
src_lang="en", tgt_lang="reverse_en",
is_training=True, static_batch=True, global_batch_size=4,
max_seq_length=4),
sentencepiece_model_path=None)
with self.assertRaisesRegex(
ValueError,
"Setencepiece model path not provided."):
translation.TranslationTask(config)
def test_sentencepiece_no_eos(self):
sentencepeice_model_prefix = os.path.join(self._temp_dir, "sp_no_eos")
_train_sentencepiece(self._sentencepeice_input_path, 20,
sentencepeice_model_prefix, eos_id=-1)
sentencepeice_model_path = "{}.model".format(
sentencepeice_model_prefix)
config = translation.TranslationConfig(
model=translation.ModelConfig(
encoder=translation.EncDecoder(), decoder=translation.EncDecoder()),
train_data=wmt_dataloader.WMTDataConfig(
input_path=self._record_input_path,
src_lang="en", tgt_lang="reverse_en",
is_training=True, static_batch=True, global_batch_size=4,
max_seq_length=4),
sentencepiece_model_path=sentencepeice_model_path)
with self.assertRaisesRegex(
ValueError,
"EOS token not in tokenizer vocab.*"):
translation.TranslationTask(config)
def test_evaluation(self):
config = translation.TranslationConfig(
model=translation.ModelConfig(
encoder=translation.EncDecoder(), decoder=translation.EncDecoder(),
padded_decode=False,
decode_max_length=64),
validation_data=wmt_dataloader.WMTDataConfig(
input_path=self._record_input_path, src_lang="en",
tgt_lang="reverse_en", static_batch=True, global_batch_size=4),
sentencepiece_model_path=self._sentencepeice_model_path)
logging_dir = self.get_temp_dir()
task = translation.TranslationTask(config, logging_dir=logging_dir)
dataset = orbit.utils.make_distributed_dataset(tf.distribute.get_strategy(),
task.build_inputs,
config.validation_data)
model = task.build_model()
strategy = tf.distribute.get_strategy()
aggregated = None
for data in dataset:
distributed_outputs = strategy.run(
functools.partial(task.validation_step, model=model),
args=(data,))
outputs = tf.nest.map_structure(strategy.experimental_local_results,
distributed_outputs)
aggregated = task.aggregate_logs(state=aggregated, step_outputs=outputs)
metrics = task.reduce_aggregated_logs(aggregated)
self.assertIn("sacrebleu_score", metrics)
self.assertIn("bleu_score", metrics)
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment