s2_convolutions.py 12.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

from typing import List, Tuple, Union, Optional

import math

import torch
import torch.nn as nn

from functools import partial

from torch_harmonics.quadrature import _precompute_latitudes
from torch_harmonics.disco_convolutions import (
    _disco_s2_contraction_torch,
    _disco_s2_transpose_contraction_torch,
    _disco_s2_contraction_triton,
    _disco_s2_transpose_contraction_triton,
)


def _compute_support_vals_isotropic(theta: torch.Tensor, phi: torch.Tensor, kernel_size: int, theta_cutoff: float):
    """
    Computes the index set that falls into the isotropic kernel's support and returns both indices and values.
    """

    # compute the support
    dtheta = (theta_cutoff - 0.0) / kernel_size
    ikernel = torch.arange(kernel_size).reshape(-1, 1, 1)
    itheta = ikernel * dtheta

60
    norm_factor = 2 * math.pi * (1 - math.cos(theta_cutoff - dtheta) + math.cos(theta_cutoff - dtheta) + (math.sin(theta_cutoff - dtheta) - math.sin(theta_cutoff)) / dtheta)
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106

    # find the indices where the rotated position falls into the support of the kernel
    iidx = torch.argwhere(((theta - itheta).abs() <= dtheta) & (theta <= theta_cutoff))
    vals = (1 - (theta[iidx[:, 1], iidx[:, 2]] - itheta[iidx[:, 0], 0, 0]).abs() / dtheta) / norm_factor
    return iidx, vals


def _precompute_convolution_tensor(
    in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi
):
    """
    Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$.
    Assumes a tensorized grid on the sphere with an equidistant sampling in longitude as described in Ocampo et al.
    The output tensor has shape kernel_shape x nlat_out x (nlat_in * nlon_in)
    """

    assert len(in_shape) == 2
    assert len(out_shape) == 2

    if len(kernel_shape) == 1:
        kernel_handle = partial(_compute_support_vals_isotropic, kernel_size=kernel_shape[0], theta_cutoff=theta_cutoff)
    else:
        raise ValueError("kernel_shape should be either one- or two-dimensional.")

    nlat_in, nlon_in = in_shape
    nlat_out, nlon_out = out_shape

    lats_in, _ = _precompute_latitudes(nlat_in, grid=grid_in)
    lats_in = torch.from_numpy(lats_in).float()
    lats_out, _ = _precompute_latitudes(nlat_out, grid=grid_out)
    lats_out = torch.from_numpy(lats_out).float()

    # array for accumulating non-zero indices
    out_idx = torch.empty([3, 0], dtype=torch.long)
    out_vals = torch.empty([0], dtype=torch.long)

    # compute the phi differences
    phis = torch.linspace(0, 2 * math.pi, nlon_in)

    for t in range(nlat_out):
        alpha = -lats_in.reshape(-1, 1)
        beta = phis
        gamma = lats_out[t]

        # compute latitude of the rotated position
        z = torch.cos(alpha) * torch.cos(gamma) - torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma)
107
        z = torch.clamp(z, min=-1.0, max=1.0)
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
        theta = torch.arccos(z)

        # compute cartesian coordinates of the rotated position
        x = torch.cos(beta) * torch.sin(alpha) + torch.cos(alpha) * torch.cos(beta) * torch.sin(gamma)
        y = torch.sin(beta) * torch.sin(gamma)
        phi = torch.arctan2(y, x)

        # find the indices where the rotated position falls into the support of the kernel
        iidx, vals = kernel_handle(theta, phi)

        # add the output latitude and reshape such that psi has dimensions kernel_shape x nlat_out x (nlat_in*nlon_in)
        idx = torch.stack([iidx[:, 0], t * torch.ones_like(iidx[:, 0]), iidx[:, 1] * nlon_in + iidx[:, 2]], dim=0)

        # append indices and values to the COO datastructure
        out_idx = torch.cat([out_idx, idx], dim=-1)
        out_vals = torch.cat([out_vals, vals], dim=-1)

    return out_idx, out_vals


# TODO:
# - parameter initialization
# - add anisotropy
class DiscreteContinuousConvS2(nn.Module):
    """
    Discrete-continuous convolutions (DISCO) on the 2-Sphere as described in [1].

    [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        in_shape: Tuple[int],
        out_shape: Tuple[int],
        kernel_shape: Union[int, List[int]],
        groups: Optional[int] = 1,
        grid_in: Optional[str] = "equiangular",
        grid_out: Optional[str] = "equiangular",
        bias: Optional[bool] = True,
        theta_cutoff: Optional[float] = None,
    ):
        super().__init__()

        self.nlat_in, self.nlon_in = in_shape
        self.nlat_out, self.nlon_out = out_shape

        if isinstance(kernel_shape, int):
            kernel_shape = [kernel_shape]
        self.kernel_size = 1
        for kdim in kernel_shape:
            self.kernel_size *= kdim

        # bandlimit
        if theta_cutoff is None:
164
            theta_cutoff = (kernel_shape[0]+1) * torch.pi / float(self.nlat_in - 1)
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187

        if theta_cutoff <= 0.0:
            raise ValueError("Error, theta_cutoff has to be positive.")

        # integration weights
        _, wgl = _precompute_latitudes(self.nlat_in, grid=grid_in)
        quad_weights = 2.0 * torch.pi * torch.from_numpy(wgl).float().reshape(-1, 1) / self.nlon_in
        self.register_buffer("quad_weights", quad_weights, persistent=False)

        idx, vals = _precompute_convolution_tensor(
            in_shape, out_shape, kernel_shape, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff
        )
        psi = torch.sparse_coo_tensor(
            idx, vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_in)
        ).coalesce()
        self.register_buffer("psi", psi, persistent=False)

        # groups
        self.groups = groups

        # weight tensor
        if in_channels % self.groups != 0:
            raise ValueError("Error, the number of input channels has to be an integer multiple of the group size")
188
189
        if out_channels % self.groups != 0:
            raise ValueError("Error, the number of output channels has to be an integer multiple of the group size")
190
        self.groupsize = in_channels // self.groups
191
        self.weight = nn.Parameter(torch.ones(out_channels, self.groupsize, kernel_shape[0]))
192
193

        if bias:
194
            self.bias = nn.Parameter(torch.zeros(out_channels))
195
196
197
        else:
            self.bias = None

198
    def forward(self, x: torch.Tensor, use_triton_kernel: bool = True) -> torch.Tensor:
199
200
201
202
203
204
205
206
207
208
209
210
211
        # pre-multiply x with the quadrature weights
        x = self.quad_weights * x

        if x.is_cuda and use_triton_kernel:
            x = _disco_s2_contraction_triton(x, self.psi, self.nlon_out)
        else:
            x = _disco_s2_contraction_torch(x, self.psi, self.nlon_out)

        # extract shape
        B, C, K, H, W = x.shape
        x = x.reshape(B, self.groups, self.groupsize, K, H, W)

        # do weight multiplication
212
213
        out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2]))
        out = out.reshape(out.shape[0], -1, out.shape[-2], out.shape[-1])
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253

        if self.bias is not None:
            out = out + self.bias.reshape(1, -1, 1, 1)

        return out


class DiscreteContinuousConvTransposeS2(nn.Module):
    """
    Discrete-continuous transpose convolutions (DISCO) on the 2-Sphere as described in [1].

    [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        in_shape: Tuple[int],
        out_shape: Tuple[int],
        kernel_shape: Union[int, List[int]],
        groups: Optional[int] = 1,
        grid_in: Optional[str] = "equiangular",
        grid_out: Optional[str] = "equiangular",
        bias: Optional[bool] = True,
        theta_cutoff: Optional[float] = None,
    ):
        super().__init__()

        self.nlat_in, self.nlon_in = in_shape
        self.nlat_out, self.nlon_out = out_shape

        if isinstance(kernel_shape, int):
            kernel_shape = [kernel_shape]
        self.kernel_size = 1
        for kdim in kernel_shape:
            self.kernel_size *= kdim

        # bandlimit
        if theta_cutoff is None:
254
            theta_cutoff = (kernel_shape[0]+1) * torch.pi / float(self.nlat_in - 1)
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278

        if theta_cutoff <= 0.0:
            raise ValueError("Error, theta_cutoff has to be positive.")

        # integration weights
        _, wgl = _precompute_latitudes(self.nlat_in, grid=grid_in)
        quad_weights = 2.0 * torch.pi * torch.from_numpy(wgl).float().reshape(-1, 1) / self.nlon_in
        self.register_buffer("quad_weights", quad_weights, persistent=False)

        # switch in_shape and out_shape since we want transpose conv
        idx, vals = _precompute_convolution_tensor(
            out_shape, in_shape, kernel_shape, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff
        )
        psi = torch.sparse_coo_tensor(
            idx, vals, size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out)
        ).coalesce()
        self.register_buffer("psi", psi, persistent=False)

        # groups
        self.groups = groups

        # weight tensor
        if in_channels % self.groups != 0:
            raise ValueError("Error, the number of input channels has to be an integer multiple of the group size")
279
280
        if out_channels % self.groups != 0:
            raise ValueError("Error, the number of output channels has to be an integer multiple of the group size")
281
        self.groupsize = in_channels // self.groups
282
        self.weight = nn.Parameter(torch.ones(out_channels, self.groupsize, kernel_shape[0]))
283
284

        if bias:
285
            self.bias = nn.Parameter(torch.zeros(out_channels))
286
287
288
289
290
291
292
293
294
        else:
            self.bias = None

    def forward(self, x: torch.Tensor, use_triton_kernel: bool = True) -> torch.Tensor:
        # extract shape
        B, F, H, W = x.shape
        x = x.reshape(B, self.groups, self.groupsize, H, W)

        # do weight multiplication
295
296
        x = torch.einsum("bgcxy,gock->bgokxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2]))
        x = x.reshape(x.shape[0], -1, x.shape[-3], x.shape[-2], x.shape[-1])
297
298
299
300
301
302
303
304
305
306
307
308
309

        # pre-multiply x with the quadrature weights
        x = self.quad_weights * x

        if x.is_cuda and use_triton_kernel:
            out = _disco_s2_transpose_contraction_triton(x, self.psi, self.nlon_out)
        else:
            out = _disco_s2_transpose_contraction_torch(x, self.psi, self.nlon_out)

        if self.bias is not None:
            out = out + self.bias.reshape(1, -1, 1, 1)

        return out