"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "942b17ab2ea495bb5203804573e29fe8de18faf3"
Unverified Commit 397b7599 authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Graphbolt] Support loading heterogeneous attributes in sampling graph. (#6873)

parent f5981789
...@@ -6,7 +6,7 @@ import torch ...@@ -6,7 +6,7 @@ import torch
from dgl.utils import recursive_apply from dgl.utils import recursive_apply
from ...base import EID, ETYPE from ...base import EID, ETYPE, NID, NTYPE
from ...convert import to_homogeneous from ...convert import to_homogeneous
from ...heterograph import DGLGraph from ...heterograph import DGLGraph
from ..base import etype_str_to_tuple, etype_tuple_to_str, ORIGINAL_EDGE_ID from ..base import etype_str_to_tuple, etype_tuple_to_str, ORIGINAL_EDGE_ID
...@@ -1117,7 +1117,9 @@ def from_dglgraph( ...@@ -1117,7 +1117,9 @@ def from_dglgraph(
) -> FusedCSCSamplingGraph: ) -> FusedCSCSamplingGraph:
"""Convert a DGLGraph to FusedCSCSamplingGraph.""" """Convert a DGLGraph to FusedCSCSamplingGraph."""
homo_g, ntype_count, _ = to_homogeneous(g, return_count=True) homo_g, ntype_count, _ = to_homogeneous(
g, ndata=g.ndata, edata=g.edata, return_count=True
)
if is_homogeneous: if is_homogeneous:
node_type_to_id = None node_type_to_id = None
...@@ -1147,8 +1149,13 @@ def from_dglgraph( ...@@ -1147,8 +1149,13 @@ def from_dglgraph(
) )
node_attributes = {} node_attributes = {}
edge_attributes = {} edge_attributes = {}
for feat_name, feat_data in homo_g.ndata.items():
if feat_name not in (NID, NTYPE):
node_attributes[feat_name] = feat_data
for feat_name, feat_data in homo_g.edata.items():
if feat_name not in (EID, ETYPE):
edge_attributes[feat_name] = feat_data
if include_original_edge_id: if include_original_edge_id:
# Assign edge attributes according to the original eids mapping. # Assign edge attributes according to the original eids mapping.
edge_attributes[ORIGINAL_EDGE_ID] = torch.index_select( edge_attributes[ORIGINAL_EDGE_ID] = torch.index_select(
......
...@@ -129,14 +129,41 @@ def preprocess_ondisk_dataset( ...@@ -129,14 +129,41 @@ def preprocess_ondisk_dataset(
graph_feature["format"], graph_feature["format"],
in_memory=in_memory, in_memory=in_memory,
) )
g.ndata[graph_feature["name"]] = node_data if is_homogeneous:
g.ndata[graph_feature["name"]] = node_data
else:
g.nodes[graph_feature["type"]].data[
graph_feature["name"]
] = node_data
if graph_feature["domain"] == "edge": if graph_feature["domain"] == "edge":
edge_data = read_data( edge_data = read_data(
os.path.join(dataset_dir, graph_feature["path"]), os.path.join(dataset_dir, graph_feature["path"]),
graph_feature["format"], graph_feature["format"],
in_memory=in_memory, in_memory=in_memory,
) )
g.edata[graph_feature["name"]] = edge_data if is_homogeneous:
g.edata[graph_feature["name"]] = edge_data
else:
g.edges[etype_str_to_tuple(graph_feature["type"])].data[
graph_feature["name"]
] = edge_data
if not is_homogeneous:
# For homogeneous graph, a node/edge feature must cover all
# node/edge types.
for feat_name, feat_data in g.ndata.items():
existing_types = set(feat_data.keys())
assert existing_types == set(g.ntypes), (
f"Node feature {feat_name} does not cover all node types."
+ f"Existing types: {existing_types}."
+ f"Expected types: {g.ntypes}."
)
for feat_name, feat_data in g.edata.items():
existing_types = set(feat_data.keys())
assert existing_types == set(g.canonical_etypes), (
f"Edge feature {feat_name} does not cover all edge types."
+ f"Existing types: {existing_types}."
+ f"Expected types: {g.etypes}."
)
# 4. Convert the DGLGraph to a FusedCSCSamplingGraph. # 4. Convert the DGLGraph to a FusedCSCSamplingGraph.
fused_csc_sampling_graph = from_dglgraph( fused_csc_sampling_graph = from_dglgraph(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment