Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
b9cf36c3
Unverified
Commit
b9cf36c3
authored
Jan 05, 2024
by
Muhammed Fatih BALIN
Committed by
GitHub
Jan 05, 2024
Browse files
[GraphBolt][CUDA] Add a `node_type_offset_list` property. (#6886)
parent
557a8f81
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
38 additions
and
11 deletions
+38
-11
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
+38
-11
No files found.
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
View file @
b9cf36c3
...
@@ -92,7 +92,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -92,7 +92,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
{'N0': 2, 'N1': 3}
{'N0': 2, 'N1': 3}
"""
"""
offset
=
self
.
node_type_offset
offset
=
self
.
_
node_type_offset
_list
# Homogenous.
# Homogenous.
if
offset
is
None
or
self
.
node_type_to_id
is
None
:
if
offset
is
None
or
self
.
node_type_to_id
is
None
:
...
@@ -101,7 +101,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -101,7 +101,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
# Heterogenous
# Heterogenous
else
:
else
:
num_nodes_per_type
=
{
num_nodes_per_type
=
{
_type
:
(
offset
[
_idx
+
1
]
-
offset
[
_idx
])
.
item
()
_type
:
(
offset
[
_idx
+
1
]
-
offset
[
_idx
])
for
_type
,
_idx
in
self
.
node_type_to_id
.
items
()
for
_type
,
_idx
in
self
.
node_type_to_id
.
items
()
}
}
...
@@ -197,7 +197,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -197,7 +197,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
@
property
@
property
def
node_type_offset
(
self
)
->
Optional
[
torch
.
Tensor
]:
def
node_type_offset
(
self
)
->
Optional
[
torch
.
Tensor
]:
"""Returns the node type offset tensor if present.
"""Returns the node type offset tensor if present. Do not modify the
returned tensor in place.
Returns
Returns
-------
-------
...
@@ -212,12 +213,39 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -212,12 +213,39 @@ class FusedCSCSamplingGraph(SamplingGraph):
"""
"""
return
self
.
_c_csc_graph
.
node_type_offset
()
return
self
.
_c_csc_graph
.
node_type_offset
()
@
property
def
_node_type_offset_list
(
self
)
->
Optional
[
list
]:
"""Returns the node type offset list if present.
Returns
-------
list or None
If present, returns a 1D integer list of shape
`(num_node_types + 1,)`. The list is in ascending order as nodes
of the same type have continuous IDs, and larger node IDs are
paired with larger node type IDs. The first value is 0 and last
value is the number of nodes. And nodes with IDs between
`node_type_offset_[i]~node_type_offset_[i+1]` are of type id 'i'.
"""
if
(
not
hasattr
(
self
,
"_node_type_offset_cached_list"
)
or
self
.
_node_type_offset_cached_list
is
None
):
self
.
_node_type_offset_cached_list
=
self
.
node_type_offset
if
self
.
_node_type_offset_cached_list
is
not
None
:
self
.
_node_type_offset_cached_list
=
(
self
.
_node_type_offset_cached_list
.
tolist
()
)
return
self
.
_node_type_offset_cached_list
@
node_type_offset
.
setter
@
node_type_offset
.
setter
def
node_type_offset
(
def
node_type_offset
(
self
,
node_type_offset
:
Optional
[
torch
.
Tensor
]
self
,
node_type_offset
:
Optional
[
torch
.
Tensor
]
)
->
None
:
)
->
None
:
"""Sets the node type offset tensor if present."""
"""Sets the node type offset tensor if present."""
self
.
_c_csc_graph
.
set_node_type_offset
(
node_type_offset
)
self
.
_c_csc_graph
.
set_node_type_offset
(
node_type_offset
)
self
.
_node_type_offset_cached_list
=
None
@
property
@
property
def
type_per_edge
(
self
)
->
Optional
[
torch
.
Tensor
]:
def
type_per_edge
(
self
)
->
Optional
[
torch
.
Tensor
]:
...
@@ -387,11 +415,10 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -387,11 +415,10 @@ class FusedCSCSamplingGraph(SamplingGraph):
def
_convert_to_homogeneous_nodes
(
self
,
nodes
,
timestamps
=
None
):
def
_convert_to_homogeneous_nodes
(
self
,
nodes
,
timestamps
=
None
):
homogeneous_nodes
=
[]
homogeneous_nodes
=
[]
homogeneous_timestamps
=
[]
homogeneous_timestamps
=
[]
offset
=
self
.
_node_type_offset_list
for
ntype
,
ids
in
nodes
.
items
():
for
ntype
,
ids
in
nodes
.
items
():
ntype_id
=
self
.
node_type_to_id
[
ntype
]
ntype_id
=
self
.
node_type_to_id
[
ntype
]
homogeneous_nodes
.
append
(
homogeneous_nodes
.
append
(
ids
+
offset
[
ntype_id
])
ids
+
self
.
node_type_offset
[
ntype_id
].
item
()
)
if
timestamps
is
not
None
:
if
timestamps
is
not
None
:
homogeneous_timestamps
.
append
(
timestamps
[
ntype
])
homogeneous_timestamps
.
append
(
timestamps
[
ntype
])
if
timestamps
is
not
None
:
if
timestamps
is
not
None
:
...
@@ -424,6 +451,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -424,6 +451,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
# The sampled graph is already a homogeneous graph.
# The sampled graph is already a homogeneous graph.
sampled_csc
=
CSCFormatBase
(
indptr
=
indptr
,
indices
=
indices
)
sampled_csc
=
CSCFormatBase
(
indptr
=
indptr
,
indices
=
indices
)
else
:
else
:
# UVA sampling requires us to move node_type_offset to GPU.
self
.
node_type_offset
=
self
.
node_type_offset
.
to
(
column
.
device
)
self
.
node_type_offset
=
self
.
node_type_offset
.
to
(
column
.
device
)
# 1. Find node types for each nodes in column.
# 1. Find node types for each nodes in column.
node_types
=
(
node_types
=
(
...
@@ -434,6 +462,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -434,6 +462,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
original_hetero_edge_ids
=
{}
original_hetero_edge_ids
=
{}
sub_indices
=
{}
sub_indices
=
{}
sub_indptr
=
{}
sub_indptr
=
{}
offset
=
self
.
_node_type_offset_list
# 2. For loop each node type.
# 2. For loop each node type.
for
ntype
,
ntype_id
in
self
.
node_type_to_id
.
items
():
for
ntype
,
ntype_id
in
self
.
node_type_to_id
.
items
():
# Get all nodes of a specific node type in column.
# Get all nodes of a specific node type in column.
...
@@ -446,9 +475,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -446,9 +475,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
# Get all edge ids of a specific edge type.
# Get all edge ids of a specific edge type.
eids
=
torch
.
nonzero
(
type_per_edge
==
etype_id
).
view
(
-
1
)
eids
=
torch
.
nonzero
(
type_per_edge
==
etype_id
).
view
(
-
1
)
src_ntype_id
=
self
.
node_type_to_id
[
src_ntype
]
src_ntype_id
=
self
.
node_type_to_id
[
src_ntype
]
sub_indices
[
etype
]
=
(
sub_indices
[
etype
]
=
indices
[
eids
]
-
offset
[
src_ntype_id
]
indices
[
eids
]
-
self
.
node_type_offset
[
src_ntype_id
]
)
cum_edges
=
torch
.
searchsorted
(
cum_edges
=
torch
.
searchsorted
(
eids
,
nids_original_indptr
,
right
=
False
eids
,
nids_original_indptr
,
right
=
False
)
)
...
@@ -882,9 +909,9 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -882,9 +909,9 @@ class FusedCSCSamplingGraph(SamplingGraph):
negative sampling by edge type."
negative sampling by edge type."
_
,
_
,
dst_node_type
=
etype_str_to_tuple
(
edge_type
)
_
,
_
,
dst_node_type
=
etype_str_to_tuple
(
edge_type
)
dst_node_type_id
=
self
.
node_type_to_id
[
dst_node_type
]
dst_node_type_id
=
self
.
node_type_to_id
[
dst_node_type
]
offset
=
self
.
_node_type_offset_list
max_node_id
=
(
max_node_id
=
(
self
.
node_type_offset
[
dst_node_type_id
+
1
]
offset
[
dst_node_type_id
+
1
]
-
offset
[
dst_node_type_id
]
-
self
.
node_type_offset
[
dst_node_type_id
]
)
)
else
:
else
:
max_node_id
=
self
.
total_num_nodes
max_node_id
=
self
.
total_num_nodes
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment