trainer.py 2.62 KB
Newer Older
1
2
import json
import logging
3
from abc import abstractmethod
4

5
import torch
6

7
8
from .base_trainer import BaseTrainer

9
_logger = logging.getLogger(__name__)
Chi Song's avatar
Chi Song committed
10
_logger.setLevel(logging.INFO)
11
12
13
14
15
16
17
18
19
20
21
22


class TorchTensorEncoder(json.JSONEncoder):
    def default(self, o):  # pylint: disable=method-hidden
        if isinstance(o, torch.Tensor):
            olist = o.tolist()
            if "bool" not in o.type().lower() and all(map(lambda d: d == 0 or d == 1, olist)):
                _logger.warning("Every element in %s is either 0 or 1. "
                                "You might consider convert it into bool.", olist)
            return olist
        return super().default(o)

23
24

class Trainer(BaseTrainer):
25
26
    def __init__(self, model, mutator, loss, metrics, optimizer, num_epochs,
                 dataset_train, dataset_valid, batch_size, workers, device, log_frequency, callbacks):
27
28
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
        self.model = model
29
        self.mutator = mutator
30
        self.loss = loss
31

32
33
34
35
36
        self.metrics = metrics
        self.optimizer = optimizer

        self.model.to(self.device)
        self.mutator.to(self.device)
37
        self.loss.to(self.device)
38
39
40
41
42
43
44
45
46
47

        self.num_epochs = num_epochs
        self.dataset_train = dataset_train
        self.dataset_valid = dataset_valid
        self.batch_size = batch_size
        self.workers = workers
        self.log_frequency = log_frequency
        self.callbacks = callbacks if callbacks is not None else []
        for callback in self.callbacks:
            callback.build(self.model, self.mutator, self)
48
49

    @abstractmethod
50
51
    def train_one_epoch(self, epoch):
        pass
52
53

    @abstractmethod
54
55
56
    def validate_one_epoch(self, epoch):
        pass

57
    def train(self, validate=True):
58
59
60
61
62
        for epoch in range(self.num_epochs):
            for callback in self.callbacks:
                callback.on_epoch_begin(epoch)

            # training
Chi Song's avatar
Chi Song committed
63
            _logger.info("Epoch %d Training", epoch)
64
65
66
67
            self.train_one_epoch(epoch)

            if validate:
                # validation
Chi Song's avatar
Chi Song committed
68
                _logger.info("Epoch %d Validating", epoch)
69
70
71
72
73
74
75
                self.validate_one_epoch(epoch)

            for callback in self.callbacks:
                callback.on_epoch_end(epoch)

    def validate(self):
        self.validate_one_epoch(-1)
76
77
78
79
80
81
82
83

    def export(self, file):
        mutator_export = self.mutator.export()
        with open(file, "w") as f:
            json.dump(mutator_export, f, indent=2, sort_keys=True, cls=TorchTensorEncoder)

    def checkpoint(self):
        raise NotImplementedError("Not implemented yet")