"vscode:/vscode.git/clone" did not exist on "87bca129a48f3b4c9248577ac8273a4a8e6cdf40"
Unverified Commit 1a08358a authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

Improve error handling for requests with unloaded LoRA path(s) (#7642)

parent f18a8fdd
......@@ -240,6 +240,12 @@ class TokenizerManager:
revision=server_args.revision,
)
# Initialize loaded loRA adapters with the initial lora paths in the server_args.
# This list will be updated when new LoRA adapters are loaded or unloaded dynamically.
self.loaded_lora_adapters: Dict[str, str] = dict(
self.server_args.lora_paths or {}
)
# Store states
self.no_create_loop = False
self.rid_to_state: Dict[str, ReqState] = {}
......@@ -549,6 +555,8 @@ class TokenizerManager:
"The server is not configured to enable custom logit processor. "
"Please set `--enable-custom-logits-processor` to enable this feature."
)
if self.server_args.lora_paths and obj.lora_path:
self._validate_lora_adapters(obj)
def _validate_input_ids_in_vocab(
self, input_ids: List[int], vocab_size: int
......@@ -662,6 +670,21 @@ class TokenizerManager:
"Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
)
def _validate_lora_adapters(self, obj: GenerateReqInput):
"""Validate that the requested LoRA adapters are loaded."""
requested_adapters = (
set(obj.lora_path) if isinstance(obj.lora_path, list) else {obj.lora_path}
)
loaded_adapters = (
self.loaded_lora_adapters.keys() if self.loaded_lora_adapters else set()
)
unloaded_adapters = requested_adapters - loaded_adapters
if unloaded_adapters:
raise ValueError(
f"The following requested LoRA adapters are not loaded: {unloaded_adapters}\n"
f"Loaded adapters: {loaded_adapters}."
)
def _send_one_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
......@@ -988,6 +1011,7 @@ class TokenizerManager:
async with self.model_update_lock.writer_lock:
result = (await self.update_lora_adapter_communicator(obj))[0]
self.loaded_lora_adapters = result.loaded_adapters
return result
async def unload_lora_adapter(
......@@ -1009,6 +1033,7 @@ class TokenizerManager:
async with self.model_update_lock.writer_lock:
result = (await self.update_lora_adapter_communicator(obj))[0]
self.loaded_lora_adapters = result.loaded_adapters
return result
async def get_weights_by_name(
......
......@@ -20,7 +20,7 @@ import logging
import os
import random
import tempfile
from typing import List, Literal, Optional
from typing import List, Literal, Optional, Union
from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
from sglang.srt.reasoning_parser import ReasoningParser
......@@ -131,7 +131,7 @@ class ServerArgs:
preferred_sampling_params: Optional[str] = None
# LoRA
lora_paths: Optional[List[str]] = None
lora_paths: Optional[Union[dict[str, str], List[str]]] = None
max_loras_per_batch: int = 8
lora_backend: str = "triton"
......
......@@ -16,7 +16,7 @@ import multiprocessing as mp
import unittest
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional, Union
from typing import Any, List, Optional, Union
import requests
import torch
......@@ -42,14 +42,16 @@ PROMPTS = [
class OperationType(Enum):
LOAD = "load"
UNLOAD = "unload"
NOOP = "noop"
FORWARD = "forward"
EXPECT_ERROR = "expect_error"
@dataclass
class Operation:
# Operation type, can be LOAD, UNLOAD, FORWARD, or EXPECT_ERROR
type: OperationType
data: Optional[str]
# Data associated with the operation. Exact type varies depending on the operation
data: Optional[Any]
@dataclass
......@@ -62,7 +64,7 @@ class TestCase:
max_new_tokens: int = 32
def create_batch_data(adapters: Union[str, list]) -> dict:
def create_batch_data(adapters: Union[str, list]) -> List[tuple[str, str]]:
if not isinstance(adapters, list):
adapters = [adapters]
return [(prompt, adapter) for prompt in PROMPTS for adapter in adapters]
......@@ -80,6 +82,26 @@ TEST_CASES = [
],
initial_adapters=["philschmid/code-llama-3-1-8b-text-to-sql-lora"],
op_sequence=[
Operation(
type=OperationType.FORWARD,
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
),
Operation(
type=OperationType.EXPECT_ERROR,
data=(
create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
"not loaded",
),
),
Operation(
type=OperationType.EXPECT_ERROR,
data=(
create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
"not loaded",
),
),
Operation(
type=OperationType.LOAD,
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
......@@ -102,6 +124,13 @@ TEST_CASES = [
type=OperationType.UNLOAD,
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
),
Operation(
type=OperationType.EXPECT_ERROR,
data=(
create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
"not loaded",
),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
......@@ -115,6 +144,15 @@ TEST_CASES = [
type=OperationType.UNLOAD,
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
),
Operation(
type=OperationType.EXPECT_ERROR,
data=(
create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
"not loaded",
),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
......@@ -149,6 +187,22 @@ TEST_CASES = [
type=OperationType.FORWARD,
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
),
Operation(
type=OperationType.EXPECT_ERROR,
data=(
create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
"not loaded",
),
),
Operation(
type=OperationType.EXPECT_ERROR,
data=(
create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
"not loaded",
),
),
Operation(
type=OperationType.LOAD,
data="pbevan11/llama-3.1-8b-ocr-correction",
......@@ -157,6 +211,13 @@ TEST_CASES = [
type=OperationType.UNLOAD,
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
),
Operation(
type=OperationType.EXPECT_ERROR,
data=(
create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
"not loaded",
),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
......@@ -332,19 +393,31 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
prompts: List[str],
lora_paths: List[str],
max_new_tokens: int = 32,
expected_error: str = None,
):
"""
Perform a batch forward pass with the current set of loaded LoRA adapters.
"""
response = self.handle.batch_forward(
prompts=prompts,
lora_paths=lora_paths,
max_new_tokens=max_new_tokens,
)
output_strs = response.output_strs
try:
response = self.handle.batch_forward(
prompts=prompts,
lora_paths=lora_paths,
max_new_tokens=max_new_tokens,
)
except ValueError as e:
if expected_error:
error_message = str(e)
self.testcase.assertIn(expected_error, error_message)
print(f"Received error as expected: {error_message}")
return error_message
raise e
self.testcase.assertEqual(len(response.output_strs), len(prompts))
output = response.output_strs
print(f"output_strs: {output}")
print(f"output_strs: {output_strs}")
return output_strs
return output
class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
......@@ -426,6 +499,7 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
prompts: List[str],
lora_paths: List[str],
max_new_tokens: int = 32,
expected_error: str = None,
):
"""
Perform a batch forward pass with the current set of loaded LoRA adapters.
......@@ -442,11 +516,18 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
},
},
)
self.testcase.assertTrue(response.ok)
output_strs = [r["text"] for r in response.json()]
print(f"output_strs: {output_strs}")
return output_strs
if expected_error:
self.testcase.assertEqual(response.status_code, 400)
self.testcase.assertIn(expected_error, response.text)
output = response.text
print(f"Received error as expected: {response.text}")
return output
else:
self.testcase.assertTrue(response.ok)
output = [r["text"] for r in response.json()]
self.testcase.assertEqual(len(output), len(prompts))
print(f"output_strs: {output}")
return output
# Factory function to create the appropriate LoRA test session based on mode
......@@ -535,14 +616,23 @@ class TestLoRADynamicUpdate(CustomTestCase):
max_new_tokens=max_new_tokens,
)
forward_outputs.append(result)
elif op_type == OperationType.EXPECT_ERROR:
input_data, expected_error = data
prompts, adapters = zip(*input_data)
result = session.forward(
prompts=list(prompts),
lora_paths=list(adapters),
max_new_tokens=max_new_tokens,
expected_error=expected_error,
)
return forward_outputs
def test_dynamic_adapter_updates(self):
for case_idx, test_case in enumerate(TEST_CASES, start=1):
for mode in [
LoRAUpdateTestSessionMode.SERVER,
LoRAUpdateTestSessionMode.ENGINE,
LoRAUpdateTestSessionMode.SERVER,
]:
print("=" * 100)
print(f"Starting test case {case_idx} in {mode.value} mode.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment