Commit 4350ba9f authored by Boris Bonev's avatar Boris Bonev Committed by Boris Bonev
Browse files

fixing bug in quadrature weights for full attention. Adding better unit tests...

fixing bug in quadrature weights for full attention. Adding better unit tests for attention. Cleanup in the cuda code.
parent b6c48457
...@@ -99,7 +99,6 @@ def get_ext_modules(): ...@@ -99,7 +99,6 @@ def get_ext_modules():
"torch_harmonics/csrc/attention/attention_fwd_cuda.cu", "torch_harmonics/csrc/attention/attention_fwd_cuda.cu",
"torch_harmonics/csrc/attention/attention_bwd_cuda.cu", "torch_harmonics/csrc/attention/attention_bwd_cuda.cu",
"torch_harmonics/csrc/attention/attention_interface.cu", "torch_harmonics/csrc/attention/attention_interface.cu",
"torch_harmonics/csrc/attention/attention_row_offset.cu"
], ],
extra_compile_args=get_compile_args("neighborhood_attention") extra_compile_args=get_compile_args("neighborhood_attention")
) )
......
...@@ -35,6 +35,7 @@ from parameterized import parameterized ...@@ -35,6 +35,7 @@ from parameterized import parameterized
# import math # import math
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn
# from torch.autograd import gradcheck # from torch.autograd import gradcheck
from torch_harmonics import AttentionS2, NeighborhoodAttentionS2 from torch_harmonics import AttentionS2, NeighborhoodAttentionS2
...@@ -58,520 +59,158 @@ except ImportError as err: ...@@ -58,520 +59,158 @@ except ImportError as err:
_cuda_extension_available = False _cuda_extension_available = False
# this routine is only supposed to be used in this test, since it is numerically not stable but supports class TestNeighborhoodAttentionS2(unittest.TestCase):
# autograd which some of the better kernels do not
def _neighborhood_attention_s2_torch_test(
kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, quad_weights: torch.Tensor, col_idx: torch.Tensor, row_idx: torch.Tensor, nlon_in: int, nlat_out: int, nlon_out: int
):
out = torch.zeros_like(qy)
for ho in range(nlat_out):
# get nonzero indices in output row
idx_ho = col_idx[row_idx == ho]
for wo in range(nlon_out):
alpha_sum = torch.zeros((out.shape[0],), dtype=out.dtype, device=out.device)
alpha = torch.zeros((out.shape[0], len(idx_ho)), dtype=out.dtype, device=out.device)
for inz, nz_col_idx in enumerate(idx_ho):
# compute input indices from psi datastructure
hi = nz_col_idx // nlon_in
# account for output shift and ensure positive index due to circular condition
wi = nz_col_idx % nlon_in
wip = (wi + wo) % nlon_in
# compute correlation & softmax numerator
q_ho_wo = qy[:, :, ho, wo]
k_hi_wip = kx[:, :, hi, wip]
alpha[:, inz] = torch.exp(torch.sum(q_ho_wo * k_hi_wip, dim=1)) * quad_weights[hi]
# softmax denominator
alpha_sum[:] = alpha_sum[:] + alpha[:, inz]
for inz, nz_col_idx in enumerate(idx_ho):
# compute input indices from psi datastructure
hi = nz_col_idx // nlon_in
# account for output shift and ensure positive index due to circular condition
wi = nz_col_idx % nlon_in
wip = (wi + wo) % nlon_in
# compute matmul of attention matrix with V-vector
out[:, :, ho, wo] = out[:, :, ho, wo] + (alpha[:, None, inz] / alpha_sum[:, None]) * vx[:, :, hi, wip]
return out
class TestNeighborhoodAttention(unittest.TestCase):
def setUp(self): def setUp(self):
if torch.cuda.is_available(): if torch.cuda.is_available():
self.device = torch.device("cuda:0") self.device = torch.device("cuda:0")
torch.cuda.set_device(self.device.index) torch.cuda.set_device(self.device.index)
torch.cuda.manual_seed(333) torch.cuda.manual_seed(333)
torch.manual_seed(333)
else: else:
self.device = torch.device("cpu") self.device = torch.device("cpu")
torch.manual_seed(333)
torch.manual_seed(333)
@parameterized.expand(
[
# regular convolution
[8, 4, 6, (17, 32), 1e-6, 1e-5],
]
)
def test_batched_linear(self, batch_size, in_channels, out_channels, shape, atol, rtol):
# weight
weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, 1, 1, dtype=torch.float32, device=self.device))
bias = torch.nn.Parameter(torch.randn(out_channels, dtype=torch.float32, device=self.device))
# input
inp = torch.randn(batch_size, in_channels, *shape, dtype=torch.float32, device=self.device)
inp.requires_grad = True
# operation
out = torch.nn.functional.conv2d(inp, weight=weight, bias=bias)
out_grad = torch.randn(batch_size, out_channels, *shape, dtype=torch.float32, device=self.device)
out.backward(out_grad)
# store for comparison
wgrad = weight.grad.clone()
bgrad = bias.grad.clone()
igrad = inp.grad.clone()
# explicit layers
igrad_explicit = torch.nn.functional.conv2d(out_grad, weight=weight.permute([1, 0, 2, 3]), bias=None)
wgrad_explicit = torch.einsum("bchw,bfhw->cf", out_grad, inp).reshape(out_channels, in_channels, 1, 1).contiguous()
bgrad_explicit = torch.sum(out_grad, dim=(0, 2, 3))
# check consistency
self.assertTrue(torch.allclose(igrad, igrad_explicit, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(wgrad, wgrad_explicit, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(bgrad, bgrad_explicit, atol=atol, rtol=rtol))
@parameterized.expand( @parameterized.expand(
[ [
# self attention # Format: [batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol]
[8, 4, 1, (17, 32), (17, 32), "equiangular", "equiangular", 1e-6, 1e-4], [4, 4, 1, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
[4, 4, 2, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
[4, 4, 4, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
[4, 4, 1, (6, 12), (6, 12), "legendre-gauss", "legendre-gauss", 1e-5, 1e-3],
[4, 4, 1, (6, 12), (6, 12), "lobatto", "lobatto", 1e-5, 1e-3],
] ]
) )
def test_fwd(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol): def test_custom_implementation(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=True):
"""Tests numerical equivalence between the custom (CUDA) implementation and the reference torch implementation"""
# extract some parameters
nlat_in, nlon_in = in_shape nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape nlat_out, nlon_out = out_shape
# set up neighbor matrix # Helper: create inputs
att = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads, inputs_ref = {
in_shape=in_shape, out_shape=out_shape, "k": torch.randn(batch_size, channels, nlat_in, nlon_in, requires_grad=True, device=self.device, dtype=torch.float32),
grid_in=grid_in, grid_out=grid_out, bias=False).to(self.device) "v": torch.randn(batch_size, channels, nlat_in, nlon_in, requires_grad=True, device=self.device, dtype=torch.float32),
"q": torch.randn(batch_size, channels, nlat_out, nlon_out, requires_grad=True, device=self.device, dtype=torch.float32),
# Execute and compare }
k_inp = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device=self.device) inputs = {k: v.detach().clone().to(self.device).requires_grad_() for k, v in inputs_ref.items()}
k_inp.requires_grad = False
v_inp = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device=self.device) # reference input and model
v_inp.requires_grad = False model_ref = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads, in_shape=in_shape, out_shape=out_shape, grid_in=grid_in, grid_out=grid_out, bias=True).to(
q_inp = torch.randn(batch_size, channels, nlat_out, nlon_out, dtype=torch.float32, device=self.device) self.device
q_inp.requires_grad = False )
out_torch = _neighborhood_attention_s2_torch_test(k_inp, v_inp, q_inp, att.quad_weights, att.psi_col_idx, att.psi_row_idx, nlon_in, nlat_out, nlon_out) # Device model and inputs
model = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads, in_shape=in_shape, out_shape=out_shape, grid_in=grid_in, grid_out=grid_out, bias=True)
with torch.no_grad():
out_torch_explicit = _neighborhood_attention_s2_fwd_torch(k_inp, v_inp, q_inp, att.quad_weights, att.psi_col_idx, att.psi_roff_idx, nlon_in, nlat_out, nlon_out) # Synchronize parameters of model
model.load_state_dict(model_ref.state_dict())
self.assertTrue(torch.allclose(out_torch_explicit, out_torch, atol=atol, rtol=rtol)) model = model.to(self.device)
for (name_ref, p_ref), (name, p) in zip(model_ref.named_parameters(), model.named_parameters()):
if _cuda_extension_available: assert torch.allclose(p_ref, p), f"Parameter mismatch: {name_ref} vs {name}"
out_cuda = attention_cuda_extension.forward(k_inp, v_inp, q_inp, att.quad_weights, att.psi_col_idx, att.psi_roff_idx, nlon_in, nlat_out, nlon_out) # reference forward passes
out_ref = _neighborhood_attention_s2_torch(
self.assertTrue(torch.allclose(out_torch, out_cuda, atol=atol, rtol=rtol)) inputs_ref["k"],
inputs_ref["v"],
@parameterized.expand( inputs_ref["q"] * model_ref.scale,
[ model_ref.k_weights,
# regular convolution model_ref.v_weights,
[8, 4, 1, (17, 32), (17, 32), "equiangular", "equiangular", 1e-6, 1e-4], model_ref.q_weights,
] model_ref.k_bias,
) model_ref.v_bias,
def test_bwd_dv(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol): model_ref.q_bias,
model_ref.quad_weights,
# extract some parameters model_ref.psi_col_idx,
_, nlon_in = in_shape model_ref.psi_roff_idx,
nlat_out, nlon_out = out_shape model_ref.num_heads,
model_ref.nlon_in,
# set up neighbor matrix model_ref.nlat_out,
att = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads, model_ref.nlon_out,
in_shape=in_shape, out_shape=out_shape, )
grid_in=grid_in, grid_out=grid_out, bias=False).to(self.device) out_ref = nn.functional.conv2d(out_ref, model_ref.proj_weights, bias=model_ref.proj_bias)
out = model(inputs["q"], inputs["k"], inputs["v"])
# Execute and compare
k_inp = torch.randn(batch_size, channels, *in_shape, dtype=torch.float32, device=self.device) # Check forward equivalence
k_inp.requires_grad = False self.assertTrue(torch.allclose(out, out_ref, atol=atol, rtol=rtol), "Forward outputs differ between torch reference and custom implementation")
v_inp = torch.randn(batch_size, channels, *in_shape, dtype=torch.float32, device=self.device)
v_inp.requires_grad = True # Backward passes
q_inp = torch.randn(batch_size, channels, *out_shape, dtype=torch.float32, device=self.device) grad = torch.randn_like(out_ref)
q_inp.requires_grad = False out_ref.backward(grad)
out_grad = torch.randn(batch_size, channels, *out_shape, dtype=torch.float32, device=self.device) out.backward(grad.to(self.device))
out_torch = _neighborhood_attention_s2_torch_test(k_inp, v_inp, q_inp, att.quad_weights, att.psi_col_idx, att.psi_row_idx, nlon_in, nlat_out, nlon_out) # Check input gradient equivalence
for inp in ["q", "k", "v"]:
# need 'retain_graph' to avoid an error in the tests after this one grad_ref = inputs_ref[inp].grad.cpu()
out_torch.backward(out_grad) grad = inputs[inp].grad.cpu()
dv_inp_torch = v_inp.grad.clone() self.assertTrue(torch.allclose(grad, grad_ref, atol=atol, rtol=rtol), f"Input gradient mismatch in {inp}")
with torch.no_grad(): # Check parameter gradient equivalence
dv_inp_torch_explicit = _neighborhood_attention_s2_bwd_dv_torch( for p_ref, p in zip(model_ref.parameters(), model.parameters()):
k_inp, v_inp, q_inp, out_grad, att.quad_weights, att.psi_col_idx, att.psi_roff_idx, nlon_in, nlat_out, nlon_out self.assertTrue(torch.allclose(p.grad, p_ref.grad, atol=atol, rtol=rtol), f"Parameter gradient mismatch: {type(p_ref).__name__}")
)
# caution: multihead-implementation between full and neighborhood attention still seem to differ. tests are only done for single head
self.assertTrue(torch.allclose(dv_inp_torch_explicit, dv_inp_torch, atol=atol, rtol=rtol))
if _cuda_extension_available:
dv_inp_cuda_explicit = attention_cuda_extension.backward_dv(
k_inp, v_inp, q_inp, out_grad, att.quad_weights, att.psi_col_idx, att.psi_roff_idx, nlon_in, nlat_out, nlon_out
)
self.assertTrue(torch.allclose(dv_inp_cuda_explicit, dv_inp_torch, atol=atol, rtol=rtol))
@parameterized.expand( @parameterized.expand(
[ [
# regular convolution # Format: [batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol]
[8, 4, 1, (17, 32), (17, 32), "equiangular", "equiangular", 1e-6, 1e-3], [4, 4, 1, (6, 12), (6, 12), "equiangular", "equiangular", 1e-2, 0],
# [4, 4, 2, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
# [4, 4, 4, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
[4, 4, 1, (6, 12), (6, 12), "legendre-gauss", "legendre-gauss", 1e-2, 0],
[4, 4, 1, (6, 12), (6, 12), "lobatto", "lobatto", 1e-2, 0],
] ]
) )
def test_bwd_dk(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol): def test_neighborhood_global_equivalence(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=True):
"""Tests numerical equivalence between the global spherical attention module and the neighborhood spherical attention module with the neighborhood set ot the whole sphere"""
# extract some parameters
nlat_in, nlon_in = in_shape nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape nlat_out, nlon_out = out_shape
# set up neighbor matrix # Helper: create inputs
att = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads, inputs_ref = {
in_shape=in_shape, out_shape=out_shape, "k": torch.randn(batch_size, channels, nlat_in, nlon_in, requires_grad=True, device=self.device, dtype=torch.float32),
grid_in=grid_in, grid_out=grid_out, bias=False).to(self.device) "v": torch.randn(batch_size, channels, nlat_in, nlon_in, requires_grad=True, device=self.device, dtype=torch.float32),
"q": torch.randn(batch_size, channels, nlat_out, nlon_out, requires_grad=True, device=self.device, dtype=torch.float32),
# Execute and compare }
k_inp = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device=self.device) inputs = {k: v.detach().clone().to(self.device).requires_grad_() for k, v in inputs_ref.items()}
k_inp.requires_grad = True
v_inp = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device=self.device) # reference input and model
v_inp.requires_grad = False model_ref = AttentionS2(in_channels=channels, num_heads=heads, in_shape=in_shape, out_shape=out_shape, grid_in=grid_in, grid_out=grid_out, bias=False).to(self.device)
q_inp = torch.randn(batch_size, channels, nlat_out, nlon_out, dtype=torch.float32, device=self.device)
q_inp.requires_grad = False # Device model and inputs
out_grad = torch.randn(batch_size, channels, nlat_out, nlon_out, dtype=torch.float32, device=self.device) model = NeighborhoodAttentionS2(
in_channels=channels, num_heads=heads, in_shape=in_shape, out_shape=out_shape, grid_in=grid_in, grid_out=grid_out, bias=False, theta_cutoff=2 * torch.pi
out_torch = _neighborhood_attention_s2_torch_test(k_inp, v_inp, q_inp, att.quad_weights, att.psi_col_idx, att.psi_row_idx, nlon_in, nlat_out, nlon_out) )
# need 'retain_graph' to avoid an error in the tests after this one # Synchronize parameters of model
out_torch.backward(out_grad) model.load_state_dict(model_ref.state_dict())
dk_inp_torch = k_inp.grad.clone() model = model.to(self.device)
for (name_ref, p_ref), (name, p) in zip(model_ref.named_parameters(), model.named_parameters()):
with torch.no_grad(): assert torch.allclose(p_ref, p), f"Parameter mismatch: {name_ref} vs {name}"
dk_inp_torch_explicit = _neighborhood_attention_s2_bwd_dk_torch(
k_inp, v_inp, q_inp, out_grad, att.quad_weights, att.psi_col_idx, att.psi_roff_idx, nlon_in, nlat_out, nlon_out # reference forward passes
) out_ref = model_ref(inputs_ref["q"], inputs_ref["k"], inputs_ref["v"])
out = model(inputs["q"], inputs["k"], inputs["v"])
self.assertTrue(torch.allclose(dk_inp_torch_explicit, dk_inp_torch, atol=atol, rtol=rtol))
# Check forward equivalence
if _cuda_extension_available: self.assertTrue(torch.allclose(out, out_ref, atol=atol, rtol=rtol), "Forward outputs differ between torch reference and custom implementation")
dk_inp_cuda_explicit = attention_cuda_extension.backward_dk( # Backward passes
k_inp, v_inp, q_inp, out_grad, att.quad_weights, att.psi_col_idx, att.psi_roff_idx, nlon_in, nlat_out, nlon_out grad = torch.randn_like(out_ref)
) out_ref.backward(grad)
out.backward(grad.to(self.device))
self.assertTrue(torch.allclose(dk_inp_cuda_explicit, dk_inp_torch, atol=atol, rtol=rtol))
# Check input gradient equivalence
@parameterized.expand( for inp in ["q", "k", "v"]:
[ grad_ref = inputs_ref[inp].grad
# regular convolution grad = inputs[inp].grad
[8, 4, 1, (17, 32), (17, 32), "equiangular", "equiangular", 4e-6, 1e-3], self.assertTrue(torch.allclose(grad, grad_ref, atol=atol, rtol=rtol), f"Input gradient mismatch in {inp}")
]
) # Check parameter gradient equivalence - check only q,k, v weights
def test_bwd_dq(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol): for key in ["q_weights", "k_weights", "v_weights"]:
grad_ref = getattr(model_ref, key).grad
# extract some parameters grad = getattr(model, key).grad
nlat_in, nlon_in = in_shape self.assertTrue(torch.allclose(grad, grad_ref, atol=atol, rtol=rtol), f"Parameter gradient mismatch")
nlat_out, nlon_out = out_shape
# set up neighbor matrix
att = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads,
in_shape=in_shape, out_shape=out_shape,
grid_in=grid_in, grid_out=grid_out, bias=False).to(self.device)
# Execute and compare
k_inp = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device=self.device)
k_inp.requires_grad = False
v_inp = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device=self.device)
v_inp.requires_grad = False
q_inp = torch.randn(batch_size, channels, nlat_out, nlon_out, dtype=torch.float32, device=self.device)
q_inp.requires_grad = True
out_grad = torch.randn(batch_size, channels, nlat_out, nlon_out, dtype=torch.float32, device=self.device)
out_torch = _neighborhood_attention_s2_torch_test(k_inp, v_inp, q_inp, att.quad_weights, att.psi_col_idx, att.psi_row_idx, nlon_in, nlat_out, nlon_out)
# need 'retain_graph' to avoid an error in the tests after this one
out_torch.backward(out_grad)
dq_inp_torch = q_inp.grad.clone()
with torch.no_grad():
dq_inp_torch_explicit = _neighborhood_attention_s2_bwd_dq_torch(
k_inp, v_inp, q_inp, out_grad, att.quad_weights, att.psi_col_idx, att.psi_roff_idx, nlon_in, nlat_out, nlon_out
)
self.assertTrue(torch.allclose(dq_inp_torch_explicit, dq_inp_torch, atol=atol, rtol=rtol))
if _cuda_extension_available:
dq_inp_cuda_explicit = attention_cuda_extension.backward_dq(
k_inp, v_inp, q_inp, out_grad, att.quad_weights, att.psi_col_idx, att.psi_roff_idx, nlon_in, nlat_out, nlon_out
)
self.assertTrue(torch.allclose(dq_inp_cuda_explicit, dq_inp_torch, atol=atol, rtol=rtol))
@parameterized.expand(
[
# self attention
[1, 73, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
]
)
def test_big(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol):
# this test only makes sense when CUDA version is available
if torch.cuda.is_available():
if not _cuda_extension_available:
print("WARNING: Problem loading CUDA attention module")
return
# extract some parameters
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
# TODO: this test seems hardcoded for GPU. Is this necessary?
k_gpu = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device="cuda:0")
k_gpu.requires_grad = False
v_gpu = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device="cuda:0")
v_gpu.requires_grad = False
q_gpu = torch.randn(batch_size, channels, nlat_out, nlon_out, dtype=torch.float32, device="cuda:0")
q_gpu.requires_grad = False
# set up layers
att_gpu = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads,
in_shape=in_shape, out_shape=out_shape,
grid_in=grid_in, grid_out=grid_out, bias=True).to("cuda:0")
# random weights
with torch.no_grad():
att_gpu.q_weights.normal_()
att_gpu.k_weights.normal_()
att_gpu.v_weights.normal_()
att_gpu.q_bias.normal_()
att_gpu.k_bias.normal_()
att_gpu.v_bias.normal_()
out_gpu = att_gpu(q_gpu, k_gpu, v_gpu)
# sync weights:
with torch.no_grad():
att_gpu.q_weights.copy_(att_gpu.q_weights)
att_gpu.k_weights.copy_(att_gpu.k_weights)
att_gpu.v_weights.copy_(att_gpu.v_weights)
att_gpu.q_bias.copy_(att_gpu.q_bias)
att_gpu.k_bias.copy_(att_gpu.k_bias)
att_gpu.v_bias.copy_(att_gpu.v_bias)
q_gpu = q_gpu.detach().clone().to(self.device)
q_gpu.requires_grad = True
k_gpu = k_gpu.detach().clone().to(self.device)
k_gpu.requires_grad = True
v_gpu = v_gpu.detach().clone().to(self.device)
v_gpu.requires_grad = True
out_gpu = att_gpu(q_gpu, k_gpu, v_gpu)
out_grad = torch.randn(out_gpu.shape, dtype=torch.float32, device="cuda:0")
out_gpu.backward(out_grad.to("cuda:0"))
@parameterized.expand(
[
# self attention
[10, 2, 1, (17, 32), (17, 32), "equiangular", "equiangular", 1e-5, 1e-5],
]
)
def test_neighborhood(self, batch_size, num_channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol):
"""
This test sets a specific q[ho,wo] value to 1.0 (elsewhere 0), and then a neighborhood of k around ho,wo to 1.0 (else 0.0). Also vi is set to a sinusoidal input. We also run it with fully 0 q,k. We test that the output of the nonzero q,k is only different to the zero q,k in a single point. We also check the value of this difference (as a regression test).
"""
# extract some parameters
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
from torch_harmonics import _neighborhood_attention_s2_fwd_torch
device = "cpu"
nas2_2 = NeighborhoodAttentionS2(in_channels=num_channels, num_heads=heads,
in_shape=(nlat_in, nlon_in), out_shape=(nlat_out, nlon_out),
theta_cutoff=torch.pi / 128 * 10)
nas2_2.to(device)
qo = torch.zeros((batch_size, num_channels, nlat_in, nlon_in)).to(device)
x = torch.linspace(0, 2 * np.pi, nlat_in) # 100 points in x direction
y = torch.linspace(0, 2 * np.pi, nlon_in) # 100 points in y direction
# Create a meshgrid
X, Y = torch.meshgrid(x, y, indexing="ij")
vi = torch.ones((batch_size, num_channels, nlat_in, nlon_in)).to(device)
vi[:, :, :, :] = (torch.sin(X) + torch.sin(Y))[None, None, :, :]
ki = torch.zeros((batch_size, num_channels, nlat_in, nlon_in)).to(device)
ki2 = torch.zeros((batch_size, num_channels, nlat_in, nlon_in)).to(device)
ho = 10
wo = 15
qo[:, 0, ho, wo] = 1.0
nas3 = NeighborhoodAttentionS2(in_channels=num_channels, num_heads=heads,
in_shape=(nlat_in, nlon_in), out_shape=(nlat_out, nlon_out),
theta_cutoff=torch.pi / 128 * 7)
zstart = nas3.psi_roff_idx[ho]
zend = nas3.psi_roff_idx[ho + 1]
# set a small neighborhood of k around (ho,wo) to 1
for idz in range(zstart, zend):
nz_col_idx = nas3.psi_col_idx[idz]
# compute input indices from psi datastructure
hi = nz_col_idx // nlon_in
# account for output shift and ensure positive index due to circular condition
wi = nz_col_idx % nlon_in
wip = (wi + wo) % nlon_in
ki2[:, 0, hi, wip] = 1.0
# run with k zero
y = _neighborhood_attention_s2_fwd_torch(ki, vi, qo, nas2_2.quad_weights, nas2_2.psi_col_idx, nas2_2.psi_roff_idx, nlon_in, nlat_out, nlon_out)
# run with k 1 at neighborhood of ho,wo
y2 = _neighborhood_attention_s2_fwd_torch(ki2, vi, qo, nas2_2.quad_weights, nas2_2.psi_col_idx, nas2_2.psi_roff_idx, nlon_in, nlat_out, nlon_out)
# for viz if desired
# plt.matshow((y[0,0,:,:]-y2[0,0,:,:]).detach().cpu())#, vmin=0, vmax=2.0)
# plt.matshow((y2[0,0,:,:]).detach().cpu())#, vmin=0, vmax=2.0)
# compare zero k vs. nonzero k and ensure difference only occurs at ho,wo
nz_x, nz_y = torch.where((y[0, 0, :, :] - y2[0, 0, :, :]).abs() > 0)
self.assertTrue(nz_x.item() == ho)
self.assertTrue(nz_y.item() == wo)
h, w = nz_x.item(), nz_y.item()
diff_hw = y[0, 0, h, w] - y2[0, 0, h, w]
# print("diff_hw=", diff_hw.item())
# regression test the difference. Unfortunately difficult to come up with an
# analytical value, so we just have it hardcoded.
self.assertTrue(torch.allclose(diff_hw, torch.tensor([0.00753], device=device), rtol=rtol, atol=atol))
@parameterized.expand(
[
# self attention
[8, 4, 1, (17, 32), (17, 32), "equiangular", "equiangular", 2e-4, 1e-5],
[8, 4, 2, (17, 32), (17, 32), "equiangular", "equiangular", 2e-4, 1e-5],
]
)
def test_full(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol):
# extract some parameters
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
k_cpu = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device="cpu")
k_cpu.requires_grad = True
v_cpu = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device="cpu")
v_cpu.requires_grad = True
q_cpu = torch.randn(batch_size, channels, nlat_out, nlon_out, dtype=torch.float32, device="cpu")
q_cpu.requires_grad = True
# set up layers
att_cpu = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads,
in_shape=in_shape, out_shape=out_shape,
grid_in=grid_in, grid_out=grid_out, bias=True)
# random weights
with torch.no_grad():
att_cpu.q_weights.normal_()
att_cpu.k_weights.normal_()
att_cpu.v_weights.normal_()
att_cpu.q_bias.normal_()
att_cpu.k_bias.normal_()
att_cpu.v_bias.normal_()
out_cpu = att_cpu(q_cpu, k_cpu, v_cpu)
out_grad = torch.randn(out_cpu.shape, dtype=torch.float32, device="cpu")
out_cpu.backward(out_grad)
att_gpu = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads,
in_shape=in_shape, out_shape=out_shape,
grid_in=grid_in, grid_out=grid_out, bias=True).to(self.device)
# sync weights:
with torch.no_grad():
att_gpu.q_weights.copy_(att_cpu.q_weights)
att_gpu.k_weights.copy_(att_cpu.k_weights)
att_gpu.v_weights.copy_(att_cpu.v_weights)
att_gpu.proj_weights.copy_(att_cpu.proj_weights)
att_gpu.q_bias.copy_(att_cpu.q_bias)
att_gpu.k_bias.copy_(att_cpu.k_bias)
att_gpu.v_bias.copy_(att_cpu.v_bias)
att_gpu.proj_bias.copy_(att_cpu.proj_bias)
q_gpu = q_cpu.detach().clone().to(self.device)
q_gpu.requires_grad = True
k_gpu = k_cpu.detach().clone().to(self.device)
k_gpu.requires_grad = True
v_gpu = v_cpu.detach().clone().to(self.device)
v_gpu.requires_grad = True
out_gpu = att_gpu(q_gpu, k_gpu, v_gpu)
out_gpu.backward(out_grad.to(self.device))
# check forward
self.assertTrue(torch.allclose(out_cpu.to(self.device), out_gpu, atol=atol, rtol=rtol))
# check input gradients:
self.assertTrue(torch.allclose(q_cpu.grad.to(self.device), q_gpu.grad, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(k_cpu.grad.to(self.device), k_gpu.grad, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(v_cpu.grad.to(self.device), v_gpu.grad, atol=atol, rtol=rtol))
# check weight gradients
self.assertTrue(torch.allclose(att_cpu.q_weights.grad.to(self.device), att_gpu.q_weights.grad, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(att_cpu.k_weights.grad.to(self.device), att_gpu.k_weights.grad, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(att_cpu.v_weights.grad.to(self.device), att_gpu.v_weights.grad, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(att_cpu.proj_weights.grad.to(self.device), att_gpu.proj_weights.grad, atol=atol, rtol=rtol))
# check bias gradients
self.assertTrue(torch.allclose(att_cpu.q_bias.grad.to(self.device), att_gpu.q_bias.grad, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(att_cpu.k_bias.grad.to(self.device), att_gpu.k_bias.grad, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(att_cpu.v_bias.grad.to(self.device), att_gpu.v_bias.grad, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(att_cpu.proj_bias.grad.to(self.device), att_gpu.proj_bias.grad, atol=atol, rtol=rtol))
@parameterized.expand(
[
# self attention
[8, 8, 8, 2, (17, 32), (17, 32), "equiangular", "equiangular", 2e-4, 1e-5],
[8, 8, 8, 2, (17, 32), (17, 32), "legendre-gauss", "legendre-gauss", 2e-4, 1e-5],
[8, 8, 8, 2, (17, 32), (17, 32), "lobatto", "lobatto", 2e-4, 1e-5],
[8, 8, 4, 2, (17, 32), (17, 32), "equiangular", "equiangular", 2e-4, 1e-5],
]
)
def test_full_attention(self, batch_size, channels_in, channels_out, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol):
# extract some parameters
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
k_cpu = torch.randn(batch_size, channels_in, nlat_in, nlon_in, dtype=torch.float32, device="cpu")
k_cpu.requires_grad = True
v_cpu = torch.randn(batch_size, channels_in, nlat_in, nlon_in, dtype=torch.float32, device="cpu")
v_cpu.requires_grad = True
q_cpu = torch.randn(batch_size, channels_in, nlat_out, nlon_out, dtype=torch.float32, device="cpu")
q_cpu.requires_grad = True
att_cpu = AttentionS2(in_channels=channels_in, out_channels=channels_out, num_heads=heads, in_shape=in_shape, out_shape=out_shape, grid_in=grid_in, grid_out=grid_out, bias=True)
out = att_cpu(q_cpu, k_cpu, v_cpu)
# check if output is sane
self.assertFalse(torch.isnan(out).any())
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -107,7 +107,7 @@ class AttentionS2(nn.Module): ...@@ -107,7 +107,7 @@ class AttentionS2(nn.Module):
_, wgl = _precompute_latitudes(self.nlat_in, grid=grid_in) _, wgl = _precompute_latitudes(self.nlat_in, grid=grid_in)
quad_weights = 2.0 * torch.pi * wgl.to(dtype=torch.float32) / self.nlon_in quad_weights = 2.0 * torch.pi * wgl.to(dtype=torch.float32) / self.nlon_in
# we need to tile and flatten them accordingly # we need to tile and flatten them accordingly
quad_weights = torch.tile(quad_weights, (1, self.nlon_in)).flatten() quad_weights = torch.tile(quad_weights.reshape(-1, 1), (1, self.nlon_in)).flatten()
# compute log because they are applied as an addition prior to the softmax ('attn_mask'), which includes an exponential. # compute log because they are applied as an addition prior to the softmax ('attn_mask'), which includes an exponential.
# see https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html # see https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
......
...@@ -76,7 +76,3 @@ torch::Tensor s2_attention_bwd_dv_cuda(at::Tensor kx, ...@@ -76,7 +76,3 @@ torch::Tensor s2_attention_bwd_dv_cuda(at::Tensor kx,
at::Tensor psi_col_idx, at::Tensor psi_col_idx,
at::Tensor psi_row_off, at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out); int nlon_in, int nlat_out, int nlon_out);
int s2_idx_offset_cuda(const at::Tensor &psi_col_idx,
const at::Tensor &psi_row_idx, at::Tensor &row_offset,
at::Tensor &row_count);
// coding=utf-8 // coding=utf-8
// //
// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause // SPDX-License-Identifier: BSD-3-Clause
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met: // modification, are permitted provided that the following conditions are met:
// //
...@@ -433,7 +433,7 @@ s2_attention_bwd_dq_kernel(int num_channels, int nlon_in, int nlat_out, int nlon ...@@ -433,7 +433,7 @@ s2_attention_bwd_dq_kernel(int num_channels, int nlon_in, int nlat_out, int nlon
float qdotk_max = std::numeric_limits<float>::lowest(); float qdotk_max = std::numeric_limits<float>::lowest();
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) { for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x; int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads // skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break; if(idz >= psi_nnz_ho) break;
...@@ -559,7 +559,7 @@ __global__ void s2_attention_bwd_dkvq_kernel(int num_channels, int nlon_in, int ...@@ -559,7 +559,7 @@ __global__ void s2_attention_bwd_dkvq_kernel(int num_channels, int nlon_in, int
float qdotk_max = std::numeric_limits<float>::lowest(); float qdotk_max = std::numeric_limits<float>::lowest();
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) { for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x; int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads // skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break; if(idz >= psi_nnz_ho) break;
...@@ -675,7 +675,7 @@ __global__ void s2_attention_bwd_dkvq_kernel(int num_channels, int nlon_in, int ...@@ -675,7 +675,7 @@ __global__ void s2_attention_bwd_dkvq_kernel(int num_channels, int nlon_in, int
} }
at::Tensor s2_attention_bwd_dk_cuda(at::Tensor kx, at::Tensor s2_attention_bwd_dk_cuda(at::Tensor kx,
at::Tensor vx, at::Tensor vx,
at::Tensor qy, at::Tensor qy,
at::Tensor dy, at::Tensor dy,
...@@ -731,7 +731,7 @@ at::Tensor s2_attention_bwd_dk_cuda(at::Tensor kx, ...@@ -731,7 +731,7 @@ at::Tensor s2_attention_bwd_dk_cuda(at::Tensor kx,
} }
at::Tensor s2_attention_bwd_dq_cuda(at::Tensor kx, at::Tensor s2_attention_bwd_dq_cuda(at::Tensor kx,
at::Tensor vx, at::Tensor vx,
at::Tensor qy, at::Tensor qy,
at::Tensor dy, at::Tensor dy,
...@@ -782,7 +782,7 @@ at::Tensor s2_attention_bwd_dq_cuda(at::Tensor kx, ...@@ -782,7 +782,7 @@ at::Tensor s2_attention_bwd_dq_cuda(at::Tensor kx,
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
return dydq; return dydq;
} }
...@@ -840,7 +840,7 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens ...@@ -840,7 +840,7 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
return std::make_tuple(dydk, dydv, dydq); return std::make_tuple(dydk, dydv, dydq);
} }
// coding=utf-8 // coding=utf-8
// //
// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause // SPDX-License-Identifier: BSD-3-Clause
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met: // modification, are permitted provided that the following conditions are met:
// //
...@@ -172,9 +172,9 @@ __global__ void s2_attention_kernel(int num_channels, int nlon_in, int nlat_out, ...@@ -172,9 +172,9 @@ __global__ void s2_attention_kernel(int num_channels, int nlon_in, int nlat_out,
} }
torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
at::Tensor vx, at::Tensor vx,
at::Tensor qy, at::Tensor qy,
at::Tensor quad_weights, at::Tensor quad_weights,
at::Tensor psi_col_idx, at::Tensor psi_col_idx,
at::Tensor psi_row_off, at::Tensor psi_row_off,
......
// coding=utf-8 // coding=utf-8
// //
// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause // SPDX-License-Identifier: BSD-3-Clause
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met: // modification, are permitted provided that the following conditions are met:
// //
...@@ -33,7 +33,6 @@ ...@@ -33,7 +33,6 @@
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &s2_attention_fwd_cuda, "(Local) Attention on S2"); m.def("forward", &s2_attention_fwd_cuda, "(Local) Attention on S2");
m.def("compute_row_offset", &s2_idx_offset_cuda, "Row offset on S2");
m.def("backward_dk", &s2_attention_bwd_dk_cuda, "(Local) Attention gradient on S2 (gradient for k)"); m.def("backward_dk", &s2_attention_bwd_dk_cuda, "(Local) Attention gradient on S2 (gradient for k)");
m.def("backward_dv", &s2_attention_bwd_dv_cuda, "(Local) Attention gradient on S2 (gradient for v)"); m.def("backward_dv", &s2_attention_bwd_dv_cuda, "(Local) Attention gradient on S2 (gradient for v)");
m.def("backward_dq", &s2_attention_bwd_dq_cuda, m.def("backward_dq", &s2_attention_bwd_dq_cuda,
......
// coding=utf-8
//
// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "ATen/core/TensorAccessor.h"
#include <cmath>
#include <cstdint>
#include <torch/extension.h>
#include <torch/torch.h>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <ATen/cuda/detail/KernelUtils.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/CUDAUtils.h>
#include <thrust/reduce.h>
#include <thrust/functional.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/binary_search.h>
#include <thrust/execution_policy.h>
#include <thrust/sequence.h>
__global__ void alpha_count_kernel(int len_alpha_count,
int len_psi_row_idx,
torch::PackedTensorAccessor64<int64_t, 1> psi_row_idx,
int64_t* alpha_start,
torch::PackedTensorAccessor64<int64_t, 1> alpha_count
) {
int ho = blockIdx.x * blockDim.x + threadIdx.x;
if(ho < len_alpha_count) {
// initialize alpha_count;
alpha_count[ho] = 0;
// NOTE: Assumes that psi_row_idx is sorted
for(int i=alpha_start[ho]; i<len_psi_row_idx; i++) {
if(psi_row_idx[i] == ho) alpha_count[ho]++;
else if(psi_row_idx[i] > ho) break;
}
}
}
int s2_idx_offset_cuda(const at::Tensor& psi_col_idx,
const at::Tensor& psi_row_idx,
at::Tensor& row_offset,
at::Tensor& row_count) {
auto stream = at::cuda::getCurrentCUDAStream();
int64_t* d_alpha_start;
int64_t* d_sequence;
int64_t* d_alpha_count = row_count.data_ptr<int64_t>();
int64_t* d_alpha_offset = row_offset.data_ptr<int64_t>();
C10_CUDA_CHECK(cudaMalloc(&d_alpha_start, row_offset.size(0)*sizeof(int64_t)));
// Find the first time each index occurs in psi_row_idx
// psi_row_idx = [0,0,0,0,1,1,1,1,2,2,2...]
// 0 starts at idx=0, 1 starts at idx=4, 2 starts at idx=8, etc
// this assumes that psi_row_idx is sorted!
C10_CUDA_CHECK(cudaMalloc(&d_sequence, row_offset.size(0)*sizeof(int64_t)));
thrust::sequence(thrust::device, d_sequence, d_sequence+row_offset.size(0), 0);
thrust::counting_iterator<int> start(0);
// thrust::lower_bound(thrust::device,
// psi_row_idx.data_ptr<int64_t>(),
// psi_row_idx.data_ptr<int64_t>()+psi_row_idx.size(0),
// start, start+psi_row_idx.size(0), d_alpha_start);
thrust::lower_bound(thrust::device,
psi_row_idx.data_ptr<int64_t>(),
psi_row_idx.data_ptr<int64_t>()+psi_row_idx.size(0),
d_sequence, d_sequence+row_offset.size(0), d_alpha_start);
alpha_count_kernel<<<at::cuda::detail::GET_BLOCKS(row_offset.size(0),512),512,
0,stream.stream()>>>(row_count.size(0),
psi_row_idx.size(0),
psi_row_idx.packed_accessor64<int64_t, 1>(),
d_alpha_start,
row_count.packed_accessor64<int64_t , 1>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
int maxAlphaSize = thrust::reduce(thrust::device,
d_alpha_count,
d_alpha_count+row_count.size(0),
0,
thrust::maximum<int>());
thrust::exclusive_scan(thrust::device,
d_alpha_count,
d_alpha_count+row_count.size(0),
d_alpha_offset);
C10_CUDA_CHECK(cudaFree(d_alpha_start));
C10_CUDA_CHECK(cudaFree(d_sequence));
return maxAlphaSize;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment