Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
81ac9d27
Unverified
Commit
81ac9d27
authored
Nov 23, 2023
by
czkkkkkk
Committed by
GitHub
Nov 23, 2023
Browse files
[Graphbolt] Add MiniBatchBase (#6531)
parent
c08f77bf
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
220 additions
and
26 deletions
+220
-26
python/dgl/graphbolt/feature_fetcher.py
python/dgl/graphbolt/feature_fetcher.py
+28
-15
python/dgl/graphbolt/minibatch.py
python/dgl/graphbolt/minibatch.py
+128
-2
python/dgl/graphbolt/minibatch_transformer.py
python/dgl/graphbolt/minibatch_transformer.py
+2
-2
tests/python/pytorch/graphbolt/test_feature_fetcher.py
tests/python/pytorch/graphbolt/test_feature_fetcher.py
+62
-7
No files found.
python/dgl/graphbolt/feature_fetcher.py
View file @
81ac9d27
...
...
@@ -4,6 +4,8 @@ from typing import Dict
from
torch.utils.data
import
functional_datapipe
from
.base
import
etype_tuple_to_str
from
.minibatch_transformer
import
MiniBatchTransformer
...
...
@@ -67,21 +69,22 @@ class FeatureFetcher(MiniBatchTransformer):
MiniBatch
An instance of :class:`MiniBatch` filled with required features.
"""
data
.
node_features
=
{}
num_layer
=
len
(
data
.
sampled_subgraphs
)
if
data
.
sampled_subgraphs
else
0
data
.
edge_features
=
[{}
for
_
in
range
(
num_layer
)]
node_features
=
{}
num_layer
s
=
data
.
num_layers
()
edge_features
=
[{}
for
_
in
range
(
num_layer
s
)]
is_heterogeneous
=
isinstance
(
self
.
node_feature_keys
,
Dict
)
or
isinstance
(
self
.
edge_feature_keys
,
Dict
)
# Read Node features.
if
self
.
node_feature_keys
and
data
.
input_nodes
is
not
None
:
input_nodes
=
data
.
node_ids
()
if
self
.
node_feature_keys
and
input_nodes
is
not
None
:
if
is_heterogeneous
:
for
type_name
,
feature_names
in
self
.
node_feature_keys
.
items
():
nodes
=
data
.
input_nodes
[
type_name
]
nodes
=
input_nodes
[
type_name
]
if
nodes
is
None
:
continue
for
feature_name
in
feature_names
:
data
.
node_features
[
node_features
[
(
type_name
,
feature_name
)
]
=
self
.
feature_store
.
read
(
"node"
,
...
...
@@ -91,39 +94,49 @@ class FeatureFetcher(MiniBatchTransformer):
)
else
:
for
feature_name
in
self
.
node_feature_keys
:
data
.
node_features
[
feature_name
]
=
self
.
feature_store
.
read
(
node_features
[
feature_name
]
=
self
.
feature_store
.
read
(
"node"
,
None
,
feature_name
,
data
.
input_nodes
,
input_nodes
,
)
# Read Edge features.
if
self
.
edge_feature_keys
and
data
.
sampled_subgraphs
:
for
i
,
subgraph
in
enumerate
(
data
.
sampled_subgraphs
):
if
subgraph
.
original_edge_ids
is
None
:
if
self
.
edge_feature_keys
and
num_layers
>
0
:
for
i
in
range
(
num_layers
):
original_edge_ids
=
data
.
edge_ids
(
i
)
if
original_edge_ids
is
None
:
continue
if
is_heterogeneous
:
# Convert edge type to string for DGLMiniBatch.
original_edge_ids
=
{
etype_tuple_to_str
(
key
)
if
isinstance
(
key
,
tuple
)
else
key
:
value
for
key
,
value
in
original_edge_ids
.
items
()
}
for
(
type_name
,
feature_names
,
)
in
self
.
edge_feature_keys
.
items
():
edges
=
subgraph
.
original_edge_ids
.
get
(
type_name
,
None
)
edges
=
original_edge_ids
.
get
(
type_name
,
None
)
if
edges
is
None
:
continue
for
feature_name
in
feature_names
:
data
.
edge_features
[
i
][
edge_features
[
i
][
(
type_name
,
feature_name
)
]
=
self
.
feature_store
.
read
(
"edge"
,
type_name
,
feature_name
,
edges
)
else
:
for
feature_name
in
self
.
edge_feature_keys
:
data
.
edge_features
[
i
][
edge_features
[
i
][
feature_name
]
=
self
.
feature_store
.
read
(
"edge"
,
None
,
feature_name
,
subgraph
.
original_edge_ids
,
original_edge_ids
,
)
data
.
set_node_features
(
node_features
)
data
.
set_edge_features
(
edge_features
)
return
data
python/dgl/graphbolt/minibatch.py
View file @
81ac9d27
"""Unified data structure for input and ouput of all the stages in loading process."""
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Tuple
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
...
...
@@ -16,7 +16,53 @@ __all__ = ["DGLMiniBatch", "MiniBatch"]
@
dataclass
class
DGLMiniBatch
:
class
MiniBatchBase
(
object
):
"""Base class for `MiniBatch` and `DGLMiniBatch`."""
def
node_ids
(
self
)
->
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]:
"""A representation of input nodes in the outermost layer. Contains all
nodes in the MiniBatch.
- If `input_nodes` is a tensor: It indicates the graph is homogeneous.
- If `input_nodes` is a dictionary: The keys should be node type and the
value should be corresponding heterogeneous node id.
"""
raise
NotImplementedError
def
num_layers
(
self
)
->
int
:
"""Return the number of layers."""
raise
NotImplementedError
def
set_node_features
(
self
,
node_features
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
Dict
[
Tuple
[
str
,
str
],
torch
.
Tensor
]
],
)
->
None
:
"""Set node features."""
raise
NotImplementedError
def
set_edge_features
(
self
,
edge_features
:
List
[
Union
[
Dict
[
str
,
torch
.
Tensor
],
Dict
[
Tuple
[
str
,
str
],
torch
.
Tensor
]]
],
)
->
None
:
"""Set edge features."""
raise
NotImplementedError
def
edge_ids
(
self
,
layer_id
:
int
)
->
Union
[
Dict
[
str
,
torch
.
Tensor
],
torch
.
Tensor
]:
"""Get the edge ids of a layer."""
raise
NotImplementedError
def
to
(
self
,
device
:
torch
.
device
)
->
None
:
# pylint: disable=invalid-name
"""Copy MiniBatch to the specified device."""
raise
NotImplementedError
@
dataclass
class
DGLMiniBatch
(
MiniBatchBase
):
r
"""A data class designed for the DGL library, encompassing all the
necessary fields for computation using the DGL library."""
...
...
@@ -99,6 +145,47 @@ class DGLMiniBatch:
def
__repr__
(
self
)
->
str
:
return
_dgl_minibatch_str
(
self
)
def
node_ids
(
self
)
->
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]:
"""A representation of input nodes in the outermost layer. Contains all
nodes in the `blocks`.
- If `input_nodes` is a tensor: It indicates the graph is homogeneous.
- If `input_nodes` is a dictionary: The keys should be node type and the
value should be corresponding heterogeneous node id.
"""
return
self
.
input_nodes
def
num_layers
(
self
)
->
int
:
"""Return the number of layers."""
if
self
.
blocks
is
None
:
return
0
return
len
(
self
.
blocks
)
def
edge_ids
(
self
,
layer_id
:
int
)
->
Optional
[
Union
[
Dict
[
str
,
torch
.
Tensor
],
torch
.
Tensor
]]:
"""Get edge ids of a layer."""
if
dgl
.
EID
not
in
self
.
blocks
[
layer_id
].
edata
:
return
None
return
self
.
blocks
[
layer_id
].
edata
[
dgl
.
EID
]
def
set_node_features
(
self
,
node_features
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
Dict
[
Tuple
[
str
,
str
],
torch
.
Tensor
]
],
)
->
None
:
"""Set node features."""
self
.
node_features
=
node_features
def
set_edge_features
(
self
,
edge_features
:
List
[
Union
[
Dict
[
str
,
torch
.
Tensor
],
Dict
[
Tuple
[
str
,
str
],
torch
.
Tensor
]]
],
)
->
None
:
"""Set edge features."""
self
.
edge_features
=
edge_features
def
to
(
self
,
device
:
torch
.
device
)
->
None
:
# pylint: disable=invalid-name
"""Copy `DGLMiniBatch` to the specified device using reflection."""
...
...
@@ -236,6 +323,45 @@ class MiniBatch:
def
__repr__
(
self
)
->
str
:
return
_minibatch_str
(
self
)
def
node_ids
(
self
)
->
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]:
"""A representation of input nodes in the outermost layer. Contains all
nodes in the `sampled_subgraphs`.
- If `input_nodes` is a tensor: It indicates the graph is homogeneous.
- If `input_nodes` is a dictionary: The keys should be node type and the
value should be corresponding heterogeneous node id.
"""
return
self
.
input_nodes
def
num_layers
(
self
)
->
int
:
"""Return the number of layers."""
if
self
.
sampled_subgraphs
is
None
:
return
0
return
len
(
self
.
sampled_subgraphs
)
def
edge_ids
(
self
,
layer_id
:
int
)
->
Union
[
Dict
[
str
,
torch
.
Tensor
],
torch
.
Tensor
]:
"""Get the edge ids of a layer."""
return
self
.
sampled_subgraphs
[
layer_id
].
original_edge_ids
def
set_node_features
(
self
,
node_features
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
Dict
[
Tuple
[
str
,
str
],
torch
.
Tensor
]
],
)
->
None
:
"""Set node features."""
self
.
node_features
=
node_features
def
set_edge_features
(
self
,
edge_features
:
List
[
Union
[
Dict
[
str
,
torch
.
Tensor
],
Dict
[
Tuple
[
str
,
str
],
torch
.
Tensor
]]
],
)
->
None
:
"""Set edge features."""
self
.
edge_features
=
edge_features
def
_to_dgl_blocks
(
self
):
"""Transforming a `MiniBatch` into DGL blocks necessitates constructing
a graphical structure and ID mappings.
...
...
python/dgl/graphbolt/minibatch_transformer.py
View file @
81ac9d27
...
...
@@ -4,7 +4,7 @@ from torch.utils.data import functional_datapipe
from
torchdata.datapipes.iter
import
Mapper
from
.minibatch
import
MiniBatch
from
.minibatch
import
DGLMiniBatch
,
MiniBatch
__all__
=
[
"MiniBatchTransformer"
,
...
...
@@ -37,7 +37,7 @@ class MiniBatchTransformer(Mapper):
def
_transformer
(
self
,
minibatch
):
minibatch
=
self
.
transformer
(
minibatch
)
assert
isinstance
(
minibatch
,
MiniBatch
minibatch
,
(
MiniBatch
,
DGLMiniBatch
)
),
"The transformer output should be an instance of MiniBatch"
return
minibatch
...
...
tests/python/pytorch/graphbolt/test_feature_fetcher.py
View file @
81ac9d27
import
random
from
enum
import
Enum
import
dgl.graphbolt
as
gb
import
gb_test_utils
import
pytest
import
torch
from
torchdata.datapipes.iter
import
Mapper
def
test_FeatureFetcher_invoke
():
class
MiniBatchType
(
Enum
):
MiniBatch
=
1
DGLMiniBatch
=
2
@
pytest
.
mark
.
parametrize
(
"minibatch_type"
,
[
MiniBatchType
.
MiniBatch
,
MiniBatchType
.
DGLMiniBatch
]
)
def
test_FeatureFetcher_invoke
(
minibatch_type
):
# Prepare graph and required datapipes.
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
,
bidirection_edge
=
True
)
a
=
torch
.
tensor
(
...
...
@@ -29,6 +39,9 @@ def test_FeatureFetcher_invoke():
# Invoke FeatureFetcher via class constructor.
datapipe
=
gb
.
NeighborSampler
(
item_sampler
,
graph
,
fanouts
)
if
minibatch_type
==
MiniBatchType
.
DGLMiniBatch
:
datapipe
=
datapipe
.
to_dgl
()
datapipe
=
gb
.
FeatureFetcher
(
datapipe
,
feature_store
,
[
"a"
],
[
"b"
])
assert
len
(
list
(
datapipe
))
==
5
...
...
@@ -39,7 +52,10 @@ def test_FeatureFetcher_invoke():
assert
len
(
list
(
datapipe
))
==
5
def
test_FeatureFetcher_homo
():
@
pytest
.
mark
.
parametrize
(
"minibatch_type"
,
[
MiniBatchType
.
MiniBatch
,
MiniBatchType
.
DGLMiniBatch
]
)
def
test_FeatureFetcher_homo
(
minibatch_type
):
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
,
bidirection_edge
=
True
)
a
=
torch
.
tensor
(
[[
random
.
randint
(
0
,
10
)]
for
_
in
range
(
graph
.
total_num_nodes
)]
...
...
@@ -59,12 +75,17 @@ def test_FeatureFetcher_homo():
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
sampler_dp
=
gb
.
NeighborSampler
(
item_sampler
,
graph
,
fanouts
)
if
minibatch_type
==
MiniBatchType
.
DGLMiniBatch
:
sampler_dp
=
sampler_dp
.
to_dgl
()
fetcher_dp
=
gb
.
FeatureFetcher
(
sampler_dp
,
feature_store
,
[
"a"
],
[
"b"
])
assert
len
(
list
(
fetcher_dp
))
==
5
def
test_FeatureFetcher_with_edges_homo
():
@
pytest
.
mark
.
parametrize
(
"minibatch_type"
,
[
MiniBatchType
.
MiniBatch
,
MiniBatchType
.
DGLMiniBatch
]
)
def
test_FeatureFetcher_with_edges_homo
(
minibatch_type
):
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
,
bidirection_edge
=
True
)
a
=
torch
.
tensor
(
[[
random
.
randint
(
0
,
10
)]
for
_
in
range
(
graph
.
total_num_nodes
)]
...
...
@@ -76,9 +97,12 @@ def test_FeatureFetcher_with_edges_homo():
def
add_node_and_edge_ids
(
seeds
):
subgraphs
=
[]
for
_
in
range
(
3
):
range_tensor
=
torch
.
arange
(
10
)
subgraphs
.
append
(
gb
.
FusedSampledSubgraphImpl
(
node_pairs
=
(
torch
.
tensor
([]),
torch
.
tensor
([])),
node_pairs
=
(
range_tensor
,
range_tensor
),
original_column_node_ids
=
range_tensor
,
original_row_node_ids
=
range_tensor
,
original_edge_ids
=
torch
.
randint
(
0
,
graph
.
total_num_edges
,
(
10
,)
),
...
...
@@ -96,6 +120,8 @@ def test_FeatureFetcher_with_edges_homo():
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
10
))
item_sampler_dp
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
converter_dp
=
Mapper
(
item_sampler_dp
,
add_node_and_edge_ids
)
if
minibatch_type
==
MiniBatchType
.
DGLMiniBatch
:
converter_dp
=
converter_dp
.
to_dgl
()
fetcher_dp
=
gb
.
FeatureFetcher
(
converter_dp
,
feature_store
,
[
"a"
],
[
"b"
])
assert
len
(
list
(
fetcher_dp
))
==
5
...
...
@@ -128,7 +154,10 @@ def get_hetero_graph():
)
def
test_FeatureFetcher_hetero
():
@
pytest
.
mark
.
parametrize
(
"minibatch_type"
,
[
MiniBatchType
.
MiniBatch
,
MiniBatchType
.
DGLMiniBatch
]
)
def
test_FeatureFetcher_hetero
(
minibatch_type
):
graph
=
get_hetero_graph
()
a
=
torch
.
tensor
([[
random
.
randint
(
0
,
10
)]
for
_
in
range
(
2
)])
b
=
torch
.
tensor
([[
random
.
randint
(
0
,
10
)]
for
_
in
range
(
3
)])
...
...
@@ -149,6 +178,8 @@ def test_FeatureFetcher_hetero():
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
sampler_dp
=
gb
.
NeighborSampler
(
item_sampler
,
graph
,
fanouts
)
if
minibatch_type
==
MiniBatchType
.
DGLMiniBatch
:
sampler_dp
=
sampler_dp
.
to_dgl
()
fetcher_dp
=
gb
.
FeatureFetcher
(
sampler_dp
,
feature_store
,
{
"n1"
:
[
"a"
],
"n2"
:
[
"a"
]}
)
...
...
@@ -156,7 +187,10 @@ def test_FeatureFetcher_hetero():
assert
len
(
list
(
fetcher_dp
))
==
3
def
test_FeatureFetcher_with_edges_hetero
():
@
pytest
.
mark
.
parametrize
(
"minibatch_type"
,
[
MiniBatchType
.
MiniBatch
,
MiniBatchType
.
DGLMiniBatch
]
)
def
test_FeatureFetcher_with_edges_hetero
(
minibatch_type
):
a
=
torch
.
tensor
([[
random
.
randint
(
0
,
10
)]
for
_
in
range
(
20
)])
b
=
torch
.
tensor
([[
random
.
randint
(
0
,
10
)]
for
_
in
range
(
50
)])
...
...
@@ -166,10 +200,29 @@ def test_FeatureFetcher_with_edges_hetero():
"n1:e1:n2"
:
torch
.
randint
(
0
,
50
,
(
10
,)),
"n2:e2:n1"
:
torch
.
randint
(
0
,
50
,
(
10
,)),
}
original_column_node_ids
=
{
"n1"
:
torch
.
randint
(
0
,
20
,
(
10
,)),
"n2"
:
torch
.
randint
(
0
,
20
,
(
10
,)),
}
original_row_node_ids
=
{
"n1"
:
torch
.
randint
(
0
,
20
,
(
10
,)),
"n2"
:
torch
.
randint
(
0
,
20
,
(
10
,)),
}
for
_
in
range
(
3
):
subgraphs
.
append
(
gb
.
FusedSampledSubgraphImpl
(
node_pairs
=
(
torch
.
tensor
([]),
torch
.
tensor
([])),
node_pairs
=
{
"n1:e1:n2"
:
(
torch
.
arange
(
10
),
torch
.
arange
(
10
),
),
"n2:e2:n1"
:
(
torch
.
arange
(
10
),
torch
.
arange
(
10
),
),
},
original_column_node_ids
=
original_column_node_ids
,
original_row_node_ids
=
original_row_node_ids
,
original_edge_ids
=
original_edge_ids
,
)
)
...
...
@@ -189,6 +242,8 @@ def test_FeatureFetcher_with_edges_hetero():
)
item_sampler_dp
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
converter_dp
=
Mapper
(
item_sampler_dp
,
add_node_and_edge_ids
)
if
minibatch_type
==
MiniBatchType
.
DGLMiniBatch
:
converter_dp
=
converter_dp
.
to_dgl
()
fetcher_dp
=
gb
.
FeatureFetcher
(
converter_dp
,
feature_store
,
{
"n1"
:
[
"a"
]},
{
"n1:e1:n2"
:
[
"a"
]}
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment