disco_helpers.cpp 4.48 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
// coding=utf-8
//
// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
Thorsten Kurth's avatar
Thorsten Kurth committed
5
//
Boris Bonev's avatar
Boris Bonev committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#include "disco.h"

Thorsten Kurth's avatar
Thorsten Kurth committed
33
34
35
36
template <typename REAL_T>
void preprocess_psi_kernel(int64_t nnz, int64_t K, int64_t Ho, int64_t *ker_h, int64_t *row_h, int64_t *col_h,
                           int64_t *roff_h, REAL_T *val_h, int64_t &nrows)
{
Boris Bonev's avatar
Boris Bonev committed
37
38

  int64_t *Koff = new int64_t[K];
Thorsten Kurth's avatar
Thorsten Kurth committed
39
  for (int i = 0; i < K; i++) { Koff[i] = 0; }
Boris Bonev's avatar
Boris Bonev committed
40

Thorsten Kurth's avatar
Thorsten Kurth committed
41
  for (int64_t i = 0; i < nnz; i++) { Koff[ker_h[i]]++; }
Boris Bonev's avatar
Boris Bonev committed
42
43
44

  int64_t prev = Koff[0];
  Koff[0] = 0;
Thorsten Kurth's avatar
Thorsten Kurth committed
45
  for (int i = 1; i < K; i++) {
Boris Bonev's avatar
Boris Bonev committed
46
    int64_t save = Koff[i];
Thorsten Kurth's avatar
Thorsten Kurth committed
47
    Koff[i] = prev + Koff[i - 1];
Boris Bonev's avatar
Boris Bonev committed
48
49
50
51
52
53
    prev = save;
  }

  int64_t *ker_sort = new int64_t[nnz];
  int64_t *row_sort = new int64_t[nnz];
  int64_t *col_sort = new int64_t[nnz];
Thorsten Kurth's avatar
Thorsten Kurth committed
54
  float *val_sort = new float[nnz];
Boris Bonev's avatar
Boris Bonev committed
55

Thorsten Kurth's avatar
Thorsten Kurth committed
56
  for (int64_t i = 0; i < nnz; i++) {
Boris Bonev's avatar
Boris Bonev committed
57
58
59
60
61
62
63
64
65

    const int64_t ker = ker_h[i];
    const int64_t off = Koff[ker]++;

    ker_sort[off] = ker;
    row_sort[off] = row_h[i];
    col_sort[off] = col_h[i];
    val_sort[off] = val_h[i];
  }
Thorsten Kurth's avatar
Thorsten Kurth committed
66
  for (int64_t i = 0; i < nnz; i++) {
Boris Bonev's avatar
Boris Bonev committed
67
68
69
70
71
72
    ker_h[i] = ker_sort[i];
    row_h[i] = row_sort[i];
    col_h[i] = col_sort[i];
    val_h[i] = val_sort[i];
  }

Thorsten Kurth's avatar
Thorsten Kurth committed
73
74
75
76
77
  delete[] Koff;
  delete[] ker_sort;
  delete[] row_sort;
  delete[] col_sort;
  delete[] val_sort;
Boris Bonev's avatar
Boris Bonev committed
78
79
80
81

  // compute rows offsets
  nrows = 1;
  roff_h[0] = 0;
Thorsten Kurth's avatar
Thorsten Kurth committed
82
  for (int64_t i = 1; i < nnz; i++) {
Boris Bonev's avatar
Boris Bonev committed
83

Thorsten Kurth's avatar
Thorsten Kurth committed
84
    if (row_h[i - 1] == row_h[i]) continue;
Boris Bonev's avatar
Boris Bonev committed
85
86
    roff_h[nrows++] = i;

Thorsten Kurth's avatar
Thorsten Kurth committed
87
88
89
    if (nrows > Ho * K) {
      fprintf(stderr, "%s:%d: error, found more rows in the K COOs than Ho*K (%ld)\n", __FILE__, __LINE__,
              int64_t(Ho) * K);
Boris Bonev's avatar
Boris Bonev committed
90
91
92
93
94
95
96
97
      exit(EXIT_FAILURE);
    }
  }
  roff_h[nrows] = nnz;

  return;
}

Thorsten Kurth's avatar
Thorsten Kurth committed
98
99
100
torch::Tensor preprocess_psi(const int64_t K, const int64_t Ho, torch::Tensor ker_idx, torch::Tensor row_idx,
                             torch::Tensor col_idx, torch::Tensor val)
{
Boris Bonev's avatar
Boris Bonev committed
101
102
103
104
105

  CHECK_INPUT_TENSOR(ker_idx);
  CHECK_INPUT_TENSOR(row_idx);
  CHECK_INPUT_TENSOR(col_idx);
  CHECK_INPUT_TENSOR(val);
Thorsten Kurth's avatar
Thorsten Kurth committed
106

Boris Bonev's avatar
Boris Bonev committed
107
108
109
110
  int64_t nnz = val.size(0);
  int64_t *ker_h = ker_idx.data_ptr<int64_t>();
  int64_t *row_h = row_idx.data_ptr<int64_t>();
  int64_t *col_h = col_idx.data_ptr<int64_t>();
Thorsten Kurth's avatar
Thorsten Kurth committed
111
  int64_t *roff_h = new int64_t[Ho * K + 1];
Boris Bonev's avatar
Boris Bonev committed
112
  int64_t nrows;
Thorsten Kurth's avatar
Thorsten Kurth committed
113
114
115
116
117
118
  // float *val_h = val.data_ptr<float>();

  AT_DISPATCH_FLOATING_TYPES(val.scalar_type(), "preprocess_psi", ([&] {
                               preprocess_psi_kernel<scalar_t>(nnz, K, Ho, ker_h, row_h, col_h, roff_h,
                                                               val.data_ptr<scalar_t>(), nrows);
                             }));
Boris Bonev's avatar
Boris Bonev committed
119
120
121

  // create output tensor
  auto options = torch::TensorOptions().dtype(row_idx.dtype());
Thorsten Kurth's avatar
Thorsten Kurth committed
122
  auto roff_idx = torch::empty({nrows + 1}, options);
Boris Bonev's avatar
Boris Bonev committed
123
124
  int64_t *roff_out_h = roff_idx.data_ptr<int64_t>();

Thorsten Kurth's avatar
Thorsten Kurth committed
125
126
127
  for (int64_t i = 0; i < (nrows + 1); i++) { roff_out_h[i] = roff_h[i]; }
  delete[] roff_h;

Boris Bonev's avatar
Boris Bonev committed
128
129
130
  return roff_idx;
}

Thorsten Kurth's avatar
Thorsten Kurth committed
131
132
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
Boris Bonev's avatar
Boris Bonev committed
133
134
  m.def("preprocess_psi", &preprocess_psi, "Sort psi matrix, required for using disco_cuda.");
}