"src/turbomind/vscode:/vscode.git/clone" did not exist on "06125966d7054a53458086f342734ea01dc2faf4"
_disco_convolution.py 8.59 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import math

import torch
35
from torch.amp import custom_fwd, custom_bwd
36

Boris Bonev's avatar
Boris Bonev committed
37
38
39
40
try:
    import disco_cuda_extension
except ImportError as err:
    disco_cuda_extension = None
41
42


Boris Bonev's avatar
Boris Bonev committed
43
class _DiscoS2ContractionCuda(torch.autograd.Function):
44
    @staticmethod
Thorsten Kurth's avatar
Thorsten Kurth committed
45
    @custom_fwd(device_type="cuda")
Boris Bonev's avatar
Boris Bonev committed
46
47
48
49
50
51
    def forward(ctx, x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor,
                row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor,
                kernel_size: int, nlat_out: int, nlon_out: int):
        ctx.save_for_backward(roff_idx, ker_idx, row_idx, col_idx, vals)
        ctx.kernel_size = kernel_size
        ctx.nlat_in = x.shape[-2]
52
        ctx.nlon_in = x.shape[-1]
Thorsten Kurth's avatar
Thorsten Kurth committed
53
54
55
56
        xtype = x.dtype
        x = x.to(torch.float32).contiguous()
        output = disco_cuda_extension.forward(x, roff_idx, ker_idx, row_idx, col_idx, vals, kernel_size, nlat_out, nlon_out)
        output = output.to(xtype)
57

58
        return output
59
60

    @staticmethod
61
    @custom_bwd(device_type="cuda")
62
    def backward(ctx, grad_output):
Boris Bonev's avatar
Boris Bonev committed
63
        roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors
Thorsten Kurth's avatar
Thorsten Kurth committed
64
65
66
        gtype =	grad_output.dtype
        grad_output = grad_output.to(torch.float32).contiguous()
        grad_input = disco_cuda_extension.backward(grad_output, roff_idx, ker_idx, row_idx, col_idx, vals,
Boris Bonev's avatar
Boris Bonev committed
67
                                         ctx.kernel_size, ctx.nlat_in, ctx.nlon_in)
Thorsten Kurth's avatar
Thorsten Kurth committed
68
        grad_input = grad_input.to(gtype)
69

Boris Bonev's avatar
Boris Bonev committed
70
        return grad_input, None, None, None, None, None, None, None, None
71

Boris Bonev's avatar
Boris Bonev committed
72
73

class _DiscoS2TransposeContractionCuda(torch.autograd.Function):
74
    @staticmethod
Thorsten Kurth's avatar
Thorsten Kurth committed
75
    @custom_fwd(device_type="cuda")
Boris Bonev's avatar
Boris Bonev committed
76
77
78
79
80
81
    def forward(ctx, x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor,
                row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor,
                kernel_size: int, nlat_out: int, nlon_out: int):
        ctx.save_for_backward(roff_idx, ker_idx, row_idx, col_idx, vals)
        ctx.kernel_size = kernel_size
        ctx.nlat_in = x.shape[-2]
82
        ctx.nlon_in = x.shape[-1]
Thorsten Kurth's avatar
Thorsten Kurth committed
83
84
85
86
        xtype =	x.dtype
        x = x.to(torch.float32).contiguous()
        output = disco_cuda_extension.backward(x, roff_idx, ker_idx, row_idx, col_idx, vals, kernel_size, nlat_out, nlon_out)
        output = output.to(xtype)
87

88
        return output
89
90

    @staticmethod
91
    @custom_bwd(device_type="cuda")
92
    def backward(ctx, grad_output):
Boris Bonev's avatar
Boris Bonev committed
93
        roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors
Thorsten Kurth's avatar
Thorsten Kurth committed
94
95
96
        gtype = grad_output.dtype
        grad_output = grad_output.to(torch.float32).contiguous()
        grad_input = disco_cuda_extension.forward(grad_output, roff_idx, ker_idx, row_idx, col_idx, vals,
Boris Bonev's avatar
Boris Bonev committed
97
                                        ctx.kernel_size, ctx.nlat_in, ctx.nlon_in)
Thorsten Kurth's avatar
Thorsten Kurth committed
98
        grad_input = grad_input.to(gtype)
99

Boris Bonev's avatar
Boris Bonev committed
100
        return grad_input, None, None, None, None, None, None, None, None
101

Boris Bonev's avatar
Boris Bonev committed
102
103
104
105
106
107
# CUDA
def _disco_s2_contraction_cuda(x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor,
                               row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor,
                               kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
    return _DiscoS2ContractionCuda.apply(x, roff_idx, ker_idx, row_idx, col_idx, vals,
                                         kernel_size, nlat_out, nlon_out)
108

Boris Bonev's avatar
Boris Bonev committed
109
110
111
112
113
def _disco_s2_transpose_contraction_cuda(x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor,
                                         row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor,
                                         kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
    return _DiscoS2TransposeContractionCuda.apply(x, roff_idx, ker_idx, row_idx, col_idx, vals,
                                                  kernel_size, nlat_out, nlon_out)
114
115
116
117
118
119


def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: int):
    """
    Reference implementation of the custom contraction as described in [1]. This requires repeated
    shifting of the input tensor, which can potentially be costly. For an efficient implementation
Boris Bonev's avatar
Boris Bonev committed
120
    on GPU, make sure to use the custom kernel written in CUDA.
121
122
123
124
125
126
127
128
129
130
    """
    assert len(psi.shape) == 3
    assert len(x.shape) == 4
    psi = psi.to(x.device)

    batch_size, n_chans, nlat_in, nlon_in = x.shape
    kernel_size, nlat_out, _ = psi.shape

    assert psi.shape[-1] == nlat_in * nlon_in
    assert nlon_in % nlon_out == 0
Boris Bonev's avatar
Boris Bonev committed
131
    assert nlon_in >= nlat_out
132
133
    pscale = nlon_in // nlon_out

134
    # add a dummy dimension for nkernel and move the batch and channel dims to the end
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
    x = x.reshape(1, batch_size * n_chans, nlat_in, nlon_in).permute(0, 2, 3, 1)
    x = x.expand(kernel_size, -1, -1, -1)

    y = torch.zeros(nlon_out, kernel_size, nlat_out, batch_size * n_chans, device=x.device, dtype=x.dtype)

    for pout in range(nlon_out):
        # sparse contraction with psi
        y[pout] = torch.bmm(psi, x.reshape(kernel_size, nlat_in * nlon_in, -1))
        # we need to repeatedly roll the input tensor to faciliate the shifted multiplication
        x = torch.roll(x, -pscale, dims=2)

    # reshape y back to expose the correct dimensions
    y = y.permute(3, 1, 2, 0).reshape(batch_size, n_chans, kernel_size, nlat_out, nlon_out)

    return y


def _disco_s2_transpose_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: int):
    """
    Reference implementation of the custom contraction as described in [1]. This requires repeated
    shifting of the input tensor, which can potentially be costly. For an efficient implementation
Boris Bonev's avatar
Boris Bonev committed
156
    on GPU, make sure to use the custom kernel written in CUDA.
157
158
159
160
161
162
    """
    assert len(psi.shape) == 3
    assert len(x.shape) == 5
    psi = psi.to(x.device)

    batch_size, n_chans, kernel_size, nlat_in, nlon_in = x.shape
163
    kernel_size, nlat_out, n_out = psi.shape
164
165

    assert n_out % nlon_out == 0
Boris Bonev's avatar
Boris Bonev committed
166
    assert nlon_out >= nlon_in
167
168
169
170
    pscale = nlon_out // nlon_in

    # interleave zeros along the longitude dimension to allow for fractional offsets to be considered
    x_ext = torch.zeros(kernel_size, nlat_in, nlon_out, batch_size * n_chans, device=x.device, dtype=x.dtype)
171
    x = x.reshape(batch_size * n_chans, kernel_size, nlat_in, nlon_in).permute(1, 2, 3, 0)
Boris Bonev's avatar
Boris Bonev committed
172

173
174
175
    # x has shape kernel_size x nlat_in x nlon_in x batch_size * n_chans
    # we only need to apoply the nlon stride here, since nlat stride is taken care of by the kernel
    x_ext[:, :, ::pscale, :] = x[...]
176

177
    # create output tensor
178
179
180
181
182
183
184
    y = torch.zeros(kernel_size, nlon_out, nlat_out, batch_size * n_chans, device=x.device, dtype=x.dtype)

    for pout in range(nlon_out):
        # we need to repeatedly roll the input tensor to faciliate the shifted multiplication
        # TODO: double-check why this has to happen first
        x_ext = torch.roll(x_ext, -1, dims=2)
        # sparse contraction with the modified psi
185
        y[:, pout, :, :] = torch.bmm(psi, x_ext.reshape(kernel_size, nlat_in * nlon_out, -1))
186
187

    # sum over the kernel dimension and reshape to the correct output size
188
    y = y.sum(dim=0).permute(2, 1, 0).reshape(batch_size, n_chans, nlat_out, nlon_out).contiguous()
189
190
191

    return y