utils.py 1.11 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
"""
Graph Representation Learning via Hard Attention Networks in DGL using Adam optimization.
References
----------
Paper: https://arxiv.org/abs/1907.04652
"""

import numpy as np
import torch
import torch.nn as nn

12

13
14
15
16
17
18
19
20
21
22
23
24
25
26
class EarlyStopping:
    def __init__(self, patience=10):
        self.patience = patience
        self.counter = 0
        self.best_score = None
        self.early_stop = False

    def step(self, acc, model):
        score = acc
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(model)
        elif score < self.best_score:
            self.counter += 1
27
28
29
            print(
                f"EarlyStopping counter: {self.counter} out of {self.patience}"
            )
30
31
32
33
34
35
36
37
38
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(model)
            self.counter = 0
        return self.early_stop

    def save_checkpoint(self, model):
39
40
        """Saves model when validation loss decrease."""
        torch.save(model.state_dict(), "es_checkpoint.pt")