"vscode:/vscode.git/clone" did not exist on "b047b553c26ebc83b8b749a605ef1c44f5bb4416"
peft.py 1.44 KB
Newer Older
1
2
3
4
5
6
7
8
import os
import json
from loguru import logger
import torch

from transformers import AutoTokenizer
from peft import AutoPeftModelForCausalLM, AutoPeftModelForSeq2SeqLM

OlivierDehaene's avatar
OlivierDehaene committed
9

10
11
12
def download_and_unload_peft(model_id, revision, trust_remote_code):
    torch_dtype = torch.float16

13
    logger.info("Trying to load a Peft model. It might take a while without feedback")
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
    try:
        model = AutoPeftModelForCausalLM.from_pretrained(
            model_id,
            revision=revision,
            torch_dtype=torch_dtype,
            trust_remote_code=trust_remote_code,
            low_cpu_mem_usage=True,
        )
    except Exception:
        model = AutoPeftModelForSeq2SeqLM.from_pretrained(
            model_id,
            revision=revision,
            torch_dtype=torch_dtype,
            trust_remote_code=trust_remote_code,
            low_cpu_mem_usage=True,
        )
30
    logger.info("Peft model detected.")
31
32
33
34
35
    logger.info(f"Merging the lora weights.")

    base_model_id = model.peft_config["default"].base_model_name_or_path

    model = model.merge_and_unload()
OlivierDehaene's avatar
OlivierDehaene committed
36

37
38
39
    os.makedirs(model_id, exist_ok=True)
    cache_dir = model_id
    logger.info(f"Saving the newly created merged model to {cache_dir}")
OlivierDehaene's avatar
OlivierDehaene committed
40
41
42
    tokenizer = AutoTokenizer.from_pretrained(
        base_model_id, trust_remote_code=trust_remote_code
    )
43
44
45
    model.save_pretrained(cache_dir, safe_serialization=True)
    model.config.save_pretrained(cache_dir)
    tokenizer.save_pretrained(cache_dir)