Commit 4627f55d authored by Jialu Liu's avatar Jialu Liu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 401000583
parent 7bb4d442
...@@ -82,15 +82,11 @@ class PackedSequenceEmbedding(tf.keras.Model): ...@@ -82,15 +82,11 @@ class PackedSequenceEmbedding(tf.keras.Model):
shape=(None,), dtype=tf.int32, name='input_mask') shape=(None,), dtype=tf.int32, name='input_mask')
type_ids = tf.keras.layers.Input( type_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name='input_type_ids') shape=(None,), dtype=tf.int32, name='input_type_ids')
inputs = { inputs = [word_ids, mask, type_ids]
'input_word_ids': word_ids,
'input_mask': mask,
'input_type_ids': type_ids,
}
if use_position_id: if use_position_id:
position_ids = tf.keras.layers.Input( position_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name='position_ids') shape=(None,), dtype=tf.int32, name='position_ids')
inputs['position_ids'] = position_ids inputs.append(position_ids)
else: else:
position_ids = None position_ids = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment