trainer_monkey_patch.py 6.18 KB
Newer Older
wanglch's avatar
wanglch committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import json
import os

import torch
import torch.nn as nn
import transformers
from transformers import Trainer, logging
from transformers.trainer import is_sagemaker_mp_enabled

logger = logging.get_logger(__name__)


def get_num_layer_for_vit_and_qllama(var_name, vit_num_max_layer, llama_num_max_layer):
    if var_name.startswith('internvl.'):
        var_name = var_name[len('internvl.'):]
    if var_name in ('query_tokens', 'logit_scale',):
        return 0
    if var_name.startswith('clip_projector.'):
        return vit_num_max_layer
    if var_name.startswith('clip_projector2.') or var_name.startswith('itm_head.') or \
            var_name == 'text_projection':
        return llama_num_max_layer
    if var_name.startswith('vision_model.'):
        if 'embeddings.' in var_name:
            return 0
        if 'layers.' in var_name:
            var_name = var_name.split('layers.')[-1]
            layer_id = int(var_name.split('.')[0])
            return layer_id + 1
    if var_name.startswith('qllama.'):
        if 'embed_tokens' in var_name:
            return 0
        if 'layers.' in var_name:
            var_name = var_name.split('layers.')[-1]
            layer_id = int(var_name.split('.')[0])
            return layer_id + 1
        else:
            return llama_num_max_layer
    return 0


def param_classification(name):
    if name.startswith('internvl.'):
        name = name[len('internvl.'):]
    if name in ['query_tokens', 'text_projection', 'logit_scale']:
        return 'qllama'
    elif name.startswith('vision_model.'):
        return 'vit'
    elif name.startswith('qllama.'):
        return 'qllama'
    elif name.startswith('clip_projector.'):
        return 'vit'
    elif name.startswith('clip_projector2.'):
        return 'qllama'
    elif name.startswith('itm_head.'):
        return 'qllama'
    else:
        return 'other'


def create_optimizer(self):
    """
    Setup the optimizer.

    We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
    Trainer's init through `optimizers`, or subclass and override this method in a subclass.
    """
    opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model

    parameter_groups = {}
    try:  # for stage2 model
        vit_num_layers = opt_model.config.vision_config.num_hidden_layers + 2
        qllama_num_layers = opt_model.config.qllama_config.num_hidden_layers + 2
    except:  # for stage3 model
        vit_num_layers = opt_model.internvl.config.vision_config.num_hidden_layers + 2
        qllama_num_layers = opt_model.internvl.config.qllama_config.num_hidden_layers + 2
    print('vit_num_layers:', vit_num_layers)
    print('qllama_num_layers:', qllama_num_layers)

    vit_layer_decay_rate = float(os.getenv('VIT_LAYER_DECAY_RATE', 1.0))
    qllama_layer_decay_rate = float(os.getenv('QLLAMA_LAYER_DECAY_RATE', 1.0))
    qllama_lr_scale = float(os.getenv('QLLAMA_LR_SCALE', 1.0))
    print('vit_layer_decay_rate:', vit_layer_decay_rate)
    print('qllama_layer_decay_rate:', qllama_layer_decay_rate)
    print('qllama_lr_scale:', qllama_lr_scale)

    for name, param in opt_model.named_parameters():
        if not param.requires_grad:
            continue  # frozen weights
        if len(param.shape) == 1 or name.endswith('.bias'):
            group_name = 'no_decay'
            this_weight_decay = 0.
        else:
            group_name = 'decay'
            this_weight_decay = self.args.weight_decay

        cls = param_classification(name)
        layer_id = get_num_layer_for_vit_and_qllama(name, vit_num_layers, qllama_num_layers)
        group_name = '%s_layer_%d_%s' % (cls, layer_id, group_name)
        if group_name not in parameter_groups:
            if cls == 'vit':
                scale = vit_layer_decay_rate ** (vit_num_layers - layer_id - 1)
            elif cls == 'qllama':
                scale = qllama_layer_decay_rate ** (qllama_num_layers - layer_id - 1)
                scale = scale * qllama_lr_scale
            else:
                scale = 1.0
            scale = min(1.0, scale)
            parameter_groups[group_name] = {
                'weight_decay': this_weight_decay,
                'params': [],
                'param_names': [],
                'lr_scale': scale,
                'group_name': group_name,
                'lr': scale * self.args.learning_rate,
            }
        parameter_groups[group_name]['params'].append(param)
        parameter_groups[group_name]['param_names'].append(name)

        rank = torch.distributed.get_rank()
        if rank == 0:
            to_display = {}
            for key in parameter_groups:
                to_display[key] = {
                    'param_names': parameter_groups[key]['param_names'],
                    'lr_scale': parameter_groups[key]['lr_scale'],
                    'lr': parameter_groups[key]['lr'],
                    'weight_decay': parameter_groups[key]['weight_decay'],
                }
            print('Param groups = %s' % json.dumps(to_display, indent=2))

    optimizer_grouped_parameters = list(parameter_groups.values())
    optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)

    self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
    if optimizer_cls.__name__ == 'Adam8bit':
        import bitsandbytes

        manager = bitsandbytes.optim.GlobalOptimManager.get_instance()

        skipped = 0
        for module in opt_model.modules():
            if isinstance(module, nn.Embedding):
                skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
                logger.info(f'skipped {module}: {skipped / 2 ** 20}M params')
                manager.register_module_override(module, 'weight', {'optim_bits': 32})
                logger.debug(f'bitsandbytes: will optimize {module} in fp32')
        logger.info(f'skipped: {skipped / 2 ** 20}M params')

    if is_sagemaker_mp_enabled():
        import smdistributed.modelparallel.torch as smp
        self.optimizer = smp.DistributedOptimizer(self.optimizer)

    return self.optimizer


def replace_create_optimizer():
    print('Replace original create_optimizer with custom create_optimizer')
    transformers.Trainer.create_optimizer = create_optimizer