_disco_convolution.py 10.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

32
from typing import Optional
33
34
35
import math

import torch
36
from torch.amp import custom_fwd, custom_bwd
37

Boris Bonev's avatar
Boris Bonev committed
38
39
40
41
try:
    import disco_cuda_extension
except ImportError as err:
    disco_cuda_extension = None
42

43
44
# some helper functions
def _get_psi(kernel_size: int, psi_idx: torch.Tensor, psi_vals: torch.Tensor, nlat_in: int, nlon_in: int, nlat_out: int, nlon_out: int, nlat_in_local: Optional[int] = None, nlat_out_local: Optional[int] = None, semi_transposed: Optional[bool] = False):
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    """Creates a sparse tensor for spherical harmonic convolution operations.
    
    This function constructs a sparse COO tensor from indices and values, with optional
    semi-transposition for computational efficiency in spherical harmonic convolutions.
    
    Args:
        kernel_size: Number of kernel elements.
        psi_idx: Tensor of shape (3, n_nonzero) containing the indices for the sparse tensor.
            The three dimensions represent [kernel_idx, lat_idx, combined_lat_lon_idx].
        psi_vals: Tensor of shape (n_nonzero,) containing the values for the sparse tensor.
        nlat_in: Number of input latitude points.
        nlon_in: Number of input longitude points.
        nlat_out: Number of output latitude points.
        nlon_out: Number of output longitude points.
        nlat_in_local: Local number of input latitude points. If None, defaults to nlat_in.
        nlat_out_local: Local number of output latitude points. If None, defaults to nlat_out.
        semi_transposed: If True, performs a semi-transposition to facilitate computation
            by flipping the longitude axis and reorganizing indices.
    
    Returns:
        torch.Tensor: A sparse COO tensor of shape (kernel_size, nlat_out_local, nlat_in_local * nlon)
            where nlon is either nlon_in or nlon_out depending on semi_transposed flag.
            The tensor is coalesced to remove duplicate indices.
    
    Note:
        When semi_transposed=True, the function performs a partial transpose operation
        that flips the longitude axis and reorganizes the indices to facilitate
        efficient spherical harmonic convolution computations.
    """
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
    nlat_in_local = nlat_in_local if nlat_in_local is not None else nlat_in
    nlat_out_local = nlat_out_local if nlat_out_local is not None else nlat_out
    
    if semi_transposed:
        # do partial transpose
        # we do a semi-transposition to faciliate the computation
        tout = psi_idx[2] // nlon_out
        pout = psi_idx[2] % nlon_out
        # flip the axis of longitudes
        pout = nlon_out - 1 - pout
        tin = psi_idx[1]
        idx = torch.stack([psi_idx[0], tout, tin * nlon_out + pout], dim=0)
        psi = torch.sparse_coo_tensor(idx, psi_vals, size=(kernel_size, nlat_out_local, nlat_in_local * nlon_out)).coalesce()
    else:
        psi = torch.sparse_coo_tensor(psi_idx, psi_vals, size=(kernel_size, nlat_out_local, nlat_in_local * nlon_in)).coalesce()
    return psi

91

Boris Bonev's avatar
Boris Bonev committed
92
class _DiscoS2ContractionCuda(torch.autograd.Function):
apaaris's avatar
apaaris committed
93

94
    @staticmethod
Thorsten Kurth's avatar
Thorsten Kurth committed
95
    @custom_fwd(device_type="cuda")
Boris Bonev's avatar
Boris Bonev committed
96
97
98
    def forward(ctx, x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor,
                row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor,
                kernel_size: int, nlat_out: int, nlon_out: int):
apaaris's avatar
apaaris committed
99
        
Boris Bonev's avatar
Boris Bonev committed
100
101
102
        ctx.save_for_backward(roff_idx, ker_idx, row_idx, col_idx, vals)
        ctx.kernel_size = kernel_size
        ctx.nlat_in = x.shape[-2]
103
        ctx.nlon_in = x.shape[-1]
Thorsten Kurth's avatar
Thorsten Kurth committed
104
105
106
107
        xtype = x.dtype
        x = x.to(torch.float32).contiguous()
        output = disco_cuda_extension.forward(x, roff_idx, ker_idx, row_idx, col_idx, vals, kernel_size, nlat_out, nlon_out)
        output = output.to(xtype)
108

109
        return output
110
111

    @staticmethod
112
    @custom_bwd(device_type="cuda")
113
    def backward(ctx, grad_output):
114

Boris Bonev's avatar
Boris Bonev committed
115
        roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors
Thorsten Kurth's avatar
Thorsten Kurth committed
116
117
118
        gtype =	grad_output.dtype
        grad_output = grad_output.to(torch.float32).contiguous()
        grad_input = disco_cuda_extension.backward(grad_output, roff_idx, ker_idx, row_idx, col_idx, vals,
Boris Bonev's avatar
Boris Bonev committed
119
                                         ctx.kernel_size, ctx.nlat_in, ctx.nlon_in)
Thorsten Kurth's avatar
Thorsten Kurth committed
120
        grad_input = grad_input.to(gtype)
121

Boris Bonev's avatar
Boris Bonev committed
122
        return grad_input, None, None, None, None, None, None, None, None
123

Boris Bonev's avatar
Boris Bonev committed
124
125

class _DiscoS2TransposeContractionCuda(torch.autograd.Function):
apaaris's avatar
apaaris committed
126

127
    @staticmethod
Thorsten Kurth's avatar
Thorsten Kurth committed
128
    @custom_fwd(device_type="cuda")
Boris Bonev's avatar
Boris Bonev committed
129
130
131
    def forward(ctx, x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor,
                row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor,
                kernel_size: int, nlat_out: int, nlon_out: int):
apaaris's avatar
apaaris committed
132
        
Boris Bonev's avatar
Boris Bonev committed
133
134
135
        ctx.save_for_backward(roff_idx, ker_idx, row_idx, col_idx, vals)
        ctx.kernel_size = kernel_size
        ctx.nlat_in = x.shape[-2]
136
        ctx.nlon_in = x.shape[-1]
Thorsten Kurth's avatar
Thorsten Kurth committed
137
138
139
140
        xtype =	x.dtype
        x = x.to(torch.float32).contiguous()
        output = disco_cuda_extension.backward(x, roff_idx, ker_idx, row_idx, col_idx, vals, kernel_size, nlat_out, nlon_out)
        output = output.to(xtype)
141

142
        return output
143
144

    @staticmethod
145
    @custom_bwd(device_type="cuda")
146
    def backward(ctx, grad_output):
147
       
Boris Bonev's avatar
Boris Bonev committed
148
        roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors
Thorsten Kurth's avatar
Thorsten Kurth committed
149
150
151
        gtype = grad_output.dtype
        grad_output = grad_output.to(torch.float32).contiguous()
        grad_input = disco_cuda_extension.forward(grad_output, roff_idx, ker_idx, row_idx, col_idx, vals,
Boris Bonev's avatar
Boris Bonev committed
152
                                        ctx.kernel_size, ctx.nlat_in, ctx.nlon_in)
Thorsten Kurth's avatar
Thorsten Kurth committed
153
        grad_input = grad_input.to(gtype)
154

Boris Bonev's avatar
Boris Bonev committed
155
        return grad_input, None, None, None, None, None, None, None, None
156

Boris Bonev's avatar
Boris Bonev committed
157
158
159
160
161
162
# CUDA
def _disco_s2_contraction_cuda(x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor,
                               row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor,
                               kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
    return _DiscoS2ContractionCuda.apply(x, roff_idx, ker_idx, row_idx, col_idx, vals,
                                         kernel_size, nlat_out, nlon_out)
163

Boris Bonev's avatar
Boris Bonev committed
164
165
166
167
168
def _disco_s2_transpose_contraction_cuda(x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor,
                                         row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor,
                                         kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
    return _DiscoS2TransposeContractionCuda.apply(x, roff_idx, ker_idx, row_idx, col_idx, vals,
                                                  kernel_size, nlat_out, nlon_out)
169
170
171


def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: int):
172

173
174
175
176
177
178
179
180
181
    assert len(psi.shape) == 3
    assert len(x.shape) == 4
    psi = psi.to(x.device)

    batch_size, n_chans, nlat_in, nlon_in = x.shape
    kernel_size, nlat_out, _ = psi.shape

    assert psi.shape[-1] == nlat_in * nlon_in
    assert nlon_in % nlon_out == 0
Boris Bonev's avatar
Boris Bonev committed
182
    assert nlon_in >= nlat_out
183
184
    pscale = nlon_in // nlon_out

185
    # add a dummy dimension for nkernel and move the batch and channel dims to the end
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
    x = x.reshape(1, batch_size * n_chans, nlat_in, nlon_in).permute(0, 2, 3, 1)
    x = x.expand(kernel_size, -1, -1, -1)

    y = torch.zeros(nlon_out, kernel_size, nlat_out, batch_size * n_chans, device=x.device, dtype=x.dtype)

    for pout in range(nlon_out):
        # sparse contraction with psi
        y[pout] = torch.bmm(psi, x.reshape(kernel_size, nlat_in * nlon_in, -1))
        # we need to repeatedly roll the input tensor to faciliate the shifted multiplication
        x = torch.roll(x, -pscale, dims=2)

    # reshape y back to expose the correct dimensions
    y = y.permute(3, 1, 2, 0).reshape(batch_size, n_chans, kernel_size, nlat_out, nlon_out)

    return y


def _disco_s2_transpose_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: int):
    assert len(psi.shape) == 3
    assert len(x.shape) == 5
    psi = psi.to(x.device)

    batch_size, n_chans, kernel_size, nlat_in, nlon_in = x.shape
209
    kernel_size, nlat_out, n_out = psi.shape
210
211

    assert n_out % nlon_out == 0
Boris Bonev's avatar
Boris Bonev committed
212
    assert nlon_out >= nlon_in
213
214
215
216
    pscale = nlon_out // nlon_in

    # interleave zeros along the longitude dimension to allow for fractional offsets to be considered
    x_ext = torch.zeros(kernel_size, nlat_in, nlon_out, batch_size * n_chans, device=x.device, dtype=x.dtype)
217
    x = x.reshape(batch_size * n_chans, kernel_size, nlat_in, nlon_in).permute(1, 2, 3, 0)
Boris Bonev's avatar
Boris Bonev committed
218

219
220
221
    # x has shape kernel_size x nlat_in x nlon_in x batch_size * n_chans
    # we only need to apoply the nlon stride here, since nlat stride is taken care of by the kernel
    x_ext[:, :, ::pscale, :] = x[...]
222

223
    # create output tensor
224
225
226
227
228
229
230
    y = torch.zeros(kernel_size, nlon_out, nlat_out, batch_size * n_chans, device=x.device, dtype=x.dtype)

    for pout in range(nlon_out):
        # we need to repeatedly roll the input tensor to faciliate the shifted multiplication
        # TODO: double-check why this has to happen first
        x_ext = torch.roll(x_ext, -1, dims=2)
        # sparse contraction with the modified psi
231
        y[:, pout, :, :] = torch.bmm(psi, x_ext.reshape(kernel_size, nlat_in * nlon_out, -1))
232
233

    # sum over the kernel dimension and reshape to the correct output size
234
    y = y.sum(dim=0).permute(2, 1, 0).reshape(batch_size, n_chans, nlat_out, nlon_out).contiguous()
235
236
237

    return y