Unverified Commit 96297fb8 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Feature] Add bfloat16 (bf16) support (#4648)

* add bf16 specializations

* remove SWITCH_BITS

* enable amp for bf16

* remove SWITCH_BITS for cpu kernels

* enbale bf16 based on CUDART

* fix compiling for sm<80

* fix cpu build

* enable unit tests

* update doc

* disable test for CUDA < 11.0

* address comments

* address comments
parent 1d229194
......@@ -28,7 +28,6 @@ dgl_option(USE_SYSTEM_NCCL "Build using system's NCCL library" OFF)
dgl_option(USE_OPENMP "Build with OpenMP" ON)
dgl_option(USE_AVX "Build with AVX optimization" OFF)
dgl_option(USE_LIBXSMM "Build with LIBXSMM library optimization" ON)
dgl_option(USE_FP16 "Build with fp16 support to enable mixed precision training" OFF)
dgl_option(USE_TVM "Build with TVM kernels" OFF)
dgl_option(BUILD_CPP_TEST "Build cpp unittest executables" OFF)
dgl_option(LIBCXX_ENABLE_PARALLEL_ALGORITHMS "Enable the parallel algorithms library. This requires the PSTL to be available." OFF)
......@@ -127,13 +126,6 @@ if ((NOT MSVC) AND USE_EPOLL)
endif()
endif ()
# Build with fp16 to support mixed precision training.
if(USE_FP16)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DUSE_FP16")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_FP16")
message(STATUS "Build with fp16 to support mixed precision training")
endif(USE_FP16)
# To compile METIS correct for DGL.
if(MSVC)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /DIDXTYPEWIDTH=64 /DREALTYPEWIDTH=32")
......
......@@ -9,9 +9,9 @@ consumption. This feature requires DGL 0.9+.
Message-Passing with Half Precision
-----------------------------------
DGL allows message-passing on ``float16 (fp16)`` features for both
UDFs (User Defined Functions) and built-in functions (e.g., ``dgl.function.sum``,
``dgl.function.copy_u``).
DGL allows message-passing on ``float16 (fp16)`` / ``bfloat16 (bf16)`` (requires CUDA >= 11.0)
features for both UDFs (User Defined Functions) and built-in functions
(e.g., ``dgl.function.sum``, ``dgl.function.copy_u``).
The following example shows how to use DGL's message-passing APIs on half-precision
features:
......@@ -63,8 +63,9 @@ efficient, most operators on half precision tensors are faster as they leverage
import torch.nn.functional as F
from torch.cuda.amp import autocast
def forward(g, feat, label, mask, model, use_fp16):
with autocast(enabled=use_fp16):
def forward(g, feat, label, mask, model, amp_dtype):
amp_enabled = amp_dtype in (torch.float16, torch.bfloat16)
with autocast(enabled=amp_enabled, dtype=amp_dtype):
logit = model(g, feat)
loss = F.cross_entropy(logit[mask], label[mask])
return loss
......@@ -74,6 +75,7 @@ PyTorch provides a ``GradScaler`` module to address this issue. It multiplies
the loss by a factor and invokes backward pass on the scaled loss to prevent
the underflow problem. It then unscales the computed gradients before the optimizer
updates the parameters. The scale factor is determined automatically.
Note that ``bfloat16`` doesn't require a ``GradScaler``.
.. code::
......@@ -87,7 +89,7 @@ updates the parameters. The scale factor is determined automatically.
scaler.update()
The following example trains a 3-layer GAT on the Reddit dataset (w/ 114 million edges).
Pay attention to the differences in the code when ``use_fp16`` is activated or not.
Pay attention to the differences in the code when AMP is activated or not.
.. code::
......@@ -98,7 +100,7 @@ Pay attention to the differences in the code when ``use_fp16`` is activated or n
from dgl.nn import GATConv
from dgl.transforms import AddSelfLoop
use_fp16 = True
amp_dtype = torch.float16 # or torch.bfloat16
class GAT(nn.Module):
def __init__(self,
......@@ -145,9 +147,9 @@ Pay attention to the differences in the code when ``use_fp16`` is activated or n
for epoch in range(100):
optimizer.zero_grad()
loss = forward(g, feat, label, train_mask, model, use_fp16)
loss = forward(g, feat, label, train_mask, model, amp_dtype)
if use_fp16:
if amp_dtype == torch.float16:
# Backprop w/ gradient scaling
backward(scaler, loss, optimizer)
else:
......
......@@ -123,26 +123,67 @@
typedef double FloatType; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << (val_name) << " can only be float32 or float64"; \
LOG(FATAL) << (val_name) \
<< " can only be float32 or float64"; \
} \
} while (0)
#define ATEN_FLOAT_BITS_SWITCH(val, bits, val_name, ...) do { \
CHECK_EQ((val).code, kDGLFloat) \
/*
* Dispatch according to float type, including 16bits (float16/bfloat16/float32/float64).
*/
#ifdef DGL_USE_CUDA
#if BF16_ENABLED
#define ATEN_FLOAT_TYPE_SWITCH_16BITS(val, FloatType, XPU, val_name, ...) do { \
CHECK((val).code == kDGLFloat || (val.code == kDGLBfloat)) \
<< (val_name) << " must be float type"; \
if ((val).bits == 16) { \
constexpr int bits = 16; \
if ((val).bits == 32) { \
typedef float FloatType; \
{__VA_ARGS__} \
} else if ((val).bits == 32) { \
constexpr int bits = 32; \
} else if ((val).bits == 64) { \
typedef double FloatType; \
{__VA_ARGS__} \
} else if (XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLFloat) { \
typedef __half FloatType; \
{__VA_ARGS__} \
} else if (XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLBfloat) { \
typedef __nv_bfloat16 FloatType; \
{__VA_ARGS__} \
} else if (XPU == kDGLCPU) { \
LOG(FATAL) << (val_name) \
<< " can only be float32 or float64 on CPU"; \
} else { \
LOG(FATAL) << (val_name) \
<< " can only be float16/bfloat16/float32/float64 on GPU"; \
} \
} while (0)
#else // BF16_ENABLED
#define ATEN_FLOAT_TYPE_SWITCH_16BITS(val, FloatType, XPU, val_name, ...) do { \
CHECK((val).code == kDGLFloat || (val.code == kDGLBfloat)) \
<< (val_name) << " must be float type"; \
if ((val).bits == 32) { \
typedef float FloatType; \
{__VA_ARGS__} \
} else if ((val).bits == 64) { \
constexpr int bits = 64; \
typedef double FloatType; \
{__VA_ARGS__} \
} else if (XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLFloat) { \
typedef __half FloatType; \
{__VA_ARGS__} \
} else if (XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLBfloat) { \
LOG(FATAL) << "bfloat16 requires CUDA >= 11.0"; \
} else if (XPU == kDGLCPU) { \
LOG(FATAL) << (val_name) \
<< " can only be float32 or float64 on CPU"; \
} else { \
LOG(FATAL) << (val_name) << " can only be float32 or float64"; \
LOG(FATAL) << (val_name) \
<< " can only be float16/float32/float64 on GPU"; \
} \
} while (0)
#endif // BF16_ENABLED
#else // DGL_USE_CUDA
#define ATEN_FLOAT_TYPE_SWITCH_16BITS(val, FloatType, XPU, val_name, ...) \
ATEN_FLOAT_TYPE_SWITCH(val, FloatType, val_name, {__VA_ARGS__})
#endif // DGL_USE_CUDA
/*
* Dispatch according to data type (int32, int64, float32 or float64):
......
......@@ -18,8 +18,15 @@
#include "shared_mem.h"
#ifdef DGL_USE_CUDA
#include <cuda_runtime.h>
#define BF16_ENABLED (defined(CUDART_VERSION) && CUDART_VERSION >= 11000)
#include <cuda_fp16.h>
#endif
#if BF16_ENABLED
#include <cuda_bf16.h>
#endif // BF16_ENABLED
#endif // DGL_USE_CUDA
// forward declaration
inline std::ostream& operator << (std::ostream& os, DGLDataType t);
......@@ -51,10 +58,11 @@ GEN_DGLDATATYPETRAITS_FOR(int64_t, kDGLInt, 64);
GEN_DGLDATATYPETRAITS_FOR(uint32_t, kDGLInt, 32);
GEN_DGLDATATYPETRAITS_FOR(uint64_t, kDGLInt, 64);
#ifdef DGL_USE_CUDA
#ifdef USE_FP16
GEN_DGLDATATYPETRAITS_FOR(__half, kDGLFloat, 16);
#endif
#endif
#if BF16_ENABLED
GEN_DGLDATATYPETRAITS_FOR(__nv_bfloat16, kDGLBfloat, 16);
#endif // BF16_ENABLED
#endif // DGL_USE_CUDA
GEN_DGLDATATYPETRAITS_FOR(float, kDGLFloat, 32);
GEN_DGLDATATYPETRAITS_FOR(double, kDGLFloat, 64);
#undef GEN_DGLDATATYPETRAITS_FOR
......
import torch as th
from torch.cuda.amp import custom_bwd, custom_fwd
from ...base import ALL, is_all
from ...heterograph_index import create_unitgraph_from_csr
......@@ -39,7 +38,6 @@ __all__ = [
"segment_mm",
]
def _reduce_grad(grad, shape):
"""Reduce gradient on the broadcast dimension
If there is broadcast in forward pass, gradients need to be reduced on
......@@ -127,10 +125,26 @@ def spmm_cache_argY(binary_op, reduce_op, req_grad_X, req_grad_Y):
return True
return False
class empty_context():
"""Empty context that does nothing"""
def __init__(self, *args, **kargs):
return
def __enter__(self, *args, **kargs):
return self
def __exit__(self, *args, **kargs):
return
# This is to avoid warnings in cpu-only dgl. We don't enable autocast for CPU ops
autocast = th.cuda.amp.autocast if th.cuda.is_available() else empty_context
def _cast_if_autocast_enabled(*args):
if not th.is_autocast_enabled() or not th.cuda.is_available():
return args
else:
return th.cuda.amp.autocast_mode._cast(args, th.get_autocast_gpu_dtype())
class GSpMM(th.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=th.float16)
def forward(ctx, gidx, op, reduce_op, X, Y):
out, (argX, argY) = _gspmm(gidx, op, reduce_op, X, Y)
reduce_last = _need_reduce_last_dim(X, Y)
......@@ -162,7 +176,6 @@ class GSpMM(th.autograd.Function):
return out
@staticmethod
@custom_bwd
def backward(ctx, dZ):
(
gidx,
......@@ -221,7 +234,6 @@ class GSpMM(th.autograd.Function):
class GSpMM_hetero(th.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=th.float16)
def forward(
ctx, gidx, op, reduce_op, X_len, *feats
): # feats = lhs_data + rhs_data
......@@ -278,7 +290,6 @@ class GSpMM_hetero(th.autograd.Function):
return out
@staticmethod
@custom_bwd
def backward(ctx, *dZ):
(
gidx,
......@@ -415,7 +426,6 @@ def sddmm_cache_Y(op, req_grad_X, req_grad_Y):
class GSDDMM(th.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=th.float16)
def forward(ctx, gidx, op, X, Y, lhs_target, rhs_target):
out = _gsddmm(gidx, op, X, Y, lhs_target, rhs_target)
X_shape = X.shape if X is not None else None
......@@ -431,7 +441,6 @@ class GSDDMM(th.autograd.Function):
return out
@staticmethod
@custom_bwd
def backward(ctx, dZ):
gidx, op, lhs_target, rhs_target, X_shape, Y_shape = ctx.backward_cache
X, Y = ctx.saved_tensors
......@@ -480,7 +489,6 @@ class GSDDMM(th.autograd.Function):
class GSDDMM_hetero(th.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=th.float16)
def forward(
ctx, gidx, op, X_len, lhs_target, rhs_target, *feats
): # feats = X+Y
......@@ -517,7 +525,6 @@ class GSDDMM_hetero(th.autograd.Function):
return out
@staticmethod
@custom_bwd
# TODO(Israt): Implement the complete backward operator
def backward(ctx, *dZ):
(
......@@ -661,7 +668,6 @@ class GSDDMM_hetero(th.autograd.Function):
class EdgeSoftmax(th.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=th.float16)
def forward(ctx, gidx, score, eids, norm_by):
"""Forward function.
......@@ -696,7 +702,6 @@ class EdgeSoftmax(th.autograd.Function):
return out
@staticmethod
@custom_bwd
def backward(ctx, grad_out):
"""Backward function.
......@@ -728,7 +733,6 @@ class EdgeSoftmax(th.autograd.Function):
class EdgeSoftmax_hetero(th.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=th.float16)
def forward(ctx, gidx, eids, norm_by, *score):
"""Forward function.
......@@ -774,7 +778,6 @@ class EdgeSoftmax_hetero(th.autograd.Function):
return out
@staticmethod
@custom_bwd
def backward(ctx, *grad_out):
"""Backward function.
......@@ -808,7 +811,6 @@ class EdgeSoftmax_hetero(th.autograd.Function):
class SegmentReduce(th.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=th.float16)
def forward(ctx, op, x, offsets):
y, arg = _segment_reduce(op, x, offsets)
ctx.save_for_backward(arg, offsets)
......@@ -816,7 +818,6 @@ class SegmentReduce(th.autograd.Function):
return y
@staticmethod
@custom_bwd
def backward(ctx, dy):
op = ctx.backward_cache
arg, offsets = ctx.saved_tensors
......@@ -838,14 +839,12 @@ class SegmentReduce(th.autograd.Function):
class ScatterAdd(th.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=th.float16)
def forward(ctx, x, idx, m):
y = _scatter_add(x, idx, m)
ctx.save_for_backward(idx)
return y
@staticmethod
@custom_bwd
def backward(ctx, dy):
idx = ctx.saved_tensors
return dy[idx], None, None
......@@ -952,7 +951,6 @@ class CSRMask(th.autograd.Function):
class SEGMENTMM(th.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=th.float16)
def forward(ctx, A, B, seglen_A):
if B.dim() != 3:
raise ValueError("segment_mm expects B to be a 3D tensor.")
......@@ -978,7 +976,6 @@ class SEGMENTMM(th.autograd.Function):
class GATHERMM(th.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=th.float16)
def forward(ctx, A, B, idx_a, idx_b):
if B.dim() != 3:
raise ValueError(
......@@ -1014,7 +1011,9 @@ def gspmm(gidx, op, reduce_op, lhs_data, rhs_data):
if op == "div":
op = "mul"
rhs_data = 1.0 / rhs_data
return GSpMM.apply(gidx, op, reduce_op, lhs_data, rhs_data)
args = _cast_if_autocast_enabled(gidx, op, reduce_op, lhs_data, rhs_data)
with autocast(enabled=False):
return GSpMM.apply(*args)
def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target="u", rhs_target="v"):
......@@ -1024,7 +1023,9 @@ def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target="u", rhs_target="v"):
if op == "div":
op = "mul"
rhs_data = 1.0 / rhs_data
return GSDDMM.apply(gidx, op, lhs_data, rhs_data, lhs_target, rhs_target)
args = _cast_if_autocast_enabled(gidx, op, lhs_data, rhs_data, lhs_target, rhs_target)
with autocast(enabled=False):
return GSDDMM.apply(*args)
def gspmm_hetero(g, op, reduce_op, lhs_len, *lhs_and_rhs_tuple):
......@@ -1050,7 +1051,10 @@ def gspmm_hetero(g, op, reduce_op, lhs_len, *lhs_and_rhs_tuple):
)
if op in ["add", "mul"]:
lhs_and_rhs_tuple = tuple(list(lhs_tuple) + list(rhs_tuple))
return GSpMM_hetero.apply(g, op, reduce_op, lhs_len, *lhs_and_rhs_tuple)
args = _cast_if_autocast_enabled(g, op, reduce_op, lhs_len, *lhs_and_rhs_tuple)
with autocast(enabled=False):
return GSpMM_hetero.apply(*args)
def gsddmm_hetero(
......@@ -1078,25 +1082,34 @@ def gsddmm_hetero(
)
if op in ["add", "mul"]:
lhs_and_rhs_tuple = tuple(list(lhs_tuple) + list(rhs_tuple))
return GSDDMM_hetero.apply(
g, op, lhs_len, lhs_target, rhs_target, *lhs_and_rhs_tuple
)
args = _cast_if_autocast_enabled(g, op, lhs_len, lhs_target, rhs_target, *lhs_and_rhs_tuple)
with autocast(enabled=False):
return GSDDMM_hetero.apply(*args)
def edge_softmax(gidx, logits, eids=ALL, norm_by="dst"):
return EdgeSoftmax.apply(gidx, logits, eids, norm_by)
args = _cast_if_autocast_enabled(gidx, logits, eids, norm_by)
with autocast(enabled=False):
return EdgeSoftmax.apply(*args)
def edge_softmax_hetero(gidx, eids=ALL, norm_by="dst", *logits):
return EdgeSoftmax_hetero.apply(gidx, eids, norm_by, *logits)
args = _cast_if_autocast_enabled(gidx, eids, norm_by, *logits)
with autocast(enabled=False):
return EdgeSoftmax_hetero.apply(*args)
def segment_reduce(op, x, offsets):
return SegmentReduce.apply(op, x, offsets)
args = _cast_if_autocast_enabled(op, x, offsets)
with autocast(enabled=False):
return SegmentReduce.apply(*args)
def scatter_add(x, idx, m):
return ScatterAdd.apply(x, idx, m)
args = _cast_if_autocast_enabled(x, idx, m)
with autocast(enabled=False):
return ScatterAdd.apply(*args)
def csrmm(gidxA, A_weights, gidxB, B_weights, num_vtypes):
......@@ -1144,7 +1157,9 @@ def segment_mm(A, B, seglen_A):
off += seglen_A[i]
return th.cat(C)
else:
return SEGMENTMM.apply(A, B, seglen_A)
args = _cast_if_autocast_enabled(A, B, seglen_A)
with autocast(enabled=False):
return SEGMENTMM.apply(*args)
def gather_mm(A, B, idx_A=None, idx_B=None):
......@@ -1153,4 +1168,6 @@ def gather_mm(A, B, idx_A=None, idx_B=None):
B = B[idx_B] if idx_B is not None else B
return th.bmm(A.unsqueeze(1), B).squeeze(1)
else:
return GATHERMM.apply(A, B, idx_A, idx_B)
args = _cast_if_autocast_enabled(A, B, idx_A, idx_B)
with autocast(enabled=False):
return GATHERMM.apply(*args)
......@@ -9,22 +9,8 @@
namespace dgl {
namespace aten {
#define SWITCH_BITS(bits, DType, ...) \
do { \
if ((bits) == 16 || (bits) == 32) { \
typedef float DType; \
{ __VA_ARGS__ } \
} else if ((bits) == 64) { \
typedef double DType; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << "Data type not recognized with bits " << bits; \
} \
} while (0)
/*! \brief Generalized SegmentMM. */
template <int XPU, typename IdType, int bits>
template <int XPU, typename IdType, typename DType>
void SegmentMM(const NDArray A,
const NDArray B,
NDArray C,
......@@ -33,7 +19,7 @@ void SegmentMM(const NDArray A,
LOG(FATAL) << "Unsupported CPU kernel for SegmentMM.";
}
template <int XPU, typename IdType, int bits>
template <int XPU, typename IdType, typename DType>
void SegmentMMBackwardB(const NDArray A,
const NDArray dC,
NDArray dB,
......@@ -42,7 +28,7 @@ void SegmentMMBackwardB(const NDArray A,
}
/*! \brief Generalized GatherMM. */
template <int XPU, typename IdType, int bits>
template <int XPU, typename IdType, typename DType>
void GatherMM(const NDArray A,
const NDArray B,
NDArray C,
......@@ -52,7 +38,7 @@ void GatherMM(const NDArray A,
}
/*! \brief Generalized GatherMM_scatter. */
template <int XPU, typename IdType, int bits>
template <int XPU, typename IdType, typename DType>
void GatherMMScatter(const NDArray A,
const NDArray B,
NDArray C,
......@@ -62,74 +48,52 @@ void GatherMMScatter(const NDArray A,
LOG(FATAL) << "Unsupported CPU kernel for GatherMM.";
}
template void GatherMM<kDGLCPU, int32_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDGLCPU, int64_t, 16>(
template void GatherMM<kDGLCPU, int32_t, float>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDGLCPU, int32_t, 32>(
template void GatherMM<kDGLCPU, int64_t, float>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDGLCPU, int64_t, 32>(
template void GatherMM<kDGLCPU, int32_t, double>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDGLCPU, int32_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDGLCPU, int64_t, 64>(
template void GatherMM<kDGLCPU, int64_t, double>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
template void GatherMMScatter<kDGLCPU, int32_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCPU, int64_t, 16>(
template void GatherMMScatter<kDGLCPU, int32_t, float>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCPU, int32_t, 32>(
template void GatherMMScatter<kDGLCPU, int64_t, float>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCPU, int64_t, 32>(
template void GatherMMScatter<kDGLCPU, int32_t, double>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCPU, int32_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCPU, int64_t, 64>(
template void GatherMMScatter<kDGLCPU, int64_t, double>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void SegmentMM<kDGLCPU, int32_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDGLCPU, int64_t, 16>(
template void SegmentMM<kDGLCPU, int32_t, float>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDGLCPU, int32_t, 32>(
template void SegmentMM<kDGLCPU, int64_t, float>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDGLCPU, int64_t, 32>(
template void SegmentMM<kDGLCPU, int32_t, double>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDGLCPU, int32_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDGLCPU, int64_t, 64>(
template void SegmentMM<kDGLCPU, int64_t, double>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMMBackwardB<kDGLCPU, int32_t, 16>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDGLCPU, int64_t, 16>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDGLCPU, int32_t, 32>(
template void SegmentMMBackwardB<kDGLCPU, int32_t, float>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDGLCPU, int64_t, 32>(
template void SegmentMMBackwardB<kDGLCPU, int64_t, float>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDGLCPU, int32_t, 64>(
template void SegmentMMBackwardB<kDGLCPU, int32_t, double>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDGLCPU, int64_t, 64>(
template void SegmentMMBackwardB<kDGLCPU, int64_t, double>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
} // namespace aten
......
......@@ -41,22 +41,9 @@ namespace aten {
} \
} while (0)
#define SWITCH_BITS(bits, DType, ...) \
do { \
if ((bits) == 16 || (bits) == 32) { \
typedef float DType; \
{ __VA_ARGS__ } \
} else if ((bits) == 64) { \
typedef double DType; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << "Data type not recognized with bits " << bits; \
} \
} while (0)
/*! \brief Generalized SDDMM on Csr format. */
template <int XPU, typename IdType, int bits>
template <int XPU, typename IdType, typename DType>
void SDDMMCsr(const std::string& op,
const BcastOff& bcast,
const CSRMatrix& csr,
......@@ -65,17 +52,15 @@ void SDDMMCsr(const std::string& op,
NDArray out,
int lhs_target,
int rhs_target) {
SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
cpu::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, csr, lhs, rhs, out);
});
});
});
}
/*! \brief Generalized SDDMM on Csr format with Heterograph support. */
template <int XPU, typename IdType, int bits>
template <int XPU, typename IdType, typename DType>
void SDDMMCsrHetero(const std::string& op,
const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
......@@ -86,7 +71,6 @@ void SDDMMCsrHetero(const std::string& op,
int rhs_target,
const std::vector<dgl_type_t>& lhs_nid,
const std::vector<dgl_type_t>& rhs_nid) {
SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
/* Call SDDMM for each relation type */
......@@ -99,70 +83,47 @@ void SDDMMCsrHetero(const std::string& op,
}
});
});
});
}
template void SDDMMCsr<kDGLCPU, int32_t, 16>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCPU, int64_t, 16>(
template void SDDMMCsr<kDGLCPU, int32_t, float>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCPU, int32_t, 32>(
template void SDDMMCsr<kDGLCPU, int64_t, float>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCPU, int64_t, 32>(
template void SDDMMCsr<kDGLCPU, int32_t, double>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCPU, int32_t, 64>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCPU, int64_t, 64>(
template void SDDMMCsr<kDGLCPU, int64_t, double>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsrHetero<kDGLCPU, int32_t, 16>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCPU, int64_t, 16>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCPU, int32_t, 32>(
template void SDDMMCsrHetero<kDGLCPU, int32_t, float>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCPU, int64_t, 32>(
template void SDDMMCsrHetero<kDGLCPU, int64_t, float>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCPU, int32_t, 64>(
template void SDDMMCsrHetero<kDGLCPU, int32_t, double>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCPU, int64_t, 64>(
template void SDDMMCsrHetero<kDGLCPU, int64_t, double>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
......@@ -171,7 +132,7 @@ template void SDDMMCsrHetero<kDGLCPU, int64_t, 64>(
const std::vector<dgl_type_t>& out_eid);
/*! \brief Generalized SDDMM on Coo format. */
template <int XPU, typename IdType, int bits>
template <int XPU, typename IdType, typename DType>
void SDDMMCoo(const std::string& op,
const BcastOff& bcast,
const COOMatrix& coo,
......@@ -180,17 +141,15 @@ void SDDMMCoo(const std::string& op,
NDArray out,
int lhs_target,
int rhs_target) {
SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
cpu::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, coo, lhs, rhs, out);
});
});
});
}
/*! \brief Generalized SDDMM on Coo format with Heterograph support. */
template <int XPU, typename IdType, int bits>
template <int XPU, typename IdType, typename DType>
void SDDMMCooHetero(const std::string& op,
const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
......@@ -201,7 +160,6 @@ void SDDMMCooHetero(const std::string& op,
int rhs_target,
const std::vector<dgl_type_t>& lhs_nid,
const std::vector<dgl_type_t>& rhs_nid) {
SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
/* Call SDDMM for each relation type */
......@@ -214,70 +172,47 @@ void SDDMMCooHetero(const std::string& op,
}
});
});
});
}
template void SDDMMCoo<kDGLCPU, int32_t, 16>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCPU, int64_t, 16>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCPU, int32_t, 32>(
template void SDDMMCoo<kDGLCPU, int32_t, float>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCPU, int64_t, 32>(
template void SDDMMCoo<kDGLCPU, int64_t, float>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCPU, int32_t, 64>(
template void SDDMMCoo<kDGLCPU, int32_t, double>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCPU, int64_t, 64>(
template void SDDMMCoo<kDGLCPU, int64_t, double>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCooHetero<kDGLCPU, int32_t, 16>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCPU, int64_t, 16>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCPU, int32_t, 32>(
template void SDDMMCooHetero<kDGLCPU, int32_t, float>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCPU, int64_t, 32>(
template void SDDMMCooHetero<kDGLCPU, int64_t, float>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCPU, int32_t, 64>(
template void SDDMMCooHetero<kDGLCPU, int32_t, double>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCPU, int64_t, 64>(
template void SDDMMCooHetero<kDGLCPU, int64_t, double>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
......@@ -285,6 +220,5 @@ template void SDDMMCooHetero<kDGLCPU, int64_t, 64>(
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
} // namespace aten
} // namespace dgl
......@@ -12,7 +12,7 @@ namespace dgl {
namespace aten {
/*! \brief Segment Reduce operator. */
template <int XPU, typename IdType, int bits>
template <int XPU, typename IdType, typename DType>
void SegmentReduce(
const std::string& op,
NDArray feat,
......@@ -20,20 +20,14 @@ void SegmentReduce(
NDArray out,
NDArray arg) {
if (op == "sum") {
SWITCH_BITS(bits, DType, {
cpu::SegmentSum<IdType, DType>(feat, offsets, out);
});
} else if (op == "max" || op == "min") {
if (op == "max") {
SWITCH_BITS(bits, DType, {
cpu::SegmentCmp<IdType, DType, cpu::op::Max<DType>>(
feat, offsets, out, arg);
});
} else {
SWITCH_BITS(bits, DType, {
cpu::SegmentCmp<IdType, DType, cpu::op::Min<DType>>(
feat, offsets, out, arg);
});
}
} else {
LOG(FATAL) << "Unsupported reduce function " << op;
......@@ -41,146 +35,105 @@ void SegmentReduce(
}
/*! \brief Scatter Add.*/
template <int XPU, typename IdType, int bits>
template <int XPU, typename IdType, typename DType>
void ScatterAdd(NDArray feat,
NDArray idx,
NDArray out) {
SWITCH_BITS(bits, DType, {
cpu::ScatterAdd<IdType, DType>(feat, idx, out);
});
}
/*! \brief Update gradients for reduce operator max/min on heterogeneous graph.*/
template <int XPU, typename IdType, int bits>
template <int XPU, typename IdType, typename DType>
void UpdateGradMinMax_hetero(const HeteroGraphPtr& g,
const std::string& op,
const std::vector<NDArray>& feat,
const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype,
std::vector<NDArray>* out) {
SWITCH_BITS(bits, DType, {
cpu::UpdateGradMinMax_hetero<IdType, DType>(g, op, feat, idx, idx_etype, out);
});
}
/*! \brief Backward function of segment cmp.*/
template <int XPU, typename IdType, int bits>
template <int XPU, typename IdType, typename DType>
void BackwardSegmentCmp(
NDArray feat,
NDArray arg,
NDArray out) {
SWITCH_BITS(bits, DType, {
cpu::BackwardSegmentCmp<IdType, DType>(feat, arg, out);
});
}
template void SegmentReduce<kDGLCPU, int32_t, 16>(
template void SegmentReduce<kDGLCPU, int32_t, float>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void SegmentReduce<kDGLCPU, int64_t, 16>(
template void SegmentReduce<kDGLCPU, int64_t, float>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void SegmentReduce<kDGLCPU, int32_t, 32>(
template void SegmentReduce<kDGLCPU, int32_t, double>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void SegmentReduce<kDGLCPU, int64_t, 32>(
template void SegmentReduce<kDGLCPU, int64_t, double>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void SegmentReduce<kDGLCPU, int32_t, 64>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void SegmentReduce<kDGLCPU, int64_t, 64>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void ScatterAdd<kDGLCPU, int32_t, 16>(
NDArray feat,
NDArray idx,
NDArray out);
template void ScatterAdd<kDGLCPU, int64_t, 16>(
NDArray feat,
NDArray idx,
NDArray out);
template void ScatterAdd<kDGLCPU, int32_t, 32>(
template void ScatterAdd<kDGLCPU, int32_t, float>(
NDArray feat,
NDArray idx,
NDArray out);
template void ScatterAdd<kDGLCPU, int64_t, 32>(
template void ScatterAdd<kDGLCPU, int64_t, float>(
NDArray feat,
NDArray idx,
NDArray out);
template void ScatterAdd<kDGLCPU, int32_t, 64>(
template void ScatterAdd<kDGLCPU, int32_t, double>(
NDArray feat,
NDArray idx,
NDArray out);
template void ScatterAdd<kDGLCPU, int64_t, 64>(
template void ScatterAdd<kDGLCPU, int64_t, double>(
NDArray feat,
NDArray arg,
NDArray out);
template void UpdateGradMinMax_hetero<kDGLCPU, int32_t, 16>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDGLCPU, int64_t, 16>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDGLCPU, int32_t, 32>(
template void UpdateGradMinMax_hetero<kDGLCPU, int32_t, float>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDGLCPU, int64_t, 32>(
template void UpdateGradMinMax_hetero<kDGLCPU, int64_t, float>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDGLCPU, int32_t, 64>(
template void UpdateGradMinMax_hetero<kDGLCPU, int32_t, double>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDGLCPU, int64_t, 64>(
template void UpdateGradMinMax_hetero<kDGLCPU, int64_t, double>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void BackwardSegmentCmp<kDGLCPU, int32_t, 16>(
NDArray feat,
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDGLCPU, int64_t, 16>(
NDArray feat,
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDGLCPU, int32_t, 32>(
template void BackwardSegmentCmp<kDGLCPU, int32_t, float>(
NDArray feat,
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDGLCPU, int64_t, 32>(
template void BackwardSegmentCmp<kDGLCPU, int64_t, float>(
NDArray feat,
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDGLCPU, int32_t, 64>(
template void BackwardSegmentCmp<kDGLCPU, int32_t, double>(
NDArray feat,
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDGLCPU, int64_t, 64>(
template void BackwardSegmentCmp<kDGLCPU, int64_t, double>(
NDArray feat,
NDArray arg,
NDArray out);
......
......@@ -10,7 +10,7 @@ namespace dgl {
namespace aten {
/*! \brief Generalized SpMM on Csr format. */
template <int XPU, typename IdType, int bits>
template <int XPU, typename IdType, typename DType>
void SpMMCsr(const std::string& op, const std::string& reduce,
const BcastOff& bcast,
const CSRMatrix& csr,
......@@ -20,13 +20,10 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
std::vector<NDArray> out_aux) {
const int64_t dim = bcast.out_len;
if (reduce == "sum") {
SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, {
cpu::SpMMSumCsr<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
});
});
} else if (reduce == "max" || reduce == "min") {
SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, {
DType *out_off = out.Ptr<DType>();
if (reduce == "max") {
......@@ -39,14 +36,13 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
}
});
});
} else {
LOG(FATAL) << "Unsupported SpMM reducer: " << reduce;
}
}
/*! \brief Generalized SpMM on Csr format. */
template <int XPU, typename IdType, int bits>
template <int XPU, typename IdType, typename DType>
void SpMMCsrHetero(const std::string& op, const std::string& reduce,
const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
......@@ -58,7 +54,6 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
const std::vector<dgl_type_t>& out_node_tids) {
const int64_t dim = bcast.out_len;
if (reduce == "sum") {
SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, {
/* Call SpMM for each relation type */
for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
......@@ -71,9 +66,7 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
cpu::SpMMSumCsr<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
}
});
});
} else if (reduce == "max" || reduce == "min") {
SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, {
std::vector<bool> updated((*vec_out).size(), false);
// TODO(Israt): use vector updated to fill(out...) too
......@@ -115,73 +108,50 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
}
}
});
});
} else {
LOG(FATAL) << "Unsupported SpMM reducer: " << reduce;
}
}
template void SpMMCsr<kDGLCPU, int32_t, 16>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCPU, int64_t, 16>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCPU, int32_t, 32>(
template void SpMMCsr<kDGLCPU, int32_t, float>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCPU, int64_t, 32>(
template void SpMMCsr<kDGLCPU, int64_t, float>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCPU, int32_t, 64>(
template void SpMMCsr<kDGLCPU, int32_t, double>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCPU, int64_t, 64>(
template void SpMMCsr<kDGLCPU, int64_t, double>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsrHetero<kDGLCPU, int32_t, 16>(
template void SpMMCsrHetero<kDGLCPU, int32_t, float>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDGLCPU, int64_t, 16>(
template void SpMMCsrHetero<kDGLCPU, int64_t, float>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDGLCPU, int32_t, 32>(
template void SpMMCsrHetero<kDGLCPU, int32_t, double>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDGLCPU, int64_t, 32>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDGLCPU, int32_t, 64>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDGLCPU, int64_t, 64>(
template void SpMMCsrHetero<kDGLCPU, int64_t, double>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
......@@ -190,87 +160,67 @@ template void SpMMCsrHetero<kDGLCPU, int64_t, 64>(
const std::vector<dgl_type_t>& out_node_tids);
/*! \brief Edge_softmax_csr forward op on Csr format. */
template <int XPU, typename IdType, int bits>
template <int XPU, typename IdType, typename DType>
void Edge_softmax_csr_forward(const std::string& op,
const BcastOff& bcast,
const CSRMatrix& csr,
NDArray ufeat,
NDArray efeat,
NDArray out) {
SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, {
cpu::Edge_softmax_csr_forward<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
});
});
}
/*! \brief Edge_softmax_csr backward op on Csr format. */
template <int XPU, typename IdType, int bits>
template <int XPU, typename IdType, typename DType>
void Edge_softmax_csr_backward(const std::string& op,
const BcastOff& bcast,
const CSRMatrix& csr,
NDArray out,
NDArray sds,
NDArray back_out) {
SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, {
cpu::Edge_softmax_csr_backward<IdType, DType, Op>(bcast, csr, out, sds, back_out);
});
});
}
template void Edge_softmax_csr_forward<kDGLCPU, int32_t, 16>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDGLCPU, int64_t, 16>(
template void Edge_softmax_csr_forward<kDGLCPU, int32_t, float>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDGLCPU, int32_t, 32>(
template void Edge_softmax_csr_forward<kDGLCPU, int64_t, float>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDGLCPU, int64_t, 32>(
template void Edge_softmax_csr_forward<kDGLCPU, int32_t, double>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDGLCPU, int32_t, 64>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDGLCPU, int64_t, 64>(
template void Edge_softmax_csr_forward<kDGLCPU, int64_t, double>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDGLCPU, int32_t, 16>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDGLCPU, int64_t, 16>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDGLCPU, int32_t, 32>(
template void Edge_softmax_csr_backward<kDGLCPU, int32_t, float>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDGLCPU, int64_t, 32>(
template void Edge_softmax_csr_backward<kDGLCPU, int64_t, float>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDGLCPU, int32_t, 64>(
template void Edge_softmax_csr_backward<kDGLCPU, int32_t, double>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDGLCPU, int64_t, 64>(
template void Edge_softmax_csr_backward<kDGLCPU, int64_t, double>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
/*! \brief Generalized SpMM on Coo format. */
template <int XPU, typename IdType, int bits>
template <int XPU, typename IdType, typename DType>
void SpMMCoo(const std::string& op, const std::string& reduce,
const BcastOff& bcast,
const COOMatrix& coo,
......@@ -279,13 +229,10 @@ void SpMMCoo(const std::string& op, const std::string& reduce,
NDArray out,
std::vector<NDArray> out_aux) {
if (reduce == "sum") {
SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, {
cpu::SpMMSumCoo<IdType, DType, Op>(bcast, coo, ufeat, efeat, out);
});
});
} else if (reduce == "max" || reduce == "min") {
SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, {
if (reduce == "max")
cpu::SpMMCmpCoo<IdType, DType, Op, cpu::op::Max<DType>>(
......@@ -294,37 +241,27 @@ void SpMMCoo(const std::string& op, const std::string& reduce,
cpu::SpMMCmpCoo<IdType, DType, Op, cpu::op::Min<DType>>(
bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]);
});
});
} else {
LOG(FATAL) << "Unsupported SpMM reducer: " << reduce;
}
}
template void SpMMCoo<kDGLCPU, int32_t, 16>(
template void SpMMCoo<kDGLCPU, int32_t, float>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCPU, int64_t, 16>(
template void SpMMCoo<kDGLCPU, int64_t, float>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCPU, int32_t, 32>(
template void SpMMCoo<kDGLCPU, int32_t, double>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCPU, int64_t, 32>(
template void SpMMCoo<kDGLCPU, int64_t, double>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCPU, int32_t, 64>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCPU, int64_t, 64>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
} // namespace aten
} // namespace dgl
......@@ -146,21 +146,6 @@ constexpr DType Min<DType>::zero;
} \
} while (0)
#define SWITCH_BITS(bits, DType, ...) \
do { \
if ((bits) == 16) { \
LOG(FATAL) << "FP16 not supported on CPU"; \
} else if ((bits) == 32) { \
typedef float DType; \
{ __VA_ARGS__ } \
} else if ((bits) == 64) { \
typedef double DType; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << "Data type not recognized with bits " << bits; \
} \
} while (0)
} // namespace op
} // namespace cpu
......
......@@ -60,10 +60,12 @@ template NDArray IndexSelect<kDGLCUDA, int32_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, int32_t, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, int64_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, int64_t, int64_t>(NDArray, IdArray);
#ifdef USE_FP16
template NDArray IndexSelect<kDGLCUDA, __half, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, __half, int64_t>(NDArray, IdArray);
#endif
#if BF16_ENABLED
template NDArray IndexSelect<kDGLCUDA, __nv_bfloat16, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, __nv_bfloat16, int64_t>(NDArray, IdArray);
#endif // BF16_ENABLED
template NDArray IndexSelect<kDGLCUDA, float, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, float, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, double, int32_t>(NDArray, IdArray);
......@@ -72,30 +74,21 @@ template NDArray IndexSelect<kDGLCUDA, double, int64_t>(NDArray, IdArray);
template <DGLDeviceType XPU, typename DType>
DType IndexSelect(NDArray array, int64_t index) {
auto device = runtime::DeviceAPI::Get(array->ctx);
#ifdef USE_FP16
// The initialization constructor for __half is apparently a device-
// only function in some setups, but the current function, IndexSelect,
// isn't run on the device, so it doesn't have access to that constructor.
using SafeDType = typename std::conditional<
std::is_same<DType, __half>::value, uint16_t, DType>::type;
SafeDType ret = 0;
#else
DType ret = 0;
#endif
DType ret = static_cast<DType>(0.0f);
device->CopyDataFromTo(
static_cast<DType*>(array->data) + index, 0,
reinterpret_cast<DType*>(&ret), 0, sizeof(DType), array->ctx,
DGLContext{kDGLCPU, 0}, array->dtype);
return reinterpret_cast<DType&>(ret);
static_cast<DType*>(array->data) + index, 0, &ret, 0,
sizeof(DType), array->ctx, DGLContext{kDGLCPU, 0}, array->dtype);
return ret;
}
template int32_t IndexSelect<kDGLCUDA, int32_t>(NDArray array, int64_t index);
template int64_t IndexSelect<kDGLCUDA, int64_t>(NDArray array, int64_t index);
template uint32_t IndexSelect<kDGLCUDA, uint32_t>(NDArray array, int64_t index);
template uint64_t IndexSelect<kDGLCUDA, uint64_t>(NDArray array, int64_t index);
#ifdef USE_FP16
template __half IndexSelect<kDGLCUDA, __half>(NDArray array, int64_t index);
#endif
#if BF16_ENABLED
template __nv_bfloat16 IndexSelect<kDGLCUDA, __nv_bfloat16>(NDArray array, int64_t index);
#endif // BF16_ENABLED
template float IndexSelect<kDGLCUDA, float>(NDArray array, int64_t index);
template double IndexSelect<kDGLCUDA, double>(NDArray array, int64_t index);
......
......@@ -224,9 +224,10 @@ NDArray Full(DType val, int64_t length, DGLContext ctx) {
template IdArray Full<kDGLCUDA, int32_t>(int32_t val, int64_t length, DGLContext ctx);
template IdArray Full<kDGLCUDA, int64_t>(int64_t val, int64_t length, DGLContext ctx);
#ifdef USE_FP16
template IdArray Full<kDGLCUDA, __half>(__half val, int64_t length, DGLContext ctx);
#endif
#if BF16_ENABLED
template IdArray Full<kDGLCUDA, __nv_bfloat16>(__nv_bfloat16 val, int64_t length, DGLContext ctx);
#endif // BF16_ENABLED
template IdArray Full<kDGLCUDA, float>(float val, int64_t length, DGLContext ctx);
template IdArray Full<kDGLCUDA, double>(double val, int64_t length, DGLContext ctx);
......
......@@ -39,16 +39,18 @@ void Scatter_(IdArray index, NDArray value, NDArray out) {
template void Scatter_<kDGLCUDA, int32_t, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, int64_t, int32_t>(IdArray, NDArray, NDArray);
#ifdef USE_FP16
template void Scatter_<kDGLCUDA, __half, int32_t>(IdArray, NDArray, NDArray);
#endif
#if BF16_ENABLED
template void Scatter_<kDGLCUDA, __nv_bfloat16, int32_t>(IdArray, NDArray, NDArray);
#endif // BF16_ENABLED
template void Scatter_<kDGLCUDA, float, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, double, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, int32_t, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, int64_t, int64_t>(IdArray, NDArray, NDArray);
#ifdef USE_FP16
template void Scatter_<kDGLCUDA, __half, int64_t>(IdArray, NDArray, NDArray);
#endif
#if BF16_ENABLED
template void Scatter_<kDGLCUDA, __nv_bfloat16, int64_t>(IdArray, NDArray, NDArray);
#endif // BF16_ENABLED
template void Scatter_<kDGLCUDA, float, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, double, int64_t>(IdArray, NDArray, NDArray);
......
......@@ -9,6 +9,7 @@
#include <cuda_runtime.h>
#include <cassert>
#include "fp16.cuh"
#include "bf16.cuh"
#if __CUDA_ARCH__ >= 600
#include <cuda_fp16.h>
......@@ -44,7 +45,6 @@ template <typename T> struct Cast {
}
};
#ifdef USE_FP16
template <> struct Cast<half> {
typedef Code<sizeof(half)>::Type Type;
static __device__ __forceinline__ Type Encode(half val) {
......@@ -54,7 +54,32 @@ template <> struct Cast<half> {
return __ushort_as_half(code);
}
};
#if BF16_ENABLED
template <> struct Cast<__nv_bfloat16> {
typedef Code<sizeof(__nv_bfloat16)>::Type Type;
static __device__ __forceinline__ Type Encode(__nv_bfloat16 val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return __bfloat16_as_ushort(val);
#else
printf("Atomic operations are not supported for bfloat16 (BF16) "
"on GPUs with compute capability less than 8.0.\n");
__trap();
return static_cast<Type>(0);
#endif
}
static __device__ __forceinline__ __nv_bfloat16 Decode(Type code) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return __ushort_as_bfloat16(code);
#else
printf("Atomic operations are not supported for bfloat16 (BF16) "
"on GPUs with compute capability less than 8.0.\n");
__trap();
return static_cast<__nv_bfloat16>(0.0f);
#endif
}
};
#endif // BF16_ENABLED
template <> struct Cast<float> {
typedef Code<sizeof(float)>::Type Type;
......@@ -109,9 +134,9 @@ static __device__ __forceinline__ unsigned short int atomicCASshort( // NOLINT
return Cast<T>::Decode(old); \
}
#define DEFINE_ATOMIC_HALF(NAME) \
#define DEFINE_ATOMIC_16BIT(NAME, dtype) \
template <> \
__device__ __forceinline__ half Atomic##NAME<half>(half* addr, half val) { \
__device__ __forceinline__ dtype Atomic##NAME<dtype>(dtype* addr, dtype val) { \
typedef uint16_t CT; \
CT* addr_as_ui = reinterpret_cast<CT*>(addr); \
CT old = *addr_as_ui; \
......@@ -119,23 +144,25 @@ static __device__ __forceinline__ unsigned short int atomicCASshort( // NOLINT
do { \
assumed = old; \
old = atomicCASshort(addr_as_ui, assumed, \
Cast<half>::Encode(OP(val, Cast<half>::Decode(old)))); \
Cast<dtype>::Encode(OP(val, Cast<dtype>::Decode(old)))); \
} while (assumed != old); \
return Cast<half>::Decode(old); \
return Cast<dtype>::Decode(old); \
}
#define OP(a, b) max(a, b)
DEFINE_ATOMIC(Max)
#ifdef USE_FP16
DEFINE_ATOMIC_HALF(Max)
#endif // USE_FP16
DEFINE_ATOMIC_16BIT(Max, half)
#if BF16_ENABLED
DEFINE_ATOMIC_16BIT(Max, __nv_bfloat16)
#endif // BF16_ENABLED
#undef OP
#define OP(a, b) min(a, b)
DEFINE_ATOMIC(Min)
#ifdef USE_FP16
DEFINE_ATOMIC_HALF(Min)
#endif // USE_FP16
DEFINE_ATOMIC_16BIT(Min, half)
#if BF16_ENABLED
DEFINE_ATOMIC_16BIT(Min, __nv_bfloat16)
#endif // BF16_ENABLED
#undef OP
#define OP(a, b) a + b
......@@ -258,11 +285,10 @@ __device__ __forceinline__ double AtomicAdd<double>(double* addr, double val) {
#endif
}
#ifdef USE_FP16
#if defined(CUDART_VERSION) && CUDART_VERSION >= 10000
template <>
__device__ __forceinline__ half AtomicAdd<half>(half* addr, half val) {
// half make sure we have half support
// make sure we have half support
#if __CUDA_ARCH__ >= 700
return atomicAdd(addr, val);
#else
......@@ -275,7 +301,24 @@ __device__ __forceinline__ half AtomicAdd<half>(half* addr, half val) {
#endif // __CUDA_ARCH__ >= 700
}
#endif // defined(CUDART_VERSION) && CUDART_VERSION >= 10000
#endif // USE_FP16
#if BF16_ENABLED
template <>
__device__ __forceinline__ __nv_bfloat16 AtomicAdd<__nv_bfloat16>(
__nv_bfloat16* addr, __nv_bfloat16 val) {
// make sure we have bfloat16 support
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return atomicAdd(addr, val);
#else
(void)addr;
(void)val;
printf("Atomic operations are not supported for bfloat16 (BF16) "
"on GPUs with compute capability less than 8.0.\n");
__trap();
return val;
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
}
#endif // BF16_ENABLED
} // namespace cuda
......
/*!
* Copyright (c) 2022 by Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* \file array/cuda/bf16.cuh
* \brief bfloat16 related functions.
*/
#ifndef DGL_ARRAY_CUDA_BF16_CUH_
#define DGL_ARRAY_CUDA_BF16_CUH_
#if BF16_ENABLED
#include <cuda_bf16.h>
#include <algorithm>
static __device__ __forceinline__ __nv_bfloat16 max(__nv_bfloat16 a,
__nv_bfloat16 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return __hmax(a, b);
#else
return __nv_bfloat16(max(float(a), float(b))); // NOLINT
#endif
}
static __device__ __forceinline__ __nv_bfloat16 min(__nv_bfloat16 a,
__nv_bfloat16 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return __hmin(a, b);
#else
return __nv_bfloat16(min(float(a), float(b))); // NOLINT
#endif
}
#ifdef __CUDACC__
// Arithmetic BF16 operations for architecture >= 8.0 are already defined in cuda_bf16.h
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
__device__ __forceinline__ __nv_bfloat16 operator+(
const __nv_bfloat16& lh, const __nv_bfloat16& rh) {
return __nv_bfloat16(float(lh) + float(rh)); // NOLINT
}
__device__ __forceinline__ __nv_bfloat16 operator-(
const __nv_bfloat16& lh, const __nv_bfloat16& rh) {
return __nv_bfloat16(float(lh) - float(rh)); // NOLINT
}
__device__ __forceinline__ __nv_bfloat16 operator*(
const __nv_bfloat16& lh, const __nv_bfloat16& rh) {
return __nv_bfloat16(float(lh) * float(rh)); // NOLINT
}
__device__ __forceinline__ __nv_bfloat16 operator/(
const __nv_bfloat16& lh, const __nv_bfloat16& rh) {
return __nv_bfloat16(float(lh) / float(rh)); // NOLINT
}
__device__ __forceinline__ __nv_bfloat16& operator+=(
__nv_bfloat16& lh, const __nv_bfloat16& rh) { // NOLINT
lh = __nv_bfloat16(float(lh) + float(rh)); return lh; // NOLINT
}
__device__ __forceinline__ __nv_bfloat16& operator-=(
__nv_bfloat16& lh, const __nv_bfloat16& rh) { // NOLINT
lh = __nv_bfloat16(float(lh) - float(rh)); return lh; // NOLINT
}
__device__ __forceinline__ __nv_bfloat16& operator*=(
__nv_bfloat16& lh, const __nv_bfloat16& rh) { // NOLINT
lh = __nv_bfloat16(float(lh) * float(rh)); return lh; // NOLINT
}
__device__ __forceinline__ __nv_bfloat16& operator/=(
__nv_bfloat16& lh, const __nv_bfloat16& rh) { // NOLINT
lh = __nv_bfloat16(float(lh) / float(rh)); return lh; // NOLINT
}
__device__ __forceinline__ __nv_bfloat16& operator++(__nv_bfloat16& h) { // NOLINT
h = __nv_bfloat16(float(h) + 1.0f); return h; // NOLINT
}
__device__ __forceinline__ __nv_bfloat16& operator--(__nv_bfloat16& h) { // NOLINT
h = __nv_bfloat16(float(h) - 1.0f); return h; // NOLINT
}
__device__ __forceinline__ __nv_bfloat16 operator++(__nv_bfloat16& h, int) { // NOLINT
__nv_bfloat16 ret = h; h = __nv_bfloat16(float(h) + 1.0f); return ret; // NOLINT
}
__device__ __forceinline__ __nv_bfloat16 operator--(__nv_bfloat16& h, int) { // NOLINT
__nv_bfloat16 ret = h; h = __nv_bfloat16(float(h) - 1.0f); return ret; // NOLINT
}
__device__ __forceinline__ __nv_bfloat16 operator+(const __nv_bfloat16& h) {
return h;
}
__device__ __forceinline__ __nv_bfloat16 operator-(const __nv_bfloat16& h) {
return __nv_bfloat16(-float(h)); // NOLINT
}
__device__ __forceinline__ bool operator==(const __nv_bfloat16& lh,
const __nv_bfloat16& rh) {
return float(lh) == float(rh); // NOLINT
}
__device__ __forceinline__ bool operator!=(const __nv_bfloat16& lh,
const __nv_bfloat16& rh) {
return float(lh) != float(rh); // NOLINT
}
__device__ __forceinline__ bool operator> (const __nv_bfloat16& lh,
const __nv_bfloat16& rh) {
return float(lh) > float(rh); // NOLINT
}
__device__ __forceinline__ bool operator< (const __nv_bfloat16& lh,
const __nv_bfloat16& rh) {
return float(lh) < float(rh); // NOLINT
}
__device__ __forceinline__ bool operator>=(const __nv_bfloat16& lh,
const __nv_bfloat16& rh) {
return float(lh) >= float(rh); // NOLINT
}
__device__ __forceinline__ bool operator<=(const __nv_bfloat16& lh,
const __nv_bfloat16& rh) {
return float(lh) <= float(rh); // NOLINT
}
#endif // defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
#endif // __CUDACC__
#endif // BF16_ENABLED
#endif // DGL_ARRAY_CUDA_BF16_CUH_
......@@ -65,12 +65,18 @@ NDArray CSRGetData(
return rst;
}
#ifdef USE_FP16
template NDArray CSRGetData<kDGLCUDA, int32_t, __half>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, __half filler);
template NDArray CSRGetData<kDGLCUDA, int64_t, __half>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, __half filler);
#endif
#if BF16_ENABLED
template NDArray CSRGetData<kDGLCUDA, int32_t, __nv_bfloat16>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
NDArray weights, __nv_bfloat16 filler);
template NDArray CSRGetData<kDGLCUDA, int64_t, __nv_bfloat16>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
NDArray weights, __nv_bfloat16 filler);
#endif // BF16_ENABLED
template NDArray CSRGetData<kDGLCUDA, int32_t, float>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler);
template NDArray CSRGetData<kDGLCUDA, int64_t, float>(
......
......@@ -255,12 +255,16 @@ std::pair<CSRMatrix, NDArray> CSRMM(
}
}
#ifdef USE_FP16
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int32_t, __half>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int64_t, __half>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
#endif
#if BF16_ENABLED
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int32_t, __nv_bfloat16>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int64_t, __nv_bfloat16>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
#endif // BF16_ENABLED
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int32_t, float>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int64_t, float>(
......
......@@ -167,12 +167,16 @@ std::pair<CSRMatrix, NDArray> CSRSum(
}
}
#ifdef USE_FP16
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int32_t, __half>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int64_t, __half>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
#endif
#if BF16_ENABLED
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int32_t, __nv_bfloat16>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int64_t, __nv_bfloat16>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
#endif // BF16_ENABLED
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int32_t, float>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int64_t, float>(
......
......@@ -9,6 +9,8 @@
#include <cusparse.h>
#include <dgl/runtime/c_runtime_api.h>
#include "fp16.cuh"
#include "bf16.cuh"
namespace dgl {
namespace aten {
......@@ -34,7 +36,6 @@ struct CSRGEMM {
}
};
#ifdef USE_FP16
template <>
struct CSRGEMM<__half> {
template <typename... Args>
......@@ -58,7 +59,32 @@ struct CSRGEMM<__half> {
return static_cast<cusparseStatus_t>(0);
}
};
#endif
#if BF16_ENABLED
template <>
struct CSRGEMM<__nv_bfloat16> {
template <typename... Args>
static inline cusparseStatus_t bufferSizeExt(Args... args) {
// TODO(ndickson): There is no cusparseHcsrgemm2_bufferSizeExt, so a different
// implementation would be required.
LOG(FATAL) << "CSRGEMM::bufferSizeExt does not support dtype bfloat16 (BF16).";
return static_cast<cusparseStatus_t>(0);
}
template <typename... Args>
static inline cusparseStatus_t nnz(Args... args) {
return cusparseXcsrgemm2Nnz(args...);
}
template <typename... Args>
static inline cusparseStatus_t compute(Args... args) {
// TODO(ndickson): There is no cusparseHcsrgemm2, so a different
// implementation would be required.
LOG(FATAL) << "CSRGEMM::compute does not support dtype bfloat16 (BF16).";
return static_cast<cusparseStatus_t>(0);
}
};
#endif // BF16_ENABLED
template <>
struct CSRGEMM<float> {
......@@ -117,7 +143,6 @@ struct CSRGEAM {
}
};
#ifdef USE_FP16
template <>
struct CSRGEAM<__half> {
template <typename... Args>
......@@ -141,7 +166,32 @@ struct CSRGEAM<__half> {
return static_cast<cusparseStatus_t>(0);
}
};
#endif
#if BF16_ENABLED
template <>
struct CSRGEAM<__nv_bfloat16> {
template <typename... Args>
static inline cusparseStatus_t bufferSizeExt(Args... args) {
// TODO(ndickson): There is no cusparseHcsrgeam2_bufferSizeExt, so a different
// implementation would be required.
LOG(FATAL) << "CSRGEAM::bufferSizeExt does not support dtype bfloat16 (BF16).";
return static_cast<cusparseStatus_t>(0);
}
template <typename... Args>
static inline cusparseStatus_t nnz(Args... args) {
return cusparseXcsrgeam2Nnz(args...);
}
template <typename... Args>
static inline cusparseStatus_t compute(Args... args) {
// TODO(ndickson): There is no cusparseHcsrgeam2, so a different
// implementation would be required.
LOG(FATAL) << "CSRGEAM::compute does not support dtype bfloat16 (BF16).";
return static_cast<cusparseStatus_t>(0);
}
};
#endif // BF16_ENABLED
template <>
struct CSRGEAM<float> {
......
/*!
* Copyright (c) 2020 by Contributors
* Copyright (c) 2020-2022 by Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* \file array/cuda/fp16.cuh
* \brief float16 related functions.
* \note this file is modified from TVM project:
......@@ -9,7 +22,6 @@
#define DGL_ARRAY_CUDA_FP16_CUH_
#ifdef USE_FP16
#include <cuda_fp16.h>
#include <algorithm>
......@@ -97,6 +109,4 @@ __device__ __forceinline__ bool operator<=(const __half& lh, const __half& rh) {
#endif // __CUDA_ARCH__ < 530
#endif // __CUDACC__
#endif // USE_FP16
#endif // DGL_ARRAY_CUDA_FP16_CUH_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment