Unverified Commit e8620a86 authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Half-precision support (#119)

* half support

* deprecation

* typo

* test half

* fix test
parent 0d735d7e
...@@ -46,7 +46,8 @@ torch::Tensor graclus_cpu(torch::Tensor rowptr, torch::Tensor col, ...@@ -46,7 +46,8 @@ torch::Tensor graclus_cpu(torch::Tensor rowptr, torch::Tensor col,
} }
} else { } else {
auto weight = optional_weight.value(); auto weight = optional_weight.value();
AT_DISPATCH_ALL_TYPES(weight.scalar_type(), "weighted_graclus", [&] { auto scalar_type = weight.scalar_type();
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, scalar_type, "_", [&] {
auto weight_data = weight.data_ptr<scalar_t>(); auto weight_data = weight.data_ptr<scalar_t>();
for (auto n = 0; n < num_nodes; n++) { for (auto n = 0; n < num_nodes; n++) {
......
...@@ -25,7 +25,7 @@ torch::Tensor knn_cpu(torch::Tensor x, torch::Tensor y, ...@@ -25,7 +25,7 @@ torch::Tensor knn_cpu(torch::Tensor x, torch::Tensor y,
std::vector<size_t> out_vec = std::vector<size_t>(); std::vector<size_t> out_vec = std::vector<size_t>();
AT_DISPATCH_ALL_TYPES(x.scalar_type(), "knn_cpu", [&] { AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, x.scalar_type(), "_", [&] {
// See: nanoflann/examples/vector_of_vectors_example.cpp // See: nanoflann/examples/vector_of_vectors_example.cpp
auto x_data = x.data_ptr<scalar_t>(); auto x_data = x.data_ptr<scalar_t>();
......
...@@ -25,7 +25,7 @@ torch::Tensor radius_cpu(torch::Tensor x, torch::Tensor y, ...@@ -25,7 +25,7 @@ torch::Tensor radius_cpu(torch::Tensor x, torch::Tensor y,
std::vector<size_t> out_vec = std::vector<size_t>(); std::vector<size_t> out_vec = std::vector<size_t>();
AT_DISPATCH_ALL_TYPES(x.scalar_type(), "radius_cpu", [&] { AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, x.scalar_type(), "_", [&] {
// See: nanoflann/examples/vector_of_vectors_example.cpp // See: nanoflann/examples/vector_of_vectors_example.cpp
auto x_data = x.data_ptr<scalar_t>(); auto x_data = x.data_ptr<scalar_t>();
......
...@@ -78,19 +78,19 @@ torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, ...@@ -78,19 +78,19 @@ torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr,
auto batch_size = ptr.numel() - 1; auto batch_size = ptr.numel() - 1;
auto deg = ptr.narrow(0, 1, batch_size) - ptr.narrow(0, 0, batch_size); auto deg = ptr.narrow(0, 1, batch_size) - ptr.narrow(0, 0, batch_size);
auto out_ptr = deg.toType(torch::kFloat) * ratio; auto out_ptr = deg.toType(ratio.scalar_type()) * ratio;
out_ptr = out_ptr.ceil().toType(torch::kLong).cumsum(0); out_ptr = out_ptr.ceil().toType(torch::kLong).cumsum(0);
out_ptr = torch::cat({torch::zeros(1, ptr.options()), out_ptr}, 0); out_ptr = torch::cat({torch::zeros(1, ptr.options()), out_ptr}, 0);
torch::Tensor start; torch::Tensor start;
if (random_start) { if (random_start) {
start = torch::rand(batch_size, src.options()); start = torch::rand(batch_size, src.options());
start = (start * deg.toType(torch::kFloat)).toType(torch::kLong); start = (start * deg.toType(ratio.scalar_type())).toType(torch::kLong);
} else { } else {
start = torch::zeros(batch_size, ptr.options()); start = torch::zeros(batch_size, ptr.options());
} }
auto dist = torch::full(src.size(0), 1e38, src.options()); auto dist = torch::full(src.size(0), 5e4, src.options());
auto out_size = (int64_t *)malloc(sizeof(int64_t)); auto out_size = (int64_t *)malloc(sizeof(int64_t));
cudaMemcpy(out_size, out_ptr[-1].data_ptr<int64_t>(), sizeof(int64_t), cudaMemcpy(out_size, out_ptr[-1].data_ptr<int64_t>(), sizeof(int64_t),
...@@ -98,7 +98,8 @@ torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, ...@@ -98,7 +98,8 @@ torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr,
auto out = torch::empty(out_size[0], out_ptr.options()); auto out = torch::empty(out_size[0], out_ptr.options());
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "fps_kernel", [&] { auto scalar_type = src.scalar_type();
AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, scalar_type, "_", [&] {
fps_kernel<scalar_t><<<batch_size, THREADS, 0, stream>>>( fps_kernel<scalar_t><<<batch_size, THREADS, 0, stream>>>(
src.data_ptr<scalar_t>(), ptr.data_ptr<int64_t>(), src.data_ptr<scalar_t>(), ptr.data_ptr<int64_t>(),
out_ptr.data_ptr<int64_t>(), start.data_ptr<int64_t>(), out_ptr.data_ptr<int64_t>(), start.data_ptr<int64_t>(),
......
...@@ -113,7 +113,8 @@ void propose(torch::Tensor out, torch::Tensor proposal, torch::Tensor rowptr, ...@@ -113,7 +113,8 @@ void propose(torch::Tensor out, torch::Tensor proposal, torch::Tensor rowptr,
rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(), out.numel()); rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(), out.numel());
} else { } else {
auto weight = optional_weight.value(); auto weight = optional_weight.value();
AT_DISPATCH_ALL_TYPES(weight.scalar_type(), "propose_kernel", [&] { auto scalar_type = weight.scalar_type();
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, scalar_type, "_", [&] {
weighted_propose_kernel<scalar_t> weighted_propose_kernel<scalar_t>
<<<BLOCKS(out.numel()), THREADS, 0, stream>>>( <<<BLOCKS(out.numel()), THREADS, 0, stream>>>(
out.data_ptr<int64_t>(), proposal.data_ptr<int64_t>(), out.data_ptr<int64_t>(), proposal.data_ptr<int64_t>(),
...@@ -201,7 +202,8 @@ void respond(torch::Tensor out, torch::Tensor proposal, torch::Tensor rowptr, ...@@ -201,7 +202,8 @@ void respond(torch::Tensor out, torch::Tensor proposal, torch::Tensor rowptr,
rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(), out.numel()); rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(), out.numel());
} else { } else {
auto weight = optional_weight.value(); auto weight = optional_weight.value();
AT_DISPATCH_ALL_TYPES(weight.scalar_type(), "respond_kernel", [&] { auto scalar_type = weight.scalar_type();
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, scalar_type, "_", [&] {
weighted_respond_kernel<scalar_t> weighted_respond_kernel<scalar_t>
<<<BLOCKS(out.numel()), THREADS, 0, stream>>>( <<<BLOCKS(out.numel()), THREADS, 0, stream>>>(
out.data_ptr<int64_t>(), proposal.data_ptr<int64_t>(), out.data_ptr<int64_t>(), proposal.data_ptr<int64_t>(),
......
...@@ -61,7 +61,7 @@ torch::Tensor grid_cuda(torch::Tensor pos, torch::Tensor size, ...@@ -61,7 +61,7 @@ torch::Tensor grid_cuda(torch::Tensor pos, torch::Tensor size,
auto out = torch::empty(pos.size(0), pos.options().dtype(torch::kLong)); auto out = torch::empty(pos.size(0), pos.options().dtype(torch::kLong));
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(pos.scalar_type(), "grid_kernel", [&] { AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, pos.scalar_type(), "_", [&] {
grid_kernel<scalar_t><<<BLOCKS(out.numel()), THREADS, 0, stream>>>( grid_kernel<scalar_t><<<BLOCKS(out.numel()), THREADS, 0, stream>>>(
pos.data_ptr<scalar_t>(), size.data_ptr<scalar_t>(), pos.data_ptr<scalar_t>(), size.data_ptr<scalar_t>(),
start.data_ptr<scalar_t>(), end.data_ptr<scalar_t>(), start.data_ptr<scalar_t>(), end.data_ptr<scalar_t>(),
......
...@@ -45,7 +45,7 @@ knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y, ...@@ -45,7 +45,7 @@ knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
int64_t best_idx[100]; int64_t best_idx[100];
for (int e = 0; e < k; e++) { for (int e = 0; e < k; e++) {
best_dist[e] = 1e10; best_dist[e] = 5e4;
best_idx[e] = -1; best_idx[e] = -1;
} }
...@@ -121,7 +121,8 @@ torch::Tensor knn_cuda(const torch::Tensor x, const torch::Tensor y, ...@@ -121,7 +121,8 @@ torch::Tensor knn_cuda(const torch::Tensor x, const torch::Tensor y,
dim3 BLOCKS((y.size(0) + THREADS - 1) / THREADS); dim3 BLOCKS((y.size(0) + THREADS - 1) / THREADS);
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "knn_kernel", [&] { auto scalar_type = x.scalar_type();
AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, scalar_type, "_", [&] {
knn_kernel<scalar_t><<<BLOCKS, THREADS, 0, stream>>>( knn_kernel<scalar_t><<<BLOCKS, THREADS, 0, stream>>>(
x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
ptr_x.value().data_ptr<int64_t>(), ptr_y.value().data_ptr<int64_t>(), ptr_x.value().data_ptr<int64_t>(), ptr_y.value().data_ptr<int64_t>(),
......
...@@ -79,7 +79,8 @@ torch::Tensor nearest_cuda(torch::Tensor x, torch::Tensor y, ...@@ -79,7 +79,8 @@ torch::Tensor nearest_cuda(torch::Tensor x, torch::Tensor y,
auto out = torch::empty({x.size(0)}, ptr_x.options()); auto out = torch::empty({x.size(0)}, ptr_x.options());
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "nearest_kernel", [&] { auto scalar_type = x.scalar_type();
AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, scalar_type, "_", [&] {
nearest_kernel<scalar_t><<<x.size(0), THREADS, 0, stream>>>( nearest_kernel<scalar_t><<<x.size(0), THREADS, 0, stream>>>(
x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
ptr_x.data_ptr<int64_t>(), ptr_y.data_ptr<int64_t>(), ptr_x.data_ptr<int64_t>(), ptr_y.data_ptr<int64_t>(),
......
...@@ -80,7 +80,8 @@ torch::Tensor radius_cuda(const torch::Tensor x, const torch::Tensor y, ...@@ -80,7 +80,8 @@ torch::Tensor radius_cuda(const torch::Tensor x, const torch::Tensor y,
dim3 BLOCKS((y.size(0) + THREADS - 1) / THREADS); dim3 BLOCKS((y.size(0) + THREADS - 1) / THREADS);
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "radius_kernel", [&] { auto scalar_type = x.scalar_type();
AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, scalar_type, "_", [&] {
radius_kernel<scalar_t><<<BLOCKS, THREADS, 0, stream>>>( radius_kernel<scalar_t><<<BLOCKS, THREADS, 0, stream>>>(
x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
ptr_x.value().data_ptr<int64_t>(), ptr_y.value().data_ptr<int64_t>(), ptr_x.value().data_ptr<int64_t>(), ptr_y.value().data_ptr<int64_t>(),
......
...@@ -67,7 +67,7 @@ def test_knn_graph(dtype, device): ...@@ -67,7 +67,7 @@ def test_knn_graph(dtype, device):
(3, 2), (0, 3), (2, 3)]) (3, 2), (0, 3), (2, 3)])
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) @pytest.mark.parametrize('dtype,device', product([torch.float], devices))
def test_knn_graph_large(dtype, device): def test_knn_graph_large(dtype, device):
x = torch.randn(1000, 3, dtype=dtype, device=device) x = torch.randn(1000, 3, dtype=dtype, device=device)
......
...@@ -66,7 +66,7 @@ def test_radius_graph(dtype, device): ...@@ -66,7 +66,7 @@ def test_radius_graph(dtype, device):
(3, 2), (0, 3), (2, 3)]) (3, 2), (0, 3), (2, 3)])
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) @pytest.mark.parametrize('dtype,device', product([torch.float], devices))
def test_radius_graph_large(dtype, device): def test_radius_graph_large(dtype, device):
x = torch.randn(1000, 3, dtype=dtype, device=device) x = torch.randn(1000, 3, dtype=dtype, device=device)
......
import torch import torch
dtypes = [torch.float, torch.double, torch.int, torch.long] dtypes = [torch.half, torch.float, torch.double, torch.int, torch.long]
grad_dtypes = [torch.float, torch.double] grad_dtypes = [torch.half, torch.float, torch.double]
devices = [torch.device('cpu')] devices = [torch.device('cpu')]
if torch.cuda.is_available(): if torch.cuda.is_available():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment