Unverified Commit 523bbb4c authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] fix preprocess issue for single ntype/etype graph (#7011)

parent 0bfe34d9
......@@ -118,7 +118,18 @@ def preprocess_ondisk_dataset(
# 2. Load the edge data and create a DGLGraph.
if "graph" not in input_config:
raise RuntimeError("Invalid config: does not contain graph field.")
is_homogeneous = "type" not in input_config["graph"]["nodes"][0]
# For any graph that node/edge types are specified, we construct DGLGraph
# with `dgl.heterograph()` even there's only one node/edge type. This is
# because we want to save the node/edge types in the graph. So the logic of
# checking whether the graph is homogeneous is different from the logic in
# `DGLGraph.is_homogeneous()`. Otherwise, we construct DGLGraph with
# `dgl.graph()`.
is_homogeneous = (
len(input_config["graph"]["nodes"]) == 1
and len(input_config["graph"]["edges"]) == 1
and "type" not in input_config["graph"]["nodes"][0]
and "type" not in input_config["graph"]["edges"][0]
)
if is_homogeneous:
# Homogeneous graph.
num_nodes = input_config["graph"]["nodes"][0]["num"]
......@@ -178,19 +189,23 @@ def preprocess_ondisk_dataset(
if not is_homogeneous:
# For heterogenous graph, a node/edge feature must cover all
# node/edge types.
for feat_name, feat_data in g.ndata.items():
existing_types = set(feat_data.keys())
assert existing_types == set(g.ntypes), (
f"Node feature {feat_name} does not cover all node types."
+ f"Existing types: {existing_types}."
+ f"Expected types: {g.ntypes}."
)
for feat_name, feat_data in g.edata.items():
existing_types = set(feat_data.keys())
assert existing_types == set(g.canonical_etypes), (
f"Edge feature {feat_name} does not cover all edge types."
+ f"Existing types: {existing_types}."
+ f"Expected types: {g.etypes}."
ntypes = g.ntypes
assert all(
set(g.nodes[ntypes[0]].data.keys())
== set(g.nodes[ntype].data.keys())
for ntype in ntypes
), (
"Node feature does not cover all node types: "
+ f"{set(g.nodes[ntype].data.keys() for ntype in ntypes)}."
)
etypes = g.canonical_etypes
assert all(
set(g.edges[etypes[0]].data.keys())
== set(g.edges[etype].data.keys())
for etype in etypes
), (
"Edge feature does not cover all edge types: "
+ f"{set(g.edges[etype].data.keys() for etype in etypes)}."
)
# 4. Convert the DGLGraph to a FusedCSCSamplingGraph.
......
......@@ -2742,3 +2742,83 @@ def test_OnDiskDataset_load_tasks_selectively():
dataset = gb.OnDiskDataset(test_dir).load(tasks=2)
dataset = None
def test_OnDiskDataset_preprocess_graph_with_single_type():
"""Test for graph with single node/edge type."""
with tempfile.TemporaryDirectory() as test_dir:
# All metadata fields are specified.
dataset_name = "graphbolt_test"
num_nodes = 4000
num_edges = 20000
# Generate random edges.
nodes = np.repeat(np.arange(num_nodes), 5)
neighbors = np.random.randint(0, num_nodes, size=(num_edges))
edges = np.stack([nodes, neighbors], axis=1)
# Wrtie into edges/edge.csv
os.makedirs(os.path.join(test_dir, "edges/"), exist_ok=True)
edges = pd.DataFrame(edges, columns=["src", "dst"])
edges.to_csv(
os.path.join(test_dir, "edges/edge.csv"),
index=False,
header=False,
)
# Generate random graph edge-feats.
edge_feats = np.random.rand(num_edges, 5)
os.makedirs(os.path.join(test_dir, "data/"), exist_ok=True)
np.save(os.path.join(test_dir, "data/edge-feat.npy"), edge_feats)
# Generate random node-feats.
node_feats = np.random.rand(num_nodes, 10)
np.save(os.path.join(test_dir, "data/node-feat.npy"), node_feats)
yaml_content = f"""
dataset_name: {dataset_name}
graph: # graph structure and required attributes.
nodes:
- num: {num_nodes}
type: author
edges:
- type: author:collab:author
format: csv
path: edges/edge.csv
feature_data:
- domain: edge
type: author:collab:author
name: feat
format: numpy
path: data/edge-feat.npy
- domain: node
type: author
name: feat
format: numpy
path: data/node-feat.npy
"""
yaml_file = os.path.join(test_dir, "metadata.yaml")
with open(yaml_file, "w") as f:
f.write(yaml_content)
dataset = gb.OnDiskDataset(test_dir).load()
assert dataset.dataset_name == dataset_name
graph = dataset.graph
assert isinstance(graph, gb.FusedCSCSamplingGraph)
assert graph.total_num_nodes == num_nodes
assert graph.total_num_edges == num_edges
assert (
graph.node_attributes is not None
and "feat" in graph.node_attributes
)
assert (
graph.edge_attributes is not None
and "feat" in graph.edge_attributes
)
assert torch.equal(graph.node_type_offset, torch.tensor([0, num_nodes]))
assert torch.equal(
graph.type_per_edge,
torch.zeros(num_edges),
)
assert graph.edge_type_to_id == {"author:collab:author": 0}
assert graph.node_type_to_id == {"author": 0}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment