Unverified Commit b5c410c0 authored by Thorsten Kurth's avatar Thorsten Kurth Committed by GitHub
Browse files

Merge pull request #93 from NVIDIA/tkurth/device-fixes

Tkurth/device fixes
parents 4aaff021 3d604f85
...@@ -39,7 +39,7 @@ from torch.autograd import gradcheck ...@@ -39,7 +39,7 @@ from torch.autograd import gradcheck
from torch_harmonics import quadrature, DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2 from torch_harmonics import quadrature, DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes
from torch_harmonics.convolution import _precompute_convolution_tensor_s2
_devices = [(torch.device("cpu"),)] _devices = [(torch.device("cpu"),)]
if torch.cuda.is_available(): if torch.cuda.is_available():
...@@ -127,7 +127,7 @@ def _precompute_convolution_tensor_dense( ...@@ -127,7 +127,7 @@ def _precompute_convolution_tensor_dense(
quad_weights = win.reshape(-1, 1) / nlon_in / 2.0 quad_weights = win.reshape(-1, 1) / nlon_in / 2.0
# array for accumulating non-zero indices # array for accumulating non-zero indices
out = torch.zeros(kernel_size, nlat_out, nlon_out, nlat_in, nlon_in, dtype=torch.float64) out = torch.zeros(kernel_size, nlat_out, nlon_out, nlat_in, nlon_in, dtype=torch.float64, device=lons_in.device)
for t in range(nlat_out): for t in range(nlat_out):
for p in range(nlon_out): for p in range(nlon_out):
...@@ -199,9 +199,10 @@ class TestDiscreteContinuousConvolution(unittest.TestCase): ...@@ -199,9 +199,10 @@ class TestDiscreteContinuousConvolution(unittest.TestCase):
[8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "equiangular", "legendre-gauss", True, 1e-4, False], [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "equiangular", "legendre-gauss", True, 1e-4, False],
[8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "legendre-gauss", "equiangular", True, 1e-4, False], [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "legendre-gauss", "equiangular", True, 1e-4, False],
[8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", True, 1e-4, False], [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", True, 1e-4, False],
] ],
skip_on_empty=True,
) )
def test_disco_convolution( def test_forward_backward(
self, self,
batch_size, batch_size,
in_channels, in_channels,
...@@ -315,6 +316,70 @@ class TestDiscreteContinuousConvolution(unittest.TestCase): ...@@ -315,6 +316,70 @@ class TestDiscreteContinuousConvolution(unittest.TestCase):
self.assertTrue(torch.allclose(x_grad, x_ref_grad, rtol=tol, atol=tol)) self.assertTrue(torch.allclose(x_grad, x_ref_grad, rtol=tol, atol=tol))
self.assertTrue(torch.allclose(conv.weight.grad, w_ref.grad, rtol=tol, atol=tol)) self.assertTrue(torch.allclose(conv.weight.grad, w_ref.grad, rtol=tol, atol=tol))
@parameterized.expand(
[
[8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False],
[8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", False, 1e-4, False],
[8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False],
[8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", True, 1e-4, False],
],
skip_on_empty=True,
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is not available")
def test_device_instantiation(self, batch_size, in_channels, out_channels, in_shape, out_shape, kernel_shape, basis_type, basis_norm_mode, grid_in, grid_out, transpose, tol, verbose):
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
if isinstance(kernel_shape, int):
theta_cutoff = (kernel_shape + 1) * torch.pi / float(nlat_in - 1)
else:
theta_cutoff = (kernel_shape[0] + 1) * torch.pi / float(nlat_in - 1)
# get handle
Conv = DiscreteContinuousConvTransposeS2 if transpose else DiscreteContinuousConvS2
# init on cpu
conv_host = Conv(
in_channels,
out_channels,
in_shape,
out_shape,
kernel_shape,
basis_type=basis_type,
basis_norm_mode=basis_norm_mode,
groups=1,
grid_in=grid_in,
grid_out=grid_out,
bias=False,
theta_cutoff=theta_cutoff,
)
#torch.set_default_device(self.device)
with torch.device(self.device):
conv_device = Conv(
in_channels,
out_channels,
in_shape,
out_shape,
kernel_shape,
basis_type=basis_type,
basis_norm_mode=basis_norm_mode,
groups=1,
grid_in=grid_in,
grid_out=grid_out,
bias=False,
theta_cutoff=theta_cutoff,
)
# since we specified the device specifier everywhere, it should always
# use the cpu and it should be the same everywhere
self.assertTrue(torch.allclose(conv_host.psi_col_idx.cpu(), conv_device.psi_col_idx.cpu()))
self.assertTrue(torch.allclose(conv_host.psi_row_idx.cpu(), conv_device.psi_row_idx.cpu()))
self.assertTrue(torch.allclose(conv_host.psi_roff_idx.cpu(), conv_device.psi_roff_idx.cpu()))
self.assertTrue(torch.allclose(conv_host.psi_vals.cpu(), conv_device.psi_vals.cpu()))
self.assertTrue(torch.allclose(conv_host.psi_idx.cpu(), conv_device.psi_idx.cpu()))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -101,15 +101,16 @@ class TestSphericalHarmonicTransform(unittest.TestCase): ...@@ -101,15 +101,16 @@ class TestSphericalHarmonicTransform(unittest.TestCase):
[33, 64, 32, "ortho", "equiangular", 1e-9, False], [33, 64, 32, "ortho", "equiangular", 1e-9, False],
[33, 64, 32, "ortho", "legendre-gauss", 1e-9, False], [33, 64, 32, "ortho", "legendre-gauss", 1e-9, False],
[33, 64, 32, "ortho", "lobatto", 1e-9, False], [33, 64, 32, "ortho", "lobatto", 1e-9, False],
[33, 64, 32, "four-pi", "equiangular", 1e-9, False], [33, 64, 32, "four-pi", "equiangular", 1e-9, False],
[33, 64, 32, "four-pi", "legendre-gauss", 1e-9, False], [33, 64, 32, "four-pi", "legendre-gauss", 1e-9, False],
[33, 64, 32, "four-pi", "lobatto", 1e-9, False], [33, 64, 32, "four-pi", "lobatto", 1e-9, False],
[33, 64, 32, "schmidt", "equiangular", 1e-9, False], [33, 64, 32, "schmidt", "equiangular", 1e-9, False],
[33, 64, 32, "schmidt", "legendre-gauss", 1e-9, False], [33, 64, 32, "schmidt", "legendre-gauss", 1e-9, False],
[33, 64, 32, "schmidt", "lobatto", 1e-9, False], [33, 64, 32, "schmidt", "lobatto", 1e-9, False],
] ],
skip_on_empty=True,
) )
def test_sht(self, nlat, nlon, batch_size, norm, grid, tol, verbose): def test_forward_inverse(self, nlat, nlon, batch_size, norm, grid, tol, verbose):
if verbose: if verbose:
print(f"Testing real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization on {self.device.type} device") print(f"Testing real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization on {self.device.type} device")
...@@ -168,9 +169,10 @@ class TestSphericalHarmonicTransform(unittest.TestCase): ...@@ -168,9 +169,10 @@ class TestSphericalHarmonicTransform(unittest.TestCase):
[15, 30, 2, "schmidt", "equiangular", 1e-5, False], [15, 30, 2, "schmidt", "equiangular", 1e-5, False],
[15, 30, 2, "schmidt", "legendre-gauss", 1e-5, False], [15, 30, 2, "schmidt", "legendre-gauss", 1e-5, False],
[15, 30, 2, "schmidt", "lobatto", 1e-5, False], [15, 30, 2, "schmidt", "lobatto", 1e-5, False],
] ],
skip_on_empty=True,
) )
def test_sht_grads(self, nlat, nlon, batch_size, norm, grid, tol, verbose): def test_grads(self, nlat, nlon, batch_size, norm, grid, tol, verbose):
if verbose: if verbose:
print(f"Testing gradients of real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization") print(f"Testing gradients of real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization")
...@@ -202,6 +204,40 @@ class TestSphericalHarmonicTransform(unittest.TestCase): ...@@ -202,6 +204,40 @@ class TestSphericalHarmonicTransform(unittest.TestCase):
test_result = gradcheck(err_handle, grad_input, eps=1e-6, atol=tol) test_result = gradcheck(err_handle, grad_input, eps=1e-6, atol=tol)
self.assertTrue(test_result) self.assertTrue(test_result)
@parameterized.expand(
[
# even-even
[12, 24, 2, "ortho", "equiangular", 1e-5, False],
[12, 24, 2, "ortho", "legendre-gauss", 1e-5, False],
[12, 24, 2, "ortho", "lobatto", 1e-5, False],
],
skip_on_empty=True,
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is not available")
def test_device_instantiation(self, nlat, nlon, batch_size, norm, grid, tol, verbose):
if verbose:
print(f"Testing device instantiation of real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization")
if grid == "equiangular":
mmax = nlat // 2
elif grid == "lobatto":
mmax = nlat - 1
else:
mmax = nlat
lmax = mmax
# init on cpu
sht_host = th.RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm)
isht_host = th.InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm)
# init on device
with torch.device(self.device):
sht_device = th.RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm)
isht_device = th.InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm)
self.assertTrue(torch.allclose(sht_host.weights.cpu(), sht_device.weights.cpu()))
self.assertTrue(torch.allclose(isht_host.pct.cpu(), isht_device.pct.cpu()))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -92,8 +92,8 @@ def _normalize_convolution_tensor_s2( ...@@ -92,8 +92,8 @@ def _normalize_convolution_tensor_s2(
q = quad_weights[ilat_in].reshape(-1) q = quad_weights[ilat_in].reshape(-1)
# buffer to store intermediate values # buffer to store intermediate values
vnorm = torch.zeros(kernel_size, nlat_out) vnorm = torch.zeros(kernel_size, nlat_out, device=psi_vals.device)
support = torch.zeros(kernel_size, nlat_out) support = torch.zeros(kernel_size, nlat_out, device=psi_vals.device)
# loop through dimensions to compute the norms # loop through dimensions to compute the norms
for ik in range(kernel_size): for ik in range(kernel_size):
...@@ -207,7 +207,7 @@ def _precompute_convolution_tensor_s2( ...@@ -207,7 +207,7 @@ def _precompute_convolution_tensor_s2(
sgamma = torch.sin(gamma) sgamma = torch.sin(gamma)
# compute row offsets # compute row offsets
out_roff = torch.zeros(nlat_out + 1, dtype=torch.int64) out_roff = torch.zeros(nlat_out + 1, dtype=torch.int64, device=lons_in.device)
out_roff[0] = 0 out_roff[0] = 0
for t in range(nlat_out): for t in range(nlat_out):
# the last angle has a negative sign as it is a passive rotation, which rotates the filter around the y-axis # the last angle has a negative sign as it is a passive rotation, which rotates the filter around the y-axis
......
...@@ -104,13 +104,22 @@ torch::Tensor preprocess_psi(const int64_t K, const int64_t Ho, torch::Tensor ke ...@@ -104,13 +104,22 @@ torch::Tensor preprocess_psi(const int64_t K, const int64_t Ho, torch::Tensor ke
CHECK_INPUT_TENSOR(col_idx); CHECK_INPUT_TENSOR(col_idx);
CHECK_INPUT_TENSOR(val); CHECK_INPUT_TENSOR(val);
// get the input device and make sure all tensors are on the same device
auto device = ker_idx.device();
TORCH_INTERNAL_ASSERT(device.type() == row_idx.device().type() && (device.type() == col_idx.device().type()) && (device.type() == val.device().type()));
// move to cpu
ker_idx = ker_idx.to(torch::kCPU);
row_idx = row_idx.to(torch::kCPU);
col_idx = col_idx.to(torch::kCPU);
val = val.to(torch::kCPU);
int64_t nnz = val.size(0); int64_t nnz = val.size(0);
int64_t *ker_h = ker_idx.data_ptr<int64_t>(); int64_t *ker_h = ker_idx.data_ptr<int64_t>();
int64_t *row_h = row_idx.data_ptr<int64_t>(); int64_t *row_h = row_idx.data_ptr<int64_t>();
int64_t *col_h = col_idx.data_ptr<int64_t>(); int64_t *col_h = col_idx.data_ptr<int64_t>();
int64_t *roff_h = new int64_t[Ho * K + 1]; int64_t *roff_h = new int64_t[Ho * K + 1];
int64_t nrows; int64_t nrows;
// float *val_h = val.data_ptr<float>();
AT_DISPATCH_FLOATING_TYPES(val.scalar_type(), "preprocess_psi", ([&] { AT_DISPATCH_FLOATING_TYPES(val.scalar_type(), "preprocess_psi", ([&] {
preprocess_psi_kernel<scalar_t>(nnz, K, Ho, ker_h, row_h, col_h, roff_h, preprocess_psi_kernel<scalar_t>(nnz, K, Ho, ker_h, row_h, col_h, roff_h,
...@@ -118,13 +127,19 @@ torch::Tensor preprocess_psi(const int64_t K, const int64_t Ho, torch::Tensor ke ...@@ -118,13 +127,19 @@ torch::Tensor preprocess_psi(const int64_t K, const int64_t Ho, torch::Tensor ke
})); }));
// create output tensor // create output tensor
auto options = torch::TensorOptions().dtype(row_idx.dtype()); auto roff_idx = torch::empty({nrows + 1}, row_idx.options());
auto roff_idx = torch::empty({nrows + 1}, options);
int64_t *roff_out_h = roff_idx.data_ptr<int64_t>(); int64_t *roff_out_h = roff_idx.data_ptr<int64_t>();
for (int64_t i = 0; i < (nrows + 1); i++) { roff_out_h[i] = roff_h[i]; } for (int64_t i = 0; i < (nrows + 1); i++) { roff_out_h[i] = roff_h[i]; }
delete[] roff_h; delete[] roff_h;
// move to original device
ker_idx = ker_idx.to(device);
row_idx = row_idx.to(device);
col_idx = col_idx.to(device);
val = val.to(device);
roff_idx = roff_idx.to(device);
return roff_idx; return roff_idx;
} }
......
...@@ -254,7 +254,7 @@ class MorletFilterBasis(FilterBasis): ...@@ -254,7 +254,7 @@ class MorletFilterBasis(FilterBasis):
mkernel = ikernel // self.kernel_shape[1] mkernel = ikernel // self.kernel_shape[1]
# get relevant indices # get relevant indices
iidx = torch.argwhere((r <= r_cutoff) & torch.full_like(ikernel, True, dtype=torch.bool)) iidx = torch.argwhere((r <= r_cutoff) & torch.full_like(ikernel, True, dtype=torch.bool, device=r.device))
# get corresponding r, phi, x and y coordinates # get corresponding r, phi, x and y coordinates
r = r[iidx[:, 1], iidx[:, 2]] / r_cutoff r = r[iidx[:, 1], iidx[:, 2]] / r_cutoff
...@@ -316,10 +316,10 @@ class ZernikeFilterBasis(FilterBasis): ...@@ -316,10 +316,10 @@ class ZernikeFilterBasis(FilterBasis):
""" """
# enumerator for basis function # enumerator for basis function
ikernel = torch.arange(self.kernel_size).reshape(-1, 1, 1) ikernel = torch.arange(self.kernel_size, device=r.device).reshape(-1, 1, 1)
# get relevant indices # get relevant indices
iidx = torch.argwhere((r <= r_cutoff) & torch.full_like(ikernel, True, dtype=torch.bool)) iidx = torch.argwhere((r <= r_cutoff) & torch.full_like(ikernel, True, dtype=torch.bool, device=r.device))
# indexing logic for zernike polynomials # indexing logic for zernike polynomials
# the total index is given by (n * (n + 2) + l ) // 2 which needs to be reversed # the total index is given by (n * (n + 2) + l ) // 2 which needs to be reversed
......
...@@ -57,10 +57,10 @@ def legpoly(mmax: int, lmax: int, x: torch.Tensor, norm: Optional[str]="ortho", ...@@ -57,10 +57,10 @@ def legpoly(mmax: int, lmax: int, x: torch.Tensor, norm: Optional[str]="ortho",
# compute the tensor P^m_n: # compute the tensor P^m_n:
nmax = max(mmax,lmax) nmax = max(mmax,lmax)
vdm = torch.zeros((nmax, nmax, len(x)), dtype=torch.float64, requires_grad=False) vdm = torch.zeros((nmax, nmax, len(x)), dtype=torch.float64, device=x.device, requires_grad=False)
norm_factor = 1. if norm == "ortho" else math.sqrt(4 * math.pi) norm_factor = 1.0 if norm == "ortho" else math.sqrt(4 * math.pi)
norm_factor = 1. / norm_factor if inverse else norm_factor norm_factor = 1.0 / norm_factor if inverse else norm_factor
# initial values to start the recursion # initial values to start the recursion
vdm[0,0,:] = norm_factor / math.sqrt(4 * math.pi) vdm[0,0,:] = norm_factor / math.sqrt(4 * math.pi)
...@@ -123,7 +123,7 @@ def _precompute_dlegpoly(mmax: int, lmax: int, t: torch.Tensor, ...@@ -123,7 +123,7 @@ def _precompute_dlegpoly(mmax: int, lmax: int, t: torch.Tensor,
pct = _precompute_legpoly(mmax+1, lmax+1, t, norm=norm, inverse=inverse, csphase=False) pct = _precompute_legpoly(mmax+1, lmax+1, t, norm=norm, inverse=inverse, csphase=False)
dpct = torch.zeros((2, mmax, lmax, len(t)), dtype=torch.float64, requires_grad=False) dpct = torch.zeros((2, mmax, lmax, len(t)), dtype=torch.float64, device=t.device, requires_grad=False)
# fill the derivative terms wrt theta # fill the derivative terms wrt theta
for l in range(0, lmax): for l in range(0, lmax):
......
...@@ -169,7 +169,7 @@ def clenshaw_curtiss_weights(n: int, a: Optional[float]=-1.0, b: Optional[float] ...@@ -169,7 +169,7 @@ def clenshaw_curtiss_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]
tcc = torch.cos(torch.linspace(math.pi, 0, n, dtype=torch.float64, requires_grad=False)) tcc = torch.cos(torch.linspace(math.pi, 0, n, dtype=torch.float64, requires_grad=False))
if n == 2: if n == 2:
wcc = torch.tensor([1.0, 1.0], dtype=torch.float64) wcc = torch.as_tensor([1.0, 1.0], dtype=torch.float64)
else: else:
n1 = n - 1 n1 = n - 1
......
...@@ -77,7 +77,7 @@ class GaussianRandomFieldS2(torch.nn.Module): ...@@ -77,7 +77,7 @@ class GaussianRandomFieldS2(torch.nn.Module):
self.isht = InverseRealSHT(self.nlat, 2*self.nlat, grid=grid, norm='backward').to(dtype=dtype) self.isht = InverseRealSHT(self.nlat, 2*self.nlat, grid=grid, norm='backward').to(dtype=dtype)
#Square root of the eigenvalues of C. #Square root of the eigenvalues of C.
sqrt_eig = torch.tensor([j*(j+1) for j in range(self.nlat)]).view(self.nlat,1).repeat(1, self.nlat+1) sqrt_eig = torch.as_tensor([j*(j+1) for j in range(self.nlat)]).view(self.nlat,1).repeat(1, self.nlat+1)
sqrt_eig = torch.tril(sigma*(((sqrt_eig/radius**2) + tau**2)**(-alpha/2.0))) sqrt_eig = torch.tril(sigma*(((sqrt_eig/radius**2) + tau**2)**(-alpha/2.0)))
sqrt_eig[0,0] = 0.0 sqrt_eig[0,0] = 0.0
sqrt_eig = sqrt_eig.unsqueeze(0) sqrt_eig = sqrt_eig.unsqueeze(0)
...@@ -85,8 +85,8 @@ class GaussianRandomFieldS2(torch.nn.Module): ...@@ -85,8 +85,8 @@ class GaussianRandomFieldS2(torch.nn.Module):
#Save mean and var of the standard Gaussian. #Save mean and var of the standard Gaussian.
#Need these to re-initialize distribution on a new device. #Need these to re-initialize distribution on a new device.
mean = torch.tensor([0.0]).to(dtype=dtype) mean = torch.as_tensor([0.0]).to(dtype=dtype)
var = torch.tensor([1.0]).to(dtype=dtype) var = torch.as_tensor([1.0]).to(dtype=dtype)
self.register_buffer('mean', mean) self.register_buffer('mean', mean)
self.register_buffer('var', var) self.register_buffer('var', var)
......
...@@ -75,9 +75,9 @@ class ResampleS2(nn.Module): ...@@ -75,9 +75,9 @@ class ResampleS2(nn.Module):
# we need to expand the solution to the poles before interpolating # we need to expand the solution to the poles before interpolating
self.expand_poles = (self.lats_out > self.lats_in[-1]).any() or (self.lats_out < self.lats_in[0]).any() self.expand_poles = (self.lats_out > self.lats_in[-1]).any() or (self.lats_out < self.lats_in[0]).any()
if self.expand_poles: if self.expand_poles:
self.lats_in = torch.cat([torch.tensor([0.], dtype=torch.float64), self.lats_in = torch.cat([torch.as_tensor([0.], dtype=torch.float64, device=self.lats_in.device),
self.lats_in, self.lats_in,
torch.tensor([math.pi], dtype=torch.float64)]).contiguous() torch.as_tensor([math.pi], dtype=torch.float64, device=self.lats_in.device)]).contiguous()
# prepare the interpolation by computing indices to the left and right of each output latitude # prepare the interpolation by computing indices to the left and right of each output latitude
lat_idx = torch.searchsorted(self.lats_in, self.lats_out, side="right") - 1 lat_idx = torch.searchsorted(self.lats_in, self.lats_out, side="right") - 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment