distributed_resample.py 9.81 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

from typing import List, Tuple, Union, Optional
import math

import torch
import torch.nn as nn

Thorsten Kurth's avatar
Thorsten Kurth committed
38
from torch_harmonics.quadrature import _precompute_latitudes, _precompute_longitudes
39
from torch_harmonics.distributed import polar_group_size, azimuth_group_size, distributed_transpose_azimuth, distributed_transpose_polar
40
from torch_harmonics.distributed import reduce_from_azimuth_region, copy_to_azimuth_region
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from torch_harmonics.distributed import polar_group_rank, azimuth_group_rank
from torch_harmonics.distributed import compute_split_shapes


class DistributedResampleS2(nn.Module):
    def __init__(
        self,
        nlat_in: int,
        nlon_in: int,
        nlat_out: int,
        nlon_out: int,
        grid_in: Optional[str] = "equiangular",
        grid_out: Optional[str] = "equiangular",
        mode: Optional[str] = "bilinear",
    ):

        super().__init__()

        # currently only bilinear is supported
Thorsten Kurth's avatar
Thorsten Kurth committed
60
        if mode in ["bilinear", "bilinear-spherical"]:
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
            self.mode = mode
        else:
            raise NotImplementedError(f"unknown interpolation mode {mode}")

        self.nlat_in, self.nlon_in = nlat_in, nlon_in
        self.nlat_out, self.nlon_out = nlat_out, nlon_out

        self.grid_in = grid_in
        self.grid_out = grid_out

        # get the comms grid:
        self.comm_size_polar = polar_group_size()
        self.comm_rank_polar = polar_group_rank()
        self.comm_size_azimuth = azimuth_group_size()
        self.comm_rank_azimuth = azimuth_group_rank()

        # compute splits: is this correct even when expanding the poles?
        self.lat_in_shapes = compute_split_shapes(self.nlat_in, self.comm_size_polar)
        self.lon_in_shapes = compute_split_shapes(self.nlon_in, self.comm_size_azimuth)
        self.lat_out_shapes = compute_split_shapes(self.nlat_out, self.comm_size_polar)
        self.lon_out_shapes = compute_split_shapes(self.nlon_out, self.comm_size_azimuth)

        # for upscaling the latitudes we will use interpolation
        self.lats_in, _ = _precompute_latitudes(nlat_in, grid=grid_in)
Thorsten Kurth's avatar
Thorsten Kurth committed
85
        self.lons_in = _precompute_longitudes(nlon_in)
86
        self.lats_out, _ = _precompute_latitudes(nlat_out, grid=grid_out)
Thorsten Kurth's avatar
Thorsten Kurth committed
87
        self.lons_out = _precompute_longitudes(nlon_out)
88
89
90
91
92

        # in the case where some points lie outside of the range spanned by lats_in,
        # we need to expand the solution to the poles before interpolating
        self.expand_poles = (self.lats_out > self.lats_in[-1]).any() or (self.lats_out < self.lats_in[0]).any()
        if self.expand_poles:
Thorsten Kurth's avatar
Thorsten Kurth committed
93
94
95
            self.lats_in = torch.cat([torch.tensor([0.], dtype=torch.float64),
                                      self.lats_in,
                                      torch.tensor([math.pi], dtype=torch.float64)]).contiguous()
96
97

        # prepare the interpolation by computing indices to the left and right of each output latitude
Thorsten Kurth's avatar
Thorsten Kurth committed
98
        lat_idx = torch.searchsorted(self.lats_in, self.lats_out, side="right") - 1
99
        # make sure that we properly treat the last point if they coincide with the pole
Thorsten Kurth's avatar
Thorsten Kurth committed
100
        lat_idx = torch.where(self.lats_out == self.lats_in[-1], lat_idx - 1, lat_idx)
101
102
103
104
105

        # lat_idx = np.where(self.lats_out > self.lats_in[-1], lat_idx - 1, lat_idx)
        # lat_idx = np.where(self.lats_out < self.lats_in[0], 0, lat_idx)

        # compute the interpolation weights along the latitude
Thorsten Kurth's avatar
Thorsten Kurth committed
106
        lat_weights = ((self.lats_out - self.lats_in[lat_idx]) / torch.diff(self.lats_in)[lat_idx]).to(torch.float32)
107
108
109
110
111
112
113
        lat_weights = lat_weights.unsqueeze(-1)

        # register buffers
        self.register_buffer("lat_idx", lat_idx, persistent=False)
        self.register_buffer("lat_weights", lat_weights, persistent=False)

        # get left and right indices but this time make sure periodicity in the longitude is handled
Thorsten Kurth's avatar
Thorsten Kurth committed
114
115
        lon_idx_left = torch.searchsorted(self.lons_in, self.lons_out, side="right") - 1
        lon_idx_right = torch.where(self.lons_out >= self.lons_in[-1], torch.zeros_like(lon_idx_left), lon_idx_left + 1)
116
117
118

        # get the difference
        diff = self.lons_in[lon_idx_right] - self.lons_in[lon_idx_left]
Thorsten Kurth's avatar
Thorsten Kurth committed
119
120
        diff = torch.where(diff < 0.0, diff + 2 * math.pi, diff)
        lon_weights = ((self.lons_out - self.lons_in[lon_idx_left]) / diff).to(torch.float32)
121
122
123
124
125
126

        # register buffers
        self.register_buffer("lon_idx_left", lon_idx_left, persistent=False)
        self.register_buffer("lon_idx_right", lon_idx_right, persistent=False)
        self.register_buffer("lon_weights", lon_weights, persistent=False)

Thorsten Kurth's avatar
Thorsten Kurth committed
127
128
        self.skip_resampling = (nlon_in == nlon_out) and (nlat_in == nlat_out) and (grid_in == grid_out)

129
130
131
132
133
134
135
136
    def extra_repr(self):
        r"""
        Pretty print module
        """
        return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}"

    def _upscale_longitudes(self, x: torch.Tensor):
        # do the interpolation
137
        lwgt = self.lon_weights.to(x.dtype)
Thorsten Kurth's avatar
Thorsten Kurth committed
138
        if self.mode == "bilinear":
139
            x = torch.lerp(x[..., self.lon_idx_left], x[..., self.lon_idx_right], lwgt)
Thorsten Kurth's avatar
Thorsten Kurth committed
140
141
142
        else:
            omega = x[..., self.lon_idx_right] - x[..., self.lon_idx_left]
            somega = torch.sin(omega)
143
144
            start_prefac = torch.where(somega > 1e-4, torch.sin((1.0 - lwgt) * omega) / somega, (1.0 - lwgt))
            end_prefac = torch.where(somega > 1e-4, torch.sin(lwgt * omega) / somega, lwgt)
Thorsten Kurth's avatar
Thorsten Kurth committed
145
146
            x = start_prefac * x[..., self.lon_idx_left] + end_prefac * x[..., self.lon_idx_right]

147
148
149
        return x

    def _expand_poles(self, x: torch.Tensor):
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
        x_north = x[...,  0, :].sum(dim=-1, keepdims=True)
        x_south = x[..., -1, :].sum(dim=-1, keepdims=True)
        x_count = torch.tensor([x.shape[-1]], dtype=torch.long, device=x.device, requires_grad=False)
        
        if self.comm_size_azimuth > 1:
            x_north = reduce_from_azimuth_region(x_north.contiguous())
            x_south = reduce_from_azimuth_region(x_south.contiguous())
            x_count = reduce_from_azimuth_region(x_count)
        x_north = x_north / x_count
        x_south = x_south / x_count

        if self.comm_size_azimuth > 1:
            x_north = copy_to_azimuth_region(x_north)
            x_south = copy_to_azimuth_region(x_south)
            
        x = nn.functional.pad(x, pad=[0, 0, 1, 1], mode='constant')
        x[..., 0, :] = x_north[...]
        x[..., -1, :] = x_south[...]

169
170
171
172
        return x

    def _upscale_latitudes(self, x: torch.Tensor):
        # do the interpolation
173
        lwgt = self.lat_weights.to(x.dtype)
Thorsten Kurth's avatar
Thorsten Kurth committed
174
        if self.mode == "bilinear":
175
            x = torch.lerp(x[..., self.lat_idx, :], x[..., self.lat_idx + 1, :], lwgt)
Thorsten Kurth's avatar
Thorsten Kurth committed
176
177
178
        else:
            omega = x[..., self.lat_idx + 1, :] - x[..., self.lat_idx, :]
            somega = torch.sin(omega)
179
180
            start_prefac = torch.where(somega > 1e-4, torch.sin((1.0 - lwgt) * omega) / somega, (1.0 - lwgt))
            end_prefac = torch.where(somega > 1e-4, torch.sin(lwgt * omega) / somega, lwgt)
Thorsten Kurth's avatar
Thorsten Kurth committed
181
182
            x = start_prefac * x[..., self.lat_idx, :] + end_prefac * x[..., self.lat_idx + 1, :]

183
184
185
186
        return x

    def forward(self, x: torch.Tensor):

Thorsten Kurth's avatar
Thorsten Kurth committed
187
188
189
        if self.skip_resampling:
            return x

190
191
        # transpose data so that h is local, and channels are split
        num_chans = x.shape[-3]
192
        
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
        # h and w is split. First we make w local by transposing into channel dim
        if self.comm_size_polar > 1:
            channels_shapes = compute_split_shapes(num_chans, self.comm_size_polar)
            x = distributed_transpose_polar.apply(x, (-3, -2), self.lat_in_shapes)

        # expand poles if requested
        if self.expand_poles:
            x = self._expand_poles(x)

        # upscaling
        x = self._upscale_latitudes(x)

        # now, transpose back
        if self.comm_size_polar > 1:
            x = distributed_transpose_polar.apply(x, (-2, -3), channels_shapes)

        # now, transpose in w:
        if self.comm_size_azimuth > 1:
            channels_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
            x = distributed_transpose_azimuth.apply(x, (-3, -1), self.lon_in_shapes)

        # upscale
        x = self._upscale_longitudes(x)

        # transpose back
        if self.comm_size_azimuth > 1:
            x = distributed_transpose_azimuth.apply(x, (-1, -3), channels_shapes)

        return x