"...text-generation-inference.git" did not exist on "e496c9ba5b574ce4e9d04d3b16bce67759ff0445"
Commit 8680e023 authored by Boris Bonev's avatar Boris Bonev Committed by Boris Bonev
Browse files

formating changes to resample module

parent 4d8755b5
...@@ -143,19 +143,12 @@ class DistributedResampleS2(nn.Module): ...@@ -143,19 +143,12 @@ class DistributedResampleS2(nn.Module):
else: else:
omega = x[..., self.lon_idx_right] - x[..., self.lon_idx_left] omega = x[..., self.lon_idx_right] - x[..., self.lon_idx_left]
somega = torch.sin(omega) somega = torch.sin(omega)
start_prefac = torch.where(somega>1.e-4, torch.sin((1.-self.lon_weights) * omega)/somega, (1.-self.lon_weights)) start_prefac = torch.where(somega > 1e-4, torch.sin((1.0 - self.lon_weights) * omega) / somega, (1.0 - self.lon_weights))
end_prefac = torch.where(somega>1.e-4, torch.sin(self.lon_weights * omega)/somega, self.lon_weights) end_prefac = torch.where(somega > 1e-4, torch.sin(self.lon_weights * omega) / somega, self.lon_weights)
x = start_prefac * x[..., self.lon_idx_left] + end_prefac * x[..., self.lon_idx_right] x = start_prefac * x[..., self.lon_idx_left] + end_prefac * x[..., self.lon_idx_right]
return x return x
# old deprecated method with repeat_interleave
# def _upscale_longitudes(self, x: torch.Tensor):
# # for artifact-free upsampling in the longitudinal direction
# x = torch.repeat_interleave(x, self.lon_scale_factor, dim=-1)
# x = torch.roll(x, - self.lon_shift, dims=-1)
# return x
def _expand_poles(self, x: torch.Tensor): def _expand_poles(self, x: torch.Tensor):
repeats = [1 for _ in x.shape] repeats = [1 for _ in x.shape]
repeats[-1] = x.shape[-1] repeats[-1] = x.shape[-1]
...@@ -171,8 +164,8 @@ class DistributedResampleS2(nn.Module): ...@@ -171,8 +164,8 @@ class DistributedResampleS2(nn.Module):
else: else:
omega = x[..., self.lat_idx + 1, :] - x[..., self.lat_idx, :] omega = x[..., self.lat_idx + 1, :] - x[..., self.lat_idx, :]
somega = torch.sin(omega) somega = torch.sin(omega)
start_prefac = torch.where(somega>1.e-4, torch.sin((1.-self.lat_weights) * omega)/somega, (1.-self.lat_weights)) start_prefac = torch.where(somega > 1e-4, torch.sin((1.0 - self.lat_weights) * omega) / somega, (1.0 - self.lat_weights))
end_prefac = torch.where(somega>1.e-4, torch.sin(self.lat_weights * omega)/somega, self.lat_weights) end_prefac = torch.where(somega > 1e-4, torch.sin(self.lat_weights * omega) / somega, self.lat_weights)
x = start_prefac * x[..., self.lat_idx, :] + end_prefac * x[..., self.lat_idx + 1, :] x = start_prefac * x[..., self.lat_idx, :] + end_prefac * x[..., self.lat_idx + 1, :]
return x return x
......
...@@ -128,19 +128,12 @@ class ResampleS2(nn.Module): ...@@ -128,19 +128,12 @@ class ResampleS2(nn.Module):
else: else:
omega = x[..., self.lon_idx_right] - x[..., self.lon_idx_left] omega = x[..., self.lon_idx_right] - x[..., self.lon_idx_left]
somega = torch.sin(omega) somega = torch.sin(omega)
start_prefac = torch.where(somega>1.e-4, torch.sin((1.-self.lon_weights) * omega)/somega, (1.-self.lon_weights)) start_prefac = torch.where(somega > 1e-4, torch.sin((1.0 - self.lon_weights) * omega) / somega, (1.0 - self.lon_weights))
end_prefac = torch.where(somega>1.e-4, torch.sin(self.lon_weights * omega)/somega, self.lon_weights) end_prefac = torch.where(somega > 1e-4, torch.sin(self.lon_weights * omega) / somega, self.lon_weights)
x = start_prefac * x[..., self.lon_idx_left] + end_prefac * x[..., self.lon_idx_right] x = start_prefac * x[..., self.lon_idx_left] + end_prefac * x[..., self.lon_idx_right]
return x return x
# old deprecated method with repeat_interleave
# def _upscale_longitudes(self, x: torch.Tensor):
# # for artifact-free upsampling in the longitudinal direction
# x = torch.repeat_interleave(x, self.lon_scale_factor, dim=-1)
# x = torch.roll(x, - self.lon_shift, dims=-1)
# return x
def _expand_poles(self, x: torch.Tensor): def _expand_poles(self, x: torch.Tensor):
repeats = [1 for _ in x.shape] repeats = [1 for _ in x.shape]
repeats[-1] = x.shape[-1] repeats[-1] = x.shape[-1]
...@@ -156,8 +149,8 @@ class ResampleS2(nn.Module): ...@@ -156,8 +149,8 @@ class ResampleS2(nn.Module):
else: else:
omega = x[..., self.lat_idx + 1, :] - x[..., self.lat_idx, :] omega = x[..., self.lat_idx + 1, :] - x[..., self.lat_idx, :]
somega = torch.sin(omega) somega = torch.sin(omega)
start_prefac = torch.where(somega>1.e-4, torch.sin((1.-self.lat_weights) * omega)/somega, (1.-self.lat_weights)) start_prefac = torch.where(somega > 1e-4, torch.sin((1.0 - self.lat_weights) * omega) / somega, (1.0 - self.lat_weights))
end_prefac = torch.where(somega>1.e-4, torch.sin(self.lat_weights * omega)/somega, self.lat_weights) end_prefac = torch.where(somega > 1e-4, torch.sin(self.lat_weights * omega) / somega, self.lat_weights)
x = start_prefac * x[..., self.lat_idx, :] + end_prefac * x[..., self.lat_idx + 1, :] x = start_prefac * x[..., self.lat_idx, :] + end_prefac * x[..., self.lat_idx + 1, :]
return x return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment