"docs/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "168fc2cfdaa4489623785cf8d5ba0e2f67eac2ba"
Unverified Commit dfff53bc authored by Mingbang Wang's avatar Mingbang Wang Committed by GitHub
Browse files

[GraphBolt] add test for PR#6873 (#6923)

parent 0f3bfd7e
...@@ -158,7 +158,7 @@ def preprocess_ondisk_dataset( ...@@ -158,7 +158,7 @@ def preprocess_ondisk_dataset(
graph_feature["name"] graph_feature["name"]
] = edge_data ] = edge_data
if not is_homogeneous: if not is_homogeneous:
# For homogeneous graph, a node/edge feature must cover all # For heterogenous graph, a node/edge feature must cover all
# node/edge types. # node/edge types.
for feat_name, feat_data in g.ndata.items(): for feat_name, feat_data in g.ndata.items():
existing_types = set(feat_data.keys()) existing_types = set(feat_data.keys())
......
...@@ -165,6 +165,12 @@ def random_homo_graphbolt_graph( ...@@ -165,6 +165,12 @@ def random_homo_graphbolt_graph(
- format: {edge_fmt} - format: {edge_fmt}
path: {edge_path} path: {edge_path}
feature_data: feature_data:
- domain: node
type: null
name: feat
format: numpy
in_memory: true
path: {node_feat_path}
- domain: edge - domain: edge
type: null type: null
name: feat name: feat
...@@ -250,6 +256,16 @@ def genereate_raw_data_for_hetero_dataset( ...@@ -250,6 +256,16 @@ def genereate_raw_data_for_hetero_dataset(
np.save(os.path.join(test_dir, node_feat_path), node_feats) np.save(os.path.join(test_dir, node_feat_path), node_feats)
node_feats_path[ntype] = node_feat_path node_feats_path[ntype] = node_feat_path
# Generate edge features.
edge_feats_path = {}
os.makedirs(os.path.join(test_dir, "data"), exist_ok=True)
for etype, num_edge in num_edges.items():
src_ntype, etype_str, dst_ntype = etype
edge_feat_path = os.path.join("data", f"{etype_str}-feat.npy")
edge_feats = np.random.rand(num_edge, num_classes)
np.save(os.path.join(test_dir, edge_feat_path), edge_feats)
edge_feats_path[etype_str] = edge_feat_path
# Generate train/test/valid set. # Generate train/test/valid set.
os.makedirs(os.path.join(test_dir, "set"), exist_ok=True) os.makedirs(os.path.join(test_dir, "set"), exist_ok=True)
user_ids = torch.arange(num_nodes["user"]) user_ids = torch.arange(num_nodes["user"])
...@@ -285,6 +301,31 @@ def genereate_raw_data_for_hetero_dataset( ...@@ -285,6 +301,31 @@ def genereate_raw_data_for_hetero_dataset(
- type: "user:click:item" - type: "user:click:item"
format: {edge_fmt} format: {edge_fmt}
path: {edges_path["click"]} path: {edges_path["click"]}
feature_data:
- domain: node
type: user
name: feat
format: numpy
in_memory: true
path: {node_feats_path["user"]}
- domain: node
type: item
name: feat
format: numpy
in_memory: true
path: {node_feats_path["item"]}
- domain: edge
type: "user:follow:user"
name: feat
format: numpy
in_memory: true
path: {edge_feats_path["follow"]}
- domain: edge
type: "user:click:item"
name: feat
format: numpy
in_memory: true
path: {edge_feats_path["click"]}
feature_data: feature_data:
- domain: node - domain: node
type: user type: user
......
...@@ -1136,9 +1136,14 @@ def test_OnDiskDataset_preprocess_homogeneous(edge_fmt): ...@@ -1136,9 +1136,14 @@ def test_OnDiskDataset_preprocess_homogeneous(edge_fmt):
assert fused_csc_sampling_graph.total_num_nodes == num_nodes assert fused_csc_sampling_graph.total_num_nodes == num_nodes
assert fused_csc_sampling_graph.total_num_edges == num_edges assert fused_csc_sampling_graph.total_num_edges == num_edges
assert ( assert (
fused_csc_sampling_graph.edge_attributes is None fused_csc_sampling_graph.node_attributes is not None
or gb.ORIGINAL_EDGE_ID and "feat" in fused_csc_sampling_graph.node_attributes
)
assert (
fused_csc_sampling_graph.edge_attributes is not None
and gb.ORIGINAL_EDGE_ID
not in fused_csc_sampling_graph.edge_attributes not in fused_csc_sampling_graph.edge_attributes
and "feat" in fused_csc_sampling_graph.edge_attributes
) )
num_samples = 100 num_samples = 100
...@@ -2147,7 +2152,14 @@ def test_OnDiskDataset_homogeneous(include_original_edge_id, edge_fmt): ...@@ -2147,7 +2152,14 @@ def test_OnDiskDataset_homogeneous(include_original_edge_id, edge_fmt):
assert isinstance(graph, gb.FusedCSCSamplingGraph) assert isinstance(graph, gb.FusedCSCSamplingGraph)
assert graph.total_num_nodes == num_nodes assert graph.total_num_nodes == num_nodes
assert graph.total_num_edges == num_edges assert graph.total_num_edges == num_edges
assert graph.edge_attributes is not None assert (
graph.node_attributes is not None
and "feat" in graph.node_attributes
)
assert (
graph.edge_attributes is not None
and "feat" in graph.edge_attributes
)
assert ( assert (
not include_original_edge_id not include_original_edge_id
) or gb.ORIGINAL_EDGE_ID in graph.edge_attributes ) or gb.ORIGINAL_EDGE_ID in graph.edge_attributes
...@@ -2220,7 +2232,14 @@ def test_OnDiskDataset_heterogeneous(include_original_edge_id, edge_fmt): ...@@ -2220,7 +2232,14 @@ def test_OnDiskDataset_heterogeneous(include_original_edge_id, edge_fmt):
assert graph.total_num_edges == sum( assert graph.total_num_edges == sum(
num_edge for num_edge in num_edges.values() num_edge for num_edge in num_edges.values()
) )
assert graph.edge_attributes is not None assert (
graph.node_attributes is not None
and "feat" in graph.node_attributes
)
assert (
graph.edge_attributes is not None
and "feat" in graph.edge_attributes
)
assert ( assert (
not include_original_edge_id not include_original_edge_id
) or gb.ORIGINAL_EDGE_ID in graph.edge_attributes ) or gb.ORIGINAL_EDGE_ID in graph.edge_attributes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment