Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
joytag
Commits
640e3441
Commit
640e3441
authored
Aug 08, 2024
by
chenpangpang
Browse files
feat: gradio页面改成中文
parent
d95d58ab
Pipeline
#1490
failed with stages
in 0 seconds
Changes
1
Pipelines
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
38 additions
and
46 deletions
+38
-46
joytag/app.py
joytag/app.py
+38
-46
No files found.
joytag/app.py
View file @
640e3441
...
...
@@ -6,12 +6,8 @@ from pathlib import Path
import
torch
import
torchvision.transforms.functional
as
TVF
MODEL_REPO
=
"fancyfeast/joytag"
THRESHOLD
=
0.4
DESCRIPTION
=
"""
joytag:一款图像多分类打标签工具,预测标签种类多达5000,可生成多标签及相应的概率预测
"""
def
prepare_image
(
image
:
Image
.
Image
,
target_size
:
int
)
->
torch
.
Tensor
:
...
...
@@ -32,7 +28,8 @@ def prepare_image(image: Image.Image, target_size: int) -> torch.Tensor:
image_tensor
=
TVF
.
pil_to_tensor
(
padded_image
)
/
255.0
# Normalize
image_tensor
=
TVF
.
normalize
(
image_tensor
,
mean
=
[
0.48145466
,
0.4578275
,
0.40821073
],
std
=
[
0.26862954
,
0.26130258
,
0.27577711
])
image_tensor
=
TVF
.
normalize
(
image_tensor
,
mean
=
[
0.48145466
,
0.4578275
,
0.40821073
],
std
=
[
0.26862954
,
0.26130258
,
0.27577711
])
return
image_tensor
...
...
@@ -43,15 +40,12 @@ def predict(image: Image.Image):
batch
=
{
'image'
:
image_tensor
.
unsqueeze
(
0
),
}
with
torch
.
amp
.
autocast_mode
.
autocast
(
'cpu'
,
enabled
=
True
):
preds
=
model
(
batch
)
tag_preds
=
preds
[
'tags'
].
sigmoid
().
cpu
()
scores
=
{
top_tags
[
i
]:
tag_preds
[
0
][
i
]
for
i
in
range
(
len
(
top_tags
))}
predicted_tags
=
[
tag
for
tag
,
score
in
scores
.
items
()
if
score
>
THRESHOLD
]
tag_string
=
', '
.
join
(
predicted_tags
)
return
tag_string
,
scores
...
...
@@ -67,16 +61,14 @@ print("Starting server...")
gradio_app
=
gr
.
Interface
(
predict
,
inputs
=
gr
.
Image
(
label
=
"
Source
"
,
sources
=
[
'upload'
,
'webcam'
],
type
=
'pil'
),
inputs
=
gr
.
Image
(
label
=
"
图像
"
,
sources
=
[
'upload'
,
'webcam'
],
type
=
'pil'
),
outputs
=
[
gr
.
Textbox
(
label
=
"
Tag String
"
),
gr
.
Label
(
label
=
"
Tag Predictions
"
,
num_top_classes
=
100
),
gr
.
Textbox
(
label
=
"
标签字符
"
),
gr
.
Label
(
label
=
"
标签及概率
"
,
num_top_classes
=
100
),
],
title
=
"JoyTag"
,
description
=
DESCRIPTION
,
title
=
"joytag:一款图像多分类打标签工具,预测标签种类多达5000,可生成多标签及相应的概率预测"
,
allow_flagging
=
"never"
,
)
if
__name__
==
'__main__'
:
gradio_app
.
launch
(
server_name
=
'0.0.0.0'
,
share
=
True
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment