test_smartparam.py 2.61 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
Deshui Yu's avatar
Deshui Yu committed
3

4
5
6
import os

os.environ['NNI_PLATFORM'] = 'unittest'
Deshui Yu's avatar
Deshui Yu committed
7
8
9
10
11
12
13
14
15
16
17
18

import nni
import nni.platform.test as test_platform
import nni.trial

from unittest import TestCase, main



class SmartParamTestCase(TestCase):
    def setUp(self):
        params = {
19
20
21
22
23
            'test_smartparam/choice1/choice': 'a',
            'test_smartparam/choice2/choice': '3*2+1',
            'test_smartparam/choice3/choice': '[1, 2]',
            'test_smartparam/choice4/choice': '{"a", 2}',
            'test_smartparam/func/function_choice': 'bar',
24
25
26
27
28
29
30
            'test_smartparam/lambda_func/function_choice': "lambda: 2*3",
            'mutable_block_66':{
                'mutable_layer_0':{
                    'chosen_layer': 'conv2D(size=5)',
                    'chosen_inputs': ['y']
                }
            }
Deshui Yu's avatar
Deshui Yu committed
31
32
33
34
35
        }
        nni.trial._params = { 'parameter_id': 'test_trial', 'parameters': params }


    def test_specified_name(self):
Zejun Lin's avatar
Zejun Lin committed
36
        val = nni.choice({'a': 'a', '3*2+1': 3*2+1, '[1, 2]': [1, 2], '{"a", 2}': {"a", 2}}, name = 'choice1', key='test_smartparam/choice1/choice')
37
        self.assertEqual(val, 'a')
Zejun Lin's avatar
Zejun Lin committed
38
        val = nni.choice({'a': 'a', '3*2+1': 3*2+1, '[1, 2]': [1, 2], '{"a", 2}': {"a", 2}}, name = 'choice2', key='test_smartparam/choice2/choice')
39
        self.assertEqual(val, 7)
Zejun Lin's avatar
Zejun Lin committed
40
        val = nni.choice({'a': 'a', '3*2+1': 3*2+1, '[1, 2]': [1, 2], '{"a", 2}': {"a", 2}}, name = 'choice3', key='test_smartparam/choice3/choice')
41
        self.assertEqual(val, [1, 2])
Zejun Lin's avatar
Zejun Lin committed
42
        val = nni.choice({'a': 'a', '3*2+1': 3*2+1, '[1, 2]': [1, 2], '{"a", 2}': {"a", 2}}, name = 'choice4', key='test_smartparam/choice4/choice')
43
        self.assertEqual(val, {"a", 2})
Deshui Yu's avatar
Deshui Yu committed
44

Zejun Lin's avatar
Zejun Lin committed
45
46
    def test_func(self):
        val = nni.function_choice({'foo': foo, 'bar': bar}, name='func', key='test_smartparam/func/function_choice')
Deshui Yu's avatar
Deshui Yu committed
47
48
        self.assertEqual(val, 'bar')

49
    def test_lambda_func(self):
Zejun Lin's avatar
Zejun Lin committed
50
        val = nni.function_choice({"lambda: 2*3": lambda: 2*3, "lambda: 3*4": lambda: 3*4}, name = 'lambda_func', key='test_smartparam/lambda_func/function_choice')
51
52
        self.assertEqual(val, 6)

53
54
55
56
57
58
59
    def test_mutable_layer(self):
        layer_out = nni.mutable_layer('mutable_block_66',
                'mutable_layer_0', {'conv2D(size=3)': conv2D, 'conv2D(size=5)': conv2D}, {'conv2D(size=3)':
                {'size':3}, 'conv2D(size=5)': {'size':5}}, [100], {'x':1,'y':2}, 1, 'classic_mode')
        self.assertEqual(layer_out, [100, 2, 5])
        

Deshui Yu's avatar
Deshui Yu committed
60
61
62
63
64
65
66

def foo():
    return 'foo'

def bar():
    return 'bar'

67
68
def conv2D(inputs, size=3):
    return inputs[0] + inputs[1] + [size]
Deshui Yu's avatar
Deshui Yu committed
69
70
71

if __name__ == '__main__':
    main()