model_utils.py 558 Bytes
Newer Older
Sugon_ldc's avatar
Sugon_ldc committed
1
2
3
4
5
6
7
import torch
from transformers.trainer_pt_utils import LabelSmoother

IGNORE_TOKEN_ID = LabelSmoother.ignore_index


def find_all_linear_names(use_8bit, model):
Sugon_ldc's avatar
Sugon_ldc committed
8
    cls = torch.nn.Linear
Sugon_ldc's avatar
Sugon_ldc committed
9
10
11
12
13
14
15
16
17
18
19
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])
    target_modules = list(lora_module_names)
    return target_modules


def load_from_checkpoint(resume_from_checkpoint, model=None):
    pass