Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
e362f023
Unverified
Commit
e362f023
authored
Aug 22, 2023
by
peizhou001
Committed by
GitHub
Aug 22, 2023
Browse files
[Graohbolt] Fix negative sampler hetero bugs (#6185)
parent
f2d42266
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
72 additions
and
1 deletion
+72
-1
python/dgl/graphbolt/negative_sampler.py
python/dgl/graphbolt/negative_sampler.py
+9
-1
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
+63
-0
No files found.
python/dgl/graphbolt/negative_sampler.py
View file @
e362f023
...
@@ -58,10 +58,18 @@ class NegativeSampler(Mapper):
...
@@ -58,10 +58,18 @@ class NegativeSampler(Mapper):
"""
"""
node_pairs
=
data
.
node_pair
node_pairs
=
data
.
node_pair
if
isinstance
(
node_pairs
,
Mapping
):
if
isinstance
(
node_pairs
,
Mapping
):
if
self
.
output_format
==
LinkPredictionEdgeFormat
.
INDEPENDENT
:
data
.
label
=
{}
else
:
data
.
negative_head
,
data
.
negative_tail
=
{},
{}
for
etype
,
pos_pairs
in
node_pairs
.
items
():
for
etype
,
pos_pairs
in
node_pairs
.
items
():
self
.
_collate
(
self
.
_collate
(
data
,
self
.
_sample_with_etype
(
pos_pairs
,
etype
),
etype
data
,
self
.
_sample_with_etype
(
pos_pairs
,
etype
),
etype
)
)
if
self
.
output_format
==
LinkPredictionEdgeFormat
.
HEAD_CONDITIONED
:
data
.
negative_tail
=
None
if
self
.
output_format
==
LinkPredictionEdgeFormat
.
TAIL_CONDITIONED
:
data
.
negative_head
=
None
else
:
else
:
self
.
_collate
(
data
,
self
.
_sample_with_etype
(
node_pairs
))
self
.
_collate
(
data
,
self
.
_sample_with_etype
(
node_pairs
))
return
data
return
data
...
@@ -101,7 +109,7 @@ class NegativeSampler(Mapper):
...
@@ -101,7 +109,7 @@ class NegativeSampler(Mapper):
etype : (str, str, str)
etype : (str, str, str)
Canonical edge type.
Canonical edge type.
"""
"""
pos_src
,
pos_dst
=
data
.
node_pair
pos_src
,
pos_dst
=
data
.
node_pair
[
etype
]
if
etype
else
data
.
node_pair
neg_src
,
neg_dst
=
neg_pairs
neg_src
,
neg_dst
=
neg_pairs
if
self
.
output_format
==
LinkPredictionEdgeFormat
.
INDEPENDENT
:
if
self
.
output_format
==
LinkPredictionEdgeFormat
.
INDEPENDENT
:
pos_label
=
torch
.
ones_like
(
pos_src
)
pos_label
=
torch
.
ones_like
(
pos_src
)
...
...
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
View file @
e362f023
...
@@ -142,3 +142,66 @@ def test_NegativeSampler_Tail_Conditioned_Format(negative_ratio):
...
@@ -142,3 +142,66 @@ def test_NegativeSampler_Tail_Conditioned_Format(negative_ratio):
assert
len
(
pos_dst
)
==
batch_size
assert
len
(
pos_dst
)
==
batch_size
assert
len
(
neg_dst
)
==
batch_size
assert
len
(
neg_dst
)
==
batch_size
assert
neg_dst
.
numel
()
==
batch_size
*
negative_ratio
assert
neg_dst
.
numel
()
==
batch_size
*
negative_ratio
def
get_hetero_graph
():
# COO graph:
# [0, 0, 1, 1, 2, 2, 3, 3, 4, 4]
# [2, 4, 2, 3, 0, 1, 1, 0, 0, 1]
# [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type.
# num_nodes = 5, num_n1 = 2, num_n2 = 3
ntypes
=
{
"n1"
:
0
,
"n2"
:
1
}
etypes
=
{(
"n1"
,
"e1"
,
"n2"
):
0
,
(
"n2"
,
"e2"
,
"n1"
):
1
}
metadata
=
gb
.
GraphMetadata
(
ntypes
,
etypes
)
indptr
=
torch
.
LongTensor
([
0
,
2
,
4
,
6
,
8
,
10
])
indices
=
torch
.
LongTensor
([
2
,
4
,
2
,
3
,
0
,
1
,
1
,
0
,
0
,
1
])
type_per_edge
=
torch
.
LongTensor
([
1
,
1
,
1
,
1
,
0
,
0
,
0
,
0
,
0
,
0
])
node_type_offset
=
torch
.
LongTensor
([
0
,
2
,
5
])
return
gb
.
from_csc
(
indptr
,
indices
,
node_type_offset
=
node_type_offset
,
type_per_edge
=
type_per_edge
,
metadata
=
metadata
,
)
def
to_link_block
(
data
):
block
=
gb
.
LinkPredictionBlock
(
node_pair
=
data
)
return
block
@
pytest
.
mark
.
parametrize
(
"format"
,
[
gb
.
LinkPredictionEdgeFormat
.
INDEPENDENT
,
gb
.
LinkPredictionEdgeFormat
.
CONDITIONED
,
gb
.
LinkPredictionEdgeFormat
.
HEAD_CONDITIONED
,
gb
.
LinkPredictionEdgeFormat
.
TAIL_CONDITIONED
,
],
)
def
test_NegativeSampler_Hetero_Data
(
format
):
graph
=
get_hetero_graph
()
itemset
=
gb
.
ItemSetDict
(
{
(
"n1"
,
"e1"
,
"n2"
):
gb
.
ItemSet
(
(
torch
.
LongTensor
([
0
,
0
,
1
,
1
]),
torch
.
LongTensor
([
0
,
2
,
0
,
1
]),
)
),
(
"n2"
,
"e2"
,
"n1"
):
gb
.
ItemSet
(
(
torch
.
LongTensor
([
0
,
0
,
1
,
1
,
2
,
2
]),
torch
.
LongTensor
([
0
,
1
,
1
,
0
,
0
,
1
]),
)
),
}
)
minibatch_dp
=
gb
.
MinibatchSampler
(
itemset
,
batch_size
=
2
)
data_block_converter
=
Mapper
(
minibatch_dp
,
to_link_block
)
negative_dp
=
gb
.
UniformNegativeSampler
(
data_block_converter
,
1
,
format
,
graph
)
assert
len
(
list
(
negative_dp
))
==
5
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment