"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "c1e6a32ae46594c6ba8cb1d4690f70755389aacb"
Unverified Commit 0f127637 authored by VoVAllen's avatar VoVAllen Committed by GitHub
Browse files

[Model] Early stop GAT (#750)

* Add early stop

* add mxnet version

* Poke ci
parent 77c58289
......@@ -18,7 +18,7 @@ import numpy as np
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
from gat import GAT
from utils import EarlyStopping
def elu(data):
return mx.nd.LeakyReLU(data, act_type='elu')
......@@ -75,6 +75,7 @@ def main(args):
args.alpha,
args.residual)
stopper = EarlyStopping(patience=100)
model.initialize(ctx=ctx)
# use optimizer
......@@ -95,10 +96,11 @@ def main(args):
dur.append(time.time() - t0)
print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}".format(
epoch, loss.asnumpy()[0], np.mean(dur), n_edges / np.mean(dur) / 1000))
if epoch % 100 == 0:
val_accuracy = evaluate(model, features, labels, val_mask)
print("Validation Accuracy {:.4f}".format(val_accuracy))
val_accuracy = evaluate(model, features, labels, val_mask)
print("Validation Accuracy {:.4f}".format(val_accuracy))
if stopper.step(val_accuracy, model):
break
model.load_parameters('model.param')
test_accuracy = evaluate(model, features, labels, test_mask)
print("Test Accuracy {:.4f}".format(test_accuracy))
......
import numpy as np
import torch
class EarlyStopping:
def __init__(self, patience=10):
self.patience = patience
self.counter = 0
self.best_score = None
self.early_stop = False
def step(self, acc, model):
score = acc
if self.best_score is None:
self.best_score = score
self.save_checkpoint(model)
elif score < self.best_score:
self.counter += 1
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(model)
self.counter = 0
return self.early_stop
def save_checkpoint(self, model):
'''Saves model when validation loss decrease.'''
model.save_parameters('model.param')
......@@ -18,12 +18,15 @@ import torch.nn.functional as F
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
from gat import GAT
from utils import EarlyStopping
def accuracy(logits, labels):
_, indices = torch.max(logits, dim=1)
correct = torch.sum(indices == labels)
return correct.item() * 1.0 / len(labels)
def evaluate(model, features, labels, mask):
model.eval()
with torch.no_grad():
......@@ -32,6 +35,7 @@ def evaluate(model, features, labels, mask):
labels = labels[mask]
return accuracy(logits, labels)
def main(args):
# load and preprocess dataset
data = load_data(args)
......@@ -45,7 +49,7 @@ def main(args):
n_edges = data.graph.number_of_edges()
print("""----Data statistics------'
#Edges %d
#Classes %d
#Classes %d
#Train samples %d
#Val samples %d
#Test samples %d""" %
......@@ -85,12 +89,14 @@ def main(args):
args.alpha,
args.residual)
print(model)
stopper = EarlyStopping(patience=100)
if cuda:
model.cuda()
loss_fcn = torch.nn.CrossEntropyLoss()
# use optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
optimizer = torch.optim.Adam(
model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
# initialize graph
dur = []
......@@ -115,6 +121,8 @@ def main(args):
val_acc = accuracy(logits[val_mask], labels[val_mask])
else:
val_acc = evaluate(model, features, labels, val_mask)
if stopper.step(val_acc, model):
break
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | TrainAcc {:.4f} |"
" ValAcc {:.4f} | ETputs(KTEPS) {:.2f}".
......@@ -122,9 +130,11 @@ def main(args):
val_acc, n_edges / np.mean(dur) / 1000))
print()
model.load_state_dict(torch.load('es_checkpoint.pt'))
acc = evaluate(model, features, labels, test_mask)
print("Test Accuracy {:.4f}".format(acc))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GAT')
......
import numpy as np
import torch
class EarlyStopping:
def __init__(self, patience=10):
self.patience = patience
self.counter = 0
self.best_score = None
self.early_stop = False
def step(self, acc, model):
score = acc
if self.best_score is None:
self.best_score = score
self.save_checkpoint(model)
elif score < self.best_score:
self.counter += 1
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(model)
self.counter = 0
return self.early_stop
def save_checkpoint(self, model):
'''Saves model when validation loss decrease.'''
torch.save(model.state_dict(), 'es_checkpoint.pt')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment