Unverified Commit f1a965bd authored by Thorsten Kurth's avatar Thorsten Kurth Committed by GitHub
Browse files

fixing amp stuff again (#48)

- doing manual type conversion in custom autograd
- fixing stride issues in backward pass by making some output tensors contiguous
parent 5d7e9b06
...@@ -44,7 +44,7 @@ except ImportError as err: ...@@ -44,7 +44,7 @@ except ImportError as err:
class _DiscoS2ContractionCuda(torch.autograd.Function): class _DiscoS2ContractionCuda(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(device_type="cuda", cast_inputs=torch.float32) @custom_fwd(device_type="cuda")
def forward(ctx, x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, def forward(ctx, x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor,
row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor,
kernel_size: int, nlat_out: int, nlon_out: int): kernel_size: int, nlat_out: int, nlon_out: int):
...@@ -52,7 +52,10 @@ class _DiscoS2ContractionCuda(torch.autograd.Function): ...@@ -52,7 +52,10 @@ class _DiscoS2ContractionCuda(torch.autograd.Function):
ctx.kernel_size = kernel_size ctx.kernel_size = kernel_size
ctx.nlat_in = x.shape[-2] ctx.nlat_in = x.shape[-2]
ctx.nlon_in = x.shape[-1] ctx.nlon_in = x.shape[-1]
output = disco_cuda_extension.forward(x.contiguous(), roff_idx, ker_idx, row_idx, col_idx, vals, kernel_size, nlat_out, nlon_out) xtype = x.dtype
x = x.to(torch.float32).contiguous()
output = disco_cuda_extension.forward(x, roff_idx, ker_idx, row_idx, col_idx, vals, kernel_size, nlat_out, nlon_out)
output = output.to(xtype)
return output return output
...@@ -60,15 +63,18 @@ class _DiscoS2ContractionCuda(torch.autograd.Function): ...@@ -60,15 +63,18 @@ class _DiscoS2ContractionCuda(torch.autograd.Function):
@custom_bwd(device_type="cuda") @custom_bwd(device_type="cuda")
def backward(ctx, grad_output): def backward(ctx, grad_output):
roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors
grad_input = disco_cuda_extension.backward(grad_output.contiguous(), roff_idx, ker_idx, row_idx, col_idx, vals, gtype = grad_output.dtype
grad_output = grad_output.to(torch.float32).contiguous()
grad_input = disco_cuda_extension.backward(grad_output, roff_idx, ker_idx, row_idx, col_idx, vals,
ctx.kernel_size, ctx.nlat_in, ctx.nlon_in) ctx.kernel_size, ctx.nlat_in, ctx.nlon_in)
grad_input = grad_input.to(gtype)
return grad_input, None, None, None, None, None, None, None, None return grad_input, None, None, None, None, None, None, None, None
class _DiscoS2TransposeContractionCuda(torch.autograd.Function): class _DiscoS2TransposeContractionCuda(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(device_type="cuda", cast_inputs=torch.float32) @custom_fwd(device_type="cuda")
def forward(ctx, x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, def forward(ctx, x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor,
row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor,
kernel_size: int, nlat_out: int, nlon_out: int): kernel_size: int, nlat_out: int, nlon_out: int):
...@@ -76,7 +82,10 @@ class _DiscoS2TransposeContractionCuda(torch.autograd.Function): ...@@ -76,7 +82,10 @@ class _DiscoS2TransposeContractionCuda(torch.autograd.Function):
ctx.kernel_size = kernel_size ctx.kernel_size = kernel_size
ctx.nlat_in = x.shape[-2] ctx.nlat_in = x.shape[-2]
ctx.nlon_in = x.shape[-1] ctx.nlon_in = x.shape[-1]
output = disco_cuda_extension.backward(x.contiguous(), roff_idx, ker_idx, row_idx, col_idx, vals, kernel_size, nlat_out, nlon_out) xtype = x.dtype
x = x.to(torch.float32).contiguous()
output = disco_cuda_extension.backward(x, roff_idx, ker_idx, row_idx, col_idx, vals, kernel_size, nlat_out, nlon_out)
output = output.to(xtype)
return output return output
...@@ -84,8 +93,11 @@ class _DiscoS2TransposeContractionCuda(torch.autograd.Function): ...@@ -84,8 +93,11 @@ class _DiscoS2TransposeContractionCuda(torch.autograd.Function):
@custom_bwd(device_type="cuda") @custom_bwd(device_type="cuda")
def backward(ctx, grad_output): def backward(ctx, grad_output):
roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors
grad_input = disco_cuda_extension.forward(grad_output.contiguous(), roff_idx, ker_idx, row_idx, col_idx, vals, gtype = grad_output.dtype
grad_output = grad_output.to(torch.float32).contiguous()
grad_input = disco_cuda_extension.forward(grad_output, roff_idx, ker_idx, row_idx, col_idx, vals,
ctx.kernel_size, ctx.nlat_in, ctx.nlon_in) ctx.kernel_size, ctx.nlat_in, ctx.nlon_in)
grad_input = grad_input.to(gtype)
return grad_input, None, None, None, None, None, None, None, None return grad_input, None, None, None, None, None, None, None, None
......
...@@ -417,7 +417,7 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv): ...@@ -417,7 +417,7 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
x = x.reshape(B, self.groups, self.groupsize, K, H, W) x = x.reshape(B, self.groups, self.groupsize, K, H, W)
# do weight multiplication # do weight multiplication
out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])) out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous()
out = out.reshape(B, -1, H, W) out = out.reshape(B, -1, H, W)
if self.bias is not None: if self.bias is not None:
...@@ -508,7 +508,7 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv): ...@@ -508,7 +508,7 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
x = x.reshape(B, self.groups, self.groupsize, H, W) x = x.reshape(B, self.groups, self.groupsize, H, W)
# do weight multiplication # do weight multiplication
x = torch.einsum("bgcxy,gock->bgokxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])) x = torch.einsum("bgcxy,gock->bgokxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous()
x = x.reshape(B, -1, x.shape[-3], H, W) x = x.reshape(B, -1, x.shape[-3], H, W)
if x.is_cuda and _cuda_extension_available: if x.is_cuda and _cuda_extension_available:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment