Commit d3a94a25 authored by rusty1s's avatar rusty1s
Browse files

add relabeling

parent ed3de958
#include "relabel_cpu.h"
#include "utils.h"
std::tuple<torch::Tensor, torch::Tensor> relabel_cpu(torch::Tensor col,
torch::Tensor idx) {
CHECK_CPU(col);
CHECK_CPU(idx);
CHECK_INPUT(idx.dim() == 1);
auto col_data = col.data_ptr<int64_t>();
auto idx_data = idx.data_ptr<int64_t>();
std::vector<int64_t> cols;
std::vector<int64_t> n_ids;
std::unordered_map<int64_t, int64_t> n_id_map;
int64_t i;
for (int64_t n = 0; n < idx.size(0); n++) {
i = idx_data[n];
n_id_map[i] = n;
n_ids.push_back(i);
}
int64_t c;
for (int64_t e = 0; e < col.size(0); e++) {
c = col_data[e];
if (n_id_map.count(c) == 0) {
n_id_map[c] = n_ids.size();
n_ids.push_back(c);
}
cols.push_back(n_id_map[c]);
}
int64_t n_len = n_ids.size(), e_len = cols.size();
auto out_col = torch::from_blob(cols.data(), {e_len}, col.options()).clone();
auto out_idx = torch::from_blob(n_ids.data(), {n_len}, col.options()).clone();
return std::make_tuple(out_col, out_idx);
}
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::Tensor> relabel_cpu(torch::Tensor col,
torch::Tensor idx);
#include <Python.h>
#include <torch/script.h>
#include "cpu/relabel_cpu.h"
#ifdef _WIN32
PyMODINIT_FUNC PyInit__relabel(void) { return NULL; }
#endif
std::tuple<torch::Tensor, torch::Tensor> relabel(torch::Tensor col,
torch::Tensor idx) {
if (col.device().is_cuda()) {
#ifdef WITH_CUDA
AT_ERROR("No CUDA version supported");
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return relabel_cpu(col, idx);
}
}
static auto registry =
torch::RegisterOperators().op("torch_sparse::relabel", &relabel);
......@@ -7,7 +7,7 @@ __version__ = '0.6.6'
for library in [
'_version', '_convert', '_diag', '_spmm', '_spspmm', '_metis', '_rw',
'_saint', '_sample'
'_saint', '_sample', '_relabel'
]:
torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
library, [osp.dirname(__file__)]).origin)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment