Commit d169c201 authored by Jing Li's avatar Jing Li Committed by A. Unique TensorFlower
Browse files

Always pass **kwargs to __call__ override for custom layers.

PiperOrigin-RevId: 275644913
parent f0f42c82
......@@ -777,9 +777,9 @@ class TransformerBlock(tf.keras.layers.Layer):
self.output_layer_norm
]
def __call__(self, input_tensor, attention_mask=None):
def __call__(self, input_tensor, attention_mask=None, **kwargs):
inputs = tf_utils.pack_inputs([input_tensor, attention_mask])
return super(TransformerBlock, self).__call__(inputs)
return super(TransformerBlock, self).__call__(inputs, **kwargs)
def call(self, inputs):
"""Implements call() for the layer."""
......
......@@ -116,10 +116,11 @@ class BertPretrainLayer(tf.keras.layers.Layer):
def __call__(self,
pooled_output,
sequence_output=None,
masked_lm_positions=None):
masked_lm_positions=None,
**kwargs):
inputs = tf_utils.pack_inputs(
[pooled_output, sequence_output, masked_lm_positions])
return super(BertPretrainLayer, self).__call__(inputs)
return super(BertPretrainLayer, self).__call__(inputs, **kwargs)
def call(self, inputs):
"""Implements call() for the layer."""
......@@ -153,12 +154,14 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
sentence_output=None,
lm_label_ids=None,
lm_label_weights=None,
sentence_labels=None):
sentence_labels=None,
**kwargs):
inputs = tf_utils.pack_inputs([
lm_output, sentence_output, lm_label_ids, lm_label_weights,
sentence_labels
])
return super(BertPretrainLossAndMetricLayer, self).__call__(inputs)
return super(BertPretrainLossAndMetricLayer, self).__call__(
inputs, **kwargs)
def _add_metrics(self, lm_output, lm_labels, lm_label_weights,
lm_per_example_loss, sentence_output, sentence_labels,
......
......@@ -160,11 +160,9 @@ class PositionalEmbedding(tf.keras.layers.Layer):
self.inv_freq = 1.0 / (10000.0**(tf.range(0, self.dim, 2.0) / self.dim))
super(PositionalEmbedding, self).build(unused_input_shapes)
def __call__(self, pos_seq, batch_size):
return super(PositionalEmbedding, self).__call__((
pos_seq,
batch_size,
))
def __call__(self, pos_seq, batch_size, **kwargs):
return super(PositionalEmbedding, self).__call__(
(pos_seq, batch_size), **kwargs)
def call(self, inputs):
"""Implements call() for the layer."""
......@@ -197,12 +195,12 @@ class RelativeAttention(tf.keras.layers.Layer):
super(RelativeAttention, self).build(unused_input_shapes)
def __call__(self, q_head, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat,
r_w_bias, r_r_bias, r_s_bias, attn_mask):
r_w_bias, r_r_bias, r_s_bias, attn_mask, **kwargs):
inputs = pack_inputs([
q_head, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias,
r_r_bias, r_s_bias, attn_mask
])
return super(RelativeAttention, self).__call__(inputs)
return super(RelativeAttention, self).__call__(inputs, **kwargs)
def call(self, inputs):
"""Implements call() for the layer."""
......@@ -364,12 +362,12 @@ class RelativeMultiheadAttention(tf.keras.layers.Layer):
super(RelativeMultiheadAttention, self).build(unused_input_shapes)
def __call__(self, h, g, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed,
attn_mask_h, attn_mask_g, mems, target_mapping):
attn_mask_h, attn_mask_g, mems, target_mapping, **kwargs):
inputs = pack_inputs([
h, g, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed, attn_mask_h,
attn_mask_g, mems, target_mapping,
])
return super(RelativeMultiheadAttention, self).__call__(inputs)
return super(RelativeMultiheadAttention, self).__call__(inputs, **kwargs)
def call(self, inputs):
"""Implements call() for the layer."""
......@@ -597,7 +595,8 @@ class TransformerXLModel(tf.keras.layers.Layer):
mems=None,
perm_mask=None,
target_mapping=None,
inp_q=None):
inp_q=None,
**kwargs):
# Uses dict to feed inputs into call() in order to keep mems as a python
# list.
inputs = {
......@@ -609,7 +608,7 @@ class TransformerXLModel(tf.keras.layers.Layer):
'target_mapping': target_mapping,
'inp_q': inp_q
}
return super(TransformerXLModel, self).__call__(inputs)
return super(TransformerXLModel, self).__call__(inputs, **kwargs)
def call(self, inputs):
"""Implements call() for the layer."""
......@@ -1011,9 +1010,9 @@ class LMLossLayer(tf.keras.layers.Layer):
super(LMLossLayer, self).build(unused_input_shapes)
def __call__(self, hidden, target, lookup_table, target_mask):
def __call__(self, hidden, target, lookup_table, target_mask, **kwargs):
inputs = pack_inputs([hidden, target, lookup_table, target_mask])
return super(LMLossLayer, self).__call__(inputs)
return super(LMLossLayer, self).__call__(inputs, **kwargs)
def call(self, inputs):
"""Implements call() for the layer."""
......@@ -1117,9 +1116,9 @@ class ClassificationLossLayer(tf.keras.layers.Layer):
super(ClassificationLossLayer, self).build(unused_input_shapes)
def __call__(self, hidden, labels):
def __call__(self, hidden, labels, **kwargs):
inputs = pack_inputs([hidden, labels])
return super(ClassificationLossLayer, self).__call__(inputs)
return super(ClassificationLossLayer, self).__call__(inputs, **kwargs)
def call(self, inputs):
"""Implements call() for the layer."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment