Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
83ea9a8d
Commit
83ea9a8d
authored
Mar 01, 2025
by
sangwz
Browse files
code update for v2.2.1+dtk25.04
parent
74d88bf8
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
392 additions
and
81 deletions
+392
-81
graphbolt/CMakeLists.txt
graphbolt/CMakeLists.txt
+20
-16
graphbolt/include/graphbolt/continuous_seed.h
graphbolt/include/graphbolt/continuous_seed.h
+16
-16
graphbolt/src/cuda/common.h
graphbolt/src/cuda/common.h
+6
-6
graphbolt/src/cuda/extension/unique_and_compact_map.hip
graphbolt/src/cuda/extension/unique_and_compact_map.hip
+270
-0
graphbolt/src/cuda/gather.hip
graphbolt/src/cuda/gather.hip
+38
-0
graphbolt/src/cuda/neighbor_sampler.hip
graphbolt/src/cuda/neighbor_sampler.hip
+10
-16
graphbolt/src/cuda/unique_and_compact_impl.hip
graphbolt/src/cuda/unique_and_compact_impl.hip
+4
-4
python/setup.py
python/setup.py
+8
-5
src/array/cuda/spmm.cuh
src/array/cuda/spmm.cuh
+11
-11
tensoradapter/pytorch/CMakeLists.txt
tensoradapter/pytorch/CMakeLists.txt
+4
-2
tests/python/pytorch/graphbolt/internal/test_sample_utils.py
tests/python/pytorch/graphbolt/internal/test_sample_utils.py
+2
-2
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
+3
-3
No files found.
graphbolt/CMakeLists.txt
View file @
83ea9a8d
...
...
@@ -9,6 +9,9 @@ if(USE_CUDA)
endif
()
if
(
USE_HIP
)
list
(
APPEND CMAKE_PREFIX_PATH $ENV{ROCM_PATH}
)
set
(
HIP_PATH $ENV{ROCM_PATH}/hip
)
find_package
(
HIP REQUIRED PATHS
${
HIP_PATH
}
NO_DEFAULT_PATH
)
message
(
STATUS
"Build graphbolt with CUDA support"
)
enable_language
(
HIP
)
add_definitions
(
-DGRAPHBOLT_USE_CUDA
)
...
...
@@ -44,7 +47,7 @@ string(REPLACE "." ";" TORCH_VERSION_LIST ${TORCH_VER})
set
(
Torch_DIR
"
${
TORCH_PREFIX
}
/Torch"
)
message
(
STATUS
"Setting directory to
${
Torch_DIR
}
"
)
find_package
(
Torch REQUIRED
)
find_package
(
Torch REQUIRED
PATH
${
Torch_DIR
}
)
set
(
CMAKE_C_FLAGS
"
${
CMAKE_C_FLAGS
}
${
TORCH_C_FLAGS
}
"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
${
TORCH_CXX_FLAGS
}
"
)
set
(
CMAKE_CXX_FLAGS_DEBUG
"
${
CMAKE_CXX_FLAGS_DEBUG
}
-O0 -g3 -ggdb"
)
...
...
@@ -90,34 +93,35 @@ if(CMAKE_SYSTEM_NAME MATCHES "Linux")
endif
(
USE_LIBURING
)
endif
()
#
if(USE_
CUDA
)
#
file(GLOB BOLT_CUDA_EXTENSION_SRC
#
${BOLT_DIR}/cuda/extension/*.
cu
#
${BOLT_DIR}/cuda/extension/*.cc
#
)
#
# Until https://github.com/NVIDIA/cccl/issues/1083 is resolved, we need to
#
# compile the cuda/extension folder with Volta+ CUDA architectures.
#
add_library(${LIB_GRAPHBOLT_CUDA_NAME} STATIC ${BOLT_CUDA_EXTENSION_SRC} ${BOLT_HEADERS})
#
target_link_libraries(${LIB_GRAPHBOLT_CUDA_NAME} "${TORCH_LIBRARIES}")
#
if
(
USE_
HIP
)
file
(
GLOB BOLT_CUDA_EXTENSION_SRC
${
BOLT_DIR
}
/cuda/extension/*.
hip
${
BOLT_DIR
}
/cuda/extension/*.cc
)
# Until https://github.com/NVIDIA/cccl/issues/1083 is resolved, we need to
# compile the cuda/extension folder with Volta+ CUDA architectures.
#
add_library(${LIB_GRAPHBOLT_CUDA_NAME} STATIC ${BOLT_CUDA_EXTENSION_SRC} ${BOLT_HEADERS})
#
target_link_libraries(${LIB_GRAPHBOLT_CUDA_NAME} "${TORCH_LIBRARIES}")
# set_target_properties(${LIB_GRAPHBOLT_NAME} PROPERTIES CUDA_STANDARD 17)
# set_target_properties(${LIB_GRAPHBOLT_CUDA_NAME} PROPERTIES CUDA_STANDARD 17)
# set_target_properties(${LIB_GRAPHBOLT_CUDA_NAME} PROPERTIES CUDA_ARCHITECTURES "${CMAKE_CUDA_ARCHITECTURES_FILTERED}")
#
set_target_properties(${LIB_GRAPHBOLT_CUDA_NAME} PROPERTIES POSITION_INDEPENDENT_CODE TRUE)
#set_target_properties(${LIB_GRAPHBOLT_CUDA_NAME} PROPERTIES POSITION_INDEPENDENT_CODE TRUE)
# message(STATUS "Use external CCCL library for a consistent API and performance for graphbolt.")
# include_directories(BEFORE
# "../third_party/cccl/thrust"
# "../third_party/cccl/cub"
# "../third_party/cccl/libcudacxx/include"
# "../third_party/cuco/include")
if
(
USE_HIP
)
# set_target_properties(${LIB_GRAPHBOLT_NAME} PROPERTIES CUDA_STANDARD 17)
message
(
STATUS
"Use external CCCL library for a consistent API and performance for graphbolt."
)
target_compile_options
(
${
LIB_GRAPHBOLT_NAME
}
PRIVATE
"--gpu-max-threads-per-block=1024"
)
target_
include_directories
(
${
LIB_GRAPHBOLT_NAME
}
PRIVAT
E
include_directories
(
BEFOR
E
# "${ROCM_PATH}/include/thrust"
"
${
ROCM_PATH
}
/include/hipcub"
"
${
ROCM_PATH
}
/include/rocprim"
"../third_party/cuco/include"
)
message
(
STATUS
"Use HugeCTR gpu_cache for graphbolt with INCLUDE_DIRS $ENV{GPU_CACHE_INCLUDE_DIRS}."
)
...
...
@@ -128,10 +132,10 @@ if(USE_HIP)
# get_property(archs TARGET ${LIB_GRAPHBOLT_NAME} PROPERTY CUDA_ARCHITECTURES)
message
(
STATUS
"CUDA_ARCHITECTURES for graphbolt:
${
archs
}
"
)
get_property
(
archs TARGET
${
LIB_GRAPHBOLT_CUDA_NAME
}
PROPERTY CUDA_ARCHITECTURES
)
#
get_property(archs TARGET ${LIB_GRAPHBOLT_CUDA_NAME} PROPERTY CUDA_ARCHITECTURES)
message
(
STATUS
"CUDA_ARCHITECTURES for graphbolt extension:
${
archs
}
"
)
target_link_libraries
(
${
LIB_GRAPHBOLT_NAME
}
${
LIB_GRAPHBOLT_CUDA_NAME
}
)
#
target_link_libraries(${LIB_GRAPHBOLT_NAME} ${LIB_GRAPHBOLT_CUDA_NAME})
endif
()
# The Torch CMake configuration only sets up the path for the MKL library when
...
...
graphbolt/include/graphbolt/continuous_seed.h
View file @
83ea9a8d
...
...
@@ -24,12 +24,12 @@
#include <cmath>
#ifdef __
CUDACC
__
#include <
cu
rand_kernel.h>
#ifdef __
HIP_DEVICE_COMPILE
__
#include <
hiprand/hip
rand_kernel.h>
#else
#include <pcg_random.hpp>
#include <random>
#endif // __
CUDA_ARCH
__
#endif // __
HIP_DEVICE_COMPILE
__
#ifndef M_SQRT1_2
#define M_SQRT1_2 0.707106781186547524401
...
...
@@ -58,24 +58,24 @@ class continuous_seed {
uint64_t
get_seed
(
int
i
)
const
{
return
s
[
i
!=
0
];
}
#ifdef __
CUDACC
__
#ifdef __
HIP_DEVICE_COMPILE
__
__device__
inline
float
uniform
(
const
uint64_t
t
)
const
{
const
uint64_t
kCurandSeed
=
999961
;
// Could be any random number.
cu
randStatePhilox4_32_10_t
rng
;
cu
rand_init
(
kCurandSeed
,
s
[
0
],
t
,
&
rng
);
hip
randStatePhilox4_32_10_t
rng
;
hip
rand_init
(
kCurandSeed
,
s
[
0
],
t
,
&
rng
);
float
rnd
;
if
(
s
[
0
]
!=
s
[
1
])
{
rnd
=
c
[
0
]
*
cu
rand_normal
(
&
rng
);
cu
rand_init
(
kCurandSeed
,
s
[
1
],
t
,
&
rng
);
rnd
+=
c
[
1
]
*
cu
rand_normal
(
&
rng
);
rnd
=
c
[
0
]
*
hip
rand_normal
(
&
rng
);
hip
rand_init
(
kCurandSeed
,
s
[
1
],
t
,
&
rng
);
rnd
+=
c
[
1
]
*
hip
rand_normal
(
&
rng
);
rnd
=
normcdff
(
rnd
);
}
else
{
rnd
=
cu
rand_uniform
(
&
rng
);
rnd
=
hip
rand_uniform
(
&
rng
);
}
return
rnd
;
}
#else
inline
float
uniform
(
const
uint64_t
t
)
const
{
__host__
inline
float
uniform
(
const
uint64_t
t
)
const
{
pcg32
ng0
(
s
[
0
],
t
);
float
rnd
;
if
(
s
[
0
]
!=
s
[
1
])
{
...
...
@@ -91,7 +91,7 @@ class continuous_seed {
}
return
rnd
;
}
#endif // __
CUDA_ARCH
__
#endif // __
HIP_DEVICE_COMPILE
__
};
class
single_seed
{
...
...
@@ -103,12 +103,12 @@ class single_seed {
single_seed
(
torch
::
Tensor
seed_arr
)
:
seed_
(
seed_arr
.
data_ptr
<
int64_t
>
()[
0
])
{}
#ifdef __
CUDACC
__
#ifdef __
HIP_DEVICE_COMPILE
__
__device__
inline
float
uniform
(
const
uint64_t
id
)
const
{
const
uint64_t
kCurandSeed
=
999961
;
// Could be any random number.
cu
randStatePhilox4_32_10_t
rng
;
cu
rand_init
(
kCurandSeed
,
seed_
,
id
,
&
rng
);
return
cu
rand_uniform
(
&
rng
);
hip
randStatePhilox4_32_10_t
rng
;
hip
rand_init
(
kCurandSeed
,
seed_
,
id
,
&
rng
);
return
hip
rand_uniform
(
&
rng
);
}
#else
inline
float
uniform
(
const
uint64_t
id
)
const
{
...
...
graphbolt/src/cuda/common.h
View file @
83ea9a8d
...
...
@@ -54,13 +54,13 @@ struct CUDAWorkspaceAllocator {
CUDAWorkspaceAllocator
&
operator
=
(
const
CUDAWorkspaceAllocator
&
)
=
default
;
void
operator
()(
void
*
ptr
)
const
{
c10
::
hip
::
HIPCachingAllocator
::
raw_delete
(
ptr
);
at
::
hip
::
HIPCachingAllocator
::
raw_delete
(
ptr
);
}
// Required by thrust to satisfy allocator requirements.
value_type
*
allocate
(
std
::
ptrdiff_t
size
)
const
{
return
reinterpret_cast
<
value_type
*>
(
c10
::
hip
::
HIPCachingAllocator
::
raw_alloc
(
size
*
sizeof
(
value_type
)));
at
::
hip
::
HIPCachingAllocator
::
raw_alloc
(
size
*
sizeof
(
value_type
)));
}
// Required by thrust to satisfy allocator requirements.
...
...
@@ -71,7 +71,7 @@ struct CUDAWorkspaceAllocator {
std
::
size_t
size
)
const
{
return
std
::
unique_ptr
<
T
,
CUDAWorkspaceAllocator
>
(
reinterpret_cast
<
T
*>
(
c10
::
cuda
::
CUDA
CachingAllocator
::
raw_alloc
(
sizeof
(
T
)
*
size
)),
at
::
cuda
::
HIP
CachingAllocator
::
raw_alloc
(
sizeof
(
T
)
*
size
)),
*
this
);
}
};
...
...
@@ -92,9 +92,9 @@ inline bool is_zero<dim3>(dim3 size) {
#define CUDA_RUNTIME_CHECK(EXPR) \
do { \
cuda
Error_t __err = EXPR; \
if (__err !=
cuda
Success) { \
auto get_error_str_err =
cuda
GetErrorString(__err); \
hip
Error_t __err = EXPR; \
if (__err !=
hip
Success) { \
auto get_error_str_err =
hip
GetErrorString(__err); \
AT_ERROR("HIP runtime error: ", get_error_str_err); \
} \
} while (0)
...
...
graphbolt/src/cuda/extension/unique_and_compact_map.hip
0 → 100644
View file @
83ea9a8d
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#include "hip/hip_runtime.h"
/**
* Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
* @file cuda/unique_and_compact_map.cu
* @brief Unique and compact operator implementation on CUDA using hash table.
*/
#include <graphbolt/cuda_ops.h>
#include <thrust/gather.h>
#include <cuco/static_map.cuh>
#include <cuda/std/atomic>
#include <numeric>
#include "../common.h"
#include "../utils.h"
#include "unique_and_compact.h"
namespace graphbolt {
namespace ops {
// Support graphs with up to 2^kNodeIdBits nodes.
constexpr int kNodeIdBits = 40;
template <typename index_t, typename map_t>
__global__ void _InsertAndSetMinBatched(
const int64_t num_edges, const int32_t* const indexes, index_t** pointers,
const int64_t* const offsets, map_t map) {
int64_t i = blockIdx.x * blockDim.x + threadIdx.x;
const int stride = gridDim.x * blockDim.x;
while (i < num_edges) {
const int64_t tensor_index = indexes[i];
const auto tensor_offset = i - offsets[tensor_index];
const int64_t node_id = pointers[tensor_index][tensor_offset];
const auto batch_index = tensor_index / 2;
const int64_t key = node_id | (batch_index << kNodeIdBits);
auto [slot, is_new_key] = map.insert_and_find(cuco::pair{key, i});
if (!is_new_key) {
auto ref = ::cuda::atomic_ref<int64_t, ::cuda::thread_scope_device>{
slot->second};
ref.fetch_min(i, ::cuda::memory_order_relaxed);
}
i += stride;
}
}
template <typename index_t, typename map_t>
__global__ void _IsInsertedBatched(
const int64_t num_edges, const int32_t* const indexes, index_t** pointers,
const int64_t* const offsets, map_t map, int64_t* valid) {
int64_t i = blockIdx.x * blockDim.x + threadIdx.x;
const int stride = gridDim.x * blockDim.x;
while (i < num_edges) {
const int64_t tensor_index = indexes[i];
const auto tensor_offset = i - offsets[tensor_index];
const int64_t node_id = pointers[tensor_index][tensor_offset];
const auto batch_index = tensor_index / 2;
const int64_t key = node_id | (batch_index << kNodeIdBits);
auto slot = map.find(key);
valid[i] = slot->second == i;
i += stride;
}
}
template <typename index_t, typename map_t>
__global__ void _GetInsertedBatched(
const int64_t num_edges, const int32_t* const indexes, index_t** pointers,
const int64_t* const offsets, map_t map, const int64_t* const valid,
index_t* unique_ids) {
int64_t i = blockIdx.x * blockDim.x + threadIdx.x;
const int stride = gridDim.x * blockDim.x;
while (i < num_edges) {
const auto valid_i = valid[i];
if (valid_i + 1 == valid[i + 1]) {
const int64_t tensor_index = indexes[i];
const auto tensor_offset = i - offsets[tensor_index];
const int64_t node_id = pointers[tensor_index][tensor_offset];
const auto batch_index = tensor_index / 2;
const int64_t key = node_id | (batch_index << kNodeIdBits);
auto slot = map.find(key);
const auto batch_offset = offsets[batch_index * 2];
const auto new_id = valid_i - valid[batch_offset];
unique_ids[valid_i] = node_id;
slot->second = new_id;
}
i += stride;
}
}
template <typename index_t, typename map_t>
__global__ void _MapIdsBatched(
const int num_batches, const int64_t num_edges,
const int32_t* const indexes, index_t** pointers,
const int64_t* const offsets, map_t map, index_t* mapped_ids) {
int64_t i = blockIdx.x * blockDim.x + threadIdx.x;
const int stride = gridDim.x * blockDim.x;
while (i < num_edges) {
const int64_t tensor_index = indexes[i];
int64_t batch_index;
if (tensor_index >= 2 * num_batches) {
batch_index = tensor_index - 2 * num_batches;
} else if (tensor_index & 1) {
batch_index = tensor_index / 2;
} else {
batch_index = -1;
}
// Only map src or dst ids.
if (batch_index >= 0) {
const auto tensor_offset = i - offsets[tensor_index];
const int64_t node_id = pointers[tensor_index][tensor_offset];
const int64_t key = node_id | (batch_index << kNodeIdBits);
auto slot = map.find(key);
mapped_ids[i] = slot->second;
}
i += stride;
}
}
std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> >
UniqueAndCompactBatchedHashMapBased(
const std::vector<torch::Tensor>& src_ids,
const std::vector<torch::Tensor>& dst_ids,
const std::vector<torch::Tensor>& unique_dst_ids) {
auto allocator = cuda::GetAllocator();
auto stream = cuda::GetCurrentStream();
auto scalar_type = src_ids.at(0).scalar_type();
constexpr int BLOCK_SIZE = 512;
const auto num_batches = src_ids.size();
static_assert(
sizeof(std::ptrdiff_t) == sizeof(int64_t),
"Need to be compiled on a 64-bit system.");
constexpr int batch_id_bits = sizeof(int64_t) * 8 - 1 - kNodeIdBits;
TORCH_CHECK(
num_batches <= (1 << batch_id_bits),
"UniqueAndCompactBatched supports a batch size of up to ",
1 << batch_id_bits);
return AT_DISPATCH_INDEX_TYPES(
scalar_type, "unique_and_compact", ([&] {
// For 2 batches of inputs, stores the input tensor pointers in the
// unique_dst, src, unique_dst, src, dst, dst order. Since there are
// 3 * num_batches input tensors, we need the first 3 * num_batches to
// store the input tensor pointers. Then, we store offsets in the rest
// of the 3 * num_batches + 1 space as if they were stored contiguously.
auto pointers_and_offsets = torch::empty(
6 * num_batches + 1,
c10::TensorOptions().dtype(torch::kInt64).pinned_memory(true));
// Points to the input tensor pointers.
auto pointers_ptr =
reinterpret_cast<index_t**>(pointers_and_offsets.data_ptr());
// Points to the input tensor storage logical offsets.
auto offsets_ptr =
pointers_and_offsets.data_ptr<int64_t>() + 3 * num_batches;
for (std::size_t i = 0; i < num_batches; i++) {
pointers_ptr[2 * i] = unique_dst_ids.at(i).data_ptr<index_t>();
offsets_ptr[2 * i] = unique_dst_ids[i].size(0);
pointers_ptr[2 * i + 1] = src_ids.at(i).data_ptr<index_t>();
offsets_ptr[2 * i + 1] = src_ids[i].size(0);
pointers_ptr[2 * num_batches + i] = dst_ids.at(i).data_ptr<index_t>();
offsets_ptr[2 * num_batches + i] = dst_ids[i].size(0);
}
// Finish computing the offsets by taking a cumulative sum.
std::exclusive_scan(
offsets_ptr, offsets_ptr + 3 * num_batches + 1, offsets_ptr, 0ll);
// Device version of the tensors defined above. We store the information
// initially on the CPU, which are later copied to the device.
auto pointers_and_offsets_dev = torch::empty(
pointers_and_offsets.size(0),
src_ids[0].options().dtype(pointers_and_offsets.scalar_type()));
auto offsets_dev = pointers_and_offsets_dev.slice(0, 3 * num_batches);
auto pointers_dev_ptr =
reinterpret_cast<index_t**>(pointers_and_offsets_dev.data_ptr());
auto offsets_dev_ptr = offsets_dev.data_ptr<int64_t>();
CUDA_CALL(hipMemcpyAsync(
pointers_dev_ptr, pointers_ptr,
sizeof(int64_t) * pointers_and_offsets.size(0),
hipMemcpyHostToDevice, stream));
auto indexes = ExpandIndptrImpl(
offsets_dev, torch::kInt32, torch::nullopt,
offsets_ptr[3 * num_batches]);
cuco::static_map map{
offsets_ptr[2 * num_batches],
0.5, // load_factor
cuco::empty_key{static_cast<int64_t>(-1)},
cuco::empty_value{static_cast<int64_t>(-1)},
{},
cuco::linear_probing<1, cuco::default_hash_function<int64_t> >{},
{},
{},
cuda::CUDAWorkspaceAllocator<cuco::pair<int64_t, int64_t> >{},
cuco::cuda_stream_ref{stream},
};
C10_HIP_KERNEL_LAUNCH_CHECK(); // Check the map constructor's success.
const dim3 block(BLOCK_SIZE);
const dim3 grid(
(offsets_ptr[2 * num_batches] + BLOCK_SIZE - 1) / BLOCK_SIZE);
CUDA_KERNEL_CALL(
_InsertAndSetMinBatched, grid, block, 0,
offsets_ptr[2 * num_batches], indexes.data_ptr<int32_t>(),
pointers_dev_ptr, offsets_dev_ptr, map.ref(cuco::insert_and_find));
auto valid = torch::empty(
offsets_ptr[2 * num_batches] + 1,
src_ids[0].options().dtype(torch::kInt64));
CUDA_KERNEL_CALL(
_IsInsertedBatched, grid, block, 0, offsets_ptr[2 * num_batches],
indexes.data_ptr<int32_t>(), pointers_dev_ptr, offsets_dev_ptr,
map.ref(cuco::find), valid.data_ptr<int64_t>());
valid = ExclusiveCumSum(valid);
auto unique_ids_offsets = torch::empty(
num_batches + 1,
c10::TensorOptions().dtype(torch::kInt64).pinned_memory(true));
auto unique_ids_offsets_ptr = unique_ids_offsets.data_ptr<int64_t>();
for (int64_t i = 0; i <= num_batches; i++) {
unique_ids_offsets_ptr[i] = offsets_ptr[2 * i];
}
THRUST_CALL(
gather, unique_ids_offsets_ptr,
unique_ids_offsets_ptr + unique_ids_offsets.size(0),
valid.data_ptr<int64_t>(), unique_ids_offsets_ptr);
at::cuda::CUDAEvent unique_ids_offsets_event;
unique_ids_offsets_event.record();
auto unique_ids =
torch::empty(offsets_ptr[2 * num_batches], src_ids[0].options());
CUDA_KERNEL_CALL(
_GetInsertedBatched, grid, block, 0, offsets_ptr[2 * num_batches],
indexes.data_ptr<int32_t>(), pointers_dev_ptr, offsets_dev_ptr,
map.ref(cuco::find), valid.data_ptr<int64_t>(),
unique_ids.data_ptr<index_t>());
auto mapped_ids =
torch::empty(offsets_ptr[3 * num_batches], unique_ids.options());
CUDA_KERNEL_CALL(
_MapIdsBatched, grid, block, 0, num_batches,
offsets_ptr[3 * num_batches], indexes.data_ptr<int32_t>(),
pointers_dev_ptr, offsets_dev_ptr, map.ref(cuco::find),
mapped_ids.data_ptr<index_t>());
std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> >
results;
unique_ids_offsets_event.synchronize();
for (int64_t i = 0; i < num_batches; i++) {
results.emplace_back(
unique_ids.slice(
0, unique_ids_offsets_ptr[i], unique_ids_offsets_ptr[i + 1]),
mapped_ids.slice(
0, offsets_ptr[2 * i + 1], offsets_ptr[2 * i + 2]),
mapped_ids.slice(
0, offsets_ptr[2 * num_batches + i],
offsets_ptr[2 * num_batches + i + 1]));
}
return results;
}));
}
} // namespace ops
} // namespace graphbolt
graphbolt/src/cuda/gather.hip
0 → 100644
View file @
83ea9a8d
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
/**
* Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
* @file cuda/gather.cu
* @brief Gather operators implementation on CUDA.
*/
#include <thrust/gather.h>
#include "./common.h"
namespace graphbolt {
namespace ops {
torch::Tensor Gather(
torch::Tensor input, torch::Tensor index,
torch::optional<torch::ScalarType> dtype) {
if (!dtype.has_value()) dtype = input.scalar_type();
auto output = torch::empty(index.sizes(), index.options().dtype(*dtype));
AT_DISPATCH_INDEX_TYPES(
index.scalar_type(), "GatherIndexType", ([&] {
AT_DISPATCH_INTEGRAL_TYPES(
input.scalar_type(), "GatherInputType", ([&] {
using input_t = scalar_t;
AT_DISPATCH_INTEGRAL_TYPES(*dtype, "GatherOutputType", ([&] {
using output_t = scalar_t;
THRUST_CALL(
gather, index.data_ptr<index_t>(),
index.data_ptr<index_t>() + index.size(0),
input.data_ptr<input_t>(), output.data_ptr<output_t>());
}));
}));
}));
return output;
}
} // namespace ops
} // namespace graphbolt
graphbolt/src/cuda/neighbor_sampler.hip
View file @
83ea9a8d
...
...
@@ -92,8 +92,8 @@ __global__ void _ComputeRandomsNS(
int64_t i = blockIdx.x * blockDim.x + threadIdx.x;
const int stride = gridDim.x * blockDim.x;
cu
randStatePhilox4_32_10_t rng;
cu
rand_init(random_seed, i, 0, &rng);
hip
randStatePhilox4_32_10_t rng;
hip
rand_init(random_seed, i, 0, &rng);
while (i < num_edges) {
const auto row_position = csr_rows[i];
...
...
@@ -101,7 +101,7 @@ __global__ void _ComputeRandomsNS(
const auto output_offset = output_indptr[row_position];
const auto fanout = output_indptr[row_position + 1] - output_offset;
const auto rnd =
row_offset < fanout ? row_offset :
cu
rand(&rng) % (row_offset + 1);
row_offset < fanout ? row_offset :
hip
rand(&rng) % (row_offset + 1);
if (rnd < fanout) {
const indptr_t edge_id =
row_offset + (sliced_indptr ? sliced_indptr[row_position] : 0);
...
...
@@ -131,26 +131,18 @@ __global__ void _ComputeRandoms(
const indptr_t* const sub_indptr, const indices_t* const csr_rows,
const weights_t* const sliced_weights, const indices_t* const indices,
const continuous_seed random_seed, float_t* random_arr,
// const unsigned long long random_seed, float_t* random_arr,
edge_id_t* edge_ids) {
int64_t i = blockIdx.x * blockDim.x + threadIdx.x;
const int stride = gridDim.x * blockDim.x;
hiprandStatePhilox4_32_10_t rng;
const auto labor = indices != nullptr;
if (!labor) {
hiprand_init(random_seed, i, 0, &rng);
}
while (i < num_edges) {
const auto row_position = csr_rows[i];
const auto row_offset = i - sub_indptr[row_position];
const auto in_idx = sliced_indptr[row_position] + row_offset;
if (labor) {
constexpr uint64_t kCurandSeed = 999961;
hiprand_init(kCurandSeed, random_seed, indices[in_idx], &rng);
}
const auto rnd = hiprand_uniform(&rng);
const auto rnd = random_seed.uniform(labor ? indices[in_idx] : i);
const auto prob =
sliced_weights ? sliced_weights[i] : static_cast<weights_t>(1);
const auto exp_rnd = -__logf(rnd);
...
...
@@ -214,7 +206,7 @@ struct SegmentEndFunc {
indptr_t* indptr;
in_degree_iterator_t in_degree;
__host__ __device__ auto operator()(int64_t i) {
return indptr[i] + in_degree[i];
return indptr[i] + in_degree[i];
}
};
...
...
@@ -381,7 +373,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
"Selected edge_id_t must be capable of storing edge_ids.");
// Using bfloat16 for random numbers works just as reliably as
// float32 and provides around 30% speedup.
using rnd_t =
nv
_bfloat16;
using rnd_t =
hip
_bfloat16;
auto randoms =
allocator.AllocateStorage<rnd_t>(num_edges.value());
auto randoms_sorted =
...
...
@@ -454,7 +446,9 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
DeviceSegmentedSort::SortKeys, edge_id_segments.get(),
sorted_edge_id_segments.get(), picked_eids.size(0),
num_rows, sub_indptr.data_ptr<indptr_t>(),
sampled_segment_end_it);
sub_indptr.data_ptr<indptr_t>()+1);
// sub_indptr.data_ptr<indptr_t>()+1);
}
auto input_buffer_it = thrust::make_transform_iterator(
...
...
graphbolt/src/cuda/unique_and_compact_impl.hip
View file @
83ea9a8d
...
...
@@ -281,8 +281,8 @@ UniqueAndCompactBatched(
return it->second;
} else {
int major;
CUDA_RUNTIME_CHECK(
cuda
DeviceGetAttribute(
&major,
cudaDevAttr
ComputeCapabilityMajor, dev_id));
CUDA_RUNTIME_CHECK(
hip
DeviceGetAttribute(
&major,
hipDeviceAttribute
ComputeCapabilityMajor, dev_id));
return compute_capability_cache[dev_id] = major;
}
}();
...
...
@@ -290,8 +290,8 @@ UniqueAndCompactBatched(
// Utilizes a hash table based implementation, the mapped id of a vertex
// will be monotonically increasing as the first occurrence index of it in
// torch.cat([unique_dst_ids, src_ids]). Thus, it is deterministic.
return UniqueAndCompactBatchedHashMapBased(
src_ids, dst_ids, unique_dst_ids);
//
return UniqueAndCompactBatchedHashMapBased(
//
src_ids, dst_ids, unique_dst_ids);
}
// Utilizes a sort based algorithm, the mapped id of a vertex part of the
// src_ids but not part of the unique_dst_ids will be monotonically increasing
...
...
python/setup.py
View file @
83ea9a8d
...
...
@@ -157,10 +157,13 @@ def copy_lib(lib_name, backend=""):
dst_dir_
,
exist_ok
=
True
,
)
shutil
.
copy
(
os
.
path
.
join
(
dir_
,
lib_name
,
backend
,
lib_file_name
),
dst_dir_
,
)
if
(
os
.
path
.
join
(
dir_
,
lib_name
,
backend
)
==
dst_dir_
):
pass
else
:
shutil
.
copy
(
os
.
path
.
join
(
dir_
,
lib_name
,
backend
,
lib_file_name
),
dst_dir_
,
)
fo
.
write
(
f
"include dgl/
{
lib_name
}
/
{
backend
}
/
{
lib_file_name
}
\n
"
)
...
...
@@ -234,7 +237,7 @@ if "DGLBACKEND" in os.environ and os.environ["DGLBACKEND"] != "pytorch":
setup
(
name
=
"dgl"
+
os
.
getenv
(
"DGL_PACKAGE_SUFFIX"
,
""
),
version
=
VERSION
,
version
=
VERSION
+
str
(
'+das.opt1.dtk2504'
)
,
description
=
"Deep Graph Library"
,
zip_safe
=
False
,
maintainer
=
"DGL Team"
,
...
...
src/array/cuda/spmm.cuh
View file @
83ea9a8d
...
...
@@ -248,18 +248,18 @@ void CusparseCsrmm2(
auto
transB
=
HIPSPARSE_OPERATION_NON_TRANSPOSE
;
size_t
workspace_size
;
cu
sparseSpMMAlg_t
spmm_alg
=
use_deterministic_alg_only
?
CU
SPARSE_SPMM_CSR_ALG3
:
CU
SPARSE_SPMM_CSR_ALG2
;
hip
sparseSpMMAlg_t
spmm_alg
=
use_deterministic_alg_only
?
HIP
SPARSE_SPMM_CSR_ALG3
:
HIP
SPARSE_SPMM_CSR_ALG2
;
CUSPARSE_CALL
(
hipsparseSpMM_bufferSize
(
thr_entry
->
cusparse_handle
,
transA
,
transB
,
&
alpha
,
matA
,
matB
,
&
beta
,
matC
,
dtype
,
HIPSPARSE_SPMM_CSR_ALG2
,
&
workspace_size
));
matC
,
dtype
,
spmm_alg
,
&
workspace_size
));
void
*
workspace
=
device
->
AllocWorkspace
(
ctx
,
workspace_size
);
CUSPARSE_CALL
(
hipsparseSpMM
(
thr_entry
->
cusparse_handle
,
transA
,
transB
,
&
alpha
,
matA
,
matB
,
&
beta
,
matC
,
dtype
,
HIPSPARSE_SPMM_CSR_ALG2
,
workspace
));
matC
,
dtype
,
spmm_alg
,
workspace
));
device
->
FreeWorkspace
(
ctx
,
workspace
);
...
...
@@ -294,7 +294,7 @@ template <typename DType, typename IdType>
void
CusparseCsrmm2Hetero
(
const
DGLContext
&
ctx
,
const
CSRMatrix
&
csr
,
const
DType
*
B_data
,
const
DType
*
A_data
,
DType
*
C_data
,
int64_t
x_length
,
cuda
Stream_t
strm_id
,
const
DType
*
A_data
,
DType
*
C_data
,
int64_t
x_length
,
hip
Stream_t
strm_id
,
bool
use_deterministic_alg_only
=
false
)
{
// We use csrmm2 to perform following operation:
...
...
@@ -348,16 +348,16 @@ void CusparseCsrmm2Hetero(
auto
transB
=
HIPSPARSE_OPERATION_NON_TRANSPOSE
;
size_t
workspace_size
;
cu
sparseSpMMAlg_t
spmm_alg
=
use_deterministic_alg_only
?
CU
SPARSE_SPMM_CSR_ALG3
:
CU
SPARSE_SPMM_CSR_ALG2
;
hip
sparseSpMMAlg_t
spmm_alg
=
use_deterministic_alg_only
?
HIP
SPARSE_SPMM_CSR_ALG3
:
HIP
SPARSE_SPMM_CSR_ALG2
;
CUSPARSE_CALL
(
hipsparseSpMM_bufferSize
(
thr_entry
->
cusparse_handle
,
transA
,
transB
,
&
alpha
,
matA
,
matB
,
&
beta
,
matC
,
dtype
,
HIPSPARSE_SPMM_CSR_ALG2
,
&
workspace_size
));
matC
,
dtype
,
spmm_alg
,
&
workspace_size
));
void
*
workspace
=
device
->
AllocWorkspace
(
ctx
,
workspace_size
);
CUSPARSE_CALL
(
hipsparseSpMM
(
thr_entry
->
cusparse_handle
,
transA
,
transB
,
&
alpha
,
matA
,
matB
,
&
beta
,
matC
,
dtype
,
HIPSPARSE_SPMM_CSR_ALG2
,
workspace
));
matC
,
dtype
,
spmm_alg
,
workspace
));
device
->
FreeWorkspace
(
ctx
,
workspace
);
...
...
tensoradapter/pytorch/CMakeLists.txt
View file @
83ea9a8d
...
...
@@ -18,13 +18,15 @@ list(GET TORCH_PREFIX_VER 1 TORCH_VER)
message
(
STATUS
"Configuring for PyTorch
${
TORCH_VER
}
"
)
if
(
USE_HIP
)
message
(
STATUS
"<<<<<<<<<<<<<< PYTORCH USE_HIP:
${
USE_HIP
}
"
)
list
(
APPEND CMAKE_PREFIX_PATH $ENV{ROCM_PATH}
)
set
(
HIP_PATH $ENV{ROCM_PATH}/hip
)
find_package
(
HIP REQUIRED PATHS
${
HIP_PATH
}
NO_DEFAULT_PATH
)
add_definitions
(
-DDGL_USE_CUDA
)
endif
()
set
(
Torch_DIR
"
${
TORCH_PREFIX
}
/Torch"
)
message
(
STATUS
"Setting directory to
${
Torch_DIR
}
"
)
find_package
(
Torch REQUIRED
)
find_package
(
Torch REQUIRED
PATH Torch_DIR
)
set
(
CMAKE_C_FLAGS
"
${
CMAKE_C_FLAGS
}
${
TORCH_C_FLAGS
}
"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
${
TORCH_CXX_FLAGS
}
"
)
set
(
CMAKE_CXX_FLAGS_DEBUG
"
${
CMAKE_CXX_FLAGS_DEBUG
}
-O0 -g3 -ggdb"
)
...
...
tests/python/pytorch/graphbolt/internal/test_sample_utils.py
View file @
83ea9a8d
...
...
@@ -15,7 +15,7 @@ def test_unique_and_compact_hetero():
"n2"
:
torch
.
tensor
([
0
,
3
,
5
,
2
,
7
,
8
,
4
,
9
],
device
=
F
.
ctx
()),
"n3"
:
torch
.
tensor
([
1
,
2
,
6
,
8
,
3
],
device
=
F
.
ctx
()),
}
if
N1
.
is_cuda
and
torch
.
cuda
.
get_device_capability
()[
0
]
<
7
:
if
N1
.
is_cuda
and
torch
.
cuda
.
get_device_capability
()[
0
]
<
7
or
torch
.
version
.
hip
:
expected_reverse_id
=
{
k
:
v
.
sort
()[
1
]
for
k
,
v
in
expected_unique
.
items
()
}
...
...
@@ -70,7 +70,7 @@ def test_unique_and_compact_homo():
expected_unique_N
=
torch
.
tensor
(
[
0
,
5
,
2
,
7
,
12
,
9
,
6
,
3
,
4
,
1
],
device
=
F
.
ctx
()
)
if
N
.
is_cuda
and
torch
.
cuda
.
get_device_capability
()[
0
]
<
7
:
if
N
.
is_cuda
and
torch
.
cuda
.
get_device_capability
()[
0
]
<
7
or
torch
.
version
.
hip
:
expected_reverse_id_N
=
expected_unique_N
.
sort
()[
1
]
expected_unique_N
=
expected_unique_N
.
sort
()[
0
]
else
:
...
...
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
View file @
83ea9a8d
...
...
@@ -893,7 +893,7 @@ def test_SubgraphSampler_unique_csc_format_Homo_Node_gpu(labor):
deduplicate
=
True
,
)
if
torch
.
cuda
.
get_device_capability
()[
0
]
<
7
:
if
torch
.
cuda
.
get_device_capability
()[
0
]
<
7
or
torch
.
version
.
hip
:
original_row_node_ids
=
[
torch
.
tensor
([
0
,
3
,
4
,
2
,
5
,
7
]).
to
(
F
.
ctx
()),
torch
.
tensor
([
0
,
3
,
4
,
2
,
5
]).
to
(
F
.
ctx
()),
...
...
@@ -1301,7 +1301,7 @@ def test_SubgraphSampler_unique_csc_format_Homo_Link_gpu(labor):
deduplicate
=
True
,
)
if
torch
.
cuda
.
get_device_capability
()[
0
]
<
7
:
if
torch
.
cuda
.
get_device_capability
()[
0
]
<
7
or
torch
.
version
.
hip
:
original_row_node_ids
=
[
torch
.
tensor
([
0
,
3
,
4
,
2
,
5
,
7
]).
to
(
F
.
ctx
()),
torch
.
tensor
([
0
,
3
,
4
,
2
,
5
]).
to
(
F
.
ctx
()),
...
...
@@ -1672,7 +1672,7 @@ def test_SubgraphSampler_unique_csc_format_Homo_HyperLink_gpu(labor):
deduplicate
=
True
,
)
if
torch
.
cuda
.
get_device_capability
()[
0
]
<
7
:
if
torch
.
cuda
.
get_device_capability
()[
0
]
<
7
or
torch
.
version
.
hip
:
original_row_node_ids
=
[
torch
.
tensor
([
0
,
3
,
4
,
2
,
5
,
7
]).
to
(
F
.
ctx
()),
torch
.
tensor
([
0
,
3
,
4
,
2
,
5
]).
to
(
F
.
ctx
()),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment