Unverified Commit 6210e2c4 authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

Support GPU pinning for LoRA (#8697)

parent 6ad6c8c9
...@@ -381,6 +381,78 @@ ...@@ -381,6 +381,78 @@
"print(f\"Output from lora1 (updated): \\n{response.json()[1]['text']}\\n\")" "print(f\"Output from lora1 (updated): \\n{response.json()[1]['text']}\\n\")"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### LoRA GPU Pinning"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Another advanced option is to specify adapters as `pinned` during loading. When an adapter is pinned, it is permanently assigned to one of the available GPU pool slots (as configured by `--max-loras-per-batch`) and will not be evicted from GPU memory during runtime. Instead, it remains resident until it is explicitly unloaded.\n",
"\n",
"This can improve performance in scenarios where the same adapter is frequently used across requests, by avoiding repeated memory transfers and reinitialization overhead. However, since GPU pool slots are limited, pinning adapters reduces the flexibility of the system to dynamically load other adapters on demand. If too many adapters are pinned, it may lead to degraded performance, or in the most extreme case (`Number of pinned adapters == max-loras-per-batch`), halt all unpinned requests. Therefore, currently SGLang limits maximal number of pinned adapters to `max-loras-per-batch - 1` to prevent unexpected starvations. \n",
"\n",
"In the example below, we unload `lora1` and reload it as a `pinned` adapter:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response = requests.post(\n",
" url + \"/unload_lora_adapter\",\n",
" json={\n",
" \"lora_name\": \"lora1\",\n",
" },\n",
")\n",
"\n",
"response = requests.post(\n",
" url + \"/load_lora_adapter\",\n",
" json={\n",
" \"lora_name\": \"lora1\",\n",
" \"lora_path\": lora1,\n",
" \"pinned\": True, # Pin the adapter to GPU\n",
" },\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Verify that the result is identical as before:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"url = f\"http://127.0.0.1:{port}\"\n",
"json_data = {\n",
" \"text\": [\n",
" \"List 3 countries and their capitals.\",\n",
" \"List 3 countries and their capitals.\",\n",
" ],\n",
" \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n",
" # The first input uses lora0, and the second input uses lora1\n",
" \"lora_path\": [\"lora0\", \"lora1\"],\n",
"}\n",
"response = requests.post(\n",
" url + \"/generate\",\n",
" json=json_data,\n",
")\n",
"print(f\"Output from lora0: \\n{response.json()[0]['text']}\\n\")\n",
"print(f\"Output from lora1 (pinned): \\n{response.json()[1]['text']}\\n\")"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
......
...@@ -492,12 +492,13 @@ class Engine(EngineBase): ...@@ -492,12 +492,13 @@ class Engine(EngineBase):
self.tokenizer_manager.get_weights_by_name(obj, None) self.tokenizer_manager.get_weights_by_name(obj, None)
) )
def load_lora_adapter(self, lora_name: str, lora_path: str): def load_lora_adapter(self, lora_name: str, lora_path: str, pinned: bool = False):
"""Load a new LoRA adapter without re-launching the engine.""" """Load a new LoRA adapter without re-launching the engine."""
obj = LoadLoRAAdapterReqInput( obj = LoadLoRAAdapterReqInput(
lora_name=lora_name, lora_name=lora_name,
lora_path=lora_path, lora_path=lora_path,
pinned=pinned,
) )
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
......
...@@ -144,6 +144,7 @@ class LoRAManager: ...@@ -144,6 +144,7 @@ class LoRAManager:
# keep metadata for displayed messages # keep metadata for displayed messages
self.lora_refs[lora_ref.lora_id] = lora_ref self.lora_refs[lora_ref.lora_id] = lora_ref
self.num_pinned_loras += int(lora_ref.pinned)
except Exception as e: except Exception as e:
return self.create_lora_update_result( return self.create_lora_update_result(
success=False, success=False,
...@@ -157,13 +158,22 @@ class LoRAManager: ...@@ -157,13 +158,22 @@ class LoRAManager:
Validate if an adapter can be loaded into the current LoRA memory pool and generate error if it is incompatible. Validate if an adapter can be loaded into the current LoRA memory pool and generate error if it is incompatible.
""" """
# Check if the LoRA adapter shape is compatible with the current LoRA memory pool configuration.
memory_pool = getattr(self, "memory_pool", None) memory_pool = getattr(self, "memory_pool", None)
incompatible = memory_pool and not memory_pool.can_support(lora_config) incompatible = memory_pool and not memory_pool.can_support(lora_config)
if incompatible: if incompatible:
raise ValueError( raise ValueError(
f"LoRA adapter {lora_ref.lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration. " f"LoRA adapter {lora_ref.lora_name} with rank {lora_config.r} is incompatible with the current "
"Please ensure that the LoRA adapter's rank is within the configured `--max_lora_rank` and that the target modules are " "LoRA memory pool configuration. Please ensure that the LoRA adapter's rank is within the configured "
"included in `--enable_lora_modules`." "`--max-lora-rank` and that the target modules are included in `--lora-target-modules`."
)
# Ensure pinned LoRA adapters does not exceed maximal limit or cause starvation.
if lora_ref.pinned and self.num_pinned_loras >= self.max_loras_per_batch - 1:
raise ValueError(
f"Failed to load LoRA adapter {lora_ref.lora_name} as a pinned adapter. It is not allowed to pin all slots "
"in the LoRA memory pool to avoid starvation for unpinned adapters and base models. Please increase your "
"`--max-loras-per-batch` or load it as unpinned LoRA adapters."
) )
def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult: def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
...@@ -172,15 +182,17 @@ class LoRAManager: ...@@ -172,15 +182,17 @@ class LoRAManager:
delete the corresponding LoRA modules. delete the corresponding LoRA modules.
""" """
adapter = self.configs.get(lora_ref.lora_id, None) adapter = self.configs.get(lora_ref.lora_id)
lora_ref = self.lora_refs.get(lora_ref.lora_id)
assert ( assert (
adapter is not None adapter is not None and lora_ref is not None
), f"LoRA adapter with ID {lora_ref.lora_id} is not loaded. This should have been verified before request is sent to the backend." ), f"LoRA adapter with ID {lora_ref.lora_id} is not loaded. This should have been verified before request is sent to the backend."
try: try:
del self.configs[lora_ref.lora_id] del self.configs[lora_ref.lora_id]
del self.loras[lora_ref.lora_id] del self.loras[lora_ref.lora_id]
del self.lora_refs[lora_ref.lora_id] del self.lora_refs[lora_ref.lora_id]
self.num_pinned_loras -= int(lora_ref.pinned)
except Exception as e: except Exception as e:
return self.create_lora_update_result( return self.create_lora_update_result(
success=False, success=False,
...@@ -189,11 +201,49 @@ class LoRAManager: ...@@ -189,11 +201,49 @@ class LoRAManager:
return self.create_lora_update_result(success=True) return self.create_lora_update_result(success=True)
def validate_lora_batch(self, lora_ids: set[str]) -> bool:
"""
Validate if the LoRA IDs in the batch can be loaded into the current LoRA memory pool.
"""
if len(lora_ids) > self.max_loras_per_batch:
return False
# skip pinned LoRA check if no pinned LoRA adapters are loaded.
if self.num_pinned_loras == 0:
return True
# counting the number of pinned LoRA adapters in the batch.
pinned_loras_in_batch = 0
for lora_id in lora_ids:
if lora_id is not None:
lora_ref = self.lora_refs.get(lora_id)
assert (
lora_ref is not None
), f"LoRA ID {lora_id} not found in lora_refs."
pinned_loras_in_batch += int(lora_ref.pinned)
assert pinned_loras_in_batch <= self.num_pinned_loras, (
f"Number of pinned LoRA adapters in the batch ({pinned_loras_in_batch}) exceeds the total number of pinned adapters "
f"({self.num_pinned_loras}). This indicates a bug in the LoRA loading logic."
)
required_slots = len(lora_ids) - pinned_loras_in_batch
mem_pool_vacancy = self.memory_pool.max_loras_per_batch - self.num_pinned_loras
return required_slots <= mem_pool_vacancy
def prepare_lora_batch(self, forward_batch: ForwardBatch): def prepare_lora_batch(self, forward_batch: ForwardBatch):
# Load active loras into lora memory pool # Load active loras into lora memory pool
cur_uids = set(forward_batch.lora_ids) cur_uids = set(forward_batch.lora_ids)
assert len(cur_uids) <= self.max_loras_per_batch assert len(cur_uids) <= self.max_loras_per_batch
self.memory_pool.prepare_lora_batch(cur_uids, self.loras, self.lora_modules) self.memory_pool.prepare_lora_batch(
cur_uids=cur_uids,
lora_adapters=self.loras,
lora_modules=self.lora_modules,
lora_refs=self.lora_refs.copy(), # copy snapshot of current lora_refs to avoid mutation during the batch preparation.
)
# set up batch info shared by all lora modules # set up batch info shared by all lora modules
bs = forward_batch.batch_size bs = forward_batch.batch_size
...@@ -366,6 +416,9 @@ class LoRAManager: ...@@ -366,6 +416,9 @@ class LoRAManager:
# Mapping from LoRA ID to LoRARef object. # Mapping from LoRA ID to LoRARef object.
self.lora_refs: Dict[str, LoRARef] = {} self.lora_refs: Dict[str, LoRARef] = {}
# Count of pinned LoRA adapters.
self.num_pinned_loras: int = 0
if lora_paths: if lora_paths:
for lora_ref in lora_paths.values(): for lora_ref in lora_paths.values():
result = self.load_lora_adapter(lora_ref) result = self.load_lora_adapter(lora_ref)
...@@ -399,7 +452,7 @@ class LoRAManager: ...@@ -399,7 +452,7 @@ class LoRAManager:
self.max_lora_rank = max_lora_rank self.max_lora_rank = max_lora_rank
else: else:
self.max_lora_rank = max( self.max_lora_rank = max(
[x.hf_config["r"] for x in self.configs.values()], [x.r for x in self.configs.values()],
default=0, default=0,
) )
......
...@@ -28,14 +28,15 @@ class LoRARef: ...@@ -28,14 +28,15 @@ class LoRARef:
""" """
Reference record for a LoRA model. Reference record for a LoRA model.
This object guarantees a unique ``lora_id`` and may include ``lora_name`` and ``lora_path``. The ID This object guarantees a unique ``lora_id`` and may include ``lora_name``, ``lora_path``, and ``pinned``.
eliminates conflicts from reused LoRA names or paths and can be used to generate deterministic cache The ID eliminates conflicts from reused LoRA names or paths and can be used to generate deterministic cache
keys (e.g., radix cache). keys (e.g., radix cache).
""" """
lora_id: str = field(default_factory=lambda: uuid4().hex) lora_id: str = field(default_factory=lambda: uuid4().hex)
lora_name: Optional[str] = None lora_name: Optional[str] = None
lora_path: Optional[str] = None lora_path: Optional[str] = None
pinned: Optional[bool] = None
def __post_init__(self): def __post_init__(self):
if self.lora_id is None: if self.lora_id is None:
......
import logging
from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch import torch
...@@ -7,6 +8,7 @@ from sglang.srt.hf_transformers_utils import AutoConfig ...@@ -7,6 +8,7 @@ from sglang.srt.hf_transformers_utils import AutoConfig
from sglang.srt.lora.layers import BaseLayerWithLoRA from sglang.srt.lora.layers import BaseLayerWithLoRA
from sglang.srt.lora.lora import LoRAAdapter from sglang.srt.lora.lora import LoRAAdapter
from sglang.srt.lora.lora_config import LoRAConfig from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.lora.lora_registry import LoRARef
from sglang.srt.lora.utils import ( from sglang.srt.lora.utils import (
ROW_PARALLELISM_LINEAR_LORA_NAMES, ROW_PARALLELISM_LINEAR_LORA_NAMES,
LoRAType, LoRAType,
...@@ -16,6 +18,28 @@ from sglang.srt.lora.utils import ( ...@@ -16,6 +18,28 @@ from sglang.srt.lora.utils import (
get_weight_name, get_weight_name,
) )
logger = logging.getLogger(__name__)
class EmptySlot:
"""
Singleton class to represent an empty slot in the memory pool.
This is used to improve readability by not using special str as a placeholder.
"""
__slots__ = ()
def __repr__(self):
return "|EMPTY|"
def __new__(cls):
if not hasattr(cls, "_instance"):
cls._instance = super().__new__(cls)
return cls._instance
EMPTY_SLOT = EmptySlot()
class LoRAMemoryPool: class LoRAMemoryPool:
"""Class for memory pool management of lora modules""" """Class for memory pool management of lora modules"""
...@@ -54,9 +78,11 @@ class LoRAMemoryPool: ...@@ -54,9 +78,11 @@ class LoRAMemoryPool:
self.uid_to_buffer_id: Dict[Optional[str], int] = {} self.uid_to_buffer_id: Dict[Optional[str], int] = {}
# Buffer idx -> lora uid in memory pool # Buffer idx -> lora uid in memory pool
# All uids are initialized as empty strings for empty buffer slots # All uids are initialized as `EmptySlot` for empty buffer slots
# Here we don't initialize to None since None is a valid uid # Here we don't initialize to None since None is a valid uid
self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch self.buffer_id_to_uid: List[Union[str, None, EmptySlot]] = [
EMPTY_SLOT
] * self.max_loras_per_batch
self.init_buffers(base_model) self.init_buffers(base_model)
...@@ -154,17 +180,29 @@ class LoRAMemoryPool: ...@@ -154,17 +180,29 @@ class LoRAMemoryPool:
cur_uids: Set[Optional[str]], cur_uids: Set[Optional[str]],
lora_adapters: Dict[str, LoRAAdapter], lora_adapters: Dict[str, LoRAAdapter],
lora_modules: List[Dict[str, BaseLayerWithLoRA]], lora_modules: List[Dict[str, BaseLayerWithLoRA]],
lora_refs: Dict[str, LoRARef],
): ):
def get_available_buffer_slot(): def get_available_buffer_slot():
for buffer_id in range(self.max_loras_per_batch): for buffer_id in range(self.max_loras_per_batch):
# Prioritize empty slots # Prioritize empty slots
if self.buffer_id_to_uid[buffer_id] == "": if self.buffer_id_to_uid[buffer_id] == EMPTY_SLOT:
return buffer_id return buffer_id
for buffer_id in range(self.max_loras_per_batch): for buffer_id in range(self.max_loras_per_batch):
uid = self.buffer_id_to_uid[buffer_id]
# Evict unneeded lora # Evict unneeded lora
if self.buffer_id_to_uid[buffer_id] not in cur_uids: if uid not in cur_uids:
self.uid_to_buffer_id.pop(self.buffer_id_to_uid[buffer_id]) # Skip pinned LoRAs
# TODO (lifuhuang): we might consider supporting pinning base model (uid == None) in the future.
if uid is not None:
lora_ref = lora_refs.get(uid)
if lora_ref is not None and lora_ref.pinned:
continue
self.uid_to_buffer_id.pop(uid)
logger.debug(f"Evicting LoRA {uid} from buffer slot {buffer_id}.")
self.buffer_id_to_uid[buffer_id] = EMPTY_SLOT
return buffer_id return buffer_id
raise ValueError( raise ValueError(
......
...@@ -1082,6 +1082,8 @@ class LoadLoRAAdapterReqInput: ...@@ -1082,6 +1082,8 @@ class LoadLoRAAdapterReqInput:
lora_name: str lora_name: str
# The path of loading. # The path of loading.
lora_path: str lora_path: str
# Whether to pin the LoRA adapter in memory.
pinned: bool = False
# The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`. # The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
lora_id: Optional[str] = None lora_id: Optional[str] = None
...@@ -1090,6 +1092,7 @@ class LoadLoRAAdapterReqInput: ...@@ -1090,6 +1092,7 @@ class LoadLoRAAdapterReqInput:
lora_id=self.lora_id, lora_id=self.lora_id,
lora_name=self.lora_name, lora_name=self.lora_name,
lora_path=self.lora_path, lora_path=self.lora_path,
pinned=self.pinned,
) )
......
...@@ -1538,14 +1538,11 @@ class Scheduler( ...@@ -1538,14 +1538,11 @@ class Scheduler(
# Get requests from the waiting queue to a new prefill batch # Get requests from the waiting queue to a new prefill batch
for req in self.waiting_queue: for req in self.waiting_queue:
if (
self.enable_lora if self.enable_lora and not self.tp_worker.can_run_lora_batch(
and len(
lora_set lora_set
| set([req.lora_id for req in adder.can_run_list]) | set([req.lora_id for req in adder.can_run_list])
| set([req.lora_id]) | set([req.lora_id])
)
> self.max_loras_per_batch
): ):
self.running_batch.batch_is_full = True self.running_batch.batch_is_full = True
break break
......
...@@ -1129,6 +1129,7 @@ class TokenizerManager: ...@@ -1129,6 +1129,7 @@ class TokenizerManager:
new_adapter = LoRARef( new_adapter = LoRARef(
lora_name=obj.lora_name, lora_name=obj.lora_name,
lora_path=obj.lora_path, lora_path=obj.lora_path,
pinned=obj.pinned,
) )
# Trigger the actual loading operation at the backend processes. # Trigger the actual loading operation at the backend processes.
...@@ -1186,7 +1187,7 @@ class TokenizerManager: ...@@ -1186,7 +1187,7 @@ class TokenizerManager:
return result return result
except ValueError as e: except ValueError as e:
return UnloadLoRAAdapterReqOutput(success=False, rror_message=str(e)) return UnloadLoRAAdapterReqOutput(success=False, error_message=str(e))
async def get_weights_by_name( async def get_weights_by_name(
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
......
...@@ -311,3 +311,6 @@ class TpModelWorker: ...@@ -311,3 +311,6 @@ class TpModelWorker:
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput): def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
result = self.model_runner.unload_lora_adapter(recv_req.to_ref()) result = self.model_runner.unload_lora_adapter(recv_req.to_ref())
return result return result
def can_run_lora_batch(self, lora_ids: list[str]) -> bool:
return self.model_runner.lora_manager.validate_lora_batch(lora_ids)
...@@ -288,6 +288,9 @@ class TpModelWorkerClient: ...@@ -288,6 +288,9 @@ class TpModelWorkerClient:
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput): def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
return self.worker.unload_lora_adapter(recv_req) return self.worker.unload_lora_adapter(recv_req)
def can_run_lora_batch(self, lora_ids: list[str]) -> bool:
return self.worker.can_run_lora_batch(lora_ids)
def __delete__(self): def __delete__(self):
self.input_queue.put((None, None)) self.input_queue.put((None, None))
self.copy_queue.put((None, None, None)) self.copy_queue.put((None, None, None))
...@@ -2067,21 +2067,23 @@ class ServerArgs: ...@@ -2067,21 +2067,23 @@ class ServerArgs:
if self.enable_lora: if self.enable_lora:
# Normalize lora_paths to a dictionary if it is a list. # Normalize lora_paths to a dictionary if it is a list.
# TODO (lifuhuang): support specifying pinned adapters in server_args.
if isinstance(self.lora_paths, list): if isinstance(self.lora_paths, list):
lora_paths = self.lora_paths lora_paths = self.lora_paths
self.lora_paths = {} self.lora_paths = {}
for lora_path in lora_paths: for lora_path in lora_paths:
if "=" in lora_path: if "=" in lora_path:
name, path = lora_path.split("=", 1) name, path = lora_path.split("=", 1)
self.lora_paths[name] = LoRARef(lora_name=name, lora_path=path) self.lora_paths[name] = LoRARef(
lora_name=name, lora_path=path, pinned=False
)
else: else:
self.lora_paths[lora_path] = LoRARef( self.lora_paths[lora_path] = LoRARef(
lora_name=lora_path, lora_name=lora_path, lora_path=lora_path, pinned=False
lora_path=lora_path,
) )
elif isinstance(self.lora_paths, dict): elif isinstance(self.lora_paths, dict):
self.lora_paths = { self.lora_paths = {
k: LoRARef(lora_name=k, lora_path=v) k: LoRARef(lora_name=k, lora_path=v, pinned=False)
for k, v in self.lora_paths.items() for k, v in self.lora_paths.items()
} }
elif self.lora_paths is None: elif self.lora_paths is None:
......
...@@ -568,8 +568,8 @@ class SRTRunner: ...@@ -568,8 +568,8 @@ class SRTRunner:
else: else:
self.tokenizer = None self.tokenizer = None
def load_lora_adapter(self, lora_name: str, lora_path: str): def load_lora_adapter(self, lora_name: str, lora_path: str, pinned: bool = False):
return self.engine.load_lora_adapter(lora_name, lora_path) return self.engine.load_lora_adapter(lora_name, lora_path, pinned)
def unload_lora_adapter(self, lora_name: str): def unload_lora_adapter(self, lora_name: str):
return self.engine.unload_lora_adapter(lora_name) return self.engine.unload_lora_adapter(lora_name)
......
...@@ -231,88 +231,6 @@ BASIC_TESTS = [ ...@@ -231,88 +231,6 @@ BASIC_TESTS = [
), ),
], ],
), ),
TestCase(
description="dynamic lora update with evictions",
base="meta-llama/Llama-3.1-8B-Instruct",
max_loras_per_batch=1,
all_adapters=[
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
"pbevan11/llama-3.1-8b-ocr-correction",
],
initial_adapters=["philschmid/code-llama-3-1-8b-text-to-sql-lora"],
op_sequence=[
Operation(
type=OperationType.FORWARD,
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
expected_error="not loaded",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
expected_error="not loaded",
),
Operation(
type=OperationType.LOAD,
data="pbevan11/llama-3.1-8b-ocr-correction",
),
Operation(
type=OperationType.UNLOAD,
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
expected_error="not loaded",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
),
Operation(
type=OperationType.LOAD,
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
),
Operation(
type=OperationType.LOAD,
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
),
],
),
] ]
TARGET_MODULE_TESTS = [ TARGET_MODULE_TESTS = [
TestCase( TestCase(
...@@ -593,9 +511,135 @@ MAX_LOADED_LORAS_TESTS = [ ...@@ -593,9 +511,135 @@ MAX_LOADED_LORAS_TESTS = [
], ],
), ),
] ]
EVICTION_TESTS = [
TestCase(
description="dynamic lora update with evictions",
base="meta-llama/Llama-3.1-8B-Instruct",
max_loras_per_batch=2,
all_adapters=[
"lora1=philschmid/code-llama-3-1-8b-text-to-sql-lora",
"lora2=Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
"lora3=pbevan11/llama-3.1-8b-ocr-correction",
],
enable_lora=True,
max_lora_rank=256,
lora_target_modules=["all"],
op_sequence=[
Operation(
type=OperationType.LOAD,
data={
"lora_name": "lora1",
"lora_path": "philschmid/code-llama-3-1-8b-text-to-sql-lora",
"pinned": True,
},
),
Operation(
type=OperationType.LOAD,
data={
"lora_name": "lora2",
"lora_path": "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
"pinned": True,
},
expected_error="starvation",
),
Operation(
type=OperationType.LOAD,
data={
"lora_name": "lora2",
"lora_path": "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
"pinned": False,
},
),
Operation(
type=OperationType.LOAD,
data={
"lora_name": "lora3",
"lora_path": "pbevan11/llama-3.1-8b-ocr-correction",
"pinned": False,
},
),
Operation(
type=OperationType.UNLOAD,
data="lora1",
),
Operation(
type=OperationType.UNLOAD,
data="lora3",
),
Operation(
type=OperationType.LOAD,
data={
"lora_name": "lora3",
"lora_path": "pbevan11/llama-3.1-8b-ocr-correction",
"pinned": True,
},
),
Operation(
type=OperationType.LOAD,
data={
"lora_name": "lora1",
"lora_path": "philschmid/code-llama-3-1-8b-text-to-sql-lora",
"pinned": True,
},
expected_error="starvation",
),
Operation(
type=OperationType.LOAD,
data={
"lora_name": "lora1",
"lora_path": "philschmid/code-llama-3-1-8b-text-to-sql-lora",
"pinned": False,
},
),
# pinned: lora3
# unpinned: lora1, lora2
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
[
"lora1",
"lora2",
]
),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
[
"lora1",
"lora3",
]
),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
[
"lora1",
"lora2",
]
),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
[
"lora1",
"lora2",
None,
]
),
),
],
),
]
ALL_TESTS = ( ALL_TESTS = (
BASIC_TESTS + TARGET_MODULE_TESTS + MAX_LORA_RANK_TESTS + MAX_LOADED_LORAS_TESTS BASIC_TESTS
+ TARGET_MODULE_TESTS
+ MAX_LORA_RANK_TESTS
+ MAX_LOADED_LORAS_TESTS
+ EVICTION_TESTS
) )
...@@ -714,6 +758,7 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase): ...@@ -714,6 +758,7 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
lora_name: str, lora_name: str,
lora_path: Optional[str] = None, lora_path: Optional[str] = None,
expected_error: Optional[str] = None, expected_error: Optional[str] = None,
pinned: bool = False,
): ):
""" """
Load a LoRA adapter by name and path. Load a LoRA adapter by name and path.
...@@ -724,17 +769,31 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase): ...@@ -724,17 +769,31 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
response = self.handle.load_lora_adapter( response = self.handle.load_lora_adapter(
lora_name=lora_name, lora_name=lora_name,
lora_path=lora_path, lora_path=lora_path,
pinned=pinned,
) )
if expected_error: if expected_error:
self.testcase.assertFalse(response.success) self.testcase.assertFalse(
self.testcase.assertIn(expected_error, response.error_message) response.success, f"Expected failure for {lora_name}, but got success."
)
self.testcase.assertIn(
expected_error,
response.error_message,
f"Expected error message to contain '{expected_error}', but got '{response.error_message}'",
)
print(f"Received error as expected: {response.error_message}") print(f"Received error as expected: {response.error_message}")
else: else:
self.expected_adapters.add(lora_name) self.expected_adapters.add(lora_name)
self.testcase.assertTrue(response.success) self.testcase.assertTrue(
response.success,
f"Failed to load LoRA adapter {lora_name}: {response.error_message}",
)
loaded_adapters = set(response.loaded_adapters) loaded_adapters = set(response.loaded_adapters)
print(f"loaded_adapters: {loaded_adapters}") print(f"loaded_adapters: {loaded_adapters}")
self.testcase.assertEqual(loaded_adapters, self.expected_adapters) self.testcase.assertEqual(
loaded_adapters,
self.expected_adapters,
f"Expected loaded adapters to be {self.expected_adapters}, but got {loaded_adapters}",
)
def unload_lora_adapter(self, lora_name: str): def unload_lora_adapter(self, lora_name: str):
""" """
...@@ -745,11 +804,18 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase): ...@@ -745,11 +804,18 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
response = self.handle.unload_lora_adapter( response = self.handle.unload_lora_adapter(
lora_name=lora_name, lora_name=lora_name,
) )
self.testcase.assertTrue(response.success) self.testcase.assertTrue(
response.success,
f"Failed to unload LoRA adapter {lora_name}: {response.error_message}",
)
loaded_adapters = set(response.loaded_adapters) loaded_adapters = set(response.loaded_adapters)
print(f"loaded_adapters: {loaded_adapters}") print(f"loaded_adapters: {loaded_adapters}")
self.testcase.assertEqual(loaded_adapters, self.expected_adapters) self.testcase.assertEqual(
loaded_adapters,
self.expected_adapters,
f"Expected loaded adapters to be {self.expected_adapters}, but got {loaded_adapters}",
)
def forward( def forward(
self, self,
...@@ -770,13 +836,21 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase): ...@@ -770,13 +836,21 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
except ValueError as e: except ValueError as e:
if expected_error: if expected_error:
error_message = str(e) error_message = str(e)
self.testcase.assertIn(expected_error, error_message) self.testcase.assertIn(
expected_error,
error_message,
f"Expected error message to contain '{expected_error}', but got '{error_message}'",
)
print(f"Received error as expected: {error_message}") print(f"Received error as expected: {error_message}")
return error_message return error_message
raise e raise e
self.testcase.assertEqual(len(response.output_strs), len(prompts)) self.testcase.assertEqual(
len(response.output_strs),
len(prompts),
f"Expected {len(prompts)} outputs, but got {len(response.output_strs)}",
)
output = response.output_strs output = response.output_strs
print(f"output_strs: {output}") print(f"output_strs: {output}")
...@@ -837,6 +911,7 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase): ...@@ -837,6 +911,7 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
lora_name: str, lora_name: str,
lora_path: Optional[str] = None, lora_path: Optional[str] = None,
expected_error: Optional[str] = None, expected_error: Optional[str] = None,
pinned: bool = False,
): ):
""" """
Load a LoRA adapter by name and path. Load a LoRA adapter by name and path.
...@@ -846,18 +921,32 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase): ...@@ -846,18 +921,32 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
response = requests.post( response = requests.post(
DEFAULT_URL_FOR_TEST + "/load_lora_adapter", DEFAULT_URL_FOR_TEST + "/load_lora_adapter",
json={"lora_name": lora_name, "lora_path": lora_path}, json={"lora_name": lora_name, "lora_path": lora_path, "pinned": pinned},
) )
if expected_error: if expected_error:
self.testcase.assertEqual(response.status_code, 400) self.testcase.assertEqual(
self.testcase.assertIn(expected_error, response.text) response.status_code,
400,
f"Expected error for {lora_name}, but got success.",
)
self.testcase.assertIn(
expected_error,
response.text,
f"Expected error message to contain '{expected_error}', but got '{response.text}'",
)
print(f"Received error as expected: {response.text}") print(f"Received error as expected: {response.text}")
else: else:
self.expected_adapters.add(lora_name) self.expected_adapters.add(lora_name)
self.testcase.assertTrue(response.ok) self.testcase.assertTrue(
response.ok, f"Failed to load LoRA adapter {lora_name}: {response.text}"
)
loaded_adapters = set(response.json()["loaded_adapters"]) loaded_adapters = set(response.json()["loaded_adapters"])
print(f"loaded_adapters: {loaded_adapters}") print(f"loaded_adapters: {loaded_adapters}")
self.testcase.assertEqual(loaded_adapters, self.expected_adapters) self.testcase.assertEqual(
loaded_adapters,
self.expected_adapters,
f"Expected loaded adapters to be {self.expected_adapters}, but got {loaded_adapters}",
)
def unload_lora_adapter(self, lora_name: str): def unload_lora_adapter(self, lora_name: str):
""" """
...@@ -869,11 +958,17 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase): ...@@ -869,11 +958,17 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
DEFAULT_URL_FOR_TEST + "/unload_lora_adapter", DEFAULT_URL_FOR_TEST + "/unload_lora_adapter",
json={"lora_name": lora_name}, json={"lora_name": lora_name},
) )
self.testcase.assertTrue(response.ok) self.testcase.assertTrue(
response.ok, f"Failed to unload LoRA adapter {lora_name}: {response.text}"
)
loaded_adapters = set(response.json()["loaded_adapters"]) loaded_adapters = set(response.json()["loaded_adapters"])
print(f"loaded_adapters: {loaded_adapters}") print(f"loaded_adapters: {loaded_adapters}")
self.testcase.assertEqual(loaded_adapters, self.expected_adapters) self.testcase.assertEqual(
loaded_adapters,
self.expected_adapters,
f"Expected loaded adapters to be {self.expected_adapters}, but got {loaded_adapters}",
)
def forward( def forward(
self, self,
...@@ -898,15 +993,29 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase): ...@@ -898,15 +993,29 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
}, },
) )
if expected_error: if expected_error:
self.testcase.assertEqual(response.status_code, 400) self.testcase.assertEqual(
self.testcase.assertIn(expected_error, response.text) response.status_code,
400,
f"Expected error for forward pass, but got success: {response.text}",
)
self.testcase.assertIn(
expected_error,
response.text,
f"Expected error message to contain '{expected_error}', but got '{response.text}'",
)
output = response.text output = response.text
print(f"Received error as expected: {response.text}") print(f"Received error as expected: {response.text}")
return output return output
else: else:
self.testcase.assertTrue(response.ok) self.testcase.assertTrue(
response.ok, f"Failed to generate text: {response.text}"
)
output = [r["text"] for r in response.json()] output = [r["text"] for r in response.json()]
self.testcase.assertEqual(len(output), len(prompts)) self.testcase.assertEqual(
len(output),
len(prompts),
f"Expected {len(prompts)} outputs, but got {len(output)}",
)
print(f"output_strs: {output}") print(f"output_strs: {output}")
return output return output
...@@ -974,10 +1083,18 @@ class TestLoRADynamicUpdate(CustomTestCase): ...@@ -974,10 +1083,18 @@ class TestLoRADynamicUpdate(CustomTestCase):
f"Running operation: {op_type} --- data: {data} --- mode: {mode} ---" f"Running operation: {op_type} --- data: {data} --- mode: {mode} ---"
) )
if op_type == OperationType.LOAD: if op_type == OperationType.LOAD:
if isinstance(data, str):
adapter_info = {
"lora_name": data,
"lora_path": data,
"pinned": False,
}
else:
adapter_info = data
result = session.load_lora_adapter( result = session.load_lora_adapter(
lora_name=data,
lora_path=data,
expected_error=expected_error, expected_error=expected_error,
**adapter_info,
) )
elif op_type == OperationType.UNLOAD: elif op_type == OperationType.UNLOAD:
result = session.unload_lora_adapter( result = session.unload_lora_adapter(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment