test_curvefitting_assessor.py 2.07 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
liuzhe-lz's avatar
liuzhe-lz committed
3

4
import numpy as np
5
6
import unittest

7
8
from nni.algorithms.hpo.curvefitting_assessor import CurvefittingAssessor
from nni.algorithms.hpo.curvefitting_assessor.model_factory import CurveModel
9
10
11
12
13
from nni.assessor import AssessResult

class TestCurveFittingAssessor(unittest.TestCase):
    def test_init(self):
        new_assessor = CurvefittingAssessor(20)
14
15
        self.assertEqual(new_assessor.start_step, 6)
        self.assertEqual(new_assessor.target_pos, 20)
16
17
18
19

    def test_insufficient_point(self):
        new_assessor = CurvefittingAssessor(20)
        ret = new_assessor.assess_trial(1, [1])
20
        self.assertEqual(ret, AssessResult.Good)
21
22
23
24
25
26

    def test_not_converged(self):
        new_assessor = CurvefittingAssessor(20)
        with self.assertRaises(TypeError):
            ret = new_assessor.assess_trial([1, 199, 0, 199, 1, 209, 2])
        ret = new_assessor.assess_trial(1, [1, 199, 0, 199, 1, 209, 2])
27
        self.assertEqual(ret, AssessResult.Good)
28
        models = CurveModel(21)
29
        self.assertEqual(models.predict([1, 199, 0, 199, 1, 209, 2]), None)
30
31
32
33
34
35
36
37
38

    def test_curve_model(self):
        test_model = CurveModel(21)
        test_model.effective_model = ['vap', 'pow3', 'linear', 'logx_linear', 'dr_hill_zero_background', 'log_power', 'pow4', 'mmf', 'exp4', 'ilog2', 'weibull', 'janoschek']
        test_model.effective_model_num = 12
        test_model.point_num = 9
        test_model.target_pos = 20
        test_model.trial_history = ([1, 1, 1, 1, 1, 1, 1, 1, 1])
        test_model.weight_samples = np.ones((test_model.effective_model_num), dtype=np.float) / test_model.effective_model_num
39
40
41
42
        self.assertAlmostEqual(test_model.predict_y('vap', 9), 0.5591906328335763)
        self.assertAlmostEqual(test_model.predict_y('logx_linear', 15), 1.0704360293379522)
        self.assertAlmostEqual(test_model.f_comb(9, test_model.weight_samples), 1.1543379521172443)
        self.assertAlmostEqual(test_model.f_comb(15, test_model.weight_samples), 1.6949395581692737)
43
44
45

if __name__ == '__main__':
    unittest.main()