Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
9632ab1d
Unverified
Commit
9632ab1d
authored
Mar 18, 2024
by
Muhammed Fatih BALIN
Committed by
GitHub
Mar 18, 2024
Browse files
[GraphBolt][CUDA] Specialize non-weighted neighbor sampling impl (#7215)
parent
7129905e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
242 additions
and
112 deletions
+242
-112
graphbolt/include/graphbolt/continuous_seed.h
graphbolt/include/graphbolt/continuous_seed.h
+2
-0
graphbolt/src/cuda/neighbor_sampler.cu
graphbolt/src/cuda/neighbor_sampler.cu
+240
-112
No files found.
graphbolt/include/graphbolt/continuous_seed.h
View file @
9632ab1d
...
@@ -56,6 +56,8 @@ class continuous_seed {
...
@@ -56,6 +56,8 @@ class continuous_seed {
c
[
1
]
=
std
::
sin
(
pi
*
r
/
2
);
c
[
1
]
=
std
::
sin
(
pi
*
r
/
2
);
}
}
uint64_t
get_seed
(
int
i
)
const
{
return
s
[
i
!=
0
];
}
#ifdef __CUDACC__
#ifdef __CUDACC__
__device__
inline
float
uniform
(
const
uint64_t
t
)
const
{
__device__
inline
float
uniform
(
const
uint64_t
t
)
const
{
const
uint64_t
kCurandSeed
=
999961
;
// Could be any random number.
const
uint64_t
kCurandSeed
=
999961
;
// Could be any random number.
...
...
graphbolt/src/cuda/neighbor_sampler.cu
View file @
9632ab1d
...
@@ -17,6 +17,9 @@
...
@@ -17,6 +17,9 @@
#include <algorithm>
#include <algorithm>
#include <array>
#include <array>
#include <cub/cub.cuh>
#include <cub/cub.cuh>
#if __CUDA_ARCH__ >= 700
#include <cuda/atomic>
#endif // __CUDA_ARCH__ >= 700
#include <limits>
#include <limits>
#include <numeric>
#include <numeric>
#include <type_traits>
#include <type_traits>
...
@@ -30,6 +33,64 @@ namespace ops {
...
@@ -30,6 +33,64 @@ namespace ops {
constexpr
int
BLOCK_SIZE
=
128
;
constexpr
int
BLOCK_SIZE
=
128
;
inline
__device__
int64_t
AtomicMax
(
int64_t
*
const
address
,
const
int64_t
val
)
{
// To match the type of "::atomicCAS", ignore lint warning.
using
Type
=
unsigned
long
long
int
;
// NOLINT
static_assert
(
sizeof
(
Type
)
==
sizeof
(
*
address
),
"Type width must match"
);
return
atomicMax
(
reinterpret_cast
<
Type
*>
(
address
),
static_cast
<
Type
>
(
val
));
}
inline
__device__
int32_t
AtomicMax
(
int32_t
*
const
address
,
const
int32_t
val
)
{
// To match the type of "::atomicCAS", ignore lint warning.
using
Type
=
int
;
// NOLINT
static_assert
(
sizeof
(
Type
)
==
sizeof
(
*
address
),
"Type width must match"
);
return
atomicMax
(
reinterpret_cast
<
Type
*>
(
address
),
static_cast
<
Type
>
(
val
));
}
/**
* @brief Performs neighbor sampling and fills the edge_ids array with
* original edge ids if sliced_indptr is valid. If not, then it fills the edge
* ids array with numbers upto the node degree.
*/
template
<
typename
indptr_t
,
typename
indices_t
>
__global__
void
_ComputeRandomsNS
(
const
int64_t
num_edges
,
const
indptr_t
*
const
sliced_indptr
,
const
indptr_t
*
const
sub_indptr
,
const
indptr_t
*
const
output_indptr
,
const
indices_t
*
const
csr_rows
,
const
uint64_t
random_seed
,
indptr_t
*
edge_ids
)
{
int64_t
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
stride
=
gridDim
.
x
*
blockDim
.
x
;
curandStatePhilox4_32_10_t
rng
;
curand_init
(
random_seed
,
i
,
0
,
&
rng
);
while
(
i
<
num_edges
)
{
const
auto
row_position
=
csr_rows
[
i
];
const
auto
row_offset
=
i
-
sub_indptr
[
row_position
];
const
auto
output_offset
=
output_indptr
[
row_position
];
const
auto
fanout
=
output_indptr
[
row_position
+
1
]
-
output_offset
;
const
auto
rnd
=
row_offset
<
fanout
?
row_offset
:
curand
(
&
rng
)
%
(
row_offset
+
1
);
if
(
rnd
<
fanout
)
{
const
indptr_t
edge_id
=
row_offset
+
(
sliced_indptr
?
sliced_indptr
[
row_position
]
:
0
);
#if __CUDA_ARCH__ >= 700
::
cuda
::
atomic_ref
<
indptr_t
,
::
cuda
::
thread_scope_device
>
a
(
edge_ids
[
output_offset
+
rnd
]);
a
.
fetch_max
(
edge_id
,
::
cuda
::
std
::
memory_order_relaxed
);
#else
AtomicMax
(
edge_ids
+
output_offset
+
rnd
,
edge_id
);
#endif // __CUDA_ARCH__
}
i
+=
stride
;
}
}
/**
/**
* @brief Fills the random_arr with random numbers and the edge_ids array with
* @brief Fills the random_arr with random numbers and the edge_ids array with
* original edge ids. When random_arr is sorted along with edge_ids, the first
* original edge ids. When random_arr is sorted along with edge_ids, the first
...
@@ -251,119 +312,186 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
...
@@ -251,119 +312,186 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
// Find the smallest integer type to store the edge id offsets. We synch
// Find the smallest integer type to store the edge id offsets. We synch
// the CUDAEvent so that the access is safe.
// the CUDAEvent so that the access is safe.
max_in_degree_event
.
synchronize
();
auto
compute_num_bits
=
[
&
]
{
const
int
num_bits
=
max_in_degree_event
.
synchronize
();
cuda
::
NumberOfBits
(
max_in_degree
.
data_ptr
<
indptr_t
>
()[
0
]);
return
cuda
::
NumberOfBits
(
max_in_degree
.
data_ptr
<
indptr_t
>
()[
0
]);
std
::
array
<
int
,
4
>
type_bits
=
{
8
,
16
,
32
,
64
};
};
const
auto
type_index
=
if
(
layer
||
probs_or_mask
.
has_value
())
{
std
::
lower_bound
(
type_bits
.
begin
(),
type_bits
.
end
(),
num_bits
)
-
const
int
num_bits
=
compute_num_bits
();
type_bits
.
begin
();
std
::
array
<
int
,
4
>
type_bits
=
{
8
,
16
,
32
,
64
};
std
::
array
<
torch
::
ScalarType
,
5
>
types
=
{
const
auto
type_index
=
torch
::
kByte
,
torch
::
kInt16
,
torch
::
kInt32
,
torch
::
kLong
,
std
::
lower_bound
(
type_bits
.
begin
(),
type_bits
.
end
(),
num_bits
)
-
torch
::
kLong
};
type_bits
.
begin
();
auto
edge_id_dtype
=
types
[
type_index
];
std
::
array
<
torch
::
ScalarType
,
5
>
types
=
{
AT_DISPATCH_INTEGRAL_TYPES
(
torch
::
kByte
,
torch
::
kInt16
,
torch
::
kInt32
,
torch
::
kLong
,
edge_id_dtype
,
"SampleNeighborsEdgeIDs"
,
([
&
]
{
torch
::
kLong
};
using
edge_id_t
=
std
::
make_unsigned_t
<
scalar_t
>
;
auto
edge_id_dtype
=
types
[
type_index
];
TORCH_CHECK
(
AT_DISPATCH_INTEGRAL_TYPES
(
num_bits
<=
sizeof
(
edge_id_t
)
*
8
,
edge_id_dtype
,
"SampleNeighborsEdgeIDs"
,
([
&
]
{
"Selected edge_id_t must be capable of storing edge_ids."
);
using
edge_id_t
=
std
::
make_unsigned_t
<
scalar_t
>
;
// Using bfloat16 for random numbers works just as reliably as
TORCH_CHECK
(
// float32 and provides around %30 percent speedup.
num_bits
<=
sizeof
(
edge_id_t
)
*
8
,
using
rnd_t
=
nv_bfloat16
;
"Selected edge_id_t must be capable of storing edge_ids."
);
auto
randoms
=
// Using bfloat16 for random numbers works just as reliably as
allocator
.
AllocateStorage
<
rnd_t
>
(
num_edges
.
value
());
// float32 and provides around 30% speedup.
auto
randoms_sorted
=
using
rnd_t
=
nv_bfloat16
;
allocator
.
AllocateStorage
<
rnd_t
>
(
num_edges
.
value
());
auto
randoms
=
auto
edge_id_segments
=
allocator
.
AllocateStorage
<
rnd_t
>
(
num_edges
.
value
());
allocator
.
AllocateStorage
<
edge_id_t
>
(
num_edges
.
value
());
auto
randoms_sorted
=
auto
sorted_edge_id_segments
=
allocator
.
AllocateStorage
<
rnd_t
>
(
num_edges
.
value
());
allocator
.
AllocateStorage
<
edge_id_t
>
(
num_edges
.
value
());
auto
edge_id_segments
=
AT_DISPATCH_INDEX_TYPES
(
allocator
.
AllocateStorage
<
edge_id_t
>
(
num_edges
.
value
());
indices
.
scalar_type
(),
"SampleNeighborsIndices"
,
([
&
]
{
auto
sorted_edge_id_segments
=
using
indices_t
=
index_t
;
allocator
.
AllocateStorage
<
edge_id_t
>
(
num_edges
.
value
());
auto
probs_or_mask_scalar_type
=
torch
::
kFloat32
;
AT_DISPATCH_INDEX_TYPES
(
if
(
probs_or_mask
.
has_value
())
{
indices
.
scalar_type
(),
"SampleNeighborsIndices"
,
([
&
]
{
probs_or_mask_scalar_type
=
using
indices_t
=
index_t
;
probs_or_mask
.
value
().
scalar_type
();
auto
probs_or_mask_scalar_type
=
torch
::
kFloat32
;
}
if
(
probs_or_mask
.
has_value
())
{
GRAPHBOLT_DISPATCH_ALL_TYPES
(
probs_or_mask_scalar_type
=
probs_or_mask_scalar_type
,
"SampleNeighborsProbs"
,
probs_or_mask
.
value
().
scalar_type
();
([
&
]
{
}
using
probs_t
=
scalar_t
;
GRAPHBOLT_DISPATCH_ALL_TYPES
(
probs_t
*
sliced_probs_ptr
=
nullptr
;
probs_or_mask_scalar_type
,
"SampleNeighborsProbs"
,
if
(
sliced_probs_or_mask
.
has_value
())
{
([
&
]
{
sliced_probs_ptr
=
sliced_probs_or_mask
.
value
()
using
probs_t
=
scalar_t
;
.
data_ptr
<
probs_t
>
();
probs_t
*
sliced_probs_ptr
=
nullptr
;
}
if
(
sliced_probs_or_mask
.
has_value
())
{
const
indices_t
*
indices_ptr
=
sliced_probs_ptr
=
sliced_probs_or_mask
.
value
()
layer
?
indices
.
data_ptr
<
indices_t
>
()
:
nullptr
;
.
data_ptr
<
probs_t
>
();
const
dim3
block
(
BLOCK_SIZE
);
}
const
dim3
grid
(
const
indices_t
*
indices_ptr
=
(
num_edges
.
value
()
+
BLOCK_SIZE
-
1
)
/
layer
?
indices
.
data_ptr
<
indices_t
>
()
:
nullptr
;
BLOCK_SIZE
);
const
dim3
block
(
BLOCK_SIZE
);
// Compute row and random number pairs.
const
dim3
grid
(
CUDA_KERNEL_CALL
(
(
num_edges
.
value
()
+
BLOCK_SIZE
-
1
)
/
_ComputeRandoms
,
grid
,
block
,
0
,
BLOCK_SIZE
);
num_edges
.
value
(),
// Compute row and random number pairs.
sliced_indptr
.
data_ptr
<
indptr_t
>
(),
CUDA_KERNEL_CALL
(
sub_indptr
.
data_ptr
<
indptr_t
>
(),
_ComputeRandoms
,
grid
,
block
,
0
,
coo_rows
.
data_ptr
<
indices_t
>
(),
sliced_probs_ptr
,
num_edges
.
value
(),
indices_ptr
,
random_seed
,
randoms
.
get
(),
sliced_indptr
.
data_ptr
<
indptr_t
>
(),
edge_id_segments
.
get
());
sub_indptr
.
data_ptr
<
indptr_t
>
(),
}));
coo_rows
.
data_ptr
<
indices_t
>
(),
}));
sliced_probs_ptr
,
indices_ptr
,
random_seed
,
randoms
.
get
(),
edge_id_segments
.
get
());
// Sort the random numbers along with edge ids, after
}));
// sorting the first fanout elements of each row will
}));
// give us the sampled edges.
CUB_CALL
(
// Sort the random numbers along with edge ids, after
DeviceSegmentedSort
::
SortPairs
,
randoms
.
get
(),
// sorting the first fanout elements of each row will
randoms_sorted
.
get
(),
edge_id_segments
.
get
(),
// give us the sampled edges.
sorted_edge_id_segments
.
get
(),
num_edges
.
value
(),
num_rows
,
sub_indptr
.
data_ptr
<
indptr_t
>
(),
sub_indptr
.
data_ptr
<
indptr_t
>
()
+
1
);
picked_eids
=
torch
::
empty
(
static_cast
<
indptr_t
>
(
num_sampled_edges
),
sub_indptr
.
options
());
// Need to sort the sampled edges only when fanouts.size() == 1
// since multiple fanout sampling case is automatically going to
// be sorted.
if
(
type_per_edge
&&
fanouts
.
size
()
==
1
)
{
// Ensuring sort result still ends up in sorted_edge_id_segments
std
::
swap
(
edge_id_segments
,
sorted_edge_id_segments
);
auto
sampled_segment_end_it
=
thrust
::
make_transform_iterator
(
iota
,
SegmentEndFunc
<
indptr_t
,
decltype
(
sampled_degree
)
>
{
sub_indptr
.
data_ptr
<
indptr_t
>
(),
sampled_degree
});
CUB_CALL
(
DeviceSegmentedSort
::
SortKeys
,
edge_id_segments
.
get
(),
sorted_edge_id_segments
.
get
(),
picked_eids
.
size
(
0
),
num_rows
,
sub_indptr
.
data_ptr
<
indptr_t
>
(),
sampled_segment_end_it
);
}
auto
input_buffer_it
=
thrust
::
make_transform_iterator
(
iota
,
IteratorFunc
<
indptr_t
,
edge_id_t
>
{
sub_indptr
.
data_ptr
<
indptr_t
>
(),
sorted_edge_id_segments
.
get
()});
auto
output_buffer_it
=
thrust
::
make_transform_iterator
(
iota
,
IteratorFuncAddOffset
<
indptr_t
,
indptr_t
>
{
output_indptr
.
data_ptr
<
indptr_t
>
(),
sliced_indptr
.
data_ptr
<
indptr_t
>
(),
picked_eids
.
data_ptr
<
indptr_t
>
()});
constexpr
int64_t
max_copy_at_once
=
std
::
numeric_limits
<
int32_t
>::
max
();
// Copy the sampled edge ids into picked_eids tensor.
for
(
int64_t
i
=
0
;
i
<
num_rows
;
i
+=
max_copy_at_once
)
{
CUB_CALL
(
CUB_CALL
(
DeviceCopy
::
Batched
,
input_buffer_it
+
i
,
DeviceSegmentedSort
::
SortPairs
,
randoms
.
get
(),
output_buffer_it
+
i
,
sampled_degree
+
i
,
randoms_sorted
.
get
(),
edge_id_segments
.
get
(),
std
::
min
(
num_rows
-
i
,
max_copy_at_once
));
sorted_edge_id_segments
.
get
(),
num_edges
.
value
(),
num_rows
,
}
sub_indptr
.
data_ptr
<
indptr_t
>
(),
}));
sub_indptr
.
data_ptr
<
indptr_t
>
()
+
1
);
picked_eids
=
torch
::
empty
(
static_cast
<
indptr_t
>
(
num_sampled_edges
),
sub_indptr
.
options
());
// Need to sort the sampled edges only when fanouts.size() == 1
// since multiple fanout sampling case is automatically going to
// be sorted.
if
(
type_per_edge
&&
fanouts
.
size
()
==
1
)
{
// Ensuring sort result still ends up in
// sorted_edge_id_segments
std
::
swap
(
edge_id_segments
,
sorted_edge_id_segments
);
auto
sampled_segment_end_it
=
thrust
::
make_transform_iterator
(
iota
,
SegmentEndFunc
<
indptr_t
,
decltype
(
sampled_degree
)
>
{
sub_indptr
.
data_ptr
<
indptr_t
>
(),
sampled_degree
});
CUB_CALL
(
DeviceSegmentedSort
::
SortKeys
,
edge_id_segments
.
get
(),
sorted_edge_id_segments
.
get
(),
picked_eids
.
size
(
0
),
num_rows
,
sub_indptr
.
data_ptr
<
indptr_t
>
(),
sampled_segment_end_it
);
}
auto
input_buffer_it
=
thrust
::
make_transform_iterator
(
iota
,
IteratorFunc
<
indptr_t
,
edge_id_t
>
{
sub_indptr
.
data_ptr
<
indptr_t
>
(),
sorted_edge_id_segments
.
get
()});
auto
output_buffer_it
=
thrust
::
make_transform_iterator
(
iota
,
IteratorFuncAddOffset
<
indptr_t
,
indptr_t
>
{
output_indptr
.
data_ptr
<
indptr_t
>
(),
sliced_indptr
.
data_ptr
<
indptr_t
>
(),
picked_eids
.
data_ptr
<
indptr_t
>
()});
constexpr
int64_t
max_copy_at_once
=
std
::
numeric_limits
<
int32_t
>::
max
();
// Copy the sampled edge ids into picked_eids tensor.
for
(
int64_t
i
=
0
;
i
<
num_rows
;
i
+=
max_copy_at_once
)
{
CUB_CALL
(
DeviceCopy
::
Batched
,
input_buffer_it
+
i
,
output_buffer_it
+
i
,
sampled_degree
+
i
,
std
::
min
(
num_rows
-
i
,
max_copy_at_once
));
}
}));
}
else
{
// Non-weighted neighbor sampling.
picked_eids
=
torch
::
zeros
(
num_edges
.
value
(),
sub_indptr
.
options
());
const
auto
sort_needed
=
type_per_edge
&&
fanouts
.
size
()
==
1
;
const
auto
sliced_indptr_ptr
=
sort_needed
?
nullptr
:
sliced_indptr
.
data_ptr
<
indptr_t
>
();
const
dim3
block
(
BLOCK_SIZE
);
const
dim3
grid
(
(
std
::
min
(
num_edges
.
value
(),
static_cast
<
int64_t
>
(
1
<<
20
))
+
BLOCK_SIZE
-
1
)
/
BLOCK_SIZE
);
AT_DISPATCH_INDEX_TYPES
(
indices
.
scalar_type
(),
"SampleNeighborsIndices"
,
([
&
]
{
using
indices_t
=
index_t
;
// Compute row and random number pairs.
CUDA_KERNEL_CALL
(
_ComputeRandomsNS
,
grid
,
block
,
0
,
num_edges
.
value
(),
sliced_indptr_ptr
,
sub_indptr
.
data_ptr
<
indptr_t
>
(),
output_indptr
.
data_ptr
<
indptr_t
>
(),
coo_rows
.
data_ptr
<
indices_t
>
(),
random_seed
.
get_seed
(
0
),
picked_eids
.
data_ptr
<
indptr_t
>
());
}));
picked_eids
=
picked_eids
.
slice
(
0
,
0
,
static_cast
<
indptr_t
>
(
num_sampled_edges
));
// Need to sort the sampled edges only when fanouts.size() == 1
// since multiple fanout sampling case is automatically going to
// be sorted.
if
(
sort_needed
)
{
const
int
num_bits
=
compute_num_bits
();
std
::
array
<
int
,
4
>
type_bits
=
{
8
,
15
,
31
,
63
};
const
auto
type_index
=
std
::
lower_bound
(
type_bits
.
begin
(),
type_bits
.
end
(),
num_bits
)
-
type_bits
.
begin
();
std
::
array
<
torch
::
ScalarType
,
5
>
types
=
{
torch
::
kByte
,
torch
::
kInt16
,
torch
::
kInt32
,
torch
::
kLong
,
torch
::
kLong
};
auto
edge_id_dtype
=
types
[
type_index
];
AT_DISPATCH_INTEGRAL_TYPES
(
edge_id_dtype
,
"SampleNeighborsEdgeIDs"
,
([
&
]
{
using
edge_id_t
=
scalar_t
;
TORCH_CHECK
(
num_bits
<=
sizeof
(
edge_id_t
)
*
8
,
"Selected edge_id_t must be capable of storing "
"edge_ids."
);
auto
picked_offsets
=
picked_eids
.
to
(
edge_id_dtype
);
auto
sorted_offsets
=
torch
::
empty_like
(
picked_offsets
);
CUB_CALL
(
DeviceSegmentedSort
::
SortKeys
,
picked_offsets
.
data_ptr
<
edge_id_t
>
(),
sorted_offsets
.
data_ptr
<
edge_id_t
>
(),
picked_eids
.
size
(
0
),
num_rows
,
output_indptr
.
data_ptr
<
indptr_t
>
(),
output_indptr
.
data_ptr
<
indptr_t
>
()
+
1
);
auto
edge_id_offsets
=
ExpandIndptrImpl
(
output_indptr
,
picked_eids
.
scalar_type
(),
sliced_indptr
,
picked_eids
.
size
(
0
));
picked_eids
=
sorted_offsets
.
to
(
picked_eids
.
scalar_type
())
+
edge_id_offsets
;
}));
}
}
output_indices
=
torch
::
empty
(
output_indices
=
torch
::
empty
(
picked_eids
.
size
(
0
),
picked_eids
.
size
(
0
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment