Commit bb124157 authored by stephenwu's avatar stephenwu
Browse files

Merge branch 'master' of https://github.com/tensorflow/models into RTESuperGLUE

parents 2e9bb539 0edeb7f6
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,34 +11,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Util functions to integrate with Keras internals."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.keras import backend
try:
from tensorflow.python.keras.engine import keras_tensor # pylint: disable=g-import-not-at-top,unused-import
keras_tensor.disable_keras_tensors()
except ImportError:
keras_tensor = None
class NoOpContextManager(object):
def __enter__(self):
pass
def __exit__(self, *args):
pass
def maybe_enter_backend_graph():
if (keras_tensor is not None) and keras_tensor.keras_tensors_enabled():
return NoOpContextManager()
else:
return backend.get_graph().as_default()
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Exports a BERT-like encoder and its preprocessing as SavedModels for TF Hub.
This tool creates preprocessor and encoder SavedModels suitable for uploading
......@@ -145,7 +145,7 @@ flags.DEFINE_integer(
"sequence length for the bert_pack_inputs subobject."
"Needed for --export_type preprocessing.")
flags.DEFINE_bool(
"tokenize_with_offsets", False, # Broken by b/149576200.
"tokenize_with_offsets", False, # TODO(b/181866850)
"Whether to export a .tokenize_with_offsets subobject for "
"--export_type preprocessing.")
flags.DEFINE_multi_string(
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Library of components of export_tfhub.py. See docstring there for more."""
import contextlib
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests export_tfhub_lib."""
import os
......@@ -21,6 +21,7 @@ from absl.testing import parameterized
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text
from sentencepiece import SentencePieceTrainer
from official.modeling import tf_utils
......@@ -32,11 +33,11 @@ from official.nlp.tools import export_tfhub_lib
def _get_bert_config_or_encoder_config(use_bert_config, hidden_size,
num_hidden_layers):
num_hidden_layers, vocab_size=100):
"""Returns config args for export_tfhub_lib._create_model()."""
if use_bert_config:
bert_config = configs.BertConfig(
vocab_size=100,
vocab_size=vocab_size,
hidden_size=hidden_size,
intermediate_size=32,
max_position_embeddings=128,
......@@ -48,7 +49,7 @@ def _get_bert_config_or_encoder_config(use_bert_config, hidden_size,
encoder_config = encoders.EncoderConfig(
type="albert",
albert=encoders.AlbertEncoderConfig(
vocab_size=100,
vocab_size=vocab_size,
embedding_width=16,
hidden_size=hidden_size,
intermediate_size=32,
......@@ -450,11 +451,12 @@ _STRING_NOT_TO_LEAK = "private_path_component_"
class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
def _make_vocab_file(self, vocab, filename="vocab.txt"):
def _make_vocab_file(self, vocab, filename="vocab.txt", add_mask_token=False):
"""Creates wordpiece vocab file with given words plus special tokens.
The tokens of the resulting model are, in this order:
[PAD], [UNK], [CLS], [SEP], ...vocab...
[PAD], [UNK], [CLS], [SEP], [MASK]*, ...vocab...
*=if requested by args.
This function also accepts wordpieces that start with the ## continuation
marker, but avoiding those makes this function interchangeable with
......@@ -465,11 +467,13 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
model's vocabulary. Do not include special tokens here.
filename: Optionally, a filename (relative to the temporary directory
created by this function).
add_mask_token: an optional bool, whether to include a [MASK] token.
Returns:
The absolute filename of the created vocab file.
"""
full_vocab = ["[PAD]", "[UNK]", "[CLS]", "[SEP]"] + vocab
full_vocab = ["[PAD]", "[UNK]", "[CLS]", "[SEP]"
] + ["[MASK]"]*add_mask_token + vocab
path = os.path.join(
tempfile.mkdtemp(dir=self.get_temp_dir(), # New subdir each time.
prefix=_STRING_NOT_TO_LEAK),
......@@ -478,11 +482,12 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
f.write("\n".join(full_vocab + [""]))
return path
def _make_sp_model_file(self, vocab, prefix="spm"):
def _make_sp_model_file(self, vocab, prefix="spm", add_mask_token=False):
"""Creates Sentencepiece word model with given words plus special tokens.
The tokens of the resulting model are, in this order:
<pad>, <unk>, [CLS], [SEP], ...vocab..., <s>, </s>
<pad>, <unk>, [CLS], [SEP], [MASK]*, ...vocab..., <s>, </s>
*=if requested by args.
The words in the input vocab are plain text, without the whitespace marker.
That makes this function interchangeable with _make_vocab_file().
......@@ -492,6 +497,7 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
vocabulary. Do not include special tokens here.
prefix: an optional string, to change the filename prefix for the model
(relative to the temporary directory created by this function).
add_mask_token: an optional bool, whether to include a [MASK] token.
Returns:
The absolute filename of the created Sentencepiece model file.
......@@ -507,12 +513,16 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
input_text.append(" ".join([token] * (len(vocab) - i)))
with tf.io.gfile.GFile(input_file, "w") as f:
f.write("\n".join(input_text + [""]))
control_symbols = "[CLS],[SEP]"
full_vocab_size = len(vocab) + 6 # <pad>, <unk>, [CLS], [SEP], <s>, </s>.
if add_mask_token:
control_symbols += ",[MASK]"
full_vocab_size += 1
flags = dict(
model_prefix=model_prefix,
model_type="word",
input=input_file,
pad_id=0, unk_id=1, control_symbols="[CLS],[SEP]",
pad_id=0, unk_id=1, control_symbols=control_symbols,
vocab_size=full_vocab_size,
bos_id=full_vocab_size-2, eos_id=full_vocab_size-1)
SentencePieceTrainer.Train(
......@@ -521,14 +531,15 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
def _do_export(self, vocab, do_lower_case, default_seq_length=128,
tokenize_with_offsets=True, use_sp_model=False,
experimental_disable_assert=False):
experimental_disable_assert=False, add_mask_token=False):
"""Runs SavedModel export and returns the export_path."""
export_path = tempfile.mkdtemp(dir=self.get_temp_dir())
vocab_file = sp_model_file = None
if use_sp_model:
sp_model_file = self._make_sp_model_file(vocab)
sp_model_file = self._make_sp_model_file(vocab,
add_mask_token=add_mask_token)
else:
vocab_file = self._make_vocab_file(vocab)
vocab_file = self._make_vocab_file(vocab, add_mask_token=add_mask_token)
export_tfhub_lib.export_preprocessing(
export_path,
vocab_file=vocab_file,
......@@ -553,7 +564,7 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
def test_exported_callables(self, use_sp_model):
preprocess = tf.saved_model.load(self._do_export(
["d", "ef", "abc", "xy"], do_lower_case=True,
tokenize_with_offsets=not use_sp_model, # TODO(b/149576200): drop this.
tokenize_with_offsets=not use_sp_model, # TODO(b/181866850): drop this.
experimental_disable_assert=True, # TODO(b/175369555): drop this.
use_sp_model=use_sp_model))
......@@ -579,7 +590,7 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
# .tokenize_with_offsets()
if use_sp_model:
# TODO(b/149576200): Enable tokenize_with_offsets when it works and test.
# TODO(b/181866850): Enable tokenize_with_offsets when it works and test.
self.assertFalse(hasattr(preprocess, "tokenize_with_offsets"))
else:
token_ids, start_offsets, limit_offsets = (
......@@ -680,7 +691,7 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
def test_shapes(self, use_sp_model):
preprocess = tf.saved_model.load(self._do_export(
["abc", "def"], do_lower_case=True,
tokenize_with_offsets=not use_sp_model, # TODO(b/149576200): drop this.
tokenize_with_offsets=not use_sp_model, # TODO(b/181866850): drop this.
experimental_disable_assert=True, # TODO(b/175369555): drop this.
use_sp_model=use_sp_model))
......@@ -700,7 +711,7 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
tf.TensorSpec([batch_size], tf.string)),
token_out_shape,
"with batch_size=%s" % batch_size)
# TODO(b/149576200): Enable tokenize_with_offsets when it works and test.
# TODO(b/181866850): Enable tokenize_with_offsets when it works and test.
if use_sp_model:
self.assertFalse(hasattr(preprocess, "tokenize_with_offsets"))
else:
......@@ -751,6 +762,137 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
tf.constant([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]))
@parameterized.named_parameters(("Bert", True), ("Albert", False))
def test_preprocessing_for_mlm(self, use_bert):
"""Combines both SavedModel types and TF.text helpers for MLM."""
# Create the preprocessing SavedModel with a [MASK] token.
non_special_tokens = ["hello", "world",
"nice", "movie", "great", "actors",
"quick", "fox", "lazy", "dog"]
preprocess = tf.saved_model.load(self._do_export(
non_special_tokens, do_lower_case=True,
tokenize_with_offsets=use_bert, # TODO(b/181866850): drop this.
experimental_disable_assert=True, # TODO(b/175369555): drop this.
add_mask_token=True, use_sp_model=not use_bert))
vocab_size = len(non_special_tokens) + (5 if use_bert else 7)
# Create the encoder SavedModel with an .mlm subobject.
hidden_size = 16
num_hidden_layers = 2
bert_config, encoder_config = _get_bert_config_or_encoder_config(
use_bert, hidden_size, num_hidden_layers, vocab_size)
_, pretrainer = export_tfhub_lib._create_model(
bert_config=bert_config, encoder_config=encoder_config, with_mlm=True)
model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")
checkpoint = tf.train.Checkpoint(**pretrainer.checkpoint_items)
checkpoint.save(os.path.join(model_checkpoint_dir, "test"))
model_checkpoint_path = tf.train.latest_checkpoint(model_checkpoint_dir)
vocab_file, sp_model_file = _get_vocab_or_sp_model_dummy( # Not used below.
self.get_temp_dir(), use_sp_model=not use_bert)
encoder_export_path = os.path.join(self.get_temp_dir(), "encoder_export")
export_tfhub_lib.export_model(
export_path=encoder_export_path,
bert_config=bert_config,
encoder_config=encoder_config,
model_checkpoint_path=model_checkpoint_path,
with_mlm=True,
vocab_file=vocab_file,
sp_model_file=sp_model_file,
do_lower_case=True)
encoder = tf.saved_model.load(encoder_export_path)
# Get special tokens from the vocab (and vocab size).
special_tokens_dict = preprocess.tokenize.get_special_tokens_dict()
self.assertEqual(int(special_tokens_dict["vocab_size"]), vocab_size)
padding_id = int(special_tokens_dict["padding_id"])
self.assertEqual(padding_id, 0)
start_of_sequence_id = int(special_tokens_dict["start_of_sequence_id"])
self.assertEqual(start_of_sequence_id, 2)
end_of_segment_id = int(special_tokens_dict["end_of_segment_id"])
self.assertEqual(end_of_segment_id, 3)
mask_id = int(special_tokens_dict["mask_id"])
self.assertEqual(mask_id, 4)
# A batch of 3 segment pairs.
raw_segments = [tf.constant(["hello", "nice movie", "quick fox"]),
tf.constant(["world", "great actors", "lazy dog"])]
batch_size = 3
# Misc hyperparameters.
seq_length = 10
max_selections_per_seq = 2
# Tokenize inputs.
tokenized_segments = [preprocess.tokenize(s) for s in raw_segments]
# Trim inputs to eventually fit seq_lentgh.
num_special_tokens = len(raw_segments) + 1
trimmed_segments = text.WaterfallTrimmer(
seq_length - num_special_tokens).trim(tokenized_segments)
# Combine input segments into one input sequence.
input_ids, segment_ids = text.combine_segments(
trimmed_segments,
start_of_sequence_id=start_of_sequence_id,
end_of_segment_id=end_of_segment_id)
# Apply random masking controlled by policy objects.
(masked_input_ids, masked_lm_positions,
masked_ids) = text.mask_language_model(
input_ids=input_ids,
item_selector=text.RandomItemSelector(
max_selections_per_seq,
selection_rate=0.5, # Adjusted for the short test examples.
unselectable_ids=[start_of_sequence_id, end_of_segment_id]),
mask_values_chooser=text.MaskValuesChooser(
vocab_size=vocab_size, mask_token=mask_id,
# Always put [MASK] to have a predictable result.
mask_token_rate=1.0, random_token_rate=0.0))
# Pad to fixed-length Transformer encoder inputs.
input_word_ids, _ = text.pad_model_inputs(masked_input_ids,
seq_length,
pad_value=padding_id)
input_type_ids, input_mask = text.pad_model_inputs(segment_ids, seq_length,
pad_value=0)
masked_lm_positions, _ = text.pad_model_inputs(masked_lm_positions,
max_selections_per_seq,
pad_value=0)
masked_lm_positions = tf.cast(masked_lm_positions, tf.int32)
num_predictions = int(tf.shape(masked_lm_positions)[1])
# Test transformer inputs.
self.assertEqual(num_predictions, max_selections_per_seq)
expected_word_ids = np.array([
# [CLS] hello [SEP] world [SEP]
[2, 5, 3, 6, 3, 0, 0, 0, 0, 0],
# [CLS] nice movie [SEP] great actors [SEP]
[2, 7, 8, 3, 9, 10, 3, 0, 0, 0],
# [CLS] brown fox [SEP] lazy dog [SEP]
[2, 11, 12, 3, 13, 14, 3, 0, 0, 0]])
for i in range(batch_size):
for j in range(num_predictions):
k = int(masked_lm_positions[i, j])
if k != 0:
expected_word_ids[i, k] = 4 # [MASK]
self.assertAllEqual(input_word_ids, expected_word_ids)
# Call the MLM head of the Transformer encoder.
mlm_inputs = dict(
input_word_ids=input_word_ids,
input_mask=input_mask,
input_type_ids=input_type_ids,
masked_lm_positions=masked_lm_positions,
)
mlm_outputs = encoder.mlm(mlm_inputs)
self.assertEqual(mlm_outputs["pooled_output"].shape,
(batch_size, hidden_size))
self.assertEqual(mlm_outputs["sequence_output"].shape,
(batch_size, seq_length, hidden_size))
self.assertEqual(mlm_outputs["mlm_logits"].shape,
(batch_size, num_predictions, vocab_size))
self.assertLen(mlm_outputs["encoder_outputs"], num_hidden_layers)
# A real trainer would now compute the loss of mlm_logits
# trying to predict the masked_ids.
del masked_ids # Unused.
@parameterized.named_parameters(("Bert", False), ("Sentencepiece", True))
def test_special_tokens_in_estimator(self, use_sp_model):
"""Tests getting special tokens without an Eager init context."""
......
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TFM common training driver."""
from absl import app
......@@ -47,7 +46,8 @@ def main(_):
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale)
params.runtime.loss_scale,
use_experimental_api=True)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
......
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TFM continuous finetuning+eval training driver."""
from absl import app
from absl import flags
......@@ -39,8 +38,8 @@ def main(_):
params = train_utils.parse_configuration(FLAGS)
model_dir = FLAGS.model_dir
train_utils.serialize_config(params, model_dir)
continuous_finetune_lib.run_continuous_finetune(FLAGS.mode, params, model_dir,
FLAGS.pretrain_steps)
continuous_finetune_lib.run_continuous_finetune(
FLAGS.mode, params, model_dir, pretrain_steps=FLAGS.pretrain_steps)
train_utils.save_gin_config(FLAGS.mode, model_dir)
......
......@@ -3,9 +3,11 @@ This is an implementation of the Transformer translation model as described in
the [Attention is All You Need](https://arxiv.org/abs/1706.03762) paper. The
implementation leverages tf.keras and makes sure it is compatible with TF 2.x.
**Note: this transformer folder is subject to be integrated into official/nlp
folder. Due to its dependencies, we will finish the refactoring after the model
garden 2.1 release.**
**Warning: the features in the `transformer/` folder have been fully intergrated
into nlp/modeling.
Due to its dependencies, we will remove this folder after the model
garden 2.5 release. The model in `nlp/modeling/models/seq2seq_transformer.py` is
identical to the model in this folder.**
## Contents
* [Contents](#contents)
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of multiheaded attention and self-attention layers."""
import math
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Beam search to find the translated sequence with the highest probability."""
import tensorflow.compat.v1 as tf
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Script to compute official BLEU score.
Source:
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test functions in compute_blue.py."""
import tempfile
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Download and preprocess WMT17 ende training and evaluation datasets."""
import os
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Input pipeline for the transformer model to read, filter, and batch examples.
Two things to note in the pipeline:
......@@ -242,7 +242,7 @@ def _read_and_batch_from_files(file_pattern,
num_parallel_calls=tf.data.experimental.AUTOTUNE).with_options(options)
# Parse each tf.Example into a dictionary
# TODO: Look into prefetch_input_elements for performance optimization.
# TODO: Look into prefetch_input_elements for performance optimization. # pylint: disable=g-bad-todo
dataset = dataset.map(
_parse_example, num_parallel_calls=tf.data.experimental.AUTOTUNE)
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of embedding layer with shared weights."""
import tensorflow as tf
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of fully connected network."""
import tensorflow as tf
......@@ -62,8 +62,6 @@ class FeedForwardNetwork(tf.keras.layers.Layer):
tensor with shape [batch_size, length, hidden_size]
"""
# Retrieve dynamically known shapes
batch_size = tf.shape(x)[0]
length = tf.shape(x)[1]
output = self.filter_dense_layer(x)
if training:
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions for calculating loss, accuracy, and other model metrics.
Metrics:
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Misc for Transformer."""
# pylint: disable=g-bad-import-order
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,13 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines Transformer model parameters."""
from collections import defaultdict
import collections
BASE_PARAMS = defaultdict(
BASE_PARAMS = collections.defaultdict(
lambda: None, # Set default value to None.
# Input params
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Transformer model helper methods."""
import math
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment