Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
9c135fd5
Unverified
Commit
9c135fd5
authored
Oct 19, 2018
by
VoVAllen
Committed by
GitHub
Oct 19, 2018
Browse files
Merge pull request #4 from jermainewang/master
Sync with latest commit
parents
9d3f299d
00add9f2
Changes
73
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
148 additions
and
89 deletions
+148
-89
include/dgl/graph.h
include/dgl/graph.h
+13
-8
include/dgl/graph_op.h
include/dgl/graph_op.h
+6
-1
include/dgl/runtime/README.md
include/dgl/runtime/README.md
+0
-3
include/dgl/runtime/c_backend_api.h
include/dgl/runtime/c_backend_api.h
+3
-3
include/dgl/runtime/c_runtime_api.h
include/dgl/runtime/c_runtime_api.h
+4
-4
include/dgl/runtime/device_api.h
include/dgl/runtime/device_api.h
+4
-4
include/dgl/runtime/module.h
include/dgl/runtime/module.h
+4
-4
include/dgl/runtime/ndarray.h
include/dgl/runtime/ndarray.h
+4
-4
include/dgl/runtime/packed_func.h
include/dgl/runtime/packed_func.h
+4
-4
include/dgl/runtime/registry.h
include/dgl/runtime/registry.h
+4
-4
include/dgl/runtime/serializer.h
include/dgl/runtime/serializer.h
+4
-4
include/dgl/runtime/threading_backend.h
include/dgl/runtime/threading_backend.h
+4
-4
include/dgl/runtime/util.h
include/dgl/runtime/util.h
+4
-4
include/dgl/scheduler.h
include/dgl/scheduler.h
+9
-5
python/dgl/__init__.py
python/dgl/__init__.py
+1
-3
python/dgl/backend/mxnet.py
python/dgl/backend/mxnet.py
+33
-3
python/dgl/backend/pytorch.py
python/dgl/backend/pytorch.py
+10
-23
python/dgl/base.py
python/dgl/base.py
+6
-2
python/dgl/data/citation_graph.py
python/dgl/data/citation_graph.py
+23
-0
python/dgl/data/utils.py
python/dgl/data/utils.py
+8
-2
No files found.
include/dgl/graph.h
View file @
9c135fd5
// DGL Graph interface
/*!
#ifndef DGL_DGLGRAPH_H_
* Copyright (c) 2018 by Contributors
#define DGL_DGLGRAPH_H_
* \file dgl/graph.h
* \brief DGL graph index class.
*/
#ifndef DGL_GRAPH_H_
#define DGL_GRAPH_H_
#include <stdint.h>
#include <vector>
#include <cstdint>
#include "runtime/ndarray.h"
#include "runtime/ndarray.h"
namespace
dgl
{
namespace
dgl
{
...
@@ -17,7 +22,7 @@ class GraphOp;
...
@@ -17,7 +22,7 @@ class GraphOp;
struct
Subgraph
;
struct
Subgraph
;
/*!
/*!
* \brief Base dgl graph class.
* \brief Base dgl graph
index
class.
*
*
* DGL's graph is directed. Vertices are integers enumerated from zero. Edges
* DGL's graph is directed. Vertices are integers enumerated from zero. Edges
* are uniquely identified by the two endpoints. Multi-edge is currently not
* are uniquely identified by the two endpoints. Multi-edge is currently not
...
@@ -41,7 +46,7 @@ class Graph {
...
@@ -41,7 +46,7 @@ class Graph {
}
EdgeArray
;
}
EdgeArray
;
/*! \brief default constructor */
/*! \brief default constructor */
Graph
(
bool
multigraph
=
false
)
:
is_multigraph_
(
multigraph
)
{}
explicit
Graph
(
bool
multigraph
=
false
)
:
is_multigraph_
(
multigraph
)
{}
/*! \brief default copy constructor */
/*! \brief default copy constructor */
Graph
(
const
Graph
&
other
)
=
default
;
Graph
(
const
Graph
&
other
)
=
default
;
...
@@ -192,7 +197,7 @@ class Graph {
...
@@ -192,7 +197,7 @@ class Graph {
* \return the id arrays of the two endpoints of the edges.
* \return the id arrays of the two endpoints of the edges.
*/
*/
EdgeArray
InEdges
(
IdArray
vids
)
const
;
EdgeArray
InEdges
(
IdArray
vids
)
const
;
/*!
/*!
* \brief Get the out edges of the vertex.
* \brief Get the out edges of the vertex.
* \note The returned src id array is filled with vid.
* \note The returned src id array is filled with vid.
...
@@ -347,4 +352,4 @@ struct Subgraph {
...
@@ -347,4 +352,4 @@ struct Subgraph {
}
// namespace dgl
}
// namespace dgl
#endif // DGL_
DGL
GRAPH_H_
#endif // DGL_GRAPH_H_
include/dgl/graph_op.h
View file @
9c135fd5
// Graph operations
/*!
* Copyright (c) 2018 by Contributors
* \file dgl/graph_op.h
* \brief Operations on graph index.
*/
#ifndef DGL_GRAPH_OP_H_
#ifndef DGL_GRAPH_OP_H_
#define DGL_GRAPH_OP_H_
#define DGL_GRAPH_OP_H_
#include <vector>
#include "graph.h"
#include "graph.h"
namespace
dgl
{
namespace
dgl
{
...
...
include/dgl/runtime/README.md
deleted
100644 → 0
View file @
9d3f299d
# C API and runtime
Borrowed and adapted from TVM project.
include/dgl/runtime/c_backend_api.h
View file @
9c135fd5
...
@@ -7,8 +7,8 @@
...
@@ -7,8 +7,8 @@
* used by compiled tvm operators, usually user do not need to use these
* used by compiled tvm operators, usually user do not need to use these
* function directly.
* function directly.
*/
*/
#ifndef
TVM
_RUNTIME_C_BACKEND_API_H_
#ifndef
DGL
_RUNTIME_C_BACKEND_API_H_
#define
TVM
_RUNTIME_C_BACKEND_API_H_
#define
DGL
_RUNTIME_C_BACKEND_API_H_
#include "c_runtime_api.h"
#include "c_runtime_api.h"
...
@@ -136,4 +136,4 @@ TVM_DLL int TVMBackendRunOnce(void** handle,
...
@@ -136,4 +136,4 @@ TVM_DLL int TVMBackendRunOnce(void** handle,
#ifdef __cplusplus
#ifdef __cplusplus
}
// TVM_EXTERN_C
}
// TVM_EXTERN_C
#endif
#endif
#endif //
TVM
_RUNTIME_C_BACKEND_API_H_
#endif //
DGL
_RUNTIME_C_BACKEND_API_H_
include/dgl/runtime/c_runtime_api.h
View file @
9c135fd5
/*!
/*!
* Copyright (c) 2016 by Contributors
* Copyright (c) 2016 by Contributors
* \file
tvm
/runtime/c_runtime_api.h
* \file
dgl
/runtime/c_runtime_api.h
* \brief TVM runtime library.
* \brief TVM runtime library.
*
*
* The philosophy of TVM project is to customize the compilation
* The philosophy of TVM project is to customize the compilation
...
@@ -15,8 +15,8 @@
...
@@ -15,8 +15,8 @@
* - Use TVMFuncListGlobalNames to get global function name
* - Use TVMFuncListGlobalNames to get global function name
* - Use TVMFuncCall to call these functions.
* - Use TVMFuncCall to call these functions.
*/
*/
#ifndef
TVM
_RUNTIME_C_RUNTIME_API_H_
#ifndef
DGL
_RUNTIME_C_RUNTIME_API_H_
#define
TVM
_RUNTIME_C_RUNTIME_API_H_
#define
DGL
_RUNTIME_C_RUNTIME_API_H_
// Macros to do weak linking
// Macros to do weak linking
#ifdef _MSC_VER
#ifdef _MSC_VER
...
@@ -530,4 +530,4 @@ TVM_DLL int TVMStreamStreamSynchronize(int device_type,
...
@@ -530,4 +530,4 @@ TVM_DLL int TVMStreamStreamSynchronize(int device_type,
#ifdef __cplusplus
#ifdef __cplusplus
}
// TVM_EXTERN_C
}
// TVM_EXTERN_C
#endif
#endif
#endif //
TVM
_RUNTIME_C_RUNTIME_API_H_
#endif //
DGL
_RUNTIME_C_RUNTIME_API_H_
include/dgl/runtime/device_api.h
View file @
9c135fd5
/*!
/*!
* Copyright (c) 2016 by Contributors
* Copyright (c) 2016 by Contributors
* \file
tvm
/runtime/device_api.h
* \file
dgl
/runtime/device_api.h
* \brief Abstract device memory management API
* \brief Abstract device memory management API
*/
*/
#ifndef
TVM
_RUNTIME_DEVICE_API_H_
#ifndef
DGL
_RUNTIME_DEVICE_API_H_
#define
TVM
_RUNTIME_DEVICE_API_H_
#define
DGL
_RUNTIME_DEVICE_API_H_
#include <string>
#include <string>
#include "packed_func.h"
#include "packed_func.h"
...
@@ -180,4 +180,4 @@ class DeviceAPI {
...
@@ -180,4 +180,4 @@ class DeviceAPI {
constexpr
int
kRPCSessMask
=
128
;
constexpr
int
kRPCSessMask
=
128
;
}
// namespace runtime
}
// namespace runtime
}
// namespace tvm
}
// namespace tvm
#endif //
TVM
_RUNTIME_DEVICE_API_H_
#endif //
DGL
_RUNTIME_DEVICE_API_H_
include/dgl/runtime/module.h
View file @
9c135fd5
/*!
/*!
* Copyright (c) 2017 by Contributors
* Copyright (c) 2017 by Contributors
* \file
tvm
/runtime/module.h
* \file
dgl
/runtime/module.h
* \brief Runtime container of the functions generated by TVM,
* \brief Runtime container of the functions generated by TVM,
* This is used to support dynamically link, load and save
* This is used to support dynamically link, load and save
* functions from different convention under unified API.
* functions from different convention under unified API.
*/
*/
#ifndef
TVM
_RUNTIME_MODULE_H_
#ifndef
DGL
_RUNTIME_MODULE_H_
#define
TVM
_RUNTIME_MODULE_H_
#define
DGL
_RUNTIME_MODULE_H_
#include <dmlc/io.h>
#include <dmlc/io.h>
#include <memory>
#include <memory>
...
@@ -174,4 +174,4 @@ inline const ModuleNode* Module::operator->() const {
...
@@ -174,4 +174,4 @@ inline const ModuleNode* Module::operator->() const {
}
// namespace tvm
}
// namespace tvm
#include "packed_func.h"
#include "packed_func.h"
#endif //
TVM
_RUNTIME_MODULE_H_
#endif //
DGL
_RUNTIME_MODULE_H_
include/dgl/runtime/ndarray.h
View file @
9c135fd5
/*!
/*!
* Copyright (c) 2017 by Contributors
* Copyright (c) 2017 by Contributors
* \file
tvm
/runtime/ndarray.h
* \file
dgl
/runtime/ndarray.h
* \brief Abstract device memory management API
* \brief Abstract device memory management API
*/
*/
#ifndef
TVM
_RUNTIME_NDARRAY_H_
#ifndef
DGL
_RUNTIME_NDARRAY_H_
#define
TVM
_RUNTIME_NDARRAY_H_
#define
DGL
_RUNTIME_NDARRAY_H_
#include <atomic>
#include <atomic>
#include <vector>
#include <vector>
...
@@ -422,4 +422,4 @@ inline bool NDArray::Load(dmlc::Stream* strm) {
...
@@ -422,4 +422,4 @@ inline bool NDArray::Load(dmlc::Stream* strm) {
}
// namespace runtime
}
// namespace runtime
}
// namespace tvm
}
// namespace tvm
#endif //
TVM
_RUNTIME_NDARRAY_H_
#endif //
DGL
_RUNTIME_NDARRAY_H_
include/dgl/runtime/packed_func.h
View file @
9c135fd5
/*!
/*!
* Copyright (c) 2017 by Contributors
* Copyright (c) 2017 by Contributors
* \file
tvm
/runtime/packed_func.h
* \file
dgl
/runtime/packed_func.h
* \brief Type-erased function used across TVM API.
* \brief Type-erased function used across TVM API.
*/
*/
#ifndef
TVM
_RUNTIME_PACKED_FUNC_H_
#ifndef
DGL
_RUNTIME_PACKED_FUNC_H_
#define
TVM
_RUNTIME_PACKED_FUNC_H_
#define
DGL
_RUNTIME_PACKED_FUNC_H_
#include <dmlc/logging.h>
#include <dmlc/logging.h>
#include <functional>
#include <functional>
...
@@ -1212,4 +1212,4 @@ inline PackedFunc Module::GetFunction(const std::string& name, bool query_import
...
@@ -1212,4 +1212,4 @@ inline PackedFunc Module::GetFunction(const std::string& name, bool query_import
}
}
}
// namespace runtime
}
// namespace runtime
}
// namespace tvm
}
// namespace tvm
#endif //
TVM
_RUNTIME_PACKED_FUNC_H_
#endif //
DGL
_RUNTIME_PACKED_FUNC_H_
include/dgl/runtime/registry.h
View file @
9c135fd5
/*!
/*!
* Copyright (c) 2017 by Contributors
* Copyright (c) 2017 by Contributors
* \file
tvm
/runtime/registry.h
* \file
dgl
/runtime/registry.h
* \brief This file defines the TVM global function registry.
* \brief This file defines the TVM global function registry.
*
*
* The registered functions will be made available to front-end
* The registered functions will be made available to front-end
...
@@ -22,8 +22,8 @@
...
@@ -22,8 +22,8 @@
* });
* });
* \endcode
* \endcode
*/
*/
#ifndef
TVM
_RUNTIME_REGISTRY_H_
#ifndef
DGL
_RUNTIME_REGISTRY_H_
#define
TVM
_RUNTIME_REGISTRY_H_
#define
DGL
_RUNTIME_REGISTRY_H_
#include <string>
#include <string>
#include <vector>
#include <vector>
...
@@ -141,4 +141,4 @@ class Registry {
...
@@ -141,4 +141,4 @@ class Registry {
}
// namespace runtime
}
// namespace runtime
}
// namespace tvm
}
// namespace tvm
#endif //
TVM
_RUNTIME_REGISTRY_H_
#endif //
DGL
_RUNTIME_REGISTRY_H_
include/dgl/runtime/serializer.h
View file @
9c135fd5
/*!
/*!
* Copyright (c) 2017 by Contributors
* Copyright (c) 2017 by Contributors
* \file
tvm
/runtime/serializer.h
* \file
dgl
/runtime/serializer.h
* \brief Serializer extension to support TVM data types
* \brief Serializer extension to support TVM data types
* Include this file to enable serialization of DLDataType, DLContext
* Include this file to enable serialization of DLDataType, DLContext
*/
*/
#ifndef
TVM
_RUNTIME_SERIALIZER_H_
#ifndef
DGL
_RUNTIME_SERIALIZER_H_
#define
TVM
_RUNTIME_SERIALIZER_H_
#define
DGL
_RUNTIME_SERIALIZER_H_
#include <dmlc/io.h>
#include <dmlc/io.h>
#include <dmlc/serializer.h>
#include <dmlc/serializer.h>
...
@@ -48,4 +48,4 @@ struct Handler<DLContext> {
...
@@ -48,4 +48,4 @@ struct Handler<DLContext> {
}
// namespace serializer
}
// namespace serializer
}
// namespace dmlc
}
// namespace dmlc
#endif //
TVM
_RUNTIME_SERIALIZER_H_
#endif //
DGL
_RUNTIME_SERIALIZER_H_
include/dgl/runtime/threading_backend.h
View file @
9c135fd5
/*!
/*!
* Copyright (c) 2018 by Contributors
* Copyright (c) 2018 by Contributors
* \file
tvm
/runtime/threading_backend.h
* \file
dgl
/runtime/threading_backend.h
* \brief Utilities for manipulating thread pool threads.
* \brief Utilities for manipulating thread pool threads.
*/
*/
#ifndef
TVM
_RUNTIME_THREADING_BACKEND_H_
#ifndef
DGL
_RUNTIME_THREADING_BACKEND_H_
#define
TVM
_RUNTIME_THREADING_BACKEND_H_
#define
DGL
_RUNTIME_THREADING_BACKEND_H_
#include <functional>
#include <functional>
#include <memory>
#include <memory>
...
@@ -82,4 +82,4 @@ int MaxConcurrency();
...
@@ -82,4 +82,4 @@ int MaxConcurrency();
}
// namespace runtime
}
// namespace runtime
}
// namespace tvm
}
// namespace tvm
#endif //
TVM
_RUNTIME_THREADING_BACKEND_H_
#endif //
DGL
_RUNTIME_THREADING_BACKEND_H_
include/dgl/runtime/util.h
View file @
9c135fd5
/*!
/*!
* Copyright (c) 2017 by Contributors
* Copyright (c) 2017 by Contributors
* \file
tvm
/runtime/util.h
* \file
dgl
/runtime/util.h
* \brief Useful runtime util.
* \brief Useful runtime util.
*/
*/
#ifndef
TVM
_RUNTIME_UTIL_H_
#ifndef
DGL
_RUNTIME_UTIL_H_
#define
TVM
_RUNTIME_UTIL_H_
#define
DGL
_RUNTIME_UTIL_H_
#include "c_runtime_api.h"
#include "c_runtime_api.h"
...
@@ -50,4 +50,4 @@ enum TVMStructFieldKind : int {
...
@@ -50,4 +50,4 @@ enum TVMStructFieldKind : int {
}
// namespace intrinsic
}
// namespace intrinsic
}
// namespace ir
}
// namespace ir
}
// namespace tvm
}
// namespace tvm
#endif //
TVM
_RUNTIME_UTIL_H_
#endif //
DGL
_RUNTIME_UTIL_H_
include/dgl/scheduler.h
View file @
9c135fd5
// DGL Scheduler interface
/*!
* Copyright (c) 2018 by Contributors
* \file dgl/scheduler.h
* \brief Operations on graph index.
*/
#ifndef DGL_SCHEDULER_H_
#ifndef DGL_SCHEDULER_H_
#define DGL_SCHEDULER_H_
#define DGL_SCHEDULER_H_
#include "runtime/ndarray.h"
#include <vector>
#include <vector>
#include "runtime/ndarray.h"
namespace
dgl
{
namespace
dgl
{
...
@@ -25,8 +29,8 @@ namespace sched {
...
@@ -25,8 +29,8 @@ namespace sched {
*/
*/
std
::
vector
<
IdArray
>
DegreeBucketing
(
const
IdArray
&
vids
);
std
::
vector
<
IdArray
>
DegreeBucketing
(
const
IdArray
&
vids
);
}
// namespace sched
}
// namespace sched
}
// namespace dgl
}
// namespace dgl
#endif // DGL_SCHEDULER_H_
#endif
// DGL_SCHEDULER_H_
python/dgl/__init__.py
View file @
9c135fd5
from
.
import
backend
from
.
import
backend
from
.
import
data
from
.
import
data
from
.
import
function
from
.
import
function
from
.
import
generator
from
.
import
nn
from
.
import
nn
from
._ffi.runtime_ctypes
import
TypeCode
from
._ffi.runtime_ctypes
import
TypeCode
...
@@ -10,6 +9,5 @@ from ._ffi.base import DGLError, __version__
...
@@ -10,6 +9,5 @@ from ._ffi.base import DGLError, __version__
from
.base
import
ALL
from
.base
import
ALL
from
.batched_graph
import
*
from
.batched_graph
import
*
from
.generator
import
*
from
.graph
import
DGLGraph
from
.graph
import
DGLGraph
,
__MSG__
,
__REPR__
from
.subgraph
import
DGLSubGraph
from
.subgraph
import
DGLSubGraph
python/dgl/backend/mxnet.py
View file @
9c135fd5
...
@@ -46,8 +46,14 @@ def from_numpy(np_data):
...
@@ -46,8 +46,14 @@ def from_numpy(np_data):
def
pack
(
tensors
):
def
pack
(
tensors
):
return
F
.
concat
(
*
tensors
,
dim
=
0
)
return
F
.
concat
(
*
tensors
,
dim
=
0
)
def
unpack
(
x
,
indices_or_sections
=
1
):
def
unpack
(
x
,
split_sizes_or_sections
=
1
):
return
th
.
split
(
x
,
indices_or_sections
)
if
isinstance
(
split_sizes_or_sections
,
list
):
np_arr
=
x
.
asnumpy
()
indices
=
np
.
cumsum
(
split_sizes_or_sections
)
res
=
np
.
split
(
np_arr
,
indices
[:
-
1
])
return
[
tensor
(
arr
,
dtype
=
x
.
dtype
)
for
arr
in
res
]
else
:
return
F
.
split
(
x
,
split_sizes_or_sections
)
# TODO this doesn't exist for symbol.
# TODO this doesn't exist for symbol.
def
shape
(
x
):
def
shape
(
x
):
...
@@ -66,7 +72,10 @@ def unique(x):
...
@@ -66,7 +72,10 @@ def unique(x):
return
mx
.
nd
.
array
(
tmp
,
ctx
=
x
.
context
,
dtype
=
x
.
dtype
)
return
mx
.
nd
.
array
(
tmp
,
ctx
=
x
.
context
,
dtype
=
x
.
dtype
)
def
gather_row
(
data
,
row_index
):
def
gather_row
(
data
,
row_index
):
return
data
[
row_index
,]
if
isinstance
(
row_index
,
F
.
NDArray
):
return
F
.
take
(
data
,
row_index
)
else
:
return
data
[
row_index
,]
scatter_row
=
mx
.
nd
.
contrib
.
index_copy
scatter_row
=
mx
.
nd
.
contrib
.
index_copy
...
@@ -114,6 +123,27 @@ def get_context(x):
...
@@ -114,6 +123,27 @@ def get_context(x):
def
_typestr
(
arr_dtype
):
def
_typestr
(
arr_dtype
):
return
arr_dtype
return
arr_dtype
def
get_tvmtype
(
arr
):
arr_dtype
=
arr
.
dtype
if
arr_dtype
==
np
.
float16
:
return
TVMType
(
'float16'
)
elif
arr_dtype
==
np
.
float32
:
return
TVMType
(
'float32'
)
elif
arr_dtype
==
np
.
float64
:
return
TVMType
(
'float64'
)
elif
arr_dtype
==
np
.
int16
:
return
TVMType
(
'int16'
)
elif
arr_dtype
==
np
.
int32
:
return
TVMType
(
'int32'
)
elif
arr_dtype
==
np
.
int64
:
return
TVMType
(
'int64'
)
elif
arr_dtype
==
np
.
int8
:
return
TVMType
(
'int8'
)
elif
arr_dtype
==
np
.
uint8
:
return
TVMType
(
'uint8'
)
else
:
raise
RuntimeError
(
'Unsupported data type:'
,
arr_dtype
)
def
zerocopy_to_dlpack
(
arr
):
def
zerocopy_to_dlpack
(
arr
):
"""Return a dlpack compatible array using zero copy."""
"""Return a dlpack compatible array using zero copy."""
return
arr
.
to_dlpack_for_read
()
return
arr
.
to_dlpack_for_read
()
...
...
python/dgl/backend/pytorch.py
View file @
9c135fd5
...
@@ -93,23 +93,24 @@ def get_context(arr):
...
@@ -93,23 +93,24 @@ def get_context(arr):
return
TVMContext
(
return
TVMContext
(
TVMContext
.
STR2MASK
[
arr
.
device
.
type
],
arr
.
device
.
index
)
TVMContext
.
STR2MASK
[
arr
.
device
.
type
],
arr
.
device
.
index
)
def
_typestr
(
arr_dtype
):
def
get_tvmtype
(
arr
):
arr_dtype
=
arr
.
dtype
if
arr_dtype
in
(
th
.
float16
,
th
.
half
):
if
arr_dtype
in
(
th
.
float16
,
th
.
half
):
return
'float16'
return
TVMType
(
'float16'
)
elif
arr_dtype
in
(
th
.
float32
,
th
.
float
):
elif
arr_dtype
in
(
th
.
float32
,
th
.
float
):
return
'float32'
return
TVMType
(
'float32'
)
elif
arr_dtype
in
(
th
.
float64
,
th
.
double
):
elif
arr_dtype
in
(
th
.
float64
,
th
.
double
):
return
'float64'
return
TVMType
(
'float64'
)
elif
arr_dtype
in
(
th
.
int16
,
th
.
short
):
elif
arr_dtype
in
(
th
.
int16
,
th
.
short
):
return
'int16'
return
TVMType
(
'int16'
)
elif
arr_dtype
in
(
th
.
int32
,
th
.
int
):
elif
arr_dtype
in
(
th
.
int32
,
th
.
int
):
return
'int32'
return
TVMType
(
'int32'
)
elif
arr_dtype
in
(
th
.
int64
,
th
.
long
):
elif
arr_dtype
in
(
th
.
int64
,
th
.
long
):
return
'int64'
return
TVMType
(
'int64'
)
elif
arr_dtype
==
th
.
int8
:
elif
arr_dtype
==
th
.
int8
:
return
'int8'
return
TVMType
(
'int8'
)
elif
arr_dtype
==
th
.
uint8
:
elif
arr_dtype
==
th
.
uint8
:
return
'uint8'
return
TVMType
(
'uint8'
)
else
:
else
:
raise
RuntimeError
(
'Unsupported data type:'
,
arr_dtype
)
raise
RuntimeError
(
'Unsupported data type:'
,
arr_dtype
)
...
@@ -130,20 +131,6 @@ def zerocopy_from_numpy(np_data):
...
@@ -130,20 +131,6 @@ def zerocopy_from_numpy(np_data):
"""Return a tensor that shares the numpy data."""
"""Return a tensor that shares the numpy data."""
return
th
.
from_numpy
(
np_data
)
return
th
.
from_numpy
(
np_data
)
'''
data = arr_data
assert data.is_contiguous()
arr = TVMArray()
shape = c_array(tvm_shape_index_t, tuple(data.shape))
arr.data = ctypes.cast(data.data_ptr(), ctypes.c_void_p)
arr.shape = shape
arr.strides = None
arr.dtype = TVMType(_typestr(data.dtype))
arr.ndim = len(shape)
arr.ctx = get_context(data)
return arr
'''
def
nonzero_1d
(
arr
):
def
nonzero_1d
(
arr
):
"""Return a 1D tensor with nonzero element indices in a 1D vector"""
"""Return a 1D tensor with nonzero element indices in a 1D vector"""
assert
arr
.
dim
()
==
1
assert
arr
.
dim
()
==
1
...
...
python/dgl/base.py
View file @
9c135fd5
"""Module for base types and utilities."""
"""Module for base types and utilities."""
from
__future__
import
absolute_import
import
warnings
from
._ffi.base
import
DGLError
# A special argument for selecting all nodes/edges.
# A special argument for selecting all nodes/edges.
ALL
=
"__ALL__"
ALL
=
"__ALL__"
...
@@ -6,5 +11,4 @@ ALL = "__ALL__"
...
@@ -6,5 +11,4 @@ ALL = "__ALL__"
def
is_all
(
arg
):
def
is_all
(
arg
):
return
isinstance
(
arg
,
str
)
and
arg
==
ALL
return
isinstance
(
arg
,
str
)
and
arg
==
ALL
__MSG__
=
"__MSG__"
dgl_warning
=
warnings
.
warn
__REPR__
=
"__REPR__"
python/dgl/data/citation_graph.py
View file @
9c135fd5
...
@@ -206,10 +206,33 @@ def get_gnp_generator(args):
...
@@ -206,10 +206,33 @@ def get_gnp_generator(args):
return
nx
.
fast_gnp_random_graph
(
n
,
p
,
seed
,
True
)
return
nx
.
fast_gnp_random_graph
(
n
,
p
,
seed
,
True
)
return
_gen
return
_gen
class
ScipyGraph
(
object
):
"""A simple graph object that uses scipy matrix."""
def
__init__
(
self
,
mat
):
self
.
_mat
=
mat
def
get_graph
(
self
):
return
self
.
_mat
def
number_of_nodes
(
self
):
return
self
.
_mat
.
shape
[
0
]
def
number_of_edges
(
self
):
return
self
.
_mat
.
getnnz
()
def
get_scipy_generator
(
args
):
n
=
args
.
syn_gnp_n
p
=
(
2
*
np
.
log
(
n
)
/
n
)
if
args
.
syn_gnp_p
==
0.
else
args
.
syn_gnp_p
def
_gen
(
seed
):
return
ScipyGraph
(
sp
.
random
(
n
,
n
,
p
,
format
=
'coo'
))
return
_gen
def
load_synthetic
(
args
):
def
load_synthetic
(
args
):
ty
=
args
.
syn_type
ty
=
args
.
syn_type
if
ty
==
'gnp'
:
if
ty
==
'gnp'
:
gen
=
get_gnp_generator
(
args
)
gen
=
get_gnp_generator
(
args
)
elif
ty
==
'scipy'
:
gen
=
get_scipy_generator
(
args
)
else
:
else
:
raise
ValueError
(
'Unknown graph generator type: {}'
.
format
(
ty
))
raise
ValueError
(
'Unknown graph generator type: {}'
.
format
(
ty
))
return
GCNSyntheticDataset
(
return
GCNSyntheticDataset
(
...
...
python/dgl/data/utils.py
View file @
9c135fd5
"""Dataset utilities."""
"""Dataset utilities."""
from
__future__
import
absolute_import
import
os
import
os
,
sys
import
hashlib
import
hashlib
import
warnings
import
warnings
import
zipfile
import
zipfile
...
@@ -125,17 +126,22 @@ def extract_archive(file, target_dir):
...
@@ -125,17 +126,22 @@ def extract_archive(file, target_dir):
target_dir : str
target_dir : str
Target directory of the archive to be uncompressed
Target directory of the archive to be uncompressed
"""
"""
if
os
.
path
.
exists
(
target_dir
):
return
if
file
.
endswith
(
'.gz'
)
or
file
.
endswith
(
'.tar'
)
or
file
.
endswith
(
'.tgz'
):
if
file
.
endswith
(
'.gz'
)
or
file
.
endswith
(
'.tar'
)
or
file
.
endswith
(
'.tgz'
):
archive
=
tarfile
.
open
(
file
,
'r'
)
archive
=
tarfile
.
open
(
file
,
'r'
)
elif
file
.
endswith
(
'.zip'
):
elif
file
.
endswith
(
'.zip'
):
archive
=
zipfile
.
ZipFile
(
file
,
'r'
)
archive
=
zipfile
.
ZipFile
(
file
,
'r'
)
else
:
else
:
raise
Exception
(
'Unrecognized file type: '
+
file
)
raise
Exception
(
'Unrecognized file type: '
+
file
)
print
(
'Extracting file to {}'
.
format
(
target_dir
))
archive
.
extractall
(
path
=
target_dir
)
archive
.
extractall
(
path
=
target_dir
)
archive
.
close
()
archive
.
close
()
def
get_download_dir
():
def
get_download_dir
():
dirname
=
'_download'
"""Get the absolute path to the download directory."""
curr_path
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
__file__
)))
dirname
=
os
.
path
.
join
(
curr_path
,
'../../../_download'
)
if
not
os
.
path
.
exists
(
dirname
):
if
not
os
.
path
.
exists
(
dirname
):
os
.
makedirs
(
dirname
)
os
.
makedirs
(
dirname
)
return
dirname
return
dirname
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment