Commit d4c0717c authored by dengjb's avatar dengjb
Browse files

update code

parent 4c93f0ed
...@@ -29,7 +29,8 @@ class lmdbDataset(Dataset): ...@@ -29,7 +29,8 @@ class lmdbDataset(Dataset):
sys.exit(0) sys.exit(0)
with self.env.begin(write=False) as txn: with self.env.begin(write=False) as txn:
nSamples = int(txn.get('num-samples')) # nSamples = int(txn.get('num-samples'))
nSamples = int(txn.get('num-samples'.encode()).decode('utf-8'))
self.nSamples = nSamples self.nSamples = nSamples
self.transform = transform self.transform = transform
...@@ -42,7 +43,8 @@ class lmdbDataset(Dataset): ...@@ -42,7 +43,8 @@ class lmdbDataset(Dataset):
assert index <= len(self), 'index range error' assert index <= len(self), 'index range error'
index += 1 index += 1
with self.env.begin(write=False) as txn: with self.env.begin(write=False) as txn:
img_key = 'image-%09d' % index # img_key = 'image-%09d' % index
img_key = 'image-%09d'.encode() % index
imgbuf = txn.get(img_key) imgbuf = txn.get(img_key)
buf = six.BytesIO() buf = six.BytesIO()
...@@ -57,8 +59,10 @@ class lmdbDataset(Dataset): ...@@ -57,8 +59,10 @@ class lmdbDataset(Dataset):
if self.transform is not None: if self.transform is not None:
img = self.transform(img) img = self.transform(img)
label_key = 'label-%09d' % index # label_key = 'label-%09d' % index
label = str(txn.get(label_key)) label_key = 'label-%09d'.encode() % index
# label = str(txn.get(label_key))
label = str(txn.get(label_key).decode('utf-8'))
if self.target_transform is not None: if self.target_transform is not None:
label = self.target_transform(label) label = self.target_transform(label)
......
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
class BidirectionalLSTM(nn.Module): class BidirectionalLSTM(nn.Module):
def __init__(self, nIn, nHidden, nOut): def __init__(self, nIn, nHidden, nOut):
super(BidirectionalLSTM, self).__init__() super(BidirectionalLSTM, self).__init__()
self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True) self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True, batch_first=True)
self.embedding = nn.Linear(nHidden * 2, nOut) self.embedding = nn.Linear(nHidden * 2, nOut)
def forward(self, input): def forward(self, input):
recurrent, _ = self.rnn(input) recurrent, _ = self.rnn(input)
T, b, h = recurrent.size() # T, b, h = recurrent.size()
t_rec = recurrent.view(T * b, h) b, T, h = recurrent.size()
# t_rec = recurrent.view(T * b, h)
t_rec = recurrent.reshape(T * b, h)
output = self.embedding(t_rec) # [T * b, nOut] output = self.embedding(t_rec) # [T * b, nOut]
output = output.view(T, b, -1) # output = output.view(T, b, -1)
output = output.view(b, T, -1)
return output return output
...@@ -71,9 +74,13 @@ class CRNN(nn.Module): ...@@ -71,9 +74,13 @@ class CRNN(nn.Module):
b, c, h, w = conv.size() b, c, h, w = conv.size()
assert h == 1, "the height of conv must be 1" assert h == 1, "the height of conv must be 1"
conv = conv.squeeze(2) conv = conv.squeeze(2)
conv = conv.permute(2, 0, 1) # [w, b, c] # conv = conv.permute(2, 0, 1) # [w, b, c]
conv = conv.permute(0, 2, 1) # [b, w, c]
# rnn features # rnn features
output = self.rnn(conv) output = self.rnn(conv)
# add log_softmax to converge output
output = F.log_softmax(output, dim=2)
return output return output
...@@ -131,7 +131,9 @@ def oneHot(v, v_length, nc): ...@@ -131,7 +131,9 @@ def oneHot(v, v_length, nc):
def loadData(v, data): def loadData(v, data):
v.data.resize_(data.size()).copy_(data) # v.data.resize_(data.size()).copy_(data)
with torch.no_grad():
v.resize_(data.size()).copy_(data)
def prettyPrint(v): def prettyPrint(v):
...@@ -147,3 +149,25 @@ def assureRatio(img): ...@@ -147,3 +149,25 @@ def assureRatio(img):
main = nn.UpsamplingBilinear2d(size=(h, h), scale_factor=None) main = nn.UpsamplingBilinear2d(size=(h, h), scale_factor=None)
img = main(img) img = main(img)
return img return img
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment