pretrain_functions.py 1.91 KB
Newer Older
mashun1's avatar
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
import json
import torch
import logging
import collections

from utils.registry_class import PRETRAIN

@PRETRAIN.register_function()
def pretrain_specific_strategies(
        model, 
        resume_checkpoint,
        sd_keys_path=None,
        grad_scale=1,
        fix_weight=False,
        **kwargs
    ):

    state_dict = torch.load(resume_checkpoint, map_location='cpu')
    if 'state_dict' in state_dict:
        state_dict = state_dict['state_dict']
    
    # [1] load model
    try:
        ret = model.load_state_dict(state_dict, strict=False)
        logging.info(f'load a fixed model with {ret}')
    except:
        model_dict = model.state_dict()
        key_list = list(state_dict.keys())
        for skey, item in state_dict.items():
            if skey not in model_dict:
                logging.info(f'Skip {skey}')
                continue
            if item.shape != model_dict[skey].shape:
                logging.info(f'Skip {skey} with different shape {item.shape} {model_dict[skey].shape}')
                continue
            model_dict[skey].copy_(item)
        model.load_state_dict(model_dict)
    
    # [2] assign strategies
    total_size = 0
    state_dict = {} if sd_keys_path is None else json.load(open(sd_keys_path))
    for k, p in model.named_parameters():
        if k in state_dict:
            total_size += p.numel()
            if fix_weight:
                p.requires_grad=False
            else:
                p.register_hook(lambda grad: grad_scale * grad)
    
    resume_step = int(os.path.basename(resume_checkpoint).split('_')[-1].split('.')[0])
    logging.info(f'Successfully load step {resume_step} model from {resume_checkpoint}')
    logging.info(f'load a fixed model with {int(total_size / (1024 ** 2))}M parameters')
    return model, resume_step



@PRETRAIN.register_function()
def pretrain_from_sd():
    pass


@PRETRAIN.register_function()
def pretrain_ema_model():
    pass