"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "71ad16b463275ce91e9279ecc8233868f709cadf"
Unverified Commit ed3840fc authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt] Replace reverse with original (#6371)

parent 7e49ccef
......@@ -206,8 +206,8 @@ def to_dgl_blocks(sampled_subgraphs: gb.SampledSubgraphImpl):
blocks = [
dgl.create_block(
sampled_subgraph.node_pairs,
num_src_nodes=sampled_subgraph.reverse_row_node_ids.shape[0],
num_dst_nodes=sampled_subgraph.reverse_column_node_ids.shape[0],
num_src_nodes=sampled_subgraph.original_row_node_ids.shape[0],
num_dst_nodes=sampled_subgraph.original_column_node_ids.shape[0],
)
for sampled_subgraph in sampled_subgraphs
]
......
......@@ -22,17 +22,17 @@ namespace sampling {
* ```
* auto indptr = torch::tensor({0, 2, 3, 4}, {torch::kInt64});
* auto indices = torch::tensor({55, 101, 3, 3}, {torch::kInt64});
* auto reverse_column_node_ids = torch::tensor({3, 3, 101}, {torch::kInt64});
* auto original_column_node_ids = torch::tensor({3, 3, 101}, {torch::kInt64});
*
* SampledSubgraph sampledSubgraph(indptr, indices, reverse_column_node_ids);
* SampledSubgraph sampledSubgraph(indptr, indices, original_column_node_ids);
* ```
*
* The `reverse_column_node_ids` indicates that nodes `[3, 3, 101]` in the
* The `original_column_node_ids` indicates that nodes `[3, 3, 101]` in the
* original graph are mapped to `[0, 1, 2]` in this subgraph, and because
* `reverse_row_node_ids` is `Null`, `{55, 101, 3, 3}` in `indices` is just
* `original_row_node_ids` is `Null`, `{55, 101, 3, 3}` in `indices` is just
* the original node ids without compaction.
*
* If `reverse_row_node_ids = torch::tensor({55, 101, 3}, {torch::kInt64})`,
* If `original_row_node_ids = torch::tensor({55, 101, 3}, {torch::kInt64})`,
* it would indicate a different mapping for the row nodes. Note this is
* inconsistent with column, which is legal, as `3` is mapped to `0` and `1` in
* the column while `2` in the row.
......@@ -44,24 +44,24 @@ struct SampledSubgraph : torch::CustomClassHolder {
*
* @param indptr CSC format index pointer array.
* @param indices CSC format index array.
* @param reverse_column_node_ids Row's reverse node ids in the original
* @param original_column_node_ids Row's reverse node ids in the original
* graph.
* @param reverse_row_node_ids Column's reverse node ids in the original
* @param original_row_node_ids Column's reverse node ids in the original
* graph.
* @param reverse_edge_ids Reverse edge ids in the original graph.
* @param original_edge_ids Reverse edge ids in the original graph.
* @param type_per_edge Type id of each edge.
*/
SampledSubgraph(
torch::Tensor indptr, torch::Tensor indices,
torch::Tensor reverse_column_node_ids,
torch::optional<torch::Tensor> reverse_row_node_ids = torch::nullopt,
torch::optional<torch::Tensor> reverse_edge_ids = torch::nullopt,
torch::Tensor original_column_node_ids,
torch::optional<torch::Tensor> original_row_node_ids = torch::nullopt,
torch::optional<torch::Tensor> original_edge_ids = torch::nullopt,
torch::optional<torch::Tensor> type_per_edge = torch::nullopt)
: indptr(indptr),
indices(indices),
reverse_column_node_ids(reverse_column_node_ids),
reverse_row_node_ids(reverse_row_node_ids),
reverse_edge_ids(reverse_edge_ids),
original_column_node_ids(original_column_node_ids),
original_row_node_ids(original_row_node_ids),
original_edge_ids(original_edge_ids),
type_per_edge(type_per_edge) {}
SampledSubgraph() = default;
......@@ -69,14 +69,14 @@ struct SampledSubgraph : torch::CustomClassHolder {
/**
* @brief CSC format index pointer array, where the implicit node ids are
* already compacted. And the original ids are stored in the
* `reverse_column_node_ids` field.
* `original_column_node_ids` field.
*/
torch::Tensor indptr;
/**
* @brief CSC format index array, where the node ids can be compacted ids or
* original ids. If compacted, the original ids are stored in the
* `reverse_row_node_ids` field.
* `original_row_node_ids` field.
*/
torch::Tensor indices;
......@@ -88,7 +88,7 @@ struct SampledSubgraph : torch::CustomClassHolder {
* @note This is required and the mapping relations can be inconsistent with
* column's.
*/
torch::Tensor reverse_column_node_ids;
torch::Tensor original_column_node_ids;
/**
* @brief Row's reverse node ids in the original graph. A graph structure
......@@ -98,14 +98,14 @@ struct SampledSubgraph : torch::CustomClassHolder {
* @note This is optional and the mapping relations can be inconsistent with
* row's.
*/
torch::optional<torch::Tensor> reverse_row_node_ids;
torch::optional<torch::Tensor> original_row_node_ids;
/**
* @brief Reverse edge ids in the original graph, the edge with id
* `reverse_edge_ids[i]` in the original graph is mapped to `i` in this
* `original_edge_ids[i]` in the original graph is mapped to `i` in this
* subgraph. This is useful when edge features are needed.
*/
torch::optional<torch::Tensor> reverse_edge_ids;
torch::optional<torch::Tensor> original_edge_ids;
/**
* @brief Type id of each edge, where type id is the corresponding index of
......
......@@ -17,10 +17,11 @@ TORCH_LIBRARY(graphbolt, m) {
.def_readwrite("indptr", &SampledSubgraph::indptr)
.def_readwrite("indices", &SampledSubgraph::indices)
.def_readwrite(
"reverse_row_node_ids", &SampledSubgraph::reverse_row_node_ids)
"original_row_node_ids", &SampledSubgraph::original_row_node_ids)
.def_readwrite(
"reverse_column_node_ids", &SampledSubgraph::reverse_column_node_ids)
.def_readwrite("reverse_edge_ids", &SampledSubgraph::reverse_edge_ids)
"original_column_node_ids",
&SampledSubgraph::original_column_node_ids)
.def_readwrite("original_edge_ids", &SampledSubgraph::original_edge_ids)
.def_readwrite("type_per_edge", &SampledSubgraph::type_per_edge);
m.class_<CSCSamplingGraph>("CSCSamplingGraph")
.def("num_nodes", &CSCSamplingGraph::NumNodes)
......
......@@ -95,14 +95,14 @@ class FeatureFetcher(MiniBatchTransformer):
# Read Edge features.
if self.edge_feature_keys and data.sampled_subgraphs:
for i, subgraph in enumerate(data.sampled_subgraphs):
if subgraph.reverse_edge_ids is None:
if subgraph.original_edge_ids is None:
continue
if is_heterogeneous:
for (
type_name,
feature_names,
) in self.edge_feature_keys.items():
edges = subgraph.reverse_edge_ids.get(type_name, None)
edges = subgraph.original_edge_ids.get(type_name, None)
if edges is None:
continue
for feature_name in feature_names:
......@@ -119,6 +119,6 @@ class FeatureFetcher(MiniBatchTransformer):
"edge",
None,
feature_name,
subgraph.reverse_edge_ids,
subgraph.original_edge_ids,
)
return data
......@@ -225,7 +225,7 @@ class CSCSamplingGraph:
column_num = (
C_sampled_subgraph.indptr[1:] - C_sampled_subgraph.indptr[:-1]
)
column = C_sampled_subgraph.reverse_column_node_ids.repeat_interleave(
column = C_sampled_subgraph.original_column_node_ids.repeat_interleave(
column_num
)
row = C_sampled_subgraph.indices
......
......@@ -106,14 +106,14 @@ class NeighborSampler(SubgraphSampler):
self.replace,
self.prob_name,
)
reverse_column_node_ids = seeds
original_column_node_ids = seeds
seeds, compacted_node_pairs = unique_and_compact_node_pairs(
subgraph.node_pairs, seeds
)
subgraph = SampledSubgraphImpl(
node_pairs=compacted_node_pairs,
reverse_column_node_ids=reverse_column_node_ids,
reverse_row_node_ids=seeds,
original_column_node_ids=original_column_node_ids,
original_row_node_ids=seeds,
)
subgraphs.insert(0, subgraph)
return seeds, subgraphs
......
......@@ -17,31 +17,33 @@ class SampledSubgraphImpl(SampledSubgraph):
--------
>>> node_pairs = {"A:relation:B"): (torch.tensor([0, 1, 2]),
... torch.tensor([0, 1, 2]))}
>>> reverse_column_node_ids = {'B': torch.tensor([10, 11, 12])}
>>> reverse_row_node_ids = {'A': torch.tensor([13, 14, 15])}
>>> reverse_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
>>> original_column_node_ids = {'B': torch.tensor([10, 11, 12])}
>>> original_row_node_ids = {'A': torch.tensor([13, 14, 15])}
>>> original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
>>> subgraph = gb.SampledSubgraphImpl(
... node_pairs=node_pairs,
... reverse_column_node_ids=reverse_column_node_ids,
... reverse_row_node_ids=reverse_row_node_ids,
... reverse_edge_ids=reverse_edge_ids
... original_column_node_ids=original_column_node_ids,
... original_row_node_ids=original_row_node_ids,
... original_edge_ids=original_edge_ids
... )
>>> print(subgraph.node_pairs)
{"A:relation:B": (tensor([0, 1, 2]), tensor([0, 1, 2]))}
>>> print(subgraph.reverse_column_node_ids)
>>> print(subgraph.original_column_node_ids)
{'B': tensor([10, 11, 12])}
>>> print(subgraph.reverse_row_node_ids)
>>> print(subgraph.original_row_node_ids)
{'A': tensor([13, 14, 15])}
>>> print(subgraph.reverse_edge_ids)
>>> print(subgraph.original_edge_ids)
{"A:relation:B": tensor([19, 20, 21])}
"""
node_pairs: Union[
Dict[str, Tuple[torch.Tensor, torch.Tensor]],
Tuple[torch.Tensor, torch.Tensor],
] = None
reverse_column_node_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None
reverse_row_node_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None
reverse_edge_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None
original_column_node_ids: Union[
Dict[str, torch.Tensor], torch.Tensor
] = None
original_row_node_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None
original_edge_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None
def __post_init__(self):
if isinstance(self.node_pairs, dict):
......
......@@ -140,14 +140,14 @@ class MiniBatch:
blocks = []
for subgraph in self.sampled_subgraphs:
reverse_row_node_ids = subgraph.reverse_row_node_ids
original_row_node_ids = subgraph.original_row_node_ids
assert (
reverse_row_node_ids is not None
), "Missing `reverse_row_node_ids` in sampled subgraph."
reverse_column_node_ids = subgraph.reverse_column_node_ids
original_row_node_ids is not None
), "Missing `original_row_node_ids` in sampled subgraph."
original_column_node_ids = subgraph.original_column_node_ids
assert (
reverse_column_node_ids is not None
), "Missing `reverse_column_node_ids` in sampled subgraph."
original_column_node_ids is not None
), "Missing `original_column_node_ids` in sampled subgraph."
if is_heterogeneous:
node_pairs = {
etype_str_to_tuple(etype): v
......@@ -155,16 +155,16 @@ class MiniBatch:
}
num_src_nodes = {
ntype: nodes.size(0)
for ntype, nodes in reverse_row_node_ids.items()
for ntype, nodes in original_row_node_ids.items()
}
num_dst_nodes = {
ntype: nodes.size(0)
for ntype, nodes in reverse_column_node_ids.items()
for ntype, nodes in original_column_node_ids.items()
}
else:
node_pairs = subgraph.node_pairs
num_src_nodes = reverse_row_node_ids.size(0)
num_dst_nodes = reverse_column_node_ids.size(0)
num_src_nodes = original_row_node_ids.size(0)
num_dst_nodes = original_column_node_ids.size(0)
blocks.append(
dgl.create_block(
node_pairs,
......@@ -194,15 +194,15 @@ class MiniBatch:
# Assign reverse node ids to the outermost layer's source nodes.
for node_type, reverse_ids in self.sampled_subgraphs[
0
].reverse_row_node_ids.items():
].original_row_node_ids.items():
blocks[0].srcnodes[node_type].data[dgl.NID] = reverse_ids
# Assign reverse edges ids.
for block, subgraph in zip(blocks, self.sampled_subgraphs):
if subgraph.reverse_edge_ids:
if subgraph.original_edge_ids:
for (
edge_type,
reverse_ids,
) in subgraph.reverse_edge_ids.items():
) in subgraph.original_edge_ids.items():
block.edges[etype_str_to_tuple(edge_type)].data[
dgl.EID
] = reverse_ids
......@@ -218,11 +218,11 @@ class MiniBatch:
block.edata[feature_name] = feature
blocks[0].srcdata[dgl.NID] = self.sampled_subgraphs[
0
].reverse_row_node_ids
].original_row_node_ids
# Assign reverse edges ids.
for block, subgraph in zip(blocks, self.sampled_subgraphs):
if subgraph.reverse_edge_ids is not None:
block.edata[dgl.EID] = subgraph.reverse_edge_ids
if subgraph.original_edge_ids is not None:
block.edata[dgl.EID] = subgraph.original_edge_ids
return blocks
......
......@@ -29,16 +29,16 @@ class SampledSubgraph:
raise NotImplementedError
@property
def reverse_column_node_ids(
def original_column_node_ids(
self,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
"""Returns corresponding reverse column node ids the original graph.
Column's reverse node ids in the original graph. A graph structure
can be treated as a coordinated row and column pair, and this is
the mapped ids of the column.
- If `reverse_column_node_ids` is a tensor: It represents the
- If `original_column_node_ids` is a tensor: It represents the
original node ids.
- If `reverse_column_node_ids` is a dictionary: The keys should be
- If `original_column_node_ids` is a dictionary: The keys should be
node type and the values should be corresponding original
heterogeneous node ids.
If present, it means column IDs are compacted, and `node_pairs`
......@@ -47,16 +47,16 @@ class SampledSubgraph:
return None
@property
def reverse_row_node_ids(
def original_row_node_ids(
self,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
"""Returns corresponding reverse row node ids the original graph.
Row's reverse node ids in the original graph. A graph structure
can be treated as a coordinated row and column pair, and this is
the mapped ids of the row.
- If `reverse_row_node_ids` is a tensor: It represents the
- If `original_row_node_ids` is a tensor: It represents the
original node ids.
- If `reverse_row_node_ids` is a dictionary: The keys should be
- If `original_row_node_ids` is a dictionary: The keys should be
node type and the values should be corresponding original
heterogeneous node ids.
If present, it means row IDs are compacted, and `node_pairs`
......@@ -64,13 +64,13 @@ class SampledSubgraph:
return None
@property
def reverse_edge_ids(self) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
def original_edge_ids(self) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
"""Returns corresponding reverse edge ids the original graph.
Reverse edge ids in the original graph. This is useful when edge
features are needed.
- If `reverse_edge_ids` is a tensor: It represents the
- If `original_edge_ids` is a tensor: It represents the
original edge ids.
- If `reverse_edge_ids` is a dictionary: The keys should be
- If `original_edge_ids` is a dictionary: The keys should be
edge type and the values should be corresponding original
heterogeneous edge ids.
"""
......@@ -110,24 +110,24 @@ class SampledSubgraph:
--------
>>> node_pairs = {"A:relation:B": (torch.tensor([0, 1, 2]),
... torch.tensor([0, 1, 2]))}
>>> reverse_column_node_ids = {'B': torch.tensor([10, 11, 12])}
>>> reverse_row_node_ids = {'A': torch.tensor([13, 14, 15])}
>>> reverse_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
>>> original_column_node_ids = {'B': torch.tensor([10, 11, 12])}
>>> original_row_node_ids = {'A': torch.tensor([13, 14, 15])}
>>> original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
>>> subgraph = gb.SampledSubgraphImpl(
... node_pairs=node_pairs,
... reverse_column_node_ids=reverse_column_node_ids,
... reverse_row_node_ids=reverse_row_node_ids,
... reverse_edge_ids=reverse_edge_ids
... original_column_node_ids=original_column_node_ids,
... original_row_node_ids=original_row_node_ids,
... original_edge_ids=original_edge_ids
... )
>>> edges_to_exclude = (torch.tensor([14, 15]), torch.tensor([11, 12]))
>>> result = subgraph.exclude_edges(edges_to_exclude)
>>> print(result.node_pairs)
{"A:relation:B": (tensor([0]), tensor([0]))}
>>> print(result.reverse_column_node_ids)
>>> print(result.original_column_node_ids)
{'B': tensor([10, 11, 12])}
>>> print(result.reverse_row_node_ids)
>>> print(result.original_row_node_ids)
{'A': tensor([13, 14, 15])}
>>> print(result.reverse_edge_ids)
>>> print(result.original_edge_ids)
{"A:relation:B": tensor([19])}
"""
assert isinstance(self.node_pairs, tuple) == isinstance(edges, tuple), (
......@@ -144,8 +144,8 @@ class SampledSubgraph:
if isinstance(self.node_pairs, tuple):
reverse_edges = _to_reverse_ids(
self.node_pairs,
self.reverse_row_node_ids,
self.reverse_column_node_ids,
self.original_row_node_ids,
self.original_column_node_ids,
)
index = _exclude_homo_edges(reverse_edges, edges)
return calling_class(*_slice_subgraph(self, index))
......@@ -153,20 +153,20 @@ class SampledSubgraph:
index = {}
for etype, pair in self.node_pairs.items():
src_type, _, dst_type = etype_str_to_tuple(etype)
reverse_row_node_ids = (
original_row_node_ids = (
None
if self.reverse_row_node_ids is None
else self.reverse_row_node_ids.get(src_type)
if self.original_row_node_ids is None
else self.original_row_node_ids.get(src_type)
)
reverse_column_node_ids = (
original_column_node_ids = (
None
if self.reverse_column_node_ids is None
else self.reverse_column_node_ids.get(dst_type)
if self.original_column_node_ids is None
else self.original_column_node_ids.get(dst_type)
)
reverse_edges = _to_reverse_ids(
pair,
reverse_row_node_ids,
reverse_column_node_ids,
original_row_node_ids,
original_column_node_ids,
)
index[etype] = _exclude_homo_edges(
reverse_edges, edges.get(etype)
......@@ -174,12 +174,12 @@ class SampledSubgraph:
return calling_class(*_slice_subgraph(self, index))
def _to_reverse_ids(node_pair, reverse_row_node_ids, reverse_column_node_ids):
def _to_reverse_ids(node_pair, original_row_node_ids, original_column_node_ids):
u, v = node_pair
if reverse_row_node_ids is not None:
u = reverse_row_node_ids[u]
if reverse_column_node_ids is not None:
v = reverse_column_node_ids[v]
if original_row_node_ids is not None:
u = original_row_node_ids[u]
if original_column_node_ids is not None:
v = original_column_node_ids[v]
return (u, v)
......@@ -224,7 +224,7 @@ def _slice_subgraph(subgraph: SampledSubgraph, index: torch.Tensor):
return (
_index_select(subgraph.node_pairs, index),
subgraph.reverse_column_node_ids,
subgraph.reverse_row_node_ids,
_index_select(subgraph.reverse_edge_ids, index),
subgraph.original_column_node_ids,
subgraph.original_row_node_ids,
_index_select(subgraph.original_edge_ids, index),
)
......@@ -397,12 +397,12 @@ def test_in_subgraph_homogeneous():
assert torch.equal(
in_subgraph.indices, torch.LongTensor([2, 3, 1, 2, 0, 3, 4])
)
assert torch.equal(in_subgraph.reverse_column_node_ids, nodes)
assert torch.equal(in_subgraph.original_column_node_ids, nodes)
assert torch.equal(
in_subgraph.reverse_row_node_ids, torch.arange(0, num_nodes)
in_subgraph.original_row_node_ids, torch.arange(0, num_nodes)
)
assert torch.equal(
in_subgraph.reverse_edge_ids, torch.LongTensor([3, 4, 7, 8, 9, 10, 11])
in_subgraph.original_edge_ids, torch.LongTensor([3, 4, 7, 8, 9, 10, 11])
)
assert in_subgraph.type_per_edge is None
......@@ -463,12 +463,12 @@ def test_in_subgraph_heterogeneous():
assert torch.equal(
in_subgraph.indices, torch.LongTensor([2, 3, 1, 2, 0, 3, 4])
)
assert torch.equal(in_subgraph.reverse_column_node_ids, nodes)
assert torch.equal(in_subgraph.original_column_node_ids, nodes)
assert torch.equal(
in_subgraph.reverse_row_node_ids, torch.arange(0, num_nodes)
in_subgraph.original_row_node_ids, torch.arange(0, num_nodes)
)
assert torch.equal(
in_subgraph.reverse_edge_ids, torch.LongTensor([3, 4, 7, 8, 9, 10, 11])
in_subgraph.original_edge_ids, torch.LongTensor([3, 4, 7, 8, 9, 10, 11])
)
assert torch.equal(
in_subgraph.type_per_edge, torch.LongTensor([2, 2, 1, 3, 1, 3, 3])
......@@ -505,9 +505,9 @@ def test_sample_neighbors_homo():
# Verify in subgraph.
sampled_num = subgraph.node_pairs[0].size(0)
assert sampled_num == 6
assert subgraph.reverse_column_node_ids is None
assert subgraph.reverse_row_node_ids is None
assert subgraph.reverse_edge_ids is None
assert subgraph.original_column_node_ids is None
assert subgraph.original_row_node_ids is None
assert subgraph.original_edge_ids is None
@unittest.skipIf(
......@@ -568,9 +568,9 @@ def test_sample_neighbors_hetero(labor):
for etype, pairs in expected_node_pairs.items():
assert torch.equal(subgraph.node_pairs[etype][0], pairs[0])
assert torch.equal(subgraph.node_pairs[etype][1], pairs[1])
assert subgraph.reverse_column_node_ids is None
assert subgraph.reverse_row_node_ids is None
assert subgraph.reverse_edge_ids is None
assert subgraph.original_column_node_ids is None
assert subgraph.original_row_node_ids is None
assert subgraph.original_edge_ids is None
# Sample on single node type.
nodes = {"n1": torch.LongTensor([0])}
......@@ -593,9 +593,9 @@ def test_sample_neighbors_hetero(labor):
for etype, pairs in expected_node_pairs.items():
assert torch.equal(subgraph.node_pairs[etype][0], pairs[0])
assert torch.equal(subgraph.node_pairs[etype][1], pairs[1])
assert subgraph.reverse_column_node_ids is None
assert subgraph.reverse_row_node_ids is None
assert subgraph.reverse_edge_ids is None
assert subgraph.original_column_node_ids is None
assert subgraph.original_row_node_ids is None
assert subgraph.original_edge_ids is None
@unittest.skipIf(
......
......@@ -13,11 +13,11 @@ def test_to_dgl_blocks_hetero():
},
{relation: (torch.tensor([0, 1]), torch.tensor([1, 0]))},
]
reverse_column_node_ids = [
original_column_node_ids = [
{"B": torch.tensor([10, 11, 12]), "A": torch.tensor([5, 7, 9, 11])},
{"B": torch.tensor([10, 11])},
]
reverse_row_node_ids = [
original_row_node_ids = [
{
"A": torch.tensor([5, 7, 9, 11]),
"B": torch.tensor([10, 11, 12]),
......@@ -27,7 +27,7 @@ def test_to_dgl_blocks_hetero():
"B": torch.tensor([10, 11]),
},
]
reverse_edge_ids = [
original_edge_ids = [
{
relation: torch.tensor([19, 20, 21]),
reverse_relation: torch.tensor([23, 26]),
......@@ -46,9 +46,9 @@ def test_to_dgl_blocks_hetero():
subgraphs.append(
gb.SampledSubgraphImpl(
node_pairs=node_pairs[i],
reverse_column_node_ids=reverse_column_node_ids[i],
reverse_row_node_ids=reverse_row_node_ids[i],
reverse_edge_ids=reverse_edge_ids[i],
original_column_node_ids=original_column_node_ids[i],
original_row_node_ids=original_row_node_ids[i],
original_edge_ids=original_edge_ids[i],
)
)
blocks = gb.MiniBatch(
......@@ -63,7 +63,7 @@ def test_to_dgl_blocks_hetero():
assert torch.equal(edges[0], node_pairs[i][relation][0])
assert torch.equal(edges[1], node_pairs[i][relation][1])
assert torch.equal(
block.edges[etype].data[dgl.EID], reverse_edge_ids[i][relation]
block.edges[etype].data[dgl.EID], original_edge_ids[i][relation]
)
assert torch.equal(
block.edges[etype].data["x"],
......@@ -73,10 +73,10 @@ def test_to_dgl_blocks_hetero():
assert torch.equal(edges[0], node_pairs[0][reverse_relation][0])
assert torch.equal(edges[1], node_pairs[0][reverse_relation][1])
assert torch.equal(
blocks[0].srcdata[dgl.NID]["A"], reverse_row_node_ids[0]["A"]
blocks[0].srcdata[dgl.NID]["A"], original_row_node_ids[0]["A"]
)
assert torch.equal(
blocks[0].srcdata[dgl.NID]["B"], reverse_row_node_ids[0]["B"]
blocks[0].srcdata[dgl.NID]["B"], original_row_node_ids[0]["B"]
)
assert torch.equal(
blocks[0].srcnodes["A"].data["x"], node_features[("A", "x")]
......@@ -97,15 +97,15 @@ def test_to_dgl_blocks_homo():
torch.tensor([1, 0, 0]),
),
]
reverse_column_node_ids = [
original_column_node_ids = [
torch.tensor([10, 11, 12, 13]),
torch.tensor([10, 11]),
]
reverse_row_node_ids = [
original_row_node_ids = [
torch.tensor([10, 11, 12, 13]),
torch.tensor([10, 11, 12]),
]
reverse_edge_ids = [
original_edge_ids = [
torch.tensor([19, 20, 21, 22, 25, 30]),
torch.tensor([10, 15, 17]),
]
......@@ -119,9 +119,9 @@ def test_to_dgl_blocks_homo():
subgraphs.append(
gb.SampledSubgraphImpl(
node_pairs=node_pairs[i],
reverse_column_node_ids=reverse_column_node_ids[i],
reverse_row_node_ids=reverse_row_node_ids[i],
reverse_edge_ids=reverse_edge_ids[i],
original_column_node_ids=original_column_node_ids[i],
original_row_node_ids=original_row_node_ids[i],
original_edge_ids=original_edge_ids[i],
)
)
blocks = gb.MiniBatch(
......@@ -133,9 +133,9 @@ def test_to_dgl_blocks_homo():
for i, block in enumerate(blocks):
assert torch.equal(block.edges()[0], node_pairs[i][0])
assert torch.equal(block.edges()[1], node_pairs[i][1])
assert torch.equal(block.edata[dgl.EID], reverse_edge_ids[i])
assert torch.equal(block.edata[dgl.EID], original_edge_ids[i])
assert torch.equal(block.edata["x"], edge_features[i]["x"])
assert torch.equal(blocks[0].srcdata[dgl.NID], reverse_row_node_ids[0])
assert torch.equal(blocks[0].srcdata[dgl.NID], original_row_node_ids[0])
assert torch.equal(blocks[0].srcdata["x"], node_features["x"])
......@@ -150,15 +150,15 @@ def test_representation():
torch.tensor([1, 0, 0]),
),
]
reverse_column_node_ids = [
original_column_node_ids = [
torch.tensor([10, 11, 12, 13]),
torch.tensor([10, 11]),
]
reverse_row_node_ids = [
original_row_node_ids = [
torch.tensor([10, 11, 12, 13]),
torch.tensor([10, 11, 12]),
]
reverse_edge_ids = [
original_edge_ids = [
torch.tensor([19, 20, 21, 22, 25, 30]),
torch.tensor([10, 15, 17]),
]
......@@ -172,9 +172,9 @@ def test_representation():
subgraphs.append(
gb.SampledSubgraphImpl(
node_pairs=node_pairs[i],
reverse_column_node_ids=reverse_column_node_ids[i],
reverse_row_node_ids=reverse_row_node_ids[i],
reverse_edge_ids=reverse_edge_ids[i],
original_column_node_ids=original_column_node_ids[i],
original_row_node_ids=original_row_node_ids[i],
original_edge_ids=original_edge_ids[i],
)
)
negative_srcs = torch.tensor([[8], [1], [6]])
......@@ -220,13 +220,13 @@ def test_representation():
expect_result = str(
"""MiniBatch(seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(node_pairs=(tensor([0, 1, 2, 2, 2, 1]), tensor([0, 1, 1, 2, 3, 2])),
reverse_column_node_ids=tensor([10, 11, 12, 13]),
reverse_edge_ids=tensor([19, 20, 21, 22, 25, 30]),
reverse_row_node_ids=tensor([10, 11, 12, 13]),),
original_column_node_ids=tensor([10, 11, 12, 13]),
original_edge_ids=tensor([19, 20, 21, 22, 25, 30]),
original_row_node_ids=tensor([10, 11, 12, 13]),),
SampledSubgraphImpl(node_pairs=(tensor([0, 1, 2]), tensor([1, 0, 0])),
reverse_column_node_ids=tensor([10, 11]),
reverse_edge_ids=tensor([10, 15, 17]),
reverse_row_node_ids=tensor([10, 11, 12]),)],
original_column_node_ids=tensor([10, 11]),
original_edge_ids=tensor([10, 15, 17]),
original_row_node_ids=tensor([10, 11, 12]),)],
node_pairs=[(tensor([0, 1, 2, 2, 2, 1]), tensor([0, 1, 1, 2, 3, 2])),
(tensor([0, 1, 2]), tensor([1, 0, 0]))],
node_features={'x': tensor([7, 6, 2, 2])},
......
......@@ -26,24 +26,24 @@ def _assert_container_equal(lhs, rhs):
def test_exclude_edges_homo(reverse_row, reverse_column):
node_pairs = (torch.tensor([0, 2, 3]), torch.tensor([1, 4, 2]))
if reverse_row:
reverse_row_node_ids = torch.tensor([10, 15, 11, 24, 9])
original_row_node_ids = torch.tensor([10, 15, 11, 24, 9])
src_to_exclude = torch.tensor([11])
else:
reverse_row_node_ids = None
original_row_node_ids = None
src_to_exclude = torch.tensor([2])
if reverse_column:
reverse_column_node_ids = torch.tensor([10, 15, 11, 24, 9])
original_column_node_ids = torch.tensor([10, 15, 11, 24, 9])
dst_to_exclude = torch.tensor([9])
else:
reverse_column_node_ids = None
original_column_node_ids = None
dst_to_exclude = torch.tensor([4])
reverse_edge_ids = torch.Tensor([5, 9, 10])
original_edge_ids = torch.Tensor([5, 9, 10])
subgraph = SampledSubgraphImpl(
node_pairs,
reverse_column_node_ids,
reverse_row_node_ids,
reverse_edge_ids,
original_column_node_ids,
original_row_node_ids,
original_edge_ids,
)
edges_to_exclude = (src_to_exclude, dst_to_exclude)
result = subgraph.exclude_edges(edges_to_exclude)
......@@ -60,10 +60,10 @@ def test_exclude_edges_homo(reverse_row, reverse_column):
_assert_container_equal(result.node_pairs, expected_node_pairs)
_assert_container_equal(
result.reverse_column_node_ids, expected_column_node_ids
result.original_column_node_ids, expected_column_node_ids
)
_assert_container_equal(result.reverse_row_node_ids, expected_row_node_ids)
_assert_container_equal(result.reverse_edge_ids, expected_edge_ids)
_assert_container_equal(result.original_row_node_ids, expected_row_node_ids)
_assert_container_equal(result.original_edge_ids, expected_edge_ids)
@pytest.mark.parametrize("reverse_row", [True, False])
......@@ -76,27 +76,27 @@ def test_exclude_edges_hetero(reverse_row, reverse_column):
)
}
if reverse_row:
reverse_row_node_ids = {
original_row_node_ids = {
"A": torch.tensor([13, 14, 15]),
}
src_to_exclude = torch.tensor([15, 13])
else:
reverse_row_node_ids = None
original_row_node_ids = None
src_to_exclude = torch.tensor([2, 0])
if reverse_column:
reverse_column_node_ids = {
original_column_node_ids = {
"B": torch.tensor([10, 11, 12]),
}
dst_to_exclude = torch.tensor([10, 12])
else:
reverse_column_node_ids = None
original_column_node_ids = None
dst_to_exclude = torch.tensor([0, 2])
reverse_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
subgraph = SampledSubgraphImpl(
node_pairs=node_pairs,
reverse_column_node_ids=reverse_column_node_ids,
reverse_row_node_ids=reverse_row_node_ids,
reverse_edge_ids=reverse_edge_ids,
original_column_node_ids=original_column_node_ids,
original_row_node_ids=original_row_node_ids,
original_edge_ids=original_edge_ids,
)
edges_to_exclude = {
......@@ -128,7 +128,7 @@ def test_exclude_edges_hetero(reverse_row, reverse_column):
_assert_container_equal(result.node_pairs, expected_node_pairs)
_assert_container_equal(
result.reverse_column_node_ids, expected_column_node_ids
result.original_column_node_ids, expected_column_node_ids
)
_assert_container_equal(result.reverse_row_node_ids, expected_row_node_ids)
_assert_container_equal(result.reverse_edge_ids, expected_edge_ids)
_assert_container_equal(result.original_row_node_ids, expected_row_node_ids)
_assert_container_equal(result.original_edge_ids, expected_edge_ids)
......@@ -65,7 +65,7 @@ def test_FeatureFetcher_with_edges_homo():
subgraphs.append(
gb.SampledSubgraphImpl(
node_pairs=(torch.tensor([]), torch.tensor([])),
reverse_edge_ids=torch.randint(0, graph.num_edges, (10,)),
original_edge_ids=torch.randint(0, graph.num_edges, (10,)),
)
)
data = gb.MiniBatch(input_nodes=seeds, sampled_subgraphs=subgraphs)
......@@ -146,7 +146,7 @@ def test_FeatureFetcher_with_edges_hetero():
def add_node_and_edge_ids(seeds):
subgraphs = []
reverse_edge_ids = {
original_edge_ids = {
"n1:e1:n2": torch.randint(0, 50, (10,)),
"n2:e2:n1": torch.randint(0, 50, (10,)),
}
......@@ -154,7 +154,7 @@ def test_FeatureFetcher_with_edges_hetero():
subgraphs.append(
gb.SampledSubgraphImpl(
node_pairs=(torch.tensor([]), torch.tensor([])),
reverse_edge_ids=reverse_edge_ids,
original_edge_ids=original_edge_ids,
)
)
data = gb.MiniBatch(input_nodes=seeds, sampled_subgraphs=subgraphs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment