Unverified Commit baa92928 authored by nv-dlasalle's avatar nv-dlasalle Committed by GitHub
Browse files

[bugfix] Fix EdgePredictionSampler for UVA sampling (#3904)



* Add failing unit test

* Fix negative sampler edge types

* fix test

* oops

* revert
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
Co-authored-by: default avatarQuan Gan <coin2028@hotmail.com>
parent f931c6ba
......@@ -6,7 +6,7 @@ from ..convert import heterograph
from .. import backend as F
from ..transforms import compact_graphs
from ..frame import LazyFeature
from ..utils import recursive_apply
from ..utils import recursive_apply, context_of
def _set_lazy_features(x, xdata, feature_names):
if feature_names is None:
......@@ -373,8 +373,11 @@ class EdgePredictionSampler(Sampler):
neg_srcdst = {g.canonical_etypes[0]: neg_srcdst}
dtype = F.dtype(list(neg_srcdst.values())[0][0])
ctx = context_of(seed_edges) if seed_edges is not None else g.device
neg_edges = {
etype: neg_srcdst.get(etype, (F.tensor([], dtype), F.tensor([], dtype)))
etype: neg_srcdst.get(etype,
(F.copy_to(F.tensor([], dtype), ctx=ctx),
F.copy_to(F.tensor([], dtype), ctx=ctx)))
for etype in g.canonical_etypes}
neg_pair_graph = heterograph(
neg_edges, {ntype: g.num_nodes(ntype) for ntype in g.ntypes})
......
import unittest
import backend as F
import dgl
from dgl.dataloading import NeighborSampler, negative_sampler, \
as_edge_prediction_sampler
from test_utils import parametrize_dtype
def create_test_graph(idtype):
# test heterograph from the docstring, plus a user -- wishes -- game relation
# 3 users, 2 games, 2 developers
# metagraph:
# ('user', 'follows', 'user'),
# ('user', 'plays', 'game'),
# ('user', 'wishes', 'game'),
# ('developer', 'develops', 'game')])
g = dgl.heterograph({
('user', 'follows', 'user'): ([0, 1], [1, 2]),
('user', 'plays', 'game'): ([0, 1, 2, 1], [0, 0, 1, 1]),
('user', 'wishes', 'game'): ([0, 2], [1, 0]),
('developer', 'develops', 'game'): ([0, 1], [0, 1])
}, idtype=idtype, device=F.ctx())
assert g.idtype == idtype
assert g.device == F.ctx()
return g
@parametrize_dtype
def test_edge_prediction_sampler(idtype):
g = create_test_graph(idtype)
sampler = NeighborSampler([10,10])
sampler = as_edge_prediction_sampler(
sampler, negative_sampler=negative_sampler.Uniform(1))
seeds = F.copy_to(F.arange(0, 2, dtype=idtype), ctx=F.ctx())
# just a smoke test to make sure we don't fail internal assertions
result = sampler.sample(g, {'follows': seeds})
if __name__ == '__main__':
test_edge_prediction_sampler()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment