Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
318fc76e
Unverified
Commit
318fc76e
authored
May 26, 2025
by
Thorsten Kurth
Committed by
GitHub
May 26, 2025
Browse files
fixing distributed resampling routine (#74)
parent
18f2c1cc
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
103 additions
and
30 deletions
+103
-30
tests/test_distributed_convolution.py
tests/test_distributed_convolution.py
+4
-0
tests/test_distributed_resample.py
tests/test_distributed_resample.py
+5
-1
torch_harmonics/distributed/__init__.py
torch_harmonics/distributed/__init__.py
+2
-0
torch_harmonics/distributed/distributed_resample.py
torch_harmonics/distributed/distributed_resample.py
+29
-14
torch_harmonics/distributed/primitives.py
torch_harmonics/distributed/primitives.py
+54
-8
torch_harmonics/resample.py
torch_harmonics/resample.py
+9
-7
No files found.
tests/test_distributed_convolution.py
View file @
318fc76e
...
...
@@ -195,6 +195,10 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
[
64
,
128
,
128
,
256
,
32
,
8
,
(
3
),
"piecewise linear"
,
"mean"
,
1
,
"equiangular"
,
"equiangular"
,
True
,
1e-5
],
[
128
,
256
,
128
,
256
,
32
,
8
,
(
3
),
"piecewise linear"
,
"mean"
,
2
,
"equiangular"
,
"equiangular"
,
True
,
1e-5
],
[
128
,
256
,
128
,
256
,
32
,
6
,
(
3
),
"piecewise linear"
,
"mean"
,
1
,
"equiangular"
,
"equiangular"
,
True
,
1e-5
],
[
129
,
256
,
129
,
256
,
32
,
8
,
(
3
,
4
),
"morlet"
,
"mean"
,
1
,
"equiangular"
,
"equiangular"
,
False
,
1e-5
],
[
129
,
256
,
129
,
256
,
32
,
8
,
(
3
,
4
),
"morlet"
,
"mean"
,
1
,
"equiangular"
,
"equiangular"
,
True
,
1e-5
],
[
65
,
128
,
129
,
256
,
32
,
8
,
(
3
,
4
),
"morlet"
,
"mean"
,
1
,
"equiangular"
,
"equiangular"
,
True
,
1e-5
],
[
129
,
256
,
65
,
128
,
32
,
8
,
(
3
,
4
),
"morlet"
,
"mean"
,
1
,
"equiangular"
,
"equiangular"
,
False
,
1e-5
],
]
)
def
test_distributed_disco_conv
(
...
...
tests/test_distributed_resample.py
View file @
318fc76e
...
...
@@ -187,6 +187,10 @@ class TestDistributedResampling(unittest.TestCase):
[
128
,
256
,
64
,
128
,
32
,
8
,
"equiangular"
,
"equiangular"
,
"bilinear"
,
1e-7
,
False
],
[
64
,
128
,
128
,
256
,
32
,
8
,
"equiangular"
,
"equiangular"
,
"bilinear-spherical"
,
1e-7
,
False
],
[
128
,
256
,
64
,
128
,
32
,
8
,
"equiangular"
,
"equiangular"
,
"bilinear-spherical"
,
1e-7
,
False
],
[
129
,
256
,
65
,
128
,
32
,
8
,
"equiangular"
,
"equiangular"
,
"bilinear"
,
1e-7
,
False
],
[
65
,
128
,
129
,
256
,
32
,
8
,
"equiangular"
,
"equiangular"
,
"bilinear"
,
1e-7
,
False
],
[
129
,
256
,
65
,
128
,
32
,
8
,
"equiangular"
,
"legendre-gauss"
,
"bilinear"
,
1e-7
,
False
],
[
65
,
128
,
129
,
256
,
32
,
8
,
"legendre-gauss"
,
"equiangular"
,
"bilinear"
,
1e-7
,
False
],
]
)
def
test_distributed_resampling
(
...
...
@@ -248,7 +252,7 @@ class TestDistributedResampling(unittest.TestCase):
with
torch
.
no_grad
():
out_gather_full
=
self
.
_gather_helper_fwd
(
out_local
,
B
,
C
,
res_dist
)
err
=
torch
.
mean
(
torch
.
norm
(
out_full
-
out_gather_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
))
/
torch
.
norm
(
out_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
)))
if
verbose
and
(
self
.
world_rank
==
)
0
:
if
verbose
and
(
self
.
world_rank
==
0
)
:
print
(
f
"final relative error of output:
{
err
.
item
()
}
"
)
self
.
assertTrue
(
err
.
item
()
<=
tol
)
...
...
torch_harmonics/distributed/__init__.py
View file @
318fc76e
...
...
@@ -37,9 +37,11 @@ from .primitives import (
distributed_transpose_azimuth
,
distributed_transpose_polar
,
reduce_from_polar_region
,
reduce_from_azimuth_region
,
scatter_to_polar_region
,
gather_from_polar_region
,
copy_to_polar_region
,
copy_to_azimuth_region
,
reduce_from_scatter_to_polar_region
,
gather_from_copy_to_polar_region
)
...
...
torch_harmonics/distributed/distributed_resample.py
View file @
318fc76e
...
...
@@ -37,6 +37,7 @@ import torch.nn as nn
from
torch_harmonics.quadrature
import
_precompute_latitudes
,
_precompute_longitudes
from
torch_harmonics.distributed
import
polar_group_size
,
azimuth_group_size
,
distributed_transpose_azimuth
,
distributed_transpose_polar
from
torch_harmonics.distributed
import
reduce_from_azimuth_region
,
copy_to_azimuth_region
from
torch_harmonics.distributed
import
polar_group_rank
,
azimuth_group_rank
from
torch_harmonics.distributed
import
compute_split_shapes
...
...
@@ -92,8 +93,6 @@ class DistributedResampleS2(nn.Module):
self
.
lats_in
=
torch
.
cat
([
torch
.
tensor
([
0.
],
dtype
=
torch
.
float64
),
self
.
lats_in
,
torch
.
tensor
([
math
.
pi
],
dtype
=
torch
.
float64
)]).
contiguous
()
#self.lats_in = np.insert(self.lats_in, 0, 0.0)
#self.lats_in = np.append(self.lats_in, np.pi)
# prepare the interpolation by computing indices to the left and right of each output latitude
lat_idx
=
torch
.
searchsorted
(
self
.
lats_in
,
self
.
lats_out
,
side
=
"right"
)
-
1
...
...
@@ -135,34 +134,50 @@ class DistributedResampleS2(nn.Module):
def
_upscale_longitudes
(
self
,
x
:
torch
.
Tensor
):
# do the interpolation
lwgt
=
self
.
lon_weights
.
to
(
x
.
dtype
)
if
self
.
mode
==
"bilinear"
:
x
=
torch
.
lerp
(
x
[...,
self
.
lon_idx_left
],
x
[...,
self
.
lon_idx_right
],
self
.
lon_weights
)
x
=
torch
.
lerp
(
x
[...,
self
.
lon_idx_left
],
x
[...,
self
.
lon_idx_right
],
lwgt
)
else
:
omega
=
x
[...,
self
.
lon_idx_right
]
-
x
[...,
self
.
lon_idx_left
]
somega
=
torch
.
sin
(
omega
)
start_prefac
=
torch
.
where
(
somega
>
1e-4
,
torch
.
sin
((
1.0
-
self
.
lon_weights
)
*
omega
)
/
somega
,
(
1.0
-
self
.
lon_weights
))
end_prefac
=
torch
.
where
(
somega
>
1e-4
,
torch
.
sin
(
self
.
lon_weights
*
omega
)
/
somega
,
self
.
lon_weights
)
start_prefac
=
torch
.
where
(
somega
>
1e-4
,
torch
.
sin
((
1.0
-
lwgt
)
*
omega
)
/
somega
,
(
1.0
-
lwgt
))
end_prefac
=
torch
.
where
(
somega
>
1e-4
,
torch
.
sin
(
lwgt
*
omega
)
/
somega
,
lwgt
)
x
=
start_prefac
*
x
[...,
self
.
lon_idx_left
]
+
end_prefac
*
x
[...,
self
.
lon_idx_right
]
return
x
def
_expand_poles
(
self
,
x
:
torch
.
Tensor
):
repeats
=
[
1
for
_
in
x
.
shape
]
repeats
[
-
1
]
=
x
.
shape
[
-
1
]
x_north
=
x
[...,
0
:
1
,
:].
mean
(
dim
=-
1
,
keepdim
=
True
).
repeat
(
*
repeats
)
x_south
=
x
[...,
-
1
:,
:].
mean
(
dim
=-
1
,
keepdim
=
True
).
repeat
(
*
repeats
)
x
=
torch
.
concatenate
((
x_north
,
x
,
x_south
),
dim
=-
2
)
x_north
=
x
[...,
0
,
:].
sum
(
dim
=-
1
,
keepdims
=
True
)
x_south
=
x
[...,
-
1
,
:].
sum
(
dim
=-
1
,
keepdims
=
True
)
x_count
=
torch
.
tensor
([
x
.
shape
[
-
1
]],
dtype
=
torch
.
long
,
device
=
x
.
device
,
requires_grad
=
False
)
if
self
.
comm_size_azimuth
>
1
:
x_north
=
reduce_from_azimuth_region
(
x_north
.
contiguous
())
x_south
=
reduce_from_azimuth_region
(
x_south
.
contiguous
())
x_count
=
reduce_from_azimuth_region
(
x_count
)
x_north
=
x_north
/
x_count
x_south
=
x_south
/
x_count
if
self
.
comm_size_azimuth
>
1
:
x_north
=
copy_to_azimuth_region
(
x_north
)
x_south
=
copy_to_azimuth_region
(
x_south
)
x
=
nn
.
functional
.
pad
(
x
,
pad
=
[
0
,
0
,
1
,
1
],
mode
=
'constant'
)
x
[...,
0
,
:]
=
x_north
[...]
x
[...,
-
1
,
:]
=
x_south
[...]
return
x
def
_upscale_latitudes
(
self
,
x
:
torch
.
Tensor
):
# do the interpolation
lwgt
=
self
.
lat_weights
.
to
(
x
.
dtype
)
if
self
.
mode
==
"bilinear"
:
x
=
torch
.
lerp
(
x
[...,
self
.
lat_idx
,
:],
x
[...,
self
.
lat_idx
+
1
,
:],
self
.
lat_weights
)
x
=
torch
.
lerp
(
x
[...,
self
.
lat_idx
,
:],
x
[...,
self
.
lat_idx
+
1
,
:],
lwgt
)
else
:
omega
=
x
[...,
self
.
lat_idx
+
1
,
:]
-
x
[...,
self
.
lat_idx
,
:]
somega
=
torch
.
sin
(
omega
)
start_prefac
=
torch
.
where
(
somega
>
1e-4
,
torch
.
sin
((
1.0
-
self
.
lat_weights
)
*
omega
)
/
somega
,
(
1.0
-
self
.
lat_weights
))
end_prefac
=
torch
.
where
(
somega
>
1e-4
,
torch
.
sin
(
self
.
lat_weights
*
omega
)
/
somega
,
self
.
lat_weights
)
start_prefac
=
torch
.
where
(
somega
>
1e-4
,
torch
.
sin
((
1.0
-
lwgt
)
*
omega
)
/
somega
,
(
1.0
-
lwgt
))
end_prefac
=
torch
.
where
(
somega
>
1e-4
,
torch
.
sin
(
lwgt
*
omega
)
/
somega
,
lwgt
)
x
=
start_prefac
*
x
[...,
self
.
lat_idx
,
:]
+
end_prefac
*
x
[...,
self
.
lat_idx
+
1
,
:]
return
x
...
...
@@ -174,7 +189,7 @@ class DistributedResampleS2(nn.Module):
# transpose data so that h is local, and channels are split
num_chans
=
x
.
shape
[
-
3
]
# h and w is split. First we make w local by transposing into channel dim
if
self
.
comm_size_polar
>
1
:
channels_shapes
=
compute_split_shapes
(
num_chans
,
self
.
comm_size_polar
)
...
...
torch_harmonics/distributed/primitives.py
View file @
318fc76e
...
...
@@ -35,7 +35,7 @@ import torch.distributed as dist
from
torch.amp
import
custom_fwd
,
custom_bwd
from
.utils
import
polar_group
,
azimuth_group
,
polar_group_size
from
.utils
import
is_initialized
,
is_distributed_polar
from
.utils
import
is_initialized
,
is_distributed_polar
,
is_distributed_azimuth
# helper routine to compute uneven splitting in balanced way:
def
compute_split_shapes
(
size
:
int
,
num_chunks
:
int
)
->
List
[
int
]:
...
...
@@ -262,8 +262,29 @@ class _CopyToPolarRegion(torch.autograd.Function):
return
_reduce
(
grad_output
,
group
=
polar_group
())
else
:
return
grad_output
,
None
class
_CopyToAzimuthRegion
(
torch
.
autograd
.
Function
):
"""Split the input and keep only the corresponding chunk to the rank."""
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
input_
@
staticmethod
@
custom_fwd
(
device_type
=
"cuda"
)
def
forward
(
ctx
,
input_
):
return
input_
@
staticmethod
@
custom_bwd
(
device_type
=
"cuda"
)
def
backward
(
ctx
,
grad_output
):
if
is_distributed_azimuth
():
return
_reduce
(
grad_output
,
group
=
azimuth_group
())
else
:
return
grad_output
,
None
class
_ScatterToPolarRegion
(
torch
.
autograd
.
Function
):
"""Split the input and keep only the corresponding chunk to the rank."""
...
...
@@ -340,6 +361,30 @@ class _ReduceFromPolarRegion(torch.autograd.Function):
def
backward
(
ctx
,
grad_output
):
return
grad_output
class
_ReduceFromAzimuthRegion
(
torch
.
autograd
.
Function
):
"""All-reduce the input from the azimuth region."""
@
staticmethod
def
symbolic
(
graph
,
input_
):
if
is_distributed_azimuth
():
return
_reduce
(
input_
,
group
=
azimuth_group
())
else
:
return
input_
@
staticmethod
@
custom_fwd
(
device_type
=
"cuda"
)
def
forward
(
ctx
,
input_
):
if
is_distributed_azimuth
():
return
_reduce
(
input_
,
group
=
azimuth_group
())
else
:
return
input_
@
staticmethod
@
custom_bwd
(
device_type
=
"cuda"
)
def
backward
(
ctx
,
grad_output
):
return
grad_output
class
_ReduceFromScatterToPolarRegion
(
torch
.
autograd
.
Function
):
"""All-reduce the input from the polar region and scatter back to polar region."""
...
...
@@ -403,23 +448,24 @@ class _GatherFromCopyToPolarRegion(torch.autograd.Function):
def
copy_to_polar_region
(
input_
):
return
_CopyToPolarRegion
.
apply
(
input_
)
def
copy_to_azimuth_region
(
input_
):
return
_CopyToAzimuthRegion
.
apply
(
input_
)
def
reduce_from_polar_region
(
input_
):
return
_ReduceFromPolarRegion
.
apply
(
input_
)
def
reduce_from_azimuth_region
(
input_
):
return
_ReduceFromAzimuthRegion
.
apply
(
input_
)
def
scatter_to_polar_region
(
input_
,
dim_
):
return
_ScatterToPolarRegion
.
apply
(
input_
,
dim_
)
def
gather_from_polar_region
(
input_
,
dim_
,
shapes_
):
return
_GatherFromPolarRegion
.
apply
(
input_
,
dim_
,
shapes_
)
def
reduce_from_scatter_to_polar_region
(
input_
,
dim_
):
return
_ReduceFromScatterToPolarRegion
.
apply
(
input_
,
dim_
)
def
gather_from_copy_to_polar_region
(
input_
,
dim_
,
shapes_
):
return
_GatherFromCopyToPolarRegion
.
apply
(
input_
,
dim_
,
shapes_
)
torch_harmonics/resample.py
View file @
318fc76e
...
...
@@ -78,8 +78,6 @@ class ResampleS2(nn.Module):
self
.
lats_in
=
torch
.
cat
([
torch
.
tensor
([
0.
],
dtype
=
torch
.
float64
),
self
.
lats_in
,
torch
.
tensor
([
math
.
pi
],
dtype
=
torch
.
float64
)]).
contiguous
()
#self.lats_in = np.insert(self.lats_in, 0, 0.0)
#self.lats_in = np.append(self.lats_in, np.pi)
# prepare the interpolation by computing indices to the left and right of each output latitude
lat_idx
=
torch
.
searchsorted
(
self
.
lats_in
,
self
.
lats_out
,
side
=
"right"
)
-
1
...
...
@@ -135,11 +133,12 @@ class ResampleS2(nn.Module):
return
x
def
_expand_poles
(
self
,
x
:
torch
.
Tensor
):
repeats
=
[
1
for
_
in
x
.
shape
]
repeats
[
-
1
]
=
x
.
shape
[
-
1
]
x_north
=
x
[...,
0
:
1
,
:].
mean
(
dim
=-
1
,
keepdim
=
True
).
repeat
(
*
repeats
)
x_south
=
x
[...,
-
1
:,
:].
mean
(
dim
=-
1
,
keepdim
=
True
).
repeat
(
*
repeats
)
x
=
torch
.
concatenate
((
x_north
,
x
,
x_south
),
dim
=-
2
).
contiguous
()
x_north
=
x
[...,
0
,
:].
mean
(
dim
=-
1
,
keepdims
=
True
)
x_south
=
x
[...,
-
1
,
:].
mean
(
dim
=-
1
,
keepdims
=
True
)
x
=
nn
.
functional
.
pad
(
x
,
pad
=
[
0
,
0
,
1
,
1
],
mode
=
'constant'
)
x
[...,
0
,
:]
=
x_north
[...]
x
[...,
-
1
,
:]
=
x_south
[...]
return
x
def
_upscale_latitudes
(
self
,
x
:
torch
.
Tensor
):
...
...
@@ -162,6 +161,9 @@ class ResampleS2(nn.Module):
if
self
.
expand_poles
:
x
=
self
.
_expand_poles
(
x
)
x
=
self
.
_upscale_latitudes
(
x
)
x
=
self
.
_upscale_longitudes
(
x
)
return
x
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment