Unverified Commit 89fb38df authored by Thorsten Kurth's avatar Thorsten Kurth Committed by GitHub
Browse files

adding reduce_scatter (#40)

parent 3a3480b8
...@@ -130,6 +130,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase): ...@@ -130,6 +130,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
lon_shapes = convolution_dist.lon_out_shapes lon_shapes = convolution_dist.lon_out_shapes
# gather in W # gather in W
tensor = tensor.contiguous()
if self.grid_size_w > 1: if self.grid_size_w > 1:
gather_shapes = [(B, C, lat_shapes[self.hrank], w) for w in lon_shapes] gather_shapes = [(B, C, lat_shapes[self.hrank], w) for w in lon_shapes]
olist = [torch.empty(shape, dtype=tensor.dtype, device=tensor.device) for shape in gather_shapes] olist = [torch.empty(shape, dtype=tensor.dtype, device=tensor.device) for shape in gather_shapes]
...@@ -140,6 +141,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase): ...@@ -140,6 +141,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
tensor_gather = tensor tensor_gather = tensor
# gather in H # gather in H
tensor_gather = tensor_gather.contiguous()
if self.grid_size_h > 1: if self.grid_size_h > 1:
gather_shapes = [(B, C, h, convolution_dist.nlon_out) for h in lat_shapes] gather_shapes = [(B, C, h, convolution_dist.nlon_out) for h in lat_shapes]
olist = [torch.empty(shape, dtype=tensor_gather.dtype, device=tensor_gather.device) for shape in gather_shapes] olist = [torch.empty(shape, dtype=tensor_gather.dtype, device=tensor_gather.device) for shape in gather_shapes]
...@@ -268,6 +270,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase): ...@@ -268,6 +270,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
############################################################# #############################################################
with torch.no_grad(): with torch.no_grad():
igrad_gather_full = self._gather_helper_bwd(igrad_local, B, C, conv_dist) igrad_gather_full = self._gather_helper_bwd(igrad_local, B, C, conv_dist)
err = torch.mean(torch.norm(igrad_full - igrad_gather_full, p="fro", dim=(-1, -2)) / torch.norm(igrad_full, p="fro", dim=(-1, -2))) err = torch.mean(torch.norm(igrad_full - igrad_gather_full, p="fro", dim=(-1, -2)) / torch.norm(igrad_full, p="fro", dim=(-1, -2)))
if self.world_rank == 0: if self.world_rank == 0:
print(f"final relative error of gradients: {err.item()}") print(f"final relative error of gradients: {err.item()}")
......
...@@ -39,7 +39,9 @@ from .primitives import ( ...@@ -39,7 +39,9 @@ from .primitives import (
reduce_from_polar_region, reduce_from_polar_region,
scatter_to_polar_region, scatter_to_polar_region,
gather_from_polar_region, gather_from_polar_region,
copy_to_polar_region copy_to_polar_region,
reduce_from_scatter_to_polar_region,
gather_from_copy_to_polar_region
) )
# import the sht # import the sht
......
...@@ -54,7 +54,7 @@ from torch_harmonics.convolution import ( ...@@ -54,7 +54,7 @@ from torch_harmonics.convolution import (
from torch_harmonics.distributed import polar_group_size, azimuth_group_size from torch_harmonics.distributed import polar_group_size, azimuth_group_size
from torch_harmonics.distributed import distributed_transpose_azimuth, distributed_transpose_polar from torch_harmonics.distributed import distributed_transpose_azimuth, distributed_transpose_polar
from torch_harmonics.distributed import copy_to_polar_region, reduce_from_polar_region, scatter_to_polar_region, gather_from_polar_region from torch_harmonics.distributed import reduce_from_scatter_to_polar_region, gather_from_copy_to_polar_region
from torch_harmonics.distributed import polar_group_rank, azimuth_group_rank from torch_harmonics.distributed import polar_group_rank, azimuth_group_rank
from torch_harmonics.distributed import compute_split_shapes, split_tensor_along_dim from torch_harmonics.distributed import compute_split_shapes, split_tensor_along_dim
...@@ -219,7 +219,7 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv): ...@@ -219,7 +219,7 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
# compute theta cutoff based on the bandlimit of the input field # compute theta cutoff based on the bandlimit of the input field
if theta_cutoff is None: if theta_cutoff is None:
theta_cutoff = (self.kernel_shape[0] + 1) / 2 * torch.pi / float(self.nlat_out - 1) theta_cutoff = torch.pi / float(self.nlat_out - 1)
if theta_cutoff <= 0.0: if theta_cutoff <= 0.0:
raise ValueError("Error, theta_cutoff has to be positive.") raise ValueError("Error, theta_cutoff has to be positive.")
...@@ -268,7 +268,7 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv): ...@@ -268,7 +268,7 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
# store number of channels # store number of channels
num_chans = x.shape[1] num_chans = x.shape[1]
# h and w is split. First we make w local by transposing into channel dim # h and w is split. First we make w local by transposing into channel dim
if self.comm_size_azimuth > 1: if self.comm_size_azimuth > 1:
x = distributed_transpose_azimuth.apply(x, (1, -1), self.lon_in_shapes) x = distributed_transpose_azimuth.apply(x, (1, -1), self.lon_in_shapes)
...@@ -288,11 +288,8 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv): ...@@ -288,11 +288,8 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
x = _disco_s2_contraction_torch(x, psi, self.nlon_out) x = _disco_s2_contraction_torch(x, psi, self.nlon_out)
# allreduce over latitudes: h is still local # perform reduce scatter in polar region
x = reduce_from_polar_region(x) x = reduce_from_scatter_to_polar_region(x, -2)
# split tensor along latitudes: h is split
x = scatter_to_polar_region(x, -2)
# now we can transpose back the result, so that lon is split and channels are local # now we can transpose back the result, so that lon is split and channels are local
if self.comm_size_azimuth > 1: if self.comm_size_azimuth > 1:
...@@ -352,7 +349,7 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv): ...@@ -352,7 +349,7 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
# bandlimit # bandlimit
if theta_cutoff is None: if theta_cutoff is None:
theta_cutoff = (self.kernel_shape[0] + 1) / 2 * torch.pi / float(self.nlat_in - 1) theta_cutoff = torch.pi / float(self.nlat_in - 1)
if theta_cutoff <= 0.0: if theta_cutoff <= 0.0:
raise ValueError("Error, theta_cutoff has to be positive.") raise ValueError("Error, theta_cutoff has to be positive.")
...@@ -429,11 +426,8 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv): ...@@ -429,11 +426,8 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
# multiply weights # multiply weights
x = self.quad_weights * x x = self.quad_weights * x
# we need to gather the input tensor # gather input tensor and set up backward reduction hooks
x = gather_from_polar_region(x, -2, self.lat_in_shapes) x = gather_from_copy_to_polar_region(x, -2, self.lat_in_shapes)
# register allreduce for bwd pass
x = copy_to_polar_region(x)
if x.is_cuda and _cuda_extension_available: if x.is_cuda and _cuda_extension_available:
out = _disco_s2_transpose_contraction_cuda( out = _disco_s2_transpose_contraction_cuda(
......
...@@ -56,14 +56,6 @@ def compute_split_shapes(size: int, num_chunks: int) -> List[int]: ...@@ -56,14 +56,6 @@ def compute_split_shapes(size: int, num_chunks: int) -> List[int]:
return sections return sections
# general helpers
def get_memory_format(tensor):
if tensor.is_contiguous(memory_format=torch.channels_last):
return torch.channels_last
else:
return torch.contiguous_format
def split_tensor_along_dim(tensor, dim, num_chunks): def split_tensor_along_dim(tensor, dim, num_chunks):
assert dim < tensor.dim(), f"Error, tensor dimension is {tensor.dim()} which cannot be split along {dim}" assert dim < tensor.dim(), f"Error, tensor dimension is {tensor.dim()} which cannot be split along {dim}"
...@@ -78,23 +70,20 @@ def split_tensor_along_dim(tensor, dim, num_chunks): ...@@ -78,23 +70,20 @@ def split_tensor_along_dim(tensor, dim, num_chunks):
def _transpose(tensor, dim0, dim1, dim1_split_sizes, group=None, async_op=False): def _transpose(tensor, dim0, dim1, dim1_split_sizes, group=None, async_op=False):
# get input format
input_format = get_memory_format(tensor)
# get comm params # get comm params
comm_size = dist.get_world_size(group=group) comm_size = dist.get_world_size(group=group)
comm_rank = dist.get_rank(group=group) comm_rank = dist.get_rank(group=group)
# split and local transposition # split and local transposition
tsplit = split_tensor_along_dim(tensor, num_chunks=comm_size, dim=dim0) tsplit = split_tensor_along_dim(tensor, num_chunks=comm_size, dim=dim0)
x_send = [y.contiguous(memory_format=input_format) for y in tsplit] x_send = [y.contiguous() for y in tsplit]
x_send_shapes = [x.shape for x in x_send] x_send_shapes = [x.shape for x in x_send]
x_recv = [] x_recv = []
x_shape = list(x_send_shapes[comm_rank]) x_shape = list(x_send_shapes[comm_rank])
for dim1_len in dim1_split_sizes: for dim1_len in dim1_split_sizes:
x_shape[dim1] = dim1_len x_shape[dim1] = dim1_len
x_recv.append(torch.empty(x_shape, dtype=tensor.dtype, device=tensor.device, memory_format=input_format)) x_recv.append(torch.empty(x_shape, dtype=tensor.dtype, device=tensor.device))
# global transposition # global transposition
req = dist.all_to_all(x_recv, x_send, group=group, async_op=async_op) req = dist.all_to_all(x_recv, x_send, group=group, async_op=async_op)
...@@ -108,24 +97,24 @@ class distributed_transpose_azimuth(torch.autograd.Function): ...@@ -108,24 +97,24 @@ class distributed_transpose_azimuth(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x, dims, dim1_split_sizes): def forward(ctx, x, dims, dim1_split_sizes):
input_format = get_memory_format(x)
# WAR for a potential contig check torch bug for channels last contig tensors # WAR for a potential contig check torch bug for channels last contig tensors
x = x.contiguous() x = x.contiguous()
xlist, dim0_split_sizes, _ = _transpose(x, dims[0], dims[1], dim1_split_sizes, group=azimuth_group()) xlist, dim0_split_sizes, _ = _transpose(x, dims[0], dims[1], dim1_split_sizes, group=azimuth_group())
x = torch.cat(xlist, dim=dims[1]).contiguous(memory_format=input_format) x = torch.cat(xlist, dim=dims[1]).contiguous()
ctx.dims = dims ctx.dims = dims
ctx.dim0_split_sizes = dim0_split_sizes ctx.dim0_split_sizes = dim0_split_sizes
return x return x
@staticmethod @staticmethod
def backward(ctx, go): def backward(ctx, go):
input_format = get_memory_format(go)
dims = ctx.dims dims = ctx.dims
dim0_split_sizes = ctx.dim0_split_sizes dim0_split_sizes = ctx.dim0_split_sizes
# WAR for a potential contig check torch bug for channels last contig tensors # WAR for a potential contig check torch bug for channels last contig tensors
go = go.contiguous() go = go.contiguous()
gilist, _, _ = _transpose(go, dims[1], dims[0], dim0_split_sizes, group=azimuth_group()) gilist, _, _ = _transpose(go, dims[1], dims[0], dim0_split_sizes, group=azimuth_group())
gi = torch.cat(gilist, dim=dims[0]).contiguous(memory_format=input_format) gi = torch.cat(gilist, dim=dims[0]).contiguous()
return gi, None, None return gi, None, None
...@@ -133,24 +122,22 @@ class distributed_transpose_polar(torch.autograd.Function): ...@@ -133,24 +122,22 @@ class distributed_transpose_polar(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x, dim, dim1_split_sizes): def forward(ctx, x, dim, dim1_split_sizes):
input_format = get_memory_format(x)
# WAR for a potential contig check torch bug for channels last contig tensors # WAR for a potential contig check torch bug for channels last contig tensors
x = x.contiguous() x = x.contiguous()
xlist, dim0_split_sizes, _ = _transpose(x, dim[0], dim[1], dim1_split_sizes, group=polar_group()) xlist, dim0_split_sizes, _ = _transpose(x, dim[0], dim[1], dim1_split_sizes, group=polar_group())
x = torch.cat(xlist, dim=dim[1]).contiguous(memory_format=input_format) x = torch.cat(xlist, dim=dim[1]).contiguous()
ctx.dim = dim ctx.dim = dim
ctx.dim0_split_sizes = dim0_split_sizes ctx.dim0_split_sizes = dim0_split_sizes
return x return x
@staticmethod @staticmethod
def backward(ctx, go): def backward(ctx, go):
input_format = get_memory_format(go)
dim = ctx.dim dim = ctx.dim
dim0_split_sizes = ctx.dim0_split_sizes dim0_split_sizes = ctx.dim0_split_sizes
# WAR for a potential contig check torch bug for channels last contig tensors # WAR for a potential contig check torch bug for channels last contig tensors
go = go.contiguous() go = go.contiguous()
gilist, _, _ = _transpose(go, dim[1], dim[0], dim0_split_sizes, group=polar_group()) gilist, _, _ = _transpose(go, dim[1], dim[0], dim0_split_sizes, group=polar_group())
gi = torch.cat(gilist, dim=dim[0]).contiguous(memory_format=input_format) gi = torch.cat(gilist, dim=dim[0]).contiguous()
return gi, None, None return gi, None, None
...@@ -175,7 +162,7 @@ def _reduce(input_, use_fp32=True, group=None): ...@@ -175,7 +162,7 @@ def _reduce(input_, use_fp32=True, group=None):
dist.all_reduce(input_, group=group) dist.all_reduce(input_, group=group)
return input_ return input_
def _split(input_, dim_, group=None): def _split(input_, dim_, group=None):
"""Split the tensor along its last dimension and keep the corresponding slice.""" """Split the tensor along its last dimension and keep the corresponding slice."""
...@@ -232,6 +219,33 @@ def _gather(input_, dim_, shapes_, group=None): ...@@ -232,6 +219,33 @@ def _gather(input_, dim_, shapes_, group=None):
return output return output
def _reduce_scatter(input_, dim_, use_fp32=True, group=None):
"""All-reduce the input tensor across model parallel group and scatter it back."""
# Bypass the function if we are using only 1 GPU.
if dist.get_world_size(group=group) == 1:
return input_
# make input contiguous
comm_size = dist.get_world_size(group=group)
comm_rank = dist.get_rank(group=group)
input_list = [x.contiguous() for x in split_tensor_along_dim(input_, dim_, comm_size)]
dtype = input_.dtype
if (use_fp32 and (dtype != torch.float32)):
input_list = [x.to(torch.float32) for x in input_list]
# perform reduce_scatter
output = torch.empty_like(input_list[comm_rank])
dist.reduce_scatter(output, input_list, group=group)
# convert dtype if necessary
if use_fp32:
output = output.to(dtype=dtype)
return output
class _CopyToPolarRegion(torch.autograd.Function): class _CopyToPolarRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chunk to the rank.""" """Split the input and keep only the corresponding chunk to the rank."""
...@@ -322,6 +336,62 @@ class _ReduceFromPolarRegion(torch.autograd.Function): ...@@ -322,6 +336,62 @@ class _ReduceFromPolarRegion(torch.autograd.Function):
return grad_output return grad_output
class _ReduceFromScatterToPolarRegion(torch.autograd.Function):
"""All-reduce the input from the polar region and scatter back to polar region."""
@staticmethod
def symbolic(graph, input_, dim_):
if is_distributed_polar():
return _reduce_scatter(input_, dim_, group=polar_group())
else:
return input_
@staticmethod
def forward(ctx, input_, dim_):
if is_distributed_polar():
ctx.dim = dim_
ctx.split_shapes = compute_split_shapes(
input_.shape[dim_], polar_group_size()
)
return _reduce_scatter(input_, dim_, group=polar_group())
else:
return input_
@staticmethod
def backward(ctx, grad_output):
if is_distributed_polar():
return _gather(grad_output, ctx.dim, ctx.split_shapes, polar_group()), None
else:
return grad_output, None
class _GatherFromCopyToPolarRegion(torch.autograd.Function):
"""Gather the input from the polar region and register BWD AR, basically the inverse of reduce-scatter"""
@staticmethod
def symbolic(graph, input_, dim_, shapes_):
if is_distributed_polar():
return _gather(input_, dim_, shapes_, polar_group())
else:
return input_
@staticmethod
def forward(ctx, input_, dim_, shapes_):
if is_distributed_polar():
ctx.dim = dim_
return _gather(input_, dim_, shapes_, group=polar_group())
else:
return input_
@staticmethod
def backward(ctx, grad_output):
if is_distributed_polar():
return _reduce_scatter(grad_output, ctx.dim, use_fp32=True, group=polar_group()), None, None
else:
return grad_output, None, None
def copy_to_polar_region(input_): def copy_to_polar_region(input_):
return _CopyToPolarRegion.apply(input_) return _CopyToPolarRegion.apply(input_)
...@@ -336,3 +406,11 @@ def scatter_to_polar_region(input_, dim_): ...@@ -336,3 +406,11 @@ def scatter_to_polar_region(input_, dim_):
def gather_from_polar_region(input_, dim_, shapes_): def gather_from_polar_region(input_, dim_, shapes_):
return _GatherFromPolarRegion.apply(input_, dim_, shapes_) return _GatherFromPolarRegion.apply(input_, dim_, shapes_)
def reduce_from_scatter_to_polar_region(input_, dim_):
return _ReduceFromScatterToPolarRegion.apply(input_, dim_)
def gather_from_copy_to_polar_region(input_, dim_, shapes_):
return _GatherFromCopyToPolarRegion.apply(input_, dim_, shapes_)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment