Commit 6a845fd3 authored by Boris Bonev's avatar Boris Bonev Committed by Boris Bonev
Browse files

adding spherical attention

parent b3816ebc
...@@ -156,8 +156,8 @@ ...@@ -156,8 +156,8 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"model = SFNO(img_size=(nlat, nlon), grid=\"equiangular\",\n", "model = SFNO(spectral_transform='sht', operator_type='driscoll-healy', img_size=(nlat, nlon), grid=\"equiangular\",\n",
" num_layers=4, scale_factor=3, embed_dim=16, big_skip=True, pos_embed=\"lat\", use_mlp=False, normalization_layer=\"none\").to(device)\n" " num_layers=4, scale_factor=3, embed_dim=16, residual_prediction=True, pos_embed=\"lat\", use_mlp=False, normalization_layer=\"none\").to(device)\n"
] ]
}, },
{ {
......
...@@ -7,7 +7,10 @@ name = "torch_harmonics" ...@@ -7,7 +7,10 @@ name = "torch_harmonics"
authors = [ authors = [
{ name="Boris Bonev" }, { name="Boris Bonev" },
{ name="Thorsten Kurth" }, { name="Thorsten Kurth" },
{ name="Max Rietmann" },
{ name="Mauro Bisson" }, { name="Mauro Bisson" },
{ name="Andrea Paris" },
{ name="Alberto Carpentieri" },
{ name="Massimiliano Fatica" }, { name="Massimiliano Fatica" },
{ name="Jean Kossaifi" }, { name="Jean Kossaifi" },
{ name="Nikola Kovachki" }, { name="Nikola Kovachki" },
...@@ -38,6 +41,7 @@ dependencies = [ ...@@ -38,6 +41,7 @@ dependencies = [
"numpy>=1.22.4", "numpy>=1.22.4",
] ]
[tool.setuptools.dynamic] [tool.setuptools.dynamic]
version = {attr = "torch_harmonics.__version__"} version = {attr = "torch_harmonics.__version__"}
...@@ -49,3 +53,10 @@ dev = [ ...@@ -49,3 +53,10 @@ dev = [
"pytest>=6.0.0", "pytest>=6.0.0",
"coverage>=6.5.0", "coverage>=6.5.0",
] ]
2d3ds = [
"requests",
"tarfile",
"tqdm",
"PIL",
"h5py",
]
...@@ -53,6 +53,22 @@ try: ...@@ -53,6 +53,22 @@ try:
except (ImportError, TypeError, AssertionError, AttributeError) as e: except (ImportError, TypeError, AssertionError, AttributeError) as e:
warnings.warn(f"building custom extensions skipped: {e}") warnings.warn(f"building custom extensions skipped: {e}")
def get_compile_args(module_name):
"""If user runs build with TORCH_HARMONICS_DEBUG=1 set, it will use debugging flags to build"""
debug_mode = os.environ.get('TORCH_HARMONICS_DEBUG', '0') == '1'
if debug_mode:
print(f"WARNING: Compiling {module_name} with debugging flags")
return {
'cxx': ['-g', '-O0', '-Wall'],
'nvcc': ['-g', '-G', '-O0']
}
else:
print(f"NOTE: Compiling {module_name} with release flags")
return {
'cxx': ['-O3', "-DNDEBUG"],
'nvcc': ['-O3', "-DNDEBUG"]
}
def get_ext_modules(): def get_ext_modules():
ext_modules = [] ext_modules = []
...@@ -73,6 +89,19 @@ def get_ext_modules(): ...@@ -73,6 +89,19 @@ def get_ext_modules():
"torch_harmonics/csrc/disco/disco_cuda_fwd.cu", "torch_harmonics/csrc/disco/disco_cuda_fwd.cu",
"torch_harmonics/csrc/disco/disco_cuda_bwd.cu", "torch_harmonics/csrc/disco/disco_cuda_bwd.cu",
], ],
extra_compile_args=get_compile_args("disco")
)
)
ext_modules.append(
CUDAExtension(
name="attention_cuda_extension",
sources=[
"torch_harmonics/csrc/attention/attention_fwd_cuda.cu",
"torch_harmonics/csrc/attention/attention_bwd_cuda.cu",
"torch_harmonics/csrc/attention/attention_interface.cu",
"torch_harmonics/csrc/attention/attention_row_offset.cu"
],
extra_compile_args=get_compile_args("neighborhood_attention")
) )
) )
cmdclass["build_ext"] = BuildExtension cmdclass["build_ext"] = BuildExtension
......
This diff is collapsed.
...@@ -36,7 +36,7 @@ import math ...@@ -36,7 +36,7 @@ import math
import numpy as np import numpy as np
import torch import torch
from torch.autograd import gradcheck from torch.autograd import gradcheck
from torch_harmonics import * from torch_harmonics import quadrature, DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes
......
...@@ -36,7 +36,7 @@ from parameterized import parameterized ...@@ -36,7 +36,7 @@ from parameterized import parameterized
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.distributed as dist import torch.distributed as dist
import torch_harmonics as harmonics import torch_harmonics as th
import torch_harmonics.distributed as thd import torch_harmonics.distributed as thd
...@@ -219,10 +219,10 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase): ...@@ -219,10 +219,10 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
# set up handles # set up handles
if transpose: if transpose:
conv_local = harmonics.DiscreteContinuousConvTransposeS2(**disco_args).to(self.device) conv_local = th.DiscreteContinuousConvTransposeS2(**disco_args).to(self.device)
conv_dist = thd.DistributedDiscreteContinuousConvTransposeS2(**disco_args).to(self.device) conv_dist = thd.DistributedDiscreteContinuousConvTransposeS2(**disco_args).to(self.device)
else: else:
conv_local = harmonics.DiscreteContinuousConvS2(**disco_args).to(self.device) conv_local = th.DiscreteContinuousConvS2(**disco_args).to(self.device)
conv_dist = thd.DistributedDiscreteContinuousConvS2(**disco_args).to(self.device) conv_dist = thd.DistributedDiscreteContinuousConvS2(**disco_args).to(self.device)
# copy the weights from the local conv into the dist conv # copy the weights from the local conv into the dist conv
......
...@@ -36,7 +36,7 @@ from parameterized import parameterized ...@@ -36,7 +36,7 @@ from parameterized import parameterized
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.distributed as dist import torch.distributed as dist
import torch_harmonics as harmonics import torch_harmonics as th
import torch_harmonics.distributed as thd import torch_harmonics.distributed as thd
...@@ -206,7 +206,7 @@ class TestDistributedResampling(unittest.TestCase): ...@@ -206,7 +206,7 @@ class TestDistributedResampling(unittest.TestCase):
) )
# set up handlesD # set up handlesD
res_local = harmonics.ResampleS2(**res_args).to(self.device) res_local = th.ResampleS2(**res_args).to(self.device)
res_dist = thd.DistributedResampleS2(**res_args).to(self.device) res_dist = thd.DistributedResampleS2(**res_args).to(self.device)
# create tensors # create tensors
......
...@@ -36,7 +36,7 @@ from parameterized import parameterized ...@@ -36,7 +36,7 @@ from parameterized import parameterized
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.distributed as dist import torch.distributed as dist
import torch_harmonics as harmonics import torch_harmonics as th
import torch_harmonics.distributed as thd import torch_harmonics.distributed as thd
...@@ -218,10 +218,10 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase): ...@@ -218,10 +218,10 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
# set up handles # set up handles
if vector: if vector:
forward_transform_local = harmonics.RealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device) forward_transform_local = th.RealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
forward_transform_dist = thd.DistributedRealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device) forward_transform_dist = thd.DistributedRealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
else: else:
forward_transform_local = harmonics.RealSHT(nlat=H, nlon=W, grid=grid).to(self.device) forward_transform_local = th.RealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
forward_transform_dist = thd.DistributedRealSHT(nlat=H, nlon=W, grid=grid).to(self.device) forward_transform_dist = thd.DistributedRealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
# create tensors # create tensors
...@@ -304,12 +304,12 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase): ...@@ -304,12 +304,12 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
B, C, H, W = batch_size, num_chan, nlat, nlon B, C, H, W = batch_size, num_chan, nlat, nlon
if vector: if vector:
forward_transform_local = harmonics.RealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device) forward_transform_local = th.RealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
backward_transform_local = harmonics.InverseRealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device) backward_transform_local = th.InverseRealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
backward_transform_dist = thd.DistributedInverseRealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device) backward_transform_dist = thd.DistributedInverseRealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
else: else:
forward_transform_local = harmonics.RealSHT(nlat=H, nlon=W, grid=grid).to(self.device) forward_transform_local = th.RealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
backward_transform_local = harmonics.InverseRealSHT(nlat=H, nlon=W, grid=grid).to(self.device) backward_transform_local = th.InverseRealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
backward_transform_dist = thd.DistributedInverseRealSHT(nlat=H, nlon=W, grid=grid).to(self.device) backward_transform_dist = thd.DistributedInverseRealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
# create tensors # create tensors
......
...@@ -34,8 +34,7 @@ from parameterized import parameterized ...@@ -34,8 +34,7 @@ from parameterized import parameterized
import math import math
import torch import torch
from torch.autograd import gradcheck from torch.autograd import gradcheck
from torch_harmonics import * import torch_harmonics as th
class TestLegendrePolynomials(unittest.TestCase): class TestLegendrePolynomials(unittest.TestCase):
...@@ -63,10 +62,9 @@ class TestLegendrePolynomials(unittest.TestCase): ...@@ -63,10 +62,9 @@ class TestLegendrePolynomials(unittest.TestCase):
def test_legendre(self, verbose=False): def test_legendre(self, verbose=False):
if verbose: if verbose:
print("Testing computation of associated Legendre polynomials") print("Testing computation of associated Legendre polynomials")
from torch_harmonics.legendre import legpoly
t = torch.linspace(0, 1, 100, dtype=torch.float64) t = torch.linspace(0, 1, 100, dtype=torch.float64)
vdm = legpoly(self.mmax, self.lmax, t) vdm = th.legendre.legpoly(self.mmax, self.lmax, t)
for l in range(self.lmax): for l in range(self.lmax):
for m in range(l + 1): for m in range(l + 1):
...@@ -109,8 +107,8 @@ class TestSphericalHarmonicTransform(unittest.TestCase): ...@@ -109,8 +107,8 @@ class TestSphericalHarmonicTransform(unittest.TestCase):
mmax = nlat mmax = nlat
lmax = mmax lmax = mmax
sht = RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device) sht = th.RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
isht = InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device) isht = th.InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
with torch.no_grad(): with torch.no_grad():
coeffs = torch.zeros(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128) coeffs = torch.zeros(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
...@@ -167,8 +165,8 @@ class TestSphericalHarmonicTransform(unittest.TestCase): ...@@ -167,8 +165,8 @@ class TestSphericalHarmonicTransform(unittest.TestCase):
mmax = nlat mmax = nlat
lmax = mmax lmax = mmax
sht = RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device) sht = th.RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
isht = InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device) isht = th.InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
with torch.no_grad(): with torch.no_grad():
coeffs = torch.zeros(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128) coeffs = torch.zeros(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
......
...@@ -29,11 +29,13 @@ ...@@ -29,11 +29,13 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# #
__version__ = "0.7.6" __version__ = "0.8.0"
from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT
from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2 from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
from .resample import ResampleS2 from .resample import ResampleS2
from .attention import AttentionS2, NeighborhoodAttentionS2
from ._neighborhood_attention import _neighborhood_attention_s2_fwd_torch, _NeighborhoodAttentionS2 # for tests
from . import quadrature from . import quadrature
from . import random_fields from . import random_fields
from . import examples from . import examples
...@@ -36,10 +36,8 @@ from torch.amp import custom_fwd, custom_bwd ...@@ -36,10 +36,8 @@ from torch.amp import custom_fwd, custom_bwd
try: try:
import disco_cuda_extension import disco_cuda_extension
_cuda_extension_available = True
except ImportError as err: except ImportError as err:
disco_cuda_extension = None disco_cuda_extension = None
_cuda_extension_available = False
class _DiscoS2ContractionCuda(torch.autograd.Function): class _DiscoS2ContractionCuda(torch.autograd.Function):
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -33,3 +33,4 @@ from .pde_sphere import SphereSolver ...@@ -33,3 +33,4 @@ from .pde_sphere import SphereSolver
from .shallow_water_equations import ShallowWaterSolver from .shallow_water_equations import ShallowWaterSolver
from .pde_dataset import PdeDataset from .pde_dataset import PdeDataset
from .stanford_2d3ds_dataset import StanfordSegmentationDataset, StanfordDepthDataset, Stanford2D3DSDownloader, compute_stats_s2, StanfordDatasetSubset
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment