test_validation_monitor.py 1.46 KB
Newer Older
facebook-github-bot's avatar
facebook-github-bot committed
1
2
3
4
5
6
7
8
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved


import unittest
from pathlib import Path

from d2go.utils.validation_monitor import fetch_checkpoints_till_final
9
from detectron2.utils.file_io import PathManager
facebook-github-bot's avatar
facebook-github-bot committed
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from mobile_cv.common.misc.file_utils import make_temp_directory
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint


def create_file(filename):
    with PathManager.open(filename, "w") as _:
        pass


class TestValidationMonitor(unittest.TestCase):
    def test_fetch_checkpoints_local(self):
        with make_temp_directory("test") as output_dir:
            output_dir = Path(output_dir)
            for i in range(5):
                create_file(output_dir / f"model_{i}.pth")
            create_file(output_dir / "model_final.pth")
            checkpoints = list(fetch_checkpoints_till_final(output_dir))
            assert len(checkpoints) == 6

    def test_fetch_lightning_checkpoints_local(self):
        with make_temp_directory("test") as output_dir:
            output_dir = Path(output_dir)
            ext = ModelCheckpoint.FILE_EXTENSION
            for i in range(5):
                create_file(output_dir / f"step={i}{ext}")
            create_file(output_dir / f"model_final{ext}")
            create_file(output_dir / f"{ModelCheckpoint.CHECKPOINT_NAME_LAST}{ext}")
            checkpoints = list(fetch_checkpoints_till_final(output_dir))
            self.assertEqual(len(checkpoints), 6)