Unverified Commit 397b7599 authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Graphbolt] Support loading heterogeneous attributes in sampling graph. (#6873)

parent f5981789
......@@ -6,7 +6,7 @@ import torch
from dgl.utils import recursive_apply
from ...base import EID, ETYPE
from ...base import EID, ETYPE, NID, NTYPE
from ...convert import to_homogeneous
from ...heterograph import DGLGraph
from ..base import etype_str_to_tuple, etype_tuple_to_str, ORIGINAL_EDGE_ID
......@@ -1117,7 +1117,9 @@ def from_dglgraph(
) -> FusedCSCSamplingGraph:
"""Convert a DGLGraph to FusedCSCSamplingGraph."""
homo_g, ntype_count, _ = to_homogeneous(g, return_count=True)
homo_g, ntype_count, _ = to_homogeneous(
g, ndata=g.ndata, edata=g.edata, return_count=True
)
if is_homogeneous:
node_type_to_id = None
......@@ -1147,8 +1149,13 @@ def from_dglgraph(
)
node_attributes = {}
edge_attributes = {}
for feat_name, feat_data in homo_g.ndata.items():
if feat_name not in (NID, NTYPE):
node_attributes[feat_name] = feat_data
for feat_name, feat_data in homo_g.edata.items():
if feat_name not in (EID, ETYPE):
edge_attributes[feat_name] = feat_data
if include_original_edge_id:
# Assign edge attributes according to the original eids mapping.
edge_attributes[ORIGINAL_EDGE_ID] = torch.index_select(
......
......@@ -129,14 +129,41 @@ def preprocess_ondisk_dataset(
graph_feature["format"],
in_memory=in_memory,
)
if is_homogeneous:
g.ndata[graph_feature["name"]] = node_data
else:
g.nodes[graph_feature["type"]].data[
graph_feature["name"]
] = node_data
if graph_feature["domain"] == "edge":
edge_data = read_data(
os.path.join(dataset_dir, graph_feature["path"]),
graph_feature["format"],
in_memory=in_memory,
)
if is_homogeneous:
g.edata[graph_feature["name"]] = edge_data
else:
g.edges[etype_str_to_tuple(graph_feature["type"])].data[
graph_feature["name"]
] = edge_data
if not is_homogeneous:
# For homogeneous graph, a node/edge feature must cover all
# node/edge types.
for feat_name, feat_data in g.ndata.items():
existing_types = set(feat_data.keys())
assert existing_types == set(g.ntypes), (
f"Node feature {feat_name} does not cover all node types."
+ f"Existing types: {existing_types}."
+ f"Expected types: {g.ntypes}."
)
for feat_name, feat_data in g.edata.items():
existing_types = set(feat_data.keys())
assert existing_types == set(g.canonical_etypes), (
f"Edge feature {feat_name} does not cover all edge types."
+ f"Existing types: {existing_types}."
+ f"Expected types: {g.etypes}."
)
# 4. Convert the DGLGraph to a FusedCSCSamplingGraph.
fused_csc_sampling_graph = from_dglgraph(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment