Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
e7f0c3a1
Unverified
Commit
e7f0c3a1
authored
Dec 28, 2023
by
Muhammed Fatih BALIN
Committed by
GitHub
Dec 28, 2023
Browse files
[GraphBolt][CUDA] SampleNeighbors (Without replacement for now) (#6770)
parent
01df9bad
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
502 additions
and
48 deletions
+502
-48
graphbolt/include/graphbolt/cuda_sampling_ops.h
graphbolt/include/graphbolt/cuda_sampling_ops.h
+42
-0
graphbolt/src/cuda/index_select_impl.cu
graphbolt/src/cuda/index_select_impl.cu
+1
-1
graphbolt/src/cuda/neighbor_sampler.cu
graphbolt/src/cuda/neighbor_sampler.cu
+319
-0
graphbolt/src/fused_csc_sampling_graph.cc
graphbolt/src/fused_csc_sampling_graph.cc
+19
-2
python/dgl/graphbolt/impl/neighbor_sampler.py
python/dgl/graphbolt/impl/neighbor_sampler.py
+9
-2
python/dgl/graphbolt/internal/sample_utils.py
python/dgl/graphbolt/internal/sample_utils.py
+16
-3
tests/python/pytorch/graphbolt/gb_test_utils.py
tests/python/pytorch/graphbolt/gb_test_utils.py
+1
-1
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
+95
-39
No files found.
graphbolt/include/graphbolt/cuda_sampling_ops.h
View file @
e7f0c3a1
...
...
@@ -11,6 +11,48 @@
namespace
graphbolt
{
namespace
ops
{
/**
* @brief Sample neighboring edges of the given nodes and return the induced
* subgraph.
*
* @param indptr Index pointer array of the CSC.
* @param indices Indices array of the CSC.
* @param nodes The nodes from which to sample neighbors.
* @param fanouts The number of edges to be sampled for each node with or
* without considering edge types.
* - When the length is 1, it indicates that the fanout applies to all
* neighbors of the node as a collective, regardless of the edge type.
* - Otherwise, the length should equal to the number of edge types, and
* each fanout value corresponds to a specific edge type of the node.
* The value of each fanout should be >= 0 or = -1.
* - When the value is -1, all neighbors will be chosen for sampling. It is
* equivalent to selecting all neighbors with non-zero probability when the
* fanout is >= the number of neighbors (and replacement is set to false).
* - When the value is a non-negative integer, it serves as a minimum
* threshold for selecting neighbors.
* @param replace Boolean indicating whether the sample is preformed with or
* without replacement. If True, a value can be selected multiple times.
* Otherwise, each value can be selected only once.
* @param layer Boolean indicating whether neighbors should be sampled in a
* layer sampling fashion. Uses the LABOR-0 algorithm to increase overlap of
* sampled edges, see arXiv:2210.13339.
* @param return_eids Boolean indicating whether edge IDs need to be returned,
* typically used when edge features are required.
* @param type_per_edge A tensor representing the type of each edge, if present.
* @param probs_or_mask An optional tensor with (unnormalized) probabilities
* corresponding to each neighboring edge of a node. It must be
* a 1D tensor, with the number of elements equaling the total number of edges.
*
* @return An intrusive pointer to a FusedSampledSubgraph object containing
* the sampled graph's information.
*/
c10
::
intrusive_ptr
<
sampling
::
FusedSampledSubgraph
>
SampleNeighbors
(
torch
::
Tensor
indptr
,
torch
::
Tensor
indices
,
torch
::
Tensor
nodes
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
bool
layer
,
bool
return_eids
,
torch
::
optional
<
torch
::
Tensor
>
type_per_edge
=
torch
::
nullopt
,
torch
::
optional
<
torch
::
Tensor
>
probs_or_mask
=
torch
::
nullopt
);
/**
* @brief Return the subgraph induced on the inbound edges of the given nodes.
* @param nodes Type agnostic node IDs to form the subgraph.
...
...
graphbolt/src/cuda/index_select_impl.cu
View file @
e7f0c3a1
...
...
@@ -124,7 +124,7 @@ torch::Tensor UVAIndexSelectImpl_(torch::Tensor input, torch::Tensor index) {
const
IdType
*
index_sorted_ptr
=
sorted_index
.
data_ptr
<
IdType
>
();
const
int64_t
*
permutation_ptr
=
permutation
.
data_ptr
<
int64_t
>
();
cudaStream_t
stream
=
cuda
::
GetCurrentStream
();
auto
stream
=
cuda
::
GetCurrentStream
();
if
(
aligned_feature_size
==
1
)
{
// Use a single thread to process each output row to avoid wasting threads.
...
...
graphbolt/src/cuda/neighbor_sampler.cu
0 → 100644
View file @
e7f0c3a1
/**
* Copyright (c) 2023 by Contributors
* Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
* @file cuda/index_select_impl.cu
* @brief Index select operator implementation on CUDA.
*/
#include <c10/core/ScalarType.h>
#include <c10/cuda/CUDAStream.h>
#include <curand_kernel.h>
#include <graphbolt/cuda_ops.h>
#include <graphbolt/cuda_sampling_ops.h>
#include <thrust/gather.h>
#include <thrust/iterator/constant_iterator.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/iterator/transform_output_iterator.h>
#include <algorithm>
#include <array>
#include <cub/cub.cuh>
#include <cuda/std/tuple>
#include <limits>
#include <numeric>
#include <type_traits>
#include "../random.h"
#include "./common.h"
#include "./utils.h"
namespace
graphbolt
{
namespace
ops
{
constexpr
int
BLOCK_SIZE
=
128
;
/**
* @brief Fills the random_arr with random numbers and the edge_ids array with
* original edge ids. When random_arr is sorted along with edge_ids, the first
* fanout elements of each row gives us the sampled edges.
*/
template
<
typename
float_t
,
typename
indptr_t
,
typename
indices_t
,
typename
weights_t
,
typename
edge_id_t
>
__global__
void
_ComputeRandoms
(
const
int64_t
num_edges
,
const
indptr_t
*
const
sliced_indptr
,
const
indptr_t
*
const
sub_indptr
,
const
indices_t
*
const
csr_rows
,
const
weights_t
*
const
weights
,
const
indices_t
*
const
indices
,
const
uint64_t
random_seed
,
float_t
*
random_arr
,
edge_id_t
*
edge_ids
)
{
int64_t
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
stride
=
gridDim
.
x
*
blockDim
.
x
;
curandStatePhilox4_32_10_t
rng
;
const
auto
labor
=
indices
!=
nullptr
;
if
(
!
labor
)
{
curand_init
(
random_seed
,
i
,
0
,
&
rng
);
}
while
(
i
<
num_edges
)
{
const
auto
row_position
=
csr_rows
[
i
];
const
auto
row_offset
=
i
-
sub_indptr
[
row_position
];
const
auto
in_idx
=
sliced_indptr
[
row_position
]
+
row_offset
;
if
(
labor
)
{
constexpr
uint64_t
kCurandSeed
=
999961
;
curand_init
(
kCurandSeed
,
random_seed
,
indices
[
in_idx
],
&
rng
);
}
const
auto
rnd
=
curand_uniform
(
&
rng
);
const
auto
prob
=
weights
?
weights
[
in_idx
]
:
static_cast
<
weights_t
>
(
1
);
const
auto
exp_rnd
=
-
__logf
(
rnd
);
const
float_t
adjusted_rnd
=
prob
>
0
?
static_cast
<
float_t
>
(
exp_rnd
/
prob
)
:
std
::
numeric_limits
<
float_t
>::
infinity
();
random_arr
[
i
]
=
adjusted_rnd
;
edge_ids
[
i
]
=
row_offset
;
i
+=
stride
;
}
}
template
<
typename
indptr_t
>
struct
MinInDegreeFanout
{
const
indptr_t
*
in_degree
;
int64_t
fanout
;
__host__
__device__
auto
operator
()(
int64_t
i
)
{
return
static_cast
<
indptr_t
>
(
min
(
static_cast
<
int64_t
>
(
in_degree
[
i
]),
fanout
));
}
};
template
<
typename
indptr_t
,
typename
indices_t
>
struct
IteratorFunc
{
indptr_t
*
indptr
;
indices_t
*
indices
;
__host__
__device__
auto
operator
()(
int64_t
i
)
{
return
indices
+
indptr
[
i
];
}
};
template
<
typename
indptr_t
>
struct
AddOffset
{
indptr_t
offset
;
template
<
typename
edge_id_t
>
__host__
__device__
indptr_t
operator
()(
edge_id_t
x
)
{
return
x
+
offset
;
}
};
template
<
typename
indptr_t
,
typename
indices_t
>
struct
IteratorFuncAddOffset
{
indptr_t
*
indptr
;
indptr_t
*
sliced_indptr
;
indices_t
*
indices
;
__host__
__device__
auto
operator
()(
int64_t
i
)
{
return
thrust
::
transform_output_iterator
{
indices
+
indptr
[
i
],
AddOffset
<
indptr_t
>
{
sliced_indptr
[
i
]}};
}
};
c10
::
intrusive_ptr
<
sampling
::
FusedSampledSubgraph
>
SampleNeighbors
(
torch
::
Tensor
indptr
,
torch
::
Tensor
indices
,
torch
::
Tensor
nodes
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
bool
layer
,
bool
return_eids
,
torch
::
optional
<
torch
::
Tensor
>
type_per_edge
,
torch
::
optional
<
torch
::
Tensor
>
probs_or_mask
)
{
TORCH_CHECK
(
fanouts
.
size
()
==
1
,
"Heterogenous sampling is not supported yet!"
);
TORCH_CHECK
(
!
replace
,
"Sampling with replacement is not supported yet!"
);
// Assume that indptr, indices, nodes, type_per_edge and probs_or_mask
// are all resident on the GPU. If not, it is better to first extract them
// before calling this function.
auto
allocator
=
cuda
::
GetAllocator
();
const
auto
stream
=
cuda
::
GetCurrentStream
();
const
auto
num_rows
=
nodes
.
size
(
0
);
const
auto
fanout
=
fanouts
[
0
]
>=
0
?
fanouts
[
0
]
:
std
::
numeric_limits
<
int64_t
>::
max
();
auto
in_degree_and_sliced_indptr
=
SliceCSCIndptr
(
indptr
,
nodes
);
auto
in_degree
=
std
::
get
<
0
>
(
in_degree_and_sliced_indptr
);
auto
max_in_degree
=
torch
::
empty
(
1
,
c10
::
TensorOptions
().
dtype
(
in_degree
.
scalar_type
()).
pinned_memory
(
true
));
AT_DISPATCH_INTEGRAL_TYPES
(
indptr
.
scalar_type
(),
"SampleNeighborsInDegree"
,
([
&
]
{
size_t
tmp_storage_size
=
0
;
cub
::
DeviceReduce
::
Max
(
nullptr
,
tmp_storage_size
,
in_degree
.
data_ptr
<
scalar_t
>
(),
max_in_degree
.
data_ptr
<
scalar_t
>
(),
num_rows
,
stream
);
auto
tmp_storage
=
allocator
.
AllocateStorage
<
char
>
(
tmp_storage_size
);
cub
::
DeviceReduce
::
Max
(
tmp_storage
.
get
(),
tmp_storage_size
,
in_degree
.
data_ptr
<
scalar_t
>
(),
max_in_degree
.
data_ptr
<
scalar_t
>
(),
num_rows
,
stream
);
}));
auto
sliced_indptr
=
std
::
get
<
1
>
(
in_degree_and_sliced_indptr
);
auto
sub_indptr
=
ExclusiveCumSum
(
in_degree
);
auto
output_indptr
=
torch
::
empty_like
(
sub_indptr
);
auto
coo_rows
=
CSRToCOO
(
sub_indptr
,
indices
.
scalar_type
());
const
auto
num_edges
=
coo_rows
.
size
(
0
);
const
auto
random_seed
=
RandomEngine
::
ThreadLocal
()
->
RandInt
(
static_cast
<
int64_t
>
(
0
),
std
::
numeric_limits
<
int64_t
>::
max
());
torch
::
Tensor
picked_eids
;
torch
::
Tensor
output_indices
;
AT_DISPATCH_INTEGRAL_TYPES
(
indptr
.
scalar_type
(),
"SampleNeighborsIndptr"
,
([
&
]
{
using
indptr_t
=
scalar_t
;
thrust
::
counting_iterator
<
int64_t
>
iota
(
0
);
auto
sampled_degree
=
thrust
::
make_transform_iterator
(
iota
,
MinInDegreeFanout
<
indptr_t
>
{
in_degree
.
data_ptr
<
indptr_t
>
(),
fanout
});
{
// Compute output_indptr.
size_t
tmp_storage_size
=
0
;
cub
::
DeviceScan
::
ExclusiveSum
(
nullptr
,
tmp_storage_size
,
sampled_degree
,
output_indptr
.
data_ptr
<
indptr_t
>
(),
num_rows
+
1
,
stream
);
auto
tmp_storage
=
allocator
.
AllocateStorage
<
char
>
(
tmp_storage_size
);
cub
::
DeviceScan
::
ExclusiveSum
(
tmp_storage
.
get
(),
tmp_storage_size
,
sampled_degree
,
output_indptr
.
data_ptr
<
indptr_t
>
(),
num_rows
+
1
,
stream
);
}
auto
num_sampled_edges
=
cuda
::
CopyScalar
{
output_indptr
.
data_ptr
<
indptr_t
>
()
+
num_rows
};
// Find the smallest integer type to store the edge id offsets.
// CSRToCOO had synch inside, so it is safe to read max_in_degree now.
const
int
num_bits
=
cuda
::
NumberOfBits
(
max_in_degree
.
data_ptr
<
indptr_t
>
()[
0
]);
std
::
array
<
int
,
4
>
type_bits
=
{
8
,
16
,
32
,
64
};
const
auto
type_index
=
std
::
lower_bound
(
type_bits
.
begin
(),
type_bits
.
end
(),
num_bits
)
-
type_bits
.
begin
();
std
::
array
<
torch
::
ScalarType
,
5
>
types
=
{
torch
::
kByte
,
torch
::
kInt16
,
torch
::
kInt32
,
torch
::
kLong
,
torch
::
kLong
};
auto
edge_id_dtype
=
types
[
type_index
];
AT_DISPATCH_INTEGRAL_TYPES
(
edge_id_dtype
,
"SampleNeighborsEdgeIDs"
,
([
&
]
{
using
edge_id_t
=
std
::
make_unsigned_t
<
scalar_t
>
;
TORCH_CHECK
(
num_bits
<=
sizeof
(
edge_id_t
)
*
8
,
"Selected edge_id_t must be capable of storing edge_ids."
);
// Using bfloat16 for random numbers works just as reliably as
// float32 and provides around %30 percent speedup.
using
rnd_t
=
nv_bfloat16
;
auto
randoms
=
allocator
.
AllocateStorage
<
rnd_t
>
(
num_edges
);
auto
randoms_sorted
=
allocator
.
AllocateStorage
<
rnd_t
>
(
num_edges
);
auto
edge_id_segments
=
allocator
.
AllocateStorage
<
edge_id_t
>
(
num_edges
);
auto
sorted_edge_id_segments
=
allocator
.
AllocateStorage
<
edge_id_t
>
(
num_edges
);
AT_DISPATCH_INTEGRAL_TYPES
(
indices
.
scalar_type
(),
"SampleNeighborsIndices"
,
([
&
]
{
using
indices_t
=
scalar_t
;
auto
probs_or_mask_scalar_type
=
torch
::
kFloat32
;
if
(
probs_or_mask
.
has_value
())
{
probs_or_mask_scalar_type
=
probs_or_mask
.
value
().
scalar_type
();
}
GRAPHBOLT_DISPATCH_ALL_TYPES
(
probs_or_mask_scalar_type
,
"SampleNeighborsProbs"
,
([
&
]
{
using
probs_t
=
scalar_t
;
probs_t
*
probs_ptr
=
nullptr
;
if
(
probs_or_mask
.
has_value
())
{
probs_ptr
=
probs_or_mask
.
value
().
data_ptr
<
probs_t
>
();
}
const
indices_t
*
indices_ptr
=
layer
?
indices
.
data_ptr
<
indices_t
>
()
:
nullptr
;
const
dim3
block
(
BLOCK_SIZE
);
const
dim3
grid
(
(
num_edges
+
BLOCK_SIZE
-
1
)
/
BLOCK_SIZE
);
// Compute row and random number pairs.
CUDA_KERNEL_CALL
(
_ComputeRandoms
,
grid
,
block
,
0
,
stream
,
num_edges
,
sliced_indptr
.
data_ptr
<
indptr_t
>
(),
sub_indptr
.
data_ptr
<
indptr_t
>
(),
coo_rows
.
data_ptr
<
indices_t
>
(),
probs_ptr
,
indices_ptr
,
random_seed
,
randoms
.
get
(),
edge_id_segments
.
get
());
}));
}));
// Sort the random numbers along with edge ids, after
// sorting the first fanout elements of each row will
// give us the sampled edges.
size_t
tmp_storage_size
=
0
;
CUDA_CALL
(
cub
::
DeviceSegmentedSort
::
SortPairs
(
nullptr
,
tmp_storage_size
,
randoms
.
get
(),
randoms_sorted
.
get
(),
edge_id_segments
.
get
(),
sorted_edge_id_segments
.
get
(),
num_edges
,
num_rows
,
sub_indptr
.
data_ptr
<
indptr_t
>
(),
sub_indptr
.
data_ptr
<
indptr_t
>
()
+
1
,
stream
));
auto
tmp_storage
=
allocator
.
AllocateStorage
<
char
>
(
tmp_storage_size
);
CUDA_CALL
(
cub
::
DeviceSegmentedSort
::
SortPairs
(
tmp_storage
.
get
(),
tmp_storage_size
,
randoms
.
get
(),
randoms_sorted
.
get
(),
edge_id_segments
.
get
(),
sorted_edge_id_segments
.
get
(),
num_edges
,
num_rows
,
sub_indptr
.
data_ptr
<
indptr_t
>
(),
sub_indptr
.
data_ptr
<
indptr_t
>
()
+
1
,
stream
));
picked_eids
=
torch
::
empty
(
static_cast
<
indptr_t
>
(
num_sampled_edges
),
nodes
.
options
().
dtype
(
indptr
.
scalar_type
()));
auto
input_buffer_it
=
thrust
::
make_transform_iterator
(
iota
,
IteratorFunc
<
indptr_t
,
edge_id_t
>
{
sub_indptr
.
data_ptr
<
indptr_t
>
(),
sorted_edge_id_segments
.
get
()});
auto
output_buffer_it
=
thrust
::
make_transform_iterator
(
iota
,
IteratorFuncAddOffset
<
indptr_t
,
indptr_t
>
{
output_indptr
.
data_ptr
<
indptr_t
>
(),
sliced_indptr
.
data_ptr
<
indptr_t
>
(),
picked_eids
.
data_ptr
<
indptr_t
>
()});
constexpr
int64_t
max_copy_at_once
=
std
::
numeric_limits
<
int32_t
>::
max
();
// Copy the sampled edge ids into picked_eids tensor.
for
(
int64_t
i
=
0
;
i
<
num_rows
;
i
+=
max_copy_at_once
)
{
size_t
tmp_storage_size
=
0
;
CUDA_CALL
(
cub
::
DeviceCopy
::
Batched
(
nullptr
,
tmp_storage_size
,
input_buffer_it
+
i
,
output_buffer_it
+
i
,
sampled_degree
+
i
,
std
::
min
(
num_rows
-
i
,
max_copy_at_once
),
stream
));
auto
tmp_storage
=
allocator
.
AllocateStorage
<
char
>
(
tmp_storage_size
);
CUDA_CALL
(
cub
::
DeviceCopy
::
Batched
(
tmp_storage
.
get
(),
tmp_storage_size
,
input_buffer_it
+
i
,
output_buffer_it
+
i
,
sampled_degree
+
i
,
std
::
min
(
num_rows
-
i
,
max_copy_at_once
),
stream
));
}
}));
output_indices
=
torch
::
empty
(
picked_eids
.
size
(
0
),
picked_eids
.
options
().
dtype
(
indices
.
scalar_type
()));
// Compute: output_indices = indices.gather(0, picked_eids);
AT_DISPATCH_INTEGRAL_TYPES
(
indices
.
scalar_type
(),
"SampleNeighborsOutputIndices"
,
([
&
]
{
using
indices_t
=
scalar_t
;
const
auto
exec_policy
=
thrust
::
cuda
::
par_nosync
(
allocator
).
on
(
stream
);
thrust
::
gather
(
exec_policy
,
picked_eids
.
data_ptr
<
indptr_t
>
(),
picked_eids
.
data_ptr
<
indptr_t
>
()
+
picked_eids
.
size
(
0
),
indices
.
data_ptr
<
indices_t
>
(),
output_indices
.
data_ptr
<
indices_t
>
());
}));
}));
torch
::
optional
<
torch
::
Tensor
>
subgraph_reverse_edge_ids
=
torch
::
nullopt
;
if
(
return_eids
)
subgraph_reverse_edge_ids
=
std
::
move
(
picked_eids
);
return
c10
::
make_intrusive
<
sampling
::
FusedSampledSubgraph
>
(
output_indptr
,
output_indices
,
nodes
,
torch
::
nullopt
,
subgraph_reverse_edge_ids
,
torch
::
nullopt
);
}
}
// namespace ops
}
// namespace graphbolt
graphbolt/src/fused_csc_sampling_graph.cc
View file @
e7f0c3a1
...
...
@@ -609,8 +609,25 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
const
torch
::
Tensor
&
nodes
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
bool
layer
,
bool
return_eids
,
torch
::
optional
<
std
::
string
>
probs_name
)
const
{
auto
probs_or_mask
=
this
->
EdgeAttribute
(
probs_name
);
if
(
probs_name
.
has_value
())
{
torch
::
optional
<
torch
::
Tensor
>
probs_or_mask
=
torch
::
nullopt
;
if
(
probs_name
.
has_value
()
&&
!
probs_name
.
value
().
empty
())
{
probs_or_mask
=
this
->
EdgeAttribute
(
probs_name
);
}
if
(
!
replace
&&
utils
::
is_accessible_from_gpu
(
indptr_
)
&&
utils
::
is_accessible_from_gpu
(
indices_
)
&&
utils
::
is_accessible_from_gpu
(
nodes
)
&&
(
!
probs_or_mask
.
has_value
()
||
utils
::
is_accessible_from_gpu
(
probs_or_mask
.
value
())))
{
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE
(
c10
::
DeviceType
::
CUDA
,
"SampleNeighbors"
,
{
return
ops
::
SampleNeighbors
(
indptr_
,
indices_
,
nodes
,
fanouts
,
replace
,
layer
,
return_eids
,
type_per_edge_
,
probs_or_mask
);
});
}
if
(
probs_or_mask
.
has_value
())
{
// Note probs will be passed as input for 'torch.multinomial' in deeper
// stack, which doesn't support 'torch.half' and 'torch.bool' data types. To
// avoid crashes, convert 'probs_or_mask' to 'float32' data type.
...
...
python/dgl/graphbolt/impl/neighbor_sampler.py
View file @
e7f0c3a1
...
...
@@ -118,9 +118,16 @@ class NeighborSampler(SubgraphSampler):
# Enrich seeds with all node types.
if
isinstance
(
seeds
,
dict
):
ntypes
=
list
(
self
.
graph
.
node_type_to_id
.
keys
())
# Loop over different seeds to extract the device they are on.
device
=
None
dtype
=
None
for
_
,
seed
in
seeds
.
items
():
device
=
seed
.
device
dtype
=
seed
.
dtype
break
default_tensor
=
torch
.
tensor
([],
dtype
=
dtype
,
device
=
device
)
seeds
=
{
ntype
:
seeds
.
get
(
ntype
,
torch
.
LongTensor
([]))
for
ntype
in
ntypes
ntype
:
seeds
.
get
(
ntype
,
default_tensor
)
for
ntype
in
ntypes
}
for
hop
in
range
(
num_layers
):
subgraph
=
self
.
sampler
(
...
...
python/dgl/graphbolt/internal/sample_utils.py
View file @
e7f0c3a1
...
...
@@ -368,7 +368,14 @@ def compact_csc_format(
original_row_ids
=
torch
.
cat
((
dst_nodes
,
csc_formats
.
indices
))
compacted_csc_formats
=
CSCFormatBase
(
indptr
=
csc_formats
.
indptr
,
indices
=
(
torch
.
arange
(
0
,
csc_formats
.
indices
.
size
(
0
))
+
offset
),
indices
=
(
torch
.
arange
(
0
,
csc_formats
.
indices
.
size
(
0
),
device
=
csc_formats
.
indices
.
device
,
)
+
offset
),
)
else
:
compacted_csc_formats
=
{}
...
...
@@ -381,12 +388,17 @@ def compact_csc_format(
assert
len
(
dst_nodes
.
get
(
dst_type
,
[]))
+
1
==
len
(
csc_format
.
indptr
),
"The seed nodes should correspond to indptr."
offset
=
original_row_ids
.
get
(
src_type
,
torch
.
tensor
([])).
size
(
0
)
device
=
csc_format
.
indices
.
device
offset
=
original_row_ids
.
get
(
src_type
,
torch
.
tensor
([],
device
=
device
)
).
size
(
0
)
original_row_ids
[
src_type
]
=
torch
.
cat
(
(
original_row_ids
.
get
(
src_type
,
torch
.
tensor
([],
dtype
=
csc_format
.
indices
.
dtype
),
torch
.
tensor
(
[],
dtype
=
csc_format
.
indices
.
dtype
,
device
=
device
),
),
csc_format
.
indices
,
)
...
...
@@ -398,6 +410,7 @@ def compact_csc_format(
0
,
csc_format
.
indices
.
size
(
0
),
dtype
=
csc_format
.
indices
.
dtype
,
device
=
device
,
)
+
offset
),
...
...
tests/python/pytorch/graphbolt/gb_test_utils.py
View file @
e7f0c3a1
...
...
@@ -235,7 +235,7 @@ def genereate_raw_data_for_hetero_dataset(
# Generate train/test/valid set.
os
.
makedirs
(
os
.
path
.
join
(
test_dir
,
"set"
),
exist_ok
=
True
)
user_ids
=
np
.
arange
(
num_nodes
[
"user"
])
user_ids
=
torch
.
arange
(
num_nodes
[
"user"
])
np
.
random
.
shuffle
(
user_ids
)
num_train
=
int
(
num_nodes
[
"user"
]
*
0.6
)
num_validation
=
int
(
num_nodes
[
"user"
]
*
0.2
)
...
...
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
View file @
e7f0c3a1
import
unittest
from
functools
import
partial
import
backend
as
F
import
dgl
import
dgl.graphbolt
as
gb
import
pytest
...
...
@@ -11,7 +14,7 @@ from . import gb_test_utils
def
test_SubgraphSampler_invoke
():
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
10
),
names
=
"seed_nodes"
)
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
.
copy_to
(
F
.
ctx
())
# Invoke via class constructor.
datapipe
=
gb
.
SubgraphSampler
(
item_sampler
)
...
...
@@ -26,9 +29,11 @@ def test_SubgraphSampler_invoke():
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
def
test_NeighborSampler_invoke
(
labor
):
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
,
bidirection_edge
=
True
)
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
,
bidirection_edge
=
True
).
to
(
F
.
ctx
()
)
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
10
),
names
=
"seed_nodes"
)
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
.
copy_to
(
F
.
ctx
())
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
...
...
@@ -47,9 +52,11 @@ def test_NeighborSampler_invoke(labor):
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
def
test_NeighborSampler_fanouts
(
labor
):
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
,
bidirection_edge
=
True
)
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
,
bidirection_edge
=
True
).
to
(
F
.
ctx
()
)
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
10
),
names
=
"seed_nodes"
)
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
.
copy_to
(
F
.
ctx
())
num_layer
=
2
# `fanouts` is a list of tensors.
...
...
@@ -71,9 +78,11 @@ def test_NeighborSampler_fanouts(labor):
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
def
test_SubgraphSampler_Node
(
labor
):
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
,
bidirection_edge
=
True
)
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
,
bidirection_edge
=
True
).
to
(
F
.
ctx
()
)
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
10
),
names
=
"seed_nodes"
)
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
.
copy_to
(
F
.
ctx
())
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
Sampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
...
...
@@ -88,9 +97,11 @@ def to_link_batch(data):
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
def
test_SubgraphSampler_Link
(
labor
):
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
,
bidirection_edge
=
True
)
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
,
bidirection_edge
=
True
).
to
(
F
.
ctx
()
)
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
20
).
reshape
(
-
1
,
2
),
names
=
"node_pairs"
)
datapipe
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
datapipe
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
.
copy_to
(
F
.
ctx
())
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
Sampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
...
...
@@ -101,9 +112,11 @@ def test_SubgraphSampler_Link(labor):
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
def
test_SubgraphSampler_Link_With_Negative
(
labor
):
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
,
bidirection_edge
=
True
)
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
,
bidirection_edge
=
True
).
to
(
F
.
ctx
()
)
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
20
).
reshape
(
-
1
,
2
),
names
=
"node_pairs"
)
datapipe
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
datapipe
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
.
copy_to
(
F
.
ctx
())
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
datapipe
=
gb
.
UniformNegativeSampler
(
datapipe
,
graph
,
1
)
...
...
@@ -135,13 +148,17 @@ def get_hetero_graph():
)
@
unittest
.
skipIf
(
F
.
_default_context_str
!=
"cpu"
,
reason
=
"Heterogenous sampling not yet supported on GPU."
,
)
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
def
test_SubgraphSampler_Node_Hetero
(
labor
):
graph
=
get_hetero_graph
()
graph
=
get_hetero_graph
()
.
to
(
F
.
ctx
())
itemset
=
gb
.
ItemSetDict
(
{
"n2"
:
gb
.
ItemSet
(
torch
.
arange
(
3
),
names
=
"seed_nodes"
)}
)
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
.
copy_to
(
F
.
ctx
())
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
Sampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
...
...
@@ -151,9 +168,13 @@ def test_SubgraphSampler_Node_Hetero(labor):
assert
len
(
minibatch
.
sampled_subgraphs
)
==
num_layer
@
unittest
.
skipIf
(
F
.
_default_context_str
!=
"cpu"
,
reason
=
"Heterogenous sampling not yet supported on GPU."
,
)
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
def
test_SubgraphSampler_Link_Hetero
(
labor
):
graph
=
get_hetero_graph
()
graph
=
get_hetero_graph
()
.
to
(
F
.
ctx
())
itemset
=
gb
.
ItemSetDict
(
{
"n1:e1:n2"
:
gb
.
ItemSet
(
...
...
@@ -167,7 +188,7 @@ def test_SubgraphSampler_Link_Hetero(labor):
}
)
datapipe
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
datapipe
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
.
copy_to
(
F
.
ctx
())
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
Sampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
...
...
@@ -176,9 +197,13 @@ def test_SubgraphSampler_Link_Hetero(labor):
assert
len
(
list
(
datapipe
))
==
5
@
unittest
.
skipIf
(
F
.
_default_context_str
!=
"cpu"
,
reason
=
"Heterogenous sampling not yet supported on GPU."
,
)
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
def
test_SubgraphSampler_Link_Hetero_With_Negative
(
labor
):
graph
=
get_hetero_graph
()
graph
=
get_hetero_graph
()
.
to
(
F
.
ctx
())
itemset
=
gb
.
ItemSetDict
(
{
"n1:e1:n2"
:
gb
.
ItemSet
(
...
...
@@ -192,7 +217,7 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(labor):
}
)
datapipe
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
datapipe
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
.
copy_to
(
F
.
ctx
())
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
datapipe
=
gb
.
UniformNegativeSampler
(
datapipe
,
graph
,
1
)
...
...
@@ -202,6 +227,10 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(labor):
assert
len
(
list
(
datapipe
))
==
5
@
unittest
.
skipIf
(
F
.
_default_context_str
!=
"cpu"
,
reason
=
"Sampling with replacement not yet supported on GPU."
,
)
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
def
test_SubgraphSampler_Random_Hetero_Graph
(
labor
):
num_nodes
=
5
...
...
@@ -230,7 +259,7 @@ def test_SubgraphSampler_Random_Hetero_Graph(labor):
node_type_to_id
=
node_type_to_id
,
edge_type_to_id
=
edge_type_to_id
,
edge_attributes
=
edge_attributes
,
)
)
.
to
(
F
.
ctx
())
itemset
=
gb
.
ItemSetDict
(
{
"n2"
:
gb
.
ItemSet
(
torch
.
tensor
([
0
]),
names
=
"seed_nodes"
),
...
...
@@ -238,10 +267,11 @@ def test_SubgraphSampler_Random_Hetero_Graph(labor):
}
)
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
.
copy_to
(
F
.
ctx
())
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
Sampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
sampler_dp
=
Sampler
(
item_sampler
,
graph
,
fanouts
,
replace
=
True
)
for
data
in
sampler_dp
:
...
...
@@ -267,16 +297,22 @@ def test_SubgraphSampler_Random_Hetero_Graph(labor):
)
@
unittest
.
skipIf
(
F
.
_default_context_str
!=
"cpu"
,
reason
=
"Fails due to randomness on the GPU."
,
)
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
def
test_SubgraphSampler_without_dedpulication_Homo
(
labor
):
graph
=
dgl
.
graph
(
([
5
,
0
,
1
,
5
,
6
,
7
,
2
,
2
,
4
],
[
0
,
1
,
2
,
2
,
2
,
2
,
3
,
4
,
4
])
)
graph
=
gb
.
from_dglgraph
(
graph
,
True
)
graph
=
gb
.
from_dglgraph
(
graph
,
True
)
.
to
(
F
.
ctx
())
seed_nodes
=
torch
.
LongTensor
([
0
,
3
,
4
])
itemset
=
gb
.
ItemSet
(
seed_nodes
,
names
=
"seed_nodes"
)
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
len
(
seed_nodes
))
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
len
(
seed_nodes
)).
copy_to
(
F
.
ctx
()
)
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
...
...
@@ -285,14 +321,17 @@ def test_SubgraphSampler_without_dedpulication_Homo(labor):
length
=
[
17
,
7
]
compacted_indices
=
[
torch
.
arange
(
0
,
10
)
+
7
,
torch
.
arange
(
0
,
4
)
+
3
,
(
torch
.
arange
(
0
,
10
)
+
7
).
to
(
F
.
ctx
())
,
(
torch
.
arange
(
0
,
4
)
+
3
).
to
(
F
.
ctx
())
,
]
indptr
=
[
torch
.
tensor
([
0
,
1
,
2
,
4
,
4
,
6
,
8
,
10
]),
torch
.
tensor
([
0
,
1
,
2
,
4
]),
torch
.
tensor
([
0
,
1
,
2
,
4
,
4
,
6
,
8
,
10
]).
to
(
F
.
ctx
()),
torch
.
tensor
([
0
,
1
,
2
,
4
]).
to
(
F
.
ctx
()),
]
seeds
=
[
torch
.
tensor
([
0
,
3
,
4
,
5
,
2
,
2
,
4
]).
to
(
F
.
ctx
()),
torch
.
tensor
([
0
,
3
,
4
]).
to
(
F
.
ctx
()),
]
seeds
=
[
torch
.
tensor
([
0
,
3
,
4
,
5
,
2
,
2
,
4
]),
torch
.
tensor
([
0
,
3
,
4
])]
for
data
in
datapipe
:
for
step
,
sampled_subgraph
in
enumerate
(
data
.
sampled_subgraphs
):
assert
len
(
sampled_subgraph
.
original_row_node_ids
)
==
length
[
step
]
...
...
@@ -307,13 +346,17 @@ def test_SubgraphSampler_without_dedpulication_Homo(labor):
)
@
unittest
.
skipIf
(
F
.
_default_context_str
!=
"cpu"
,
reason
=
"Heterogenous sampling not yet supported on GPU."
,
)
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
def
test_SubgraphSampler_without_dedpulication_Hetero
(
labor
):
graph
=
get_hetero_graph
()
graph
=
get_hetero_graph
()
.
to
(
F
.
ctx
())
itemset
=
gb
.
ItemSetDict
(
{
"n2"
:
gb
.
ItemSet
(
torch
.
arange
(
2
),
names
=
"seed_nodes"
)}
)
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
.
copy_to
(
F
.
ctx
())
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
Sampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
...
...
@@ -383,15 +426,21 @@ def test_SubgraphSampler_without_dedpulication_Hetero(labor):
)
@
unittest
.
skipIf
(
F
.
_default_context_str
!=
"cpu"
,
reason
=
"Fails due to randomness on the GPU."
,
)
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
def
test_SubgraphSampler_unique_csc_format_Homo
(
labor
):
torch
.
manual_seed
(
1205
)
graph
=
dgl
.
graph
(([
5
,
0
,
6
,
7
,
2
,
2
,
4
],
[
0
,
1
,
2
,
2
,
3
,
4
,
4
]))
graph
=
gb
.
from_dglgraph
(
graph
,
True
)
graph
=
gb
.
from_dglgraph
(
graph
,
True
)
.
to
(
F
.
ctx
())
seed_nodes
=
torch
.
LongTensor
([
0
,
3
,
4
])
itemset
=
gb
.
ItemSet
(
seed_nodes
,
names
=
"seed_nodes"
)
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
len
(
seed_nodes
))
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
len
(
seed_nodes
)).
copy_to
(
F
.
ctx
()
)
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
...
...
@@ -405,18 +454,21 @@ def test_SubgraphSampler_unique_csc_format_Homo(labor):
)
original_row_node_ids
=
[
torch
.
tensor
([
0
,
3
,
4
,
5
,
2
,
6
,
7
]),
torch
.
tensor
([
0
,
3
,
4
,
5
,
2
]),
torch
.
tensor
([
0
,
3
,
4
,
5
,
2
,
6
,
7
])
.
to
(
F
.
ctx
())
,
torch
.
tensor
([
0
,
3
,
4
,
5
,
2
])
.
to
(
F
.
ctx
())
,
]
compacted_indices
=
[
torch
.
tensor
([
3
,
4
,
4
,
2
,
5
,
6
]),
torch
.
tensor
([
3
,
4
,
4
,
2
]),
torch
.
tensor
([
3
,
4
,
4
,
2
,
5
,
6
])
.
to
(
F
.
ctx
())
,
torch
.
tensor
([
3
,
4
,
4
,
2
])
.
to
(
F
.
ctx
())
,
]
indptr
=
[
torch
.
tensor
([
0
,
1
,
2
,
4
,
4
,
6
]),
torch
.
tensor
([
0
,
1
,
2
,
4
]),
torch
.
tensor
([
0
,
1
,
2
,
4
,
4
,
6
]).
to
(
F
.
ctx
()),
torch
.
tensor
([
0
,
1
,
2
,
4
]).
to
(
F
.
ctx
()),
]
seeds
=
[
torch
.
tensor
([
0
,
3
,
4
,
5
,
2
]).
to
(
F
.
ctx
()),
torch
.
tensor
([
0
,
3
,
4
]).
to
(
F
.
ctx
()),
]
seeds
=
[
torch
.
tensor
([
0
,
3
,
4
,
5
,
2
]),
torch
.
tensor
([
0
,
3
,
4
])]
for
data
in
datapipe
:
for
step
,
sampled_subgraph
in
enumerate
(
data
.
sampled_subgraphs
):
assert
torch
.
equal
(
...
...
@@ -434,13 +486,17 @@ def test_SubgraphSampler_unique_csc_format_Homo(labor):
)
@
unittest
.
skipIf
(
F
.
_default_context_str
!=
"cpu"
,
reason
=
"Heterogenous sampling not yet supported on GPU."
,
)
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
def
test_SubgraphSampler_unique_csc_format_Hetero
(
labor
):
graph
=
get_hetero_graph
()
graph
=
get_hetero_graph
()
.
to
(
F
.
ctx
())
itemset
=
gb
.
ItemSetDict
(
{
"n2"
:
gb
.
ItemSet
(
torch
.
arange
(
2
),
names
=
"seed_nodes"
)}
)
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
.
copy_to
(
F
.
ctx
())
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
Sampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment