Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
joytag
Commits
640e3441
Commit
640e3441
authored
Aug 08, 2024
by
chenpangpang
Browse files
feat: gradio页面改成中文
parent
d95d58ab
Pipeline
#1490
failed with stages
in 0 seconds
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
38 additions
and
46 deletions
+38
-46
joytag/app.py
joytag/app.py
+38
-46
No files found.
joytag/app.py
View file @
640e3441
...
@@ -6,53 +6,47 @@ from pathlib import Path
...
@@ -6,53 +6,47 @@ from pathlib import Path
import
torch
import
torch
import
torchvision.transforms.functional
as
TVF
import
torchvision.transforms.functional
as
TVF
MODEL_REPO
=
"fancyfeast/joytag"
MODEL_REPO
=
"fancyfeast/joytag"
THRESHOLD
=
0.4
THRESHOLD
=
0.4
DESCRIPTION
=
"""
joytag:一款图像多分类打标签工具,预测标签种类多达5000,可生成多标签及相应的概率预测
"""
def
prepare_image
(
image
:
Image
.
Image
,
target_size
:
int
)
->
torch
.
Tensor
:
def
prepare_image
(
image
:
Image
.
Image
,
target_size
:
int
)
->
torch
.
Tensor
:
# Pad image to square
# Pad image to square
image_shape
=
image
.
size
image_shape
=
image
.
size
max_dim
=
max
(
image_shape
)
max_dim
=
max
(
image_shape
)
pad_left
=
(
max_dim
-
image_shape
[
0
])
//
2
pad_left
=
(
max_dim
-
image_shape
[
0
])
//
2
pad_top
=
(
max_dim
-
image_shape
[
1
])
//
2
pad_top
=
(
max_dim
-
image_shape
[
1
])
//
2
padded_image
=
Image
.
new
(
'RGB'
,
(
max_dim
,
max_dim
),
(
255
,
255
,
255
))
padded_image
.
paste
(
image
,
(
pad_left
,
pad_top
))
padded_image
=
Image
.
new
(
'RGB'
,
(
max_dim
,
max_dim
),
(
255
,
255
,
255
))
# Resize image
padded_image
.
paste
(
image
,
(
pad_left
,
pad_top
))
if
max_dim
!=
target_size
:
padded_image
=
padded_image
.
resize
((
target_size
,
target_size
),
Image
.
BICUBIC
)
# Resize image
# Convert to tensor
if
max_dim
!=
target_size
:
image_tensor
=
TVF
.
pil_to_tensor
(
padded_image
)
/
255.0
padded_image
=
padded_image
.
resize
((
target_size
,
target_size
),
Image
.
BICUBIC
)
# Convert to tensor
image_tensor
=
TVF
.
pil_to_tensor
(
padded_image
)
/
255.0
# Normalize
# Normalize
image_tensor
=
TVF
.
normalize
(
image_tensor
,
mean
=
[
0.48145466
,
0.4578275
,
0.40821073
],
std
=
[
0.26862954
,
0.26130258
,
0.27577711
])
image_tensor
=
TVF
.
normalize
(
image_tensor
,
mean
=
[
0.48145466
,
0.4578275
,
0.40821073
],
std
=
[
0.26862954
,
0.26130258
,
0.27577711
])
return
image_tensor
return
image_tensor
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
predict
(
image
:
Image
.
Image
):
def
predict
(
image
:
Image
.
Image
):
image_tensor
=
prepare_image
(
image
,
model
.
image_size
)
image_tensor
=
prepare_image
(
image
,
model
.
image_size
)
batch
=
{
batch
=
{
'image'
:
image_tensor
.
unsqueeze
(
0
),
'image'
:
image_tensor
.
unsqueeze
(
0
),
}
}
with
torch
.
amp
.
autocast_mode
.
autocast
(
'cpu'
,
enabled
=
True
):
with
torch
.
amp
.
autocast_mode
.
autocast
(
'cpu'
,
enabled
=
True
):
preds
=
model
(
batch
)
preds
=
model
(
batch
)
tag_preds
=
preds
[
'tags'
].
sigmoid
().
cpu
()
tag_preds
=
preds
[
'tags'
].
sigmoid
().
cpu
()
scores
=
{
top_tags
[
i
]:
tag_preds
[
0
][
i
]
for
i
in
range
(
len
(
top_tags
))}
predicted_tags
=
[
tag
for
tag
,
score
in
scores
.
items
()
if
score
>
THRESHOLD
]
scores
=
{
top_tags
[
i
]:
tag_preds
[
0
][
i
]
for
i
in
range
(
len
(
top_tags
))}
tag_string
=
', '
.
join
(
predicted_tags
)
predicted_tags
=
[
tag
for
tag
,
score
in
scores
.
items
()
if
score
>
THRESHOLD
]
return
tag_string
,
scores
tag_string
=
', '
.
join
(
predicted_tags
)
return
tag_string
,
scores
print
(
"Loading model..."
)
print
(
"Loading model..."
)
...
@@ -61,22 +55,20 @@ model = VisionModel.load_model(path)
...
@@ -61,22 +55,20 @@ model = VisionModel.load_model(path)
model
.
eval
()
model
.
eval
()
with
open
(
Path
(
path
)
/
'top_tags.txt'
,
'r'
)
as
f
:
with
open
(
Path
(
path
)
/
'top_tags.txt'
,
'r'
)
as
f
:
top_tags
=
[
line
.
strip
()
for
line
in
f
.
readlines
()
if
line
.
strip
()]
top_tags
=
[
line
.
strip
()
for
line
in
f
.
readlines
()
if
line
.
strip
()]
print
(
"Starting server..."
)
print
(
"Starting server..."
)
gradio_app
=
gr
.
Interface
(
gradio_app
=
gr
.
Interface
(
predict
,
predict
,
inputs
=
gr
.
Image
(
label
=
"Source"
,
sources
=
[
'upload'
,
'webcam'
],
type
=
'pil'
),
inputs
=
gr
.
Image
(
label
=
"图像"
,
sources
=
[
'upload'
,
'webcam'
],
type
=
'pil'
),
outputs
=
[
outputs
=
[
gr
.
Textbox
(
label
=
"Tag String"
),
gr
.
Textbox
(
label
=
"标签字符"
),
gr
.
Label
(
label
=
"Tag Predictions"
,
num_top_classes
=
100
),
gr
.
Label
(
label
=
"标签及概率"
,
num_top_classes
=
100
),
],
],
title
=
"JoyTag"
,
title
=
"joytag:一款图像多分类打标签工具,预测标签种类多达5000,可生成多标签及相应的概率预测"
,
description
=
DESCRIPTION
,
allow_flagging
=
"never"
,
allow_flagging
=
"never"
,
)
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
gradio_app
.
launch
(
server_name
=
'0.0.0.0'
,
share
=
True
)
gradio_app
.
launch
(
server_name
=
'0.0.0.0'
,
share
=
True
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment