"vscode:/vscode.git/clone" did not exist on "1b82388ab855a37bd4bc28ba9a4672c7e843b9e8"
Commit b9cab01b authored by Tian Lin's avatar Tian Lin Committed by Toby Boyd
Browse files

Merged commit that fixes transformer's predict and eval. (#6874)

* Merged commit includes the following changes:
249776315  by tianlin<tianlin@google.com>:

    Internal change

249763206  by tianlin<tianlin@google.com>:

    For TF 2.0 (related to Beam Search), expand cond dims in tf.where(cond, x, y) to make all parameters broadcastable.

--
249392724  by hongkuny<hongkuny@google.com>:

    Internal change

PiperOrigin-RevId: 249776315

* Merged commit includes the following changes:
249823043  by tianlin<tianlin@google.com>:

    Bring back v2 test for predict and eval.

--

PiperOrigin-RevId: 249823043
parent 92bad0d2
......@@ -27,7 +27,7 @@ import os
import sys
# pylint: disable=g-bad-import-order
from absl import app as absl_app
from absl import app as absl_app # pylint: disable=unused-import
import tensorflow as tf
# pylint: enable=g-bad-import-order
......
......@@ -112,7 +112,6 @@ def _get_train_and_eval_data(producer, params):
preprocess_train_input)
train_input_dataset = train_input_dataset.repeat(FLAGS.train_epochs)
def preprocess_eval_input(features):
"""Pre-process the eval data.
......
......@@ -569,8 +569,9 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
self._run_and_report_benchmark()
def benchmark_xla_8_gpu_fp16_tweaked_optional_next(self):
"""Test Keras model with manual config tuning, XLA, 8 GPUs, fp16 and
enabling get_next_as_optional.
"""Test Keras model with manual config tuning, XLA, 8 GPUs, fp16.
This test also enables get_next_as_optional.
"""
self._setup()
......@@ -589,8 +590,9 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
self._run_and_report_benchmark()
def benchmark_xla_8_gpu_fp16_slack(self):
"""Test Keras model with tf.data's experimental_slack functionality, XLA,
8 GPUs and fp16.
"""Test Keras model with XLA, 8 GPUs and fp16.
This test also enable tf.data's experimental_slack functionality.
"""
self._setup()
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Beam search in TF v2.
"""
import tensorflow as tf
from official.transformer.model import beam_search as v1
from official.transformer.v2 import misc
_StateKeys = v1._StateKeys # pylint: disable=protected-access
class SequenceBeamSearchV2(v1.SequenceBeamSearch):
"""Implementation of beam search loop in v2."""
def search(self, initial_ids, initial_cache):
"""Beam search for sequences with highest scores."""
state, state_shapes = self._create_initial_state(initial_ids, initial_cache)
finished_state = tf.while_loop(
self._continue_search, self._search_step, loop_vars=[state],
shape_invariants=[state_shapes], parallel_iterations=1, back_prop=False)
finished_state = finished_state[0]
alive_seq = finished_state[_StateKeys.ALIVE_SEQ]
alive_log_probs = finished_state[_StateKeys.ALIVE_LOG_PROBS]
finished_seq = finished_state[_StateKeys.FINISHED_SEQ]
finished_scores = finished_state[_StateKeys.FINISHED_SCORES]
finished_flags = finished_state[_StateKeys.FINISHED_FLAGS]
# 2.0 changes tf.where behavior. Should make parameters broadcastable.
finished_cond = tf.reduce_any(finished_flags, 1, name="finished_cond")
seq_cond = _expand_to_same_rank(finished_cond, finished_seq)
score_cond = _expand_to_same_rank(finished_cond, finished_scores)
# Account for corner case where there are no finished sequences for a
# particular batch item. In that case, return alive sequences for that batch
# item.
finished_seq = tf.where(seq_cond, finished_seq, alive_seq)
finished_scores = tf.where(score_cond, finished_scores, alive_log_probs)
return finished_seq, finished_scores
def sequence_beam_search(
symbols_to_logits_fn, initial_ids, initial_cache, vocab_size, beam_size,
alpha, max_decode_length, eos_id):
"""Search for sequence of subtoken ids with the largest probability.
Args:
symbols_to_logits_fn: A function that takes in ids, index, and cache as
arguments. The passed in arguments will have shape:
ids -> [batch_size * beam_size, index]
index -> [] (scalar)
cache -> nested dictionary of tensors [batch_size * beam_size, ...]
The function must return logits and new cache.
logits -> [batch * beam_size, vocab_size]
new cache -> same shape/structure as inputted cache
initial_ids: Starting ids for each batch item.
int32 tensor with shape [batch_size]
initial_cache: dict containing starting decoder variables information
vocab_size: int size of tokens
beam_size: int number of beams
alpha: float defining the strength of length normalization
max_decode_length: maximum length to decoded sequence
eos_id: int id of eos token, used to determine when a sequence has finished
Returns:
Top decoded sequences [batch_size, beam_size, max_decode_length]
sequence scores [batch_size, beam_size]
"""
batch_size = tf.shape(initial_ids)[0]
if misc.is_v2():
sbs = SequenceBeamSearchV2(symbols_to_logits_fn, vocab_size, batch_size,
beam_size, alpha, max_decode_length, eos_id)
else:
sbs = v1.SequenceBeamSearch(symbols_to_logits_fn, vocab_size, batch_size,
beam_size, alpha, max_decode_length, eos_id)
return sbs.search(initial_ids, initial_cache)
def _expand_to_same_rank(tensor, target):
"""Expands a given tensor to target's rank to be broadcastable.
Args:
tensor: input tensor to tile. Shape: [b, d1, ..., da]
target: target tensor. Shape: [b, d1, ..., da, ..., dn]
Returns:
Tiled tensor of shape [b, d1, ..., da, 1, ..., 1] with same rank of target.
Raises:
ValueError, if the shape rank of rank tensor/target is None.
"""
if tensor.shape.rank is None:
raise ValueError("Expect rank for tensor shape, but got None.")
if target.shape.rank is None:
raise ValueError("Expect rank for target shape, but got None.")
with tf.name_scope("expand_rank"):
diff_rank = target.shape.rank - tensor.shape.rank
for _ in range(diff_rank):
tensor = tf.expand_dims(tensor, -1)
return tensor
......@@ -56,10 +56,7 @@ import os
import tensorflow as tf
# TODO(tianlin) Import internal library. Remove this when different behaviors
# of keras_model.fit(dataset, ...) for different TF versions are fixed.
from tensorflow.python import tf2 as tf2_internal
from official.transformer.v2 import misc
from official.utils.misc import model_helpers
# Buffer size for reading records from a TFRecord file. Each training file is
......@@ -292,7 +289,7 @@ def eval_input_fn(params):
def map_data_for_transformer_fn(x, y):
"""Maps data for training, and handles weried behaviors for different vers."""
# Will transform input x and targets y into tuple(x, y) as new model inputs.
if tf2_internal.enabled():
if misc.is_v2():
# For TF v2, the 2nd parameter is omitted to make Keras training work.
return ((x, y),)
else:
......
......@@ -143,6 +143,7 @@ class MetricLayer(tf.keras.layers.Layer):
self.metric_mean_fns = []
def build(self, input_shape):
""""Builds metric layer."""
neg_log_perplexity = functools.partial(
padded_neg_log_perplexity, vocab_size=self.vocab_size)
self.metric_mean_fns = [
......
......@@ -20,6 +20,10 @@ from __future__ import print_function
from absl import flags
# TODO(tianlin) Import internal library. Remove this when some functions for
# different TF versions are fixed.
from tensorflow.python import tf2 as tf2_internal
from official.transformer.model import model_params
from official.utils.flags import core as flags_core
......@@ -30,6 +34,11 @@ PARAMS_MAP = {
}
def is_v2():
"""Returns whether it is v2."""
return tf2_internal.enabled()
def get_model_params(param_set, num_gpus):
"""Gets predefined model params."""
if num_gpus > 1:
......
......@@ -132,6 +132,7 @@ class LearningRateScheduler(tf.keras.callbacks.Callback):
raise ValueError('Optimizer must have a "iterations" attribute.')
def on_train_batch_begin(self, batch, logs=None):
"""Adjusts learning rate for each train batch."""
if self.verbose > 0:
iterations = K.get_value(self.model.optimizer.iterations)
print('Original iteration %d' % iterations)
......
......@@ -23,10 +23,10 @@ from __future__ import print_function
import tensorflow as tf
from official.transformer.model import beam_search
from official.transformer.model import model_utils
from official.transformer.utils.tokenizer import EOS_ID
from official.transformer.v2 import attention_layer
from official.transformer.v2 import beam_search
from official.transformer.v2 import embedding_layer
from official.transformer.v2 import ffn_layer
from official.transformer.v2 import metrics
......
......@@ -23,7 +23,6 @@ import re
from absl import flags
import tensorflow as tf
from tensorflow.python.framework import test_util
from official.transformer.v2 import misc
from official.transformer.v2 import transformer_main as tm
......@@ -108,13 +107,11 @@ class TransformerTaskTest(tf.test.TestCase):
update_flags.extend(extra_flags)
FLAGS(update_flags)
@test_util.run_v1_only("V1 should work. Issue: V2 w/ graph transformed.")
def test_predict(self):
self._prepare_files_and_flags()
t = tm.TransformerTask(FLAGS)
t.predict()
@test_util.run_v1_only("V1 should work. Issue: V2 w/ graph transformed.")
def test_eval(self):
self._prepare_files_and_flags()
t = tm.TransformerTask(FLAGS)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment