train_utils.py 4.52 KB
Newer Older
wangwei990215's avatar
wangwei990215 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import time
from contextlib import nullcontext
import torch.distributed as dist

try:
    import torch_musa
except ImportError as e:
    print("You should install torch_musa if you want to run on Moore Threads GPU")
from mooer.utils.utils import *
from mooer.utils.checpoint_io import save_model_checkpoint_deepspeed

logger = logging.getLogger(__name__)

device = str(get_device())


def clear_gpu_cache(rank=None):
    """Clear the GPU cache for all ranks"""
    if rank == 0:
        logger.info(f"Clearing GPU cache for all ranks")
    if 'musa' in device:
        torch.musa.empty_cache()
    else:
        torch.cuda.empty_cache()
        

def train(
    model,
    train_dataloader,
    train_config,
    local_rank=None,
    rank=None,
    train_data_set=None,
    model_org=None
):
    epoch_times = []
    total_step = 0
    if isinstance(model, torch.nn.parallel.DistributedDataParallel):
        model_context = model.join
    else:
        model_context = nullcontext
    for epoch in range(train_config.resume_epoch, train_config.num_epochs):
        if train_data_set is not None:
            dist.barrier()
            logging.info(f"RANK:{rank} Reset Dataset Epoch {epoch}...")
            train_data_set.set_epoch(epoch)
            train_dataloader_iterator = iter(train_dataloader)
            dist.barrier()
        epoch_start_time = time.perf_counter()
        model.train()
        total_loss = 0.0
        total_acc = 0.0
        step = 0
        input_dtype = torch.float32
        if train_config.use_fp16:
            input_dtype = torch.float16
        elif train_config.use_bf16:
            input_dtype = torch.bfloat16
        logging.info(f"Input data type: {input_dtype}")
        with model_context():
            should_continue = True
            while should_continue:
                try:
                    batch = next(train_dataloader_iterator)
                    total_step += 1
                    for key in batch.keys():
                        batch[key] = (
                            batch[key].to(local_rank).to(input_dtype)
                            if isinstance(batch[key], torch.Tensor)
                               and batch[key].dtype == torch.float32
                            else (
                                batch[key].to(local_rank)
                                if isinstance(batch[key], torch.Tensor)
                                else batch[key]
                            )
                        )
                    outputs, acc = model(**batch)
                    loss = outputs.loss
                    
                    acc_report = acc
                    loss_report = loss.detach().float()
                    
                    total_loss += loss.detach().float()
                    total_acc += acc
                    
                    model.backward(loss)
                    model.step()
                    
                    current_lr = model.optimizer.param_groups[0]['lr']
                    if rank == 0 and step % train_config.log_interval == 0:
                        logging.info(
                            f"Training Epoch: {epoch + 1}/{train_config.num_epochs}, "
                            f"step {step} lr {current_lr} "
                            f"completed (loss: {loss_report}, "
                            f"acc: {acc_report})")
                    
                    if step % train_config.save_interval == 0:
                        checkpoint_name = f"{train_config.model_name}_epoch_{str(epoch + 1)}_step_{step + 1}"
                        save_model_checkpoint_deepspeed(
                            model, train_config, checkpoint_name,
                            merge_rank=train_config.get('save_merge_rank', True),
                            model=model_org
                        )
                    step += 1
                
                except Exception as e:
                    logging.error(f"Exception occurred on Rank {rank}: {e}")
                    epoch_end_time = time.perf_counter() - epoch_start_time
                    logging.info(f"Epoch {epoch + 1}, Cost Time: {epoch_end_time}")
                    epoch_times.append(epoch_end_time)
                    # save model
                    checkpoint_name = f"{train_config.model_name}_epoch_{str(epoch + 2)}_step_1"
                    save_model_checkpoint_deepspeed(
                        model, train_config, checkpoint_name,
                        merge_rank=train_config.get('save_merge_rank', True),
                        model=model_org
                    )
                    break