Unverified Commit 1394e08c authored by Dean Wyatte's avatar Dean Wyatte Committed by GitHub
Browse files

Support ONNX export for causal LM sequence classifiers (#27450)

support onnx for causal lm sequence classification
parent 06343b06
......@@ -796,7 +796,7 @@ class CTRLForSequenceClassification(CTRLPreTrainedModel):
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
logits.device
)
else:
......
......@@ -924,7 +924,7 @@ class OpenLlamaForSequenceClassification(OpenLlamaPreTrainedModel):
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
logits.device
)
else:
......
......@@ -1451,7 +1451,7 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel):
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
logits.device
)
else:
......
......@@ -1184,7 +1184,7 @@ class GPTBigCodeForSequenceClassification(GPTBigCodePreTrainedModel):
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
logits.device
)
else:
......
......@@ -1090,7 +1090,7 @@ class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel):
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
logits.device
)
else:
......
......@@ -948,7 +948,7 @@ class GPTNeoXForSequenceClassification(GPTNeoXPreTrainedModel):
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
logits.device
)
else:
......
......@@ -1001,7 +1001,7 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel):
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
logits.device
)
else:
......
......@@ -1204,7 +1204,7 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
logits.device
)
else:
......
......@@ -1174,7 +1174,7 @@ class MistralForSequenceClassification(MistralPreTrainedModel):
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
logits.device
)
else:
......
......@@ -814,7 +814,7 @@ class OpenAIGPTForSequenceClassification(OpenAIGPTPreTrainedModel):
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
logits.device
)
else:
......
......@@ -1030,7 +1030,7 @@ class OPTForSequenceClassification(OPTPreTrainedModel):
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
logits.device
)
else:
......
......@@ -925,7 +925,7 @@ class PersimmonForSequenceClassification(PersimmonPreTrainedModel):
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
logits.device
)
else:
......
......@@ -938,7 +938,7 @@ class PhiForSequenceClassification(PhiPreTrainedModel):
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
logits.device
)
else:
......
......@@ -1247,7 +1247,7 @@ class TransfoXLForSequenceClassification(TransfoXLPreTrainedModel):
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
logits.device
)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment