filesystem_resolver.py 2.29 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
import json
import os

import vllm.envs as envs
from vllm.lora.request import LoRARequest
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry


class FilesystemResolver(LoRAResolver):
    def __init__(self, lora_cache_dir: str):
        self.lora_cache_dir = lora_cache_dir

15
16
    async def resolve_lora(
        self, base_model_name: str, lora_name: str
17
    ) -> LoRARequest | None:
18
        lora_path = os.path.join(self.lora_cache_dir, lora_name)
19
20
21
22
23
24
25
26
27
28
29
        maybe_lora_request = await self._get_lora_req_from_path(
            lora_name, lora_path, base_model_name
        )
        return maybe_lora_request

    async def _get_lora_req_from_path(
        self, lora_name: str, lora_path: str, base_model_name: str
    ) -> LoRARequest | None:
        """Builds a LoraRequest pointing to the lora path if it's a valid
        LoRA adapter and has a matching base_model_name.
        """
30
        if os.path.exists(lora_path):
31
32
            adapter_config_path = os.path.join(lora_path, "adapter_config.json")

33
34
35
            if os.path.exists(adapter_config_path):
                with open(adapter_config_path) as file:
                    adapter_config = json.load(file)
36
37
38
39
40
41
42
43
44
                if (
                    adapter_config["peft_type"] == "LORA"
                    and adapter_config["base_model_name_or_path"] == base_model_name
                ):
                    lora_request = LoRARequest(
                        lora_name=lora_name,
                        lora_int_id=abs(hash(lora_name)),
                        lora_path=lora_path,
                    )
45
46
47
48
49
50
51
52
53
                    return lora_request
        return None


def register_filesystem_resolver():
    """Register the filesystem LoRA Resolver with vLLM"""

    lora_cache_dir = envs.VLLM_LORA_RESOLVER_CACHE_DIR
    if lora_cache_dir:
54
        if not os.path.exists(lora_cache_dir) or not os.path.isdir(lora_cache_dir):
55
56
            raise ValueError(
                "VLLM_LORA_RESOLVER_CACHE_DIR must be set to a valid directory \
57
58
                for Filesystem Resolver plugin to function"
            )
59
        fs_resolver = FilesystemResolver(lora_cache_dir)
60
        LoRAResolverRegistry.register_resolver("Filesystem Resolver", fs_resolver)
61
62

    return