Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
18d37590
Unverified
Commit
18d37590
authored
Aug 18, 2022
by
YanbingJiang
Committed by
GitHub
Aug 18, 2022
Browse files
Add scatter/segment bf16 support (#316)
parent
fc1b1394
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
7 additions
and
6 deletions
+7
-6
csrc/cpu/scatter_cpu.cpp
csrc/cpu/scatter_cpu.cpp
+1
-1
csrc/cpu/segment_coo_cpu.cpp
csrc/cpu/segment_coo_cpu.cpp
+2
-2
csrc/cpu/segment_csr_cpu.cpp
csrc/cpu/segment_csr_cpu.cpp
+2
-2
test/utils.py
test/utils.py
+2
-1
No files found.
csrc/cpu/scatter_cpu.cpp
View file @
18d37590
...
@@ -57,7 +57,7 @@ scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim,
...
@@ -57,7 +57,7 @@ scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim,
auto
N
=
out
.
size
(
dim
);
auto
N
=
out
.
size
(
dim
);
auto
index_info
=
getTensorInfo
<
int64_t
>
(
index
);
auto
index_info
=
getTensorInfo
<
int64_t
>
(
index
);
AT_DISPATCH_ALL_TYPES_AND
(
at
::
ScalarType
::
Half
,
src
.
scalar_type
(),
"
_
"
,
[
&
]
{
AT_DISPATCH_ALL_TYPES_AND
2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
src
.
scalar_type
(),
"
scatter_cpu
"
,
[
&
]
{
auto
src_data
=
src
.
data_ptr
<
scalar_t
>
();
auto
src_data
=
src
.
data_ptr
<
scalar_t
>
();
auto
out_data
=
out
.
data_ptr
<
scalar_t
>
();
auto
out_data
=
out
.
data_ptr
<
scalar_t
>
();
...
...
csrc/cpu/segment_coo_cpu.cpp
View file @
18d37590
...
@@ -69,7 +69,7 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
...
@@ -69,7 +69,7 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
auto
index_info
=
getTensorInfo
<
int64_t
>
(
index
);
auto
index_info
=
getTensorInfo
<
int64_t
>
(
index
);
auto
stride
=
index_info
.
strides
[
index_info
.
dims
-
1
];
auto
stride
=
index_info
.
strides
[
index_info
.
dims
-
1
];
std
::
vector
<
int64_t
>
args
(
K
);
std
::
vector
<
int64_t
>
args
(
K
);
AT_DISPATCH_ALL_TYPES_AND
(
at
::
ScalarType
::
Half
,
src
.
scalar_type
(),
"
_
"
,
[
&
]
{
AT_DISPATCH_ALL_TYPES_AND
2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
src
.
scalar_type
(),
"
segment_coo_cpu
"
,
[
&
]
{
auto
src_data
=
src
.
data_ptr
<
scalar_t
>
();
auto
src_data
=
src
.
data_ptr
<
scalar_t
>
();
auto
out_data
=
out
.
data_ptr
<
scalar_t
>
();
auto
out_data
=
out
.
data_ptr
<
scalar_t
>
();
scalar_t
*
count_data
=
nullptr
;
scalar_t
*
count_data
=
nullptr
;
...
@@ -178,7 +178,7 @@ torch::Tensor gather_coo_cpu(torch::Tensor src, torch::Tensor index,
...
@@ -178,7 +178,7 @@ torch::Tensor gather_coo_cpu(torch::Tensor src, torch::Tensor index,
auto
index_info
=
getTensorInfo
<
int64_t
>
(
index
);
auto
index_info
=
getTensorInfo
<
int64_t
>
(
index
);
auto
stride
=
index_info
.
strides
[
index_info
.
dims
-
1
];
auto
stride
=
index_info
.
strides
[
index_info
.
dims
-
1
];
AT_DISPATCH_ALL_TYPES_AND
(
at
::
ScalarType
::
Half
,
src
.
scalar_type
(),
"
_
"
,
[
&
]
{
AT_DISPATCH_ALL_TYPES_AND
2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
src
.
scalar_type
(),
"
gather_coo_cpu
"
,
[
&
]
{
auto
src_data
=
src
.
data_ptr
<
scalar_t
>
();
auto
src_data
=
src
.
data_ptr
<
scalar_t
>
();
auto
out_data
=
out
.
data_ptr
<
scalar_t
>
();
auto
out_data
=
out
.
data_ptr
<
scalar_t
>
();
...
...
csrc/cpu/segment_csr_cpu.cpp
View file @
18d37590
...
@@ -57,7 +57,7 @@ segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
...
@@ -57,7 +57,7 @@ segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
auto
indptr_info
=
getTensorInfo
<
int64_t
>
(
indptr
);
auto
indptr_info
=
getTensorInfo
<
int64_t
>
(
indptr
);
auto
stride
=
indptr_info
.
strides
[
indptr_info
.
dims
-
1
];
auto
stride
=
indptr_info
.
strides
[
indptr_info
.
dims
-
1
];
std
::
vector
<
int64_t
>
args
(
K
);
std
::
vector
<
int64_t
>
args
(
K
);
AT_DISPATCH_ALL_TYPES_AND
(
at
::
ScalarType
::
Half
,
src
.
scalar_type
(),
"
_
"
,
[
&
]
{
AT_DISPATCH_ALL_TYPES_AND
2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
src
.
scalar_type
(),
"
segment_csr_cpu
"
,
[
&
]
{
auto
src_data
=
src
.
data_ptr
<
scalar_t
>
();
auto
src_data
=
src
.
data_ptr
<
scalar_t
>
();
auto
out_data
=
out
.
data_ptr
<
scalar_t
>
();
auto
out_data
=
out
.
data_ptr
<
scalar_t
>
();
...
@@ -135,7 +135,7 @@ torch::Tensor gather_csr_cpu(torch::Tensor src, torch::Tensor indptr,
...
@@ -135,7 +135,7 @@ torch::Tensor gather_csr_cpu(torch::Tensor src, torch::Tensor indptr,
auto
indptr_info
=
getTensorInfo
<
int64_t
>
(
indptr
);
auto
indptr_info
=
getTensorInfo
<
int64_t
>
(
indptr
);
auto
stride
=
indptr_info
.
strides
[
indptr_info
.
dims
-
1
];
auto
stride
=
indptr_info
.
strides
[
indptr_info
.
dims
-
1
];
AT_DISPATCH_ALL_TYPES_AND
(
at
::
ScalarType
::
Half
,
src
.
scalar_type
(),
"
_
"
,
[
&
]
{
AT_DISPATCH_ALL_TYPES_AND
2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
src
.
scalar_type
(),
"
gather_csr_cpu
"
,
[
&
]
{
auto
src_data
=
src
.
data_ptr
<
scalar_t
>
();
auto
src_data
=
src
.
data_ptr
<
scalar_t
>
();
auto
out_data
=
out
.
data_ptr
<
scalar_t
>
();
auto
out_data
=
out
.
data_ptr
<
scalar_t
>
();
...
...
test/utils.py
View file @
18d37590
...
@@ -2,7 +2,8 @@ import torch
...
@@ -2,7 +2,8 @@ import torch
reductions
=
[
'sum'
,
'add'
,
'mean'
,
'min'
,
'max'
]
reductions
=
[
'sum'
,
'add'
,
'mean'
,
'min'
,
'max'
]
dtypes
=
[
torch
.
half
,
torch
.
float
,
torch
.
double
,
torch
.
int
,
torch
.
long
]
dtypes
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
,
torch
.
double
,
torch
.
int
,
torch
.
long
]
grad_dtypes
=
[
torch
.
float
,
torch
.
double
]
grad_dtypes
=
[
torch
.
float
,
torch
.
double
]
devices
=
[
torch
.
device
(
'cpu'
)]
devices
=
[
torch
.
device
(
'cpu'
)]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment