test_lora_huggingface.py 1.61 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
import pytest

from vllm.lora.models import LoRAModel
7
from vllm.lora.peft_helper import PEFTHelper
8
from vllm.lora.utils import get_adapter_absolute_path
9
from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM
10
11

# Provide absolute path and huggingface lora ids
12
lora_fixture_name = ["llama32_lora_files", "llama32_lora_huggingface_id"]
13
LLAMA_LORA_MODULES = [
14
15
16
17
18
19
    "qkv_proj",
    "o_proj",
    "gate_up_proj",
    "down_proj",
    "embed_tokens",
    "lm_head",
20
]
21
22
23
24
25


@pytest.mark.parametrize("lora_fixture_name", lora_fixture_name)
def test_load_checkpoints_from_huggingface(lora_fixture_name, request):
    lora_name = request.getfixturevalue(lora_fixture_name)
26
27
    packed_modules_mapping = Qwen3ForCausalLM.packed_modules_mapping

28
    expected_lora_lst: list[str] = []
29
    for module in LLAMA_LORA_MODULES:
30
        if module in packed_modules_mapping:
31
            expected_lora_lst.extend(packed_modules_mapping[module])
32
        else:
33
34
            expected_lora_lst.append(module)
    expected_lora_modules = set(expected_lora_lst)
35
36
    lora_path = get_adapter_absolute_path(lora_name)

omahs's avatar
omahs committed
37
    # lora loading should work for either absolute path and huggingface id.
38
    peft_helper = PEFTHelper.from_local_dir(lora_path, 4096)
39
40
41
    lora_model = LoRAModel.from_local_checkpoint(
        lora_path,
        expected_lora_modules,
42
        peft_helper=peft_helper,
43
44
        lora_model_id=1,
        device="cpu",
45
    )
46
47
48

    # Assertions to ensure the model is loaded correctly
    assert lora_model is not None, "LoRAModel is not loaded correctly"