Commit 6f4fe730 authored by Boris Bonev's avatar Boris Bonev
Browse files

v0.4 commit

parents
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import torch
import torch.distributed as dist
from .utils import get_model_parallel_group, is_initialized
# general helpers
def get_memory_format(tensor):
if tensor.is_contiguous(memory_format=torch.channels_last):
return torch.channels_last
else:
return torch.contiguous_format
def split_tensor_along_dim(tensor, dim, num_chunks):
assert dim < tensor.dim(), f"Error, tensor dimension is {tensor.dim()} which cannot be split along {dim}"
assert (tensor.shape[dim] % num_chunks == 0), f"Error, cannot split dim {dim} evenly. Dim size is \
{tensor.shape[dim]} and requested numnber of splits is {num_chunks}"
chunk_size = tensor.shape[dim] // num_chunks
tensor_list = torch.split(tensor, chunk_size, dim=dim)
return tensor_list
# split
def _split(input_, dim_, group=None):
"""Split the tensor along its last dimension and keep the corresponding slice."""
# get input format
input_format = get_memory_format(input_)
# Bypass the function if we are using only 1 GPU.
comm_size = dist.get_world_size(group=group)
if comm_size == 1:
return input_
# Split along last dimension.
input_list = split_tensor_along_dim(input_, dim_, comm_size)
# Note: torch.split does not create contiguous tensors by default.
rank = dist.get_rank(group=group)
output = input_list[rank].contiguous(memory_format=input_format)
return output
# those are used by the various helper functions
def _reduce(input_, use_fp32=True, group=None):
"""All-reduce the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
if dist.get_world_size(group=group) == 1:
return input_
# All-reduce.
if use_fp32:
dtype = input_.dtype
inputf_ = input_.float()
dist.all_reduce(inputf_, group=group)
input_ = inputf_.to(dtype)
else:
dist.all_reduce(input_, group=group)
return input_
class _CopyToParallelRegion(torch.autograd.Function):
"""Pass the input to the parallel region."""
@staticmethod
def symbolic(graph, input_):
return input_
@staticmethod
def forward(ctx, input_):
return input_
@staticmethod
def backward(ctx, grad_output):
return _reduce(grad_output, group=get_model_parallel_group())
# write a convenient functional wrapper
def copy_to_parallel_region(input_):
if not is_initialized():
return input_
else:
return _CopyToParallelRegion.apply(input_)
# reduce
class _ReduceFromParallelRegion(torch.autograd.Function):
"""All-reduce the input from the parallel region."""
@staticmethod
def symbolic(graph, input_):
return _reduce(input_, group=get_model_parallel_group())
@staticmethod
def forward(ctx, input_):
return _reduce(input_, group=get_model_parallel_group())
@staticmethod
def backward(ctx, grad_output):
return grad_output
def reduce_from_parallel_region(input_):
if not is_initialized():
return input_
else:
return _ReduceFromParallelRegion.apply(input_)
# gather
def _gather(input_, dim_, group=None):
"""Gather tensors and concatinate along the last dimension."""
# get input format
input_format = get_memory_format(input_)
print(input_format)
comm_size = dist.get_world_size(group=group)
# Bypass the function if we are using only 1 GPU.
if comm_size==1:
return input_
# sanity checks
assert(dim_ < input_.dim()), f"Error, cannot gather along {dim_} for tensor with {input_.dim()} dimensions."
# Size and dimension.
comm_rank = dist.get_rank(group=group)
# input needs to be contiguous
input_ = input_.contiguous(memory_format=input_format)
tensor_list = [torch.empty_like(input_) for _ in range(comm_size)]
tensor_list[comm_rank] = input_
dist.all_gather(tensor_list, input_, group=group)
output = torch.cat(tensor_list, dim=dim_).contiguous(memory_format=input_format)
return output
class _GatherFromParallelRegion(torch.autograd.Function):
"""Gather the input from parallel region and concatinate."""
@staticmethod
def symbolic(graph, input_, dim_):
return _gather(input_, dim_, group=get_model_parallel_group())
@staticmethod
def forward(ctx, input_, dim_):
ctx.dim = dim_
return _gather(input_, dim_, group=get_model_parallel_group())
@staticmethod
def backward(ctx, grad_output):
return _split(grad_output, ctx.dim, group=get_model_parallel_group()), None
def gather_from_parallel_region(input_, dim):
if not is_initialized():
return input_
else:
return _GatherFromParallelRegion.apply(input_, dim)
# scatter
class _ScatterToParallelRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank."""
@staticmethod
def symbolic(graph, input_, dim_):
return _split(input_, dim_, group=get_model_parallel_group())
@staticmethod
def forward(ctx, input_, dim_):
ctx.dim = dim_
return _split(input_, dim_, group=get_model_parallel_group())
@staticmethod
def backward(ctx, grad_output):
return _gather(grad_output, ctx.dim, group=get_model_parallel_group()), None
def scatter_to_parallel_region(input_, dim):
if not is_initialized():
return input_
else:
return _ScatterToParallelRegion.apply(input_, dim)
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# we need this in order to enable distributed
import torch
import torch.distributed as dist
# those need to be global
_MODEL_PARALLEL_GROUP = None
def get_model_parallel_group():
return _MODEL_PARALLEL_GROUP
def init(process_group):
global _MODEL_PARALLEL_GROUP
_MODEL_PARALLEL_GROUP = process_group
def is_initialized() -> bool:
return _MODEL_PARALLEL_GROUP is not None
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import numpy as np
import torch
def clm(l, m):
"""
defines the normalization factor to orthonormalize the Spherical Harmonics
"""
return np.sqrt((2*l + 1) / 4 / np.pi) * np.sqrt(np.math.factorial(l-m) / np.math.factorial(l+m))
def precompute_legpoly(mmax, lmax, t, norm="ortho", inverse=False, csphase=True):
"""
Computes the values of (-1)^m c^l_m P^l_m(\cos \theta) at the positions specified by x (theta)
The resulting tensor has shape (mmax, lmax, len(x)).
The Condon-Shortley Phase (-1)^m can be turned off optionally
method of computation follows
[1] Schaeffer, N.; Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
[2] Rapp, R.H.; A Fortran Program for the Computation of Gravimetric Quantities from High Degree Spherical Harmonic Expansions, Ohio State University Columbus; report; 1982;
https://apps.dtic.mil/sti/citations/ADA123406
[3] Schrama, E.; Orbit integration based upon interpolated gravitational gradients
"""
# compute the tensor P^m_n:
pct = np.zeros((mmax, lmax, len(t)), dtype=np.float64)
sint = np.sin(t)
cost = np.cos(t)
norm_factor = 1. if norm == "ortho" else np.sqrt(4 * np.pi)
norm_factor = 1. / norm_factor if inverse else norm_factor
# initial values to start the recursion
pct[0,0,:] = norm_factor / np.sqrt(4 * np.pi)
# fill the diagonal and the lower diagonal
for l in range(1, min(mmax,lmax)):
pct[l-1, l, :] = np.sqrt(2*l + 1) * cost * pct[l-1, l-1, :]
pct[l, l, :] = np.sqrt( (2*l + 1) * (1 + cost) * (1 - cost) / 2 / l ) * pct[l-1, l-1, :]
# fill the remaining values on the upper triangle and multiply b
for l in range(2, lmax):
for m in range(0, l-1):
pct[m, l, :] = cost * np.sqrt((2*l - 1) / (l - m) * (2*l + 1) / (l + m)) * pct[m, l-1, :] \
- np.sqrt((l + m - 1) / (l - m) * (2*l + 1) / (2*l - 3) * (l - m - 1) / (l + m)) * pct[m, l-2, :]
if norm == "schmidt":
for l in range(0, lmax):
if inverse:
pct[:, l, : ] = pct[:, l, : ] * np.sqrt(2*l + 1)
else:
pct[:, l, : ] = pct[:, l, : ] / np.sqrt(2*l + 1)
if csphase:
for m in range(1, mmax, 2):
pct[m] *= -1
return torch.from_numpy(pct)
def precompute_dlegpoly(mmax, lmax, x, norm="ortho", inverse=False, csphase=True):
"""
Computes the values of the derivatives $\frac{d}{d \theta} P^m_l(\cos \theta)$
at the positions specified by x (theta), as well as $\frac{1}{\sin \theta} P^m_l(\cos \theta)$,
needed for the computation of the vector spherical harmonics. The resulting tensor has shape
(2, mmax, lmax, len(x)).
computation follows
[2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
"""
pct = precompute_legpoly(mmax+1, lmax+1, x, norm=norm, inverse=inverse, csphase=False)
dpct = torch.zeros((2, mmax, lmax, len(x)), dtype=torch.float64)
# fill the derivative terms wrt theta
for l in range(0, lmax):
# m = 0
dpct[0, 0, l] = - np.sqrt(l*(l+1)) * pct[1, l]
# 0 < m < l
for m in range(1, min(l, mmax)):
dpct[0, m, l] = 0.5 * ( np.sqrt((l+m)*(l-m+1)) * pct[m-1, l] - np.sqrt((l-m)*(l+m+1)) * pct[m+1, l] )
# m == l
if mmax > l:
dpct[0, l, l] = np.sqrt(l/2) * pct[l-1, l]
# fill the - 1j m P^m_l / sin(phi). as this component is purely imaginary,
# we won't store it explicitly in a complex array
for m in range(1, min(l+1, mmax)):
# this component is implicitly complex
# we do not divide by m here as this cancels with the derivative of the exponential
dpct[1, m, l] = 0.5 * np.sqrt((2*l+1)/(2*l+3)) * \
( np.sqrt((l-m+1)*(l-m+2)) * pct[m-1, l+1] + np.sqrt((l+m+1)*(l+m+2)) * pct[m+1, l+1] )
if csphase:
for m in range(1, mmax, 2):
dpct[:, m] *= -1
return dpct
\ No newline at end of file
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import numpy as np
def legendre_gauss_weights(n, a=-1.0, b=1.0):
"""
Helper routine which returns the Legendre-Gauss nodes and weights
on the interval [a, b]
"""
xlg, wlg = np.polynomial.legendre.leggauss(n)
xlg = (b - a) * 0.5 * xlg + (b + a) * 0.5
wlg = wlg * (b - a) * 0.5
return xlg, wlg
def lobatto_weights(n, a=-1.0, b=1.0, tol=1e-16, maxiter=100):
"""
Helper routine which returns the Legendre-Gauss-Lobatto nodes and weights
on the interval [a, b]
"""
wlg = np.zeros((n,))
tlg = np.zeros((n,))
tmp = np.zeros((n,))
# Vandermonde Matrix
vdm = np.zeros((n, n))
# initialize Chebyshev nodes as first guess
for i in range(n):
tlg[i] = -np.cos(np.pi*i / (n-1))
tmp = 2.0
for i in range(maxiter):
tmp = tlg
vdm[:,0] = 1.0
vdm[:,1] = tlg
for k in range(2, n):
vdm[:, k] = ( (2*k-1) * tlg * vdm[:, k-1] - (k-1) * vdm[:, k-2] ) / k
tlg = tmp - ( tlg*vdm[:, n-1] - vdm[:, n-2] ) / ( n * vdm[:, n-1])
if (max(abs(tlg - tmp).flatten()) < tol ):
break
wlg = 2.0 / ( (n*(n-1))*(vdm[:, n-1]**2))
# rescale
tlg = (b - a) * 0.5 * tlg + (b + a) * 0.5
wlg = wlg * (b - a) * 0.5
return tlg, wlg
def clenshaw_curtiss_weights(n, a=-1.0, b=1.0):
"""
Computation of the Clenshaw-Curtis quadrature nodes and weights.
This implementation follows
[1] Joerg Waldvogel, Fast Construction of the Fejer and Clenshaw-Curtis Quadrature Rules; BIT Numerical Mathematics, Vol. 43, No. 1, pp. 001–018.
"""
assert(n > 1)
tcc = np.cos(np.linspace(np.pi, 0, n))
if n == 2:
wcc = np.array([1., 1.])
else:
n1 = n - 1
N = np.arange(1, n1, 2)
l = len(N)
m = n1 - l
v = np.concatenate([2 / N / (N-2), 1 / N[-1:], np.zeros(m)])
v = 0 - v[:-1] - v[-1:0:-1]
g0 = -np.ones(n1)
g0[l] = g0[l] + n1
g0[m] = g0[m] + n1
g = g0 / (n1**2 - 1 + (n1%2))
wcc = np.fft.ifft(v + g).real
wcc = np.concatenate((wcc, wcc[:1]))
# rescale
tcc = (b - a) * 0.5 * tcc + (b + a) * 0.5
wcc = wcc * (b - a) * 0.5
return tcc, wcc
def fejer2_weights(n, a=-1.0, b=1.0):
"""
Computation of the Fejer quadrature nodes and weights.
This implementation follows
[1] Joerg Waldvogel, Fast Construction of the Fejer and Clenshaw-Curtis Quadrature Rules; BIT Numerical Mathematics, Vol. 43, No. 1, pp. 001–018.
"""
assert(n > 2)
tcc = np.cos(np.linspace(np.pi, 0, n))
n1 = n - 1
N = np.arange(1, n1, 2)
l = len(N)
m = n1 - l
v = np.concatenate([2 / N / (N-2), 1 / N[-1:], np.zeros(m)])
v = 0 - v[:-1] - v[-1:0:-1]
wcc = np.fft.ifft(v).real
wcc = np.concatenate((wcc, wcc[:1]))
# rescale
tcc = (b - a) * 0.5 * tcc + (b + a) * 0.5
wcc = wcc * (b - a) * 0.5
return tcc, wcc
\ No newline at end of file
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import numpy as np
import torch
import torch.nn as nn
import torch.fft
from .quadrature import *
from .legendre import *
from .distributed import copy_to_parallel_region, scatter_to_parallel_region, reduce_from_parallel_region
class RealSHT(nn.Module):
"""
Defines a module for computing the forward (real-valued) SHT.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
The SHT is applied to the last two dimensions of the input
[1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
[2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
"""
def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho", csphase=True):
"""
Initializes the SHT Layer, precomputing the necessary quadrature weights
Parameters:
nlat: input grid resolution in the latitudinal direction
nlon: input grid resolution in the longitudinal direction
grid: grid in the latitude direction (for now only tensor product grids are supported)
"""
super().__init__()
self.nlat = nlat
self.nlon = nlon
self.grid = grid
self.norm = norm
self.csphase = csphase
# TODO: include assertions regarding the dimensions
# compute quadrature points
if self.grid == "legendre-gauss":
cost, w = legendre_gauss_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat
elif self.grid == "lobatto":
cost, w = lobatto_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat-1
elif self.grid == "equiangular":
cost, w = clenshaw_curtiss_weights(nlat, -1, 1)
# cost, w = fejer2_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat
else:
raise(ValueError("Unknown quadrature mode"))
# apply cosine transform and flip them
tq = np.flip(np.arccos(cost))
# determine the dimensions
self.mmax = mmax or self.nlon // 2 + 1
# combine quadrature weights with the legendre weights
weights = torch.from_numpy(w)
pct = precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
weights = torch.einsum('mlk,k->mlk', pct, weights)
# shard the weights along n, because we want to be distributed in spectral space:
weights = scatter_to_parallel_region(weights, dim=1)
self.lmax_local = weights.shape[1]
# remember quadrature weights
self.register_buffer('weights', weights, persistent=False)
def extra_repr(self):
"""
Pretty print module
"""
return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'
def forward(self, x: torch.Tensor):
assert(x.shape[-2] == self.nlat)
assert(x.shape[-1] == self.nlon)
# apply real fft in the longitudinal direction
x = 2.0 * torch.pi * torch.fft.rfft(x, dim=-1, norm="forward")
# do the Legendre-Gauss quadrature
x = torch.view_as_real(x)
# distributed contraction: fork
x = copy_to_parallel_region(x)
out_shape = list(x.size())
out_shape[-3] = self.lmax_local
out_shape[-2] = self.mmax
xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device)
# contraction
xout[..., 0] = torch.einsum('...km,mlk->...lm', x[..., :self.mmax, 0], self.weights )
xout[..., 1] = torch.einsum('...km,mlk->...lm', x[..., :self.mmax, 1], self.weights )
x = torch.view_as_complex(xout)
return x
class InverseRealSHT(nn.Module):
"""
Defines a module for computing the inverse (real-valued) SHT.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
nlat, nlon: Output dimensions
lmax, mmax: Input dimensions (spherical coefficients). For convenience, these are inferred from the output dimensions
[1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
[2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
"""
def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho", csphase=True):
super().__init__()
self.nlat = nlat
self.nlon = nlon
self.grid = grid
self.norm = norm
self.csphase = csphase
# compute quadrature points
if self.grid == "legendre-gauss":
cost, _ = legendre_gauss_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat
elif self.grid == "lobatto":
cost, _ = lobatto_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat-1
elif self.grid == "equiangular":
cost, _ = clenshaw_curtiss_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat
else:
raise(ValueError("Unknown quadrature mode"))
# apply cosine transform and flip them
t = np.flip(np.arccos(cost))
# determine the dimensions
self.mmax = mmax or self.nlon // 2 + 1
pct = precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
# shard the pct along the n dim
pct = scatter_to_parallel_region(pct, dim=1)
self.lmax_local = pct.shape[1]
self.register_buffer('pct', pct, persistent=False)
def extra_repr(self):
"""
Pretty print module
"""
return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'
def forward(self, x: torch.Tensor):
assert(x.shape[-2] == self.lmax_local)
assert(x.shape[-1] == self.mmax)
# Evaluate associated Legendre functions on the output nodes
x = torch.view_as_real(x)
rl = torch.einsum('...lm, mlk->...km', x[..., 0], self.pct )
im = torch.einsum('...lm, mlk->...km', x[..., 1], self.pct )
xs = torch.stack((rl, im), -1)
# distributed contraction: join
xs = reduce_from_parallel_region(xs)
# apply the inverse (real) FFT
x = torch.view_as_complex(xs)
x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")
return x
class RealVectorSHT(nn.Module):
"""
Defines a module for computing the forward (real) vector SHT.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
The SHT is applied to the last three dimensions of the input.
[1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
[2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
"""
def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho", csphase=True):
"""
Initializes the vector SHT Layer, precomputing the necessary quadrature weights
Parameters:
nlat: input grid resolution in the latitudinal direction
nlon: input grid resolution in the longitudinal direction
grid: type of grid the data lives on
"""
super().__init__()
self.nlat = nlat
self.nlon = nlon
self.grid = grid
self.norm = norm
self.csphase = csphase
# compute quadrature points
if self.grid == "legendre-gauss":
cost, w = legendre_gauss_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat
elif self.grid == "lobatto":
cost, w = lobatto_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat-1
elif self.grid == "equiangular":
cost, w = clenshaw_curtiss_weights(nlat, -1, 1)
# cost, w = fejer2_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat
else:
raise(ValueError("Unknown quadrature mode"))
# apply cosine transform and flip them
tq = np.flip(np.arccos(cost))
# determine the dimensions
self.mmax = mmax or self.nlon // 2 + 1
weights = torch.from_numpy(w)
dpct = precompute_dlegpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
# combine integration weights, normalization factor in to one:
l = torch.arange(0, self.lmax)
norm_factor = 1. / l / (l+1)
norm_factor[0] = 1.
weights = torch.einsum('dmlk,k,l->dmlk', dpct, weights, norm_factor)
# since the second component is imaginary, we need to take complex conjugation into account
weights[1] = -1 * weights[1]
# shard the weights along n, because we want to be distributed in spectral space:
weights = scatter_to_parallel_region(weights, dim=2)
self.lmax_local = weights.shape[2]
# remember quadrature weights
self.register_buffer('weights', weights, persistent=False)
def extra_repr(self):
"""
Pretty print module
"""
return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'
def forward(self, x: torch.Tensor):
assert(len(x.shape) >= 3)
# apply real fft in the longitudinal direction
x = 2.0 * torch.pi * torch.fft.rfft(x, dim=-1, norm="forward")
# do the Legendre-Gauss quadrature
x = torch.view_as_real(x)
# distributed contraction: fork
x = copy_to_parallel_region(x)
out_shape = list(x.size())
out_shape[-3] = self.lmax_local
out_shape[-2] = self.mmax
xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device)
# contraction - spheroidal component
# real component
xout[..., 0, :, :, 0] = torch.einsum('...km,mlk->...lm', x[..., 0, :, :self.mmax, 0], self.weights[0]) \
- torch.einsum('...km,mlk->...lm', x[..., 1, :, :self.mmax, 1], self.weights[1])
# iamg component
xout[..., 0, :, :, 1] = torch.einsum('...km,mlk->...lm', x[..., 0, :, :self.mmax, 1], self.weights[0]) \
+ torch.einsum('...km,mlk->...lm', x[..., 1, :, :self.mmax, 0], self.weights[1])
# contraction - toroidal component
# real component
xout[..., 1, :, :, 0] = - torch.einsum('...km,mlk->...lm', x[..., 0, :, :self.mmax, 1], self.weights[1]) \
- torch.einsum('...km,mlk->...lm', x[..., 1, :, :self.mmax, 0], self.weights[0])
# imag component
xout[..., 1, :, :, 1] = torch.einsum('...km,mlk->...lm', x[..., 0, :, :self.mmax, 0], self.weights[1]) \
- torch.einsum('...km,mlk->...lm', x[..., 1, :, :self.mmax, 1], self.weights[0])
return torch.view_as_complex(xout)
class InverseRealVectorSHT(nn.Module):
"""
Defines a module for computing the inverse (real-valued) vector SHT.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
[1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
[2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
"""
def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho", csphase=True):
super().__init__()
self.nlat = nlat
self.nlon = nlon
self.grid = grid
self.norm = norm
self.csphase = csphase
# compute quadrature points
if self.grid == "legendre-gauss":
cost, _ = legendre_gauss_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat
elif self.grid == "lobatto":
cost, _ = lobatto_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat-1
elif self.grid == "equiangular":
cost, _ = clenshaw_curtiss_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat
else:
raise(ValueError("Unknown quadrature mode"))
# apply cosine transform and flip them
t = np.flip(np.arccos(cost))
# determine the dimensions
self.mmax = mmax or self.nlon // 2 + 1
dpct = precompute_dlegpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
# shard the pct along the n dim
dpct = scatter_to_parallel_region(dpct, dim=2)
self.lmax_local = dpct.shape[2]
self.register_buffer('dpct', dpct, persistent=False)
def extra_repr(self):
"""
Pretty print module
"""
return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'
def forward(self, x: torch.Tensor):
assert(x.shape[-2] == self.lmax_local)
assert(x.shape[-1] == self.mmax)
# Evaluate associated Legendre functions on the output nodes
x = torch.view_as_real(x)
# contraction - spheroidal component
# real component
srl = torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 0], self.dpct[0]) \
- torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 1], self.dpct[1])
# iamg component
sim = torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 1], self.dpct[0]) \
+ torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 0], self.dpct[1])
# contraction - toroidal component
# real component
trl = - torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 1], self.dpct[1]) \
- torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 0], self.dpct[0])
# imag component
tim = torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 0], self.dpct[1]) \
- torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 1], self.dpct[0])
# reassemble
s = torch.stack((srl, sim), -1)
t = torch.stack((trl, tim), -1)
xs = torch.stack((s, t), -4)
# distributed contraction: join
xs = reduce_from_parallel_region(xs)
# apply the inverse (real) FFT
x = torch.view_as_complex(xs)
x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")
return x
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import unittest
import numpy as np
import torch
from torch_harmonics import *
# try:
# from tqdm import tqdm
# except:
# tqdm = lambda x : x
tqdm = lambda x : x
class TestLegendrePolynomials(unittest.TestCase):
def setUp(self):
self.cml = lambda m, l : np.sqrt((2*l + 1) / 4 / np.pi) * np.sqrt(np.math.factorial(l-m) / np.math.factorial(l+m))
self.pml = dict()
# preparing associated Legendre Polynomials (These include the Condon-Shortley phase)
# for reference see e.g. https://en.wikipedia.org/wiki/Associated_Legendre_polynomials
self.pml[(0, 0)] = lambda x : np.ones_like(x)
self.pml[(0, 1)] = lambda x : x
self.pml[(1, 1)] = lambda x : - np.sqrt(1. - x**2)
self.pml[(0, 2)] = lambda x : 0.5 * (3*x**2 - 1)
self.pml[(1, 2)] = lambda x : - 3 * x * np.sqrt(1. - x**2)
self.pml[(2, 2)] = lambda x : 3 * (1 - x**2)
self.pml[(0, 3)] = lambda x : 0.5 * (5*x**3 - 3*x)
self.pml[(1, 3)] = lambda x : 1.5 * (1 - 5*x**2) * np.sqrt(1. - x**2)
self.pml[(2, 3)] = lambda x : 15 * x * (1 - x**2)
self.pml[(3, 3)] = lambda x : -15 * np.sqrt(1. - x**2)**3
self.lmax = self.mmax = 4
def test_legendre(self):
print("Testing computation of associated Legendre polynomials")
from torch_harmonics.legendre import precompute_legpoly
TOL = 1e-9
t = np.linspace(0, np.pi, 100)
pct = precompute_legpoly(self.mmax, self.lmax, t)
for l in range(self.lmax):
for m in range(l+1):
diff = pct[m, l].numpy() / self.cml(m,l) - self.pml[(m,l)](np.cos(t))
self.assertTrue(diff.max() <= TOL)
print("done.")
class TestSphericalHarmonicTransform(unittest.TestCase):
def __init__(self, testname, norm="ortho"):
super(TestSphericalHarmonicTransform, self).__init__(testname) # calling the super class init varies for different python versions. This works for 2.7
self.norm = norm
def setUp(self):
if torch.cuda.is_available():
print("Running test on GPU")
self.device = torch.device('cuda')
else:
print("Running test on CPU")
self.device = torch.device('cpu')
self.batch_size = 128
self.nlat = 256
self.nlon = 2*self.nlat
def test_sht_leggauss(self):
print(f"Testing real-valued SHT on Legendre-Gauss grid with {self.norm} normalization")
TOL = 1e-9
testiters = [1, 2, 4, 8, 16]
mmax = self.nlat
lmax = mmax
sht = RealSHT(self.nlat, self.nlon, mmax=mmax, lmax=lmax, grid="legendre-gauss", norm=self.norm).to(self.device)
isht = InverseRealSHT(self.nlat, self.nlon, mmax=mmax, lmax=lmax, grid="legendre-gauss", norm=self.norm).to(self.device)
coeffs = torch.zeros(self.batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
coeffs[:, :lmax, :mmax] = torch.randn(self.batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
signal = isht(coeffs)
for iter in testiters:
with self.subTest(i = iter):
print(f"{iter} iterations of batchsize {self.batch_size}:")
base = signal
for _ in tqdm(range(iter)):
base = isht(sht(base))
# err = ( torch.norm(base-self.signal, p='fro') / torch.norm(self.signal, p='fro') ).item()
err = torch.mean(torch.norm(base-signal, p='fro', dim=(-1,-2)) / torch.norm(signal, p='fro', dim=(-1,-2)) ).item()
print(f"final relative error: {err}")
self.assertTrue(err <= TOL)
def test_sht_equiangular(self):
print(f"Testing real-valued SHT on equiangular grid with {self.norm} normalization")
TOL = 1e-1
testiters = [1, 2, 4, 8]
mmax = self.nlat // 2
lmax = mmax
sht = RealSHT(self.nlat, self.nlon, mmax=mmax, lmax=lmax, grid="equiangular", norm=self.norm).to(self.device)
isht = InverseRealSHT(self.nlat, self.nlon, mmax=mmax, lmax=lmax, grid="equiangular", norm=self.norm).to(self.device)
coeffs = torch.zeros(self.batch_size, sht.lmax, sht.mmax, device=self.device, dtype=torch.complex128)
coeffs[:, :lmax, :mmax] = torch.randn(self.batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
signal = isht(coeffs)
for iter in testiters:
with self.subTest(i = iter):
print(f"{iter} iterations of batchsize {self.batch_size}:")
base = signal
for _ in tqdm(range(iter)):
base = isht(sht(base))
# err = ( torch.norm(base-self.signal, p='fro') / torch.norm(self.signal, p='fro') ).item()
err = torch.mean(torch.norm(base-signal, p='fro', dim=(-1,-2)) / torch.norm(signal, p='fro', dim=(-1,-2)) ).item()
print(f"final relative error: {err}")
self.assertTrue(err <= TOL)
if __name__ == '__main__':
sht_test_suite = unittest.TestSuite()
sht_test_suite.addTest(TestLegendrePolynomials('test_legendre'))
sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht_leggauss', norm="ortho"))
sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht_equiangular', norm="ortho"))
sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht_leggauss', norm="four-pi"))
sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht_equiangular', norm="four-pi"))
sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht_leggauss', norm="schmidt"))
sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht_equiangular', norm="schmidt"))
unittest.TextTestRunner(verbosity=2).run(sht_test_suite)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment