Commit 4350ba9f authored by Boris Bonev's avatar Boris Bonev Committed by Boris Bonev
Browse files

fixing bug in quadrature weights for full attention. Adding better unit tests...

fixing bug in quadrature weights for full attention. Adding better unit tests for attention. Cleanup in the cuda code.
parent b6c48457
...@@ -99,7 +99,6 @@ def get_ext_modules(): ...@@ -99,7 +99,6 @@ def get_ext_modules():
"torch_harmonics/csrc/attention/attention_fwd_cuda.cu", "torch_harmonics/csrc/attention/attention_fwd_cuda.cu",
"torch_harmonics/csrc/attention/attention_bwd_cuda.cu", "torch_harmonics/csrc/attention/attention_bwd_cuda.cu",
"torch_harmonics/csrc/attention/attention_interface.cu", "torch_harmonics/csrc/attention/attention_interface.cu",
"torch_harmonics/csrc/attention/attention_row_offset.cu"
], ],
extra_compile_args=get_compile_args("neighborhood_attention") extra_compile_args=get_compile_args("neighborhood_attention")
) )
......
...@@ -35,6 +35,7 @@ from parameterized import parameterized ...@@ -35,6 +35,7 @@ from parameterized import parameterized
# import math # import math
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn
# from torch.autograd import gradcheck # from torch.autograd import gradcheck
from torch_harmonics import AttentionS2, NeighborhoodAttentionS2 from torch_harmonics import AttentionS2, NeighborhoodAttentionS2
...@@ -58,519 +59,157 @@ except ImportError as err: ...@@ -58,519 +59,157 @@ except ImportError as err:
_cuda_extension_available = False _cuda_extension_available = False
# this routine is only supposed to be used in this test, since it is numerically not stable but supports class TestNeighborhoodAttentionS2(unittest.TestCase):
# autograd which some of the better kernels do not
def _neighborhood_attention_s2_torch_test(
kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, quad_weights: torch.Tensor, col_idx: torch.Tensor, row_idx: torch.Tensor, nlon_in: int, nlat_out: int, nlon_out: int
):
out = torch.zeros_like(qy)
for ho in range(nlat_out):
# get nonzero indices in output row
idx_ho = col_idx[row_idx == ho]
for wo in range(nlon_out):
alpha_sum = torch.zeros((out.shape[0],), dtype=out.dtype, device=out.device)
alpha = torch.zeros((out.shape[0], len(idx_ho)), dtype=out.dtype, device=out.device)
for inz, nz_col_idx in enumerate(idx_ho):
# compute input indices from psi datastructure
hi = nz_col_idx // nlon_in
# account for output shift and ensure positive index due to circular condition
wi = nz_col_idx % nlon_in
wip = (wi + wo) % nlon_in
# compute correlation & softmax numerator
q_ho_wo = qy[:, :, ho, wo]
k_hi_wip = kx[:, :, hi, wip]
alpha[:, inz] = torch.exp(torch.sum(q_ho_wo * k_hi_wip, dim=1)) * quad_weights[hi]
# softmax denominator
alpha_sum[:] = alpha_sum[:] + alpha[:, inz]
for inz, nz_col_idx in enumerate(idx_ho):
# compute input indices from psi datastructure
hi = nz_col_idx // nlon_in
# account for output shift and ensure positive index due to circular condition
wi = nz_col_idx % nlon_in
wip = (wi + wo) % nlon_in
# compute matmul of attention matrix with V-vector
out[:, :, ho, wo] = out[:, :, ho, wo] + (alpha[:, None, inz] / alpha_sum[:, None]) * vx[:, :, hi, wip]
return out
class TestNeighborhoodAttention(unittest.TestCase):
def setUp(self): def setUp(self):
if torch.cuda.is_available(): if torch.cuda.is_available():
self.device = torch.device("cuda:0") self.device = torch.device("cuda:0")
torch.cuda.set_device(self.device.index) torch.cuda.set_device(self.device.index)
torch.cuda.manual_seed(333) torch.cuda.manual_seed(333)
torch.manual_seed(333)
else: else:
self.device = torch.device("cpu") self.device = torch.device("cpu")
torch.manual_seed(333) torch.manual_seed(333)
@parameterized.expand( @parameterized.expand(
[ [
# regular convolution # Format: [batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol]
[8, 4, 6, (17, 32), 1e-6, 1e-5], [4, 4, 1, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
[4, 4, 2, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
[4, 4, 4, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
[4, 4, 1, (6, 12), (6, 12), "legendre-gauss", "legendre-gauss", 1e-5, 1e-3],
[4, 4, 1, (6, 12), (6, 12), "lobatto", "lobatto", 1e-5, 1e-3],
] ]
) )
def test_batched_linear(self, batch_size, in_channels, out_channels, shape, atol, rtol): def test_custom_implementation(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=True):
# weight """Tests numerical equivalence between the custom (CUDA) implementation and the reference torch implementation"""
weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, 1, 1, dtype=torch.float32, device=self.device))
bias = torch.nn.Parameter(torch.randn(out_channels, dtype=torch.float32, device=self.device))
# input
inp = torch.randn(batch_size, in_channels, *shape, dtype=torch.float32, device=self.device)
inp.requires_grad = True
# operation
out = torch.nn.functional.conv2d(inp, weight=weight, bias=bias)
out_grad = torch.randn(batch_size, out_channels, *shape, dtype=torch.float32, device=self.device)
out.backward(out_grad)
# store for comparison
wgrad = weight.grad.clone()
bgrad = bias.grad.clone()
igrad = inp.grad.clone()
# explicit layers
igrad_explicit = torch.nn.functional.conv2d(out_grad, weight=weight.permute([1, 0, 2, 3]), bias=None)
wgrad_explicit = torch.einsum("bchw,bfhw->cf", out_grad, inp).reshape(out_channels, in_channels, 1, 1).contiguous()
bgrad_explicit = torch.sum(out_grad, dim=(0, 2, 3))
# check consistency
self.assertTrue(torch.allclose(igrad, igrad_explicit, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(wgrad, wgrad_explicit, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(bgrad, bgrad_explicit, atol=atol, rtol=rtol))
@parameterized.expand(
[
# self attention
[8, 4, 1, (17, 32), (17, 32), "equiangular", "equiangular", 1e-6, 1e-4],
]
)
def test_fwd(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol):
# extract some parameters
nlat_in, nlon_in = in_shape nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape nlat_out, nlon_out = out_shape
# set up neighbor matrix # Helper: create inputs
att = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads, inputs_ref = {
in_shape=in_shape, out_shape=out_shape, "k": torch.randn(batch_size, channels, nlat_in, nlon_in, requires_grad=True, device=self.device, dtype=torch.float32),
grid_in=grid_in, grid_out=grid_out, bias=False).to(self.device) "v": torch.randn(batch_size, channels, nlat_in, nlon_in, requires_grad=True, device=self.device, dtype=torch.float32),
"q": torch.randn(batch_size, channels, nlat_out, nlon_out, requires_grad=True, device=self.device, dtype=torch.float32),
# Execute and compare }
k_inp = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device=self.device) inputs = {k: v.detach().clone().to(self.device).requires_grad_() for k, v in inputs_ref.items()}
k_inp.requires_grad = False
v_inp = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device=self.device) # reference input and model
v_inp.requires_grad = False model_ref = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads, in_shape=in_shape, out_shape=out_shape, grid_in=grid_in, grid_out=grid_out, bias=True).to(
q_inp = torch.randn(batch_size, channels, nlat_out, nlon_out, dtype=torch.float32, device=self.device) self.device
q_inp.requires_grad = False
out_torch = _neighborhood_attention_s2_torch_test(k_inp, v_inp, q_inp, att.quad_weights, att.psi_col_idx, att.psi_row_idx, nlon_in, nlat_out, nlon_out)
with torch.no_grad():
out_torch_explicit = _neighborhood_attention_s2_fwd_torch(k_inp, v_inp, q_inp, att.quad_weights, att.psi_col_idx, att.psi_roff_idx, nlon_in, nlat_out, nlon_out)
self.assertTrue(torch.allclose(out_torch_explicit, out_torch, atol=atol, rtol=rtol))
if _cuda_extension_available:
out_cuda = attention_cuda_extension.forward(k_inp, v_inp, q_inp, att.quad_weights, att.psi_col_idx, att.psi_roff_idx, nlon_in, nlat_out, nlon_out)
self.assertTrue(torch.allclose(out_torch, out_cuda, atol=atol, rtol=rtol))
@parameterized.expand(
[
# regular convolution
[8, 4, 1, (17, 32), (17, 32), "equiangular", "equiangular", 1e-6, 1e-4],
]
) )
def test_bwd_dv(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol):
# extract some parameters
_, nlon_in = in_shape
nlat_out, nlon_out = out_shape
# set up neighbor matrix # Device model and inputs
att = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads, model = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads, in_shape=in_shape, out_shape=out_shape, grid_in=grid_in, grid_out=grid_out, bias=True)
in_shape=in_shape, out_shape=out_shape,
grid_in=grid_in, grid_out=grid_out, bias=False).to(self.device) # Synchronize parameters of model
model.load_state_dict(model_ref.state_dict())
# Execute and compare model = model.to(self.device)
k_inp = torch.randn(batch_size, channels, *in_shape, dtype=torch.float32, device=self.device) for (name_ref, p_ref), (name, p) in zip(model_ref.named_parameters(), model.named_parameters()):
k_inp.requires_grad = False assert torch.allclose(p_ref, p), f"Parameter mismatch: {name_ref} vs {name}"
v_inp = torch.randn(batch_size, channels, *in_shape, dtype=torch.float32, device=self.device)
v_inp.requires_grad = True # reference forward passes
q_inp = torch.randn(batch_size, channels, *out_shape, dtype=torch.float32, device=self.device) out_ref = _neighborhood_attention_s2_torch(
q_inp.requires_grad = False inputs_ref["k"],
out_grad = torch.randn(batch_size, channels, *out_shape, dtype=torch.float32, device=self.device) inputs_ref["v"],
inputs_ref["q"] * model_ref.scale,
out_torch = _neighborhood_attention_s2_torch_test(k_inp, v_inp, q_inp, att.quad_weights, att.psi_col_idx, att.psi_row_idx, nlon_in, nlat_out, nlon_out) model_ref.k_weights,
model_ref.v_weights,
# need 'retain_graph' to avoid an error in the tests after this one model_ref.q_weights,
out_torch.backward(out_grad) model_ref.k_bias,
dv_inp_torch = v_inp.grad.clone() model_ref.v_bias,
model_ref.q_bias,
with torch.no_grad(): model_ref.quad_weights,
dv_inp_torch_explicit = _neighborhood_attention_s2_bwd_dv_torch( model_ref.psi_col_idx,
k_inp, v_inp, q_inp, out_grad, att.quad_weights, att.psi_col_idx, att.psi_roff_idx, nlon_in, nlat_out, nlon_out model_ref.psi_roff_idx,
model_ref.num_heads,
model_ref.nlon_in,
model_ref.nlat_out,
model_ref.nlon_out,
) )
out_ref = nn.functional.conv2d(out_ref, model_ref.proj_weights, bias=model_ref.proj_bias)
out = model(inputs["q"], inputs["k"], inputs["v"])
self.assertTrue(torch.allclose(dv_inp_torch_explicit, dv_inp_torch, atol=atol, rtol=rtol)) # Check forward equivalence
self.assertTrue(torch.allclose(out, out_ref, atol=atol, rtol=rtol), "Forward outputs differ between torch reference and custom implementation")
if _cuda_extension_available: # Backward passes
grad = torch.randn_like(out_ref)
out_ref.backward(grad)
out.backward(grad.to(self.device))
dv_inp_cuda_explicit = attention_cuda_extension.backward_dv( # Check input gradient equivalence
k_inp, v_inp, q_inp, out_grad, att.quad_weights, att.psi_col_idx, att.psi_roff_idx, nlon_in, nlat_out, nlon_out for inp in ["q", "k", "v"]:
) grad_ref = inputs_ref[inp].grad.cpu()
grad = inputs[inp].grad.cpu()
self.assertTrue(torch.allclose(grad, grad_ref, atol=atol, rtol=rtol), f"Input gradient mismatch in {inp}")
self.assertTrue(torch.allclose(dv_inp_cuda_explicit, dv_inp_torch, atol=atol, rtol=rtol)) # Check parameter gradient equivalence
for p_ref, p in zip(model_ref.parameters(), model.parameters()):
self.assertTrue(torch.allclose(p.grad, p_ref.grad, atol=atol, rtol=rtol), f"Parameter gradient mismatch: {type(p_ref).__name__}")
# caution: multihead-implementation between full and neighborhood attention still seem to differ. tests are only done for single head
@parameterized.expand( @parameterized.expand(
[ [
# regular convolution # Format: [batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol]
[8, 4, 1, (17, 32), (17, 32), "equiangular", "equiangular", 1e-6, 1e-3], [4, 4, 1, (6, 12), (6, 12), "equiangular", "equiangular", 1e-2, 0],
# [4, 4, 2, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
# [4, 4, 4, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
[4, 4, 1, (6, 12), (6, 12), "legendre-gauss", "legendre-gauss", 1e-2, 0],
[4, 4, 1, (6, 12), (6, 12), "lobatto", "lobatto", 1e-2, 0],
] ]
) )
def test_bwd_dk(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol): def test_neighborhood_global_equivalence(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=True):
"""Tests numerical equivalence between the global spherical attention module and the neighborhood spherical attention module with the neighborhood set ot the whole sphere"""
# extract some parameters
nlat_in, nlon_in = in_shape nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape nlat_out, nlon_out = out_shape
# set up neighbor matrix # Helper: create inputs
att = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads, inputs_ref = {
in_shape=in_shape, out_shape=out_shape, "k": torch.randn(batch_size, channels, nlat_in, nlon_in, requires_grad=True, device=self.device, dtype=torch.float32),
grid_in=grid_in, grid_out=grid_out, bias=False).to(self.device) "v": torch.randn(batch_size, channels, nlat_in, nlon_in, requires_grad=True, device=self.device, dtype=torch.float32),
"q": torch.randn(batch_size, channels, nlat_out, nlon_out, requires_grad=True, device=self.device, dtype=torch.float32),
# Execute and compare }
k_inp = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device=self.device) inputs = {k: v.detach().clone().to(self.device).requires_grad_() for k, v in inputs_ref.items()}
k_inp.requires_grad = True
v_inp = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device=self.device)
v_inp.requires_grad = False
q_inp = torch.randn(batch_size, channels, nlat_out, nlon_out, dtype=torch.float32, device=self.device)
q_inp.requires_grad = False
out_grad = torch.randn(batch_size, channels, nlat_out, nlon_out, dtype=torch.float32, device=self.device)
out_torch = _neighborhood_attention_s2_torch_test(k_inp, v_inp, q_inp, att.quad_weights, att.psi_col_idx, att.psi_row_idx, nlon_in, nlat_out, nlon_out)
# need 'retain_graph' to avoid an error in the tests after this one
out_torch.backward(out_grad)
dk_inp_torch = k_inp.grad.clone()
with torch.no_grad():
dk_inp_torch_explicit = _neighborhood_attention_s2_bwd_dk_torch(
k_inp, v_inp, q_inp, out_grad, att.quad_weights, att.psi_col_idx, att.psi_roff_idx, nlon_in, nlat_out, nlon_out
)
self.assertTrue(torch.allclose(dk_inp_torch_explicit, dk_inp_torch, atol=atol, rtol=rtol))
if _cuda_extension_available: # reference input and model
model_ref = AttentionS2(in_channels=channels, num_heads=heads, in_shape=in_shape, out_shape=out_shape, grid_in=grid_in, grid_out=grid_out, bias=False).to(self.device)
dk_inp_cuda_explicit = attention_cuda_extension.backward_dk( # Device model and inputs
k_inp, v_inp, q_inp, out_grad, att.quad_weights, att.psi_col_idx, att.psi_roff_idx, nlon_in, nlat_out, nlon_out model = NeighborhoodAttentionS2(
in_channels=channels, num_heads=heads, in_shape=in_shape, out_shape=out_shape, grid_in=grid_in, grid_out=grid_out, bias=False, theta_cutoff=2 * torch.pi
) )
self.assertTrue(torch.allclose(dk_inp_cuda_explicit, dk_inp_torch, atol=atol, rtol=rtol)) # Synchronize parameters of model
model.load_state_dict(model_ref.state_dict())
@parameterized.expand( model = model.to(self.device)
[ for (name_ref, p_ref), (name, p) in zip(model_ref.named_parameters(), model.named_parameters()):
# regular convolution assert torch.allclose(p_ref, p), f"Parameter mismatch: {name_ref} vs {name}"
[8, 4, 1, (17, 32), (17, 32), "equiangular", "equiangular", 4e-6, 1e-3],
] # reference forward passes
) out_ref = model_ref(inputs_ref["q"], inputs_ref["k"], inputs_ref["v"])
def test_bwd_dq(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol): out = model(inputs["q"], inputs["k"], inputs["v"])
# extract some parameters # Check forward equivalence
nlat_in, nlon_in = in_shape self.assertTrue(torch.allclose(out, out_ref, atol=atol, rtol=rtol), "Forward outputs differ between torch reference and custom implementation")
nlat_out, nlon_out = out_shape
# Backward passes
# set up neighbor matrix grad = torch.randn_like(out_ref)
att = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads, out_ref.backward(grad)
in_shape=in_shape, out_shape=out_shape, out.backward(grad.to(self.device))
grid_in=grid_in, grid_out=grid_out, bias=False).to(self.device)
# Check input gradient equivalence
# Execute and compare for inp in ["q", "k", "v"]:
k_inp = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device=self.device) grad_ref = inputs_ref[inp].grad
k_inp.requires_grad = False grad = inputs[inp].grad
v_inp = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device=self.device) self.assertTrue(torch.allclose(grad, grad_ref, atol=atol, rtol=rtol), f"Input gradient mismatch in {inp}")
v_inp.requires_grad = False
q_inp = torch.randn(batch_size, channels, nlat_out, nlon_out, dtype=torch.float32, device=self.device) # Check parameter gradient equivalence - check only q,k, v weights
q_inp.requires_grad = True for key in ["q_weights", "k_weights", "v_weights"]:
out_grad = torch.randn(batch_size, channels, nlat_out, nlon_out, dtype=torch.float32, device=self.device) grad_ref = getattr(model_ref, key).grad
grad = getattr(model, key).grad
out_torch = _neighborhood_attention_s2_torch_test(k_inp, v_inp, q_inp, att.quad_weights, att.psi_col_idx, att.psi_row_idx, nlon_in, nlat_out, nlon_out) self.assertTrue(torch.allclose(grad, grad_ref, atol=atol, rtol=rtol), f"Parameter gradient mismatch")
# need 'retain_graph' to avoid an error in the tests after this one
out_torch.backward(out_grad)
dq_inp_torch = q_inp.grad.clone()
with torch.no_grad():
dq_inp_torch_explicit = _neighborhood_attention_s2_bwd_dq_torch(
k_inp, v_inp, q_inp, out_grad, att.quad_weights, att.psi_col_idx, att.psi_roff_idx, nlon_in, nlat_out, nlon_out
)
self.assertTrue(torch.allclose(dq_inp_torch_explicit, dq_inp_torch, atol=atol, rtol=rtol))
if _cuda_extension_available:
dq_inp_cuda_explicit = attention_cuda_extension.backward_dq(
k_inp, v_inp, q_inp, out_grad, att.quad_weights, att.psi_col_idx, att.psi_roff_idx, nlon_in, nlat_out, nlon_out
)
self.assertTrue(torch.allclose(dq_inp_cuda_explicit, dq_inp_torch, atol=atol, rtol=rtol))
@parameterized.expand(
[
# self attention
[1, 73, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
]
)
def test_big(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol):
# this test only makes sense when CUDA version is available
if torch.cuda.is_available():
if not _cuda_extension_available:
print("WARNING: Problem loading CUDA attention module")
return
# extract some parameters
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
# TODO: this test seems hardcoded for GPU. Is this necessary?
k_gpu = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device="cuda:0")
k_gpu.requires_grad = False
v_gpu = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device="cuda:0")
v_gpu.requires_grad = False
q_gpu = torch.randn(batch_size, channels, nlat_out, nlon_out, dtype=torch.float32, device="cuda:0")
q_gpu.requires_grad = False
# set up layers
att_gpu = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads,
in_shape=in_shape, out_shape=out_shape,
grid_in=grid_in, grid_out=grid_out, bias=True).to("cuda:0")
# random weights
with torch.no_grad():
att_gpu.q_weights.normal_()
att_gpu.k_weights.normal_()
att_gpu.v_weights.normal_()
att_gpu.q_bias.normal_()
att_gpu.k_bias.normal_()
att_gpu.v_bias.normal_()
out_gpu = att_gpu(q_gpu, k_gpu, v_gpu)
# sync weights:
with torch.no_grad():
att_gpu.q_weights.copy_(att_gpu.q_weights)
att_gpu.k_weights.copy_(att_gpu.k_weights)
att_gpu.v_weights.copy_(att_gpu.v_weights)
att_gpu.q_bias.copy_(att_gpu.q_bias)
att_gpu.k_bias.copy_(att_gpu.k_bias)
att_gpu.v_bias.copy_(att_gpu.v_bias)
q_gpu = q_gpu.detach().clone().to(self.device)
q_gpu.requires_grad = True
k_gpu = k_gpu.detach().clone().to(self.device)
k_gpu.requires_grad = True
v_gpu = v_gpu.detach().clone().to(self.device)
v_gpu.requires_grad = True
out_gpu = att_gpu(q_gpu, k_gpu, v_gpu)
out_grad = torch.randn(out_gpu.shape, dtype=torch.float32, device="cuda:0")
out_gpu.backward(out_grad.to("cuda:0"))
@parameterized.expand(
[
# self attention
[10, 2, 1, (17, 32), (17, 32), "equiangular", "equiangular", 1e-5, 1e-5],
]
)
def test_neighborhood(self, batch_size, num_channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol):
"""
This test sets a specific q[ho,wo] value to 1.0 (elsewhere 0), and then a neighborhood of k around ho,wo to 1.0 (else 0.0). Also vi is set to a sinusoidal input. We also run it with fully 0 q,k. We test that the output of the nonzero q,k is only different to the zero q,k in a single point. We also check the value of this difference (as a regression test).
"""
# extract some parameters
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
from torch_harmonics import _neighborhood_attention_s2_fwd_torch
device = "cpu"
nas2_2 = NeighborhoodAttentionS2(in_channels=num_channels, num_heads=heads,
in_shape=(nlat_in, nlon_in), out_shape=(nlat_out, nlon_out),
theta_cutoff=torch.pi / 128 * 10)
nas2_2.to(device)
qo = torch.zeros((batch_size, num_channels, nlat_in, nlon_in)).to(device)
x = torch.linspace(0, 2 * np.pi, nlat_in) # 100 points in x direction
y = torch.linspace(0, 2 * np.pi, nlon_in) # 100 points in y direction
# Create a meshgrid
X, Y = torch.meshgrid(x, y, indexing="ij")
vi = torch.ones((batch_size, num_channels, nlat_in, nlon_in)).to(device)
vi[:, :, :, :] = (torch.sin(X) + torch.sin(Y))[None, None, :, :]
ki = torch.zeros((batch_size, num_channels, nlat_in, nlon_in)).to(device)
ki2 = torch.zeros((batch_size, num_channels, nlat_in, nlon_in)).to(device)
ho = 10
wo = 15
qo[:, 0, ho, wo] = 1.0
nas3 = NeighborhoodAttentionS2(in_channels=num_channels, num_heads=heads,
in_shape=(nlat_in, nlon_in), out_shape=(nlat_out, nlon_out),
theta_cutoff=torch.pi / 128 * 7)
zstart = nas3.psi_roff_idx[ho]
zend = nas3.psi_roff_idx[ho + 1]
# set a small neighborhood of k around (ho,wo) to 1
for idz in range(zstart, zend):
nz_col_idx = nas3.psi_col_idx[idz]
# compute input indices from psi datastructure
hi = nz_col_idx // nlon_in
# account for output shift and ensure positive index due to circular condition
wi = nz_col_idx % nlon_in
wip = (wi + wo) % nlon_in
ki2[:, 0, hi, wip] = 1.0
# run with k zero
y = _neighborhood_attention_s2_fwd_torch(ki, vi, qo, nas2_2.quad_weights, nas2_2.psi_col_idx, nas2_2.psi_roff_idx, nlon_in, nlat_out, nlon_out)
# run with k 1 at neighborhood of ho,wo
y2 = _neighborhood_attention_s2_fwd_torch(ki2, vi, qo, nas2_2.quad_weights, nas2_2.psi_col_idx, nas2_2.psi_roff_idx, nlon_in, nlat_out, nlon_out)
# for viz if desired
# plt.matshow((y[0,0,:,:]-y2[0,0,:,:]).detach().cpu())#, vmin=0, vmax=2.0)
# plt.matshow((y2[0,0,:,:]).detach().cpu())#, vmin=0, vmax=2.0)
# compare zero k vs. nonzero k and ensure difference only occurs at ho,wo
nz_x, nz_y = torch.where((y[0, 0, :, :] - y2[0, 0, :, :]).abs() > 0)
self.assertTrue(nz_x.item() == ho)
self.assertTrue(nz_y.item() == wo)
h, w = nz_x.item(), nz_y.item()
diff_hw = y[0, 0, h, w] - y2[0, 0, h, w]
# print("diff_hw=", diff_hw.item())
# regression test the difference. Unfortunately difficult to come up with an
# analytical value, so we just have it hardcoded.
self.assertTrue(torch.allclose(diff_hw, torch.tensor([0.00753], device=device), rtol=rtol, atol=atol))
@parameterized.expand(
[
# self attention
[8, 4, 1, (17, 32), (17, 32), "equiangular", "equiangular", 2e-4, 1e-5],
[8, 4, 2, (17, 32), (17, 32), "equiangular", "equiangular", 2e-4, 1e-5],
]
)
def test_full(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol):
# extract some parameters
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
k_cpu = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device="cpu")
k_cpu.requires_grad = True
v_cpu = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device="cpu")
v_cpu.requires_grad = True
q_cpu = torch.randn(batch_size, channels, nlat_out, nlon_out, dtype=torch.float32, device="cpu")
q_cpu.requires_grad = True
# set up layers
att_cpu = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads,
in_shape=in_shape, out_shape=out_shape,
grid_in=grid_in, grid_out=grid_out, bias=True)
# random weights
with torch.no_grad():
att_cpu.q_weights.normal_()
att_cpu.k_weights.normal_()
att_cpu.v_weights.normal_()
att_cpu.q_bias.normal_()
att_cpu.k_bias.normal_()
att_cpu.v_bias.normal_()
out_cpu = att_cpu(q_cpu, k_cpu, v_cpu)
out_grad = torch.randn(out_cpu.shape, dtype=torch.float32, device="cpu")
out_cpu.backward(out_grad)
att_gpu = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads,
in_shape=in_shape, out_shape=out_shape,
grid_in=grid_in, grid_out=grid_out, bias=True).to(self.device)
# sync weights:
with torch.no_grad():
att_gpu.q_weights.copy_(att_cpu.q_weights)
att_gpu.k_weights.copy_(att_cpu.k_weights)
att_gpu.v_weights.copy_(att_cpu.v_weights)
att_gpu.proj_weights.copy_(att_cpu.proj_weights)
att_gpu.q_bias.copy_(att_cpu.q_bias)
att_gpu.k_bias.copy_(att_cpu.k_bias)
att_gpu.v_bias.copy_(att_cpu.v_bias)
att_gpu.proj_bias.copy_(att_cpu.proj_bias)
q_gpu = q_cpu.detach().clone().to(self.device)
q_gpu.requires_grad = True
k_gpu = k_cpu.detach().clone().to(self.device)
k_gpu.requires_grad = True
v_gpu = v_cpu.detach().clone().to(self.device)
v_gpu.requires_grad = True
out_gpu = att_gpu(q_gpu, k_gpu, v_gpu)
out_gpu.backward(out_grad.to(self.device))
# check forward
self.assertTrue(torch.allclose(out_cpu.to(self.device), out_gpu, atol=atol, rtol=rtol))
# check input gradients:
self.assertTrue(torch.allclose(q_cpu.grad.to(self.device), q_gpu.grad, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(k_cpu.grad.to(self.device), k_gpu.grad, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(v_cpu.grad.to(self.device), v_gpu.grad, atol=atol, rtol=rtol))
# check weight gradients
self.assertTrue(torch.allclose(att_cpu.q_weights.grad.to(self.device), att_gpu.q_weights.grad, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(att_cpu.k_weights.grad.to(self.device), att_gpu.k_weights.grad, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(att_cpu.v_weights.grad.to(self.device), att_gpu.v_weights.grad, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(att_cpu.proj_weights.grad.to(self.device), att_gpu.proj_weights.grad, atol=atol, rtol=rtol))
# check bias gradients
self.assertTrue(torch.allclose(att_cpu.q_bias.grad.to(self.device), att_gpu.q_bias.grad, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(att_cpu.k_bias.grad.to(self.device), att_gpu.k_bias.grad, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(att_cpu.v_bias.grad.to(self.device), att_gpu.v_bias.grad, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(att_cpu.proj_bias.grad.to(self.device), att_gpu.proj_bias.grad, atol=atol, rtol=rtol))
@parameterized.expand(
[
# self attention
[8, 8, 8, 2, (17, 32), (17, 32), "equiangular", "equiangular", 2e-4, 1e-5],
[8, 8, 8, 2, (17, 32), (17, 32), "legendre-gauss", "legendre-gauss", 2e-4, 1e-5],
[8, 8, 8, 2, (17, 32), (17, 32), "lobatto", "lobatto", 2e-4, 1e-5],
[8, 8, 4, 2, (17, 32), (17, 32), "equiangular", "equiangular", 2e-4, 1e-5],
]
)
def test_full_attention(self, batch_size, channels_in, channels_out, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol):
# extract some parameters
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
k_cpu = torch.randn(batch_size, channels_in, nlat_in, nlon_in, dtype=torch.float32, device="cpu")
k_cpu.requires_grad = True
v_cpu = torch.randn(batch_size, channels_in, nlat_in, nlon_in, dtype=torch.float32, device="cpu")
v_cpu.requires_grad = True
q_cpu = torch.randn(batch_size, channels_in, nlat_out, nlon_out, dtype=torch.float32, device="cpu")
q_cpu.requires_grad = True
att_cpu = AttentionS2(in_channels=channels_in, out_channels=channels_out, num_heads=heads, in_shape=in_shape, out_shape=out_shape, grid_in=grid_in, grid_out=grid_out, bias=True)
out = att_cpu(q_cpu, k_cpu, v_cpu)
# check if output is sane
self.assertFalse(torch.isnan(out).any())
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -107,7 +107,7 @@ class AttentionS2(nn.Module): ...@@ -107,7 +107,7 @@ class AttentionS2(nn.Module):
_, wgl = _precompute_latitudes(self.nlat_in, grid=grid_in) _, wgl = _precompute_latitudes(self.nlat_in, grid=grid_in)
quad_weights = 2.0 * torch.pi * wgl.to(dtype=torch.float32) / self.nlon_in quad_weights = 2.0 * torch.pi * wgl.to(dtype=torch.float32) / self.nlon_in
# we need to tile and flatten them accordingly # we need to tile and flatten them accordingly
quad_weights = torch.tile(quad_weights, (1, self.nlon_in)).flatten() quad_weights = torch.tile(quad_weights.reshape(-1, 1), (1, self.nlon_in)).flatten()
# compute log because they are applied as an addition prior to the softmax ('attn_mask'), which includes an exponential. # compute log because they are applied as an addition prior to the softmax ('attn_mask'), which includes an exponential.
# see https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html # see https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
......
...@@ -76,7 +76,3 @@ torch::Tensor s2_attention_bwd_dv_cuda(at::Tensor kx, ...@@ -76,7 +76,3 @@ torch::Tensor s2_attention_bwd_dv_cuda(at::Tensor kx,
at::Tensor psi_col_idx, at::Tensor psi_col_idx,
at::Tensor psi_row_off, at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out); int nlon_in, int nlat_out, int nlon_out);
int s2_idx_offset_cuda(const at::Tensor &psi_col_idx,
const at::Tensor &psi_row_idx, at::Tensor &row_offset,
at::Tensor &row_count);
// coding=utf-8 // coding=utf-8
// //
// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause // SPDX-License-Identifier: BSD-3-Clause
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
......
// coding=utf-8 // coding=utf-8
// //
// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause // SPDX-License-Identifier: BSD-3-Clause
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
......
// coding=utf-8 // coding=utf-8
// //
// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause // SPDX-License-Identifier: BSD-3-Clause
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
...@@ -33,7 +33,6 @@ ...@@ -33,7 +33,6 @@
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &s2_attention_fwd_cuda, "(Local) Attention on S2"); m.def("forward", &s2_attention_fwd_cuda, "(Local) Attention on S2");
m.def("compute_row_offset", &s2_idx_offset_cuda, "Row offset on S2");
m.def("backward_dk", &s2_attention_bwd_dk_cuda, "(Local) Attention gradient on S2 (gradient for k)"); m.def("backward_dk", &s2_attention_bwd_dk_cuda, "(Local) Attention gradient on S2 (gradient for k)");
m.def("backward_dv", &s2_attention_bwd_dv_cuda, "(Local) Attention gradient on S2 (gradient for v)"); m.def("backward_dv", &s2_attention_bwd_dv_cuda, "(Local) Attention gradient on S2 (gradient for v)");
m.def("backward_dq", &s2_attention_bwd_dq_cuda, m.def("backward_dq", &s2_attention_bwd_dq_cuda,
......
// coding=utf-8
//
// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "ATen/core/TensorAccessor.h"
#include <cmath>
#include <cstdint>
#include <torch/extension.h>
#include <torch/torch.h>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <ATen/cuda/detail/KernelUtils.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/CUDAUtils.h>
#include <thrust/reduce.h>
#include <thrust/functional.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/binary_search.h>
#include <thrust/execution_policy.h>
#include <thrust/sequence.h>
__global__ void alpha_count_kernel(int len_alpha_count,
int len_psi_row_idx,
torch::PackedTensorAccessor64<int64_t, 1> psi_row_idx,
int64_t* alpha_start,
torch::PackedTensorAccessor64<int64_t, 1> alpha_count
) {
int ho = blockIdx.x * blockDim.x + threadIdx.x;
if(ho < len_alpha_count) {
// initialize alpha_count;
alpha_count[ho] = 0;
// NOTE: Assumes that psi_row_idx is sorted
for(int i=alpha_start[ho]; i<len_psi_row_idx; i++) {
if(psi_row_idx[i] == ho) alpha_count[ho]++;
else if(psi_row_idx[i] > ho) break;
}
}
}
int s2_idx_offset_cuda(const at::Tensor& psi_col_idx,
const at::Tensor& psi_row_idx,
at::Tensor& row_offset,
at::Tensor& row_count) {
auto stream = at::cuda::getCurrentCUDAStream();
int64_t* d_alpha_start;
int64_t* d_sequence;
int64_t* d_alpha_count = row_count.data_ptr<int64_t>();
int64_t* d_alpha_offset = row_offset.data_ptr<int64_t>();
C10_CUDA_CHECK(cudaMalloc(&d_alpha_start, row_offset.size(0)*sizeof(int64_t)));
// Find the first time each index occurs in psi_row_idx
// psi_row_idx = [0,0,0,0,1,1,1,1,2,2,2...]
// 0 starts at idx=0, 1 starts at idx=4, 2 starts at idx=8, etc
// this assumes that psi_row_idx is sorted!
C10_CUDA_CHECK(cudaMalloc(&d_sequence, row_offset.size(0)*sizeof(int64_t)));
thrust::sequence(thrust::device, d_sequence, d_sequence+row_offset.size(0), 0);
thrust::counting_iterator<int> start(0);
// thrust::lower_bound(thrust::device,
// psi_row_idx.data_ptr<int64_t>(),
// psi_row_idx.data_ptr<int64_t>()+psi_row_idx.size(0),
// start, start+psi_row_idx.size(0), d_alpha_start);
thrust::lower_bound(thrust::device,
psi_row_idx.data_ptr<int64_t>(),
psi_row_idx.data_ptr<int64_t>()+psi_row_idx.size(0),
d_sequence, d_sequence+row_offset.size(0), d_alpha_start);
alpha_count_kernel<<<at::cuda::detail::GET_BLOCKS(row_offset.size(0),512),512,
0,stream.stream()>>>(row_count.size(0),
psi_row_idx.size(0),
psi_row_idx.packed_accessor64<int64_t, 1>(),
d_alpha_start,
row_count.packed_accessor64<int64_t , 1>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
int maxAlphaSize = thrust::reduce(thrust::device,
d_alpha_count,
d_alpha_count+row_count.size(0),
0,
thrust::maximum<int>());
thrust::exclusive_scan(thrust::device,
d_alpha_count,
d_alpha_count+row_count.size(0),
d_alpha_offset);
C10_CUDA_CHECK(cudaFree(d_alpha_start));
C10_CUDA_CHECK(cudaFree(d_sequence));
return maxAlphaSize;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment