Commit 4c93f0ed authored by dengjf's avatar dengjf
Browse files

update code

parents
Pipeline #685 canceled with stages
.DS_Store
*.pth
*.pyc
*.pyo
*.log
*.tmp
The MIT License (MIT)
Copyright (c) 2017 Jieru Mei <meijieru@gmail.com>
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
# CRNN_pytorch
## 论文
[An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition](https://arxiv.org/abs/1507.05717)
## 模型结构
CRNN模型主要包括三个部分,分别称作卷积层、循环层、转录层。
![model_structure.jpg](asserts%2Fmodel_structure.jpg)
## 算法原理
CRNN网络将CNN和RNN网络结合,共同训练,使用CNN对输入图像提取特征,使用RNN对特征序列进行预测并输出预测标签,使用CTCLoss讲标签分布转换为最终的标签序列,其中RNN采用的双层各256单元的双向LSTM
![Algorithm.jpg](asserts%2FAlgorithm.jpg)
## 环境配置
### Docker (方法一)
```
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.10.0-centos7.6-dtk-22.10.1-py39-latest
docker run -it -v /path/your_code_data/:/path/your_code_data/ --shm-size=32G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name docker_name imageID bash
cd /path/workspace/
pip3 install -r requirements.txt
```
### Dockerfile (方法二)
```
cd ./docker
docker build --no-cache -t crnn_paddle:last .
docker run -it -v /path/your_code_data/:/path/your_code_data/ --shm-size=32G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name docker_name imageID bash
```
### Anaconda (方法三)
1、关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装: https://developer.hpccube.com/tool/
```
DTK软件栈:dtk22.10
python:python3.8
pytorch:1.10.1
```
Tips:以上dtk软件栈、python、pytorch等DCU相关工具版本需要严格一一对应
2、其他非特殊库直接按照requirements.txt安装
```
pip3 install -r requirements.txt
```
## 数据集
Synth90k(合成文本数据集-该数据集包含900万张由一组90k常见英语单词生成的图像)
[训练数据](https://www.robots.ox.ac.uk/~vgg/data/text/)
数据集的目录结构如下:
```
├── IIIT5K_lmdb
│   ├── data.mdb
│   ├── error_image_log.txt
│   └── lock.mdb
└── MJ_LMDB
├── data.mdb
└── lock.mdb
```
## 训练
### 单机单卡
```
export HIP_VISIBLE_DEVICES=0
export USE_MIOPEN_BATCHNORM=1
python3 train.py --adadelta --trainRoot ../Datasets/Synth90k/MJ_LMDB --valRoot ../Datasets/Synth90k/IIIT5K_lmdb --cuda --ngpu 1 --batchSize 64 --workers 8
```
### 单机多卡
```
#以单机四卡为例子
export HSA_FORCE_FINE_GRAIN_PCIE=1
export USE_MIOPEN_BATCHNORM=1
export HIP_VISIBLE_DEVICES=6,7
python -m torch.distributed.launch --nproc_per_node=2 train_ddp.py --adadelta --trainRoot ../Datasets/Synth90k/MJ_LMDB --valRoot ../Datasets/Synth90k/IIIT5K_lmdb --cuda --ngpu 1 --batchSize 64 --workers 8
```
## 推理
#### 单卡推理
[权重下载地址](https://pan.baidu.com/s/1pLbeCND)
```
python demo.py
```
## result
此处以crnn模型进行推理测试<br>
| 输入 | 输出 |
|:--:|:--:|
|![demo.png](data%2Fdemo.png)|![result.jpg](asserts%2Fresult.jpg)|
### 精度
| 模型 | 数据类型 | ACC | Loss |
|:-------:|:----:|:------:|:--------:|
| crnn | 单精 | 0.9376 | 0.000768 |
## 应用场景
### 算法分类
OCR
### 热点应用行业
金融,零售,交通
## 源码仓库及问题反馈
https://developer.hpccube.com/codes/modelzoo/crnn_pytorch
## 参考
[GitHub - crnn.pytorch](https://github.com/meijieru/crnn.pytorch/tree/master)
Convolutional Recurrent Neural Network
======================================
This software implements the Convolutional Recurrent Neural Network (CRNN) in pytorch.
Origin software could be found in [crnn](https://github.com/bgshih/crnn)
Run demo
--------
A demo program can be found in ``demo.py``. Before running the demo, download a pretrained model
from [Baidu Netdisk](https://pan.baidu.com/s/1pLbeCND) or [Dropbox](https://www.dropbox.com/s/dboqjk20qjkpta3/crnn.pth?dl=0).
This pretrained model is converted from auther offered one by ``tool``.
Put the downloaded model file ``crnn.pth`` into directory ``data/``. Then launch the demo by:
python demo.py
The demo reads an example image and recognizes its text content.
Example image:
![Example Image](./data/demo.png)
Expected output:
loading pretrained model from ./data/crnn.pth
a-----v--a-i-l-a-bb-l-ee-- => available
Dependence
----------
* [warp_ctc_pytorch](https://github.com/SeanNaren/warp-ctc/tree/pytorch_bindings/pytorch_binding)
* lmdb
Train a new model
-----------------
1. Construct dataset following [origin guide](https://github.com/bgshih/crnn#train-a-new-model). If you want to train with variable length images (keep the origin ratio for example), please modify the `tool/create_dataset.py` and sort the image according to the text length.
2. Execute ``python train.py --adadelta --trainRoot {train_path} --valRoot {val_path} --cuda``. Explore ``train.py`` for details.
#!/usr/bin/python
# encoding: utf-8
import random
import torch
from torch.utils.data import Dataset
from torch.utils.data import sampler
import torchvision.transforms as transforms
import lmdb
import six
import sys
from PIL import Image
import numpy as np
class lmdbDataset(Dataset):
def __init__(self, root=None, transform=None, target_transform=None):
self.env = lmdb.open(
root,
max_readers=1,
readonly=True,
lock=False,
readahead=False,
meminit=False)
if not self.env:
print('cannot creat lmdb from %s' % (root))
sys.exit(0)
with self.env.begin(write=False) as txn:
nSamples = int(txn.get('num-samples'))
self.nSamples = nSamples
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return self.nSamples
def __getitem__(self, index):
assert index <= len(self), 'index range error'
index += 1
with self.env.begin(write=False) as txn:
img_key = 'image-%09d' % index
imgbuf = txn.get(img_key)
buf = six.BytesIO()
buf.write(imgbuf)
buf.seek(0)
try:
img = Image.open(buf).convert('L')
except IOError:
print('Corrupted image for %d' % index)
return self[index + 1]
if self.transform is not None:
img = self.transform(img)
label_key = 'label-%09d' % index
label = str(txn.get(label_key))
if self.target_transform is not None:
label = self.target_transform(label)
return (img, label)
class resizeNormalize(object):
def __init__(self, size, interpolation=Image.BILINEAR):
self.size = size
self.interpolation = interpolation
self.toTensor = transforms.ToTensor()
def __call__(self, img):
img = img.resize(self.size, self.interpolation)
img = self.toTensor(img)
img.sub_(0.5).div_(0.5)
return img
class randomSequentialSampler(sampler.Sampler):
def __init__(self, data_source, batch_size):
self.num_samples = len(data_source)
self.batch_size = batch_size
def __iter__(self):
n_batch = len(self) // self.batch_size
tail = len(self) % self.batch_size
index = torch.LongTensor(len(self)).fill_(0)
for i in range(n_batch):
random_start = random.randint(0, len(self) - self.batch_size)
batch_index = random_start + torch.range(0, self.batch_size - 1)
index[i * self.batch_size:(i + 1) * self.batch_size] = batch_index
# deal with tail
if tail:
random_start = random.randint(0, len(self) - self.batch_size)
tail_index = random_start + torch.range(0, tail - 1)
index[(i + 1) * self.batch_size:] = tail_index
return iter(index)
def __len__(self):
return self.num_samples
class alignCollate(object):
def __init__(self, imgH=32, imgW=100, keep_ratio=False, min_ratio=1):
self.imgH = imgH
self.imgW = imgW
self.keep_ratio = keep_ratio
self.min_ratio = min_ratio
def __call__(self, batch):
images, labels = zip(*batch)
imgH = self.imgH
imgW = self.imgW
if self.keep_ratio:
ratios = []
for image in images:
w, h = image.size
ratios.append(w / float(h))
ratios.sort()
max_ratio = ratios[-1]
imgW = int(np.floor(max_ratio * imgH))
imgW = max(imgH * self.min_ratio, imgW) # assure imgH >= imgW
transform = resizeNormalize((imgW, imgH))
images = [transform(image) for image in images]
images = torch.cat([t.unsqueeze(0) for t in images], 0)
return images, labels
import torch
from torch.autograd import Variable
import utils
import dataset
from PIL import Image
import models.crnn as crnn
model_path = './data/crnn.pth'
img_path = './data/demo.png'
alphabet = '0123456789abcdefghijklmnopqrstuvwxyz'
model = crnn.CRNN(32, 1, 37, 256)
if torch.cuda.is_available():
model = model.cuda()
print('loading pretrained model from %s' % model_path)
model.load_state_dict(torch.load(model_path))
converter = utils.strLabelConverter(alphabet)
transformer = dataset.resizeNormalize((100, 32))
image = Image.open(img_path).convert('L')
image = transformer(image)
if torch.cuda.is_available():
image = image.cuda()
image = image.view(1, *image.size())
image = Variable(image)
model.eval()
preds = model(image)
_, preds = preds.max(2)
preds = preds.transpose(1, 0).contiguous().view(-1)
preds_size = Variable(torch.IntTensor([preds.size(0)]))
raw_pred = converter.decode(preds.data, preds_size.data, raw=True)
sim_pred = converter.decode(preds.data, preds_size.data, raw=False)
print('%-20s => %-20s' % (raw_pred, sim_pred))
FROM image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.10.0-centos7.6-dtk-22.10.1-py39-latest
RUN source /opt/dtk/env.sh
COPY requirements.txt requirements.txt
COPY requirements/ requirements/
RUN pip3 install -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com -r requirements.txt
lmdb==1.0.0
\ No newline at end of file
# 模型唯一标识
modelCode = 499
# 模型名称
modelName=crnn_pytorch
# 模型描述
modelDescription=OCR识别算法CRNN是一种端到端的文字识别的网络
# 应用场景
appScenario=推理,训练,金融,零售,交通
# 框架类型
frameType=pytorch
import torch.nn as nn
class BidirectionalLSTM(nn.Module):
def __init__(self, nIn, nHidden, nOut):
super(BidirectionalLSTM, self).__init__()
self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
self.embedding = nn.Linear(nHidden * 2, nOut)
def forward(self, input):
recurrent, _ = self.rnn(input)
T, b, h = recurrent.size()
t_rec = recurrent.view(T * b, h)
output = self.embedding(t_rec) # [T * b, nOut]
output = output.view(T, b, -1)
return output
class CRNN(nn.Module):
def __init__(self, imgH, nc, nclass, nh, n_rnn=2, leakyRelu=False):
super(CRNN, self).__init__()
assert imgH % 16 == 0, 'imgH has to be a multiple of 16'
ks = [3, 3, 3, 3, 3, 3, 2]
ps = [1, 1, 1, 1, 1, 1, 0]
ss = [1, 1, 1, 1, 1, 1, 1]
nm = [64, 128, 256, 256, 512, 512, 512]
cnn = nn.Sequential()
def convRelu(i, batchNormalization=False):
nIn = nc if i == 0 else nm[i - 1]
nOut = nm[i]
cnn.add_module('conv{0}'.format(i),
nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))
if batchNormalization:
cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
if leakyRelu:
cnn.add_module('relu{0}'.format(i),
nn.LeakyReLU(0.2, inplace=True))
else:
cnn.add_module('relu{0}'.format(i), nn.ReLU(True))
convRelu(0)
cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64
convRelu(1)
cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32
convRelu(2, True)
convRelu(3)
cnn.add_module('pooling{0}'.format(2),
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16
convRelu(4, True)
convRelu(5)
cnn.add_module('pooling{0}'.format(3),
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16
convRelu(6, True) # 512x1x16
self.cnn = cnn
self.rnn = nn.Sequential(
BidirectionalLSTM(512, nh, nh),
BidirectionalLSTM(nh, nh, nclass))
def forward(self, input):
# conv features
conv = self.cnn(input)
b, c, h, w = conv.size()
assert h == 1, "the height of conv must be 1"
conv = conv.squeeze(2)
conv = conv.permute(2, 0, 1) # [w, b, c]
# rnn features
output = self.rnn(conv)
return output
#!/usr/bin/python
# encoding: utf-8
import sys
import unittest
import torch
from torch.autograd import Variable
import collections
origin_path = sys.path
sys.path.append("..")
import utils
sys.path = origin_path
def equal(a, b):
if isinstance(a, torch.Tensor):
return a.equal(b)
elif isinstance(a, str):
return a == b
elif isinstance(a, collections.Iterable):
res = True
for (x, y) in zip(a, b):
res = res & equal(x, y)
return res
else:
return a == b
class utilsTestCase(unittest.TestCase):
def checkConverter(self):
encoder = utils.strLabelConverter('abcdefghijklmnopqrstuvwxyz')
# Encode
# trivial mode
result = encoder.encode('efa')
target = (torch.IntTensor([5, 6, 1]), torch.IntTensor([3]))
self.assertTrue(equal(result, target))
# batch mode
result = encoder.encode(['efa', 'ab'])
target = (torch.IntTensor([5, 6, 1, 1, 2]), torch.IntTensor([3, 2]))
self.assertTrue(equal(result, target))
# Decode
# trivial mode
result = encoder.decode(
torch.IntTensor([5, 6, 1]), torch.IntTensor([3]))
target = 'efa'
self.assertTrue(equal(result, target))
# replicate mode
result = encoder.decode(
torch.IntTensor([5, 5, 0, 1]), torch.IntTensor([4]))
target = 'ea'
self.assertTrue(equal(result, target))
# raise AssertionError
def f():
result = encoder.decode(
torch.IntTensor([5, 5, 0, 1]), torch.IntTensor([3]))
self.assertRaises(AssertionError, f)
# batch mode
result = encoder.decode(
torch.IntTensor([5, 6, 1, 1, 2]), torch.IntTensor([3, 2]))
target = ['efa', 'ab']
self.assertTrue(equal(result, target))
def checkOneHot(self):
v = torch.LongTensor([1, 2, 1, 2, 0])
v_length = torch.LongTensor([2, 3])
v_onehot = utils.oneHot(v, v_length, 4)
target = torch.FloatTensor([[[0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 0]],
[[0, 1, 0, 0], [0, 0, 1, 0], [1, 0, 0, 0]]])
assert target.equal(v_onehot)
def checkAverager(self):
acc = utils.averager()
acc.add(Variable(torch.Tensor([1, 2])))
acc.add(Variable(torch.Tensor([[5, 6]])))
assert acc.val() == 3.5
acc = utils.averager()
acc.add(torch.Tensor([1, 2]))
acc.add(torch.Tensor([[5, 6]]))
assert acc.val() == 3.5
def checkAssureRatio(self):
img = torch.Tensor([[1], [3]]).view(1, 1, 2, 1)
img = Variable(img)
img = utils.assureRatio(img)
assert torch.Size([1, 1, 2, 2]) == img.size()
def _suite():
suite = unittest.TestSuite()
suite.addTest(utilsTestCase("checkConverter"))
suite.addTest(utilsTestCase("checkOneHot"))
suite.addTest(utilsTestCase("checkAverager"))
suite.addTest(utilsTestCase("checkAssureRatio"))
return suite
if __name__ == "__main__":
suite = _suite()
runner = unittest.TextTestRunner()
runner.run(suite)
require('table')
require('torch')
require('os')
function clone(t)
-- deep-copy a table
if type(t) ~= "table" then return t end
local meta = getmetatable(t)
local target = {}
for k, v in pairs(t) do
if type(v) == "table" then
target[k] = clone(v)
else
target[k] = v
end
end
setmetatable(target, meta)
return target
end
function tableMerge(lhs, rhs)
output = clone(lhs)
for _, v in pairs(rhs) do
table.insert(output, v)
end
return output
end
function isInTable(val, val_list)
for _, item in pairs(val_list) do
if val == item then
return true
end
end
return false
end
function modelToList(model)
local ignoreList = {
'nn.Copy',
'nn.AddConstant',
'nn.MulConstant',
'nn.View',
'nn.Transpose',
'nn.SplitTable',
'nn.SharedParallelTable',
'nn.JoinTable',
}
local state = {}
local param
for i, layer in pairs(model.modules) do
local typeName = torch.type(layer)
if not isInTable(typeName, ignoreList) then
if typeName == 'nn.Sequential' or typeName == 'nn.ConcatTable' then
param = modelToList(layer)
elseif typeName == 'cudnn.SpatialConvolution' or typeName == 'nn.SpatialConvolution' then
param = layer:parameters()
elseif typeName == 'cudnn.SpatialBatchNormalization' or typeName == 'nn.SpatialBatchNormalization' then
param = layer:parameters()
bn_vars = {layer.running_mean, layer.running_var}
param = tableMerge(param, bn_vars)
elseif typeName == 'nn.LstmLayer' then
param = layer:parameters()
elseif typeName == 'nn.BiRnnJoin' then
param = layer:parameters()
elseif typeName == 'cudnn.SpatialMaxPooling' or typeName == 'nn.SpatialMaxPooling' then
param = {}
elseif typeName == 'cudnn.ReLU' or typeName == 'nn.ReLU' then
param = {}
else
print(string.format('Unknown class %s', typeName))
os.exit(0)
end
table.insert(state, {typeName, param})
else
print(string.format('pass %s', typeName))
end
end
return state
end
function saveModel(model, output_path)
local state = modelToList(model)
torch.save(output_path, state)
end
import torchfile
import argparse
import torch
from torch.nn.parameter import Parameter
import numpy as np
import models.crnn as crnn
layer_map = {
'SpatialConvolution': 'Conv2d',
'SpatialBatchNormalization': 'BatchNorm2d',
'ReLU': 'ReLU',
'SpatialMaxPooling': 'MaxPool2d',
'SpatialAveragePooling': 'AvgPool2d',
'SpatialUpSamplingNearest': 'UpsamplingNearest2d',
'View': None,
'Linear': 'linear',
'Dropout': 'Dropout',
'SoftMax': 'Softmax',
'Identity': None,
'SpatialFullConvolution': 'ConvTranspose2d',
'SpatialReplicationPadding': None,
'SpatialReflectionPadding': None,
'Copy': None,
'Narrow': None,
'SpatialCrossMapLRN': None,
'Sequential': None,
'ConcatTable': None, # output is list
'CAddTable': None, # input is list
'Concat': None,
'TorchObject': None,
'LstmLayer': 'LSTM',
'BiRnnJoin': 'Linear'
}
def torch_layer_serial(layer, layers):
name = layer[0]
if name == 'nn.Sequential' or name == 'nn.ConcatTable':
tmp_layers = []
for sub_layer in layer[1]:
torch_layer_serial(sub_layer, tmp_layers)
layers.extend(tmp_layers)
else:
layers.append(layer)
def py_layer_serial(layer, layers):
"""
Assume modules are defined as executive sequence.
"""
if len(layer._modules) >= 1:
tmp_layers = []
for sub_layer in layer.children():
py_layer_serial(sub_layer, tmp_layers)
layers.extend(tmp_layers)
else:
layers.append(layer)
def trans_pos(param, part_indexes, dim=0):
parts = np.split(param, len(part_indexes), dim)
new_parts = []
for i in part_indexes:
new_parts.append(parts[i])
return np.concatenate(new_parts, dim)
def load_params(py_layer, t7_layer):
if type(py_layer).__name__ == 'LSTM':
# LSTM
all_weights = []
num_directions = 2 if py_layer.bidirectional else 1
for i in range(py_layer.num_layers):
for j in range(num_directions):
suffix = '_reverse' if j == 1 else ''
weights = ['weight_ih_l{}{}', 'bias_ih_l{}{}',
'weight_hh_l{}{}', 'bias_hh_l{}{}']
weights = [x.format(i, suffix) for x in weights]
all_weights += weights
params = []
for i in range(len(t7_layer)):
params.extend(t7_layer[i][1])
params = [trans_pos(p, [0, 1, 3, 2], dim=0) for p in params]
else:
all_weights = []
name = t7_layer[0].split('.')[-1]
if name == 'BiRnnJoin':
weight_0, bias_0, weight_1, bias_1 = t7_layer[1]
weight = np.concatenate((weight_0, weight_1), axis=1)
bias = bias_0 + bias_1
t7_layer[1] = [weight, bias]
all_weights += ['weight', 'bias']
elif name == 'SpatialConvolution' or name == 'Linear':
all_weights += ['weight', 'bias']
elif name == 'SpatialBatchNormalization':
all_weights += ['weight', 'bias', 'running_mean', 'running_var']
params = t7_layer[1]
params = [torch.from_numpy(item) for item in params]
assert len(all_weights) == len(params), "params' number not match"
for py_param_name, t7_param in zip(all_weights, params):
item = getattr(py_layer, py_param_name)
if isinstance(item, Parameter):
item = item.data
try:
item.copy_(t7_param)
except RuntimeError:
print('Size not match between %s and %s' %
(item.size(), t7_param.size()))
def torch_to_pytorch(model, t7_file, output):
py_layers = []
for layer in list(model.children()):
py_layer_serial(layer, py_layers)
t7_data = torchfile.load(t7_file)
t7_layers = []
for layer in t7_data:
torch_layer_serial(layer, t7_layers)
j = 0
for i, py_layer in enumerate(py_layers):
py_name = type(py_layer).__name__
t7_layer = t7_layers[j]
t7_name = t7_layer[0].split('.')[-1]
if layer_map[t7_name] != py_name:
raise RuntimeError('%s does not match %s' % (py_name, t7_name))
if py_name == 'LSTM':
n_layer = 2 if py_layer.bidirectional else 1
n_layer *= py_layer.num_layers
t7_layer = t7_layers[j:j + n_layer]
j += n_layer
else:
j += 1
load_params(py_layer, t7_layer)
torch.save(model.state_dict(), output)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='Convert torch t7 model to pytorch'
)
parser.add_argument(
'--model_file',
'-m',
type=str,
required=True,
help='torch model file in t7 format'
)
parser.add_argument(
'--output',
'-o',
type=str,
default=None,
help='output file name prefix, xxx.py xxx.pth'
)
args = parser.parse_args()
py_model = crnn.CRNN(32, 1, 37, 256, 1)
torch_to_pytorch(py_model, args.model_file, args.output)
from __future__ import print_function
from __future__ import division
import argparse
import random
import torch
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
from torch.autograd import Variable
import numpy as np
# from warpctc_pytorch import CTCLoss
from torch.nn import CTCLoss
import os
import utils
import dataset
from datetime import datetime
import models.crnn as crnn
import time
parser = argparse.ArgumentParser()
parser.add_argument('--trainRoot', required=True, help='path to dataset')
parser.add_argument('--valRoot', required=True, help='path to dataset')
parser.add_argument('--workers', type=int, help='number of data loading workers', default=2)
parser.add_argument('--batchSize', type=int, default=64, help='input batch size')
parser.add_argument('--imgH', type=int, default=32, help='the height of the input image to network')
parser.add_argument('--imgW', type=int, default=100, help='the width of the input image to network')
parser.add_argument('--nh', type=int, default=256, help='size of the lstm hidden state')
parser.add_argument('--nepoch', type=int, default=25, help='number of epochs to train for')
# TODO(meijieru): epoch -> iter
parser.add_argument('--cuda', action='store_true', help='enables cuda')
parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
parser.add_argument('--pretrained', default='', help="path to pretrained model (to continue training)")
parser.add_argument('--alphabet', type=str, default='0123456789abcdefghijklmnopqrstuvwxyz')
parser.add_argument('--expr_dir', default='expr', help='Where to store samples and models')
parser.add_argument('--displayInterval', type=int, default=500, help='Interval to be displayed')
parser.add_argument('--n_test_disp', type=int, default=10, help='Number of samples to display when test')
parser.add_argument('--valInterval', type=int, default=500, help='Interval to be displayed')
parser.add_argument('--saveInterval', type=int, default=500, help='Interval to be displayed')
parser.add_argument('--lr', type=float, default=0.01, help='learning rate for Critic, not used by adadealta')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
parser.add_argument('--adam', action='store_true', help='Whether to use adam (default is rmsprop)')
parser.add_argument('--adadelta', action='store_true', help='Whether to use adadelta (default is rmsprop)')
parser.add_argument('--keep_ratio', action='store_true', help='whether to keep ratio for image resize')
parser.add_argument('--manualSeed', type=int, default=1234, help='reproduce experiemnt')
parser.add_argument('--random_sample', action='store_true', help='whether to sample the dataset with random sampler')
opt = parser.parse_args()
print(opt)
if not os.path.exists(opt.expr_dir):
os.makedirs(opt.expr_dir)
random.seed(opt.manualSeed)
np.random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
cudnn.benchmark = True
if torch.cuda.is_available() and not opt.cuda:
print("WARNING: You have a CUDA device, so you should probably run with --cuda")
# train_dataset = dataset.lmdbDataset(root=opt.trainroot)
train_dataset = dataset.lmdbDataset(root=opt.trainRoot)
assert train_dataset
if not opt.random_sample:
sampler = dataset.randomSequentialSampler(train_dataset, opt.batchSize)
else:
sampler = None
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=opt.batchSize,
shuffle=False, sampler=sampler,
num_workers=int(opt.workers),
collate_fn=dataset.alignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio=opt.keep_ratio))
test_dataset = dataset.lmdbDataset(
root=opt.valRoot, transform=dataset.resizeNormalize((100, 32)))
nclass = len(opt.alphabet) + 1
nc = 1
converter = utils.strLabelConverter(opt.alphabet)
criterion = CTCLoss()
# custom weights initialization called on crnn
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
crnn = crnn.CRNN(opt.imgH, nc, nclass, opt.nh)
crnn.apply(weights_init)
if opt.pretrained != '':
print('loading pretrained model from %s' % opt.pretrained)
crnn.load_state_dict(torch.load(opt.pretrained))
print(crnn)
image = torch.FloatTensor(opt.batchSize, 3, opt.imgH, opt.imgH)
text = torch.IntTensor(opt.batchSize * 5)
length = torch.IntTensor(opt.batchSize)
if opt.cuda:
crnn.cuda()
crnn = torch.nn.DataParallel(crnn, device_ids=range(opt.ngpu))
image = image.cuda()
criterion = criterion.cuda()
image = Variable(image)
text = Variable(text)
length = Variable(length)
# loss averager
loss_avg = utils.averager()
# setup optimizer
if opt.adam:
optimizer = optim.Adam(crnn.parameters(), lr=opt.lr,
betas=(opt.beta1, 0.999))
elif opt.adadelta:
optimizer = optim.Adadelta(crnn.parameters())
else:
optimizer = optim.RMSprop(crnn.parameters(), lr=opt.lr)
def val(net, dataset, criterion, max_iter=100):
print('Start val')
for p in crnn.parameters():
p.requires_grad = False
net.eval()
data_loader = torch.utils.data.DataLoader(
dataset, shuffle=True, batch_size=opt.batchSize, num_workers=int(opt.workers))
val_iter = iter(data_loader)
i = 0
n_correct = 0
loss_avg = utils.averager()
max_iter = min(max_iter, len(data_loader))
for i in range(max_iter):
data = next(val_iter)
i += 1
cpu_images, cpu_texts = data
batch_size = cpu_images.size(0)
utils.loadData(image, cpu_images)
t, l = converter.encode(cpu_texts)
utils.loadData(text, t)
utils.loadData(length, l)
preds = crnn(image).permute(1, 0, 2)
preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size))
cost = criterion(preds, text, preds_size, length) / batch_size
loss_avg.add(cost)
_, preds = preds.max(2)
# preds = preds.squeeze(2)
preds = preds.transpose(1, 0).contiguous().view(-1)
sim_preds = converter.decode(preds.data, preds_size.data, raw=False)
for pred, target in zip(sim_preds, cpu_texts):
if pred == target.lower():
n_correct += 1
raw_preds = converter.decode(preds.data, preds_size.data, raw=True)[:opt.n_test_disp]
for raw_pred, pred, gt in zip(raw_preds, sim_preds, cpu_texts):
print('%-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt))
accuracy = n_correct / float(max_iter * opt.batchSize)
print('Test loss: %f, accuray: %f' % (loss_avg.val(), accuracy))
def trainBatch(net, criterion, optimizer):
batch_time = utils.AverageMeter()
data_time = utils.AverageMeter()
end = time.time()
data = next(train_iter)
data_time.update((time.time() - end) * 1000)
cpu_images, cpu_texts = data
batch_size = cpu_images.size(0)
utils.loadData(image, cpu_images)
t, l = converter.encode(cpu_texts)
utils.loadData(text, t)
utils.loadData(length, l)
preds = crnn(image).permute(1, 0, 2)
preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size))
cost = criterion(preds, text, preds_size, length) / batch_size
crnn.zero_grad()
cost.backward()
optimizer.step()
batch_time.update((time.time() - end) * 1000)
fps = (batch_size / batch_time.val) * 1000
msg = 'Time {batch_time.val:.3f}ms (avg_time:{batch_time.avg:.3f}ms)\t' \
'Data {data_time.val:.3f}ms ({data_time.avg:.3f}ms)\t' \
'Fps {fps:.3f}\t'.format(
batch_time=batch_time,
data_time=data_time, fps=fps)
return cost
for epoch in range(opt.nepoch):
train_iter = iter(train_loader)
i = 0
time_all = 0
while i < len(train_loader):
for p in crnn.parameters():
p.requires_grad = True
crnn.train()
cost = trainBatch(crnn, criterion, optimizer)
loss_avg.add(cost)
i += 1
print('\r[%d/%d][%d/%d] Loss: %f' %
(epoch, opt.nepoch, i, len(train_loader), loss_avg.val()), end='')
loss_avg.reset()
val(crnn, test_dataset, criterion)
if i % opt.saveInterval == 0:
torch.save(
crnn.state_dict(), '{0}/netCRNN_{1}_{2}.pth'.format(opt.expr_dir, epoch, i))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment