disco_helpers.cpp 5.19 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
// coding=utf-8
//
// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
Thorsten Kurth's avatar
Thorsten Kurth committed
5
//
Boris Bonev's avatar
Boris Bonev committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#include "disco.h"

Thorsten Kurth's avatar
Thorsten Kurth committed
33
34
35
36
template <typename REAL_T>
void preprocess_psi_kernel(int64_t nnz, int64_t K, int64_t Ho, int64_t *ker_h, int64_t *row_h, int64_t *col_h,
                           int64_t *roff_h, REAL_T *val_h, int64_t &nrows)
{
Boris Bonev's avatar
Boris Bonev committed
37

Thorsten Kurth's avatar
Thorsten Kurth committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
    int64_t *Koff = new int64_t[K];
    for (int i = 0; i < K; i++) { Koff[i] = 0; }

    for (int64_t i = 0; i < nnz; i++) { Koff[ker_h[i]]++; }

    int64_t prev = Koff[0];
    Koff[0] = 0;
    for (int i = 1; i < K; i++) {
        int64_t save = Koff[i];
        Koff[i] = prev + Koff[i - 1];
        prev = save;
    }

    int64_t *ker_sort = new int64_t[nnz];
    int64_t *row_sort = new int64_t[nnz];
    int64_t *col_sort = new int64_t[nnz];
    float *val_sort = new float[nnz];

    for (int64_t i = 0; i < nnz; i++) {

        const int64_t ker = ker_h[i];
        const int64_t off = Koff[ker]++;

        ker_sort[off] = ker;
        row_sort[off] = row_h[i];
        col_sort[off] = col_h[i];
        val_sort[off] = val_h[i];
    }
    for (int64_t i = 0; i < nnz; i++) {
        ker_h[i] = ker_sort[i];
        row_h[i] = row_sort[i];
        col_h[i] = col_sort[i];
        val_h[i] = val_sort[i];
    }

    delete[] Koff;
    delete[] ker_sort;
    delete[] row_sort;
    delete[] col_sort;
    delete[] val_sort;

    // compute rows offsets
    nrows = 1;
    roff_h[0] = 0;
    for (int64_t i = 1; i < nnz; i++) {

        if (row_h[i - 1] == row_h[i]) continue;
        roff_h[nrows++] = i;

        if (nrows > Ho * K) {
            fprintf(stderr, "%s:%d: error, found more rows in the K COOs than Ho*K (%ld)\n", __FILE__, __LINE__,
                    int64_t(Ho) * K);
            exit(EXIT_FAILURE);
        }
Boris Bonev's avatar
Boris Bonev committed
92
    }
Thorsten Kurth's avatar
Thorsten Kurth committed
93
    roff_h[nrows] = nnz;
Boris Bonev's avatar
Boris Bonev committed
94

Thorsten Kurth's avatar
Thorsten Kurth committed
95
    return;
Boris Bonev's avatar
Boris Bonev committed
96
97
}

Thorsten Kurth's avatar
Thorsten Kurth committed
98
99
100
torch::Tensor preprocess_psi(const int64_t K, const int64_t Ho, torch::Tensor ker_idx, torch::Tensor row_idx,
                             torch::Tensor col_idx, torch::Tensor val)
{
Boris Bonev's avatar
Boris Bonev committed
101

Thorsten Kurth's avatar
Thorsten Kurth committed
102
103
104
105
106
    CHECK_INPUT_TENSOR(ker_idx);
    CHECK_INPUT_TENSOR(row_idx);
    CHECK_INPUT_TENSOR(col_idx);
    CHECK_INPUT_TENSOR(val);

107
108
109
110
111
112
113
114
115
116
    // get the input device and make sure all tensors are on the same device
    auto device = ker_idx.device();
    TORCH_INTERNAL_ASSERT(device.type() == row_idx.device().type() && (device.type() == col_idx.device().type()) && (device.type() == val.device().type()));

    // move to cpu
    ker_idx = ker_idx.to(torch::kCPU);
    row_idx = row_idx.to(torch::kCPU);
    col_idx = col_idx.to(torch::kCPU);
    val = val.to(torch::kCPU);

Thorsten Kurth's avatar
Thorsten Kurth committed
117
118
119
120
121
122
123
124
125
126
127
128
129
    int64_t nnz = val.size(0);
    int64_t *ker_h = ker_idx.data_ptr<int64_t>();
    int64_t *row_h = row_idx.data_ptr<int64_t>();
    int64_t *col_h = col_idx.data_ptr<int64_t>();
    int64_t *roff_h = new int64_t[Ho * K + 1];
    int64_t nrows;

    AT_DISPATCH_FLOATING_TYPES(val.scalar_type(), "preprocess_psi", ([&] {
                                   preprocess_psi_kernel<scalar_t>(nnz, K, Ho, ker_h, row_h, col_h, roff_h,
                                                                   val.data_ptr<scalar_t>(), nrows);
                               }));

    // create output tensor
130
    auto roff_idx = torch::empty({nrows + 1}, row_idx.options());
Thorsten Kurth's avatar
Thorsten Kurth committed
131
132
133
134
135
    int64_t *roff_out_h = roff_idx.data_ptr<int64_t>();

    for (int64_t i = 0; i < (nrows + 1); i++) { roff_out_h[i] = roff_h[i]; }
    delete[] roff_h;

136
137
138
139
140
141
142
    // move to original device
    ker_idx = ker_idx.to(device);
    row_idx = row_idx.to(device);
    col_idx = col_idx.to(device);
    val = val.to(device);
    roff_idx = roff_idx.to(device);

Thorsten Kurth's avatar
Thorsten Kurth committed
143
    return roff_idx;
Boris Bonev's avatar
Boris Bonev committed
144
145
}

Thorsten Kurth's avatar
Thorsten Kurth committed
146
147
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
Thorsten Kurth's avatar
Thorsten Kurth committed
148
    m.def("preprocess_psi", &preprocess_psi, "Sort psi matrix, required for using disco_cuda.");
Boris Bonev's avatar
Boris Bonev committed
149
}