Unverified Commit dfff53bc authored by Mingbang Wang's avatar Mingbang Wang Committed by GitHub
Browse files

[GraphBolt] add test for PR#6873 (#6923)

parent 0f3bfd7e
......@@ -158,7 +158,7 @@ def preprocess_ondisk_dataset(
graph_feature["name"]
] = edge_data
if not is_homogeneous:
# For homogeneous graph, a node/edge feature must cover all
# For heterogenous graph, a node/edge feature must cover all
# node/edge types.
for feat_name, feat_data in g.ndata.items():
existing_types = set(feat_data.keys())
......
......@@ -165,6 +165,12 @@ def random_homo_graphbolt_graph(
- format: {edge_fmt}
path: {edge_path}
feature_data:
- domain: node
type: null
name: feat
format: numpy
in_memory: true
path: {node_feat_path}
- domain: edge
type: null
name: feat
......@@ -250,6 +256,16 @@ def genereate_raw_data_for_hetero_dataset(
np.save(os.path.join(test_dir, node_feat_path), node_feats)
node_feats_path[ntype] = node_feat_path
# Generate edge features.
edge_feats_path = {}
os.makedirs(os.path.join(test_dir, "data"), exist_ok=True)
for etype, num_edge in num_edges.items():
src_ntype, etype_str, dst_ntype = etype
edge_feat_path = os.path.join("data", f"{etype_str}-feat.npy")
edge_feats = np.random.rand(num_edge, num_classes)
np.save(os.path.join(test_dir, edge_feat_path), edge_feats)
edge_feats_path[etype_str] = edge_feat_path
# Generate train/test/valid set.
os.makedirs(os.path.join(test_dir, "set"), exist_ok=True)
user_ids = torch.arange(num_nodes["user"])
......@@ -285,6 +301,31 @@ def genereate_raw_data_for_hetero_dataset(
- type: "user:click:item"
format: {edge_fmt}
path: {edges_path["click"]}
feature_data:
- domain: node
type: user
name: feat
format: numpy
in_memory: true
path: {node_feats_path["user"]}
- domain: node
type: item
name: feat
format: numpy
in_memory: true
path: {node_feats_path["item"]}
- domain: edge
type: "user:follow:user"
name: feat
format: numpy
in_memory: true
path: {edge_feats_path["follow"]}
- domain: edge
type: "user:click:item"
name: feat
format: numpy
in_memory: true
path: {edge_feats_path["click"]}
feature_data:
- domain: node
type: user
......
......@@ -1136,9 +1136,14 @@ def test_OnDiskDataset_preprocess_homogeneous(edge_fmt):
assert fused_csc_sampling_graph.total_num_nodes == num_nodes
assert fused_csc_sampling_graph.total_num_edges == num_edges
assert (
fused_csc_sampling_graph.edge_attributes is None
or gb.ORIGINAL_EDGE_ID
fused_csc_sampling_graph.node_attributes is not None
and "feat" in fused_csc_sampling_graph.node_attributes
)
assert (
fused_csc_sampling_graph.edge_attributes is not None
and gb.ORIGINAL_EDGE_ID
not in fused_csc_sampling_graph.edge_attributes
and "feat" in fused_csc_sampling_graph.edge_attributes
)
num_samples = 100
......@@ -2147,7 +2152,14 @@ def test_OnDiskDataset_homogeneous(include_original_edge_id, edge_fmt):
assert isinstance(graph, gb.FusedCSCSamplingGraph)
assert graph.total_num_nodes == num_nodes
assert graph.total_num_edges == num_edges
assert graph.edge_attributes is not None
assert (
graph.node_attributes is not None
and "feat" in graph.node_attributes
)
assert (
graph.edge_attributes is not None
and "feat" in graph.edge_attributes
)
assert (
not include_original_edge_id
) or gb.ORIGINAL_EDGE_ID in graph.edge_attributes
......@@ -2220,7 +2232,14 @@ def test_OnDiskDataset_heterogeneous(include_original_edge_id, edge_fmt):
assert graph.total_num_edges == sum(
num_edge for num_edge in num_edges.values()
)
assert graph.edge_attributes is not None
assert (
graph.node_attributes is not None
and "feat" in graph.node_attributes
)
assert (
graph.edge_attributes is not None
and "feat" in graph.edge_attributes
)
assert (
not include_original_edge_id
) or gb.ORIGINAL_EDGE_ID in graph.edge_attributes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment