Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
d169c201
Commit
d169c201
authored
Oct 19, 2019
by
Jing Li
Committed by
A. Unique TensorFlower
Oct 19, 2019
Browse files
Always pass **kwargs to __call__ override for custom layers.
PiperOrigin-RevId: 275644913
parent
f0f42c82
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
23 additions
and
21 deletions
+23
-21
official/nlp/bert_modeling.py
official/nlp/bert_modeling.py
+2
-2
official/nlp/bert_models.py
official/nlp/bert_models.py
+7
-4
official/nlp/xlnet_modeling.py
official/nlp/xlnet_modeling.py
+14
-15
No files found.
official/nlp/bert_modeling.py
View file @
d169c201
...
...
@@ -777,9 +777,9 @@ class TransformerBlock(tf.keras.layers.Layer):
self
.
output_layer_norm
]
def
__call__
(
self
,
input_tensor
,
attention_mask
=
None
):
def
__call__
(
self
,
input_tensor
,
attention_mask
=
None
,
**
kwargs
):
inputs
=
tf_utils
.
pack_inputs
([
input_tensor
,
attention_mask
])
return
super
(
TransformerBlock
,
self
).
__call__
(
inputs
)
return
super
(
TransformerBlock
,
self
).
__call__
(
inputs
,
**
kwargs
)
def
call
(
self
,
inputs
):
"""Implements call() for the layer."""
...
...
official/nlp/bert_models.py
View file @
d169c201
...
...
@@ -116,10 +116,11 @@ class BertPretrainLayer(tf.keras.layers.Layer):
def
__call__
(
self
,
pooled_output
,
sequence_output
=
None
,
masked_lm_positions
=
None
):
masked_lm_positions
=
None
,
**
kwargs
):
inputs
=
tf_utils
.
pack_inputs
(
[
pooled_output
,
sequence_output
,
masked_lm_positions
])
return
super
(
BertPretrainLayer
,
self
).
__call__
(
inputs
)
return
super
(
BertPretrainLayer
,
self
).
__call__
(
inputs
,
**
kwargs
)
def
call
(
self
,
inputs
):
"""Implements call() for the layer."""
...
...
@@ -153,12 +154,14 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
sentence_output
=
None
,
lm_label_ids
=
None
,
lm_label_weights
=
None
,
sentence_labels
=
None
):
sentence_labels
=
None
,
**
kwargs
):
inputs
=
tf_utils
.
pack_inputs
([
lm_output
,
sentence_output
,
lm_label_ids
,
lm_label_weights
,
sentence_labels
])
return
super
(
BertPretrainLossAndMetricLayer
,
self
).
__call__
(
inputs
)
return
super
(
BertPretrainLossAndMetricLayer
,
self
).
__call__
(
inputs
,
**
kwargs
)
def
_add_metrics
(
self
,
lm_output
,
lm_labels
,
lm_label_weights
,
lm_per_example_loss
,
sentence_output
,
sentence_labels
,
...
...
official/nlp/xlnet_modeling.py
View file @
d169c201
...
...
@@ -160,11 +160,9 @@ class PositionalEmbedding(tf.keras.layers.Layer):
self
.
inv_freq
=
1.0
/
(
10000.0
**
(
tf
.
range
(
0
,
self
.
dim
,
2.0
)
/
self
.
dim
))
super
(
PositionalEmbedding
,
self
).
build
(
unused_input_shapes
)
def
__call__
(
self
,
pos_seq
,
batch_size
):
return
super
(
PositionalEmbedding
,
self
).
__call__
((
pos_seq
,
batch_size
,
))
def
__call__
(
self
,
pos_seq
,
batch_size
,
**
kwargs
):
return
super
(
PositionalEmbedding
,
self
).
__call__
(
(
pos_seq
,
batch_size
),
**
kwargs
)
def
call
(
self
,
inputs
):
"""Implements call() for the layer."""
...
...
@@ -197,12 +195,12 @@ class RelativeAttention(tf.keras.layers.Layer):
super
(
RelativeAttention
,
self
).
build
(
unused_input_shapes
)
def
__call__
(
self
,
q_head
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_embed
,
seg_mat
,
r_w_bias
,
r_r_bias
,
r_s_bias
,
attn_mask
):
r_w_bias
,
r_r_bias
,
r_s_bias
,
attn_mask
,
**
kwargs
):
inputs
=
pack_inputs
([
q_head
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_embed
,
seg_mat
,
r_w_bias
,
r_r_bias
,
r_s_bias
,
attn_mask
])
return
super
(
RelativeAttention
,
self
).
__call__
(
inputs
)
return
super
(
RelativeAttention
,
self
).
__call__
(
inputs
,
**
kwargs
)
def
call
(
self
,
inputs
):
"""Implements call() for the layer."""
...
...
@@ -364,12 +362,12 @@ class RelativeMultiheadAttention(tf.keras.layers.Layer):
super
(
RelativeMultiheadAttention
,
self
).
build
(
unused_input_shapes
)
def
__call__
(
self
,
h
,
g
,
r
,
r_w_bias
,
r_r_bias
,
seg_mat
,
r_s_bias
,
seg_embed
,
attn_mask_h
,
attn_mask_g
,
mems
,
target_mapping
):
attn_mask_h
,
attn_mask_g
,
mems
,
target_mapping
,
**
kwargs
):
inputs
=
pack_inputs
([
h
,
g
,
r
,
r_w_bias
,
r_r_bias
,
seg_mat
,
r_s_bias
,
seg_embed
,
attn_mask_h
,
attn_mask_g
,
mems
,
target_mapping
,
])
return
super
(
RelativeMultiheadAttention
,
self
).
__call__
(
inputs
)
return
super
(
RelativeMultiheadAttention
,
self
).
__call__
(
inputs
,
**
kwargs
)
def
call
(
self
,
inputs
):
"""Implements call() for the layer."""
...
...
@@ -597,7 +595,8 @@ class TransformerXLModel(tf.keras.layers.Layer):
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
):
inp_q
=
None
,
**
kwargs
):
# Uses dict to feed inputs into call() in order to keep mems as a python
# list.
inputs
=
{
...
...
@@ -609,7 +608,7 @@ class TransformerXLModel(tf.keras.layers.Layer):
'target_mapping'
:
target_mapping
,
'inp_q'
:
inp_q
}
return
super
(
TransformerXLModel
,
self
).
__call__
(
inputs
)
return
super
(
TransformerXLModel
,
self
).
__call__
(
inputs
,
**
kwargs
)
def
call
(
self
,
inputs
):
"""Implements call() for the layer."""
...
...
@@ -1011,9 +1010,9 @@ class LMLossLayer(tf.keras.layers.Layer):
super
(
LMLossLayer
,
self
).
build
(
unused_input_shapes
)
def
__call__
(
self
,
hidden
,
target
,
lookup_table
,
target_mask
):
def
__call__
(
self
,
hidden
,
target
,
lookup_table
,
target_mask
,
**
kwargs
):
inputs
=
pack_inputs
([
hidden
,
target
,
lookup_table
,
target_mask
])
return
super
(
LMLossLayer
,
self
).
__call__
(
inputs
)
return
super
(
LMLossLayer
,
self
).
__call__
(
inputs
,
**
kwargs
)
def
call
(
self
,
inputs
):
"""Implements call() for the layer."""
...
...
@@ -1117,9 +1116,9 @@ class ClassificationLossLayer(tf.keras.layers.Layer):
super
(
ClassificationLossLayer
,
self
).
build
(
unused_input_shapes
)
def
__call__
(
self
,
hidden
,
labels
):
def
__call__
(
self
,
hidden
,
labels
,
**
kwargs
):
inputs
=
pack_inputs
([
hidden
,
labels
])
return
super
(
ClassificationLossLayer
,
self
).
__call__
(
inputs
)
return
super
(
ClassificationLossLayer
,
self
).
__call__
(
inputs
,
**
kwargs
)
def
call
(
self
,
inputs
):
"""Implements call() for the layer."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment