test_attention.py 26.6 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import unittest
from parameterized import parameterized

# import math
import numpy as np
import torch

# from torch.autograd import gradcheck
from torch_harmonics import AttentionS2, NeighborhoodAttentionS2

from torch_harmonics._neighborhood_attention import (
    _neighborhood_attention_s2_torch,
    _neighborhood_attention_s2_fwd_torch,
    _neighborhood_attention_s2_bwd_dv_torch,
    _neighborhood_attention_s2_bwd_dk_torch,
    _neighborhood_attention_s2_bwd_dq_torch,
)

# import custom C++/CUDA extensions
try:
    import attention_cuda_extension

    _cuda_extension_available = True
except ImportError as err:
    print(f"Warning: Couldn't Import cuda attention: {err}")
    attention_cuda_extension = None
    _cuda_extension_available = False


# this routine is only supposed to be used in this test, since it is numerically not stable but supports
# autograd which some of the better kernels do not
def _neighborhood_attention_s2_torch_test(
    kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, quad_weights: torch.Tensor, col_idx: torch.Tensor, row_idx: torch.Tensor, nlon_in: int, nlat_out: int, nlon_out: int
):

    out = torch.zeros_like(qy)

    for ho in range(nlat_out):

        # get nonzero indices in output row
        idx_ho = col_idx[row_idx == ho]

        for wo in range(nlon_out):
            alpha_sum = torch.zeros((out.shape[0],), dtype=out.dtype, device=out.device)
            alpha = torch.zeros((out.shape[0], len(idx_ho)), dtype=out.dtype, device=out.device)
            for inz, nz_col_idx in enumerate(idx_ho):
                # compute input indices from psi datastructure
                hi = nz_col_idx // nlon_in
                # account for output shift and ensure positive index due to circular condition
                wi = nz_col_idx % nlon_in
                wip = (wi + wo) % nlon_in

                # compute correlation & softmax numerator
                q_ho_wo = qy[:, :, ho, wo]
                k_hi_wip = kx[:, :, hi, wip]
                alpha[:, inz] = torch.exp(torch.sum(q_ho_wo * k_hi_wip, dim=1)) * quad_weights[hi]
                # softmax denominator
                alpha_sum[:] = alpha_sum[:] + alpha[:, inz]

            for inz, nz_col_idx in enumerate(idx_ho):
                # compute input indices from psi datastructure
                hi = nz_col_idx // nlon_in
                # account for output shift and ensure positive index due to circular condition
                wi = nz_col_idx % nlon_in
                wip = (wi + wo) % nlon_in

                # compute matmul of attention matrix with V-vector
                out[:, :, ho, wo] = out[:, :, ho, wo] + (alpha[:, None, inz] / alpha_sum[:, None]) * vx[:, :, hi, wip]

    return out


class TestNeighborhoodAttention(unittest.TestCase):

    def setUp(self):
        if torch.cuda.is_available():
            self.device = torch.device("cuda:0")
            torch.cuda.set_device(self.device.index)
            torch.cuda.manual_seed(333)
        else:
            self.device = torch.device("cpu")

        torch.manual_seed(333)

    @parameterized.expand(
        [
            # regular convolution
            [8, 4, 6, (17, 32), 1e-6, 1e-5],
        ]
    )
    def test_batched_linear(self, batch_size, in_channels, out_channels, shape, atol, rtol):
        # weight
        weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, 1, 1, dtype=torch.float32, device=self.device))
        bias = torch.nn.Parameter(torch.randn(out_channels, dtype=torch.float32, device=self.device))

        # input
        inp = torch.randn(batch_size, in_channels, *shape, dtype=torch.float32, device=self.device)
        inp.requires_grad = True

        # operation
        out = torch.nn.functional.conv2d(inp, weight=weight, bias=bias)
        out_grad = torch.randn(batch_size, out_channels, *shape, dtype=torch.float32, device=self.device)
        out.backward(out_grad)

        # store for comparison
        wgrad = weight.grad.clone()
        bgrad = bias.grad.clone()
        igrad = inp.grad.clone()

        # explicit layers
        igrad_explicit = torch.nn.functional.conv2d(out_grad, weight=weight.permute([1, 0, 2, 3]), bias=None)
        wgrad_explicit = torch.einsum("bchw,bfhw->cf", out_grad, inp).reshape(out_channels, in_channels, 1, 1).contiguous()
        bgrad_explicit = torch.sum(out_grad, dim=(0, 2, 3))

        # check consistency
        self.assertTrue(torch.allclose(igrad, igrad_explicit, atol=atol, rtol=rtol))
        self.assertTrue(torch.allclose(wgrad, wgrad_explicit, atol=atol, rtol=rtol))
        self.assertTrue(torch.allclose(bgrad, bgrad_explicit, atol=atol, rtol=rtol))

    @parameterized.expand(
        [
            # self attention
            [8, 4, 1, (17, 32), (17, 32), "equiangular", "equiangular", 1e-6, 1e-4],
        ]
    )
    def test_fwd(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol):

        # extract some parameters
        nlat_in, nlon_in = in_shape
        nlat_out, nlon_out = out_shape

        # set up neighbor matrix
        att = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads,
                                      in_shape=in_shape, out_shape=out_shape,
                                      grid_in=grid_in, grid_out=grid_out, bias=False).to(self.device)

        # Execute and compare
        k_inp = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device=self.device)
        k_inp.requires_grad = False
        v_inp = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device=self.device)
        v_inp.requires_grad = False
        q_inp = torch.randn(batch_size, channels, nlat_out, nlon_out, dtype=torch.float32, device=self.device)
        q_inp.requires_grad = False

        out_torch = _neighborhood_attention_s2_torch_test(k_inp, v_inp, q_inp, att.quad_weights, att.psi_col_idx, att.psi_row_idx, nlon_in, nlat_out, nlon_out)

        with torch.no_grad():
            out_torch_explicit = _neighborhood_attention_s2_fwd_torch(k_inp, v_inp, q_inp, att.quad_weights, att.psi_col_idx, att.psi_roff_idx, nlon_in, nlat_out, nlon_out)

            self.assertTrue(torch.allclose(out_torch_explicit, out_torch, atol=atol, rtol=rtol))

        if _cuda_extension_available:

            out_cuda = attention_cuda_extension.forward(k_inp, v_inp, q_inp, att.quad_weights, att.psi_col_idx, att.psi_roff_idx, nlon_in, nlat_out, nlon_out)

            self.assertTrue(torch.allclose(out_torch, out_cuda, atol=atol, rtol=rtol))

    @parameterized.expand(
        [
            # regular convolution
            [8, 4, 1, (17, 32), (17, 32), "equiangular", "equiangular", 1e-6, 1e-4],
        ]
    )
    def test_bwd_dv(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol):

        # extract some parameters
        _, nlon_in = in_shape
        nlat_out, nlon_out = out_shape

        # set up neighbor matrix
        att = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads,
                                      in_shape=in_shape, out_shape=out_shape,
                                      grid_in=grid_in, grid_out=grid_out, bias=False).to(self.device)

        # Execute and compare
        k_inp = torch.randn(batch_size, channels, *in_shape, dtype=torch.float32, device=self.device)
        k_inp.requires_grad = False
        v_inp = torch.randn(batch_size, channels, *in_shape, dtype=torch.float32, device=self.device)
        v_inp.requires_grad = True
        q_inp = torch.randn(batch_size, channels, *out_shape, dtype=torch.float32, device=self.device)
        q_inp.requires_grad = False
        out_grad = torch.randn(batch_size, channels, *out_shape, dtype=torch.float32, device=self.device)

        out_torch = _neighborhood_attention_s2_torch_test(k_inp, v_inp, q_inp, att.quad_weights, att.psi_col_idx, att.psi_row_idx, nlon_in, nlat_out, nlon_out)

        # need 'retain_graph' to avoid an error in the tests after this one
        out_torch.backward(out_grad)
        dv_inp_torch = v_inp.grad.clone()

        with torch.no_grad():
            dv_inp_torch_explicit = _neighborhood_attention_s2_bwd_dv_torch(
                k_inp, v_inp, q_inp, out_grad, att.quad_weights, att.psi_col_idx, att.psi_roff_idx, nlon_in, nlat_out, nlon_out
            )

        self.assertTrue(torch.allclose(dv_inp_torch_explicit, dv_inp_torch, atol=atol, rtol=rtol))

        if _cuda_extension_available:

            dv_inp_cuda_explicit = attention_cuda_extension.backward_dv(
                k_inp, v_inp, q_inp, out_grad, att.quad_weights, att.psi_col_idx, att.psi_roff_idx, nlon_in, nlat_out, nlon_out
            )

            self.assertTrue(torch.allclose(dv_inp_cuda_explicit, dv_inp_torch, atol=atol, rtol=rtol))

    @parameterized.expand(
        [
            # regular convolution
            [8, 4, 1, (17, 32), (17, 32), "equiangular", "equiangular", 1e-6, 1e-3],
        ]
    )
    def test_bwd_dk(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol):

        # extract some parameters
        nlat_in, nlon_in = in_shape
        nlat_out, nlon_out = out_shape

        # set up neighbor matrix
        att = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads,
                                      in_shape=in_shape, out_shape=out_shape,
                                      grid_in=grid_in, grid_out=grid_out, bias=False).to(self.device)

        # Execute and compare
        k_inp = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device=self.device)
        k_inp.requires_grad = True
        v_inp = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device=self.device)
        v_inp.requires_grad = False
        q_inp = torch.randn(batch_size, channels, nlat_out, nlon_out, dtype=torch.float32, device=self.device)
        q_inp.requires_grad = False
        out_grad = torch.randn(batch_size, channels, nlat_out, nlon_out, dtype=torch.float32, device=self.device)

        out_torch = _neighborhood_attention_s2_torch_test(k_inp, v_inp, q_inp, att.quad_weights, att.psi_col_idx, att.psi_row_idx, nlon_in, nlat_out, nlon_out)

        # need 'retain_graph' to avoid an error in the tests after this one
        out_torch.backward(out_grad)
        dk_inp_torch = k_inp.grad.clone()

        with torch.no_grad():
            dk_inp_torch_explicit = _neighborhood_attention_s2_bwd_dk_torch(
                k_inp, v_inp, q_inp, out_grad, att.quad_weights, att.psi_col_idx, att.psi_roff_idx, nlon_in, nlat_out, nlon_out
            )

            self.assertTrue(torch.allclose(dk_inp_torch_explicit, dk_inp_torch, atol=atol, rtol=rtol))

        if _cuda_extension_available:

            dk_inp_cuda_explicit = attention_cuda_extension.backward_dk(
                k_inp, v_inp, q_inp, out_grad, att.quad_weights, att.psi_col_idx, att.psi_roff_idx, nlon_in, nlat_out, nlon_out
            )

            self.assertTrue(torch.allclose(dk_inp_cuda_explicit, dk_inp_torch, atol=atol, rtol=rtol))

    @parameterized.expand(
        [
            # regular convolution
            [8, 4, 1, (17, 32), (17, 32), "equiangular", "equiangular", 4e-6, 1e-3],
        ]
    )
    def test_bwd_dq(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol):

        # extract some parameters
        nlat_in, nlon_in = in_shape
        nlat_out, nlon_out = out_shape

        # set up neighbor matrix
        att = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads,
                                      in_shape=in_shape, out_shape=out_shape,
                                      grid_in=grid_in, grid_out=grid_out, bias=False).to(self.device)

        # Execute and compare
        k_inp = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device=self.device)
        k_inp.requires_grad = False
        v_inp = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device=self.device)
        v_inp.requires_grad = False
        q_inp = torch.randn(batch_size, channels, nlat_out, nlon_out, dtype=torch.float32, device=self.device)
        q_inp.requires_grad = True
        out_grad = torch.randn(batch_size, channels, nlat_out, nlon_out, dtype=torch.float32, device=self.device)

        out_torch = _neighborhood_attention_s2_torch_test(k_inp, v_inp, q_inp, att.quad_weights, att.psi_col_idx, att.psi_row_idx, nlon_in, nlat_out, nlon_out)

        # need 'retain_graph' to avoid an error in the tests after this one
        out_torch.backward(out_grad)
        dq_inp_torch = q_inp.grad.clone()

        with torch.no_grad():
            dq_inp_torch_explicit = _neighborhood_attention_s2_bwd_dq_torch(
                k_inp, v_inp, q_inp, out_grad, att.quad_weights, att.psi_col_idx, att.psi_roff_idx, nlon_in, nlat_out, nlon_out
            )

            self.assertTrue(torch.allclose(dq_inp_torch_explicit, dq_inp_torch, atol=atol, rtol=rtol))

        if _cuda_extension_available:

            dq_inp_cuda_explicit = attention_cuda_extension.backward_dq(
                k_inp, v_inp, q_inp, out_grad, att.quad_weights, att.psi_col_idx, att.psi_roff_idx, nlon_in, nlat_out, nlon_out
            )

            self.assertTrue(torch.allclose(dq_inp_cuda_explicit, dq_inp_torch, atol=atol, rtol=rtol))

    @parameterized.expand(
        [
            # self attention
            [1, 73, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
        ]
    )
    def test_big(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol):

        # this test only makes sense when CUDA version is available
        if torch.cuda.is_available():
            if not _cuda_extension_available:
                print("WARNING: Problem loading CUDA attention module")
                return
        # extract some parameters
        nlat_in, nlon_in = in_shape
        nlat_out, nlon_out = out_shape

        # TODO: this test seems hardcoded for GPU. Is this necessary?
        k_gpu = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device="cuda:0")
        k_gpu.requires_grad = False
        v_gpu = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device="cuda:0")
        v_gpu.requires_grad = False
        q_gpu = torch.randn(batch_size, channels, nlat_out, nlon_out, dtype=torch.float32, device="cuda:0")
        q_gpu.requires_grad = False

        # set up layers
        att_gpu = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads,
                                          in_shape=in_shape, out_shape=out_shape,
                                          grid_in=grid_in, grid_out=grid_out, bias=True).to("cuda:0")

        # random weights
        with torch.no_grad():
            att_gpu.q_weights.normal_()
            att_gpu.k_weights.normal_()
            att_gpu.v_weights.normal_()
            att_gpu.q_bias.normal_()
            att_gpu.k_bias.normal_()
            att_gpu.v_bias.normal_()

            out_gpu = att_gpu(q_gpu, k_gpu, v_gpu)

        # sync weights:
        with torch.no_grad():
            att_gpu.q_weights.copy_(att_gpu.q_weights)
            att_gpu.k_weights.copy_(att_gpu.k_weights)
            att_gpu.v_weights.copy_(att_gpu.v_weights)
            att_gpu.q_bias.copy_(att_gpu.q_bias)
            att_gpu.k_bias.copy_(att_gpu.k_bias)
            att_gpu.v_bias.copy_(att_gpu.v_bias)

        q_gpu = q_gpu.detach().clone().to(self.device)
        q_gpu.requires_grad = True
        k_gpu = k_gpu.detach().clone().to(self.device)
        k_gpu.requires_grad = True
        v_gpu = v_gpu.detach().clone().to(self.device)
        v_gpu.requires_grad = True

        out_gpu = att_gpu(q_gpu, k_gpu, v_gpu)
        out_grad = torch.randn(out_gpu.shape, dtype=torch.float32, device="cuda:0")
        out_gpu.backward(out_grad.to("cuda:0"))

    @parameterized.expand(
        [
            # self attention
            [10, 2, 1, (17, 32), (17, 32), "equiangular", "equiangular", 1e-5, 1e-5],
        ]
    )
    def test_neighborhood(self, batch_size, num_channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol):
        """
        This test sets a specific q[ho,wo] value to 1.0 (elsewhere 0), and then a neighborhood of k around ho,wo to 1.0 (else 0.0). Also vi is set to a sinusoidal input. We also run it with fully 0 q,k. We test that the output of the nonzero q,k is only different to the zero q,k in a single point. We also check the value of this difference (as a regression test).
        """

        # extract some parameters
        nlat_in, nlon_in = in_shape
        nlat_out, nlon_out = out_shape
        from torch_harmonics import _neighborhood_attention_s2_fwd_torch

        device = "cpu"
        nas2_2 = NeighborhoodAttentionS2(in_channels=num_channels, num_heads=heads,
                                         in_shape=(nlat_in, nlon_in), out_shape=(nlat_out, nlon_out),
                                         theta_cutoff=torch.pi / 128 * 10)
        nas2_2.to(device)
        qo = torch.zeros((batch_size, num_channels, nlat_in, nlon_in)).to(device)
        x = torch.linspace(0, 2 * np.pi, nlat_in)  # 100 points in x direction
        y = torch.linspace(0, 2 * np.pi, nlon_in)  # 100 points in y direction
        # Create a meshgrid
        X, Y = torch.meshgrid(x, y, indexing="ij")
        vi = torch.ones((batch_size, num_channels, nlat_in, nlon_in)).to(device)
        vi[:, :, :, :] = (torch.sin(X) + torch.sin(Y))[None, None, :, :]

        ki = torch.zeros((batch_size, num_channels, nlat_in, nlon_in)).to(device)
        ki2 = torch.zeros((batch_size, num_channels, nlat_in, nlon_in)).to(device)

        ho = 10
        wo = 15
        qo[:, 0, ho, wo] = 1.0
        nas3 = NeighborhoodAttentionS2(in_channels=num_channels, num_heads=heads,
                                       in_shape=(nlat_in, nlon_in), out_shape=(nlat_out, nlon_out),
                                       theta_cutoff=torch.pi / 128 * 7)
        zstart = nas3.psi_roff_idx[ho]
        zend = nas3.psi_roff_idx[ho + 1]

        # set a small neighborhood of k around (ho,wo) to 1
        for idz in range(zstart, zend):
            nz_col_idx = nas3.psi_col_idx[idz]
            # compute input indices from psi datastructure
            hi = nz_col_idx // nlon_in
            # account for output shift and ensure positive index due to circular condition
            wi = nz_col_idx % nlon_in
            wip = (wi + wo) % nlon_in
            ki2[:, 0, hi, wip] = 1.0

        # run with k zero
        y = _neighborhood_attention_s2_fwd_torch(ki, vi, qo, nas2_2.quad_weights, nas2_2.psi_col_idx, nas2_2.psi_roff_idx, nlon_in, nlat_out, nlon_out)
        # run with k 1 at neighborhood of ho,wo
        y2 = _neighborhood_attention_s2_fwd_torch(ki2, vi, qo, nas2_2.quad_weights, nas2_2.psi_col_idx, nas2_2.psi_roff_idx, nlon_in, nlat_out, nlon_out)

        # for viz if desired
        # plt.matshow((y[0,0,:,:]-y2[0,0,:,:]).detach().cpu())#, vmin=0, vmax=2.0)
        # plt.matshow((y2[0,0,:,:]).detach().cpu())#, vmin=0, vmax=2.0)

        # compare zero k vs. nonzero k and ensure difference only occurs at ho,wo
        nz_x, nz_y = torch.where((y[0, 0, :, :] - y2[0, 0, :, :]).abs() > 0)
        self.assertTrue(nz_x.item() == ho)
        self.assertTrue(nz_y.item() == wo)
        h, w = nz_x.item(), nz_y.item()
        diff_hw = y[0, 0, h, w] - y2[0, 0, h, w]
        # print("diff_hw=", diff_hw.item())

        # regression test the difference. Unfortunately difficult to come up with an
        # analytical value, so we just have it hardcoded.
        self.assertTrue(torch.allclose(diff_hw, torch.tensor([0.00753], device=device), rtol=rtol, atol=atol))

    @parameterized.expand(
        [
            # self attention
            [8, 4, 1, (17, 32), (17, 32), "equiangular", "equiangular", 2e-4, 1e-5],
            [8, 4, 2, (17, 32), (17, 32), "equiangular", "equiangular", 2e-4, 1e-5],
        ]
    )
    def test_full(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol):

        # extract some parameters
        nlat_in, nlon_in = in_shape
        nlat_out, nlon_out = out_shape

        k_cpu = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device="cpu")
        k_cpu.requires_grad = True
        v_cpu = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device="cpu")
        v_cpu.requires_grad = True
        q_cpu = torch.randn(batch_size, channels, nlat_out, nlon_out, dtype=torch.float32, device="cpu")
        q_cpu.requires_grad = True

        # set up layers
        att_cpu = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads,
                                          in_shape=in_shape, out_shape=out_shape,
                                          grid_in=grid_in, grid_out=grid_out, bias=True)

        # random weights
        with torch.no_grad():
            att_cpu.q_weights.normal_()
            att_cpu.k_weights.normal_()
            att_cpu.v_weights.normal_()
            att_cpu.q_bias.normal_()
            att_cpu.k_bias.normal_()
            att_cpu.v_bias.normal_()

        out_cpu = att_cpu(q_cpu, k_cpu, v_cpu)
        out_grad = torch.randn(out_cpu.shape, dtype=torch.float32, device="cpu")
        out_cpu.backward(out_grad)

        att_gpu = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads,
                                          in_shape=in_shape, out_shape=out_shape,
                                          grid_in=grid_in, grid_out=grid_out, bias=True).to(self.device)

        # sync weights:
        with torch.no_grad():
            att_gpu.q_weights.copy_(att_cpu.q_weights)
            att_gpu.k_weights.copy_(att_cpu.k_weights)
            att_gpu.v_weights.copy_(att_cpu.v_weights)
            att_gpu.proj_weights.copy_(att_cpu.proj_weights)
            att_gpu.q_bias.copy_(att_cpu.q_bias)
            att_gpu.k_bias.copy_(att_cpu.k_bias)
            att_gpu.v_bias.copy_(att_cpu.v_bias)
            att_gpu.proj_bias.copy_(att_cpu.proj_bias)
            

        q_gpu = q_cpu.detach().clone().to(self.device)
        q_gpu.requires_grad = True
        k_gpu = k_cpu.detach().clone().to(self.device)
        k_gpu.requires_grad = True
        v_gpu = v_cpu.detach().clone().to(self.device)
        v_gpu.requires_grad = True

        out_gpu = att_gpu(q_gpu, k_gpu, v_gpu)
        out_gpu.backward(out_grad.to(self.device))

        # check forward
        self.assertTrue(torch.allclose(out_cpu.to(self.device), out_gpu, atol=atol, rtol=rtol))

        # check input gradients:
        self.assertTrue(torch.allclose(q_cpu.grad.to(self.device), q_gpu.grad, atol=atol, rtol=rtol))
        self.assertTrue(torch.allclose(k_cpu.grad.to(self.device), k_gpu.grad, atol=atol, rtol=rtol))
        self.assertTrue(torch.allclose(v_cpu.grad.to(self.device), v_gpu.grad, atol=atol, rtol=rtol))

        # check weight gradients
        self.assertTrue(torch.allclose(att_cpu.q_weights.grad.to(self.device), att_gpu.q_weights.grad, atol=atol, rtol=rtol))
        self.assertTrue(torch.allclose(att_cpu.k_weights.grad.to(self.device), att_gpu.k_weights.grad, atol=atol, rtol=rtol))
        self.assertTrue(torch.allclose(att_cpu.v_weights.grad.to(self.device), att_gpu.v_weights.grad, atol=atol, rtol=rtol))
        self.assertTrue(torch.allclose(att_cpu.proj_weights.grad.to(self.device), att_gpu.proj_weights.grad, atol=atol, rtol=rtol))

        # check bias gradients
        self.assertTrue(torch.allclose(att_cpu.q_bias.grad.to(self.device), att_gpu.q_bias.grad, atol=atol, rtol=rtol))
        self.assertTrue(torch.allclose(att_cpu.k_bias.grad.to(self.device), att_gpu.k_bias.grad, atol=atol, rtol=rtol))
        self.assertTrue(torch.allclose(att_cpu.v_bias.grad.to(self.device), att_gpu.v_bias.grad, atol=atol, rtol=rtol))
        self.assertTrue(torch.allclose(att_cpu.proj_bias.grad.to(self.device), att_gpu.proj_bias.grad, atol=atol, rtol=rtol))

    @parameterized.expand(
        [
            # self attention
            [8, 8, 8, 2, (17, 32), (17, 32), "equiangular", "equiangular", 2e-4, 1e-5],
            [8, 8, 8, 2, (17, 32), (17, 32), "legendre-gauss", "legendre-gauss", 2e-4, 1e-5],
            [8, 8, 8, 2, (17, 32), (17, 32), "lobatto", "lobatto", 2e-4, 1e-5],
            [8, 8, 4, 2, (17, 32), (17, 32), "equiangular", "equiangular", 2e-4, 1e-5],
        ]
    )
    def test_full_attention(self, batch_size, channels_in, channels_out, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol):
        # extract some parameters
        nlat_in, nlon_in = in_shape
        nlat_out, nlon_out = out_shape

        k_cpu = torch.randn(batch_size, channels_in, nlat_in, nlon_in, dtype=torch.float32, device="cpu")
        k_cpu.requires_grad = True
        v_cpu = torch.randn(batch_size, channels_in, nlat_in, nlon_in, dtype=torch.float32, device="cpu")
        v_cpu.requires_grad = True
        q_cpu = torch.randn(batch_size, channels_in, nlat_out, nlon_out, dtype=torch.float32, device="cpu")
        q_cpu.requires_grad = True

        att_cpu = AttentionS2(in_channels=channels_in, out_channels=channels_out, num_heads=heads, in_shape=in_shape, out_shape=out_shape, grid_in=grid_in, grid_out=grid_out, bias=True)

        out = att_cpu(q_cpu, k_cpu, v_cpu)

        # check if output is sane
        self.assertFalse(torch.isnan(out).any())
        

if __name__ == "__main__":
    unittest.main()