"docs/vscode:/vscode.git/clone" did not exist on "6f9ae8d61eca4a2841bb06b47a993c523de6f43c"
Unverified Commit dba36c87 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

Fix some minor problems in SSE (#309)

* fix SSE.

* fix.

* fix.
parent 0d0f4436
......@@ -25,10 +25,11 @@ Test convergence
```bash
DGLBACKEND=mxnet python3 sse_batch.py --dataset "pubmed" \
--n-epochs 100 \
--n-epochs 1000 \
--lr 0.001 \
--batch-size 1024 \
--batch-size 30 \
--dgl \
--use-spmv \
--neigh-expand 4
--neigh-expand 8 \
--n-hidden 16
```
......@@ -200,9 +200,14 @@ def main(args, data):
labels = data.labels
else:
labels = mx.nd.array(data.labels)
if data.train_mask is not None:
train_vs = mx.nd.array(np.nonzero(data.train_mask)[0], dtype='int64')
eval_vs = mx.nd.array(np.nonzero(data.train_mask == 0)[0], dtype='int64')
else:
train_size = len(labels) * args.train_percent
train_vs = mx.nd.arange(0, train_size, dtype='int64')
eval_vs = mx.nd.arange(train_size, len(labels), dtype='int64')
print("train size: " + str(len(train_vs)))
print("eval size: " + str(len(eval_vs)))
eval_labels = mx.nd.take(labels, eval_vs)
......@@ -305,9 +310,6 @@ def main(args, data):
+ " subgraphs takes " + str(end1 - start1))
start1 = end1
if i > num_batches / 3:
break
if args.cache_subgraph:
sampler.restart()
else:
......@@ -317,10 +319,12 @@ def main(args, data):
seed_nodes=train_vs, shuffle=True,
return_seed_id=True)
# prediction.
# test set accuracy
logits = model_infer(g, eval_vs)
eval_loss = mx.nd.softmax_cross_entropy(logits, eval_labels)
eval_loss = eval_loss.asnumpy()[0]
y_bar = mx.nd.argmax(logits, axis=1)
y = eval_labels
accuracy = mx.nd.sum(y_bar == y) / len(y)
accuracy = accuracy.asnumpy()[0]
# update the inference model.
infer_params = model_infer.collect_params()
......@@ -334,8 +338,8 @@ def main(args, data):
rets.append(all_hidden)
dur.append(time.time() - t0)
print("Epoch {:05d} | Train Loss {:.4f} | Eval Loss {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}".format(
epoch, train_loss, eval_loss, np.mean(dur), n_edges / np.mean(dur) / 1000))
print("Epoch {:05d} | Train Loss {:.4f} | Test Accuracy {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}".format(
epoch, train_loss, accuracy, np.mean(dur), n_edges / np.mean(dur) / 1000))
return rets
......@@ -361,6 +365,7 @@ class GraphData:
self.graph = MXNetGraph(csr)
self.features = mx.nd.random.normal(shape=(csr.shape[0], num_feats))
self.labels = mx.nd.floor(mx.nd.random.uniform(low=0, high=10, shape=(csr.shape[0])))
self.train_mask = None
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GCN')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment