logging_util.py 1.06 KB
Newer Older
Sehoon Kim's avatar
Sehoon Kim committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import wandb
import tensorflow as tf
import numpy as np
from numpy import linalg as la


from . import env_util

logger = env_util.setup_environment()

class StepLossMetric(tf.keras.metrics.Metric):
    def __init__(self, name='step_loss', **kwargs):
        super(StepLossMetric, self).__init__(name=name, **kwargs)
        self.loss = tf.zeros(())

    def update_state(self, loss):
        self.loss = loss

    def result(self):
        return self.loss

    def reset_states(self):
        self.loss = tf.zeros(())


class LoggingCallback(tf.keras.callbacks.Callback):
    def __init__(
        self, 
        optimizer, 
        model, 
    ):
        super(LoggingCallback, self).__init__()
        self.optimizer = optimizer
        self.model = model

    def on_epoch_end(self, epoch, logs=None):
        logger.info("saving checkpoint")
        iterations = self.optimizer.iterations
        lr = self.optimizer.learning_rate(iterations)
        logger.info(f"[LR Logger] Epoch: {epoch}, lr: {lr}")
        wandb.log({"epoch": epoch, "lr": lr, "iterations": iterations.numpy()})