Unverified Commit f1045227 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[ `ForSequenceClassification`] Support `left` padding (#24979)

* support left padding

* nit

* Update src/transformers/models/gpt_neox/modeling_gpt_neox.py

* Update src/transformers/models/gpt_neox/modeling_gpt_neox.py
parent 1e662f0f
...@@ -956,7 +956,9 @@ class OpenLlamaForSequenceClassification(OpenLlamaPreTrainedModel): ...@@ -956,7 +956,9 @@ class OpenLlamaForSequenceClassification(OpenLlamaPreTrainedModel):
sequence_lengths = -1 sequence_lengths = -1
else: else:
if input_ids is not None: if input_ids is not None:
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
logits.device
)
else: else:
sequence_lengths = -1 sequence_lengths = -1
......
...@@ -1443,7 +1443,9 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel): ...@@ -1443,7 +1443,9 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel):
sequence_lengths = -1 sequence_lengths = -1
else: else:
if input_ids is not None: if input_ids is not None:
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
logits.device
)
else: else:
sequence_lengths = -1 sequence_lengths = -1
logger.warning( logger.warning(
......
...@@ -934,7 +934,9 @@ class GPTBigCodeForSequenceClassification(GPTBigCodePreTrainedModel): ...@@ -934,7 +934,9 @@ class GPTBigCodeForSequenceClassification(GPTBigCodePreTrainedModel):
sequence_lengths = -1 sequence_lengths = -1
else: else:
if input_ids is not None: if input_ids is not None:
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
logits.device
)
else: else:
sequence_lengths = -1 sequence_lengths = -1
logger.warning( logger.warning(
......
...@@ -878,7 +878,9 @@ class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel): ...@@ -878,7 +878,9 @@ class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel):
sequence_lengths = -1 sequence_lengths = -1
else: else:
if input_ids is not None: if input_ids is not None:
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
logits.device
)
else: else:
sequence_lengths = -1 sequence_lengths = -1
logger.warning( logger.warning(
......
...@@ -926,7 +926,9 @@ class GPTNeoXForSequenceClassification(GPTNeoXPreTrainedModel): ...@@ -926,7 +926,9 @@ class GPTNeoXForSequenceClassification(GPTNeoXPreTrainedModel):
sequence_lengths = -1 sequence_lengths = -1
else: else:
if input_ids is not None: if input_ids is not None:
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
logits.device
)
else: else:
sequence_lengths = -1 sequence_lengths = -1
logger.warning( logger.warning(
......
...@@ -1002,7 +1002,9 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel): ...@@ -1002,7 +1002,9 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel):
sequence_lengths = -1 sequence_lengths = -1
else: else:
if input_ids is not None: if input_ids is not None:
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
logits.device
)
else: else:
sequence_lengths = -1 sequence_lengths = -1
logger.warning( logger.warning(
......
...@@ -971,7 +971,9 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel): ...@@ -971,7 +971,9 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
sequence_lengths = -1 sequence_lengths = -1
else: else:
if input_ids is not None: if input_ids is not None:
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
logits.device
)
else: else:
sequence_lengths = -1 sequence_lengths = -1
......
...@@ -1084,7 +1084,9 @@ class OPTForSequenceClassification(OPTPreTrainedModel): ...@@ -1084,7 +1084,9 @@ class OPTForSequenceClassification(OPTPreTrainedModel):
sequence_lengths = -1 sequence_lengths = -1
else: else:
if input_ids is not None: if input_ids is not None:
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
logits.device
)
else: else:
sequence_lengths = -1 sequence_lengths = -1
logger.warning( logger.warning(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment