Unverified Commit 3d654843 authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

Fix mxnet gin (#2048)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-214.ec2.internal>
parent 65866989
......@@ -30,7 +30,8 @@ def collate(samples):
g.ndata[key] = nd.array(g.ndata[key])
# no edge feats
batched_graph = dgl.batch(graphs)
labels = nd.array(labels)
labels = [nd.reshape(label, (1,)) for label in labels]
labels = nd.concat(*labels, dim=0)
return batched_graph, labels
class GraphDataLoader():
......@@ -78,7 +79,7 @@ class GraphDataLoader():
skf = StratifiedKFold(n_splits=10, shuffle=shuffle, random_state=seed)
idx_list = []
for idx in skf.split(np.zeros(len(labels)), labels): # split(x, y)
for idx in skf.split(np.zeros(len(labels)), [label.asnumpy() for label in labels]): # split(x, y)
idx_list.append(idx)
train_idx, valid_idx = idx_list[fold_idx]
......
......@@ -24,6 +24,7 @@ def train(args, net, trainloader, trainer, criterion, epoch):
feat = graphs.ndata['attr'].astype('float32').as_in_context(args.device)
with mx.autograd.record():
graphs = graphs.to(args.device)
outputs = net(graphs, feat)
loss = criterion(outputs, labels)
loss = loss.sum() / len(labels)
......@@ -54,9 +55,10 @@ def eval_net(args, net, dataloader, criterion):
feat = graphs.ndata['attr'].astype('float32').as_in_context(args.device)
total += len(labels)
graphs = graphs.to(args.device)
outputs = net(graphs, feat)
predicted = nd.argmax(outputs, axis=1)
predicted = predicted.astype('int64')
total_correct += (predicted == labels).sum().asscalar()
loss = criterion(outputs, labels)
......@@ -155,4 +157,4 @@ if __name__ == '__main__':
print('show all arguments configuration...')
print(args)
main(args)
\ No newline at end of file
main(args)
......@@ -257,7 +257,7 @@ class GINDataset(DGLBuiltinDataset):
for g in self.graphs:
g.ndata['attr'] = F.tensor(np.zeros((
g.number_of_nodes(), len(label2idx))))
g.ndata['attr'][range(g.number_of_nodes()), [label2idx[F.as_scalar(nl)] for nl in g.ndata['label']]] = 1
g.ndata['attr'][range(g.number_of_nodes()), [label2idx[F.as_scalar(F.reshape(nl, (1,)))] for nl in g.ndata['label']]] = 1
# after load, get the #classes and #dim
self.gclasses = len(self.glabel_dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment