test_distributed_convolution.py 12.1 KB
Newer Older
1
2
3
4
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
Boris Bonev's avatar
Boris Bonev committed
5
#
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import os
import unittest
from parameterized import parameterized

import torch
import torch.nn.functional as F
import torch.distributed as dist
Boris Bonev's avatar
Boris Bonev committed
39
import torch_harmonics as th
40
41
42
import torch_harmonics.distributed as thd


43
class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
44
    """Test the distributed discrete-continuous convolution module."""
45
46
47
48

    @classmethod
    def setUpClass(cls):
        # set up distributed
Boris Bonev's avatar
Boris Bonev committed
49
50
51
52
53
        cls.world_rank = int(os.getenv("WORLD_RANK", 0))
        cls.grid_size_h = int(os.getenv("GRID_H", 1))
        cls.grid_size_w = int(os.getenv("GRID_W", 1))
        port = int(os.getenv("MASTER_PORT", "29501"))
        master_address = os.getenv("MASTER_ADDR", "localhost")
54
55
56
57
58
59
60
        cls.world_size = cls.grid_size_h * cls.grid_size_w

        if torch.cuda.is_available():
            if cls.world_rank == 0:
                print("Running test on GPU")
            local_rank = cls.world_rank % torch.cuda.device_count()
            cls.device = torch.device(f"cuda:{local_rank}")
61
            torch.cuda.set_device(local_rank)
62
            torch.cuda.manual_seed(333)
Boris Bonev's avatar
Boris Bonev committed
63
            proc_backend = "nccl"
64
65
66
        else:
            if cls.world_rank == 0:
                print("Running test on CPU")
Boris Bonev's avatar
Boris Bonev committed
67
68
            cls.device = torch.device("cpu")
            proc_backend = "gloo"
69
70
        torch.manual_seed(333)

Boris Bonev's avatar
Boris Bonev committed
71
72
        dist.init_process_group(backend=proc_backend, init_method=f"tcp://{master_address}:{port}", rank=cls.world_rank, world_size=cls.world_size)

73
74
75
76
        cls.wrank = cls.world_rank % cls.grid_size_w
        cls.hrank = cls.world_rank // cls.grid_size_w

        # now set up the comm groups:
Boris Bonev's avatar
Boris Bonev committed
77
        # set default
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
        cls.w_group = None
        cls.h_group = None

        # do the init
        wgroups = []
        for w in range(0, cls.world_size, cls.grid_size_w):
            start = w
            end = w + cls.grid_size_w
            wgroups.append(list(range(start, end)))

        if cls.world_rank == 0:
            print("w-groups:", wgroups)
        for grp in wgroups:
            if len(grp) == 1:
                continue
            tmp_group = dist.new_group(ranks=grp)
            if cls.world_rank in grp:
                cls.w_group = tmp_group

        # transpose:
        hgroups = [sorted(list(i)) for i in zip(*wgroups)]

        if cls.world_rank == 0:
            print("h-groups:", hgroups)
        for grp in hgroups:
            if len(grp) == 1:
                continue
            tmp_group = dist.new_group(ranks=grp)
            if cls.world_rank in grp:
                cls.h_group = tmp_group

        if cls.world_rank == 0:
            print(f"Running distributed tests on grid H x W = {cls.grid_size_h} x {cls.grid_size_w}")

        # initializing sht
        thd.init(cls.h_group, cls.w_group)

115
116
    @classmethod
    def tearDownClass(cls):
117
118
        thd.finalize()
        dist.destroy_process_group(None)
119

120
    def _split_helper(self, tensor):
apaaris's avatar
apaaris committed
121

122
123
124
125
126
127
128
129
130
131
        with torch.no_grad():
            # split in W
            tensor_list_local = thd.split_tensor_along_dim(tensor, dim=-1, num_chunks=self.grid_size_w)
            tensor_local = tensor_list_local[self.wrank]

            # split in H
            tensor_list_local = thd.split_tensor_along_dim(tensor_local, dim=-2, num_chunks=self.grid_size_h)
            tensor_local = tensor_list_local[self.hrank]

        return tensor_local
Boris Bonev's avatar
Boris Bonev committed
132

133
    def _gather_helper_fwd(self, tensor, B, C, convolution_dist):
apaaris's avatar
apaaris committed
134

135
        # we need the shapes
136
137
        lat_shapes = convolution_dist.lat_out_shapes
        lon_shapes = convolution_dist.lon_out_shapes
138
139

        # gather in W
Thorsten Kurth's avatar
Thorsten Kurth committed
140
        tensor = tensor.contiguous()
141
        if self.grid_size_w > 1:
142
            gather_shapes = [(B, C, lat_shapes[self.hrank], w) for w in lon_shapes]
143
144
145
146
147
148
149
150
            olist = [torch.empty(shape, dtype=tensor.dtype, device=tensor.device) for shape in gather_shapes]
            olist[self.wrank] = tensor
            dist.all_gather(olist, tensor, group=self.w_group)
            tensor_gather = torch.cat(olist, dim=-1)
        else:
            tensor_gather = tensor

        # gather in H
Thorsten Kurth's avatar
Thorsten Kurth committed
151
        tensor_gather = tensor_gather.contiguous()
152
        if self.grid_size_h > 1:
153
            gather_shapes = [(B, C, h, convolution_dist.nlon_out) for h in lat_shapes]
154
155
156
157
158
159
160
            olist = [torch.empty(shape, dtype=tensor_gather.dtype, device=tensor_gather.device) for shape in gather_shapes]
            olist[self.hrank] = tensor_gather
            dist.all_gather(olist, tensor_gather, group=self.h_group)
            tensor_gather = torch.cat(olist, dim=-2)

        return tensor_gather

161
    def _gather_helper_bwd(self, tensor, B, C, convolution_dist):
162

163
        # we need the shapes
164
165
        lat_shapes = convolution_dist.lat_in_shapes
        lon_shapes = convolution_dist.lon_in_shapes
166
167
168

        # gather in W
        if self.grid_size_w > 1:
169
            gather_shapes = [(B, C, lat_shapes[self.hrank], w) for w in lon_shapes]
170
171
172
173
174
175
176
177
178
            olist = [torch.empty(shape, dtype=tensor.dtype, device=tensor.device) for shape in gather_shapes]
            olist[self.wrank] = tensor
            dist.all_gather(olist, tensor, group=self.w_group)
            tensor_gather = torch.cat(olist, dim=-1)
        else:
            tensor_gather = tensor

        # gather in H
        if self.grid_size_h > 1:
179
            gather_shapes = [(B, C, h, convolution_dist.nlon_in) for h in lat_shapes]
180
181
182
183
184
185
186
            olist = [torch.empty(shape, dtype=tensor_gather.dtype, device=tensor_gather.device) for shape in gather_shapes]
            olist[self.hrank] = tensor_gather
            dist.all_gather(olist, tensor_gather, group=self.h_group)
            tensor_gather = torch.cat(olist, dim=-2)

        return tensor_gather

Boris Bonev's avatar
Boris Bonev committed
187
188
    @parameterized.expand(
        [
Thorsten Kurth's avatar
Thorsten Kurth committed
189
190
191
192
193
194
195
196
197
198
199
200
            [128, 256, 128, 256, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
            [129, 256, 128, 256, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
            [128, 256, 128, 256, 32, 8, (3, 2), "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
            [128, 256, 64, 128, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
            [128, 256, 128, 256, 32, 8, (3), "piecewise linear", "mean", 2, "equiangular", "equiangular", False, 1e-5],
            [128, 256, 128, 256, 32, 6, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
            [128, 256, 128, 256, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
            [129, 256, 129, 256, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
            [128, 256, 128, 256, 32, 8, (3, 2), "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
            [64, 128, 128, 256, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
            [128, 256, 128, 256, 32, 8, (3), "piecewise linear", "mean", 2, "equiangular", "equiangular", True, 1e-5],
            [128, 256, 128, 256, 32, 6, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
201
202
203
204
            [129, 256, 129, 256, 32, 8, (3, 4), "morlet", "mean", 1, "equiangular", "equiangular", False, 1e-5],
            [129, 256, 129, 256, 32, 8, (3, 4), "morlet", "mean", 1, "equiangular", "equiangular", True, 1e-5],
            [65, 128, 129, 256, 32, 8, (3, 4), "morlet", "mean", 1, "equiangular", "equiangular", True, 1e-5],
            [129, 256, 65, 128, 32, 8, (3, 4), "morlet", "mean", 1, "equiangular", "equiangular", False, 1e-5],
Boris Bonev's avatar
Boris Bonev committed
205
206
        ]
    )
207
208
209
    def test_distributed_disco_conv(
        self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, kernel_shape, basis_type, basis_norm_mode, groups, grid_in, grid_out, transpose, tol
    ):
apaaris's avatar
apaaris committed
210
        
211
        B, C, H, W = batch_size, num_chan, nlat_in, nlon_in
Boris Bonev's avatar
Boris Bonev committed
212
213
214
215
216
217

        disco_args = dict(
            in_channels=C,
            out_channels=C,
            in_shape=(nlat_in, nlon_in),
            out_shape=(nlat_out, nlon_out),
218
219
            basis_type=basis_type,
            basis_norm_mode=basis_norm_mode,
Boris Bonev's avatar
Boris Bonev committed
220
221
222
223
224
225
226
            kernel_shape=kernel_shape,
            groups=groups,
            grid_in=grid_in,
            grid_out=grid_out,
            bias=True,
        )

227
        # set up handles
228
        if transpose:
Boris Bonev's avatar
Boris Bonev committed
229
            conv_local = th.DiscreteContinuousConvTransposeS2(**disco_args).to(self.device)
230
            conv_dist = thd.DistributedDiscreteContinuousConvTransposeS2(**disco_args).to(self.device)
231
        else:
Boris Bonev's avatar
Boris Bonev committed
232
            conv_local = th.DiscreteContinuousConvS2(**disco_args).to(self.device)
233
            conv_dist = thd.DistributedDiscreteContinuousConvS2(**disco_args).to(self.device)
Boris Bonev's avatar
Boris Bonev committed
234

235
236
237
        # copy the weights from the local conv into the dist conv
        with torch.no_grad():
            conv_dist.weight.copy_(conv_local.weight)
Boris Bonev's avatar
Boris Bonev committed
238
239
            if disco_args["bias"]:
                conv_dist.bias.copy_(conv_local.bias)
240
241

        # create tensors
242
        inp_full = torch.randn((B, C, H, W), dtype=torch.float32, device=self.device)
Boris Bonev's avatar
Boris Bonev committed
243

244
        # local conv
245
246
        # FWD pass
        inp_full.requires_grad = True
Boris Bonev's avatar
Boris Bonev committed
247
248
        out_full = conv_local(inp_full)

249
250
251
252
        # create grad for backward
        with torch.no_grad():
            # create full grad
            ograd_full = torch.randn_like(out_full)
Boris Bonev's avatar
Boris Bonev committed
253

254
255
256
        # BWD pass
        out_full.backward(ograd_full)
        igrad_full = inp_full.grad.clone()
Boris Bonev's avatar
Boris Bonev committed
257

258
        # distributed conv
259
260
261
        # FWD pass
        inp_local = self._split_helper(inp_full)
        inp_local.requires_grad = True
Boris Bonev's avatar
Boris Bonev committed
262
263
        out_local = conv_dist(inp_local)

264
265
        # BWD pass
        ograd_local = self._split_helper(ograd_full)
Boris Bonev's avatar
Boris Bonev committed
266
        out_local = conv_dist(inp_local)
267
268
        out_local.backward(ograd_local)
        igrad_local = inp_local.grad.clone()
Boris Bonev's avatar
Boris Bonev committed
269

270
271
        # evaluate FWD pass
        with torch.no_grad():
272
            out_gather_full = self._gather_helper_fwd(out_local, B, C, conv_dist)
Boris Bonev's avatar
Boris Bonev committed
273
            err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2)))
274
275
276
            if self.world_rank == 0:
                print(f"final relative error of output: {err.item()}")
        self.assertTrue(err.item() <= tol)
Boris Bonev's avatar
Boris Bonev committed
277

278
279
        # evaluate BWD pass
        with torch.no_grad():
280
            igrad_gather_full = self._gather_helper_bwd(igrad_local, B, C, conv_dist)
Thorsten Kurth's avatar
Thorsten Kurth committed
281

Boris Bonev's avatar
Boris Bonev committed
282
            err = torch.mean(torch.norm(igrad_full - igrad_gather_full, p="fro", dim=(-1, -2)) / torch.norm(igrad_full, p="fro", dim=(-1, -2)))
283
284
285
286
287
            if self.world_rank == 0:
                print(f"final relative error of gradients: {err.item()}")
        self.assertTrue(err.item() <= tol)


Boris Bonev's avatar
Boris Bonev committed
288
if __name__ == "__main__":
289
    unittest.main()