test_distributed_resample.py 10.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import os
import unittest
from parameterized import parameterized

import torch
import torch.nn.functional as F
import torch.distributed as dist
Boris Bonev's avatar
Boris Bonev committed
39
import torch_harmonics as th
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import torch_harmonics.distributed as thd


class TestDistributedResampling(unittest.TestCase):

    @classmethod
    def setUpClass(cls):

        # set up distributed
        cls.world_rank = int(os.getenv("WORLD_RANK", 0))
        cls.grid_size_h = int(os.getenv("GRID_H", 1))
        cls.grid_size_w = int(os.getenv("GRID_W", 1))
        port = int(os.getenv("MASTER_PORT", "29501"))
        master_address = os.getenv("MASTER_ADDR", "localhost")
        cls.world_size = cls.grid_size_h * cls.grid_size_w

        if torch.cuda.is_available():
            if cls.world_rank == 0:
                print("Running test on GPU")
            local_rank = cls.world_rank % torch.cuda.device_count()
            cls.device = torch.device(f"cuda:{local_rank}")
            torch.cuda.set_device(local_rank)
            torch.cuda.manual_seed(333)
            proc_backend = "nccl"
        else:
            if cls.world_rank == 0:
                print("Running test on CPU")
            cls.device = torch.device("cpu")
            proc_backend = "gloo"
        torch.manual_seed(333)

        dist.init_process_group(backend=proc_backend, init_method=f"tcp://{master_address}:{port}", rank=cls.world_rank, world_size=cls.world_size)

        cls.wrank = cls.world_rank % cls.grid_size_w
        cls.hrank = cls.world_rank // cls.grid_size_w

        # now set up the comm groups:
        # set default
        cls.w_group = None
        cls.h_group = None

        # do the init
        wgroups = []
        for w in range(0, cls.world_size, cls.grid_size_w):
            start = w
            end = w + cls.grid_size_w
            wgroups.append(list(range(start, end)))

        if cls.world_rank == 0:
            print("w-groups:", wgroups)
        for grp in wgroups:
            if len(grp) == 1:
                continue
            tmp_group = dist.new_group(ranks=grp)
            if cls.world_rank in grp:
                cls.w_group = tmp_group

        # transpose:
        hgroups = [sorted(list(i)) for i in zip(*wgroups)]

        if cls.world_rank == 0:
            print("h-groups:", hgroups)
        for grp in hgroups:
            if len(grp) == 1:
                continue
            tmp_group = dist.new_group(ranks=grp)
            if cls.world_rank in grp:
                cls.h_group = tmp_group

        if cls.world_rank == 0:
            print(f"Running distributed tests on grid H x W = {cls.grid_size_h} x {cls.grid_size_w}")

        # initializing sht
        thd.init(cls.h_group, cls.w_group)

    @classmethod
    def tearDownClass(cls):
        thd.finalize()
        dist.destroy_process_group(None)

    def _split_helper(self, tensor):
        with torch.no_grad():
            # split in W
            tensor_list_local = thd.split_tensor_along_dim(tensor, dim=-1, num_chunks=self.grid_size_w)
            tensor_local = tensor_list_local[self.wrank]

            # split in H
            tensor_list_local = thd.split_tensor_along_dim(tensor_local, dim=-2, num_chunks=self.grid_size_h)
            tensor_local = tensor_list_local[self.hrank]

        return tensor_local

    def _gather_helper_fwd(self, tensor, B, C, convolution_dist):
        # we need the shapes
        lat_shapes = convolution_dist.lat_out_shapes
        lon_shapes = convolution_dist.lon_out_shapes

        # gather in W
        tensor = tensor.contiguous()
        if self.grid_size_w > 1:
            gather_shapes = [(B, C, lat_shapes[self.hrank], w) for w in lon_shapes]
            olist = [torch.empty(shape, dtype=tensor.dtype, device=tensor.device) for shape in gather_shapes]
            olist[self.wrank] = tensor
            dist.all_gather(olist, tensor, group=self.w_group)
            tensor_gather = torch.cat(olist, dim=-1)
        else:
            tensor_gather = tensor

        # gather in H
        tensor_gather = tensor_gather.contiguous()
        if self.grid_size_h > 1:
            gather_shapes = [(B, C, h, convolution_dist.nlon_out) for h in lat_shapes]
            olist = [torch.empty(shape, dtype=tensor_gather.dtype, device=tensor_gather.device) for shape in gather_shapes]
            olist[self.hrank] = tensor_gather
            dist.all_gather(olist, tensor_gather, group=self.h_group)
            tensor_gather = torch.cat(olist, dim=-2)

        return tensor_gather

    def _gather_helper_bwd(self, tensor, B, C, resampling_dist):
        # we need the shapes
        lat_shapes = resampling_dist.lat_in_shapes
        lon_shapes = resampling_dist.lon_in_shapes

        # gather in W
        if self.grid_size_w > 1:
            gather_shapes = [(B, C, lat_shapes[self.hrank], w) for w in lon_shapes]
            olist = [torch.empty(shape, dtype=tensor.dtype, device=tensor.device) for shape in gather_shapes]
            olist[self.wrank] = tensor
            dist.all_gather(olist, tensor, group=self.w_group)
            tensor_gather = torch.cat(olist, dim=-1)
        else:
            tensor_gather = tensor

        # gather in H
        if self.grid_size_h > 1:
            gather_shapes = [(B, C, h, resampling_dist.nlon_in) for h in lat_shapes]
            olist = [torch.empty(shape, dtype=tensor_gather.dtype, device=tensor_gather.device) for shape in gather_shapes]
            olist[self.hrank] = tensor_gather
            dist.all_gather(olist, tensor_gather, group=self.h_group)
            tensor_gather = torch.cat(olist, dim=-2)

        return tensor_gather

    @parameterized.expand(
        [
Thorsten Kurth's avatar
Thorsten Kurth committed
186
187
188
189
            [64, 128, 128, 256, 32, 8, "equiangular", "equiangular", "bilinear", 1e-7, False],
            [128, 256, 64, 128, 32, 8, "equiangular", "equiangular", "bilinear", 1e-7, False],
            [64, 128, 128, 256, 32, 8, "equiangular", "equiangular", "bilinear-spherical", 1e-7, False],
            [128, 256, 64, 128, 32, 8, "equiangular", "equiangular", "bilinear-spherical", 1e-7, False],
190
191
192
        ]
    )
    def test_distributed_resampling(
Thorsten Kurth's avatar
Thorsten Kurth committed
193
            self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, grid_in, grid_out, mode, tol, verbose
194
195
196
197
198
    ):

        B, C, H, W = batch_size, num_chan, nlat_in, nlon_in

        res_args = dict(
Boris Bonev's avatar
Boris Bonev committed
199
            nlat_in=nlat_in,
200
            nlon_in=nlon_in,
Boris Bonev's avatar
Boris Bonev committed
201
            nlat_out=nlat_out,
202
203
204
            nlon_out=nlon_out,
            grid_in=grid_in,
            grid_out=grid_out,
Thorsten Kurth's avatar
Thorsten Kurth committed
205
            mode=mode,
206
207
208
        )

        # set up handlesD
Boris Bonev's avatar
Boris Bonev committed
209
        res_local = th.ResampleS2(**res_args).to(self.device)
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
        res_dist = thd.DistributedResampleS2(**res_args).to(self.device)

        # create tensors
        inp_full = torch.randn((B, C, H, W), dtype=torch.float32, device=self.device)

        #############################################################
        # local conv
        #############################################################
        # FWD pass
        inp_full.requires_grad = True
        out_full = res_local(inp_full)

        # create grad for backward
        with torch.no_grad():
            # create full grad
            ograd_full = torch.randn_like(out_full)

        # BWD pass
        out_full.backward(ograd_full)
        igrad_full = inp_full.grad.clone()

        #############################################################
        # distributed conv
        #############################################################
        # FWD pass
        inp_local = self._split_helper(inp_full)
        inp_local.requires_grad = True
        out_local = res_dist(inp_local)

        # BWD pass
        ograd_local = self._split_helper(ograd_full)
        out_local = res_dist(inp_local)
        out_local.backward(ograd_local)
        igrad_local = inp_local.grad.clone()

        #############################################################
        # evaluate FWD pass
        #############################################################
        with torch.no_grad():
            out_gather_full = self._gather_helper_fwd(out_local, B, C, res_dist)
            err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2)))
Thorsten Kurth's avatar
Thorsten Kurth committed
251
            if verbose and (self.world_rank == )0:
252
253
254
255
256
257
258
259
260
261
                print(f"final relative error of output: {err.item()}")
        self.assertTrue(err.item() <= tol)

        #############################################################
        # evaluate BWD pass
        #############################################################
        with torch.no_grad():
            igrad_gather_full = self._gather_helper_bwd(igrad_local, B, C, res_dist)

            err = torch.mean(torch.norm(igrad_full - igrad_gather_full, p="fro", dim=(-1, -2)) / torch.norm(igrad_full, p="fro", dim=(-1, -2)))
Thorsten Kurth's avatar
Thorsten Kurth committed
262
            if verbose and (self.world_rank == 0):
263
264
265
266
267
268
                print(f"final relative error of gradients: {err.item()}")
        self.assertTrue(err.item() <= tol)


if __name__ == "__main__":
    unittest.main()