Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
86df41f7
Commit
86df41f7
authored
Jan 05, 2021
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Jan 05, 2021
Browse files
#KerasNLP Update TransformerEncoderBlock to support Q, KV as two input streams.
PiperOrigin-RevId: 350170448
parent
ece64b24
Changes
59
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
73 additions
and
49 deletions
+73
-49
official/modeling/progressive/utils.py
official/modeling/progressive/utils.py
+1
-1
official/modeling/tf_utils.py
official/modeling/tf_utils.py
+1
-1
official/nlp/keras_nlp/layers/transformer_encoder_block.py
official/nlp/keras_nlp/layers/transformer_encoder_block.py
+41
-19
official/nlp/keras_nlp/layers/transformer_encoder_block_test.py
...al/nlp/keras_nlp/layers/transformer_encoder_block_test.py
+14
-12
official/nlp/modeling/layers/transformer.py
official/nlp/modeling/layers/transformer.py
+2
-2
orbit/__init__.py
orbit/__init__.py
+1
-1
orbit/controller.py
orbit/controller.py
+1
-1
orbit/controller_test.py
orbit/controller_test.py
+1
-1
orbit/runner.py
orbit/runner.py
+1
-1
orbit/standard_runner.py
orbit/standard_runner.py
+1
-1
orbit/standard_runner_test.py
orbit/standard_runner_test.py
+1
-1
orbit/utils/__init__.py
orbit/utils/__init__.py
+1
-1
orbit/utils/common.py
orbit/utils/common.py
+1
-1
orbit/utils/common_test.py
orbit/utils/common_test.py
+1
-1
orbit/utils/epoch_helper.py
orbit/utils/epoch_helper.py
+1
-1
orbit/utils/loop_fns.py
orbit/utils/loop_fns.py
+1
-1
orbit/utils/summary_manager.py
orbit/utils/summary_manager.py
+1
-1
orbit/utils/tpu_summaries.py
orbit/utils/tpu_summaries.py
+1
-1
orbit/utils/tpu_summaries_test.py
orbit/utils/tpu_summaries_test.py
+1
-1
No files found.
official/modeling/progressive/utils.py
View file @
86df41f7
# Copyright 202
0
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
official/modeling/tf_utils.py
View file @
86df41f7
# Copyright 202
0
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
official/nlp/keras_nlp/layers/transformer_encoder_block.py
View file @
86df41f7
...
@@ -85,7 +85,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
...
@@ -85,7 +85,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
kernel.
kernel.
**kwargs: keyword arguments/
**kwargs: keyword arguments/
"""
"""
super
(
TransformerEncoderBlock
,
self
).
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
self
.
_num_heads
=
num_attention_heads
self
.
_num_heads
=
num_attention_heads
self
.
_inner_dim
=
inner_dim
self
.
_inner_dim
=
inner_dim
...
@@ -111,23 +111,18 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
...
@@ -111,23 +111,18 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
self
.
_attention_initializer
=
self
.
_kernel_initializer
self
.
_attention_initializer
=
self
.
_kernel_initializer
def
build
(
self
,
input_shape
):
def
build
(
self
,
input_shape
):
input_tensor
=
input_shape
[
0
]
if
len
(
input_shape
)
==
2
else
input_shape
if
isinstance
(
input_shape
,
tf
.
TensorShape
):
input_tensor_shape
=
tf
.
TensorShape
(
input_tensor
)
input_tensor_shape
=
input_shape
elif
isinstance
(
input_shape
,
(
list
,
tuple
)):
input_tensor_shape
=
tf
.
TensorShape
(
input_shape
[
0
])
else
:
raise
ValueError
(
"The type of input shape argument is not supported, got: %s"
%
type
(
input_shape
))
if
len
(
input_tensor_shape
.
as_list
())
!=
3
:
if
len
(
input_tensor_shape
.
as_list
())
!=
3
:
raise
ValueError
(
"TransformerEncoderBlock expects a three-dimensional "
raise
ValueError
(
"TransformerEncoderBlock expects a three-dimensional "
"input of shape [batch, sequence, width]."
)
"input of shape [batch, sequence, width]."
)
batch_size
,
sequence_length
,
hidden_size
=
input_tensor_shape
hidden_size
=
input_tensor_shape
[
-
1
]
if
len
(
input_shape
)
==
2
:
mask_tensor_shape
=
tf
.
TensorShape
(
input_shape
[
1
])
expected_mask_tensor_shape
=
tf
.
TensorShape
(
[
batch_size
,
sequence_length
,
sequence_length
])
if
not
expected_mask_tensor_shape
.
is_compatible_with
(
mask_tensor_shape
):
raise
ValueError
(
"When passing a mask tensor to "
"TransformerEncoderBlock, the mask tensor must be of "
"shape [batch, sequence_length, sequence_length] "
"(here %s). Got a mask tensor of shape %s."
%
(
expected_mask_tensor_shape
,
mask_tensor_shape
))
if
hidden_size
%
self
.
_num_heads
!=
0
:
if
hidden_size
%
self
.
_num_heads
!=
0
:
raise
ValueError
(
raise
ValueError
(
"The input size (%d) is not a multiple of the number of attention "
"The input size (%d) is not a multiple of the number of attention "
...
@@ -234,15 +229,38 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
...
@@ -234,15 +229,38 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
call
(
self
,
inputs
):
def
call
(
self
,
inputs
):
if
isinstance
(
inputs
,
(
list
,
tuple
))
and
len
(
inputs
)
==
2
:
"""Transformer self-attention encoder block call.
input_tensor
,
attention_mask
=
inputs
Args:
inputs: a single tensor or a list of tensors.
`input tensor` as the single sequence of embeddings.
[`input tensor`, `attention mask`] to have the additional attention
mask.
[`query tensor`, `key value tensor`, `attention mask`] to have separate
input streams for the query, and key/value to the multi-head
attention.
Returns:
An ouput tensor with the same dimensions as input/query tensor.
"""
if
isinstance
(
inputs
,
(
list
,
tuple
)):
if
len
(
inputs
)
==
2
:
input_tensor
,
attention_mask
=
inputs
key_value
=
None
elif
len
(
inputs
)
==
3
:
input_tensor
,
key_value
,
attention_mask
=
inputs
else
:
raise
ValueError
(
"Unexpected inputs to %s with length at %d"
%
(
self
.
__class__
,
len
(
inputs
)))
else
:
else
:
input_tensor
,
attention_mask
=
(
inputs
,
None
)
input_tensor
,
key_value
,
attention_mask
=
(
inputs
,
None
,
None
)
if
self
.
_output_range
:
if
self
.
_output_range
:
if
self
.
_norm_first
:
if
self
.
_norm_first
:
source_tensor
=
input_tensor
[:,
0
:
self
.
_output_range
,
:]
source_tensor
=
input_tensor
[:,
0
:
self
.
_output_range
,
:]
input_tensor
=
self
.
_attention_layer_norm
(
input_tensor
)
input_tensor
=
self
.
_attention_layer_norm
(
input_tensor
)
if
key_value
is
not
None
:
key_value
=
self
.
_attention_layer_norm
(
key_value
)
target_tensor
=
input_tensor
[:,
0
:
self
.
_output_range
,
:]
target_tensor
=
input_tensor
[:,
0
:
self
.
_output_range
,
:]
if
attention_mask
is
not
None
:
if
attention_mask
is
not
None
:
attention_mask
=
attention_mask
[:,
0
:
self
.
_output_range
,
:]
attention_mask
=
attention_mask
[:,
0
:
self
.
_output_range
,
:]
...
@@ -250,10 +268,14 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
...
@@ -250,10 +268,14 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
if
self
.
_norm_first
:
if
self
.
_norm_first
:
source_tensor
=
input_tensor
source_tensor
=
input_tensor
input_tensor
=
self
.
_attention_layer_norm
(
input_tensor
)
input_tensor
=
self
.
_attention_layer_norm
(
input_tensor
)
if
key_value
is
not
None
:
key_value
=
self
.
_attention_layer_norm
(
key_value
)
target_tensor
=
input_tensor
target_tensor
=
input_tensor
if
key_value
is
None
:
key_value
=
input_tensor
attention_output
=
self
.
_attention_layer
(
attention_output
=
self
.
_attention_layer
(
query
=
target_tensor
,
value
=
input_tensor
,
attention_mask
=
attention_mask
)
query
=
target_tensor
,
value
=
key_value
,
attention_mask
=
attention_mask
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
if
self
.
_norm_first
:
if
self
.
_norm_first
:
attention_output
=
source_tensor
+
attention_output
attention_output
=
source_tensor
+
attention_output
...
...
official/nlp/keras_nlp/layers/transformer_encoder_block_test.py
View file @
86df41f7
...
@@ -55,18 +55,6 @@ class TransformerEncoderBlockLayerTest(keras_parameterized.TestCase):
...
@@ -55,18 +55,6 @@ class TransformerEncoderBlockLayerTest(keras_parameterized.TestCase):
# The default output of a transformer layer should be the same as the input.
# The default output of a transformer layer should be the same as the input.
self
.
assertEqual
(
data_tensor
.
shape
.
as_list
(),
output_tensor
.
shape
.
as_list
())
self
.
assertEqual
(
data_tensor
.
shape
.
as_list
(),
output_tensor
.
shape
.
as_list
())
def
test_layer_creation_with_incorrect_mask_fails
(
self
,
transformer_cls
):
test_layer
=
transformer_cls
(
num_attention_heads
=
10
,
inner_dim
=
2048
,
inner_activation
=
'relu'
)
sequence_length
=
21
width
=
80
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,
width
))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,
sequence_length
-
3
))
with
self
.
assertRaisesRegex
(
ValueError
,
'When passing a mask tensor.*'
):
_
=
test_layer
([
data_tensor
,
mask_tensor
])
def
test_layer_invocation
(
self
,
transformer_cls
):
def
test_layer_invocation
(
self
,
transformer_cls
):
test_layer
=
transformer_cls
(
test_layer
=
transformer_cls
(
num_attention_heads
=
10
,
inner_dim
=
2048
,
inner_activation
=
'relu'
)
num_attention_heads
=
10
,
inner_dim
=
2048
,
inner_activation
=
'relu'
)
...
@@ -249,6 +237,20 @@ class TransformerEncoderBlockLayerTest(keras_parameterized.TestCase):
...
@@ -249,6 +237,20 @@ class TransformerEncoderBlockLayerTest(keras_parameterized.TestCase):
self
.
assertAllEqual
([
1
,
input_length
,
width
],
output_data
.
shape
)
self
.
assertAllEqual
([
1
,
input_length
,
width
],
output_data
.
shape
)
def
test_separate_qkv
(
self
,
transformer_cls
):
test_layer
=
transformer_cls
(
num_attention_heads
=
2
,
inner_dim
=
128
,
inner_activation
=
'relu'
,
kernel_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
))
# Forward path.
q_tensor
=
tf
.
zeros
([
2
,
4
,
16
],
dtype
=
tf
.
float32
)
kv_tensor
=
tf
.
zeros
([
2
,
8
,
16
],
dtype
=
tf
.
float32
)
dummy_mask
=
tf
.
zeros
([
2
,
4
,
8
],
dtype
=
tf
.
float32
)
inputs
=
[
q_tensor
,
kv_tensor
,
dummy_mask
]
output
=
test_layer
(
inputs
)
self
.
assertEqual
(
output
.
shape
,
q_tensor
.
shape
)
@
keras_parameterized
.
run_all_keras_modes
@
keras_parameterized
.
run_all_keras_modes
class
TransformerArgumentTest
(
keras_parameterized
.
TestCase
):
class
TransformerArgumentTest
(
keras_parameterized
.
TestCase
):
...
...
official/nlp/modeling/layers/transformer.py
View file @
86df41f7
...
@@ -77,7 +77,7 @@ class Transformer(keras_nlp.layers.TransformerEncoderBlock):
...
@@ -77,7 +77,7 @@ class Transformer(keras_nlp.layers.TransformerEncoderBlock):
intermediate_dropout
=
0.0
,
intermediate_dropout
=
0.0
,
attention_initializer
=
None
,
attention_initializer
=
None
,
**
kwargs
):
**
kwargs
):
super
(
Transformer
,
self
).
__init__
(
super
().
__init__
(
num_attention_heads
=
num_attention_heads
,
num_attention_heads
=
num_attention_heads
,
inner_dim
=
intermediate_size
,
inner_dim
=
intermediate_size
,
inner_activation
=
intermediate_activation
,
inner_activation
=
intermediate_activation
,
...
@@ -105,7 +105,7 @@ class CompiledTransformer(Transformer):
...
@@ -105,7 +105,7 @@ class CompiledTransformer(Transformer):
@
tf_function_if_eager
(
experimental_compile
=
True
)
@
tf_function_if_eager
(
experimental_compile
=
True
)
def
call
(
self
,
inputs
):
def
call
(
self
,
inputs
):
return
super
(
CompiledTransformer
,
self
).
call
(
inputs
)
return
super
().
call
(
inputs
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
...
...
orbit/__init__.py
View file @
86df41f7
# Copyright 202
0
The Orbit Authors. All Rights Reserved.
# Copyright 202
1
The Orbit Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
orbit/controller.py
View file @
86df41f7
# Copyright 202
0
The Orbit Authors. All Rights Reserved.
# Copyright 202
1
The Orbit Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
orbit/controller_test.py
View file @
86df41f7
# Copyright 202
0
The Orbit Authors. All Rights Reserved.
# Copyright 202
1
The Orbit Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
orbit/runner.py
View file @
86df41f7
# Copyright 202
0
The Orbit Authors. All Rights Reserved.
# Copyright 202
1
The Orbit Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
orbit/standard_runner.py
View file @
86df41f7
# Copyright 202
0
The Orbit Authors. All Rights Reserved.
# Copyright 202
1
The Orbit Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
orbit/standard_runner_test.py
View file @
86df41f7
# Copyright 202
0
The Orbit Authors. All Rights Reserved.
# Copyright 202
1
The Orbit Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
orbit/utils/__init__.py
View file @
86df41f7
# Copyright 202
0
The Orbit Authors. All Rights Reserved.
# Copyright 202
1
The Orbit Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
orbit/utils/common.py
View file @
86df41f7
# Copyright 202
0
The Orbit Authors. All Rights Reserved.
# Copyright 202
1
The Orbit Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
orbit/utils/common_test.py
View file @
86df41f7
# Copyright 202
0
The Orbit Authors. All Rights Reserved.
# Copyright 202
1
The Orbit Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
orbit/utils/epoch_helper.py
View file @
86df41f7
# Copyright 202
0
The Orbit Authors. All Rights Reserved.
# Copyright 202
1
The Orbit Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
orbit/utils/loop_fns.py
View file @
86df41f7
# Copyright 202
0
The Orbit Authors. All Rights Reserved.
# Copyright 202
1
The Orbit Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
orbit/utils/summary_manager.py
View file @
86df41f7
# Copyright 202
0
The Orbit Authors. All Rights Reserved.
# Copyright 202
1
The Orbit Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
orbit/utils/tpu_summaries.py
View file @
86df41f7
# Copyright 202
0
The Orbit Authors. All Rights Reserved.
# Copyright 202
1
The Orbit Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
orbit/utils/tpu_summaries_test.py
View file @
86df41f7
# Copyright 202
0
The Orbit Authors. All Rights Reserved.
# Copyright 202
1
The Orbit Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment