Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
9c135fd5
You need to sign in or sign up before continuing.
Unverified
Commit
9c135fd5
authored
Oct 19, 2018
by
VoVAllen
Committed by
GitHub
Oct 19, 2018
Browse files
Merge pull request #4 from jermainewang/master
Sync with latest commit
parents
9d3f299d
00add9f2
Changes
73
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
148 additions
and
89 deletions
+148
-89
include/dgl/graph.h
include/dgl/graph.h
+13
-8
include/dgl/graph_op.h
include/dgl/graph_op.h
+6
-1
include/dgl/runtime/README.md
include/dgl/runtime/README.md
+0
-3
include/dgl/runtime/c_backend_api.h
include/dgl/runtime/c_backend_api.h
+3
-3
include/dgl/runtime/c_runtime_api.h
include/dgl/runtime/c_runtime_api.h
+4
-4
include/dgl/runtime/device_api.h
include/dgl/runtime/device_api.h
+4
-4
include/dgl/runtime/module.h
include/dgl/runtime/module.h
+4
-4
include/dgl/runtime/ndarray.h
include/dgl/runtime/ndarray.h
+4
-4
include/dgl/runtime/packed_func.h
include/dgl/runtime/packed_func.h
+4
-4
include/dgl/runtime/registry.h
include/dgl/runtime/registry.h
+4
-4
include/dgl/runtime/serializer.h
include/dgl/runtime/serializer.h
+4
-4
include/dgl/runtime/threading_backend.h
include/dgl/runtime/threading_backend.h
+4
-4
include/dgl/runtime/util.h
include/dgl/runtime/util.h
+4
-4
include/dgl/scheduler.h
include/dgl/scheduler.h
+9
-5
python/dgl/__init__.py
python/dgl/__init__.py
+1
-3
python/dgl/backend/mxnet.py
python/dgl/backend/mxnet.py
+33
-3
python/dgl/backend/pytorch.py
python/dgl/backend/pytorch.py
+10
-23
python/dgl/base.py
python/dgl/base.py
+6
-2
python/dgl/data/citation_graph.py
python/dgl/data/citation_graph.py
+23
-0
python/dgl/data/utils.py
python/dgl/data/utils.py
+8
-2
No files found.
include/dgl/graph.h
View file @
9c135fd5
// DGL Graph interface
#ifndef DGL_DGLGRAPH_H_
#define DGL_DGLGRAPH_H_
/*!
* Copyright (c) 2018 by Contributors
* \file dgl/graph.h
* \brief DGL graph index class.
*/
#ifndef DGL_GRAPH_H_
#define DGL_GRAPH_H_
#include <stdint.h>
#include <vector>
#include <cstdint>
#include "runtime/ndarray.h"
namespace
dgl
{
...
...
@@ -17,7 +22,7 @@ class GraphOp;
struct
Subgraph
;
/*!
* \brief Base dgl graph class.
* \brief Base dgl graph
index
class.
*
* DGL's graph is directed. Vertices are integers enumerated from zero. Edges
* are uniquely identified by the two endpoints. Multi-edge is currently not
...
...
@@ -41,7 +46,7 @@ class Graph {
}
EdgeArray
;
/*! \brief default constructor */
Graph
(
bool
multigraph
=
false
)
:
is_multigraph_
(
multigraph
)
{}
explicit
Graph
(
bool
multigraph
=
false
)
:
is_multigraph_
(
multigraph
)
{}
/*! \brief default copy constructor */
Graph
(
const
Graph
&
other
)
=
default
;
...
...
@@ -192,7 +197,7 @@ class Graph {
* \return the id arrays of the two endpoints of the edges.
*/
EdgeArray
InEdges
(
IdArray
vids
)
const
;
/*!
* \brief Get the out edges of the vertex.
* \note The returned src id array is filled with vid.
...
...
@@ -347,4 +352,4 @@ struct Subgraph {
}
// namespace dgl
#endif // DGL_
DGL
GRAPH_H_
#endif // DGL_GRAPH_H_
include/dgl/graph_op.h
View file @
9c135fd5
// Graph operations
/*!
* Copyright (c) 2018 by Contributors
* \file dgl/graph_op.h
* \brief Operations on graph index.
*/
#ifndef DGL_GRAPH_OP_H_
#define DGL_GRAPH_OP_H_
#include <vector>
#include "graph.h"
namespace
dgl
{
...
...
include/dgl/runtime/README.md
deleted
100644 → 0
View file @
9d3f299d
# C API and runtime
Borrowed and adapted from TVM project.
include/dgl/runtime/c_backend_api.h
View file @
9c135fd5
...
...
@@ -7,8 +7,8 @@
* used by compiled tvm operators, usually user do not need to use these
* function directly.
*/
#ifndef
TVM
_RUNTIME_C_BACKEND_API_H_
#define
TVM
_RUNTIME_C_BACKEND_API_H_
#ifndef
DGL
_RUNTIME_C_BACKEND_API_H_
#define
DGL
_RUNTIME_C_BACKEND_API_H_
#include "c_runtime_api.h"
...
...
@@ -136,4 +136,4 @@ TVM_DLL int TVMBackendRunOnce(void** handle,
#ifdef __cplusplus
}
// TVM_EXTERN_C
#endif
#endif //
TVM
_RUNTIME_C_BACKEND_API_H_
#endif //
DGL
_RUNTIME_C_BACKEND_API_H_
include/dgl/runtime/c_runtime_api.h
View file @
9c135fd5
/*!
* Copyright (c) 2016 by Contributors
* \file
tvm
/runtime/c_runtime_api.h
* \file
dgl
/runtime/c_runtime_api.h
* \brief TVM runtime library.
*
* The philosophy of TVM project is to customize the compilation
...
...
@@ -15,8 +15,8 @@
* - Use TVMFuncListGlobalNames to get global function name
* - Use TVMFuncCall to call these functions.
*/
#ifndef
TVM
_RUNTIME_C_RUNTIME_API_H_
#define
TVM
_RUNTIME_C_RUNTIME_API_H_
#ifndef
DGL
_RUNTIME_C_RUNTIME_API_H_
#define
DGL
_RUNTIME_C_RUNTIME_API_H_
// Macros to do weak linking
#ifdef _MSC_VER
...
...
@@ -530,4 +530,4 @@ TVM_DLL int TVMStreamStreamSynchronize(int device_type,
#ifdef __cplusplus
}
// TVM_EXTERN_C
#endif
#endif //
TVM
_RUNTIME_C_RUNTIME_API_H_
#endif //
DGL
_RUNTIME_C_RUNTIME_API_H_
include/dgl/runtime/device_api.h
View file @
9c135fd5
/*!
* Copyright (c) 2016 by Contributors
* \file
tvm
/runtime/device_api.h
* \file
dgl
/runtime/device_api.h
* \brief Abstract device memory management API
*/
#ifndef
TVM
_RUNTIME_DEVICE_API_H_
#define
TVM
_RUNTIME_DEVICE_API_H_
#ifndef
DGL
_RUNTIME_DEVICE_API_H_
#define
DGL
_RUNTIME_DEVICE_API_H_
#include <string>
#include "packed_func.h"
...
...
@@ -180,4 +180,4 @@ class DeviceAPI {
constexpr
int
kRPCSessMask
=
128
;
}
// namespace runtime
}
// namespace tvm
#endif //
TVM
_RUNTIME_DEVICE_API_H_
#endif //
DGL
_RUNTIME_DEVICE_API_H_
include/dgl/runtime/module.h
View file @
9c135fd5
/*!
* Copyright (c) 2017 by Contributors
* \file
tvm
/runtime/module.h
* \file
dgl
/runtime/module.h
* \brief Runtime container of the functions generated by TVM,
* This is used to support dynamically link, load and save
* functions from different convention under unified API.
*/
#ifndef
TVM
_RUNTIME_MODULE_H_
#define
TVM
_RUNTIME_MODULE_H_
#ifndef
DGL
_RUNTIME_MODULE_H_
#define
DGL
_RUNTIME_MODULE_H_
#include <dmlc/io.h>
#include <memory>
...
...
@@ -174,4 +174,4 @@ inline const ModuleNode* Module::operator->() const {
}
// namespace tvm
#include "packed_func.h"
#endif //
TVM
_RUNTIME_MODULE_H_
#endif //
DGL
_RUNTIME_MODULE_H_
include/dgl/runtime/ndarray.h
View file @
9c135fd5
/*!
* Copyright (c) 2017 by Contributors
* \file
tvm
/runtime/ndarray.h
* \file
dgl
/runtime/ndarray.h
* \brief Abstract device memory management API
*/
#ifndef
TVM
_RUNTIME_NDARRAY_H_
#define
TVM
_RUNTIME_NDARRAY_H_
#ifndef
DGL
_RUNTIME_NDARRAY_H_
#define
DGL
_RUNTIME_NDARRAY_H_
#include <atomic>
#include <vector>
...
...
@@ -422,4 +422,4 @@ inline bool NDArray::Load(dmlc::Stream* strm) {
}
// namespace runtime
}
// namespace tvm
#endif //
TVM
_RUNTIME_NDARRAY_H_
#endif //
DGL
_RUNTIME_NDARRAY_H_
include/dgl/runtime/packed_func.h
View file @
9c135fd5
/*!
* Copyright (c) 2017 by Contributors
* \file
tvm
/runtime/packed_func.h
* \file
dgl
/runtime/packed_func.h
* \brief Type-erased function used across TVM API.
*/
#ifndef
TVM
_RUNTIME_PACKED_FUNC_H_
#define
TVM
_RUNTIME_PACKED_FUNC_H_
#ifndef
DGL
_RUNTIME_PACKED_FUNC_H_
#define
DGL
_RUNTIME_PACKED_FUNC_H_
#include <dmlc/logging.h>
#include <functional>
...
...
@@ -1212,4 +1212,4 @@ inline PackedFunc Module::GetFunction(const std::string& name, bool query_import
}
}
// namespace runtime
}
// namespace tvm
#endif //
TVM
_RUNTIME_PACKED_FUNC_H_
#endif //
DGL
_RUNTIME_PACKED_FUNC_H_
include/dgl/runtime/registry.h
View file @
9c135fd5
/*!
* Copyright (c) 2017 by Contributors
* \file
tvm
/runtime/registry.h
* \file
dgl
/runtime/registry.h
* \brief This file defines the TVM global function registry.
*
* The registered functions will be made available to front-end
...
...
@@ -22,8 +22,8 @@
* });
* \endcode
*/
#ifndef
TVM
_RUNTIME_REGISTRY_H_
#define
TVM
_RUNTIME_REGISTRY_H_
#ifndef
DGL
_RUNTIME_REGISTRY_H_
#define
DGL
_RUNTIME_REGISTRY_H_
#include <string>
#include <vector>
...
...
@@ -141,4 +141,4 @@ class Registry {
}
// namespace runtime
}
// namespace tvm
#endif //
TVM
_RUNTIME_REGISTRY_H_
#endif //
DGL
_RUNTIME_REGISTRY_H_
include/dgl/runtime/serializer.h
View file @
9c135fd5
/*!
* Copyright (c) 2017 by Contributors
* \file
tvm
/runtime/serializer.h
* \file
dgl
/runtime/serializer.h
* \brief Serializer extension to support TVM data types
* Include this file to enable serialization of DLDataType, DLContext
*/
#ifndef
TVM
_RUNTIME_SERIALIZER_H_
#define
TVM
_RUNTIME_SERIALIZER_H_
#ifndef
DGL
_RUNTIME_SERIALIZER_H_
#define
DGL
_RUNTIME_SERIALIZER_H_
#include <dmlc/io.h>
#include <dmlc/serializer.h>
...
...
@@ -48,4 +48,4 @@ struct Handler<DLContext> {
}
// namespace serializer
}
// namespace dmlc
#endif //
TVM
_RUNTIME_SERIALIZER_H_
#endif //
DGL
_RUNTIME_SERIALIZER_H_
include/dgl/runtime/threading_backend.h
View file @
9c135fd5
/*!
* Copyright (c) 2018 by Contributors
* \file
tvm
/runtime/threading_backend.h
* \file
dgl
/runtime/threading_backend.h
* \brief Utilities for manipulating thread pool threads.
*/
#ifndef
TVM
_RUNTIME_THREADING_BACKEND_H_
#define
TVM
_RUNTIME_THREADING_BACKEND_H_
#ifndef
DGL
_RUNTIME_THREADING_BACKEND_H_
#define
DGL
_RUNTIME_THREADING_BACKEND_H_
#include <functional>
#include <memory>
...
...
@@ -82,4 +82,4 @@ int MaxConcurrency();
}
// namespace runtime
}
// namespace tvm
#endif //
TVM
_RUNTIME_THREADING_BACKEND_H_
#endif //
DGL
_RUNTIME_THREADING_BACKEND_H_
include/dgl/runtime/util.h
View file @
9c135fd5
/*!
* Copyright (c) 2017 by Contributors
* \file
tvm
/runtime/util.h
* \file
dgl
/runtime/util.h
* \brief Useful runtime util.
*/
#ifndef
TVM
_RUNTIME_UTIL_H_
#define
TVM
_RUNTIME_UTIL_H_
#ifndef
DGL
_RUNTIME_UTIL_H_
#define
DGL
_RUNTIME_UTIL_H_
#include "c_runtime_api.h"
...
...
@@ -50,4 +50,4 @@ enum TVMStructFieldKind : int {
}
// namespace intrinsic
}
// namespace ir
}
// namespace tvm
#endif //
TVM
_RUNTIME_UTIL_H_
#endif //
DGL
_RUNTIME_UTIL_H_
include/dgl/scheduler.h
View file @
9c135fd5
// DGL Scheduler interface
/*!
* Copyright (c) 2018 by Contributors
* \file dgl/scheduler.h
* \brief Operations on graph index.
*/
#ifndef DGL_SCHEDULER_H_
#define DGL_SCHEDULER_H_
#include "runtime/ndarray.h"
#include <vector>
#include "runtime/ndarray.h"
namespace
dgl
{
...
...
@@ -25,8 +29,8 @@ namespace sched {
*/
std
::
vector
<
IdArray
>
DegreeBucketing
(
const
IdArray
&
vids
);
}
// namespace sched
}
// namespace sched
}
// namespace dgl
}
// namespace dgl
#endif // DGL_SCHEDULER_H_
#endif
// DGL_SCHEDULER_H_
python/dgl/__init__.py
View file @
9c135fd5
from
.
import
backend
from
.
import
data
from
.
import
function
from
.
import
generator
from
.
import
nn
from
._ffi.runtime_ctypes
import
TypeCode
...
...
@@ -10,6 +9,5 @@ from ._ffi.base import DGLError, __version__
from
.base
import
ALL
from
.batched_graph
import
*
from
.generator
import
*
from
.graph
import
DGLGraph
,
__MSG__
,
__REPR__
from
.graph
import
DGLGraph
from
.subgraph
import
DGLSubGraph
python/dgl/backend/mxnet.py
View file @
9c135fd5
...
...
@@ -46,8 +46,14 @@ def from_numpy(np_data):
def
pack
(
tensors
):
return
F
.
concat
(
*
tensors
,
dim
=
0
)
def
unpack
(
x
,
indices_or_sections
=
1
):
return
th
.
split
(
x
,
indices_or_sections
)
def
unpack
(
x
,
split_sizes_or_sections
=
1
):
if
isinstance
(
split_sizes_or_sections
,
list
):
np_arr
=
x
.
asnumpy
()
indices
=
np
.
cumsum
(
split_sizes_or_sections
)
res
=
np
.
split
(
np_arr
,
indices
[:
-
1
])
return
[
tensor
(
arr
,
dtype
=
x
.
dtype
)
for
arr
in
res
]
else
:
return
F
.
split
(
x
,
split_sizes_or_sections
)
# TODO this doesn't exist for symbol.
def
shape
(
x
):
...
...
@@ -66,7 +72,10 @@ def unique(x):
return
mx
.
nd
.
array
(
tmp
,
ctx
=
x
.
context
,
dtype
=
x
.
dtype
)
def
gather_row
(
data
,
row_index
):
return
data
[
row_index
,]
if
isinstance
(
row_index
,
F
.
NDArray
):
return
F
.
take
(
data
,
row_index
)
else
:
return
data
[
row_index
,]
scatter_row
=
mx
.
nd
.
contrib
.
index_copy
...
...
@@ -114,6 +123,27 @@ def get_context(x):
def
_typestr
(
arr_dtype
):
return
arr_dtype
def
get_tvmtype
(
arr
):
arr_dtype
=
arr
.
dtype
if
arr_dtype
==
np
.
float16
:
return
TVMType
(
'float16'
)
elif
arr_dtype
==
np
.
float32
:
return
TVMType
(
'float32'
)
elif
arr_dtype
==
np
.
float64
:
return
TVMType
(
'float64'
)
elif
arr_dtype
==
np
.
int16
:
return
TVMType
(
'int16'
)
elif
arr_dtype
==
np
.
int32
:
return
TVMType
(
'int32'
)
elif
arr_dtype
==
np
.
int64
:
return
TVMType
(
'int64'
)
elif
arr_dtype
==
np
.
int8
:
return
TVMType
(
'int8'
)
elif
arr_dtype
==
np
.
uint8
:
return
TVMType
(
'uint8'
)
else
:
raise
RuntimeError
(
'Unsupported data type:'
,
arr_dtype
)
def
zerocopy_to_dlpack
(
arr
):
"""Return a dlpack compatible array using zero copy."""
return
arr
.
to_dlpack_for_read
()
...
...
python/dgl/backend/pytorch.py
View file @
9c135fd5
...
...
@@ -93,23 +93,24 @@ def get_context(arr):
return
TVMContext
(
TVMContext
.
STR2MASK
[
arr
.
device
.
type
],
arr
.
device
.
index
)
def
_typestr
(
arr_dtype
):
def
get_tvmtype
(
arr
):
arr_dtype
=
arr
.
dtype
if
arr_dtype
in
(
th
.
float16
,
th
.
half
):
return
'float16'
return
TVMType
(
'float16'
)
elif
arr_dtype
in
(
th
.
float32
,
th
.
float
):
return
'float32'
return
TVMType
(
'float32'
)
elif
arr_dtype
in
(
th
.
float64
,
th
.
double
):
return
'float64'
return
TVMType
(
'float64'
)
elif
arr_dtype
in
(
th
.
int16
,
th
.
short
):
return
'int16'
return
TVMType
(
'int16'
)
elif
arr_dtype
in
(
th
.
int32
,
th
.
int
):
return
'int32'
return
TVMType
(
'int32'
)
elif
arr_dtype
in
(
th
.
int64
,
th
.
long
):
return
'int64'
return
TVMType
(
'int64'
)
elif
arr_dtype
==
th
.
int8
:
return
'int8'
return
TVMType
(
'int8'
)
elif
arr_dtype
==
th
.
uint8
:
return
'uint8'
return
TVMType
(
'uint8'
)
else
:
raise
RuntimeError
(
'Unsupported data type:'
,
arr_dtype
)
...
...
@@ -130,20 +131,6 @@ def zerocopy_from_numpy(np_data):
"""Return a tensor that shares the numpy data."""
return
th
.
from_numpy
(
np_data
)
'''
data = arr_data
assert data.is_contiguous()
arr = TVMArray()
shape = c_array(tvm_shape_index_t, tuple(data.shape))
arr.data = ctypes.cast(data.data_ptr(), ctypes.c_void_p)
arr.shape = shape
arr.strides = None
arr.dtype = TVMType(_typestr(data.dtype))
arr.ndim = len(shape)
arr.ctx = get_context(data)
return arr
'''
def
nonzero_1d
(
arr
):
"""Return a 1D tensor with nonzero element indices in a 1D vector"""
assert
arr
.
dim
()
==
1
...
...
python/dgl/base.py
View file @
9c135fd5
"""Module for base types and utilities."""
from
__future__
import
absolute_import
import
warnings
from
._ffi.base
import
DGLError
# A special argument for selecting all nodes/edges.
ALL
=
"__ALL__"
...
...
@@ -6,5 +11,4 @@ ALL = "__ALL__"
def
is_all
(
arg
):
return
isinstance
(
arg
,
str
)
and
arg
==
ALL
__MSG__
=
"__MSG__"
__REPR__
=
"__REPR__"
dgl_warning
=
warnings
.
warn
python/dgl/data/citation_graph.py
View file @
9c135fd5
...
...
@@ -206,10 +206,33 @@ def get_gnp_generator(args):
return
nx
.
fast_gnp_random_graph
(
n
,
p
,
seed
,
True
)
return
_gen
class
ScipyGraph
(
object
):
"""A simple graph object that uses scipy matrix."""
def
__init__
(
self
,
mat
):
self
.
_mat
=
mat
def
get_graph
(
self
):
return
self
.
_mat
def
number_of_nodes
(
self
):
return
self
.
_mat
.
shape
[
0
]
def
number_of_edges
(
self
):
return
self
.
_mat
.
getnnz
()
def
get_scipy_generator
(
args
):
n
=
args
.
syn_gnp_n
p
=
(
2
*
np
.
log
(
n
)
/
n
)
if
args
.
syn_gnp_p
==
0.
else
args
.
syn_gnp_p
def
_gen
(
seed
):
return
ScipyGraph
(
sp
.
random
(
n
,
n
,
p
,
format
=
'coo'
))
return
_gen
def
load_synthetic
(
args
):
ty
=
args
.
syn_type
if
ty
==
'gnp'
:
gen
=
get_gnp_generator
(
args
)
elif
ty
==
'scipy'
:
gen
=
get_scipy_generator
(
args
)
else
:
raise
ValueError
(
'Unknown graph generator type: {}'
.
format
(
ty
))
return
GCNSyntheticDataset
(
...
...
python/dgl/data/utils.py
View file @
9c135fd5
"""Dataset utilities."""
from
__future__
import
absolute_import
import
os
import
os
,
sys
import
hashlib
import
warnings
import
zipfile
...
...
@@ -125,17 +126,22 @@ def extract_archive(file, target_dir):
target_dir : str
Target directory of the archive to be uncompressed
"""
if
os
.
path
.
exists
(
target_dir
):
return
if
file
.
endswith
(
'.gz'
)
or
file
.
endswith
(
'.tar'
)
or
file
.
endswith
(
'.tgz'
):
archive
=
tarfile
.
open
(
file
,
'r'
)
elif
file
.
endswith
(
'.zip'
):
archive
=
zipfile
.
ZipFile
(
file
,
'r'
)
else
:
raise
Exception
(
'Unrecognized file type: '
+
file
)
print
(
'Extracting file to {}'
.
format
(
target_dir
))
archive
.
extractall
(
path
=
target_dir
)
archive
.
close
()
def
get_download_dir
():
dirname
=
'_download'
"""Get the absolute path to the download directory."""
curr_path
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
__file__
)))
dirname
=
os
.
path
.
join
(
curr_path
,
'../../../_download'
)
if
not
os
.
path
.
exists
(
dirname
):
os
.
makedirs
(
dirname
)
return
dirname
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment