Unverified Commit 53f0c9c2 authored by Sai-Suraj-27's avatar Sai-Suraj-27 Committed by GitHub
Browse files

fix: Removed unnecessary `@staticmethod` decorator (#32361)

* Fixed staticmethods with self as first argument.

* Fixed staticmethods with self as first argument.

* Fixed staticmethods with self as first argument.

* Fixed staticmethods with self as first argument.
parent 92abe603
......@@ -104,20 +104,20 @@ class XSoftmax(torch.autograd.Function):
```"""
@staticmethod
def forward(self, input, mask, dim):
self.dim = dim
def forward(ctx, input, mask, dim):
ctx.dim = dim
rmask = ~(mask.to(torch.bool))
output = input.masked_fill(rmask, torch.tensor(torch.finfo(input.dtype).min))
output = torch.softmax(output, self.dim)
output = torch.softmax(output, ctx.dim)
output.masked_fill_(rmask, 0)
self.save_for_backward(output)
ctx.save_for_backward(output)
return output
@staticmethod
def backward(self, grad_output):
(output,) = self.saved_tensors
inputGrad = softmax_backward_data(self, grad_output, output, self.dim, output)
def backward(ctx, grad_output):
(output,) = ctx.saved_tensors
inputGrad = softmax_backward_data(ctx, grad_output, output, ctx.dim, output)
return inputGrad, None, None
@staticmethod
......
......@@ -98,20 +98,20 @@ class XSoftmax(torch.autograd.Function):
```"""
@staticmethod
def forward(self, input, mask, dim):
self.dim = dim
def forward(ctx, input, mask, dim):
ctx.dim = dim
rmask = ~(mask.to(torch.bool))
output = input.masked_fill(rmask, torch.tensor(torch.finfo(input.dtype).min))
output = torch.softmax(output, self.dim)
output = torch.softmax(output, ctx.dim)
output.masked_fill_(rmask, 0)
self.save_for_backward(output)
ctx.save_for_backward(output)
return output
@staticmethod
def backward(self, grad_output):
(output,) = self.saved_tensors
inputGrad = softmax_backward_data(self, grad_output, output, self.dim, output)
def backward(ctx, grad_output):
(output,) = ctx.saved_tensors
inputGrad = softmax_backward_data(ctx, grad_output, output, ctx.dim, output)
return inputGrad, None, None
@staticmethod
......
......@@ -520,20 +520,20 @@ class XSoftmax(torch.autograd.Function):
```"""
@staticmethod
def forward(self, input, mask, dim):
self.dim = dim
def forward(ctx, input, mask, dim):
ctx.dim = dim
rmask = ~(mask.to(torch.bool))
output = input.masked_fill(rmask, torch.tensor(torch.finfo(input.dtype).min))
output = torch.softmax(output, self.dim)
output = torch.softmax(output, ctx.dim)
output.masked_fill_(rmask, 0)
self.save_for_backward(output)
ctx.save_for_backward(output)
return output
@staticmethod
def backward(self, grad_output):
(output,) = self.saved_tensors
inputGrad = softmax_backward_data(self, grad_output, output, self.dim, output)
def backward(ctx, grad_output):
(output,) = ctx.saved_tensors
inputGrad = softmax_backward_data(ctx, grad_output, output, ctx.dim, output)
return inputGrad, None, None
@staticmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment