utils.py 922 Bytes
Newer Older
VoVAllen's avatar
VoVAllen committed
1
2
import numpy as np

3

VoVAllen's avatar
VoVAllen committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class EarlyStopping:
    def __init__(self, patience=10):
        self.patience = patience
        self.counter = 0
        self.best_score = None
        self.early_stop = False

    def step(self, acc, model):
        score = acc
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(model)
        elif score < self.best_score:
            self.counter += 1
18
19
20
            print(
                f"EarlyStopping counter: {self.counter} out of {self.patience}"
            )
VoVAllen's avatar
VoVAllen committed
21
22
23
24
25
26
27
28
29
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(model)
            self.counter = 0
        return self.early_stop

    def save_checkpoint(self, model):
30
31
        """Saves model when validation loss decrease."""
        model.save_parameters("model.param")