Unverified Commit 548a8f61 authored by Dean Wyatte's avatar Dean Wyatte Committed by GitHub
Browse files

Fix ONNX export for causal LM sequence classifiers by removing reverse indexing (#28144)

* normalize reverse indexing for causal lm sequence classifiers

* normalize reverse indexing for causal lm sequence classifiers

* normalize reverse indexing for causal lm sequence classifiers

* use modulo instead

* unify modulo-based sequence lengths
parent 71f46057
...@@ -1011,7 +1011,10 @@ class BloomForSequenceClassification(BloomPreTrainedModel): ...@@ -1011,7 +1011,10 @@ class BloomForSequenceClassification(BloomPreTrainedModel):
sequence_lengths = -1 sequence_lengths = -1
else: else:
if input_ids is not None: if input_ids is not None:
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else: else:
sequence_lengths = -1 sequence_lengths = -1
logger.warning( logger.warning(
......
...@@ -796,9 +796,10 @@ class CTRLForSequenceClassification(CTRLPreTrainedModel): ...@@ -796,9 +796,10 @@ class CTRLForSequenceClassification(CTRLPreTrainedModel):
sequence_lengths = -1 sequence_lengths = -1
else: else:
if input_ids is not None: if input_ids is not None:
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to( # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
logits.device sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
) sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else: else:
sequence_lengths = -1 sequence_lengths = -1
logger.warning( logger.warning(
......
...@@ -923,9 +923,10 @@ class OpenLlamaForSequenceClassification(OpenLlamaPreTrainedModel): ...@@ -923,9 +923,10 @@ class OpenLlamaForSequenceClassification(OpenLlamaPreTrainedModel):
sequence_lengths = -1 sequence_lengths = -1
else: else:
if input_ids is not None: if input_ids is not None:
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to( # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
logits.device sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
) sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else: else:
sequence_lengths = -1 sequence_lengths = -1
......
...@@ -1247,9 +1247,10 @@ class TransfoXLForSequenceClassification(TransfoXLPreTrainedModel): ...@@ -1247,9 +1247,10 @@ class TransfoXLForSequenceClassification(TransfoXLPreTrainedModel):
sequence_lengths = -1 sequence_lengths = -1
else: else:
if input_ids is not None: if input_ids is not None:
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to( # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
logits.device sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
) sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else: else:
sequence_lengths = -1 sequence_lengths = -1
logger.warning( logger.warning(
......
...@@ -1432,7 +1432,10 @@ class FalconForSequenceClassification(FalconPreTrainedModel): ...@@ -1432,7 +1432,10 @@ class FalconForSequenceClassification(FalconPreTrainedModel):
sequence_lengths = -1 sequence_lengths = -1
else: else:
if input_ids is not None: if input_ids is not None:
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(dim=-1) - 1).to(logits.device) # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else: else:
sequence_lengths = -1 sequence_lengths = -1
logger.warning( logger.warning(
......
...@@ -1451,9 +1451,10 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel): ...@@ -1451,9 +1451,10 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel):
sequence_lengths = -1 sequence_lengths = -1
else: else:
if input_ids is not None: if input_ids is not None:
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to( # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
logits.device sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
) sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else: else:
sequence_lengths = -1 sequence_lengths = -1
logger.warning( logger.warning(
......
...@@ -1384,9 +1384,10 @@ class GPTBigCodeForSequenceClassification(GPTBigCodePreTrainedModel): ...@@ -1384,9 +1384,10 @@ class GPTBigCodeForSequenceClassification(GPTBigCodePreTrainedModel):
sequence_lengths = -1 sequence_lengths = -1
else: else:
if input_ids is not None: if input_ids is not None:
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to( # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
logits.device sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
) sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else: else:
sequence_lengths = -1 sequence_lengths = -1
logger.warning( logger.warning(
......
...@@ -1113,9 +1113,10 @@ class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel): ...@@ -1113,9 +1113,10 @@ class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel):
sequence_lengths = -1 sequence_lengths = -1
else: else:
if input_ids is not None: if input_ids is not None:
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to( # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
logits.device sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
) sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else: else:
sequence_lengths = -1 sequence_lengths = -1
logger.warning( logger.warning(
......
...@@ -1200,9 +1200,10 @@ class GPTNeoXForSequenceClassification(GPTNeoXPreTrainedModel): ...@@ -1200,9 +1200,10 @@ class GPTNeoXForSequenceClassification(GPTNeoXPreTrainedModel):
sequence_lengths = -1 sequence_lengths = -1
else: else:
if input_ids is not None: if input_ids is not None:
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to( # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
logits.device sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
) sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else: else:
sequence_lengths = -1 sequence_lengths = -1
logger.warning( logger.warning(
......
...@@ -1001,9 +1001,10 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel): ...@@ -1001,9 +1001,10 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel):
sequence_lengths = -1 sequence_lengths = -1
else: else:
if input_ids is not None: if input_ids is not None:
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to( # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
logits.device sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
) sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else: else:
sequence_lengths = -1 sequence_lengths = -1
logger.warning( logger.warning(
......
...@@ -1370,9 +1370,10 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel): ...@@ -1370,9 +1370,10 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
sequence_lengths = -1 sequence_lengths = -1
else: else:
if input_ids is not None: if input_ids is not None:
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to( # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
logits.device sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
) sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else: else:
sequence_lengths = -1 sequence_lengths = -1
......
...@@ -1338,9 +1338,10 @@ class MistralForSequenceClassification(MistralPreTrainedModel): ...@@ -1338,9 +1338,10 @@ class MistralForSequenceClassification(MistralPreTrainedModel):
sequence_lengths = -1 sequence_lengths = -1
else: else:
if input_ids is not None: if input_ids is not None:
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to( # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
logits.device sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
) sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else: else:
sequence_lengths = -1 sequence_lengths = -1
......
...@@ -1518,9 +1518,10 @@ class MixtralForSequenceClassification(MixtralPreTrainedModel): ...@@ -1518,9 +1518,10 @@ class MixtralForSequenceClassification(MixtralPreTrainedModel):
sequence_lengths = -1 sequence_lengths = -1
else: else:
if input_ids is not None: if input_ids is not None:
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to( # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
logits.device sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
) sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else: else:
sequence_lengths = -1 sequence_lengths = -1
......
...@@ -729,7 +729,10 @@ class MptForSequenceClassification(MptPreTrainedModel): ...@@ -729,7 +729,10 @@ class MptForSequenceClassification(MptPreTrainedModel):
sequence_lengths = -1 sequence_lengths = -1
else: else:
if input_ids is not None: if input_ids is not None:
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else: else:
sequence_lengths = -1 sequence_lengths = -1
logger.warning( logger.warning(
......
...@@ -814,9 +814,10 @@ class OpenAIGPTForSequenceClassification(OpenAIGPTPreTrainedModel): ...@@ -814,9 +814,10 @@ class OpenAIGPTForSequenceClassification(OpenAIGPTPreTrainedModel):
sequence_lengths = -1 sequence_lengths = -1
else: else:
if input_ids is not None: if input_ids is not None:
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to( # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
logits.device sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
) sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else: else:
sequence_lengths = -1 sequence_lengths = -1
logger.warning( logger.warning(
......
...@@ -1294,9 +1294,10 @@ class OPTForSequenceClassification(OPTPreTrainedModel): ...@@ -1294,9 +1294,10 @@ class OPTForSequenceClassification(OPTPreTrainedModel):
sequence_lengths = -1 sequence_lengths = -1
else: else:
if input_ids is not None: if input_ids is not None:
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to( # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
logits.device sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
) sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else: else:
sequence_lengths = -1 sequence_lengths = -1
logger.warning( logger.warning(
......
...@@ -969,9 +969,10 @@ class PersimmonForSequenceClassification(PersimmonPreTrainedModel): ...@@ -969,9 +969,10 @@ class PersimmonForSequenceClassification(PersimmonPreTrainedModel):
sequence_lengths = -1 sequence_lengths = -1
else: else:
if input_ids is not None: if input_ids is not None:
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to( # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
logits.device sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
) sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else: else:
sequence_lengths = -1 sequence_lengths = -1
......
...@@ -1225,9 +1225,10 @@ class PhiForSequenceClassification(PhiPreTrainedModel): ...@@ -1225,9 +1225,10 @@ class PhiForSequenceClassification(PhiPreTrainedModel):
sequence_lengths = -1 sequence_lengths = -1
else: else:
if input_ids is not None: if input_ids is not None:
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to( # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
logits.device sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
) sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else: else:
sequence_lengths = -1 sequence_lengths = -1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment