"examples/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "df6556049da332d978799a7f67d60c4f21484133"
Commit 07a6928b authored by rusty1s's avatar rusty1s
Browse files

graph saint sampling complete

parent 61f01b59
#include "rw_cpu.h"
#include "utils.h"
torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor start, int64_t walk_length) {
CHECK_CPU(rowptr);
CHECK_CPU(col);
CHECK_CPU(start);
CHECK_INPUT(rowptr.dim() == 1);
CHECK_INPUT(col.dim() == 1);
CHECK_INPUT(start.dim() == 1);
auto rand = torch::rand({start.size(0), walk_length},
start.options().dtype(torch::kFloat));
auto L = walk_length + 1;
auto out = torch::full({start.size(0), L}, -1, start.options());
auto rowptr_data = rowptr.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>();
auto start_data = start.data_ptr<int64_t>();
auto rand_data = rand.data_ptr<float>();
auto out_data = out.data_ptr<int64_t>();
for (auto n = 0; n < start.size(0); n++) {
auto cur = start_data[n];
out_data[n * L] = cur;
int64_t row_start, row_end;
for (auto l = 0; l < walk_length; l++) {
row_start = rowptr_data[cur];
row_end = rowptr_data[cur + 1];
cur = col_data[row_start + int64_t(rand_data[n * walk_length + l] *
(row_end - row_start))];
out_data[n * L + l + 1] = cur;
}
}
return out;
}
#pragma once
#include <torch/extension.h>
torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor start, int64_t walk_length);
#include "rw_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include "utils.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
__global__ void uniform_random_walk_kernel(const int64_t *rowptr,
const int64_t *col,
const int64_t *start,
const float *rand, int64_t *out,
int64_t walk_length, int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_idx < numel) {
int64_t cur = start[thread_idx];
out[thread_idx] = cur;
int64_t row_start, row_end;
for (int64_t l = 0; l < walk_length; l++) {
row_start = rowptr[cur], row_end = rowptr[cur + 1];
cur = col[row_start +
int64_t(rand[l * numel + thread_idx] * (row_end - row_start))];
out[(l + 1) * numel + thread_idx] = cur;
}
}
}
torch::Tensor random_walk_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor start, int64_t walk_length) {
CHECK_CUDA(rowptr);
CHECK_CUDA(col);
CHECK_CUDA(start);
cudaSetDevice(rowptr.get_device());
CHECK_INPUT(rowptr.dim() == 1);
CHECK_INPUT(col.dim() == 1);
CHECK_INPUT(start.dim() == 1);
auto rand = torch::rand({walk_length, start.size(0)},
start.options().dtype(torch::kFloat));
auto out = torch::full({walk_length + 1, start.size(0)}, -1, start.options());
auto stream = at::cuda::getCurrentCUDAStream();
uniform_random_walk_kernel<<<BLOCKS(start.numel()), THREADS, 0, stream>>>(
rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(),
start.data_ptr<int64_t>(), rand.data_ptr<float>(),
out.data_ptr<int64_t>(), walk_length, start.numel());
return out.t().contiguous();
}
#pragma once
#include <torch/extension.h>
torch::Tensor random_walk_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor start, int64_t walk_length);
#include <Python.h>
#include <torch/script.h>
#include "cpu/rw_cpu.h"
#ifdef WITH_CUDA
#include "cuda/rw_cuda.h"
#endif
#ifdef _WIN32
PyMODINIT_FUNC PyInit__rw(void) { return NULL; }
#endif
torch::Tensor random_walk(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor start, int64_t walk_length) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA
return random_walk_cuda(rowptr, col, start, walk_length);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return random_walk_cpu(rowptr, col, start, walk_length);
}
}
static auto registry =
torch::RegisterOperators().op("torch_sparse::random_walk", &random_walk);
...@@ -8,7 +8,8 @@ expected_torch_version = (1, 4) ...@@ -8,7 +8,8 @@ expected_torch_version = (1, 4)
try: try:
for library in [ for library in [
'_version', '_convert', '_diag', '_spmm', '_spspmm', '_metis' '_version', '_convert', '_diag', '_spmm', '_spspmm', '_metis',
'_rw'
]: ]:
torch.ops.load_library(importlib.machinery.PathFinder().find_spec( torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
library, [osp.dirname(__file__)]).origin) library, [osp.dirname(__file__)]).origin)
......
...@@ -46,13 +46,16 @@ def sample_edge(src: SparseTensor, ...@@ -46,13 +46,16 @@ def sample_edge(src: SparseTensor,
def sample_rw(src: SparseTensor, num_root_nodes: int, def sample_rw(src: SparseTensor, num_root_nodes: int,
walk_length: int) -> Tuple[torch.Tensor, torch.Tensor]: walk_length: int) -> Tuple[torch.Tensor, torch.Tensor]:
rowptr, col, _ = src.csr()
start = np.random.choice(src.size(0), size=num_root_nodes, replace=False) start = np.random.choice(src.size(0), size=num_root_nodes, replace=False)
start = torch.from_numpy(start).to(src.device()) start = torch.from_numpy(start).to(src.device())
# get random walks of length `walk_length`: out = torch.ops.torch_sparse.random_walk(rowptr, col, start, walk_length)
# => `rw.size(1) == walk_length + 1
node_idx = out.flatten().unique()
return None, None return src.permute(node_idx), node_idx
SparseTensor.sample_node = sample_node SparseTensor.sample_node = sample_node
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment