Unverified Commit 9116f673 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt]Unique and compact OP (#6098)

parent 88964a82
/**
* Copyright (c) 2023 by Contributors
*
* @file unique_and_compact.h
* @brief Unique and compact op.
*/
#ifndef GRAPHBOLT_UNIQUE_AND_COMPACT_H_
#define GRAPHBOLT_UNIQUE_AND_COMPACT_H_
#include <torch/torch.h>
namespace graphbolt {
namespace sampling {
/**
* @brief Removes duplicate elements from the concatenated 'unique_dst_ids' and
* 'src_ids' tensor and applies the uniqueness information to compact both
* source and destination tensors.
*
* The function performs two main operations:
* 1. Unique Operation: 'unique(concat(unique_dst_ids, src_ids))', in which
* the unique operator will guarantee the 'unique_dst_ids' are at the head of
* the result tensor.
* 2. Compact Operation: Utilizes the reverse mapping derived from the unique
* operation to transform 'src_ids' and 'dst_ids' into compacted IDs.
*
* @param src_ids A tensor containing source IDs.
* @param dst_ids A tensor containing destination IDs.
* @param unique_dst_ids A tensor containing unique destination IDs, which is
* exactly all the unique elements in 'dst_ids'.
*
* @return
* - A tensor representing all unique elements in 'src_ids' and 'dst_ids' after
* removing duplicates. The indices in this tensor precisely match the compacted
* IDs of the corresponding elements.
* - The tensor corresponding to the 'src_ids' tensor, where the entries are
* mapped to compacted IDs.
* - The tensor corresponding to the 'dst_ids' tensor, where the entries are
* mapped to compacted IDs.
*
* @example
* torch::Tensor src_ids = src
* torch::Tensor dst_ids = dst
* torch::Tensor unique_dst_ids = torch::unique(dst);
* auto result = UniqueAndCompact(src_ids, dst_ids, unique_dst_ids);
* torch::Tensor unique_ids = std::get<0>(result);
* torch::Tensor compacted_src_ids = std::get<1>(result);
* torch::Tensor compacted_dst_ids = std::get<2>(result);
*/
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
const torch::Tensor& src_ids, const torch::Tensor& dst_ids,
const torch::Tensor unique_dst_ids);
} // namespace sampling
} // namespace graphbolt
#endif // GRAPHBOLT_UNIQUE_AND_COMPACT_H_
......@@ -6,6 +6,7 @@
#include <graphbolt/csc_sampling_graph.h>
#include <graphbolt/serialize.h>
#include <graphbolt/unique_and_compact.h>
namespace graphbolt {
namespace sampling {
......@@ -39,6 +40,7 @@ TORCH_LIBRARY(graphbolt, m) {
m.def("load_csc_sampling_graph", &LoadCSCSamplingGraph);
m.def("save_csc_sampling_graph", &SaveCSCSamplingGraph);
m.def("load_from_shared_memory", &CSCSamplingGraph::LoadFromSharedMemory);
m.def("unique_and_compact", &UniqueAndCompact);
}
} // namespace sampling
......
/**
* Copyright (c) 2023 by Contributors
*
* @file unique_and_compact.cc
* @brief Unique and compact op.
*/
#include <graphbolt/unique_and_compact.h>
#include "./concurrent_id_hash_map.h"
namespace graphbolt {
namespace sampling {
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
const torch::Tensor& src_ids, const torch::Tensor& dst_ids,
const torch::Tensor unique_dst_ids) {
torch::Tensor compacted_src_ids;
torch::Tensor compacted_dst_ids;
torch::Tensor unique_ids;
auto num_dst = unique_dst_ids.size(0);
torch::Tensor ids = torch::cat({unique_dst_ids, src_ids});
AT_DISPATCH_INTEGRAL_TYPES(ids.scalar_type(), "unique_and_compact", ([&] {
ConcurrentIdHashMap<scalar_t> id_map;
unique_ids = id_map.Init(ids, num_dst);
compacted_src_ids = id_map.MapIds(src_ids);
compacted_dst_ids = id_map.MapIds(dst_ids);
}));
return std::tuple(unique_ids, compacted_src_ids, compacted_dst_ids);
}
} // namespace sampling
} // namespace graphbolt
......@@ -10,7 +10,11 @@ def unique_and_compact_node_pairs(
node_pairs: Union[
Tuple[torch.Tensor, torch.Tensor],
Dict[Tuple[str, str, str], Tuple[torch.Tensor, torch.Tensor]],
]
],
unique_dst_nodes: Union[
torch.Tensor,
Dict[str, torch.Tensor],
] = None,
):
"""
Compact node pairs and return unique nodes (per type).
......@@ -26,6 +30,11 @@ def unique_and_compact_node_pairs(
- If `node_pairs` is a dictionary: The keys should be edge type and
the values should be corresponding node pairs. And IDs inside are
heterogeneous ids.
unique_dst_nodes: torch.Tensor or Dict[str, torch.Tensor]
Unique nodes of all destination nodes in the node pairs.
- If `unique_dst_nodes` is a tensor: It means the graph is homogeneous.
- If `node_pairs` is a dictionary: The keys are node type and the
values are corresponding nodes. And IDs inside are heterogeneous ids.
Returns
-------
......@@ -52,44 +61,59 @@ def unique_and_compact_node_pairs(
{('n1', 'e1', 'n2'): (tensor([0, 1, 1]), tensor([0, 1, 0])),
('n2', 'e2', 'n1'): (tensor([0, 1, 0]), tensor([0, 1, 1]))}
"""
is_homogeneous = not isinstance(node_pairs, Dict)
is_homogeneous = not isinstance(node_pairs, dict)
if is_homogeneous:
node_pairs = {("_N", "_E", "_N"): node_pairs}
nodes_dict = defaultdict(list)
# Collect nodes for each node type.
for etype, node_pair in node_pairs.items():
u_type, _, v_type = etype
u, v = node_pair
nodes_dict[u_type].append(u)
nodes_dict[v_type].append(v)
if unique_dst_nodes is not None:
assert isinstance(
unique_dst_nodes, torch.Tensor
), "Edge type not supported in homogeneous graph."
unique_dst_nodes = {"_N": unique_dst_nodes}
unique_nodes_dict = {}
inverse_indices_dict = {}
for ntype, nodes in nodes_dict.items():
collected_nodes = torch.cat(nodes)
# Compact and find unique nodes.
unique_nodes, inverse_indices = torch.unique(
collected_nodes,
return_inverse=True,
)
unique_nodes_dict[ntype] = unique_nodes
inverse_indices_dict[ntype] = inverse_indices
# Collect all source and destination nodes for each node type.
src_nodes = defaultdict(list)
dst_nodes = defaultdict(list)
for etype, (src_node, dst_node) in node_pairs.items():
src_nodes[etype[0]].append(src_node)
dst_nodes[etype[2]].append(dst_node)
src_nodes = {ntype: torch.cat(nodes) for ntype, nodes in src_nodes.items()}
dst_nodes = {ntype: torch.cat(nodes) for ntype, nodes in dst_nodes.items()}
# Compute unique destination nodes if not provided.
if unique_dst_nodes is None:
unique_dst_nodes = {
ntype: torch.unique(nodes) for ntype, nodes in dst_nodes.items()
}
ntypes = set(dst_nodes.keys()) | set(src_nodes.keys())
unique_nodes = {}
compacted_src = {}
compacted_dst = {}
dtype = list(src_nodes.values())[0].dtype
default_tensor = torch.tensor([], dtype=dtype)
for ntype in ntypes:
src = src_nodes.get(ntype, default_tensor)
unique_dst = unique_dst_nodes.get(ntype, default_tensor)
dst = dst_nodes.get(ntype, default_tensor)
(
unique_nodes[ntype],
compacted_src[ntype],
compacted_dst[ntype],
) = torch.ops.graphbolt.unique_and_compact(src, dst, unique_dst)
# Map back in same order as collect.
compacted_node_pairs = {}
unique_nodes = unique_nodes_dict
for etype, node_pair in node_pairs.items():
u_type, _, v_type = etype
u, v = node_pair
u_size, v_size = u.numel(), v.numel()
u = inverse_indices_dict[u_type][:u_size]
inverse_indices_dict[u_type] = inverse_indices_dict[u_type][u_size:]
v = inverse_indices_dict[v_type][:v_size]
inverse_indices_dict[v_type] = inverse_indices_dict[v_type][v_size:]
compacted_node_pairs[etype] = (u, v)
# Map back with the same order.
for etype, pair in node_pairs.items():
num_elem = pair[0].size(0)
src_type, _, dst_type = etype
src = compacted_src[src_type][:num_elem]
dst = compacted_dst[dst_type][:num_elem]
compacted_node_pairs[etype] = (src, dst)
compacted_src[src_type] = compacted_src[src_type][num_elem:]
compacted_dst[dst_type] = compacted_dst[dst_type][num_elem:]
# Return singleton for homogeneous graph.
# Return singleton for a homogeneous graph.
if is_homogeneous:
compacted_node_pairs = list(compacted_node_pairs.values())[0]
unique_nodes = list(unique_nodes_dict.values())[0]
unique_nodes = list(unique_nodes.values())[0]
return unique_nodes, compacted_node_pairs
import dgl.graphbolt as gb
import pytest
import torch
......@@ -6,28 +7,14 @@ def test_unique_and_compact_node_pairs_hetero():
N1 = torch.randint(0, 50, (30,))
N2 = torch.randint(0, 50, (20,))
N3 = torch.randint(0, 50, (10,))
unique_N1, compacted_N1 = torch.unique(N1, return_inverse=True)
unique_N2, compacted_N2 = torch.unique(N2, return_inverse=True)
unique_N3, compacted_N3 = torch.unique(N3, return_inverse=True)
unique_N1 = torch.unique(N1)
unique_N2 = torch.unique(N2)
unique_N3 = torch.unique(N3)
expected_unique_nodes = {
"n1": unique_N1,
"n2": unique_N2,
"n3": unique_N3,
}
expected_compacted_pairs = {
("n1", "e1", "n2"): (
compacted_N1[:20],
compacted_N2,
),
("n1", "e2", "n3"): (
compacted_N1[20:30],
compacted_N3,
),
("n2", "e3", "n3"): (
compacted_N2[10:],
compacted_N3,
),
}
node_pairs = {
("n1", "e1", "n2"): (
N1[:20],
......@@ -46,27 +33,39 @@ def test_unique_and_compact_node_pairs_hetero():
unique_nodes, compacted_node_pairs = gb.unique_and_compact_node_pairs(
node_pairs
)
for ntype, nodes in unique_nodes.items():
expected_nodes = expected_unique_nodes[ntype]
assert torch.equal(torch.sort(nodes)[0], expected_nodes)
for etype, pair in compacted_node_pairs.items():
expected_u, expected_v = expected_compacted_pairs[etype]
u, v = pair
u_type, _, v_type = etype
u, v = unique_nodes[u_type][u], unique_nodes[v_type][v]
expected_u, expected_v = node_pairs[etype]
assert torch.equal(u, expected_u)
assert torch.equal(v, expected_v)
for ntype, nodes in unique_nodes.items():
expected_nodes = expected_unique_nodes[ntype]
assert torch.equal(nodes, expected_nodes)
def test_unique_and_compact_node_pairs_homo():
N = torch.randint(0, 50, (20,))
expected_unique_N, compacted_N = torch.unique(N, return_inverse=True)
expected_compacted_pairs = tuple(compacted_N.split(10))
N = torch.randint(0, 50, (200,))
expected_unique_N = torch.unique(N)
node_pairs = tuple(N.split(10))
node_pairs = tuple(N.split(100))
unique_nodes, compacted_node_pairs = gb.unique_and_compact_node_pairs(
node_pairs
)
expected_u, expected_v = expected_compacted_pairs
assert torch.equal(torch.sort(unique_nodes)[0], expected_unique_N)
u, v = compacted_node_pairs
u, v = unique_nodes[u], unique_nodes[v]
expected_u, expected_v = node_pairs
unique_v = torch.unique(expected_v)
assert torch.equal(u, expected_u)
assert torch.equal(v, expected_v)
assert torch.equal(unique_nodes, expected_unique_N)
assert torch.equal(unique_nodes[: unique_v.size(0)], unique_v)
def test_incomplete_unique_dst_nodes_():
node_pairs = (torch.randint(0, 50, (50,)), torch.randint(100, 150, (50,)))
unique_dst_nodes = torch.arange(150, 200)
with pytest.raises(IndexError):
gb.unique_and_compact_node_pairs(node_pairs, unique_dst_nodes)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment