Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
8680e023
Commit
8680e023
authored
Jan 14, 2025
by
Boris Bonev
Committed by
Boris Bonev
Jan 14, 2025
Browse files
formating changes to resample module
parent
4d8755b5
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
22 deletions
+8
-22
torch_harmonics/distributed/distributed_resample.py
torch_harmonics/distributed/distributed_resample.py
+4
-11
torch_harmonics/resample.py
torch_harmonics/resample.py
+4
-11
No files found.
torch_harmonics/distributed/distributed_resample.py
View file @
8680e023
...
...
@@ -143,19 +143,12 @@ class DistributedResampleS2(nn.Module):
else
:
omega
=
x
[...,
self
.
lon_idx_right
]
-
x
[...,
self
.
lon_idx_left
]
somega
=
torch
.
sin
(
omega
)
start_prefac
=
torch
.
where
(
somega
>
1.
e-4
,
torch
.
sin
((
1.
-
self
.
lon_weights
)
*
omega
)
/
somega
,
(
1.
-
self
.
lon_weights
))
end_prefac
=
torch
.
where
(
somega
>
1.
e-4
,
torch
.
sin
(
self
.
lon_weights
*
omega
)
/
somega
,
self
.
lon_weights
)
start_prefac
=
torch
.
where
(
somega
>
1
e-4
,
torch
.
sin
((
1.
0
-
self
.
lon_weights
)
*
omega
)
/
somega
,
(
1.
0
-
self
.
lon_weights
))
end_prefac
=
torch
.
where
(
somega
>
1
e-4
,
torch
.
sin
(
self
.
lon_weights
*
omega
)
/
somega
,
self
.
lon_weights
)
x
=
start_prefac
*
x
[...,
self
.
lon_idx_left
]
+
end_prefac
*
x
[...,
self
.
lon_idx_right
]
return
x
# old deprecated method with repeat_interleave
# def _upscale_longitudes(self, x: torch.Tensor):
# # for artifact-free upsampling in the longitudinal direction
# x = torch.repeat_interleave(x, self.lon_scale_factor, dim=-1)
# x = torch.roll(x, - self.lon_shift, dims=-1)
# return x
def
_expand_poles
(
self
,
x
:
torch
.
Tensor
):
repeats
=
[
1
for
_
in
x
.
shape
]
repeats
[
-
1
]
=
x
.
shape
[
-
1
]
...
...
@@ -171,8 +164,8 @@ class DistributedResampleS2(nn.Module):
else
:
omega
=
x
[...,
self
.
lat_idx
+
1
,
:]
-
x
[...,
self
.
lat_idx
,
:]
somega
=
torch
.
sin
(
omega
)
start_prefac
=
torch
.
where
(
somega
>
1.
e-4
,
torch
.
sin
((
1.
-
self
.
lat_weights
)
*
omega
)
/
somega
,
(
1.
-
self
.
lat_weights
))
end_prefac
=
torch
.
where
(
somega
>
1.
e-4
,
torch
.
sin
(
self
.
lat_weights
*
omega
)
/
somega
,
self
.
lat_weights
)
start_prefac
=
torch
.
where
(
somega
>
1
e-4
,
torch
.
sin
((
1.
0
-
self
.
lat_weights
)
*
omega
)
/
somega
,
(
1.
0
-
self
.
lat_weights
))
end_prefac
=
torch
.
where
(
somega
>
1
e-4
,
torch
.
sin
(
self
.
lat_weights
*
omega
)
/
somega
,
self
.
lat_weights
)
x
=
start_prefac
*
x
[...,
self
.
lat_idx
,
:]
+
end_prefac
*
x
[...,
self
.
lat_idx
+
1
,
:]
return
x
...
...
torch_harmonics/resample.py
View file @
8680e023
...
...
@@ -128,19 +128,12 @@ class ResampleS2(nn.Module):
else
:
omega
=
x
[...,
self
.
lon_idx_right
]
-
x
[...,
self
.
lon_idx_left
]
somega
=
torch
.
sin
(
omega
)
start_prefac
=
torch
.
where
(
somega
>
1.
e-4
,
torch
.
sin
((
1.
-
self
.
lon_weights
)
*
omega
)
/
somega
,
(
1.
-
self
.
lon_weights
))
end_prefac
=
torch
.
where
(
somega
>
1.
e-4
,
torch
.
sin
(
self
.
lon_weights
*
omega
)
/
somega
,
self
.
lon_weights
)
start_prefac
=
torch
.
where
(
somega
>
1
e-4
,
torch
.
sin
((
1.
0
-
self
.
lon_weights
)
*
omega
)
/
somega
,
(
1.
0
-
self
.
lon_weights
))
end_prefac
=
torch
.
where
(
somega
>
1
e-4
,
torch
.
sin
(
self
.
lon_weights
*
omega
)
/
somega
,
self
.
lon_weights
)
x
=
start_prefac
*
x
[...,
self
.
lon_idx_left
]
+
end_prefac
*
x
[...,
self
.
lon_idx_right
]
return
x
# old deprecated method with repeat_interleave
# def _upscale_longitudes(self, x: torch.Tensor):
# # for artifact-free upsampling in the longitudinal direction
# x = torch.repeat_interleave(x, self.lon_scale_factor, dim=-1)
# x = torch.roll(x, - self.lon_shift, dims=-1)
# return x
def
_expand_poles
(
self
,
x
:
torch
.
Tensor
):
repeats
=
[
1
for
_
in
x
.
shape
]
repeats
[
-
1
]
=
x
.
shape
[
-
1
]
...
...
@@ -156,8 +149,8 @@ class ResampleS2(nn.Module):
else
:
omega
=
x
[...,
self
.
lat_idx
+
1
,
:]
-
x
[...,
self
.
lat_idx
,
:]
somega
=
torch
.
sin
(
omega
)
start_prefac
=
torch
.
where
(
somega
>
1.
e-4
,
torch
.
sin
((
1.
-
self
.
lat_weights
)
*
omega
)
/
somega
,
(
1.
-
self
.
lat_weights
))
end_prefac
=
torch
.
where
(
somega
>
1.
e-4
,
torch
.
sin
(
self
.
lat_weights
*
omega
)
/
somega
,
self
.
lat_weights
)
start_prefac
=
torch
.
where
(
somega
>
1
e-4
,
torch
.
sin
((
1.
0
-
self
.
lat_weights
)
*
omega
)
/
somega
,
(
1.
0
-
self
.
lat_weights
))
end_prefac
=
torch
.
where
(
somega
>
1
e-4
,
torch
.
sin
(
self
.
lat_weights
*
omega
)
/
somega
,
self
.
lat_weights
)
x
=
start_prefac
*
x
[...,
self
.
lat_idx
,
:]
+
end_prefac
*
x
[...,
self
.
lat_idx
+
1
,
:]
return
x
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment