Unverified Commit 4aaff021 authored by Thorsten Kurth's avatar Thorsten Kurth Committed by GitHub
Browse files

Merge pull request #91 from NVIDIA/maurob/devel

Attention Backward improvement
parents ab44ba59 fa58767d
...@@ -61,6 +61,7 @@ def get_compile_args(module_name): ...@@ -61,6 +61,7 @@ def get_compile_args(module_name):
nvcc_extra_flags = [] nvcc_extra_flags = []
if profile_mode: if profile_mode:
nvcc_extra_flags.append("-lineinfo") nvcc_extra_flags.append("-lineinfo")
nvcc_extra_flags.append("-Xptxas=-v")
if debug_mode: if debug_mode:
print(f"WARNING: Compiling {module_name} with debugging flags") print(f"WARNING: Compiling {module_name} with debugging flags")
...@@ -102,6 +103,7 @@ def get_ext_modules(): ...@@ -102,6 +103,7 @@ def get_ext_modules():
CUDAExtension( CUDAExtension(
name="attention_cuda_extension", name="attention_cuda_extension",
sources=[ sources=[
"torch_harmonics/csrc/attention/attention_utils.cu",
"torch_harmonics/csrc/attention/attention_fwd_cuda.cu", "torch_harmonics/csrc/attention/attention_fwd_cuda.cu",
"torch_harmonics/csrc/attention/attention_bwd_cuda.cu", "torch_harmonics/csrc/attention/attention_bwd_cuda.cu",
"torch_harmonics/csrc/attention/attention_interface.cu", "torch_harmonics/csrc/attention/attention_interface.cu",
......
...@@ -78,7 +78,8 @@ class TestNeighborhoodAttentionS2(unittest.TestCase): ...@@ -78,7 +78,8 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
[4, 4, 1, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3], [4, 4, 1, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
[4, 4, 2, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3], [4, 4, 2, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
[4, 4, 4, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3], [4, 4, 4, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
[4, 4, 1, (6, 12), (6, 12), "legendre-gauss", "legendre-gauss", 1e-5, 1e-3], [4, 1, 1, (2, 4), (2, 4), "equiangular", "equiangular", 1e-5, 1e-3],
[4, 4, 4, (6, 12), (6, 12), "legendre-gauss", "legendre-gauss", 1e-5, 1e-3],
[4, 4, 1, (6, 12), (6, 12), "lobatto", "lobatto", 1e-5, 1e-3], [4, 4, 1, (6, 12), (6, 12), "lobatto", "lobatto", 1e-5, 1e-3],
], ],
skip_on_empty=True, skip_on_empty=True,
...@@ -156,8 +157,6 @@ class TestNeighborhoodAttentionS2(unittest.TestCase): ...@@ -156,8 +157,6 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
[ [
# Format: [batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol] # Format: [batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol]
[4, 4, 1, (6, 12), (6, 12), "equiangular", "equiangular", 1e-2, 0], [4, 4, 1, (6, 12), (6, 12), "equiangular", "equiangular", 1e-2, 0],
# [4, 4, 2, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
# [4, 4, 4, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
[4, 4, 1, (6, 12), (6, 12), "legendre-gauss", "legendre-gauss", 1e-2, 0], [4, 4, 1, (6, 12), (6, 12), "legendre-gauss", "legendre-gauss", 1e-2, 0],
[4, 4, 1, (6, 12), (6, 12), "lobatto", "lobatto", 1e-2, 0], [4, 4, 1, (6, 12), (6, 12), "lobatto", "lobatto", 1e-2, 0],
], ],
......
...@@ -520,6 +520,16 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function): ...@@ -520,6 +520,16 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
B, _, H, W = grad_output.shape B, _, H, W = grad_output.shape
grad_output = grad_output.reshape(B*nh, -1, H, W) grad_output = grad_output.reshape(B*nh, -1, H, W)
# save type and convert to float32
kw_dtype = kw.dtype
vw_dtype = vw.dtype
qw_dtype = qw.dtype
kw = kw.to(torch.float32).contiguous()
vw = vw.to(torch.float32).contiguous()
qw = qw.to(torch.float32).contiguous()
grad_output = grad_output.to(torch.float32).contiguous()
dkw,dvw,dqw = attention_cuda_extension.backward_dkvq(kw, vw, qw, grad_output, dkw,dvw,dqw = attention_cuda_extension.backward_dkvq(kw, vw, qw, grad_output,
quad_weights, quad_weights,
col_idx, row_off, col_idx, row_off,
...@@ -533,6 +543,11 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function): ...@@ -533,6 +543,11 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
_, C, H, W = dqw.shape _, C, H, W = dqw.shape
dqw = dqw.reshape(B, -1, H, W) dqw = dqw.reshape(B, -1, H, W)
# convert precision
dkw = dkw.to(dtype=kw_dtype)
dvw = dvw.to(dtype=vw_dtype)
dqw = dqw.to(dtype=qw_dtype)
# input grads # input grads
dv = torch.nn.functional.conv2d(dvw, weight=wv.permute([1,0,2,3]), bias=None) dv = torch.nn.functional.conv2d(dvw, weight=wv.permute([1,0,2,3]), bias=None)
dk = torch.nn.functional.conv2d(dkw, weight=wk.permute([1,0,2,3]), bias=None) dk = torch.nn.functional.conv2d(dkw, weight=wk.permute([1,0,2,3]), bias=None)
......
...@@ -34,7 +34,11 @@ ...@@ -34,7 +34,11 @@
#include <cstdint> #include <cstdint>
#include <torch/torch.h> #include <torch/torch.h>
#define CHECK_CUDA_TENSOR(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA_TENSOR(x) TORCH_INTERNAL_ASSERT(x.device().type() == torch::kCUDA)
#define CHECK_CONTIGUOUS_TENSOR(x) TORCH_INTERNAL_ASSERT(x.is_contiguous() || x.is_contiguous(at::MemoryFormat::ChannelsLast))
#define CHECK_CUDA_INPUT_TENSOR(x) \
CHECK_CUDA_TENSOR(x); \
CHECK_CONTIGUOUS_TENSOR(x)
torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, at::Tensor vx, at::Tensor qy, at::Tensor quad_weights, torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, at::Tensor vx, at::Tensor qy, at::Tensor quad_weights,
at::Tensor psi_col_idx, at::Tensor psi_row_off, int nlon_in, int nlat_out, at::Tensor psi_col_idx, at::Tensor psi_row_off, int nlon_in, int nlat_out,
......
// coding=utf-8
//
// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "attention.cuh"
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <ATen/cuda/detail/KernelUtils.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/CUDAUtils.h>
#include <cuda_runtime.h>
#include <cub/cub.cuh>
#include <limits>
#include "cudamacro.h"
#include "attention_utils.cuh"
#define THREADS (64)
#define TRANSP_WARPS_X_TILE_GENERIC (32)
#define TRANSP_WARPS_X_TILE_SM100 (4)
// BEGIN - CSR rows sorting kernels and functions
__global__ void set_rlen_rids_k(const int n,
const int64_t *__restrict__ offs,
int *__restrict__ rids,
int *__restrict__ rlen) {
const int nth = gridDim.x*blockDim.x;
const int tid = blockIdx.x*blockDim.x + threadIdx.x;
for(int i = tid; i < n; i += nth) {
rids[i] = i;
rlen[i] = offs[i+1]-offs[i];
}
return;
}
at::Tensor sortRows(int nlat_out, at::Tensor row_off, cudaStream_t stream) {
int64_t *_row_off_d = reinterpret_cast<int64_t *>(row_off.data_ptr());
auto options = torch::TensorOptions().dtype(torch::kInt32).device(row_off.device());
torch::Tensor rids_d = torch::empty({nlat_out}, options);
torch::Tensor rlen_d = torch::empty({nlat_out}, options);
int *_rids_d = reinterpret_cast<int *>(rids_d.data_ptr());
int *_rlen_d = reinterpret_cast<int *>(rlen_d.data_ptr());
const int grid = DIV_UP(nlat_out, THREADS);
const int block = THREADS;
set_rlen_rids_k<<<grid, block, 0, stream>>>(nlat_out,
_row_off_d,
_rids_d,
_rlen_d);
torch::Tensor rids_sort_d = torch::empty({nlat_out}, options);
torch::Tensor rlen_sort_d = torch::empty({nlat_out}, options);
int *_rids_sort_d = reinterpret_cast<int *>(rids_sort_d.data_ptr());
int *_rlen_sort_d = reinterpret_cast<int *>(rlen_sort_d.data_ptr());
size_t temp_storage_bytes = 0;
CHECK_CUDA(cub::DeviceRadixSort::SortPairsDescending(NULL, temp_storage_bytes,
_rlen_d, _rlen_sort_d,
_rids_d, _rids_sort_d,
nlat_out, 0, sizeof(*_rlen_d)*8, stream));
options = torch::TensorOptions().dtype(torch::kByte).device(row_off.device());
torch::Tensor temp_storage_d = torch::empty({int64_t(temp_storage_bytes)}, options);
void *_temp_storage_d = reinterpret_cast<void *>(temp_storage_d.data_ptr());
CHECK_CUDA(cub::DeviceRadixSort::SortPairsDescending(_temp_storage_d, temp_storage_bytes,
_rlen_d, _rlen_sort_d,
_rids_d, _rids_sort_d,
nlat_out, 0, sizeof(*_rlen_d)*8, stream));
return rids_sort_d;
}
// END - CSR rows sorting kernels and functions
// BEGIN - 4D tensor permutation kernels and functions
__global__ void empty_k() {}
static int getPtxver() {
cudaFuncAttributes attrs;
CHECK_CUDA(cudaFuncGetAttributes(&attrs, empty_k));
return attrs.ptxVersion*10;
}
at::Tensor permute_4D_to0231(at::Tensor src) {
auto options = torch::TensorOptions().dtype(src.dtype()).device(src.device());
torch::Tensor dst = torch::empty({src.size(0), src.size(2), src.size(3), src.size(1)}, options);
const int ptxv = getPtxver();
// to be further specialized for additional archs, if necessary
if (ptxv < 100) {
AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0231_k_tile_generic", ([&] {
launch_permute_to0231<TRANSP_WARPS_X_TILE_GENERIC, scalar_t>(src, dst);
}));
CHECK_ERROR("permute_to0231_k_tile_generic");
} else {
AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0231_k_tile_sm100", ([&] {
launch_permute_to0231<TRANSP_WARPS_X_TILE_SM100, scalar_t>(src, dst);
}));
CHECK_ERROR("permute_to0231_k_tile_sm100");
}
return dst;
}
at::Tensor permute_4D_to0312(at::Tensor src) {
auto options = torch::TensorOptions().dtype(src.dtype()).device(src.device());
torch::Tensor dst = torch::empty({src.size(0), src.size(3), src.size(1), src.size(2)}, options);
const int ptxv = getPtxver();
// to be further specialized for additional archs, if necessary
if (ptxv < 100) {
AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0312_k_tile_generic", ([&] {
launch_permute_to0312<TRANSP_WARPS_X_TILE_GENERIC, scalar_t>(src, dst);
}));
CHECK_ERROR("permute_to0312_k_tile_generic");
} else {
AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0312_k_tile_sm100", ([&] {
launch_permute_to0312<TRANSP_WARPS_X_TILE_SM100, scalar_t>(src, dst);
}));
CHECK_ERROR("permute_to0312_k_tile_sm100");
}
return dst;
}
// END - tensor permutation kernels and functions
// BEGIN - general host-side functions
unsigned int next_pow2(unsigned int x) {
x -= 1;
#pragma unroll
for(int i = 1; i <= sizeof(x)*8 / 2; i *= 2) {
x |= x >> i;
}
return x+1;
}
// END - general host-side functions
// coding=utf-8
//
// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <ATen/ATen.h>
#define WARP_SIZE (32)
#define FULL_MASK (0xFFFFFFFF)
#define DIV_UP(a,b) (((a)+((b)-1))/(b))
// CSR rows sorting kernels and functions
at::Tensor sortRows(int nlat_out, at::Tensor row_off, cudaStream_t stream);
// 4D tensor permutation kernels and functions
at::Tensor permute_4D_to0231(at::Tensor src);
at::Tensor permute_4D_to0312(at::Tensor src);
// Host tensor dump and CSR manipulation functions
void dump_tensor(const char *fname, at::Tensor t);
void dump_csr(const char *fname, at::Tensor roff, at::Tensor cols);
int part_csr_rows(int *row_perm,
const at::Tensor roff,
const at::Tensor cols,
int **part_off,
int **part_val);
int verify_part(const int npart,
const int *part_off,
const int *part_val,
const at::Tensor roff,
const at::Tensor cols);
void verify_part_new(const int nlon_out,
const int nlat_in,
const int nlon_in,
const int npart, // partitioning data
const int *part_off,
const int *part_val,
const at::Tensor roff,
const at::Tensor cols);
unsigned int next_pow2(unsigned int x);
// utility host functions and templates
template<unsigned int ALIGN>
int is_aligned(const void *ptr) {
static_assert(0 == (ALIGN & (ALIGN-1)));
return (0 == (uintptr_t(ptr) & (ALIGN-1)));
}
// utility device functions and templates
template<typename FLOATV_T>
__device__ FLOATV_T __vset(float x) {
static_assert(sizeof(FLOATV_T) == 0, "Unsupported type for __vset");
return FLOATV_T{};
}
template<>
__device__ float __forceinline__ __vset<float>(float x) {
return x;
}
__device__ float __forceinline__ __vmul(float a, float b) {
return a*b;
}
__device__ float __forceinline__ __vadd(float a, float b) {
return a+b;
}
__device__ float __forceinline__ __vsub(float a, float b) {
return a-b;
}
__device__ float __forceinline__ __vred(float a) {
return a;
}
__device__ float __forceinline__ __vscale(float s, float v) {
return v*s;
}
__device__ float __forceinline__ __vdiv(float s, float v) {
return v/s;
}
template<>
__device__ float4 __forceinline__ __vset<float4>(float x) {
return make_float4(x, x, x, x);
}
__device__ float4 __forceinline__ __vmul(float4 a, float4 b) {
return make_float4(a.x*b.x, a.y*b.y, a.z*b.z, a.w*b.w);
}
__device__ float4 __forceinline__ __vadd(float4 a, float4 b) {
return make_float4(a.x+b.x, a.y+b.y, a.z+b.z, a.w+b.w);
}
__device__ float4 __forceinline__ __vsub(float4 a, float4 b) {
return make_float4(a.x-b.x, a.y-b.y, a.z-b.z, a.w-b.w);
}
__device__ float __forceinline__ __vred(float4 a) {
return a.x + a.y + a.z + a.w;
}
__device__ float4 __forceinline__ __vscale(float s, float4 v) {
return make_float4(s*v.x, s*v.y, s*v.z, s*v.w);
}
__device__ float4 __forceinline__ __vdiv(float s, float4 v) {
return make_float4(s/v.x, s/v.y, s/v.z, s/v.w);;
}
template<typename VAL_T>
__device__ VAL_T __warp_sum(VAL_T val) {
#pragma unroll
for(int i = WARP_SIZE/2; i; i /= 2) {
val += __shfl_xor_sync(FULL_MASK, val, i, WARP_SIZE);
}
return val;
}
template<int BDIM_X,
int BDIM_Y=1,
int BDIM_Z=1,
typename VAL_T>
__device__ VAL_T __block_sum(VAL_T val) {
const int NWARP = (BDIM_X*BDIM_Y*BDIM_Z) / WARP_SIZE;
val = __warp_sum(val);
if constexpr(NWARP > 1) {
int tid = threadIdx.x;
if constexpr(BDIM_Y > 1) { tid += threadIdx.y*BDIM_X; }
if constexpr(BDIM_Z > 1) { tid += threadIdx.z*BDIM_X*BDIM_Y; }
const int lid = tid%WARP_SIZE;
const int wid = tid/WARP_SIZE;
__shared__ VAL_T sh[NWARP];
if (lid == 0) {
sh[wid] = val;
}
__syncthreads();
if (wid == 0) {
val = (lid < NWARP) ? sh[lid] : 0;
val = __warp_sum(val);
__syncwarp();
if (!lid) {
sh[0] = val;
}
}
__syncthreads();
val = sh[0];
__syncthreads();
}
return val;
}
// transpose utils
template<int BDIM_X,
int BDIM_Y,
typename VAL_T>
__global__
__launch_bounds__(BDIM_X*BDIM_Y)
void permute_to0231_k(const int nchn,
const int nlat,
const int nlon,
const at::PackedTensorAccessor32<VAL_T, 4, at::RestrictPtrTraits> src,
at::PackedTensorAccessor32<VAL_T, 4, at::RestrictPtrTraits> dst) {
static_assert(!(BDIM_X & (BDIM_X-1)));
static_assert(!(BDIM_Y & (BDIM_Y-1)));
static_assert(BDIM_X >= BDIM_Y);
__shared__ VAL_T sh[BDIM_X][BDIM_X+1];
const int tidx = threadIdx.x;
const int tidy = threadIdx.y;
const int coff = blockIdx.x*BDIM_X; // channel offset
const int woff = blockIdx.y*BDIM_X; // width offset
const int batch = blockIdx.z / nlat; // batch (same for all block)
const int h = blockIdx.z - (batch * nlat); // height (same for all block)
const int nchn_full = (nchn-coff) >= BDIM_X;
const int nlon_full = (nlon-woff) >= BDIM_X;
if (nchn_full && nlon_full) {
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
sh[j+tidy][tidx] = src[batch][coff + j+tidy][h][woff+tidx];
}
__syncthreads();
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
dst[batch][h][woff + j+tidy][coff+tidx] = sh[tidx][j+tidy];
}
} else {
if (woff+tidx < nlon) {
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
sh[j+tidy][tidx] = (coff + j+tidy < nchn) ? src[batch][coff + j+tidy][h][woff+tidx] : VAL_T(0);
}
}
__syncthreads();
if (coff+tidx < nchn) {
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
if (woff + j+tidy < nlon) {
dst[batch][h][woff + j+tidy][coff+tidx] = sh[tidx][j+tidy];
}
}
}
}
return;
}
template<int WARPS_X_TILE, typename VAL_T>
void launch_permute_to0231(at::Tensor src, at::Tensor dst){
dim3 block;
dim3 grid;
block.x = WARP_SIZE;
block.y = WARPS_X_TILE;
grid.x = DIV_UP(src.size(1), block.x);
grid.y = DIV_UP(src.size(3), block.x);
grid.z = src.size(2)*src.size(0);
assert(grid.y < 65536);
assert(grid.z < 65536);
// get stream
auto stream = at::cuda::getCurrentCUDAStream().stream();
permute_to0231_k<WARP_SIZE, WARPS_X_TILE>
<<<grid, block, 0, stream>>>(src.size(1),
src.size(2),
src.size(3),
src.packed_accessor32<VAL_T, 4, at::RestrictPtrTraits>(),
dst.packed_accessor32<VAL_T, 4, at::RestrictPtrTraits>());
}
template<int BDIM_X,
int BDIM_Y,
typename VAL_T>
__global__
__launch_bounds__(BDIM_X*BDIM_Y)
void permute_to0312_k(const int nchn,
const int nlat,
const int nlon,
const at::PackedTensorAccessor32<VAL_T, 4, at::RestrictPtrTraits> src,
at::PackedTensorAccessor32<VAL_T, 4, at::RestrictPtrTraits> dst) {
static_assert(!(BDIM_X & (BDIM_X-1)));
static_assert(!(BDIM_Y & (BDIM_Y-1)));
static_assert(BDIM_X >= BDIM_Y);
__shared__ VAL_T sh[BDIM_X][BDIM_X+1];
const int tidx = threadIdx.x;
const int tidy = threadIdx.y;
const int woff = blockIdx.x*BDIM_X; // width offset
const int coff = blockIdx.y*BDIM_X; // channel offset
const int batch = blockIdx.z / nlat; // batch (same for all block)
const int h = blockIdx.z - (batch * nlat); // height (same for all block)
const int nchn_full = (nchn-coff) >= BDIM_X;
const int nlon_full = (nlon-woff) >= BDIM_X;
if (nchn_full && nlon_full) {
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
sh[j+tidy][tidx] = src[batch][h][woff + j+tidy][coff+tidx];
}
__syncthreads();
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
dst[batch][coff + j+tidy][h][woff+tidx] = sh[tidx][j+tidy];
}
} else {
if (coff+tidx < nchn) {
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
sh[j+tidy][tidx] = (woff + j+tidy < nlon) ? src[batch][h][woff + j+tidy][coff+tidx] : VAL_T(0);
}
}
__syncthreads();
if (woff+tidx < nlon) {
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
if (coff + j+tidy < nchn) {
dst[batch][coff + j+tidy][h][woff+tidx] = sh[tidx][j+tidy];;
}
}
}
}
return;
}
template<int WARPS_X_TILE, typename VAL_T>
void launch_permute_to0312(at::Tensor src, at::Tensor dst){
dim3 block;
dim3 grid;
block.x = WARP_SIZE;
block.y = WARPS_X_TILE;
grid.x = DIV_UP(src.size(2), block.x);
grid.y = DIV_UP(src.size(3), block.x);
grid.z = src.size(1)*src.size(0);
assert(grid.y < 65536);
assert(grid.z < 65536);
// get stream
auto stream = at::cuda::getCurrentCUDAStream().stream();
permute_to0312_k<WARP_SIZE, WARPS_X_TILE>
<<<grid, block, 0, stream>>>(src.size(3),
src.size(1),
src.size(2),
src.packed_accessor32<VAL_T, 4, at::RestrictPtrTraits>(),
dst.packed_accessor32<VAL_T, 4, at::RestrictPtrTraits>());
}
// coding=utf-8
//
// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#define CHECK_CUDA(call) { \
cudaError_t err = call; \
if( cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\n", \
__FILE__, __LINE__, cudaGetErrorString( err) ); \
exit(EXIT_FAILURE); \
}}
#define CHECK_ERROR(errorMessage) { \
cudaError_t err = cudaGetLastError(); \
if( cudaSuccess != err) { \
fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", \
errorMessage, __FILE__, __LINE__, cudaGetErrorString( err) );\
exit(EXIT_FAILURE); \
}}
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.amp as amp
import torch.nn.functional as F import torch.nn.functional as F
from typing import Optional from typing import Optional
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
...@@ -259,11 +260,14 @@ class W11LossS2(SphericalLossBase): ...@@ -259,11 +260,14 @@ class W11LossS2(SphericalLossBase):
self.register_buffer("k_theta_mesh", k_theta_mesh) self.register_buffer("k_theta_mesh", k_theta_mesh)
def _compute_loss_term(self, prd: torch.Tensor, tar: torch.Tensor) -> torch.Tensor: def _compute_loss_term(self, prd: torch.Tensor, tar: torch.Tensor) -> torch.Tensor:
prd_prime_fft2_phi_h = torch.fft.ifft2(1j * self.k_phi_mesh * torch.fft.fft2(prd)).real prdtype = prd.dtype
prd_prime_fft2_theta_h = torch.fft.ifft2(1j * self.k_theta_mesh * torch.fft.fft2(prd)).real with amp.autocast(device_type="cuda", enabled=False):
prd = prd.to(torch.float32)
tar_prime_fft2_phi_h = torch.fft.ifft2(1j * self.k_phi_mesh * torch.fft.fft2(tar)).real prd_prime_fft2_phi_h = torch.fft.ifft2(1j * self.k_phi_mesh * torch.fft.fft2(prd)).real
tar_prime_fft2_theta_h = torch.fft.ifft2(1j * self.k_theta_mesh * torch.fft.fft2(tar)).real prd_prime_fft2_theta_h = torch.fft.ifft2(1j * self.k_theta_mesh * torch.fft.fft2(prd)).real
tar_prime_fft2_phi_h = torch.fft.ifft2(1j * self.k_phi_mesh * torch.fft.fft2(tar)).real
tar_prime_fft2_theta_h = torch.fft.ifft2(1j * self.k_theta_mesh * torch.fft.fft2(tar)).real
# Return the element-wise loss term # Return the element-wise loss term
return torch.abs(prd_prime_fft2_phi_h - tar_prime_fft2_phi_h) + torch.abs(prd_prime_fft2_theta_h - tar_prime_fft2_theta_h) return torch.abs(prd_prime_fft2_phi_h - tar_prime_fft2_phi_h) + torch.abs(prd_prime_fft2_theta_h - tar_prime_fft2_theta_h)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment