Unverified Commit df906455 authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

Support overlapped lora updates (#8213)

parent 95217a9b
......@@ -14,10 +14,14 @@
import asyncio
from collections import defaultdict
from dataclasses import dataclass, field, fields
from typing import Dict, List, Optional, Union
from uuid import uuid4
from sglang.srt.aio_rwlock import RWLock
from sglang.srt.utils import ConcurrentCounter
@dataclass(frozen=True)
class LoRARef:
......@@ -48,10 +52,11 @@ class LoRARef:
class LoRARegistry:
"""
The central registry to keep track of available LoRA adapters.
The central registry to keep track of available LoRA adapters and ongoing LoRA requests.
TODO (lifuhuang): This registry is intended as the foundation for overlapped lora update. We decided
to keep it in a separate PR to keep code review simple and to unblock the radix cache work.
The `LoRARegistry` resides in the tokenizer manager process and acts as the single source of truth for all
available LoRA adapters. It supports concurrent inference and dynamic adapter updates through a two-phase
update / eventual consistency model between the tokenizer manager process and the scheduler processes.
"""
def __init__(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
......@@ -62,8 +67,19 @@ class LoRARegistry:
"Please file an issue if you see this error."
)
# A read-write lock to ensure adapters loading / unloading operations are exclusive.
# Please note that the counter increment/decrement operations are not synchronized through this
# lock, as they are designed to be non-blocking and can be performed concurrently.
self._registry_lock = RWLock()
# A dictionary to hold LoRARef objects, mapping from LoRA name to LoRARef.
self._registry: Dict[str, LoRARef] = dict(lora_paths or {})
self._registry: Dict[str, LoRARef] = {}
# Counters for ongoing requests, mapping from LoRA ID to ConcurrentCounter.
self._counters: Dict[str, ConcurrentCounter] = {}
# Initialize the registry with provided LoRA paths, if present.
if lora_paths:
for lora_ref in lora_paths.values():
self._register_adapter(lora_ref)
async def register(self, lora_ref: LoRARef):
"""
......@@ -72,11 +88,8 @@ class LoRARegistry:
Args:
lora_ref (LoRARef): The LoRARef object to register.
"""
if lora_ref.lora_name in self._registry:
raise ValueError(
f"LoRA with name {lora_ref.lora_name} already exists. Loaded LoRAs: {self._registry.keys()}"
)
self._registry[lora_ref.lora_name] = lora_ref
async with self._registry_lock.writer_lock:
self._register_adapter(lora_ref)
async def unregister(self, lora_name: str) -> str:
"""
......@@ -85,12 +98,14 @@ class LoRARegistry:
Args:
lora_name (str): The name of the LoRA model to unregister.
"""
lora_ref = self._registry.get(lora_name, None)
if lora_ref is None:
raise ValueError(
f"LoRA with name {lora_name} does not exist. Loaded LoRAs: {self._registry.keys()}"
)
del self._registry[lora_name]
async with self._registry_lock.writer_lock:
lora_ref = self._registry.get(lora_name, None)
if lora_ref is None:
raise ValueError(
f"LoRA with name {lora_name} does not exist. Loaded LoRAs: {self._registry.keys()}"
)
del self._registry[lora_name]
del self._counters[lora_ref.lora_id]
return lora_ref.lora_id
......@@ -98,27 +113,76 @@ class LoRARegistry:
"""
Queries registry for LoRA IDs based on LoRA names and start tracking the usage of the corresponding LoRA adapters
by incrementing its counter.
TODO (lifuhuang): currently it only queries the registry and does not track the usage of LoRA adapters.
"""
async def _acquire_single(name: str) -> str:
def _lookup(name: str) -> str:
lora_ref = self._registry.get(name, None)
if lora_ref is None:
raise ValueError(
f"The following requested LoRA adapters are not loaded: {name}\n"
f"Loaded adapters: {self._registry.keys()}."
)
# await self._counters[lora_ref.lora_id].increment()
return lora_ref.lora_id
if isinstance(lora_name, str):
lora_id = await _acquire_single(lora_name)
return lora_id
elif isinstance(lora_name, list):
lora_ids = await asyncio.gather(
*[_acquire_single(name) for name in lora_name]
async with self._registry_lock.reader_lock:
if isinstance(lora_name, str):
lora_id = _lookup(lora_name)
await self._counters[lora_id].increment(notify_all=False)
return lora_id
elif isinstance(lora_name, list):
lora_ids = [_lookup(name) for name in lora_name]
# Increment the counters only after all IDs are looked up.
await asyncio.gather(
*[self._counters[id].increment(notify_all=False) for id in lora_ids]
)
return lora_ids
else:
raise TypeError(
"lora_name must be either a string or a list of strings."
)
async def release(self, lora_id: Union[str, List[str]]):
"""
Decrements the usage counter for a LoRA adapter, indicating that it is no longer in use.
"""
async with self._registry_lock.reader_lock:
if isinstance(lora_id, str):
await self._counters[lora_id].decrement()
elif isinstance(lora_id, list):
await asyncio.gather(
*[self._counters[id].decrement() for id in lora_id]
)
else:
raise TypeError("lora_id must be either a string or a list of strings.")
async def wait_for_unload(self, lora_id: str):
"""
Waits until the usage counter for a LoRA adapter reaches zero, indicating that it is no longer in use.
This is useful for ensuring that a LoRA adapter can be safely unloaded.
This method itself is not synchronized, which is safe because it should only be called during LoRA unloading,
which itself is guaranteed to be sequential.
"""
assert (
lora_id not in self._registry
), "wait_for_unload should only be called after the LoRA adapter has been unregistered. "
counter = self._counters.get(lora_id)
if counter:
# Wait until no requests are using this LoRA adapter.
await counter.wait_for_zero()
del self._counters[lora_id]
def _register_adapter(self, lora_ref: LoRARef):
"""
Internal helper method to register a LoRA adapter.
"""
if lora_ref.lora_name in self._registry:
raise ValueError(
f"LoRA with name {lora_ref.lora_name} already exists. Loaded LoRAs: {self._registry.keys()}"
)
return lora_ids
else:
raise TypeError("lora_name must be either a string or a list of strings.")
self._registry[lora_ref.lora_name] = lora_ref
self._counters[lora_ref.lora_id] = ConcurrentCounter()
return lora_ref
......@@ -282,6 +282,11 @@ class TokenizerManager:
None
)
# Lock to serialize LoRA update operations.
# Please note that, unlike `model_update_lock`, this does not block inference, allowing
# LoRA updates and inference to overlap.
self.lora_update_lock = asyncio.Lock()
# For pd disaggregtion
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
......@@ -537,7 +542,8 @@ class TokenizerManager:
mm_inputs = None
if self.server_args.enable_lora and obj.lora_path:
# Replace the user-friendly LoRA names in `lora_path` with their corresponding unique LoRA IDs.
# Start tracking ongoing requests for LoRA adapters and replace the user-friendly LoRA names in
# `lora_path` with their corresponding unique LoRA IDs, as required for internal processing.
obj.lora_path = await self.lora_registry.acquire(obj.lora_path)
self._validate_one_request(obj, input_ids)
......@@ -747,6 +753,10 @@ class TokenizerManager:
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
logger.info(msg)
# Mark ongoing LoRA request as finished.
if self.server_args.enable_lora and obj.lora_path:
await self.lora_registry.release(obj.lora_path)
# Check if this was an abort/error created by scheduler
if isinstance(out["meta_info"].get("finish_reason"), dict):
finish_reason = out["meta_info"]["finish_reason"]
......@@ -1053,16 +1063,18 @@ class TokenizerManager:
obj.lora_path,
)
async with self.model_update_lock.writer_lock:
async with self.lora_update_lock:
# Generate new uniquely identifiable LoRARef object.
new_adapter = LoRARef(
lora_name=obj.lora_name,
lora_path=obj.lora_path,
)
# Register the new adapter in the registry.
# Trigger the actual loading operation at the backend processes.
obj.lora_id = new_adapter.lora_id
result = (await self.update_lora_adapter_communicator(obj))[0]
# Register the LoRA adapter only after loading is successful.
if result.success:
await self.lora_registry.register(new_adapter)
......@@ -1093,8 +1105,15 @@ class TokenizerManager:
obj.lora_name,
)
async with self.model_update_lock.writer_lock:
obj.lora_id = await self.lora_registry.unregister(obj.lora_name)
async with self.lora_update_lock:
# Unregister the LoRA adapter from the registry to stop new requests for this adapter
# from being started.
lora_id = await self.lora_registry.unregister(obj.lora_name)
obj.lora_id = lora_id
# Initiate the actual unloading operation at the backend processes only after all
# ongoing requests using this LoRA adapter are finished.
await self.lora_registry.wait_for_unload(lora_id)
result = (await self.update_lora_adapter_communicator(obj))[0]
return result
......
......@@ -15,6 +15,7 @@
from __future__ import annotations
import asyncio
import builtins
import ctypes
import dataclasses
......@@ -2862,3 +2863,89 @@ SUPPORTED_LORA_TARGET_MODULES = [
]
LORA_TARGET_ALL_MODULES = "all"
class ConcurrentCounter:
"""
An asynchronous counter for managing concurrent tasks that need
coordinated increments, decrements, and waiting until the count reaches zero.
This class is useful for scenarios like tracking the number of in-flight tasks
and waiting for them to complete.
"""
def __init__(self, initial: int = 0):
"""
Initialize the counter with an optional initial value.
Args:
initial (int): The initial value of the counter. Default is 0.
"""
self._count = initial
self._condition = asyncio.Condition()
def value(self) -> int:
"""
Return the current value of the counter.
Note:
This method is not synchronized. It may return a stale value
if other coroutines are concurrently modifying the counter.
Returns:
int: The current counter value.
"""
return self._count
def __repr__(self) -> str:
"""Return an informative string representation of the counter."""
return f"<ConcurrentCounter value={self.value()}>"
async def increment(self, n: int = 1, notify_all: bool = True):
"""
Atomically increment the counter by a given amount and notify all waiters.
Args:
n (int): The amount to increment the counter by. Default is 1.
notify_all (bool): Whether to notify all waiters after incrementing. Default is True.
"""
async with self._condition:
self._count += n
if notify_all:
self._condition.notify_all()
async def decrement(self, n: int = 1, notify_all: bool = True):
"""
Atomically decrement the counter by a given amount and notify all waiters.
Args:
n (int): The amount to decrement the counter by. Default is 1.
notify_all (bool): Whether to notify all waiters after decrementing. Default is True.
"""
async with self._condition:
self._count -= n
if notify_all:
self._condition.notify_all()
async def wait_for(self, condition: Callable[[int], bool]):
"""
Asynchronously wait until the counter satisfies a given condition.
This suspends the calling coroutine without blocking the thread, allowing
other tasks to run while waiting. When the condition is met, the coroutine resumes.
Args:
condition (Callable[[int], bool]): A function that takes the current counter value
and returns True when the condition is satisfied.
"""
async with self._condition:
await self._condition.wait_for(lambda: condition(self._count))
async def wait_for_zero(self):
"""
Asynchronously wait until the counter reaches zero.
This suspends the calling coroutine without blocking the thread, allowing
other tasks to run while waiting. When the counter becomes zero, the coroutine resumes.
"""
self.wait_for(lambda count: count == 0)
......@@ -231,8 +231,7 @@ class TestBenchServing(CustomTestCase):
f"median_ttft_ms: {res['median_ttft_ms']:.2f} ms\n"
)
self.assertLess(res["median_e2e_latency_ms"], 4000)
# TODO (lifuhuang): This will be fixed by the overlapped LoRA update in a separate PR.
self.assertLess(res["median_ttft_ms"], 1600)
self.assertLess(res["median_ttft_ms"], 80)
def _run_lora_latency_test(self, enable_background_task: bool):
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment