Commit 0a56d652 authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by Zihao Ye
Browse files

[Feature] x_dot_x builtin kernel support (#831)

* upd

* fig edgebatch edges

* add test

* trigger

* Update README.md for pytorch PinSage example.

Add noting that the PinSage model example under
example/pytorch/recommendation only work with Python 3.6+
as its dataset loader depends on stanfordnlp package
which work only with Python 3.6+.

* Provid a frame agnostic API to test nn modules on both CPU and CUDA side.

1. make dgl.nn.xxx frame agnostic
2. make test.backend include dgl.nn modules
3. modify test_edge_softmax of test/mxnet/test_nn.py and
    test/pytorch/test_nn.py work on both CPU and GPU

* Fix style

* Delete unused code

* Make agnostic test only related to tests/backend

1. clear all agnostic related code in dgl.nn
2. make test_graph_conv agnostic to cpu/gpu

* Fix code style

* fix

* doc

* Make all test code under tests.mxnet/pytorch.test_nn.py
work on both CPU and GPU.

* Fix syntex

* Remove rand

* Start implementing masked-mm kernel.

Add base control flow code.

* Add masked dot declare

* Update func/variable name

* Skeleton compile OK

* Update Implement. Unify BinaryDot with BinaryReduce

* New Impl of x_dot_x, reuse binary reduce template

* Compile OK.

TODO:
1. make sure x_add_x, x_sub_x, x_mul_x, x_div_x work
2. let x_dot_x work
3. make sure backward of x_add_x, x_sub_x, x_mul_x, x_div_x work
4. let x_dot_x backward work

* Fix code style

* Now we can pass the tests/compute/test_kernel.py for add/sub/mul/div forward and backward

* Fix mxnet test code

* Add u_dot_v, u_dot_e, v_dot_e unitest.

* Update doc

* Now also support v_dot_u, e_dot_u, e_dot_v

* Add unroll for some loop

* Add some Opt for cuda backward of dot builtin.

Backward is still slow for dot

* Apply UnravelRavel opt for broadcast backward

* update docstring
parent 4297e6d6
...@@ -40,6 +40,12 @@ Message functions ...@@ -40,6 +40,12 @@ Message functions
e_sub_v e_sub_v
e_mul_v e_mul_v
e_div_v e_div_v
u_dot_v
u_dot_e
v_dot_e
v_dot_u
e_dot_u
e_dot_v
Reduce functions Reduce functions
---------------- ----------------
......
...@@ -98,6 +98,10 @@ Here is a cheatsheet of all the DGL builtins. ...@@ -98,6 +98,10 @@ Here is a cheatsheet of all the DGL builtins.
| +----------------------------------------------------+-----------------------+ | +----------------------------------------------------+-----------------------+
| | ``e_add_v``, ``e_sub_v``, ``e_mul_v``, ``e_div_v`` | | | | ``e_add_v``, ``e_sub_v``, ``e_mul_v``, ``e_div_v`` | |
| +----------------------------------------------------+-----------------------+ | +----------------------------------------------------+-----------------------+
| | ``u_dot_v``, ``u_dot_e``, ``v_dot_e`` | |
| +----------------------------------------------------+-----------------------+
| | ``v_dot_u``, ``e_dot_u``, ``e_dot_v`` | |
| +----------------------------------------------------+-----------------------+
| | ``src_mul_edge`` | alias of ``u_mul_e`` | | | ``src_mul_edge`` | alias of ``u_mul_e`` |
+-------------------------+----------------------------------------------------+-----------------------+ +-------------------------+----------------------------------------------------+-----------------------+
| Reduce function | ``max`` | | | Reduce function | ``max`` | |
......
...@@ -376,8 +376,11 @@ class BinaryReduce(mx.autograd.Function): ...@@ -376,8 +376,11 @@ class BinaryReduce(mx.autograd.Function):
def forward(self, lhs_data, rhs_data): def forward(self, lhs_data, rhs_data):
lhs_data_nd = zerocopy_to_dgl_ndarray(lhs_data) lhs_data_nd = zerocopy_to_dgl_ndarray(lhs_data)
rhs_data_nd = zerocopy_to_dgl_ndarray(rhs_data) rhs_data_nd = zerocopy_to_dgl_ndarray(rhs_data)
feat_shape = K.infer_binary_feature_shape(lhs_data_nd, rhs_data_nd) feat_shape = K.infer_binary_feature_shape(self.binary_op, lhs_data_nd, rhs_data_nd)
out_data = nd.empty((self.out_size,) + feat_shape, out_shape = feat_shape
if self.binary_op == 'dot':
out_shape = feat_shape[:-1]
out_data = nd.empty((self.out_size,) + out_shape,
ctx=lhs_data.context, dtype=lhs_data.dtype) ctx=lhs_data.context, dtype=lhs_data.dtype)
out_data_nd = zerocopy_to_dgl_ndarray_for_write(out_data) out_data_nd = zerocopy_to_dgl_ndarray_for_write(out_data)
K.binary_op_reduce( K.binary_op_reduce(
......
...@@ -288,8 +288,11 @@ class BinaryReduce(th.autograd.Function): ...@@ -288,8 +288,11 @@ class BinaryReduce(th.autograd.Function):
out_size, lhs_map, rhs_map, out_map): out_size, lhs_map, rhs_map, out_map):
lhs_data_nd = zerocopy_to_dgl_ndarray(lhs_data) lhs_data_nd = zerocopy_to_dgl_ndarray(lhs_data)
rhs_data_nd = zerocopy_to_dgl_ndarray(rhs_data) rhs_data_nd = zerocopy_to_dgl_ndarray(rhs_data)
feat_shape = K.infer_binary_feature_shape(lhs_data_nd, rhs_data_nd) feat_shape = K.infer_binary_feature_shape(binary_op, lhs_data_nd, rhs_data_nd)
out_data = lhs_data.new_empty((out_size,) + feat_shape) out_shape = feat_shape
if binary_op == 'dot':
out_shape = feat_shape[:-1]
out_data = lhs_data.new_empty((out_size,) + out_shape)
out_data_nd = zerocopy_to_dgl_ndarray(out_data) out_data_nd = zerocopy_to_dgl_ndarray(out_data)
K.binary_op_reduce( K.binary_op_reduce(
reducer if reducer != 'mean' else 'sum', reducer if reducer != 'mean' else 'sum',
......
...@@ -147,12 +147,17 @@ def copy_e(e, out): ...@@ -147,12 +147,17 @@ def copy_e(e, out):
############################################################################### ###############################################################################
# Generate all following builtin message functions: # Generate all following builtin message functions:
# element-wise message functions:
# u_add_v, u_sub_v, u_mul_v, u_div_v # u_add_v, u_sub_v, u_mul_v, u_div_v
# u_add_e, u_sub_e, u_mul_e, u_div_e # u_add_e, u_sub_e, u_mul_e, u_div_e
# v_add_u, v_sub_u, v_mul_u, v_div_u # v_add_u, v_sub_u, v_mul_u, v_div_u
# v_add_e, v_sub_e, v_mul_e, v_div_e # v_add_e, v_sub_e, v_mul_e, v_div_e
# e_add_u, e_sub_u, e_mul_u, e_div_u # e_add_u, e_sub_u, e_mul_u, e_div_u
# e_add_v, e_sub_v, e_mul_v, e_div_v # e_add_v, e_sub_v, e_mul_v, e_div_v
#
# dot message functions:
# u_dot_v, u_dot_e, v_dot_e
# v_dot_u, e_dot_u, e_dot_v
_TARGET_MAP = { _TARGET_MAP = {
"u": TargetCode.SRC, "u": TargetCode.SRC,
...@@ -200,12 +205,11 @@ def _register_builtin_message_func(): ...@@ -200,12 +205,11 @@ def _register_builtin_message_func():
target = ["u", "v", "e"] target = ["u", "v", "e"]
for lhs, rhs in product(target, target): for lhs, rhs in product(target, target):
if lhs != rhs: if lhs != rhs:
for binary_op in ["add", "sub", "mul", "div"]: for binary_op in ["add", "sub", "mul", "div", "dot"]:
func = _gen_message_builtin(lhs, rhs, binary_op) func = _gen_message_builtin(lhs, rhs, binary_op)
setattr(sys.modules[__name__], func.__name__, func) setattr(sys.modules[__name__], func.__name__, func)
__all__.append(func.__name__) __all__.append(func.__name__)
_register_builtin_message_func() _register_builtin_message_func()
......
...@@ -4,11 +4,14 @@ from __future__ import absolute_import ...@@ -4,11 +4,14 @@ from __future__ import absolute_import
from ._ffi.function import _init_api from ._ffi.function import _init_api
from .ndarray import empty from .ndarray import empty
def infer_binary_feature_shape(lhs, rhs): # pylint: disable=invalid-name
def infer_binary_feature_shape(op, lhs, rhs):
"""Infer the output feature shape after a binary operation between lhs and rhs. """Infer the output feature shape after a binary operation between lhs and rhs.
Parameter Parameter
--------- ---------
op : string
The binary_op name.
lhs : dgl.ndarray.NDArray lhs : dgl.ndarray.NDArray
The lhs tensor. The lhs tensor.
rhs : dgl.ndarray.NDArray rhs : dgl.ndarray.NDArray
...@@ -19,7 +22,7 @@ def infer_binary_feature_shape(lhs, rhs): ...@@ -19,7 +22,7 @@ def infer_binary_feature_shape(lhs, rhs):
tuple of int tuple of int
The output feature shape. The output feature shape.
""" """
ret = _CAPI_DGLKernelInferBinaryFeatureShape(lhs, rhs) ret = _CAPI_DGLKernelInferBinaryFeatureShape(op, lhs, rhs)
return tuple(ret.asnumpy()) return tuple(ret.asnumpy())
# pylint: disable=invalid-name # pylint: disable=invalid-name
......
...@@ -93,11 +93,25 @@ bool HasBcast(NDArray lhs, NDArray rhs) { ...@@ -93,11 +93,25 @@ bool HasBcast(NDArray lhs, NDArray rhs) {
// e.g. (4, 1, 3, 3) and (4, 5, 3, 3) become (4, 1, 9) and (4, 5, 9) // e.g. (4, 1, 3, 3) and (4, 5, 3, 3) become (4, 1, 9) and (4, 5, 9)
// //
// See also: BcastInfo (kernel/binary_reduce.h) // See also: BcastInfo (kernel/binary_reduce.h)
BcastInfo CalcBcastInfo(NDArray lhs, NDArray rhs) { BcastInfo CalcBcastInfo(const std::string& op, NDArray lhs, NDArray rhs) {
BcastInfo ret; BcastInfo ret;
const int max_ndim = std::max(lhs->ndim, rhs->ndim) - 1; const int max_ndim = std::max(lhs->ndim, rhs->ndim) - 1;
int64_t accum = 0; int64_t accum = 0;
for (int j = 0; j < max_ndim; ++j) { int j = 0;
// for dot operation: vector [dot] vector
// lhs_shape[ndim-1] == rhs_shape[ndim-1] = sizeof(vector)
// out_shape[ndim-1] = 1
if (op == binary_op::kDot) {
// get size of vector
ret.data_len = lhs->shape[lhs->ndim - 1];
// skip vector size dim
++j;
ret.real_out_shape.push_back(ret.data_len);
} else { // op != binary_op::kDot
ret.data_len = 1;
}
for (; j < max_ndim; ++j) {
const int dl = (lhs->ndim - 1 - j < 1)? 1 : lhs->shape[lhs->ndim - 1 - j]; const int dl = (lhs->ndim - 1 - j < 1)? 1 : lhs->shape[lhs->ndim - 1 - j];
const int dr = (rhs->ndim - 1 - j < 1)? 1 : rhs->shape[rhs->ndim - 1 - j]; const int dr = (rhs->ndim - 1 - j < 1)? 1 : rhs->shape[rhs->ndim - 1 - j];
if (dl != dr) { if (dl != dr) {
...@@ -258,16 +272,18 @@ class BipartiteCSRWrapper : public CSRWrapper { ...@@ -258,16 +272,18 @@ class BipartiteCSRWrapper : public CSRWrapper {
std::vector<int64_t> InferBinaryFeatureShape( std::vector<int64_t> InferBinaryFeatureShape(
const std::string& op,
NDArray lhs, NDArray lhs,
NDArray rhs) { NDArray rhs) {
return CalcBcastInfo(lhs, rhs).real_out_shape; return CalcBcastInfo(op, lhs, rhs).real_out_shape;
} }
DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelInferBinaryFeatureShape") DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelInferBinaryFeatureShape")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
NDArray lhs = args[0]; std::string op = args[0];
NDArray rhs = args[1]; NDArray lhs = args[1];
const auto& shape = InferBinaryFeatureShape(lhs, rhs); NDArray rhs = args[2];
const auto& shape = InferBinaryFeatureShape(op, lhs, rhs);
const int64_t len = shape.size(); const int64_t len = shape.size();
NDArray ret = NDArray::Empty( NDArray ret = NDArray::Empty(
{len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); {len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
...@@ -300,7 +316,7 @@ void BinaryOpReduce( ...@@ -300,7 +316,7 @@ void BinaryOpReduce(
rhs_mapping, lhs_mapping, out_mapping); rhs_mapping, lhs_mapping, out_mapping);
} else { } else {
if (HasBcast(lhs_data, rhs_data)) { if (HasBcast(lhs_data, rhs_data)) {
BcastInfo info = CalcBcastInfo(lhs_data, rhs_data); BcastInfo info = CalcBcastInfo(op, lhs_data, rhs_data);
DGL_XPU_SWITCH(ctx.device_type, BinaryReduceBcastImpl, DGL_XPU_SWITCH(ctx.device_type, BinaryReduceBcastImpl,
info, reducer, op, graph, info, reducer, op, graph,
lhs, rhs, lhs, rhs,
...@@ -394,7 +410,7 @@ void BackwardLhsBinaryOpReduce( ...@@ -394,7 +410,7 @@ void BackwardLhsBinaryOpReduce(
grad_out_data, grad_lhs_data); grad_out_data, grad_lhs_data);
} else { } else {
if (HasBcast(lhs_data, rhs_data)) { if (HasBcast(lhs_data, rhs_data)) {
BcastInfo info = CalcBcastInfo(lhs_data, rhs_data); BcastInfo info = CalcBcastInfo(op, lhs_data, rhs_data);
DGL_XPU_SWITCH(ctx.device_type, BackwardBinaryReduceBcastImpl, DGL_XPU_SWITCH(ctx.device_type, BackwardBinaryReduceBcastImpl,
info, reducer, op, graph, info, reducer, op, graph,
lhs, rhs, lhs, rhs,
...@@ -468,7 +484,7 @@ void BackwardRhsBinaryOpReduce( ...@@ -468,7 +484,7 @@ void BackwardRhsBinaryOpReduce(
grad_out_data, grad_rhs_data); grad_out_data, grad_rhs_data);
} else { } else {
if (HasBcast(lhs_data, rhs_data)) { if (HasBcast(lhs_data, rhs_data)) {
BcastInfo info = CalcBcastInfo(lhs_data, rhs_data); BcastInfo info = CalcBcastInfo(op, lhs_data, rhs_data);
DGL_XPU_SWITCH(ctx.device_type, BackwardBinaryReduceBcastImpl, DGL_XPU_SWITCH(ctx.device_type, BackwardBinaryReduceBcastImpl,
info, reducer, op, graph, info, reducer, op, graph,
lhs, rhs, lhs, rhs,
......
...@@ -30,6 +30,8 @@ struct BcastInfo { ...@@ -30,6 +30,8 @@ struct BcastInfo {
std::vector<int64_t> lhs_shape, lhs_stride; std::vector<int64_t> lhs_shape, lhs_stride;
std::vector<int64_t> rhs_shape, rhs_stride; std::vector<int64_t> rhs_shape, rhs_stride;
std::vector<int64_t> out_shape, out_stride; std::vector<int64_t> out_shape, out_stride;
int64_t data_len;
}; };
/* /*
......
...@@ -29,6 +29,7 @@ static const char kAdd[] = "add"; ...@@ -29,6 +29,7 @@ static const char kAdd[] = "add";
static const char kSub[] = "sub"; static const char kSub[] = "sub";
static const char kMul[] = "mul"; static const char kMul[] = "mul";
static const char kDiv[] = "div"; static const char kDiv[] = "div";
static const char kDot[] = "dot";
static const char kUseLhs[] = "use_lhs"; static const char kUseLhs[] = "use_lhs";
/*! /*!
...@@ -129,8 +130,8 @@ struct SwitchSrcDst<SelectDst> { ...@@ -129,8 +130,8 @@ struct SwitchSrcDst<SelectDst> {
// common binary functors // common binary functors
template <typename DType> template <typename DType>
struct BinaryAdd { struct BinaryAdd {
static DGLDEVICE DGLINLINE DType Call(DType lhs, DType rhs) { static DGLDEVICE DGLINLINE DType Call(DType *lhs, DType *rhs, int64_t len) {
return lhs + rhs; return lhs[0] + rhs[0];
} }
static DGLDEVICE DGLINLINE DType BackwardLhs(DType lhs, DType rhs, DType out) { static DGLDEVICE DGLINLINE DType BackwardLhs(DType lhs, DType rhs, DType out) {
return 1; return 1;
...@@ -142,8 +143,8 @@ struct BinaryAdd { ...@@ -142,8 +143,8 @@ struct BinaryAdd {
template <typename DType> template <typename DType>
struct BinaryMul { struct BinaryMul {
static DGLDEVICE DGLINLINE DType Call(DType lhs, DType rhs) { static DGLDEVICE DGLINLINE DType Call(DType *lhs, DType *rhs, int64_t len) {
return lhs * rhs; return lhs[0] * rhs[0];
} }
static DGLDEVICE DGLINLINE DType BackwardLhs(DType lhs, DType rhs, DType out) { static DGLDEVICE DGLINLINE DType BackwardLhs(DType lhs, DType rhs, DType out) {
return rhs; return rhs;
...@@ -155,8 +156,8 @@ struct BinaryMul { ...@@ -155,8 +156,8 @@ struct BinaryMul {
template <typename DType> template <typename DType>
struct BinarySub { struct BinarySub {
static DGLDEVICE DGLINLINE DType Call(DType lhs, DType rhs) { static DGLDEVICE DGLINLINE DType Call(DType *lhs, DType *rhs, int64_t len) {
return lhs - rhs; return lhs[0] - rhs[0];
} }
static DGLDEVICE DGLINLINE DType BackwardLhs(DType lhs, DType rhs, DType out) { static DGLDEVICE DGLINLINE DType BackwardLhs(DType lhs, DType rhs, DType out) {
return 1; return 1;
...@@ -168,8 +169,8 @@ struct BinarySub { ...@@ -168,8 +169,8 @@ struct BinarySub {
template <typename DType> template <typename DType>
struct BinaryDiv { struct BinaryDiv {
static DGLDEVICE DGLINLINE DType Call(DType lhs, DType rhs) { static DGLDEVICE DGLINLINE DType Call(DType *lhs, DType *rhs, int64_t len) {
return lhs / rhs; return lhs[0] / rhs[0];
} }
static DGLDEVICE DGLINLINE DType BackwardLhs(DType lhs, DType rhs, DType out) { static DGLDEVICE DGLINLINE DType BackwardLhs(DType lhs, DType rhs, DType out) {
return static_cast<DType>(1) / rhs; return static_cast<DType>(1) / rhs;
...@@ -181,8 +182,8 @@ struct BinaryDiv { ...@@ -181,8 +182,8 @@ struct BinaryDiv {
template <typename DType> template <typename DType>
struct BinaryUseLhs { struct BinaryUseLhs {
static DGLDEVICE DGLINLINE DType Call(DType lhs, DType rhs) { static DGLDEVICE DGLINLINE DType Call(DType *lhs, DType *rhs, int64_t len) {
return lhs; return lhs[0];
} }
static DGLDEVICE DGLINLINE DType BackwardLhs(DType lhs, DType rhs, DType out) { static DGLDEVICE DGLINLINE DType BackwardLhs(DType lhs, DType rhs, DType out) {
return 1; return 1;
...@@ -192,6 +193,25 @@ struct BinaryUseLhs { ...@@ -192,6 +193,25 @@ struct BinaryUseLhs {
} }
}; };
template <typename DType>
struct BinaryDot {
static DGLDEVICE DGLINLINE DType Call(DType *lhs, DType *rhs, int64_t len) {
DType out = 0;
// simple vector dot vector
#pragma unroll
for (int i = 0; i < len; i ++)
out += lhs[i] * rhs[i];
return out;
}
static DGLDEVICE DGLINLINE DType BackwardLhs(DType lhs, DType rhs, DType out) {
return rhs;
}
static DGLDEVICE DGLINLINE DType BackwardRhs(DType lhs, DType rhs, DType out) {
return lhs;
}
};
// Macro for dispatching op enum code and target code into template arguments. // Macro for dispatching op enum code and target code into template arguments.
// The macro dispatches following combinations: // The macro dispatches following combinations:
// - Add(Src, Dst), Add(Src, Edge), Add(Dst, Edge) // - Add(Src, Dst), Add(Src, Edge), Add(Dst, Edge)
...@@ -201,6 +221,8 @@ struct BinaryUseLhs { ...@@ -201,6 +221,8 @@ struct BinaryUseLhs {
// - Div(Src, Dst), Div(Src, Edge), Div(Dst, Edge) // - Div(Src, Dst), Div(Src, Edge), Div(Dst, Edge)
// Div(Dst, Src), Div(Edge, Src), Div(Edge, Dst) // Div(Dst, Src), Div(Edge, Src), Div(Edge, Dst)
// - UseLhs(Src, None), UseLhs(Edge, None) // - UseLhs(Src, None), UseLhs(Edge, None)
// - Dot(Src, Dst), Dot(Src, Edge), Dot(Dst, Edge)
// - Dot(Dst, Src), Dot(Edge, Src), Dot(Edge, Dst)
// Note that for commutative operators (e.g. Add and Mul), we only generate // Note that for commutative operators (e.g. Add and Mul), we only generate
// kernels for lhs code smaller than rhs code. // kernels for lhs code smaller than rhs code.
#define OP_TARGET_SWITCH(op, lhs, rhs, DType, OpType, LeftType, RightType, ...) \ #define OP_TARGET_SWITCH(op, lhs, rhs, DType, OpType, LeftType, RightType, ...) \
...@@ -306,6 +328,36 @@ struct BinaryUseLhs { ...@@ -306,6 +328,36 @@ struct BinaryUseLhs {
typedef SelectEdge LeftType; \ typedef SelectEdge LeftType; \
typedef SelectNone RightType; \ typedef SelectNone RightType; \
{__VA_ARGS__} \ {__VA_ARGS__} \
} else if (op == kDot && lhs == kSrc && rhs == kDst) { \
typedef BinaryDot<DType> OpType; \
typedef SelectSrc LeftType; \
typedef SelectDst RightType; \
{__VA_ARGS__} \
} else if (op == kDot && lhs == kSrc && rhs == kEdge) { \
typedef BinaryDot<DType> OpType; \
typedef SelectSrc LeftType; \
typedef SelectEdge RightType; \
{__VA_ARGS__} \
} else if (op == kDot && lhs == kDst && rhs == kEdge) { \
typedef BinaryDot<DType> OpType; \
typedef SelectDst LeftType; \
typedef SelectEdge RightType; \
{__VA_ARGS__} \
} else if (op == kDot && lhs == kDst && rhs == kSrc) { \
typedef BinaryDot<DType> OpType; \
typedef SelectDst LeftType; \
typedef SelectSrc RightType; \
{__VA_ARGS__} \
} else if (op == kDot && lhs == kEdge && rhs == kSrc) { \
typedef BinaryDot<DType> OpType; \
typedef SelectEdge LeftType; \
typedef SelectSrc RightType; \
{__VA_ARGS__} \
} else if (op == kDot && lhs == kEdge && rhs == kDst) { \
typedef BinaryDot<DType> OpType; \
typedef SelectEdge LeftType; \
typedef SelectDst RightType; \
{__VA_ARGS__} \
} else { \ } else { \
LOG(FATAL) << "Unsupported operation: op=" << op \ LOG(FATAL) << "Unsupported operation: op=" << op \
<< " lhs=" << lhs << " rhs=" << rhs; \ << " lhs=" << lhs << " rhs=" << rhs; \
...@@ -333,7 +385,13 @@ struct BinaryUseLhs { ...@@ -333,7 +385,13 @@ struct BinaryUseLhs {
MSVC_EXPAND(GEN(__VA_ARGS__, SelectDst, SelectEdge, BinaryDiv)) \ MSVC_EXPAND(GEN(__VA_ARGS__, SelectDst, SelectEdge, BinaryDiv)) \
MSVC_EXPAND(GEN(__VA_ARGS__, SelectEdge, SelectDst, BinaryDiv)) \ MSVC_EXPAND(GEN(__VA_ARGS__, SelectEdge, SelectDst, BinaryDiv)) \
MSVC_EXPAND(GEN(__VA_ARGS__, SelectSrc, SelectNone, BinaryUseLhs)) \ MSVC_EXPAND(GEN(__VA_ARGS__, SelectSrc, SelectNone, BinaryUseLhs)) \
MSVC_EXPAND(GEN(__VA_ARGS__, SelectEdge, SelectNone, BinaryUseLhs)) MSVC_EXPAND(GEN(__VA_ARGS__, SelectEdge, SelectNone, BinaryUseLhs)) \
MSVC_EXPAND(GEN(__VA_ARGS__, SelectSrc, SelectDst, BinaryDot)) \
MSVC_EXPAND(GEN(__VA_ARGS__, SelectSrc, SelectEdge, BinaryDot)) \
MSVC_EXPAND(GEN(__VA_ARGS__, SelectDst, SelectEdge, BinaryDot)) \
MSVC_EXPAND(GEN(__VA_ARGS__, SelectDst, SelectSrc, BinaryDot)) \
MSVC_EXPAND(GEN(__VA_ARGS__, SelectEdge, SelectSrc, BinaryDot)) \
MSVC_EXPAND(GEN(__VA_ARGS__, SelectEdge, SelectDst, BinaryDot))
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
// Defines reducer category. Each category is an empty structure. // Defines reducer category. Each category is an empty structure.
......
...@@ -28,7 +28,7 @@ namespace kernel { ...@@ -28,7 +28,7 @@ namespace kernel {
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
template <int XPU, typename Idx, typename DType, typename Reducer> template <int XPU, typename Idx, typename DType, typename Reducer>
GData<Idx, DType> AllocGData( GData<Idx, DType> AllocGData(const std::string& op,
const DLContext& ctx, int64_t x_len, const DLContext& ctx, int64_t x_len,
runtime::NDArray lhs_mapping, runtime::NDArray rhs_mapping, runtime::NDArray lhs_mapping, runtime::NDArray rhs_mapping,
runtime::NDArray lhs_data, runtime::NDArray rhs_data, runtime::NDArray lhs_data, runtime::NDArray rhs_data,
...@@ -49,6 +49,15 @@ GData<Idx, DType> AllocGData( ...@@ -49,6 +49,15 @@ GData<Idx, DType> AllocGData(
if (!utils::IsNoneArray(out_mapping)) { if (!utils::IsNoneArray(out_mapping)) {
gdata.out_mapping = static_cast<Idx*>(out_mapping->data); gdata.out_mapping = static_cast<Idx*>(out_mapping->data);
} }
// for dot operation: vector [dot] vector
if (op == binary_op::kDot) {
// get size of vector
gdata.data_len = lhs_data->shape[lhs_data->ndim - 1];
} else {
gdata.data_len = 1;
}
// fill out data with zero values // fill out data with zero values
utils::Fill<XPU>(ctx, gdata.out_data, utils::NElements(out_data), Zero<Reducer>::value); utils::Fill<XPU>(ctx, gdata.out_data, utils::NElements(out_data), Zero<Reducer>::value);
return gdata; return gdata;
...@@ -92,7 +101,7 @@ void BinaryReduceImpl( ...@@ -92,7 +101,7 @@ void BinaryReduceImpl(
DGL_DTYPE_SWITCH(dtype, DType, { DGL_DTYPE_SWITCH(dtype, DType, {
DGL_IDX_TYPE_SWITCH(bits, Idx, { DGL_IDX_TYPE_SWITCH(bits, Idx, {
REDUCER_SWITCH(reducer, XPU, DType, Reducer, { REDUCER_SWITCH(reducer, XPU, DType, Reducer, {
auto gdata = AllocGData<XPU, Idx, DType, Reducer>( auto gdata = AllocGData<XPU, Idx, DType, Reducer>(op,
rtcfg.ctx, x_len, lhs_mapping, rhs_mapping, rtcfg.ctx, x_len, lhs_mapping, rhs_mapping,
lhs_data, rhs_data, out_mapping, out_data); lhs_data, rhs_data, out_mapping, out_data);
OP_TARGET_SWITCH(op, lhs, rhs, DType, BinaryOp, LeftTarget, RightTarget, { OP_TARGET_SWITCH(op, lhs, rhs, DType, BinaryOp, LeftTarget, RightTarget, {
...@@ -110,7 +119,7 @@ void BinaryReduceImpl( ...@@ -110,7 +119,7 @@ void BinaryReduceImpl(
template <int XPU, typename Idx, typename DType> template <int XPU, typename Idx, typename DType>
BackwardGData<Idx, DType> AllocBackwardGData( BackwardGData<Idx, DType> AllocBackwardGData(
const DLContext& ctx, int64_t x_len, const std::string& op, const DLContext& ctx, int64_t x_len,
runtime::NDArray lhs_mapping, runtime::NDArray rhs_mapping, runtime::NDArray out_mapping, runtime::NDArray lhs_mapping, runtime::NDArray rhs_mapping, runtime::NDArray out_mapping,
runtime::NDArray lhs_data, runtime::NDArray rhs_data, runtime::NDArray out_data, runtime::NDArray lhs_data, runtime::NDArray rhs_data, runtime::NDArray out_data,
runtime::NDArray grad_out_data, runtime::NDArray grad_out_data,
...@@ -144,6 +153,14 @@ BackwardGData<Idx, DType> AllocBackwardGData( ...@@ -144,6 +153,14 @@ BackwardGData<Idx, DType> AllocBackwardGData(
if (!utils::IsNoneArray(out_mapping)) { if (!utils::IsNoneArray(out_mapping)) {
gdata.out_mapping = static_cast<Idx*>(out_mapping->data); gdata.out_mapping = static_cast<Idx*>(out_mapping->data);
} }
// for dot operation: vector [dot] vector
if (op == binary_op::kDot) {
// get size of vector
gdata.data_len = lhs_data->shape[lhs_data->ndim - 1];
} else {
gdata.data_len = 1;
}
return gdata; return gdata;
} }
...@@ -182,13 +199,14 @@ void BackwardBinaryReduceImpl( ...@@ -182,13 +199,14 @@ void BackwardBinaryReduceImpl(
const bool req_lhs = !utils::IsNoneArray(grad_lhs_data); const bool req_lhs = !utils::IsNoneArray(grad_lhs_data);
const bool req_rhs = !utils::IsNoneArray(grad_rhs_data); const bool req_rhs = !utils::IsNoneArray(grad_rhs_data);
const auto bits = graph.NumBits(); const auto bits = graph.NumBits();
if (reducer == binary_op::kReduceMean) { if (reducer == binary_op::kReduceMean) {
// TODO(minjie): divide // TODO(minjie): divide
LOG(FATAL) << "reduce mean is not supported."; LOG(FATAL) << "reduce mean is not supported.";
} }
DGL_DTYPE_SWITCH(dtype, DType, { DGL_DTYPE_SWITCH(dtype, DType, {
DGL_IDX_TYPE_SWITCH(bits, Idx, { DGL_IDX_TYPE_SWITCH(bits, Idx, {
auto gdata = AllocBackwardGData<XPU, Idx, DType>( auto gdata = AllocBackwardGData<XPU, Idx, DType>(op,
rtcfg.ctx, x_len, lhs_mapping, rhs_mapping, out_mapping, rtcfg.ctx, x_len, lhs_mapping, rhs_mapping, out_mapping,
lhs_data, rhs_data, out_data, grad_out_data, lhs_data, rhs_data, out_data, grad_out_data,
grad_lhs_data, grad_rhs_data); grad_lhs_data, grad_rhs_data);
...@@ -240,6 +258,8 @@ BcastGData<NDim, Idx, DType> AllocBcastGData( ...@@ -240,6 +258,8 @@ BcastGData<NDim, Idx, DType> AllocBcastGData(
if (!utils::IsNoneArray(out_mapping)) { if (!utils::IsNoneArray(out_mapping)) {
gdata.out_mapping = static_cast<Idx*>(out_mapping->data); gdata.out_mapping = static_cast<Idx*>(out_mapping->data);
} }
gdata.data_len = info.data_len;
// fill out data with zero values // fill out data with zero values
utils::Fill<XPU>(ctx, gdata.out_data, utils::NElements(out_data), Zero<Reducer>::value); utils::Fill<XPU>(ctx, gdata.out_data, utils::NElements(out_data), Zero<Reducer>::value);
return gdata; return gdata;
...@@ -280,6 +300,7 @@ void BinaryReduceBcastImpl( ...@@ -280,6 +300,7 @@ void BinaryReduceBcastImpl(
const DLDataType& dtype = out_data->dtype; const DLDataType& dtype = out_data->dtype;
const int bcast_ndim = info.out_shape.size(); const int bcast_ndim = info.out_shape.size();
const auto bits = graph.NumBits(); const auto bits = graph.NumBits();
if (reducer == binary_op::kReduceMean) { if (reducer == binary_op::kReduceMean) {
// TODO(minjie): divide // TODO(minjie): divide
LOG(FATAL) << "reduce mean is not supported."; LOG(FATAL) << "reduce mean is not supported.";
...@@ -334,6 +355,8 @@ BackwardBcastGData<NDim, Idx, DType> AllocBackwardBcastGData( ...@@ -334,6 +355,8 @@ BackwardBcastGData<NDim, Idx, DType> AllocBackwardBcastGData(
if (!utils::IsNoneArray(out_mapping)) { if (!utils::IsNoneArray(out_mapping)) {
gdata.out_mapping = static_cast<Idx*>(out_mapping->data); gdata.out_mapping = static_cast<Idx*>(out_mapping->data);
} }
gdata.data_len = info.data_len;
// data // data
gdata.lhs_data = static_cast<DType*>(lhs->data); gdata.lhs_data = static_cast<DType*>(lhs->data);
gdata.rhs_data = static_cast<DType*>(rhs->data); gdata.rhs_data = static_cast<DType*>(rhs->data);
...@@ -387,6 +410,7 @@ void BackwardBinaryReduceBcastImpl( ...@@ -387,6 +410,7 @@ void BackwardBinaryReduceBcastImpl(
const bool req_lhs = !utils::IsNoneArray(grad_lhs); const bool req_lhs = !utils::IsNoneArray(grad_lhs);
const bool req_rhs = !utils::IsNoneArray(grad_rhs); const bool req_rhs = !utils::IsNoneArray(grad_rhs);
const auto bits = graph.NumBits(); const auto bits = graph.NumBits();
if (reducer == binary_op::kReduceMean) { if (reducer == binary_op::kReduceMean) {
// TODO(minjie): divide // TODO(minjie): divide
LOG(FATAL) << "reduce mean is not supported."; LOG(FATAL) << "reduce mean is not supported.";
......
...@@ -36,6 +36,8 @@ template <typename Idx, typename DType> ...@@ -36,6 +36,8 @@ template <typename Idx, typename DType>
struct GData { struct GData {
// length along x(feature) dimension // length along x(feature) dimension
int64_t x_length{0}; int64_t x_length{0};
// size of data, can be single value or a vector
int64_t data_len;
// number of rows of the output tensor // number of rows of the output tensor
int64_t out_size{0}; int64_t out_size{0};
// input data // input data
...@@ -119,6 +121,8 @@ template <typename Idx, typename DType> ...@@ -119,6 +121,8 @@ template <typename Idx, typename DType>
struct BackwardGData { struct BackwardGData {
// length along x(feature) dimension // length along x(feature) dimension
int64_t x_length{0}; int64_t x_length{0};
// size of data, can be single value or a vector
int64_t data_len;
// number of rows of the output tensor // number of rows of the output tensor
int64_t out_size{0}; int64_t out_size{0};
// input data // input data
...@@ -222,6 +226,8 @@ struct BcastGData { ...@@ -222,6 +226,8 @@ struct BcastGData {
int64_t lhs_len{0}, rhs_len{0}; int64_t lhs_len{0}, rhs_len{0};
int64_t lhs_shape[NDim]{0}, lhs_stride[NDim]{0}; int64_t lhs_shape[NDim]{0}, lhs_stride[NDim]{0};
int64_t rhs_shape[NDim]{0}, rhs_stride[NDim]{0}; int64_t rhs_shape[NDim]{0}, rhs_stride[NDim]{0};
// size of data, can be single value or a vector
int64_t data_len;
// input data // input data
DType *lhs_data{nullptr}, *rhs_data{nullptr}; DType *lhs_data{nullptr}, *rhs_data{nullptr};
// input id mappings // input id mappings
...@@ -326,6 +332,8 @@ struct BackwardBcastGData { ...@@ -326,6 +332,8 @@ struct BackwardBcastGData {
int64_t lhs_shape[NDim]{0}, lhs_stride[NDim]{0}; int64_t lhs_shape[NDim]{0}, lhs_stride[NDim]{0};
int64_t rhs_shape[NDim]{0}, rhs_stride[NDim]{0}; int64_t rhs_shape[NDim]{0}, rhs_stride[NDim]{0};
int64_t out_shape[NDim]{0}, out_stride[NDim]{0}; int64_t out_shape[NDim]{0}, out_stride[NDim]{0};
// size of data, can be single value or a vector
int64_t data_len;
// input id mappings // input id mappings
Idx *lhs_mapping{nullptr}, *rhs_mapping{nullptr}, *out_mapping{nullptr}; Idx *lhs_mapping{nullptr}, *rhs_mapping{nullptr}, *out_mapping{nullptr};
// input data // input data
......
...@@ -27,6 +27,7 @@ struct BackwardBinaryReduce { ...@@ -27,6 +27,7 @@ struct BackwardBinaryReduce {
static inline void ApplyEdge( static inline void ApplyEdge(
Idx src, Idx dst, Idx eid, BackwardGData<Idx, DType>* gdata) { Idx src, Idx dst, Idx eid, BackwardGData<Idx, DType>* gdata) {
const int64_t D = gdata->x_length; const int64_t D = gdata->x_length;
const int64_t len = gdata->data_len;
Idx lid = Functors::SelectLeft(src, eid, dst); Idx lid = Functors::SelectLeft(src, eid, dst);
Idx rid = Functors::SelectRight(src, eid, dst); Idx rid = Functors::SelectRight(src, eid, dst);
Idx oid = Functors::SelectOut(src, eid, dst); Idx oid = Functors::SelectOut(src, eid, dst);
...@@ -39,28 +40,46 @@ struct BackwardBinaryReduce { ...@@ -39,28 +40,46 @@ struct BackwardBinaryReduce {
if (gdata->out_mapping) { if (gdata->out_mapping) {
oid = Functors::GetId(oid, gdata->out_mapping); oid = Functors::GetId(oid, gdata->out_mapping);
} }
DType* lhsoff = gdata->lhs_data + lid * D; DType* lhsoff = gdata->lhs_data + lid * D * len;
DType* rhsoff = gdata->rhs_data + rid * D; DType* rhsoff = gdata->rhs_data + rid * D * len;
DType* outoff = gdata->out_data + oid * D; DType* outoff = gdata->out_data + oid * D;
DType* gradlhsoff = gdata->grad_lhs_data + lid * D; DType* gradlhsoff = gdata->grad_lhs_data + lid * D * len;
DType* gradrhsoff = gdata->grad_rhs_data + rid * D; DType* gradrhsoff = gdata->grad_rhs_data + rid * D * len;
DType* gradoutoff = gdata->grad_out_data + oid * D; DType* gradoutoff = gdata->grad_out_data + oid * D;
for (int64_t tx = 0; tx < D; ++tx) { for (int64_t tx = 0; tx < D; ++tx) {
DType lhs = Functors::Read(lhsoff + tx);
DType rhs = Functors::Read(rhsoff + tx);
DType out = Functors::Read(outoff + tx); DType out = Functors::Read(outoff + tx);
DType grad_out = Functors::Read(gradoutoff + tx); DType grad_out = Functors::Read(gradoutoff + tx);
DType e = Functors::Op(lhs, rhs); DType e = Functors::Op(lhsoff + tx * len, rhsoff + tx * len, len);
DType grad_e = grad_out * Functors::BackwardWrite(e, out); DType grad_e = grad_out * Functors::BackwardWrite(e, out);
if (Mode == binary_op::kGradLhs || Mode == binary_op::kGradBoth) {
DType* lhs_base = lhsoff + tx * len;
DType* rhs_base = rhsoff + tx * len;
if (Mode == binary_op::kGradBoth) {
for (int64_t i = 0; i < len; ++i) {
DType lhs = Functors::Read(lhs_base + i);
DType rhs = Functors::Read(rhs_base + i);
DType grad_lhs = grad_e * Functors::BackwardOpLhs(lhs, rhs, e);
DType grad_rhs = grad_e * Functors::BackwardOpRhs(lhs, rhs, e);
DType grad = grad_lhs + grad_rhs;
#pragma omp atomic
gradlhsoff[tx * len + i] += grad;
}
} else if (Mode == binary_op::kGradLhs) {
for (int64_t i = 0; i < len; ++i) {
DType lhs = Functors::Read(lhs_base + i);
DType rhs = Functors::Read(rhs_base + i);
DType grad_lhs = grad_e * Functors::BackwardOpLhs(lhs, rhs, e); DType grad_lhs = grad_e * Functors::BackwardOpLhs(lhs, rhs, e);
#pragma omp atomic #pragma omp atomic
gradlhsoff[tx] += grad_lhs; gradlhsoff[tx * len + i] += grad_lhs;
} }
if (Mode == binary_op::kGradRhs || Mode == binary_op::kGradBoth) { } else if (Mode == binary_op::kGradRhs) {
for (int64_t i = 0; i < len; ++i) {
DType lhs = Functors::Read(lhs_base + i);
DType rhs = Functors::Read(rhs_base + i);
DType grad_rhs = grad_e * Functors::BackwardOpRhs(lhs, rhs, e); DType grad_rhs = grad_e * Functors::BackwardOpRhs(lhs, rhs, e);
#pragma omp atomic #pragma omp atomic
gradrhsoff[tx] += grad_rhs; gradrhsoff[tx * len + i] += grad_rhs;
}
} }
} }
} }
...@@ -76,6 +95,7 @@ struct BackwardBinaryReduceBcast { ...@@ -76,6 +95,7 @@ struct BackwardBinaryReduceBcast {
} }
static inline void ApplyEdge( static inline void ApplyEdge(
Idx src, Idx dst, Idx eid, BackwardBcastGData<NDim, Idx, DType>* gdata) { Idx src, Idx dst, Idx eid, BackwardBcastGData<NDim, Idx, DType>* gdata) {
const int64_t len = gdata->data_len;
Idx lid = Functors::SelectLeft(src, eid, dst); Idx lid = Functors::SelectLeft(src, eid, dst);
Idx rid = Functors::SelectRight(src, eid, dst); Idx rid = Functors::SelectRight(src, eid, dst);
Idx oid = Functors::SelectOut(src, eid, dst); Idx oid = Functors::SelectOut(src, eid, dst);
...@@ -88,32 +108,53 @@ struct BackwardBinaryReduceBcast { ...@@ -88,32 +108,53 @@ struct BackwardBinaryReduceBcast {
if (gdata->out_mapping) { if (gdata->out_mapping) {
oid = Functors::GetId(oid, gdata->out_mapping); oid = Functors::GetId(oid, gdata->out_mapping);
} }
DType* lhsoff = gdata->lhs_data + lid * gdata->lhs_len; DType* lhsoff = gdata->lhs_data + lid * gdata->lhs_len * len;
DType* rhsoff = gdata->rhs_data + rid * gdata->rhs_len; DType* rhsoff = gdata->rhs_data + rid * gdata->rhs_len * len;
DType* outoff = gdata->out_data + oid * gdata->out_len; DType* outoff = gdata->out_data + oid * gdata->out_len;
DType* gradlhsoff = gdata->grad_lhs_data + lid * gdata->out_len; DType* gradlhsoff = gdata->grad_lhs_data + lid * gdata->out_len * len;
DType* gradrhsoff = gdata->grad_rhs_data + rid * gdata->out_len; DType* gradrhsoff = gdata->grad_rhs_data + rid * gdata->out_len * len;
DType* gradoutoff = gdata->grad_out_data + oid * gdata->out_len; DType* gradoutoff = gdata->grad_out_data + oid * gdata->out_len;
int64_t tmp[NDim]; // store unraveled idx. int64_t tmp[NDim]; // store unraveled idx.
for (int64_t tx = 0; tx < gdata->out_len; ++tx) { for (int64_t tx = 0; tx < gdata->out_len; ++tx) {
Unravel(tx, gdata->ndim, gdata->out_shape, gdata->out_stride, tmp); Unravel(tx, gdata->ndim, gdata->out_shape, gdata->out_stride, tmp);
DType lhs = Functors::Read(lhsoff +
Ravel(tmp, gdata->ndim, gdata->lhs_shape, gdata->lhs_stride));
DType rhs = Functors::Read(rhsoff +
Ravel(tmp, gdata->ndim, gdata->rhs_shape, gdata->rhs_stride));
DType out = Functors::Read(outoff + tx); DType out = Functors::Read(outoff + tx);
DType grad_out = Functors::Read(gradoutoff + tx); DType grad_out = Functors::Read(gradoutoff + tx);
DType e = Functors::Op(lhs, rhs); DType e = Functors::Op(
lhsoff + Ravel(tmp, gdata->ndim, gdata->lhs_shape, gdata->lhs_stride) * len,
rhsoff + Ravel(tmp, gdata->ndim, gdata->rhs_shape, gdata->rhs_stride) * len,
len);
DType grad_e = grad_out * Functors::BackwardWrite(e, out); DType grad_e = grad_out * Functors::BackwardWrite(e, out);
if (Mode == binary_op::kGradLhs || Mode == binary_op::kGradBoth) {
DType* lhs_base = lhsoff +
Ravel(tmp, gdata->ndim, gdata->lhs_shape, gdata->lhs_stride) * len;
DType* rhs_base = rhsoff +
Ravel(tmp, gdata->ndim, gdata->rhs_shape, gdata->rhs_stride) * len;
if (Mode == binary_op::kGradBoth) {
for (int64_t i = 0; i < len; ++i) {
DType lhs = Functors::Read(lhs_base + i);
DType rhs = Functors::Read(rhs_base + i);
DType grad_lhs = grad_e * Functors::BackwardOpLhs(lhs, rhs, e);
DType grad_rhs = grad_e * Functors::BackwardOpRhs(lhs, rhs, e);
DType grad = grad_lhs + grad_rhs;
#pragma omp atomic
gradlhsoff[tx * len + i] += grad;
}
} else if (Mode == binary_op::kGradLhs) {
for (int64_t i = 0; i < len; ++i) {
DType lhs = Functors::Read(lhs_base + i);
DType rhs = Functors::Read(rhs_base + i);
DType grad_lhs = grad_e * Functors::BackwardOpLhs(lhs, rhs, e); DType grad_lhs = grad_e * Functors::BackwardOpLhs(lhs, rhs, e);
#pragma omp atomic #pragma omp atomic
gradlhsoff[tx] += grad_lhs; gradlhsoff[tx * len + i] += grad_lhs;
} }
if (Mode == binary_op::kGradRhs || Mode == binary_op::kGradBoth) { } else if (Mode == binary_op::kGradRhs) {
for (int64_t i = 0; i < len; ++i) {
DType lhs = Functors::Read(lhs_base + i);
DType rhs = Functors::Read(rhs_base + i);
DType grad_rhs = grad_e * Functors::BackwardOpRhs(lhs, rhs, e); DType grad_rhs = grad_e * Functors::BackwardOpRhs(lhs, rhs, e);
#pragma omp atomic #pragma omp atomic
gradrhsoff[tx] += grad_rhs; gradrhsoff[tx * len + i] += grad_rhs;
}
} }
} }
} }
...@@ -137,8 +178,8 @@ struct BackwardFunctorsTempl { ...@@ -137,8 +178,8 @@ struct BackwardFunctorsTempl {
Idx src, Idx edge, Idx dst) { Idx src, Idx edge, Idx dst) {
return RightSelector::Call(src, edge, dst); return RightSelector::Call(src, edge, dst);
} }
static inline DType Op(DType lhs, DType rhs) { static inline DType Op(DType* lhs, DType* rhs, int64_t len) {
return BinaryOp::Call(lhs, rhs); return BinaryOp::Call(lhs, rhs, len);
} }
static inline DType Read(DType* addr) { static inline DType Read(DType* addr) {
return *addr; return *addr;
......
...@@ -29,6 +29,7 @@ struct BinaryReduce { ...@@ -29,6 +29,7 @@ struct BinaryReduce {
static inline void ApplyEdge( static inline void ApplyEdge(
Idx src, Idx dst, Idx eid, GData<Idx, DType>* gdata) { Idx src, Idx dst, Idx eid, GData<Idx, DType>* gdata) {
const int64_t D = gdata->x_length; const int64_t D = gdata->x_length;
const int64_t len = gdata->data_len;
Idx lid = Functors::SelectLeft(src, eid, dst); Idx lid = Functors::SelectLeft(src, eid, dst);
Idx rid = Functors::SelectRight(src, eid, dst); Idx rid = Functors::SelectRight(src, eid, dst);
Idx oid = Functors::SelectOut(src, eid, dst); Idx oid = Functors::SelectOut(src, eid, dst);
...@@ -41,13 +42,11 @@ struct BinaryReduce { ...@@ -41,13 +42,11 @@ struct BinaryReduce {
if (gdata->out_mapping) { if (gdata->out_mapping) {
oid = Functors::GetId(oid, gdata->out_mapping); oid = Functors::GetId(oid, gdata->out_mapping);
} }
DType* lhsoff = gdata->lhs_data + lid * D; DType* lhsoff = gdata->lhs_data + lid * D * len;
DType* rhsoff = gdata->rhs_data + rid * D; DType* rhsoff = gdata->rhs_data + rid * D * len;
DType* outoff = gdata->out_data + oid * D; DType* outoff = gdata->out_data + oid * D;
for (int64_t tx = 0; tx < D; ++tx) { for (int64_t tx = 0; tx < D; ++tx) {
DType lhs = Functors::Read(lhsoff + tx); DType out = Functors::Op(lhsoff + tx * len, rhsoff + tx * len, len);
DType rhs = Functors::Read(rhsoff + tx);
DType out = Functors::Op(lhs, rhs);
Functors::Write(outoff + tx, out); Functors::Write(outoff + tx, out);
} }
} }
...@@ -80,6 +79,7 @@ struct BinaryReduceBcast { ...@@ -80,6 +79,7 @@ struct BinaryReduceBcast {
} }
static inline void ApplyEdge( static inline void ApplyEdge(
Idx src, Idx dst, Idx eid, BcastGData<NDim, Idx, DType>* gdata) { Idx src, Idx dst, Idx eid, BcastGData<NDim, Idx, DType>* gdata) {
const int64_t len = gdata->data_len;
Idx lid = Functors::SelectLeft(src, eid, dst); Idx lid = Functors::SelectLeft(src, eid, dst);
Idx rid = Functors::SelectRight(src, eid, dst); Idx rid = Functors::SelectRight(src, eid, dst);
Idx oid = Functors::SelectOut(src, eid, dst); Idx oid = Functors::SelectOut(src, eid, dst);
...@@ -92,17 +92,17 @@ struct BinaryReduceBcast { ...@@ -92,17 +92,17 @@ struct BinaryReduceBcast {
if (gdata->out_mapping) { if (gdata->out_mapping) {
oid = Functors::GetId(oid, gdata->out_mapping); oid = Functors::GetId(oid, gdata->out_mapping);
} }
DType* lhsoff = gdata->lhs_data + lid * gdata->lhs_len; DType* lhsoff = gdata->lhs_data + lid * gdata->lhs_len * len; // data with len size
DType* rhsoff = gdata->rhs_data + rid * gdata->rhs_len; DType* rhsoff = gdata->rhs_data + rid * gdata->rhs_len * len;
DType* outoff = gdata->out_data + oid * gdata->out_len; DType* outoff = gdata->out_data + oid * gdata->out_len;
int64_t tmp[NDim]; // store unraveled idx. int64_t tmp[NDim]; // store unraveled idx.
for (int64_t tx = 0; tx < gdata->out_len; ++tx) { for (int64_t tx = 0; tx < gdata->out_len; ++tx) {
Unravel(tx, gdata->ndim, gdata->out_shape, gdata->out_stride, tmp); Unravel(tx, gdata->ndim, gdata->out_shape, gdata->out_stride, tmp);
DType lhs = Functors::Read(lhsoff + DType out = Functors::Op(
Ravel(tmp, gdata->ndim, gdata->lhs_shape, gdata->lhs_stride)); lhsoff + Ravel(tmp, gdata->ndim, gdata->lhs_shape, gdata->lhs_stride) * len,
DType rhs = Functors::Read(rhsoff + rhsoff + Ravel(tmp, gdata->ndim, gdata->rhs_shape, gdata->rhs_stride) * len,
Ravel(tmp, gdata->ndim, gdata->rhs_shape, gdata->rhs_stride)); len);
DType out = Functors::Op(lhs, rhs);
Functors::Write(outoff + tx, out); Functors::Write(outoff + tx, out);
} }
} }
...@@ -125,11 +125,8 @@ struct FunctorsTempl { ...@@ -125,11 +125,8 @@ struct FunctorsTempl {
Idx src, Idx edge, Idx dst) { Idx src, Idx edge, Idx dst) {
return RightSelector::Call(src, edge, dst); return RightSelector::Call(src, edge, dst);
} }
static inline DType Op(DType lhs, DType rhs) { static inline DType Op(DType *lhs, DType *rhs, int64_t len) {
return BinaryOp::Call(lhs, rhs); return BinaryOp::Call(lhs, rhs, len);
}
static inline DType Read(DType* addr) {
return *addr;
} }
static inline void Write(DType* addr, DType val) { static inline void Write(DType* addr, DType val) {
Reducer::Call(addr, val); Reducer::Call(addr, val);
......
...@@ -29,6 +29,7 @@ struct BackwardBinaryReduce { ...@@ -29,6 +29,7 @@ struct BackwardBinaryReduce {
const int64_t D = gdata->x_length; const int64_t D = gdata->x_length;
int64_t tx = blockIdx.x * blockDim.x + threadIdx.x; int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;
int stride_x = blockDim.x * gridDim.x; int stride_x = blockDim.x * gridDim.x;
const int64_t len = gdata->data_len;
Idx lid = Functors::SelectLeft(src, eid, dst); Idx lid = Functors::SelectLeft(src, eid, dst);
Idx rid = Functors::SelectRight(src, eid, dst); Idx rid = Functors::SelectRight(src, eid, dst);
Idx oid = Functors::SelectOut(src, eid, dst); Idx oid = Functors::SelectOut(src, eid, dst);
...@@ -41,50 +42,52 @@ struct BackwardBinaryReduce { ...@@ -41,50 +42,52 @@ struct BackwardBinaryReduce {
if (gdata->out_mapping) { if (gdata->out_mapping) {
oid = Functors::GetId(oid, gdata->out_mapping); oid = Functors::GetId(oid, gdata->out_mapping);
} }
DType* lhsoff = gdata->lhs_data + lid * D; DType* lhsoff = gdata->lhs_data + lid * D * len;
DType* rhsoff = gdata->rhs_data + rid * D; DType* rhsoff = gdata->rhs_data + rid * D * len;
DType* outoff = gdata->out_data + oid * D; DType* outoff = gdata->out_data + oid * D;
DType* gradlhsoff = gdata->grad_lhs_data + lid * D; DType* gradlhsoff = gdata->grad_lhs_data + lid * D * len;
DType* gradrhsoff = gdata->grad_rhs_data + rid * D; DType* gradrhsoff = gdata->grad_rhs_data + rid * D * len;
DType* gradoutoff = gdata->grad_out_data + oid * D; DType* gradoutoff = gdata->grad_out_data + oid * D;
while (tx < D) { while (tx < D) {
DType lhs = Functors::Read(lhsoff + tx);
DType rhs = Functors::Read(rhsoff + tx);
DType out = Functors::Read(outoff + tx); DType out = Functors::Read(outoff + tx);
DType grad_out = Functors::Read(gradoutoff + tx); DType grad_out = Functors::Read(gradoutoff + tx);
DType e = Functors::Op(lhs, rhs); DType e = Functors::Op(lhsoff + tx * len, rhsoff + tx * len, len);
DType grad_e = grad_out * Functors::BackwardWrite(e, out); DType grad_e = grad_out * Functors::BackwardWrite(e, out);
if (Mode == binary_op::kGradLhs || Mode == binary_op::kGradBoth) {
DType* lhs_base = lhsoff + tx * len;
DType* rhs_base = rhsoff + tx * len;
if (Mode == binary_op::kGradBoth) {
#pragma unroll
for (int64_t i = 0; i < len; ++i) {
DType lhs = Functors::Read(lhs_base + i);
DType rhs = Functors::Read(rhs_base + i);
DType grad_lhs = grad_e * Functors::BackwardOpLhs(lhs, rhs, e);
DType grad_rhs = grad_e * Functors::BackwardOpRhs(lhs, rhs, e);
DType grad = grad_lhs + grad_rhs;
AtomicAdd(gradlhsoff + tx * len + i, grad);
}
} else if (Mode == binary_op::kGradLhs) {
#pragma unroll
for (int64_t i = 0; i < len; ++i) {
DType lhs = Functors::Read(lhs_base + i);
DType rhs = Functors::Read(rhs_base + i);
DType grad_lhs = grad_e * Functors::BackwardOpLhs(lhs, rhs, e); DType grad_lhs = grad_e * Functors::BackwardOpLhs(lhs, rhs, e);
AtomicAdd(gradlhsoff + tx, grad_lhs); AtomicAdd(gradlhsoff + tx * len + i, grad_lhs);
} }
if (Mode == binary_op::kGradRhs || Mode == binary_op::kGradBoth) { } else if (Mode == binary_op::kGradRhs) {
#pragma unroll
for (int64_t i = 0; i < len; ++i) {
DType lhs = Functors::Read(lhs_base + i);
DType rhs = Functors::Read(rhs_base + i);
DType grad_rhs = grad_e * Functors::BackwardOpRhs(lhs, rhs, e); DType grad_rhs = grad_e * Functors::BackwardOpRhs(lhs, rhs, e);
AtomicAdd(gradrhsoff + tx, grad_rhs); AtomicAdd(gradrhsoff + tx * len + i, grad_rhs);
}
} }
tx += stride_x; tx += stride_x;
} }
} }
}; };
// Convert flattened index to multi-dimension index (assume row-major).
__device__ __forceinline__ void Unravel(
int64_t idx, int ndim, const int64_t* shape, const int64_t* stride, int64_t* out) {
for (int d = 0; d < ndim; ++d) {
out[d] = (idx / stride[d]) % shape[d];
}
}
// Convert multi-dimension index to flattened index (assume row-major).
__device__ __forceinline__ int64_t Ravel(
const int64_t* idx, int ndim, const int64_t* shape, const int64_t* stride) {
int64_t out = 0;
for (int d = 0; d < ndim; ++d) {
out += min(idx[d], shape[d] - 1) * stride[d];
}
return out;
}
// Minigun UDF to compute backward binary reduce with broadcasting. // Minigun UDF to compute backward binary reduce with broadcasting.
template <int Mode, int NDim, typename Idx, typename DType, typename Functors> template <int Mode, int NDim, typename Idx, typename DType, typename Functors>
struct BackwardBinaryReduceBcast { struct BackwardBinaryReduceBcast {
...@@ -96,6 +99,7 @@ struct BackwardBinaryReduceBcast { ...@@ -96,6 +99,7 @@ struct BackwardBinaryReduceBcast {
Idx src, Idx dst, Idx eid, BackwardBcastGData<NDim, Idx, DType>* gdata) { Idx src, Idx dst, Idx eid, BackwardBcastGData<NDim, Idx, DType>* gdata) {
int64_t tx = blockIdx.x * blockDim.x + threadIdx.x; int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;
int stride_x = blockDim.x * gridDim.x; int stride_x = blockDim.x * gridDim.x;
const int64_t len = gdata->data_len;
Idx lid = Functors::SelectLeft(src, eid, dst); Idx lid = Functors::SelectLeft(src, eid, dst);
Idx rid = Functors::SelectRight(src, eid, dst); Idx rid = Functors::SelectRight(src, eid, dst);
Idx oid = Functors::SelectOut(src, eid, dst); Idx oid = Functors::SelectOut(src, eid, dst);
...@@ -108,30 +112,51 @@ struct BackwardBinaryReduceBcast { ...@@ -108,30 +112,51 @@ struct BackwardBinaryReduceBcast {
if (gdata->out_mapping) { if (gdata->out_mapping) {
oid = Functors::GetId(oid, gdata->out_mapping); oid = Functors::GetId(oid, gdata->out_mapping);
} }
DType* lhsoff = gdata->lhs_data + lid * gdata->lhs_len; DType* lhsoff = gdata->lhs_data + lid * gdata->lhs_len * len;
DType* rhsoff = gdata->rhs_data + rid * gdata->rhs_len; DType* rhsoff = gdata->rhs_data + rid * gdata->rhs_len * len;
DType* outoff = gdata->out_data + oid * gdata->out_len; DType* outoff = gdata->out_data + oid * gdata->out_len;
DType* gradlhsoff = gdata->grad_lhs_data + lid * gdata->out_len; DType* gradlhsoff = gdata->grad_lhs_data + lid * gdata->out_len * len;
DType* gradrhsoff = gdata->grad_rhs_data + rid * gdata->out_len; DType* gradrhsoff = gdata->grad_rhs_data + rid * gdata->out_len * len;
DType* gradoutoff = gdata->grad_out_data + oid * gdata->out_len; DType* gradoutoff = gdata->grad_out_data + oid * gdata->out_len;
int64_t tmp[NDim]; // store unraveled idx.
while (tx < gdata->out_len) { while (tx < gdata->out_len) {
Unravel(tx, gdata->ndim, gdata->out_shape, gdata->out_stride, tmp); int64_t lhs_add = 0;
DType lhs = Functors::Read(lhsoff + int64_t rhs_add = 0;
Ravel(tmp, gdata->ndim, gdata->lhs_shape, gdata->lhs_stride)); UnravelRavel(tx, gdata->ndim, gdata->out_shape, gdata->out_stride,
DType rhs = Functors::Read(rhsoff + gdata->lhs_shape, gdata->lhs_stride,
Ravel(tmp, gdata->ndim, gdata->rhs_shape, gdata->rhs_stride)); gdata->rhs_shape, gdata->rhs_stride, &lhs_add, &rhs_add);
DType out = Functors::Read(outoff + tx); DType out = Functors::Read(outoff + tx);
DType grad_out = Functors::Read(gradoutoff + tx); DType grad_out = Functors::Read(gradoutoff + tx);
DType e = Functors::Op(lhs, rhs); DType e = Functors::Op(lhsoff + lhs_add * len, rhsoff + rhs_add * len, len);
DType grad_e = grad_out * Functors::BackwardWrite(e, out); DType grad_e = grad_out * Functors::BackwardWrite(e, out);
if (Mode == binary_op::kGradLhs || Mode == binary_op::kGradBoth) {
DType* lhs_base = lhsoff + lhs_add * len;
DType* rhs_base = rhsoff + rhs_add * len;
if (Mode == binary_op::kGradBoth) {
#pragma unroll
for (int64_t i = 0; i < len; ++i) {
DType lhs = Functors::Read(lhs_base + i);
DType rhs = Functors::Read(rhs_base + i);
DType grad_lhs = grad_e * Functors::BackwardOpLhs(lhs, rhs, e); DType grad_lhs = grad_e * Functors::BackwardOpLhs(lhs, rhs, e);
AtomicAdd(gradlhsoff + tx, grad_lhs); DType grad_rhs = grad_e * Functors::BackwardOpRhs(lhs, rhs, e);
DType grad = grad_lhs + grad_rhs;
AtomicAdd(gradlhsoff + tx * len + i, grad);
}
} else if (Mode == binary_op::kGradLhs) {
#pragma unroll
for (int64_t i = 0; i < len; ++i) {
DType lhs = Functors::Read(lhs_base + i);
DType rhs = Functors::Read(rhs_base + i);
DType grad_lhs = grad_e * Functors::BackwardOpLhs(lhs, rhs, e);
AtomicAdd(gradlhsoff + tx * len + i, grad_lhs);
} }
if (Mode == binary_op::kGradRhs || Mode == binary_op::kGradBoth) { } else if (Mode == binary_op::kGradRhs) {
#pragma unroll
for (int64_t i = 0; i < len; ++i) {
DType lhs = Functors::Read(lhs_base + i);
DType rhs = Functors::Read(rhs_base + i);
DType grad_rhs = grad_e * Functors::BackwardOpRhs(lhs, rhs, e); DType grad_rhs = grad_e * Functors::BackwardOpRhs(lhs, rhs, e);
AtomicAdd(gradrhsoff + tx, grad_rhs); AtomicAdd(gradrhsoff + tx * len + i, grad_rhs);
}
} }
tx += stride_x; tx += stride_x;
} }
...@@ -156,8 +181,8 @@ struct BackwardFunctorsTempl { ...@@ -156,8 +181,8 @@ struct BackwardFunctorsTempl {
Idx src, Idx edge, Idx dst) { Idx src, Idx edge, Idx dst) {
return RightSelector::Call(src, edge, dst); return RightSelector::Call(src, edge, dst);
} }
static __device__ __forceinline__ DType Op(DType lhs, DType rhs) { static __device__ __forceinline__ DType Op(DType* lhs, DType* rhs, int64_t len) {
return BinaryOp::Call(lhs, rhs); return BinaryOp::Call(lhs, rhs, len);
} }
static __device__ __forceinline__ DType Read(DType* addr) { static __device__ __forceinline__ DType Read(DType* addr) {
return LDGReader<DType>::Call(addr); return LDGReader<DType>::Call(addr);
......
...@@ -28,7 +28,8 @@ struct BinaryReduce { ...@@ -28,7 +28,8 @@ struct BinaryReduce {
Idx src, Idx dst, Idx eid, GData<Idx, DType>* gdata) { Idx src, Idx dst, Idx eid, GData<Idx, DType>* gdata) {
const int64_t D = gdata->x_length; const int64_t D = gdata->x_length;
int64_t tx = blockIdx.x * blockDim.x + threadIdx.x; int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;
int stride_x = blockDim.x * gridDim.x; const int64_t stride_x = blockDim.x * gridDim.x;
const int64_t len = gdata->data_len;
Idx lid = Functors::SelectLeft(src, eid, dst); Idx lid = Functors::SelectLeft(src, eid, dst);
Idx rid = Functors::SelectRight(src, eid, dst); Idx rid = Functors::SelectRight(src, eid, dst);
Idx oid = Functors::SelectOut(src, eid, dst); Idx oid = Functors::SelectOut(src, eid, dst);
...@@ -41,13 +42,11 @@ struct BinaryReduce { ...@@ -41,13 +42,11 @@ struct BinaryReduce {
if (gdata->out_mapping) { if (gdata->out_mapping) {
oid = Functors::GetId(oid, gdata->out_mapping); oid = Functors::GetId(oid, gdata->out_mapping);
} }
DType* lhsoff = gdata->lhs_data + lid * D; DType* lhsoff = gdata->lhs_data + lid * D * len;
DType* rhsoff = gdata->rhs_data + rid * D; DType* rhsoff = gdata->rhs_data + rid * D * len;
DType* outoff = gdata->out_data + oid * D; DType* outoff = gdata->out_data + oid * D;
while (tx < D) { while (tx < D) {
DType lhs = Functors::Read(lhsoff + tx); DType out = Functors::Op(lhsoff + tx * len, rhsoff + tx * len, len);
DType rhs = Functors::Read(rhsoff + tx);
DType out = Functors::Op(lhs, rhs);
Functors::Write(outoff + tx, out); Functors::Write(outoff + tx, out);
tx += stride_x; tx += stride_x;
} }
...@@ -72,7 +71,6 @@ __device__ __forceinline__ void UnravelRavel( ...@@ -72,7 +71,6 @@ __device__ __forceinline__ void UnravelRavel(
int64_t o_st = out_stride[d]; int64_t o_st = out_stride[d];
int64_t rhs_sh = rhs_shape[d]; int64_t rhs_sh = rhs_shape[d];
int64_t rhs_st = rhs_stride[d]; int64_t rhs_st = rhs_stride[d];
int64_t i = (idx / o_st) % o_sh; int64_t i = (idx / o_st) % o_sh;
/* /*
* Simplfied for rhs_out += min(i, rhs_sh - 1) * rhs_st; * Simplfied for rhs_out += min(i, rhs_sh - 1) * rhs_st;
...@@ -114,7 +112,8 @@ struct BinaryReduceBcast { ...@@ -114,7 +112,8 @@ struct BinaryReduceBcast {
static __device__ __forceinline__ void ApplyEdge( static __device__ __forceinline__ void ApplyEdge(
Idx src, Idx dst, Idx eid, BcastGData<NDim, Idx, DType>* gdata) { Idx src, Idx dst, Idx eid, BcastGData<NDim, Idx, DType>* gdata) {
int64_t tx = blockIdx.x * blockDim.x + threadIdx.x; int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;
int stride_x = blockDim.x * gridDim.x; const int64_t stride_x = blockDim.x * gridDim.x;
const int64_t len = gdata->data_len;
Idx lid = Functors::SelectLeft(src, eid, dst); Idx lid = Functors::SelectLeft(src, eid, dst);
Idx rid = Functors::SelectRight(src, eid, dst); Idx rid = Functors::SelectRight(src, eid, dst);
Idx oid = Functors::SelectOut(src, eid, dst); Idx oid = Functors::SelectOut(src, eid, dst);
...@@ -127,8 +126,8 @@ struct BinaryReduceBcast { ...@@ -127,8 +126,8 @@ struct BinaryReduceBcast {
if (gdata->out_mapping) { if (gdata->out_mapping) {
oid = Functors::GetId(oid, gdata->out_mapping); oid = Functors::GetId(oid, gdata->out_mapping);
} }
DType* lhsoff = gdata->lhs_data + lid * gdata->lhs_len; DType* lhsoff = gdata->lhs_data + lid * gdata->lhs_len * len; //data with len size
DType* rhsoff = gdata->rhs_data + rid * gdata->rhs_len; DType* rhsoff = gdata->rhs_data + rid * gdata->rhs_len * len;
DType* outoff = gdata->out_data + oid * gdata->out_len; DType* outoff = gdata->out_data + oid * gdata->out_len;
while (tx < gdata->out_len) { while (tx < gdata->out_len) {
int64_t lhs_add = 0; int64_t lhs_add = 0;
...@@ -136,9 +135,8 @@ struct BinaryReduceBcast { ...@@ -136,9 +135,8 @@ struct BinaryReduceBcast {
UnravelRavel(tx, gdata->ndim, gdata->out_shape, gdata->out_stride, UnravelRavel(tx, gdata->ndim, gdata->out_shape, gdata->out_stride,
gdata->lhs_shape, gdata->lhs_stride, gdata->lhs_shape, gdata->lhs_stride,
gdata->rhs_shape, gdata->rhs_stride, &lhs_add, &rhs_add); gdata->rhs_shape, gdata->rhs_stride, &lhs_add, &rhs_add);
DType lhs = Functors::Read(lhsoff + lhs_add); DType out = Functors::Op(lhsoff + lhs_add * len, rhsoff + rhs_add * len, len);
DType rhs = Functors::Read(rhsoff + rhs_add);
DType out = Functors::Op(lhs, rhs);
Functors::Write(outoff + tx, out); Functors::Write(outoff + tx, out);
tx += stride_x; tx += stride_x;
} }
...@@ -162,11 +160,8 @@ struct FunctorsTempl { ...@@ -162,11 +160,8 @@ struct FunctorsTempl {
Idx src, Idx edge, Idx dst) { Idx src, Idx edge, Idx dst) {
return RightSelector::Call(src, edge, dst); return RightSelector::Call(src, edge, dst);
} }
static __device__ __forceinline__ DType Op(DType lhs, DType rhs) { static __device__ __forceinline__ DType Op(DType *lhs, DType *rhs, int64_t len) {
return BinaryOp::Call(lhs, rhs); return BinaryOp::Call(lhs, rhs, len);
}
static __device__ __forceinline__ DType Read(DType* addr) {
return LDGReader<DType>::Call(addr);
} }
static __device__ __forceinline__ void Write(DType* addr, DType val) { static __device__ __forceinline__ void Write(DType* addr, DType val) {
Reducer::Call(addr, val); Reducer::Call(addr, val);
......
...@@ -115,6 +115,10 @@ def matmul(a, b): ...@@ -115,6 +115,10 @@ def matmul(a, b):
"""Compute Matrix Multiplication between a and b""" """Compute Matrix Multiplication between a and b"""
pass pass
def dot(a, b):
"""Compute Dot between a and b"""
pass
############################################################################### ###############################################################################
# Tensor functions used *only* on index tensor # Tensor functions used *only* on index tensor
# ---------------- # ----------------
......
...@@ -86,6 +86,9 @@ def prod(x, dim): ...@@ -86,6 +86,9 @@ def prod(x, dim):
def matmul(a, b): def matmul(a, b):
return nd.dot(a, b) return nd.dot(a, b)
def dot(a, b):
return nd.sum(mul(a, b), axis=-1)
record_grad = autograd.record record_grad = autograd.record
......
...@@ -82,6 +82,9 @@ def prod(x, dim): ...@@ -82,6 +82,9 @@ def prod(x, dim):
def matmul(a, b): def matmul(a, b):
return a @ b return a @ b
def dot(a, b):
return sum(mul(a, b), dim=-1)
class record_grad(object): class record_grad(object):
def __init__(self): def __init__(self):
pass pass
......
...@@ -26,17 +26,36 @@ def udf_max(nodes): ...@@ -26,17 +26,36 @@ def udf_max(nodes):
D1 = 5 D1 = 5
D2 = 3 D2 = 3
D3 = 4 D3 = 4
D4 = 10 # NOTE(xiang): used to dot feature vector
builtin = {'sum': fn.sum, 'max': fn.max, 'mean': fn.mean} builtin = {'sum': fn.sum, 'max': fn.max, 'mean': fn.mean}
udf_reduce = {'sum': udf_sum, 'max': udf_max, 'mean': udf_mean} udf_reduce = {'sum': udf_sum, 'max': udf_max, 'mean': udf_mean}
fill_value = {'sum': 0, 'max': float("-inf")} fill_value = {'sum': 0, 'max': float("-inf")}
def generate_feature(g, broadcast='none'): def generate_feature(g, broadcast='none', binary_op='none'):
"""Create graph with src, edge, dst feature. broadcast can be 'u', """Create graph with src, edge, dst feature. broadcast can be 'u',
'e', 'v', 'none' 'e', 'v', 'none'
""" """
nv = g.number_of_nodes() nv = g.number_of_nodes()
ne = g.number_of_edges() ne = g.number_of_edges()
if binary_op == 'dot':
if broadcast == 'e':
u = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3, D4)))
e = F.tensor(np.random.uniform(-1, 1, (ne, D2, 1, D4)))
v = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3, D4)))
elif broadcast == 'u':
u = F.tensor(np.random.uniform(-1, 1, (nv, D2, 1, D4)))
e = F.tensor(np.random.uniform(-1, 1, (ne, D1, D2, D3, D4)))
v = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3, D4)))
elif broadcast == 'v':
u = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3, D4)))
e = F.tensor(np.random.uniform(-1, 1, (ne, D1, D2, D3, D4)))
v = F.tensor(np.random.uniform(-1, 1, (nv, D2, 1, D4)))
else:
u = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3, D4)))
e = F.tensor(np.random.uniform(-1, 1, (ne, D1, D2, D3, D4)))
v = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3, D4)))
else:
if broadcast == 'e': if broadcast == 'e':
u = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3))) u = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3)))
e = F.tensor(np.random.uniform(-1, 1, (ne, D2, 1))) e = F.tensor(np.random.uniform(-1, 1, (ne, D2, 1)))
...@@ -62,7 +81,7 @@ def test_copy_src_reduce(): ...@@ -62,7 +81,7 @@ def test_copy_src_reduce():
# NOTE(zihao): add self-loop to avoid zero-degree nodes. # NOTE(zihao): add self-loop to avoid zero-degree nodes.
# https://github.com/dmlc/dgl/issues/761 # https://github.com/dmlc/dgl/issues/761
g.add_edges(g.nodes(), g.nodes()) g.add_edges(g.nodes(), g.nodes())
hu, hv, he = generate_feature(g, 'none') hu, hv, he = generate_feature(g, 'none', 'none')
if partial: if partial:
nid = F.tensor(list(range(0, 100, 2))) nid = F.tensor(list(range(0, 100, 2)))
...@@ -123,7 +142,7 @@ def test_copy_edge_reduce(): ...@@ -123,7 +142,7 @@ def test_copy_edge_reduce():
g = dgl.DGLGraph(nx.erdos_renyi_graph(100, 0.1)) g = dgl.DGLGraph(nx.erdos_renyi_graph(100, 0.1))
# NOTE(zihao): add self-loop to avoid zero-degree nodes. # NOTE(zihao): add self-loop to avoid zero-degree nodes.
g.add_edges(g.nodes(), g.nodes()) g.add_edges(g.nodes(), g.nodes())
hu, hv, he = generate_feature(g, 'none') hu, hv, he = generate_feature(g, 'none', 'none')
if partial: if partial:
nid = F.tensor(list(range(0, 100, 2))) nid = F.tensor(list(range(0, 100, 2)))
...@@ -182,7 +201,7 @@ def test_copy_edge_reduce(): ...@@ -182,7 +201,7 @@ def test_copy_edge_reduce():
def test_all_binary_builtins(): def test_all_binary_builtins():
def _test(g, lhs, rhs, binary_op, reducer, partial, nid, broadcast='none'): def _test(g, lhs, rhs, binary_op, reducer, partial, nid, broadcast='none'):
# initialize node/edge features with uniform(-1, 1) # initialize node/edge features with uniform(-1, 1)
hu, hv, he = generate_feature(g, broadcast) hu, hv, he = generate_feature(g, broadcast, binary_op)
if binary_op == 'div': if binary_op == 'div':
# op = div # op = div
# lhs range: [-1, 1] # lhs range: [-1, 1]
...@@ -321,7 +340,7 @@ def test_all_binary_builtins(): ...@@ -321,7 +340,7 @@ def test_all_binary_builtins():
for lhs, rhs in product(target, target): for lhs, rhs in product(target, target):
if lhs == rhs: if lhs == rhs:
continue continue
for binary_op in ["add", "sub", "mul", "div"]: for binary_op in ["add", "sub", "mul", "div", "dot"]:
for reducer in ["sum", "max", "min", "prod", "mean"]: for reducer in ["sum", "max", "min", "prod", "mean"]:
for broadcast in ["none", lhs, rhs]: for broadcast in ["none", lhs, rhs]:
for partial in [False, True]: for partial in [False, True]:
...@@ -332,3 +351,4 @@ if __name__ == '__main__': ...@@ -332,3 +351,4 @@ if __name__ == '__main__':
#test_copy_src_reduce() #test_copy_src_reduce()
#test_copy_edge_reduce() #test_copy_edge_reduce()
test_all_binary_builtins() test_all_binary_builtins()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment