"src/libtorchaudio/pybind/pybind.cpp" did not exist on "f70b970ab11694c035db0062df6b60f2550c1d43"
Unverified Commit c6890c23 authored by nv-dlasalle's avatar nv-dlasalle Committed by GitHub
Browse files

[Performance] In HeteroNodeView, build arange on target device, instead of on...


[Performance] In HeteroNodeView, build arange on target device, instead of on CPU and copying it (#2266)

* Build arange on target device

* Utilize arange device in viewpy:HeteroNodeView.__call__

* Work around uint64 error in TF to_dlpack

* Restore else clause
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
Co-authored-by: default avatarDa Zheng <zhengda1936@gmail.com>
parent d453d72d
......@@ -1168,7 +1168,7 @@ def sort_1d(input):
"""
pass
def arange(start, stop, dtype):
def arange(start, stop, dtype, ctx):
"""Create a 1D range int64 tensor.
Parameters
......@@ -1178,7 +1178,9 @@ def arange(start, stop, dtype):
stop : int
The range stop.
dtype: str
The dtype of result tensor
The dtype of result tensor.
ctx : Device context object.
Device context.
Returns
-------
......
......@@ -355,11 +355,11 @@ def sort_1d(input):
idx = nd.cast(idx, dtype='int64')
return val, idx
def arange(start, stop, dtype=np.int64):
def arange(start, stop, dtype=np.int64, ctx=None):
if start >= stop:
return nd.array([], dtype=dtype)
return nd.array([], dtype=dtype, ctx=ctx)
else:
return nd.arange(start, stop, dtype=dtype)
return nd.arange(start, stop, dtype=dtype, ctx=ctx)
def rand_shuffle(arr):
return mx.nd.random.shuffle(arr)
......
......@@ -286,8 +286,8 @@ def nonzero_1d(input):
def sort_1d(input):
return th.sort(input)
def arange(start, stop, dtype=th.int64):
return th.arange(start, stop, dtype=dtype)
def arange(start, stop, dtype=th.int64, ctx=None):
return th.arange(start, stop, dtype=dtype, device=ctx)
def rand_shuffle(arr):
idx = th.randperm(len(arr))
......
......@@ -421,8 +421,10 @@ def sort_1d(input):
return tf.sort(input), tf.cast(tf.argsort(input), dtype=tf.int64)
def arange(start, stop, dtype=tf.int64):
with tf.device("/cpu:0"):
def arange(start, stop, dtype=tf.int64, ctx=None):
if not ctx:
ctx = "/cpu:0"
with tf.device(ctx):
t = tf.range(start, stop, dtype=dtype)
return t
......@@ -444,10 +446,13 @@ def zerocopy_from_numpy(np_array):
def zerocopy_to_dgl_ndarray(data):
if data.dtype == tf.int32 and device_type(data.device) == 'gpu':
# NOTE: TF doesn't keep int32 tensors on GPU due to legacy issues with
# shape inference. Convert it to uint32 and cast it back afterwards.
data = tf.cast(data, tf.uint32)
if device_type(data.device) == 'gpu' and data.dtype in (tf.int32, tf.int64):
# NOTE: TF doesn't keep signed tensors on GPU due to legacy issues with
# shape inference. Convert it to unsigned and cast it back afterwards.
if data.dtype == tf.int32:
data = tf.cast(data, tf.uint32)
elif data.dtype == tf.int64:
data = tf.cast(data, tf.uint64)
return nd.cast_to_signed(nd.from_dlpack(zerocopy_to_dlpack(data)))
else:
return nd.from_dlpack(zerocopy_to_dlpack(data))
......
......@@ -40,9 +40,9 @@ class HeteroNodeView(object):
def __call__(self, ntype=None):
"""Return the nodes."""
ntid = self._typeid_getter(ntype)
return F.copy_to(F.arange(0, self._graph._graph.number_of_nodes(ntid),
dtype=self._graph.idtype),
self._graph.device)
ret = F.arange(0, self._graph._graph.number_of_nodes(ntid),
dtype=self._graph.idtype, ctx=self._graph.device)
return ret
class HeteroNodeDataView(MutableMapping):
"""The data view class when G.ndata[ntype] is called."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment