Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
d6a017ee
"examples/sampling/vscode:/vscode.git/clone" did not exist on "c0ac2f60b7e6622bae3a5b8a79686f55bc7b4ae3"
Commit
d6a017ee
authored
Aug 30, 2022
by
yanbing-j
Browse files
Enable bf16 support for basis_fw, basis_bw, weighting_fw and weighting_bw_x
parent
fb3260be
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
12 additions
and
8 deletions
+12
-8
csrc/cpu/basis_cpu.cpp
csrc/cpu/basis_cpu.cpp
+2
-2
csrc/cpu/weighting_cpu.cpp
csrc/cpu/weighting_cpu.cpp
+4
-4
test/test_conv.py
test/test_conv.py
+5
-1
test/utils.py
test/utils.py
+1
-1
No files found.
csrc/cpu/basis_cpu.cpp
View file @
d6a017ee
...
@@ -75,7 +75,7 @@ spline_basis_fw_cpu(torch::Tensor pseudo, torch::Tensor kernel_size,
...
@@ -75,7 +75,7 @@ spline_basis_fw_cpu(torch::Tensor pseudo, torch::Tensor kernel_size,
auto
is_open_spline_data
=
is_open_spline
.
data_ptr
<
uint8_t
>
();
auto
is_open_spline_data
=
is_open_spline
.
data_ptr
<
uint8_t
>
();
auto
weight_index_data
=
weight_index
.
data_ptr
<
int64_t
>
();
auto
weight_index_data
=
weight_index
.
data_ptr
<
int64_t
>
();
AT_DISPATCH_FLOATING_TYPES
(
pseudo
.
scalar_type
(),
"basis_fw"
,
[
&
]
{
AT_DISPATCH_FLOATING_TYPES
_AND
(
at
::
ScalarType
::
BFloat16
,
pseudo
.
scalar_type
(),
"basis_fw"
,
[
&
]
{
auto
pseudo_data
=
pseudo
.
data_ptr
<
scalar_t
>
();
auto
pseudo_data
=
pseudo
.
data_ptr
<
scalar_t
>
();
auto
basis_data
=
basis
.
data_ptr
<
scalar_t
>
();
auto
basis_data
=
basis
.
data_ptr
<
scalar_t
>
();
...
@@ -135,7 +135,7 @@ torch::Tensor spline_basis_bw_cpu(torch::Tensor grad_basis,
...
@@ -135,7 +135,7 @@ torch::Tensor spline_basis_bw_cpu(torch::Tensor grad_basis,
auto
kernel_size_data
=
kernel_size
.
data_ptr
<
int64_t
>
();
auto
kernel_size_data
=
kernel_size
.
data_ptr
<
int64_t
>
();
auto
is_open_spline_data
=
is_open_spline
.
data_ptr
<
uint8_t
>
();
auto
is_open_spline_data
=
is_open_spline
.
data_ptr
<
uint8_t
>
();
AT_DISPATCH_FLOATING_TYPES
(
pseudo
.
scalar_type
(),
"basis_bw"
,
[
&
]
{
AT_DISPATCH_FLOATING_TYPES
_AND
(
at
::
ScalarType
::
BFloat16
,
pseudo
.
scalar_type
(),
"basis_bw"
,
[
&
]
{
auto
grad_basis_data
=
grad_basis
.
data_ptr
<
scalar_t
>
();
auto
grad_basis_data
=
grad_basis
.
data_ptr
<
scalar_t
>
();
auto
pseudo_data
=
pseudo
.
data_ptr
<
scalar_t
>
();
auto
pseudo_data
=
pseudo
.
data_ptr
<
scalar_t
>
();
auto
grad_pseudo_data
=
grad_pseudo
.
data_ptr
<
scalar_t
>
();
auto
grad_pseudo_data
=
grad_pseudo
.
data_ptr
<
scalar_t
>
();
...
...
csrc/cpu/weighting_cpu.cpp
View file @
d6a017ee
...
@@ -21,7 +21,7 @@ torch::Tensor spline_weighting_fw_cpu(torch::Tensor x, torch::Tensor weight,
...
@@ -21,7 +21,7 @@ torch::Tensor spline_weighting_fw_cpu(torch::Tensor x, torch::Tensor weight,
auto
weight_index_data
=
weight_index
.
data_ptr
<
int64_t
>
();
auto
weight_index_data
=
weight_index
.
data_ptr
<
int64_t
>
();
AT_DISPATCH_FLOATING_TYPES
(
x
.
scalar_type
(),
"weighting_fw"
,
[
&
]
{
AT_DISPATCH_FLOATING_TYPES
_AND
(
at
::
ScalarType
::
BFloat16
,
x
.
scalar_type
(),
"weighting_fw"
,
[
&
]
{
auto
x_data
=
x
.
data_ptr
<
scalar_t
>
();
auto
x_data
=
x
.
data_ptr
<
scalar_t
>
();
auto
weight_data
=
weight
.
data_ptr
<
scalar_t
>
();
auto
weight_data
=
weight
.
data_ptr
<
scalar_t
>
();
auto
basis_data
=
basis
.
data_ptr
<
scalar_t
>
();
auto
basis_data
=
basis
.
data_ptr
<
scalar_t
>
();
...
@@ -71,7 +71,7 @@ torch::Tensor spline_weighting_bw_x_cpu(torch::Tensor grad_out,
...
@@ -71,7 +71,7 @@ torch::Tensor spline_weighting_bw_x_cpu(torch::Tensor grad_out,
auto
weight_index_data
=
weight_index
.
data_ptr
<
int64_t
>
();
auto
weight_index_data
=
weight_index
.
data_ptr
<
int64_t
>
();
AT_DISPATCH_FLOATING_TYPES
(
grad_out
.
scalar_type
(),
"weighting_bw_x"
,
[
&
]
{
AT_DISPATCH_FLOATING_TYPES
_AND
(
at
::
ScalarType
::
BFloat16
,
grad_out
.
scalar_type
(),
"weighting_bw_x"
,
[
&
]
{
auto
grad_out_data
=
grad_out
.
data_ptr
<
scalar_t
>
();
auto
grad_out_data
=
grad_out
.
data_ptr
<
scalar_t
>
();
auto
weight_data
=
weight
.
data_ptr
<
scalar_t
>
();
auto
weight_data
=
weight
.
data_ptr
<
scalar_t
>
();
auto
basis_data
=
basis
.
data_ptr
<
scalar_t
>
();
auto
basis_data
=
basis
.
data_ptr
<
scalar_t
>
();
...
@@ -117,7 +117,7 @@ torch::Tensor spline_weighting_bw_weight_cpu(torch::Tensor grad_out,
...
@@ -117,7 +117,7 @@ torch::Tensor spline_weighting_bw_weight_cpu(torch::Tensor grad_out,
auto
weight_index_data
=
weight_index
.
data_ptr
<
int64_t
>
();
auto
weight_index_data
=
weight_index
.
data_ptr
<
int64_t
>
();
AT_DISPATCH_FLOATING_TYPES
(
x
.
scalar_type
(),
"weighting_bw_weight"
,
[
&
]
{
AT_DISPATCH_FLOATING_TYPES
_AND
(
at
::
ScalarType
::
BFloat16
,
x
.
scalar_type
(),
"weighting_bw_weight"
,
[
&
]
{
auto
grad_out_data
=
grad_out
.
data_ptr
<
scalar_t
>
();
auto
grad_out_data
=
grad_out
.
data_ptr
<
scalar_t
>
();
auto
x_data
=
x
.
data_ptr
<
scalar_t
>
();
auto
x_data
=
x
.
data_ptr
<
scalar_t
>
();
auto
basis_data
=
basis
.
data_ptr
<
scalar_t
>
();
auto
basis_data
=
basis
.
data_ptr
<
scalar_t
>
();
...
@@ -163,7 +163,7 @@ torch::Tensor spline_weighting_bw_basis_cpu(torch::Tensor grad_out,
...
@@ -163,7 +163,7 @@ torch::Tensor spline_weighting_bw_basis_cpu(torch::Tensor grad_out,
auto
weight_index_data
=
weight_index
.
data_ptr
<
int64_t
>
();
auto
weight_index_data
=
weight_index
.
data_ptr
<
int64_t
>
();
AT_DISPATCH_FLOATING_TYPES
(
x
.
scalar_type
(),
"weighting_bw_basis"
,
[
&
]
{
AT_DISPATCH_FLOATING_TYPES
_AND
(
at
::
ScalarType
::
BFloat16
,
x
.
scalar_type
(),
"weighting_bw_basis"
,
[
&
]
{
auto
grad_out_data
=
grad_out
.
data_ptr
<
scalar_t
>
();
auto
grad_out_data
=
grad_out
.
data_ptr
<
scalar_t
>
();
auto
x_data
=
x
.
data_ptr
<
scalar_t
>
();
auto
x_data
=
x
.
data_ptr
<
scalar_t
>
();
auto
weight_data
=
weight
.
data_ptr
<
scalar_t
>
();
auto
weight_data
=
weight
.
data_ptr
<
scalar_t
>
();
...
...
test/test_conv.py
View file @
d6a017ee
...
@@ -54,7 +54,11 @@ def test_spline_conv_forward(test, dtype, device):
...
@@ -54,7 +54,11 @@ def test_spline_conv_forward(test, dtype, device):
out
=
spline_conv
(
x
,
edge_index
,
pseudo
,
weight
,
kernel_size
,
out
=
spline_conv
(
x
,
edge_index
,
pseudo
,
weight
,
kernel_size
,
is_open_spline
,
1
,
True
,
root_weight
,
bias
)
is_open_spline
,
1
,
True
,
root_weight
,
bias
)
assert
out
.
tolist
()
==
test
[
'expected'
]
if
dtype
==
torch
.
bfloat16
:
target
=
torch
.
tensor
(
test
[
'expected'
])
assert
torch
.
allclose
(
out
.
to
(
torch
.
float
),
target
,
rtol
=
1e-2
,
atol
=
1e-2
)
else
:
assert
out
.
tolist
()
==
test
[
'expected'
]
@
pytest
.
mark
.
parametrize
(
'degree,device'
,
product
(
degrees
,
devices
))
@
pytest
.
mark
.
parametrize
(
'degree,device'
,
product
(
degrees
,
devices
))
...
...
test/utils.py
View file @
d6a017ee
import
torch
import
torch
dtypes
=
[
torch
.
float
,
torch
.
double
]
dtypes
=
[
torch
.
float
,
torch
.
double
,
torch
.
bfloat16
]
devices
=
[
torch
.
device
(
'cpu'
)]
devices
=
[
torch
.
device
(
'cpu'
)]
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment