"README_ORIGIN.md" did not exist on "dd1929ba7668226bf77563a411475f7e7c4ca076"
distributed_convolution.py 16.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import abc
from typing import List, Tuple, Union, Optional
Boris Bonev's avatar
Boris Bonev committed
34
35
from itertools import accumulate
from warnings import warn
36
37
38
39
40
41
42
43

import math

import torch
import torch.nn as nn

from functools import partial

Thorsten Kurth's avatar
Thorsten Kurth committed
44
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes
45
from torch_harmonics._disco_convolution import _get_psi, _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch
Boris Bonev's avatar
Boris Bonev committed
46
from torch_harmonics._disco_convolution import _disco_s2_contraction_cuda, _disco_s2_transpose_contraction_cuda
47
from torch_harmonics.filter_basis import get_filter_basis
48
from torch_harmonics.convolution import (
Thorsten Kurth's avatar
Thorsten Kurth committed
49
    _precompute_convolution_tensor_s2,
50
51
52
    DiscreteContinuousConv,
)

53

54
55
from torch_harmonics.distributed import polar_group_size, azimuth_group_size
from torch_harmonics.distributed import distributed_transpose_azimuth, distributed_transpose_polar
56
from torch_harmonics.distributed import reduce_from_polar_region, scatter_to_polar_region, gather_from_polar_region, copy_to_polar_region
57
58
59
from torch_harmonics.distributed import polar_group_rank, azimuth_group_rank
from torch_harmonics.distributed import compute_split_shapes, split_tensor_along_dim

60
# import custom C++/CUDA extensions if available
Boris Bonev's avatar
Boris Bonev committed
61
try:
62
    from disco_helpers import preprocess_psi
Boris Bonev's avatar
Boris Bonev committed
63
    import disco_cuda_extension
64

Boris Bonev's avatar
Boris Bonev committed
65
66
67
68
69
70
    _cuda_extension_available = True
except ImportError as err:
    disco_cuda_extension = None
    _cuda_extension_available = False


Thorsten Kurth's avatar
Thorsten Kurth committed
71
72
73
74
75
def _split_distributed_convolution_tensor_s2(
    idx: torch.Tensor,
    vals: torch.Tensor,
    in_shape: Tuple[int],
    out_shape: Tuple[int],
Boris Bonev's avatar
Boris Bonev committed
76
):
77
78
79
    nlat_in, nlon_in = in_shape
    nlat_out, nlon_out = out_shape

Boris Bonev's avatar
Boris Bonev committed
80
81
82
83
84
    comm_size_polar = polar_group_size()
    comm_rank_polar = polar_group_rank()
    split_shapes = compute_split_shapes(nlat_in, num_chunks=comm_size_polar)
    offsets = [0] + list(accumulate(split_shapes))
    start_idx = offsets[comm_rank_polar]
85
    end_idx = offsets[comm_rank_polar + 1]
Boris Bonev's avatar
Boris Bonev committed
86
87

    # once normalization is done we can throw away the entries which correspond to input latitudes we do not care about
Thorsten Kurth's avatar
Thorsten Kurth committed
88
89
    lats = idx[2] // nlon_in
    lons = idx[2] % nlon_in
Boris Bonev's avatar
Boris Bonev committed
90
    ilats = torch.argwhere((lats < end_idx) & (lats >= start_idx)).squeeze()
Thorsten Kurth's avatar
Thorsten Kurth committed
91
    vals = vals[ilats]
Boris Bonev's avatar
Boris Bonev committed
92
    # for the indices we need to recompute them to refer to local indices of the input tenor
Thorsten Kurth's avatar
Thorsten Kurth committed
93
    idx = torch.stack([idx[0, ilats], idx[1, ilats], (lats[ilats] - start_idx) * nlon_in + lons[ilats]], dim=0)
94

Thorsten Kurth's avatar
Thorsten Kurth committed
95
96
97
    # make results contiguous
    idx = idx.contiguous()
    vals = vals.to(dtype=torch.float32).contiguous()
98

Thorsten Kurth's avatar
Thorsten Kurth committed
99
    return idx, vals
100

Boris Bonev's avatar
Boris Bonev committed
101

102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
    """
    Distributed version of Discrete-continuous convolutions (DISCO) on the 2-Sphere as described in [1].

    [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603

    We assume the data can be splitted in polar and azimuthal directions.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        in_shape: Tuple[int],
        out_shape: Tuple[int],
Thorsten Kurth's avatar
Thorsten Kurth committed
117
        kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
118
        basis_type: Optional[str] = "piecewise linear",
119
        basis_norm_mode: Optional[str] = "mean",
120
121
122
123
124
125
        groups: Optional[int] = 1,
        grid_in: Optional[str] = "equiangular",
        grid_out: Optional[str] = "equiangular",
        bias: Optional[bool] = True,
        theta_cutoff: Optional[float] = None,
    ):
126
        super().__init__(in_channels, out_channels, kernel_shape, basis_type, groups, bias)
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144

        self.nlat_in, self.nlon_in = in_shape
        self.nlat_out, self.nlon_out = out_shape

        # get the comms grid:
        self.comm_size_polar = polar_group_size()
        self.comm_rank_polar = polar_group_rank()
        self.comm_size_azimuth = azimuth_group_size()
        self.comm_rank_azimuth = azimuth_group_rank()

        # we need those shapes:
        self.lat_in_shapes = compute_split_shapes(self.nlat_in, self.comm_size_polar)
        self.lon_in_shapes = compute_split_shapes(self.nlon_in, self.comm_size_azimuth)
        self.lat_out_shapes = compute_split_shapes(self.nlat_out, self.comm_size_polar)
        self.lon_out_shapes = compute_split_shapes(self.nlon_out, self.comm_size_azimuth)

        # compute theta cutoff based on the bandlimit of the input field
        if theta_cutoff is None:
Thorsten Kurth's avatar
Thorsten Kurth committed
145
            theta_cutoff = torch.pi / float(self.nlat_out - 1)
146
147
148
149
150
151
152
153
154

        if theta_cutoff <= 0.0:
            raise ValueError("Error, theta_cutoff has to be positive.")

        # Note that the psi matrix is of shape nlat_out x nlat_in * nlon_in. Since the contraction in nlon direction is a convolution,
        # we will keep local to all nodes and split the computation up along nlat. We further split the input dim because this reduces the number
        # of atomic reduction calls inside the actual kernel

        # set local shapes according to distributed mode:
Boris Bonev's avatar
Boris Bonev committed
155
156
        self.nlat_in_local = self.lat_in_shapes[self.comm_rank_polar]
        self.nlat_out_local = self.nlat_out
157

Thorsten Kurth's avatar
Thorsten Kurth committed
158
159
        # compute global convolution tensor
        idx, vals, _ = _precompute_convolution_tensor_s2(
160
161
162
163
164
165
166
167
168
            in_shape,
            out_shape,
            self.filter_basis,
            grid_in=grid_in,
            grid_out=grid_out,
            theta_cutoff=theta_cutoff,
            transpose_normalization=False,
            basis_norm_mode=basis_norm_mode,
            merge_quadrature=True,
Boris Bonev's avatar
Boris Bonev committed
169
        )
170

Thorsten Kurth's avatar
Thorsten Kurth committed
171
172
173
        # split the convolution tensor along latitude
        idx, vals = _split_distributed_convolution_tensor_s2(idx, vals, in_shape, out_shape)

Boris Bonev's avatar
Boris Bonev committed
174
175
176
177
        # sort the values
        ker_idx = idx[0, ...].contiguous()
        row_idx = idx[1, ...].contiguous()
        col_idx = idx[2, ...].contiguous()
178
        vals = vals.contiguous()
Boris Bonev's avatar
Boris Bonev committed
179

180
181
182
183
184
        if _cuda_extension_available:
            # preprocessed data-structure for GPU kernel
            roff_idx = preprocess_psi(self.kernel_size, self.nlat_out_local, ker_idx, row_idx, col_idx, vals).contiguous()
            self.register_buffer("psi_roff_idx", roff_idx, persistent=False)

185
        # save all datastructures
Boris Bonev's avatar
Boris Bonev committed
186
187
188
        self.register_buffer("psi_ker_idx", ker_idx, persistent=False)
        self.register_buffer("psi_row_idx", row_idx, persistent=False)
        self.register_buffer("psi_col_idx", col_idx, persistent=False)
189
190
        self.register_buffer("psi_vals", vals, persistent=False)

191
192
193
        # store psi jic:
        self.psi = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out, self.nlat_in_local, self.nlat_out_local)

194
195
196
197
    def extra_repr(self):
        r"""
        Pretty print module
        """
198
        return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
199

Boris Bonev's avatar
Boris Bonev committed
200
201
202
203
204
    @property
    def psi_idx(self):
        return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
205
206
207

        # store number of channels
        num_chans = x.shape[1]
208

209
210
211
212
        # h and w is split. First we make w local by transposing into channel dim
        if self.comm_size_azimuth > 1:
            x = distributed_transpose_azimuth.apply(x, (1, -1), self.lon_in_shapes)

Boris Bonev's avatar
Boris Bonev committed
213
214
215
216
217
218
219
        if x.is_cuda and _cuda_extension_available:
            x = _disco_s2_contraction_cuda(
                x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out_local, self.nlon_out
            )
        else:
            if x.is_cuda:
                warn("couldn't find CUDA extension, falling back to slow PyTorch implementation")
220

221
            x = _disco_s2_contraction_torch(x, self.psi.to(x.device), self.nlon_out)
222

Thorsten Kurth's avatar
Thorsten Kurth committed
223
        # perform reduce scatter in polar region
224
        x = reduce_from_polar_region(x)
225
        x = scatter_to_polar_region(x, -2)
226
227
228
229
230
231
232
233
234
235
236
237

        # now we can transpose back the result, so that lon is split and channels are local
        if self.comm_size_azimuth > 1:
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
            x = distributed_transpose_azimuth.apply(x, (-1, 1), chan_shapes)

        # extract shape
        B, C, K, H, W = x.shape
        x = x.reshape(B, self.groups, self.groupsize, K, H, W)

        # do weight multiplication
        out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous()
Boris Bonev's avatar
Boris Bonev committed
238
        out = out.reshape(out.shape[0], -1, H, W)
239
240
241
242
243
244
245
246
247
248
249
250

        if self.bias is not None:
            out = out + self.bias.reshape(1, -1, 1, 1)

        return out


class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
    """
    Discrete-continuous transpose convolutions (DISCO) on the 2-Sphere as described in [1].

    [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
Thorsten Kurth's avatar
Thorsten Kurth committed
251
252

    We assume the data can be splitted in polar and azimuthal directions.
253
254
255
256
257
258
259
260
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        in_shape: Tuple[int],
        out_shape: Tuple[int],
Thorsten Kurth's avatar
Thorsten Kurth committed
261
        kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
262
        basis_type: Optional[str] = "piecewise linear",
263
        basis_norm_mode: Optional[str] = "mean",
264
265
266
267
268
269
        groups: Optional[int] = 1,
        grid_in: Optional[str] = "equiangular",
        grid_out: Optional[str] = "equiangular",
        bias: Optional[bool] = True,
        theta_cutoff: Optional[float] = None,
    ):
270
        super().__init__(in_channels, out_channels, kernel_shape, basis_type, groups, bias)
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288

        self.nlat_in, self.nlon_in = in_shape
        self.nlat_out, self.nlon_out = out_shape

        # get the comms grid:
        self.comm_size_polar = polar_group_size()
        self.comm_rank_polar = polar_group_rank()
        self.comm_size_azimuth = azimuth_group_size()
        self.comm_rank_azimuth = azimuth_group_rank()

        # we need those shapes:
        self.lat_in_shapes = compute_split_shapes(self.nlat_in, self.comm_size_polar)
        self.lon_in_shapes = compute_split_shapes(self.nlon_in, self.comm_size_azimuth)
        self.lat_out_shapes = compute_split_shapes(self.nlat_out, self.comm_size_polar)
        self.lon_out_shapes = compute_split_shapes(self.nlon_out, self.comm_size_azimuth)

        # bandlimit
        if theta_cutoff is None:
Thorsten Kurth's avatar
Thorsten Kurth committed
289
            theta_cutoff = torch.pi / float(self.nlat_in - 1)
290
291
292
293
294
295

        if theta_cutoff <= 0.0:
            raise ValueError("Error, theta_cutoff has to be positive.")

        # Note that the psi matrix is of shape nlat_out x nlat_in * nlon_in. Since the contraction in nlon direction is a convolution,
        # we will keep local to all nodes and split the computation up along nlat. We further split the input dim because this reduces the number
Boris Bonev's avatar
Boris Bonev committed
296
        # of atomic reduction calls inside the actual kernel
297
298

        # set local shapes according to distributed mode:
Boris Bonev's avatar
Boris Bonev committed
299
300
        self.nlat_in_local = self.nlat_in
        self.nlat_out_local = self.lat_out_shapes[self.comm_rank_polar]
301

Thorsten Kurth's avatar
Thorsten Kurth committed
302
        # compute global convolution tensor
303
304
        # switch in_shape and out_shape since we want transpose conv
        # distributed mode here is swapped because of the transpose
Thorsten Kurth's avatar
Thorsten Kurth committed
305
        idx, vals, _ = _precompute_convolution_tensor_s2(
306
307
308
309
310
311
312
313
314
            out_shape,
            in_shape,
            self.filter_basis,
            grid_in=grid_out,
            grid_out=grid_in,
            theta_cutoff=theta_cutoff,
            transpose_normalization=True,
            basis_norm_mode=basis_norm_mode,
            merge_quadrature=True,
Boris Bonev's avatar
Boris Bonev committed
315
        )
316

Thorsten Kurth's avatar
Thorsten Kurth committed
317
318
319
320
        # split the convolution tensor along latitude, again, we need to swap the meaning
        # of in_shape and out_shape
        idx, vals = _split_distributed_convolution_tensor_s2(idx, vals, out_shape, in_shape)

Boris Bonev's avatar
Boris Bonev committed
321
322
323
324
        # sort the values
        ker_idx = idx[0, ...].contiguous()
        row_idx = idx[1, ...].contiguous()
        col_idx = idx[2, ...].contiguous()
325
        vals = vals.contiguous()
Boris Bonev's avatar
Boris Bonev committed
326

327
328
329
330
331
        if _cuda_extension_available:
            # preprocessed data-structure for GPU kernel
            roff_idx = preprocess_psi(self.kernel_size, self.nlat_in_local, ker_idx, row_idx, col_idx, vals).contiguous()
            self.register_buffer("psi_roff_idx", roff_idx, persistent=False)

332
        # save all datastructures
Boris Bonev's avatar
Boris Bonev committed
333
334
335
        self.register_buffer("psi_ker_idx", ker_idx, persistent=False)
        self.register_buffer("psi_row_idx", row_idx, persistent=False)
        self.register_buffer("psi_col_idx", col_idx, persistent=False)
336
337
        self.register_buffer("psi_vals", vals, persistent=False)

338
339
340
        # store psi as COO
        self.psi_st = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out, self.nlat_in_local, self.nlat_out_local, semi_transposed=True)

341
342
343
344
    def extra_repr(self):
        r"""
        Pretty print module
        """
345
        return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
346

Boris Bonev's avatar
Boris Bonev committed
347
348
349
350
351
352
    @property
    def psi_idx(self):
        return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()

    def forward(self, x: torch.Tensor) -> torch.Tensor:

353
354
355
356
357
358
        # extract shape
        B, C, H, W = x.shape
        x = x.reshape(B, self.groups, self.groupsize, H, W)

        # do weight multiplication
        x = torch.einsum("bgcxy,gock->bgokxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous()
Boris Bonev's avatar
Boris Bonev committed
359
        x = x.reshape(B, -1, x.shape[-3], H, W)
360
        num_chans = x.shape[1]
Boris Bonev's avatar
Boris Bonev committed
361

362
363
364
        # transpose such that lon is local, channels are split
        if self.comm_size_azimuth > 1:
            x = distributed_transpose_azimuth.apply(x, (1, -1), self.lon_in_shapes)
Boris Bonev's avatar
Boris Bonev committed
365

Thorsten Kurth's avatar
Thorsten Kurth committed
366
        # gather input tensor and set up backward reduction hooks
367
368
        x = gather_from_polar_region(x, -2, self.lat_in_shapes)
        x = copy_to_polar_region(x)
Boris Bonev's avatar
Boris Bonev committed
369
370
371
372
373

        if x.is_cuda and _cuda_extension_available:
            out = _disco_s2_transpose_contraction_cuda(
                x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out_local, self.nlon_out
            )
374
        else:
Boris Bonev's avatar
Boris Bonev committed
375
376
            if x.is_cuda:
                warn("couldn't find CUDA extension, falling back to slow PyTorch implementation")
377
            out = _disco_s2_transpose_contraction_torch(x, self.psi_st.to(x.device), self.nlon_out)
378
379
380
381
382
383
384
385
386
387

        # now we can transpose back the result, so that lon is split and channels are local
        if self.comm_size_azimuth > 1:
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
            out = distributed_transpose_azimuth.apply(out, (-1, 1), chan_shapes)

        if self.bias is not None:
            out = out + self.bias.reshape(1, -1, 1, 1)

        return out