Unverified Commit 91fe0c90 authored by LastWhisper's avatar LastWhisper Committed by GitHub
Browse files

[GraphBolt] Refactor CSCSamplingGraph and update the corresponding method. (#6515)

parent 3f3652e0
......@@ -1022,15 +1022,15 @@ def test_OnDiskDataset_Graph_Exceptions():
def test_OnDiskDataset_Graph_homogeneous():
"""Test homogeneous graph topology."""
csc_indptr, indices = gbt.random_homo_graph(1000, 10 * 1000)
graph = gb.from_csc(csc_indptr, indices)
graph = gb.from_fused_csc(csc_indptr, indices)
with tempfile.TemporaryDirectory() as test_dir:
graph_path = os.path.join(test_dir, "csc_sampling_graph.tar")
gb.save_csc_sampling_graph(graph, graph_path)
graph_path = os.path.join(test_dir, "fused_csc_sampling_graph.tar")
gb.save_fused_csc_sampling_graph(graph, graph_path)
yaml_content = f"""
graph_topology:
type: CSCSamplingGraph
type: FusedCSCSamplingGraph
path: {graph_path}
"""
os.makedirs(os.path.join(test_dir, "preprocessed"), exist_ok=True)
......@@ -1063,17 +1063,17 @@ def test_OnDiskDataset_Graph_heterogeneous():
type_per_edge,
metadata,
) = gbt.random_hetero_graph(1000, 10 * 1000, 3, 4)
graph = gb.from_csc(
graph = gb.from_fused_csc(
csc_indptr, indices, node_type_offset, type_per_edge, None, metadata
)
with tempfile.TemporaryDirectory() as test_dir:
graph_path = os.path.join(test_dir, "csc_sampling_graph.tar")
gb.save_csc_sampling_graph(graph, graph_path)
graph_path = os.path.join(test_dir, "fused_csc_sampling_graph.tar")
gb.save_fused_csc_sampling_graph(graph, graph_path)
yaml_content = f"""
graph_topology:
type: CSCSamplingGraph
type: FusedCSCSamplingGraph
path: {graph_path}
"""
os.makedirs(os.path.join(test_dir, "preprocessed"), exist_ok=True)
......@@ -1155,19 +1155,24 @@ def test_OnDiskDataset_preprocess_homogeneous():
assert "graph" not in processed_dataset
assert "graph_topology" in processed_dataset
csc_sampling_graph = gb.csc_sampling_graph.load_csc_sampling_graph(
os.path.join(test_dir, processed_dataset["graph_topology"]["path"])
fused_csc_sampling_graph = (
gb.fused_csc_sampling_graph.load_fused_csc_sampling_graph(
os.path.join(
test_dir, processed_dataset["graph_topology"]["path"]
)
)
)
assert csc_sampling_graph.total_num_nodes == num_nodes
assert csc_sampling_graph.total_num_edges == num_edges
assert fused_csc_sampling_graph.total_num_nodes == num_nodes
assert fused_csc_sampling_graph.total_num_edges == num_edges
assert (
csc_sampling_graph.edge_attributes is None
or gb.ORIGINAL_EDGE_ID not in csc_sampling_graph.edge_attributes
fused_csc_sampling_graph.edge_attributes is None
or gb.ORIGINAL_EDGE_ID
not in fused_csc_sampling_graph.edge_attributes
)
num_samples = 100
fanout = 1
subgraph = csc_sampling_graph.sample_neighbors(
subgraph = fused_csc_sampling_graph.sample_neighbors(
torch.arange(num_samples),
torch.tensor([fanout]),
)
......@@ -1197,14 +1202,19 @@ def test_OnDiskDataset_preprocess_homogeneous():
)
with open(output_file, "rb") as f:
processed_dataset = yaml.load(f, Loader=yaml.Loader)
csc_sampling_graph = gb.csc_sampling_graph.load_csc_sampling_graph(
os.path.join(test_dir, processed_dataset["graph_topology"]["path"])
fused_csc_sampling_graph = (
gb.fused_csc_sampling_graph.load_fused_csc_sampling_graph(
os.path.join(
test_dir, processed_dataset["graph_topology"]["path"]
)
)
)
assert (
csc_sampling_graph.edge_attributes is not None
and gb.ORIGINAL_EDGE_ID not in csc_sampling_graph.edge_attributes
fused_csc_sampling_graph.edge_attributes is not None
and gb.ORIGINAL_EDGE_ID
not in fused_csc_sampling_graph.edge_attributes
)
csc_sampling_graph = None
fused_csc_sampling_graph = None
def test_OnDiskDataset_preprocess_path():
......@@ -1350,8 +1360,8 @@ def test_OnDiskDataset_preprocess_yaml_content_unix():
target_yaml_content = f"""
dataset_name: {dataset_name}
graph_topology:
type: CSCSamplingGraph
path: preprocessed/csc_sampling_graph.tar
type: FusedCSCSamplingGraph
path: preprocessed/fused_csc_sampling_graph.tar
feature_data:
- domain: node
type: null
......@@ -1504,8 +1514,8 @@ def test_OnDiskDataset_preprocess_yaml_content_windows():
target_yaml_content = f"""
dataset_name: {dataset_name}
graph_topology:
type: CSCSamplingGraph
path: preprocessed\\csc_sampling_graph.tar
type: FusedCSCSamplingGraph
path: preprocessed\\fused_csc_sampling_graph.tar
feature_data:
- domain: node
type: null
......@@ -1703,7 +1713,7 @@ def test_OnDiskDataset_load_graph():
pydantic.ValidationError,
# As error message diffs in pydantic 1.x and 2.x, we just match
# keyword only.
match="'CSCSamplingGraph'",
match="'FusedCSCSamplingGraph'",
):
dataset.load()
......@@ -1858,15 +1868,15 @@ def test_OnDiskDataset_load_tasks():
def test_OnDiskDataset_all_nodes_set_homo():
"""Test homograph's all nodes set of OnDiskDataset."""
csc_indptr, indices = gbt.random_homo_graph(1000, 10 * 1000)
graph = gb.from_csc(csc_indptr, indices)
graph = gb.from_fused_csc(csc_indptr, indices)
with tempfile.TemporaryDirectory() as test_dir:
graph_path = os.path.join(test_dir, "csc_sampling_graph.tar")
gb.save_csc_sampling_graph(graph, graph_path)
graph_path = os.path.join(test_dir, "fused_csc_sampling_graph.tar")
gb.save_fused_csc_sampling_graph(graph, graph_path)
yaml_content = f"""
graph_topology:
type: CSCSamplingGraph
type: FusedCSCSamplingGraph
path: {graph_path}
"""
os.makedirs(os.path.join(test_dir, "preprocessed"), exist_ok=True)
......@@ -1893,7 +1903,7 @@ def test_OnDiskDataset_all_nodes_set_hetero():
type_per_edge,
metadata,
) = gbt.random_hetero_graph(1000, 10 * 1000, 3, 4)
graph = gb.from_csc(
graph = gb.from_fused_csc(
csc_indptr,
indices,
node_type_offset=node_type_offset,
......@@ -1903,12 +1913,12 @@ def test_OnDiskDataset_all_nodes_set_hetero():
)
with tempfile.TemporaryDirectory() as test_dir:
graph_path = os.path.join(test_dir, "csc_sampling_graph.tar")
gb.save_csc_sampling_graph(graph, graph_path)
graph_path = os.path.join(test_dir, "fused_csc_sampling_graph.tar")
gb.save_fused_csc_sampling_graph(graph, graph_path)
yaml_content = f"""
graph_topology:
type: CSCSamplingGraph
type: FusedCSCSamplingGraph
path: {graph_path}
"""
os.makedirs(os.path.join(test_dir, "preprocessed"), exist_ok=True)
......
......@@ -119,7 +119,7 @@ def get_hetero_graph():
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1])
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0])
node_type_offset = torch.LongTensor([0, 2, 5])
return gb.from_csc(
return gb.from_fused_csc(
indptr,
indices,
node_type_offset=node_type_offset,
......
......@@ -38,7 +38,7 @@ def test_integration_link_prediction():
)
item_set = gb.ItemSet(node_pairs, names="node_pairs")
graph = gb.from_csc(indptr, indices)
graph = gb.from_fused_csc(indptr, indices)
node_feature = gb.TorchBasedFeature(node_feature_data)
edge_feature = gb.TorchBasedFeature(edge_feature_data)
......@@ -162,7 +162,7 @@ def test_integration_node_classification():
)
item_set = gb.ItemSet(node_pairs, names="node_pairs")
graph = gb.from_csc(indptr, indices)
graph = gb.from_fused_csc(indptr, indices)
node_feature = gb.TorchBasedFeature(node_feature_data)
edge_feature = gb.TorchBasedFeature(edge_feature_data)
......
......@@ -120,7 +120,7 @@ def get_hetero_graph():
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1])
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0])
node_type_offset = torch.LongTensor([0, 2, 5])
return gb.from_csc(
return gb.from_fused_csc(
indptr,
indices,
node_type_offset=node_type_offset,
......@@ -213,7 +213,7 @@ def test_SubgraphSampler_Random_Hetero_Graph(labor):
"A1": torch.randn(num_edges),
"A2": torch.randn(num_edges),
}
graph = gb.from_csc(
graph = gb.from_fused_csc(
csc_indptr,
indices,
node_type_offset,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment