Commit 2e912f00 authored by wangkx1's avatar wangkx1
Browse files

init

parents
# 模型性能测试框架
这是一个基于Python的模型性能测试框架,用于在AMD GPU上使用MIGraphX驱动测试各种深度学习模型的性能。该框架替代了原来的Shell脚本,使用YAML配置文件进行灵活配置。
## 项目结构
```
benchmark/
├── config/ # 配置文件目录
│ ├── base.yaml # 基础配置文件
│ ├── ppocr-v5-rec.yaml # PPOCR V5识别模型配置
│ ├── yolov3.yaml # YOLOv3模型配置
│ ├── inception.yaml # Inception模型配置
│ └── ... # 其他模型配置文件
├── scripts/ # 核心脚本目录
│ ├── __init__.py # Python包文件
│ ├── base_runner.py # 基础运行器
│ ├── model_runner.py # 模型运行器
│ └── monitor.py # 显存监控器
├── all_test.py # 主测试脚本
├── env.py # 环境配置
├── requirements.txt # Python依赖
└── README.md # 本文件
```
## 文件说明
### 1. 配置文件 (config/)
#### base.yaml
基础配置文件,包含:
- **通用环境变量**:LD_LIBRARY_PATH等系统路径
- **通用设置**:结果目录、设备ID、MIGraphX驱动路径、迭代次数等
- **监控设置**:监控间隔、日志文件名
- **要运行的模型列表**:指定需要测试的模型
- **日志文件夹路径配置**: 输出的日志文件存放的文件夹路径设置
#### 模型配置文件(如ppocr-v5-rec.yaml)
每个模型对应的配置文件,包含:
- **模型名称**:用于标识模型
- **模型文件路径**:ONNX模型文件位置
- **批次大小列表**:要测试的batch大小
- **输入配置**:输入节点名称和维度
- **环境变量**:模型特定的环境变量
- **额外参数**:传递给MIGraphX驱动的额外参数
### 2. 核心脚本 (scripts/)
#### monitor.py - 显存监控器
负责监控GPU显存使用情况:
- **MemoryMonitor类**:管理显存监控线程
- **功能**
- 获取总显存
- 实时监控显存使用率
- 记录使用数据到日志文件
- 计算最大使用量和峰值使用率
#### base_runner.py - 基础运行器
包含通用的模型运行逻辑:
- **BaseRunner类**:所有模型运行器的基类
- **功能**
- 构建MIGraphX命令
- 执行性能测试命令
- 处理标准输出和错误
#### model_runner.py - 模型运行器
实现具体的模型测试逻辑:
- **ModelRunner类**:继承BaseRunner
- **功能**
- 加载模型配置文件
- 支持单模型文件多批次测试
- 支持多模型文件(如YOLOv3不同批次有不同文件)
- 集成显存监控
- 生成测试结果和统计信息
### 3. 主程序文件
#### env.py - 环境配置
- **Environment类**:管理测试环境
- **功能**
- 设置环境变量
- 创建结果目录
- 管理设备可见性
#### all_test.py - 主测试脚本
- 程序入口点
- 加载配置文件
- 初始化环境和运行器
- 执行所有模型的测试
## 安装与使用
### 环境要求
- Python 3.7+
- ROCm环境
- MIGraphX驱动
### 安装步骤
1. **克隆或创建项目结构**
```bash
mkdir -p benchmark/{config,scripts}
```
2. **安装Python依赖**
```bash
# 或手动安装
pip install pyyaml
```
3. **准备配置文件**
将示例配置文件放入`config/`目录,并根据实际情况修改:
- 修改`base.yaml`中的路径和设备ID
- 为每个模型创建对应的配置文件
### 运行测试
1. **运行所有模型**
```bash
python all_test.py
```
2. **运行特定模型**
```bash
# 可以直接修改all_test.py中的models_to_run列表
# 或使用命令行参数(需要稍作修改支持)
```
### 配置示例
#### 单个模型文件配置(如PPOCR)
```yaml
# config/ppocr-v5-rec.yaml
model:
name: "ppocr-v5-rec"
model_file: "/path/to/ppocr-v5-rec_model.onnx"
batch_sizes: [1, 8] # 测试批次大小
inputs:
- name: "x" # 输入节点名称
shape: [3, 48, 320] # 输入维度(批次大小会自动插入)
env_vars:
MIGRAPHX_ENABLE_NHWC: 1 # 模型特定环境变量
```
#### 多个模型文件配置(如YOLOv3)
```yaml
# config/yolov3.yaml
model:
name: "yolov3"
model_files: # 多个模型文件
- path: "../models/yolov3/yolov3.onnx"
batch: 1
- path: "../models/yolov3/yolov3_b8.onnx"
batch: 8
inputs:
- name: "input"
shape: [3, 416, 416]
env_vars:
MIGRAPHX_ENABLE_NHWC: 1
```
## 输出结果
测试结果保存在`result_dir`指定的目录中(默认为`../result-v1`),每个测试生成的文件包括:
1. **性能日志文件**`{模型名}-{批次大小}batch.log`
- MIGraphX驱动的输出
- 性能统计信息
- 显存使用统计
2. **显存监控日志**`memory_simple.log`
- 实时显存使用记录
## 添加新模型
要添加新模型测试,只需以下步骤:
1. **创建配置文件**
`config/`目录下创建`{模型名}.yaml`文件
2. **配置模型参数**
参考现有配置文件设置:
- 模型文件路径
- 测试批次大小
- 输入节点配置
3. **添加到运行列表**
`base.yaml``models_to_run`中添加模型名称
## 注意事项
1. **环境要求**
- 确保ROCm和MIGraphX已正确安装
- 确认模型文件路径正确
- 检查GPU设备ID是否可用
2. **权限要求**
- 需要访问GPU的权限
- 需要有模型文件的读取权限
3. **监控功能**
- 监控依赖于`hy-smi`命令
- 确保监控间隔设置合理(默认1秒)
4. **错误处理**
- 模型文件不存在时会跳过测试
- 命令执行失败会记录错误信息
- 监控异常不会中断主测试
## 从Shell脚本迁移
### 主要改进
1. **配置管理**:从Shell变量改为YAML文件
2. **代码复用**:统一命令构建逻辑
3. **错误处理**:更好的异常捕获和处理
4. **线程安全**:独立的监控线程
5. **类型安全**:Python类型提示
### 对应关系
| Shell脚本部分 | Python对应部分 |
|--------------|---------------|
| env.sh配置 | config/base.yaml + env.py |
| 模型特定配置 | config/{模型名}.yaml |
| 监控脚本 | scripts/monitor.py |
| migraphx-driver调用 | base_runner.py中build_command方法 |
| 循环测试批次 | model_runner.py中run方法 |
## 故障排除
### 常见问题
1. **找不到模型文件**
- 检查配置文件中的路径
- 确认文件是否存在且可读
2. **监控失败**
- 检查`hy-smi`命令是否可用
- 确认设备ID正确
3. **权限问题**
- 确保有运行MIGraphX驱动的权限
- 检查结果目录的写入权限
### 调试方法
1. **增加日志输出**
修改配置文件或代码中的日志级别
2. **单独运行命令**
手动运行生成的MIGraphX命令进行调试
3. **检查环境变量**
使用`printenv`确认环境变量已正确设置
## 扩展功能
框架支持以下扩展:
1. **自定义监控指标**:扩展Monitor类添加更多监控项
2. **性能数据解析**:添加结果解析和汇总功能
3. **自动化报告**:生成HTML或PDF测试报告
4. **批量测试**:支持跨多个设备的并行测试
---
如有问题或建议,请参考代码注释或创建Issue。
\ No newline at end of file
# all_test.py
#!/usr/bin/env python3
import yaml
import os
import sys
# 添加当前目录到Python路径
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from env import Environment
from scripts.model_runner import ModelRunner
def load_config():
"""加载配置文件"""
config_file = "config/base.yaml"
if not os.path.exists(config_file):
print(f"错误: 配置文件不存在: {config_file}")
sys.exit(1)
with open(config_file, 'r') as f:
return yaml.safe_load(f)
def main():
"""主函数"""
print("开始运行模型性能测试...")
# 加载配置
config = load_config()
# 初始化环境
env = Environment(config)
# 初始化模型运行器
runner = ModelRunner(config, env)
# 运行所有模型
runner.run_all_models()
if __name__ == "__main__":
main()
\ No newline at end of file
# 日志解析工具
这是一个用于解析模型性能测试日志文件的Python脚本,能够从ONNX模型性能测试日志中提取关键信息并汇总到CSV文件中。
## 功能特性
- 自动解析日志文件中的关键信息:
- 模型名称 (model_name)
- 批处理大小 (batch_size)
- 输入形状 (input_shape)
- 推理速率 (FPS)
- 最大内存使用量 (MaxMemoryUsageMiB)
- HCU使用率 (HCU%)
- 支持批量处理日志文件
- 自动检测并跳过重复记录
- 将结果保存到CSV文件以便后续分析
- 提供详细的过程输出和结果预览
## 安装要求
- Python 3.x
- 无需额外依赖包(仅使用Python标准库)
## 使用方法
### 基本用法
```bash
# 删除之前保存的csv
# 解析当前目录下所有.log文件,输出到result.csv
python parse_logs.py
# 解析特定模式的日志文件
python parse_logs.py "*.log"
# 指定输入文件模式和输出CSV文件名
python parse_logs.py "test_*.log" "output.csv"
```
### 参数说明
1. **log_pattern** (可选,默认: `"*.log"`):
- 用于匹配日志文件的通配符模式
- 示例: `"resnet*.log"`, `"test_*.txt"`
2. **csv_file** (可选,默认: `"result.csv"`):
- 输出CSV文件的路径
- 如果文件已存在,新记录将被追加到文件末尾
### 命令行示例
```bash
# 示例1:解析特定模型的日志文件
python parse_logs.py "./result/*.log" "results.csv"
# 示例2:解析不同格式的日志文件
python parse_logs.py "performance_*.txt" "summary.csv"
```
## 输出格式
脚本将生成一个CSV文件,包含以下列:
| 列名 | 描述 | 示例 |
|------|------|------|
| model_name | 模型名称 | resnet50 |
| batch_size | 批处理大小 | 4 |
| input_shape | 输入形状 | 1 3 224 224 |
| FPS | 每秒推理次数 | x |
| MaxMemoryUsageMiB | 最大内存使用量(MiB) | x |
| HCU% | HCU峰值使用率 | x |
## 注意事项
1. **重复记录检测**: 脚本会自动检测并跳过CSV文件中已存在的相同模型、相同batch size和相同输入形状的记录
2. **排序**: 结果会按模型名称(字母顺序)和batch size(数值大小)升序排序
3. **错误处理**: 如果某个日志文件解析失败,脚本会继续处理其他文件并显示错误信息
4. **编码**: 脚本使用UTF-8编码读取和写入文件,确保支持中文等非ASCII字符
## 故障排除
### 常见问题
1. **"没有找到匹配的日志文件"**
- 检查当前目录是否正确
- 确认文件扩展名是否匹配(默认为`.log`
2. **"解析不完整"警告**
- 检查日志文件格式是否符合预期
- 确保日志包含所有必要的信息字段
3. **编码错误**
- 如果日志文件使用其他编码,可能需要修改脚本中的编码设置
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import re
import os
import glob
import csv
import sys
from pathlib import Path
from collections import defaultdict
def parse_log_file(log_file_path):
"""
解析单个日志文件,提取关键信息
参数:
log_file_path: 日志文件路径
返回:
字典包含解析结果,或None表示解析失败
"""
try:
with open(log_file_path, 'r', encoding='utf-8') as f:
content = f.read()
result = {
'model_name': None,
'batch_size': None,
'input_shape': None,
'FPS': None
}
# 1. 提取模型名称:从日志命令中提取
# 查找 "perf resnet50.onnx" 这样的模式
model_match = re.search(r'[/\s]([a-zA-Z0-9_-]+)\.onnx', content)
if model_match:
result['model_name'] = model_match.group(1)
else:
# 如果没找到,尝试从文件名提取
filename = os.path.basename(log_file_path)
# 匹配类似 "resnet50-4batch.log" 的文件名
filename_match = re.search(r'([a-zA-Z0-9_-]+)-\d+batch\.log', filename)
if filename_match:
result['model_name'] = filename_match.group(1)
# 2. 提取batch size:从"Batch size: 1"中提取
batch_match = re.search(r'Batch size:\s*(\d+)', content)
if batch_match:
result['batch_size'] = batch_match.group(1)
# 3. 提取输入shape:从"--input-dim @input 1 3 224 224"中提取
input_match = re.search(r'--input-dim\s+(?:@)?[a-zA-Z0-9_\-\.]+\s+(\d+\s+\d+\s+\d+\s+\d+)', content)
if input_match:
result['input_shape'] = input_match.group(1)
# 4. 提取总时间:从"Total time: 2.08637ms"中提取
# time_match = re.search(r'Total time:\s*([\d.]+)ms', content)
# if time_match:
# result['total_time'] = time_match.group(1)
# 优先尝试提取 Rate 行中的数值
rate_match = re.search(r'Rate:\s*([\d.]+)\s*inferences/sec', content)
if rate_match:
result['FPS'] = rate_match.group(1) # 注意:这里实际存的是速率,不是时间
else:
# 回退:从 Total time 提取(如果需要保留原逻辑)
rate_match = re.search(r'Total time:\s*([\d.]+)ms', content)
if rate_match:
result['FPS'] = rate_match.group(1)
# 优先尝试提取 Rate 行中的数值
# 尝试提取“最大使用: XXXX MiB”中的数值
memory_match = re.search(r'最大使用:\s*([\d.]+)\s*MiB', content)
if memory_match:
result['MaxMemoryUsageMiB'] = memory_match.group(1)
else:
result['MaxMemoryUsageMiB'] = "null"
# 优先尝试提取 Rate 行中的数值
peak_usage_match = re.search(r'峰值使用率:\s*([\d.]+)%', content)
if peak_usage_match:
result['HCU%'] = peak_usage_match.group(1)
else:
# 回退:从 Total time 提取(如果需要保留原逻辑)
peak_usage_match = re.search(r'峰值使用率:\s*([\d.]+)%', content)
if peak_usage_match:
result['HCU%'] = peak_usage_match.group(1)
# 检查是否成功提取了所有必要信息
if all(result.values()):
return result
else:
print(f"警告: 文件 {log_file_path} 解析不完整:")
print(f" 模型名: {result['model_name']}")
print(f" Batch: {result['batch_size']}")
print(f" 输入shape: {result['input_shape']}")
print(f" FPS: {result['FPS']}")
print(f" MaxMemoryUsageMiB: {result['MaxMemoryUsageMiB']}")
print(f" HCU%: {result['HCU%']}")
return None
except Exception as e:
print(f"错误: 读取或解析文件 {log_file_path} 时出错: {e}")
return None
def main():
"""主函数"""
# 解析命令行参数
if len(sys.argv) > 1:
log_pattern = sys.argv[1]
else:
log_pattern = "*.log" # 默认匹配所有log文件
if len(sys.argv) > 2:
csv_file = sys.argv[2]
else:
csv_file = "result.csv" # 默认输出文件名
print(f"开始解析日志文件,模式: {log_pattern}")
print(f"输出到: {csv_file}")
print("=" * 50)
# 获取所有匹配的日志文件
log_files = glob.glob(log_pattern)
if not log_files:
print(f"错误: 没有找到匹配 '{log_pattern}' 的日志文件")
return
print(f"找到 {len(log_files)} 个日志文件")
# 检查CSV文件是否存在,决定是否需要写表头
write_header = not os.path.exists(csv_file)
# 解析所有日志文件
results = []
for log_file in sorted(log_files):
print(f"正在解析: {log_file}")
result = parse_log_file(log_file)
if result:
results.append(result)
print(f" ✓ 成功解析: {result['model_name']} (batch={result['batch_size']})")
if not results:
print("错误: 没有成功解析任何日志文件")
return
# 写入CSV文件
try:
with open(csv_file, 'a', newline='', encoding='utf-8') as f:
writer = csv.DictWriter(f, fieldnames=['model_name', 'batch_size', 'input_shape', 'FPS', "MaxMemoryUsageMiB", "HCU%"])
if write_header:
writer.writeheader()
print(f"创建新的CSV文件: {csv_file}")
# 写入数据,检查是否已存在相同的记录
existing_data = []
if os.path.exists(csv_file) and os.path.getsize(csv_file) > 0:
with open(csv_file, 'r', encoding='utf-8') as existing_f:
existing_reader = csv.DictReader(existing_f)
existing_data = list(existing_reader)
# 对结果按 model_name 字符串升序,batch_size 数值升序排序
results.sort(key=lambda x: (x['model_name'], int(x['batch_size'])), reverse=False)
new_count = 0
for result in results:
# 检查是否已存在相同模型、相同batch、相同输入的记录
is_duplicate = False
for existing in existing_data:
if (existing.get('model_name') == result['model_name'] and
existing.get('batch_size') == result['batch_size'] and
existing.get('input_shape') == result['input_shape']):
is_duplicate = True
print(f" ⚠ 跳过重复记录: {result['model_name']}, batch={result['batch_size']}")
break
if not is_duplicate:
writer.writerow({
'model_name': result['model_name'],
'batch_size': result['batch_size'],
'input_shape': result['input_shape'],
'FPS': result['FPS'],
'MaxMemoryUsageMiB': result['MaxMemoryUsageMiB'],
'HCU%': result['HCU%']
})
new_count += 1
print(f"\n" + "=" * 50)
print("解析完成!")
print(f"处理日志文件数: {len(log_files)}")
print(f"成功解析文件数: {len(results)}")
print(f"新增记录数: {new_count}")
print(f"CSV文件: {csv_file}")
# 显示CSV文件内容预览
print("\nCSV内容预览:")
print("-" * 50)
try:
with open(csv_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
for i, line in enumerate(lines[:20]): # 显示前20行
print(line.rstrip())
if len(lines) > 20:
print(f"... (共 {len(lines)} 行)")
except Exception as e:
print(f"读取CSV文件时出错: {e}")
except Exception as e:
print(f"错误: 写入CSV文件时出错: {e}")
if __name__ == "__main__":
main()
\ No newline at end of file
model_name,batch_size,input_shape,FPS,MaxMemoryUsageMiB,HCU%
inception-v3,1,1 3 299 299,,,
inception-v3_b16,16,16 3 299 299,,,
inception-v3_b4,4,4 3 299 299,,,
inception-v3_b8,8,8 3 299 299,,,
ppocr-v5-det_model,8,8 3 640 640,,,
ppocr-v5-det_model-new-1x3x640x640,1,1 3 640 640,,,
ppocr-v5-rec_model,1,1 3 48 320,,,
ppocr-v5-rec_model,8,8 3 48 320,,,
yolov3,1,1 3 416 416,,,
yolov3_b16,16,16 3 416 416,,,
yolov3_b4,4,4 3 416 416,,,
yolov3_b8,8,8 3 416 416,,,
yolov8n,1,1 3 640 640,,,
yolov8n,8,8 3 640 640,,,
resnet50,1,1 3 224 224,,,
resnet50,8,8 3 224 224,,,
# config/base.yaml
base:
# 通用环境变量
env_vars:
LD_LIBRARY_PATH: "/data/wkx/benchmark/env-0204/rocblas-install/lib:$LD_LIBRARY_PATH"
# 通用设置
common:
result_dir: "./result"
device_id: 3
migraphx_driver: "/opt/dtk/bin/migraphx-driver"
fp16: true
iterations: 100
# 监控设置
monitor:
log_file: "memory_simple.log"
interval: 1 # 秒
# 要运行的模型列表
models_to_run:
# - ppocr-v5-rec
# - yolov3
# - ppocr-v5-det
# - yolov8n
# - eva-2
- resnet50
# - swin-B
# - vit-L
# - inception-v3
\ No newline at end of file
# config/yolov3.yaml
model:
name: "yolov3"
# 多个模型文件(对应不同batch size)
model_files:
- path: "../models/eva/eva02_large_224_1x3x224x224-opt.onnx"
batch: 1
- path: "../models/eva/eva02_large_224_8x3x224x224-opt.onnx"
batch: 8
# 测试的batch大小(如果使用model_files,这里也可以指定)
batch_sizes: [1, 8]
# 输入配置
inputs:
- name: "input.1"
shape: [3, 224, 224]
# 环境变量
env_vars:
MIGRAPHX_ENABLE_LAYERNORM_FUSION: 1
MIGRAPHX_ENABLE_MHA_BHSD: 1
# 额外参数
extra_args: []
\ No newline at end of file
# config/inception-v3.yaml
model:
name: "inception-v3"
# 模型文件路径
# model_file: "/data/wkx/benchmark/models/ppocr-v5-rec_model.onnx"
model_files:
- path: "../models/inception/inception-v3.onnx"
batch: 1
# - path: "../models/inception/inception-v3_b4.onnx"
# batch: 4
- path: "../models/inception/inception-v3_b8.onnx"
batch: 8
# - path: "../models/inception/inception-v3_b16.onnx"
# batch: 16
# 测试的batch大小
batch_sizes: [1, 8]
# 输入配置
inputs:
- name: "input"
shape: [3, 299, 299] # BATCH会在运行时插入到第一个位置
# 环境变量
env_vars:
# MIGRAPHX_ENABLE_NHWC: 1
# 额外参数
extra_args: []
\ No newline at end of file
# config/ppocr-v5-det.yaml
model:
name: "ppocr-v5-det"
# 多个模型文件(对应不同batch size)
model_files:
- path: "/data/wkx/models/ppocr-v5-det_model-new-1x3x640x640.onnx"
batch: 1
# - path: "../models/yolov3/yolov3_b4.onnx"
# batch: 4
- path: "/data/wkx/benchmark/models/ppocr-v5-det_model.onnx"
batch: 8
# - path: "../models/yolov3/yolov3_b16.onnx"
# batch: 16
# 测试的batch大小(如果使用model_files,这里也可以指定)
batch_sizes: [1, 8]
# 输入配置
inputs:
- name: "x"
shape: [3, 640, 640]
# 环境变量
env_vars:
# MIGRAPHX_ENABLE_NHWC: 1
# 额外参数
extra_args: []
\ No newline at end of file
# config/ppocr-v5-rec.yaml
model:
name: "ppocr-v5-rec"
# 模型文件路径
model_file: "/data/wkx/benchmark/models/ppocr-v5-rec_model.onnx"
# 测试的batch大小
batch_sizes: [1, 8]
# 输入配置
inputs:
- name: "x"
shape: [3, 48, 320] # BATCH会在运行时插入到第一个位置
# 环境变量
env_vars:
# MIGRAPHX_ENABLE_NHWC: 1
# 额外参数
extra_args: []
\ No newline at end of file
# config/resnet50.yaml
model:
name: "resnet50"
# 多个模型文件(对应不同batch size)
model_file: ../models/resnet50.onnx
# 测试的batch大小(如果使用model_files,这里也可以指定)
batch_sizes: [1, 8]
# 输入配置
inputs:
- name: "input"
shape: [3, 224, 224]
# 环境变量
env_vars:
# MIGRAPHX_ENABLE_NHWC: 1
# 额外参数
extra_args: []
\ No newline at end of file
# config/yolov3.yaml
model:
name: "yolov3"
# 多个模型文件(对应不同batch size)
model_files:
- path: "../models/yolov3/yolov3.onnx"
batch: 1
- path: "../models/yolov3/yolov3_b4.onnx"
batch: 4
- path: "../models/yolov3/yolov3_b8.onnx"
batch: 8
- path: "../models/yolov3/yolov3_b16.onnx"
batch: 16
# 测试的batch大小(如果使用model_files,这里也可以指定)
batch_sizes: [1, 8]
# 输入配置
inputs:
- name: "input"
shape: [3, 416, 416]
# 环境变量
env_vars:
MIGRAPHX_ENABLE_GEMM_SOFTMAX_GEMM_FUSE: 1
MIGRAPHX_ENABLE_LAYERNORM_FUSION: 1
# 额外参数
extra_args: []
\ No newline at end of file
# config/yolov3.yaml
model:
name: "yolov3"
# 多个模型文件(对应不同batch size)
model_files:
- path: "../models/yolov3/yolov3.onnx"
batch: 1
- path: "../models/yolov3/yolov3_b4.onnx"
batch: 4
- path: "../models/yolov3/yolov3_b8.onnx"
batch: 8
- path: "../models/yolov3/yolov3_b16.onnx"
batch: 16
# 测试的batch大小(如果使用model_files,这里也可以指定)
batch_sizes: [1, 8]
# 输入配置
inputs:
- name: "input"
shape: [3, 416, 416]
# 环境变量
env_vars:
# MIGRAPHX_ENABLE_NHWC: 1
# 额外参数
extra_args: []
\ No newline at end of file
# config/yolov3.yaml
model:
name: "yolov3"
# 多个模型文件(对应不同batch size)
model_files:
- path: "../models/yolov3/yolov3.onnx"
batch: 1
- path: "../models/yolov3/yolov3_b4.onnx"
batch: 4
- path: "../models/yolov3/yolov3_b8.onnx"
batch: 8
- path: "../models/yolov3/yolov3_b16.onnx"
batch: 16
# 测试的batch大小(如果使用model_files,这里也可以指定)
batch_sizes: [1, 8]
# 输入配置
inputs:
- name: "input"
shape: [3, 416, 416]
# 环境变量
env_vars:
# MIGRAPHX_ENABLE_NHWC: 1
# 额外参数
extra_args: []
\ No newline at end of file
# config/yolov8n.yaml
model:
name: "yolov8n"
# 多个模型文件(对应不同batch size)
model_files:
- path: "/data/wkx/benchmark/models/yolov8/yolov8n.onnx"
batch: 1
# - path: "../models/yolov3/yolov3_b4.onnx"
# batch: 4
- path: "/data/wkx/benchmark/models/yolov8/yolov8n.onnx"
batch: 8
# - path: "../models/yolov3/yolov3_b16.onnx"
# batch: 16
# 测试的batch大小(如果使用model_files,这里也可以指定)
batch_sizes: [1, 8]
# 输入配置
inputs:
- name: "images"
shape: [3, 640, 640]
# 环境变量
env_vars:
# MIGRAPHX_ENABLE_NHWC: 1
# 额外参数
extra_args: []
\ No newline at end of file
# env.py
import os
from typing import Dict, Any
class Environment:
"""管理环境变量和系统配置"""
def __init__(self, config: Dict[str, Any]):
self.config = config
self.base_env = config.get('base', {})
self.device_id = self.base_env.get('common', {}).get('device_id', 3)
def setup(self, model_env_vars: Dict[str, str] = None):
"""设置环境变量"""
# 设置基础环境变量
base_env_vars = self.base_env.get('env_vars', {})
for key, value in base_env_vars.items():
if key in os.environ:
# 如果环境变量已存在,追加
os.environ[key] = f"{value}:{os.environ[key]}"
else:
os.environ[key] = value
# 设置模型特定的环境变量
if model_env_vars:
for key, value in model_env_vars.items():
os.environ[key] = str(value)
# 设置设备可见性
os.environ['HIP_VISIBLE_DEVICES'] = str(self.device_id)
def get_result_dir(self) -> str:
"""获取结果目录"""
result_dir = self.base_env.get('common', {}).get('result_dir', '../result-v1')
# 确保目录存在
if not os.path.exists(result_dir):
os.makedirs(result_dir)
print(f"创建目录: {result_dir}")
return result_dir
\ No newline at end of file
2026-02-06 11:03:53 2 0.0
2026-02-06 11:03:54 2 0.0
2026-02-06 11:03:55 2 0.0
2026-02-06 11:03:56 2 0.0
2026-02-06 11:03:57 2 0.0
2026-02-06 11:03:58 2 0.0
2026-02-06 11:03:59 2 0.0
2026-02-06 11:04:00 2 0.0
2026-02-06 11:04:01 2 0.0
2026-02-06 11:04:02 2 0.0
2026-02-06 11:04:03 2 0.0
2026-02-06 11:04:04 2 0.0
2026-02-06 11:04:05 2 0.0
2026-02-06 11:04:06 2 0.0
2026-02-06 11:04:07 2 0.0
2026-02-06 11:04:08 2 0.0
2026-02-06 11:04:09 2 0.0
2026-02-06 11:04:10 2 0.0
2026-02-06 11:04:11 2 0.0
2026-02-06 11:04:12 2 0.0
2026-02-06 11:04:13 2 0.0
2026-02-06 11:04:14 2 0.0
2026-02-06 11:04:15 2 0.0
2026-02-06 11:04:16 2 0.0
2026-02-06 11:04:17 2 0.0
2026-02-06 11:04:18 279 0.0
2026-02-06 11:04:19 279 0.0
2026-02-06 11:04:20 279 0.0
2026-02-06 11:04:22 279 0.0
2026-02-06 11:04:23 279 0.0
2026-02-06 11:04:24 279 0.0
2026-02-06 11:04:25 281 0.0
2026-02-06 11:04:26 281 0.0
2026-02-06 11:04:27 281 0.0
2026-02-06 11:04:28 540 95.8
2026-02-06 11:04:29 540 0.0
2026-02-06 11:04:30 199 0.0
2026-02-06 11:04:31 199 0.0
2026-02-06 11:04:32 199 0.0
2026-02-06 11:04:33 145 0.0
2026-02-06 11:04:34 145 0.0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment