Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
a81f8590
Unverified
Commit
a81f8590
authored
Aug 02, 2022
by
karun
Committed by
GitHub
Aug 02, 2022
Browse files
Adding transformer based bytestream models (#10734)
Co-authored-by:
Arun Kandoor
<
akandoor@google.com
>
parent
82a26070
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
1548 additions
and
0 deletions
+1548
-0
research/seq_flow_lite/layers/BUILD
research/seq_flow_lite/layers/BUILD
+16
-0
research/seq_flow_lite/layers/transformer_layers.py
research/seq_flow_lite/layers/transformer_layers.py
+672
-0
research/seq_flow_lite/models/BUILD
research/seq_flow_lite/models/BUILD
+49
-0
research/seq_flow_lite/models/charformer.py
research/seq_flow_lite/models/charformer.py
+153
-0
research/seq_flow_lite/models/transformer_encoder.py
research/seq_flow_lite/models/transformer_encoder.py
+112
-0
research/seq_flow_lite/models/transformer_uniform_attn_decoder.py
.../seq_flow_lite/models/transformer_uniform_attn_decoder.py
+516
-0
research/seq_flow_lite/tf_ops/tf_custom_ops.cc
research/seq_flow_lite/tf_ops/tf_custom_ops.cc
+30
-0
No files found.
research/seq_flow_lite/layers/BUILD
View file @
a81f8590
...
@@ -116,3 +116,19 @@ py_strict_library(
...
@@ -116,3 +116,19 @@ py_strict_library(
"//layers:quantization_layers"
,
"//layers:quantization_layers"
,
],
],
)
)
py_strict_library
(
name
=
"transformer_layers"
,
srcs
=
[
"transformer_layers.py"
],
srcs_version
=
"PY3"
,
deps
=
[
":embedding_layers"
,
# package tensorflow
"//layers:base_layers"
,
"//layers:dense_layers"
,
"//layers:normalization_layers"
,
"//layers:quantization_layers"
,
"//tf_ops:tf_custom_ops"
,
"//tf_ops:tf_custom_ops_py"
,
],
)
research/seq_flow_lite/layers/transformer_layers.py
0 → 100644
View file @
a81f8590
# Copyright 2020 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Layers for Transformer encoder."""
# pylint: disable=arguments-renamed
import
tensorflow
as
tf
from
layers
import
base_layers
from
layers
import
dense_layers
from
layers
import
embedding_layers
from
layers
import
normalization_layers
from
layers
import
quantization_layers
from
tf_ops
import
tf_custom_ops_py
class
SelfAttention
(
base_layers
.
BaseLayer
):
"""Self attention encoder (not suitable for causal attention)."""
def
__init__
(
self
,
model_dimension
,
num_heads
,
attention_dropout_rate
=
0.0
,
**
kwargs
):
self
.
model_dimension
=
model_dimension
self
.
num_heads
=
num_heads
self
.
filters
=
model_dimension
//
num_heads
self
.
dense_layers
=
[
dense_layers
.
BaseQDenseVarLen
(
units
=
self
.
filters
,
activation
=
None
,
**
kwargs
)
for
i
in
range
(
num_heads
*
3
)
]
self
.
qactivation
=
quantization_layers
.
ActivationQuantization
(
**
kwargs
)
self
.
attention_dropout_rate
=
attention_dropout_rate
self
.
qconcat
=
quantization_layers
.
ConcatQuantization
(
axis
=
2
,
**
kwargs
)
super
(
SelfAttention
,
self
).
__init__
(
**
kwargs
)
def
call
(
self
,
inputs
,
mask
,
inverse_normalizer
,
attn_mask
=
None
):
batch_size
=
self
.
get_batch_dimension
(
inputs
)
self
.
_assert_rank_and_type
(
inputs
,
3
)
self
.
_assert_rank_and_type
(
mask
,
3
)
assert
inputs
.
get_shape
().
as_list
()[
-
1
]
==
self
.
model_dimension
inputs_rank2
=
tf
.
reshape
(
inputs
,
[
-
1
,
self
.
model_dimension
])
mask_rank2
=
tf
.
reshape
(
mask
,
[
-
1
,
1
])
tensors
=
[
layer
(
inputs_rank2
,
mask_rank2
,
inverse_normalizer
)
for
layer
in
self
.
dense_layers
]
if
self
.
parameters
.
mode
not
in
[
base_layers
.
TFLITE
,
base_layers
.
PREDICT
]:
tensors
=
[
tf
.
reshape
(
tensor
,
[
batch_size
,
-
1
,
self
.
filters
])
for
tensor
in
tensors
]
context
=
[]
if
attn_mask
is
None
:
attn_mask
=
tf
.
matmul
(
mask
,
tf
.
transpose
(
mask
,
[
0
,
2
,
1
]))
if
(
self
.
attention_dropout_rate
>
0.0
and
self
.
parameters
.
mode
==
base_layers
.
TRAIN
):
attn_mask
*=
self
.
random_drop_to_zero
(
attn_mask
,
self
.
attention_dropout_rate
)
invalid_mask
=
(
1
-
attn_mask
)
*
self
.
parameters
.
invalid_logit
for
_
in
range
(
self
.
num_heads
):
keys
=
tensors
.
pop
()
values
=
tensors
.
pop
()
queries
=
tensors
.
pop
()
# Attention is not scaled dot product, batch normalization compensates
# for it.
if
self
.
parameters
.
mode
not
in
[
base_layers
.
TFLITE
,
base_layers
.
PREDICT
]:
queries
=
tf
.
transpose
(
queries
,
[
0
,
2
,
1
])
attn_logits
=
self
.
qactivation
(
tf
.
matmul
(
keys
,
queries
))
attn_logits_masked
=
attn_logits
*
attn_mask
+
invalid_mask
attention
=
tf
.
nn
.
softmax
(
attn_logits_masked
)
attention
=
self
.
qrange_sigmoid
(
attention
,
tf_only
=
True
)
context
.
append
(
tf
.
matmul
(
attention
,
values
))
else
:
queries
=
tf
.
transpose
(
queries
)
attn_logits_masked
=
self
.
qactivation
(
tf
.
matmul
(
keys
,
queries
))
attention
=
tf
.
nn
.
softmax
(
attn_logits_masked
)
attention
=
self
.
qrange_sigmoid
(
attention
,
tf_only
=
True
)
ctx
=
tf
.
matmul
(
attention
,
values
)
ctx
=
tf
.
reshape
(
ctx
,
[
1
,
-
1
,
self
.
filters
])
context
.
append
(
ctx
)
return
self
.
qconcat
(
context
)
class
SelfAttentionV2
(
base_layers
.
BaseLayer
):
"""Self attention encoder (not suitable for causal attention)."""
def
__init__
(
self
,
model_dimension
,
num_heads
,
attention_dropout_rate
=
0.0
,
**
kwargs
):
self
.
model_dimension
=
model_dimension
self
.
num_heads
=
num_heads
self
.
filters
=
model_dimension
//
num_heads
self
.
dense_layers
=
dense_layers
.
BaseQDenseVarLen
(
units
=
model_dimension
*
3
,
activation
=
None
,
**
kwargs
)
self
.
qactivation
=
quantization_layers
.
ActivationQuantization
(
**
kwargs
)
self
.
attention_dropout_rate
=
attention_dropout_rate
self
.
qconcat
=
quantization_layers
.
ConcatQuantization
(
axis
=
1
,
**
kwargs
)
super
(
SelfAttentionV2
,
self
).
__init__
(
**
kwargs
)
def
call
(
self
,
inputs
,
mask
,
inverse_normalizer
,
attn_mask
=
None
):
bsz
=
self
.
get_batch_dimension
(
inputs
)
self
.
_assert_rank_and_type
(
inputs
,
3
)
self
.
_assert_rank_and_type
(
mask
,
3
)
assert
inputs
.
get_shape
().
as_list
()[
-
1
]
==
self
.
model_dimension
inputs_rank2
=
tf
.
reshape
(
inputs
,
[
-
1
,
self
.
model_dimension
])
mask_rank2
=
tf
.
reshape
(
mask
,
[
-
1
,
1
])
tensors
=
self
.
dense_layers
(
inputs_rank2
,
mask_rank2
,
inverse_normalizer
)
if
self
.
parameters
.
mode
not
in
[
base_layers
.
TFLITE
,
base_layers
.
PREDICT
]:
tensors
=
tf
.
reshape
(
tensors
,
[
bsz
,
-
1
,
3
,
self
.
num_heads
,
self
.
filters
])
tensors
=
tf
.
unstack
(
tensors
,
axis
=
2
)
else
:
tensors
=
tf
.
split
(
tensors
,
self
.
num_heads
*
3
,
axis
=
1
)
if
attn_mask
is
None
:
attn_mask
=
tf
.
matmul
(
mask
,
mask
,
transpose_b
=
True
)
if
(
self
.
attention_dropout_rate
>
0.0
and
self
.
parameters
.
mode
==
base_layers
.
TRAIN
):
attn_mask
*=
self
.
random_drop_to_zero
(
attn_mask
,
self
.
attention_dropout_rate
)
attn_mask
=
tf
.
expand_dims
(
attn_mask
,
axis
=
1
)
invalid_mask
=
(
1
-
attn_mask
)
*
self
.
parameters
.
invalid_logit
if
self
.
parameters
.
mode
not
in
[
base_layers
.
TFLITE
,
base_layers
.
PREDICT
]:
queries
=
tf
.
transpose
(
tensors
[
0
],
[
0
,
2
,
1
,
3
])
keys
=
tf
.
transpose
(
tensors
[
1
],
[
0
,
2
,
1
,
3
])
values
=
tf
.
transpose
(
tensors
[
2
],
[
0
,
2
,
1
,
3
])
attn_logits
=
self
.
qactivation
(
tf
.
matmul
(
queries
,
keys
,
transpose_b
=
True
))
attn_logits_masked
=
attn_logits
*
attn_mask
+
invalid_mask
attention
=
tf
.
nn
.
softmax
(
attn_logits_masked
)
attention
=
self
.
qrange_sigmoid
(
attention
,
tf_only
=
True
)
result
=
tf
.
matmul
(
attention
,
values
)
result
=
tf
.
transpose
(
result
,
[
0
,
2
,
1
,
3
])
result
=
tf
.
reshape
(
result
,
[
bsz
,
-
1
,
self
.
model_dimension
])
return
self
.
qconcat
([
result
])
else
:
context
=
[]
for
idx
in
range
(
self
.
num_heads
):
queries
=
tensors
[
idx
]
keys
=
tensors
[
idx
+
self
.
num_heads
]
values
=
tensors
[
idx
+
self
.
num_heads
*
2
]
# Attention is not scaled dot product, batch normalization compensates
# for it.
attn_logits_masked
=
self
.
qactivation
(
tf
.
matmul
(
queries
,
keys
,
transpose_b
=
True
))
attention
=
tf
.
nn
.
softmax
(
attn_logits_masked
)
attention
=
self
.
qrange_sigmoid
(
attention
,
tf_only
=
True
)
context
.
append
(
tf
.
matmul
(
attention
,
values
))
result
=
self
.
qconcat
(
context
)
return
tf
.
reshape
(
result
,
[
1
,
-
1
,
self
.
model_dimension
])
class
TransformerEncoder
(
base_layers
.
BaseLayer
):
"""Transformer Encoder."""
def
__init__
(
self
,
model_dimension
,
num_heads
,
intermediate_size
,
initializer_stddev
=
0.02
,
activation_dropout_rate
=
0.0
,
attention_dropout_rate
=
0.0
,
**
kwargs
):
super
(
TransformerEncoder
,
self
).
__init__
(
**
kwargs
)
self
.
model_dimension
=
model_dimension
self
.
parameters
.
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
initializer_stddev
)
self
.
self_attn
=
SelfAttentionV2
(
model_dimension
,
num_heads
,
attention_dropout_rate
=
attention_dropout_rate
,
parameters
=
self
.
parameters
)
self
.
prx
=
dense_layers
.
BaseQDenseVarLen
(
model_dimension
,
activation
=
None
,
parameters
=
self
.
parameters
)
self
.
upprx
=
dense_layers
.
BaseQDenseVarLen
(
intermediate_size
,
parameters
=
self
.
parameters
)
self
.
downprx
=
dense_layers
.
BaseQDenseVarLen
(
model_dimension
,
activation
=
None
,
parameters
=
self
.
parameters
)
self
.
activation_dropout_rate
=
activation_dropout_rate
self
.
ln1
=
normalization_layers
.
LayerNormalization
(
**
kwargs
)
self
.
ln2
=
normalization_layers
.
LayerNormalization
(
**
kwargs
)
self
.
q1
=
quantization_layers
.
ActivationQuantization
(
**
kwargs
)
self
.
q2
=
quantization_layers
.
ActivationQuantization
(
**
kwargs
)
def
call
(
self
,
inputs
,
mask
,
inverse_normalizer
,
attn_mask
=
None
):
batch_size
=
self
.
get_batch_dimension
(
inputs
)
self
.
_assert_rank_and_type
(
inputs
,
3
)
self
.
_assert_rank_and_type
(
mask
,
3
)
assert
inputs
.
get_shape
().
as_list
()[
-
1
]
==
self
.
model_dimension
mask_rank2
=
tf
.
reshape
(
mask
,
[
-
1
,
1
])
assert
inputs
.
get_shape
().
as_list
()[
-
1
]
==
self
.
model_dimension
tensor
=
self
.
self_attn
(
inputs
,
mask
,
inverse_normalizer
,
attn_mask
)
inputs
=
tf
.
reshape
(
inputs
,
[
-
1
,
self
.
model_dimension
])
tensor
=
tf
.
reshape
(
tensor
,
[
-
1
,
self
.
model_dimension
])
tensor
=
self
.
prx
(
tensor
,
mask_rank2
,
inverse_normalizer
)
if
(
self
.
parameters
.
mode
==
base_layers
.
TRAIN
and
self
.
activation_dropout_rate
>
0.0
):
tensor
=
tf
.
nn
.
dropout
(
tensor
,
rate
=
self
.
activation_dropout_rate
)
inputs_plus_selfattn
=
self
.
q1
(
self
.
ln1
(
inputs
+
tensor
))
ffn_up
=
self
.
upprx
(
inputs_plus_selfattn
,
mask_rank2
,
inverse_normalizer
)
ffn_down
=
self
.
downprx
(
ffn_up
,
mask_rank2
,
inverse_normalizer
)
if
(
self
.
parameters
.
mode
==
base_layers
.
TRAIN
and
self
.
activation_dropout_rate
>
0.0
):
ffn_down
=
tf
.
nn
.
dropout
(
ffn_down
,
rate
=
self
.
activation_dropout_rate
)
inputs_plus_ffn
=
self
.
q2
(
self
.
ln2
(
inputs_plus_selfattn
+
ffn_down
))
return
tf
.
reshape
(
inputs_plus_ffn
,
[
batch_size
,
-
1
,
self
.
model_dimension
])
class
TransformerEncoderStack
(
base_layers
.
BaseLayer
):
"""Transformer Encoder."""
def
__init__
(
self
,
num_layers
,
max_time_step
,
vocabulary_size
,
embedding_size
,
model_dimension
,
num_heads
,
intermediate_size
,
**
kwargs
):
self
.
max_time_step
=
max_time_step
self
.
vocabulary_size
=
vocabulary_size
self
.
embedding_size
=
embedding_size
activation_dropout_rate
=
kwargs
.
pop
(
'activation_dropout_rate'
,
0.0
)
attention_dropout_rate
=
kwargs
.
pop
(
'attention_dropout_rate'
,
0.0
)
self
.
layers
=
[]
for
_
in
range
(
num_layers
):
self
.
layers
.
append
(
TransformerEncoder
(
model_dimension
=
model_dimension
,
num_heads
=
num_heads
,
intermediate_size
=
intermediate_size
,
activation_dropout_rate
=
activation_dropout_rate
,
attention_dropout_rate
=
attention_dropout_rate
,
**
kwargs
))
self
.
embedding
=
embedding_layers
.
EmbeddingLayer
(
shape
=
[
self
.
vocabulary_size
,
self
.
embedding_size
],
**
kwargs
)
self
.
positional_embedding
=
embedding_layers
.
EmbeddingLayer
(
shape
=
[
self
.
max_time_step
,
self
.
embedding_size
],
**
kwargs
)
self
.
ln
=
normalization_layers
.
LayerNormalization
(
**
kwargs
)
self
.
qact
=
quantization_layers
.
ActivationQuantization
(
**
kwargs
)
super
(
TransformerEncoderStack
,
self
).
__init__
(
**
kwargs
)
def
call
(
self
,
input_indices
,
sequence_length
):
mask_rank2
=
tf
.
sequence_mask
(
sequence_length
,
tf
.
shape
(
input_indices
)[
1
],
dtype
=
tf
.
float32
)
mask_rank3
=
tf
.
expand_dims
(
mask_rank2
,
axis
=
2
)
inverse_normalizer
=
tf
.
math
.
reciprocal
(
tf
.
reduce_sum
(
mask_rank3
))
if
self
.
parameters
.
mode
in
[
base_layers
.
PREDICT
,
base_layers
.
TFLITE
]:
sequence_length
=
tf
.
reduce_sum
(
input_indices
+
1
-
input_indices
)
pos_indices
=
tf
.
range
(
sequence_length
,
dtype
=
tf
.
int32
)
pos_indices
=
tf
.
reshape
(
pos_indices
,
[
1
,
-
1
])
else
:
pos_indices
=
tf
.
cumsum
(
mask_rank2
,
axis
=
1
,
exclusive
=
True
)
pos_indices
=
tf
.
cast
(
pos_indices
,
dtype
=
tf
.
int32
)
input_values
=
self
.
embedding
(
input_indices
)
pos_values
=
self
.
positional_embedding
(
pos_indices
)
inputs
=
self
.
qact
(
self
.
ln
(
input_values
+
pos_values
))
attn_mask
=
tf
.
matmul
(
mask_rank3
,
tf
.
transpose
(
mask_rank3
,
[
0
,
2
,
1
]))
if
self
.
parameters
.
mode
not
in
[
base_layers
.
PREDICT
,
base_layers
.
TFLITE
]:
inputs
=
inputs
*
mask_rank3
for
layer
in
self
.
layers
:
outputs
=
layer
(
inputs
,
mask_rank3
,
inverse_normalizer
,
attn_mask
)
inputs
=
outputs
if
self
.
parameters
.
mode
not
in
[
base_layers
.
PREDICT
,
base_layers
.
TFLITE
]:
outputs
=
outputs
*
mask_rank3
return
outputs
class
TransformerEncoderStackWithInputEmbedding
(
TransformerEncoderStack
):
"""Transformer Encoder."""
def
call
(
self
,
inputs
,
sequence_length
):
mask_rank2
=
tf
.
sequence_mask
(
sequence_length
,
tf
.
shape
(
inputs
)[
1
],
dtype
=
tf
.
float32
)
mask_rank3
=
tf
.
expand_dims
(
mask_rank2
,
axis
=
2
)
inverse_normalizer
=
tf
.
math
.
reciprocal
(
tf
.
reduce_sum
(
mask_rank3
))
attn_mask
=
tf
.
matmul
(
mask_rank3
,
tf
.
transpose
(
mask_rank3
,
[
0
,
2
,
1
]))
if
self
.
parameters
.
mode
not
in
[
base_layers
.
PREDICT
,
base_layers
.
TFLITE
]:
inputs
=
inputs
*
mask_rank3
for
layer
in
self
.
layers
:
outputs
=
layer
(
inputs
,
mask_rank3
,
inverse_normalizer
,
attn_mask
)
inputs
=
outputs
if
self
.
parameters
.
mode
not
in
[
base_layers
.
PREDICT
,
base_layers
.
TFLITE
]:
outputs
=
outputs
*
mask_rank3
return
outputs
class
FunnelAttention
(
base_layers
.
BaseLayer
):
"""Self attention encoder (not suitable for causal attention)."""
def
__init__
(
self
,
model_dimension
,
num_heads
,
attention_dropout_rate
=
0.0
,
**
kwargs
):
self
.
model_dimension
=
model_dimension
self
.
num_heads
=
num_heads
self
.
filters
=
model_dimension
//
num_heads
self
.
q_dense_layer
=
dense_layers
.
BaseQDenseVarLen
(
units
=
model_dimension
,
activation
=
None
,
**
kwargs
)
self
.
kv_dense_layer
=
dense_layers
.
BaseQDenseVarLen
(
units
=
model_dimension
*
2
,
activation
=
None
,
**
kwargs
)
self
.
qactivation
=
quantization_layers
.
ActivationQuantization
(
**
kwargs
)
self
.
attention_dropout_rate
=
attention_dropout_rate
self
.
qconcat
=
quantization_layers
.
ConcatQuantization
(
axis
=
1
,
**
kwargs
)
super
(
FunnelAttention
,
self
).
__init__
(
**
kwargs
)
def
call
(
self
,
inputs
,
mask
,
inverse_normalizer
,
memory
,
memory_mask
,
memory_inverse_normalizer
,
attn_mask
):
bsz
=
self
.
get_batch_dimension
(
inputs
)
self
.
_assert_rank_and_type
(
inputs
,
3
)
self
.
_assert_rank_and_type
(
mask
,
3
)
assert
inputs
.
get_shape
().
as_list
()[
-
1
]
==
self
.
model_dimension
self
.
_assert_rank_and_type
(
memory
,
3
)
self
.
_assert_rank_and_type
(
memory_mask
,
3
)
assert
memory
.
get_shape
().
as_list
()[
-
1
]
==
self
.
model_dimension
inputs_rank2
=
tf
.
reshape
(
inputs
,
[
-
1
,
self
.
model_dimension
])
mask_rank2
=
tf
.
reshape
(
mask
,
[
-
1
,
1
])
q_tensor
=
self
.
q_dense_layer
(
inputs_rank2
,
mask_rank2
,
inverse_normalizer
)
memory_rank2
=
tf
.
reshape
(
memory
,
[
-
1
,
self
.
model_dimension
])
memory_mask_rank2
=
tf
.
reshape
(
memory_mask
,
[
-
1
,
1
])
kv_tensors
=
self
.
kv_dense_layer
(
memory_rank2
,
memory_mask_rank2
,
inverse_normalizer
)
if
self
.
parameters
.
mode
not
in
[
base_layers
.
TFLITE
,
base_layers
.
PREDICT
]:
q_tensor
=
tf
.
reshape
(
q_tensor
,
[
bsz
,
-
1
,
self
.
num_heads
,
self
.
filters
])
kv_tensors
=
tf
.
reshape
(
kv_tensors
,
[
bsz
,
-
1
,
2
,
self
.
num_heads
,
self
.
filters
])
kv_tensors
=
tf
.
unstack
(
kv_tensors
,
axis
=
2
)
else
:
q_tensor
=
tf
.
split
(
q_tensor
,
self
.
num_heads
,
axis
=
1
)
kv_tensors
=
tf
.
split
(
kv_tensors
,
self
.
num_heads
*
2
,
axis
=
1
)
attn_mask
=
tf
.
expand_dims
(
attn_mask
,
axis
=
1
)
invalid_mask
=
(
1
-
attn_mask
)
*
self
.
parameters
.
invalid_logit
if
self
.
parameters
.
mode
not
in
[
base_layers
.
TFLITE
,
base_layers
.
PREDICT
]:
queries
=
tf
.
transpose
(
q_tensor
,
[
0
,
2
,
1
,
3
])
keys
=
tf
.
transpose
(
kv_tensors
[
0
],
[
0
,
2
,
1
,
3
])
values
=
tf
.
transpose
(
kv_tensors
[
1
],
[
0
,
2
,
1
,
3
])
attn_logits
=
self
.
qactivation
(
tf
.
matmul
(
queries
,
keys
,
transpose_b
=
True
))
attn_logits_masked
=
attn_logits
*
attn_mask
+
invalid_mask
attention
=
tf
.
nn
.
softmax
(
attn_logits_masked
)
attention
=
self
.
qrange_sigmoid
(
attention
,
tf_only
=
True
)
result
=
tf
.
matmul
(
attention
,
values
)
result
=
tf
.
transpose
(
result
,
[
0
,
2
,
1
,
3
])
result
=
tf
.
reshape
(
result
,
[
bsz
,
-
1
,
self
.
model_dimension
])
return
self
.
qconcat
([
result
])
else
:
context
=
[]
for
idx
in
range
(
self
.
num_heads
):
queries
=
q_tensor
[
idx
]
keys
=
kv_tensors
[
idx
]
values
=
kv_tensors
[
idx
+
self
.
num_heads
]
# Attention is not scaled dot product, batch normalization compensates
# for it.
attn_logits_masked
=
self
.
qactivation
(
tf
.
matmul
(
queries
,
keys
,
transpose_b
=
True
))
attention
=
tf
.
nn
.
softmax
(
attn_logits_masked
)
attention
=
self
.
qrange_sigmoid
(
attention
,
tf_only
=
True
)
context
.
append
(
tf
.
matmul
(
attention
,
values
))
result
=
self
.
qconcat
(
context
)
return
tf
.
reshape
(
result
,
[
1
,
-
1
,
self
.
model_dimension
])
class
FunnelTransformerEncoder
(
base_layers
.
BaseLayer
):
"""Transformer Encoder."""
def
__init__
(
self
,
model_dimension
,
num_heads
,
intermediate_size
,
initializer_stddev
=
0.02
,
activation_dropout_rate
=
0.0
,
attention_dropout_rate
=
0.0
,
**
kwargs
):
super
(
FunnelTransformerEncoder
,
self
).
__init__
(
**
kwargs
)
self
.
model_dimension
=
model_dimension
self
.
parameters
.
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
initializer_stddev
)
self
.
self_attn
=
FunnelAttention
(
model_dimension
,
num_heads
,
attention_dropout_rate
=
attention_dropout_rate
,
parameters
=
self
.
parameters
)
self
.
prx
=
dense_layers
.
BaseQDenseVarLen
(
model_dimension
,
activation
=
None
,
parameters
=
self
.
parameters
)
self
.
upprx
=
dense_layers
.
BaseQDenseVarLen
(
intermediate_size
,
parameters
=
self
.
parameters
)
self
.
downprx
=
dense_layers
.
BaseQDenseVarLen
(
model_dimension
,
activation
=
None
,
parameters
=
self
.
parameters
)
self
.
activation_dropout_rate
=
activation_dropout_rate
self
.
ln1
=
normalization_layers
.
LayerNormalization
(
**
kwargs
)
self
.
ln2
=
normalization_layers
.
LayerNormalization
(
**
kwargs
)
self
.
q1
=
quantization_layers
.
ActivationQuantization
(
**
kwargs
)
self
.
q2
=
quantization_layers
.
ActivationQuantization
(
**
kwargs
)
def
call
(
self
,
inputs
,
mask
,
inverse_normalizer
,
memory
,
memory_mask
,
memory_inverse_normalizer
,
attn_mask
):
batch_size
=
self
.
get_batch_dimension
(
inputs
)
self
.
_assert_rank_and_type
(
inputs
,
3
)
self
.
_assert_rank_and_type
(
mask
,
3
)
assert
inputs
.
get_shape
().
as_list
()[
-
1
]
==
self
.
model_dimension
mask_rank2
=
tf
.
reshape
(
mask
,
[
-
1
,
1
])
assert
inputs
.
get_shape
().
as_list
()[
-
1
]
==
self
.
model_dimension
tensor
=
self
.
self_attn
(
inputs
,
mask
,
inverse_normalizer
,
memory
,
memory_mask
,
memory_inverse_normalizer
,
attn_mask
)
inputs
=
tf
.
reshape
(
inputs
,
[
-
1
,
self
.
model_dimension
])
tensor
=
tf
.
reshape
(
tensor
,
[
-
1
,
self
.
model_dimension
])
tensor
=
self
.
prx
(
tensor
,
mask_rank2
,
inverse_normalizer
)
if
(
self
.
parameters
.
mode
==
base_layers
.
TRAIN
and
self
.
activation_dropout_rate
>
0.0
):
tensor
=
tf
.
nn
.
dropout
(
tensor
,
rate
=
self
.
activation_dropout_rate
)
inputs_plus_selfattn
=
self
.
q1
(
self
.
ln1
(
inputs
+
tensor
))
ffn_up
=
self
.
upprx
(
inputs_plus_selfattn
,
mask_rank2
,
inverse_normalizer
)
ffn_down
=
self
.
downprx
(
ffn_up
,
mask_rank2
,
inverse_normalizer
)
if
(
self
.
parameters
.
mode
==
base_layers
.
TRAIN
and
self
.
activation_dropout_rate
>
0.0
):
ffn_down
=
tf
.
nn
.
dropout
(
ffn_down
,
rate
=
self
.
activation_dropout_rate
)
inputs_plus_ffn
=
self
.
q2
(
self
.
ln2
(
inputs_plus_selfattn
+
ffn_down
))
return
tf
.
reshape
(
inputs_plus_ffn
,
[
batch_size
,
-
1
,
self
.
model_dimension
])
class
FunnelTransformerEncoderStack
(
base_layers
.
BaseLayer
):
"""Transformer Encoder."""
def
__init__
(
self
,
num_layers
,
max_time_step
,
vocabulary_size
,
embedding_size
,
model_dimension
,
num_heads
,
intermediate_size
,
**
kwargs
):
self
.
max_time_step
=
max_time_step
self
.
pool_windows
=
kwargs
.
pop
(
'pool_windows'
,
[])
assert
len
(
self
.
pool_windows
)
==
num_layers
self
.
vocabulary_size
=
vocabulary_size
activation_dropout_rate
=
kwargs
.
pop
(
'activation_dropout_rate'
,
0.0
)
attention_dropout_rate
=
kwargs
.
pop
(
'attention_dropout_rate'
,
0.0
)
self
.
layers
=
[]
for
_
in
range
(
num_layers
):
self
.
layers
.
append
(
FunnelTransformerEncoder
(
model_dimension
=
model_dimension
,
num_heads
=
num_heads
,
intermediate_size
=
intermediate_size
,
activation_dropout_rate
=
activation_dropout_rate
,
attention_dropout_rate
=
attention_dropout_rate
,
**
kwargs
))
super
(
FunnelTransformerEncoderStack
,
self
).
__init__
(
**
kwargs
)
def
call
(
self
,
inputs
,
sequence_length
):
mask_rank2
=
tf
.
sequence_mask
(
sequence_length
,
tf
.
shape
(
inputs
)[
1
],
dtype
=
tf
.
float32
)
mask_rank3
=
tf
.
expand_dims
(
mask_rank2
,
axis
=
2
)
if
self
.
parameters
.
mode
not
in
[
base_layers
.
PREDICT
,
base_layers
.
TFLITE
]:
inputs
=
inputs
*
mask_rank3
pooled_inputs
=
inputs
pooled_mask
=
mask_rank3
pooled_inverse_normalizer
=
tf
.
math
.
reciprocal
(
tf
.
reduce_sum
(
pooled_mask
))
memory
=
pooled_inputs
memory_mask
=
pooled_mask
memory_inverse_normalizer
=
pooled_inverse_normalizer
for
i
,
layer
in
enumerate
(
self
.
layers
):
if
self
.
pool_windows
[
i
]
>
1
:
pooled_inputs
=
tf
.
nn
.
avg_pool
(
pooled_inputs
,
[
self
.
pool_windows
[
i
]],
strides
=
[
self
.
pool_windows
[
i
]],
padding
=
'SAME'
)
pooled_mask
=
pooled_mask
[:,
::
self
.
pool_windows
[
i
],
:]
pooled_inverse_normalizer
=
tf
.
math
.
reciprocal
(
tf
.
reduce_sum
(
pooled_mask
))
attn_mask
=
tf
.
matmul
(
pooled_mask
,
memory_mask
,
transpose_b
=
True
)
pooled_outputs
=
layer
(
pooled_inputs
,
pooled_mask
,
pooled_inverse_normalizer
,
memory
,
memory_mask
,
memory_inverse_normalizer
,
attn_mask
)
pooled_inputs
=
pooled_outputs
pooled_inverse_normalizer
=
tf
.
math
.
reciprocal
(
tf
.
reduce_sum
(
pooled_mask
))
memory
=
pooled_inputs
memory_mask
=
pooled_mask
memory_inverse_normalizer
=
pooled_inverse_normalizer
if
self
.
parameters
.
mode
not
in
[
base_layers
.
PREDICT
,
base_layers
.
TFLITE
]:
pooled_outputs
=
pooled_outputs
*
pooled_mask
return
pooled_outputs
,
pooled_mask
class
DecoderMultiheadAttention
(
base_layers
.
BaseLayer
):
"""Multihead attention for decoder."""
def
__init__
(
self
,
model_dimension
,
num_heads
,
attention_dropout_rate
=
0.0
,
cached_kv
=
False
,
**
kwargs
):
self
.
model_dimension
=
model_dimension
self
.
num_heads
=
num_heads
self
.
filters
=
model_dimension
//
num_heads
self
.
cached_kv
=
cached_kv
self
.
q_dense_layers
=
dense_layers
.
BaseQDense
(
units
=
model_dimension
,
activation
=
None
,
normalize
=
False
,
bias
=
False
,
**
kwargs
)
self
.
kv_dense_layers
=
dense_layers
.
BaseQDenseVarLen
(
units
=
model_dimension
*
2
,
activation
=
None
,
**
kwargs
)
self
.
qactivation
=
quantization_layers
.
ActivationQuantization
(
**
kwargs
)
self
.
attention_dropout_rate
=
attention_dropout_rate
self
.
qconcat
=
quantization_layers
.
ConcatQuantization
(
axis
=
1
,
**
kwargs
)
super
(
DecoderMultiheadAttention
,
self
).
__init__
(
**
kwargs
)
def
call
(
self
,
inputs
,
input_mask
,
input_inverse_normalizer
,
memory
=
None
,
memory_mask
=
None
,
memory_inverse_normalizer
=
None
,
attn_mask
=
None
):
bsz
=
self
.
get_batch_dimension
(
inputs
)
self
.
_assert_rank_and_type
(
inputs
,
3
)
self
.
_assert_rank_and_type
(
input_mask
,
3
)
assert
inputs
.
get_shape
().
as_list
()[
-
1
]
==
self
.
model_dimension
inputs_rank2
=
tf
.
reshape
(
inputs
,
[
-
1
,
self
.
model_dimension
])
q_tensor
=
self
.
q_dense_layers
(
inputs_rank2
)
if
memory
is
not
None
:
self
.
_assert_rank_and_type
(
memory
,
2
)
self
.
_assert_rank_and_type
(
memory_mask
,
2
)
if
self
.
cached_kv
:
# Keys and Values are cached and reused at each layer.
assert
memory
.
get_shape
().
as_list
()[
1
]
==
2
*
self
.
model_dimension
kv_tensors
=
memory
else
:
kv_tensors
=
self
.
kv_dense_layers
(
memory
,
memory_mask
,
memory_inverse_normalizer
)
else
:
kv_tensors
=
self
.
kv_dense_layers
(
inputs_rank2
)
if
self
.
parameters
.
mode
not
in
[
base_layers
.
TFLITE
,
base_layers
.
PREDICT
]:
q_tensor
=
tf
.
reshape
(
q_tensor
,
[
bsz
,
-
1
,
self
.
num_heads
,
self
.
filters
])
kv_tensors
=
tf
.
reshape
(
kv_tensors
,
[
bsz
,
-
1
,
2
,
self
.
num_heads
,
self
.
filters
])
kv_tensors
=
tf
.
unstack
(
kv_tensors
,
axis
=
2
)
else
:
q_tensor
=
tf
.
split
(
q_tensor
,
self
.
num_heads
,
axis
=
1
)
kv_tensors
=
tf
.
split
(
kv_tensors
,
self
.
num_heads
*
2
,
axis
=
1
)
if
self
.
parameters
.
mode
in
[
base_layers
.
TRAIN
,
base_layers
.
EVAL
]:
assert
attn_mask
is
not
None
if
(
self
.
attention_dropout_rate
>
0.0
and
self
.
parameters
.
mode
==
base_layers
.
TRAIN
):
attn_mask
*=
self
.
random_drop_to_zero
(
attn_mask
,
self
.
attention_dropout_rate
)
attn_mask
=
tf
.
expand_dims
(
attn_mask
,
1
)
invalid_mask
=
(
1
-
attn_mask
)
*
self
.
parameters
.
invalid_logit
queries
=
tf
.
transpose
(
q_tensor
,
[
0
,
2
,
1
,
3
])
keys
=
tf
.
transpose
(
kv_tensors
[
0
],
[
0
,
2
,
1
,
3
])
values
=
tf
.
transpose
(
kv_tensors
[
1
],
[
0
,
2
,
1
,
3
])
attn_logits
=
self
.
qactivation
(
tf
.
matmul
(
queries
,
keys
,
transpose_b
=
True
))
attn_logits_masked
=
attn_logits
*
attn_mask
+
invalid_mask
attention
=
tf
.
nn
.
softmax
(
attn_logits_masked
)
attention
=
self
.
qrange_sigmoid
(
attention
,
tf_only
=
True
)
result
=
tf
.
matmul
(
attention
,
values
)
result
=
tf
.
transpose
(
result
,
[
0
,
2
,
1
,
3
])
result
=
tf
.
reshape
(
result
,
[
bsz
,
-
1
,
self
.
model_dimension
])
return
self
.
qconcat
([
result
])
else
:
# We need to invoke the keras layer before calling APIs that it provides
# such as quantize_using_range.
self
.
qconcat
(
None
)
context
=
[]
for
head
in
range
(
self
.
num_heads
):
queries
=
q_tensor
[
head
]
if
self
.
parameters
.
mode
==
base_layers
.
PREDICT
:
# PREDICT mode assumes callers tile and merge beam size with batch
# size. Hence extracting the first entry in the tile to compute
# attention.
keys
=
tf
.
split
(
kv_tensors
[
head
],
bsz
,
axis
=
0
)
keys
=
keys
[
0
]
values
=
tf
.
split
(
kv_tensors
[
head
+
self
.
num_heads
],
bsz
,
axis
=
0
)
values
=
values
[
0
]
else
:
keys
=
kv_tensors
[
head
]
values
=
kv_tensors
[
head
+
self
.
num_heads
]
attn_logits_masked
=
self
.
qactivation
(
tf
.
matmul
(
queries
,
keys
,
transpose_b
=
True
))
attention
=
tf
.
nn
.
softmax
(
attn_logits_masked
)
attention
=
self
.
qrange_sigmoid
(
attention
,
tf_only
=
True
)
context
.
append
(
self
.
qconcat
.
quantize_using_range
(
tf
.
matmul
(
attention
,
values
)))
# Concatenating heads along axis 1.
result
=
self
.
qconcat
.
quantize_using_range
(
tf
.
concat
(
context
,
axis
=
1
))
return
tf
.
reshape
(
result
,
[
-
1
,
1
,
self
.
model_dimension
])
class
DecoderUniformAttention
(
base_layers
.
BaseLayer
):
"""Decoder uniform attention."""
def
__init__
(
self
,
model_dimension
,
max_time_step
,
attention_dropout_rate
=
0.0
,
beam_size
=
1
,
**
kwargs
):
self
.
model_dimension
=
model_dimension
self
.
max_time_step
=
max_time_step
self
.
beam_size
=
beam_size
self
.
causal_mask
=
tf
.
expand_dims
(
tf
.
linalg
.
band_part
(
tf
.
ones
([
max_time_step
,
max_time_step
]),
-
1
,
0
),
0
)
self
.
dense_layers
=
dense_layers
.
BaseQDenseVarLen
(
units
=
model_dimension
,
activation
=
None
,
normalize
=
False
,
bias
=
False
,
rank
=
3
,
**
kwargs
)
self
.
qoutput
=
quantization_layers
.
ActivationQuantization
(
**
kwargs
)
super
(
DecoderUniformAttention
,
self
).
__init__
(
**
kwargs
)
def
get_uniform_attention
(
self
,
attn_mask
=
None
):
"""Generates uniform attention matrix using `causal_mask`."""
mask
=
tf
.
math
.
divide_no_nan
(
self
.
causal_mask
,
tf
.
reduce_sum
(
self
.
causal_mask
,
axis
=-
1
,
keepdims
=
True
))
if
attn_mask
is
not
None
:
self
.
_assert_rank_and_type
(
attn_mask
,
3
)
mask
=
mask
*
attn_mask
return
mask
def
call
(
self
,
inputs
,
mask
,
inverse_normalizer
,
step
=
None
,
beam_indices
=
None
,
cache
=
None
,
attn_mask
=
None
):
self
.
_assert_rank_and_type
(
inputs
,
3
)
self
.
_assert_rank_and_type
(
mask
,
3
)
assert
inputs
.
get_shape
().
as_list
()[
-
1
]
==
self
.
model_dimension
layer_out
=
self
.
dense_layers
(
inputs
,
mask
,
inverse_normalizer
)
# TFLite mode is handled with a custom op.
if
self
.
parameters
.
mode
==
base_layers
.
TFLITE
:
assert
beam_indices
is
not
None
assert
step
is
not
None
layer_out
=
tf_custom_ops_py
.
uniform_causal_attn
(
layer_out
,
step
,
beam_indices
,
self
.
model_dimension
,
self
.
beam_size
)
else
:
# Cache is used for TF Predict and Eval modes.
if
cache
is
None
:
attention_matrix
=
self
.
get_uniform_attention
(
attn_mask
)
layer_out
=
tf
.
matmul
(
attention_matrix
,
layer_out
)
else
:
assert
self
.
parameters
.
mode
in
[
base_layers
.
PREDICT
,
base_layers
.
EVAL
]
assert
step
is
not
None
cache
[
'uniform_avg'
]
=
layer_out
+
cache
[
'uniform_avg'
]
layer_out
=
cache
[
'uniform_avg'
]
/
tf
.
cast
(
step
,
dtype
=
tf
.
float32
)
return
self
.
qoutput
(
layer_out
)
research/seq_flow_lite/models/BUILD
View file @
a81f8590
...
@@ -54,3 +54,52 @@ py_library(
...
@@ -54,3 +54,52 @@ py_library(
"//tf_ops:tf_custom_ops_py"
,
"//tf_ops:tf_custom_ops_py"
,
],
],
)
)
py_library
(
name
=
"charformer"
,
srcs
=
[
"charformer.py"
],
srcs_version
=
"PY3"
,
deps
=
[
":transformer_encoder"
,
# package tensorflow
"//layers:base_layers"
,
"//layers:embedding_layers"
,
"//layers:misc_layers"
,
"//layers:normalization_layers"
,
"//layers:quantization_layers"
,
# "//tf_ops:tf_custom_ops",
"//tf_ops:tf_custom_ops_py"
,
],
)
py_library
(
name
=
"transformer_encoder"
,
srcs
=
[
"transformer_encoder.py"
],
srcs_version
=
"PY3"
,
deps
=
[
# package absl/logging
# package tensorflow
"//layers:base_layers"
,
"//layers:embedding_layers"
,
"//layers:transformer_layers"
,
# "//tf_ops:tf_custom_ops",
"//tf_ops:tf_custom_ops_py"
,
],
)
py_library
(
name
=
"transformer_uniform_attn_decoder"
,
srcs
=
[
"transformer_uniform_attn_decoder.py"
],
srcs_version
=
"PY3"
,
deps
=
[
# package absl/logging
# package tensorflow
# tensor2tensor/utils:beam_search",
"//layers:base_layers"
,
"//layers:embedding_layers"
,
"//layers:misc_layers"
,
"//layers:transformer_layers"
,
"//tf_ops:tf_custom_ops"
,
"//tf_ops:tf_custom_ops_py"
,
],
)
research/seq_flow_lite/models/charformer.py
0 → 100644
View file @
a81f8590
# Copyright 2022 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Charformer based model for in-training tokenization."""
from
absl
import
logging
import
tensorflow
as
tf
from
layers
import
base_layers
from
layers
import
dense_layers
from
layers
import
embedding_layers
from
layers
import
misc_layers
from
layers
import
normalization_layers
from
layers
import
quantization_layers
from
models
import
transformer_encoder
class
Encoder
(
tf
.
keras
.
layers
.
Layer
):
"""Encoder with GBST and Transformer layers."""
def
__init__
(
self
,
config
,
mode
,
**
kwargs
):
super
(
Encoder
,
self
).
__init__
(
**
kwargs
)
def
_get_params
(
varname
,
default_value
=
None
):
value
=
config
[
varname
]
if
varname
in
config
else
default_value
default
=
""
if
varname
in
config
else
" (default)"
logging
.
info
(
"%s = %s%s"
,
varname
,
value
,
default
)
setattr
(
self
,
varname
,
value
)
_get_params
(
"labels"
,
[])
_get_params
(
"regularizer_scale"
)
_get_params
(
"quantize"
)
_get_params
(
"feature_size"
)
_get_params
(
"bottleneck_size"
)
self
.
max_seq_len
=
config
.
get
(
"max_seq_len"
,
128
)
self
.
gbst_max_token_len
=
config
.
get
(
"gbst_max_token_len"
,
128
)
# Including 3 additional special token ids (0=padding, 1=EOS, 2=UNK).
self
.
vocabulary_size
=
config
.
get
(
"vocabulary_size"
,
259
)
self
.
parameters
=
base_layers
.
Parameters
(
mode
,
quantize
=
self
.
quantize
,
regularizer_scale
=
self
.
regularizer_scale
)
self
.
embedding
=
embedding_layers
.
EmbeddingLayer
(
shape
=
[
self
.
vocabulary_size
,
self
.
feature_size
],
parameters
=
self
.
parameters
)
self
.
gbst_downsample_rate
=
config
.
get
(
"gbst_downsample_rate"
,
1
)
self
.
positional_embedding
=
embedding_layers
.
EmbeddingLayer
(
shape
=
[
self
.
gbst_max_token_len
,
self
.
feature_size
],
parameters
=
self
.
parameters
)
self
.
ln
=
normalization_layers
.
LayerNormalization
(
parameters
=
self
.
parameters
)
self
.
qact
=
quantization_layers
.
ActivationQuantization
(
parameters
=
self
.
parameters
)
self
.
bottleneck_layer
=
None
gbst_size
=
self
.
feature_size
if
self
.
bottleneck_size
!=
self
.
feature_size
:
self
.
bottleneck_layer
=
dense_layers
.
BaseQDenseVarLen
(
self
.
bottleneck_size
,
rank
=
3
,
normalize
=
False
,
activation
=
None
,
parameters
=
self
.
parameters
)
gbst_size
=
self
.
bottleneck_size
self
.
gbst_max_subword_block_width
=
config
.
get
(
"gbst_max_subword_block_width"
,
5
)
self
.
gbst_conv_kernel_size
=
config
.
get
(
"gbst_conv_kernel_size"
,
5
)
self
.
gbst_block_mixing_mode
=
config
.
get
(
"gbst_block_mixing_mode"
,
None
)
self
.
gbst_layer
=
misc_layers
.
GBSTLayerV2
(
feature_size
=
gbst_size
,
max_seq_len
=
self
.
gbst_max_token_len
,
downsample_rate
=
self
.
gbst_downsample_rate
,
max_subword_block_width
=
self
.
gbst_max_subword_block_width
,
conv_kernel_size
=
self
.
gbst_conv_kernel_size
,
block_mixing_mode
=
self
.
gbst_block_mixing_mode
,
parameters
=
self
.
parameters
)
self
.
pool_windows
=
config
.
get
(
"pool_windows"
,
None
)
if
self
.
pool_windows
:
self
.
transformer_encoder_layer
=
transformer_encoder
.
FunnelTransformerModel
(
config
,
mode
)
else
:
self
.
transformer_encoder_layer
=
transformer_encoder
.
ModelWithEmbeddings
(
config
,
mode
)
self
.
attention_pool
=
misc_layers
.
AttentionPooling
(
parameters
=
self
.
parameters
)
self
.
num_classes
=
len
(
self
.
labels
)
if
self
.
num_classes
:
self
.
final_fc
=
dense_layers
.
BaseQDense
(
units
=
self
.
num_classes
,
rank
=
2
,
parameters
=
self
.
parameters
,
activation
=
None
)
def
call
(
self
,
token_ids
,
seq_length
):
if
self
.
parameters
.
mode
in
[
base_layers
.
PREDICT
,
base_layers
.
TFLITE
]:
mask_rank2
=
tf
.
ones
(
tf
.
shape
(
token_ids
),
dtype
=
tf
.
int32
)
seq_length
=
tf
.
reduce_sum
(
mask_rank2
,
axis
=
1
)
pos_indices
=
tf
.
cumsum
(
mask_rank2
,
axis
=
1
,
exclusive
=
True
)
pos_indices
=
tf
.
cast
(
pos_indices
,
dtype
=
tf
.
int32
)
pos_indices
=
tf
.
reshape
(
pos_indices
,
[
1
,
-
1
])
else
:
mask_rank2
=
tf
.
sequence_mask
(
seq_length
,
tf
.
shape
(
token_ids
)[
1
],
dtype
=
tf
.
float32
)
pos_indices
=
tf
.
cumsum
(
mask_rank2
,
axis
=
1
,
exclusive
=
True
)
pos_indices
=
tf
.
cast
(
pos_indices
,
dtype
=
tf
.
int32
)
input_values
=
self
.
embedding
(
token_ids
)
pos_values
=
self
.
positional_embedding
(
pos_indices
)
input_embeds
=
self
.
qact
(
self
.
ln
(
input_values
+
pos_values
))
if
self
.
bottleneck_layer
is
not
None
:
maskr3
=
tf
.
expand_dims
(
mask_rank2
,
axis
=
2
)
maskr3
=
tf
.
cast
(
maskr3
,
tf
.
float32
)
bottleneck_output
=
self
.
bottleneck_layer
(
input_embeds
,
maskr3
)
else
:
bottleneck_output
=
input_embeds
gbst_output
=
self
.
gbst_layer
(
bottleneck_output
,
seq_length
)
if
self
.
parameters
.
mode
in
[
base_layers
.
PREDICT
,
base_layers
.
TFLITE
]:
mask_rank2
=
tf
.
ones
(
tf
.
shape
(
gbst_output
)[:
-
1
],
dtype
=
tf
.
float32
)
seq_length
=
tf
.
reduce_sum
(
mask_rank2
,
axis
=
1
)
else
:
seq_length
=
seq_length
/
self
.
gbst_downsample_rate
if
self
.
pool_windows
:
outputs
,
mask
=
self
.
transformer_encoder_layer
(
gbst_output
,
seq_length
)
inverse_normalizer
=
tf
.
math
.
reciprocal
(
tf
.
reduce_sum
(
mask
))
pre_logits
=
self
.
attention_pool
(
outputs
,
mask
,
inverse_normalizer
)
else
:
outputs
=
self
.
transformer_encoder_layer
(
gbst_output
,
seq_length
)
mask
=
tf
.
sequence_mask
(
seq_length
,
tf
.
shape
(
outputs
)[
1
],
dtype
=
tf
.
float32
)
inverse_normalizer
=
tf
.
math
.
reciprocal
(
tf
.
reduce_sum
(
mask
))
maskr3
=
tf
.
expand_dims
(
mask
,
axis
=
2
)
pre_logits
=
self
.
attention_pool
(
outputs
,
maskr3
,
inverse_normalizer
)
if
self
.
num_classes
:
return
self
.
final_fc
(
pre_logits
)
else
:
return
pre_logits
research/seq_flow_lite/models/transformer_encoder.py
0 → 100644
View file @
a81f8590
# Copyright 2020 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of pQRNN model."""
# pylint: disable=arguments-renamed
from
absl
import
logging
import
tensorflow
as
tf
from
layers
import
base_layers
from
layers
import
transformer_layers
class
Model
(
tf
.
keras
.
layers
.
Layer
):
"""Quantized transformer encoder."""
def
__init__
(
self
,
config
,
mode
):
def
_get_params
(
varname
,
default_value
=
None
):
value
=
config
[
varname
]
if
varname
in
config
else
default_value
default
=
""
if
varname
in
config
else
" (default)"
logging
.
info
(
"%s = %s%s"
,
varname
,
value
,
default
)
setattr
(
self
,
varname
,
value
)
_get_params
(
"intermediate_size"
)
_get_params
(
"max_time_step"
)
_get_params
(
"embedding_size"
)
_get_params
(
"vocabulary_size"
)
_get_params
(
"num_layers"
)
_get_params
(
"labels"
)
_get_params
(
"regularizer_scale"
)
_get_params
(
"num_heads"
)
_get_params
(
"model_dimension"
)
_get_params
(
"quantize"
)
_get_params
(
"activation_dropout_rate"
,
0.0
)
_get_params
(
"attention_dropout_rate"
,
0.0
)
self
.
parameters
=
base_layers
.
Parameters
(
mode
,
self
.
quantize
,
self
.
regularizer_scale
)
super
(
Model
,
self
).
__init__
()
def
build
(
self
,
input_shape
):
self
.
transformer
=
transformer_layers
.
TransformerEncoderStack
(
parameters
=
self
.
parameters
,
num_layers
=
self
.
num_layers
,
intermediate_size
=
self
.
intermediate_size
,
embedding_size
=
self
.
embedding_size
,
max_time_step
=
self
.
max_time_step
,
num_heads
=
self
.
num_heads
,
model_dimension
=
self
.
model_dimension
,
vocabulary_size
=
self
.
vocabulary_size
,
activation_dropout_rate
=
self
.
activation_dropout_rate
,
attention_dropout_rate
=
self
.
attention_dropout_rate
)
def
call
(
self
,
indices
,
sequence_length
):
return
self
.
transformer
(
indices
,
sequence_length
)
class
ModelWithEmbeddings
(
Model
):
"""Quantized transformer encoder which takes embeddings instead of indices."""
def
build
(
self
,
input_shape
):
self
.
transformer_with_input_embedding
=
transformer_layers
.
TransformerEncoderStackWithInputEmbedding
(
parameters
=
self
.
parameters
,
num_layers
=
self
.
num_layers
,
intermediate_size
=
self
.
intermediate_size
,
embedding_size
=
self
.
embedding_size
,
max_time_step
=
self
.
max_time_step
,
num_heads
=
self
.
num_heads
,
model_dimension
=
self
.
model_dimension
,
vocabulary_size
=
self
.
vocabulary_size
,
activation_dropout_rate
=
self
.
activation_dropout_rate
,
attention_dropout_rate
=
self
.
attention_dropout_rate
)
def
call
(
self
,
embeddings
,
sequence_length
):
return
self
.
transformer_with_input_embedding
(
embeddings
,
sequence_length
)
class
FunnelTransformerModel
(
Model
):
"""Quantized transformer encoder which takes embeddings instead of indices."""
def
__init__
(
self
,
config
,
mode
):
self
.
pool_windows
=
config
.
get
(
"pool_windows"
,
None
)
super
(
FunnelTransformerModel
,
self
).
__init__
(
config
,
mode
)
def
build
(
self
,
input_shape
):
self
.
funnel_transformer
=
transformer_layers
.
FunnelTransformerEncoderStack
(
parameters
=
self
.
parameters
,
num_layers
=
self
.
num_layers
,
intermediate_size
=
self
.
intermediate_size
,
embedding_size
=
self
.
embedding_size
,
max_time_step
=
self
.
max_time_step
,
num_heads
=
self
.
num_heads
,
model_dimension
=
self
.
model_dimension
,
vocabulary_size
=
self
.
vocabulary_size
,
activation_dropout_rate
=
self
.
activation_dropout_rate
,
attention_dropout_rate
=
self
.
attention_dropout_rate
,
pool_windows
=
self
.
pool_windows
)
def
call
(
self
,
embeddings
,
sequence_length
):
return
self
.
funnel_transformer
(
embeddings
,
sequence_length
)
research/seq_flow_lite/models/transformer_uniform_attn_decoder.py
0 → 100644
View file @
a81f8590
# Copyright 2020 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of Transformer decoder model."""
import
math
from
absl
import
logging
from
tensor2tensor.utils
import
beam_search
import
tensorflow
as
tf
from
layers
import
base_layers
from
layers
import
dense_layers
from
layers
import
embedding_layers
from
layers
import
normalization_layers
from
layers
import
quantization_layers
from
layers
import
transformer_layers
class
TransformerUniformAttnDecoder
(
base_layers
.
BaseLayer
):
"""Transformer Uniform Attention Decoder."""
def
__init__
(
self
,
model_dimension
,
max_time_step
,
num_heads
,
intermediate_size
,
activation_dropout_rate
=
0.0
,
attention_dropout_rate
=
0.0
,
beam_size
=
1
,
cached_kv
=
False
,
**
kwargs
):
self
.
model_dimension
=
model_dimension
self
.
decoder_uniform_attn
=
transformer_layers
.
DecoderUniformAttention
(
model_dimension
,
max_time_step
,
attention_dropout_rate
=
attention_dropout_rate
,
beam_size
=
beam_size
,
**
kwargs
)
self
.
multihead_cross_attn
=
transformer_layers
.
DecoderMultiheadAttention
(
model_dimension
,
num_heads
,
cached_kv
=
cached_kv
,
attention_dropout_rate
=
attention_dropout_rate
,
**
kwargs
)
self
.
prx
=
dense_layers
.
BaseQDense
(
model_dimension
,
activation
=
None
,
normalize
=
False
,
bias
=
False
,
**
kwargs
)
self
.
upprx
=
dense_layers
.
BaseQDense
(
intermediate_size
,
normalize
=
False
,
**
kwargs
)
self
.
downprx
=
dense_layers
.
BaseQDense
(
model_dimension
,
activation
=
None
,
normalize
=
False
,
**
kwargs
)
self
.
activation_dropout_rate
=
activation_dropout_rate
self
.
ln1
=
normalization_layers
.
LayerNormalization
(
**
kwargs
)
self
.
ln2
=
normalization_layers
.
LayerNormalization
(
**
kwargs
)
self
.
q0
=
quantization_layers
.
ActivationQuantization
(
**
kwargs
)
self
.
q1
=
quantization_layers
.
ActivationQuantization
(
**
kwargs
)
self
.
q2
=
quantization_layers
.
ActivationQuantization
(
**
kwargs
)
super
(
TransformerUniformAttnDecoder
,
self
).
__init__
(
**
kwargs
)
def
call
(
self
,
dec_inputs
,
dec_mask
,
dec_inverse_normalizer
,
enc_output
,
enc_mask
,
enc_inverse_normalizer
,
cross_attn_mask
=
None
,
step
=
None
,
selected_beams
=
None
,
cache
=
None
):
batch_size
=
self
.
get_batch_dimension
(
dec_inputs
)
self
.
_assert_rank_and_type
(
dec_inputs
,
3
)
self
.
_assert_rank_and_type
(
dec_mask
,
3
)
assert
dec_inputs
.
get_shape
().
as_list
()[
-
1
]
==
self
.
model_dimension
self_attn_output
=
self
.
decoder_uniform_attn
(
dec_inputs
,
dec_mask
,
dec_inverse_normalizer
,
step
=
step
,
beam_indices
=
selected_beams
,
cache
=
cache
)
cross_attn_output
=
self
.
multihead_cross_attn
(
dec_inputs
,
dec_mask
,
dec_inverse_normalizer
,
enc_output
,
enc_mask
,
enc_inverse_normalizer
,
cross_attn_mask
)
layer_out
=
self
.
q0
(
cross_attn_output
+
self_attn_output
)
layer_out
=
tf
.
reshape
(
layer_out
,
[
-
1
,
self
.
model_dimension
])
layer_out
=
self
.
prx
(
layer_out
)
if
self
.
parameters
.
mode
==
base_layers
.
TRAIN
:
layer_out
=
tf
.
nn
.
dropout
(
layer_out
,
rate
=
self
.
activation_dropout_rate
)
dec_inputs
=
tf
.
reshape
(
dec_inputs
,
[
-
1
,
self
.
model_dimension
])
dec_inputs_updated
=
self
.
q1
(
self
.
ln1
(
dec_inputs
+
layer_out
))
# Feed forward network.
layer_out
=
self
.
upprx
(
dec_inputs_updated
)
layer_out
=
self
.
downprx
(
layer_out
)
if
self
.
parameters
.
mode
==
base_layers
.
TRAIN
:
layer_out
=
tf
.
nn
.
dropout
(
layer_out
,
rate
=
self
.
activation_dropout_rate
)
outputs
=
self
.
q2
(
self
.
ln2
(
dec_inputs_updated
+
layer_out
))
return
tf
.
reshape
(
outputs
,
[
batch_size
,
-
1
,
self
.
model_dimension
])
class
TransformerUniformAttnDecoderStack
(
base_layers
.
BaseLayer
):
"""TransformerUniformAttnDecoderStack Decoder."""
def
__init__
(
self
,
num_layers
,
max_time_step
,
vocabulary_size
,
embedding_size
,
model_dimension
,
num_heads
,
intermediate_size
,
beam_size
=
1
,
activation_dropout_rate
=
0.1
,
attention_dropout_rate
=
0.0
,
cached_kv
=
False
,
**
kwargs
):
super
(
TransformerUniformAttnDecoderStack
,
self
).
__init__
(
**
kwargs
)
self
.
max_time_step
=
max_time_step
self
.
vocabulary_size
=
vocabulary_size
self
.
embedding_size
=
embedding_size
self
.
activation_dropout_rate
=
activation_dropout_rate
self
.
layers
=
[]
for
_
in
range
(
num_layers
):
self
.
layers
.
append
(
TransformerUniformAttnDecoder
(
model_dimension
=
model_dimension
,
max_time_step
=
max_time_step
,
num_heads
=
num_heads
,
intermediate_size
=
intermediate_size
,
beam_size
=
beam_size
,
cached_kv
=
cached_kv
,
activation_dropout_rate
=
activation_dropout_rate
,
attention_dropout_rate
=
attention_dropout_rate
,
**
kwargs
))
def
call
(
self
,
dec_inputs
,
dec_mask
,
enc_output
,
enc_mask
,
step
=
None
,
selected_beams
=
None
,
cache
=
None
):
self
.
_assert_rank_and_type
(
dec_mask
,
2
)
self
.
_assert_rank_and_type
(
enc_mask
,
2
)
dec_mask_rank3
=
tf
.
expand_dims
(
dec_mask
,
axis
=
2
)
dec_inverse_normalizer
=
tf
.
math
.
reciprocal
(
tf
.
reduce_sum
(
dec_mask_rank3
))
enc_mask_rank3
=
tf
.
expand_dims
(
enc_mask
,
1
)
enc_inverse_normalizer
=
tf
.
math
.
reciprocal
(
tf
.
reduce_sum
(
enc_mask_rank3
))
cross_attn_mask
=
enc_mask_rank3
layer_in
=
dec_inputs
if
self
.
parameters
.
mode
==
base_layers
.
TRAIN
:
layer_in
=
tf
.
nn
.
dropout
(
layer_in
,
rate
=
self
.
activation_dropout_rate
)
enc_output_feature_dim
=
enc_output
.
get_shape
().
as_list
()[
2
]
enc_output
=
tf
.
reshape
(
enc_output
,
[
-
1
,
enc_output_feature_dim
])
for
i
,
layer
in
enumerate
(
self
.
layers
):
layer_cache
=
cache
[
"layer_%d"
%
i
]
if
cache
is
not
None
else
None
layer_in
=
layer
(
layer_in
,
dec_mask_rank3
,
dec_inverse_normalizer
,
enc_output
,
enc_mask
,
enc_inverse_normalizer
,
cross_attn_mask
,
step
=
step
,
selected_beams
=
selected_beams
,
cache
=
layer_cache
)
return
layer_in
class
Model
(
tf
.
keras
.
layers
.
Layer
):
"""Quantized transformer decoder."""
def
__init__
(
self
,
config
,
mode
):
super
(
Model
,
self
).
__init__
()
def
_get_params
(
varname
,
default_value
=
None
):
value
=
config
[
varname
]
if
varname
in
config
else
default_value
default
=
""
if
varname
in
config
else
" (default)"
logging
.
info
(
"%s = %s%s"
,
varname
,
value
,
default
)
setattr
(
self
,
varname
,
value
)
_get_params
(
"intermediate_size"
)
_get_params
(
"max_dec_time_step"
)
_get_params
(
"max_enc_time_step"
)
_get_params
(
"embedding_size"
)
_get_params
(
"vocabulary_size"
)
_get_params
(
"num_layers"
)
_get_params
(
"labels"
)
_get_params
(
"regularizer_scale"
)
_get_params
(
"num_heads"
)
_get_params
(
"model_dimension"
)
_get_params
(
"beam_size"
,
1
)
_get_params
(
"quantize"
,
True
)
_get_params
(
"cached_kv"
,
False
)
_get_params
(
"attention_dropout_rate"
,
0.0
)
_get_params
(
"activation_dropout_rate"
,
0.0
)
# If set, a separate dense layer is used to generate the logits instead of
# re-using the input embedding table.
_get_params
(
"use_output_layer"
,
False
)
self
.
parameters
=
base_layers
.
Parameters
(
mode
,
self
.
quantize
,
self
.
regularizer_scale
)
# Activation/Normalization enabled on input bottleneck as there is no
# temporal information.
self
.
input_bottleneck
=
dense_layers
.
BaseQDenseVarLen
(
self
.
model_dimension
,
rank
=
3
,
parameters
=
self
.
parameters
)
self
.
output_bottleneck
=
dense_layers
.
BaseQDense
(
self
.
embedding_size
,
normalize
=
False
,
activation
=
None
,
bias
=
False
,
parameters
=
self
.
parameters
)
self
.
embedding
=
embedding_layers
.
EmbeddingFullyConnected
(
shape
=
[
self
.
vocabulary_size
,
self
.
embedding_size
],
initializer
=
tf
.
random_uniform_initializer
(
-
math
.
sqrt
(
3
),
math
.
sqrt
(
3
)),
parameters
=
self
.
parameters
)
if
self
.
use_output_layer
:
self
.
output_layer
=
dense_layers
.
BaseQDense
(
self
.
vocabulary_size
,
activation
=
None
,
normalize
=
False
,
bias
=
False
,
parameters
=
self
.
parameters
)
self
.
positional_embedding
=
embedding_layers
.
EmbeddingLayer
(
shape
=
[
self
.
max_dec_time_step
,
self
.
model_dimension
],
initializer
=
tf
.
random_uniform_initializer
(
-
math
.
sqrt
(
3
),
math
.
sqrt
(
3
)),
parameters
=
self
.
parameters
)
self
.
ln
=
normalization_layers
.
LayerNormalization
(
parameters
=
self
.
parameters
)
self
.
qact
=
quantization_layers
.
ActivationQuantization
(
parameters
=
self
.
parameters
)
# Scales the weights for computing logits.
self
.
logits_fc_weights_scale_factor
=
None
self
.
logits_fc_bias
=
self
.
add_weight
(
"logits_fc_bias"
,
shape
=
[
self
.
vocabulary_size
],
initializer
=
tf
.
constant_initializer
(
0
),
dtype
=
"float32"
)
# Optional bias which can be used to mask logits output.
self
.
output_bias
=
None
self
.
transformer_uniform_attn_decoder
=
TransformerUniformAttnDecoderStack
(
parameters
=
self
.
parameters
,
num_layers
=
self
.
num_layers
,
intermediate_size
=
self
.
intermediate_size
,
embedding_size
=
self
.
embedding_size
,
max_time_step
=
self
.
max_dec_time_step
,
num_heads
=
self
.
num_heads
,
model_dimension
=
self
.
model_dimension
,
vocabulary_size
=
self
.
vocabulary_size
,
beam_size
=
self
.
beam_size
,
cached_kv
=
self
.
cached_kv
,
attention_dropout_rate
=
self
.
attention_dropout_rate
,
activation_dropout_rate
=
self
.
activation_dropout_rate
)
# Beam search output.
self
.
finished_seq
=
None
self
.
finished_scores
=
None
def
call
(
self
,
decode_ids
,
decode_ids_mask
,
enc_output
,
enc_mask
,
start_ids
=
None
,
eos_id
=
None
,
pad_id
=
None
,
input_id
=
None
,
time_step
=
None
,
selected_beams
=
None
):
if
self
.
parameters
.
mode
==
base_layers
.
TRAIN
:
inputs
=
self
.
training_inputs
(
decode_ids
,
decode_ids_mask
)
layer_out
=
self
.
transformer_uniform_attn_decoder
(
inputs
,
decode_ids_mask
,
enc_output
,
enc_mask
)
logits
,
predicted_ids
=
self
.
model_outputs
(
layer_out
)
elif
self
.
parameters
.
mode
in
[
base_layers
.
EVAL
,
base_layers
.
PREDICT
]:
logits
,
predicted_ids
=
self
.
decode_beam_search
(
start_ids
,
eos_id
,
pad_id
,
enc_output
,
enc_mask
)
elif
self
.
parameters
.
mode
==
base_layers
.
TFLITE
:
input_values
=
self
.
embedding
(
input_id
)
# time_step starts from 1.
pos_values
=
self
.
positional_embedding
(
time_step
-
1
)
pos_values
=
tf
.
reshape
(
pos_values
,
[
-
1
,
1
,
self
.
embedding_size
])
input_mask
=
tf
.
ones
(
tf
.
shape
(
input_values
)[:
-
1
],
dtype
=
tf
.
float32
)
inputs
=
self
.
qact
(
self
.
ln
(
input_values
+
pos_values
))
layer_out
=
self
.
transformer_uniform_attn_decoder
(
inputs
,
input_mask
,
enc_output
,
enc_mask
,
step
=
time_step
,
selected_beams
=
selected_beams
)
logits
,
predicted_ids
=
self
.
model_outputs
(
layer_out
)
else
:
assert
"Invalid mode."
return
logits
,
predicted_ids
def
training_inputs
(
self
,
input_ids
,
input_mask
):
input_values
=
self
.
embedding
(
input_ids
)
if
self
.
embedding_size
!=
self
.
model_dimension
:
input_values
=
self
.
input_bottleneck
(
input_values
,
input_mask
)
pos_indices
=
tf
.
cumsum
(
input_mask
,
axis
=
1
,
exclusive
=
True
)
pos_indices
=
tf
.
cast
(
pos_indices
,
dtype
=
tf
.
int32
)
pos_values
=
self
.
positional_embedding
(
pos_indices
)
inputs
=
self
.
qact
(
self
.
ln
(
input_values
+
pos_values
))
return
inputs
def
model_outputs
(
self
,
layer_in
):
bsz
=
layer_in
.
get_shape
().
as_list
()[
0
]
or
tf
.
shape
(
layer_in
)[
0
]
layer_out
=
tf
.
reshape
(
layer_in
,
[
-
1
,
self
.
model_dimension
])
if
self
.
use_output_layer
:
logits
=
self
.
output_layer
(
layer_out
)
else
:
if
self
.
model_dimension
!=
self
.
embedding_size
:
layer_out
=
self
.
output_bottleneck
(
layer_out
)
logits
=
self
.
embedding
.
fully_connected
(
layer_out
,
bias
=
self
.
logits_fc_bias
,
weights_scale_factor
=
self
.
logits_fc_weights_scale_factor
)
logits
=
tf
.
reshape
(
logits
,
[
bsz
,
-
1
,
self
.
vocabulary_size
])
# Optional bias to mask out logits before applying argmax.
if
self
.
output_bias
is
not
None
:
logits
+=
self
.
output_bias
predicted_ids
=
tf
.
argmax
(
logits
,
axis
=
2
,
output_type
=
tf
.
int64
)
return
logits
,
predicted_ids
def
decode_beam_search
(
self
,
start_ids
,
eos_id
,
pad_id
,
enc_output
,
enc_mask
,
scope
=
"model"
):
batch_size
=
tf
.
shape
(
start_ids
)[
0
]
cache
=
{
# pylint: disable=g-complex-comprehension
"layer_%d"
%
layer
:
{
"uniform_avg"
:
tf
.
zeros
([
batch_size
,
1
,
self
.
model_dimension
]),
}
for
layer
in
range
(
self
.
num_layers
)
}
cache
[
"logits"
]
=
tf
.
zeros
([
batch_size
,
0
,
self
.
vocabulary_size
])
pos_indices
=
tf
.
range
(
self
.
max_dec_time_step
,
dtype
=
tf
.
int32
)
pos_indices
=
tf
.
reshape
(
pos_indices
,
[
1
,
-
1
])
pos_values
=
self
.
positional_embedding
(
pos_indices
)
def
beam_search_tile
(
output
,
tile_pattern
,
final_shape
):
x
=
tf
.
tile
(
output
,
tile_pattern
)
x
=
tf
.
reshape
(
x
,
final_shape
)
return
x
enc_output_feature_dim
=
enc_output
.
get_shape
().
as_list
()[
2
]
enc_output
=
beam_search_tile
(
enc_output
,
[
1
,
self
.
beam_size
,
1
],
[
batch_size
*
self
.
beam_size
,
-
1
,
enc_output_feature_dim
])
enc_mask
=
beam_search_tile
(
enc_mask
,
[
1
,
self
.
beam_size
],
[
batch_size
*
self
.
beam_size
,
-
1
])
def
symbols_to_logits_fn
(
ids
,
step
,
cache
):
"""Looks up ids to logits."""
logging
.
info
(
"Running symbols to logits. ids=%s, step=%s, cache=%s"
,
ids
,
step
,
cache
)
curr_id
=
ids
[:,
-
1
:]
with
tf
.
name_scope
(
scope
):
curr_embed
=
self
.
embedding
(
curr_id
)
input_mask
=
tf
.
ones
(
tf
.
shape
(
curr_embed
)[:
-
1
],
dtype
=
tf
.
float32
)
if
self
.
embedding_size
!=
self
.
model_dimension
:
curr_embed
=
self
.
input_bottleneck
(
curr_embed
,
input_mask
)
inputs
=
self
.
qact
(
self
.
ln
(
curr_embed
+
pos_values
[:,
step
:
step
+
1
,
:]))
layer_out
=
self
.
transformer_uniform_attn_decoder
(
inputs
,
input_mask
,
enc_output
,
enc_mask
,
step
=
step
+
1
,
cache
=
cache
)
next_logits
,
_
=
self
.
model_outputs
(
layer_out
)
cache
[
"logits"
]
=
tf
.
concat
([
cache
[
"logits"
],
next_logits
],
axis
=
1
)
return
next_logits
,
cache
self
.
finished_seq
,
self
.
finished_scores
,
states
=
beam_search
.
beam_search
(
symbols_to_logits_fn
,
initial_ids
=
start_ids
,
beam_size
=
self
.
beam_size
,
decode_length
=
self
.
max_dec_time_step
,
vocab_size
=
self
.
vocabulary_size
,
alpha
=
0.6
,
eos_id
=
eos_id
,
states
=
cache
)
beam_ids
=
self
.
finished_seq
[:,
0
,
1
:]
beam_ids
=
tf
.
pad
(
beam_ids
,
[[
0
,
0
],
[
0
,
self
.
max_dec_time_step
-
tf
.
shape
(
beam_ids
)[
1
]]],
constant_values
=
pad_id
)
logits
=
states
[
"logits"
][:,
0
,
:,
:]
logits
=
tf
.
pad
(
logits
,
[[
0
,
0
],
[
0
,
self
.
max_dec_time_step
-
tf
.
shape
(
logits
)[
1
]],
[
0
,
0
]],
constant_values
=
self
.
parameters
.
invalid_logit
)
return
logits
,
beam_ids
class
ModelEvalWithGTLogitsAndPredictions
(
Model
):
"""Model with EVAL mode logits and predictions based on ground truth inputs at each step."""
def
call
(
self
,
decode_ids
,
decode_ids_mask
,
enc_output
,
enc_mask
,
start_ids
=
None
,
eos_id
=
None
,
pad_id
=
None
,
input_id
=
None
,
time_step
=
None
,
selected_beams
=
None
):
if
self
.
parameters
.
mode
in
[
base_layers
.
TRAIN
,
base_layers
.
EVAL
]:
inputs
=
self
.
training_inputs
(
decode_ids
,
decode_ids_mask
)
layer_out
=
self
.
transformer_uniform_attn_decoder
(
inputs
,
decode_ids_mask
,
enc_output
,
enc_mask
)
logits
,
predicted_ids
=
self
.
model_outputs
(
layer_out
)
elif
self
.
parameters
.
mode
==
base_layers
.
PREDICT
:
logits
,
predicted_ids
=
self
.
decode_beam_search
(
start_ids
,
eos_id
,
pad_id
,
enc_output
,
enc_mask
,
scope
=
"model_eval_with_gt_logits_and_predictions"
)
elif
self
.
parameters
.
mode
==
base_layers
.
TFLITE
:
input_values
=
self
.
embedding
(
input_id
)
# time_step starts from 1.
pos_values
=
self
.
positional_embedding
(
time_step
-
1
)
pos_values
=
tf
.
reshape
(
pos_values
,
[
-
1
,
1
,
self
.
embedding_size
])
input_mask
=
tf
.
ones
(
tf
.
shape
(
input_values
)[:
-
1
],
dtype
=
tf
.
float32
)
inputs
=
self
.
qact
(
self
.
ln
(
input_values
+
pos_values
))
layer_out
=
self
.
transformer_uniform_attn_decoder
(
inputs
,
input_mask
,
enc_output
,
enc_mask
,
step
=
time_step
,
selected_beams
=
selected_beams
)
logits
,
predicted_ids
=
self
.
model_outputs
(
layer_out
)
else
:
assert
"Invalid mode."
return
logits
,
predicted_ids
class
ModelEvalWithGTLogits
(
Model
):
"""Model with EVAL mode logits computed based on ground truth input at each step."""
def
call
(
self
,
decode_ids
,
decode_ids_mask
,
enc_output
,
enc_mask
,
start_ids
=
None
,
eos_id
=
None
,
pad_id
=
None
,
input_id
=
None
,
time_step
=
None
,
selected_beams
=
None
):
logits
=
None
if
self
.
parameters
.
mode
in
[
base_layers
.
TRAIN
,
base_layers
.
EVAL
]:
inputs
=
self
.
training_inputs
(
decode_ids
,
decode_ids_mask
)
layer_out
=
self
.
transformer_uniform_attn_decoder
(
inputs
,
decode_ids_mask
,
enc_output
,
enc_mask
)
logits
,
predicted_ids
=
self
.
model_outputs
(
layer_out
)
if
self
.
parameters
.
mode
in
[
base_layers
.
EVAL
,
base_layers
.
PREDICT
]:
# EVAL mode predictions are based on beam search path.
_
,
predicted_ids
=
self
.
decode_beam_search
(
start_ids
,
eos_id
,
pad_id
,
enc_output
,
enc_mask
,
scope
=
"model_eval_with_gt_logits"
)
if
self
.
parameters
.
mode
==
base_layers
.
TFLITE
:
input_values
=
self
.
embedding
(
input_id
)
# time_step starts from 1.
pos_values
=
self
.
positional_embedding
(
time_step
-
1
)
pos_values
=
tf
.
reshape
(
pos_values
,
[
-
1
,
1
,
self
.
embedding_size
])
input_mask
=
tf
.
ones
(
tf
.
shape
(
input_values
)[:
-
1
],
dtype
=
tf
.
float32
)
inputs
=
self
.
qact
(
self
.
ln
(
input_values
+
pos_values
))
layer_out
=
self
.
transformer_uniform_attn_decoder
(
inputs
,
input_mask
,
enc_output
,
enc_mask
,
step
=
time_step
,
selected_beams
=
selected_beams
)
logits
,
predicted_ids
=
self
.
model_outputs
(
layer_out
)
return
logits
,
predicted_ids
research/seq_flow_lite/tf_ops/tf_custom_ops.cc
View file @
a81f8590
...
@@ -93,3 +93,33 @@ REGISTER_OP("PoolingOp")
...
@@ -93,3 +93,33 @@ REGISTER_OP("PoolingOp")
.
Doc
(
R"doc(
.
Doc
(
R"doc(
Dummy pooling op.
Dummy pooling op.
)doc"
);
)doc"
);
class
UniformCausalAttnOp
:
public
tensorflow
::
OpKernel
{
public:
explicit
UniformCausalAttnOp
(
tensorflow
::
OpKernelConstruction
*
context
)
:
tensorflow
::
OpKernel
(
context
)
{}
void
Compute
(
tensorflow
::
OpKernelContext
*
ctx
)
override
{}
};
REGISTER_KERNEL_BUILDER
(
Name
(
"UniformCausalAttn"
).
Device
(
::
tensorflow
::
DEVICE_CPU
),
UniformCausalAttnOp
);
REGISTER_OP
(
"UniformCausalAttn"
)
.
Input
(
"input: float32"
)
.
Input
(
"time_step: int32"
)
.
Input
(
"selected_beams: int32"
)
.
Attr
(
"feature_size: int"
)
.
Attr
(
"beam_size: int"
)
.
Output
(
"output: float32"
)
.
SetShapeFn
([](
::
tensorflow
::
shape_inference
::
InferenceContext
*
c
)
{
auto
batch_size
=
c
->
Dim
(
c
->
input
(
0
),
0
);
int32
feature_size
;
TF_RETURN_IF_ERROR
(
c
->
GetAttr
(
"feature_size"
,
&
feature_size
));
c
->
set_output
(
0
,
c
->
MakeShape
({
batch_size
,
1
,
feature_size
}));
return
tensorflow
::
Status
::
OK
();
})
.
Doc
(
R"doc(
Dummy uniform causal attn op.
)doc"
;
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment