Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
e5b92d2b
You need to sign in or sign up before continuing.
Unverified
Commit
e5b92d2b
authored
Apr 06, 2024
by
Muhammed Fatih BALIN
Committed by
GitHub
Apr 06, 2024
Browse files
[GraphBolt][CUDA] Remove overlap graph variable hacks. (#7263)
parent
d4a6f8a0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
38 additions
and
8 deletions
+38
-8
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
+33
-4
python/dgl/graphbolt/impl/neighbor_sampler.py
python/dgl/graphbolt/impl/neighbor_sampler.py
+5
-4
No files found.
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
View file @
e5b92d2b
...
@@ -290,6 +290,35 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -290,6 +290,35 @@ class FusedCSCSamplingGraph(SamplingGraph):
self
.
_c_csc_graph
.
set_node_type_offset
(
node_type_offset
)
self
.
_c_csc_graph
.
set_node_type_offset
(
node_type_offset
)
self
.
_node_type_offset_cached_list
=
None
self
.
_node_type_offset_cached_list
=
None
@
property
def
_indptr_node_type_offset_list
(
self
)
->
Optional
[
list
]:
"""Returns the indptr node type offset list which presents the column id
space when it does not match the global id space. It is useful when we
slice a subgraph from another FusedCSCSamplingGraph.
Returns
-------
list or None
If present, returns a 1D integer list of shape
`(num_node_types + 1,)`. The list is in ascending order as nodes
of the same type have continuous IDs, and larger node IDs are
paired with larger node type IDs. The first value is 0 and last
value is the number of nodes. And nodes with IDs between
`node_type_offset_[i]~node_type_offset_[i+1]` are of type id 'i'.
"""
return
(
self
.
_indptr_node_type_offset_list_
if
hasattr
(
self
,
"_indptr_node_type_offset_list_"
)
else
None
)
@
_indptr_node_type_offset_list
.
setter
def
_indptr_node_type_offset_list
(
self
,
indptr_node_type_offset_list
:
Optional
[
torch
.
Tensor
]
):
"""Sets the indptr node type offset list if present."""
self
.
_indptr_node_type_offset_list_
=
indptr_node_type_offset_list
@
property
@
property
def
type_per_edge
(
self
)
->
Optional
[
torch
.
Tensor
]:
def
type_per_edge
(
self
)
->
Optional
[
torch
.
Tensor
]:
"""Returns the edge type tensor if present.
"""Returns the edge type tensor if present.
...
@@ -665,8 +694,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -665,8 +694,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
seed_offsets
=
None
seed_offsets
=
None
if
isinstance
(
seeds
,
dict
):
if
isinstance
(
seeds
,
dict
):
seeds
,
seed_offsets
=
self
.
_convert_to_homogeneous_nodes
(
seeds
)
seeds
,
seed_offsets
=
self
.
_convert_to_homogeneous_nodes
(
seeds
)
elif
seeds
is
None
and
hasattr
(
self
,
"_seed_offset_list"
)
:
elif
seeds
is
None
:
seed_offsets
=
self
.
_
seed_offset_list
# pylint: disable=no-member
seed_offsets
=
self
.
_
indptr_node_type_offset_list
C_sampled_subgraph
=
self
.
_sample_neighbors
(
C_sampled_subgraph
=
self
.
_sample_neighbors
(
seeds
,
seeds
,
seed_offsets
,
seed_offsets
,
...
@@ -914,8 +943,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -914,8 +943,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
seed_offsets
=
None
seed_offsets
=
None
if
isinstance
(
seeds
,
dict
):
if
isinstance
(
seeds
,
dict
):
seeds
,
seed_offsets
=
self
.
_convert_to_homogeneous_nodes
(
seeds
)
seeds
,
seed_offsets
=
self
.
_convert_to_homogeneous_nodes
(
seeds
)
elif
seeds
is
None
and
hasattr
(
self
,
"_seed_offset_list"
)
:
elif
seeds
is
None
:
seed_offsets
=
self
.
_
seed_offset_list
# pylint: disable=no-member
seed_offsets
=
self
.
_
indptr_node_type_offset_list
self
.
_check_sampler_arguments
(
seeds
,
fanouts
,
probs_name
)
self
.
_check_sampler_arguments
(
seeds
,
fanouts
,
probs_name
)
C_sampled_subgraph
=
self
.
_c_csc_graph
.
sample_neighbors
(
C_sampled_subgraph
=
self
.
_c_csc_graph
.
sample_neighbors
(
seeds
,
seeds
,
...
...
python/dgl/graphbolt/impl/neighbor_sampler.py
View file @
e5b92d2b
...
@@ -102,9 +102,9 @@ class FetchInsubgraphData(Mapper):
...
@@ -102,9 +102,9 @@ class FetchInsubgraphData(Mapper):
)
)
if
self
.
prob_name
is
not
None
and
probs_or_mask
is
not
None
:
if
self
.
prob_name
is
not
None
and
probs_or_mask
is
not
None
:
subgraph
.
edge_attributes
=
{
self
.
prob_name
:
probs_or_mask
}
subgraph
.
edge_attributes
=
{
self
.
prob_name
:
probs_or_mask
}
subgraph
.
_seed_offset_list
=
seed_offsets
minibatch
.
sampled_subgraphs
.
insert
(
0
,
subgraph
)
subgraph
.
_indptr_node_type_offset_list
=
seed_offsets
minibatch
.
_sliced_sampling_graph
=
subgraph
if
self
.
stream
is
not
None
:
if
self
.
stream
is
not
None
:
minibatch
.
wait
=
torch
.
cuda
.
current_stream
().
record_event
().
wait
minibatch
.
wait
=
torch
.
cuda
.
current_stream
().
record_event
().
wait
...
@@ -133,7 +133,8 @@ class SamplePerLayerFromFetchedSubgraph(MiniBatchTransformer):
...
@@ -133,7 +133,8 @@ class SamplePerLayerFromFetchedSubgraph(MiniBatchTransformer):
self
.
prob_name
=
sample_per_layer_obj
.
prob_name
self
.
prob_name
=
sample_per_layer_obj
.
prob_name
def
_sample_per_layer_from_fetched_subgraph
(
self
,
minibatch
):
def
_sample_per_layer_from_fetched_subgraph
(
self
,
minibatch
):
subgraph
=
minibatch
.
sampled_subgraphs
[
0
]
subgraph
=
minibatch
.
_sliced_sampling_graph
delattr
(
minibatch
,
"_sliced_sampling_graph"
)
kwargs
=
{
kwargs
=
{
key
[
1
:]:
getattr
(
minibatch
,
key
)
key
[
1
:]:
getattr
(
minibatch
,
key
)
for
key
in
[
"_random_seed"
,
"_seed2_contribution"
]
for
key
in
[
"_random_seed"
,
"_seed2_contribution"
]
...
@@ -146,7 +147,7 @@ class SamplePerLayerFromFetchedSubgraph(MiniBatchTransformer):
...
@@ -146,7 +147,7 @@ class SamplePerLayerFromFetchedSubgraph(MiniBatchTransformer):
self
.
prob_name
,
self
.
prob_name
,
**
kwargs
,
**
kwargs
,
)
)
minibatch
.
sampled_subgraphs
[
0
]
=
sampled_subgraph
minibatch
.
sampled_subgraphs
.
insert
(
0
,
sampled_subgraph
)
return
minibatch
return
minibatch
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment