Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
c2134442
Unverified
Commit
c2134442
authored
Dec 07, 2023
by
Rhett Ying
Committed by
GitHub
Dec 07, 2023
Browse files
[GraphBolt] use torch.load/save instead of load/save_fused_xxx() (#6707)
parent
0348ad3d
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
54 additions
and
127 deletions
+54
-127
graphbolt/include/graphbolt/serialize.h
graphbolt/include/graphbolt/serialize.h
+11
-21
graphbolt/src/python_binding.cc
graphbolt/src/python_binding.cc
+0
-2
graphbolt/src/serialize.cc
graphbolt/src/serialize.cc
+0
-13
python/dgl/distributed/partition.py
python/dgl/distributed/partition.py
+5
-2
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
+7
-45
python/dgl/graphbolt/impl/ondisk_dataset.py
python/dgl/graphbolt/impl/ondisk_dataset.py
+4
-9
tests/distributed/test_partition.py
tests/distributed/test_partition.py
+4
-4
tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
...n/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
+8
-8
tests/python/pytorch/graphbolt/impl/test_ondisk_dataset.py
tests/python/pytorch/graphbolt/impl/test_ondisk_dataset.py
+15
-23
No files found.
graphbolt/include/graphbolt/serialize.h
View file @
c2134442
...
@@ -21,11 +21,16 @@ namespace torch {
...
@@ -21,11 +21,16 @@ namespace torch {
/**
/**
* @brief Overload input stream operator for FusedCSCSamplingGraph
* @brief Overload input stream operator for FusedCSCSamplingGraph
* deserialization.
* deserialization. This enables `torch::load()` for FusedCSCSamplingGraph.
*
* @param archive Input stream for deserializing.
* @param archive Input stream for deserializing.
* @param graph FusedCSCSamplingGraph.
* @param graph FusedCSCSamplingGraph.
*
*
* @return archive
* @return archive
*
* @code
* auto&& graph = c10::make_intrusive<sampling::FusedCSCSamplingGraph>();
* torch::load(*graph, filename);
*/
*/
inline
serialize
::
InputArchive
&
operator
>>
(
inline
serialize
::
InputArchive
&
operator
>>
(
serialize
::
InputArchive
&
archive
,
serialize
::
InputArchive
&
archive
,
...
@@ -33,11 +38,15 @@ inline serialize::InputArchive& operator>>(
...
@@ -33,11 +38,15 @@ inline serialize::InputArchive& operator>>(
/**
/**
* @brief Overload output stream operator for FusedCSCSamplingGraph
* @brief Overload output stream operator for FusedCSCSamplingGraph
* serialization.
* serialization.
This enables `torch::save()` for FusedCSCSamplingGraph.
* @param archive Output stream for serializing.
* @param archive Output stream for serializing.
* @param graph FusedCSCSamplingGraph.
* @param graph FusedCSCSamplingGraph.
*
*
* @return archive
* @return archive
*
* @code
* auto&& graph = c10::make_intrusive<sampling::FusedCSCSamplingGraph>();
* torch::save(*graph, filename);
*/
*/
inline
serialize
::
OutputArchive
&
operator
<<
(
inline
serialize
::
OutputArchive
&
operator
<<
(
serialize
::
OutputArchive
&
archive
,
serialize
::
OutputArchive
&
archive
,
...
@@ -47,25 +56,6 @@ inline serialize::OutputArchive& operator<<(
...
@@ -47,25 +56,6 @@ inline serialize::OutputArchive& operator<<(
namespace
graphbolt
{
namespace
graphbolt
{
/**
* @brief Load FusedCSCSamplingGraph from file.
* @param filename File name to read.
*
* @return FusedCSCSamplingGraph.
*/
c10
::
intrusive_ptr
<
sampling
::
FusedCSCSamplingGraph
>
LoadFusedCSCSamplingGraph
(
const
std
::
string
&
filename
);
/**
* @brief Save FusedCSCSamplingGraph to file.
* @param graph FusedCSCSamplingGraph to save.
* @param filename File name to save.
*
*/
void
SaveFusedCSCSamplingGraph
(
c10
::
intrusive_ptr
<
sampling
::
FusedCSCSamplingGraph
>
graph
,
const
std
::
string
&
filename
);
/**
/**
* @brief Read data from archive.
* @brief Read data from archive.
* @param archive Input archive.
* @param archive Input archive.
...
...
graphbolt/src/python_binding.cc
View file @
c2134442
...
@@ -66,8 +66,6 @@ TORCH_LIBRARY(graphbolt, m) {
...
@@ -66,8 +66,6 @@ TORCH_LIBRARY(graphbolt, m) {
return
g
;
return
g
;
});
});
m
.
def
(
"from_fused_csc"
,
&
FusedCSCSamplingGraph
::
FromCSC
);
m
.
def
(
"from_fused_csc"
,
&
FusedCSCSamplingGraph
::
FromCSC
);
m
.
def
(
"load_fused_csc_sampling_graph"
,
&
LoadFusedCSCSamplingGraph
);
m
.
def
(
"save_fused_csc_sampling_graph"
,
&
SaveFusedCSCSamplingGraph
);
m
.
def
(
m
.
def
(
"load_from_shared_memory"
,
&
FusedCSCSamplingGraph
::
LoadFromSharedMemory
);
"load_from_shared_memory"
,
&
FusedCSCSamplingGraph
::
LoadFromSharedMemory
);
m
.
def
(
"unique_and_compact"
,
&
UniqueAndCompact
);
m
.
def
(
"unique_and_compact"
,
&
UniqueAndCompact
);
...
...
graphbolt/src/serialize.cc
View file @
c2134442
...
@@ -27,19 +27,6 @@ serialize::OutputArchive& operator<<(
...
@@ -27,19 +27,6 @@ serialize::OutputArchive& operator<<(
namespace
graphbolt
{
namespace
graphbolt
{
c10
::
intrusive_ptr
<
sampling
::
FusedCSCSamplingGraph
>
LoadFusedCSCSamplingGraph
(
const
std
::
string
&
filename
)
{
auto
&&
graph
=
c10
::
make_intrusive
<
sampling
::
FusedCSCSamplingGraph
>
();
torch
::
load
(
*
graph
,
filename
);
return
graph
;
}
void
SaveFusedCSCSamplingGraph
(
c10
::
intrusive_ptr
<
sampling
::
FusedCSCSamplingGraph
>
graph
,
const
std
::
string
&
filename
)
{
torch
::
save
(
*
graph
,
filename
);
}
torch
::
IValue
read_from_archive
(
torch
::
IValue
read_from_archive
(
torch
::
serialize
::
InputArchive
&
archive
,
const
std
::
string
&
key
)
{
torch
::
serialize
::
InputArchive
&
archive
,
const
std
::
string
&
key
)
{
torch
::
IValue
data
;
torch
::
IValue
data
;
...
...
python/dgl/distributed/partition.py
View file @
c2134442
...
@@ -7,6 +7,8 @@ import time
...
@@ -7,6 +7,8 @@ import time
import
numpy
as
np
import
numpy
as
np
import
torch
from
..
import
backend
as
F
from
..
import
backend
as
F
from
..base
import
DGLError
,
EID
,
ETYPE
,
NID
,
NTYPE
from
..base
import
DGLError
,
EID
,
ETYPE
,
NID
,
NTYPE
from
..convert
import
to_homogeneous
from
..convert
import
to_homogeneous
...
@@ -1236,6 +1238,7 @@ def convert_dgl_partition_to_csc_sampling_graph(part_config):
...
@@ -1236,6 +1238,7 @@ def convert_dgl_partition_to_csc_sampling_graph(part_config):
part_config : str
part_config : str
The partition configuration JSON file.
The partition configuration JSON file.
"""
"""
# As only this function requires GraphBolt for now, let's import here.
# As only this function requires GraphBolt for now, let's import here.
from
..
import
graphbolt
from
..
import
graphbolt
...
@@ -1279,6 +1282,6 @@ def convert_dgl_partition_to_csc_sampling_graph(part_config):
...
@@ -1279,6 +1282,6 @@ def convert_dgl_partition_to_csc_sampling_graph(part_config):
part_meta
[
f
"part-
{
part_id
}
"
][
"part_graph"
],
part_meta
[
f
"part-
{
part_id
}
"
][
"part_graph"
],
)
)
csc_graph_path
=
os
.
path
.
join
(
csc_graph_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
orig_graph_path
),
"fused_csc_sampling_graph.t
ar
"
os
.
path
.
dirname
(
orig_graph_path
),
"fused_csc_sampling_graph.
p
t"
)
)
graphbolt
.
save_fused_csc_sampling_graph
(
csc_graph
,
csc_graph_path
)
torch
.
save
(
csc_graph
,
csc_graph_path
)
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
View file @
c2134442
"""CSC format sampling graph."""
"""CSC format sampling graph."""
# pylint: disable= invalid-name
# pylint: disable= invalid-name
import
os
import
tarfile
import
tempfile
from
collections
import
defaultdict
from
collections
import
defaultdict
from
typing
import
Dict
,
Optional
,
Union
from
typing
import
Dict
,
Optional
,
Union
...
@@ -27,8 +24,6 @@ __all__ = [
...
@@ -27,8 +24,6 @@ __all__ = [
"FusedCSCSamplingGraph"
,
"FusedCSCSamplingGraph"
,
"from_fused_csc"
,
"from_fused_csc"
,
"load_from_shared_memory"
,
"load_from_shared_memory"
,
"load_fused_csc_sampling_graph"
,
"save_fused_csc_sampling_graph"
,
"from_dglgraph"
,
"from_dglgraph"
,
]
]
...
@@ -99,11 +94,11 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -99,11 +94,11 @@ class FusedCSCSamplingGraph(SamplingGraph):
return
_csc_sampling_graph_str
(
self
)
return
_csc_sampling_graph_str
(
self
)
def
__init__
(
def
__init__
(
self
,
c_csc_graph
:
torch
.
ScriptObject
,
metadata
:
Optional
[
GraphMetadata
]
self
,
c_csc_graph
:
torch
.
ScriptObject
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
_c_csc_graph
=
c_csc_graph
self
.
_c_csc_graph
=
c_csc_graph
self
.
_metadata
=
metadata
@
property
@
property
def
total_num_nodes
(
self
)
->
int
:
def
total_num_nodes
(
self
)
->
int
:
...
@@ -318,12 +313,16 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -318,12 +313,16 @@ class FusedCSCSamplingGraph(SamplingGraph):
def
metadata
(
self
)
->
Optional
[
GraphMetadata
]:
def
metadata
(
self
)
->
Optional
[
GraphMetadata
]:
"""Returns the metadata of the graph.
"""Returns the metadata of the graph.
[TODO][Rui] This API needs to be updated.
Returns
Returns
-------
-------
GraphMetadata or None
GraphMetadata or None
If present, returns the metadata of the graph.
If present, returns the metadata of the graph.
"""
"""
return
self
.
_metadata
if
self
.
node_type_to_id
is
None
or
self
.
edge_type_to_id
is
None
:
return
None
return
GraphMetadata
(
self
.
node_type_to_id
,
self
.
edge_type_to_id
)
def
in_subgraph
(
def
in_subgraph
(
self
,
nodes
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]
self
,
nodes
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]
...
@@ -884,7 +883,6 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -884,7 +883,6 @@ class FusedCSCSamplingGraph(SamplingGraph):
"""
"""
return
FusedCSCSamplingGraph
(
return
FusedCSCSamplingGraph
(
self
.
_c_csc_graph
.
copy_to_shared_memory
(
shared_memory_name
),
self
.
_c_csc_graph
.
copy_to_shared_memory
(
shared_memory_name
),
self
.
_metadata
,
)
)
def
to
(
self
,
device
:
torch
.
device
)
->
None
:
# pylint: disable=invalid-name
def
to
(
self
,
device
:
torch
.
device
)
->
None
:
# pylint: disable=invalid-name
...
@@ -975,13 +973,11 @@ def from_fused_csc(
...
@@ -975,13 +973,11 @@ def from_fused_csc(
edge_type_to_id
,
edge_type_to_id
,
edge_attributes
,
edge_attributes
,
),
),
metadata
,
)
)
def
load_from_shared_memory
(
def
load_from_shared_memory
(
shared_memory_name
:
str
,
shared_memory_name
:
str
,
metadata
:
Optional
[
GraphMetadata
]
=
None
,
)
->
FusedCSCSamplingGraph
:
)
->
FusedCSCSamplingGraph
:
"""Load a FusedCSCSamplingGraph object from shared memory.
"""Load a FusedCSCSamplingGraph object from shared memory.
...
@@ -997,7 +993,6 @@ def load_from_shared_memory(
...
@@ -997,7 +993,6 @@ def load_from_shared_memory(
"""
"""
return
FusedCSCSamplingGraph
(
return
FusedCSCSamplingGraph
(
torch
.
ops
.
graphbolt
.
load_from_shared_memory
(
shared_memory_name
),
torch
.
ops
.
graphbolt
.
load_from_shared_memory
(
shared_memory_name
),
metadata
,
)
)
...
@@ -1033,38 +1028,6 @@ def _csc_sampling_graph_str(graph: FusedCSCSamplingGraph) -> str:
...
@@ -1033,38 +1028,6 @@ def _csc_sampling_graph_str(graph: FusedCSCSamplingGraph) -> str:
return
final_str
return
final_str
def
load_fused_csc_sampling_graph
(
filename
):
"""Load FusedCSCSamplingGraph from tar file."""
with
tempfile
.
TemporaryDirectory
()
as
temp_dir
:
with
tarfile
.
open
(
filename
,
"r"
)
as
archive
:
archive
.
extractall
(
temp_dir
)
graph_filename
=
os
.
path
.
join
(
temp_dir
,
"fused_csc_sampling_graph.pt"
)
metadata_filename
=
os
.
path
.
join
(
temp_dir
,
"metadata.pt"
)
return
FusedCSCSamplingGraph
(
torch
.
ops
.
graphbolt
.
load_fused_csc_sampling_graph
(
graph_filename
),
torch
.
load
(
metadata_filename
),
)
def
save_fused_csc_sampling_graph
(
graph
,
filename
):
"""Save FusedCSCSamplingGraph to tar file."""
with
tempfile
.
TemporaryDirectory
()
as
temp_dir
:
graph_filename
=
os
.
path
.
join
(
temp_dir
,
"fused_csc_sampling_graph.pt"
)
torch
.
ops
.
graphbolt
.
save_fused_csc_sampling_graph
(
graph
.
_c_csc_graph
,
graph_filename
)
metadata_filename
=
os
.
path
.
join
(
temp_dir
,
"metadata.pt"
)
torch
.
save
(
graph
.
metadata
,
metadata_filename
)
with
tarfile
.
open
(
filename
,
"w"
)
as
archive
:
archive
.
add
(
graph_filename
,
arcname
=
os
.
path
.
basename
(
graph_filename
)
)
archive
.
add
(
metadata_filename
,
arcname
=
os
.
path
.
basename
(
metadata_filename
)
)
print
(
f
"FusedCSCSamplingGraph has been saved to
{
filename
}
."
)
def
from_dglgraph
(
def
from_dglgraph
(
g
:
DGLGraph
,
g
:
DGLGraph
,
is_homogeneous
:
bool
=
False
,
is_homogeneous
:
bool
=
False
,
...
@@ -1114,5 +1077,4 @@ def from_dglgraph(
...
@@ -1114,5 +1077,4 @@ def from_dglgraph(
edge_type_to_id
,
edge_type_to_id
,
edge_attributes
,
edge_attributes
,
),
),
metadata
,
)
)
python/dgl/graphbolt/impl/ondisk_dataset.py
View file @
c2134442
...
@@ -17,12 +17,7 @@ from ..dataset import Dataset, Task
...
@@ -17,12 +17,7 @@ from ..dataset import Dataset, Task
from
..internal
import
copy_or_convert_data
,
read_data
from
..internal
import
copy_or_convert_data
,
read_data
from
..itemset
import
ItemSet
,
ItemSetDict
from
..itemset
import
ItemSet
,
ItemSetDict
from
..sampling_graph
import
SamplingGraph
from
..sampling_graph
import
SamplingGraph
from
.fused_csc_sampling_graph
import
(
from
.fused_csc_sampling_graph
import
from_dglgraph
,
FusedCSCSamplingGraph
from_dglgraph
,
FusedCSCSamplingGraph
,
load_fused_csc_sampling_graph
,
save_fused_csc_sampling_graph
,
)
from
.ondisk_metadata
import
(
from
.ondisk_metadata
import
(
OnDiskGraphTopology
,
OnDiskGraphTopology
,
OnDiskMetaData
,
OnDiskMetaData
,
...
@@ -147,10 +142,10 @@ def preprocess_ondisk_dataset(
...
@@ -147,10 +142,10 @@ def preprocess_ondisk_dataset(
output_config
[
"graph_topology"
]
=
{}
output_config
[
"graph_topology"
]
=
{}
output_config
[
"graph_topology"
][
"type"
]
=
"FusedCSCSamplingGraph"
output_config
[
"graph_topology"
][
"type"
]
=
"FusedCSCSamplingGraph"
output_config
[
"graph_topology"
][
"path"
]
=
os
.
path
.
join
(
output_config
[
"graph_topology"
][
"path"
]
=
os
.
path
.
join
(
processed_dir_prefix
,
"fused_csc_sampling_graph.t
ar
"
processed_dir_prefix
,
"fused_csc_sampling_graph.
p
t"
)
)
save_fused_csc_sampling_graph
(
torch
.
save
(
fused_csc_sampling_graph
,
fused_csc_sampling_graph
,
os
.
path
.
join
(
os
.
path
.
join
(
dataset_dir
,
dataset_dir
,
...
@@ -452,7 +447,7 @@ class OnDiskDataset(Dataset):
...
@@ -452,7 +447,7 @@ class OnDiskDataset(Dataset):
if
graph_topology
is
None
:
if
graph_topology
is
None
:
return
None
return
None
if
graph_topology
.
type
==
"FusedCSCSamplingGraph"
:
if
graph_topology
.
type
==
"FusedCSCSamplingGraph"
:
return
load_fused_csc_sampling_graph
(
graph_topology
.
path
)
return
torch
.
load
(
graph_topology
.
path
)
raise
NotImplementedError
(
raise
NotImplementedError
(
f
"Graph topology type
{
graph_topology
.
type
}
is not supported."
f
"Graph topology type
{
graph_topology
.
type
}
is not supported."
)
)
...
...
tests/distributed/test_partition.py
View file @
c2134442
...
@@ -695,9 +695,9 @@ def test_convert_dgl_partition_to_csc_sampling_graph_homo(
...
@@ -695,9 +695,9 @@ def test_convert_dgl_partition_to_csc_sampling_graph_homo(
orig_g
=
dgl
.
load_graphs
(
orig_g
=
dgl
.
load_graphs
(
os
.
path
.
join
(
test_dir
,
f
"part
{
part_id
}
/graph.dgl"
)
os
.
path
.
join
(
test_dir
,
f
"part
{
part_id
}
/graph.dgl"
)
)[
0
][
0
]
)[
0
][
0
]
new_g
=
dgl
.
graphbolt
.
load_fused_csc_sampling_graph
(
new_g
=
th
.
load
(
os
.
path
.
join
(
os
.
path
.
join
(
test_dir
,
f
"part
{
part_id
}
/fused_csc_sampling_graph.t
ar
"
test_dir
,
f
"part
{
part_id
}
/fused_csc_sampling_graph.
p
t"
)
)
)
)
orig_indptr
,
orig_indices
,
_
=
orig_g
.
adj
().
csc
()
orig_indptr
,
orig_indices
,
_
=
orig_g
.
adj
().
csc
()
...
@@ -728,9 +728,9 @@ def test_convert_dgl_partition_to_csc_sampling_graph_hetero(
...
@@ -728,9 +728,9 @@ def test_convert_dgl_partition_to_csc_sampling_graph_hetero(
orig_g
=
dgl
.
load_graphs
(
orig_g
=
dgl
.
load_graphs
(
os
.
path
.
join
(
test_dir
,
f
"part
{
part_id
}
/graph.dgl"
)
os
.
path
.
join
(
test_dir
,
f
"part
{
part_id
}
/graph.dgl"
)
)[
0
][
0
]
)[
0
][
0
]
new_g
=
dgl
.
graphbolt
.
load_fused_csc_sampling_graph
(
new_g
=
th
.
load
(
os
.
path
.
join
(
os
.
path
.
join
(
test_dir
,
f
"part
{
part_id
}
/fused_csc_sampling_graph.t
ar
"
test_dir
,
f
"part
{
part_id
}
/fused_csc_sampling_graph.
p
t"
)
)
)
)
orig_indptr
,
orig_indices
,
_
=
orig_g
.
adj
().
csc
()
orig_indptr
,
orig_indices
,
_
=
orig_g
.
adj
().
csc
()
...
...
tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
View file @
c2134442
...
@@ -297,9 +297,9 @@ def test_load_save_homo_graph(total_num_nodes, total_num_edges):
...
@@ -297,9 +297,9 @@ def test_load_save_homo_graph(total_num_nodes, total_num_edges):
graph
=
gb
.
from_fused_csc
(
csc_indptr
,
indices
)
graph
=
gb
.
from_fused_csc
(
csc_indptr
,
indices
)
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
filename
=
os
.
path
.
join
(
test_dir
,
"fused_csc_sampling_graph.t
ar
"
)
filename
=
os
.
path
.
join
(
test_dir
,
"fused_csc_sampling_graph.
p
t"
)
gb
.
save_fused_csc_sampling_graph
(
graph
,
filename
)
torch
.
save
(
graph
,
filename
)
graph2
=
gb
.
load_fused_csc_sampling_graph
(
filename
)
graph2
=
torch
.
load
(
filename
)
assert
graph
.
total_num_nodes
==
graph2
.
total_num_nodes
assert
graph
.
total_num_nodes
==
graph2
.
total_num_nodes
assert
graph
.
total_num_edges
==
graph2
.
total_num_edges
assert
graph
.
total_num_edges
==
graph2
.
total_num_edges
...
@@ -338,9 +338,9 @@ def test_load_save_hetero_graph(
...
@@ -338,9 +338,9 @@ def test_load_save_hetero_graph(
)
)
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
filename
=
os
.
path
.
join
(
test_dir
,
"fused_csc_sampling_graph.t
ar
"
)
filename
=
os
.
path
.
join
(
test_dir
,
"fused_csc_sampling_graph.
p
t"
)
gb
.
save_fused_csc_sampling_graph
(
graph
,
filename
)
torch
.
save
(
graph
,
filename
)
graph2
=
gb
.
load_fused_csc_sampling_graph
(
filename
)
graph2
=
torch
.
load
(
filename
)
assert
graph
.
total_num_nodes
==
graph2
.
total_num_nodes
assert
graph
.
total_num_nodes
==
graph2
.
total_num_nodes
assert
graph
.
total_num_edges
==
graph2
.
total_num_edges
assert
graph
.
total_num_edges
==
graph2
.
total_num_edges
...
@@ -1103,7 +1103,7 @@ def test_homo_graph_on_shared_memory(
...
@@ -1103,7 +1103,7 @@ def test_homo_graph_on_shared_memory(
shm_name
=
"test_homo_g"
shm_name
=
"test_homo_g"
graph1
=
graph
.
copy_to_shared_memory
(
shm_name
)
graph1
=
graph
.
copy_to_shared_memory
(
shm_name
)
graph2
=
gb
.
load_from_shared_memory
(
shm_name
,
graph
.
metadata
)
graph2
=
gb
.
load_from_shared_memory
(
shm_name
)
assert
graph1
.
total_num_nodes
==
total_num_nodes
assert
graph1
.
total_num_nodes
==
total_num_nodes
assert
graph1
.
total_num_nodes
==
total_num_nodes
assert
graph1
.
total_num_nodes
==
total_num_nodes
...
@@ -1181,7 +1181,7 @@ def test_hetero_graph_on_shared_memory(
...
@@ -1181,7 +1181,7 @@ def test_hetero_graph_on_shared_memory(
shm_name
=
"test_hetero_g"
shm_name
=
"test_hetero_g"
graph1
=
graph
.
copy_to_shared_memory
(
shm_name
)
graph1
=
graph
.
copy_to_shared_memory
(
shm_name
)
graph2
=
gb
.
load_from_shared_memory
(
shm_name
,
graph
.
metadata
)
graph2
=
gb
.
load_from_shared_memory
(
shm_name
)
assert
graph1
.
total_num_nodes
==
total_num_nodes
assert
graph1
.
total_num_nodes
==
total_num_nodes
assert
graph1
.
total_num_nodes
==
total_num_nodes
assert
graph1
.
total_num_nodes
==
total_num_nodes
...
...
tests/python/pytorch/graphbolt/impl/test_ondisk_dataset.py
View file @
c2134442
...
@@ -1008,8 +1008,8 @@ def test_OnDiskDataset_Graph_homogeneous():
...
@@ -1008,8 +1008,8 @@ def test_OnDiskDataset_Graph_homogeneous():
graph
=
gb
.
from_fused_csc
(
csc_indptr
,
indices
)
graph
=
gb
.
from_fused_csc
(
csc_indptr
,
indices
)
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
graph_path
=
os
.
path
.
join
(
test_dir
,
"fused_csc_sampling_graph.t
ar
"
)
graph_path
=
os
.
path
.
join
(
test_dir
,
"fused_csc_sampling_graph.
p
t"
)
gb
.
save_fused_csc_sampling_graph
(
graph
,
graph_path
)
torch
.
save
(
graph
,
graph_path
)
yaml_content
=
f
"""
yaml_content
=
f
"""
graph_topology:
graph_topology:
...
@@ -1046,8 +1046,8 @@ def test_OnDiskDataset_Graph_heterogeneous():
...
@@ -1046,8 +1046,8 @@ def test_OnDiskDataset_Graph_heterogeneous():
)
)
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
graph_path
=
os
.
path
.
join
(
test_dir
,
"fused_csc_sampling_graph.t
ar
"
)
graph_path
=
os
.
path
.
join
(
test_dir
,
"fused_csc_sampling_graph.
p
t"
)
gb
.
save_fused_csc_sampling_graph
(
graph
,
graph_path
)
torch
.
save
(
graph
,
graph_path
)
yaml_content
=
f
"""
yaml_content
=
f
"""
graph_topology:
graph_topology:
...
@@ -1119,12 +1119,8 @@ def test_OnDiskDataset_preprocess_homogeneous():
...
@@ -1119,12 +1119,8 @@ def test_OnDiskDataset_preprocess_homogeneous():
assert
"graph"
not
in
processed_dataset
assert
"graph"
not
in
processed_dataset
assert
"graph_topology"
in
processed_dataset
assert
"graph_topology"
in
processed_dataset
fused_csc_sampling_graph
=
(
fused_csc_sampling_graph
=
torch
.
load
(
gb
.
fused_csc_sampling_graph
.
load_fused_csc_sampling_graph
(
os
.
path
.
join
(
test_dir
,
processed_dataset
[
"graph_topology"
][
"path"
])
os
.
path
.
join
(
test_dir
,
processed_dataset
[
"graph_topology"
][
"path"
]
)
)
)
)
assert
fused_csc_sampling_graph
.
total_num_nodes
==
num_nodes
assert
fused_csc_sampling_graph
.
total_num_nodes
==
num_nodes
assert
fused_csc_sampling_graph
.
total_num_edges
==
num_edges
assert
fused_csc_sampling_graph
.
total_num_edges
==
num_edges
...
@@ -1166,12 +1162,8 @@ def test_OnDiskDataset_preprocess_homogeneous():
...
@@ -1166,12 +1162,8 @@ def test_OnDiskDataset_preprocess_homogeneous():
)
)
with
open
(
output_file
,
"rb"
)
as
f
:
with
open
(
output_file
,
"rb"
)
as
f
:
processed_dataset
=
yaml
.
load
(
f
,
Loader
=
yaml
.
Loader
)
processed_dataset
=
yaml
.
load
(
f
,
Loader
=
yaml
.
Loader
)
fused_csc_sampling_graph
=
(
fused_csc_sampling_graph
=
torch
.
load
(
gb
.
fused_csc_sampling_graph
.
load_fused_csc_sampling_graph
(
os
.
path
.
join
(
test_dir
,
processed_dataset
[
"graph_topology"
][
"path"
])
os
.
path
.
join
(
test_dir
,
processed_dataset
[
"graph_topology"
][
"path"
]
)
)
)
)
assert
(
assert
(
fused_csc_sampling_graph
.
edge_attributes
is
not
None
fused_csc_sampling_graph
.
edge_attributes
is
not
None
...
@@ -1325,7 +1317,7 @@ def test_OnDiskDataset_preprocess_yaml_content_unix():
...
@@ -1325,7 +1317,7 @@ def test_OnDiskDataset_preprocess_yaml_content_unix():
dataset_name:
{
dataset_name
}
dataset_name:
{
dataset_name
}
graph_topology:
graph_topology:
type: FusedCSCSamplingGraph
type: FusedCSCSamplingGraph
path: preprocessed/fused_csc_sampling_graph.t
ar
path: preprocessed/fused_csc_sampling_graph.
p
t
feature_data:
feature_data:
- domain: node
- domain: node
type: null
type: null
...
@@ -1479,7 +1471,7 @@ def test_OnDiskDataset_preprocess_yaml_content_windows():
...
@@ -1479,7 +1471,7 @@ def test_OnDiskDataset_preprocess_yaml_content_windows():
dataset_name:
{
dataset_name
}
dataset_name:
{
dataset_name
}
graph_topology:
graph_topology:
type: FusedCSCSamplingGraph
type: FusedCSCSamplingGraph
path: preprocessed
\\
fused_csc_sampling_graph.t
ar
path: preprocessed
\\
fused_csc_sampling_graph.
p
t
feature_data:
feature_data:
- domain: node
- domain: node
type: null
type: null
...
@@ -1836,8 +1828,8 @@ def test_OnDiskDataset_all_nodes_set_homo():
...
@@ -1836,8 +1828,8 @@ def test_OnDiskDataset_all_nodes_set_homo():
graph
=
gb
.
from_fused_csc
(
csc_indptr
,
indices
)
graph
=
gb
.
from_fused_csc
(
csc_indptr
,
indices
)
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
graph_path
=
os
.
path
.
join
(
test_dir
,
"fused_csc_sampling_graph.t
ar
"
)
graph_path
=
os
.
path
.
join
(
test_dir
,
"fused_csc_sampling_graph.
p
t"
)
gb
.
save_fused_csc_sampling_graph
(
graph
,
graph_path
)
torch
.
save
(
graph
,
graph_path
)
yaml_content
=
f
"""
yaml_content
=
f
"""
graph_topology:
graph_topology:
...
@@ -1873,8 +1865,8 @@ def test_OnDiskDataset_all_nodes_set_hetero():
...
@@ -1873,8 +1865,8 @@ def test_OnDiskDataset_all_nodes_set_hetero():
)
)
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
graph_path
=
os
.
path
.
join
(
test_dir
,
"fused_csc_sampling_graph.t
ar
"
)
graph_path
=
os
.
path
.
join
(
test_dir
,
"fused_csc_sampling_graph.
p
t"
)
gb
.
save_fused_csc_sampling_graph
(
graph
,
graph_path
)
torch
.
save
(
graph
,
graph_path
)
yaml_content
=
f
"""
yaml_content
=
f
"""
graph_topology:
graph_topology:
...
@@ -1999,7 +1991,7 @@ def test_BuiltinDataset():
...
@@ -1999,7 +1991,7 @@ def test_BuiltinDataset():
"""Test BuiltinDataset."""
"""Test BuiltinDataset."""
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
# Case 1: download from DGL S3 storage.
# Case 1: download from DGL S3 storage.
dataset_name
=
"test-dataset-23120
4
"
dataset_name
=
"test-dataset-23120
7
"
# Add dataset to the builtin dataset list for testing only.
# Add dataset to the builtin dataset list for testing only.
gb
.
BuiltinDataset
.
_all_datasets
.
append
(
dataset_name
)
gb
.
BuiltinDataset
.
_all_datasets
.
append
(
dataset_name
)
dataset
=
gb
.
BuiltinDataset
(
name
=
dataset_name
,
root
=
test_dir
).
load
()
dataset
=
gb
.
BuiltinDataset
(
name
=
dataset_name
,
root
=
test_dir
).
load
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment