"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "e28f07812bae274f5ba49d26629de5541c9cff60"
Unverified Commit 9116f673 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt]Unique and compact OP (#6098)

parent 88964a82
/**
* Copyright (c) 2023 by Contributors
*
* @file unique_and_compact.h
* @brief Unique and compact op.
*/
#ifndef GRAPHBOLT_UNIQUE_AND_COMPACT_H_
#define GRAPHBOLT_UNIQUE_AND_COMPACT_H_
#include <torch/torch.h>
namespace graphbolt {
namespace sampling {
/**
* @brief Removes duplicate elements from the concatenated 'unique_dst_ids' and
* 'src_ids' tensor and applies the uniqueness information to compact both
* source and destination tensors.
*
* The function performs two main operations:
* 1. Unique Operation: 'unique(concat(unique_dst_ids, src_ids))', in which
* the unique operator will guarantee the 'unique_dst_ids' are at the head of
* the result tensor.
* 2. Compact Operation: Utilizes the reverse mapping derived from the unique
* operation to transform 'src_ids' and 'dst_ids' into compacted IDs.
*
* @param src_ids A tensor containing source IDs.
* @param dst_ids A tensor containing destination IDs.
* @param unique_dst_ids A tensor containing unique destination IDs, which is
* exactly all the unique elements in 'dst_ids'.
*
* @return
* - A tensor representing all unique elements in 'src_ids' and 'dst_ids' after
* removing duplicates. The indices in this tensor precisely match the compacted
* IDs of the corresponding elements.
* - The tensor corresponding to the 'src_ids' tensor, where the entries are
* mapped to compacted IDs.
* - The tensor corresponding to the 'dst_ids' tensor, where the entries are
* mapped to compacted IDs.
*
* @example
* torch::Tensor src_ids = src
* torch::Tensor dst_ids = dst
* torch::Tensor unique_dst_ids = torch::unique(dst);
* auto result = UniqueAndCompact(src_ids, dst_ids, unique_dst_ids);
* torch::Tensor unique_ids = std::get<0>(result);
* torch::Tensor compacted_src_ids = std::get<1>(result);
* torch::Tensor compacted_dst_ids = std::get<2>(result);
*/
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
const torch::Tensor& src_ids, const torch::Tensor& dst_ids,
const torch::Tensor unique_dst_ids);
} // namespace sampling
} // namespace graphbolt
#endif // GRAPHBOLT_UNIQUE_AND_COMPACT_H_
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <graphbolt/csc_sampling_graph.h> #include <graphbolt/csc_sampling_graph.h>
#include <graphbolt/serialize.h> #include <graphbolt/serialize.h>
#include <graphbolt/unique_and_compact.h>
namespace graphbolt { namespace graphbolt {
namespace sampling { namespace sampling {
...@@ -39,6 +40,7 @@ TORCH_LIBRARY(graphbolt, m) { ...@@ -39,6 +40,7 @@ TORCH_LIBRARY(graphbolt, m) {
m.def("load_csc_sampling_graph", &LoadCSCSamplingGraph); m.def("load_csc_sampling_graph", &LoadCSCSamplingGraph);
m.def("save_csc_sampling_graph", &SaveCSCSamplingGraph); m.def("save_csc_sampling_graph", &SaveCSCSamplingGraph);
m.def("load_from_shared_memory", &CSCSamplingGraph::LoadFromSharedMemory); m.def("load_from_shared_memory", &CSCSamplingGraph::LoadFromSharedMemory);
m.def("unique_and_compact", &UniqueAndCompact);
} }
} // namespace sampling } // namespace sampling
......
/**
* Copyright (c) 2023 by Contributors
*
* @file unique_and_compact.cc
* @brief Unique and compact op.
*/
#include <graphbolt/unique_and_compact.h>
#include "./concurrent_id_hash_map.h"
namespace graphbolt {
namespace sampling {
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
const torch::Tensor& src_ids, const torch::Tensor& dst_ids,
const torch::Tensor unique_dst_ids) {
torch::Tensor compacted_src_ids;
torch::Tensor compacted_dst_ids;
torch::Tensor unique_ids;
auto num_dst = unique_dst_ids.size(0);
torch::Tensor ids = torch::cat({unique_dst_ids, src_ids});
AT_DISPATCH_INTEGRAL_TYPES(ids.scalar_type(), "unique_and_compact", ([&] {
ConcurrentIdHashMap<scalar_t> id_map;
unique_ids = id_map.Init(ids, num_dst);
compacted_src_ids = id_map.MapIds(src_ids);
compacted_dst_ids = id_map.MapIds(dst_ids);
}));
return std::tuple(unique_ids, compacted_src_ids, compacted_dst_ids);
}
} // namespace sampling
} // namespace graphbolt
...@@ -10,7 +10,11 @@ def unique_and_compact_node_pairs( ...@@ -10,7 +10,11 @@ def unique_and_compact_node_pairs(
node_pairs: Union[ node_pairs: Union[
Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor],
Dict[Tuple[str, str, str], Tuple[torch.Tensor, torch.Tensor]], Dict[Tuple[str, str, str], Tuple[torch.Tensor, torch.Tensor]],
] ],
unique_dst_nodes: Union[
torch.Tensor,
Dict[str, torch.Tensor],
] = None,
): ):
""" """
Compact node pairs and return unique nodes (per type). Compact node pairs and return unique nodes (per type).
...@@ -26,6 +30,11 @@ def unique_and_compact_node_pairs( ...@@ -26,6 +30,11 @@ def unique_and_compact_node_pairs(
- If `node_pairs` is a dictionary: The keys should be edge type and - If `node_pairs` is a dictionary: The keys should be edge type and
the values should be corresponding node pairs. And IDs inside are the values should be corresponding node pairs. And IDs inside are
heterogeneous ids. heterogeneous ids.
unique_dst_nodes: torch.Tensor or Dict[str, torch.Tensor]
Unique nodes of all destination nodes in the node pairs.
- If `unique_dst_nodes` is a tensor: It means the graph is homogeneous.
- If `node_pairs` is a dictionary: The keys are node type and the
values are corresponding nodes. And IDs inside are heterogeneous ids.
Returns Returns
------- -------
...@@ -52,44 +61,59 @@ def unique_and_compact_node_pairs( ...@@ -52,44 +61,59 @@ def unique_and_compact_node_pairs(
{('n1', 'e1', 'n2'): (tensor([0, 1, 1]), tensor([0, 1, 0])), {('n1', 'e1', 'n2'): (tensor([0, 1, 1]), tensor([0, 1, 0])),
('n2', 'e2', 'n1'): (tensor([0, 1, 0]), tensor([0, 1, 1]))} ('n2', 'e2', 'n1'): (tensor([0, 1, 0]), tensor([0, 1, 1]))}
""" """
is_homogeneous = not isinstance(node_pairs, Dict) is_homogeneous = not isinstance(node_pairs, dict)
if is_homogeneous: if is_homogeneous:
node_pairs = {("_N", "_E", "_N"): node_pairs} node_pairs = {("_N", "_E", "_N"): node_pairs}
nodes_dict = defaultdict(list) if unique_dst_nodes is not None:
# Collect nodes for each node type. assert isinstance(
for etype, node_pair in node_pairs.items(): unique_dst_nodes, torch.Tensor
u_type, _, v_type = etype ), "Edge type not supported in homogeneous graph."
u, v = node_pair unique_dst_nodes = {"_N": unique_dst_nodes}
nodes_dict[u_type].append(u)
nodes_dict[v_type].append(v)
unique_nodes_dict = {} # Collect all source and destination nodes for each node type.
inverse_indices_dict = {} src_nodes = defaultdict(list)
for ntype, nodes in nodes_dict.items(): dst_nodes = defaultdict(list)
collected_nodes = torch.cat(nodes) for etype, (src_node, dst_node) in node_pairs.items():
# Compact and find unique nodes. src_nodes[etype[0]].append(src_node)
unique_nodes, inverse_indices = torch.unique( dst_nodes[etype[2]].append(dst_node)
collected_nodes, src_nodes = {ntype: torch.cat(nodes) for ntype, nodes in src_nodes.items()}
return_inverse=True, dst_nodes = {ntype: torch.cat(nodes) for ntype, nodes in dst_nodes.items()}
) # Compute unique destination nodes if not provided.
unique_nodes_dict[ntype] = unique_nodes if unique_dst_nodes is None:
inverse_indices_dict[ntype] = inverse_indices unique_dst_nodes = {
ntype: torch.unique(nodes) for ntype, nodes in dst_nodes.items()
}
ntypes = set(dst_nodes.keys()) | set(src_nodes.keys())
unique_nodes = {}
compacted_src = {}
compacted_dst = {}
dtype = list(src_nodes.values())[0].dtype
default_tensor = torch.tensor([], dtype=dtype)
for ntype in ntypes:
src = src_nodes.get(ntype, default_tensor)
unique_dst = unique_dst_nodes.get(ntype, default_tensor)
dst = dst_nodes.get(ntype, default_tensor)
(
unique_nodes[ntype],
compacted_src[ntype],
compacted_dst[ntype],
) = torch.ops.graphbolt.unique_and_compact(src, dst, unique_dst)
# Map back in same order as collect.
compacted_node_pairs = {} compacted_node_pairs = {}
unique_nodes = unique_nodes_dict # Map back with the same order.
for etype, node_pair in node_pairs.items(): for etype, pair in node_pairs.items():
u_type, _, v_type = etype num_elem = pair[0].size(0)
u, v = node_pair src_type, _, dst_type = etype
u_size, v_size = u.numel(), v.numel() src = compacted_src[src_type][:num_elem]
u = inverse_indices_dict[u_type][:u_size] dst = compacted_dst[dst_type][:num_elem]
inverse_indices_dict[u_type] = inverse_indices_dict[u_type][u_size:] compacted_node_pairs[etype] = (src, dst)
v = inverse_indices_dict[v_type][:v_size] compacted_src[src_type] = compacted_src[src_type][num_elem:]
inverse_indices_dict[v_type] = inverse_indices_dict[v_type][v_size:] compacted_dst[dst_type] = compacted_dst[dst_type][num_elem:]
compacted_node_pairs[etype] = (u, v)
# Return singleton for homogeneous graph. # Return singleton for a homogeneous graph.
if is_homogeneous: if is_homogeneous:
compacted_node_pairs = list(compacted_node_pairs.values())[0] compacted_node_pairs = list(compacted_node_pairs.values())[0]
unique_nodes = list(unique_nodes_dict.values())[0] unique_nodes = list(unique_nodes.values())[0]
return unique_nodes, compacted_node_pairs return unique_nodes, compacted_node_pairs
import dgl.graphbolt as gb import dgl.graphbolt as gb
import pytest
import torch import torch
...@@ -6,28 +7,14 @@ def test_unique_and_compact_node_pairs_hetero(): ...@@ -6,28 +7,14 @@ def test_unique_and_compact_node_pairs_hetero():
N1 = torch.randint(0, 50, (30,)) N1 = torch.randint(0, 50, (30,))
N2 = torch.randint(0, 50, (20,)) N2 = torch.randint(0, 50, (20,))
N3 = torch.randint(0, 50, (10,)) N3 = torch.randint(0, 50, (10,))
unique_N1, compacted_N1 = torch.unique(N1, return_inverse=True) unique_N1 = torch.unique(N1)
unique_N2, compacted_N2 = torch.unique(N2, return_inverse=True) unique_N2 = torch.unique(N2)
unique_N3, compacted_N3 = torch.unique(N3, return_inverse=True) unique_N3 = torch.unique(N3)
expected_unique_nodes = { expected_unique_nodes = {
"n1": unique_N1, "n1": unique_N1,
"n2": unique_N2, "n2": unique_N2,
"n3": unique_N3, "n3": unique_N3,
} }
expected_compacted_pairs = {
("n1", "e1", "n2"): (
compacted_N1[:20],
compacted_N2,
),
("n1", "e2", "n3"): (
compacted_N1[20:30],
compacted_N3,
),
("n2", "e3", "n3"): (
compacted_N2[10:],
compacted_N3,
),
}
node_pairs = { node_pairs = {
("n1", "e1", "n2"): ( ("n1", "e1", "n2"): (
N1[:20], N1[:20],
...@@ -46,27 +33,39 @@ def test_unique_and_compact_node_pairs_hetero(): ...@@ -46,27 +33,39 @@ def test_unique_and_compact_node_pairs_hetero():
unique_nodes, compacted_node_pairs = gb.unique_and_compact_node_pairs( unique_nodes, compacted_node_pairs = gb.unique_and_compact_node_pairs(
node_pairs node_pairs
) )
for ntype, nodes in unique_nodes.items():
expected_nodes = expected_unique_nodes[ntype]
assert torch.equal(torch.sort(nodes)[0], expected_nodes)
for etype, pair in compacted_node_pairs.items(): for etype, pair in compacted_node_pairs.items():
expected_u, expected_v = expected_compacted_pairs[etype]
u, v = pair u, v = pair
u_type, _, v_type = etype
u, v = unique_nodes[u_type][u], unique_nodes[v_type][v]
expected_u, expected_v = node_pairs[etype]
assert torch.equal(u, expected_u) assert torch.equal(u, expected_u)
assert torch.equal(v, expected_v) assert torch.equal(v, expected_v)
for ntype, nodes in unique_nodes.items():
expected_nodes = expected_unique_nodes[ntype]
assert torch.equal(nodes, expected_nodes)
def test_unique_and_compact_node_pairs_homo(): def test_unique_and_compact_node_pairs_homo():
N = torch.randint(0, 50, (20,)) N = torch.randint(0, 50, (200,))
expected_unique_N, compacted_N = torch.unique(N, return_inverse=True) expected_unique_N = torch.unique(N)
expected_compacted_pairs = tuple(compacted_N.split(10))
node_pairs = tuple(N.split(10)) node_pairs = tuple(N.split(100))
unique_nodes, compacted_node_pairs = gb.unique_and_compact_node_pairs( unique_nodes, compacted_node_pairs = gb.unique_and_compact_node_pairs(
node_pairs node_pairs
) )
expected_u, expected_v = expected_compacted_pairs assert torch.equal(torch.sort(unique_nodes)[0], expected_unique_N)
u, v = compacted_node_pairs u, v = compacted_node_pairs
u, v = unique_nodes[u], unique_nodes[v]
expected_u, expected_v = node_pairs
unique_v = torch.unique(expected_v)
assert torch.equal(u, expected_u) assert torch.equal(u, expected_u)
assert torch.equal(v, expected_v) assert torch.equal(v, expected_v)
assert torch.equal(unique_nodes, expected_unique_N) assert torch.equal(unique_nodes[: unique_v.size(0)], unique_v)
def test_incomplete_unique_dst_nodes_():
node_pairs = (torch.randint(0, 50, (50,)), torch.randint(100, 150, (50,)))
unique_dst_nodes = torch.arange(150, 200)
with pytest.raises(IndexError):
gb.unique_and_compact_node_pairs(node_pairs, unique_dst_nodes)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment