Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
c4ebfef2
Commit
c4ebfef2
authored
Nov 01, 2021
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 406906304
parent
460890ed
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
89 additions
and
33 deletions
+89
-33
official/nlp/modeling/models/seq2seq_transformer.py
official/nlp/modeling/models/seq2seq_transformer.py
+44
-14
official/nlp/modeling/models/seq2seq_transformer_test.py
official/nlp/modeling/models/seq2seq_transformer_test.py
+45
-19
No files found.
official/nlp/modeling/models/seq2seq_transformer.py
View file @
c4ebfef2
...
...
@@ -103,7 +103,7 @@ class Seq2SeqTransformer(tf.keras.Model):
"beam_size"
:
self
.
_beam_size
,
"alpha"
:
self
.
_alpha
,
"encoder_layer"
:
self
.
encoder_layer
,
"decoder_layer"
:
self
.
decoder_layer
"decoder_layer"
:
self
.
decoder_layer
,
}
base_config
=
super
(
Seq2SeqTransformer
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
...
...
@@ -122,14 +122,47 @@ class Seq2SeqTransformer(tf.keras.Model):
return
tf
.
reshape
(
logits
,
[
batch_size
,
length
,
vocab_size
])
def
_parse_inputs
(
self
,
inputs
):
"""Parses the `call` inputs and returns an uniformed output."""
sources
=
inputs
.
get
(
"inputs"
,
None
)
input_mask
=
inputs
.
get
(
"input_masks"
,
None
)
embedded
=
inputs
.
get
(
"embedded_inputs"
,
None
)
if
sources
is
None
and
embedded
is
not
None
:
embedded_inputs
=
embedded
boolean_mask
=
input_mask
input_shape
=
tf_utils
.
get_shape_list
(
embedded
,
expected_rank
=
3
)
source_dtype
=
embedded
.
dtype
elif
sources
is
not
None
:
embedded_inputs
=
self
.
embedding_lookup
(
sources
)
boolean_mask
=
tf
.
not_equal
(
sources
,
0
)
input_shape
=
tf_utils
.
get_shape_list
(
sources
,
expected_rank
=
2
)
source_dtype
=
sources
.
dtype
else
:
raise
KeyError
(
"The call method expects either `inputs` or `embedded_inputs` and "
"`input_masks` as input features."
)
return
embedded_inputs
,
boolean_mask
,
input_shape
,
source_dtype
def
call
(
self
,
inputs
):
"""Calculate target logits or inferred target sequences.
Args:
inputs: a dictionary of tensors.
Feature `inputs`: int tensor with shape `[batch_size, input_length]`.
Feature `inputs` (optional): int tensor with shape
`[batch_size, input_length]`.
Feature `embedded_inputs` (optional): float tensor with shape
`[batch_size, input_length, embedding_width]`.
Feature `targets` (optional): None or int tensor with shape
`[batch_size, target_length]`.
Feature `input_masks` (optional): When providing the `embedded_inputs`,
the dictionary must provide a boolean mask marking the filled time
steps. The shape of the tensor is `[batch_size, input_length]`.
Either `inputs` or `embedded_inputs` and `input_masks` must be present
in the input dictionary. In the second case the projection of the
integer tokens to the transformer embedding space is skipped and
`input_masks` is expected to be present.
Returns:
If targets is defined, then return logits for each word in the target
...
...
@@ -144,21 +177,19 @@ class Seq2SeqTransformer(tf.keras.Model):
Raises:
NotImplementedError: If try to use padded decode method on CPU/GPUs.
"""
sources
=
inputs
[
"inputs"
]
targets
=
inputs
.
get
(
"targets"
,
None
)
# Prepare inputs to the layer stack by adding positional encodings and
# applying dropout.
embedded_inputs
=
self
.
embedding_lookup
(
sources
)
embedding_mask
=
tf
.
cast
(
tf
.
not_equal
(
sources
,
0
),
embedded_inputs
.
dtype
)
targets
=
inputs
.
get
(
"targets"
,
None
)
(
embedded_inputs
,
boolean_mask
,
input_shape
,
source_dtype
)
=
self
.
_parse_inputs
(
inputs
)
embedding_mask
=
tf
.
cast
(
boolean_mask
,
embedded_inputs
.
dtype
)
embedded_inputs
*=
tf
.
expand_dims
(
embedding_mask
,
-
1
)
# Attention_mask generation.
input_shape
=
tf_utils
.
get_shape_list
(
sources
,
expected_rank
=
2
)
attention_mask
=
tf
.
cast
(
tf
.
reshape
(
tf
.
not_equal
(
sources
,
0
),
[
input_shape
[
0
],
1
,
input_shape
[
1
]]),
dtype
=
sources
.
dtype
)
tf
.
reshape
(
boolean_mask
,
[
input_shape
[
0
],
1
,
input_shape
[
1
]]),
dtype
=
source_dtype
)
broadcast_ones
=
tf
.
ones
(
shape
=
[
input_shape
[
0
],
input_shape
[
1
],
1
],
dtype
=
source
s
.
dtype
)
shape
=
[
input_shape
[
0
],
input_shape
[
1
],
1
],
dtype
=
source
_
dtype
)
attention_mask
=
broadcast_ones
*
attention_mask
pos_encoding
=
self
.
position_embedding
(
embedded_inputs
)
...
...
@@ -206,8 +237,7 @@ class Seq2SeqTransformer(tf.keras.Model):
# Add encoder output and attention bias to the cache.
encoder_outputs
=
tf
.
cast
(
encoder_outputs
,
dtype
=
self
.
compute_dtype
)
attention_mask
=
tf
.
cast
(
tf
.
reshape
(
tf
.
not_equal
(
sources
,
0
),
[
input_shape
[
0
],
1
,
input_shape
[
1
]]),
tf
.
reshape
(
boolean_mask
,
[
input_shape
[
0
],
1
,
input_shape
[
1
]]),
dtype
=
self
.
compute_dtype
)
cache
[
"encoder_outputs"
]
=
encoder_outputs
cache
[
"encoder_decoder_attention_mask"
]
=
attention_mask
...
...
@@ -252,7 +282,7 @@ class Seq2SeqTransformer(tf.keras.Model):
self_attention_mask
=
tf
.
tile
(
self_attention_mask
,
[
batch_size
,
1
,
1
])
attention_mask
=
tf
.
cast
(
tf
.
expand_dims
(
tf
.
not_equal
(
sources
,
0
)
,
axis
=
1
),
dtype
=
source
s
.
dtype
)
tf
.
expand_dims
(
boolean_mask
,
axis
=
1
),
dtype
=
source
_
dtype
)
attention_mask
=
tf
.
tile
(
attention_mask
,
[
1
,
decoder_length
,
1
])
outputs
=
self
.
decoder_layer
(
...
...
official/nlp/modeling/models/seq2seq_transformer_test.py
View file @
c4ebfef2
...
...
@@ -26,12 +26,11 @@ from official.nlp.modeling.models import seq2seq_transformer
class
Seq2SeqTransformerTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
_build_model
(
self
,
padded_decode
,
decode_max_length
):
def
_build_model
(
self
,
padded_decode
,
decode_max_length
,
embedding_width
):
num_layers
=
1
num_attention_heads
=
2
intermediate_size
=
32
vocab_size
=
100
embedding_width
=
16
encdec_kwargs
=
dict
(
num_layers
=
num_layers
,
num_attention_heads
=
num_attention_heads
,
...
...
@@ -63,15 +62,19 @@ class Seq2SeqTransformerTest(tf.test.TestCase, parameterized.TestCase):
strategy_combinations
.
default_strategy
,
strategy_combinations
.
cloud_tpu_strategy
,
],
embed
=
[
True
,
False
],
is_training
=
[
True
,
False
],
mode
=
"eager"
))
def
test_create_model_with_ds
(
self
,
distribution
):
def
test_create_model_with_ds
(
self
,
distribution
,
embed
,
is_training
):
with
distribution
.
scope
():
padded_decode
=
isinstance
(
distribution
,
(
tf
.
distribute
.
TPUStrategy
,
tf
.
distribute
.
experimental
.
TPUStrategy
))
decode_max_length
=
10
batch_size
=
4
model
=
self
.
_build_model
(
padded_decode
,
decode_max_length
)
embedding_width
=
16
model
=
self
.
_build_model
(
padded_decode
,
decode_max_length
,
embedding_width
)
@
tf
.
function
def
step
(
inputs
):
...
...
@@ -83,23 +86,32 @@ class Seq2SeqTransformerTest(tf.test.TestCase, parameterized.TestCase):
return
tf
.
nest
.
map_structure
(
distribution
.
experimental_local_results
,
outputs
)
fake_inputs
=
dict
(
inputs
=
np
.
zeros
((
batch_size
,
decode_max_length
),
dtype
=
np
.
int32
))
local_outputs
=
step
(
fake_inputs
)
logging
.
info
(
"local_outputs=%s"
,
local_outputs
)
self
.
assertEqual
(
local_outputs
[
"outputs"
][
0
].
shape
,
(
4
,
10
))
fake_inputs
=
dict
(
inputs
=
np
.
zeros
((
batch_size
,
decode_max_length
),
dtype
=
np
.
int32
),
targets
=
np
.
zeros
((
batch_size
,
8
),
dtype
=
np
.
int32
))
local_outputs
=
step
(
fake_inputs
)
logging
.
info
(
"local_outputs=%s"
,
local_outputs
)
self
.
assertEqual
(
local_outputs
[
0
].
shape
,
(
4
,
8
,
100
))
if
embed
:
fake_inputs
=
dict
(
embedded_inputs
=
np
.
zeros
(
(
batch_size
,
decode_max_length
,
embedding_width
),
dtype
=
np
.
float32
),
input_masks
=
np
.
ones
((
batch_size
,
decode_max_length
),
dtype
=
np
.
bool
))
else
:
fake_inputs
=
dict
(
inputs
=
np
.
zeros
((
batch_size
,
decode_max_length
),
dtype
=
np
.
int32
))
if
is_training
:
fake_inputs
[
"targets"
]
=
np
.
zeros
((
batch_size
,
8
),
dtype
=
np
.
int32
)
local_outputs
=
step
(
fake_inputs
)
logging
.
info
(
"local_outputs=%s"
,
local_outputs
)
self
.
assertEqual
(
local_outputs
[
0
].
shape
,
(
4
,
8
,
100
))
else
:
local_outputs
=
step
(
fake_inputs
)
logging
.
info
(
"local_outputs=%s"
,
local_outputs
)
self
.
assertEqual
(
local_outputs
[
"outputs"
][
0
].
shape
,
(
4
,
10
))
@
parameterized
.
parameters
(
True
,
False
)
def
test_create_savedmodel
(
self
,
padded_decode
):
decode_max_length
=
10
model
=
self
.
_build_model
(
padded_decode
,
decode_max_length
)
embedding_width
=
16
model
=
self
.
_build_model
(
padded_decode
,
decode_max_length
,
embedding_width
)
class
SaveModule
(
tf
.
Module
):
...
...
@@ -111,14 +123,28 @@ class Seq2SeqTransformerTest(tf.test.TestCase, parameterized.TestCase):
def
serve
(
self
,
inputs
):
return
self
.
model
.
call
(
dict
(
inputs
=
inputs
))
@
tf
.
function
def
embedded_serve
(
self
,
embedded_inputs
,
input_masks
):
return
self
.
model
.
call
(
dict
(
embedded_inputs
=
embedded_inputs
,
input_masks
=
input_masks
))
save_module
=
SaveModule
(
model
)
if
padded_decode
:
tensor_shape
=
(
4
,
10
)
tensor_shape
=
(
4
,
decode_max_length
)
embedded_tensor_shape
=
(
4
,
decode_max_length
,
embedding_width
)
else
:
tensor_shape
=
(
None
,
None
)
embedded_tensor_shape
=
(
None
,
None
,
embedding_width
)
signatures
=
dict
(
serving_default
=
save_module
.
serve
.
get_concrete_function
(
tf
.
TensorSpec
(
shape
=
tensor_shape
,
dtype
=
tf
.
int32
,
name
=
"inputs"
)))
tf
.
TensorSpec
(
shape
=
tensor_shape
,
dtype
=
tf
.
int32
,
name
=
"inputs"
)),
embedded_serving
=
save_module
.
embedded_serve
.
get_concrete_function
(
tf
.
TensorSpec
(
shape
=
embedded_tensor_shape
,
dtype
=
tf
.
float32
,
name
=
"embedded_inputs"
),
tf
.
TensorSpec
(
shape
=
tensor_shape
,
dtype
=
tf
.
bool
,
name
=
"input_masks"
),
))
tf
.
saved_model
.
save
(
save_module
,
self
.
get_temp_dir
(),
signatures
=
signatures
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment