Commit ae8257b5 authored by Boris Bonev's avatar Boris Bonev Committed by Boris Bonev
Browse files

adapting header files

parent 13d6130e
# coding=utf-8 # coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause # SPDX-License-Identifier: BSD-3-Clause
# #
# Redistribution and use in source and binary forms, with or without # Redistribution and use in source and binary forms, with or without
...@@ -115,7 +115,7 @@ def _neighborhood_attention_s2_bwd_dv_torch(kx: torch.Tensor, vx: torch.Tensor, ...@@ -115,7 +115,7 @@ def _neighborhood_attention_s2_bwd_dv_torch(kx: torch.Tensor, vx: torch.Tensor,
# dvx: B, C, Hi, Wi # dvx: B, C, Hi, Wi
dvx = torch.zeros_like(vx) dvx = torch.zeros_like(vx)
for ho in range(nlat_out): for ho in range(nlat_out):
# get number of nonzeros # get number of nonzeros
...@@ -181,7 +181,7 @@ def _neighborhood_attention_s2_bwd_dk_torch(kx: torch.Tensor, vx: torch.Tensor, ...@@ -181,7 +181,7 @@ def _neighborhood_attention_s2_bwd_dk_torch(kx: torch.Tensor, vx: torch.Tensor,
# quad_weights: Hi # quad_weights: Hi
# output # output
# dkx: B, C, Hi, Wi # dkx: B, C, Hi, Wi
dkx = torch.zeros_like(kx) dkx = torch.zeros_like(kx)
for ho in range(nlat_out): for ho in range(nlat_out):
...@@ -262,15 +262,15 @@ def _neighborhood_attention_s2_bwd_dq_torch(kx: torch.Tensor, vx: torch.Tensor, ...@@ -262,15 +262,15 @@ def _neighborhood_attention_s2_bwd_dq_torch(kx: torch.Tensor, vx: torch.Tensor,
# quad_weights: Hi # quad_weights: Hi
# output # output
# dvx: B, C, Hi, Wi # dvx: B, C, Hi, Wi
dqy = torch.zeros_like(qy) dqy = torch.zeros_like(qy)
for ho in range(nlat_out): for ho in range(nlat_out):
# get number of nonzeros # get number of nonzeros
zstart = row_off[ho] zstart = row_off[ho]
zend = row_off[ho+1] zend = row_off[ho+1]
for wo in range(nlon_out): for wo in range(nlon_out):
alpha = torch.zeros((dy.shape[0], zend-zstart), dtype=dy.dtype, device=dy.device) alpha = torch.zeros((dy.shape[0], zend-zstart), dtype=dy.dtype, device=dy.device)
...@@ -353,7 +353,7 @@ class _NeighborhoodAttentionS2(torch.autograd.Function): ...@@ -353,7 +353,7 @@ class _NeighborhoodAttentionS2(torch.autograd.Function):
kw = kw.to(torch.float32) kw = kw.to(torch.float32)
vw = vw.to(torch.float32) vw = vw.to(torch.float32)
qw = qw.to(torch.float32) qw = qw.to(torch.float32)
output = _neighborhood_attention_s2_fwd_torch(kw, vw, qw, quad_weights, output = _neighborhood_attention_s2_fwd_torch(kw, vw, qw, quad_weights,
col_idx, row_off, col_idx, row_off,
nlon_in, nlat_out, nlon_out) nlon_in, nlat_out, nlon_out)
...@@ -371,7 +371,7 @@ class _NeighborhoodAttentionS2(torch.autograd.Function): ...@@ -371,7 +371,7 @@ class _NeighborhoodAttentionS2(torch.autograd.Function):
nlon_in = ctx.nlon_in nlon_in = ctx.nlon_in
nlat_out = ctx.nlat_out nlat_out = ctx.nlat_out
nlon_out = ctx.nlon_out nlon_out = ctx.nlon_out
kw = F.conv2d(k, weight=wk, bias=bk) kw = F.conv2d(k, weight=wk, bias=bk)
vw = F.conv2d(v, weight=wv, bias=bv) vw = F.conv2d(v, weight=wv, bias=bv)
qw = F.conv2d(q, weight=wq, bias=bq) qw = F.conv2d(q, weight=wq, bias=bq)
...@@ -408,7 +408,7 @@ class _NeighborhoodAttentionS2(torch.autograd.Function): ...@@ -408,7 +408,7 @@ class _NeighborhoodAttentionS2(torch.autograd.Function):
dvw = dvw.reshape(B, -1, H, W) dvw = dvw.reshape(B, -1, H, W)
_, C, H, W = dqw.shape _, C, H, W = dqw.shape
dqw = dqw.reshape(B, -1, H, W) dqw = dqw.reshape(B, -1, H, W)
# input grads # input grads
dv = torch.nn.functional.conv2d(dvw, weight=wv.permute([1,0,2,3]), bias=None) dv = torch.nn.functional.conv2d(dvw, weight=wv.permute([1,0,2,3]), bias=None)
dk = torch.nn.functional.conv2d(dkw, weight=wk.permute([1,0,2,3]), bias=None) dk = torch.nn.functional.conv2d(dkw, weight=wk.permute([1,0,2,3]), bias=None)
...@@ -439,13 +439,13 @@ class _NeighborhoodAttentionS2(torch.autograd.Function): ...@@ -439,13 +439,13 @@ class _NeighborhoodAttentionS2(torch.autograd.Function):
None, None, None, None, None, None, None None, None, None, None, None, None, None
def _neighborhood_attention_s2_torch(k: torch.Tensor, v: torch.Tensor, q: torch.Tensor, def _neighborhood_attention_s2_torch(k: torch.Tensor, v: torch.Tensor, q: torch.Tensor,
wk: torch.Tensor, wv: torch.Tensor, wq: torch.Tensor, wk: torch.Tensor, wv: torch.Tensor, wq: torch.Tensor,
bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None],
bq: Union[torch.Tensor, None], quad_weights: torch.Tensor, bq: Union[torch.Tensor, None], quad_weights: torch.Tensor,
col_idx: torch.Tensor, row_off: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor: nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
return _NeighborhoodAttentionS2.apply(k, v, q, wk, wv, wq, bk, bv, bq, return _NeighborhoodAttentionS2.apply(k, v, q, wk, wv, wq, bk, bv, bq,
quad_weights, col_idx, row_off, quad_weights, col_idx, row_off,
nh, nlon_in, nlat_out, nlon_out) nh, nlon_in, nlat_out, nlon_out)
...@@ -457,7 +457,7 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function): ...@@ -457,7 +457,7 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
@custom_fwd(device_type="cuda") @custom_fwd(device_type="cuda")
def forward(ctx, k: torch.Tensor, v: torch.Tensor, q: torch.Tensor, def forward(ctx, k: torch.Tensor, v: torch.Tensor, q: torch.Tensor,
wk: torch.Tensor, wv: torch.Tensor, wq: torch.Tensor, wk: torch.Tensor, wv: torch.Tensor, wq: torch.Tensor,
bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None], bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None],
quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
max_psi_nnz: int, nh: int, nlon_in: int, nlat_out: int, nlon_out: int): max_psi_nnz: int, nh: int, nlon_in: int, nlat_out: int, nlon_out: int):
...@@ -479,12 +479,12 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function): ...@@ -479,12 +479,12 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
vw = vw.reshape(B*nh, -1, H, W) vw = vw.reshape(B*nh, -1, H, W)
B, _, H, W = qw.shape B, _, H, W = qw.shape
qw = qw.reshape(B*nh, -1, H, W) qw = qw.reshape(B*nh, -1, H, W)
# convert to float32 # convert to float32
kw = kw.to(torch.float32) kw = kw.to(torch.float32)
vw = vw.to(torch.float32) vw = vw.to(torch.float32)
qw = qw.to(torch.float32) qw = qw.to(torch.float32)
output = attention_cuda_extension.forward(kw, vw, qw, quad_weights, output = attention_cuda_extension.forward(kw, vw, qw, quad_weights,
col_idx, row_off, col_idx, row_off,
nlon_in, nlat_out, nlon_out) nlon_in, nlat_out, nlon_out)
...@@ -561,13 +561,13 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function): ...@@ -561,13 +561,13 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
None, None, None, None, None, None, None, None None, None, None, None, None, None, None, None
def _neighborhood_attention_s2_cuda(k: torch.Tensor, v: torch.Tensor, q: torch.Tensor, def _neighborhood_attention_s2_cuda(k: torch.Tensor, v: torch.Tensor, q: torch.Tensor,
wk: torch.Tensor, wv: torch.Tensor, wq: torch.Tensor, wk: torch.Tensor, wv: torch.Tensor, wq: torch.Tensor,
bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None],
bq: Union[torch.Tensor, None], quad_weights: torch.Tensor, bq: Union[torch.Tensor, None], quad_weights: torch.Tensor,
col_idx: torch.Tensor, row_off: torch.Tensor, max_psi_nnz: int, col_idx: torch.Tensor, row_off: torch.Tensor, max_psi_nnz: int,
nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor: nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
return _NeighborhoodAttentionS2Cuda.apply(k, v, q, wk, wv, wq, bk, bv, bq, return _NeighborhoodAttentionS2Cuda.apply(k, v, q, wk, wv, wq, bk, bv, bq,
quad_weights, col_idx, row_off, max_psi_nnz, quad_weights, col_idx, row_off, max_psi_nnz,
nh, nlon_in, nlat_out, nlon_out) nh, nlon_in, nlat_out, nlon_out)
# coding=utf-8 # coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause # SPDX-License-Identifier: BSD-3-Clause
# #
# Redistribution and use in source and binary forms, with or without # Redistribution and use in source and binary forms, with or without
......
// coding=utf-8
//
// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "ATen/core/TensorAccessor.h" #include "ATen/core/TensorAccessor.h"
#include <cmath> #include <cmath>
#include <cstdint> #include <cstdint>
......
# coding=utf-8 # coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause # SPDX-License-Identifier: BSD-3-Clause
# #
# Redistribution and use in source and binary forms, with or without # Redistribution and use in source and binary forms, with or without
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment