{ "cells": [ { "cell_type": "markdown", "id": "ede95fbe", "metadata": {}, "source": [ "# DeepSolo\n", "\n", "## 1. 模型简介\n", "DeepSolo是一个简洁的类似DETR的基线模型,允许一个具有显式点的解码器同时进行检测和识别。\n", "\n", "**模型结构**\n", "\n", "DeepSolo中,编码器在接收到图像特征后,生成由四个Bezier控制点表示的Bezier中心曲线候选和相应的分数,然后,选择前K个评分的候选。对于每个选定的曲线候选,在曲线上均匀采样N个点,这些点的坐标被编码为位置query并将其添加到内容query中形成复合query。接下来,将复合query输入deformable cross-attention解码器收集有用的文本特征。在解码器之后,采用了几个简单的并行预测头(线性层或MLP)将query解码为文本的中心线、边界、script和置信度,从而同时解决检测和识别问题。\n", "\n", "
\n", " \n" ] }, { "cell_type": "markdown", "id": "aa7b7b67-9d3b-4926-b039-ba9840eefa4d", "metadata": {}, "source": [ "¶\n", "## 2. 环境检查及依赖补全\n", "\n", "### 2.1 环境检查\n", "\n", "推荐环境:pytorch=1.13.1 py38\n", "推荐环境:dcu=23.04.1\n" ] }, { "cell_type": "code", "execution_count": null, "id": "925e2b69", "metadata": {}, "outputs": [], "source": [ "# 检查torch版本\n", "\n", "import torch\n", "import torch.utils.cpp_extension\n", "version = torch.__version__\n", "num = float(version[:version.rfind('.')])\n", "assert num >= 1.10\n", "device = \"cpu\"\n", "\n", "# 检查硬件环境\n", "if torch.utils.cpp_extension.HIP_HOME:\n", " device = \"dtk\"\n", " !rocm-smi\n", "elif torch.utils.cpp_extension.CUDA_HOME:\n", " device = \"cuda\"\n", " !nvidia-smi\n", "print(\"pytorch version:\", version)\n", "print(\"device =\", device)" ] }, { "cell_type": "markdown", "id": "aa7b7b67-9d3b-4926-b039-ba9840eefacc", "metadata": {}, "source": [ "### 2.2 依赖安装\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "9e7d7059", "metadata": {}, "outputs": [], "source": [ "!pip install -r requirements.txt -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com" ] }, { "cell_type": "code", "execution_count": null, "id": "5bf42297", "metadata": {}, "outputs": [], "source": [ "!bash make.sh" ] }, { "cell_type": "code", "execution_count": null, "id": "0a7b0e93", "metadata": {}, "outputs": [], "source": [ "!git clone https://github.com/facebookresearch/detectron2.git\n", "!python -m pip install -e detectron2" ] }, { "cell_type": "code", "execution_count": null, "id": "15f59780", "metadata": {}, "outputs": [], "source": [ "!pip3 install -r requirements.txt" ] }, { "cell_type": "markdown", "id": "64c534dd-ad8f-493d-a5e4-435676d4f162", "metadata": {}, "source": [ "## 3. 素材准备\n", "### 3.1 数据集准备\n" ] }, { "cell_type": "markdown", "id": "b1caaec4", "metadata": {}, "source": [ "项目已经预制了轻量数据simple进行训练验证,请确保当前项目中包含datasets目录且结构如下:\n", "\n", "```\n", "├── datasets\n", "│ ├── simple\n", "│ ├── test_images\n", "│ ├── train_images\n", "│ ├── test.json\n", "│ └── train.json\n", "```" ] }, { "cell_type": "markdown", "id": "ef0faf15-d9ab-454f-9368-e026372752ad", "metadata": {}, "source": [ "## 4 训练\n", "### 4.1 开始训练\n", "\n", "根据需求选择单卡或多卡训练" ] }, { "cell_type": "code", "execution_count": null, "id": "c305213b", "metadata": {}, "outputs": [], "source": [ "# 单卡训练\n", "!export HIP_VISIBLE_DEVICES=4\n", "!python tools/train_net.py --config-file configs/simple/train_simple.yaml --num-gpus 1" ] }, { "cell_type": "code", "execution_count": null, "id": "2d95c0e6-4143-40de-b90c-3978c02cb169", "metadata": {}, "outputs": [], "source": [ "# 多卡训练(需要2个加速卡)\n", "!export HIP_VISIBLE_DEVICES=4,5,6,7\n", "!python tools/train_net.py --config-file configs/simple/train_simple.yaml --num-gpus 4" ] }, { "cell_type": "markdown", "id": "5f1986b2-1ea1-45be-8f56-dfaacfba7694", "metadata": {}, "source": [ "## 5. 推理\n", "## 5.1 开始推理\n", "\n", "提供了一个推理脚本来测试模型,执行下面的脚本来测试模型输出" ] }, { "cell_type": "code", "execution_count": null, "id": "de9d22f1-4764-4c16-a4c9-52e9581669ce", "metadata": {}, "outputs": [], "source": [ "!python demo/demo.py --config-file configs/simple/test_simple.yaml --input datasets/simple/test_images\n", "# 推理结果默认保存在test_results文件夹下" ] }, { "cell_type": "markdown", "id": "cfe3acde", "metadata": {}, "source": [ "## 6. 相关文献和引用\n", "https://github.com/ViTAE-Transformer/DeepSolo.git\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.12" } }, "nbformat": 4, "nbformat_minor": 5 }