utils.py 953 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
"""
From GAT utils
"""
import torch


class EarlyStopping:
    def __init__(self, patience=10):
        self.patience = patience
        self.counter = 0
        self.best_score = None
        self.early_stop = False

    def step(self, acc, model):
        score = acc
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(model)
        elif score < self.best_score:
            self.counter += 1
21
22
23
            print(
                f"EarlyStopping counter: {self.counter} out of {self.patience}"
            )
24
25
26
27
28
29
30
31
32
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(model)
            self.counter = 0
        return self.early_stop

    def save_checkpoint(self, model):
33
34
        """Saves model when validation loss decrease."""
        torch.save(model.state_dict(), "es_checkpoint.pt")