test_attention.py 14.1 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
# coding=utf-8

Max Rietmann's avatar
Max Rietmann committed
3
# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
Boris Bonev's avatar
Boris Bonev committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import unittest
from parameterized import parameterized

# import math
import numpy as np
import torch
38
import torch.nn as nn
Boris Bonev's avatar
Boris Bonev committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61

# from torch.autograd import gradcheck
from torch_harmonics import AttentionS2, NeighborhoodAttentionS2

from torch_harmonics._neighborhood_attention import (
    _neighborhood_attention_s2_torch,
    _neighborhood_attention_s2_fwd_torch,
    _neighborhood_attention_s2_bwd_dv_torch,
    _neighborhood_attention_s2_bwd_dk_torch,
    _neighborhood_attention_s2_bwd_dq_torch,
)

# import custom C++/CUDA extensions
try:
    import attention_cuda_extension

    _cuda_extension_available = True
except ImportError as err:
    print(f"Warning: Couldn't Import cuda attention: {err}")
    attention_cuda_extension = None
    _cuda_extension_available = False


62
class TestNeighborhoodAttentionS2(unittest.TestCase):
Boris Bonev's avatar
Boris Bonev committed
63
64
65
66
67
68
69
    def setUp(self):
        if torch.cuda.is_available():
            self.device = torch.device("cuda:0")
            torch.cuda.set_device(self.device.index)
            torch.cuda.manual_seed(333)
        else:
            self.device = torch.device("cpu")
Thorsten Kurth's avatar
Thorsten Kurth committed
70
        torch.manual_seed(333)
Boris Bonev's avatar
Boris Bonev committed
71
72
73

    @parameterized.expand(
        [
74
75
76
77
78
79
            # Format: [batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol]
            [4, 4, 1, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
            [4, 4, 2, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
            [4, 4, 4, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
            [4, 4, 1, (6, 12), (6, 12), "legendre-gauss", "legendre-gauss", 1e-5, 1e-3],
            [4, 4, 1, (6, 12), (6, 12), "lobatto", "lobatto", 1e-5, 1e-3],
Thorsten Kurth's avatar
Thorsten Kurth committed
80
81
        ],
        skip_on_empty=True,
Boris Bonev's avatar
Boris Bonev committed
82
    )
83
84
    def test_custom_implementation(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=True):
        """Tests numerical equivalence between the custom (CUDA) implementation and the reference torch implementation"""
Boris Bonev's avatar
Boris Bonev committed
85
86
87
88

        nlat_in, nlon_in = in_shape
        nlat_out, nlon_out = out_shape

89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
        # Helper: create inputs
        inputs_ref = {
            "k": torch.randn(batch_size, channels, nlat_in, nlon_in, requires_grad=True, device=self.device, dtype=torch.float32),
            "v": torch.randn(batch_size, channels, nlat_in, nlon_in, requires_grad=True, device=self.device, dtype=torch.float32),
            "q": torch.randn(batch_size, channels, nlat_out, nlon_out, requires_grad=True, device=self.device, dtype=torch.float32),
        }
        inputs = {k: v.detach().clone().to(self.device).requires_grad_() for k, v in inputs_ref.items()}

        # reference input and model
        model_ref = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads, in_shape=in_shape, out_shape=out_shape, grid_in=grid_in, grid_out=grid_out, bias=True).to(
            self.device
        )

        # Device model and inputs
        model = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads, in_shape=in_shape, out_shape=out_shape, grid_in=grid_in, grid_out=grid_out, bias=True)

        # Synchronize parameters of model
        model.load_state_dict(model_ref.state_dict())
        model = model.to(self.device)
        for (name_ref, p_ref), (name, p) in zip(model_ref.named_parameters(), model.named_parameters()):
            assert torch.allclose(p_ref, p), f"Parameter mismatch: {name_ref} vs {name}"

        # reference forward passes
        out_ref = _neighborhood_attention_s2_torch(
            inputs_ref["k"],
            inputs_ref["v"],
            inputs_ref["q"] * model_ref.scale,
            model_ref.k_weights,
            model_ref.v_weights,
            model_ref.q_weights,
            model_ref.k_bias,
            model_ref.v_bias,
            model_ref.q_bias,
            model_ref.quad_weights,
            model_ref.psi_col_idx,
            model_ref.psi_roff_idx,
            model_ref.num_heads,
            model_ref.nlon_in,
            model_ref.nlat_out,
            model_ref.nlon_out,
        )
        out_ref = nn.functional.conv2d(out_ref, model_ref.proj_weights, bias=model_ref.proj_bias)
        out = model(inputs["q"], inputs["k"], inputs["v"])

        # Check forward equivalence
        self.assertTrue(torch.allclose(out, out_ref, atol=atol, rtol=rtol), "Forward outputs differ between torch reference and custom implementation")

        # Backward passes
        grad = torch.randn_like(out_ref)
        out_ref.backward(grad)
        out.backward(grad.to(self.device))

        # Check input gradient equivalence
        for inp in ["q", "k", "v"]:
            grad_ref = inputs_ref[inp].grad.cpu()
            grad = inputs[inp].grad.cpu()
            self.assertTrue(torch.allclose(grad, grad_ref, atol=atol, rtol=rtol), f"Input gradient mismatch in {inp}")

        # Check parameter gradient equivalence
        for p_ref, p in zip(model_ref.parameters(), model.parameters()):
            self.assertTrue(torch.allclose(p.grad, p_ref.grad, atol=atol, rtol=rtol), f"Parameter gradient mismatch: {type(p_ref).__name__}")

    # caution: multihead-implementation between full and neighborhood attention still seem to differ. tests are only done for single head
Boris Bonev's avatar
Boris Bonev committed
152
153
    @parameterized.expand(
        [
154
155
156
157
158
159
            # Format: [batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol]
            [4, 4, 1, (6, 12), (6, 12), "equiangular", "equiangular", 1e-2, 0],
            # [4, 4, 2, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
            # [4, 4, 4, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
            [4, 4, 1, (6, 12), (6, 12), "legendre-gauss", "legendre-gauss", 1e-2, 0],
            [4, 4, 1, (6, 12), (6, 12), "lobatto", "lobatto", 1e-2, 0],
Thorsten Kurth's avatar
Thorsten Kurth committed
160
161
        ],
        skip_on_empty=True,
Boris Bonev's avatar
Boris Bonev committed
162
    )
163
164
    def test_neighborhood_global_equivalence(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=True):
        """Tests numerical equivalence between the global spherical attention module and the neighborhood spherical attention module with the neighborhood set ot the whole sphere"""
Boris Bonev's avatar
Boris Bonev committed
165
166
167
168

        nlat_in, nlon_in = in_shape
        nlat_out, nlon_out = out_shape

169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
        # Helper: create inputs
        inputs_ref = {
            "k": torch.randn(batch_size, channels, nlat_in, nlon_in, requires_grad=True, device=self.device, dtype=torch.float32),
            "v": torch.randn(batch_size, channels, nlat_in, nlon_in, requires_grad=True, device=self.device, dtype=torch.float32),
            "q": torch.randn(batch_size, channels, nlat_out, nlon_out, requires_grad=True, device=self.device, dtype=torch.float32),
        }
        inputs = {k: v.detach().clone().to(self.device).requires_grad_() for k, v in inputs_ref.items()}

        # reference input and model
        model_ref = AttentionS2(in_channels=channels, num_heads=heads, in_shape=in_shape, out_shape=out_shape, grid_in=grid_in, grid_out=grid_out, bias=False).to(self.device)

        # Device model and inputs
        model = NeighborhoodAttentionS2(
            in_channels=channels, num_heads=heads, in_shape=in_shape, out_shape=out_shape, grid_in=grid_in, grid_out=grid_out, bias=False, theta_cutoff=2 * torch.pi
        )

        # Synchronize parameters of model
        model.load_state_dict(model_ref.state_dict())
        model = model.to(self.device)
        for (name_ref, p_ref), (name, p) in zip(model_ref.named_parameters(), model.named_parameters()):
            assert torch.allclose(p_ref, p), f"Parameter mismatch: {name_ref} vs {name}"

        # reference forward passes
        out_ref = model_ref(inputs_ref["q"], inputs_ref["k"], inputs_ref["v"])
        out = model(inputs["q"], inputs["k"], inputs["v"])

        # Check forward equivalence
        self.assertTrue(torch.allclose(out, out_ref, atol=atol, rtol=rtol), "Forward outputs differ between torch reference and custom implementation")

        # Backward passes
        grad = torch.randn_like(out_ref)
        out_ref.backward(grad)
        out.backward(grad.to(self.device))

        # Check input gradient equivalence
        for inp in ["q", "k", "v"]:
            grad_ref = inputs_ref[inp].grad
            grad = inputs[inp].grad
            self.assertTrue(torch.allclose(grad, grad_ref, atol=atol, rtol=rtol), f"Input gradient mismatch in {inp}")

        # Check parameter gradient equivalence - check only q,k, v weights
        for key in ["q_weights", "k_weights", "v_weights"]:
            grad_ref = getattr(model_ref, key).grad
            grad = getattr(model, key).grad
            self.assertTrue(torch.allclose(grad, grad_ref, atol=atol, rtol=rtol), f"Parameter gradient mismatch")
Boris Bonev's avatar
Boris Bonev committed
214
215


Thorsten Kurth's avatar
Thorsten Kurth committed
216
    @unittest.skipIf((not torch.cuda.is_available()) or (not _cuda_extension_available), "skipping performance test because CUDA is not available")
217
218
219
220
    @parameterized.expand(
        [
            # self attention
            [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
Thorsten Kurth's avatar
Thorsten Kurth committed
221
222
        ],
        skip_on_empty=True,
223
    )
Thorsten Kurth's avatar
Thorsten Kurth committed
224
    def test_perf(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=False):
225
226
227
228
229
230

        # extract some parameters
        nlat_in, nlon_in = in_shape
        nlat_out, nlon_out = out_shape

        # TODO: this test seems hardcoded for GPU. Is this necessary?
Thorsten Kurth's avatar
Thorsten Kurth committed
231
        k_gpu = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device=self.device)
232
        k_gpu.requires_grad = False
Thorsten Kurth's avatar
Thorsten Kurth committed
233
        v_gpu = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device=self.device)
234
        v_gpu.requires_grad = False
Thorsten Kurth's avatar
Thorsten Kurth committed
235
        q_gpu = torch.randn(batch_size, channels, nlat_out, nlon_out, dtype=torch.float32, device=self.device)
236
237
238
239
240
241
242
243
        q_gpu.requires_grad = False

        # set up layers
        time_layer_setup_start = torch.cuda.Event(enable_timing=True)
        time_layer_setup_end = torch.cuda.Event(enable_timing=True)
        time_layer_setup_start.record()
        att_gpu = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads,
                                          in_shape=in_shape, out_shape=out_shape,
Thorsten Kurth's avatar
Thorsten Kurth committed
244
                                          grid_in=grid_in, grid_out=grid_out, bias=True).to(self.device)
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
        time_layer_setup_end.record()
        torch.cuda.synchronize()

        # random weights
        with torch.no_grad():
            att_gpu.q_weights.normal_()
            att_gpu.k_weights.normal_()
            att_gpu.v_weights.normal_()
            att_gpu.q_bias.normal_()
            att_gpu.k_bias.normal_()
            att_gpu.v_bias.normal_()

            # time forward pass
            for i in range(2):
                # warmup
                out_gpu = att_gpu(q_gpu, k_gpu, v_gpu)
            time_forward_start = torch.cuda.Event(enable_timing=True)
            time_forward_end = torch.cuda.Event(enable_timing=True)
            time_forward_start.record()
            out_gpu = att_gpu(q_gpu, k_gpu, v_gpu)
            time_forward_end.record()
            torch.cuda.synchronize()
267
            elapsed_time = time_forward_start.elapsed_time(time_forward_end)
Thorsten Kurth's avatar
Thorsten Kurth committed
268
269
270
            if verbose:
                print(f"Forward execution time: {elapsed_time} ms")
            self.assertTrue(elapsed_time < 150)
271
272
273
274
275
276
277
278
279
280

        # sync weights:
        with torch.no_grad():
            att_gpu.q_weights.copy_(att_gpu.q_weights)
            att_gpu.k_weights.copy_(att_gpu.k_weights)
            att_gpu.v_weights.copy_(att_gpu.v_weights)
            att_gpu.q_bias.copy_(att_gpu.q_bias)
            att_gpu.k_bias.copy_(att_gpu.k_bias)
            att_gpu.v_bias.copy_(att_gpu.v_bias)

281
        q_gpu = q_gpu.detach().clone().to(self.device)#, memory_format=torch.channels_last)
282
        q_gpu.requires_grad = True
283
        k_gpu = k_gpu.detach().clone().to(self.device)#, memory_format=torch.channels_last)
284
        k_gpu.requires_grad = True
285
        v_gpu = v_gpu.detach().clone().to(self.device)#, memory_format=torch.channels_last)
286
287
288
        v_gpu.requires_grad = True

        out_gpu = att_gpu(q_gpu, k_gpu, v_gpu)
289
        out_grad = torch.randn(out_gpu.shape, dtype=torch.float32, device="cuda:0")
290
291
292
293
294
295
296
        time_backward_start = torch.cuda.Event(enable_timing=True)
        time_backward_end = torch.cuda.Event(enable_timing=True)

        for i in range(2):
            # warmup
            out_gpu.backward(out_grad, retain_graph=True)

297
        # print("out_grad_stride=",out_grad.stride())
298
299
300
301
        time_backward_start.record()
        out_gpu.backward(out_grad)
        time_backward_end.record()
        torch.cuda.synchronize()
302
        elapsed_time = time_backward_start.elapsed_time(time_backward_end)
Thorsten Kurth's avatar
Thorsten Kurth committed
303
304
305
        if verbose:
            print(f"Backward execution time: {elapsed_time} ms")
        self.assertTrue(elapsed_time < 400)
306
307


Boris Bonev's avatar
Boris Bonev committed
308
309
if __name__ == "__main__":
    unittest.main()