Commit 8680e023 authored by Boris Bonev's avatar Boris Bonev Committed by Boris Bonev
Browse files

formating changes to resample module

parent 4d8755b5
......@@ -143,19 +143,12 @@ class DistributedResampleS2(nn.Module):
else:
omega = x[..., self.lon_idx_right] - x[..., self.lon_idx_left]
somega = torch.sin(omega)
start_prefac = torch.where(somega>1.e-4, torch.sin((1.-self.lon_weights) * omega)/somega, (1.-self.lon_weights))
end_prefac = torch.where(somega>1.e-4, torch.sin(self.lon_weights * omega)/somega, self.lon_weights)
start_prefac = torch.where(somega > 1e-4, torch.sin((1.0 - self.lon_weights) * omega) / somega, (1.0 - self.lon_weights))
end_prefac = torch.where(somega > 1e-4, torch.sin(self.lon_weights * omega) / somega, self.lon_weights)
x = start_prefac * x[..., self.lon_idx_left] + end_prefac * x[..., self.lon_idx_right]
return x
# old deprecated method with repeat_interleave
# def _upscale_longitudes(self, x: torch.Tensor):
# # for artifact-free upsampling in the longitudinal direction
# x = torch.repeat_interleave(x, self.lon_scale_factor, dim=-1)
# x = torch.roll(x, - self.lon_shift, dims=-1)
# return x
def _expand_poles(self, x: torch.Tensor):
repeats = [1 for _ in x.shape]
repeats[-1] = x.shape[-1]
......@@ -171,8 +164,8 @@ class DistributedResampleS2(nn.Module):
else:
omega = x[..., self.lat_idx + 1, :] - x[..., self.lat_idx, :]
somega = torch.sin(omega)
start_prefac = torch.where(somega>1.e-4, torch.sin((1.-self.lat_weights) * omega)/somega, (1.-self.lat_weights))
end_prefac = torch.where(somega>1.e-4, torch.sin(self.lat_weights * omega)/somega, self.lat_weights)
start_prefac = torch.where(somega > 1e-4, torch.sin((1.0 - self.lat_weights) * omega) / somega, (1.0 - self.lat_weights))
end_prefac = torch.where(somega > 1e-4, torch.sin(self.lat_weights * omega) / somega, self.lat_weights)
x = start_prefac * x[..., self.lat_idx, :] + end_prefac * x[..., self.lat_idx + 1, :]
return x
......
......@@ -128,19 +128,12 @@ class ResampleS2(nn.Module):
else:
omega = x[..., self.lon_idx_right] - x[..., self.lon_idx_left]
somega = torch.sin(omega)
start_prefac = torch.where(somega>1.e-4, torch.sin((1.-self.lon_weights) * omega)/somega, (1.-self.lon_weights))
end_prefac = torch.where(somega>1.e-4, torch.sin(self.lon_weights * omega)/somega, self.lon_weights)
start_prefac = torch.where(somega > 1e-4, torch.sin((1.0 - self.lon_weights) * omega) / somega, (1.0 - self.lon_weights))
end_prefac = torch.where(somega > 1e-4, torch.sin(self.lon_weights * omega) / somega, self.lon_weights)
x = start_prefac * x[..., self.lon_idx_left] + end_prefac * x[..., self.lon_idx_right]
return x
# old deprecated method with repeat_interleave
# def _upscale_longitudes(self, x: torch.Tensor):
# # for artifact-free upsampling in the longitudinal direction
# x = torch.repeat_interleave(x, self.lon_scale_factor, dim=-1)
# x = torch.roll(x, - self.lon_shift, dims=-1)
# return x
def _expand_poles(self, x: torch.Tensor):
repeats = [1 for _ in x.shape]
repeats[-1] = x.shape[-1]
......@@ -156,8 +149,8 @@ class ResampleS2(nn.Module):
else:
omega = x[..., self.lat_idx + 1, :] - x[..., self.lat_idx, :]
somega = torch.sin(omega)
start_prefac = torch.where(somega>1.e-4, torch.sin((1.-self.lat_weights) * omega)/somega, (1.-self.lat_weights))
end_prefac = torch.where(somega>1.e-4, torch.sin(self.lat_weights * omega)/somega, self.lat_weights)
start_prefac = torch.where(somega > 1e-4, torch.sin((1.0 - self.lat_weights) * omega) / somega, (1.0 - self.lat_weights))
end_prefac = torch.where(somega > 1e-4, torch.sin(self.lat_weights * omega) / somega, self.lat_weights)
x = start_prefac * x[..., self.lat_idx, :] + end_prefac * x[..., self.lat_idx + 1, :]
return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment