test_lora_huggingface.py 1.69 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
from typing import List

import pytest

from vllm.lora.models import LoRAModel
8
from vllm.lora.peft_helper import PEFTHelper
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from vllm.lora.utils import get_adapter_absolute_path
from vllm.model_executor.models.llama import LlamaForCausalLM

# Provide absolute path and huggingface lora ids
lora_fixture_name = ["sql_lora_files", "sql_lora_huggingface_id"]


@pytest.mark.parametrize("lora_fixture_name", lora_fixture_name)
def test_load_checkpoints_from_huggingface(lora_fixture_name, request):
    lora_name = request.getfixturevalue(lora_fixture_name)
    supported_lora_modules = LlamaForCausalLM.supported_lora_modules
    packed_modules_mapping = LlamaForCausalLM.packed_modules_mapping
    embedding_modules = LlamaForCausalLM.embedding_modules
    embed_padding_modules = LlamaForCausalLM.embedding_padding_modules
    expected_lora_modules: List[str] = []
    for module in supported_lora_modules:
        if module in packed_modules_mapping:
            expected_lora_modules.extend(packed_modules_mapping[module])
        else:
            expected_lora_modules.append(module)

    lora_path = get_adapter_absolute_path(lora_name)

    # lora loading should work for either absolute path and hugggingface id.
33
    peft_helper = PEFTHelper.from_local_dir(lora_path, 4096)
34
35
36
    lora_model = LoRAModel.from_local_checkpoint(
        lora_path,
        expected_lora_modules,
37
        peft_helper=peft_helper,
38
39
40
41
42
43
44
        lora_model_id=1,
        device="cpu",
        embedding_modules=embedding_modules,
        embedding_padding_modules=embed_padding_modules)

    # Assertions to ensure the model is loaded correctly
    assert lora_model is not None, "LoRAModel is not loaded correctly"