resample.py 7.38 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

from typing import List, Tuple, Union, Optional
import math
Thorsten Kurth's avatar
Thorsten Kurth committed
34
#import numpy as np
Boris Bonev's avatar
Boris Bonev committed
35
36
37
38

import torch
import torch.nn as nn

Thorsten Kurth's avatar
Thorsten Kurth committed
39
from torch_harmonics.quadrature import _precompute_latitudes, _precompute_longitudes
Boris Bonev's avatar
Boris Bonev committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56


class ResampleS2(nn.Module):
    def __init__(
        self,
        nlat_in: int,
        nlon_in: int,
        nlat_out: int,
        nlon_out: int,
        grid_in: Optional[str] = "equiangular",
        grid_out: Optional[str] = "equiangular",
        mode: Optional[str] = "bilinear",
    ):

        super().__init__()

        # currently only bilinear is supported
Thorsten Kurth's avatar
Thorsten Kurth committed
57
        if mode in ["bilinear", "bilinear-spherical"]:
Boris Bonev's avatar
Boris Bonev committed
58
59
60
61
62
63
64
65
66
67
68
69
            self.mode = mode
        else:
            raise NotImplementedError(f"unknown interpolation mode {mode}")

        self.nlat_in, self.nlon_in = nlat_in, nlon_in
        self.nlat_out, self.nlon_out = nlat_out, nlon_out

        self.grid_in = grid_in
        self.grid_out = grid_out

        # for upscaling the latitudes we will use interpolation
        self.lats_in, _ = _precompute_latitudes(nlat_in, grid=grid_in)
Thorsten Kurth's avatar
Thorsten Kurth committed
70
        self.lons_in = _precompute_longitudes(nlon_in)
Boris Bonev's avatar
Boris Bonev committed
71
        self.lats_out, _ = _precompute_latitudes(nlat_out, grid=grid_out)
Thorsten Kurth's avatar
Thorsten Kurth committed
72
        self.lons_out = _precompute_longitudes(nlon_out)
Boris Bonev's avatar
Boris Bonev committed
73

74
75
76
77
        # in the case where some points lie outside of the range spanned by lats_in,
        # we need to expand the solution to the poles before interpolating
        self.expand_poles = (self.lats_out > self.lats_in[-1]).any() or (self.lats_out < self.lats_in[0]).any()
        if self.expand_poles:
Thorsten Kurth's avatar
Thorsten Kurth committed
78
            self.lats_in = torch.cat([torch.as_tensor([0.], dtype=torch.float64, device=self.lats_in.device),
Thorsten Kurth's avatar
Thorsten Kurth committed
79
                                      self.lats_in,
Thorsten Kurth's avatar
Thorsten Kurth committed
80
                                      torch.as_tensor([math.pi], dtype=torch.float64, device=self.lats_in.device)]).contiguous()
81

Boris Bonev's avatar
Boris Bonev committed
82
        # prepare the interpolation by computing indices to the left and right of each output latitude
Thorsten Kurth's avatar
Thorsten Kurth committed
83
        lat_idx = torch.searchsorted(self.lats_in, self.lats_out, side="right") - 1
84
        # make sure that we properly treat the last point if they coincide with the pole
Thorsten Kurth's avatar
Thorsten Kurth committed
85
        lat_idx = torch.where(self.lats_out == self.lats_in[-1], lat_idx - 1, lat_idx)
Boris Bonev's avatar
Boris Bonev committed
86

87
88
89
        # lat_idx = np.where(self.lats_out > self.lats_in[-1], lat_idx - 1, lat_idx)
        # lat_idx = np.where(self.lats_out < self.lats_in[0], 0, lat_idx)

Boris Bonev's avatar
Boris Bonev committed
90
        # compute the interpolation weights along the latitude
Thorsten Kurth's avatar
Thorsten Kurth committed
91
        lat_weights = ((self.lats_out - self.lats_in[lat_idx]) / torch.diff(self.lats_in)[lat_idx]).to(torch.float32)
Boris Bonev's avatar
Boris Bonev committed
92
93
94
95
96
97
98
        lat_weights = lat_weights.unsqueeze(-1)

        # register buffers
        self.register_buffer("lat_idx", lat_idx, persistent=False)
        self.register_buffer("lat_weights", lat_weights, persistent=False)

        # get left and right indices but this time make sure periodicity in the longitude is handled
Thorsten Kurth's avatar
Thorsten Kurth committed
99
100
        lon_idx_left = torch.searchsorted(self.lons_in, self.lons_out, side="right") - 1
        lon_idx_right = torch.where(self.lons_out >= self.lons_in[-1], torch.zeros_like(lon_idx_left), lon_idx_left + 1)
Boris Bonev's avatar
Boris Bonev committed
101
102
103

        # get the difference
        diff = self.lons_in[lon_idx_right] - self.lons_in[lon_idx_left]
Thorsten Kurth's avatar
Thorsten Kurth committed
104
105
        diff = torch.where(diff < 0.0, diff + 2 * math.pi, diff)
        lon_weights = ((self.lons_out - self.lons_in[lon_idx_left]) / diff).to(torch.float32)
Boris Bonev's avatar
Boris Bonev committed
106
107
108
109
110
111

        # register buffers
        self.register_buffer("lon_idx_left", lon_idx_left, persistent=False)
        self.register_buffer("lon_idx_right", lon_idx_right, persistent=False)
        self.register_buffer("lon_weights", lon_weights, persistent=False)

Thorsten Kurth's avatar
Thorsten Kurth committed
112
113
114
        self.skip_resampling = (nlon_in == nlon_out) and (nlat_in == nlat_out) and (grid_in == grid_out)
        

Boris Bonev's avatar
Boris Bonev committed
115
116
117
118
119
120
121
    def extra_repr(self):
        r"""
        Pretty print module
        """
        return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}"

    def _upscale_longitudes(self, x: torch.Tensor):
Boris Bonev's avatar
Boris Bonev committed
122
123
        # do the interpolation in precision of x
        lwgt = self.lon_weights.to(x.dtype)
Thorsten Kurth's avatar
Thorsten Kurth committed
124
        if self.mode == "bilinear":
Boris Bonev's avatar
Boris Bonev committed
125
            x = torch.lerp(x[..., self.lon_idx_left], x[..., self.lon_idx_right], lwgt)
Thorsten Kurth's avatar
Thorsten Kurth committed
126
127
128
        else:
            omega = x[..., self.lon_idx_right] - x[..., self.lon_idx_left]
            somega = torch.sin(omega)
Boris Bonev's avatar
Boris Bonev committed
129
130
            start_prefac = torch.where(somega > 1e-4, torch.sin((1.0 - lwgt) * omega) / somega, (1.0 - lwgt))
            end_prefac = torch.where(somega > 1e-4, torch.sin(lwgt * omega) / somega, lwgt)
Thorsten Kurth's avatar
Thorsten Kurth committed
131
132
            x = start_prefac * x[..., self.lon_idx_left] + end_prefac * x[..., self.lon_idx_right]

Boris Bonev's avatar
Boris Bonev committed
133
134
        return x

135
    def _expand_poles(self, x: torch.Tensor):
136
137
138
139
140
141
        x_north = x[...,  0, :].mean(dim=-1, keepdims=True)
        x_south = x[..., -1, :].mean(dim=-1, keepdims=True)
        x = nn.functional.pad(x, pad=[0, 0, 1, 1], mode='constant')
        x[...,  0, :] = x_north[...]
        x[..., -1, :] = x_south[...]

142
143
        return x

Boris Bonev's avatar
Boris Bonev committed
144
    def _upscale_latitudes(self, x: torch.Tensor):
Boris Bonev's avatar
Boris Bonev committed
145
146
        # do the interpolation in precision of x
        lwgt = self.lat_weights.to(x.dtype)
Thorsten Kurth's avatar
Thorsten Kurth committed
147
        if self.mode == "bilinear":
Boris Bonev's avatar
Boris Bonev committed
148
            x = torch.lerp(x[..., self.lat_idx, :], x[..., self.lat_idx + 1, :], lwgt)
Thorsten Kurth's avatar
Thorsten Kurth committed
149
150
151
        else:
            omega = x[..., self.lat_idx + 1, :] - x[..., self.lat_idx, :]
            somega = torch.sin(omega)
Boris Bonev's avatar
Boris Bonev committed
152
153
            start_prefac = torch.where(somega > 1e-4, torch.sin((1.0 - lwgt) * omega) / somega, (1.0 - lwgt))
            end_prefac = torch.where(somega > 1e-4, torch.sin(lwgt * omega) / somega, lwgt)
Thorsten Kurth's avatar
Thorsten Kurth committed
154
155
            x = start_prefac * x[..., self.lat_idx, :] + end_prefac * x[..., self.lat_idx + 1, :]

Boris Bonev's avatar
Boris Bonev committed
156
157
158
        return x

    def forward(self, x: torch.Tensor):
Thorsten Kurth's avatar
Thorsten Kurth committed
159
160
161
        if self.skip_resampling:
            return x
        
162
163
        if self.expand_poles:
            x = self._expand_poles(x)
164

Boris Bonev's avatar
Boris Bonev committed
165
        x = self._upscale_latitudes(x)
166

Boris Bonev's avatar
Boris Bonev committed
167
        x = self._upscale_longitudes(x)
168

Boris Bonev's avatar
Boris Bonev committed
169
        return x