Unverified Commit e66743e6 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

DeBERTa/DeBERTa-v2/SEW Support for torch 1.11 (#16043)

* Support for torch 1.11

* Address Sylvain's comment
parent 741e4930
...@@ -18,7 +18,7 @@ import math ...@@ -18,7 +18,7 @@ import math
from collections.abc import Sequence from collections.abc import Sequence
import torch import torch
from torch import _softmax_backward_data, nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
...@@ -31,12 +31,12 @@ from ...modeling_outputs import ( ...@@ -31,12 +31,12 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import softmax_backward_data
from ...utils import logging from ...utils import logging
from .configuration_deberta import DebertaConfig from .configuration_deberta import DebertaConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "DebertaConfig" _CONFIG_FOR_DOC = "DebertaConfig"
_TOKENIZER_FOR_DOC = "DebertaTokenizer" _TOKENIZER_FOR_DOC = "DebertaTokenizer"
_CHECKPOINT_FOR_DOC = "microsoft/deberta-base" _CHECKPOINT_FOR_DOC = "microsoft/deberta-base"
...@@ -115,7 +115,7 @@ class XSoftmax(torch.autograd.Function): ...@@ -115,7 +115,7 @@ class XSoftmax(torch.autograd.Function):
@staticmethod @staticmethod
def backward(self, grad_output): def backward(self, grad_output):
(output,) = self.saved_tensors (output,) = self.saved_tensors
inputGrad = _softmax_backward_data(grad_output, output, self.dim, output) inputGrad = softmax_backward_data(self, grad_output, output, self.dim, output)
return inputGrad, None, None return inputGrad, None, None
@staticmethod @staticmethod
......
...@@ -19,7 +19,7 @@ from collections.abc import Sequence ...@@ -19,7 +19,7 @@ from collections.abc import Sequence
import numpy as np import numpy as np
import torch import torch
from torch import _softmax_backward_data, nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
...@@ -32,6 +32,7 @@ from ...modeling_outputs import ( ...@@ -32,6 +32,7 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import softmax_backward_data
from ...utils import logging from ...utils import logging
from .configuration_deberta_v2 import DebertaV2Config from .configuration_deberta_v2 import DebertaV2Config
...@@ -116,7 +117,7 @@ class XSoftmax(torch.autograd.Function): ...@@ -116,7 +117,7 @@ class XSoftmax(torch.autograd.Function):
@staticmethod @staticmethod
def backward(self, grad_output): def backward(self, grad_output):
(output,) = self.saved_tensors (output,) = self.saved_tensors
inputGrad = _softmax_backward_data(grad_output, output, self.dim, output) inputGrad = softmax_backward_data(self, grad_output, output, self.dim, output)
return inputGrad, None, None return inputGrad, None, None
@staticmethod @staticmethod
......
...@@ -22,7 +22,7 @@ from typing import Optional, Tuple, Union ...@@ -22,7 +22,7 @@ from typing import Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import _softmax_backward_data, nn from torch import nn
from torch.nn import CrossEntropyLoss, LayerNorm from torch.nn import CrossEntropyLoss, LayerNorm
from transformers.deepspeed import is_deepspeed_zero3_enabled from transformers.deepspeed import is_deepspeed_zero3_enabled
...@@ -31,14 +31,13 @@ from ...activations import ACT2FN ...@@ -31,14 +31,13 @@ from ...activations import ACT2FN
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_int_div from ...pytorch_utils import softmax_backward_data, torch_int_div
from ...utils import logging from ...utils import logging
from .configuration_sew_d import SEWDConfig from .configuration_sew_d import SEWDConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_HIDDEN_STATES_START_POSITION = 1 _HIDDEN_STATES_START_POSITION = 1
...@@ -545,7 +544,7 @@ class XSoftmax(torch.autograd.Function): ...@@ -545,7 +544,7 @@ class XSoftmax(torch.autograd.Function):
@staticmethod @staticmethod
def backward(self, grad_output): def backward(self, grad_output):
(output,) = self.saved_tensors (output,) = self.saved_tensors
inputGrad = _softmax_backward_data(grad_output, output, self.dim, output) inputGrad = softmax_backward_data(self, grad_output, output, self.dim, output)
return inputGrad, None, None return inputGrad, None, None
@staticmethod @staticmethod
......
...@@ -14,18 +14,34 @@ ...@@ -14,18 +14,34 @@
import torch import torch
from packaging import version from packaging import version
from torch import _softmax_backward_data
from .utils import logging from .utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
is_torch_less_than_1_8 = version.parse(torch.__version__) < version.parse("1.8.0")
is_torch_less_than_1_11 = version.parse(torch.__version__) < version.parse("1.11")
def torch_int_div(tensor1, tensor2): def torch_int_div(tensor1, tensor2):
""" """
A function that performs integer division across different versions of PyTorch. A function that performs integer division across different versions of PyTorch.
""" """
if version.parse(torch.__version__) < version.parse("1.8.0"): if is_torch_less_than_1_8:
return tensor1 // tensor2 return tensor1 // tensor2
else: else:
return torch.div(tensor1, tensor2, rounding_mode="floor") return torch.div(tensor1, tensor2, rounding_mode="floor")
def softmax_backward_data(parent, grad_output, output, dim, self):
"""
A function that calls the internal `_softmax_backward_data` PyTorch method and that adjusts the arguments according
to the torch version detected.
"""
if is_torch_less_than_1_11:
return _softmax_backward_data(grad_output, output, parent.dim, self)
else:
return _softmax_backward_data(grad_output, output, parent.dim, self.dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment