conftest.py 2.95 KB
Newer Older
Neelay Shah's avatar
Neelay Shah committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85

import logging
import os

import pytest

# List of models used in the serve tests
SERVE_TEST_MODELS = [
    "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
    "llava-hf/llava-1.5-7b-hf",
]

logger = logging.getLogger(__name__)


@pytest.fixture(scope="session")
def predownload_models():
    # Check for HF_TOKEN in environment
    hf_token = os.environ.get("HF_TOKEN")
    if hf_token:
        logger.info("HF_TOKEN found in environment")
    else:
        logger.warning(
            "HF_TOKEN not found in environment. "
            "Some models may fail to download or you may encounter rate limits. "
            "Get a token from https://huggingface.co/settings/tokens"
        )

    try:
        from huggingface_hub import snapshot_download

        for model_id in SERVE_TEST_MODELS:
            logger.info(f"Pre-downloading model: {model_id}")

            try:
                # Download the full model snapshot (includes all files)
                # HuggingFace will handle caching automatically
                snapshot_download(
                    repo_id=model_id,
                    token=hf_token,
                )
                logger.info(f"Successfully pre-downloaded: {model_id}")

            except Exception as e:
                logger.error(f"Failed to pre-download {model_id}: {e}")
                # Don't fail the fixture - let individual tests handle missing models

    except ImportError:
        logger.warning(
            "huggingface_hub not installed. "
            "Models will be downloaded during test execution."
        )

    yield


# Automatically use the predownload fixture for all serve tests
def pytest_collection_modifyitems(config, items):
    for item in items:
        # Skip items that don't have fixturenames (like MypyFileItem)
        if not hasattr(item, "fixturenames"):
            continue

        # Only apply to tests in the serve directory
        if "serve" in str(item.path):
            # Check if the test already uses the fixture
            if "predownload_models" not in item.fixturenames:
                # Don't add if test explicitly marks to skip model download
                if not item.get_closest_marker("skip_model_download"):
                    item.fixturenames = list(item.fixturenames)
                    item.fixturenames.append("predownload_models")