Unverified Commit 382a2de7 authored by LastWhisper's avatar LastWhisper Committed by GitHub
Browse files

[GraphBolt] Refactor SampledSubgraph and update the corresponding method. (#6533)

parent a24a38bc
......@@ -60,7 +60,7 @@ Standard Implementations
UniformNegativeSampler
NeighborSampler
LayerNeighborSampler
SampledSubgraphImpl
FusedSampledSubgraphImpl
BasicFeatureStore
TorchBasedFeature
TorchBasedFeatureStore
......
......@@ -6,7 +6,7 @@
#ifndef GRAPHBOLT_CSC_SAMPLING_GRAPH_H_
#define GRAPHBOLT_CSC_SAMPLING_GRAPH_H_
#include <graphbolt/sampled_subgraph.h>
#include <graphbolt/fused_sampled_subgraph.h>
#include <graphbolt/shared_memory.h>
#include <torch/torch.h>
......@@ -172,9 +172,9 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* @brief Return the subgraph induced on the inbound edges of the given nodes.
* @param nodes Type agnostic node IDs to form the subgraph.
*
* @return SampledSubgraph.
* @return FusedSampledSubgraph.
*/
c10::intrusive_ptr<SampledSubgraph> InSubgraph(
c10::intrusive_ptr<FusedSampledSubgraph> InSubgraph(
const torch::Tensor& nodes) const;
/**
......@@ -208,10 +208,10 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* a 1D floating-point or boolean tensor, with the number of elements
* equalling the total number of edges.
*
* @return An intrusive pointer to a SampledSubgraph object containing the
* sampled graph's information.
* @return An intrusive pointer to a FusedSampledSubgraph object containing
* the sampled graph's information.
*/
c10::intrusive_ptr<SampledSubgraph> SampleNeighbors(
c10::intrusive_ptr<FusedSampledSubgraph> SampleNeighbors(
const torch::Tensor& nodes, const std::vector<int64_t>& fanouts,
bool replace, bool layer, bool return_eids,
torch::optional<std::string> probs_name) const;
......@@ -276,7 +276,7 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
private:
template <typename NumPickFn, typename PickFn>
c10::intrusive_ptr<SampledSubgraph> SampleNeighborsImpl(
c10::intrusive_ptr<FusedSampledSubgraph> SampleNeighborsImpl(
const torch::Tensor& nodes, bool return_eids, NumPickFn num_pick_fn,
PickFn pick_fn) const;
......
/**
* Copyright (c) 2023 by Contributors
* @file graphbolt/sampled_subgraph.h
* @file graphbolt/fused_sampled_subgraph.h
* @brief Header file of sampled sub graph.
*/
#ifndef GRAPHBOLT_SAMPLED_SUBGRAPH_H_
#define GRAPHBOLT_SAMPLED_SUBGRAPH_H_
#ifndef GRAPHBOLT_FUSED_SAMPLED_SUBGRAPH_H_
#define GRAPHBOLT_FUSED_SAMPLED_SUBGRAPH_H_
#include <torch/custom_class.h>
#include <torch/torch.h>
......@@ -24,7 +24,8 @@ namespace sampling {
* auto indices = torch::tensor({55, 101, 3, 3}, {torch::kInt64});
* auto original_column_node_ids = torch::tensor({3, 3, 101}, {torch::kInt64});
*
* SampledSubgraph sampledSubgraph(indptr, indices, original_column_node_ids);
* FusedSampledSubgraph sampledSubgraph(indptr, indices,
* original_column_node_ids);
* ```
*
* The `original_column_node_ids` indicates that nodes `[3, 3, 101]` in the
......@@ -37,10 +38,10 @@ namespace sampling {
* inconsistent with column, which is legal, as `3` is mapped to `0` and `1` in
* the column while `2` in the row.
*/
struct SampledSubgraph : torch::CustomClassHolder {
struct FusedSampledSubgraph : torch::CustomClassHolder {
public:
/**
* @brief Constructor for the SampledSubgraph struct.
* @brief Constructor for the FusedSampledSubgraph struct.
*
* @param indptr CSC format index pointer array.
* @param indices CSC format index array.
......@@ -51,7 +52,7 @@ struct SampledSubgraph : torch::CustomClassHolder {
* @param original_edge_ids Reverse edge ids in the original graph.
* @param type_per_edge Type id of each edge.
*/
SampledSubgraph(
FusedSampledSubgraph(
torch::Tensor indptr, torch::Tensor indices,
torch::Tensor original_column_node_ids,
torch::optional<torch::Tensor> original_row_node_ids = torch::nullopt,
......@@ -64,7 +65,7 @@ struct SampledSubgraph : torch::CustomClassHolder {
original_edge_ids(original_edge_ids),
type_per_edge(type_per_edge) {}
SampledSubgraph() = default;
FusedSampledSubgraph() = default;
/**
* @brief CSC format index pointer array, where the implicit node ids are
......@@ -118,4 +119,4 @@ struct SampledSubgraph : torch::CustomClassHolder {
} // namespace sampling
} // namespace graphbolt
#endif // GRAPHBOLT_SAMPLED_SUBGRAPH_H_
#endif // GRAPHBOLT_FUSED_SAMPLED_SUBGRAPH_H_
......@@ -182,7 +182,7 @@ FusedCSCSamplingGraph::GetState() const {
return state;
}
c10::intrusive_ptr<SampledSubgraph> FusedCSCSamplingGraph::InSubgraph(
c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::InSubgraph(
const torch::Tensor& nodes) const {
using namespace torch::indexing;
const int32_t kDefaultGrainSize = 100;
......@@ -211,7 +211,7 @@ c10::intrusive_ptr<SampledSubgraph> FusedCSCSamplingGraph::InSubgraph(
torch::Tensor compact_indptr =
torch::zeros({nonzero_idx.size(0) + 1}, indptr_.dtype());
compact_indptr.index_put_({Slice(1, None)}, indptr.index({nonzero_idx}));
return c10::make_intrusive<SampledSubgraph>(
return c10::make_intrusive<FusedSampledSubgraph>(
compact_indptr.cumsum(0), torch::cat(indices_arr), nonzero_idx - 1,
torch::arange(0, NumNodes()), torch::cat(edge_ids_arr),
type_per_edge_
......@@ -305,7 +305,8 @@ auto GetPickFn(
}
template <typename NumPickFn, typename PickFn>
c10::intrusive_ptr<SampledSubgraph> FusedCSCSamplingGraph::SampleNeighborsImpl(
c10::intrusive_ptr<FusedSampledSubgraph>
FusedCSCSamplingGraph::SampleNeighborsImpl(
const torch::Tensor& nodes, bool return_eids, NumPickFn num_pick_fn,
PickFn pick_fn) const {
const int64_t num_nodes = nodes.size(0);
......@@ -417,12 +418,12 @@ c10::intrusive_ptr<SampledSubgraph> FusedCSCSamplingGraph::SampleNeighborsImpl(
torch::optional<torch::Tensor> subgraph_reverse_edge_ids = torch::nullopt;
if (return_eids) subgraph_reverse_edge_ids = std::move(picked_eids);
return c10::make_intrusive<SampledSubgraph>(
return c10::make_intrusive<FusedSampledSubgraph>(
subgraph_indptr, subgraph_indices, nodes, torch::nullopt,
subgraph_reverse_edge_ids, subgraph_type_per_edge);
}
c10::intrusive_ptr<SampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
const torch::Tensor& nodes, const std::vector<int64_t>& fanouts,
bool replace, bool layer, bool return_eids,
torch::optional<std::string> probs_name) const {
......
......@@ -15,17 +15,18 @@ namespace graphbolt {
namespace sampling {
TORCH_LIBRARY(graphbolt, m) {
m.class_<SampledSubgraph>("SampledSubgraph")
m.class_<FusedSampledSubgraph>("FusedSampledSubgraph")
.def(torch::init<>())
.def_readwrite("indptr", &SampledSubgraph::indptr)
.def_readwrite("indices", &SampledSubgraph::indices)
.def_readwrite("indptr", &FusedSampledSubgraph::indptr)
.def_readwrite("indices", &FusedSampledSubgraph::indices)
.def_readwrite(
"original_row_node_ids", &SampledSubgraph::original_row_node_ids)
"original_row_node_ids", &FusedSampledSubgraph::original_row_node_ids)
.def_readwrite(
"original_column_node_ids",
&SampledSubgraph::original_column_node_ids)
.def_readwrite("original_edge_ids", &SampledSubgraph::original_edge_ids)
.def_readwrite("type_per_edge", &SampledSubgraph::type_per_edge);
&FusedSampledSubgraph::original_column_node_ids)
.def_readwrite(
"original_edge_ids", &FusedSampledSubgraph::original_edge_ids)
.def_readwrite("type_per_edge", &FusedSampledSubgraph::type_per_edge);
m.class_<FusedCSCSamplingGraph>("FusedCSCSamplingGraph")
.def("num_nodes", &FusedCSCSamplingGraph::NumNodes)
.def("num_edges", &FusedCSCSamplingGraph::NumEdges)
......
......@@ -15,7 +15,7 @@ from ...convert import to_homogeneous
from ...heterograph import DGLGraph
from ..base import etype_str_to_tuple, etype_tuple_to_str, ORIGINAL_EDGE_ID
from ..sampling_graph import SamplingGraph
from .sampled_subgraph_impl import SampledSubgraphImpl
from .sampled_subgraph_impl import FusedSampledSubgraphImpl
__all__ = [
......@@ -305,7 +305,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
assert len(torch.unique(nodes)) == len(
nodes
), "Nodes cannot have duplicate values."
# TODO: change the result to 'SampledSubgraphImpl'.
# TODO: change the result to 'FusedSampledSubgraphImpl'.
return self._c_csc_graph.in_subgraph(nodes)
def _convert_to_sampled_subgraph(
......@@ -313,7 +313,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
C_sampled_subgraph: torch.ScriptObject,
):
"""An internal function used to convert a fused homogeneous sampled
subgraph to general struct 'SampledSubgraphImpl'."""
subgraph to general struct 'FusedSampledSubgraphImpl'."""
column_num = (
C_sampled_subgraph.indptr[1:] - C_sampled_subgraph.indptr[:-1]
)
......@@ -353,7 +353,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
original_hetero_edge_ids[etype] = original_edge_ids[mask]
if has_original_eids:
original_edge_ids = original_hetero_edge_ids
return SampledSubgraphImpl(
return FusedSampledSubgraphImpl(
node_pairs=node_pairs, original_edge_ids=original_edge_ids
)
......@@ -370,7 +370,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
fanouts: torch.Tensor,
replace: bool = False,
probs_name: Optional[str] = None,
) -> SampledSubgraphImpl:
) -> FusedSampledSubgraphImpl:
"""Sample neighboring edges of the given nodes and return the induced
subgraph.
......@@ -411,7 +411,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
equalling the total number of edges.
Returns
-------
SampledSubgraphImpl
FusedSampledSubgraphImpl
The sampled subgraph.
Examples
......@@ -548,7 +548,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
fanouts: torch.Tensor,
replace: bool = False,
probs_name: Optional[str] = None,
) -> SampledSubgraphImpl:
) -> FusedSampledSubgraphImpl:
"""Sample neighboring edges of the given nodes and return the induced
subgraph via layer-neighbor sampling from the NeurIPS 2023 paper
`Layer-Neighbor Sampling -- Defusing Neighborhood Explosion in GNNs
......@@ -591,7 +591,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
equalling the total number of edges.
Returns
-------
SampledSubgraphImpl
FusedSampledSubgraphImpl
The sampled subgraph.
Examples
......
......@@ -5,7 +5,7 @@ from torch.utils.data import functional_datapipe
from ..subgraph_sampler import SubgraphSampler
from ..utils import unique_and_compact_node_pairs
from .sampled_subgraph_impl import SampledSubgraphImpl
from .sampled_subgraph_impl import FusedSampledSubgraphImpl
__all__ = ["NeighborSampler", "LayerNeighborSampler"]
......@@ -124,7 +124,7 @@ class NeighborSampler(SubgraphSampler):
) = unique_and_compact_node_pairs(subgraph.node_pairs, seeds)
else:
raise RuntimeError("Not implemented yet.")
subgraph = SampledSubgraphImpl(
subgraph = FusedSampledSubgraphImpl(
node_pairs=compacted_node_pairs,
original_column_node_ids=seeds,
original_row_node_ids=original_row_node_ids,
......
......@@ -8,11 +8,11 @@ import torch
from ..base import etype_str_to_tuple
from ..sampled_subgraph import SampledSubgraph
__all__ = ["SampledSubgraphImpl"]
__all__ = ["FusedSampledSubgraphImpl"]
@dataclass
class SampledSubgraphImpl(SampledSubgraph):
class FusedSampledSubgraphImpl(SampledSubgraph):
r"""Sampled subgraph of FusedCSCSamplingGraph.
Examples
......@@ -22,7 +22,7 @@ class SampledSubgraphImpl(SampledSubgraph):
>>> original_column_node_ids = {'B': torch.tensor([10, 11, 12])}
>>> original_row_node_ids = {'A': torch.tensor([13, 14, 15])}
>>> original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
>>> subgraph = gb.SampledSubgraphImpl(
>>> subgraph = gb.FusedSampledSubgraphImpl(
... node_pairs=node_pairs,
... original_column_node_ids=original_column_node_ids,
... original_row_node_ids=original_row_node_ids,
......
......@@ -453,16 +453,16 @@ def _minibatch_str(minibatch: MiniBatch) -> str:
if isinstance(val, list):
if len(val) == 0:
val = "[]"
# Special handling of SampledSubgraphImpl data. Each element of
# Special handling of FusedSampledSubgraphImpl data. Each element of
# the data occupies one row and is further structured.
elif isinstance(
val[0],
dgl.graphbolt.impl.sampled_subgraph_impl.SampledSubgraphImpl,
dgl.graphbolt.impl.sampled_subgraph_impl.FusedSampledSubgraphImpl,
):
sampledsubgraph_strs = []
for sampledsubgraph in val:
ss_attributes = _get_attributes(sampledsubgraph)
sampledsubgraph_str = "SampledSubgraphImpl("
sampledsubgraph_str = "FusedSampledSubgraphImpl("
for ss_name in ss_attributes:
ss_val = str(getattr(sampledsubgraph, ss_name))
sampledsubgraph_str = (
......
......@@ -123,7 +123,7 @@ class SampledSubgraph:
>>> original_column_node_ids = {'B': torch.tensor([10, 11, 12])}
>>> original_row_node_ids = {'A': torch.tensor([13, 14, 15])}
>>> original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
>>> subgraph = gb.SampledSubgraphImpl(
>>> subgraph = gb.FusedSampledSubgraphImpl(
... node_pairs=node_pairs,
... original_column_node_ids=original_column_node_ids,
... original_row_node_ids=original_row_node_ids,
......
......@@ -39,7 +39,7 @@ def create_homo_minibatch():
subgraphs = []
for i in range(2):
subgraphs.append(
gb.SampledSubgraphImpl(
gb.FusedSampledSubgraphImpl(
node_pairs=node_pairs[i],
original_column_node_ids=original_column_node_ids[i],
original_row_node_ids=original_row_node_ids[i],
......@@ -93,7 +93,7 @@ def create_hetero_minibatch():
subgraphs = []
for i in range(2):
subgraphs.append(
gb.SampledSubgraphImpl(
gb.FusedSampledSubgraphImpl(
node_pairs=node_pairs[i],
original_column_node_ids=original_column_node_ids[i],
original_row_node_ids=original_row_node_ids[i],
......@@ -142,7 +142,7 @@ def test_minibatch_representation():
subgraphs = []
for i in range(2):
subgraphs.append(
gb.SampledSubgraphImpl(
gb.FusedSampledSubgraphImpl(
node_pairs=node_pairs[i],
original_column_node_ids=original_column_node_ids[i],
original_row_node_ids=original_row_node_ids[i],
......@@ -191,11 +191,11 @@ def test_minibatch_representation():
)
expect_result = str(
"""MiniBatch(seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(node_pairs=(tensor([0, 1, 2, 2, 2, 1]), tensor([0, 1, 1, 2, 3, 2])),
sampled_subgraphs=[FusedSampledSubgraphImpl(node_pairs=(tensor([0, 1, 2, 2, 2, 1]), tensor([0, 1, 1, 2, 3, 2])),
original_column_node_ids=tensor([10, 11, 12, 13]),
original_edge_ids=tensor([19, 20, 21, 22, 25, 30]),
original_row_node_ids=tensor([10, 11, 12, 13]),),
SampledSubgraphImpl(node_pairs=(tensor([0, 1, 2]), tensor([1, 0, 0])),
FusedSampledSubgraphImpl(node_pairs=(tensor([0, 1, 2]), tensor([1, 0, 0])),
original_column_node_ids=tensor([10, 11]),
original_edge_ids=tensor([10, 15, 17]),
original_row_node_ids=tensor([10, 11, 12]),)],
......@@ -260,7 +260,7 @@ def test_dgl_minibatch_representation():
subgraphs = []
for i in range(2):
subgraphs.append(
gb.SampledSubgraphImpl(
gb.FusedSampledSubgraphImpl(
node_pairs=node_pairs[i],
original_column_node_ids=original_column_node_ids[i],
original_row_node_ids=original_row_node_ids[i],
......
......@@ -4,7 +4,7 @@ import backend as F
import pytest
import torch
from dgl.graphbolt.impl.sampled_subgraph_impl import SampledSubgraphImpl
from dgl.graphbolt.impl.sampled_subgraph_impl import FusedSampledSubgraphImpl
def _assert_container_equal(lhs, rhs):
......@@ -42,7 +42,7 @@ def test_exclude_edges_homo(reverse_row, reverse_column):
original_column_node_ids = None
dst_to_exclude = torch.tensor([4])
original_edge_ids = torch.Tensor([5, 9, 10])
subgraph = SampledSubgraphImpl(
subgraph = FusedSampledSubgraphImpl(
node_pairs,
original_column_node_ids,
original_row_node_ids,
......@@ -95,7 +95,7 @@ def test_exclude_edges_hetero(reverse_row, reverse_column):
original_column_node_ids = None
dst_to_exclude = torch.tensor([0, 2])
original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
subgraph = SampledSubgraphImpl(
subgraph = FusedSampledSubgraphImpl(
node_pairs=node_pairs,
original_column_node_ids=original_column_node_ids,
original_row_node_ids=original_row_node_ids,
......@@ -158,7 +158,7 @@ def test_sampled_subgraph_to_device():
}
dst_to_exclude = torch.tensor([10, 12])
original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
subgraph = SampledSubgraphImpl(
subgraph = FusedSampledSubgraphImpl(
node_pairs=node_pairs,
original_column_node_ids=original_column_node_ids,
original_row_node_ids=original_row_node_ids,
......
......@@ -77,7 +77,7 @@ def test_FeatureFetcher_with_edges_homo():
subgraphs = []
for _ in range(3):
subgraphs.append(
gb.SampledSubgraphImpl(
gb.FusedSampledSubgraphImpl(
node_pairs=(torch.tensor([]), torch.tensor([])),
original_edge_ids=torch.randint(
0, graph.total_num_edges, (10,)
......@@ -168,7 +168,7 @@ def test_FeatureFetcher_with_edges_hetero():
}
for _ in range(3):
subgraphs.append(
gb.SampledSubgraphImpl(
gb.FusedSampledSubgraphImpl(
node_pairs=(torch.tensor([]), torch.tensor([])),
original_edge_ids=original_edge_ids,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment