convolution.py 17.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

Boris Bonev's avatar
Boris Bonev committed
32
import abc
33
34
35
36
37
38
39
40
41
from typing import List, Tuple, Union, Optional

import math

import torch
import torch.nn as nn

from functools import partial

Boris Bonev's avatar
Boris Bonev committed
42
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes
Boris Bonev's avatar
Boris Bonev committed
43
from torch_harmonics._disco_convolution import (
44
45
46
47
48
49
50
    _disco_s2_contraction_torch,
    _disco_s2_transpose_contraction_torch,
    _disco_s2_contraction_triton,
    _disco_s2_transpose_contraction_triton,
)


Boris Bonev's avatar
Boris Bonev committed
51
def _compute_support_vals_isotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, r_cutoff: float, norm: str = "s2"):
52
53
54
55
56
    """
    Computes the index set that falls into the isotropic kernel's support and returns both indices and values.
    """

    # compute the support
Boris Bonev's avatar
Boris Bonev committed
57
58
59
60
61
62
63
64
65
66
67
68
    dr = (r_cutoff - 0.0) / nr
    ikernel = torch.arange(nr).reshape(-1, 1, 1)
    ir = ikernel * dr

    if norm == "none":
        norm_factor = 1.0
    elif norm == "2d":
        norm_factor = math.pi * (r_cutoff * nr / (nr + 1))**2 + math.pi * r_cutoff**2 * (2 * nr / (nr + 1) + 1) / (nr + 1) / 3
    elif norm == "s2":
        norm_factor = 2 * math.pi * (1 - math.cos(r_cutoff - dr) + math.cos(r_cutoff - dr) + (math.sin(r_cutoff - dr) - math.sin(r_cutoff)) / dr)
    else:
        raise ValueError(f"Unknown normalization mode {norm}.")
69
70

    # find the indices where the rotated position falls into the support of the kernel
Boris Bonev's avatar
Boris Bonev committed
71
72
    iidx = torch.argwhere(((r - ir).abs() <= dr) & (r <= r_cutoff))
    vals = (1 - (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() / dr) / norm_factor
73
74
    return iidx, vals

Boris Bonev's avatar
Boris Bonev committed
75
76

def _compute_support_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, nphi: int, r_cutoff: float, norm: str = "s2"):
Boris Bonev's avatar
Boris Bonev committed
77
78
79
80
81
    """
    Computes the index set that falls into the anisotropic kernel's support and returns both indices and values.
    """

    # compute the support
Boris Bonev's avatar
Boris Bonev committed
82
    dr = (r_cutoff - 0.0) / nr
Boris Bonev's avatar
Boris Bonev committed
83
    dphi = 2.0 * math.pi / nphi
Boris Bonev's avatar
Boris Bonev committed
84
    kernel_size = (nr - 1) * nphi + 1
Boris Bonev's avatar
Boris Bonev committed
85
    ikernel = torch.arange(kernel_size).reshape(-1, 1, 1)
Boris Bonev's avatar
Boris Bonev committed
86
    ir = ((ikernel - 1) // nphi + 1) * dr
Boris Bonev's avatar
Boris Bonev committed
87
88
    iphi = ((ikernel - 1) % nphi) * dphi

Boris Bonev's avatar
Boris Bonev committed
89
90
91
92
93
94
95
96
    if norm == "none":
        norm_factor = 1.0
    elif norm == "2d":
        norm_factor = math.pi * (r_cutoff * nr / (nr + 1))**2 + math.pi * r_cutoff**2 * (2 * nr / (nr + 1) + 1) / (nr + 1) / 3
    elif norm == "s2":
        norm_factor = 2 * math.pi * (1 - math.cos(r_cutoff - dr) + math.cos(r_cutoff - dr) + (math.sin(r_cutoff - dr) - math.sin(r_cutoff)) / dr)
    else:
        raise ValueError(f"Unknown normalization mode {norm}.")
Boris Bonev's avatar
Boris Bonev committed
97
98

    # find the indices where the rotated position falls into the support of the kernel
Boris Bonev's avatar
Boris Bonev committed
99
100
101
102
103
104
105
106
107
    cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff)
    cond_phi = (ikernel == 0) | ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi)
    iidx = torch.argwhere(cond_r & cond_phi)
    vals = (1 - (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() / dr) / norm_factor
    vals *= torch.where(
        iidx[:, 0] > 0,
        (1 - torch.minimum((phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs(), (2 * math.pi - (phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs())) / dphi),
        1.0,
    )
Boris Bonev's avatar
Boris Bonev committed
108
109
    return iidx, vals

110

Boris Bonev's avatar
Boris Bonev committed
111
def _precompute_convolution_tensor_s2(in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi):
112
113
114
    """
    Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$.
    Assumes a tensorized grid on the sphere with an equidistant sampling in longitude as described in Ocampo et al.
115
116
117
118
119
120
121
122
123
124
125
    The output tensor has shape kernel_shape x nlat_out x (nlat_in * nlon_in).

    The rotation of the Euler angles uses the YZY convention, which applied to the northpole $(0,0,1)^T$ yields
    $$
    Y(\alpha) Z(\beta) Y(\gamma) n =
        {\begin{bmatrix} 
            \cos(\gamma)\sin(\alpha) + \cos(\alpha)\cos(\beta)\sin(\gamma) \\
            \sin(\beta)\sin(\gamma) \\
            \cos(\alpha)\cos(\gamma)-\cos(\beta)\sin(\alpha)\sin(\gamma)
        \end{bmatrix}}
    $$
126
127
128
129
130
131
    """

    assert len(in_shape) == 2
    assert len(out_shape) == 2

    if len(kernel_shape) == 1:
Boris Bonev's avatar
Boris Bonev committed
132
        kernel_handle = partial(_compute_support_vals_isotropic, nr=kernel_shape[0], r_cutoff=theta_cutoff, norm="s2")
Boris Bonev's avatar
Boris Bonev committed
133
    elif len(kernel_shape) == 2:
Boris Bonev's avatar
Boris Bonev committed
134
        kernel_handle = partial(_compute_support_vals_anisotropic, nr=kernel_shape[0], nphi=kernel_shape[1], r_cutoff=theta_cutoff, norm="s2")
135
136
137
138
139
140
141
142
143
144
145
146
    else:
        raise ValueError("kernel_shape should be either one- or two-dimensional.")

    nlat_in, nlon_in = in_shape
    nlat_out, nlon_out = out_shape

    lats_in, _ = _precompute_latitudes(nlat_in, grid=grid_in)
    lats_in = torch.from_numpy(lats_in).float()
    lats_out, _ = _precompute_latitudes(nlat_out, grid=grid_out)
    lats_out = torch.from_numpy(lats_out).float()

    # compute the phi differences
147
    # It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
Boris Bonev's avatar
Boris Bonev committed
148
    lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1]
149

150
151
    out_idx = []
    out_vals = []
152
    for t in range(nlat_out):
153
        # the last angle has a negative sign as it is a passive rotation, which rotates the filter around the y-axis
Boris Bonev's avatar
Boris Bonev committed
154
        alpha = -lats_out[t]
155
156
        beta = lons_in
        gamma = lats_in.reshape(-1, 1)
157
158

        # compute cartesian coordinates of the rotated position
159
160
        # This uses the YZY convention of Euler angles, where the last angle (alpha) is a passive rotation,
        # and therefore applied with a negative sign
Boris Bonev's avatar
Boris Bonev committed
161
        z = -torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma) + torch.cos(alpha) * torch.cos(gamma)
162
        x = torch.cos(alpha) * torch.cos(beta) * torch.sin(gamma) + torch.cos(gamma) * torch.sin(alpha)
163
        y = torch.sin(beta) * torch.sin(gamma)
Boris Bonev's avatar
Boris Bonev committed
164

165
166
        # normalization is emportant to avoid NaNs when arccos and atan are applied
        # this can otherwise lead to spurious artifacts in the solution
Boris Bonev's avatar
Boris Bonev committed
167
        norm = torch.sqrt(x * x + y * y + z * z)
168
169
170
171
        x = x / norm
        y = y / norm
        z = z / norm

Boris Bonev's avatar
Boris Bonev committed
172
        # compute spherical coordinates, where phi needs to fall into the [0, 2pi) range
173
        theta = torch.arccos(z)
Boris Bonev's avatar
Boris Bonev committed
174
        phi = torch.arctan2(y, x) + torch.pi
175
176
177
178
179
180
181
182

        # find the indices where the rotated position falls into the support of the kernel
        iidx, vals = kernel_handle(theta, phi)

        # add the output latitude and reshape such that psi has dimensions kernel_shape x nlat_out x (nlat_in*nlon_in)
        idx = torch.stack([iidx[:, 0], t * torch.ones_like(iidx[:, 0]), iidx[:, 1] * nlon_in + iidx[:, 2]], dim=0)

        # append indices and values to the COO datastructure
183
184
185
186
187
188
        out_idx.append(idx)
        out_vals.append(vals)

    # concatenate the indices and values
    out_idx = torch.cat(out_idx, dim=-1).to(torch.long).contiguous()
    out_vals = torch.cat(out_vals, dim=-1).to(torch.float32).contiguous()
189
190
191
192

    return out_idx, out_vals


Boris Bonev's avatar
Boris Bonev committed
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
def _precompute_convolution_tensor_2d(grid_in, grid_out, kernel_shape, radius_cutoff=0.01, periodic=False):
    """
    Precomputes the translated filters at positions $T^{-1}_j \omega_i = T^{-1}_j T_i \nu$. Similar to the S2 routine,
    only that it assumes a non-periodic subset of the euclidean plane
    """

    # check that input arrays are valid point clouds in 2D
    assert len(grid_in) == 2
    assert len(grid_out) == 2
    assert grid_in.shape[0] == 2
    assert grid_out.shape[0] == 2

    n_in = grid_in.shape[-1]
    n_out = grid_out.shape[-1]

    if len(kernel_shape) == 1:
        kernel_handle = partial(_compute_support_vals_isotropic, nr=kernel_shape[0], r_cutoff=radius_cutoff, norm="2d")
    elif len(kernel_shape) == 2:
        kernel_handle = partial(_compute_support_vals_anisotropic, nr=kernel_shape[0], nphi=kernel_shape[1], r_cutoff=radius_cutoff, norm="2d")
    else:
        raise ValueError("kernel_shape should be either one- or two-dimensional.")

    grid_in = grid_in.reshape(2, 1, n_in)
    grid_out = grid_out.reshape(2, n_out, 1)

    diffs = grid_in - grid_out
    if periodic:
        periodic_diffs = torch.where(diffs > 0.0, diffs-1, diffs+1)
        diffs = torch.where(diffs.abs() < periodic_diffs.abs(), diffs, periodic_diffs)


    r = torch.sqrt(diffs[0] ** 2 + diffs[1] ** 2)
    phi = torch.arctan2(diffs[1], diffs[0]) + torch.pi

    idx, vals = kernel_handle(r, phi)
    idx = idx.permute(1, 0)

    return idx, vals


class DiscreteContinuousConv(nn.Module, metaclass=abc.ABCMeta):
    """
    Abstract base class for DISCO convolutions
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_shape: Union[int, List[int]],
        groups: Optional[int] = 1,
        bias: Optional[bool] = True,
    ):
        super().__init__()

        if isinstance(kernel_shape, int):
            self.kernel_shape = [kernel_shape]
        else:
            self.kernel_shape = kernel_shape

        if len(self.kernel_shape) == 1:
            self.kernel_size = self.kernel_shape[0]
        elif len(self.kernel_shape) == 2:
            self.kernel_size = (self.kernel_shape[0] - 1) * self.kernel_shape[1] + 1
        else:
            raise ValueError("kernel_shape should be either one- or two-dimensional.")

        # groups
        self.groups = groups

        # weight tensor
        if in_channels % self.groups != 0:
            raise ValueError("Error, the number of input channels has to be an integer multiple of the group size")
        if out_channels % self.groups != 0:
            raise ValueError("Error, the number of output channels has to be an integer multiple of the group size")
        self.groupsize = in_channels // self.groups
        scale = math.sqrt(1.0 / self.groupsize)
        self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, self.kernel_size))

        if bias:
            self.bias = nn.Parameter(torch.zeros(out_channels))
        else:
            self.bias = None

    @abc.abstractmethod
    def forward(self, x: torch.Tensor):
        raise NotImplementedError


class DiscreteContinuousConvS2(DiscreteContinuousConv):
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
    """
    Discrete-continuous convolutions (DISCO) on the 2-Sphere as described in [1].

    [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        in_shape: Tuple[int],
        out_shape: Tuple[int],
        kernel_shape: Union[int, List[int]],
        groups: Optional[int] = 1,
        grid_in: Optional[str] = "equiangular",
        grid_out: Optional[str] = "equiangular",
        bias: Optional[bool] = True,
        theta_cutoff: Optional[float] = None,
    ):
Boris Bonev's avatar
Boris Bonev committed
302
        super().__init__(in_channels, out_channels, kernel_shape, groups, bias)
303
304
305
306

        self.nlat_in, self.nlon_in = in_shape
        self.nlat_out, self.nlon_out = out_shape

307
        # compute theta cutoff based on the bandlimit of the input field
308
        if theta_cutoff is None:
Boris Bonev's avatar
Boris Bonev committed
309
            theta_cutoff = (self.kernel_shape[0] + 1) * torch.pi / float(self.nlat_in - 1)
310
311
312
313
314
315
316
317
318

        if theta_cutoff <= 0.0:
            raise ValueError("Error, theta_cutoff has to be positive.")

        # integration weights
        _, wgl = _precompute_latitudes(self.nlat_in, grid=grid_in)
        quad_weights = 2.0 * torch.pi * torch.from_numpy(wgl).float().reshape(-1, 1) / self.nlon_in
        self.register_buffer("quad_weights", quad_weights, persistent=False)

Boris Bonev's avatar
Boris Bonev committed
319
320
        idx, vals = _precompute_convolution_tensor_s2(in_shape, out_shape, self.kernel_shape, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff)

Boris Bonev's avatar
Boris Bonev committed
321
322
        self.register_buffer("psi_idx", idx, persistent=False)
        self.register_buffer("psi_vals", vals, persistent=False)
323

Boris Bonev's avatar
Boris Bonev committed
324
325
326
    def get_psi(self):
        psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_in)).coalesce()
        return psi
327

328
    def forward(self, x: torch.Tensor, use_triton_kernel: bool = True) -> torch.Tensor:
329
330
331
        # pre-multiply x with the quadrature weights
        x = self.quad_weights * x

Boris Bonev's avatar
Boris Bonev committed
332
        psi = self.get_psi()
333
        
334
        if x.is_cuda and use_triton_kernel:
Boris Bonev's avatar
Boris Bonev committed
335
            x = _disco_s2_contraction_triton(x, psi, self.nlon_out)
336
        else:
Boris Bonev's avatar
Boris Bonev committed
337
            x = _disco_s2_contraction_torch(x, psi, self.nlon_out)
338
339
340
341
342
343

        # extract shape
        B, C, K, H, W = x.shape
        x = x.reshape(B, self.groups, self.groupsize, K, H, W)

        # do weight multiplication
344
        out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous()
345
        out = out.reshape(out.shape[0], -1, out.shape[-2], out.shape[-1])
346
347
348
349
350
351
352

        if self.bias is not None:
            out = out + self.bias.reshape(1, -1, 1, 1)

        return out


Boris Bonev's avatar
Boris Bonev committed
353
class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
    """
    Discrete-continuous transpose convolutions (DISCO) on the 2-Sphere as described in [1].

    [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        in_shape: Tuple[int],
        out_shape: Tuple[int],
        kernel_shape: Union[int, List[int]],
        groups: Optional[int] = 1,
        grid_in: Optional[str] = "equiangular",
        grid_out: Optional[str] = "equiangular",
        bias: Optional[bool] = True,
        theta_cutoff: Optional[float] = None,
    ):
Boris Bonev's avatar
Boris Bonev committed
373
        super().__init__(in_channels, out_channels, kernel_shape, groups, bias)
374
375
376
377
378
379

        self.nlat_in, self.nlon_in = in_shape
        self.nlat_out, self.nlon_out = out_shape

        # bandlimit
        if theta_cutoff is None:
Boris Bonev's avatar
Boris Bonev committed
380
            theta_cutoff = (self.kernel_shape[0] + 1) * torch.pi / float(self.nlat_in - 1)
381
382
383
384
385
386
387
388
389
390

        if theta_cutoff <= 0.0:
            raise ValueError("Error, theta_cutoff has to be positive.")

        # integration weights
        _, wgl = _precompute_latitudes(self.nlat_in, grid=grid_in)
        quad_weights = 2.0 * torch.pi * torch.from_numpy(wgl).float().reshape(-1, 1) / self.nlon_in
        self.register_buffer("quad_weights", quad_weights, persistent=False)

        # switch in_shape and out_shape since we want transpose conv
Boris Bonev's avatar
Boris Bonev committed
391
392
        idx, vals = _precompute_convolution_tensor_s2(out_shape, in_shape, self.kernel_shape, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff)

Boris Bonev's avatar
Boris Bonev committed
393
394
        self.register_buffer("psi_idx", idx, persistent=False)
        self.register_buffer("psi_vals", vals, persistent=False)
395

396
397
398
399
400
401
402
403
404
405
406
407
    def get_psi(self, use_triton_kernel=True):
        if not use_triton_kernel:
            # we do a semi-transposition to faciliate the computation
            tout = self.psi_idx[2] // self.nlon_out
            pout = self.psi_idx[2] % self.nlon_out
            # flip the axis of longitudes
            pout = self.nlon_out - 1 - pout
            tin = self.psi_idx[1]
            idx = torch.stack([self.psi_idx[0], tout, tin*self.nlon_out + pout], dim=0)
            psi = torch.sparse_coo_tensor(idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_out)).coalesce()
        else:
            psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out)).coalesce()
Boris Bonev's avatar
Boris Bonev committed
408
        return psi
409
410
411

    def forward(self, x: torch.Tensor, use_triton_kernel: bool = True) -> torch.Tensor:
        # extract shape
Boris Bonev's avatar
Boris Bonev committed
412
        B, C, H, W = x.shape
413
414
415
        x = x.reshape(B, self.groups, self.groupsize, H, W)

        # do weight multiplication
416
        x = torch.einsum("bgcxy,gock->bgokxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous()
417
        x = x.reshape(x.shape[0], -1, x.shape[-3], x.shape[-2], x.shape[-1])
418
419
420
421

        # pre-multiply x with the quadrature weights
        x = self.quad_weights * x

422
        psi = self.get_psi(x.is_cuda and use_triton_kernel)
Boris Bonev's avatar
Boris Bonev committed
423

424
        if x.is_cuda and use_triton_kernel:
Boris Bonev's avatar
Boris Bonev committed
425
            out = _disco_s2_transpose_contraction_triton(x, psi, self.nlon_out)
426
        else:
Boris Bonev's avatar
Boris Bonev committed
427
            out = _disco_s2_transpose_contraction_torch(x, psi, self.nlon_out)
428
429
430
431
432

        if self.bias is not None:
            out = out + self.bias.reshape(1, -1, 1, 1)

        return out
Boris Bonev's avatar
Boris Bonev committed
433