torch_verify.py 2.17 KB
Newer Older
wangkaixiong's avatar
init  
wangkaixiong committed
1
2
3
4
5
6
7
8
9
import torch
from torchvision import models, transforms
from PIL import Image


# Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
# 加载预训练的ResNet50模型,并指定不使用GPU
model = models.resnet50(pretrained=True)
model.eval()  # 将模型设置为评估模式
wangkx1's avatar
wangkx1 committed
10
device = torch.device("cuda:0")  # 指定设备为CPU
wangkaixiong's avatar
init  
wangkaixiong committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
model.to(device)  # 将模型转移到CPU上

# 图片预处理
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

image_path = 'dog.jpg'  # 替换为你的图片路径
image = Image.open(image_path).convert('RGB')
image = transform(image)

# 添加batch维度
image = image.unsqueeze(0)

# 在CPU上进行推理
with torch.no_grad():
    outputs = model(image.to(device))
    
# 获取预测类别
_, predicted_class = torch.max(outputs, 1)
predicted_class_idx = predicted_class.item()

# 打印预测类别
print(f"Predicted class: {predicted_class_idx}")

# 验证分类结果:
# https://github.com/HoldenCaulfieldRye/caffe/blob/master/data/ilsvrc12/synset_words.txt
synset_words_path = "synset_words.txt"
with open(synset_words_path, 'r') as f:
    synset_words = f.readlines()

# 每行前去除空格和制表符,并根据索引获取类别名称
class_names = [line.strip() for line in synset_words]
predicted_class_name = class_names[predicted_class_idx+1]  # 注意索引可能从1开始,所以+1
wangkx1's avatar
wangkx1 committed
48
print(f"Predicted class name: {predicted_class_name}")  # 彭布罗克,彭布罗克威尔士柯基犬
wangkaixiong's avatar
wangkaixiong committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69

with torch.no_grad():
    input_tensor = torch.ones((1,3,224,224), dtype=torch.float).to("cuda:0")
    save_path = "./resnet50.onnx"
    dynamic_shape = {
        'input': {0: 'batch_size', 2: 'height', 3: 'width'},
        'output': {0: 'batch_size', 1: 'n_class'}
    }

    torch.onnx.export(
        model, 
        (input_tensor, ),
        f=save_path,
        verbose=False,
        opset_version=17,
        input_names=['input'],
        output_names=['output'],
        dynamic_axes=dynamic_shape,
    )

    print("export success")