Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
219c9f1a
Unverified
Commit
219c9f1a
authored
Aug 31, 2023
by
peizhou001
Committed by
GitHub
Aug 31, 2023
Browse files
[Graphbolt]Refator feature fetcher (#6245)
parent
155608d3
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
111 additions
and
54 deletions
+111
-54
python/dgl/graphbolt/data_block.py
python/dgl/graphbolt/data_block.py
+15
-8
python/dgl/graphbolt/feature_fetcher.py
python/dgl/graphbolt/feature_fetcher.py
+77
-33
tests/python/pytorch/graphbolt/test_feature_fetcher.py
tests/python/pytorch/graphbolt/test_feature_fetcher.py
+15
-11
tests/python/pytorch/graphbolt/test_multi_process_dataloader.py
...python/pytorch/graphbolt/test_multi_process_dataloader.py
+1
-1
tests/python/pytorch/graphbolt/test_single_process_dataloader.py
...ython/pytorch/graphbolt/test_single_process_dataloader.py
+3
-1
No files found.
python/dgl/graphbolt/data_block.py
View file @
219c9f1a
...
@@ -23,18 +23,25 @@ class DataBlock:
...
@@ -23,18 +23,25 @@ class DataBlock:
representing a subset of a larger graph structure.
representing a subset of a larger graph structure.
"""
"""
node_feature
:
Dict
[
Tuple
[
str
,
str
],
torch
.
Tensor
]
=
None
node_features
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
Dict
[
Tuple
[
str
,
str
],
torch
.
Tensor
]
]
=
None
"""A representation of node features.
"""A representation of node features.
Keys are tuples of '(node_type, feature_name)' and the values are
- If keys are single strings: It means the graph is homogeneous, and the
corresponding features. Note that for a homogeneous graph, where there are
keys are feature names.
no node types, 'node_type' should be None.
- If keys are tuples: It means the graph is heterogeneous, and the keys
are tuples of '(node_type, feature_name)'.
"""
"""
edge_feature
:
List
[
Dict
[
Tuple
[
str
,
str
],
torch
.
Tensor
]]
=
None
edge_features
:
List
[
Union
[
Dict
[
str
,
torch
.
Tensor
],
Dict
[
Tuple
[
str
,
str
],
torch
.
Tensor
]]
]
=
None
"""Edge features associated with the 'sampled_subgraphs'.
"""Edge features associated with the 'sampled_subgraphs'.
The keys are tuples in the format '(edge_type, feature_name)', and the
- If keys are single strings: It means the graph is homogeneous, and the
values represent the corresponding features. In the case of a homogeneous
keys are feature names.
graph where no edge types exist, 'edge_type' should be set to None.
- If keys are tuples: It means the graph is heterogeneous, and the keys
are tuples of '(edge_type, feature_name)'. Note, edge type is single
string of format 'str:str:str'.
"""
"""
input_nodes
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]
=
None
input_nodes
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]
=
None
...
...
python/dgl/graphbolt/feature_fetcher.py
View file @
219c9f1a
"""Feature fetchers"""
"""Feature fetchers"""
from
typing
import
Dict
from
torchdata.datapipes.iter
import
Mapper
from
torchdata.datapipes.iter
import
Mapper
class
FeatureFetcher
(
Mapper
):
class
FeatureFetcher
(
Mapper
):
"""A feature fetcher used to fetch features for node/edge in graphbolt."""
"""A feature fetcher used to fetch features for node/edge in graphbolt."""
def
__init__
(
self
,
datapipe
,
feature_store
,
feature_keys
):
def
__init__
(
self
,
datapipe
,
feature_store
,
node_feature_keys
=
None
,
edge_feature_keys
=
None
,
):
"""
"""
Initlization for a feature fetcher.
Initlization for a feature fetcher.
...
@@ -16,13 +24,24 @@ class FeatureFetcher(Mapper):
...
@@ -16,13 +24,24 @@ class FeatureFetcher(Mapper):
The datapipe.
The datapipe.
feature_store : FeatureStore
feature_store : FeatureStore
A storage for features, support read and update.
A storage for features, support read and update.
feature_keys : (str, str, str)
node_feature_keys : List[str] or Dict[str, List[str]]
Features need to be read, with each feature being uniquely identified
Node features keys indicates the node features need to be read.
by a triplet '(domain, type_name, feature_name)'.
- If `node_features` is a list: It means the graph is homogeneous
graph, and the 'str' inside are feature names.
- If `node_features` is a dictionary: The keys should be node type
and the values are lists of feature names.
edge_feature_keys : List[str] or Dict[str, List[str]]
Edge features name indicates the edge features need to be read.
- If `edge_features` is a list: It means the graph is homogeneous
graph, and the 'str' inside are feature names.
- If `edge_features` is a dictionary: The keys are edge types,
following the format 'str:str:str', and the values are lists of
feature names.
"""
"""
super
().
__init__
(
datapipe
,
self
.
_read
)
super
().
__init__
(
datapipe
,
self
.
_read
)
self
.
feature_store
=
feature_store
self
.
feature_store
=
feature_store
self
.
feature_keys
=
feature_keys
self
.
node_feature_keys
=
node_feature_keys
self
.
edge_feature_keys
=
edge_feature_keys
def
_read
(
self
,
data
):
def
_read
(
self
,
data
):
"""
"""
...
@@ -40,38 +59,63 @@ class FeatureFetcher(Mapper):
...
@@ -40,38 +59,63 @@ class FeatureFetcher(Mapper):
DataBlock
DataBlock
An instance of 'DataBlock' filled with required features.
An instance of 'DataBlock' filled with required features.
"""
"""
data
.
node_feature
=
{}
data
.
node_feature
s
=
{}
num_layer
=
len
(
data
.
sampled_subgraphs
)
if
data
.
sampled_subgraphs
else
0
num_layer
=
len
(
data
.
sampled_subgraphs
)
if
data
.
sampled_subgraphs
else
0
data
.
edge_feature
=
[{}
for
_
in
range
(
num_layer
)]
data
.
edge_features
=
[{}
for
_
in
range
(
num_layer
)]
for
key
in
self
.
feature_keys
:
is_heterogeneous
=
isinstance
(
domain
,
type_name
,
feature_name
=
key
self
.
node_feature_keys
,
Dict
if
domain
==
"node"
and
data
.
input_nodes
is
not
None
:
)
or
isinstance
(
self
.
edge_feature_keys
,
Dict
)
nodes
=
(
# Read Node features.
data
.
input_nodes
if
self
.
node_feature_keys
and
data
.
input_nodes
is
not
None
:
if
not
type_name
if
is_heterogeneous
:
else
data
.
input_nodes
[
type_name
]
for
type_name
,
feature_names
in
self
.
node_feature_keys
.
items
():
)
nodes
=
data
.
input_nodes
[
type_name
]
if
nodes
is
not
None
:
if
nodes
is
None
:
data
.
node_feature
[
continue
for
feature_name
in
feature_names
:
data
.
node_features
[
(
type_name
,
feature_name
)
(
type_name
,
feature_name
)
]
=
self
.
feature_store
.
read
(
]
=
self
.
feature_store
.
read
(
domain
,
"node"
,
type_name
,
type_name
,
feature_name
,
feature_name
,
nodes
,
nodes
,
)
)
el
if
domain
==
"edge"
and
data
.
sampled_subgraphs
is
not
Non
e
:
el
s
e
:
for
i
,
subgraph
in
enumerate
(
data
.
sampled_subgraphs
)
:
for
feature_name
in
self
.
node_feature_keys
:
if
subgraph
.
reverse_edge_ids
is
not
None
:
data
.
node_features
[
feature_name
]
=
self
.
feature_store
.
read
(
edges
=
(
"node"
,
subgraph
.
reverse_edge_ids
None
,
if
not
typ
e_name
featur
e_name
,
else
subgraph
.
reverse_edge_ids
.
get
(
type_name
,
None
)
data
.
input_nodes
,
)
)
if
edges
is
not
None
:
# Read Edge features.
data
.
edge_feature
[
i
][
if
self
.
edge_feature_keys
and
data
.
sampled_subgraphs
:
for
i
,
subgraph
in
enumerate
(
data
.
sampled_subgraphs
):
if
subgraph
.
reverse_edge_ids
is
None
:
continue
if
is_heterogeneous
:
for
(
type_name
,
feature_names
,
)
in
self
.
edge_feature_keys
.
items
():
edges
=
subgraph
.
reverse_edge_ids
.
get
(
type_name
,
None
)
if
edges
is
None
:
continue
for
feature_name
in
feature_names
:
data
.
edge_features
[
i
][
(
type_name
,
feature_name
)
(
type_name
,
feature_name
)
]
=
self
.
feature_store
.
read
(
]
=
self
.
feature_store
.
read
(
domain
,
type_name
,
feature_name
,
edges
"edge"
,
type_name
,
feature_name
,
edges
)
else
:
for
feature_name
in
self
.
edge_feature_keys
:
data
.
edge_features
[
i
][
feature_name
]
=
self
.
feature_store
.
read
(
"edge"
,
None
,
feature_name
,
subgraph
.
reverse_edge_ids
,
)
)
return
data
return
data
tests/python/pytorch/graphbolt/test_feature_fetcher.py
View file @
219c9f1a
...
@@ -21,7 +21,7 @@ def test_FeatureFetcher_homo():
...
@@ -21,7 +21,7 @@ def test_FeatureFetcher_homo():
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
data_block_converter
=
Mapper
(
minibatch_dp
,
gb_test_utils
.
to_node_block
)
data_block_converter
=
Mapper
(
minibatch_dp
,
gb_test_utils
.
to_node_block
)
sampler_dp
=
gb
.
NeighborSampler
(
data_block_converter
,
graph
,
fanouts
)
sampler_dp
=
gb
.
NeighborSampler
(
data_block_converter
,
graph
,
fanouts
)
fetcher_dp
=
gb
.
FeatureFetcher
(
sampler_dp
,
feature_store
,
keys
)
fetcher_dp
=
gb
.
FeatureFetcher
(
sampler_dp
,
feature_store
,
[
"a"
],
[
"b"
]
)
assert
len
(
list
(
fetcher_dp
))
==
5
assert
len
(
list
(
fetcher_dp
))
==
5
...
@@ -54,14 +54,14 @@ def test_FeatureFetcher_with_edges_homo():
...
@@ -54,14 +54,14 @@ def test_FeatureFetcher_with_edges_homo():
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
10
))
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
10
))
minibatch_dp
=
gb
.
MinibatchSampler
(
itemset
,
batch_size
=
2
)
minibatch_dp
=
gb
.
MinibatchSampler
(
itemset
,
batch_size
=
2
)
converter_dp
=
Mapper
(
minibatch_dp
,
add_node_and_edge_ids
)
converter_dp
=
Mapper
(
minibatch_dp
,
add_node_and_edge_ids
)
fetcher_dp
=
gb
.
FeatureFetcher
(
converter_dp
,
feature_store
,
keys
)
fetcher_dp
=
gb
.
FeatureFetcher
(
converter_dp
,
feature_store
,
[
"a"
],
[
"b"
]
)
assert
len
(
list
(
fetcher_dp
))
==
5
assert
len
(
list
(
fetcher_dp
))
==
5
for
data
in
fetcher_dp
:
for
data
in
fetcher_dp
:
assert
data
.
node_feature
[
(
None
,
"a"
)
].
size
(
0
)
==
2
assert
data
.
node_feature
s
[
"a"
].
size
(
0
)
==
2
assert
len
(
data
.
edge_feature
)
==
3
assert
len
(
data
.
edge_feature
s
)
==
3
for
edge_feature
in
data
.
edge_feature
:
for
edge_feature
in
data
.
edge_feature
s
:
assert
edge_feature
[
(
None
,
"b"
)
].
size
(
0
)
==
10
assert
edge_feature
[
"b"
].
size
(
0
)
==
10
def
get_hetero_graph
():
def
get_hetero_graph
():
...
@@ -108,7 +108,9 @@ def test_FeatureFetcher_hetero():
...
@@ -108,7 +108,9 @@ def test_FeatureFetcher_hetero():
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
data_block_converter
=
Mapper
(
minibatch_dp
,
gb_test_utils
.
to_node_block
)
data_block_converter
=
Mapper
(
minibatch_dp
,
gb_test_utils
.
to_node_block
)
sampler_dp
=
gb
.
NeighborSampler
(
data_block_converter
,
graph
,
fanouts
)
sampler_dp
=
gb
.
NeighborSampler
(
data_block_converter
,
graph
,
fanouts
)
fetcher_dp
=
gb
.
FeatureFetcher
(
sampler_dp
,
feature_store
,
keys
)
fetcher_dp
=
gb
.
FeatureFetcher
(
sampler_dp
,
feature_store
,
{
"n1"
:
[
"a"
],
"n2"
:
[
"a"
]}
)
assert
len
(
list
(
fetcher_dp
))
==
3
assert
len
(
list
(
fetcher_dp
))
==
3
...
@@ -148,11 +150,13 @@ def test_FeatureFetcher_with_edges_hetero():
...
@@ -148,11 +150,13 @@ def test_FeatureFetcher_with_edges_hetero():
)
)
minibatch_dp
=
gb
.
MinibatchSampler
(
itemset
,
batch_size
=
2
)
minibatch_dp
=
gb
.
MinibatchSampler
(
itemset
,
batch_size
=
2
)
converter_dp
=
Mapper
(
minibatch_dp
,
add_node_and_edge_ids
)
converter_dp
=
Mapper
(
minibatch_dp
,
add_node_and_edge_ids
)
fetcher_dp
=
gb
.
FeatureFetcher
(
converter_dp
,
feature_store
,
keys
)
fetcher_dp
=
gb
.
FeatureFetcher
(
converter_dp
,
feature_store
,
{
"n1"
:
[
"a"
]},
{
"n1:e1:n2"
:
[
"a"
]}
)
assert
len
(
list
(
fetcher_dp
))
==
5
assert
len
(
list
(
fetcher_dp
))
==
5
for
data
in
fetcher_dp
:
for
data
in
fetcher_dp
:
assert
data
.
node_feature
[(
"n1"
,
"a"
)].
size
(
0
)
==
2
assert
data
.
node_feature
s
[(
"n1"
,
"a"
)].
size
(
0
)
==
2
assert
len
(
data
.
edge_feature
)
==
3
assert
len
(
data
.
edge_feature
s
)
==
3
for
edge_feature
in
data
.
edge_feature
:
for
edge_feature
in
data
.
edge_feature
s
:
assert
edge_feature
[(
"n1:e1:n2"
,
"a"
)].
size
(
0
)
==
10
assert
edge_feature
[(
"n1:e1:n2"
,
"a"
)].
size
(
0
)
==
10
tests/python/pytorch/graphbolt/test_multi_process_dataloader.py
View file @
219c9f1a
...
@@ -32,7 +32,7 @@ def test_DataLoader():
...
@@ -32,7 +32,7 @@ def test_DataLoader():
feature_fetcher
=
dgl
.
graphbolt
.
FeatureFetcher
(
feature_fetcher
=
dgl
.
graphbolt
.
FeatureFetcher
(
subgraph_sampler
,
subgraph_sampler
,
feature_store
,
feature_store
,
keys
,
[
"a"
,
"b"
]
,
)
)
device_transferrer
=
dgl
.
graphbolt
.
CopyTo
(
feature_fetcher
,
F
.
ctx
())
device_transferrer
=
dgl
.
graphbolt
.
CopyTo
(
feature_fetcher
,
F
.
ctx
())
...
...
tests/python/pytorch/graphbolt/test_single_process_dataloader.py
View file @
219c9f1a
...
@@ -32,7 +32,9 @@ def test_DataLoader():
...
@@ -32,7 +32,9 @@ def test_DataLoader():
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
2
)],
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
2
)],
)
)
feature_fetcher
=
dgl
.
graphbolt
.
FeatureFetcher
(
feature_fetcher
=
dgl
.
graphbolt
.
FeatureFetcher
(
subgraph_sampler
,
feature_store
,
keys
subgraph_sampler
,
feature_store
,
[
"a"
],
)
)
device_transferrer
=
dgl
.
graphbolt
.
CopyTo
(
feature_fetcher
,
F
.
ctx
())
device_transferrer
=
dgl
.
graphbolt
.
CopyTo
(
feature_fetcher
,
F
.
ctx
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment