deepsolo-pytorch.ipynb 5.77 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ede95fbe",
   "metadata": {},
   "source": [
    "# DeepSolo\n",
    "\n",
    "## 1. 模型简介\n",
    "DeepSolo是一个简洁的类似DETR的基线模型,允许一个具有显式点的解码器同时进行检测和识别。\n",
    "\n",
    "**模型结构**\n",
    "\n",
    "DeepSolo中,编码器在接收到图像特征后,生成由四个Bezier控制点表示的Bezier中心曲线候选和相应的分数,然后,选择前K个评分的候选。对于每个选定的曲线候选,在曲线上均匀采样N个点,这些点的坐标被编码为位置query并将其添加到内容query中形成复合query。接下来,将复合query输入deformable cross-attention解码器收集有用的文本特征。在解码器之后,采用了几个简单的并行预测头(线性层或MLP)将query解码为文本的中心线、边界、script和置信度,从而同时解决检测和识别问题。\n",
    "\n",
    "<div align=center>\n",
    "    <img src=\"./doc/DeepSolo.jpg\"/>\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aa7b7b67-9d3b-4926-b039-ba9840eefa4d",
   "metadata": {},
   "source": [
    "¶\n",
    "## 2. 环境检查及依赖补全\n",
    "\n",
    "### 2.1 环境检查\n",
    "\n",
    "推荐环境:pytorch=1.13.1 py38\n",
    "推荐环境:dcu=23.04.1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "925e2b69",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 检查torch版本\n",
    "\n",
    "import torch\n",
    "import torch.utils.cpp_extension\n",
    "version = torch.__version__\n",
    "num = float(version[:version.rfind('.')])\n",
    "assert num >= 1.10\n",
    "device = \"cpu\"\n",
    "\n",
    "# 检查硬件环境\n",
    "if torch.utils.cpp_extension.HIP_HOME:\n",
    "    device = \"dtk\"\n",
    "    !rocm-smi\n",
    "elif torch.utils.cpp_extension.CUDA_HOME:\n",
    "    device = \"cuda\"\n",
    "    !nvidia-smi\n",
    "print(\"pytorch version:\", version)\n",
    "print(\"device =\", device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aa7b7b67-9d3b-4926-b039-ba9840eefacc",
   "metadata": {},
   "source": [
    "### 2.2 依赖安装\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e7d7059",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -r requirements.txt -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5bf42297",
   "metadata": {},
   "outputs": [],
   "source": [
    "!bash make.sh"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a7b0e93",
   "metadata": {},
   "outputs": [],
   "source": [
    "!git clone https://github.com/facebookresearch/detectron2.git\n",
    "!python -m pip install -e detectron2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15f59780",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip3 install -r requirements.txt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "64c534dd-ad8f-493d-a5e4-435676d4f162",
   "metadata": {},
   "source": [
    "## 3. 素材准备\n",
    "### 3.1 数据集准备\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b1caaec4",
   "metadata": {},
   "source": [
    "项目已经预制了轻量数据simple进行训练验证,请确保当前项目中包含datasets目录且结构如下:\n",
    "\n",
    "```\n",
    "├── datasets\n",
    "│   ├── simple\n",
    "│       ├── test_images\n",
    "│       ├── train_images\n",
    "│       ├── test.json\n",
    "│       └── train.json\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef0faf15-d9ab-454f-9368-e026372752ad",
   "metadata": {},
   "source": [
    "## 4 训练\n",
    "### 4.1 开始训练\n",
    "\n",
    "根据需求选择单卡或多卡训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c305213b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 单卡训练\n",
    "!export HIP_VISIBLE_DEVICES=4\n",
    "!python tools/train_net.py --config-file configs/simple/train_simple.yaml --num-gpus 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d95c0e6-4143-40de-b90c-3978c02cb169",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 多卡训练(需要2个加速卡)\n",
    "!export HIP_VISIBLE_DEVICES=4,5,6,7\n",
    "!python tools/train_net.py --config-file configs/simple/train_simple.yaml --num-gpus 4"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5f1986b2-1ea1-45be-8f56-dfaacfba7694",
   "metadata": {},
   "source": [
    "## 5. 推理\n",
    "## 5.1 开始推理\n",
    "\n",
    "提供了一个推理脚本来测试模型,执行下面的脚本来测试模型输出"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de9d22f1-4764-4c16-a4c9-52e9581669ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "!python demo/demo.py --config-file configs/simple/test_simple.yaml --input datasets/simple/test_images\n",
    "# 推理结果默认保存在test_results文件夹下"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cfe3acde",
   "metadata": {},
   "source": [
    "## 6. 相关文献和引用\n",
    "https://github.com/ViTAE-Transformer/DeepSolo.git\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}