Commit 0a56d652 authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by Zihao Ye
Browse files

[Feature] x_dot_x builtin kernel support (#831)

* upd

* fig edgebatch edges

* add test

* trigger

* Update README.md for pytorch PinSage example.

Add noting that the PinSage model example under
example/pytorch/recommendation only work with Python 3.6+
as its dataset loader depends on stanfordnlp package
which work only with Python 3.6+.

* Provid a frame agnostic API to test nn modules on both CPU and CUDA side.

1. make dgl.nn.xxx frame agnostic
2. make test.backend include dgl.nn modules
3. modify test_edge_softmax of test/mxnet/test_nn.py and
    test/pytorch/test_nn.py work on both CPU and GPU

* Fix style

* Delete unused code

* Make agnostic test only related to tests/backend

1. clear all agnostic related code in dgl.nn
2. make test_graph_conv agnostic to cpu/gpu

* Fix code style

* fix

* doc

* Make all test code under tests.mxnet/pytorch.test_nn.py
work on both CPU and GPU.

* Fix syntex

* Remove rand

* Start implementing masked-mm kernel.

A...
parent 4297e6d6
......@@ -40,6 +40,12 @@ Message functions
e_sub_v
e_mul_v
e_div_v
u_dot_v
u_dot_e
v_dot_e
v_dot_u
e_dot_u
e_dot_v
Reduce functions
----------------
......
......@@ -98,6 +98,10 @@ Here is a cheatsheet of all the DGL builtins.
| +----------------------------------------------------+-----------------------+
| | ``e_add_v``, ``e_sub_v``, ``e_mul_v``, ``e_div_v`` | |
| +----------------------------------------------------+-----------------------+
| | ``u_dot_v``, ``u_dot_e``, ``v_dot_e`` | |
| +----------------------------------------------------+-----------------------+
| | ``v_dot_u``, ``e_dot_u``, ``e_dot_v`` | |
| +----------------------------------------------------+-----------------------+
| | ``src_mul_edge`` | alias of ``u_mul_e`` |
+-------------------------+----------------------------------------------------+-----------------------+
| Reduce function | ``max`` | |
......
......@@ -376,8 +376,11 @@ class BinaryReduce(mx.autograd.Function):
def forward(self, lhs_data, rhs_data):
lhs_data_nd = zerocopy_to_dgl_ndarray(lhs_data)
rhs_data_nd = zerocopy_to_dgl_ndarray(rhs_data)
feat_shape = K.infer_binary_feature_shape(lhs_data_nd, rhs_data_nd)
out_data = nd.empty((self.out_size,) + feat_shape,
feat_shape = K.infer_binary_feature_shape(self.binary_op, lhs_data_nd, rhs_data_nd)
out_shape = feat_shape
if self.binary_op == 'dot':
out_shape = feat_shape[:-1]
out_data = nd.empty((self.out_size,) + out_shape,
ctx=lhs_data.context, dtype=lhs_data.dtype)
out_data_nd = zerocopy_to_dgl_ndarray_for_write(out_data)
K.binary_op_reduce(
......@@ -402,10 +405,10 @@ class BinaryReduce(mx.autograd.Function):
in_ones = nd.ones((n,), ctx=lhs_data.context, dtype=lhs_data.dtype)
in_ones_nd = zerocopy_to_dgl_ndarray(in_ones)
K.copy_reduce(
'sum', self.graph, target, in_ones_nd, degs_nd,
'sum', self.graph, target, in_ones_nd, degs_nd,
in_map, self.out_map[0])
# reshape
degs = degs.reshape((out_data.shape[0],) + (1,) * (out_data.ndim - 1)).clip(1, float('inf'))
degs = degs.reshape((out_data.shape[0],) + (1,) * (out_data.ndim - 1)).clip(1, float('inf'))
out_data = out_data / degs
else:
degs = None
......
......@@ -288,11 +288,14 @@ class BinaryReduce(th.autograd.Function):
out_size, lhs_map, rhs_map, out_map):
lhs_data_nd = zerocopy_to_dgl_ndarray(lhs_data)
rhs_data_nd = zerocopy_to_dgl_ndarray(rhs_data)
feat_shape = K.infer_binary_feature_shape(lhs_data_nd, rhs_data_nd)
out_data = lhs_data.new_empty((out_size,) + feat_shape)
feat_shape = K.infer_binary_feature_shape(binary_op, lhs_data_nd, rhs_data_nd)
out_shape = feat_shape
if binary_op == 'dot':
out_shape = feat_shape[:-1]
out_data = lhs_data.new_empty((out_size,) + out_shape)
out_data_nd = zerocopy_to_dgl_ndarray(out_data)
K.binary_op_reduce(
reducer if reducer != 'mean' else 'sum',
reducer if reducer != 'mean' else 'sum',
binary_op, graph, lhs, rhs, lhs_data_nd, rhs_data_nd,
out_data_nd, lhs_map[0], rhs_map[0], out_map[0])
# normalize if mean reducer
......@@ -311,7 +314,7 @@ class BinaryReduce(th.autograd.Function):
in_ones = lhs_data.new_ones((n,))
in_ones_nd = zerocopy_to_dgl_ndarray(in_ones)
K.copy_reduce(
'sum', graph, target, in_ones_nd, degs_nd, in_map, out_map[0])
'sum', graph, target, in_ones_nd, degs_nd, in_map, out_map[0])
# reshape
degs = degs.reshape((out_data.shape[0],) + (1,) * (out_data.dim() - 1)).clamp(min=1)
out_data = out_data / degs
......
......@@ -147,12 +147,17 @@ def copy_e(e, out):
###############################################################################
# Generate all following builtin message functions:
# element-wise message functions:
# u_add_v, u_sub_v, u_mul_v, u_div_v
# u_add_e, u_sub_e, u_mul_e, u_div_e
# v_add_u, v_sub_u, v_mul_u, v_div_u
# v_add_e, v_sub_e, v_mul_e, v_div_e
# e_add_u, e_sub_u, e_mul_u, e_div_u
# e_add_v, e_sub_v, e_mul_v, e_div_v
#
# dot message functions:
# u_dot_v, u_dot_e, v_dot_e
# v_dot_u, e_dot_u, e_dot_v
_TARGET_MAP = {
"u": TargetCode.SRC,
......@@ -200,12 +205,11 @@ def _register_builtin_message_func():
target = ["u", "v", "e"]
for lhs, rhs in product(target, target):
if lhs != rhs:
for binary_op in ["add", "sub", "mul", "div"]:
for binary_op in ["add", "sub", "mul", "div", "dot"]:
func = _gen_message_builtin(lhs, rhs, binary_op)
setattr(sys.modules[__name__], func.__name__, func)
__all__.append(func.__name__)
_register_builtin_message_func()
......
......@@ -4,11 +4,14 @@ from __future__ import absolute_import
from ._ffi.function import _init_api
from .ndarray import empty
def infer_binary_feature_shape(lhs, rhs):
# pylint: disable=invalid-name
def infer_binary_feature_shape(op, lhs, rhs):
"""Infer the output feature shape after a binary operation between lhs and rhs.
Parameter
---------
op : string
The binary_op name.
lhs : dgl.ndarray.NDArray
The lhs tensor.
rhs : dgl.ndarray.NDArray
......@@ -19,7 +22,7 @@ def infer_binary_feature_shape(lhs, rhs):
tuple of int
The output feature shape.
"""
ret = _CAPI_DGLKernelInferBinaryFeatureShape(lhs, rhs)
ret = _CAPI_DGLKernelInferBinaryFeatureShape(op, lhs, rhs)
return tuple(ret.asnumpy())
# pylint: disable=invalid-name
......
......@@ -93,11 +93,25 @@ bool HasBcast(NDArray lhs, NDArray rhs) {
// e.g. (4, 1, 3, 3) and (4, 5, 3, 3) become (4, 1, 9) and (4, 5, 9)
//
// See also: BcastInfo (kernel/binary_reduce.h)
BcastInfo CalcBcastInfo(NDArray lhs, NDArray rhs) {
BcastInfo CalcBcastInfo(const std::string& op, NDArray lhs, NDArray rhs) {
BcastInfo ret;
const int max_ndim = std::max(lhs->ndim, rhs->ndim) - 1;
int64_t accum = 0;
for (int j = 0; j < max_ndim; ++j) {
int j = 0;
// for dot operation: vector [dot] vector
// lhs_shape[ndim-1] == rhs_shape[ndim-1] = sizeof(vector)
// out_shape[ndim-1] = 1
if (op == binary_op::kDot) {
// get size of vector
ret.data_len = lhs->shape[lhs->ndim - 1];
// skip vector size dim
++j;
ret.real_out_shape.push_back(ret.data_len);
} else { // op != binary_op::kDot
ret.data_len = 1;
}
for (; j < max_ndim; ++j) {
const int dl = (lhs->ndim - 1 - j < 1)? 1 : lhs->shape[lhs->ndim - 1 - j];
const int dr = (rhs->ndim - 1 - j < 1)? 1 : rhs->shape[rhs->ndim - 1 - j];
if (dl != dr) {
......@@ -258,16 +272,18 @@ class BipartiteCSRWrapper : public CSRWrapper {
std::vector<int64_t> InferBinaryFeatureShape(
const std::string& op,
NDArray lhs,
NDArray rhs) {
return CalcBcastInfo(lhs, rhs).real_out_shape;
return CalcBcastInfo(op, lhs, rhs).real_out_shape;
}
DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelInferBinaryFeatureShape")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
NDArray lhs = args[0];
NDArray rhs = args[1];
const auto& shape = InferBinaryFeatureShape(lhs, rhs);
std::string op = args[0];
NDArray lhs = args[1];
NDArray rhs = args[2];
const auto& shape = InferBinaryFeatureShape(op, lhs, rhs);
const int64_t len = shape.size();
NDArray ret = NDArray::Empty(
{len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
......@@ -300,7 +316,7 @@ void BinaryOpReduce(
rhs_mapping, lhs_mapping, out_mapping);
} else {
if (HasBcast(lhs_data, rhs_data)) {
BcastInfo info = CalcBcastInfo(lhs_data, rhs_data);
BcastInfo info = CalcBcastInfo(op, lhs_data, rhs_data);
DGL_XPU_SWITCH(ctx.device_type, BinaryReduceBcastImpl,
info, reducer, op, graph,
lhs, rhs,
......@@ -394,7 +410,7 @@ void BackwardLhsBinaryOpReduce(
grad_out_data, grad_lhs_data);
} else {
if (HasBcast(lhs_data, rhs_data)) {
BcastInfo info = CalcBcastInfo(lhs_data, rhs_data);
BcastInfo info = CalcBcastInfo(op, lhs_data, rhs_data);
DGL_XPU_SWITCH(ctx.device_type, BackwardBinaryReduceBcastImpl,
info, reducer, op, graph,
lhs, rhs,
......@@ -468,7 +484,7 @@ void BackwardRhsBinaryOpReduce(
grad_out_data, grad_rhs_data);
} else {
if (HasBcast(lhs_data, rhs_data)) {
BcastInfo info = CalcBcastInfo(lhs_data, rhs_data);
BcastInfo info = CalcBcastInfo(op, lhs_data, rhs_data);
DGL_XPU_SWITCH(ctx.device_type, BackwardBinaryReduceBcastImpl,
info, reducer, op, graph,
lhs, rhs,
......
......@@ -30,6 +30,8 @@ struct BcastInfo {
std::vector<int64_t> lhs_shape, lhs_stride;
std::vector<int64_t> rhs_shape, rhs_stride;
std::vector<int64_t> out_shape, out_stride;
int64_t data_len;
};
/*
......
......@@ -29,6 +29,7 @@ static const char kAdd[] = "add";
static const char kSub[] = "sub";
static const char kMul[] = "mul";
static const char kDiv[] = "div";
static const char kDot[] = "dot";
static const char kUseLhs[] = "use_lhs";
/*!
......@@ -129,8 +130,8 @@ struct SwitchSrcDst<SelectDst> {
// common binary functors
template <typename DType>
struct BinaryAdd {
static DGLDEVICE DGLINLINE DType Call(DType lhs, DType rhs) {
return lhs + rhs;
static DGLDEVICE DGLINLINE DType Call(DType *lhs, DType *rhs, int64_t len) {
return lhs[0] + rhs[0];
}
static DGLDEVICE DGLINLINE DType BackwardLhs(DType lhs, DType rhs, DType out) {
return 1;
......@@ -142,8 +143,8 @@ struct BinaryAdd {
template <typename DType>
struct BinaryMul {
static DGLDEVICE DGLINLINE DType Call(DType lhs, DType rhs) {
return lhs * rhs;
static DGLDEVICE DGLINLINE DType Call(DType *lhs, DType *rhs, int64_t len) {
return lhs[0] * rhs[0];
}
static DGLDEVICE DGLINLINE DType BackwardLhs(DType lhs, DType rhs, DType out) {
return rhs;
......@@ -155,8 +156,8 @@ struct BinaryMul {
template <typename DType>
struct BinarySub {
static DGLDEVICE DGLINLINE DType Call(DType lhs, DType rhs) {
return lhs - rhs;
static DGLDEVICE DGLINLINE DType Call(DType *lhs, DType *rhs, int64_t len) {
return lhs[0] - rhs[0];
}
static DGLDEVICE DGLINLINE DType BackwardLhs(DType lhs, DType rhs, DType out) {
return 1;
......@@ -168,8 +169,8 @@ struct BinarySub {
template <typename DType>
struct BinaryDiv {
static DGLDEVICE DGLINLINE DType Call(DType lhs, DType rhs) {
return lhs / rhs;
static DGLDEVICE DGLINLINE DType Call(DType *lhs, DType *rhs, int64_t len) {
return lhs[0] / rhs[0];
}
static DGLDEVICE DGLINLINE DType BackwardLhs(DType lhs, DType rhs, DType out) {
return static_cast<DType>(1) / rhs;
......@@ -181,8 +182,8 @@ struct BinaryDiv {
template <typename DType>
struct BinaryUseLhs {
static DGLDEVICE DGLINLINE DType Call(DType lhs, DType rhs) {
return lhs;
static DGLDEVICE DGLINLINE DType Call(DType *lhs, DType *rhs, int64_t len) {
return lhs[0];
}
static DGLDEVICE DGLINLINE DType BackwardLhs(DType lhs, DType rhs, DType out) {
return 1;
......@@ -192,6 +193,25 @@ struct BinaryUseLhs {
}
};
template <typename DType>
struct BinaryDot {
static DGLDEVICE DGLINLINE DType Call(DType *lhs, DType *rhs, int64_t len) {
DType out = 0;
// simple vector dot vector
#pragma unroll
for (int i = 0; i < len; i ++)
out += lhs[i] * rhs[i];
return out;
}
static DGLDEVICE DGLINLINE DType BackwardLhs(DType lhs, DType rhs, DType out) {
return rhs;
}
static DGLDEVICE DGLINLINE DType BackwardRhs(DType lhs, DType rhs, DType out) {
return lhs;
}
};
// Macro for dispatching op enum code and target code into template arguments.
// The macro dispatches following combinations:
// - Add(Src, Dst), Add(Src, Edge), Add(Dst, Edge)
......@@ -201,6 +221,8 @@ struct BinaryUseLhs {
// - Div(Src, Dst), Div(Src, Edge), Div(Dst, Edge)
// Div(Dst, Src), Div(Edge, Src), Div(Edge, Dst)
// - UseLhs(Src, None), UseLhs(Edge, None)
// - Dot(Src, Dst), Dot(Src, Edge), Dot(Dst, Edge)
// - Dot(Dst, Src), Dot(Edge, Src), Dot(Edge, Dst)
// Note that for commutative operators (e.g. Add and Mul), we only generate
// kernels for lhs code smaller than rhs code.
#define OP_TARGET_SWITCH(op, lhs, rhs, DType, OpType, LeftType, RightType, ...) \
......@@ -306,6 +328,36 @@ struct BinaryUseLhs {
typedef SelectEdge LeftType; \
typedef SelectNone RightType; \
{__VA_ARGS__} \
} else if (op == kDot && lhs == kSrc && rhs == kDst) { \
typedef BinaryDot<DType> OpType; \
typedef SelectSrc LeftType; \
typedef SelectDst RightType; \
{__VA_ARGS__} \
} else if (op == kDot && lhs == kSrc && rhs == kEdge) { \
typedef BinaryDot<DType> OpType; \
typedef SelectSrc LeftType; \
typedef SelectEdge RightType; \
{__VA_ARGS__} \
} else if (op == kDot && lhs == kDst && rhs == kEdge) { \
typedef BinaryDot<DType> OpType; \
typedef SelectDst LeftType; \
typedef SelectEdge RightType; \
{__VA_ARGS__} \
} else if (op == kDot && lhs == kDst && rhs == kSrc) { \
typedef BinaryDot<DType> OpType; \
typedef SelectDst LeftType; \
typedef SelectSrc RightType; \
{__VA_ARGS__} \
} else if (op == kDot && lhs == kEdge && rhs == kSrc) { \
typedef BinaryDot<DType> OpType; \
typedef SelectEdge LeftType; \
typedef SelectSrc RightType; \
{__VA_ARGS__} \
} else if (op == kDot && lhs == kEdge && rhs == kDst) { \
typedef BinaryDot<DType> OpType; \
typedef SelectEdge LeftType; \
typedef SelectDst RightType; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << "Unsupported operation: op=" << op \
<< " lhs=" << lhs << " rhs=" << rhs; \
......@@ -333,7 +385,13 @@ struct BinaryUseLhs {
MSVC_EXPAND(GEN(__VA_ARGS__, SelectDst, SelectEdge, BinaryDiv)) \
MSVC_EXPAND(GEN(__VA_ARGS__, SelectEdge, SelectDst, BinaryDiv)) \
MSVC_EXPAND(GEN(__VA_ARGS__, SelectSrc, SelectNone, BinaryUseLhs)) \
MSVC_EXPAND(GEN(__VA_ARGS__, SelectEdge, SelectNone, BinaryUseLhs))
MSVC_EXPAND(GEN(__VA_ARGS__, SelectEdge, SelectNone, BinaryUseLhs)) \
MSVC_EXPAND(GEN(__VA_ARGS__, SelectSrc, SelectDst, BinaryDot)) \
MSVC_EXPAND(GEN(__VA_ARGS__, SelectSrc, SelectEdge, BinaryDot)) \
MSVC_EXPAND(GEN(__VA_ARGS__, SelectDst, SelectEdge, BinaryDot)) \
MSVC_EXPAND(GEN(__VA_ARGS__, SelectDst, SelectSrc, BinaryDot)) \
MSVC_EXPAND(GEN(__VA_ARGS__, SelectEdge, SelectSrc, BinaryDot)) \
MSVC_EXPAND(GEN(__VA_ARGS__, SelectEdge, SelectDst, BinaryDot))
//////////////////////////////////////////////////////////////////////////
// Defines reducer category. Each category is an empty structure.
......
......@@ -28,7 +28,7 @@ namespace kernel {
///////////////////////////////////////////////////////////////////////////////
template <int XPU, typename Idx, typename DType, typename Reducer>
GData<Idx, DType> AllocGData(
GData<Idx, DType> AllocGData(const std::string& op,
const DLContext& ctx, int64_t x_len,
runtime::NDArray lhs_mapping, runtime::NDArray rhs_mapping,
runtime::NDArray lhs_data, runtime::NDArray rhs_data,
......@@ -49,6 +49,15 @@ GData<Idx, DType> AllocGData(
if (!utils::IsNoneArray(out_mapping)) {
gdata.out_mapping = static_cast<Idx*>(out_mapping->data);
}
// for dot operation: vector [dot] vector
if (op == binary_op::kDot) {
// get size of vector
gdata.data_len = lhs_data->shape[lhs_data->ndim - 1];
} else {
gdata.data_len = 1;
}
// fill out data with zero values
utils::Fill<XPU>(ctx, gdata.out_data, utils::NElements(out_data), Zero<Reducer>::value);
return gdata;
......@@ -92,7 +101,7 @@ void BinaryReduceImpl(
DGL_DTYPE_SWITCH(dtype, DType, {
DGL_IDX_TYPE_SWITCH(bits, Idx, {
REDUCER_SWITCH(reducer, XPU, DType, Reducer, {
auto gdata = AllocGData<XPU, Idx, DType, Reducer>(
auto gdata = AllocGData<XPU, Idx, DType, Reducer>(op,
rtcfg.ctx, x_len, lhs_mapping, rhs_mapping,
lhs_data, rhs_data, out_mapping, out_data);
OP_TARGET_SWITCH(op, lhs, rhs, DType, BinaryOp, LeftTarget, RightTarget, {
......@@ -110,7 +119,7 @@ void BinaryReduceImpl(
template <int XPU, typename Idx, typename DType>
BackwardGData<Idx, DType> AllocBackwardGData(
const DLContext& ctx, int64_t x_len,
const std::string& op, const DLContext& ctx, int64_t x_len,
runtime::NDArray lhs_mapping, runtime::NDArray rhs_mapping, runtime::NDArray out_mapping,
runtime::NDArray lhs_data, runtime::NDArray rhs_data, runtime::NDArray out_data,
runtime::NDArray grad_out_data,
......@@ -144,6 +153,14 @@ BackwardGData<Idx, DType> AllocBackwardGData(
if (!utils::IsNoneArray(out_mapping)) {
gdata.out_mapping = static_cast<Idx*>(out_mapping->data);
}
// for dot operation: vector [dot] vector
if (op == binary_op::kDot) {
// get size of vector
gdata.data_len = lhs_data->shape[lhs_data->ndim - 1];
} else {
gdata.data_len = 1;
}
return gdata;
}
......@@ -182,13 +199,14 @@ void BackwardBinaryReduceImpl(
const bool req_lhs = !utils::IsNoneArray(grad_lhs_data);
const bool req_rhs = !utils::IsNoneArray(grad_rhs_data);
const auto bits = graph.NumBits();
if (reducer == binary_op::kReduceMean) {
// TODO(minjie): divide
LOG(FATAL) << "reduce mean is not supported.";
}
DGL_DTYPE_SWITCH(dtype, DType, {
DGL_IDX_TYPE_SWITCH(bits, Idx, {
auto gdata = AllocBackwardGData<XPU, Idx, DType>(
auto gdata = AllocBackwardGData<XPU, Idx, DType>(op,
rtcfg.ctx, x_len, lhs_mapping, rhs_mapping, out_mapping,
lhs_data, rhs_data, out_data, grad_out_data,
grad_lhs_data, grad_rhs_data);
......@@ -240,6 +258,8 @@ BcastGData<NDim, Idx, DType> AllocBcastGData(
if (!utils::IsNoneArray(out_mapping)) {
gdata.out_mapping = static_cast<Idx*>(out_mapping->data);
}
gdata.data_len = info.data_len;
// fill out data with zero values
utils::Fill<XPU>(ctx, gdata.out_data, utils::NElements(out_data), Zero<Reducer>::value);
return gdata;
......@@ -280,6 +300,7 @@ void BinaryReduceBcastImpl(
const DLDataType& dtype = out_data->dtype;
const int bcast_ndim = info.out_shape.size();
const auto bits = graph.NumBits();
if (reducer == binary_op::kReduceMean) {
// TODO(minjie): divide
LOG(FATAL) << "reduce mean is not supported.";
......@@ -334,6 +355,8 @@ BackwardBcastGData<NDim, Idx, DType> AllocBackwardBcastGData(
if (!utils::IsNoneArray(out_mapping)) {
gdata.out_mapping = static_cast<Idx*>(out_mapping->data);
}
gdata.data_len = info.data_len;
// data
gdata.lhs_data = static_cast<DType*>(lhs->data);
gdata.rhs_data = static_cast<DType*>(rhs->data);
......@@ -387,6 +410,7 @@ void BackwardBinaryReduceBcastImpl(
const bool req_lhs = !utils::IsNoneArray(grad_lhs);
const bool req_rhs = !utils::IsNoneArray(grad_rhs);
const auto bits = graph.NumBits();
if (reducer == binary_op::kReduceMean) {
// TODO(minjie): divide
LOG(FATAL) << "reduce mean is not supported.";
......
......@@ -36,6 +36,8 @@ template <typename Idx, typename DType>
struct GData {
// length along x(feature) dimension
int64_t x_length{0};
// size of data, can be single value or a vector
int64_t data_len;
// number of rows of the output tensor
int64_t out_size{0};
// input data
......@@ -119,6 +121,8 @@ template <typename Idx, typename DType>
struct BackwardGData {
// length along x(feature) dimension
int64_t x_length{0};
// size of data, can be single value or a vector
int64_t data_len;
// number of rows of the output tensor
int64_t out_size{0};
// input data
......@@ -222,6 +226,8 @@ struct BcastGData {
int64_t lhs_len{0}, rhs_len{0};
int64_t lhs_shape[NDim]{0}, lhs_stride[NDim]{0};
int64_t rhs_shape[NDim]{0}, rhs_stride[NDim]{0};
// size of data, can be single value or a vector
int64_t data_len;
// input data
DType *lhs_data{nullptr}, *rhs_data{nullptr};
// input id mappings
......@@ -326,6 +332,8 @@ struct BackwardBcastGData {
int64_t lhs_shape[NDim]{0}, lhs_stride[NDim]{0};
int64_t rhs_shape[NDim]{0}, rhs_stride[NDim]{0};
int64_t out_shape[NDim]{0}, out_stride[NDim]{0};
// size of data, can be single value or a vector
int64_t data_len;
// input id mappings
Idx *lhs_mapping{nullptr}, *rhs_mapping{nullptr}, *out_mapping{nullptr};
// input data
......
......@@ -27,6 +27,7 @@ struct BackwardBinaryReduce {
static inline void ApplyEdge(
Idx src, Idx dst, Idx eid, BackwardGData<Idx, DType>* gdata) {
const int64_t D = gdata->x_length;
const int64_t len = gdata->data_len;
Idx lid = Functors::SelectLeft(src, eid, dst);
Idx rid = Functors::SelectRight(src, eid, dst);
Idx oid = Functors::SelectOut(src, eid, dst);
......@@ -39,28 +40,46 @@ struct BackwardBinaryReduce {
if (gdata->out_mapping) {
oid = Functors::GetId(oid, gdata->out_mapping);
}
DType* lhsoff = gdata->lhs_data + lid * D;
DType* rhsoff = gdata->rhs_data + rid * D;
DType* lhsoff = gdata->lhs_data + lid * D * len;
DType* rhsoff = gdata->rhs_data + rid * D * len;
DType* outoff = gdata->out_data + oid * D;
DType* gradlhsoff = gdata->grad_lhs_data + lid * D;
DType* gradrhsoff = gdata->grad_rhs_data + rid * D;
DType* gradlhsoff = gdata->grad_lhs_data + lid * D * len;
DType* gradrhsoff = gdata->grad_rhs_data + rid * D * len;
DType* gradoutoff = gdata->grad_out_data + oid * D;
for (int64_t tx = 0; tx < D; ++tx) {
DType lhs = Functors::Read(lhsoff + tx);
DType rhs = Functors::Read(rhsoff + tx);
DType out = Functors::Read(outoff + tx);
DType grad_out = Functors::Read(gradoutoff + tx);
DType e = Functors::Op(lhs, rhs);
DType e = Functors::Op(lhsoff + tx * len, rhsoff + tx * len, len);
DType grad_e = grad_out * Functors::BackwardWrite(e, out);
if (Mode == binary_op::kGradLhs || Mode == binary_op::kGradBoth) {
DType grad_lhs = grad_e * Functors::BackwardOpLhs(lhs, rhs, e);
DType* lhs_base = lhsoff + tx * len;
DType* rhs_base = rhsoff + tx * len;
if (Mode == binary_op::kGradBoth) {
for (int64_t i = 0; i < len; ++i) {
DType lhs = Functors::Read(lhs_base + i);
DType rhs = Functors::Read(rhs_base + i);
DType grad_lhs = grad_e * Functors::BackwardOpLhs(lhs, rhs, e);
DType grad_rhs = grad_e * Functors::BackwardOpRhs(lhs, rhs, e);
DType grad = grad_lhs + grad_rhs;
#pragma omp atomic
gradlhsoff[tx] += grad_lhs;
}
if (Mode == binary_op::kGradRhs || Mode == binary_op::kGradBoth) {
DType grad_rhs = grad_e * Functors::BackwardOpRhs(lhs, rhs, e);
gradlhsoff[tx * len + i] += grad;
}
} else if (Mode == binary_op::kGradLhs) {
for (int64_t i = 0; i < len; ++i) {
DType lhs = Functors::Read(lhs_base + i);
DType rhs = Functors::Read(rhs_base + i);
DType grad_lhs = grad_e * Functors::BackwardOpLhs(lhs, rhs, e);
#pragma omp atomic
gradlhsoff[tx * len + i] += grad_lhs;
}
} else if (Mode == binary_op::kGradRhs) {
for (int64_t i = 0; i < len; ++i) {
DType lhs = Functors::Read(lhs_base + i);
DType rhs = Functors::Read(rhs_base + i);
DType grad_rhs = grad_e * Functors::BackwardOpRhs(lhs, rhs, e);
#pragma omp atomic
gradrhsoff[tx] += grad_rhs;
gradrhsoff[tx * len + i] += grad_rhs;
}
}
}
}
......@@ -76,6 +95,7 @@ struct BackwardBinaryReduceBcast {
}
static inline void ApplyEdge(
Idx src, Idx dst, Idx eid, BackwardBcastGData<NDim, Idx, DType>* gdata) {
const int64_t len = gdata->data_len;
Idx lid = Functors::SelectLeft(src, eid, dst);
Idx rid = Functors::SelectRight(src, eid, dst);
Idx oid = Functors::SelectOut(src, eid, dst);
......@@ -88,32 +108,53 @@ struct BackwardBinaryReduceBcast {
if (gdata->out_mapping) {
oid = Functors::GetId(oid, gdata->out_mapping);
}
DType* lhsoff = gdata->lhs_data + lid * gdata->lhs_len;
DType* rhsoff = gdata->rhs_data + rid * gdata->rhs_len;
DType* lhsoff = gdata->lhs_data + lid * gdata->lhs_len * len;
DType* rhsoff = gdata->rhs_data + rid * gdata->rhs_len * len;
DType* outoff = gdata->out_data + oid * gdata->out_len;
DType* gradlhsoff = gdata->grad_lhs_data + lid * gdata->out_len;
DType* gradrhsoff = gdata->grad_rhs_data + rid * gdata->out_len;
DType* gradlhsoff = gdata->grad_lhs_data + lid * gdata->out_len * len;
DType* gradrhsoff = gdata->grad_rhs_data + rid * gdata->out_len * len;
DType* gradoutoff = gdata->grad_out_data + oid * gdata->out_len;
int64_t tmp[NDim]; // store unraveled idx.
for (int64_t tx = 0; tx < gdata->out_len; ++tx) {
Unravel(tx, gdata->ndim, gdata->out_shape, gdata->out_stride, tmp);
DType lhs = Functors::Read(lhsoff +
Ravel(tmp, gdata->ndim, gdata->lhs_shape, gdata->lhs_stride));
DType rhs = Functors::Read(rhsoff +
Ravel(tmp, gdata->ndim, gdata->rhs_shape, gdata->rhs_stride));
DType out = Functors::Read(outoff + tx);
DType grad_out = Functors::Read(gradoutoff + tx);
DType e = Functors::Op(lhs, rhs);
DType e = Functors::Op(
lhsoff + Ravel(tmp, gdata->ndim, gdata->lhs_shape, gdata->lhs_stride) * len,
rhsoff + Ravel(tmp, gdata->ndim, gdata->rhs_shape, gdata->rhs_stride) * len,
len);
DType grad_e = grad_out * Functors::BackwardWrite(e, out);
if (Mode == binary_op::kGradLhs || Mode == binary_op::kGradBoth) {
DType grad_lhs = grad_e * Functors::BackwardOpLhs(lhs, rhs, e);
DType* lhs_base = lhsoff +
Ravel(tmp, gdata->ndim, gdata->lhs_shape, gdata->lhs_stride) * len;
DType* rhs_base = rhsoff +
Ravel(tmp, gdata->ndim, gdata->rhs_shape, gdata->rhs_stride) * len;
if (Mode == binary_op::kGradBoth) {
for (int64_t i = 0; i < len; ++i) {
DType lhs = Functors::Read(lhs_base + i);
DType rhs = Functors::Read(rhs_base + i);
DType grad_lhs = grad_e * Functors::BackwardOpLhs(lhs, rhs, e);
DType grad_rhs = grad_e * Functors::BackwardOpRhs(lhs, rhs, e);
DType grad = grad_lhs + grad_rhs;
#pragma omp atomic
gradlhsoff[tx] += grad_lhs;
}
if (Mode == binary_op::kGradRhs || Mode == binary_op::kGradBoth) {
DType grad_rhs = grad_e * Functors::BackwardOpRhs(lhs, rhs, e);
gradlhsoff[tx * len + i] += grad;
}
} else if (Mode == binary_op::kGradLhs) {
for (int64_t i = 0; i < len; ++i) {
DType lhs = Functors::Read(lhs_base + i);
DType rhs = Functors::Read(rhs_base + i);
DType grad_lhs = grad_e * Functors::BackwardOpLhs(lhs, rhs, e);
#pragma omp atomic
gradlhsoff[tx * len + i] += grad_lhs;
}
} else if (Mode == binary_op::kGradRhs) {
for (int64_t i = 0; i < len; ++i) {
DType lhs = Functors::Read(lhs_base + i);
DType rhs = Functors::Read(rhs_base + i);
DType grad_rhs = grad_e * Functors::BackwardOpRhs(lhs, rhs, e);
#pragma omp atomic
gradrhsoff[tx] += grad_rhs;
gradrhsoff[tx * len + i] += grad_rhs;
}
}
}
}
......@@ -137,8 +178,8 @@ struct BackwardFunctorsTempl {
Idx src, Idx edge, Idx dst) {
return RightSelector::Call(src, edge, dst);
}
static inline DType Op(DType lhs, DType rhs) {
return BinaryOp::Call(lhs, rhs);
static inline DType Op(DType* lhs, DType* rhs, int64_t len) {
return BinaryOp::Call(lhs, rhs, len);
}
static inline DType Read(DType* addr) {
return *addr;
......
......@@ -29,6 +29,7 @@ struct BinaryReduce {
static inline void ApplyEdge(
Idx src, Idx dst, Idx eid, GData<Idx, DType>* gdata) {
const int64_t D = gdata->x_length;
const int64_t len = gdata->data_len;
Idx lid = Functors::SelectLeft(src, eid, dst);
Idx rid = Functors::SelectRight(src, eid, dst);
Idx oid = Functors::SelectOut(src, eid, dst);
......@@ -41,13 +42,11 @@ struct BinaryReduce {
if (gdata->out_mapping) {
oid = Functors::GetId(oid, gdata->out_mapping);
}
DType* lhsoff = gdata->lhs_data + lid * D;
DType* rhsoff = gdata->rhs_data + rid * D;
DType* lhsoff = gdata->lhs_data + lid * D * len;
DType* rhsoff = gdata->rhs_data + rid * D * len;
DType* outoff = gdata->out_data + oid * D;
for (int64_t tx = 0; tx < D; ++tx) {
DType lhs = Functors::Read(lhsoff + tx);
DType rhs = Functors::Read(rhsoff + tx);
DType out = Functors::Op(lhs, rhs);
DType out = Functors::Op(lhsoff + tx * len, rhsoff + tx * len, len);
Functors::Write(outoff + tx, out);
}
}
......@@ -80,6 +79,7 @@ struct BinaryReduceBcast {
}
static inline void ApplyEdge(
Idx src, Idx dst, Idx eid, BcastGData<NDim, Idx, DType>* gdata) {
const int64_t len = gdata->data_len;
Idx lid = Functors::SelectLeft(src, eid, dst);
Idx rid = Functors::SelectRight(src, eid, dst);
Idx oid = Functors::SelectOut(src, eid, dst);
......@@ -92,17 +92,17 @@ struct BinaryReduceBcast {
if (gdata->out_mapping) {
oid = Functors::GetId(oid, gdata->out_mapping);
}
DType* lhsoff = gdata->lhs_data + lid * gdata->lhs_len;
DType* rhsoff = gdata->rhs_data + rid * gdata->rhs_len;
DType* lhsoff = gdata->lhs_data + lid * gdata->lhs_len * len; // data with len size
DType* rhsoff = gdata->rhs_data + rid * gdata->rhs_len * len;
DType* outoff = gdata->out_data + oid * gdata->out_len;
int64_t tmp[NDim]; // store unraveled idx.
for (int64_t tx = 0; tx < gdata->out_len; ++tx) {
Unravel(tx, gdata->ndim, gdata->out_shape, gdata->out_stride, tmp);
DType lhs = Functors::Read(lhsoff +
Ravel(tmp, gdata->ndim, gdata->lhs_shape, gdata->lhs_stride));
DType rhs = Functors::Read(rhsoff +
Ravel(tmp, gdata->ndim, gdata->rhs_shape, gdata->rhs_stride));
DType out = Functors::Op(lhs, rhs);
DType out = Functors::Op(
lhsoff + Ravel(tmp, gdata->ndim, gdata->lhs_shape, gdata->lhs_stride) * len,
rhsoff + Ravel(tmp, gdata->ndim, gdata->rhs_shape, gdata->rhs_stride) * len,
len);
Functors::Write(outoff + tx, out);
}
}
......@@ -125,11 +125,8 @@ struct FunctorsTempl {
Idx src, Idx edge, Idx dst) {
return RightSelector::Call(src, edge, dst);
}
static inline DType Op(DType lhs, DType rhs) {
return BinaryOp::Call(lhs, rhs);
}
static inline DType Read(DType* addr) {
return *addr;
static inline DType Op(DType *lhs, DType *rhs, int64_t len) {
return BinaryOp::Call(lhs, rhs, len);
}
static inline void Write(DType* addr, DType val) {
Reducer::Call(addr, val);
......
......@@ -29,6 +29,7 @@ struct BackwardBinaryReduce {
const int64_t D = gdata->x_length;
int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;
int stride_x = blockDim.x * gridDim.x;
const int64_t len = gdata->data_len;
Idx lid = Functors::SelectLeft(src, eid, dst);
Idx rid = Functors::SelectRight(src, eid, dst);
Idx oid = Functors::SelectOut(src, eid, dst);
......@@ -41,50 +42,52 @@ struct BackwardBinaryReduce {
if (gdata->out_mapping) {
oid = Functors::GetId(oid, gdata->out_mapping);
}
DType* lhsoff = gdata->lhs_data + lid * D;
DType* rhsoff = gdata->rhs_data + rid * D;
DType* lhsoff = gdata->lhs_data + lid * D * len;
DType* rhsoff = gdata->rhs_data + rid * D * len;
DType* outoff = gdata->out_data + oid * D;
DType* gradlhsoff = gdata->grad_lhs_data + lid * D;
DType* gradrhsoff = gdata->grad_rhs_data + rid * D;
DType* gradlhsoff = gdata->grad_lhs_data + lid * D * len;
DType* gradrhsoff = gdata->grad_rhs_data + rid * D * len;
DType* gradoutoff = gdata->grad_out_data + oid * D;
while (tx < D) {
DType lhs = Functors::Read(lhsoff + tx);
DType rhs = Functors::Read(rhsoff + tx);
DType out = Functors::Read(outoff + tx);
DType grad_out = Functors::Read(gradoutoff + tx);
DType e = Functors::Op(lhs, rhs);
DType e = Functors::Op(lhsoff + tx * len, rhsoff + tx * len, len);
DType grad_e = grad_out * Functors::BackwardWrite(e, out);
if (Mode == binary_op::kGradLhs || Mode == binary_op::kGradBoth) {
DType grad_lhs = grad_e * Functors::BackwardOpLhs(lhs, rhs, e);
AtomicAdd(gradlhsoff + tx, grad_lhs);
}
if (Mode == binary_op::kGradRhs || Mode == binary_op::kGradBoth) {
DType grad_rhs = grad_e * Functors::BackwardOpRhs(lhs, rhs, e);
AtomicAdd(gradrhsoff + tx, grad_rhs);
DType* lhs_base = lhsoff + tx * len;
DType* rhs_base = rhsoff + tx * len;
if (Mode == binary_op::kGradBoth) {
#pragma unroll
for (int64_t i = 0; i < len; ++i) {
DType lhs = Functors::Read(lhs_base + i);
DType rhs = Functors::Read(rhs_base + i);
DType grad_lhs = grad_e * Functors::BackwardOpLhs(lhs, rhs, e);
DType grad_rhs = grad_e * Functors::BackwardOpRhs(lhs, rhs, e);
DType grad = grad_lhs + grad_rhs;
AtomicAdd(gradlhsoff + tx * len + i, grad);
}
} else if (Mode == binary_op::kGradLhs) {
#pragma unroll
for (int64_t i = 0; i < len; ++i) {
DType lhs = Functors::Read(lhs_base + i);
DType rhs = Functors::Read(rhs_base + i);
DType grad_lhs = grad_e * Functors::BackwardOpLhs(lhs, rhs, e);
AtomicAdd(gradlhsoff + tx * len + i, grad_lhs);
}
} else if (Mode == binary_op::kGradRhs) {
#pragma unroll
for (int64_t i = 0; i < len; ++i) {
DType lhs = Functors::Read(lhs_base + i);
DType rhs = Functors::Read(rhs_base + i);
DType grad_rhs = grad_e * Functors::BackwardOpRhs(lhs, rhs, e);
AtomicAdd(gradrhsoff + tx * len + i, grad_rhs);
}
}
tx += stride_x;
}
}
};
// Convert flattened index to multi-dimension index (assume row-major).
__device__ __forceinline__ void Unravel(
int64_t idx, int ndim, const int64_t* shape, const int64_t* stride, int64_t* out) {
for (int d = 0; d < ndim; ++d) {
out[d] = (idx / stride[d]) % shape[d];
}
}
// Convert multi-dimension index to flattened index (assume row-major).
__device__ __forceinline__ int64_t Ravel(
const int64_t* idx, int ndim, const int64_t* shape, const int64_t* stride) {
int64_t out = 0;
for (int d = 0; d < ndim; ++d) {
out += min(idx[d], shape[d] - 1) * stride[d];
}
return out;
}
// Minigun UDF to compute backward binary reduce with broadcasting.
template <int Mode, int NDim, typename Idx, typename DType, typename Functors>
struct BackwardBinaryReduceBcast {
......@@ -96,6 +99,7 @@ struct BackwardBinaryReduceBcast {
Idx src, Idx dst, Idx eid, BackwardBcastGData<NDim, Idx, DType>* gdata) {
int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;
int stride_x = blockDim.x * gridDim.x;
const int64_t len = gdata->data_len;
Idx lid = Functors::SelectLeft(src, eid, dst);
Idx rid = Functors::SelectRight(src, eid, dst);
Idx oid = Functors::SelectOut(src, eid, dst);
......@@ -108,30 +112,51 @@ struct BackwardBinaryReduceBcast {
if (gdata->out_mapping) {
oid = Functors::GetId(oid, gdata->out_mapping);
}
DType* lhsoff = gdata->lhs_data + lid * gdata->lhs_len;
DType* rhsoff = gdata->rhs_data + rid * gdata->rhs_len;
DType* lhsoff = gdata->lhs_data + lid * gdata->lhs_len * len;
DType* rhsoff = gdata->rhs_data + rid * gdata->rhs_len * len;
DType* outoff = gdata->out_data + oid * gdata->out_len;
DType* gradlhsoff = gdata->grad_lhs_data + lid * gdata->out_len;
DType* gradrhsoff = gdata->grad_rhs_data + rid * gdata->out_len;
DType* gradlhsoff = gdata->grad_lhs_data + lid * gdata->out_len * len;
DType* gradrhsoff = gdata->grad_rhs_data + rid * gdata->out_len * len;
DType* gradoutoff = gdata->grad_out_data + oid * gdata->out_len;
int64_t tmp[NDim]; // store unraveled idx.
while (tx < gdata->out_len) {
Unravel(tx, gdata->ndim, gdata->out_shape, gdata->out_stride, tmp);
DType lhs = Functors::Read(lhsoff +
Ravel(tmp, gdata->ndim, gdata->lhs_shape, gdata->lhs_stride));
DType rhs = Functors::Read(rhsoff +
Ravel(tmp, gdata->ndim, gdata->rhs_shape, gdata->rhs_stride));
int64_t lhs_add = 0;
int64_t rhs_add = 0;
UnravelRavel(tx, gdata->ndim, gdata->out_shape, gdata->out_stride,
gdata->lhs_shape, gdata->lhs_stride,
gdata->rhs_shape, gdata->rhs_stride, &lhs_add, &rhs_add);
DType out = Functors::Read(outoff + tx);
DType grad_out = Functors::Read(gradoutoff + tx);
DType e = Functors::Op(lhs, rhs);
DType e = Functors::Op(lhsoff + lhs_add * len, rhsoff + rhs_add * len, len);
DType grad_e = grad_out * Functors::BackwardWrite(e, out);
if (Mode == binary_op::kGradLhs || Mode == binary_op::kGradBoth) {
DType grad_lhs = grad_e * Functors::BackwardOpLhs(lhs, rhs, e);
AtomicAdd(gradlhsoff + tx, grad_lhs);
}
if (Mode == binary_op::kGradRhs || Mode == binary_op::kGradBoth) {
DType grad_rhs = grad_e * Functors::BackwardOpRhs(lhs, rhs, e);
AtomicAdd(gradrhsoff + tx, grad_rhs);
DType* lhs_base = lhsoff + lhs_add * len;
DType* rhs_base = rhsoff + rhs_add * len;
if (Mode == binary_op::kGradBoth) {
#pragma unroll
for (int64_t i = 0; i < len; ++i) {
DType lhs = Functors::Read(lhs_base + i);
DType rhs = Functors::Read(rhs_base + i);
DType grad_lhs = grad_e * Functors::BackwardOpLhs(lhs, rhs, e);
DType grad_rhs = grad_e * Functors::BackwardOpRhs(lhs, rhs, e);
DType grad = grad_lhs + grad_rhs;
AtomicAdd(gradlhsoff + tx * len + i, grad);
}
} else if (Mode == binary_op::kGradLhs) {
#pragma unroll
for (int64_t i = 0; i < len; ++i) {
DType lhs = Functors::Read(lhs_base + i);
DType rhs = Functors::Read(rhs_base + i);
DType grad_lhs = grad_e * Functors::BackwardOpLhs(lhs, rhs, e);
AtomicAdd(gradlhsoff + tx * len + i, grad_lhs);
}
} else if (Mode == binary_op::kGradRhs) {
#pragma unroll
for (int64_t i = 0; i < len; ++i) {
DType lhs = Functors::Read(lhs_base + i);
DType rhs = Functors::Read(rhs_base + i);
DType grad_rhs = grad_e * Functors::BackwardOpRhs(lhs, rhs, e);
AtomicAdd(gradrhsoff + tx * len + i, grad_rhs);
}
}
tx += stride_x;
}
......@@ -156,8 +181,8 @@ struct BackwardFunctorsTempl {
Idx src, Idx edge, Idx dst) {
return RightSelector::Call(src, edge, dst);
}
static __device__ __forceinline__ DType Op(DType lhs, DType rhs) {
return BinaryOp::Call(lhs, rhs);
static __device__ __forceinline__ DType Op(DType* lhs, DType* rhs, int64_t len) {
return BinaryOp::Call(lhs, rhs, len);
}
static __device__ __forceinline__ DType Read(DType* addr) {
return LDGReader<DType>::Call(addr);
......
......@@ -28,7 +28,8 @@ struct BinaryReduce {
Idx src, Idx dst, Idx eid, GData<Idx, DType>* gdata) {
const int64_t D = gdata->x_length;
int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;
int stride_x = blockDim.x * gridDim.x;
const int64_t stride_x = blockDim.x * gridDim.x;
const int64_t len = gdata->data_len;
Idx lid = Functors::SelectLeft(src, eid, dst);
Idx rid = Functors::SelectRight(src, eid, dst);
Idx oid = Functors::SelectOut(src, eid, dst);
......@@ -41,13 +42,11 @@ struct BinaryReduce {
if (gdata->out_mapping) {
oid = Functors::GetId(oid, gdata->out_mapping);
}
DType* lhsoff = gdata->lhs_data + lid * D;
DType* rhsoff = gdata->rhs_data + rid * D;
DType* lhsoff = gdata->lhs_data + lid * D * len;
DType* rhsoff = gdata->rhs_data + rid * D * len;
DType* outoff = gdata->out_data + oid * D;
while (tx < D) {
DType lhs = Functors::Read(lhsoff + tx);
DType rhs = Functors::Read(rhsoff + tx);
DType out = Functors::Op(lhs, rhs);
DType out = Functors::Op(lhsoff + tx * len, rhsoff + tx * len, len);
Functors::Write(outoff + tx, out);
tx += stride_x;
}
......@@ -72,13 +71,12 @@ __device__ __forceinline__ void UnravelRavel(
int64_t o_st = out_stride[d];
int64_t rhs_sh = rhs_shape[d];
int64_t rhs_st = rhs_stride[d];
int64_t i = (idx / o_st) % o_sh;
/*
* Simplfied for rhs_out += min(i, rhs_sh - 1) * rhs_st;
* rhs_sh be o_sh or 1
*/
if (rhs_sh > i) {
if (rhs_sh > i) {
*rhs_out += i * rhs_st;
}
}
......@@ -90,7 +88,7 @@ __device__ __forceinline__ void UnravelRavel(
int64_t o_st = out_stride[d];
int64_t lhs_sh = lhs_shape[d];
int64_t lhs_st = lhs_stride[d];
int64_t i = (idx / o_st) % o_sh;
/*
* Simplfied for lhs_out += min(i, lhs_sh - 1) * lhs_st;
......@@ -114,7 +112,8 @@ struct BinaryReduceBcast {
static __device__ __forceinline__ void ApplyEdge(
Idx src, Idx dst, Idx eid, BcastGData<NDim, Idx, DType>* gdata) {
int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;
int stride_x = blockDim.x * gridDim.x;
const int64_t stride_x = blockDim.x * gridDim.x;
const int64_t len = gdata->data_len;
Idx lid = Functors::SelectLeft(src, eid, dst);
Idx rid = Functors::SelectRight(src, eid, dst);
Idx oid = Functors::SelectOut(src, eid, dst);
......@@ -127,8 +126,8 @@ struct BinaryReduceBcast {
if (gdata->out_mapping) {
oid = Functors::GetId(oid, gdata->out_mapping);
}
DType* lhsoff = gdata->lhs_data + lid * gdata->lhs_len;
DType* rhsoff = gdata->rhs_data + rid * gdata->rhs_len;
DType* lhsoff = gdata->lhs_data + lid * gdata->lhs_len * len; //data with len size
DType* rhsoff = gdata->rhs_data + rid * gdata->rhs_len * len;
DType* outoff = gdata->out_data + oid * gdata->out_len;
while (tx < gdata->out_len) {
int64_t lhs_add = 0;
......@@ -136,9 +135,8 @@ struct BinaryReduceBcast {
UnravelRavel(tx, gdata->ndim, gdata->out_shape, gdata->out_stride,
gdata->lhs_shape, gdata->lhs_stride,
gdata->rhs_shape, gdata->rhs_stride, &lhs_add, &rhs_add);
DType lhs = Functors::Read(lhsoff + lhs_add);
DType rhs = Functors::Read(rhsoff + rhs_add);
DType out = Functors::Op(lhs, rhs);
DType out = Functors::Op(lhsoff + lhs_add * len, rhsoff + rhs_add * len, len);
Functors::Write(outoff + tx, out);
tx += stride_x;
}
......@@ -162,11 +160,8 @@ struct FunctorsTempl {
Idx src, Idx edge, Idx dst) {
return RightSelector::Call(src, edge, dst);
}
static __device__ __forceinline__ DType Op(DType lhs, DType rhs) {
return BinaryOp::Call(lhs, rhs);
}
static __device__ __forceinline__ DType Read(DType* addr) {
return LDGReader<DType>::Call(addr);
static __device__ __forceinline__ DType Op(DType *lhs, DType *rhs, int64_t len) {
return BinaryOp::Call(lhs, rhs, len);
}
static __device__ __forceinline__ void Write(DType* addr, DType val) {
Reducer::Call(addr, val);
......
......@@ -115,6 +115,10 @@ def matmul(a, b):
"""Compute Matrix Multiplication between a and b"""
pass
def dot(a, b):
"""Compute Dot between a and b"""
pass
###############################################################################
# Tensor functions used *only* on index tensor
# ----------------
......
......@@ -86,6 +86,9 @@ def prod(x, dim):
def matmul(a, b):
return nd.dot(a, b)
def dot(a, b):
return nd.sum(mul(a, b), axis=-1)
record_grad = autograd.record
......
......@@ -82,6 +82,9 @@ def prod(x, dim):
def matmul(a, b):
return a @ b
def dot(a, b):
return sum(mul(a, b), dim=-1)
class record_grad(object):
def __init__(self):
pass
......
......@@ -26,33 +26,52 @@ def udf_max(nodes):
D1 = 5
D2 = 3
D3 = 4
D4 = 10 # NOTE(xiang): used to dot feature vector
builtin = {'sum': fn.sum, 'max': fn.max, 'mean': fn.mean}
udf_reduce = {'sum': udf_sum, 'max': udf_max, 'mean': udf_mean}
fill_value = {'sum': 0, 'max': float("-inf")}
def generate_feature(g, broadcast='none'):
def generate_feature(g, broadcast='none', binary_op='none'):
"""Create graph with src, edge, dst feature. broadcast can be 'u',
'e', 'v', 'none'
"""
nv = g.number_of_nodes()
ne = g.number_of_edges()
if broadcast == 'e':
u = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3)))
e = F.tensor(np.random.uniform(-1, 1, (ne, D2, 1)))
v = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3)))
elif broadcast == 'u':
u = F.tensor(np.random.uniform(-1, 1, (nv, D2, 1)))
e = F.tensor(np.random.uniform(-1, 1, (ne, D1, D2, D3)))
v = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3)))
elif broadcast == 'v':
u = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3)))
e = F.tensor(np.random.uniform(-1, 1, (ne, D1, D2, D3)))
v = F.tensor(np.random.uniform(-1, 1, (nv, D2, 1)))
if binary_op == 'dot':
if broadcast == 'e':
u = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3, D4)))
e = F.tensor(np.random.uniform(-1, 1, (ne, D2, 1, D4)))
v = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3, D4)))
elif broadcast == 'u':
u = F.tensor(np.random.uniform(-1, 1, (nv, D2, 1, D4)))
e = F.tensor(np.random.uniform(-1, 1, (ne, D1, D2, D3, D4)))
v = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3, D4)))
elif broadcast == 'v':
u = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3, D4)))
e = F.tensor(np.random.uniform(-1, 1, (ne, D1, D2, D3, D4)))
v = F.tensor(np.random.uniform(-1, 1, (nv, D2, 1, D4)))
else:
u = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3, D4)))
e = F.tensor(np.random.uniform(-1, 1, (ne, D1, D2, D3, D4)))
v = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3, D4)))
else:
u = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3)))
e = F.tensor(np.random.uniform(-1, 1, (ne, D1, D2, D3)))
v = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3)))
if broadcast == 'e':
u = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3)))
e = F.tensor(np.random.uniform(-1, 1, (ne, D2, 1)))
v = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3)))
elif broadcast == 'u':
u = F.tensor(np.random.uniform(-1, 1, (nv, D2, 1)))
e = F.tensor(np.random.uniform(-1, 1, (ne, D1, D2, D3)))
v = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3)))
elif broadcast == 'v':
u = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3)))
e = F.tensor(np.random.uniform(-1, 1, (ne, D1, D2, D3)))
v = F.tensor(np.random.uniform(-1, 1, (nv, D2, 1)))
else:
u = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3)))
e = F.tensor(np.random.uniform(-1, 1, (ne, D1, D2, D3)))
v = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3)))
return u, v, e
......@@ -62,7 +81,7 @@ def test_copy_src_reduce():
# NOTE(zihao): add self-loop to avoid zero-degree nodes.
# https://github.com/dmlc/dgl/issues/761
g.add_edges(g.nodes(), g.nodes())
hu, hv, he = generate_feature(g, 'none')
hu, hv, he = generate_feature(g, 'none', 'none')
if partial:
nid = F.tensor(list(range(0, 100, 2)))
......@@ -123,7 +142,7 @@ def test_copy_edge_reduce():
g = dgl.DGLGraph(nx.erdos_renyi_graph(100, 0.1))
# NOTE(zihao): add self-loop to avoid zero-degree nodes.
g.add_edges(g.nodes(), g.nodes())
hu, hv, he = generate_feature(g, 'none')
hu, hv, he = generate_feature(g, 'none', 'none')
if partial:
nid = F.tensor(list(range(0, 100, 2)))
......@@ -182,7 +201,7 @@ def test_copy_edge_reduce():
def test_all_binary_builtins():
def _test(g, lhs, rhs, binary_op, reducer, partial, nid, broadcast='none'):
# initialize node/edge features with uniform(-1, 1)
hu, hv, he = generate_feature(g, broadcast)
hu, hv, he = generate_feature(g, broadcast, binary_op)
if binary_op == 'div':
# op = div
# lhs range: [-1, 1]
......@@ -321,14 +340,15 @@ def test_all_binary_builtins():
for lhs, rhs in product(target, target):
if lhs == rhs:
continue
for binary_op in ["add", "sub", "mul", "div"]:
for binary_op in ["add", "sub", "mul", "div", "dot"]:
for reducer in ["sum", "max", "min", "prod", "mean"]:
for broadcast in ["none", lhs, rhs]:
for partial in [False, True]:
_test(g, lhs, rhs, binary_op, reducer, partial, nid,
broadcast=broadcast)
if __name__ == '__main__':
#test_copy_src_reduce()
#test_copy_edge_reduce()
test_all_binary_builtins()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment