Unverified Commit 3c2b72b0 authored by Biswa Panda's avatar Biswa Panda Committed by GitHub
Browse files

fix: [lora] refactor test and clean up examples (#4884)

parent 242a4d5b
...@@ -4,14 +4,6 @@ ...@@ -4,14 +4,6 @@
set -e set -e
trap 'echo Cleaning up...; kill 0' EXIT trap 'echo Cleaning up...; kill 0' EXIT
# Follow the README.md instructions to setup MinIO or upload the LoRA to s3/minio
# Adjust these values to match your local MinIO or S3 setup
# load math lora to minio
# LORA_NAME=Neural-Hacker/Qwen3-Math-Reasoning-LoRA HF_LORA_REPO=Neural-Hacker/Qwen3-Math-Reasoning-LoRA ./setup_minio.sh
export AWS_ENDPOINT=http://localhost:9000 export AWS_ENDPOINT=http://localhost:9000
export AWS_ACCESS_KEY_ID=minioadmin export AWS_ACCESS_KEY_ID=minioadmin
export AWS_SECRET_ACCESS_KEY=minioadmin export AWS_SECRET_ACCESS_KEY=minioadmin
...@@ -21,8 +13,6 @@ export AWS_ALLOW_HTTP=true ...@@ -21,8 +13,6 @@ export AWS_ALLOW_HTTP=true
# Dynamo LoRA Configuration # Dynamo LoRA Configuration
export DYN_LORA_ENABLED=true export DYN_LORA_ENABLED=true
export DYN_LORA_PATH=/tmp/dynamo_loras_minio export DYN_LORA_PATH=/tmp/dynamo_loras_minio
export DYN_LOG=debug
# export DYN_LOG_LEVEL=debug
mkdir -p $DYN_LORA_PATH mkdir -p $DYN_LORA_PATH
...@@ -63,7 +53,7 @@ curl -X POST http://localhost:8000/v1/chat/completions \ ...@@ -63,7 +53,7 @@ curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{ -d '{
"model": "Qwen/Qwen3-0.6B", "model": "Qwen/Qwen3-0.6B",
"messages": [{"role": "user", "content": "Solve (x*x - x + 1 = 0) for x"}], "messages": [{"role": "user", "content": "What is deep learning?"}],
"max_tokens": 300, "max_tokens": 300,
"temperature": 0.0 "temperature": 0.0
}' }'
......
...@@ -4,12 +4,6 @@ ...@@ -4,12 +4,6 @@
set -e set -e
trap 'echo Cleaning up...; kill 0' EXIT trap 'echo Cleaning up...; kill 0' EXIT
# Follow the README.md instructions to setup MinIO or upload the LoRA to s3/minio
# Adjust these values to match your local MinIO or S3 setup
# load math lora to minio
# LORA_NAME=Neural-Hacker/Qwen3-Math-Reasoning-LoRA HF_LORA_REPO=Neural-Hacker/Qwen3-Math-Reasoning-LoRA ./setup_minio.sh
export AWS_ENDPOINT=http://localhost:9000 export AWS_ENDPOINT=http://localhost:9000
export AWS_ACCESS_KEY_ID=minioadmin export AWS_ACCESS_KEY_ID=minioadmin
export AWS_SECRET_ACCESS_KEY=minioadmin export AWS_SECRET_ACCESS_KEY=minioadmin
...@@ -19,8 +13,6 @@ export AWS_ALLOW_HTTP=true ...@@ -19,8 +13,6 @@ export AWS_ALLOW_HTTP=true
# Dynamo LoRA Configuration # Dynamo LoRA Configuration
export DYN_LORA_ENABLED=true export DYN_LORA_ENABLED=true
export DYN_LORA_PATH=/tmp/dynamo_loras_minio export DYN_LORA_PATH=/tmp/dynamo_loras_minio
export DYN_LOG=debug
# export DYN_LOG_LEVEL=debug
mkdir -p $DYN_LORA_PATH mkdir -p $DYN_LORA_PATH
...@@ -118,7 +110,7 @@ curl localhost:8000/v1/chat/completions \ ...@@ -118,7 +110,7 @@ curl localhost:8000/v1/chat/completions \
"total_tokens": 226, "total_tokens": 226,
"prompt_tokens_details": { "prompt_tokens_details": {
"audio_tokens": null, "audio_tokens": null,
"cached_tokens": 192 "cached_tokens": 192 # tokens that were cached from the previous request.
} }
}, },
"nvext": { "nvext": {
......
...@@ -86,8 +86,8 @@ def minio_lora_service(): ...@@ -86,8 +86,8 @@ def minio_lora_service():
local_path = service.download_lora() local_path = service.download_lora()
service.upload_lora(local_path) service.upload_lora(local_path)
# Clean up downloaded files (keep MinIO running) # Clean up downloaded files (keep MinIO data intact)
service.cleanup_temp() service.cleanup_download()
yield config yield config
......
...@@ -61,7 +61,7 @@ class MinioService: ...@@ -61,7 +61,7 @@ class MinioService:
def __init__(self, config: MinioLoraConfig): def __init__(self, config: MinioLoraConfig):
self.config = config self.config = config
self._logger = logging.getLogger(self.__class__.__name__) self._logger = logging.getLogger(self.__class__.__name__)
self._temp_dir: Optional[str] = None self._temp_download_dir: Optional[str] = None
def start(self) -> None: def start(self) -> None:
"""Start MinIO container""" """Start MinIO container"""
...@@ -183,9 +183,9 @@ class MinioService: ...@@ -183,9 +183,9 @@ class MinioService:
def download_lora(self) -> str: def download_lora(self) -> str:
"""Download LoRA from Hugging Face Hub, returns temp directory path""" """Download LoRA from Hugging Face Hub, returns temp directory path"""
self._temp_dir = tempfile.mkdtemp(prefix="lora_download_") self._temp_download_dir = tempfile.mkdtemp(prefix="lora_download_")
self._logger.info( self._logger.info(
f"Downloading LoRA {self.config.lora_repo} to {self._temp_dir}" f"Downloading LoRA {self.config.lora_repo} to {self._temp_download_dir}"
) )
result = subprocess.run( result = subprocess.run(
...@@ -194,7 +194,7 @@ class MinioService: ...@@ -194,7 +194,7 @@ class MinioService:
"download", "download",
self.config.lora_repo, self.config.lora_repo,
"--local-dir", "--local-dir",
self._temp_dir, self._temp_download_dir,
"--local-dir-use-symlinks", "--local-dir-use-symlinks",
"False", "False",
], ],
...@@ -206,11 +206,11 @@ class MinioService: ...@@ -206,11 +206,11 @@ class MinioService:
raise RuntimeError(f"Failed to download LoRA: {result.stderr}") raise RuntimeError(f"Failed to download LoRA: {result.stderr}")
# Clean up cache directory # Clean up cache directory
cache_dir = os.path.join(self._temp_dir, ".cache") cache_dir = os.path.join(self._temp_download_dir, ".cache")
if os.path.exists(cache_dir): if os.path.exists(cache_dir):
shutil.rmtree(cache_dir) shutil.rmtree(cache_dir)
return self._temp_dir return self._temp_download_dir
def upload_lora(self, local_path: str) -> None: def upload_lora(self, local_path: str) -> None:
"""Upload LoRA to MinIO""" """Upload LoRA to MinIO"""
...@@ -246,11 +246,15 @@ class MinioService: ...@@ -246,11 +246,15 @@ class MinioService:
if result.returncode != 0: if result.returncode != 0:
raise RuntimeError(f"Failed to upload LoRA: {result.stderr}") raise RuntimeError(f"Failed to upload LoRA: {result.stderr}")
def cleanup_download(self) -> None:
"""Clean up temporary download directory only"""
if self._temp_download_dir and os.path.exists(self._temp_download_dir):
shutil.rmtree(self._temp_download_dir)
self._temp_download_dir = None
def cleanup_temp(self) -> None: def cleanup_temp(self) -> None:
"""Clean up temporary directories""" """Clean up all temporary directories including MinIO data dir"""
if self._temp_dir and os.path.exists(self._temp_dir): self.cleanup_download()
shutil.rmtree(self._temp_dir)
self._temp_dir = None
if self.config.data_dir and os.path.exists(self.config.data_dir): if self.config.data_dir and os.path.exists(self.config.data_dir):
shutil.rmtree(self.config.data_dir, ignore_errors=True) shutil.rmtree(self.config.data_dir, ignore_errors=True)
......
...@@ -16,7 +16,7 @@ from tests.serve.common import ( ...@@ -16,7 +16,7 @@ from tests.serve.common import (
run_serve_deployment, run_serve_deployment,
) )
from tests.serve.conftest import MULTIMODAL_IMG_PATH, MULTIMODAL_IMG_URL from tests.serve.conftest import MULTIMODAL_IMG_PATH, MULTIMODAL_IMG_URL
from tests.serve.lora_utils import MinioLoraConfig, load_lora_adapter from tests.serve.lora_utils import MinioLoraConfig
from tests.utils.engine_process import EngineConfig from tests.utils.engine_process import EngineConfig
from tests.utils.payload_builder import ( from tests.utils.payload_builder import (
chat_payload, chat_payload,
...@@ -26,7 +26,7 @@ from tests.utils.payload_builder import ( ...@@ -26,7 +26,7 @@ from tests.utils.payload_builder import (
completion_payload_with_logprobs, completion_payload_with_logprobs,
metric_payload_default, metric_payload_default,
) )
from tests.utils.payloads import ChatPayload, ToolCallingChatPayload from tests.utils.payloads import LoraTestChatPayload, ToolCallingChatPayload
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -614,93 +614,6 @@ def test_multimodal_b64(request, runtime_services, predownload_models): ...@@ -614,93 +614,6 @@ def test_multimodal_b64(request, runtime_services, predownload_models):
lora_dir = os.path.join(vllm_dir, "launch/lora") lora_dir = os.path.join(vllm_dir, "launch/lora")
class LoraTestChatPayload(ChatPayload):
"""
Chat payload that loads a LoRA adapter before sending inference requests.
This payload first loads the specified LoRA adapter via the system API,
then sends chat completion requests using the LoRA model.
"""
def __init__(
self,
body: dict,
lora_name: str,
s3_uri: str,
system_port: int = 8081,
repeat_count: int = 1,
expected_response: Optional[list] = None,
expected_log: Optional[list] = None,
timeout: int = 60,
):
super().__init__(
body=body,
repeat_count=repeat_count,
expected_response=expected_response or [],
expected_log=expected_log or [],
timeout=timeout,
)
self.system_port = system_port
self.lora_name = lora_name
self.s3_uri = s3_uri
self._lora_loaded = False
def _ensure_lora_loaded(self) -> None:
"""Ensure the LoRA adapter is loaded before making inference requests"""
if not self._lora_loaded:
import time
import requests
load_lora_adapter(
system_port=self.system_port,
lora_name=self.lora_name,
s3_uri=self.s3_uri,
timeout=self.timeout,
)
# Wait for the LoRA model to appear in /v1/models
models_url = f"http://{self.host}:{self.port}/v1/models"
start_time = time.time()
max_wait = 60 # 1 minute timeout
logger.info(
f"Waiting for LoRA model '{self.lora_name}' to appear in /v1/models..."
)
while time.time() - start_time < max_wait:
try:
response = requests.get(models_url, timeout=5)
if response.status_code == 200:
data = response.json()
models = data.get("data", [])
model_ids = [m.get("id", "") for m in models]
if self.lora_name in model_ids:
logger.info(
f"LoRA model '{self.lora_name}' is now available"
)
self._lora_loaded = True
return
logger.debug(
f"Available models: {model_ids}, waiting for '{self.lora_name}'..."
)
except requests.RequestException as e:
logger.debug(f"Error checking /v1/models: {e}")
time.sleep(1)
raise RuntimeError(
f"Timeout: LoRA model '{self.lora_name}' did not appear in /v1/models within {max_wait}s"
)
def url(self) -> str:
"""Load LoRA before first request, then return URL"""
self._ensure_lora_loaded()
return super().url()
def lora_chat_payload( def lora_chat_payload(
lora_name: str, lora_name: str,
s3_uri: str, s3_uri: str,
......
...@@ -21,6 +21,8 @@ from copy import deepcopy ...@@ -21,6 +21,8 @@ from copy import deepcopy
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
import requests
from dynamo import prometheus_names # type: ignore[attr-defined] from dynamo import prometheus_names # type: ignore[attr-defined]
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -240,6 +242,93 @@ class ToolCallingChatPayload(ChatPayload): ...@@ -240,6 +242,93 @@ class ToolCallingChatPayload(ChatPayload):
logger.info(f"Expected tool '{self.expected_tool_name}' was called") logger.info(f"Expected tool '{self.expected_tool_name}' was called")
@dataclass
class LoraTestChatPayload(ChatPayload):
"""
Chat payload that loads a LoRA adapter before sending inference requests.
This payload first loads the specified LoRA adapter via the system API,
then sends chat completion requests using the LoRA model.
"""
def __init__(
self,
body: dict,
lora_name: str,
s3_uri: str,
system_port: int = 8081,
repeat_count: int = 1,
expected_response: Optional[list] = None,
expected_log: Optional[list] = None,
timeout: int = 60,
):
super().__init__(
body=body,
repeat_count=repeat_count,
expected_response=expected_response or [],
expected_log=expected_log or [],
timeout=timeout,
)
self.system_port = system_port
self.lora_name = lora_name
self.s3_uri = s3_uri
self._lora_loaded = False
def _ensure_lora_loaded(self) -> None:
"""Ensure the LoRA adapter is loaded before making inference requests"""
if not self._lora_loaded:
# Import the load_lora_adapter function
# Note: This import is done here to avoid circular dependencies
from tests.serve.lora_utils import load_lora_adapter
load_lora_adapter(
system_port=self.system_port,
lora_name=self.lora_name,
s3_uri=self.s3_uri,
timeout=self.timeout,
)
# Wait for the LoRA model to appear in /v1/models
models_url = f"http://{self.host}:{self.port}/v1/models"
start_time = time.time()
logger.info(
f"Waiting for LoRA model '{self.lora_name}' to appear in /v1/models..."
)
while time.time() - start_time < self.timeout:
try:
response = requests.get(models_url, timeout=5)
if response.status_code == 200:
data = response.json()
models = data.get("data", [])
model_ids = [m.get("id", "") for m in models]
if self.lora_name in model_ids:
logger.info(
f"LoRA model '{self.lora_name}' is now available"
)
self._lora_loaded = True
return
logger.debug(
f"Available models: {model_ids}, waiting for '{self.lora_name}'..."
)
except requests.RequestException as e:
logger.debug(f"Error checking /v1/models: {e}")
time.sleep(1)
raise RuntimeError(
f"Timeout: LoRA model '{self.lora_name}' did not appear in /v1/models within {self.timeout}s"
)
def url(self) -> str:
"""Load LoRA before first request, then return URL"""
self._ensure_lora_loaded()
return super().url()
@dataclass @dataclass
class CompletionPayload(BasePayload): class CompletionPayload(BasePayload):
"""Payload for completions endpoint.""" """Payload for completions endpoint."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment