Commit 27c888db authored by Lysandre's avatar Lysandre
Browse files

Fix copies

parent 3f23634a
......@@ -494,6 +494,21 @@ class XSoftmax(torch.autograd.Function):
inputGrad = _softmax_backward_data(grad_output, output, self.dim, output)
return inputGrad, None, None
@staticmethod
def symbolic(g, self, mask, dim):
import torch.onnx.symbolic_helper as sym_help
from torch.onnx.symbolic_opset9 import masked_fill, softmax
mask_cast_value = g.op("Cast", mask, to_i=sym_help.cast_pytorch_to_onnx["Long"])
r_mask = g.op(
"Cast",
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
to_i=sym_help.cast_pytorch_to_onnx["Byte"],
)
output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(float("-inf"))))
output = softmax(g, output, dim)
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))
# Copied from transformers.models.deberta.modeling_deberta.DropoutContext
class DropoutContext(object):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment