test_distributed_sht.py 18.9 KB
Newer Older
1
2
3
4
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
5
#
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import os
import unittest
from parameterized import parameterized

import torch
import torch.nn.functional as F
import torch.distributed as dist
Boris Bonev's avatar
Boris Bonev committed
39
import torch_harmonics as th
40
41
42
43
import torch_harmonics.distributed as thd


class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
apaaris's avatar
apaaris committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    """
    Test the distributed spherical harmonic transform module.

    Parameters
    ----------
    nlat : int
        Number of latitude points
    nlon : int
        Number of longitude points
    batch_size : int
        Batch size
    num_chan : int
        Number of channels
    grid : str
        Grid type
    vector : bool
        Whether to use vector spherical harmonic transform
    tol : float
        Tolerance for numerical equivalence
    """
64
65
66
67

    @classmethod
    def setUpClass(cls):
        # set up distributed
68
69
70
71
72
        cls.world_rank = int(os.getenv("WORLD_RANK", 0))
        cls.grid_size_h = int(os.getenv("GRID_H", 1))
        cls.grid_size_w = int(os.getenv("GRID_W", 1))
        port = int(os.getenv("MASTER_PORT", "29501"))
        master_address = os.getenv("MASTER_ADDR", "localhost")
73
74
75
76
77
78
79
80
        cls.world_size = cls.grid_size_h * cls.grid_size_w

        if torch.cuda.is_available():
            if cls.world_rank == 0:
                print("Running test on GPU")
            local_rank = cls.world_rank % torch.cuda.device_count()
            cls.device = torch.device(f"cuda:{local_rank}")
            torch.cuda.manual_seed(333)
81
            proc_backend = "nccl"
82
83
84
        else:
            if cls.world_rank == 0:
                print("Running test on CPU")
85
86
            cls.device = torch.device("cpu")
            proc_backend = "gloo"
87
88
        torch.manual_seed(333)

89
90
        dist.init_process_group(backend=proc_backend, init_method=f"tcp://{master_address}:{port}", rank=cls.world_rank, world_size=cls.world_size)

91
92
93
94
        cls.wrank = cls.world_rank % cls.grid_size_w
        cls.hrank = cls.world_rank // cls.grid_size_w

        # now set up the comm groups:
95
        # set default
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
        cls.w_group = None
        cls.h_group = None

        # do the init
        wgroups = []
        for w in range(0, cls.world_size, cls.grid_size_w):
            start = w
            end = w + cls.grid_size_w
            wgroups.append(list(range(start, end)))

        if cls.world_rank == 0:
            print("w-groups:", wgroups)
        for grp in wgroups:
            if len(grp) == 1:
                continue
            tmp_group = dist.new_group(ranks=grp)
            if cls.world_rank in grp:
                cls.w_group = tmp_group

        # transpose:
        hgroups = [sorted(list(i)) for i in zip(*wgroups)]

        if cls.world_rank == 0:
            print("h-groups:", hgroups)
        for grp in hgroups:
            if len(grp) == 1:
                continue
            tmp_group = dist.new_group(ranks=grp)
            if cls.world_rank in grp:
                cls.h_group = tmp_group

        # set seed
        torch.manual_seed(333)

        if cls.world_rank == 0:
            print(f"Running distributed tests on grid H x W = {cls.grid_size_h} x {cls.grid_size_w}")

        # initializing sht
        thd.init(cls.h_group, cls.w_group)

136
137
138
139
140
    @classmethod
    def tearDownClass(cls):
        thd.finalize()
        dist.destroy_process_group(None)

141
    def _split_helper(self, tensor):
apaaris's avatar
apaaris committed
142
143
144
145
146
147
148
149
150
151
152
153
154
        """
        Split the tensor along the W and H dimensions.

        Parameters
        ----------
        tensor : torch.Tensor
            The tensor to split

        Returns
        -------
        torch.Tensor
            The split tensor
        """
155
156
157
158
159
160
161
162
163
164
        with torch.no_grad():
            # split in W
            tensor_list_local = thd.split_tensor_along_dim(tensor, dim=-1, num_chunks=self.grid_size_w)
            tensor_local = tensor_list_local[self.wrank]

            # split in H
            tensor_list_local = thd.split_tensor_along_dim(tensor_local, dim=-2, num_chunks=self.grid_size_h)
            tensor_local = tensor_list_local[self.hrank]

        return tensor_local
165

166
    def _gather_helper_fwd(self, tensor, B, C, transform_dist, vector):
apaaris's avatar
apaaris committed
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
        """
        Gather the tensor along the W and H dimensions.

        Parameters
        ----------
        tensor : torch.Tensor
            The tensor to gather
        B : int
            Batch size
        C : int
            Number of channels
        transform_dist : thd.DistributedRealSHT or thd.DistributedRealVectorSHT
            The distributed transform
        vector : bool
            Whether to use vector spherical harmonic transform

        Returns
        -------
        torch.Tensor
            The gathered tensor
        """
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
        # we need the shapes
        l_shapes = transform_dist.l_shapes
        m_shapes = transform_dist.m_shapes

        # gather in W
        if self.grid_size_w > 1:
            if vector:
                gather_shapes = [(B, C, 2, l_shapes[self.hrank], m) for m in m_shapes]
            else:
                gather_shapes = [(B, C, l_shapes[self.hrank], m) for m in m_shapes]
            olist = [torch.empty(shape, dtype=tensor.dtype, device=tensor.device) for shape in gather_shapes]
            olist[self.wrank] = tensor
            dist.all_gather(olist, tensor, group=self.w_group)
            tensor_gather = torch.cat(olist, dim=-1)
        else:
            tensor_gather = tensor

        # gather in H
        if self.grid_size_h > 1:
            if vector:
                gather_shapes = [(B, C, 2, l, transform_dist.mmax) for l in l_shapes]
            else:
                gather_shapes = [(B, C, l, transform_dist.mmax) for l in l_shapes]
            olist = [torch.empty(shape, dtype=tensor_gather.dtype, device=tensor_gather.device) for shape in gather_shapes]
            olist[self.hrank] = tensor_gather
            dist.all_gather(olist, tensor_gather, group=self.h_group)
            tensor_gather = torch.cat(olist, dim=-2)

        return tensor_gather

    def _gather_helper_bwd(self, tensor, B, C, transform_dist, vector):
apaaris's avatar
apaaris committed
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
        """
        Gather the tensor along the W and H dimensions.

        Parameters
        ----------
        tensor : torch.Tensor
            The tensor to gather
        B : int
            Batch size
        C : int
            Number of channels
        transform_dist : thd.DistributedRealSHT or thd.DistributedRealVectorSHT
            The distributed transform
        vector : bool
            Whether to use vector spherical harmonic transform

        Returns
        -------
        torch.Tensor
            The gathered tensor
        """
        
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
        # we need the shapes
        lat_shapes = transform_dist.lat_shapes
        lon_shapes = transform_dist.lon_shapes

        # gather in W
        if self.grid_size_w > 1:
            if vector:
                gather_shapes = [(B, C, 2, lat_shapes[self.hrank], w) for w in lon_shapes]
            else:
                gather_shapes = [(B, C, lat_shapes[self.hrank], w) for w in lon_shapes]
            olist = [torch.empty(shape, dtype=tensor.dtype, device=tensor.device) for shape in gather_shapes]
            olist[self.wrank] = tensor
            dist.all_gather(olist, tensor, group=self.w_group)
            tensor_gather = torch.cat(olist, dim=-1)
        else:
            tensor_gather = tensor

        # gather in H
        if self.grid_size_h > 1:
            if vector:
                gather_shapes = [(B, C, 2, h, transform_dist.nlon) for h in lat_shapes]
            else:
                gather_shapes = [(B, C, h, transform_dist.nlon) for h in lat_shapes]
            olist = [torch.empty(shape, dtype=tensor_gather.dtype, device=tensor_gather.device) for shape in gather_shapes]
            olist[self.hrank] = tensor_gather
            dist.all_gather(olist, tensor_gather, group=self.h_group)
            tensor_gather = torch.cat(olist, dim=-2)

        return tensor_gather

271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
    @parameterized.expand(
        [
            [256, 512, 32, 8, "equiangular", False, 1e-9],
            [256, 512, 32, 8, "legendre-gauss", False, 1e-9],
            [256, 512, 32, 8, "equiangular", False, 1e-9],
            [256, 512, 32, 8, "legendre-gauss", False, 1e-9],
            [256, 512, 32, 8, "equiangular", False, 1e-9],
            [256, 512, 32, 8, "legendre-gauss", False, 1e-9],
            [361, 720, 1, 10, "equiangular", False, 1e-6],
            [361, 720, 1, 10, "legendre-gauss", False, 1e-6],
            [256, 512, 32, 8, "equiangular", True, 1e-9],
            [256, 512, 32, 8, "legendre-gauss", True, 1e-9],
            [256, 512, 32, 8, "equiangular", True, 1e-9],
            [256, 512, 32, 8, "legendre-gauss", True, 1e-9],
            [256, 512, 32, 8, "equiangular", True, 1e-9],
            [256, 512, 32, 8, "legendre-gauss", True, 1e-9],
            [361, 720, 1, 10, "equiangular", True, 1e-6],
            [361, 720, 1, 10, "legendre-gauss", True, 1e-6],
        ]
    )
291
    def test_distributed_sht(self, nlat, nlon, batch_size, num_chan, grid, vector, tol):
apaaris's avatar
apaaris committed
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
        """
        Test the distributed spherical harmonic transform.

        Parameters
        ----------
        nlat : int
            Number of latitude points
        nlon : int
            Number of longitude points
        batch_size : int
            Batch size
        num_chan : int
            Number of channels
        grid : str
            Grid type
        vector : bool
            Whether to use vector spherical harmonic transform
        tol : float
            Tolerance for numerical equivalence
        """

313
314
315
316
        B, C, H, W = batch_size, num_chan, nlat, nlon

        # set up handles
        if vector:
Boris Bonev's avatar
Boris Bonev committed
317
            forward_transform_local = th.RealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
318
319
            forward_transform_dist = thd.DistributedRealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
        else:
Boris Bonev's avatar
Boris Bonev committed
320
            forward_transform_local = th.RealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
            forward_transform_dist = thd.DistributedRealSHT(nlat=H, nlon=W, grid=grid).to(self.device)

        # create tensors
        if vector:
            inp_full = torch.randn((B, C, 2, H, W), dtype=torch.float32, device=self.device)
        else:
            inp_full = torch.randn((B, C, H, W), dtype=torch.float32, device=self.device)

        #############################################################
        # local transform
        #############################################################
        # FWD pass
        inp_full.requires_grad = True
        out_full = forward_transform_local(inp_full)

        # create grad for backward
        with torch.no_grad():
            # create full grad
            ograd_full = torch.randn_like(out_full)
340

341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
        # BWD pass
        out_full.backward(ograd_full)
        igrad_full = inp_full.grad.clone()

        #############################################################
        # distributed transform
        #############################################################
        # FWD pass
        inp_local = self._split_helper(inp_full)
        inp_local.requires_grad = True
        out_local = forward_transform_dist(inp_local)

        # BWD pass
        ograd_local = self._split_helper(ograd_full)
        out_local = forward_transform_dist(inp_local)
        out_local.backward(ograd_local)
        igrad_local = inp_local.grad.clone()

        #############################################################
        # evaluate FWD pass
        #############################################################
        with torch.no_grad():
            out_gather_full = self._gather_helper_fwd(out_local, B, C, forward_transform_dist, vector)
364
            err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2)))
365
366
367
368
369
370
371
372
373
            if self.world_rank == 0:
                print(f"final relative error of output: {err.item()}")
        self.assertTrue(err.item() <= tol)

        #############################################################
        # evaluate BWD pass
        #############################################################
        with torch.no_grad():
            igrad_gather_full = self._gather_helper_bwd(igrad_local, B, C, forward_transform_dist, vector)
374
            err = torch.mean(torch.norm(igrad_full - igrad_gather_full, p="fro", dim=(-1, -2)) / torch.norm(igrad_full, p="fro", dim=(-1, -2)))
375
376
377
378
            if self.world_rank == 0:
                print(f"final relative error of gradients: {err.item()}")
        self.assertTrue(err.item() <= tol)

379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
    @parameterized.expand(
        [
            [256, 512, 32, 8, "equiangular", False, 1e-9],
            [256, 512, 32, 8, "legendre-gauss", False, 1e-9],
            [256, 512, 32, 8, "equiangular", False, 1e-9],
            [256, 512, 32, 8, "legendre-gauss", False, 1e-9],
            [256, 512, 32, 8, "equiangular", False, 1e-9],
            [256, 512, 32, 8, "legendre-gauss", False, 1e-9],
            [361, 720, 1, 10, "equiangular", False, 1e-6],
            [361, 720, 1, 10, "legendre-gauss", False, 1e-6],
            [256, 512, 32, 8, "equiangular", True, 1e-9],
            [256, 512, 32, 8, "legendre-gauss", True, 1e-9],
            [256, 512, 32, 8, "equiangular", True, 1e-9],
            [256, 512, 32, 8, "legendre-gauss", True, 1e-9],
            [256, 512, 32, 8, "equiangular", True, 1e-9],
            [256, 512, 32, 8, "legendre-gauss", True, 1e-9],
            [361, 720, 1, 10, "equiangular", True, 1e-6],
            [361, 720, 1, 10, "legendre-gauss", True, 1e-6],
        ]
    )
399
    def test_distributed_isht(self, nlat, nlon, batch_size, num_chan, grid, vector, tol):
apaaris's avatar
apaaris committed
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
        """
        Test the distributed inverse spherical harmonic transform.

        Parameters
        ----------
        nlat : int
            Number of latitude points
        nlon : int
            Number of longitude points
        batch_size : int
            Batch size
        num_chan : int
            Number of channels
        grid : str
            Grid type
        vector : bool
            Whether to use vector spherical harmonic transform
        tol : float
            Tolerance for numerical equivalence
        """
        
421
422
423
        B, C, H, W = batch_size, num_chan, nlat, nlon

        if vector:
Boris Bonev's avatar
Boris Bonev committed
424
425
            forward_transform_local = th.RealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
            backward_transform_local = th.InverseRealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
426
            backward_transform_dist = thd.DistributedInverseRealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
427
        else:
Boris Bonev's avatar
Boris Bonev committed
428
429
            forward_transform_local = th.RealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
            backward_transform_local = th.InverseRealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
430
431
432
433
434
435
436
437
438
439
440
            backward_transform_dist = thd.DistributedInverseRealSHT(nlat=H, nlon=W, grid=grid).to(self.device)

        # create tensors
        if vector:
            dummy_full = torch.randn((B, C, 2, H, W), dtype=torch.float32, device=self.device)
        else:
            dummy_full = torch.randn((B, C, H, W), dtype=torch.float32, device=self.device)
        inp_full = forward_transform_local(dummy_full)

        #############################################################
        # local transform
441
442
        #############################################################
        # FWD pass
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
        inp_full.requires_grad = True
        out_full = backward_transform_local(inp_full)

        # create grad for backward
        with torch.no_grad():
            # create full grad
            ograd_full = torch.randn_like(out_full)

        # BWD pass
        out_full.backward(ograd_full)

        # repeat once due to known irfft bug
        inp_full.grad = None
        out_full = backward_transform_local(inp_full)
        out_full.backward(ograd_full)
        igrad_full = inp_full.grad.clone()

        #############################################################
        # distributed transform
        #############################################################
        # FWD pass
        inp_local = self._split_helper(inp_full)
        inp_local.requires_grad = True
        out_local = backward_transform_dist(inp_local)

        # BWD pass
        ograd_local = self._split_helper(ograd_full)
        out_local = backward_transform_dist(inp_local)
        out_local.backward(ograd_local)
        igrad_local = inp_local.grad.clone()

        #############################################################
        # evaluate FWD pass
        #############################################################
        with torch.no_grad():
            out_gather_full = self._gather_helper_bwd(out_local, B, C, backward_transform_dist, vector)
479
            err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2)))
480
481
482
483
484
485
486
487
488
            if self.world_rank == 0:
                print(f"final relative error of output: {err.item()}")
        self.assertTrue(err.item() <= tol)

        #############################################################
        # evaluate BWD pass
        #############################################################
        with torch.no_grad():
            igrad_gather_full = self._gather_helper_fwd(igrad_local, B, C, backward_transform_dist, vector)
489
            err = torch.mean(torch.norm(igrad_full - igrad_gather_full, p="fro", dim=(-1, -2)) / torch.norm(igrad_full, p="fro", dim=(-1, -2)))
490
491
492
493
            if self.world_rank == 0:
                print(f"final relative error of gradients: {err.item()}")
        self.assertTrue(err.item() <= tol)

494
495

if __name__ == "__main__":
496
    unittest.main()