"git@developer.sourcefind.cn:chenpangpang/diffusers.git" did not exist on "f243282e3e9a81ebe2727dcc931d26b2c7e644f3"
Unverified Commit 48ae1b08 authored by shenggan's avatar shenggan Committed by GitHub
Browse files

fix backward return for fused softmax torch function (#36)

parent ded582b2
...@@ -54,7 +54,7 @@ class FusedMaskSoftmaxFunction(torch.autograd.Function): ...@@ -54,7 +54,7 @@ class FusedMaskSoftmaxFunction(torch.autograd.Function):
grad_input = fastfold_softmax_cuda.fused_mask_softmax_backward( grad_input = fastfold_softmax_cuda.fused_mask_softmax_backward(
grad_output.contiguous(), output, mask_, ctx.rows, ctx.cols) grad_output.contiguous(), output, mask_, ctx.rows, ctx.cols)
return grad_input.contiguous(), None, None return grad_input.contiguous(), None
class FusedMaskBiasSoftmaxFunction(torch.autograd.Function): class FusedMaskBiasSoftmaxFunction(torch.autograd.Function):
...@@ -85,7 +85,7 @@ class FusedMaskBiasSoftmaxFunction(torch.autograd.Function): ...@@ -85,7 +85,7 @@ class FusedMaskBiasSoftmaxFunction(torch.autograd.Function):
grad_bias = torch.sum(grad_input, dim=1, keepdim=True) grad_bias = torch.sum(grad_input, dim=1, keepdim=True)
return grad_input.contiguous(), grad_bias, None, None return grad_input.contiguous(), None, grad_bias
softmax = SoftmaxAffineFunction.apply softmax = SoftmaxAffineFunction.apply
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment