Commit 6a845fd3 authored by Boris Bonev's avatar Boris Bonev Committed by Boris Bonev
Browse files

adding spherical attention

parent b3816ebc
......@@ -156,8 +156,8 @@
"metadata": {},
"outputs": [],
"source": [
"model = SFNO(img_size=(nlat, nlon), grid=\"equiangular\",\n",
" num_layers=4, scale_factor=3, embed_dim=16, big_skip=True, pos_embed=\"lat\", use_mlp=False, normalization_layer=\"none\").to(device)\n"
"model = SFNO(spectral_transform='sht', operator_type='driscoll-healy', img_size=(nlat, nlon), grid=\"equiangular\",\n",
" num_layers=4, scale_factor=3, embed_dim=16, residual_prediction=True, pos_embed=\"lat\", use_mlp=False, normalization_layer=\"none\").to(device)\n"
]
},
{
......
......@@ -7,7 +7,10 @@ name = "torch_harmonics"
authors = [
{ name="Boris Bonev" },
{ name="Thorsten Kurth" },
{ name="Max Rietmann" },
{ name="Mauro Bisson" },
{ name="Andrea Paris" },
{ name="Alberto Carpentieri" },
{ name="Massimiliano Fatica" },
{ name="Jean Kossaifi" },
{ name="Nikola Kovachki" },
......@@ -38,6 +41,7 @@ dependencies = [
"numpy>=1.22.4",
]
[tool.setuptools.dynamic]
version = {attr = "torch_harmonics.__version__"}
......@@ -49,3 +53,10 @@ dev = [
"pytest>=6.0.0",
"coverage>=6.5.0",
]
2d3ds = [
"requests",
"tarfile",
"tqdm",
"PIL",
"h5py",
]
......@@ -53,6 +53,22 @@ try:
except (ImportError, TypeError, AssertionError, AttributeError) as e:
warnings.warn(f"building custom extensions skipped: {e}")
def get_compile_args(module_name):
"""If user runs build with TORCH_HARMONICS_DEBUG=1 set, it will use debugging flags to build"""
debug_mode = os.environ.get('TORCH_HARMONICS_DEBUG', '0') == '1'
if debug_mode:
print(f"WARNING: Compiling {module_name} with debugging flags")
return {
'cxx': ['-g', '-O0', '-Wall'],
'nvcc': ['-g', '-G', '-O0']
}
else:
print(f"NOTE: Compiling {module_name} with release flags")
return {
'cxx': ['-O3', "-DNDEBUG"],
'nvcc': ['-O3', "-DNDEBUG"]
}
def get_ext_modules():
ext_modules = []
......@@ -73,6 +89,19 @@ def get_ext_modules():
"torch_harmonics/csrc/disco/disco_cuda_fwd.cu",
"torch_harmonics/csrc/disco/disco_cuda_bwd.cu",
],
extra_compile_args=get_compile_args("disco")
)
)
ext_modules.append(
CUDAExtension(
name="attention_cuda_extension",
sources=[
"torch_harmonics/csrc/attention/attention_fwd_cuda.cu",
"torch_harmonics/csrc/attention/attention_bwd_cuda.cu",
"torch_harmonics/csrc/attention/attention_interface.cu",
"torch_harmonics/csrc/attention/attention_row_offset.cu"
],
extra_compile_args=get_compile_args("neighborhood_attention")
)
)
cmdclass["build_ext"] = BuildExtension
......
This diff is collapsed.
......@@ -36,7 +36,7 @@ import math
import numpy as np
import torch
from torch.autograd import gradcheck
from torch_harmonics import *
from torch_harmonics import quadrature, DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes
......
......@@ -36,7 +36,7 @@ from parameterized import parameterized
import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch_harmonics as harmonics
import torch_harmonics as th
import torch_harmonics.distributed as thd
......@@ -219,10 +219,10 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
# set up handles
if transpose:
conv_local = harmonics.DiscreteContinuousConvTransposeS2(**disco_args).to(self.device)
conv_local = th.DiscreteContinuousConvTransposeS2(**disco_args).to(self.device)
conv_dist = thd.DistributedDiscreteContinuousConvTransposeS2(**disco_args).to(self.device)
else:
conv_local = harmonics.DiscreteContinuousConvS2(**disco_args).to(self.device)
conv_local = th.DiscreteContinuousConvS2(**disco_args).to(self.device)
conv_dist = thd.DistributedDiscreteContinuousConvS2(**disco_args).to(self.device)
# copy the weights from the local conv into the dist conv
......
......@@ -36,7 +36,7 @@ from parameterized import parameterized
import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch_harmonics as harmonics
import torch_harmonics as th
import torch_harmonics.distributed as thd
......@@ -206,7 +206,7 @@ class TestDistributedResampling(unittest.TestCase):
)
# set up handlesD
res_local = harmonics.ResampleS2(**res_args).to(self.device)
res_local = th.ResampleS2(**res_args).to(self.device)
res_dist = thd.DistributedResampleS2(**res_args).to(self.device)
# create tensors
......
......@@ -36,7 +36,7 @@ from parameterized import parameterized
import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch_harmonics as harmonics
import torch_harmonics as th
import torch_harmonics.distributed as thd
......@@ -218,10 +218,10 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
# set up handles
if vector:
forward_transform_local = harmonics.RealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
forward_transform_local = th.RealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
forward_transform_dist = thd.DistributedRealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
else:
forward_transform_local = harmonics.RealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
forward_transform_local = th.RealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
forward_transform_dist = thd.DistributedRealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
# create tensors
......@@ -304,12 +304,12 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
B, C, H, W = batch_size, num_chan, nlat, nlon
if vector:
forward_transform_local = harmonics.RealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
backward_transform_local = harmonics.InverseRealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
forward_transform_local = th.RealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
backward_transform_local = th.InverseRealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
backward_transform_dist = thd.DistributedInverseRealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
else:
forward_transform_local = harmonics.RealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
backward_transform_local = harmonics.InverseRealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
forward_transform_local = th.RealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
backward_transform_local = th.InverseRealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
backward_transform_dist = thd.DistributedInverseRealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
# create tensors
......
......@@ -34,8 +34,7 @@ from parameterized import parameterized
import math
import torch
from torch.autograd import gradcheck
from torch_harmonics import *
import torch_harmonics as th
class TestLegendrePolynomials(unittest.TestCase):
......@@ -63,10 +62,9 @@ class TestLegendrePolynomials(unittest.TestCase):
def test_legendre(self, verbose=False):
if verbose:
print("Testing computation of associated Legendre polynomials")
from torch_harmonics.legendre import legpoly
t = torch.linspace(0, 1, 100, dtype=torch.float64)
vdm = legpoly(self.mmax, self.lmax, t)
vdm = th.legendre.legpoly(self.mmax, self.lmax, t)
for l in range(self.lmax):
for m in range(l + 1):
......@@ -109,8 +107,8 @@ class TestSphericalHarmonicTransform(unittest.TestCase):
mmax = nlat
lmax = mmax
sht = RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
isht = InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
sht = th.RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
isht = th.InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
with torch.no_grad():
coeffs = torch.zeros(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
......@@ -167,8 +165,8 @@ class TestSphericalHarmonicTransform(unittest.TestCase):
mmax = nlat
lmax = mmax
sht = RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
isht = InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
sht = th.RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
isht = th.InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
with torch.no_grad():
coeffs = torch.zeros(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
......
......@@ -29,11 +29,13 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
__version__ = "0.7.6"
__version__ = "0.8.0"
from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT
from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
from .resample import ResampleS2
from .attention import AttentionS2, NeighborhoodAttentionS2
from ._neighborhood_attention import _neighborhood_attention_s2_fwd_torch, _NeighborhoodAttentionS2 # for tests
from . import quadrature
from . import random_fields
from . import examples
......@@ -36,10 +36,8 @@ from torch.amp import custom_fwd, custom_bwd
try:
import disco_cuda_extension
_cuda_extension_available = True
except ImportError as err:
disco_cuda_extension = None
_cuda_extension_available = False
class _DiscoS2ContractionCuda(torch.autograd.Function):
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -33,3 +33,4 @@ from .pde_sphere import SphereSolver
from .shallow_water_equations import ShallowWaterSolver
from .pde_dataset import PdeDataset
from .stanford_2d3ds_dataset import StanfordSegmentationDataset, StanfordDepthDataset, Stanford2D3DSDownloader, compute_stats_s2, StanfordDatasetSubset
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment