loader.py 671 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
import torch
import torch.nn as nn

from .model_utils import find_layers
from .quant import make_quant


def load_quant(model: nn.Module, checkpoint: str, wbits: int, groupsize: int):
    model = model.eval()
    layers = find_layers(model)

    # ignore lm head
    layers = find_layers(model)
14
    for name in ["lm_head"]:
15
16
17
18
19
        if name in layers:
            del layers[name]

    make_quant(model, layers, wbits, groupsize)

20
    if checkpoint.endswith(".safetensors"):
21
        from safetensors.torch import load_file as safe_load
22

23
24
25
26
27
        model.load_state_dict(safe_load(checkpoint))
    else:
        model.load_state_dict(torch.load(checkpoint))

    return model