resample.py 9.59 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

from typing import List, Tuple, Union, Optional
import math
Thorsten Kurth's avatar
Thorsten Kurth committed
34
#import numpy as np
Boris Bonev's avatar
Boris Bonev committed
35
36
37
38

import torch
import torch.nn as nn

Thorsten Kurth's avatar
Thorsten Kurth committed
39
from torch_harmonics.quadrature import _precompute_latitudes, _precompute_longitudes
Boris Bonev's avatar
Boris Bonev committed
40
41
42


class ResampleS2(nn.Module):
apaaris's avatar
apaaris committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    """
    Resampling module for signals on the 2-sphere.
    
    This module provides functionality to resample spherical signals between different
    grid resolutions and grid types using bilinear interpolation.
    
    Parameters
    -----------
    nlat_in : int
        Number of latitude points in the input grid
    nlon_in : int
        Number of longitude points in the input grid
    nlat_out : int
        Number of latitude points in the output grid
    nlon_out : int
        Number of longitude points in the output grid
    grid_in : str, optional
        Input grid type ("equiangular", "legendre-gauss", "lobatto"), by default "equiangular"
    grid_out : str, optional
        Output grid type ("equiangular", "legendre-gauss", "lobatto"), by default "equiangular"
    mode : str, optional
        Interpolation mode ("bilinear", "bilinear-spherical"), by default "bilinear"
    """
    
Boris Bonev's avatar
Boris Bonev committed
67
68
69
70
71
72
73
74
75
76
77
78
79
80
    def __init__(
        self,
        nlat_in: int,
        nlon_in: int,
        nlat_out: int,
        nlon_out: int,
        grid_in: Optional[str] = "equiangular",
        grid_out: Optional[str] = "equiangular",
        mode: Optional[str] = "bilinear",
    ):

        super().__init__()

        # currently only bilinear is supported
Thorsten Kurth's avatar
Thorsten Kurth committed
81
        if mode in ["bilinear", "bilinear-spherical"]:
Boris Bonev's avatar
Boris Bonev committed
82
83
84
85
86
87
88
89
90
91
92
93
            self.mode = mode
        else:
            raise NotImplementedError(f"unknown interpolation mode {mode}")

        self.nlat_in, self.nlon_in = nlat_in, nlon_in
        self.nlat_out, self.nlon_out = nlat_out, nlon_out

        self.grid_in = grid_in
        self.grid_out = grid_out

        # for upscaling the latitudes we will use interpolation
        self.lats_in, _ = _precompute_latitudes(nlat_in, grid=grid_in)
Thorsten Kurth's avatar
Thorsten Kurth committed
94
        self.lons_in = _precompute_longitudes(nlon_in)
Boris Bonev's avatar
Boris Bonev committed
95
        self.lats_out, _ = _precompute_latitudes(nlat_out, grid=grid_out)
Thorsten Kurth's avatar
Thorsten Kurth committed
96
        self.lons_out = _precompute_longitudes(nlon_out)
Boris Bonev's avatar
Boris Bonev committed
97

98
99
100
101
        # in the case where some points lie outside of the range spanned by lats_in,
        # we need to expand the solution to the poles before interpolating
        self.expand_poles = (self.lats_out > self.lats_in[-1]).any() or (self.lats_out < self.lats_in[0]).any()
        if self.expand_poles:
Thorsten Kurth's avatar
Thorsten Kurth committed
102
            self.lats_in = torch.cat([torch.as_tensor([0.], dtype=torch.float64, device=self.lats_in.device),
Thorsten Kurth's avatar
Thorsten Kurth committed
103
                                      self.lats_in,
Thorsten Kurth's avatar
Thorsten Kurth committed
104
                                      torch.as_tensor([math.pi], dtype=torch.float64, device=self.lats_in.device)]).contiguous()
105

Boris Bonev's avatar
Boris Bonev committed
106
        # prepare the interpolation by computing indices to the left and right of each output latitude
Thorsten Kurth's avatar
Thorsten Kurth committed
107
        lat_idx = torch.searchsorted(self.lats_in, self.lats_out, side="right") - 1
108
        # make sure that we properly treat the last point if they coincide with the pole
Thorsten Kurth's avatar
Thorsten Kurth committed
109
        lat_idx = torch.where(self.lats_out == self.lats_in[-1], lat_idx - 1, lat_idx)
Boris Bonev's avatar
Boris Bonev committed
110

111
112
113
        # lat_idx = np.where(self.lats_out > self.lats_in[-1], lat_idx - 1, lat_idx)
        # lat_idx = np.where(self.lats_out < self.lats_in[0], 0, lat_idx)

Boris Bonev's avatar
Boris Bonev committed
114
        # compute the interpolation weights along the latitude
Thorsten Kurth's avatar
Thorsten Kurth committed
115
        lat_weights = ((self.lats_out - self.lats_in[lat_idx]) / torch.diff(self.lats_in)[lat_idx]).to(torch.float32)
Boris Bonev's avatar
Boris Bonev committed
116
117
118
119
120
121
122
        lat_weights = lat_weights.unsqueeze(-1)

        # register buffers
        self.register_buffer("lat_idx", lat_idx, persistent=False)
        self.register_buffer("lat_weights", lat_weights, persistent=False)

        # get left and right indices but this time make sure periodicity in the longitude is handled
Thorsten Kurth's avatar
Thorsten Kurth committed
123
124
        lon_idx_left = torch.searchsorted(self.lons_in, self.lons_out, side="right") - 1
        lon_idx_right = torch.where(self.lons_out >= self.lons_in[-1], torch.zeros_like(lon_idx_left), lon_idx_left + 1)
Boris Bonev's avatar
Boris Bonev committed
125
126
127

        # get the difference
        diff = self.lons_in[lon_idx_right] - self.lons_in[lon_idx_left]
Thorsten Kurth's avatar
Thorsten Kurth committed
128
129
        diff = torch.where(diff < 0.0, diff + 2 * math.pi, diff)
        lon_weights = ((self.lons_out - self.lons_in[lon_idx_left]) / diff).to(torch.float32)
Boris Bonev's avatar
Boris Bonev committed
130
131
132
133
134
135

        # register buffers
        self.register_buffer("lon_idx_left", lon_idx_left, persistent=False)
        self.register_buffer("lon_idx_right", lon_idx_right, persistent=False)
        self.register_buffer("lon_weights", lon_weights, persistent=False)

Thorsten Kurth's avatar
Thorsten Kurth committed
136
137
138
        self.skip_resampling = (nlon_in == nlon_out) and (nlat_in == nlat_out) and (grid_in == grid_out)
        

Boris Bonev's avatar
Boris Bonev committed
139
140
141
142
143
144
145
    def extra_repr(self):
        r"""
        Pretty print module
        """
        return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}"

    def _upscale_longitudes(self, x: torch.Tensor):
apaaris's avatar
apaaris committed
146
147
148
149
150
151
152
153
154
155
156
157
158
        """
        Interpolate the input tensor along the longitude dimension.
        
        Parameters
        -----------
        x : torch.Tensor
            Input tensor with shape (..., nlat, nlon)
            
        Returns
        -------
        torch.Tensor
            Interpolated tensor along longitude dimension
        """
Boris Bonev's avatar
Boris Bonev committed
159
160
        # do the interpolation in precision of x
        lwgt = self.lon_weights.to(x.dtype)
Thorsten Kurth's avatar
Thorsten Kurth committed
161
        if self.mode == "bilinear":
Boris Bonev's avatar
Boris Bonev committed
162
            x = torch.lerp(x[..., self.lon_idx_left], x[..., self.lon_idx_right], lwgt)
Thorsten Kurth's avatar
Thorsten Kurth committed
163
164
165
        else:
            omega = x[..., self.lon_idx_right] - x[..., self.lon_idx_left]
            somega = torch.sin(omega)
Boris Bonev's avatar
Boris Bonev committed
166
167
            start_prefac = torch.where(somega > 1e-4, torch.sin((1.0 - lwgt) * omega) / somega, (1.0 - lwgt))
            end_prefac = torch.where(somega > 1e-4, torch.sin(lwgt * omega) / somega, lwgt)
Thorsten Kurth's avatar
Thorsten Kurth committed
168
169
            x = start_prefac * x[..., self.lon_idx_left] + end_prefac * x[..., self.lon_idx_right]

Boris Bonev's avatar
Boris Bonev committed
170
171
        return x

172
    def _expand_poles(self, x: torch.Tensor):
apaaris's avatar
apaaris committed
173
174
175
176
177
178
179
180
181
182
183
184
185
        """
        Expand the input tensor to include pole points for interpolation.
        
        Parameters
        -----------
        x : torch.Tensor
            Input tensor with shape (..., nlat, nlon)
            
        Returns
        -------
        torch.Tensor
            Expanded tensor with pole points added
        """
186
187
188
189
190
191
        x_north = x[...,  0, :].mean(dim=-1, keepdims=True)
        x_south = x[..., -1, :].mean(dim=-1, keepdims=True)
        x = nn.functional.pad(x, pad=[0, 0, 1, 1], mode='constant')
        x[...,  0, :] = x_north[...]
        x[..., -1, :] = x_south[...]

192
193
        return x

Boris Bonev's avatar
Boris Bonev committed
194
    def _upscale_latitudes(self, x: torch.Tensor):
apaaris's avatar
apaaris committed
195
196
197
198
199
200
201
202
203
204
205
206
207
        """
        Interpolate the input tensor along the latitude dimension.
        
        Parameters
        -----------
        x : torch.Tensor
            Input tensor with shape (..., nlat, nlon)
            
        Returns
        -------
        torch.Tensor
            Interpolated tensor along latitude dimension
        """
Boris Bonev's avatar
Boris Bonev committed
208
209
        # do the interpolation in precision of x
        lwgt = self.lat_weights.to(x.dtype)
Thorsten Kurth's avatar
Thorsten Kurth committed
210
        if self.mode == "bilinear":
Boris Bonev's avatar
Boris Bonev committed
211
            x = torch.lerp(x[..., self.lat_idx, :], x[..., self.lat_idx + 1, :], lwgt)
Thorsten Kurth's avatar
Thorsten Kurth committed
212
213
214
        else:
            omega = x[..., self.lat_idx + 1, :] - x[..., self.lat_idx, :]
            somega = torch.sin(omega)
Boris Bonev's avatar
Boris Bonev committed
215
216
            start_prefac = torch.where(somega > 1e-4, torch.sin((1.0 - lwgt) * omega) / somega, (1.0 - lwgt))
            end_prefac = torch.where(somega > 1e-4, torch.sin(lwgt * omega) / somega, lwgt)
Thorsten Kurth's avatar
Thorsten Kurth committed
217
218
            x = start_prefac * x[..., self.lat_idx, :] + end_prefac * x[..., self.lat_idx + 1, :]

Boris Bonev's avatar
Boris Bonev committed
219
220
221
        return x

    def forward(self, x: torch.Tensor):
apaaris's avatar
apaaris committed
222
223
224
225
226
227
228
229
230
231
232
233
234
        """
        Forward pass of the resampling module.
        
        Parameters
        -----------
        x : torch.Tensor
            Input tensor with shape (..., nlat_in, nlon_in)
            
        Returns
        -------
        torch.Tensor
            Resampled tensor with shape (..., nlat_out, nlon_out)
        """
Thorsten Kurth's avatar
Thorsten Kurth committed
235
236
237
        if self.skip_resampling:
            return x
        
238
239
        if self.expand_poles:
            x = self._expand_poles(x)
240

Boris Bonev's avatar
Boris Bonev committed
241
        x = self._upscale_latitudes(x)
242

Boris Bonev's avatar
Boris Bonev committed
243
        x = self._upscale_longitudes(x)
244

Boris Bonev's avatar
Boris Bonev committed
245
        return x