model_runner.py 7.22 KB
Newer Older
wangkx1's avatar
init  
wangkx1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
# scripts/model_runner.py
import os
import yaml
from typing import Dict, Any, List
from .base_runner import BaseRunner
from .monitor import MemoryMonitor

class ModelRunner(BaseRunner):
    """模型运行器"""
    
    def __init__(self, config: Dict[str, Any], env):
        super().__init__(config, env)
        
    def load_model_config(self, model_name: str) -> Dict[str, Any]:
        """加载模型配置"""
        config_file = f"config/{model_name}.yaml"
        
        if not os.path.exists(config_file):
            raise FileNotFoundError(f"模型配置文件不存在: {config_file}")
        
        with open(config_file, 'r') as f:
            return yaml.safe_load(f)
    
    def run_single_model(self, model_name: str):
        """运行单个模型"""
        
        print(f"\n{'='*60}")
        print(f"开始测试模型: {model_name}")
        print(f"{'='*60}")
        
        # 加载模型配置
        try:
            model_config = self.load_model_config(model_name)
        except FileNotFoundError as e:
            print(f"错误: {e}")
            return
        
        # 设置环境变量
        self.env.setup(model_config.get('model', {}).get('env_vars', {}))
        
        # 创建监控器
        monitor = MemoryMonitor(
            device_id=self.env.device_id,
            log_file=self.base_config.get('monitor', {}).get('log_file', 'memory_simple.log')
        )
        
        # 获取结果目录
        result_dir = self.env.get_result_dir()
        
        # 获取模型配置
        model_info = model_config.get('model', {})
        
        # 检查是单个模型文件还是多个模型文件
        if 'model_files' in model_info:
            # 多个模型文件的情况(如YOLOv3)
            self.run_multiple_model_files(model_info, monitor, result_dir)
        else:
            # 单个模型文件的情况
            self.run_single_model_file(model_info, monitor, result_dir)
        
        print(f"\n{model_name} 测试完成!")
    
    def run_single_model_file(self, model_info: Dict[str, Any], monitor: MemoryMonitor, result_dir: str):
        """运行单个模型文件(多个batch size)"""
        
        model_file = model_info.get('model_file')
        if not model_file or not os.path.exists(model_file):
            print(f"错误: 模型文件不存在: {model_file}")
            return
        
        model_name = os.path.basename(model_file)
        batch_sizes = model_info.get('batch_sizes', [1, 8])
        
        print(f"模型文件: {model_file}")
        print(f"测试batch大小: {batch_sizes}")
        print(f"{'-'*60}")
        
        for batch in batch_sizes:
            print(f"\n正在测试 batch={batch} ...")
            print(f"{'-'*40}")
            
            # 开始监控
            total_memory = monitor.start()
            
            # 构建命令
            cmd = self.build_command(model_file, model_info, batch)
            
            # 生成日志文件名
            log_file = os.path.join(result_dir, f"{model_name}-{batch}batch.log")
            
            # 运行模型
            success, output = self.run_model(cmd, log_file)
            
            # 停止监控
            monitor.stop()
            
            # 获取统计信息
            stats = monitor.get_statistics()
            
            # 输出统计信息
            print(f"\n=== 显存使用统计 ===")
            print(f"最大使用: {stats['max_used']:.2f} MiB")
            print(f"总显存: {total_memory} MiB")
            print(f"峰值使用率: {stats['max_percent']:.2f}%")
            
            # 将统计信息追加到日志文件
            with open(log_file, 'a') as f:
                f.write(f"\n=== 统计摘要 ===\n")
                f.write(f"最大使用: {stats['max_used']:.2f} MiB\n")
                f.write(f"总显存: {total_memory} MiB\n")
                f.write(f"峰值使用率: {stats['max_percent']:.2f}%\n")
            
            if success:
                print(f"✓ batch={batch} 测试完成,日志保存至: {log_file}")
            else:
                print(f"✗ batch={batch} 测试失败!")
            
            print(f"{'-'*40}")
    
    def run_multiple_model_files(self, model_info: Dict[str, Any], monitor: MemoryMonitor, result_dir: str):
        """运行多个模型文件(如YOLOv3不同batch size有不同文件)"""
        
        model_files = model_info.get('model_files', [])
        if not model_files:
            print("错误: 没有找到模型文件配置")
            return
        
        print(f"测试多个模型文件...")
        print(f"{'-'*60}")
        
        for model_file_info in model_files:
            model_file = model_file_info.get('path')
            batch = model_file_info.get('batch', 1)
            
            if not model_file or not os.path.exists(model_file):
                print(f"警告: 模型文件不存在,跳过: {model_file}")
                continue
            
            print(f"\n正在测试 batch={batch} ...")
            print(f"模型文件: {model_file}")
            print(f"{'-'*40}")
            
            # 开始监控
            total_memory = monitor.start()
            
            # 构建命令
            cmd = self.build_command(model_file, model_info, batch)
            
            # 生成日志文件名
            model_name = os.path.basename(model_file)
            log_file = os.path.join(result_dir, f"{model_name}-{batch}batch.log")
            
            # 运行模型
            success, output = self.run_model(cmd, log_file)
            
            # 停止监控
            monitor.stop()
            
            # 获取统计信息
            stats = monitor.get_statistics()
            
            # 输出统计信息
            print(f"\n=== 显存使用统计 ===")
            print(f"最大使用: {stats['max_used']:.2f} MiB")
            print(f"总显存: {total_memory} MiB")
            print(f"峰值使用率: {stats['max_percent']:.2f}%")
            
            # 将统计信息追加到日志文件
            with open(log_file, 'a') as f:
                f.write(f"\n=== 统计摘要 ===\n")
                f.write(f"最大使用: {stats['max_used']:.2f} MiB\n")
                f.write(f"总显存: {total_memory} MiB\n")
                f.write(f"峰值使用率: {stats['max_percent']:.2f}%\n")
            
            if success:
                print(f"✓ batch={batch} 测试完成,日志保存至: {log_file}")
            else:
                print(f"✗ batch={batch} 测试失败!")
            
            print(f"{'-'*40}")
    
    def run_all_models(self):
        """运行所有模型"""
        
        models_to_run = self.base_config.get('models_to_run', [])
        
        if not models_to_run:
            print("错误: 没有配置要运行的模型")
            return
        
        print(f"将运行以下模型: {models_to_run}")
        
        for model_name in models_to_run:
            try:
                self.run_single_model(model_name)
            except Exception as e:
                print(f"运行模型 {model_name} 时出错: {e}")
                continue
        
        print(f"\n{'='*60}")
        print("所有测试完成!")
        print(f"{'='*60}")