"vscode:/vscode.git/clone" did not exist on "2e544bd77afe019c4bb9d8c6882879c48d3ac65f"
Unverified Commit 29e7fb68 authored by Boris Bonev's avatar Boris Bonev Committed by GitHub
Browse files

Tkurth/cuda disco (#38)



* adding cuda kernels for disco conv

* making psi_idx an attribute

* adding license headers

* adding author files

* reorganizing files

* draft implementation

* added conditional installation to setup.py

* formatting changes

* removing triton kernel in DISCO convolution

* updated github actions

* updated Readme and changelog

* adding another guard for the cuda installation

* renaming the  cuda extension

* simplifying setup.py

* minor bugfix

* Bbonev/cuda disco cleanup (#32)

* cleanup of disco convolutions based on CUDA extension

* fixing unittest

* changing version to experimental 0.7.0a

* initial rewrite of the distributed convolution with CUDA

* fixing streams

* need to fix install options

* fixing streams

* undid setup.py changes

* reset setup.py

* including CUDAStream

* adjusted the precomputation of theta_cutoff. If you rely on this, your models will not be backwards-compatible.

* adjusting theta_cutoff in the unittest

* adding newly refactored kernels for faster compile

* Tkurth/cuda disco distributed fix (#34)

* attempt to make disco distributed

* working distributed convolutions

* fixing distributed conv

* working distributed disco

* removing irrelevant extra argument

* using stream functions from at instead of c10

* using stream functions from at instead of c10, small fix

* Bbonev/disc even filters (#35)

* initial working commit with new convention of counting collocation points across the diameter instead of across the radius

* fixed a bug in the computation of the even kernels

* changing heuristic for computing theta_cutoff

* Fixing unittest

* Readability improvements

* reworked normalization of filter basis functions

* implemented discrete normalization of disco filters

* relaxing tolerances in convolution unit test

* bugfix to correctly support unequal scale factors in latitudes and longitudes

* hotfix to a bug in the imports

* Bbonev/distributed disco refactor (#37)

* cleaned up normalization code in convolution

* formatting changes in distributed convolution

* Fixing default theta_cutoff to be the same in distributed and local case

* fixed distributed convolution to support the same normalization as non-distributed one

* readability improvements

* fixed initial scale of convolution parameter weights and fixed naming of the normalization routine

* Updated Readme.md

* added comment in Dockerfile regarding older architectures

---------
Co-authored-by: default avatarThorsten Kurth <tkurth@nvidia.com>
Co-authored-by: default avatarBoris Bonev <bbonev@nvidia.com>
parent 214fa40a
......@@ -230,7 +230,26 @@ def _gather(input_, dim_, shapes_, group=None):
output = torch.cat(input_list, dim=dim_).contiguous()
return output
class _CopyToPolarRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chunk to the rank."""
@staticmethod
def symbolic(graph, input_):
return input_
@staticmethod
def forward(ctx, input_):
return input_
@staticmethod
def backward(ctx, grad_output):
if is_distributed_polar():
return _reduce(grad_output, group=polar_group())
else:
return grad_output, None
class _ScatterToPolarRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chunk to the rank."""
......@@ -257,6 +276,29 @@ class _ScatterToPolarRegion(torch.autograd.Function):
else:
return grad_output, None
class _GatherFromPolarRegion(torch.autograd.Function):
"""Gather the input and keep it on the rank."""
@staticmethod
def symbolic(graph, input_, dim_, shapes_):
return _gather(input_, dim_, shapes_, polar_group())
@staticmethod
def forward(ctx, input_, dim_, shapes_):
if is_distributed_polar():
ctx.dim = dim_
return _gather(input_, dim_, shapes_, group=polar_group())
else:
return input_
@staticmethod
def backward(ctx, grad_output):
if is_distributed_polar():
return _split(grad_output, ctx.dim, group=polar_group()), None, None
else:
return grad_output, None, None
class _ReduceFromPolarRegion(torch.autograd.Function):
"""All-reduce the input from the polar region."""
......@@ -279,6 +321,10 @@ class _ReduceFromPolarRegion(torch.autograd.Function):
def backward(ctx, grad_output):
return grad_output
def copy_to_polar_region(input_):
return _CopyToPolarRegion.apply(input_)
def reduce_from_polar_region(input_):
return _ReduceFromPolarRegion.apply(input_)
......@@ -286,3 +332,7 @@ def reduce_from_polar_region(input_):
def scatter_to_polar_region(input_, dim_):
return _ScatterToPolarRegion.apply(input_, dim_)
def gather_from_polar_region(input_, dim_, shapes_):
return _GatherFromPolarRegion.apply(input_, dim_, shapes_)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment