Unverified Commit 0b369703 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Remove `decoder_position_ids` from `check_decoder_model_past_large_inputs` (#18980)


Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent a86acb75
......@@ -125,21 +125,10 @@ class TFBartModelTester:
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
next_attention_mask = tf.concat([attention_mask, next_attn_mask], axis=-1)
decoder_position_ids = tf.cast(tf.cumsum(next_attention_mask, axis=1, exclusive=True), dtype=tf.int32)
output_from_no_past = model(
next_input_ids, attention_mask=next_attention_mask, position_ids=decoder_position_ids
)
output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)
output_from_no_past = output_from_no_past[0]
decoder_position_ids = (
tf.cast(tf.cumsum(next_attn_mask, axis=1, exclusive=True), dtype=tf.int32) + past_key_values[0][0].shape[2]
)
output_from_past = model(
next_tokens,
attention_mask=next_attention_mask,
past_key_values=past_key_values,
position_ids=decoder_position_ids,
)
output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)
output_from_past = output_from_past[0]
self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment