distributed_resample.py 11.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

from typing import List, Tuple, Union, Optional
import math

import torch
import torch.nn as nn

Thorsten Kurth's avatar
Thorsten Kurth committed
38
from torch_harmonics.quadrature import _precompute_latitudes, _precompute_longitudes
39
from torch_harmonics.distributed import polar_group_size, azimuth_group_size, distributed_transpose_azimuth, distributed_transpose_polar
40
from torch_harmonics.distributed import reduce_from_azimuth_region, copy_to_azimuth_region
41
42
43
44
45
from torch_harmonics.distributed import polar_group_rank, azimuth_group_rank
from torch_harmonics.distributed import compute_split_shapes


class DistributedResampleS2(nn.Module):
apaaris's avatar
apaaris committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
    r"""
    Distributed resampling module for spherical data on the 2-sphere.
    
    This module performs distributed resampling of spherical data across multiple processes,
    supporting both upscaling and downscaling operations. The data is distributed across
    polar and azimuthal directions, and the module handles the necessary communication
    and interpolation operations.
    
    Parameters
    -----------
    nlat_in : int
        Number of input latitude points
    nlon_in : int
        Number of input longitude points
    nlat_out : int
        Number of output latitude points
    nlon_out : int
        Number of output longitude points
    grid_in : str, optional
        Input grid type, by default "equiangular"
    grid_out : str, optional
        Output grid type, by default "equiangular"
    mode : str, optional
        Interpolation mode ("bilinear" or "bilinear-spherical"), by default "bilinear"
    """

72
73
74
75
76
77
78
79
80
81
82
83
84
85
    def __init__(
        self,
        nlat_in: int,
        nlon_in: int,
        nlat_out: int,
        nlon_out: int,
        grid_in: Optional[str] = "equiangular",
        grid_out: Optional[str] = "equiangular",
        mode: Optional[str] = "bilinear",
    ):

        super().__init__()

        # currently only bilinear is supported
Thorsten Kurth's avatar
Thorsten Kurth committed
86
        if mode in ["bilinear", "bilinear-spherical"]:
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
            self.mode = mode
        else:
            raise NotImplementedError(f"unknown interpolation mode {mode}")

        self.nlat_in, self.nlon_in = nlat_in, nlon_in
        self.nlat_out, self.nlon_out = nlat_out, nlon_out

        self.grid_in = grid_in
        self.grid_out = grid_out

        # get the comms grid:
        self.comm_size_polar = polar_group_size()
        self.comm_rank_polar = polar_group_rank()
        self.comm_size_azimuth = azimuth_group_size()
        self.comm_rank_azimuth = azimuth_group_rank()

        # compute splits: is this correct even when expanding the poles?
        self.lat_in_shapes = compute_split_shapes(self.nlat_in, self.comm_size_polar)
        self.lon_in_shapes = compute_split_shapes(self.nlon_in, self.comm_size_azimuth)
        self.lat_out_shapes = compute_split_shapes(self.nlat_out, self.comm_size_polar)
        self.lon_out_shapes = compute_split_shapes(self.nlon_out, self.comm_size_azimuth)

        # for upscaling the latitudes we will use interpolation
        self.lats_in, _ = _precompute_latitudes(nlat_in, grid=grid_in)
Thorsten Kurth's avatar
Thorsten Kurth committed
111
        self.lons_in = _precompute_longitudes(nlon_in)
112
        self.lats_out, _ = _precompute_latitudes(nlat_out, grid=grid_out)
Thorsten Kurth's avatar
Thorsten Kurth committed
113
        self.lons_out = _precompute_longitudes(nlon_out)
114
115
116
117
118

        # in the case where some points lie outside of the range spanned by lats_in,
        # we need to expand the solution to the poles before interpolating
        self.expand_poles = (self.lats_out > self.lats_in[-1]).any() or (self.lats_out < self.lats_in[0]).any()
        if self.expand_poles:
Thorsten Kurth's avatar
Thorsten Kurth committed
119
120
121
            self.lats_in = torch.cat([torch.tensor([0.], dtype=torch.float64),
                                      self.lats_in,
                                      torch.tensor([math.pi], dtype=torch.float64)]).contiguous()
122
123

        # prepare the interpolation by computing indices to the left and right of each output latitude
Thorsten Kurth's avatar
Thorsten Kurth committed
124
        lat_idx = torch.searchsorted(self.lats_in, self.lats_out, side="right") - 1
125
        # make sure that we properly treat the last point if they coincide with the pole
Thorsten Kurth's avatar
Thorsten Kurth committed
126
        lat_idx = torch.where(self.lats_out == self.lats_in[-1], lat_idx - 1, lat_idx)
127
128
129
130
131

        # lat_idx = np.where(self.lats_out > self.lats_in[-1], lat_idx - 1, lat_idx)
        # lat_idx = np.where(self.lats_out < self.lats_in[0], 0, lat_idx)

        # compute the interpolation weights along the latitude
Thorsten Kurth's avatar
Thorsten Kurth committed
132
        lat_weights = ((self.lats_out - self.lats_in[lat_idx]) / torch.diff(self.lats_in)[lat_idx]).to(torch.float32)
133
134
135
136
137
138
139
        lat_weights = lat_weights.unsqueeze(-1)

        # register buffers
        self.register_buffer("lat_idx", lat_idx, persistent=False)
        self.register_buffer("lat_weights", lat_weights, persistent=False)

        # get left and right indices but this time make sure periodicity in the longitude is handled
Thorsten Kurth's avatar
Thorsten Kurth committed
140
141
        lon_idx_left = torch.searchsorted(self.lons_in, self.lons_out, side="right") - 1
        lon_idx_right = torch.where(self.lons_out >= self.lons_in[-1], torch.zeros_like(lon_idx_left), lon_idx_left + 1)
142
143
144

        # get the difference
        diff = self.lons_in[lon_idx_right] - self.lons_in[lon_idx_left]
Thorsten Kurth's avatar
Thorsten Kurth committed
145
146
        diff = torch.where(diff < 0.0, diff + 2 * math.pi, diff)
        lon_weights = ((self.lons_out - self.lons_in[lon_idx_left]) / diff).to(torch.float32)
147
148
149
150
151
152

        # register buffers
        self.register_buffer("lon_idx_left", lon_idx_left, persistent=False)
        self.register_buffer("lon_idx_right", lon_idx_right, persistent=False)
        self.register_buffer("lon_weights", lon_weights, persistent=False)

Thorsten Kurth's avatar
Thorsten Kurth committed
153
154
        self.skip_resampling = (nlon_in == nlon_out) and (nlat_in == nlat_out) and (grid_in == grid_out)

155
156
157
158
159
160
161
    def extra_repr(self):
        r"""
        Pretty print module
        """
        return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}"

    def _upscale_longitudes(self, x: torch.Tensor):
apaaris's avatar
apaaris committed
162
163
164
165
166
167
168
169
170
171
172
173
174
        """
        Upscale the longitude dimension using interpolation.
        
        Parameters
        -----------
        x : torch.Tensor
            Input tensor with shape (..., nlat, nlon)
            
        Returns
        -------
        torch.Tensor
            Upscaled tensor in the longitude dimension
        """
175
        # do the interpolation
176
        lwgt = self.lon_weights.to(x.dtype)
Thorsten Kurth's avatar
Thorsten Kurth committed
177
        if self.mode == "bilinear":
178
            x = torch.lerp(x[..., self.lon_idx_left], x[..., self.lon_idx_right], lwgt)
Thorsten Kurth's avatar
Thorsten Kurth committed
179
180
181
        else:
            omega = x[..., self.lon_idx_right] - x[..., self.lon_idx_left]
            somega = torch.sin(omega)
182
183
            start_prefac = torch.where(somega > 1e-4, torch.sin((1.0 - lwgt) * omega) / somega, (1.0 - lwgt))
            end_prefac = torch.where(somega > 1e-4, torch.sin(lwgt * omega) / somega, lwgt)
Thorsten Kurth's avatar
Thorsten Kurth committed
184
185
            x = start_prefac * x[..., self.lon_idx_left] + end_prefac * x[..., self.lon_idx_right]

186
187
188
        return x

    def _expand_poles(self, x: torch.Tensor):
apaaris's avatar
apaaris committed
189
190
191
192
193
194
195
196
197
198
199
200
201
        """
        Expand the data to include pole values for interpolation.
        
        Parameters
        -----------
        x : torch.Tensor
            Input tensor with shape (..., nlat, nlon)
            
        Returns
        -------
        torch.Tensor
            Tensor with expanded pole values
        """
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
        x_north = x[...,  0, :].sum(dim=-1, keepdims=True)
        x_south = x[..., -1, :].sum(dim=-1, keepdims=True)
        x_count = torch.tensor([x.shape[-1]], dtype=torch.long, device=x.device, requires_grad=False)
        
        if self.comm_size_azimuth > 1:
            x_north = reduce_from_azimuth_region(x_north.contiguous())
            x_south = reduce_from_azimuth_region(x_south.contiguous())
            x_count = reduce_from_azimuth_region(x_count)
        x_north = x_north / x_count
        x_south = x_south / x_count

        if self.comm_size_azimuth > 1:
            x_north = copy_to_azimuth_region(x_north)
            x_south = copy_to_azimuth_region(x_south)
            
        x = nn.functional.pad(x, pad=[0, 0, 1, 1], mode='constant')
        x[..., 0, :] = x_north[...]
        x[..., -1, :] = x_south[...]

221
222
223
        return x

    def _upscale_latitudes(self, x: torch.Tensor):
apaaris's avatar
apaaris committed
224
225
226
227
228
229
230
231
232
233
234
235
236
        """
        Upscale the latitude dimension using interpolation.
        
        Parameters
        -----------
        x : torch.Tensor
            Input tensor with shape (..., nlat, nlon)
            
        Returns
        -------
        torch.Tensor
            Upscaled tensor in the latitude dimension
        """
237
        # do the interpolation
238
        lwgt = self.lat_weights.to(x.dtype)
Thorsten Kurth's avatar
Thorsten Kurth committed
239
        if self.mode == "bilinear":
240
            x = torch.lerp(x[..., self.lat_idx, :], x[..., self.lat_idx + 1, :], lwgt)
Thorsten Kurth's avatar
Thorsten Kurth committed
241
242
243
        else:
            omega = x[..., self.lat_idx + 1, :] - x[..., self.lat_idx, :]
            somega = torch.sin(omega)
244
245
            start_prefac = torch.where(somega > 1e-4, torch.sin((1.0 - lwgt) * omega) / somega, (1.0 - lwgt))
            end_prefac = torch.where(somega > 1e-4, torch.sin(lwgt * omega) / somega, lwgt)
Thorsten Kurth's avatar
Thorsten Kurth committed
246
247
            x = start_prefac * x[..., self.lat_idx, :] + end_prefac * x[..., self.lat_idx + 1, :]

248
249
250
251
        return x

    def forward(self, x: torch.Tensor):

Thorsten Kurth's avatar
Thorsten Kurth committed
252
253
254
        if self.skip_resampling:
            return x

255
256
        # transpose data so that h is local, and channels are split
        num_chans = x.shape[-3]
257
        
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
        # h and w is split. First we make w local by transposing into channel dim
        if self.comm_size_polar > 1:
            channels_shapes = compute_split_shapes(num_chans, self.comm_size_polar)
            x = distributed_transpose_polar.apply(x, (-3, -2), self.lat_in_shapes)

        # expand poles if requested
        if self.expand_poles:
            x = self._expand_poles(x)

        # upscaling
        x = self._upscale_latitudes(x)

        # now, transpose back
        if self.comm_size_polar > 1:
            x = distributed_transpose_polar.apply(x, (-2, -3), channels_shapes)

        # now, transpose in w:
        if self.comm_size_azimuth > 1:
            channels_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
            x = distributed_transpose_azimuth.apply(x, (-3, -1), self.lon_in_shapes)

        # upscale
        x = self._upscale_longitudes(x)

        # transpose back
        if self.comm_size_azimuth > 1:
            x = distributed_transpose_azimuth.apply(x, (-1, -3), channels_shapes)

        return x