Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
acb4eb7e
Unverified
Commit
acb4eb7e
authored
Apr 06, 2023
by
Ilia Taraban
Committed by
GitHub
Apr 06, 2023
Browse files
[Feature] Add bfloat16 support for CPU (#5497)
Co-authored-by:
Hongzhi (Steve), Chen
<
chenhongzhi.nkcs@gmail.com
>
parent
29e66615
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
337 additions
and
56 deletions
+337
-56
CMakeLists.txt
CMakeLists.txt
+2
-2
include/dgl/aten/macro.h
include/dgl/aten/macro.h
+53
-26
include/dgl/runtime/bfloat16.h
include/dgl/runtime/bfloat16.h
+68
-0
include/dgl/runtime/ndarray.h
include/dgl/runtime/ndarray.h
+1
-0
python/setup.py
python/setup.py
+1
-1
src/array/cpu/gather_mm.cc
src/array/cpu/gather_mm.cc
+22
-0
src/array/cpu/sddmm.cc
src/array/cpu/sddmm.cc
+36
-0
src/array/cpu/segment_reduce.cc
src/array/cpu/segment_reduce.cc
+34
-0
src/array/cpu/segment_reduce.h
src/array/cpu/segment_reduce.h
+2
-0
src/array/cpu/spmm.cc
src/array/cpu/spmm.cc
+42
-1
src/array/cpu/spmm.h
src/array/cpu/spmm.h
+25
-8
src/array/cpu/spmm_binary_ops.h
src/array/cpu/spmm_binary_ops.h
+18
-2
src/array/cpu/spmm_blocking_libxsmm.h
src/array/cpu/spmm_blocking_libxsmm.h
+6
-0
tests/cpp/test_aten.cc
tests/cpp/test_aten.cc
+9
-0
tests/cpp/test_spmm.cc
tests/cpp/test_spmm.cc
+6
-0
tests/python/common/ops/test_ops.py
tests/python/common/ops/test_ops.py
+12
-16
No files found.
CMakeLists.txt
View file @
acb4eb7e
...
...
@@ -203,11 +203,11 @@ endif(NOT MSVC)
# Compile LIBXSMM
if
((
NOT MSVC
)
AND USE_LIBXSMM
)
if
(
REBUILD_LIBXSMM
)
add_custom_target
(
libxsmm COMMAND make realclean COMMAND make -j ECFLAGS=
"-Wno-error=deprecated-declarations"
BLAS=0
add_custom_target
(
libxsmm COMMAND make realclean COMMAND make -j ECFLAGS=
"-Wno-error=deprecated-declarations"
BLAS=0
CC=
${
CMAKE_C_COMPILER
}
WORKING_DIRECTORY
${
CMAKE_SOURCE_DIR
}
/third_party/libxsmm
)
else
(
REBUILD_LIBXSMM
)
add_custom_target
(
libxsmm COMMAND make -j ECFLAGS=
"-Wno-error=deprecated-declarations"
BLAS=0
add_custom_target
(
libxsmm COMMAND make -j ECFLAGS=
"-Wno-error=deprecated-declarations"
BLAS=0
CC=
${
CMAKE_C_COMPILER
}
WORKING_DIRECTORY
${
CMAKE_SOURCE_DIR
}
/third_party/libxsmm
)
endif
(
REBUILD_LIBXSMM
)
...
...
include/dgl/aten/macro.h
View file @
acb4eb7e
...
...
@@ -152,42 +152,69 @@
XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLBfloat) { \
typedef __nv_bfloat16 FloatType; \
{ __VA_ARGS__ } \
} else if (XPU == kDGLCPU) { \
LOG(FATAL) << (val_name) << " can only be float32 or float64 on CPU"; \
} else if ( \
XPU == kDGLCPU && (val).bits == 16 && (val).code == kDGLFloat) { \
LOG(FATAL) << (val_name) << " can't be float16 on CPU"; \
} else if ( \
XPU == kDGLCPU && (val).bits == 16 && (val).code == kDGLBfloat) { \
typedef BFloat16 FloatType; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << (val_name) \
<< " can only be float16/bfloat16/float32/float64 on GPU"; \
} \
} while (0)
#else // BF16_ENABLED
#define ATEN_FLOAT_TYPE_SWITCH_16BITS(val, FloatType, XPU, val_name, ...) \
do { \
CHECK((val).code == kDGLFloat || (val.code == kDGLBfloat)) \
<< (val_name) << " must be float type"; \
if ((val).bits == 32) { \
typedef float FloatType; \
{ __VA_ARGS__ } \
} else if ((val).bits == 64) { \
typedef double FloatType; \
{ __VA_ARGS__ } \
} else if ( \
XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLFloat) { \
typedef __half FloatType; \
{ __VA_ARGS__ } \
} else if ( \
XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLBfloat) { \
LOG(FATAL) << "bfloat16 requires CUDA >= 11.0"; \
} else if (XPU == kDGLCPU) { \
LOG(FATAL) << (val_name) << " can only be float32 or float64 on CPU"; \
} else { \
LOG(FATAL) << (val_name) \
<< " can only be float16/float32/float64 on GPU"; \
} \
#define ATEN_FLOAT_TYPE_SWITCH_16BITS(val, FloatType, XPU, val_name, ...) \
do { \
CHECK((val).code == kDGLFloat || (val.code == kDGLBfloat)) \
<< (val_name) << " must be float type"; \
if ((val).bits == 32) { \
typedef float FloatType; \
{ __VA_ARGS__ } \
} else if ((val).bits == 64) { \
typedef double FloatType; \
{ __VA_ARGS__ } \
} else if ( \
XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLFloat) { \
typedef __half FloatType; \
{ __VA_ARGS__ } \
} else if ( \
XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLBfloat) { \
LOG(FATAL) << "bfloat16 requires CUDA >= 11.0"; \
} else if ( \
XPU == kDGLCPU && (val).bits == 16 && (val).code == kDGLFloat) { \
LOG(FATAL) << (val_name) << " can't be float16 on CPU"; \
} else if ( \
XPU == kDGLCPU && (val).bits == 16 && (val).code == kDGLBfloat) { \
typedef BFloat16 FloatType; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << (val_name) \
<< " can only be float16/float32/float64 on GPU"; \
} \
} while (0)
#endif // BF16_ENABLED
#else // DGL_USE_CUDA
#define ATEN_FLOAT_TYPE_SWITCH_16BITS(val, FloatType, XPU, val_name, ...) \
ATEN_FLOAT_TYPE_SWITCH(val, FloatType, val_name, {__VA_ARGS__})
do { \
CHECK((val).code == kDGLFloat || (val.code == kDGLBfloat)) \
<< (val_name) << " must be float type"; \
if ((val).bits == 32) { \
typedef float FloatType; \
{ __VA_ARGS__ } \
} else if ((val).bits == 64) { \
typedef double FloatType; \
{ __VA_ARGS__ } \
} else if ( \
XPU == kDGLCPU && (val).bits == 16 && (val).code == kDGLBfloat) { \
typedef BFloat16 FloatType; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << (val_name) \
<< " can only be bfloat16/float32/float64 on CPU"; \
} \
} while (0)
#endif // DGL_USE_CUDA
/**
...
...
include/dgl/runtime/bfloat16.h
0 → 100644
View file @
acb4eb7e
/**
* Copyright (c) 2023 by Contributors
* @file dgl/runtime/ndarray.h
* @brief BFloat16 CPU header
*/
#ifndef DGL_RUNTIME_BFLOAT16_H_
#define DGL_RUNTIME_BFLOAT16_H_
#include <cmath>
class
BFloat16
{
uint16_t
val
;
public:
constexpr
BFloat16
()
:
val
(
0
)
{}
// Disable lint "explicit" warning, since implicit usage on constructor is
// expected.
BFloat16
(
float
f
)
{
// NOLINT
if
(
std
::
isnan
(
f
))
{
val
=
0x7FC0
;
}
else
{
union
{
uint16_t
iraw16
[
2
];
uint32_t
iraw32
;
float
f32
;
};
f32
=
f
;
const
uint32_t
rounding_bias
=
0x00007FFF
+
(
iraw16
[
1
]
&
0x1
);
val
=
static_cast
<
uint16_t
>
((
iraw32
+
rounding_bias
)
>>
16
);
}
}
static
constexpr
BFloat16
Min
()
{
BFloat16
min
;
min
.
val
=
0xFF80
;
return
min
;
}
static
constexpr
BFloat16
Max
()
{
BFloat16
max
;
max
.
val
=
0x7F80
;
return
max
;
}
BFloat16
&
operator
-=
(
const
float
&
rhs
)
{
float
lhs
=
(
*
this
);
(
*
this
)
=
lhs
-
rhs
;
return
*
this
;
}
BFloat16
&
operator
+=
(
const
float
&
rhs
)
{
float
lhs
=
(
*
this
);
(
*
this
)
=
lhs
+
rhs
;
return
*
this
;
}
operator
float
()
const
{
union
{
float
f
;
uint16_t
raw
[
2
];
};
raw
[
0
]
=
0
;
raw
[
1
]
=
val
;
return
f
;
}
};
#endif // DGL_RUNTIME_BFLOAT16_H_
include/dgl/runtime/ndarray.h
View file @
acb4eb7e
...
...
@@ -12,6 +12,7 @@
#include <utility>
#include <vector>
#include "bfloat16.h"
#include "c_runtime_api.h"
#include "serializer.h"
#include "shared_mem.h"
...
...
python/setup.py
View file @
acb4eb7e
...
...
@@ -153,7 +153,7 @@ def config_cython():
library_dirs
=
library_dirs
,
libraries
=
libraries
,
# Crashes without this flag with GCC 5.3.1
extra_compile_args
=
[
"-std=c++1
1
"
],
extra_compile_args
=
[
"-std=c++1
4
"
],
language
=
"c++"
,
)
)
...
...
src/array/cpu/gather_mm.cc
View file @
acb4eb7e
...
...
@@ -40,6 +40,12 @@ void GatherMMScatter(
LOG
(
FATAL
)
<<
"Unsupported CPU kernel for GatherMM."
;
}
template
void
GatherMM
<
kDGLCPU
,
int32_t
,
BFloat16
>(
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
idx_a
,
const
NDArray
idx_b
);
template
void
GatherMM
<
kDGLCPU
,
int64_t
,
BFloat16
>(
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
idx_a
,
const
NDArray
idx_b
);
template
void
GatherMM
<
kDGLCPU
,
int32_t
,
float
>(
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
idx_a
,
const
NDArray
idx_b
);
...
...
@@ -53,6 +59,12 @@ template void GatherMM<kDGLCPU, int64_t, double>(
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
idx_a
,
const
NDArray
idx_b
);
template
void
GatherMMScatter
<
kDGLCPU
,
int32_t
,
BFloat16
>(
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
idx_a
,
const
NDArray
idx_b
,
const
NDArray
idx_c
);
template
void
GatherMMScatter
<
kDGLCPU
,
int64_t
,
BFloat16
>(
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
idx_a
,
const
NDArray
idx_b
,
const
NDArray
idx_c
);
template
void
GatherMMScatter
<
kDGLCPU
,
int32_t
,
float
>(
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
idx_a
,
const
NDArray
idx_b
,
const
NDArray
idx_c
);
...
...
@@ -66,6 +78,12 @@ template void GatherMMScatter<kDGLCPU, int64_t, double>(
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
idx_a
,
const
NDArray
idx_b
,
const
NDArray
idx_c
);
template
void
SegmentMM
<
kDGLCPU
,
int32_t
,
BFloat16
>(
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
seglen_A
,
bool
a_trans
,
bool
b_trans
);
template
void
SegmentMM
<
kDGLCPU
,
int64_t
,
BFloat16
>(
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
seglen_A
,
bool
a_trans
,
bool
b_trans
);
template
void
SegmentMM
<
kDGLCPU
,
int32_t
,
float
>(
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
seglen_A
,
bool
a_trans
,
bool
b_trans
);
...
...
@@ -79,6 +97,10 @@ template void SegmentMM<kDGLCPU, int64_t, double>(
const
NDArray
A
,
const
NDArray
B
,
NDArray
C
,
const
NDArray
seglen_A
,
bool
a_trans
,
bool
b_trans
);
template
void
SegmentMMBackwardB
<
kDGLCPU
,
int32_t
,
BFloat16
>(
const
NDArray
A
,
const
NDArray
dC
,
NDArray
dB
,
const
NDArray
seglen
);
template
void
SegmentMMBackwardB
<
kDGLCPU
,
int64_t
,
BFloat16
>(
const
NDArray
A
,
const
NDArray
dC
,
NDArray
dB
,
const
NDArray
seglen
);
template
void
SegmentMMBackwardB
<
kDGLCPU
,
int32_t
,
float
>(
const
NDArray
A
,
const
NDArray
dC
,
NDArray
dB
,
const
NDArray
seglen
);
template
void
SegmentMMBackwardB
<
kDGLCPU
,
int64_t
,
float
>(
...
...
src/array/cpu/sddmm.cc
View file @
acb4eb7e
...
...
@@ -78,6 +78,12 @@ void SDDMMCsrHetero(
});
}
template
void
SDDMMCsr
<
kDGLCPU
,
int32_t
,
BFloat16
>(
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
CSRMatrix
&
csr
,
NDArray
lhs
,
NDArray
rhs
,
NDArray
out
,
int
lhs_target
,
int
rhs_target
);
template
void
SDDMMCsr
<
kDGLCPU
,
int64_t
,
BFloat16
>(
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
CSRMatrix
&
csr
,
NDArray
lhs
,
NDArray
rhs
,
NDArray
out
,
int
lhs_target
,
int
rhs_target
);
template
void
SDDMMCsr
<
kDGLCPU
,
int32_t
,
float
>(
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
CSRMatrix
&
csr
,
NDArray
lhs
,
NDArray
rhs
,
NDArray
out
,
int
lhs_target
,
int
rhs_target
);
...
...
@@ -91,6 +97,18 @@ template void SDDMMCsr<kDGLCPU, int64_t, double>(
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
CSRMatrix
&
csr
,
NDArray
lhs
,
NDArray
rhs
,
NDArray
out
,
int
lhs_target
,
int
rhs_target
);
template
void
SDDMMCsrHetero
<
kDGLCPU
,
int32_t
,
BFloat16
>(
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
std
::
vector
<
CSRMatrix
>&
vec_csr
,
const
std
::
vector
<
NDArray
>&
lhs
,
const
std
::
vector
<
NDArray
>&
rhs
,
std
::
vector
<
NDArray
>
out
,
int
lhs_target
,
int
rhs_target
,
const
std
::
vector
<
dgl_type_t
>&
in_eid
,
const
std
::
vector
<
dgl_type_t
>&
out_eid
);
template
void
SDDMMCsrHetero
<
kDGLCPU
,
int64_t
,
BFloat16
>(
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
std
::
vector
<
CSRMatrix
>&
vec_csr
,
const
std
::
vector
<
NDArray
>&
lhs
,
const
std
::
vector
<
NDArray
>&
rhs
,
std
::
vector
<
NDArray
>
out
,
int
lhs_target
,
int
rhs_target
,
const
std
::
vector
<
dgl_type_t
>&
in_eid
,
const
std
::
vector
<
dgl_type_t
>&
out_eid
);
template
void
SDDMMCsrHetero
<
kDGLCPU
,
int32_t
,
float
>(
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
std
::
vector
<
CSRMatrix
>&
vec_csr
,
const
std
::
vector
<
NDArray
>&
lhs
,
...
...
@@ -152,6 +170,12 @@ void SDDMMCooHetero(
});
}
template
void
SDDMMCoo
<
kDGLCPU
,
int32_t
,
BFloat16
>(
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
COOMatrix
&
coo
,
NDArray
lhs
,
NDArray
rhs
,
NDArray
out
,
int
lhs_target
,
int
rhs_target
);
template
void
SDDMMCoo
<
kDGLCPU
,
int64_t
,
BFloat16
>(
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
COOMatrix
&
coo
,
NDArray
lhs
,
NDArray
rhs
,
NDArray
out
,
int
lhs_target
,
int
rhs_target
);
template
void
SDDMMCoo
<
kDGLCPU
,
int32_t
,
float
>(
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
COOMatrix
&
coo
,
NDArray
lhs
,
NDArray
rhs
,
NDArray
out
,
int
lhs_target
,
int
rhs_target
);
...
...
@@ -165,6 +189,18 @@ template void SDDMMCoo<kDGLCPU, int64_t, double>(
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
COOMatrix
&
coo
,
NDArray
lhs
,
NDArray
rhs
,
NDArray
out
,
int
lhs_target
,
int
rhs_target
);
template
void
SDDMMCooHetero
<
kDGLCPU
,
int32_t
,
BFloat16
>(
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
std
::
vector
<
COOMatrix
>&
vec_coo
,
const
std
::
vector
<
NDArray
>&
lhs
,
const
std
::
vector
<
NDArray
>&
rhs
,
std
::
vector
<
NDArray
>
out
,
int
lhs_target
,
int
rhs_target
,
const
std
::
vector
<
dgl_type_t
>&
in_eid
,
const
std
::
vector
<
dgl_type_t
>&
out_eid
);
template
void
SDDMMCooHetero
<
kDGLCPU
,
int64_t
,
BFloat16
>(
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
std
::
vector
<
COOMatrix
>&
vec_coo
,
const
std
::
vector
<
NDArray
>&
lhs
,
const
std
::
vector
<
NDArray
>&
rhs
,
std
::
vector
<
NDArray
>
out
,
int
lhs_target
,
int
rhs_target
,
const
std
::
vector
<
dgl_type_t
>&
in_eid
,
const
std
::
vector
<
dgl_type_t
>&
out_eid
);
template
void
SDDMMCooHetero
<
kDGLCPU
,
int32_t
,
float
>(
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
std
::
vector
<
COOMatrix
>&
vec_coo
,
const
std
::
vector
<
NDArray
>&
lhs
,
...
...
src/array/cpu/segment_reduce.cc
View file @
acb4eb7e
...
...
@@ -56,6 +56,12 @@ void BackwardSegmentCmp(NDArray feat, NDArray arg, NDArray out) {
cpu
::
BackwardSegmentCmp
<
IdType
,
DType
>
(
feat
,
arg
,
out
);
}
template
void
SegmentReduce
<
kDGLCPU
,
int32_t
,
BFloat16
>(
const
std
::
string
&
op
,
NDArray
feat
,
NDArray
offsets
,
NDArray
out
,
NDArray
arg
);
template
void
SegmentReduce
<
kDGLCPU
,
int64_t
,
BFloat16
>(
const
std
::
string
&
op
,
NDArray
feat
,
NDArray
offsets
,
NDArray
out
,
NDArray
arg
);
template
void
SegmentReduce
<
kDGLCPU
,
int32_t
,
float
>(
const
std
::
string
&
op
,
NDArray
feat
,
NDArray
offsets
,
NDArray
out
,
NDArray
arg
);
...
...
@@ -69,6 +75,16 @@ template void SegmentReduce<kDGLCPU, int64_t, double>(
const
std
::
string
&
op
,
NDArray
feat
,
NDArray
offsets
,
NDArray
out
,
NDArray
arg
);
template
<
>
void
ScatterAdd
<
kDGLCPU
,
int32_t
,
BFloat16
>
(
NDArray
feat
,
NDArray
idx
,
NDArray
out
)
{
LOG
(
FATAL
)
<<
"Unsupported CPU kernel for ScatterAdd for BF16."
;
}
template
<
>
void
ScatterAdd
<
kDGLCPU
,
int64_t
,
BFloat16
>
(
NDArray
feat
,
NDArray
idx
,
NDArray
out
)
{
LOG
(
FATAL
)
<<
"Unsupported CPU kernel for ScatterAdd for BF16."
;
}
template
void
ScatterAdd
<
kDGLCPU
,
int32_t
,
float
>(
NDArray
feat
,
NDArray
idx
,
NDArray
out
);
template
void
ScatterAdd
<
kDGLCPU
,
int64_t
,
float
>(
...
...
@@ -78,6 +94,20 @@ template void ScatterAdd<kDGLCPU, int32_t, double>(
template
void
ScatterAdd
<
kDGLCPU
,
int64_t
,
double
>(
NDArray
feat
,
NDArray
arg
,
NDArray
out
);
template
<
>
void
UpdateGradMinMax_hetero
<
kDGLCPU
,
int32_t
,
BFloat16
>
(
const
HeteroGraphPtr
&
g
,
const
std
::
string
&
op
,
const
std
::
vector
<
NDArray
>&
feat
,
const
std
::
vector
<
NDArray
>&
idx
,
const
std
::
vector
<
NDArray
>&
idx_etype
,
std
::
vector
<
NDArray
>*
out
)
{
LOG
(
FATAL
)
<<
"Unsupported CPU kernel for UpdateGradMinMax_hetero for BF16."
;
}
template
<
>
void
UpdateGradMinMax_hetero
<
kDGLCPU
,
int64_t
,
BFloat16
>
(
const
HeteroGraphPtr
&
g
,
const
std
::
string
&
op
,
const
std
::
vector
<
NDArray
>&
feat
,
const
std
::
vector
<
NDArray
>&
idx
,
const
std
::
vector
<
NDArray
>&
idx_etype
,
std
::
vector
<
NDArray
>*
out
)
{
LOG
(
FATAL
)
<<
"Unsupported CPU kernel for UpdateGradMinMax_hetero for BF16."
;
}
template
void
UpdateGradMinMax_hetero
<
kDGLCPU
,
int32_t
,
float
>(
const
HeteroGraphPtr
&
g
,
const
std
::
string
&
op
,
const
std
::
vector
<
NDArray
>&
feat
,
const
std
::
vector
<
NDArray
>&
idx
,
...
...
@@ -95,6 +125,10 @@ template void UpdateGradMinMax_hetero<kDGLCPU, int64_t, double>(
const
std
::
vector
<
NDArray
>&
feat
,
const
std
::
vector
<
NDArray
>&
idx
,
const
std
::
vector
<
NDArray
>&
idx_etype
,
std
::
vector
<
NDArray
>*
out
);
template
void
BackwardSegmentCmp
<
kDGLCPU
,
int32_t
,
BFloat16
>(
NDArray
feat
,
NDArray
arg
,
NDArray
out
);
template
void
BackwardSegmentCmp
<
kDGLCPU
,
int64_t
,
BFloat16
>(
NDArray
feat
,
NDArray
arg
,
NDArray
out
);
template
void
BackwardSegmentCmp
<
kDGLCPU
,
int32_t
,
float
>(
NDArray
feat
,
NDArray
arg
,
NDArray
out
);
template
void
BackwardSegmentCmp
<
kDGLCPU
,
int64_t
,
float
>(
...
...
src/array/cpu/segment_reduce.h
View file @
acb4eb7e
...
...
@@ -25,6 +25,8 @@ namespace cpu {
*/
template
<
typename
IdType
,
typename
DType
>
void
SegmentSum
(
NDArray
feat
,
NDArray
offsets
,
NDArray
out
)
{
if
(
std
::
is_same
<
DType
,
BFloat16
>::
value
)
LOG
(
FATAL
)
<<
"Unsupported CPU kernel for SegmentSum for BF16."
;
int
n
=
out
->
shape
[
0
];
int
dim
=
1
;
for
(
int
i
=
1
;
i
<
out
->
ndim
;
++
i
)
dim
*=
out
->
shape
[
i
];
...
...
src/array/cpu/spmm.cc
View file @
acb4eb7e
...
...
@@ -124,6 +124,14 @@ void SpMMCsrHetero(
}
}
template
void
SpMMCsr
<
kDGLCPU
,
int32_t
,
BFloat16
>(
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
BcastOff
&
bcast
,
const
CSRMatrix
&
csr
,
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
std
::
vector
<
NDArray
>
out_aux
);
template
void
SpMMCsr
<
kDGLCPU
,
int64_t
,
BFloat16
>(
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
BcastOff
&
bcast
,
const
CSRMatrix
&
csr
,
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
std
::
vector
<
NDArray
>
out_aux
);
template
void
SpMMCsr
<
kDGLCPU
,
int32_t
,
float
>(
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
BcastOff
&
bcast
,
const
CSRMatrix
&
csr
,
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
...
...
@@ -141,6 +149,20 @@ template void SpMMCsr<kDGLCPU, int64_t, double>(
const
CSRMatrix
&
csr
,
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
std
::
vector
<
NDArray
>
out_aux
);
template
void
SpMMCsrHetero
<
kDGLCPU
,
int32_t
,
BFloat16
>(
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
BcastOff
&
bcast
,
const
std
::
vector
<
CSRMatrix
>&
csr
,
const
std
::
vector
<
NDArray
>&
ufeat
,
const
std
::
vector
<
NDArray
>&
efeat
,
std
::
vector
<
NDArray
>*
out
,
std
::
vector
<
std
::
vector
<
NDArray
>>*
out_aux
,
const
std
::
vector
<
dgl_type_t
>&
ufeat_node_tids
,
const
std
::
vector
<
dgl_type_t
>&
out_node_tids
);
template
void
SpMMCsrHetero
<
kDGLCPU
,
int64_t
,
BFloat16
>(
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
BcastOff
&
bcast
,
const
std
::
vector
<
CSRMatrix
>&
csr
,
const
std
::
vector
<
NDArray
>&
ufeat
,
const
std
::
vector
<
NDArray
>&
efeat
,
std
::
vector
<
NDArray
>*
out
,
std
::
vector
<
std
::
vector
<
NDArray
>>*
out_aux
,
const
std
::
vector
<
dgl_type_t
>&
ufeat_node_tids
,
const
std
::
vector
<
dgl_type_t
>&
out_node_tids
);
template
void
SpMMCsrHetero
<
kDGLCPU
,
int32_t
,
float
>(
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
BcastOff
&
bcast
,
const
std
::
vector
<
CSRMatrix
>&
csr
,
const
std
::
vector
<
NDArray
>&
ufeat
,
...
...
@@ -191,7 +213,12 @@ void Edge_softmax_csr_backward(
bcast
,
csr
,
out
,
sds
,
back_out
);
});
}
template
void
Edge_softmax_csr_forward
<
kDGLCPU
,
int32_t
,
BFloat16
>(
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
CSRMatrix
&
csr
,
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
);
template
void
Edge_softmax_csr_forward
<
kDGLCPU
,
int64_t
,
BFloat16
>(
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
CSRMatrix
&
csr
,
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
);
template
void
Edge_softmax_csr_forward
<
kDGLCPU
,
int32_t
,
float
>(
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
CSRMatrix
&
csr
,
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
);
...
...
@@ -205,6 +232,12 @@ template void Edge_softmax_csr_forward<kDGLCPU, int64_t, double>(
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
CSRMatrix
&
csr
,
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
);
template
void
Edge_softmax_csr_backward
<
kDGLCPU
,
int32_t
,
BFloat16
>(
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
CSRMatrix
&
csr
,
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
);
template
void
Edge_softmax_csr_backward
<
kDGLCPU
,
int64_t
,
BFloat16
>(
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
CSRMatrix
&
csr
,
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
);
template
void
Edge_softmax_csr_backward
<
kDGLCPU
,
int32_t
,
float
>(
const
std
::
string
&
op
,
const
BcastOff
&
bcast
,
const
CSRMatrix
&
csr
,
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
);
...
...
@@ -242,6 +275,14 @@ void SpMMCoo(
}
}
template
void
SpMMCoo
<
kDGLCPU
,
int32_t
,
BFloat16
>(
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
BcastOff
&
bcast
,
const
COOMatrix
&
coo
,
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
std
::
vector
<
NDArray
>
out_aux
);
template
void
SpMMCoo
<
kDGLCPU
,
int64_t
,
BFloat16
>(
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
BcastOff
&
bcast
,
const
COOMatrix
&
coo
,
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
std
::
vector
<
NDArray
>
out_aux
);
template
void
SpMMCoo
<
kDGLCPU
,
int32_t
,
float
>(
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
BcastOff
&
bcast
,
const
COOMatrix
&
coo
,
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
...
...
src/array/cpu/spmm.h
View file @
acb4eb7e
...
...
@@ -27,6 +27,10 @@ namespace dgl {
namespace
aten
{
namespace
cpu
{
template
<
typename
DType
>
using
AccType
=
typename
std
::
conditional
<
std
::
is_same
<
DType
,
BFloat16
>::
value
,
float
,
DType
>::
type
;
/**
* @brief Naive CPU kernel of SpMM on Csr format.
* @param cpu_spec JIT'ed kernel
...
...
@@ -51,18 +55,20 @@ void SpMMSumCsrNaive(
for
(
auto
rid
=
b
;
rid
<
e
;
++
rid
)
{
const
IdType
row_start
=
indptr
[
rid
],
row_end
=
indptr
[
rid
+
1
];
DType
*
out_off
=
O
+
rid
*
dim
;
for
(
IdType
j
=
row_start
;
j
<
row_end
;
++
j
)
{
const
IdType
cid
=
indices
[
j
];
const
IdType
eid
=
has_idx
?
edges
[
j
]
:
j
;
for
(
int64_t
k
=
0
;
k
<
dim
;
++
k
)
{
for
(
int64_t
k
=
0
;
k
<
dim
;
++
k
)
{
AccType
<
DType
>
acc
=
0.
;
for
(
IdType
j
=
row_start
;
j
<
row_end
;
++
j
)
{
const
IdType
cid
=
indices
[
j
];
const
IdType
eid
=
has_idx
?
edges
[
j
]
:
j
;
const
int64_t
lhs_add
=
bcast
.
use_bcast
?
bcast
.
lhs_offset
[
k
]
:
k
;
const
int64_t
rhs_add
=
bcast
.
use_bcast
?
bcast
.
rhs_offset
[
k
]
:
k
;
const
DType
*
lhs_off
=
Op
::
use_lhs
?
X
+
cid
*
lhs_dim
+
lhs_add
:
nullptr
;
const
DType
*
rhs_off
=
Op
::
use_rhs
?
W
+
eid
*
rhs_dim
+
rhs_add
:
nullptr
;
out_off
[
k
]
+=
Op
::
Call
(
lhs_off
,
rhs_off
);
acc
+=
Op
::
Call
(
lhs_off
,
rhs_off
);
}
out_off
[
k
]
+=
acc
;
}
}
});
...
...
@@ -129,7 +135,8 @@ void SpMMSumCsr(
* we use atomic operators in the reduction phase.
*/
template
<
typename
IdType
,
typename
DType
,
typename
Op
>
void
SpMMSumCoo
(
typename
std
::
enable_if
<!
std
::
is_same
<
DType
,
BFloat16
>::
value
,
void
>::
type
SpMMSumCoo
(
const
BcastOff
&
bcast
,
const
COOMatrix
&
coo
,
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
)
{
const
bool
has_idx
=
!
IsNullArray
(
coo
.
data
);
...
...
@@ -166,6 +173,14 @@ void SpMMSumCoo(
}
}
template
<
typename
IdType
,
typename
DType
,
typename
Op
>
typename
std
::
enable_if
<
std
::
is_same
<
DType
,
BFloat16
>::
value
,
void
>::
type
SpMMSumCoo
(
const
BcastOff
&
bcast
,
const
COOMatrix
&
coo
,
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
)
{
LOG
(
FATAL
)
<<
"Unsupported CPU kernel for SpMMSumCoo for BF16."
;
}
/**
* @brief CPU kernel of SpMM-Min/Max on Csr format.
* @param bcast Broadcast information.
...
...
@@ -442,7 +457,7 @@ void Edge_softmax_csr_forward(
runtime
::
parallel_for
(
0
,
csr
.
num_rows
,
[
&
](
size_t
b
,
size_t
e
)
{
for
(
auto
rid
=
b
;
rid
<
e
;
++
rid
)
{
const
IdType
row_start
=
indptr
[
rid
],
row_end
=
indptr
[
rid
+
1
];
std
::
vector
<
DType
>
data_e
(
row_end
-
row_start
,
0
);
std
::
vector
<
AccType
<
DType
>
>
data_e
(
row_end
-
row_start
,
0
);
std
::
vector
<
IdType
>
num
(
row_end
-
row_start
,
0
);
for
(
int64_t
k
=
0
;
k
<
dim
;
++
k
)
{
DType
max_v
=
-
std
::
numeric_limits
<
DType
>::
infinity
();
...
...
@@ -481,6 +496,8 @@ template <typename IdType, typename DType, typename Op>
void
Edge_softmax_csr_backward
(
const
BcastOff
&
bcast
,
const
CSRMatrix
&
csr
,
NDArray
out
,
NDArray
sds
,
NDArray
back_out
)
{
typedef
typename
std
::
conditional
<
std
::
is_same
<
DType
,
BFloat16
>::
value
,
float
,
DType
>::
type
AccType
;
const
bool
has_idx
=
!
IsNullArray
(
csr
.
data
);
const
IdType
*
indptr
=
static_cast
<
IdType
*>
(
csr
.
indptr
->
data
);
const
IdType
*
edges
=
...
...
@@ -492,7 +509,7 @@ void Edge_softmax_csr_backward(
for
(
auto
rid
=
b
;
rid
<
e
;
++
rid
)
{
const
IdType
row_start
=
indptr
[
rid
],
row_end
=
indptr
[
rid
+
1
];
for
(
int64_t
k
=
0
;
k
<
dim
;
++
k
)
{
D
Type
sum_sds
=
0
;
Acc
Type
sum_sds
=
0
;
for
(
IdType
j
=
row_start
;
j
<
row_end
;
++
j
)
{
const
IdType
eid
=
has_idx
?
edges
[
j
]
:
j
;
const
int64_t
rhs_add
=
bcast
.
use_bcast
?
bcast
.
rhs_offset
[
k
]
:
k
;
...
...
src/array/cpu/spmm_binary_ops.h
View file @
acb4eb7e
...
...
@@ -102,20 +102,36 @@ constexpr bool CopyRhs<DType>::use_rhs;
//////////////////////////////// Reduce operators on CPU
///////////////////////////////////
template
<
typename
DType
>
constexpr
DType
MinDType
()
{
if
(
std
::
is_same
<
DType
,
BFloat16
>::
value
)
return
BFloat16
::
Min
();
else
return
-
std
::
numeric_limits
<
DType
>::
infinity
();
}
template
<
typename
DType
>
struct
Max
{
typedef
DType
type
;
static
constexpr
DType
zero
=
-
std
::
numeric_limits
<
DType
>::
infinity
();
static
constexpr
DType
zero
=
MinDType
<
DType
>
();
// return true if accum should be replaced
inline
static
DType
Call
(
DType
accum
,
DType
val
)
{
return
accum
<
val
;
}
};
template
<
typename
DType
>
constexpr
DType
Max
<
DType
>::
zero
;
template
<
typename
DType
>
constexpr
DType
MaxDType
()
{
if
(
std
::
is_same
<
DType
,
BFloat16
>::
value
)
return
BFloat16
::
Max
();
else
return
std
::
numeric_limits
<
DType
>::
infinity
();
}
template
<
typename
DType
>
struct
Min
{
typedef
DType
type
;
static
constexpr
DType
zero
=
std
::
numeric_limits
<
DType
>::
infinity
();
static
constexpr
DType
zero
=
MaxDType
<
DType
>
();
// return true if accum should be replaced
inline
static
DType
Call
(
DType
accum
,
DType
val
)
{
return
accum
>
val
;
}
};
...
...
src/array/cpu/spmm_blocking_libxsmm.h
View file @
acb4eb7e
...
...
@@ -257,7 +257,13 @@ inline libxsmm_meltwfunction_opreduce_vecs_idx SpMMCreateLibxsmmKernel(
N
,
&
_ld
,
&
_ld
,
LIBXSMM_DATATYPE_F32
,
LIBXSMM_DATATYPE_F32
,
(
sizeof
(
IdType
)
==
8
)
?
LIBXSMM_DATATYPE_I64
:
LIBXSMM_DATATYPE_I32
,
opredop_flags
);
}
else
{
// assume bf16
kernel
=
libxsmm_dispatch_meltw_opreduce_vecs_idx
(
N
,
&
_ld
,
&
_ld
,
LIBXSMM_DATATYPE_BF16
,
LIBXSMM_DATATYPE_BF16
,
(
sizeof
(
IdType
)
==
8
)
?
LIBXSMM_DATATYPE_I64
:
LIBXSMM_DATATYPE_I32
,
opredop_flags
);
}
if
(
kernel
==
nullptr
)
{
LOG
(
FATAL
)
<<
"Failed to generate libxsmm kernel for the SpMM operation."
"To disable libxsmm, use dgl.use_libxsmm(false)."
;
...
...
tests/cpp/test_aten.cc
View file @
acb4eb7e
...
...
@@ -1426,3 +1426,12 @@ TEST(ArrayTest, Sort) {
_TestSort
<
int64_t
>
(
GPU
);
#endif
}
TEST
(
ArrayTest
,
BFloatCast
)
{
for
(
int
i
=
-
100
;
i
<
100
;
++
i
)
{
float
a
=
i
;
BFloat16
b
=
a
;
float
a_casted
=
b
;
ASSERT_FLOAT_EQ
(
a
,
a_casted
);
}
}
tests/cpp/test_spmm.cc
View file @
acb4eb7e
...
...
@@ -105,6 +105,7 @@ void _TestSpmmCopyLhs() {
TEST
(
SpmmTest
,
TestSpmmCopyLhs
)
{
_TestSpmmCopyLhs
<
float
>
();
_TestSpmmCopyLhs
<
double
>
();
_TestSpmmCopyLhs
<
BFloat16
>
();
}
template
<
typename
IDX
>
...
...
@@ -130,6 +131,7 @@ void _TestSpmmCopyRhs() {
TEST
(
SpmmTest
,
TestSpmmCopyRhs
)
{
_TestSpmmCopyRhs
<
float
>
();
_TestSpmmCopyRhs
<
double
>
();
_TestSpmmCopyRhs
<
BFloat16
>
();
}
template
<
typename
IDX
>
...
...
@@ -156,6 +158,7 @@ void _TestSpmmAdd() {
TEST
(
SpmmTest
,
TestSpmmAdd
)
{
_TestSpmmAdd
<
float
>
();
_TestSpmmAdd
<
double
>
();
_TestSpmmAdd
<
BFloat16
>
();
}
template
<
typename
IDX
>
...
...
@@ -182,6 +185,7 @@ void _TestSpmmSub() {
TEST
(
SpmmTest
,
TestSpmmSub
)
{
_TestSpmmSub
<
float
>
();
_TestSpmmSub
<
double
>
();
_TestSpmmSub
<
BFloat16
>
();
}
template
<
typename
IDX
>
...
...
@@ -208,6 +212,7 @@ void _TestSpmmMul() {
TEST
(
SpmmTest
,
TestSpmmMul
)
{
_TestSpmmMul
<
float
>
();
_TestSpmmMul
<
double
>
();
_TestSpmmMul
<
BFloat16
>
();
}
template
<
typename
IDX
>
...
...
@@ -234,5 +239,6 @@ void _TestSpmmDiv() {
TEST
(
SpmmTest
,
TestSpmmDiv
)
{
_TestSpmmDiv
<
float
>
();
_TestSpmmDiv
<
double
>
();
_TestSpmmDiv
<
BFloat16
>
();
}
#endif // _WIN32
tests/python/common/ops/test_ops.py
View file @
acb4eb7e
...
...
@@ -176,17 +176,19 @@ def test_spmm(idtype, g, shp, msg, reducer):
dgl
.
backend
.
backend_name
!=
"pytorch"
,
reason
=
"Only support PyTorch for now."
,
)
@
unittest
.
skipIf
(
F
.
_default_context_str
==
"cpu"
,
reason
=
"Don't support half precision on CPU."
,
)
@
parametrize_idtype
@
pytest
.
mark
.
parametrize
(
"dtype, rtol, atol"
,
[(
torch
.
float16
,
1e-3
,
0.5
),
(
torch
.
bfloat16
,
4e-3
,
2.0
)],
)
def
test_half_spmm
(
idtype
,
dtype
,
rtol
,
atol
):
if
dtype
==
torch
.
bfloat16
and
not
torch
.
cuda
.
is_bf16_supported
():
if
F
.
_default_context_str
==
"cpu"
and
dtype
==
torch
.
float16
:
pytest
.
skip
(
"float16 is not supported on CPU."
)
if
(
F
.
_default_context_str
==
"gpu"
and
dtype
==
torch
.
bfloat16
and
not
torch
.
cuda
.
is_bf16_supported
()
):
pytest
.
skip
(
"BF16 is not supported."
)
# make sure the spmm result is < 512 to match the rtol/atol we set.
...
...
@@ -195,7 +197,7 @@ def test_half_spmm(idtype, dtype, rtol, atol):
idtype
=
idtype
,
device
=
F
.
ctx
(),
)
feat_fp32
=
torch
.
rand
((
g
.
num_src_nodes
(),
32
)).
to
(
0
)
feat_fp32
=
torch
.
rand
((
g
.
num_src_nodes
(),
32
)).
to
(
F
.
ctx
()
)
feat_half
=
feat_fp32
.
to
(
dtype
)
# test SpMMCSR
...
...
@@ -337,11 +339,8 @@ def test_segment_reduce(reducer):
],
)
def
test_segment_mm
(
idtype
,
feat_size
,
dtype
,
tol
):
if
F
.
_default_context_str
==
"cpu"
and
dtype
in
(
torch
.
float16
,
torch
.
bfloat16
,
):
pytest
.
skip
(
"Only support float32 and float64 on CPU."
)
if
F
.
_default_context_str
==
"cpu"
and
dtype
==
torch
.
float16
:
pytest
.
skip
(
"float16 is not supported on CPU."
)
if
(
F
.
_default_context_str
==
"gpu"
and
dtype
==
torch
.
bfloat16
...
...
@@ -397,11 +396,8 @@ def test_segment_mm(idtype, feat_size, dtype, tol):
],
)
def
test_gather_mm_idx_b
(
feat_size
,
dtype
,
tol
):
if
F
.
_default_context_str
==
"cpu"
and
dtype
in
(
torch
.
float16
,
torch
.
bfloat16
,
):
pytest
.
skip
(
"Only support float32 and float64 on CPU."
)
if
F
.
_default_context_str
==
"cpu"
and
dtype
==
torch
.
float16
:
pytest
.
skip
(
"float16 is not supported on CPU."
)
if
(
F
.
_default_context_str
==
"gpu"
and
dtype
==
torch
.
bfloat16
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment