Unverified Commit 418d9576 authored by Thorsten Kurth's avatar Thorsten Kurth Committed by GitHub
Browse files

removing some redundant code (#44)

parent 78365cb9
...@@ -84,10 +84,8 @@ class _DiscoS2TransposeContractionCuda(torch.autograd.Function): ...@@ -84,10 +84,8 @@ class _DiscoS2TransposeContractionCuda(torch.autograd.Function):
@custom_bwd(device_type="cuda") @custom_bwd(device_type="cuda")
def backward(ctx, grad_output): def backward(ctx, grad_output):
roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors
inp_type = grad_output.dtype
grad_input = disco_cuda_extension.forward(grad_output.contiguous(), roff_idx, ker_idx, row_idx, col_idx, vals, grad_input = disco_cuda_extension.forward(grad_output.contiguous(), roff_idx, ker_idx, row_idx, col_idx, vals,
ctx.kernel_size, ctx.nlat_in, ctx.nlon_in) ctx.kernel_size, ctx.nlat_in, ctx.nlon_in)
grad_input = grad_input.to(dtype=inp_type)
return grad_input, None, None, None, None, None, None, None, None return grad_input, None, None, None, None, None, None, None, None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment