Unverified Commit 637faccb authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Tests for loading previously-generated checkpoints (#1899)



* Add tests for loading previously-generated checkpoints
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use `NVTE_` prefix for envvar

Review suggestion from @ksivaman
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 2f25d121
......@@ -38,4 +38,5 @@ downloads/
.pytest_cache/
compile_commands.json
.nfs
tensor_dumps/
\ No newline at end of file
tensor_dumps/
artifacts/
......@@ -45,6 +45,7 @@ NVTE_FLASH_ATTN=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cp
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/fused_attn/test_kv_cache.py || test_fail "test_kv_cache.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py"
NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py"
if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES"
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from __future__ import annotations
import argparse
import functools
import os
import pathlib
import pytest
import torch
import transformer_engine.pytorch as te
from utils import make_recipe
# Check supported quantization schemes
fp8_available, reason_for_no_fp8 = te.fp8.FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = te.fp8.FP8GlobalStateManager.is_mxfp8_available()
# Test cases for loading checkpoint files
_TestLoadCheckpoint_name_list: tuple[str, ...] = (
"linear",
"layernorm_linear",
"layernorm_mlp",
"layernorm",
"rmsnorm",
"transformer_layer",
"ops_linear",
"linear.fp8",
"ops_linear.fp8",
"linear.mxfp8",
"ops_linear.mxfp8",
)
class TestLoadCheckpoint:
"""Tests for loading checkpoint files
Tests assume that checkpoint files have already been created. In
order to regenerate checkpoint files, e.g. after a breaking change
in the checkpoint format, run this file directly as a Python
script: `python3 test_checkpoint.py --save-checkpoint all`.
"""
@staticmethod
def _make_module(name: str) -> torch.nn.Module:
"""Construct a module"""
if name == "linear":
return te.Linear(1, 1)
if name == "layernorm_linear":
return te.LayerNormLinear(1, 1)
if name == "layernorm_mlp":
return te.LayerNormMLP(1, 1)
if name == "layernorm":
return te.LayerNorm(1)
if name == "rmsnorm":
return te.RMSNorm(1)
if name == "transformer_layer":
return te.TransformerLayer(1, 1, 1)
if name == "ops_linear":
return te.ops.Linear(1, 1)
if name == "linear.fp8":
with te.fp8_model_init(recipe=make_recipe("fp8")):
return te.Linear(16, 16)
if name == "ops_linear.fp8":
with te.fp8_model_init(recipe=make_recipe("fp8")):
return te.ops.Linear(16, 16)
if name == "linear.mxfp8":
with te.fp8_model_init(recipe=make_recipe("mxfp8")):
return te.Linear(32, 32)
if name == "ops_linear.mxfp8":
with te.fp8_model_init(recipe=make_recipe("mxfp8")):
return te.ops.Linear(32, 32)
raise ValueError(f"Unrecognized module name ({name})")
@staticmethod
@functools.lru_cache(maxsize=None)
def _checkpoint_dir() -> pathlib.Path:
"""Path to directory with checkpoint files"""
# Check environment variable
path = os.getenv("NVTE_TEST_CHECKPOINT_ARTIFACT_PATH")
if path:
return pathlib.Path(path).resolve()
# Fallback to path in root dir
root_dir = pathlib.Path(__file__).resolve().parent.parent.parent
return root_dir / "artifacts" / "tests" / "pytorch" / "test_checkpoint"
@staticmethod
def _save_checkpoint(name: str, checkpoint_dir: Optional[pathlib.Path] = None) -> None:
"""Save a module's checkpoint file"""
# Path to save checkpoint
if checkpoint_dir is None:
checkpoint_dir = TestLoadCheckpoint._checkpoint_dir()
checkpoint_dir.mkdir(exist_ok=True)
checkpoint_file = checkpoint_dir / f"{name}.pt"
# Create module and save checkpoint
module = TestLoadCheckpoint._make_module(name)
torch.save(module.state_dict(), checkpoint_file)
print(f"Saved checkpoint for {name} at {checkpoint_file}")
@pytest.mark.parametrize("name", _TestLoadCheckpoint_name_list)
def test_module(self, name: str) -> None:
"""Test for loading a module's checkpoint file"""
# Skip if quantization is not supported
quantization = None
if "." in name:
quantization = name.split(".")[1]
if quantization == "fp8" and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
# Construct module
module = self._make_module(name)
# Load checkpoint from file
checkpoint_file = self._checkpoint_dir() / f"{name}.pt"
if not checkpoint_file.is_file():
raise FileNotFoundError(f"Could not find checkpoint file at {checkpoint_file}")
state_dict = torch.load(checkpoint_file, weights_only=False)
# Update module from checkpoint
module.load_state_dict(state_dict, strict=True)
def main() -> None:
"""Main function
Typically used to generate checkpoint files.
"""
# Parse command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument(
"--save-checkpoint",
type=str,
default=None,
help="Save checkpoint file for a module",
)
parser.add_argument(
"--checkpoint-dir",
type=str,
default=None,
help="Directory to save checkpoint file in",
)
args = parser.parse_args()
# Save checkpoint files if needed
if args.save_checkpoint is not None:
checkpoint_dir = args.checkpoint_dir
if checkpoint_dir is not None:
checkpoint_dir = pathlib.Path(checkpoint_dir).resolve()
if args.save_checkpoint == "all":
for name in _TestLoadCheckpoint_name_list:
TestLoadCheckpoint._save_checkpoint(name, checkpoint_dir=checkpoint_dir)
else:
TestLoadCheckpoint._save_checkpoint(
args.save_checkpoint,
checkpoint_dir=checkpoint_dir,
)
if __name__ == "__main__":
main()
......@@ -93,6 +93,7 @@ def make_recipe(name: Optional[str]) -> Optional[Recipe]:
if name in ("fp8", "fp8_delayed_scaling"):
return transformer_engine.common.recipe.DelayedScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
amax_history_len=8,
)
if name == "fp8_current_scaling":
return transformer_engine.common.recipe.Float8CurrentScaling(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment