"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "fb9dcc51ff14e7cec8fc453cf7bfe44ee9349858"
Unverified Commit 0f127637 authored by VoVAllen's avatar VoVAllen Committed by GitHub
Browse files

[Model] Early stop GAT (#750)

* Add early stop

* add mxnet version

* Poke ci
parent 77c58289
...@@ -18,7 +18,7 @@ import numpy as np ...@@ -18,7 +18,7 @@ import numpy as np
from dgl import DGLGraph from dgl import DGLGraph
from dgl.data import register_data_args, load_data from dgl.data import register_data_args, load_data
from gat import GAT from gat import GAT
from utils import EarlyStopping
def elu(data): def elu(data):
return mx.nd.LeakyReLU(data, act_type='elu') return mx.nd.LeakyReLU(data, act_type='elu')
...@@ -75,6 +75,7 @@ def main(args): ...@@ -75,6 +75,7 @@ def main(args):
args.alpha, args.alpha,
args.residual) args.residual)
stopper = EarlyStopping(patience=100)
model.initialize(ctx=ctx) model.initialize(ctx=ctx)
# use optimizer # use optimizer
...@@ -95,10 +96,11 @@ def main(args): ...@@ -95,10 +96,11 @@ def main(args):
dur.append(time.time() - t0) dur.append(time.time() - t0)
print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}".format( print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}".format(
epoch, loss.asnumpy()[0], np.mean(dur), n_edges / np.mean(dur) / 1000)) epoch, loss.asnumpy()[0], np.mean(dur), n_edges / np.mean(dur) / 1000))
if epoch % 100 == 0: val_accuracy = evaluate(model, features, labels, val_mask)
val_accuracy = evaluate(model, features, labels, val_mask) print("Validation Accuracy {:.4f}".format(val_accuracy))
print("Validation Accuracy {:.4f}".format(val_accuracy)) if stopper.step(val_accuracy, model):
break
model.load_parameters('model.param')
test_accuracy = evaluate(model, features, labels, test_mask) test_accuracy = evaluate(model, features, labels, test_mask)
print("Test Accuracy {:.4f}".format(test_accuracy)) print("Test Accuracy {:.4f}".format(test_accuracy))
......
import numpy as np
import torch
class EarlyStopping:
def __init__(self, patience=10):
self.patience = patience
self.counter = 0
self.best_score = None
self.early_stop = False
def step(self, acc, model):
score = acc
if self.best_score is None:
self.best_score = score
self.save_checkpoint(model)
elif score < self.best_score:
self.counter += 1
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(model)
self.counter = 0
return self.early_stop
def save_checkpoint(self, model):
'''Saves model when validation loss decrease.'''
model.save_parameters('model.param')
...@@ -18,12 +18,15 @@ import torch.nn.functional as F ...@@ -18,12 +18,15 @@ import torch.nn.functional as F
from dgl import DGLGraph from dgl import DGLGraph
from dgl.data import register_data_args, load_data from dgl.data import register_data_args, load_data
from gat import GAT from gat import GAT
from utils import EarlyStopping
def accuracy(logits, labels): def accuracy(logits, labels):
_, indices = torch.max(logits, dim=1) _, indices = torch.max(logits, dim=1)
correct = torch.sum(indices == labels) correct = torch.sum(indices == labels)
return correct.item() * 1.0 / len(labels) return correct.item() * 1.0 / len(labels)
def evaluate(model, features, labels, mask): def evaluate(model, features, labels, mask):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
...@@ -32,6 +35,7 @@ def evaluate(model, features, labels, mask): ...@@ -32,6 +35,7 @@ def evaluate(model, features, labels, mask):
labels = labels[mask] labels = labels[mask]
return accuracy(logits, labels) return accuracy(logits, labels)
def main(args): def main(args):
# load and preprocess dataset # load and preprocess dataset
data = load_data(args) data = load_data(args)
...@@ -45,7 +49,7 @@ def main(args): ...@@ -45,7 +49,7 @@ def main(args):
n_edges = data.graph.number_of_edges() n_edges = data.graph.number_of_edges()
print("""----Data statistics------' print("""----Data statistics------'
#Edges %d #Edges %d
#Classes %d #Classes %d
#Train samples %d #Train samples %d
#Val samples %d #Val samples %d
#Test samples %d""" % #Test samples %d""" %
...@@ -85,12 +89,14 @@ def main(args): ...@@ -85,12 +89,14 @@ def main(args):
args.alpha, args.alpha,
args.residual) args.residual)
print(model) print(model)
stopper = EarlyStopping(patience=100)
if cuda: if cuda:
model.cuda() model.cuda()
loss_fcn = torch.nn.CrossEntropyLoss() loss_fcn = torch.nn.CrossEntropyLoss()
# use optimizer # use optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) optimizer = torch.optim.Adam(
model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
# initialize graph # initialize graph
dur = [] dur = []
...@@ -115,6 +121,8 @@ def main(args): ...@@ -115,6 +121,8 @@ def main(args):
val_acc = accuracy(logits[val_mask], labels[val_mask]) val_acc = accuracy(logits[val_mask], labels[val_mask])
else: else:
val_acc = evaluate(model, features, labels, val_mask) val_acc = evaluate(model, features, labels, val_mask)
if stopper.step(val_acc, model):
break
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | TrainAcc {:.4f} |" print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | TrainAcc {:.4f} |"
" ValAcc {:.4f} | ETputs(KTEPS) {:.2f}". " ValAcc {:.4f} | ETputs(KTEPS) {:.2f}".
...@@ -122,9 +130,11 @@ def main(args): ...@@ -122,9 +130,11 @@ def main(args):
val_acc, n_edges / np.mean(dur) / 1000)) val_acc, n_edges / np.mean(dur) / 1000))
print() print()
model.load_state_dict(torch.load('es_checkpoint.pt'))
acc = evaluate(model, features, labels, test_mask) acc = evaluate(model, features, labels, test_mask)
print("Test Accuracy {:.4f}".format(acc)) print("Test Accuracy {:.4f}".format(acc))
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GAT') parser = argparse.ArgumentParser(description='GAT')
......
import numpy as np
import torch
class EarlyStopping:
def __init__(self, patience=10):
self.patience = patience
self.counter = 0
self.best_score = None
self.early_stop = False
def step(self, acc, model):
score = acc
if self.best_score is None:
self.best_score = score
self.save_checkpoint(model)
elif score < self.best_score:
self.counter += 1
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(model)
self.counter = 0
return self.early_stop
def save_checkpoint(self, model):
'''Saves model when validation loss decrease.'''
torch.save(model.state_dict(), 'es_checkpoint.pt')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment