Unverified Commit 6210e2c4 authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

Support GPU pinning for LoRA (#8697)

parent 6ad6c8c9
......@@ -381,6 +381,78 @@
"print(f\"Output from lora1 (updated): \\n{response.json()[1]['text']}\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### LoRA GPU Pinning"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Another advanced option is to specify adapters as `pinned` during loading. When an adapter is pinned, it is permanently assigned to one of the available GPU pool slots (as configured by `--max-loras-per-batch`) and will not be evicted from GPU memory during runtime. Instead, it remains resident until it is explicitly unloaded.\n",
"\n",
"This can improve performance in scenarios where the same adapter is frequently used across requests, by avoiding repeated memory transfers and reinitialization overhead. However, since GPU pool slots are limited, pinning adapters reduces the flexibility of the system to dynamically load other adapters on demand. If too many adapters are pinned, it may lead to degraded performance, or in the most extreme case (`Number of pinned adapters == max-loras-per-batch`), halt all unpinned requests. Therefore, currently SGLang limits maximal number of pinned adapters to `max-loras-per-batch - 1` to prevent unexpected starvations. \n",
"\n",
"In the example below, we unload `lora1` and reload it as a `pinned` adapter:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response = requests.post(\n",
" url + \"/unload_lora_adapter\",\n",
" json={\n",
" \"lora_name\": \"lora1\",\n",
" },\n",
")\n",
"\n",
"response = requests.post(\n",
" url + \"/load_lora_adapter\",\n",
" json={\n",
" \"lora_name\": \"lora1\",\n",
" \"lora_path\": lora1,\n",
" \"pinned\": True, # Pin the adapter to GPU\n",
" },\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Verify that the result is identical as before:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"url = f\"http://127.0.0.1:{port}\"\n",
"json_data = {\n",
" \"text\": [\n",
" \"List 3 countries and their capitals.\",\n",
" \"List 3 countries and their capitals.\",\n",
" ],\n",
" \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n",
" # The first input uses lora0, and the second input uses lora1\n",
" \"lora_path\": [\"lora0\", \"lora1\"],\n",
"}\n",
"response = requests.post(\n",
" url + \"/generate\",\n",
" json=json_data,\n",
")\n",
"print(f\"Output from lora0: \\n{response.json()[0]['text']}\\n\")\n",
"print(f\"Output from lora1 (pinned): \\n{response.json()[1]['text']}\\n\")"
]
},
{
"cell_type": "code",
"execution_count": null,
......
......@@ -492,12 +492,13 @@ class Engine(EngineBase):
self.tokenizer_manager.get_weights_by_name(obj, None)
)
def load_lora_adapter(self, lora_name: str, lora_path: str):
def load_lora_adapter(self, lora_name: str, lora_path: str, pinned: bool = False):
"""Load a new LoRA adapter without re-launching the engine."""
obj = LoadLoRAAdapterReqInput(
lora_name=lora_name,
lora_path=lora_path,
pinned=pinned,
)
loop = asyncio.get_event_loop()
......
......@@ -144,6 +144,7 @@ class LoRAManager:
# keep metadata for displayed messages
self.lora_refs[lora_ref.lora_id] = lora_ref
self.num_pinned_loras += int(lora_ref.pinned)
except Exception as e:
return self.create_lora_update_result(
success=False,
......@@ -157,13 +158,22 @@ class LoRAManager:
Validate if an adapter can be loaded into the current LoRA memory pool and generate error if it is incompatible.
"""
# Check if the LoRA adapter shape is compatible with the current LoRA memory pool configuration.
memory_pool = getattr(self, "memory_pool", None)
incompatible = memory_pool and not memory_pool.can_support(lora_config)
if incompatible:
raise ValueError(
f"LoRA adapter {lora_ref.lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration. "
"Please ensure that the LoRA adapter's rank is within the configured `--max_lora_rank` and that the target modules are "
"included in `--enable_lora_modules`."
f"LoRA adapter {lora_ref.lora_name} with rank {lora_config.r} is incompatible with the current "
"LoRA memory pool configuration. Please ensure that the LoRA adapter's rank is within the configured "
"`--max-lora-rank` and that the target modules are included in `--lora-target-modules`."
)
# Ensure pinned LoRA adapters does not exceed maximal limit or cause starvation.
if lora_ref.pinned and self.num_pinned_loras >= self.max_loras_per_batch - 1:
raise ValueError(
f"Failed to load LoRA adapter {lora_ref.lora_name} as a pinned adapter. It is not allowed to pin all slots "
"in the LoRA memory pool to avoid starvation for unpinned adapters and base models. Please increase your "
"`--max-loras-per-batch` or load it as unpinned LoRA adapters."
)
def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
......@@ -172,15 +182,17 @@ class LoRAManager:
delete the corresponding LoRA modules.
"""
adapter = self.configs.get(lora_ref.lora_id, None)
adapter = self.configs.get(lora_ref.lora_id)
lora_ref = self.lora_refs.get(lora_ref.lora_id)
assert (
adapter is not None
adapter is not None and lora_ref is not None
), f"LoRA adapter with ID {lora_ref.lora_id} is not loaded. This should have been verified before request is sent to the backend."
try:
del self.configs[lora_ref.lora_id]
del self.loras[lora_ref.lora_id]
del self.lora_refs[lora_ref.lora_id]
self.num_pinned_loras -= int(lora_ref.pinned)
except Exception as e:
return self.create_lora_update_result(
success=False,
......@@ -189,11 +201,49 @@ class LoRAManager:
return self.create_lora_update_result(success=True)
def validate_lora_batch(self, lora_ids: set[str]) -> bool:
"""
Validate if the LoRA IDs in the batch can be loaded into the current LoRA memory pool.
"""
if len(lora_ids) > self.max_loras_per_batch:
return False
# skip pinned LoRA check if no pinned LoRA adapters are loaded.
if self.num_pinned_loras == 0:
return True
# counting the number of pinned LoRA adapters in the batch.
pinned_loras_in_batch = 0
for lora_id in lora_ids:
if lora_id is not None:
lora_ref = self.lora_refs.get(lora_id)
assert (
lora_ref is not None
), f"LoRA ID {lora_id} not found in lora_refs."
pinned_loras_in_batch += int(lora_ref.pinned)
assert pinned_loras_in_batch <= self.num_pinned_loras, (
f"Number of pinned LoRA adapters in the batch ({pinned_loras_in_batch}) exceeds the total number of pinned adapters "
f"({self.num_pinned_loras}). This indicates a bug in the LoRA loading logic."
)
required_slots = len(lora_ids) - pinned_loras_in_batch
mem_pool_vacancy = self.memory_pool.max_loras_per_batch - self.num_pinned_loras
return required_slots <= mem_pool_vacancy
def prepare_lora_batch(self, forward_batch: ForwardBatch):
# Load active loras into lora memory pool
cur_uids = set(forward_batch.lora_ids)
assert len(cur_uids) <= self.max_loras_per_batch
self.memory_pool.prepare_lora_batch(cur_uids, self.loras, self.lora_modules)
self.memory_pool.prepare_lora_batch(
cur_uids=cur_uids,
lora_adapters=self.loras,
lora_modules=self.lora_modules,
lora_refs=self.lora_refs.copy(), # copy snapshot of current lora_refs to avoid mutation during the batch preparation.
)
# set up batch info shared by all lora modules
bs = forward_batch.batch_size
......@@ -366,6 +416,9 @@ class LoRAManager:
# Mapping from LoRA ID to LoRARef object.
self.lora_refs: Dict[str, LoRARef] = {}
# Count of pinned LoRA adapters.
self.num_pinned_loras: int = 0
if lora_paths:
for lora_ref in lora_paths.values():
result = self.load_lora_adapter(lora_ref)
......@@ -399,7 +452,7 @@ class LoRAManager:
self.max_lora_rank = max_lora_rank
else:
self.max_lora_rank = max(
[x.hf_config["r"] for x in self.configs.values()],
[x.r for x in self.configs.values()],
default=0,
)
......
......@@ -28,14 +28,15 @@ class LoRARef:
"""
Reference record for a LoRA model.
This object guarantees a unique ``lora_id`` and may include ``lora_name`` and ``lora_path``. The ID
eliminates conflicts from reused LoRA names or paths and can be used to generate deterministic cache
This object guarantees a unique ``lora_id`` and may include ``lora_name``, ``lora_path``, and ``pinned``.
The ID eliminates conflicts from reused LoRA names or paths and can be used to generate deterministic cache
keys (e.g., radix cache).
"""
lora_id: str = field(default_factory=lambda: uuid4().hex)
lora_name: Optional[str] = None
lora_path: Optional[str] = None
pinned: Optional[bool] = None
def __post_init__(self):
if self.lora_id is None:
......
import logging
from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch
......@@ -7,6 +8,7 @@ from sglang.srt.hf_transformers_utils import AutoConfig
from sglang.srt.lora.layers import BaseLayerWithLoRA
from sglang.srt.lora.lora import LoRAAdapter
from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.lora.lora_registry import LoRARef
from sglang.srt.lora.utils import (
ROW_PARALLELISM_LINEAR_LORA_NAMES,
LoRAType,
......@@ -16,6 +18,28 @@ from sglang.srt.lora.utils import (
get_weight_name,
)
logger = logging.getLogger(__name__)
class EmptySlot:
"""
Singleton class to represent an empty slot in the memory pool.
This is used to improve readability by not using special str as a placeholder.
"""
__slots__ = ()
def __repr__(self):
return "|EMPTY|"
def __new__(cls):
if not hasattr(cls, "_instance"):
cls._instance = super().__new__(cls)
return cls._instance
EMPTY_SLOT = EmptySlot()
class LoRAMemoryPool:
"""Class for memory pool management of lora modules"""
......@@ -54,9 +78,11 @@ class LoRAMemoryPool:
self.uid_to_buffer_id: Dict[Optional[str], int] = {}
# Buffer idx -> lora uid in memory pool
# All uids are initialized as empty strings for empty buffer slots
# All uids are initialized as `EmptySlot` for empty buffer slots
# Here we don't initialize to None since None is a valid uid
self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
self.buffer_id_to_uid: List[Union[str, None, EmptySlot]] = [
EMPTY_SLOT
] * self.max_loras_per_batch
self.init_buffers(base_model)
......@@ -154,17 +180,29 @@ class LoRAMemoryPool:
cur_uids: Set[Optional[str]],
lora_adapters: Dict[str, LoRAAdapter],
lora_modules: List[Dict[str, BaseLayerWithLoRA]],
lora_refs: Dict[str, LoRARef],
):
def get_available_buffer_slot():
for buffer_id in range(self.max_loras_per_batch):
# Prioritize empty slots
if self.buffer_id_to_uid[buffer_id] == "":
if self.buffer_id_to_uid[buffer_id] == EMPTY_SLOT:
return buffer_id
for buffer_id in range(self.max_loras_per_batch):
uid = self.buffer_id_to_uid[buffer_id]
# Evict unneeded lora
if self.buffer_id_to_uid[buffer_id] not in cur_uids:
self.uid_to_buffer_id.pop(self.buffer_id_to_uid[buffer_id])
if uid not in cur_uids:
# Skip pinned LoRAs
# TODO (lifuhuang): we might consider supporting pinning base model (uid == None) in the future.
if uid is not None:
lora_ref = lora_refs.get(uid)
if lora_ref is not None and lora_ref.pinned:
continue
self.uid_to_buffer_id.pop(uid)
logger.debug(f"Evicting LoRA {uid} from buffer slot {buffer_id}.")
self.buffer_id_to_uid[buffer_id] = EMPTY_SLOT
return buffer_id
raise ValueError(
......
......@@ -1082,6 +1082,8 @@ class LoadLoRAAdapterReqInput:
lora_name: str
# The path of loading.
lora_path: str
# Whether to pin the LoRA adapter in memory.
pinned: bool = False
# The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
lora_id: Optional[str] = None
......@@ -1090,6 +1092,7 @@ class LoadLoRAAdapterReqInput:
lora_id=self.lora_id,
lora_name=self.lora_name,
lora_path=self.lora_path,
pinned=self.pinned,
)
......
......@@ -1538,14 +1538,11 @@ class Scheduler(
# Get requests from the waiting queue to a new prefill batch
for req in self.waiting_queue:
if (
self.enable_lora
and len(
lora_set
| set([req.lora_id for req in adder.can_run_list])
| set([req.lora_id])
)
> self.max_loras_per_batch
if self.enable_lora and not self.tp_worker.can_run_lora_batch(
lora_set
| set([req.lora_id for req in adder.can_run_list])
| set([req.lora_id])
):
self.running_batch.batch_is_full = True
break
......
......@@ -1129,6 +1129,7 @@ class TokenizerManager:
new_adapter = LoRARef(
lora_name=obj.lora_name,
lora_path=obj.lora_path,
pinned=obj.pinned,
)
# Trigger the actual loading operation at the backend processes.
......@@ -1186,7 +1187,7 @@ class TokenizerManager:
return result
except ValueError as e:
return UnloadLoRAAdapterReqOutput(success=False, rror_message=str(e))
return UnloadLoRAAdapterReqOutput(success=False, error_message=str(e))
async def get_weights_by_name(
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
......
......@@ -311,3 +311,6 @@ class TpModelWorker:
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
result = self.model_runner.unload_lora_adapter(recv_req.to_ref())
return result
def can_run_lora_batch(self, lora_ids: list[str]) -> bool:
return self.model_runner.lora_manager.validate_lora_batch(lora_ids)
......@@ -288,6 +288,9 @@ class TpModelWorkerClient:
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
return self.worker.unload_lora_adapter(recv_req)
def can_run_lora_batch(self, lora_ids: list[str]) -> bool:
return self.worker.can_run_lora_batch(lora_ids)
def __delete__(self):
self.input_queue.put((None, None))
self.copy_queue.put((None, None, None))
......@@ -2067,21 +2067,23 @@ class ServerArgs:
if self.enable_lora:
# Normalize lora_paths to a dictionary if it is a list.
# TODO (lifuhuang): support specifying pinned adapters in server_args.
if isinstance(self.lora_paths, list):
lora_paths = self.lora_paths
self.lora_paths = {}
for lora_path in lora_paths:
if "=" in lora_path:
name, path = lora_path.split("=", 1)
self.lora_paths[name] = LoRARef(lora_name=name, lora_path=path)
self.lora_paths[name] = LoRARef(
lora_name=name, lora_path=path, pinned=False
)
else:
self.lora_paths[lora_path] = LoRARef(
lora_name=lora_path,
lora_path=lora_path,
lora_name=lora_path, lora_path=lora_path, pinned=False
)
elif isinstance(self.lora_paths, dict):
self.lora_paths = {
k: LoRARef(lora_name=k, lora_path=v)
k: LoRARef(lora_name=k, lora_path=v, pinned=False)
for k, v in self.lora_paths.items()
}
elif self.lora_paths is None:
......
......@@ -568,8 +568,8 @@ class SRTRunner:
else:
self.tokenizer = None
def load_lora_adapter(self, lora_name: str, lora_path: str):
return self.engine.load_lora_adapter(lora_name, lora_path)
def load_lora_adapter(self, lora_name: str, lora_path: str, pinned: bool = False):
return self.engine.load_lora_adapter(lora_name, lora_path, pinned)
def unload_lora_adapter(self, lora_name: str):
return self.engine.unload_lora_adapter(lora_name)
......
......@@ -231,88 +231,6 @@ BASIC_TESTS = [
),
],
),
TestCase(
description="dynamic lora update with evictions",
base="meta-llama/Llama-3.1-8B-Instruct",
max_loras_per_batch=1,
all_adapters=[
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
"pbevan11/llama-3.1-8b-ocr-correction",
],
initial_adapters=["philschmid/code-llama-3-1-8b-text-to-sql-lora"],
op_sequence=[
Operation(
type=OperationType.FORWARD,
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
expected_error="not loaded",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
expected_error="not loaded",
),
Operation(
type=OperationType.LOAD,
data="pbevan11/llama-3.1-8b-ocr-correction",
),
Operation(
type=OperationType.UNLOAD,
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
expected_error="not loaded",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
),
Operation(
type=OperationType.LOAD,
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
),
Operation(
type=OperationType.LOAD,
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
),
],
),
]
TARGET_MODULE_TESTS = [
TestCase(
......@@ -593,9 +511,135 @@ MAX_LOADED_LORAS_TESTS = [
],
),
]
EVICTION_TESTS = [
TestCase(
description="dynamic lora update with evictions",
base="meta-llama/Llama-3.1-8B-Instruct",
max_loras_per_batch=2,
all_adapters=[
"lora1=philschmid/code-llama-3-1-8b-text-to-sql-lora",
"lora2=Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
"lora3=pbevan11/llama-3.1-8b-ocr-correction",
],
enable_lora=True,
max_lora_rank=256,
lora_target_modules=["all"],
op_sequence=[
Operation(
type=OperationType.LOAD,
data={
"lora_name": "lora1",
"lora_path": "philschmid/code-llama-3-1-8b-text-to-sql-lora",
"pinned": True,
},
),
Operation(
type=OperationType.LOAD,
data={
"lora_name": "lora2",
"lora_path": "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
"pinned": True,
},
expected_error="starvation",
),
Operation(
type=OperationType.LOAD,
data={
"lora_name": "lora2",
"lora_path": "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
"pinned": False,
},
),
Operation(
type=OperationType.LOAD,
data={
"lora_name": "lora3",
"lora_path": "pbevan11/llama-3.1-8b-ocr-correction",
"pinned": False,
},
),
Operation(
type=OperationType.UNLOAD,
data="lora1",
),
Operation(
type=OperationType.UNLOAD,
data="lora3",
),
Operation(
type=OperationType.LOAD,
data={
"lora_name": "lora3",
"lora_path": "pbevan11/llama-3.1-8b-ocr-correction",
"pinned": True,
},
),
Operation(
type=OperationType.LOAD,
data={
"lora_name": "lora1",
"lora_path": "philschmid/code-llama-3-1-8b-text-to-sql-lora",
"pinned": True,
},
expected_error="starvation",
),
Operation(
type=OperationType.LOAD,
data={
"lora_name": "lora1",
"lora_path": "philschmid/code-llama-3-1-8b-text-to-sql-lora",
"pinned": False,
},
),
# pinned: lora3
# unpinned: lora1, lora2
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
[
"lora1",
"lora2",
]
),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
[
"lora1",
"lora3",
]
),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
[
"lora1",
"lora2",
]
),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
[
"lora1",
"lora2",
None,
]
),
),
],
),
]
ALL_TESTS = (
BASIC_TESTS + TARGET_MODULE_TESTS + MAX_LORA_RANK_TESTS + MAX_LOADED_LORAS_TESTS
BASIC_TESTS
+ TARGET_MODULE_TESTS
+ MAX_LORA_RANK_TESTS
+ MAX_LOADED_LORAS_TESTS
+ EVICTION_TESTS
)
......@@ -714,6 +758,7 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
lora_name: str,
lora_path: Optional[str] = None,
expected_error: Optional[str] = None,
pinned: bool = False,
):
"""
Load a LoRA adapter by name and path.
......@@ -724,17 +769,31 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
response = self.handle.load_lora_adapter(
lora_name=lora_name,
lora_path=lora_path,
pinned=pinned,
)
if expected_error:
self.testcase.assertFalse(response.success)
self.testcase.assertIn(expected_error, response.error_message)
self.testcase.assertFalse(
response.success, f"Expected failure for {lora_name}, but got success."
)
self.testcase.assertIn(
expected_error,
response.error_message,
f"Expected error message to contain '{expected_error}', but got '{response.error_message}'",
)
print(f"Received error as expected: {response.error_message}")
else:
self.expected_adapters.add(lora_name)
self.testcase.assertTrue(response.success)
self.testcase.assertTrue(
response.success,
f"Failed to load LoRA adapter {lora_name}: {response.error_message}",
)
loaded_adapters = set(response.loaded_adapters)
print(f"loaded_adapters: {loaded_adapters}")
self.testcase.assertEqual(loaded_adapters, self.expected_adapters)
self.testcase.assertEqual(
loaded_adapters,
self.expected_adapters,
f"Expected loaded adapters to be {self.expected_adapters}, but got {loaded_adapters}",
)
def unload_lora_adapter(self, lora_name: str):
"""
......@@ -745,11 +804,18 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
response = self.handle.unload_lora_adapter(
lora_name=lora_name,
)
self.testcase.assertTrue(response.success)
self.testcase.assertTrue(
response.success,
f"Failed to unload LoRA adapter {lora_name}: {response.error_message}",
)
loaded_adapters = set(response.loaded_adapters)
print(f"loaded_adapters: {loaded_adapters}")
self.testcase.assertEqual(loaded_adapters, self.expected_adapters)
self.testcase.assertEqual(
loaded_adapters,
self.expected_adapters,
f"Expected loaded adapters to be {self.expected_adapters}, but got {loaded_adapters}",
)
def forward(
self,
......@@ -770,13 +836,21 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
except ValueError as e:
if expected_error:
error_message = str(e)
self.testcase.assertIn(expected_error, error_message)
self.testcase.assertIn(
expected_error,
error_message,
f"Expected error message to contain '{expected_error}', but got '{error_message}'",
)
print(f"Received error as expected: {error_message}")
return error_message
raise e
self.testcase.assertEqual(len(response.output_strs), len(prompts))
self.testcase.assertEqual(
len(response.output_strs),
len(prompts),
f"Expected {len(prompts)} outputs, but got {len(response.output_strs)}",
)
output = response.output_strs
print(f"output_strs: {output}")
......@@ -837,6 +911,7 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
lora_name: str,
lora_path: Optional[str] = None,
expected_error: Optional[str] = None,
pinned: bool = False,
):
"""
Load a LoRA adapter by name and path.
......@@ -846,18 +921,32 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
response = requests.post(
DEFAULT_URL_FOR_TEST + "/load_lora_adapter",
json={"lora_name": lora_name, "lora_path": lora_path},
json={"lora_name": lora_name, "lora_path": lora_path, "pinned": pinned},
)
if expected_error:
self.testcase.assertEqual(response.status_code, 400)
self.testcase.assertIn(expected_error, response.text)
self.testcase.assertEqual(
response.status_code,
400,
f"Expected error for {lora_name}, but got success.",
)
self.testcase.assertIn(
expected_error,
response.text,
f"Expected error message to contain '{expected_error}', but got '{response.text}'",
)
print(f"Received error as expected: {response.text}")
else:
self.expected_adapters.add(lora_name)
self.testcase.assertTrue(response.ok)
self.testcase.assertTrue(
response.ok, f"Failed to load LoRA adapter {lora_name}: {response.text}"
)
loaded_adapters = set(response.json()["loaded_adapters"])
print(f"loaded_adapters: {loaded_adapters}")
self.testcase.assertEqual(loaded_adapters, self.expected_adapters)
self.testcase.assertEqual(
loaded_adapters,
self.expected_adapters,
f"Expected loaded adapters to be {self.expected_adapters}, but got {loaded_adapters}",
)
def unload_lora_adapter(self, lora_name: str):
"""
......@@ -869,11 +958,17 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
DEFAULT_URL_FOR_TEST + "/unload_lora_adapter",
json={"lora_name": lora_name},
)
self.testcase.assertTrue(response.ok)
self.testcase.assertTrue(
response.ok, f"Failed to unload LoRA adapter {lora_name}: {response.text}"
)
loaded_adapters = set(response.json()["loaded_adapters"])
print(f"loaded_adapters: {loaded_adapters}")
self.testcase.assertEqual(loaded_adapters, self.expected_adapters)
self.testcase.assertEqual(
loaded_adapters,
self.expected_adapters,
f"Expected loaded adapters to be {self.expected_adapters}, but got {loaded_adapters}",
)
def forward(
self,
......@@ -898,15 +993,29 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
},
)
if expected_error:
self.testcase.assertEqual(response.status_code, 400)
self.testcase.assertIn(expected_error, response.text)
self.testcase.assertEqual(
response.status_code,
400,
f"Expected error for forward pass, but got success: {response.text}",
)
self.testcase.assertIn(
expected_error,
response.text,
f"Expected error message to contain '{expected_error}', but got '{response.text}'",
)
output = response.text
print(f"Received error as expected: {response.text}")
return output
else:
self.testcase.assertTrue(response.ok)
self.testcase.assertTrue(
response.ok, f"Failed to generate text: {response.text}"
)
output = [r["text"] for r in response.json()]
self.testcase.assertEqual(len(output), len(prompts))
self.testcase.assertEqual(
len(output),
len(prompts),
f"Expected {len(prompts)} outputs, but got {len(output)}",
)
print(f"output_strs: {output}")
return output
......@@ -974,10 +1083,18 @@ class TestLoRADynamicUpdate(CustomTestCase):
f"Running operation: {op_type} --- data: {data} --- mode: {mode} ---"
)
if op_type == OperationType.LOAD:
if isinstance(data, str):
adapter_info = {
"lora_name": data,
"lora_path": data,
"pinned": False,
}
else:
adapter_info = data
result = session.load_lora_adapter(
lora_name=data,
lora_path=data,
expected_error=expected_error,
**adapter_info,
)
elif op_type == OperationType.UNLOAD:
result = session.unload_lora_adapter(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment