"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "023f51fe16e34e0ca2b5598791ae508874d5b443"
Unverified Commit 6c8ec2a9 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

fix tf led pt test (#9513)

parent 1e3c3622
......@@ -166,7 +166,13 @@ def prepare_led_inputs_dict(
if attention_mask is None:
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
if decoder_attention_mask is None:
decoder_attention_mask = tf.cast(tf.math.not_equal(decoder_input_ids, config.pad_token_id), tf.int8)
decoder_attention_mask = tf.concat(
[
tf.ones(decoder_input_ids[:, :1].shape, dtype=tf.int8),
tf.cast(tf.math.not_equal(decoder_input_ids[:, 1:], config.pad_token_id), tf.int8),
],
axis=-1,
)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment