Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
78df8101
Unverified
Commit
78df8101
authored
Apr 07, 2024
by
Muhammed Fatih BALIN
Committed by
GitHub
Apr 07, 2024
Browse files
[GraphBolt] Add optimized `unique_and_compact_batched`. (#7239)
parent
7eb4de4b
Changes
20
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
795 additions
and
194 deletions
+795
-194
.gitmodules
.gitmodules
+3
-0
CMakeLists.txt
CMakeLists.txt
+1
-1
graphbolt/CMakeLists.txt
graphbolt/CMakeLists.txt
+32
-5
graphbolt/build.bat
graphbolt/build.bat
+2
-2
graphbolt/build.sh
graphbolt/build.sh
+5
-1
graphbolt/include/graphbolt/cuda_ops.h
graphbolt/include/graphbolt/cuda_ops.h
+11
-0
graphbolt/include/graphbolt/unique_and_compact.h
graphbolt/include/graphbolt/unique_and_compact.h
+6
-0
graphbolt/src/cuda/common.h
graphbolt/src/cuda/common.h
+26
-3
graphbolt/src/cuda/extension/unique_and_compact.h
graphbolt/src/cuda/extension/unique_and_compact.h
+26
-0
graphbolt/src/cuda/extension/unique_and_compact_map.cu
graphbolt/src/cuda/extension/unique_and_compact_map.cu
+267
-0
graphbolt/src/cuda/unique_and_compact_impl.cu
graphbolt/src/cuda/unique_and_compact_impl.cu
+228
-103
graphbolt/src/python_binding.cc
graphbolt/src/python_binding.cc
+1
-0
graphbolt/src/unique_and_compact.cc
graphbolt/src/unique_and_compact.cc
+32
-0
python/dgl/graphbolt/internal/sample_utils.py
python/dgl/graphbolt/internal/sample_utils.py
+12
-11
tests/python/pytorch/graphbolt/impl/test_in_subgraph_sampler.py
...python/pytorch/graphbolt/impl/test_in_subgraph_sampler.py
+1
-1
tests/python/pytorch/graphbolt/internal/test_sample_utils.py
tests/python/pytorch/graphbolt/internal/test_sample_utils.py
+2
-2
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
+104
-48
tests/python/pytorch/graphbolt/test_utils.py
tests/python/pytorch/graphbolt/test_utils.py
+34
-16
third_party/cccl
third_party/cccl
+1
-1
third_party/cuco
third_party/cuco
+1
-0
No files found.
.gitmodules
View file @
78df8101
...
@@ -28,3 +28,6 @@
...
@@ -28,3 +28,6 @@
[submodule "third_party/liburing"]
[submodule "third_party/liburing"]
path = third_party/liburing
path = third_party/liburing
url = https://github.com/axboe/liburing.git
url = https://github.com/axboe/liburing.git
[submodule "third_party/cuco"]
path = third_party/cuco
url = https://github.com/NVIDIA/cuCollections.git
CMakeLists.txt
View file @
78df8101
...
@@ -590,5 +590,5 @@ if(BUILD_GRAPHBOLT)
...
@@ -590,5 +590,5 @@ if(BUILD_GRAPHBOLT)
endif
(
USE_CUDA
)
endif
(
USE_CUDA
)
if
(
CMAKE_SYSTEM_NAME MATCHES
"Linux"
)
if
(
CMAKE_SYSTEM_NAME MATCHES
"Linux"
)
add_dependencies
(
graphbolt liburing
)
add_dependencies
(
graphbolt liburing
)
endif
(
USE_CUDA
)
endif
()
endif
(
BUILD_GRAPHBOLT
)
endif
(
BUILD_GRAPHBOLT
)
graphbolt/CMakeLists.txt
View file @
78df8101
...
@@ -58,10 +58,19 @@ if(USE_CUDA)
...
@@ -58,10 +58,19 @@ if(USE_CUDA)
if
(
DEFINED ENV{CUDAARCHS}
)
if
(
DEFINED ENV{CUDAARCHS}
)
set
(
CMAKE_CUDA_ARCHITECTURES $ENV{CUDAARCHS}
)
set
(
CMAKE_CUDA_ARCHITECTURES $ENV{CUDAARCHS}
)
endif
()
endif
()
set
(
CMAKE_CUDA_ARCHITECTURES_FILTERED
${
CMAKE_CUDA_ARCHITECTURES
}
)
# CUDA extension supports only sm_70 and up (Volta+).
list
(
FILTER CMAKE_CUDA_ARCHITECTURES_FILTERED EXCLUDE REGEX
"[2-6][0-9]"
)
list
(
LENGTH CMAKE_CUDA_ARCHITECTURES_FILTERED CMAKE_CUDA_ARCHITECTURES_FILTERED_LEN
)
if
(
CMAKE_CUDA_ARCHITECTURES_FILTERED_LEN EQUAL 0
)
# Build the CUDA extension at least build for Volta.
set
(
CMAKE_CUDA_ARCHITECTURES_FILTERED
"70"
)
endif
()
set
(
LIB_GRAPHBOLT_CUDA_NAME
"
${
LIB_GRAPHBOLT_NAME
}
_cuda"
)
endif
()
endif
()
add_library
(
${
LIB_GRAPHBOLT_NAME
}
SHARED
${
BOLT_SRC
}
${
BOLT_HEADERS
}
)
add_library
(
${
LIB_GRAPHBOLT_NAME
}
SHARED
${
BOLT_SRC
}
${
BOLT_HEADERS
}
)
target_
include_directories
(
${
LIB_GRAPHBOLT_NAME
}
PRIVAT
E
${
BOLT_DIR
}
include_directories
(
BEFOR
E
${
BOLT_DIR
}
${
BOLT_HEADERS
}
${
BOLT_HEADERS
}
"../third_party/dmlc-core/include"
"../third_party/dmlc-core/include"
"../third_party/pcg/include"
)
"../third_party/pcg/include"
)
...
@@ -73,12 +82,25 @@ if(CMAKE_SYSTEM_NAME MATCHES "Linux")
...
@@ -73,12 +82,25 @@ if(CMAKE_SYSTEM_NAME MATCHES "Linux")
endif
()
endif
()
if
(
USE_CUDA
)
if
(
USE_CUDA
)
file
(
GLOB BOLT_CUDA_EXTENSION_SRC
${
BOLT_DIR
}
/cuda/extension/*.cu
${
BOLT_DIR
}
/cuda/extension/*.cc
)
# Until https://github.com/NVIDIA/cccl/issues/1083 is resolved, we need to
# compile the cuda/extension folder with Volta+ CUDA architectures.
add_library
(
${
LIB_GRAPHBOLT_CUDA_NAME
}
STATIC
${
BOLT_CUDA_EXTENSION_SRC
}
${
BOLT_HEADERS
}
)
target_link_libraries
(
${
LIB_GRAPHBOLT_CUDA_NAME
}
"
${
TORCH_LIBRARIES
}
"
)
set_target_properties
(
${
LIB_GRAPHBOLT_NAME
}
PROPERTIES CUDA_STANDARD 17
)
set_target_properties
(
${
LIB_GRAPHBOLT_NAME
}
PROPERTIES CUDA_STANDARD 17
)
set_target_properties
(
${
LIB_GRAPHBOLT_CUDA_NAME
}
PROPERTIES CUDA_STANDARD 17
)
set_target_properties
(
${
LIB_GRAPHBOLT_CUDA_NAME
}
PROPERTIES CUDA_ARCHITECTURES
"
${
CMAKE_CUDA_ARCHITECTURES_FILTERED
}
"
)
set_target_properties
(
${
LIB_GRAPHBOLT_CUDA_NAME
}
PROPERTIES POSITION_INDEPENDENT_CODE TRUE
)
message
(
STATUS
"Use external CCCL library for a consistent API and performance for graphbolt."
)
message
(
STATUS
"Use external CCCL library for a consistent API and performance for graphbolt."
)
target_include_directories
(
${
LIB_GRAPHBOLT_NAME
}
PRIVATE
include_directories
(
BEFORE
"../third_party/cccl/thrust"
"../third_party/cccl/thrust"
"../third_party/cccl/cub"
"../third_party/cccl/cub"
"../third_party/cccl/libcudacxx/include"
)
"../third_party/cccl/libcudacxx/include"
"../third_party/cuco/include"
)
message
(
STATUS
"Use HugeCTR gpu_cache for graphbolt with INCLUDE_DIRS $ENV{GPU_CACHE_INCLUDE_DIRS}."
)
message
(
STATUS
"Use HugeCTR gpu_cache for graphbolt with INCLUDE_DIRS $ENV{GPU_CACHE_INCLUDE_DIRS}."
)
target_include_directories
(
${
LIB_GRAPHBOLT_NAME
}
PRIVATE $ENV{GPU_CACHE_INCLUDE_DIRS}
)
target_include_directories
(
${
LIB_GRAPHBOLT_NAME
}
PRIVATE $ENV{GPU_CACHE_INCLUDE_DIRS}
)
...
@@ -87,6 +109,11 @@ if(USE_CUDA)
...
@@ -87,6 +109,11 @@ if(USE_CUDA)
get_property
(
archs TARGET
${
LIB_GRAPHBOLT_NAME
}
PROPERTY CUDA_ARCHITECTURES
)
get_property
(
archs TARGET
${
LIB_GRAPHBOLT_NAME
}
PROPERTY CUDA_ARCHITECTURES
)
message
(
STATUS
"CUDA_ARCHITECTURES for graphbolt:
${
archs
}
"
)
message
(
STATUS
"CUDA_ARCHITECTURES for graphbolt:
${
archs
}
"
)
get_property
(
archs TARGET
${
LIB_GRAPHBOLT_CUDA_NAME
}
PROPERTY CUDA_ARCHITECTURES
)
message
(
STATUS
"CUDA_ARCHITECTURES for graphbolt extension:
${
archs
}
"
)
target_link_libraries
(
${
LIB_GRAPHBOLT_NAME
}
${
LIB_GRAPHBOLT_CUDA_NAME
}
)
endif
()
endif
()
# The Torch CMake configuration only sets up the path for the MKL library when
# The Torch CMake configuration only sets up the path for the MKL library when
...
...
graphbolt/build.bat
View file @
78df8101
...
@@ -11,7 +11,7 @@ IF x%1x == xx GOTO single
...
@@ -11,7 +11,7 @@ IF x%1x == xx GOTO single
FOR
%%X
IN
(
%
*)
DO
(
FOR
%%X
IN
(
%
*)
DO
(
DEL
/S /Q
*
DEL
/S /Q
*
"
%CMAKE_COMMAND%
"
-DGPU
_CACHE_BUILD_DIR
=
%BINDIR%
-DCMAKE
_CONFIGURATION_TYPES
=
Release
-DPYTHON
_INTERP
=
%%X
..
-G
"Visual Studio 16 2019"
||
EXIT
/B
1
"
%CMAKE_COMMAND%
"
-DGPU
_CACHE_BUILD_DIR
=
%BINDIR%
-DCMAKE
_CONFIGURATION_TYPES
=
Release
-DPYTHON
_INTERP
=
%%X
-DTORCH
_CUDA_ARCH_LIST
=
Volta
..
-G
"Visual Studio 16 2019"
||
EXIT
/B
1
msbuild
graphbolt
.sln
/m /nr
:false
||
EXIT
/B
1
msbuild
graphbolt
.sln
/m /nr
:false
||
EXIT
/B
1
COPY
/Y
Release
\
*
.dll
"
%BINDIR%
\graphbolt"
||
EXIT
/B
1
COPY
/Y
Release
\
*
.dll
"
%BINDIR%
\graphbolt"
||
EXIT
/B
1
)
)
...
@@ -21,7 +21,7 @@ GOTO end
...
@@ -21,7 +21,7 @@ GOTO end
:single
:single
DEL
/S /Q
*
DEL
/S /Q
*
"
%CMAKE_COMMAND%
"
-DGPU
_CACHE_BUILD_DIR
=
%BINDIR%
-DCMAKE
_CONFIGURATION_TYPES
=
Release
..
-G
"Visual Studio 16 2019"
||
EXIT
/B
1
"
%CMAKE_COMMAND%
"
-DGPU
_CACHE_BUILD_DIR
=
%BINDIR%
-DCMAKE
_CONFIGURATION_TYPES
=
Release
-DTORCH
_CUDA_ARCH_LIST
=
Volta
..
-G
"Visual Studio 16 2019"
||
EXIT
/B
1
msbuild
graphbolt
.sln
/m /nr
:false
||
EXIT
/B
1
msbuild
graphbolt
.sln
/m /nr
:false
||
EXIT
/B
1
COPY
/Y
Release
\
*
.dll
"
%BINDIR%
\graphbolt"
||
EXIT
/B
1
COPY
/Y
Release
\
*
.dll
"
%BINDIR%
\graphbolt"
||
EXIT
/B
1
...
...
graphbolt/build.sh
View file @
78df8101
...
@@ -12,7 +12,11 @@ else
...
@@ -12,7 +12,11 @@ else
CPSOURCE
=
*
.so
CPSOURCE
=
*
.so
fi
fi
CMAKE_FLAGS
=
"-DCUDA_TOOLKIT_ROOT_DIR=
$CUDA_TOOLKIT_ROOT_DIR
-DUSE_CUDA=
$USE_CUDA
-DGPU_CACHE_BUILD_DIR=
$BINDIR
"
# We build for the same architectures as DGL, thus we hardcode
# TORCH_CUDA_ARCH_LIST and we need to at least compile for Volta. Until
# https://github.com/NVIDIA/cccl/issues/1083 is resolved, we need to compile the
# cuda/extension folder with Volta+ CUDA architectures.
CMAKE_FLAGS
=
"-DCUDA_TOOLKIT_ROOT_DIR=
$CUDA_TOOLKIT_ROOT_DIR
-DUSE_CUDA=
$USE_CUDA
-DGPU_CACHE_BUILD_DIR=
$BINDIR
-DTORCH_CUDA_ARCH_LIST=Volta"
echo
$CMAKE_FLAGS
echo
$CMAKE_FLAGS
if
[
$#
-eq
0
]
;
then
if
[
$#
-eq
0
]
;
then
...
...
graphbolt/include/graphbolt/cuda_ops.h
View file @
78df8101
...
@@ -235,6 +235,17 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
...
@@ -235,6 +235,17 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
const
torch
::
Tensor
src_ids
,
const
torch
::
Tensor
dst_ids
,
const
torch
::
Tensor
src_ids
,
const
torch
::
Tensor
dst_ids
,
const
torch
::
Tensor
unique_dst_ids
,
int
num_bits
=
0
);
const
torch
::
Tensor
unique_dst_ids
,
int
num_bits
=
0
);
/**
* @brief Batched version of UniqueAndCompact. The ith element of the return
* value is equal to the passing the ith elements of the input arguments to
* UniqueAndCompact.
*/
std
::
vector
<
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
>>
UniqueAndCompactBatched
(
const
std
::
vector
<
torch
::
Tensor
>&
src_ids
,
const
std
::
vector
<
torch
::
Tensor
>&
dst_ids
,
const
std
::
vector
<
torch
::
Tensor
>&
unique_dst_ids
,
int
num_bits
=
0
);
}
// namespace ops
}
// namespace ops
}
// namespace graphbolt
}
// namespace graphbolt
...
...
graphbolt/include/graphbolt/unique_and_compact.h
View file @
78df8101
...
@@ -50,6 +50,12 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
...
@@ -50,6 +50,12 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
const
torch
::
Tensor
&
src_ids
,
const
torch
::
Tensor
&
dst_ids
,
const
torch
::
Tensor
&
src_ids
,
const
torch
::
Tensor
&
dst_ids
,
const
torch
::
Tensor
unique_dst_ids
);
const
torch
::
Tensor
unique_dst_ids
);
std
::
vector
<
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
>>
UniqueAndCompactBatched
(
const
std
::
vector
<
torch
::
Tensor
>&
src_ids
,
const
std
::
vector
<
torch
::
Tensor
>&
dst_ids
,
const
std
::
vector
<
torch
::
Tensor
>
unique_dst_ids
);
}
// namespace sampling
}
// namespace sampling
}
// namespace graphbolt
}
// namespace graphbolt
...
...
graphbolt/src/cuda/common.h
View file @
78df8101
...
@@ -11,6 +11,7 @@
...
@@ -11,6 +11,7 @@
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAStream.h>
#include <c10/cuda/CUDAStream.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include <torch/script.h>
#include <torch/script.h>
...
@@ -38,12 +39,17 @@ namespace cuda {
...
@@ -38,12 +39,17 @@ namespace cuda {
*
*
* int_array.get() gives the raw pointer.
* int_array.get() gives the raw pointer.
*/
*/
template
<
typename
value_t
=
char
>
struct
CUDAWorkspaceAllocator
{
struct
CUDAWorkspaceAllocator
{
static_assert
(
sizeof
(
char
)
==
1
,
"sizeof(char) == 1 should hold."
);
// Required by thrust to satisfy allocator requirements.
// Required by thrust to satisfy allocator requirements.
using
value_type
=
char
;
using
value_type
=
value_t
;
explicit
CUDAWorkspaceAllocator
()
{
at
::
globalContext
().
lazyInitCUDA
();
}
explicit
CUDAWorkspaceAllocator
()
{
at
::
globalContext
().
lazyInitCUDA
();
}
template
<
class
U
>
CUDAWorkspaceAllocator
(
CUDAWorkspaceAllocator
<
U
>
const
&
)
noexcept
{}
CUDAWorkspaceAllocator
&
operator
=
(
const
CUDAWorkspaceAllocator
&
)
=
default
;
CUDAWorkspaceAllocator
&
operator
=
(
const
CUDAWorkspaceAllocator
&
)
=
default
;
void
operator
()(
void
*
ptr
)
const
{
void
operator
()(
void
*
ptr
)
const
{
...
@@ -53,7 +59,7 @@ struct CUDAWorkspaceAllocator {
...
@@ -53,7 +59,7 @@ struct CUDAWorkspaceAllocator {
// Required by thrust to satisfy allocator requirements.
// Required by thrust to satisfy allocator requirements.
value_type
*
allocate
(
std
::
ptrdiff_t
size
)
const
{
value_type
*
allocate
(
std
::
ptrdiff_t
size
)
const
{
return
reinterpret_cast
<
value_type
*>
(
return
reinterpret_cast
<
value_type
*>
(
c10
::
cuda
::
CUDACachingAllocator
::
raw_alloc
(
size
));
c10
::
cuda
::
CUDACachingAllocator
::
raw_alloc
(
size
*
sizeof
(
value_type
)
));
}
}
// Required by thrust to satisfy allocator requirements.
// Required by thrust to satisfy allocator requirements.
...
@@ -63,7 +69,9 @@ struct CUDAWorkspaceAllocator {
...
@@ -63,7 +69,9 @@ struct CUDAWorkspaceAllocator {
std
::
unique_ptr
<
T
,
CUDAWorkspaceAllocator
>
AllocateStorage
(
std
::
unique_ptr
<
T
,
CUDAWorkspaceAllocator
>
AllocateStorage
(
std
::
size_t
size
)
const
{
std
::
size_t
size
)
const
{
return
std
::
unique_ptr
<
T
,
CUDAWorkspaceAllocator
>
(
return
std
::
unique_ptr
<
T
,
CUDAWorkspaceAllocator
>
(
reinterpret_cast
<
T
*>
(
allocate
(
sizeof
(
T
)
*
size
)),
*
this
);
reinterpret_cast
<
T
*>
(
c10
::
cuda
::
CUDACachingAllocator
::
raw_alloc
(
sizeof
(
T
)
*
size
)),
*
this
);
}
}
};
};
...
@@ -81,6 +89,21 @@ inline bool is_zero<dim3>(dim3 size) {
...
@@ -81,6 +89,21 @@ inline bool is_zero<dim3>(dim3 size) {
return
size
.
x
==
0
||
size
.
y
==
0
||
size
.
z
==
0
;
return
size
.
x
==
0
||
size
.
y
==
0
||
size
.
z
==
0
;
}
}
#define CUDA_DRIVER_CHECK(EXPR) \
do { \
CUresult __err = EXPR; \
if (__err != CUDA_SUCCESS) { \
const char* err_str; \
CUresult get_error_str_err C10_UNUSED = \
cuGetErrorString(__err, &err_str); \
if (get_error_str_err != CUDA_SUCCESS) { \
AT_ERROR("CUDA driver error: unknown error"); \
} else { \
AT_ERROR("CUDA driver error: ", err_str); \
} \
} \
} while (0)
#define CUDA_CALL(func) C10_CUDA_CHECK((func))
#define CUDA_CALL(func) C10_CUDA_CHECK((func))
#define CUDA_KERNEL_CALL(kernel, nblks, nthrs, shmem, ...) \
#define CUDA_KERNEL_CALL(kernel, nblks, nthrs, shmem, ...) \
...
...
graphbolt/src/cuda/extension/unique_and_compact.h
0 → 100644
View file @
78df8101
/**
* Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
* @file cuda/unique_and_compact.h
* @brief Unique and compact operator utilities on CUDA using hash table.
*/
#ifndef GRAPHBOLT_CUDA_UNIQUE_AND_COMPACT_H_
#define GRAPHBOLT_CUDA_UNIQUE_AND_COMPACT_H_
#include <torch/script.h>
#include <vector>
namespace
graphbolt
{
namespace
ops
{
std
::
vector
<
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
>
>
UniqueAndCompactBatchedHashMapBased
(
const
std
::
vector
<
torch
::
Tensor
>&
src_ids
,
const
std
::
vector
<
torch
::
Tensor
>&
dst_ids
,
const
std
::
vector
<
torch
::
Tensor
>&
unique_dst_ids
);
}
// namespace ops
}
// namespace graphbolt
#endif // GRAPHBOLT_CUDA_UNIQUE_AND_COMPACT_H_
graphbolt/src/cuda/extension/unique_and_compact_map.cu
0 → 100644
View file @
78df8101
/**
* Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
* @file cuda/unique_and_compact_map.cu
* @brief Unique and compact operator implementation on CUDA using hash table.
*/
#include <graphbolt/cuda_ops.h>
#include <thrust/gather.h>
#include <cuco/static_map.cuh>
#include <cuda/std/atomic>
#include <numeric>
#include "../common.h"
#include "../utils.h"
#include "./unique_and_compact.h"
namespace
graphbolt
{
namespace
ops
{
// Support graphs with up to 2^kNodeIdBits nodes.
constexpr
int
kNodeIdBits
=
40
;
template
<
typename
index_t
,
typename
map_t
>
__global__
void
_InsertAndSetMinBatched
(
const
int64_t
num_edges
,
const
int32_t
*
const
indexes
,
index_t
**
pointers
,
const
int64_t
*
const
offsets
,
map_t
map
)
{
int64_t
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
stride
=
gridDim
.
x
*
blockDim
.
x
;
while
(
i
<
num_edges
)
{
const
int64_t
tensor_index
=
indexes
[
i
];
const
auto
tensor_offset
=
i
-
offsets
[
tensor_index
];
const
int64_t
node_id
=
pointers
[
tensor_index
][
tensor_offset
];
const
auto
batch_index
=
tensor_index
/
2
;
const
int64_t
key
=
node_id
|
(
batch_index
<<
kNodeIdBits
);
auto
[
slot
,
is_new_key
]
=
map
.
insert_and_find
(
cuco
::
pair
{
key
,
i
});
if
(
!
is_new_key
)
{
auto
ref
=
::
cuda
::
atomic_ref
<
int64_t
,
::
cuda
::
thread_scope_device
>
{
slot
->
second
};
ref
.
fetch_min
(
i
,
::
cuda
::
memory_order_relaxed
);
}
i
+=
stride
;
}
}
template
<
typename
index_t
,
typename
map_t
>
__global__
void
_IsInsertedBatched
(
const
int64_t
num_edges
,
const
int32_t
*
const
indexes
,
index_t
**
pointers
,
const
int64_t
*
const
offsets
,
map_t
map
,
int64_t
*
valid
)
{
int64_t
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
stride
=
gridDim
.
x
*
blockDim
.
x
;
while
(
i
<
num_edges
)
{
const
int64_t
tensor_index
=
indexes
[
i
];
const
auto
tensor_offset
=
i
-
offsets
[
tensor_index
];
const
int64_t
node_id
=
pointers
[
tensor_index
][
tensor_offset
];
const
auto
batch_index
=
tensor_index
/
2
;
const
int64_t
key
=
node_id
|
(
batch_index
<<
kNodeIdBits
);
auto
slot
=
map
.
find
(
key
);
valid
[
i
]
=
slot
->
second
==
i
;
i
+=
stride
;
}
}
template
<
typename
index_t
,
typename
map_t
>
__global__
void
_GetInsertedBatched
(
const
int64_t
num_edges
,
const
int32_t
*
const
indexes
,
index_t
**
pointers
,
const
int64_t
*
const
offsets
,
map_t
map
,
const
int64_t
*
const
valid
,
index_t
*
unique_ids
)
{
int64_t
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
stride
=
gridDim
.
x
*
blockDim
.
x
;
while
(
i
<
num_edges
)
{
const
auto
valid_i
=
valid
[
i
];
if
(
valid_i
+
1
==
valid
[
i
+
1
])
{
const
int64_t
tensor_index
=
indexes
[
i
];
const
auto
tensor_offset
=
i
-
offsets
[
tensor_index
];
const
int64_t
node_id
=
pointers
[
tensor_index
][
tensor_offset
];
const
auto
batch_index
=
tensor_index
/
2
;
const
int64_t
key
=
node_id
|
(
batch_index
<<
kNodeIdBits
);
auto
slot
=
map
.
find
(
key
);
const
auto
batch_offset
=
offsets
[
batch_index
*
2
];
const
auto
new_id
=
valid_i
-
valid
[
batch_offset
];
unique_ids
[
valid_i
]
=
node_id
;
slot
->
second
=
new_id
;
}
i
+=
stride
;
}
}
template
<
typename
index_t
,
typename
map_t
>
__global__
void
_MapIdsBatched
(
const
int
num_batches
,
const
int64_t
num_edges
,
const
int32_t
*
const
indexes
,
index_t
**
pointers
,
const
int64_t
*
const
offsets
,
map_t
map
,
index_t
*
mapped_ids
)
{
int64_t
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
stride
=
gridDim
.
x
*
blockDim
.
x
;
while
(
i
<
num_edges
)
{
const
int64_t
tensor_index
=
indexes
[
i
];
int64_t
batch_index
;
if
(
tensor_index
>=
2
*
num_batches
)
{
batch_index
=
tensor_index
-
2
*
num_batches
;
}
else
if
(
tensor_index
&
1
)
{
batch_index
=
tensor_index
/
2
;
}
else
{
batch_index
=
-
1
;
}
// Only map src or dst ids.
if
(
batch_index
>=
0
)
{
const
auto
tensor_offset
=
i
-
offsets
[
tensor_index
];
const
int64_t
node_id
=
pointers
[
tensor_index
][
tensor_offset
];
const
int64_t
key
=
node_id
|
(
batch_index
<<
kNodeIdBits
);
auto
slot
=
map
.
find
(
key
);
mapped_ids
[
i
]
=
slot
->
second
;
}
i
+=
stride
;
}
}
std
::
vector
<
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
>
>
UniqueAndCompactBatchedHashMapBased
(
const
std
::
vector
<
torch
::
Tensor
>&
src_ids
,
const
std
::
vector
<
torch
::
Tensor
>&
dst_ids
,
const
std
::
vector
<
torch
::
Tensor
>&
unique_dst_ids
)
{
auto
allocator
=
cuda
::
GetAllocator
();
auto
stream
=
cuda
::
GetCurrentStream
();
auto
scalar_type
=
src_ids
.
at
(
0
).
scalar_type
();
constexpr
int
BLOCK_SIZE
=
512
;
const
auto
num_batches
=
src_ids
.
size
();
static_assert
(
sizeof
(
std
::
ptrdiff_t
)
==
sizeof
(
int64_t
),
"Need to be compiled on a 64-bit system."
);
constexpr
int
batch_id_bits
=
sizeof
(
int64_t
)
*
8
-
1
-
kNodeIdBits
;
TORCH_CHECK
(
num_batches
<=
(
1
<<
batch_id_bits
),
"UniqueAndCompactBatched supports a batch size of up to "
,
1
<<
batch_id_bits
);
return
AT_DISPATCH_INDEX_TYPES
(
scalar_type
,
"unique_and_compact"
,
([
&
]
{
// For 2 batches of inputs, stores the input tensor pointers in the
// unique_dst, src, unique_dst, src, dst, dst order. Since there are
// 3 * num_batches input tensors, we need the first 3 * num_batches to
// store the input tensor pointers. Then, we store offsets in the rest
// of the 3 * num_batches + 1 space as if they were stored contiguously.
auto
pointers_and_offsets
=
torch
::
empty
(
6
*
num_batches
+
1
,
c10
::
TensorOptions
().
dtype
(
torch
::
kInt64
).
pinned_memory
(
true
));
// Points to the input tensor pointers.
auto
pointers_ptr
=
reinterpret_cast
<
index_t
**>
(
pointers_and_offsets
.
data_ptr
());
// Points to the input tensor storage logical offsets.
auto
offsets_ptr
=
pointers_and_offsets
.
data_ptr
<
int64_t
>
()
+
3
*
num_batches
;
for
(
std
::
size_t
i
=
0
;
i
<
num_batches
;
i
++
)
{
pointers_ptr
[
2
*
i
]
=
unique_dst_ids
.
at
(
i
).
data_ptr
<
index_t
>
();
offsets_ptr
[
2
*
i
]
=
unique_dst_ids
[
i
].
size
(
0
);
pointers_ptr
[
2
*
i
+
1
]
=
src_ids
.
at
(
i
).
data_ptr
<
index_t
>
();
offsets_ptr
[
2
*
i
+
1
]
=
src_ids
[
i
].
size
(
0
);
pointers_ptr
[
2
*
num_batches
+
i
]
=
dst_ids
.
at
(
i
).
data_ptr
<
index_t
>
();
offsets_ptr
[
2
*
num_batches
+
i
]
=
dst_ids
[
i
].
size
(
0
);
}
// Finish computing the offsets by taking a cumulative sum.
std
::
exclusive_scan
(
offsets_ptr
,
offsets_ptr
+
3
*
num_batches
+
1
,
offsets_ptr
,
0ll
);
// Device version of the tensors defined above. We store the information
// initially on the CPU, which are later copied to the device.
auto
pointers_and_offsets_dev
=
torch
::
empty
(
pointers_and_offsets
.
size
(
0
),
src_ids
[
0
].
options
().
dtype
(
pointers_and_offsets
.
scalar_type
()));
auto
offsets_dev
=
pointers_and_offsets_dev
.
slice
(
0
,
3
*
num_batches
);
auto
pointers_dev_ptr
=
reinterpret_cast
<
index_t
**>
(
pointers_and_offsets_dev
.
data_ptr
());
auto
offsets_dev_ptr
=
offsets_dev
.
data_ptr
<
int64_t
>
();
CUDA_CALL
(
cudaMemcpyAsync
(
pointers_dev_ptr
,
pointers_ptr
,
sizeof
(
int64_t
)
*
pointers_and_offsets
.
size
(
0
),
cudaMemcpyHostToDevice
,
stream
));
auto
indexes
=
ExpandIndptrImpl
(
offsets_dev
,
torch
::
kInt32
,
torch
::
nullopt
,
offsets_ptr
[
3
*
num_batches
]);
cuco
::
static_map
map
{
offsets_ptr
[
2
*
num_batches
],
0.5
,
// load_factor
cuco
::
empty_key
{
static_cast
<
int64_t
>
(
-
1
)},
cuco
::
empty_value
{
static_cast
<
int64_t
>
(
-
1
)},
{},
cuco
::
linear_probing
<
1
,
cuco
::
default_hash_function
<
int64_t
>
>
{},
{},
{},
cuda
::
CUDAWorkspaceAllocator
<
cuco
::
pair
<
int64_t
,
int64_t
>
>
{},
cuco
::
cuda_stream_ref
{
stream
},
};
C10_CUDA_KERNEL_LAUNCH_CHECK
();
// Check the map constructor's success.
const
dim3
block
(
BLOCK_SIZE
);
const
dim3
grid
(
(
offsets_ptr
[
2
*
num_batches
]
+
BLOCK_SIZE
-
1
)
/
BLOCK_SIZE
);
CUDA_KERNEL_CALL
(
_InsertAndSetMinBatched
,
grid
,
block
,
0
,
offsets_ptr
[
2
*
num_batches
],
indexes
.
data_ptr
<
int32_t
>
(),
pointers_dev_ptr
,
offsets_dev_ptr
,
map
.
ref
(
cuco
::
insert_and_find
));
auto
valid
=
torch
::
empty
(
offsets_ptr
[
2
*
num_batches
]
+
1
,
src_ids
[
0
].
options
().
dtype
(
torch
::
kInt64
));
CUDA_KERNEL_CALL
(
_IsInsertedBatched
,
grid
,
block
,
0
,
offsets_ptr
[
2
*
num_batches
],
indexes
.
data_ptr
<
int32_t
>
(),
pointers_dev_ptr
,
offsets_dev_ptr
,
map
.
ref
(
cuco
::
find
),
valid
.
data_ptr
<
int64_t
>
());
valid
=
ExclusiveCumSum
(
valid
);
auto
unique_ids_offsets
=
torch
::
empty
(
num_batches
+
1
,
c10
::
TensorOptions
().
dtype
(
torch
::
kInt64
).
pinned_memory
(
true
));
auto
unique_ids_offsets_ptr
=
unique_ids_offsets
.
data_ptr
<
int64_t
>
();
for
(
int64_t
i
=
0
;
i
<=
num_batches
;
i
++
)
{
unique_ids_offsets_ptr
[
i
]
=
offsets_ptr
[
2
*
i
];
}
THRUST_CALL
(
gather
,
unique_ids_offsets_ptr
,
unique_ids_offsets_ptr
+
unique_ids_offsets
.
size
(
0
),
valid
.
data_ptr
<
int64_t
>
(),
unique_ids_offsets_ptr
);
at
::
cuda
::
CUDAEvent
unique_ids_offsets_event
;
unique_ids_offsets_event
.
record
();
auto
unique_ids
=
torch
::
empty
(
offsets_ptr
[
2
*
num_batches
],
src_ids
[
0
].
options
());
CUDA_KERNEL_CALL
(
_GetInsertedBatched
,
grid
,
block
,
0
,
offsets_ptr
[
2
*
num_batches
],
indexes
.
data_ptr
<
int32_t
>
(),
pointers_dev_ptr
,
offsets_dev_ptr
,
map
.
ref
(
cuco
::
find
),
valid
.
data_ptr
<
int64_t
>
(),
unique_ids
.
data_ptr
<
index_t
>
());
auto
mapped_ids
=
torch
::
empty
(
offsets_ptr
[
3
*
num_batches
],
unique_ids
.
options
());
CUDA_KERNEL_CALL
(
_MapIdsBatched
,
grid
,
block
,
0
,
num_batches
,
offsets_ptr
[
3
*
num_batches
],
indexes
.
data_ptr
<
int32_t
>
(),
pointers_dev_ptr
,
offsets_dev_ptr
,
map
.
ref
(
cuco
::
find
),
mapped_ids
.
data_ptr
<
index_t
>
());
std
::
vector
<
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
>
>
results
;
unique_ids_offsets_event
.
synchronize
();
for
(
int64_t
i
=
0
;
i
<
num_batches
;
i
++
)
{
results
.
emplace_back
(
unique_ids
.
slice
(
0
,
unique_ids_offsets_ptr
[
i
],
unique_ids_offsets_ptr
[
i
+
1
]),
mapped_ids
.
slice
(
0
,
offsets_ptr
[
2
*
i
+
1
],
offsets_ptr
[
2
*
i
+
2
]),
mapped_ids
.
slice
(
0
,
offsets_ptr
[
2
*
num_batches
+
i
],
offsets_ptr
[
2
*
num_batches
+
i
+
1
]));
}
return
results
;
}));
}
}
// namespace ops
}
// namespace graphbolt
graphbolt/src/cuda/unique_and_compact_impl.cu
View file @
78df8101
...
@@ -11,9 +11,12 @@
...
@@ -11,9 +11,12 @@
#include <thrust/logical.h>
#include <thrust/logical.h>
#include <cub/cub.cuh>
#include <cub/cub.cuh>
#include <mutex>
#include <type_traits>
#include <type_traits>
#include <unordered_map>
#include "./common.h"
#include "./common.h"
#include "./extension/unique_and_compact.h"
#include "./utils.h"
#include "./utils.h"
namespace
graphbolt
{
namespace
graphbolt
{
...
@@ -41,139 +44,261 @@ struct EqualityFunc {
...
@@ -41,139 +44,261 @@ struct EqualityFunc {
DefineCubReductionFunction
(
DeviceReduce
::
Max
,
Max
);
DefineCubReductionFunction
(
DeviceReduce
::
Max
,
Max
);
DefineCubReductionFunction
(
DeviceReduce
::
Min
,
Min
);
DefineCubReductionFunction
(
DeviceReduce
::
Min
,
Min
);
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
>
UniqueAndCompact
(
std
::
vector
<
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
>>
const
torch
::
Tensor
src_ids
,
const
torch
::
Tensor
dst_ids
,
UniqueAndCompactBatchedSortBased
(
const
torch
::
Tensor
unique_dst_ids
,
int
num_bits
)
{
const
std
::
vector
<
torch
::
Tensor
>&
src_ids
,
TORCH_CHECK
(
const
std
::
vector
<
torch
::
Tensor
>&
dst_ids
,
src_ids
.
scalar_type
()
==
dst_ids
.
scalar_type
()
&&
const
std
::
vector
<
torch
::
Tensor
>&
unique_dst_ids
,
int
num_bits
)
{
dst_ids
.
scalar_type
()
==
unique_dst_ids
.
scalar_type
(),
"Dtypes of tensors passed to UniqueAndCompact need to be identical."
);
auto
allocator
=
cuda
::
GetAllocator
();
auto
allocator
=
cuda
::
GetAllocator
();
auto
stream
=
cuda
::
GetCurrentStream
();
auto
stream
=
cuda
::
GetCurrentStream
();
return
AT_DISPATCH_INTEGRAL_TYPES
(
auto
scalar_type
=
src_ids
.
at
(
0
).
scalar_type
();
src_ids
.
scalar_type
(),
"unique_and_compact"
,
([
&
]
{
return
AT_DISPATCH_INDEX_TYPES
(
auto
src_ids_ptr
=
src_ids
.
data_ptr
<
scalar_t
>
();
scalar_type
,
"unique_and_compact"
,
([
&
]
{
auto
dst_ids_ptr
=
dst_ids
.
data_ptr
<
scalar_t
>
();
std
::
vector
<
index_t
*>
src_ids_ptr
,
dst_ids_ptr
,
unique_dst_ids_ptr
;
auto
unique_dst_ids_ptr
=
unique_dst_ids
.
data_ptr
<
scalar_t
>
();
for
(
std
::
size_t
i
=
0
;
i
<
src_ids
.
size
();
i
++
)
{
src_ids_ptr
.
emplace_back
(
src_ids
[
i
].
data_ptr
<
index_t
>
());
dst_ids_ptr
.
emplace_back
(
dst_ids
[
i
].
data_ptr
<
index_t
>
());
unique_dst_ids_ptr
.
emplace_back
(
unique_dst_ids
[
i
].
data_ptr
<
index_t
>
());
}
// If num_bits is not given, compute maximum vertex ids to compute
// If num_bits is not given, compute maximum vertex ids to compute
// num_bits later to speedup the expensive sort operations.
// num_bits later to speedup the expensive sort operations.
cuda
::
CopyScalar
<
scalar_t
>
max_id_src
;
std
::
vector
<
cuda
::
CopyScalar
<
index_t
>>
max_id_src
;
cuda
::
CopyScalar
<
scalar_t
>
max_id_dst
;
std
::
vector
<
cuda
::
CopyScalar
<
index_t
>>
max_id_dst
;
if
(
num_bits
==
0
)
{
for
(
std
::
size_t
i
=
0
;
num_bits
==
0
&&
i
<
src_ids
.
size
();
i
++
)
{
max_id_src
=
Max
(
src_ids_ptr
,
src_ids
.
size
(
0
));
max_id_src
.
emplace_back
(
Max
(
src_ids_ptr
[
i
],
src_ids
[
i
].
size
(
0
)));
max_id_dst
=
Max
(
unique_dst_ids_ptr
,
unique_dst_ids
.
size
(
0
));
max_id_dst
.
emplace_back
(
Max
(
unique_dst_ids_ptr
[
i
],
unique_dst_ids
[
i
].
size
(
0
)));
}
}
// Sort the unique_dst_ids tensor.
// Sort the unique_dst_ids tensor.
auto
sorted_unique_dst_ids
=
std
::
vector
<
torch
::
Tensor
>
sorted_unique_dst_ids
;
Sort
<
false
>
(
unique_dst_ids_ptr
,
unique_dst_ids
.
size
(
0
),
num_bits
);
std
::
vector
<
index_t
*>
sorted_unique_dst_ids_ptr
;
auto
sorted_unique_dst_ids_ptr
=
for
(
std
::
size_t
i
=
0
;
i
<
unique_dst_ids
.
size
();
i
++
)
{
sorted_unique_dst_ids
.
data_ptr
<
scalar_t
>
();
sorted_unique_dst_ids
.
emplace_back
(
Sort
<
false
>
(
unique_dst_ids_ptr
[
i
],
unique_dst_ids
[
i
].
size
(
0
),
num_bits
));
sorted_unique_dst_ids_ptr
.
emplace_back
(
sorted_unique_dst_ids
[
i
].
data_ptr
<
index_t
>
());
}
// Mark dst nodes in the src_ids tensor.
// Mark dst nodes in the src_ids tensor.
auto
is_dst
=
allocator
.
AllocateStorage
<
bool
>
(
src_ids
.
size
(
0
));
std
::
vector
<
decltype
(
allocator
.
AllocateStorage
<
bool
>
(
0
))
>
is_dst
;
THRUST_CALL
(
for
(
std
::
size_t
i
=
0
;
i
<
src_ids
.
size
();
i
++
)
{
binary_search
,
sorted_unique_dst_ids_ptr
,
is_dst
.
emplace_back
(
sorted_unique_dst_ids_ptr
+
unique_dst_ids
.
size
(
0
),
src_ids_ptr
,
allocator
.
AllocateStorage
<
bool
>
(
src_ids
[
i
].
size
(
0
)));
src_ids_ptr
+
src_ids
.
size
(
0
),
is_dst
.
get
());
THRUST_CALL
(
binary_search
,
sorted_unique_dst_ids_ptr
[
i
],
sorted_unique_dst_ids_ptr
[
i
]
+
unique_dst_ids
[
i
].
size
(
0
),
src_ids_ptr
[
i
],
src_ids_ptr
[
i
]
+
src_ids
[
i
].
size
(
0
),
is_dst
[
i
].
get
());
}
// Filter the non-dst nodes in the src_ids tensor, hence only_src.
// Filter the non-dst nodes in the src_ids tensor, hence only_src.
auto
only_src
=
std
::
vector
<
torch
::
Tensor
>
only_src
;
torch
::
empty
(
src_ids
.
size
(
0
),
sorted_unique_dst_ids
.
options
());
{
{
auto
is_src
=
thrust
::
make_transform_iterator
(
std
::
vector
<
cuda
::
CopyScalar
<
int64_t
>>
only_src_size
;
is_dst
.
get
(),
thrust
::
logical_not
<
bool
>
{});
for
(
std
::
size_t
i
=
0
;
i
<
src_ids
.
size
();
i
++
)
{
cuda
::
CopyScalar
<
int64_t
>
only_src_size
;
only_src
.
emplace_back
(
torch
::
empty
(
CUB_CALL
(
src_ids
[
i
].
size
(
0
),
sorted_unique_dst_ids
[
i
].
options
()));
DeviceSelect
::
Flagged
,
src_ids_ptr
,
is_src
,
auto
is_src
=
thrust
::
make_transform_iterator
(
only_src
.
data_ptr
<
scalar_t
>
(),
only_src_size
.
get
(),
is_dst
[
i
].
get
(),
thrust
::
logical_not
<
bool
>
{});
src_ids
.
size
(
0
));
only_src_size
.
emplace_back
(
cuda
::
CopyScalar
<
int64_t
>
{});
CUB_CALL
(
DeviceSelect
::
Flagged
,
src_ids_ptr
[
i
],
is_src
,
only_src
[
i
].
data_ptr
<
index_t
>
(),
only_src_size
[
i
].
get
(),
src_ids
[
i
].
size
(
0
));
}
stream
.
synchronize
();
stream
.
synchronize
();
only_src
=
only_src
.
slice
(
0
,
0
,
static_cast
<
int64_t
>
(
only_src_size
));
for
(
std
::
size_t
i
=
0
;
i
<
only_src
.
size
();
i
++
)
{
only_src
[
i
]
=
only_src
[
i
].
slice
(
0
,
0
,
static_cast
<
int64_t
>
(
only_src_size
[
i
]));
}
}
}
// The code block above synchronizes, ensuring safe access to
max_id_src
// The code block above synchronizes, ensuring safe access to
// and max_id_dst.
//
max_id_src
and max_id_dst.
if
(
num_bits
==
0
)
{
if
(
num_bits
==
0
)
{
num_bits
=
cuda
::
NumberOfBits
(
index_t
max_id
=
0
;
1
+
std
::
max
(
for
(
std
::
size_t
i
=
0
;
i
<
max_id_src
.
size
();
i
++
)
{
static_cast
<
scalar_t
>
(
max_id_src
),
max_id
=
std
::
max
(
max_id
,
static_cast
<
index_t
>
(
max_id_src
[
i
]));
static_cast
<
scalar_t
>
(
max_id_dst
)));
max_id
=
std
::
max
(
max_id
,
static_cast
<
index_t
>
(
max_id_dst
[
i
]));
}
num_bits
=
cuda
::
NumberOfBits
(
1ll
+
max_id
);
}
}
// Sort the only_src tensor so that we can unique it later.
// Sort the only_src tensor so that we can unique it later.
auto
sorted_only_src
=
Sort
<
false
>
(
std
::
vector
<
torch
::
Tensor
>
sorted_only_src
;
only_src
.
data_ptr
<
scalar_t
>
(),
only_src
.
size
(
0
),
num_bits
);
for
(
auto
&
only_src_i
:
only_src
)
{
sorted_only_src
.
emplace_back
(
Sort
<
false
>
(
only_src_i
.
data_ptr
<
index_t
>
(),
only_src_i
.
size
(
0
),
num_bits
));
}
auto
unique_only_src
=
std
::
vector
<
torch
::
Tensor
>
unique_only_src
;
torch
::
empty
(
only_src
.
size
(
0
),
src_ids
.
options
());
std
::
vector
<
index_t
*>
unique_only_src_ptr
;
auto
unique_only_src_ptr
=
unique_only_src
.
data_ptr
<
scalar_t
>
();
{
// Compute the unique operation on the only_src tensor.
std
::
vector
<
cuda
::
CopyScalar
<
int64_t
>>
unique_only_src_size
;
cuda
::
CopyScalar
<
int64_t
>
unique_only_src_size
;
for
(
std
::
size_t
i
=
0
;
i
<
src_ids
.
size
();
i
++
)
{
// Compute the unique operation on the only_src tensor.
unique_only_src
.
emplace_back
(
torch
::
empty
(
only_src
[
i
].
size
(
0
),
src_ids
[
i
].
options
()));
unique_only_src_ptr
.
emplace_back
(
unique_only_src
[
i
].
data_ptr
<
index_t
>
());
unique_only_src_size
.
emplace_back
(
cuda
::
CopyScalar
<
int64_t
>
{});
CUB_CALL
(
CUB_CALL
(
DeviceSelect
::
Unique
,
sorted_only_src
.
data_ptr
<
scalar_t
>
(),
DeviceSelect
::
Unique
,
sorted_only_src
[
i
].
data_ptr
<
index_t
>
(),
unique_only_src_ptr
,
unique_only_src_size
.
get
(),
unique_only_src_ptr
[
i
],
unique_only_src_size
[
i
].
get
(),
only_src
.
size
(
0
));
only_src
[
i
].
size
(
0
));
stream
.
synchronize
();
}
unique_only_src
=
unique_only_src
.
slice
(
stream
.
synchronize
();
0
,
0
,
static_cast
<
int64_t
>
(
unique_only_src_size
));
for
(
std
::
size_t
i
=
0
;
i
<
unique_only_src
.
size
();
i
++
)
{
unique_only_src
[
i
]
=
unique_only_src
[
i
].
slice
(
0
,
0
,
static_cast
<
int64_t
>
(
unique_only_src_size
[
i
]));
}
}
auto
real_order
=
torch
::
cat
({
unique_dst_ids
,
unique_only_src
});
std
::
vector
<
torch
::
Tensor
>
real_order
;
for
(
std
::
size_t
i
=
0
;
i
<
unique_dst_ids
.
size
();
i
++
)
{
real_order
.
emplace_back
(
torch
::
cat
({
unique_dst_ids
[
i
],
unique_only_src
[
i
]}));
}
// Sort here so that binary search can be used to lookup new_ids.
// Sort here so that binary search can be used to lookup new_ids.
torch
::
Tensor
sorted_order
,
new_ids
;
std
::
vector
<
torch
::
Tensor
>
sorted_order
,
new_ids
;
std
::
tie
(
sorted_order
,
new_ids
)
=
Sort
(
real_order
,
num_bits
);
std
::
vector
<
index_t
*>
sorted_order_ptr
;
auto
sorted_order_ptr
=
sorted_order
.
data_ptr
<
scalar_t
>
();
std
::
vector
<
int64_t
*>
new_ids_ptr
;
auto
new_ids_ptr
=
new_ids
.
data_ptr
<
int64_t
>
();
for
(
std
::
size_t
i
=
0
;
i
<
real_order
.
size
();
i
++
)
{
// Holds the found locations of the src and dst ids in the sorted_order.
auto
[
sorted_order_i
,
new_ids_i
]
=
Sort
(
real_order
[
i
],
num_bits
);
// Later is used to lookup the new ids of the src_ids and dst_ids
sorted_order_ptr
.
emplace_back
(
sorted_order_i
.
data_ptr
<
index_t
>
());
// tensors.
new_ids_ptr
.
emplace_back
(
new_ids_i
.
data_ptr
<
int64_t
>
());
auto
new_dst_ids_loc
=
sorted_order
.
emplace_back
(
std
::
move
(
sorted_order_i
));
allocator
.
AllocateStorage
<
scalar_t
>
(
dst_ids
.
size
(
0
));
new_ids
.
emplace_back
(
std
::
move
(
new_ids_i
));
THRUST_CALL
(
}
lower_bound
,
sorted_order_ptr
,
// Holds the found locations of the src and dst ids in the
sorted_order_ptr
+
sorted_order
.
size
(
0
),
dst_ids_ptr
,
// sorted_order. Later is used to lookup the new ids of the src_ids
dst_ids_ptr
+
dst_ids
.
size
(
0
),
new_dst_ids_loc
.
get
());
// and dst_ids tensors.
std
::
vector
<
decltype
(
allocator
.
AllocateStorage
<
index_t
>
(
0
))
>
cuda
::
CopyScalar
<
bool
>
all_exist
;
new_dst_ids_loc
;
for
(
std
::
size_t
i
=
0
;
i
<
sorted_order
.
size
();
i
++
)
{
new_dst_ids_loc
.
emplace_back
(
allocator
.
AllocateStorage
<
index_t
>
(
dst_ids
[
i
].
size
(
0
)));
THRUST_CALL
(
lower_bound
,
sorted_order_ptr
[
i
],
sorted_order_ptr
[
i
]
+
sorted_order
[
i
].
size
(
0
),
dst_ids_ptr
[
i
],
dst_ids_ptr
[
i
]
+
dst_ids
[
i
].
size
(
0
),
new_dst_ids_loc
[
i
].
get
());
}
std
::
vector
<
cuda
::
CopyScalar
<
bool
>>
all_exist
;
at
::
cuda
::
CUDAEvent
all_exist_event
;
bool
should_record
=
false
;
// Check if unique_dst_ids includes all dst_ids.
// Check if unique_dst_ids includes all dst_ids.
if
(
dst_ids
.
size
(
0
)
>
0
)
{
for
(
std
::
size_t
i
=
0
;
i
<
dst_ids
.
size
();
i
++
)
{
thrust
::
counting_iterator
<
int64_t
>
iota
(
0
);
if
(
dst_ids
[
i
].
size
(
0
)
>
0
)
{
auto
equal_it
=
thrust
::
make_transform_iterator
(
thrust
::
counting_iterator
<
int64_t
>
iota
(
0
);
iota
,
EqualityFunc
<
scalar_t
>
{
auto
equal_it
=
thrust
::
make_transform_iterator
(
sorted_order_ptr
,
new_dst_ids_loc
.
get
(),
dst_ids_ptr
});
iota
,
EqualityFunc
<
index_t
>
{
all_exist
=
Min
(
equal_it
,
dst_ids
.
size
(
0
));
sorted_order_ptr
[
i
],
new_dst_ids_loc
[
i
].
get
(),
all_exist
.
record
();
dst_ids_ptr
[
i
]});
}
all_exist
.
emplace_back
(
Min
(
equal_it
,
dst_ids
[
i
].
size
(
0
)));
should_record
=
true
;
auto
new_src_ids_loc
=
}
else
{
allocator
.
AllocateStorage
<
scalar_t
>
(
src_ids
.
size
(
0
));
all_exist
.
emplace_back
(
cuda
::
CopyScalar
<
bool
>
{});
THRUST_CALL
(
}
lower_bound
,
sorted_order_ptr
,
}
sorted_order_ptr
+
sorted_order
.
size
(
0
),
src_ids_ptr
,
if
(
should_record
)
all_exist_event
.
record
();
src_ids_ptr
+
src_ids
.
size
(
0
),
new_src_ids_loc
.
get
());
std
::
vector
<
decltype
(
allocator
.
AllocateStorage
<
index_t
>
(
0
))
>
// Finally, lookup the new compact ids of the src and dst tensors via
new_src_ids_loc
;
// gather operations.
for
(
std
::
size_t
i
=
0
;
i
<
sorted_order
.
size
();
i
++
)
{
auto
new_src_ids
=
torch
::
empty_like
(
src_ids
);
new_src_ids_loc
.
emplace_back
(
THRUST_CALL
(
allocator
.
AllocateStorage
<
index_t
>
(
src_ids
[
i
].
size
(
0
)));
gather
,
new_src_ids_loc
.
get
(),
THRUST_CALL
(
new_src_ids_loc
.
get
()
+
src_ids
.
size
(
0
),
lower_bound
,
sorted_order_ptr
[
i
],
new_ids
.
data_ptr
<
int64_t
>
(),
new_src_ids
.
data_ptr
<
scalar_t
>
());
sorted_order_ptr
[
i
]
+
sorted_order
[
i
].
size
(
0
),
src_ids_ptr
[
i
],
src_ids_ptr
[
i
]
+
src_ids
[
i
].
size
(
0
),
new_src_ids_loc
[
i
].
get
());
}
// Finally, lookup the new compact ids of the src and dst tensors
// via gather operations.
std
::
vector
<
torch
::
Tensor
>
new_src_ids
;
for
(
std
::
size_t
i
=
0
;
i
<
src_ids
.
size
();
i
++
)
{
new_src_ids
.
emplace_back
(
torch
::
empty_like
(
src_ids
[
i
]));
THRUST_CALL
(
gather
,
new_src_ids_loc
[
i
].
get
(),
new_src_ids_loc
[
i
].
get
()
+
src_ids
[
i
].
size
(
0
),
new_ids
[
i
].
data_ptr
<
int64_t
>
(),
new_src_ids
[
i
].
data_ptr
<
index_t
>
());
}
// Perform check before we gather for the dst indices.
// Perform check before we gather for the dst indices.
if
(
dst_ids
.
size
(
0
)
>
0
&&
!
static_cast
<
bool
>
(
all_exist
))
{
for
(
std
::
size_t
i
=
0
;
i
<
dst_ids
.
size
();
i
++
)
{
throw
std
::
out_of_range
(
"Some ids not found."
);
if
(
dst_ids
[
i
].
size
(
0
)
>
0
)
{
}
if
(
should_record
)
{
auto
new_dst_ids
=
torch
::
empty_like
(
dst_ids
);
all_exist_event
.
synchronize
();
THRUST_CALL
(
should_record
=
false
;
gather
,
new_dst_ids_loc
.
get
(),
}
new_dst_ids_loc
.
get
()
+
dst_ids
.
size
(
0
),
if
(
!
static_cast
<
bool
>
(
all_exist
[
i
]))
{
new_ids
.
data_ptr
<
int64_t
>
(),
new_dst_ids
.
data_ptr
<
scalar_t
>
());
throw
std
::
out_of_range
(
"Some ids not found."
);
return
std
::
make_tuple
(
real_order
,
new_src_ids
,
new_dst_ids
);
}
}
}
std
::
vector
<
torch
::
Tensor
>
new_dst_ids
;
for
(
std
::
size_t
i
=
0
;
i
<
dst_ids
.
size
();
i
++
)
{
new_dst_ids
.
emplace_back
(
torch
::
empty_like
(
dst_ids
[
i
]));
THRUST_CALL
(
gather
,
new_dst_ids_loc
[
i
].
get
(),
new_dst_ids_loc
[
i
].
get
()
+
dst_ids
[
i
].
size
(
0
),
new_ids
[
i
].
data_ptr
<
int64_t
>
(),
new_dst_ids
[
i
].
data_ptr
<
index_t
>
());
}
std
::
vector
<
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
>>
results
;
for
(
std
::
size_t
i
=
0
;
i
<
src_ids
.
size
();
i
++
)
{
results
.
emplace_back
(
std
::
move
(
real_order
[
i
]),
std
::
move
(
new_src_ids
[
i
]),
std
::
move
(
new_dst_ids
[
i
]));
}
return
results
;
}));
}));
}
}
std
::
vector
<
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
>>
UniqueAndCompactBatched
(
const
std
::
vector
<
torch
::
Tensor
>&
src_ids
,
const
std
::
vector
<
torch
::
Tensor
>&
dst_ids
,
const
std
::
vector
<
torch
::
Tensor
>&
unique_dst_ids
,
int
num_bits
)
{
auto
dev_id
=
cuda
::
GetCurrentStream
().
device_index
();
static
std
::
mutex
mtx
;
static
std
::
unordered_map
<
decltype
(
dev_id
),
int
>
compute_capability_cache
;
const
auto
compute_capability_major
=
[
&
]
{
std
::
lock_guard
lock
(
mtx
);
auto
it
=
compute_capability_cache
.
find
(
dev_id
);
if
(
it
!=
compute_capability_cache
.
end
())
{
return
it
->
second
;
}
else
{
int
major
;
CUDA_DRIVER_CHECK
(
cuDeviceGetAttribute
(
&
major
,
CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR
,
dev_id
));
return
compute_capability_cache
[
dev_id
]
=
major
;
}
}();
if
(
compute_capability_major
>=
7
)
{
// Utilizes a hash table based implementation, the mapped id of a vertex
// will be monotonically increasing as the first occurrence index of it in
// torch.cat([unique_dst_ids, src_ids]). Thus, it is deterministic.
return
UniqueAndCompactBatchedHashMapBased
(
src_ids
,
dst_ids
,
unique_dst_ids
);
}
// Utilizes a sort based algorithm, the mapped id of a vertex part of the
// src_ids but not part of the unique_dst_ids will be monotonically increasing
// as the actual vertex id increases. Thus, it is deterministic.
return
UniqueAndCompactBatchedSortBased
(
src_ids
,
dst_ids
,
unique_dst_ids
,
num_bits
);
}
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
>
UniqueAndCompact
(
const
torch
::
Tensor
src_ids
,
const
torch
::
Tensor
dst_ids
,
const
torch
::
Tensor
unique_dst_ids
,
int
num_bits
)
{
return
UniqueAndCompactBatched
(
{
src_ids
},
{
dst_ids
},
{
unique_dst_ids
},
num_bits
)[
0
];
}
}
// namespace ops
}
// namespace ops
}
// namespace graphbolt
}
// namespace graphbolt
graphbolt/src/python_binding.cc
View file @
78df8101
...
@@ -89,6 +89,7 @@ TORCH_LIBRARY(graphbolt, m) {
...
@@ -89,6 +89,7 @@ TORCH_LIBRARY(graphbolt, m) {
m
.
def
(
m
.
def
(
"load_from_shared_memory"
,
&
FusedCSCSamplingGraph
::
LoadFromSharedMemory
);
"load_from_shared_memory"
,
&
FusedCSCSamplingGraph
::
LoadFromSharedMemory
);
m
.
def
(
"unique_and_compact"
,
&
UniqueAndCompact
);
m
.
def
(
"unique_and_compact"
,
&
UniqueAndCompact
);
m
.
def
(
"unique_and_compact_batched"
,
&
UniqueAndCompactBatched
);
m
.
def
(
"isin"
,
&
IsIn
);
m
.
def
(
"isin"
,
&
IsIn
);
m
.
def
(
"index_select"
,
&
ops
::
IndexSelect
);
m
.
def
(
"index_select"
,
&
ops
::
IndexSelect
);
m
.
def
(
"index_select_csc"
,
&
ops
::
IndexSelectCSC
);
m
.
def
(
"index_select_csc"
,
&
ops
::
IndexSelectCSC
);
...
...
graphbolt/src/unique_and_compact.cc
View file @
78df8101
...
@@ -85,5 +85,37 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
...
@@ -85,5 +85,37 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
#endif
#endif
return
std
::
tuple
(
unique_ids
,
compacted_src_ids
,
compacted_dst_ids
);
return
std
::
tuple
(
unique_ids
,
compacted_src_ids
,
compacted_dst_ids
);
}
}
std
::
vector
<
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
>>
UniqueAndCompactBatched
(
const
std
::
vector
<
torch
::
Tensor
>&
src_ids
,
const
std
::
vector
<
torch
::
Tensor
>&
dst_ids
,
const
std
::
vector
<
torch
::
Tensor
>
unique_dst_ids
)
{
TORCH_CHECK
(
src_ids
.
size
()
==
dst_ids
.
size
()
&&
dst_ids
.
size
()
==
unique_dst_ids
.
size
(),
"The batch dimension of the parameters need to be identical."
);
bool
all_on_gpu
=
true
;
for
(
std
::
size_t
i
=
0
;
i
<
src_ids
.
size
();
i
++
)
{
all_on_gpu
=
all_on_gpu
&&
utils
::
is_on_gpu
(
src_ids
[
i
])
&&
utils
::
is_on_gpu
(
dst_ids
[
i
])
&&
utils
::
is_on_gpu
(
unique_dst_ids
[
i
]);
if
(
!
all_on_gpu
)
break
;
}
if
(
all_on_gpu
)
{
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE
(
c10
::
DeviceType
::
CUDA
,
"unique_and_compact"
,
{
return
ops
::
UniqueAndCompactBatched
(
src_ids
,
dst_ids
,
unique_dst_ids
);
});
}
std
::
vector
<
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
>>
results
;
results
.
reserve
(
src_ids
.
size
());
for
(
std
::
size_t
i
=
0
;
i
<
src_ids
.
size
();
i
++
)
{
results
.
emplace_back
(
UniqueAndCompact
(
src_ids
[
i
],
dst_ids
[
i
],
unique_dst_ids
[
i
]));
}
return
results
;
}
}
// namespace sampling
}
// namespace sampling
}
// namespace graphbolt
}
// namespace graphbolt
python/dgl/graphbolt/internal/sample_utils.py
View file @
78df8101
...
@@ -204,18 +204,19 @@ def unique_and_compact_csc_formats(
...
@@ -204,18 +204,19 @@ def unique_and_compact_csc_formats(
compacted_indices
=
{}
compacted_indices
=
{}
dtype
=
list
(
indices
.
values
())[
0
].
dtype
dtype
=
list
(
indices
.
values
())[
0
].
dtype
default_tensor
=
torch
.
tensor
([],
dtype
=
dtype
,
device
=
device
)
default_tensor
=
torch
.
tensor
([],
dtype
=
dtype
,
device
=
device
)
indice_list
=
[]
unique_dst_list
=
[]
for
ntype
in
ntypes
:
for
ntype
in
ntypes
:
indice
=
indices
.
get
(
ntype
,
default_tensor
)
indice_list
.
append
(
indices
.
get
(
ntype
,
default_tensor
))
unique_dst
=
unique_dst_nodes
.
get
(
ntype
,
default_tensor
)
unique_dst_list
.
append
(
unique_dst_nodes
.
get
(
ntype
,
default_tensor
))
(
dst_list
=
[
torch
.
tensor
([],
dtype
=
dtype
,
device
=
device
)]
*
len
(
unique_nodes
[
ntype
],
unique_dst_list
compacted_indices
[
ntype
],
)
_
,
results
=
torch
.
ops
.
graphbolt
.
unique_and_compact_batched
(
)
=
torch
.
ops
.
graphbolt
.
unique_and_compact
(
indice_list
,
dst_list
,
unique_dst_list
indice
,
)
torch
.
tensor
([],
dtype
=
indice
.
dtype
,
device
=
device
),
for
i
,
ntype
in
enumerate
(
ntypes
):
unique_dst
,
unique_nodes
[
ntype
],
compacted_indices
[
ntype
],
_
=
results
[
i
]
)
compacted_csc_formats
=
{}
compacted_csc_formats
=
{}
# Map back with the same order.
# Map back with the same order.
...
...
tests/python/pytorch/graphbolt/impl/test_in_subgraph_sampler.py
View file @
78df8101
...
@@ -244,7 +244,7 @@ def test_InSubgraphSampler_hetero():
...
@@ -244,7 +244,7 @@ def test_InSubgraphSampler_hetero():
indices
=
torch
.
LongTensor
([
1
,
2
,
0
]),
indices
=
torch
.
LongTensor
([
1
,
2
,
0
]),
),
),
}
}
if
graph
.
csc_indptr
.
is_cuda
:
if
graph
.
csc_indptr
.
is_cuda
and
torch
.
cuda
.
get_device_capability
()[
0
]
<
7
:
expected_sampled_csc
[
"N0:R1:N1"
]
=
gb
.
CSCFormatBase
(
expected_sampled_csc
[
"N0:R1:N1"
]
=
gb
.
CSCFormatBase
(
indptr
=
torch
.
LongTensor
([
0
,
1
,
2
]),
indices
=
torch
.
LongTensor
([
1
,
0
])
indptr
=
torch
.
LongTensor
([
0
,
1
,
2
]),
indices
=
torch
.
LongTensor
([
1
,
0
])
)
)
...
...
tests/python/pytorch/graphbolt/internal/test_sample_utils.py
View file @
78df8101
...
@@ -15,7 +15,7 @@ def test_unique_and_compact_hetero():
...
@@ -15,7 +15,7 @@ def test_unique_and_compact_hetero():
"n2"
:
torch
.
tensor
([
0
,
3
,
5
,
2
,
7
,
8
,
4
,
9
],
device
=
F
.
ctx
()),
"n2"
:
torch
.
tensor
([
0
,
3
,
5
,
2
,
7
,
8
,
4
,
9
],
device
=
F
.
ctx
()),
"n3"
:
torch
.
tensor
([
1
,
2
,
6
,
8
,
3
],
device
=
F
.
ctx
()),
"n3"
:
torch
.
tensor
([
1
,
2
,
6
,
8
,
3
],
device
=
F
.
ctx
()),
}
}
if
N1
.
is_cuda
:
if
N1
.
is_cuda
and
torch
.
cuda
.
get_device_capability
()[
0
]
<
7
:
expected_reverse_id
=
{
expected_reverse_id
=
{
k
:
v
.
sort
()[
1
]
for
k
,
v
in
expected_unique
.
items
()
k
:
v
.
sort
()[
1
]
for
k
,
v
in
expected_unique
.
items
()
}
}
...
@@ -70,7 +70,7 @@ def test_unique_and_compact_homo():
...
@@ -70,7 +70,7 @@ def test_unique_and_compact_homo():
expected_unique_N
=
torch
.
tensor
(
expected_unique_N
=
torch
.
tensor
(
[
0
,
5
,
2
,
7
,
12
,
9
,
6
,
3
,
4
,
1
],
device
=
F
.
ctx
()
[
0
,
5
,
2
,
7
,
12
,
9
,
6
,
3
,
4
,
1
],
device
=
F
.
ctx
()
)
)
if
N
.
is_cuda
:
if
N
.
is_cuda
and
torch
.
cuda
.
get_device_capability
()[
0
]
<
7
:
expected_reverse_id_N
=
expected_unique_N
.
sort
()[
1
]
expected_reverse_id_N
=
expected_unique_N
.
sort
()[
1
]
expected_unique_N
=
expected_unique_N
.
sort
()[
0
]
expected_unique_N
=
expected_unique_N
.
sort
()[
0
]
else
:
else
:
...
...
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
View file @
78df8101
...
@@ -792,22 +792,40 @@ def test_SubgraphSampler_unique_csc_format_Homo_gpu_seed_nodes(labor):
...
@@ -792,22 +792,40 @@ def test_SubgraphSampler_unique_csc_format_Homo_gpu_seed_nodes(labor):
deduplicate
=
True
,
deduplicate
=
True
,
)
)
original_row_node_ids
=
[
if
torch
.
cuda
.
get_device_capability
()[
0
]
<
7
:
torch
.
tensor
([
0
,
3
,
4
,
2
,
5
,
7
]).
to
(
F
.
ctx
()),
original_row_node_ids
=
[
torch
.
tensor
([
0
,
3
,
4
,
2
,
5
]).
to
(
F
.
ctx
()),
torch
.
tensor
([
0
,
3
,
4
,
2
,
5
,
7
]).
to
(
F
.
ctx
()),
]
torch
.
tensor
([
0
,
3
,
4
,
2
,
5
]).
to
(
F
.
ctx
()),
compacted_indices
=
[
]
torch
.
tensor
([
4
,
3
,
2
,
5
,
5
]).
to
(
F
.
ctx
()),
compacted_indices
=
[
torch
.
tensor
([
4
,
3
,
2
]).
to
(
F
.
ctx
()),
torch
.
tensor
([
4
,
3
,
2
,
5
,
5
]).
to
(
F
.
ctx
()),
]
torch
.
tensor
([
4
,
3
,
2
]).
to
(
F
.
ctx
()),
indptr
=
[
]
torch
.
tensor
([
0
,
1
,
2
,
3
,
5
,
5
]).
to
(
F
.
ctx
()),
indptr
=
[
torch
.
tensor
([
0
,
1
,
2
,
3
]).
to
(
F
.
ctx
()),
torch
.
tensor
([
0
,
1
,
2
,
3
,
5
,
5
]).
to
(
F
.
ctx
()),
]
torch
.
tensor
([
0
,
1
,
2
,
3
]).
to
(
F
.
ctx
()),
seeds
=
[
]
torch
.
tensor
([
0
,
3
,
4
,
2
,
5
]).
to
(
F
.
ctx
()),
seeds
=
[
torch
.
tensor
([
0
,
3
,
4
]).
to
(
F
.
ctx
()),
torch
.
tensor
([
0
,
3
,
4
,
2
,
5
]).
to
(
F
.
ctx
()),
]
torch
.
tensor
([
0
,
3
,
4
]).
to
(
F
.
ctx
()),
]
else
:
original_row_node_ids
=
[
torch
.
tensor
([
0
,
3
,
4
,
5
,
2
,
7
]).
to
(
F
.
ctx
()),
torch
.
tensor
([
0
,
3
,
4
,
5
,
2
]).
to
(
F
.
ctx
()),
]
compacted_indices
=
[
torch
.
tensor
([
3
,
4
,
2
,
5
,
5
]).
to
(
F
.
ctx
()),
torch
.
tensor
([
3
,
4
,
2
]).
to
(
F
.
ctx
()),
]
indptr
=
[
torch
.
tensor
([
0
,
1
,
2
,
3
,
3
,
5
]).
to
(
F
.
ctx
()),
torch
.
tensor
([
0
,
1
,
2
,
3
]).
to
(
F
.
ctx
()),
]
seeds
=
[
torch
.
tensor
([
0
,
3
,
4
,
5
,
2
]).
to
(
F
.
ctx
()),
torch
.
tensor
([
0
,
3
,
4
]).
to
(
F
.
ctx
()),
]
_assert_homo_values
(
_assert_homo_values
(
datapipe
,
original_row_node_ids
,
compacted_indices
,
indptr
,
seeds
datapipe
,
original_row_node_ids
,
compacted_indices
,
indptr
,
seeds
)
)
...
@@ -1646,22 +1664,41 @@ def test_SubgraphSampler_unique_csc_format_Homo_Node_gpu(labor):
...
@@ -1646,22 +1664,41 @@ def test_SubgraphSampler_unique_csc_format_Homo_Node_gpu(labor):
deduplicate
=
True
,
deduplicate
=
True
,
)
)
original_row_node_ids
=
[
if
torch
.
cuda
.
get_device_capability
()[
0
]
<
7
:
torch
.
tensor
([
0
,
3
,
4
,
2
,
5
,
7
]).
to
(
F
.
ctx
()),
original_row_node_ids
=
[
torch
.
tensor
([
0
,
3
,
4
,
2
,
5
]).
to
(
F
.
ctx
()),
torch
.
tensor
([
0
,
3
,
4
,
2
,
5
,
7
]).
to
(
F
.
ctx
()),
]
torch
.
tensor
([
0
,
3
,
4
,
2
,
5
]).
to
(
F
.
ctx
()),
compacted_indices
=
[
]
torch
.
tensor
([
4
,
3
,
2
,
5
,
5
]).
to
(
F
.
ctx
()),
compacted_indices
=
[
torch
.
tensor
([
4
,
3
,
2
]).
to
(
F
.
ctx
()),
torch
.
tensor
([
4
,
3
,
2
,
5
,
5
]).
to
(
F
.
ctx
()),
]
torch
.
tensor
([
4
,
3
,
2
]).
to
(
F
.
ctx
()),
indptr
=
[
]
torch
.
tensor
([
0
,
1
,
2
,
3
,
5
,
5
]).
to
(
F
.
ctx
()),
indptr
=
[
torch
.
tensor
([
0
,
1
,
2
,
3
]).
to
(
F
.
ctx
()),
torch
.
tensor
([
0
,
1
,
2
,
3
,
5
,
5
]).
to
(
F
.
ctx
()),
]
torch
.
tensor
([
0
,
1
,
2
,
3
]).
to
(
F
.
ctx
()),
seeds
=
[
]
torch
.
tensor
([
0
,
3
,
4
,
2
,
5
]).
to
(
F
.
ctx
()),
seeds
=
[
torch
.
tensor
([
0
,
3
,
4
]).
to
(
F
.
ctx
()),
torch
.
tensor
([
0
,
3
,
4
,
2
,
5
]).
to
(
F
.
ctx
()),
]
torch
.
tensor
([
0
,
3
,
4
]).
to
(
F
.
ctx
()),
]
else
:
original_row_node_ids
=
[
torch
.
tensor
([
0
,
3
,
4
,
5
,
2
,
7
]).
to
(
F
.
ctx
()),
torch
.
tensor
([
0
,
3
,
4
,
5
,
2
]).
to
(
F
.
ctx
()),
]
compacted_indices
=
[
torch
.
tensor
([
3
,
4
,
2
,
5
,
5
]).
to
(
F
.
ctx
()),
torch
.
tensor
([
3
,
4
,
2
]).
to
(
F
.
ctx
()),
]
indptr
=
[
torch
.
tensor
([
0
,
1
,
2
,
3
,
3
,
5
]).
to
(
F
.
ctx
()),
torch
.
tensor
([
0
,
1
,
2
,
3
]).
to
(
F
.
ctx
()),
]
seeds
=
[
torch
.
tensor
([
0
,
3
,
4
,
5
,
2
]).
to
(
F
.
ctx
()),
torch
.
tensor
([
0
,
3
,
4
]).
to
(
F
.
ctx
()),
]
for
data
in
datapipe
:
for
data
in
datapipe
:
for
step
,
sampled_subgraph
in
enumerate
(
data
.
sampled_subgraphs
):
for
step
,
sampled_subgraph
in
enumerate
(
data
.
sampled_subgraphs
):
assert
torch
.
equal
(
assert
torch
.
equal
(
...
@@ -2060,22 +2097,41 @@ def test_SubgraphSampler_unique_csc_format_Homo_Link_gpu(labor):
...
@@ -2060,22 +2097,41 @@ def test_SubgraphSampler_unique_csc_format_Homo_Link_gpu(labor):
deduplicate
=
True
,
deduplicate
=
True
,
)
)
original_row_node_ids
=
[
if
torch
.
cuda
.
get_device_capability
()[
0
]
<
7
:
torch
.
tensor
([
0
,
3
,
4
,
2
,
5
,
7
]).
to
(
F
.
ctx
()),
original_row_node_ids
=
[
torch
.
tensor
([
0
,
3
,
4
,
2
,
5
]).
to
(
F
.
ctx
()),
torch
.
tensor
([
0
,
3
,
4
,
2
,
5
,
7
]).
to
(
F
.
ctx
()),
]
torch
.
tensor
([
0
,
3
,
4
,
2
,
5
]).
to
(
F
.
ctx
()),
compacted_indices
=
[
]
torch
.
tensor
([
4
,
3
,
2
,
5
,
5
]).
to
(
F
.
ctx
()),
compacted_indices
=
[
torch
.
tensor
([
4
,
3
,
2
]).
to
(
F
.
ctx
()),
torch
.
tensor
([
4
,
3
,
2
,
5
,
5
]).
to
(
F
.
ctx
()),
]
torch
.
tensor
([
4
,
3
,
2
]).
to
(
F
.
ctx
()),
indptr
=
[
]
torch
.
tensor
([
0
,
1
,
2
,
3
,
5
,
5
]).
to
(
F
.
ctx
()),
indptr
=
[
torch
.
tensor
([
0
,
1
,
2
,
3
]).
to
(
F
.
ctx
()),
torch
.
tensor
([
0
,
1
,
2
,
3
,
5
,
5
]).
to
(
F
.
ctx
()),
]
torch
.
tensor
([
0
,
1
,
2
,
3
]).
to
(
F
.
ctx
()),
seeds
=
[
]
torch
.
tensor
([
0
,
3
,
4
,
2
,
5
]).
to
(
F
.
ctx
()),
seeds
=
[
torch
.
tensor
([
0
,
3
,
4
]).
to
(
F
.
ctx
()),
torch
.
tensor
([
0
,
3
,
4
,
2
,
5
]).
to
(
F
.
ctx
()),
]
torch
.
tensor
([
0
,
3
,
4
]).
to
(
F
.
ctx
()),
]
else
:
original_row_node_ids
=
[
torch
.
tensor
([
0
,
3
,
4
,
5
,
2
,
7
]).
to
(
F
.
ctx
()),
torch
.
tensor
([
0
,
3
,
4
,
5
,
2
]).
to
(
F
.
ctx
()),
]
compacted_indices
=
[
torch
.
tensor
([
3
,
4
,
2
,
5
,
5
]).
to
(
F
.
ctx
()),
torch
.
tensor
([
3
,
4
,
2
]).
to
(
F
.
ctx
()),
]
indptr
=
[
torch
.
tensor
([
0
,
1
,
2
,
3
,
3
,
5
]).
to
(
F
.
ctx
()),
torch
.
tensor
([
0
,
1
,
2
,
3
]).
to
(
F
.
ctx
()),
]
seeds
=
[
torch
.
tensor
([
0
,
3
,
4
,
5
,
2
]).
to
(
F
.
ctx
()),
torch
.
tensor
([
0
,
3
,
4
]).
to
(
F
.
ctx
()),
]
for
data
in
datapipe
:
for
data
in
datapipe
:
for
step
,
sampled_subgraph
in
enumerate
(
data
.
sampled_subgraphs
):
for
step
,
sampled_subgraph
in
enumerate
(
data
.
sampled_subgraphs
):
assert
torch
.
equal
(
assert
torch
.
equal
(
...
...
tests/python/pytorch/graphbolt/test_utils.py
View file @
78df8101
...
@@ -175,22 +175,40 @@ def test_exclude_seed_edges_gpu():
...
@@ -175,22 +175,40 @@ def test_exclude_seed_edges_gpu():
deduplicate
=
True
,
deduplicate
=
True
,
)
)
datapipe
=
datapipe
.
transform
(
partial
(
gb
.
exclude_seed_edges
))
datapipe
=
datapipe
.
transform
(
partial
(
gb
.
exclude_seed_edges
))
original_row_node_ids
=
[
if
torch
.
cuda
.
get_device_capability
()[
0
]
<
7
:
torch
.
tensor
([
0
,
3
,
4
,
2
,
5
,
7
]).
to
(
F
.
ctx
()),
original_row_node_ids
=
[
torch
.
tensor
([
0
,
3
,
4
,
2
,
5
]).
to
(
F
.
ctx
()),
torch
.
tensor
([
0
,
3
,
4
,
2
,
5
,
7
]).
to
(
F
.
ctx
()),
]
torch
.
tensor
([
0
,
3
,
4
,
2
,
5
]).
to
(
F
.
ctx
()),
compacted_indices
=
[
]
torch
.
tensor
([
4
,
3
,
5
,
5
]).
to
(
F
.
ctx
()),
compacted_indices
=
[
torch
.
tensor
([
4
,
3
]).
to
(
F
.
ctx
()),
torch
.
tensor
([
4
,
3
,
5
,
5
]).
to
(
F
.
ctx
()),
]
torch
.
tensor
([
4
,
3
]).
to
(
F
.
ctx
()),
indptr
=
[
]
torch
.
tensor
([
0
,
1
,
2
,
2
,
4
,
4
]).
to
(
F
.
ctx
()),
indptr
=
[
torch
.
tensor
([
0
,
1
,
2
,
2
]).
to
(
F
.
ctx
()),
torch
.
tensor
([
0
,
1
,
2
,
2
,
5
,
5
]).
to
(
F
.
ctx
()),
]
torch
.
tensor
([
0
,
1
,
2
,
2
]).
to
(
F
.
ctx
()),
seeds
=
[
]
torch
.
tensor
([
0
,
3
,
4
,
2
,
5
]).
to
(
F
.
ctx
()),
seeds
=
[
torch
.
tensor
([
0
,
3
,
4
]).
to
(
F
.
ctx
()),
torch
.
tensor
([
0
,
3
,
4
,
2
,
5
]).
to
(
F
.
ctx
()),
]
torch
.
tensor
([
0
,
3
,
4
]).
to
(
F
.
ctx
()),
]
else
:
original_row_node_ids
=
[
torch
.
tensor
([
0
,
3
,
4
,
5
,
2
,
7
]).
to
(
F
.
ctx
()),
torch
.
tensor
([
0
,
3
,
4
,
5
,
2
]).
to
(
F
.
ctx
()),
]
compacted_indices
=
[
torch
.
tensor
([
3
,
4
,
5
,
5
]).
to
(
F
.
ctx
()),
torch
.
tensor
([
3
,
4
]).
to
(
F
.
ctx
()),
]
indptr
=
[
torch
.
tensor
([
0
,
1
,
2
,
2
,
2
,
4
]).
to
(
F
.
ctx
()),
torch
.
tensor
([
0
,
1
,
2
,
2
]).
to
(
F
.
ctx
()),
]
seeds
=
[
torch
.
tensor
([
0
,
3
,
4
,
5
,
2
]).
to
(
F
.
ctx
()),
torch
.
tensor
([
0
,
3
,
4
]).
to
(
F
.
ctx
()),
]
for
data
in
datapipe
:
for
data
in
datapipe
:
for
step
,
sampled_subgraph
in
enumerate
(
data
.
sampled_subgraphs
):
for
step
,
sampled_subgraph
in
enumerate
(
data
.
sampled_subgraphs
):
assert
torch
.
equal
(
assert
torch
.
equal
(
...
...
cccl
@
64d3a5f0
Compare
c4eda1ae
...
64d3a5f0
Subproject commit
c4eda1aea304c012270dbd10235e60eaf47bd06f
Subproject commit
64d3a5f0c1c83ed83be8c0a9a1f0cdb31f913e81
cuco
@
2101cb31
Subproject commit 2101cb31d0210b609cd02c88f9b538e10881d91d
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment