_disco_convolution.py 9.52 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

32
from typing import Optional
33
34
35
import math

import torch
36
from torch.amp import custom_fwd, custom_bwd
37

Boris Bonev's avatar
Boris Bonev committed
38
39
40
41
try:
    import disco_cuda_extension
except ImportError as err:
    disco_cuda_extension = None
42

43
44
# some helper functions
def _get_psi(kernel_size: int, psi_idx: torch.Tensor, psi_vals: torch.Tensor, nlat_in: int, nlon_in: int, nlat_out: int, nlon_out: int, nlat_in_local: Optional[int] = None, nlat_out_local: Optional[int] = None, semi_transposed: Optional[bool] = False):
Andrea Paris's avatar
Andrea Paris committed
45
    """Creates a sparse tensor for spherical harmonic convolution operations."""
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
    nlat_in_local = nlat_in_local if nlat_in_local is not None else nlat_in
    nlat_out_local = nlat_out_local if nlat_out_local is not None else nlat_out
    
    if semi_transposed:
        # do partial transpose
        # we do a semi-transposition to faciliate the computation
        tout = psi_idx[2] // nlon_out
        pout = psi_idx[2] % nlon_out
        # flip the axis of longitudes
        pout = nlon_out - 1 - pout
        tin = psi_idx[1]
        idx = torch.stack([psi_idx[0], tout, tin * nlon_out + pout], dim=0)
        psi = torch.sparse_coo_tensor(idx, psi_vals, size=(kernel_size, nlat_out_local, nlat_in_local * nlon_out)).coalesce()
    else:
        psi = torch.sparse_coo_tensor(psi_idx, psi_vals, size=(kernel_size, nlat_out_local, nlat_in_local * nlon_in)).coalesce()
    return psi

63

Boris Bonev's avatar
Boris Bonev committed
64
class _DiscoS2ContractionCuda(torch.autograd.Function):
65
    @staticmethod
Thorsten Kurth's avatar
Thorsten Kurth committed
66
    @custom_fwd(device_type="cuda")
Boris Bonev's avatar
Boris Bonev committed
67
68
69
    def forward(ctx, x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor,
                row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor,
                kernel_size: int, nlat_out: int, nlon_out: int):
apaaris's avatar
apaaris committed
70
        
Boris Bonev's avatar
Boris Bonev committed
71
72
73
        ctx.save_for_backward(roff_idx, ker_idx, row_idx, col_idx, vals)
        ctx.kernel_size = kernel_size
        ctx.nlat_in = x.shape[-2]
74
        ctx.nlon_in = x.shape[-1]
Thorsten Kurth's avatar
Thorsten Kurth committed
75
76
77
78
        xtype = x.dtype
        x = x.to(torch.float32).contiguous()
        output = disco_cuda_extension.forward(x, roff_idx, ker_idx, row_idx, col_idx, vals, kernel_size, nlat_out, nlon_out)
        output = output.to(xtype)
79

80
        return output
81
82

    @staticmethod
83
    @custom_bwd(device_type="cuda")
84
    def backward(ctx, grad_output):
85

Boris Bonev's avatar
Boris Bonev committed
86
        roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors
Thorsten Kurth's avatar
Thorsten Kurth committed
87
88
89
        gtype =	grad_output.dtype
        grad_output = grad_output.to(torch.float32).contiguous()
        grad_input = disco_cuda_extension.backward(grad_output, roff_idx, ker_idx, row_idx, col_idx, vals,
Boris Bonev's avatar
Boris Bonev committed
90
                                         ctx.kernel_size, ctx.nlat_in, ctx.nlon_in)
Thorsten Kurth's avatar
Thorsten Kurth committed
91
        grad_input = grad_input.to(gtype)
92

Boris Bonev's avatar
Boris Bonev committed
93
        return grad_input, None, None, None, None, None, None, None, None
94

Boris Bonev's avatar
Boris Bonev committed
95
96

class _DiscoS2TransposeContractionCuda(torch.autograd.Function):
97
    @staticmethod
Thorsten Kurth's avatar
Thorsten Kurth committed
98
    @custom_fwd(device_type="cuda")
Boris Bonev's avatar
Boris Bonev committed
99
100
101
    def forward(ctx, x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor,
                row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor,
                kernel_size: int, nlat_out: int, nlon_out: int):
apaaris's avatar
apaaris committed
102
        
Boris Bonev's avatar
Boris Bonev committed
103
104
105
        ctx.save_for_backward(roff_idx, ker_idx, row_idx, col_idx, vals)
        ctx.kernel_size = kernel_size
        ctx.nlat_in = x.shape[-2]
106
        ctx.nlon_in = x.shape[-1]
Thorsten Kurth's avatar
Thorsten Kurth committed
107
108
109
110
        xtype =	x.dtype
        x = x.to(torch.float32).contiguous()
        output = disco_cuda_extension.backward(x, roff_idx, ker_idx, row_idx, col_idx, vals, kernel_size, nlat_out, nlon_out)
        output = output.to(xtype)
111

112
        return output
113
114

    @staticmethod
115
    @custom_bwd(device_type="cuda")
116
    def backward(ctx, grad_output):
117
       
Boris Bonev's avatar
Boris Bonev committed
118
        roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors
Thorsten Kurth's avatar
Thorsten Kurth committed
119
120
121
        gtype = grad_output.dtype
        grad_output = grad_output.to(torch.float32).contiguous()
        grad_input = disco_cuda_extension.forward(grad_output, roff_idx, ker_idx, row_idx, col_idx, vals,
Boris Bonev's avatar
Boris Bonev committed
122
                                        ctx.kernel_size, ctx.nlat_in, ctx.nlon_in)
Thorsten Kurth's avatar
Thorsten Kurth committed
123
        grad_input = grad_input.to(gtype)
124

Boris Bonev's avatar
Boris Bonev committed
125
        return grad_input, None, None, None, None, None, None, None, None
126

Boris Bonev's avatar
Boris Bonev committed
127
128
129
130
131
132
# CUDA
def _disco_s2_contraction_cuda(x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor,
                               row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor,
                               kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
    return _DiscoS2ContractionCuda.apply(x, roff_idx, ker_idx, row_idx, col_idx, vals,
                                         kernel_size, nlat_out, nlon_out)
133

Boris Bonev's avatar
Boris Bonev committed
134
135
136
137
138
def _disco_s2_transpose_contraction_cuda(x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor,
                                         row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor,
                                         kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
    return _DiscoS2TransposeContractionCuda.apply(x, roff_idx, ker_idx, row_idx, col_idx, vals,
                                                  kernel_size, nlat_out, nlon_out)
139
140
141


def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: int):
Andrea Paris's avatar
Andrea Paris committed
142
143
144
145
146
147
    """
    Reference implementation of the custom contraction as described in [1]. This requires repeated
    shifting of the input tensor, which can potentially be costly. For an efficient implementation
    on GPU, make sure to use the custom kernel written in CUDA.
    """
    
148
149
150
151
152
153
154
155
156
    assert len(psi.shape) == 3
    assert len(x.shape) == 4
    psi = psi.to(x.device)

    batch_size, n_chans, nlat_in, nlon_in = x.shape
    kernel_size, nlat_out, _ = psi.shape

    assert psi.shape[-1] == nlat_in * nlon_in
    assert nlon_in % nlon_out == 0
Boris Bonev's avatar
Boris Bonev committed
157
    assert nlon_in >= nlat_out
158
159
    pscale = nlon_in // nlon_out

160
    # add a dummy dimension for nkernel and move the batch and channel dims to the end
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
    x = x.reshape(1, batch_size * n_chans, nlat_in, nlon_in).permute(0, 2, 3, 1)
    x = x.expand(kernel_size, -1, -1, -1)

    y = torch.zeros(nlon_out, kernel_size, nlat_out, batch_size * n_chans, device=x.device, dtype=x.dtype)

    for pout in range(nlon_out):
        # sparse contraction with psi
        y[pout] = torch.bmm(psi, x.reshape(kernel_size, nlat_in * nlon_in, -1))
        # we need to repeatedly roll the input tensor to faciliate the shifted multiplication
        x = torch.roll(x, -pscale, dims=2)

    # reshape y back to expose the correct dimensions
    y = y.permute(3, 1, 2, 0).reshape(batch_size, n_chans, kernel_size, nlat_out, nlon_out)

    return y


def _disco_s2_transpose_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: int):
    assert len(psi.shape) == 3
    assert len(x.shape) == 5
    psi = psi.to(x.device)

    batch_size, n_chans, kernel_size, nlat_in, nlon_in = x.shape
184
    kernel_size, nlat_out, n_out = psi.shape
185
186

    assert n_out % nlon_out == 0
Boris Bonev's avatar
Boris Bonev committed
187
    assert nlon_out >= nlon_in
188
189
190
191
    pscale = nlon_out // nlon_in

    # interleave zeros along the longitude dimension to allow for fractional offsets to be considered
    x_ext = torch.zeros(kernel_size, nlat_in, nlon_out, batch_size * n_chans, device=x.device, dtype=x.dtype)
192
    x = x.reshape(batch_size * n_chans, kernel_size, nlat_in, nlon_in).permute(1, 2, 3, 0)
Boris Bonev's avatar
Boris Bonev committed
193

194
195
196
    # x has shape kernel_size x nlat_in x nlon_in x batch_size * n_chans
    # we only need to apoply the nlon stride here, since nlat stride is taken care of by the kernel
    x_ext[:, :, ::pscale, :] = x[...]
197

198
    # create output tensor
199
200
201
202
203
204
205
    y = torch.zeros(kernel_size, nlon_out, nlat_out, batch_size * n_chans, device=x.device, dtype=x.dtype)

    for pout in range(nlon_out):
        # we need to repeatedly roll the input tensor to faciliate the shifted multiplication
        # TODO: double-check why this has to happen first
        x_ext = torch.roll(x_ext, -1, dims=2)
        # sparse contraction with the modified psi
206
        y[:, pout, :, :] = torch.bmm(psi, x_ext.reshape(kernel_size, nlat_in * nlon_out, -1))
207
208

    # sum over the kernel dimension and reshape to the correct output size
209
    y = y.sum(dim=0).permute(2, 1, 0).reshape(batch_size, n_chans, nlat_out, nlon_out).contiguous()
210
211
212

    return y