Unverified Commit c6890c23 authored by nv-dlasalle's avatar nv-dlasalle Committed by GitHub
Browse files

[Performance] In HeteroNodeView, build arange on target device, instead of on...


[Performance] In HeteroNodeView, build arange on target device, instead of on CPU and copying it (#2266)

* Build arange on target device

* Utilize arange device in viewpy:HeteroNodeView.__call__

* Work around uint64 error in TF to_dlpack

* Restore else clause
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
Co-authored-by: default avatarDa Zheng <zhengda1936@gmail.com>
parent d453d72d
...@@ -1168,7 +1168,7 @@ def sort_1d(input): ...@@ -1168,7 +1168,7 @@ def sort_1d(input):
""" """
pass pass
def arange(start, stop, dtype): def arange(start, stop, dtype, ctx):
"""Create a 1D range int64 tensor. """Create a 1D range int64 tensor.
Parameters Parameters
...@@ -1178,7 +1178,9 @@ def arange(start, stop, dtype): ...@@ -1178,7 +1178,9 @@ def arange(start, stop, dtype):
stop : int stop : int
The range stop. The range stop.
dtype: str dtype: str
The dtype of result tensor The dtype of result tensor.
ctx : Device context object.
Device context.
Returns Returns
------- -------
......
...@@ -355,11 +355,11 @@ def sort_1d(input): ...@@ -355,11 +355,11 @@ def sort_1d(input):
idx = nd.cast(idx, dtype='int64') idx = nd.cast(idx, dtype='int64')
return val, idx return val, idx
def arange(start, stop, dtype=np.int64): def arange(start, stop, dtype=np.int64, ctx=None):
if start >= stop: if start >= stop:
return nd.array([], dtype=dtype) return nd.array([], dtype=dtype, ctx=ctx)
else: else:
return nd.arange(start, stop, dtype=dtype) return nd.arange(start, stop, dtype=dtype, ctx=ctx)
def rand_shuffle(arr): def rand_shuffle(arr):
return mx.nd.random.shuffle(arr) return mx.nd.random.shuffle(arr)
......
...@@ -286,8 +286,8 @@ def nonzero_1d(input): ...@@ -286,8 +286,8 @@ def nonzero_1d(input):
def sort_1d(input): def sort_1d(input):
return th.sort(input) return th.sort(input)
def arange(start, stop, dtype=th.int64): def arange(start, stop, dtype=th.int64, ctx=None):
return th.arange(start, stop, dtype=dtype) return th.arange(start, stop, dtype=dtype, device=ctx)
def rand_shuffle(arr): def rand_shuffle(arr):
idx = th.randperm(len(arr)) idx = th.randperm(len(arr))
......
...@@ -421,8 +421,10 @@ def sort_1d(input): ...@@ -421,8 +421,10 @@ def sort_1d(input):
return tf.sort(input), tf.cast(tf.argsort(input), dtype=tf.int64) return tf.sort(input), tf.cast(tf.argsort(input), dtype=tf.int64)
def arange(start, stop, dtype=tf.int64): def arange(start, stop, dtype=tf.int64, ctx=None):
with tf.device("/cpu:0"): if not ctx:
ctx = "/cpu:0"
with tf.device(ctx):
t = tf.range(start, stop, dtype=dtype) t = tf.range(start, stop, dtype=dtype)
return t return t
...@@ -444,10 +446,13 @@ def zerocopy_from_numpy(np_array): ...@@ -444,10 +446,13 @@ def zerocopy_from_numpy(np_array):
def zerocopy_to_dgl_ndarray(data): def zerocopy_to_dgl_ndarray(data):
if data.dtype == tf.int32 and device_type(data.device) == 'gpu': if device_type(data.device) == 'gpu' and data.dtype in (tf.int32, tf.int64):
# NOTE: TF doesn't keep int32 tensors on GPU due to legacy issues with # NOTE: TF doesn't keep signed tensors on GPU due to legacy issues with
# shape inference. Convert it to uint32 and cast it back afterwards. # shape inference. Convert it to unsigned and cast it back afterwards.
if data.dtype == tf.int32:
data = tf.cast(data, tf.uint32) data = tf.cast(data, tf.uint32)
elif data.dtype == tf.int64:
data = tf.cast(data, tf.uint64)
return nd.cast_to_signed(nd.from_dlpack(zerocopy_to_dlpack(data))) return nd.cast_to_signed(nd.from_dlpack(zerocopy_to_dlpack(data)))
else: else:
return nd.from_dlpack(zerocopy_to_dlpack(data)) return nd.from_dlpack(zerocopy_to_dlpack(data))
......
...@@ -40,9 +40,9 @@ class HeteroNodeView(object): ...@@ -40,9 +40,9 @@ class HeteroNodeView(object):
def __call__(self, ntype=None): def __call__(self, ntype=None):
"""Return the nodes.""" """Return the nodes."""
ntid = self._typeid_getter(ntype) ntid = self._typeid_getter(ntype)
return F.copy_to(F.arange(0, self._graph._graph.number_of_nodes(ntid), ret = F.arange(0, self._graph._graph.number_of_nodes(ntid),
dtype=self._graph.idtype), dtype=self._graph.idtype, ctx=self._graph.device)
self._graph.device) return ret
class HeteroNodeDataView(MutableMapping): class HeteroNodeDataView(MutableMapping):
"""The data view class when G.ndata[ntype] is called.""" """The data view class when G.ndata[ntype] is called."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment