"vscode:/vscode.git/clone" did not exist on "c984fd24abfd8f24a66c4dbb39c85985a5261dd5"
test_smartparam.py 3.48 KB
Newer Older
Deshui Yu's avatar
Deshui Yu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================

21
22
23
import os

os.environ['NNI_PLATFORM'] = 'unittest'
Deshui Yu's avatar
Deshui Yu committed
24
25
26
27
28
29
30
31

import nni
import nni.platform.test as test_platform
import nni.trial

from unittest import TestCase, main


32
33
lineno1 = 61
lineno2 = 75
Deshui Yu's avatar
Deshui Yu committed
34
35
36
37

class SmartParamTestCase(TestCase):
    def setUp(self):
        params = {
38
39
40
41
            'test_smartparam/choice1/choice': 'a',
            'test_smartparam/choice2/choice': '3*2+1',
            'test_smartparam/choice3/choice': '[1, 2]',
            'test_smartparam/choice4/choice': '{"a", 2}',
fishyds's avatar
fishyds committed
42
            'test_smartparam/__line{:d}/uniform'.format(lineno1): '5',
43
44
45
            'test_smartparam/func/function_choice': 'bar',
            'test_smartparam/lambda_func/function_choice': "lambda: 2*3",
            'test_smartparam/__line{:d}/function_choice'.format(lineno2): 'max(1, 2, 3)'
Deshui Yu's avatar
Deshui Yu committed
46
47
48
49
50
        }
        nni.trial._params = { 'parameter_id': 'test_trial', 'parameters': params }


    def test_specified_name(self):
51
52
53
54
55
56
57
58
        val = nni.choice({'a': 'a', '3*2+1': 3*2+1, '[1, 2]': [1, 2], '{"a", 2}': {"a", 2}}, name = 'choice1')
        self.assertEqual(val, 'a')
        val = nni.choice({'a': 'a', '3*2+1': 3*2+1, '[1, 2]': [1, 2], '{"a", 2}': {"a", 2}}, name = 'choice2')
        self.assertEqual(val, 7)
        val = nni.choice({'a': 'a', '3*2+1': 3*2+1, '[1, 2]': [1, 2], '{"a", 2}': {"a", 2}}, name = 'choice3')
        self.assertEqual(val, [1, 2])
        val = nni.choice({'a': 'a', '3*2+1': 3*2+1, '[1, 2]': [1, 2], '{"a", 2}': {"a", 2}}, name = 'choice4')
        self.assertEqual(val, {"a", 2})
Deshui Yu's avatar
Deshui Yu committed
59
60
61
62
63
64

    def test_default_name(self):
        val = nni.uniform(1, 10)  # NOTE: assign this line number to lineno1
        self.assertEqual(val, '5')

    def test_specified_name_func(self):
65
        val = nni.function_choice({'foo': foo, 'bar': bar}, name = 'func')
Deshui Yu's avatar
Deshui Yu committed
66
67
        self.assertEqual(val, 'bar')

68
69
70
71
    def test_lambda_func(self):
        val = nni.function_choice({"lambda: 2*3": lambda: 2*3, "lambda: 3*4": lambda: 3*4}, name = 'lambda_func')
        self.assertEqual(val, 6)

Deshui Yu's avatar
Deshui Yu committed
72
    def test_default_name_func(self):
73
74
75
76
        val = nni.function_choice({
            'max(1, 2, 3)': lambda: max(1, 2, 3),
            'min(1, 2)': lambda: min(1, 2)  # NOTE: assign this line number to lineno2
        })
Deshui Yu's avatar
Deshui Yu committed
77
78
79
80
81
82
83
84
85
86
87
88
        self.assertEqual(val, 3)


def foo():
    return 'foo'

def bar():
    return 'bar'


if __name__ == '__main__':
    main()