simple_trainer.py 6.45 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import os
import shutil
from contextlib import nullcontext

import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

from allamo.trainer.base import BaseTrainer
from allamo.logging import logger
from allamo.model.model import AllamoTransformer
from allamo.configuration import AllamoConfiguration
from allamo.torch_utils import TORCH_DTYPE_MAP
from allamo.train_utils import (
    get_model_checkpoint_path,
    get_config_checkpoint_path,
    get_optimizer_checkpoint_path,
)

class SimpleTrainer(BaseTrainer):

    def __init__(self, config: AllamoConfiguration):
        super().__init__(config)
        if config.distributed_checkpoint:
            config.distributed_checkpoint = False
            logger.warn("PyTorch Distributed Checkpoint (DCP) is only available for FSDP training! Fallback to regular checkpoint")
        
    def distributed(self):
        return self.train_ctx.world_size > 1
        
    def init_torch(self, config: AllamoConfiguration):
        super().init_torch(config)
        self.ctx = nullcontext() if self.device_type == 'cpu' else torch.amp.autocast(device_type=self.device_type, dtype=TORCH_DTYPE_MAP[config.dtype])
        if config.dtype == 'bfloat16-true':
            # torch.set_float32_matmul_precision("high")
            torch.set_default_dtype(torch.bfloat16)
        
    def init_training(self):
        super().init_training()
        
        model = AllamoTransformer(self.model_config)
        print("model: ", model)
        self.model_num_params = model.model_num_params
chenzk's avatar
v1.0.3  
chenzk committed
44
45
46

        self.freeze_model_params(model) # Optionally freezes model parameters depending on the configuration

chenzk's avatar
v1.0  
chenzk committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
        if self.checkpoint_manager.is_checkpoint_available():
            self.checkpoint_manager.load_regular_model_checkpoint(model)
        else:
            logger.info("New model initialized from scratch")
        model.to(self.config.device)

        if self.config.compile:
            logger.info("Compiling model")
            try:
                model = torch.compile(model, mode=self.config.compile_mode)
                logger.info("Model compiled and ready to use")
            except Exception as err:
                logger.warn(f"Unable to compile the model: {err}")

        self.raw_model = model # neeeded in DDP training
        self.model = model
        # wrap model into DDP container
        if self.distributed():
            self.model = DDP(self.model, device_ids=[self.train_ctx.local_rank])
            
        # initialize a GradScaler. If enabled=False scaler is a no-op
        self.scaler = torch.amp.GradScaler(self.device_type, enabled=(self.config.dtype == 'float16' or self.config.dtype == 'bfloat16'))
        
        # optimizer
        self.optimizer = self.raw_model.configure_optimizers(self.config, self.device_type)
        if self.checkpoint_manager.is_checkpoint_available():
            self.load_optimizer_checkpoint(self.optimizer)
        
        self.init_gradient_accumulation_scheduler()
        self.log_init_learning_rate()

    def load_optimizer_checkpoint(self, optimizer):
        ckpt_path = get_optimizer_checkpoint_path(self.checkpoint_manager.checkpoint_name, self.checkpoint_manager.checkpoint_dir)
        if os.path.exists(ckpt_path):
chenzk's avatar
v1.0.3  
chenzk committed
81
            state_dict = torch.load(ckpt_path, map_location=self.config.device, weights_only=True)
chenzk's avatar
v1.0  
chenzk committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
            optimizer.load_state_dict(state_dict)
            logger.info(f"Optimizer state loaded from checkpoint {ckpt_path}")
        else:
            logger.warning("Optimizer checkpoint file not found. Initializing optimizer from scratch")

    # helps saving checkpoint to a file
    def save_checkpoint(self, ckpt_file_name, model_only=False, epoch_ckpt=False):
        if not self.train_ctx.master_process:
            return
        
        model_ckpt_file_path = get_model_checkpoint_path(ckpt_file_name, self.config.out_dir)
        md5sum = self.checkpoint_manager.save_regular_model_checkpoint(self.raw_model.state_dict(), model_ckpt_file_path, epoch_ckpt)

        config_ckpt_file_path = get_config_checkpoint_path(ckpt_file_name, self.config.out_dir)
        self.checkpoint_manager.save_config_checkpoint(config_ckpt_file_path, md5sum, self.model_config)
        
        if model_only == False and self.checkpoint_manager.should_save_optimizer():
            optim_ckpt_file_path = get_optimizer_checkpoint_path(ckpt_file_name, self.config.out_dir)
            self.checkpoint_manager.save_regular_optimizer_checkpoint(self.optimizer.state_dict(), optim_ckpt_file_path)
            
            if self.config.optimizer_checkpoint_interval is not None:
                shutil.copy(model_ckpt_file_path, model_ckpt_file_path + '.optim')
                shutil.copy(config_ckpt_file_path, config_ckpt_file_path + '.optim')
        logger.info(f"checkpoint files saved in {self.config.out_dir}")

    def should_evaluate(self):
        return super().should_evaluate() and self.train_ctx.master_process
    
    def forward(self, batch, last_micro_step):
        if self.distributed():
            # in DDP training we only need to sync gradients at the last micro step.
            # the official way to do this is with model.no_sync() context manager, but
            # I really dislike that this bloats the code and forces us to repeat code
            # looking at the source of that context manager, it just toggles this variable
            self.model.require_backward_grad_sync = last_micro_step
        with self.ctx:
            logits, loss, _ = self.model(**batch)
        if self.gradient_accumulation_steps > 1:
            loss = loss / self.gradient_accumulation_steps # scale the loss to account for micro steps
        if batch["target_weights"] is not None:
            if self.config.weighted_loss_method == 'openchat':
                target_weights = batch["target_weights"].sum()
                # sum loss weights over all processes
                target_weights = self.dist_all_reduce(target_weights, op=dist.ReduceOp.SUM)
                loss = (self.train_ctx.world_size / target_weights) * loss
            else:
                loss = loss / torch.sum(batch["target_weights"] > 0).item()
        
        unmasked_labels = torch.sum(batch["target_ids"].view(-1) != self.config.ignore_index).item()
        accuracy = (logits.max(2).indices == batch["target_ids"]).sum().item() / unmasked_labels
        return loss, unmasked_labels, accuracy

    def close(self):
        if self.distributed():
            dist.barrier()
            dist.destroy_process_group()