Commit 640e3441 authored by chenpangpang's avatar chenpangpang
Browse files

feat: gradio页面改成中文

parent d95d58ab
Pipeline #1490 failed with stages
in 0 seconds
...@@ -6,53 +6,47 @@ from pathlib import Path ...@@ -6,53 +6,47 @@ from pathlib import Path
import torch import torch
import torchvision.transforms.functional as TVF import torchvision.transforms.functional as TVF
MODEL_REPO = "fancyfeast/joytag" MODEL_REPO = "fancyfeast/joytag"
THRESHOLD = 0.4 THRESHOLD = 0.4
DESCRIPTION = """
joytag:一款图像多分类打标签工具,预测标签种类多达5000,可生成多标签及相应的概率预测
"""
def prepare_image(image: Image.Image, target_size: int) -> torch.Tensor: def prepare_image(image: Image.Image, target_size: int) -> torch.Tensor:
# Pad image to square # Pad image to square
image_shape = image.size image_shape = image.size
max_dim = max(image_shape) max_dim = max(image_shape)
pad_left = (max_dim - image_shape[0]) // 2 pad_left = (max_dim - image_shape[0]) // 2
pad_top = (max_dim - image_shape[1]) // 2 pad_top = (max_dim - image_shape[1]) // 2
padded_image = Image.new('RGB', (max_dim, max_dim), (255, 255, 255))
padded_image.paste(image, (pad_left, pad_top))
padded_image = Image.new('RGB', (max_dim, max_dim), (255, 255, 255)) # Resize image
padded_image.paste(image, (pad_left, pad_top)) if max_dim != target_size:
padded_image = padded_image.resize((target_size, target_size), Image.BICUBIC)
# Resize image # Convert to tensor
if max_dim != target_size: image_tensor = TVF.pil_to_tensor(padded_image) / 255.0
padded_image = padded_image.resize((target_size, target_size), Image.BICUBIC)
# Convert to tensor
image_tensor = TVF.pil_to_tensor(padded_image) / 255.0
# Normalize # Normalize
image_tensor = TVF.normalize(image_tensor, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]) image_tensor = TVF.normalize(image_tensor, mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711])
return image_tensor return image_tensor
@torch.no_grad() @torch.no_grad()
def predict(image: Image.Image): def predict(image: Image.Image):
image_tensor = prepare_image(image, model.image_size) image_tensor = prepare_image(image, model.image_size)
batch = { batch = {
'image': image_tensor.unsqueeze(0), 'image': image_tensor.unsqueeze(0),
} }
with torch.amp.autocast_mode.autocast('cpu', enabled=True):
with torch.amp.autocast_mode.autocast('cpu', enabled=True): preds = model(batch)
preds = model(batch) tag_preds = preds['tags'].sigmoid().cpu()
tag_preds = preds['tags'].sigmoid().cpu() scores = {top_tags[i]: tag_preds[0][i] for i in range(len(top_tags))}
predicted_tags = [tag for tag, score in scores.items() if score > THRESHOLD]
scores = {top_tags[i]: tag_preds[0][i] for i in range(len(top_tags))} tag_string = ', '.join(predicted_tags)
predicted_tags = [tag for tag, score in scores.items() if score > THRESHOLD] return tag_string, scores
tag_string = ', '.join(predicted_tags)
return tag_string, scores
print("Loading model...") print("Loading model...")
...@@ -61,22 +55,20 @@ model = VisionModel.load_model(path) ...@@ -61,22 +55,20 @@ model = VisionModel.load_model(path)
model.eval() model.eval()
with open(Path(path) / 'top_tags.txt', 'r') as f: with open(Path(path) / 'top_tags.txt', 'r') as f:
top_tags = [line.strip() for line in f.readlines() if line.strip()] top_tags = [line.strip() for line in f.readlines() if line.strip()]
print("Starting server...") print("Starting server...")
gradio_app = gr.Interface( gradio_app = gr.Interface(
predict, predict,
inputs=gr.Image(label="Source", sources=['upload', 'webcam'], type='pil'), inputs=gr.Image(label="图像", sources=['upload', 'webcam'], type='pil'),
outputs=[ outputs=[
gr.Textbox(label="Tag String"), gr.Textbox(label="标签字符"),
gr.Label(label="Tag Predictions", num_top_classes=100), gr.Label(label="标签及概率", num_top_classes=100),
], ],
title="JoyTag", title="joytag:一款图像多分类打标签工具,预测标签种类多达5000,可生成多标签及相应的概率预测",
description=DESCRIPTION, allow_flagging="never",
allow_flagging="never",
) )
if __name__ == '__main__': if __name__ == '__main__':
gradio_app.launch(server_name='0.0.0.0', share=True) gradio_app.launch(server_name='0.0.0.0', share=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment