"Cargo.lock" did not exist on "e86ecbac63ab8cec773e07549a286a77edaac1d4"
Commit 6a845fd3 authored by Boris Bonev's avatar Boris Bonev Committed by Boris Bonev
Browse files

adding spherical attention

parent b3816ebc
......@@ -156,8 +156,8 @@
"metadata": {},
"outputs": [],
"source": [
"model = SFNO(img_size=(nlat, nlon), grid=\"equiangular\",\n",
" num_layers=4, scale_factor=3, embed_dim=16, big_skip=True, pos_embed=\"lat\", use_mlp=False, normalization_layer=\"none\").to(device)\n"
"model = SFNO(spectral_transform='sht', operator_type='driscoll-healy', img_size=(nlat, nlon), grid=\"equiangular\",\n",
" num_layers=4, scale_factor=3, embed_dim=16, residual_prediction=True, pos_embed=\"lat\", use_mlp=False, normalization_layer=\"none\").to(device)\n"
]
},
{
......
......@@ -7,7 +7,10 @@ name = "torch_harmonics"
authors = [
{ name="Boris Bonev" },
{ name="Thorsten Kurth" },
{ name="Max Rietmann" },
{ name="Mauro Bisson" },
{ name="Andrea Paris" },
{ name="Alberto Carpentieri" },
{ name="Massimiliano Fatica" },
{ name="Jean Kossaifi" },
{ name="Nikola Kovachki" },
......@@ -38,6 +41,7 @@ dependencies = [
"numpy>=1.22.4",
]
[tool.setuptools.dynamic]
version = {attr = "torch_harmonics.__version__"}
......@@ -49,3 +53,10 @@ dev = [
"pytest>=6.0.0",
"coverage>=6.5.0",
]
2d3ds = [
"requests",
"tarfile",
"tqdm",
"PIL",
"h5py",
]
......@@ -53,6 +53,22 @@ try:
except (ImportError, TypeError, AssertionError, AttributeError) as e:
warnings.warn(f"building custom extensions skipped: {e}")
def get_compile_args(module_name):
"""If user runs build with TORCH_HARMONICS_DEBUG=1 set, it will use debugging flags to build"""
debug_mode = os.environ.get('TORCH_HARMONICS_DEBUG', '0') == '1'
if debug_mode:
print(f"WARNING: Compiling {module_name} with debugging flags")
return {
'cxx': ['-g', '-O0', '-Wall'],
'nvcc': ['-g', '-G', '-O0']
}
else:
print(f"NOTE: Compiling {module_name} with release flags")
return {
'cxx': ['-O3', "-DNDEBUG"],
'nvcc': ['-O3', "-DNDEBUG"]
}
def get_ext_modules():
ext_modules = []
......@@ -73,6 +89,19 @@ def get_ext_modules():
"torch_harmonics/csrc/disco/disco_cuda_fwd.cu",
"torch_harmonics/csrc/disco/disco_cuda_bwd.cu",
],
extra_compile_args=get_compile_args("disco")
)
)
ext_modules.append(
CUDAExtension(
name="attention_cuda_extension",
sources=[
"torch_harmonics/csrc/attention/attention_fwd_cuda.cu",
"torch_harmonics/csrc/attention/attention_bwd_cuda.cu",
"torch_harmonics/csrc/attention/attention_interface.cu",
"torch_harmonics/csrc/attention/attention_row_offset.cu"
],
extra_compile_args=get_compile_args("neighborhood_attention")
)
)
cmdclass["build_ext"] = BuildExtension
......@@ -87,4 +116,4 @@ if __name__ == "__main__":
packages=find_packages(),
ext_modules=ext_modules,
cmdclass=cmdclass,
)
\ No newline at end of file
)
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import unittest
from parameterized import parameterized
# import math
import numpy as np
import torch
# from torch.autograd import gradcheck
from torch_harmonics import AttentionS2, NeighborhoodAttentionS2
from torch_harmonics._neighborhood_attention import (
_neighborhood_attention_s2_torch,
_neighborhood_attention_s2_fwd_torch,
_neighborhood_attention_s2_bwd_dv_torch,
_neighborhood_attention_s2_bwd_dk_torch,
_neighborhood_attention_s2_bwd_dq_torch,
)
# import custom C++/CUDA extensions
try:
import attention_cuda_extension
_cuda_extension_available = True
except ImportError as err:
print(f"Warning: Couldn't Import cuda attention: {err}")
attention_cuda_extension = None
_cuda_extension_available = False
# this routine is only supposed to be used in this test, since it is numerically not stable but supports
# autograd which some of the better kernels do not
def _neighborhood_attention_s2_torch_test(
kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, quad_weights: torch.Tensor, col_idx: torch.Tensor, row_idx: torch.Tensor, nlon_in: int, nlat_out: int, nlon_out: int
):
out = torch.zeros_like(qy)
for ho in range(nlat_out):
# get nonzero indices in output row
idx_ho = col_idx[row_idx == ho]
for wo in range(nlon_out):
alpha_sum = torch.zeros((out.shape[0],), dtype=out.dtype, device=out.device)
alpha = torch.zeros((out.shape[0], len(idx_ho)), dtype=out.dtype, device=out.device)
for inz, nz_col_idx in enumerate(idx_ho):
# compute input indices from psi datastructure
hi = nz_col_idx // nlon_in
# account for output shift and ensure positive index due to circular condition
wi = nz_col_idx % nlon_in
wip = (wi + wo) % nlon_in
# compute correlation & softmax numerator
q_ho_wo = qy[:, :, ho, wo]
k_hi_wip = kx[:, :, hi, wip]
alpha[:, inz] = torch.exp(torch.sum(q_ho_wo * k_hi_wip, dim=1)) * quad_weights[hi]
# softmax denominator
alpha_sum[:] = alpha_sum[:] + alpha[:, inz]
for inz, nz_col_idx in enumerate(idx_ho):
# compute input indices from psi datastructure
hi = nz_col_idx // nlon_in
# account for output shift and ensure positive index due to circular condition
wi = nz_col_idx % nlon_in
wip = (wi + wo) % nlon_in
# compute matmul of attention matrix with V-vector
out[:, :, ho, wo] = out[:, :, ho, wo] + (alpha[:, None, inz] / alpha_sum[:, None]) * vx[:, :, hi, wip]
return out
class TestNeighborhoodAttention(unittest.TestCase):
def setUp(self):
if torch.cuda.is_available():
self.device = torch.device("cuda:0")
torch.cuda.set_device(self.device.index)
torch.cuda.manual_seed(333)
else:
self.device = torch.device("cpu")
torch.manual_seed(333)
@parameterized.expand(
[
# regular convolution
[8, 4, 6, (17, 32), 1e-6, 1e-5],
]
)
def test_batched_linear(self, batch_size, in_channels, out_channels, shape, atol, rtol):
# weight
weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, 1, 1, dtype=torch.float32, device=self.device))
bias = torch.nn.Parameter(torch.randn(out_channels, dtype=torch.float32, device=self.device))
# input
inp = torch.randn(batch_size, in_channels, *shape, dtype=torch.float32, device=self.device)
inp.requires_grad = True
# operation
out = torch.nn.functional.conv2d(inp, weight=weight, bias=bias)
out_grad = torch.randn(batch_size, out_channels, *shape, dtype=torch.float32, device=self.device)
out.backward(out_grad)
# store for comparison
wgrad = weight.grad.clone()
bgrad = bias.grad.clone()
igrad = inp.grad.clone()
# explicit layers
igrad_explicit = torch.nn.functional.conv2d(out_grad, weight=weight.permute([1, 0, 2, 3]), bias=None)
wgrad_explicit = torch.einsum("bchw,bfhw->cf", out_grad, inp).reshape(out_channels, in_channels, 1, 1).contiguous()
bgrad_explicit = torch.sum(out_grad, dim=(0, 2, 3))
# check consistency
self.assertTrue(torch.allclose(igrad, igrad_explicit, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(wgrad, wgrad_explicit, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(bgrad, bgrad_explicit, atol=atol, rtol=rtol))
@parameterized.expand(
[
# self attention
[8, 4, 1, (17, 32), (17, 32), "equiangular", "equiangular", 1e-6, 1e-4],
]
)
def test_fwd(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol):
# extract some parameters
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
# set up neighbor matrix
att = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads,
in_shape=in_shape, out_shape=out_shape,
grid_in=grid_in, grid_out=grid_out, bias=False).to(self.device)
# Execute and compare
k_inp = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device=self.device)
k_inp.requires_grad = False
v_inp = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device=self.device)
v_inp.requires_grad = False
q_inp = torch.randn(batch_size, channels, nlat_out, nlon_out, dtype=torch.float32, device=self.device)
q_inp.requires_grad = False
out_torch = _neighborhood_attention_s2_torch_test(k_inp, v_inp, q_inp, att.quad_weights, att.psi_col_idx, att.psi_row_idx, nlon_in, nlat_out, nlon_out)
with torch.no_grad():
out_torch_explicit = _neighborhood_attention_s2_fwd_torch(k_inp, v_inp, q_inp, att.quad_weights, att.psi_col_idx, att.psi_roff_idx, nlon_in, nlat_out, nlon_out)
self.assertTrue(torch.allclose(out_torch_explicit, out_torch, atol=atol, rtol=rtol))
if _cuda_extension_available:
out_cuda = attention_cuda_extension.forward(k_inp, v_inp, q_inp, att.quad_weights, att.psi_col_idx, att.psi_roff_idx, nlon_in, nlat_out, nlon_out)
self.assertTrue(torch.allclose(out_torch, out_cuda, atol=atol, rtol=rtol))
@parameterized.expand(
[
# regular convolution
[8, 4, 1, (17, 32), (17, 32), "equiangular", "equiangular", 1e-6, 1e-4],
]
)
def test_bwd_dv(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol):
# extract some parameters
_, nlon_in = in_shape
nlat_out, nlon_out = out_shape
# set up neighbor matrix
att = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads,
in_shape=in_shape, out_shape=out_shape,
grid_in=grid_in, grid_out=grid_out, bias=False).to(self.device)
# Execute and compare
k_inp = torch.randn(batch_size, channels, *in_shape, dtype=torch.float32, device=self.device)
k_inp.requires_grad = False
v_inp = torch.randn(batch_size, channels, *in_shape, dtype=torch.float32, device=self.device)
v_inp.requires_grad = True
q_inp = torch.randn(batch_size, channels, *out_shape, dtype=torch.float32, device=self.device)
q_inp.requires_grad = False
out_grad = torch.randn(batch_size, channels, *out_shape, dtype=torch.float32, device=self.device)
out_torch = _neighborhood_attention_s2_torch_test(k_inp, v_inp, q_inp, att.quad_weights, att.psi_col_idx, att.psi_row_idx, nlon_in, nlat_out, nlon_out)
# need 'retain_graph' to avoid an error in the tests after this one
out_torch.backward(out_grad)
dv_inp_torch = v_inp.grad.clone()
with torch.no_grad():
dv_inp_torch_explicit = _neighborhood_attention_s2_bwd_dv_torch(
k_inp, v_inp, q_inp, out_grad, att.quad_weights, att.psi_col_idx, att.psi_roff_idx, nlon_in, nlat_out, nlon_out
)
self.assertTrue(torch.allclose(dv_inp_torch_explicit, dv_inp_torch, atol=atol, rtol=rtol))
if _cuda_extension_available:
dv_inp_cuda_explicit = attention_cuda_extension.backward_dv(
k_inp, v_inp, q_inp, out_grad, att.quad_weights, att.psi_col_idx, att.psi_roff_idx, nlon_in, nlat_out, nlon_out
)
self.assertTrue(torch.allclose(dv_inp_cuda_explicit, dv_inp_torch, atol=atol, rtol=rtol))
@parameterized.expand(
[
# regular convolution
[8, 4, 1, (17, 32), (17, 32), "equiangular", "equiangular", 1e-6, 1e-3],
]
)
def test_bwd_dk(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol):
# extract some parameters
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
# set up neighbor matrix
att = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads,
in_shape=in_shape, out_shape=out_shape,
grid_in=grid_in, grid_out=grid_out, bias=False).to(self.device)
# Execute and compare
k_inp = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device=self.device)
k_inp.requires_grad = True
v_inp = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device=self.device)
v_inp.requires_grad = False
q_inp = torch.randn(batch_size, channels, nlat_out, nlon_out, dtype=torch.float32, device=self.device)
q_inp.requires_grad = False
out_grad = torch.randn(batch_size, channels, nlat_out, nlon_out, dtype=torch.float32, device=self.device)
out_torch = _neighborhood_attention_s2_torch_test(k_inp, v_inp, q_inp, att.quad_weights, att.psi_col_idx, att.psi_row_idx, nlon_in, nlat_out, nlon_out)
# need 'retain_graph' to avoid an error in the tests after this one
out_torch.backward(out_grad)
dk_inp_torch = k_inp.grad.clone()
with torch.no_grad():
dk_inp_torch_explicit = _neighborhood_attention_s2_bwd_dk_torch(
k_inp, v_inp, q_inp, out_grad, att.quad_weights, att.psi_col_idx, att.psi_roff_idx, nlon_in, nlat_out, nlon_out
)
self.assertTrue(torch.allclose(dk_inp_torch_explicit, dk_inp_torch, atol=atol, rtol=rtol))
if _cuda_extension_available:
dk_inp_cuda_explicit = attention_cuda_extension.backward_dk(
k_inp, v_inp, q_inp, out_grad, att.quad_weights, att.psi_col_idx, att.psi_roff_idx, nlon_in, nlat_out, nlon_out
)
self.assertTrue(torch.allclose(dk_inp_cuda_explicit, dk_inp_torch, atol=atol, rtol=rtol))
@parameterized.expand(
[
# regular convolution
[8, 4, 1, (17, 32), (17, 32), "equiangular", "equiangular", 4e-6, 1e-3],
]
)
def test_bwd_dq(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol):
# extract some parameters
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
# set up neighbor matrix
att = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads,
in_shape=in_shape, out_shape=out_shape,
grid_in=grid_in, grid_out=grid_out, bias=False).to(self.device)
# Execute and compare
k_inp = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device=self.device)
k_inp.requires_grad = False
v_inp = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device=self.device)
v_inp.requires_grad = False
q_inp = torch.randn(batch_size, channels, nlat_out, nlon_out, dtype=torch.float32, device=self.device)
q_inp.requires_grad = True
out_grad = torch.randn(batch_size, channels, nlat_out, nlon_out, dtype=torch.float32, device=self.device)
out_torch = _neighborhood_attention_s2_torch_test(k_inp, v_inp, q_inp, att.quad_weights, att.psi_col_idx, att.psi_row_idx, nlon_in, nlat_out, nlon_out)
# need 'retain_graph' to avoid an error in the tests after this one
out_torch.backward(out_grad)
dq_inp_torch = q_inp.grad.clone()
with torch.no_grad():
dq_inp_torch_explicit = _neighborhood_attention_s2_bwd_dq_torch(
k_inp, v_inp, q_inp, out_grad, att.quad_weights, att.psi_col_idx, att.psi_roff_idx, nlon_in, nlat_out, nlon_out
)
self.assertTrue(torch.allclose(dq_inp_torch_explicit, dq_inp_torch, atol=atol, rtol=rtol))
if _cuda_extension_available:
dq_inp_cuda_explicit = attention_cuda_extension.backward_dq(
k_inp, v_inp, q_inp, out_grad, att.quad_weights, att.psi_col_idx, att.psi_roff_idx, nlon_in, nlat_out, nlon_out
)
self.assertTrue(torch.allclose(dq_inp_cuda_explicit, dq_inp_torch, atol=atol, rtol=rtol))
@parameterized.expand(
[
# self attention
[1, 73, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
]
)
def test_big(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol):
# this test only makes sense when CUDA version is available
if torch.cuda.is_available():
if not _cuda_extension_available:
print("WARNING: Problem loading CUDA attention module")
return
# extract some parameters
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
# TODO: this test seems hardcoded for GPU. Is this necessary?
k_gpu = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device="cuda:0")
k_gpu.requires_grad = False
v_gpu = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device="cuda:0")
v_gpu.requires_grad = False
q_gpu = torch.randn(batch_size, channels, nlat_out, nlon_out, dtype=torch.float32, device="cuda:0")
q_gpu.requires_grad = False
# set up layers
att_gpu = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads,
in_shape=in_shape, out_shape=out_shape,
grid_in=grid_in, grid_out=grid_out, bias=True).to("cuda:0")
# random weights
with torch.no_grad():
att_gpu.q_weights.normal_()
att_gpu.k_weights.normal_()
att_gpu.v_weights.normal_()
att_gpu.q_bias.normal_()
att_gpu.k_bias.normal_()
att_gpu.v_bias.normal_()
out_gpu = att_gpu(q_gpu, k_gpu, v_gpu)
# sync weights:
with torch.no_grad():
att_gpu.q_weights.copy_(att_gpu.q_weights)
att_gpu.k_weights.copy_(att_gpu.k_weights)
att_gpu.v_weights.copy_(att_gpu.v_weights)
att_gpu.q_bias.copy_(att_gpu.q_bias)
att_gpu.k_bias.copy_(att_gpu.k_bias)
att_gpu.v_bias.copy_(att_gpu.v_bias)
q_gpu = q_gpu.detach().clone().to(self.device)
q_gpu.requires_grad = True
k_gpu = k_gpu.detach().clone().to(self.device)
k_gpu.requires_grad = True
v_gpu = v_gpu.detach().clone().to(self.device)
v_gpu.requires_grad = True
out_gpu = att_gpu(q_gpu, k_gpu, v_gpu)
out_grad = torch.randn(out_gpu.shape, dtype=torch.float32, device="cuda:0")
out_gpu.backward(out_grad.to("cuda:0"))
@parameterized.expand(
[
# self attention
[10, 2, 1, (17, 32), (17, 32), "equiangular", "equiangular", 1e-5, 1e-5],
]
)
def test_neighborhood(self, batch_size, num_channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol):
"""
This test sets a specific q[ho,wo] value to 1.0 (elsewhere 0), and then a neighborhood of k around ho,wo to 1.0 (else 0.0). Also vi is set to a sinusoidal input. We also run it with fully 0 q,k. We test that the output of the nonzero q,k is only different to the zero q,k in a single point. We also check the value of this difference (as a regression test).
"""
# extract some parameters
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
from torch_harmonics import _neighborhood_attention_s2_fwd_torch
device = "cpu"
nas2_2 = NeighborhoodAttentionS2(in_channels=num_channels, num_heads=heads,
in_shape=(nlat_in, nlon_in), out_shape=(nlat_out, nlon_out),
theta_cutoff=torch.pi / 128 * 10)
nas2_2.to(device)
qo = torch.zeros((batch_size, num_channels, nlat_in, nlon_in)).to(device)
x = torch.linspace(0, 2 * np.pi, nlat_in) # 100 points in x direction
y = torch.linspace(0, 2 * np.pi, nlon_in) # 100 points in y direction
# Create a meshgrid
X, Y = torch.meshgrid(x, y, indexing="ij")
vi = torch.ones((batch_size, num_channels, nlat_in, nlon_in)).to(device)
vi[:, :, :, :] = (torch.sin(X) + torch.sin(Y))[None, None, :, :]
ki = torch.zeros((batch_size, num_channels, nlat_in, nlon_in)).to(device)
ki2 = torch.zeros((batch_size, num_channels, nlat_in, nlon_in)).to(device)
ho = 10
wo = 15
qo[:, 0, ho, wo] = 1.0
nas3 = NeighborhoodAttentionS2(in_channels=num_channels, num_heads=heads,
in_shape=(nlat_in, nlon_in), out_shape=(nlat_out, nlon_out),
theta_cutoff=torch.pi / 128 * 7)
zstart = nas3.psi_roff_idx[ho]
zend = nas3.psi_roff_idx[ho + 1]
# set a small neighborhood of k around (ho,wo) to 1
for idz in range(zstart, zend):
nz_col_idx = nas3.psi_col_idx[idz]
# compute input indices from psi datastructure
hi = nz_col_idx // nlon_in
# account for output shift and ensure positive index due to circular condition
wi = nz_col_idx % nlon_in
wip = (wi + wo) % nlon_in
ki2[:, 0, hi, wip] = 1.0
# run with k zero
y = _neighborhood_attention_s2_fwd_torch(ki, vi, qo, nas2_2.quad_weights, nas2_2.psi_col_idx, nas2_2.psi_roff_idx, nlon_in, nlat_out, nlon_out)
# run with k 1 at neighborhood of ho,wo
y2 = _neighborhood_attention_s2_fwd_torch(ki2, vi, qo, nas2_2.quad_weights, nas2_2.psi_col_idx, nas2_2.psi_roff_idx, nlon_in, nlat_out, nlon_out)
# for viz if desired
# plt.matshow((y[0,0,:,:]-y2[0,0,:,:]).detach().cpu())#, vmin=0, vmax=2.0)
# plt.matshow((y2[0,0,:,:]).detach().cpu())#, vmin=0, vmax=2.0)
# compare zero k vs. nonzero k and ensure difference only occurs at ho,wo
nz_x, nz_y = torch.where((y[0, 0, :, :] - y2[0, 0, :, :]).abs() > 0)
self.assertTrue(nz_x.item() == ho)
self.assertTrue(nz_y.item() == wo)
h, w = nz_x.item(), nz_y.item()
diff_hw = y[0, 0, h, w] - y2[0, 0, h, w]
# print("diff_hw=", diff_hw.item())
# regression test the difference. Unfortunately difficult to come up with an
# analytical value, so we just have it hardcoded.
self.assertTrue(torch.allclose(diff_hw, torch.tensor([0.00753], device=device), rtol=rtol, atol=atol))
@parameterized.expand(
[
# self attention
[8, 4, 1, (17, 32), (17, 32), "equiangular", "equiangular", 2e-4, 1e-5],
[8, 4, 2, (17, 32), (17, 32), "equiangular", "equiangular", 2e-4, 1e-5],
]
)
def test_full(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol):
# extract some parameters
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
k_cpu = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device="cpu")
k_cpu.requires_grad = True
v_cpu = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device="cpu")
v_cpu.requires_grad = True
q_cpu = torch.randn(batch_size, channels, nlat_out, nlon_out, dtype=torch.float32, device="cpu")
q_cpu.requires_grad = True
# set up layers
att_cpu = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads,
in_shape=in_shape, out_shape=out_shape,
grid_in=grid_in, grid_out=grid_out, bias=True)
# random weights
with torch.no_grad():
att_cpu.q_weights.normal_()
att_cpu.k_weights.normal_()
att_cpu.v_weights.normal_()
att_cpu.q_bias.normal_()
att_cpu.k_bias.normal_()
att_cpu.v_bias.normal_()
out_cpu = att_cpu(q_cpu, k_cpu, v_cpu)
out_grad = torch.randn(out_cpu.shape, dtype=torch.float32, device="cpu")
out_cpu.backward(out_grad)
att_gpu = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads,
in_shape=in_shape, out_shape=out_shape,
grid_in=grid_in, grid_out=grid_out, bias=True).to(self.device)
# sync weights:
with torch.no_grad():
att_gpu.q_weights.copy_(att_cpu.q_weights)
att_gpu.k_weights.copy_(att_cpu.k_weights)
att_gpu.v_weights.copy_(att_cpu.v_weights)
att_gpu.proj_weights.copy_(att_cpu.proj_weights)
att_gpu.q_bias.copy_(att_cpu.q_bias)
att_gpu.k_bias.copy_(att_cpu.k_bias)
att_gpu.v_bias.copy_(att_cpu.v_bias)
att_gpu.proj_bias.copy_(att_cpu.proj_bias)
q_gpu = q_cpu.detach().clone().to(self.device)
q_gpu.requires_grad = True
k_gpu = k_cpu.detach().clone().to(self.device)
k_gpu.requires_grad = True
v_gpu = v_cpu.detach().clone().to(self.device)
v_gpu.requires_grad = True
out_gpu = att_gpu(q_gpu, k_gpu, v_gpu)
out_gpu.backward(out_grad.to(self.device))
# check forward
self.assertTrue(torch.allclose(out_cpu.to(self.device), out_gpu, atol=atol, rtol=rtol))
# check input gradients:
self.assertTrue(torch.allclose(q_cpu.grad.to(self.device), q_gpu.grad, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(k_cpu.grad.to(self.device), k_gpu.grad, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(v_cpu.grad.to(self.device), v_gpu.grad, atol=atol, rtol=rtol))
# check weight gradients
self.assertTrue(torch.allclose(att_cpu.q_weights.grad.to(self.device), att_gpu.q_weights.grad, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(att_cpu.k_weights.grad.to(self.device), att_gpu.k_weights.grad, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(att_cpu.v_weights.grad.to(self.device), att_gpu.v_weights.grad, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(att_cpu.proj_weights.grad.to(self.device), att_gpu.proj_weights.grad, atol=atol, rtol=rtol))
# check bias gradients
self.assertTrue(torch.allclose(att_cpu.q_bias.grad.to(self.device), att_gpu.q_bias.grad, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(att_cpu.k_bias.grad.to(self.device), att_gpu.k_bias.grad, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(att_cpu.v_bias.grad.to(self.device), att_gpu.v_bias.grad, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(att_cpu.proj_bias.grad.to(self.device), att_gpu.proj_bias.grad, atol=atol, rtol=rtol))
@parameterized.expand(
[
# self attention
[8, 8, 8, 2, (17, 32), (17, 32), "equiangular", "equiangular", 2e-4, 1e-5],
[8, 8, 8, 2, (17, 32), (17, 32), "legendre-gauss", "legendre-gauss", 2e-4, 1e-5],
[8, 8, 8, 2, (17, 32), (17, 32), "lobatto", "lobatto", 2e-4, 1e-5],
[8, 8, 4, 2, (17, 32), (17, 32), "equiangular", "equiangular", 2e-4, 1e-5],
]
)
def test_full_attention(self, batch_size, channels_in, channels_out, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol):
# extract some parameters
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
k_cpu = torch.randn(batch_size, channels_in, nlat_in, nlon_in, dtype=torch.float32, device="cpu")
k_cpu.requires_grad = True
v_cpu = torch.randn(batch_size, channels_in, nlat_in, nlon_in, dtype=torch.float32, device="cpu")
v_cpu.requires_grad = True
q_cpu = torch.randn(batch_size, channels_in, nlat_out, nlon_out, dtype=torch.float32, device="cpu")
q_cpu.requires_grad = True
att_cpu = AttentionS2(in_channels=channels_in, out_channels=channels_out, num_heads=heads, in_shape=in_shape, out_shape=out_shape, grid_in=grid_in, grid_out=grid_out, bias=True)
out = att_cpu(q_cpu, k_cpu, v_cpu)
# check if output is sane
self.assertFalse(torch.isnan(out).any())
if __name__ == "__main__":
unittest.main()
......@@ -36,7 +36,7 @@ import math
import numpy as np
import torch
from torch.autograd import gradcheck
from torch_harmonics import *
from torch_harmonics import quadrature, DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes
......
......@@ -36,7 +36,7 @@ from parameterized import parameterized
import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch_harmonics as harmonics
import torch_harmonics as th
import torch_harmonics.distributed as thd
......@@ -219,10 +219,10 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
# set up handles
if transpose:
conv_local = harmonics.DiscreteContinuousConvTransposeS2(**disco_args).to(self.device)
conv_local = th.DiscreteContinuousConvTransposeS2(**disco_args).to(self.device)
conv_dist = thd.DistributedDiscreteContinuousConvTransposeS2(**disco_args).to(self.device)
else:
conv_local = harmonics.DiscreteContinuousConvS2(**disco_args).to(self.device)
conv_local = th.DiscreteContinuousConvS2(**disco_args).to(self.device)
conv_dist = thd.DistributedDiscreteContinuousConvS2(**disco_args).to(self.device)
# copy the weights from the local conv into the dist conv
......
......@@ -36,7 +36,7 @@ from parameterized import parameterized
import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch_harmonics as harmonics
import torch_harmonics as th
import torch_harmonics.distributed as thd
......@@ -196,9 +196,9 @@ class TestDistributedResampling(unittest.TestCase):
B, C, H, W = batch_size, num_chan, nlat_in, nlon_in
res_args = dict(
nlat_in=nlat_in,
nlat_in=nlat_in,
nlon_in=nlon_in,
nlat_out=nlat_out,
nlat_out=nlat_out,
nlon_out=nlon_out,
grid_in=grid_in,
grid_out=grid_out,
......@@ -206,7 +206,7 @@ class TestDistributedResampling(unittest.TestCase):
)
# set up handlesD
res_local = harmonics.ResampleS2(**res_args).to(self.device)
res_local = th.ResampleS2(**res_args).to(self.device)
res_dist = thd.DistributedResampleS2(**res_args).to(self.device)
# create tensors
......
......@@ -36,7 +36,7 @@ from parameterized import parameterized
import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch_harmonics as harmonics
import torch_harmonics as th
import torch_harmonics.distributed as thd
......@@ -218,10 +218,10 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
# set up handles
if vector:
forward_transform_local = harmonics.RealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
forward_transform_local = th.RealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
forward_transform_dist = thd.DistributedRealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
else:
forward_transform_local = harmonics.RealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
forward_transform_local = th.RealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
forward_transform_dist = thd.DistributedRealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
# create tensors
......@@ -304,12 +304,12 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
B, C, H, W = batch_size, num_chan, nlat, nlon
if vector:
forward_transform_local = harmonics.RealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
backward_transform_local = harmonics.InverseRealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
forward_transform_local = th.RealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
backward_transform_local = th.InverseRealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
backward_transform_dist = thd.DistributedInverseRealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
else:
forward_transform_local = harmonics.RealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
backward_transform_local = harmonics.InverseRealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
forward_transform_local = th.RealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
backward_transform_local = th.InverseRealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
backward_transform_dist = thd.DistributedInverseRealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
# create tensors
......
......@@ -34,8 +34,7 @@ from parameterized import parameterized
import math
import torch
from torch.autograd import gradcheck
from torch_harmonics import *
import torch_harmonics as th
class TestLegendrePolynomials(unittest.TestCase):
......@@ -63,10 +62,9 @@ class TestLegendrePolynomials(unittest.TestCase):
def test_legendre(self, verbose=False):
if verbose:
print("Testing computation of associated Legendre polynomials")
from torch_harmonics.legendre import legpoly
t = torch.linspace(0, 1, 100, dtype=torch.float64)
vdm = legpoly(self.mmax, self.lmax, t)
vdm = th.legendre.legpoly(self.mmax, self.lmax, t)
for l in range(self.lmax):
for m in range(l + 1):
......@@ -109,8 +107,8 @@ class TestSphericalHarmonicTransform(unittest.TestCase):
mmax = nlat
lmax = mmax
sht = RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
isht = InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
sht = th.RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
isht = th.InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
with torch.no_grad():
coeffs = torch.zeros(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
......@@ -167,8 +165,8 @@ class TestSphericalHarmonicTransform(unittest.TestCase):
mmax = nlat
lmax = mmax
sht = RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
isht = InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
sht = th.RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
isht = th.InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
with torch.no_grad():
coeffs = torch.zeros(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
......
......@@ -29,11 +29,13 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
__version__ = "0.7.6"
__version__ = "0.8.0"
from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT
from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
from .resample import ResampleS2
from .attention import AttentionS2, NeighborhoodAttentionS2
from ._neighborhood_attention import _neighborhood_attention_s2_fwd_torch, _NeighborhoodAttentionS2 # for tests
from . import quadrature
from . import random_fields
from . import examples
......@@ -36,10 +36,8 @@ from torch.amp import custom_fwd, custom_bwd
try:
import disco_cuda_extension
_cuda_extension_available = True
except ImportError as err:
disco_cuda_extension = None
_cuda_extension_available = False
class _DiscoS2ContractionCuda(torch.autograd.Function):
......
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import math
from typing import Union
import torch
import torch.nn.functional as F
from torch.amp import custom_fwd, custom_bwd
try:
import attention_cuda_extension
_cuda_extension_available = True
except ImportError as err:
attention_cuda_extension = None
_cuda_extension_available = False
def _neighborhood_attention_s2_fwd_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor,
quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
# prepare result tensor
y = torch.zeros_like(qy)
for ho in range(nlat_out):
# get number of nonzeros
zstart = row_off[ho]
zend = row_off[ho+1]
for wo in range(nlon_out):
alpha_sum = torch.zeros((y.shape[0],), dtype=y.dtype, device=y.device)
qdotk_nz = torch.zeros((y.shape[0], zend-zstart,), dtype=y.dtype, device=y.device)
for idz in range(zstart, zend):
nz_col_idx = col_idx[idz]
# compute input indices from psi datastructure
hi = nz_col_idx // nlon_in
# account for output shift and ensure positive index due to circular condition
wi = nz_col_idx % nlon_in
wip = (wi + wo) % nlon_in
# compute correlation & softmax numerator
q_ho_wo = qy[:, :, ho, wo]
k_hi_wip = kx[:, :, hi, wip]
qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hi_wip, dim=1)
qdotk_max, _ = torch.max(qdotk_nz, dim=1)
for idz in range(zstart, zend):
nz_col_idx = col_idx[idz]
# compute input indices from psi datastructure
hi = nz_col_idx // nlon_in
# account for output shift and ensure positive index due to circular condition
wi = nz_col_idx % nlon_in
wip = (wi + wo) % nlon_in
alpha = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max)
# softmax denominator
alpha_sum[:] += alpha[:] * quad_weights[hi]
y[:,:,ho,wo] += alpha[:, None] * vx[:,:,hi,wip] * quad_weights[hi]
y[:,:,ho,wo] = y[:,:,ho,wo] / alpha_sum[:, None]
return y
# Explicit gradient w.r.t. vx: dM/dv
# provided as a reference for CUDA & other hand-written gradients
def _neighborhood_attention_s2_bwd_dv_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor,
quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
nlon_in: int, nlat_out: int, nlon_out: int):
# shapes:
# input
# kx: B, C, Hi, Wi
# vx: B, C, Hi, Wi
# qy: B, C, Ho, Wo
# quad_weights: Hi
# output
# dvx: B, C, Hi, Wi
dvx = torch.zeros_like(vx)
for ho in range(nlat_out):
# get number of nonzeros
zstart = row_off[ho]
zend = row_off[ho+1]
for wo in range(nlon_out):
alpha_nz = torch.zeros((dy.shape[0], zend-zstart), dtype=dy.dtype, device=dy.device)
qdotk_nz = torch.zeros((dy.shape[0], zend-zstart), dtype=dy.dtype, device=dy.device)
alpha_sum = torch.zeros((dy.shape[0],), dtype=dy.dtype, device=dy.device)
for idz in range(zstart, zend):
nz_col_idx = col_idx[idz]
# compute input indices from psi datastructure
hi = nz_col_idx // nlon_in
# account for output shift and ensure positive index due to circular condition
wi = nz_col_idx % nlon_in
wip = (wi+wo) % nlon_in
# compute correlation & softmax numerator
q_ho_wo = qy[:, :, ho, wo]
k_hi_wi = kx[:, :, hi, wip]
qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hi_wi, dim=1)
qdotk_max, _ = torch.max(qdotk_nz, dim=1)
for idz in range(zstart, zend):
nz_col_idx = col_idx[idz]
# compute input indices from psi datastructure
hi = nz_col_idx // nlon_in
# account for output shift and ensure positive index due to circular condition
wi = nz_col_idx % nlon_in
wip = (wi+wo) % nlon_in
alpha_nz[:,idz-zstart] = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max) * quad_weights[hi]
alpha_sum[:] += alpha_nz[:,idz-zstart]
for idz in range(zstart, zend):
nz_col_idx = col_idx[idz]
# compute input indices from psi datastructure
hi = nz_col_idx // nlon_in
# account for output shift and ensure positive index due to circular condition
wi = nz_col_idx % nlon_in
wip = (wi+wo) % nlon_in
dvx[:,:,hi, wip] += (alpha_nz[:, None, idz-zstart] / alpha_sum[:, None]) * dy[:,:,ho,wo]
return dvx
# Explicit gradient w.r.t. kx: dM/dk
# provided as a reference for CUDA & other hand-written gradients
def _neighborhood_attention_s2_bwd_dk_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor,
quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
nlon_in: int, nlat_out: int, nlon_out: int):
# shapes:
# input
# kx: B, C, Hi, Wi
# vx: B, C, Hi, Wi
# qy: B, C, Ho, Wo
# quad_weights: Hi
# output
# dkx: B, C, Hi, Wi
dkx = torch.zeros_like(kx)
for ho in range(nlat_out):
# get number of nonzeros
zstart = row_off[ho]
zend = row_off[ho+1]
for wo in range(nlon_out):
qdotk_nz = torch.zeros((dy.shape[0], zend-zstart), dtype=dy.dtype, device=dy.device)
integral = torch.zeros((dy.shape[0],), dtype=dy.dtype, device=dy.device)
alpha = torch.zeros((dy.shape[0], zend-zstart), dtype=dy.dtype, device=dy.device)
alpha_sum = torch.zeros((dy.shape[0],), dtype=dy.dtype, device=dy.device)
for idz in range(zstart, zend):
nz_col_idx = col_idx[idz]
# compute input indices from psi datastructure
hj = nz_col_idx // nlon_in
# account for output shift and ensure positive index due to circular condition
wj = nz_col_idx % nlon_in
wjp = (wj+wo) % nlon_in
# compute correlation & softmax numerator
q_ho_wo = qy[:, :, ho, wo]
k_hj_wjp = kx[:, :, hj, wjp]
qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hj_wjp, dim=1)
qdotk_max, _ = torch.max(qdotk_nz, dim=1)
for idz in range(zstart, zend):
nz_col_idx = col_idx[idz]
# compute input indices from psi datastructure
hj = nz_col_idx // nlon_in
# account for output shift and ensure positive index due to circular condition
wj = nz_col_idx % nlon_in
wjp = (wj+wo) % nlon_in
alpha[:, idz-zstart] = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max) * quad_weights[hj]
alpha_sum[:] += alpha[:, idz-zstart]
# input dot
gdotv = torch.sum(dy[:,:,ho, wo] * vx[:,:,hj, wjp], dim=1)
# integral term
integral[:] += alpha[:, idz-zstart] * gdotv[:]
integral[:] = integral[:] / alpha_sum[:]
for idz in range(zstart, zend):
nz_col_idx = col_idx[idz]
# compute input indices from psi datastructure
hi = nz_col_idx // nlon_in
# account for output shift and ensure positive index due to circular condition
wi = nz_col_idx % nlon_in
wip = (wi+wo) % nlon_in
# compute correlation & softmax numerator
gdotv = torch.sum(dy[:,:,ho, wo] * vx[:,:,hi, wip], dim=1)
dkx[:,:,hi,wip] += qy[:, :, ho, wo] * (alpha[:, None, idz-zstart] / alpha_sum[:, None]) * (gdotv[:, None] - integral[:, None])
return dkx
# Explicit gradient w.r.t. qy: dM/dq
# provided as a reference for CUDA & other hand-written gradients
def _neighborhood_attention_s2_bwd_dq_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor,
quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
nlon_in: int, nlat_out: int, nlon_out: int):
# shapes:
# input
# kx: B, C, Hi, Wi
# vx: B, C, Hi, Wi
# qy: B, C, Ho, Wo
# quad_weights: Hi
# output
# dvx: B, C, Hi, Wi
dqy = torch.zeros_like(qy)
for ho in range(nlat_out):
# get number of nonzeros
zstart = row_off[ho]
zend = row_off[ho+1]
for wo in range(nlon_out):
alpha = torch.zeros((dy.shape[0], zend-zstart), dtype=dy.dtype, device=dy.device)
qdotk_nz = torch.zeros((dy.shape[0], zend-zstart), dtype=dy.dtype, device=dy.device)
alpha_k = torch.zeros((dy.shape[0], dy.shape[1]), dtype=dy.dtype, device=dy.device)
alpha_vw = torch.zeros((dy.shape[0], dy.shape[1]), dtype=dy.dtype, device=dy.device)
alpha_kvw = torch.zeros((dy.shape[0], dy.shape[1]), dtype=dy.dtype, device=dy.device)
alpha_sum = torch.zeros((dy.shape[0],), dtype=dy.dtype, device=dy.device)
alpha_sum2 = torch.zeros((dy.shape[0],), dtype=dy.dtype, device=dy.device)
for idz in range(zstart, zend):
nz_col_idx = col_idx[idz]
# compute input indices from psi datastructure
hi = nz_col_idx // nlon_in
# account for output shift and ensure positive index due to circular condition
wi = nz_col_idx % nlon_in
wip = (wi+wo) % nlon_in
idz_i = idz-zstart
# compute correlation & softmax numerator
q_ho_wo = qy[:, :, ho, wo]
k_hi_wi = kx[:, :, hi, wip]
qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hi_wi, dim=1)
qdotk_max,_ = qdotk_nz.max(dim=1)
for idz in range(zstart, zend):
nz_col_idx = col_idx[idz]
# compute input indices from psi datastructure
hi = nz_col_idx // nlon_in
# account for output shift and ensure positive index due to circular condition
wi = nz_col_idx % nlon_in
wip = (wi+wo) % nlon_in
q_ho_wo = qy[:, :, ho, wo]
k_hi_wi = kx[:, :, hi, wip]
idz_i = idz-zstart
alpha[:, idz_i] = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max) * quad_weights[hi]
alpha_sum[:] += alpha[:, idz_i]
gdotv = torch.sum(dy[:,:,ho, wo] * vx[:,:,hi, wip], dim=1)
alpha_k[:,:] += alpha[:, None, idz_i] * k_hi_wi
alpha_vw[:,:] += alpha[:, None, idz_i] * gdotv[:,None]
alpha_kvw[:,:] += alpha[:, None, idz_i] * k_hi_wi * gdotv[:,None]
dqy[:,:,ho,wo] = (alpha_kvw*alpha_sum[:,None] - alpha_vw*alpha_k) / (alpha_sum[:,None]*alpha_sum[:,None])
return dqy
class _NeighborhoodAttentionS2(torch.autograd.Function):
@staticmethod
@custom_fwd(device_type="cpu")
def forward(ctx, k: torch.Tensor, v: torch.Tensor, q: torch.Tensor,
wk: torch.Tensor, wv: torch.Tensor, wq: torch.Tensor,
bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None],
quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
nh: int, nlon_in: int, nlat_out: int, nlon_out: int):
ctx.save_for_backward(col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq)
ctx.nh = nh
ctx.nlon_in = nlon_in
ctx.nlat_out = nlat_out
ctx.nlon_out = nlon_out
kw = F.conv2d(k, weight=wk, bias=bk)
vw = F.conv2d(v, weight=wv, bias=bv)
qw = F.conv2d(q, weight=wq, bias=bq)
# reshape, folding num heads into batch dim
B, _, H, W = kw.shape
kw = kw.reshape(B*nh, -1, H, W)
B, _, H, W = vw.shape
vw = vw.reshape(B*nh, -1, H, W)
B, _, H, W = qw.shape
qw = qw.reshape(B*nh, -1, H, W)
kw = kw.to(torch.float32)
vw = vw.to(torch.float32)
qw = qw.to(torch.float32)
output = _neighborhood_attention_s2_fwd_torch(kw, vw, qw, quad_weights,
col_idx, row_off,
nlon_in, nlat_out, nlon_out)
_, C, H, W = output.shape
output = output.reshape(B, -1, H, W)
return output
@staticmethod
@custom_bwd(device_type="cpu")
def backward(ctx, grad_output):
col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors
nh = ctx.nh
nlon_in = ctx.nlon_in
nlat_out = ctx.nlat_out
nlon_out = ctx.nlon_out
kw = F.conv2d(k, weight=wk, bias=bk)
vw = F.conv2d(v, weight=wv, bias=bv)
qw = F.conv2d(q, weight=wq, bias=bq)
# reshape, folding num heads into batch dim
B, _, H, W = kw.shape
kw = kw.reshape(B*nh, -1, H, W)
B, _, H, W = vw.shape
vw = vw.reshape(B*nh, -1, H, W)
B, _, H, W = qw.shape
qw = qw.reshape(B*nh, -1, H, W)
B, _, H, W = grad_output.shape
grad_output = grad_output.reshape(B*nh, -1, H, W)
dvw = _neighborhood_attention_s2_bwd_dv_torch(kw, vw, qw, grad_output,
quad_weights,
col_idx, row_off,
nlon_in, nlat_out, nlon_out)
dkw = _neighborhood_attention_s2_bwd_dk_torch(kw, vw, qw, grad_output,
quad_weights,
col_idx, row_off,
nlon_in, nlat_out, nlon_out)
dqw = _neighborhood_attention_s2_bwd_dq_torch(kw, vw, qw, grad_output,
quad_weights,
col_idx, row_off,
nlon_in, nlat_out, nlon_out)
# reshape again
_, C, H, W = dkw.shape
dkw = dkw.reshape(B, -1, H, W)
_, C, H, W = dvw.shape
dvw = dvw.reshape(B, -1, H, W)
_, C, H, W = dqw.shape
dqw = dqw.reshape(B, -1, H, W)
# input grads
dv = torch.nn.functional.conv2d(dvw, weight=wv.permute([1,0,2,3]), bias=None)
dk = torch.nn.functional.conv2d(dkw, weight=wk.permute([1,0,2,3]), bias=None)
dq = torch.nn.functional.conv2d(dqw, weight=wq.permute([1,0,2,3]), bias=None)
# weight grads
dwv = torch.einsum("bchw,bfhw->cf", dvw, v).reshape(*wv.shape).contiguous()
dwk = torch.einsum("bchw,bfhw->cf", dkw, k).reshape(*wk.shape).contiguous()
dwq = torch.einsum("bchw,bfhw->cf", dqw, q).reshape(*wq.shape).contiguous()
# bias grads:
if bv is not None:
dbv = torch.sum(dvw, dim=(0,2,3))
else:
dbv = None
if bk is not None:
dbk = torch.sum(dkw, dim=(0,2,3))
else:
dbk = None
if bq is not None:
dbq = torch.sum(dqw, dim=(0,2,3))
else:
dbq = None
return dk, dv, dq, dwk, dwv, dwq, dbk, dbv, dbq, \
None, None, None, None, None, None, None
def _neighborhood_attention_s2_torch(k: torch.Tensor, v: torch.Tensor, q: torch.Tensor,
wk: torch.Tensor, wv: torch.Tensor, wq: torch.Tensor,
bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None],
bq: Union[torch.Tensor, None], quad_weights: torch.Tensor,
col_idx: torch.Tensor, row_off: torch.Tensor,
nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
return _NeighborhoodAttentionS2.apply(k, v, q, wk, wv, wq, bk, bv, bq,
quad_weights, col_idx, row_off,
nh, nlon_in, nlat_out, nlon_out)
class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
@staticmethod
@custom_fwd(device_type="cuda")
def forward(ctx, k: torch.Tensor, v: torch.Tensor, q: torch.Tensor,
wk: torch.Tensor, wv: torch.Tensor, wq: torch.Tensor,
bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None],
quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
max_psi_nnz: int, nh: int, nlon_in: int, nlat_out: int, nlon_out: int):
ctx.save_for_backward(col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq)
ctx.nh = nh
ctx.max_psi_nnz = max_psi_nnz
ctx.nlon_in = nlon_in
ctx.nlat_out = nlat_out
ctx.nlon_out = nlon_out
kw = F.conv2d(k, weight=wk, bias=bk)
vw = F.conv2d(v, weight=wv, bias=bv)
qw = F.conv2d(q, weight=wq, bias=bq)
# reshape, folding num heads into batch dim
B, _, H, W = kw.shape
kw = kw.reshape(B*nh, -1, H, W)
B, _, H, W = vw.shape
vw = vw.reshape(B*nh, -1, H, W)
B, _, H, W = qw.shape
qw = qw.reshape(B*nh, -1, H, W)
# convert to float32
kw = kw.to(torch.float32)
vw = vw.to(torch.float32)
qw = qw.to(torch.float32)
output = attention_cuda_extension.forward(kw, vw, qw, quad_weights,
col_idx, row_off,
nlon_in, nlat_out, nlon_out)
_, C, H, W = output.shape
output = output.reshape(B, -1, H, W)
return output
@staticmethod
@custom_bwd(device_type="cuda")
def backward(ctx, grad_output):
col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors
nh = ctx.nh
max_psi_nnz = ctx.max_psi_nnz
nlon_in = ctx.nlon_in
nlat_out = ctx.nlat_out
nlon_out = ctx.nlon_out
kw = F.conv2d(k, weight=wk, bias=bk)
vw = F.conv2d(v, weight=wv, bias=bv)
qw = F.conv2d(q, weight=wq, bias=bq)
# reshape, folding num heads into batch dim
B, _, H, W = kw.shape
kw = kw.reshape(B*nh, -1, H, W)
B, _, H, W = vw.shape
vw = vw.reshape(B*nh, -1, H, W)
B, _, H, W = qw.shape
qw = qw.reshape(B*nh, -1, H, W)
B, _, H, W = grad_output.shape
grad_output = grad_output.reshape(B*nh, -1, H, W)
dkw,dvw,dqw = attention_cuda_extension.backward_dkvq(kw, vw, qw, grad_output,
quad_weights,
col_idx, row_off,
nlon_in, nlat_out, nlon_out)
# reshape again
_, C, H, W = dkw.shape
dkw = dkw.reshape(B, -1, H, W)
_, C, H, W = dvw.shape
dvw = dvw.reshape(B, -1, H, W)
_, C, H, W = dqw.shape
dqw = dqw.reshape(B, -1, H, W)
# input grads
dv = torch.nn.functional.conv2d(dvw, weight=wv.permute([1,0,2,3]), bias=None)
dk = torch.nn.functional.conv2d(dkw, weight=wk.permute([1,0,2,3]), bias=None)
dq = torch.nn.functional.conv2d(dqw, weight=wq.permute([1,0,2,3]), bias=None)
# weight grads
dwv = torch.einsum("bchw,bfhw->cf", dvw, v).reshape(*wv.shape).contiguous()
dwk = torch.einsum("bchw,bfhw->cf", dkw, k).reshape(*wk.shape).contiguous()
dwq = torch.einsum("bchw,bfhw->cf", dqw, q).reshape(*wq.shape).contiguous()
# bias grads:
if bv is not None:
dbv = torch.sum(dvw, dim=(0,2,3))
else:
dbv = None
if bk is not None:
dbk = torch.sum(dkw, dim=(0,2,3))
else:
dbk = None
if bq is not None:
dbq = torch.sum(dqw, dim=(0,2,3))
else:
dbq = None
return dk, dv, dq, dwk, dwv, dwq, dbk, dbv, dbq, \
None, None, None, None, None, None, None, None
def _neighborhood_attention_s2_cuda(k: torch.Tensor, v: torch.Tensor, q: torch.Tensor,
wk: torch.Tensor, wv: torch.Tensor, wq: torch.Tensor,
bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None],
bq: Union[torch.Tensor, None], quad_weights: torch.Tensor,
col_idx: torch.Tensor, row_off: torch.Tensor, max_psi_nnz: int,
nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
return _NeighborhoodAttentionS2Cuda.apply(k, v, q, wk, wv, wq, bk, bv, bq,
quad_weights, col_idx, row_off, max_psi_nnz,
nh, nlon_in, nlat_out, nlon_out)
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
from typing import List, Tuple, Union, Optional
from warnings import warn
import math
import torch
import torch.nn as nn
import numpy as np
from torch_harmonics.quadrature import _precompute_latitudes
from torch_harmonics.convolution import _precompute_convolution_tensor_s2
from torch_harmonics._neighborhood_attention import _neighborhood_attention_s2_torch, _neighborhood_attention_s2_cuda
from torch_harmonics.filter_basis import get_filter_basis
# import custom C++/CUDA extensions
try:
import attention_cuda_extension
_cuda_extension_available = True
except ImportError as err:
attention_cuda_extension = None
_cuda_extension_available = False
class AttentionS2(nn.Module):
"""
(Global) attention on the 2-sphere.
Parameters
-----------
in_channels: int
number of channels of the input signal (corresponds to embed_dim in MHA in PyTorch)
num_heads: int
number of attention heads
in_shape: tuple
shape of the input grid
out_shape: tuple
shape of the output grid
grid_in: str, optional
input grid type, "equiangular" by default
grid_out: str, optional
output grid type, "equiangular" by default
bias: bool, optional
if specified, adds bias to input / output projection layers
k_channels: int
number of dimensions for interior inner product in the attention matrix (corresponds to kdim in MHA in PyTorch)
out_channels: int, optional
number of dimensions for interior inner product in the attention matrix (corresponds to vdim in MHA in PyTorch)
"""
def __init__(
self,
in_channels: int,
num_heads: int,
in_shape: Tuple[int],
out_shape: Tuple[int],
grid_in: Optional[str] = "equiangular",
grid_out: Optional[str] = "equiangular",
scale: Optional[Union[torch.Tensor, float]] = None,
bias: Optional[bool] = True,
k_channels: Optional[int] = None,
out_channels: Optional[int] = None,
drop_rate: Optional[float]=0.0,
):
super().__init__()
self.nlat_in, self.nlon_in = in_shape
self.nlat_out, self.nlon_out = out_shape
self.in_channels = in_channels
self.num_heads = num_heads
self.k_channels = in_channels if k_channels is None else k_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.drop_rate = drop_rate
self.scale = scale
# integration weights
_, wgl = _precompute_latitudes(self.nlat_in, grid=grid_in)
quad_weights = 2.0 * torch.pi * wgl.to(dtype=torch.float32) / self.nlon_in
# we need to tile and flatten them accordingly
quad_weights = torch.tile(quad_weights, (1, self.nlon_in)).flatten()
# compute log because they are applied as an addition prior to the softmax ('attn_mask'), which includes an exponential.
# see https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
# for info on how 'attn_mask' is applied to the attention weights
log_quad_weights = torch.log(quad_weights).reshape(1,1,-1)
self.register_buffer("log_quad_weights", log_quad_weights, persistent=False)
# learnable parameters
# TODO: double-check that this gives us the correct initialization magnitudes
# the standard MHA uses xavier uniform, NATTEN uses kaiming. Let's use that for now
if self.k_channels % self.num_heads != 0:
raise ValueError(f"Please make sure that number of heads {self.num_heads} divides k_channels {self.k_channels} evenly.")
if self.out_channels % self.num_heads != 0:
raise ValueError(f"Please make sure that number of heads {self.num_heads} divides out_channels {self.out_channels} evenly.")
scale_qkv = math.sqrt(3.0 / self.in_channels)
self.q_weights = nn.Parameter(scale_qkv * (2 * torch.rand(self.k_channels, self.in_channels, 1, 1) - 1))
self.k_weights = nn.Parameter(scale_qkv * (2 * torch.rand(self.k_channels, self.in_channels, 1, 1) - 1))
self.v_weights = nn.Parameter(scale_qkv * (2 * torch.rand(self.out_channels, self.in_channels, 1, 1) - 1))
scale_proj = math.sqrt(3.0 / self.out_channels)
self.proj_weights = nn.Parameter(scale_proj * (2 * torch.rand(self.out_channels, self.out_channels, 1, 1) - 1))
if bias:
self.q_bias = nn.Parameter(torch.zeros(self.k_channels))
self.k_bias = nn.Parameter(torch.zeros(self.k_channels))
self.v_bias = nn.Parameter(torch.zeros(self.out_channels))
self.proj_bias = nn.Parameter(torch.zeros(self.out_channels))
else:
self.q_bias = None
self.k_bias = None
self.v_bias = None
self.proj_bias = None
def extra_repr(self):
r"""
Pretty print module
"""
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_channels={self.in_channels}, out_channels={self.out_channels}, k_channels={self.k_channels}"
def forward(self, query: torch.Tensor, key: Optional[torch.Tensor] = None, value: Optional[torch.Tensor] = None) -> torch.Tensor:
# self attention simplification
if key is None:
key = query
if value is None:
value = query
# change this later to allow arbitrary number of batch dims
assert (query.dim() == key.dim()) and (key.dim() == value.dim()) and (value.dim() == 4)
# perform MLP
query = nn.functional.conv2d(query, self.q_weights, bias=self.q_bias)
key = nn.functional.conv2d(key, self.k_weights, bias=self.k_bias)
value = nn.functional.conv2d(value, self.v_weights, bias=self.v_bias)
# reshape
B, _, H, W = query.shape
query = query.reshape(B, self.num_heads, -1, H, W)
B, _, H, W = key.shape
key = key.reshape(B, self.num_heads, -1, H, W)
B, _, H, W = value.shape
value = value.reshape(B, self.num_heads, -1, H, W)
# reshape to the right dimensions
B, _, C, H, W = query.shape
query = query.permute(0,1,3,4,2).reshape(B, self.num_heads, H*W, C)
B, _, C, H, W = key.shape
key = key.permute(0,1,3,4,2).reshape(B, self.num_heads, H*W, C)
B, _, C, H, W = value.shape
value = value.permute(0,1,3,4,2).reshape(B, self.num_heads, H*W, C)
# multiply the query, key and value tensors
out = nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=self.log_quad_weights, dropout_p=self.drop_rate, scale=self.scale)
# reshape
B, _, _, C = out.shape
# (B, heads, H*W, C)
out = out.permute(0,1,3,2)
# (B, heads, C, H*W)
out = out.reshape(B, self.num_heads*C, self.nlat_out, self.nlon_out)
# (B, heads*C, H, W)
out = nn.functional.conv2d(out, self.proj_weights, bias=self.proj_bias)
return out
class NeighborhoodAttentionS2(nn.Module):
"""
Neighborhood attention on the 2-sphere.
Parameters
-----------
in_channels: int
number of channels of the input signal (corresponds to embed_dim in MHA in PyTorch)
in_shape: tuple
shape of the input grid
out_shape: tuple
shape of the output grid
grid_in: str, optional
input grid type, "equiangular" by default
grid_out: str, optional
output grid type, "equiangular" by default
bias: bool, optional
if specified, adds bias to input / output projection layers
theta_cutoff: float, optional
neighborhood size
k_channels: int
number of dimensions for interior inner product in the attention matrix (corresponds to kdim in MHA in PyTorch)
out_channels: int, optional
number of dimensions for interior inner product in the attention matrix (corresponds to vdim in MHA in PyTorch)
"""
def __init__(
self,
in_channels: int,
in_shape: Tuple[int],
out_shape: Tuple[int],
grid_in: Optional[str] = "equiangular",
grid_out: Optional[str] = "equiangular",
num_heads: Optional[int] = 1,
scale: Optional[Union[torch.Tensor, float]] = None,
bias: Optional[bool] = True,
theta_cutoff: Optional[float] = None,
k_channels: Optional[int] = None,
out_channels: Optional[int] = None,
):
super().__init__()
self.nlat_in, self.nlon_in = in_shape
self.nlat_out, self.nlon_out = out_shape
self.in_channels = in_channels
self.num_heads = num_heads
self.k_channels = in_channels if k_channels is None else k_channels
self.out_channels = in_channels if out_channels is None else out_channels
# heuristic to compute theta cutoff based on the bandlimit of the input field and overlaps of the basis functions
if theta_cutoff is None:
theta_cutoff = torch.pi / float(self.nlat_out - 1)
if theta_cutoff <= 0.0:
raise ValueError("Error, theta_cutoff has to be positive.")
# integration weights
_, wgl = _precompute_latitudes(self.nlat_in, grid=grid_in)
quad_weights = 2.0 * torch.pi * wgl.to(dtype=torch.float32) / self.nlon_in
self.register_buffer("quad_weights", quad_weights, persistent=False)
# create a dummy filter basis to pass to the construction of the convolution tensor
# this is to avoid code duplication as the logic of pre-computing the sparsity pattern
# is identical to convolutions with a constant filter function
fb = get_filter_basis(kernel_shape=1, basis_type="zernike")
# precompute the neighborhood sparsity pattern
idx, vals = _precompute_convolution_tensor_s2(
in_shape,
out_shape,
fb,
grid_in=grid_in,
grid_out=grid_out,
theta_cutoff=theta_cutoff,
transpose_normalization=False,
basis_norm_mode="none",
merge_quadrature=True,
)
# this is kept for legacy resons in case we want to resuse sorting of these entries
row_idx = idx[1, ...].contiguous()
col_idx = idx[2, ...].contiguous()
# compute row offsets for more structured traversal.
# only works if rows are sorted but they are by construction
row_offset = np.empty(self.nlat_out + 1, dtype=np.int64)
row_offset[0] = 0
row = row_idx[0]
for idz, z in enumerate(range(col_idx.shape[0])):
if row_idx[z] != row:
row_offset[row + 1] = idz
row = row_idx[z]
# set the last value
row_offset[row + 1] = idz + 1
row_offset = torch.from_numpy(row_offset)
self.max_psi_nnz = col_idx.max().item() + 1
self.register_buffer("psi_row_idx", row_idx, persistent=False)
self.register_buffer("psi_col_idx", col_idx, persistent=False)
self.register_buffer("psi_roff_idx", row_offset, persistent=False)
# self.register_buffer("psi_vals", vals, persistent=False)
# learnable parameters
# TODO: double-check that this gives us the correct initialization magnitudes
# the standard MHA uses xavier uniform, NATTEN uses kaiming. Let's use that for now
if self.k_channels % self.num_heads != 0:
raise ValueError(f"Please make sure that number of heads {self.num_heads} divides k_channels {self.k_channels} evenly.")
if self.out_channels % self.num_heads != 0:
raise ValueError(f"Please make sure that number of heads {self.num_heads} divides out_channels {self.out_channels} evenly.")
scale_qkv = math.sqrt(3.0 / self.in_channels)
self.q_weights = nn.Parameter(scale_qkv * (2 * torch.rand(self.k_channels, self.in_channels, 1, 1) - 1))
self.k_weights = nn.Parameter(scale_qkv * (2 * torch.rand(self.k_channels, self.in_channels, 1, 1) - 1))
self.v_weights = nn.Parameter(scale_qkv * (2 * torch.rand(self.out_channels, self.in_channels, 1, 1) - 1))
scale_proj = math.sqrt(3.0 / self.out_channels)
self.proj_weights = nn.Parameter(scale_proj * (2 * torch.rand(self.out_channels, self.out_channels, 1, 1) - 1))
if scale is not None:
self.scale = scale
else:
self.scale = 1 / math.sqrt(self.k_channels)
if bias:
self.q_bias = nn.Parameter(torch.zeros(self.k_channels))
self.k_bias = nn.Parameter(torch.zeros(self.k_channels))
self.v_bias = nn.Parameter(torch.zeros(self.out_channels))
self.proj_bias = nn.Parameter(torch.zeros(self.out_channels))
else:
self.q_bias = None
self.k_bias = None
self.v_bias = None
self.proj_bias = None
def extra_repr(self):
r"""
Pretty print module
"""
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_channels={self.in_channels}, out_channels={self.out_channels}, k_channels={self.k_channels}"
def forward(self, query: torch.Tensor, key: Optional[torch.Tensor] = None, value: Optional[torch.Tensor] = None) -> torch.Tensor:
# self attention simplification
if key is None:
key = query
if value is None:
value = query
# change this later to allow arbitrary number of batch dims
assert (query.dim() == key.dim()) and (key.dim() == value.dim()) and (value.dim() == 4)
# do the scaling
query_scaled = query * self.scale
# TODO: insert dimension checks for input
if query.is_cuda and _cuda_extension_available:
out = _neighborhood_attention_s2_cuda(
key,
value,
query_scaled,
self.k_weights,
self.v_weights,
self.q_weights,
self.k_bias,
self.v_bias,
self.q_bias,
self.quad_weights,
self.psi_col_idx,
self.psi_roff_idx,
self.max_psi_nnz,
self.num_heads,
self.nlon_in,
self.nlat_out,
self.nlon_out,
)
else:
if query.is_cuda:
warn("couldn't find CUDA extension, falling back to slow PyTorch implementation")
# call attention
out = _neighborhood_attention_s2_torch(
key,
value,
query_scaled,
self.k_weights,
self.v_weights,
self.q_weights,
self.k_bias,
self.v_bias,
self.q_bias,
self.quad_weights,
self.psi_col_idx,
self.psi_roff_idx,
self.num_heads,
self.nlon_in,
self.nlat_out,
self.nlon_out,
)
out = nn.functional.conv2d(out, self.proj_weights, bias=self.proj_bias)
return out
// coding=utf-8
//
// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <cmath>
#include <cstdint>
#include <torch/torch.h>
#define CHECK_CUDA_TENSOR(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, at::Tensor vx,
at::Tensor qy, at::Tensor quad_weights,
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out);
std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tensor kx, at::Tensor vx,
at::Tensor qy,
at::Tensor dy,
at::Tensor quad_weights,
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out);
torch::Tensor s2_attention_bwd_dq_cuda(at::Tensor kx,
at::Tensor vx,
at::Tensor qy,
at::Tensor dy,
at::Tensor quad_weights,
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out);
torch::Tensor s2_attention_bwd_dk_cuda(at::Tensor kx,
at::Tensor vx,
at::Tensor qy,
at::Tensor dy,
at::Tensor quad_weights,
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out);
torch::Tensor s2_attention_bwd_dv_cuda(at::Tensor kx,
at::Tensor vx,
at::Tensor qy,
at::Tensor dy,
at::Tensor quad_weights,
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out);
int s2_idx_offset_cuda(const at::Tensor &psi_col_idx,
const at::Tensor &psi_row_idx, at::Tensor &row_offset,
at::Tensor &row_count);
// coding=utf-8
//
// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "attention.cuh"
#include <ATen/core/TensorAccessor.h>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <ATen/cuda/detail/KernelUtils.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/CUDAUtils.h>
#include <cub/cub.cuh>
#include <limits>
using BlockReduceFloat256 = cub::BlockReduce<float, 256>;
using BlockReduceFloat512 = cub::BlockReduce<float, 512>;
__device__ static float atomicMax(float* address, float val)
{
int* address_as_i = (int*) address;
int old = *address_as_i, assumed;
do {
assumed = old;
old = ::atomicCAS(address_as_i, assumed,
__float_as_int(::fmaxf(val, __int_as_float(assumed))));
} while (assumed != old);
return __int_as_float(old);
}
__global__ void
s2_attention_bwd_dv_kernel(int num_channels, int nlon_in, int nlat_out, int nlon_out,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dy,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydv,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights)
{
// shared memory
extern __shared__ float sharedMem[];
float* sh_alpha_sum = (float*)&sharedMem; // 1
float* sh_qdotk_max = (float*)&sharedMem[1]; // 1
float* sh_qy_ho_wo = (float*)&sharedMem[2]; // num_channels
if (threadIdx.x == 0) {
sh_qdotk_max[0] = std::numeric_limits<float>::lowest();
sh_alpha_sum[0] = 0.0;
}
__syncthreads();
int ho = blockIdx.x;
int wo = blockIdx.y;
int batch_b = blockIdx.z;
// load qy channels into shared memory
for(int channel_block_i = 0; channel_block_i<(num_channels/blockDim.x)+1; channel_block_i++) {
int channel_idx = channel_block_i*blockDim.x + threadIdx.x;
if(channel_idx >= num_channels) break;
sh_qy_ho_wo[channel_idx] = qy[batch_b][channel_idx][ho][wo];
}
__syncthreads();
int psi_offset = psi_row_offset[ho];
int psi_nnz_ho = psi_row_offset[ho + 1] - psi_offset;
float qdotk_max = std::numeric_limits<float>::lowest();
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
// correlation Q&K (dot-product Q.K)
float qdotk = 0.0;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
qdotk += sh_qy_ho_wo[channel_idx]*kx[batch_b][channel_idx][hi][wip];
}
qdotk_max = std::max(qdotk, qdotk_max);
}
// collect thread-local qdotk max
atomicMax(&sh_qdotk_max[0], qdotk_max);
__syncthreads();
// "broadcast" qdotk_max back into all thread-local registers
qdotk_max = sh_qdotk_max[0];
// form alpha & sum alpha
float alpha_sum = 0.0;
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
// softmax numerator
float qdotk = 0.0;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
qdotk += sh_qy_ho_wo[channel_idx]*kx[batch_b][channel_idx][hi][wip];
}
float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
// sum alpha
alpha_sum += alpha_inz;
}
// collect thread-local alpha_sum
atomicAdd(&sh_alpha_sum[0], alpha_sum);
__syncthreads();
// "broadcast" alpha sum back to thread-local registers
alpha_sum = sh_alpha_sum[0];
// alpha * dy * omega / alpha_sum
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
float qdotk = 0.0;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
qdotk += sh_qy_ho_wo[channel_idx]*kx[batch_b][channel_idx][hi][wip];
}
float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
// multiply alpha/sum_alpha, dy, and quadrature weights
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
atomicAdd(&dydv[batch_b][channel_idx][hi][wip], (alpha_inz/alpha_sum) * dy[batch_b][channel_idx][ho][wo]);
}
}
}
at::Tensor s2_attention_bwd_dv_cuda(at::Tensor kx,
at::Tensor vx,
at::Tensor qy,
at::Tensor dy,
at::Tensor quad_weights,
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out) {
CHECK_CUDA_TENSOR(kx);
CHECK_CUDA_TENSOR(vx);
CHECK_CUDA_TENSOR(qy);
CHECK_CUDA_TENSOR(quad_weights);
CHECK_CUDA_TENSOR(psi_col_idx);
CHECK_CUDA_TENSOR(psi_row_off);
CHECK_CUDA_TENSOR(dy);
auto stream = at::cuda::getCurrentCUDAStream().stream();
torch::Tensor dydv = torch::zeros_like(vx);
size_t uo_num_channels = kx.size(1);
size_t sharedMemSize = (uo_num_channels+2)*sizeof(float);
const int batch_size = kx.size(0);
// cuda grid y,z size limitations
assert(nlon_out < 65535);
assert(batch_size < 65535);
// block-parallel over output points and batches
dim3 gridDim(nlat_out,nlon_out,batch_size);
// threads compute "blocks" of neighborhood and also "blocks" of channels
dim3 blockDim(256, 1, 1);
s2_attention_bwd_dv_kernel <<<gridDim, blockDim, sharedMemSize, stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out,
kx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydv.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>()
);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return dydv;
}
__global__ void
s2_attention_bwd_dk_kernel(int num_channels, int nlon_in, int nlat_out, int nlon_out,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dy,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydk,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights)
{
// shared memory
extern __shared__ float sharedMem[];
float* sh_alpha_sum = (float*)&sharedMem;
float *sh_qy_ho_wo = (float *)&sharedMem[1];
float *sh_integral = (float *)&sharedMem[1 + num_channels];
float *sh_dy_ho_wo = (float *)&sharedMem[2 + num_channels];
float *sh_qdotk_max = (float *)&sharedMem[2 + 2 * num_channels];
if (threadIdx.x == 0) {
sh_alpha_sum[0] = 0.0;
sh_integral[0] = 0.0;
sh_qdotk_max[0] = std::numeric_limits<float>::lowest();
}
__syncthreads();
int ho = blockIdx.x;
int wo = blockIdx.y;
int batch_b = blockIdx.z;
// load qy channels into shared memory
for(int channel_block_i = 0; channel_block_i<(num_channels/blockDim.x)+1; channel_block_i++) {
int channel_idx = channel_block_i*blockDim.x + threadIdx.x;
if(channel_idx >= num_channels) break;
sh_qy_ho_wo[channel_idx] = qy[batch_b][channel_idx][ho][wo];
sh_dy_ho_wo[channel_idx] = dy[batch_b][channel_idx][ho][wo];
}
__syncthreads();
int psi_offset = psi_row_offset[ho];
int psi_nnz_ho = psi_row_offset[ho + 1] - psi_offset;
float qdotk_max = std::numeric_limits<float>::lowest();
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
// correlation Q&K (dot-product Q.K)
float qdotk = 0.0;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
qdotk += sh_qy_ho_wo[channel_idx]*kx[batch_b][channel_idx][hi][wip];
}
qdotk_max = max(qdotk_max, qdotk);
}
// compute max over all threads
atomicMax(&sh_qdotk_max[0], qdotk_max);
__syncthreads();
// "broadcast" qdotk_max back into all thread-local registers
qdotk_max = sh_qdotk_max[0];
float alpha_sum = 0.0;
float integral = 0.0;
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
// correlation Q&K (dot-product Q.K)
float gdotv = 0.0;
float qdotk = 0.0;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
gdotv += sh_dy_ho_wo[channel_idx] * vx[batch_b][channel_idx][hi][wip];
qdotk += sh_qy_ho_wo[channel_idx]*kx[batch_b][channel_idx][hi][wip];
}
// softmax numerator
float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
// sum alpha & integral
alpha_sum += alpha_inz;
integral += alpha_inz * gdotv;
}
// block sum thread-local alpha_sum and integral
atomicAdd(&sh_alpha_sum[0], alpha_sum);
atomicAdd(&sh_integral[0], integral);
__syncthreads();
// finish integral computation
if(threadIdx.x==0) sh_integral[0] /= sh_alpha_sum[0];
__syncthreads();
// broadcast sum and integral back to thread-local registers
integral = sh_integral[0];
alpha_sum = sh_alpha_sum[0];
// divide output by alpha_sum
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
float gdotv = 0.0;
float qdotk = 0.0;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
gdotv += sh_dy_ho_wo[channel_idx] * vx[batch_b][channel_idx][hi][wip];
qdotk += sh_qy_ho_wo[channel_idx]*kx[batch_b][channel_idx][hi][wip];
}
float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
// multiply alpha/sum_alpha, vx, and quadrature weights
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
atomicAdd(&dydk[batch_b][channel_idx][hi][wip],
sh_qy_ho_wo[channel_idx] * (alpha_inz/alpha_sum) * (gdotv - integral));
}
}
__syncthreads();
}
__global__ void
s2_attention_bwd_dq_kernel(int num_channels, int nlon_in, int nlat_out, int nlon_out,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dy,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydq,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights)
{
// shared memory
extern __shared__ float sharedMem[];
float* sh_alpha_sum = (float*)&sharedMem;
float *sh_qy_ho_wo = (float *)&sharedMem[1];
float *sh_alpha_k = (float *)&sharedMem[1 + num_channels];
float *sh_alpha_vw = (float *)&sharedMem[1 + 2*num_channels];
float *sh_alpha_kvw = (float *)&sharedMem[1 + 3*num_channels];
float *sh_dy_ho_wo = (float *)&sharedMem[1 + 4 * num_channels];
float *sh_qdotk_max = (float *)&sharedMem[1 + 5 * num_channels];
if (threadIdx.x == 0) {
sh_alpha_sum[0] = 0.0;
sh_qdotk_max[0] = std::numeric_limits<float>::lowest();
}
__syncthreads();
int ho = blockIdx.x;
int wo = blockIdx.y;
int batch_b = blockIdx.z;
// load qy channels into shared memory and zero temporary variables
for(int channel_block_i = 0; channel_block_i<(num_channels/blockDim.x)+1; channel_block_i++) {
int channel_idx = channel_block_i*blockDim.x + threadIdx.x;
if(channel_idx >= num_channels) break;
sh_qy_ho_wo[channel_idx] = qy[batch_b][channel_idx][ho][wo];
sh_dy_ho_wo[channel_idx] = dy[batch_b][channel_idx][ho][wo];
sh_alpha_k[channel_idx] = 0.0f;
sh_alpha_vw[channel_idx] = 0.0f;
sh_alpha_kvw[channel_idx] = 0.0f;
}
__syncthreads();
int psi_offset = psi_row_offset[ho];
int psi_nnz_ho = psi_row_offset[ho + 1] - psi_offset;
float qdotk_max = std::numeric_limits<float>::lowest();
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
// correlation Q&K (dot-product Q.K)
float qdotk = 0.0f;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
qdotk += sh_qy_ho_wo[channel_idx] * kx[batch_b][channel_idx][hi][wip];
}
qdotk_max = std::max(qdotk, qdotk_max);
}
atomicMax(&sh_qdotk_max[0], qdotk_max);
__syncthreads();
// "broadcast" qdotk_max back into all thread-local registers
qdotk_max = sh_qdotk_max[0];
float alpha_sum = 0.0;
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
// correlation Q&K (dot-product Q.K)
float qdotk = 0.0f;
float gdotv = 0.0f;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
gdotv += sh_dy_ho_wo[channel_idx] * vx[batch_b][channel_idx][hi][wip];
qdotk += sh_qy_ho_wo[channel_idx] * kx[batch_b][channel_idx][hi][wip];
}
// softmax numerator
float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
// sum alpha
alpha_sum += alpha_inz;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
atomicAdd(&sh_alpha_k[channel_idx],
alpha_inz * kx[batch_b][channel_idx][hi][wip]);
atomicAdd(&sh_alpha_vw[channel_idx],
alpha_inz * gdotv);
atomicAdd(&sh_alpha_kvw[channel_idx],
alpha_inz * kx[batch_b][channel_idx][hi][wip] * gdotv);
}
}
// sum thread-local alpha_sums across block
atomicAdd(&sh_alpha_sum[0], alpha_sum);
__syncthreads();
// "broadcast" alpha sum back to thread-local registers
alpha_sum = sh_alpha_sum[0];
for(int channel_block_i = 0; channel_block_i<(num_channels/blockDim.x)+1; channel_block_i++) {
int channel_idx = channel_block_i*blockDim.x + threadIdx.x;
if (channel_idx >= num_channels)
break;
dydq[batch_b][channel_idx][ho][wo] = (sh_alpha_kvw[channel_idx]*sh_alpha_sum[0] - sh_alpha_vw[channel_idx]*sh_alpha_k[channel_idx])/(alpha_sum*alpha_sum);
}
}
__global__ void s2_attention_bwd_dkvq_kernel(int num_channels, int nlon_in, int nlat_out, int nlon_out,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits>
dy,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydk,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydv,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydq,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights)
{
// shared memory
extern __shared__ float sharedMem[];
float *sh_alpha_sum = (float *)&sharedMem;
float* sh_integral = (float*)&sharedMem[1];
float *sh_qy_ho_wo = (float *)&sharedMem[2];
float *sh_alpha_k = (float *)&sharedMem[2 + num_channels];
float *sh_alpha_vw = (float *)&sharedMem[2 + 2*num_channels];
float *sh_alpha_kvw = (float *)&sharedMem[2 + 3*num_channels];
float *sh_dy_ho_wo = (float *)&sharedMem[2 + 4 * num_channels];
float *sh_qdotk_max = (float *)&sharedMem[2 + 5 * num_channels];
if (threadIdx.x == 0) {
sh_alpha_sum[0] = 0.0;
sh_integral[0] = 0.0;
sh_qdotk_max[0] = std::numeric_limits<float>::lowest();
}
__syncthreads();
int ho = blockIdx.x;
int wo = blockIdx.y;
int batch_b = blockIdx.z;
// load qy channels into shared memory and zero temporary variables
for(int channel_block_i = 0; channel_block_i<(num_channels/blockDim.x)+1; channel_block_i++) {
int channel_idx = channel_block_i*blockDim.x + threadIdx.x;
if(channel_idx >= num_channels) break;
sh_qy_ho_wo[channel_idx] = qy[batch_b][channel_idx][ho][wo];
sh_dy_ho_wo[channel_idx] = dy[batch_b][channel_idx][ho][wo];
sh_alpha_k[channel_idx] = 0.0f;
sh_alpha_vw[channel_idx] = 0.0f;
sh_alpha_kvw[channel_idx] = 0.0f;
}
__syncthreads();
int psi_offset = psi_row_offset[ho];
int psi_nnz_ho = psi_row_offset[ho + 1] - psi_offset;
float qdotk_max = std::numeric_limits<float>::lowest();
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
// correlation Q&K (dot-product Q.K)
float qdotk = 0.0f;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
qdotk += sh_qy_ho_wo[channel_idx] * kx[batch_b][channel_idx][hi][wip];
}
qdotk_max = std::max(qdotk, qdotk_max);
}
atomicMax(&sh_qdotk_max[0], qdotk_max);
__syncthreads();
// "broadcast" qdotk_max back into all thread-local registers
qdotk_max = sh_qdotk_max[0];
float alpha_sum = 0.0;
float integral = 0.0;
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
// correlation Q&K (dot-product Q.K)
float qdotk = 0.0f;
float gdotv = 0.0f;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
gdotv += sh_dy_ho_wo[channel_idx] * vx[batch_b][channel_idx][hi][wip];
qdotk += sh_qy_ho_wo[channel_idx] * kx[batch_b][channel_idx][hi][wip];
}
// softmax numerator
float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
// sum alpha
alpha_sum += alpha_inz;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
atomicAdd(&sh_alpha_k[channel_idx],
alpha_inz * kx[batch_b][channel_idx][hi][wip]);
atomicAdd(&sh_alpha_vw[channel_idx],
alpha_inz * gdotv);
atomicAdd(&sh_alpha_kvw[channel_idx],
alpha_inz * kx[batch_b][channel_idx][hi][wip] * gdotv);
}
integral += alpha_inz * gdotv;
}
// sum thread-local alpha_sums & integral across block
atomicAdd(&sh_alpha_sum[0], alpha_sum);
atomicAdd(&sh_integral[0], integral);
__syncthreads();
// finalize integral
if(threadIdx.x==0) sh_integral[0] /= sh_alpha_sum[0];
__syncthreads();
// "broadcast" alpha sum & integral back to thread-local registers
alpha_sum = sh_alpha_sum[0];
integral = sh_integral[0];
// dq
for(int channel_block_i = 0; channel_block_i<(num_channels/blockDim.x)+1; channel_block_i++) {
int channel_idx = channel_block_i*blockDim.x + threadIdx.x;
if (channel_idx >= num_channels)
break;
dydq[batch_b][channel_idx][ho][wo] = (sh_alpha_kvw[channel_idx]*sh_alpha_sum[0] - sh_alpha_vw[channel_idx]*sh_alpha_k[channel_idx])/(alpha_sum*alpha_sum);
}
__syncthreads();
// dk & dv
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
float gdotv = 0.0;
float qdotk = 0.0;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
gdotv += sh_dy_ho_wo[channel_idx] * vx[batch_b][channel_idx][hi][wip];
qdotk += sh_qy_ho_wo[channel_idx]*kx[batch_b][channel_idx][hi][wip];
}
float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
// multiply alpha/sum_alpha, vx, and quadrature weights
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
atomicAdd(&dydk[batch_b][channel_idx][hi][wip],
sh_qy_ho_wo[channel_idx] * (alpha_inz / alpha_sum) *
(gdotv - integral));
atomicAdd(&dydv[batch_b][channel_idx][hi][wip], (alpha_inz/alpha_sum) * sh_dy_ho_wo[channel_idx]);
}
}
__syncthreads();
}
at::Tensor s2_attention_bwd_dk_cuda(at::Tensor kx,
at::Tensor vx,
at::Tensor qy,
at::Tensor dy,
at::Tensor quad_weights,
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out) {
CHECK_CUDA_TENSOR(kx);
CHECK_CUDA_TENSOR(vx);
CHECK_CUDA_TENSOR(qy);
CHECK_CUDA_TENSOR(quad_weights);
CHECK_CUDA_TENSOR(psi_col_idx);
CHECK_CUDA_TENSOR(psi_row_off);
CHECK_CUDA_TENSOR(dy);
auto stream = at::cuda::getCurrentCUDAStream().stream();
torch::Tensor dydk = torch::zeros_like(kx);
size_t uo_num_channels = kx.size(1);
size_t sharedMemSize = (2*uo_num_channels+3)*sizeof(float);
const int batch_size = kx.size(0);
// cuda grid y,z size limitations
assert(nlon_out < 65535);
assert(batch_size < 65535);
// block-parallel over output points and batches
dim3 gridDim(nlat_out,nlon_out,batch_size);
// threads compute "blocks" of neighborhood and also "blocks" of channels
dim3 blockDim(256, 1, 1);
s2_attention_bwd_dk_kernel <<<gridDim, blockDim, sharedMemSize, stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out,
kx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydk.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>()
);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return dydk;
}
at::Tensor s2_attention_bwd_dq_cuda(at::Tensor kx,
at::Tensor vx,
at::Tensor qy,
at::Tensor dy,
at::Tensor quad_weights,
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out) {
CHECK_CUDA_TENSOR(kx);
CHECK_CUDA_TENSOR(vx);
CHECK_CUDA_TENSOR(qy);
CHECK_CUDA_TENSOR(quad_weights);
CHECK_CUDA_TENSOR(psi_col_idx);
CHECK_CUDA_TENSOR(psi_row_off);
CHECK_CUDA_TENSOR(dy);
auto stream = at::cuda::getCurrentCUDAStream().stream();
torch::Tensor dydq = torch::zeros_like(qy);
size_t uo_num_channels = kx.size(1);
size_t sharedMemSize = (5*uo_num_channels+2)*sizeof(float);
const int batch_size = kx.size(0);
// cuda grid y,z size limitations
assert(nlon_out < 65535);
assert(batch_size < 65535);
// block-parallel over output points and batches
dim3 gridDim(nlat_out,nlon_out,batch_size);
// threads compute "blocks" of neighborhood and also "blocks" of channels
dim3 blockDim(256, 1, 1);
s2_attention_bwd_dq_kernel <<<gridDim, blockDim, sharedMemSize, stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out,
kx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydq.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>()
);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return dydq;
}
std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tensor kx, at::Tensor vx,
at::Tensor qy,
at::Tensor dy,
at::Tensor quad_weights,
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out) {
CHECK_CUDA_TENSOR(kx);
CHECK_CUDA_TENSOR(vx);
CHECK_CUDA_TENSOR(qy);
CHECK_CUDA_TENSOR(quad_weights);
CHECK_CUDA_TENSOR(psi_col_idx);
CHECK_CUDA_TENSOR(psi_row_off);
CHECK_CUDA_TENSOR(dy);
auto stream = at::cuda::getCurrentCUDAStream().stream();
torch::Tensor dydk = torch::zeros_like(qy);
torch::Tensor dydv = torch::zeros_like(qy);
torch::Tensor dydq = torch::zeros_like(qy);
size_t uo_num_channels = kx.size(1);
size_t sharedMemSize = (6*uo_num_channels+3)*sizeof(float);
const int batch_size = kx.size(0);
// cuda grid y,z size limitations
assert(nlon_out < 65535);
assert(batch_size < 65535);
// block-parallel over output points and batches
dim3 gridDim(nlat_out,nlon_out,batch_size);
// threads compute "blocks" of neighborhood and also "blocks" of channels
dim3 blockDim(256, 1, 1);
s2_attention_bwd_dkvq_kernel<<<gridDim, blockDim, sharedMemSize, stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out,
kx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydk.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydv.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydq.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>()
);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return std::make_tuple(dydk, dydv, dydq);
}
// coding=utf-8
//
// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "attention.cuh"
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <ATen/cuda/detail/KernelUtils.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/CUDAUtils.h>
#include <cub/cub.cuh>
#include <limits>
using BlockReduceFloat256 = cub::BlockReduce<float, 256>;
using BlockReduceFloat512 = cub::BlockReduce<float, 512>;
__device__ static float atomicMax(float* address, float val)
{
int* address_as_i = (int*) address;
int old = *address_as_i, assumed;
do {
assumed = old;
old = ::atomicCAS(address_as_i, assumed,
__float_as_int(::fmaxf(val, __int_as_float(assumed))));
} while (assumed != old);
return __int_as_float(old);
}
__global__ void s2_attention_kernel(int num_channels, int nlon_in, int nlat_out,
int nlon_out,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> y,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights)
{
// shared memory
extern __shared__ float sharedMem[];
float *sh_alpha_sum = (float *)&sharedMem;
float* sh_qdotk_max = (float*)&sharedMem[1];
float* sh_qy_ho_wo = (float *)&sharedMem[2];
if (threadIdx.x == 0) {
sh_qdotk_max[0] = std::numeric_limits<float>::lowest();
sh_alpha_sum[0] = 0.0;
}
__syncthreads();
int ho = blockIdx.x;
int wo = blockIdx.y;
int batch_b = blockIdx.z;
// load qy channels into shared memory
for(int channel_block_i = 0; channel_block_i<(num_channels/blockDim.x)+1; channel_block_i++) {
int channel_idx = channel_block_i*blockDim.x + threadIdx.x;
if(channel_idx >= num_channels) break;
sh_qy_ho_wo[channel_idx] = qy[batch_b][channel_idx][ho][wo];
y[batch_b][channel_idx][ho][wo] = 0.0;
}
__syncthreads();
int psi_offset = psi_row_offset[ho];
int psi_nnz_ho = psi_row_offset[ho + 1] - psi_offset;
float qdotk_max = std::numeric_limits<float>::lowest();
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
// correlation Q&K (dot-product Q.K)
float qdotk = 0.0;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
qdotk += sh_qy_ho_wo[channel_idx]*kx[batch_b][channel_idx][hi][wip];
}
qdotk_max = std::max(qdotk_max, qdotk);
}
// collect thread-local qdotk max
atomicMax(&sh_qdotk_max[0], qdotk_max);
__syncthreads();
// "broadcast" qdotk_max back into all thread-local registers
qdotk_max = sh_qdotk_max[0];
float alpha_sum = 0.0f;
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
float alpha_inz = 0.0;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz < psi_nnz_ho) {
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
// softmax numerator with minus qdotk_max to avoid numerical overflow.
// Because qdotk_max is in both numerator and denominator (due to
// alpha_sum), it doesn't effect the solution other than removing overflow
// correlation Q&K (dot-product Q.K)
float qdotk = 0.0;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
qdotk += sh_qy_ho_wo[channel_idx]*kx[batch_b][channel_idx][hi][wip];
}
alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
// thread-local sum alpha
alpha_sum += alpha_inz;
// multiply alpha, vx, and quadrature weights
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
atomicAdd(&y[batch_b][channel_idx][ho][wo], alpha_inz * vx[batch_b][channel_idx][hi][wip]);
}
}
}
// collect all alpha_sum across threads
atomicAdd(&sh_alpha_sum[0], alpha_sum);
__syncthreads();
// rebroadcast sum to all threads
alpha_sum = sh_alpha_sum[0];
// divide output by alpha_sum
for(int channel_block_i = 0; channel_block_i<(num_channels/blockDim.x)+1; channel_block_i++) {
int channel_idx = channel_block_i*blockDim.x + threadIdx.x;
if(channel_idx >= num_channels) break;
y[batch_b][channel_idx][ho][wo] /= alpha_sum;
}
}
torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
at::Tensor vx,
at::Tensor qy,
at::Tensor quad_weights,
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
int nlon_in,
int nlat_out,
int nlon_out) {
CHECK_CUDA_TENSOR(kx);
CHECK_CUDA_TENSOR(vx);
CHECK_CUDA_TENSOR(qy);
CHECK_CUDA_TENSOR(quad_weights);
CHECK_CUDA_TENSOR(psi_col_idx);
CHECK_CUDA_TENSOR(psi_row_off);
// TODO: check sizes
auto stream = at::cuda::getCurrentCUDAStream().stream();
// allocate output
torch::Tensor y = torch::zeros_like(qy);
size_t uo_num_channels = kx.size(1);
size_t sharedMemSize = (uo_num_channels+2)*sizeof(float);
const int batch_size = kx.size(0);
// cuda grid y,z size limitations
assert(nlon_out < 65535);
assert(batch_size < 65535);
// block-parallel over output points and batches
dim3 gridDim(nlat_out,nlon_out,batch_size);
// threads compute "blocks" of neighborhood and also "blocks" of channels
// note: blocksize of 512 runs into resource limits
dim3 blockDim(256,1,1);
s2_attention_kernel<<<gridDim, blockDim, sharedMemSize,stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out,
kx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
y.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>()
);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return y;
}
// coding=utf-8
//
// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "attention.cuh"
#include <torch/extension.h>
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &s2_attention_fwd_cuda, "(Local) Attention on S2");
m.def("compute_row_offset", &s2_idx_offset_cuda, "Row offset on S2");
m.def("backward_dk", &s2_attention_bwd_dk_cuda, "(Local) Attention gradient on S2 (gradient for k)");
m.def("backward_dv", &s2_attention_bwd_dv_cuda, "(Local) Attention gradient on S2 (gradient for v)");
m.def("backward_dq", &s2_attention_bwd_dq_cuda,
"(Local) Attention gradient on S2 (gradient for q)");
m.def("backward_dkvq", &s2_attention_bwd_dkvq_cuda, "(Local) Attention gradient on S2 (gradient for k,v,&q)");
}
#include "ATen/core/TensorAccessor.h"
#include <cmath>
#include <cstdint>
#include <torch/extension.h>
#include <torch/torch.h>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <ATen/cuda/detail/KernelUtils.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/CUDAUtils.h>
#include <thrust/reduce.h>
#include <thrust/functional.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/binary_search.h>
#include <thrust/execution_policy.h>
#include <thrust/sequence.h>
__global__ void alpha_count_kernel(int len_alpha_count,
int len_psi_row_idx,
torch::PackedTensorAccessor64<int64_t, 1> psi_row_idx,
int64_t* alpha_start,
torch::PackedTensorAccessor64<int64_t, 1> alpha_count
) {
int ho = blockIdx.x * blockDim.x + threadIdx.x;
if(ho < len_alpha_count) {
// initialize alpha_count;
alpha_count[ho] = 0;
// NOTE: Assumes that psi_row_idx is sorted
for(int i=alpha_start[ho]; i<len_psi_row_idx; i++) {
if(psi_row_idx[i] == ho) alpha_count[ho]++;
else if(psi_row_idx[i] > ho) break;
}
}
}
int s2_idx_offset_cuda(const at::Tensor& psi_col_idx,
const at::Tensor& psi_row_idx,
at::Tensor& row_offset,
at::Tensor& row_count) {
auto stream = at::cuda::getCurrentCUDAStream();
int64_t* d_alpha_start;
int64_t* d_sequence;
int64_t* d_alpha_count = row_count.data_ptr<int64_t>();
int64_t* d_alpha_offset = row_offset.data_ptr<int64_t>();
C10_CUDA_CHECK(cudaMalloc(&d_alpha_start, row_offset.size(0)*sizeof(int64_t)));
// Find the first time each index occurs in psi_row_idx
// psi_row_idx = [0,0,0,0,1,1,1,1,2,2,2...]
// 0 starts at idx=0, 1 starts at idx=4, 2 starts at idx=8, etc
// this assumes that psi_row_idx is sorted!
C10_CUDA_CHECK(cudaMalloc(&d_sequence, row_offset.size(0)*sizeof(int64_t)));
thrust::sequence(thrust::device, d_sequence, d_sequence+row_offset.size(0), 0);
thrust::counting_iterator<int> start(0);
// thrust::lower_bound(thrust::device,
// psi_row_idx.data_ptr<int64_t>(),
// psi_row_idx.data_ptr<int64_t>()+psi_row_idx.size(0),
// start, start+psi_row_idx.size(0), d_alpha_start);
thrust::lower_bound(thrust::device,
psi_row_idx.data_ptr<int64_t>(),
psi_row_idx.data_ptr<int64_t>()+psi_row_idx.size(0),
d_sequence, d_sequence+row_offset.size(0), d_alpha_start);
alpha_count_kernel<<<at::cuda::detail::GET_BLOCKS(row_offset.size(0),512),512,
0,stream.stream()>>>(row_count.size(0),
psi_row_idx.size(0),
psi_row_idx.packed_accessor64<int64_t, 1>(),
d_alpha_start,
row_count.packed_accessor64<int64_t , 1>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
int maxAlphaSize = thrust::reduce(thrust::device,
d_alpha_count,
d_alpha_count+row_count.size(0),
0,
thrust::maximum<int>());
thrust::exclusive_scan(thrust::device,
d_alpha_count,
d_alpha_count+row_count.size(0),
d_alpha_offset);
C10_CUDA_CHECK(cudaFree(d_alpha_start));
C10_CUDA_CHECK(cudaFree(d_sequence));
return maxAlphaSize;
}
......@@ -32,4 +32,5 @@
from .pde_sphere import SphereSolver
from .shallow_water_equations import ShallowWaterSolver
from .pde_dataset import PdeDataset
\ No newline at end of file
from .pde_dataset import PdeDataset
from .stanford_2d3ds_dataset import StanfordSegmentationDataset, StanfordDepthDataset, Stanford2D3DSDownloader, compute_stats_s2, StanfordDatasetSubset
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
from abc import ABC, abstractmethod
from torch_harmonics.quadrature import _precompute_latitudes
def get_quadrature_weights(nlat: int, nlon: int, grid: str, tile: bool = False, normalized: bool = True) -> torch.Tensor:
# area weights
_, q = _precompute_latitudes(nlat=nlat, grid=grid)
q = q.reshape(-1, 1) * 2 * torch.pi / nlon
# numerical precision can be an issue here, make sure it sums to 1:
if normalized:
q = q / torch.sum(q) / float(nlon)
if tile:
q = torch.tile(q, (1, nlon)).contiguous()
return q.to(torch.float32)
class DiceLossS2(nn.Module):
def __init__(self, nlat: int, nlon: int, grid: str = "equiangular", weight: torch.Tensor = None, smooth: float = 0, ignore_index: int = -100, mode: str = "micro"):
super().__init__()
self.smooth = smooth
self.ignore_index = ignore_index
self.mode = mode
# area weights
q = get_quadrature_weights(nlat=nlat, nlon=nlon, grid=grid)
self.register_buffer("quad_weights", q)
if weight is None:
self.weight = None
else:
self.register_buffer("weight", weight.unsqueeze(0))
def forward(self, prd: torch.Tensor, tar: torch.Tensor) -> torch.Tensor:
prd = nn.functional.softmax(prd, dim=1)
# mask values
if self.ignore_index is not None:
mask = torch.where(tar == self.ignore_index, 0, 1)
prd = prd * mask.unsqueeze(1)
tar = tar * mask
# one hot encode
taroh = nn.functional.one_hot(tar, num_classes=prd.shape[1]).permute(0, 3, 1, 2)
# compute numerator and denominator
intersection = torch.sum((prd * taroh) * self.quad_weights, dim=(-2, -1))
union = torch.sum((prd + taroh) * self.quad_weights, dim=(-2, -1))
if self.mode == "micro":
if self.weight is not None:
intersection = torch.sum(intersection * self.weight, dim=1)
union = torch.sum(union * self.weight, dim=1)
else:
intersection = torch.mean(intersection, dim=1)
union = torch.mean(union, dim=1)
# compute score
dice = (2 * intersection + self.smooth) / (union + self.smooth)
# compute average over classes
if self.mode == "macro":
if self.weight is not None:
dice = torch.sum(dice * self.weight, dim=1)
else:
dice = torch.mean(dice, dim=1)
# average over batch
dice = torch.mean(dice)
return 1 - dice
class CrossEntropyLossS2(nn.Module):
def __init__(self, nlat: int, nlon: int, grid: str = "equiangular", weight: torch.Tensor = None, smooth: float = 0, ignore_index: int = -100):
super().__init__()
self.smooth = smooth
self.ignore_index = ignore_index
if weight is None:
self.weight = None
else:
self.register_buffer("weight", weight)
q = get_quadrature_weights(nlat=nlat, nlon=nlon, grid=grid)
self.register_buffer("quad_weights", q)
def forward(self, prd: torch.Tensor, tar: torch.Tensor) -> torch.Tensor:
# compute log softmax
logits = nn.functional.log_softmax(prd, dim=1)
ce = nn.functional.cross_entropy(logits, tar, weight=self.weight, reduction="none", ignore_index=self.ignore_index, label_smoothing=self.smooth)
ce = (ce * self.quad_weights).sum(dim=(-1, -2))
ce = torch.mean(ce)
return ce
class FocalLossS2(nn.Module):
def __init__(self, nlat: int, nlon: int, grid: str = "equiangular", weight: torch.Tensor = None, smooth: float = 0, ignore_index: int = -100):
super().__init__()
self.smooth = smooth
self.ignore_index = ignore_index
if weight is None:
self.weight = None
else:
self.register_buffer("weight", weight)
q = get_quadrature_weights(nlat=nlat, nlon=nlon, grid=grid)
self.register_buffer("quad_weights", q)
def forward(self, prd: torch.Tensor, tar: torch.Tensor, alpha: float = 0.25, gamma: float = 2):
# compute logits
logits = nn.functional.log_softmax(prd, dim=1)
# w = (1.0 - nn.functional.softmax(prd, dim=-3)).pow(gamma)
# w = torch.where(tar == self.ignore_index, 0.0, w.gather(-3, tar.unsqueeze(-3)).squeeze(-3))
ce = nn.functional.cross_entropy(logits, tar, weight=self.weight, reduction="none", ignore_index=self.ignore_index, label_smoothing=self.smooth)
fl = alpha * (1 - torch.exp(-ce)) ** gamma * ce
# fl = w * ce
fl = (fl * self.quad_weights).sum(dim=(-1, -2))
fl = fl.mean()
return fl
class SphericalLossBase(nn.Module, ABC):
"""Abstract base class for spherical losses that handles common initialization and integration."""
def __init__(self, nlat: int, nlon: int, grid: str = "equiangular", normalized: bool = True):
super().__init__()
self.nlat = nlat
self.nlon = nlon
self.grid = grid
# get quadrature weights - these sum to 1!
q = get_quadrature_weights(nlat=nlat, nlon=nlon, grid=grid, normalized=normalized)
self.register_buffer("quad_weights", q)
def _integrate_sphere(self, ugrid, mask=None):
if mask is None:
out = torch.sum(ugrid * self.quad_weights, dim=(-2, -1))
elif mask is not None:
out = torch.sum(mask * ugrid * self.quad_weights, dim=(-2, -1)) / torch.sum(mask * self.quad_weights, dim=(-2, -1))
return out
@abstractmethod
def _compute_loss_term(self, prd: torch.Tensor, tar: torch.Tensor) -> torch.Tensor:
"""Abstract method that must be implemented by child classes to compute loss terms.
Args:
prd (torch.Tensor): Prediction tensor
tar (torch.Tensor): Target tensor
Returns:
torch.Tensor: Computed loss term before integration
"""
pass
def _post_integration_hook(self, loss: torch.Tensor) -> torch.Tensor:
"""Post-integration hook. Commonly used for the roots in Lp norms"""
return loss
def forward(self, prd: torch.Tensor, tar: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Common forward pass that handles masking and reduction.
Args:
prd (torch.Tensor): Prediction tensor
tar (torch.Tensor): Target tensor
mask (Optional[torch.Tensor], optional): Mask tensor. Defaults to None.
Returns:
torch.Tensor: Final loss value
"""
loss_term = self._compute_loss_term(prd, tar)
# Integrate over the sphere for each item in the batch
loss = self._integrate_sphere(loss_term, mask)
# potentially call root
loss = self._post_integration_hook(loss)
# Average the loss over the batch dimension
return torch.mean(loss)
class SquaredL2LossS2(SphericalLossBase):
def _compute_loss_term(self, prd: torch.Tensor, tar: torch.Tensor) -> torch.Tensor:
return torch.square(prd - tar)
class L1LossS2(SphericalLossBase):
def _compute_loss_term(self, prd: torch.Tensor, tar: torch.Tensor) -> torch.Tensor:
return torch.abs(prd - tar)
class L2LossS2(SquaredL2LossS2):
def _post_integration_hook(self, loss: torch.Tensor) -> torch.Tensor:
return torch.sqrt(loss)
class W11LossS2(SphericalLossBase):
def __init__(self, nlat: int, nlon: int, grid: str = "equiangular"):
super().__init__(nlat=nlat, nlon=nlon, grid=grid)
# Set up grid and domain for FFT
l_phi = 2 * torch.pi # domain size
l_theta = torch.pi # domain size
k_phi = torch.fft.fftfreq(nlon, d=l_phi / (2 * torch.pi * nlon))
k_theta = torch.fft.fftfreq(nlat, d=l_theta / (2 * torch.pi * nlat))
k_theta_mesh, k_phi_mesh = torch.meshgrid(k_theta, k_phi, indexing="ij")
self.register_buffer("k_phi_mesh", k_phi_mesh)
self.register_buffer("k_theta_mesh", k_theta_mesh)
def _compute_loss_term(self, prd: torch.Tensor, tar: torch.Tensor) -> torch.Tensor:
prd_prime_fft2_phi_h = torch.fft.ifft2(1j * self.k_phi_mesh * torch.fft.fft2(prd)).real
prd_prime_fft2_theta_h = torch.fft.ifft2(1j * self.k_theta_mesh * torch.fft.fft2(prd)).real
tar_prime_fft2_phi_h = torch.fft.ifft2(1j * self.k_phi_mesh * torch.fft.fft2(tar)).real
tar_prime_fft2_theta_h = torch.fft.ifft2(1j * self.k_theta_mesh * torch.fft.fft2(tar)).real
# Return the element-wise loss term
return torch.abs(prd_prime_fft2_phi_h - tar_prime_fft2_phi_h) + torch.abs(prd_prime_fft2_theta_h - tar_prime_fft2_theta_h)
class NormalLossS2(SphericalLossBase):
"""Combined L1 and Surface Normal Consistency Loss for spherical data.
This loss function combines an L1 loss term with a surface normal alignment term.
The loss consists of:
1. L1 Loss: Absolute difference between predicted and target values
2. Normal Consistency Loss: 1 - cosine similarity between surface normals
(equivalent to cosine distance between normal vectors)
Surface normals are computed by calculating gradients in latitude and longitude
directions using FFT, then constructing 3D normal vectors that are normalized.
Args:
nlat (int): Number of latitude points
nlon (int): Number of longitude points
grid (str, optional): Grid type. Defaults to "equiangular".
"""
def __init__(self, nlat: int, nlon: int, grid: str = "equiangular"):
super().__init__(nlat=nlat, nlon=nlon, grid=grid)
# Set up grid and domain for FFT
l_phi = 2 * torch.pi # domain size
l_theta = torch.pi # domain size
k_phi = torch.fft.fftfreq(nlon, d=l_phi / (2 * torch.pi * nlon))
k_theta = torch.fft.fftfreq(nlat, d=l_theta / (2 * torch.pi * nlat))
k_theta_mesh, k_phi_mesh = torch.meshgrid(k_theta, k_phi, indexing="ij")
self.register_buffer("k_phi_mesh", k_phi_mesh)
self.register_buffer("k_theta_mesh", k_theta_mesh)
def compute_gradients(self, x):
# Make sure x is reshaped to have a batch dimension if it's missing
if x.dim() == 2:
x = x.unsqueeze(0) # Add batch dimension
x_prime_fft2_phi_h = torch.fft.ifft2(1j * self.k_phi_mesh * torch.fft.fft2(x)).real
x_prime_fft2_theta_h = torch.fft.ifft2(1j * self.k_theta_mesh * torch.fft.fft2(x)).real
return x_prime_fft2_theta_h, x_prime_fft2_phi_h
def compute_normals(self, x):
x = x.to(torch.float32)
# Ensure x has a batch dimension
if x.dim() == 2:
x = x.unsqueeze(0)
grad_lat, grad_lon = self.compute_gradients(x)
# Create 3D normal vectors
ones = torch.ones_like(x)
normals = torch.stack([-grad_lon, -grad_lat, ones], dim=1)
# Normalize along component dimension
normals = F.normalize(normals, p=2, dim=1)
return normals
def _compute_loss_term(self, prd: torch.Tensor, tar: torch.Tensor) -> torch.Tensor:
# Handle dimensions for both prediction and target
# Ensure we have at least a batch dimension
if prd.dim() == 2:
prd = prd.unsqueeze(0)
if tar.dim() == 2:
tar = tar.unsqueeze(0)
# For 4D tensors (batch, channel, height, width), remove channel if it's 1
if prd.dim() == 4 and prd.size(1) == 1:
prd = prd.squeeze(1)
if tar.dim() == 4 and tar.size(1) == 1:
tar = tar.squeeze(1)
pred_normals = self.compute_normals(prd)
tar_normals = self.compute_normals(tar)
# Compute cosine similarity
normal_loss = 1 - torch.sum(pred_normals * tar_normals, dim=1, keepdim=True)
return normal_loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment