Unverified Commit d2eca855 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt] Speed up exclude edges (#6464)

parent 72b3e078
/**
* Copyright (c) 2023 by Contributors
*
* @file graphbolt/isin.h
* @brief isin op.
*/
#ifndef GRAPHBOLT_ISIN_H_
#define GRAPHBOLT_ISIN_H_
#include <torch/torch.h>
namespace graphbolt {
namespace sampling {
/**
* @brief Tests if each element of elements is in test_elements. Returns a
* boolean tensor of the same shape as elements that is True for elements
* in test_elements and False otherwise. Enhance torch.isin by implementing
* multi-threaded searching, as detailed in the documentation at
* https://pytorch.org/docs/stable/generated/torch.isin.html."
*
* @param elements Input elements
* @param test_elements Values against which to test for each input element.
*
* @return
* A boolean tensor of the same shape as elements that is True for elements
* in test_elements and False otherwise.
*
*/
torch::Tensor IsIn(
const torch::Tensor& elements, const torch::Tensor& test_elements);
} // namespace sampling
} // namespace graphbolt
#endif // GRAPHBOLT_ISIN_H_
/**
* Copyright (c) 2023 by Contributors
*
* @file isin.cc
* @brief Isin op.
*/
#include <graphbolt/isin.h>
namespace {
static constexpr int kSearchGrainSize = 4096;
} // namespace
namespace graphbolt {
namespace sampling {
torch::Tensor IsIn(
const torch::Tensor& elements, const torch::Tensor& test_elements) {
torch::Tensor sorted_test_elements;
std::tie(sorted_test_elements, std::ignore) = test_elements.sort(
/*stable=*/false, /*dim=*/0, /*descending=*/false);
torch::Tensor result = torch::empty_like(elements, torch::kBool);
size_t num_test_elements = test_elements.size(0);
size_t num_elements = elements.size(0);
AT_DISPATCH_INTEGRAL_TYPES(
elements.scalar_type(), "IsInOperation", ([&] {
const scalar_t* elements_ptr = elements.data_ptr<scalar_t>();
const scalar_t* sorted_test_elements_ptr =
sorted_test_elements.data_ptr<scalar_t>();
bool* result_ptr = result.data_ptr<bool>();
torch::parallel_for(
0, num_elements, kSearchGrainSize, [&](size_t start, size_t end) {
for (auto i = start; i < end; i++) {
result_ptr[i] = std::binary_search(
sorted_test_elements_ptr,
sorted_test_elements_ptr + num_test_elements,
elements_ptr[i]);
}
});
}));
return result;
}
} // namespace sampling
} // namespace graphbolt
......@@ -5,6 +5,7 @@
*/
#include <graphbolt/csc_sampling_graph.h>
#include <graphbolt/isin.h>
#include <graphbolt/serialize.h>
#include <graphbolt/unique_and_compact.h>
......@@ -56,6 +57,7 @@ TORCH_LIBRARY(graphbolt, m) {
m.def("save_csc_sampling_graph", &SaveCSCSamplingGraph);
m.def("load_from_shared_memory", &CSCSamplingGraph::LoadFromSharedMemory);
m.def("unique_and_compact", &UniqueAndCompact);
m.def("isin", &IsIn);
}
} // namespace sampling
......
"""Base types and utilities for Graph Bolt."""
import torch
from torch.utils.data import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe
......@@ -11,12 +12,35 @@ __all__ = [
"etype_str_to_tuple",
"etype_tuple_to_str",
"CopyTo",
"isin",
]
CANONICAL_ETYPE_DELIMITER = ":"
ORIGINAL_EDGE_ID = "_ORIGINAL_EDGE_ID"
def isin(elements, test_elements):
"""Tests if each element of elements is in test_elements. Returns a boolean
tensor of the same shape as elements that is True for elements in
test_elements and False otherwise.
Parameters
----------
elements : torch.Tensor
A 1D tensor represents the input elements.
test_elements : torch.Tensor
A 1D tensor represents the values to test against for each input.
Examples
--------
>>> isin(torch.tensor([1, 2, 3, 4]), torch.tensor([2, 3]))
tensor([[False, True, True, False]])
"""
assert elements.dim() == 1, "Elements should be 1D tensor."
assert test_elements.dim() == 1, "Test_elements should be 1D tensor."
return torch.ops.graphbolt.isin(elements, test_elements)
def etype_tuple_to_str(c_etype):
"""Convert canonical etype from tuple to string.
......
......@@ -4,7 +4,7 @@ from typing import Dict, Tuple, Union
import torch
from .base import etype_str_to_tuple
from .base import etype_str_to_tuple, isin
__all__ = ["SampledSubgraph"]
......@@ -85,6 +85,7 @@ class SampledSubgraph:
Dict[str, Tuple[torch.Tensor, torch.Tensor]],
Tuple[torch.Tensor, torch.Tensor],
],
assume_num_node_within_int32: bool = True,
):
r"""Exclude edges from the sampled subgraph.
......@@ -103,6 +104,10 @@ class SampledSubgraph:
should be a pair of tensors representing the edges to exclude. If
sampled subgraph is heterogeneous, then `edges` should be a dictionary
of edge types and the corresponding edges to exclude.
assume_num_node_within_int32: bool
If True, assumes the value of node IDs in the provided `edges` fall
within the int32 range, which can significantly enhance computation
speed. Default: True
Returns
-------
......@@ -133,6 +138,10 @@ class SampledSubgraph:
>>> print(result.original_edge_ids)
{"A:relation:B": tensor([19])}
"""
# TODO: Add support for value > in32, then remove this line.
assert (
assume_num_node_within_int32
), "Values > int32 are not supported yet."
assert isinstance(self.node_pairs, tuple) == isinstance(edges, tuple), (
"The sampled subgraph and the edges to exclude should be both "
"homogeneous or both heterogeneous."
......@@ -150,7 +159,9 @@ class SampledSubgraph:
self.original_row_node_ids,
self.original_column_node_ids,
)
index = _exclude_homo_edges(reverse_edges, edges)
index = _exclude_homo_edges(
reverse_edges, edges, assume_num_node_within_int32
)
return calling_class(*_slice_subgraph(self, index))
else:
index = {}
......@@ -172,7 +183,9 @@ class SampledSubgraph:
original_column_node_ids,
)
index[etype] = _exclude_homo_edges(
reverse_edges, edges.get(etype)
reverse_edges,
edges.get(etype),
assume_num_node_within_int32,
)
return calling_class(*_slice_subgraph(self, index))
......@@ -193,17 +206,17 @@ def _relabel_two_arrays(lhs_array, rhs_array):
return mapping[: lhs_array.numel()], mapping[lhs_array.numel() :]
def _exclude_homo_edges(edges, edges_to_exclude):
def _exclude_homo_edges(edges, edges_to_exclude, assume_num_node_within_int32):
"""Return the indices of edges that are not in edges_to_exclude."""
# 1. Relabel edges.
src, src_to_exclude = _relabel_two_arrays(edges[0], edges_to_exclude[0])
dst, dst_to_exclude = _relabel_two_arrays(edges[1], edges_to_exclude[1])
# 2. Compact the edges to integers.
dst_max_range = dst.numel() + dst_to_exclude.numel()
val = src * dst_max_range + dst
val_to_exclude = src_to_exclude * dst_max_range + dst_to_exclude
# 3. Use torch.isin to get the indices of edges to keep.
mask = ~torch.isin(val, val_to_exclude)
if assume_num_node_within_int32:
val = edges[0] << 32 | edges[1]
val_to_exclude = edges_to_exclude[0] << 32 | edges_to_exclude[1]
else:
# TODO: Add support for value > int32.
raise NotImplementedError(
"Values out of range int32 are not supported yet"
)
mask = ~isin(val, val_to_exclude)
return torch.nonzero(mask, as_tuple=True)[0]
......
......@@ -123,3 +123,30 @@ def test_etype_str_to_tuple():
),
):
_ = gb.etype_str_to_tuple(c_etype_str)
def test_isin():
elements = torch.tensor([2, 3, 5, 5, 20, 13, 11])
test_elements = torch.tensor([2, 5])
res = gb.isin(elements, test_elements)
expected = torch.tensor([True, False, True, True, False, False, False])
assert torch.equal(res, expected)
def test_isin_big_data():
elements = torch.randint(0, 10000, (10000000,))
test_elements = torch.randint(0, 10000, (500000,))
res = gb.isin(elements, test_elements)
expected = torch.isin(elements, test_elements)
assert torch.equal(res, expected)
def test_isin_non_1D_dim():
elements = torch.tensor([[2, 3], [5, 5], [20, 13]])
test_elements = torch.tensor([2, 5])
with pytest.raises(Exception):
gb.isin(elements, test_elements)
elements = torch.tensor([2, 3, 5, 5, 20, 13])
test_elements = torch.tensor([[2, 5]])
with pytest.raises(Exception):
gb.isin(elements, test_elements)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment