Unverified Commit 18f48b73 authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Merge pull request #30 from yanbing-j/yanbing/bf16

Enable bf16 support for basis_fw, basis_bw, weighting_fw and weightin…
parents fb3260be 4c16e41b
......@@ -75,7 +75,7 @@ spline_basis_fw_cpu(torch::Tensor pseudo, torch::Tensor kernel_size,
auto is_open_spline_data = is_open_spline.data_ptr<uint8_t>();
auto weight_index_data = weight_index.data_ptr<int64_t>();
AT_DISPATCH_FLOATING_TYPES(pseudo.scalar_type(), "basis_fw", [&] {
AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::BFloat16, pseudo.scalar_type(), "basis_fw", [&] {
auto pseudo_data = pseudo.data_ptr<scalar_t>();
auto basis_data = basis.data_ptr<scalar_t>();
......@@ -135,7 +135,7 @@ torch::Tensor spline_basis_bw_cpu(torch::Tensor grad_basis,
auto kernel_size_data = kernel_size.data_ptr<int64_t>();
auto is_open_spline_data = is_open_spline.data_ptr<uint8_t>();
AT_DISPATCH_FLOATING_TYPES(pseudo.scalar_type(), "basis_bw", [&] {
AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::BFloat16, pseudo.scalar_type(), "basis_bw", [&] {
auto grad_basis_data = grad_basis.data_ptr<scalar_t>();
auto pseudo_data = pseudo.data_ptr<scalar_t>();
auto grad_pseudo_data = grad_pseudo.data_ptr<scalar_t>();
......
......@@ -21,7 +21,7 @@ torch::Tensor spline_weighting_fw_cpu(torch::Tensor x, torch::Tensor weight,
auto weight_index_data = weight_index.data_ptr<int64_t>();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_fw", [&] {
AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::BFloat16, x.scalar_type(), "weighting_fw", [&] {
auto x_data = x.data_ptr<scalar_t>();
auto weight_data = weight.data_ptr<scalar_t>();
auto basis_data = basis.data_ptr<scalar_t>();
......@@ -71,7 +71,7 @@ torch::Tensor spline_weighting_bw_x_cpu(torch::Tensor grad_out,
auto weight_index_data = weight_index.data_ptr<int64_t>();
AT_DISPATCH_FLOATING_TYPES(grad_out.scalar_type(), "weighting_bw_x", [&] {
AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::BFloat16, grad_out.scalar_type(), "weighting_bw_x", [&] {
auto grad_out_data = grad_out.data_ptr<scalar_t>();
auto weight_data = weight.data_ptr<scalar_t>();
auto basis_data = basis.data_ptr<scalar_t>();
......@@ -117,7 +117,7 @@ torch::Tensor spline_weighting_bw_weight_cpu(torch::Tensor grad_out,
auto weight_index_data = weight_index.data_ptr<int64_t>();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_bw_weight", [&] {
AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::BFloat16, x.scalar_type(), "weighting_bw_weight", [&] {
auto grad_out_data = grad_out.data_ptr<scalar_t>();
auto x_data = x.data_ptr<scalar_t>();
auto basis_data = basis.data_ptr<scalar_t>();
......@@ -163,7 +163,7 @@ torch::Tensor spline_weighting_bw_basis_cpu(torch::Tensor grad_out,
auto weight_index_data = weight_index.data_ptr<int64_t>();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_bw_basis", [&] {
AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::BFloat16, x.scalar_type(), "weighting_bw_basis", [&] {
auto grad_out_data = grad_out.data_ptr<scalar_t>();
auto x_data = x.data_ptr<scalar_t>();
auto weight_data = weight.data_ptr<scalar_t>();
......
......@@ -54,7 +54,12 @@ def test_spline_conv_forward(test, dtype, device):
out = spline_conv(x, edge_index, pseudo, weight, kernel_size,
is_open_spline, 1, True, root_weight, bias)
assert out.tolist() == test['expected']
if dtype == torch.bfloat16:
target = torch.tensor(test['expected'])
assert torch.allclose(out.to(torch.float), target,
rtol=1e-2, atol=1e-2)
else:
assert out.tolist() == test['expected']
@pytest.mark.parametrize('degree,device', product(degrees, devices))
......
import torch
dtypes = [torch.float, torch.double]
dtypes = [torch.float, torch.double, torch.bfloat16]
devices = [torch.device('cpu')]
if torch.cuda.is_available():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment