Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
f0e2f833
Commit
f0e2f833
authored
Aug 13, 2020
by
xinliupitt
Browse files
remove params
parent
9bb04e60
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
192 additions
and
86 deletions
+192
-86
official/nlp/modeling/models/seq2seq_transformer.py
official/nlp/modeling/models/seq2seq_transformer.py
+189
-75
official/nlp/modeling/models/seq2seq_transformer_test.py
official/nlp/modeling/models/seq2seq_transformer_test.py
+3
-11
No files found.
official/nlp/modeling/models/seq2seq_transformer.py
View file @
f0e2f833
# Copyright 20
18
The TensorFlow Authors. All Rights Reserved.
# Copyright 20
20
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -12,15 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Defines the
Transformer model
in
TF
2.0
.
"""
Implement Seq2Seq
Transformer model
by
TF
official NLP library
.
Model paper: https://arxiv.org/pdf/1706.03762.pdf
Transformer model code source: https://github.com/tensorflow/tensor2tensor
TF official NLP library:
https://github.com/tensorflow/models/tree/master/official/nlp/modeling
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
math
import
tensorflow
as
tf
...
...
@@ -29,6 +26,7 @@ from official.nlp.modeling import layers
from
official.nlp.modeling.layers
import
position_embedding
from
official.nlp.modeling.layers
import
transformer
from
official.nlp.modeling.ops
import
beam_search
from
official.nlp.transformer
import
metrics
from
official.nlp.transformer
import
model_utils
from
official.nlp.transformer.utils.tokenizer
import
EOS_ID
...
...
@@ -37,6 +35,66 @@ from official.nlp.transformer.utils.tokenizer import EOS_ID
# callable when they actually are.
# pylint: disable=not-callable
def
create_model
(
params
,
is_train
):
"""Creates transformer model."""
encdec_kwargs
=
dict
(
num_layers
=
params
[
"num_hidden_layers"
],
num_attention_heads
=
params
[
"num_heads"
],
intermediate_size
=
params
[
"filter_size"
],
activation
=
"relu"
,
dropout_rate
=
params
[
"relu_dropout"
],
attention_dropout_rate
=
params
[
"attention_dropout"
],
use_bias
=
False
,
norm_first
=
True
,
norm_epsilon
=
1e-6
,
intermediate_dropout
=
params
[
"relu_dropout"
])
encoder_layer
=
TransformerEncoder
(
**
encdec_kwargs
)
decoder_layer
=
TransformerDecoder
(
**
encdec_kwargs
)
model_kwargs
=
dict
(
vocab_size
=
params
[
"vocab_size"
],
hidden_size
=
params
[
"hidden_size"
],
dropout_rate
=
params
[
"layer_postprocess_dropout"
],
padded_decode
=
params
[
"padded_decode"
],
num_replicas
=
params
[
"num_replicas"
],
decode_batch_size
=
params
[
"decode_batch_size"
],
decode_max_length
=
params
[
"decode_max_length"
],
dtype
=
params
[
"dtype"
],
extra_decode_length
=
params
[
"extra_decode_length"
],
num_heads
=
params
[
"num_heads"
],
num_layers
=
params
[
"num_hidden_layers"
],
beam_size
=
params
[
"beam_size"
],
alpha
=
params
[
"alpha"
],
encoder_layer
=
encoder_layer
,
decoder_layer
=
decoder_layer
,
name
=
"transformer_v2"
)
with
tf
.
name_scope
(
"model"
):
if
is_train
:
inputs
=
tf
.
keras
.
layers
.
Input
((
None
,),
dtype
=
"int64"
,
name
=
"inputs"
)
targets
=
tf
.
keras
.
layers
.
Input
((
None
,),
dtype
=
"int64"
,
name
=
"targets"
)
internal_model
=
Seq2SeqTransformer
(
**
model_kwargs
)
logits
=
internal_model
([
inputs
,
targets
],
training
=
is_train
)
vocab_size
=
params
[
"vocab_size"
]
label_smoothing
=
params
[
"label_smoothing"
]
if
params
[
"enable_metrics_in_training"
]:
logits
=
metrics
.
MetricLayer
(
vocab_size
)([
logits
,
targets
])
logits
=
tf
.
keras
.
layers
.
Lambda
(
lambda
x
:
x
,
name
=
"logits"
,
dtype
=
tf
.
float32
)(
logits
)
model
=
tf
.
keras
.
Model
([
inputs
,
targets
],
logits
)
loss
=
metrics
.
transformer_loss
(
logits
,
targets
,
label_smoothing
,
vocab_size
)
model
.
add_loss
(
loss
)
return
model
else
:
inputs
=
tf
.
keras
.
layers
.
Input
((
None
,),
dtype
=
"int64"
,
name
=
"inputs"
)
internal_model
=
Seq2SeqTransformer
(
**
model_kwargs
)
ret
=
internal_model
([
inputs
],
training
=
is_train
)
outputs
,
scores
=
ret
[
"outputs"
],
ret
[
"scores"
]
return
tf
.
keras
.
Model
(
inputs
,
[
outputs
,
scores
])
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
class
Seq2SeqTransformer
(
tf
.
keras
.
Model
):
"""Transformer model with Keras.
...
...
@@ -49,54 +107,108 @@ class Seq2SeqTransformer(tf.keras.Model):
probabilities for the output sequence.
"""
def
__init__
(
self
,
params
,
name
=
None
):
def
__init__
(
self
,
vocab_size
=
33708
,
hidden_size
=
512
,
dropout_rate
=
0.0
,
padded_decode
=
False
,
num_replicas
=
1
,
decode_batch_size
=
2048
,
decode_max_length
=
97
,
dtype
=
tf
.
float32
,
extra_decode_length
=
0
,
num_heads
=
8
,
num_layers
=
6
,
beam_size
=
4
,
alpha
=
0.6
,
encoder_layer
=
None
,
decoder_layer
=
None
,
name
=
None
,
**
kwargs
):
"""Initialize layers to build Transformer model.
Args:
params: hyperparameter object defining layer sizes, dropout values, etc.
Arguments:
vocab_size: Size of vocabulary.
hidden_size: Size of hidden layer for embedding.
dropout_rate: Dropout probability.
padded_decode: Whether to max_sequence_length padding is used. If set
False, max_sequence_length padding is not used.
num_replicas: Number of replicas for distribution strategy.
decode_batch_size: batch_size for decoding.
decode_max_length: maximum number of steps to decode a sequence.
dtype: data type.
num_heads: Number of attention heads.
num_layers: Number of identical layers for Transformer architecture.
beam_size: Number of beams for beam search
alpha: The strength of length normalization for beam search.
encoder_layer: An initialized encoder layer.
decoder_layer: An initialized decoder layer.
name: name of the model.
"""
super
(
Seq2SeqTransformer
,
self
).
__init__
(
name
=
name
)
self
.
params
=
params
super
(
Seq2SeqTransformer
,
self
).
__init__
(
**
kwargs
)
self
.
_vocab_size
=
vocab_size
self
.
_hidden_size
=
hidden_size
self
.
_dropout_rate
=
dropout_rate
self
.
_padded_decode
=
padded_decode
self
.
_num_replicas
=
num_replicas
self
.
_decode_batch_size
=
decode_batch_size
self
.
_decode_max_length
=
decode_max_length
self
.
_dtype
=
dtype
self
.
_extra_decode_length
=
extra_decode_length
self
.
_num_heads
=
num_heads
self
.
_num_layers
=
num_layers
self
.
_beam_size
=
beam_size
self
.
_alpha
=
alpha
self
.
embedding_lookup
=
layers
.
OnDeviceEmbedding
(
vocab_size
=
params
[
"
vocab_size
"
]
,
embedding_width
=
params
[
"
hidden_size
"
]
,
vocab_size
=
self
.
_
vocab_size
,
embedding_width
=
self
.
_
hidden_size
,
initializer
=
tf
.
random_normal_initializer
(
mean
=
0.
,
stddev
=
params
[
"
hidden_size
"
]
**-
0.5
),
mean
=
0.
,
stddev
=
self
.
_
hidden_size
**-
0.5
),
use_scale
=
True
)
self
.
encoder_layer
=
TransformerEncoder
(
num_layers
=
self
.
params
[
"num_hidden_layers"
],
num_attention_heads
=
self
.
params
[
"num_heads"
],
intermediate_size
=
self
.
params
[
"filter_size"
],
activation
=
"relu"
,
dropout_rate
=
self
.
params
[
"relu_dropout"
],
attention_dropout_rate
=
self
.
params
[
"attention_dropout"
],
use_bias
=
False
,
norm_first
=
True
,
norm_epsilon
=
1e-6
,
intermediate_dropout
=
self
.
params
[
"relu_dropout"
])
self
.
decoder_layer
=
TransformerDecoder
(
num_layers
=
self
.
params
[
"num_hidden_layers"
],
num_attention_heads
=
self
.
params
[
"num_heads"
],
intermediate_size
=
self
.
params
[
"filter_size"
],
activation
=
"relu"
,
dropout_rate
=
self
.
params
[
"relu_dropout"
],
attention_dropout_rate
=
self
.
params
[
"attention_dropout"
],
use_bias
=
False
,
norm_first
=
True
,
norm_epsilon
=
1e-6
,
intermediate_dropout
=
self
.
params
[
"relu_dropout"
])
self
.
encoder_layer
=
encoder_layer
self
.
decoder_layer
=
decoder_layer
self
.
position_embedding
=
position_embedding
.
RelativePositionEmbedding
(
hidden_size
=
self
.
params
[
"
hidden_size
"
]
)
hidden_size
=
self
.
_
hidden_size
)
self
.
encoder_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
params
[
"layer_postprocess
_dropout
"
]
)
rate
=
self
.
_dropout
_rate
)
self
.
decoder_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
params
[
"layer_postprocess
_dropout
"
]
)
rate
=
self
.
_dropout
_rate
)
def
get_config
(
self
):
return
{
"params"
:
self
.
params
,
config
=
{
"vocab_size"
:
self
.
_vocab_size
,
"hidden_size"
:
self
.
_hidden_size
,
"dropout_rate"
:
self
.
_dropout_rate
,
"padded_decode"
:
self
.
_padded_decode
,
"num_replicas"
:
self
.
_num_replicas
,
"decode_batch_size"
:
self
.
_decode_batch_size
,
"decode_max_length"
:
self
.
_decode_max_length
,
"dtype"
:
self
.
_dtype
,
"extra_decode_length"
:
self
.
_extra_decode_length
,
"num_heads"
:
self
.
_num_heads
,
"num_layers"
:
self
.
_num_layers
,
"beam_size"
:
self
.
_beam_size
,
"alpha"
:
self
.
_alpha
,
"encoder_layer"
:
self
.
encoder_layer
,
"decoder_layer"
:
self
.
decoder_layer
}
base_config
=
super
(
Seq2SeqTransformer
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
call
(
self
,
inputs
):
"""Calculate target logits or inferred target sequences.
...
...
@@ -124,19 +236,19 @@ class Seq2SeqTransformer(tf.keras.Model):
else
:
# Decoding path.
inputs
,
targets
=
inputs
[
0
],
None
if
self
.
params
[
"
padded_decode
"
]
:
if
not
self
.
params
[
"
num_replicas
"
]
:
if
self
.
_
padded_decode
:
if
not
self
.
_
num_replicas
:
raise
NotImplementedError
(
"Padded decoding on CPU/GPUs is not supported."
)
decode_batch_size
=
int
(
self
.
params
[
"
decode_batch_size
"
]
/
self
.
params
[
"
num_replicas
"
]
)
decode_batch_size
=
int
(
self
.
_
decode_batch_size
/
self
.
_
num_replicas
)
inputs
.
set_shape
([
decode_batch_size
,
self
.
params
[
"
decode_max_length
"
]
decode_batch_size
,
self
.
_
decode_max_length
])
with
tf
.
name_scope
(
"Transformer"
):
attention_bias
=
model_utils
.
get_padding_bias
(
inputs
)
attention_bias
=
tf
.
cast
(
attention_bias
,
self
.
params
[
"
dtype
"
]
)
attention_bias
=
tf
.
cast
(
attention_bias
,
self
.
_
dtype
)
with
tf
.
name_scope
(
"encode"
):
# Prepare inputs to the layer stack by adding positional encodings and
# applying dropout.
...
...
@@ -144,7 +256,7 @@ class Seq2SeqTransformer(tf.keras.Model):
embedding_mask
=
tf
.
cast
(
tf
.
not_equal
(
inputs
,
0
),
self
.
embedding_lookup
.
embeddings
.
dtype
)
embedded_inputs
*=
tf
.
expand_dims
(
embedding_mask
,
-
1
)
embedded_inputs
=
tf
.
cast
(
embedded_inputs
,
self
.
params
[
"
dtype
"
]
)
embedded_inputs
=
tf
.
cast
(
embedded_inputs
,
self
.
_
dtype
)
# Attention_mask generation.
input_shape
=
tf_utils
.
get_shape_list
(
inputs
,
expected_rank
=
2
)
...
...
@@ -158,7 +270,7 @@ class Seq2SeqTransformer(tf.keras.Model):
with
tf
.
name_scope
(
"add_pos_encoding"
):
pos_encoding
=
self
.
position_embedding
(
inputs
=
embedded_inputs
)
pos_encoding
=
tf
.
cast
(
pos_encoding
,
self
.
params
[
"
dtype
"
]
)
pos_encoding
=
tf
.
cast
(
pos_encoding
,
self
.
_
dtype
)
encoder_inputs
=
embedded_inputs
+
pos_encoding
encoder_inputs
=
self
.
encoder_dropout
(
encoder_inputs
)
...
...
@@ -168,16 +280,16 @@ class Seq2SeqTransformer(tf.keras.Model):
if
targets
is
None
:
encoder_decoder_attention_bias
=
attention_bias
encoder_outputs
=
tf
.
cast
(
encoder_outputs
,
self
.
params
[
"
dtype
"
]
)
if
self
.
params
[
"
padded_decode
"
]
:
encoder_outputs
=
tf
.
cast
(
encoder_outputs
,
self
.
_
dtype
)
if
self
.
_
padded_decode
:
batch_size
=
encoder_outputs
.
shape
.
as_list
()[
0
]
input_length
=
encoder_outputs
.
shape
.
as_list
()[
1
]
else
:
batch_size
=
tf
.
shape
(
encoder_outputs
)[
0
]
input_length
=
tf
.
shape
(
encoder_outputs
)[
1
]
max_decode_length
=
input_length
+
self
.
params
[
"
extra_decode_length
"
]
max_decode_length
=
input_length
+
self
.
_
extra_decode_length
encoder_decoder_attention_bias
=
tf
.
cast
(
encoder_decoder_attention_bias
,
self
.
params
[
"
dtype
"
]
)
self
.
_
dtype
)
symbols_to_logits_fn
=
self
.
_get_symbols_to_logits_fn
(
max_decode_length
)
...
...
@@ -188,9 +300,9 @@ class Seq2SeqTransformer(tf.keras.Model):
# Create cache storing decoder attention values for each layer.
# pylint: disable=g-complex-comprehension
init_decode_length
=
(
max_decode_length
if
self
.
params
[
"
padded_decode
"
]
else
0
)
num_heads
=
self
.
params
[
"
num_heads
"
]
dim_per_head
=
self
.
params
[
"
hidden_size
"
]
//
num_heads
max_decode_length
if
self
.
_
padded_decode
else
0
)
num_heads
=
self
.
_
num_heads
dim_per_head
=
self
.
_
hidden_size
//
num_heads
cache
=
{
str
(
layer
):
{
...
...
@@ -198,13 +310,13 @@ class Seq2SeqTransformer(tf.keras.Model):
tf
.
zeros
([
batch_size
,
init_decode_length
,
num_heads
,
dim_per_head
],
dtype
=
self
.
params
[
"
dtype
"
]
),
dtype
=
self
.
_
dtype
),
"value"
:
tf
.
zeros
([
batch_size
,
init_decode_length
,
num_heads
,
dim_per_head
],
dtype
=
self
.
params
[
"
dtype
"
]
)
}
for
layer
in
range
(
self
.
params
[
"num_hidden
_layers
"
]
)
dtype
=
self
.
_
dtype
)
}
for
layer
in
range
(
self
.
_num
_layers
)
}
# pylint: enable=g-complex-comprehension
...
...
@@ -218,13 +330,13 @@ class Seq2SeqTransformer(tf.keras.Model):
symbols_to_logits_fn
=
symbols_to_logits_fn
,
initial_ids
=
initial_ids
,
initial_cache
=
cache
,
vocab_size
=
self
.
params
[
"
vocab_size
"
]
,
beam_size
=
self
.
params
[
"
beam_size
"
]
,
alpha
=
self
.
params
[
"
alpha
"
]
,
vocab_size
=
self
.
_
vocab_size
,
beam_size
=
self
.
_
beam_size
,
alpha
=
self
.
_
alpha
,
max_decode_length
=
max_decode_length
,
eos_id
=
EOS_ID
,
padded_decode
=
self
.
params
[
"
padded_decode
"
]
,
dtype
=
self
.
params
[
"
dtype
"
]
)
padded_decode
=
self
.
_
padded_decode
,
dtype
=
self
.
_
dtype
)
# Get the top sequence for each batch element
top_decoded_ids
=
decoded_ids
[:,
0
,
1
:]
...
...
@@ -238,7 +350,7 @@ class Seq2SeqTransformer(tf.keras.Model):
embedding_mask
=
tf
.
cast
(
tf
.
not_equal
(
targets
,
0
),
self
.
embedding_lookup
.
embeddings
.
dtype
)
decoder_inputs
*=
tf
.
expand_dims
(
embedding_mask
,
-
1
)
decoder_inputs
=
tf
.
cast
(
decoder_inputs
,
self
.
params
[
"
dtype
"
]
)
decoder_inputs
=
tf
.
cast
(
decoder_inputs
,
self
.
_
dtype
)
with
tf
.
name_scope
(
"shift_targets"
):
# Shift targets to the right, and remove the last element
decoder_inputs
=
tf
.
pad
(
decoder_inputs
,
...
...
@@ -246,7 +358,7 @@ class Seq2SeqTransformer(tf.keras.Model):
with
tf
.
name_scope
(
"add_pos_encoding"
):
length
=
tf
.
shape
(
decoder_inputs
)[
1
]
pos_encoding
=
self
.
position_embedding
(
decoder_inputs
)
pos_encoding
=
tf
.
cast
(
pos_encoding
,
self
.
params
[
"
dtype
"
]
)
pos_encoding
=
tf
.
cast
(
pos_encoding
,
self
.
_
dtype
)
decoder_inputs
+=
pos_encoding
decoder_inputs
=
self
.
decoder_dropout
(
decoder_inputs
)
...
...
@@ -282,9 +394,9 @@ class Seq2SeqTransformer(tf.keras.Model):
"""Returns a decoding function that calculates logits of the next tokens."""
timing_signal
=
self
.
position_embedding
(
inputs
=
None
,
length
=
max_decode_length
+
1
)
timing_signal
=
tf
.
cast
(
timing_signal
,
self
.
params
[
"
dtype
"
]
)
timing_signal
=
tf
.
cast
(
timing_signal
,
self
.
_
dtype
)
decoder_self_attention_bias
=
model_utils
.
get_decoder_self_attention_bias
(
max_decode_length
,
dtype
=
self
.
params
[
"
dtype
"
]
)
max_decode_length
,
dtype
=
self
.
_
dtype
)
def
symbols_to_logits_fn
(
ids
,
i
,
cache
):
"""Generate logits for next potential IDs.
...
...
@@ -312,7 +424,7 @@ class Seq2SeqTransformer(tf.keras.Model):
self
.
embedding_lookup
.
embeddings
.
dtype
)
decoder_input
*=
tf
.
expand_dims
(
embedding_mask
,
-
1
)
if
self
.
params
[
"
padded_decode
"
]
:
if
self
.
_
padded_decode
:
timing_signal_shape
=
timing_signal
.
shape
.
as_list
()
decoder_input
+=
tf
.
slice
(
timing_signal
,
[
i
,
0
],
[
1
,
timing_signal_shape
[
1
]])
...
...
@@ -350,7 +462,7 @@ class Seq2SeqTransformer(tf.keras.Model):
memory_mask
=
self_attention_mask
,
target_mask
=
attention_mask
,
cache
=
cache
,
decode_loop_step
=
i
if
self
.
params
[
"
padded_decode
"
]
else
None
)
decode_loop_step
=
i
if
self
.
_
padded_decode
else
None
)
logits
=
embedding_linear
(
self
.
embedding_lookup
.
embeddings
,
decoder_outputs
)
...
...
@@ -392,8 +504,9 @@ class TransformerEncoder(tf.keras.layers.Layer):
use_bias
=
False
,
norm_first
=
True
,
norm_epsilon
=
1e-6
,
intermediate_dropout
=
0.0
):
super
(
TransformerEncoder
,
self
).
__init__
()
intermediate_dropout
=
0.0
,
**
kwargs
):
super
(
TransformerEncoder
,
self
).
__init__
(
**
kwargs
)
self
.
_num_layers
=
num_layers
self
.
_num_attention_heads
=
num_attention_heads
self
.
_intermediate_size
=
intermediate_size
...
...
@@ -507,8 +620,9 @@ class TransformerDecoder(tf.keras.layers.Layer):
use_bias
=
False
,
norm_first
=
True
,
norm_epsilon
=
1e-6
,
intermediate_dropout
=
0.0
):
super
(
TransformerDecoder
,
self
).
__init__
()
intermediate_dropout
=
0.0
,
**
kwargs
):
super
(
TransformerDecoder
,
self
).
__init__
(
**
kwargs
)
self
.
_num_layers
=
num_layers
self
.
_num_attention_heads
=
num_attention_heads
self
.
_intermediate_size
=
intermediate_size
...
...
official/nlp/modeling/models/seq2seq_transformer_test.py
View file @
f0e2f833
...
...
@@ -28,7 +28,6 @@ class TransformerV2Test(tf.test.TestCase):
def
setUp
(
self
):
self
.
params
=
params
=
model_params
.
TINY_PARAMS
params
[
"batch_size"
]
=
params
[
"default_batch_size"
]
=
16
params
[
"use_synthetic_data"
]
=
True
params
[
"hidden_size"
]
=
12
params
[
"num_hidden_layers"
]
=
2
params
[
"filter_size"
]
=
14
...
...
@@ -39,11 +38,7 @@ class TransformerV2Test(tf.test.TestCase):
params
[
"dtype"
]
=
tf
.
float32
def
test_create_model_train
(
self
):
inputs
=
tf
.
keras
.
layers
.
Input
((
None
,),
dtype
=
"int64"
,
name
=
"inputs"
)
targets
=
tf
.
keras
.
layers
.
Input
((
None
,),
dtype
=
"int64"
,
name
=
"targets"
)
internal_model
=
seq2seq_transformer
.
Seq2SeqTransformer
(
self
.
params
)
logits
=
internal_model
([
inputs
,
targets
],
training
=
True
)
model
=
tf
.
keras
.
Model
([
inputs
,
targets
],
logits
)
model
=
seq2seq_transformer
.
create_model
(
self
.
params
,
True
)
inputs
,
outputs
=
model
.
inputs
,
model
.
outputs
self
.
assertEqual
(
len
(
inputs
),
2
)
self
.
assertEqual
(
len
(
outputs
),
1
)
...
...
@@ -55,11 +50,7 @@ class TransformerV2Test(tf.test.TestCase):
self
.
assertEqual
(
outputs
[
0
].
dtype
,
tf
.
float32
)
def
test_create_model_not_train
(
self
):
inputs
=
tf
.
keras
.
layers
.
Input
((
None
,),
dtype
=
"int64"
,
name
=
"inputs"
)
internal_model
=
seq2seq_transformer
.
Seq2SeqTransformer
(
self
.
params
)
ret
=
internal_model
([
inputs
],
training
=
False
)
outputs
,
scores
=
ret
[
"outputs"
],
ret
[
"scores"
]
model
=
tf
.
keras
.
Model
(
inputs
,
[
outputs
,
scores
])
model
=
seq2seq_transformer
.
create_model
(
self
.
params
,
False
)
inputs
,
outputs
=
model
.
inputs
,
model
.
outputs
self
.
assertEqual
(
len
(
inputs
),
1
)
self
.
assertEqual
(
len
(
outputs
),
2
)
...
...
@@ -71,5 +62,6 @@ class TransformerV2Test(tf.test.TestCase):
self
.
assertEqual
(
outputs
[
1
].
dtype
,
tf
.
float32
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment