"src/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "42b0c38f81fe90de8f709133c2a5183ddff0d4d1"
Unverified Commit e9d5156e authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] rename from_fused_csc as fused_csc_sampling_graph (#6732)

parent f04cce25
...@@ -65,7 +65,7 @@ TORCH_LIBRARY(graphbolt, m) { ...@@ -65,7 +65,7 @@ TORCH_LIBRARY(graphbolt, m) {
g->SetState(state); g->SetState(state);
return g; return g;
}); });
m.def("from_fused_csc", &FusedCSCSamplingGraph::Create); m.def("fused_csc_sampling_graph", &FusedCSCSamplingGraph::Create);
m.def( m.def(
"load_from_shared_memory", &FusedCSCSamplingGraph::LoadFromSharedMemory); "load_from_shared_memory", &FusedCSCSamplingGraph::LoadFromSharedMemory);
m.def("unique_and_compact", &UniqueAndCompact); m.def("unique_and_compact", &UniqueAndCompact);
......
...@@ -1268,7 +1268,7 @@ def convert_dgl_partition_to_csc_sampling_graph(part_config): ...@@ -1268,7 +1268,7 @@ def convert_dgl_partition_to_csc_sampling_graph(part_config):
type_per_edge = type_per_edge.to(RESERVED_FIELD_DTYPE[ETYPE]) type_per_edge = type_per_edge.to(RESERVED_FIELD_DTYPE[ETYPE])
# Sanity check. # Sanity check.
assert len(type_per_edge) == graph.num_edges() assert len(type_per_edge) == graph.num_edges()
csc_graph = graphbolt.from_fused_csc( csc_graph = graphbolt.fused_csc_sampling_graph(
indptr, indptr,
indices, indices,
node_type_offset=None, node_type_offset=None,
......
...@@ -21,7 +21,7 @@ from .sampled_subgraph_impl import ( ...@@ -21,7 +21,7 @@ from .sampled_subgraph_impl import (
__all__ = [ __all__ = [
"FusedCSCSamplingGraph", "FusedCSCSamplingGraph",
"from_fused_csc", "fused_csc_sampling_graph",
"load_from_shared_memory", "load_from_shared_memory",
"from_dglgraph", "from_dglgraph",
] ]
...@@ -88,7 +88,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -88,7 +88,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
>>> node_type_offset = torch.LongTensor([0, 2, 5]) >>> node_type_offset = torch.LongTensor([0, 2, 5])
>>> type_per_edge = torch.LongTensor( >>> type_per_edge = torch.LongTensor(
... [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3]) ... [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])
>>> graph = gb.from_fused_csc(indptr, indices, >>> graph = gb.fused_csc_sampling_graph(indptr, indices,
... node_type_offset=node_type_offset, ... node_type_offset=node_type_offset,
... type_per_edge=type_per_edge, ... type_per_edge=type_per_edge,
... node_type_to_id=ntypes, ... node_type_to_id=ntypes,
...@@ -139,7 +139,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -139,7 +139,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
>>> type_per_edge = torch.LongTensor( >>> type_per_edge = torch.LongTensor(
... [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3]) ... [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])
>>> metadata = gb.GraphMetadata(ntypes, etypes) >>> metadata = gb.GraphMetadata(ntypes, etypes)
>>> graph = gb.from_fused_csc(indptr, indices, node_type_offset, >>> graph = gb.fused_csc_sampling_graph(indptr, indices, node_type_offset,
... type_per_edge, None, metadata) ... type_per_edge, None, metadata)
>>> print(graph.num_edges) >>> print(graph.num_edges)
{'N0:R0:N0': 2, 'N0:R1:N1': 1, 'N1:R2:N0': 2, 'N1:R3:N1': 3} {'N0:R0:N0': 2, 'N0:R1:N1': 1, 'N1:R2:N0': 2, 'N1:R3:N1': 3}
...@@ -337,7 +337,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -337,7 +337,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
>>> node_type_offset = torch.LongTensor([0, 2, 5]) >>> node_type_offset = torch.LongTensor([0, 2, 5])
>>> type_per_edge = torch.LongTensor( >>> type_per_edge = torch.LongTensor(
... [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3]) ... [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])
>>> graph = gb.from_fused_csc(indptr, indices, >>> graph = gb.fused_csc_sampling_graph(indptr, indices,
... node_type_offset=node_type_offset, ... node_type_offset=node_type_offset,
... type_per_edge=type_per_edge, ... type_per_edge=type_per_edge,
... node_type_to_id=ntypes, ... node_type_to_id=ntypes,
...@@ -577,7 +577,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -577,7 +577,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
>>> indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1]) >>> indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
>>> node_type_offset = torch.LongTensor([0, 2, 5]) >>> node_type_offset = torch.LongTensor([0, 2, 5])
>>> type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0]) >>> type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
>>> graph = gb.from_fused_csc(indptr, indices, >>> graph = gb.fused_csc_sampling_graph(indptr, indices,
... node_type_offset=node_type_offset, ... node_type_offset=node_type_offset,
... type_per_edge=type_per_edge, ... type_per_edge=type_per_edge,
... node_type_to_id=ntypes, ... node_type_to_id=ntypes,
...@@ -765,7 +765,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -765,7 +765,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
>>> indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1]) >>> indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
>>> node_type_offset = torch.LongTensor([0, 2, 5]) >>> node_type_offset = torch.LongTensor([0, 2, 5])
>>> type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0]) >>> type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
>>> graph = gb.from_fused_csc(indptr, indices, >>> graph = gb.fused_csc_sampling_graph(indptr, indices,
... node_type_offset=node_type_offset, ... node_type_offset=node_type_offset,
... type_per_edge=type_per_edge, ... type_per_edge=type_per_edge,
... node_type_to_id=ntypes, ... node_type_to_id=ntypes,
...@@ -893,7 +893,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -893,7 +893,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
return self return self
def from_fused_csc( def fused_csc_sampling_graph(
csc_indptr: torch.Tensor, csc_indptr: torch.Tensor,
indices: torch.Tensor, indices: torch.Tensor,
node_type_offset: Optional[torch.tensor] = None, node_type_offset: Optional[torch.tensor] = None,
...@@ -936,7 +936,7 @@ def from_fused_csc( ...@@ -936,7 +936,7 @@ def from_fused_csc(
>>> indices = torch.tensor([1, 3, 0, 1, 2, 0, 3]) >>> indices = torch.tensor([1, 3, 0, 1, 2, 0, 3])
>>> node_type_offset = torch.tensor([0, 1, 2, 3]) >>> node_type_offset = torch.tensor([0, 1, 2, 3])
>>> type_per_edge = torch.tensor([0, 1, 0, 1, 1, 0, 0]) >>> type_per_edge = torch.tensor([0, 1, 0, 1, 1, 0, 0])
>>> graph = graphbolt.from_fused_csc(csc_indptr, indices, >>> graph = graphbolt.fused_csc_sampling_graph(csc_indptr, indices,
... node_type_offset=node_type_offset, ... node_type_offset=node_type_offset,
... type_per_edge=type_per_edge, ... type_per_edge=type_per_edge,
... node_type_to_id=ntypes, edge_type_to_id=etypes, ... node_type_to_id=ntypes, edge_type_to_id=etypes,
...@@ -984,7 +984,7 @@ def from_fused_csc( ...@@ -984,7 +984,7 @@ def from_fused_csc(
0 0
), "node_type_offset length should be |ntypes| + 1." ), "node_type_offset length should be |ntypes| + 1."
return FusedCSCSamplingGraph( return FusedCSCSamplingGraph(
torch.ops.graphbolt.from_fused_csc( torch.ops.graphbolt.fused_csc_sampling_graph(
csc_indptr, csc_indptr,
indices, indices,
node_type_offset, node_type_offset,
...@@ -1086,7 +1086,7 @@ def from_dglgraph( ...@@ -1086,7 +1086,7 @@ def from_dglgraph(
edge_attributes[ORIGINAL_EDGE_ID] = homo_g.edata[EID][edge_ids] edge_attributes[ORIGINAL_EDGE_ID] = homo_g.edata[EID][edge_ids]
return FusedCSCSamplingGraph( return FusedCSCSamplingGraph(
torch.ops.graphbolt.from_fused_csc( torch.ops.graphbolt.fused_csc_sampling_graph(
indptr, indptr,
indices, indices,
node_type_offset, node_type_offset,
......
...@@ -31,7 +31,7 @@ class InSubgraphSampler(SubgraphSampler): ...@@ -31,7 +31,7 @@ class InSubgraphSampler(SubgraphSampler):
>>> import torch >>> import torch
>>> indptr = torch.LongTensor([0, 3, 5, 7, 9, 12, 14]) >>> indptr = torch.LongTensor([0, 3, 5, 7, 9, 12, 14])
>>> indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 5, 1, 2, 0, 3, 5, 1, 4]) >>> indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 5, 1, 2, 0, 3, 5, 1, 4])
>>> graph = gb.from_fused_csc(indptr, indices) >>> graph = gb.fused_csc_sampling_graph(indptr, indices)
>>> item_set = gb.ItemSet(len(indptr) - 1, names="seed_nodes") >>> item_set = gb.ItemSet(len(indptr) - 1, names="seed_nodes")
>>> item_sampler = gb.ItemSampler(item_set, batch_size=2) >>> item_sampler = gb.ItemSampler(item_set, batch_size=2)
>>> insubgraph_sampler = gb.InSubgraphSampler(item_sampler, graph) >>> insubgraph_sampler = gb.InSubgraphSampler(item_sampler, graph)
......
...@@ -64,7 +64,7 @@ class NeighborSampler(SubgraphSampler): ...@@ -64,7 +64,7 @@ class NeighborSampler(SubgraphSampler):
>>> from dgl import graphbolt as gb >>> from dgl import graphbolt as gb
>>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8]) >>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8])
>>> indices = torch.LongTensor([1, 2, 0, 3, 5, 4, 3, 5]) >>> indices = torch.LongTensor([1, 2, 0, 3, 5, 4, 3, 5])
>>> graph = gb.from_fused_csc(indptr, indices) >>> graph = gb.fused_csc_sampling_graph(indptr, indices)
>>> node_pairs = torch.LongTensor([[0, 1], [1, 2]]) >>> node_pairs = torch.LongTensor([[0, 1], [1, 2]])
>>> item_set = gb.ItemSet(node_pairs, names="node_pairs") >>> item_set = gb.ItemSet(node_pairs, names="node_pairs")
>>> item_sampler = gb.ItemSampler( >>> item_sampler = gb.ItemSampler(
...@@ -229,7 +229,7 @@ class LayerNeighborSampler(NeighborSampler): ...@@ -229,7 +229,7 @@ class LayerNeighborSampler(NeighborSampler):
>>> from dgl import graphbolt as gb >>> from dgl import graphbolt as gb
>>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8]) >>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8])
>>> indices = torch.LongTensor([1, 2, 0, 3, 5, 4, 3, 5]) >>> indices = torch.LongTensor([1, 2, 0, 3, 5, 4, 3, 5])
>>> graph = gb.from_fused_csc(indptr, indices) >>> graph = gb.fused_csc_sampling_graph(indptr, indices)
>>> data_format = gb.LinkPredictionEdgeFormat.INDEPENDENT >>> data_format = gb.LinkPredictionEdgeFormat.INDEPENDENT
>>> node_pairs = torch.LongTensor([[0, 1], [1, 2]]) >>> node_pairs = torch.LongTensor([[0, 1], [1, 2]])
>>> item_set = gb.ItemSet(node_pairs, names="node_pairs") >>> item_set = gb.ItemSet(node_pairs, names="node_pairs")
......
...@@ -32,7 +32,7 @@ class UniformNegativeSampler(NegativeSampler): ...@@ -32,7 +32,7 @@ class UniformNegativeSampler(NegativeSampler):
>>> from dgl import graphbolt as gb >>> from dgl import graphbolt as gb
>>> indptr = torch.LongTensor([0, 2, 4, 5]) >>> indptr = torch.LongTensor([0, 2, 4, 5])
>>> indices = torch.LongTensor([1, 2, 0, 2, 0]) >>> indices = torch.LongTensor([1, 2, 0, 2, 0])
>>> graph = gb.from_fused_csc(indptr, indices) >>> graph = gb.fused_csc_sampling_graph(indptr, indices)
>>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2])) >>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2]))
>>> item_set = gb.ItemSet(node_pairs, names="node_pairs") >>> item_set = gb.ItemSet(node_pairs, names="node_pairs")
>>> item_sampler = gb.ItemSampler( >>> item_sampler = gb.ItemSampler(
......
...@@ -18,7 +18,7 @@ def rand_csc_graph(N, density, bidirection_edge=False): ...@@ -18,7 +18,7 @@ def rand_csc_graph(N, density, bidirection_edge=False):
indptr = torch.LongTensor(adj.indptr) indptr = torch.LongTensor(adj.indptr)
indices = torch.LongTensor(adj.indices) indices = torch.LongTensor(adj.indices)
graph = gb.from_fused_csc(indptr, indices) graph = gb.fused_csc_sampling_graph(indptr, indices)
return graph return graph
......
...@@ -27,7 +27,7 @@ mp.set_sharing_strategy("file_system") ...@@ -27,7 +27,7 @@ mp.set_sharing_strategy("file_system")
def test_empty_graph(total_num_nodes): def test_empty_graph(total_num_nodes):
csc_indptr = torch.zeros((total_num_nodes + 1,), dtype=int) csc_indptr = torch.zeros((total_num_nodes + 1,), dtype=int)
indices = torch.tensor([]) indices = torch.tensor([])
graph = gb.from_fused_csc(csc_indptr, indices) graph = gb.fused_csc_sampling_graph(csc_indptr, indices)
assert graph.total_num_edges == 0 assert graph.total_num_edges == 0
assert graph.total_num_nodes == total_num_nodes assert graph.total_num_nodes == total_num_nodes
assert torch.equal(graph.csc_indptr, csc_indptr) assert torch.equal(graph.csc_indptr, csc_indptr)
...@@ -55,7 +55,7 @@ def test_hetero_empty_graph(total_num_nodes): ...@@ -55,7 +55,7 @@ def test_hetero_empty_graph(total_num_nodes):
node_type_offset[0] = 0 node_type_offset[0] = 0
node_type_offset[-1] = total_num_nodes node_type_offset[-1] = total_num_nodes
type_per_edge = torch.tensor([]) type_per_edge = torch.tensor([])
graph = gb.from_fused_csc( graph = gb.fused_csc_sampling_graph(
csc_indptr, csc_indptr,
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
...@@ -83,7 +83,7 @@ def test_hetero_empty_graph(total_num_nodes): ...@@ -83,7 +83,7 @@ def test_hetero_empty_graph(total_num_nodes):
) )
def test_type_to_id_with_ntype_exception(ntypes): def test_type_to_id_with_ntype_exception(ntypes):
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
gb.from_fused_csc( gb.fused_csc_sampling_graph(
None, None, node_type_to_id=ntypes, edge_type_to_id={"e1": 1} None, None, node_type_to_id=ntypes, edge_type_to_id={"e1": 1}
) )
...@@ -106,7 +106,7 @@ def test_type_to_id_with_ntype_exception(ntypes): ...@@ -106,7 +106,7 @@ def test_type_to_id_with_ntype_exception(ntypes):
) )
def test_type_to_id_with_etype_exception(etypes): def test_type_to_id_with_etype_exception(etypes):
with pytest.raises(Exception): with pytest.raises(Exception):
gb.from_fused_csc( gb.fused_csc_sampling_graph(
None, None,
None, None,
node_type_to_id={"n1": 0, "n2": 1, "n3": 2}, node_type_to_id={"n1": 0, "n2": 1, "n3": 2},
...@@ -130,7 +130,7 @@ def test_homo_graph(total_num_nodes, total_num_edges): ...@@ -130,7 +130,7 @@ def test_homo_graph(total_num_nodes, total_num_edges):
"A1": torch.randn(total_num_edges), "A1": torch.randn(total_num_edges),
"A2": torch.randn(total_num_edges), "A2": torch.randn(total_num_edges),
} }
graph = gb.from_fused_csc( graph = gb.fused_csc_sampling_graph(
csc_indptr, indices, edge_attributes=edge_attributes csc_indptr, indices, edge_attributes=edge_attributes
) )
...@@ -171,7 +171,7 @@ def test_hetero_graph(total_num_nodes, total_num_edges, num_ntypes, num_etypes): ...@@ -171,7 +171,7 @@ def test_hetero_graph(total_num_nodes, total_num_edges, num_ntypes, num_etypes):
"A1": torch.randn(total_num_edges), "A1": torch.randn(total_num_edges),
"A2": torch.randn(total_num_edges), "A2": torch.randn(total_num_edges),
} }
graph = gb.from_fused_csc( graph = gb.fused_csc_sampling_graph(
csc_indptr, csc_indptr,
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
...@@ -209,7 +209,7 @@ def test_num_nodes_edges_homo(total_num_nodes, total_num_edges): ...@@ -209,7 +209,7 @@ def test_num_nodes_edges_homo(total_num_nodes, total_num_edges):
"A1": torch.randn(total_num_edges), "A1": torch.randn(total_num_edges),
"A2": torch.randn(total_num_edges), "A2": torch.randn(total_num_edges),
} }
graph = gb.from_fused_csc( graph = gb.fused_csc_sampling_graph(
csc_indptr, indices, edge_attributes=edge_attributes csc_indptr, indices, edge_attributes=edge_attributes
) )
...@@ -260,7 +260,7 @@ def test_num_nodes_hetero(): ...@@ -260,7 +260,7 @@ def test_num_nodes_hetero():
assert all(type_per_edge < len(etypes)) assert all(type_per_edge < len(etypes))
# Construct FusedCSCSamplingGraph. # Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc( graph = gb.fused_csc_sampling_graph(
indptr, indptr,
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
...@@ -309,7 +309,7 @@ def test_node_type_offset_wrong_legnth(node_type_offset): ...@@ -309,7 +309,7 @@ def test_node_type_offset_wrong_legnth(node_type_offset):
edge_type_to_id, edge_type_to_id,
) = gbt.random_hetero_graph(10, 50, num_ntypes, 5) ) = gbt.random_hetero_graph(10, 50, num_ntypes, 5)
with pytest.raises(Exception): with pytest.raises(Exception):
gb.from_fused_csc( gb.fused_csc_sampling_graph(
csc_indptr, csc_indptr,
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
...@@ -331,7 +331,7 @@ def test_load_save_homo_graph(total_num_nodes, total_num_edges): ...@@ -331,7 +331,7 @@ def test_load_save_homo_graph(total_num_nodes, total_num_edges):
csc_indptr, indices = gbt.random_homo_graph( csc_indptr, indices = gbt.random_homo_graph(
total_num_nodes, total_num_edges total_num_nodes, total_num_edges
) )
graph = gb.from_fused_csc(csc_indptr, indices) graph = gb.fused_csc_sampling_graph(csc_indptr, indices)
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
filename = os.path.join(test_dir, "fused_csc_sampling_graph.pt") filename = os.path.join(test_dir, "fused_csc_sampling_graph.pt")
...@@ -373,7 +373,7 @@ def test_load_save_hetero_graph( ...@@ -373,7 +373,7 @@ def test_load_save_hetero_graph(
) = gbt.random_hetero_graph( ) = gbt.random_hetero_graph(
total_num_nodes, total_num_edges, num_ntypes, num_etypes total_num_nodes, total_num_edges, num_ntypes, num_etypes
) )
graph = gb.from_fused_csc( graph = gb.fused_csc_sampling_graph(
csc_indptr, csc_indptr,
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
...@@ -410,7 +410,7 @@ def test_pickle_homo_graph(total_num_nodes, total_num_edges): ...@@ -410,7 +410,7 @@ def test_pickle_homo_graph(total_num_nodes, total_num_edges):
csc_indptr, indices = gbt.random_homo_graph( csc_indptr, indices = gbt.random_homo_graph(
total_num_nodes, total_num_edges total_num_nodes, total_num_edges
) )
graph = gb.from_fused_csc(csc_indptr, indices) graph = gb.fused_csc_sampling_graph(csc_indptr, indices)
serialized = pickle.dumps(graph) serialized = pickle.dumps(graph)
graph2 = pickle.loads(serialized) graph2 = pickle.loads(serialized)
...@@ -454,7 +454,7 @@ def test_pickle_hetero_graph( ...@@ -454,7 +454,7 @@ def test_pickle_hetero_graph(
"a": torch.randn((total_num_edges,)), "a": torch.randn((total_num_edges,)),
"b": torch.randint(1, 10, (total_num_edges,)), "b": torch.randint(1, 10, (total_num_edges,)),
} }
graph = gb.from_fused_csc( graph = gb.fused_csc_sampling_graph(
csc_indptr, csc_indptr,
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
...@@ -511,7 +511,7 @@ def test_multiprocessing(): ...@@ -511,7 +511,7 @@ def test_multiprocessing():
edge_attributes = { edge_attributes = {
"a": torch.randn((total_num_edges,)), "a": torch.randn((total_num_edges,)),
} }
graph = gb.from_fused_csc( graph = gb.fused_csc_sampling_graph(
csc_indptr, csc_indptr,
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
...@@ -549,7 +549,7 @@ def test_in_subgraph_homogeneous(): ...@@ -549,7 +549,7 @@ def test_in_subgraph_homogeneous():
assert indptr[-1] == len(indices) assert indptr[-1] == len(indices)
# Construct FusedCSCSamplingGraph. # Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc(indptr, indices) graph = gb.fused_csc_sampling_graph(indptr, indices)
# Extract in subgraph. # Extract in subgraph.
nodes = torch.LongTensor([4, 1, 3]) nodes = torch.LongTensor([4, 1, 3])
...@@ -611,7 +611,7 @@ def test_in_subgraph_heterogeneous(): ...@@ -611,7 +611,7 @@ def test_in_subgraph_heterogeneous():
assert all(type_per_edge < len(etypes)) assert all(type_per_edge < len(etypes))
# Construct FusedCSCSamplingGraph. # Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc( graph = gb.fused_csc_sampling_graph(
indptr, indptr,
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
...@@ -686,7 +686,7 @@ def test_sample_neighbors_homo(labor, indptr_dtype, indices_dtype): ...@@ -686,7 +686,7 @@ def test_sample_neighbors_homo(labor, indptr_dtype, indices_dtype):
assert len(indptr) == total_num_nodes + 1 assert len(indptr) == total_num_nodes + 1
# Construct FusedCSCSamplingGraph. # Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc(indptr, indices) graph = gb.fused_csc_sampling_graph(indptr, indices)
# Generate subgraph via sample neighbors. # Generate subgraph via sample neighbors.
nodes = torch.tensor([1, 3, 4], dtype=indices_dtype) nodes = torch.tensor([1, 3, 4], dtype=indices_dtype)
...@@ -730,7 +730,7 @@ def test_sample_neighbors_hetero(labor, indptr_dtype, indices_dtype): ...@@ -730,7 +730,7 @@ def test_sample_neighbors_hetero(labor, indptr_dtype, indices_dtype):
assert indptr[-1] == len(indices) assert indptr[-1] == len(indices)
# Construct FusedCSCSamplingGraph. # Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc( graph = gb.fused_csc_sampling_graph(
indptr, indptr,
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
...@@ -838,7 +838,7 @@ def test_sample_neighbors_fanouts( ...@@ -838,7 +838,7 @@ def test_sample_neighbors_fanouts(
assert indptr[-1] == len(indices) assert indptr[-1] == len(indices)
# Construct FusedCSCSamplingGraph. # Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc( graph = gb.fused_csc_sampling_graph(
indptr, indptr,
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
...@@ -895,7 +895,7 @@ def test_sample_neighbors_replace( ...@@ -895,7 +895,7 @@ def test_sample_neighbors_replace(
assert indptr[-1] == len(indices) assert indptr[-1] == len(indices)
# Construct FusedCSCSamplingGraph. # Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc( graph = gb.fused_csc_sampling_graph(
indptr, indptr,
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
...@@ -938,7 +938,9 @@ def test_sample_neighbors_return_eids_homo(labor): ...@@ -938,7 +938,9 @@ def test_sample_neighbors_return_eids_homo(labor):
edge_attributes = {gb.ORIGINAL_EDGE_ID: torch.randperm(total_num_edges)} edge_attributes = {gb.ORIGINAL_EDGE_ID: torch.randperm(total_num_edges)}
# Construct FusedCSCSamplingGraph. # Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc(indptr, indices, edge_attributes=edge_attributes) graph = gb.fused_csc_sampling_graph(
indptr, indices, edge_attributes=edge_attributes
)
# Generate subgraph via sample neighbors. # Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4]) nodes = torch.LongTensor([1, 3, 4])
...@@ -984,7 +986,7 @@ def test_sample_neighbors_return_eids_hetero(labor): ...@@ -984,7 +986,7 @@ def test_sample_neighbors_return_eids_hetero(labor):
assert indptr[-1] == len(indices) assert indptr[-1] == len(indices)
# Construct FusedCSCSamplingGraph. # Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc( graph = gb.fused_csc_sampling_graph(
indptr, indptr,
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
...@@ -1043,7 +1045,9 @@ def test_sample_neighbors_probs(replace, labor, probs_name): ...@@ -1043,7 +1045,9 @@ def test_sample_neighbors_probs(replace, labor, probs_name):
} }
# Construct FusedCSCSamplingGraph. # Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc(indptr, indices, edge_attributes=edge_attributes) graph = gb.fused_csc_sampling_graph(
indptr, indices, edge_attributes=edge_attributes
)
# Generate subgraph via sample neighbors. # Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4]) nodes = torch.LongTensor([1, 3, 4])
...@@ -1088,7 +1092,9 @@ def test_sample_neighbors_zero_probs(replace, labor, probs_or_mask): ...@@ -1088,7 +1092,9 @@ def test_sample_neighbors_zero_probs(replace, labor, probs_or_mask):
edge_attributes = {"probs_or_mask": probs_or_mask} edge_attributes = {"probs_or_mask": probs_or_mask}
# Construct FusedCSCSamplingGraph. # Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc(indptr, indices, edge_attributes=edge_attributes) graph = gb.fused_csc_sampling_graph(
indptr, indices, edge_attributes=edge_attributes
)
# Generate subgraph via sample neighbors. # Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4]) nodes = torch.LongTensor([1, 3, 4])
...@@ -1143,7 +1149,7 @@ def test_homo_graph_on_shared_memory( ...@@ -1143,7 +1149,7 @@ def test_homo_graph_on_shared_memory(
} }
else: else:
edge_attributes = None edge_attributes = None
graph = gb.from_fused_csc( graph = gb.fused_csc_sampling_graph(
csc_indptr, indices, edge_attributes=edge_attributes csc_indptr, indices, edge_attributes=edge_attributes
) )
...@@ -1218,7 +1224,7 @@ def test_hetero_graph_on_shared_memory( ...@@ -1218,7 +1224,7 @@ def test_hetero_graph_on_shared_memory(
} }
else: else:
edge_attributes = None edge_attributes = None
graph = gb.from_fused_csc( graph = gb.fused_csc_sampling_graph(
csc_indptr, csc_indptr,
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
...@@ -1343,7 +1349,7 @@ def test_multiprocessing_with_shared_memory(): ...@@ -1343,7 +1349,7 @@ def test_multiprocessing_with_shared_memory():
node_type_offset.share_memory_() node_type_offset.share_memory_()
type_per_edge.share_memory_() type_per_edge.share_memory_()
graph = gb.from_fused_csc( graph = gb.fused_csc_sampling_graph(
csc_indptr, csc_indptr,
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
...@@ -1559,7 +1565,9 @@ def test_sample_neighbors_homo_pick_number(fanouts, replace, labor, probs_name): ...@@ -1559,7 +1565,9 @@ def test_sample_neighbors_homo_pick_number(fanouts, replace, labor, probs_name):
} }
# Construct FusedCSCSamplingGraph. # Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc(indptr, indices, edge_attributes=edge_attributes) graph = gb.fused_csc_sampling_graph(
indptr, indices, edge_attributes=edge_attributes
)
# Generate subgraph via sample neighbors. # Generate subgraph via sample neighbors.
nodes = torch.LongTensor([0, 1]) nodes = torch.LongTensor([0, 1])
...@@ -1642,7 +1650,7 @@ def test_sample_neighbors_hetero_pick_number( ...@@ -1642,7 +1650,7 @@ def test_sample_neighbors_hetero_pick_number(
} }
# Construct FusedCSCSamplingGraph. # Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc( graph = gb.fused_csc_sampling_graph(
indptr, indptr,
indices, indices,
edge_attributes=edge_attributes, edge_attributes=edge_attributes,
...@@ -1731,7 +1739,7 @@ def test_csc_sampling_graph_to_device(): ...@@ -1731,7 +1739,7 @@ def test_csc_sampling_graph_to_device():
} }
# Construct FusedCSCSamplingGraph. # Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc( graph = gb.fused_csc_sampling_graph(
indptr, indptr,
indices, indices,
edge_attributes=edge_attributes, edge_attributes=edge_attributes,
...@@ -1774,7 +1782,7 @@ def test_sample_neighbors_homo_csc_format(): ...@@ -1774,7 +1782,7 @@ def test_sample_neighbors_homo_csc_format():
assert indptr[-1] == len(indices) assert indptr[-1] == len(indices)
# Construct FusedCSCSamplingGraph. # Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc(indptr, indices) graph = gb.fused_csc_sampling_graph(indptr, indices)
# Generate subgraph via sample neighbors. # Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4]) nodes = torch.LongTensor([1, 3, 4])
...@@ -1819,7 +1827,7 @@ def test_sample_neighbors_hetero_csc_format(labor): ...@@ -1819,7 +1827,7 @@ def test_sample_neighbors_hetero_csc_format(labor):
assert indptr[-1] == len(indices) assert indptr[-1] == len(indices)
# Construct FusedCSCSamplingGraph. # Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc( graph = gb.fused_csc_sampling_graph(
indptr, indptr,
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
...@@ -1924,7 +1932,7 @@ def test_sample_neighbors_fanouts_csc_format( ...@@ -1924,7 +1932,7 @@ def test_sample_neighbors_fanouts_csc_format(
assert indptr[-1] == len(indices) assert indptr[-1] == len(indices)
# Construct FusedCSCSamplingGraph. # Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc( graph = gb.fused_csc_sampling_graph(
indptr, indptr,
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
...@@ -1985,7 +1993,7 @@ def test_sample_neighbors_replace_csc_format( ...@@ -1985,7 +1993,7 @@ def test_sample_neighbors_replace_csc_format(
assert indptr[-1] == len(indices) assert indptr[-1] == len(indices)
# Construct FusedCSCSamplingGraph. # Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc( graph = gb.fused_csc_sampling_graph(
indptr, indptr,
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
...@@ -2034,7 +2042,9 @@ def test_sample_neighbors_return_eids_homo_csc_format(labor): ...@@ -2034,7 +2042,9 @@ def test_sample_neighbors_return_eids_homo_csc_format(labor):
edge_attributes = {gb.ORIGINAL_EDGE_ID: torch.randperm(total_num_edges)} edge_attributes = {gb.ORIGINAL_EDGE_ID: torch.randperm(total_num_edges)}
# Construct FusedCSCSamplingGraph. # Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc(indptr, indices, edge_attributes=edge_attributes) graph = gb.fused_csc_sampling_graph(
indptr, indices, edge_attributes=edge_attributes
)
# Generate subgraph via sample neighbors. # Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4]) nodes = torch.LongTensor([1, 3, 4])
...@@ -2083,7 +2093,7 @@ def test_sample_neighbors_return_eids_hetero_csc_format(labor): ...@@ -2083,7 +2093,7 @@ def test_sample_neighbors_return_eids_hetero_csc_format(labor):
assert indptr[-1] == len(indices) assert indptr[-1] == len(indices)
# Construct FusedCSCSamplingGraph. # Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc( graph = gb.fused_csc_sampling_graph(
indptr, indptr,
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
...@@ -2142,7 +2152,9 @@ def test_sample_neighbors_probs_csc_format(replace, labor, probs_name): ...@@ -2142,7 +2152,9 @@ def test_sample_neighbors_probs_csc_format(replace, labor, probs_name):
} }
# Construct FusedCSCSamplingGraph. # Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc(indptr, indices, edge_attributes=edge_attributes) graph = gb.fused_csc_sampling_graph(
indptr, indices, edge_attributes=edge_attributes
)
# Generate subgraph via sample neighbors. # Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4]) nodes = torch.LongTensor([1, 3, 4])
...@@ -2190,7 +2202,9 @@ def test_sample_neighbors_zero_probs_csc_format(replace, labor, probs_or_mask): ...@@ -2190,7 +2202,9 @@ def test_sample_neighbors_zero_probs_csc_format(replace, labor, probs_or_mask):
edge_attributes = {"probs_or_mask": probs_or_mask} edge_attributes = {"probs_or_mask": probs_or_mask}
# Construct FusedCSCSamplingGraph. # Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc(indptr, indices, edge_attributes=edge_attributes) graph = gb.fused_csc_sampling_graph(
indptr, indices, edge_attributes=edge_attributes
)
# Generate subgraph via sample neighbors. # Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4]) nodes = torch.LongTensor([1, 3, 4])
...@@ -2257,7 +2271,9 @@ def test_sample_neighbors_homo_pick_number_csc_format( ...@@ -2257,7 +2271,9 @@ def test_sample_neighbors_homo_pick_number_csc_format(
} }
# Construct FusedCSCSamplingGraph. # Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc(indptr, indices, edge_attributes=edge_attributes) graph = gb.fused_csc_sampling_graph(
indptr, indices, edge_attributes=edge_attributes
)
# Generate subgraph via sample neighbors. # Generate subgraph via sample neighbors.
nodes = torch.LongTensor([0, 1]) nodes = torch.LongTensor([0, 1])
...@@ -2341,7 +2357,7 @@ def test_sample_neighbors_hetero_pick_number_csc_format( ...@@ -2341,7 +2357,7 @@ def test_sample_neighbors_hetero_pick_number_csc_format(
} }
# Construct FusedCSCSamplingGraph. # Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc( graph = gb.fused_csc_sampling_graph(
indptr, indptr,
indices, indices,
edge_attributes=edge_attributes, edge_attributes=edge_attributes,
......
...@@ -15,7 +15,7 @@ def test_InSubgraphSampler_homo(): ...@@ -15,7 +15,7 @@ def test_InSubgraphSampler_homo():
""" """
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12, 14]) indptr = torch.LongTensor([0, 3, 5, 7, 9, 12, 14])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 5, 1, 2, 0, 3, 5, 1, 4]) indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 5, 1, 2, 0, 3, 5, 1, 4])
graph = gb.from_fused_csc(indptr, indices) graph = gb.fused_csc_sampling_graph(indptr, indices)
seed_nodes = torch.LongTensor([0, 5, 3]) seed_nodes = torch.LongTensor([0, 5, 3])
item_set = gb.ItemSet(seed_nodes, names="seed_nodes") item_set = gb.ItemSet(seed_nodes, names="seed_nodes")
...@@ -80,7 +80,7 @@ def test_InSubgraphSampler_hetero(): ...@@ -80,7 +80,7 @@ def test_InSubgraphSampler_hetero():
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 5, 1, 2, 0, 3, 5, 1, 4]) indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 5, 1, 2, 0, 3, 5, 1, 4])
node_type_offset = torch.LongTensor([0, 3, 6]) node_type_offset = torch.LongTensor([0, 3, 6])
type_per_edge = torch.LongTensor([0, 0, 2, 0, 2, 0, 2, 1, 1, 1, 3, 3, 1, 3]) type_per_edge = torch.LongTensor([0, 0, 2, 0, 2, 0, 2, 1, 1, 1, 3, 3, 1, 3])
graph = gb.from_fused_csc( graph = gb.fused_csc_sampling_graph(
csc_indptr=indptr, csc_indptr=indptr,
indices=indices, indices=indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
......
...@@ -110,7 +110,7 @@ def get_hetero_graph(): ...@@ -110,7 +110,7 @@ def get_hetero_graph():
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1]) indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1])
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0]) type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0])
node_type_offset = torch.LongTensor([0, 2, 5]) node_type_offset = torch.LongTensor([0, 2, 5])
return gb.from_fused_csc( return gb.fused_csc_sampling_graph(
indptr, indptr,
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
......
...@@ -1006,7 +1006,7 @@ def test_OnDiskDataset_Graph_Exceptions(): ...@@ -1006,7 +1006,7 @@ def test_OnDiskDataset_Graph_Exceptions():
def test_OnDiskDataset_Graph_homogeneous(): def test_OnDiskDataset_Graph_homogeneous():
"""Test homogeneous graph topology.""" """Test homogeneous graph topology."""
csc_indptr, indices = gbt.random_homo_graph(1000, 10 * 1000) csc_indptr, indices = gbt.random_homo_graph(1000, 10 * 1000)
graph = gb.from_fused_csc(csc_indptr, indices) graph = gb.fused_csc_sampling_graph(csc_indptr, indices)
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
graph_path = os.path.join(test_dir, "fused_csc_sampling_graph.pt") graph_path = os.path.join(test_dir, "fused_csc_sampling_graph.pt")
...@@ -1044,7 +1044,7 @@ def test_OnDiskDataset_Graph_heterogeneous(): ...@@ -1044,7 +1044,7 @@ def test_OnDiskDataset_Graph_heterogeneous():
node_type_to_id, node_type_to_id,
edge_type_to_id, edge_type_to_id,
) = gbt.random_hetero_graph(1000, 10 * 1000, 3, 4) ) = gbt.random_hetero_graph(1000, 10 * 1000, 3, 4)
graph = gb.from_fused_csc( graph = gb.fused_csc_sampling_graph(
csc_indptr, csc_indptr,
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
...@@ -1833,7 +1833,7 @@ def test_OnDiskDataset_load_tasks(): ...@@ -1833,7 +1833,7 @@ def test_OnDiskDataset_load_tasks():
def test_OnDiskDataset_all_nodes_set_homo(): def test_OnDiskDataset_all_nodes_set_homo():
"""Test homograph's all nodes set of OnDiskDataset.""" """Test homograph's all nodes set of OnDiskDataset."""
csc_indptr, indices = gbt.random_homo_graph(1000, 10 * 1000) csc_indptr, indices = gbt.random_homo_graph(1000, 10 * 1000)
graph = gb.from_fused_csc(csc_indptr, indices) graph = gb.fused_csc_sampling_graph(csc_indptr, indices)
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
graph_path = os.path.join(test_dir, "fused_csc_sampling_graph.pt") graph_path = os.path.join(test_dir, "fused_csc_sampling_graph.pt")
...@@ -1864,7 +1864,7 @@ def test_OnDiskDataset_all_nodes_set_hetero(): ...@@ -1864,7 +1864,7 @@ def test_OnDiskDataset_all_nodes_set_hetero():
node_type_to_id, node_type_to_id,
edge_type_to_id, edge_type_to_id,
) = gbt.random_hetero_graph(1000, 10 * 1000, 3, 4) ) = gbt.random_hetero_graph(1000, 10 * 1000, 3, 4)
graph = gb.from_fused_csc( graph = gb.fused_csc_sampling_graph(
csc_indptr, csc_indptr,
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
......
...@@ -145,7 +145,7 @@ def get_hetero_graph(): ...@@ -145,7 +145,7 @@ def get_hetero_graph():
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1]) indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1])
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0]) type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0])
node_type_offset = torch.LongTensor([0, 2, 5]) node_type_offset = torch.LongTensor([0, 2, 5])
return gb.from_fused_csc( return gb.fused_csc_sampling_graph(
indptr, indptr,
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
......
...@@ -38,7 +38,7 @@ def test_integration_link_prediction(): ...@@ -38,7 +38,7 @@ def test_integration_link_prediction():
) )
item_set = gb.ItemSet(node_pairs, names="node_pairs") item_set = gb.ItemSet(node_pairs, names="node_pairs")
graph = gb.from_fused_csc(indptr, indices) graph = gb.fused_csc_sampling_graph(indptr, indices)
node_feature = gb.TorchBasedFeature(node_feature_data) node_feature = gb.TorchBasedFeature(node_feature_data)
edge_feature = gb.TorchBasedFeature(edge_feature_data) edge_feature = gb.TorchBasedFeature(edge_feature_data)
...@@ -154,7 +154,7 @@ def test_integration_node_classification(): ...@@ -154,7 +154,7 @@ def test_integration_node_classification():
) )
item_set = gb.ItemSet(node_pairs, names="node_pairs") item_set = gb.ItemSet(node_pairs, names="node_pairs")
graph = gb.from_fused_csc(indptr, indices) graph = gb.fused_csc_sampling_graph(indptr, indices)
node_feature = gb.TorchBasedFeature(node_feature_data) node_feature = gb.TorchBasedFeature(node_feature_data)
edge_feature = gb.TorchBasedFeature(edge_feature_data) edge_feature = gb.TorchBasedFeature(edge_feature_data)
......
...@@ -121,7 +121,7 @@ def get_hetero_graph(): ...@@ -121,7 +121,7 @@ def get_hetero_graph():
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1]) indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1])
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0]) type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0])
node_type_offset = torch.LongTensor([0, 2, 5]) node_type_offset = torch.LongTensor([0, 2, 5])
return gb.from_fused_csc( return gb.fused_csc_sampling_graph(
indptr, indptr,
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
...@@ -216,7 +216,7 @@ def test_SubgraphSampler_Random_Hetero_Graph(labor): ...@@ -216,7 +216,7 @@ def test_SubgraphSampler_Random_Hetero_Graph(labor):
"A1": torch.randn(num_edges), "A1": torch.randn(num_edges),
"A2": torch.randn(num_edges), "A2": torch.randn(num_edges),
} }
graph = gb.from_fused_csc( graph = gb.fused_csc_sampling_graph(
csc_indptr, csc_indptr,
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment