Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
c086d454
Commit
c086d454
authored
Sep 16, 2018
by
Minjie Wang
Browse files
Use TVMContext
parent
b24daa66
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
14 additions
and
34 deletions
+14
-34
python/dgl/__init__.py
python/dgl/__init__.py
+0
-1
python/dgl/_ffi/runtime_ctypes.py
python/dgl/_ffi/runtime_ctypes.py
+3
-0
python/dgl/backend/pytorch.py
python/dgl/backend/pytorch.py
+11
-11
python/dgl/context.py
python/dgl/context.py
+0
-21
python/dgl/graph.py
python/dgl/graph.py
+0
-1
No files found.
python/dgl/__init__.py
View file @
c086d454
...
@@ -10,7 +10,6 @@ from ._ffi.base import DGLError, __version__
...
@@ -10,7 +10,6 @@ from ._ffi.base import DGLError, __version__
from
.base
import
ALL
from
.base
import
ALL
from
.batch
import
batch
,
unbatch
from
.batch
import
batch
,
unbatch
from
.context
import
cpu
,
gpu
from
.generator
import
*
from
.generator
import
*
from
.graph
import
DGLGraph
,
__MSG__
,
__REPR__
from
.graph
import
DGLGraph
,
__MSG__
,
__REPR__
from
.subgraph
import
DGLSubGraph
from
.subgraph
import
DGLSubGraph
python/dgl/_ffi/runtime_ctypes.py
View file @
c086d454
...
@@ -218,6 +218,9 @@ class TVMContext(ctypes.Structure):
...
@@ -218,6 +218,9 @@ class TVMContext(ctypes.Structure):
return
"%s(%d)"
%
(
return
"%s(%d)"
%
(
TVMContext
.
MASK2STR
[
self
.
device_type
],
self
.
device_id
)
TVMContext
.
MASK2STR
[
self
.
device_type
],
self
.
device_id
)
def
__hash__
(
self
):
return
hash
((
self
.
device_type
,
self
.
device_id
))
class
TVMArray
(
ctypes
.
Structure
):
class
TVMArray
(
ctypes
.
Structure
):
"""TVMValue in C API"""
"""TVMValue in C API"""
...
...
python/dgl/backend/pytorch.py
View file @
c086d454
...
@@ -4,7 +4,6 @@ import torch as th
...
@@ -4,7 +4,6 @@ import torch as th
from
.._ffi.runtime_ctypes
import
TVMType
,
TVMContext
,
TVMArray
from
.._ffi.runtime_ctypes
import
TVMType
,
TVMContext
,
TVMArray
from
.._ffi.runtime_ctypes
import
TypeCode
,
tvm_shape_index_t
from
.._ffi.runtime_ctypes
import
TypeCode
,
tvm_shape_index_t
from
..context
import
cpu
,
gpu
# Tensor types
# Tensor types
Tensor
=
th
.
Tensor
Tensor
=
th
.
Tensor
...
@@ -67,22 +66,23 @@ sort = th.sort
...
@@ -67,22 +66,23 @@ sort = th.sort
arange
=
th
.
arange
arange
=
th
.
arange
mul
=
th
.
mul
mul
=
th
.
mul
def
to_context
(
x
,
ctx
):
def
to_context
(
arr
,
ctx
):
if
ctx
is
None
:
if
ctx
is
None
:
return
x
return
arr
elif
ctx
.
device
==
'gpu'
:
elif
ctx
.
device
_type
==
TVMContext
.
STR2MASK
[
'cuda'
]
:
th
.
cuda
.
set_device
(
ctx
.
device_id
)
th
.
cuda
.
set_device
(
ctx
.
device_id
)
return
x
.
cuda
()
return
arr
.
cuda
()
elif
ctx
.
device
==
'cpu'
:
elif
ctx
.
device
_type
==
TVMContext
.
STR2MASK
[
'cpu'
]
:
return
x
.
cpu
()
return
arr
.
cpu
()
else
:
else
:
raise
RuntimeError
(
'Invalid context'
,
ctx
)
raise
RuntimeError
(
'Invalid context'
,
ctx
)
def
get_context
(
x
):
def
get_context
(
arr
):
if
x
.
device
.
type
==
'cpu'
:
if
arr
.
device
.
type
==
'cpu'
:
return
cpu
(
)
return
TVMContext
(
TVMContext
.
STR2MASK
[
'cpu'
],
0
)
else
:
else
:
return
gpu
(
x
.
device
.
index
)
return
TVMContext
(
TVMContext
.
STR2MASK
[
arr
.
device
.
type
],
arr
.
device
.
index
)
def
asdglarray
(
arr
):
def
asdglarray
(
arr
):
assert
arr
.
is_contiguous
()
assert
arr
.
is_contiguous
()
...
...
python/dgl/context.py
deleted
100644 → 0
View file @
b24daa66
"""DGL's device context shim."""
class
Context
(
object
):
def
__init__
(
self
,
dev
,
devid
=-
1
):
self
.
device
=
dev
self
.
device_id
=
devid
def
__str__
(
self
):
return
'{}:{}'
.
format
(
self
.
device
,
self
.
device_id
)
def
__eq__
(
self
,
other
):
return
self
.
device
==
other
.
device
and
self
.
device_id
==
other
.
device_id
def
__hash__
(
self
):
return
hash
((
self
.
device
,
self
.
device_id
))
def
gpu
(
gpuid
):
return
Context
(
'gpu'
,
gpuid
)
def
cpu
():
return
Context
(
'cpu'
)
python/dgl/graph.py
View file @
c086d454
...
@@ -10,7 +10,6 @@ from .base import ALL, is_all, __MSG__, __REPR__
...
@@ -10,7 +10,6 @@ from .base import ALL, is_all, __MSG__, __REPR__
from
.
import
backend
as
F
from
.
import
backend
as
F
from
.backend
import
Tensor
from
.backend
import
Tensor
from
.cached_graph
import
CachedGraph
,
create_cached_graph
from
.cached_graph
import
CachedGraph
,
create_cached_graph
from
.
import
context
from
.frame
import
FrameRef
,
merge_frames
from
.frame
import
FrameRef
,
merge_frames
from
.nx_adapt
import
nx_init
from
.nx_adapt
import
nx_init
from
.
import
scheduler
from
.
import
scheduler
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment