Commit c086d454 authored by Minjie Wang's avatar Minjie Wang
Browse files

Use TVMContext

parent b24daa66
...@@ -10,7 +10,6 @@ from ._ffi.base import DGLError, __version__ ...@@ -10,7 +10,6 @@ from ._ffi.base import DGLError, __version__
from .base import ALL from .base import ALL
from .batch import batch, unbatch from .batch import batch, unbatch
from .context import cpu, gpu
from .generator import * from .generator import *
from .graph import DGLGraph, __MSG__, __REPR__ from .graph import DGLGraph, __MSG__, __REPR__
from .subgraph import DGLSubGraph from .subgraph import DGLSubGraph
...@@ -218,6 +218,9 @@ class TVMContext(ctypes.Structure): ...@@ -218,6 +218,9 @@ class TVMContext(ctypes.Structure):
return "%s(%d)" % ( return "%s(%d)" % (
TVMContext.MASK2STR[self.device_type], self.device_id) TVMContext.MASK2STR[self.device_type], self.device_id)
def __hash__(self):
return hash((self.device_type, self.device_id))
class TVMArray(ctypes.Structure): class TVMArray(ctypes.Structure):
"""TVMValue in C API""" """TVMValue in C API"""
......
...@@ -4,7 +4,6 @@ import torch as th ...@@ -4,7 +4,6 @@ import torch as th
from .._ffi.runtime_ctypes import TVMType, TVMContext, TVMArray from .._ffi.runtime_ctypes import TVMType, TVMContext, TVMArray
from .._ffi.runtime_ctypes import TypeCode, tvm_shape_index_t from .._ffi.runtime_ctypes import TypeCode, tvm_shape_index_t
from ..context import cpu, gpu
# Tensor types # Tensor types
Tensor = th.Tensor Tensor = th.Tensor
...@@ -67,22 +66,23 @@ sort = th.sort ...@@ -67,22 +66,23 @@ sort = th.sort
arange = th.arange arange = th.arange
mul = th.mul mul = th.mul
def to_context(x, ctx): def to_context(arr, ctx):
if ctx is None: if ctx is None:
return x return arr
elif ctx.device == 'gpu': elif ctx.device_type == TVMContext.STR2MASK['cuda']:
th.cuda.set_device(ctx.device_id) th.cuda.set_device(ctx.device_id)
return x.cuda() return arr.cuda()
elif ctx.device == 'cpu': elif ctx.device_type == TVMContext.STR2MASK['cpu']:
return x.cpu() return arr.cpu()
else: else:
raise RuntimeError('Invalid context', ctx) raise RuntimeError('Invalid context', ctx)
def get_context(x): def get_context(arr):
if x.device.type == 'cpu': if arr.device.type == 'cpu':
return cpu() return TVMContext(TVMContext.STR2MASK['cpu'], 0)
else: else:
return gpu(x.device.index) return TVMContext(
TVMContext.STR2MASK[arr.device.type], arr.device.index)
def asdglarray(arr): def asdglarray(arr):
assert arr.is_contiguous() assert arr.is_contiguous()
......
"""DGL's device context shim."""
class Context(object):
def __init__(self, dev, devid=-1):
self.device = dev
self.device_id = devid
def __str__(self):
return '{}:{}'.format(self.device, self.device_id)
def __eq__(self, other):
return self.device == other.device and self.device_id == other.device_id
def __hash__(self):
return hash((self.device, self.device_id))
def gpu(gpuid):
return Context('gpu', gpuid)
def cpu():
return Context('cpu')
...@@ -10,7 +10,6 @@ from .base import ALL, is_all, __MSG__, __REPR__ ...@@ -10,7 +10,6 @@ from .base import ALL, is_all, __MSG__, __REPR__
from . import backend as F from . import backend as F
from .backend import Tensor from .backend import Tensor
from .cached_graph import CachedGraph, create_cached_graph from .cached_graph import CachedGraph, create_cached_graph
from . import context
from .frame import FrameRef, merge_frames from .frame import FrameRef, merge_frames
from .nx_adapt import nx_init from .nx_adapt import nx_init
from . import scheduler from . import scheduler
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment