README.md 2.4 KB
Newer Older
huchen's avatar
huchen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# 简介  
本用例用于神经网络架构搜索(NAS)领域darts算法在ROCm平台PyTorch框架下的测试,包括架构搜索和架构评估两部分内容,已在rocm2.9 pytorch1.3.0版本下进行验证,测试流程如下。  
# 测试流程  
## 准备数据
[cifar-10](http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz)为例进行测试说明,也可以自行下载PTB和ImageNet数据集。  
## 预训练模型
可以选择在已有的预训练模型上进行训练,下载地址如下:  
[CIFAR-10](https://drive.google.com/file/d/1Y13i4zKGKgjtWBdC0HWLavjO7wvEiGOc/view?usp=sharing)  
[PTB](https://drive.google.com/file/d/1Mt_o6fZOlG-VDF3Q5ModgnAJ9W6f_av2/view?usp=sharing)  
[ImageNet](https://drive.google.com/file/d/1AKr6Y_PoYj7j0Upggyzc26W0RVdg4CVX/view?usp=sharing)  
## 运行指令
### 架构搜索  

	python3 cnn/train_search.py --batch_size 100  
运行结束后会在当前目录下生成./search * /log.txt文件,架构格式如下:  

	genotype = Genotype(normal=[('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_5x5', 1), ('dil_conv_5x5', 3), ('sep_conv_3x3', 1), ('sep_conv_3x3', 3)], normal_concat=range(2, 6), reduce=[('skip_connect', 0), ('skip_connect', 1), ('max_pool_3x3', 0), ('skip_connect', 2), ('max_pool_3x3', 0), ('skip_connect', 2), ('max_pool_3x3', 0), ('skip_connect', 2)], reduce_concat=range(2, 6))
### 格式转化
需使用nasnet/protoc工具将上述得到的genotype转换为protobuf格式,操作如下:  

	cd nasnet/protoc  
更改util.py中的main()函数,将架构描述填入LegacyGenotype()  
```python
def main():
    PDARTS = LegacyGenotype(normal=[('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_5x5', 1), ('dil_conv_5x5', 3), ('sep_conv_3x3', 1), ('sep_conv_3x3', 3)], normal_concat=range(2, 6), reduce=[('skip_connect', 0), ('skip_connect', 1), ('max_pool_3x3', 0), ('skip_connect', 2), ('max_pool_3x3', 0), ('skip_connect', 2), ('max_pool_3x3', 0), ('skip_connect', 2)], reduce_concat=range(3, 6))
    new_PDARTS = convert_legacy_format_to_protobuf(PDARTS)
    save_genotype_to_file('pdarts.txt', new_PDARTS)
```
执行如下指令,生成pdarts.txt文件

	python3 main.py
### 架构评估  
运行示例  

	cd evaluation

	./evaluate.sh {node_name} 1 0 /path/to/pdarts.txt /path/to/{save_dir}
# 参考
[https://github.com/quark0/darts](https://github.com/quark0/darts)