mnist_pytorch.json 2.33 KB
Newer Older
1
2
3
4
5
6
{
    "_model": {
        "inputs": ["image"],
        "outputs": ["metric"],

        "nodes": {
7
8
9
10
11
            "stem": {"operation": {"type": "_cell", "cell_name": "stem"}},
            "flatten": {"operation": {"type": "Flatten"}},
            "fc1": {"operation": {"type": "Dense", "parameters": {"out_features": 256, "in_features": 1024}}},
            "fc2": {"operation": {"type": "Dense", "parameters": {"out_features": 10, "in_features": 256}}},
            "softmax": {"operation": {"type": "Softmax"}}
12
13
14
15
16
17
18
19
20
21
22
23
24
25
        },

        "edges": [
            {"head": ["_inputs", 0], "tail": ["stem", null]},
            {"head": ["stem", null], "tail": ["flatten", null]},
            {"head": ["flatten", null], "tail": ["fc1", null]},
            {"head": ["fc1", null], "tail": ["fc2", null]},
            {"head": ["fc2", null], "tail": ["softmax", null]},
            {"head": ["softmax", null], "tail": ["_outputs", 0]}
        ]
    },

    "stem": {
        "nodes": {
26
27
28
29
            "conv1": {"operation": {"type": "__torch__.Conv2d", "parameters": {"out_channels": 32, "in_channels": 1, "kernel_size": 5}}},
            "pool1": {"operation": {"type": "__torch__.MaxPool2d", "parameters": {"kernel_size": 2}}},
            "conv2": {"operation": {"type": "__torch__.Conv2d", "parameters": {"out_channels": 64, "in_channels": 32, "kernel_size": 5}}},
            "pool2": {"operation": {"type": "__torch__.MaxPool2d", "parameters": {"kernel_size": 2}}}
30
31
32
33
34
35
36
37
38
39
40
        },

        "edges": [
            {"head": ["_inputs", 0], "tail": ["conv1", null]},
            {"head": ["conv1", null], "tail": ["pool1", null]},
            {"head": ["pool1", null], "tail": ["conv2", null]},
            {"head": ["conv2", null], "tail": ["pool2", null]},
            {"head": ["pool2", null], "tail": ["_outputs", 0]}
        ]
    },

41
    "_evaluator": {
42
43
        "module": "nni.retiarii.trainer.PyTorchImageClassificationTrainer",
        "kwargs": {
44
            "dataset_cls": "MNIST",
45
46
47
48
49
50
51
            "dataset_kwargs": {
                "root": "data/mnist",
                "download": true
            },
            "dataloader_kwargs": {
                "batch_size": 32
            },
52
            "optimizer_cls" : "SGD",
53
54
55
56
57
58
59
60
61
            "optimizer_kwargs": {
                "lr": 1e-3
            },
            "trainer_kwargs": {
                "max_epochs": 1
            }
        }
    }
}