Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
37e76715
Commit
37e76715
authored
Sep 21, 2022
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 475997523
parent
60568599
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
17 additions
and
14 deletions
+17
-14
official/nlp/modeling/networks/bert_dense_encoder_test.py
official/nlp/modeling/networks/bert_dense_encoder_test.py
+2
-2
official/nlp/modeling/networks/bert_encoder.py
official/nlp/modeling/networks/bert_encoder.py
+13
-8
official/nlp/modeling/networks/bert_encoder_test.py
official/nlp/modeling/networks/bert_encoder_test.py
+2
-4
No files found.
official/nlp/modeling/networks/bert_dense_encoder_test.py
View file @
37e76715
...
...
@@ -196,9 +196,9 @@ class BertEncoderV2Test(keras_parameterized.TestCase):
num_attention_heads
=
2
,
num_layers
=
3
,
type_vocab_size
=
num_types
,
output_range
=
output_range
,
dict_outputs
=
True
,
with_dense_inputs
=
True
)
with_dense_inputs
=
True
,
output_range
=
output_range
)
# Create the inputs (note that the first dimension is implicit).
word_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
mask
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
...
...
official/nlp/modeling/networks/bert_encoder.py
View file @
37e76715
...
...
@@ -116,6 +116,8 @@ class BertEncoderV2(tf.keras.layers.Layer):
attention_dropout
=
kwargs
.
pop
(
'attention_dropout_rate'
)
super
().
__init__
(
**
kwargs
)
self
.
_output_range
=
output_range
activation
=
tf
.
keras
.
activations
.
get
(
inner_activation
)
initializer
=
tf
.
keras
.
initializers
.
get
(
initializer
)
...
...
@@ -163,6 +165,7 @@ class BertEncoderV2(tf.keras.layers.Layer):
self
.
_transformer_layers
=
[]
self
.
_attention_mask_layer
=
layers
.
SelfAttentionMask
(
name
=
'self_attention_mask'
)
self
.
_num_layers
=
num_layers
for
i
in
range
(
num_layers
):
layer
=
layers
.
TransformerEncoderBlock
(
num_attention_heads
=
num_attention_heads
,
...
...
@@ -172,7 +175,6 @@ class BertEncoderV2(tf.keras.layers.Layer):
attention_dropout
=
attention_dropout
,
norm_first
=
norm_first
,
return_attention_scores
=
return_attention_scores
,
output_range
=
output_range
if
i
==
num_layers
-
1
else
None
,
kernel_initializer
=
tf_utils
.
clone_initializer
(
initializer
),
name
=
'transformer/layer_%d'
%
i
)
self
.
_transformer_layers
.
append
(
layer
)
...
...
@@ -257,8 +259,11 @@ class BertEncoderV2(tf.keras.layers.Layer):
encoder_outputs
=
[]
attention_outputs
=
[]
x
=
embeddings
for
layer
in
self
.
_transformer_layers
:
x
=
layer
([
x
,
attention_mask
])
for
i
,
layer
in
enumerate
(
self
.
_transformer_layers
):
transformer_output_range
=
None
if
i
==
self
.
_num_layers
-
1
:
transformer_output_range
=
self
.
_output_range
x
=
layer
([
x
,
attention_mask
],
output_range
=
transformer_output_range
)
if
self
.
_config
[
'return_attention_scores'
]:
x
,
attention_scores
=
x
attention_outputs
.
append
(
attention_scores
)
...
...
@@ -475,10 +480,9 @@ class BertEncoder(tf.keras.Model):
encoder_outputs
=
[]
attention_outputs
=
[]
for
i
in
range
(
num_layers
):
if
i
==
num_layers
-
1
and
output_range
is
not
None
:
transformer_output_range
=
None
if
i
==
num_layers
-
1
:
transformer_output_range
=
output_range
else
:
transformer_output_range
=
None
layer
=
layers
.
TransformerEncoderBlock
(
num_attention_heads
=
num_attention_heads
,
inner_dim
=
inner_dim
,
...
...
@@ -487,11 +491,11 @@ class BertEncoder(tf.keras.Model):
attention_dropout
=
attention_dropout
,
norm_first
=
norm_first
,
return_attention_scores
=
return_attention_scores
,
output_range
=
transformer_output_range
,
kernel_initializer
=
tf_utils
.
clone_initializer
(
initializer
),
name
=
'transformer/layer_%d'
%
i
)
transformer_layers
.
append
(
layer
)
data
=
layer
([
data
,
attention_mask
])
data
=
layer
([
data
,
attention_mask
],
output_range
=
transformer_output_range
)
if
return_attention_scores
:
data
,
attention_scores
=
data
attention_outputs
.
append
(
attention_scores
)
...
...
@@ -600,3 +604,4 @@ class BertEncoder(tf.keras.Model):
logging
.
warn
(
warn_string
)
return
cls
(
**
config
)
official/nlp/modeling/networks/bert_encoder_test.py
View file @
37e76715
...
...
@@ -545,8 +545,7 @@ class BertEncoderV2CompatibilityTest(tf.test.TestCase):
hidden_size
=
hidden_size
,
num_attention_heads
=
2
,
num_layers
=
3
,
type_vocab_size
=
num_types
,
output_range
=
None
)
type_vocab_size
=
num_types
)
word_id_data
=
np
.
random
.
randint
(
vocab_size
,
size
=
(
batch_size
,
sequence_length
))
...
...
@@ -605,8 +604,7 @@ class BertEncoderV2CompatibilityTest(tf.test.TestCase):
hidden_size
=
hidden_size
,
num_attention_heads
=
2
,
num_layers
=
3
,
type_vocab_size
=
num_types
,
output_range
=
None
)
type_vocab_size
=
num_types
)
word_id_data
=
np
.
random
.
randint
(
vocab_size
,
size
=
(
batch_size
,
sequence_length
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment