Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
7bcc27ff
Unverified
Commit
7bcc27ff
authored
Sep 27, 2023
by
peizhou001
Committed by
GitHub
Sep 27, 2023
Browse files
[Graphbolt] Add to dgl (#6381)
parent
17198e9e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
308 additions
and
120 deletions
+308
-120
python/dgl/graphbolt/minibatch.py
python/dgl/graphbolt/minibatch.py
+94
-33
tests/python/pytorch/graphbolt/impl/test_minibatch.py
tests/python/pytorch/graphbolt/impl/test_minibatch.py
+213
-85
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
+1
-2
No files found.
python/dgl/graphbolt/minibatch.py
View file @
7bcc27ff
...
@@ -209,10 +209,12 @@ class MiniBatch:
...
@@ -209,10 +209,12 @@ class MiniBatch:
all node ids inside are compacted.
all node ids inside are compacted.
"""
"""
def
to_dgl_blocks
(
self
):
def
__repr__
(
self
)
->
str
:
"""Transforming a `MiniBatch` into DGL blocks necessitates constructing a
return
_minibatch_str
(
self
)
graphical structure and assigning features to the nodes and edges within
the blocks.
def
_to_dgl_blocks
(
self
):
"""Transforming a `MiniBatch` into DGL blocks necessitates constructing
a graphical structure and ID mappings.
"""
"""
if
not
self
.
sampled_subgraphs
:
if
not
self
.
sampled_subgraphs
:
return
None
return
None
...
@@ -257,23 +259,6 @@ class MiniBatch:
...
@@ -257,23 +259,6 @@ class MiniBatch:
)
)
if
is_heterogeneous
:
if
is_heterogeneous
:
# Assign node features to the outermost layer's source nodes.
if
self
.
node_features
:
for
(
node_type
,
feature_name
,
),
feature
in
self
.
node_features
.
items
():
blocks
[
0
].
srcnodes
[
node_type
].
data
[
feature_name
]
=
feature
# Assign edge features.
if
self
.
edge_features
:
for
block
,
edge_feature
in
zip
(
blocks
,
self
.
edge_features
):
for
(
edge_type
,
feature_name
,
),
feature
in
edge_feature
.
items
():
block
.
edges
[
etype_str_to_tuple
(
edge_type
)].
data
[
feature_name
]
=
feature
# Assign reverse node ids to the outermost layer's source nodes.
# Assign reverse node ids to the outermost layer's source nodes.
for
node_type
,
reverse_ids
in
self
.
sampled_subgraphs
[
for
node_type
,
reverse_ids
in
self
.
sampled_subgraphs
[
0
0
...
@@ -290,15 +275,6 @@ class MiniBatch:
...
@@ -290,15 +275,6 @@ class MiniBatch:
dgl
.
EID
dgl
.
EID
]
=
reverse_ids
]
=
reverse_ids
else
:
else
:
# Assign node features to the outermost layer's source nodes.
if
self
.
node_features
:
for
feature_name
,
feature
in
self
.
node_features
.
items
():
blocks
[
0
].
srcdata
[
feature_name
]
=
feature
# Assign edge features.
if
self
.
edge_features
:
for
block
,
edge_feature
in
zip
(
blocks
,
self
.
edge_features
):
for
feature_name
,
feature
in
edge_feature
.
items
():
block
.
edata
[
feature_name
]
=
feature
blocks
[
0
].
srcdata
[
dgl
.
NID
]
=
self
.
sampled_subgraphs
[
blocks
[
0
].
srcdata
[
dgl
.
NID
]
=
self
.
sampled_subgraphs
[
0
0
].
original_row_node_ids
].
original_row_node_ids
...
@@ -306,11 +282,96 @@ class MiniBatch:
...
@@ -306,11 +282,96 @@ class MiniBatch:
for
block
,
subgraph
in
zip
(
blocks
,
self
.
sampled_subgraphs
):
for
block
,
subgraph
in
zip
(
blocks
,
self
.
sampled_subgraphs
):
if
subgraph
.
original_edge_ids
is
not
None
:
if
subgraph
.
original_edge_ids
is
not
None
:
block
.
edata
[
dgl
.
EID
]
=
subgraph
.
original_edge_ids
block
.
edata
[
dgl
.
EID
]
=
subgraph
.
original_edge_ids
return
blocks
return
blocks
def
__repr__
(
self
)
->
str
:
def
to_dgl
(
self
):
return
_minibatch_str
(
self
)
"""Converting a `MiniBatch` into a DGL MiniBatch that contains
everything necessary for computation."
"""
minibatch
=
DGLMiniBatch
(
blocks
=
self
.
_to_dgl_blocks
(),
input_nodes
=
self
.
input_nodes
,
output_nodes
=
self
.
seed_nodes
,
node_features
=
self
.
node_features
,
edge_features
=
self
.
edge_features
,
labels
=
self
.
labels
,
)
assert
(
minibatch
.
blocks
is
not
None
),
"Sampled subgraphs for computation are missing."
# For link prediction tasks.
if
self
.
compacted_node_pairs
is
not
None
:
minibatch
.
positive_node_pairs
=
self
.
compacted_node_pairs
# Build negative graph.
if
(
self
.
compacted_negative_srcs
is
not
None
and
self
.
compacted_negative_dsts
is
not
None
):
# For homogeneous graph.
if
isinstance
(
self
.
compacted_negative_srcs
,
torch
.
Tensor
):
minibatch
.
negative_node_pairs
=
(
self
.
compacted_negative_srcs
.
view
(
-
1
),
self
.
compacted_negative_dsts
.
view
(
-
1
),
)
# For heterogeneous graph.
else
:
minibatch
.
negative_node_pairs
=
{
etype
:
(
neg_src
.
view
(
-
1
),
self
.
compacted_negative_dsts
[
etype
].
view
(
-
1
),
)
for
etype
,
neg_src
in
self
.
compacted_negative_srcs
.
items
()
}
elif
self
.
compacted_negative_srcs
is
not
None
:
# For homogeneous graph.
if
isinstance
(
self
.
compacted_negative_srcs
,
torch
.
Tensor
):
negative_ratio
=
self
.
compacted_negative_srcs
.
size
(
1
)
minibatch
.
negative_node_pairs
=
(
self
.
compacted_negative_srcs
.
view
(
-
1
),
self
.
compacted_node_pairs
[
1
].
repeat_interleave
(
negative_ratio
),
)
# For heterogeneous graph.
else
:
negative_ratio
=
list
(
self
.
compacted_negative_srcs
.
values
()
)[
0
].
size
(
1
)
minibatch
.
negative_node_pairs
=
{
etype
:
(
neg_src
.
view
(
-
1
),
self
.
compacted_node_pairs
[
etype
][
1
].
repeat_interleave
(
negative_ratio
),
)
for
etype
,
neg_src
in
self
.
compacted_negative_srcs
.
items
()
}
elif
self
.
compacted_negative_dsts
is
not
None
:
# For homogeneous graph.
if
isinstance
(
self
.
compacted_negative_dsts
,
torch
.
Tensor
):
negative_ratio
=
self
.
compacted_negative_dsts
.
size
(
1
)
minibatch
.
negative_node_pairs
=
(
self
.
compacted_node_pairs
[
0
].
repeat_interleave
(
negative_ratio
),
self
.
compacted_negative_dsts
.
view
(
-
1
),
)
# For heterogeneous graph.
else
:
negative_ratio
=
list
(
self
.
compacted_negative_dsts
.
values
()
)[
0
].
size
(
1
)
minibatch
.
negative_node_pairs
=
{
etype
:
(
self
.
compacted_node_pairs
[
etype
][
0
].
repeat_interleave
(
negative_ratio
),
neg_dst
.
view
(
-
1
),
)
for
etype
,
neg_dst
in
self
.
compacted_negative_dsts
.
items
()
}
return
minibatch
def
_minibatch_str
(
minibatch
:
MiniBatch
)
->
str
:
def
_minibatch_str
(
minibatch
:
MiniBatch
)
->
str
:
...
...
tests/python/pytorch/graphbolt/impl/test_minibatch.py
View file @
7bcc27ff
import
dgl
import
dgl
import
dgl.graphbolt
as
gb
import
dgl.graphbolt
as
gb
import
pytest
import
torch
import
torch
def
test_to_dgl_blocks_hetero
():
relation
=
"A:r:B"
relation
=
"A:r:B"
reverse_relation
=
"B:rr:A"
reverse_relation
=
"B:rr:A"
def
create_homo_minibatch
():
node_pairs
=
[
node_pairs
=
[
{
(
relation
:
(
torch
.
tensor
([
0
,
1
,
1
]),
torch
.
tensor
([
0
,
1
,
2
])),
torch
.
tensor
([
0
,
1
,
2
,
2
,
2
,
1
]),
reverse_relation
:
(
torch
.
tensor
([
1
,
0
]),
torch
.
tensor
([
2
,
3
])),
torch
.
tensor
([
0
,
1
,
1
,
2
,
3
,
2
]),
},
),
{
relation
:
(
torch
.
tensor
([
0
,
1
]),
torch
.
tensor
([
1
,
0
]))},
(
torch
.
tensor
([
0
,
1
,
2
]),
torch
.
tensor
([
1
,
0
,
0
]),
),
]
]
original_column_node_ids
=
[
original_column_node_ids
=
[
{
"B"
:
torch
.
tensor
([
10
,
11
,
12
]),
"A"
:
torch
.
tensor
([
5
,
7
,
9
,
11
])
}
,
torch
.
tensor
([
10
,
11
,
12
,
13
]),
{
"B"
:
torch
.
tensor
([
10
,
11
])
}
,
torch
.
tensor
([
10
,
11
]),
]
]
original_row_node_ids
=
[
original_row_node_ids
=
[
{
torch
.
tensor
([
10
,
11
,
12
,
13
]),
"A"
:
torch
.
tensor
([
5
,
7
,
9
,
11
]),
torch
.
tensor
([
10
,
11
,
12
]),
"B"
:
torch
.
tensor
([
10
,
11
,
12
]),
},
{
"A"
:
torch
.
tensor
([
5
,
7
]),
"B"
:
torch
.
tensor
([
10
,
11
]),
},
]
]
original_edge_ids
=
[
original_edge_ids
=
[
{
torch
.
tensor
([
19
,
20
,
21
,
22
,
25
,
30
]),
relation
:
torch
.
tensor
([
19
,
20
,
21
]),
torch
.
tensor
([
10
,
15
,
17
]),
reverse_relation
:
torch
.
tensor
([
23
,
26
]),
},
{
relation
:
torch
.
tensor
([
10
,
12
])},
]
]
node_features
=
{
node_features
=
{
"x"
:
torch
.
randint
(
0
,
10
,
(
4
,))}
(
"A"
,
"x"
):
torch
.
randint
(
0
,
10
,
(
4
,)),
}
edge_features
=
[
edge_features
=
[
{
(
relation
,
"x"
)
:
torch
.
randint
(
0
,
10
,
(
3
,))},
{
"x"
:
torch
.
randint
(
0
,
10
,
(
6
,))},
{
(
relation
,
"x"
)
:
torch
.
randint
(
0
,
10
,
(
2
,))},
{
"x"
:
torch
.
randint
(
0
,
10
,
(
3
,))},
]
]
subgraphs
=
[]
subgraphs
=
[]
for
i
in
range
(
2
):
for
i
in
range
(
2
):
...
@@ -51,68 +46,48 @@ def test_to_dgl_blocks_hetero():
...
@@ -51,68 +46,48 @@ def test_to_dgl_blocks_hetero():
original_edge_ids
=
original_edge_ids
[
i
],
original_edge_ids
=
original_edge_ids
[
i
],
)
)
)
)
blocks
=
gb
.
MiniBatch
(
return
gb
.
MiniBatch
(
sampled_subgraphs
=
subgraphs
,
sampled_subgraphs
=
subgraphs
,
node_features
=
node_features
,
node_features
=
node_features
,
edge_features
=
edge_features
,
edge_features
=
edge_features
,
).
to_dgl_blocks
()
etype
=
gb
.
etype_str_to_tuple
(
relation
)
for
i
,
block
in
enumerate
(
blocks
):
edges
=
block
.
edges
(
etype
=
etype
)
assert
torch
.
equal
(
edges
[
0
],
node_pairs
[
i
][
relation
][
0
])
assert
torch
.
equal
(
edges
[
1
],
node_pairs
[
i
][
relation
][
1
])
assert
torch
.
equal
(
block
.
edges
[
etype
].
data
[
dgl
.
EID
],
original_edge_ids
[
i
][
relation
]
)
assert
torch
.
equal
(
block
.
edges
[
etype
].
data
[
"x"
],
edge_features
[
i
][(
relation
,
"x"
)],
)
edges
=
blocks
[
0
].
edges
(
etype
=
gb
.
etype_str_to_tuple
(
reverse_relation
))
assert
torch
.
equal
(
edges
[
0
],
node_pairs
[
0
][
reverse_relation
][
0
])
assert
torch
.
equal
(
edges
[
1
],
node_pairs
[
0
][
reverse_relation
][
1
])
assert
torch
.
equal
(
blocks
[
0
].
srcdata
[
dgl
.
NID
][
"A"
],
original_row_node_ids
[
0
][
"A"
]
)
)
assert
torch
.
equal
(
blocks
[
0
].
srcdata
[
dgl
.
NID
][
"B"
],
original_row_node_ids
[
0
][
"B"
]
)
assert
torch
.
equal
(
blocks
[
0
].
srcnodes
[
"A"
].
data
[
"x"
],
node_features
[(
"A"
,
"x"
)]
)
test_to_dgl_blocks_hetero
()
def
test_to_dgl_blocks_homo
():
def
create_hetero_minibatch
():
node_pairs
=
[
node_pairs
=
[
(
{
torch
.
tensor
([
0
,
1
,
2
,
2
,
2
,
1
]),
relation
:
(
torch
.
tensor
([
0
,
1
,
1
]),
torch
.
tensor
([
0
,
1
,
2
])),
torch
.
tensor
([
0
,
1
,
1
,
2
,
3
,
2
]),
reverse_relation
:
(
torch
.
tensor
([
1
,
0
]),
torch
.
tensor
([
2
,
3
])),
),
},
(
{
relation
:
(
torch
.
tensor
([
0
,
1
]),
torch
.
tensor
([
1
,
0
]))},
torch
.
tensor
([
0
,
1
,
2
]),
torch
.
tensor
([
1
,
0
,
0
]),
),
]
]
original_column_node_ids
=
[
original_column_node_ids
=
[
torch
.
tensor
([
10
,
11
,
12
,
1
3
]),
{
"B"
:
torch
.
tensor
([
10
,
11
,
12
]),
"A"
:
torch
.
tensor
([
5
,
7
,
9
,
1
1
])
}
,
torch
.
tensor
([
10
,
11
]),
{
"B"
:
torch
.
tensor
([
10
,
11
])
}
,
]
]
original_row_node_ids
=
[
original_row_node_ids
=
[
torch
.
tensor
([
10
,
11
,
12
,
13
]),
{
torch
.
tensor
([
10
,
11
,
12
]),
"A"
:
torch
.
tensor
([
5
,
7
,
9
,
11
]),
"B"
:
torch
.
tensor
([
10
,
11
,
12
]),
},
{
"A"
:
torch
.
tensor
([
5
,
7
]),
"B"
:
torch
.
tensor
([
10
,
11
]),
},
]
]
original_edge_ids
=
[
original_edge_ids
=
[
torch
.
tensor
([
19
,
20
,
21
,
22
,
25
,
30
]),
{
torch
.
tensor
([
10
,
15
,
17
]),
relation
:
torch
.
tensor
([
19
,
20
,
21
]),
reverse_relation
:
torch
.
tensor
([
23
,
26
]),
},
{
relation
:
torch
.
tensor
([
10
,
12
])},
]
]
node_features
=
{
"x"
:
torch
.
randint
(
0
,
10
,
(
4
,))}
node_features
=
{
(
"A"
,
"x"
):
torch
.
randint
(
0
,
10
,
(
4
,)),
}
edge_features
=
[
edge_features
=
[
{
"x"
:
torch
.
randint
(
0
,
10
,
(
6
,))},
{
(
relation
,
"x"
)
:
torch
.
randint
(
0
,
10
,
(
3
,))},
{
"x"
:
torch
.
randint
(
0
,
10
,
(
3
,))},
{
(
relation
,
"x"
)
:
torch
.
randint
(
0
,
10
,
(
2
,))},
]
]
subgraphs
=
[]
subgraphs
=
[]
for
i
in
range
(
2
):
for
i
in
range
(
2
):
...
@@ -124,19 +99,11 @@ def test_to_dgl_blocks_homo():
...
@@ -124,19 +99,11 @@ def test_to_dgl_blocks_homo():
original_edge_ids
=
original_edge_ids
[
i
],
original_edge_ids
=
original_edge_ids
[
i
],
)
)
)
)
blocks
=
gb
.
MiniBatch
(
return
gb
.
MiniBatch
(
sampled_subgraphs
=
subgraphs
,
sampled_subgraphs
=
subgraphs
,
node_features
=
node_features
,
node_features
=
node_features
,
edge_features
=
edge_features
,
edge_features
=
edge_features
,
).
to_dgl_blocks
()
)
for
i
,
block
in
enumerate
(
blocks
):
assert
torch
.
equal
(
block
.
edges
()[
0
],
node_pairs
[
i
][
0
])
assert
torch
.
equal
(
block
.
edges
()[
1
],
node_pairs
[
i
][
1
])
assert
torch
.
equal
(
block
.
edata
[
dgl
.
EID
],
original_edge_ids
[
i
])
assert
torch
.
equal
(
block
.
edata
[
"x"
],
edge_features
[
i
][
"x"
])
assert
torch
.
equal
(
blocks
[
0
].
srcdata
[
dgl
.
NID
],
original_row_node_ids
[
0
])
assert
torch
.
equal
(
blocks
[
0
].
srcdata
[
"x"
],
node_features
[
"x"
])
def
test_representation
():
def
test_representation
():
...
@@ -251,3 +218,164 @@ def test_representation():
...
@@ -251,3 +218,164 @@ def test_representation():
)
)
result
=
str
(
minibatch
)
result
=
str
(
minibatch
)
assert
result
==
expect_result
,
print
(
expect_result
,
result
)
assert
result
==
expect_result
,
print
(
expect_result
,
result
)
def
check_dgl_blocks_hetero
(
minibatch
,
blocks
):
etype
=
gb
.
etype_str_to_tuple
(
relation
)
node_pairs
=
[
subgraph
.
node_pairs
for
subgraph
in
minibatch
.
sampled_subgraphs
]
original_edge_ids
=
[
subgraph
.
original_edge_ids
for
subgraph
in
minibatch
.
sampled_subgraphs
]
original_row_node_ids
=
[
subgraph
.
original_row_node_ids
for
subgraph
in
minibatch
.
sampled_subgraphs
]
for
i
,
block
in
enumerate
(
blocks
):
edges
=
block
.
edges
(
etype
=
etype
)
assert
torch
.
equal
(
edges
[
0
],
node_pairs
[
i
][
relation
][
0
])
assert
torch
.
equal
(
edges
[
1
],
node_pairs
[
i
][
relation
][
1
])
assert
torch
.
equal
(
block
.
edges
[
etype
].
data
[
dgl
.
EID
],
original_edge_ids
[
i
][
relation
]
)
edges
=
blocks
[
0
].
edges
(
etype
=
gb
.
etype_str_to_tuple
(
reverse_relation
))
assert
torch
.
equal
(
edges
[
0
],
node_pairs
[
0
][
reverse_relation
][
0
])
assert
torch
.
equal
(
edges
[
1
],
node_pairs
[
0
][
reverse_relation
][
1
])
assert
torch
.
equal
(
blocks
[
0
].
srcdata
[
dgl
.
NID
][
"A"
],
original_row_node_ids
[
0
][
"A"
]
)
assert
torch
.
equal
(
blocks
[
0
].
srcdata
[
dgl
.
NID
][
"B"
],
original_row_node_ids
[
0
][
"B"
]
)
def
check_dgl_blocks_homo
(
minibatch
,
blocks
):
node_pairs
=
[
subgraph
.
node_pairs
for
subgraph
in
minibatch
.
sampled_subgraphs
]
original_edge_ids
=
[
subgraph
.
original_edge_ids
for
subgraph
in
minibatch
.
sampled_subgraphs
]
original_row_node_ids
=
[
subgraph
.
original_row_node_ids
for
subgraph
in
minibatch
.
sampled_subgraphs
]
for
i
,
block
in
enumerate
(
blocks
):
assert
torch
.
equal
(
block
.
edges
()[
0
],
node_pairs
[
i
][
0
])
assert
torch
.
equal
(
block
.
edges
()[
1
],
node_pairs
[
i
][
1
])
assert
torch
.
equal
(
block
.
edata
[
dgl
.
EID
],
original_edge_ids
[
i
])
assert
torch
.
equal
(
blocks
[
0
].
srcdata
[
dgl
.
NID
],
original_row_node_ids
[
0
])
def
test_to_dgl_node_classification_homo
():
# Arrange
minibatch
=
create_homo_minibatch
()
minibatch
.
seed_nodes
=
torch
.
tensor
([
10
,
15
])
minibatch
.
labels
=
torch
.
tensor
([
2
,
5
])
# Act
dgl_minibatch
=
minibatch
.
to_dgl
()
# Assert
assert
len
(
dgl_minibatch
.
blocks
)
==
2
assert
minibatch
.
node_features
is
dgl_minibatch
.
node_features
assert
minibatch
.
edge_features
is
dgl_minibatch
.
edge_features
assert
minibatch
.
labels
is
dgl_minibatch
.
labels
assert
minibatch
.
seed_nodes
is
dgl_minibatch
.
output_nodes
check_dgl_blocks_homo
(
minibatch
,
dgl_minibatch
.
blocks
)
def
test_to_dgl_node_classification_hetero
():
minibatch
=
create_hetero_minibatch
()
minibatch
.
labels
=
{
"B"
:
torch
.
tensor
([
2
,
5
])}
minibatch
.
seed_nodes
=
{
"B"
:
torch
.
tensor
([
10
,
15
])}
dgl_minibatch
=
minibatch
.
to_dgl
()
# Assert
assert
len
(
dgl_minibatch
.
blocks
)
==
2
assert
minibatch
.
node_features
is
dgl_minibatch
.
node_features
assert
minibatch
.
edge_features
is
dgl_minibatch
.
edge_features
assert
minibatch
.
labels
is
dgl_minibatch
.
labels
assert
minibatch
.
seed_nodes
is
dgl_minibatch
.
output_nodes
check_dgl_blocks_hetero
(
minibatch
,
dgl_minibatch
.
blocks
)
@
pytest
.
mark
.
parametrize
(
"mode"
,
[
"neg_graph"
,
"neg_src"
,
"neg_dst"
])
def
test_to_dgl_link_predication_homo
(
mode
):
# Arrange
minibatch
=
create_homo_minibatch
()
minibatch
.
compacted_node_pairs
=
(
torch
.
tensor
([
0
,
1
]),
torch
.
tensor
([
1
,
0
]),
)
if
mode
==
"neg_graph"
or
mode
==
"neg_src"
:
minibatch
.
compacted_negative_srcs
=
torch
.
tensor
([[
0
,
0
],
[
1
,
1
]])
if
mode
==
"neg_graph"
or
mode
==
"neg_dst"
:
minibatch
.
compacted_negative_dsts
=
torch
.
tensor
([[
1
,
0
],
[
0
,
1
]])
# Act
dgl_minibatch
=
minibatch
.
to_dgl
()
# Assert
assert
len
(
dgl_minibatch
.
blocks
)
==
2
assert
minibatch
.
node_features
is
dgl_minibatch
.
node_features
assert
minibatch
.
edge_features
is
dgl_minibatch
.
edge_features
assert
minibatch
.
compacted_node_pairs
is
dgl_minibatch
.
positive_node_pairs
check_dgl_blocks_homo
(
minibatch
,
dgl_minibatch
.
blocks
)
if
mode
==
"neg_graph"
or
mode
==
"neg_src"
:
assert
torch
.
equal
(
dgl_minibatch
.
negative_node_pairs
[
0
],
minibatch
.
compacted_negative_srcs
.
view
(
-
1
),
)
if
mode
==
"neg_graph"
or
mode
==
"neg_dst"
:
assert
torch
.
equal
(
dgl_minibatch
.
negative_node_pairs
[
1
],
minibatch
.
compacted_negative_dsts
.
view
(
-
1
),
)
@
pytest
.
mark
.
parametrize
(
"mode"
,
[
"neg_graph"
,
"neg_src"
,
"neg_dst"
])
def
test_to_dgl_link_predication_hetero
(
mode
):
# Arrange
minibatch
=
create_hetero_minibatch
()
minibatch
.
compacted_node_pairs
=
{
relation
:
(
torch
.
tensor
([
1
,
1
]),
torch
.
tensor
([
1
,
0
]),
),
reverse_relation
:
(
torch
.
tensor
([
0
,
1
]),
torch
.
tensor
([
1
,
0
]),
),
}
if
mode
==
"neg_graph"
or
mode
==
"neg_src"
:
minibatch
.
compacted_negative_srcs
=
{
relation
:
torch
.
tensor
([[
2
,
0
],
[
1
,
2
]]),
reverse_relation
:
torch
.
tensor
([[
1
,
2
],
[
0
,
2
]]),
}
if
mode
==
"neg_graph"
or
mode
==
"neg_dst"
:
minibatch
.
compacted_negative_dsts
=
{
relation
:
torch
.
tensor
([[
1
,
3
],
[
2
,
1
]]),
reverse_relation
:
torch
.
tensor
([[
2
,
1
],
[
3
,
1
]]),
}
# Act
dgl_minibatch
=
minibatch
.
to_dgl
()
# Assert
assert
len
(
dgl_minibatch
.
blocks
)
==
2
assert
minibatch
.
node_features
is
dgl_minibatch
.
node_features
assert
minibatch
.
edge_features
is
dgl_minibatch
.
edge_features
assert
minibatch
.
compacted_node_pairs
is
dgl_minibatch
.
positive_node_pairs
check_dgl_blocks_hetero
(
minibatch
,
dgl_minibatch
.
blocks
)
if
mode
==
"neg_graph"
or
mode
==
"neg_src"
:
for
etype
,
src
in
minibatch
.
compacted_negative_srcs
.
items
():
assert
torch
.
equal
(
dgl_minibatch
.
negative_node_pairs
[
etype
][
0
],
src
.
view
(
-
1
),
)
if
mode
==
"neg_graph"
or
mode
==
"neg_dst"
:
for
etype
,
dst
in
minibatch
.
compacted_negative_dsts
.
items
():
assert
torch
.
equal
(
dgl_minibatch
.
negative_node_pairs
[
etype
][
1
],
minibatch
.
compacted_negative_dsts
[
etype
].
view
(
-
1
),
)
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
View file @
7bcc27ff
...
@@ -142,8 +142,7 @@ def test_SubgraphSampler_Node_Hetero(labor):
...
@@ -142,8 +142,7 @@ def test_SubgraphSampler_Node_Hetero(labor):
sampler_dp
=
Sampler
(
item_sampler
,
graph
,
fanouts
)
sampler_dp
=
Sampler
(
item_sampler
,
graph
,
fanouts
)
assert
len
(
list
(
sampler_dp
))
==
2
assert
len
(
list
(
sampler_dp
))
==
2
for
minibatch
in
sampler_dp
:
for
minibatch
in
sampler_dp
:
blocks
=
minibatch
.
to_dgl_blocks
()
assert
len
(
minibatch
.
sampled_subgraphs
)
==
num_layer
assert
len
(
blocks
)
==
num_layer
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment