Unverified Commit 9d7b70bc authored by Gary Miguel's avatar Gary Miguel Committed by GitHub
Browse files

support ONNX export of XDropout in deberta{,_v2} and sew_d (#17502)

* support ONNX export of XDropout in deberta{,_v2}

* black

* copy to sew_d

* add test

* isort

* use pytest.mark.filterwarnings

* review comments
parent 92915ebe
...@@ -185,6 +185,21 @@ class XDropout(torch.autograd.Function): ...@@ -185,6 +185,21 @@ class XDropout(torch.autograd.Function):
else: else:
return grad_output, None return grad_output, None
@staticmethod
def symbolic(g: torch._C.Graph, input: torch._C.Value, local_ctx: Union[float, DropoutContext]) -> torch._C.Value:
dropout_p = local_ctx
if isinstance(local_ctx, DropoutContext):
dropout_p = local_ctx.dropout
# StableDropout only calls this function when training.
train = True
# TODO: We should check if the opset_version being used to export
# is > 12 here, but there's no good way to do that. As-is, if the
# opset_version < 12, export will fail with a CheckerError.
# Once https://github.com/pytorch/pytorch/issues/78391 is fixed, do something like:
# if opset_version < 12:
# return torch.onnx.symbolic_opset9.dropout(g, input, dropout_p, train)
return torch.onnx.symbolic_opset12.dropout(g, input, dropout_p, train)
class StableDropout(nn.Module): class StableDropout(nn.Module):
""" """
......
...@@ -191,6 +191,21 @@ class XDropout(torch.autograd.Function): ...@@ -191,6 +191,21 @@ class XDropout(torch.autograd.Function):
else: else:
return grad_output, None return grad_output, None
@staticmethod
def symbolic(g: torch._C.Graph, input: torch._C.Value, local_ctx: Union[float, DropoutContext]) -> torch._C.Value:
dropout_p = local_ctx
if isinstance(local_ctx, DropoutContext):
dropout_p = local_ctx.dropout
# StableDropout only calls this function when training.
train = True
# TODO: We should check if the opset_version being used to export
# is > 12 here, but there's no good way to do that. As-is, if the
# opset_version < 12, export will fail with a CheckerError.
# Once https://github.com/pytorch/pytorch/issues/78391 is fixed, do something like:
# if opset_version < 12:
# return torch.onnx.symbolic_opset9.dropout(g, input, dropout_p, train)
return torch.onnx.symbolic_opset12.dropout(g, input, dropout_p, train)
# Copied from transformers.models.deberta.modeling_deberta.StableDropout # Copied from transformers.models.deberta.modeling_deberta.StableDropout
class StableDropout(nn.Module): class StableDropout(nn.Module):
......
...@@ -595,6 +595,21 @@ class XDropout(torch.autograd.Function): ...@@ -595,6 +595,21 @@ class XDropout(torch.autograd.Function):
else: else:
return grad_output, None return grad_output, None
@staticmethod
def symbolic(g: torch._C.Graph, input: torch._C.Value, local_ctx: Union[float, DropoutContext]) -> torch._C.Value:
dropout_p = local_ctx
if isinstance(local_ctx, DropoutContext):
dropout_p = local_ctx.dropout
# StableDropout only calls this function when training.
train = True
# TODO: We should check if the opset_version being used to export
# is > 12 here, but there's no good way to do that. As-is, if the
# opset_version < 12, export will fail with a CheckerError.
# Once https://github.com/pytorch/pytorch/issues/78391 is fixed, do something like:
# if opset_version < 12:
# return torch.onnx.symbolic_opset9.dropout(g, input, dropout_p, train)
return torch.onnx.symbolic_opset12.dropout(g, input, dropout_p, train)
# Copied from transformers.models.deberta.modeling_deberta.StableDropout # Copied from transformers.models.deberta.modeling_deberta.StableDropout
class StableDropout(nn.Module): class StableDropout(nn.Module):
......
import os
from pathlib import Path from pathlib import Path
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from unittest import TestCase from unittest import TestCase
...@@ -26,6 +27,11 @@ from transformers.testing_utils import require_onnx, require_rjieba, require_tf, ...@@ -26,6 +27,11 @@ from transformers.testing_utils import require_onnx, require_rjieba, require_tf,
if is_torch_available() or is_tf_available(): if is_torch_available() or is_tf_available():
from transformers.onnx.features import FeaturesManager from transformers.onnx.features import FeaturesManager
if is_torch_available():
import torch
from transformers.models.deberta import modeling_deberta
@require_onnx @require_onnx
class OnnxUtilsTestCaseV2(TestCase): class OnnxUtilsTestCaseV2(TestCase):
...@@ -356,3 +362,40 @@ class OnnxExportTestCaseV2(TestCase): ...@@ -356,3 +362,40 @@ class OnnxExportTestCaseV2(TestCase):
self, test_name, name, model_name, feature, onnx_config_class_constructor self, test_name, name, model_name, feature, onnx_config_class_constructor
): ):
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor) self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
class StableDropoutTestCase(TestCase):
"""Tests export of StableDropout module."""
@require_torch
@pytest.mark.filterwarnings("ignore:.*Dropout.*:UserWarning:torch.onnx.*") # torch.onnx is spammy.
def test_training(self):
"""Tests export of StableDropout in training mode."""
devnull = open(os.devnull, "wb")
# drop_prob must be > 0 for the test to be meaningful
sd = modeling_deberta.StableDropout(0.1)
# Avoid warnings in training mode
do_constant_folding = False
# Dropout is a no-op in inference mode
training = torch.onnx.TrainingMode.PRESERVE
input = (torch.randn(2, 2),)
torch.onnx.export(
sd,
input,
devnull,
opset_version=12, # Minimum supported
do_constant_folding=do_constant_folding,
training=training,
)
# Expected to fail with opset_version < 12
with self.assertRaises(Exception):
torch.onnx.export(
sd,
input,
devnull,
opset_version=11,
do_constant_folding=do_constant_folding,
training=training,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment