Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
ff9f05c5
Unverified
Commit
ff9f05c5
authored
Jul 03, 2023
by
czkkkkkk
Committed by
GitHub
Jul 03, 2023
Browse files
[Graphbolt] Remove keys in feature store. (#5938)
parent
28578137
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
74 additions
and
117 deletions
+74
-117
python/dgl/graphbolt/feature_store.py
python/dgl/graphbolt/feature_store.py
+46
-90
tests/python/pytorch/graphbolt/test_feature_fetcher.py
tests/python/pytorch/graphbolt/test_feature_fetcher.py
+10
-8
tests/python/pytorch/graphbolt/test_feature_store.py
tests/python/pytorch/graphbolt/test_feature_store.py
+18
-19
No files found.
python/dgl/graphbolt/feature_store.py
View file @
ff9f05c5
...
@@ -8,13 +8,11 @@ class FeatureStore:
...
@@ -8,13 +8,11 @@ class FeatureStore:
def
__init__
(
self
):
def
__init__
(
self
):
pass
pass
def
read
(
self
,
key
:
str
,
ids
:
torch
.
Tensor
=
None
):
def
read
(
self
,
ids
:
torch
.
Tensor
=
None
):
"""Read
a feature
from the feature store.
"""Read from the feature store.
Parameters
Parameters
----------
----------
key : str
The key that uniquely identifies the feature in the feature store.
ids : torch.Tensor, optional
ids : torch.Tensor, optional
The index of the feature. If specified, only the specified indices
The index of the feature. If specified, only the specified indices
of the feature are read. If None, the entire feature is returned.
of the feature are read. If None, the entire feature is returned.
...
@@ -26,17 +24,11 @@ class FeatureStore:
...
@@ -26,17 +24,11 @@ class FeatureStore:
"""
"""
raise
NotImplementedError
raise
NotImplementedError
def
update
(
self
,
key
:
str
,
value
:
torch
.
Tensor
,
ids
:
torch
.
Tensor
=
None
):
def
update
(
self
,
value
:
torch
.
Tensor
,
ids
:
torch
.
Tensor
=
None
):
"""Update a feature in the feature store.
"""Update the feature store.
This function is used to update a feature in the feature store. The
feature is identified by a unique key, and its value is specified using
a tensor.
Parameters
Parameters
----------
----------
key : str
The key that uniquely identifies the feature in the feature store.
value : torch.Tensor
value : torch.Tensor
The updated value of the feature.
The updated value of the feature.
ids : torch.Tensor, optional
ids : torch.Tensor, optional
...
@@ -50,87 +42,58 @@ class FeatureStore:
...
@@ -50,87 +42,58 @@ class FeatureStore:
class
TorchBasedFeatureStore
(
FeatureStore
):
class
TorchBasedFeatureStore
(
FeatureStore
):
r
"""Torch based key-value feature store, where the key are strings and
r
"""Torch based feature store."""
values are Pytorch tensors."""
def
__init__
(
self
,
feature_dict
:
dict
):
"""Initialize a torch based feature store.
The feature store is initialized with a dictionary of tensors, where the
def
__init__
(
self
,
torch_feature
:
torch
.
Tensor
):
key is the name of a feature and the value is the tensor. The value can
"""Initialize a torch based feature store by a torch feature.
be multi-dimensional, where the first dimension is the index of the
feature.
Note that the
values can be
in memory or on disk.
Note that the
feature can be either
in memory or on disk.
Parameters
Parameters
----------
----------
feature
_dict : dict, optional
torch_
feature
: torch.Tensor
A dictionary of tensors
.
The torch feature
.
Examples
Examples
--------
--------
>>> import torch
>>> import torch
>>> feature_dict = {
>>> torch_feat = torch.arange(0, 5)
... "user": torch.arange(0, 5),
>>> feature_store = TorchBasedFeatureStore(torch_feat)
... "item": torch.arange(0, 6),
>>> feature_store.read()
... "rel": torch.arange(0, 6).view(2, 3),
tensor([0, 1, 2, 3, 4])
... }
>>> feature_store.read(torch.tensor([0, 1, 2]))
>>> feature_store = TorchBasedFeatureStore(feature_dict)
>>> feature_store.read("user", torch.tensor([0, 1, 2]))
tensor([0, 1, 2])
tensor([0, 1, 2])
>>> feature_store.read("item", torch.tensor([0, 1, 2]))
>>> feature_store.update(torch.ones(3, dtype=torch.long),
tensor([0, 1, 2])
... torch.tensor([0, 1, 2]))
>>> feature_store.read("rel", torch.tensor([0]))
>>> feature_store.read(torch.tensor([0, 1, 2, 3]))
tensor([[0, 1, 2]])
tensor([1, 1, 1, 3])
>>> feature_store.update("user",
... torch.ones(3, dtype=torch.long), torch.tensor([0, 1, 2]))
>>> feature_store.read("user", torch.tensor([0, 1, 2]))
tensor([1, 1, 1])
>>> import numpy as np
>>> import numpy as np
>>> user = np.arange(0, 5)
>>> arr = np.arange(0, 5)
>>> item = np.arange(0, 6)
>>> np.save("/tmp/arr.npy", arr)
>>> np.save("/tmp/user.npy", user)
>>> torch_feat = torch.as_tensor(np.load("/tmp/arr.npy",
>>. np.save("/tmp/item.npy", item)
... mmap_mode="r+"))
>>> feature_dict = {
>>> feature_store = TorchBasedFeatureStore(torch_feat)
... "user": torch.as_tensor(np.load("/tmp/user.npy",
>>> feature_store.read()
... mmap_mode="r+")),
tensor([0, 1, 2, 3, 4])
... "item": torch.as_tensor(np.load("/tmp/item.npy",
>>> feature_store.read(torch.tensor([0, 1, 2]))
... mmap_mode="r+")),
... }
>>> feature_store = TorchBasedFeatureStore(feature_dict)
>>> feature_store.read("user", torch.tensor([0, 1, 2]))
tensor([0, 1, 2])
tensor([0, 1, 2])
>>> feature_store.read("item", torch.tensor([3, 4, 2]))
tensor([3, 4, 2])
"""
"""
super
(
TorchBasedFeatureStore
,
self
).
__init__
()
super
(
TorchBasedFeatureStore
,
self
).
__init__
()
assert
isinstance
(
feature_dict
,
dict
),
(
assert
isinstance
(
torch_feature
,
torch
.
Tensor
),
(
f
"feature_dict in TorchBasedFeatureStore must be dict, "
f
"torch_feature in TorchBasedFeatureStore must be torch.Tensor, "
f
"but got
{
type
(
feature_dict
)
}
."
f
"but got
{
type
(
torch_feature
)
}
."
)
for
k
,
v
in
feature_dict
.
items
():
assert
isinstance
(
k
,
str
),
f
"Key in TorchBasedFeatureStore must be str, but got
{
k
}
."
assert
isinstance
(
v
,
torch
.
Tensor
),
(
f
"Value in TorchBasedFeatureStore must be torch.Tensor,"
f
"but got
{
v
}
."
)
)
self
.
_tensor
=
torch_feature
self
.
_feature_dict
=
feature_dict
def
read
(
self
,
ids
:
torch
.
Tensor
=
None
):
"""Read the feature by index.
def
read
(
self
,
key
:
str
,
ids
:
torch
.
Tensor
=
None
):
The returned tensor is always in memory, no matter whether the feature
"""Read a feature from the feature store by index.
store is in memory or on disk.
The returned feature is always in memory, no matter whether the feature
to read is in memory or on disk.
Parameters
Parameters
----------
----------
key : str
The key of the feature.
ids : torch.Tensor, optional
ids : torch.Tensor, optional
The index of the feature. If specified, only the specified indices
The index of the feature. If specified, only the specified indices
of the feature are read. If None, the entire feature is returned.
of the feature are read. If None, the entire feature is returned.
...
@@ -140,24 +103,15 @@ class TorchBasedFeatureStore(FeatureStore):
...
@@ -140,24 +103,15 @@ class TorchBasedFeatureStore(FeatureStore):
torch.Tensor
torch.Tensor
The read feature.
The read feature.
"""
"""
assert
(
key
in
self
.
_feature_dict
),
f
"key
{
key
}
not in
{
self
.
_feature_dict
.
keys
()
}
"
if
ids
is
None
:
if
ids
is
None
:
return
self
.
_feature_dict
[
key
]
return
self
.
_tensor
return
self
.
_feature_dict
[
key
][
ids
]
return
self
.
_tensor
[
ids
]
def
update
(
self
,
key
:
str
,
value
:
torch
.
Tensor
,
ids
:
torch
.
Tensor
=
None
):
"""Update a feature in the feature store.
This function is used to update a feature in the feature store. The
def
update
(
self
,
value
:
torch
.
Tensor
,
ids
:
torch
.
Tensor
=
None
):
feature is identified by a unique key, and its value is specified using
"""Update the feature store.
a tensor.
Parameters
Parameters
----------
----------
key : str
The key that uniquely identifies the feature in the feature store.
value : torch.Tensor
value : torch.Tensor
The updated value of the feature.
The updated value of the feature.
ids : torch.Tensor, optional
ids : torch.Tensor, optional
...
@@ -167,14 +121,16 @@ class TorchBasedFeatureStore(FeatureStore):
...
@@ -167,14 +121,16 @@ class TorchBasedFeatureStore(FeatureStore):
must have the same length. If None, the entire feature will be
must have the same length. If None, the entire feature will be
updated.
updated.
"""
"""
assert
(
key
in
self
.
_feature_dict
),
f
"key
{
key
}
not in
{
self
.
_feature_dict
.
keys
()
}
"
if
ids
is
None
:
if
ids
is
None
:
self
.
_feature_dict
[
key
]
=
value
assert
self
.
_tensor
.
shape
==
value
.
shape
,
(
f
"ids is None, so the entire feature will be updated. "
f
"But the shape of the feature is
{
self
.
_tensor
.
shape
}
, "
f
"while the shape of the value is
{
value
.
shape
}
."
)
self
.
_tensor
[:]
=
value
else
:
else
:
assert
ids
.
shape
[
0
]
==
value
.
shape
[
0
],
(
assert
ids
.
shape
[
0
]
==
value
.
shape
[
0
],
(
f
"ids and value must have the same length, "
f
"ids and value must have the same length, "
f
"but got
{
ids
.
shape
[
0
]
}
and
{
value
.
shape
[
0
]
}
."
f
"but got
{
ids
.
shape
[
0
]
}
and
{
value
.
shape
[
0
]
}
."
)
)
self
.
_
feature_dict
[
key
]
[
ids
]
=
value
self
.
_
tensor
[
ids
]
=
value
tests/python/pytorch/graphbolt/test_feature_fetcher.py
View file @
ff9f05c5
...
@@ -5,16 +5,18 @@ import torch
...
@@ -5,16 +5,18 @@ import torch
def
get_graphbolt_fetch_func
():
def
get_graphbolt_fetch_func
():
feature_store
=
dgl
.
graphbolt
.
feature_store
.
TorchBasedFeatureStore
(
feature_store
=
{
{
"feature"
:
dgl
.
graphbolt
.
feature_store
.
TorchBasedFeatureStore
(
"feature"
:
torch
.
randn
(
200
,
4
),
torch
.
randn
(
200
,
4
)
"label"
:
torch
.
randint
(
0
,
10
,
(
200
,)),
),
"label"
:
dgl
.
graphbolt
.
feature_store
.
TorchBasedFeatureStore
(
torch
.
randint
(
0
,
10
,
(
200
,))
),
}
}
)
def
fetch_func
(
data
):
def
fetch_func
(
data
):
return
feature_store
.
read
(
"feature"
,
data
),
feature_store
.
read
(
return
feature_store
[
"feature"
].
read
(
data
),
feature_store
[
"label"
]
.
read
(
"label"
,
data
data
)
)
return
fetch_func
return
fetch_func
...
...
tests/python/pytorch/graphbolt/test_feature_store.py
View file @
ff9f05c5
...
@@ -21,40 +21,39 @@ def to_on_disk_tensor(test_dir, name, t):
...
@@ -21,40 +21,39 @@ def to_on_disk_tensor(test_dir, name, t):
def
test_torch_based_feature_store
(
in_memory
):
def
test_torch_based_feature_store
(
in_memory
):
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
a
=
torch
.
tensor
([
1
,
2
,
3
])
a
=
torch
.
tensor
([
1
,
2
,
3
])
b
=
torch
.
tensor
([
3
,
4
,
5
])
b
=
torch
.
tensor
([[
1
,
2
,
3
],
[
4
,
5
,
6
]])
c
=
torch
.
tensor
([[
1
,
2
,
3
],
[
4
,
5
,
6
]])
if
not
in_memory
:
if
not
in_memory
:
a
=
to_on_disk_tensor
(
test_dir
,
"a"
,
a
)
a
=
to_on_disk_tensor
(
test_dir
,
"a"
,
a
)
b
=
to_on_disk_tensor
(
test_dir
,
"b"
,
b
)
b
=
to_on_disk_tensor
(
test_dir
,
"b"
,
b
)
c
=
to_on_disk_tensor
(
test_dir
,
"c"
,
c
)
feature_store
=
gb
.
TorchBasedFeatureStore
({
"a"
:
a
,
"b"
:
b
,
"c"
:
c
})
feat_store_a
=
gb
.
TorchBasedFeatureStore
(
a
)
assert
torch
.
equal
(
feature_store
.
read
(
"a"
),
torch
.
tensor
([
1
,
2
,
3
]))
feat_store_b
=
gb
.
TorchBasedFeatureStore
(
b
)
assert
torch
.
equal
(
feature_store
.
read
(
"b"
),
torch
.
tensor
([
3
,
4
,
5
]))
assert
torch
.
equal
(
feat_store_a
.
read
(),
torch
.
tensor
([
1
,
2
,
3
]))
assert
torch
.
equal
(
feat_store_b
.
read
(),
torch
.
tensor
([[
1
,
2
,
3
],
[
4
,
5
,
6
]])
)
assert
torch
.
equal
(
assert
torch
.
equal
(
feat
ure
_store
.
read
(
"a"
,
torch
.
tensor
([
0
,
2
])),
feat_store
_a
.
read
(
torch
.
tensor
([
0
,
2
])),
torch
.
tensor
([
1
,
3
]),
torch
.
tensor
([
1
,
3
]),
)
)
assert
torch
.
equal
(
assert
torch
.
equal
(
feat
ure
_store
.
read
(
"a"
,
torch
.
tensor
([
1
,
1
])),
feat_store
_a
.
read
(
torch
.
tensor
([
1
,
1
])),
torch
.
tensor
([
2
,
2
]),
torch
.
tensor
([
2
,
2
]),
)
)
assert
torch
.
equal
(
assert
torch
.
equal
(
feat
ure
_store
.
read
(
"c"
,
torch
.
tensor
([
1
])),
feat_store
_b
.
read
(
torch
.
tensor
([
1
])),
torch
.
tensor
([[
4
,
5
,
6
]]),
torch
.
tensor
([[
4
,
5
,
6
]]),
)
)
feature_store
.
update
(
"a"
,
torch
.
tensor
([
0
,
1
,
2
]))
feat_store_a
.
update
(
torch
.
tensor
([
0
,
1
,
2
]),
torch
.
tensor
([
0
,
1
,
2
]))
assert
torch
.
equal
(
feature_store
.
read
(
"a"
),
torch
.
tensor
([
0
,
1
,
2
]))
assert
torch
.
equal
(
feat_store_a
.
read
(),
torch
.
tensor
([
0
,
1
,
2
]))
assert
torch
.
equal
(
feat_store_a
.
update
(
torch
.
tensor
([
2
,
0
]),
torch
.
tensor
([
0
,
2
]))
feature_store
.
read
(
"a"
,
torch
.
tensor
([
0
,
2
])),
assert
torch
.
equal
(
feat_store_a
.
read
(),
torch
.
tensor
([
2
,
1
,
0
]))
torch
.
tensor
([
0
,
2
]),
)
with
pytest
.
raises
(
AssertionError
):
feature_store
.
read
(
"d"
)
with
pytest
.
raises
(
IndexError
):
with
pytest
.
raises
(
IndexError
):
feat
ure
_store
.
read
(
"a"
,
torch
.
tensor
([
0
,
1
,
2
,
3
]))
feat_store
_a
.
read
(
torch
.
tensor
([
0
,
1
,
2
,
3
]))
# For windows, the file is locked by the numpy.load. We need to delete
# For windows, the file is locked by the numpy.load. We need to delete
# it before closing the temporary directory.
# it before closing the temporary directory.
a
=
b
=
c
=
feature_store
=
None
a
=
b
=
None
feat_store_a
=
feat_store_b
=
None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment