Unverified Commit 21e172d0 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Bugfix] fix numpy integer bug (#580)

parent 40dc1859
...@@ -35,7 +35,8 @@ def random_walk(g, seeds, num_traces, num_hops): ...@@ -35,7 +35,8 @@ def random_walk(g, seeds, num_traces, num_hops):
if len(seeds) == 0: if len(seeds) == 0:
return utils.toindex([]).tousertensor() return utils.toindex([]).tousertensor()
seeds = utils.toindex(seeds).todgltensor() seeds = utils.toindex(seeds).todgltensor()
traces = _CAPI_DGLRandomWalk(g._graph._handle, seeds, num_traces, num_hops) traces = _CAPI_DGLRandomWalk(g._graph._handle,
seeds, int(num_traces), int(num_hops))
return F.zerocopy_from_dlpack(traces.to_dlpack()) return F.zerocopy_from_dlpack(traces.to_dlpack())
...@@ -108,8 +109,8 @@ def random_walk_with_restart( ...@@ -108,8 +109,8 @@ def random_walk_with_restart(
return [] return []
seeds = utils.toindex(seeds).todgltensor() seeds = utils.toindex(seeds).todgltensor()
traces = _CAPI_DGLRandomWalkWithRestart( traces = _CAPI_DGLRandomWalkWithRestart(
g._graph._handle, seeds, restart_prob, max_nodes_per_seed, g._graph._handle, seeds, restart_prob, int(max_nodes_per_seed),
max_visit_counts, max_frequent_visited_nodes) int(max_visit_counts), int(max_frequent_visited_nodes))
return _split_traces(traces) return _split_traces(traces)
...@@ -160,8 +161,8 @@ def bipartite_single_sided_random_walk_with_restart( ...@@ -160,8 +161,8 @@ def bipartite_single_sided_random_walk_with_restart(
return [] return []
seeds = utils.toindex(seeds).todgltensor() seeds = utils.toindex(seeds).todgltensor()
traces = _CAPI_DGLBipartiteSingleSidedRandomWalkWithRestart( traces = _CAPI_DGLBipartiteSingleSidedRandomWalkWithRestart(
g._graph._handle, seeds, restart_prob, max_nodes_per_seed, g._graph._handle, seeds, restart_prob, int(max_nodes_per_seed),
max_visit_counts, max_frequent_visited_nodes) int(max_visit_counts), int(max_frequent_visited_nodes))
return _split_traces(traces) return _split_traces(traces)
_init_api('dgl.randomwalk', __name__) _init_api('dgl.randomwalk', __name__)
...@@ -130,8 +130,7 @@ class ThreadPrefetchingWrapper(PrefetchingWrapper, threading.Thread): ...@@ -130,8 +130,7 @@ class ThreadPrefetchingWrapper(PrefetchingWrapper, threading.Thread):
class NodeFlowSampler(object): class NodeFlowSampler(object):
''' '''Base class that generates NodeFlows from a graph.
Base class that generates NodeFlows from a graph.
Class properties Class properties
---------------- ----------------
...@@ -301,10 +300,10 @@ class NeighborSampler(NodeFlowSampler): ...@@ -301,10 +300,10 @@ class NeighborSampler(NodeFlowSampler):
assert node_prob is None, 'non-uniform node probability not supported' assert node_prob is None, 'non-uniform node probability not supported'
assert isinstance(expand_factor, Integral), 'non-int expand_factor not supported' assert isinstance(expand_factor, Integral), 'non-int expand_factor not supported'
self._expand_factor = expand_factor self._expand_factor = int(expand_factor)
self._num_hops = num_hops self._num_hops = int(num_hops)
self._add_self_loop = add_self_loop self._add_self_loop = add_self_loop
self._num_workers = num_workers self._num_workers = int(num_workers)
self._neighbor_type = neighbor_type self._neighbor_type = neighbor_type
def fetch(self, current_nodeflow_index): def fetch(self, current_nodeflow_index):
...@@ -366,7 +365,7 @@ class LayerSampler(NodeFlowSampler): ...@@ -366,7 +365,7 @@ class LayerSampler(NodeFlowSampler):
assert node_prob is None, 'non-uniform node probability not supported' assert node_prob is None, 'non-uniform node probability not supported'
self._num_workers = num_workers self._num_workers = int(num_workers)
self._neighbor_type = neighbor_type self._neighbor_type = neighbor_type
self._layer_sizes = utils.toindex(layer_sizes) self._layer_sizes = utils.toindex(layer_sizes)
......
...@@ -76,7 +76,7 @@ class GraphIndex(object): ...@@ -76,7 +76,7 @@ class GraphIndex(object):
num : int num : int
Number of nodes to be added. Number of nodes to be added.
""" """
_CAPI_DGLGraphAddVertices(self._handle, num) _CAPI_DGLGraphAddVertices(self._handle, int(num))
self.clear_cache() self.clear_cache()
def add_edge(self, u, v): def add_edge(self, u, v):
...@@ -186,7 +186,7 @@ class GraphIndex(object): ...@@ -186,7 +186,7 @@ class GraphIndex(object):
bool bool
True if the node exists, False otherwise. True if the node exists, False otherwise.
""" """
return bool(_CAPI_DGLGraphHasVertex(self._handle, vid)) return bool(_CAPI_DGLGraphHasVertex(self._handle, int(vid)))
def has_nodes(self, vids): def has_nodes(self, vids):
"""Return true if the nodes exist. """Return true if the nodes exist.
...@@ -255,7 +255,8 @@ class GraphIndex(object): ...@@ -255,7 +255,8 @@ class GraphIndex(object):
utils.Index utils.Index
Array of predecessors Array of predecessors
""" """
return utils.toindex(_CAPI_DGLGraphPredecessors(self._handle, v, radius)) return utils.toindex(_CAPI_DGLGraphPredecessors(
self._handle, int(v), int(radius)))
def successors(self, v, radius=1): def successors(self, v, radius=1):
"""Return the successors of the node. """Return the successors of the node.
...@@ -272,7 +273,8 @@ class GraphIndex(object): ...@@ -272,7 +273,8 @@ class GraphIndex(object):
utils.Index utils.Index
Array of successors Array of successors
""" """
return utils.toindex(_CAPI_DGLGraphSuccessors(self._handle, v, radius)) return utils.toindex(_CAPI_DGLGraphSuccessors(
self._handle, int(v), int(radius)))
def edge_id(self, u, v): def edge_id(self, u, v):
"""Return the id array of all edges between u and v. """Return the id array of all edges between u and v.
...@@ -364,7 +366,7 @@ class GraphIndex(object): ...@@ -364,7 +366,7 @@ class GraphIndex(object):
The edge ids. The edge ids.
""" """
if len(v) == 1: if len(v) == 1:
edge_array = _CAPI_DGLGraphInEdges_1(self._handle, v[0]) edge_array = _CAPI_DGLGraphInEdges_1(self._handle, int(v[0]))
else: else:
v_array = v.todgltensor() v_array = v.todgltensor()
edge_array = _CAPI_DGLGraphInEdges_2(self._handle, v_array) edge_array = _CAPI_DGLGraphInEdges_2(self._handle, v_array)
...@@ -391,7 +393,7 @@ class GraphIndex(object): ...@@ -391,7 +393,7 @@ class GraphIndex(object):
The edge ids. The edge ids.
""" """
if len(v) == 1: if len(v) == 1:
edge_array = _CAPI_DGLGraphOutEdges_1(self._handle, v[0]) edge_array = _CAPI_DGLGraphOutEdges_1(self._handle, int(v[0]))
else: else:
v_array = v.todgltensor() v_array = v.todgltensor()
edge_array = _CAPI_DGLGraphOutEdges_2(self._handle, v_array) edge_array = _CAPI_DGLGraphOutEdges_2(self._handle, v_array)
...@@ -848,7 +850,7 @@ class GraphIndex(object): ...@@ -848,7 +850,7 @@ class GraphIndex(object):
assert self.is_readonly() assert self.is_readonly()
self._handle = _CAPI_DGLGraphCSRCreateMMap( self._handle = _CAPI_DGLGraphCSRCreateMMap(
shared_mem_name, shared_mem_name,
num_nodes, num_edges, int(num_nodes), int(num_edges),
self._multigraph, self._multigraph,
edge_dir) edge_dir)
......
...@@ -40,7 +40,7 @@ def _add_receiver_addr(sender, ip_addr, port, recv_id): ...@@ -40,7 +40,7 @@ def _add_receiver_addr(sender, ip_addr, port, recv_id):
recv_id : int recv_id : int
Receiver ID Receiver ID
""" """
_CAPI_DGLSenderAddReceiver(sender, ip_addr, port, recv_id) _CAPI_DGLSenderAddReceiver(sender, ip_addr, int(port), int(recv_id))
def _sender_connect(sender): def _sender_connect(sender):
"""Connect to all the Receiver """Connect to all the Receiver
...@@ -70,7 +70,7 @@ def _send_nodeflow(sender, nodeflow, recv_id): ...@@ -70,7 +70,7 @@ def _send_nodeflow(sender, nodeflow, recv_id):
layers_offsets = utils.toindex(nodeflow._layer_offsets).todgltensor() layers_offsets = utils.toindex(nodeflow._layer_offsets).todgltensor()
flows_offsets = utils.toindex(nodeflow._block_offsets).todgltensor() flows_offsets = utils.toindex(nodeflow._block_offsets).todgltensor()
_CAPI_SenderSendSubgraph(sender, _CAPI_SenderSendSubgraph(sender,
recv_id, int(recv_id),
graph_handle, graph_handle,
node_mapping, node_mapping,
edge_mapping, edge_mapping,
...@@ -87,7 +87,7 @@ def _send_end_signal(sender, recv_id): ...@@ -87,7 +87,7 @@ def _send_end_signal(sender, recv_id):
recv_id : int recv_id : int
Receiver ID Receiver ID
""" """
_CAPI_SenderSendEndSignal(sender, recv_id) _CAPI_SenderSendEndSignal(sender, int(recv_id))
def _create_receiver(): def _create_receiver():
"""Create a Receiver communicator via C api """Create a Receiver communicator via C api
...@@ -113,7 +113,7 @@ def _receiver_wait(receiver, ip_addr, port, num_sender): ...@@ -113,7 +113,7 @@ def _receiver_wait(receiver, ip_addr, port, num_sender):
num_sender : int num_sender : int
total number of Sender total number of Sender
""" """
_CAPI_DGLReceiverWait(receiver, ip_addr, port, num_sender) _CAPI_DGLReceiverWait(receiver, ip_addr, int(port), int(num_sender))
def _recv_nodeflow(receiver, graph): def _recv_nodeflow(receiver, graph):
"""Receive sampled subgraph (NodeFlow) from remote sampler. """Receive sampled subgraph (NodeFlow) from remote sampler.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment