Unverified Commit 5f2f100b authored by VoVAllen's avatar VoVAllen Committed by GitHub
Browse files

[Feature] convert np.ndarray to backend tensor when setting ndata/edata(#850)

parent 9df8cd32
......@@ -270,7 +270,7 @@ def zerocopy_to_numpy(input):
return asnumpy(input)
def zerocopy_from_numpy(np_array):
return th.from_numpy(np_array)
return th.as_tensor(np_array)
def zerocopy_to_dgl_ndarray(input):
return nd.from_dlpack(dlpack.to_dlpack(input.contiguous()))
......
......@@ -4,6 +4,8 @@ from __future__ import absolute_import
from collections import namedtuple
from collections.abc import MutableMapping
import numpy as np
from .base import ALL, is_all, DGLError
from . import backend as F
......@@ -57,6 +59,8 @@ class NodeDataView(MutableMapping):
return self._graph.get_n_repr(self._nodes)[key]
def __setitem__(self, key, val):
if isinstance(val, np.ndarray):
val = F.zerocopy_from_numpy(val)
self._graph.set_n_repr({key : val}, self._nodes)
def __delitem__(self, key):
......@@ -125,6 +129,8 @@ class EdgeDataView(MutableMapping):
return self._graph.get_e_repr(self._edges)[key]
def __setitem__(self, key, val):
if isinstance(val, np.ndarray):
val = F.zerocopy_from_numpy(val)
self._graph.set_e_repr({key : val}, self._edges)
def __delitem__(self, key):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment