Commit 0ef85f28 authored by lisj's avatar lisj
Browse files

恢复ctx转换中的代码,tests中对device的处理在cpu和cpu:0时判断为非同一设备

parent 351f7588
...@@ -89,18 +89,12 @@ def device_id(ctx): ...@@ -89,18 +89,12 @@ def device_id(ctx):
else: else:
return ctx.index return ctx.index
__devtype_th_map = {
1: "cpu",
2: "cuda", # cuda device
10: "cuda" # rocm device
}
def to_backend_ctx(dglctx): def to_backend_ctx(dglctx):
dev_type = dglctx.device_type dev_type = dglctx.device_type
if dev_type in __devtype_th_map: if dev_type == 1:
th_type = __devtype_th_map[dev_type] return th.device('cpu')
return th.device(th_type, dglctx.device_id) elif dev_type == 2 or dev_type == 10:
return th.device('cuda', dglctx.device_id)
else: else:
raise ValueError('Unsupported DGL device context:', dglctx) raise ValueError('Unsupported DGL device context:', dglctx)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment