# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import os import shutil import tempfile import time from pathlib import Path from typing import Optional import pytest from filelock import FileLock from tests.utils.constants import TEST_MODELS from tests.utils.managed_process import ManagedProcess def pytest_configure(config): # Defining markers to avoid ` not found in 'markers' configuration option` # errors when pyproject.toml is not available in the container (e.g. some CI jobs). # IMPORTANT: Keep this marker list in sync with [tool.pytest.ini_options].markers # in pyproject.toml. If you add or remove markers there, mirror the change here. markers = [ "pre_merge: marks tests to run before merging", "post_merge: marks tests to run after merge", "parallel: marks tests that can run in parallel with pytest-xdist", "nightly: marks tests to run nightly", "weekly: marks tests to run weekly", "gpu_0: marks tests that don't require GPU", "gpu_1: marks tests to run on GPU", "gpu_2: marks tests to run on 2GPUs", "gpu_4: marks tests to run on 4GPUs", "gpu_8: marks tests to run on 8GPUs", "e2e: marks tests as end-to-end tests", "integration: marks tests as integration tests", "unit: marks tests as unit tests", "stress: marks tests as stress tests", "performance: marks tests as performance tests", "vllm: marks tests as requiring vllm", "trtllm: marks tests as requiring trtllm", "sglang: marks tests as requiring sglang", "multimodal: marks tests as multimodal (image/video) tests", "slow: marks tests as known to be slow", "h100: marks tests to run on H100", "router: marks tests for router component", "planner: marks tests for planner component", "kvbm: marks tests for KV behavior and model determinism", "kvbm_v2: marks tests using KVBM V2", "model: model id used by a test or parameter", "custom_build: marks tests that require custom builds or special setup (e.g., MoE models)", "k8s: marks tests as requiring Kubernetes", "fault_tolerance: marks tests as fault tolerance tests", ] for marker in markers: config.addinivalue_line("markers", marker) LOG_FORMAT = "[TEST] %(asctime)s %(levelname)s %(name)s: %(message)s" DATE_FORMAT = "%Y-%m-%dT%H:%M:%S" logging.basicConfig( level=logging.INFO, format=LOG_FORMAT, datefmt=DATE_FORMAT, # ISO 8601 UTC format ) @pytest.fixture() def set_ucx_tls_no_mm(): """Set UCX env defaults for all tests.""" mp = pytest.MonkeyPatch() # CI note: # - Affected test: tests/fault_tolerance/cancellation/test_vllm.py::test_request_cancellation_vllm_decode_cancel # - Symptom on L40 CI: UCX/NIXL mm transport assertion during worker init # (uct_mem.c:482: mem.memh != UCT_MEM_HANDLE_NULL) when two workers # start on the same node (maybe a shared-memory segment collision/limits). # - Mitigation: disable UCX "mm" shared-memory transport globally for tests mp.setenv("UCX_TLS", "^mm") yield mp.undo() def download_models(model_list=None, ignore_weights=False): """Download models - can be called directly or via fixture Args: model_list: List of model IDs to download. If None, downloads TEST_MODELS. ignore_weights: If True, skips downloading model weight files. Default is False. """ if model_list is None: model_list = TEST_MODELS # Check for HF_TOKEN in environment hf_token = os.environ.get("HF_TOKEN") if hf_token: logging.info("HF_TOKEN found in environment") else: logging.warning( "HF_TOKEN not found in environment. " "Some models may fail to download or you may encounter rate limits. " "Get a token from https://huggingface.co/settings/tokens" ) try: from huggingface_hub import snapshot_download for model_id in model_list: logging.info( f"Pre-downloading {'model (no weights)' if ignore_weights else 'model'}: {model_id}" ) try: if ignore_weights: # Weight file patterns to exclude (based on hub.rs implementation) weight_patterns = [ "*.bin", "*.safetensors", "*.h5", "*.msgpack", "*.ckpt.index", ] # Download everything except weight files snapshot_download( repo_id=model_id, token=hf_token, ignore_patterns=weight_patterns, ) else: # Download the full model snapshot (includes all files) snapshot_download( repo_id=model_id, token=hf_token, ) logging.info(f"Successfully pre-downloaded: {model_id}") except Exception as e: logging.error(f"Failed to pre-download {model_id}: {e}") # Don't fail the fixture - let individual tests handle missing models except ImportError: logging.warning( "huggingface_hub not installed. " "Models will be downloaded during test execution." ) @pytest.fixture(scope="session") def predownload_models(pytestconfig): """Fixture wrapper around download_models for models used in collected tests""" # Get models from pytest config if available, otherwise fall back to TEST_MODELS models = getattr(pytestconfig, "models_to_download", None) if models: logging.info( f"Downloading {len(models)} models needed for collected tests\nModels: {models}" ) download_models(model_list=list(models)) else: # Fallback to original behavior if extraction failed download_models() yield @pytest.fixture(scope="session") def predownload_tokenizers(pytestconfig): """Fixture wrapper around download_models for tokenizers used in collected tests""" # Get models from pytest config if available, otherwise fall back to TEST_MODELS models = getattr(pytestconfig, "models_to_download", None) if models: logging.info( f"Downloading tokenizers for {len(models)} models needed for collected tests\nModels: {models}" ) download_models(model_list=list(models), ignore_weights=True) else: # Fallback to original behavior if extraction failed download_models(ignore_weights=True) yield @pytest.fixture(autouse=True) def logger(request): log_path = os.path.join(request.node.name, "test.log.txt") logger = logging.getLogger() shutil.rmtree(request.node.name, ignore_errors=True) os.makedirs(request.node.name, exist_ok=True) handler = logging.FileHandler(log_path, mode="w") formatter = logging.Formatter(LOG_FORMAT, datefmt=DATE_FORMAT) handler.setFormatter(formatter) logger.addHandler(handler) yield handler.close() logger.removeHandler(handler) @pytest.hookimpl(trylast=True) def pytest_collection_modifyitems(config, items): """ This function is called to modify the list of tests to run. """ # Collect models via explicit pytest mark from final filtered items only models_to_download = set() for item in items: # Only collect from items that are not skipped if any( getattr(m, "name", "") == "skip" for m in getattr(item, "own_markers", []) ): continue model_mark = item.get_closest_marker("model") if model_mark and model_mark.args: models_to_download.add(model_mark.args[0]) # Store models to download in pytest config for fixtures to access if models_to_download: config.models_to_download = models_to_download def pytest_runtestloop(session): """Download models after collection but before any tests run. This hook runs after pytest_collection_modifyitems (so models are collected) but before any test execution, ensuring model downloads don't count against test timeouts. """ models = getattr(session.config, "models_to_download", None) if models: logging.info( f"Downloading {len(models)} models before test execution\nModels: {models}" ) start_time = time.time() download_models(model_list=list(models)) download_duration = time.time() - start_time logging.info(f"Model download completed in {download_duration:.1f}s") class EtcdServer(ManagedProcess): def __init__(self, request, port=2379, timeout=300): port_string = str(port) etcd_env = os.environ.copy() etcd_env["ALLOW_NONE_AUTHENTICATION"] = "yes" data_dir = tempfile.mkdtemp(prefix="etcd_") command = [ "etcd", "--listen-client-urls", f"http://0.0.0.0:{port_string}", "--advertise-client-urls", f"http://0.0.0.0:{port_string}", "--data-dir", data_dir, ] super().__init__( env=etcd_env, command=command, timeout=timeout, display_output=False, health_check_ports=[port], data_dir=data_dir, log_dir=request.node.name, ) class NatsServer(ManagedProcess): def __init__(self, request, port=4222, timeout=300): data_dir = tempfile.mkdtemp(prefix="nats_") command = ["nats-server", "-js", "--trace", "--store_dir", data_dir] super().__init__( command=command, timeout=timeout, display_output=False, data_dir=data_dir, health_check_ports=[port], log_dir=request.node.name, ) class SharedManagedProcess: """Base class for ManagedProcess with file-based reference counting for multi-process sharing.""" def __init__( self, request, tmp_path_factory, resource_name: str, port: int, timeout: int = 300, ): self.request = request self.port = port self.timeout = timeout self.resource_name = resource_name self._server: Optional[ManagedProcess] = None self._owns_process = False root_tmp = Path(tempfile.gettempdir()) / "pytest_ref_counting" root_tmp.mkdir(parents=True, exist_ok=True) self.ref_file = root_tmp / f"pytest_{resource_name}_{port}_ref_count" self.lock_file = str(self.ref_file) + ".lock" def _create_server(self) -> ManagedProcess: """Create the underlying server instance. Must be implemented by subclasses.""" raise NotImplementedError def _read_ref_count(self) -> int: """Read current reference count.""" if self.ref_file.exists(): try: return int(self.ref_file.read_text().strip()) except (ValueError, IOError): return 0 return 0 def _write_ref_count(self, count: int): """Write reference count atomically.""" self.ref_file.write_text(str(count)) def _increment_ref_count(self) -> int: """Increment reference count and return new count.""" count = self._read_ref_count() count += 1 self._write_ref_count(count) return count def _decrement_ref_count(self) -> int: """Decrement reference count and return new count.""" count = self._read_ref_count() count = max(0, count - 1) self._write_ref_count(count) return count def __enter__(self): with FileLock(self.lock_file): ref_count = self._increment_ref_count() if ref_count == 1: # First reference - start the process self._server = self._create_server() self._server.__enter__() self._owns_process = True logging.info(f"[{self.resource_name}] Started process (ref_count=1)") else: # Process already running, just track reference self._owns_process = False logging.info( f"[{self.resource_name}] Reusing existing process (ref_count={ref_count})" ) return self def __exit__(self, exc_type, exc_val, exc_tb): with FileLock(self.lock_file): ref_count = self._decrement_ref_count() if ref_count == 0 and self._owns_process: # Last reference - stop the process if self._server: self._server.__exit__(exc_type, exc_val, exc_tb) logging.info(f"[{self.resource_name}] Stopped process (ref_count=0)") elif ref_count == 0: # Last reference but we don't own it - shouldn't happen, but clean up ref file if self.ref_file.exists(): self.ref_file.unlink() logging.warning( f"[{self.resource_name}] Ref count reached 0 but we don't own process" ) else: logging.info( f"[{self.resource_name}] Released reference (ref_count={ref_count})" ) class SharedEtcdServer(SharedManagedProcess): """EtcdServer with file-based reference counting for multi-process sharing.""" def __init__(self, request, tmp_path_factory, port=2379, timeout=300): super().__init__(request, tmp_path_factory, "etcd", port, timeout) # Create a log directory for session-scoped servers self._log_dir = tempfile.mkdtemp(prefix=f"pytest_{self.resource_name}_logs_") def _create_server(self) -> ManagedProcess: """Create EtcdServer instance.""" server = EtcdServer(self.request, port=self.port, timeout=self.timeout) # Override log_dir since request.node.name is empty in session scope server.log_dir = self._log_dir return server class SharedNatsServer(SharedManagedProcess): """NatsServer with file-based reference counting for multi-process sharing.""" def __init__(self, request, tmp_path_factory, port=4222, timeout=300): super().__init__(request, tmp_path_factory, "nats", port, timeout) # Create a log directory for session-scoped servers self._log_dir = tempfile.mkdtemp(prefix=f"pytest_{self.resource_name}_logs_") def _create_server(self) -> ManagedProcess: """Create NatsServer instance.""" server = NatsServer(self.request, port=self.port, timeout=self.timeout) # Override log_dir since request.node.name is empty in session scope server.log_dir = self._log_dir return server @pytest.fixture def store_kv(request): """ KV store for runtime. Defaults to "etcd". To iterate over multiple stores in a test: @pytest.mark.parametrize("store_kv", ["file", "etcd"], indirect=True) def test_example(runtime_services): ... """ return getattr(request, "param", "etcd") @pytest.fixture def request_plane(request): """ Request plane for runtime. Defaults to "nats". To iterate over multiple transports in a test: @pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True) def test_example(runtime_services): ... """ return getattr(request, "param", "nats") @pytest.fixture() def runtime_services(request, store_kv, request_plane): """ Start runtime services (NATS and/or etcd) based on store_kv and request_plane. - If store_kv != "etcd", etcd is not started (returns None) - If request_plane != "nats", NATS is not started (returns None) """ if request_plane == "nats" and store_kv == "etcd": with NatsServer(request) as nats_process: with EtcdServer(request) as etcd_process: yield nats_process, etcd_process elif request_plane == "nats": with NatsServer(request) as nats_process: yield nats_process, None elif store_kv == "etcd": with EtcdServer(request) as etcd_process: yield None, etcd_process else: yield None, None @pytest.fixture(scope="session") def runtime_services_session(request, tmp_path_factory): """Session-scoped fixture that provides shared NATS and etcd instances for all tests. Uses file-based reference counting to coordinate between pytest-xdist worker processes. Only the first worker starts services, and only the last worker tears them down. Test isolation is achieved through unique namespaces (test-namespace-{random-suffix}). """ with SharedNatsServer(request, tmp_path_factory) as nats: with SharedEtcdServer(request, tmp_path_factory) as etcd: yield nats, etcd @pytest.fixture def file_storage_backend(): """Fixture that sets up and tears down file storage backend. Creates a temporary directory for file-based KV storage and sets the DYN_FILE_KV environment variable. Cleans up after the test. """ with tempfile.TemporaryDirectory() as tmpdir: old_env = os.environ.get("DYN_FILE_KV") os.environ["DYN_FILE_KV"] = tmpdir logging.info(f"Set up file storage backend in: {tmpdir}") yield tmpdir # Cleanup if old_env is not None: os.environ["DYN_FILE_KV"] = old_env else: os.environ.pop("DYN_FILE_KV", None)