degree_padding_cuda.h 697 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
#pragma once

#include <torch/extension.h>

rusty1s's avatar
rusty1s committed
5
6
7
std::tuple<std::vector<torch::Tensor>, std::vector<int64_t>>
bin_assignment_cuda(torch::Tensor rowcount, torch::Tensor binptr);

rusty1s's avatar
rusty1s committed
8
9
std::tuple<torch::Tensor, torch::Tensor>
padded_index_select_cuda(torch::Tensor src, torch::Tensor rowptr,
rusty1s's avatar
rusty1s committed
10
11
                         torch::Tensor col, torch::Tensor index, int64_t length,
                         torch::Tensor fill_value);
rusty1s's avatar
rusty1s committed
12
13
14
15
16

// std::tuple<torch::Tensor, torch::Tensor> padded_index_select_cuda2(
//     torch::Tensor src, torch::Tensor rowptr, torch::Tensor col,
//     torch::Tensor bin, torch::Tensor index, std::vector<int64_t> node_counts,
//     std::vector<int64_t> lengths, torch::Tensor fill_value);