Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
81ac9d27
Unverified
Commit
81ac9d27
authored
Nov 23, 2023
by
czkkkkkk
Committed by
GitHub
Nov 23, 2023
Browse files
[Graphbolt] Add MiniBatchBase (#6531)
parent
c08f77bf
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
220 additions
and
26 deletions
+220
-26
python/dgl/graphbolt/feature_fetcher.py
python/dgl/graphbolt/feature_fetcher.py
+28
-15
python/dgl/graphbolt/minibatch.py
python/dgl/graphbolt/minibatch.py
+128
-2
python/dgl/graphbolt/minibatch_transformer.py
python/dgl/graphbolt/minibatch_transformer.py
+2
-2
tests/python/pytorch/graphbolt/test_feature_fetcher.py
tests/python/pytorch/graphbolt/test_feature_fetcher.py
+62
-7
No files found.
python/dgl/graphbolt/feature_fetcher.py
View file @
81ac9d27
...
@@ -4,6 +4,8 @@ from typing import Dict
...
@@ -4,6 +4,8 @@ from typing import Dict
from
torch.utils.data
import
functional_datapipe
from
torch.utils.data
import
functional_datapipe
from
.base
import
etype_tuple_to_str
from
.minibatch_transformer
import
MiniBatchTransformer
from
.minibatch_transformer
import
MiniBatchTransformer
...
@@ -67,21 +69,22 @@ class FeatureFetcher(MiniBatchTransformer):
...
@@ -67,21 +69,22 @@ class FeatureFetcher(MiniBatchTransformer):
MiniBatch
MiniBatch
An instance of :class:`MiniBatch` filled with required features.
An instance of :class:`MiniBatch` filled with required features.
"""
"""
data
.
node_features
=
{}
node_features
=
{}
num_layer
=
len
(
data
.
sampled_subgraphs
)
if
data
.
sampled_subgraphs
else
0
num_layer
s
=
data
.
num_layers
()
data
.
edge_features
=
[{}
for
_
in
range
(
num_layer
)]
edge_features
=
[{}
for
_
in
range
(
num_layer
s
)]
is_heterogeneous
=
isinstance
(
is_heterogeneous
=
isinstance
(
self
.
node_feature_keys
,
Dict
self
.
node_feature_keys
,
Dict
)
or
isinstance
(
self
.
edge_feature_keys
,
Dict
)
)
or
isinstance
(
self
.
edge_feature_keys
,
Dict
)
# Read Node features.
# Read Node features.
if
self
.
node_feature_keys
and
data
.
input_nodes
is
not
None
:
input_nodes
=
data
.
node_ids
()
if
self
.
node_feature_keys
and
input_nodes
is
not
None
:
if
is_heterogeneous
:
if
is_heterogeneous
:
for
type_name
,
feature_names
in
self
.
node_feature_keys
.
items
():
for
type_name
,
feature_names
in
self
.
node_feature_keys
.
items
():
nodes
=
data
.
input_nodes
[
type_name
]
nodes
=
input_nodes
[
type_name
]
if
nodes
is
None
:
if
nodes
is
None
:
continue
continue
for
feature_name
in
feature_names
:
for
feature_name
in
feature_names
:
data
.
node_features
[
node_features
[
(
type_name
,
feature_name
)
(
type_name
,
feature_name
)
]
=
self
.
feature_store
.
read
(
]
=
self
.
feature_store
.
read
(
"node"
,
"node"
,
...
@@ -91,39 +94,49 @@ class FeatureFetcher(MiniBatchTransformer):
...
@@ -91,39 +94,49 @@ class FeatureFetcher(MiniBatchTransformer):
)
)
else
:
else
:
for
feature_name
in
self
.
node_feature_keys
:
for
feature_name
in
self
.
node_feature_keys
:
data
.
node_features
[
feature_name
]
=
self
.
feature_store
.
read
(
node_features
[
feature_name
]
=
self
.
feature_store
.
read
(
"node"
,
"node"
,
None
,
None
,
feature_name
,
feature_name
,
data
.
input_nodes
,
input_nodes
,
)
)
# Read Edge features.
# Read Edge features.
if
self
.
edge_feature_keys
and
data
.
sampled_subgraphs
:
if
self
.
edge_feature_keys
and
num_layers
>
0
:
for
i
,
subgraph
in
enumerate
(
data
.
sampled_subgraphs
):
for
i
in
range
(
num_layers
):
if
subgraph
.
original_edge_ids
is
None
:
original_edge_ids
=
data
.
edge_ids
(
i
)
if
original_edge_ids
is
None
:
continue
continue
if
is_heterogeneous
:
if
is_heterogeneous
:
# Convert edge type to string for DGLMiniBatch.
original_edge_ids
=
{
etype_tuple_to_str
(
key
)
if
isinstance
(
key
,
tuple
)
else
key
:
value
for
key
,
value
in
original_edge_ids
.
items
()
}
for
(
for
(
type_name
,
type_name
,
feature_names
,
feature_names
,
)
in
self
.
edge_feature_keys
.
items
():
)
in
self
.
edge_feature_keys
.
items
():
edges
=
subgraph
.
original_edge_ids
.
get
(
type_name
,
None
)
edges
=
original_edge_ids
.
get
(
type_name
,
None
)
if
edges
is
None
:
if
edges
is
None
:
continue
continue
for
feature_name
in
feature_names
:
for
feature_name
in
feature_names
:
data
.
edge_features
[
i
][
edge_features
[
i
][
(
type_name
,
feature_name
)
(
type_name
,
feature_name
)
]
=
self
.
feature_store
.
read
(
]
=
self
.
feature_store
.
read
(
"edge"
,
type_name
,
feature_name
,
edges
"edge"
,
type_name
,
feature_name
,
edges
)
)
else
:
else
:
for
feature_name
in
self
.
edge_feature_keys
:
for
feature_name
in
self
.
edge_feature_keys
:
data
.
edge_features
[
i
][
edge_features
[
i
][
feature_name
feature_name
]
=
self
.
feature_store
.
read
(
]
=
self
.
feature_store
.
read
(
"edge"
,
"edge"
,
None
,
None
,
feature_name
,
feature_name
,
subgraph
.
original_edge_ids
,
original_edge_ids
,
)
)
data
.
set_node_features
(
node_features
)
data
.
set_edge_features
(
edge_features
)
return
data
return
data
python/dgl/graphbolt/minibatch.py
View file @
81ac9d27
"""Unified data structure for input and ouput of all the stages in loading process."""
"""Unified data structure for input and ouput of all the stages in loading process."""
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Tuple
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
...
@@ -16,7 +16,53 @@ __all__ = ["DGLMiniBatch", "MiniBatch"]
...
@@ -16,7 +16,53 @@ __all__ = ["DGLMiniBatch", "MiniBatch"]
@
dataclass
@
dataclass
class
DGLMiniBatch
:
class
MiniBatchBase
(
object
):
"""Base class for `MiniBatch` and `DGLMiniBatch`."""
def
node_ids
(
self
)
->
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]:
"""A representation of input nodes in the outermost layer. Contains all
nodes in the MiniBatch.
- If `input_nodes` is a tensor: It indicates the graph is homogeneous.
- If `input_nodes` is a dictionary: The keys should be node type and the
value should be corresponding heterogeneous node id.
"""
raise
NotImplementedError
def
num_layers
(
self
)
->
int
:
"""Return the number of layers."""
raise
NotImplementedError
def
set_node_features
(
self
,
node_features
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
Dict
[
Tuple
[
str
,
str
],
torch
.
Tensor
]
],
)
->
None
:
"""Set node features."""
raise
NotImplementedError
def
set_edge_features
(
self
,
edge_features
:
List
[
Union
[
Dict
[
str
,
torch
.
Tensor
],
Dict
[
Tuple
[
str
,
str
],
torch
.
Tensor
]]
],
)
->
None
:
"""Set edge features."""
raise
NotImplementedError
def
edge_ids
(
self
,
layer_id
:
int
)
->
Union
[
Dict
[
str
,
torch
.
Tensor
],
torch
.
Tensor
]:
"""Get the edge ids of a layer."""
raise
NotImplementedError
def
to
(
self
,
device
:
torch
.
device
)
->
None
:
# pylint: disable=invalid-name
"""Copy MiniBatch to the specified device."""
raise
NotImplementedError
@
dataclass
class
DGLMiniBatch
(
MiniBatchBase
):
r
"""A data class designed for the DGL library, encompassing all the
r
"""A data class designed for the DGL library, encompassing all the
necessary fields for computation using the DGL library."""
necessary fields for computation using the DGL library."""
...
@@ -99,6 +145,47 @@ class DGLMiniBatch:
...
@@ -99,6 +145,47 @@ class DGLMiniBatch:
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
_dgl_minibatch_str
(
self
)
return
_dgl_minibatch_str
(
self
)
def
node_ids
(
self
)
->
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]:
"""A representation of input nodes in the outermost layer. Contains all
nodes in the `blocks`.
- If `input_nodes` is a tensor: It indicates the graph is homogeneous.
- If `input_nodes` is a dictionary: The keys should be node type and the
value should be corresponding heterogeneous node id.
"""
return
self
.
input_nodes
def
num_layers
(
self
)
->
int
:
"""Return the number of layers."""
if
self
.
blocks
is
None
:
return
0
return
len
(
self
.
blocks
)
def
edge_ids
(
self
,
layer_id
:
int
)
->
Optional
[
Union
[
Dict
[
str
,
torch
.
Tensor
],
torch
.
Tensor
]]:
"""Get edge ids of a layer."""
if
dgl
.
EID
not
in
self
.
blocks
[
layer_id
].
edata
:
return
None
return
self
.
blocks
[
layer_id
].
edata
[
dgl
.
EID
]
def
set_node_features
(
self
,
node_features
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
Dict
[
Tuple
[
str
,
str
],
torch
.
Tensor
]
],
)
->
None
:
"""Set node features."""
self
.
node_features
=
node_features
def
set_edge_features
(
self
,
edge_features
:
List
[
Union
[
Dict
[
str
,
torch
.
Tensor
],
Dict
[
Tuple
[
str
,
str
],
torch
.
Tensor
]]
],
)
->
None
:
"""Set edge features."""
self
.
edge_features
=
edge_features
def
to
(
self
,
device
:
torch
.
device
)
->
None
:
# pylint: disable=invalid-name
def
to
(
self
,
device
:
torch
.
device
)
->
None
:
# pylint: disable=invalid-name
"""Copy `DGLMiniBatch` to the specified device using reflection."""
"""Copy `DGLMiniBatch` to the specified device using reflection."""
...
@@ -236,6 +323,45 @@ class MiniBatch:
...
@@ -236,6 +323,45 @@ class MiniBatch:
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
_minibatch_str
(
self
)
return
_minibatch_str
(
self
)
def
node_ids
(
self
)
->
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]:
"""A representation of input nodes in the outermost layer. Contains all
nodes in the `sampled_subgraphs`.
- If `input_nodes` is a tensor: It indicates the graph is homogeneous.
- If `input_nodes` is a dictionary: The keys should be node type and the
value should be corresponding heterogeneous node id.
"""
return
self
.
input_nodes
def
num_layers
(
self
)
->
int
:
"""Return the number of layers."""
if
self
.
sampled_subgraphs
is
None
:
return
0
return
len
(
self
.
sampled_subgraphs
)
def
edge_ids
(
self
,
layer_id
:
int
)
->
Union
[
Dict
[
str
,
torch
.
Tensor
],
torch
.
Tensor
]:
"""Get the edge ids of a layer."""
return
self
.
sampled_subgraphs
[
layer_id
].
original_edge_ids
def
set_node_features
(
self
,
node_features
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
Dict
[
Tuple
[
str
,
str
],
torch
.
Tensor
]
],
)
->
None
:
"""Set node features."""
self
.
node_features
=
node_features
def
set_edge_features
(
self
,
edge_features
:
List
[
Union
[
Dict
[
str
,
torch
.
Tensor
],
Dict
[
Tuple
[
str
,
str
],
torch
.
Tensor
]]
],
)
->
None
:
"""Set edge features."""
self
.
edge_features
=
edge_features
def
_to_dgl_blocks
(
self
):
def
_to_dgl_blocks
(
self
):
"""Transforming a `MiniBatch` into DGL blocks necessitates constructing
"""Transforming a `MiniBatch` into DGL blocks necessitates constructing
a graphical structure and ID mappings.
a graphical structure and ID mappings.
...
...
python/dgl/graphbolt/minibatch_transformer.py
View file @
81ac9d27
...
@@ -4,7 +4,7 @@ from torch.utils.data import functional_datapipe
...
@@ -4,7 +4,7 @@ from torch.utils.data import functional_datapipe
from
torchdata.datapipes.iter
import
Mapper
from
torchdata.datapipes.iter
import
Mapper
from
.minibatch
import
MiniBatch
from
.minibatch
import
DGLMiniBatch
,
MiniBatch
__all__
=
[
__all__
=
[
"MiniBatchTransformer"
,
"MiniBatchTransformer"
,
...
@@ -37,7 +37,7 @@ class MiniBatchTransformer(Mapper):
...
@@ -37,7 +37,7 @@ class MiniBatchTransformer(Mapper):
def
_transformer
(
self
,
minibatch
):
def
_transformer
(
self
,
minibatch
):
minibatch
=
self
.
transformer
(
minibatch
)
minibatch
=
self
.
transformer
(
minibatch
)
assert
isinstance
(
assert
isinstance
(
minibatch
,
MiniBatch
minibatch
,
(
MiniBatch
,
DGLMiniBatch
)
),
"The transformer output should be an instance of MiniBatch"
),
"The transformer output should be an instance of MiniBatch"
return
minibatch
return
minibatch
...
...
tests/python/pytorch/graphbolt/test_feature_fetcher.py
View file @
81ac9d27
import
random
import
random
from
enum
import
Enum
import
dgl.graphbolt
as
gb
import
dgl.graphbolt
as
gb
import
gb_test_utils
import
gb_test_utils
import
pytest
import
torch
import
torch
from
torchdata.datapipes.iter
import
Mapper
from
torchdata.datapipes.iter
import
Mapper
def
test_FeatureFetcher_invoke
():
class
MiniBatchType
(
Enum
):
MiniBatch
=
1
DGLMiniBatch
=
2
@
pytest
.
mark
.
parametrize
(
"minibatch_type"
,
[
MiniBatchType
.
MiniBatch
,
MiniBatchType
.
DGLMiniBatch
]
)
def
test_FeatureFetcher_invoke
(
minibatch_type
):
# Prepare graph and required datapipes.
# Prepare graph and required datapipes.
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
,
bidirection_edge
=
True
)
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
,
bidirection_edge
=
True
)
a
=
torch
.
tensor
(
a
=
torch
.
tensor
(
...
@@ -29,6 +39,9 @@ def test_FeatureFetcher_invoke():
...
@@ -29,6 +39,9 @@ def test_FeatureFetcher_invoke():
# Invoke FeatureFetcher via class constructor.
# Invoke FeatureFetcher via class constructor.
datapipe
=
gb
.
NeighborSampler
(
item_sampler
,
graph
,
fanouts
)
datapipe
=
gb
.
NeighborSampler
(
item_sampler
,
graph
,
fanouts
)
if
minibatch_type
==
MiniBatchType
.
DGLMiniBatch
:
datapipe
=
datapipe
.
to_dgl
()
datapipe
=
gb
.
FeatureFetcher
(
datapipe
,
feature_store
,
[
"a"
],
[
"b"
])
datapipe
=
gb
.
FeatureFetcher
(
datapipe
,
feature_store
,
[
"a"
],
[
"b"
])
assert
len
(
list
(
datapipe
))
==
5
assert
len
(
list
(
datapipe
))
==
5
...
@@ -39,7 +52,10 @@ def test_FeatureFetcher_invoke():
...
@@ -39,7 +52,10 @@ def test_FeatureFetcher_invoke():
assert
len
(
list
(
datapipe
))
==
5
assert
len
(
list
(
datapipe
))
==
5
def
test_FeatureFetcher_homo
():
@
pytest
.
mark
.
parametrize
(
"minibatch_type"
,
[
MiniBatchType
.
MiniBatch
,
MiniBatchType
.
DGLMiniBatch
]
)
def
test_FeatureFetcher_homo
(
minibatch_type
):
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
,
bidirection_edge
=
True
)
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
,
bidirection_edge
=
True
)
a
=
torch
.
tensor
(
a
=
torch
.
tensor
(
[[
random
.
randint
(
0
,
10
)]
for
_
in
range
(
graph
.
total_num_nodes
)]
[[
random
.
randint
(
0
,
10
)]
for
_
in
range
(
graph
.
total_num_nodes
)]
...
@@ -59,12 +75,17 @@ def test_FeatureFetcher_homo():
...
@@ -59,12 +75,17 @@ def test_FeatureFetcher_homo():
num_layer
=
2
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
sampler_dp
=
gb
.
NeighborSampler
(
item_sampler
,
graph
,
fanouts
)
sampler_dp
=
gb
.
NeighborSampler
(
item_sampler
,
graph
,
fanouts
)
if
minibatch_type
==
MiniBatchType
.
DGLMiniBatch
:
sampler_dp
=
sampler_dp
.
to_dgl
()
fetcher_dp
=
gb
.
FeatureFetcher
(
sampler_dp
,
feature_store
,
[
"a"
],
[
"b"
])
fetcher_dp
=
gb
.
FeatureFetcher
(
sampler_dp
,
feature_store
,
[
"a"
],
[
"b"
])
assert
len
(
list
(
fetcher_dp
))
==
5
assert
len
(
list
(
fetcher_dp
))
==
5
def
test_FeatureFetcher_with_edges_homo
():
@
pytest
.
mark
.
parametrize
(
"minibatch_type"
,
[
MiniBatchType
.
MiniBatch
,
MiniBatchType
.
DGLMiniBatch
]
)
def
test_FeatureFetcher_with_edges_homo
(
minibatch_type
):
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
,
bidirection_edge
=
True
)
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
,
bidirection_edge
=
True
)
a
=
torch
.
tensor
(
a
=
torch
.
tensor
(
[[
random
.
randint
(
0
,
10
)]
for
_
in
range
(
graph
.
total_num_nodes
)]
[[
random
.
randint
(
0
,
10
)]
for
_
in
range
(
graph
.
total_num_nodes
)]
...
@@ -76,9 +97,12 @@ def test_FeatureFetcher_with_edges_homo():
...
@@ -76,9 +97,12 @@ def test_FeatureFetcher_with_edges_homo():
def
add_node_and_edge_ids
(
seeds
):
def
add_node_and_edge_ids
(
seeds
):
subgraphs
=
[]
subgraphs
=
[]
for
_
in
range
(
3
):
for
_
in
range
(
3
):
range_tensor
=
torch
.
arange
(
10
)
subgraphs
.
append
(
subgraphs
.
append
(
gb
.
FusedSampledSubgraphImpl
(
gb
.
FusedSampledSubgraphImpl
(
node_pairs
=
(
torch
.
tensor
([]),
torch
.
tensor
([])),
node_pairs
=
(
range_tensor
,
range_tensor
),
original_column_node_ids
=
range_tensor
,
original_row_node_ids
=
range_tensor
,
original_edge_ids
=
torch
.
randint
(
original_edge_ids
=
torch
.
randint
(
0
,
graph
.
total_num_edges
,
(
10
,)
0
,
graph
.
total_num_edges
,
(
10
,)
),
),
...
@@ -96,6 +120,8 @@ def test_FeatureFetcher_with_edges_homo():
...
@@ -96,6 +120,8 @@ def test_FeatureFetcher_with_edges_homo():
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
10
))
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
10
))
item_sampler_dp
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
item_sampler_dp
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
converter_dp
=
Mapper
(
item_sampler_dp
,
add_node_and_edge_ids
)
converter_dp
=
Mapper
(
item_sampler_dp
,
add_node_and_edge_ids
)
if
minibatch_type
==
MiniBatchType
.
DGLMiniBatch
:
converter_dp
=
converter_dp
.
to_dgl
()
fetcher_dp
=
gb
.
FeatureFetcher
(
converter_dp
,
feature_store
,
[
"a"
],
[
"b"
])
fetcher_dp
=
gb
.
FeatureFetcher
(
converter_dp
,
feature_store
,
[
"a"
],
[
"b"
])
assert
len
(
list
(
fetcher_dp
))
==
5
assert
len
(
list
(
fetcher_dp
))
==
5
...
@@ -128,7 +154,10 @@ def get_hetero_graph():
...
@@ -128,7 +154,10 @@ def get_hetero_graph():
)
)
def
test_FeatureFetcher_hetero
():
@
pytest
.
mark
.
parametrize
(
"minibatch_type"
,
[
MiniBatchType
.
MiniBatch
,
MiniBatchType
.
DGLMiniBatch
]
)
def
test_FeatureFetcher_hetero
(
minibatch_type
):
graph
=
get_hetero_graph
()
graph
=
get_hetero_graph
()
a
=
torch
.
tensor
([[
random
.
randint
(
0
,
10
)]
for
_
in
range
(
2
)])
a
=
torch
.
tensor
([[
random
.
randint
(
0
,
10
)]
for
_
in
range
(
2
)])
b
=
torch
.
tensor
([[
random
.
randint
(
0
,
10
)]
for
_
in
range
(
3
)])
b
=
torch
.
tensor
([[
random
.
randint
(
0
,
10
)]
for
_
in
range
(
3
)])
...
@@ -149,6 +178,8 @@ def test_FeatureFetcher_hetero():
...
@@ -149,6 +178,8 @@ def test_FeatureFetcher_hetero():
num_layer
=
2
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
sampler_dp
=
gb
.
NeighborSampler
(
item_sampler
,
graph
,
fanouts
)
sampler_dp
=
gb
.
NeighborSampler
(
item_sampler
,
graph
,
fanouts
)
if
minibatch_type
==
MiniBatchType
.
DGLMiniBatch
:
sampler_dp
=
sampler_dp
.
to_dgl
()
fetcher_dp
=
gb
.
FeatureFetcher
(
fetcher_dp
=
gb
.
FeatureFetcher
(
sampler_dp
,
feature_store
,
{
"n1"
:
[
"a"
],
"n2"
:
[
"a"
]}
sampler_dp
,
feature_store
,
{
"n1"
:
[
"a"
],
"n2"
:
[
"a"
]}
)
)
...
@@ -156,7 +187,10 @@ def test_FeatureFetcher_hetero():
...
@@ -156,7 +187,10 @@ def test_FeatureFetcher_hetero():
assert
len
(
list
(
fetcher_dp
))
==
3
assert
len
(
list
(
fetcher_dp
))
==
3
def
test_FeatureFetcher_with_edges_hetero
():
@
pytest
.
mark
.
parametrize
(
"minibatch_type"
,
[
MiniBatchType
.
MiniBatch
,
MiniBatchType
.
DGLMiniBatch
]
)
def
test_FeatureFetcher_with_edges_hetero
(
minibatch_type
):
a
=
torch
.
tensor
([[
random
.
randint
(
0
,
10
)]
for
_
in
range
(
20
)])
a
=
torch
.
tensor
([[
random
.
randint
(
0
,
10
)]
for
_
in
range
(
20
)])
b
=
torch
.
tensor
([[
random
.
randint
(
0
,
10
)]
for
_
in
range
(
50
)])
b
=
torch
.
tensor
([[
random
.
randint
(
0
,
10
)]
for
_
in
range
(
50
)])
...
@@ -166,10 +200,29 @@ def test_FeatureFetcher_with_edges_hetero():
...
@@ -166,10 +200,29 @@ def test_FeatureFetcher_with_edges_hetero():
"n1:e1:n2"
:
torch
.
randint
(
0
,
50
,
(
10
,)),
"n1:e1:n2"
:
torch
.
randint
(
0
,
50
,
(
10
,)),
"n2:e2:n1"
:
torch
.
randint
(
0
,
50
,
(
10
,)),
"n2:e2:n1"
:
torch
.
randint
(
0
,
50
,
(
10
,)),
}
}
original_column_node_ids
=
{
"n1"
:
torch
.
randint
(
0
,
20
,
(
10
,)),
"n2"
:
torch
.
randint
(
0
,
20
,
(
10
,)),
}
original_row_node_ids
=
{
"n1"
:
torch
.
randint
(
0
,
20
,
(
10
,)),
"n2"
:
torch
.
randint
(
0
,
20
,
(
10
,)),
}
for
_
in
range
(
3
):
for
_
in
range
(
3
):
subgraphs
.
append
(
subgraphs
.
append
(
gb
.
FusedSampledSubgraphImpl
(
gb
.
FusedSampledSubgraphImpl
(
node_pairs
=
(
torch
.
tensor
([]),
torch
.
tensor
([])),
node_pairs
=
{
"n1:e1:n2"
:
(
torch
.
arange
(
10
),
torch
.
arange
(
10
),
),
"n2:e2:n1"
:
(
torch
.
arange
(
10
),
torch
.
arange
(
10
),
),
},
original_column_node_ids
=
original_column_node_ids
,
original_row_node_ids
=
original_row_node_ids
,
original_edge_ids
=
original_edge_ids
,
original_edge_ids
=
original_edge_ids
,
)
)
)
)
...
@@ -189,6 +242,8 @@ def test_FeatureFetcher_with_edges_hetero():
...
@@ -189,6 +242,8 @@ def test_FeatureFetcher_with_edges_hetero():
)
)
item_sampler_dp
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
item_sampler_dp
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
converter_dp
=
Mapper
(
item_sampler_dp
,
add_node_and_edge_ids
)
converter_dp
=
Mapper
(
item_sampler_dp
,
add_node_and_edge_ids
)
if
minibatch_type
==
MiniBatchType
.
DGLMiniBatch
:
converter_dp
=
converter_dp
.
to_dgl
()
fetcher_dp
=
gb
.
FeatureFetcher
(
fetcher_dp
=
gb
.
FeatureFetcher
(
converter_dp
,
feature_store
,
{
"n1"
:
[
"a"
]},
{
"n1:e1:n2"
:
[
"a"
]}
converter_dp
,
feature_store
,
{
"n1"
:
[
"a"
]},
{
"n1:e1:n2"
:
[
"a"
]}
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment