Commit 32eb2157 authored by wangkaixiong's avatar wangkaixiong 🚴🏼
Browse files

init

parent cf5a291c
# torch_inference_resnet50 # torch_inference_resnet50
## 从光合开发者社区安装 torch、torchvision
## 验证:
```bash
git clone http://developer.hpccube.com/codes/wangkx1/torch_inference_resnet50.git
cd torch_inference_resnet50
python torch_verify.py
```
\ No newline at end of file
dog.jpg

3.12 KB

This diff is collapsed.
import torch
from torchvision import models, transforms
from PIL import Image
# Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
# 加载预训练的ResNet50模型,并指定不使用GPU
model = models.resnet50(pretrained=True)
model.eval() # 将模型设置为评估模式
device = torch.device("cpu") # 指定设备为CPU
model.to(device) # 将模型转移到CPU上
# 图片预处理
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
image_path = 'dog.jpg' # 替换为你的图片路径
image = Image.open(image_path).convert('RGB')
image = transform(image)
# 添加batch维度
image = image.unsqueeze(0)
# 在CPU上进行推理
with torch.no_grad():
outputs = model(image.to(device))
# 获取预测类别
_, predicted_class = torch.max(outputs, 1)
predicted_class_idx = predicted_class.item()
# 打印预测类别
print(f"Predicted class: {predicted_class_idx}")
# 验证分类结果:
# https://github.com/HoldenCaulfieldRye/caffe/blob/master/data/ilsvrc12/synset_words.txt
synset_words_path = "synset_words.txt"
with open(synset_words_path, 'r') as f:
synset_words = f.readlines()
# 每行前去除空格和制表符,并根据索引获取类别名称
class_names = [line.strip() for line in synset_words]
predicted_class_name = class_names[predicted_class_idx+1] # 注意索引可能从1开始,所以+1
print(f"Predicted class name: {predicted_class_name}") # 彭布罗克,彭布罗克威尔士柯基犬
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment