Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
bb124157
Commit
bb124157
authored
Mar 10, 2021
by
stephenwu
Browse files
Merge branch 'master' of
https://github.com/tensorflow/models
into RTESuperGLUE
parents
2e9bb539
0edeb7f6
Changes
386
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
238 additions
and
68 deletions
+238
-68
official/nlp/tools/__init__.py
official/nlp/tools/__init__.py
+15
-0
official/nlp/tools/export_tfhub.py
official/nlp/tools/export_tfhub.py
+3
-3
official/nlp/tools/export_tfhub_lib.py
official/nlp/tools/export_tfhub_lib.py
+2
-2
official/nlp/tools/export_tfhub_lib_test.py
official/nlp/tools/export_tfhub_lib_test.py
+160
-18
official/nlp/train.py
official/nlp/train.py
+4
-4
official/nlp/train_ctl_continuous_finetune.py
official/nlp/train_ctl_continuous_finetune.py
+4
-5
official/nlp/transformer/README.md
official/nlp/transformer/README.md
+5
-3
official/nlp/transformer/__init__.py
official/nlp/transformer/__init__.py
+14
-0
official/nlp/transformer/attention_layer.py
official/nlp/transformer/attention_layer.py
+2
-2
official/nlp/transformer/beam_search_v1.py
official/nlp/transformer/beam_search_v1.py
+2
-2
official/nlp/transformer/compute_bleu.py
official/nlp/transformer/compute_bleu.py
+2
-2
official/nlp/transformer/compute_bleu_test.py
official/nlp/transformer/compute_bleu_test.py
+2
-2
official/nlp/transformer/data_download.py
official/nlp/transformer/data_download.py
+2
-2
official/nlp/transformer/data_pipeline.py
official/nlp/transformer/data_pipeline.py
+3
-3
official/nlp/transformer/embedding_layer.py
official/nlp/transformer/embedding_layer.py
+2
-2
official/nlp/transformer/ffn_layer.py
official/nlp/transformer/ffn_layer.py
+2
-4
official/nlp/transformer/metrics.py
official/nlp/transformer/metrics.py
+4
-4
official/nlp/transformer/misc.py
official/nlp/transformer/misc.py
+4
-4
official/nlp/transformer/model_params.py
official/nlp/transformer/model_params.py
+4
-4
official/nlp/transformer/model_utils.py
official/nlp/transformer/model_utils.py
+2
-2
No files found.
official/
vision/detection/modeling/architecture/keras_utils
.py
→
official/
nlp/tools/__init__
.py
View file @
bb124157
# Copyright 201
9
The TensorFlow Authors. All Rights Reserved.
# Copyright 20
2
1 The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -11,34 +11,5 @@
...
@@ -11,34 +11,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
"""Util functions to integrate with Keras internals."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
tensorflow.python.keras
import
backend
try
:
from
tensorflow.python.keras.engine
import
keras_tensor
# pylint: disable=g-import-not-at-top,unused-import
keras_tensor
.
disable_keras_tensors
()
except
ImportError
:
keras_tensor
=
None
class
NoOpContextManager
(
object
):
def
__enter__
(
self
):
pass
def
__exit__
(
self
,
*
args
):
pass
def
maybe_enter_backend_graph
():
if
(
keras_tensor
is
not
None
)
and
keras_tensor
.
keras_tensors_enabled
():
return
NoOpContextManager
()
else
:
return
backend
.
get_graph
().
as_default
()
official/nlp/tools/export_tfhub.py
View file @
bb124157
# Copyright 202
0
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
r
"""Exports a BERT-like encoder and its preprocessing as SavedModels for TF Hub.
r
"""Exports a BERT-like encoder and its preprocessing as SavedModels for TF Hub.
This tool creates preprocessor and encoder SavedModels suitable for uploading
This tool creates preprocessor and encoder SavedModels suitable for uploading
...
@@ -145,7 +145,7 @@ flags.DEFINE_integer(
...
@@ -145,7 +145,7 @@ flags.DEFINE_integer(
"sequence length for the bert_pack_inputs subobject."
"sequence length for the bert_pack_inputs subobject."
"Needed for --export_type preprocessing."
)
"Needed for --export_type preprocessing."
)
flags
.
DEFINE_bool
(
flags
.
DEFINE_bool
(
"tokenize_with_offsets"
,
False
,
#
Broken by b/149576200.
"tokenize_with_offsets"
,
False
,
#
TODO(b/181866850)
"Whether to export a .tokenize_with_offsets subobject for "
"Whether to export a .tokenize_with_offsets subobject for "
"--export_type preprocessing."
)
"--export_type preprocessing."
)
flags
.
DEFINE_multi_string
(
flags
.
DEFINE_multi_string
(
...
...
official/nlp/tools/export_tfhub_lib.py
View file @
bb124157
# Copyright 202
0
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
"""Library of components of export_tfhub.py. See docstring there for more."""
"""Library of components of export_tfhub.py. See docstring there for more."""
import
contextlib
import
contextlib
...
...
official/nlp/tools/export_tfhub_lib_test.py
View file @
bb124157
# Copyright 201
9
The TensorFlow Authors. All Rights Reserved.
# Copyright 20
2
1 The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
"""Tests export_tfhub_lib."""
"""Tests export_tfhub_lib."""
import
os
import
os
...
@@ -21,6 +21,7 @@ from absl.testing import parameterized
...
@@ -21,6 +21,7 @@ from absl.testing import parameterized
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
import
tensorflow_hub
as
hub
import
tensorflow_hub
as
hub
import
tensorflow_text
as
text
from
sentencepiece
import
SentencePieceTrainer
from
sentencepiece
import
SentencePieceTrainer
from
official.modeling
import
tf_utils
from
official.modeling
import
tf_utils
...
@@ -32,11 +33,11 @@ from official.nlp.tools import export_tfhub_lib
...
@@ -32,11 +33,11 @@ from official.nlp.tools import export_tfhub_lib
def
_get_bert_config_or_encoder_config
(
use_bert_config
,
hidden_size
,
def
_get_bert_config_or_encoder_config
(
use_bert_config
,
hidden_size
,
num_hidden_layers
):
num_hidden_layers
,
vocab_size
=
100
):
"""Returns config args for export_tfhub_lib._create_model()."""
"""Returns config args for export_tfhub_lib._create_model()."""
if
use_bert_config
:
if
use_bert_config
:
bert_config
=
configs
.
BertConfig
(
bert_config
=
configs
.
BertConfig
(
vocab_size
=
100
,
vocab_size
=
vocab_size
,
hidden_size
=
hidden_size
,
hidden_size
=
hidden_size
,
intermediate_size
=
32
,
intermediate_size
=
32
,
max_position_embeddings
=
128
,
max_position_embeddings
=
128
,
...
@@ -48,7 +49,7 @@ def _get_bert_config_or_encoder_config(use_bert_config, hidden_size,
...
@@ -48,7 +49,7 @@ def _get_bert_config_or_encoder_config(use_bert_config, hidden_size,
encoder_config
=
encoders
.
EncoderConfig
(
encoder_config
=
encoders
.
EncoderConfig
(
type
=
"albert"
,
type
=
"albert"
,
albert
=
encoders
.
AlbertEncoderConfig
(
albert
=
encoders
.
AlbertEncoderConfig
(
vocab_size
=
100
,
vocab_size
=
vocab_size
,
embedding_width
=
16
,
embedding_width
=
16
,
hidden_size
=
hidden_size
,
hidden_size
=
hidden_size
,
intermediate_size
=
32
,
intermediate_size
=
32
,
...
@@ -450,11 +451,12 @@ _STRING_NOT_TO_LEAK = "private_path_component_"
...
@@ -450,11 +451,12 @@ _STRING_NOT_TO_LEAK = "private_path_component_"
class
ExportPreprocessingTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
class
ExportPreprocessingTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
_make_vocab_file
(
self
,
vocab
,
filename
=
"vocab.txt"
):
def
_make_vocab_file
(
self
,
vocab
,
filename
=
"vocab.txt"
,
add_mask_token
=
False
):
"""Creates wordpiece vocab file with given words plus special tokens.
"""Creates wordpiece vocab file with given words plus special tokens.
The tokens of the resulting model are, in this order:
The tokens of the resulting model are, in this order:
[PAD], [UNK], [CLS], [SEP], ...vocab...
[PAD], [UNK], [CLS], [SEP], [MASK]*, ...vocab...
*=if requested by args.
This function also accepts wordpieces that start with the ## continuation
This function also accepts wordpieces that start with the ## continuation
marker, but avoiding those makes this function interchangeable with
marker, but avoiding those makes this function interchangeable with
...
@@ -465,11 +467,13 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -465,11 +467,13 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
model's vocabulary. Do not include special tokens here.
model's vocabulary. Do not include special tokens here.
filename: Optionally, a filename (relative to the temporary directory
filename: Optionally, a filename (relative to the temporary directory
created by this function).
created by this function).
add_mask_token: an optional bool, whether to include a [MASK] token.
Returns:
Returns:
The absolute filename of the created vocab file.
The absolute filename of the created vocab file.
"""
"""
full_vocab
=
[
"[PAD]"
,
"[UNK]"
,
"[CLS]"
,
"[SEP]"
]
+
vocab
full_vocab
=
[
"[PAD]"
,
"[UNK]"
,
"[CLS]"
,
"[SEP]"
]
+
[
"[MASK]"
]
*
add_mask_token
+
vocab
path
=
os
.
path
.
join
(
path
=
os
.
path
.
join
(
tempfile
.
mkdtemp
(
dir
=
self
.
get_temp_dir
(),
# New subdir each time.
tempfile
.
mkdtemp
(
dir
=
self
.
get_temp_dir
(),
# New subdir each time.
prefix
=
_STRING_NOT_TO_LEAK
),
prefix
=
_STRING_NOT_TO_LEAK
),
...
@@ -478,11 +482,12 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -478,11 +482,12 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
f
.
write
(
"
\n
"
.
join
(
full_vocab
+
[
""
]))
f
.
write
(
"
\n
"
.
join
(
full_vocab
+
[
""
]))
return
path
return
path
def
_make_sp_model_file
(
self
,
vocab
,
prefix
=
"spm"
):
def
_make_sp_model_file
(
self
,
vocab
,
prefix
=
"spm"
,
add_mask_token
=
False
):
"""Creates Sentencepiece word model with given words plus special tokens.
"""Creates Sentencepiece word model with given words plus special tokens.
The tokens of the resulting model are, in this order:
The tokens of the resulting model are, in this order:
<pad>, <unk>, [CLS], [SEP], ...vocab..., <s>, </s>
<pad>, <unk>, [CLS], [SEP], [MASK]*, ...vocab..., <s>, </s>
*=if requested by args.
The words in the input vocab are plain text, without the whitespace marker.
The words in the input vocab are plain text, without the whitespace marker.
That makes this function interchangeable with _make_vocab_file().
That makes this function interchangeable with _make_vocab_file().
...
@@ -492,6 +497,7 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -492,6 +497,7 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
vocabulary. Do not include special tokens here.
vocabulary. Do not include special tokens here.
prefix: an optional string, to change the filename prefix for the model
prefix: an optional string, to change the filename prefix for the model
(relative to the temporary directory created by this function).
(relative to the temporary directory created by this function).
add_mask_token: an optional bool, whether to include a [MASK] token.
Returns:
Returns:
The absolute filename of the created Sentencepiece model file.
The absolute filename of the created Sentencepiece model file.
...
@@ -507,12 +513,16 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -507,12 +513,16 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
input_text
.
append
(
" "
.
join
([
token
]
*
(
len
(
vocab
)
-
i
)))
input_text
.
append
(
" "
.
join
([
token
]
*
(
len
(
vocab
)
-
i
)))
with
tf
.
io
.
gfile
.
GFile
(
input_file
,
"w"
)
as
f
:
with
tf
.
io
.
gfile
.
GFile
(
input_file
,
"w"
)
as
f
:
f
.
write
(
"
\n
"
.
join
(
input_text
+
[
""
]))
f
.
write
(
"
\n
"
.
join
(
input_text
+
[
""
]))
control_symbols
=
"[CLS],[SEP]"
full_vocab_size
=
len
(
vocab
)
+
6
# <pad>, <unk>, [CLS], [SEP], <s>, </s>.
full_vocab_size
=
len
(
vocab
)
+
6
# <pad>, <unk>, [CLS], [SEP], <s>, </s>.
if
add_mask_token
:
control_symbols
+=
",[MASK]"
full_vocab_size
+=
1
flags
=
dict
(
flags
=
dict
(
model_prefix
=
model_prefix
,
model_prefix
=
model_prefix
,
model_type
=
"word"
,
model_type
=
"word"
,
input
=
input_file
,
input
=
input_file
,
pad_id
=
0
,
unk_id
=
1
,
control_symbols
=
"[CLS],[SEP]"
,
pad_id
=
0
,
unk_id
=
1
,
control_symbols
=
control_symbols
,
vocab_size
=
full_vocab_size
,
vocab_size
=
full_vocab_size
,
bos_id
=
full_vocab_size
-
2
,
eos_id
=
full_vocab_size
-
1
)
bos_id
=
full_vocab_size
-
2
,
eos_id
=
full_vocab_size
-
1
)
SentencePieceTrainer
.
Train
(
SentencePieceTrainer
.
Train
(
...
@@ -521,14 +531,15 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -521,14 +531,15 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
def
_do_export
(
self
,
vocab
,
do_lower_case
,
default_seq_length
=
128
,
def
_do_export
(
self
,
vocab
,
do_lower_case
,
default_seq_length
=
128
,
tokenize_with_offsets
=
True
,
use_sp_model
=
False
,
tokenize_with_offsets
=
True
,
use_sp_model
=
False
,
experimental_disable_assert
=
False
):
experimental_disable_assert
=
False
,
add_mask_token
=
False
):
"""Runs SavedModel export and returns the export_path."""
"""Runs SavedModel export and returns the export_path."""
export_path
=
tempfile
.
mkdtemp
(
dir
=
self
.
get_temp_dir
())
export_path
=
tempfile
.
mkdtemp
(
dir
=
self
.
get_temp_dir
())
vocab_file
=
sp_model_file
=
None
vocab_file
=
sp_model_file
=
None
if
use_sp_model
:
if
use_sp_model
:
sp_model_file
=
self
.
_make_sp_model_file
(
vocab
)
sp_model_file
=
self
.
_make_sp_model_file
(
vocab
,
add_mask_token
=
add_mask_token
)
else
:
else
:
vocab_file
=
self
.
_make_vocab_file
(
vocab
)
vocab_file
=
self
.
_make_vocab_file
(
vocab
,
add_mask_token
=
add_mask_token
)
export_tfhub_lib
.
export_preprocessing
(
export_tfhub_lib
.
export_preprocessing
(
export_path
,
export_path
,
vocab_file
=
vocab_file
,
vocab_file
=
vocab_file
,
...
@@ -553,7 +564,7 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -553,7 +564,7 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
def
test_exported_callables
(
self
,
use_sp_model
):
def
test_exported_callables
(
self
,
use_sp_model
):
preprocess
=
tf
.
saved_model
.
load
(
self
.
_do_export
(
preprocess
=
tf
.
saved_model
.
load
(
self
.
_do_export
(
[
"d"
,
"ef"
,
"abc"
,
"xy"
],
do_lower_case
=
True
,
[
"d"
,
"ef"
,
"abc"
,
"xy"
],
do_lower_case
=
True
,
tokenize_with_offsets
=
not
use_sp_model
,
# TODO(b/1
4957620
0): drop this.
tokenize_with_offsets
=
not
use_sp_model
,
# TODO(b/1
8186685
0): drop this.
experimental_disable_assert
=
True
,
# TODO(b/175369555): drop this.
experimental_disable_assert
=
True
,
# TODO(b/175369555): drop this.
use_sp_model
=
use_sp_model
))
use_sp_model
=
use_sp_model
))
...
@@ -579,7 +590,7 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -579,7 +590,7 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
# .tokenize_with_offsets()
# .tokenize_with_offsets()
if
use_sp_model
:
if
use_sp_model
:
# TODO(b/1
4957620
0): Enable tokenize_with_offsets when it works and test.
# TODO(b/1
8186685
0): Enable tokenize_with_offsets when it works and test.
self
.
assertFalse
(
hasattr
(
preprocess
,
"tokenize_with_offsets"
))
self
.
assertFalse
(
hasattr
(
preprocess
,
"tokenize_with_offsets"
))
else
:
else
:
token_ids
,
start_offsets
,
limit_offsets
=
(
token_ids
,
start_offsets
,
limit_offsets
=
(
...
@@ -680,7 +691,7 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -680,7 +691,7 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
def
test_shapes
(
self
,
use_sp_model
):
def
test_shapes
(
self
,
use_sp_model
):
preprocess
=
tf
.
saved_model
.
load
(
self
.
_do_export
(
preprocess
=
tf
.
saved_model
.
load
(
self
.
_do_export
(
[
"abc"
,
"def"
],
do_lower_case
=
True
,
[
"abc"
,
"def"
],
do_lower_case
=
True
,
tokenize_with_offsets
=
not
use_sp_model
,
# TODO(b/1
4957620
0): drop this.
tokenize_with_offsets
=
not
use_sp_model
,
# TODO(b/1
8186685
0): drop this.
experimental_disable_assert
=
True
,
# TODO(b/175369555): drop this.
experimental_disable_assert
=
True
,
# TODO(b/175369555): drop this.
use_sp_model
=
use_sp_model
))
use_sp_model
=
use_sp_model
))
...
@@ -700,7 +711,7 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -700,7 +711,7 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
tf
.
TensorSpec
([
batch_size
],
tf
.
string
)),
tf
.
TensorSpec
([
batch_size
],
tf
.
string
)),
token_out_shape
,
token_out_shape
,
"with batch_size=%s"
%
batch_size
)
"with batch_size=%s"
%
batch_size
)
# TODO(b/1
4957620
0): Enable tokenize_with_offsets when it works and test.
# TODO(b/1
8186685
0): Enable tokenize_with_offsets when it works and test.
if
use_sp_model
:
if
use_sp_model
:
self
.
assertFalse
(
hasattr
(
preprocess
,
"tokenize_with_offsets"
))
self
.
assertFalse
(
hasattr
(
preprocess
,
"tokenize_with_offsets"
))
else
:
else
:
...
@@ -751,6 +762,137 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -751,6 +762,137 @@ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
tf
.
constant
([[
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
],
tf
.
constant
([[
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
]]))
[
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
]]))
@
parameterized
.
named_parameters
((
"Bert"
,
True
),
(
"Albert"
,
False
))
def
test_preprocessing_for_mlm
(
self
,
use_bert
):
"""Combines both SavedModel types and TF.text helpers for MLM."""
# Create the preprocessing SavedModel with a [MASK] token.
non_special_tokens
=
[
"hello"
,
"world"
,
"nice"
,
"movie"
,
"great"
,
"actors"
,
"quick"
,
"fox"
,
"lazy"
,
"dog"
]
preprocess
=
tf
.
saved_model
.
load
(
self
.
_do_export
(
non_special_tokens
,
do_lower_case
=
True
,
tokenize_with_offsets
=
use_bert
,
# TODO(b/181866850): drop this.
experimental_disable_assert
=
True
,
# TODO(b/175369555): drop this.
add_mask_token
=
True
,
use_sp_model
=
not
use_bert
))
vocab_size
=
len
(
non_special_tokens
)
+
(
5
if
use_bert
else
7
)
# Create the encoder SavedModel with an .mlm subobject.
hidden_size
=
16
num_hidden_layers
=
2
bert_config
,
encoder_config
=
_get_bert_config_or_encoder_config
(
use_bert
,
hidden_size
,
num_hidden_layers
,
vocab_size
)
_
,
pretrainer
=
export_tfhub_lib
.
_create_model
(
bert_config
=
bert_config
,
encoder_config
=
encoder_config
,
with_mlm
=
True
)
model_checkpoint_dir
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"checkpoint"
)
checkpoint
=
tf
.
train
.
Checkpoint
(
**
pretrainer
.
checkpoint_items
)
checkpoint
.
save
(
os
.
path
.
join
(
model_checkpoint_dir
,
"test"
))
model_checkpoint_path
=
tf
.
train
.
latest_checkpoint
(
model_checkpoint_dir
)
vocab_file
,
sp_model_file
=
_get_vocab_or_sp_model_dummy
(
# Not used below.
self
.
get_temp_dir
(),
use_sp_model
=
not
use_bert
)
encoder_export_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"encoder_export"
)
export_tfhub_lib
.
export_model
(
export_path
=
encoder_export_path
,
bert_config
=
bert_config
,
encoder_config
=
encoder_config
,
model_checkpoint_path
=
model_checkpoint_path
,
with_mlm
=
True
,
vocab_file
=
vocab_file
,
sp_model_file
=
sp_model_file
,
do_lower_case
=
True
)
encoder
=
tf
.
saved_model
.
load
(
encoder_export_path
)
# Get special tokens from the vocab (and vocab size).
special_tokens_dict
=
preprocess
.
tokenize
.
get_special_tokens_dict
()
self
.
assertEqual
(
int
(
special_tokens_dict
[
"vocab_size"
]),
vocab_size
)
padding_id
=
int
(
special_tokens_dict
[
"padding_id"
])
self
.
assertEqual
(
padding_id
,
0
)
start_of_sequence_id
=
int
(
special_tokens_dict
[
"start_of_sequence_id"
])
self
.
assertEqual
(
start_of_sequence_id
,
2
)
end_of_segment_id
=
int
(
special_tokens_dict
[
"end_of_segment_id"
])
self
.
assertEqual
(
end_of_segment_id
,
3
)
mask_id
=
int
(
special_tokens_dict
[
"mask_id"
])
self
.
assertEqual
(
mask_id
,
4
)
# A batch of 3 segment pairs.
raw_segments
=
[
tf
.
constant
([
"hello"
,
"nice movie"
,
"quick fox"
]),
tf
.
constant
([
"world"
,
"great actors"
,
"lazy dog"
])]
batch_size
=
3
# Misc hyperparameters.
seq_length
=
10
max_selections_per_seq
=
2
# Tokenize inputs.
tokenized_segments
=
[
preprocess
.
tokenize
(
s
)
for
s
in
raw_segments
]
# Trim inputs to eventually fit seq_lentgh.
num_special_tokens
=
len
(
raw_segments
)
+
1
trimmed_segments
=
text
.
WaterfallTrimmer
(
seq_length
-
num_special_tokens
).
trim
(
tokenized_segments
)
# Combine input segments into one input sequence.
input_ids
,
segment_ids
=
text
.
combine_segments
(
trimmed_segments
,
start_of_sequence_id
=
start_of_sequence_id
,
end_of_segment_id
=
end_of_segment_id
)
# Apply random masking controlled by policy objects.
(
masked_input_ids
,
masked_lm_positions
,
masked_ids
)
=
text
.
mask_language_model
(
input_ids
=
input_ids
,
item_selector
=
text
.
RandomItemSelector
(
max_selections_per_seq
,
selection_rate
=
0.5
,
# Adjusted for the short test examples.
unselectable_ids
=
[
start_of_sequence_id
,
end_of_segment_id
]),
mask_values_chooser
=
text
.
MaskValuesChooser
(
vocab_size
=
vocab_size
,
mask_token
=
mask_id
,
# Always put [MASK] to have a predictable result.
mask_token_rate
=
1.0
,
random_token_rate
=
0.0
))
# Pad to fixed-length Transformer encoder inputs.
input_word_ids
,
_
=
text
.
pad_model_inputs
(
masked_input_ids
,
seq_length
,
pad_value
=
padding_id
)
input_type_ids
,
input_mask
=
text
.
pad_model_inputs
(
segment_ids
,
seq_length
,
pad_value
=
0
)
masked_lm_positions
,
_
=
text
.
pad_model_inputs
(
masked_lm_positions
,
max_selections_per_seq
,
pad_value
=
0
)
masked_lm_positions
=
tf
.
cast
(
masked_lm_positions
,
tf
.
int32
)
num_predictions
=
int
(
tf
.
shape
(
masked_lm_positions
)[
1
])
# Test transformer inputs.
self
.
assertEqual
(
num_predictions
,
max_selections_per_seq
)
expected_word_ids
=
np
.
array
([
# [CLS] hello [SEP] world [SEP]
[
2
,
5
,
3
,
6
,
3
,
0
,
0
,
0
,
0
,
0
],
# [CLS] nice movie [SEP] great actors [SEP]
[
2
,
7
,
8
,
3
,
9
,
10
,
3
,
0
,
0
,
0
],
# [CLS] brown fox [SEP] lazy dog [SEP]
[
2
,
11
,
12
,
3
,
13
,
14
,
3
,
0
,
0
,
0
]])
for
i
in
range
(
batch_size
):
for
j
in
range
(
num_predictions
):
k
=
int
(
masked_lm_positions
[
i
,
j
])
if
k
!=
0
:
expected_word_ids
[
i
,
k
]
=
4
# [MASK]
self
.
assertAllEqual
(
input_word_ids
,
expected_word_ids
)
# Call the MLM head of the Transformer encoder.
mlm_inputs
=
dict
(
input_word_ids
=
input_word_ids
,
input_mask
=
input_mask
,
input_type_ids
=
input_type_ids
,
masked_lm_positions
=
masked_lm_positions
,
)
mlm_outputs
=
encoder
.
mlm
(
mlm_inputs
)
self
.
assertEqual
(
mlm_outputs
[
"pooled_output"
].
shape
,
(
batch_size
,
hidden_size
))
self
.
assertEqual
(
mlm_outputs
[
"sequence_output"
].
shape
,
(
batch_size
,
seq_length
,
hidden_size
))
self
.
assertEqual
(
mlm_outputs
[
"mlm_logits"
].
shape
,
(
batch_size
,
num_predictions
,
vocab_size
))
self
.
assertLen
(
mlm_outputs
[
"encoder_outputs"
],
num_hidden_layers
)
# A real trainer would now compute the loss of mlm_logits
# trying to predict the masked_ids.
del
masked_ids
# Unused.
@
parameterized
.
named_parameters
((
"Bert"
,
False
),
(
"Sentencepiece"
,
True
))
@
parameterized
.
named_parameters
((
"Bert"
,
False
),
(
"Sentencepiece"
,
True
))
def
test_special_tokens_in_estimator
(
self
,
use_sp_model
):
def
test_special_tokens_in_estimator
(
self
,
use_sp_model
):
"""Tests getting special tokens without an Eager init context."""
"""Tests getting special tokens without an Eager init context."""
...
...
official/nlp/train.py
View file @
bb124157
# Lint as: python3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -12,7 +11,7 @@
...
@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
"""TFM common training driver."""
"""TFM common training driver."""
from
absl
import
app
from
absl
import
app
...
@@ -47,7 +46,8 @@ def main(_):
...
@@ -47,7 +46,8 @@ def main(_):
# dtype is float16
# dtype is float16
if
params
.
runtime
.
mixed_precision_dtype
:
if
params
.
runtime
.
mixed_precision_dtype
:
performance
.
set_mixed_precision_policy
(
params
.
runtime
.
mixed_precision_dtype
,
performance
.
set_mixed_precision_policy
(
params
.
runtime
.
mixed_precision_dtype
,
params
.
runtime
.
loss_scale
)
params
.
runtime
.
loss_scale
,
use_experimental_api
=
True
)
distribution_strategy
=
distribute_utils
.
get_distribution_strategy
(
distribution_strategy
=
distribute_utils
.
get_distribution_strategy
(
distribution_strategy
=
params
.
runtime
.
distribution_strategy
,
distribution_strategy
=
params
.
runtime
.
distribution_strategy
,
all_reduce_alg
=
params
.
runtime
.
all_reduce_alg
,
all_reduce_alg
=
params
.
runtime
.
all_reduce_alg
,
...
...
official/nlp/train_ctl_continuous_finetune.py
View file @
bb124157
# Lint as: python3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -12,7 +11,7 @@
...
@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
"""TFM continuous finetuning+eval training driver."""
"""TFM continuous finetuning+eval training driver."""
from
absl
import
app
from
absl
import
app
from
absl
import
flags
from
absl
import
flags
...
@@ -39,8 +38,8 @@ def main(_):
...
@@ -39,8 +38,8 @@ def main(_):
params
=
train_utils
.
parse_configuration
(
FLAGS
)
params
=
train_utils
.
parse_configuration
(
FLAGS
)
model_dir
=
FLAGS
.
model_dir
model_dir
=
FLAGS
.
model_dir
train_utils
.
serialize_config
(
params
,
model_dir
)
train_utils
.
serialize_config
(
params
,
model_dir
)
continuous_finetune_lib
.
run_continuous_finetune
(
FLAGS
.
mode
,
params
,
model_dir
,
continuous_finetune_lib
.
run_continuous_finetune
(
FLAGS
.
pretrain_steps
)
FLAGS
.
mode
,
params
,
model_dir
,
pretrain_steps
=
FLAGS
.
pretrain_steps
)
train_utils
.
save_gin_config
(
FLAGS
.
mode
,
model_dir
)
train_utils
.
save_gin_config
(
FLAGS
.
mode
,
model_dir
)
...
...
official/nlp/transformer/README.md
View file @
bb124157
...
@@ -3,9 +3,11 @@ This is an implementation of the Transformer translation model as described in
...
@@ -3,9 +3,11 @@ This is an implementation of the Transformer translation model as described in
the
[
Attention is All You Need
](
https://arxiv.org/abs/1706.03762
)
paper. The
the
[
Attention is All You Need
](
https://arxiv.org/abs/1706.03762
)
paper. The
implementation leverages tf.keras and makes sure it is compatible with TF 2.x.
implementation leverages tf.keras and makes sure it is compatible with TF 2.x.
**
Note: this transformer folder is subject to be integrated into official/nlp
**
Warning: the features in the
`transformer/`
folder have been fully intergrated
folder. Due to its dependencies, we will finish the refactoring after the model
into nlp/modeling.
garden 2.1 release.
**
Due to its dependencies, we will remove this folder after the model
garden 2.5 release. The model in
`nlp/modeling/models/seq2seq_transformer.py`
is
identical to the model in this folder.
**
## Contents
## Contents
*
[
Contents
](
#contents
)
*
[
Contents
](
#contents
)
...
...
official/nlp/transformer/__init__.py
View file @
bb124157
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
official/nlp/transformer/attention_layer.py
View file @
bb124157
# Copyright 201
8
The TensorFlow Authors. All Rights Reserved.
# Copyright 20
2
1 The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
"""Implementation of multiheaded attention and self-attention layers."""
"""Implementation of multiheaded attention and self-attention layers."""
import
math
import
math
...
...
official/nlp/transformer/beam_search_v1.py
View file @
bb124157
# Copyright 201
8
The TensorFlow Authors. All Rights Reserved.
# Copyright 20
2
1 The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
"""Beam search to find the translated sequence with the highest probability."""
"""Beam search to find the translated sequence with the highest probability."""
import
tensorflow.compat.v1
as
tf
import
tensorflow.compat.v1
as
tf
...
...
official/nlp/transformer/compute_bleu.py
View file @
bb124157
# Copyright 201
8
The TensorFlow Authors. All Rights Reserved.
# Copyright 20
2
1 The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
"""Script to compute official BLEU score.
"""Script to compute official BLEU score.
Source:
Source:
...
...
official/nlp/transformer/compute_bleu_test.py
View file @
bb124157
# Copyright 201
8
The TensorFlow Authors. All Rights Reserved.
# Copyright 20
2
1 The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
"""Test functions in compute_blue.py."""
"""Test functions in compute_blue.py."""
import
tempfile
import
tempfile
...
...
official/nlp/transformer/data_download.py
View file @
bb124157
# Copyright 201
8
The TensorFlow Authors. All Rights Reserved.
# Copyright 20
2
1 The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
"""Download and preprocess WMT17 ende training and evaluation datasets."""
"""Download and preprocess WMT17 ende training and evaluation datasets."""
import
os
import
os
...
...
official/nlp/transformer/data_pipeline.py
View file @
bb124157
# Copyright 201
8
The TensorFlow Authors. All Rights Reserved.
# Copyright 20
2
1 The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
"""Input pipeline for the transformer model to read, filter, and batch examples.
"""Input pipeline for the transformer model to read, filter, and batch examples.
Two things to note in the pipeline:
Two things to note in the pipeline:
...
@@ -242,7 +242,7 @@ def _read_and_batch_from_files(file_pattern,
...
@@ -242,7 +242,7 @@ def _read_and_batch_from_files(file_pattern,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
).
with_options
(
options
)
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
).
with_options
(
options
)
# Parse each tf.Example into a dictionary
# Parse each tf.Example into a dictionary
# TODO: Look into prefetch_input_elements for performance optimization.
# TODO: Look into prefetch_input_elements for performance optimization.
# pylint: disable=g-bad-todo
dataset
=
dataset
.
map
(
dataset
=
dataset
.
map
(
_parse_example
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
_parse_example
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
...
...
official/nlp/transformer/embedding_layer.py
View file @
bb124157
# Copyright 201
8
The TensorFlow Authors. All Rights Reserved.
# Copyright 20
2
1 The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
"""Implementation of embedding layer with shared weights."""
"""Implementation of embedding layer with shared weights."""
import
tensorflow
as
tf
import
tensorflow
as
tf
...
...
official/nlp/transformer/ffn_layer.py
View file @
bb124157
# Copyright 201
8
The TensorFlow Authors. All Rights Reserved.
# Copyright 20
2
1 The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
"""Implementation of fully connected network."""
"""Implementation of fully connected network."""
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -62,8 +62,6 @@ class FeedForwardNetwork(tf.keras.layers.Layer):
...
@@ -62,8 +62,6 @@ class FeedForwardNetwork(tf.keras.layers.Layer):
tensor with shape [batch_size, length, hidden_size]
tensor with shape [batch_size, length, hidden_size]
"""
"""
# Retrieve dynamically known shapes
# Retrieve dynamically known shapes
batch_size
=
tf
.
shape
(
x
)[
0
]
length
=
tf
.
shape
(
x
)[
1
]
output
=
self
.
filter_dense_layer
(
x
)
output
=
self
.
filter_dense_layer
(
x
)
if
training
:
if
training
:
...
...
official/nlp/transformer/metrics.py
View file @
bb124157
# Copyright 201
8
The TensorFlow Authors. All Rights Reserved.
# Copyright 20
2
1 The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the
'
License
'
);
# Licensed under the Apache License, Version 2.0 (the
"
License
"
);
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an
'
AS IS
'
BASIS,
# distributed under the License is distributed on an
"
AS IS
"
BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
"""Functions for calculating loss, accuracy, and other model metrics.
"""Functions for calculating loss, accuracy, and other model metrics.
Metrics:
Metrics:
...
...
official/nlp/transformer/misc.py
View file @
bb124157
# Copyright 201
9
The TensorFlow Authors. All Rights Reserved.
# Copyright 20
2
1 The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the
'
License
'
);
# Licensed under the Apache License, Version 2.0 (the
"
License
"
);
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an
'
AS IS
'
BASIS,
# distributed under the License is distributed on an
"
AS IS
"
BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
"""Misc for Transformer."""
"""Misc for Transformer."""
# pylint: disable=g-bad-import-order
# pylint: disable=g-bad-import-order
...
...
official/nlp/transformer/model_params.py
View file @
bb124157
# Copyright 201
8
The TensorFlow Authors. All Rights Reserved.
# Copyright 20
2
1 The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -11,13 +11,13 @@
...
@@ -11,13 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
"""Defines Transformer model parameters."""
"""Defines Transformer model parameters."""
from
collections
import
defaultdict
import
collections
BASE_PARAMS
=
defaultdict
(
BASE_PARAMS
=
collections
.
defaultdict
(
lambda
:
None
,
# Set default value to None.
lambda
:
None
,
# Set default value to None.
# Input params
# Input params
...
...
official/nlp/transformer/model_utils.py
View file @
bb124157
# Copyright 201
8
The TensorFlow Authors. All Rights Reserved.
# Copyright 20
2
1 The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
"""Transformer model helper methods."""
"""Transformer model helper methods."""
import
math
import
math
...
...
Prev
1
…
9
10
11
12
13
14
15
16
17
…
20
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment