"tools/git@developer.sourcefind.cn:OpenDAS/dlib.git" did not exist on "b52e50a30bc994ab5c04f8b0f8adc1334d5f6a61"
Unverified Commit b0980af8 authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

Support pinning adapter via server args. (#9249)

parent 24eaebeb
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
"\n", "\n",
"* `enable_lora`: Enable LoRA support for the model. This argument is automatically set to True if `--lora-paths` is provided for backward compatibility.\n", "* `enable_lora`: Enable LoRA support for the model. This argument is automatically set to True if `--lora-paths` is provided for backward compatibility.\n",
"\n", "\n",
"* `lora_paths`: A mapping from each adaptor's name to its path, in the form of `{name}={path} {name}={path}`.\n", "* `lora_paths`: The list of LoRA adapters to load. Each adapter must be specified in one of the following formats: <PATH> | <NAME>=<PATH> | JSON with schema {\"lora_name\":str,\"lora_path\":str,\"pinned\":bool}.\n",
"\n", "\n",
"* `max_loras_per_batch`: Maximum number of adaptors used by each batch. This argument can affect the amount of GPU memory reserved for multi-LoRA serving, so it should be set to a smaller value when memory is scarce. Defaults to be 8.\n", "* `max_loras_per_batch`: Maximum number of adaptors used by each batch. This argument can affect the amount of GPU memory reserved for multi-LoRA serving, so it should be set to a smaller value when memory is scarce. Defaults to be 8.\n",
"\n", "\n",
...@@ -372,6 +372,15 @@ ...@@ -372,6 +372,15 @@
"print(f\"Output from lora1 (updated): \\n{response.json()[1]['text']}\\n\")" "print(f\"Output from lora1 (updated): \\n{response.json()[1]['text']}\\n\")"
] ]
}, },
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"terminate_process(server_process)"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
...@@ -387,7 +396,40 @@ ...@@ -387,7 +396,40 @@
"\n", "\n",
"This can improve performance in scenarios where the same adapter is frequently used across requests, by avoiding repeated memory transfers and reinitialization overhead. However, since GPU pool slots are limited, pinning adapters reduces the flexibility of the system to dynamically load other adapters on demand. If too many adapters are pinned, it may lead to degraded performance, or in the most extreme case (`Number of pinned adapters == max-loras-per-batch`), halt all unpinned requests. Therefore, currently SGLang limits maximal number of pinned adapters to `max-loras-per-batch - 1` to prevent unexpected starvations. \n", "This can improve performance in scenarios where the same adapter is frequently used across requests, by avoiding repeated memory transfers and reinitialization overhead. However, since GPU pool slots are limited, pinning adapters reduces the flexibility of the system to dynamically load other adapters on demand. If too many adapters are pinned, it may lead to degraded performance, or in the most extreme case (`Number of pinned adapters == max-loras-per-batch`), halt all unpinned requests. Therefore, currently SGLang limits maximal number of pinned adapters to `max-loras-per-batch - 1` to prevent unexpected starvations. \n",
"\n", "\n",
"In the example below, we unload `lora1` and reload it as a `pinned` adapter:" "In the example below, we start a server with `lora1` loaded as pinned, `lora2` and `lora3` loaded as regular (unpinned) adapters. Please note that, we intentionally specify `lora2` and `lora3` in two different formats to demonstrate that both are supported."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"server_process, port = launch_server_cmd(\n",
" \"\"\"\n",
" python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
" --enable-lora \\\n",
" --cuda-graph-max-bs 8 \\\n",
" --max-loras-per-batch 3 --lora-backend triton \\\n",
" --max-lora-rank 256 \\\n",
" --lora-target-modules all \\\n",
" --lora-paths \\\n",
" {\"lora_name\":\"lora0\",\"lora_path\":\"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16\",\"pinned\":true} \\\n",
" {\"lora_name\":\"lora1\",\"lora_path\":\"algoprog/fact-generation-llama-3.1-8b-instruct-lora\"} \\\n",
" lora2=philschmid/code-llama-3-1-8b-text-to-sql-lora\n",
" \"\"\"\n",
")\n",
"\n",
"\n",
"url = f\"http://127.0.0.1:{port}\"\n",
"wait_for_server(url)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can also specify adapter as pinned during dynamic adapter loading. In the example below, we reload `lora2` as pinned adapter:"
] ]
}, },
{ {
...@@ -407,7 +449,7 @@ ...@@ -407,7 +449,7 @@
" url + \"/load_lora_adapter\",\n", " url + \"/load_lora_adapter\",\n",
" json={\n", " json={\n",
" \"lora_name\": \"lora1\",\n", " \"lora_name\": \"lora1\",\n",
" \"lora_path\": lora1,\n", " \"lora_path\": \"algoprog/fact-generation-llama-3.1-8b-instruct-lora\",\n",
" \"pinned\": True, # Pin the adapter to GPU\n", " \"pinned\": True, # Pin the adapter to GPU\n",
" },\n", " },\n",
")" ")"
...@@ -417,7 +459,7 @@ ...@@ -417,7 +459,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"Verify that the result is identical as before:" "Verify that the results are expected:"
] ]
}, },
{ {
...@@ -431,17 +473,19 @@ ...@@ -431,17 +473,19 @@
" \"text\": [\n", " \"text\": [\n",
" \"List 3 countries and their capitals.\",\n", " \"List 3 countries and their capitals.\",\n",
" \"List 3 countries and their capitals.\",\n", " \"List 3 countries and their capitals.\",\n",
" \"List 3 countries and their capitals.\",\n",
" ],\n", " ],\n",
" \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n", " \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n",
" # The first input uses lora0, and the second input uses lora1\n", " # The first input uses lora0, and the second input uses lora1\n",
" \"lora_path\": [\"lora0\", \"lora1\"],\n", " \"lora_path\": [\"lora0\", \"lora1\", \"lora2\"],\n",
"}\n", "}\n",
"response = requests.post(\n", "response = requests.post(\n",
" url + \"/generate\",\n", " url + \"/generate\",\n",
" json=json_data,\n", " json=json_data,\n",
")\n", ")\n",
"print(f\"Output from lora0: \\n{response.json()[0]['text']}\\n\")\n", "print(f\"Output from lora0 (pinned): \\n{response.json()[0]['text']}\\n\")\n",
"print(f\"Output from lora1 (pinned): \\n{response.json()[1]['text']}\\n\")" "print(f\"Output from lora1 (pinned): \\n{response.json()[1]['text']}\\n\")\n",
"print(f\"Output from lora2 (not pinned): \\n{response.json()[2]['text']}\\n\")"
] ]
}, },
{ {
......
...@@ -179,7 +179,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s ...@@ -179,7 +179,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--enable-lora` | Enable LoRA support for the model. This argument is automatically set to True if `--lora-paths` is provided for backward compatibility. | False | | `--enable-lora` | Enable LoRA support for the model. This argument is automatically set to True if `--lora-paths` is provided for backward compatibility. | False |
| `--max-lora-rank` | The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup. | None | | `--max-lora-rank` | The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup. | None |
| `--lora-target-modules` | The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup. You can also set it to `all` to enable LoRA for all supported modules. However, enabling LoRA on additional modules introduces a minor performance overhead. If your application is performance-sensitive, we recommend only specifying the modules for which you plan to load adapters. | None | | `--lora-target-modules` | The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup. You can also set it to `all` to enable LoRA for all supported modules. However, enabling LoRA on additional modules introduces a minor performance overhead. If your application is performance-sensitive, we recommend only specifying the modules for which you plan to load adapters. | None |
| `--lora-paths` | The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}. | None | | `--lora-paths` | The list of LoRA adapters to load. Each adapter must be specified in one of the following formats: <PATH> | <NAME>=<PATH> | JSON with schema {"lora_name":str,"lora_path":str,"pinned":bool} | None |
| `--max-loras-per-batch` | Maximum number of adapters for a running batch, include base-only request. | 8 | | `--max-loras-per-batch` | Maximum number of adapters for a running batch, include base-only request. | 8 |
| `--max-loaded-loras` | If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `--max-loras-per-batch`. | None | | `--max-loaded-loras` | If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `--max-loras-per-batch`. | None |
| `--lora-backend` | Choose the kernel backend for multi-LoRA serving. | triton | | `--lora-backend` | Choose the kernel backend for multi-LoRA serving. | triton |
......
...@@ -55,7 +55,7 @@ class LoRAManager: ...@@ -55,7 +55,7 @@ class LoRAManager:
tp_rank: int = 0, tp_rank: int = 0,
max_lora_rank: Optional[int] = None, max_lora_rank: Optional[int] = None,
target_modules: Optional[Iterable[str]] = None, target_modules: Optional[Iterable[str]] = None,
lora_paths: Optional[Dict[str, LoRARef]] = None, lora_paths: Optional[List[LoRARef]] = None,
): ):
self.base_model: torch.nn.Module = base_model self.base_model: torch.nn.Module = base_model
self.base_hf_config: AutoConfig = base_hf_config self.base_hf_config: AutoConfig = base_hf_config
...@@ -370,7 +370,7 @@ class LoRAManager: ...@@ -370,7 +370,7 @@ class LoRAManager:
self, self,
max_lora_rank: Optional[int] = None, max_lora_rank: Optional[int] = None,
target_modules: Optional[Iterable[str]] = None, target_modules: Optional[Iterable[str]] = None,
lora_paths: Optional[Dict[str, LoRARef]] = None, lora_paths: Optional[List[LoRARef]] = None,
): ):
""" """
Initialize the internal (mutable) state of the LoRAManager. Initialize the internal (mutable) state of the LoRAManager.
...@@ -392,7 +392,7 @@ class LoRAManager: ...@@ -392,7 +392,7 @@ class LoRAManager:
self.init_memory_pool() self.init_memory_pool()
self.update_lora_info() self.update_lora_info()
def init_lora_adapters(self, lora_paths: Optional[Dict[str, LoRARef]] = None): def init_lora_adapters(self, lora_paths: Optional[List[LoRARef]] = None):
# Configs of all active LoRA adapters, indexed by LoRA ID. # Configs of all active LoRA adapters, indexed by LoRA ID.
self.configs: Dict[str, LoRAConfig] = {} self.configs: Dict[str, LoRAConfig] = {}
...@@ -406,7 +406,7 @@ class LoRAManager: ...@@ -406,7 +406,7 @@ class LoRAManager:
self.num_pinned_loras: int = 0 self.num_pinned_loras: int = 0
if lora_paths: if lora_paths:
for lora_ref in lora_paths.values(): for lora_ref in lora_paths:
result = self.load_lora_adapter(lora_ref) result = self.load_lora_adapter(lora_ref)
if not result.success: if not result.success:
raise RuntimeError( raise RuntimeError(
......
...@@ -59,9 +59,9 @@ class LoRARegistry: ...@@ -59,9 +59,9 @@ class LoRARegistry:
update / eventual consistency model between the tokenizer manager process and the scheduler processes. update / eventual consistency model between the tokenizer manager process and the scheduler processes.
""" """
def __init__(self, lora_paths: Optional[Dict[str, LoRARef]] = None): def __init__(self, lora_paths: Optional[List[LoRARef]] = None):
assert lora_paths is None or all( assert lora_paths is None or all(
isinstance(lora, LoRARef) for lora in lora_paths.values() isinstance(lora, LoRARef) for lora in lora_paths
), ( ), (
"server_args.lora_paths should have been normalized to LoRARef objects during server initialization. " "server_args.lora_paths should have been normalized to LoRARef objects during server initialization. "
"Please file an issue if you see this error." "Please file an issue if you see this error."
...@@ -78,7 +78,7 @@ class LoRARegistry: ...@@ -78,7 +78,7 @@ class LoRARegistry:
# Initialize the registry with provided LoRA paths, if present. # Initialize the registry with provided LoRA paths, if present.
if lora_paths: if lora_paths:
for lora_ref in lora_paths.values(): for lora_ref in lora_paths:
self._register_adapter(lora_ref) self._register_adapter(lora_ref)
async def register(self, lora_ref: LoRARef): async def register(self, lora_ref: LoRARef):
......
...@@ -298,7 +298,7 @@ class TokenizerManager: ...@@ -298,7 +298,7 @@ class TokenizerManager:
# The registry dynamically updates as adapters are loaded / unloaded during runtime. It # The registry dynamically updates as adapters are loaded / unloaded during runtime. It
# serves as the source of truth for available adapters and maps user-friendly LoRA names # serves as the source of truth for available adapters and maps user-friendly LoRA names
# to internally used unique LoRA IDs. # to internally used unique LoRA IDs.
self.lora_registry = LoRARegistry(self.server_args.lora_paths or {}) self.lora_registry = LoRARegistry(self.server_args.lora_paths)
# Lock to serialize LoRA update operations. # Lock to serialize LoRA update operations.
# Please note that, unlike `model_update_lock`, this does not block inference, allowing # Please note that, unlike `model_update_lock`, this does not block inference, allowing
# LoRA updates and inference to overlap. # LoRA updates and inference to overlap.
......
...@@ -153,7 +153,9 @@ class ServerArgs: ...@@ -153,7 +153,9 @@ class ServerArgs:
enable_lora: Optional[bool] = None enable_lora: Optional[bool] = None
max_lora_rank: Optional[int] = None max_lora_rank: Optional[int] = None
lora_target_modules: Optional[Union[set[str], List[str]]] = None lora_target_modules: Optional[Union[set[str], List[str]]] = None
lora_paths: Optional[Union[dict[str, str], dict[str, LoRARef], List[str]]] = None lora_paths: Optional[
Union[dict[str, str], List[dict[str, str]], List[str], List[LoRARef]]
] = None
max_loaded_loras: Optional[int] = None max_loaded_loras: Optional[int] = None
max_loras_per_batch: int = 8 max_loras_per_batch: int = 8
lora_backend: str = "triton" lora_backend: str = "triton"
...@@ -1319,7 +1321,7 @@ class ServerArgs: ...@@ -1319,7 +1321,7 @@ class ServerArgs:
nargs="*", nargs="*",
default=None, default=None,
action=LoRAPathAction, action=LoRAPathAction,
help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}.", help='The list of LoRA adapters to load. Each adapter must be specified in one of the following formats: <PATH> | <NAME>=<PATH> | JSON with schema {"lora_name":str,"lora_path":str,"pinned":bool}',
) )
parser.add_argument( parser.add_argument(
"--max-loras-per-batch", "--max-loras-per-batch",
...@@ -2086,28 +2088,42 @@ class ServerArgs: ...@@ -2086,28 +2088,42 @@ class ServerArgs:
) )
if self.enable_lora: if self.enable_lora:
# Normalize lora_paths to a dictionary if it is a list.
# TODO (lifuhuang): support specifying pinned adapters in server_args.
if isinstance(self.lora_paths, list): if isinstance(self.lora_paths, list):
lora_paths = self.lora_paths lora_paths = self.lora_paths
self.lora_paths = {} self.lora_paths = []
for lora_path in lora_paths: for lora_path in lora_paths:
if "=" in lora_path: if isinstance(lora_path, str):
name, path = lora_path.split("=", 1) if "=" in lora_path:
self.lora_paths[name] = LoRARef( name, path = lora_path.split("=", 1)
lora_name=name, lora_path=path, pinned=False lora_ref = LoRARef(
lora_name=name, lora_path=path, pinned=False
)
else:
lora_ref = LoRARef(
lora_name=lora_path, lora_path=lora_path, pinned=False
)
elif isinstance(lora_path, dict):
assert (
"lora_name" in lora_path and "lora_path" in lora_path
), f"When providing LoRA paths as a list of dict, each dict should contain 'lora_name' and 'lora_path' keys. Got: {lora_path}"
lora_ref = LoRARef(
lora_name=lora_path["lora_name"],
lora_path=lora_path["lora_path"],
pinned=lora_path.get("pinned", False),
) )
else: else:
self.lora_paths[lora_path] = LoRARef( raise ValueError(
lora_name=lora_path, lora_path=lora_path, pinned=False f"Invalid type for item in --lora-paths list: {type(lora_path)}. "
"Expected a string or a dictionary."
) )
self.lora_paths.append(lora_ref)
elif isinstance(self.lora_paths, dict): elif isinstance(self.lora_paths, dict):
self.lora_paths = { self.lora_paths = [
k: LoRARef(lora_name=k, lora_path=v, pinned=False) LoRARef(lora_name=k, lora_path=v, pinned=False)
for k, v in self.lora_paths.items() for k, v in self.lora_paths.items()
} ]
elif self.lora_paths is None: elif self.lora_paths is None:
self.lora_paths = {} self.lora_paths = []
else: else:
raise ValueError( raise ValueError(
f"Invalid type for --lora-paths: {type(self.lora_paths)}. " f"Invalid type for --lora-paths: {type(self.lora_paths)}. "
...@@ -2134,9 +2150,7 @@ class ServerArgs: ...@@ -2134,9 +2150,7 @@ class ServerArgs:
"max_loaded_loras should be greater than or equal to max_loras_per_batch. " "max_loaded_loras should be greater than or equal to max_loras_per_batch. "
f"max_loaded_loras={self.max_loaded_loras}, max_loras_per_batch={self.max_loras_per_batch}" f"max_loaded_loras={self.max_loaded_loras}, max_loras_per_batch={self.max_loras_per_batch}"
) )
assert ( assert len(self.lora_paths) <= self.max_loaded_loras, (
not self.lora_paths or len(self.lora_paths) <= self.max_loaded_loras
), (
"The number of LoRA paths should not exceed max_loaded_loras. " "The number of LoRA paths should not exceed max_loaded_loras. "
f"max_loaded_loras={self.max_loaded_loras}, lora_paths={len(self.lora_paths)}" f"max_loaded_loras={self.max_loaded_loras}, lora_paths={len(self.lora_paths)}"
) )
...@@ -2357,13 +2371,22 @@ class PortArgs: ...@@ -2357,13 +2371,22 @@ class PortArgs:
class LoRAPathAction(argparse.Action): class LoRAPathAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None): def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, {}) lora_paths = []
for lora_path in values: if values:
if "=" in lora_path: assert isinstance(values, list), "Expected a list of LoRA paths."
name, path = lora_path.split("=", 1) for lora_path in values:
getattr(namespace, self.dest)[name] = path lora_path = lora_path.strip()
else: if lora_path.startswith("{") and lora_path.endswith("}"):
getattr(namespace, self.dest)[lora_path] = lora_path obj = json.loads(lora_path)
assert "lora_path" in obj and "lora_name" in obj, (
f"{repr(lora_path)} looks like a JSON str, "
"but it does not contain 'lora_name' and 'lora_path' keys."
)
lora_paths.append(obj)
else:
lora_paths.append(lora_path)
setattr(namespace, self.dest, lora_paths)
class DeprecatedAction(argparse.Action): class DeprecatedAction(argparse.Action):
......
...@@ -491,7 +491,7 @@ class SRTRunner: ...@@ -491,7 +491,7 @@ class SRTRunner:
tp_size: int = 1, tp_size: int = 1,
model_impl: str = "auto", model_impl: str = "auto",
port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER, port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
lora_paths: List[str] = None, lora_paths: Optional[Union[List[str], List[dict[str, str]]]] = None,
max_loras_per_batch: int = 4, max_loras_per_batch: int = 4,
attention_backend: Optional[str] = None, attention_backend: Optional[str] = None,
prefill_attention_backend: Optional[str] = None, prefill_attention_backend: Optional[str] = None,
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
import json
import multiprocessing as mp import multiprocessing as mp
import unittest import unittest
from dataclasses import dataclass from dataclasses import dataclass
...@@ -89,8 +90,35 @@ BASIC_TESTS = [ ...@@ -89,8 +90,35 @@ BASIC_TESTS = [
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
"pbevan11/llama-3.1-8b-ocr-correction", "pbevan11/llama-3.1-8b-ocr-correction",
], ],
initial_adapters=["philschmid/code-llama-3-1-8b-text-to-sql-lora"], initial_adapters=[
# Testing 3 supported lora-path formats.
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16=Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
{
"lora_name": "pbevan11/llama-3.1-8b-ocr-correction",
"lora_path": "pbevan11/llama-3.1-8b-ocr-correction",
"pinned": False,
},
],
op_sequence=[ op_sequence=[
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
[
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
"pbevan11/llama-3.1-8b-ocr-correction",
]
),
),
Operation(
type=OperationType.UNLOAD,
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
),
Operation(
type=OperationType.UNLOAD,
data="pbevan11/llama-3.1-8b-ocr-correction",
),
Operation( Operation(
type=OperationType.FORWARD, type=OperationType.FORWARD,
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"), data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
...@@ -147,6 +175,10 @@ BASIC_TESTS = [ ...@@ -147,6 +175,10 @@ BASIC_TESTS = [
type=OperationType.UNLOAD, type=OperationType.UNLOAD,
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
), ),
Operation(
type=OperationType.UNLOAD,
data="pbevan11/llama-3.1-8b-ocr-correction",
),
Operation( Operation(
type=OperationType.FORWARD, type=OperationType.FORWARD,
data=create_batch_data( data=create_batch_data(
...@@ -157,18 +189,12 @@ BASIC_TESTS = [ ...@@ -157,18 +189,12 @@ BASIC_TESTS = [
Operation( Operation(
type=OperationType.FORWARD, type=OperationType.FORWARD,
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"), data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
), expected_error="not loaded",
Operation(
type=OperationType.LOAD,
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
), ),
Operation( Operation(
type=OperationType.FORWARD, type=OperationType.FORWARD,
data=create_batch_data( data=create_batch_data(
[ None,
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
"pbevan11/llama-3.1-8b-ocr-correction",
]
), ),
), ),
], ],
...@@ -705,7 +731,7 @@ class LoRAUpdateTestSessionBase: ...@@ -705,7 +731,7 @@ class LoRAUpdateTestSessionBase:
*, *,
testcase: Optional[TestCase], testcase: Optional[TestCase],
model_path: str, model_path: str,
lora_paths: list[str], lora_paths: List[Union[str, dict]],
max_loras_per_batch: int, max_loras_per_batch: int,
max_loaded_loras: Optional[int] = None, max_loaded_loras: Optional[int] = None,
max_lora_rank: Optional[int], max_lora_rank: Optional[int],
...@@ -727,7 +753,17 @@ class LoRAUpdateTestSessionBase: ...@@ -727,7 +753,17 @@ class LoRAUpdateTestSessionBase:
self.cuda_graph_max_bs = cuda_graph_max_bs self.cuda_graph_max_bs = cuda_graph_max_bs
self.enable_lora = enable_lora self.enable_lora = enable_lora
self.expected_adapters = set(lora_paths or []) self.expected_adapters = set()
if self.lora_paths:
for adapter in self.lora_paths:
if isinstance(adapter, dict):
lora_name = adapter["lora_name"]
elif "=" in adapter:
lora_name = adapter.split("=")[0]
else:
lora_name = adapter
self.expected_adapters.add(lora_name)
self.handle = None # Will be set in __enter__ self.handle = None # Will be set in __enter__
def __enter__(self): def __enter__(self):
...@@ -926,7 +962,11 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase): ...@@ -926,7 +962,11 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
if self.enable_lora: if self.enable_lora:
other_args.append("--enable-lora") other_args.append("--enable-lora")
if self.lora_paths: if self.lora_paths:
other_args.extend(["--lora-paths"] + self.lora_paths) other_args.append("--lora-paths")
for lora_path in self.lora_paths:
if isinstance(lora_path, dict):
lora_path = json.dumps(lora_path)
other_args.append(lora_path)
if self.disable_cuda_graph: if self.disable_cuda_graph:
other_args.append("--disable-cuda-graph") other_args.append("--disable-cuda-graph")
if self.max_lora_rank is not None: if self.max_lora_rank is not None:
...@@ -1093,7 +1133,7 @@ class TestLoRADynamicUpdate(CustomTestCase): ...@@ -1093,7 +1133,7 @@ class TestLoRADynamicUpdate(CustomTestCase):
self, self,
mode: LoRAUpdateTestSessionMode, mode: LoRAUpdateTestSessionMode,
base: str, base: str,
initial_adapters: List[str], initial_adapters: List[Union[str, dict]],
op_sequence: List[Operation], op_sequence: List[Operation],
max_loras_per_batch: int, max_loras_per_batch: int,
max_loaded_loras: Optional[int] = None, max_loaded_loras: Optional[int] = None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment