Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
842d3768
Commit
842d3768
authored
Sep 16, 2018
by
Minjie Wang
Browse files
python api
parent
14d88497
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
139 additions
and
15 deletions
+139
-15
include/dgl/vector_view.h
include/dgl/vector_view.h
+3
-3
python/dgl/_ffi/_ctypes/function.py
python/dgl/_ffi/_ctypes/function.py
+4
-3
python/dgl/_ffi/function.py
python/dgl/_ffi/function.py
+4
-4
python/dgl/backend/pytorch.py
python/dgl/backend/pytorch.py
+18
-4
python/dgl/cgraph.py
python/dgl/cgraph.py
+28
-0
src/graph/graph.cc
src/graph/graph.cc
+1
-1
src/graph/graph_apis.cc
src/graph/graph_apis.cc
+81
-0
No files found.
include/dgl/vector_view.h
View file @
842d3768
...
@@ -103,7 +103,7 @@ class vector_view {
...
@@ -103,7 +103,7 @@ class vector_view {
*/
*/
ValueType
&
operator
[](
size_t
i
)
{
ValueType
&
operator
[](
size_t
i
)
{
CHECK
(
!
is_view_
);
CHECK
(
!
is_view_
);
return
data_
[
i
];
return
(
*
data_
)
[
i
];
}
}
/*!
/*!
...
@@ -113,9 +113,9 @@ class vector_view {
...
@@ -113,9 +113,9 @@ class vector_view {
*/
*/
const
ValueType
&
operator
[](
size_t
i
)
const
{
const
ValueType
&
operator
[](
size_t
i
)
const
{
if
(
is_view_
)
{
if
(
is_view_
)
{
return
data_
[
index_
[
i
]];
return
(
*
data_
)
[
index_
[
i
]];
}
else
{
}
else
{
return
data_
[
i
];
return
(
*
data_
)
[
i
];
}
}
}
}
...
...
python/dgl/_ffi/_ctypes/function.py
View file @
842d3768
...
@@ -118,9 +118,10 @@ def _make_tvm_args(args, temp_args):
...
@@ -118,9 +118,10 @@ def _make_tvm_args(args, temp_args):
elif
isinstance
(
arg
,
string_types
):
elif
isinstance
(
arg
,
string_types
):
values
[
i
].
v_str
=
c_str
(
arg
)
values
[
i
].
v_str
=
c_str
(
arg
)
type_codes
[
i
]
=
TypeCode
.
STR
type_codes
[
i
]
=
TypeCode
.
STR
elif
isinstance
(
arg
,
_CLASS_MODULE
):
# NOTE(minjie): module is not used in DGL
values
[
i
].
v_handle
=
arg
.
handle
#elif isinstance(arg, _CLASS_MODULE):
type_codes
[
i
]
=
TypeCode
.
MODULE_HANDLE
# values[i].v_handle = arg.handle
# type_codes[i] = TypeCode.MODULE_HANDLE
elif
isinstance
(
arg
,
FunctionBase
):
elif
isinstance
(
arg
,
FunctionBase
):
values
[
i
].
v_handle
=
arg
.
handle
values
[
i
].
v_handle
=
arg
.
handle
type_codes
[
i
]
=
TypeCode
.
FUNC_HANDLE
type_codes
[
i
]
=
TypeCode
.
FUNC_HANDLE
...
...
python/dgl/_ffi/function.py
View file @
842d3768
...
@@ -29,7 +29,7 @@ except IMPORT_EXCEPT:
...
@@ -29,7 +29,7 @@ except IMPORT_EXCEPT:
FunctionHandle
=
ctypes
.
c_void_p
FunctionHandle
=
ctypes
.
c_void_p
class
Function
(
_FunctionBase
):
class
Function
(
_FunctionBase
):
"""The PackedFunc object
used in TVM
.
"""The PackedFunc object.
Function plays an key role to bridge front and backend in TVM.
Function plays an key role to bridge front and backend in TVM.
Function provide a type-erased interface, you can call function with positional arguments.
Function provide a type-erased interface, you can call function with positional arguments.
...
@@ -275,7 +275,7 @@ def _init_api(namespace, target_module_name=None):
...
@@ -275,7 +275,7 @@ def _init_api(namespace, target_module_name=None):
"""
"""
target_module_name
=
(
target_module_name
=
(
target_module_name
if
target_module_name
else
namespace
)
target_module_name
if
target_module_name
else
namespace
)
if
namespace
.
startswith
(
"
tvm
."
):
if
namespace
.
startswith
(
"
dgl
."
):
_init_api_prefix
(
target_module_name
,
namespace
[
4
:])
_init_api_prefix
(
target_module_name
,
namespace
[
4
:])
else
:
else
:
_init_api_prefix
(
target_module_name
,
namespace
)
_init_api_prefix
(
target_module_name
,
namespace
)
...
@@ -288,7 +288,7 @@ def _init_api_prefix(module_name, prefix):
...
@@ -288,7 +288,7 @@ def _init_api_prefix(module_name, prefix):
if
prefix
==
"api"
:
if
prefix
==
"api"
:
fname
=
name
fname
=
name
if
name
.
startswith
(
"_"
):
if
name
.
startswith
(
"_"
):
target_module
=
sys
.
modules
[
"
tvm
._api_internal"
]
target_module
=
sys
.
modules
[
"
dgl
._api_internal"
]
else
:
else
:
target_module
=
module
target_module
=
module
else
:
else
:
...
@@ -302,7 +302,7 @@ def _init_api_prefix(module_name, prefix):
...
@@ -302,7 +302,7 @@ def _init_api_prefix(module_name, prefix):
f
=
get_global_func
(
name
)
f
=
get_global_func
(
name
)
ff
=
_get_api
(
f
)
ff
=
_get_api
(
f
)
ff
.
__name__
=
fname
ff
.
__name__
=
fname
ff
.
__doc__
=
(
"
TVM
PackedFunc %s. "
%
fname
)
ff
.
__doc__
=
(
"
DGL
PackedFunc %s. "
%
fname
)
setattr
(
target_module
,
ff
.
__name__
,
ff
)
setattr
(
target_module
,
ff
.
__name__
,
ff
)
_set_class_function
(
Function
)
_set_class_function
(
Function
)
python/dgl/backend/pytorch.py
View file @
842d3768
from
__future__
import
absolute_import
from
__future__
import
absolute_import
import
torch
as
th
import
torch
as
th
import
scipy.sparse
import
dgl.context
as
context
from
.._ffi.runtime_ctypes
import
TVMType
,
TVMContext
,
TVMArray
from
.._ffi.runtime_ctypes
import
TypeCode
,
tvm_shape_index_t
from
..
context
as
cpu
,
gpu
# Tensor types
# Tensor types
Tensor
=
th
.
Tensor
Tensor
=
th
.
Tensor
...
@@ -78,6 +80,18 @@ def to_context(x, ctx):
...
@@ -78,6 +80,18 @@ def to_context(x, ctx):
def
get_context
(
x
):
def
get_context
(
x
):
if
x
.
device
.
type
==
'cpu'
:
if
x
.
device
.
type
==
'cpu'
:
return
context
.
cpu
()
return
cpu
()
else
:
else
:
return
context
.
gpu
(
x
.
device
.
index
)
return
gpu
(
x
.
device
.
index
)
def
asdglarray
(
arr
):
assert
arr
.
is_contiguous
()
rst
=
TVMArray
()
rst
.
data
=
arr
.
data_ptr
()
rst
.
shape
=
c_array
(
tvm_shape_index_t
,
arr
.
shape
)
rst
.
strides
=
None
# TODO: dtype
rst
.
dtype
=
TVMType
(
arr
.
dtype
)
rst
.
ndim
=
arr
.
ndimension
()
# TODO: ctx
return
rst
python/dgl/cgraph.py
0 → 100644
View file @
842d3768
from
__future__
import
absolute_import
from
._ffi.function
import
_init_api
import
.backend
as
F
class
DGLGraph
(
object
):
def
__init__
(
self
):
self
.
_handle
=
_CAPI_DGLGraphCreate
()
def
__del__
(
self
):
_CAPI_DGLGraphFree
(
self
.
_handle
)
def
add_nodes
(
self
,
num
):
_CAPI_DGLGraphAddVertices
(
self
.
_handle
,
num
);
def
add_edge
(
self
,
u
,
v
):
_CAPI_DGLGraphAddEdge
(
self
.
_handle
,
u
,
v
);
def
add_edges
(
self
,
u
,
v
):
pass
def
number_of_nodes
(
self
):
return
_CAPI_DGLGraphNumVertices
(
self
.
_handle
)
def
number_of_edges
(
self
):
return
_CAPI_DGLGraphNumEdges
(
self
.
_handle
)
_init_api
(
"dgl.cgraph"
)
src/graph/graph.cc
View file @
842d3768
...
@@ -58,7 +58,7 @@ BoolArray Graph::HasVertices(IdArray vids) const {
...
@@ -58,7 +58,7 @@ BoolArray Graph::HasVertices(IdArray vids) const {
BoolArray
rst
=
BoolArray
::
Empty
({
len
},
vids
->
dtype
,
vids
->
ctx
);
BoolArray
rst
=
BoolArray
::
Empty
({
len
},
vids
->
dtype
,
vids
->
ctx
);
const
int64_t
*
vid_data
=
static_cast
<
int64_t
*>
(
vids
->
data
);
const
int64_t
*
vid_data
=
static_cast
<
int64_t
*>
(
vids
->
data
);
int64_t
*
rst_data
=
static_cast
<
int64_t
*>
(
rst
->
data
);
int64_t
*
rst_data
=
static_cast
<
int64_t
*>
(
rst
->
data
);
const
u
int64_t
nverts
=
NumVertices
();
const
int64_t
nverts
=
NumVertices
();
for
(
int64_t
i
=
0
;
i
<
len
;
++
i
)
{
for
(
int64_t
i
=
0
;
i
<
len
;
++
i
)
{
rst_data
[
i
]
=
(
vid_data
[
i
]
<
nverts
)
?
1
:
0
;
rst_data
[
i
]
=
(
vid_data
[
i
]
<
nverts
)
?
1
:
0
;
}
}
...
...
src/graph/graph_apis.cc
0 → 100644
View file @
842d3768
#include <dgl/runtime/packed_func.h>
#include <dgl/runtime/registry.h>
#include <dgl/graph.h>
using
tvm
::
runtime
::
TVMArgs
;
using
tvm
::
runtime
::
TVMArgValue
;
using
tvm
::
runtime
::
TVMRetValue
;
using
tvm
::
runtime
::
PackedFunc
;
namespace
dgl
{
typedef
void
*
GraphHandle
;
void
DGLGraphCreate
(
TVMArgs
args
,
TVMRetValue
*
rv
)
{
GraphHandle
ghandle
=
new
Graph
();
*
rv
=
ghandle
;
}
TVM_REGISTER_GLOBAL
(
"cgraph._CAPI_DGLGraphCreate"
)
.
set_body
(
DGLGraphCreate
);
void
DGLGraphFree
(
TVMArgs
args
,
TVMRetValue
*
rv
)
{
GraphHandle
ghandle
=
args
[
0
];
Graph
*
gptr
=
static_cast
<
Graph
*>
(
ghandle
);
delete
gptr
;
}
TVM_REGISTER_GLOBAL
(
"cgraph._CAPI_DGLGraphFree"
)
.
set_body
(
DGLGraphFree
);
void
DGLGraphAddVertices
(
TVMArgs
args
,
TVMRetValue
*
rv
)
{
GraphHandle
ghandle
=
args
[
0
];
Graph
*
gptr
=
static_cast
<
Graph
*>
(
ghandle
);
uint64_t
num_vertices
=
args
[
1
];
gptr
->
AddVertices
(
num_vertices
);
}
TVM_REGISTER_GLOBAL
(
"cgraph._CAPI_DGLGraphAddVertices"
)
.
set_body
(
DGLGraphAddVertices
);
void
DGLGraphAddEdge
(
TVMArgs
args
,
TVMRetValue
*
rv
)
{
GraphHandle
ghandle
=
args
[
0
];
Graph
*
gptr
=
static_cast
<
Graph
*>
(
ghandle
);
const
dgl_id_t
src
=
args
[
1
];
const
dgl_id_t
dst
=
args
[
2
];
gptr
->
AddEdge
(
src
,
dst
);
}
TVM_REGISTER_GLOBAL
(
"cgraph._CAPI_DGLGraphAddEdge"
)
.
set_body
(
DGLGraphAddEdge
);
void
DGLGraphAddEdges
(
TVMArgs
args
,
TVMRetValue
*
rv
)
{
GraphHandle
ghandle
=
args
[
0
];
Graph
*
gptr
=
static_cast
<
Graph
*>
(
ghandle
);
const
IdArray
src
=
args
[
1
];
const
IdArray
dst
=
args
[
2
];
gptr
->
AddEdges
(
src
,
dst
);
}
TVM_REGISTER_GLOBAL
(
"cgraph._CAPI_DGLGraphAddEdges"
)
.
set_body
(
DGLGraphAddEdges
);
void
DGLGraphNumVertices
(
TVMArgs
args
,
TVMRetValue
*
rv
)
{
GraphHandle
ghandle
=
args
[
0
];
const
Graph
*
gptr
=
static_cast
<
Graph
*>
(
ghandle
);
*
rv
=
static_cast
<
int64_t
>
(
gptr
->
NumVertices
());
}
TVM_REGISTER_GLOBAL
(
"cgraph._CAPI_DGLGraphNumVertices"
)
.
set_body
(
DGLGraphNumVertices
);
void
DGLGraphNumEdges
(
TVMArgs
args
,
TVMRetValue
*
rv
)
{
GraphHandle
ghandle
=
args
[
0
];
const
Graph
*
gptr
=
static_cast
<
Graph
*>
(
ghandle
);
*
rv
=
static_cast
<
int64_t
>
(
gptr
->
NumEdges
());
}
TVM_REGISTER_GLOBAL
(
"cgraph._CAPI_DGLGraphNumEdges"
)
.
set_body
(
DGLGraphNumEdges
);
}
// namespace dgl
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment