Unverified Commit 382a2de7 authored by LastWhisper's avatar LastWhisper Committed by GitHub
Browse files

[GraphBolt] Refactor SampledSubgraph and update the corresponding method. (#6533)

parent a24a38bc
...@@ -60,7 +60,7 @@ Standard Implementations ...@@ -60,7 +60,7 @@ Standard Implementations
UniformNegativeSampler UniformNegativeSampler
NeighborSampler NeighborSampler
LayerNeighborSampler LayerNeighborSampler
SampledSubgraphImpl FusedSampledSubgraphImpl
BasicFeatureStore BasicFeatureStore
TorchBasedFeature TorchBasedFeature
TorchBasedFeatureStore TorchBasedFeatureStore
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#ifndef GRAPHBOLT_CSC_SAMPLING_GRAPH_H_ #ifndef GRAPHBOLT_CSC_SAMPLING_GRAPH_H_
#define GRAPHBOLT_CSC_SAMPLING_GRAPH_H_ #define GRAPHBOLT_CSC_SAMPLING_GRAPH_H_
#include <graphbolt/sampled_subgraph.h> #include <graphbolt/fused_sampled_subgraph.h>
#include <graphbolt/shared_memory.h> #include <graphbolt/shared_memory.h>
#include <torch/torch.h> #include <torch/torch.h>
...@@ -172,9 +172,9 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { ...@@ -172,9 +172,9 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* @brief Return the subgraph induced on the inbound edges of the given nodes. * @brief Return the subgraph induced on the inbound edges of the given nodes.
* @param nodes Type agnostic node IDs to form the subgraph. * @param nodes Type agnostic node IDs to form the subgraph.
* *
* @return SampledSubgraph. * @return FusedSampledSubgraph.
*/ */
c10::intrusive_ptr<SampledSubgraph> InSubgraph( c10::intrusive_ptr<FusedSampledSubgraph> InSubgraph(
const torch::Tensor& nodes) const; const torch::Tensor& nodes) const;
/** /**
...@@ -208,10 +208,10 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { ...@@ -208,10 +208,10 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* a 1D floating-point or boolean tensor, with the number of elements * a 1D floating-point or boolean tensor, with the number of elements
* equalling the total number of edges. * equalling the total number of edges.
* *
* @return An intrusive pointer to a SampledSubgraph object containing the * @return An intrusive pointer to a FusedSampledSubgraph object containing
* sampled graph's information. * the sampled graph's information.
*/ */
c10::intrusive_ptr<SampledSubgraph> SampleNeighbors( c10::intrusive_ptr<FusedSampledSubgraph> SampleNeighbors(
const torch::Tensor& nodes, const std::vector<int64_t>& fanouts, const torch::Tensor& nodes, const std::vector<int64_t>& fanouts,
bool replace, bool layer, bool return_eids, bool replace, bool layer, bool return_eids,
torch::optional<std::string> probs_name) const; torch::optional<std::string> probs_name) const;
...@@ -276,7 +276,7 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { ...@@ -276,7 +276,7 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
private: private:
template <typename NumPickFn, typename PickFn> template <typename NumPickFn, typename PickFn>
c10::intrusive_ptr<SampledSubgraph> SampleNeighborsImpl( c10::intrusive_ptr<FusedSampledSubgraph> SampleNeighborsImpl(
const torch::Tensor& nodes, bool return_eids, NumPickFn num_pick_fn, const torch::Tensor& nodes, bool return_eids, NumPickFn num_pick_fn,
PickFn pick_fn) const; PickFn pick_fn) const;
......
/** /**
* Copyright (c) 2023 by Contributors * Copyright (c) 2023 by Contributors
* @file graphbolt/sampled_subgraph.h * @file graphbolt/fused_sampled_subgraph.h
* @brief Header file of sampled sub graph. * @brief Header file of sampled sub graph.
*/ */
#ifndef GRAPHBOLT_SAMPLED_SUBGRAPH_H_ #ifndef GRAPHBOLT_FUSED_SAMPLED_SUBGRAPH_H_
#define GRAPHBOLT_SAMPLED_SUBGRAPH_H_ #define GRAPHBOLT_FUSED_SAMPLED_SUBGRAPH_H_
#include <torch/custom_class.h> #include <torch/custom_class.h>
#include <torch/torch.h> #include <torch/torch.h>
...@@ -24,7 +24,8 @@ namespace sampling { ...@@ -24,7 +24,8 @@ namespace sampling {
* auto indices = torch::tensor({55, 101, 3, 3}, {torch::kInt64}); * auto indices = torch::tensor({55, 101, 3, 3}, {torch::kInt64});
* auto original_column_node_ids = torch::tensor({3, 3, 101}, {torch::kInt64}); * auto original_column_node_ids = torch::tensor({3, 3, 101}, {torch::kInt64});
* *
* SampledSubgraph sampledSubgraph(indptr, indices, original_column_node_ids); * FusedSampledSubgraph sampledSubgraph(indptr, indices,
* original_column_node_ids);
* ``` * ```
* *
* The `original_column_node_ids` indicates that nodes `[3, 3, 101]` in the * The `original_column_node_ids` indicates that nodes `[3, 3, 101]` in the
...@@ -37,10 +38,10 @@ namespace sampling { ...@@ -37,10 +38,10 @@ namespace sampling {
* inconsistent with column, which is legal, as `3` is mapped to `0` and `1` in * inconsistent with column, which is legal, as `3` is mapped to `0` and `1` in
* the column while `2` in the row. * the column while `2` in the row.
*/ */
struct SampledSubgraph : torch::CustomClassHolder { struct FusedSampledSubgraph : torch::CustomClassHolder {
public: public:
/** /**
* @brief Constructor for the SampledSubgraph struct. * @brief Constructor for the FusedSampledSubgraph struct.
* *
* @param indptr CSC format index pointer array. * @param indptr CSC format index pointer array.
* @param indices CSC format index array. * @param indices CSC format index array.
...@@ -51,7 +52,7 @@ struct SampledSubgraph : torch::CustomClassHolder { ...@@ -51,7 +52,7 @@ struct SampledSubgraph : torch::CustomClassHolder {
* @param original_edge_ids Reverse edge ids in the original graph. * @param original_edge_ids Reverse edge ids in the original graph.
* @param type_per_edge Type id of each edge. * @param type_per_edge Type id of each edge.
*/ */
SampledSubgraph( FusedSampledSubgraph(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor indptr, torch::Tensor indices,
torch::Tensor original_column_node_ids, torch::Tensor original_column_node_ids,
torch::optional<torch::Tensor> original_row_node_ids = torch::nullopt, torch::optional<torch::Tensor> original_row_node_ids = torch::nullopt,
...@@ -64,7 +65,7 @@ struct SampledSubgraph : torch::CustomClassHolder { ...@@ -64,7 +65,7 @@ struct SampledSubgraph : torch::CustomClassHolder {
original_edge_ids(original_edge_ids), original_edge_ids(original_edge_ids),
type_per_edge(type_per_edge) {} type_per_edge(type_per_edge) {}
SampledSubgraph() = default; FusedSampledSubgraph() = default;
/** /**
* @brief CSC format index pointer array, where the implicit node ids are * @brief CSC format index pointer array, where the implicit node ids are
...@@ -118,4 +119,4 @@ struct SampledSubgraph : torch::CustomClassHolder { ...@@ -118,4 +119,4 @@ struct SampledSubgraph : torch::CustomClassHolder {
} // namespace sampling } // namespace sampling
} // namespace graphbolt } // namespace graphbolt
#endif // GRAPHBOLT_SAMPLED_SUBGRAPH_H_ #endif // GRAPHBOLT_FUSED_SAMPLED_SUBGRAPH_H_
...@@ -182,7 +182,7 @@ FusedCSCSamplingGraph::GetState() const { ...@@ -182,7 +182,7 @@ FusedCSCSamplingGraph::GetState() const {
return state; return state;
} }
c10::intrusive_ptr<SampledSubgraph> FusedCSCSamplingGraph::InSubgraph( c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::InSubgraph(
const torch::Tensor& nodes) const { const torch::Tensor& nodes) const {
using namespace torch::indexing; using namespace torch::indexing;
const int32_t kDefaultGrainSize = 100; const int32_t kDefaultGrainSize = 100;
...@@ -211,7 +211,7 @@ c10::intrusive_ptr<SampledSubgraph> FusedCSCSamplingGraph::InSubgraph( ...@@ -211,7 +211,7 @@ c10::intrusive_ptr<SampledSubgraph> FusedCSCSamplingGraph::InSubgraph(
torch::Tensor compact_indptr = torch::Tensor compact_indptr =
torch::zeros({nonzero_idx.size(0) + 1}, indptr_.dtype()); torch::zeros({nonzero_idx.size(0) + 1}, indptr_.dtype());
compact_indptr.index_put_({Slice(1, None)}, indptr.index({nonzero_idx})); compact_indptr.index_put_({Slice(1, None)}, indptr.index({nonzero_idx}));
return c10::make_intrusive<SampledSubgraph>( return c10::make_intrusive<FusedSampledSubgraph>(
compact_indptr.cumsum(0), torch::cat(indices_arr), nonzero_idx - 1, compact_indptr.cumsum(0), torch::cat(indices_arr), nonzero_idx - 1,
torch::arange(0, NumNodes()), torch::cat(edge_ids_arr), torch::arange(0, NumNodes()), torch::cat(edge_ids_arr),
type_per_edge_ type_per_edge_
...@@ -305,7 +305,8 @@ auto GetPickFn( ...@@ -305,7 +305,8 @@ auto GetPickFn(
} }
template <typename NumPickFn, typename PickFn> template <typename NumPickFn, typename PickFn>
c10::intrusive_ptr<SampledSubgraph> FusedCSCSamplingGraph::SampleNeighborsImpl( c10::intrusive_ptr<FusedSampledSubgraph>
FusedCSCSamplingGraph::SampleNeighborsImpl(
const torch::Tensor& nodes, bool return_eids, NumPickFn num_pick_fn, const torch::Tensor& nodes, bool return_eids, NumPickFn num_pick_fn,
PickFn pick_fn) const { PickFn pick_fn) const {
const int64_t num_nodes = nodes.size(0); const int64_t num_nodes = nodes.size(0);
...@@ -417,12 +418,12 @@ c10::intrusive_ptr<SampledSubgraph> FusedCSCSamplingGraph::SampleNeighborsImpl( ...@@ -417,12 +418,12 @@ c10::intrusive_ptr<SampledSubgraph> FusedCSCSamplingGraph::SampleNeighborsImpl(
torch::optional<torch::Tensor> subgraph_reverse_edge_ids = torch::nullopt; torch::optional<torch::Tensor> subgraph_reverse_edge_ids = torch::nullopt;
if (return_eids) subgraph_reverse_edge_ids = std::move(picked_eids); if (return_eids) subgraph_reverse_edge_ids = std::move(picked_eids);
return c10::make_intrusive<SampledSubgraph>( return c10::make_intrusive<FusedSampledSubgraph>(
subgraph_indptr, subgraph_indices, nodes, torch::nullopt, subgraph_indptr, subgraph_indices, nodes, torch::nullopt,
subgraph_reverse_edge_ids, subgraph_type_per_edge); subgraph_reverse_edge_ids, subgraph_type_per_edge);
} }
c10::intrusive_ptr<SampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors( c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
const torch::Tensor& nodes, const std::vector<int64_t>& fanouts, const torch::Tensor& nodes, const std::vector<int64_t>& fanouts,
bool replace, bool layer, bool return_eids, bool replace, bool layer, bool return_eids,
torch::optional<std::string> probs_name) const { torch::optional<std::string> probs_name) const {
......
...@@ -15,17 +15,18 @@ namespace graphbolt { ...@@ -15,17 +15,18 @@ namespace graphbolt {
namespace sampling { namespace sampling {
TORCH_LIBRARY(graphbolt, m) { TORCH_LIBRARY(graphbolt, m) {
m.class_<SampledSubgraph>("SampledSubgraph") m.class_<FusedSampledSubgraph>("FusedSampledSubgraph")
.def(torch::init<>()) .def(torch::init<>())
.def_readwrite("indptr", &SampledSubgraph::indptr) .def_readwrite("indptr", &FusedSampledSubgraph::indptr)
.def_readwrite("indices", &SampledSubgraph::indices) .def_readwrite("indices", &FusedSampledSubgraph::indices)
.def_readwrite( .def_readwrite(
"original_row_node_ids", &SampledSubgraph::original_row_node_ids) "original_row_node_ids", &FusedSampledSubgraph::original_row_node_ids)
.def_readwrite( .def_readwrite(
"original_column_node_ids", "original_column_node_ids",
&SampledSubgraph::original_column_node_ids) &FusedSampledSubgraph::original_column_node_ids)
.def_readwrite("original_edge_ids", &SampledSubgraph::original_edge_ids) .def_readwrite(
.def_readwrite("type_per_edge", &SampledSubgraph::type_per_edge); "original_edge_ids", &FusedSampledSubgraph::original_edge_ids)
.def_readwrite("type_per_edge", &FusedSampledSubgraph::type_per_edge);
m.class_<FusedCSCSamplingGraph>("FusedCSCSamplingGraph") m.class_<FusedCSCSamplingGraph>("FusedCSCSamplingGraph")
.def("num_nodes", &FusedCSCSamplingGraph::NumNodes) .def("num_nodes", &FusedCSCSamplingGraph::NumNodes)
.def("num_edges", &FusedCSCSamplingGraph::NumEdges) .def("num_edges", &FusedCSCSamplingGraph::NumEdges)
......
...@@ -15,7 +15,7 @@ from ...convert import to_homogeneous ...@@ -15,7 +15,7 @@ from ...convert import to_homogeneous
from ...heterograph import DGLGraph from ...heterograph import DGLGraph
from ..base import etype_str_to_tuple, etype_tuple_to_str, ORIGINAL_EDGE_ID from ..base import etype_str_to_tuple, etype_tuple_to_str, ORIGINAL_EDGE_ID
from ..sampling_graph import SamplingGraph from ..sampling_graph import SamplingGraph
from .sampled_subgraph_impl import SampledSubgraphImpl from .sampled_subgraph_impl import FusedSampledSubgraphImpl
__all__ = [ __all__ = [
...@@ -305,7 +305,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -305,7 +305,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
assert len(torch.unique(nodes)) == len( assert len(torch.unique(nodes)) == len(
nodes nodes
), "Nodes cannot have duplicate values." ), "Nodes cannot have duplicate values."
# TODO: change the result to 'SampledSubgraphImpl'. # TODO: change the result to 'FusedSampledSubgraphImpl'.
return self._c_csc_graph.in_subgraph(nodes) return self._c_csc_graph.in_subgraph(nodes)
def _convert_to_sampled_subgraph( def _convert_to_sampled_subgraph(
...@@ -313,7 +313,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -313,7 +313,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
C_sampled_subgraph: torch.ScriptObject, C_sampled_subgraph: torch.ScriptObject,
): ):
"""An internal function used to convert a fused homogeneous sampled """An internal function used to convert a fused homogeneous sampled
subgraph to general struct 'SampledSubgraphImpl'.""" subgraph to general struct 'FusedSampledSubgraphImpl'."""
column_num = ( column_num = (
C_sampled_subgraph.indptr[1:] - C_sampled_subgraph.indptr[:-1] C_sampled_subgraph.indptr[1:] - C_sampled_subgraph.indptr[:-1]
) )
...@@ -353,7 +353,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -353,7 +353,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
original_hetero_edge_ids[etype] = original_edge_ids[mask] original_hetero_edge_ids[etype] = original_edge_ids[mask]
if has_original_eids: if has_original_eids:
original_edge_ids = original_hetero_edge_ids original_edge_ids = original_hetero_edge_ids
return SampledSubgraphImpl( return FusedSampledSubgraphImpl(
node_pairs=node_pairs, original_edge_ids=original_edge_ids node_pairs=node_pairs, original_edge_ids=original_edge_ids
) )
...@@ -370,7 +370,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -370,7 +370,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
fanouts: torch.Tensor, fanouts: torch.Tensor,
replace: bool = False, replace: bool = False,
probs_name: Optional[str] = None, probs_name: Optional[str] = None,
) -> SampledSubgraphImpl: ) -> FusedSampledSubgraphImpl:
"""Sample neighboring edges of the given nodes and return the induced """Sample neighboring edges of the given nodes and return the induced
subgraph. subgraph.
...@@ -411,7 +411,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -411,7 +411,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
equalling the total number of edges. equalling the total number of edges.
Returns Returns
------- -------
SampledSubgraphImpl FusedSampledSubgraphImpl
The sampled subgraph. The sampled subgraph.
Examples Examples
...@@ -548,7 +548,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -548,7 +548,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
fanouts: torch.Tensor, fanouts: torch.Tensor,
replace: bool = False, replace: bool = False,
probs_name: Optional[str] = None, probs_name: Optional[str] = None,
) -> SampledSubgraphImpl: ) -> FusedSampledSubgraphImpl:
"""Sample neighboring edges of the given nodes and return the induced """Sample neighboring edges of the given nodes and return the induced
subgraph via layer-neighbor sampling from the NeurIPS 2023 paper subgraph via layer-neighbor sampling from the NeurIPS 2023 paper
`Layer-Neighbor Sampling -- Defusing Neighborhood Explosion in GNNs `Layer-Neighbor Sampling -- Defusing Neighborhood Explosion in GNNs
...@@ -591,7 +591,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -591,7 +591,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
equalling the total number of edges. equalling the total number of edges.
Returns Returns
------- -------
SampledSubgraphImpl FusedSampledSubgraphImpl
The sampled subgraph. The sampled subgraph.
Examples Examples
......
...@@ -5,7 +5,7 @@ from torch.utils.data import functional_datapipe ...@@ -5,7 +5,7 @@ from torch.utils.data import functional_datapipe
from ..subgraph_sampler import SubgraphSampler from ..subgraph_sampler import SubgraphSampler
from ..utils import unique_and_compact_node_pairs from ..utils import unique_and_compact_node_pairs
from .sampled_subgraph_impl import SampledSubgraphImpl from .sampled_subgraph_impl import FusedSampledSubgraphImpl
__all__ = ["NeighborSampler", "LayerNeighborSampler"] __all__ = ["NeighborSampler", "LayerNeighborSampler"]
...@@ -124,7 +124,7 @@ class NeighborSampler(SubgraphSampler): ...@@ -124,7 +124,7 @@ class NeighborSampler(SubgraphSampler):
) = unique_and_compact_node_pairs(subgraph.node_pairs, seeds) ) = unique_and_compact_node_pairs(subgraph.node_pairs, seeds)
else: else:
raise RuntimeError("Not implemented yet.") raise RuntimeError("Not implemented yet.")
subgraph = SampledSubgraphImpl( subgraph = FusedSampledSubgraphImpl(
node_pairs=compacted_node_pairs, node_pairs=compacted_node_pairs,
original_column_node_ids=seeds, original_column_node_ids=seeds,
original_row_node_ids=original_row_node_ids, original_row_node_ids=original_row_node_ids,
......
...@@ -8,11 +8,11 @@ import torch ...@@ -8,11 +8,11 @@ import torch
from ..base import etype_str_to_tuple from ..base import etype_str_to_tuple
from ..sampled_subgraph import SampledSubgraph from ..sampled_subgraph import SampledSubgraph
__all__ = ["SampledSubgraphImpl"] __all__ = ["FusedSampledSubgraphImpl"]
@dataclass @dataclass
class SampledSubgraphImpl(SampledSubgraph): class FusedSampledSubgraphImpl(SampledSubgraph):
r"""Sampled subgraph of FusedCSCSamplingGraph. r"""Sampled subgraph of FusedCSCSamplingGraph.
Examples Examples
...@@ -22,7 +22,7 @@ class SampledSubgraphImpl(SampledSubgraph): ...@@ -22,7 +22,7 @@ class SampledSubgraphImpl(SampledSubgraph):
>>> original_column_node_ids = {'B': torch.tensor([10, 11, 12])} >>> original_column_node_ids = {'B': torch.tensor([10, 11, 12])}
>>> original_row_node_ids = {'A': torch.tensor([13, 14, 15])} >>> original_row_node_ids = {'A': torch.tensor([13, 14, 15])}
>>> original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])} >>> original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
>>> subgraph = gb.SampledSubgraphImpl( >>> subgraph = gb.FusedSampledSubgraphImpl(
... node_pairs=node_pairs, ... node_pairs=node_pairs,
... original_column_node_ids=original_column_node_ids, ... original_column_node_ids=original_column_node_ids,
... original_row_node_ids=original_row_node_ids, ... original_row_node_ids=original_row_node_ids,
......
...@@ -453,16 +453,16 @@ def _minibatch_str(minibatch: MiniBatch) -> str: ...@@ -453,16 +453,16 @@ def _minibatch_str(minibatch: MiniBatch) -> str:
if isinstance(val, list): if isinstance(val, list):
if len(val) == 0: if len(val) == 0:
val = "[]" val = "[]"
# Special handling of SampledSubgraphImpl data. Each element of # Special handling of FusedSampledSubgraphImpl data. Each element of
# the data occupies one row and is further structured. # the data occupies one row and is further structured.
elif isinstance( elif isinstance(
val[0], val[0],
dgl.graphbolt.impl.sampled_subgraph_impl.SampledSubgraphImpl, dgl.graphbolt.impl.sampled_subgraph_impl.FusedSampledSubgraphImpl,
): ):
sampledsubgraph_strs = [] sampledsubgraph_strs = []
for sampledsubgraph in val: for sampledsubgraph in val:
ss_attributes = _get_attributes(sampledsubgraph) ss_attributes = _get_attributes(sampledsubgraph)
sampledsubgraph_str = "SampledSubgraphImpl(" sampledsubgraph_str = "FusedSampledSubgraphImpl("
for ss_name in ss_attributes: for ss_name in ss_attributes:
ss_val = str(getattr(sampledsubgraph, ss_name)) ss_val = str(getattr(sampledsubgraph, ss_name))
sampledsubgraph_str = ( sampledsubgraph_str = (
......
...@@ -123,7 +123,7 @@ class SampledSubgraph: ...@@ -123,7 +123,7 @@ class SampledSubgraph:
>>> original_column_node_ids = {'B': torch.tensor([10, 11, 12])} >>> original_column_node_ids = {'B': torch.tensor([10, 11, 12])}
>>> original_row_node_ids = {'A': torch.tensor([13, 14, 15])} >>> original_row_node_ids = {'A': torch.tensor([13, 14, 15])}
>>> original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])} >>> original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
>>> subgraph = gb.SampledSubgraphImpl( >>> subgraph = gb.FusedSampledSubgraphImpl(
... node_pairs=node_pairs, ... node_pairs=node_pairs,
... original_column_node_ids=original_column_node_ids, ... original_column_node_ids=original_column_node_ids,
... original_row_node_ids=original_row_node_ids, ... original_row_node_ids=original_row_node_ids,
......
...@@ -39,7 +39,7 @@ def create_homo_minibatch(): ...@@ -39,7 +39,7 @@ def create_homo_minibatch():
subgraphs = [] subgraphs = []
for i in range(2): for i in range(2):
subgraphs.append( subgraphs.append(
gb.SampledSubgraphImpl( gb.FusedSampledSubgraphImpl(
node_pairs=node_pairs[i], node_pairs=node_pairs[i],
original_column_node_ids=original_column_node_ids[i], original_column_node_ids=original_column_node_ids[i],
original_row_node_ids=original_row_node_ids[i], original_row_node_ids=original_row_node_ids[i],
...@@ -93,7 +93,7 @@ def create_hetero_minibatch(): ...@@ -93,7 +93,7 @@ def create_hetero_minibatch():
subgraphs = [] subgraphs = []
for i in range(2): for i in range(2):
subgraphs.append( subgraphs.append(
gb.SampledSubgraphImpl( gb.FusedSampledSubgraphImpl(
node_pairs=node_pairs[i], node_pairs=node_pairs[i],
original_column_node_ids=original_column_node_ids[i], original_column_node_ids=original_column_node_ids[i],
original_row_node_ids=original_row_node_ids[i], original_row_node_ids=original_row_node_ids[i],
...@@ -142,7 +142,7 @@ def test_minibatch_representation(): ...@@ -142,7 +142,7 @@ def test_minibatch_representation():
subgraphs = [] subgraphs = []
for i in range(2): for i in range(2):
subgraphs.append( subgraphs.append(
gb.SampledSubgraphImpl( gb.FusedSampledSubgraphImpl(
node_pairs=node_pairs[i], node_pairs=node_pairs[i],
original_column_node_ids=original_column_node_ids[i], original_column_node_ids=original_column_node_ids[i],
original_row_node_ids=original_row_node_ids[i], original_row_node_ids=original_row_node_ids[i],
...@@ -191,11 +191,11 @@ def test_minibatch_representation(): ...@@ -191,11 +191,11 @@ def test_minibatch_representation():
) )
expect_result = str( expect_result = str(
"""MiniBatch(seed_nodes=None, """MiniBatch(seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(node_pairs=(tensor([0, 1, 2, 2, 2, 1]), tensor([0, 1, 1, 2, 3, 2])), sampled_subgraphs=[FusedSampledSubgraphImpl(node_pairs=(tensor([0, 1, 2, 2, 2, 1]), tensor([0, 1, 1, 2, 3, 2])),
original_column_node_ids=tensor([10, 11, 12, 13]), original_column_node_ids=tensor([10, 11, 12, 13]),
original_edge_ids=tensor([19, 20, 21, 22, 25, 30]), original_edge_ids=tensor([19, 20, 21, 22, 25, 30]),
original_row_node_ids=tensor([10, 11, 12, 13]),), original_row_node_ids=tensor([10, 11, 12, 13]),),
SampledSubgraphImpl(node_pairs=(tensor([0, 1, 2]), tensor([1, 0, 0])), FusedSampledSubgraphImpl(node_pairs=(tensor([0, 1, 2]), tensor([1, 0, 0])),
original_column_node_ids=tensor([10, 11]), original_column_node_ids=tensor([10, 11]),
original_edge_ids=tensor([10, 15, 17]), original_edge_ids=tensor([10, 15, 17]),
original_row_node_ids=tensor([10, 11, 12]),)], original_row_node_ids=tensor([10, 11, 12]),)],
...@@ -260,7 +260,7 @@ def test_dgl_minibatch_representation(): ...@@ -260,7 +260,7 @@ def test_dgl_minibatch_representation():
subgraphs = [] subgraphs = []
for i in range(2): for i in range(2):
subgraphs.append( subgraphs.append(
gb.SampledSubgraphImpl( gb.FusedSampledSubgraphImpl(
node_pairs=node_pairs[i], node_pairs=node_pairs[i],
original_column_node_ids=original_column_node_ids[i], original_column_node_ids=original_column_node_ids[i],
original_row_node_ids=original_row_node_ids[i], original_row_node_ids=original_row_node_ids[i],
......
...@@ -4,7 +4,7 @@ import backend as F ...@@ -4,7 +4,7 @@ import backend as F
import pytest import pytest
import torch import torch
from dgl.graphbolt.impl.sampled_subgraph_impl import SampledSubgraphImpl from dgl.graphbolt.impl.sampled_subgraph_impl import FusedSampledSubgraphImpl
def _assert_container_equal(lhs, rhs): def _assert_container_equal(lhs, rhs):
...@@ -42,7 +42,7 @@ def test_exclude_edges_homo(reverse_row, reverse_column): ...@@ -42,7 +42,7 @@ def test_exclude_edges_homo(reverse_row, reverse_column):
original_column_node_ids = None original_column_node_ids = None
dst_to_exclude = torch.tensor([4]) dst_to_exclude = torch.tensor([4])
original_edge_ids = torch.Tensor([5, 9, 10]) original_edge_ids = torch.Tensor([5, 9, 10])
subgraph = SampledSubgraphImpl( subgraph = FusedSampledSubgraphImpl(
node_pairs, node_pairs,
original_column_node_ids, original_column_node_ids,
original_row_node_ids, original_row_node_ids,
...@@ -95,7 +95,7 @@ def test_exclude_edges_hetero(reverse_row, reverse_column): ...@@ -95,7 +95,7 @@ def test_exclude_edges_hetero(reverse_row, reverse_column):
original_column_node_ids = None original_column_node_ids = None
dst_to_exclude = torch.tensor([0, 2]) dst_to_exclude = torch.tensor([0, 2])
original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])} original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
subgraph = SampledSubgraphImpl( subgraph = FusedSampledSubgraphImpl(
node_pairs=node_pairs, node_pairs=node_pairs,
original_column_node_ids=original_column_node_ids, original_column_node_ids=original_column_node_ids,
original_row_node_ids=original_row_node_ids, original_row_node_ids=original_row_node_ids,
...@@ -158,7 +158,7 @@ def test_sampled_subgraph_to_device(): ...@@ -158,7 +158,7 @@ def test_sampled_subgraph_to_device():
} }
dst_to_exclude = torch.tensor([10, 12]) dst_to_exclude = torch.tensor([10, 12])
original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])} original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
subgraph = SampledSubgraphImpl( subgraph = FusedSampledSubgraphImpl(
node_pairs=node_pairs, node_pairs=node_pairs,
original_column_node_ids=original_column_node_ids, original_column_node_ids=original_column_node_ids,
original_row_node_ids=original_row_node_ids, original_row_node_ids=original_row_node_ids,
......
...@@ -77,7 +77,7 @@ def test_FeatureFetcher_with_edges_homo(): ...@@ -77,7 +77,7 @@ def test_FeatureFetcher_with_edges_homo():
subgraphs = [] subgraphs = []
for _ in range(3): for _ in range(3):
subgraphs.append( subgraphs.append(
gb.SampledSubgraphImpl( gb.FusedSampledSubgraphImpl(
node_pairs=(torch.tensor([]), torch.tensor([])), node_pairs=(torch.tensor([]), torch.tensor([])),
original_edge_ids=torch.randint( original_edge_ids=torch.randint(
0, graph.total_num_edges, (10,) 0, graph.total_num_edges, (10,)
...@@ -168,7 +168,7 @@ def test_FeatureFetcher_with_edges_hetero(): ...@@ -168,7 +168,7 @@ def test_FeatureFetcher_with_edges_hetero():
} }
for _ in range(3): for _ in range(3):
subgraphs.append( subgraphs.append(
gb.SampledSubgraphImpl( gb.FusedSampledSubgraphImpl(
node_pairs=(torch.tensor([]), torch.tensor([])), node_pairs=(torch.tensor([]), torch.tensor([])),
original_edge_ids=original_edge_ids, original_edge_ids=original_edge_ids,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment