Unverified Commit 99f726ff authored by Chang Liu's avatar Chang Liu Committed by GitHub
Browse files

Fix MNIST examples (#4632)

parent 7f50a6da
...@@ -35,6 +35,8 @@ L, perm = coarsen(A, coarsening_levels) ...@@ -35,6 +35,8 @@ L, perm = coarsen(A, coarsening_levels)
g_arr = [dgl.from_scipy(csr) for csr in L] g_arr = [dgl.from_scipy(csr) for csr in L]
coordinate_arr = get_coordinates(g_arr, grid_side, coarsening_levels, perm) coordinate_arr = get_coordinates(g_arr, grid_side, coarsening_levels, perm)
str_to_torch_dtype = {'float16':torch.half, 'float32':torch.float32, 'float64':torch.float64}
coordinate_arr = [coord.to(dtype=str_to_torch_dtype[str(A.dtype)]) for coord in coordinate_arr]
for g, coordinate_arr in zip(g_arr, coordinate_arr): for g, coordinate_arr in zip(g_arr, coordinate_arr):
g.ndata['xy'] = coordinate_arr g.ndata['xy'] = coordinate_arr
g.apply_edges(z2polar) g.apply_edges(z2polar)
...@@ -99,8 +101,6 @@ class MoNet(nn.Module): ...@@ -99,8 +101,6 @@ class MoNet(nn.Module):
u = g.edata['u'] u = g.edata['u']
feat = self.pool(layer(g, feat, u).transpose(-1, -2).unsqueeze(0))\ feat = self.pool(layer(g, feat, u).transpose(-1, -2).unsqueeze(0))\
.squeeze(0).transpose(-1, -2) .squeeze(0).transpose(-1, -2)
print(feat.shape)
print(g_arr[-1].batch_size)
return self.cls(self.readout(g_arr[-1], feat)) return self.cls(self.readout(g_arr[-1], feat))
class ChebNet(nn.Module): class ChebNet(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment