fsdp_trainer.py 9.89 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import os
import shutil
import torch
import torch.distributed as dist
import torch.distributed._functional_collectives as funcol
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    StateDictType,
    FullStateDictConfig, # general model non-sharded, non-flattened params
)

from allamo.trainer.base import BaseTrainer
from allamo.logging import logger
from allamo.model.model import AllamoTransformer
from allamo.configuration import AllamoConfiguration
from allamo.parallelisms.fsdp_utils import parallelize_model_with_fsdp1
from allamo.parallelisms.fsdp2_utils import build_world_mesh, parallelize_model_with_fsdp2
from allamo.train_utils import (
    get_model_checkpoint_path,
    get_config_checkpoint_path,
    get_optimizer_checkpoint_path,
)

class FSDPTrainer(BaseTrainer):

    def __init__(self, config: AllamoConfiguration):
        super().__init__(config)
        
    def distributed(self):
        return True
                    
    def init_torch(self, config: AllamoConfiguration):
        super().init_torch(config)
        if config.dtype == 'bfloat16-true':
            raise Exception('Full bfloat16 training is not supported with FSDP')
        
        self.fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
        if config.gradient_checkpointing:
            self.fsdp_activation_checkpointing = True
            config.gradient_checkpointing = False # control gradient checkpointing with FSDP 
            logger.info(
                "Deactivated gradient checkpointing at the model configuration level. "
                "Activated gradient checkpointing at the FSDP level."
            )
        else:
            self.fsdp_activation_checkpointing = False
        
        # DCP activates FSDP2
        if self.config.distributed_checkpoint:
            assert self.config.dtype != 'float16', "GradScaler is not functioning properly with FSDP2"
            self.world_mesh = build_world_mesh(self.train_ctx, self.device_type)
        else:
            self.world_mesh = None
            
    def init_training(self):
        super().init_training()
        self.model_config.gradient_checkpointing = False # AC is handled by FSDP
            
        with torch.device('meta'):
            model = AllamoTransformer(self.model_config)
        self.model_num_params = model.model_num_params
chenzk's avatar
v1.0.3  
chenzk committed
62
63
64

        self.freeze_model_params(model) # Optionally freezes model parameters depending on the configuration

chenzk's avatar
v1.0  
chenzk committed
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
        if self.checkpoint_manager.checkpoint_name is None:
            if self.world_mesh is None:
                self.model = parallelize_model_with_fsdp1(model, self.config, self.fsdp_activation_checkpointing)
            else:
                self.model = parallelize_model_with_fsdp2(model, self.world_mesh, self.config, self.fsdp_activation_checkpointing)
            self.model.to_empty(device=self.device_type)
            self.model.init_model_weights()
            logger.info("Initialized a new model from scratch")
            
            self.optimizer = self.model.configure_optimizers(self.config, self.device_type)
            logger.info("Initializing optimizer from scratch")
        else:
            if self.config.distributed_checkpoint:
                self.model = parallelize_model_with_fsdp2(model, self.world_mesh, self.config, self.fsdp_activation_checkpointing)
                logger.info("model.to_empty")
                self.model.to_empty(device=self.device_type)
                logger.info("model.init_model_weights")
                self.model.init_model_weights()
                logger.info("checkpoint_manager.load_distributed_model_checkpoint")
                self.checkpoint_manager.load_distributed_model_checkpoint(self.model)
                
                logger.info("model.configure_optimizers")
                self.optimizer = self.model.configure_optimizers(self.config, self.device_type)
                logger.info("checkpoint_manager.load_distributed_optimizer_checkpoint")
                self.checkpoint_manager.load_distributed_optimizer_checkpoint(self.model, self.optimizer)
                logger.info("ready")
            else:
                model.to_empty(device=self.device_type)
                model.init_model_weights()
                self.checkpoint_manager.load_regular_model_checkpoint(model)
                
                self.model = parallelize_model_with_fsdp1(model, self.config, self.fsdp_activation_checkpointing)
                
                self.optimizer = self.model.configure_optimizers(self.config, self.device_type)
                self.load_optimizer_checkpoint(self.model, self.optimizer)
                
        # initialize a GradScaler only for FSDP's built-in mixed precision with fp16
        self.scaler = torch.amp.GradScaler(self.device_type, enabled=(self.config.dtype == 'float16'))
        
        self.init_gradient_accumulation_scheduler()
        self.log_init_learning_rate()
    
    def load_optimizer_checkpoint(self, model, optimizer):
        ckpt_path = get_optimizer_checkpoint_path(self.checkpoint_manager.checkpoint_name, self.checkpoint_manager.checkpoint_dir)
        if os.path.exists(ckpt_path):
            # requires each rank to have the full dict in CPU memory to reduce communication
chenzk's avatar
v1.0.3  
chenzk committed
111
            full_osd = torch.load(ckpt_path, map_location='cpu', weights_only=True)
chenzk's avatar
v1.0  
chenzk committed
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
            sharded_osd = FSDP.optim_state_dict_to_load(model, optimizer, full_osd)
            optimizer.load_state_dict(sharded_osd)
            logger.info(f"Shared optimizer state loaded from checkpoint {ckpt_path}")
        else:
            if self.train_ctx.master_process:
                logger.warning("Optimizer checkpoint file not found. Initializing optimizer from scratch")
        
    # helps saving checkpoint to a file
    def save_checkpoint(self, ckpt_file_name, model_only=False, epoch_ckpt=False):
        if self.config.distributed_checkpoint:
            config_ckpt_file_path = get_config_checkpoint_path(ckpt_file_name, self.config.out_dir)
            self.checkpoint_manager.save_config_checkpoint(config_ckpt_file_path, None, self.model_config)
            
            if self.train_ctx.master_process and not self.config.ignore_last_checkpoint_backup:
                logger.warning("Backing up a previous checkpoint is not supported for distributed checkpoints")
            model_ckpt_dir_path = self.checkpoint_manager.save_distributed_model_checkpoint(self.model, ckpt_file_name)
            
            if model_only == False and self.checkpoint_manager.should_save_optimizer():
                self.checkpoint_manager.save_distributed_optimizer_checkpoint(self.model, self.optimizer, ckpt_file_name)
                
                if self.config.optimizer_checkpoint_interval is not None:
                    shutil.copytree(model_ckpt_dir_path, model_ckpt_dir_path + '-optim')
                    shutil.copy(config_ckpt_file_path, config_ckpt_file_path + '.optim')
        else:
            with FSDP.state_dict_type(self.model, StateDictType.FULL_STATE_DICT, self.fullstate_save_policy):
                full_msd = self.model.state_dict()
            if self.train_ctx.master_process:
                model_ckpt_file_path = get_model_checkpoint_path(ckpt_file_name, self.config.out_dir)
                md5sum = self.checkpoint_manager.save_regular_model_checkpoint(full_msd, model_ckpt_file_path, epoch_ckpt)
                del full_msd
                
                config_ckpt_file_path = get_config_checkpoint_path(ckpt_file_name, self.config.out_dir)
                self.checkpoint_manager.save_config_checkpoint(config_ckpt_file_path, md5sum, self.model_config)
            
            if model_only == False and self.checkpoint_manager.should_save_optimizer():
                # pull all sharded optimizer states to rank0 cpu.
                full_osd = FSDP.full_optim_state_dict(self.model, self.optimizer)
                if self.train_ctx.master_process:
                    optim_ckpt_file_path = get_optimizer_checkpoint_path(ckpt_file_name, self.config.out_dir)
                    self.checkpoint_manager.save_regular_optimizer_checkpoint(full_osd, optim_ckpt_file_path)
                    del full_osd
                    
                    if self.config.optimizer_checkpoint_interval is not None:
                        shutil.copy(model_ckpt_file_path, model_ckpt_file_path + '.optim')
                        shutil.copy(config_ckpt_file_path, config_ckpt_file_path + '.optim')
    
    def dist_all_reduce(self, x: torch.Tensor, op: dist.ReduceOp):
        if self.world_mesh is None:
            dist.all_reduce(x, op=op)
            return x
        else:
            return funcol.all_reduce(x, reduceOp=op.name, group=self.world_mesh["dp"])
    
    def clip_grad_norm(self):
        if self.world_mesh is None:
            return self.model.clip_grad_norm_(self.config.grad_clip).item()
        else:
            return super().clip_grad_norm()
            
    def forward(self, batch, last_micro_step):
        logits, loss, _ = self.model(**batch)
        if self.gradient_accumulation_steps > 1:
            loss = loss / self.gradient_accumulation_steps # scale the loss to account for micro steps
        
        if batch["target_weights"] is not None:
            if self.config.weighted_loss_method == 'openchat':
                target_weights = batch["target_weights"].sum()
                # sum loss weights over all processes
                target_weights = self.dist_all_reduce(target_weights, op=dist.ReduceOp.SUM)
                loss = (self.train_ctx.world_size / target_weights) * loss
            else:
                loss = loss / torch.sum(batch["target_weights"] > 0).item()
        
        unmasked_labels = torch.sum(batch["target_ids"].view(-1) != self.config.ignore_index).item()
        accuracy = (logits.max(2).indices == batch["target_ids"]).sum().item() / unmasked_labels
        return loss, unmasked_labels, accuracy
            
    def close(self):
        dist.barrier()
        dist.destroy_process_group()