Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
96c89c0b
Unverified
Commit
96c89c0b
authored
Jul 10, 2023
by
Rhett Ying
Committed by
GitHub
Jul 10, 2023
Browse files
[GraphBolt] init graph topology (#5972)
parent
e1781586
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
186 additions
and
54 deletions
+186
-54
python/dgl/graphbolt/impl/ondisk_dataset.py
python/dgl/graphbolt/impl/ondisk_dataset.py
+20
-2
python/dgl/graphbolt/impl/ondisk_metadata.py
python/dgl/graphbolt/impl/ondisk_metadata.py
+16
-0
tests/python/pytorch/graphbolt/gb_test_utils.py
tests/python/pytorch/graphbolt/gb_test_utils.py
+44
-2
tests/python/pytorch/graphbolt/test_csc_sampling_graph.py
tests/python/pytorch/graphbolt/test_csc_sampling_graph.py
+10
-50
tests/python/pytorch/graphbolt/test_ondisk_dataset.py
tests/python/pytorch/graphbolt/test_ondisk_dataset.py
+96
-0
No files found.
python/dgl/graphbolt/impl/ondisk_dataset.py
View file @
96c89c0b
...
@@ -3,9 +3,11 @@
...
@@ -3,9 +3,11 @@
from
typing
import
Dict
,
List
,
Tuple
from
typing
import
Dict
,
List
,
Tuple
from
..dataset
import
Dataset
from
..dataset
import
Dataset
from
..graph_storage
import
CSCSamplingGraph
,
load_csc_sampling_graph
from
..itemset
import
ItemSet
,
ItemSetDict
from
..itemset
import
ItemSet
,
ItemSetDict
from
..utils
import
read_data
,
tensor_to_tuple
from
..utils
import
read_data
,
tensor_to_tuple
from
.ondisk_metadata
import
OnDiskMetaData
,
OnDiskTVTSet
from
.ondisk_metadata
import
OnDiskGraphTopology
,
OnDiskMetaData
,
OnDiskTVTSet
from
.torch_based_feature_store
import
(
from
.torch_based_feature_store
import
(
load_feature_stores
,
load_feature_stores
,
TorchBasedFeatureStore
,
TorchBasedFeatureStore
,
...
@@ -27,6 +29,9 @@ class OnDiskDataset(Dataset):
...
@@ -27,6 +29,9 @@ class OnDiskDataset(Dataset):
.. code-block:: yaml
.. code-block:: yaml
graph_topology:
type: CSCSamplingGraph
path: graph_topology/csc_sampling_graph.tar
feature_data:
feature_data:
- domain: node
- domain: node
type: paper
type: paper
...
@@ -65,6 +70,7 @@ class OnDiskDataset(Dataset):
...
@@ -65,6 +70,7 @@ class OnDiskDataset(Dataset):
def
__init__
(
self
,
path
:
str
)
->
None
:
def
__init__
(
self
,
path
:
str
)
->
None
:
with
open
(
path
,
"r"
)
as
f
:
with
open
(
path
,
"r"
)
as
f
:
self
.
_meta
=
OnDiskMetaData
.
parse_raw
(
f
.
read
(),
proto
=
"yaml"
)
self
.
_meta
=
OnDiskMetaData
.
parse_raw
(
f
.
read
(),
proto
=
"yaml"
)
self
.
_graph
=
self
.
_load_graph
(
self
.
_meta
.
graph_topology
)
self
.
_feature
=
load_feature_stores
(
self
.
_meta
.
feature_data
)
self
.
_feature
=
load_feature_stores
(
self
.
_meta
.
feature_data
)
self
.
_train_sets
=
self
.
_init_tvt_sets
(
self
.
_meta
.
train_sets
)
self
.
_train_sets
=
self
.
_init_tvt_sets
(
self
.
_meta
.
train_sets
)
self
.
_validation_sets
=
self
.
_init_tvt_sets
(
self
.
_meta
.
validation_sets
)
self
.
_validation_sets
=
self
.
_init_tvt_sets
(
self
.
_meta
.
validation_sets
)
...
@@ -84,12 +90,24 @@ class OnDiskDataset(Dataset):
...
@@ -84,12 +90,24 @@ class OnDiskDataset(Dataset):
def
graph
(
self
)
->
object
:
def
graph
(
self
)
->
object
:
"""Return the graph."""
"""Return the graph."""
r
aise
NotImplementedError
r
eturn
self
.
_graph
def
feature
(
self
)
->
Dict
[
Tuple
,
TorchBasedFeatureStore
]:
def
feature
(
self
)
->
Dict
[
Tuple
,
TorchBasedFeatureStore
]:
"""Return the feature."""
"""Return the feature."""
return
self
.
_feature
return
self
.
_feature
def
_load_graph
(
self
,
graph_topology
:
OnDiskGraphTopology
)
->
CSCSamplingGraph
:
"""Load the graph topology."""
if
graph_topology
is
None
:
return
None
if
graph_topology
.
type
==
"CSCSamplingGraph"
:
return
load_csc_sampling_graph
(
graph_topology
.
path
)
raise
NotImplementedError
(
f
"Graph topology type
{
graph_topology
.
type
}
is not supported."
)
def
_init_tvt_sets
(
def
_init_tvt_sets
(
self
,
tvt_sets
:
List
[
List
[
OnDiskTVTSet
]]
self
,
tvt_sets
:
List
[
List
[
OnDiskTVTSet
]]
)
->
List
[
ItemSet
]
or
List
[
ItemSetDict
]:
)
->
List
[
ItemSet
]
or
List
[
ItemSetDict
]:
...
...
python/dgl/graphbolt/impl/ondisk_metadata.py
View file @
96c89c0b
...
@@ -12,6 +12,8 @@ __all__ = [
...
@@ -12,6 +12,8 @@ __all__ = [
"OnDiskFeatureDataDomain"
,
"OnDiskFeatureDataDomain"
,
"OnDiskFeatureData"
,
"OnDiskFeatureData"
,
"OnDiskMetaData"
,
"OnDiskMetaData"
,
"OnDiskGraphTopologyType"
,
"OnDiskGraphTopology"
,
]
]
...
@@ -49,6 +51,19 @@ class OnDiskFeatureData(pydantic.BaseModel):
...
@@ -49,6 +51,19 @@ class OnDiskFeatureData(pydantic.BaseModel):
in_memory
:
Optional
[
bool
]
=
True
in_memory
:
Optional
[
bool
]
=
True
class
OnDiskGraphTopologyType
(
pydantic_yaml
.
YamlStrEnum
):
"""Enum of graph topology type."""
CSC_SAMPLING
=
"CSCSamplingGraph"
class
OnDiskGraphTopology
(
pydantic
.
BaseModel
):
"""The description of an on-disk graph topology."""
type
:
OnDiskGraphTopologyType
path
:
str
class
OnDiskMetaData
(
pydantic_yaml
.
YamlModel
):
class
OnDiskMetaData
(
pydantic_yaml
.
YamlModel
):
"""Metadata specification in YAML.
"""Metadata specification in YAML.
...
@@ -56,6 +71,7 @@ class OnDiskMetaData(pydantic_yaml.YamlModel):
...
@@ -56,6 +71,7 @@ class OnDiskMetaData(pydantic_yaml.YamlModel):
is a list of list of ``OnDiskTVTSet``.
is a list of list of ``OnDiskTVTSet``.
"""
"""
graph_topology
:
Optional
[
OnDiskGraphTopology
]
=
None
feature_data
:
Optional
[
List
[
OnDiskFeatureData
]]
=
[]
feature_data
:
Optional
[
List
[
OnDiskFeatureData
]]
=
[]
train_sets
:
Optional
[
List
[
List
[
OnDiskTVTSet
]]]
=
[]
train_sets
:
Optional
[
List
[
List
[
OnDiskTVTSet
]]]
=
[]
validation_sets
:
Optional
[
List
[
List
[
OnDiskTVTSet
]]]
=
[]
validation_sets
:
Optional
[
List
[
List
[
OnDiskTVTSet
]]]
=
[]
...
...
tests/python/pytorch/graphbolt/gb_test_utils.py
View file @
96c89c0b
import
dgl.graphbolt
import
dgl.graphbolt
as
gb
import
scipy.sparse
as
sp
import
scipy.sparse
as
sp
import
torch
import
torch
...
@@ -11,6 +11,48 @@ def rand_csc_graph(N, density):
...
@@ -11,6 +11,48 @@ def rand_csc_graph(N, density):
indptr
=
torch
.
LongTensor
(
adj
.
indptr
)
indptr
=
torch
.
LongTensor
(
adj
.
indptr
)
indices
=
torch
.
LongTensor
(
adj
.
indices
)
indices
=
torch
.
LongTensor
(
adj
.
indices
)
graph
=
dgl
.
graphbolt
.
from_csc
(
indptr
,
indices
)
graph
=
gb
.
from_csc
(
indptr
,
indices
)
return
graph
return
graph
def
random_homo_graph
(
num_nodes
,
num_edges
):
csc_indptr
=
torch
.
randint
(
0
,
num_edges
,
(
num_nodes
+
1
,))
csc_indptr
=
torch
.
sort
(
csc_indptr
)[
0
]
csc_indptr
[
0
]
=
0
csc_indptr
[
-
1
]
=
num_edges
indices
=
torch
.
randint
(
0
,
num_nodes
,
(
num_edges
,))
return
csc_indptr
,
indices
def
get_metadata
(
num_ntypes
,
num_etypes
):
ntypes
=
{
f
"n
{
i
}
"
:
i
for
i
in
range
(
num_ntypes
)}
etypes
=
{}
count
=
0
for
n1
in
range
(
num_ntypes
):
for
n2
in
range
(
n1
,
num_ntypes
):
if
count
>=
num_etypes
:
break
etypes
.
update
({(
f
"n
{
n1
}
"
,
f
"e
{
count
}
"
,
f
"n
{
n2
}
"
):
count
})
count
+=
1
return
gb
.
GraphMetadata
(
ntypes
,
etypes
)
def
random_hetero_graph
(
num_nodes
,
num_edges
,
num_ntypes
,
num_etypes
):
csc_indptr
,
indices
=
random_homo_graph
(
num_nodes
,
num_edges
)
metadata
=
get_metadata
(
num_ntypes
,
num_etypes
)
# Randomly get node type split point.
node_type_offset
=
torch
.
sort
(
torch
.
randint
(
0
,
num_nodes
,
(
num_ntypes
+
1
,))
)[
0
]
node_type_offset
[
0
]
=
0
node_type_offset
[
-
1
]
=
num_nodes
type_per_edge
=
[]
for
i
in
range
(
num_nodes
):
num
=
csc_indptr
[
i
+
1
]
-
csc_indptr
[
i
]
type_per_edge
.
append
(
torch
.
sort
(
torch
.
randint
(
0
,
num_etypes
,
(
num
,)))[
0
]
)
type_per_edge
=
torch
.
cat
(
type_per_edge
,
dim
=
0
)
return
(
csc_indptr
,
indices
,
node_type_offset
,
type_per_edge
,
metadata
)
tests/python/pytorch/graphbolt/test_csc_sampling_graph.py
View file @
96c89c0b
...
@@ -7,6 +7,8 @@ import backend as F
...
@@ -7,6 +7,8 @@ import backend as F
import
dgl
import
dgl
import
dgl.graphbolt
as
gb
import
dgl.graphbolt
as
gb
import
gb_test_utils
as
gbt
import
pytest
import
pytest
import
torch
import
torch
from
scipy
import
sparse
as
spsp
from
scipy
import
sparse
as
spsp
...
@@ -14,19 +16,6 @@ from scipy import sparse as spsp
...
@@ -14,19 +16,6 @@ from scipy import sparse as spsp
torch
.
manual_seed
(
3407
)
torch
.
manual_seed
(
3407
)
def
get_metadata
(
num_ntypes
,
num_etypes
):
ntypes
=
{
f
"n
{
i
}
"
:
i
for
i
in
range
(
num_ntypes
)}
etypes
=
{}
count
=
0
for
n1
in
range
(
num_ntypes
):
for
n2
in
range
(
n1
,
num_ntypes
):
if
count
>=
num_etypes
:
break
etypes
.
update
({(
f
"n
{
n1
}
"
,
f
"e
{
count
}
"
,
f
"n
{
n2
}
"
):
count
})
count
+=
1
return
gb
.
GraphMetadata
(
ntypes
,
etypes
)
@
unittest
.
skipIf
(
@
unittest
.
skipIf
(
F
.
_default_context_str
==
"gpu"
,
F
.
_default_context_str
==
"gpu"
,
reason
=
"Graph is CPU only at present."
,
reason
=
"Graph is CPU only at present."
,
...
@@ -50,7 +39,7 @@ def test_empty_graph(num_nodes):
...
@@ -50,7 +39,7 @@ def test_empty_graph(num_nodes):
def
test_hetero_empty_graph
(
num_nodes
):
def
test_hetero_empty_graph
(
num_nodes
):
csc_indptr
=
torch
.
zeros
((
num_nodes
+
1
,),
dtype
=
int
)
csc_indptr
=
torch
.
zeros
((
num_nodes
+
1
,),
dtype
=
int
)
indices
=
torch
.
tensor
([])
indices
=
torch
.
tensor
([])
metadata
=
get_metadata
(
num_ntypes
=
3
,
num_etypes
=
5
)
metadata
=
gbt
.
get_metadata
(
num_ntypes
=
3
,
num_etypes
=
5
)
# Some node types have no nodes.
# Some node types have no nodes.
if
num_nodes
==
0
:
if
num_nodes
==
0
:
node_type_offset
=
torch
.
zeros
((
4
,),
dtype
=
int
)
node_type_offset
=
torch
.
zeros
((
4
,),
dtype
=
int
)
...
@@ -109,35 +98,6 @@ def test_metadata_with_etype_exception(etypes):
...
@@ -109,35 +98,6 @@ def test_metadata_with_etype_exception(etypes):
gb
.
GraphMetadata
({
"n1"
:
0
,
"n2"
:
1
,
"n3"
:
2
},
etypes
)
gb
.
GraphMetadata
({
"n1"
:
0
,
"n2"
:
1
,
"n3"
:
2
},
etypes
)
def
random_homo_graph
(
num_nodes
,
num_edges
):
csc_indptr
=
torch
.
randint
(
0
,
num_edges
,
(
num_nodes
+
1
,))
csc_indptr
=
torch
.
sort
(
csc_indptr
)[
0
]
csc_indptr
[
0
]
=
0
csc_indptr
[
-
1
]
=
num_edges
indices
=
torch
.
randint
(
0
,
num_nodes
,
(
num_edges
,))
return
csc_indptr
,
indices
def
random_hetero_graph
(
num_nodes
,
num_edges
,
num_ntypes
,
num_etypes
):
csc_indptr
,
indices
=
random_homo_graph
(
num_nodes
,
num_edges
)
metadata
=
get_metadata
(
num_ntypes
,
num_etypes
)
# Randomly get node type split point.
node_type_offset
=
torch
.
sort
(
torch
.
randint
(
0
,
num_nodes
,
(
num_ntypes
+
1
,))
)[
0
]
node_type_offset
[
0
]
=
0
node_type_offset
[
-
1
]
=
num_nodes
type_per_edge
=
[]
for
i
in
range
(
num_nodes
):
num
=
csc_indptr
[
i
+
1
]
-
csc_indptr
[
i
]
type_per_edge
.
append
(
torch
.
sort
(
torch
.
randint
(
0
,
num_etypes
,
(
num
,)))[
0
]
)
type_per_edge
=
torch
.
cat
(
type_per_edge
,
dim
=
0
)
return
(
csc_indptr
,
indices
,
node_type_offset
,
type_per_edge
,
metadata
)
@
unittest
.
skipIf
(
@
unittest
.
skipIf
(
F
.
_default_context_str
==
"gpu"
,
F
.
_default_context_str
==
"gpu"
,
reason
=
"Graph is CPU only at present."
,
reason
=
"Graph is CPU only at present."
,
...
@@ -146,7 +106,7 @@ def random_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes):
...
@@ -146,7 +106,7 @@ def random_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes):
"num_nodes, num_edges"
,
[(
1
,
1
),
(
100
,
1
),
(
10
,
50
),
(
1000
,
50000
)]
"num_nodes, num_edges"
,
[(
1
,
1
),
(
100
,
1
),
(
10
,
50
),
(
1000
,
50000
)]
)
)
def
test_homo_graph
(
num_nodes
,
num_edges
):
def
test_homo_graph
(
num_nodes
,
num_edges
):
csc_indptr
,
indices
=
random_homo_graph
(
num_nodes
,
num_edges
)
csc_indptr
,
indices
=
gbt
.
random_homo_graph
(
num_nodes
,
num_edges
)
graph
=
gb
.
from_csc
(
csc_indptr
,
indices
)
graph
=
gb
.
from_csc
(
csc_indptr
,
indices
)
assert
graph
.
num_nodes
==
num_nodes
assert
graph
.
num_nodes
==
num_nodes
...
@@ -175,7 +135,7 @@ def test_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes):
...
@@ -175,7 +135,7 @@ def test_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes):
node_type_offset
,
node_type_offset
,
type_per_edge
,
type_per_edge
,
metadata
,
metadata
,
)
=
random_hetero_graph
(
num_nodes
,
num_edges
,
num_ntypes
,
num_etypes
)
)
=
gbt
.
random_hetero_graph
(
num_nodes
,
num_edges
,
num_ntypes
,
num_etypes
)
graph
=
gb
.
from_csc
(
graph
=
gb
.
from_csc
(
csc_indptr
,
indices
,
node_type_offset
,
type_per_edge
,
metadata
csc_indptr
,
indices
,
node_type_offset
,
type_per_edge
,
metadata
)
)
...
@@ -205,7 +165,7 @@ def test_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes):
...
@@ -205,7 +165,7 @@ def test_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes):
)
)
def
test_node_type_offset_wrong_legnth
(
node_type_offset
):
def
test_node_type_offset_wrong_legnth
(
node_type_offset
):
num_ntypes
=
3
num_ntypes
=
3
csc_indptr
,
indices
,
_
,
type_per_edge
,
metadata
=
random_hetero_graph
(
csc_indptr
,
indices
,
_
,
type_per_edge
,
metadata
=
gbt
.
random_hetero_graph
(
10
,
50
,
num_ntypes
,
5
10
,
50
,
num_ntypes
,
5
)
)
with
pytest
.
raises
(
Exception
):
with
pytest
.
raises
(
Exception
):
...
@@ -222,7 +182,7 @@ def test_node_type_offset_wrong_legnth(node_type_offset):
...
@@ -222,7 +182,7 @@ def test_node_type_offset_wrong_legnth(node_type_offset):
"num_nodes, num_edges"
,
[(
1
,
1
),
(
100
,
1
),
(
10
,
50
),
(
1000
,
50000
)]
"num_nodes, num_edges"
,
[(
1
,
1
),
(
100
,
1
),
(
10
,
50
),
(
1000
,
50000
)]
)
)
def
test_load_save_homo_graph
(
num_nodes
,
num_edges
):
def
test_load_save_homo_graph
(
num_nodes
,
num_edges
):
csc_indptr
,
indices
=
random_homo_graph
(
num_nodes
,
num_edges
)
csc_indptr
,
indices
=
gbt
.
random_homo_graph
(
num_nodes
,
num_edges
)
graph
=
gb
.
from_csc
(
csc_indptr
,
indices
)
graph
=
gb
.
from_csc
(
csc_indptr
,
indices
)
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
...
@@ -256,7 +216,7 @@ def test_load_save_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes):
...
@@ -256,7 +216,7 @@ def test_load_save_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes):
node_type_offset
,
node_type_offset
,
type_per_edge
,
type_per_edge
,
metadata
,
metadata
,
)
=
random_hetero_graph
(
num_nodes
,
num_edges
,
num_ntypes
,
num_etypes
)
)
=
gbt
.
random_hetero_graph
(
num_nodes
,
num_edges
,
num_ntypes
,
num_etypes
)
graph
=
gb
.
from_csc
(
graph
=
gb
.
from_csc
(
csc_indptr
,
indices
,
node_type_offset
,
type_per_edge
,
metadata
csc_indptr
,
indices
,
node_type_offset
,
type_per_edge
,
metadata
)
)
...
@@ -652,7 +612,7 @@ def check_tensors_on_the_same_shared_memory(t1: torch.Tensor, t2: torch.Tensor):
...
@@ -652,7 +612,7 @@ def check_tensors_on_the_same_shared_memory(t1: torch.Tensor, t2: torch.Tensor):
"num_nodes, num_edges"
,
[(
1
,
1
),
(
100
,
1
),
(
10
,
50
),
(
1000
,
50000
)]
"num_nodes, num_edges"
,
[(
1
,
1
),
(
100
,
1
),
(
10
,
50
),
(
1000
,
50000
)]
)
)
def
test_homo_graph_on_shared_memory
(
num_nodes
,
num_edges
):
def
test_homo_graph_on_shared_memory
(
num_nodes
,
num_edges
):
csc_indptr
,
indices
=
random_homo_graph
(
num_nodes
,
num_edges
)
csc_indptr
,
indices
=
gbt
.
random_homo_graph
(
num_nodes
,
num_edges
)
graph
=
gb
.
from_csc
(
csc_indptr
,
indices
)
graph
=
gb
.
from_csc
(
csc_indptr
,
indices
)
shm_name
=
"test_homo_g"
shm_name
=
"test_homo_g"
...
@@ -700,7 +660,7 @@ def test_hetero_graph_on_shared_memory(
...
@@ -700,7 +660,7 @@ def test_hetero_graph_on_shared_memory(
node_type_offset
,
node_type_offset
,
type_per_edge
,
type_per_edge
,
metadata
,
metadata
,
)
=
random_hetero_graph
(
num_nodes
,
num_edges
,
num_ntypes
,
num_etypes
)
)
=
gbt
.
random_hetero_graph
(
num_nodes
,
num_edges
,
num_ntypes
,
num_etypes
)
graph
=
gb
.
from_csc
(
graph
=
gb
.
from_csc
(
csc_indptr
,
indices
,
node_type_offset
,
type_per_edge
,
metadata
csc_indptr
,
indices
,
node_type_offset
,
type_per_edge
,
metadata
)
)
...
...
tests/python/pytorch/graphbolt/test_ondisk_dataset.py
View file @
96c89c0b
import
os
import
os
import
tempfile
import
tempfile
import
gb_test_utils
as
gbt
import
numpy
as
np
import
numpy
as
np
import
pydantic
import
pydantic
...
@@ -616,3 +618,97 @@ def test_OnDiskDataset_Feature_homograph():
...
@@ -616,3 +618,97 @@ def test_OnDiskDataset_Feature_homograph():
edge_label
=
None
edge_label
=
None
feature_data
=
None
feature_data
=
None
dataset
=
None
dataset
=
None
def
test_OnDiskDataset_Graph_Exceptions
():
"""Test exceptions in parsing graph topology."""
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
# Invalid graph type.
yaml_content
=
"""
graph_topology:
type: CSRSamplingGraph
path: /path/to/graph
"""
yaml_file
=
os
.
path
.
join
(
test_dir
,
"test.yaml"
)
with
open
(
yaml_file
,
"w"
)
as
f
:
f
.
write
(
yaml_content
)
with
pytest
.
raises
(
pydantic
.
ValidationError
,
match
=
"value is not a valid enumeration member"
,
):
_
=
gb
.
OnDiskDataset
(
yaml_file
)
def
test_OnDiskDataset_Graph_homogeneous
():
"""Test homogeneous graph topology."""
csc_indptr
,
indices
=
gbt
.
random_homo_graph
(
1000
,
10
*
1000
)
graph
=
gb
.
from_csc
(
csc_indptr
,
indices
)
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
graph_path
=
os
.
path
.
join
(
test_dir
,
"csc_sampling_graph.tar"
)
gb
.
save_csc_sampling_graph
(
graph
,
graph_path
)
yaml_content
=
f
"""
graph_topology:
type: CSCSamplingGraph
path:
{
graph_path
}
"""
yaml_file
=
os
.
path
.
join
(
test_dir
,
"test.yaml"
)
with
open
(
yaml_file
,
"w"
)
as
f
:
f
.
write
(
yaml_content
)
dataset
=
gb
.
OnDiskDataset
(
yaml_file
)
graph2
=
dataset
.
graph
()
assert
graph
.
num_nodes
==
graph2
.
num_nodes
assert
graph
.
num_edges
==
graph2
.
num_edges
assert
torch
.
equal
(
graph
.
csc_indptr
,
graph2
.
csc_indptr
)
assert
torch
.
equal
(
graph
.
indices
,
graph2
.
indices
)
assert
graph
.
metadata
is
None
and
graph2
.
metadata
is
None
assert
(
graph
.
node_type_offset
is
None
and
graph2
.
node_type_offset
is
None
)
assert
graph
.
type_per_edge
is
None
and
graph2
.
type_per_edge
is
None
def
test_OnDiskDataset_Graph_heterogeneous
():
"""Test heterogeneous graph topology."""
(
csc_indptr
,
indices
,
node_type_offset
,
type_per_edge
,
metadata
,
)
=
gbt
.
random_hetero_graph
(
1000
,
10
*
1000
,
3
,
4
)
graph
=
gb
.
from_csc
(
csc_indptr
,
indices
,
node_type_offset
,
type_per_edge
,
metadata
)
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
graph_path
=
os
.
path
.
join
(
test_dir
,
"csc_sampling_graph.tar"
)
gb
.
save_csc_sampling_graph
(
graph
,
graph_path
)
yaml_content
=
f
"""
graph_topology:
type: CSCSamplingGraph
path:
{
graph_path
}
"""
yaml_file
=
os
.
path
.
join
(
test_dir
,
"test.yaml"
)
with
open
(
yaml_file
,
"w"
)
as
f
:
f
.
write
(
yaml_content
)
dataset
=
gb
.
OnDiskDataset
(
yaml_file
)
graph2
=
dataset
.
graph
()
assert
graph
.
num_nodes
==
graph2
.
num_nodes
assert
graph
.
num_edges
==
graph2
.
num_edges
assert
torch
.
equal
(
graph
.
csc_indptr
,
graph2
.
csc_indptr
)
assert
torch
.
equal
(
graph
.
indices
,
graph2
.
indices
)
assert
torch
.
equal
(
graph
.
node_type_offset
,
graph2
.
node_type_offset
)
assert
torch
.
equal
(
graph
.
type_per_edge
,
graph2
.
type_per_edge
)
assert
graph
.
metadata
.
node_type_to_id
==
graph2
.
metadata
.
node_type_to_id
assert
graph
.
metadata
.
edge_type_to_id
==
graph2
.
metadata
.
edge_type_to_id
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment