Commit 547759a6 authored by rusty1s's avatar rusty1s
Browse files

GPU build

parent 4e2e69be
......@@ -59,7 +59,7 @@ install:
- conda install pytorch=${TORCH_VERSION} ${TOOLKIT} -c pytorch --yes
- source script/torch.sh
- pip install flake8 codecov
- python setup.py install
- pip install .[test]
script:
- flake8 .
- python setup.py test
......
cmake_minimum_required(VERSION 3.0)
project(torchcluster)
set(CMAKE_CXX_STANDARD 14)
set(TORCHCLUSTER_VERSION 1.5.4)
set(TORCHCLUSTER_VERSION 1.5.5)
option(WITH_CUDA "Enable CUDA support" OFF)
......
......@@ -101,12 +101,12 @@ torch::Tensor knn_cuda(torch::Tensor x, torch::Tensor y,
CHECK_INPUT(ptr_x.value().numel() == ptr_y.value().numel());
auto dist = torch::full(y.size(0) * k, 1e38, y.options());
auto row = torch::empty(y.size(0) * k, ptr_y.options());
auto col = torch::full(y.size(0) * k, -1, ptr_y.options());
auto row = torch::empty(y.size(0) * k, ptr_y.value().options());
auto col = torch::full(y.size(0) * k, -1, ptr_y.value().options());
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "knn_kernel", [&] {
knn_kernel<scalar_t><<<ptr_x.size(0) - 1, THREADS, 0, stream>>>(
knn_kernel<scalar_t><<<ptr_x.value().size(0) - 1, THREADS, 0, stream>>>(
x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
ptr_x.value().data_ptr<int64_t>(), ptr_y.value().data_ptr<int64_t>(),
dist.data_ptr<scalar_t>(), row.data_ptr<int64_t>(),
......
......@@ -75,7 +75,7 @@ torch::Tensor radius_cuda(torch::Tensor x, torch::Tensor y,
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "radius_kernel", [&] {
radius_kernel<scalar_t><<<ptr_x.size(0) - 1, THREADS, 0, stream>>>(
radius_kernel<scalar_t><<<ptr_x.value().size(0) - 1, THREADS, 0, stream>>>(
x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
ptr_x.value().data_ptr<int64_t>(), ptr_y.value().data_ptr<int64_t>(),
row.data_ptr<int64_t>(), col.data_ptr<int64_t>(), r, max_num_neighbors,
......
......@@ -3,6 +3,6 @@
#include <torch/extension.h>
torch::Tensor radius_cuda(torch::Tensor x, torch::Tensor y,
torch::optiona<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_y, double r,
int64_t max_num_neighbors);
......@@ -63,7 +63,7 @@ tests_require = ['pytest', 'pytest-cov', 'scipy']
setup(
name='torch_cluster',
version='1.5.4',
version='1.5.5',
author='Matthias Fey',
author_email='matthias.fey@tu-dortmund.de',
url='https://github.com/rusty1s/pytorch_cluster',
......@@ -80,6 +80,7 @@ setup(
install_requires=install_requires,
setup_requires=setup_requires,
tests_require=tests_require,
extras_require={'test': tests_require},
ext_modules=get_extensions() if not BUILD_DOCS else [],
cmdclass={
'build_ext':
......
......@@ -3,7 +3,7 @@ import os.path as osp
import torch
__version__ = '1.5.4'
__version__ = '1.5.5'
for library in [
'_version', '_grid', '_graclus', '_fps', '_rw', '_sampler', '_nearest',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment