Unverified Commit 3af0e91c authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Frame] Fix for Column Index Device (#2055)



* Update frame.py

* add unit test

* fix test

* fix

* fix

* fix
Co-authored-by: default avatarQuan Gan <coin2028@hotmail.com>
parent 4c5136c8
......@@ -143,6 +143,8 @@ class Column(object):
"""
col = self.clone()
col.device = (device, kwargs)
if self.index is not None:
col.index = F.copy_to(self.index, device)
return col
def __getitem__(self, rowids):
......
......@@ -366,3 +366,10 @@ def test_out_subgraph(idtype):
edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
assert edge_set == {(0,0),(1,0)}
assert F.array_equal(hg['flips'].edge_ids(u, v), subg['flips'].edata[dgl.EID])
def test_subgraph_message_passing():
# Unit test for PR #2055
g = dgl.graph(([0, 1, 2], [2, 3, 4])).to(F.cpu())
g.ndata['x'] = F.copy_to(F.randn((5, 6)), F.cpu())
sg = g.subgraph([1, 2, 3]).to(F.ctx())
sg.update_all(lambda edges: {'x': edges.src['x']}, lambda nodes: {'y': F.sum(nodes.mailbox['x'], 1)})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment