README.md 12.4 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
<!--
Boris Bonev's avatar
Boris Bonev committed
2
3
4
SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.

SPDX-License-Identifier: BSD-3-Clause
Boris Bonev's avatar
Boris Bonev committed
5

Boris Bonev's avatar
Boris Bonev committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

1. Redistributions of source code must retain the above copyright notice, this
   list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
   this list of conditions and the following disclaimer in the documentation
   and/or other materials provided with the distribution.

3. Neither the name of the copyright holder nor the names of its
   contributors may be used to endorse or promote products derived from
   this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-->

31
<!-- <div align="center">
Boris Bonev's avatar
Boris Bonev committed
32
    <img src="https://raw.githubusercontent.com/NVIDIA/torch-harmonics/main/images/logo/logo.png"  width="568">
Boris Bonev's avatar
Boris Bonev committed
33
    <br>
Boris Bonev's avatar
Boris Bonev committed
34
    <a href="https://github.com/NVIDIA/torch-harmonics/actions/workflows/tests.yml"><img src="https://github.com/NVIDIA/torch-harmonics/actions/workflows/tests.yml/badge.svg"></a>
Boris Bonev's avatar
Boris Bonev committed
35
    <a href="https://pypi.org/project/torch_harmonics/"><img src="https://img.shields.io/pypi/v/torch_harmonics"></a>
36
37
</div> -->

Boris Bonev's avatar
Boris Bonev committed
38
39
40
41
<!--
[![pypi](https://img.shields.io/pypi/v/torch_harmonics)](https://pypi.org/project/torch_harmonics/)
-->

42
43
44
45
46
47
48
49
<!-- # spherical harmonic transforms -->

# torch-harmonics

[**Overview**](#overview) | [**Installation**](#installation) | [**More information**](#more-about-torch-harmonics) | [**Getting started**](#getting-started) | [**Contributors**](#contributors) | [**Cite us**](#cite-us) | [**References**](#references)

[![tests](https://github.com/NVIDIA/torch-harmonics/actions/workflows/tests.yml/badge.svg)](https://github.com/NVIDIA/torch-harmonics/actions/workflows/tests.yml)
[![pypi](https://img.shields.io/pypi/v/torch_harmonics)](https://pypi.org/project/torch_harmonics/)
Boris Bonev's avatar
Boris Bonev committed
50

51
## Overview
Boris Bonev's avatar
Boris Bonev committed
52

Boris Bonev's avatar
Boris Bonev committed
53
54
55
torch-harmonics implements differentiable signal processing on the sphere. This includes differentiable implementations of the spherical harmonic transforms, vector spherical harmonic transforms and discrete-continuous convolutions on the sphere. The package was originally implemented to enable Spherical Fourier Neural Operators (SFNO) [1].

The SHT algorithm uses quadrature rules to compute the projection onto the associated Legendre polynomials and FFTs for the projection onto the harmonic basis. This algorithm tends to outperform others with better asymptotic scaling for most practical purposes [2].
Boris Bonev's avatar
Boris Bonev committed
56

Boris Bonev's avatar
Boris Bonev committed
57
torch-harmonics uses PyTorch primitives to implement these operations, making it fully differentiable. Moreover, the quadrature can be distributed onto multiple ranks making it spatially distributed.
Boris Bonev's avatar
Boris Bonev committed
58

Boris Bonev's avatar
Boris Bonev committed
59
torch-harmonics has been used to implement a variety of differentiable PDE solvers which generated the animations below. Moreover, it has enabled the development of Spherical Fourier Neural Operators (SFNOs) [1].
Boris Bonev's avatar
Boris Bonev committed
60

61
<div align="center">
Boris Bonev's avatar
Boris Bonev committed
62
63
<table border="0" cellspacing="0" cellpadding="0">
    <tr>
Boris Bonev's avatar
Boris Bonev committed
64
65
66
        <td><img src="https://media.githubusercontent.com/media/NVIDIA/torch-harmonics/main/images/sfno.gif"  width="240"></td>
        <td><img src="https://media.githubusercontent.com/media/NVIDIA/torch-harmonics/main/images/zonal_jet.gif"  width="240"></td>
        <td><img src="https://media.githubusercontent.com/media/NVIDIA/torch-harmonics/main/images/allen-cahn.gif"  width="240"></td>
67
    </tr>
Boris Bonev's avatar
Boris Bonev committed
68
<!--     <tr>
Boris Bonev's avatar
Boris Bonev committed
69
70
71
        <td style="text-align:center; border-style : hidden!important;">Shallow Water Eqns.</td>
        <td style="text-align:center; border-style : hidden!important;">Ginzburg-Landau Eqn.</td>
        <td style="text-align:center; border-style : hidden!important;">Allen-Cahn Eqn.</td>
Boris Bonev's avatar
Boris Bonev committed
72
    </tr>  -->
Boris Bonev's avatar
Boris Bonev committed
73
</table>
74
</div>
Boris Bonev's avatar
Boris Bonev committed
75
76
77


## Installation
Boris Bonev's avatar
Boris Bonev committed
78
Download directly from PyPI:
Boris Bonev's avatar
Boris Bonev committed
79

80
```bash
Boris Bonev's avatar
Boris Bonev committed
81
82
pip install torch-harmonics
```
Boris Bonev's avatar
Boris Bonev committed
83
84
85
86
87
If you would like to have accelerated CUDA extensions for the discrete-continuous convolutions, please use the '--cuda_ext' flag:
```bash
pip install --global-option --cuda_ext torch-harmonics
```
:warning: Please note that the custom CUDA extensions currently only support CUDA architectures >= 7.0.
Boris Bonev's avatar
Boris Bonev committed
88

Boris Bonev's avatar
Boris Bonev committed
89
If you want to actively develop torch-harmonics, we recommend building it in your environment from github:
Boris Bonev's avatar
Boris Bonev committed
90

91
```bash
Boris Bonev's avatar
Boris Bonev committed
92
git clone git@github.com:NVIDIA/torch-harmonics.git
93
94
cd torch-harmonics
pip install -e .
Boris Bonev's avatar
Boris Bonev committed
95
96
97
98
```

Alternatively, use the Dockerfile to build your custom container after cloning:

99
```bash
Boris Bonev's avatar
Boris Bonev committed
100
git clone git@github.com:NVIDIA/torch-harmonics.git
101
cd torch-harmonics
Boris Bonev's avatar
Boris Bonev committed
102
103
104
105
docker build . -t torch_harmonics
docker run --gpus all -it --rm --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 torch_harmonics
```

106
## More about torch-harmonics
Boris Bonev's avatar
Boris Bonev committed
107

108
### Spherical harmonics
Boris Bonev's avatar
Boris Bonev committed
109

110
The [spherical harmonics](https://en.wikipedia.org/wiki/Spherical_harmonics) are special functions defined on the two-dimensional sphere $S^2$ (embedded in three dimensions). They form an orthonormal basis of the space of square-integrable functions defined on the sphere $L^2(S^2)$ and are comparable to the harmonic functions defined on a circle/torus. The spherical harmonics are defined as
111
112

$$
113
Y_l^m(\theta, \lambda) = \sqrt{\frac{(2l + 1)}{4 \pi} \frac{(l - m)!}{(l + m)!}} P_l^m(\cos \theta)  \exp(im\lambda),
114
115
116
117
$$

where $\theta$ and $\lambda$ are colatitude and longitude respectively, and $P_l^m$ the normalized, [associated Legendre polynomials](https://en.wikipedia.org/wiki/Associated_Legendre_polynomials).

118
119
120
121
122
<div align="center">
<img src="https://media.githubusercontent.com/media/NVIDIA/torch-harmonics/main/images/spherical_harmonics.gif" width="432">
<br>
Spherical harmonics up to degree 5
</div>
Boris Bonev's avatar
Boris Bonev committed
123
124
125

### Spherical harmonic transform

126
127
128
The spherical harmonic transform (SHT)

$$
129
f_l^m = \int_{S^2}  \overline{Y_{l}^{m}}(\theta, \lambda) f(\theta, \lambda) \mathrm{d} \mu(\theta, \lambda)
130
131
$$

132
realizes the projection of a signal $f(\theta, \lambda)$ on $S^2$ onto the spherical harmonics basis. The SHT generalizes the Fourier transform on the sphere. Conversely, a truncated series expansion of a function $f$ can be written in terms of spherical harmonics as
Boris Bonev's avatar
Boris Bonev committed
133
134

$$
135
f (\theta, \lambda) = \sum_{m=-M}^{M} \exp(im\lambda) \sum_{l=|m|}^{M} \hat f_l^m  P_l^m (\cos \theta),
Boris Bonev's avatar
Boris Bonev committed
136
137
$$

138
where $\hat{f}_l^m$, are the expansion coefficients associated to the mode $m$, $n$.
Boris Bonev's avatar
Boris Bonev committed
139

140
The implementation of the SHT follows the algorithm as presented in [2]. A direct spherical harmonic transform can be accomplished by a Fourier transform
Boris Bonev's avatar
Boris Bonev committed
141
142

$$
143
\hat f^m(\theta) = \frac{1}{2 \pi} \int_{0}^{2\pi} f(\theta, \lambda) \exp(-im\lambda) \mathrm{d} \lambda
Boris Bonev's avatar
Boris Bonev committed
144
145
146
147
148
$$

in longitude and a Legendre transform

$$
149
\hat f_l^m = \frac{1}{2} \int^{\pi}_0 \hat f^{m} (\theta) P_l^m (\cos \theta) \sin \theta \mathrm{d} \theta
Boris Bonev's avatar
Boris Bonev committed
150
151
152
153
154
155
$$

in latitude.

### Discrete Legendre transform

156
The second integral, which computed the projection onto the Legendre polynomials is realized with quadrature. On the Gaussian grid, we use Gaussian quadrature in the $\cos \theta$ domain. The integral
Boris Bonev's avatar
Boris Bonev committed
157
158

$$
159
\hat f_l^m = \frac{1}{2} \int_{-1}^1 \hat{f}^m(\arccos x) P_l^m (x) \mathrm{d} x
Boris Bonev's avatar
Boris Bonev committed
160
161
$$

162
is obtained with the substitution $x = \cos \theta$ and then approximated by the sum
Boris Bonev's avatar
Boris Bonev committed
163
164

$$
165
\hat f_l^m = \sum_{j=1}^{N_\theta}  \hat{f}^m(\arccos x_j) P_l^m(x_j) w_j.
Boris Bonev's avatar
Boris Bonev committed
166
167
$$

168
Here, $x_j \in [-1,1]$ are the quadrature nodes with the respective quadrature weights $w_j$.
Boris Bonev's avatar
Boris Bonev committed
169

Boris Bonev's avatar
Boris Bonev committed
170
171
172
173
### Discrete-continuous convolutions

torch-harmonics now provides local discrete-continuous (DISCO) convolutions as outlined in [4] on the sphere.

174
## Getting started
Boris Bonev's avatar
Boris Bonev committed
175
176
177
178
179

The main functionality of `torch_harmonics` is provided in the form of `torch.nn.Modules` for composability. A minimum example is given by:

```python
import torch
Boris Bonev's avatar
Boris Bonev committed
180
import torch_harmonics as th
Boris Bonev's avatar
Boris Bonev committed
181
182
183
184
185
186
187
188
189

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

nlat = 512
nlon = 2*nlat
batch_size = 32
signal = torch.randn(batch_size, nlat, nlon)

# transform data on an equiangular grid
Boris Bonev's avatar
Boris Bonev committed
190
sht = th.RealSHT(nlat, nlon, grid="equiangular").to(device)
Boris Bonev's avatar
Boris Bonev committed
191
192
193

coeffs = sht(signal)
```
Boris Bonev's avatar
Boris Bonev committed
194

Boris Bonev's avatar
Boris Bonev committed
195
To enable scalable model-parallelism, `torch-harmonics` implements a distributed variant of the SHT located in `torch_harmonics.distributed`.
Boris Bonev's avatar
Boris Bonev committed
196

197
Detailed usage of torch-harmonics, alongside helpful analysis provided in a series of notebooks:
198
199
200
201

1. [Getting started](./notebooks/getting_started.ipynb)
2. [Quadrature](./notebooks/quadrature.ipynb)
3. [Visualizing the spherical harmonics](./notebooks/plot_spherical_harmonics.ipynb)
202
203
204
205
206
4. [Spectral fitting vs. SHT](./notebooks/gradient_analysis.ipynb)
5. [Conditioning of the Gramian](./notebooks/conditioning_sht.ipynb)
6. [Solving the Helmholtz equation](./notebooks/helmholtz.ipynb)
7. [Solving the shallow water equations](./notebooks/shallow_water_equations.ipynb)
8. [Training Spherical Fourier Neural Operators](./notebooks/train_sfno.ipynb)
207

Thorsten Kurth's avatar
Thorsten Kurth committed
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
## Remarks on automatic mixed precision (AMP) support

Note that torch-harmonics uses Fourier transforms from `torch.fft` which in turn uses kernels from the optimized `cuFFT` library. This library supports fourier transforms of `float32` and `float64` (i.e. `single` and `double` precision) tensors for all input sizes. For `float16` (i.e. `half` precision) and `bfloat16` inputs however, the dimensions which are transformed are restricted to powers of two. Since data is converted to one of these reduced precision floating point formats when `torch.cuda.amp.autocast` is used, torch-harmonics will issue an error when the input shapes are not powers of two. For these cases, we recommend disabling autocast for the harmonics transform specifically:

```python
import torch
import torch_harmonics as th

sht = th.RealSHT(512, 1024, grid="equiangular").cuda()

with torch.cuda.amp.autocast(enabled = True):
   # do some AMP converted math here
   x = some_math(x)
   # convert tensor to float32
   x = x.to(torch.float32)
   # now disable autocast specifically for the transform,
   # making sure that the tensors are not converted
   # back to reduced precision internally
   with torch.cuda.amp.autocast(enabled = False):
      xt = sht(x)

   # continue operating on the transformed tensor
   xt = some_more_math(xt)
```

Depending on the problem, it might be beneficial to upcast data to `float64` instead of `float32` precision for numerical stability.

235
236
## Contributors

Boris Bonev's avatar
Boris Bonev committed
237
[Boris Bonev](https://bonevbs.github.io) (bbonev@nvidia.com), [Thorsten Kurth](https://github.com/azrael417) (tkurth@nvidia.com), [Mauro Bisson](https://scholar.google.com/citations?hl=en&user=f0JE-0gAAAAJ) , [Massimiliano Fatica](https://scholar.google.com/citations?user=Deaq4uUAAAAJ&hl=en), [Nikola Kovachki](https://kovachki.github.io), [Jean Kossaifi](http://jeankossaifi.com), [Christian Hundt](https://github.com/gravitino)
238
239

## Cite us
Boris Bonev's avatar
Boris Bonev committed
240
241
242

If you use `torch-harmonics` in an academic paper, please cite [1]

243
```bibtex
Boris Bonev's avatar
Boris Bonev committed
244
@misc{bonev2023spherical,
Boris Bonev's avatar
Boris Bonev committed
245
      title={Spherical Fourier Neural Operators: Learning Stable Dynamics on the Sphere},
Boris Bonev's avatar
Boris Bonev committed
246
247
248
249
250
251
252
253
      author={Boris Bonev and Thorsten Kurth and Christian Hundt and Jaideep Pathak and Maximilian Baust and Karthik Kashinath and Anima Anandkumar},
      year={2023},
      eprint={2306.03838},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}
```

Boris Bonev's avatar
Boris Bonev committed
254
255
## References

Boris Bonev's avatar
Boris Bonev committed
256
<a id="1">[1]</a>
Boris Bonev's avatar
Boris Bonev committed
257
258
259
260
Bonev B., Kurth T., Hundt C., Pathak, J., Baust M., Kashinath K., Anandkumar A.;
Spherical Fourier Neural Operators: Learning Stable Dynamics on the Sphere;
arXiv 2306.0383, 2023.

Boris Bonev's avatar
Boris Bonev committed
261
<a id="1">[2]</a>
Boris Bonev's avatar
Boris Bonev committed
262
Schaeffer N.;
Boris Bonev's avatar
Boris Bonev committed
263
Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations;
Boris Bonev's avatar
Boris Bonev committed
264
G3: Geochemistry, Geophysics, Geosystems, 2013.
Boris Bonev's avatar
Boris Bonev committed
265

Boris Bonev's avatar
Boris Bonev committed
266
<a id="1">[3]</a>
Boris Bonev's avatar
Boris Bonev committed
267
Wang B., Wang L., Xie Z.;
Boris Bonev's avatar
Boris Bonev committed
268
Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids;
Boris Bonev's avatar
Boris Bonev committed
269
Adv Comput Math, 2018.
Boris Bonev's avatar
Boris Bonev committed
270
271
272

<a id="1">[4]</a>
Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603