Unverified Commit e66743e6 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

DeBERTa/DeBERTa-v2/SEW Support for torch 1.11 (#16043)

* Support for torch 1.11

* Address Sylvain's comment
parent 741e4930
......@@ -18,7 +18,7 @@ import math
from collections.abc import Sequence
import torch
from torch import _softmax_backward_data, nn
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
......@@ -31,12 +31,12 @@ from ...modeling_outputs import (
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import softmax_backward_data
from ...utils import logging
from .configuration_deberta import DebertaConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "DebertaConfig"
_TOKENIZER_FOR_DOC = "DebertaTokenizer"
_CHECKPOINT_FOR_DOC = "microsoft/deberta-base"
......@@ -115,7 +115,7 @@ class XSoftmax(torch.autograd.Function):
@staticmethod
def backward(self, grad_output):
(output,) = self.saved_tensors
inputGrad = _softmax_backward_data(grad_output, output, self.dim, output)
inputGrad = softmax_backward_data(self, grad_output, output, self.dim, output)
return inputGrad, None, None
@staticmethod
......
......@@ -19,7 +19,7 @@ from collections.abc import Sequence
import numpy as np
import torch
from torch import _softmax_backward_data, nn
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
from ...activations import ACT2FN
......@@ -32,6 +32,7 @@ from ...modeling_outputs import (
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import softmax_backward_data
from ...utils import logging
from .configuration_deberta_v2 import DebertaV2Config
......@@ -116,7 +117,7 @@ class XSoftmax(torch.autograd.Function):
@staticmethod
def backward(self, grad_output):
(output,) = self.saved_tensors
inputGrad = _softmax_backward_data(grad_output, output, self.dim, output)
inputGrad = softmax_backward_data(self, grad_output, output, self.dim, output)
return inputGrad, None, None
@staticmethod
......
......@@ -22,7 +22,7 @@ from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.utils.checkpoint
from torch import _softmax_backward_data, nn
from torch import nn
from torch.nn import CrossEntropyLoss, LayerNorm
from transformers.deepspeed import is_deepspeed_zero3_enabled
......@@ -31,14 +31,13 @@ from ...activations import ACT2FN
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_int_div
from ...pytorch_utils import softmax_backward_data, torch_int_div
from ...utils import logging
from .configuration_sew_d import SEWDConfig
logger = logging.get_logger(__name__)
_HIDDEN_STATES_START_POSITION = 1
......@@ -545,7 +544,7 @@ class XSoftmax(torch.autograd.Function):
@staticmethod
def backward(self, grad_output):
(output,) = self.saved_tensors
inputGrad = _softmax_backward_data(grad_output, output, self.dim, output)
inputGrad = softmax_backward_data(self, grad_output, output, self.dim, output)
return inputGrad, None, None
@staticmethod
......
......@@ -14,18 +14,34 @@
import torch
from packaging import version
from torch import _softmax_backward_data
from .utils import logging
logger = logging.get_logger(__name__)
is_torch_less_than_1_8 = version.parse(torch.__version__) < version.parse("1.8.0")
is_torch_less_than_1_11 = version.parse(torch.__version__) < version.parse("1.11")
def torch_int_div(tensor1, tensor2):
"""
A function that performs integer division across different versions of PyTorch.
"""
if version.parse(torch.__version__) < version.parse("1.8.0"):
if is_torch_less_than_1_8:
return tensor1 // tensor2
else:
return torch.div(tensor1, tensor2, rounding_mode="floor")
def softmax_backward_data(parent, grad_output, output, dim, self):
"""
A function that calls the internal `_softmax_backward_data` PyTorch method and that adjusts the arguments according
to the torch version detected.
"""
if is_torch_less_than_1_11:
return _softmax_backward_data(grad_output, output, parent.dim, self)
else:
return _softmax_backward_data(grad_output, output, parent.dim, self.dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment