Commit 547759a6 authored by rusty1s's avatar rusty1s
Browse files

GPU build

parent 4e2e69be
...@@ -59,7 +59,7 @@ install: ...@@ -59,7 +59,7 @@ install:
- conda install pytorch=${TORCH_VERSION} ${TOOLKIT} -c pytorch --yes - conda install pytorch=${TORCH_VERSION} ${TOOLKIT} -c pytorch --yes
- source script/torch.sh - source script/torch.sh
- pip install flake8 codecov - pip install flake8 codecov
- python setup.py install - pip install .[test]
script: script:
- flake8 . - flake8 .
- python setup.py test - python setup.py test
......
cmake_minimum_required(VERSION 3.0) cmake_minimum_required(VERSION 3.0)
project(torchcluster) project(torchcluster)
set(CMAKE_CXX_STANDARD 14) set(CMAKE_CXX_STANDARD 14)
set(TORCHCLUSTER_VERSION 1.5.4) set(TORCHCLUSTER_VERSION 1.5.5)
option(WITH_CUDA "Enable CUDA support" OFF) option(WITH_CUDA "Enable CUDA support" OFF)
......
...@@ -101,12 +101,12 @@ torch::Tensor knn_cuda(torch::Tensor x, torch::Tensor y, ...@@ -101,12 +101,12 @@ torch::Tensor knn_cuda(torch::Tensor x, torch::Tensor y,
CHECK_INPUT(ptr_x.value().numel() == ptr_y.value().numel()); CHECK_INPUT(ptr_x.value().numel() == ptr_y.value().numel());
auto dist = torch::full(y.size(0) * k, 1e38, y.options()); auto dist = torch::full(y.size(0) * k, 1e38, y.options());
auto row = torch::empty(y.size(0) * k, ptr_y.options()); auto row = torch::empty(y.size(0) * k, ptr_y.value().options());
auto col = torch::full(y.size(0) * k, -1, ptr_y.options()); auto col = torch::full(y.size(0) * k, -1, ptr_y.value().options());
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "knn_kernel", [&] { AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "knn_kernel", [&] {
knn_kernel<scalar_t><<<ptr_x.size(0) - 1, THREADS, 0, stream>>>( knn_kernel<scalar_t><<<ptr_x.value().size(0) - 1, THREADS, 0, stream>>>(
x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
ptr_x.value().data_ptr<int64_t>(), ptr_y.value().data_ptr<int64_t>(), ptr_x.value().data_ptr<int64_t>(), ptr_y.value().data_ptr<int64_t>(),
dist.data_ptr<scalar_t>(), row.data_ptr<int64_t>(), dist.data_ptr<scalar_t>(), row.data_ptr<int64_t>(),
......
...@@ -75,7 +75,7 @@ torch::Tensor radius_cuda(torch::Tensor x, torch::Tensor y, ...@@ -75,7 +75,7 @@ torch::Tensor radius_cuda(torch::Tensor x, torch::Tensor y,
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "radius_kernel", [&] { AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "radius_kernel", [&] {
radius_kernel<scalar_t><<<ptr_x.size(0) - 1, THREADS, 0, stream>>>( radius_kernel<scalar_t><<<ptr_x.value().size(0) - 1, THREADS, 0, stream>>>(
x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
ptr_x.value().data_ptr<int64_t>(), ptr_y.value().data_ptr<int64_t>(), ptr_x.value().data_ptr<int64_t>(), ptr_y.value().data_ptr<int64_t>(),
row.data_ptr<int64_t>(), col.data_ptr<int64_t>(), r, max_num_neighbors, row.data_ptr<int64_t>(), col.data_ptr<int64_t>(), r, max_num_neighbors,
......
...@@ -3,6 +3,6 @@ ...@@ -3,6 +3,6 @@
#include <torch/extension.h> #include <torch/extension.h>
torch::Tensor radius_cuda(torch::Tensor x, torch::Tensor y, torch::Tensor radius_cuda(torch::Tensor x, torch::Tensor y,
torch::optiona<torch::Tensor> ptr_x, torch::optional<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_y, double r, torch::optional<torch::Tensor> ptr_y, double r,
int64_t max_num_neighbors); int64_t max_num_neighbors);
...@@ -63,7 +63,7 @@ tests_require = ['pytest', 'pytest-cov', 'scipy'] ...@@ -63,7 +63,7 @@ tests_require = ['pytest', 'pytest-cov', 'scipy']
setup( setup(
name='torch_cluster', name='torch_cluster',
version='1.5.4', version='1.5.5',
author='Matthias Fey', author='Matthias Fey',
author_email='matthias.fey@tu-dortmund.de', author_email='matthias.fey@tu-dortmund.de',
url='https://github.com/rusty1s/pytorch_cluster', url='https://github.com/rusty1s/pytorch_cluster',
...@@ -80,6 +80,7 @@ setup( ...@@ -80,6 +80,7 @@ setup(
install_requires=install_requires, install_requires=install_requires,
setup_requires=setup_requires, setup_requires=setup_requires,
tests_require=tests_require, tests_require=tests_require,
extras_require={'test': tests_require},
ext_modules=get_extensions() if not BUILD_DOCS else [], ext_modules=get_extensions() if not BUILD_DOCS else [],
cmdclass={ cmdclass={
'build_ext': 'build_ext':
......
...@@ -3,7 +3,7 @@ import os.path as osp ...@@ -3,7 +3,7 @@ import os.path as osp
import torch import torch
__version__ = '1.5.4' __version__ = '1.5.5'
for library in [ for library in [
'_version', '_grid', '_graclus', '_fps', '_rw', '_sampler', '_nearest', '_version', '_grid', '_graclus', '_fps', '_rw', '_sampler', '_nearest',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment