simple_trainer.py 6.32 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import os
import shutil
from contextlib import nullcontext

import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

from allamo.trainer.base import BaseTrainer
from allamo.logging import logger
from allamo.model.model import AllamoTransformer
from allamo.configuration import AllamoConfiguration
from allamo.torch_utils import TORCH_DTYPE_MAP
from allamo.train_utils import (
    get_model_checkpoint_path,
    get_config_checkpoint_path,
    get_optimizer_checkpoint_path,
)

class SimpleTrainer(BaseTrainer):

    def __init__(self, config: AllamoConfiguration):
        super().__init__(config)
        if config.distributed_checkpoint:
            config.distributed_checkpoint = False
            logger.warn("PyTorch Distributed Checkpoint (DCP) is only available for FSDP training! Fallback to regular checkpoint")
        
    def distributed(self):
        return self.train_ctx.world_size > 1
        
    def init_torch(self, config: AllamoConfiguration):
        super().init_torch(config)
        self.ctx = nullcontext() if self.device_type == 'cpu' else torch.amp.autocast(device_type=self.device_type, dtype=TORCH_DTYPE_MAP[config.dtype])
        if config.dtype == 'bfloat16-true':
            # torch.set_float32_matmul_precision("high")
            torch.set_default_dtype(torch.bfloat16)
        
    def init_training(self):
        super().init_training()
        
        model = AllamoTransformer(self.model_config)
        print("model: ", model)
        self.model_num_params = model.model_num_params
        if self.checkpoint_manager.is_checkpoint_available():
            self.checkpoint_manager.load_regular_model_checkpoint(model)
        else:
            logger.info("New model initialized from scratch")
        model.to(self.config.device)

        if self.config.compile:
            logger.info("Compiling model")
            try:
                model = torch.compile(model, mode=self.config.compile_mode)
                logger.info("Model compiled and ready to use")
            except Exception as err:
                logger.warn(f"Unable to compile the model: {err}")

        self.raw_model = model # neeeded in DDP training
        self.model = model
        # wrap model into DDP container
        if self.distributed():
            self.model = DDP(self.model, device_ids=[self.train_ctx.local_rank])
            
        # initialize a GradScaler. If enabled=False scaler is a no-op
        self.scaler = torch.amp.GradScaler(self.device_type, enabled=(self.config.dtype == 'float16' or self.config.dtype == 'bfloat16'))
        
        # optimizer
        self.optimizer = self.raw_model.configure_optimizers(self.config, self.device_type)
        if self.checkpoint_manager.is_checkpoint_available():
            self.load_optimizer_checkpoint(self.optimizer)
        
        self.init_gradient_accumulation_scheduler()
        self.log_init_learning_rate()

    def load_optimizer_checkpoint(self, optimizer):
        ckpt_path = get_optimizer_checkpoint_path(self.checkpoint_manager.checkpoint_name, self.checkpoint_manager.checkpoint_dir)
        if os.path.exists(ckpt_path):
            state_dict = torch.load(ckpt_path, map_location=self.config.device)
            optimizer.load_state_dict(state_dict)
            logger.info(f"Optimizer state loaded from checkpoint {ckpt_path}")
        else:
            logger.warning("Optimizer checkpoint file not found. Initializing optimizer from scratch")

    # helps saving checkpoint to a file
    def save_checkpoint(self, ckpt_file_name, model_only=False, epoch_ckpt=False):
        if not self.train_ctx.master_process:
            return
        
        model_ckpt_file_path = get_model_checkpoint_path(ckpt_file_name, self.config.out_dir)
        md5sum = self.checkpoint_manager.save_regular_model_checkpoint(self.raw_model.state_dict(), model_ckpt_file_path, epoch_ckpt)

        config_ckpt_file_path = get_config_checkpoint_path(ckpt_file_name, self.config.out_dir)
        self.checkpoint_manager.save_config_checkpoint(config_ckpt_file_path, md5sum, self.model_config)
        
        if model_only == False and self.checkpoint_manager.should_save_optimizer():
            optim_ckpt_file_path = get_optimizer_checkpoint_path(ckpt_file_name, self.config.out_dir)
            self.checkpoint_manager.save_regular_optimizer_checkpoint(self.optimizer.state_dict(), optim_ckpt_file_path)
            
            if self.config.optimizer_checkpoint_interval is not None:
                shutil.copy(model_ckpt_file_path, model_ckpt_file_path + '.optim')
                shutil.copy(config_ckpt_file_path, config_ckpt_file_path + '.optim')
        logger.info(f"checkpoint files saved in {self.config.out_dir}")

    def should_evaluate(self):
        return super().should_evaluate() and self.train_ctx.master_process
    
    def forward(self, batch, last_micro_step):
        if self.distributed():
            # in DDP training we only need to sync gradients at the last micro step.
            # the official way to do this is with model.no_sync() context manager, but
            # I really dislike that this bloats the code and forces us to repeat code
            # looking at the source of that context manager, it just toggles this variable
            self.model.require_backward_grad_sync = last_micro_step
        with self.ctx:
            logits, loss, _ = self.model(**batch)
        if self.gradient_accumulation_steps > 1:
            loss = loss / self.gradient_accumulation_steps # scale the loss to account for micro steps
        if batch["target_weights"] is not None:
            if self.config.weighted_loss_method == 'openchat':
                target_weights = batch["target_weights"].sum()
                # sum loss weights over all processes
                target_weights = self.dist_all_reduce(target_weights, op=dist.ReduceOp.SUM)
                loss = (self.train_ctx.world_size / target_weights) * loss
            else:
                loss = loss / torch.sum(batch["target_weights"] > 0).item()
        
        unmasked_labels = torch.sum(batch["target_ids"].view(-1) != self.config.ignore_index).item()
        accuracy = (logits.max(2).indices == batch["target_ids"]).sum().item() / unmasked_labels
        return loss, unmasked_labels, accuracy

    def close(self):
        if self.distributed():
            dist.barrier()
            dist.destroy_process_group()