Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
55bbcb25
Commit
55bbcb25
authored
Jan 14, 2025
by
Thorsten Kurth
Committed by
Boris Bonev
Jan 14, 2025
Browse files
implemented slerp
parent
87d9bfdc
Changes
5
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
77 additions
and
49 deletions
+77
-49
notebooks/resample_sphere.ipynb
notebooks/resample_sphere.ipynb
+32
-39
tests/test_distributed_resample.py
tests/test_distributed_resample.py
+6
-3
torch_harmonics/__init__.py
torch_harmonics/__init__.py
+1
-1
torch_harmonics/distributed/distributed_resample.py
torch_harmonics/distributed/distributed_resample.py
+19
-3
torch_harmonics/resample.py
torch_harmonics/resample.py
+19
-3
No files found.
notebooks/resample_sphere.ipynb
View file @
55bbcb25
This diff is collapsed.
Click to expand it.
tests/test_distributed_resample.py
View file @
55bbcb25
...
...
@@ -183,12 +183,14 @@ class TestDistributedResampling(unittest.TestCase):
@
parameterized
.
expand
(
[
[
64
,
128
,
128
,
256
,
32
,
8
,
"equiangular"
,
"equiangular"
,
1e-7
],
[
128
,
256
,
64
,
128
,
32
,
8
,
"equiangular"
,
"equiangular"
,
1e-7
],
[
64
,
128
,
128
,
256
,
32
,
8
,
"equiangular"
,
"equiangular"
,
"bilinear"
,
1e-7
],
[
128
,
256
,
64
,
128
,
32
,
8
,
"equiangular"
,
"equiangular"
,
"bilinear"
,
1e-7
],
[
64
,
128
,
128
,
256
,
32
,
8
,
"equiangular"
,
"equiangular"
,
"bilinear-spherical"
,
1e-7
],
[
128
,
256
,
64
,
128
,
32
,
8
,
"equiangular"
,
"equiangular"
,
"bilinear-spherical"
,
1e-7
],
]
)
def
test_distributed_resampling
(
self
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
batch_size
,
num_chan
,
grid_in
,
grid_out
,
tol
self
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
batch_size
,
num_chan
,
grid_in
,
grid_out
,
mode
,
tol
):
B
,
C
,
H
,
W
=
batch_size
,
num_chan
,
nlat_in
,
nlon_in
...
...
@@ -200,6 +202,7 @@ class TestDistributedResampling(unittest.TestCase):
nlon_out
=
nlon_out
,
grid_in
=
grid_in
,
grid_out
=
grid_out
,
mode
=
mode
,
)
# set up handlesD
...
...
torch_harmonics/__init__.py
View file @
55bbcb25
...
...
@@ -33,7 +33,7 @@ __version__ = "0.7.4a"
from
.sht
import
RealSHT
,
InverseRealSHT
,
RealVectorSHT
,
InverseRealVectorSHT
from
.convolution
import
DiscreteContinuousConvS2
,
DiscreteContinuousConvTransposeS2
from
.resampl
ing
import
ResampleS2
from
.resampl
e
import
ResampleS2
from
.
import
quadrature
from
.
import
random_fields
from
.
import
examples
torch_harmonics/distributed/distributed_resample.py
View file @
55bbcb25
...
...
@@ -57,7 +57,7 @@ class DistributedResampleS2(nn.Module):
super
().
__init__
()
# currently only bilinear is supported
if
mode
==
"bilinear"
:
if
mode
in
[
"bilinear"
,
"bilinear-spherical"
]
:
self
.
mode
=
mode
else
:
raise
NotImplementedError
(
f
"unknown interpolation mode
{
mode
}
"
)
...
...
@@ -138,7 +138,15 @@ class DistributedResampleS2(nn.Module):
def
_upscale_longitudes
(
self
,
x
:
torch
.
Tensor
):
# do the interpolation
x
=
torch
.
lerp
(
x
[...,
self
.
lon_idx_left
],
x
[...,
self
.
lon_idx_right
],
self
.
lon_weights
)
if
self
.
mode
==
"bilinear"
:
x
=
torch
.
lerp
(
x
[...,
self
.
lon_idx_left
],
x
[...,
self
.
lon_idx_right
],
self
.
lon_weights
)
else
:
omega
=
x
[...,
self
.
lon_idx_right
]
-
x
[...,
self
.
lon_idx_left
]
somega
=
torch
.
sin
(
omega
)
start_prefac
=
torch
.
where
(
somega
>
1.e-4
,
torch
.
sin
((
1.
-
self
.
lon_weights
)
*
omega
)
/
somega
,
(
1.
-
self
.
lon_weights
))
end_prefac
=
torch
.
where
(
somega
>
1.e-4
,
torch
.
sin
(
self
.
lon_weights
*
omega
)
/
somega
,
self
.
lon_weights
)
x
=
start_prefac
*
x
[...,
self
.
lon_idx_left
]
+
end_prefac
*
x
[...,
self
.
lon_idx_right
]
return
x
# old deprecated method with repeat_interleave
...
...
@@ -158,7 +166,15 @@ class DistributedResampleS2(nn.Module):
def
_upscale_latitudes
(
self
,
x
:
torch
.
Tensor
):
# do the interpolation
x
=
torch
.
lerp
(
x
[...,
self
.
lat_idx
,
:],
x
[...,
self
.
lat_idx
+
1
,
:],
self
.
lat_weights
)
if
self
.
mode
==
"bilinear"
:
x
=
torch
.
lerp
(
x
[...,
self
.
lat_idx
,
:],
x
[...,
self
.
lat_idx
+
1
,
:],
self
.
lat_weights
)
else
:
omega
=
x
[...,
self
.
lat_idx
+
1
,
:]
-
x
[...,
self
.
lat_idx
,
:]
somega
=
torch
.
sin
(
omega
)
start_prefac
=
torch
.
where
(
somega
>
1.e-4
,
torch
.
sin
((
1.
-
self
.
lat_weights
)
*
omega
)
/
somega
,
(
1.
-
self
.
lat_weights
))
end_prefac
=
torch
.
where
(
somega
>
1.e-4
,
torch
.
sin
(
self
.
lat_weights
*
omega
)
/
somega
,
self
.
lat_weights
)
x
=
start_prefac
*
x
[...,
self
.
lat_idx
,
:]
+
end_prefac
*
x
[...,
self
.
lat_idx
+
1
,
:]
return
x
def
forward
(
self
,
x
:
torch
.
Tensor
):
...
...
torch_harmonics/resampl
ing
.py
→
torch_harmonics/resampl
e
.py
View file @
55bbcb25
...
...
@@ -54,7 +54,7 @@ class ResampleS2(nn.Module):
super
().
__init__
()
# currently only bilinear is supported
if
mode
==
"bilinear"
:
if
mode
in
[
"bilinear"
,
"bilinear-spherical"
]
:
self
.
mode
=
mode
else
:
raise
NotImplementedError
(
f
"unknown interpolation mode
{
mode
}
"
)
...
...
@@ -123,7 +123,15 @@ class ResampleS2(nn.Module):
def
_upscale_longitudes
(
self
,
x
:
torch
.
Tensor
):
# do the interpolation
x
=
torch
.
lerp
(
x
[...,
self
.
lon_idx_left
],
x
[...,
self
.
lon_idx_right
],
self
.
lon_weights
)
if
self
.
mode
==
"bilinear"
:
x
=
torch
.
lerp
(
x
[...,
self
.
lon_idx_left
],
x
[...,
self
.
lon_idx_right
],
self
.
lon_weights
)
else
:
omega
=
x
[...,
self
.
lon_idx_right
]
-
x
[...,
self
.
lon_idx_left
]
somega
=
torch
.
sin
(
omega
)
start_prefac
=
torch
.
where
(
somega
>
1.e-4
,
torch
.
sin
((
1.
-
self
.
lon_weights
)
*
omega
)
/
somega
,
(
1.
-
self
.
lon_weights
))
end_prefac
=
torch
.
where
(
somega
>
1.e-4
,
torch
.
sin
(
self
.
lon_weights
*
omega
)
/
somega
,
self
.
lon_weights
)
x
=
start_prefac
*
x
[...,
self
.
lon_idx_left
]
+
end_prefac
*
x
[...,
self
.
lon_idx_right
]
return
x
# old deprecated method with repeat_interleave
...
...
@@ -143,7 +151,15 @@ class ResampleS2(nn.Module):
def
_upscale_latitudes
(
self
,
x
:
torch
.
Tensor
):
# do the interpolation
x
=
torch
.
lerp
(
x
[...,
self
.
lat_idx
,
:],
x
[...,
self
.
lat_idx
+
1
,
:],
self
.
lat_weights
)
if
self
.
mode
==
"bilinear"
:
x
=
torch
.
lerp
(
x
[...,
self
.
lat_idx
,
:],
x
[...,
self
.
lat_idx
+
1
,
:],
self
.
lat_weights
)
else
:
omega
=
x
[...,
self
.
lat_idx
+
1
,
:]
-
x
[...,
self
.
lat_idx
,
:]
somega
=
torch
.
sin
(
omega
)
start_prefac
=
torch
.
where
(
somega
>
1.e-4
,
torch
.
sin
((
1.
-
self
.
lat_weights
)
*
omega
)
/
somega
,
(
1.
-
self
.
lat_weights
))
end_prefac
=
torch
.
where
(
somega
>
1.e-4
,
torch
.
sin
(
self
.
lat_weights
*
omega
)
/
somega
,
self
.
lat_weights
)
x
=
start_prefac
*
x
[...,
self
.
lat_idx
,
:]
+
end_prefac
*
x
[...,
self
.
lat_idx
+
1
,
:]
return
x
def
forward
(
self
,
x
:
torch
.
Tensor
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment