Commit 1eb0ec0e authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Define the package pattern for keras-nlp

PiperOrigin-RevId: 328872353
parent 3fc70674
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras-NLP package definition."""
# pylint: disable=wildcard-import
from official.nlp.keras_nlp.layers import *
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras-NLP layers package definition."""
from official.nlp.keras_nlp.layers.transformer_encoder_block import TransformerEncoderBlock
......@@ -246,16 +246,5 @@ class TransformerArgumentTest(keras_parameterized.TestCase):
self.assertEqual(encoder_block_config, new_encoder_block.get_config())
def _create_cache(batch_size, init_decode_length, num_heads, head_size):
return {
'key':
tf.zeros([batch_size, init_decode_length, num_heads, head_size],
dtype=tf.float32),
'value':
tf.zeros([batch_size, init_decode_length, num_heads, head_size],
dtype=tf.float32)
}
if __name__ == '__main__':
tf.test.main()
......@@ -18,14 +18,14 @@
import gin
import tensorflow as tf
from official.nlp.keras_nlp.layers import transformer_encoder_block
from official.nlp import keras_nlp
from official.nlp.modeling.layers import attention
from official.nlp.modeling.layers import multi_channel_attention
from official.nlp.modeling.layers.util import tf_function_if_eager
@tf.keras.utils.register_keras_serializable(package="Text")
class Transformer(transformer_encoder_block.TransformerEncoderBlock):
class Transformer(keras_nlp.TransformerEncoderBlock):
"""Transformer layer.
This layer implements the Transformer from "Attention Is All You Need".
......
......@@ -20,7 +20,7 @@ import math
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.keras_nlp.layers import transformer_encoder_block
from official.nlp import keras_nlp
from official.nlp.modeling import layers
from official.nlp.modeling.ops import beam_search
from official.nlp.transformer import metrics
......@@ -472,7 +472,7 @@ class TransformerEncoder(tf.keras.layers.Layer):
self.encoder_layers = []
for i in range(self.num_layers):
self.encoder_layers.append(
transformer_encoder_block.TransformerEncoderBlock(
keras_nlp.TransformerEncoderBlock(
num_attention_heads=self.num_attention_heads,
inner_dim=self._intermediate_size,
inner_activation=self._activation,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment