lora_utils.py 10.5 KB
Newer Older
1
2
3
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

4
5
6
7
8
9
"""
MinIO Service and LoRA Test Utilities.

Provides infrastructure for LoRA adapter testing with S3-compatible storage.
Works in both CI (pre-started MinIO) and local development (auto-starts Docker).
"""
10
11
12
13
14
15
16
17

import logging
import os
import shutil
import subprocess
import tempfile
import time
from dataclasses import dataclass
18
from pathlib import Path
19
20
from typing import Optional

21
import boto3
22
import requests
23
24
from botocore.client import Config
from botocore.exceptions import ClientError
25
26
27
28
29
30
31
32
33
34
35
36
37
38

logger = logging.getLogger(__name__)

# LoRA testing constants
MINIO_ENDPOINT = "http://localhost:9000"
MINIO_ACCESS_KEY = "minioadmin"
MINIO_SECRET_KEY = "minioadmin"
MINIO_BUCKET = "my-loras"
DEFAULT_LORA_REPO = "codelion/Qwen3-0.6B-accuracy-recovery-lora"
DEFAULT_LORA_NAME = "codelion/Qwen3-0.6B-accuracy-recovery-lora"


@dataclass
class MinioLoraConfig:
39
    """Configuration for MinIO and LoRA setup."""
40
41
42
43
44
45
46
47
48
49

    endpoint: str = MINIO_ENDPOINT
    access_key: str = MINIO_ACCESS_KEY
    secret_key: str = MINIO_SECRET_KEY
    bucket: str = MINIO_BUCKET
    lora_repo: str = DEFAULT_LORA_REPO
    lora_name: str = DEFAULT_LORA_NAME
    data_dir: Optional[str] = None

    def get_s3_uri(self) -> str:
50
        """Get the S3 URI for the LoRA adapter."""
51
52
53
        return f"s3://{self.bucket}/{self.lora_name}"

    def get_env_vars(self) -> dict:
54
        """Get environment variables for AWS/MinIO access."""
55
56
57
58
59
60
61
62
63
64
65
66
        return {
            "AWS_ENDPOINT": self.endpoint,
            "AWS_ACCESS_KEY_ID": self.access_key,
            "AWS_SECRET_ACCESS_KEY": self.secret_key,
            "AWS_REGION": "us-east-1",
            "AWS_ALLOW_HTTP": "true",
            "DYN_LORA_ENABLED": "true",
            "DYN_LORA_PATH": "/tmp/dynamo_loras_minio_test",
        }


class MinioService:
67
68
69
70
71
72
73
74
    """
    Manages MinIO service lifecycle for tests.

    Follows a "connect or create" pattern:
    - First checks if MinIO is already running (CI or manual)
    - If not, starts a Docker container (local development)
    - Only cleans up containers it created
    """
75
76
77
78
79
80

    CONTAINER_NAME = "dynamo-minio-test"

    def __init__(self, config: MinioLoraConfig):
        self.config = config
        self._logger = logging.getLogger(self.__class__.__name__)
81
        self._temp_download_dir: Optional[str] = None
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
        self._s3_client = None
        self._owns_container: bool = False

    def _get_s3_client(self):
        """Get or create boto3 S3 client for MinIO."""
        if self._s3_client is None:
            self._s3_client = boto3.client(
                "s3",
                endpoint_url=self.config.endpoint,
                aws_access_key_id=self.config.access_key,
                aws_secret_access_key=self.config.secret_key,
                config=Config(signature_version="s3v4"),
                region_name="us-east-1",
            )
        return self._s3_client

    def _is_healthy(self) -> bool:
        """Check if MinIO is running and healthy."""
        health_url = f"{self.config.endpoint}/minio/health/live"
        try:
            response = requests.get(health_url, timeout=2)
            return response.status_code == 200
        except requests.RequestException:
            return False

    def _is_docker_available(self) -> bool:
        """Check if Docker daemon is accessible."""
        try:
            result = subprocess.run(["docker", "info"], capture_output=True, timeout=5)
            return result.returncode == 0
        except (subprocess.SubprocessError, FileNotFoundError):
            return False
114
115

    def start(self) -> None:
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
        """
        Connect to MinIO service, starting a container if necessary.

        Raises:
            RuntimeError: If MinIO cannot be started or connected to.
        """
        self._logger.info("Connecting to MinIO...")

        # Check if MinIO is already running
        if self._is_healthy():
            self._logger.info("Connected to existing MinIO instance")
            self._owns_container = False
            return

        # Try to start Docker container
        if not self._is_docker_available():
            raise RuntimeError(
                "MinIO is not available and Docker is not accessible.\n"
                "Start MinIO manually:\n"
                "  docker run -d -p 9000:9000 -p 9001:9001 "
                "-e MINIO_ROOT_USER=minioadmin -e MINIO_ROOT_PASSWORD=minioadmin "
                f"--name {self.CONTAINER_NAME} "
                "quay.io/minio/minio server /data --console-address ':9001'"
            )
140

141
142
143
        self._start_container()
        self._owns_container = True
        self._logger.info("MinIO container started successfully")
144

145
146
147
148
149
150
151
152
153
154
155
    def _start_container(self) -> None:
        """Start MinIO Docker container."""
        # Clean up any existing container
        subprocess.run(
            ["docker", "rm", "-f", self.CONTAINER_NAME],
            capture_output=True,
        )

        # Create data directory
        if not self.config.data_dir:
            self.config.data_dir = tempfile.mkdtemp(prefix="minio_test_")
156
157
158
159
160
161
162
163
164
165
166

        cmd = [
            "docker",
            "run",
            "-d",
            "--name",
            self.CONTAINER_NAME,
            "-p",
            "9000:9000",
            "-p",
            "9001:9001",
167
168
169
170
            "-e",
            f"MINIO_ROOT_USER={self.config.access_key}",
            "-e",
            f"MINIO_ROOT_PASSWORD={self.config.secret_key}",
171
            "-v",
172
            f"{self.config.data_dir}:/data",
173
174
175
176
177
178
179
180
181
182
183
184
185
186
            "quay.io/minio/minio",
            "server",
            "/data",
            "--console-address",
            ":9001",
        ]

        result = subprocess.run(cmd, capture_output=True, text=True)
        if result.returncode != 0:
            raise RuntimeError(f"Failed to start MinIO: {result.stderr}")

        self._wait_for_ready()

    def _wait_for_ready(self, timeout: int = 30) -> None:
187
        """Wait for MinIO to be ready."""
188
189
190
        start_time = time.time()

        while time.time() - start_time < timeout:
191
192
            if self._is_healthy():
                return
193
194
195
196
197
            time.sleep(1)

        raise RuntimeError(f"MinIO did not become ready within {timeout}s")

    def stop(self) -> None:
198
199
200
201
        """Stop MinIO container if this instance started it."""
        if not self._owns_container:
            self._logger.debug("Not stopping MinIO (not owned by this instance)")
            return
202

203
        self._logger.info("Stopping MinIO container...")
204
        subprocess.run(
205
            ["docker", "rm", "-f", self.CONTAINER_NAME],
206
207
            capture_output=True,
        )
208
        self._owns_container = False
209
210

    def create_bucket(self) -> None:
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
        """Create the S3 bucket if it doesn't exist."""
        s3_client = self._get_s3_client()

        try:
            s3_client.head_bucket(Bucket=self.config.bucket)
            self._logger.info(f"Bucket already exists: {self.config.bucket}")
        except ClientError as e:
            error_code = e.response.get("Error", {}).get("Code", "")
            if error_code in ("404", "NoSuchBucket"):
                self._logger.info(f"Creating bucket: {self.config.bucket}")
                try:
                    s3_client.create_bucket(Bucket=self.config.bucket)
                except ClientError as create_error:
                    raise RuntimeError(
                        f"Failed to create bucket: {create_error}"
                    ) from create_error
            else:
                raise RuntimeError(f"Failed to check bucket: {e}") from e
229
230

    def download_lora(self) -> str:
231
        """Download LoRA from Hugging Face Hub, returns temp directory path."""
232
        self._temp_download_dir = tempfile.mkdtemp(prefix="lora_download_")
233
        self._logger.info(
234
            f"Downloading LoRA {self.config.lora_repo} to {self._temp_download_dir}"
235
236
237
238
239
240
241
242
        )

        result = subprocess.run(
            [
                "huggingface-cli",
                "download",
                self.config.lora_repo,
                "--local-dir",
243
                self._temp_download_dir,
244
245
246
247
248
249
250
251
252
253
254
                "--local-dir-use-symlinks",
                "False",
            ],
            capture_output=True,
            text=True,
        )

        if result.returncode != 0:
            raise RuntimeError(f"Failed to download LoRA: {result.stderr}")

        # Clean up cache directory
255
        cache_dir = os.path.join(self._temp_download_dir, ".cache")
256
257
258
        if os.path.exists(cache_dir):
            shutil.rmtree(cache_dir)

259
        return self._temp_download_dir
260
261

    def upload_lora(self, local_path: str) -> None:
262
        """Upload LoRA to MinIO using boto3."""
263
264
265
266
        self._logger.info(
            f"Uploading LoRA to s3://{self.config.bucket}/{self.config.lora_name}"
        )

267
268
        s3_client = self._get_s3_client()
        local_path = Path(local_path)
269

270
271
272
273
274
        for file_path in local_path.rglob("*"):
            if not file_path.is_file():
                continue
            if ".git" in file_path.parts:
                continue
275

276
277
278
279
280
281
282
283
284
            relative_path = file_path.relative_to(local_path).as_posix()
            s3_key = f"{self.config.lora_name}/{relative_path}"

            try:
                s3_client.upload_file(str(file_path), self.config.bucket, s3_key)
            except ClientError as e:
                raise RuntimeError(f"Failed to upload {file_path}: {e}") from e

        self._logger.info("LoRA upload completed")
285

286
    def cleanup_download(self) -> None:
287
        """Clean up temporary download directory only."""
288
289
290
291
        if self._temp_download_dir and os.path.exists(self._temp_download_dir):
            shutil.rmtree(self._temp_download_dir)
            self._temp_download_dir = None

292
    def cleanup_temp(self) -> None:
293
        """Clean up all temporary directories including MinIO data dir."""
294
        self.cleanup_download()
295
296
297
298
299
300
301
302

        if self.config.data_dir and os.path.exists(self.config.data_dir):
            shutil.rmtree(self.config.data_dir, ignore_errors=True)


def load_lora_adapter(
    system_port: int, lora_name: str, s3_uri: str, timeout: int = 60
) -> None:
303
    """Load a LoRA adapter via the system API."""
304
305
306
307
308
309
310
311
312
313
314
315
    url = f"http://localhost:{system_port}/v1/loras"
    payload = {"lora_name": lora_name, "source": {"uri": s3_uri}}

    logger.info(f"Loading LoRA adapter: {lora_name} from {s3_uri}")

    response = requests.post(url, json=payload, timeout=timeout)
    if response.status_code != 200:
        raise RuntimeError(
            f"Failed to load LoRA adapter: {response.status_code} - {response.text}"
        )

    logger.info(f"LoRA adapter loaded successfully: {response.json()}")