Commit 4350ba9f authored by Boris Bonev's avatar Boris Bonev Committed by Boris Bonev
Browse files

fixing bug in quadrature weights for full attention. Adding better unit tests...

fixing bug in quadrature weights for full attention. Adding better unit tests for attention. Cleanup in the cuda code.
parent b6c48457
......@@ -99,7 +99,6 @@ def get_ext_modules():
"torch_harmonics/csrc/attention/attention_fwd_cuda.cu",
"torch_harmonics/csrc/attention/attention_bwd_cuda.cu",
"torch_harmonics/csrc/attention/attention_interface.cu",
"torch_harmonics/csrc/attention/attention_row_offset.cu"
],
extra_compile_args=get_compile_args("neighborhood_attention")
)
......
......@@ -35,6 +35,7 @@ from parameterized import parameterized
# import math
import numpy as np
import torch
import torch.nn as nn
# from torch.autograd import gradcheck
from torch_harmonics import AttentionS2, NeighborhoodAttentionS2
......@@ -58,520 +59,158 @@ except ImportError as err:
_cuda_extension_available = False
# this routine is only supposed to be used in this test, since it is numerically not stable but supports
# autograd which some of the better kernels do not
def _neighborhood_attention_s2_torch_test(
kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, quad_weights: torch.Tensor, col_idx: torch.Tensor, row_idx: torch.Tensor, nlon_in: int, nlat_out: int, nlon_out: int
):
out = torch.zeros_like(qy)
for ho in range(nlat_out):
# get nonzero indices in output row
idx_ho = col_idx[row_idx == ho]
for wo in range(nlon_out):
alpha_sum = torch.zeros((out.shape[0],), dtype=out.dtype, device=out.device)
alpha = torch.zeros((out.shape[0], len(idx_ho)), dtype=out.dtype, device=out.device)
for inz, nz_col_idx in enumerate(idx_ho):
# compute input indices from psi datastructure
hi = nz_col_idx // nlon_in
# account for output shift and ensure positive index due to circular condition
wi = nz_col_idx % nlon_in
wip = (wi + wo) % nlon_in
# compute correlation & softmax numerator
q_ho_wo = qy[:, :, ho, wo]
k_hi_wip = kx[:, :, hi, wip]
alpha[:, inz] = torch.exp(torch.sum(q_ho_wo * k_hi_wip, dim=1)) * quad_weights[hi]
# softmax denominator
alpha_sum[:] = alpha_sum[:] + alpha[:, inz]
for inz, nz_col_idx in enumerate(idx_ho):
# compute input indices from psi datastructure
hi = nz_col_idx // nlon_in
# account for output shift and ensure positive index due to circular condition
wi = nz_col_idx % nlon_in
wip = (wi + wo) % nlon_in
# compute matmul of attention matrix with V-vector
out[:, :, ho, wo] = out[:, :, ho, wo] + (alpha[:, None, inz] / alpha_sum[:, None]) * vx[:, :, hi, wip]
return out
class TestNeighborhoodAttention(unittest.TestCase):
class TestNeighborhoodAttentionS2(unittest.TestCase):
def setUp(self):
if torch.cuda.is_available():
self.device = torch.device("cuda:0")
torch.cuda.set_device(self.device.index)
torch.cuda.manual_seed(333)
torch.manual_seed(333)
else:
self.device = torch.device("cpu")
torch.manual_seed(333)
@parameterized.expand(
[
# regular convolution
[8, 4, 6, (17, 32), 1e-6, 1e-5],
]
)
def test_batched_linear(self, batch_size, in_channels, out_channels, shape, atol, rtol):
# weight
weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, 1, 1, dtype=torch.float32, device=self.device))
bias = torch.nn.Parameter(torch.randn(out_channels, dtype=torch.float32, device=self.device))
# input
inp = torch.randn(batch_size, in_channels, *shape, dtype=torch.float32, device=self.device)
inp.requires_grad = True
# operation
out = torch.nn.functional.conv2d(inp, weight=weight, bias=bias)
out_grad = torch.randn(batch_size, out_channels, *shape, dtype=torch.float32, device=self.device)
out.backward(out_grad)
# store for comparison
wgrad = weight.grad.clone()
bgrad = bias.grad.clone()
igrad = inp.grad.clone()
# explicit layers
igrad_explicit = torch.nn.functional.conv2d(out_grad, weight=weight.permute([1, 0, 2, 3]), bias=None)
wgrad_explicit = torch.einsum("bchw,bfhw->cf", out_grad, inp).reshape(out_channels, in_channels, 1, 1).contiguous()
bgrad_explicit = torch.sum(out_grad, dim=(0, 2, 3))
# check consistency
self.assertTrue(torch.allclose(igrad, igrad_explicit, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(wgrad, wgrad_explicit, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(bgrad, bgrad_explicit, atol=atol, rtol=rtol))
torch.manual_seed(333)
@parameterized.expand(
[
# self attention
[8, 4, 1, (17, 32), (17, 32), "equiangular", "equiangular", 1e-6, 1e-4],
# Format: [batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol]
[4, 4, 1, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
[4, 4, 2, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
[4, 4, 4, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
[4, 4, 1, (6, 12), (6, 12), "legendre-gauss", "legendre-gauss", 1e-5, 1e-3],
[4, 4, 1, (6, 12), (6, 12), "lobatto", "lobatto", 1e-5, 1e-3],
]
)
def test_fwd(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol):
def test_custom_implementation(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=True):
"""Tests numerical equivalence between the custom (CUDA) implementation and the reference torch implementation"""
# extract some parameters
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
# set up neighbor matrix
att = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads,
in_shape=in_shape, out_shape=out_shape,
grid_in=grid_in, grid_out=grid_out, bias=False).to(self.device)
# Execute and compare
k_inp = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device=self.device)
k_inp.requires_grad = False
v_inp = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device=self.device)
v_inp.requires_grad = False
q_inp = torch.randn(batch_size, channels, nlat_out, nlon_out, dtype=torch.float32, device=self.device)
q_inp.requires_grad = False
out_torch = _neighborhood_attention_s2_torch_test(k_inp, v_inp, q_inp, att.quad_weights, att.psi_col_idx, att.psi_row_idx, nlon_in, nlat_out, nlon_out)
with torch.no_grad():
out_torch_explicit = _neighborhood_attention_s2_fwd_torch(k_inp, v_inp, q_inp, att.quad_weights, att.psi_col_idx, att.psi_roff_idx, nlon_in, nlat_out, nlon_out)
self.assertTrue(torch.allclose(out_torch_explicit, out_torch, atol=atol, rtol=rtol))
if _cuda_extension_available:
out_cuda = attention_cuda_extension.forward(k_inp, v_inp, q_inp, att.quad_weights, att.psi_col_idx, att.psi_roff_idx, nlon_in, nlat_out, nlon_out)
self.assertTrue(torch.allclose(out_torch, out_cuda, atol=atol, rtol=rtol))
@parameterized.expand(
[
# regular convolution
[8, 4, 1, (17, 32), (17, 32), "equiangular", "equiangular", 1e-6, 1e-4],
]
)
def test_bwd_dv(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol):
# extract some parameters
_, nlon_in = in_shape
nlat_out, nlon_out = out_shape
# set up neighbor matrix
att = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads,
in_shape=in_shape, out_shape=out_shape,
grid_in=grid_in, grid_out=grid_out, bias=False).to(self.device)
# Execute and compare
k_inp = torch.randn(batch_size, channels, *in_shape, dtype=torch.float32, device=self.device)
k_inp.requires_grad = False
v_inp = torch.randn(batch_size, channels, *in_shape, dtype=torch.float32, device=self.device)
v_inp.requires_grad = True
q_inp = torch.randn(batch_size, channels, *out_shape, dtype=torch.float32, device=self.device)
q_inp.requires_grad = False
out_grad = torch.randn(batch_size, channels, *out_shape, dtype=torch.float32, device=self.device)
out_torch = _neighborhood_attention_s2_torch_test(k_inp, v_inp, q_inp, att.quad_weights, att.psi_col_idx, att.psi_row_idx, nlon_in, nlat_out, nlon_out)
# need 'retain_graph' to avoid an error in the tests after this one
out_torch.backward(out_grad)
dv_inp_torch = v_inp.grad.clone()
with torch.no_grad():
dv_inp_torch_explicit = _neighborhood_attention_s2_bwd_dv_torch(
k_inp, v_inp, q_inp, out_grad, att.quad_weights, att.psi_col_idx, att.psi_roff_idx, nlon_in, nlat_out, nlon_out
)
self.assertTrue(torch.allclose(dv_inp_torch_explicit, dv_inp_torch, atol=atol, rtol=rtol))
if _cuda_extension_available:
dv_inp_cuda_explicit = attention_cuda_extension.backward_dv(
k_inp, v_inp, q_inp, out_grad, att.quad_weights, att.psi_col_idx, att.psi_roff_idx, nlon_in, nlat_out, nlon_out
)
self.assertTrue(torch.allclose(dv_inp_cuda_explicit, dv_inp_torch, atol=atol, rtol=rtol))
# Helper: create inputs
inputs_ref = {
"k": torch.randn(batch_size, channels, nlat_in, nlon_in, requires_grad=True, device=self.device, dtype=torch.float32),
"v": torch.randn(batch_size, channels, nlat_in, nlon_in, requires_grad=True, device=self.device, dtype=torch.float32),
"q": torch.randn(batch_size, channels, nlat_out, nlon_out, requires_grad=True, device=self.device, dtype=torch.float32),
}
inputs = {k: v.detach().clone().to(self.device).requires_grad_() for k, v in inputs_ref.items()}
# reference input and model
model_ref = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads, in_shape=in_shape, out_shape=out_shape, grid_in=grid_in, grid_out=grid_out, bias=True).to(
self.device
)
# Device model and inputs
model = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads, in_shape=in_shape, out_shape=out_shape, grid_in=grid_in, grid_out=grid_out, bias=True)
# Synchronize parameters of model
model.load_state_dict(model_ref.state_dict())
model = model.to(self.device)
for (name_ref, p_ref), (name, p) in zip(model_ref.named_parameters(), model.named_parameters()):
assert torch.allclose(p_ref, p), f"Parameter mismatch: {name_ref} vs {name}"
# reference forward passes
out_ref = _neighborhood_attention_s2_torch(
inputs_ref["k"],
inputs_ref["v"],
inputs_ref["q"] * model_ref.scale,
model_ref.k_weights,
model_ref.v_weights,
model_ref.q_weights,
model_ref.k_bias,
model_ref.v_bias,
model_ref.q_bias,
model_ref.quad_weights,
model_ref.psi_col_idx,
model_ref.psi_roff_idx,
model_ref.num_heads,
model_ref.nlon_in,
model_ref.nlat_out,
model_ref.nlon_out,
)
out_ref = nn.functional.conv2d(out_ref, model_ref.proj_weights, bias=model_ref.proj_bias)
out = model(inputs["q"], inputs["k"], inputs["v"])
# Check forward equivalence
self.assertTrue(torch.allclose(out, out_ref, atol=atol, rtol=rtol), "Forward outputs differ between torch reference and custom implementation")
# Backward passes
grad = torch.randn_like(out_ref)
out_ref.backward(grad)
out.backward(grad.to(self.device))
# Check input gradient equivalence
for inp in ["q", "k", "v"]:
grad_ref = inputs_ref[inp].grad.cpu()
grad = inputs[inp].grad.cpu()
self.assertTrue(torch.allclose(grad, grad_ref, atol=atol, rtol=rtol), f"Input gradient mismatch in {inp}")
# Check parameter gradient equivalence
for p_ref, p in zip(model_ref.parameters(), model.parameters()):
self.assertTrue(torch.allclose(p.grad, p_ref.grad, atol=atol, rtol=rtol), f"Parameter gradient mismatch: {type(p_ref).__name__}")
# caution: multihead-implementation between full and neighborhood attention still seem to differ. tests are only done for single head
@parameterized.expand(
[
# regular convolution
[8, 4, 1, (17, 32), (17, 32), "equiangular", "equiangular", 1e-6, 1e-3],
# Format: [batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol]
[4, 4, 1, (6, 12), (6, 12), "equiangular", "equiangular", 1e-2, 0],
# [4, 4, 2, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
# [4, 4, 4, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
[4, 4, 1, (6, 12), (6, 12), "legendre-gauss", "legendre-gauss", 1e-2, 0],
[4, 4, 1, (6, 12), (6, 12), "lobatto", "lobatto", 1e-2, 0],
]
)
def test_bwd_dk(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol):
def test_neighborhood_global_equivalence(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=True):
"""Tests numerical equivalence between the global spherical attention module and the neighborhood spherical attention module with the neighborhood set ot the whole sphere"""
# extract some parameters
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
# set up neighbor matrix
att = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads,
in_shape=in_shape, out_shape=out_shape,
grid_in=grid_in, grid_out=grid_out, bias=False).to(self.device)
# Execute and compare
k_inp = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device=self.device)
k_inp.requires_grad = True
v_inp = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device=self.device)
v_inp.requires_grad = False
q_inp = torch.randn(batch_size, channels, nlat_out, nlon_out, dtype=torch.float32, device=self.device)
q_inp.requires_grad = False
out_grad = torch.randn(batch_size, channels, nlat_out, nlon_out, dtype=torch.float32, device=self.device)
out_torch = _neighborhood_attention_s2_torch_test(k_inp, v_inp, q_inp, att.quad_weights, att.psi_col_idx, att.psi_row_idx, nlon_in, nlat_out, nlon_out)
# need 'retain_graph' to avoid an error in the tests after this one
out_torch.backward(out_grad)
dk_inp_torch = k_inp.grad.clone()
with torch.no_grad():
dk_inp_torch_explicit = _neighborhood_attention_s2_bwd_dk_torch(
k_inp, v_inp, q_inp, out_grad, att.quad_weights, att.psi_col_idx, att.psi_roff_idx, nlon_in, nlat_out, nlon_out
)
self.assertTrue(torch.allclose(dk_inp_torch_explicit, dk_inp_torch, atol=atol, rtol=rtol))
if _cuda_extension_available:
dk_inp_cuda_explicit = attention_cuda_extension.backward_dk(
k_inp, v_inp, q_inp, out_grad, att.quad_weights, att.psi_col_idx, att.psi_roff_idx, nlon_in, nlat_out, nlon_out
)
self.assertTrue(torch.allclose(dk_inp_cuda_explicit, dk_inp_torch, atol=atol, rtol=rtol))
@parameterized.expand(
[
# regular convolution
[8, 4, 1, (17, 32), (17, 32), "equiangular", "equiangular", 4e-6, 1e-3],
]
)
def test_bwd_dq(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol):
# extract some parameters
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
# set up neighbor matrix
att = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads,
in_shape=in_shape, out_shape=out_shape,
grid_in=grid_in, grid_out=grid_out, bias=False).to(self.device)
# Execute and compare
k_inp = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device=self.device)
k_inp.requires_grad = False
v_inp = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device=self.device)
v_inp.requires_grad = False
q_inp = torch.randn(batch_size, channels, nlat_out, nlon_out, dtype=torch.float32, device=self.device)
q_inp.requires_grad = True
out_grad = torch.randn(batch_size, channels, nlat_out, nlon_out, dtype=torch.float32, device=self.device)
out_torch = _neighborhood_attention_s2_torch_test(k_inp, v_inp, q_inp, att.quad_weights, att.psi_col_idx, att.psi_row_idx, nlon_in, nlat_out, nlon_out)
# need 'retain_graph' to avoid an error in the tests after this one
out_torch.backward(out_grad)
dq_inp_torch = q_inp.grad.clone()
with torch.no_grad():
dq_inp_torch_explicit = _neighborhood_attention_s2_bwd_dq_torch(
k_inp, v_inp, q_inp, out_grad, att.quad_weights, att.psi_col_idx, att.psi_roff_idx, nlon_in, nlat_out, nlon_out
)
self.assertTrue(torch.allclose(dq_inp_torch_explicit, dq_inp_torch, atol=atol, rtol=rtol))
if _cuda_extension_available:
dq_inp_cuda_explicit = attention_cuda_extension.backward_dq(
k_inp, v_inp, q_inp, out_grad, att.quad_weights, att.psi_col_idx, att.psi_roff_idx, nlon_in, nlat_out, nlon_out
)
self.assertTrue(torch.allclose(dq_inp_cuda_explicit, dq_inp_torch, atol=atol, rtol=rtol))
@parameterized.expand(
[
# self attention
[1, 73, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
]
)
def test_big(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol):
# this test only makes sense when CUDA version is available
if torch.cuda.is_available():
if not _cuda_extension_available:
print("WARNING: Problem loading CUDA attention module")
return
# extract some parameters
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
# TODO: this test seems hardcoded for GPU. Is this necessary?
k_gpu = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device="cuda:0")
k_gpu.requires_grad = False
v_gpu = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device="cuda:0")
v_gpu.requires_grad = False
q_gpu = torch.randn(batch_size, channels, nlat_out, nlon_out, dtype=torch.float32, device="cuda:0")
q_gpu.requires_grad = False
# set up layers
att_gpu = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads,
in_shape=in_shape, out_shape=out_shape,
grid_in=grid_in, grid_out=grid_out, bias=True).to("cuda:0")
# random weights
with torch.no_grad():
att_gpu.q_weights.normal_()
att_gpu.k_weights.normal_()
att_gpu.v_weights.normal_()
att_gpu.q_bias.normal_()
att_gpu.k_bias.normal_()
att_gpu.v_bias.normal_()
out_gpu = att_gpu(q_gpu, k_gpu, v_gpu)
# sync weights:
with torch.no_grad():
att_gpu.q_weights.copy_(att_gpu.q_weights)
att_gpu.k_weights.copy_(att_gpu.k_weights)
att_gpu.v_weights.copy_(att_gpu.v_weights)
att_gpu.q_bias.copy_(att_gpu.q_bias)
att_gpu.k_bias.copy_(att_gpu.k_bias)
att_gpu.v_bias.copy_(att_gpu.v_bias)
q_gpu = q_gpu.detach().clone().to(self.device)
q_gpu.requires_grad = True
k_gpu = k_gpu.detach().clone().to(self.device)
k_gpu.requires_grad = True
v_gpu = v_gpu.detach().clone().to(self.device)
v_gpu.requires_grad = True
out_gpu = att_gpu(q_gpu, k_gpu, v_gpu)
out_grad = torch.randn(out_gpu.shape, dtype=torch.float32, device="cuda:0")
out_gpu.backward(out_grad.to("cuda:0"))
@parameterized.expand(
[
# self attention
[10, 2, 1, (17, 32), (17, 32), "equiangular", "equiangular", 1e-5, 1e-5],
]
)
def test_neighborhood(self, batch_size, num_channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol):
"""
This test sets a specific q[ho,wo] value to 1.0 (elsewhere 0), and then a neighborhood of k around ho,wo to 1.0 (else 0.0). Also vi is set to a sinusoidal input. We also run it with fully 0 q,k. We test that the output of the nonzero q,k is only different to the zero q,k in a single point. We also check the value of this difference (as a regression test).
"""
# extract some parameters
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
from torch_harmonics import _neighborhood_attention_s2_fwd_torch
device = "cpu"
nas2_2 = NeighborhoodAttentionS2(in_channels=num_channels, num_heads=heads,
in_shape=(nlat_in, nlon_in), out_shape=(nlat_out, nlon_out),
theta_cutoff=torch.pi / 128 * 10)
nas2_2.to(device)
qo = torch.zeros((batch_size, num_channels, nlat_in, nlon_in)).to(device)
x = torch.linspace(0, 2 * np.pi, nlat_in) # 100 points in x direction
y = torch.linspace(0, 2 * np.pi, nlon_in) # 100 points in y direction
# Create a meshgrid
X, Y = torch.meshgrid(x, y, indexing="ij")
vi = torch.ones((batch_size, num_channels, nlat_in, nlon_in)).to(device)
vi[:, :, :, :] = (torch.sin(X) + torch.sin(Y))[None, None, :, :]
ki = torch.zeros((batch_size, num_channels, nlat_in, nlon_in)).to(device)
ki2 = torch.zeros((batch_size, num_channels, nlat_in, nlon_in)).to(device)
ho = 10
wo = 15
qo[:, 0, ho, wo] = 1.0
nas3 = NeighborhoodAttentionS2(in_channels=num_channels, num_heads=heads,
in_shape=(nlat_in, nlon_in), out_shape=(nlat_out, nlon_out),
theta_cutoff=torch.pi / 128 * 7)
zstart = nas3.psi_roff_idx[ho]
zend = nas3.psi_roff_idx[ho + 1]
# set a small neighborhood of k around (ho,wo) to 1
for idz in range(zstart, zend):
nz_col_idx = nas3.psi_col_idx[idz]
# compute input indices from psi datastructure
hi = nz_col_idx // nlon_in
# account for output shift and ensure positive index due to circular condition
wi = nz_col_idx % nlon_in
wip = (wi + wo) % nlon_in
ki2[:, 0, hi, wip] = 1.0
# run with k zero
y = _neighborhood_attention_s2_fwd_torch(ki, vi, qo, nas2_2.quad_weights, nas2_2.psi_col_idx, nas2_2.psi_roff_idx, nlon_in, nlat_out, nlon_out)
# run with k 1 at neighborhood of ho,wo
y2 = _neighborhood_attention_s2_fwd_torch(ki2, vi, qo, nas2_2.quad_weights, nas2_2.psi_col_idx, nas2_2.psi_roff_idx, nlon_in, nlat_out, nlon_out)
# for viz if desired
# plt.matshow((y[0,0,:,:]-y2[0,0,:,:]).detach().cpu())#, vmin=0, vmax=2.0)
# plt.matshow((y2[0,0,:,:]).detach().cpu())#, vmin=0, vmax=2.0)
# compare zero k vs. nonzero k and ensure difference only occurs at ho,wo
nz_x, nz_y = torch.where((y[0, 0, :, :] - y2[0, 0, :, :]).abs() > 0)
self.assertTrue(nz_x.item() == ho)
self.assertTrue(nz_y.item() == wo)
h, w = nz_x.item(), nz_y.item()
diff_hw = y[0, 0, h, w] - y2[0, 0, h, w]
# print("diff_hw=", diff_hw.item())
# regression test the difference. Unfortunately difficult to come up with an
# analytical value, so we just have it hardcoded.
self.assertTrue(torch.allclose(diff_hw, torch.tensor([0.00753], device=device), rtol=rtol, atol=atol))
@parameterized.expand(
[
# self attention
[8, 4, 1, (17, 32), (17, 32), "equiangular", "equiangular", 2e-4, 1e-5],
[8, 4, 2, (17, 32), (17, 32), "equiangular", "equiangular", 2e-4, 1e-5],
]
)
def test_full(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol):
# extract some parameters
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
k_cpu = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device="cpu")
k_cpu.requires_grad = True
v_cpu = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device="cpu")
v_cpu.requires_grad = True
q_cpu = torch.randn(batch_size, channels, nlat_out, nlon_out, dtype=torch.float32, device="cpu")
q_cpu.requires_grad = True
# set up layers
att_cpu = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads,
in_shape=in_shape, out_shape=out_shape,
grid_in=grid_in, grid_out=grid_out, bias=True)
# random weights
with torch.no_grad():
att_cpu.q_weights.normal_()
att_cpu.k_weights.normal_()
att_cpu.v_weights.normal_()
att_cpu.q_bias.normal_()
att_cpu.k_bias.normal_()
att_cpu.v_bias.normal_()
out_cpu = att_cpu(q_cpu, k_cpu, v_cpu)
out_grad = torch.randn(out_cpu.shape, dtype=torch.float32, device="cpu")
out_cpu.backward(out_grad)
att_gpu = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads,
in_shape=in_shape, out_shape=out_shape,
grid_in=grid_in, grid_out=grid_out, bias=True).to(self.device)
# sync weights:
with torch.no_grad():
att_gpu.q_weights.copy_(att_cpu.q_weights)
att_gpu.k_weights.copy_(att_cpu.k_weights)
att_gpu.v_weights.copy_(att_cpu.v_weights)
att_gpu.proj_weights.copy_(att_cpu.proj_weights)
att_gpu.q_bias.copy_(att_cpu.q_bias)
att_gpu.k_bias.copy_(att_cpu.k_bias)
att_gpu.v_bias.copy_(att_cpu.v_bias)
att_gpu.proj_bias.copy_(att_cpu.proj_bias)
q_gpu = q_cpu.detach().clone().to(self.device)
q_gpu.requires_grad = True
k_gpu = k_cpu.detach().clone().to(self.device)
k_gpu.requires_grad = True
v_gpu = v_cpu.detach().clone().to(self.device)
v_gpu.requires_grad = True
out_gpu = att_gpu(q_gpu, k_gpu, v_gpu)
out_gpu.backward(out_grad.to(self.device))
# check forward
self.assertTrue(torch.allclose(out_cpu.to(self.device), out_gpu, atol=atol, rtol=rtol))
# check input gradients:
self.assertTrue(torch.allclose(q_cpu.grad.to(self.device), q_gpu.grad, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(k_cpu.grad.to(self.device), k_gpu.grad, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(v_cpu.grad.to(self.device), v_gpu.grad, atol=atol, rtol=rtol))
# check weight gradients
self.assertTrue(torch.allclose(att_cpu.q_weights.grad.to(self.device), att_gpu.q_weights.grad, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(att_cpu.k_weights.grad.to(self.device), att_gpu.k_weights.grad, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(att_cpu.v_weights.grad.to(self.device), att_gpu.v_weights.grad, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(att_cpu.proj_weights.grad.to(self.device), att_gpu.proj_weights.grad, atol=atol, rtol=rtol))
# check bias gradients
self.assertTrue(torch.allclose(att_cpu.q_bias.grad.to(self.device), att_gpu.q_bias.grad, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(att_cpu.k_bias.grad.to(self.device), att_gpu.k_bias.grad, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(att_cpu.v_bias.grad.to(self.device), att_gpu.v_bias.grad, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(att_cpu.proj_bias.grad.to(self.device), att_gpu.proj_bias.grad, atol=atol, rtol=rtol))
@parameterized.expand(
[
# self attention
[8, 8, 8, 2, (17, 32), (17, 32), "equiangular", "equiangular", 2e-4, 1e-5],
[8, 8, 8, 2, (17, 32), (17, 32), "legendre-gauss", "legendre-gauss", 2e-4, 1e-5],
[8, 8, 8, 2, (17, 32), (17, 32), "lobatto", "lobatto", 2e-4, 1e-5],
[8, 8, 4, 2, (17, 32), (17, 32), "equiangular", "equiangular", 2e-4, 1e-5],
]
)
def test_full_attention(self, batch_size, channels_in, channels_out, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol):
# extract some parameters
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
k_cpu = torch.randn(batch_size, channels_in, nlat_in, nlon_in, dtype=torch.float32, device="cpu")
k_cpu.requires_grad = True
v_cpu = torch.randn(batch_size, channels_in, nlat_in, nlon_in, dtype=torch.float32, device="cpu")
v_cpu.requires_grad = True
q_cpu = torch.randn(batch_size, channels_in, nlat_out, nlon_out, dtype=torch.float32, device="cpu")
q_cpu.requires_grad = True
att_cpu = AttentionS2(in_channels=channels_in, out_channels=channels_out, num_heads=heads, in_shape=in_shape, out_shape=out_shape, grid_in=grid_in, grid_out=grid_out, bias=True)
out = att_cpu(q_cpu, k_cpu, v_cpu)
# Helper: create inputs
inputs_ref = {
"k": torch.randn(batch_size, channels, nlat_in, nlon_in, requires_grad=True, device=self.device, dtype=torch.float32),
"v": torch.randn(batch_size, channels, nlat_in, nlon_in, requires_grad=True, device=self.device, dtype=torch.float32),
"q": torch.randn(batch_size, channels, nlat_out, nlon_out, requires_grad=True, device=self.device, dtype=torch.float32),
}
inputs = {k: v.detach().clone().to(self.device).requires_grad_() for k, v in inputs_ref.items()}
# reference input and model
model_ref = AttentionS2(in_channels=channels, num_heads=heads, in_shape=in_shape, out_shape=out_shape, grid_in=grid_in, grid_out=grid_out, bias=False).to(self.device)
# Device model and inputs
model = NeighborhoodAttentionS2(
in_channels=channels, num_heads=heads, in_shape=in_shape, out_shape=out_shape, grid_in=grid_in, grid_out=grid_out, bias=False, theta_cutoff=2 * torch.pi
)
# Synchronize parameters of model
model.load_state_dict(model_ref.state_dict())
model = model.to(self.device)
for (name_ref, p_ref), (name, p) in zip(model_ref.named_parameters(), model.named_parameters()):
assert torch.allclose(p_ref, p), f"Parameter mismatch: {name_ref} vs {name}"
# reference forward passes
out_ref = model_ref(inputs_ref["q"], inputs_ref["k"], inputs_ref["v"])
out = model(inputs["q"], inputs["k"], inputs["v"])
# Check forward equivalence
self.assertTrue(torch.allclose(out, out_ref, atol=atol, rtol=rtol), "Forward outputs differ between torch reference and custom implementation")
# Backward passes
grad = torch.randn_like(out_ref)
out_ref.backward(grad)
out.backward(grad.to(self.device))
# Check input gradient equivalence
for inp in ["q", "k", "v"]:
grad_ref = inputs_ref[inp].grad
grad = inputs[inp].grad
self.assertTrue(torch.allclose(grad, grad_ref, atol=atol, rtol=rtol), f"Input gradient mismatch in {inp}")
# Check parameter gradient equivalence - check only q,k, v weights
for key in ["q_weights", "k_weights", "v_weights"]:
grad_ref = getattr(model_ref, key).grad
grad = getattr(model, key).grad
self.assertTrue(torch.allclose(grad, grad_ref, atol=atol, rtol=rtol), f"Parameter gradient mismatch")
# check if output is sane
self.assertFalse(torch.isnan(out).any())
if __name__ == "__main__":
unittest.main()
......@@ -107,7 +107,7 @@ class AttentionS2(nn.Module):
_, wgl = _precompute_latitudes(self.nlat_in, grid=grid_in)
quad_weights = 2.0 * torch.pi * wgl.to(dtype=torch.float32) / self.nlon_in
# we need to tile and flatten them accordingly
quad_weights = torch.tile(quad_weights, (1, self.nlon_in)).flatten()
quad_weights = torch.tile(quad_weights.reshape(-1, 1), (1, self.nlon_in)).flatten()
# compute log because they are applied as an addition prior to the softmax ('attn_mask'), which includes an exponential.
# see https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
......
......@@ -76,7 +76,3 @@ torch::Tensor s2_attention_bwd_dv_cuda(at::Tensor kx,
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out);
int s2_idx_offset_cuda(const at::Tensor &psi_col_idx,
const at::Tensor &psi_row_idx, at::Tensor &row_offset,
at::Tensor &row_count);
// coding=utf-8
//
// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
//
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
......@@ -433,7 +433,7 @@ s2_attention_bwd_dq_kernel(int num_channels, int nlon_in, int nlat_out, int nlon
float qdotk_max = std::numeric_limits<float>::lowest();
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
......@@ -559,7 +559,7 @@ __global__ void s2_attention_bwd_dkvq_kernel(int num_channels, int nlon_in, int
float qdotk_max = std::numeric_limits<float>::lowest();
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
......@@ -675,7 +675,7 @@ __global__ void s2_attention_bwd_dkvq_kernel(int num_channels, int nlon_in, int
}
at::Tensor s2_attention_bwd_dk_cuda(at::Tensor kx,
at::Tensor s2_attention_bwd_dk_cuda(at::Tensor kx,
at::Tensor vx,
at::Tensor qy,
at::Tensor dy,
......@@ -731,7 +731,7 @@ at::Tensor s2_attention_bwd_dk_cuda(at::Tensor kx,
}
at::Tensor s2_attention_bwd_dq_cuda(at::Tensor kx,
at::Tensor s2_attention_bwd_dq_cuda(at::Tensor kx,
at::Tensor vx,
at::Tensor qy,
at::Tensor dy,
......@@ -782,7 +782,7 @@ at::Tensor s2_attention_bwd_dq_cuda(at::Tensor kx,
C10_CUDA_KERNEL_LAUNCH_CHECK();
return dydq;
}
......@@ -840,7 +840,7 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
C10_CUDA_KERNEL_LAUNCH_CHECK();
return std::make_tuple(dydk, dydv, dydq);
}
// coding=utf-8
//
// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
//
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
......@@ -172,9 +172,9 @@ __global__ void s2_attention_kernel(int num_channels, int nlon_in, int nlat_out,
}
torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
at::Tensor vx,
at::Tensor qy,
at::Tensor qy,
at::Tensor quad_weights,
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
......
// coding=utf-8
//
// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
//
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
......@@ -33,7 +33,6 @@
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &s2_attention_fwd_cuda, "(Local) Attention on S2");
m.def("compute_row_offset", &s2_idx_offset_cuda, "Row offset on S2");
m.def("backward_dk", &s2_attention_bwd_dk_cuda, "(Local) Attention gradient on S2 (gradient for k)");
m.def("backward_dv", &s2_attention_bwd_dv_cuda, "(Local) Attention gradient on S2 (gradient for v)");
m.def("backward_dq", &s2_attention_bwd_dq_cuda,
......
// coding=utf-8
//
// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "ATen/core/TensorAccessor.h"
#include <cmath>
#include <cstdint>
#include <torch/extension.h>
#include <torch/torch.h>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <ATen/cuda/detail/KernelUtils.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/CUDAUtils.h>
#include <thrust/reduce.h>
#include <thrust/functional.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/binary_search.h>
#include <thrust/execution_policy.h>
#include <thrust/sequence.h>
__global__ void alpha_count_kernel(int len_alpha_count,
int len_psi_row_idx,
torch::PackedTensorAccessor64<int64_t, 1> psi_row_idx,
int64_t* alpha_start,
torch::PackedTensorAccessor64<int64_t, 1> alpha_count
) {
int ho = blockIdx.x * blockDim.x + threadIdx.x;
if(ho < len_alpha_count) {
// initialize alpha_count;
alpha_count[ho] = 0;
// NOTE: Assumes that psi_row_idx is sorted
for(int i=alpha_start[ho]; i<len_psi_row_idx; i++) {
if(psi_row_idx[i] == ho) alpha_count[ho]++;
else if(psi_row_idx[i] > ho) break;
}
}
}
int s2_idx_offset_cuda(const at::Tensor& psi_col_idx,
const at::Tensor& psi_row_idx,
at::Tensor& row_offset,
at::Tensor& row_count) {
auto stream = at::cuda::getCurrentCUDAStream();
int64_t* d_alpha_start;
int64_t* d_sequence;
int64_t* d_alpha_count = row_count.data_ptr<int64_t>();
int64_t* d_alpha_offset = row_offset.data_ptr<int64_t>();
C10_CUDA_CHECK(cudaMalloc(&d_alpha_start, row_offset.size(0)*sizeof(int64_t)));
// Find the first time each index occurs in psi_row_idx
// psi_row_idx = [0,0,0,0,1,1,1,1,2,2,2...]
// 0 starts at idx=0, 1 starts at idx=4, 2 starts at idx=8, etc
// this assumes that psi_row_idx is sorted!
C10_CUDA_CHECK(cudaMalloc(&d_sequence, row_offset.size(0)*sizeof(int64_t)));
thrust::sequence(thrust::device, d_sequence, d_sequence+row_offset.size(0), 0);
thrust::counting_iterator<int> start(0);
// thrust::lower_bound(thrust::device,
// psi_row_idx.data_ptr<int64_t>(),
// psi_row_idx.data_ptr<int64_t>()+psi_row_idx.size(0),
// start, start+psi_row_idx.size(0), d_alpha_start);
thrust::lower_bound(thrust::device,
psi_row_idx.data_ptr<int64_t>(),
psi_row_idx.data_ptr<int64_t>()+psi_row_idx.size(0),
d_sequence, d_sequence+row_offset.size(0), d_alpha_start);
alpha_count_kernel<<<at::cuda::detail::GET_BLOCKS(row_offset.size(0),512),512,
0,stream.stream()>>>(row_count.size(0),
psi_row_idx.size(0),
psi_row_idx.packed_accessor64<int64_t, 1>(),
d_alpha_start,
row_count.packed_accessor64<int64_t , 1>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
int maxAlphaSize = thrust::reduce(thrust::device,
d_alpha_count,
d_alpha_count+row_count.size(0),
0,
thrust::maximum<int>());
thrust::exclusive_scan(thrust::device,
d_alpha_count,
d_alpha_count+row_count.size(0),
d_alpha_offset);
C10_CUDA_CHECK(cudaFree(d_alpha_start));
C10_CUDA_CHECK(cudaFree(d_sequence));
return maxAlphaSize;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment