WandbLog.py 1.16 KB
Newer Older
1
import os
mandoxzhang's avatar
mandoxzhang committed
2
import time
3

mandoxzhang's avatar
mandoxzhang committed
4
5
6
import wandb
from torch.utils.tensorboard import SummaryWriter

7

mandoxzhang's avatar
mandoxzhang committed
8
9
10
11
12
13
14
15
16
17
18
class WandbLog:
    @classmethod
    def init_wandb(cls, project, notes=None, name=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), config=None):
        wandb.init(project=project, notes=notes, name=name, config=config)

    @classmethod
    def log(cls, result, model=None, gradient=None):
        wandb.log(result)

        if model:
            wandb.watch(model)
19

mandoxzhang's avatar
mandoxzhang committed
20
21
22
23
24
25
26
27
28
29
30
31
        if gradient:
            wandb.watch(gradient)


class TensorboardLog:
    def __init__(self, location, name=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), config=None):
        if not os.path.exists(location):
            os.mkdir(location)
        self.writer = SummaryWriter(location, comment=name)

    def log_train(self, result, step):
        for k, v in result.items():
32
            self.writer.add_scalar(f"{k}/train", v, step)
33

mandoxzhang's avatar
mandoxzhang committed
34
35
    def log_eval(self, result, step):
        for k, v in result.items():
36
            self.writer.add_scalar(f"{k}/eval", v, step)
mandoxzhang's avatar
mandoxzhang committed
37
38
39

    def log_zeroshot(self, result, step):
        for k, v in result.items():
40
            self.writer.add_scalar(f"{k}_acc/eval", v, step)