trainer.py 3.9 KB
Newer Older
Yuge Zhang's avatar
Yuge Zhang committed
1
2
3
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

4
5
import json
import logging
6
from abc import abstractmethod
7

8
import torch
9

10
11
from .base_trainer import BaseTrainer

12
13
14
15
16
17
18
19
20
21
22
23
24
_logger = logging.getLogger(__name__)


class TorchTensorEncoder(json.JSONEncoder):
    def default(self, o):  # pylint: disable=method-hidden
        if isinstance(o, torch.Tensor):
            olist = o.tolist()
            if "bool" not in o.type().lower() and all(map(lambda d: d == 0 or d == 1, olist)):
                _logger.warning("Every element in %s is either 0 or 1. "
                                "You might consider convert it into bool.", olist)
            return olist
        return super().default(o)

25
26

class Trainer(BaseTrainer):
27
28
    def __init__(self, model, mutator, loss, metrics, optimizer, num_epochs,
                 dataset_train, dataset_valid, batch_size, workers, device, log_frequency, callbacks):
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
        """
        Trainer initialization.

        Parameters
        ----------
        model : nn.Module
            Model with mutables.
        mutator : BaseMutator
            A mutator object that has been initialized with the model.
        loss : callable
            Called with logits and targets. Returns a loss tensor.
        metrics : callable
            Returns a dict that maps metrics keys to metrics data.
        optimizer : Optimizer
            Optimizer that optimizes the model.
        num_epochs : int
            Number of epochs of training.
        dataset_train : torch.utils.data.Dataset
            Dataset of training.
        dataset_valid : torch.utils.data.Dataset
            Dataset of validation/testing.
        batch_size : int
            Batch size.
        workers : int
            Number of workers used in data preprocessing.
        device : torch.device
55
            Device object. Either ``torch.device("cuda")`` or ``torch.device("cpu")``. When ``None``, trainer will
56
57
58
59
60
61
62
            automatic detects GPU and selects GPU first.
        log_frequency : int
            Number of mini-batches to log metrics.
        callbacks : list of Callback
            Callbacks to plug into the trainer. See Callbacks.
        """

63
64
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
        self.model = model
65
        self.mutator = mutator
66
        self.loss = loss
67

68
69
70
71
72
        self.metrics = metrics
        self.optimizer = optimizer

        self.model.to(self.device)
        self.mutator.to(self.device)
73
        self.loss.to(self.device)
74
75
76
77
78
79
80
81
82
83

        self.num_epochs = num_epochs
        self.dataset_train = dataset_train
        self.dataset_valid = dataset_valid
        self.batch_size = batch_size
        self.workers = workers
        self.log_frequency = log_frequency
        self.callbacks = callbacks if callbacks is not None else []
        for callback in self.callbacks:
            callback.build(self.model, self.mutator, self)
84
85

    @abstractmethod
86
87
    def train_one_epoch(self, epoch):
        pass
88
89

    @abstractmethod
90
91
92
    def validate_one_epoch(self, epoch):
        pass

93
    def train(self, validate=True):
94
95
96
97
98
        for epoch in range(self.num_epochs):
            for callback in self.callbacks:
                callback.on_epoch_begin(epoch)

            # training
Chi Song's avatar
Chi Song committed
99
            _logger.info("Epoch %d Training", epoch)
100
101
102
103
            self.train_one_epoch(epoch)

            if validate:
                # validation
Chi Song's avatar
Chi Song committed
104
                _logger.info("Epoch %d Validating", epoch)
105
106
107
108
109
110
111
                self.validate_one_epoch(epoch)

            for callback in self.callbacks:
                callback.on_epoch_end(epoch)

    def validate(self):
        self.validate_one_epoch(-1)
112
113
114
115
116
117
118
119

    def export(self, file):
        mutator_export = self.mutator.export()
        with open(file, "w") as f:
            json.dump(mutator_export, f, indent=2, sort_keys=True, cls=TorchTensorEncoder)

    def checkpoint(self):
        raise NotImplementedError("Not implemented yet")