Unverified Commit 2fac3422 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[`TF`] Also apply patch to support left padding (#25085)

* tf versions

* apply changes to other models

* 3 models slipped through the cracks
parent f1045227
...@@ -785,7 +785,9 @@ class CTRLForSequenceClassification(CTRLPreTrainedModel): ...@@ -785,7 +785,9 @@ class CTRLForSequenceClassification(CTRLPreTrainedModel):
sequence_lengths = -1 sequence_lengths = -1
else: else:
if input_ids is not None: if input_ids is not None:
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1 sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
logits.device
)
else: else:
sequence_lengths = -1 sequence_lengths = -1
logger.warning( logger.warning(
......
...@@ -798,16 +798,10 @@ class TFCTRLForSequenceClassification(TFCTRLPreTrainedModel, TFSequenceClassific ...@@ -798,16 +798,10 @@ class TFCTRLForSequenceClassification(TFCTRLPreTrainedModel, TFSequenceClassific
else: else:
if input_ids is not None: if input_ids is not None:
sequence_lengths = ( sequence_lengths = (
tf.reduce_sum( tf.argmax(tf.cast(tf.math.equal(input_ids, self.config.pad_token_id), input_ids.dtype), axis=-1)
tf.cast(
tf.math.not_equal(input_ids, self.config.pad_token_id),
dtype=input_ids.dtype,
),
-1,
keepdims=False,
)
- 1 - 1
) )
sequence_lengths = tf.where(sequence_lengths >= 0, sequence_lengths, input_ids.shape[-1] - 1)
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1) in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
else: else:
sequence_lengths = -1 sequence_lengths = -1
......
...@@ -1082,16 +1082,10 @@ class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassific ...@@ -1082,16 +1082,10 @@ class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassific
else: else:
if input_ids is not None: if input_ids is not None:
sequence_lengths = ( sequence_lengths = (
tf.reduce_sum( tf.argmax(tf.cast(tf.math.equal(input_ids, self.config.pad_token_id), input_ids.dtype), axis=-1)
tf.cast(
tf.math.not_equal(input_ids, self.config.pad_token_id),
dtype=input_ids.dtype,
),
-1,
keepdims=False,
)
- 1 - 1
) )
sequence_lengths = tf.where(sequence_lengths >= 0, sequence_lengths, input_ids.shape[-1] - 1)
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1) in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
else: else:
sequence_lengths = -1 sequence_lengths = -1
......
...@@ -867,16 +867,10 @@ class TFGPTJForSequenceClassification(TFGPTJPreTrainedModel, TFSequenceClassific ...@@ -867,16 +867,10 @@ class TFGPTJForSequenceClassification(TFGPTJPreTrainedModel, TFSequenceClassific
else: else:
if input_ids is not None: if input_ids is not None:
sequence_lengths = ( sequence_lengths = (
tf.reduce_sum( tf.argmax(tf.cast(tf.math.equal(input_ids, self.config.pad_token_id), input_ids.dtype), axis=-1)
tf.cast(
tf.math.not_equal(input_ids, self.config.pad_token_id),
dtype=input_ids.dtype,
),
-1,
keepdims=False,
)
- 1 - 1
) )
sequence_lengths = tf.where(sequence_lengths >= 0, sequence_lengths, input_ids.shape[-1] - 1)
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1) in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
else: else:
sequence_lengths = -1 sequence_lengths = -1
......
...@@ -813,7 +813,9 @@ class OpenAIGPTForSequenceClassification(OpenAIGPTPreTrainedModel): ...@@ -813,7 +813,9 @@ class OpenAIGPTForSequenceClassification(OpenAIGPTPreTrainedModel):
sequence_lengths = -1 sequence_lengths = -1
else: else:
if input_ids is not None: if input_ids is not None:
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1 sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
logits.device
)
else: else:
sequence_lengths = -1 sequence_lengths = -1
logger.warning( logger.warning(
......
...@@ -809,16 +809,10 @@ class TFOpenAIGPTForSequenceClassification(TFOpenAIGPTPreTrainedModel, TFSequenc ...@@ -809,16 +809,10 @@ class TFOpenAIGPTForSequenceClassification(TFOpenAIGPTPreTrainedModel, TFSequenc
else: else:
if input_ids is not None: if input_ids is not None:
sequence_lengths = ( sequence_lengths = (
tf.reduce_sum( tf.argmax(tf.cast(tf.math.equal(input_ids, self.config.pad_token_id), input_ids.dtype), axis=-1)
tf.cast(
tf.math.not_equal(input_ids, self.config.pad_token_id),
dtype=input_ids.dtype,
),
-1,
keepdims=False,
)
- 1 - 1
) )
sequence_lengths = tf.where(sequence_lengths >= 0, sequence_lengths, input_ids.shape[-1] - 1)
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1) in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
else: else:
sequence_lengths = -1 sequence_lengths = -1
......
...@@ -1066,16 +1066,10 @@ class TFTransfoXLForSequenceClassification(TFTransfoXLPreTrainedModel, TFSequenc ...@@ -1066,16 +1066,10 @@ class TFTransfoXLForSequenceClassification(TFTransfoXLPreTrainedModel, TFSequenc
else: else:
if input_ids is not None: if input_ids is not None:
sequence_lengths = ( sequence_lengths = (
tf.reduce_sum( tf.argmax(tf.cast(tf.math.equal(input_ids, self.config.pad_token_id), input_ids.dtype), axis=-1)
tf.cast(
tf.math.not_equal(input_ids, self.config.pad_token_id),
dtype=input_ids.dtype,
),
-1,
keepdims=False,
)
- 1 - 1
) )
sequence_lengths = tf.where(sequence_lengths >= 0, sequence_lengths, input_ids.shape[-1] - 1)
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1) in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
else: else:
sequence_lengths = -1 sequence_lengths = -1
......
...@@ -1247,7 +1247,9 @@ class TransfoXLForSequenceClassification(TransfoXLPreTrainedModel): ...@@ -1247,7 +1247,9 @@ class TransfoXLForSequenceClassification(TransfoXLPreTrainedModel):
sequence_lengths = -1 sequence_lengths = -1
else: else:
if input_ids is not None: if input_ids is not None:
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1 sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
logits.device
)
else: else:
sequence_lengths = -1 sequence_lengths = -1
logger.warning( logger.warning(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment