Unverified Commit 78365cb9 authored by Thorsten Kurth's avatar Thorsten Kurth Committed by GitHub
Browse files

Tkurth/distributed memory reduction (#43)

* updating amp to new torch.amp
* using amp autocrats to FP32 for disco convolution kernels
* implemented reduce_scatter routines but disabled those because of memory fluctuations which can cause OOM on big networks
parent 77a64b2c
...@@ -30,7 +30,7 @@ ...@@ -30,7 +30,7 @@
# build after cloning in directoy torch_harmonics via # build after cloning in directoy torch_harmonics via
# docker build . -t torch_harmonics # docker build . -t torch_harmonics
FROM nvcr.io/nvidia/pytorch:24.05-py3 FROM nvcr.io/nvidia/pytorch:24.07-py3
COPY . /workspace/torch_harmonics COPY . /workspace/torch_harmonics
......
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
import math import math
import torch import torch
from torch.amp import custom_fwd, custom_bwd
try: try:
import disco_cuda_extension import disco_cuda_extension
...@@ -43,6 +44,7 @@ except ImportError as err: ...@@ -43,6 +44,7 @@ except ImportError as err:
class _DiscoS2ContractionCuda(torch.autograd.Function): class _DiscoS2ContractionCuda(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(device_type="cuda", cast_inputs=torch.float32)
def forward(ctx, x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, def forward(ctx, x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor,
row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor,
kernel_size: int, nlat_out: int, nlon_out: int): kernel_size: int, nlat_out: int, nlon_out: int):
...@@ -50,10 +52,12 @@ class _DiscoS2ContractionCuda(torch.autograd.Function): ...@@ -50,10 +52,12 @@ class _DiscoS2ContractionCuda(torch.autograd.Function):
ctx.kernel_size = kernel_size ctx.kernel_size = kernel_size
ctx.nlat_in = x.shape[-2] ctx.nlat_in = x.shape[-2]
ctx.nlon_in = x.shape[-1] ctx.nlon_in = x.shape[-1]
output = disco_cuda_extension.forward(x.contiguous(), roff_idx, ker_idx, row_idx, col_idx, vals, kernel_size, nlat_out, nlon_out)
return disco_cuda_extension.forward(x.contiguous(), roff_idx, ker_idx, row_idx, col_idx, vals, kernel_size, nlat_out, nlon_out) return output
@staticmethod @staticmethod
@custom_bwd(device_type="cuda")
def backward(ctx, grad_output): def backward(ctx, grad_output):
roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors
grad_input = disco_cuda_extension.backward(grad_output.contiguous(), roff_idx, ker_idx, row_idx, col_idx, vals, grad_input = disco_cuda_extension.backward(grad_output.contiguous(), roff_idx, ker_idx, row_idx, col_idx, vals,
...@@ -64,6 +68,7 @@ class _DiscoS2ContractionCuda(torch.autograd.Function): ...@@ -64,6 +68,7 @@ class _DiscoS2ContractionCuda(torch.autograd.Function):
class _DiscoS2TransposeContractionCuda(torch.autograd.Function): class _DiscoS2TransposeContractionCuda(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(device_type="cuda", cast_inputs=torch.float32)
def forward(ctx, x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, def forward(ctx, x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor,
row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor,
kernel_size: int, nlat_out: int, nlon_out: int): kernel_size: int, nlat_out: int, nlon_out: int):
...@@ -71,14 +76,18 @@ class _DiscoS2TransposeContractionCuda(torch.autograd.Function): ...@@ -71,14 +76,18 @@ class _DiscoS2TransposeContractionCuda(torch.autograd.Function):
ctx.kernel_size = kernel_size ctx.kernel_size = kernel_size
ctx.nlat_in = x.shape[-2] ctx.nlat_in = x.shape[-2]
ctx.nlon_in = x.shape[-1] ctx.nlon_in = x.shape[-1]
output = disco_cuda_extension.backward(x.contiguous(), roff_idx, ker_idx, row_idx, col_idx, vals, kernel_size, nlat_out, nlon_out)
return disco_cuda_extension.backward(x.contiguous(), roff_idx, ker_idx, row_idx, col_idx, vals, kernel_size, nlat_out, nlon_out) return output
@staticmethod @staticmethod
@custom_bwd(device_type="cuda")
def backward(ctx, grad_output): def backward(ctx, grad_output):
roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors
inp_type = grad_output.dtype
grad_input = disco_cuda_extension.forward(grad_output.contiguous(), roff_idx, ker_idx, row_idx, col_idx, vals, grad_input = disco_cuda_extension.forward(grad_output.contiguous(), roff_idx, ker_idx, row_idx, col_idx, vals,
ctx.kernel_size, ctx.nlat_in, ctx.nlon_in) ctx.kernel_size, ctx.nlat_in, ctx.nlon_in)
grad_input = grad_input.to(dtype=inp_type)
return grad_input, None, None, None, None, None, None, None, None return grad_input, None, None, None, None, None, None, None, None
......
...@@ -376,7 +376,8 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv): ...@@ -376,7 +376,8 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
ker_idx = idx[0, ...].contiguous() ker_idx = idx[0, ...].contiguous()
row_idx = idx[1, ...].contiguous() row_idx = idx[1, ...].contiguous()
col_idx = idx[2, ...].contiguous() col_idx = idx[2, ...].contiguous()
roff_idx = preprocess_psi(self.kernel_size, out_shape[0], ker_idx, row_idx, col_idx, vals) vals = vals.contiguous()
roff_idx = preprocess_psi(self.kernel_size, out_shape[0], ker_idx, row_idx, col_idx, vals).contiguous()
# preprocessed data-structure for GPU kernel # preprocessed data-structure for GPU kernel
self.register_buffer("psi_roff_idx", roff_idx, persistent=False) self.register_buffer("psi_roff_idx", roff_idx, persistent=False)
...@@ -466,7 +467,8 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv): ...@@ -466,7 +467,8 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
ker_idx = idx[0, ...].contiguous() ker_idx = idx[0, ...].contiguous()
row_idx = idx[1, ...].contiguous() row_idx = idx[1, ...].contiguous()
col_idx = idx[2, ...].contiguous() col_idx = idx[2, ...].contiguous()
roff_idx = preprocess_psi(self.kernel_size, in_shape[0], ker_idx, row_idx, col_idx, vals) vals = vals.contiguous()
roff_idx = preprocess_psi(self.kernel_size, in_shape[0], ker_idx, row_idx, col_idx, vals).contiguous()
# preprocessed data-structure for GPU kernel # preprocessed data-structure for GPU kernel
self.register_buffer("psi_roff_idx", roff_idx, persistent=False) self.register_buffer("psi_roff_idx", roff_idx, persistent=False)
......
...@@ -239,7 +239,8 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv): ...@@ -239,7 +239,8 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
ker_idx = idx[0, ...].contiguous() ker_idx = idx[0, ...].contiguous()
row_idx = idx[1, ...].contiguous() row_idx = idx[1, ...].contiguous()
col_idx = idx[2, ...].contiguous() col_idx = idx[2, ...].contiguous()
roff_idx = preprocess_psi(self.kernel_size, self.nlat_out_local, ker_idx, row_idx, col_idx, vals) vals = vals.contiguous()
roff_idx = preprocess_psi(self.kernel_size, self.nlat_out_local, ker_idx, row_idx, col_idx, vals).contiguous()
# preprocessed data-structure for GPU kernel # preprocessed data-structure for GPU kernel
self.register_buffer("psi_roff_idx", roff_idx, persistent=False) self.register_buffer("psi_roff_idx", roff_idx, persistent=False)
...@@ -284,7 +285,7 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv): ...@@ -284,7 +285,7 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
x = _disco_s2_contraction_torch(x, psi, self.nlon_out) x = _disco_s2_contraction_torch(x, psi, self.nlon_out)
# perform reduce scatter in polar region # perform reduce scatter in polar region
x = reduce_from_polar_region (x) x = reduce_from_polar_region(x)
x = scatter_to_polar_region(x, -2) x = scatter_to_polar_region(x, -2)
# now we can transpose back the result, so that lon is split and channels are local # now we can transpose back the result, so that lon is split and channels are local
...@@ -368,7 +369,8 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv): ...@@ -368,7 +369,8 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
ker_idx = idx[0, ...].contiguous() ker_idx = idx[0, ...].contiguous()
row_idx = idx[1, ...].contiguous() row_idx = idx[1, ...].contiguous()
col_idx = idx[2, ...].contiguous() col_idx = idx[2, ...].contiguous()
roff_idx = preprocess_psi(self.kernel_size, self.nlat_in_local, ker_idx, row_idx, col_idx, vals) vals = vals.contiguous()
roff_idx = preprocess_psi(self.kernel_size, self.nlat_in_local, ker_idx, row_idx, col_idx, vals).contiguous()
# preprocessed data-structure for GPU kernel # preprocessed data-structure for GPU kernel
self.register_buffer("psi_roff_idx", roff_idx, persistent=False) self.register_buffer("psi_roff_idx", roff_idx, persistent=False)
......
...@@ -32,6 +32,7 @@ from typing import List ...@@ -32,6 +32,7 @@ from typing import List
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.amp import custom_fwd, custom_bwd
from .utils import polar_group, azimuth_group, polar_group_size from .utils import polar_group, azimuth_group, polar_group_size
from .utils import is_initialized, is_distributed_polar from .utils import is_initialized, is_distributed_polar
...@@ -96,6 +97,7 @@ def _transpose(tensor, dim0, dim1, dim1_split_sizes, group=None, async_op=False) ...@@ -96,6 +97,7 @@ def _transpose(tensor, dim0, dim1, dim1_split_sizes, group=None, async_op=False)
class distributed_transpose_azimuth(torch.autograd.Function): class distributed_transpose_azimuth(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(device_type="cuda")
def forward(ctx, x, dims, dim1_split_sizes): def forward(ctx, x, dims, dim1_split_sizes):
# WAR for a potential contig check torch bug for channels last contig tensors # WAR for a potential contig check torch bug for channels last contig tensors
xlist, dim0_split_sizes, _ = _transpose(x, dims[0], dims[1], dim1_split_sizes, group=azimuth_group()) xlist, dim0_split_sizes, _ = _transpose(x, dims[0], dims[1], dim1_split_sizes, group=azimuth_group())
...@@ -106,6 +108,7 @@ class distributed_transpose_azimuth(torch.autograd.Function): ...@@ -106,6 +108,7 @@ class distributed_transpose_azimuth(torch.autograd.Function):
return x return x
@staticmethod @staticmethod
@custom_bwd(device_type="cuda")
def backward(ctx, go): def backward(ctx, go):
dims = ctx.dims dims = ctx.dims
dim0_split_sizes = ctx.dim0_split_sizes dim0_split_sizes = ctx.dim0_split_sizes
...@@ -119,6 +122,7 @@ class distributed_transpose_azimuth(torch.autograd.Function): ...@@ -119,6 +122,7 @@ class distributed_transpose_azimuth(torch.autograd.Function):
class distributed_transpose_polar(torch.autograd.Function): class distributed_transpose_polar(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(device_type="cuda")
def forward(ctx, x, dim, dim1_split_sizes): def forward(ctx, x, dim, dim1_split_sizes):
# WAR for a potential contig check torch bug for channels last contig tensors # WAR for a potential contig check torch bug for channels last contig tensors
xlist, dim0_split_sizes, _ = _transpose(x, dim[0], dim[1], dim1_split_sizes, group=polar_group()) xlist, dim0_split_sizes, _ = _transpose(x, dim[0], dim[1], dim1_split_sizes, group=polar_group())
...@@ -128,6 +132,7 @@ class distributed_transpose_polar(torch.autograd.Function): ...@@ -128,6 +132,7 @@ class distributed_transpose_polar(torch.autograd.Function):
return x return x
@staticmethod @staticmethod
@custom_bwd(device_type="cuda")
def backward(ctx, go): def backward(ctx, go):
dim = ctx.dim dim = ctx.dim
dim0_split_sizes = ctx.dim0_split_sizes dim0_split_sizes = ctx.dim0_split_sizes
...@@ -194,15 +199,10 @@ def _gather(input_, dim_, shapes_, group=None): ...@@ -194,15 +199,10 @@ def _gather(input_, dim_, shapes_, group=None):
input_shape = list(input_.shape) input_shape = list(input_.shape)
if shapes_ is not None: if shapes_ is not None:
input_list = [None] * comm_size input_list = []
for src in range(comm_size): for src in range(comm_size):
input_shape[dim_] = shapes_[src] input_shape[dim_] = shapes_[src]
input_list[src] = torch.empty( input_list.append(torch.empty(input_shape, dtype=input_.dtype, device=input_.device))
input_shape,
dtype=input_.dtype,
device=input_.device,
)
else: else:
# assume equal shape on all ranks # assume equal shape on all ranks
input_list = [torch.empty_like(input_) for _ in range(comm_size)] input_list = [torch.empty_like(input_) for _ in range(comm_size)]
...@@ -251,10 +251,12 @@ class _CopyToPolarRegion(torch.autograd.Function): ...@@ -251,10 +251,12 @@ class _CopyToPolarRegion(torch.autograd.Function):
return input_ return input_
@staticmethod @staticmethod
@custom_fwd(device_type="cuda")
def forward(ctx, input_): def forward(ctx, input_):
return input_ return input_
@staticmethod @staticmethod
@custom_bwd(device_type="cuda")
def backward(ctx, grad_output): def backward(ctx, grad_output):
if is_distributed_polar(): if is_distributed_polar():
return _reduce(grad_output, group=polar_group()) return _reduce(grad_output, group=polar_group())
...@@ -270,6 +272,7 @@ class _ScatterToPolarRegion(torch.autograd.Function): ...@@ -270,6 +272,7 @@ class _ScatterToPolarRegion(torch.autograd.Function):
return _split(input_, dim_, group=polar_group()) return _split(input_, dim_, group=polar_group())
@staticmethod @staticmethod
@custom_fwd(device_type="cuda")
def forward(ctx, input_, dim_): def forward(ctx, input_, dim_):
if is_distributed_polar(): if is_distributed_polar():
ctx.dim = dim_ ctx.dim = dim_
...@@ -281,6 +284,7 @@ class _ScatterToPolarRegion(torch.autograd.Function): ...@@ -281,6 +284,7 @@ class _ScatterToPolarRegion(torch.autograd.Function):
return input_ return input_
@staticmethod @staticmethod
@custom_bwd(device_type="cuda")
def backward(ctx, grad_output): def backward(ctx, grad_output):
if is_distributed_polar(): if is_distributed_polar():
return _gather(grad_output, ctx.dim, ctx.split_shapes, polar_group()), None return _gather(grad_output, ctx.dim, ctx.split_shapes, polar_group()), None
...@@ -296,6 +300,7 @@ class _GatherFromPolarRegion(torch.autograd.Function): ...@@ -296,6 +300,7 @@ class _GatherFromPolarRegion(torch.autograd.Function):
return _gather(input_, dim_, shapes_, polar_group()) return _gather(input_, dim_, shapes_, polar_group())
@staticmethod @staticmethod
@custom_fwd(device_type="cuda")
def forward(ctx, input_, dim_, shapes_): def forward(ctx, input_, dim_, shapes_):
if is_distributed_polar(): if is_distributed_polar():
ctx.dim = dim_ ctx.dim = dim_
...@@ -304,6 +309,7 @@ class _GatherFromPolarRegion(torch.autograd.Function): ...@@ -304,6 +309,7 @@ class _GatherFromPolarRegion(torch.autograd.Function):
return input_ return input_
@staticmethod @staticmethod
@custom_bwd(device_type="cuda")
def backward(ctx, grad_output): def backward(ctx, grad_output):
if is_distributed_polar(): if is_distributed_polar():
return _split(grad_output, ctx.dim, group=polar_group()), None, None return _split(grad_output, ctx.dim, group=polar_group()), None, None
...@@ -322,6 +328,7 @@ class _ReduceFromPolarRegion(torch.autograd.Function): ...@@ -322,6 +328,7 @@ class _ReduceFromPolarRegion(torch.autograd.Function):
return input_ return input_
@staticmethod @staticmethod
@custom_fwd(device_type="cuda")
def forward(ctx, input_): def forward(ctx, input_):
if is_distributed_polar(): if is_distributed_polar():
return _reduce(input_, group=polar_group()) return _reduce(input_, group=polar_group())
...@@ -329,6 +336,7 @@ class _ReduceFromPolarRegion(torch.autograd.Function): ...@@ -329,6 +336,7 @@ class _ReduceFromPolarRegion(torch.autograd.Function):
return input_ return input_
@staticmethod @staticmethod
@custom_bwd(device_type="cuda")
def backward(ctx, grad_output): def backward(ctx, grad_output):
return grad_output return grad_output
...@@ -344,6 +352,7 @@ class _ReduceFromScatterToPolarRegion(torch.autograd.Function): ...@@ -344,6 +352,7 @@ class _ReduceFromScatterToPolarRegion(torch.autograd.Function):
return input_ return input_
@staticmethod @staticmethod
@custom_fwd(device_type="cuda")
def forward(ctx, input_, dim_): def forward(ctx, input_, dim_):
if is_distributed_polar(): if is_distributed_polar():
ctx.dim = dim_ ctx.dim = dim_
...@@ -355,6 +364,7 @@ class _ReduceFromScatterToPolarRegion(torch.autograd.Function): ...@@ -355,6 +364,7 @@ class _ReduceFromScatterToPolarRegion(torch.autograd.Function):
return input_ return input_
@staticmethod @staticmethod
@custom_bwd(device_type="cuda")
def backward(ctx, grad_output): def backward(ctx, grad_output):
if is_distributed_polar(): if is_distributed_polar():
return _gather(grad_output, ctx.dim, ctx.split_shapes, polar_group()), None return _gather(grad_output, ctx.dim, ctx.split_shapes, polar_group()), None
...@@ -373,6 +383,7 @@ class _GatherFromCopyToPolarRegion(torch.autograd.Function): ...@@ -373,6 +383,7 @@ class _GatherFromCopyToPolarRegion(torch.autograd.Function):
return input_ return input_
@staticmethod @staticmethod
@custom_fwd(device_type="cuda")
def forward(ctx, input_, dim_, shapes_): def forward(ctx, input_, dim_, shapes_):
if is_distributed_polar(): if is_distributed_polar():
ctx.dim = dim_ ctx.dim = dim_
...@@ -381,6 +392,7 @@ class _GatherFromCopyToPolarRegion(torch.autograd.Function): ...@@ -381,6 +392,7 @@ class _GatherFromCopyToPolarRegion(torch.autograd.Function):
return input_ return input_
@staticmethod @staticmethod
@custom_bwd(device_type="cuda")
def backward(ctx, grad_output): def backward(ctx, grad_output):
if is_distributed_polar(): if is_distributed_polar():
return _reduce_scatter(grad_output, ctx.dim, use_fp32=True, group=polar_group()), None, None return _reduce_scatter(grad_output, ctx.dim, use_fp32=True, group=polar_group()), None, None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment