Unverified Commit 6c8ec2a9 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

fix tf led pt test (#9513)

parent 1e3c3622
...@@ -166,7 +166,13 @@ def prepare_led_inputs_dict( ...@@ -166,7 +166,13 @@ def prepare_led_inputs_dict(
if attention_mask is None: if attention_mask is None:
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8) attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
if decoder_attention_mask is None: if decoder_attention_mask is None:
decoder_attention_mask = tf.cast(tf.math.not_equal(decoder_input_ids, config.pad_token_id), tf.int8) decoder_attention_mask = tf.concat(
[
tf.ones(decoder_input_ids[:, :1].shape, dtype=tf.int8),
tf.cast(tf.math.not_equal(decoder_input_ids[:, 1:], config.pad_token_id), tf.int8),
],
axis=-1,
)
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment