save_checkpoint.py 2.8 KB
Newer Older
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os

import torch.distributed as dist
from coati.trainer.strategies import ColossalAIStrategy, Strategy
from coati.trainer.utils import is_rank_0
from torch import nn
from torch.optim import Optimizer

from .base import Callback


class SaveCheckpoint(Callback):
    """
        The callback for saving checkpoint for coati.

        Only support saving actor and critic model.
        A typical architecture of the saved checkpoint would be:
            - checkpoint
                - episode_x
                    - actor.pt
                    - actor-optim-rank-0.pt
                    - actor-optim-rank-1.pt
                    - critic.pt
                    - critic-optim-rank-0.pt
                    - critic-optim-rank-1.pt
                - ...

    Args:
        path(str): the base path you want to save checkpoint, the checkpoint would be saved at `path/checkpoint`
        interval(int): the interval episode of saving checkpoint
        strategy(Strategy): the strategy used to train
        actor(nn.Module): the actor model
        critic(nn.Module): the critic model
        actor_optim(Optimizer): the optimizer of actor
        critic_optim(Optimizer): the optimizer of critic

    """

    def __init__(self,
                 path: str,
                 interval: int,
                 strategy: Strategy,
                 actor: nn.Module = None,
                 critic: nn.Module = None,
                 actor_optim: Optimizer = None,
                 critic_optim: Optimizer = None) -> None:
        super().__init__()
        self.path = os.path.join(path, 'checkpoint')
        self.interval = interval
        self.strategy = strategy
        self.model_dict = {'actor': [actor, actor_optim], 'critic': [critic, critic_optim]}

    def on_episode_end(self, episode: int) -> None:
        if (episode + 1) % self.interval != 0:
            return
        base_path = os.path.join(self.path, f'episode_{episode}')
        if not os.path.exists(base_path):
            os.makedirs(base_path)

        for model in self.model_dict.keys():

            # save model
            if self.model_dict[model][0] is None:
                # saving only optimizer states is meaningless, so it would be skipped
                continue
            model_path = os.path.join(base_path, f'{model}.pt')
            self.strategy.save_model(model=self.model_dict[model][0], path=model_path, only_rank0=True)

            # save optimizer
            if self.model_dict[model][1] is None:
                continue
            only_rank0 = not isinstance(self.strategy, ColossalAIStrategy)
            rank = 0 if is_rank_0() else dist.get_rank()
            optim_path = os.path.join(base_path, f'{model}-optim-rank-{rank}.pt')
            self.strategy.save_optimizer(optimizer=self.model_dict[model][1], path=optim_path, only_rank0=only_rank0)