Unverified Commit 01ac7f8e authored by Jacket's avatar Jacket Committed by GitHub
Browse files

[Pytorch] Make test script generate checkpoints if they don't exist (#2650)


Signed-off-by: default avatarKaining Zhong <kainingz@nvidia.com>
parent b09ff7e9
......@@ -48,7 +48,11 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $T
NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py"
NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py"
export NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint
if [ ! -d "$NVTE_TEST_CHECKPOINT_ARTIFACT_PATH" ]; then
python3 $TE_PATH/tests/pytorch/test_checkpoint.py --save-checkpoint all || error_exit "Failed to generate checkpoint files"
fi
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_partial_cast.xml $TE_PATH/tests/pytorch/test_partial_cast.py || test_fail "test_partial_cast.py"
......
......@@ -101,7 +101,7 @@ class TestLoadCheckpoint:
# Path to save checkpoint
if checkpoint_dir is None:
checkpoint_dir = TestLoadCheckpoint._checkpoint_dir()
checkpoint_dir.mkdir(exist_ok=True)
checkpoint_dir.mkdir(parents=True, exist_ok=True)
checkpoint_file = checkpoint_dir / f"{name}.pt"
# Create module and save checkpoint
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment