Unverified Commit 208fa368 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] Unique and Compact (#6738)

parent a259767b
......@@ -10,20 +10,54 @@
namespace graphbolt {
namespace ops {
std::pair<torch::Tensor, torch::Tensor> Sort(torch::Tensor input, int num_bits);
/**
* @brief Computes the exclusive prefix sum of the given input.
* @brief Sorts the given input and also returns the original indexes.
*
* @param input The input tensor.
* @param input A tensor containing IDs.
* @param num_bits An integer such that all elements of input tensor are
* are less than (1 << num_bits).
*
* @return The prefix sum result such that r[i] = \sum_{j=0}^{i-1} input[j]
* @return
* - A tuple of tensors, the first one includes sorted input, the second
* contains original positions of the sorted result.
*/
torch::Tensor ExclusiveCumSum(torch::Tensor input);
std::pair<torch::Tensor, torch::Tensor> Sort(
torch::Tensor input, int num_bits = 0);
/**
* @brief Select columns for a sparse matrix in a CSC format according to nodes
* tensor.
*
* NOTE:
* 1. The shape of all tensors must be 1-D.
* 2. Should be called if all input tensors are on device memory.
*
* @param indptr Indptr tensor containing offsets with shape (N,).
* @param indices Indices tensor with edge information of shape (indptr[N],).
* @param nodes Nodes tensor with shape (M,).
* @return (torch::Tensor, torch::Tensor) Output indptr and indices tensors of
* shapes (M + 1,) and ((indptr[nodes + 1] - indptr[nodes]).sum(),).
*/
std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes);
/**
* @brief Select columns for a sparse matrix in a CSC format according to nodes
* tensor.
*
* NOTE:
* 1. The shape of all tensors must be 1-D.
* 2. Should be called if indices tensor is on pinned memory.
*
* @param indptr Indptr tensor containing offsets with shape (N,).
* @param indices Indices tensor with edge information of shape (indptr[N],).
* @param nodes Nodes tensor with shape (M,).
* @return (torch::Tensor, torch::Tensor) Output indptr and indices tensors of
* shapes (M + 1,) and ((indptr[nodes + 1] - indptr[nodes]).sum(),).
*/
std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCImpl(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes);
/**
* @brief Slices the indptr tensor with nodes and returns the indegrees of the
* given nodes and their indptr values.
......@@ -39,10 +73,68 @@ std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl(
std::tuple<torch::Tensor, torch::Tensor> SliceCSCIndptr(
torch::Tensor indptr, torch::Tensor nodes);
std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCImpl(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes);
/**
* @brief Computes the exclusive prefix sum of the given input.
*
* @param input The input tensor.
*
* @return The prefix sum result such that r[i] = \sum_{j=0}^{i-1} input[j]
*/
torch::Tensor ExclusiveCumSum(torch::Tensor input);
/**
* @brief Select rows from input tensor according to index tensor.
*
* NOTE:
* 1. The shape of input tensor can be multi-dimensional, but the index tensor
* must be 1-D.
* 2. Should be called if input is on pinned memory and index is on pinned
* memory or GPU memory.
*
* @param input Input tensor with shape (N, ...).
* @param index Index tensor with shape (M,).
* @return torch::Tensor Output tensor with shape (M, ...).
*/
torch::Tensor UVAIndexSelectImpl(torch::Tensor input, torch::Tensor index);
/**
* @brief Removes duplicate elements from the concatenated 'unique_dst_ids' and
* 'src_ids' tensor and applies the uniqueness information to compact both
* source and destination tensors.
*
* The function performs two main operations:
* 1. Unique Operation: 'unique(concat(unique_dst_ids, src_ids))', in which
* the unique operator will guarantee the 'unique_dst_ids' are at the head of
* the result tensor.
* 2. Compact Operation: Utilizes the reverse mapping derived from the unique
* operation to transform 'src_ids' and 'dst_ids' into compacted IDs.
*
* @param src_ids A tensor containing source IDs.
* @param dst_ids A tensor containing destination IDs.
* @param unique_dst_ids A tensor containing unique destination IDs, which is
* exactly all the unique elements in 'dst_ids'.
*
* @return
* - A tensor representing all unique elements in 'src_ids' and 'dst_ids' after
* removing duplicates. The indices in this tensor precisely match the compacted
* IDs of the corresponding elements.
* - The tensor corresponding to the 'src_ids' tensor, where the entries are
* mapped to compacted IDs.
* - The tensor corresponding to the 'dst_ids' tensor, where the entries are
* mapped to compacted IDs.
*
* @example
* torch::Tensor src_ids = src
* torch::Tensor dst_ids = dst
* torch::Tensor unique_dst_ids = torch::unique(dst);
* auto result = UniqueAndCompact(src_ids, dst_ids, unique_dst_ids);
* torch::Tensor unique_ids = std::get<0>(result);
* torch::Tensor compacted_src_ids = std::get<1>(result);
* torch::Tensor compacted_dst_ids = std::get<2>(result);
*/
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
const torch::Tensor src_ids, const torch::Tensor dst_ids,
const torch::Tensor unique_dst_ids, int num_bits = 0);
} // namespace ops
} // namespace graphbolt
/**
* Copyright (c) 2023 by Contributors
* Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
* @file cuda/unique_and_compact_impl.cu
* @brief Unique and compact operator implementation on CUDA.
*/
#include <c10/cuda/CUDAStream.h>
#include <graphbolt/cuda_ops.h>
#include <thrust/binary_search.h>
#include <thrust/functional.h>
#include <thrust/gather.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/logical.h>
#include <thrust/reduce.h>
#include <thrust/remove.h>
#include <cub/cub.cuh>
#include "./common.h"
#include "./utils.h"
namespace graphbolt {
namespace ops {
template <typename scalar_t>
struct EqualityFunc {
const scalar_t* sorted_order;
const scalar_t* found_locations;
const scalar_t* searched_items;
__host__ __device__ auto operator()(int64_t i) {
return sorted_order[found_locations[i]] == searched_items[i];
}
};
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
const torch::Tensor src_ids, const torch::Tensor dst_ids,
const torch::Tensor unique_dst_ids, int num_bits) {
TORCH_CHECK(
src_ids.scalar_type() == dst_ids.scalar_type() &&
dst_ids.scalar_type() == unique_dst_ids.scalar_type(),
"Dtypes of tensors passed to UniqueAndCompact need to be identical.");
auto allocator = cuda::GetAllocator();
auto stream = cuda::GetCurrentStream();
const auto exec_policy = thrust::cuda::par_nosync(allocator).on(stream);
return AT_DISPATCH_INTEGRAL_TYPES(
src_ids.scalar_type(), "unique_and_compact", ([&] {
auto src_ids_ptr = src_ids.data_ptr<scalar_t>();
auto dst_ids_ptr = dst_ids.data_ptr<scalar_t>();
auto unique_dst_ids_ptr = unique_dst_ids.data_ptr<scalar_t>();
// If the given num_bits argument is not in the reasonable range,
// we recompute it to speedup the expensive sort operations.
if (num_bits <= 0 || num_bits > sizeof(scalar_t) * 8) {
auto max_id = thrust::reduce(
exec_policy, src_ids_ptr, src_ids_ptr + src_ids.size(0),
static_cast<scalar_t>(0), thrust::maximum<scalar_t>{});
max_id = thrust::reduce(
exec_policy, unique_dst_ids_ptr,
unique_dst_ids_ptr + unique_dst_ids.size(0), max_id,
thrust::maximum<scalar_t>{});
num_bits = cuda::NumberOfBits(max_id + 1);
}
// Sort the unique_dst_ids tensor.
auto sorted_unique_dst_ids =
allocator.AllocateStorage<scalar_t>(unique_dst_ids.size(0));
{
size_t workspace_size;
CUDA_CALL(cub::DeviceRadixSort::SortKeys(
nullptr, workspace_size, unique_dst_ids_ptr,
sorted_unique_dst_ids.get(), unique_dst_ids.size(0), 0, num_bits,
stream));
auto temp = allocator.AllocateStorage<char>(workspace_size);
CUDA_CALL(cub::DeviceRadixSort::SortKeys(
temp.get(), workspace_size, unique_dst_ids_ptr,
sorted_unique_dst_ids.get(), unique_dst_ids.size(0), 0, num_bits,
stream));
}
// Mark dst nodes in the src_ids tensor.
auto is_dst = allocator.AllocateStorage<bool>(src_ids.size(0));
thrust::binary_search(
exec_policy, sorted_unique_dst_ids.get(),
sorted_unique_dst_ids.get() + unique_dst_ids.size(0), src_ids_ptr,
src_ids_ptr + src_ids.size(0), is_dst.get());
// Filter the non-dst nodes in the src_ids tensor, hence only_src.
auto only_src = allocator.AllocateStorage<scalar_t>(src_ids.size(0));
auto only_src_size =
thrust::remove_copy_if(
exec_policy, src_ids_ptr, src_ids_ptr + src_ids.size(0),
is_dst.get(), only_src.get(), thrust::identity<bool>{}) -
only_src.get();
auto sorted_only_src =
allocator.AllocateStorage<scalar_t>(only_src_size);
{ // Sort the only_src tensor so that we can unique it with Encode
// operation later.
size_t workspace_size;
CUDA_CALL(cub::DeviceRadixSort::SortKeys(
nullptr, workspace_size, only_src.get(), sorted_only_src.get(),
only_src_size, 0, num_bits, stream));
auto temp = allocator.AllocateStorage<char>(workspace_size);
CUDA_CALL(cub::DeviceRadixSort::SortKeys(
temp.get(), workspace_size, only_src.get(), sorted_only_src.get(),
only_src_size, 0, num_bits, stream));
}
auto unique_only_src = torch::empty(only_src_size, src_ids.options());
auto unique_only_src_ptr = unique_only_src.data_ptr<scalar_t>();
auto unique_only_src_cnt = allocator.AllocateStorage<scalar_t>(1);
{ // Compute the unique operation on the only_src tensor.
size_t workspace_size;
CUDA_CALL(cub::DeviceRunLengthEncode::Encode(
nullptr, workspace_size, sorted_only_src.get(),
unique_only_src_ptr, cub::DiscardOutputIterator{},
unique_only_src_cnt.get(), only_src_size, stream));
auto temp = allocator.AllocateStorage<char>(workspace_size);
CUDA_CALL(cub::DeviceRunLengthEncode::Encode(
temp.get(), workspace_size, sorted_only_src.get(),
unique_only_src_ptr, cub::DiscardOutputIterator{},
unique_only_src_cnt.get(), only_src_size, stream));
}
auto unique_only_src_size = cuda::CopyScalar(unique_only_src_cnt.get());
unique_only_src = unique_only_src.slice(
0, 0, static_cast<scalar_t>(unique_only_src_size));
auto real_order = torch::cat({unique_dst_ids, unique_only_src});
// Sort here so that binary search can be used to lookup new_ids.
auto [sorted_order, new_ids] = Sort(real_order, num_bits);
auto sorted_order_ptr = sorted_order.data_ptr<scalar_t>();
auto new_ids_ptr = new_ids.data_ptr<int64_t>();
// Holds the found locations of the src and dst ids in the sorted_order.
// Later is used to lookup the new ids of the src_ids and dst_ids
// tensors.
auto new_src_ids_loc =
allocator.AllocateStorage<scalar_t>(src_ids.size(0));
auto new_dst_ids_loc =
allocator.AllocateStorage<scalar_t>(dst_ids.size(0));
thrust::lower_bound(
exec_policy, sorted_order_ptr,
sorted_order_ptr + sorted_order.size(0), src_ids_ptr,
src_ids_ptr + src_ids.size(0), new_src_ids_loc.get());
thrust::lower_bound(
exec_policy, sorted_order_ptr,
sorted_order_ptr + sorted_order.size(0), dst_ids_ptr,
dst_ids_ptr + dst_ids.size(0), new_dst_ids_loc.get());
{ // Check if unique_dst_ids includes all dst_ids.
thrust::counting_iterator<int64_t> iota(0);
auto equal_it = thrust::make_transform_iterator(
iota, EqualityFunc<scalar_t>{
sorted_order_ptr, new_dst_ids_loc.get(), dst_ids_ptr});
auto all_exist = thrust::all_of(
exec_policy, equal_it, equal_it + dst_ids.size(0),
thrust::identity<bool>());
if (!all_exist) {
throw std::out_of_range("Some ids not found.");
}
}
// Finally, lookup the new compact ids of the src and dst tensors via
// gather operations.
auto new_src_ids = torch::empty_like(src_ids);
auto new_dst_ids = torch::empty_like(dst_ids);
thrust::gather(
exec_policy, new_src_ids_loc.get(),
new_src_ids_loc.get() + src_ids.size(0),
new_ids.data_ptr<int64_t>(), new_src_ids.data_ptr<scalar_t>());
thrust::gather(
exec_policy, new_dst_ids_loc.get(),
new_dst_ids_loc.get() + dst_ids.size(0),
new_ids.data_ptr<int64_t>(), new_dst_ids.data_ptr<scalar_t>());
return std::make_tuple(real_order, new_src_ids, new_dst_ids);
}));
}
} // namespace ops
} // namespace graphbolt
......@@ -5,17 +5,27 @@
* @brief Unique and compact op.
*/
#include <graphbolt/cuda_ops.h>
#include <graphbolt/unique_and_compact.h>
#include <unordered_map>
#include "./concurrent_id_hash_map.h"
#include "./macro.h"
#include "./utils.h"
namespace graphbolt {
namespace sampling {
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
const torch::Tensor& src_ids, const torch::Tensor& dst_ids,
const torch::Tensor unique_dst_ids) {
if (utils::is_accessible_from_gpu(src_ids) &&
utils::is_accessible_from_gpu(dst_ids) &&
utils::is_accessible_from_gpu(unique_dst_ids)) {
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(
c10::DeviceType::CUDA, "unique_and_compact",
{ return ops::UniqueAndCompact(src_ids, dst_ids, unique_dst_ids); });
}
torch::Tensor compacted_src_ids;
torch::Tensor compacted_dst_ids;
torch::Tensor unique_ids;
......
......@@ -128,7 +128,10 @@ def unique_and_compact_node_pairs(
# Collect all source and destination nodes for each node type.
src_nodes = defaultdict(list)
dst_nodes = defaultdict(list)
device = None
for etype, (src_node, dst_node) in node_pairs.items():
if device is None:
device = src_node.device
src_type, _, dst_type = etype_str_to_tuple(etype)
src_nodes[src_type].append(src_node)
dst_nodes[dst_type].append(dst_node)
......@@ -145,7 +148,7 @@ def unique_and_compact_node_pairs(
compacted_src = {}
compacted_dst = {}
dtype = list(src_nodes.values())[0].dtype
default_tensor = torch.tensor([], dtype=dtype)
default_tensor = torch.tensor([], dtype=dtype, device=device)
for ntype in ntypes:
src = src_nodes.get(ntype, default_tensor)
unique_dst = unique_dst_nodes.get(ntype, default_tensor)
......@@ -247,7 +250,10 @@ def unique_and_compact_csc_formats(
# Collect all source and destination nodes for each node type.
indices = defaultdict(list)
device = None
for etype, csc_format in csc_formats.items():
if device is None:
device = csc_format.indices.device
assert csc_format.indptr[-1] == len(
csc_format.indices
), "The last element of indptr should be the same as the length of indices."
......@@ -262,7 +268,7 @@ def unique_and_compact_csc_formats(
unique_nodes = {}
compacted_indices = {}
dtype = list(indices.values())[0].dtype
default_tensor = torch.tensor([], dtype=dtype)
default_tensor = torch.tensor([], dtype=dtype, device=device)
for ntype in ntypes:
indice = indices.get(ntype, default_tensor)
unique_dst = unique_dst_nodes.get(ntype, default_tensor)
......@@ -271,7 +277,9 @@ def unique_and_compact_csc_formats(
compacted_indices[ntype],
_,
) = torch.ops.graphbolt.unique_and_compact(
indice, torch.tensor([], dtype=indice.dtype), unique_dst
indice,
torch.tensor([], dtype=indice.dtype, device=device),
unique_dst,
)
compacted_csc_formats = {}
......
......@@ -3,11 +3,15 @@ import backend as F
import dgl
import dgl.graphbolt
import torch
import torch.multiprocessing as mp
from . import gb_test_utils
def test_DataLoader():
# https://pytorch.org/docs/master/notes/multiprocessing.html#cuda-in-multiprocessing
mp.set_start_method("spawn", force=True)
N = 40
B = 4
itemset = dgl.graphbolt.ItemSet(torch.arange(N), names="seed_nodes")
......
import backend as F
import dgl.graphbolt as gb
import pytest
import torch
......@@ -71,13 +72,25 @@ def test_find_reverse_edges_circual_reverse_types():
def test_unique_and_compact_hetero():
N1 = torch.tensor([0, 5, 2, 7, 12, 7, 9, 5, 6, 2, 3, 4, 1, 0, 9])
N2 = torch.tensor([0, 3, 3, 5, 2, 7, 2, 8, 4, 9, 2, 3])
N3 = torch.tensor([1, 2, 6, 6, 1, 8, 3, 6, 3, 2])
N1 = torch.tensor(
[0, 5, 2, 7, 12, 7, 9, 5, 6, 2, 3, 4, 1, 0, 9], device=F.ctx()
)
N2 = torch.tensor([0, 3, 3, 5, 2, 7, 2, 8, 4, 9, 2, 3], device=F.ctx())
N3 = torch.tensor([1, 2, 6, 6, 1, 8, 3, 6, 3, 2], device=F.ctx())
expected_unique = {
"n1": torch.tensor([0, 5, 2, 7, 12, 9, 6, 3, 4, 1]),
"n2": torch.tensor([0, 3, 5, 2, 7, 8, 4, 9]),
"n3": torch.tensor([1, 2, 6, 8, 3]),
"n1": torch.tensor([0, 5, 2, 7, 12, 9, 6, 3, 4, 1], device=F.ctx()),
"n2": torch.tensor([0, 3, 5, 2, 7, 8, 4, 9], device=F.ctx()),
"n3": torch.tensor([1, 2, 6, 8, 3], device=F.ctx()),
}
if N1.is_cuda:
expected_reverse_id = {
k: v.sort()[1] for k, v in expected_unique.items()
}
expected_unique = {k: v.sort()[0] for k, v in expected_unique.items()}
else:
expected_reverse_id = {
k: torch.arange(0, v.shape[0], device=F.ctx())
for k, v in expected_unique.items()
}
nodes_dict = {
"n1": N1.split(5),
......@@ -86,21 +99,21 @@ def test_unique_and_compact_hetero():
}
expected_nodes_dict = {
"n1": [
torch.tensor([0, 1, 2, 3, 4]),
torch.tensor([3, 5, 1, 6, 2]),
torch.tensor([7, 8, 9, 0, 5]),
torch.tensor([0, 1, 2, 3, 4], device=F.ctx()),
torch.tensor([3, 5, 1, 6, 2], device=F.ctx()),
torch.tensor([7, 8, 9, 0, 5], device=F.ctx()),
],
"n2": [
torch.tensor([0, 1, 1, 2]),
torch.tensor([3, 4, 3, 5]),
torch.tensor([6, 7, 3, 1]),
torch.tensor([0, 1, 1, 2], device=F.ctx()),
torch.tensor([3, 4, 3, 5], device=F.ctx()),
torch.tensor([6, 7, 3, 1], device=F.ctx()),
],
"n3": [
torch.tensor([0, 1]),
torch.tensor([2, 2]),
torch.tensor([0, 3]),
torch.tensor([4, 2]),
torch.tensor([4, 1]),
torch.tensor([0, 1], device=F.ctx()),
torch.tensor([2, 2], device=F.ctx()),
torch.tensor([0, 3], device=F.ctx()),
torch.tensor([4, 2], device=F.ctx()),
torch.tensor([4, 1], device=F.ctx()),
],
}
......@@ -113,17 +126,29 @@ def test_unique_and_compact_hetero():
expected_nodes = expected_nodes_dict[ntype]
assert isinstance(nodes, list)
for expected_node, node in zip(expected_nodes, nodes):
node = expected_reverse_id[ntype][node]
assert torch.equal(expected_node, node)
def test_unique_and_compact_homo():
N = torch.tensor([0, 5, 2, 7, 12, 7, 9, 5, 6, 2, 3, 4, 1, 0, 9])
expected_unique_N = torch.tensor([0, 5, 2, 7, 12, 9, 6, 3, 4, 1])
N = torch.tensor(
[0, 5, 2, 7, 12, 7, 9, 5, 6, 2, 3, 4, 1, 0, 9], device=F.ctx()
)
expected_unique_N = torch.tensor(
[0, 5, 2, 7, 12, 9, 6, 3, 4, 1], device=F.ctx()
)
if N.is_cuda:
expected_reverse_id_N = expected_unique_N.sort()[1]
expected_unique_N = expected_unique_N.sort()[0]
else:
expected_reverse_id_N = torch.arange(
0, expected_unique_N.shape[0], device=F.ctx()
)
nodes_list = N.split(5)
expected_nodes_list = [
torch.tensor([0, 1, 2, 3, 4]),
torch.tensor([3, 5, 1, 6, 2]),
torch.tensor([7, 8, 9, 0, 5]),
torch.tensor([0, 1, 2, 3, 4], device=F.ctx()),
torch.tensor([3, 5, 1, 6, 2], device=F.ctx()),
torch.tensor([7, 8, 9, 0, 5], device=F.ctx()),
]
unique, compacted = gb.unique_and_compact(nodes_list)
......@@ -131,42 +156,59 @@ def test_unique_and_compact_homo():
assert torch.equal(unique, expected_unique_N)
assert isinstance(compacted, list)
for expected_node, node in zip(expected_nodes_list, compacted):
node = expected_reverse_id_N[node]
assert torch.equal(expected_node, node)
def test_unique_and_compact_node_pairs_hetero():
node_pairs = {
"n1:e1:n2": (
torch.tensor([1, 3, 4, 6, 2, 7, 9, 4, 2, 6]),
torch.tensor([2, 2, 2, 4, 1, 1, 1, 3, 3, 3]),
torch.tensor([1, 3, 4, 6, 2, 7, 9, 4, 2, 6], device=F.ctx()),
torch.tensor([2, 2, 2, 4, 1, 1, 1, 3, 3, 3], device=F.ctx()),
),
"n1:e2:n3": (
torch.tensor([5, 2, 6, 4, 7, 2, 8, 1, 3, 0]),
torch.tensor([1, 3, 3, 3, 2, 2, 2, 7, 7, 7]),
torch.tensor([5, 2, 6, 4, 7, 2, 8, 1, 3, 0], device=F.ctx()),
torch.tensor([1, 3, 3, 3, 2, 2, 2, 7, 7, 7], device=F.ctx()),
),
"n2:e3:n3": (
torch.tensor([2, 5, 4, 1, 4, 3, 6, 0]),
torch.tensor([1, 1, 3, 3, 2, 2, 7, 7]),
torch.tensor([2, 5, 4, 1, 4, 3, 6, 0], device=F.ctx()),
torch.tensor([1, 1, 3, 3, 2, 2, 7, 7], device=F.ctx()),
),
}
expected_unique_nodes = {
"n1": torch.tensor([1, 3, 4, 6, 2, 7, 9, 5, 8, 0]),
"n2": torch.tensor([1, 2, 3, 4, 5, 6, 0]),
"n3": torch.tensor([1, 2, 3, 7]),
"n1": torch.tensor([1, 3, 4, 6, 2, 7, 9, 5, 8, 0], device=F.ctx()),
"n2": torch.tensor([1, 2, 3, 4, 5, 6, 0], device=F.ctx()),
"n3": torch.tensor([1, 2, 3, 7], device=F.ctx()),
}
if expected_unique_nodes["n1"].is_cuda:
expected_reverse_id = {
"n1": expected_unique_nodes["n1"].sort()[1],
"n2": torch.tensor([0, 1, 2, 3, 6, 4, 5], device=F.ctx()),
"n3": expected_unique_nodes["n3"].sort()[1],
}
expected_unique_nodes = {
"n1": expected_unique_nodes["n1"].sort()[0],
"n2": torch.tensor([1, 2, 3, 4, 0, 5, 6], device=F.ctx()),
"n3": expected_unique_nodes["n3"].sort()[0],
}
else:
expected_reverse_id = {
k: torch.arange(0, v.shape[0], device=F.ctx())
for k, v in expected_unique_nodes.items()
}
expected_node_pairs = {
"n1:e1:n2": (
torch.tensor([0, 1, 2, 3, 4, 5, 6, 2, 4, 3]),
torch.tensor([1, 1, 1, 3, 0, 0, 0, 2, 2, 2]),
torch.tensor([0, 1, 2, 3, 4, 5, 6, 2, 4, 3], device=F.ctx()),
torch.tensor([1, 1, 1, 3, 0, 0, 0, 2, 2, 2], device=F.ctx()),
),
"n1:e2:n3": (
torch.tensor([7, 4, 3, 2, 5, 4, 8, 0, 1, 9]),
torch.tensor([0, 2, 2, 2, 1, 1, 1, 3, 3, 3]),
torch.tensor([7, 4, 3, 2, 5, 4, 8, 0, 1, 9], device=F.ctx()),
torch.tensor([0, 2, 2, 2, 1, 1, 1, 3, 3, 3], device=F.ctx()),
),
"n2:e3:n3": (
torch.tensor([1, 4, 3, 0, 3, 2, 5, 6]),
torch.tensor([0, 0, 2, 2, 1, 1, 3, 3]),
torch.tensor([1, 4, 3, 0, 3, 2, 5, 6], device=F.ctx()),
torch.tensor([0, 0, 2, 2, 1, 1, 3, 3], device=F.ctx()),
),
}
......@@ -178,19 +220,26 @@ def test_unique_and_compact_node_pairs_hetero():
assert torch.equal(nodes, expected_nodes)
for etype, pair in compacted_node_pairs.items():
u, v = pair
ntype1, _, ntype2 = etype.split(":")
u = expected_reverse_id[ntype1][u]
v = expected_reverse_id[ntype2][v]
expected_u, expected_v = expected_node_pairs[etype]
assert torch.equal(u, expected_u)
assert torch.equal(v, expected_v)
def test_unique_and_compact_node_pairs_homo():
dst_nodes = torch.tensor([1, 1, 3, 3, 5, 5, 2, 6, 6, 6, 6])
src_ndoes = torch.tensor([2, 3, 1, 4, 5, 2, 5, 1, 4, 4, 6])
node_pairs = (src_ndoes, dst_nodes)
dst_nodes = torch.tensor([1, 1, 3, 3, 5, 5, 2, 6, 6, 6, 6], device=F.ctx())
src_nodes = torch.tensor([2, 3, 1, 4, 5, 2, 5, 1, 4, 4, 6], device=F.ctx())
node_pairs = (src_nodes, dst_nodes)
expected_unique_nodes = torch.tensor([1, 2, 3, 5, 6, 4])
expected_dst_nodes = torch.tensor([0, 0, 2, 2, 3, 3, 1, 4, 4, 4, 4])
expected_src_ndoes = torch.tensor([1, 2, 0, 5, 3, 1, 3, 0, 5, 5, 4])
expected_unique_nodes = torch.tensor([1, 2, 3, 5, 6, 4], device=F.ctx())
expected_dst_nodes = torch.tensor(
[0, 0, 2, 2, 3, 3, 1, 4, 4, 4, 4], device=F.ctx()
)
expected_src_ndoes = torch.tensor(
[1, 2, 0, 5, 3, 1, 3, 0, 5, 5, 4], device=F.ctx()
)
unique_nodes, compacted_node_pairs = gb.unique_and_compact_node_pairs(
node_pairs
)
......@@ -199,12 +248,17 @@ def test_unique_and_compact_node_pairs_homo():
u, v = compacted_node_pairs
assert torch.equal(u, expected_src_ndoes)
assert torch.equal(v, expected_dst_nodes)
assert torch.equal(unique_nodes[:5], torch.tensor([1, 2, 3, 5, 6]))
assert torch.equal(
unique_nodes[:5], torch.tensor([1, 2, 3, 5, 6], device=F.ctx())
)
def test_incomplete_unique_dst_nodes_():
node_pairs = (torch.arange(0, 50), torch.arange(100, 150))
unique_dst_nodes = torch.arange(150, 200)
node_pairs = (
torch.arange(0, 50, device=F.ctx()),
torch.arange(100, 150, device=F.ctx()),
)
unique_dst_nodes = torch.arange(150, 200, device=F.ctx())
with pytest.raises(IndexError):
gb.unique_and_compact_node_pairs(node_pairs, unique_dst_nodes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment