Commit ec413b4d authored by Thorsten Kurth's avatar Thorsten Kurth
Browse files

Merge branch 'mr/bwd-channel-permute-experiments' of...

Merge branch 'mr/bwd-channel-permute-experiments' of https://github.com/rietmann-nv/torch-harmonics into mr/bwd-channel-permute-experiments
parents 1a47fa08 37b08bb8
......@@ -43,7 +43,10 @@ except ImportError as err:
attention_cuda_extension = None
_cuda_extension_available = False
# s2 neighborhood attention forward pass
# uses qdotk_max update trick to avoid two loops when computing the softmax
# see e.g., https://arxiv.org/abs/1805.02867
# and https://alexdremov.me/understanding-flash-attention-writing-the-algorithm-from-scratch-in-triton/
def _neighborhood_attention_s2_fwd_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor,
quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
......@@ -61,7 +64,7 @@ def _neighborhood_attention_s2_fwd_torch(kx: torch.Tensor, vx: torch.Tensor, qy:
for wo in range(nlon_out):
alpha_sum = torch.zeros((y.shape[0],), dtype=y.dtype, device=y.device)
qdotk_nz = torch.zeros((y.shape[0], zend-zstart,), dtype=y.dtype, device=y.device)
qdotk_max = torch.zeros((y.shape[0],), dtype=y.dtype, device=y.device)
for idz in range(zstart, zend):
nz_col_idx = col_idx[idz]
......@@ -75,24 +78,19 @@ def _neighborhood_attention_s2_fwd_torch(kx: torch.Tensor, vx: torch.Tensor, qy:
# compute correlation & softmax numerator
q_ho_wo = qy[:, :, ho, wo]
k_hi_wip = kx[:, :, hi, wip]
qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hi_wip, dim=1)
qdotk_max, _ = torch.max(qdotk_nz, dim=1)
for idz in range(zstart, zend):
nz_col_idx = col_idx[idz]
qdotk = torch.sum(q_ho_wo * k_hi_wip, dim=1)
# compute input indices from psi datastructure
hi = nz_col_idx // nlon_in
# account for output shift and ensure positive index due to circular condition
wi = nz_col_idx % nlon_in
wip = (wi + wo) % nlon_in
alpha = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max)
# softmax denominator
alpha_sum[:] += alpha[:] * quad_weights[hi]
# tmp max
qdotk_max_tmp = torch.maximum(qdotk_max, qdotk)
y[:,:,ho,wo] += alpha[:, None] * vx[:,:,hi,wip] * quad_weights[hi]
# alpha sum update
alpha = torch.exp(qdotk - qdotk_max_tmp) * quad_weights[hi]
alpha_sum = alpha + alpha_sum * torch.exp(qdotk_max - qdotk_max_tmp)
# update output
y[:,:,ho,wo] = y[:,:,ho,wo] * torch.exp(qdotk_max - qdotk_max_tmp).unsqueeze(1) + alpha[:, None] * vx[:,:,hi,wip]
# define new max
qdotk_max = qdotk_max_tmp
y[:,:,ho,wo] = y[:,:,ho,wo] / alpha_sum[:, None]
......
......@@ -256,7 +256,7 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
CHECK_CUDA(cudaEventRecord(stop, stream));
CHECK_CUDA(cudaEventSynchronize(stop));
CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop));
// printf("s2_attention_kernel_mbT execution time: %f ms\n", milliseconds);
// printf("s2_attention_kernel_fwd execution time: %f ms\n", milliseconds);
CHECK_CUDA(cudaEventDestroy(start));
CHECK_CUDA(cudaEventDestroy(stop));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment