Unverified Commit 5ded39ca authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

Fix race condition in async lora unload (#9084)

parent 4093d460
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
import asyncio import asyncio
from collections import defaultdict
from dataclasses import dataclass, field, fields from dataclasses import dataclass, field, fields
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from uuid import uuid4 from uuid import uuid4
...@@ -106,7 +105,6 @@ class LoRARegistry: ...@@ -106,7 +105,6 @@ class LoRARegistry:
f"LoRA with name {lora_name} does not exist. Loaded LoRAs: {self._registry.keys()}" f"LoRA with name {lora_name} does not exist. Loaded LoRAs: {self._registry.keys()}"
) )
del self._registry[lora_name] del self._registry[lora_name]
del self._counters[lora_ref.lora_id]
return lora_ref.lora_id return lora_ref.lora_id
...@@ -117,6 +115,9 @@ class LoRARegistry: ...@@ -117,6 +115,9 @@ class LoRARegistry:
""" """
def _lookup(name: str) -> str: def _lookup(name: str) -> str:
if name is None:
return None
lora_ref = self._registry.get(name, None) lora_ref = self._registry.get(name, None)
if lora_ref is None: if lora_ref is None:
raise ValueError( raise ValueError(
...@@ -135,7 +136,11 @@ class LoRARegistry: ...@@ -135,7 +136,11 @@ class LoRARegistry:
# Increment the counters only after all IDs are looked up. # Increment the counters only after all IDs are looked up.
await asyncio.gather( await asyncio.gather(
*[self._counters[id].increment(notify_all=False) for id in lora_ids] *[
self._counters[id].increment(notify_all=False)
for id in lora_ids
if id is not None
]
) )
return lora_ids return lora_ids
else: else:
...@@ -153,7 +158,11 @@ class LoRARegistry: ...@@ -153,7 +158,11 @@ class LoRARegistry:
await self._counters[lora_id].decrement() await self._counters[lora_id].decrement()
elif isinstance(lora_id, list): elif isinstance(lora_id, list):
await asyncio.gather( await asyncio.gather(
*[self._counters[id].decrement() for id in lora_id] *[
self._counters[id].decrement()
for id in lora_id
if id is not None
]
) )
else: else:
raise TypeError("lora_id must be either a string or a list of strings.") raise TypeError("lora_id must be either a string or a list of strings.")
...@@ -169,10 +178,12 @@ class LoRARegistry: ...@@ -169,10 +178,12 @@ class LoRARegistry:
assert ( assert (
lora_id not in self._registry lora_id not in self._registry
), "wait_for_unload should only be called after the LoRA adapter has been unregistered. " ), "wait_for_unload should only be called after the LoRA adapter has been unregistered. "
counter = self._counters.get(lora_id) assert (
if counter: lora_id in self._counters
), "The LoRA ID should still have a counter if it has been registered before."
# Wait until no requests are using this LoRA adapter. # Wait until no requests are using this LoRA adapter.
await counter.wait_for_zero() await self._counters[lora_id].wait_for_zero()
del self._counters[lora_id] del self._counters[lora_id]
def _register_adapter(self, lora_ref: LoRARef): def _register_adapter(self, lora_ref: LoRARef):
......
...@@ -455,6 +455,7 @@ class GenerateReqInput: ...@@ -455,6 +455,7 @@ class GenerateReqInput:
log_metrics=self.log_metrics, log_metrics=self.log_metrics,
modalities=self.modalities[i] if self.modalities else None, modalities=self.modalities[i] if self.modalities else None,
lora_path=self.lora_path[i] if self.lora_path is not None else None, lora_path=self.lora_path[i] if self.lora_path is not None else None,
lora_id=self.lora_id[i] if self.lora_id is not None else None,
custom_logit_processor=( custom_logit_processor=(
self.custom_logit_processor[i] self.custom_logit_processor[i]
if self.custom_logit_processor is not None if self.custom_logit_processor is not None
......
...@@ -485,6 +485,10 @@ class TokenizerManager: ...@@ -485,6 +485,10 @@ class TokenizerManager:
await self.is_pause_cond.wait_for(lambda: not self.is_pause) await self.is_pause_cond.wait_for(lambda: not self.is_pause)
async with self.model_update_lock.reader_lock: async with self.model_update_lock.reader_lock:
if self.server_args.enable_lora and obj.lora_path:
# Look up the LoRA ID from the registry and start tracking ongoing LoRA requests.
obj.lora_id = await self.lora_registry.acquire(obj.lora_path)
if obj.is_single: if obj.is_single:
tokenized_obj = await self._tokenize_one_request(obj) tokenized_obj = await self._tokenize_one_request(obj)
state = self._send_one_request(obj, tokenized_obj, created_time) state = self._send_one_request(obj, tokenized_obj, created_time)
...@@ -552,11 +556,6 @@ class TokenizerManager: ...@@ -552,11 +556,6 @@ class TokenizerManager:
else: else:
mm_inputs = None mm_inputs = None
if self.server_args.enable_lora and obj.lora_path:
# Start tracking ongoing requests for LoRA adapters and replace the user-friendly LoRA names in
# `lora_path` with their corresponding unique LoRA IDs, as required for internal processing.
obj.lora_id = await self.lora_registry.acquire(obj.lora_path)
self._validate_one_request(obj, input_ids) self._validate_one_request(obj, input_ids)
return self._create_tokenized_object( return self._create_tokenized_object(
obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
...@@ -774,10 +773,6 @@ class TokenizerManager: ...@@ -774,10 +773,6 @@ class TokenizerManager:
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}" msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
logger.info(msg) logger.info(msg)
# Mark ongoing LoRA request as finished.
if self.server_args.enable_lora and obj.lora_path:
await self.lora_registry.release(obj.lora_id)
# Check if this was an abort/error created by scheduler # Check if this was an abort/error created by scheduler
if isinstance(out["meta_info"].get("finish_reason"), dict): if isinstance(out["meta_info"].get("finish_reason"), dict):
finish_reason = out["meta_info"]["finish_reason"] finish_reason = out["meta_info"]["finish_reason"]
...@@ -796,6 +791,11 @@ class TokenizerManager: ...@@ -796,6 +791,11 @@ class TokenizerManager:
# Delete the key to prevent resending abort request to the scheduler and # Delete the key to prevent resending abort request to the scheduler and
# to ensure aborted request state is cleaned up. # to ensure aborted request state is cleaned up.
del self.rid_to_state[state.obj.rid] del self.rid_to_state[state.obj.rid]
# Mark ongoing LoRA request as finished.
if self.server_args.enable_lora and state.obj.lora_path:
await self.lora_registry.release(state.obj.lora_id)
raise fastapi.HTTPException( raise fastapi.HTTPException(
status_code=finish_reason["status_code"], status_code=finish_reason["status_code"],
detail=finish_reason["message"], detail=finish_reason["message"],
...@@ -1599,6 +1599,10 @@ class TokenizerManager: ...@@ -1599,6 +1599,10 @@ class TokenizerManager:
meta_info["e2e_latency"] = state.finished_time - state.created_time meta_info["e2e_latency"] = state.finished_time - state.created_time
del self.rid_to_state[rid] del self.rid_to_state[rid]
# Mark ongoing LoRA request as finished.
if self.server_args.enable_lora and state.obj.lora_path:
asyncio.create_task(self.lora_registry.release(state.obj.lora_id))
state.out_list.append(out_dict) state.out_list.append(out_dict)
state.event.set() state.event.set()
......
...@@ -2960,7 +2960,7 @@ class ConcurrentCounter: ...@@ -2960,7 +2960,7 @@ class ConcurrentCounter:
This suspends the calling coroutine without blocking the thread, allowing This suspends the calling coroutine without blocking the thread, allowing
other tasks to run while waiting. When the counter becomes zero, the coroutine resumes. other tasks to run while waiting. When the counter becomes zero, the coroutine resumes.
""" """
self.wait_for(lambda count: count == 0) await self.wait_for(lambda count: count == 0)
@lru_cache(maxsize=1) @lru_cache(maxsize=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment