trainer.py 6.36 KB
Newer Older
Yuge Zhang's avatar
Yuge Zhang committed
1
2
3
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

4
5
import json
import logging
6
7
import os
import time
8
from abc import abstractmethod
9

10
import torch
11

12
13
from .base_trainer import BaseTrainer

14
15
16
17
18
19
20
21
22
23
24
25
26
_logger = logging.getLogger(__name__)


class TorchTensorEncoder(json.JSONEncoder):
    def default(self, o):  # pylint: disable=method-hidden
        if isinstance(o, torch.Tensor):
            olist = o.tolist()
            if "bool" not in o.type().lower() and all(map(lambda d: d == 0 or d == 1, olist)):
                _logger.warning("Every element in %s is either 0 or 1. "
                                "You might consider convert it into bool.", olist)
            return olist
        return super().default(o)

27
28

class Trainer(BaseTrainer):
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    """
    A trainer with some helper functions implemented. To implement a new trainer,
    users need to implement :meth:`train_one_epoch`, :meth:`validate_one_epoch` and :meth:`checkpoint`.

    Parameters
    ----------
    model : nn.Module
        Model with mutables.
    mutator : BaseMutator
        A mutator object that has been initialized with the model.
    loss : callable
        Called with logits and targets. Returns a loss tensor.
        See `PyTorch loss functions`_ for examples.
    metrics : callable
        Called with logits and targets. Returns a dict that maps metrics keys to metrics data. For example,

        .. code-block:: python

            def metrics_fn(output, target):
                return {"acc1": accuracy(output, target, topk=1), "acc5": accuracy(output, target, topk=5)}

    optimizer : Optimizer
        Optimizer that optimizes the model.
    num_epochs : int
        Number of epochs of training.
    dataset_train : torch.utils.data.Dataset
        Dataset of training. If not otherwise specified, ``dataset_train`` and ``dataset_valid`` should be standard
        PyTorch Dataset. See `torch.utils.data`_ for examples.
    dataset_valid : torch.utils.data.Dataset
        Dataset of validation/testing.
    batch_size : int
        Batch size.
    workers : int
        Number of workers used in data preprocessing.
    device : torch.device
        Device object. Either ``torch.device("cuda")`` or ``torch.device("cpu")``. When ``None``, trainer will
        automatic detects GPU and selects GPU first.
    log_frequency : int
        Number of mini-batches to log metrics.
    callbacks : list of Callback
        Callbacks to plug into the trainer. See Callbacks.


    .. _`PyTorch loss functions`: https://pytorch.org/docs/stable/nn.html#loss-functions
    .. _`torch.utils.data`: https://pytorch.org/docs/stable/data.html
    """
75
76
    def __init__(self, model, mutator, loss, metrics, optimizer, num_epochs,
                 dataset_train, dataset_valid, batch_size, workers, device, log_frequency, callbacks):
77
78
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
        self.model = model
79
        self.mutator = mutator
80
        self.loss = loss
81

82
83
84
85
86
        self.metrics = metrics
        self.optimizer = optimizer

        self.model.to(self.device)
        self.mutator.to(self.device)
87
        self.loss.to(self.device)
88
89
90
91
92
93
94

        self.num_epochs = num_epochs
        self.dataset_train = dataset_train
        self.dataset_valid = dataset_valid
        self.batch_size = batch_size
        self.workers = workers
        self.log_frequency = log_frequency
95
96
97
        self.log_dir = os.path.join("logs", str(time.time()))
        os.makedirs(self.log_dir, exist_ok=True)
        self.status_writer = open(os.path.join(self.log_dir, "log"), "w")
98
99
100
        self.callbacks = callbacks if callbacks is not None else []
        for callback in self.callbacks:
            callback.build(self.model, self.mutator, self)
101
102

    @abstractmethod
103
    def train_one_epoch(self, epoch):
104
105
106
107
108
109
110
111
        """
        Train one epoch.

        Parameters
        ----------
        epoch : int
            Epoch number starting from 0.
        """
112
        pass
113
114

    @abstractmethod
115
    def validate_one_epoch(self, epoch):
116
117
118
119
120
121
122
123
        """
        Validate one epoch.

        Parameters
        ----------
        epoch : int
            Epoch number starting from 0.
        """
124
125
        pass

126
    def train(self, validate=True):
127
128
129
130
131
132
133
134
135
        """
        Train ``num_epochs``.
        Trigger callbacks at the start and the end of each epoch.

        Parameters
        ----------
        validate : bool
            If ``true``, will do validation every epoch.
        """
136
137
138
139
140
        for epoch in range(self.num_epochs):
            for callback in self.callbacks:
                callback.on_epoch_begin(epoch)

            # training
Yuge Zhang's avatar
Yuge Zhang committed
141
            _logger.info("Epoch %d Training", epoch + 1)
142
143
144
145
            self.train_one_epoch(epoch)

            if validate:
                # validation
Yuge Zhang's avatar
Yuge Zhang committed
146
                _logger.info("Epoch %d Validating", epoch + 1)
147
148
149
150
151
152
                self.validate_one_epoch(epoch)

            for callback in self.callbacks:
                callback.on_epoch_end(epoch)

    def validate(self):
153
154
155
        """
        Do one validation.
        """
156
        self.validate_one_epoch(-1)
157
158

    def export(self, file):
159
160
161
162
163
164
165
166
        """
        Call ``mutator.export()`` and dump the architecture to ``file``.

        Parameters
        ----------
        file : str
            A file path. Expected to be a JSON.
        """
167
168
169
170
171
        mutator_export = self.mutator.export()
        with open(file, "w") as f:
            json.dump(mutator_export, f, indent=2, sort_keys=True, cls=TorchTensorEncoder)

    def checkpoint(self):
172
173
174
        """
        Return trainer checkpoint.
        """
175
        raise NotImplementedError("Not implemented yet")
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194

    def enable_visualization(self):
        """
        Enable visualization. Write graph and training log to folder ``logs/<timestamp>``.
        """
        sample = None
        for x, _ in self.train_loader:
            sample = x.to(self.device)[:2]
            break
        if sample is None:
            _logger.warning("Sample is %s.", sample)
        _logger.info("Creating graph json, writing to %s. Visualization enabled.", self.log_dir)
        with open(os.path.join(self.log_dir, "graph.json"), "w") as f:
            json.dump(self.mutator.graph(sample), f)
        self.visualization_enabled = True

    def _write_graph_status(self):
        if hasattr(self, "visualization_enabled") and self.visualization_enabled:
            print(json.dumps(self.mutator.status()), file=self.status_writer, flush=True)