test_save_checkpoint.py 1.25 KB
Newer Older
hepj's avatar
hepj committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import os
import shutil

import pytest
import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP


@pytest.fixture(scope="module", autouse=True)
def setup_distributed():
    os.environ["RANK"] = "0"
    os.environ["WORLD_SIZE"] = "1"
    os.environ["LOCAL_RANK"] = "0"
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = "12345"

    dist.init_process_group("nccl")
    yield
    dist.destroy_process_group()


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires at least 2 GPUs to run NCCL tests")
def test_save_and_remove_checkpoint():
    from fastvideo.models.mochi_hf.modeling_mochi import MochiTransformer3DModel
    from fastvideo.utils.checkpoint import save_checkpoint
    from fastvideo.utils.fsdp_util import get_dit_fsdp_kwargs

    transformer = MochiTransformer3DModel(num_layers=0)
    fsdp_kwargs, _ = get_dit_fsdp_kwargs(transformer, "none")
    transformer = FSDP(transformer, **fsdp_kwargs)

    test_folder = "./test_checkpoint"
    save_checkpoint(transformer, 0, test_folder, 0)

    assert os.path.exists(test_folder), "Checkpoint folder was not created."

    shutil.rmtree(test_folder)
    assert not os.path.exists(test_folder), "Checkpoint folder still exists."