Commit f30ec30a authored by Thorsten Kurth's avatar Thorsten Kurth
Browse files

adding more missing device statements

parent ea8d1a2e
...@@ -92,8 +92,8 @@ def _normalize_convolution_tensor_s2( ...@@ -92,8 +92,8 @@ def _normalize_convolution_tensor_s2(
q = quad_weights[ilat_in].reshape(-1) q = quad_weights[ilat_in].reshape(-1)
# buffer to store intermediate values # buffer to store intermediate values
vnorm = torch.zeros(kernel_size, nlat_out) vnorm = torch.zeros(kernel_size, nlat_out, device=psi_vals.device)
support = torch.zeros(kernel_size, nlat_out) support = torch.zeros(kernel_size, nlat_out, device=psi_vals.device)
# loop through dimensions to compute the norms # loop through dimensions to compute the norms
for ik in range(kernel_size): for ik in range(kernel_size):
...@@ -207,7 +207,7 @@ def _precompute_convolution_tensor_s2( ...@@ -207,7 +207,7 @@ def _precompute_convolution_tensor_s2(
sgamma = torch.sin(gamma) sgamma = torch.sin(gamma)
# compute row offsets # compute row offsets
out_roff = torch.zeros(nlat_out + 1, dtype=torch.int64) out_roff = torch.zeros(nlat_out + 1, dtype=torch.int64, device=lons_in.device)
out_roff[0] = 0 out_roff[0] = 0
for t in range(nlat_out): for t in range(nlat_out):
# the last angle has a negative sign as it is a passive rotation, which rotates the filter around the y-axis # the last angle has a negative sign as it is a passive rotation, which rotates the filter around the y-axis
......
...@@ -123,7 +123,7 @@ def _precompute_dlegpoly(mmax: int, lmax: int, t: torch.Tensor, ...@@ -123,7 +123,7 @@ def _precompute_dlegpoly(mmax: int, lmax: int, t: torch.Tensor,
pct = _precompute_legpoly(mmax+1, lmax+1, t, norm=norm, inverse=inverse, csphase=False) pct = _precompute_legpoly(mmax+1, lmax+1, t, norm=norm, inverse=inverse, csphase=False)
dpct = torch.zeros((2, mmax, lmax, len(t)), dtype=torch.float64, requires_grad=False) dpct = torch.zeros((2, mmax, lmax, len(t)), dtype=torch.float64, device=t.device, requires_grad=False)
# fill the derivative terms wrt theta # fill the derivative terms wrt theta
for l in range(0, lmax): for l in range(0, lmax):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment