Commit d169c201 authored by Jing Li's avatar Jing Li Committed by A. Unique TensorFlower
Browse files

Always pass **kwargs to __call__ override for custom layers.

PiperOrigin-RevId: 275644913
parent f0f42c82
...@@ -777,9 +777,9 @@ class TransformerBlock(tf.keras.layers.Layer): ...@@ -777,9 +777,9 @@ class TransformerBlock(tf.keras.layers.Layer):
self.output_layer_norm self.output_layer_norm
] ]
def __call__(self, input_tensor, attention_mask=None): def __call__(self, input_tensor, attention_mask=None, **kwargs):
inputs = tf_utils.pack_inputs([input_tensor, attention_mask]) inputs = tf_utils.pack_inputs([input_tensor, attention_mask])
return super(TransformerBlock, self).__call__(inputs) return super(TransformerBlock, self).__call__(inputs, **kwargs)
def call(self, inputs): def call(self, inputs):
"""Implements call() for the layer.""" """Implements call() for the layer."""
......
...@@ -116,10 +116,11 @@ class BertPretrainLayer(tf.keras.layers.Layer): ...@@ -116,10 +116,11 @@ class BertPretrainLayer(tf.keras.layers.Layer):
def __call__(self, def __call__(self,
pooled_output, pooled_output,
sequence_output=None, sequence_output=None,
masked_lm_positions=None): masked_lm_positions=None,
**kwargs):
inputs = tf_utils.pack_inputs( inputs = tf_utils.pack_inputs(
[pooled_output, sequence_output, masked_lm_positions]) [pooled_output, sequence_output, masked_lm_positions])
return super(BertPretrainLayer, self).__call__(inputs) return super(BertPretrainLayer, self).__call__(inputs, **kwargs)
def call(self, inputs): def call(self, inputs):
"""Implements call() for the layer.""" """Implements call() for the layer."""
...@@ -153,12 +154,14 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer): ...@@ -153,12 +154,14 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
sentence_output=None, sentence_output=None,
lm_label_ids=None, lm_label_ids=None,
lm_label_weights=None, lm_label_weights=None,
sentence_labels=None): sentence_labels=None,
**kwargs):
inputs = tf_utils.pack_inputs([ inputs = tf_utils.pack_inputs([
lm_output, sentence_output, lm_label_ids, lm_label_weights, lm_output, sentence_output, lm_label_ids, lm_label_weights,
sentence_labels sentence_labels
]) ])
return super(BertPretrainLossAndMetricLayer, self).__call__(inputs) return super(BertPretrainLossAndMetricLayer, self).__call__(
inputs, **kwargs)
def _add_metrics(self, lm_output, lm_labels, lm_label_weights, def _add_metrics(self, lm_output, lm_labels, lm_label_weights,
lm_per_example_loss, sentence_output, sentence_labels, lm_per_example_loss, sentence_output, sentence_labels,
......
...@@ -160,11 +160,9 @@ class PositionalEmbedding(tf.keras.layers.Layer): ...@@ -160,11 +160,9 @@ class PositionalEmbedding(tf.keras.layers.Layer):
self.inv_freq = 1.0 / (10000.0**(tf.range(0, self.dim, 2.0) / self.dim)) self.inv_freq = 1.0 / (10000.0**(tf.range(0, self.dim, 2.0) / self.dim))
super(PositionalEmbedding, self).build(unused_input_shapes) super(PositionalEmbedding, self).build(unused_input_shapes)
def __call__(self, pos_seq, batch_size): def __call__(self, pos_seq, batch_size, **kwargs):
return super(PositionalEmbedding, self).__call__(( return super(PositionalEmbedding, self).__call__(
pos_seq, (pos_seq, batch_size), **kwargs)
batch_size,
))
def call(self, inputs): def call(self, inputs):
"""Implements call() for the layer.""" """Implements call() for the layer."""
...@@ -197,12 +195,12 @@ class RelativeAttention(tf.keras.layers.Layer): ...@@ -197,12 +195,12 @@ class RelativeAttention(tf.keras.layers.Layer):
super(RelativeAttention, self).build(unused_input_shapes) super(RelativeAttention, self).build(unused_input_shapes)
def __call__(self, q_head, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, def __call__(self, q_head, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat,
r_w_bias, r_r_bias, r_s_bias, attn_mask): r_w_bias, r_r_bias, r_s_bias, attn_mask, **kwargs):
inputs = pack_inputs([ inputs = pack_inputs([
q_head, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias, q_head, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias,
r_r_bias, r_s_bias, attn_mask r_r_bias, r_s_bias, attn_mask
]) ])
return super(RelativeAttention, self).__call__(inputs) return super(RelativeAttention, self).__call__(inputs, **kwargs)
def call(self, inputs): def call(self, inputs):
"""Implements call() for the layer.""" """Implements call() for the layer."""
...@@ -364,12 +362,12 @@ class RelativeMultiheadAttention(tf.keras.layers.Layer): ...@@ -364,12 +362,12 @@ class RelativeMultiheadAttention(tf.keras.layers.Layer):
super(RelativeMultiheadAttention, self).build(unused_input_shapes) super(RelativeMultiheadAttention, self).build(unused_input_shapes)
def __call__(self, h, g, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed, def __call__(self, h, g, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed,
attn_mask_h, attn_mask_g, mems, target_mapping): attn_mask_h, attn_mask_g, mems, target_mapping, **kwargs):
inputs = pack_inputs([ inputs = pack_inputs([
h, g, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed, attn_mask_h, h, g, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed, attn_mask_h,
attn_mask_g, mems, target_mapping, attn_mask_g, mems, target_mapping,
]) ])
return super(RelativeMultiheadAttention, self).__call__(inputs) return super(RelativeMultiheadAttention, self).__call__(inputs, **kwargs)
def call(self, inputs): def call(self, inputs):
"""Implements call() for the layer.""" """Implements call() for the layer."""
...@@ -597,7 +595,8 @@ class TransformerXLModel(tf.keras.layers.Layer): ...@@ -597,7 +595,8 @@ class TransformerXLModel(tf.keras.layers.Layer):
mems=None, mems=None,
perm_mask=None, perm_mask=None,
target_mapping=None, target_mapping=None,
inp_q=None): inp_q=None,
**kwargs):
# Uses dict to feed inputs into call() in order to keep mems as a python # Uses dict to feed inputs into call() in order to keep mems as a python
# list. # list.
inputs = { inputs = {
...@@ -609,7 +608,7 @@ class TransformerXLModel(tf.keras.layers.Layer): ...@@ -609,7 +608,7 @@ class TransformerXLModel(tf.keras.layers.Layer):
'target_mapping': target_mapping, 'target_mapping': target_mapping,
'inp_q': inp_q 'inp_q': inp_q
} }
return super(TransformerXLModel, self).__call__(inputs) return super(TransformerXLModel, self).__call__(inputs, **kwargs)
def call(self, inputs): def call(self, inputs):
"""Implements call() for the layer.""" """Implements call() for the layer."""
...@@ -1011,9 +1010,9 @@ class LMLossLayer(tf.keras.layers.Layer): ...@@ -1011,9 +1010,9 @@ class LMLossLayer(tf.keras.layers.Layer):
super(LMLossLayer, self).build(unused_input_shapes) super(LMLossLayer, self).build(unused_input_shapes)
def __call__(self, hidden, target, lookup_table, target_mask): def __call__(self, hidden, target, lookup_table, target_mask, **kwargs):
inputs = pack_inputs([hidden, target, lookup_table, target_mask]) inputs = pack_inputs([hidden, target, lookup_table, target_mask])
return super(LMLossLayer, self).__call__(inputs) return super(LMLossLayer, self).__call__(inputs, **kwargs)
def call(self, inputs): def call(self, inputs):
"""Implements call() for the layer.""" """Implements call() for the layer."""
...@@ -1117,9 +1116,9 @@ class ClassificationLossLayer(tf.keras.layers.Layer): ...@@ -1117,9 +1116,9 @@ class ClassificationLossLayer(tf.keras.layers.Layer):
super(ClassificationLossLayer, self).build(unused_input_shapes) super(ClassificationLossLayer, self).build(unused_input_shapes)
def __call__(self, hidden, labels): def __call__(self, hidden, labels, **kwargs):
inputs = pack_inputs([hidden, labels]) inputs = pack_inputs([hidden, labels])
return super(ClassificationLossLayer, self).__call__(inputs) return super(ClassificationLossLayer, self).__call__(inputs, **kwargs)
def call(self, inputs): def call(self, inputs):
"""Implements call() for the layer.""" """Implements call() for the layer."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment