Unverified Commit 7bde5d63 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[`TFxxxxForSequenceClassifciation`] Fix the eager mode after #25085 (#25751)



* TODOS

* Switch .shape -> shape_list

---------
Co-authored-by: default avatarMatt <rocketknight1@gmail.com>
parent e2d6d5ce
...@@ -870,7 +870,11 @@ class TFGPTJForSequenceClassification(TFGPTJPreTrainedModel, TFSequenceClassific ...@@ -870,7 +870,11 @@ class TFGPTJForSequenceClassification(TFGPTJPreTrainedModel, TFSequenceClassific
tf.argmax(tf.cast(tf.math.equal(input_ids, self.config.pad_token_id), input_ids.dtype), axis=-1) tf.argmax(tf.cast(tf.math.equal(input_ids, self.config.pad_token_id), input_ids.dtype), axis=-1)
- 1 - 1
) )
sequence_lengths = tf.where(sequence_lengths >= 0, sequence_lengths, input_ids.shape[-1] - 1) sequence_lengths = tf.where(
sequence_lengths >= 0,
sequence_lengths,
tf.cast(shape_list(input_ids[-1]), sequence_lengths.dtype) - 1,
)
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1) in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
else: else:
sequence_lengths = -1 sequence_lengths = -1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment