"examples/sampling/vscode:/vscode.git/clone" did not exist on "c0ac2f60b7e6622bae3a5b8a79686f55bc7b4ae3"
Commit d6a017ee authored by yanbing-j's avatar yanbing-j
Browse files

Enable bf16 support for basis_fw, basis_bw, weighting_fw and weighting_bw_x

parent fb3260be
...@@ -75,7 +75,7 @@ spline_basis_fw_cpu(torch::Tensor pseudo, torch::Tensor kernel_size, ...@@ -75,7 +75,7 @@ spline_basis_fw_cpu(torch::Tensor pseudo, torch::Tensor kernel_size,
auto is_open_spline_data = is_open_spline.data_ptr<uint8_t>(); auto is_open_spline_data = is_open_spline.data_ptr<uint8_t>();
auto weight_index_data = weight_index.data_ptr<int64_t>(); auto weight_index_data = weight_index.data_ptr<int64_t>();
AT_DISPATCH_FLOATING_TYPES(pseudo.scalar_type(), "basis_fw", [&] { AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::BFloat16, pseudo.scalar_type(), "basis_fw", [&] {
auto pseudo_data = pseudo.data_ptr<scalar_t>(); auto pseudo_data = pseudo.data_ptr<scalar_t>();
auto basis_data = basis.data_ptr<scalar_t>(); auto basis_data = basis.data_ptr<scalar_t>();
...@@ -135,7 +135,7 @@ torch::Tensor spline_basis_bw_cpu(torch::Tensor grad_basis, ...@@ -135,7 +135,7 @@ torch::Tensor spline_basis_bw_cpu(torch::Tensor grad_basis,
auto kernel_size_data = kernel_size.data_ptr<int64_t>(); auto kernel_size_data = kernel_size.data_ptr<int64_t>();
auto is_open_spline_data = is_open_spline.data_ptr<uint8_t>(); auto is_open_spline_data = is_open_spline.data_ptr<uint8_t>();
AT_DISPATCH_FLOATING_TYPES(pseudo.scalar_type(), "basis_bw", [&] { AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::BFloat16, pseudo.scalar_type(), "basis_bw", [&] {
auto grad_basis_data = grad_basis.data_ptr<scalar_t>(); auto grad_basis_data = grad_basis.data_ptr<scalar_t>();
auto pseudo_data = pseudo.data_ptr<scalar_t>(); auto pseudo_data = pseudo.data_ptr<scalar_t>();
auto grad_pseudo_data = grad_pseudo.data_ptr<scalar_t>(); auto grad_pseudo_data = grad_pseudo.data_ptr<scalar_t>();
......
...@@ -21,7 +21,7 @@ torch::Tensor spline_weighting_fw_cpu(torch::Tensor x, torch::Tensor weight, ...@@ -21,7 +21,7 @@ torch::Tensor spline_weighting_fw_cpu(torch::Tensor x, torch::Tensor weight,
auto weight_index_data = weight_index.data_ptr<int64_t>(); auto weight_index_data = weight_index.data_ptr<int64_t>();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_fw", [&] { AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::BFloat16, x.scalar_type(), "weighting_fw", [&] {
auto x_data = x.data_ptr<scalar_t>(); auto x_data = x.data_ptr<scalar_t>();
auto weight_data = weight.data_ptr<scalar_t>(); auto weight_data = weight.data_ptr<scalar_t>();
auto basis_data = basis.data_ptr<scalar_t>(); auto basis_data = basis.data_ptr<scalar_t>();
...@@ -71,7 +71,7 @@ torch::Tensor spline_weighting_bw_x_cpu(torch::Tensor grad_out, ...@@ -71,7 +71,7 @@ torch::Tensor spline_weighting_bw_x_cpu(torch::Tensor grad_out,
auto weight_index_data = weight_index.data_ptr<int64_t>(); auto weight_index_data = weight_index.data_ptr<int64_t>();
AT_DISPATCH_FLOATING_TYPES(grad_out.scalar_type(), "weighting_bw_x", [&] { AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::BFloat16, grad_out.scalar_type(), "weighting_bw_x", [&] {
auto grad_out_data = grad_out.data_ptr<scalar_t>(); auto grad_out_data = grad_out.data_ptr<scalar_t>();
auto weight_data = weight.data_ptr<scalar_t>(); auto weight_data = weight.data_ptr<scalar_t>();
auto basis_data = basis.data_ptr<scalar_t>(); auto basis_data = basis.data_ptr<scalar_t>();
...@@ -117,7 +117,7 @@ torch::Tensor spline_weighting_bw_weight_cpu(torch::Tensor grad_out, ...@@ -117,7 +117,7 @@ torch::Tensor spline_weighting_bw_weight_cpu(torch::Tensor grad_out,
auto weight_index_data = weight_index.data_ptr<int64_t>(); auto weight_index_data = weight_index.data_ptr<int64_t>();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_bw_weight", [&] { AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::BFloat16, x.scalar_type(), "weighting_bw_weight", [&] {
auto grad_out_data = grad_out.data_ptr<scalar_t>(); auto grad_out_data = grad_out.data_ptr<scalar_t>();
auto x_data = x.data_ptr<scalar_t>(); auto x_data = x.data_ptr<scalar_t>();
auto basis_data = basis.data_ptr<scalar_t>(); auto basis_data = basis.data_ptr<scalar_t>();
...@@ -163,7 +163,7 @@ torch::Tensor spline_weighting_bw_basis_cpu(torch::Tensor grad_out, ...@@ -163,7 +163,7 @@ torch::Tensor spline_weighting_bw_basis_cpu(torch::Tensor grad_out,
auto weight_index_data = weight_index.data_ptr<int64_t>(); auto weight_index_data = weight_index.data_ptr<int64_t>();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_bw_basis", [&] { AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::BFloat16, x.scalar_type(), "weighting_bw_basis", [&] {
auto grad_out_data = grad_out.data_ptr<scalar_t>(); auto grad_out_data = grad_out.data_ptr<scalar_t>();
auto x_data = x.data_ptr<scalar_t>(); auto x_data = x.data_ptr<scalar_t>();
auto weight_data = weight.data_ptr<scalar_t>(); auto weight_data = weight.data_ptr<scalar_t>();
......
...@@ -54,7 +54,11 @@ def test_spline_conv_forward(test, dtype, device): ...@@ -54,7 +54,11 @@ def test_spline_conv_forward(test, dtype, device):
out = spline_conv(x, edge_index, pseudo, weight, kernel_size, out = spline_conv(x, edge_index, pseudo, weight, kernel_size,
is_open_spline, 1, True, root_weight, bias) is_open_spline, 1, True, root_weight, bias)
assert out.tolist() == test['expected'] if dtype == torch.bfloat16:
target = torch.tensor(test['expected'])
assert torch.allclose(out.to(torch.float), target, rtol=1e-2, atol=1e-2)
else:
assert out.tolist() == test['expected']
@pytest.mark.parametrize('degree,device', product(degrees, devices)) @pytest.mark.parametrize('degree,device', product(degrees, devices))
......
import torch import torch
dtypes = [torch.float, torch.double] dtypes = [torch.float, torch.double, torch.bfloat16]
devices = [torch.device('cpu')] devices = [torch.device('cpu')]
if torch.cuda.is_available(): if torch.cuda.is_available():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment