Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
a272efed
Unverified
Commit
a272efed
authored
Mar 13, 2024
by
Muhammed Fatih BALIN
Committed by
GitHub
Mar 13, 2024
Browse files
[GraphBolt] Implement labor dependent minibatching - python side. (#7208)
parent
93990a90
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
187 additions
and
16 deletions
+187
-16
graphbolt/include/graphbolt/cuda_sampling_ops.h
graphbolt/include/graphbolt/cuda_sampling_ops.h
+6
-1
graphbolt/include/graphbolt/fused_csc_sampling_graph.h
graphbolt/include/graphbolt/fused_csc_sampling_graph.h
+6
-1
graphbolt/src/cuda/neighbor_sampler.cu
graphbolt/src/cuda/neighbor_sampler.cu
+11
-3
graphbolt/src/fused_csc_sampling_graph.cc
graphbolt/src/fused_csc_sampling_graph.cc
+18
-5
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
+7
-1
python/dgl/graphbolt/impl/neighbor_sampler.py
python/dgl/graphbolt/impl/neighbor_sampler.py
+83
-5
tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py
tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py
+56
-0
No files found.
graphbolt/include/graphbolt/cuda_sampling_ops.h
View file @
a272efed
...
@@ -45,6 +45,9 @@ namespace ops {
...
@@ -45,6 +45,9 @@ namespace ops {
* @param probs_or_mask An optional tensor with (unnormalized) probabilities
* @param probs_or_mask An optional tensor with (unnormalized) probabilities
* corresponding to each neighboring edge of a node. It must be
* corresponding to each neighboring edge of a node. It must be
* a 1D tensor, with the number of elements equaling the total number of edges.
* a 1D tensor, with the number of elements equaling the total number of edges.
* @param random_seed The random seed for the sampler for layer=True.
* @param seed2_contribution The contribution of the second random seed, [0, 1)
* for layer=True.
*
*
* @return An intrusive pointer to a FusedSampledSubgraph object containing
* @return An intrusive pointer to a FusedSampledSubgraph object containing
* the sampled graph's information.
* the sampled graph's information.
...
@@ -54,7 +57,9 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
...
@@ -54,7 +57,9 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
torch
::
optional
<
torch
::
Tensor
>
nodes
,
const
std
::
vector
<
int64_t
>&
fanouts
,
torch
::
optional
<
torch
::
Tensor
>
nodes
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
bool
layer
,
bool
return_eids
,
bool
replace
,
bool
layer
,
bool
return_eids
,
torch
::
optional
<
torch
::
Tensor
>
type_per_edge
=
torch
::
nullopt
,
torch
::
optional
<
torch
::
Tensor
>
type_per_edge
=
torch
::
nullopt
,
torch
::
optional
<
torch
::
Tensor
>
probs_or_mask
=
torch
::
nullopt
);
torch
::
optional
<
torch
::
Tensor
>
probs_or_mask
=
torch
::
nullopt
,
torch
::
optional
<
torch
::
Tensor
>
random_seed
=
torch
::
nullopt
,
float
seed2_contribution
=
.0
f
);
/**
/**
* @brief Return the subgraph induced on the inbound edges of the given nodes.
* @brief Return the subgraph induced on the inbound edges of the given nodes.
...
...
graphbolt/include/graphbolt/fused_csc_sampling_graph.h
View file @
a272efed
...
@@ -314,6 +314,9 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
...
@@ -314,6 +314,9 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* probabilities corresponding to each neighboring edge of a node. It must be
* probabilities corresponding to each neighboring edge of a node. It must be
* a 1D floating-point or boolean tensor, with the number of elements
* a 1D floating-point or boolean tensor, with the number of elements
* equalling the total number of edges.
* equalling the total number of edges.
* @param random_seed The random seed for the sampler for layer=True.
* @param seed2_contribution The contribution of the second random seed,
* [0, 1) for layer=True.
*
*
* @return An intrusive pointer to a FusedSampledSubgraph object containing
* @return An intrusive pointer to a FusedSampledSubgraph object containing
* the sampled graph's information.
* the sampled graph's information.
...
@@ -321,7 +324,9 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
...
@@ -321,7 +324,9 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
c10
::
intrusive_ptr
<
FusedSampledSubgraph
>
SampleNeighbors
(
c10
::
intrusive_ptr
<
FusedSampledSubgraph
>
SampleNeighbors
(
torch
::
optional
<
torch
::
Tensor
>
nodes
,
const
std
::
vector
<
int64_t
>&
fanouts
,
torch
::
optional
<
torch
::
Tensor
>
nodes
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
bool
layer
,
bool
return_eids
,
bool
replace
,
bool
layer
,
bool
return_eids
,
torch
::
optional
<
std
::
string
>
probs_name
)
const
;
torch
::
optional
<
std
::
string
>
probs_name
,
torch
::
optional
<
torch
::
Tensor
>
random_seed
,
double
seed2_contribution
)
const
;
/**
/**
* @brief Sample neighboring edges of the given nodes with a temporal
* @brief Sample neighboring edges of the given nodes with a temporal
...
...
graphbolt/src/cuda/neighbor_sampler.cu
View file @
a272efed
...
@@ -125,7 +125,9 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
...
@@ -125,7 +125,9 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
torch
::
optional
<
torch
::
Tensor
>
nodes
,
const
std
::
vector
<
int64_t
>&
fanouts
,
torch
::
optional
<
torch
::
Tensor
>
nodes
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
bool
layer
,
bool
return_eids
,
bool
replace
,
bool
layer
,
bool
return_eids
,
torch
::
optional
<
torch
::
Tensor
>
type_per_edge
,
torch
::
optional
<
torch
::
Tensor
>
type_per_edge
,
torch
::
optional
<
torch
::
Tensor
>
probs_or_mask
)
{
torch
::
optional
<
torch
::
Tensor
>
probs_or_mask
,
torch
::
optional
<
torch
::
Tensor
>
random_seed_tensor
,
float
seed2_contribution
)
{
TORCH_CHECK
(
!
replace
,
"Sampling with replacement is not supported yet!"
);
TORCH_CHECK
(
!
replace
,
"Sampling with replacement is not supported yet!"
);
// Assume that indptr, indices, nodes, type_per_edge and probs_or_mask
// Assume that indptr, indices, nodes, type_per_edge and probs_or_mask
// are all resident on the GPU. If not, it is better to first extract them
// are all resident on the GPU. If not, it is better to first extract them
...
@@ -202,8 +204,14 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
...
@@ -202,8 +204,14 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
auto
coo_rows
=
ExpandIndptrImpl
(
auto
coo_rows
=
ExpandIndptrImpl
(
sub_indptr
,
indices
.
scalar_type
(),
torch
::
nullopt
,
num_edges
);
sub_indptr
,
indices
.
scalar_type
(),
torch
::
nullopt
,
num_edges
);
num_edges
=
coo_rows
.
size
(
0
);
num_edges
=
coo_rows
.
size
(
0
);
const
continuous_seed
random_seed
(
RandomEngine
::
ThreadLocal
()
->
RandInt
(
const
continuous_seed
random_seed
=
[
&
]
{
static_cast
<
int64_t
>
(
0
),
std
::
numeric_limits
<
int64_t
>::
max
()));
if
(
random_seed_tensor
.
has_value
())
{
return
continuous_seed
(
random_seed_tensor
.
value
(),
seed2_contribution
);
}
else
{
return
continuous_seed
{
RandomEngine
::
ThreadLocal
()
->
RandInt
(
static_cast
<
int64_t
>
(
0
),
std
::
numeric_limits
<
int64_t
>::
max
())};
}
}();
auto
output_indptr
=
torch
::
empty_like
(
sub_indptr
);
auto
output_indptr
=
torch
::
empty_like
(
sub_indptr
);
torch
::
Tensor
picked_eids
;
torch
::
Tensor
picked_eids
;
torch
::
Tensor
output_indices
;
torch
::
Tensor
output_indices
;
...
...
graphbolt/src/fused_csc_sampling_graph.cc
View file @
a272efed
...
@@ -618,7 +618,9 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
...
@@ -618,7 +618,9 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
c10
::
intrusive_ptr
<
FusedSampledSubgraph
>
FusedCSCSamplingGraph
::
SampleNeighbors
(
c10
::
intrusive_ptr
<
FusedSampledSubgraph
>
FusedCSCSamplingGraph
::
SampleNeighbors
(
torch
::
optional
<
torch
::
Tensor
>
nodes
,
const
std
::
vector
<
int64_t
>&
fanouts
,
torch
::
optional
<
torch
::
Tensor
>
nodes
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
bool
layer
,
bool
return_eids
,
bool
replace
,
bool
layer
,
bool
return_eids
,
torch
::
optional
<
std
::
string
>
probs_name
)
const
{
torch
::
optional
<
std
::
string
>
probs_name
,
torch
::
optional
<
torch
::
Tensor
>
random_seed
,
double
seed2_contribution
)
const
{
auto
probs_or_mask
=
this
->
EdgeAttribute
(
probs_name
);
auto
probs_or_mask
=
this
->
EdgeAttribute
(
probs_name
);
// If nodes does not have a value, then we expect all arguments to be resident
// If nodes does not have a value, then we expect all arguments to be resident
...
@@ -642,7 +644,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
...
@@ -642,7 +644,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
c10
::
DeviceType
::
CUDA
,
"SampleNeighbors"
,
{
c10
::
DeviceType
::
CUDA
,
"SampleNeighbors"
,
{
return
ops
::
SampleNeighbors
(
return
ops
::
SampleNeighbors
(
indptr_
,
indices_
,
nodes
,
fanouts
,
replace
,
layer
,
return_eids
,
indptr_
,
indices_
,
nodes
,
fanouts
,
replace
,
layer
,
return_eids
,
type_per_edge_
,
probs_or_mask
);
type_per_edge_
,
probs_or_mask
,
random_seed
,
seed2_contribution
);
});
});
}
}
TORCH_CHECK
(
nodes
.
has_value
(),
"Nodes can not be None on the CPU."
);
TORCH_CHECK
(
nodes
.
has_value
(),
"Nodes can not be None on the CPU."
);
...
@@ -658,9 +660,20 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
...
@@ -658,9 +660,20 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
}
}
if
(
layer
)
{
if
(
layer
)
{
const
int64_t
random_seed
=
RandomEngine
::
ThreadLocal
()
->
RandInt
(
SamplerArgs
<
SamplerType
::
LABOR
>
args
=
[
&
]
{
static_cast
<
int64_t
>
(
0
),
std
::
numeric_limits
<
int64_t
>::
max
());
if
(
random_seed
.
has_value
())
{
SamplerArgs
<
SamplerType
::
LABOR
>
args
{
indices_
,
random_seed
,
NumNodes
()};
return
SamplerArgs
<
SamplerType
::
LABOR
>
{
indices_
,
{
random_seed
.
value
(),
static_cast
<
float
>
(
seed2_contribution
)},
NumNodes
()};
}
else
{
return
SamplerArgs
<
SamplerType
::
LABOR
>
{
indices_
,
RandomEngine
::
ThreadLocal
()
->
RandInt
(
static_cast
<
int64_t
>
(
0
),
std
::
numeric_limits
<
int64_t
>::
max
()),
NumNodes
()};
}
}();
return
SampleNeighborsImpl
(
return
SampleNeighborsImpl
(
nodes
.
value
(),
return_eids
,
nodes
.
value
(),
return_eids
,
GetNumPickFn
(
fanouts
,
replace
,
type_per_edge_
,
probs_or_mask
),
GetNumPickFn
(
fanouts
,
replace
,
type_per_edge_
,
probs_or_mask
),
...
...
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
View file @
a272efed
...
@@ -735,9 +735,11 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -735,9 +735,11 @@ class FusedCSCSamplingGraph(SamplingGraph):
nodes
,
nodes
,
fanouts
.
tolist
(),
fanouts
.
tolist
(),
replace
,
replace
,
False
,
False
,
# is_labor
return_eids
,
return_eids
,
probs_name
,
probs_name
,
None
,
# random_seed, labor parameter
0
,
# seed2_contribution, labor_parameter
)
)
def
sample_layer_neighbors
(
def
sample_layer_neighbors
(
...
@@ -746,6 +748,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -746,6 +748,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
fanouts
:
torch
.
Tensor
,
fanouts
:
torch
.
Tensor
,
replace
:
bool
=
False
,
replace
:
bool
=
False
,
probs_name
:
Optional
[
str
]
=
None
,
probs_name
:
Optional
[
str
]
=
None
,
random_seed
:
torch
.
Tensor
=
None
,
seed2_contribution
:
float
=
0.0
,
)
->
SampledSubgraphImpl
:
)
->
SampledSubgraphImpl
:
"""Sample neighboring edges of the given nodes and return the induced
"""Sample neighboring edges of the given nodes and return the induced
subgraph via layer-neighbor sampling from the NeurIPS 2023 paper
subgraph via layer-neighbor sampling from the NeurIPS 2023 paper
...
@@ -833,6 +837,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -833,6 +837,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
True
,
True
,
has_original_eids
,
has_original_eids
,
probs_name
,
probs_name
,
random_seed
,
seed2_contribution
,
)
)
return
self
.
_convert_to_sampled_subgraph
(
C_sampled_subgraph
)
return
self
.
_convert_to_sampled_subgraph
(
C_sampled_subgraph
)
...
...
python/dgl/graphbolt/impl/neighbor_sampler.py
View file @
a272efed
...
@@ -146,12 +146,17 @@ class SamplePerLayerFromFetchedSubgraph(MiniBatchTransformer):
...
@@ -146,12 +146,17 @@ class SamplePerLayerFromFetchedSubgraph(MiniBatchTransformer):
def
_sample_per_layer_from_fetched_subgraph
(
self
,
minibatch
):
def
_sample_per_layer_from_fetched_subgraph
(
self
,
minibatch
):
subgraph
=
minibatch
.
sampled_subgraphs
[
0
]
subgraph
=
minibatch
.
sampled_subgraphs
[
0
]
kwargs
=
{
key
[
1
:]:
getattr
(
minibatch
,
key
)
for
key
in
[
"_random_seed"
,
"_seed2_contribution"
]
if
hasattr
(
minibatch
,
key
)
}
sampled_subgraph
=
getattr
(
subgraph
,
self
.
sampler_name
)(
sampled_subgraph
=
getattr
(
subgraph
,
self
.
sampler_name
)(
minibatch
.
_subgraph_seed_nodes
,
minibatch
.
_subgraph_seed_nodes
,
self
.
fanout
,
self
.
fanout
,
self
.
replace
,
self
.
replace
,
self
.
prob_name
,
self
.
prob_name
,
**
kwargs
,
)
)
delattr
(
minibatch
,
"_subgraph_seed_nodes"
)
delattr
(
minibatch
,
"_subgraph_seed_nodes"
)
sampled_subgraph
.
original_column_node_ids
=
minibatch
.
_seed_nodes
sampled_subgraph
.
original_column_node_ids
=
minibatch
.
_seed_nodes
...
@@ -172,8 +177,17 @@ class SamplePerLayer(MiniBatchTransformer):
...
@@ -172,8 +177,17 @@ class SamplePerLayer(MiniBatchTransformer):
self
.
prob_name
=
prob_name
self
.
prob_name
=
prob_name
def
_sample_per_layer
(
self
,
minibatch
):
def
_sample_per_layer
(
self
,
minibatch
):
kwargs
=
{
key
[
1
:]:
getattr
(
minibatch
,
key
)
for
key
in
[
"_random_seed"
,
"_seed2_contribution"
]
if
hasattr
(
minibatch
,
key
)
}
subgraph
=
self
.
sampler
(
subgraph
=
self
.
sampler
(
minibatch
.
_seed_nodes
,
self
.
fanout
,
self
.
replace
,
self
.
prob_name
minibatch
.
_seed_nodes
,
self
.
fanout
,
self
.
replace
,
self
.
prob_name
,
**
kwargs
,
)
)
minibatch
.
sampled_subgraphs
.
insert
(
0
,
subgraph
)
minibatch
.
sampled_subgraphs
.
insert
(
0
,
subgraph
)
return
minibatch
return
minibatch
...
@@ -244,11 +258,57 @@ class NeighborSamplerImpl(SubgraphSampler):
...
@@ -244,11 +258,57 @@ class NeighborSamplerImpl(SubgraphSampler):
prob_name
,
prob_name
,
deduplicate
,
deduplicate
,
sampler
,
sampler
,
layer_dependency
=
None
,
batch_dependency
=
None
,
):
):
if
sampler
.
__name__
==
"sample_layer_neighbors"
:
self
.
_init_seed
(
batch_dependency
)
super
().
__init__
(
super
().
__init__
(
datapipe
,
graph
,
fanouts
,
replace
,
prob_name
,
deduplicate
,
sampler
datapipe
,
graph
,
fanouts
,
replace
,
prob_name
,
deduplicate
,
sampler
,
layer_dependency
,
)
)
def
_init_seed
(
self
,
batch_dependency
):
self
.
rng
=
torch
.
random
.
manual_seed
(
torch
.
randint
(
0
,
int
(
1e18
),
size
=
tuple
())
)
self
.
cnt
=
[
-
1
,
int
(
batch_dependency
)]
self
.
random_seed
=
torch
.
empty
(
2
if
self
.
cnt
[
1
]
>
1
else
1
,
dtype
=
torch
.
int64
)
self
.
random_seed
.
random_
(
generator
=
self
.
rng
)
def
_set_seed
(
self
,
minibatch
):
self
.
cnt
[
0
]
+=
1
if
self
.
cnt
[
1
]
>
0
and
self
.
cnt
[
0
]
%
self
.
cnt
[
1
]
==
0
:
self
.
random_seed
[
0
]
=
self
.
random_seed
[
-
1
]
self
.
random_seed
[
-
1
:].
random_
(
generator
=
self
.
rng
)
minibatch
.
_random_seed
=
self
.
random_seed
.
clone
()
minibatch
.
_seed2_contribution
=
(
0.0
if
self
.
cnt
[
1
]
<=
1
else
(
self
.
cnt
[
0
]
%
self
.
cnt
[
1
])
/
self
.
cnt
[
1
]
)
minibatch
.
_iter
=
self
.
cnt
[
0
]
return
minibatch
@
staticmethod
def
_increment_seed
(
minibatch
):
minibatch
.
_random_seed
=
1
+
minibatch
.
_random_seed
return
minibatch
@
staticmethod
def
_delattr_dependency
(
minibatch
):
delattr
(
minibatch
,
"_random_seed"
)
delattr
(
minibatch
,
"_seed2_contribution"
)
return
minibatch
@
staticmethod
@
staticmethod
def
_prepare
(
node_type_to_id
,
minibatch
):
def
_prepare
(
node_type_to_id
,
minibatch
):
seeds
=
minibatch
.
_seed_nodes
seeds
=
minibatch
.
_seed_nodes
...
@@ -277,11 +337,22 @@ class NeighborSamplerImpl(SubgraphSampler):
...
@@ -277,11 +337,22 @@ class NeighborSamplerImpl(SubgraphSampler):
# pylint: disable=arguments-differ
# pylint: disable=arguments-differ
def
sampling_stages
(
def
sampling_stages
(
self
,
datapipe
,
graph
,
fanouts
,
replace
,
prob_name
,
deduplicate
,
sampler
self
,
datapipe
,
graph
,
fanouts
,
replace
,
prob_name
,
deduplicate
,
sampler
,
layer_dependency
,
):
):
datapipe
=
datapipe
.
transform
(
datapipe
=
datapipe
.
transform
(
partial
(
self
.
_prepare
,
graph
.
node_type_to_id
)
partial
(
self
.
_prepare
,
graph
.
node_type_to_id
)
)
)
is_labor
=
sampler
.
__name__
==
"sample_layer_neighbors"
if
is_labor
:
datapipe
=
datapipe
.
transform
(
self
.
_set_seed
)
for
fanout
in
reversed
(
fanouts
):
for
fanout
in
reversed
(
fanouts
):
# Convert fanout to tensor.
# Convert fanout to tensor.
if
not
isinstance
(
fanout
,
torch
.
Tensor
):
if
not
isinstance
(
fanout
,
torch
.
Tensor
):
...
@@ -290,7 +361,10 @@ class NeighborSamplerImpl(SubgraphSampler):
...
@@ -290,7 +361,10 @@ class NeighborSamplerImpl(SubgraphSampler):
sampler
,
fanout
,
replace
,
prob_name
sampler
,
fanout
,
replace
,
prob_name
)
)
datapipe
=
datapipe
.
compact_per_layer
(
deduplicate
)
datapipe
=
datapipe
.
compact_per_layer
(
deduplicate
)
if
is_labor
and
not
layer_dependency
:
datapipe
=
datapipe
.
transform
(
self
.
_increment_seed
)
if
is_labor
:
datapipe
=
datapipe
.
transform
(
self
.
_delattr_dependency
)
return
datapipe
.
transform
(
self
.
_set_input_nodes
)
return
datapipe
.
transform
(
self
.
_set_input_nodes
)
...
@@ -504,6 +578,8 @@ class LayerNeighborSampler(NeighborSamplerImpl):
...
@@ -504,6 +578,8 @@ class LayerNeighborSampler(NeighborSamplerImpl):
replace
=
False
,
replace
=
False
,
prob_name
=
None
,
prob_name
=
None
,
deduplicate
=
True
,
deduplicate
=
True
,
layer_dependency
=
False
,
batch_dependency
=
1
,
):
):
super
().
__init__
(
super
().
__init__
(
datapipe
,
datapipe
,
...
@@ -513,4 +589,6 @@ class LayerNeighborSampler(NeighborSamplerImpl):
...
@@ -513,4 +589,6 @@ class LayerNeighborSampler(NeighborSamplerImpl):
prob_name
,
prob_name
,
deduplicate
,
deduplicate
,
graph
.
sample_layer_neighbors
,
graph
.
sample_layer_neighbors
,
layer_dependency
,
batch_dependency
,
)
)
tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py
View file @
a272efed
...
@@ -75,3 +75,59 @@ def test_NeighborSampler_GraphFetch(hetero, prob_name, sorted):
...
@@ -75,3 +75,59 @@ def test_NeighborSampler_GraphFetch(hetero, prob_name, sorted):
assert
len
(
expected_results
)
==
len
(
new_results
)
assert
len
(
expected_results
)
==
len
(
new_results
)
for
a
,
b
in
zip
(
expected_results
,
new_results
):
for
a
,
b
in
zip
(
expected_results
,
new_results
):
assert
repr
(
a
)
==
repr
(
b
)
assert
repr
(
a
)
==
repr
(
b
)
@
pytest
.
mark
.
parametrize
(
"layer_dependency"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"overlap_graph_fetch"
,
[
False
,
True
])
def
test_labor_dependent_minibatching
(
layer_dependency
,
overlap_graph_fetch
):
num_edges
=
200
csc_indptr
=
torch
.
cat
(
(
torch
.
zeros
(
1
,
dtype
=
torch
.
int64
),
torch
.
ones
(
num_edges
+
1
,
dtype
=
torch
.
int64
)
*
num_edges
,
)
)
indices
=
torch
.
arange
(
1
,
num_edges
+
1
)
graph
=
gb
.
fused_csc_sampling_graph
(
csc_indptr
.
int
(),
indices
.
int
(),
).
to
(
F
.
ctx
())
torch
.
random
.
set_rng_state
(
torch
.
manual_seed
(
123
).
get_state
())
batch_dependency
=
100
itemset
=
gb
.
ItemSet
(
torch
.
zeros
(
batch_dependency
+
1
).
int
(),
names
=
"seed_nodes"
)
datapipe
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
1
).
copy_to
(
F
.
ctx
())
fanouts
=
[
5
,
5
]
datapipe
=
datapipe
.
sample_layer_neighbor
(
graph
,
fanouts
,
layer_dependency
=
layer_dependency
,
batch_dependency
=
batch_dependency
,
)
dataloader
=
gb
.
DataLoader
(
datapipe
,
overlap_graph_fetch
=
overlap_graph_fetch
)
res
=
list
(
dataloader
)
assert
len
(
res
)
==
batch_dependency
+
1
if
layer_dependency
:
assert
torch
.
equal
(
res
[
0
].
input_nodes
,
res
[
0
].
sampled_subgraphs
[
1
].
original_row_node_ids
,
)
else
:
assert
res
[
0
].
input_nodes
.
size
(
0
)
>
res
[
0
].
sampled_subgraphs
[
1
].
original_row_node_ids
.
size
(
0
)
delta
=
0
for
i
in
range
(
batch_dependency
):
res_current
=
(
res
[
i
].
sampled_subgraphs
[
-
1
].
original_row_node_ids
.
tolist
()
)
res_next
=
(
res
[
i
+
1
].
sampled_subgraphs
[
-
1
].
original_row_node_ids
.
tolist
()
)
intersect_len
=
len
(
set
(
res_current
).
intersection
(
set
(
res_next
)))
assert
intersect_len
>=
fanouts
[
-
1
]
delta
+=
1
+
fanouts
[
-
1
]
-
intersect_len
assert
delta
>=
fanouts
[
-
1
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment