Unverified Commit 3f23634a authored by Jay Zhang's avatar Jay Zhang Committed by GitHub
Browse files

[ONNX] Add symbolic function for XSoftmax op for exporting to ONNX. (#14013)

* Add symbolic function for XSoftmax op for exporting to ONNX.

* Fix format issues.

* Fix a CI issue relative to copies.
parent 9f3aa46f
......@@ -113,6 +113,21 @@ class XSoftmax(torch.autograd.Function):
inputGrad = _softmax_backward_data(grad_output, output, self.dim, output)
return inputGrad, None, None
@staticmethod
def symbolic(g, self, mask, dim):
import torch.onnx.symbolic_helper as sym_help
from torch.onnx.symbolic_opset9 import masked_fill, softmax
mask_cast_value = g.op("Cast", mask, to_i=sym_help.cast_pytorch_to_onnx["Long"])
r_mask = g.op(
"Cast",
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
to_i=sym_help.cast_pytorch_to_onnx["Byte"],
)
output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(float("-inf"))))
output = softmax(g, output, dim)
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))
class DropoutContext(object):
def __init__(self):
......@@ -1178,10 +1193,7 @@ class DebertaForSequenceClassification(DebertaPreTrainedModel):
return ((loss,) + output) if loss is not None else output
else:
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
)
......@@ -1266,10 +1278,7 @@ class DebertaForTokenClassification(DebertaPreTrainedModel):
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
)
......
......@@ -114,6 +114,21 @@ class XSoftmax(torch.autograd.Function):
inputGrad = _softmax_backward_data(grad_output, output, self.dim, output)
return inputGrad, None, None
@staticmethod
def symbolic(g, self, mask, dim):
import torch.onnx.symbolic_helper as sym_help
from torch.onnx.symbolic_opset9 import masked_fill, softmax
mask_cast_value = g.op("Cast", mask, to_i=sym_help.cast_pytorch_to_onnx["Long"])
r_mask = g.op(
"Cast",
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
to_i=sym_help.cast_pytorch_to_onnx["Byte"],
)
output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(float("-inf"))))
output = softmax(g, output, dim)
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))
# Copied from transformers.models.deberta.modeling_deberta.DropoutContext
class DropoutContext(object):
......@@ -1288,10 +1303,7 @@ class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel):
return ((loss,) + output) if loss is not None else output
else:
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
)
......@@ -1377,10 +1389,7 @@ class DebertaV2ForTokenClassification(DebertaV2PreTrainedModel):
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment