Unverified Commit bd92cdf7 authored by Thorsten Kurth's avatar Thorsten Kurth Committed by GitHub
Browse files

Tkurth/remove sparse coo tensor (#89)

* refactoring disco backend code

* removed get_psi as member function and instead put it in _disco_convolution

* setting seeds in tests more consistently

* parametrized test classes to ensure that tests are always run on both CPU and GPU (if available)

* cleaning up
parent 9959a7a6
......@@ -30,7 +30,7 @@
#
import unittest
from parameterized import parameterized
from parameterized import parameterized, parameterized_class
# import math
import numpy as np
......@@ -58,17 +58,19 @@ except ImportError as err:
attention_cuda_extension = None
_cuda_extension_available = False
_devices = [(torch.device("cpu"),)]
if torch.cuda.is_available():
_devices.append((torch.device("cuda"),))
_perf_test_thresholds = {"fwd_ms": 50, "bwd_ms": 150}
@parameterized_class(("device"), _devices)
class TestNeighborhoodAttentionS2(unittest.TestCase):
def setUp(self):
if torch.cuda.is_available():
self.device = torch.device("cuda:0")
torch.cuda.set_device(self.device.index)
torch.cuda.manual_seed(333)
else:
self.device = torch.device("cpu")
torch.manual_seed(333)
if self.device.type == "cuda":
torch.cuda.manual_seed(333)
@parameterized.expand(
[
......@@ -107,7 +109,7 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
model.load_state_dict(model_ref.state_dict())
model = model.to(self.device)
for (name_ref, p_ref), (name, p) in zip(model_ref.named_parameters(), model.named_parameters()):
assert torch.allclose(p_ref, p), f"Parameter mismatch: {name_ref} vs {name}"
self.assertTrue(torch.allclose(p_ref, p))
# reference forward passes
out_ref = _neighborhood_attention_s2_torch(
......@@ -187,7 +189,7 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
model.load_state_dict(model_ref.state_dict())
model = model.to(self.device)
for (name_ref, p_ref), (name, p) in zip(model_ref.named_parameters(), model.named_parameters()):
assert torch.allclose(p_ref, p), f"Parameter mismatch: {name_ref} vs {name}"
self.assertTrue(torch.allclose(p_ref, p))
# reference forward passes
out_ref = model_ref(inputs_ref["q"], inputs_ref["k"], inputs_ref["v"])
......@@ -217,7 +219,6 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
@parameterized.expand(
[
# self attention
#[1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
[1, 256, 1, (361, 720), (361, 720), "equiangular", "equiangular", 1e-5, 1e-5],
],
skip_on_empty=True,
......@@ -225,6 +226,11 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
@unittest.skipUnless((torch.cuda.is_available() and _cuda_extension_available), "skipping performance test because CUDA is not available")
def test_perf(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=False):
# skip this test if we are not running on GPU, it will take very long otherwise
if self.device.type != "cuda":
self.assertFalse(False)
return
# extract some parameters
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
......@@ -280,11 +286,11 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
att_gpu.k_bias.copy_(att_gpu.k_bias)
att_gpu.v_bias.copy_(att_gpu.v_bias)
q_gpu = q_gpu.detach().clone().to(self.device)#, memory_format=torch.channels_last)
q_gpu = q_gpu.detach().clone().to(self.device)
q_gpu.requires_grad = True
k_gpu = k_gpu.detach().clone().to(self.device)#, memory_format=torch.channels_last)
k_gpu = k_gpu.detach().clone().to(self.device)
k_gpu.requires_grad = True
v_gpu = v_gpu.detach().clone().to(self.device)#, memory_format=torch.channels_last)
v_gpu = v_gpu.detach().clone().to(self.device)
v_gpu.requires_grad = True
out_gpu = att_gpu(q_gpu, k_gpu, v_gpu)
......
......@@ -30,7 +30,7 @@
#
import unittest
from parameterized import parameterized
from parameterized import parameterized, parameterized_class
from functools import partial
import math
import numpy as np
......@@ -41,6 +41,11 @@ from torch_harmonics import quadrature, DiscreteContinuousConvS2, DiscreteContin
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes
_devices = [(torch.device("cpu"),)]
if torch.cuda.is_available():
_devices.append((torch.device("cuda"),))
def _normalize_convolution_tensor_dense(psi, quad_weights, transpose_normalization=False, basis_norm_mode="none", merge_quadrature=False, eps=1e-9):
"""
Discretely normalizes the convolution tensor.
......@@ -161,14 +166,12 @@ def _precompute_convolution_tensor_dense(
return out
@parameterized_class(("device"), _devices)
class TestDiscreteContinuousConvolution(unittest.TestCase):
def setUp(self):
if torch.cuda.is_available():
self.device = torch.device("cuda:0")
torch.cuda.set_device(self.device.index)
torch.manual_seed(333)
if self.device.type == "cuda":
torch.cuda.manual_seed(333)
else:
self.device = torch.device("cpu")
@parameterized.expand(
[
......
......@@ -30,12 +30,17 @@
#
import unittest
from parameterized import parameterized
from parameterized import parameterized, parameterized_class
import math
import torch
from torch.autograd import gradcheck
import torch_harmonics as th
_devices = [(torch.device("cpu"),)]
if torch.cuda.is_available():
_devices.append((torch.device("cuda"),))
class TestLegendrePolynomials(unittest.TestCase):
def setUp(self):
......@@ -72,26 +77,36 @@ class TestLegendrePolynomials(unittest.TestCase):
self.assertTrue(diff.max() <= self.tol)
@parameterized_class(("device"), _devices)
class TestSphericalHarmonicTransform(unittest.TestCase):
def setUp(self):
if torch.cuda.is_available():
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
torch.manual_seed(333)
if self.device.type == "cuda":
torch.cuda.manual_seed(333)
@parameterized.expand(
[
[256, 512, 32, "ortho", "equiangular", 1e-9, False],
[256, 512, 32, "ortho", "legendre-gauss", 1e-9, False],
[256, 512, 32, "ortho", "lobatto", 1e-9, False],
[256, 512, 32, "four-pi", "equiangular", 1e-9, False],
[256, 512, 32, "four-pi", "legendre-gauss", 1e-9, False],
[256, 512, 32, "four-pi", "lobatto", 1e-9, False],
[256, 512, 32, "schmidt", "equiangular", 1e-9, False],
[256, 512, 32, "schmidt", "legendre-gauss", 1e-9, False],
[256, 512, 32, "schmidt", "lobatto", 1e-9, False],
# even-even
[32, 64, 32, "ortho", "equiangular", 1e-9, False],
[32, 64, 32, "ortho", "legendre-gauss", 1e-9, False],
[32, 64, 32, "ortho", "lobatto", 1e-9, False],
[32, 64, 32, "four-pi", "equiangular", 1e-9, False],
[32, 64, 32, "four-pi", "legendre-gauss", 1e-9, False],
[32, 64, 32, "four-pi", "lobatto", 1e-9, False],
[32, 64, 32, "schmidt", "equiangular", 1e-9, False],
[32, 64, 32, "schmidt", "legendre-gauss", 1e-9, False],
[32, 64, 32, "schmidt", "lobatto", 1e-9, False],
# odd-even
[33, 64, 32, "ortho", "equiangular", 1e-9, False],
[33, 64, 32, "ortho", "legendre-gauss", 1e-9, False],
[33, 64, 32, "ortho", "lobatto", 1e-9, False],
[33, 64, 32, "four-pi", "equiangular", 1e-9, False],
[33, 64, 32, "four-pi", "legendre-gauss", 1e-9, False],
[33, 64, 32, "four-pi", "lobatto", 1e-9, False],
[33, 64, 32, "schmidt", "equiangular", 1e-9, False],
[33, 64, 32, "schmidt", "legendre-gauss", 1e-9, False],
[33, 64, 32, "schmidt", "lobatto", 1e-9, False],
]
)
def test_sht(self, nlat, nlon, batch_size, norm, grid, tol, verbose):
......@@ -133,6 +148,7 @@ class TestSphericalHarmonicTransform(unittest.TestCase):
@parameterized.expand(
[
# even-even
[12, 24, 2, "ortho", "equiangular", 1e-5, False],
[12, 24, 2, "ortho", "legendre-gauss", 1e-5, False],
[12, 24, 2, "ortho", "lobatto", 1e-5, False],
......@@ -142,6 +158,7 @@ class TestSphericalHarmonicTransform(unittest.TestCase):
[12, 24, 2, "schmidt", "equiangular", 1e-5, False],
[12, 24, 2, "schmidt", "legendre-gauss", 1e-5, False],
[12, 24, 2, "schmidt", "lobatto", 1e-5, False],
# odd-even
[15, 30, 2, "ortho", "equiangular", 1e-5, False],
[15, 30, 2, "ortho", "legendre-gauss", 1e-5, False],
[15, 30, 2, "ortho", "lobatto", 1e-5, False],
......
......@@ -29,6 +29,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
from typing import Optional
import math
import torch
......@@ -39,6 +40,26 @@ try:
except ImportError as err:
disco_cuda_extension = None
# some helper functions
def _get_psi(kernel_size: int, psi_idx: torch.Tensor, psi_vals: torch.Tensor, nlat_in: int, nlon_in: int, nlat_out: int, nlon_out: int, nlat_in_local: Optional[int] = None, nlat_out_local: Optional[int] = None, semi_transposed: Optional[bool] = False):
nlat_in_local = nlat_in_local if nlat_in_local is not None else nlat_in
nlat_out_local = nlat_out_local if nlat_out_local is not None else nlat_out
if semi_transposed:
# do partial transpose
# we do a semi-transposition to faciliate the computation
tout = psi_idx[2] // nlon_out
pout = psi_idx[2] % nlon_out
# flip the axis of longitudes
pout = nlon_out - 1 - pout
tin = psi_idx[1]
idx = torch.stack([psi_idx[0], tout, tin * nlon_out + pout], dim=0)
psi = torch.sparse_coo_tensor(idx, psi_vals, size=(kernel_size, nlat_out_local, nlat_in_local * nlon_out)).coalesce()
else:
psi = torch.sparse_coo_tensor(psi_idx, psi_vals, size=(kernel_size, nlat_out_local, nlat_in_local * nlon_in)).coalesce()
return psi
class _DiscoS2ContractionCuda(torch.autograd.Function):
@staticmethod
......
......@@ -42,7 +42,7 @@ from functools import partial
from torch_harmonics.cache import lru_cache
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes
from torch_harmonics._disco_convolution import _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch
from torch_harmonics._disco_convolution import _get_psi, _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch
from torch_harmonics._disco_convolution import _disco_s2_contraction_cuda, _disco_s2_transpose_contraction_cuda
from torch_harmonics.filter_basis import FilterBasis, get_filter_basis
......@@ -362,6 +362,9 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
self.register_buffer("psi_col_idx", col_idx, persistent=False)
self.register_buffer("psi_vals", vals, persistent=False)
# also store psi as COO matrix just in case for torch input
self.psi = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out)
def extra_repr(self):
r"""
Pretty print module
......@@ -372,10 +375,6 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
def psi_idx(self):
return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()
def get_psi(self):
psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_in)).coalesce()
return psi
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.is_cuda and _cuda_extension_available:
......@@ -385,8 +384,7 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
else:
if x.is_cuda:
warn("couldn't find CUDA extension, falling back to slow PyTorch implementation")
psi = self.get_psi()
x = _disco_s2_contraction_torch(x, psi, self.nlon_out)
x = _disco_s2_contraction_torch(x, self.psi.to(x.device), self.nlon_out)
# extract shape
B, C, K, H, W = x.shape
......@@ -469,6 +467,9 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
self.register_buffer("psi_col_idx", col_idx, persistent=False)
self.register_buffer("psi_vals", vals, persistent=False)
# also store psi just in case
self.psi_st = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out, semi_transposed=True)
def extra_repr(self):
r"""
Pretty print module
......@@ -479,21 +480,6 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
def psi_idx(self):
return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()
def get_psi(self, semi_transposed: bool = False):
if semi_transposed:
# we do a semi-transposition to faciliate the computation
tout = self.psi_idx[2] // self.nlon_out
pout = self.psi_idx[2] % self.nlon_out
# flip the axis of longitudes
pout = self.nlon_out - 1 - pout
tin = self.psi_idx[1]
idx = torch.stack([self.psi_idx[0], tout, tin * self.nlon_out + pout], dim=0)
psi = torch.sparse_coo_tensor(idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_out)).coalesce()
else:
psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out)).coalesce()
return psi
def forward(self, x: torch.Tensor) -> torch.Tensor:
# extract shape
B, C, H, W = x.shape
......@@ -510,8 +496,7 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
else:
if x.is_cuda:
warn("couldn't find CUDA extension, falling back to slow PyTorch implementation")
psi = self.get_psi(semi_transposed=True)
out = _disco_s2_transpose_contraction_torch(x, psi, self.nlon_out)
out = _disco_s2_transpose_contraction_torch(x, self.psi_st.to(x.device), self.nlon_out)
if self.bias is not None:
out = out + self.bias.reshape(1, -1, 1, 1)
......
......@@ -42,7 +42,7 @@ import torch.nn as nn
from functools import partial
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes
from torch_harmonics._disco_convolution import _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch
from torch_harmonics._disco_convolution import _get_psi, _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch
from torch_harmonics._disco_convolution import _disco_s2_contraction_cuda, _disco_s2_transpose_contraction_cuda
from torch_harmonics.filter_basis import get_filter_basis
from torch_harmonics.convolution import (
......@@ -283,6 +283,9 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
self.register_buffer("psi_col_idx", col_idx, persistent=False)
self.register_buffer("psi_vals", vals, persistent=False)
# store psi jic:
self.psi = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out, self.nlat_in_local, self.nlat_out_local)
def extra_repr(self):
r"""
Pretty print module
......@@ -293,10 +296,6 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
def psi_idx(self):
return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()
def get_psi(self):
psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_out_local, self.nlat_in_local * self.nlon_in)).coalesce()
return psi
def forward(self, x: torch.Tensor) -> torch.Tensor:
# store number of channels
......@@ -314,9 +313,7 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
if x.is_cuda:
warn("couldn't find CUDA extension, falling back to slow PyTorch implementation")
psi = self.get_psi()
x = _disco_s2_contraction_torch(x, psi, self.nlon_out)
x = _disco_s2_contraction_torch(x, self.psi.to(x.device), self.nlon_out)
# perform reduce scatter in polar region
x = reduce_from_polar_region(x)
......@@ -426,6 +423,9 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
self.register_buffer("psi_col_idx", col_idx, persistent=False)
self.register_buffer("psi_vals", vals, persistent=False)
# store psi as COO
self.psi_st = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out, self.nlat_in_local, self.nlat_out_local, semi_transposed=True)
def extra_repr(self):
r"""
Pretty print module
......@@ -436,21 +436,6 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
def psi_idx(self):
return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()
def get_psi(self, semi_transposed: bool = False):
if semi_transposed:
# do partial transpose
# we do a semi-transposition to faciliate the computation
tout = self.psi_idx[2] // self.nlon_out
pout = self.psi_idx[2] % self.nlon_out
# flip the axis of longitudes
pout = self.nlon_out - 1 - pout
tin = self.psi_idx[1]
idx = torch.stack([self.psi_idx[0], tout, tin * self.nlon_out + pout], dim=0)
psi = torch.sparse_coo_tensor(idx, self.psi_vals, size=(self.kernel_size, self.nlat_out_local, self.nlat_in_local * self.nlon_out)).coalesce()
else:
psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_in_local, self.nlat_out_local * self.nlon_out)).coalesce()
return psi
def forward(self, x: torch.Tensor) -> torch.Tensor:
# extract shape
......@@ -477,8 +462,7 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
else:
if x.is_cuda:
warn("couldn't find CUDA extension, falling back to slow PyTorch implementation")
psi = self.get_psi(semi_transposed=True)
out = _disco_s2_transpose_contraction_torch(x, psi, self.nlon_out)
out = _disco_s2_transpose_contraction_torch(x, self.psi_st.to(x.device), self.nlon_out)
# now we can transpose back the result, so that lon is split and channels are local
if self.comm_size_azimuth > 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment