"examples/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "549df65a771be2443fd9b67d33eac3d0f1b13965"
Unverified Commit 208fa368 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] Unique and Compact (#6738)

parent a259767b
...@@ -10,20 +10,54 @@ ...@@ -10,20 +10,54 @@
namespace graphbolt { namespace graphbolt {
namespace ops { namespace ops {
std::pair<torch::Tensor, torch::Tensor> Sort(torch::Tensor input, int num_bits);
/** /**
* @brief Computes the exclusive prefix sum of the given input. * @brief Sorts the given input and also returns the original indexes.
* *
* @param input The input tensor. * @param input A tensor containing IDs.
* @param num_bits An integer such that all elements of input tensor are
* are less than (1 << num_bits).
* *
* @return The prefix sum result such that r[i] = \sum_{j=0}^{i-1} input[j] * @return
* - A tuple of tensors, the first one includes sorted input, the second
* contains original positions of the sorted result.
*/ */
torch::Tensor ExclusiveCumSum(torch::Tensor input); std::pair<torch::Tensor, torch::Tensor> Sort(
torch::Tensor input, int num_bits = 0);
/**
* @brief Select columns for a sparse matrix in a CSC format according to nodes
* tensor.
*
* NOTE:
* 1. The shape of all tensors must be 1-D.
* 2. Should be called if all input tensors are on device memory.
*
* @param indptr Indptr tensor containing offsets with shape (N,).
* @param indices Indices tensor with edge information of shape (indptr[N],).
* @param nodes Nodes tensor with shape (M,).
* @return (torch::Tensor, torch::Tensor) Output indptr and indices tensors of
* shapes (M + 1,) and ((indptr[nodes + 1] - indptr[nodes]).sum(),).
*/
std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl( std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes); torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes);
/**
* @brief Select columns for a sparse matrix in a CSC format according to nodes
* tensor.
*
* NOTE:
* 1. The shape of all tensors must be 1-D.
* 2. Should be called if indices tensor is on pinned memory.
*
* @param indptr Indptr tensor containing offsets with shape (N,).
* @param indices Indices tensor with edge information of shape (indptr[N],).
* @param nodes Nodes tensor with shape (M,).
* @return (torch::Tensor, torch::Tensor) Output indptr and indices tensors of
* shapes (M + 1,) and ((indptr[nodes + 1] - indptr[nodes]).sum(),).
*/
std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCImpl(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes);
/** /**
* @brief Slices the indptr tensor with nodes and returns the indegrees of the * @brief Slices the indptr tensor with nodes and returns the indegrees of the
* given nodes and their indptr values. * given nodes and their indptr values.
...@@ -39,10 +73,68 @@ std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl( ...@@ -39,10 +73,68 @@ std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl(
std::tuple<torch::Tensor, torch::Tensor> SliceCSCIndptr( std::tuple<torch::Tensor, torch::Tensor> SliceCSCIndptr(
torch::Tensor indptr, torch::Tensor nodes); torch::Tensor indptr, torch::Tensor nodes);
std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCImpl( /**
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes); * @brief Computes the exclusive prefix sum of the given input.
*
* @param input The input tensor.
*
* @return The prefix sum result such that r[i] = \sum_{j=0}^{i-1} input[j]
*/
torch::Tensor ExclusiveCumSum(torch::Tensor input);
/**
* @brief Select rows from input tensor according to index tensor.
*
* NOTE:
* 1. The shape of input tensor can be multi-dimensional, but the index tensor
* must be 1-D.
* 2. Should be called if input is on pinned memory and index is on pinned
* memory or GPU memory.
*
* @param input Input tensor with shape (N, ...).
* @param index Index tensor with shape (M,).
* @return torch::Tensor Output tensor with shape (M, ...).
*/
torch::Tensor UVAIndexSelectImpl(torch::Tensor input, torch::Tensor index); torch::Tensor UVAIndexSelectImpl(torch::Tensor input, torch::Tensor index);
/**
* @brief Removes duplicate elements from the concatenated 'unique_dst_ids' and
* 'src_ids' tensor and applies the uniqueness information to compact both
* source and destination tensors.
*
* The function performs two main operations:
* 1. Unique Operation: 'unique(concat(unique_dst_ids, src_ids))', in which
* the unique operator will guarantee the 'unique_dst_ids' are at the head of
* the result tensor.
* 2. Compact Operation: Utilizes the reverse mapping derived from the unique
* operation to transform 'src_ids' and 'dst_ids' into compacted IDs.
*
* @param src_ids A tensor containing source IDs.
* @param dst_ids A tensor containing destination IDs.
* @param unique_dst_ids A tensor containing unique destination IDs, which is
* exactly all the unique elements in 'dst_ids'.
*
* @return
* - A tensor representing all unique elements in 'src_ids' and 'dst_ids' after
* removing duplicates. The indices in this tensor precisely match the compacted
* IDs of the corresponding elements.
* - The tensor corresponding to the 'src_ids' tensor, where the entries are
* mapped to compacted IDs.
* - The tensor corresponding to the 'dst_ids' tensor, where the entries are
* mapped to compacted IDs.
*
* @example
* torch::Tensor src_ids = src
* torch::Tensor dst_ids = dst
* torch::Tensor unique_dst_ids = torch::unique(dst);
* auto result = UniqueAndCompact(src_ids, dst_ids, unique_dst_ids);
* torch::Tensor unique_ids = std::get<0>(result);
* torch::Tensor compacted_src_ids = std::get<1>(result);
* torch::Tensor compacted_dst_ids = std::get<2>(result);
*/
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
const torch::Tensor src_ids, const torch::Tensor dst_ids,
const torch::Tensor unique_dst_ids, int num_bits = 0);
} // namespace ops } // namespace ops
} // namespace graphbolt } // namespace graphbolt
/**
* Copyright (c) 2023 by Contributors
* Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
* @file cuda/unique_and_compact_impl.cu
* @brief Unique and compact operator implementation on CUDA.
*/
#include <c10/cuda/CUDAStream.h>
#include <graphbolt/cuda_ops.h>
#include <thrust/binary_search.h>
#include <thrust/functional.h>
#include <thrust/gather.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/logical.h>
#include <thrust/reduce.h>
#include <thrust/remove.h>
#include <cub/cub.cuh>
#include "./common.h"
#include "./utils.h"
namespace graphbolt {
namespace ops {
template <typename scalar_t>
struct EqualityFunc {
const scalar_t* sorted_order;
const scalar_t* found_locations;
const scalar_t* searched_items;
__host__ __device__ auto operator()(int64_t i) {
return sorted_order[found_locations[i]] == searched_items[i];
}
};
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
const torch::Tensor src_ids, const torch::Tensor dst_ids,
const torch::Tensor unique_dst_ids, int num_bits) {
TORCH_CHECK(
src_ids.scalar_type() == dst_ids.scalar_type() &&
dst_ids.scalar_type() == unique_dst_ids.scalar_type(),
"Dtypes of tensors passed to UniqueAndCompact need to be identical.");
auto allocator = cuda::GetAllocator();
auto stream = cuda::GetCurrentStream();
const auto exec_policy = thrust::cuda::par_nosync(allocator).on(stream);
return AT_DISPATCH_INTEGRAL_TYPES(
src_ids.scalar_type(), "unique_and_compact", ([&] {
auto src_ids_ptr = src_ids.data_ptr<scalar_t>();
auto dst_ids_ptr = dst_ids.data_ptr<scalar_t>();
auto unique_dst_ids_ptr = unique_dst_ids.data_ptr<scalar_t>();
// If the given num_bits argument is not in the reasonable range,
// we recompute it to speedup the expensive sort operations.
if (num_bits <= 0 || num_bits > sizeof(scalar_t) * 8) {
auto max_id = thrust::reduce(
exec_policy, src_ids_ptr, src_ids_ptr + src_ids.size(0),
static_cast<scalar_t>(0), thrust::maximum<scalar_t>{});
max_id = thrust::reduce(
exec_policy, unique_dst_ids_ptr,
unique_dst_ids_ptr + unique_dst_ids.size(0), max_id,
thrust::maximum<scalar_t>{});
num_bits = cuda::NumberOfBits(max_id + 1);
}
// Sort the unique_dst_ids tensor.
auto sorted_unique_dst_ids =
allocator.AllocateStorage<scalar_t>(unique_dst_ids.size(0));
{
size_t workspace_size;
CUDA_CALL(cub::DeviceRadixSort::SortKeys(
nullptr, workspace_size, unique_dst_ids_ptr,
sorted_unique_dst_ids.get(), unique_dst_ids.size(0), 0, num_bits,
stream));
auto temp = allocator.AllocateStorage<char>(workspace_size);
CUDA_CALL(cub::DeviceRadixSort::SortKeys(
temp.get(), workspace_size, unique_dst_ids_ptr,
sorted_unique_dst_ids.get(), unique_dst_ids.size(0), 0, num_bits,
stream));
}
// Mark dst nodes in the src_ids tensor.
auto is_dst = allocator.AllocateStorage<bool>(src_ids.size(0));
thrust::binary_search(
exec_policy, sorted_unique_dst_ids.get(),
sorted_unique_dst_ids.get() + unique_dst_ids.size(0), src_ids_ptr,
src_ids_ptr + src_ids.size(0), is_dst.get());
// Filter the non-dst nodes in the src_ids tensor, hence only_src.
auto only_src = allocator.AllocateStorage<scalar_t>(src_ids.size(0));
auto only_src_size =
thrust::remove_copy_if(
exec_policy, src_ids_ptr, src_ids_ptr + src_ids.size(0),
is_dst.get(), only_src.get(), thrust::identity<bool>{}) -
only_src.get();
auto sorted_only_src =
allocator.AllocateStorage<scalar_t>(only_src_size);
{ // Sort the only_src tensor so that we can unique it with Encode
// operation later.
size_t workspace_size;
CUDA_CALL(cub::DeviceRadixSort::SortKeys(
nullptr, workspace_size, only_src.get(), sorted_only_src.get(),
only_src_size, 0, num_bits, stream));
auto temp = allocator.AllocateStorage<char>(workspace_size);
CUDA_CALL(cub::DeviceRadixSort::SortKeys(
temp.get(), workspace_size, only_src.get(), sorted_only_src.get(),
only_src_size, 0, num_bits, stream));
}
auto unique_only_src = torch::empty(only_src_size, src_ids.options());
auto unique_only_src_ptr = unique_only_src.data_ptr<scalar_t>();
auto unique_only_src_cnt = allocator.AllocateStorage<scalar_t>(1);
{ // Compute the unique operation on the only_src tensor.
size_t workspace_size;
CUDA_CALL(cub::DeviceRunLengthEncode::Encode(
nullptr, workspace_size, sorted_only_src.get(),
unique_only_src_ptr, cub::DiscardOutputIterator{},
unique_only_src_cnt.get(), only_src_size, stream));
auto temp = allocator.AllocateStorage<char>(workspace_size);
CUDA_CALL(cub::DeviceRunLengthEncode::Encode(
temp.get(), workspace_size, sorted_only_src.get(),
unique_only_src_ptr, cub::DiscardOutputIterator{},
unique_only_src_cnt.get(), only_src_size, stream));
}
auto unique_only_src_size = cuda::CopyScalar(unique_only_src_cnt.get());
unique_only_src = unique_only_src.slice(
0, 0, static_cast<scalar_t>(unique_only_src_size));
auto real_order = torch::cat({unique_dst_ids, unique_only_src});
// Sort here so that binary search can be used to lookup new_ids.
auto [sorted_order, new_ids] = Sort(real_order, num_bits);
auto sorted_order_ptr = sorted_order.data_ptr<scalar_t>();
auto new_ids_ptr = new_ids.data_ptr<int64_t>();
// Holds the found locations of the src and dst ids in the sorted_order.
// Later is used to lookup the new ids of the src_ids and dst_ids
// tensors.
auto new_src_ids_loc =
allocator.AllocateStorage<scalar_t>(src_ids.size(0));
auto new_dst_ids_loc =
allocator.AllocateStorage<scalar_t>(dst_ids.size(0));
thrust::lower_bound(
exec_policy, sorted_order_ptr,
sorted_order_ptr + sorted_order.size(0), src_ids_ptr,
src_ids_ptr + src_ids.size(0), new_src_ids_loc.get());
thrust::lower_bound(
exec_policy, sorted_order_ptr,
sorted_order_ptr + sorted_order.size(0), dst_ids_ptr,
dst_ids_ptr + dst_ids.size(0), new_dst_ids_loc.get());
{ // Check if unique_dst_ids includes all dst_ids.
thrust::counting_iterator<int64_t> iota(0);
auto equal_it = thrust::make_transform_iterator(
iota, EqualityFunc<scalar_t>{
sorted_order_ptr, new_dst_ids_loc.get(), dst_ids_ptr});
auto all_exist = thrust::all_of(
exec_policy, equal_it, equal_it + dst_ids.size(0),
thrust::identity<bool>());
if (!all_exist) {
throw std::out_of_range("Some ids not found.");
}
}
// Finally, lookup the new compact ids of the src and dst tensors via
// gather operations.
auto new_src_ids = torch::empty_like(src_ids);
auto new_dst_ids = torch::empty_like(dst_ids);
thrust::gather(
exec_policy, new_src_ids_loc.get(),
new_src_ids_loc.get() + src_ids.size(0),
new_ids.data_ptr<int64_t>(), new_src_ids.data_ptr<scalar_t>());
thrust::gather(
exec_policy, new_dst_ids_loc.get(),
new_dst_ids_loc.get() + dst_ids.size(0),
new_ids.data_ptr<int64_t>(), new_dst_ids.data_ptr<scalar_t>());
return std::make_tuple(real_order, new_src_ids, new_dst_ids);
}));
}
} // namespace ops
} // namespace graphbolt
...@@ -5,17 +5,27 @@ ...@@ -5,17 +5,27 @@
* @brief Unique and compact op. * @brief Unique and compact op.
*/ */
#include <graphbolt/cuda_ops.h>
#include <graphbolt/unique_and_compact.h> #include <graphbolt/unique_and_compact.h>
#include <unordered_map> #include <unordered_map>
#include "./concurrent_id_hash_map.h" #include "./concurrent_id_hash_map.h"
#include "./macro.h"
#include "./utils.h"
namespace graphbolt { namespace graphbolt {
namespace sampling { namespace sampling {
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact( std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
const torch::Tensor& src_ids, const torch::Tensor& dst_ids, const torch::Tensor& src_ids, const torch::Tensor& dst_ids,
const torch::Tensor unique_dst_ids) { const torch::Tensor unique_dst_ids) {
if (utils::is_accessible_from_gpu(src_ids) &&
utils::is_accessible_from_gpu(dst_ids) &&
utils::is_accessible_from_gpu(unique_dst_ids)) {
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(
c10::DeviceType::CUDA, "unique_and_compact",
{ return ops::UniqueAndCompact(src_ids, dst_ids, unique_dst_ids); });
}
torch::Tensor compacted_src_ids; torch::Tensor compacted_src_ids;
torch::Tensor compacted_dst_ids; torch::Tensor compacted_dst_ids;
torch::Tensor unique_ids; torch::Tensor unique_ids;
......
...@@ -128,7 +128,10 @@ def unique_and_compact_node_pairs( ...@@ -128,7 +128,10 @@ def unique_and_compact_node_pairs(
# Collect all source and destination nodes for each node type. # Collect all source and destination nodes for each node type.
src_nodes = defaultdict(list) src_nodes = defaultdict(list)
dst_nodes = defaultdict(list) dst_nodes = defaultdict(list)
device = None
for etype, (src_node, dst_node) in node_pairs.items(): for etype, (src_node, dst_node) in node_pairs.items():
if device is None:
device = src_node.device
src_type, _, dst_type = etype_str_to_tuple(etype) src_type, _, dst_type = etype_str_to_tuple(etype)
src_nodes[src_type].append(src_node) src_nodes[src_type].append(src_node)
dst_nodes[dst_type].append(dst_node) dst_nodes[dst_type].append(dst_node)
...@@ -145,7 +148,7 @@ def unique_and_compact_node_pairs( ...@@ -145,7 +148,7 @@ def unique_and_compact_node_pairs(
compacted_src = {} compacted_src = {}
compacted_dst = {} compacted_dst = {}
dtype = list(src_nodes.values())[0].dtype dtype = list(src_nodes.values())[0].dtype
default_tensor = torch.tensor([], dtype=dtype) default_tensor = torch.tensor([], dtype=dtype, device=device)
for ntype in ntypes: for ntype in ntypes:
src = src_nodes.get(ntype, default_tensor) src = src_nodes.get(ntype, default_tensor)
unique_dst = unique_dst_nodes.get(ntype, default_tensor) unique_dst = unique_dst_nodes.get(ntype, default_tensor)
...@@ -247,7 +250,10 @@ def unique_and_compact_csc_formats( ...@@ -247,7 +250,10 @@ def unique_and_compact_csc_formats(
# Collect all source and destination nodes for each node type. # Collect all source and destination nodes for each node type.
indices = defaultdict(list) indices = defaultdict(list)
device = None
for etype, csc_format in csc_formats.items(): for etype, csc_format in csc_formats.items():
if device is None:
device = csc_format.indices.device
assert csc_format.indptr[-1] == len( assert csc_format.indptr[-1] == len(
csc_format.indices csc_format.indices
), "The last element of indptr should be the same as the length of indices." ), "The last element of indptr should be the same as the length of indices."
...@@ -262,7 +268,7 @@ def unique_and_compact_csc_formats( ...@@ -262,7 +268,7 @@ def unique_and_compact_csc_formats(
unique_nodes = {} unique_nodes = {}
compacted_indices = {} compacted_indices = {}
dtype = list(indices.values())[0].dtype dtype = list(indices.values())[0].dtype
default_tensor = torch.tensor([], dtype=dtype) default_tensor = torch.tensor([], dtype=dtype, device=device)
for ntype in ntypes: for ntype in ntypes:
indice = indices.get(ntype, default_tensor) indice = indices.get(ntype, default_tensor)
unique_dst = unique_dst_nodes.get(ntype, default_tensor) unique_dst = unique_dst_nodes.get(ntype, default_tensor)
...@@ -271,7 +277,9 @@ def unique_and_compact_csc_formats( ...@@ -271,7 +277,9 @@ def unique_and_compact_csc_formats(
compacted_indices[ntype], compacted_indices[ntype],
_, _,
) = torch.ops.graphbolt.unique_and_compact( ) = torch.ops.graphbolt.unique_and_compact(
indice, torch.tensor([], dtype=indice.dtype), unique_dst indice,
torch.tensor([], dtype=indice.dtype, device=device),
unique_dst,
) )
compacted_csc_formats = {} compacted_csc_formats = {}
......
...@@ -3,11 +3,15 @@ import backend as F ...@@ -3,11 +3,15 @@ import backend as F
import dgl import dgl
import dgl.graphbolt import dgl.graphbolt
import torch import torch
import torch.multiprocessing as mp
from . import gb_test_utils from . import gb_test_utils
def test_DataLoader(): def test_DataLoader():
# https://pytorch.org/docs/master/notes/multiprocessing.html#cuda-in-multiprocessing
mp.set_start_method("spawn", force=True)
N = 40 N = 40
B = 4 B = 4
itemset = dgl.graphbolt.ItemSet(torch.arange(N), names="seed_nodes") itemset = dgl.graphbolt.ItemSet(torch.arange(N), names="seed_nodes")
......
import backend as F
import dgl.graphbolt as gb import dgl.graphbolt as gb
import pytest import pytest
import torch import torch
...@@ -71,14 +72,26 @@ def test_find_reverse_edges_circual_reverse_types(): ...@@ -71,14 +72,26 @@ def test_find_reverse_edges_circual_reverse_types():
def test_unique_and_compact_hetero(): def test_unique_and_compact_hetero():
N1 = torch.tensor([0, 5, 2, 7, 12, 7, 9, 5, 6, 2, 3, 4, 1, 0, 9]) N1 = torch.tensor(
N2 = torch.tensor([0, 3, 3, 5, 2, 7, 2, 8, 4, 9, 2, 3]) [0, 5, 2, 7, 12, 7, 9, 5, 6, 2, 3, 4, 1, 0, 9], device=F.ctx()
N3 = torch.tensor([1, 2, 6, 6, 1, 8, 3, 6, 3, 2]) )
N2 = torch.tensor([0, 3, 3, 5, 2, 7, 2, 8, 4, 9, 2, 3], device=F.ctx())
N3 = torch.tensor([1, 2, 6, 6, 1, 8, 3, 6, 3, 2], device=F.ctx())
expected_unique = { expected_unique = {
"n1": torch.tensor([0, 5, 2, 7, 12, 9, 6, 3, 4, 1]), "n1": torch.tensor([0, 5, 2, 7, 12, 9, 6, 3, 4, 1], device=F.ctx()),
"n2": torch.tensor([0, 3, 5, 2, 7, 8, 4, 9]), "n2": torch.tensor([0, 3, 5, 2, 7, 8, 4, 9], device=F.ctx()),
"n3": torch.tensor([1, 2, 6, 8, 3]), "n3": torch.tensor([1, 2, 6, 8, 3], device=F.ctx()),
} }
if N1.is_cuda:
expected_reverse_id = {
k: v.sort()[1] for k, v in expected_unique.items()
}
expected_unique = {k: v.sort()[0] for k, v in expected_unique.items()}
else:
expected_reverse_id = {
k: torch.arange(0, v.shape[0], device=F.ctx())
for k, v in expected_unique.items()
}
nodes_dict = { nodes_dict = {
"n1": N1.split(5), "n1": N1.split(5),
"n2": N2.split(4), "n2": N2.split(4),
...@@ -86,21 +99,21 @@ def test_unique_and_compact_hetero(): ...@@ -86,21 +99,21 @@ def test_unique_and_compact_hetero():
} }
expected_nodes_dict = { expected_nodes_dict = {
"n1": [ "n1": [
torch.tensor([0, 1, 2, 3, 4]), torch.tensor([0, 1, 2, 3, 4], device=F.ctx()),
torch.tensor([3, 5, 1, 6, 2]), torch.tensor([3, 5, 1, 6, 2], device=F.ctx()),
torch.tensor([7, 8, 9, 0, 5]), torch.tensor([7, 8, 9, 0, 5], device=F.ctx()),
], ],
"n2": [ "n2": [
torch.tensor([0, 1, 1, 2]), torch.tensor([0, 1, 1, 2], device=F.ctx()),
torch.tensor([3, 4, 3, 5]), torch.tensor([3, 4, 3, 5], device=F.ctx()),
torch.tensor([6, 7, 3, 1]), torch.tensor([6, 7, 3, 1], device=F.ctx()),
], ],
"n3": [ "n3": [
torch.tensor([0, 1]), torch.tensor([0, 1], device=F.ctx()),
torch.tensor([2, 2]), torch.tensor([2, 2], device=F.ctx()),
torch.tensor([0, 3]), torch.tensor([0, 3], device=F.ctx()),
torch.tensor([4, 2]), torch.tensor([4, 2], device=F.ctx()),
torch.tensor([4, 1]), torch.tensor([4, 1], device=F.ctx()),
], ],
} }
...@@ -113,17 +126,29 @@ def test_unique_and_compact_hetero(): ...@@ -113,17 +126,29 @@ def test_unique_and_compact_hetero():
expected_nodes = expected_nodes_dict[ntype] expected_nodes = expected_nodes_dict[ntype]
assert isinstance(nodes, list) assert isinstance(nodes, list)
for expected_node, node in zip(expected_nodes, nodes): for expected_node, node in zip(expected_nodes, nodes):
node = expected_reverse_id[ntype][node]
assert torch.equal(expected_node, node) assert torch.equal(expected_node, node)
def test_unique_and_compact_homo(): def test_unique_and_compact_homo():
N = torch.tensor([0, 5, 2, 7, 12, 7, 9, 5, 6, 2, 3, 4, 1, 0, 9]) N = torch.tensor(
expected_unique_N = torch.tensor([0, 5, 2, 7, 12, 9, 6, 3, 4, 1]) [0, 5, 2, 7, 12, 7, 9, 5, 6, 2, 3, 4, 1, 0, 9], device=F.ctx()
)
expected_unique_N = torch.tensor(
[0, 5, 2, 7, 12, 9, 6, 3, 4, 1], device=F.ctx()
)
if N.is_cuda:
expected_reverse_id_N = expected_unique_N.sort()[1]
expected_unique_N = expected_unique_N.sort()[0]
else:
expected_reverse_id_N = torch.arange(
0, expected_unique_N.shape[0], device=F.ctx()
)
nodes_list = N.split(5) nodes_list = N.split(5)
expected_nodes_list = [ expected_nodes_list = [
torch.tensor([0, 1, 2, 3, 4]), torch.tensor([0, 1, 2, 3, 4], device=F.ctx()),
torch.tensor([3, 5, 1, 6, 2]), torch.tensor([3, 5, 1, 6, 2], device=F.ctx()),
torch.tensor([7, 8, 9, 0, 5]), torch.tensor([7, 8, 9, 0, 5], device=F.ctx()),
] ]
unique, compacted = gb.unique_and_compact(nodes_list) unique, compacted = gb.unique_and_compact(nodes_list)
...@@ -131,42 +156,59 @@ def test_unique_and_compact_homo(): ...@@ -131,42 +156,59 @@ def test_unique_and_compact_homo():
assert torch.equal(unique, expected_unique_N) assert torch.equal(unique, expected_unique_N)
assert isinstance(compacted, list) assert isinstance(compacted, list)
for expected_node, node in zip(expected_nodes_list, compacted): for expected_node, node in zip(expected_nodes_list, compacted):
node = expected_reverse_id_N[node]
assert torch.equal(expected_node, node) assert torch.equal(expected_node, node)
def test_unique_and_compact_node_pairs_hetero(): def test_unique_and_compact_node_pairs_hetero():
node_pairs = { node_pairs = {
"n1:e1:n2": ( "n1:e1:n2": (
torch.tensor([1, 3, 4, 6, 2, 7, 9, 4, 2, 6]), torch.tensor([1, 3, 4, 6, 2, 7, 9, 4, 2, 6], device=F.ctx()),
torch.tensor([2, 2, 2, 4, 1, 1, 1, 3, 3, 3]), torch.tensor([2, 2, 2, 4, 1, 1, 1, 3, 3, 3], device=F.ctx()),
), ),
"n1:e2:n3": ( "n1:e2:n3": (
torch.tensor([5, 2, 6, 4, 7, 2, 8, 1, 3, 0]), torch.tensor([5, 2, 6, 4, 7, 2, 8, 1, 3, 0], device=F.ctx()),
torch.tensor([1, 3, 3, 3, 2, 2, 2, 7, 7, 7]), torch.tensor([1, 3, 3, 3, 2, 2, 2, 7, 7, 7], device=F.ctx()),
), ),
"n2:e3:n3": ( "n2:e3:n3": (
torch.tensor([2, 5, 4, 1, 4, 3, 6, 0]), torch.tensor([2, 5, 4, 1, 4, 3, 6, 0], device=F.ctx()),
torch.tensor([1, 1, 3, 3, 2, 2, 7, 7]), torch.tensor([1, 1, 3, 3, 2, 2, 7, 7], device=F.ctx()),
), ),
} }
expected_unique_nodes = { expected_unique_nodes = {
"n1": torch.tensor([1, 3, 4, 6, 2, 7, 9, 5, 8, 0]), "n1": torch.tensor([1, 3, 4, 6, 2, 7, 9, 5, 8, 0], device=F.ctx()),
"n2": torch.tensor([1, 2, 3, 4, 5, 6, 0]), "n2": torch.tensor([1, 2, 3, 4, 5, 6, 0], device=F.ctx()),
"n3": torch.tensor([1, 2, 3, 7]), "n3": torch.tensor([1, 2, 3, 7], device=F.ctx()),
} }
if expected_unique_nodes["n1"].is_cuda:
expected_reverse_id = {
"n1": expected_unique_nodes["n1"].sort()[1],
"n2": torch.tensor([0, 1, 2, 3, 6, 4, 5], device=F.ctx()),
"n3": expected_unique_nodes["n3"].sort()[1],
}
expected_unique_nodes = {
"n1": expected_unique_nodes["n1"].sort()[0],
"n2": torch.tensor([1, 2, 3, 4, 0, 5, 6], device=F.ctx()),
"n3": expected_unique_nodes["n3"].sort()[0],
}
else:
expected_reverse_id = {
k: torch.arange(0, v.shape[0], device=F.ctx())
for k, v in expected_unique_nodes.items()
}
expected_node_pairs = { expected_node_pairs = {
"n1:e1:n2": ( "n1:e1:n2": (
torch.tensor([0, 1, 2, 3, 4, 5, 6, 2, 4, 3]), torch.tensor([0, 1, 2, 3, 4, 5, 6, 2, 4, 3], device=F.ctx()),
torch.tensor([1, 1, 1, 3, 0, 0, 0, 2, 2, 2]), torch.tensor([1, 1, 1, 3, 0, 0, 0, 2, 2, 2], device=F.ctx()),
), ),
"n1:e2:n3": ( "n1:e2:n3": (
torch.tensor([7, 4, 3, 2, 5, 4, 8, 0, 1, 9]), torch.tensor([7, 4, 3, 2, 5, 4, 8, 0, 1, 9], device=F.ctx()),
torch.tensor([0, 2, 2, 2, 1, 1, 1, 3, 3, 3]), torch.tensor([0, 2, 2, 2, 1, 1, 1, 3, 3, 3], device=F.ctx()),
), ),
"n2:e3:n3": ( "n2:e3:n3": (
torch.tensor([1, 4, 3, 0, 3, 2, 5, 6]), torch.tensor([1, 4, 3, 0, 3, 2, 5, 6], device=F.ctx()),
torch.tensor([0, 0, 2, 2, 1, 1, 3, 3]), torch.tensor([0, 0, 2, 2, 1, 1, 3, 3], device=F.ctx()),
), ),
} }
...@@ -178,19 +220,26 @@ def test_unique_and_compact_node_pairs_hetero(): ...@@ -178,19 +220,26 @@ def test_unique_and_compact_node_pairs_hetero():
assert torch.equal(nodes, expected_nodes) assert torch.equal(nodes, expected_nodes)
for etype, pair in compacted_node_pairs.items(): for etype, pair in compacted_node_pairs.items():
u, v = pair u, v = pair
ntype1, _, ntype2 = etype.split(":")
u = expected_reverse_id[ntype1][u]
v = expected_reverse_id[ntype2][v]
expected_u, expected_v = expected_node_pairs[etype] expected_u, expected_v = expected_node_pairs[etype]
assert torch.equal(u, expected_u) assert torch.equal(u, expected_u)
assert torch.equal(v, expected_v) assert torch.equal(v, expected_v)
def test_unique_and_compact_node_pairs_homo(): def test_unique_and_compact_node_pairs_homo():
dst_nodes = torch.tensor([1, 1, 3, 3, 5, 5, 2, 6, 6, 6, 6]) dst_nodes = torch.tensor([1, 1, 3, 3, 5, 5, 2, 6, 6, 6, 6], device=F.ctx())
src_ndoes = torch.tensor([2, 3, 1, 4, 5, 2, 5, 1, 4, 4, 6]) src_nodes = torch.tensor([2, 3, 1, 4, 5, 2, 5, 1, 4, 4, 6], device=F.ctx())
node_pairs = (src_ndoes, dst_nodes) node_pairs = (src_nodes, dst_nodes)
expected_unique_nodes = torch.tensor([1, 2, 3, 5, 6, 4]) expected_unique_nodes = torch.tensor([1, 2, 3, 5, 6, 4], device=F.ctx())
expected_dst_nodes = torch.tensor([0, 0, 2, 2, 3, 3, 1, 4, 4, 4, 4]) expected_dst_nodes = torch.tensor(
expected_src_ndoes = torch.tensor([1, 2, 0, 5, 3, 1, 3, 0, 5, 5, 4]) [0, 0, 2, 2, 3, 3, 1, 4, 4, 4, 4], device=F.ctx()
)
expected_src_ndoes = torch.tensor(
[1, 2, 0, 5, 3, 1, 3, 0, 5, 5, 4], device=F.ctx()
)
unique_nodes, compacted_node_pairs = gb.unique_and_compact_node_pairs( unique_nodes, compacted_node_pairs = gb.unique_and_compact_node_pairs(
node_pairs node_pairs
) )
...@@ -199,12 +248,17 @@ def test_unique_and_compact_node_pairs_homo(): ...@@ -199,12 +248,17 @@ def test_unique_and_compact_node_pairs_homo():
u, v = compacted_node_pairs u, v = compacted_node_pairs
assert torch.equal(u, expected_src_ndoes) assert torch.equal(u, expected_src_ndoes)
assert torch.equal(v, expected_dst_nodes) assert torch.equal(v, expected_dst_nodes)
assert torch.equal(unique_nodes[:5], torch.tensor([1, 2, 3, 5, 6])) assert torch.equal(
unique_nodes[:5], torch.tensor([1, 2, 3, 5, 6], device=F.ctx())
)
def test_incomplete_unique_dst_nodes_(): def test_incomplete_unique_dst_nodes_():
node_pairs = (torch.arange(0, 50), torch.arange(100, 150)) node_pairs = (
unique_dst_nodes = torch.arange(150, 200) torch.arange(0, 50, device=F.ctx()),
torch.arange(100, 150, device=F.ctx()),
)
unique_dst_nodes = torch.arange(150, 200, device=F.ctx())
with pytest.raises(IndexError): with pytest.raises(IndexError):
gb.unique_and_compact_node_pairs(node_pairs, unique_dst_nodes) gb.unique_and_compact_node_pairs(node_pairs, unique_dst_nodes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment