test_model_log_utils.py 3.17 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""Tests for model_log_utils module."""

from unittest.mock import Mock
from superbench.common import model_log_utils


class TestRecordStepLoss:
    """Tests for record_step_loss function."""
    def test_record_loss_conversion_failure(self):
        """Test exception handling when loss conversion fails."""
        logger = Mock()
        losses_list = []

        # Create a mock object that raises exception on conversion
        bad_loss = Mock()
        bad_loss.detach.side_effect = RuntimeError('Conversion failed')

        result = model_log_utils.record_step_loss(bad_loss, curr_step=5, losses_list=losses_list, logger=logger)

        assert result is None
        assert losses_list == [None]
        logger.info.assert_called_once_with('Unable to convert loss to float at step 5')

    def test_record_loss_success(self):
        """Test successful loss recording."""
        logger = Mock()
        losses_list = []

        # Create a mock tensor with detach and item methods
        loss = Mock()
        loss.detach.return_value.item.return_value = 2.5

        result = model_log_utils.record_step_loss(loss, curr_step=10, losses_list=losses_list, logger=logger)

        assert result == 2.5
        assert losses_list == [2.5]

    def test_record_loss_from_float(self):
        """Test recording loss from plain float value."""
        losses_list = []

        result = model_log_utils.record_step_loss(1.234, curr_step=1, losses_list=losses_list, logger=None)

        assert result == 1.234
        assert losses_list == [1.234]


class TestRecordPeriodicFingerprint:
    """Tests for record_periodic_fingerprint function."""
    def test_skips_when_determinism_disabled(self):
        """Test that fingerprint is not recorded when determinism is disabled."""
        periodic_dict = {}
        model_log_utils.record_periodic_fingerprint(
            curr_step=100,
            loss_value=1.0,
            logits=None,
            periodic_dict=periodic_dict,
            check_frequency=10,
            enable_determinism=False,
            logger=None
        )
        assert periodic_dict == {}

    def test_skips_when_not_at_frequency(self):
        """Test that fingerprint is not recorded when not at check frequency."""
        periodic_dict = {}
        model_log_utils.record_periodic_fingerprint(
            curr_step=15,
            loss_value=1.0,
            logits=None,
            periodic_dict=periodic_dict,
            check_frequency=10,
            enable_determinism=True,
            logger=None
        )
        assert periodic_dict == {}

    def test_records_at_frequency(self):
        """Test that fingerprint is recorded at check frequency."""
        periodic_dict = {}
        model_log_utils.record_periodic_fingerprint(
            curr_step=20,
            loss_value=1.5,
            logits=None,
            periodic_dict=periodic_dict,
            check_frequency=10,
            enable_determinism=True,
            logger=None
        )
        assert 'loss' in periodic_dict
        assert periodic_dict['loss'] == [1.5]
        assert 'step' in periodic_dict
        assert periodic_dict['step'] == [20]