debug_strategy.py 805 Bytes
Newer Older
1
2
import json
import os
3
import logging
4
5
6
7

from nni.retiarii import Model, submit_models, wait_models


8
9
10

def single_model_strategy():
    with open(os.path.join(os.path.dirname(__file__), 'converted_mnist_pytorch.json')) as f:
11
12
13
14
15
16
        ir = json.load(f)
    model = Model._load(ir)
    submit_models(model)
    wait_models(model)
    print('Strategy says:', model.metric)

17
18
19
20
21
22
23
24
25
26
27
28
def multi_model_cgo():
    os.environ['CGO'] = 'true'
    with open(os.path.join(os.path.dirname(__file__), 'converted_mnist_pytorch.json')) as f:
        ir = json.load(f)
    m = Model._load(ir)
    models = [m]
    for i in range(3):
        models.append(m.fork())
    submit_models(*models)
    wait_models(*models)
    
    print('Strategy says:', [_.metric for _ in models])
29
30

if __name__ == '__main__':
31
    single_model_strategy()