Commit 30f7802b authored by Thorsten Kurth's avatar Thorsten Kurth
Browse files

fixing some more missing device statements

parent f30ec30a
......@@ -77,7 +77,7 @@ class GaussianRandomFieldS2(torch.nn.Module):
self.isht = InverseRealSHT(self.nlat, 2*self.nlat, grid=grid, norm='backward').to(dtype=dtype)
#Square root of the eigenvalues of C.
sqrt_eig = torch.tensor([j*(j+1) for j in range(self.nlat)]).view(self.nlat,1).repeat(1, self.nlat+1)
sqrt_eig = torch.as_tensor([j*(j+1) for j in range(self.nlat)]).view(self.nlat,1).repeat(1, self.nlat+1)
sqrt_eig = torch.tril(sigma*(((sqrt_eig/radius**2) + tau**2)**(-alpha/2.0)))
sqrt_eig[0,0] = 0.0
sqrt_eig = sqrt_eig.unsqueeze(0)
......@@ -85,8 +85,8 @@ class GaussianRandomFieldS2(torch.nn.Module):
#Save mean and var of the standard Gaussian.
#Need these to re-initialize distribution on a new device.
mean = torch.tensor([0.0]).to(dtype=dtype)
var = torch.tensor([1.0]).to(dtype=dtype)
mean = torch.as_tensor([0.0]).to(dtype=dtype)
var = torch.as_tensor([1.0]).to(dtype=dtype)
self.register_buffer('mean', mean)
self.register_buffer('var', var)
......
......@@ -75,9 +75,9 @@ class ResampleS2(nn.Module):
# we need to expand the solution to the poles before interpolating
self.expand_poles = (self.lats_out > self.lats_in[-1]).any() or (self.lats_out < self.lats_in[0]).any()
if self.expand_poles:
self.lats_in = torch.cat([torch.tensor([0.], dtype=torch.float64),
self.lats_in = torch.cat([torch.as_tensor([0.], dtype=torch.float64),
self.lats_in,
torch.tensor([math.pi], dtype=torch.float64)]).contiguous()
torch.as_tensor([math.pi], dtype=torch.float64)]).contiguous()
# prepare the interpolation by computing indices to the left and right of each output latitude
lat_idx = torch.searchsorted(self.lats_in, self.lats_out, side="right") - 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment