Unverified Commit 01dc5042 authored by YanbingJiang's avatar YanbingJiang Committed by GitHub
Browse files

Add `spmm` bf16 support (#269)



* Add spmm bf16 support and add ut

* Add bf16 support for spspmm

* Disable bf16 test before torch_scatter 2.0.9

* Refactor version compare

* Update test/utils.py
Co-authored-by: default avatarMatthias Fey <matthias.fey@tu-dortmund.de>
parent f7c74ec0
......@@ -44,7 +44,7 @@ spmm_cpu(torch::Tensor rowptr, torch::Tensor col,
auto K = mat.size(-1);
auto B = mat.numel() / (N * K);
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, mat.scalar_type(), "_", [&] {
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, mat.scalar_type(), "spmm_cpu", [&] {
scalar_t *value_data = nullptr;
auto mat_data = mat.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
......@@ -123,7 +123,7 @@ torch::Tensor spmm_value_bw_cpu(torch::Tensor row, torch::Tensor rowptr,
auto row_data = row.data_ptr<int64_t>();
auto rowptr_data = rowptr.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>();
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, mat.scalar_type(), "_", [&] {
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, mat.scalar_type(), "spmm_value_bw_cpu", [&] {
auto mat_data = mat.data_ptr<scalar_t>();
auto grad_data = grad.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
......
......@@ -55,7 +55,7 @@ spspmm_cpu(torch::Tensor rowptrA, torch::Tensor colA,
torch::Tensor colC;
torch::optional<torch::Tensor> optional_valueC = torch::nullopt;
AT_DISPATCH_ALL_TYPES(scalar_type, "spspmm", [&] {
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::BFloat16, scalar_type, "spspmm", [&] {
AT_DISPATCH_HAS_VALUE(optional_valueA, [&] {
scalar_t *valA_data = nullptr, *valB_data = nullptr;
if (HAS_VALUE) {
......@@ -77,7 +77,7 @@ spspmm_cpu(torch::Tensor rowptrA, torch::Tensor colA,
if (HAS_VALUE)
tmp_vals[cB] += valA_data[eA] * valB_data[eB];
else
tmp_vals[cB]++;
tmp_vals[cB] += 1;
}
}
......
......@@ -35,7 +35,7 @@ def test_sparse_tensor_spspmm(dtype, device):
], dtype=dtype, device=device),
)
expected = torch.eye(10, dtype=dtype, device=device)
expected = torch.eye(10, device=device).to(dtype)
out = x @ x.to_dense().t()
assert torch.allclose(out, expected, atol=1e-2)
......
import torch
import torch_scatter
from packaging import version
reductions = ['sum', 'add', 'mean', 'min', 'max']
dtypes = [torch.half, torch.float, torch.double, torch.int, torch.long]
grad_dtypes = [torch.half, torch.float, torch.double]
if version.parse(torch_scatter.__version__) > version.parse("2.0.9"):
dtypes.append(torch.bfloat16)
grad_dtypes.append(torch.bfloat16)
devices = [torch.device('cpu')]
if torch.cuda.is_available():
devices += [torch.device(f'cuda:{torch.cuda.current_device()}')]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment