"docs/vscode:/vscode.git/clone" did not exist on "d7117b95ab120230bb7dc6e69c7c4c800397fcbf"
Unverified Commit 27387388 authored by YanbingJiang's avatar YanbingJiang Committed by GitHub
Browse files

Add bf16 support for knn_cpu, radius_cpu and graclus_cpu (#144)

parent eea2fc58
......@@ -47,7 +47,7 @@ torch::Tensor graclus_cpu(torch::Tensor rowptr, torch::Tensor col,
} else {
auto weight = optional_weight.value();
auto scalar_type = weight.scalar_type();
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, scalar_type, "_", [&] {
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, scalar_type, "graclus_cpu", [&] {
auto weight_data = weight.data_ptr<scalar_t>();
for (auto n = 0; n < num_nodes; n++) {
......
......@@ -25,7 +25,7 @@ torch::Tensor knn_cpu(torch::Tensor x, torch::Tensor y,
std::vector<size_t> out_vec = std::vector<size_t>();
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, x.scalar_type(), "_", [&] {
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, x.scalar_type(), "knn_cpu", [&] {
// See: nanoflann/examples/vector_of_vectors_example.cpp
auto x_data = x.data_ptr<scalar_t>();
......
......@@ -25,7 +25,7 @@ torch::Tensor radius_cpu(torch::Tensor x, torch::Tensor y,
std::vector<size_t> out_vec = std::vector<size_t>();
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, x.scalar_type(), "_", [&] {
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, x.scalar_type(), "radius_cpu", [&] {
// See: nanoflann/examples/vector_of_vectors_example.cpp
auto x_data = x.data_ptr<scalar_t>();
......
import torch
dtypes = [torch.half, torch.float, torch.double, torch.int, torch.long]
dtypes = [torch.half, torch.bfloat16, torch.float, torch.double,
torch.int, torch.long]
grad_dtypes = [torch.half, torch.float, torch.double]
devices = [torch.device('cpu')]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment