Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
7439b7e7
Unverified
Commit
7439b7e7
authored
Oct 20, 2023
by
Ramon Zhou
Committed by
GitHub
Oct 20, 2023
Browse files
[GraphBolt] Add to function for CSCSamplingGraph (#6465)
parent
ea58090e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
135 additions
and
0 deletions
+135
-0
graphbolt/include/graphbolt/csc_sampling_graph.h
graphbolt/include/graphbolt/csc_sampling_graph.h
+24
-0
graphbolt/src/python_binding.cc
graphbolt/src/python_binding.cc
+5
-0
python/dgl/graphbolt/impl/csc_sampling_graph.py
python/dgl/graphbolt/impl/csc_sampling_graph.py
+53
-0
tests/python/pytorch/graphbolt/impl/test_csc_sampling_graph.py
.../python/pytorch/graphbolt/impl/test_csc_sampling_graph.py
+53
-0
No files found.
graphbolt/include/graphbolt/csc_sampling_graph.h
View file @
7439b7e7
...
...
@@ -111,6 +111,30 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
return
edge_attributes_
;
}
/** @brief Set the csc index pointer tensor. */
inline
void
SetCSCIndptr
(
const
torch
::
Tensor
&
indptr
)
{
indptr_
=
indptr
;
}
/** @brief Set the index tensor. */
inline
void
SetIndices
(
const
torch
::
Tensor
&
indices
)
{
indices_
=
indices
;
}
/** @brief Set the node type offset tensor for a heterogeneous graph. */
inline
void
SetNodeTypeOffset
(
const
torch
::
optional
<
torch
::
Tensor
>&
node_type_offset
)
{
node_type_offset_
=
node_type_offset
;
}
/** @brief Set the edge type tensor for a heterogeneous graph. */
inline
void
SetTypePerEdge
(
const
torch
::
optional
<
torch
::
Tensor
>&
type_per_edge
)
{
type_per_edge_
=
type_per_edge
;
}
/** @brief Set the edge attributes dictionary. */
inline
void
SetEdgeAttributes
(
const
torch
::
optional
<
EdgeAttrMap
>&
edge_attributes
)
{
edge_attributes_
=
edge_attributes
;
}
/**
* @brief Magic number to indicate graph version in serialize/deserialize
* stage.
...
...
graphbolt/src/python_binding.cc
View file @
7439b7e7
...
...
@@ -32,6 +32,11 @@ TORCH_LIBRARY(graphbolt, m) {
.
def
(
"node_type_offset"
,
&
CSCSamplingGraph
::
NodeTypeOffset
)
.
def
(
"type_per_edge"
,
&
CSCSamplingGraph
::
TypePerEdge
)
.
def
(
"edge_attributes"
,
&
CSCSamplingGraph
::
EdgeAttributes
)
.
def
(
"set_csc_indptr"
,
&
CSCSamplingGraph
::
SetCSCIndptr
)
.
def
(
"set_indices"
,
&
CSCSamplingGraph
::
SetIndices
)
.
def
(
"set_node_type_offset"
,
&
CSCSamplingGraph
::
SetNodeTypeOffset
)
.
def
(
"set_type_per_edge"
,
&
CSCSamplingGraph
::
SetTypePerEdge
)
.
def
(
"set_edge_attributes"
,
&
CSCSamplingGraph
::
SetEdgeAttributes
)
.
def
(
"in_subgraph"
,
&
CSCSamplingGraph
::
InSubgraph
)
.
def
(
"sample_neighbors"
,
&
CSCSamplingGraph
::
SampleNeighbors
)
.
def
(
...
...
python/dgl/graphbolt/impl/csc_sampling_graph.py
View file @
7439b7e7
...
...
@@ -8,6 +8,8 @@ from typing import Dict, Optional, Union
import
torch
from
dgl.utils
import
recursive_apply
from
...base
import
EID
,
ETYPE
from
...convert
import
to_homogeneous
from
...heterograph
import
DGLGraph
...
...
@@ -181,6 +183,11 @@ class CSCSamplingGraph(SamplingGraph):
"""
return
self
.
_c_csc_graph
.
csc_indptr
()
@
csc_indptr
.
setter
def
csc_indptr
(
self
,
csc_indptr
:
torch
.
tensor
)
->
None
:
"""Sets the indices pointer in the CSC graph."""
self
.
_c_csc_graph
.
set_csc_indptr
(
csc_indptr
)
@
property
def
indices
(
self
)
->
torch
.
tensor
:
"""Returns the indices in the CSC graph.
...
...
@@ -198,6 +205,11 @@ class CSCSamplingGraph(SamplingGraph):
"""
return
self
.
_c_csc_graph
.
indices
()
@
indices
.
setter
def
indices
(
self
,
indices
:
torch
.
tensor
)
->
None
:
"""Sets the indices in the CSC graph."""
self
.
_c_csc_graph
.
set_indices
(
indices
)
@
property
def
node_type_offset
(
self
)
->
Optional
[
torch
.
Tensor
]:
"""Returns the node type offset tensor if present.
...
...
@@ -215,6 +227,13 @@ class CSCSamplingGraph(SamplingGraph):
"""
return
self
.
_c_csc_graph
.
node_type_offset
()
@
node_type_offset
.
setter
def
node_type_offset
(
self
,
node_type_offset
:
Optional
[
torch
.
Tensor
]
)
->
None
:
"""Sets the node type offset tensor if present."""
self
.
_c_csc_graph
.
set_node_type_offset
(
node_type_offset
)
@
property
def
type_per_edge
(
self
)
->
Optional
[
torch
.
Tensor
]:
"""Returns the edge type tensor if present.
...
...
@@ -227,6 +246,11 @@ class CSCSamplingGraph(SamplingGraph):
"""
return
self
.
_c_csc_graph
.
type_per_edge
()
@
type_per_edge
.
setter
def
type_per_edge
(
self
,
type_per_edge
:
Optional
[
torch
.
Tensor
])
->
None
:
"""Sets the edge type tensor if present."""
self
.
_c_csc_graph
.
set_type_per_edge
(
type_per_edge
)
@
property
def
edge_attributes
(
self
)
->
Optional
[
Dict
[
str
,
torch
.
Tensor
]]:
"""Returns the edge attributes dictionary.
...
...
@@ -241,6 +265,13 @@ class CSCSamplingGraph(SamplingGraph):
"""
return
self
.
_c_csc_graph
.
edge_attributes
()
@
edge_attributes
.
setter
def
edge_attributes
(
self
,
edge_attributes
:
Optional
[
Dict
[
str
,
torch
.
Tensor
]]
)
->
None
:
"""Sets the edge attributes dictionary."""
self
.
_c_csc_graph
.
set_edge_attributes
(
edge_attributes
)
@
property
def
metadata
(
self
)
->
Optional
[
GraphMetadata
]:
"""Returns the metadata of the graph.
...
...
@@ -674,6 +705,28 @@ class CSCSamplingGraph(SamplingGraph):
self
.
_metadata
,
)
def
to
(
self
,
device
:
torch
.
device
)
->
None
:
# pylint: disable=invalid-name
"""Copy `CSCSamplingGraph` to the specified device."""
def
_to
(
x
,
device
):
return
x
.
to
(
device
)
if
hasattr
(
x
,
"to"
)
else
x
self
.
csc_indptr
=
recursive_apply
(
self
.
csc_indptr
,
lambda
x
:
_to
(
x
,
device
)
)
self
.
indices
=
recursive_apply
(
self
.
indices
,
lambda
x
:
_to
(
x
,
device
))
self
.
node_type_offset
=
recursive_apply
(
self
.
node_type_offset
,
lambda
x
:
_to
(
x
,
device
)
)
self
.
type_per_edge
=
recursive_apply
(
self
.
type_per_edge
,
lambda
x
:
_to
(
x
,
device
)
)
self
.
edge_attributes
=
recursive_apply
(
self
.
edge_attributes
,
lambda
x
:
_to
(
x
,
device
)
)
return
self
def
from_csc
(
csc_indptr
:
torch
.
Tensor
,
...
...
tests/python/pytorch/graphbolt/impl/test_csc_sampling_graph.py
View file @
7439b7e7
...
...
@@ -1604,3 +1604,56 @@ def test_sample_neighbors_hetero_pick_number(
else
:
# Etype 2: 0 valid neighbors.
assert
pairs
[
0
].
size
(
0
)
==
0
@
unittest
.
skipIf
(
F
.
_default_context_str
==
"cpu"
,
reason
=
"`to` function needs GPU to test."
,
)
def
test_csc_sampling_graph_to_device
():
# Initialize data.
total_num_nodes
=
10
total_num_edges
=
9
ntypes
=
{
"N0"
:
0
,
"N1"
:
1
,
"N2"
:
2
,
"N3"
:
3
}
etypes
=
{
"N0:R0:N1"
:
0
,
"N0:R1:N2"
:
1
,
"N0:R2:N3"
:
2
,
}
metadata
=
gb
.
GraphMetadata
(
ntypes
,
etypes
)
indptr
=
torch
.
LongTensor
([
0
,
9
,
9
,
9
,
9
,
9
,
9
,
9
,
9
,
9
,
9
])
indices
=
torch
.
LongTensor
([
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
])
node_type_offset
=
torch
.
LongTensor
([
0
,
1
,
4
,
7
,
10
])
type_per_edge
=
torch
.
LongTensor
([
0
,
0
,
0
,
1
,
1
,
1
,
2
,
2
,
2
])
assert
indptr
[
-
1
]
==
total_num_edges
assert
indptr
[
-
1
]
==
len
(
indices
)
assert
node_type_offset
[
-
1
]
==
total_num_nodes
assert
all
(
type_per_edge
<
len
(
etypes
))
edge_attributes
=
{
"mask"
:
torch
.
BoolTensor
([
1
,
1
,
0
,
1
,
1
,
1
,
0
,
0
,
0
]),
"all"
:
torch
.
BoolTensor
([
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
]),
"zero"
:
torch
.
BoolTensor
([
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
]),
}
# Construct CSCSamplingGraph.
graph
=
gb
.
from_csc
(
indptr
,
indices
,
edge_attributes
=
edge_attributes
,
node_type_offset
=
node_type_offset
,
type_per_edge
=
type_per_edge
,
metadata
=
metadata
,
)
# Copy to device.
graph
=
graph
.
to
(
"cuda"
)
# Check.
assert
graph
.
csc_indptr
.
device
.
type
==
"cuda"
assert
graph
.
indices
.
device
.
type
==
"cuda"
assert
graph
.
node_type_offset
.
device
.
type
==
"cuda"
assert
graph
.
type_per_edge
.
device
.
type
==
"cuda"
assert
graph
.
csc_indptr
.
device
.
type
==
"cuda"
for
key
in
graph
.
edge_attributes
:
assert
graph
.
edge_attributes
[
key
].
device
.
type
==
"cuda"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment