Commit bb124157 authored by stephenwu's avatar stephenwu
Browse files

Merge branch 'master' of https://github.com/tensorflow/models into RTESuperGLUE

parents 2e9bb539 0edeb7f6
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for dual encoder network."""
from absl.testing import parameterized
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trainer network for ELECTRA models."""
# pylint: disable=g-classes-have-attributes
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for ELECTRA pre trainer network."""
import tensorflow as tf
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implement Seq2Seq Transformer model by TF official NLP library.
Model paper: https://arxiv.org/pdf/1706.03762.pdf
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test Transformer model."""
from absl import logging
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""XLNet models."""
# pylint: disable=g-classes-have-attributes
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for XLNet classifier network."""
from absl.testing import parameterized
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Networks package definition."""
from official.nlp.modeling.networks.albert_encoder import AlbertEncoder
from official.nlp.modeling.networks.bert_encoder import BertEncoder
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ALBERT (https://arxiv.org/abs/1810.04805) text encoder network."""
# pylint: disable=g-classes-have-attributes
import collections
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for ALBERT transformer-based text encoder network."""
from __future__ import absolute_import
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Transformer-based text encoder network."""
# pylint: disable=g-classes-have-attributes
import collections
......@@ -65,7 +65,7 @@ class BertEncoder(keras_nlp.encoders.BertEncoder):
keyed by `encoder_outputs`.
output_range: The sequence output range, [0, output_range), by slicing the
target sequence of the last transformer layer. `None` means the entire
target sequence will attend to the source sequence, which yeilds the full
target sequence will attend to the source sequence, which yields the full
output.
embedding_width: The width of the word embeddings. If the embedding width is
not equal to hidden size, embedding parameters will be factorized into two
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for transformer-based bert encoder network."""
# Import libraries
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Classification and regression network."""
# pylint: disable=g-classes-have-attributes
import collections
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for classification network."""
from __future__ import absolute_import
......
# Lint as: python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Transformer-based text encoder network."""
# pylint: disable=g-classes-have-attributes
import inspect
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for EncoderScaffold network."""
from absl.testing import parameterized
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""MobileBERT text encoder network."""
import gin
import tensorflow as tf
......@@ -43,6 +43,7 @@ class MobileBERTEncoder(tf.keras.Model):
num_feedforward_networks=4,
normalization_type='no_norm',
classifier_activation=False,
input_mask_dtype='int32',
**kwargs):
"""Class initialization.
......@@ -76,6 +77,11 @@ class MobileBERTEncoder(tf.keras.Model):
MobileBERT paper. 'layer_norm' is used for the teacher model.
classifier_activation: If using the tanh activation for the final
representation of the [CLS] token in fine-tuning.
input_mask_dtype: The dtype of `input_mask` tensor, which is one of the
input tensors of this encoder. Defaults to `int32`. If you want
to use `tf.lite` quantization, which does not support `Cast` op,
please set this argument to `tf.float32` and feed `input_mask`
tensor with values in float32 to avoid `tf.cast` in the computation.
**kwargs: Other keyworded and arguments.
"""
self._self_setattr_tracking = False
......@@ -115,11 +121,14 @@ class MobileBERTEncoder(tf.keras.Model):
input_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name='input_word_ids')
input_mask = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name='input_mask')
shape=(None,), dtype=input_mask_dtype, name='input_mask')
type_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name='input_type_ids')
self.inputs = [input_ids, input_mask, type_ids]
attention_mask = keras_nlp.layers.SelfAttentionMask()(input_ids, input_mask)
# The dtype of `attention_mask` will the same as the dtype of `input_mask`.
attention_mask = keras_nlp.layers.SelfAttentionMask()(input_mask,
input_mask)
# build the computation graph
all_layer_outputs = []
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from absl.testing import parameterized
import numpy as np
......@@ -21,7 +21,7 @@ from official.nlp.modeling.networks import mobile_bert_encoder
def generate_fake_input(batch_size=1, seq_len=5, vocab_size=10000, seed=0):
"""Generate consisitant fake integer input sequences."""
"""Generate consistent fake integer input sequences."""
np.random.seed(seed)
fake_input = []
for _ in range(batch_size):
......@@ -89,7 +89,8 @@ class MobileBertEncoderTest(parameterized.TestCase, tf.test.TestCase):
self.assertIsInstance(all_layer_output, list)
self.assertLen(all_layer_output, num_blocks + 1)
def test_mobilebert_encoder_invocation(self):
@parameterized.parameters('int32', 'float32')
def test_mobilebert_encoder_invocation(self, input_mask_dtype):
vocab_size = 100
hidden_size = 32
sequence_length = 16
......@@ -97,10 +98,11 @@ class MobileBertEncoderTest(parameterized.TestCase, tf.test.TestCase):
test_network = mobile_bert_encoder.MobileBERTEncoder(
word_vocab_size=vocab_size,
hidden_size=hidden_size,
num_blocks=num_blocks)
num_blocks=num_blocks,
input_mask_dtype=input_mask_dtype)
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=input_mask_dtype)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
outputs = test_network([word_ids, mask, type_ids])
model = tf.keras.Model([word_ids, mask, type_ids], outputs)
......
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""An embedding network supporting packed sequences and position ids."""
# pylint: disable=g-classes-have-attributes
import collections
......
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for official.nlp.modeling.networks.packed_sequence_embedding."""
# Import libraries
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment