Unverified Commit 1a8e0c00 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] remove GraphMeta and update API accordingly (#6709)

parent 50cbb0ec
......@@ -1255,14 +1255,12 @@ def convert_dgl_partition_to_csc_sampling_graph(part_config):
graph, _, _, gpb, _, _, _ = load_partition(
part_config, part_id, load_feats=False
)
# Construct GraphMetadata.
_, _, ntypes, etypes = load_partition_book(part_config, part_id)
node_type_to_id = {ntype: ntid for ntid, ntype in enumerate(ntypes)}
edge_type_to_id = {
_etype_tuple_to_str(etype): etid
for etid, etype in enumerate(etypes)
}
metadata = graphbolt.GraphMetadata(node_type_to_id, edge_type_to_id)
# Obtain CSC indtpr and indices.
indptr, indices, _ = graph.adj().csc()
# Initalize type per edge.
......@@ -1275,7 +1273,8 @@ def convert_dgl_partition_to_csc_sampling_graph(part_config):
indices,
node_type_offset=None,
type_per_edge=type_per_edge,
metadata=metadata,
node_type_to_id=node_type_to_id,
edge_type_to_id=edge_type_to_id,
)
orig_graph_path = os.path.join(
os.path.dirname(part_config),
......
......@@ -20,7 +20,6 @@ from .sampled_subgraph_impl import (
__all__ = [
"GraphMetadata",
"FusedCSCSamplingGraph",
"from_fused_csc",
"load_from_shared_memory",
......@@ -28,65 +27,6 @@ __all__ = [
]
class GraphMetadata:
r"""Class for metadata of csc sampling graph."""
def __init__(
self,
node_type_to_id: Dict[str, int],
edge_type_to_id: Dict[str, int],
):
"""Initialize the GraphMetadata object.
Parameters
----------
node_type_to_id : Dict[str, int]
Dictionary from node types to node type IDs.
edge_type_to_id : Dict[str, int]
Dictionary from edge types to edge type IDs.
Raises
------
AssertionError
If any of the assertions fail.
"""
node_types = list(node_type_to_id.keys())
edge_types = list(edge_type_to_id.keys())
node_type_ids = list(node_type_to_id.values())
edge_type_ids = list(edge_type_to_id.values())
# Validate node_type_to_id.
assert all(
isinstance(x, str) for x in node_types
), "Node type name should be string."
assert all(
isinstance(x, int) for x in node_type_ids
), "Node type id should be int."
assert len(node_type_ids) == len(
set(node_type_ids)
), "Multiple node types shoud not be mapped to a same id."
# Validate edge_type_to_id.
for edge_type in edge_types:
src, edge, dst = etype_str_to_tuple(edge_type)
assert isinstance(edge, str), "Edge type name should be string."
assert (
src in node_types
), f"Unrecognized node type {src} in edge type {edge_type}"
assert (
dst in node_types
), f"Unrecognized node type {dst} in edge type {edge_type}"
assert all(
isinstance(x, int) for x in edge_type_ids
), "Edge type id should be int."
assert len(edge_type_ids) == len(
set(edge_type_ids)
), "Multiple edge types shoud not be mapped to a same id."
self.node_type_to_id = node_type_to_id
self.edge_type_to_id = edge_type_to_id
class FusedCSCSamplingGraph(SamplingGraph):
r"""A sampling graph in CSC format."""
......@@ -148,9 +88,11 @@ class FusedCSCSamplingGraph(SamplingGraph):
>>> node_type_offset = torch.LongTensor([0, 2, 5])
>>> type_per_edge = torch.LongTensor(
... [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])
>>> metadata = gb.GraphMetadata(ntypes, etypes)
>>> graph = gb.from_fused_csc(indptr, indices, node_type_offset,
... type_per_edge, None, metadata)
>>> graph = gb.from_fused_csc(indptr, indices,
... node_type_offset=node_type_offset,
... type_per_edge=type_per_edge,
... node_type_to_id=ntypes,
... edge_type_to_id=etypes)
>>> print(graph.num_nodes)
{'N0': 2, 'N1': 3}
"""
......@@ -158,14 +100,14 @@ class FusedCSCSamplingGraph(SamplingGraph):
offset = self.node_type_offset
# Homogenous.
if offset is None or self.metadata is None:
if offset is None or self.node_type_to_id is None:
return self._c_csc_graph.num_nodes()
# Heterogenous
else:
num_nodes_per_type = {
_type: (offset[_idx + 1] - offset[_idx]).item()
for _type, _idx in self.metadata.node_type_to_id.items()
for _type, _idx in self.node_type_to_id.items()
}
return num_nodes_per_type
......@@ -358,21 +300,6 @@ class FusedCSCSamplingGraph(SamplingGraph):
"""Sets the edge attributes dictionary."""
self._c_csc_graph.set_edge_attributes(edge_attributes)
@property
def metadata(self) -> Optional[GraphMetadata]:
"""Returns the metadata of the graph.
[TODO][Rui] This API needs to be updated.
Returns
-------
GraphMetadata or None
If present, returns the metadata of the graph.
"""
if self.node_type_to_id is None or self.edge_type_to_id is None:
return None
return GraphMetadata(self.node_type_to_id, self.edge_type_to_id)
def in_subgraph(
self, nodes: Union[torch.Tensor, Dict[str, torch.Tensor]]
) -> FusedSampledSubgraphImpl:
......@@ -405,14 +332,16 @@ class FusedCSCSamplingGraph(SamplingGraph):
>>> ntypes = {"N0": 0, "N1": 1}
>>> etypes = {
... "N0:R0:N0": 0, "N0:R1:N1": 1, "N1:R2:N0": 2, "N1:R3:N1": 3}
>>> metadata = gb.GraphMetadata(ntypes, etypes)
>>> indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
>>> indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
>>> node_type_offset = torch.LongTensor([0, 2, 5])
>>> type_per_edge = torch.LongTensor(
... [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])
>>> graph = gb.from_fused_csc(indptr, indices, node_type_offset,
... type_per_edge, None, metadata)
>>> graph = gb.from_fused_csc(indptr, indices,
... node_type_offset=node_type_offset,
... type_per_edge=type_per_edge,
... node_type_to_id=ntypes,
... edge_type_to_id=etypes)
>>> nodes = {"N0":torch.LongTensor([1]), "N1":torch.LongTensor([1, 2])}
>>> in_subgraph = graph.in_subgraph(nodes)
>>> print(in_subgraph.node_pairs)
......@@ -465,10 +394,10 @@ class FusedCSCSamplingGraph(SamplingGraph):
# converted to heterogeneous graphs.
node_pairs = defaultdict(list)
original_hetero_edge_ids = {}
for etype, etype_id in self.metadata.edge_type_to_id.items():
for etype, etype_id in self.edge_type_to_id.items():
src_ntype, _, dst_ntype = etype_str_to_tuple(etype)
src_ntype_id = self.metadata.node_type_to_id[src_ntype]
dst_ntype_id = self.metadata.node_type_to_id[dst_ntype]
src_ntype_id = self.node_type_to_id[src_ntype]
dst_ntype_id = self.node_type_to_id[dst_ntype]
mask = type_per_edge == etype_id
hetero_row = row[mask] - self.node_type_offset[src_ntype_id]
hetero_column = (
......@@ -486,7 +415,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
def _convert_to_homogeneous_nodes(self, nodes):
homogeneous_nodes = []
for ntype, ids in nodes.items():
ntype_id = self.metadata.node_type_to_id[ntype]
ntype_id = self.node_type_to_id[ntype]
homogeneous_nodes.append(ids + self.node_type_offset[ntype_id])
return torch.cat(homogeneous_nodes)
......@@ -526,7 +455,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
subgraph_indptr = {}
node_edge_type = defaultdict(list)
original_hetero_edge_ids = {}
for etype, etype_id in self.metadata.edge_type_to_id.items():
for etype, etype_id in self.edge_type_to_id.items():
subgraph_indice[etype] = torch.empty(
(num.get(etype_id, 0),), dtype=indices.dtype
)
......@@ -539,7 +468,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
# Preprocessing saves the type of seed_nodes as the edge type
# of dst_ntype.
_, _, dst_ntype = etype_str_to_tuple(etype)
dst_ntype_id = self.metadata.node_type_to_id[dst_ntype]
dst_ntype_id = self.node_type_to_id[dst_ntype]
node_edge_type[dst_ntype_id].append((etype, etype_id))
# construct subgraphs
for i, seed in enumerate(column):
......@@ -553,7 +482,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
)
for etype, etype_id in node_edge_type[node_type]:
src_ntype, _, _ = etype_str_to_tuple(etype)
src_ntype_id = self.metadata.node_type_to_id[src_ntype]
src_ntype_id = self.node_type_to_id[src_ntype]
num_edges = torch.searchsorted(
type_per_edge[l:r], etype_id, right=True
).item()
......@@ -578,7 +507,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
indptr=torch.tensor(subgraph_indptr[etype]),
indices=subgraph_indice[etype],
)
for etype in self.metadata.edge_type_to_id.keys()
for etype in self.edge_type_to_id.keys()
}
return SampledSubgraphImpl(
node_pairs=node_pairs,
......@@ -644,13 +573,15 @@ class FusedCSCSamplingGraph(SamplingGraph):
>>> import torch
>>> ntypes = {"n1": 0, "n2": 1}
>>> etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
>>> metadata = gb.GraphMetadata(ntypes, etypes)
>>> indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
>>> indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
>>> node_type_offset = torch.LongTensor([0, 2, 5])
>>> type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
>>> graph = gb.from_fused_csc(indptr, indices, type_per_edge=type_per_edge,
... node_type_offset=node_type_offset, metadata=metadata)
>>> graph = gb.from_fused_csc(indptr, indices,
... node_type_offset=node_type_offset,
... type_per_edge=type_per_edge,
... node_type_to_id=ntypes,
... edge_type_to_id=etypes)
>>> nodes = {'n1': torch.LongTensor([0]), 'n2': torch.LongTensor([0])}
>>> fanouts = torch.tensor([1, 1])
>>> subgraph = graph.sample_neighbors(nodes, fanouts)
......@@ -673,8 +604,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
assert nodes.dim() == 1, "Nodes should be 1-D tensor."
assert fanouts.dim() == 1, "Fanouts should be 1-D tensor."
expected_fanout_len = 1
if self.metadata and self.metadata.edge_type_to_id:
expected_fanout_len = len(self.metadata.edge_type_to_id)
if self.edge_type_to_id:
expected_fanout_len = len(self.edge_type_to_id)
assert len(fanouts) in [
expected_fanout_len,
1,
......@@ -830,13 +761,15 @@ class FusedCSCSamplingGraph(SamplingGraph):
>>> import torch
>>> ntypes = {"n1": 0, "n2": 1}
>>> etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
>>> metadata = gb.GraphMetadata(ntypes, etypes)
>>> indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
>>> indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
>>> node_type_offset = torch.LongTensor([0, 2, 5])
>>> type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
>>> graph = gb.from_fused_csc(indptr, indices, type_per_edge=type_per_edge,
... node_type_offset=node_type_offset, metadata=metadata)
>>> graph = gb.from_fused_csc(indptr, indices,
... node_type_offset=node_type_offset,
... type_per_edge=type_per_edge,
... node_type_to_id=ntypes,
... edge_type_to_id=etypes)
>>> nodes = {'n1': torch.LongTensor([0]), 'n2': torch.LongTensor([0])}
>>> fanouts = torch.tensor([1, 1])
>>> subgraph = graph.sample_layer_neighbors(nodes, fanouts)
......@@ -907,7 +840,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
), "The 'node_type_offset' array is necessary for performing \
negative sampling by edge type."
_, _, dst_node_type = etype_str_to_tuple(edge_type)
dst_node_type_id = self.metadata.node_type_to_id[dst_node_type]
dst_node_type_id = self.node_type_to_id[dst_node_type]
max_node_id = (
self.node_type_offset[dst_node_type_id + 1]
- self.node_type_offset[dst_node_type_id]
......@@ -965,8 +898,9 @@ def from_fused_csc(
indices: torch.Tensor,
node_type_offset: Optional[torch.tensor] = None,
type_per_edge: Optional[torch.tensor] = None,
node_type_to_id: Optional[Dict[str, int]] = None,
edge_type_to_id: Optional[Dict[str, int]] = None,
edge_attributes: Optional[Dict[str, torch.tensor]] = None,
metadata: Optional[GraphMetadata] = None,
) -> FusedCSCSamplingGraph:
"""Create a FusedCSCSamplingGraph object from a CSC representation.
......@@ -982,10 +916,12 @@ def from_fused_csc(
Offset of node types in the graph, by default None.
type_per_edge : Optional[torch.tensor], optional
Type ids of each edge in the graph, by default None.
node_type_to_id : Optional[Dict[str, int]], optional
Map node types to ids, by default None.
edge_type_to_id : Optional[Dict[str, int]], optional
Map edge types to ids, by default None.
edge_attributes: Optional[Dict[str, torch.tensor]], optional
Edge attributes of the graph, by default None.
metadata: Optional[GraphMetadata], optional
Metadata of the graph, by default None.
Returns
-------
......@@ -995,8 +931,7 @@ def from_fused_csc(
Examples
--------
>>> ntypes = {'n1': 0, 'n2': 1, 'n3': 2}
>>> etypes = {('n1', 'e1', 'n2'): 0, ('n1', 'e2', 'n3'): 1}
>>> metadata = graphbolt.GraphMetadata(ntypes, etypes)
>>> etypes = {'n1:e1:n2': 0, 'n1:e2:n3': 1}
>>> csc_indptr = torch.tensor([0, 2, 5, 7])
>>> indices = torch.tensor([1, 3, 0, 1, 2, 0, 3])
>>> node_type_offset = torch.tensor([0, 1, 2, 3])
......@@ -1004,18 +939,50 @@ def from_fused_csc(
>>> graph = graphbolt.from_fused_csc(csc_indptr, indices,
... node_type_offset=node_type_offset,
... type_per_edge=type_per_edge,
... edge_attributes=None, metadata=metadata)
... node_type_to_id=ntypes, edge_type_to_id=etypes,
... edge_attributes=None,)
>>> print(graph)
FusedCSCSamplingGraph(csc_indptr=tensor([0, 2, 5, 7]),
indices=tensor([1, 3, 0, 1, 2, 0, 3]),
total_num_nodes=3, total_num_edges=7)
"""
if metadata and metadata.node_type_to_id and node_type_offset is not None:
assert len(metadata.node_type_to_id) + 1 == node_type_offset.size(
0
), "node_type_offset length should be |ntypes| + 1."
node_type_to_id = metadata.node_type_to_id if metadata else None
edge_type_to_id = metadata.edge_type_to_id if metadata else None
if node_type_to_id is not None and edge_type_to_id is not None:
node_types = list(node_type_to_id.keys())
edge_types = list(edge_type_to_id.keys())
node_type_ids = list(node_type_to_id.values())
edge_type_ids = list(edge_type_to_id.values())
# Validate node_type_to_id.
assert all(
isinstance(x, str) for x in node_types
), "Node type name should be string."
assert all(
isinstance(x, int) for x in node_type_ids
), "Node type id should be int."
assert len(node_type_ids) == len(
set(node_type_ids)
), "Multiple node types shoud not be mapped to a same id."
# Validate edge_type_to_id.
for edge_type in edge_types:
src, edge, dst = etype_str_to_tuple(edge_type)
assert isinstance(edge, str), "Edge type name should be string."
assert (
src in node_types
), f"Unrecognized node type {src} in edge type {edge_type}"
assert (
dst in node_types
), f"Unrecognized node type {dst} in edge type {edge_type}"
assert all(
isinstance(x, int) for x in edge_type_ids
), "Edge type id should be int."
assert len(edge_type_ids) == len(
set(edge_type_ids)
), "Multiple edge types shoud not be mapped to a same id."
if node_type_offset is not None:
assert len(node_type_to_id) + 1 == node_type_offset.size(
0
), "node_type_offset length should be |ntypes| + 1."
return FusedCSCSamplingGraph(
torch.ops.graphbolt.from_fused_csc(
csc_indptr,
......@@ -1091,7 +1058,8 @@ def from_dglgraph(
homo_g, ntype_count, _ = to_homogeneous(g, return_count=True)
if is_homogeneous:
metadata = None
node_type_to_id = None
edge_type_to_id = None
else:
# Initialize metadata.
node_type_to_id = {ntype: g.get_ntype_id(ntype) for ntype in g.ntypes}
......@@ -1099,7 +1067,6 @@ def from_dglgraph(
etype_tuple_to_str(etype): g.get_etype_id(etype)
for etype in g.canonical_etypes
}
metadata = GraphMetadata(node_type_to_id, edge_type_to_id)
# Obtain CSC matrix.
indptr, indices, edge_ids = homo_g.adj_tensors("csc")
......@@ -1118,8 +1085,6 @@ def from_dglgraph(
# Assign edge attributes according to the original eids mapping.
edge_attributes[ORIGINAL_EDGE_ID] = homo_g.edata[EID][edge_ids]
node_type_to_id = metadata.node_type_to_id if metadata else None
edge_type_to_id = metadata.edge_type_to_id if metadata else None
return FusedCSCSamplingGraph(
torch.ops.graphbolt.from_fused_csc(
indptr,
......
......@@ -112,7 +112,7 @@ class NeighborSampler(SubgraphSampler):
num_layers = len(self.fanouts)
# Enrich seeds with all node types.
if isinstance(seeds, dict):
ntypes = list(self.graph.metadata.node_type_to_id.keys())
ntypes = list(self.graph.node_type_to_id.keys())
seeds = {
ntype: seeds.get(ntype, torch.LongTensor([]))
for ntype in ntypes
......
......@@ -705,9 +705,9 @@ def test_convert_dgl_partition_to_csc_sampling_graph_homo(
assert th.equal(orig_indices, new_g.indices)
assert new_g.node_type_offset is None
assert all(new_g.type_per_edge == 0)
for node_type, type_id in new_g.metadata.node_type_to_id.items():
for node_type, type_id in new_g.node_type_to_id.items():
assert g.get_ntype_id(node_type) == type_id
for edge_type, type_id in new_g.metadata.edge_type_to_id.items():
for edge_type, type_id in new_g.edge_type_to_id.items():
assert g.get_etype_id(_etype_str_to_tuple(edge_type)) == type_id
......@@ -736,9 +736,9 @@ def test_convert_dgl_partition_to_csc_sampling_graph_hetero(
orig_indptr, orig_indices, _ = orig_g.adj().csc()
assert th.equal(orig_indptr, new_g.csc_indptr)
assert th.equal(orig_indices, new_g.indices)
for node_type, type_id in new_g.metadata.node_type_to_id.items():
for node_type, type_id in new_g.node_type_to_id.items():
assert g.get_ntype_id(node_type) == type_id
for edge_type, type_id in new_g.metadata.edge_type_to_id.items():
for edge_type, type_id in new_g.edge_type_to_id.items():
assert g.get_etype_id(_etype_str_to_tuple(edge_type)) == type_id
assert new_g.node_type_offset is None
assert th.equal(orig_g.edata[dgl.ETYPE], new_g.type_per_edge)
......
......@@ -32,7 +32,7 @@ def random_homo_graph(num_nodes, num_edges):
return csc_indptr, indices
def get_metadata(num_ntypes, num_etypes):
def get_type_to_id(num_ntypes, num_etypes):
ntypes = {f"n{i}": i for i in range(num_ntypes)}
etypes = {}
count = 0
......@@ -42,7 +42,7 @@ def get_metadata(num_ntypes, num_etypes):
break
etypes.update({f"n{n1}:e{count}:n{n2}": count})
count += 1
return gb.GraphMetadata(ntypes, etypes)
return ntypes, etypes
def get_ntypes_and_etypes(num_nodes, num_ntypes, num_etypes):
......@@ -82,7 +82,8 @@ def random_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes):
gb_g.indices,
gb_g.node_type_offset,
gb_g.type_per_edge,
gb_g.metadata,
gb_g.node_type_to_id,
gb_g.edge_type_to_id,
)
......
......@@ -42,7 +42,9 @@ def test_empty_graph(total_num_nodes):
def test_hetero_empty_graph(total_num_nodes):
csc_indptr = torch.zeros((total_num_nodes + 1,), dtype=int)
indices = torch.tensor([])
metadata = gbt.get_metadata(num_ntypes=3, num_etypes=5)
node_type_to_id, edge_type_to_id = gbt.get_type_to_id(
num_ntypes=3, num_etypes=5
)
# Some node types have no nodes.
if total_num_nodes == 0:
node_type_offset = torch.zeros((4,), dtype=int)
......@@ -56,17 +58,18 @@ def test_hetero_empty_graph(total_num_nodes):
graph = gb.from_fused_csc(
csc_indptr,
indices,
node_type_offset,
type_per_edge,
None,
metadata,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=node_type_to_id,
edge_type_to_id=edge_type_to_id,
edge_attributes=None,
)
assert graph.total_num_edges == 0
assert graph.total_num_nodes == total_num_nodes
assert torch.equal(graph.csc_indptr, csc_indptr)
assert torch.equal(graph.indices, indices)
assert graph.metadata.node_type_to_id == metadata.node_type_to_id
assert graph.metadata.edge_type_to_id == metadata.edge_type_to_id
assert graph.node_type_to_id == node_type_to_id
assert graph.edge_type_to_id == edge_type_to_id
assert torch.equal(graph.node_type_offset, node_type_offset)
assert torch.equal(graph.type_per_edge, type_per_edge)
......@@ -78,9 +81,11 @@ def test_hetero_empty_graph(total_num_nodes):
@pytest.mark.parametrize(
"ntypes", [{"n1": 1, "n2": 1}, {5: 1, "n2": 2}, {"n1": 1.5, "n2": 2.0}]
)
def test_metadata_with_ntype_exception(ntypes):
with pytest.raises(Exception):
gb.GraphMetadata(ntypes, {"n1:e1:n2": 1})
def test_type_to_id_with_ntype_exception(ntypes):
with pytest.raises(AssertionError):
gb.from_fused_csc(
None, None, node_type_to_id=ntypes, edge_type_to_id={"e1": 1}
)
@unittest.skipIf(
......@@ -99,9 +104,14 @@ def test_metadata_with_ntype_exception(ntypes):
{"n1:e1:n2": 1.5},
],
)
def test_metadata_with_etype_exception(etypes):
def test_type_to_id_with_etype_exception(etypes):
with pytest.raises(Exception):
gb.GraphMetadata({"n1": 0, "n2": 1, "n3": 2}, etypes)
gb.from_fused_csc(
None,
None,
node_type_to_id={"n1": 0, "n2": 1, "n3": 2},
edge_type_to_id=etypes,
)
@unittest.skipIf(
......@@ -131,9 +141,10 @@ def test_homo_graph(total_num_nodes, total_num_edges):
assert torch.equal(indices, graph.indices)
assert graph.edge_attributes == edge_attributes
assert graph.metadata is None
assert graph.node_type_offset is None
assert graph.type_per_edge is None
assert graph.node_type_to_id is None
assert graph.edge_type_to_id is None
@unittest.skipIf(
......@@ -151,7 +162,8 @@ def test_hetero_graph(total_num_nodes, total_num_edges, num_ntypes, num_etypes):
indices,
node_type_offset,
type_per_edge,
metadata,
node_type_to_id,
edge_type_to_id,
) = gbt.random_hetero_graph(
total_num_nodes, total_num_edges, num_ntypes, num_etypes
)
......@@ -162,10 +174,11 @@ def test_hetero_graph(total_num_nodes, total_num_edges, num_ntypes, num_etypes):
graph = gb.from_fused_csc(
csc_indptr,
indices,
node_type_offset,
type_per_edge,
edge_attributes,
metadata,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=node_type_to_id,
edge_type_to_id=edge_type_to_id,
edge_attributes=edge_attributes,
)
assert graph.total_num_nodes == total_num_nodes
......@@ -176,8 +189,8 @@ def test_hetero_graph(total_num_nodes, total_num_edges, num_ntypes, num_etypes):
assert torch.equal(node_type_offset, graph.node_type_offset)
assert torch.equal(type_per_edge, graph.type_per_edge)
assert graph.edge_attributes == edge_attributes
assert metadata.node_type_to_id == graph.metadata.node_type_to_id
assert metadata.edge_type_to_id == graph.metadata.edge_type_to_id
assert node_type_to_id == graph.node_type_to_id
assert edge_type_to_id == graph.edge_type_to_id
@unittest.skipIf(
......@@ -247,9 +260,13 @@ def test_num_nodes_hetero():
assert all(type_per_edge < len(etypes))
# Construct FusedCSCSamplingGraph.
metadata = gb.GraphMetadata(ntypes, etypes)
graph = gb.from_fused_csc(
indptr, indices, node_type_offset, type_per_edge, None, metadata
indptr,
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
)
# Verify nodes number per node types.
......@@ -283,12 +300,22 @@ def test_num_nodes_hetero():
)
def test_node_type_offset_wrong_legnth(node_type_offset):
num_ntypes = 3
csc_indptr, indices, _, type_per_edge, metadata = gbt.random_hetero_graph(
10, 50, num_ntypes, 5
)
(
csc_indptr,
indices,
_,
type_per_edge,
node_type_to_id,
edge_type_to_id,
) = gbt.random_hetero_graph(10, 50, num_ntypes, 5)
with pytest.raises(Exception):
gb.from_fused_csc(
csc_indptr, indices, node_type_offset, type_per_edge, None, metadata
csc_indptr,
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=node_type_to_id,
edge_type_to_id=edge_type_to_id,
)
......@@ -317,9 +344,11 @@ def test_load_save_homo_graph(total_num_nodes, total_num_edges):
assert torch.equal(graph.csc_indptr, graph2.csc_indptr)
assert torch.equal(graph.indices, graph2.indices)
assert graph.metadata is None and graph2.metadata is None
assert graph.node_type_offset is None and graph2.node_type_offset is None
assert graph.type_per_edge is None and graph2.type_per_edge is None
assert graph.node_type_to_id is None and graph2.node_type_to_id is None
assert graph.edge_type_to_id is None and graph2.edge_type_to_id is None
assert graph.edge_attributes is None and graph2.edge_attributes is None
@unittest.skipIf(
......@@ -339,12 +368,18 @@ def test_load_save_hetero_graph(
indices,
node_type_offset,
type_per_edge,
metadata,
node_type_to_id,
edge_type_to_id,
) = gbt.random_hetero_graph(
total_num_nodes, total_num_edges, num_ntypes, num_etypes
)
graph = gb.from_fused_csc(
csc_indptr, indices, node_type_offset, type_per_edge, None, metadata
csc_indptr,
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=node_type_to_id,
edge_type_to_id=edge_type_to_id,
)
with tempfile.TemporaryDirectory() as test_dir:
......@@ -359,8 +394,8 @@ def test_load_save_hetero_graph(
assert torch.equal(graph.indices, graph2.indices)
assert torch.equal(graph.node_type_offset, graph2.node_type_offset)
assert torch.equal(graph.type_per_edge, graph2.type_per_edge)
assert graph.metadata.node_type_to_id == graph2.metadata.node_type_to_id
assert graph.metadata.edge_type_to_id == graph2.metadata.edge_type_to_id
assert graph.node_type_to_id == graph2.node_type_to_id
assert graph.edge_type_to_id == graph2.edge_type_to_id
@unittest.skipIf(
......@@ -410,7 +445,8 @@ def test_pickle_hetero_graph(
indices,
node_type_offset,
type_per_edge,
metadata,
node_type_to_id,
edge_type_to_id,
) = gbt.random_hetero_graph(
total_num_nodes, total_num_edges, num_ntypes, num_etypes
)
......@@ -421,10 +457,11 @@ def test_pickle_hetero_graph(
graph = gb.from_fused_csc(
csc_indptr,
indices,
node_type_offset,
type_per_edge,
edge_attributes,
metadata,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=node_type_to_id,
edge_type_to_id=edge_type_to_id,
edge_attributes=edge_attributes,
)
serialized = pickle.dumps(graph)
......@@ -466,7 +503,8 @@ def test_multiprocessing():
indices,
node_type_offset,
type_per_edge,
metadata,
node_type_to_id,
edge_type_to_id,
) = gbt.random_hetero_graph(
total_num_nodes, total_num_edges, num_ntypes, num_etypes
)
......@@ -476,10 +514,11 @@ def test_multiprocessing():
graph = gb.from_fused_csc(
csc_indptr,
indices,
node_type_offset,
type_per_edge,
edge_attributes,
metadata,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=node_type_to_id,
edge_type_to_id=edge_type_to_id,
edge_attributes=edge_attributes,
)
p = mp.Process(
......@@ -572,9 +611,13 @@ def test_in_subgraph_heterogeneous():
assert all(type_per_edge < len(etypes))
# Construct FusedCSCSamplingGraph.
metadata = gb.GraphMetadata(ntypes, etypes)
graph = gb.from_fused_csc(
indptr, indices, node_type_offset, type_per_edge, None, metadata
indptr,
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
)
# Extract in subgraph.
......@@ -678,8 +721,6 @@ def test_sample_neighbors_hetero(labor, indptr_dtype, indices_dtype):
# Initialize data.
ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
metadata = gb.GraphMetadata(ntypes, etypes)
total_num_nodes = 5
total_num_edges = 9
indptr = torch.tensor([0, 2, 4, 6, 7, 9], dtype=indptr_dtype)
indices = torch.tensor([2, 4, 2, 3, 0, 1, 1, 0, 1], dtype=indices_dtype)
......@@ -694,7 +735,8 @@ def test_sample_neighbors_hetero(labor, indptr_dtype, indices_dtype):
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
metadata=metadata,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
)
# Sample on both node types.
......@@ -787,8 +829,6 @@ def test_sample_neighbors_fanouts(
# Initialize data.
ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
metadata = gb.GraphMetadata(ntypes, etypes)
total_num_nodes = 5
total_num_edges = 9
indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
......@@ -803,7 +843,8 @@ def test_sample_neighbors_fanouts(
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
metadata=metadata,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
)
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])}
......@@ -845,8 +886,6 @@ def test_sample_neighbors_replace(
# Initialize data.
ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
metadata = gb.GraphMetadata(ntypes, etypes)
total_num_nodes = 5
total_num_edges = 9
indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
......@@ -861,7 +900,8 @@ def test_sample_neighbors_replace(
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
metadata=metadata,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
)
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])}
......@@ -888,7 +928,6 @@ def test_sample_neighbors_return_eids_homo(labor):
1 0 0 0 1
"""
# Initialize data.
total_num_nodes = 5
total_num_edges = 12
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
......@@ -933,8 +972,6 @@ def test_sample_neighbors_return_eids_hetero(labor):
# Initialize data.
ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
metadata = gb.GraphMetadata(ntypes, etypes)
total_num_nodes = 5
total_num_edges = 9
indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
......@@ -952,8 +989,9 @@ def test_sample_neighbors_return_eids_hetero(labor):
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
edge_attributes=edge_attributes,
metadata=metadata,
)
# Sample on both node types.
......@@ -991,7 +1029,6 @@ def test_sample_neighbors_probs(replace, labor, probs_name):
1 0 0 0 1
"""
# Initialize data.
total_num_nodes = 5
total_num_edges = 12
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
......@@ -1042,7 +1079,6 @@ def test_sample_neighbors_probs(replace, labor, probs_name):
)
def test_sample_neighbors_zero_probs(replace, labor, probs_or_mask):
# Initialize data.
total_num_nodes = 5
total_num_edges = 12
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
......@@ -1143,9 +1179,10 @@ def test_homo_graph_on_shared_memory(
graph1.edge_attributes[name], graph2.edge_attributes[name]
)
assert graph1.metadata is None and graph2.metadata is None
assert graph1.node_type_offset is None and graph2.node_type_offset is None
assert graph1.type_per_edge is None and graph2.type_per_edge is None
assert graph1.node_type_to_id is None and graph2.node_type_to_id is None
assert graph1.edge_type_to_id is None and graph2.edge_type_to_id is None
@unittest.skipIf(
......@@ -1168,7 +1205,8 @@ def test_hetero_graph_on_shared_memory(
indices,
node_type_offset,
type_per_edge,
metadata,
node_type_to_id,
edge_type_to_id,
) = gbt.random_hetero_graph(
total_num_nodes, total_num_edges, num_ntypes, num_etypes
)
......@@ -1185,8 +1223,9 @@ def test_hetero_graph_on_shared_memory(
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=node_type_to_id,
edge_type_to_id=edge_type_to_id,
edge_attributes=edge_attributes,
metadata=metadata,
)
shm_name = "test_hetero_g"
......@@ -1231,10 +1270,10 @@ def test_hetero_graph_on_shared_memory(
graph1.edge_attributes[name], graph2.edge_attributes[name]
)
assert metadata.node_type_to_id == graph1.metadata.node_type_to_id
assert metadata.edge_type_to_id == graph1.metadata.edge_type_to_id
assert metadata.node_type_to_id == graph2.metadata.node_type_to_id
assert metadata.edge_type_to_id == graph2.metadata.edge_type_to_id
assert node_type_to_id == graph1.node_type_to_id
assert edge_type_to_id == graph1.edge_type_to_id
assert node_type_to_id == graph2.node_type_to_id
assert edge_type_to_id == graph2.edge_type_to_id
def process_csc_sampling_graph_on_shared_memory(graph, data_queue, flag_queue):
......@@ -1293,7 +1332,8 @@ def test_multiprocessing_with_shared_memory():
indices,
node_type_offset,
type_per_edge,
metadata,
node_type_to_id,
edge_type_to_id,
) = gbt.random_hetero_graph(
total_num_nodes, total_num_edges, num_ntypes, num_etypes
)
......@@ -1308,8 +1348,9 @@ def test_multiprocessing_with_shared_memory():
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=node_type_to_id,
edge_type_to_id=edge_type_to_id,
edge_attributes=None,
metadata=metadata,
)
ctx = mp.get_context("spawn") # Use spawn method.
......@@ -1374,7 +1415,8 @@ def test_from_dglgraph_homogeneous():
assert gb_g.total_num_edges == dgl_g.num_edges()
assert gb_g.node_type_offset is None
assert gb_g.type_per_edge is None
assert gb_g.metadata is None
assert gb_g.node_type_to_id is None
assert gb_g.edge_type_to_id is None
@unittest.skipIf(
......@@ -1432,7 +1474,7 @@ def test_from_dglgraph_heterogeneous():
list(
map(
lambda ss: ss.split(":")[1],
gb_g.metadata.edge_type_to_id.keys(),
gb_g.edge_type_to_id.keys(),
)
)
== dgl_g.etypes
......@@ -1456,13 +1498,13 @@ def test_from_dglgraph_heterogeneous():
[3, 3, 3, 3, 3, 0, 0, 0, 0, 0, 1, 2, 1, 2, 1, 2, 1, 1, 2, 2, 1, 2]
),
)
assert gb_g.metadata.node_type_to_id == {
assert gb_g.node_type_to_id == {
"author": 0,
"field": 1,
"institution": 2,
"paper": 3,
}
assert gb_g.metadata.edge_type_to_id == {
assert gb_g.edge_type_to_id == {
"author:affiliated_with:institution": 0,
"author:writes:paper": 1,
"paper:cites:paper": 2,
......@@ -1584,7 +1626,6 @@ def test_sample_neighbors_hetero_pick_number(
"N0:R1:N2": 1,
"N0:R2:N3": 2,
}
metadata = gb.GraphMetadata(ntypes, etypes)
indptr = torch.LongTensor([0, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9])
indices = torch.LongTensor([1, 2, 3, 4, 5, 6, 7, 8, 9])
node_type_offset = torch.LongTensor([0, 1, 4, 7, 10])
......@@ -1607,7 +1648,8 @@ def test_sample_neighbors_hetero_pick_number(
edge_attributes=edge_attributes,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
metadata=metadata,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
)
# Generate subgraph via sample neighbors.
......@@ -1673,7 +1715,6 @@ def test_csc_sampling_graph_to_device():
"N0:R1:N2": 1,
"N0:R2:N3": 2,
}
metadata = gb.GraphMetadata(ntypes, etypes)
indptr = torch.LongTensor([0, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9])
indices = torch.LongTensor([1, 2, 3, 4, 5, 6, 7, 8, 9])
node_type_offset = torch.LongTensor([0, 1, 4, 7, 10])
......@@ -1696,7 +1737,8 @@ def test_csc_sampling_graph_to_device():
edge_attributes=edge_attributes,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
metadata=metadata,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
)
# Copy to device.
......@@ -1725,7 +1767,6 @@ def test_sample_neighbors_homo_csc_format():
1 0 0 0 1
"""
# Initialize data.
total_num_nodes = 5
total_num_edges = 12
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
......@@ -1769,8 +1810,6 @@ def test_sample_neighbors_hetero_csc_format(labor):
# Initialize data.
ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
metadata = gb.GraphMetadata(ntypes, etypes)
total_num_nodes = 5
total_num_edges = 9
indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
......@@ -1785,7 +1824,8 @@ def test_sample_neighbors_hetero_csc_format(labor):
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
metadata=metadata,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
)
# Sample on both node types.
......@@ -1875,8 +1915,6 @@ def test_sample_neighbors_fanouts_csc_format(
# Initialize data.
ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
metadata = gb.GraphMetadata(ntypes, etypes)
total_num_nodes = 5
total_num_edges = 9
indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
......@@ -1891,7 +1929,8 @@ def test_sample_neighbors_fanouts_csc_format(
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
metadata=metadata,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
)
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])}
......@@ -1937,8 +1976,6 @@ def test_sample_neighbors_replace_csc_format(
# Initialize data.
ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
metadata = gb.GraphMetadata(ntypes, etypes)
total_num_nodes = 5
total_num_edges = 9
indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
......@@ -1953,7 +1990,8 @@ def test_sample_neighbors_replace_csc_format(
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
metadata=metadata,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
)
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])}
......@@ -1986,7 +2024,6 @@ def test_sample_neighbors_return_eids_homo_csc_format(labor):
1 0 0 0 1
"""
# Initialize data.
total_num_nodes = 5
total_num_edges = 12
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
......@@ -2034,8 +2071,6 @@ def test_sample_neighbors_return_eids_hetero_csc_format(labor):
# Initialize data.
ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
metadata = gb.GraphMetadata(ntypes, etypes)
total_num_nodes = 5
total_num_edges = 9
indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
......@@ -2054,7 +2089,8 @@ def test_sample_neighbors_return_eids_hetero_csc_format(labor):
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
edge_attributes=edge_attributes,
metadata=metadata,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
)
# Sample on both node types.
......@@ -2092,7 +2128,6 @@ def test_sample_neighbors_probs_csc_format(replace, labor, probs_name):
1 0 0 0 1
"""
# Initialize data.
total_num_nodes = 5
total_num_edges = 12
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
......@@ -2209,7 +2244,6 @@ def test_sample_neighbors_homo_pick_number_csc_format(
0 0 0 0 0 0
"""
# Initialize data.
total_num_nodes = 6
total_num_edges = 6
indptr = torch.LongTensor([0, 6, 6, 6, 6, 6, 6])
indices = torch.LongTensor([0, 1, 2, 3, 4, 5])
......@@ -2291,7 +2325,6 @@ def test_sample_neighbors_hetero_pick_number_csc_format(
"N2:R1:N0": 1,
"N3:R2:N0": 2,
}
metadata = gb.GraphMetadata(ntypes, etypes)
indptr = torch.LongTensor([0, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9])
indices = torch.LongTensor([1, 2, 3, 4, 5, 6, 7, 8, 9])
node_type_offset = torch.LongTensor([0, 1, 4, 7, 10])
......@@ -2314,7 +2347,8 @@ def test_sample_neighbors_hetero_pick_number_csc_format(
edge_attributes=edge_attributes,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
metadata=metadata,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
)
# Generate subgraph via sample neighbors.
......
......@@ -76,7 +76,6 @@ def test_InSubgraphSampler_hetero():
"N1:R2:N0": 2,
"N1:R3:N1": 3,
}
metadata = gb.GraphMetadata(ntypes, etypes)
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12, 14])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 5, 1, 2, 0, 3, 5, 1, 4])
node_type_offset = torch.LongTensor([0, 3, 6])
......@@ -86,7 +85,8 @@ def test_InSubgraphSampler_hetero():
indices=indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
metadata=metadata,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
)
item_set = gb.ItemSetDict(
......
......@@ -106,7 +106,6 @@ def get_hetero_graph():
# num_nodes = 5, num_n1 = 2, num_n2 = 3
ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
metadata = gb.GraphMetadata(ntypes, etypes)
indptr = torch.LongTensor([0, 2, 4, 6, 8, 10])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1])
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0])
......@@ -116,7 +115,8 @@ def get_hetero_graph():
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
metadata=metadata,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
)
......
......@@ -1026,11 +1026,12 @@ def test_OnDiskDataset_Graph_homogeneous():
assert torch.equal(graph.csc_indptr, graph2.csc_indptr)
assert torch.equal(graph.indices, graph2.indices)
assert graph.metadata is None and graph2.metadata is None
assert (
graph.node_type_offset is None and graph2.node_type_offset is None
)
assert graph.type_per_edge is None and graph2.type_per_edge is None
assert graph.node_type_to_id is None and graph2.node_type_to_id is None
assert graph.edge_type_to_id is None and graph2.edge_type_to_id is None
def test_OnDiskDataset_Graph_heterogeneous():
......@@ -1040,10 +1041,16 @@ def test_OnDiskDataset_Graph_heterogeneous():
indices,
node_type_offset,
type_per_edge,
metadata,
node_type_to_id,
edge_type_to_id,
) = gbt.random_hetero_graph(1000, 10 * 1000, 3, 4)
graph = gb.from_fused_csc(
csc_indptr, indices, node_type_offset, type_per_edge, None, metadata
csc_indptr,
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=node_type_to_id,
edge_type_to_id=edge_type_to_id,
)
with tempfile.TemporaryDirectory() as test_dir:
......@@ -1065,8 +1072,8 @@ def test_OnDiskDataset_Graph_heterogeneous():
assert torch.equal(graph.indices, graph2.indices)
assert torch.equal(graph.node_type_offset, graph2.node_type_offset)
assert torch.equal(graph.type_per_edge, graph2.type_per_edge)
assert graph.metadata.node_type_to_id == graph2.metadata.node_type_to_id
assert graph.metadata.edge_type_to_id == graph2.metadata.edge_type_to_id
assert graph.node_type_to_id == graph2.node_type_to_id
assert graph.edge_type_to_id == graph2.edge_type_to_id
def test_OnDiskDataset_Metadata():
......@@ -1854,15 +1861,17 @@ def test_OnDiskDataset_all_nodes_set_hetero():
indices,
node_type_offset,
type_per_edge,
metadata,
node_type_to_id,
edge_type_to_id,
) = gbt.random_hetero_graph(1000, 10 * 1000, 3, 4)
graph = gb.from_fused_csc(
csc_indptr,
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=node_type_to_id,
edge_type_to_id=edge_type_to_id,
edge_attributes=None,
metadata=metadata,
)
with tempfile.TemporaryDirectory() as test_dir:
......
......@@ -141,7 +141,6 @@ def get_hetero_graph():
# num_nodes = 5, num_n1 = 2, num_n2 = 3
ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
metadata = gb.GraphMetadata(ntypes, etypes)
indptr = torch.LongTensor([0, 2, 4, 6, 8, 10])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1])
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0])
......@@ -151,7 +150,8 @@ def get_hetero_graph():
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
metadata=metadata,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
)
......
......@@ -117,7 +117,6 @@ def get_hetero_graph():
# num_nodes = 5, num_n1 = 2, num_n2 = 3
ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
metadata = gb.GraphMetadata(ntypes, etypes)
indptr = torch.LongTensor([0, 2, 4, 6, 8, 10])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1])
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0])
......@@ -127,7 +126,8 @@ def get_hetero_graph():
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
metadata=metadata,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
)
......@@ -207,7 +207,8 @@ def test_SubgraphSampler_Random_Hetero_Graph(labor):
indices,
node_type_offset,
type_per_edge,
metadata,
node_type_to_id,
edge_type_to_id,
) = gb_test_utils.random_hetero_graph(
num_nodes, num_edges, num_ntypes, num_etypes
)
......@@ -218,10 +219,11 @@ def test_SubgraphSampler_Random_Hetero_Graph(labor):
graph = gb.from_fused_csc(
csc_indptr,
indices,
node_type_offset,
type_per_edge,
edge_attributes,
metadata,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=node_type_to_id,
edge_type_to_id=edge_type_to_id,
edge_attributes=edge_attributes,
)
itemset = gb.ItemSetDict(
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment