import torch
from mmpretrain import get_model

model = get_model("resnet50_8xb16_cifar100", pretrained="./tools/resnet50_b16x8_cifar100_20210528-67b58a1b.pth")

# 设置模型为评估模式
model.eval()

# 创建一个示例的输入张量（需要和模型的输入要求一致）
input_tensor = torch.randn(24, 3, 32, 32)  # 示例输入尺寸为 224x224，3 通道

# 导出模型为 ONNX 格式
torch.onnx.export(model,                 # 模型
                  input_tensor,            # 输入张量
                  "resnet50.onnx",        # 导出的 ONNX 文件路径
                  export_params=True,    # 不导出模型参数，因为已经加载了预训练参数
                  opset_version=13,       # ONNX 版本
                  do_constant_folding=True,  # 是否执行常量折叠优化
                  input_names=['input'],  # 输入节点名称
                  output_names=['output'])  # 输出节点名称