Commit 6a845fd3 authored by Boris Bonev's avatar Boris Bonev Committed by Boris Bonev
Browse files

adding spherical attention

parent b3816ebc
......@@ -156,8 +156,8 @@
"metadata": {},
"outputs": [],
"source": [
"model = SFNO(img_size=(nlat, nlon), grid=\"equiangular\",\n",
" num_layers=4, scale_factor=3, embed_dim=16, big_skip=True, pos_embed=\"lat\", use_mlp=False, normalization_layer=\"none\").to(device)\n"
"model = SFNO(spectral_transform='sht', operator_type='driscoll-healy', img_size=(nlat, nlon), grid=\"equiangular\",\n",
" num_layers=4, scale_factor=3, embed_dim=16, residual_prediction=True, pos_embed=\"lat\", use_mlp=False, normalization_layer=\"none\").to(device)\n"
]
},
{
......
......@@ -7,7 +7,10 @@ name = "torch_harmonics"
authors = [
{ name="Boris Bonev" },
{ name="Thorsten Kurth" },
{ name="Max Rietmann" },
{ name="Mauro Bisson" },
{ name="Andrea Paris" },
{ name="Alberto Carpentieri" },
{ name="Massimiliano Fatica" },
{ name="Jean Kossaifi" },
{ name="Nikola Kovachki" },
......@@ -38,6 +41,7 @@ dependencies = [
"numpy>=1.22.4",
]
[tool.setuptools.dynamic]
version = {attr = "torch_harmonics.__version__"}
......@@ -49,3 +53,10 @@ dev = [
"pytest>=6.0.0",
"coverage>=6.5.0",
]
2d3ds = [
"requests",
"tarfile",
"tqdm",
"PIL",
"h5py",
]
......@@ -53,6 +53,22 @@ try:
except (ImportError, TypeError, AssertionError, AttributeError) as e:
warnings.warn(f"building custom extensions skipped: {e}")
def get_compile_args(module_name):
"""If user runs build with TORCH_HARMONICS_DEBUG=1 set, it will use debugging flags to build"""
debug_mode = os.environ.get('TORCH_HARMONICS_DEBUG', '0') == '1'
if debug_mode:
print(f"WARNING: Compiling {module_name} with debugging flags")
return {
'cxx': ['-g', '-O0', '-Wall'],
'nvcc': ['-g', '-G', '-O0']
}
else:
print(f"NOTE: Compiling {module_name} with release flags")
return {
'cxx': ['-O3', "-DNDEBUG"],
'nvcc': ['-O3', "-DNDEBUG"]
}
def get_ext_modules():
ext_modules = []
......@@ -73,6 +89,19 @@ def get_ext_modules():
"torch_harmonics/csrc/disco/disco_cuda_fwd.cu",
"torch_harmonics/csrc/disco/disco_cuda_bwd.cu",
],
extra_compile_args=get_compile_args("disco")
)
)
ext_modules.append(
CUDAExtension(
name="attention_cuda_extension",
sources=[
"torch_harmonics/csrc/attention/attention_fwd_cuda.cu",
"torch_harmonics/csrc/attention/attention_bwd_cuda.cu",
"torch_harmonics/csrc/attention/attention_interface.cu",
"torch_harmonics/csrc/attention/attention_row_offset.cu"
],
extra_compile_args=get_compile_args("neighborhood_attention")
)
)
cmdclass["build_ext"] = BuildExtension
......@@ -87,4 +116,4 @@ if __name__ == "__main__":
packages=find_packages(),
ext_modules=ext_modules,
cmdclass=cmdclass,
)
\ No newline at end of file
)
This diff is collapsed.
......@@ -36,7 +36,7 @@ import math
import numpy as np
import torch
from torch.autograd import gradcheck
from torch_harmonics import *
from torch_harmonics import quadrature, DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes
......
......@@ -36,7 +36,7 @@ from parameterized import parameterized
import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch_harmonics as harmonics
import torch_harmonics as th
import torch_harmonics.distributed as thd
......@@ -219,10 +219,10 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
# set up handles
if transpose:
conv_local = harmonics.DiscreteContinuousConvTransposeS2(**disco_args).to(self.device)
conv_local = th.DiscreteContinuousConvTransposeS2(**disco_args).to(self.device)
conv_dist = thd.DistributedDiscreteContinuousConvTransposeS2(**disco_args).to(self.device)
else:
conv_local = harmonics.DiscreteContinuousConvS2(**disco_args).to(self.device)
conv_local = th.DiscreteContinuousConvS2(**disco_args).to(self.device)
conv_dist = thd.DistributedDiscreteContinuousConvS2(**disco_args).to(self.device)
# copy the weights from the local conv into the dist conv
......
......@@ -36,7 +36,7 @@ from parameterized import parameterized
import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch_harmonics as harmonics
import torch_harmonics as th
import torch_harmonics.distributed as thd
......@@ -196,9 +196,9 @@ class TestDistributedResampling(unittest.TestCase):
B, C, H, W = batch_size, num_chan, nlat_in, nlon_in
res_args = dict(
nlat_in=nlat_in,
nlat_in=nlat_in,
nlon_in=nlon_in,
nlat_out=nlat_out,
nlat_out=nlat_out,
nlon_out=nlon_out,
grid_in=grid_in,
grid_out=grid_out,
......@@ -206,7 +206,7 @@ class TestDistributedResampling(unittest.TestCase):
)
# set up handlesD
res_local = harmonics.ResampleS2(**res_args).to(self.device)
res_local = th.ResampleS2(**res_args).to(self.device)
res_dist = thd.DistributedResampleS2(**res_args).to(self.device)
# create tensors
......
......@@ -36,7 +36,7 @@ from parameterized import parameterized
import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch_harmonics as harmonics
import torch_harmonics as th
import torch_harmonics.distributed as thd
......@@ -218,10 +218,10 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
# set up handles
if vector:
forward_transform_local = harmonics.RealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
forward_transform_local = th.RealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
forward_transform_dist = thd.DistributedRealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
else:
forward_transform_local = harmonics.RealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
forward_transform_local = th.RealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
forward_transform_dist = thd.DistributedRealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
# create tensors
......@@ -304,12 +304,12 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
B, C, H, W = batch_size, num_chan, nlat, nlon
if vector:
forward_transform_local = harmonics.RealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
backward_transform_local = harmonics.InverseRealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
forward_transform_local = th.RealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
backward_transform_local = th.InverseRealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
backward_transform_dist = thd.DistributedInverseRealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
else:
forward_transform_local = harmonics.RealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
backward_transform_local = harmonics.InverseRealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
forward_transform_local = th.RealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
backward_transform_local = th.InverseRealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
backward_transform_dist = thd.DistributedInverseRealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
# create tensors
......
......@@ -34,8 +34,7 @@ from parameterized import parameterized
import math
import torch
from torch.autograd import gradcheck
from torch_harmonics import *
import torch_harmonics as th
class TestLegendrePolynomials(unittest.TestCase):
......@@ -63,10 +62,9 @@ class TestLegendrePolynomials(unittest.TestCase):
def test_legendre(self, verbose=False):
if verbose:
print("Testing computation of associated Legendre polynomials")
from torch_harmonics.legendre import legpoly
t = torch.linspace(0, 1, 100, dtype=torch.float64)
vdm = legpoly(self.mmax, self.lmax, t)
vdm = th.legendre.legpoly(self.mmax, self.lmax, t)
for l in range(self.lmax):
for m in range(l + 1):
......@@ -109,8 +107,8 @@ class TestSphericalHarmonicTransform(unittest.TestCase):
mmax = nlat
lmax = mmax
sht = RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
isht = InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
sht = th.RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
isht = th.InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
with torch.no_grad():
coeffs = torch.zeros(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
......@@ -167,8 +165,8 @@ class TestSphericalHarmonicTransform(unittest.TestCase):
mmax = nlat
lmax = mmax
sht = RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
isht = InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
sht = th.RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
isht = th.InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
with torch.no_grad():
coeffs = torch.zeros(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
......
......@@ -29,11 +29,13 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
__version__ = "0.7.6"
__version__ = "0.8.0"
from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT
from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
from .resample import ResampleS2
from .attention import AttentionS2, NeighborhoodAttentionS2
from ._neighborhood_attention import _neighborhood_attention_s2_fwd_torch, _NeighborhoodAttentionS2 # for tests
from . import quadrature
from . import random_fields
from . import examples
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment