Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
96ddf410
Unverified
Commit
96ddf410
authored
Aug 07, 2023
by
Rhett Ying
Committed by
GitHub
Aug 07, 2023
Browse files
[GraphBolt] refactor FeatureStore and related impl (#6103)
parent
8e86c89c
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
201 additions
and
133 deletions
+201
-133
python/dgl/graphbolt/dataset.py
python/dgl/graphbolt/dataset.py
+1
-3
python/dgl/graphbolt/feature_store.py
python/dgl/graphbolt/feature_store.py
+27
-2
python/dgl/graphbolt/impl/ondisk_dataset.py
python/dgl/graphbolt/impl/ondisk_dataset.py
+4
-7
python/dgl/graphbolt/impl/torch_based_feature_store.py
python/dgl/graphbolt/impl/torch_based_feature_store.py
+127
-69
tests/python/pytorch/graphbolt/test_feature_fetcher.py
tests/python/pytorch/graphbolt/test_feature_fetcher.py
+2
-4
tests/python/pytorch/graphbolt/test_feature_store.py
tests/python/pytorch/graphbolt/test_feature_store.py
+11
-9
tests/python/pytorch/graphbolt/test_multi_process_dataloader.py
...python/pytorch/graphbolt/test_multi_process_dataloader.py
+2
-2
tests/python/pytorch/graphbolt/test_ondisk_dataset.py
tests/python/pytorch/graphbolt/test_ondisk_dataset.py
+25
-35
tests/python/pytorch/graphbolt/test_single_process_dataloader.py
...ython/pytorch/graphbolt/test_single_process_dataloader.py
+2
-2
No files found.
python/dgl/graphbolt/dataset.py
View file @
96ddf410
"""GraphBolt Dataset."""
from
typing
import
Dict
from
.feature_store
import
FeatureStore
from
.itemset
import
ItemSet
,
ItemSetDict
...
...
@@ -52,7 +50,7 @@ class Dataset:
raise
NotImplementedError
@
property
def
feature
(
self
)
->
Dict
[
object
,
FeatureStore
]
:
def
feature
(
self
)
->
FeatureStore
:
"""Return the feature."""
raise
NotImplementedError
...
...
python/dgl/graphbolt/feature_store.py
View file @
96ddf410
...
...
@@ -11,11 +11,23 @@ class FeatureStore:
def
__init__
(
self
):
pass
def
read
(
self
,
ids
:
torch
.
Tensor
=
None
):
def
read
(
self
,
domain
:
str
,
type_name
:
str
,
feature_name
:
str
,
ids
:
torch
.
Tensor
=
None
,
):
"""Read from the feature store.
Parameters
----------
domain : str
The domain of the feature such as "node", "edge" or "graph".
type_name : str
The node or edge type name.
feature_name : str
The feature name.
ids : torch.Tensor, optional
The index of the feature. If specified, only the specified indices
of the feature are read. If None, the entire feature is returned.
...
...
@@ -27,11 +39,24 @@ class FeatureStore:
"""
raise
NotImplementedError
def
update
(
self
,
value
:
torch
.
Tensor
,
ids
:
torch
.
Tensor
=
None
):
def
update
(
self
,
domain
:
str
,
type_name
:
str
,
feature_name
:
str
,
value
:
torch
.
Tensor
,
ids
:
torch
.
Tensor
=
None
,
):
"""Update the feature store.
Parameters
----------
domain : str
The domain of the feature such as "node", "edge" or "graph".
type_name : str
The node or edge type name.
feature_name : str
The feature name.
value : torch.Tensor
The updated value of the feature.
ids : torch.Tensor, optional
...
...
python/dgl/graphbolt/impl/ondisk_dataset.py
View file @
96ddf410
...
...
@@ -5,7 +5,7 @@ import shutil
from
copy
import
deepcopy
from
pathlib
import
Path
from
typing
import
Dict
,
List
,
Tuple
from
typing
import
List
import
pandas
as
pd
import
torch
...
...
@@ -24,10 +24,7 @@ from .csc_sampling_graph import (
save_csc_sampling_graph
,
)
from
.ondisk_metadata
import
OnDiskGraphTopology
,
OnDiskMetaData
,
OnDiskTVTSet
from
.torch_based_feature_store
import
(
load_feature_stores
,
TorchBasedFeatureStore
,
)
from
.torch_based_feature_store
import
TorchBasedFeatureStore
__all__
=
[
"OnDiskDataset"
,
"preprocess_ondisk_dataset"
]
...
...
@@ -281,7 +278,7 @@ class OnDiskDataset(Dataset):
self
.
_num_classes
=
self
.
_meta
.
num_classes
self
.
_num_labels
=
self
.
_meta
.
num_labels
self
.
_graph
=
self
.
_load_graph
(
self
.
_meta
.
graph_topology
)
self
.
_feature
=
load_f
eature
_s
tore
s
(
self
.
_meta
.
feature_data
)
self
.
_feature
=
TorchBasedF
eature
S
tore
(
self
.
_meta
.
feature_data
)
self
.
_train_set
=
self
.
_init_tvt_set
(
self
.
_meta
.
train_set
)
self
.
_validation_set
=
self
.
_init_tvt_set
(
self
.
_meta
.
validation_set
)
self
.
_test_set
=
self
.
_init_tvt_set
(
self
.
_meta
.
test_set
)
...
...
@@ -307,7 +304,7 @@ class OnDiskDataset(Dataset):
return
self
.
_graph
@
property
def
feature
(
self
)
->
Dict
[
Tuple
,
TorchBasedFeatureStore
]
:
def
feature
(
self
)
->
TorchBasedFeatureStore
:
"""Return the feature."""
return
self
.
_feature
...
...
python/dgl/graphbolt/impl/torch_based_feature_store.py
View file @
96ddf410
...
...
@@ -8,11 +8,11 @@ import torch
from
..feature_store
import
FeatureStore
from
.ondisk_metadata
import
OnDiskFeatureData
__all__
=
[
"TorchBasedFeature
Store"
,
"load_f
eature
_s
tore
s
"
]
__all__
=
[
"TorchBasedFeature
"
,
"TorchBasedF
eature
S
tore"
]
class
TorchBasedFeature
Store
(
FeatureStore
)
:
r
"""Torch based feature
store
."""
class
TorchBasedFeature
:
r
"""Torch based feature."""
def
__init__
(
self
,
torch_feature
:
torch
.
Tensor
):
"""Initialize a torch based feature store by a torch feature.
...
...
@@ -28,7 +28,7 @@ class TorchBasedFeatureStore(FeatureStore):
--------
>>> import torch
>>> torch_feat = torch.arange(0, 5)
>>> feature_store = TorchBasedFeature
Store
(torch_feat)
>>> feature_store = TorchBasedFeature(torch_feat)
>>> feature_store.read()
tensor([0, 1, 2, 3, 4])
>>> feature_store.read(torch.tensor([0, 1, 2]))
...
...
@@ -43,15 +43,14 @@ class TorchBasedFeatureStore(FeatureStore):
>>> np.save("/tmp/arr.npy", arr)
>>> torch_feat = torch.as_tensor(np.load("/tmp/arr.npy",
... mmap_mode="r+"))
>>> feature_store = TorchBasedFeature
Store
(torch_feat)
>>> feature_store = TorchBasedFeature(torch_feat)
>>> feature_store.read()
tensor([0, 1, 2, 3, 4])
>>> feature_store.read(torch.tensor([0, 1, 2]))
tensor([0, 1, 2])
"""
super
(
TorchBasedFeatureStore
,
self
).
__init__
()
assert
isinstance
(
torch_feature
,
torch
.
Tensor
),
(
f
"torch_feature in TorchBasedFeature
Store
must be torch.Tensor, "
f
"torch_feature in TorchBasedFeature must be torch.Tensor, "
f
"but got
{
type
(
torch_feature
)
}
."
)
self
.
_tensor
=
torch_feature
...
...
@@ -106,7 +105,10 @@ class TorchBasedFeatureStore(FeatureStore):
self
.
_tensor
[
ids
]
=
value
def
load_feature_stores
(
feat_data
:
List
[
OnDiskFeatureData
]):
class
TorchBasedFeatureStore
(
FeatureStore
):
r
"""Torch based feature store."""
def
__init__
(
self
,
feat_data
:
List
[
OnDiskFeatureData
]):
r
"""Load feature stores from disk.
The feature stores are described by the `feat_data`. The `feat_data` is a
...
...
@@ -144,14 +146,10 @@ def load_feature_stores(feat_data: List[OnDiskFeatureData]):
... gb.OnDiskFeatureData(domain="node", type="paper", name="feat",
... format="numpy", path="/tmp/node_feat.npy", in_memory=False),
... ]
>>> gb.load_feature_stores(feat_data)
... {("edge", "author:writes:paper", "label"):
... <dgl.graphbolt.feature_store.TorchBasedFeatureStore object at
... 0x7ff093cb4df0>, ("node", "paper", "feat"):
... <dgl.graphbolt.feature_store.TorchBasedFeatureStore object at
... 0x7ff093cb4dc0>}
>>> feature_sotre = gb.TorchBasedFeatureStore(feat_data)
"""
feat_stores
=
{}
super
().
__init__
()
self
.
_features
=
{}
for
spec
in
feat_data
:
key
=
(
spec
.
domain
,
spec
.
type
,
spec
.
name
)
if
spec
.
format
==
"torch"
:
...
...
@@ -159,12 +157,72 @@ def load_feature_stores(feat_data: List[OnDiskFeatureData]):
f
"Pytorch tensor can only be loaded in memory, "
f
"but the feature
{
key
}
is loaded on disk."
)
feat_sto
res
[
key
]
=
TorchBasedFeature
Store
(
torch
.
load
(
spec
.
path
))
self
.
_featu
res
[
key
]
=
TorchBasedFeature
(
torch
.
load
(
spec
.
path
))
elif
spec
.
format
==
"numpy"
:
mmap_mode
=
"r+"
if
not
spec
.
in_memory
else
None
feat_sto
res
[
key
]
=
TorchBasedFeature
Store
(
self
.
_featu
res
[
key
]
=
TorchBasedFeature
(
torch
.
as_tensor
(
np
.
load
(
spec
.
path
,
mmap_mode
=
mmap_mode
))
)
else
:
raise
ValueError
(
f
"Unknown feature format
{
spec
.
format
}
"
)
return
feat_stores
def
read
(
self
,
domain
:
str
,
type_name
:
str
,
feature_name
:
str
,
ids
:
torch
.
Tensor
=
None
,
):
"""Read from the feature store.
Parameters
----------
domain : str
The domain of the feature such as "node", "edge" or "graph".
type_name : str
The node or edge type name.
feature_name : str
The feature name.
ids : torch.Tensor, optional
The index of the feature. If specified, only the specified indices
of the feature are read. If None, the entire feature is returned.
Returns
-------
torch.Tensor
The read feature.
"""
return
self
.
_features
[(
domain
,
type_name
,
feature_name
)].
read
(
ids
)
def
update
(
self
,
domain
:
str
,
type_name
:
str
,
feature_name
:
str
,
value
:
torch
.
Tensor
,
ids
:
torch
.
Tensor
=
None
,
):
"""Update the feature store.
Parameters
----------
domain : str
The domain of the feature such as "node", "edge" or "graph".
type_name : str
The node or edge type name.
feature_name : str
The feature name.
value : torch.Tensor
The updated value of the feature.
ids : torch.Tensor, optional
The indices of the feature to update. If specified, only the
specified indices of the feature will be updated. For the feature,
the `ids[i]` row is updated to `value[i]`. So the indices and value
must have the same length. If None, the entire feature will be
updated.
"""
self
.
_features
[(
domain
,
type_name
,
feature_name
)].
update
(
value
,
ids
)
def
__len__
(
self
):
"""Return the number of features."""
return
len
(
self
.
_features
)
tests/python/pytorch/graphbolt/test_feature_fetcher.py
View file @
96ddf410
...
...
@@ -6,10 +6,8 @@ import torch
def
get_graphbolt_fetch_func
():
feature_store
=
{
"feature"
:
dgl
.
graphbolt
.
TorchBasedFeatureStore
(
torch
.
randn
(
200
,
4
)),
"label"
:
dgl
.
graphbolt
.
TorchBasedFeatureStore
(
torch
.
randint
(
0
,
10
,
(
200
,))
),
"feature"
:
dgl
.
graphbolt
.
TorchBasedFeature
(
torch
.
randn
(
200
,
4
)),
"label"
:
dgl
.
graphbolt
.
TorchBasedFeature
(
torch
.
randint
(
0
,
10
,
(
200
,))),
}
def
fetch_func
(
data
):
...
...
tests/python/pytorch/graphbolt/test_feature_store.py
View file @
96ddf410
...
...
@@ -19,7 +19,7 @@ def to_on_disk_tensor(test_dir, name, t):
@
pytest
.
mark
.
parametrize
(
"in_memory"
,
[
True
,
False
])
def
test_torch_based_feature
_store
(
in_memory
):
def
test_torch_based_feature
(
in_memory
):
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
a
=
torch
.
tensor
([
1
,
2
,
3
])
b
=
torch
.
tensor
([[
1
,
2
,
3
],
[
4
,
5
,
6
]])
...
...
@@ -27,8 +27,8 @@ def test_torch_based_feature_store(in_memory):
a
=
to_on_disk_tensor
(
test_dir
,
"a"
,
a
)
b
=
to_on_disk_tensor
(
test_dir
,
"b"
,
b
)
feat_store_a
=
gb
.
TorchBasedFeature
Store
(
a
)
feat_store_b
=
gb
.
TorchBasedFeature
Store
(
b
)
feat_store_a
=
gb
.
TorchBasedFeature
(
a
)
feat_store_b
=
gb
.
TorchBasedFeature
(
b
)
assert
torch
.
equal
(
feat_store_a
.
read
(),
torch
.
tensor
([
1
,
2
,
3
]))
assert
torch
.
equal
(
...
...
@@ -71,7 +71,7 @@ def write_tensor_to_disk(dir, name, t, fmt="torch"):
@
pytest
.
mark
.
parametrize
(
"in_memory"
,
[
True
,
False
])
def
test_
loa
d_feature_store
s
(
in_memory
):
def
test_
torch_base
d_feature_store
(
in_memory
):
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
a
=
torch
.
tensor
([
1
,
2
,
3
])
b
=
torch
.
tensor
([
2
,
5
,
3
])
...
...
@@ -95,12 +95,12 @@ def test_load_feature_stores(in_memory):
in_memory
=
in_memory
,
),
]
feat_store
s
=
gb
.
load_f
eature
_s
tore
s
(
feat_data
)
feat_store
=
gb
.
TorchBasedF
eature
S
tore
(
feat_data
)
assert
torch
.
equal
(
feat_store
s
[
(
"node"
,
"paper"
,
"a"
)
].
read
()
,
torch
.
tensor
([
1
,
2
,
3
])
feat_store
.
read
(
"node"
,
"paper"
,
"a"
),
torch
.
tensor
([
1
,
2
,
3
])
)
assert
torch
.
equal
(
feat_store
s
[
(
"edge"
,
"paper-cites-paper"
,
"b"
)
].
read
()
,
feat_store
.
read
(
"edge"
,
"paper-cites-paper"
,
"b"
),
torch
.
tensor
([
2
,
5
,
3
]),
)
...
...
@@ -130,6 +130,8 @@ def test_load_feature_stores(in_memory):
in_memory
=
True
,
),
]
feat_stores
=
gb
.
load_feature_stores
(
feat_data
)
assert
(
"node"
,
None
,
"a"
)
in
feat_stores
feat_store
=
gb
.
TorchBasedFeatureStore
(
feat_data
)
assert
torch
.
equal
(
feat_store
.
read
(
"node"
,
None
,
"a"
),
torch
.
tensor
([
1
,
2
,
3
])
)
feat_stores
=
None
tests/python/pytorch/graphbolt/test_multi_process_dataloader.py
View file @
96ddf410
...
...
@@ -27,8 +27,8 @@ def test_DataLoader():
# TODO(BarclayII): temporarily using DGLGraph. Should test using
# GraphBolt's storage as well once issue #5953 is resolved.
graph
=
dgl
.
add_reverse_edges
(
dgl
.
rand_graph
(
200
,
6000
))
features
=
dgl
.
graphbolt
.
TorchBasedFeature
Store
(
torch
.
randn
(
200
,
4
))
labels
=
dgl
.
graphbolt
.
TorchBasedFeature
Store
(
torch
.
randint
(
0
,
10
,
(
200
,)))
features
=
dgl
.
graphbolt
.
TorchBasedFeature
(
torch
.
randn
(
200
,
4
))
labels
=
dgl
.
graphbolt
.
TorchBasedFeature
(
torch
.
randint
(
0
,
10
,
(
200
,)))
minibatch_sampler
=
dgl
.
graphbolt
.
MinibatchSampler
(
itemset
,
batch_size
=
B
)
subgraph_sampler
=
dgl
.
graphbolt
.
SubgraphSampler
(
...
...
tests/python/pytorch/graphbolt/test_ondisk_dataset.py
View file @
96ddf410
...
...
@@ -647,35 +647,25 @@ def test_OnDiskDataset_Feature_heterograph():
assert
len
(
feature_data
)
==
4
# Verify node feature data.
node_paper_feat
=
feature_data
[(
"node"
,
"paper"
,
"feat"
)]
assert
isinstance
(
node_paper_feat
,
gb
.
TorchBasedFeatureStore
)
assert
torch
.
equal
(
node_paper_feat
.
read
(),
torch
.
tensor
(
node_data_paper
)
feature_data
.
read
(
"node"
,
"paper"
,
"feat"
),
torch
.
tensor
(
node_data_paper
),
)
node_paper_label
=
feature_data
[(
"node"
,
"paper"
,
"label"
)]
assert
isinstance
(
node_paper_label
,
gb
.
TorchBasedFeatureStore
)
assert
torch
.
equal
(
node_paper_label
.
read
(),
torch
.
tensor
(
node_data_label
)
feature_data
.
read
(
"node"
,
"paper"
,
"label"
),
torch
.
tensor
(
node_data_label
),
)
# Verify edge feature data.
edge_writes_feat
=
feature_data
[(
"edge"
,
"author:writes:paper"
,
"feat"
)]
assert
isinstance
(
edge_writes_feat
,
gb
.
TorchBasedFeatureStore
)
assert
torch
.
equal
(
edge_writes_feat
.
read
(),
torch
.
tensor
(
edge_data_writes
)
feature_data
.
read
(
"edge"
,
"author:writes:paper"
,
"feat"
),
torch
.
tensor
(
edge_data_writes
),
)
edge_writes_label
=
feature_data
[
(
"edge"
,
"author:writes:paper"
,
"label"
)
]
assert
isinstance
(
edge_writes_label
,
gb
.
TorchBasedFeatureStore
)
assert
torch
.
equal
(
edge_writes_label
.
read
(),
torch
.
tensor
(
edge_data_label
)
feature_data
.
read
(
"edge"
,
"author:writes:paper"
,
"label"
),
torch
.
tensor
(
edge_data_label
),
)
node_paper_feat
=
None
node_paper_label
=
None
edge_writes_feat
=
None
edge_writes_label
=
None
feature_data
=
None
dataset
=
None
...
...
@@ -735,25 +725,25 @@ def test_OnDiskDataset_Feature_homograph():
assert
len
(
feature_data
)
==
4
# Verify node feature data.
node_feat
=
feature_data
[(
"node"
,
None
,
"feat"
)]
assert
isinstance
(
node_feat
,
gb
.
TorchBasedFeatureStore
)
assert
torch
.
equal
(
node_feat
.
read
(),
torch
.
tensor
(
node_data_feat
))
node_label
=
feature_data
[(
"node"
,
None
,
"label"
)]
assert
isinstance
(
node_label
,
gb
.
TorchBasedFeatureStore
)
assert
torch
.
equal
(
node_label
.
read
(),
torch
.
tensor
(
node_data_label
))
assert
torch
.
equal
(
feature_data
.
read
(
"node"
,
None
,
"feat"
),
torch
.
tensor
(
node_data_feat
),
)
assert
torch
.
equal
(
feature_data
.
read
(
"node"
,
None
,
"label"
),
torch
.
tensor
(
node_data_label
),
)
# Verify edge feature data.
edge_feat
=
feature_data
[(
"edge"
,
None
,
"feat"
)]
assert
isinstance
(
edge_feat
,
gb
.
TorchBasedFeatureStore
)
assert
torch
.
equal
(
edge_feat
.
read
(),
torch
.
tensor
(
edge_data_feat
))
edge_label
=
feature_data
[(
"edge"
,
None
,
"label"
)]
assert
isinstance
(
edge_label
,
gb
.
TorchBasedFeatureStore
)
assert
torch
.
equal
(
edge_label
.
read
(),
torch
.
tensor
(
edge_data_label
))
node_feat
=
None
node_label
=
None
edge_feat
=
None
edge_label
=
None
assert
torch
.
equal
(
feature_data
.
read
(
"edge"
,
None
,
"feat"
),
torch
.
tensor
(
edge_data_feat
),
)
assert
torch
.
equal
(
feature_data
.
read
(
"edge"
,
None
,
"label"
),
torch
.
tensor
(
edge_data_label
),
)
feature_data
=
None
dataset
=
None
...
...
tests/python/pytorch/graphbolt/test_single_process_dataloader.py
View file @
96ddf410
...
...
@@ -10,8 +10,8 @@ def test_DataLoader():
B
=
4
itemset
=
dgl
.
graphbolt
.
ItemSet
(
torch
.
arange
(
N
))
graph
=
gb_test_utils
.
rand_csc_graph
(
200
,
0.15
)
features
=
dgl
.
graphbolt
.
TorchBasedFeature
Store
(
torch
.
randn
(
200
,
4
))
labels
=
dgl
.
graphbolt
.
TorchBasedFeature
Store
(
torch
.
randint
(
0
,
10
,
(
200
,)))
features
=
dgl
.
graphbolt
.
TorchBasedFeature
(
torch
.
randn
(
200
,
4
))
labels
=
dgl
.
graphbolt
.
TorchBasedFeature
(
torch
.
randint
(
0
,
10
,
(
200
,)))
def
sampler_func
(
data
):
adjs
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment