Unverified Commit f1045227 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[ `ForSequenceClassification`] Support `left` padding (#24979)

* support left padding

* nit

* Update src/transformers/models/gpt_neox/modeling_gpt_neox.py

* Update src/transformers/models/gpt_neox/modeling_gpt_neox.py
parent 1e662f0f
......@@ -956,7 +956,9 @@ class OpenLlamaForSequenceClassification(OpenLlamaPreTrainedModel):
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
logits.device
)
else:
sequence_lengths = -1
......
......@@ -1443,7 +1443,9 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel):
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
logits.device
)
else:
sequence_lengths = -1
logger.warning(
......
......@@ -934,7 +934,9 @@ class GPTBigCodeForSequenceClassification(GPTBigCodePreTrainedModel):
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
logits.device
)
else:
sequence_lengths = -1
logger.warning(
......
......@@ -878,7 +878,9 @@ class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel):
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
logits.device
)
else:
sequence_lengths = -1
logger.warning(
......
......@@ -926,7 +926,9 @@ class GPTNeoXForSequenceClassification(GPTNeoXPreTrainedModel):
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
logits.device
)
else:
sequence_lengths = -1
logger.warning(
......
......@@ -1002,7 +1002,9 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel):
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
logits.device
)
else:
sequence_lengths = -1
logger.warning(
......
......@@ -971,7 +971,9 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
logits.device
)
else:
sequence_lengths = -1
......
......@@ -1084,7 +1084,9 @@ class OPTForSequenceClassification(OPTPreTrainedModel):
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
logits.device
)
else:
sequence_lengths = -1
logger.warning(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment