Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
3683a774
Commit
3683a774
authored
Oct 18, 2018
by
Gan Quan
Browse files
Merge branch 'cpp' of github.com:jermainewang/dgl into cpp
parents
f1ede61f
c9e3c658
Changes
45
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
219 additions
and
178 deletions
+219
-178
Jenkinsfile
Jenkinsfile
+54
-36
examples/pytorch/gcn/gcn_spmv.py
examples/pytorch/gcn/gcn_spmv.py
+1
-1
include/dgl/graph.h
include/dgl/graph.h
+13
-8
include/dgl/graph_op.h
include/dgl/graph_op.h
+6
-1
include/dgl/runtime/README.md
include/dgl/runtime/README.md
+0
-3
include/dgl/runtime/c_backend_api.h
include/dgl/runtime/c_backend_api.h
+3
-3
include/dgl/runtime/c_runtime_api.h
include/dgl/runtime/c_runtime_api.h
+4
-4
include/dgl/runtime/device_api.h
include/dgl/runtime/device_api.h
+4
-4
include/dgl/runtime/module.h
include/dgl/runtime/module.h
+4
-4
include/dgl/runtime/ndarray.h
include/dgl/runtime/ndarray.h
+4
-4
include/dgl/runtime/packed_func.h
include/dgl/runtime/packed_func.h
+4
-4
include/dgl/runtime/registry.h
include/dgl/runtime/registry.h
+4
-4
include/dgl/runtime/serializer.h
include/dgl/runtime/serializer.h
+4
-4
include/dgl/runtime/threading_backend.h
include/dgl/runtime/threading_backend.h
+4
-4
include/dgl/runtime/util.h
include/dgl/runtime/util.h
+4
-4
include/dgl/scheduler.h
include/dgl/scheduler.h
+9
-5
python/dgl/__init__.py
python/dgl/__init__.py
+1
-1
python/dgl/base.py
python/dgl/base.py
+0
-3
python/dgl/function/message.py
python/dgl/function/message.py
+59
-46
python/dgl/function/reducer.py
python/dgl/function/reducer.py
+37
-35
No files found.
Jenkinsfile
View file @
3683a774
#
!
/usr/
bin
/
env
groovy
def
setup
()
{
sh
'easy_install nose'
sh
'git submodule init'
sh
'git submodule update'
}
def
build_dgl
()
{
sh
'if [ -d build ]; then rm -rf build; fi; mkdir build'
dir
(
'python'
)
{
sh
'python3 setup.py install'
}
dir
(
'build'
)
{
sh
'cmake ..'
sh
'make -j$(nproc)'
}
}
def
unit_test
()
{
withEnv
([
"DGL_LIBRARY_PATH=${env.WORKSPACE}/build"
])
{
sh
'nosetests tests -v --with-xunit'
sh
'nosetests tests/pytorch -v --with-xunit'
sh
'nosetests tests/graph_index -v --with-xunit'
}
}
def
example_test
(
dev
)
{
dir
(
'tests/scripts'
)
{
withEnv
([
"DGL_LIBRARY_PATH=${env.WORKSPACE}/build"
])
{
sh
"./test_examples.sh ${dev}"
}
}
}
pipeline
{
agent
none
stages
{
...
...
@@ -7,36 +42,27 @@ pipeline {
agent
{
docker
{
image
'lingfanyu/dgl-cpu'
args
'-u root'
}
}
stages
{
stage
(
'SETUP'
)
{
steps
{
sh
'easy_install nose'
sh
'git submodule init'
sh
'git submodule update'
setup
()
}
}
stage
(
'BUILD'
)
{
steps
{
sh
'if [ -d build ]; then rm -rf build; fi; mkdir build'
dir
(
'python'
)
{
sh
'python3 setup.py install'
}
dir
(
'build'
)
{
sh
'cmake ..'
sh
'make -j$(nproc)'
}
build_dgl
()
}
}
stage
(
'UNIT TEST'
)
{
steps
{
unit_test
()
}
}
stage
(
'TEST'
)
{
stage
(
'
EXAMPLE
TEST'
)
{
steps
{
withEnv
([
"DGL_LIBRARY_PATH=${env.WORKSPACE}/build"
])
{
sh
'nosetests tests -v --with-xunit'
sh
'nosetests tests/pytorch -v --with-xunit'
sh
'nosetests tests/graph_index -v --with-xunit'
}
example_test
(
'CPU'
)
}
}
}
...
...
@@ -50,36 +76,28 @@ pipeline {
agent
{
docker
{
image
'lingfanyu/dgl-gpu'
args
'--runtime nvidia
-u root
'
args
'--runtime nvidia'
}
}
stages
{
stage
(
'SETUP'
)
{
steps
{
sh
'easy_install nose'
sh
'git submodule init'
sh
'git submodule update'
setup
()
}
}
stage
(
'BUILD'
)
{
steps
{
sh
'if [ -d build ]; then rm -rf build; fi; mkdir build'
dir
(
'python'
)
{
sh
'python3 setup.py install'
}
dir
(
'build'
)
{
sh
'cmake ..'
sh
'make -j$(nproc)'
}
build_dgl
()
}
}
stage
(
'UNIT TEST'
)
{
steps
{
unit_test
()
}
}
stage
(
'TEST'
)
{
stage
(
'
EXAMPLE
TEST'
)
{
steps
{
withEnv
([
"DGL_LIBRARY_PATH=${env.WORKSPACE}/build"
])
{
sh
'nosetests tests -v --with-xunit'
sh
'nosetests tests/pytorch -v --with-xunit'
sh
'nosetests tests/graph_index -v --with-xunit'
}
example_test
(
'GPU'
)
}
}
}
...
...
examples/pytorch/gcn/gcn_spmv.py
View file @
3683a774
...
...
@@ -56,7 +56,7 @@ class GCN(nn.Module):
g
.
apply_nodes
(
apply_node_func
=
lambda
node
:
F
.
dropout
(
node
[
'h'
],
p
=
self
.
dropout
))
self
.
g
.
update_all
(
fn
.
copy_src
(
src
=
'h'
,
out
=
'm'
),
fn
.
sum
(
msg
s
=
'm'
,
out
=
'h'
),
fn
.
sum
(
msg
=
'm'
,
out
=
'h'
),
layer
)
return
self
.
g
.
pop_n_repr
(
'h'
)
...
...
include/dgl/graph.h
View file @
3683a774
// DGL Graph interface
#ifndef DGL_DGLGRAPH_H_
#define DGL_DGLGRAPH_H_
/*!
* Copyright (c) 2018 by Contributors
* \file dgl/graph.h
* \brief DGL graph index class.
*/
#ifndef DGL_GRAPH_H_
#define DGL_GRAPH_H_
#include <stdint.h>
#include <vector>
#include <cstdint>
#include "runtime/ndarray.h"
namespace
dgl
{
...
...
@@ -17,7 +22,7 @@ class GraphOp;
struct
Subgraph
;
/*!
* \brief Base dgl graph class.
* \brief Base dgl graph
index
class.
*
* DGL's graph is directed. Vertices are integers enumerated from zero. Edges
* are uniquely identified by the two endpoints. Multi-edge is currently not
...
...
@@ -41,7 +46,7 @@ class Graph {
}
EdgeArray
;
/*! \brief default constructor */
Graph
(
bool
multigraph
=
false
)
:
is_multigraph_
(
multigraph
)
{}
explicit
Graph
(
bool
multigraph
=
false
)
:
is_multigraph_
(
multigraph
)
{}
/*! \brief default copy constructor */
Graph
(
const
Graph
&
other
)
=
default
;
...
...
@@ -192,7 +197,7 @@ class Graph {
* \return the id arrays of the two endpoints of the edges.
*/
EdgeArray
InEdges
(
IdArray
vids
)
const
;
/*!
* \brief Get the out edges of the vertex.
* \note The returned src id array is filled with vid.
...
...
@@ -347,4 +352,4 @@ struct Subgraph {
}
// namespace dgl
#endif // DGL_
DGL
GRAPH_H_
#endif // DGL_GRAPH_H_
include/dgl/graph_op.h
View file @
3683a774
// Graph operations
/*!
* Copyright (c) 2018 by Contributors
* \file dgl/graph_op.h
* \brief Operations on graph index.
*/
#ifndef DGL_GRAPH_OP_H_
#define DGL_GRAPH_OP_H_
#include <vector>
#include "graph.h"
namespace
dgl
{
...
...
include/dgl/runtime/README.md
deleted
100644 → 0
View file @
f1ede61f
# C API and runtime
Borrowed and adapted from TVM project.
include/dgl/runtime/c_backend_api.h
View file @
3683a774
...
...
@@ -7,8 +7,8 @@
* used by compiled tvm operators, usually user do not need to use these
* function directly.
*/
#ifndef
TVM
_RUNTIME_C_BACKEND_API_H_
#define
TVM
_RUNTIME_C_BACKEND_API_H_
#ifndef
DGL
_RUNTIME_C_BACKEND_API_H_
#define
DGL
_RUNTIME_C_BACKEND_API_H_
#include "c_runtime_api.h"
...
...
@@ -136,4 +136,4 @@ TVM_DLL int TVMBackendRunOnce(void** handle,
#ifdef __cplusplus
}
// TVM_EXTERN_C
#endif
#endif //
TVM
_RUNTIME_C_BACKEND_API_H_
#endif //
DGL
_RUNTIME_C_BACKEND_API_H_
include/dgl/runtime/c_runtime_api.h
View file @
3683a774
/*!
* Copyright (c) 2016 by Contributors
* \file
tvm
/runtime/c_runtime_api.h
* \file
dgl
/runtime/c_runtime_api.h
* \brief TVM runtime library.
*
* The philosophy of TVM project is to customize the compilation
...
...
@@ -15,8 +15,8 @@
* - Use TVMFuncListGlobalNames to get global function name
* - Use TVMFuncCall to call these functions.
*/
#ifndef
TVM
_RUNTIME_C_RUNTIME_API_H_
#define
TVM
_RUNTIME_C_RUNTIME_API_H_
#ifndef
DGL
_RUNTIME_C_RUNTIME_API_H_
#define
DGL
_RUNTIME_C_RUNTIME_API_H_
// Macros to do weak linking
#ifdef _MSC_VER
...
...
@@ -530,4 +530,4 @@ TVM_DLL int TVMStreamStreamSynchronize(int device_type,
#ifdef __cplusplus
}
// TVM_EXTERN_C
#endif
#endif //
TVM
_RUNTIME_C_RUNTIME_API_H_
#endif //
DGL
_RUNTIME_C_RUNTIME_API_H_
include/dgl/runtime/device_api.h
View file @
3683a774
/*!
* Copyright (c) 2016 by Contributors
* \file
tvm
/runtime/device_api.h
* \file
dgl
/runtime/device_api.h
* \brief Abstract device memory management API
*/
#ifndef
TVM
_RUNTIME_DEVICE_API_H_
#define
TVM
_RUNTIME_DEVICE_API_H_
#ifndef
DGL
_RUNTIME_DEVICE_API_H_
#define
DGL
_RUNTIME_DEVICE_API_H_
#include <string>
#include "packed_func.h"
...
...
@@ -180,4 +180,4 @@ class DeviceAPI {
constexpr
int
kRPCSessMask
=
128
;
}
// namespace runtime
}
// namespace tvm
#endif //
TVM
_RUNTIME_DEVICE_API_H_
#endif //
DGL
_RUNTIME_DEVICE_API_H_
include/dgl/runtime/module.h
View file @
3683a774
/*!
* Copyright (c) 2017 by Contributors
* \file
tvm
/runtime/module.h
* \file
dgl
/runtime/module.h
* \brief Runtime container of the functions generated by TVM,
* This is used to support dynamically link, load and save
* functions from different convention under unified API.
*/
#ifndef
TVM
_RUNTIME_MODULE_H_
#define
TVM
_RUNTIME_MODULE_H_
#ifndef
DGL
_RUNTIME_MODULE_H_
#define
DGL
_RUNTIME_MODULE_H_
#include <dmlc/io.h>
#include <memory>
...
...
@@ -174,4 +174,4 @@ inline const ModuleNode* Module::operator->() const {
}
// namespace tvm
#include "packed_func.h"
#endif //
TVM
_RUNTIME_MODULE_H_
#endif //
DGL
_RUNTIME_MODULE_H_
include/dgl/runtime/ndarray.h
View file @
3683a774
/*!
* Copyright (c) 2017 by Contributors
* \file
tvm
/runtime/ndarray.h
* \file
dgl
/runtime/ndarray.h
* \brief Abstract device memory management API
*/
#ifndef
TVM
_RUNTIME_NDARRAY_H_
#define
TVM
_RUNTIME_NDARRAY_H_
#ifndef
DGL
_RUNTIME_NDARRAY_H_
#define
DGL
_RUNTIME_NDARRAY_H_
#include <atomic>
#include <vector>
...
...
@@ -422,4 +422,4 @@ inline bool NDArray::Load(dmlc::Stream* strm) {
}
// namespace runtime
}
// namespace tvm
#endif //
TVM
_RUNTIME_NDARRAY_H_
#endif //
DGL
_RUNTIME_NDARRAY_H_
include/dgl/runtime/packed_func.h
View file @
3683a774
/*!
* Copyright (c) 2017 by Contributors
* \file
tvm
/runtime/packed_func.h
* \file
dgl
/runtime/packed_func.h
* \brief Type-erased function used across TVM API.
*/
#ifndef
TVM
_RUNTIME_PACKED_FUNC_H_
#define
TVM
_RUNTIME_PACKED_FUNC_H_
#ifndef
DGL
_RUNTIME_PACKED_FUNC_H_
#define
DGL
_RUNTIME_PACKED_FUNC_H_
#include <dmlc/logging.h>
#include <functional>
...
...
@@ -1212,4 +1212,4 @@ inline PackedFunc Module::GetFunction(const std::string& name, bool query_import
}
}
// namespace runtime
}
// namespace tvm
#endif //
TVM
_RUNTIME_PACKED_FUNC_H_
#endif //
DGL
_RUNTIME_PACKED_FUNC_H_
include/dgl/runtime/registry.h
View file @
3683a774
/*!
* Copyright (c) 2017 by Contributors
* \file
tvm
/runtime/registry.h
* \file
dgl
/runtime/registry.h
* \brief This file defines the TVM global function registry.
*
* The registered functions will be made available to front-end
...
...
@@ -22,8 +22,8 @@
* });
* \endcode
*/
#ifndef
TVM
_RUNTIME_REGISTRY_H_
#define
TVM
_RUNTIME_REGISTRY_H_
#ifndef
DGL
_RUNTIME_REGISTRY_H_
#define
DGL
_RUNTIME_REGISTRY_H_
#include <string>
#include <vector>
...
...
@@ -141,4 +141,4 @@ class Registry {
}
// namespace runtime
}
// namespace tvm
#endif //
TVM
_RUNTIME_REGISTRY_H_
#endif //
DGL
_RUNTIME_REGISTRY_H_
include/dgl/runtime/serializer.h
View file @
3683a774
/*!
* Copyright (c) 2017 by Contributors
* \file
tvm
/runtime/serializer.h
* \file
dgl
/runtime/serializer.h
* \brief Serializer extension to support TVM data types
* Include this file to enable serialization of DLDataType, DLContext
*/
#ifndef
TVM
_RUNTIME_SERIALIZER_H_
#define
TVM
_RUNTIME_SERIALIZER_H_
#ifndef
DGL
_RUNTIME_SERIALIZER_H_
#define
DGL
_RUNTIME_SERIALIZER_H_
#include <dmlc/io.h>
#include <dmlc/serializer.h>
...
...
@@ -48,4 +48,4 @@ struct Handler<DLContext> {
}
// namespace serializer
}
// namespace dmlc
#endif //
TVM
_RUNTIME_SERIALIZER_H_
#endif //
DGL
_RUNTIME_SERIALIZER_H_
include/dgl/runtime/threading_backend.h
View file @
3683a774
/*!
* Copyright (c) 2018 by Contributors
* \file
tvm
/runtime/threading_backend.h
* \file
dgl
/runtime/threading_backend.h
* \brief Utilities for manipulating thread pool threads.
*/
#ifndef
TVM
_RUNTIME_THREADING_BACKEND_H_
#define
TVM
_RUNTIME_THREADING_BACKEND_H_
#ifndef
DGL
_RUNTIME_THREADING_BACKEND_H_
#define
DGL
_RUNTIME_THREADING_BACKEND_H_
#include <functional>
#include <memory>
...
...
@@ -82,4 +82,4 @@ int MaxConcurrency();
}
// namespace runtime
}
// namespace tvm
#endif //
TVM
_RUNTIME_THREADING_BACKEND_H_
#endif //
DGL
_RUNTIME_THREADING_BACKEND_H_
include/dgl/runtime/util.h
View file @
3683a774
/*!
* Copyright (c) 2017 by Contributors
* \file
tvm
/runtime/util.h
* \file
dgl
/runtime/util.h
* \brief Useful runtime util.
*/
#ifndef
TVM
_RUNTIME_UTIL_H_
#define
TVM
_RUNTIME_UTIL_H_
#ifndef
DGL
_RUNTIME_UTIL_H_
#define
DGL
_RUNTIME_UTIL_H_
#include "c_runtime_api.h"
...
...
@@ -50,4 +50,4 @@ enum TVMStructFieldKind : int {
}
// namespace intrinsic
}
// namespace ir
}
// namespace tvm
#endif //
TVM
_RUNTIME_UTIL_H_
#endif //
DGL
_RUNTIME_UTIL_H_
include/dgl/scheduler.h
View file @
3683a774
// DGL Scheduler interface
/*!
* Copyright (c) 2018 by Contributors
* \file dgl/scheduler.h
* \brief Operations on graph index.
*/
#ifndef DGL_SCHEDULER_H_
#define DGL_SCHEDULER_H_
#include "runtime/ndarray.h"
#include <vector>
#include "runtime/ndarray.h"
namespace
dgl
{
...
...
@@ -25,8 +29,8 @@ namespace sched {
*/
std
::
vector
<
IdArray
>
DegreeBucketing
(
const
IdArray
&
vids
);
}
// namespace sched
}
// namespace sched
}
// namespace dgl
}
// namespace dgl
#endif // DGL_SCHEDULER_H_
#endif
// DGL_SCHEDULER_H_
python/dgl/__init__.py
View file @
3683a774
...
...
@@ -11,5 +11,5 @@ from ._ffi.base import DGLError, __version__
from
.base
import
ALL
from
.batched_graph
import
*
from
.generator
import
*
from
.graph
import
DGLGraph
,
__MSG__
,
__REPR__
from
.graph
import
DGLGraph
from
.subgraph
import
DGLSubGraph
python/dgl/base.py
View file @
3683a774
...
...
@@ -11,7 +11,4 @@ ALL = "__ALL__"
def
is_all
(
arg
):
return
isinstance
(
arg
,
str
)
and
arg
==
ALL
__MSG__
=
"__MSG__"
__REPR__
=
"__REPR__"
dgl_warning
=
warnings
.
warn
python/dgl/function/message.py
View file @
3683a774
...
...
@@ -4,17 +4,25 @@ from __future__ import absolute_import
import
operator
import
dgl.backend
as
F
__all__
=
[
"MessageFunction"
,
"src_mul_edge"
,
"copy_src"
,
"copy_edge"
]
__all__
=
[
"src_mul_edge"
,
"copy_src"
,
"copy_edge"
]
class
MessageFunction
(
object
):
"""Base builtin message function class."""
def
__call__
(
self
,
src
,
edge
):
"""Regular computation of this builtin.
This will be used when optimization is not available.
"""
raise
NotImplementedError
def
name
(
self
):
"""Return the name of this builtin function."""
raise
NotImplementedError
def
is_spmv_supported
(
self
,
g
):
"""Return whether the SPMV optimization is supported."""
raise
NotImplementedError
...
...
@@ -22,12 +30,6 @@ class BundledMessageFunction(MessageFunction):
def
__init__
(
self
,
fn_list
):
if
not
isinstance
(
fn_list
,
(
list
,
tuple
)):
fn_list
=
[
fn_list
]
else
:
# sanity check on out field
for
fn
in
fn_list
:
# cannot perform check for udf
if
isinstance
(
fn
,
MessageFunction
)
and
fn
.
out_field
is
None
:
raise
RuntimeError
(
"Not specifying out field for multiple message is ambiguous"
)
self
.
fn_list
=
fn_list
def
is_spmv_supported
(
self
,
g
):
...
...
@@ -43,11 +45,8 @@ class BundledMessageFunction(MessageFunction):
if
ret
is
None
:
ret
=
msg
else
:
try
:
# ret and msg must be dict
ret
.
update
(
msg
)
except
:
raise
RuntimeError
(
"Must specify out field for multiple message"
)
# ret and msg must be dict
ret
.
update
(
msg
)
return
ret
def
name
(
self
):
...
...
@@ -55,25 +54,26 @@ class BundledMessageFunction(MessageFunction):
def
_is_spmv_supported_node_feat
(
g
,
field
):
if
field
is
None
:
feat
=
g
.
get_n_repr
()
else
:
feat
=
g
.
get_n_repr
()[
field
]
"""Return whether the node feature shape supports SPMV optimization.
Only scalar and vector features are supported currently.
"""
feat
=
g
.
get_n_repr
()[
field
]
shape
=
F
.
shape
(
feat
)
return
len
(
shape
)
==
1
or
len
(
shape
)
==
2
def
_is_spmv_supported_edge_feat
(
g
,
field
):
# check shape, only scalar edge feature can be optimized at the moment
if
field
is
None
:
feat
=
g
.
get_e_repr
()
else
:
feat
=
g
.
get_e_repr
()[
field
]
"""Return whether the edge feature shape supports SPMV optimization.
Only scalar feature is supported currently.
"""
feat
=
g
.
get_e_repr
()[
field
]
shape
=
F
.
shape
(
feat
)
return
len
(
shape
)
==
1
or
(
len
(
shape
)
==
2
and
shape
[
1
]
==
1
)
class
SrcMulEdgeMessageFunction
(
MessageFunction
):
def
__init__
(
self
,
mul_op
,
src_field
=
None
,
edge_field
=
None
,
out_field
=
None
):
def
__init__
(
self
,
mul_op
,
src_field
,
edge_field
,
out_field
):
self
.
mul_op
=
mul_op
self
.
src_field
=
src_field
self
.
edge_field
=
edge_field
...
...
@@ -84,21 +84,14 @@ class SrcMulEdgeMessageFunction(MessageFunction):
and
_is_spmv_supported_edge_feat
(
g
,
self
.
edge_field
)
def
__call__
(
self
,
src
,
edge
):
if
self
.
src_field
is
not
None
:
src
=
src
[
self
.
src_field
]
if
self
.
edge_field
is
not
None
:
edge
=
edge
[
self
.
edge_field
]
ret
=
self
.
mul_op
(
src
,
edge
)
if
self
.
out_field
is
None
:
return
ret
else
:
return
{
self
.
out_field
:
ret
}
ret
=
self
.
mul_op
(
src
[
self
.
src_field
],
edge
[
self
.
edge_field
])
return
{
self
.
out_field
:
ret
}
def
name
(
self
):
return
"src_mul_edge"
class
CopySrcMessageFunction
(
MessageFunction
):
def
__init__
(
self
,
src_field
=
None
,
out_field
=
None
):
def
__init__
(
self
,
src_field
,
out_field
):
self
.
src_field
=
src_field
self
.
out_field
=
out_field
...
...
@@ -106,14 +99,7 @@ class CopySrcMessageFunction(MessageFunction):
return
_is_spmv_supported_node_feat
(
g
,
self
.
src_field
)
def
__call__
(
self
,
src
,
edge
):
if
self
.
src_field
is
not
None
:
ret
=
src
[
self
.
src_field
]
else
:
ret
=
src
if
self
.
out_field
is
None
:
return
ret
else
:
return
{
self
.
out_field
:
ret
}
return
{
self
.
out_field
:
src
[
self
.
src_field
]}
def
name
(
self
):
return
"copy_src"
...
...
@@ -142,14 +128,41 @@ class CopyEdgeMessageFunction(MessageFunction):
return
"copy_edge"
def
src_mul_edge
(
src
=
None
,
edge
=
None
,
out
=
None
):
"""TODO(minjie): docstring """
def
src_mul_edge
(
src
,
edge
,
out
):
"""Builtin message function that computes message by multiplying source node features
with edge features.
Parameters
----------
src : str
The source feature name.
edge : str
The edge feature name.
out : str
The output message name.
"""
return
SrcMulEdgeMessageFunction
(
operator
.
mul
,
src
,
edge
,
out
)
def
copy_src
(
src
=
None
,
out
=
None
):
"""TODO(minjie): docstring """
def
copy_src
(
src
,
out
):
"""Builtin message function that computes message using source node feature.
Parameters
----------
src : str
The source feature name.
out : str
The output message name.
"""
return
CopySrcMessageFunction
(
src
,
out
)
def
copy_edge
(
edge
=
None
,
out
=
None
):
"""TODO(minjie): docstring """
def
copy_edge
(
edge
,
out
):
"""Builtin message function that computes message using edge feature.
Parameters
----------
edge : str
The edge feature name.
out : str
The output message name.
"""
return
CopyEdgeMessageFunction
(
edge
,
out
)
python/dgl/function/reducer.py
View file @
3683a774
...
...
@@ -3,27 +3,30 @@ from __future__ import absolute_import
from
..
import
backend
as
F
__all__
=
[
"ReduceFunction"
,
"sum"
,
"max"
]
__all__
=
[
"sum"
,
"max"
]
class
ReduceFunction
(
object
):
"""Base builtin reduce function class."""
def
__call__
(
self
,
node
,
msgs
):
"""Regular computation of this builtin.
This will be used when optimization is not available.
"""
raise
NotImplementedError
def
name
(
self
):
"""Return the name of this builtin function."""
raise
NotImplementedError
def
is_spmv_supported
(
self
):
"""Return whether the SPMV optimization is supported."""
raise
NotImplementedError
class
BundledReduceFunction
(
ReduceFunction
):
def
__init__
(
self
,
fn_list
):
if
not
isinstance
(
fn_list
,
(
list
,
tuple
)):
fn_list
=
[
fn_list
]
else
:
# sanity check on out field
for
fn
in
fn_list
:
if
isinstance
(
fn
,
ReduceFunction
)
and
fn
.
out_field
is
None
:
raise
RuntimeError
(
"Not specifying out field for multiple reduce is ambiguous"
)
self
.
fn_list
=
fn_list
def
is_spmv_supported
(
self
):
...
...
@@ -39,51 +42,50 @@ class BundledReduceFunction(ReduceFunction):
if
ret
is
None
:
ret
=
rpr
else
:
try
:
# ret and rpr must be dict
ret
.
update
(
rpr
)
except
:
raise
RuntimeError
(
"Must specify out field for multiple reudce"
)
# ret and rpr must be dict
ret
.
update
(
rpr
)
return
ret
def
name
(
self
):
return
"bundled"
class
ReducerFunctionTemplate
(
ReduceFunction
):
def
__init__
(
self
,
name
,
batch_op
,
nonbatch_
op
,
msg_field
=
None
,
out_field
=
None
):
def
__init__
(
self
,
name
,
op
,
msg_field
,
out_field
):
self
.
name
=
name
self
.
batch_op
=
batch_op
self
.
nonbatch_op
=
nonbatch_op
self
.
op
=
op
self
.
msg_field
=
msg_field
self
.
out_field
=
out_field
def
is_spmv_supported
(
self
):
#
TODO: support max
#
NOTE: only sum is supported right now.
return
self
.
name
==
"sum"
def
__call__
(
self
,
node
,
msgs
):
if
isinstance
(
msgs
,
list
):
if
self
.
msg_field
is
None
:
ret
=
self
.
nonbatch_op
(
msgs
)
else
:
ret
=
self
.
nonbatch_op
([
msg
[
self
.
msg_field
]
for
msg
in
msgs
])
else
:
if
self
.
msg_field
is
None
:
ret
=
self
.
batch_op
(
msgs
,
1
)
else
:
ret
=
self
.
batch_op
(
msgs
[
self
.
msg_field
],
1
)
if
self
.
out_field
is
None
:
return
ret
else
:
return
{
self
.
out_field
:
ret
}
return
{
self
.
out_field
:
self
.
op
(
msgs
[
self
.
msg_field
],
1
)}
def
name
(
self
):
return
self
.
name
_python_sum
=
sum
def
sum
(
msgs
=
None
,
out
=
None
):
return
ReducerFunctionTemplate
(
"sum"
,
F
.
sum
,
_python_sum
,
msgs
,
out
)
def
sum
(
msg
,
out
):
"""Builtin reduce function that aggregates messages by sum.
Parameters
----------
msg : str
The message name.
out : str
The output node feature name.
"""
return
ReducerFunctionTemplate
(
"sum"
,
F
.
sum
,
msg
,
out
)
def
max
(
msg
,
out
):
"""Builtin reduce function that aggregates messages by max.
_python_max
=
max
def
max
(
msgs
=
None
,
out
=
None
):
return
ReducerFunctionTemplate
(
"max"
,
F
.
max
,
_python_max
,
msgs
,
out
)
Parameters
----------
msg : str
The message name.
out : str
The output node feature name.
"""
return
ReducerFunctionTemplate
(
"max"
,
F
.
max
,
msg
,
out
)
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment