Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
e5a9c4af
Commit
e5a9c4af
authored
Jan 13, 2025
by
Thorsten Kurth
Committed by
Boris Bonev
Jan 14, 2025
Browse files
adding distributed resampling and test routines
parent
3350099a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
467 additions
and
2 deletions
+467
-2
notebooks/plot_spherical_harmonics.ipynb
notebooks/plot_spherical_harmonics.ipynb
+2
-2
tests/test_distributed_resample.py
tests/test_distributed_resample.py
+265
-0
torch_harmonics/distributed/__init__.py
torch_harmonics/distributed/__init__.py
+3
-0
torch_harmonics/distributed/distributed_resample.py
torch_harmonics/distributed/distributed_resample.py
+197
-0
No files found.
notebooks/plot_spherical_harmonics.ipynb
View file @
e5a9c4af
...
@@ -205,7 +205,7 @@
...
@@ -205,7 +205,7 @@
],
],
"metadata": {
"metadata": {
"kernelspec": {
"kernelspec": {
"display_name": "
Python 3
",
"display_name": "
dace
",
"language": "python",
"language": "python",
"name": "python3"
"name": "python3"
},
},
...
@@ -219,7 +219,7 @@
...
@@ -219,7 +219,7 @@
"name": "python",
"name": "python",
"nbconvert_exporter": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"pygments_lexer": "ipython3",
"version": "3.
10
.1
2
"
"version": "3.
8
.1
8
"
}
}
},
},
"nbformat": 4,
"nbformat": 4,
...
...
tests/test_distributed_resample.py
0 → 100644
View file @
e5a9c4af
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import
os
import
unittest
from
parameterized
import
parameterized
import
torch
import
torch.nn.functional
as
F
import
torch.distributed
as
dist
import
torch_harmonics
as
harmonics
import
torch_harmonics.distributed
as
thd
class
TestDistributedResampling
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
# set up distributed
cls
.
world_rank
=
int
(
os
.
getenv
(
"WORLD_RANK"
,
0
))
cls
.
grid_size_h
=
int
(
os
.
getenv
(
"GRID_H"
,
1
))
cls
.
grid_size_w
=
int
(
os
.
getenv
(
"GRID_W"
,
1
))
port
=
int
(
os
.
getenv
(
"MASTER_PORT"
,
"29501"
))
master_address
=
os
.
getenv
(
"MASTER_ADDR"
,
"localhost"
)
cls
.
world_size
=
cls
.
grid_size_h
*
cls
.
grid_size_w
if
torch
.
cuda
.
is_available
():
if
cls
.
world_rank
==
0
:
print
(
"Running test on GPU"
)
local_rank
=
cls
.
world_rank
%
torch
.
cuda
.
device_count
()
cls
.
device
=
torch
.
device
(
f
"cuda:
{
local_rank
}
"
)
torch
.
cuda
.
set_device
(
local_rank
)
torch
.
cuda
.
manual_seed
(
333
)
proc_backend
=
"nccl"
else
:
if
cls
.
world_rank
==
0
:
print
(
"Running test on CPU"
)
cls
.
device
=
torch
.
device
(
"cpu"
)
proc_backend
=
"gloo"
torch
.
manual_seed
(
333
)
dist
.
init_process_group
(
backend
=
proc_backend
,
init_method
=
f
"tcp://
{
master_address
}
:
{
port
}
"
,
rank
=
cls
.
world_rank
,
world_size
=
cls
.
world_size
)
cls
.
wrank
=
cls
.
world_rank
%
cls
.
grid_size_w
cls
.
hrank
=
cls
.
world_rank
//
cls
.
grid_size_w
# now set up the comm groups:
# set default
cls
.
w_group
=
None
cls
.
h_group
=
None
# do the init
wgroups
=
[]
for
w
in
range
(
0
,
cls
.
world_size
,
cls
.
grid_size_w
):
start
=
w
end
=
w
+
cls
.
grid_size_w
wgroups
.
append
(
list
(
range
(
start
,
end
)))
if
cls
.
world_rank
==
0
:
print
(
"w-groups:"
,
wgroups
)
for
grp
in
wgroups
:
if
len
(
grp
)
==
1
:
continue
tmp_group
=
dist
.
new_group
(
ranks
=
grp
)
if
cls
.
world_rank
in
grp
:
cls
.
w_group
=
tmp_group
# transpose:
hgroups
=
[
sorted
(
list
(
i
))
for
i
in
zip
(
*
wgroups
)]
if
cls
.
world_rank
==
0
:
print
(
"h-groups:"
,
hgroups
)
for
grp
in
hgroups
:
if
len
(
grp
)
==
1
:
continue
tmp_group
=
dist
.
new_group
(
ranks
=
grp
)
if
cls
.
world_rank
in
grp
:
cls
.
h_group
=
tmp_group
if
cls
.
world_rank
==
0
:
print
(
f
"Running distributed tests on grid H x W =
{
cls
.
grid_size_h
}
x
{
cls
.
grid_size_w
}
"
)
# initializing sht
thd
.
init
(
cls
.
h_group
,
cls
.
w_group
)
@
classmethod
def
tearDownClass
(
cls
):
thd
.
finalize
()
dist
.
destroy_process_group
(
None
)
def
_split_helper
(
self
,
tensor
):
with
torch
.
no_grad
():
# split in W
tensor_list_local
=
thd
.
split_tensor_along_dim
(
tensor
,
dim
=-
1
,
num_chunks
=
self
.
grid_size_w
)
tensor_local
=
tensor_list_local
[
self
.
wrank
]
# split in H
tensor_list_local
=
thd
.
split_tensor_along_dim
(
tensor_local
,
dim
=-
2
,
num_chunks
=
self
.
grid_size_h
)
tensor_local
=
tensor_list_local
[
self
.
hrank
]
return
tensor_local
def
_gather_helper_fwd
(
self
,
tensor
,
B
,
C
,
convolution_dist
):
# we need the shapes
lat_shapes
=
convolution_dist
.
lat_out_shapes
lon_shapes
=
convolution_dist
.
lon_out_shapes
# gather in W
tensor
=
tensor
.
contiguous
()
if
self
.
grid_size_w
>
1
:
gather_shapes
=
[(
B
,
C
,
lat_shapes
[
self
.
hrank
],
w
)
for
w
in
lon_shapes
]
olist
=
[
torch
.
empty
(
shape
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
for
shape
in
gather_shapes
]
olist
[
self
.
wrank
]
=
tensor
dist
.
all_gather
(
olist
,
tensor
,
group
=
self
.
w_group
)
tensor_gather
=
torch
.
cat
(
olist
,
dim
=-
1
)
else
:
tensor_gather
=
tensor
# gather in H
tensor_gather
=
tensor_gather
.
contiguous
()
if
self
.
grid_size_h
>
1
:
gather_shapes
=
[(
B
,
C
,
h
,
convolution_dist
.
nlon_out
)
for
h
in
lat_shapes
]
olist
=
[
torch
.
empty
(
shape
,
dtype
=
tensor_gather
.
dtype
,
device
=
tensor_gather
.
device
)
for
shape
in
gather_shapes
]
olist
[
self
.
hrank
]
=
tensor_gather
dist
.
all_gather
(
olist
,
tensor_gather
,
group
=
self
.
h_group
)
tensor_gather
=
torch
.
cat
(
olist
,
dim
=-
2
)
return
tensor_gather
def
_gather_helper_bwd
(
self
,
tensor
,
B
,
C
,
resampling_dist
):
# we need the shapes
lat_shapes
=
resampling_dist
.
lat_in_shapes
lon_shapes
=
resampling_dist
.
lon_in_shapes
# gather in W
if
self
.
grid_size_w
>
1
:
gather_shapes
=
[(
B
,
C
,
lat_shapes
[
self
.
hrank
],
w
)
for
w
in
lon_shapes
]
olist
=
[
torch
.
empty
(
shape
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
for
shape
in
gather_shapes
]
olist
[
self
.
wrank
]
=
tensor
dist
.
all_gather
(
olist
,
tensor
,
group
=
self
.
w_group
)
tensor_gather
=
torch
.
cat
(
olist
,
dim
=-
1
)
else
:
tensor_gather
=
tensor
# gather in H
if
self
.
grid_size_h
>
1
:
gather_shapes
=
[(
B
,
C
,
h
,
resampling_dist
.
nlon_in
)
for
h
in
lat_shapes
]
olist
=
[
torch
.
empty
(
shape
,
dtype
=
tensor_gather
.
dtype
,
device
=
tensor_gather
.
device
)
for
shape
in
gather_shapes
]
olist
[
self
.
hrank
]
=
tensor_gather
dist
.
all_gather
(
olist
,
tensor_gather
,
group
=
self
.
h_group
)
tensor_gather
=
torch
.
cat
(
olist
,
dim
=-
2
)
return
tensor_gather
@
parameterized
.
expand
(
[
[
64
,
128
,
128
,
256
,
32
,
8
,
"equiangular"
,
"equiangular"
,
1e-7
],
[
128
,
256
,
64
,
128
,
32
,
8
,
"equiangular"
,
"equiangular"
,
1e-7
],
]
)
def
test_distributed_resampling
(
self
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
batch_size
,
num_chan
,
grid_in
,
grid_out
,
tol
):
B
,
C
,
H
,
W
=
batch_size
,
num_chan
,
nlat_in
,
nlon_in
res_args
=
dict
(
nlat_in
=
nlat_in
,
nlon_in
=
nlon_in
,
nlat_out
=
nlat_out
,
nlon_out
=
nlon_out
,
grid_in
=
grid_in
,
grid_out
=
grid_out
,
)
# set up handlesD
res_local
=
harmonics
.
ResampleS2
(
**
res_args
).
to
(
self
.
device
)
res_dist
=
thd
.
DistributedResampleS2
(
**
res_args
).
to
(
self
.
device
)
# create tensors
inp_full
=
torch
.
randn
((
B
,
C
,
H
,
W
),
dtype
=
torch
.
float32
,
device
=
self
.
device
)
#############################################################
# local conv
#############################################################
# FWD pass
inp_full
.
requires_grad
=
True
out_full
=
res_local
(
inp_full
)
# create grad for backward
with
torch
.
no_grad
():
# create full grad
ograd_full
=
torch
.
randn_like
(
out_full
)
# BWD pass
out_full
.
backward
(
ograd_full
)
igrad_full
=
inp_full
.
grad
.
clone
()
#############################################################
# distributed conv
#############################################################
# FWD pass
inp_local
=
self
.
_split_helper
(
inp_full
)
inp_local
.
requires_grad
=
True
out_local
=
res_dist
(
inp_local
)
# BWD pass
ograd_local
=
self
.
_split_helper
(
ograd_full
)
out_local
=
res_dist
(
inp_local
)
out_local
.
backward
(
ograd_local
)
igrad_local
=
inp_local
.
grad
.
clone
()
#############################################################
# evaluate FWD pass
#############################################################
with
torch
.
no_grad
():
out_gather_full
=
self
.
_gather_helper_fwd
(
out_local
,
B
,
C
,
res_dist
)
err
=
torch
.
mean
(
torch
.
norm
(
out_full
-
out_gather_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
))
/
torch
.
norm
(
out_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
)))
if
self
.
world_rank
==
0
:
print
(
f
"final relative error of output:
{
err
.
item
()
}
"
)
self
.
assertTrue
(
err
.
item
()
<=
tol
)
#############################################################
# evaluate BWD pass
#############################################################
with
torch
.
no_grad
():
igrad_gather_full
=
self
.
_gather_helper_bwd
(
igrad_local
,
B
,
C
,
res_dist
)
err
=
torch
.
mean
(
torch
.
norm
(
igrad_full
-
igrad_gather_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
))
/
torch
.
norm
(
igrad_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
)))
if
self
.
world_rank
==
0
:
print
(
f
"final relative error of gradients:
{
err
.
item
()
}
"
)
self
.
assertTrue
(
err
.
item
()
<=
tol
)
if
__name__
==
"__main__"
:
unittest
.
main
()
torch_harmonics/distributed/__init__.py
View file @
e5a9c4af
...
@@ -51,3 +51,6 @@ from .distributed_sht import DistributedRealVectorSHT, DistributedInverseRealVec
...
@@ -51,3 +51,6 @@ from .distributed_sht import DistributedRealVectorSHT, DistributedInverseRealVec
# import DISCO
# import DISCO
from
.distributed_convolution
import
DistributedDiscreteContinuousConvS2
from
.distributed_convolution
import
DistributedDiscreteContinuousConvS2
from
.distributed_convolution
import
DistributedDiscreteContinuousConvTransposeS2
from
.distributed_convolution
import
DistributedDiscreteContinuousConvTransposeS2
# import resampling
from
.distributed_resample
import
DistributedResampleS2
torch_harmonics/distributed/distributed_resample.py
0 → 100644
View file @
e5a9c4af
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
from
typing
import
List
,
Tuple
,
Union
,
Optional
import
math
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
torch_harmonics.quadrature
import
_precompute_latitudes
from
torch_harmonics.distributed
import
polar_group_size
,
azimuth_group_size
,
distributed_transpose_azimuth
,
distributed_transpose_polar
from
torch_harmonics.distributed
import
polar_group_rank
,
azimuth_group_rank
from
torch_harmonics.distributed
import
compute_split_shapes
class
DistributedResampleS2
(
nn
.
Module
):
def
__init__
(
self
,
nlat_in
:
int
,
nlon_in
:
int
,
nlat_out
:
int
,
nlon_out
:
int
,
grid_in
:
Optional
[
str
]
=
"equiangular"
,
grid_out
:
Optional
[
str
]
=
"equiangular"
,
mode
:
Optional
[
str
]
=
"bilinear"
,
):
super
().
__init__
()
# currently only bilinear is supported
if
mode
==
"bilinear"
:
self
.
mode
=
mode
else
:
raise
NotImplementedError
(
f
"unknown interpolation mode
{
mode
}
"
)
self
.
nlat_in
,
self
.
nlon_in
=
nlat_in
,
nlon_in
self
.
nlat_out
,
self
.
nlon_out
=
nlat_out
,
nlon_out
self
.
grid_in
=
grid_in
self
.
grid_out
=
grid_out
# get the comms grid:
self
.
comm_size_polar
=
polar_group_size
()
self
.
comm_rank_polar
=
polar_group_rank
()
self
.
comm_size_azimuth
=
azimuth_group_size
()
self
.
comm_rank_azimuth
=
azimuth_group_rank
()
# compute splits: is this correct even when expanding the poles?
self
.
lat_in_shapes
=
compute_split_shapes
(
self
.
nlat_in
,
self
.
comm_size_polar
)
self
.
lon_in_shapes
=
compute_split_shapes
(
self
.
nlon_in
,
self
.
comm_size_azimuth
)
self
.
lat_out_shapes
=
compute_split_shapes
(
self
.
nlat_out
,
self
.
comm_size_polar
)
self
.
lon_out_shapes
=
compute_split_shapes
(
self
.
nlon_out
,
self
.
comm_size_azimuth
)
# for upscaling the latitudes we will use interpolation
self
.
lats_in
,
_
=
_precompute_latitudes
(
nlat_in
,
grid
=
grid_in
)
self
.
lons_in
=
np
.
linspace
(
0
,
2
*
math
.
pi
,
nlon_in
,
endpoint
=
False
)
self
.
lats_out
,
_
=
_precompute_latitudes
(
nlat_out
,
grid
=
grid_out
)
self
.
lons_out
=
np
.
linspace
(
0
,
2
*
math
.
pi
,
nlon_out
,
endpoint
=
False
)
# in the case where some points lie outside of the range spanned by lats_in,
# we need to expand the solution to the poles before interpolating
self
.
expand_poles
=
(
self
.
lats_out
>
self
.
lats_in
[
-
1
]).
any
()
or
(
self
.
lats_out
<
self
.
lats_in
[
0
]).
any
()
if
self
.
expand_poles
:
self
.
lats_in
=
np
.
insert
(
self
.
lats_in
,
0
,
0.0
)
self
.
lats_in
=
np
.
append
(
self
.
lats_in
,
np
.
pi
)
# prepare the interpolation by computing indices to the left and right of each output latitude
lat_idx
=
np
.
searchsorted
(
self
.
lats_in
,
self
.
lats_out
,
side
=
"right"
)
-
1
# make sure that we properly treat the last point if they coincide with the pole
lat_idx
=
np
.
where
(
self
.
lats_out
==
self
.
lats_in
[
-
1
],
lat_idx
-
1
,
lat_idx
)
# lat_idx = np.where(self.lats_out > self.lats_in[-1], lat_idx - 1, lat_idx)
# lat_idx = np.where(self.lats_out < self.lats_in[0], 0, lat_idx)
# compute the interpolation weights along the latitude
lat_weights
=
torch
.
from_numpy
((
self
.
lats_out
-
self
.
lats_in
[
lat_idx
])
/
np
.
diff
(
self
.
lats_in
)[
lat_idx
]).
float
()
lat_weights
=
lat_weights
.
unsqueeze
(
-
1
)
# convert to tensor
lat_idx
=
torch
.
LongTensor
(
lat_idx
)
# register buffers
self
.
register_buffer
(
"lat_idx"
,
lat_idx
,
persistent
=
False
)
self
.
register_buffer
(
"lat_weights"
,
lat_weights
,
persistent
=
False
)
# get left and right indices but this time make sure periodicity in the longitude is handled
lon_idx_left
=
np
.
searchsorted
(
self
.
lons_in
,
self
.
lons_out
,
side
=
"right"
)
-
1
lon_idx_right
=
np
.
where
(
self
.
lons_out
>=
self
.
lons_in
[
-
1
],
np
.
zeros_like
(
lon_idx_left
),
lon_idx_left
+
1
)
# get the difference
diff
=
self
.
lons_in
[
lon_idx_right
]
-
self
.
lons_in
[
lon_idx_left
]
diff
=
np
.
where
(
diff
<
0.0
,
diff
+
2
*
math
.
pi
,
diff
)
lon_weights
=
torch
.
from_numpy
((
self
.
lons_out
-
self
.
lons_in
[
lon_idx_left
])
/
diff
).
float
()
# convert to tensor
lon_idx_left
=
torch
.
LongTensor
(
lon_idx_left
)
lon_idx_right
=
torch
.
LongTensor
(
lon_idx_right
)
# register buffers
self
.
register_buffer
(
"lon_idx_left"
,
lon_idx_left
,
persistent
=
False
)
self
.
register_buffer
(
"lon_idx_right"
,
lon_idx_right
,
persistent
=
False
)
self
.
register_buffer
(
"lon_weights"
,
lon_weights
,
persistent
=
False
)
def
extra_repr
(
self
):
r
"""
Pretty print module
"""
return
f
"in_shape=
{
(
self
.
nlat_in
,
self
.
nlon_in
)
}
, out_shape=
{
(
self
.
nlat_out
,
self
.
nlon_out
)
}
"
def
_upscale_longitudes
(
self
,
x
:
torch
.
Tensor
):
# do the interpolation
x
=
torch
.
lerp
(
x
[...,
self
.
lon_idx_left
],
x
[...,
self
.
lon_idx_right
],
self
.
lon_weights
)
return
x
# old deprecated method with repeat_interleave
# def _upscale_longitudes(self, x: torch.Tensor):
# # for artifact-free upsampling in the longitudinal direction
# x = torch.repeat_interleave(x, self.lon_scale_factor, dim=-1)
# x = torch.roll(x, - self.lon_shift, dims=-1)
# return x
def
_expand_poles
(
self
,
x
:
torch
.
Tensor
):
repeats
=
[
1
for
_
in
x
.
shape
]
repeats
[
-
1
]
=
x
.
shape
[
-
1
]
x_north
=
x
[...,
0
:
1
,
:].
mean
(
dim
=-
1
,
keepdim
=
True
).
repeat
(
*
repeats
)
x_south
=
x
[...,
-
1
:,
:].
mean
(
dim
=-
1
,
keepdim
=
True
).
repeat
(
*
repeats
)
x
=
torch
.
concatenate
((
x_north
,
x
,
x_south
),
dim
=-
2
)
return
x
def
_upscale_latitudes
(
self
,
x
:
torch
.
Tensor
):
# do the interpolation
x
=
torch
.
lerp
(
x
[...,
self
.
lat_idx
,
:],
x
[...,
self
.
lat_idx
+
1
,
:],
self
.
lat_weights
)
return
x
def
forward
(
self
,
x
:
torch
.
Tensor
):
# transpose data so that h is local, and channels are split
num_chans
=
x
.
shape
[
-
3
]
# h and w is split. First we make w local by transposing into channel dim
if
self
.
comm_size_polar
>
1
:
channels_shapes
=
compute_split_shapes
(
num_chans
,
self
.
comm_size_polar
)
x
=
distributed_transpose_polar
.
apply
(
x
,
(
-
3
,
-
2
),
self
.
lat_in_shapes
)
# expand poles if requested
if
self
.
expand_poles
:
x
=
self
.
_expand_poles
(
x
)
# upscaling
x
=
self
.
_upscale_latitudes
(
x
)
# now, transpose back
if
self
.
comm_size_polar
>
1
:
x
=
distributed_transpose_polar
.
apply
(
x
,
(
-
2
,
-
3
),
channels_shapes
)
# now, transpose in w:
if
self
.
comm_size_azimuth
>
1
:
channels_shapes
=
compute_split_shapes
(
num_chans
,
self
.
comm_size_azimuth
)
x
=
distributed_transpose_azimuth
.
apply
(
x
,
(
-
3
,
-
1
),
self
.
lon_in_shapes
)
# upscale
x
=
self
.
_upscale_longitudes
(
x
)
# transpose back
if
self
.
comm_size_azimuth
>
1
:
x
=
distributed_transpose_azimuth
.
apply
(
x
,
(
-
1
,
-
3
),
channels_shapes
)
return
x
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment