"vscode:/vscode.git/clone" did not exist on "c448e61b55842e186eec01c8e8c2dea9507784e4"
Unverified Commit f8c60a18 authored by LastWhisper's avatar LastWhisper Committed by GitHub
Browse files

[Graphbolt] Introduce an `is_homogeneous` option for converting from COO to CSC format. (#6202)

parent a38bb5d1
......@@ -712,7 +712,7 @@ def save_csc_sampling_graph(graph, filename):
print(f"CSCSamplingGraph has been saved to {filename}.")
def from_dglgraph(g: DGLGraph) -> CSCSamplingGraph:
def from_dglgraph(g: DGLGraph, is_homogeneous=False) -> CSCSamplingGraph:
"""Convert a DGLGraph to CSCSamplingGraph."""
homo_g, ntype_count, _ = to_homogeneous(g, return_count=True)
# Initialize metadata.
......@@ -726,7 +726,8 @@ def from_dglgraph(g: DGLGraph) -> CSCSamplingGraph:
indptr, indices, _ = homo_g.adj_tensors("csc")
ntype_count.insert(0, 0)
node_type_offset = torch.cumsum(torch.LongTensor(ntype_count), 0)
type_per_edge = homo_g.edata[ETYPE]
type_per_edge = None if is_homogeneous else homo_g.edata[ETYPE]
return CSCSamplingGraph(
torch.ops.graphbolt.from_csc(
......
......@@ -152,7 +152,7 @@ def preprocess_ondisk_dataset(dataset_dir: str) -> str:
g.edata[graph_feature["name"]] = edge_data
# 4. Convert the DGLGraph to a CSCSamplingGraph.
csc_sampling_graph = from_dglgraph(g)
csc_sampling_graph = from_dglgraph(g, is_homogeneous)
# 5. Save the CSCSamplingGraph and modify the output_config.
output_config["graph_topology"] = {}
......
......@@ -949,7 +949,7 @@ def test_OnDiskDataset_preprocess_homogeneous():
torch.arange(num_samples),
torch.tensor([fanout]),
)
assert len(list(subgraph.node_pairs.values())[0][0]) <= num_samples
assert len(subgraph.node_pairs[0]) <= num_samples
def test_OnDiskDataset_preprocess_path():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment