Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
758cb16e
Commit
758cb16e
authored
Sep 17, 2018
by
Minjie Wang
Browse files
ndarray argument
parent
c086d454
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
106 additions
and
9 deletions
+106
-9
python/dgl/backend/pytorch.py
python/dgl/backend/pytorch.py
+28
-8
python/dgl/cgraph.py
python/dgl/cgraph.py
+9
-1
python/dgl/ndarray.py
python/dgl/ndarray.py
+69
-0
No files found.
python/dgl/backend/pytorch.py
View file @
758cb16e
from
__future__
import
absolute_import
from
__future__
import
absolute_import
import
ctypes
import
torch
as
th
import
torch
as
th
from
.._ffi.base
import
_LIB
,
check_call
,
c_array
from
.._ffi.runtime_ctypes
import
TVMType
,
TVMContext
,
TVMArray
from
.._ffi.runtime_ctypes
import
TVMType
,
TVMContext
,
TVMArray
from
.._ffi.runtime_ctypes
import
TypeCode
,
tvm_shape_index_t
from
.._ffi.runtime_ctypes
import
TypeCode
,
tvm_shape_index_t
from
..
import
ndarray
as
nd
# Tensor types
# Tensor types
Tensor
=
th
.
Tensor
Tensor
=
th
.
Tensor
...
@@ -84,14 +87,31 @@ def get_context(arr):
...
@@ -84,14 +87,31 @@ def get_context(arr):
return
TVMContext
(
return
TVMContext
(
TVMContext
.
STR2MASK
[
arr
.
device
.
type
],
arr
.
device
.
index
)
TVMContext
.
STR2MASK
[
arr
.
device
.
type
],
arr
.
device
.
index
)
def
_typestr
(
arr_dtype
):
if
arr_dtype
in
(
th
.
float16
,
th
.
half
):
return
'float16'
elif
arr_dtype
in
(
th
.
float32
,
th
.
float
):
return
'float32'
elif
arr_dtype
in
(
th
.
float64
,
th
.
double
):
return
'float64'
elif
arr_dtype
in
(
th
.
int16
,
th
.
short
):
return
'int16'
elif
arr_dtype
in
(
th
.
int32
,
th
.
int
):
return
'int32'
elif
arr_dtype
in
(
th
.
int64
,
th
.
long
):
return
'int64'
elif
arr_dtype
==
th
.
int8
:
return
'int8'
elif
arr_dtype
==
th
.
uint8
:
return
'uint8'
else
:
raise
RuntimeError
(
'Unsupported data type:'
,
arr_dtype
)
def
asdglarray
(
arr
):
def
asdglarray
(
arr
):
"""The data is copied to the new array."""
assert
arr
.
is_contiguous
()
assert
arr
.
is_contiguous
()
rst
=
TVMArray
()
rst
=
nd
.
empty
(
tuple
(
arr
.
shape
),
_typestr
(
arr
.
dtype
),
get_context
(
arr
))
rst
.
data
=
arr
.
data_ptr
()
data
=
ctypes
.
cast
(
arr
.
data_ptr
(),
ctypes
.
c_void_p
)
rst
.
shape
=
c_array
(
tvm_shape_index_t
,
arr
.
shape
)
nbytes
=
ctypes
.
c_size_t
(
arr
.
numel
()
*
arr
.
element_size
())
rst
.
strides
=
None
check_call
(
_LIB
.
TVMArrayCopyFromBytes
(
rst
.
handle
,
data
,
nbytes
))
# TODO: dtype
rst
.
dtype
=
TVMType
(
arr
.
dtype
)
rst
.
ndim
=
arr
.
ndimension
()
# TODO: ctx
return
rst
return
rst
python/dgl/cgraph.py
View file @
758cb16e
...
@@ -2,6 +2,7 @@ from __future__ import absolute_import
...
@@ -2,6 +2,7 @@ from __future__ import absolute_import
from
._ffi.function
import
_init_api
from
._ffi.function
import
_init_api
from
.
import
backend
as
F
from
.
import
backend
as
F
from
.
import
utils
class
DGLGraph
(
object
):
class
DGLGraph
(
object
):
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -17,7 +18,14 @@ class DGLGraph(object):
...
@@ -17,7 +18,14 @@ class DGLGraph(object):
_CAPI_DGLGraphAddEdge
(
self
.
_handle
,
u
,
v
);
_CAPI_DGLGraphAddEdge
(
self
.
_handle
,
u
,
v
);
def
add_edges
(
self
,
u
,
v
):
def
add_edges
(
self
,
u
,
v
):
pass
u
=
utils
.
Index
(
u
)
v
=
utils
.
Index
(
v
)
u_array
=
F
.
asdglarray
(
u
.
totensor
())
v_array
=
F
.
asdglarray
(
v
.
totensor
())
_CAPI_DGLGraphAddEdges
(
self
.
_handle
,
u_array
,
v_array
)
def
number_of_nodes
(
self
):
def
number_of_nodes
(
self
):
return
_CAPI_DGLGraphNumVertices
(
self
.
_handle
)
return
_CAPI_DGLGraphNumVertices
(
self
.
_handle
)
...
...
python/dgl/ndarray.py
0 → 100644
View file @
758cb16e
"""DGL Runtime NDArray API.
dgl.ndarray provides a minimum runtime array API to unify
different array libraries used as backend.
"""
# pylint: disable=invalid-name,unused-import
from
__future__
import
absolute_import
as
_abs
import
numpy
as
_np
from
._ffi.ndarray
import
TVMContext
,
TVMType
,
NDArrayBase
from
._ffi.ndarray
import
context
,
empty
,
from_dlpack
from
._ffi.ndarray
import
_set_class_ndarray
class
NDArray
(
NDArrayBase
):
"""Lightweight NDArray class for DGL framework."""
pass
def
cpu
(
dev_id
=
0
):
"""Construct a CPU device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return
TVMContext
(
1
,
dev_id
)
def
gpu
(
dev_id
=
0
):
"""Construct a CPU device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return
TVMContext
(
2
,
dev_id
)
def
array
(
arr
,
ctx
=
cpu
(
0
)):
"""Create an array from source arr.
Parameters
----------
arr : numpy.ndarray
The array to be copied from
ctx : TVMContext, optional
The device context to create the array
Returns
-------
ret : NDArray
The created array
"""
if
not
isinstance
(
arr
,
(
_np
.
ndarray
,
NDArray
)):
arr
=
_np
.
array
(
arr
)
return
empty
(
arr
.
shape
,
arr
.
dtype
,
ctx
).
copyfrom
(
arr
)
_set_class_ndarray
(
NDArray
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment