Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangkx1
torch_inference_resnet50
Commits
32eb2157
Commit
32eb2157
authored
Jul 20, 2024
by
wangkaixiong
🚴🏼
Browse files
init
parent
cf5a291c
Changes
4
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
1057 additions
and
0 deletions
+1057
-0
README.md
README.md
+9
-0
dog.jpg
dog.jpg
+0
-0
synset_words.txt
synset_words.txt
+1000
-0
torch_verify.py
torch_verify.py
+48
-0
No files found.
README.md
View file @
32eb2157
# torch_inference_resnet50
# torch_inference_resnet50
## 从光合开发者社区安装 torch、torchvision
## 验证:
```
bash
git clone http://developer.hpccube.com/codes/wangkx1/torch_inference_resnet50.git
cd
torch_inference_resnet50
python torch_verify.py
```
\ No newline at end of file
dog.jpg
0 → 100644
View file @
32eb2157
3.12 KB
synset_words.txt
0 → 100644
View file @
32eb2157
This diff is collapsed.
Click to expand it.
torch_verify.py
0 → 100644
View file @
32eb2157
import
torch
from
torchvision
import
models
,
transforms
from
PIL
import
Image
# Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
# 加载预训练的ResNet50模型,并指定不使用GPU
model
=
models
.
resnet50
(
pretrained
=
True
)
model
.
eval
()
# 将模型设置为评估模式
device
=
torch
.
device
(
"cpu"
)
# 指定设备为CPU
model
.
to
(
device
)
# 将模型转移到CPU上
# 图片预处理
transform
=
transforms
.
Compose
([
transforms
.
Resize
(
256
),
transforms
.
CenterCrop
(
224
),
transforms
.
ToTensor
(),
transforms
.
Normalize
(
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.229
,
0.224
,
0.225
]),
])
image_path
=
'dog.jpg'
# 替换为你的图片路径
image
=
Image
.
open
(
image_path
).
convert
(
'RGB'
)
image
=
transform
(
image
)
# 添加batch维度
image
=
image
.
unsqueeze
(
0
)
# 在CPU上进行推理
with
torch
.
no_grad
():
outputs
=
model
(
image
.
to
(
device
))
# 获取预测类别
_
,
predicted_class
=
torch
.
max
(
outputs
,
1
)
predicted_class_idx
=
predicted_class
.
item
()
# 打印预测类别
print
(
f
"Predicted class:
{
predicted_class_idx
}
"
)
# 验证分类结果:
# https://github.com/HoldenCaulfieldRye/caffe/blob/master/data/ilsvrc12/synset_words.txt
synset_words_path
=
"synset_words.txt"
with
open
(
synset_words_path
,
'r'
)
as
f
:
synset_words
=
f
.
readlines
()
# 每行前去除空格和制表符,并根据索引获取类别名称
class_names
=
[
line
.
strip
()
for
line
in
synset_words
]
predicted_class_name
=
class_names
[
predicted_class_idx
+
1
]
# 注意索引可能从1开始,所以+1
print
(
f
"Predicted class name:
{
predicted_class_name
}
"
)
# 彭布罗克,彭布罗克威尔士柯基犬
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment