Unverified Commit ab0e66bc authored by Thorsten Kurth's avatar Thorsten Kurth Committed by GitHub
Browse files

Tkurth/distributed memory reduction (#41)

* removing unneccessary contiguous statements

* replacing reduce scatter with independent reduce and scatter
parent 89fb38df
...@@ -30,7 +30,7 @@ ...@@ -30,7 +30,7 @@
# build after cloning in directoy torch_harmonics via # build after cloning in directoy torch_harmonics via
# docker build . -t torch_harmonics # docker build . -t torch_harmonics
FROM nvcr.io/nvidia/pytorch:23.11-py3 FROM nvcr.io/nvidia/pytorch:24.05-py3
COPY . /workspace/torch_harmonics COPY . /workspace/torch_harmonics
......
...@@ -54,7 +54,7 @@ from torch_harmonics.convolution import ( ...@@ -54,7 +54,7 @@ from torch_harmonics.convolution import (
from torch_harmonics.distributed import polar_group_size, azimuth_group_size from torch_harmonics.distributed import polar_group_size, azimuth_group_size
from torch_harmonics.distributed import distributed_transpose_azimuth, distributed_transpose_polar from torch_harmonics.distributed import distributed_transpose_azimuth, distributed_transpose_polar
from torch_harmonics.distributed import reduce_from_scatter_to_polar_region, gather_from_copy_to_polar_region from torch_harmonics.distributed import reduce_from_polar_region, scatter_to_polar_region, gather_from_polar_region, copy_to_polar_region
from torch_harmonics.distributed import polar_group_rank, azimuth_group_rank from torch_harmonics.distributed import polar_group_rank, azimuth_group_rank
from torch_harmonics.distributed import compute_split_shapes, split_tensor_along_dim from torch_harmonics.distributed import compute_split_shapes, split_tensor_along_dim
...@@ -289,7 +289,8 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv): ...@@ -289,7 +289,8 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
x = _disco_s2_contraction_torch(x, psi, self.nlon_out) x = _disco_s2_contraction_torch(x, psi, self.nlon_out)
# perform reduce scatter in polar region # perform reduce scatter in polar region
x = reduce_from_scatter_to_polar_region(x, -2) x = reduce_from_polar_region (x)
x = scatter_to_polar_region(x, -2)
# now we can transpose back the result, so that lon is split and channels are local # now we can transpose back the result, so that lon is split and channels are local
if self.comm_size_azimuth > 1: if self.comm_size_azimuth > 1:
...@@ -427,7 +428,8 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv): ...@@ -427,7 +428,8 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
x = self.quad_weights * x x = self.quad_weights * x
# gather input tensor and set up backward reduction hooks # gather input tensor and set up backward reduction hooks
x = gather_from_copy_to_polar_region(x, -2, self.lat_in_shapes) x = gather_from_polar_region(x, -2, self.lat_in_shapes)
x = copy_to_polar_region(x)
if x.is_cuda and _cuda_extension_available: if x.is_cuda and _cuda_extension_available:
out = _disco_s2_transpose_contraction_cuda( out = _disco_s2_transpose_contraction_cuda(
......
...@@ -98,9 +98,8 @@ class distributed_transpose_azimuth(torch.autograd.Function): ...@@ -98,9 +98,8 @@ class distributed_transpose_azimuth(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x, dims, dim1_split_sizes): def forward(ctx, x, dims, dim1_split_sizes):
# WAR for a potential contig check torch bug for channels last contig tensors # WAR for a potential contig check torch bug for channels last contig tensors
x = x.contiguous()
xlist, dim0_split_sizes, _ = _transpose(x, dims[0], dims[1], dim1_split_sizes, group=azimuth_group()) xlist, dim0_split_sizes, _ = _transpose(x, dims[0], dims[1], dim1_split_sizes, group=azimuth_group())
x = torch.cat(xlist, dim=dims[1]).contiguous() x = torch.cat(xlist, dim=dims[1])
ctx.dims = dims ctx.dims = dims
ctx.dim0_split_sizes = dim0_split_sizes ctx.dim0_split_sizes = dim0_split_sizes
...@@ -111,9 +110,8 @@ class distributed_transpose_azimuth(torch.autograd.Function): ...@@ -111,9 +110,8 @@ class distributed_transpose_azimuth(torch.autograd.Function):
dims = ctx.dims dims = ctx.dims
dim0_split_sizes = ctx.dim0_split_sizes dim0_split_sizes = ctx.dim0_split_sizes
# WAR for a potential contig check torch bug for channels last contig tensors # WAR for a potential contig check torch bug for channels last contig tensors
go = go.contiguous()
gilist, _, _ = _transpose(go, dims[1], dims[0], dim0_split_sizes, group=azimuth_group()) gilist, _, _ = _transpose(go, dims[1], dims[0], dim0_split_sizes, group=azimuth_group())
gi = torch.cat(gilist, dim=dims[0]).contiguous() gi = torch.cat(gilist, dim=dims[0])
return gi, None, None return gi, None, None
...@@ -123,9 +121,8 @@ class distributed_transpose_polar(torch.autograd.Function): ...@@ -123,9 +121,8 @@ class distributed_transpose_polar(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x, dim, dim1_split_sizes): def forward(ctx, x, dim, dim1_split_sizes):
# WAR for a potential contig check torch bug for channels last contig tensors # WAR for a potential contig check torch bug for channels last contig tensors
x = x.contiguous()
xlist, dim0_split_sizes, _ = _transpose(x, dim[0], dim[1], dim1_split_sizes, group=polar_group()) xlist, dim0_split_sizes, _ = _transpose(x, dim[0], dim[1], dim1_split_sizes, group=polar_group())
x = torch.cat(xlist, dim=dim[1]).contiguous() x = torch.cat(xlist, dim=dim[1])
ctx.dim = dim ctx.dim = dim
ctx.dim0_split_sizes = dim0_split_sizes ctx.dim0_split_sizes = dim0_split_sizes
return x return x
...@@ -135,9 +132,8 @@ class distributed_transpose_polar(torch.autograd.Function): ...@@ -135,9 +132,8 @@ class distributed_transpose_polar(torch.autograd.Function):
dim = ctx.dim dim = ctx.dim
dim0_split_sizes = ctx.dim0_split_sizes dim0_split_sizes = ctx.dim0_split_sizes
# WAR for a potential contig check torch bug for channels last contig tensors # WAR for a potential contig check torch bug for channels last contig tensors
go = go.contiguous()
gilist, _, _ = _transpose(go, dim[1], dim[0], dim0_split_sizes, group=polar_group()) gilist, _, _ = _transpose(go, dim[1], dim[0], dim0_split_sizes, group=polar_group())
gi = torch.cat(gilist, dim=dim[0]).contiguous() gi = torch.cat(gilist, dim=dim[0])
return gi, None, None return gi, None, None
...@@ -148,17 +144,16 @@ def _reduce(input_, use_fp32=True, group=None): ...@@ -148,17 +144,16 @@ def _reduce(input_, use_fp32=True, group=None):
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if dist.get_world_size(group=group) == 1: if dist.get_world_size(group=group) == 1:
return input_ return input_
# make input contiguous
input_ = input_.contiguous()
# All-reduce. # All-reduce.
if use_fp32: if use_fp32:
dtype = input_.dtype dtype = input_.dtype
inputf_ = input_.float() inputf_ = input_.float()
inputf_ = inputf_.contiguous()
dist.all_reduce(inputf_, group=group) dist.all_reduce(inputf_, group=group)
input_ = inputf_.to(dtype) input_ = inputf_.to(dtype)
else: else:
input_ = input_.contiguous()
dist.all_reduce(input_, group=group) dist.all_reduce(input_, group=group)
return input_ return input_
...@@ -176,7 +171,7 @@ def _split(input_, dim_, group=None): ...@@ -176,7 +171,7 @@ def _split(input_, dim_, group=None):
# Note: torch.split does not create contiguous tensors by default. # Note: torch.split does not create contiguous tensors by default.
rank = dist.get_rank(group=group) rank = dist.get_rank(group=group)
output = input_list[rank].contiguous() output = input_list[rank]
return output return output
...@@ -214,7 +209,7 @@ def _gather(input_, dim_, shapes_, group=None): ...@@ -214,7 +209,7 @@ def _gather(input_, dim_, shapes_, group=None):
dist.all_gather(input_list, input_, group=group) dist.all_gather(input_list, input_, group=group)
output = torch.cat(input_list, dim=dim_).contiguous() output = torch.cat(input_list, dim=dim_)
return output return output
...@@ -229,12 +224,14 @@ def _reduce_scatter(input_, dim_, use_fp32=True, group=None): ...@@ -229,12 +224,14 @@ def _reduce_scatter(input_, dim_, use_fp32=True, group=None):
# make input contiguous # make input contiguous
comm_size = dist.get_world_size(group=group) comm_size = dist.get_world_size(group=group)
comm_rank = dist.get_rank(group=group) comm_rank = dist.get_rank(group=group)
input_list = [x.contiguous() for x in split_tensor_along_dim(input_, dim_, comm_size)] input_list = split_tensor_along_dim(input_, dim_, comm_size)
dtype = input_.dtype dtype = input_.dtype
if (use_fp32 and (dtype != torch.float32)): if (use_fp32 and (dtype != torch.float32)):
input_list = [x.to(torch.float32) for x in input_list] input_list = [x.to(torch.float32) for x in input_list]
input_list = [x.contiguous() for x in input_list]
# perform reduce_scatter # perform reduce_scatter
output = torch.empty_like(input_list[comm_rank]) output = torch.empty_like(input_list[comm_rank])
dist.reduce_scatter(output, input_list, group=group) dist.reduce_scatter(output, input_list, group=group)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment