Commit a0b548e2 authored by Jiayu Ye's avatar Jiayu Ye Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 430740048
parent b72f4975
...@@ -402,6 +402,11 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -402,6 +402,11 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
_transformer_cls2str.get(transformer_cls, str(transformer_cls)) _transformer_cls2str.get(transformer_cls, str(transformer_cls))
} }
self.inputs = dict(
input_word_ids=tf.keras.Input(shape=(None,), dtype=tf.int32),
input_mask=tf.keras.Input(shape=(None,), dtype=tf.int32),
input_type_ids=tf.keras.Input(shape=(None,), dtype=tf.int32))
def call(self, inputs): def call(self, inputs):
# inputs are [word_ids, mask, type_ids] # inputs are [word_ids, mask, type_ids]
if isinstance(inputs, (list, tuple)): if isinstance(inputs, (list, tuple)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment