attention_utils.cu 12.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
// coding=utf-8
//
// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
// 
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#include "attention.cuh"
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <ATen/cuda/detail/KernelUtils.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/CUDAUtils.h>

#include <cuda_runtime.h>

#include <cub/cub.cuh>
#include <limits>

#include "cudamacro.h"
#include "attention_utils.cuh"

#define THREADS (64)

#define TRANSP_WARPS_X_TILE_GENERIC (32)
#define TRANSP_WARPS_X_TILE_SM100    (4)

// BEGIN - CSR rows sorting kernels and functions
__global__ void set_rlen_rids_k(const int n,
                                const int64_t *__restrict__ offs,
                                      int *__restrict__ rids,
                                      int *__restrict__ rlen) {

    const int nth = gridDim.x*blockDim.x;
    const int tid = blockIdx.x*blockDim.x + threadIdx.x;

    for(int i = tid; i < n; i += nth) {
        rids[i] = i;
        rlen[i] = offs[i+1]-offs[i];
    }

    return;
}   

at::Tensor sortRows(int nlat_out, at::Tensor row_off, cudaStream_t stream) {

    int64_t *_row_off_d = reinterpret_cast<int64_t *>(row_off.data_ptr());

    auto options = torch::TensorOptions().dtype(torch::kInt32).device(row_off.device());

    torch::Tensor rids_d = torch::empty({nlat_out}, options);
    torch::Tensor rlen_d = torch::empty({nlat_out}, options);

    int *_rids_d = reinterpret_cast<int *>(rids_d.data_ptr());
    int *_rlen_d = reinterpret_cast<int *>(rlen_d.data_ptr());

    const int grid = DIV_UP(nlat_out, THREADS);
    const int block = THREADS;

    set_rlen_rids_k<<<grid, block, 0, stream>>>(nlat_out,
                                                _row_off_d,
                                                _rids_d,
                                                _rlen_d);

    torch::Tensor rids_sort_d = torch::empty({nlat_out}, options);
    torch::Tensor rlen_sort_d = torch::empty({nlat_out}, options);

    int *_rids_sort_d = reinterpret_cast<int *>(rids_sort_d.data_ptr());
    int *_rlen_sort_d = reinterpret_cast<int *>(rlen_sort_d.data_ptr());

    size_t temp_storage_bytes = 0;
    CHECK_CUDA(cub::DeviceRadixSort::SortPairsDescending(NULL, temp_storage_bytes,
                                                         _rlen_d, _rlen_sort_d, 
                                                         _rids_d, _rids_sort_d,
                                                         nlat_out, 0, sizeof(*_rlen_d)*8, stream));

    options = torch::TensorOptions().dtype(torch::kByte).device(row_off.device());
    torch::Tensor temp_storage_d = torch::empty({int64_t(temp_storage_bytes)}, options);

    void *_temp_storage_d = reinterpret_cast<void *>(temp_storage_d.data_ptr());

    CHECK_CUDA(cub::DeviceRadixSort::SortPairsDescending(_temp_storage_d, temp_storage_bytes,
                                                         _rlen_d, _rlen_sort_d, 
                                                         _rids_d, _rids_sort_d,
                                                         nlat_out, 0, sizeof(*_rlen_d)*8, stream));
    return rids_sort_d;
}
// END - CSR rows sorting kernels and functions


// BEGIN - 4D tensor permutation kernels and functions
template<int BDIM_X,
         int BDIM_Y,
         typename VAL_T>
__global__
__launch_bounds__(BDIM_X*BDIM_Y)
void  permute_to0231_k(const int nchn,
                       const int nlat,
                       const int nlon,
                       const torch::PackedTensorAccessor32<VAL_T, 4, torch::RestrictPtrTraits> src,
                             torch::PackedTensorAccessor32<VAL_T, 4, torch::RestrictPtrTraits> dst) {

    static_assert(!(BDIM_X & (BDIM_X-1)));
    static_assert(!(BDIM_Y & (BDIM_Y-1)));
    static_assert(BDIM_X >= BDIM_Y);

    __shared__ VAL_T sh[BDIM_X][BDIM_X+1];

    const int tidx = threadIdx.x;
    const int tidy = threadIdx.y;

    const int coff  = blockIdx.x*BDIM_X;           // channel offset
    const int woff  = blockIdx.y*BDIM_X;           // width offset
    const int batch = blockIdx.z / nlat;           // batch (same for all block)
    const int h     = blockIdx.z - (batch * nlat); // height (same for all block)

    const int nchn_full = (nchn-coff) >= BDIM_X;
    const int nlon_full = (nlon-woff) >= BDIM_X;

    if (nchn_full && nlon_full) {
        #pragma unroll
        for(int j = 0; j < BDIM_X; j += BDIM_Y) {
            sh[j+tidy][tidx] = src[batch][coff + j+tidy][h][woff+tidx];
        }
        __syncthreads();

        #pragma unroll
        for(int j = 0; j < BDIM_X; j += BDIM_Y) {
            dst[batch][h][woff + j+tidy][coff+tidx] = sh[tidx][j+tidy];
        }
    } else {
        if (woff+tidx < nlon) {
            #pragma unroll
            for(int j = 0; j < BDIM_X; j += BDIM_Y) {
                sh[j+tidy][tidx] = (coff + j+tidy < nchn) ? src[batch][coff + j+tidy][h][woff+tidx] : 0.f;
            }
        }
        __syncthreads();

        if (coff+tidx < nchn) {
            #pragma unroll
            for(int j = 0; j < BDIM_X; j += BDIM_Y) {
                if (woff + j+tidy < nlon) {
                    dst[batch][h][woff + j+tidy][coff+tidx] = sh[tidx][j+tidy];
                }
            }
        }
    }
    return;
}

__global__ void empty_k() {}

static int getPtxver() {
    cudaFuncAttributes attrs;
    CHECK_CUDA(cudaFuncGetAttributes(&attrs, empty_k));
    return attrs.ptxVersion*10;
}

at::Tensor permute_4D_floatT_to0231(at::Tensor src, cudaStream_t stream) {

    dim3 block;
    dim3 grid;

    block.x = WARP_SIZE;
    grid.x = DIV_UP(src.size(1), block.x);
    grid.y = DIV_UP(src.size(3), block.x);
    grid.z = src.size(2)*src.size(0);

    assert(grid.y < 65536);
    assert(grid.z < 65536);

    auto options = torch::TensorOptions().dtype(torch::kFloat32).device(src.device());
    torch::Tensor dst = torch::empty({src.size(0), src.size(2), src.size(3), src.size(1)}, options);

    const int ptxv = getPtxver();

    // to be further specialized for additional archs, if necessary
    if (ptxv < 100) {
        block.y = TRANSP_WARPS_X_TILE_GENERIC;
        permute_to0231_k<WARP_SIZE, TRANSP_WARPS_X_TILE_GENERIC>
                        <<<grid, block, 0, stream>>>(src.size(1),
                                                     src.size(2),
                                                     src.size(3),
                                                     src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
                                                     dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>());
    } else {
        block.y = TRANSP_WARPS_X_TILE_SM100;
        permute_to0231_k<WARP_SIZE, TRANSP_WARPS_X_TILE_SM100>
                        <<<grid, block, 0, stream>>>(src.size(1),
                                                     src.size(2),
                                                     src.size(3),
                                                     src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
                                                     dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>());
    }

    return dst;
}

template<int BDIM_X,
         int BDIM_Y,
         typename VAL_T>
__global__
__launch_bounds__(BDIM_X*BDIM_Y)
void  permute_to0312_k(const int nchn,
                       const int nlat,
                       const int nlon,
                       const torch::PackedTensorAccessor32<VAL_T, 4, torch::RestrictPtrTraits> src,
                             torch::PackedTensorAccessor32<VAL_T, 4, torch::RestrictPtrTraits> dst) {

    static_assert(!(BDIM_X & (BDIM_X-1)));
    static_assert(!(BDIM_Y & (BDIM_Y-1)));
    static_assert(BDIM_X >= BDIM_Y);

    __shared__ VAL_T sh[BDIM_X][BDIM_X+1];

    const int tidx = threadIdx.x;
    const int tidy = threadIdx.y;

    const int woff  = blockIdx.x*BDIM_X;           // width offset
    const int coff  = blockIdx.y*BDIM_X;           // channel offset
    const int batch = blockIdx.z / nlat;           // batch (same for all block)
    const int h     = blockIdx.z - (batch * nlat); // height (same for all block)

    const int nchn_full = (nchn-coff) >= BDIM_X;
    const int nlon_full = (nlon-woff) >= BDIM_X;

    if (nchn_full && nlon_full) {
        #pragma unroll
        for(int j = 0; j < BDIM_X; j += BDIM_Y) {
            sh[j+tidy][tidx] = src[batch][h][woff + j+tidy][coff+tidx];
        }
        __syncthreads();

        #pragma unroll
        for(int j = 0; j < BDIM_X; j += BDIM_Y) {
            dst[batch][coff + j+tidy][h][woff+tidx] = sh[tidx][j+tidy];
        }
    } else {
        if (coff+tidx < nchn) {
            #pragma unroll
            for(int j = 0; j < BDIM_X; j += BDIM_Y) {
                sh[j+tidy][tidx] = (woff + j+tidy < nlon) ? src[batch][h][woff + j+tidy][coff+tidx] : 0.f;
            }
        }
        __syncthreads();

        if (woff+tidx < nlon) {
            #pragma unroll
            for(int j = 0; j < BDIM_X; j += BDIM_Y) {
                if (coff + j+tidy < nchn) {
                    dst[batch][coff + j+tidy][h][woff+tidx] = sh[tidx][j+tidy];;
                }
            }
        }
    }
    return;
}

at::Tensor permute_4D_floatT_to0312(at::Tensor src, cudaStream_t stream) {

    dim3 block;
    dim3 grid;

    block.x = WARP_SIZE;
    grid.x = DIV_UP(src.size(2), block.x);
    grid.y = DIV_UP(src.size(3), block.x);
    grid.z = src.size(1)*src.size(0);

    assert(grid.y < 65536);
    assert(grid.z < 65536);

    auto options = torch::TensorOptions().dtype(torch::kFloat32).device(src.device());
    torch::Tensor dst = torch::empty({src.size(0), src.size(3), src.size(1), src.size(2)}, options);

    const int ptxv = getPtxver();

    // to be further specialized for additional archs, if necessary
    if (ptxv < 100) {
        block.y = TRANSP_WARPS_X_TILE_GENERIC;
        permute_to0312_k<WARP_SIZE, TRANSP_WARPS_X_TILE_GENERIC>
                        <<<grid, block, 0, stream>>>(src.size(3),
                                                     src.size(1),
                                                     src.size(2),
                                                     src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
                                                     dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>());
    } else {
        block.y = TRANSP_WARPS_X_TILE_SM100;
        permute_to0312_k<WARP_SIZE, TRANSP_WARPS_X_TILE_SM100>
                        <<<grid, block, 0, stream>>>(src.size(3),
                                                     src.size(1),
                                                     src.size(2),
                                                     src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
                                                     dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>());
    }

    return dst;
}
// END - tensor permutation kernels and functions

// BEGIN - general host-side functions
unsigned int next_pow2(unsigned int x) { 

    x -= 1;

    #pragma unroll
    for(int i = 1; i <= sizeof(x)*8 / 2; i *= 2) {
        x |= x >> i;    
    }
    return x+1;
}
// END - general host-side functions