Unverified Commit 96297fb8 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Feature] Add bfloat16 (bf16) support (#4648)

* add bf16 specializations

* remove SWITCH_BITS

* enable amp for bf16

* remove SWITCH_BITS for cpu kernels

* enbale bf16 based on CUDART

* fix compiling for sm<80

* fix cpu build

* enable unit tests

* update doc

* disable test for CUDA < 11.0

* address comments

* address comments
parent 1d229194
...@@ -28,7 +28,6 @@ dgl_option(USE_SYSTEM_NCCL "Build using system's NCCL library" OFF) ...@@ -28,7 +28,6 @@ dgl_option(USE_SYSTEM_NCCL "Build using system's NCCL library" OFF)
dgl_option(USE_OPENMP "Build with OpenMP" ON) dgl_option(USE_OPENMP "Build with OpenMP" ON)
dgl_option(USE_AVX "Build with AVX optimization" OFF) dgl_option(USE_AVX "Build with AVX optimization" OFF)
dgl_option(USE_LIBXSMM "Build with LIBXSMM library optimization" ON) dgl_option(USE_LIBXSMM "Build with LIBXSMM library optimization" ON)
dgl_option(USE_FP16 "Build with fp16 support to enable mixed precision training" OFF)
dgl_option(USE_TVM "Build with TVM kernels" OFF) dgl_option(USE_TVM "Build with TVM kernels" OFF)
dgl_option(BUILD_CPP_TEST "Build cpp unittest executables" OFF) dgl_option(BUILD_CPP_TEST "Build cpp unittest executables" OFF)
dgl_option(LIBCXX_ENABLE_PARALLEL_ALGORITHMS "Enable the parallel algorithms library. This requires the PSTL to be available." OFF) dgl_option(LIBCXX_ENABLE_PARALLEL_ALGORITHMS "Enable the parallel algorithms library. This requires the PSTL to be available." OFF)
...@@ -127,13 +126,6 @@ if ((NOT MSVC) AND USE_EPOLL) ...@@ -127,13 +126,6 @@ if ((NOT MSVC) AND USE_EPOLL)
endif() endif()
endif () endif ()
# Build with fp16 to support mixed precision training.
if(USE_FP16)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DUSE_FP16")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_FP16")
message(STATUS "Build with fp16 to support mixed precision training")
endif(USE_FP16)
# To compile METIS correct for DGL. # To compile METIS correct for DGL.
if(MSVC) if(MSVC)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /DIDXTYPEWIDTH=64 /DREALTYPEWIDTH=32") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /DIDXTYPEWIDTH=64 /DREALTYPEWIDTH=32")
......
...@@ -9,9 +9,9 @@ consumption. This feature requires DGL 0.9+. ...@@ -9,9 +9,9 @@ consumption. This feature requires DGL 0.9+.
Message-Passing with Half Precision Message-Passing with Half Precision
----------------------------------- -----------------------------------
DGL allows message-passing on ``float16 (fp16)`` features for both DGL allows message-passing on ``float16 (fp16)`` / ``bfloat16 (bf16)`` (requires CUDA >= 11.0)
UDFs (User Defined Functions) and built-in functions (e.g., ``dgl.function.sum``, features for both UDFs (User Defined Functions) and built-in functions
``dgl.function.copy_u``). (e.g., ``dgl.function.sum``, ``dgl.function.copy_u``).
The following example shows how to use DGL's message-passing APIs on half-precision The following example shows how to use DGL's message-passing APIs on half-precision
features: features:
...@@ -63,8 +63,9 @@ efficient, most operators on half precision tensors are faster as they leverage ...@@ -63,8 +63,9 @@ efficient, most operators on half precision tensors are faster as they leverage
import torch.nn.functional as F import torch.nn.functional as F
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
def forward(g, feat, label, mask, model, use_fp16): def forward(g, feat, label, mask, model, amp_dtype):
with autocast(enabled=use_fp16): amp_enabled = amp_dtype in (torch.float16, torch.bfloat16)
with autocast(enabled=amp_enabled, dtype=amp_dtype):
logit = model(g, feat) logit = model(g, feat)
loss = F.cross_entropy(logit[mask], label[mask]) loss = F.cross_entropy(logit[mask], label[mask])
return loss return loss
...@@ -74,6 +75,7 @@ PyTorch provides a ``GradScaler`` module to address this issue. It multiplies ...@@ -74,6 +75,7 @@ PyTorch provides a ``GradScaler`` module to address this issue. It multiplies
the loss by a factor and invokes backward pass on the scaled loss to prevent the loss by a factor and invokes backward pass on the scaled loss to prevent
the underflow problem. It then unscales the computed gradients before the optimizer the underflow problem. It then unscales the computed gradients before the optimizer
updates the parameters. The scale factor is determined automatically. updates the parameters. The scale factor is determined automatically.
Note that ``bfloat16`` doesn't require a ``GradScaler``.
.. code:: .. code::
...@@ -87,7 +89,7 @@ updates the parameters. The scale factor is determined automatically. ...@@ -87,7 +89,7 @@ updates the parameters. The scale factor is determined automatically.
scaler.update() scaler.update()
The following example trains a 3-layer GAT on the Reddit dataset (w/ 114 million edges). The following example trains a 3-layer GAT on the Reddit dataset (w/ 114 million edges).
Pay attention to the differences in the code when ``use_fp16`` is activated or not. Pay attention to the differences in the code when AMP is activated or not.
.. code:: .. code::
...@@ -98,7 +100,7 @@ Pay attention to the differences in the code when ``use_fp16`` is activated or n ...@@ -98,7 +100,7 @@ Pay attention to the differences in the code when ``use_fp16`` is activated or n
from dgl.nn import GATConv from dgl.nn import GATConv
from dgl.transforms import AddSelfLoop from dgl.transforms import AddSelfLoop
use_fp16 = True amp_dtype = torch.float16 # or torch.bfloat16
class GAT(nn.Module): class GAT(nn.Module):
def __init__(self, def __init__(self,
...@@ -145,9 +147,9 @@ Pay attention to the differences in the code when ``use_fp16`` is activated or n ...@@ -145,9 +147,9 @@ Pay attention to the differences in the code when ``use_fp16`` is activated or n
for epoch in range(100): for epoch in range(100):
optimizer.zero_grad() optimizer.zero_grad()
loss = forward(g, feat, label, train_mask, model, use_fp16) loss = forward(g, feat, label, train_mask, model, amp_dtype)
if use_fp16: if amp_dtype == torch.float16:
# Backprop w/ gradient scaling # Backprop w/ gradient scaling
backward(scaler, loss, optimizer) backward(scaler, loss, optimizer)
else: else:
......
...@@ -123,26 +123,67 @@ ...@@ -123,26 +123,67 @@
typedef double FloatType; \ typedef double FloatType; \
{__VA_ARGS__} \ {__VA_ARGS__} \
} else { \ } else { \
LOG(FATAL) << (val_name) << " can only be float32 or float64"; \ LOG(FATAL) << (val_name) \
<< " can only be float32 or float64"; \
} \ } \
} while (0) } while (0)
#define ATEN_FLOAT_BITS_SWITCH(val, bits, val_name, ...) do { \ /*
CHECK_EQ((val).code, kDGLFloat) \ * Dispatch according to float type, including 16bits (float16/bfloat16/float32/float64).
<< (val_name) << " must be float type"; \ */
if ((val).bits == 16) { \ #ifdef DGL_USE_CUDA
constexpr int bits = 16; \ #if BF16_ENABLED
{__VA_ARGS__} \ #define ATEN_FLOAT_TYPE_SWITCH_16BITS(val, FloatType, XPU, val_name, ...) do { \
} else if ((val).bits == 32) { \ CHECK((val).code == kDGLFloat || (val.code == kDGLBfloat)) \
constexpr int bits = 32; \ << (val_name) << " must be float type"; \
{__VA_ARGS__} \ if ((val).bits == 32) { \
} else if ((val).bits == 64) { \ typedef float FloatType; \
constexpr int bits = 64; \ {__VA_ARGS__} \
{__VA_ARGS__} \ } else if ((val).bits == 64) { \
} else { \ typedef double FloatType; \
LOG(FATAL) << (val_name) << " can only be float32 or float64"; \ {__VA_ARGS__} \
} \ } else if (XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLFloat) { \
typedef __half FloatType; \
{__VA_ARGS__} \
} else if (XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLBfloat) { \
typedef __nv_bfloat16 FloatType; \
{__VA_ARGS__} \
} else if (XPU == kDGLCPU) { \
LOG(FATAL) << (val_name) \
<< " can only be float32 or float64 on CPU"; \
} else { \
LOG(FATAL) << (val_name) \
<< " can only be float16/bfloat16/float32/float64 on GPU"; \
} \
} while (0) } while (0)
#else // BF16_ENABLED
#define ATEN_FLOAT_TYPE_SWITCH_16BITS(val, FloatType, XPU, val_name, ...) do { \
CHECK((val).code == kDGLFloat || (val.code == kDGLBfloat)) \
<< (val_name) << " must be float type"; \
if ((val).bits == 32) { \
typedef float FloatType; \
{__VA_ARGS__} \
} else if ((val).bits == 64) { \
typedef double FloatType; \
{__VA_ARGS__} \
} else if (XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLFloat) { \
typedef __half FloatType; \
{__VA_ARGS__} \
} else if (XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLBfloat) { \
LOG(FATAL) << "bfloat16 requires CUDA >= 11.0"; \
} else if (XPU == kDGLCPU) { \
LOG(FATAL) << (val_name) \
<< " can only be float32 or float64 on CPU"; \
} else { \
LOG(FATAL) << (val_name) \
<< " can only be float16/float32/float64 on GPU"; \
} \
} while (0)
#endif // BF16_ENABLED
#else // DGL_USE_CUDA
#define ATEN_FLOAT_TYPE_SWITCH_16BITS(val, FloatType, XPU, val_name, ...) \
ATEN_FLOAT_TYPE_SWITCH(val, FloatType, val_name, {__VA_ARGS__})
#endif // DGL_USE_CUDA
/* /*
* Dispatch according to data type (int32, int64, float32 or float64): * Dispatch according to data type (int32, int64, float32 or float64):
......
...@@ -18,8 +18,15 @@ ...@@ -18,8 +18,15 @@
#include "shared_mem.h" #include "shared_mem.h"
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
#include <cuda_runtime.h>
#define BF16_ENABLED (defined(CUDART_VERSION) && CUDART_VERSION >= 11000)
#include <cuda_fp16.h> #include <cuda_fp16.h>
#endif #if BF16_ENABLED
#include <cuda_bf16.h>
#endif // BF16_ENABLED
#endif // DGL_USE_CUDA
// forward declaration // forward declaration
inline std::ostream& operator << (std::ostream& os, DGLDataType t); inline std::ostream& operator << (std::ostream& os, DGLDataType t);
...@@ -51,10 +58,11 @@ GEN_DGLDATATYPETRAITS_FOR(int64_t, kDGLInt, 64); ...@@ -51,10 +58,11 @@ GEN_DGLDATATYPETRAITS_FOR(int64_t, kDGLInt, 64);
GEN_DGLDATATYPETRAITS_FOR(uint32_t, kDGLInt, 32); GEN_DGLDATATYPETRAITS_FOR(uint32_t, kDGLInt, 32);
GEN_DGLDATATYPETRAITS_FOR(uint64_t, kDGLInt, 64); GEN_DGLDATATYPETRAITS_FOR(uint64_t, kDGLInt, 64);
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
#ifdef USE_FP16
GEN_DGLDATATYPETRAITS_FOR(__half, kDGLFloat, 16); GEN_DGLDATATYPETRAITS_FOR(__half, kDGLFloat, 16);
#endif #if BF16_ENABLED
#endif GEN_DGLDATATYPETRAITS_FOR(__nv_bfloat16, kDGLBfloat, 16);
#endif // BF16_ENABLED
#endif // DGL_USE_CUDA
GEN_DGLDATATYPETRAITS_FOR(float, kDGLFloat, 32); GEN_DGLDATATYPETRAITS_FOR(float, kDGLFloat, 32);
GEN_DGLDATATYPETRAITS_FOR(double, kDGLFloat, 64); GEN_DGLDATATYPETRAITS_FOR(double, kDGLFloat, 64);
#undef GEN_DGLDATATYPETRAITS_FOR #undef GEN_DGLDATATYPETRAITS_FOR
......
import torch as th import torch as th
from torch.cuda.amp import custom_bwd, custom_fwd
from ...base import ALL, is_all from ...base import ALL, is_all
from ...heterograph_index import create_unitgraph_from_csr from ...heterograph_index import create_unitgraph_from_csr
...@@ -39,7 +38,6 @@ __all__ = [ ...@@ -39,7 +38,6 @@ __all__ = [
"segment_mm", "segment_mm",
] ]
def _reduce_grad(grad, shape): def _reduce_grad(grad, shape):
"""Reduce gradient on the broadcast dimension """Reduce gradient on the broadcast dimension
If there is broadcast in forward pass, gradients need to be reduced on If there is broadcast in forward pass, gradients need to be reduced on
...@@ -127,10 +125,26 @@ def spmm_cache_argY(binary_op, reduce_op, req_grad_X, req_grad_Y): ...@@ -127,10 +125,26 @@ def spmm_cache_argY(binary_op, reduce_op, req_grad_X, req_grad_Y):
return True return True
return False return False
class empty_context():
"""Empty context that does nothing"""
def __init__(self, *args, **kargs):
return
def __enter__(self, *args, **kargs):
return self
def __exit__(self, *args, **kargs):
return
# This is to avoid warnings in cpu-only dgl. We don't enable autocast for CPU ops
autocast = th.cuda.amp.autocast if th.cuda.is_available() else empty_context
def _cast_if_autocast_enabled(*args):
if not th.is_autocast_enabled() or not th.cuda.is_available():
return args
else:
return th.cuda.amp.autocast_mode._cast(args, th.get_autocast_gpu_dtype())
class GSpMM(th.autograd.Function): class GSpMM(th.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=th.float16)
def forward(ctx, gidx, op, reduce_op, X, Y): def forward(ctx, gidx, op, reduce_op, X, Y):
out, (argX, argY) = _gspmm(gidx, op, reduce_op, X, Y) out, (argX, argY) = _gspmm(gidx, op, reduce_op, X, Y)
reduce_last = _need_reduce_last_dim(X, Y) reduce_last = _need_reduce_last_dim(X, Y)
...@@ -162,7 +176,6 @@ class GSpMM(th.autograd.Function): ...@@ -162,7 +176,6 @@ class GSpMM(th.autograd.Function):
return out return out
@staticmethod @staticmethod
@custom_bwd
def backward(ctx, dZ): def backward(ctx, dZ):
( (
gidx, gidx,
...@@ -221,7 +234,6 @@ class GSpMM(th.autograd.Function): ...@@ -221,7 +234,6 @@ class GSpMM(th.autograd.Function):
class GSpMM_hetero(th.autograd.Function): class GSpMM_hetero(th.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=th.float16)
def forward( def forward(
ctx, gidx, op, reduce_op, X_len, *feats ctx, gidx, op, reduce_op, X_len, *feats
): # feats = lhs_data + rhs_data ): # feats = lhs_data + rhs_data
...@@ -278,7 +290,6 @@ class GSpMM_hetero(th.autograd.Function): ...@@ -278,7 +290,6 @@ class GSpMM_hetero(th.autograd.Function):
return out return out
@staticmethod @staticmethod
@custom_bwd
def backward(ctx, *dZ): def backward(ctx, *dZ):
( (
gidx, gidx,
...@@ -415,7 +426,6 @@ def sddmm_cache_Y(op, req_grad_X, req_grad_Y): ...@@ -415,7 +426,6 @@ def sddmm_cache_Y(op, req_grad_X, req_grad_Y):
class GSDDMM(th.autograd.Function): class GSDDMM(th.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=th.float16)
def forward(ctx, gidx, op, X, Y, lhs_target, rhs_target): def forward(ctx, gidx, op, X, Y, lhs_target, rhs_target):
out = _gsddmm(gidx, op, X, Y, lhs_target, rhs_target) out = _gsddmm(gidx, op, X, Y, lhs_target, rhs_target)
X_shape = X.shape if X is not None else None X_shape = X.shape if X is not None else None
...@@ -431,7 +441,6 @@ class GSDDMM(th.autograd.Function): ...@@ -431,7 +441,6 @@ class GSDDMM(th.autograd.Function):
return out return out
@staticmethod @staticmethod
@custom_bwd
def backward(ctx, dZ): def backward(ctx, dZ):
gidx, op, lhs_target, rhs_target, X_shape, Y_shape = ctx.backward_cache gidx, op, lhs_target, rhs_target, X_shape, Y_shape = ctx.backward_cache
X, Y = ctx.saved_tensors X, Y = ctx.saved_tensors
...@@ -480,7 +489,6 @@ class GSDDMM(th.autograd.Function): ...@@ -480,7 +489,6 @@ class GSDDMM(th.autograd.Function):
class GSDDMM_hetero(th.autograd.Function): class GSDDMM_hetero(th.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=th.float16)
def forward( def forward(
ctx, gidx, op, X_len, lhs_target, rhs_target, *feats ctx, gidx, op, X_len, lhs_target, rhs_target, *feats
): # feats = X+Y ): # feats = X+Y
...@@ -517,7 +525,6 @@ class GSDDMM_hetero(th.autograd.Function): ...@@ -517,7 +525,6 @@ class GSDDMM_hetero(th.autograd.Function):
return out return out
@staticmethod @staticmethod
@custom_bwd
# TODO(Israt): Implement the complete backward operator # TODO(Israt): Implement the complete backward operator
def backward(ctx, *dZ): def backward(ctx, *dZ):
( (
...@@ -661,7 +668,6 @@ class GSDDMM_hetero(th.autograd.Function): ...@@ -661,7 +668,6 @@ class GSDDMM_hetero(th.autograd.Function):
class EdgeSoftmax(th.autograd.Function): class EdgeSoftmax(th.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=th.float16)
def forward(ctx, gidx, score, eids, norm_by): def forward(ctx, gidx, score, eids, norm_by):
"""Forward function. """Forward function.
...@@ -696,7 +702,6 @@ class EdgeSoftmax(th.autograd.Function): ...@@ -696,7 +702,6 @@ class EdgeSoftmax(th.autograd.Function):
return out return out
@staticmethod @staticmethod
@custom_bwd
def backward(ctx, grad_out): def backward(ctx, grad_out):
"""Backward function. """Backward function.
...@@ -728,7 +733,6 @@ class EdgeSoftmax(th.autograd.Function): ...@@ -728,7 +733,6 @@ class EdgeSoftmax(th.autograd.Function):
class EdgeSoftmax_hetero(th.autograd.Function): class EdgeSoftmax_hetero(th.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=th.float16)
def forward(ctx, gidx, eids, norm_by, *score): def forward(ctx, gidx, eids, norm_by, *score):
"""Forward function. """Forward function.
...@@ -774,7 +778,6 @@ class EdgeSoftmax_hetero(th.autograd.Function): ...@@ -774,7 +778,6 @@ class EdgeSoftmax_hetero(th.autograd.Function):
return out return out
@staticmethod @staticmethod
@custom_bwd
def backward(ctx, *grad_out): def backward(ctx, *grad_out):
"""Backward function. """Backward function.
...@@ -808,7 +811,6 @@ class EdgeSoftmax_hetero(th.autograd.Function): ...@@ -808,7 +811,6 @@ class EdgeSoftmax_hetero(th.autograd.Function):
class SegmentReduce(th.autograd.Function): class SegmentReduce(th.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=th.float16)
def forward(ctx, op, x, offsets): def forward(ctx, op, x, offsets):
y, arg = _segment_reduce(op, x, offsets) y, arg = _segment_reduce(op, x, offsets)
ctx.save_for_backward(arg, offsets) ctx.save_for_backward(arg, offsets)
...@@ -816,7 +818,6 @@ class SegmentReduce(th.autograd.Function): ...@@ -816,7 +818,6 @@ class SegmentReduce(th.autograd.Function):
return y return y
@staticmethod @staticmethod
@custom_bwd
def backward(ctx, dy): def backward(ctx, dy):
op = ctx.backward_cache op = ctx.backward_cache
arg, offsets = ctx.saved_tensors arg, offsets = ctx.saved_tensors
...@@ -838,14 +839,12 @@ class SegmentReduce(th.autograd.Function): ...@@ -838,14 +839,12 @@ class SegmentReduce(th.autograd.Function):
class ScatterAdd(th.autograd.Function): class ScatterAdd(th.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=th.float16)
def forward(ctx, x, idx, m): def forward(ctx, x, idx, m):
y = _scatter_add(x, idx, m) y = _scatter_add(x, idx, m)
ctx.save_for_backward(idx) ctx.save_for_backward(idx)
return y return y
@staticmethod @staticmethod
@custom_bwd
def backward(ctx, dy): def backward(ctx, dy):
idx = ctx.saved_tensors idx = ctx.saved_tensors
return dy[idx], None, None return dy[idx], None, None
...@@ -952,7 +951,6 @@ class CSRMask(th.autograd.Function): ...@@ -952,7 +951,6 @@ class CSRMask(th.autograd.Function):
class SEGMENTMM(th.autograd.Function): class SEGMENTMM(th.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=th.float16)
def forward(ctx, A, B, seglen_A): def forward(ctx, A, B, seglen_A):
if B.dim() != 3: if B.dim() != 3:
raise ValueError("segment_mm expects B to be a 3D tensor.") raise ValueError("segment_mm expects B to be a 3D tensor.")
...@@ -978,7 +976,6 @@ class SEGMENTMM(th.autograd.Function): ...@@ -978,7 +976,6 @@ class SEGMENTMM(th.autograd.Function):
class GATHERMM(th.autograd.Function): class GATHERMM(th.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=th.float16)
def forward(ctx, A, B, idx_a, idx_b): def forward(ctx, A, B, idx_a, idx_b):
if B.dim() != 3: if B.dim() != 3:
raise ValueError( raise ValueError(
...@@ -1014,7 +1011,9 @@ def gspmm(gidx, op, reduce_op, lhs_data, rhs_data): ...@@ -1014,7 +1011,9 @@ def gspmm(gidx, op, reduce_op, lhs_data, rhs_data):
if op == "div": if op == "div":
op = "mul" op = "mul"
rhs_data = 1.0 / rhs_data rhs_data = 1.0 / rhs_data
return GSpMM.apply(gidx, op, reduce_op, lhs_data, rhs_data) args = _cast_if_autocast_enabled(gidx, op, reduce_op, lhs_data, rhs_data)
with autocast(enabled=False):
return GSpMM.apply(*args)
def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target="u", rhs_target="v"): def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target="u", rhs_target="v"):
...@@ -1024,7 +1023,9 @@ def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target="u", rhs_target="v"): ...@@ -1024,7 +1023,9 @@ def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target="u", rhs_target="v"):
if op == "div": if op == "div":
op = "mul" op = "mul"
rhs_data = 1.0 / rhs_data rhs_data = 1.0 / rhs_data
return GSDDMM.apply(gidx, op, lhs_data, rhs_data, lhs_target, rhs_target) args = _cast_if_autocast_enabled(gidx, op, lhs_data, rhs_data, lhs_target, rhs_target)
with autocast(enabled=False):
return GSDDMM.apply(*args)
def gspmm_hetero(g, op, reduce_op, lhs_len, *lhs_and_rhs_tuple): def gspmm_hetero(g, op, reduce_op, lhs_len, *lhs_and_rhs_tuple):
...@@ -1050,7 +1051,10 @@ def gspmm_hetero(g, op, reduce_op, lhs_len, *lhs_and_rhs_tuple): ...@@ -1050,7 +1051,10 @@ def gspmm_hetero(g, op, reduce_op, lhs_len, *lhs_and_rhs_tuple):
) )
if op in ["add", "mul"]: if op in ["add", "mul"]:
lhs_and_rhs_tuple = tuple(list(lhs_tuple) + list(rhs_tuple)) lhs_and_rhs_tuple = tuple(list(lhs_tuple) + list(rhs_tuple))
return GSpMM_hetero.apply(g, op, reduce_op, lhs_len, *lhs_and_rhs_tuple)
args = _cast_if_autocast_enabled(g, op, reduce_op, lhs_len, *lhs_and_rhs_tuple)
with autocast(enabled=False):
return GSpMM_hetero.apply(*args)
def gsddmm_hetero( def gsddmm_hetero(
...@@ -1078,25 +1082,34 @@ def gsddmm_hetero( ...@@ -1078,25 +1082,34 @@ def gsddmm_hetero(
) )
if op in ["add", "mul"]: if op in ["add", "mul"]:
lhs_and_rhs_tuple = tuple(list(lhs_tuple) + list(rhs_tuple)) lhs_and_rhs_tuple = tuple(list(lhs_tuple) + list(rhs_tuple))
return GSDDMM_hetero.apply(
g, op, lhs_len, lhs_target, rhs_target, *lhs_and_rhs_tuple args = _cast_if_autocast_enabled(g, op, lhs_len, lhs_target, rhs_target, *lhs_and_rhs_tuple)
) with autocast(enabled=False):
return GSDDMM_hetero.apply(*args)
def edge_softmax(gidx, logits, eids=ALL, norm_by="dst"): def edge_softmax(gidx, logits, eids=ALL, norm_by="dst"):
return EdgeSoftmax.apply(gidx, logits, eids, norm_by) args = _cast_if_autocast_enabled(gidx, logits, eids, norm_by)
with autocast(enabled=False):
return EdgeSoftmax.apply(*args)
def edge_softmax_hetero(gidx, eids=ALL, norm_by="dst", *logits): def edge_softmax_hetero(gidx, eids=ALL, norm_by="dst", *logits):
return EdgeSoftmax_hetero.apply(gidx, eids, norm_by, *logits) args = _cast_if_autocast_enabled(gidx, eids, norm_by, *logits)
with autocast(enabled=False):
return EdgeSoftmax_hetero.apply(*args)
def segment_reduce(op, x, offsets): def segment_reduce(op, x, offsets):
return SegmentReduce.apply(op, x, offsets) args = _cast_if_autocast_enabled(op, x, offsets)
with autocast(enabled=False):
return SegmentReduce.apply(*args)
def scatter_add(x, idx, m): def scatter_add(x, idx, m):
return ScatterAdd.apply(x, idx, m) args = _cast_if_autocast_enabled(x, idx, m)
with autocast(enabled=False):
return ScatterAdd.apply(*args)
def csrmm(gidxA, A_weights, gidxB, B_weights, num_vtypes): def csrmm(gidxA, A_weights, gidxB, B_weights, num_vtypes):
...@@ -1144,7 +1157,9 @@ def segment_mm(A, B, seglen_A): ...@@ -1144,7 +1157,9 @@ def segment_mm(A, B, seglen_A):
off += seglen_A[i] off += seglen_A[i]
return th.cat(C) return th.cat(C)
else: else:
return SEGMENTMM.apply(A, B, seglen_A) args = _cast_if_autocast_enabled(A, B, seglen_A)
with autocast(enabled=False):
return SEGMENTMM.apply(*args)
def gather_mm(A, B, idx_A=None, idx_B=None): def gather_mm(A, B, idx_A=None, idx_B=None):
...@@ -1153,4 +1168,6 @@ def gather_mm(A, B, idx_A=None, idx_B=None): ...@@ -1153,4 +1168,6 @@ def gather_mm(A, B, idx_A=None, idx_B=None):
B = B[idx_B] if idx_B is not None else B B = B[idx_B] if idx_B is not None else B
return th.bmm(A.unsqueeze(1), B).squeeze(1) return th.bmm(A.unsqueeze(1), B).squeeze(1)
else: else:
return GATHERMM.apply(A, B, idx_A, idx_B) args = _cast_if_autocast_enabled(A, B, idx_A, idx_B)
with autocast(enabled=False):
return GATHERMM.apply(*args)
...@@ -9,22 +9,8 @@ ...@@ -9,22 +9,8 @@
namespace dgl { namespace dgl {
namespace aten { namespace aten {
#define SWITCH_BITS(bits, DType, ...) \
do { \
if ((bits) == 16 || (bits) == 32) { \
typedef float DType; \
{ __VA_ARGS__ } \
} else if ((bits) == 64) { \
typedef double DType; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << "Data type not recognized with bits " << bits; \
} \
} while (0)
/*! \brief Generalized SegmentMM. */ /*! \brief Generalized SegmentMM. */
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void SegmentMM(const NDArray A, void SegmentMM(const NDArray A,
const NDArray B, const NDArray B,
NDArray C, NDArray C,
...@@ -33,7 +19,7 @@ void SegmentMM(const NDArray A, ...@@ -33,7 +19,7 @@ void SegmentMM(const NDArray A,
LOG(FATAL) << "Unsupported CPU kernel for SegmentMM."; LOG(FATAL) << "Unsupported CPU kernel for SegmentMM.";
} }
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void SegmentMMBackwardB(const NDArray A, void SegmentMMBackwardB(const NDArray A,
const NDArray dC, const NDArray dC,
NDArray dB, NDArray dB,
...@@ -42,7 +28,7 @@ void SegmentMMBackwardB(const NDArray A, ...@@ -42,7 +28,7 @@ void SegmentMMBackwardB(const NDArray A,
} }
/*! \brief Generalized GatherMM. */ /*! \brief Generalized GatherMM. */
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void GatherMM(const NDArray A, void GatherMM(const NDArray A,
const NDArray B, const NDArray B,
NDArray C, NDArray C,
...@@ -52,7 +38,7 @@ void GatherMM(const NDArray A, ...@@ -52,7 +38,7 @@ void GatherMM(const NDArray A,
} }
/*! \brief Generalized GatherMM_scatter. */ /*! \brief Generalized GatherMM_scatter. */
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void GatherMMScatter(const NDArray A, void GatherMMScatter(const NDArray A,
const NDArray B, const NDArray B,
NDArray C, NDArray C,
...@@ -62,74 +48,52 @@ void GatherMMScatter(const NDArray A, ...@@ -62,74 +48,52 @@ void GatherMMScatter(const NDArray A,
LOG(FATAL) << "Unsupported CPU kernel for GatherMM."; LOG(FATAL) << "Unsupported CPU kernel for GatherMM.";
} }
template void GatherMM<kDGLCPU, int32_t, 16>( template void GatherMM<kDGLCPU, int32_t, float>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDGLCPU, int64_t, 16>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b); const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDGLCPU, int32_t, 32>( template void GatherMM<kDGLCPU, int64_t, float>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b); const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDGLCPU, int64_t, 32>( template void GatherMM<kDGLCPU, int32_t, double>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b); const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDGLCPU, int32_t, 64>( template void GatherMM<kDGLCPU, int64_t, double>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDGLCPU, int64_t, 64>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b); const NDArray idx_a, const NDArray idx_b);
template void GatherMMScatter<kDGLCPU, int32_t, 16>( template void GatherMMScatter<kDGLCPU, int32_t, float>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCPU, int64_t, 16>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCPU, int32_t, 32>( template void GatherMMScatter<kDGLCPU, int64_t, float>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCPU, int64_t, 32>( template void GatherMMScatter<kDGLCPU, int32_t, double>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCPU, int32_t, 64>( template void GatherMMScatter<kDGLCPU, int64_t, double>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCPU, int64_t, 64>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void SegmentMM<kDGLCPU, int32_t, 16>( template void SegmentMM<kDGLCPU, int32_t, float>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDGLCPU, int64_t, 16>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans); const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDGLCPU, int32_t, 32>( template void SegmentMM<kDGLCPU, int64_t, float>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans); const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDGLCPU, int64_t, 32>( template void SegmentMM<kDGLCPU, int32_t, double>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans); const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDGLCPU, int32_t, 64>( template void SegmentMM<kDGLCPU, int64_t, double>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDGLCPU, int64_t, 64>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans); const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMMBackwardB<kDGLCPU, int32_t, 16>( template void SegmentMMBackwardB<kDGLCPU, int32_t, float>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDGLCPU, int64_t, 16>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDGLCPU, int32_t, 32>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen); const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDGLCPU, int64_t, 32>( template void SegmentMMBackwardB<kDGLCPU, int64_t, float>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen); const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDGLCPU, int32_t, 64>( template void SegmentMMBackwardB<kDGLCPU, int32_t, double>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen); const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDGLCPU, int64_t, 64>( template void SegmentMMBackwardB<kDGLCPU, int64_t, double>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen); const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
} // namespace aten } // namespace aten
......
...@@ -41,22 +41,9 @@ namespace aten { ...@@ -41,22 +41,9 @@ namespace aten {
} \ } \
} while (0) } while (0)
#define SWITCH_BITS(bits, DType, ...) \
do { \
if ((bits) == 16 || (bits) == 32) { \
typedef float DType; \
{ __VA_ARGS__ } \
} else if ((bits) == 64) { \
typedef double DType; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << "Data type not recognized with bits " << bits; \
} \
} while (0)
/*! \brief Generalized SDDMM on Csr format. */ /*! \brief Generalized SDDMM on Csr format. */
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void SDDMMCsr(const std::string& op, void SDDMMCsr(const std::string& op,
const BcastOff& bcast, const BcastOff& bcast,
const CSRMatrix& csr, const CSRMatrix& csr,
...@@ -65,17 +52,15 @@ void SDDMMCsr(const std::string& op, ...@@ -65,17 +52,15 @@ void SDDMMCsr(const std::string& op,
NDArray out, NDArray out,
int lhs_target, int lhs_target,
int rhs_target) { int rhs_target) {
SWITCH_BITS(bits, DType, { SWITCH_OP(op, Op, {
SWITCH_OP(op, Op, { SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, { cpu::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, csr, lhs, rhs, out);
cpu::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, csr, lhs, rhs, out);
});
}); });
}); });
} }
/*! \brief Generalized SDDMM on Csr format with Heterograph support. */ /*! \brief Generalized SDDMM on Csr format with Heterograph support. */
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void SDDMMCsrHetero(const std::string& op, void SDDMMCsrHetero(const std::string& op,
const BcastOff& bcast, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<CSRMatrix>& vec_csr,
...@@ -86,83 +71,59 @@ void SDDMMCsrHetero(const std::string& op, ...@@ -86,83 +71,59 @@ void SDDMMCsrHetero(const std::string& op,
int rhs_target, int rhs_target,
const std::vector<dgl_type_t>& lhs_nid, const std::vector<dgl_type_t>& lhs_nid,
const std::vector<dgl_type_t>& rhs_nid) { const std::vector<dgl_type_t>& rhs_nid) {
SWITCH_BITS(bits, DType, { SWITCH_OP(op, Op, {
SWITCH_OP(op, Op, { SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, { /* Call SDDMM for each relation type */
/* Call SDDMM for each relation type */ for (dgl_type_t etype = 0; etype < lhs_nid.size(); ++etype) {
for (dgl_type_t etype = 0; etype < lhs_nid.size(); ++etype) { CSRMatrix csr = vec_csr[etype];
CSRMatrix csr = vec_csr[etype]; NDArray lhs = vec_lhs[lhs_nid[etype]];
NDArray lhs = vec_lhs[lhs_nid[etype]]; NDArray rhs = vec_rhs[rhs_nid[etype]];
NDArray rhs = vec_rhs[rhs_nid[etype]]; NDArray out = vec_out[etype];
NDArray out = vec_out[etype]; cpu::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, csr, lhs, rhs, out);
cpu::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, csr, lhs, rhs, out); }
}
});
}); });
}); });
} }
template void SDDMMCsr<kDGLCPU, int32_t, 16>( template void SDDMMCsr<kDGLCPU, int32_t, float>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCPU, int64_t, 16>( template void SDDMMCsr<kDGLCPU, int64_t, float>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCPU, int32_t, 32>( template void SDDMMCsr<kDGLCPU, int32_t, double>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCPU, int64_t, 32>( template void SDDMMCsr<kDGLCPU, int64_t, double>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCPU, int32_t, 64>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCPU, int64_t, 64>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCsrHetero<kDGLCPU, int32_t, 16>( template void SDDMMCsrHetero<kDGLCPU, int32_t, float>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCPU, int64_t, 16>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCPU, int32_t, 32>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target, std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCPU, int64_t, 32>( template void SDDMMCsrHetero<kDGLCPU, int64_t, float>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target, std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCPU, int32_t, 64>( template void SDDMMCsrHetero<kDGLCPU, int32_t, double>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target, std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCPU, int64_t, 64>( template void SDDMMCsrHetero<kDGLCPU, int64_t, double>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
...@@ -171,7 +132,7 @@ template void SDDMMCsrHetero<kDGLCPU, int64_t, 64>( ...@@ -171,7 +132,7 @@ template void SDDMMCsrHetero<kDGLCPU, int64_t, 64>(
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
/*! \brief Generalized SDDMM on Coo format. */ /*! \brief Generalized SDDMM on Coo format. */
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void SDDMMCoo(const std::string& op, void SDDMMCoo(const std::string& op,
const BcastOff& bcast, const BcastOff& bcast,
const COOMatrix& coo, const COOMatrix& coo,
...@@ -180,17 +141,15 @@ void SDDMMCoo(const std::string& op, ...@@ -180,17 +141,15 @@ void SDDMMCoo(const std::string& op,
NDArray out, NDArray out,
int lhs_target, int lhs_target,
int rhs_target) { int rhs_target) {
SWITCH_BITS(bits, DType, { SWITCH_OP(op, Op, {
SWITCH_OP(op, Op, { SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, { cpu::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, coo, lhs, rhs, out);
cpu::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, coo, lhs, rhs, out);
});
}); });
}); });
} }
/*! \brief Generalized SDDMM on Coo format with Heterograph support. */ /*! \brief Generalized SDDMM on Coo format with Heterograph support. */
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void SDDMMCooHetero(const std::string& op, void SDDMMCooHetero(const std::string& op,
const BcastOff& bcast, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<COOMatrix>& vec_coo,
...@@ -201,90 +160,65 @@ void SDDMMCooHetero(const std::string& op, ...@@ -201,90 +160,65 @@ void SDDMMCooHetero(const std::string& op,
int rhs_target, int rhs_target,
const std::vector<dgl_type_t>& lhs_nid, const std::vector<dgl_type_t>& lhs_nid,
const std::vector<dgl_type_t>& rhs_nid) { const std::vector<dgl_type_t>& rhs_nid) {
SWITCH_BITS(bits, DType, { SWITCH_OP(op, Op, {
SWITCH_OP(op, Op, { SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, { /* Call SDDMM for each relation type */
/* Call SDDMM for each relation type */ for (dgl_type_t etype = 0; etype < lhs_nid.size(); ++etype) {
for (dgl_type_t etype = 0; etype < lhs_nid.size(); ++etype) { COOMatrix coo = vec_coo[etype];
COOMatrix coo = vec_coo[etype]; NDArray lhs = vec_lhs[lhs_nid[etype]];
NDArray lhs = vec_lhs[lhs_nid[etype]]; NDArray rhs = vec_rhs[rhs_nid[etype]];
NDArray rhs = vec_rhs[rhs_nid[etype]]; NDArray out = vec_out[etype];
NDArray out = vec_out[etype]; cpu::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, coo, lhs, rhs, out);
cpu::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, coo, lhs, rhs, out); }
}
});
}); });
}); });
} }
template void SDDMMCoo<kDGLCPU, int32_t, 16>( template void SDDMMCoo<kDGLCPU, int32_t, float>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCPU, int64_t, 16>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCPU, int32_t, 32>( template void SDDMMCoo<kDGLCPU, int64_t, float>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCPU, int64_t, 32>( template void SDDMMCoo<kDGLCPU, int32_t, double>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCPU, int32_t, 64>( template void SDDMMCoo<kDGLCPU, int64_t, double>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCPU, int64_t, 64>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCooHetero<kDGLCPU, int32_t, 16>( template void SDDMMCooHetero<kDGLCPU, int32_t, float>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target, std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCPU, int64_t, 16>( template void SDDMMCooHetero<kDGLCPU, int64_t, float>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target, std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCPU, int32_t, 32>( template void SDDMMCooHetero<kDGLCPU, int32_t, double>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target, std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCPU, int64_t, 32>( template void SDDMMCooHetero<kDGLCPU, int64_t, double>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target, std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCPU, int32_t, 64>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCPU, int64_t, 64>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
...@@ -12,7 +12,7 @@ namespace dgl { ...@@ -12,7 +12,7 @@ namespace dgl {
namespace aten { namespace aten {
/*! \brief Segment Reduce operator. */ /*! \brief Segment Reduce operator. */
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void SegmentReduce( void SegmentReduce(
const std::string& op, const std::string& op,
NDArray feat, NDArray feat,
...@@ -20,20 +20,14 @@ void SegmentReduce( ...@@ -20,20 +20,14 @@ void SegmentReduce(
NDArray out, NDArray out,
NDArray arg) { NDArray arg) {
if (op == "sum") { if (op == "sum") {
SWITCH_BITS(bits, DType, { cpu::SegmentSum<IdType, DType>(feat, offsets, out);
cpu::SegmentSum<IdType, DType>(feat, offsets, out);
});
} else if (op == "max" || op == "min") { } else if (op == "max" || op == "min") {
if (op == "max") { if (op == "max") {
SWITCH_BITS(bits, DType, { cpu::SegmentCmp<IdType, DType, cpu::op::Max<DType>>(
cpu::SegmentCmp<IdType, DType, cpu::op::Max<DType>>( feat, offsets, out, arg);
feat, offsets, out, arg);
});
} else { } else {
SWITCH_BITS(bits, DType, { cpu::SegmentCmp<IdType, DType, cpu::op::Min<DType>>(
cpu::SegmentCmp<IdType, DType, cpu::op::Min<DType>>( feat, offsets, out, arg);
feat, offsets, out, arg);
});
} }
} else { } else {
LOG(FATAL) << "Unsupported reduce function " << op; LOG(FATAL) << "Unsupported reduce function " << op;
...@@ -41,146 +35,105 @@ void SegmentReduce( ...@@ -41,146 +35,105 @@ void SegmentReduce(
} }
/*! \brief Scatter Add.*/ /*! \brief Scatter Add.*/
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void ScatterAdd(NDArray feat, void ScatterAdd(NDArray feat,
NDArray idx, NDArray idx,
NDArray out) { NDArray out) {
SWITCH_BITS(bits, DType, { cpu::ScatterAdd<IdType, DType>(feat, idx, out);
cpu::ScatterAdd<IdType, DType>(feat, idx, out);
});
} }
/*! \brief Update gradients for reduce operator max/min on heterogeneous graph.*/ /*! \brief Update gradients for reduce operator max/min on heterogeneous graph.*/
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void UpdateGradMinMax_hetero(const HeteroGraphPtr& g, void UpdateGradMinMax_hetero(const HeteroGraphPtr& g,
const std::string& op, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& feat,
const std::vector<NDArray>& idx, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, const std::vector<NDArray>& idx_etype,
std::vector<NDArray>* out) { std::vector<NDArray>* out) {
SWITCH_BITS(bits, DType, { cpu::UpdateGradMinMax_hetero<IdType, DType>(g, op, feat, idx, idx_etype, out);
cpu::UpdateGradMinMax_hetero<IdType, DType>(g, op, feat, idx, idx_etype, out);
});
} }
/*! \brief Backward function of segment cmp.*/ /*! \brief Backward function of segment cmp.*/
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void BackwardSegmentCmp( void BackwardSegmentCmp(
NDArray feat, NDArray feat,
NDArray arg, NDArray arg,
NDArray out) { NDArray out) {
SWITCH_BITS(bits, DType, { cpu::BackwardSegmentCmp<IdType, DType>(feat, arg, out);
cpu::BackwardSegmentCmp<IdType, DType>(feat, arg, out);
});
} }
template void SegmentReduce<kDGLCPU, int32_t, 16>( template void SegmentReduce<kDGLCPU, int32_t, float>(
const std::string &op, const std::string &op,
NDArray feat, NDArray feat,
NDArray offsets, NDArray offsets,
NDArray out, NDArray out,
NDArray arg); NDArray arg);
template void SegmentReduce<kDGLCPU, int64_t, 16>( template void SegmentReduce<kDGLCPU, int64_t, float>(
const std::string &op, const std::string &op,
NDArray feat, NDArray feat,
NDArray offsets, NDArray offsets,
NDArray out, NDArray out,
NDArray arg); NDArray arg);
template void SegmentReduce<kDGLCPU, int32_t, 32>( template void SegmentReduce<kDGLCPU, int32_t, double>(
const std::string &op, const std::string &op,
NDArray feat, NDArray feat,
NDArray offsets, NDArray offsets,
NDArray out, NDArray out,
NDArray arg); NDArray arg);
template void SegmentReduce<kDGLCPU, int64_t, 32>( template void SegmentReduce<kDGLCPU, int64_t, double>(
const std::string &op, const std::string &op,
NDArray feat, NDArray feat,
NDArray offsets, NDArray offsets,
NDArray out, NDArray out,
NDArray arg); NDArray arg);
template void SegmentReduce<kDGLCPU, int32_t, 64>(
const std::string &op, template void ScatterAdd<kDGLCPU, int32_t, float>(
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void SegmentReduce<kDGLCPU, int64_t, 64>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void ScatterAdd<kDGLCPU, int32_t, 16>(
NDArray feat,
NDArray idx,
NDArray out);
template void ScatterAdd<kDGLCPU, int64_t, 16>(
NDArray feat,
NDArray idx,
NDArray out);
template void ScatterAdd<kDGLCPU, int32_t, 32>(
NDArray feat, NDArray feat,
NDArray idx, NDArray idx,
NDArray out); NDArray out);
template void ScatterAdd<kDGLCPU, int64_t, 32>( template void ScatterAdd<kDGLCPU, int64_t, float>(
NDArray feat, NDArray feat,
NDArray idx, NDArray idx,
NDArray out); NDArray out);
template void ScatterAdd<kDGLCPU, int32_t, 64>( template void ScatterAdd<kDGLCPU, int32_t, double>(
NDArray feat, NDArray feat,
NDArray idx, NDArray idx,
NDArray out); NDArray out);
template void ScatterAdd<kDGLCPU, int64_t, 64>( template void ScatterAdd<kDGLCPU, int64_t, double>(
NDArray feat, NDArray feat,
NDArray arg, NDArray arg,
NDArray out); NDArray out);
template void UpdateGradMinMax_hetero<kDGLCPU, int32_t, 16>( template void UpdateGradMinMax_hetero<kDGLCPU, int32_t, float>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDGLCPU, int64_t, 16>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDGLCPU, int32_t, 32>(
const HeteroGraphPtr& g, const std::string& op, const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx, const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out); const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDGLCPU, int64_t, 32>( template void UpdateGradMinMax_hetero<kDGLCPU, int64_t, float>(
const HeteroGraphPtr& g, const std::string& op, const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx, const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out); const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDGLCPU, int32_t, 64>( template void UpdateGradMinMax_hetero<kDGLCPU, int32_t, double>(
const HeteroGraphPtr& g, const std::string& op, const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx, const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out); const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDGLCPU, int64_t, 64>( template void UpdateGradMinMax_hetero<kDGLCPU, int64_t, double>(
const HeteroGraphPtr& g, const std::string& op, const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx, const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out); const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void BackwardSegmentCmp<kDGLCPU, int32_t, 16>( template void BackwardSegmentCmp<kDGLCPU, int32_t, float>(
NDArray feat,
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDGLCPU, int64_t, 16>(
NDArray feat,
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDGLCPU, int32_t, 32>(
NDArray feat, NDArray feat,
NDArray arg, NDArray arg,
NDArray out); NDArray out);
template void BackwardSegmentCmp<kDGLCPU, int64_t, 32>( template void BackwardSegmentCmp<kDGLCPU, int64_t, float>(
NDArray feat, NDArray feat,
NDArray arg, NDArray arg,
NDArray out); NDArray out);
template void BackwardSegmentCmp<kDGLCPU, int32_t, 64>( template void BackwardSegmentCmp<kDGLCPU, int32_t, double>(
NDArray feat, NDArray feat,
NDArray arg, NDArray arg,
NDArray out); NDArray out);
template void BackwardSegmentCmp<kDGLCPU, int64_t, 64>( template void BackwardSegmentCmp<kDGLCPU, int64_t, double>(
NDArray feat, NDArray feat,
NDArray arg, NDArray arg,
NDArray out); NDArray out);
......
...@@ -10,7 +10,7 @@ namespace dgl { ...@@ -10,7 +10,7 @@ namespace dgl {
namespace aten { namespace aten {
/*! \brief Generalized SpMM on Csr format. */ /*! \brief Generalized SpMM on Csr format. */
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void SpMMCsr(const std::string& op, const std::string& reduce, void SpMMCsr(const std::string& op, const std::string& reduce,
const BcastOff& bcast, const BcastOff& bcast,
const CSRMatrix& csr, const CSRMatrix& csr,
...@@ -20,25 +20,21 @@ void SpMMCsr(const std::string& op, const std::string& reduce, ...@@ -20,25 +20,21 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
std::vector<NDArray> out_aux) { std::vector<NDArray> out_aux) {
const int64_t dim = bcast.out_len; const int64_t dim = bcast.out_len;
if (reduce == "sum") { if (reduce == "sum") {
SWITCH_BITS(bits, DType, { SWITCH_OP(op, Op, {
SWITCH_OP(op, Op, { cpu::SpMMSumCsr<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
cpu::SpMMSumCsr<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
});
}); });
} else if (reduce == "max" || reduce == "min") { } else if (reduce == "max" || reduce == "min") {
SWITCH_BITS(bits, DType, { SWITCH_OP(op, Op, {
SWITCH_OP(op, Op, { DType *out_off = out.Ptr<DType>();
DType *out_off = out.Ptr<DType>(); if (reduce == "max") {
if (reduce == "max") { std::fill(out_off, out_off + csr.num_rows * dim, cpu::op::Max<DType>::zero);
std::fill(out_off, out_off + csr.num_rows * dim, cpu::op::Max<DType>::zero); cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Max<DType>>(
cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Max<DType>>( bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]); } else {
} else { std::fill(out_off, out_off + csr.num_rows * dim, cpu::op::Min<DType>::zero);
std::fill(out_off, out_off + csr.num_rows * dim, cpu::op::Min<DType>::zero); cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Min<DType>>(
cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Min<DType>>( bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]); }
}
});
}); });
} else { } else {
LOG(FATAL) << "Unsupported SpMM reducer: " << reduce; LOG(FATAL) << "Unsupported SpMM reducer: " << reduce;
...@@ -46,7 +42,7 @@ void SpMMCsr(const std::string& op, const std::string& reduce, ...@@ -46,7 +42,7 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
} }
/*! \brief Generalized SpMM on Csr format. */ /*! \brief Generalized SpMM on Csr format. */
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void SpMMCsrHetero(const std::string& op, const std::string& reduce, void SpMMCsrHetero(const std::string& op, const std::string& reduce,
const BcastOff& bcast, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<CSRMatrix>& vec_csr,
...@@ -58,130 +54,104 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce, ...@@ -58,130 +54,104 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
const std::vector<dgl_type_t>& out_node_tids) { const std::vector<dgl_type_t>& out_node_tids) {
const int64_t dim = bcast.out_len; const int64_t dim = bcast.out_len;
if (reduce == "sum") { if (reduce == "sum") {
SWITCH_BITS(bits, DType, { SWITCH_OP(op, Op, {
SWITCH_OP(op, Op, { /* Call SpMM for each relation type */
/* Call SpMM for each relation type */ for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) { const dgl_type_t src_id = ufeat_node_tids[etype];
const dgl_type_t src_id = ufeat_node_tids[etype]; const dgl_type_t dst_id = out_node_tids[etype];
const dgl_type_t dst_id = out_node_tids[etype]; CSRMatrix csr = vec_csr[etype];
CSRMatrix csr = vec_csr[etype]; NDArray ufeat = (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
NDArray ufeat = (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id]; NDArray efeat = (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
NDArray efeat = (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype]; NDArray out = (*vec_out)[dst_id];
NDArray out = (*vec_out)[dst_id]; cpu::SpMMSumCsr<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
cpu::SpMMSumCsr<IdType, DType, Op>(bcast, csr, ufeat, efeat, out); }
}
});
}); });
} else if (reduce == "max" || reduce == "min") { } else if (reduce == "max" || reduce == "min") {
SWITCH_BITS(bits, DType, { SWITCH_OP(op, Op, {
SWITCH_OP(op, Op, { std::vector<bool> updated((*vec_out).size(), false);
std::vector<bool> updated((*vec_out).size(), false); // TODO(Israt): use vector updated to fill(out...) too
// TODO(Israt): use vector updated to fill(out...) too for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) { DType *out_off = (*vec_out)[out_node_tids[etype]].Ptr<DType>();
DType *out_off = (*vec_out)[out_node_tids[etype]].Ptr<DType>(); if (reduce == "max")
if (reduce == "max") std::fill(out_off, out_off + vec_csr[etype].num_rows * dim, cpu::op::Max<DType>::zero);
std::fill(out_off, out_off + vec_csr[etype].num_rows * dim, cpu::op::Max<DType>::zero); else
else std::fill(out_off, out_off + vec_csr[etype].num_rows * dim, cpu::op::Min<DType>::zero);
std::fill(out_off, out_off + vec_csr[etype].num_rows * dim, cpu::op::Min<DType>::zero); const dgl_type_t dst_id = out_node_tids[etype];
const dgl_type_t dst_id = out_node_tids[etype]; if (!updated[dst_id]) {
if (!updated[dst_id]) { updated[dst_id] = true;
updated[dst_id] = true; if (Op::use_lhs) {
if (Op::use_lhs) { IdType *argu_ntype = (*out_aux)[2][dst_id].Ptr<IdType>();
IdType *argu_ntype = (*out_aux)[2][dst_id].Ptr<IdType>(); std::fill(argu_ntype, argu_ntype + vec_csr[etype].num_rows * dim, -1);
std::fill(argu_ntype, argu_ntype + vec_csr[etype].num_rows * dim, -1);
}
if (Op::use_rhs) {
IdType *arge_etype = (*out_aux)[3][dst_id].Ptr<IdType>();
std::fill(arge_etype, arge_etype + vec_csr[etype].num_rows * dim, -1);
}
} }
} if (Op::use_rhs) {
/* Call SpMM for each relation type */ IdType *arge_etype = (*out_aux)[3][dst_id].Ptr<IdType>();
for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) { std::fill(arge_etype, arge_etype + vec_csr[etype].num_rows * dim, -1);
const dgl_type_t src_id = ufeat_node_tids[etype];
const dgl_type_t dst_id = out_node_tids[etype];
CSRMatrix csr = vec_csr[etype];
NDArray ufeat = (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
NDArray efeat = (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
NDArray out = (*vec_out)[dst_id];
if (reduce == "max") {
cpu::SpMMCmpCsrHetero<IdType, DType, Op, cpu::op::Max<DType>>(
bcast, csr, ufeat, efeat, out, (*out_aux)[0][dst_id], (*out_aux)[1][dst_id],
(*out_aux)[2][dst_id], (*out_aux)[3][dst_id], src_id, etype);
} else {
cpu::SpMMCmpCsrHetero<IdType, DType, Op, cpu::op::Min<DType>>(
bcast, csr, ufeat, efeat, out, (*out_aux)[0][dst_id], (*out_aux)[1][dst_id],
(*out_aux)[2][dst_id], (*out_aux)[3][dst_id], src_id, etype);
} }
} }
}); }
/* Call SpMM for each relation type */
for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
const dgl_type_t src_id = ufeat_node_tids[etype];
const dgl_type_t dst_id = out_node_tids[etype];
CSRMatrix csr = vec_csr[etype];
NDArray ufeat = (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
NDArray efeat = (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
NDArray out = (*vec_out)[dst_id];
if (reduce == "max") {
cpu::SpMMCmpCsrHetero<IdType, DType, Op, cpu::op::Max<DType>>(
bcast, csr, ufeat, efeat, out, (*out_aux)[0][dst_id], (*out_aux)[1][dst_id],
(*out_aux)[2][dst_id], (*out_aux)[3][dst_id], src_id, etype);
} else {
cpu::SpMMCmpCsrHetero<IdType, DType, Op, cpu::op::Min<DType>>(
bcast, csr, ufeat, efeat, out, (*out_aux)[0][dst_id], (*out_aux)[1][dst_id],
(*out_aux)[2][dst_id], (*out_aux)[3][dst_id], src_id, etype);
}
}
}); });
} else { } else {
LOG(FATAL) << "Unsupported SpMM reducer: " << reduce; LOG(FATAL) << "Unsupported SpMM reducer: " << reduce;
} }
} }
template void SpMMCsr<kDGLCPU, int32_t, 16>( template void SpMMCsr<kDGLCPU, int32_t, float>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCPU, int64_t, 16>( template void SpMMCsr<kDGLCPU, int64_t, float>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCPU, int32_t, 32>( template void SpMMCsr<kDGLCPU, int32_t, double>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCPU, int64_t, 32>( template void SpMMCsr<kDGLCPU, int64_t, double>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCPU, int32_t, 64>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCPU, int64_t, 64>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsrHetero<kDGLCPU, int32_t, 16>( template void SpMMCsrHetero<kDGLCPU, int32_t, float>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr, const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat, const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux, std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids, const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids); const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDGLCPU, int64_t, 16>( template void SpMMCsrHetero<kDGLCPU, int64_t, float>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr, const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat, const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux, std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids, const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids); const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDGLCPU, int32_t, 32>( template void SpMMCsrHetero<kDGLCPU, int32_t, double>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr, const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat, const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux, std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids, const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids); const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDGLCPU, int64_t, 32>( template void SpMMCsrHetero<kDGLCPU, int64_t, double>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDGLCPU, int32_t, 64>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDGLCPU, int64_t, 64>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr, const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat, const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
...@@ -190,87 +160,67 @@ template void SpMMCsrHetero<kDGLCPU, int64_t, 64>( ...@@ -190,87 +160,67 @@ template void SpMMCsrHetero<kDGLCPU, int64_t, 64>(
const std::vector<dgl_type_t>& out_node_tids); const std::vector<dgl_type_t>& out_node_tids);
/*! \brief Edge_softmax_csr forward op on Csr format. */ /*! \brief Edge_softmax_csr forward op on Csr format. */
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void Edge_softmax_csr_forward(const std::string& op, void Edge_softmax_csr_forward(const std::string& op,
const BcastOff& bcast, const BcastOff& bcast,
const CSRMatrix& csr, const CSRMatrix& csr,
NDArray ufeat, NDArray ufeat,
NDArray efeat, NDArray efeat,
NDArray out) { NDArray out) {
SWITCH_BITS(bits, DType, { SWITCH_OP(op, Op, {
SWITCH_OP(op, Op, { cpu::Edge_softmax_csr_forward<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
cpu::Edge_softmax_csr_forward<IdType, DType, Op>(bcast, csr, ufeat, efeat, out); });
});
});
} }
/*! \brief Edge_softmax_csr backward op on Csr format. */ /*! \brief Edge_softmax_csr backward op on Csr format. */
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void Edge_softmax_csr_backward(const std::string& op, void Edge_softmax_csr_backward(const std::string& op,
const BcastOff& bcast, const BcastOff& bcast,
const CSRMatrix& csr, const CSRMatrix& csr,
NDArray out, NDArray out,
NDArray sds, NDArray sds,
NDArray back_out) { NDArray back_out) {
SWITCH_BITS(bits, DType, { SWITCH_OP(op, Op, {
SWITCH_OP(op, Op, { cpu::Edge_softmax_csr_backward<IdType, DType, Op>(bcast, csr, out, sds, back_out);
cpu::Edge_softmax_csr_backward<IdType, DType, Op>(bcast, csr, out, sds, back_out);
});
}); });
} }
template void Edge_softmax_csr_forward<kDGLCPU, int32_t, 16>( template void Edge_softmax_csr_forward<kDGLCPU, int32_t, float>(
const std::string& op, const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out); NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDGLCPU, int64_t, 16>( template void Edge_softmax_csr_forward<kDGLCPU, int64_t, float>(
const std::string& op, const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out); NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDGLCPU, int32_t, 32>( template void Edge_softmax_csr_forward<kDGLCPU, int32_t, double>(
const std::string& op, const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out); NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDGLCPU, int64_t, 32>( template void Edge_softmax_csr_forward<kDGLCPU, int64_t, double>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDGLCPU, int32_t, 64>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDGLCPU, int64_t, 64>(
const std::string& op, const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out); NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDGLCPU, int32_t, 16>( template void Edge_softmax_csr_backward<kDGLCPU, int32_t, float>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDGLCPU, int64_t, 16>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDGLCPU, int32_t, 32>(
const std::string& op, const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out); NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDGLCPU, int64_t, 32>( template void Edge_softmax_csr_backward<kDGLCPU, int64_t, float>(
const std::string& op, const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out); NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDGLCPU, int32_t, 64>( template void Edge_softmax_csr_backward<kDGLCPU, int32_t, double>(
const std::string& op, const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out); NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDGLCPU, int64_t, 64>( template void Edge_softmax_csr_backward<kDGLCPU, int64_t, double>(
const std::string& op, const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out); NDArray ufeat, NDArray efeat, NDArray out);
/*! \brief Generalized SpMM on Coo format. */ /*! \brief Generalized SpMM on Coo format. */
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void SpMMCoo(const std::string& op, const std::string& reduce, void SpMMCoo(const std::string& op, const std::string& reduce,
const BcastOff& bcast, const BcastOff& bcast,
const COOMatrix& coo, const COOMatrix& coo,
...@@ -279,52 +229,39 @@ void SpMMCoo(const std::string& op, const std::string& reduce, ...@@ -279,52 +229,39 @@ void SpMMCoo(const std::string& op, const std::string& reduce,
NDArray out, NDArray out,
std::vector<NDArray> out_aux) { std::vector<NDArray> out_aux) {
if (reduce == "sum") { if (reduce == "sum") {
SWITCH_BITS(bits, DType, { SWITCH_OP(op, Op, {
SWITCH_OP(op, Op, { cpu::SpMMSumCoo<IdType, DType, Op>(bcast, coo, ufeat, efeat, out);
cpu::SpMMSumCoo<IdType, DType, Op>(bcast, coo, ufeat, efeat, out);
});
}); });
} else if (reduce == "max" || reduce == "min") { } else if (reduce == "max" || reduce == "min") {
SWITCH_BITS(bits, DType, { SWITCH_OP(op, Op, {
SWITCH_OP(op, Op, { if (reduce == "max")
if (reduce == "max") cpu::SpMMCmpCoo<IdType, DType, Op, cpu::op::Max<DType>>(
cpu::SpMMCmpCoo<IdType, DType, Op, cpu::op::Max<DType>>( bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]);
bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]); else
else cpu::SpMMCmpCoo<IdType, DType, Op, cpu::op::Min<DType>>(
cpu::SpMMCmpCoo<IdType, DType, Op, cpu::op::Min<DType>>( bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]);
bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]);
});
}); });
} else { } else {
LOG(FATAL) << "Unsupported SpMM reducer: " << reduce; LOG(FATAL) << "Unsupported SpMM reducer: " << reduce;
} }
} }
template void SpMMCoo<kDGLCPU, int32_t, 16>( template void SpMMCoo<kDGLCPU, int32_t, float>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo, const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCPU, int64_t, 16>( template void SpMMCoo<kDGLCPU, int64_t, float>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo, const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCPU, int32_t, 32>( template void SpMMCoo<kDGLCPU, int32_t, double>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo, const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCPU, int64_t, 32>( template void SpMMCoo<kDGLCPU, int64_t, double>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo, const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCPU, int32_t, 64>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCPU, int64_t, 64>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
...@@ -146,21 +146,6 @@ constexpr DType Min<DType>::zero; ...@@ -146,21 +146,6 @@ constexpr DType Min<DType>::zero;
} \ } \
} while (0) } while (0)
#define SWITCH_BITS(bits, DType, ...) \
do { \
if ((bits) == 16) { \
LOG(FATAL) << "FP16 not supported on CPU"; \
} else if ((bits) == 32) { \
typedef float DType; \
{ __VA_ARGS__ } \
} else if ((bits) == 64) { \
typedef double DType; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << "Data type not recognized with bits " << bits; \
} \
} while (0)
} // namespace op } // namespace op
} // namespace cpu } // namespace cpu
......
...@@ -60,10 +60,12 @@ template NDArray IndexSelect<kDGLCUDA, int32_t, int32_t>(NDArray, IdArray); ...@@ -60,10 +60,12 @@ template NDArray IndexSelect<kDGLCUDA, int32_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, int32_t, int64_t>(NDArray, IdArray); template NDArray IndexSelect<kDGLCUDA, int32_t, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, int64_t, int32_t>(NDArray, IdArray); template NDArray IndexSelect<kDGLCUDA, int64_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, int64_t, int64_t>(NDArray, IdArray); template NDArray IndexSelect<kDGLCUDA, int64_t, int64_t>(NDArray, IdArray);
#ifdef USE_FP16
template NDArray IndexSelect<kDGLCUDA, __half, int32_t>(NDArray, IdArray); template NDArray IndexSelect<kDGLCUDA, __half, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, __half, int64_t>(NDArray, IdArray); template NDArray IndexSelect<kDGLCUDA, __half, int64_t>(NDArray, IdArray);
#endif #if BF16_ENABLED
template NDArray IndexSelect<kDGLCUDA, __nv_bfloat16, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, __nv_bfloat16, int64_t>(NDArray, IdArray);
#endif // BF16_ENABLED
template NDArray IndexSelect<kDGLCUDA, float, int32_t>(NDArray, IdArray); template NDArray IndexSelect<kDGLCUDA, float, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, float, int64_t>(NDArray, IdArray); template NDArray IndexSelect<kDGLCUDA, float, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, double, int32_t>(NDArray, IdArray); template NDArray IndexSelect<kDGLCUDA, double, int32_t>(NDArray, IdArray);
...@@ -72,30 +74,21 @@ template NDArray IndexSelect<kDGLCUDA, double, int64_t>(NDArray, IdArray); ...@@ -72,30 +74,21 @@ template NDArray IndexSelect<kDGLCUDA, double, int64_t>(NDArray, IdArray);
template <DGLDeviceType XPU, typename DType> template <DGLDeviceType XPU, typename DType>
DType IndexSelect(NDArray array, int64_t index) { DType IndexSelect(NDArray array, int64_t index) {
auto device = runtime::DeviceAPI::Get(array->ctx); auto device = runtime::DeviceAPI::Get(array->ctx);
#ifdef USE_FP16 DType ret = static_cast<DType>(0.0f);
// The initialization constructor for __half is apparently a device-
// only function in some setups, but the current function, IndexSelect,
// isn't run on the device, so it doesn't have access to that constructor.
using SafeDType = typename std::conditional<
std::is_same<DType, __half>::value, uint16_t, DType>::type;
SafeDType ret = 0;
#else
DType ret = 0;
#endif
device->CopyDataFromTo( device->CopyDataFromTo(
static_cast<DType*>(array->data) + index, 0, static_cast<DType*>(array->data) + index, 0, &ret, 0,
reinterpret_cast<DType*>(&ret), 0, sizeof(DType), array->ctx, sizeof(DType), array->ctx, DGLContext{kDGLCPU, 0}, array->dtype);
DGLContext{kDGLCPU, 0}, array->dtype); return ret;
return reinterpret_cast<DType&>(ret);
} }
template int32_t IndexSelect<kDGLCUDA, int32_t>(NDArray array, int64_t index); template int32_t IndexSelect<kDGLCUDA, int32_t>(NDArray array, int64_t index);
template int64_t IndexSelect<kDGLCUDA, int64_t>(NDArray array, int64_t index); template int64_t IndexSelect<kDGLCUDA, int64_t>(NDArray array, int64_t index);
template uint32_t IndexSelect<kDGLCUDA, uint32_t>(NDArray array, int64_t index); template uint32_t IndexSelect<kDGLCUDA, uint32_t>(NDArray array, int64_t index);
template uint64_t IndexSelect<kDGLCUDA, uint64_t>(NDArray array, int64_t index); template uint64_t IndexSelect<kDGLCUDA, uint64_t>(NDArray array, int64_t index);
#ifdef USE_FP16
template __half IndexSelect<kDGLCUDA, __half>(NDArray array, int64_t index); template __half IndexSelect<kDGLCUDA, __half>(NDArray array, int64_t index);
#endif #if BF16_ENABLED
template __nv_bfloat16 IndexSelect<kDGLCUDA, __nv_bfloat16>(NDArray array, int64_t index);
#endif // BF16_ENABLED
template float IndexSelect<kDGLCUDA, float>(NDArray array, int64_t index); template float IndexSelect<kDGLCUDA, float>(NDArray array, int64_t index);
template double IndexSelect<kDGLCUDA, double>(NDArray array, int64_t index); template double IndexSelect<kDGLCUDA, double>(NDArray array, int64_t index);
......
...@@ -224,9 +224,10 @@ NDArray Full(DType val, int64_t length, DGLContext ctx) { ...@@ -224,9 +224,10 @@ NDArray Full(DType val, int64_t length, DGLContext ctx) {
template IdArray Full<kDGLCUDA, int32_t>(int32_t val, int64_t length, DGLContext ctx); template IdArray Full<kDGLCUDA, int32_t>(int32_t val, int64_t length, DGLContext ctx);
template IdArray Full<kDGLCUDA, int64_t>(int64_t val, int64_t length, DGLContext ctx); template IdArray Full<kDGLCUDA, int64_t>(int64_t val, int64_t length, DGLContext ctx);
#ifdef USE_FP16
template IdArray Full<kDGLCUDA, __half>(__half val, int64_t length, DGLContext ctx); template IdArray Full<kDGLCUDA, __half>(__half val, int64_t length, DGLContext ctx);
#endif #if BF16_ENABLED
template IdArray Full<kDGLCUDA, __nv_bfloat16>(__nv_bfloat16 val, int64_t length, DGLContext ctx);
#endif // BF16_ENABLED
template IdArray Full<kDGLCUDA, float>(float val, int64_t length, DGLContext ctx); template IdArray Full<kDGLCUDA, float>(float val, int64_t length, DGLContext ctx);
template IdArray Full<kDGLCUDA, double>(double val, int64_t length, DGLContext ctx); template IdArray Full<kDGLCUDA, double>(double val, int64_t length, DGLContext ctx);
......
...@@ -39,16 +39,18 @@ void Scatter_(IdArray index, NDArray value, NDArray out) { ...@@ -39,16 +39,18 @@ void Scatter_(IdArray index, NDArray value, NDArray out) {
template void Scatter_<kDGLCUDA, int32_t, int32_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, int32_t, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, int64_t, int32_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, int64_t, int32_t>(IdArray, NDArray, NDArray);
#ifdef USE_FP16
template void Scatter_<kDGLCUDA, __half, int32_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, __half, int32_t>(IdArray, NDArray, NDArray);
#endif #if BF16_ENABLED
template void Scatter_<kDGLCUDA, __nv_bfloat16, int32_t>(IdArray, NDArray, NDArray);
#endif // BF16_ENABLED
template void Scatter_<kDGLCUDA, float, int32_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, float, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, double, int32_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, double, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, int32_t, int64_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, int32_t, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, int64_t, int64_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, int64_t, int64_t>(IdArray, NDArray, NDArray);
#ifdef USE_FP16
template void Scatter_<kDGLCUDA, __half, int64_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, __half, int64_t>(IdArray, NDArray, NDArray);
#endif #if BF16_ENABLED
template void Scatter_<kDGLCUDA, __nv_bfloat16, int64_t>(IdArray, NDArray, NDArray);
#endif // BF16_ENABLED
template void Scatter_<kDGLCUDA, float, int64_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, float, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, double, int64_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, double, int64_t>(IdArray, NDArray, NDArray);
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cassert> #include <cassert>
#include "fp16.cuh" #include "fp16.cuh"
#include "bf16.cuh"
#if __CUDA_ARCH__ >= 600 #if __CUDA_ARCH__ >= 600
#include <cuda_fp16.h> #include <cuda_fp16.h>
...@@ -44,7 +45,6 @@ template <typename T> struct Cast { ...@@ -44,7 +45,6 @@ template <typename T> struct Cast {
} }
}; };
#ifdef USE_FP16
template <> struct Cast<half> { template <> struct Cast<half> {
typedef Code<sizeof(half)>::Type Type; typedef Code<sizeof(half)>::Type Type;
static __device__ __forceinline__ Type Encode(half val) { static __device__ __forceinline__ Type Encode(half val) {
...@@ -54,7 +54,32 @@ template <> struct Cast<half> { ...@@ -54,7 +54,32 @@ template <> struct Cast<half> {
return __ushort_as_half(code); return __ushort_as_half(code);
} }
}; };
#if BF16_ENABLED
template <> struct Cast<__nv_bfloat16> {
typedef Code<sizeof(__nv_bfloat16)>::Type Type;
static __device__ __forceinline__ Type Encode(__nv_bfloat16 val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return __bfloat16_as_ushort(val);
#else
printf("Atomic operations are not supported for bfloat16 (BF16) "
"on GPUs with compute capability less than 8.0.\n");
__trap();
return static_cast<Type>(0);
#endif
}
static __device__ __forceinline__ __nv_bfloat16 Decode(Type code) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return __ushort_as_bfloat16(code);
#else
printf("Atomic operations are not supported for bfloat16 (BF16) "
"on GPUs with compute capability less than 8.0.\n");
__trap();
return static_cast<__nv_bfloat16>(0.0f);
#endif #endif
}
};
#endif // BF16_ENABLED
template <> struct Cast<float> { template <> struct Cast<float> {
typedef Code<sizeof(float)>::Type Type; typedef Code<sizeof(float)>::Type Type;
...@@ -109,9 +134,9 @@ static __device__ __forceinline__ unsigned short int atomicCASshort( // NOLINT ...@@ -109,9 +134,9 @@ static __device__ __forceinline__ unsigned short int atomicCASshort( // NOLINT
return Cast<T>::Decode(old); \ return Cast<T>::Decode(old); \
} }
#define DEFINE_ATOMIC_HALF(NAME) \ #define DEFINE_ATOMIC_16BIT(NAME, dtype) \
template <> \ template <> \
__device__ __forceinline__ half Atomic##NAME<half>(half* addr, half val) { \ __device__ __forceinline__ dtype Atomic##NAME<dtype>(dtype* addr, dtype val) { \
typedef uint16_t CT; \ typedef uint16_t CT; \
CT* addr_as_ui = reinterpret_cast<CT*>(addr); \ CT* addr_as_ui = reinterpret_cast<CT*>(addr); \
CT old = *addr_as_ui; \ CT old = *addr_as_ui; \
...@@ -119,23 +144,25 @@ static __device__ __forceinline__ unsigned short int atomicCASshort( // NOLINT ...@@ -119,23 +144,25 @@ static __device__ __forceinline__ unsigned short int atomicCASshort( // NOLINT
do { \ do { \
assumed = old; \ assumed = old; \
old = atomicCASshort(addr_as_ui, assumed, \ old = atomicCASshort(addr_as_ui, assumed, \
Cast<half>::Encode(OP(val, Cast<half>::Decode(old)))); \ Cast<dtype>::Encode(OP(val, Cast<dtype>::Decode(old)))); \
} while (assumed != old); \ } while (assumed != old); \
return Cast<half>::Decode(old); \ return Cast<dtype>::Decode(old); \
} }
#define OP(a, b) max(a, b) #define OP(a, b) max(a, b)
DEFINE_ATOMIC(Max) DEFINE_ATOMIC(Max)
#ifdef USE_FP16 DEFINE_ATOMIC_16BIT(Max, half)
DEFINE_ATOMIC_HALF(Max) #if BF16_ENABLED
#endif // USE_FP16 DEFINE_ATOMIC_16BIT(Max, __nv_bfloat16)
#endif // BF16_ENABLED
#undef OP #undef OP
#define OP(a, b) min(a, b) #define OP(a, b) min(a, b)
DEFINE_ATOMIC(Min) DEFINE_ATOMIC(Min)
#ifdef USE_FP16 DEFINE_ATOMIC_16BIT(Min, half)
DEFINE_ATOMIC_HALF(Min) #if BF16_ENABLED
#endif // USE_FP16 DEFINE_ATOMIC_16BIT(Min, __nv_bfloat16)
#endif // BF16_ENABLED
#undef OP #undef OP
#define OP(a, b) a + b #define OP(a, b) a + b
...@@ -258,11 +285,10 @@ __device__ __forceinline__ double AtomicAdd<double>(double* addr, double val) { ...@@ -258,11 +285,10 @@ __device__ __forceinline__ double AtomicAdd<double>(double* addr, double val) {
#endif #endif
} }
#ifdef USE_FP16
#if defined(CUDART_VERSION) && CUDART_VERSION >= 10000 #if defined(CUDART_VERSION) && CUDART_VERSION >= 10000
template <> template <>
__device__ __forceinline__ half AtomicAdd<half>(half* addr, half val) { __device__ __forceinline__ half AtomicAdd<half>(half* addr, half val) {
// half make sure we have half support // make sure we have half support
#if __CUDA_ARCH__ >= 700 #if __CUDA_ARCH__ >= 700
return atomicAdd(addr, val); return atomicAdd(addr, val);
#else #else
...@@ -275,7 +301,24 @@ __device__ __forceinline__ half AtomicAdd<half>(half* addr, half val) { ...@@ -275,7 +301,24 @@ __device__ __forceinline__ half AtomicAdd<half>(half* addr, half val) {
#endif // __CUDA_ARCH__ >= 700 #endif // __CUDA_ARCH__ >= 700
} }
#endif // defined(CUDART_VERSION) && CUDART_VERSION >= 10000 #endif // defined(CUDART_VERSION) && CUDART_VERSION >= 10000
#endif // USE_FP16
#if BF16_ENABLED
template <>
__device__ __forceinline__ __nv_bfloat16 AtomicAdd<__nv_bfloat16>(
__nv_bfloat16* addr, __nv_bfloat16 val) {
// make sure we have bfloat16 support
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return atomicAdd(addr, val);
#else
(void)addr;
(void)val;
printf("Atomic operations are not supported for bfloat16 (BF16) "
"on GPUs with compute capability less than 8.0.\n");
__trap();
return val;
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
}
#endif // BF16_ENABLED
} // namespace cuda } // namespace cuda
......
/*!
* Copyright (c) 2022 by Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* \file array/cuda/bf16.cuh
* \brief bfloat16 related functions.
*/
#ifndef DGL_ARRAY_CUDA_BF16_CUH_
#define DGL_ARRAY_CUDA_BF16_CUH_
#if BF16_ENABLED
#include <cuda_bf16.h>
#include <algorithm>
static __device__ __forceinline__ __nv_bfloat16 max(__nv_bfloat16 a,
__nv_bfloat16 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return __hmax(a, b);
#else
return __nv_bfloat16(max(float(a), float(b))); // NOLINT
#endif
}
static __device__ __forceinline__ __nv_bfloat16 min(__nv_bfloat16 a,
__nv_bfloat16 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return __hmin(a, b);
#else
return __nv_bfloat16(min(float(a), float(b))); // NOLINT
#endif
}
#ifdef __CUDACC__
// Arithmetic BF16 operations for architecture >= 8.0 are already defined in cuda_bf16.h
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
__device__ __forceinline__ __nv_bfloat16 operator+(
const __nv_bfloat16& lh, const __nv_bfloat16& rh) {
return __nv_bfloat16(float(lh) + float(rh)); // NOLINT
}
__device__ __forceinline__ __nv_bfloat16 operator-(
const __nv_bfloat16& lh, const __nv_bfloat16& rh) {
return __nv_bfloat16(float(lh) - float(rh)); // NOLINT
}
__device__ __forceinline__ __nv_bfloat16 operator*(
const __nv_bfloat16& lh, const __nv_bfloat16& rh) {
return __nv_bfloat16(float(lh) * float(rh)); // NOLINT
}
__device__ __forceinline__ __nv_bfloat16 operator/(
const __nv_bfloat16& lh, const __nv_bfloat16& rh) {
return __nv_bfloat16(float(lh) / float(rh)); // NOLINT
}
__device__ __forceinline__ __nv_bfloat16& operator+=(
__nv_bfloat16& lh, const __nv_bfloat16& rh) { // NOLINT
lh = __nv_bfloat16(float(lh) + float(rh)); return lh; // NOLINT
}
__device__ __forceinline__ __nv_bfloat16& operator-=(
__nv_bfloat16& lh, const __nv_bfloat16& rh) { // NOLINT
lh = __nv_bfloat16(float(lh) - float(rh)); return lh; // NOLINT
}
__device__ __forceinline__ __nv_bfloat16& operator*=(
__nv_bfloat16& lh, const __nv_bfloat16& rh) { // NOLINT
lh = __nv_bfloat16(float(lh) * float(rh)); return lh; // NOLINT
}
__device__ __forceinline__ __nv_bfloat16& operator/=(
__nv_bfloat16& lh, const __nv_bfloat16& rh) { // NOLINT
lh = __nv_bfloat16(float(lh) / float(rh)); return lh; // NOLINT
}
__device__ __forceinline__ __nv_bfloat16& operator++(__nv_bfloat16& h) { // NOLINT
h = __nv_bfloat16(float(h) + 1.0f); return h; // NOLINT
}
__device__ __forceinline__ __nv_bfloat16& operator--(__nv_bfloat16& h) { // NOLINT
h = __nv_bfloat16(float(h) - 1.0f); return h; // NOLINT
}
__device__ __forceinline__ __nv_bfloat16 operator++(__nv_bfloat16& h, int) { // NOLINT
__nv_bfloat16 ret = h; h = __nv_bfloat16(float(h) + 1.0f); return ret; // NOLINT
}
__device__ __forceinline__ __nv_bfloat16 operator--(__nv_bfloat16& h, int) { // NOLINT
__nv_bfloat16 ret = h; h = __nv_bfloat16(float(h) - 1.0f); return ret; // NOLINT
}
__device__ __forceinline__ __nv_bfloat16 operator+(const __nv_bfloat16& h) {
return h;
}
__device__ __forceinline__ __nv_bfloat16 operator-(const __nv_bfloat16& h) {
return __nv_bfloat16(-float(h)); // NOLINT
}
__device__ __forceinline__ bool operator==(const __nv_bfloat16& lh,
const __nv_bfloat16& rh) {
return float(lh) == float(rh); // NOLINT
}
__device__ __forceinline__ bool operator!=(const __nv_bfloat16& lh,
const __nv_bfloat16& rh) {
return float(lh) != float(rh); // NOLINT
}
__device__ __forceinline__ bool operator> (const __nv_bfloat16& lh,
const __nv_bfloat16& rh) {
return float(lh) > float(rh); // NOLINT
}
__device__ __forceinline__ bool operator< (const __nv_bfloat16& lh,
const __nv_bfloat16& rh) {
return float(lh) < float(rh); // NOLINT
}
__device__ __forceinline__ bool operator>=(const __nv_bfloat16& lh,
const __nv_bfloat16& rh) {
return float(lh) >= float(rh); // NOLINT
}
__device__ __forceinline__ bool operator<=(const __nv_bfloat16& lh,
const __nv_bfloat16& rh) {
return float(lh) <= float(rh); // NOLINT
}
#endif // defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
#endif // __CUDACC__
#endif // BF16_ENABLED
#endif // DGL_ARRAY_CUDA_BF16_CUH_
...@@ -65,12 +65,18 @@ NDArray CSRGetData( ...@@ -65,12 +65,18 @@ NDArray CSRGetData(
return rst; return rst;
} }
#ifdef USE_FP16
template NDArray CSRGetData<kDGLCUDA, int32_t, __half>( template NDArray CSRGetData<kDGLCUDA, int32_t, __half>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, __half filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, __half filler);
template NDArray CSRGetData<kDGLCUDA, int64_t, __half>( template NDArray CSRGetData<kDGLCUDA, int64_t, __half>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, __half filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, __half filler);
#endif #if BF16_ENABLED
template NDArray CSRGetData<kDGLCUDA, int32_t, __nv_bfloat16>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
NDArray weights, __nv_bfloat16 filler);
template NDArray CSRGetData<kDGLCUDA, int64_t, __nv_bfloat16>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
NDArray weights, __nv_bfloat16 filler);
#endif // BF16_ENABLED
template NDArray CSRGetData<kDGLCUDA, int32_t, float>( template NDArray CSRGetData<kDGLCUDA, int32_t, float>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler);
template NDArray CSRGetData<kDGLCUDA, int64_t, float>( template NDArray CSRGetData<kDGLCUDA, int64_t, float>(
......
...@@ -255,12 +255,16 @@ std::pair<CSRMatrix, NDArray> CSRMM( ...@@ -255,12 +255,16 @@ std::pair<CSRMatrix, NDArray> CSRMM(
} }
} }
#ifdef USE_FP16
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int32_t, __half>( template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int32_t, __half>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray); const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int64_t, __half>( template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int64_t, __half>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray); const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
#endif #if BF16_ENABLED
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int32_t, __nv_bfloat16>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int64_t, __nv_bfloat16>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
#endif // BF16_ENABLED
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int32_t, float>( template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int32_t, float>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray); const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int64_t, float>( template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int64_t, float>(
......
...@@ -167,12 +167,16 @@ std::pair<CSRMatrix, NDArray> CSRSum( ...@@ -167,12 +167,16 @@ std::pair<CSRMatrix, NDArray> CSRSum(
} }
} }
#ifdef USE_FP16
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int32_t, __half>( template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int32_t, __half>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&); const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int64_t, __half>( template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int64_t, __half>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&); const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
#endif #if BF16_ENABLED
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int32_t, __nv_bfloat16>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int64_t, __nv_bfloat16>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
#endif // BF16_ENABLED
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int32_t, float>( template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int32_t, float>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&); const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int64_t, float>( template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int64_t, float>(
......
...@@ -9,6 +9,8 @@ ...@@ -9,6 +9,8 @@
#include <cusparse.h> #include <cusparse.h>
#include <dgl/runtime/c_runtime_api.h> #include <dgl/runtime/c_runtime_api.h>
#include "fp16.cuh"
#include "bf16.cuh"
namespace dgl { namespace dgl {
namespace aten { namespace aten {
...@@ -34,7 +36,6 @@ struct CSRGEMM { ...@@ -34,7 +36,6 @@ struct CSRGEMM {
} }
}; };
#ifdef USE_FP16
template <> template <>
struct CSRGEMM<__half> { struct CSRGEMM<__half> {
template <typename... Args> template <typename... Args>
...@@ -58,7 +59,32 @@ struct CSRGEMM<__half> { ...@@ -58,7 +59,32 @@ struct CSRGEMM<__half> {
return static_cast<cusparseStatus_t>(0); return static_cast<cusparseStatus_t>(0);
} }
}; };
#endif
#if BF16_ENABLED
template <>
struct CSRGEMM<__nv_bfloat16> {
template <typename... Args>
static inline cusparseStatus_t bufferSizeExt(Args... args) {
// TODO(ndickson): There is no cusparseHcsrgemm2_bufferSizeExt, so a different
// implementation would be required.
LOG(FATAL) << "CSRGEMM::bufferSizeExt does not support dtype bfloat16 (BF16).";
return static_cast<cusparseStatus_t>(0);
}
template <typename... Args>
static inline cusparseStatus_t nnz(Args... args) {
return cusparseXcsrgemm2Nnz(args...);
}
template <typename... Args>
static inline cusparseStatus_t compute(Args... args) {
// TODO(ndickson): There is no cusparseHcsrgemm2, so a different
// implementation would be required.
LOG(FATAL) << "CSRGEMM::compute does not support dtype bfloat16 (BF16).";
return static_cast<cusparseStatus_t>(0);
}
};
#endif // BF16_ENABLED
template <> template <>
struct CSRGEMM<float> { struct CSRGEMM<float> {
...@@ -117,7 +143,6 @@ struct CSRGEAM { ...@@ -117,7 +143,6 @@ struct CSRGEAM {
} }
}; };
#ifdef USE_FP16
template <> template <>
struct CSRGEAM<__half> { struct CSRGEAM<__half> {
template <typename... Args> template <typename... Args>
...@@ -141,7 +166,32 @@ struct CSRGEAM<__half> { ...@@ -141,7 +166,32 @@ struct CSRGEAM<__half> {
return static_cast<cusparseStatus_t>(0); return static_cast<cusparseStatus_t>(0);
} }
}; };
#endif
#if BF16_ENABLED
template <>
struct CSRGEAM<__nv_bfloat16> {
template <typename... Args>
static inline cusparseStatus_t bufferSizeExt(Args... args) {
// TODO(ndickson): There is no cusparseHcsrgeam2_bufferSizeExt, so a different
// implementation would be required.
LOG(FATAL) << "CSRGEAM::bufferSizeExt does not support dtype bfloat16 (BF16).";
return static_cast<cusparseStatus_t>(0);
}
template <typename... Args>
static inline cusparseStatus_t nnz(Args... args) {
return cusparseXcsrgeam2Nnz(args...);
}
template <typename... Args>
static inline cusparseStatus_t compute(Args... args) {
// TODO(ndickson): There is no cusparseHcsrgeam2, so a different
// implementation would be required.
LOG(FATAL) << "CSRGEAM::compute does not support dtype bfloat16 (BF16).";
return static_cast<cusparseStatus_t>(0);
}
};
#endif // BF16_ENABLED
template <> template <>
struct CSRGEAM<float> { struct CSRGEAM<float> {
......
/*! /*!
* Copyright (c) 2020 by Contributors * Copyright (c) 2020-2022 by Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* \file array/cuda/fp16.cuh * \file array/cuda/fp16.cuh
* \brief float16 related functions. * \brief float16 related functions.
* \note this file is modified from TVM project: * \note this file is modified from TVM project:
...@@ -9,7 +22,6 @@ ...@@ -9,7 +22,6 @@
#define DGL_ARRAY_CUDA_FP16_CUH_ #define DGL_ARRAY_CUDA_FP16_CUH_
#ifdef USE_FP16
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <algorithm> #include <algorithm>
...@@ -97,6 +109,4 @@ __device__ __forceinline__ bool operator<=(const __half& lh, const __half& rh) { ...@@ -97,6 +109,4 @@ __device__ __forceinline__ bool operator<=(const __half& lh, const __half& rh) {
#endif // __CUDA_ARCH__ < 530 #endif // __CUDA_ARCH__ < 530
#endif // __CUDACC__ #endif // __CUDACC__
#endif // USE_FP16
#endif // DGL_ARRAY_CUDA_FP16_CUH_ #endif // DGL_ARRAY_CUDA_FP16_CUH_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment