Commit 3683a774 authored by Gan Quan's avatar Gan Quan
Browse files

Merge branch 'cpp' of github.com:jermainewang/dgl into cpp

parents f1ede61f c9e3c658
#!/usr/bin/env groovy
def setup() {
sh 'easy_install nose'
sh 'git submodule init'
sh 'git submodule update'
}
def build_dgl() {
sh 'if [ -d build ]; then rm -rf build; fi; mkdir build'
dir('python') {
sh 'python3 setup.py install'
}
dir ('build') {
sh 'cmake ..'
sh 'make -j$(nproc)'
}
}
def unit_test() {
withEnv(["DGL_LIBRARY_PATH=${env.WORKSPACE}/build"]) {
sh 'nosetests tests -v --with-xunit'
sh 'nosetests tests/pytorch -v --with-xunit'
sh 'nosetests tests/graph_index -v --with-xunit'
}
}
def example_test(dev) {
dir ('tests/scripts') {
withEnv(["DGL_LIBRARY_PATH=${env.WORKSPACE}/build"]) {
sh "./test_examples.sh ${dev}"
}
}
}
pipeline { pipeline {
agent none agent none
stages { stages {
...@@ -7,36 +42,27 @@ pipeline { ...@@ -7,36 +42,27 @@ pipeline {
agent { agent {
docker { docker {
image 'lingfanyu/dgl-cpu' image 'lingfanyu/dgl-cpu'
args '-u root'
} }
} }
stages { stages {
stage('SETUP') { stage('SETUP') {
steps { steps {
sh 'easy_install nose' setup()
sh 'git submodule init'
sh 'git submodule update'
} }
} }
stage('BUILD') { stage('BUILD') {
steps { steps {
sh 'if [ -d build ]; then rm -rf build; fi; mkdir build' build_dgl()
dir('python') { }
sh 'python3 setup.py install' }
} stage('UNIT TEST') {
dir ('build') { steps {
sh 'cmake ..' unit_test()
sh 'make -j$(nproc)'
}
} }
} }
stage('TEST') { stage('EXAMPLE TEST') {
steps { steps {
withEnv(["DGL_LIBRARY_PATH=${env.WORKSPACE}/build"]) { example_test('CPU')
sh 'nosetests tests -v --with-xunit'
sh 'nosetests tests/pytorch -v --with-xunit'
sh 'nosetests tests/graph_index -v --with-xunit'
}
} }
} }
} }
...@@ -50,36 +76,28 @@ pipeline { ...@@ -50,36 +76,28 @@ pipeline {
agent { agent {
docker { docker {
image 'lingfanyu/dgl-gpu' image 'lingfanyu/dgl-gpu'
args '--runtime nvidia -u root' args '--runtime nvidia'
} }
} }
stages { stages {
stage('SETUP') { stage('SETUP') {
steps { steps {
sh 'easy_install nose' setup()
sh 'git submodule init'
sh 'git submodule update'
} }
} }
stage('BUILD') { stage('BUILD') {
steps { steps {
sh 'if [ -d build ]; then rm -rf build; fi; mkdir build' build_dgl()
dir('python') { }
sh 'python3 setup.py install' }
} stage('UNIT TEST') {
dir ('build') { steps {
sh 'cmake ..' unit_test()
sh 'make -j$(nproc)'
}
} }
} }
stage('TEST') { stage('EXAMPLE TEST') {
steps { steps {
withEnv(["DGL_LIBRARY_PATH=${env.WORKSPACE}/build"]) { example_test('GPU')
sh 'nosetests tests -v --with-xunit'
sh 'nosetests tests/pytorch -v --with-xunit'
sh 'nosetests tests/graph_index -v --with-xunit'
}
} }
} }
} }
......
...@@ -56,7 +56,7 @@ class GCN(nn.Module): ...@@ -56,7 +56,7 @@ class GCN(nn.Module):
g.apply_nodes(apply_node_func= g.apply_nodes(apply_node_func=
lambda node: F.dropout(node['h'], p=self.dropout)) lambda node: F.dropout(node['h'], p=self.dropout))
self.g.update_all(fn.copy_src(src='h', out='m'), self.g.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msgs='m', out='h'), fn.sum(msg='m', out='h'),
layer) layer)
return self.g.pop_n_repr('h') return self.g.pop_n_repr('h')
......
// DGL Graph interface /*!
#ifndef DGL_DGLGRAPH_H_ * Copyright (c) 2018 by Contributors
#define DGL_DGLGRAPH_H_ * \file dgl/graph.h
* \brief DGL graph index class.
*/
#ifndef DGL_GRAPH_H_
#define DGL_GRAPH_H_
#include <stdint.h> #include <vector>
#include <cstdint>
#include "runtime/ndarray.h" #include "runtime/ndarray.h"
namespace dgl { namespace dgl {
...@@ -17,7 +22,7 @@ class GraphOp; ...@@ -17,7 +22,7 @@ class GraphOp;
struct Subgraph; struct Subgraph;
/*! /*!
* \brief Base dgl graph class. * \brief Base dgl graph index class.
* *
* DGL's graph is directed. Vertices are integers enumerated from zero. Edges * DGL's graph is directed. Vertices are integers enumerated from zero. Edges
* are uniquely identified by the two endpoints. Multi-edge is currently not * are uniquely identified by the two endpoints. Multi-edge is currently not
...@@ -41,7 +46,7 @@ class Graph { ...@@ -41,7 +46,7 @@ class Graph {
} EdgeArray; } EdgeArray;
/*! \brief default constructor */ /*! \brief default constructor */
Graph(bool multigraph = false) : is_multigraph_(multigraph) {} explicit Graph(bool multigraph = false) : is_multigraph_(multigraph) {}
/*! \brief default copy constructor */ /*! \brief default copy constructor */
Graph(const Graph& other) = default; Graph(const Graph& other) = default;
...@@ -192,7 +197,7 @@ class Graph { ...@@ -192,7 +197,7 @@ class Graph {
* \return the id arrays of the two endpoints of the edges. * \return the id arrays of the two endpoints of the edges.
*/ */
EdgeArray InEdges(IdArray vids) const; EdgeArray InEdges(IdArray vids) const;
/*! /*!
* \brief Get the out edges of the vertex. * \brief Get the out edges of the vertex.
* \note The returned src id array is filled with vid. * \note The returned src id array is filled with vid.
...@@ -347,4 +352,4 @@ struct Subgraph { ...@@ -347,4 +352,4 @@ struct Subgraph {
} // namespace dgl } // namespace dgl
#endif // DGL_DGLGRAPH_H_ #endif // DGL_GRAPH_H_
// Graph operations /*!
* Copyright (c) 2018 by Contributors
* \file dgl/graph_op.h
* \brief Operations on graph index.
*/
#ifndef DGL_GRAPH_OP_H_ #ifndef DGL_GRAPH_OP_H_
#define DGL_GRAPH_OP_H_ #define DGL_GRAPH_OP_H_
#include <vector>
#include "graph.h" #include "graph.h"
namespace dgl { namespace dgl {
......
# C API and runtime
Borrowed and adapted from TVM project.
...@@ -7,8 +7,8 @@ ...@@ -7,8 +7,8 @@
* used by compiled tvm operators, usually user do not need to use these * used by compiled tvm operators, usually user do not need to use these
* function directly. * function directly.
*/ */
#ifndef TVM_RUNTIME_C_BACKEND_API_H_ #ifndef DGL_RUNTIME_C_BACKEND_API_H_
#define TVM_RUNTIME_C_BACKEND_API_H_ #define DGL_RUNTIME_C_BACKEND_API_H_
#include "c_runtime_api.h" #include "c_runtime_api.h"
...@@ -136,4 +136,4 @@ TVM_DLL int TVMBackendRunOnce(void** handle, ...@@ -136,4 +136,4 @@ TVM_DLL int TVMBackendRunOnce(void** handle,
#ifdef __cplusplus #ifdef __cplusplus
} // TVM_EXTERN_C } // TVM_EXTERN_C
#endif #endif
#endif // TVM_RUNTIME_C_BACKEND_API_H_ #endif // DGL_RUNTIME_C_BACKEND_API_H_
/*! /*!
* Copyright (c) 2016 by Contributors * Copyright (c) 2016 by Contributors
* \file tvm/runtime/c_runtime_api.h * \file dgl/runtime/c_runtime_api.h
* \brief TVM runtime library. * \brief TVM runtime library.
* *
* The philosophy of TVM project is to customize the compilation * The philosophy of TVM project is to customize the compilation
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
* - Use TVMFuncListGlobalNames to get global function name * - Use TVMFuncListGlobalNames to get global function name
* - Use TVMFuncCall to call these functions. * - Use TVMFuncCall to call these functions.
*/ */
#ifndef TVM_RUNTIME_C_RUNTIME_API_H_ #ifndef DGL_RUNTIME_C_RUNTIME_API_H_
#define TVM_RUNTIME_C_RUNTIME_API_H_ #define DGL_RUNTIME_C_RUNTIME_API_H_
// Macros to do weak linking // Macros to do weak linking
#ifdef _MSC_VER #ifdef _MSC_VER
...@@ -530,4 +530,4 @@ TVM_DLL int TVMStreamStreamSynchronize(int device_type, ...@@ -530,4 +530,4 @@ TVM_DLL int TVMStreamStreamSynchronize(int device_type,
#ifdef __cplusplus #ifdef __cplusplus
} // TVM_EXTERN_C } // TVM_EXTERN_C
#endif #endif
#endif // TVM_RUNTIME_C_RUNTIME_API_H_ #endif // DGL_RUNTIME_C_RUNTIME_API_H_
/*! /*!
* Copyright (c) 2016 by Contributors * Copyright (c) 2016 by Contributors
* \file tvm/runtime/device_api.h * \file dgl/runtime/device_api.h
* \brief Abstract device memory management API * \brief Abstract device memory management API
*/ */
#ifndef TVM_RUNTIME_DEVICE_API_H_ #ifndef DGL_RUNTIME_DEVICE_API_H_
#define TVM_RUNTIME_DEVICE_API_H_ #define DGL_RUNTIME_DEVICE_API_H_
#include <string> #include <string>
#include "packed_func.h" #include "packed_func.h"
...@@ -180,4 +180,4 @@ class DeviceAPI { ...@@ -180,4 +180,4 @@ class DeviceAPI {
constexpr int kRPCSessMask = 128; constexpr int kRPCSessMask = 128;
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
#endif // TVM_RUNTIME_DEVICE_API_H_ #endif // DGL_RUNTIME_DEVICE_API_H_
/*! /*!
* Copyright (c) 2017 by Contributors * Copyright (c) 2017 by Contributors
* \file tvm/runtime/module.h * \file dgl/runtime/module.h
* \brief Runtime container of the functions generated by TVM, * \brief Runtime container of the functions generated by TVM,
* This is used to support dynamically link, load and save * This is used to support dynamically link, load and save
* functions from different convention under unified API. * functions from different convention under unified API.
*/ */
#ifndef TVM_RUNTIME_MODULE_H_ #ifndef DGL_RUNTIME_MODULE_H_
#define TVM_RUNTIME_MODULE_H_ #define DGL_RUNTIME_MODULE_H_
#include <dmlc/io.h> #include <dmlc/io.h>
#include <memory> #include <memory>
...@@ -174,4 +174,4 @@ inline const ModuleNode* Module::operator->() const { ...@@ -174,4 +174,4 @@ inline const ModuleNode* Module::operator->() const {
} // namespace tvm } // namespace tvm
#include "packed_func.h" #include "packed_func.h"
#endif // TVM_RUNTIME_MODULE_H_ #endif // DGL_RUNTIME_MODULE_H_
/*! /*!
* Copyright (c) 2017 by Contributors * Copyright (c) 2017 by Contributors
* \file tvm/runtime/ndarray.h * \file dgl/runtime/ndarray.h
* \brief Abstract device memory management API * \brief Abstract device memory management API
*/ */
#ifndef TVM_RUNTIME_NDARRAY_H_ #ifndef DGL_RUNTIME_NDARRAY_H_
#define TVM_RUNTIME_NDARRAY_H_ #define DGL_RUNTIME_NDARRAY_H_
#include <atomic> #include <atomic>
#include <vector> #include <vector>
...@@ -422,4 +422,4 @@ inline bool NDArray::Load(dmlc::Stream* strm) { ...@@ -422,4 +422,4 @@ inline bool NDArray::Load(dmlc::Stream* strm) {
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
#endif // TVM_RUNTIME_NDARRAY_H_ #endif // DGL_RUNTIME_NDARRAY_H_
/*! /*!
* Copyright (c) 2017 by Contributors * Copyright (c) 2017 by Contributors
* \file tvm/runtime/packed_func.h * \file dgl/runtime/packed_func.h
* \brief Type-erased function used across TVM API. * \brief Type-erased function used across TVM API.
*/ */
#ifndef TVM_RUNTIME_PACKED_FUNC_H_ #ifndef DGL_RUNTIME_PACKED_FUNC_H_
#define TVM_RUNTIME_PACKED_FUNC_H_ #define DGL_RUNTIME_PACKED_FUNC_H_
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <functional> #include <functional>
...@@ -1212,4 +1212,4 @@ inline PackedFunc Module::GetFunction(const std::string& name, bool query_import ...@@ -1212,4 +1212,4 @@ inline PackedFunc Module::GetFunction(const std::string& name, bool query_import
} }
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
#endif // TVM_RUNTIME_PACKED_FUNC_H_ #endif // DGL_RUNTIME_PACKED_FUNC_H_
/*! /*!
* Copyright (c) 2017 by Contributors * Copyright (c) 2017 by Contributors
* \file tvm/runtime/registry.h * \file dgl/runtime/registry.h
* \brief This file defines the TVM global function registry. * \brief This file defines the TVM global function registry.
* *
* The registered functions will be made available to front-end * The registered functions will be made available to front-end
...@@ -22,8 +22,8 @@ ...@@ -22,8 +22,8 @@
* }); * });
* \endcode * \endcode
*/ */
#ifndef TVM_RUNTIME_REGISTRY_H_ #ifndef DGL_RUNTIME_REGISTRY_H_
#define TVM_RUNTIME_REGISTRY_H_ #define DGL_RUNTIME_REGISTRY_H_
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -141,4 +141,4 @@ class Registry { ...@@ -141,4 +141,4 @@ class Registry {
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
#endif // TVM_RUNTIME_REGISTRY_H_ #endif // DGL_RUNTIME_REGISTRY_H_
/*! /*!
* Copyright (c) 2017 by Contributors * Copyright (c) 2017 by Contributors
* \file tvm/runtime/serializer.h * \file dgl/runtime/serializer.h
* \brief Serializer extension to support TVM data types * \brief Serializer extension to support TVM data types
* Include this file to enable serialization of DLDataType, DLContext * Include this file to enable serialization of DLDataType, DLContext
*/ */
#ifndef TVM_RUNTIME_SERIALIZER_H_ #ifndef DGL_RUNTIME_SERIALIZER_H_
#define TVM_RUNTIME_SERIALIZER_H_ #define DGL_RUNTIME_SERIALIZER_H_
#include <dmlc/io.h> #include <dmlc/io.h>
#include <dmlc/serializer.h> #include <dmlc/serializer.h>
...@@ -48,4 +48,4 @@ struct Handler<DLContext> { ...@@ -48,4 +48,4 @@ struct Handler<DLContext> {
} // namespace serializer } // namespace serializer
} // namespace dmlc } // namespace dmlc
#endif // TVM_RUNTIME_SERIALIZER_H_ #endif // DGL_RUNTIME_SERIALIZER_H_
/*! /*!
* Copyright (c) 2018 by Contributors * Copyright (c) 2018 by Contributors
* \file tvm/runtime/threading_backend.h * \file dgl/runtime/threading_backend.h
* \brief Utilities for manipulating thread pool threads. * \brief Utilities for manipulating thread pool threads.
*/ */
#ifndef TVM_RUNTIME_THREADING_BACKEND_H_ #ifndef DGL_RUNTIME_THREADING_BACKEND_H_
#define TVM_RUNTIME_THREADING_BACKEND_H_ #define DGL_RUNTIME_THREADING_BACKEND_H_
#include <functional> #include <functional>
#include <memory> #include <memory>
...@@ -82,4 +82,4 @@ int MaxConcurrency(); ...@@ -82,4 +82,4 @@ int MaxConcurrency();
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
#endif // TVM_RUNTIME_THREADING_BACKEND_H_ #endif // DGL_RUNTIME_THREADING_BACKEND_H_
/*! /*!
* Copyright (c) 2017 by Contributors * Copyright (c) 2017 by Contributors
* \file tvm/runtime/util.h * \file dgl/runtime/util.h
* \brief Useful runtime util. * \brief Useful runtime util.
*/ */
#ifndef TVM_RUNTIME_UTIL_H_ #ifndef DGL_RUNTIME_UTIL_H_
#define TVM_RUNTIME_UTIL_H_ #define DGL_RUNTIME_UTIL_H_
#include "c_runtime_api.h" #include "c_runtime_api.h"
...@@ -50,4 +50,4 @@ enum TVMStructFieldKind : int { ...@@ -50,4 +50,4 @@ enum TVMStructFieldKind : int {
} // namespace intrinsic } // namespace intrinsic
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
#endif // TVM_RUNTIME_UTIL_H_ #endif // DGL_RUNTIME_UTIL_H_
// DGL Scheduler interface /*!
* Copyright (c) 2018 by Contributors
* \file dgl/scheduler.h
* \brief Operations on graph index.
*/
#ifndef DGL_SCHEDULER_H_ #ifndef DGL_SCHEDULER_H_
#define DGL_SCHEDULER_H_ #define DGL_SCHEDULER_H_
#include "runtime/ndarray.h"
#include <vector> #include <vector>
#include "runtime/ndarray.h"
namespace dgl { namespace dgl {
...@@ -25,8 +29,8 @@ namespace sched { ...@@ -25,8 +29,8 @@ namespace sched {
*/ */
std::vector<IdArray> DegreeBucketing(const IdArray& vids); std::vector<IdArray> DegreeBucketing(const IdArray& vids);
} // namespace sched } // namespace sched
} // namespace dgl } // namespace dgl
#endif // DGL_SCHEDULER_H_ #endif // DGL_SCHEDULER_H_
...@@ -11,5 +11,5 @@ from ._ffi.base import DGLError, __version__ ...@@ -11,5 +11,5 @@ from ._ffi.base import DGLError, __version__
from .base import ALL from .base import ALL
from .batched_graph import * from .batched_graph import *
from .generator import * from .generator import *
from .graph import DGLGraph, __MSG__, __REPR__ from .graph import DGLGraph
from .subgraph import DGLSubGraph from .subgraph import DGLSubGraph
...@@ -11,7 +11,4 @@ ALL = "__ALL__" ...@@ -11,7 +11,4 @@ ALL = "__ALL__"
def is_all(arg): def is_all(arg):
return isinstance(arg, str) and arg == ALL return isinstance(arg, str) and arg == ALL
__MSG__ = "__MSG__"
__REPR__ = "__REPR__"
dgl_warning = warnings.warn dgl_warning = warnings.warn
...@@ -4,17 +4,25 @@ from __future__ import absolute_import ...@@ -4,17 +4,25 @@ from __future__ import absolute_import
import operator import operator
import dgl.backend as F import dgl.backend as F
__all__ = ["MessageFunction", "src_mul_edge", "copy_src", "copy_edge"] __all__ = ["src_mul_edge", "copy_src", "copy_edge"]
class MessageFunction(object): class MessageFunction(object):
"""Base builtin message function class."""
def __call__(self, src, edge): def __call__(self, src, edge):
"""Regular computation of this builtin.
This will be used when optimization is not available.
"""
raise NotImplementedError raise NotImplementedError
def name(self): def name(self):
"""Return the name of this builtin function."""
raise NotImplementedError raise NotImplementedError
def is_spmv_supported(self, g): def is_spmv_supported(self, g):
"""Return whether the SPMV optimization is supported."""
raise NotImplementedError raise NotImplementedError
...@@ -22,12 +30,6 @@ class BundledMessageFunction(MessageFunction): ...@@ -22,12 +30,6 @@ class BundledMessageFunction(MessageFunction):
def __init__(self, fn_list): def __init__(self, fn_list):
if not isinstance(fn_list, (list, tuple)): if not isinstance(fn_list, (list, tuple)):
fn_list = [fn_list] fn_list = [fn_list]
else:
# sanity check on out field
for fn in fn_list:
# cannot perform check for udf
if isinstance(fn, MessageFunction) and fn.out_field is None:
raise RuntimeError("Not specifying out field for multiple message is ambiguous")
self.fn_list = fn_list self.fn_list = fn_list
def is_spmv_supported(self, g): def is_spmv_supported(self, g):
...@@ -43,11 +45,8 @@ class BundledMessageFunction(MessageFunction): ...@@ -43,11 +45,8 @@ class BundledMessageFunction(MessageFunction):
if ret is None: if ret is None:
ret = msg ret = msg
else: else:
try: # ret and msg must be dict
# ret and msg must be dict ret.update(msg)
ret.update(msg)
except:
raise RuntimeError("Must specify out field for multiple message")
return ret return ret
def name(self): def name(self):
...@@ -55,25 +54,26 @@ class BundledMessageFunction(MessageFunction): ...@@ -55,25 +54,26 @@ class BundledMessageFunction(MessageFunction):
def _is_spmv_supported_node_feat(g, field): def _is_spmv_supported_node_feat(g, field):
if field is None: """Return whether the node feature shape supports SPMV optimization.
feat = g.get_n_repr()
else: Only scalar and vector features are supported currently.
feat = g.get_n_repr()[field] """
feat = g.get_n_repr()[field]
shape = F.shape(feat) shape = F.shape(feat)
return len(shape) == 1 or len(shape) == 2 return len(shape) == 1 or len(shape) == 2
def _is_spmv_supported_edge_feat(g, field): def _is_spmv_supported_edge_feat(g, field):
# check shape, only scalar edge feature can be optimized at the moment """Return whether the edge feature shape supports SPMV optimization.
if field is None:
feat = g.get_e_repr() Only scalar feature is supported currently.
else: """
feat = g.get_e_repr()[field] feat = g.get_e_repr()[field]
shape = F.shape(feat) shape = F.shape(feat)
return len(shape) == 1 or (len(shape) == 2 and shape[1] == 1) return len(shape) == 1 or (len(shape) == 2 and shape[1] == 1)
class SrcMulEdgeMessageFunction(MessageFunction): class SrcMulEdgeMessageFunction(MessageFunction):
def __init__(self, mul_op, src_field=None, edge_field=None, out_field=None): def __init__(self, mul_op, src_field, edge_field, out_field):
self.mul_op = mul_op self.mul_op = mul_op
self.src_field = src_field self.src_field = src_field
self.edge_field = edge_field self.edge_field = edge_field
...@@ -84,21 +84,14 @@ class SrcMulEdgeMessageFunction(MessageFunction): ...@@ -84,21 +84,14 @@ class SrcMulEdgeMessageFunction(MessageFunction):
and _is_spmv_supported_edge_feat(g, self.edge_field) and _is_spmv_supported_edge_feat(g, self.edge_field)
def __call__(self, src, edge): def __call__(self, src, edge):
if self.src_field is not None: ret = self.mul_op(src[self.src_field], edge[self.edge_field])
src = src[self.src_field] return {self.out_field : ret}
if self.edge_field is not None:
edge = edge[self.edge_field]
ret = self.mul_op(src, edge)
if self.out_field is None:
return ret
else:
return {self.out_field : ret}
def name(self): def name(self):
return "src_mul_edge" return "src_mul_edge"
class CopySrcMessageFunction(MessageFunction): class CopySrcMessageFunction(MessageFunction):
def __init__(self, src_field=None, out_field=None): def __init__(self, src_field, out_field):
self.src_field = src_field self.src_field = src_field
self.out_field = out_field self.out_field = out_field
...@@ -106,14 +99,7 @@ class CopySrcMessageFunction(MessageFunction): ...@@ -106,14 +99,7 @@ class CopySrcMessageFunction(MessageFunction):
return _is_spmv_supported_node_feat(g, self.src_field) return _is_spmv_supported_node_feat(g, self.src_field)
def __call__(self, src, edge): def __call__(self, src, edge):
if self.src_field is not None: return {self.out_field : src[self.src_field]}
ret = src[self.src_field]
else:
ret = src
if self.out_field is None:
return ret
else:
return {self.out_field : ret}
def name(self): def name(self):
return "copy_src" return "copy_src"
...@@ -142,14 +128,41 @@ class CopyEdgeMessageFunction(MessageFunction): ...@@ -142,14 +128,41 @@ class CopyEdgeMessageFunction(MessageFunction):
return "copy_edge" return "copy_edge"
def src_mul_edge(src=None, edge=None, out=None): def src_mul_edge(src, edge, out):
"""TODO(minjie): docstring """ """Builtin message function that computes message by multiplying source node features
with edge features.
Parameters
----------
src : str
The source feature name.
edge : str
The edge feature name.
out : str
The output message name.
"""
return SrcMulEdgeMessageFunction(operator.mul, src, edge, out) return SrcMulEdgeMessageFunction(operator.mul, src, edge, out)
def copy_src(src=None, out=None): def copy_src(src, out):
"""TODO(minjie): docstring """ """Builtin message function that computes message using source node feature.
Parameters
----------
src : str
The source feature name.
out : str
The output message name.
"""
return CopySrcMessageFunction(src, out) return CopySrcMessageFunction(src, out)
def copy_edge(edge=None, out=None): def copy_edge(edge, out):
"""TODO(minjie): docstring """ """Builtin message function that computes message using edge feature.
Parameters
----------
edge : str
The edge feature name.
out : str
The output message name.
"""
return CopyEdgeMessageFunction(edge, out) return CopyEdgeMessageFunction(edge, out)
...@@ -3,27 +3,30 @@ from __future__ import absolute_import ...@@ -3,27 +3,30 @@ from __future__ import absolute_import
from .. import backend as F from .. import backend as F
__all__ = ["ReduceFunction", "sum", "max"] __all__ = ["sum", "max"]
class ReduceFunction(object): class ReduceFunction(object):
"""Base builtin reduce function class."""
def __call__(self, node, msgs): def __call__(self, node, msgs):
"""Regular computation of this builtin.
This will be used when optimization is not available.
"""
raise NotImplementedError raise NotImplementedError
def name(self): def name(self):
"""Return the name of this builtin function."""
raise NotImplementedError raise NotImplementedError
def is_spmv_supported(self): def is_spmv_supported(self):
"""Return whether the SPMV optimization is supported."""
raise NotImplementedError raise NotImplementedError
class BundledReduceFunction(ReduceFunction): class BundledReduceFunction(ReduceFunction):
def __init__(self, fn_list): def __init__(self, fn_list):
if not isinstance(fn_list, (list, tuple)): if not isinstance(fn_list, (list, tuple)):
fn_list = [fn_list] fn_list = [fn_list]
else:
# sanity check on out field
for fn in fn_list:
if isinstance(fn, ReduceFunction) and fn.out_field is None:
raise RuntimeError("Not specifying out field for multiple reduce is ambiguous")
self.fn_list = fn_list self.fn_list = fn_list
def is_spmv_supported(self): def is_spmv_supported(self):
...@@ -39,51 +42,50 @@ class BundledReduceFunction(ReduceFunction): ...@@ -39,51 +42,50 @@ class BundledReduceFunction(ReduceFunction):
if ret is None: if ret is None:
ret = rpr ret = rpr
else: else:
try: # ret and rpr must be dict
# ret and rpr must be dict ret.update(rpr)
ret.update(rpr)
except:
raise RuntimeError("Must specify out field for multiple reudce")
return ret return ret
def name(self): def name(self):
return "bundled" return "bundled"
class ReducerFunctionTemplate(ReduceFunction): class ReducerFunctionTemplate(ReduceFunction):
def __init__(self, name, batch_op, nonbatch_op, msg_field=None, out_field=None): def __init__(self, name, op, msg_field, out_field):
self.name = name self.name = name
self.batch_op = batch_op self.op = op
self.nonbatch_op = nonbatch_op
self.msg_field = msg_field self.msg_field = msg_field
self.out_field = out_field self.out_field = out_field
def is_spmv_supported(self): def is_spmv_supported(self):
# TODO: support max # NOTE: only sum is supported right now.
return self.name == "sum" return self.name == "sum"
def __call__(self, node, msgs): def __call__(self, node, msgs):
if isinstance(msgs, list): return {self.out_field : self.op(msgs[self.msg_field], 1)}
if self.msg_field is None:
ret = self.nonbatch_op(msgs)
else:
ret = self.nonbatch_op([msg[self.msg_field] for msg in msgs])
else:
if self.msg_field is None:
ret = self.batch_op(msgs, 1)
else:
ret = self.batch_op(msgs[self.msg_field], 1)
if self.out_field is None:
return ret
else:
return {self.out_field : ret}
def name(self): def name(self):
return self.name return self.name
_python_sum = sum def sum(msg, out):
def sum(msgs=None, out=None): """Builtin reduce function that aggregates messages by sum.
return ReducerFunctionTemplate("sum", F.sum, _python_sum, msgs, out)
Parameters
----------
msg : str
The message name.
out : str
The output node feature name.
"""
return ReducerFunctionTemplate("sum", F.sum, msg, out)
def max(msg, out):
"""Builtin reduce function that aggregates messages by max.
_python_max = max Parameters
def max(msgs=None, out=None): ----------
return ReducerFunctionTemplate("max", F.max, _python_max, msgs, out) msg : str
The message name.
out : str
The output node feature name.
"""
return ReducerFunctionTemplate("max", F.max, msg, out)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment