Commit d4c0717c authored by dengjb's avatar dengjb
Browse files

update code

parent 4c93f0ed
......@@ -29,7 +29,8 @@ class lmdbDataset(Dataset):
sys.exit(0)
with self.env.begin(write=False) as txn:
nSamples = int(txn.get('num-samples'))
# nSamples = int(txn.get('num-samples'))
nSamples = int(txn.get('num-samples'.encode()).decode('utf-8'))
self.nSamples = nSamples
self.transform = transform
......@@ -42,7 +43,8 @@ class lmdbDataset(Dataset):
assert index <= len(self), 'index range error'
index += 1
with self.env.begin(write=False) as txn:
img_key = 'image-%09d' % index
# img_key = 'image-%09d' % index
img_key = 'image-%09d'.encode() % index
imgbuf = txn.get(img_key)
buf = six.BytesIO()
......@@ -57,8 +59,10 @@ class lmdbDataset(Dataset):
if self.transform is not None:
img = self.transform(img)
label_key = 'label-%09d' % index
label = str(txn.get(label_key))
# label_key = 'label-%09d' % index
label_key = 'label-%09d'.encode() % index
# label = str(txn.get(label_key))
label = str(txn.get(label_key).decode('utf-8'))
if self.target_transform is not None:
label = self.target_transform(label)
......
import torch.nn as nn
import torch.nn.functional as F
class BidirectionalLSTM(nn.Module):
def __init__(self, nIn, nHidden, nOut):
super(BidirectionalLSTM, self).__init__()
self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True, batch_first=True)
self.embedding = nn.Linear(nHidden * 2, nOut)
def forward(self, input):
recurrent, _ = self.rnn(input)
T, b, h = recurrent.size()
t_rec = recurrent.view(T * b, h)
# T, b, h = recurrent.size()
b, T, h = recurrent.size()
# t_rec = recurrent.view(T * b, h)
t_rec = recurrent.reshape(T * b, h)
output = self.embedding(t_rec) # [T * b, nOut]
output = output.view(T, b, -1)
# output = output.view(T, b, -1)
output = output.view(b, T, -1)
return output
......@@ -71,9 +74,13 @@ class CRNN(nn.Module):
b, c, h, w = conv.size()
assert h == 1, "the height of conv must be 1"
conv = conv.squeeze(2)
conv = conv.permute(2, 0, 1) # [w, b, c]
# conv = conv.permute(2, 0, 1) # [w, b, c]
conv = conv.permute(0, 2, 1) # [b, w, c]
# rnn features
output = self.rnn(conv)
# add log_softmax to converge output
output = F.log_softmax(output, dim=2)
return output
......@@ -131,7 +131,9 @@ def oneHot(v, v_length, nc):
def loadData(v, data):
v.data.resize_(data.size()).copy_(data)
# v.data.resize_(data.size()).copy_(data)
with torch.no_grad():
v.resize_(data.size()).copy_(data)
def prettyPrint(v):
......@@ -147,3 +149,25 @@ def assureRatio(img):
main = nn.UpsamplingBilinear2d(size=(h, h), scale_factor=None)
img = main(img)
return img
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment