"...text-generation-inference.git" did not exist on "8b295aa498408ab526ce36bb726b5eaafa5e1593"
Unverified Commit 7815fe8a authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[CUDA] Make sanity check optional for `dgl.create_block`. (#7240)

parent 3c391533
...@@ -387,7 +387,12 @@ def heterograph(data_dict, num_nodes_dict=None, idtype=None, device=None): ...@@ -387,7 +387,12 @@ def heterograph(data_dict, num_nodes_dict=None, idtype=None, device=None):
def create_block( def create_block(
data_dict, num_src_nodes=None, num_dst_nodes=None, idtype=None, device=None data_dict,
num_src_nodes=None,
num_dst_nodes=None,
idtype=None,
device=None,
node_count_check=True,
): ):
"""Create a message flow graph (MFG) as a :class:`DGLBlock` object. """Create a message flow graph (MFG) as a :class:`DGLBlock` object.
...@@ -456,6 +461,9 @@ def create_block( ...@@ -456,6 +461,9 @@ def create_block(
the :attr:`data` argument. If :attr:`data` is not a tuple of node-tensors, the the :attr:`data` argument. If :attr:`data` is not a tuple of node-tensors, the
returned graph is on CPU. If the specified :attr:`device` differs from that of the returned graph is on CPU. If the specified :attr:`device` differs from that of the
provided tensors, it casts the given tensors to the specified device first. provided tensors, it casts the given tensors to the specified device first.
node_count_check : bool, optional
When num_src_nodes and num_dst_nodes are passed, whether we should perform
sanity checks to ensure they are valid.
Returns Returns
------- -------
...@@ -540,13 +548,16 @@ def create_block( ...@@ -540,13 +548,16 @@ def create_block(
node_tensor_dict = {} node_tensor_dict = {}
for (sty, ety, dty), data in data_dict.items(): for (sty, ety, dty), data in data_dict.items():
(sparse_fmt, arrays), urange, vrange = utils.graphdata2tensors( (sparse_fmt, arrays), urange, vrange = utils.graphdata2tensors(
data, idtype, bipartite=True data,
idtype,
bipartite=True,
infer_node_count=need_infer or node_count_check,
) )
node_tensor_dict[(sty, ety, dty)] = (sparse_fmt, arrays) node_tensor_dict[(sty, ety, dty)] = (sparse_fmt, arrays)
if need_infer: if need_infer:
num_src_nodes[sty] = max(num_src_nodes[sty], urange) num_src_nodes[sty] = max(num_src_nodes[sty], urange)
num_dst_nodes[dty] = max(num_dst_nodes[dty], vrange) num_dst_nodes[dty] = max(num_dst_nodes[dty], vrange)
else: # sanity check elif node_count_check: # sanity check
if num_src_nodes[sty] < urange: if num_src_nodes[sty] < urange:
raise DGLError( raise DGLError(
"The given number of nodes of source node type {} must be larger" "The given number of nodes of source node type {} must be larger"
......
...@@ -303,6 +303,7 @@ class MiniBatch: ...@@ -303,6 +303,7 @@ class MiniBatch:
sampled_csc, sampled_csc,
num_src_nodes=num_src_nodes, num_src_nodes=num_src_nodes,
num_dst_nodes=num_dst_nodes, num_dst_nodes=num_dst_nodes,
node_count_check=False,
) )
) )
......
...@@ -116,7 +116,9 @@ def networkx2tensor(nx_graph, idtype, edge_id_attr_name=None): ...@@ -116,7 +116,9 @@ def networkx2tensor(nx_graph, idtype, edge_id_attr_name=None):
SparseAdjTuple = namedtuple("SparseAdjTuple", ["format", "arrays"]) SparseAdjTuple = namedtuple("SparseAdjTuple", ["format", "arrays"])
def graphdata2tensors(data, idtype=None, bipartite=False, **kwargs): def graphdata2tensors(
data, idtype=None, bipartite=False, infer_node_count=True, **kwargs
):
"""Function to convert various types of data to edge tensors and infer """Function to convert various types of data to edge tensors and infer
the number of nodes. the number of nodes.
...@@ -137,6 +139,9 @@ def graphdata2tensors(data, idtype=None, bipartite=False, **kwargs): ...@@ -137,6 +139,9 @@ def graphdata2tensors(data, idtype=None, bipartite=False, **kwargs):
bipartite : bool, optional bipartite : bool, optional
Whether infer number of nodes of a bipartite graph -- Whether infer number of nodes of a bipartite graph --
num_src and num_dst can be different. num_src and num_dst can be different.
infer_node_count : bool, optional
Whether infer number of nodes at all. If False, num_src and num_dst
are returned as None.
kwargs kwargs
- edge_id_attr_name : The name (str) of the edge attribute that stores the edge - edge_id_attr_name : The name (str) of the edge attribute that stores the edge
...@@ -186,23 +191,28 @@ def graphdata2tensors(data, idtype=None, bipartite=False, **kwargs): ...@@ -186,23 +191,28 @@ def graphdata2tensors(data, idtype=None, bipartite=False, **kwargs):
data.format, tuple(F.tensor(a) for a in data.arrays) data.format, tuple(F.tensor(a) for a in data.arrays)
) )
num_src, num_dst = None, None
if isinstance(data, SparseAdjTuple): if isinstance(data, SparseAdjTuple):
if idtype is not None: if idtype is not None:
data = SparseAdjTuple( data = SparseAdjTuple(
data.format, tuple(F.astype(a, idtype) for a in data.arrays) data.format, tuple(F.astype(a, idtype) for a in data.arrays)
) )
num_src, num_dst = infer_num_nodes(data, bipartite=bipartite) if infer_node_count:
num_src, num_dst = infer_num_nodes(data, bipartite=bipartite)
elif isinstance(data, list): elif isinstance(data, list):
src, dst = elist2tensor(data, idtype) src, dst = elist2tensor(data, idtype)
data = SparseAdjTuple("coo", (src, dst)) data = SparseAdjTuple("coo", (src, dst))
num_src, num_dst = infer_num_nodes(data, bipartite=bipartite) if infer_node_count:
num_src, num_dst = infer_num_nodes(data, bipartite=bipartite)
elif isinstance(data, sp.sparse.spmatrix): elif isinstance(data, sp.sparse.spmatrix):
# We can get scipy matrix's number of rows and columns easily. # We can get scipy matrix's number of rows and columns easily.
num_src, num_dst = infer_num_nodes(data, bipartite=bipartite) if infer_node_count:
num_src, num_dst = infer_num_nodes(data, bipartite=bipartite)
data = scipy2tensor(data, idtype) data = scipy2tensor(data, idtype)
elif isinstance(data, nx.Graph): elif isinstance(data, nx.Graph):
# We can get networkx graph's number of sources and destinations easily. # We can get networkx graph's number of sources and destinations easily.
num_src, num_dst = infer_num_nodes(data, bipartite=bipartite) if infer_node_count:
num_src, num_dst = infer_num_nodes(data, bipartite=bipartite)
edge_id_attr_name = kwargs.get("edge_id_attr_name", None) edge_id_attr_name = kwargs.get("edge_id_attr_name", None)
if bipartite: if bipartite:
top_map = kwargs.get("top_map") top_map = kwargs.get("top_map")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment