Unverified Commit efe0b061 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

fix test dataloader (#3482)


Co-authored-by: default avatarRhett Ying <85214957+Rhett-Ying@users.noreply.github.com>
parent c5ae54bf
......@@ -217,8 +217,8 @@ def _check_device(data):
def test_node_dataloader():
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)
g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4])).to(F.ctx())
g1.ndata['feat'] = F.copy_to(F.randn((5, 8)), F.ctx())
g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4]))
g1.ndata['feat'] = F.copy_to(F.randn((5, 8)), F.cpu())
dataloader = dgl.dataloading.NodeDataLoader(
g1, g1.nodes(), sampler, device=F.ctx(), batch_size=g1.num_nodes())
......@@ -232,9 +232,9 @@ def test_node_dataloader():
('user', 'followed-by', 'user'): ([1, 2, 3, 0, 2, 3, 0], [0, 0, 0, 1, 1, 1, 2]),
('user', 'play', 'game'): ([0, 1, 1, 3, 5], [0, 1, 2, 0, 2]),
('game', 'played-by', 'user'): ([0, 1, 2, 0, 2], [0, 1, 1, 3, 5])
}).to(F.ctx())
})
for ntype in g2.ntypes:
g2.nodes[ntype].data['feat'] = F.copy_to(F.randn((g2.num_nodes(ntype), 8)), F.ctx())
g2.nodes[ntype].data['feat'] = F.copy_to(F.randn((g2.num_nodes(ntype), 8)), F.cpu())
batch_size = max(g2.num_nodes(nty) for nty in g2.ntypes)
dataloader = dgl.dataloading.NodeDataLoader(
......@@ -250,8 +250,8 @@ def test_edge_dataloader():
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)
neg_sampler = dgl.dataloading.negative_sampler.Uniform(2)
g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4])).to(F.ctx())
g1.ndata['feat'] = F.copy_to(F.randn((5, 8)), F.ctx())
g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4]))
g1.ndata['feat'] = F.copy_to(F.randn((5, 8)), F.cpu())
# no negative sampler
dataloader = dgl.dataloading.EdgeDataLoader(
......@@ -276,9 +276,9 @@ def test_edge_dataloader():
('user', 'followed-by', 'user'): ([1, 2, 3, 0, 2, 3, 0], [0, 0, 0, 1, 1, 1, 2]),
('user', 'play', 'game'): ([0, 1, 1, 3, 5], [0, 1, 2, 0, 2]),
('game', 'played-by', 'user'): ([0, 1, 2, 0, 2], [0, 1, 1, 3, 5])
}).to(F.ctx())
})
for ntype in g2.ntypes:
g2.nodes[ntype].data['feat'] = F.copy_to(F.randn((g2.num_nodes(ntype), 8)), F.ctx())
g2.nodes[ntype].data['feat'] = F.copy_to(F.randn((g2.num_nodes(ntype), 8)), F.cpu())
batch_size = max(g2.num_edges(ety) for ety in g2.canonical_etypes)
# no negative sampler
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment