Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
3a79f021
Unverified
Commit
3a79f021
authored
Jan 24, 2024
by
Muhammed Fatih BALIN
Committed by
GitHub
Jan 24, 2024
Browse files
[GraphBolt][CUDA] Make nodes optional for sampling (#6993)
parent
365bb723
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
149 additions
and
85 deletions
+149
-85
graphbolt/include/graphbolt/cuda_ops.h
graphbolt/include/graphbolt/cuda_ops.h
+3
-2
graphbolt/include/graphbolt/cuda_sampling_ops.h
graphbolt/include/graphbolt/cuda_sampling_ops.h
+5
-4
graphbolt/include/graphbolt/fused_csc_sampling_graph.h
graphbolt/include/graphbolt/fused_csc_sampling_graph.h
+3
-2
graphbolt/src/cuda/neighbor_sampler.cu
graphbolt/src/cuda/neighbor_sampler.cu
+54
-30
graphbolt/src/cuda/sampling_utils.cu
graphbolt/src/cuda/sampling_utils.cu
+42
-24
graphbolt/src/fused_csc_sampling_graph.cc
graphbolt/src/fused_csc_sampling_graph.cc
+22
-14
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
+6
-5
tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
...n/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
+14
-4
No files found.
graphbolt/include/graphbolt/cuda_ops.h
View file @
3a79f021
...
...
@@ -113,7 +113,8 @@ std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl(
* given nodes and their indptr values.
*
* @param indptr The indptr tensor.
* @param nodes The nodes to read from indptr
* @param nodes The nodes to read from indptr. If not provided, assumed to be
* equal to torch.arange(indptr.size(0) - 1).
*
* @return Tuple of tensors with values:
* (indptr[nodes + 1] - indptr[nodes], indptr[nodes]), the returned indegrees
...
...
@@ -121,7 +122,7 @@ std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl(
* on it gives the output indptr.
*/
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
SliceCSCIndptr
(
torch
::
Tensor
indptr
,
torch
::
Tensor
nodes
);
torch
::
Tensor
indptr
,
torch
::
optional
<
torch
::
Tensor
>
nodes
);
/**
* @brief Given the compacted sub_indptr tensor, edge type tensor and
...
...
graphbolt/include/graphbolt/cuda_sampling_ops.h
View file @
3a79f021
...
...
@@ -19,7 +19,8 @@ namespace ops {
*
* @param indptr Index pointer array of the CSC.
* @param indices Indices array of the CSC.
* @param nodes The nodes from which to sample neighbors.
* @param nodes The nodes from which to sample neighbors. If not provided,
* assumed to be equal to torch.arange(indptr.size(0) - 1).
* @param fanouts The number of edges to be sampled for each node with or
* without considering edge types.
* - When the length is 1, it indicates that the fanout applies to all
...
...
@@ -49,9 +50,9 @@ namespace ops {
* the sampled graph's information.
*/
c10
::
intrusive_ptr
<
sampling
::
FusedSampledSubgraph
>
SampleNeighbors
(
torch
::
Tensor
indptr
,
torch
::
Tensor
indices
,
torch
::
Tensor
nodes
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
bool
layer
,
bool
return_eids
,
torch
::
Tensor
indptr
,
torch
::
Tensor
indices
,
torch
::
optional
<
torch
::
Tensor
>
nodes
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
bool
layer
,
bool
return_eids
,
torch
::
optional
<
torch
::
Tensor
>
type_per_edge
=
torch
::
nullopt
,
torch
::
optional
<
torch
::
Tensor
>
probs_or_mask
=
torch
::
nullopt
);
...
...
graphbolt/include/graphbolt/fused_csc_sampling_graph.h
View file @
3a79f021
...
...
@@ -286,7 +286,8 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* @brief Sample neighboring edges of the given nodes and return the induced
* subgraph.
*
* @param nodes The nodes from which to sample neighbors.
* @param nodes The nodes from which to sample neighbors. If not provided,
* assumed to be equal to torch.arange(NumNodes()).
* @param fanouts The number of edges to be sampled for each node with or
* without considering edge types.
* - When the length is 1, it indicates that the fanout applies to all
...
...
@@ -317,7 +318,7 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* the sampled graph's information.
*/
c10
::
intrusive_ptr
<
FusedSampledSubgraph
>
SampleNeighbors
(
const
torch
::
Tensor
&
nodes
,
const
std
::
vector
<
int64_t
>&
fanouts
,
torch
::
optional
<
torch
::
Tensor
>
nodes
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
bool
layer
,
bool
return_eids
,
torch
::
optional
<
std
::
string
>
probs_name
)
const
;
...
...
graphbolt/src/cuda/neighbor_sampler.cu
View file @
3a79f021
...
...
@@ -130,16 +130,18 @@ struct SegmentEndFunc {
};
c10
::
intrusive_ptr
<
sampling
::
FusedSampledSubgraph
>
SampleNeighbors
(
torch
::
Tensor
indptr
,
torch
::
Tensor
indices
,
torch
::
Tensor
nodes
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
bool
layer
,
bool
return_eids
,
torch
::
optional
<
torch
::
Tensor
>
type_per_edge
,
torch
::
Tensor
indptr
,
torch
::
Tensor
indices
,
torch
::
optional
<
torch
::
Tensor
>
nodes
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
bool
layer
,
bool
return_eids
,
torch
::
optional
<
torch
::
Tensor
>
type_per_edge
,
torch
::
optional
<
torch
::
Tensor
>
probs_or_mask
)
{
TORCH_CHECK
(
!
replace
,
"Sampling with replacement is not supported yet!"
);
// Assume that indptr, indices, nodes, type_per_edge and probs_or_mask
// are all resident on the GPU. If not, it is better to first extract them
// before calling this function.
auto
allocator
=
cuda
::
GetAllocator
();
auto
num_rows
=
nodes
.
size
(
0
);
auto
num_rows
=
nodes
.
has_value
()
?
nodes
.
value
().
size
(
0
)
:
indptr
.
size
(
0
)
-
1
;
auto
fanouts_pinned
=
torch
::
empty
(
fanouts
.
size
(),
c10
::
TensorOptions
().
dtype
(
torch
::
kLong
).
pinned_memory
(
true
));
...
...
@@ -166,34 +168,49 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
DeviceReduce
::
Max
,
in_degree
.
data_ptr
<
index_t
>
(),
max_in_degree
.
data_ptr
<
index_t
>
(),
num_rows
);
}));
torch
::
optional
<
int64_t
>
num_edges_
;
// Protect access to max_in_degree with a CUDAEvent
at
::
cuda
::
CUDAEvent
max_in_degree_event
;
max_in_degree_event
.
record
();
torch
::
optional
<
int64_t
>
num_edges
;
torch
::
Tensor
sub_indptr
;
if
(
!
nodes
.
has_value
())
{
num_edges
=
indices
.
size
(
0
);
sub_indptr
=
indptr
;
}
torch
::
optional
<
torch
::
Tensor
>
sliced_probs_or_mask
;
if
(
probs_or_mask
.
has_value
())
{
if
(
nodes
.
has_value
())
{
torch
::
Tensor
sliced_probs_or_mask_tensor
;
std
::
tie
(
sub_indptr
,
sliced_probs_or_mask_tensor
)
=
IndexSelectCSCImpl
(
in_degree
,
sliced_indptr
,
probs_or_mask
.
value
(),
nodes
,
indptr
.
size
(
0
)
-
2
,
num_edges
_
);
in_degree
,
sliced_indptr
,
probs_or_mask
.
value
(),
nodes
.
value
()
,
indptr
.
size
(
0
)
-
2
,
num_edges
);
sliced_probs_or_mask
=
sliced_probs_or_mask_tensor
;
num_edges_
=
sliced_probs_or_mask_tensor
.
size
(
0
);
num_edges
=
sliced_probs_or_mask_tensor
.
size
(
0
);
}
else
{
sliced_probs_or_mask
=
probs_or_mask
;
}
}
if
(
fanouts
.
size
()
>
1
)
{
torch
::
Tensor
sliced_type_per_edge
;
if
(
nodes
.
has_value
())
{
std
::
tie
(
sub_indptr
,
sliced_type_per_edge
)
=
IndexSelectCSCImpl
(
in_degree
,
sliced_indptr
,
type_per_edge
.
value
(),
nodes
,
indptr
.
size
(
0
)
-
2
,
num_edges_
);
in_degree
,
sliced_indptr
,
type_per_edge
.
value
(),
nodes
.
value
(),
indptr
.
size
(
0
)
-
2
,
num_edges
);
}
else
{
sliced_type_per_edge
=
type_per_edge
.
value
();
}
std
::
tie
(
sub_indptr
,
in_degree
,
sliced_indptr
)
=
SliceCSCIndptrHetero
(
sub_indptr
,
sliced_type_per_edge
,
sliced_indptr
,
fanouts
.
size
());
num_rows
=
sliced_indptr
.
size
(
0
);
num_edges
_
=
sliced_type_per_edge
.
size
(
0
);
num_edges
=
sliced_type_per_edge
.
size
(
0
);
}
// If sub_indptr was not computed in the two code blocks above:
if
(
!
probs_or_mask
.
has_value
()
&&
fanouts
.
size
()
<=
1
)
{
if
(
nodes
.
has_value
()
&&
!
probs_or_mask
.
has_value
()
&&
fanouts
.
size
()
<=
1
)
{
sub_indptr
=
ExclusiveCumSum
(
in_degree
);
}
auto
coo_rows
=
ExpandIndptrImpl
(
sub_indptr
,
indices
.
scalar_type
(),
torch
::
nullopt
,
num_edges
_
);
const
auto
num_edges
=
coo_rows
.
size
(
0
);
sub_indptr
,
indices
.
scalar_type
(),
torch
::
nullopt
,
num_edges
);
num_edges
=
coo_rows
.
size
(
0
);
const
auto
random_seed
=
RandomEngine
::
ThreadLocal
()
->
RandInt
(
static_cast
<
int64_t
>
(
0
),
std
::
numeric_limits
<
int64_t
>::
max
());
auto
output_indptr
=
torch
::
empty_like
(
sub_indptr
);
...
...
@@ -233,9 +250,9 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
auto
num_sampled_edges
=
cuda
::
CopyScalar
{
output_indptr
.
data_ptr
<
indptr_t
>
()
+
num_rows
};
// Find the smallest integer type to store the edge id offsets.
//
ExpandIndptr or IndexSelectCSCImpl had synch inside, so it
is safe
to
// read
max_in_degree
now.
// Find the smallest integer type to store the edge id offsets.
We synch
//
the CUDAEvent so that the access
is safe
.
max_in_degree
_event
.
synchronize
();
const
int
num_bits
=
cuda
::
NumberOfBits
(
max_in_degree
.
data_ptr
<
indptr_t
>
()[
0
]);
std
::
array
<
int
,
4
>
type_bits
=
{
8
,
16
,
32
,
64
};
...
...
@@ -255,12 +272,14 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
// Using bfloat16 for random numbers works just as reliably as
// float32 and provides around %30 percent speedup.
using
rnd_t
=
nv_bfloat16
;
auto
randoms
=
allocator
.
AllocateStorage
<
rnd_t
>
(
num_edges
);
auto
randoms_sorted
=
allocator
.
AllocateStorage
<
rnd_t
>
(
num_edges
);
auto
randoms
=
allocator
.
AllocateStorage
<
rnd_t
>
(
num_edges
.
value
());
auto
randoms_sorted
=
allocator
.
AllocateStorage
<
rnd_t
>
(
num_edges
.
value
());
auto
edge_id_segments
=
allocator
.
AllocateStorage
<
edge_id_t
>
(
num_edges
);
allocator
.
AllocateStorage
<
edge_id_t
>
(
num_edges
.
value
()
);
auto
sorted_edge_id_segments
=
allocator
.
AllocateStorage
<
edge_id_t
>
(
num_edges
);
allocator
.
AllocateStorage
<
edge_id_t
>
(
num_edges
.
value
()
);
AT_DISPATCH_INDEX_TYPES
(
indices
.
scalar_type
(),
"SampleNeighborsIndices"
,
([
&
]
{
using
indices_t
=
index_t
;
...
...
@@ -282,10 +301,12 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
layer
?
indices
.
data_ptr
<
indices_t
>
()
:
nullptr
;
const
dim3
block
(
BLOCK_SIZE
);
const
dim3
grid
(
(
num_edges
+
BLOCK_SIZE
-
1
)
/
BLOCK_SIZE
);
(
num_edges
.
value
()
+
BLOCK_SIZE
-
1
)
/
BLOCK_SIZE
);
// Compute row and random number pairs.
CUDA_KERNEL_CALL
(
_ComputeRandoms
,
grid
,
block
,
0
,
num_edges
,
_ComputeRandoms
,
grid
,
block
,
0
,
num_edges
.
value
(),
sliced_indptr
.
data_ptr
<
indptr_t
>
(),
sub_indptr
.
data_ptr
<
indptr_t
>
(),
coo_rows
.
data_ptr
<
indices_t
>
(),
sliced_probs_ptr
,
...
...
@@ -300,13 +321,13 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
CUB_CALL
(
DeviceSegmentedSort
::
SortPairs
,
randoms
.
get
(),
randoms_sorted
.
get
(),
edge_id_segments
.
get
(),
sorted_edge_id_segments
.
get
(),
num_edges
,
num_rows
,
sorted_edge_id_segments
.
get
(),
num_edges
.
value
()
,
num_rows
,
sub_indptr
.
data_ptr
<
indptr_t
>
(),
sub_indptr
.
data_ptr
<
indptr_t
>
()
+
1
);
picked_eids
=
torch
::
empty
(
static_cast
<
indptr_t
>
(
num_sampled_edges
),
nodes
.
options
().
dtype
(
indptr
.
scalar_type
()
));
sub_indptr
.
options
(
));
// Need to sort the sampled edges only when fanouts.size() == 1
// since multiple fanout sampling case is automatically going to
...
...
@@ -385,9 +406,12 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
output_indptr
.
slice
(
0
,
0
,
output_indptr
.
size
(
0
),
fanouts
.
size
());
torch
::
optional
<
torch
::
Tensor
>
subgraph_reverse_edge_ids
=
torch
::
nullopt
;
if
(
return_eids
)
subgraph_reverse_edge_ids
=
std
::
move
(
picked_eids
);
if
(
!
nodes
.
has_value
())
{
nodes
=
torch
::
arange
(
indptr
.
size
(
0
)
-
1
,
indices
.
options
());
}
return
c10
::
make_intrusive
<
sampling
::
FusedSampledSubgraph
>
(
output_indptr
,
output_indices
,
nodes
,
torch
::
nullopt
,
output_indptr
,
output_indices
,
nodes
.
value
()
,
torch
::
nullopt
,
subgraph_reverse_edge_ids
,
output_type_per_edge
);
}
...
...
graphbolt/src/cuda/sampling_utils.cu
View file @
3a79f021
...
...
@@ -35,14 +35,16 @@ struct SliceFunc {
// Returns (indptr[nodes + 1] - indptr[nodes], indptr[nodes])
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
SliceCSCIndptr
(
torch
::
Tensor
indptr
,
torch
::
Tensor
nodes
)
{
torch
::
Tensor
indptr
,
torch
::
optional
<
torch
::
Tensor
>
nodes_optional
)
{
if
(
nodes_optional
.
has_value
())
{
auto
nodes
=
nodes_optional
.
value
();
const
int64_t
num_nodes
=
nodes
.
size
(
0
);
// Read indptr only once in case it is pinned and access is slow.
auto
sliced_indptr
=
torch
::
empty
(
num_nodes
,
nodes
.
options
().
dtype
(
indptr
.
scalar_type
()));
// compute in-degrees
auto
in_degree
=
torch
::
empty
(
num_nodes
+
1
,
nodes
.
options
().
dtype
(
indptr
.
scalar_type
()));
auto
in_degree
=
torch
::
empty
(
num_nodes
+
1
,
nodes
.
options
().
dtype
(
indptr
.
scalar_type
()));
thrust
::
counting_iterator
<
int64_t
>
iota
(
0
);
AT_DISPATCH_INTEGRAL_TYPES
(
indptr
.
scalar_type
(),
"IndexSelectCSCIndptr"
,
([
&
]
{
...
...
@@ -59,6 +61,22 @@ std::tuple<torch::Tensor, torch::Tensor> SliceCSCIndptr(
}));
}));
return
{
in_degree
,
sliced_indptr
};
}
else
{
const
int64_t
num_nodes
=
indptr
.
size
(
0
)
-
1
;
auto
sliced_indptr
=
indptr
.
slice
(
0
,
0
,
num_nodes
);
auto
in_degree
=
torch
::
empty
(
num_nodes
+
2
,
indptr
.
options
().
dtype
(
indptr
.
scalar_type
()));
AT_DISPATCH_INTEGRAL_TYPES
(
indptr
.
scalar_type
(),
"IndexSelectCSCIndptr"
,
([
&
]
{
using
indptr_t
=
scalar_t
;
CUB_CALL
(
DeviceAdjacentDifference
::
SubtractLeftCopy
,
indptr
.
data_ptr
<
indptr_t
>
(),
in_degree
.
data_ptr
<
indptr_t
>
(),
num_nodes
+
1
,
cub
::
Difference
{});
}));
in_degree
=
in_degree
.
slice
(
0
,
1
);
return
{
in_degree
,
sliced_indptr
};
}
}
template
<
typename
indptr_t
,
typename
etype_t
>
...
...
graphbolt/src/fused_csc_sampling_graph.cc
View file @
3a79f021
...
...
@@ -607,21 +607,28 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
}
c10
::
intrusive_ptr
<
FusedSampledSubgraph
>
FusedCSCSamplingGraph
::
SampleNeighbors
(
const
torch
::
Tensor
&
nodes
,
const
std
::
vector
<
int64_t
>&
fanouts
,
torch
::
optional
<
torch
::
Tensor
>
nodes
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
bool
layer
,
bool
return_eids
,
torch
::
optional
<
std
::
string
>
probs_name
)
const
{
torch
::
optional
<
torch
::
Tensor
>
probs_or_mask
=
torch
::
nullopt
;
if
(
probs_name
.
has_value
()
&&
!
probs_name
.
value
().
empty
())
{
probs_or_mask
=
this
->
EdgeAttribute
(
probs_name
);
}
auto
probs_or_mask
=
this
->
EdgeAttribute
(
probs_name
);
if
(
!
replace
&&
utils
::
is_on_gpu
(
nodes
)
&&
// If nodes does not have a value, then we expect all arguments to be resident
// on the GPU. If nodes has a value, then we expect them to be accessible from
// GPU. This is required for the dispatch to work when CUDA is not available.
if
(((
!
nodes
.
has_value
()
&&
utils
::
is_on_gpu
(
indptr_
)
&&
utils
::
is_on_gpu
(
indices_
)
&&
(
!
probs_or_mask
.
has_value
()
||
utils
::
is_on_gpu
(
probs_or_mask
.
value
()))
&&
(
!
type_per_edge_
.
has_value
()
||
utils
::
is_on_gpu
(
type_per_edge_
.
value
())))
||
(
nodes
.
has_value
()
&&
utils
::
is_on_gpu
(
nodes
.
value
())
&&
utils
::
is_accessible_from_gpu
(
indptr_
)
&&
utils
::
is_accessible_from_gpu
(
indices_
)
&&
(
!
probs_or_mask
.
has_value
()
||
utils
::
is_accessible_from_gpu
(
probs_or_mask
.
value
()))
&&
(
!
type_per_edge_
.
has_value
()
||
utils
::
is_accessible_from_gpu
(
type_per_edge_
.
value
())))
{
utils
::
is_accessible_from_gpu
(
type_per_edge_
.
value
()))))
&&
!
replace
)
{
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE
(
c10
::
DeviceType
::
CUDA
,
"SampleNeighbors"
,
{
return
ops
::
SampleNeighbors
(
...
...
@@ -629,6 +636,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
type_per_edge_
,
probs_or_mask
);
});
}
TORCH_CHECK
(
nodes
.
has_value
(),
"Nodes can not be None on the CPU."
);
if
(
probs_or_mask
.
has_value
())
{
// Note probs will be passed as input for 'torch.multinomial' in deeper
...
...
@@ -645,7 +653,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
static_cast
<
int64_t
>
(
0
),
std
::
numeric_limits
<
int64_t
>::
max
());
SamplerArgs
<
SamplerType
::
LABOR
>
args
{
indices_
,
random_seed
,
NumNodes
()};
return
SampleNeighborsImpl
(
nodes
,
return_eids
,
nodes
.
value
()
,
return_eids
,
GetNumPickFn
(
fanouts
,
replace
,
type_per_edge_
,
probs_or_mask
),
GetPickFn
(
fanouts
,
replace
,
indptr_
.
options
(),
type_per_edge_
,
probs_or_mask
,
...
...
@@ -653,7 +661,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
}
else
{
SamplerArgs
<
SamplerType
::
NEIGHBOR
>
args
;
return
SampleNeighborsImpl
(
nodes
,
return_eids
,
nodes
.
value
()
,
return_eids
,
GetNumPickFn
(
fanouts
,
replace
,
type_per_edge_
,
probs_or_mask
),
GetPickFn
(
fanouts
,
replace
,
indptr_
.
options
(),
type_per_edge_
,
probs_or_mask
,
...
...
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
View file @
3a79f021
...
...
@@ -597,6 +597,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
return
self
.
_convert_to_sampled_subgraph
(
C_sampled_subgraph
)
def
_check_sampler_arguments
(
self
,
nodes
,
fanouts
,
probs_name
):
if
nodes
is
not
None
:
assert
nodes
.
dim
()
==
1
,
"Nodes should be 1-D tensor."
assert
nodes
.
dtype
==
self
.
indices
.
dtype
,
(
f
"Data type of nodes must be consistent with "
...
...
tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
View file @
3a79f021
...
...
@@ -1615,7 +1615,10 @@ def test_csc_sampling_graph_to_pinned_memory():
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"is_pinned"
,
[
False
,
True
])
def
test_sample_neighbors_homo
(
labor
,
is_pinned
):
@
pytest
.
mark
.
parametrize
(
"nodes"
,
[
None
,
True
])
def
test_sample_neighbors_homo
(
labor
,
is_pinned
,
nodes
):
if
is_pinned
and
nodes
is
None
:
pytest
.
skip
(
"Optional nodes and is_pinned is not supported together."
)
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
...
...
@@ -1638,13 +1641,20 @@ def test_sample_neighbors_homo(labor, is_pinned):
)
# Generate subgraph via sample neighbors.
if
nodes
:
nodes
=
torch
.
LongTensor
([
1
,
3
,
4
]).
to
(
F
.
ctx
())
elif
F
.
_default_context_str
!=
"gpu"
:
pytest
.
skip
(
"Optional nodes is supported only for the GPU."
)
sampler
=
graph
.
sample_layer_neighbors
if
labor
else
graph
.
sample_neighbors
subgraph
=
sampler
(
nodes
,
fanouts
=
torch
.
LongTensor
([
2
]))
# Verify in subgraph.
sampled_indptr_num
=
subgraph
.
sampled_csc
.
indptr
.
size
(
0
)
sampled_num
=
subgraph
.
sampled_csc
.
indices
.
size
(
0
)
if
nodes
is
None
:
assert
sampled_indptr_num
==
indptr
.
shape
[
0
]
assert
sampled_num
==
10
else
:
assert
sampled_indptr_num
==
4
assert
sampled_num
==
6
assert
subgraph
.
original_column_node_ids
is
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment