Unverified Commit 8675bdf2 authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

Support limiting max loaded loras in CPU. (#8650)

parent a437aa99
...@@ -33,6 +33,8 @@ ...@@ -33,6 +33,8 @@
"\n", "\n",
"* `max_loras_per_batch`: Maximum number of adaptors used by each batch. This argument can affect the amount of GPU memory reserved for multi-LoRA serving, so it should be set to a smaller value when memory is scarce. Defaults to be 8.\n", "* `max_loras_per_batch`: Maximum number of adaptors used by each batch. This argument can affect the amount of GPU memory reserved for multi-LoRA serving, so it should be set to a smaller value when memory is scarce. Defaults to be 8.\n",
"\n", "\n",
"* `max_loaded_loras`: If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `max-loras-per-batch`.\n",
"\n",
"* `lora_backend`: The backend of running GEMM kernels for Lora modules. It can be one of `triton` or `flashinfer`, and set to `triton` by default. For better performance and stability, we recommend using the Triton LoRA backend. In the future, faster backend built upon Cutlass or Cuda kernels will be added.\n", "* `lora_backend`: The backend of running GEMM kernels for Lora modules. It can be one of `triton` or `flashinfer`, and set to `triton` by default. For better performance and stability, we recommend using the Triton LoRA backend. In the future, faster backend built upon Cutlass or Cuda kernels will be added.\n",
"\n", "\n",
"* `max_lora_rank`: The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup.\n", "* `max_lora_rank`: The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup.\n",
......
...@@ -181,6 +181,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s ...@@ -181,6 +181,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--lora-target-modules` | The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup. You can also set it to `all` to enable LoRA for all supported modules. However, enabling LoRA on additional modules introduces a minor performance overhead. If your application is performance-sensitive, we recommend only specifying the modules for which you plan to load adapters. | None | | `--lora-target-modules` | The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup. You can also set it to `all` to enable LoRA for all supported modules. However, enabling LoRA on additional modules introduces a minor performance overhead. If your application is performance-sensitive, we recommend only specifying the modules for which you plan to load adapters. | None |
| `--lora-paths` | The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}. | None | | `--lora-paths` | The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}. | None |
| `--max-loras-per-batch` | Maximum number of adapters for a running batch, include base-only request. | 8 | | `--max-loras-per-batch` | Maximum number of adapters for a running batch, include base-only request. | 8 |
| `--max-loaded-loras` | If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `--max-loras-per-batch`. | None |
| `--lora-backend` | Choose the kernel backend for multi-LoRA serving. | triton | | `--lora-backend` | Choose the kernel backend for multi-LoRA serving. | triton |
## Kernel backend ## Kernel backend
......
...@@ -186,3 +186,10 @@ class LoRARegistry: ...@@ -186,3 +186,10 @@ class LoRARegistry:
self._registry[lora_ref.lora_name] = lora_ref self._registry[lora_ref.lora_name] = lora_ref
self._counters[lora_ref.lora_id] = ConcurrentCounter() self._counters[lora_ref.lora_id] = ConcurrentCounter()
return lora_ref return lora_ref
@property
def num_registered_loras(self) -> int:
"""
Returns the total number of LoRA adapters currently registered.
"""
return len(self._registry)
...@@ -1097,7 +1097,7 @@ class UnloadLoRAAdapterReqInput: ...@@ -1097,7 +1097,7 @@ class UnloadLoRAAdapterReqInput:
class LoRAUpdateResult: class LoRAUpdateResult:
success: bool success: bool
error_message: Optional[str] = None error_message: Optional[str] = None
loaded_adapters: Dict[str, LoRARef] = field(default_factory=dict) loaded_adapters: Optional[Dict[str, LoRARef]] = None
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
......
...@@ -1084,38 +1084,56 @@ class TokenizerManager: ...@@ -1084,38 +1084,56 @@ class TokenizerManager:
_: Optional[fastapi.Request] = None, _: Optional[fastapi.Request] = None,
) -> LoadLoRAAdapterReqOutput: ) -> LoadLoRAAdapterReqOutput:
self.auto_create_handle_loop() self.auto_create_handle_loop()
if not self.server_args.enable_lora:
raise ValueError(
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
)
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works try:
# with dp_size > 1. if not self.server_args.enable_lora:
assert ( raise ValueError(
self.server_args.dp_size == 1 "LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
), "dp_size must be 1 for dynamic lora loading" )
logger.info(
"Start load Lora adapter. Lora name=%s, path=%s",
obj.lora_name,
obj.lora_path,
)
async with self.lora_update_lock: # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
# Generate new uniquely identifiable LoRARef object. # with dp_size > 1.
new_adapter = LoRARef( assert (
lora_name=obj.lora_name, self.server_args.dp_size == 1
lora_path=obj.lora_path, ), "dp_size must be 1 for dynamic lora loading"
logger.info(
"Start load Lora adapter. Lora name=%s, path=%s",
obj.lora_name,
obj.lora_path,
) )
# Trigger the actual loading operation at the backend processes. async with self.lora_update_lock:
obj.lora_id = new_adapter.lora_id if (
result = (await self.update_lora_adapter_communicator(obj))[0] self.server_args.max_loaded_loras is not None
and self.lora_registry.num_registered_loras
>= self.server_args.max_loaded_loras
):
raise ValueError(
f"Cannot load LoRA adapter {obj.lora_name} at path {obj.lora_path}. "
f"Maximum number of loaded LoRA adapters is {self.server_args.max_loaded_loras}. "
"Please unload some LoRA adapters before loading new ones."
)
# Register the LoRA adapter only after loading is successful. # Generate new uniquely identifiable LoRARef object.
if result.success: new_adapter = LoRARef(
await self.lora_registry.register(new_adapter) lora_name=obj.lora_name,
lora_path=obj.lora_path,
)
return result # Trigger the actual loading operation at the backend processes.
obj.lora_id = new_adapter.lora_id
result = (await self.update_lora_adapter_communicator(obj))[0]
# Register the LoRA adapter only after loading is successful.
if result.success:
await self.lora_registry.register(new_adapter)
return result
except ValueError as e:
return LoadLoRAAdapterReqOutput(
success=False,
error_message=str(e),
)
async def unload_lora_adapter( async def unload_lora_adapter(
self, self,
...@@ -1123,37 +1141,41 @@ class TokenizerManager: ...@@ -1123,37 +1141,41 @@ class TokenizerManager:
_: Optional[fastapi.Request] = None, _: Optional[fastapi.Request] = None,
) -> UnloadLoRAAdapterReqOutput: ) -> UnloadLoRAAdapterReqOutput:
self.auto_create_handle_loop() self.auto_create_handle_loop()
if not self.server_args.enable_lora:
raise ValueError(
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
)
assert ( try:
obj.lora_name is not None if not self.server_args.enable_lora:
), "lora_name must be provided to unload LoRA adapter" raise ValueError(
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
)
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works assert (
# with dp_size > 1. obj.lora_name is not None
assert ( ), "lora_name must be provided to unload LoRA adapter"
self.server_args.dp_size == 1
), "dp_size must be 1 for dynamic lora loading" # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
logger.info( # with dp_size > 1.
"Start unload Lora adapter. Lora name=%s", assert (
obj.lora_name, self.server_args.dp_size == 1
) ), "dp_size must be 1 for dynamic lora loading"
logger.info(
"Start unload Lora adapter. Lora name=%s",
obj.lora_name,
)
async with self.lora_update_lock: async with self.lora_update_lock:
# Unregister the LoRA adapter from the registry to stop new requests for this adapter # Unregister the LoRA adapter from the registry to stop new requests for this adapter
# from being started. # from being started.
lora_id = await self.lora_registry.unregister(obj.lora_name) lora_id = await self.lora_registry.unregister(obj.lora_name)
obj.lora_id = lora_id obj.lora_id = lora_id
# Initiate the actual unloading operation at the backend processes only after all # Initiate the actual unloading operation at the backend processes only after all
# ongoing requests using this LoRA adapter are finished. # ongoing requests using this LoRA adapter are finished.
await self.lora_registry.wait_for_unload(lora_id) await self.lora_registry.wait_for_unload(lora_id)
result = (await self.update_lora_adapter_communicator(obj))[0] result = (await self.update_lora_adapter_communicator(obj))[0]
return result return result
except ValueError as e:
return UnloadLoRAAdapterReqOutput(success=False, rror_message=str(e))
async def get_weights_by_name( async def get_weights_by_name(
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
......
...@@ -149,6 +149,7 @@ class ServerArgs: ...@@ -149,6 +149,7 @@ class ServerArgs:
max_lora_rank: Optional[int] = None max_lora_rank: Optional[int] = None
lora_target_modules: Optional[Union[set[str], List[str]]] = None lora_target_modules: Optional[Union[set[str], List[str]]] = None
lora_paths: Optional[Union[dict[str, str], dict[str, LoRARef], List[str]]] = None lora_paths: Optional[Union[dict[str, str], dict[str, LoRARef], List[str]]] = None
max_loaded_loras: Optional[int] = None
max_loras_per_batch: int = 8 max_loras_per_batch: int = 8
lora_backend: str = "triton" lora_backend: str = "triton"
...@@ -1237,6 +1238,12 @@ class ServerArgs: ...@@ -1237,6 +1238,12 @@ class ServerArgs:
default=8, default=8,
help="Maximum number of adapters for a running batch, include base-only request.", help="Maximum number of adapters for a running batch, include base-only request.",
) )
parser.add_argument(
"--max-loaded-loras",
type=int,
default=ServerArgs.max_loaded_loras,
help="If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `--max-loras-per-batch`.",
)
parser.add_argument( parser.add_argument(
"--lora-backend", "--lora-backend",
type=str, type=str,
...@@ -2008,6 +2015,19 @@ class ServerArgs: ...@@ -2008,6 +2015,19 @@ class ServerArgs:
self.max_lora_rank and self.lora_target_modules self.max_lora_rank and self.lora_target_modules
), "When no initial --lora-paths is provided, you need to specify both --max-lora-rank and --lora-target-modules for LoRA initialization." ), "When no initial --lora-paths is provided, you need to specify both --max-lora-rank and --lora-target-modules for LoRA initialization."
# Validate max_loaded_loras
if self.max_loaded_loras is not None:
assert self.max_loaded_loras >= self.max_loras_per_batch, (
"max_loaded_loras should be greater than or equal to max_loras_per_batch. "
f"max_loaded_loras={self.max_loaded_loras}, max_loras_per_batch={self.max_loras_per_batch}"
)
assert (
not self.lora_paths or len(self.lora_paths) <= self.max_loaded_loras
), (
"The number of LoRA paths should not exceed max_loaded_loras. "
f"max_loaded_loras={self.max_loaded_loras}, lora_paths={len(self.lora_paths)}"
)
def validate_disagg_tp_size(self, prefill_tp: int, decode_tp: int): def validate_disagg_tp_size(self, prefill_tp: int, decode_tp: int):
larger_tp = max(decode_tp, prefill_tp) larger_tp = max(decode_tp, prefill_tp)
smaller_tp = min(decode_tp, prefill_tp) smaller_tp = min(decode_tp, prefill_tp)
......
...@@ -514,6 +514,7 @@ class SRTRunner: ...@@ -514,6 +514,7 @@ class SRTRunner:
max_lora_rank: Optional[int] = None, max_lora_rank: Optional[int] = None,
lora_target_modules: Optional[List[str]] = None, lora_target_modules: Optional[List[str]] = None,
enable_lora: Optional[bool] = None, enable_lora: Optional[bool] = None,
max_loaded_loras: Optional[int] = None,
): ):
self.model_type = model_type self.model_type = model_type
self.is_generation = model_type == "generation" self.is_generation = model_type == "generation"
...@@ -556,6 +557,7 @@ class SRTRunner: ...@@ -556,6 +557,7 @@ class SRTRunner:
max_lora_rank=max_lora_rank, max_lora_rank=max_lora_rank,
lora_target_modules=lora_target_modules, lora_target_modules=lora_target_modules,
enable_lora=enable_lora, enable_lora=enable_lora,
max_loaded_loras=max_loaded_loras,
**spec_kwargs, **spec_kwargs,
) )
......
...@@ -70,6 +70,7 @@ class TestCase: ...@@ -70,6 +70,7 @@ class TestCase:
max_lora_rank: Optional[int] = None max_lora_rank: Optional[int] = None
lora_target_modules: Optional[List] = None lora_target_modules: Optional[List] = None
max_new_tokens: int = 32 max_new_tokens: int = 32
max_loaded_loras: Optional[int] = None
def create_batch_data(adapters: Union[str, list]) -> List[tuple[str, str]]: def create_batch_data(adapters: Union[str, list]) -> List[tuple[str, str]]:
...@@ -559,7 +560,43 @@ MAX_LORA_RANK_TESTS = [ ...@@ -559,7 +560,43 @@ MAX_LORA_RANK_TESTS = [
], ],
), ),
] ]
ALL_TESTS = BASIC_TESTS + TARGET_MODULE_TESTS + MAX_LORA_RANK_TESTS MAX_LOADED_LORAS_TESTS = [
TestCase(
description="Test max_loaded_loras limit",
base="meta-llama/Llama-3.1-8B-Instruct",
max_loras_per_batch=2,
max_loaded_loras=2,
all_adapters=[
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
"pbevan11/llama-3.1-8b-ocr-correction",
],
initial_adapters=["philschmid/code-llama-3-1-8b-text-to-sql-lora"],
op_sequence=[
Operation(
type=OperationType.LOAD,
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
),
Operation(
type=OperationType.LOAD,
data="pbevan11/llama-3.1-8b-ocr-correction",
expected_error="Maximum number of loaded LoRA adapters",
),
Operation(
type=OperationType.UNLOAD,
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
),
Operation(
type=OperationType.LOAD,
data="pbevan11/llama-3.1-8b-ocr-correction",
),
],
),
]
ALL_TESTS = (
BASIC_TESTS + TARGET_MODULE_TESTS + MAX_LORA_RANK_TESTS + MAX_LOADED_LORAS_TESTS
)
class LoRAUpdateTestSessionMode(Enum): class LoRAUpdateTestSessionMode(Enum):
...@@ -579,6 +616,7 @@ class LoRAUpdateTestSessionBase: ...@@ -579,6 +616,7 @@ class LoRAUpdateTestSessionBase:
model_path: str, model_path: str,
lora_paths: list[str], lora_paths: list[str],
max_loras_per_batch: int, max_loras_per_batch: int,
max_loaded_loras: Optional[int] = None,
max_lora_rank: Optional[int], max_lora_rank: Optional[int],
enable_lora: Optional[bool] = None, enable_lora: Optional[bool] = None,
lora_target_modules: Optional[List[str]] = None, lora_target_modules: Optional[List[str]] = None,
...@@ -592,6 +630,7 @@ class LoRAUpdateTestSessionBase: ...@@ -592,6 +630,7 @@ class LoRAUpdateTestSessionBase:
self.max_lora_rank = max_lora_rank self.max_lora_rank = max_lora_rank
self.lora_target_modules = lora_target_modules self.lora_target_modules = lora_target_modules
self.max_loras_per_batch = max_loras_per_batch self.max_loras_per_batch = max_loras_per_batch
self.max_loaded_loras = max_loaded_loras
self.lora_backend = lora_backend self.lora_backend = lora_backend
self.disable_cuda_graph = disable_cuda_graph self.disable_cuda_graph = disable_cuda_graph
self.cuda_graph_max_bs = cuda_graph_max_bs self.cuda_graph_max_bs = cuda_graph_max_bs
...@@ -654,6 +693,7 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase): ...@@ -654,6 +693,7 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
torch_dtype=torch.float16, torch_dtype=torch.float16,
mem_fraction_static=MEM_FRACTION_STATIC, mem_fraction_static=MEM_FRACTION_STATIC,
max_loras_per_batch=self.max_loras_per_batch, max_loras_per_batch=self.max_loras_per_batch,
max_loaded_loras=self.max_loaded_loras,
disable_cuda_graph=self.disable_cuda_graph, disable_cuda_graph=self.disable_cuda_graph,
cuda_graph_max_bs=self.cuda_graph_max_bs, cuda_graph_max_bs=self.cuda_graph_max_bs,
disable_radix_cache=True, disable_radix_cache=True,
...@@ -774,6 +814,8 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase): ...@@ -774,6 +814,8 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
other_args.extend(["--max-lora-rank", str(self.max_lora_rank)]) other_args.extend(["--max-lora-rank", str(self.max_lora_rank)])
if self.lora_target_modules is not None: if self.lora_target_modules is not None:
other_args.extend(["--lora-target-modules"] + self.lora_target_modules) other_args.extend(["--lora-target-modules"] + self.lora_target_modules)
if self.max_loaded_loras is not None:
other_args.extend(["--max-loaded-loras", str(self.max_loaded_loras)])
# launch external server # launch external server
self.handle = popen_launch_server( self.handle = popen_launch_server(
...@@ -898,8 +940,9 @@ class TestLoRADynamicUpdate(CustomTestCase): ...@@ -898,8 +940,9 @@ class TestLoRADynamicUpdate(CustomTestCase):
mode: LoRAUpdateTestSessionMode, mode: LoRAUpdateTestSessionMode,
base: str, base: str,
initial_adapters: List[str], initial_adapters: List[str],
max_loras_per_batch: int,
op_sequence: List[Operation], op_sequence: List[Operation],
max_loras_per_batch: int,
max_loaded_loras: Optional[int] = None,
enable_lora: Optional[bool] = None, enable_lora: Optional[bool] = None,
max_lora_rank: Optional[int] = None, max_lora_rank: Optional[int] = None,
lora_target_modules: Optional[List[str]] = None, lora_target_modules: Optional[List[str]] = None,
...@@ -917,6 +960,7 @@ class TestLoRADynamicUpdate(CustomTestCase): ...@@ -917,6 +960,7 @@ class TestLoRADynamicUpdate(CustomTestCase):
model_path=base, model_path=base,
lora_paths=initial_adapters, lora_paths=initial_adapters,
max_loras_per_batch=max_loras_per_batch, max_loras_per_batch=max_loras_per_batch,
max_loaded_loras=max_loaded_loras,
max_lora_rank=max_lora_rank, max_lora_rank=max_lora_rank,
lora_target_modules=lora_target_modules, lora_target_modules=lora_target_modules,
enable_lora=enable_lora, enable_lora=enable_lora,
...@@ -972,6 +1016,7 @@ class TestLoRADynamicUpdate(CustomTestCase): ...@@ -972,6 +1016,7 @@ class TestLoRADynamicUpdate(CustomTestCase):
enable_lora=test_case.enable_lora, enable_lora=test_case.enable_lora,
base=test_case.base, base=test_case.base,
max_loras_per_batch=test_case.max_loras_per_batch, max_loras_per_batch=test_case.max_loras_per_batch,
max_loaded_loras=test_case.max_loaded_loras,
op_sequence=test_case.op_sequence, op_sequence=test_case.op_sequence,
max_new_tokens=test_case.max_new_tokens, max_new_tokens=test_case.max_new_tokens,
max_lora_rank=test_case.max_lora_rank, max_lora_rank=test_case.max_lora_rank,
...@@ -985,6 +1030,12 @@ class TestLoRADynamicUpdate(CustomTestCase): ...@@ -985,6 +1030,12 @@ class TestLoRADynamicUpdate(CustomTestCase):
if x.type == OperationType.FORWARD and x.expected_error is None if x.type == OperationType.FORWARD and x.expected_error is None
] ]
if not forward_ops:
print(
f"No forward operations found in test case {case_idx}. Skipping static pass."
)
continue
print("=" * 100) print("=" * 100)
print(f"\n--- Running static pass with {len(forward_ops)} operations ---") print(f"\n--- Running static pass with {len(forward_ops)} operations ---")
static_output = self._run_operation_sequence( static_output = self._run_operation_sequence(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment