Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
d873acc2
Unverified
Commit
d873acc2
authored
Dec 11, 2023
by
Rhett Ying
Committed by
GitHub
Dec 11, 2023
Browse files
[GraphBolt] add more APIs for SamplingGraph (#6719)
parent
e6f78c10
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
120 additions
and
4 deletions
+120
-4
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
+49
-0
python/dgl/graphbolt/sampling_graph.py
python/dgl/graphbolt/sampling_graph.py
+58
-0
tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
...n/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
+13
-4
No files found.
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
View file @
d873acc2
...
...
@@ -170,6 +170,55 @@ class FusedCSCSamplingGraph(SamplingGraph):
return
num_nodes_per_type
@
property
def
num_edges
(
self
)
->
Union
[
int
,
Dict
[
str
,
int
]]:
"""The number of edges in the graph.
- If the graph is homogenous, returns an integer.
- If the graph is heterogenous, returns a dictionary.
Returns
-------
Union[int, Dict[str, int]]
The number of edges. Integer indicates the total edges number of a
homogenous graph; dict indicates edges number per edge types of a
heterogenous graph.
Examples
--------
>>> import dgl.graphbolt as gb, torch
>>> total_num_nodes = 5
>>> total_num_edges = 12
>>> ntypes = {"N0": 0, "N1": 1}
>>> etypes = {"N0:R0:N0": 0, "N0:R1:N1": 1,
... "N1:R2:N0": 2, "N1:R3:N1": 3}
>>> indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
>>> indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
>>> node_type_offset = torch.LongTensor([0, 2, 5])
>>> type_per_edge = torch.LongTensor(
... [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])
>>> metadata = gb.GraphMetadata(ntypes, etypes)
>>> graph = gb.from_fused_csc(indptr, indices, node_type_offset,
... type_per_edge, None, metadata)
>>> print(graph.num_edges)
{'N0:R0:N0': 2, 'N0:R1:N1': 1, 'N1:R2:N0': 2, 'N1:R3:N1': 3}
"""
type_per_edge
=
self
.
type_per_edge
# Homogenous.
if
type_per_edge
is
None
or
self
.
edge_type_to_id
is
None
:
return
self
.
_c_csc_graph
.
num_edges
()
# Heterogenous
bincount
=
torch
.
bincount
(
type_per_edge
)
num_edges_per_type
=
{}
for
etype
,
etype_id
in
self
.
edge_type_to_id
.
items
():
if
etype_id
<
len
(
bincount
):
num_edges_per_type
[
etype
]
=
bincount
[
etype_id
].
item
()
else
:
num_edges_per_type
[
etype
]
=
0
return
num_edges_per_type
@
property
def
csc_indptr
(
self
)
->
torch
.
tensor
:
"""Returns the indices pointer in the CSC graph.
...
...
python/dgl/graphbolt/sampling_graph.py
View file @
d873acc2
...
...
@@ -2,6 +2,8 @@
from
typing
import
Dict
,
Union
import
torch
__all__
=
[
"SamplingGraph"
]
...
...
@@ -12,6 +14,16 @@ class SamplingGraph:
def
__init__
(
self
):
pass
def
__repr__
(
self
)
->
str
:
"""Return a string representation of the graph.
Returns
-------
str
String representation of the graph.
"""
raise
NotImplementedError
@
property
def
num_nodes
(
self
)
->
Union
[
int
,
Dict
[
str
,
int
]]:
"""The number of nodes in the graph.
...
...
@@ -26,3 +38,49 @@ class SamplingGraph:
heterogenous graph.
"""
raise
NotImplementedError
@
property
def
num_edges
(
self
)
->
Union
[
int
,
Dict
[
str
,
int
]]:
"""The number of edges in the graph.
- If the graph is homogenous, returns an integer.
- If the graph is heterogenous, returns a dictionary.
Returns
-------
Union[int, Dict[str, int]]
The number of edges. Integer indicates the total edges number of a
homogenous graph; dict indicates edges number per edge types of a
heterogenous graph.
"""
raise
NotImplementedError
def
copy_to_shared_memory
(
self
,
shared_memory_name
:
str
)
->
"SamplingGraph"
:
"""Copy the graph to shared memory.
Parameters
----------
shared_memory_name : str
Name of the shared memory.
Returns
-------
SamplingGraph
The copied SamplingGraph object on shared memory.
"""
raise
NotImplementedError
# pylint: disable=invalid-name
def
to
(
self
,
device
:
torch
.
device
)
->
"SamplingGraph"
:
"""Copy graph to the specified device.
Parameters
----------
device : torch.device
The destination device.
Returns
-------
SamplingGraph
The graph on the specified device.
"""
raise
NotImplementedError
tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
View file @
d873acc2
...
...
@@ -187,7 +187,7 @@ def test_hetero_graph(total_num_nodes, total_num_edges, num_ntypes, num_etypes):
"total_num_nodes, total_num_edges"
,
[(
1
,
1
),
(
100
,
1
),
(
10
,
50
),
(
1000
,
50000
)],
)
def
test_num_nodes_homo
(
total_num_nodes
,
total_num_edges
):
def
test_num_nodes_
edges_
homo
(
total_num_nodes
,
total_num_edges
):
csc_indptr
,
indices
=
gbt
.
random_homo_graph
(
total_num_nodes
,
total_num_edges
)
...
...
@@ -200,6 +200,7 @@ def test_num_nodes_homo(total_num_nodes, total_num_edges):
)
assert
graph
.
num_nodes
==
total_num_nodes
assert
graph
.
num_edges
==
total_num_edges
@
unittest
.
skipIf
(
...
...
@@ -233,6 +234,7 @@ def test_num_nodes_hetero():
"N0:R1:N1"
:
1
,
"N1:R2:N0"
:
2
,
"N1:R3:N1"
:
3
,
"N1:R4:N0"
:
4
,
}
indptr
=
torch
.
LongTensor
([
0
,
3
,
5
,
7
,
9
,
12
])
indices
=
torch
.
LongTensor
([
0
,
1
,
4
,
2
,
3
,
0
,
1
,
1
,
2
,
0
,
3
,
4
])
...
...
@@ -254,9 +256,16 @@ def test_num_nodes_hetero():
"N0"
:
2
,
"N1"
:
3
,
}
assert
graph
.
num_nodes
[
"N0"
]
==
2
assert
graph
.
num_nodes
[
"N1"
]
==
3
assert
"N2"
not
in
graph
.
num_nodes
assert
sum
(
graph
.
num_nodes
.
values
())
==
total_num_nodes
# Verify edges number per edge types.
assert
graph
.
num_edges
==
{
"N0:R0:N0"
:
2
,
"N0:R1:N1"
:
4
,
"N1:R2:N0"
:
3
,
"N1:R3:N1"
:
3
,
"N1:R4:N0"
:
0
,
}
assert
sum
(
graph
.
num_edges
.
values
())
==
total_num_edges
@
unittest
.
skipIf
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment