"docs/vscode:/vscode.git/clone" did not exist on "fcf5ad5f494eed6542d7a236346dcd7b359fad09"
Unverified Commit 1a08358a authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

Improve error handling for requests with unloaded LoRA path(s) (#7642)

parent f18a8fdd
...@@ -240,6 +240,12 @@ class TokenizerManager: ...@@ -240,6 +240,12 @@ class TokenizerManager:
revision=server_args.revision, revision=server_args.revision,
) )
# Initialize loaded loRA adapters with the initial lora paths in the server_args.
# This list will be updated when new LoRA adapters are loaded or unloaded dynamically.
self.loaded_lora_adapters: Dict[str, str] = dict(
self.server_args.lora_paths or {}
)
# Store states # Store states
self.no_create_loop = False self.no_create_loop = False
self.rid_to_state: Dict[str, ReqState] = {} self.rid_to_state: Dict[str, ReqState] = {}
...@@ -549,6 +555,8 @@ class TokenizerManager: ...@@ -549,6 +555,8 @@ class TokenizerManager:
"The server is not configured to enable custom logit processor. " "The server is not configured to enable custom logit processor. "
"Please set `--enable-custom-logits-processor` to enable this feature." "Please set `--enable-custom-logits-processor` to enable this feature."
) )
if self.server_args.lora_paths and obj.lora_path:
self._validate_lora_adapters(obj)
def _validate_input_ids_in_vocab( def _validate_input_ids_in_vocab(
self, input_ids: List[int], vocab_size: int self, input_ids: List[int], vocab_size: int
...@@ -662,6 +670,21 @@ class TokenizerManager: ...@@ -662,6 +670,21 @@ class TokenizerManager:
"Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`." "Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
) )
def _validate_lora_adapters(self, obj: GenerateReqInput):
"""Validate that the requested LoRA adapters are loaded."""
requested_adapters = (
set(obj.lora_path) if isinstance(obj.lora_path, list) else {obj.lora_path}
)
loaded_adapters = (
self.loaded_lora_adapters.keys() if self.loaded_lora_adapters else set()
)
unloaded_adapters = requested_adapters - loaded_adapters
if unloaded_adapters:
raise ValueError(
f"The following requested LoRA adapters are not loaded: {unloaded_adapters}\n"
f"Loaded adapters: {loaded_adapters}."
)
def _send_one_request( def _send_one_request(
self, self,
obj: Union[GenerateReqInput, EmbeddingReqInput], obj: Union[GenerateReqInput, EmbeddingReqInput],
...@@ -988,6 +1011,7 @@ class TokenizerManager: ...@@ -988,6 +1011,7 @@ class TokenizerManager:
async with self.model_update_lock.writer_lock: async with self.model_update_lock.writer_lock:
result = (await self.update_lora_adapter_communicator(obj))[0] result = (await self.update_lora_adapter_communicator(obj))[0]
self.loaded_lora_adapters = result.loaded_adapters
return result return result
async def unload_lora_adapter( async def unload_lora_adapter(
...@@ -1009,6 +1033,7 @@ class TokenizerManager: ...@@ -1009,6 +1033,7 @@ class TokenizerManager:
async with self.model_update_lock.writer_lock: async with self.model_update_lock.writer_lock:
result = (await self.update_lora_adapter_communicator(obj))[0] result = (await self.update_lora_adapter_communicator(obj))[0]
self.loaded_lora_adapters = result.loaded_adapters
return result return result
async def get_weights_by_name( async def get_weights_by_name(
......
...@@ -20,7 +20,7 @@ import logging ...@@ -20,7 +20,7 @@ import logging
import os import os
import random import random
import tempfile import tempfile
from typing import List, Literal, Optional from typing import List, Literal, Optional, Union
from sglang.srt.hf_transformers_utils import check_gguf_file, get_config from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
from sglang.srt.reasoning_parser import ReasoningParser from sglang.srt.reasoning_parser import ReasoningParser
...@@ -131,7 +131,7 @@ class ServerArgs: ...@@ -131,7 +131,7 @@ class ServerArgs:
preferred_sampling_params: Optional[str] = None preferred_sampling_params: Optional[str] = None
# LoRA # LoRA
lora_paths: Optional[List[str]] = None lora_paths: Optional[Union[dict[str, str], List[str]]] = None
max_loras_per_batch: int = 8 max_loras_per_batch: int = 8
lora_backend: str = "triton" lora_backend: str = "triton"
......
...@@ -16,7 +16,7 @@ import multiprocessing as mp ...@@ -16,7 +16,7 @@ import multiprocessing as mp
import unittest import unittest
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import List, Optional, Union from typing import Any, List, Optional, Union
import requests import requests
import torch import torch
...@@ -42,14 +42,16 @@ PROMPTS = [ ...@@ -42,14 +42,16 @@ PROMPTS = [
class OperationType(Enum): class OperationType(Enum):
LOAD = "load" LOAD = "load"
UNLOAD = "unload" UNLOAD = "unload"
NOOP = "noop"
FORWARD = "forward" FORWARD = "forward"
EXPECT_ERROR = "expect_error"
@dataclass @dataclass
class Operation: class Operation:
# Operation type, can be LOAD, UNLOAD, FORWARD, or EXPECT_ERROR
type: OperationType type: OperationType
data: Optional[str] # Data associated with the operation. Exact type varies depending on the operation
data: Optional[Any]
@dataclass @dataclass
...@@ -62,7 +64,7 @@ class TestCase: ...@@ -62,7 +64,7 @@ class TestCase:
max_new_tokens: int = 32 max_new_tokens: int = 32
def create_batch_data(adapters: Union[str, list]) -> dict: def create_batch_data(adapters: Union[str, list]) -> List[tuple[str, str]]:
if not isinstance(adapters, list): if not isinstance(adapters, list):
adapters = [adapters] adapters = [adapters]
return [(prompt, adapter) for prompt in PROMPTS for adapter in adapters] return [(prompt, adapter) for prompt in PROMPTS for adapter in adapters]
...@@ -80,6 +82,26 @@ TEST_CASES = [ ...@@ -80,6 +82,26 @@ TEST_CASES = [
], ],
initial_adapters=["philschmid/code-llama-3-1-8b-text-to-sql-lora"], initial_adapters=["philschmid/code-llama-3-1-8b-text-to-sql-lora"],
op_sequence=[ op_sequence=[
Operation(
type=OperationType.FORWARD,
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
),
Operation(
type=OperationType.EXPECT_ERROR,
data=(
create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
"not loaded",
),
),
Operation(
type=OperationType.EXPECT_ERROR,
data=(
create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
"not loaded",
),
),
Operation( Operation(
type=OperationType.LOAD, type=OperationType.LOAD,
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
...@@ -102,6 +124,13 @@ TEST_CASES = [ ...@@ -102,6 +124,13 @@ TEST_CASES = [
type=OperationType.UNLOAD, type=OperationType.UNLOAD,
data="philschmid/code-llama-3-1-8b-text-to-sql-lora", data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
), ),
Operation(
type=OperationType.EXPECT_ERROR,
data=(
create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
"not loaded",
),
),
Operation( Operation(
type=OperationType.FORWARD, type=OperationType.FORWARD,
data=create_batch_data( data=create_batch_data(
...@@ -115,6 +144,15 @@ TEST_CASES = [ ...@@ -115,6 +144,15 @@ TEST_CASES = [
type=OperationType.UNLOAD, type=OperationType.UNLOAD,
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
), ),
Operation(
type=OperationType.EXPECT_ERROR,
data=(
create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
"not loaded",
),
),
Operation( Operation(
type=OperationType.FORWARD, type=OperationType.FORWARD,
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"), data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
...@@ -149,6 +187,22 @@ TEST_CASES = [ ...@@ -149,6 +187,22 @@ TEST_CASES = [
type=OperationType.FORWARD, type=OperationType.FORWARD,
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"), data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
), ),
Operation(
type=OperationType.EXPECT_ERROR,
data=(
create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
"not loaded",
),
),
Operation(
type=OperationType.EXPECT_ERROR,
data=(
create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
"not loaded",
),
),
Operation( Operation(
type=OperationType.LOAD, type=OperationType.LOAD,
data="pbevan11/llama-3.1-8b-ocr-correction", data="pbevan11/llama-3.1-8b-ocr-correction",
...@@ -157,6 +211,13 @@ TEST_CASES = [ ...@@ -157,6 +211,13 @@ TEST_CASES = [
type=OperationType.UNLOAD, type=OperationType.UNLOAD,
data="philschmid/code-llama-3-1-8b-text-to-sql-lora", data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
), ),
Operation(
type=OperationType.EXPECT_ERROR,
data=(
create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
"not loaded",
),
),
Operation( Operation(
type=OperationType.FORWARD, type=OperationType.FORWARD,
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"), data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
...@@ -332,19 +393,31 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase): ...@@ -332,19 +393,31 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
prompts: List[str], prompts: List[str],
lora_paths: List[str], lora_paths: List[str],
max_new_tokens: int = 32, max_new_tokens: int = 32,
expected_error: str = None,
): ):
""" """
Perform a batch forward pass with the current set of loaded LoRA adapters. Perform a batch forward pass with the current set of loaded LoRA adapters.
""" """
response = self.handle.batch_forward( try:
prompts=prompts, response = self.handle.batch_forward(
lora_paths=lora_paths, prompts=prompts,
max_new_tokens=max_new_tokens, lora_paths=lora_paths,
) max_new_tokens=max_new_tokens,
output_strs = response.output_strs )
except ValueError as e:
if expected_error:
error_message = str(e)
self.testcase.assertIn(expected_error, error_message)
print(f"Received error as expected: {error_message}")
return error_message
raise e
self.testcase.assertEqual(len(response.output_strs), len(prompts))
output = response.output_strs
print(f"output_strs: {output}")
print(f"output_strs: {output_strs}") return output
return output_strs
class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase): class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
...@@ -426,6 +499,7 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase): ...@@ -426,6 +499,7 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
prompts: List[str], prompts: List[str],
lora_paths: List[str], lora_paths: List[str],
max_new_tokens: int = 32, max_new_tokens: int = 32,
expected_error: str = None,
): ):
""" """
Perform a batch forward pass with the current set of loaded LoRA adapters. Perform a batch forward pass with the current set of loaded LoRA adapters.
...@@ -442,11 +516,18 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase): ...@@ -442,11 +516,18 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
}, },
}, },
) )
self.testcase.assertTrue(response.ok) if expected_error:
output_strs = [r["text"] for r in response.json()] self.testcase.assertEqual(response.status_code, 400)
self.testcase.assertIn(expected_error, response.text)
print(f"output_strs: {output_strs}") output = response.text
return output_strs print(f"Received error as expected: {response.text}")
return output
else:
self.testcase.assertTrue(response.ok)
output = [r["text"] for r in response.json()]
self.testcase.assertEqual(len(output), len(prompts))
print(f"output_strs: {output}")
return output
# Factory function to create the appropriate LoRA test session based on mode # Factory function to create the appropriate LoRA test session based on mode
...@@ -535,14 +616,23 @@ class TestLoRADynamicUpdate(CustomTestCase): ...@@ -535,14 +616,23 @@ class TestLoRADynamicUpdate(CustomTestCase):
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
) )
forward_outputs.append(result) forward_outputs.append(result)
elif op_type == OperationType.EXPECT_ERROR:
input_data, expected_error = data
prompts, adapters = zip(*input_data)
result = session.forward(
prompts=list(prompts),
lora_paths=list(adapters),
max_new_tokens=max_new_tokens,
expected_error=expected_error,
)
return forward_outputs return forward_outputs
def test_dynamic_adapter_updates(self): def test_dynamic_adapter_updates(self):
for case_idx, test_case in enumerate(TEST_CASES, start=1): for case_idx, test_case in enumerate(TEST_CASES, start=1):
for mode in [ for mode in [
LoRAUpdateTestSessionMode.SERVER,
LoRAUpdateTestSessionMode.ENGINE, LoRAUpdateTestSessionMode.ENGINE,
LoRAUpdateTestSessionMode.SERVER,
]: ]:
print("=" * 100) print("=" * 100)
print(f"Starting test case {case_idx} in {mode.value} mode.") print(f"Starting test case {case_idx} in {mode.value} mode.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment