Unverified Commit 18d37590 authored by YanbingJiang's avatar YanbingJiang Committed by GitHub
Browse files

Add scatter/segment bf16 support (#316)

parent fc1b1394
...@@ -57,7 +57,7 @@ scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim, ...@@ -57,7 +57,7 @@ scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim,
auto N = out.size(dim); auto N = out.size(dim);
auto index_info = getTensorInfo<int64_t>(index); auto index_info = getTensorInfo<int64_t>(index);
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] { AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type(), "scatter_cpu", [&] {
auto src_data = src.data_ptr<scalar_t>(); auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>(); auto out_data = out.data_ptr<scalar_t>();
......
...@@ -69,7 +69,7 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index, ...@@ -69,7 +69,7 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
auto index_info = getTensorInfo<int64_t>(index); auto index_info = getTensorInfo<int64_t>(index);
auto stride = index_info.strides[index_info.dims - 1]; auto stride = index_info.strides[index_info.dims - 1];
std::vector<int64_t> args(K); std::vector<int64_t> args(K);
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] { AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type(), "segment_coo_cpu", [&] {
auto src_data = src.data_ptr<scalar_t>(); auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>(); auto out_data = out.data_ptr<scalar_t>();
scalar_t *count_data = nullptr; scalar_t *count_data = nullptr;
...@@ -178,7 +178,7 @@ torch::Tensor gather_coo_cpu(torch::Tensor src, torch::Tensor index, ...@@ -178,7 +178,7 @@ torch::Tensor gather_coo_cpu(torch::Tensor src, torch::Tensor index,
auto index_info = getTensorInfo<int64_t>(index); auto index_info = getTensorInfo<int64_t>(index);
auto stride = index_info.strides[index_info.dims - 1]; auto stride = index_info.strides[index_info.dims - 1];
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] { AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type(), "gather_coo_cpu", [&] {
auto src_data = src.data_ptr<scalar_t>(); auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>(); auto out_data = out.data_ptr<scalar_t>();
......
...@@ -57,7 +57,7 @@ segment_csr_cpu(torch::Tensor src, torch::Tensor indptr, ...@@ -57,7 +57,7 @@ segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
auto indptr_info = getTensorInfo<int64_t>(indptr); auto indptr_info = getTensorInfo<int64_t>(indptr);
auto stride = indptr_info.strides[indptr_info.dims - 1]; auto stride = indptr_info.strides[indptr_info.dims - 1];
std::vector<int64_t> args(K); std::vector<int64_t> args(K);
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] { AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type(), "segment_csr_cpu", [&] {
auto src_data = src.data_ptr<scalar_t>(); auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>(); auto out_data = out.data_ptr<scalar_t>();
...@@ -135,7 +135,7 @@ torch::Tensor gather_csr_cpu(torch::Tensor src, torch::Tensor indptr, ...@@ -135,7 +135,7 @@ torch::Tensor gather_csr_cpu(torch::Tensor src, torch::Tensor indptr,
auto indptr_info = getTensorInfo<int64_t>(indptr); auto indptr_info = getTensorInfo<int64_t>(indptr);
auto stride = indptr_info.strides[indptr_info.dims - 1]; auto stride = indptr_info.strides[indptr_info.dims - 1];
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] { AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type(), "gather_csr_cpu", [&] {
auto src_data = src.data_ptr<scalar_t>(); auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>(); auto out_data = out.data_ptr<scalar_t>();
......
...@@ -2,7 +2,8 @@ import torch ...@@ -2,7 +2,8 @@ import torch
reductions = ['sum', 'add', 'mean', 'min', 'max'] reductions = ['sum', 'add', 'mean', 'min', 'max']
dtypes = [torch.half, torch.float, torch.double, torch.int, torch.long] dtypes = [torch.half, torch.bfloat16, torch.float, torch.double,
torch.int, torch.long]
grad_dtypes = [torch.float, torch.double] grad_dtypes = [torch.float, torch.double]
devices = [torch.device('cpu')] devices = [torch.device('cpu')]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment