Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
2deb09ae
"llm/llama.cpp/src/unicode-data.cpp" did not exist on "b0135f4b9b176eab9155b660d04c9ca2a1ec2341"
Commit
2deb09ae
authored
Aug 07, 2020
by
xinliupitt
Browse files
attention_initializer
parent
f68a262d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
40 additions
and
8 deletions
+40
-8
official/nlp/modeling/layers/transformer.py
official/nlp/modeling/layers/transformer.py
+28
-4
official/nlp/modeling/layers/transformer_test.py
official/nlp/modeling/layers/transformer_test.py
+12
-4
No files found.
official/nlp/modeling/layers/transformer.py
View file @
2deb09ae
...
...
@@ -56,6 +56,8 @@ class Transformer(tf.keras.layers.Layer):
normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
intermediate_dropout: Dropout probability for intermediate_dropout_layer.
attention_initializer: Initializer for kernels of attention layers. If set
`None`, attention layers use kernel_initializer as initializer for kernel.
"""
def
__init__
(
self
,
...
...
@@ -76,6 +78,7 @@ class Transformer(tf.keras.layers.Layer):
norm_first
=
False
,
norm_epsilon
=
1e-12
,
intermediate_dropout
=
0.0
,
attention_initializer
=
None
,
**
kwargs
):
super
(
Transformer
,
self
).
__init__
(
**
kwargs
)
...
...
@@ -96,6 +99,10 @@ class Transformer(tf.keras.layers.Layer):
self
.
_norm_first
=
norm_first
self
.
_norm_epsilon
=
norm_epsilon
self
.
_intermediate_dropout
=
intermediate_dropout
if
attention_initializer
:
self
.
_attention_initializer
=
attention_initializer
else
:
self
.
_attention_initializer
=
self
.
_kernel_initializer
def
build
(
self
,
input_shape
):
input_tensor
=
input_shape
[
0
]
if
len
(
input_shape
)
==
2
else
input_shape
...
...
@@ -121,7 +128,6 @@ class Transformer(tf.keras.layers.Layer):
"heads (%d)"
%
(
hidden_size
,
self
.
_num_heads
))
self
.
_attention_head_size
=
int
(
hidden_size
//
self
.
_num_heads
)
common_kwargs
=
dict
(
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
...
...
@@ -133,6 +139,7 @@ class Transformer(tf.keras.layers.Layer):
key_size
=
self
.
_attention_head_size
,
dropout
=
self
.
_attention_dropout_rate
,
use_bias
=
self
.
_use_bias
,
kernel_initializer
=
self
.
_attention_initializer
,
name
=
"self_attention"
,
**
common_kwargs
)
self
.
_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
...
...
@@ -148,6 +155,7 @@ class Transformer(tf.keras.layers.Layer):
"abc,cd->abd"
,
output_shape
=
(
None
,
self
.
_intermediate_size
),
bias_axes
=
"d"
,
kernel_initializer
=
self
.
_kernel_initializer
,
name
=
"intermediate"
,
**
common_kwargs
)
policy
=
tf
.
keras
.
mixed_precision
.
experimental
.
global_policy
()
...
...
@@ -165,6 +173,7 @@ class Transformer(tf.keras.layers.Layer):
output_shape
=
(
None
,
hidden_size
),
bias_axes
=
"d"
,
name
=
"output"
,
kernel_initializer
=
self
.
_kernel_initializer
,
**
common_kwargs
)
self
.
_output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
# Use float32 in layernorm for numeric stability.
...
...
@@ -211,7 +220,9 @@ class Transformer(tf.keras.layers.Layer):
"norm_epsilon"
:
self
.
_norm_epsilon
,
"intermediate_dropout"
:
self
.
_intermediate_dropout
self
.
_intermediate_dropout
,
"attention_initializer"
:
self
.
_attention_initializer
}
base_config
=
super
(
Transformer
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
...
...
@@ -300,6 +311,8 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
intermediate_dropout: Dropout probability for intermediate_dropout_layer.
attention_initializer: Initializer for kernels of attention layers. If set
`None`, attention layers use kernel_initializer as initializer for kernel.
"""
def
__init__
(
self
,
...
...
@@ -320,6 +333,7 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
norm_first
=
False
,
norm_epsilon
=
1e-12
,
intermediate_dropout
=
0.0
,
attention_initializer
=
None
,
**
kwargs
):
super
(
TransformerDecoderLayer
,
self
).
__init__
(
**
kwargs
)
self
.
num_attention_heads
=
num_attention_heads
...
...
@@ -340,6 +354,10 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
self
.
_norm_first
=
norm_first
self
.
_norm_epsilon
=
norm_epsilon
self
.
_intermediate_dropout
=
intermediate_dropout
if
attention_initializer
:
self
.
_attention_initializer
=
attention_initializer
else
:
self
.
_attention_initializer
=
self
.
_kernel_initializer
if
self
.
multi_channel_cross_attention
:
self
.
_cross_attention_cls
=
multi_channel_attention
.
MultiChannelAttention
else
:
...
...
@@ -357,7 +375,6 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
"heads (%d)"
%
(
hidden_size
,
self
.
num_attention_heads
))
self
.
attention_head_size
=
int
(
hidden_size
/
self
.
num_attention_heads
)
common_kwargs
=
dict
(
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
...
...
@@ -370,12 +387,14 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
key_size
=
self
.
attention_head_size
,
dropout
=
self
.
attention_dropout_rate
,
use_bias
=
self
.
_use_bias
,
kernel_initializer
=
self
.
_attention_initializer
,
name
=
"self_attention"
,
**
common_kwargs
)
self
.
self_attention_output_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
"abc,cd->abd"
,
output_shape
=
(
None
,
hidden_size
),
bias_axes
=
"d"
,
kernel_initializer
=
self
.
_kernel_initializer
,
name
=
"output"
,
**
common_kwargs
)
self
.
self_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
...
...
@@ -392,6 +411,7 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
dropout
=
self
.
attention_dropout_rate
,
output_shape
=
hidden_size
,
use_bias
=
self
.
_use_bias
,
kernel_initializer
=
self
.
_attention_initializer
,
name
=
"attention/encdec"
,
**
common_kwargs
)
...
...
@@ -408,6 +428,7 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
"abc,cd->abd"
,
output_shape
=
(
None
,
self
.
intermediate_size
),
bias_axes
=
"d"
,
kernel_initializer
=
self
.
_kernel_initializer
,
name
=
"intermediate"
,
**
common_kwargs
)
self
.
intermediate_activation_layer
=
tf
.
keras
.
layers
.
Activation
(
...
...
@@ -418,6 +439,7 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
"abc,cd->abd"
,
output_shape
=
(
None
,
hidden_size
),
bias_axes
=
"d"
,
kernel_initializer
=
self
.
_kernel_initializer
,
name
=
"output"
,
**
common_kwargs
)
self
.
output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
dropout_rate
)
...
...
@@ -460,7 +482,9 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
"norm_epsilon"
:
self
.
_norm_epsilon
,
"intermediate_dropout"
:
self
.
_intermediate_dropout
self
.
_intermediate_dropout
,
"attention_initializer"
:
self
.
_attention_initializer
}
base_config
=
super
(
TransformerDecoderLayer
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
...
...
official/nlp/modeling/layers/transformer_test.py
View file @
2deb09ae
...
...
@@ -231,7 +231,9 @@ class TransformerArgumentTest(keras_parameterized.TestCase):
use_bias
=
False
,
norm_first
=
True
,
norm_epsilon
=
1e-6
,
intermediate_dropout
=
0.1
)
intermediate_dropout
=
0.1
,
attention_initializer
=
tf
.
keras
.
initializers
.
RandomUniform
(
minval
=
0.
,
maxval
=
1.
))
# Forward path.
dummy_tensor
=
tf
.
zeros
([
2
,
4
,
16
],
dtype
=
tf
.
float32
)
dummy_mask
=
tf
.
zeros
([
2
,
4
,
4
],
dtype
=
tf
.
float32
)
...
...
@@ -250,7 +252,9 @@ class TransformerArgumentTest(keras_parameterized.TestCase):
use_bias
=
False
,
norm_first
=
True
,
norm_epsilon
=
1e-6
,
intermediate_dropout
=
0.1
)
intermediate_dropout
=
0.1
,
attention_initializer
=
tf
.
keras
.
initializers
.
RandomUniform
(
minval
=
0.
,
maxval
=
1.
))
encoder_block_config
=
encoder_block
.
get_config
()
new_encoder_block
=
transformer
.
Transformer
.
from_config
(
encoder_block_config
)
...
...
@@ -302,7 +306,9 @@ class TransformerDecoderLayerTest(keras_parameterized.TestCase):
use_bias
=
False
,
norm_first
=
True
,
norm_epsilon
=
1e-6
,
intermediate_dropout
=
0.1
)
intermediate_dropout
=
0.1
,
attention_initializer
=
tf
.
keras
.
initializers
.
RandomUniform
(
minval
=
0.
,
maxval
=
1.
))
# Forward path.
dummy_tensor
=
tf
.
zeros
([
2
,
4
,
16
],
dtype
=
tf
.
float32
)
dummy_mask
=
tf
.
zeros
([
2
,
4
,
4
],
dtype
=
tf
.
float32
)
...
...
@@ -321,7 +327,9 @@ class TransformerDecoderLayerTest(keras_parameterized.TestCase):
use_bias
=
False
,
norm_first
=
True
,
norm_epsilon
=
1e-6
,
intermediate_dropout
=
0.1
)
intermediate_dropout
=
0.1
,
attention_initializer
=
tf
.
keras
.
initializers
.
RandomUniform
(
minval
=
0.
,
maxval
=
1.
))
decoder_block_config
=
decoder_block
.
get_config
()
new_decoder_block
=
transformer
.
TransformerDecoderLayer
.
from_config
(
decoder_block_config
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment