Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
96ddf410
Unverified
Commit
96ddf410
authored
Aug 07, 2023
by
Rhett Ying
Committed by
GitHub
Aug 07, 2023
Browse files
[GraphBolt] refactor FeatureStore and related impl (#6103)
parent
8e86c89c
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
201 additions
and
133 deletions
+201
-133
python/dgl/graphbolt/dataset.py
python/dgl/graphbolt/dataset.py
+1
-3
python/dgl/graphbolt/feature_store.py
python/dgl/graphbolt/feature_store.py
+27
-2
python/dgl/graphbolt/impl/ondisk_dataset.py
python/dgl/graphbolt/impl/ondisk_dataset.py
+4
-7
python/dgl/graphbolt/impl/torch_based_feature_store.py
python/dgl/graphbolt/impl/torch_based_feature_store.py
+127
-69
tests/python/pytorch/graphbolt/test_feature_fetcher.py
tests/python/pytorch/graphbolt/test_feature_fetcher.py
+2
-4
tests/python/pytorch/graphbolt/test_feature_store.py
tests/python/pytorch/graphbolt/test_feature_store.py
+11
-9
tests/python/pytorch/graphbolt/test_multi_process_dataloader.py
...python/pytorch/graphbolt/test_multi_process_dataloader.py
+2
-2
tests/python/pytorch/graphbolt/test_ondisk_dataset.py
tests/python/pytorch/graphbolt/test_ondisk_dataset.py
+25
-35
tests/python/pytorch/graphbolt/test_single_process_dataloader.py
...ython/pytorch/graphbolt/test_single_process_dataloader.py
+2
-2
No files found.
python/dgl/graphbolt/dataset.py
View file @
96ddf410
"""GraphBolt Dataset."""
"""GraphBolt Dataset."""
from
typing
import
Dict
from
.feature_store
import
FeatureStore
from
.feature_store
import
FeatureStore
from
.itemset
import
ItemSet
,
ItemSetDict
from
.itemset
import
ItemSet
,
ItemSetDict
...
@@ -52,7 +50,7 @@ class Dataset:
...
@@ -52,7 +50,7 @@ class Dataset:
raise
NotImplementedError
raise
NotImplementedError
@
property
@
property
def
feature
(
self
)
->
Dict
[
object
,
FeatureStore
]
:
def
feature
(
self
)
->
FeatureStore
:
"""Return the feature."""
"""Return the feature."""
raise
NotImplementedError
raise
NotImplementedError
...
...
python/dgl/graphbolt/feature_store.py
View file @
96ddf410
...
@@ -11,11 +11,23 @@ class FeatureStore:
...
@@ -11,11 +11,23 @@ class FeatureStore:
def
__init__
(
self
):
def
__init__
(
self
):
pass
pass
def
read
(
self
,
ids
:
torch
.
Tensor
=
None
):
def
read
(
self
,
domain
:
str
,
type_name
:
str
,
feature_name
:
str
,
ids
:
torch
.
Tensor
=
None
,
):
"""Read from the feature store.
"""Read from the feature store.
Parameters
Parameters
----------
----------
domain : str
The domain of the feature such as "node", "edge" or "graph".
type_name : str
The node or edge type name.
feature_name : str
The feature name.
ids : torch.Tensor, optional
ids : torch.Tensor, optional
The index of the feature. If specified, only the specified indices
The index of the feature. If specified, only the specified indices
of the feature are read. If None, the entire feature is returned.
of the feature are read. If None, the entire feature is returned.
...
@@ -27,11 +39,24 @@ class FeatureStore:
...
@@ -27,11 +39,24 @@ class FeatureStore:
"""
"""
raise
NotImplementedError
raise
NotImplementedError
def
update
(
self
,
value
:
torch
.
Tensor
,
ids
:
torch
.
Tensor
=
None
):
def
update
(
self
,
domain
:
str
,
type_name
:
str
,
feature_name
:
str
,
value
:
torch
.
Tensor
,
ids
:
torch
.
Tensor
=
None
,
):
"""Update the feature store.
"""Update the feature store.
Parameters
Parameters
----------
----------
domain : str
The domain of the feature such as "node", "edge" or "graph".
type_name : str
The node or edge type name.
feature_name : str
The feature name.
value : torch.Tensor
value : torch.Tensor
The updated value of the feature.
The updated value of the feature.
ids : torch.Tensor, optional
ids : torch.Tensor, optional
...
...
python/dgl/graphbolt/impl/ondisk_dataset.py
View file @
96ddf410
...
@@ -5,7 +5,7 @@ import shutil
...
@@ -5,7 +5,7 @@ import shutil
from
copy
import
deepcopy
from
copy
import
deepcopy
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Dict
,
List
,
Tuple
from
typing
import
List
import
pandas
as
pd
import
pandas
as
pd
import
torch
import
torch
...
@@ -24,10 +24,7 @@ from .csc_sampling_graph import (
...
@@ -24,10 +24,7 @@ from .csc_sampling_graph import (
save_csc_sampling_graph
,
save_csc_sampling_graph
,
)
)
from
.ondisk_metadata
import
OnDiskGraphTopology
,
OnDiskMetaData
,
OnDiskTVTSet
from
.ondisk_metadata
import
OnDiskGraphTopology
,
OnDiskMetaData
,
OnDiskTVTSet
from
.torch_based_feature_store
import
(
from
.torch_based_feature_store
import
TorchBasedFeatureStore
load_feature_stores
,
TorchBasedFeatureStore
,
)
__all__
=
[
"OnDiskDataset"
,
"preprocess_ondisk_dataset"
]
__all__
=
[
"OnDiskDataset"
,
"preprocess_ondisk_dataset"
]
...
@@ -281,7 +278,7 @@ class OnDiskDataset(Dataset):
...
@@ -281,7 +278,7 @@ class OnDiskDataset(Dataset):
self
.
_num_classes
=
self
.
_meta
.
num_classes
self
.
_num_classes
=
self
.
_meta
.
num_classes
self
.
_num_labels
=
self
.
_meta
.
num_labels
self
.
_num_labels
=
self
.
_meta
.
num_labels
self
.
_graph
=
self
.
_load_graph
(
self
.
_meta
.
graph_topology
)
self
.
_graph
=
self
.
_load_graph
(
self
.
_meta
.
graph_topology
)
self
.
_feature
=
load_f
eature
_s
tore
s
(
self
.
_meta
.
feature_data
)
self
.
_feature
=
TorchBasedF
eature
S
tore
(
self
.
_meta
.
feature_data
)
self
.
_train_set
=
self
.
_init_tvt_set
(
self
.
_meta
.
train_set
)
self
.
_train_set
=
self
.
_init_tvt_set
(
self
.
_meta
.
train_set
)
self
.
_validation_set
=
self
.
_init_tvt_set
(
self
.
_meta
.
validation_set
)
self
.
_validation_set
=
self
.
_init_tvt_set
(
self
.
_meta
.
validation_set
)
self
.
_test_set
=
self
.
_init_tvt_set
(
self
.
_meta
.
test_set
)
self
.
_test_set
=
self
.
_init_tvt_set
(
self
.
_meta
.
test_set
)
...
@@ -307,7 +304,7 @@ class OnDiskDataset(Dataset):
...
@@ -307,7 +304,7 @@ class OnDiskDataset(Dataset):
return
self
.
_graph
return
self
.
_graph
@
property
@
property
def
feature
(
self
)
->
Dict
[
Tuple
,
TorchBasedFeatureStore
]
:
def
feature
(
self
)
->
TorchBasedFeatureStore
:
"""Return the feature."""
"""Return the feature."""
return
self
.
_feature
return
self
.
_feature
...
...
python/dgl/graphbolt/impl/torch_based_feature_store.py
View file @
96ddf410
...
@@ -8,11 +8,11 @@ import torch
...
@@ -8,11 +8,11 @@ import torch
from
..feature_store
import
FeatureStore
from
..feature_store
import
FeatureStore
from
.ondisk_metadata
import
OnDiskFeatureData
from
.ondisk_metadata
import
OnDiskFeatureData
__all__
=
[
"TorchBasedFeature
Store"
,
"load_f
eature
_s
tore
s
"
]
__all__
=
[
"TorchBasedFeature
"
,
"TorchBasedF
eature
S
tore"
]
class
TorchBasedFeature
Store
(
FeatureStore
)
:
class
TorchBasedFeature
:
r
"""Torch based feature
store
."""
r
"""Torch based feature."""
def
__init__
(
self
,
torch_feature
:
torch
.
Tensor
):
def
__init__
(
self
,
torch_feature
:
torch
.
Tensor
):
"""Initialize a torch based feature store by a torch feature.
"""Initialize a torch based feature store by a torch feature.
...
@@ -28,7 +28,7 @@ class TorchBasedFeatureStore(FeatureStore):
...
@@ -28,7 +28,7 @@ class TorchBasedFeatureStore(FeatureStore):
--------
--------
>>> import torch
>>> import torch
>>> torch_feat = torch.arange(0, 5)
>>> torch_feat = torch.arange(0, 5)
>>> feature_store = TorchBasedFeature
Store
(torch_feat)
>>> feature_store = TorchBasedFeature(torch_feat)
>>> feature_store.read()
>>> feature_store.read()
tensor([0, 1, 2, 3, 4])
tensor([0, 1, 2, 3, 4])
>>> feature_store.read(torch.tensor([0, 1, 2]))
>>> feature_store.read(torch.tensor([0, 1, 2]))
...
@@ -43,15 +43,14 @@ class TorchBasedFeatureStore(FeatureStore):
...
@@ -43,15 +43,14 @@ class TorchBasedFeatureStore(FeatureStore):
>>> np.save("/tmp/arr.npy", arr)
>>> np.save("/tmp/arr.npy", arr)
>>> torch_feat = torch.as_tensor(np.load("/tmp/arr.npy",
>>> torch_feat = torch.as_tensor(np.load("/tmp/arr.npy",
... mmap_mode="r+"))
... mmap_mode="r+"))
>>> feature_store = TorchBasedFeature
Store
(torch_feat)
>>> feature_store = TorchBasedFeature(torch_feat)
>>> feature_store.read()
>>> feature_store.read()
tensor([0, 1, 2, 3, 4])
tensor([0, 1, 2, 3, 4])
>>> feature_store.read(torch.tensor([0, 1, 2]))
>>> feature_store.read(torch.tensor([0, 1, 2]))
tensor([0, 1, 2])
tensor([0, 1, 2])
"""
"""
super
(
TorchBasedFeatureStore
,
self
).
__init__
()
assert
isinstance
(
torch_feature
,
torch
.
Tensor
),
(
assert
isinstance
(
torch_feature
,
torch
.
Tensor
),
(
f
"torch_feature in TorchBasedFeature
Store
must be torch.Tensor, "
f
"torch_feature in TorchBasedFeature must be torch.Tensor, "
f
"but got
{
type
(
torch_feature
)
}
."
f
"but got
{
type
(
torch_feature
)
}
."
)
)
self
.
_tensor
=
torch_feature
self
.
_tensor
=
torch_feature
...
@@ -106,65 +105,124 @@ class TorchBasedFeatureStore(FeatureStore):
...
@@ -106,65 +105,124 @@ class TorchBasedFeatureStore(FeatureStore):
self
.
_tensor
[
ids
]
=
value
self
.
_tensor
[
ids
]
=
value
def
load_feature_stores
(
feat_data
:
List
[
OnDiskFeatureData
]):
class
TorchBasedFeatureStore
(
FeatureStore
):
r
"""Load feature stores from disk.
r
"""Torch based feature store."""
The feature stores are described by the `feat_data`. The `feat_data` is a
def
__init__
(
self
,
feat_data
:
List
[
OnDiskFeatureData
]):
list of `OnDiskFeatureData`.
r
"""Load feature stores from disk.
For a feature store, its format must be either "pt" or "npy" for Pytorch or
The feature stores are described by the `feat_data`. The `feat_data` is a
Numpy formats. If the format is "pt", the feature store must be loaded in
list of `OnDiskFeatureData`.
memory. If the format is "npy", the feature store can be loaded in memory or
on disk.
For a feature store, its format must be either "pt" or "npy" for Pytorch or
Numpy formats. If the format is "pt", the feature store must be loaded in
Parameters
memory. If the format is "npy", the feature store can be loaded in memory or
----------
on disk.
feat_data : List[OnDiskFeatureData]
The description of the feature stores.
Parameters
----------
Returns
feat_data : List[OnDiskFeatureData]
-------
The description of the feature stores.
dict
The loaded feature stores. The keys are the names of the feature stores,
Returns
and the values are the feature stores.
-------
dict
Examples
The loaded feature stores. The keys are the names of the feature stores,
--------
and the values are the feature stores.
>>> import torch
>>> import numpy as np
Examples
>>> from dgl import graphbolt as gb
--------
>>> edge_label = torch.tensor([1, 2, 3])
>>> import torch
>>> node_feat = torch.tensor([[1, 2, 3], [4, 5, 6]])
>>> import numpy as np
>>> torch.save(edge_label, "/tmp/edge_label.pt")
>>> from dgl import graphbolt as gb
>>> np.save("/tmp/node_feat.npy", node_feat.numpy())
>>> edge_label = torch.tensor([1, 2, 3])
>>> feat_data = [
>>> node_feat = torch.tensor([[1, 2, 3], [4, 5, 6]])
... gb.OnDiskFeatureData(domain="edge", type="author:writes:paper",
>>> torch.save(edge_label, "/tmp/edge_label.pt")
... name="label", format="torch", path="/tmp/edge_label.pt",
>>> np.save("/tmp/node_feat.npy", node_feat.numpy())
... in_memory=True),
>>> feat_data = [
... gb.OnDiskFeatureData(domain="node", type="paper", name="feat",
... gb.OnDiskFeatureData(domain="edge", type="author:writes:paper",
... format="numpy", path="/tmp/node_feat.npy", in_memory=False),
... name="label", format="torch", path="/tmp/edge_label.pt",
... ]
... in_memory=True),
>>> gb.load_feature_stores(feat_data)
... gb.OnDiskFeatureData(domain="node", type="paper", name="feat",
... {("edge", "author:writes:paper", "label"):
... format="numpy", path="/tmp/node_feat.npy", in_memory=False),
... <dgl.graphbolt.feature_store.TorchBasedFeatureStore object at
... ]
... 0x7ff093cb4df0>, ("node", "paper", "feat"):
>>> feature_sotre = gb.TorchBasedFeatureStore(feat_data)
... <dgl.graphbolt.feature_store.TorchBasedFeatureStore object at
"""
... 0x7ff093cb4dc0>}
super
().
__init__
()
"""
self
.
_features
=
{}
feat_stores
=
{}
for
spec
in
feat_data
:
for
spec
in
feat_data
:
key
=
(
spec
.
domain
,
spec
.
type
,
spec
.
name
)
key
=
(
spec
.
domain
,
spec
.
type
,
spec
.
name
)
if
spec
.
format
==
"torch"
:
if
spec
.
format
==
"torch"
:
assert
spec
.
in_memory
,
(
assert
spec
.
in_memory
,
(
f
"Pytorch tensor can only be loaded in memory, "
f
"Pytorch tensor can only be loaded in memory, "
f
"but the feature
{
key
}
is loaded on disk."
f
"but the feature
{
key
}
is loaded on disk."
)
)
self
.
_features
[
key
]
=
TorchBasedFeature
(
torch
.
load
(
spec
.
path
))
feat_stores
[
key
]
=
TorchBasedFeatureStore
(
torch
.
load
(
spec
.
path
))
elif
spec
.
format
==
"numpy"
:
elif
spec
.
format
==
"numpy"
:
mmap_mode
=
"r+"
if
not
spec
.
in_memory
else
None
mmap_mode
=
"r+"
if
not
spec
.
in_memory
else
None
self
.
_features
[
key
]
=
TorchBasedFeature
(
feat_stores
[
key
]
=
TorchBasedFeatureStore
(
torch
.
as_tensor
(
np
.
load
(
spec
.
path
,
mmap_mode
=
mmap_mode
))
torch
.
as_tensor
(
np
.
load
(
spec
.
path
,
mmap_mode
=
mmap_mode
))
)
)
else
:
else
:
raise
ValueError
(
f
"Unknown feature format
{
spec
.
format
}
"
)
raise
ValueError
(
f
"Unknown feature format
{
spec
.
format
}
"
)
return
feat_stores
def
read
(
self
,
domain
:
str
,
type_name
:
str
,
feature_name
:
str
,
ids
:
torch
.
Tensor
=
None
,
):
"""Read from the feature store.
Parameters
----------
domain : str
The domain of the feature such as "node", "edge" or "graph".
type_name : str
The node or edge type name.
feature_name : str
The feature name.
ids : torch.Tensor, optional
The index of the feature. If specified, only the specified indices
of the feature are read. If None, the entire feature is returned.
Returns
-------
torch.Tensor
The read feature.
"""
return
self
.
_features
[(
domain
,
type_name
,
feature_name
)].
read
(
ids
)
def
update
(
self
,
domain
:
str
,
type_name
:
str
,
feature_name
:
str
,
value
:
torch
.
Tensor
,
ids
:
torch
.
Tensor
=
None
,
):
"""Update the feature store.
Parameters
----------
domain : str
The domain of the feature such as "node", "edge" or "graph".
type_name : str
The node or edge type name.
feature_name : str
The feature name.
value : torch.Tensor
The updated value of the feature.
ids : torch.Tensor, optional
The indices of the feature to update. If specified, only the
specified indices of the feature will be updated. For the feature,
the `ids[i]` row is updated to `value[i]`. So the indices and value
must have the same length. If None, the entire feature will be
updated.
"""
self
.
_features
[(
domain
,
type_name
,
feature_name
)].
update
(
value
,
ids
)
def
__len__
(
self
):
"""Return the number of features."""
return
len
(
self
.
_features
)
tests/python/pytorch/graphbolt/test_feature_fetcher.py
View file @
96ddf410
...
@@ -6,10 +6,8 @@ import torch
...
@@ -6,10 +6,8 @@ import torch
def
get_graphbolt_fetch_func
():
def
get_graphbolt_fetch_func
():
feature_store
=
{
feature_store
=
{
"feature"
:
dgl
.
graphbolt
.
TorchBasedFeatureStore
(
torch
.
randn
(
200
,
4
)),
"feature"
:
dgl
.
graphbolt
.
TorchBasedFeature
(
torch
.
randn
(
200
,
4
)),
"label"
:
dgl
.
graphbolt
.
TorchBasedFeatureStore
(
"label"
:
dgl
.
graphbolt
.
TorchBasedFeature
(
torch
.
randint
(
0
,
10
,
(
200
,))),
torch
.
randint
(
0
,
10
,
(
200
,))
),
}
}
def
fetch_func
(
data
):
def
fetch_func
(
data
):
...
...
tests/python/pytorch/graphbolt/test_feature_store.py
View file @
96ddf410
...
@@ -19,7 +19,7 @@ def to_on_disk_tensor(test_dir, name, t):
...
@@ -19,7 +19,7 @@ def to_on_disk_tensor(test_dir, name, t):
@
pytest
.
mark
.
parametrize
(
"in_memory"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"in_memory"
,
[
True
,
False
])
def
test_torch_based_feature
_store
(
in_memory
):
def
test_torch_based_feature
(
in_memory
):
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
a
=
torch
.
tensor
([
1
,
2
,
3
])
a
=
torch
.
tensor
([
1
,
2
,
3
])
b
=
torch
.
tensor
([[
1
,
2
,
3
],
[
4
,
5
,
6
]])
b
=
torch
.
tensor
([[
1
,
2
,
3
],
[
4
,
5
,
6
]])
...
@@ -27,8 +27,8 @@ def test_torch_based_feature_store(in_memory):
...
@@ -27,8 +27,8 @@ def test_torch_based_feature_store(in_memory):
a
=
to_on_disk_tensor
(
test_dir
,
"a"
,
a
)
a
=
to_on_disk_tensor
(
test_dir
,
"a"
,
a
)
b
=
to_on_disk_tensor
(
test_dir
,
"b"
,
b
)
b
=
to_on_disk_tensor
(
test_dir
,
"b"
,
b
)
feat_store_a
=
gb
.
TorchBasedFeature
Store
(
a
)
feat_store_a
=
gb
.
TorchBasedFeature
(
a
)
feat_store_b
=
gb
.
TorchBasedFeature
Store
(
b
)
feat_store_b
=
gb
.
TorchBasedFeature
(
b
)
assert
torch
.
equal
(
feat_store_a
.
read
(),
torch
.
tensor
([
1
,
2
,
3
]))
assert
torch
.
equal
(
feat_store_a
.
read
(),
torch
.
tensor
([
1
,
2
,
3
]))
assert
torch
.
equal
(
assert
torch
.
equal
(
...
@@ -71,7 +71,7 @@ def write_tensor_to_disk(dir, name, t, fmt="torch"):
...
@@ -71,7 +71,7 @@ def write_tensor_to_disk(dir, name, t, fmt="torch"):
@
pytest
.
mark
.
parametrize
(
"in_memory"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"in_memory"
,
[
True
,
False
])
def
test_
loa
d_feature_store
s
(
in_memory
):
def
test_
torch_base
d_feature_store
(
in_memory
):
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
a
=
torch
.
tensor
([
1
,
2
,
3
])
a
=
torch
.
tensor
([
1
,
2
,
3
])
b
=
torch
.
tensor
([
2
,
5
,
3
])
b
=
torch
.
tensor
([
2
,
5
,
3
])
...
@@ -95,12 +95,12 @@ def test_load_feature_stores(in_memory):
...
@@ -95,12 +95,12 @@ def test_load_feature_stores(in_memory):
in_memory
=
in_memory
,
in_memory
=
in_memory
,
),
),
]
]
feat_store
s
=
gb
.
load_f
eature
_s
tore
s
(
feat_data
)
feat_store
=
gb
.
TorchBasedF
eature
S
tore
(
feat_data
)
assert
torch
.
equal
(
assert
torch
.
equal
(
feat_store
s
[
(
"node"
,
"paper"
,
"a"
)
].
read
()
,
torch
.
tensor
([
1
,
2
,
3
])
feat_store
.
read
(
"node"
,
"paper"
,
"a"
),
torch
.
tensor
([
1
,
2
,
3
])
)
)
assert
torch
.
equal
(
assert
torch
.
equal
(
feat_store
s
[
(
"edge"
,
"paper-cites-paper"
,
"b"
)
].
read
()
,
feat_store
.
read
(
"edge"
,
"paper-cites-paper"
,
"b"
),
torch
.
tensor
([
2
,
5
,
3
]),
torch
.
tensor
([
2
,
5
,
3
]),
)
)
...
@@ -130,6 +130,8 @@ def test_load_feature_stores(in_memory):
...
@@ -130,6 +130,8 @@ def test_load_feature_stores(in_memory):
in_memory
=
True
,
in_memory
=
True
,
),
),
]
]
feat_stores
=
gb
.
load_feature_stores
(
feat_data
)
feat_store
=
gb
.
TorchBasedFeatureStore
(
feat_data
)
assert
(
"node"
,
None
,
"a"
)
in
feat_stores
assert
torch
.
equal
(
feat_store
.
read
(
"node"
,
None
,
"a"
),
torch
.
tensor
([
1
,
2
,
3
])
)
feat_stores
=
None
feat_stores
=
None
tests/python/pytorch/graphbolt/test_multi_process_dataloader.py
View file @
96ddf410
...
@@ -27,8 +27,8 @@ def test_DataLoader():
...
@@ -27,8 +27,8 @@ def test_DataLoader():
# TODO(BarclayII): temporarily using DGLGraph. Should test using
# TODO(BarclayII): temporarily using DGLGraph. Should test using
# GraphBolt's storage as well once issue #5953 is resolved.
# GraphBolt's storage as well once issue #5953 is resolved.
graph
=
dgl
.
add_reverse_edges
(
dgl
.
rand_graph
(
200
,
6000
))
graph
=
dgl
.
add_reverse_edges
(
dgl
.
rand_graph
(
200
,
6000
))
features
=
dgl
.
graphbolt
.
TorchBasedFeature
Store
(
torch
.
randn
(
200
,
4
))
features
=
dgl
.
graphbolt
.
TorchBasedFeature
(
torch
.
randn
(
200
,
4
))
labels
=
dgl
.
graphbolt
.
TorchBasedFeature
Store
(
torch
.
randint
(
0
,
10
,
(
200
,)))
labels
=
dgl
.
graphbolt
.
TorchBasedFeature
(
torch
.
randint
(
0
,
10
,
(
200
,)))
minibatch_sampler
=
dgl
.
graphbolt
.
MinibatchSampler
(
itemset
,
batch_size
=
B
)
minibatch_sampler
=
dgl
.
graphbolt
.
MinibatchSampler
(
itemset
,
batch_size
=
B
)
subgraph_sampler
=
dgl
.
graphbolt
.
SubgraphSampler
(
subgraph_sampler
=
dgl
.
graphbolt
.
SubgraphSampler
(
...
...
tests/python/pytorch/graphbolt/test_ondisk_dataset.py
View file @
96ddf410
...
@@ -647,35 +647,25 @@ def test_OnDiskDataset_Feature_heterograph():
...
@@ -647,35 +647,25 @@ def test_OnDiskDataset_Feature_heterograph():
assert
len
(
feature_data
)
==
4
assert
len
(
feature_data
)
==
4
# Verify node feature data.
# Verify node feature data.
node_paper_feat
=
feature_data
[(
"node"
,
"paper"
,
"feat"
)]
assert
isinstance
(
node_paper_feat
,
gb
.
TorchBasedFeatureStore
)
assert
torch
.
equal
(
assert
torch
.
equal
(
node_paper_feat
.
read
(),
torch
.
tensor
(
node_data_paper
)
feature_data
.
read
(
"node"
,
"paper"
,
"feat"
),
torch
.
tensor
(
node_data_paper
),
)
)
node_paper_label
=
feature_data
[(
"node"
,
"paper"
,
"label"
)]
assert
isinstance
(
node_paper_label
,
gb
.
TorchBasedFeatureStore
)
assert
torch
.
equal
(
assert
torch
.
equal
(
node_paper_label
.
read
(),
torch
.
tensor
(
node_data_label
)
feature_data
.
read
(
"node"
,
"paper"
,
"label"
),
torch
.
tensor
(
node_data_label
),
)
)
# Verify edge feature data.
# Verify edge feature data.
edge_writes_feat
=
feature_data
[(
"edge"
,
"author:writes:paper"
,
"feat"
)]
assert
isinstance
(
edge_writes_feat
,
gb
.
TorchBasedFeatureStore
)
assert
torch
.
equal
(
assert
torch
.
equal
(
edge_writes_feat
.
read
(),
torch
.
tensor
(
edge_data_writes
)
feature_data
.
read
(
"edge"
,
"author:writes:paper"
,
"feat"
),
torch
.
tensor
(
edge_data_writes
),
)
)
edge_writes_label
=
feature_data
[
(
"edge"
,
"author:writes:paper"
,
"label"
)
]
assert
isinstance
(
edge_writes_label
,
gb
.
TorchBasedFeatureStore
)
assert
torch
.
equal
(
assert
torch
.
equal
(
edge_writes_label
.
read
(),
torch
.
tensor
(
edge_data_label
)
feature_data
.
read
(
"edge"
,
"author:writes:paper"
,
"label"
),
torch
.
tensor
(
edge_data_label
),
)
)
node_paper_feat
=
None
node_paper_label
=
None
edge_writes_feat
=
None
edge_writes_label
=
None
feature_data
=
None
feature_data
=
None
dataset
=
None
dataset
=
None
...
@@ -735,25 +725,25 @@ def test_OnDiskDataset_Feature_homograph():
...
@@ -735,25 +725,25 @@ def test_OnDiskDataset_Feature_homograph():
assert
len
(
feature_data
)
==
4
assert
len
(
feature_data
)
==
4
# Verify node feature data.
# Verify node feature data.
node_feat
=
feature_data
[(
"node"
,
None
,
"feat"
)]
assert
torch
.
equal
(
assert
isinstance
(
node_feat
,
gb
.
TorchBasedFeatureStore
)
feature_data
.
read
(
"node"
,
None
,
"feat"
),
assert
torch
.
equal
(
node_feat
.
read
(),
torch
.
tensor
(
node_data_feat
))
torch
.
tensor
(
node_data_feat
),
node_label
=
feature_data
[(
"node"
,
None
,
"label"
)]
)
assert
isinstance
(
node_label
,
gb
.
TorchBasedFeatureStore
)
assert
torch
.
equal
(
assert
torch
.
equal
(
node_label
.
read
(),
torch
.
tensor
(
node_data_label
))
feature_data
.
read
(
"node"
,
None
,
"label"
),
torch
.
tensor
(
node_data_label
),
)
# Verify edge feature data.
# Verify edge feature data.
edge_feat
=
feature_data
[(
"edge"
,
None
,
"feat"
)]
assert
torch
.
equal
(
assert
isinstance
(
edge_feat
,
gb
.
TorchBasedFeatureStore
)
feature_data
.
read
(
"edge"
,
None
,
"feat"
),
assert
torch
.
equal
(
edge_feat
.
read
(),
torch
.
tensor
(
edge_data_feat
))
torch
.
tensor
(
edge_data_feat
),
edge_label
=
feature_data
[(
"edge"
,
None
,
"label"
)]
)
assert
isinstance
(
edge_label
,
gb
.
TorchBasedFeatureStore
)
assert
torch
.
equal
(
assert
torch
.
equal
(
edge_label
.
read
(),
torch
.
tensor
(
edge_data_label
))
feature_data
.
read
(
"edge"
,
None
,
"label"
),
torch
.
tensor
(
edge_data_label
),
node_feat
=
None
)
node_label
=
None
edge_feat
=
None
edge_label
=
None
feature_data
=
None
feature_data
=
None
dataset
=
None
dataset
=
None
...
...
tests/python/pytorch/graphbolt/test_single_process_dataloader.py
View file @
96ddf410
...
@@ -10,8 +10,8 @@ def test_DataLoader():
...
@@ -10,8 +10,8 @@ def test_DataLoader():
B
=
4
B
=
4
itemset
=
dgl
.
graphbolt
.
ItemSet
(
torch
.
arange
(
N
))
itemset
=
dgl
.
graphbolt
.
ItemSet
(
torch
.
arange
(
N
))
graph
=
gb_test_utils
.
rand_csc_graph
(
200
,
0.15
)
graph
=
gb_test_utils
.
rand_csc_graph
(
200
,
0.15
)
features
=
dgl
.
graphbolt
.
TorchBasedFeature
Store
(
torch
.
randn
(
200
,
4
))
features
=
dgl
.
graphbolt
.
TorchBasedFeature
(
torch
.
randn
(
200
,
4
))
labels
=
dgl
.
graphbolt
.
TorchBasedFeature
Store
(
torch
.
randint
(
0
,
10
,
(
200
,)))
labels
=
dgl
.
graphbolt
.
TorchBasedFeature
(
torch
.
randint
(
0
,
10
,
(
200
,)))
def
sampler_func
(
data
):
def
sampler_func
(
data
):
adjs
=
[]
adjs
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment