Unverified Commit fdc58a89 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[Hotfix] Fix GAT example and clarify the usage of early stop (#1065)

* upd

* rm redundancy:

* upd
parent a2b8d8e4
...@@ -19,5 +19,7 @@ pip install requests ...@@ -19,5 +19,7 @@ pip install requests
### Usage (make sure that DGLBACKEND is changed into mxnet) ### Usage (make sure that DGLBACKEND is changed into mxnet)
```bash ```bash
DGLBACKEND=mxnet python3 train.py --dataset cora --gpu 0 --num-heads 8 DGLBACKEND=mxnet python3 train.py --dataset cora --gpu 0
DGLBACKEND=mxnet python3 train.py --dataset citeseer --gpu 0 --early-stop
DGLBACKEND=mxnet python3 train.py --dataset pubmed --gpu 0 --early-stop
``` ```
""" """
Graph Attention Networks in DGL using SPMV optimization. Graph Attention Networks in DGL using SPMV optimization.
Multiple heads are also batched together for faster training. Multiple heads are also batched together for faster training.
Compared with the original paper, this code does not implement
early stopping.
References References
---------- ----------
Paper: https://arxiv.org/abs/1710.10903 Paper: https://arxiv.org/abs/1710.10903
...@@ -76,6 +74,7 @@ def main(args): ...@@ -76,6 +74,7 @@ def main(args):
args.alpha, args.alpha,
args.residual) args.residual)
if args.early_stop:
stopper = EarlyStopping(patience=100) stopper = EarlyStopping(patience=100)
model.initialize(ctx=ctx) model.initialize(ctx=ctx)
...@@ -99,8 +98,12 @@ def main(args): ...@@ -99,8 +98,12 @@ def main(args):
epoch, loss.asnumpy()[0], np.mean(dur), n_edges / np.mean(dur) / 1000)) epoch, loss.asnumpy()[0], np.mean(dur), n_edges / np.mean(dur) / 1000))
val_accuracy = evaluate(model, features, labels, val_mask) val_accuracy = evaluate(model, features, labels, val_mask)
print("Validation Accuracy {:.4f}".format(val_accuracy)) print("Validation Accuracy {:.4f}".format(val_accuracy))
if args.early_stop:
if stopper.step(val_accuracy, model): if stopper.step(val_accuracy, model):
break break
print()
if args.early_stop:
model.load_parameters('model.param') model.load_parameters('model.param')
test_accuracy = evaluate(model, features, labels, test_mask) test_accuracy = evaluate(model, features, labels, test_mask)
print("Test Accuracy {:.4f}".format(test_accuracy)) print("Test Accuracy {:.4f}".format(test_accuracy))
...@@ -134,6 +137,8 @@ if __name__ == '__main__': ...@@ -134,6 +137,8 @@ if __name__ == '__main__':
help="weight decay") help="weight decay")
parser.add_argument('--alpha', type=float, default=0.2, parser.add_argument('--alpha', type=float, default=0.2,
help="the negative slop of leaky relu") help="the negative slop of leaky relu")
parser.add_argument('--early-stop', action='store_true', default=False,
help="indicates whether to use early stop or not")
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
......
...@@ -27,11 +27,11 @@ python3 train.py --dataset=cora --gpu=0 ...@@ -27,11 +27,11 @@ python3 train.py --dataset=cora --gpu=0
``` ```
```bash ```bash
python3 train.py --dataset=citeseer --gpu=0 python3 train.py --dataset=citeseer --gpu=0 --early-stop
``` ```
```bash ```bash
python3 train.py --dataset=pubmed --gpu=0 --num-out-heads=8 --weight-decay=0.001 python3 train.py --dataset=pubmed --gpu=0 --num-out-heads=8 --weight-decay=0.001 --early-stop
``` ```
```bash ```bash
...@@ -43,9 +43,9 @@ Results ...@@ -43,9 +43,9 @@ Results
| Dataset | Test Accuracy | Time(s) | Baseline#1 times(s) | Baseline#2 times(s) | | Dataset | Test Accuracy | Time(s) | Baseline#1 times(s) | Baseline#2 times(s) |
| ------- | ------------- | ------- | ------------------- | ------------------- | | ------- | ------------- | ------- | ------------------- | ------------------- |
| Cora | 84.0% | 0.0113 | 0.0982 (**8.7x**) | 0.0424 (**3.8x**) | | Cora | 84.02(0.40) | 0.0113 | 0.0982 (**8.7x**) | 0.0424 (**3.8x**) |
| Citeseer | 70.7% | 0.0111 | n/a | n/a | | Citeseer | 70.91(0.79) | 0.0111 | n/a | n/a |
| Pubmed | 78.0% | 0.0115 | n/a | n/a | | Pubmed | 78.57(0.75) | 0.0115 | n/a | n/a |
* All the accuracy numbers are obtained after 300 epochs. * All the accuracy numbers are obtained after 300 epochs.
* The time measures how long it takes to train one epoch. * The time measures how long it takes to train one epoch.
......
""" """
Graph Attention Networks in DGL using SPMV optimization. Graph Attention Networks in DGL using SPMV optimization.
Multiple heads are also batched together for faster training. Multiple heads are also batched together for faster training.
Compared with the original paper, this code does not implement
early stopping.
References References
---------- ----------
Paper: https://arxiv.org/abs/1710.10903 Paper: https://arxiv.org/abs/1710.10903
...@@ -95,6 +93,7 @@ def main(args): ...@@ -95,6 +93,7 @@ def main(args):
args.negative_slope, args.negative_slope,
args.residual) args.residual)
print(model) print(model)
if args.early_stop:
stopper = EarlyStopping(patience=100) stopper = EarlyStopping(patience=100)
if cuda: if cuda:
model.cuda() model.cuda()
...@@ -127,6 +126,7 @@ def main(args): ...@@ -127,6 +126,7 @@ def main(args):
val_acc = accuracy(logits[val_mask], labels[val_mask]) val_acc = accuracy(logits[val_mask], labels[val_mask])
else: else:
val_acc = evaluate(model, features, labels, val_mask) val_acc = evaluate(model, features, labels, val_mask)
if args.early_stop:
if stopper.step(val_acc, model): if stopper.step(val_acc, model):
break break
...@@ -136,6 +136,7 @@ def main(args): ...@@ -136,6 +136,7 @@ def main(args):
val_acc, n_edges / np.mean(dur) / 1000)) val_acc, n_edges / np.mean(dur) / 1000))
print() print()
if args.early_stop:
model.load_state_dict(torch.load('es_checkpoint.pt')) model.load_state_dict(torch.load('es_checkpoint.pt'))
acc = evaluate(model, features, labels, test_mask) acc = evaluate(model, features, labels, test_mask)
print("Test Accuracy {:.4f}".format(acc)) print("Test Accuracy {:.4f}".format(acc))
...@@ -169,6 +170,8 @@ if __name__ == '__main__': ...@@ -169,6 +170,8 @@ if __name__ == '__main__':
help="weight decay") help="weight decay")
parser.add_argument('--negative-slope', type=float, default=0.2, parser.add_argument('--negative-slope', type=float, default=0.2,
help="the negative slope of leaky relu") help="the negative slope of leaky relu")
parser.add_argument('--early-stop', action='store_true', default=False,
help="indicates whether to use early stop or not")
parser.add_argument('--fastmode', action="store_true", default=False, parser.add_argument('--fastmode', action="store_true", default=False,
help="skip re-evaluate the validation set") help="skip re-evaluate the validation set")
args = parser.parse_args() args = parser.parse_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment