wandblogger.py 1.95 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# Copyright (c) 2024 westlake-repl
# SPDX-License-Identifier: MIT

class WandbLogger(object):
    """WandbLogger to log metrics to Weights and Biases.

    """

    def __init__(self, config):
        """
        Args:
            config (dict): A dictionary of parameters used by RecBole.
        """
        self.config = config
        self.log_wandb = config.log_wandb
        self.setup()

    def setup(self):
        if self.log_wandb:
            try:
                import wandb
                self._wandb = wandb
            except ImportError:
                raise ImportError(
                    "To use the Weights and Biases Logger please install wandb."
                    "Run `pip install wandb` to install it."
                )

            # Initialize a W&B run
            if self._wandb.run is None:
                self._wandb.init(
                    project=self.config.wandb_project,
                    config=self.config
                )

            self._set_steps()

    def log_metrics(self, metrics, head='train', commit=True):
        if self.log_wandb:
            if head:
                metrics = self._add_head_to_metrics(metrics, head)
                self._wandb.log(metrics, commit=commit)
            else:
                self._wandb.log(metrics, commit=commit)

    def log_eval_metrics(self, metrics, head='eval'):
        if self.log_wandb:
            metrics = self._add_head_to_metrics(metrics, head)
            for k, v in metrics.items():
                self._wandb.run.summary[k] = v

    def _set_steps(self):
        self._wandb.define_metric('train/*', step_metric='train_step')
        self._wandb.define_metric('valid/*', step_metric='valid_step')

    def _add_head_to_metrics(self, metrics, head):
        head_metrics = dict()
        for k, v in metrics.items():
            if '_step' in k:
                head_metrics[k] = v
            else:
                head_metrics[f'{head}/{k}'] = v

        return head_metrics