Commit c086d454 authored by Minjie Wang's avatar Minjie Wang
Browse files

Use TVMContext

parent b24daa66
......@@ -10,7 +10,6 @@ from ._ffi.base import DGLError, __version__
from .base import ALL
from .batch import batch, unbatch
from .context import cpu, gpu
from .generator import *
from .graph import DGLGraph, __MSG__, __REPR__
from .subgraph import DGLSubGraph
......@@ -218,6 +218,9 @@ class TVMContext(ctypes.Structure):
return "%s(%d)" % (
TVMContext.MASK2STR[self.device_type], self.device_id)
def __hash__(self):
return hash((self.device_type, self.device_id))
class TVMArray(ctypes.Structure):
"""TVMValue in C API"""
......
......@@ -4,7 +4,6 @@ import torch as th
from .._ffi.runtime_ctypes import TVMType, TVMContext, TVMArray
from .._ffi.runtime_ctypes import TypeCode, tvm_shape_index_t
from ..context import cpu, gpu
# Tensor types
Tensor = th.Tensor
......@@ -67,22 +66,23 @@ sort = th.sort
arange = th.arange
mul = th.mul
def to_context(x, ctx):
def to_context(arr, ctx):
if ctx is None:
return x
elif ctx.device == 'gpu':
return arr
elif ctx.device_type == TVMContext.STR2MASK['cuda']:
th.cuda.set_device(ctx.device_id)
return x.cuda()
elif ctx.device == 'cpu':
return x.cpu()
return arr.cuda()
elif ctx.device_type == TVMContext.STR2MASK['cpu']:
return arr.cpu()
else:
raise RuntimeError('Invalid context', ctx)
def get_context(x):
if x.device.type == 'cpu':
return cpu()
def get_context(arr):
if arr.device.type == 'cpu':
return TVMContext(TVMContext.STR2MASK['cpu'], 0)
else:
return gpu(x.device.index)
return TVMContext(
TVMContext.STR2MASK[arr.device.type], arr.device.index)
def asdglarray(arr):
assert arr.is_contiguous()
......
"""DGL's device context shim."""
class Context(object):
def __init__(self, dev, devid=-1):
self.device = dev
self.device_id = devid
def __str__(self):
return '{}:{}'.format(self.device, self.device_id)
def __eq__(self, other):
return self.device == other.device and self.device_id == other.device_id
def __hash__(self):
return hash((self.device, self.device_id))
def gpu(gpuid):
return Context('gpu', gpuid)
def cpu():
return Context('cpu')
......@@ -10,7 +10,6 @@ from .base import ALL, is_all, __MSG__, __REPR__
from . import backend as F
from .backend import Tensor
from .cached_graph import CachedGraph, create_cached_graph
from . import context
from .frame import FrameRef, merge_frames
from .nx_adapt import nx_init
from . import scheduler
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment