Unverified Commit 60b3b5a2 authored by Boris Bonev's avatar Boris Bonev Committed by GitHub
Browse files

Changing channel dimension of distributed SHT from 1 to -3 (#52)

* changing distributeds SHT to use dim=-3 as the channel dimension for distributed transpose

* Fixed formatting in tests

* Adding bash script for running test suite

* Replaced asserts regarding number of dimensions in tensor with checks
parent 4a6ed467
......@@ -5,6 +5,7 @@
### v0.7.2
* Added resampling modules for convenience
* Changing behavior of distributed SHT to use `dim=-3` as channel dimension
### v0.7.1
......
......@@ -347,7 +347,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
torch.cuda.manual_seed(333)
# set device
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
torch.cuda.set_device(device.index)
......
#!/bin/bash
# Set default parameters
default_grid_size_lat=1
default_grid_size_lon=1
default_run_distributed=false
# Parse command line arguments
while [[ $# -gt 0 ]]; do
case $1 in
-h|--help)
bold=$(tput bold)
normal=$(tput sgr0)
echo "Runs the torch-harmonics test suite."
echo "${bold}Arguments:${normal}"
echo " ${bold}-h | --help:${normal} Prints this text."
echo " ${bold}-d | --run_distributed:${normal} Run the distributed test suite."
echo " ${bold}-lat | --grid_size_lat:${normal} Number of ranks in latitudinal direction for distributed case."
echo " ${bold}-lon | --grid_size_lon:${normal} Number of ranks in longitudinal direction for distributed case."
shift
exit 0
;;
-lat|--grid_size_lat)
grid_size_lat="$2"
shift 2
;;
-lon|--grid_size_lon)
grid_size_lon="$2"
shift 2
;;
-d|--run_distributed)
run_distributed=true
shift
;;
*)
echo "Unknown argument: $1"
exit 1
;;
esac
done
# Use default values if arguments were not provided
grid_size_lat=${grid_size_lat:-$default_grid_size_lat}
grid_size_lon=${grid_size_lon:-$default_grid_size_lon}
run_distributed=${run_distributed:-$default_run_distributed}
echo "Running sequential tests:"
python3 -m pytest tests/test_convolution.py tests/test_sht.py
# Run distributed tests if requested
if [ "$run_distributed" = "true" ]; then
echo "Running distributed tests with the following parameters:"
echo "Grid size latitude: $grid_size_lat"
echo "Grid size longitude: $grid_size_lon"
ngpu=$(( ${grid_size_lat} * ${grid_size_lon} ))
mpirun --allow-run-as-root -np ${ngpu} bash -c "
export CUDA_LAUNCH_BLOCKING=1;
export WORLD_RANK=\${OMPI_COMM_WORLD_RANK};
export WORLD_SIZE=\${OMPI_COMM_WORLD_SIZE};
export RANK=\${OMPI_COMM_WORLD_RANK};
export MASTER_ADDR=localhost;
export MASTER_PORT=29501;
export GRID_H=${grid_size_lat};
export GRID_W=${grid_size_lon};
python3 -m pytest tests/test_distributed_sht.py
python3 -m pytest tests/test_distributed_convolution.py
"
else
echo "Skipping distributed tests."
fi
......@@ -40,6 +40,7 @@ from torch_harmonics import *
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes
def _compute_vals_isotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, r_cutoff: float):
"""
helper routine to compute the values of the isotropic kernel densely
......@@ -47,7 +48,7 @@ def _compute_vals_isotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, r_cutof
kernel_size = (nr // 2) + nr % 2
ikernel = torch.arange(kernel_size).reshape(-1, 1, 1)
dr = 2 * r_cutoff / (nr + 1)
dr = 2 * r_cutoff / (nr + 1)
# compute the support
if nr % 2 == 1:
......@@ -71,7 +72,7 @@ def _compute_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, nphi:
kernel_size = (nr // 2) * nphi + nr % 2
ikernel = torch.arange(kernel_size).reshape(-1, 1, 1)
dr = 2 * r_cutoff / (nr + 1)
dr = 2 * r_cutoff / (nr + 1)
dphi = 2.0 * math.pi / nphi
# disambiguate even and uneven cases and compute the support
......@@ -87,7 +88,7 @@ def _compute_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, nphi:
# find the indices where the rotated position falls into the support of the kernel
cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff)
cond_phi = ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi)
r_vals = torch.where(cond_r, (1 - (r - ir).abs() / dr) , 0.0)
r_vals = torch.where(cond_r, (1 - (r - ir).abs() / dr), 0.0)
phi_vals = torch.where(cond_phi, (1 - torch.minimum((phi - iphi).abs(), (2 * math.pi - (phi - iphi).abs())) / dphi), 0.0)
vals = torch.where(ikernel > 0, r_vals * phi_vals, r_vals)
else:
......@@ -99,8 +100,8 @@ def _compute_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, nphi:
vals = r_vals * phi_vals
# in the even case, the inner casis functions overlap into areas with a negative areas
rn = - r
phin = torch.where(phi + math.pi >= 2*math.pi, phi - math.pi, phi + math.pi)
rn = -r
phin = torch.where(phi + math.pi >= 2 * math.pi, phi - math.pi, phi + math.pi)
cond_rn = ((rn - ir).abs() <= dr) & (rn <= r_cutoff)
cond_phin = ((phin - iphi).abs() <= dphi) | ((2 * math.pi - (phin - iphi).abs()) <= dphi)
rn_vals = torch.where(cond_rn, (1 - (rn - ir).abs() / dr), 0.0)
......@@ -109,6 +110,7 @@ def _compute_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, nphi:
return vals
def _normalize_convolution_tensor_dense(psi, quad_weights, transpose_normalization=False, merge_quadrature=False, eps=1e-9):
"""
Discretely normalizes the convolution tensor.
......@@ -120,7 +122,7 @@ def _normalize_convolution_tensor_dense(psi, quad_weights, transpose_normalizati
if transpose_normalization:
# the normalization is not quite symmetric due to the compressed way psi is stored in the main code
# look at the normalization code in the actual implementation
psi_norm = torch.sum(quad_weights.reshape(1, -1, 1, 1, 1) * psi[:,:,:1], dim=(1, 4), keepdim=True) / scale_factor
psi_norm = torch.sum(quad_weights.reshape(1, -1, 1, 1, 1) * psi[:, :, :1], dim=(1, 4), keepdim=True) / scale_factor
if merge_quadrature:
psi = quad_weights.reshape(1, -1, 1, 1, 1) * psi
else:
......@@ -131,7 +133,17 @@ def _normalize_convolution_tensor_dense(psi, quad_weights, transpose_normalizati
return psi / (psi_norm + eps)
def _precompute_convolution_tensor_dense(in_shape, out_shape, kernel_shape, quad_weights, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi, transpose_normalization=False, merge_quadrature=False):
def _precompute_convolution_tensor_dense(
in_shape,
out_shape,
kernel_shape,
quad_weights,
grid_in="equiangular",
grid_out="equiangular",
theta_cutoff=0.01 * math.pi,
transpose_normalization=False,
merge_quadrature=False,
):
"""
Helper routine to compute the convolution Tensor in a dense fashion
"""
......@@ -143,7 +155,7 @@ def _precompute_convolution_tensor_dense(in_shape, out_shape, kernel_shape, quad
if len(kernel_shape) == 1:
kernel_handle = partial(_compute_vals_isotropic, nr=kernel_shape[0], r_cutoff=theta_cutoff)
kernel_size = math.ceil( kernel_shape[0] / 2)
kernel_size = math.ceil(kernel_shape[0] / 2)
elif len(kernel_shape) == 2:
kernel_handle = partial(_compute_vals_anisotropic, nr=kernel_shape[0], nphi=kernel_shape[1], r_cutoff=theta_cutoff)
kernel_size = (kernel_shape[0] // 2) * kernel_shape[1] + kernel_shape[0] % 2
......@@ -250,30 +262,25 @@ class TestDiscreteContinuousConvolution(unittest.TestCase):
theta_cutoff = (kernel_shape[0] + 1) / 2 * torch.pi / float(nlat_out - 1)
Conv = DiscreteContinuousConvTransposeS2 if transpose else DiscreteContinuousConvS2
conv = Conv(
in_channels,
out_channels,
in_shape,
out_shape,
kernel_shape,
groups=1,
grid_in=grid_in,
grid_out=grid_out,
bias=False,
theta_cutoff=theta_cutoff
).to(self.device)
conv = Conv(in_channels, out_channels, in_shape, out_shape, kernel_shape, groups=1, grid_in=grid_in, grid_out=grid_out, bias=False, theta_cutoff=theta_cutoff).to(
self.device
)
_, wgl = _precompute_latitudes(nlat_in, grid=grid_in)
quad_weights = 2.0 * torch.pi * torch.from_numpy(wgl).float().reshape(-1, 1) / nlon_in
if transpose:
psi_dense = _precompute_convolution_tensor_dense(out_shape, in_shape, kernel_shape, quad_weights, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff, transpose_normalization=True, merge_quadrature=True).to(self.device)
psi_dense = _precompute_convolution_tensor_dense(
out_shape, in_shape, kernel_shape, quad_weights, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff, transpose_normalization=True, merge_quadrature=True
).to(self.device)
psi = torch.sparse_coo_tensor(conv.psi_idx, conv.psi_vals, size=(conv.kernel_size, conv.nlat_in, conv.nlat_out * conv.nlon_out)).to_dense()
self.assertTrue(torch.allclose(psi, psi_dense[:, :, 0].reshape(-1, nlat_in, nlat_out * nlon_out)))
else:
psi_dense = _precompute_convolution_tensor_dense(in_shape, out_shape, kernel_shape, quad_weights, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff, transpose_normalization=False, merge_quadrature=True).to(self.device)
psi_dense = _precompute_convolution_tensor_dense(
in_shape, out_shape, kernel_shape, quad_weights, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff, transpose_normalization=False, merge_quadrature=True
).to(self.device)
psi = torch.sparse_coo_tensor(conv.psi_idx, conv.psi_vals, size=(conv.kernel_size, conv.nlat_out, conv.nlat_in * conv.nlon_in)).to_dense()
......
......@@ -114,8 +114,8 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
@classmethod
def tearDownClass(cls):
thd.finalize()
dist.destroy_process_group(None)
thd.finalize()
dist.destroy_process_group(None)
def _split_helper(self, tensor):
with torch.no_grad():
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
......@@ -46,11 +46,11 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
def setUpClass(cls):
# set up distributed
cls.world_rank = int(os.getenv('WORLD_RANK', 0))
cls.grid_size_h = int(os.getenv('GRID_H', 1))
cls.grid_size_w = int(os.getenv('GRID_W', 1))
port = int(os.getenv('MASTER_PORT', '29501'))
master_address = os.getenv('MASTER_ADDR', 'localhost')
cls.world_rank = int(os.getenv("WORLD_RANK", 0))
cls.grid_size_h = int(os.getenv("GRID_H", 1))
cls.grid_size_w = int(os.getenv("GRID_W", 1))
port = int(os.getenv("MASTER_PORT", "29501"))
master_address = os.getenv("MASTER_ADDR", "localhost")
cls.world_size = cls.grid_size_h * cls.grid_size_w
if torch.cuda.is_available():
......@@ -59,24 +59,21 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
local_rank = cls.world_rank % torch.cuda.device_count()
cls.device = torch.device(f"cuda:{local_rank}")
torch.cuda.manual_seed(333)
proc_backend = 'nccl'
proc_backend = "nccl"
else:
if cls.world_rank == 0:
print("Running test on CPU")
cls.device = torch.device('cpu')
proc_backend = 'gloo'
cls.device = torch.device("cpu")
proc_backend = "gloo"
torch.manual_seed(333)
dist.init_process_group(backend = proc_backend,
init_method = f"tcp://{master_address}:{port}",
rank = cls.world_rank,
world_size = cls.world_size)
dist.init_process_group(backend=proc_backend, init_method=f"tcp://{master_address}:{port}", rank=cls.world_rank, world_size=cls.world_size)
cls.wrank = cls.world_rank % cls.grid_size_w
cls.hrank = cls.world_rank // cls.grid_size_w
# now set up the comm groups:
#set default
# set default
cls.w_group = None
cls.h_group = None
......@@ -110,7 +107,6 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
# set seed
torch.manual_seed(333)
if cls.world_rank == 0:
print(f"Running distributed tests on grid H x W = {cls.grid_size_h} x {cls.grid_size_w}")
......@@ -123,7 +119,6 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
thd.finalize()
dist.destroy_process_group(None)
def _split_helper(self, tensor):
with torch.no_grad():
# split in W
......@@ -135,8 +130,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
tensor_local = tensor_list_local[self.hrank]
return tensor_local
def _gather_helper_fwd(self, tensor, B, C, transform_dist, vector):
# we need the shapes
l_shapes = transform_dist.l_shapes
......@@ -199,25 +193,26 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
return tensor_gather
@parameterized.expand([
[256, 512, 32, 8, "equiangular", False, 1e-9],
[256, 512, 32, 8, "legendre-gauss", False, 1e-9],
[256, 512, 32, 8, "equiangular", False, 1e-9],
[256, 512, 32, 8, "legendre-gauss", False, 1e-9],
[256, 512, 32, 8, "equiangular", False, 1e-9],
[256, 512, 32, 8, "legendre-gauss", False, 1e-9],
[361, 720, 1, 10, "equiangular", False, 1e-6],
[361, 720, 1, 10, "legendre-gauss", False, 1e-6],
[256, 512, 32, 8, "equiangular", True, 1e-9],
[256, 512, 32, 8, "legendre-gauss", True, 1e-9],
[256, 512, 32, 8, "equiangular", True, 1e-9],
[256, 512, 32, 8, "legendre-gauss", True, 1e-9],
[256, 512, 32, 8, "equiangular", True, 1e-9],
[256, 512, 32, 8, "legendre-gauss", True, 1e-9],
[361, 720, 1, 10, "equiangular", True, 1e-6],
[361, 720, 1, 10, "legendre-gauss", True, 1e-6],
])
@parameterized.expand(
[
[256, 512, 32, 8, "equiangular", False, 1e-9],
[256, 512, 32, 8, "legendre-gauss", False, 1e-9],
[256, 512, 32, 8, "equiangular", False, 1e-9],
[256, 512, 32, 8, "legendre-gauss", False, 1e-9],
[256, 512, 32, 8, "equiangular", False, 1e-9],
[256, 512, 32, 8, "legendre-gauss", False, 1e-9],
[361, 720, 1, 10, "equiangular", False, 1e-6],
[361, 720, 1, 10, "legendre-gauss", False, 1e-6],
[256, 512, 32, 8, "equiangular", True, 1e-9],
[256, 512, 32, 8, "legendre-gauss", True, 1e-9],
[256, 512, 32, 8, "equiangular", True, 1e-9],
[256, 512, 32, 8, "legendre-gauss", True, 1e-9],
[256, 512, 32, 8, "equiangular", True, 1e-9],
[256, 512, 32, 8, "legendre-gauss", True, 1e-9],
[361, 720, 1, 10, "equiangular", True, 1e-6],
[361, 720, 1, 10, "legendre-gauss", True, 1e-6],
]
)
def test_distributed_sht(self, nlat, nlon, batch_size, num_chan, grid, vector, tol):
B, C, H, W = batch_size, num_chan, nlat, nlon
......@@ -246,7 +241,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
with torch.no_grad():
# create full grad
ograd_full = torch.randn_like(out_full)
# BWD pass
out_full.backward(ograd_full)
igrad_full = inp_full.grad.clone()
......@@ -270,7 +265,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
#############################################################
with torch.no_grad():
out_gather_full = self._gather_helper_fwd(out_local, B, C, forward_transform_dist, vector)
err = torch.mean(torch.norm(out_full-out_gather_full, p='fro', dim=(-1,-2)) / torch.norm(out_full, p='fro', dim=(-1,-2)) )
err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2)))
if self.world_rank == 0:
print(f"final relative error of output: {err.item()}")
self.assertTrue(err.item() <= tol)
......@@ -280,30 +275,31 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
#############################################################
with torch.no_grad():
igrad_gather_full = self._gather_helper_bwd(igrad_local, B, C, forward_transform_dist, vector)
err = torch.mean(torch.norm(igrad_full-igrad_gather_full, p='fro', dim=(-1,-2)) / torch.norm(igrad_full, p='fro', dim=(-1,-2)) )
err = torch.mean(torch.norm(igrad_full - igrad_gather_full, p="fro", dim=(-1, -2)) / torch.norm(igrad_full, p="fro", dim=(-1, -2)))
if self.world_rank == 0:
print(f"final relative error of gradients: {err.item()}")
self.assertTrue(err.item() <= tol)
@parameterized.expand([
[256, 512, 32, 8, "equiangular", False, 1e-9],
[256, 512, 32, 8, "legendre-gauss", False, 1e-9],
[256, 512, 32, 8, "equiangular", False, 1e-9],
[256, 512, 32, 8, "legendre-gauss", False, 1e-9],
[256, 512, 32, 8, "equiangular", False, 1e-9],
[256, 512, 32, 8, "legendre-gauss", False, 1e-9],
[361, 720, 1, 10, "equiangular", False, 1e-6],
[361, 720, 1, 10, "legendre-gauss", False, 1e-6],
[256, 512, 32, 8, "equiangular", True, 1e-9],
[256, 512, 32, 8, "legendre-gauss", True, 1e-9],
[256, 512, 32, 8, "equiangular", True, 1e-9],
[256, 512, 32, 8, "legendre-gauss", True, 1e-9],
[256, 512, 32, 8, "equiangular", True, 1e-9],
[256, 512, 32, 8, "legendre-gauss", True, 1e-9],
[361, 720, 1, 10, "equiangular", True, 1e-6],
[361, 720, 1, 10, "legendre-gauss", True, 1e-6],
])
@parameterized.expand(
[
[256, 512, 32, 8, "equiangular", False, 1e-9],
[256, 512, 32, 8, "legendre-gauss", False, 1e-9],
[256, 512, 32, 8, "equiangular", False, 1e-9],
[256, 512, 32, 8, "legendre-gauss", False, 1e-9],
[256, 512, 32, 8, "equiangular", False, 1e-9],
[256, 512, 32, 8, "legendre-gauss", False, 1e-9],
[361, 720, 1, 10, "equiangular", False, 1e-6],
[361, 720, 1, 10, "legendre-gauss", False, 1e-6],
[256, 512, 32, 8, "equiangular", True, 1e-9],
[256, 512, 32, 8, "legendre-gauss", True, 1e-9],
[256, 512, 32, 8, "equiangular", True, 1e-9],
[256, 512, 32, 8, "legendre-gauss", True, 1e-9],
[256, 512, 32, 8, "equiangular", True, 1e-9],
[256, 512, 32, 8, "legendre-gauss", True, 1e-9],
[361, 720, 1, 10, "equiangular", True, 1e-6],
[361, 720, 1, 10, "legendre-gauss", True, 1e-6],
]
)
def test_distributed_isht(self, nlat, nlon, batch_size, num_chan, grid, vector, tol):
B, C, H, W = batch_size, num_chan, nlat, nlon
......@@ -311,7 +307,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
forward_transform_local = harmonics.RealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
backward_transform_local = harmonics.InverseRealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
backward_transform_dist = thd.DistributedInverseRealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
else:
else:
forward_transform_local = harmonics.RealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
backward_transform_local = harmonics.InverseRealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
backward_transform_dist = thd.DistributedInverseRealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
......@@ -325,8 +321,8 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
#############################################################
# local transform
#############################################################
# FWD pass
#############################################################
# FWD pass
inp_full.requires_grad = True
out_full = backward_transform_local(inp_full)
......@@ -363,7 +359,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
#############################################################
with torch.no_grad():
out_gather_full = self._gather_helper_bwd(out_local, B, C, backward_transform_dist, vector)
err = torch.mean(torch.norm(out_full-out_gather_full, p='fro', dim=(-1,-2)) / torch.norm(out_full, p='fro', dim=(-1,-2)) )
err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2)))
if self.world_rank == 0:
print(f"final relative error of output: {err.item()}")
self.assertTrue(err.item() <= tol)
......@@ -373,10 +369,11 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
#############################################################
with torch.no_grad():
igrad_gather_full = self._gather_helper_fwd(igrad_local, B, C, backward_transform_dist, vector)
err = torch.mean(torch.norm(igrad_full-igrad_gather_full, p='fro', dim=(-1,-2)) / torch.norm(igrad_full, p='fro', dim=(-1,-2)) )
err = torch.mean(torch.norm(igrad_full - igrad_gather_full, p="fro", dim=(-1, -2)) / torch.norm(igrad_full, p="fro", dim=(-1, -2)))
if self.world_rank == 0:
print(f"final relative error of gradients: {err.item()}")
self.assertTrue(err.item() <= tol)
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()
......@@ -55,7 +55,7 @@ class DistributedRealSHT(nn.Module):
def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho", csphase=True):
"""
Initializes the SHT Layer, precomputing the necessary quadrature weights
Distribtued SHT layer. Expects the last 3 dimensions of the input tensor to be channels, latitude, longitude.
Parameters:
nlat: input grid resolution in the latitudinal direction
......@@ -108,9 +108,9 @@ class DistributedRealSHT(nn.Module):
# combine quadrature weights with the legendre weights
weights = torch.from_numpy(w)
pct = _precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
pct = torch.from_numpy(pct)
pct = torch.from_numpy(pct)
weights = torch.einsum('mlk,k->mlk', pct, weights)
# split weights
weights = split_tensor_along_dim(weights, dim=0, num_chunks=self.comm_size_azimuth)[self.comm_rank_azimuth]
......@@ -125,12 +125,15 @@ class DistributedRealSHT(nn.Module):
def forward(self, x: torch.Tensor):
if x.dim() < 3:
raise ValueError(f"Expected tensor with at least 3 dimensions but got {x.dim()} instead")
# we need to ensure that we can split the channels evenly
num_chans = x.shape[1]
num_chans = x.shape[-3]
# h and w is split. First we make w local by transposing into channel dim
if self.comm_size_azimuth > 1:
x = distributed_transpose_azimuth.apply(x, (1, -1), self.lon_shapes)
x = distributed_transpose_azimuth.apply(x, (-3, -1), self.lon_shapes)
# apply real fft in the longitudinal direction: make sure to truncate to nlon
x = 2.0 * torch.pi * torch.fft.rfft(x, n=self.nlon, dim=-1, norm="forward")
......@@ -141,11 +144,11 @@ class DistributedRealSHT(nn.Module):
# transpose: after this, m is split and c is local
if self.comm_size_azimuth > 1:
chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
x = distributed_transpose_azimuth.apply(x, (-1, 1), chan_shapes)
x = distributed_transpose_azimuth.apply(x, (-1, -3), chan_shapes)
# transpose: after this, c is split and h is local
if self.comm_size_polar > 1:
x = distributed_transpose_polar.apply(x, (1, -2), self.lat_shapes)
x = distributed_transpose_polar.apply(x, (-3, -2), self.lat_shapes)
# do the Legendre-Gauss quadrature
x = torch.view_as_real(x)
......@@ -159,8 +162,8 @@ class DistributedRealSHT(nn.Module):
# transpose: after this, l is split and c is local
if self.comm_size_polar > 1:
chan_shapes = compute_split_shapes(num_chans, self.comm_size_polar)
x = distributed_transpose_polar.apply(x, (-2, 1), chan_shapes)
x = distributed_transpose_polar.apply(x, (-2, -3), chan_shapes)
return x
......@@ -210,7 +213,7 @@ class DistributedInverseRealSHT(nn.Module):
# determine the dimensions
self.mmax = mmax or self.nlon // 2 + 1
# compute splits
# compute splits
self.lat_shapes = compute_split_shapes(self.nlat, self.comm_size_polar)
self.lon_shapes = compute_split_shapes(self.nlon, self.comm_size_azimuth)
self.l_shapes = compute_split_shapes(self.lmax, self.comm_size_polar)
......@@ -234,12 +237,15 @@ class DistributedInverseRealSHT(nn.Module):
def forward(self, x: torch.Tensor):
if x.dim() < 3:
raise ValueError(f"Expected tensor with at least 3 dimensions but got {x.dim()} instead")
# we need to ensure that we can split the channels evenly
num_chans = x.shape[1]
num_chans = x.shape[-3]
# transpose: after that, channels are split, l is local:
if self.comm_size_polar > 1:
x = distributed_transpose_polar.apply(x, (1, -2), self.l_shapes)
x = distributed_transpose_polar.apply(x, (-3, -2), self.l_shapes)
# Evaluate associated Legendre functions on the output nodes
x = torch.view_as_real(x)
......@@ -255,11 +261,11 @@ class DistributedInverseRealSHT(nn.Module):
if self.comm_size_polar > 1:
chan_shapes = compute_split_shapes(num_chans, self.comm_size_polar)
x = distributed_transpose_polar.apply(x, (-2, 1), chan_shapes)
x = distributed_transpose_polar.apply(x, (-2, -3), chan_shapes)
# transpose: after this, channels are split and m is local
if self.comm_size_azimuth > 1:
x = distributed_transpose_azimuth.apply(x, (1, -1), self.m_shapes)
x = distributed_transpose_azimuth.apply(x, (-3, -1), self.m_shapes)
# apply the inverse (real) FFT
x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")
......@@ -267,7 +273,7 @@ class DistributedInverseRealSHT(nn.Module):
# transpose: after this, m is split and channels are local
if self.comm_size_azimuth > 1:
chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
x = distributed_transpose_azimuth.apply(x, (-1, 1), chan_shapes)
x = distributed_transpose_azimuth.apply(x, (-1, -3), chan_shapes)
return x
......@@ -351,7 +357,7 @@ class DistributedRealVectorSHT(nn.Module):
# remember quadrature weights
self.register_buffer('weights', weights, persistent=False)
def extra_repr(self):
"""
Pretty print module
......@@ -360,14 +366,15 @@ class DistributedRealVectorSHT(nn.Module):
def forward(self, x: torch.Tensor):
assert(len(x.shape) >= 3)
if x.dim() < 4:
raise ValueError(f"Expected tensor with at least 4 dimensions but got {x.dim()} instead")
# we need to ensure that we can split the channels evenly
num_chans = x.shape[1]
num_chans = x.shape[-4]
# h and w is split. First we make w local by transposing into channel dim
if self.comm_size_azimuth > 1:
x = distributed_transpose_azimuth.apply(x, (1, -1), self.lon_shapes)
x = distributed_transpose_azimuth.apply(x, (-4, -1), self.lon_shapes)
# apply real fft in the longitudinal direction: make sure to truncate to nlon
x = 2.0 * torch.pi * torch.fft.rfft(x, n=self.nlon, dim=-1, norm="forward")
......@@ -378,11 +385,11 @@ class DistributedRealVectorSHT(nn.Module):
# transpose: after this, m is split and c is local
if self.comm_size_azimuth > 1:
chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
x = distributed_transpose_azimuth.apply(x, (-1, 1), chan_shapes)
x = distributed_transpose_azimuth.apply(x, (-1, -4), chan_shapes)
# transpose: after this, c is split and h is local
if self.comm_size_polar > 1:
x = distributed_transpose_polar.apply(x, (1, -2), self.lat_shapes)
x = distributed_transpose_polar.apply(x, (-4, -2), self.lat_shapes)
# do the Legendre-Gauss quadrature
x = torch.view_as_real(x)
......@@ -412,7 +419,7 @@ class DistributedRealVectorSHT(nn.Module):
# transpose: after this, l is split and c is local
if self.comm_size_polar > 1:
chan_shapes = compute_split_shapes(num_chans, self.comm_size_polar)
x = distributed_transpose_polar.apply(x, (-2, 1), chan_shapes)
x = distributed_transpose_polar.apply(x, (-2, -4), chan_shapes)
return x
......@@ -459,7 +466,7 @@ class DistributedInverseRealVectorSHT(nn.Module):
# determine the dimensions
self.mmax = mmax or self.nlon // 2 + 1
# compute splits
# compute splits
self.lat_shapes = compute_split_shapes(self.nlat, self.comm_size_polar)
self.lon_shapes = compute_split_shapes(self.nlon, self.comm_size_azimuth)
self.l_shapes = compute_split_shapes(self.lmax, self.comm_size_polar)
......@@ -483,12 +490,15 @@ class DistributedInverseRealVectorSHT(nn.Module):
def forward(self, x: torch.Tensor):
if x.dim() < 4:
raise ValueError(f"Expected tensor with at least 4 dimensions but got {x.dim()} instead")
# store num channels
num_chans = x.shape[1]
num_chans = x.shape[-4]
# transpose: after that, channels are split, l is local:
if self.comm_size_polar > 1:
x = distributed_transpose_polar.apply(x, (1, -2), self.l_shapes)
x = distributed_transpose_polar.apply(x, (-4, -2), self.l_shapes)
# Evaluate associated Legendre functions on the output nodes
x = torch.view_as_real(x)
......@@ -519,11 +529,11 @@ class DistributedInverseRealVectorSHT(nn.Module):
if self.comm_size_polar > 1:
chan_shapes = compute_split_shapes(num_chans, self.comm_size_polar)
x = distributed_transpose_polar.apply(x, (-2, 1), chan_shapes)
x = distributed_transpose_polar.apply(x, (-2, -4), chan_shapes)
# transpose: after this, channels are split and m is local
if self.comm_size_azimuth > 1:
x = distributed_transpose_azimuth.apply(x, (1, -1), self.m_shapes)
x = distributed_transpose_azimuth.apply(x, (-4, -1), self.m_shapes)
# apply the inverse (real) FFT
x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")
......@@ -531,6 +541,6 @@ class DistributedInverseRealVectorSHT(nn.Module):
# transpose: after this, m is split and channels are local
if self.comm_size_azimuth > 1:
chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
x = distributed_transpose_azimuth.apply(x, (-1, 1), chan_shapes)
x = distributed_transpose_azimuth.apply(x, (-1, -4), chan_shapes)
return x
......@@ -272,7 +272,6 @@ class SpectralConvS2(nn.Module):
scale = math.sqrt(gain / in_channels) * torch.ones(self.modes_lat, 2)
scale[0] *= math.sqrt(2)
self.weight = nn.Parameter(scale * torch.view_as_real(torch.randn(*weight_shape, dtype=torch.complex64)))
# self.weight = nn.Parameter(scale * torch.randn(*weight_shape, 2))
# get the right contraction function
self._contract = _contract
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
......@@ -57,7 +57,7 @@ class SpectralFilterLayer(nn.Module):
separable = False,
rank = 1e-2,
bias = True):
super(SpectralFilterLayer, self).__init__()
super(SpectralFilterLayer, self).__init__()
if factorization is None:
self.filter = SpectralConvS2(forward_transform,
......@@ -67,7 +67,7 @@ class SpectralFilterLayer(nn.Module):
gain = gain,
operator_type = operator_type,
bias = bias)
elif factorization is not None:
self.filter = FactorizedSpectralConvS2(forward_transform,
inverse_transform,
......@@ -117,7 +117,7 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
if inner_skip == "linear" or inner_skip == "identity":
gain_factor /= 2.0
# convolution layer
self.filter = SpectralFilterLayer(forward_transform,
inverse_transform,
......@@ -146,14 +146,14 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
# first normalisation layer
self.norm0 = norm_layer()
# dropout
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
gain_factor = 1.0
if outer_skip == "linear" or inner_skip == "identity":
gain_factor /= 2.
if use_mlp == True:
mlp_hidden_dim = int(output_dim * mlp_ratio)
self.mlp = MLP(in_features = output_dim,
......@@ -355,7 +355,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
norm_layer0 = nn.Identity
norm_layer1 = norm_layer0
else:
raise NotImplementedError(f"Error, normalization {self.normalization_layer} not implemented.")
raise NotImplementedError(f"Error, normalization {self.normalization_layer} not implemented.")
if pos_embed == "latlon" or pos_embed==True:
self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, self.img_size[0], self.img_size[1]))
......@@ -402,7 +402,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
nn.init.constant_(fc.bias, 0.0)
encoder_layers.append(fc)
self.encoder = nn.Sequential(*encoder_layers)
# prepare the spectral transform
if self.spectral_transform == "sht":
......@@ -424,7 +424,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
self.itrans_up = InverseRealFFT2(*self.img_size, lmax=modes_lat, mmax=modes_lon).float()
self.trans = RealFFT2(self.h, self.w, lmax=modes_lat, mmax=modes_lon).float()
self.itrans = InverseRealFFT2(self.h, self.w, lmax=modes_lat, mmax=modes_lon).float()
else:
raise(ValueError("Unknown spectral transform"))
......@@ -508,7 +508,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
for blk in self.blocks:
x = blk(x)
return x
def forward(self, x):
......@@ -529,5 +529,5 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
x = self.decoder(x)
return x
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
......@@ -85,7 +85,7 @@ class RealSHT(nn.Module):
# apply cosine transform and flip them
tq = np.flip(np.arccos(cost))
# determine the dimensions
# determine the dimensions
self.mmax = mmax or self.nlon // 2 + 1
# combine quadrature weights with the legendre weights
......@@ -105,26 +105,29 @@ class RealSHT(nn.Module):
def forward(self, x: torch.Tensor):
if x.dim() < 2:
raise ValueError(f"Expected tensor with at least 2 dimensions but got {x.dim()} instead")
assert(x.shape[-2] == self.nlat)
assert(x.shape[-1] == self.nlon)
# apply real fft in the longitudinal direction
x = 2.0 * torch.pi * torch.fft.rfft(x, dim=-1, norm="forward")
# do the Legendre-Gauss quadrature
x = torch.view_as_real(x)
# distributed contraction: fork
out_shape = list(x.size())
out_shape[-3] = self.lmax
out_shape[-2] = self.mmax
xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device)
# contraction
xout[..., 0] = torch.einsum('...km,mlk->...lm', x[..., :self.mmax, 0], self.weights.to(x.dtype) )
xout[..., 1] = torch.einsum('...km,mlk->...lm', x[..., :self.mmax, 1], self.weights.to(x.dtype) )
x = torch.view_as_complex(xout)
return x
class InverseRealSHT(nn.Module):
......@@ -164,7 +167,7 @@ class InverseRealSHT(nn.Module):
# apply cosine transform and flip them
t = np.flip(np.arccos(cost))
# determine the dimensions
# determine the dimensions
self.mmax = mmax or self.nlon // 2 + 1
pct = _precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
......@@ -181,12 +184,15 @@ class InverseRealSHT(nn.Module):
def forward(self, x: torch.Tensor):
if len(x.shape) < 2:
raise ValueError(f"Expected tensor with at least 2 dimensions but got {len(x.shape)} instead")
assert(x.shape[-2] == self.lmax)
assert(x.shape[-1] == self.mmax)
# Evaluate associated Legendre functions on the output nodes
x = torch.view_as_real(x)
rl = torch.einsum('...lm, mlk->...km', x[..., 0], self.pct.to(x.dtype) )
im = torch.einsum('...lm, mlk->...km', x[..., 1], self.pct.to(x.dtype) )
xs = torch.stack((rl, im), -1)
......@@ -243,13 +249,13 @@ class RealVectorSHT(nn.Module):
# apply cosine transform and flip them
tq = np.flip(np.arccos(cost))
# determine the dimensions
# determine the dimensions
self.mmax = mmax or self.nlon // 2 + 1
weights = torch.from_numpy(w)
dpct = _precompute_dlegpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
dpct = torch.from_numpy(dpct)
# combine integration weights, normalization factor in to one:
l = torch.arange(0, self.lmax)
norm_factor = 1. / l / (l+1)
......@@ -269,14 +275,18 @@ class RealVectorSHT(nn.Module):
def forward(self, x: torch.Tensor):
assert(len(x.shape) >= 3)
if x.dim() < 3:
raise ValueError(f"Expected tensor with at least 3 dimensions but got {x.dim()} instead")
assert(x.shape[-2] == self.nlat)
assert(x.shape[-1] == self.nlon)
# apply real fft in the longitudinal direction
x = 2.0 * torch.pi * torch.fft.rfft(x, dim=-1, norm="forward")
# do the Legendre-Gauss quadrature
x = torch.view_as_real(x)
# distributed contraction: fork
out_shape = list(x.size())
out_shape[-3] = self.lmax
......@@ -286,19 +296,19 @@ class RealVectorSHT(nn.Module):
# contraction - spheroidal component
# real component
xout[..., 0, :, :, 0] = torch.einsum('...km,mlk->...lm', x[..., 0, :, :self.mmax, 0], self.weights[0].to(x.dtype)) \
- torch.einsum('...km,mlk->...lm', x[..., 1, :, :self.mmax, 1], self.weights[1].to(x.dtype))
- torch.einsum('...km,mlk->...lm', x[..., 1, :, :self.mmax, 1], self.weights[1].to(x.dtype))
# iamg component
xout[..., 0, :, :, 1] = torch.einsum('...km,mlk->...lm', x[..., 0, :, :self.mmax, 1], self.weights[0].to(x.dtype)) \
+ torch.einsum('...km,mlk->...lm', x[..., 1, :, :self.mmax, 0], self.weights[1].to(x.dtype))
+ torch.einsum('...km,mlk->...lm', x[..., 1, :, :self.mmax, 0], self.weights[1].to(x.dtype))
# contraction - toroidal component
# real component
xout[..., 1, :, :, 0] = - torch.einsum('...km,mlk->...lm', x[..., 0, :, :self.mmax, 1], self.weights[1].to(x.dtype)) \
- torch.einsum('...km,mlk->...lm', x[..., 1, :, :self.mmax, 0], self.weights[0].to(x.dtype))
- torch.einsum('...km,mlk->...lm', x[..., 1, :, :self.mmax, 0], self.weights[0].to(x.dtype))
# imag component
xout[..., 1, :, :, 1] = torch.einsum('...km,mlk->...lm', x[..., 0, :, :self.mmax, 0], self.weights[1].to(x.dtype)) \
- torch.einsum('...km,mlk->...lm', x[..., 1, :, :self.mmax, 1], self.weights[0].to(x.dtype))
- torch.einsum('...km,mlk->...lm', x[..., 1, :, :self.mmax, 1], self.weights[0].to(x.dtype))
return torch.view_as_complex(xout)
......@@ -307,7 +317,7 @@ class InverseRealVectorSHT(nn.Module):
r"""
Defines a module for computing the inverse (real-valued) vector SHT.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
[1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
[2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
"""
......@@ -337,7 +347,7 @@ class InverseRealVectorSHT(nn.Module):
# apply cosine transform and flip them
t = np.flip(np.arccos(cost))
# determine the dimensions
# determine the dimensions
self.mmax = mmax or self.nlon // 2 + 1
dpct = _precompute_dlegpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
......@@ -354,28 +364,31 @@ class InverseRealVectorSHT(nn.Module):
def forward(self, x: torch.Tensor):
if x.dim() < 3:
raise ValueError(f"Expected tensor with at least 3 dimensions but got {x.dim()} instead")
assert(x.shape[-2] == self.lmax)
assert(x.shape[-1] == self.mmax)
# Evaluate associated Legendre functions on the output nodes
x = torch.view_as_real(x)
# contraction - spheroidal component
# real component
srl = torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 0], self.dpct[0].to(x.dtype)) \
- torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 1], self.dpct[1].to(x.dtype))
- torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 1], self.dpct[1].to(x.dtype))
# iamg component
sim = torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 1], self.dpct[0].to(x.dtype)) \
+ torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 0], self.dpct[1].to(x.dtype))
+ torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 0], self.dpct[1].to(x.dtype))
# contraction - toroidal component
# real component
trl = - torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 1], self.dpct[1].to(x.dtype)) \
- torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 0], self.dpct[0].to(x.dtype))
- torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 0], self.dpct[0].to(x.dtype))
# imag component
tim = torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 0], self.dpct[1].to(x.dtype)) \
- torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 1], self.dpct[0].to(x.dtype))
- torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 1], self.dpct[0].to(x.dtype))
# reassemble
s = torch.stack((srl, sim), -1)
t = torch.stack((trl, tim), -1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment