README.md 14.2 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
<!--
Boris Bonev's avatar
Boris Bonev committed
2
3
4
SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.

SPDX-License-Identifier: BSD-3-Clause
Boris Bonev's avatar
Boris Bonev committed
5

Boris Bonev's avatar
Boris Bonev committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

1. Redistributions of source code must retain the above copyright notice, this
   list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
   this list of conditions and the following disclaimer in the documentation
   and/or other materials provided with the distribution.

3. Neither the name of the copyright holder nor the names of its
   contributors may be used to endorse or promote products derived from
   this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-->

31
<!-- <div align="center">
Boris Bonev's avatar
Boris Bonev committed
32
    <img src="https://raw.githubusercontent.com/NVIDIA/torch-harmonics/main/images/logo/logo.png"  width="568">
Boris Bonev's avatar
Boris Bonev committed
33
    <br>
Boris Bonev's avatar
Boris Bonev committed
34
    <a href="https://github.com/NVIDIA/torch-harmonics/actions/workflows/tests.yml"><img src="https://github.com/NVIDIA/torch-harmonics/actions/workflows/tests.yml/badge.svg"></a>
Boris Bonev's avatar
Boris Bonev committed
35
    <a href="https://pypi.org/project/torch_harmonics/"><img src="https://img.shields.io/pypi/v/torch_harmonics"></a>
36
37
</div> -->

Boris Bonev's avatar
Boris Bonev committed
38
39
40
41
<!--
[![pypi](https://img.shields.io/pypi/v/torch_harmonics)](https://pypi.org/project/torch_harmonics/)
-->

42
43
44
45
46
47
48
49
<!-- # spherical harmonic transforms -->

# torch-harmonics

[**Overview**](#overview) | [**Installation**](#installation) | [**More information**](#more-about-torch-harmonics) | [**Getting started**](#getting-started) | [**Contributors**](#contributors) | [**Cite us**](#cite-us) | [**References**](#references)

[![tests](https://github.com/NVIDIA/torch-harmonics/actions/workflows/tests.yml/badge.svg)](https://github.com/NVIDIA/torch-harmonics/actions/workflows/tests.yml)
[![pypi](https://img.shields.io/pypi/v/torch_harmonics)](https://pypi.org/project/torch_harmonics/)
Boris Bonev's avatar
Boris Bonev committed
50

51
## Overview
Boris Bonev's avatar
Boris Bonev committed
52

Boris Bonev's avatar
Boris Bonev committed
53
54
55
torch-harmonics implements differentiable signal processing on the sphere. This includes differentiable implementations of the spherical harmonic transforms, vector spherical harmonic transforms and discrete-continuous convolutions on the sphere. The package was originally implemented to enable Spherical Fourier Neural Operators (SFNO) [1].

The SHT algorithm uses quadrature rules to compute the projection onto the associated Legendre polynomials and FFTs for the projection onto the harmonic basis. This algorithm tends to outperform others with better asymptotic scaling for most practical purposes [2].
Boris Bonev's avatar
Boris Bonev committed
56

Boris Bonev's avatar
Boris Bonev committed
57
torch-harmonics uses PyTorch primitives to implement these operations, making it fully differentiable. Moreover, the quadrature can be distributed onto multiple ranks making it spatially distributed.
Boris Bonev's avatar
Boris Bonev committed
58

Boris Bonev's avatar
Boris Bonev committed
59
torch-harmonics has been used to implement a variety of differentiable PDE solvers which generated the animations below. Moreover, it has enabled the development of Spherical Fourier Neural Operators  [1].
Boris Bonev's avatar
Boris Bonev committed
60

61
<div align="center">
Boris Bonev's avatar
Boris Bonev committed
62
63
<table border="0" cellspacing="0" cellpadding="0">
    <tr>
Boris Bonev's avatar
Boris Bonev committed
64
65
66
        <td><img src="https://media.githubusercontent.com/media/NVIDIA/torch-harmonics/main/images/sfno.gif"  width="240"></td>
        <td><img src="https://media.githubusercontent.com/media/NVIDIA/torch-harmonics/main/images/zonal_jet.gif"  width="240"></td>
        <td><img src="https://media.githubusercontent.com/media/NVIDIA/torch-harmonics/main/images/allen-cahn.gif"  width="240"></td>
67
    </tr>
Boris Bonev's avatar
Boris Bonev committed
68
<!--     <tr>
Boris Bonev's avatar
Boris Bonev committed
69
70
71
        <td style="text-align:center; border-style : hidden!important;">Shallow Water Eqns.</td>
        <td style="text-align:center; border-style : hidden!important;">Ginzburg-Landau Eqn.</td>
        <td style="text-align:center; border-style : hidden!important;">Allen-Cahn Eqn.</td>
Boris Bonev's avatar
Boris Bonev committed
72
    </tr>  -->
Boris Bonev's avatar
Boris Bonev committed
73
</table>
74
</div>
Boris Bonev's avatar
Boris Bonev committed
75
76
77


## Installation
78
A simple installation can be directly done from PyPI:
Boris Bonev's avatar
Boris Bonev committed
79

80
```bash
Boris Bonev's avatar
Boris Bonev committed
81
82
pip install torch-harmonics
```
83
If you are planning to use spherical convolutions, we recommend building the corresponding custom CUDA kernels. To enforce this, you can set the `FORCE_CUDA_EXTENSION` flag. You may also want to set appropriate architectures with the `TORCH_CUDA_ARCH_LIST` flag. Finally, make sure to disable build isolation via the `--no-build-isolation` flag to ensure that the custom kernels are built with the existing torch installation.
Boris Bonev's avatar
Boris Bonev committed
84
```bash
85
86
export FORCE_CUDA_EXTENSION=1
export TORCH_CUDA_ARCH_LIST="7.0 7.2 7.5 8.0 8.6 8.7 9.0+PTX"
87
pip install --no-build-isolation torch-harmonics
Boris Bonev's avatar
Boris Bonev committed
88
89
```
:warning: Please note that the custom CUDA extensions currently only support CUDA architectures >= 7.0.
Boris Bonev's avatar
Boris Bonev committed
90

Boris Bonev's avatar
Boris Bonev committed
91
If you want to actively develop torch-harmonics, we recommend building it in your environment from github:
Boris Bonev's avatar
Boris Bonev committed
92

93
```bash
Boris Bonev's avatar
Boris Bonev committed
94
git clone git@github.com:NVIDIA/torch-harmonics.git
95
96
cd torch-harmonics
pip install -e .
Boris Bonev's avatar
Boris Bonev committed
97
98
99
100
```

Alternatively, use the Dockerfile to build your custom container after cloning:

101
```bash
Boris Bonev's avatar
Boris Bonev committed
102
git clone git@github.com:NVIDIA/torch-harmonics.git
103
cd torch-harmonics
Boris Bonev's avatar
Boris Bonev committed
104
105
106
107
docker build . -t torch_harmonics
docker run --gpus all -it --rm --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 torch_harmonics
```

108
## More about torch-harmonics
Boris Bonev's avatar
Boris Bonev committed
109

110
### Spherical harmonics
Boris Bonev's avatar
Boris Bonev committed
111

112
The [spherical harmonics](https://en.wikipedia.org/wiki/Spherical_harmonics) are special functions defined on the two-dimensional sphere $S^2$ (embedded in three dimensions). They form an orthonormal basis of the space of square-integrable functions defined on the sphere $L^2(S^2)$ and are comparable to the harmonic functions defined on a circle/torus. The spherical harmonics are defined as
113
114

$$
115
Y_l^m(\theta, \lambda) = \sqrt{\frac{(2l + 1)}{4 \pi} \frac{(l - m)!}{(l + m)!}} P_l^m(\cos \theta)  \exp(im\lambda),
116
117
118
119
$$

where $\theta$ and $\lambda$ are colatitude and longitude respectively, and $P_l^m$ the normalized, [associated Legendre polynomials](https://en.wikipedia.org/wiki/Associated_Legendre_polynomials).

120
121
122
123
124
<div align="center">
<img src="https://media.githubusercontent.com/media/NVIDIA/torch-harmonics/main/images/spherical_harmonics.gif" width="432">
<br>
Spherical harmonics up to degree 5
</div>
Boris Bonev's avatar
Boris Bonev committed
125
126
127

### Spherical harmonic transform

128
129
130
The spherical harmonic transform (SHT)

$$
131
f_l^m = \int_{S^2}  \overline{Y_{l}^{m}}(\theta, \lambda) f(\theta, \lambda) \mathrm{d} \mu(\theta, \lambda)
132
133
$$

134
realizes the projection of a signal $f(\theta, \lambda)$ on $S^2$ onto the spherical harmonics basis. The SHT generalizes the Fourier transform on the sphere. Conversely, a truncated series expansion of a function $f$ can be written in terms of spherical harmonics as
Boris Bonev's avatar
Boris Bonev committed
135
136

$$
137
f (\theta, \lambda) = \sum_{m=-M}^{M} \exp(im\lambda) \sum_{l=|m|}^{M} \hat f_l^m  P_l^m (\cos \theta),
Boris Bonev's avatar
Boris Bonev committed
138
139
$$

140
where $\hat{f}_l^m$, are the expansion coefficients associated to the mode $m$, $n$.
Boris Bonev's avatar
Boris Bonev committed
141

142
The implementation of the SHT follows the algorithm as presented in [2]. A direct spherical harmonic transform can be accomplished by a Fourier transform
Boris Bonev's avatar
Boris Bonev committed
143
144

$$
145
\hat f^m(\theta) = \frac{1}{2 \pi} \int_{0}^{2\pi} f(\theta, \lambda) \exp(-im\lambda) \mathrm{d} \lambda
Boris Bonev's avatar
Boris Bonev committed
146
147
148
149
150
$$

in longitude and a Legendre transform

$$
151
\hat f_l^m = \frac{1}{2} \int^{\pi}_0 \hat f^{m} (\theta) P_l^m (\cos \theta) \sin \theta \mathrm{d} \theta
Boris Bonev's avatar
Boris Bonev committed
152
153
154
155
156
157
$$

in latitude.

### Discrete Legendre transform

158
The second integral, which computed the projection onto the Legendre polynomials is realized with quadrature. On the Gaussian grid, we use Gaussian quadrature in the $\cos \theta$ domain. The integral
Boris Bonev's avatar
Boris Bonev committed
159
160

$$
161
\hat f_l^m = \frac{1}{2} \int_{-1}^1 \hat{f}^m(\arccos x) P_l^m (x) \mathrm{d} x
Boris Bonev's avatar
Boris Bonev committed
162
163
$$

164
is obtained with the substitution $x = \cos \theta$ and then approximated by the sum
Boris Bonev's avatar
Boris Bonev committed
165
166

$$
167
\hat f_l^m = \sum_{j=1}^{N_\theta}  \hat{f}^m(\arccos x_j) P_l^m(x_j) w_j.
Boris Bonev's avatar
Boris Bonev committed
168
169
$$

170
Here, $x_j \in [-1,1]$ are the quadrature nodes with the respective quadrature weights $w_j$.
Boris Bonev's avatar
Boris Bonev committed
171

Boris Bonev's avatar
Boris Bonev committed
172
### Discrete-continuous convolutions on the sphere
Boris Bonev's avatar
Boris Bonev committed
173

Boris Bonev's avatar
Boris Bonev committed
174
175
176
177
178
torch-harmonics now provides local discrete-continuous (DISCO) convolutions as outlined in [4] on the sphere. These are use in local neural operators to generalize convolutions to structured and unstructured meshes on the sphere.

### Spherical (neighborhood) attention

torch-harmonics introducers spherical attention mechanisms which correctly generalize the attention mechanism to the sphere. The use of quadrature rules makes the resulting operations approximately equivariant and equivariant in the continuous limit. Moreover, neighborhood attention is correctly generalized onto the sphere by using the geodesic distance to determine the size of the neighborhood.
Boris Bonev's avatar
Boris Bonev committed
179

180
## Getting started
Boris Bonev's avatar
Boris Bonev committed
181
182
183
184
185

The main functionality of `torch_harmonics` is provided in the form of `torch.nn.Modules` for composability. A minimum example is given by:

```python
import torch
Boris Bonev's avatar
Boris Bonev committed
186
import torch_harmonics as th
Boris Bonev's avatar
Boris Bonev committed
187
188
189
190
191
192

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

nlat = 512
nlon = 2*nlat
batch_size = 32
Mike McCann's avatar
Mike McCann committed
193
signal = torch.randn(batch_size, nlat, nlon, device=device)
Boris Bonev's avatar
Boris Bonev committed
194
195

# transform data on an equiangular grid
Boris Bonev's avatar
Boris Bonev committed
196
sht = th.RealSHT(nlat, nlon, grid="equiangular").to(device)
Boris Bonev's avatar
Boris Bonev committed
197
198
199

coeffs = sht(signal)
```
Boris Bonev's avatar
Boris Bonev committed
200

Boris Bonev's avatar
Boris Bonev committed
201
To enable scalable model-parallelism, `torch-harmonics` implements a distributed variant of the SHT located in `torch_harmonics.distributed`.
Boris Bonev's avatar
Boris Bonev committed
202

203
Detailed usage of torch-harmonics, alongside helpful analysis provided in a series of notebooks:
204
205
206
207

1. [Getting started](./notebooks/getting_started.ipynb)
2. [Quadrature](./notebooks/quadrature.ipynb)
3. [Visualizing the spherical harmonics](./notebooks/plot_spherical_harmonics.ipynb)
208
209
210
211
4. [Spectral fitting vs. SHT](./notebooks/gradient_analysis.ipynb)
5. [Conditioning of the Gramian](./notebooks/conditioning_sht.ipynb)
6. [Solving the Helmholtz equation](./notebooks/helmholtz.ipynb)
7. [Solving the shallow water equations](./notebooks/shallow_water_equations.ipynb)
Boris Bonev's avatar
Boris Bonev committed
212
213
8. [Training Spherical Fourier Neural Operators (SFNO)](./notebooks/train_sfno.ipynb)
9. [Resampling signals on the sphere](./notebooks/resample_sphere.ipynb)
214

Boris Bonev's avatar
Boris Bonev committed
215
216
217
218
219
220
221
222
223
224
## Examples and reproducibility

The `examples` folder contains training scripts for three distinct tasks:

* [solution of the shallow water equations on the rotating sphere](./examples/shallow_water_equations/train.py)
* [depth estimation on the sphere](./examples/depth/train.py)
* [semantic segmentation on the sphere](./examples/segmentation/train.py)

Results from the papers can generally be reproduced by running `python train.py`. In the case of some older results the number of epochs and learning-rate may need to be adjusted by passing the corresponding command line argument.

Thorsten Kurth's avatar
Thorsten Kurth committed
225
226
## Remarks on automatic mixed precision (AMP) support

Boris Bonev's avatar
Boris Bonev committed
227
Note that torch-harmonics uses Fourier transforms from `torch.fft` which in turn uses kernels from the optimized `cuFFT` library. This library supports fourier transforms of `float32` and `float64` (i.e. `single` and `double` precision) tensors for all input sizes. For `float16` (i.e. `half` precision) and `bfloat16` inputs however, the dimensions which are transformed are restricted to powers of two. Since data is converted to one of these reduced precision floating point formats when `torch.autocast` is used, torch-harmonics will issue an error when the input shapes are not powers of two. For these cases, we recommend disabling autocast for the harmonics transform specifically:
Thorsten Kurth's avatar
Thorsten Kurth committed
228
229
230
231
232
233
234

```python
import torch
import torch_harmonics as th

sht = th.RealSHT(512, 1024, grid="equiangular").cuda()

Boris Bonev's avatar
Boris Bonev committed
235
with torch.autocast(device_type="cuda", enabled = True):
Thorsten Kurth's avatar
Thorsten Kurth committed
236
237
238
239
240
241
242
   # do some AMP converted math here
   x = some_math(x)
   # convert tensor to float32
   x = x.to(torch.float32)
   # now disable autocast specifically for the transform,
   # making sure that the tensors are not converted
   # back to reduced precision internally
Boris Bonev's avatar
Boris Bonev committed
243
   with torch.autocast(device_type="cuda", enabled = False):
Thorsten Kurth's avatar
Thorsten Kurth committed
244
245
246
247
248
249
250
251
      xt = sht(x)

   # continue operating on the transformed tensor
   xt = some_more_math(xt)
```

Depending on the problem, it might be beneficial to upcast data to `float64` instead of `float32` precision for numerical stability.

252
253
## Contributors

Boris Bonev's avatar
Boris Bonev committed
254
[Boris Bonev](https://bonevbs.github.io) (bbonev@nvidia.com), [Thorsten Kurth](https://github.com/azrael417) (tkurth@nvidia.com), [Max Rietmann](https://github.com/rietmann-nv), [Mauro Bisson](https://scholar.google.com/citations?hl=en&user=f0JE-0gAAAAJ), [Andrea Paris](https://github.com/apaaris), [Alberto Carpentieri](https://github.com/albertocarpentieri), [Massimiliano Fatica](https://scholar.google.com/citations?user=Deaq4uUAAAAJ&hl=en), [Nikola Kovachki](https://kovachki.github.io), [Jean Kossaifi](http://jeankossaifi.com), [Christian Hundt](https://github.com/gravitino)
255
256

## Cite us
Boris Bonev's avatar
Boris Bonev committed
257
258
259

If you use `torch-harmonics` in an academic paper, please cite [1]

260
```bibtex
Boris Bonev's avatar
Boris Bonev committed
261
@misc{bonev2023spherical,
Boris Bonev's avatar
Boris Bonev committed
262
      title={Spherical Fourier Neural Operators: Learning Stable Dynamics on the Sphere},
Boris Bonev's avatar
Boris Bonev committed
263
264
265
266
267
268
269
270
      author={Boris Bonev and Thorsten Kurth and Christian Hundt and Jaideep Pathak and Maximilian Baust and Karthik Kashinath and Anima Anandkumar},
      year={2023},
      eprint={2306.03838},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}
```

Boris Bonev's avatar
Boris Bonev committed
271
272
## References

Boris Bonev's avatar
Boris Bonev committed
273
<a id="1">[1]</a>
Boris Bonev's avatar
Boris Bonev committed
274
275
Bonev B., Kurth T., Hundt C., Pathak, J., Baust M., Kashinath K., Anandkumar A.;
Spherical Fourier Neural Operators: Learning Stable Dynamics on the Sphere;
276
International Conference on Machine Learning, 2023. [arxiv link](https://arxiv.org/abs/2306.03838)
Boris Bonev's avatar
Boris Bonev committed
277

Boris Bonev's avatar
Boris Bonev committed
278
<a id="1">[2]</a>
Boris Bonev's avatar
Boris Bonev committed
279
Schaeffer N.;
Boris Bonev's avatar
Boris Bonev committed
280
Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations;
Boris Bonev's avatar
Boris Bonev committed
281
G3: Geochemistry, Geophysics, Geosystems, 2013.
Boris Bonev's avatar
Boris Bonev committed
282

Boris Bonev's avatar
Boris Bonev committed
283
<a id="1">[3]</a>
Boris Bonev's avatar
Boris Bonev committed
284
Wang B., Wang L., Xie Z.;
Boris Bonev's avatar
Boris Bonev committed
285
Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids;
Boris Bonev's avatar
Boris Bonev committed
286
Adv Comput Math, 2018.
Boris Bonev's avatar
Boris Bonev committed
287
288
289

<a id="1">[4]</a>
Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603