_disco_convolution.py 12.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

32
from typing import Optional
33
34
35
import math

import torch
36
from torch.amp import custom_fwd, custom_bwd
37

Boris Bonev's avatar
Boris Bonev committed
38
39
40
41
try:
    import disco_cuda_extension
except ImportError as err:
    disco_cuda_extension = None
42

43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# some helper functions
def _get_psi(kernel_size: int, psi_idx: torch.Tensor, psi_vals: torch.Tensor, nlat_in: int, nlon_in: int, nlat_out: int, nlon_out: int, nlat_in_local: Optional[int] = None, nlat_out_local: Optional[int] = None, semi_transposed: Optional[bool] = False):

    nlat_in_local = nlat_in_local if nlat_in_local is not None else nlat_in
    nlat_out_local = nlat_out_local if nlat_out_local is not None else nlat_out
    
    if semi_transposed:
        # do partial transpose
        # we do a semi-transposition to faciliate the computation
        tout = psi_idx[2] // nlon_out
        pout = psi_idx[2] % nlon_out
        # flip the axis of longitudes
        pout = nlon_out - 1 - pout
        tin = psi_idx[1]
        idx = torch.stack([psi_idx[0], tout, tin * nlon_out + pout], dim=0)
        psi = torch.sparse_coo_tensor(idx, psi_vals, size=(kernel_size, nlat_out_local, nlat_in_local * nlon_out)).coalesce()
    else:
        psi = torch.sparse_coo_tensor(psi_idx, psi_vals, size=(kernel_size, nlat_out_local, nlat_in_local * nlon_in)).coalesce()
    return psi

63

Boris Bonev's avatar
Boris Bonev committed
64
class _DiscoS2ContractionCuda(torch.autograd.Function):
apaaris's avatar
apaaris committed
65
66
67
68
69
70
    r"""
    CUDA implementation of the discrete-continuous convolution contraction on the sphere.
    This class provides the forward and backward passes for efficient GPU computation
    of the S2 convolution operation using custom CUDA kernels.
    """

71
    @staticmethod
Thorsten Kurth's avatar
Thorsten Kurth committed
72
    @custom_fwd(device_type="cuda")
Boris Bonev's avatar
Boris Bonev committed
73
74
75
    def forward(ctx, x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor,
                row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor,
                kernel_size: int, nlat_out: int, nlon_out: int):
apaaris's avatar
apaaris committed
76
77
78
        r"""
        Forward pass for CUDA S2 convolution contraction.
        
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
        Parameters
        -----------
        ctx: torch.autograd.function.Context
            Context object
        x: torch.Tensor
            Input tensor
        roff_idx: torch.Tensor
            Row offset indices for sparse computation
        ker_idx: torch.Tensor
            Kernel indices
        row_idx: torch.Tensor
            Row indices for sparse computation
        col_idx: torch.Tensor
            Column indices for sparse computation
        vals: torch.Tensor
            Values for sparse computation
        kernel_size: int
            Size of the kernel
        nlat_out: int
            Number of output latitude points
        nlon_out: int
            Number of output longitude points
apaaris's avatar
apaaris committed
101
        """
Boris Bonev's avatar
Boris Bonev committed
102
103
104
        ctx.save_for_backward(roff_idx, ker_idx, row_idx, col_idx, vals)
        ctx.kernel_size = kernel_size
        ctx.nlat_in = x.shape[-2]
105
        ctx.nlon_in = x.shape[-1]
Thorsten Kurth's avatar
Thorsten Kurth committed
106
107
108
109
        xtype = x.dtype
        x = x.to(torch.float32).contiguous()
        output = disco_cuda_extension.forward(x, roff_idx, ker_idx, row_idx, col_idx, vals, kernel_size, nlat_out, nlon_out)
        output = output.to(xtype)
110

111
        return output
112
113

    @staticmethod
114
    @custom_bwd(device_type="cuda")
115
    def backward(ctx, grad_output):
apaaris's avatar
apaaris committed
116
117
118
        r"""
        Backward pass for CUDA S2 convolution contraction.
        
119
120
121
122
        Parameters
        -----------
        grad_output: torch.Tensor
            Gradient of the output
apaaris's avatar
apaaris committed
123
        
124
125
126
127
        Returns
        --------
        grad_input: torch.Tensor
            Gradient of the input
apaaris's avatar
apaaris committed
128
        """
Boris Bonev's avatar
Boris Bonev committed
129
        roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors
Thorsten Kurth's avatar
Thorsten Kurth committed
130
131
132
        gtype =	grad_output.dtype
        grad_output = grad_output.to(torch.float32).contiguous()
        grad_input = disco_cuda_extension.backward(grad_output, roff_idx, ker_idx, row_idx, col_idx, vals,
Boris Bonev's avatar
Boris Bonev committed
133
                                         ctx.kernel_size, ctx.nlat_in, ctx.nlon_in)
Thorsten Kurth's avatar
Thorsten Kurth committed
134
        grad_input = grad_input.to(gtype)
135

Boris Bonev's avatar
Boris Bonev committed
136
        return grad_input, None, None, None, None, None, None, None, None
137

Boris Bonev's avatar
Boris Bonev committed
138
139

class _DiscoS2TransposeContractionCuda(torch.autograd.Function):
apaaris's avatar
apaaris committed
140
141
142
143
144
145
    r"""
    CUDA implementation of the transpose discrete-continuous convolution contraction on the sphere.
    This class provides the forward and backward passes for efficient GPU computation
    of the transpose S2 convolution operation using custom CUDA kernels.
    """

146
    @staticmethod
Thorsten Kurth's avatar
Thorsten Kurth committed
147
    @custom_fwd(device_type="cuda")
Boris Bonev's avatar
Boris Bonev committed
148
149
150
    def forward(ctx, x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor,
                row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor,
                kernel_size: int, nlat_out: int, nlon_out: int):
apaaris's avatar
apaaris committed
151
152
153
        r"""
        Forward pass for CUDA transpose S2 convolution contraction.
        
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
        Parameters
        -----------
        ctx: torch.autograd.function.Context
            Context object
        x: torch.Tensor
            Input tensor
        roff_idx: torch.Tensor
            Row offset indices for sparse computation
        ker_idx: torch.Tensor
            Kernel indices
        row_idx: torch.Tensor
            Row indices for sparse computation
        col_idx: torch.Tensor
            Column indices for sparse computation
        vals: torch.Tensor
            Values for sparse computation
        kernel_size: int
            Size of the kernel
        nlat_out: int
            Number of output latitude points
        nlon_out: int
            Number of output longitude points
apaaris's avatar
apaaris committed
176
        """
Boris Bonev's avatar
Boris Bonev committed
177
178
179
        ctx.save_for_backward(roff_idx, ker_idx, row_idx, col_idx, vals)
        ctx.kernel_size = kernel_size
        ctx.nlat_in = x.shape[-2]
180
        ctx.nlon_in = x.shape[-1]
Thorsten Kurth's avatar
Thorsten Kurth committed
181
182
183
184
        xtype =	x.dtype
        x = x.to(torch.float32).contiguous()
        output = disco_cuda_extension.backward(x, roff_idx, ker_idx, row_idx, col_idx, vals, kernel_size, nlat_out, nlon_out)
        output = output.to(xtype)
185

186
        return output
187
188

    @staticmethod
189
    @custom_bwd(device_type="cuda")
190
    def backward(ctx, grad_output):
apaaris's avatar
apaaris committed
191
192
193
        r"""
        Backward pass for CUDA transpose S2 convolution contraction.
        
194
195
196
197
        Parameters
        -----------
        grad_output: torch.Tensor
            Gradient of the output
apaaris's avatar
apaaris committed
198
        
199
200
201
202
        Returns
        --------
        grad_input: torch.Tensor
            Gradient of the input
apaaris's avatar
apaaris committed
203
        """
Boris Bonev's avatar
Boris Bonev committed
204
        roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors
Thorsten Kurth's avatar
Thorsten Kurth committed
205
206
207
        gtype = grad_output.dtype
        grad_output = grad_output.to(torch.float32).contiguous()
        grad_input = disco_cuda_extension.forward(grad_output, roff_idx, ker_idx, row_idx, col_idx, vals,
Boris Bonev's avatar
Boris Bonev committed
208
                                        ctx.kernel_size, ctx.nlat_in, ctx.nlon_in)
Thorsten Kurth's avatar
Thorsten Kurth committed
209
        grad_input = grad_input.to(gtype)
210

Boris Bonev's avatar
Boris Bonev committed
211
        return grad_input, None, None, None, None, None, None, None, None
212

Boris Bonev's avatar
Boris Bonev committed
213
214
215
216
217
218
# CUDA
def _disco_s2_contraction_cuda(x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor,
                               row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor,
                               kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
    return _DiscoS2ContractionCuda.apply(x, roff_idx, ker_idx, row_idx, col_idx, vals,
                                         kernel_size, nlat_out, nlon_out)
219

Boris Bonev's avatar
Boris Bonev committed
220
221
222
223
224
def _disco_s2_transpose_contraction_cuda(x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor,
                                         row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor,
                                         kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
    return _DiscoS2TransposeContractionCuda.apply(x, roff_idx, ker_idx, row_idx, col_idx, vals,
                                                  kernel_size, nlat_out, nlon_out)
225
226
227
228
229
230


def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: int):
    """
    Reference implementation of the custom contraction as described in [1]. This requires repeated
    shifting of the input tensor, which can potentially be costly. For an efficient implementation
Boris Bonev's avatar
Boris Bonev committed
231
    on GPU, make sure to use the custom kernel written in CUDA.
232
233
234
235
236
237
238
239
240
241
242
243
244
245

    Parameters
    -----------
    x: torch.Tensor
        Input tensor
    psi: torch.Tensor
        Kernel tensor
    nlon_out: int   
        Number of output longitude points

    Returns
    --------
    y: torch.Tensor
        Output tensor
246
247
248
249
250
251
252
253
254
255
    """
    assert len(psi.shape) == 3
    assert len(x.shape) == 4
    psi = psi.to(x.device)

    batch_size, n_chans, nlat_in, nlon_in = x.shape
    kernel_size, nlat_out, _ = psi.shape

    assert psi.shape[-1] == nlat_in * nlon_in
    assert nlon_in % nlon_out == 0
Boris Bonev's avatar
Boris Bonev committed
256
    assert nlon_in >= nlat_out
257
258
    pscale = nlon_in // nlon_out

259
    # add a dummy dimension for nkernel and move the batch and channel dims to the end
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
    x = x.reshape(1, batch_size * n_chans, nlat_in, nlon_in).permute(0, 2, 3, 1)
    x = x.expand(kernel_size, -1, -1, -1)

    y = torch.zeros(nlon_out, kernel_size, nlat_out, batch_size * n_chans, device=x.device, dtype=x.dtype)

    for pout in range(nlon_out):
        # sparse contraction with psi
        y[pout] = torch.bmm(psi, x.reshape(kernel_size, nlat_in * nlon_in, -1))
        # we need to repeatedly roll the input tensor to faciliate the shifted multiplication
        x = torch.roll(x, -pscale, dims=2)

    # reshape y back to expose the correct dimensions
    y = y.permute(3, 1, 2, 0).reshape(batch_size, n_chans, kernel_size, nlat_out, nlon_out)

    return y


def _disco_s2_transpose_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: int):
    """
    Reference implementation of the custom contraction as described in [1]. This requires repeated
    shifting of the input tensor, which can potentially be costly. For an efficient implementation
Boris Bonev's avatar
Boris Bonev committed
281
    on GPU, make sure to use the custom kernel written in CUDA.
282
283
284
285
286
287
288
289
290
291
292
293
294
295

    Parameters
    -----------
    x: torch.Tensor
        Input tensor
    psi: torch.Tensor
        Kernel tensor   
    nlon_out: int
        Number of output longitude points

    Returns
    --------
    y: torch.Tensor
        Output tensor
296
297
298
299
300
301
    """
    assert len(psi.shape) == 3
    assert len(x.shape) == 5
    psi = psi.to(x.device)

    batch_size, n_chans, kernel_size, nlat_in, nlon_in = x.shape
302
    kernel_size, nlat_out, n_out = psi.shape
303
304

    assert n_out % nlon_out == 0
Boris Bonev's avatar
Boris Bonev committed
305
    assert nlon_out >= nlon_in
306
307
308
309
    pscale = nlon_out // nlon_in

    # interleave zeros along the longitude dimension to allow for fractional offsets to be considered
    x_ext = torch.zeros(kernel_size, nlat_in, nlon_out, batch_size * n_chans, device=x.device, dtype=x.dtype)
310
    x = x.reshape(batch_size * n_chans, kernel_size, nlat_in, nlon_in).permute(1, 2, 3, 0)
Boris Bonev's avatar
Boris Bonev committed
311

312
313
314
    # x has shape kernel_size x nlat_in x nlon_in x batch_size * n_chans
    # we only need to apoply the nlon stride here, since nlat stride is taken care of by the kernel
    x_ext[:, :, ::pscale, :] = x[...]
315

316
    # create output tensor
317
318
319
320
321
322
323
    y = torch.zeros(kernel_size, nlon_out, nlat_out, batch_size * n_chans, device=x.device, dtype=x.dtype)

    for pout in range(nlon_out):
        # we need to repeatedly roll the input tensor to faciliate the shifted multiplication
        # TODO: double-check why this has to happen first
        x_ext = torch.roll(x_ext, -1, dims=2)
        # sparse contraction with the modified psi
324
        y[:, pout, :, :] = torch.bmm(psi, x_ext.reshape(kernel_size, nlat_in * nlon_out, -1))
325
326

    # sum over the kernel dimension and reshape to the correct output size
327
    y = y.sum(dim=0).permute(2, 1, 0).reshape(batch_size, n_chans, nlat_out, nlon_out).contiguous()
328
329
330

    return y