checpoint_io.py 1.57 KB
Newer Older
wangwei990215's avatar
wangwei990215 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import os
import logging
import torch
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint

logger = logging.getLogger(__name__)


def save_model_checkpoint_deepspeed(model_engine, cfg, checkpoint_name="checkpoint", merge_rank=False, model=None):
    logger.info(f"--> saving model ...")
    save_dir = os.path.join(cfg.output_dir, checkpoint_name)
    os.makedirs(save_dir, exist_ok=True)
    save_full_path = save_dir
    model_engine.save_checkpoint(save_dir=save_full_path, exclude_frozen_parameters=True)
    logger.info(f"encoder saved at {save_full_path}")
    
    # if merged, it will be fast when decoding
    if merge_rank:
        assert model is not None
        save_dir_merge = os.path.join(cfg.output_dir, checkpoint_name + '_merged')
        os.makedirs(save_dir_merge, exist_ok=True)
        logger.info("CKPT: loading DeepSpeed Model from: {}".format(save_full_path))
        ckpt_dict = get_fp32_state_dict_from_zero_checkpoint(save_full_path)
        logging.info("Merge Zero3 model to FP32...")
        logging.info("Save Lora Weights...")
        model.llm.save_pretrained(os.path.join(save_dir_merge, 'new_llm'))
        logging.info(f"Save finished... {os.path.join(save_dir_merge, 'new_llm')}")
        ckpt_dict_new = {}
        for key in ckpt_dict.keys():
            if 'llm' not in key:
                ckpt_dict_new[key] = ckpt_dict[key].to('cpu').clone()
            torch.save(ckpt_dict_new, os.path.join(save_dir_merge, 'adapter_project.pt'))
        logging.info(f"Save finished... {os.path.join(save_dir_merge, 'adapter_project.pt')}")