"hip/fastermoe/smart_schedule.h" did not exist on "4f9f77f86eb56b130d471fe39edbf66305df4a31"
test_lora_eviction.py 5.55 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

15
import contextlib
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import multiprocessing as mp
import unittest
from typing import Dict, List, Tuple

import torch

from sglang.test.runners import SRTRunner
from sglang.test.test_utils import CustomTestCase

PROMPTS = [
    "AI is a field of computer science focused on",
    """
    ### Instruction:
    Compose a SQL query that uses the following table: users, and returns the user_id and name of all users whose name that does not have a duplicate in the table.
    ### Response:
    SELECT user_id, name FROM users WHERE name LIKE 'A%';
    """,
]

ADAPTERS = [
    "faridlazuarda/valadapt-llama-3.1-8B-it-chinese",  # target_modules = q, v
    "philschmid/code-llama-3-1-8b-text-to-sql-lora",  # target_modules = q, k, v, o, gate, up, down
]

BASE_MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct"


43
44
45
46
47
48
49
50
51
52
@contextlib.contextmanager
def dynamically_loaded_adapter(runner, lora_path: str, lora_name: str):
    """A context manager to load and automatically unload a LoRA adapter."""
    try:
        runner.load_lora_adapter(lora_name=lora_name, lora_path=lora_path)
        yield
    finally:
        runner.unload_lora_adapter(lora_name=lora_name)


53
54
55
56
57
58
59
60
61
62
63
64
class TestLoRAEviction(CustomTestCase):
    def test_lora_eviction_with_different_target_modules(self):
        """
        Test LoRA eviction with different target modules.

        This test runs inference against two LoRA adapters in different orders to force eviction behavior, and ensures
        that the outputs of the same (adapter, prompt) pair are consistent across runs.
        """
        output_history = {}
        self._run_test(ADAPTERS, output_history, reverse=False)
        self._run_test(ADAPTERS, output_history, reverse=True)

65
66
67
68
69
70
71
72
73
74
75
    def test_lora_eviction_with_reused_lora_name(self):
        """
        Test LoRA eviction with reused LoRA names.

        This test runs inference against two LoRA adapters with the same name to ensure that the eviction behavior
        works correctly when reusing LoRA names.
        """
        output_history = {}
        self._run_test(ADAPTERS, output_history, reuse_lora_name=True, repeat=1)
        self._run_test(ADAPTERS, output_history, reuse_lora_name=False, repeat=1)

76
77
78
79
    def _run_test(
        self,
        lora_paths: List[str],
        output_history: Dict[Tuple[str, str], str],
80
        reverse: bool = False,
81
        repeat: int = 2,
82
        reuse_lora_name: bool = False,
83
    ):
84
        REUSED_LORA_NAME = "lora"
85
86
87
88
89
        max_new_tokens = 256
        torch_dtype = torch.float16
        base_path = BASE_MODEL
        assert len(lora_paths) >= 2

90
        initial_lora_paths = lora_paths if not reuse_lora_name else None
91
92
93
94
95
        # Initialize runners
        with SRTRunner(
            base_path,
            torch_dtype=torch_dtype,
            model_type="generation",
96
            lora_paths=initial_lora_paths,
97
            max_loras_per_batch=1,
98
99
100
            enable_lora=True,
            max_lora_rank=256,
            lora_target_modules=["all"],
101
102
103
104
        ) as srt_runner:
            adapter_sequence = lora_paths if not reverse else lora_paths[::-1]

            for i in range(repeat):
105
                for j, lora_path in enumerate(adapter_sequence):
106
                    print(
107
108
109
110
111
112
113
114
                        f"\n========== Testing LoRA eviction with adapter '{lora_path}' (#{j + 1}/{len(adapter_sequence)}), reuse_lora_name: {reuse_lora_name}, reversed: {reverse}, repeat: {i + 1}/{repeat} ---"
                    )

                    lora_name = REUSED_LORA_NAME if reuse_lora_name else lora_path
                    context = (
                        dynamically_loaded_adapter(srt_runner, lora_path, lora_name)
                        if reuse_lora_name
                        else contextlib.nullcontext()
115
                    )
116
117
118
119
120
121
122
                    with context:
                        for prompt in PROMPTS:
                            print("\nprompt:\n", prompt)
                            srt_outputs = srt_runner.forward(
                                [prompt],
                                max_new_tokens=max_new_tokens,
                                lora_paths=[lora_name],
123
                            )
124
125
126
127
128
129
130
131
132
133
134
135
                            output = srt_outputs.output_strs[0].strip()
                            print("\noutput:\n", output)

                            prev_output = output_history.get((lora_path, prompt))
                            if prev_output is not None:
                                self.assertEqual(
                                    prev_output,
                                    output,
                                    f"Output mismatch for adapter {lora_path} and prompt '{prompt}' on repeat {j + 1}, previous: '{prev_output}', current: '{output}'.",
                                )
                            else:
                                output_history[(lora_path, prompt)] = output
136
137
138
139
140
141
142
143
144


if __name__ == "__main__":
    try:
        mp.set_start_method("spawn")
    except RuntimeError:
        pass

    unittest.main(warnings="ignore")