Commit 5a898973 authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 361957289
parent 7e3d5270
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines the Transformer model in TF 2.0.
Model paper: https://arxiv.org/pdf/1706.03762.pdf
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Forward pass test for Transformer model refactoring."""
import numpy as np
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for layers in Transformer."""
import tensorflow as tf
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Train and evaluate the Transformer model.
See README for description of setting the training schedule and evaluating the
......@@ -38,6 +38,7 @@ from official.nlp.transformer import translate
from official.nlp.transformer.utils import tokenizer
from official.utils.flags import core as flags_core
from official.utils.misc import keras_utils
# pylint:disable=logging-format-interpolation
INF = int(1e9)
BLEU_DIR = "bleu"
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test Transformer model."""
import os
......@@ -25,7 +25,6 @@ import tensorflow as tf
from tensorflow.python.eager import context # pylint: disable=ungrouped-imports
from official.nlp.transformer import misc
from official.nlp.transformer import transformer_main
from official.utils.misc import keras_utils
FLAGS = flags.FLAGS
FIXED_TIMESTAMP = 'my_time_stamp'
......@@ -41,7 +40,7 @@ def _generate_file(filepath, lines):
class TransformerTaskTest(tf.test.TestCase):
local_flags = None
def setUp(self):
def setUp(self): # pylint: disable=g-missing-super-call
temp_dir = self.get_temp_dir()
if TransformerTaskTest.local_flags is None:
misc.define_transformer_flags()
......@@ -70,7 +69,7 @@ class TransformerTaskTest(tf.test.TestCase):
self.orig_policy = (
tf.compat.v2.keras.mixed_precision.experimental.global_policy())
def tearDown(self):
def tearDown(self): # pylint: disable=g-missing-super-call
tf.compat.v2.keras.mixed_precision.experimental.set_policy(self.orig_policy)
def _assert_exists(self, filepath):
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test Transformer model."""
import tensorflow as tf
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Translate text or files using trained transformer model."""
# Import libraries
......@@ -146,7 +146,7 @@ def translate_file(model,
def text_as_per_replica():
replica_context = tf.distribute.get_replica_context()
replica_id = replica_context.replica_id_in_sync_group
return replica_id, text[replica_id]
return replica_id, text[replica_id] # pylint: disable=cell-var-from-loop
text = distribution_strategy.run(text_as_per_replica)
outputs = distribution_strategy.experimental_local_results(
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions for calculating loss, accuracy, and other model metrics.
Metrics:
......@@ -203,7 +203,7 @@ def bleu_score(logits, labels):
bleu: int, approx bleu score
"""
predictions = tf.cast(tf.argmax(logits, axis=-1), tf.int32)
# TODO: Look into removing use of py_func
# TODO: Look into removing use of py_func # pylint: disable=g-bad-todo
bleu = tf.py_func(compute_bleu, (labels, predictions), tf.float32)
return bleu, tf.constant(1.0)
......@@ -308,7 +308,7 @@ def rouge_2_fscore(logits, labels):
rouge2_fscore: approx rouge-2 f1 score.
"""
predictions = tf.cast(tf.argmax(logits, axis=-1), tf.int32)
# TODO: Look into removing use of py_func
# TODO: Look into removing use of py_func # pylint: disable=g-bad-todo
rouge_2_f_score = tf.py_func(rouge_n, (predictions, labels), tf.float32)
return rouge_2_f_score, tf.constant(1.0)
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines Subtokenizer class to encode and decode strings."""
from __future__ import absolute_import
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test Subtokenizer and string helper methods."""
import collections
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for pre-processing classification data."""
from absl import logging
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Common flags used in XLNet model."""
from absl import flags
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities used for data preparation."""
import collections
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions and classes related to optimization (weight updates)."""
from absl import logging
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Script to pre-process classification data into tfrecords."""
import collections
......
# -*- coding: utf-8 -*-
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# -*- coding: utf-8 -*-
"""Script to pre-process pre-training data into tfrecords."""
import json
......@@ -34,15 +34,15 @@ FLAGS = flags.FLAGS
special_symbols = {
"<unk>" : 0,
"<s>" : 1,
"</s>" : 2,
"<cls>" : 3,
"<sep>" : 4,
"<pad>" : 5,
"<mask>" : 6,
"<eod>" : 7,
"<eop>" : 8,
"<unk>": 0,
"<s>": 1,
"</s>": 2,
"<cls>": 3,
"<sep>": 4,
"<pad>": 5,
"<mask>": 6,
"<eod>": 7,
"<eop>": 8,
}
VOCAB_SIZE = 32000
......@@ -627,6 +627,7 @@ def _convert_example(example, use_bfloat16):
def parse_files_to_dataset(parser, file_names, split, num_batch, num_hosts,
host_id, num_core_per_host, bsz_per_core):
"""Parses files to a dataset."""
del num_batch
# list of file pathes
num_files = len(file_names)
num_files_per_host = num_files // num_hosts
......@@ -733,6 +734,8 @@ def get_dataset(params, num_hosts, num_core_per_host, split, file_names,
mask_beta, use_bfloat16=False, num_predict=None):
"""Gets the dataset."""
del mask_alpha
del mask_beta
bsz_per_core = params["batch_size"]
if num_hosts > 1:
host_id = params["context"].current_host
......
# coding=utf-8
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# coding=utf-8
"""Script to pre-process SQUAD data into tfrecords."""
import os
......
# coding=utf-8
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# coding=utf-8
"""Utilities for pre-processing."""
import unicodedata
......@@ -36,7 +36,7 @@ def printable_text(text):
elif six.PY2:
if isinstance(text, str):
return text
elif isinstance(text, unicode):
elif isinstance(text, unicode): # pylint: disable=undefined-variable
return text.encode('utf-8')
else:
raise ValueError('Unsupported string type: %s' % (type(text)))
......@@ -81,7 +81,7 @@ def encode_pieces(sp_model, text, return_unicode=True, sample=False):
"""Encodes pieces."""
# return_unicode is used only for py2
if six.PY2 and isinstance(text, unicode):
if six.PY2 and isinstance(text, unicode): # pylint: disable=undefined-variable
text = text.encode('utf-8')
if not sample:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment