Commit 65374fb6 authored by rusty1s's avatar rusty1s
Browse files

multi gpu update

parent 3a4c67c0
...@@ -162,6 +162,7 @@ fps_kernel(const scalar_t *__restrict__ x, const int64_t *__restrict__ cum_deg, ...@@ -162,6 +162,7 @@ fps_kernel(const scalar_t *__restrict__ x, const int64_t *__restrict__ cum_deg,
}() }()
at::Tensor fps_cuda(at::Tensor x, at::Tensor batch, float ratio, bool random) { at::Tensor fps_cuda(at::Tensor x, at::Tensor batch, float ratio, bool random) {
cudaSetDevice(x.get_device());
auto batch_sizes = (int64_t *)malloc(sizeof(int64_t)); auto batch_sizes = (int64_t *)malloc(sizeof(int64_t));
cudaMemcpy(batch_sizes, batch[-1].data<int64_t>(), sizeof(int64_t), cudaMemcpy(batch_sizes, batch[-1].data<int64_t>(), sizeof(int64_t),
cudaMemcpyDeviceToHost); cudaMemcpyDeviceToHost);
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "utils.cuh" #include "utils.cuh"
at::Tensor graclus_cuda(at::Tensor row, at::Tensor col, int64_t num_nodes) { at::Tensor graclus_cuda(at::Tensor row, at::Tensor col, int64_t num_nodes) {
cudaSetDevice(row.get_device());
std::tie(row, col) = remove_self_loops(row, col); std::tie(row, col) = remove_self_loops(row, col);
std::tie(row, col) = rand(row, col); std::tie(row, col) = rand(row, col);
std::tie(row, col) = to_csr(row, col, num_nodes); std::tie(row, col) = to_csr(row, col, num_nodes);
...@@ -23,6 +24,7 @@ at::Tensor graclus_cuda(at::Tensor row, at::Tensor col, int64_t num_nodes) { ...@@ -23,6 +24,7 @@ at::Tensor graclus_cuda(at::Tensor row, at::Tensor col, int64_t num_nodes) {
at::Tensor weighted_graclus_cuda(at::Tensor row, at::Tensor col, at::Tensor weighted_graclus_cuda(at::Tensor row, at::Tensor col,
at::Tensor weight, int64_t num_nodes) { at::Tensor weight, int64_t num_nodes) {
cudaSetDevice(row.get_device());
std::tie(row, col, weight) = remove_self_loops(row, col, weight); std::tie(row, col, weight) = remove_self_loops(row, col, weight);
std::tie(row, col, weight) = to_csr(row, col, weight, num_nodes); std::tie(row, col, weight) = to_csr(row, col, weight, num_nodes);
......
...@@ -26,6 +26,7 @@ __global__ void grid_kernel(int64_t *cluster, ...@@ -26,6 +26,7 @@ __global__ void grid_kernel(int64_t *cluster,
at::Tensor grid_cuda(at::Tensor pos, at::Tensor size, at::Tensor start, at::Tensor grid_cuda(at::Tensor pos, at::Tensor size, at::Tensor start,
at::Tensor end) { at::Tensor end) {
cudaSetDevice(pos.get_device());
auto cluster = at::empty(pos.size(0), pos.options().dtype(at::kLong)); auto cluster = at::empty(pos.size(0), pos.options().dtype(at::kLong));
AT_DISPATCH_ALL_TYPES(pos.type(), "grid_kernel", [&] { AT_DISPATCH_ALL_TYPES(pos.type(), "grid_kernel", [&] {
......
...@@ -52,6 +52,7 @@ knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y, ...@@ -52,6 +52,7 @@ knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
at::Tensor knn_cuda(at::Tensor x, at::Tensor y, size_t k, at::Tensor batch_x, at::Tensor knn_cuda(at::Tensor x, at::Tensor y, size_t k, at::Tensor batch_x,
at::Tensor batch_y) { at::Tensor batch_y) {
cudaSetDevice(x.get_device());
auto batch_sizes = (int64_t *)malloc(sizeof(int64_t)); auto batch_sizes = (int64_t *)malloc(sizeof(int64_t));
cudaMemcpy(batch_sizes, batch_x[-1].data<int64_t>(), sizeof(int64_t), cudaMemcpy(batch_sizes, batch_x[-1].data<int64_t>(), sizeof(int64_t),
cudaMemcpyDeviceToHost); cudaMemcpyDeviceToHost);
......
...@@ -60,6 +60,7 @@ __global__ void nearest_kernel(const scalar_t *__restrict__ x, ...@@ -60,6 +60,7 @@ __global__ void nearest_kernel(const scalar_t *__restrict__ x,
at::Tensor nearest_cuda(at::Tensor x, at::Tensor y, at::Tensor batch_x, at::Tensor nearest_cuda(at::Tensor x, at::Tensor y, at::Tensor batch_x,
at::Tensor batch_y) { at::Tensor batch_y) {
cudaSetDevice(x.get_device());
auto batch_sizes = (int64_t *)malloc(sizeof(int64_t)); auto batch_sizes = (int64_t *)malloc(sizeof(int64_t));
cudaMemcpy(batch_sizes, batch_x[-1].data<int64_t>(), sizeof(int64_t), cudaMemcpy(batch_sizes, batch_x[-1].data<int64_t>(), sizeof(int64_t),
cudaMemcpyDeviceToHost); cudaMemcpyDeviceToHost);
......
...@@ -48,6 +48,7 @@ radius_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y, ...@@ -48,6 +48,7 @@ radius_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
at::Tensor radius_cuda(at::Tensor x, at::Tensor y, float radius, at::Tensor radius_cuda(at::Tensor x, at::Tensor y, float radius,
at::Tensor batch_x, at::Tensor batch_y, at::Tensor batch_x, at::Tensor batch_y,
size_t max_num_neighbors) { size_t max_num_neighbors) {
cudaSetDevice(x.get_device());
auto batch_sizes = (int64_t *)malloc(sizeof(int64_t)); auto batch_sizes = (int64_t *)malloc(sizeof(int64_t));
cudaMemcpy(batch_sizes, batch_x[-1].data<int64_t>(), sizeof(int64_t), cudaMemcpy(batch_sizes, batch_x[-1].data<int64_t>(), sizeof(int64_t),
cudaMemcpyDeviceToHost); cudaMemcpyDeviceToHost);
......
...@@ -27,7 +27,7 @@ __global__ void uniform_rw_kernel( ...@@ -27,7 +27,7 @@ __global__ void uniform_rw_kernel(
at::Tensor rw_cuda(at::Tensor row, at::Tensor col, at::Tensor start, at::Tensor rw_cuda(at::Tensor row, at::Tensor col, at::Tensor start,
size_t walk_length, float p, float q, size_t num_nodes) { size_t walk_length, float p, float q, size_t num_nodes) {
cudaSetDevice(row.get_device());
auto deg = degree(row, num_nodes); auto deg = degree(row, num_nodes);
row = at::cat({at::zeros(1, deg.options()), deg.cumsum(0)}, 0); row = at::cat({at::zeros(1, deg.options()), deg.cumsum(0)}, 0);
......
...@@ -27,7 +27,7 @@ if CUDA_HOME is not None: ...@@ -27,7 +27,7 @@ if CUDA_HOME is not None:
['cuda/rw.cpp', 'cuda/rw_kernel.cu']), ['cuda/rw.cpp', 'cuda/rw_kernel.cu']),
] ]
__version__ = '1.2.3' __version__ = '1.2.4'
url = 'https://github.com/rusty1s/pytorch_cluster' url = 'https://github.com/rusty1s/pytorch_cluster'
install_requires = ['scipy'] install_requires = ['scipy']
......
...@@ -6,7 +6,7 @@ from .knn import knn, knn_graph ...@@ -6,7 +6,7 @@ from .knn import knn, knn_graph
from .radius import radius, radius_graph from .radius import radius, radius_graph
from .rw import random_walk from .rw import random_walk
__version__ = '1.2.3' __version__ = '1.2.4'
__all__ = [ __all__ = [
'graclus_cluster', 'graclus_cluster',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment