import gc import os import sys from pathlib import Path from unittest import mock import pytest import torch from megatron.core.dist_checkpointing.strategies.base import StrategyAction, get_default_strategy from megatron.core.utils import is_te_min_version from tests.unit_tests.dist_checkpointing import TempNamedDir from tests.unit_tests.test_utilities import Utils @pytest.fixture(scope="session", autouse=True) def set_env(): if is_te_min_version("1.3"): os.environ['NVTE_FLASH_ATTN'] = '0' os.environ['NVTE_FUSED_ATTN'] = '0' @pytest.fixture(scope="session") def tmp_path_dist_ckpt(tmp_path_factory) -> Path: """Common directory for saving the checkpoint. Can't use pytest `tmp_path_factory` directly because directory must be shared between processes. """ tmp_dir = tmp_path_factory.mktemp('ignored', numbered=False) tmp_dir = tmp_dir.parent.parent / 'tmp_dist_ckpt' if Utils.rank == 0: with TempNamedDir(tmp_dir, sync=False): yield tmp_dir else: yield tmp_dir