resample.py 8.21 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

from typing import List, Tuple, Union, Optional
import math
Thorsten Kurth's avatar
Thorsten Kurth committed
34
#import numpy as np
Boris Bonev's avatar
Boris Bonev committed
35
36
37
38

import torch
import torch.nn as nn

Thorsten Kurth's avatar
Thorsten Kurth committed
39
from torch_harmonics.quadrature import _precompute_latitudes, _precompute_longitudes
Boris Bonev's avatar
Boris Bonev committed
40
41
42


class ResampleS2(nn.Module):
apaaris's avatar
apaaris committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    """
    Resampling module for signals on the 2-sphere.
    
    This module provides functionality to resample spherical signals between different
    grid resolutions and grid types using bilinear interpolation.
    
    Parameters
    -----------
    nlat_in : int
        Number of latitude points in the input grid
    nlon_in : int
        Number of longitude points in the input grid
    nlat_out : int
        Number of latitude points in the output grid
    nlon_out : int
        Number of longitude points in the output grid
    grid_in : str, optional
        Input grid type ("equiangular", "legendre-gauss", "lobatto"), by default "equiangular"
    grid_out : str, optional
        Output grid type ("equiangular", "legendre-gauss", "lobatto"), by default "equiangular"
    mode : str, optional
        Interpolation mode ("bilinear", "bilinear-spherical"), by default "bilinear"
    """
    
Boris Bonev's avatar
Boris Bonev committed
67
68
69
70
71
72
73
74
75
76
77
78
79
80
    def __init__(
        self,
        nlat_in: int,
        nlon_in: int,
        nlat_out: int,
        nlon_out: int,
        grid_in: Optional[str] = "equiangular",
        grid_out: Optional[str] = "equiangular",
        mode: Optional[str] = "bilinear",
    ):

        super().__init__()

        # currently only bilinear is supported
Thorsten Kurth's avatar
Thorsten Kurth committed
81
        if mode in ["bilinear", "bilinear-spherical"]:
Boris Bonev's avatar
Boris Bonev committed
82
83
84
85
86
87
88
89
90
91
92
93
            self.mode = mode
        else:
            raise NotImplementedError(f"unknown interpolation mode {mode}")

        self.nlat_in, self.nlon_in = nlat_in, nlon_in
        self.nlat_out, self.nlon_out = nlat_out, nlon_out

        self.grid_in = grid_in
        self.grid_out = grid_out

        # for upscaling the latitudes we will use interpolation
        self.lats_in, _ = _precompute_latitudes(nlat_in, grid=grid_in)
Thorsten Kurth's avatar
Thorsten Kurth committed
94
        self.lons_in = _precompute_longitudes(nlon_in)
Boris Bonev's avatar
Boris Bonev committed
95
        self.lats_out, _ = _precompute_latitudes(nlat_out, grid=grid_out)
Thorsten Kurth's avatar
Thorsten Kurth committed
96
        self.lons_out = _precompute_longitudes(nlon_out)
Boris Bonev's avatar
Boris Bonev committed
97

98
99
100
101
        # in the case where some points lie outside of the range spanned by lats_in,
        # we need to expand the solution to the poles before interpolating
        self.expand_poles = (self.lats_out > self.lats_in[-1]).any() or (self.lats_out < self.lats_in[0]).any()
        if self.expand_poles:
Thorsten Kurth's avatar
Thorsten Kurth committed
102
            self.lats_in = torch.cat([torch.as_tensor([0.], dtype=torch.float64, device=self.lats_in.device),
Thorsten Kurth's avatar
Thorsten Kurth committed
103
                                      self.lats_in,
Thorsten Kurth's avatar
Thorsten Kurth committed
104
                                      torch.as_tensor([math.pi], dtype=torch.float64, device=self.lats_in.device)]).contiguous()
105

Boris Bonev's avatar
Boris Bonev committed
106
        # prepare the interpolation by computing indices to the left and right of each output latitude
Thorsten Kurth's avatar
Thorsten Kurth committed
107
        lat_idx = torch.searchsorted(self.lats_in, self.lats_out, side="right") - 1
108
        # make sure that we properly treat the last point if they coincide with the pole
Thorsten Kurth's avatar
Thorsten Kurth committed
109
        lat_idx = torch.where(self.lats_out == self.lats_in[-1], lat_idx - 1, lat_idx)
Boris Bonev's avatar
Boris Bonev committed
110

111
112
113
        # lat_idx = np.where(self.lats_out > self.lats_in[-1], lat_idx - 1, lat_idx)
        # lat_idx = np.where(self.lats_out < self.lats_in[0], 0, lat_idx)

Boris Bonev's avatar
Boris Bonev committed
114
        # compute the interpolation weights along the latitude
Thorsten Kurth's avatar
Thorsten Kurth committed
115
        lat_weights = ((self.lats_out - self.lats_in[lat_idx]) / torch.diff(self.lats_in)[lat_idx]).to(torch.float32)
Boris Bonev's avatar
Boris Bonev committed
116
117
118
119
120
121
122
        lat_weights = lat_weights.unsqueeze(-1)

        # register buffers
        self.register_buffer("lat_idx", lat_idx, persistent=False)
        self.register_buffer("lat_weights", lat_weights, persistent=False)

        # get left and right indices but this time make sure periodicity in the longitude is handled
Thorsten Kurth's avatar
Thorsten Kurth committed
123
124
        lon_idx_left = torch.searchsorted(self.lons_in, self.lons_out, side="right") - 1
        lon_idx_right = torch.where(self.lons_out >= self.lons_in[-1], torch.zeros_like(lon_idx_left), lon_idx_left + 1)
Boris Bonev's avatar
Boris Bonev committed
125
126
127

        # get the difference
        diff = self.lons_in[lon_idx_right] - self.lons_in[lon_idx_left]
Thorsten Kurth's avatar
Thorsten Kurth committed
128
129
        diff = torch.where(diff < 0.0, diff + 2 * math.pi, diff)
        lon_weights = ((self.lons_out - self.lons_in[lon_idx_left]) / diff).to(torch.float32)
Boris Bonev's avatar
Boris Bonev committed
130
131
132
133
134
135

        # register buffers
        self.register_buffer("lon_idx_left", lon_idx_left, persistent=False)
        self.register_buffer("lon_idx_right", lon_idx_right, persistent=False)
        self.register_buffer("lon_weights", lon_weights, persistent=False)

Thorsten Kurth's avatar
Thorsten Kurth committed
136
137
138
        self.skip_resampling = (nlon_in == nlon_out) and (nlat_in == nlat_out) and (grid_in == grid_out)
        

Boris Bonev's avatar
Boris Bonev committed
139
140
141
142
    def extra_repr(self):
        return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}"

    def _upscale_longitudes(self, x: torch.Tensor):
Boris Bonev's avatar
Boris Bonev committed
143
144
        # do the interpolation in precision of x
        lwgt = self.lon_weights.to(x.dtype)
Thorsten Kurth's avatar
Thorsten Kurth committed
145
        if self.mode == "bilinear":
Boris Bonev's avatar
Boris Bonev committed
146
            x = torch.lerp(x[..., self.lon_idx_left], x[..., self.lon_idx_right], lwgt)
Thorsten Kurth's avatar
Thorsten Kurth committed
147
148
149
        else:
            omega = x[..., self.lon_idx_right] - x[..., self.lon_idx_left]
            somega = torch.sin(omega)
Boris Bonev's avatar
Boris Bonev committed
150
151
            start_prefac = torch.where(somega > 1e-4, torch.sin((1.0 - lwgt) * omega) / somega, (1.0 - lwgt))
            end_prefac = torch.where(somega > 1e-4, torch.sin(lwgt * omega) / somega, lwgt)
Thorsten Kurth's avatar
Thorsten Kurth committed
152
153
            x = start_prefac * x[..., self.lon_idx_left] + end_prefac * x[..., self.lon_idx_right]

Boris Bonev's avatar
Boris Bonev committed
154
155
        return x

156
    def _expand_poles(self, x: torch.Tensor):
157
158
159
160
161
162
        x_north = x[...,  0, :].mean(dim=-1, keepdims=True)
        x_south = x[..., -1, :].mean(dim=-1, keepdims=True)
        x = nn.functional.pad(x, pad=[0, 0, 1, 1], mode='constant')
        x[...,  0, :] = x_north[...]
        x[..., -1, :] = x_south[...]

163
164
        return x

Boris Bonev's avatar
Boris Bonev committed
165
    def _upscale_latitudes(self, x: torch.Tensor):
Boris Bonev's avatar
Boris Bonev committed
166
167
        # do the interpolation in precision of x
        lwgt = self.lat_weights.to(x.dtype)
Thorsten Kurth's avatar
Thorsten Kurth committed
168
        if self.mode == "bilinear":
Boris Bonev's avatar
Boris Bonev committed
169
            x = torch.lerp(x[..., self.lat_idx, :], x[..., self.lat_idx + 1, :], lwgt)
Thorsten Kurth's avatar
Thorsten Kurth committed
170
171
172
        else:
            omega = x[..., self.lat_idx + 1, :] - x[..., self.lat_idx, :]
            somega = torch.sin(omega)
Boris Bonev's avatar
Boris Bonev committed
173
174
            start_prefac = torch.where(somega > 1e-4, torch.sin((1.0 - lwgt) * omega) / somega, (1.0 - lwgt))
            end_prefac = torch.where(somega > 1e-4, torch.sin(lwgt * omega) / somega, lwgt)
Thorsten Kurth's avatar
Thorsten Kurth committed
175
176
            x = start_prefac * x[..., self.lat_idx, :] + end_prefac * x[..., self.lat_idx + 1, :]

Boris Bonev's avatar
Boris Bonev committed
177
178
179
        return x

    def forward(self, x: torch.Tensor):
Thorsten Kurth's avatar
Thorsten Kurth committed
180
181
182
        if self.skip_resampling:
            return x
        
183
184
        if self.expand_poles:
            x = self._expand_poles(x)
185

Boris Bonev's avatar
Boris Bonev committed
186
        x = self._upscale_latitudes(x)
187

Boris Bonev's avatar
Boris Bonev committed
188
        x = self._upscale_longitudes(x)
189

Boris Bonev's avatar
Boris Bonev committed
190
        return x