_neighborhood_attention.py 35.6 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
# coding=utf-8

Boris Bonev's avatar
Boris Bonev committed
3
# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
Boris Bonev's avatar
Boris Bonev committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import math
from typing import Union

import torch
import torch.nn.functional as F
from torch.amp import custom_fwd, custom_bwd

try:
    import attention_cuda_extension
    _cuda_extension_available = True
except ImportError as err:
    attention_cuda_extension = None
    _cuda_extension_available = False

46
47
48
49
# s2 neighborhood attention forward pass
# uses qdotk_max update trick to avoid two loops when computing the softmax
# see e.g., https://arxiv.org/abs/1805.02867
# and https://alexdremov.me/understanding-flash-attention-writing-the-algorithm-from-scratch-in-triton/
Boris Bonev's avatar
Boris Bonev committed
50
def _neighborhood_attention_s2_fwd_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor,
51
52
                                            quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
                                            nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
apaaris's avatar
apaaris committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
    """
    Forward pass implementation of neighborhood attention on the sphere (S2).
    
    This function computes the neighborhood attention operation using sparse tensor
    operations. It implements the attention mechanism with softmax normalization
    and quadrature weights for spherical integration.
    
    Parameters
    -----------
    kx : torch.Tensor
        Key tensor with shape (B, C, Hi, Wi) where B is batch size, C is channels,
        Hi is input height (latitude), Wi is input width (longitude)
    vx : torch.Tensor
        Value tensor with shape (B, C, Hi, Wi)
    qy : torch.Tensor
        Query tensor with shape (B, C, Ho, Wo) where Ho is output height, Wo is output width
    quad_weights : torch.Tensor
        Quadrature weights for spherical integration with shape (Hi,)
    col_idx : torch.Tensor
        Column indices for sparse computation
    row_off : torch.Tensor
        Row offsets for sparse computation
    nlon_in : int
        Number of input longitude points
    nlat_out : int
        Number of output latitude points
    nlon_out : int
        Number of output longitude points
        
    Returns
    -------
    torch.Tensor
        Output tensor with shape (B, C, Ho, Wo) after neighborhood attention computation
    """
Boris Bonev's avatar
Boris Bonev committed
87
88
89
90
91

    # prepare result tensor
    y = torch.zeros_like(qy)

    for ho in range(nlat_out):
92
93
        
	# get number of nonzeros
Boris Bonev's avatar
Boris Bonev committed
94
95
96
97
98
99
        zstart = row_off[ho]
        zend = row_off[ho+1]

        for wo in range(nlon_out):

            alpha_sum = torch.zeros((y.shape[0],), dtype=y.dtype, device=y.device)
100
            qdotk_max = torch.zeros((y.shape[0],), dtype=y.dtype, device=y.device)
Boris Bonev's avatar
Boris Bonev committed
101
102
103
104
105
106
107
108
109
110
111
112
113

            for idz in range(zstart, zend):
                nz_col_idx = col_idx[idz]

                # compute input indices from psi datastructure
                hi = nz_col_idx // nlon_in
                # account for output shift and ensure positive index due to circular condition
                wi = nz_col_idx % nlon_in
                wip = (wi + wo) % nlon_in

                # compute correlation & softmax numerator
                q_ho_wo = qy[:, :, ho, wo]
                k_hi_wip = kx[:, :, hi, wip]
114
                qdotk = torch.sum(q_ho_wo * k_hi_wip, dim=1)
Boris Bonev's avatar
Boris Bonev committed
115

116
117
                # tmp max
                qdotk_max_tmp = torch.maximum(qdotk_max, qdotk)
Boris Bonev's avatar
Boris Bonev committed
118

119
120
121
122
123
                # alpha sum update
                alpha = torch.exp(qdotk - qdotk_max_tmp) * quad_weights[hi]
                alpha_sum = alpha + alpha_sum * torch.exp(qdotk_max - qdotk_max_tmp)
                # update output
                y[:,:,ho,wo] = y[:,:,ho,wo] * torch.exp(qdotk_max - qdotk_max_tmp).unsqueeze(1) + alpha[:, None] * vx[:,:,hi,wip]
Boris Bonev's avatar
Boris Bonev committed
124

125
126
                # define new max
                qdotk_max = qdotk_max_tmp
Boris Bonev's avatar
Boris Bonev committed
127
128
129
130
131
132
133
134
135
136
137

            y[:,:,ho,wo] = y[:,:,ho,wo] / alpha_sum[:, None]

    return y


# Explicit gradient w.r.t. vx: dM/dv
# provided as a reference for CUDA & other hand-written gradients
def _neighborhood_attention_s2_bwd_dv_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor,
                                            quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
                                            nlon_in: int, nlat_out: int, nlon_out: int):
apaaris's avatar
apaaris committed
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
    """
    Backward pass implementation for value gradients in neighborhood attention on S2.
    
    This function computes the gradient of the output with respect to the value tensor (vx).
    It implements the backward pass for the neighborhood attention operation using
    sparse tensor operations and quadrature weights for spherical integration.
    
    Parameters
    -----------
    kx : torch.Tensor
        Key tensor with shape (B, C, Hi, Wi)
    vx : torch.Tensor
        Value tensor with shape (B, C, Hi, Wi)
    qy : torch.Tensor
        Query tensor with shape (B, C, Ho, Wo)
    dy : torch.Tensor
        Gradient of the output with shape (B, C, Ho, Wo)
    quad_weights : torch.Tensor
        Quadrature weights for spherical integration with shape (Hi,)
    col_idx : torch.Tensor
        Column indices for sparse computation
    row_off : torch.Tensor
        Row offsets for sparse computation
    nlon_in : int
        Number of input longitude points
    nlat_out : int
        Number of output latitude points
    nlon_out : int
        Number of output longitude points
        
    Returns
    -------
    torch.Tensor
        Gradient of the value tensor with shape (B, C, Hi, Wi)
    """
Boris Bonev's avatar
Boris Bonev committed
173
174
175
176
177
178
179
180
181
182
183

    # shapes:
    # input
    # kx: B, C, Hi, Wi
    # vx: B, C, Hi, Wi
    # qy: B, C, Ho, Wo
    # quad_weights: Hi
    # output
    # dvx: B, C, Hi, Wi

    dvx = torch.zeros_like(vx)
Boris Bonev's avatar
Boris Bonev committed
184

Boris Bonev's avatar
Boris Bonev committed
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
    for ho in range(nlat_out):

        # get number of nonzeros
        zstart = row_off[ho]
        zend = row_off[ho+1]

        for wo in range(nlon_out):

            alpha_nz = torch.zeros((dy.shape[0], zend-zstart), dtype=dy.dtype, device=dy.device)
            qdotk_nz = torch.zeros((dy.shape[0], zend-zstart), dtype=dy.dtype, device=dy.device)
            alpha_sum = torch.zeros((dy.shape[0],), dtype=dy.dtype, device=dy.device)
            for idz in range(zstart, zend):
                nz_col_idx = col_idx[idz]

                # compute input indices from psi datastructure
                hi = nz_col_idx // nlon_in
                # account for output shift and ensure positive index due to circular condition
                wi = nz_col_idx % nlon_in
                wip = (wi+wo) % nlon_in

                # compute correlation & softmax numerator
                q_ho_wo = qy[:, :, ho, wo]
                k_hi_wi = kx[:, :, hi, wip]
                qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hi_wi, dim=1)

            qdotk_max, _ = torch.max(qdotk_nz, dim=1)

            for idz in range(zstart, zend):
                nz_col_idx = col_idx[idz]

                # compute input indices from psi datastructure
                hi = nz_col_idx // nlon_in
                # account for output shift and ensure positive index due to circular condition
                wi = nz_col_idx % nlon_in
                wip = (wi+wo) % nlon_in
                alpha_nz[:,idz-zstart] = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max) * quad_weights[hi]
                alpha_sum[:] += alpha_nz[:,idz-zstart]

            for idz in range(zstart, zend):
                nz_col_idx = col_idx[idz]

                # compute input indices from psi datastructure
                hi = nz_col_idx // nlon_in
                # account for output shift and ensure positive index due to circular condition
                wi = nz_col_idx % nlon_in
                wip = (wi+wo) % nlon_in
                dvx[:,:,hi, wip] += (alpha_nz[:, None, idz-zstart] / alpha_sum[:, None]) * dy[:,:,ho,wo]

    return dvx


# Explicit gradient w.r.t. kx: dM/dk
# provided as a reference for CUDA & other hand-written gradients
def _neighborhood_attention_s2_bwd_dk_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor,
                                            quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
                                            nlon_in: int, nlat_out: int, nlon_out: int):
apaaris's avatar
apaaris committed
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
    """
    Backward pass implementation for key gradients in neighborhood attention on S2.
    
    This function computes the gradient of the output with respect to the key tensor (kx).
    It implements the backward pass for the neighborhood attention operation using
    sparse tensor operations and quadrature weights for spherical integration.
    
    Parameters
    -----------
    kx : torch.Tensor
        Key tensor with shape (B, C, Hi, Wi)
    vx : torch.Tensor
        Value tensor with shape (B, C, Hi, Wi)
    qy : torch.Tensor
        Query tensor with shape (B, C, Ho, Wo)
    dy : torch.Tensor
        Gradient of the output with shape (B, C, Ho, Wo)
    quad_weights : torch.Tensor
        Quadrature weights for spherical integration with shape (Hi,)
    col_idx : torch.Tensor
        Column indices for sparse computation
    row_off : torch.Tensor
        Row offsets for sparse computation
    nlon_in : int
        Number of input longitude points
    nlat_out : int
        Number of output latitude points
    nlon_out : int
        Number of output longitude points
        
    Returns
    -------
    torch.Tensor
        Gradient of the key tensor with shape (B, C, Hi, Wi)
    """
Boris Bonev's avatar
Boris Bonev committed
276
277
278
279
280
281
282
283
284

    # shapes:
    # input
    # kx: B, C, Hi, Wi
    # vx: B, C, Hi, Wi
    # qy: B, C, Ho, Wo
    # quad_weights: Hi
    # output
    # dkx: B, C, Hi, Wi
Boris Bonev's avatar
Boris Bonev committed
285

Boris Bonev's avatar
Boris Bonev committed
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
    dkx = torch.zeros_like(kx)

    for ho in range(nlat_out):

        # get number of nonzeros
        zstart = row_off[ho]
        zend = row_off[ho+1]

        for wo in range(nlon_out):

            qdotk_nz = torch.zeros((dy.shape[0], zend-zstart), dtype=dy.dtype, device=dy.device)
            integral = torch.zeros((dy.shape[0],), dtype=dy.dtype, device=dy.device)
            alpha = torch.zeros((dy.shape[0], zend-zstart), dtype=dy.dtype, device=dy.device)
            alpha_sum = torch.zeros((dy.shape[0],), dtype=dy.dtype, device=dy.device)
            for idz in range(zstart, zend):
                nz_col_idx = col_idx[idz]

                # compute input indices from psi datastructure
                hj = nz_col_idx // nlon_in
                # account for output shift and ensure positive index due to circular condition
                wj = nz_col_idx % nlon_in
                wjp = (wj+wo) % nlon_in

                # compute correlation & softmax numerator
                q_ho_wo = qy[:, :, ho, wo]
                k_hj_wjp = kx[:, :, hj, wjp]
                qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hj_wjp, dim=1)

            qdotk_max, _ = torch.max(qdotk_nz, dim=1)

            for idz in range(zstart, zend):
                nz_col_idx = col_idx[idz]

                # compute input indices from psi datastructure
                hj = nz_col_idx // nlon_in
                # account for output shift and ensure positive index due to circular condition
                wj = nz_col_idx % nlon_in
                wjp = (wj+wo) % nlon_in

                alpha[:, idz-zstart] = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max) * quad_weights[hj]
                alpha_sum[:] += alpha[:, idz-zstart]

                # input dot
                gdotv = torch.sum(dy[:,:,ho, wo] * vx[:,:,hj, wjp], dim=1)

                # integral term
                integral[:] += alpha[:, idz-zstart] * gdotv[:]

            integral[:] = integral[:] / alpha_sum[:]

            for idz in range(zstart, zend):
                nz_col_idx = col_idx[idz]

                # compute input indices from psi datastructure
                hi = nz_col_idx // nlon_in
                # account for output shift and ensure positive index due to circular condition
                wi = nz_col_idx % nlon_in
                wip = (wi+wo) % nlon_in

                # compute correlation & softmax numerator
                gdotv = torch.sum(dy[:,:,ho, wo] * vx[:,:,hi, wip], dim=1)

                dkx[:,:,hi,wip] += qy[:, :, ho, wo] * (alpha[:, None, idz-zstart] / alpha_sum[:, None]) * (gdotv[:, None] - integral[:, None])

    return dkx

# Explicit gradient w.r.t. qy: dM/dq
# provided as a reference for CUDA & other hand-written gradients
def _neighborhood_attention_s2_bwd_dq_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor,
                                            quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
                                            nlon_in: int, nlat_out: int, nlon_out: int):
apaaris's avatar
apaaris committed
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
    """
    Backward pass implementation for query gradients in neighborhood attention on S2.
    
    This function computes the gradient of the output with respect to the query tensor (qy).
    It implements the backward pass for the neighborhood attention operation using
    sparse tensor operations and quadrature weights for spherical integration.
    
    Parameters
    -----------
    kx : torch.Tensor
        Key tensor with shape (B, C, Hi, Wi)
    vx : torch.Tensor
        Value tensor with shape (B, C, Hi, Wi)
    qy : torch.Tensor
        Query tensor with shape (B, C, Ho, Wo)
    dy : torch.Tensor
        Gradient of the output with shape (B, C, Ho, Wo)
    quad_weights : torch.Tensor
        Quadrature weights for spherical integration with shape (Hi,)
    col_idx : torch.Tensor
        Column indices for sparse computation
    row_off : torch.Tensor
        Row offsets for sparse computation
    nlon_in : int
        Number of input longitude points
    nlat_out : int
        Number of output latitude points
    nlon_out : int
        Number of output longitude points
        
    Returns
    -------
    torch.Tensor
        Gradient of the query tensor with shape (B, C, Ho, Wo)
    """
Boris Bonev's avatar
Boris Bonev committed
392
393
394
395
396
397
398
399
400

    # shapes:
    # input
    # kx: B, C, Hi, Wi
    # vx: B, C, Hi, Wi
    # qy: B, C, Ho, Wo
    # quad_weights: Hi
    # output
    # dvx: B, C, Hi, Wi
Boris Bonev's avatar
Boris Bonev committed
401

Boris Bonev's avatar
Boris Bonev committed
402
403
404
    dqy = torch.zeros_like(qy)

    for ho in range(nlat_out):
Boris Bonev's avatar
Boris Bonev committed
405

Boris Bonev's avatar
Boris Bonev committed
406
407
408
        # get number of nonzeros
        zstart = row_off[ho]
        zend = row_off[ho+1]
Boris Bonev's avatar
Boris Bonev committed
409

Boris Bonev's avatar
Boris Bonev committed
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
        for wo in range(nlon_out):

            alpha = torch.zeros((dy.shape[0], zend-zstart), dtype=dy.dtype, device=dy.device)
            qdotk_nz = torch.zeros((dy.shape[0], zend-zstart), dtype=dy.dtype, device=dy.device)
            alpha_k = torch.zeros((dy.shape[0], dy.shape[1]), dtype=dy.dtype, device=dy.device)
            alpha_vw = torch.zeros((dy.shape[0], dy.shape[1]), dtype=dy.dtype, device=dy.device)
            alpha_kvw = torch.zeros((dy.shape[0], dy.shape[1]), dtype=dy.dtype, device=dy.device)
            alpha_sum = torch.zeros((dy.shape[0],), dtype=dy.dtype, device=dy.device)
            alpha_sum2 = torch.zeros((dy.shape[0],), dtype=dy.dtype, device=dy.device)
            for idz in range(zstart, zend):
                nz_col_idx = col_idx[idz]

                # compute input indices from psi datastructure
                hi = nz_col_idx // nlon_in
                # account for output shift and ensure positive index due to circular condition
                wi = nz_col_idx % nlon_in
                wip = (wi+wo) % nlon_in

                idz_i = idz-zstart

                # compute correlation & softmax numerator
                q_ho_wo = qy[:, :, ho, wo]
                k_hi_wi = kx[:, :, hi, wip]
                qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hi_wi, dim=1)

            qdotk_max,_ = qdotk_nz.max(dim=1)

            for idz in range(zstart, zend):
                nz_col_idx = col_idx[idz]

                # compute input indices from psi datastructure
                hi = nz_col_idx // nlon_in
                # account for output shift and ensure positive index due to circular condition
                wi = nz_col_idx % nlon_in
                wip = (wi+wo) % nlon_in

                q_ho_wo = qy[:, :, ho, wo]
                k_hi_wi = kx[:, :, hi, wip]
                idz_i = idz-zstart
                alpha[:, idz_i] = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max) * quad_weights[hi]
                alpha_sum[:] += alpha[:, idz_i]

                gdotv = torch.sum(dy[:,:,ho, wo] * vx[:,:,hi, wip], dim=1)
                alpha_k[:,:] += alpha[:, None, idz_i] * k_hi_wi
                alpha_vw[:,:] += alpha[:, None, idz_i] * gdotv[:,None]
                alpha_kvw[:,:] += alpha[:, None, idz_i] * k_hi_wi * gdotv[:,None]

            dqy[:,:,ho,wo] = (alpha_kvw*alpha_sum[:,None] - alpha_vw*alpha_k) / (alpha_sum[:,None]*alpha_sum[:,None])

    return dqy

class _NeighborhoodAttentionS2(torch.autograd.Function):
apaaris's avatar
apaaris committed
462
463
464
465
466
    r"""
    CPU implementation of neighborhood attention on the sphere (S2).
    This class provides the forward and backward passes for efficient CPU computation
    of neighborhood attention operations using sparse tensor operations.
    """
Boris Bonev's avatar
Boris Bonev committed
467
468
469
470
471
472
473
474

    @staticmethod
    @custom_fwd(device_type="cpu")
    def forward(ctx, k: torch.Tensor, v: torch.Tensor, q: torch.Tensor,
                wk: torch.Tensor, wv: torch.Tensor, wq: torch.Tensor,
                bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None],
                quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
                nh: int, nlon_in: int, nlat_out: int, nlon_out: int):
apaaris's avatar
apaaris committed
475
476
477
        r"""
        Forward pass for CPU neighborhood attention on S2.
        
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
        Parameters
        -----------
        k: torch.Tensor
            Key tensor
        v: torch.Tensor
            Value tensor
        q: torch.Tensor
            Query tensor
        wk: torch.Tensor
            Key weight tensor
        wv: torch.Tensor
            Value weight tensor
        wq: torch.Tensor
            Query weight tensor
        bk: torch.Tensor or None
            Key bias tensor (optional)
        bv: torch.Tensor or None
            Value bias tensor (optional)
        bq: torch.Tensor or None
            Query bias tensor (optional)
        quad_weights: torch.Tensor
            Quadrature weights for spherical integration
        col_idx: torch.Tensor
            Column indices for sparse computation
        row_off: torch.Tensor
            Row offsets for sparse computation
        nh: int
            Number of attention heads
        nlon_in: int
            Number of input longitude points
        nlat_out: int
            Number of output latitude points
        nlon_out: int
            Number of output longitude points
apaaris's avatar
apaaris committed
512
        """
Boris Bonev's avatar
Boris Bonev committed
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
        ctx.save_for_backward(col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq)
        ctx.nh = nh
        ctx.nlon_in = nlon_in
        ctx.nlat_out = nlat_out
        ctx.nlon_out = nlon_out

        kw = F.conv2d(k, weight=wk, bias=bk)
        vw = F.conv2d(v, weight=wv, bias=bv)
        qw = F.conv2d(q, weight=wq, bias=bq)

        # reshape, folding num heads into batch dim
        B, _, H, W = kw.shape
        kw = kw.reshape(B*nh, -1, H, W)
        B, _, H, W = vw.shape
        vw = vw.reshape(B*nh, -1, H, W)
        B, _, H, W = qw.shape
        qw = qw.reshape(B*nh, -1, H, W)

        kw = kw.to(torch.float32)
        vw = vw.to(torch.float32)
        qw = qw.to(torch.float32)
Boris Bonev's avatar
Boris Bonev committed
534

Boris Bonev's avatar
Boris Bonev committed
535
536
537
538
539
540
541
542
543
544
545
546
        output = _neighborhood_attention_s2_fwd_torch(kw, vw, qw, quad_weights,
                                                      col_idx, row_off,
                                                      nlon_in, nlat_out, nlon_out)

        _, C, H, W = output.shape
        output = output.reshape(B, -1, H, W)

        return output

    @staticmethod
    @custom_bwd(device_type="cpu")
    def backward(ctx, grad_output):
apaaris's avatar
apaaris committed
547
548
549
        r"""
        Backward pass for CPU neighborhood attention on S2.
        
550
551
552
553
        Parameters
        -----------
        grad_output: torch.Tensor
            Gradient of the output
apaaris's avatar
apaaris committed
554
        
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
        Returns
        --------
        dk: torch.Tensor
            Gradient of the key tensor
        dv: torch.Tensor
            Gradient of the value tensor
        dq: torch.Tensor
            Gradient of the query tensor
        dwk: torch.Tensor
            Gradient of the key weight tensor
        dwv: torch.Tensor
            Gradient of the value weight tensor
        dwq: torch.Tensor
            Gradient of the query weight tensor
        dbk: torch.Tensor or None
            Gradient of the key bias tensor
        dbv: torch.Tensor or None
            Gradient of the value bias tensor
        dbq: torch.Tensor or None
            Gradient of the query bias tensor
apaaris's avatar
apaaris committed
575
        """
Boris Bonev's avatar
Boris Bonev committed
576
577
578
579
580
        col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors
        nh = ctx.nh
        nlon_in = ctx.nlon_in
        nlat_out = ctx.nlat_out
        nlon_out = ctx.nlon_out
Boris Bonev's avatar
Boris Bonev committed
581

Boris Bonev's avatar
Boris Bonev committed
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
        kw = F.conv2d(k, weight=wk, bias=bk)
        vw = F.conv2d(v, weight=wv, bias=bv)
        qw = F.conv2d(q, weight=wq, bias=bq)

        # reshape, folding num heads into batch dim
        B, _, H, W = kw.shape
        kw = kw.reshape(B*nh, -1, H, W)
        B, _, H, W = vw.shape
        vw = vw.reshape(B*nh, -1, H, W)
        B, _, H, W = qw.shape
        qw = qw.reshape(B*nh, -1, H, W)
        B, _, H, W  = grad_output.shape
        grad_output = grad_output.reshape(B*nh, -1, H, W)

        dvw = _neighborhood_attention_s2_bwd_dv_torch(kw, vw, qw, grad_output,
                                                      quad_weights,
                                                      col_idx, row_off,
                                                      nlon_in, nlat_out, nlon_out)

        dkw = _neighborhood_attention_s2_bwd_dk_torch(kw, vw, qw, grad_output,
                                                      quad_weights,
                                                      col_idx, row_off,
                                                      nlon_in, nlat_out, nlon_out)

        dqw = _neighborhood_attention_s2_bwd_dq_torch(kw, vw, qw, grad_output,
                                                      quad_weights,
                                                      col_idx, row_off,
                                                      nlon_in, nlat_out, nlon_out)

        # reshape again
        _, C, H, W = dkw.shape
        dkw = dkw.reshape(B, -1, H, W)
        _, C, H, W = dvw.shape
        dvw = dvw.reshape(B, -1, H, W)
        _, C, H, W = dqw.shape
        dqw = dqw.reshape(B, -1, H, W)
Boris Bonev's avatar
Boris Bonev committed
618

Boris Bonev's avatar
Boris Bonev committed
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
        # input grads
        dv = torch.nn.functional.conv2d(dvw, weight=wv.permute([1,0,2,3]), bias=None)
        dk = torch.nn.functional.conv2d(dkw, weight=wk.permute([1,0,2,3]), bias=None)
        dq = torch.nn.functional.conv2d(dqw, weight=wq.permute([1,0,2,3]), bias=None)

        # weight grads
        dwv = torch.einsum("bchw,bfhw->cf", dvw, v).reshape(*wv.shape).contiguous()
        dwk = torch.einsum("bchw,bfhw->cf", dkw, k).reshape(*wk.shape).contiguous()
        dwq = torch.einsum("bchw,bfhw->cf", dqw, q).reshape(*wq.shape).contiguous()

        # bias grads:
        if bv is not None:
            dbv = torch.sum(dvw, dim=(0,2,3))
        else:
            dbv = None

        if bk is not None:
            dbk = torch.sum(dkw, dim=(0,2,3))
        else:
            dbk = None

        if bq is not None:
            dbq = torch.sum(dqw, dim=(0,2,3))
        else:
            dbq = None

        return dk, dv, dq, dwk, dwv, dwq, dbk, dbv, dbq, \
                None, None, None, None, None, None, None


Boris Bonev's avatar
Boris Bonev committed
649
def _neighborhood_attention_s2_torch(k: torch.Tensor, v: torch.Tensor, q: torch.Tensor,
Boris Bonev's avatar
Boris Bonev committed
650
                                     wk: torch.Tensor, wv: torch.Tensor, wq: torch.Tensor,
Boris Bonev's avatar
Boris Bonev committed
651
                                     bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None],
Boris Bonev's avatar
Boris Bonev committed
652
653
654
                                     bq: Union[torch.Tensor, None], quad_weights: torch.Tensor,
                                     col_idx: torch.Tensor, row_off: torch.Tensor,
                                     nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
apaaris's avatar
apaaris committed
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
    """
    Torch implementation of neighborhood attention on the sphere (S2).
    
    This function provides a wrapper around the CPU autograd function for
    neighborhood attention operations using sparse tensor computations.
    
    Parameters
    -----------
    k : torch.Tensor
        Key tensor
    v : torch.Tensor
        Value tensor
    q : torch.Tensor
        Query tensor
    wk : torch.Tensor
        Key weight tensor
    wv : torch.Tensor
        Value weight tensor
    wq : torch.Tensor
        Query weight tensor
    bk : torch.Tensor or None
        Key bias tensor (optional)
    bv : torch.Tensor or None
        Value bias tensor (optional)
    bq : torch.Tensor or None
        Query bias tensor (optional)
    quad_weights : torch.Tensor
        Quadrature weights for spherical integration
    col_idx : torch.Tensor
        Column indices for sparse computation
    row_off : torch.Tensor
        Row offsets for sparse computation
    nh : int
        Number of attention heads
    nlon_in : int
        Number of input longitude points
    nlat_out : int
        Number of output latitude points
    nlon_out : int
        Number of output longitude points
        
    Returns
    -------
    torch.Tensor
        Output tensor after neighborhood attention computation
    """
Boris Bonev's avatar
Boris Bonev committed
701
702
703
704
705
706
    return _NeighborhoodAttentionS2.apply(k, v, q, wk, wv, wq, bk, bv, bq,
                                          quad_weights, col_idx, row_off,
                                          nh, nlon_in, nlat_out, nlon_out)


class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
apaaris's avatar
apaaris committed
707
708
709
710
711
    r"""
    CUDA implementation of neighborhood attention on the sphere (S2).
    This class provides the forward and backward passes for efficient GPU computation
    of neighborhood attention operations using custom CUDA kernels.
    """
Boris Bonev's avatar
Boris Bonev committed
712
713
714
715
716

    @staticmethod
    @custom_fwd(device_type="cuda")
    def forward(ctx, k: torch.Tensor, v: torch.Tensor, q: torch.Tensor,
                wk: torch.Tensor, wv: torch.Tensor, wq: torch.Tensor,
Boris Bonev's avatar
Boris Bonev committed
717
                bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None],
Boris Bonev's avatar
Boris Bonev committed
718
719
                quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
                max_psi_nnz: int, nh: int, nlon_in: int, nlat_out: int, nlon_out: int):
apaaris's avatar
apaaris committed
720
721
722
        r"""
        Forward pass for CUDA neighborhood attention on S2.
        
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
        Parameters
        -----------
        k: torch.Tensor
            Key tensor
        v: torch.Tensor
            Value tensor
        q: torch.Tensor
            Query tensor
        wk: torch.Tensor
            Key weight tensor
        wv: torch.Tensor
            Value weight tensor
        wq: torch.Tensor
            Query weight tensor
        bk: torch.Tensor or None
            Key bias tensor (optional)
        bv: torch.Tensor or None
            Value bias tensor (optional)
        bq: torch.Tensor or None
            Query bias tensor (optional)
        quad_weights: torch.Tensor
            Quadrature weights for spherical integration
        col_idx: torch.Tensor
            Column indices for sparse computation
        row_off: torch.Tensor
            Row offsets for sparse computation
        max_psi_nnz: int
            Maximum number of non-zero elements in sparse tensor
        nh: int
            Number of attention heads
        nlon_in: int
            Number of input longitude points
        nlat_out: int
            Number of output latitude points
        nlon_out: int
            Number of output longitude points
apaaris's avatar
apaaris committed
759
        """
Boris Bonev's avatar
Boris Bonev committed
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
        ctx.save_for_backward(col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq)
        ctx.nh = nh
        ctx.max_psi_nnz = max_psi_nnz
        ctx.nlon_in = nlon_in
        ctx.nlat_out = nlat_out
        ctx.nlon_out = nlon_out

        kw = F.conv2d(k, weight=wk, bias=bk)
        vw = F.conv2d(v, weight=wv, bias=bv)
        qw = F.conv2d(q, weight=wq, bias=bq)

        # reshape, folding num heads into batch dim
        B, _, H, W = kw.shape
        kw = kw.reshape(B*nh, -1, H, W)
        B, _, H, W = vw.shape
        vw = vw.reshape(B*nh, -1, H, W)
        B, _, H, W = qw.shape
        qw = qw.reshape(B*nh, -1, H, W)
Boris Bonev's avatar
Boris Bonev committed
778

Boris Bonev's avatar
Boris Bonev committed
779
        # convert to float32
780
781
782
783
        inp_dtype = kw.dtype
        kw = kw.to(torch.float32).contiguous()
        vw = vw.to(torch.float32).contiguous()
        qw = qw.to(torch.float32).contiguous()
Boris Bonev's avatar
Boris Bonev committed
784

Boris Bonev's avatar
Boris Bonev committed
785
786
787
788
789
790
791
        output = attention_cuda_extension.forward(kw, vw, qw, quad_weights,
                                                  col_idx, row_off,
                                                  nlon_in, nlat_out, nlon_out)

        _, C, H, W = output.shape
        output = output.reshape(B, -1, H, W)

792
793
794
        # convert back precision
        output = output.to(dtype=inp_dtype)

Boris Bonev's avatar
Boris Bonev committed
795
796
797
798
799
        return output

    @staticmethod
    @custom_bwd(device_type="cuda")
    def backward(ctx, grad_output):
apaaris's avatar
apaaris committed
800
801
802
        r"""
        Backward pass for CUDA neighborhood attention on S2.
        
803
804
805
806
        Parameters
        -----------
        grad_output: torch.Tensor
            Gradient of the output
apaaris's avatar
apaaris committed
807
        
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
        Returns
        --------
        dk: torch.Tensor
            Gradient of the key tensor
        dv: torch.Tensor
            Gradient of the value tensor
        dq: torch.Tensor
            Gradient of the query tensor
        dwk: torch.Tensor
            Gradient of the key weight tensor
        dwv: torch.Tensor
            Gradient of the value weight tensor
        dwq: torch.Tensor
            Gradient of the query weight tensor
        dbk: torch.Tensor or None
            Gradient of the key bias tensor
        dbv: torch.Tensor or None
            Gradient of the value bias tensor
        dbq: torch.Tensor or None
            Gradient of the query bias tensor
apaaris's avatar
apaaris committed
828
        """
Boris Bonev's avatar
Boris Bonev committed
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
        col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors
        nh = ctx.nh
        max_psi_nnz = ctx.max_psi_nnz
        nlon_in = ctx.nlon_in
        nlat_out = ctx.nlat_out
        nlon_out = ctx.nlon_out

        kw = F.conv2d(k, weight=wk, bias=bk)
        vw = F.conv2d(v, weight=wv, bias=bv)
        qw = F.conv2d(q, weight=wq, bias=bq)

        # reshape, folding num heads into batch dim
        B, _, H, W = kw.shape
        kw = kw.reshape(B*nh, -1, H, W)
        B, _, H, W = vw.shape
        vw = vw.reshape(B*nh, -1, H, W)
        B, _, H, W = qw.shape
        qw = qw.reshape(B*nh, -1, H, W)
        B, _, H, W  = grad_output.shape
        grad_output = grad_output.reshape(B*nh, -1, H, W)

850
851
852
853
854
855
856
857
858
859
        # save type and convert to float32
        kw_dtype = kw.dtype
        vw_dtype = vw.dtype
        qw_dtype = qw.dtype

        kw = kw.to(torch.float32).contiguous()
        vw = vw.to(torch.float32).contiguous()
        qw = qw.to(torch.float32).contiguous()
        grad_output = grad_output.to(torch.float32).contiguous()

Boris Bonev's avatar
Boris Bonev committed
860
861
862
863
864
865
866
867
868
869
870
871
872
        dkw,dvw,dqw = attention_cuda_extension.backward_dkvq(kw, vw, qw, grad_output,
                                                             quad_weights,
                                                             col_idx, row_off,
                                                             nlon_in, nlat_out, nlon_out)

        # reshape again
        _, C, H, W = dkw.shape
        dkw = dkw.reshape(B, -1, H, W)
        _, C, H, W = dvw.shape
        dvw = dvw.reshape(B, -1, H, W)
        _, C, H, W = dqw.shape
        dqw = dqw.reshape(B, -1, H, W)

873
874
875
876
877
        # convert precision
        dkw = dkw.to(dtype=kw_dtype)
        dvw = dvw.to(dtype=vw_dtype)
        dqw = dqw.to(dtype=qw_dtype)

Boris Bonev's avatar
Boris Bonev committed
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
        # input grads
        dv = torch.nn.functional.conv2d(dvw, weight=wv.permute([1,0,2,3]), bias=None)
        dk = torch.nn.functional.conv2d(dkw, weight=wk.permute([1,0,2,3]), bias=None)
        dq = torch.nn.functional.conv2d(dqw, weight=wq.permute([1,0,2,3]), bias=None)

        # weight grads
        dwv = torch.einsum("bchw,bfhw->cf", dvw, v).reshape(*wv.shape).contiguous()
        dwk = torch.einsum("bchw,bfhw->cf", dkw, k).reshape(*wk.shape).contiguous()
        dwq = torch.einsum("bchw,bfhw->cf", dqw, q).reshape(*wq.shape).contiguous()

        # bias grads:
        if bv is not None:
            dbv = torch.sum(dvw, dim=(0,2,3))
        else:
            dbv = None

        if bk is not None:
            dbk = torch.sum(dkw, dim=(0,2,3))
        else:
            dbk = None

        if bq is not None:
            dbq = torch.sum(dqw, dim=(0,2,3))
        else:
            dbq = None

        return dk, dv, dq, dwk, dwv, dwq, dbk, dbv, dbq, \
                None, None, None, None, None, None, None, None


Boris Bonev's avatar
Boris Bonev committed
908
def _neighborhood_attention_s2_cuda(k: torch.Tensor, v: torch.Tensor, q: torch.Tensor,
Boris Bonev's avatar
Boris Bonev committed
909
                                    wk: torch.Tensor, wv: torch.Tensor, wq: torch.Tensor,
Boris Bonev's avatar
Boris Bonev committed
910
                                    bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None],
Boris Bonev's avatar
Boris Bonev committed
911
912
913
                                    bq: Union[torch.Tensor, None], quad_weights: torch.Tensor,
                                    col_idx: torch.Tensor, row_off: torch.Tensor, max_psi_nnz: int,
                                    nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
apaaris's avatar
apaaris committed
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
    """
    CUDA implementation of neighborhood attention on the sphere (S2).
    
    This function provides a wrapper around the CUDA autograd function for
    neighborhood attention operations using custom CUDA kernels for efficient GPU computation.
    
    Parameters
    -----------
    k : torch.Tensor
        Key tensor
    v : torch.Tensor
        Value tensor
    q : torch.Tensor
        Query tensor
    wk : torch.Tensor
        Key weight tensor
    wv : torch.Tensor
        Value weight tensor
    wq : torch.Tensor
        Query weight tensor
    bk : torch.Tensor or None
        Key bias tensor (optional)
    bv : torch.Tensor or None
        Value bias tensor (optional)
    bq : torch.Tensor or None
        Query bias tensor (optional)
    quad_weights : torch.Tensor
        Quadrature weights for spherical integration
    col_idx : torch.Tensor
        Column indices for sparse computation
    row_off : torch.Tensor
        Row offsets for sparse computation
    max_psi_nnz : int
        Maximum number of non-zero elements in sparse tensor
    nh : int
        Number of attention heads
    nlon_in : int
        Number of input longitude points
    nlat_out : int
        Number of output latitude points
    nlon_out : int
        Number of output longitude points
        
    Returns
    -------
    torch.Tensor
        Output tensor after neighborhood attention computation
    """
Boris Bonev's avatar
Boris Bonev committed
962
    return _NeighborhoodAttentionS2Cuda.apply(k, v, q, wk, wv, wq, bk, bv, bq,
Boris Bonev's avatar
Boris Bonev committed
963
964
                                              quad_weights, col_idx, row_off, max_psi_nnz,
                                              nh, nlon_in, nlat_out, nlon_out)