Unverified Commit 92a46d12 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graohbolt] Reorg CSCSamplingGraph folder (#5974)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-16-19.ap-northeast-1.compute.internal>
parent 9a3456cd
......@@ -5,7 +5,6 @@ import sys
import torch
from .._ffi import libinfo
from .graph_storage import *
from .itemset import *
from .minibatch_sampler import *
from .feature_store import *
......@@ -44,5 +43,3 @@ def load_graphbolt():
load_graphbolt()
SampledSubgraph = torch.classes.graphbolt.SampledSubgraph
"""Graphbolt graph module."""
from .csc_sampling_graph import *
......@@ -2,3 +2,4 @@
from .ondisk_dataset import *
from .ondisk_metadata import *
from .torch_based_feature_store import *
from .csc_sampling_graph import *
......@@ -187,7 +187,7 @@ class CSCSamplingGraph:
Returns
-------
SampledSubgraph
torch.classes.graphbolt.SampledSubgraph
The in subgraph.
"""
# Ensure nodes is 1-D tensor.
......@@ -244,7 +244,7 @@ class CSCSamplingGraph:
to the number of edges.
Returns
-------
SampledSubgraph
torch.classes.graphbolt.SampledSubgraph
The sampled subgraph.
Examples
......
......@@ -3,10 +3,10 @@
from typing import Dict, List, Tuple
from ..dataset import Dataset
from ..graph_storage import CSCSamplingGraph, load_csc_sampling_graph
from ..itemset import ItemSet, ItemSetDict
from ..utils import read_data, tensor_to_tuple
from .csc_sampling_graph import CSCSamplingGraph, load_csc_sampling_graph
from .ondisk_metadata import OnDiskGraphTopology, OnDiskMetaData, OnDiskTVTSet
from .torch_based_feature_store import (
load_feature_stores,
......
import dgl.graphbolt as gb
import scipy.sparse as sp
import torch
def rand_csc_graph(N, density):
adj = sp.random(N, N, density)
adj = adj + adj.T
adj = adj.tocsc()
indptr = torch.LongTensor(adj.indptr)
indices = torch.LongTensor(adj.indices)
graph = gb.from_csc(indptr, indices)
return graph
def random_homo_graph(num_nodes, num_edges):
csc_indptr = torch.randint(0, num_edges, (num_nodes + 1,))
csc_indptr = torch.sort(csc_indptr)[0]
csc_indptr[0] = 0
csc_indptr[-1] = num_edges
indices = torch.randint(0, num_nodes, (num_edges,))
return csc_indptr, indices
def get_metadata(num_ntypes, num_etypes):
ntypes = {f"n{i}": i for i in range(num_ntypes)}
etypes = {}
count = 0
for n1 in range(num_ntypes):
for n2 in range(n1, num_ntypes):
if count >= num_etypes:
break
etypes.update({(f"n{n1}", f"e{count}", f"n{n2}"): count})
count += 1
return gb.GraphMetadata(ntypes, etypes)
def random_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes):
csc_indptr, indices = random_homo_graph(num_nodes, num_edges)
metadata = get_metadata(num_ntypes, num_etypes)
# Randomly get node type split point.
node_type_offset = torch.sort(
torch.randint(0, num_nodes, (num_ntypes + 1,))
)[0]
node_type_offset[0] = 0
node_type_offset[-1] = num_nodes
type_per_edge = []
for i in range(num_nodes):
num = csc_indptr[i + 1] - csc_indptr[i]
type_per_edge.append(
torch.sort(torch.randint(0, num_etypes, (num,)))[0]
)
type_per_edge = torch.cat(type_per_edge, dim=0)
return (csc_indptr, indices, node_type_offset, type_per_edge, metadata)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment