test_distributed_convolution.py 15.9 KB
Newer Older
1
2
3
4
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
Boris Bonev's avatar
Boris Bonev committed
5
#
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import os
import unittest
from parameterized import parameterized

import torch
import torch.nn.functional as F
import torch.distributed as dist
Boris Bonev's avatar
Boris Bonev committed
39
import torch_harmonics as th
40
41
42
import torch_harmonics.distributed as thd


43
class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
apaaris's avatar
apaaris committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    """
    Test the distributed discrete-continuous convolution module.
    
    Parameters
    ----------
    nlat_in : int
        Number of latitude points in input
    nlon_in : int
        Number of longitude points in input
    nlat_out : int
        Number of latitude points in output
    nlon_out : int
        Number of longitude points in output
    batch_size : int
        Batch size
    num_chan : int
        Number of channels
    kernel_shape : tuple
        Kernel shape
    basis_type : str
        Basis type
    basis_norm_mode : str
        Basis normalization mode
    groups : int
        Number of groups
    grid_in : str
        Grid type for input
    grid_out : str
        Grid type for output
    transpose : bool
        Whether to transpose the convolution
    tol : float
        Tolerance for numerical equivalence
    """
78
79
80

    @classmethod
    def setUpClass(cls):
apaaris's avatar
apaaris committed
81
82
83
84
85
86
87
88
        """
        Set up the distributed convolution test.
        
        Parameters
        ----------
        cls : TestDistributedDiscreteContinuousConvolution
            The test class instance
        """
89
90

        # set up distributed
Boris Bonev's avatar
Boris Bonev committed
91
92
93
94
95
        cls.world_rank = int(os.getenv("WORLD_RANK", 0))
        cls.grid_size_h = int(os.getenv("GRID_H", 1))
        cls.grid_size_w = int(os.getenv("GRID_W", 1))
        port = int(os.getenv("MASTER_PORT", "29501"))
        master_address = os.getenv("MASTER_ADDR", "localhost")
96
97
98
99
100
101
102
        cls.world_size = cls.grid_size_h * cls.grid_size_w

        if torch.cuda.is_available():
            if cls.world_rank == 0:
                print("Running test on GPU")
            local_rank = cls.world_rank % torch.cuda.device_count()
            cls.device = torch.device(f"cuda:{local_rank}")
103
            torch.cuda.set_device(local_rank)
104
            torch.cuda.manual_seed(333)
Boris Bonev's avatar
Boris Bonev committed
105
            proc_backend = "nccl"
106
107
108
        else:
            if cls.world_rank == 0:
                print("Running test on CPU")
Boris Bonev's avatar
Boris Bonev committed
109
110
            cls.device = torch.device("cpu")
            proc_backend = "gloo"
111
112
        torch.manual_seed(333)

Boris Bonev's avatar
Boris Bonev committed
113
114
        dist.init_process_group(backend=proc_backend, init_method=f"tcp://{master_address}:{port}", rank=cls.world_rank, world_size=cls.world_size)

115
116
117
118
        cls.wrank = cls.world_rank % cls.grid_size_w
        cls.hrank = cls.world_rank // cls.grid_size_w

        # now set up the comm groups:
Boris Bonev's avatar
Boris Bonev committed
119
        # set default
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
        cls.w_group = None
        cls.h_group = None

        # do the init
        wgroups = []
        for w in range(0, cls.world_size, cls.grid_size_w):
            start = w
            end = w + cls.grid_size_w
            wgroups.append(list(range(start, end)))

        if cls.world_rank == 0:
            print("w-groups:", wgroups)
        for grp in wgroups:
            if len(grp) == 1:
                continue
            tmp_group = dist.new_group(ranks=grp)
            if cls.world_rank in grp:
                cls.w_group = tmp_group

        # transpose:
        hgroups = [sorted(list(i)) for i in zip(*wgroups)]

        if cls.world_rank == 0:
            print("h-groups:", hgroups)
        for grp in hgroups:
            if len(grp) == 1:
                continue
            tmp_group = dist.new_group(ranks=grp)
            if cls.world_rank in grp:
                cls.h_group = tmp_group

        if cls.world_rank == 0:
            print(f"Running distributed tests on grid H x W = {cls.grid_size_h} x {cls.grid_size_w}")

        # initializing sht
        thd.init(cls.h_group, cls.w_group)

157
158
    @classmethod
    def tearDownClass(cls):
apaaris's avatar
apaaris committed
159
160
161
162
163
164
165
166
167
        """
        Tear down the distributed convolution test.
        
        Parameters
        ----------
        cls : TestDistributedDiscreteContinuousConvolution
            The test class instance
        """

168
169
        thd.finalize()
        dist.destroy_process_group(None)
170

171
    def _split_helper(self, tensor):
apaaris's avatar
apaaris committed
172
173
174
175
176
177
178
179
180
        """
        Split the tensor along the horizontal and vertical dimensions.
        
        Parameters
        ----------
        tensor : torch.Tensor
            The tensor to split
        """

181
182
183
184
185
186
187
188
189
190
        with torch.no_grad():
            # split in W
            tensor_list_local = thd.split_tensor_along_dim(tensor, dim=-1, num_chunks=self.grid_size_w)
            tensor_local = tensor_list_local[self.wrank]

            # split in H
            tensor_list_local = thd.split_tensor_along_dim(tensor_local, dim=-2, num_chunks=self.grid_size_h)
            tensor_local = tensor_list_local[self.hrank]

        return tensor_local
Boris Bonev's avatar
Boris Bonev committed
191

192
    def _gather_helper_fwd(self, tensor, B, C, convolution_dist):
apaaris's avatar
apaaris committed
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
        """
        Gather the tensor along the horizontal and vertical dimensions.
        
        Parameters
        ----------
        tensor : torch.Tensor
            The tensor to gather
        B : int
            Batch size
        C : int
            Number of channels
        convolution_dist : thd.DistributedDiscreteContinuousConvTransposeS2 or thd.DistributedDiscreteContinuousConvS2
            The distributed convolution object
        """

208
        # we need the shapes
209
210
        lat_shapes = convolution_dist.lat_out_shapes
        lon_shapes = convolution_dist.lon_out_shapes
211
212

        # gather in W
Thorsten Kurth's avatar
Thorsten Kurth committed
213
        tensor = tensor.contiguous()
214
        if self.grid_size_w > 1:
215
            gather_shapes = [(B, C, lat_shapes[self.hrank], w) for w in lon_shapes]
216
217
218
219
220
221
222
223
            olist = [torch.empty(shape, dtype=tensor.dtype, device=tensor.device) for shape in gather_shapes]
            olist[self.wrank] = tensor
            dist.all_gather(olist, tensor, group=self.w_group)
            tensor_gather = torch.cat(olist, dim=-1)
        else:
            tensor_gather = tensor

        # gather in H
Thorsten Kurth's avatar
Thorsten Kurth committed
224
        tensor_gather = tensor_gather.contiguous()
225
        if self.grid_size_h > 1:
226
            gather_shapes = [(B, C, h, convolution_dist.nlon_out) for h in lat_shapes]
227
228
229
230
231
232
233
            olist = [torch.empty(shape, dtype=tensor_gather.dtype, device=tensor_gather.device) for shape in gather_shapes]
            olist[self.hrank] = tensor_gather
            dist.all_gather(olist, tensor_gather, group=self.h_group)
            tensor_gather = torch.cat(olist, dim=-2)

        return tensor_gather

234
    def _gather_helper_bwd(self, tensor, B, C, convolution_dist):
apaaris's avatar
apaaris committed
235
236
237
238
239
240
241
242
243
244
245
246
247
248
        """
        Gather the tensor along the horizontal and vertical dimensions.
        
        Parameters
        ----------
        tensor : torch.Tensor
            The tensor to gather
        B : int
            Batch size
        C : int
            Number of channels
        convolution_dist : thd.DistributedDiscreteContinuousConvTransposeS2 or thd.DistributedDiscreteContinuousConvS2
            The distributed convolution object
        """
249
        # we need the shapes
250
251
        lat_shapes = convolution_dist.lat_in_shapes
        lon_shapes = convolution_dist.lon_in_shapes
252
253
254

        # gather in W
        if self.grid_size_w > 1:
255
            gather_shapes = [(B, C, lat_shapes[self.hrank], w) for w in lon_shapes]
256
257
258
259
260
261
262
263
264
            olist = [torch.empty(shape, dtype=tensor.dtype, device=tensor.device) for shape in gather_shapes]
            olist[self.wrank] = tensor
            dist.all_gather(olist, tensor, group=self.w_group)
            tensor_gather = torch.cat(olist, dim=-1)
        else:
            tensor_gather = tensor

        # gather in H
        if self.grid_size_h > 1:
265
            gather_shapes = [(B, C, h, convolution_dist.nlon_in) for h in lat_shapes]
266
267
268
269
270
271
272
            olist = [torch.empty(shape, dtype=tensor_gather.dtype, device=tensor_gather.device) for shape in gather_shapes]
            olist[self.hrank] = tensor_gather
            dist.all_gather(olist, tensor_gather, group=self.h_group)
            tensor_gather = torch.cat(olist, dim=-2)

        return tensor_gather

Boris Bonev's avatar
Boris Bonev committed
273
274
    @parameterized.expand(
        [
Thorsten Kurth's avatar
Thorsten Kurth committed
275
276
277
278
279
280
281
282
283
284
285
286
            [128, 256, 128, 256, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
            [129, 256, 128, 256, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
            [128, 256, 128, 256, 32, 8, (3, 2), "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
            [128, 256, 64, 128, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
            [128, 256, 128, 256, 32, 8, (3), "piecewise linear", "mean", 2, "equiangular", "equiangular", False, 1e-5],
            [128, 256, 128, 256, 32, 6, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
            [128, 256, 128, 256, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
            [129, 256, 129, 256, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
            [128, 256, 128, 256, 32, 8, (3, 2), "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
            [64, 128, 128, 256, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
            [128, 256, 128, 256, 32, 8, (3), "piecewise linear", "mean", 2, "equiangular", "equiangular", True, 1e-5],
            [128, 256, 128, 256, 32, 6, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
287
288
289
290
            [129, 256, 129, 256, 32, 8, (3, 4), "morlet", "mean", 1, "equiangular", "equiangular", False, 1e-5],
            [129, 256, 129, 256, 32, 8, (3, 4), "morlet", "mean", 1, "equiangular", "equiangular", True, 1e-5],
            [65, 128, 129, 256, 32, 8, (3, 4), "morlet", "mean", 1, "equiangular", "equiangular", True, 1e-5],
            [129, 256, 65, 128, 32, 8, (3, 4), "morlet", "mean", 1, "equiangular", "equiangular", False, 1e-5],
Boris Bonev's avatar
Boris Bonev committed
291
292
        ]
    )
293
294
295
    def test_distributed_disco_conv(
        self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, kernel_shape, basis_type, basis_norm_mode, groups, grid_in, grid_out, transpose, tol
    ):
296

apaaris's avatar
apaaris committed
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
        """
        Test the distributed discrete-continuous convolution module.
        
        Parameters
        ----------
        nlat_in : int
            Number of latitude points in input
        nlon_in : int
            Number of longitude points in input
        nlat_out : int
            Number of latitude points in output
        nlon_out : int
            Number of longitude points in output
        batch_size : int
            Batch size
        num_chan : int
            Number of channels
        kernel_shape : tuple
            Kernel shape
        basis_type : str
            Basis type
        basis_norm_mode : str
            Basis normalization mode
        groups : int
            Number of groups
        grid_in : str
            Grid type for input
        grid_out : str
            Grid type for output
        transpose : bool
            Whether to transpose the convolution
        tol : float
            Tolerance for numerical equivalence
        """
        
332
        B, C, H, W = batch_size, num_chan, nlat_in, nlon_in
Boris Bonev's avatar
Boris Bonev committed
333
334
335
336
337
338

        disco_args = dict(
            in_channels=C,
            out_channels=C,
            in_shape=(nlat_in, nlon_in),
            out_shape=(nlat_out, nlon_out),
339
340
            basis_type=basis_type,
            basis_norm_mode=basis_norm_mode,
Boris Bonev's avatar
Boris Bonev committed
341
342
343
344
345
346
347
            kernel_shape=kernel_shape,
            groups=groups,
            grid_in=grid_in,
            grid_out=grid_out,
            bias=True,
        )

348
        # set up handles
349
        if transpose:
Boris Bonev's avatar
Boris Bonev committed
350
            conv_local = th.DiscreteContinuousConvTransposeS2(**disco_args).to(self.device)
351
            conv_dist = thd.DistributedDiscreteContinuousConvTransposeS2(**disco_args).to(self.device)
352
        else:
Boris Bonev's avatar
Boris Bonev committed
353
            conv_local = th.DiscreteContinuousConvS2(**disco_args).to(self.device)
354
            conv_dist = thd.DistributedDiscreteContinuousConvS2(**disco_args).to(self.device)
Boris Bonev's avatar
Boris Bonev committed
355

356
357
358
        # copy the weights from the local conv into the dist conv
        with torch.no_grad():
            conv_dist.weight.copy_(conv_local.weight)
Boris Bonev's avatar
Boris Bonev committed
359
360
            if disco_args["bias"]:
                conv_dist.bias.copy_(conv_local.bias)
361
362

        # create tensors
363
        inp_full = torch.randn((B, C, H, W), dtype=torch.float32, device=self.device)
Boris Bonev's avatar
Boris Bonev committed
364

365
        #############################################################
366
        # local conv
367
368
369
        #############################################################
        # FWD pass
        inp_full.requires_grad = True
Boris Bonev's avatar
Boris Bonev committed
370
371
        out_full = conv_local(inp_full)

372
373
374
375
        # create grad for backward
        with torch.no_grad():
            # create full grad
            ograd_full = torch.randn_like(out_full)
Boris Bonev's avatar
Boris Bonev committed
376

377
378
379
        # BWD pass
        out_full.backward(ograd_full)
        igrad_full = inp_full.grad.clone()
Boris Bonev's avatar
Boris Bonev committed
380

381
        #############################################################
382
        # distributed conv
383
384
385
386
        #############################################################
        # FWD pass
        inp_local = self._split_helper(inp_full)
        inp_local.requires_grad = True
Boris Bonev's avatar
Boris Bonev committed
387
388
        out_local = conv_dist(inp_local)

389
390
        # BWD pass
        ograd_local = self._split_helper(ograd_full)
Boris Bonev's avatar
Boris Bonev committed
391
        out_local = conv_dist(inp_local)
392
393
        out_local.backward(ograd_local)
        igrad_local = inp_local.grad.clone()
Boris Bonev's avatar
Boris Bonev committed
394

395
396
397
398
        #############################################################
        # evaluate FWD pass
        #############################################################
        with torch.no_grad():
399
            out_gather_full = self._gather_helper_fwd(out_local, B, C, conv_dist)
Boris Bonev's avatar
Boris Bonev committed
400
            err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2)))
401
402
403
            if self.world_rank == 0:
                print(f"final relative error of output: {err.item()}")
        self.assertTrue(err.item() <= tol)
Boris Bonev's avatar
Boris Bonev committed
404

405
406
407
408
        #############################################################
        # evaluate BWD pass
        #############################################################
        with torch.no_grad():
409
            igrad_gather_full = self._gather_helper_bwd(igrad_local, B, C, conv_dist)
Thorsten Kurth's avatar
Thorsten Kurth committed
410

Boris Bonev's avatar
Boris Bonev committed
411
            err = torch.mean(torch.norm(igrad_full - igrad_gather_full, p="fro", dim=(-1, -2)) / torch.norm(igrad_full, p="fro", dim=(-1, -2)))
412
413
414
415
416
            if self.world_rank == 0:
                print(f"final relative error of gradients: {err.item()}")
        self.assertTrue(err.item() <= tol)


Boris Bonev's avatar
Boris Bonev committed
417
if __name__ == "__main__":
418
    unittest.main()