Commit 3c906d41 authored by tink2123's avatar tink2123
Browse files

merge dygraph

parents 8308f332 6a41a37a
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle import nn
class PSEHead(nn.Layer):
def __init__(self,
in_channels,
hidden_dim=256,
out_channels=7,
**kwargs):
super(PSEHead, self).__init__()
self.conv1 = nn.Conv2D(in_channels, hidden_dim, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2D(hidden_dim)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2D(hidden_dim, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x, **kwargs):
out = self.conv1(x)
out = self.relu1(self.bn1(out))
out = self.conv2(out)
return {'maps': out}
...@@ -71,8 +71,6 @@ class MultiheadAttention(nn.Layer): ...@@ -71,8 +71,6 @@ class MultiheadAttention(nn.Layer):
value, value,
key_padding_mask=None, key_padding_mask=None,
incremental_state=None, incremental_state=None,
need_weights=True,
static_kv=False,
attn_mask=None): attn_mask=None):
""" """
Inputs of forward function Inputs of forward function
...@@ -88,46 +86,42 @@ class MultiheadAttention(nn.Layer): ...@@ -88,46 +86,42 @@ class MultiheadAttention(nn.Layer):
attn_output: [target length, batch size, embed dim] attn_output: [target length, batch size, embed dim]
attn_output_weights: [batch size, target length, sequence length] attn_output_weights: [batch size, target length, sequence length]
""" """
tgt_len, bsz, embed_dim = query.shape q_shape = paddle.shape(query)
assert embed_dim == self.embed_dim src_shape = paddle.shape(key)
assert list(query.shape) == [tgt_len, bsz, embed_dim]
assert key.shape == value.shape
q = self._in_proj_q(query) q = self._in_proj_q(query)
k = self._in_proj_k(key) k = self._in_proj_k(key)
v = self._in_proj_v(value) v = self._in_proj_v(value)
q *= self.scaling q *= self.scaling
q = paddle.transpose(
q = q.reshape([tgt_len, bsz * self.num_heads, self.head_dim]).transpose( paddle.reshape(
[1, 0, 2]) q, [q_shape[0], q_shape[1], self.num_heads, self.head_dim]),
k = k.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose( [1, 2, 0, 3])
[1, 0, 2]) k = paddle.transpose(
v = v.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose( paddle.reshape(
[1, 0, 2]) k, [src_shape[0], q_shape[1], self.num_heads, self.head_dim]),
[1, 2, 0, 3])
src_len = k.shape[1] v = paddle.transpose(
paddle.reshape(
v, [src_shape[0], q_shape[1], self.num_heads, self.head_dim]),
[1, 2, 0, 3])
if key_padding_mask is not None: if key_padding_mask is not None:
assert key_padding_mask.shape[0] == bsz assert key_padding_mask.shape[0] == q_shape[1]
assert key_padding_mask.shape[1] == src_len assert key_padding_mask.shape[1] == src_shape[0]
attn_output_weights = paddle.matmul(q,
attn_output_weights = paddle.bmm(q, k.transpose([0, 2, 1])) paddle.transpose(k, [0, 1, 3, 2]))
assert list(attn_output_weights.
shape) == [bsz * self.num_heads, tgt_len, src_len]
if attn_mask is not None: if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0) attn_mask = paddle.unsqueeze(paddle.unsqueeze(attn_mask, 0), 0)
attn_output_weights += attn_mask attn_output_weights += attn_mask
if key_padding_mask is not None: if key_padding_mask is not None:
attn_output_weights = attn_output_weights.reshape( attn_output_weights = paddle.reshape(
[bsz, self.num_heads, tgt_len, src_len]) attn_output_weights,
key = key_padding_mask.unsqueeze(1).unsqueeze(2).astype('float32') [q_shape[1], self.num_heads, q_shape[0], src_shape[0]])
y = paddle.full(shape=key.shape, dtype='float32', fill_value='-inf') key = paddle.unsqueeze(paddle.unsqueeze(key_padding_mask, 1), 2)
key = paddle.cast(key, 'float32')
y = paddle.full(
shape=paddle.shape(key), dtype='float32', fill_value='-inf')
y = paddle.where(key == 0., key, y) y = paddle.where(key == 0., key, y)
attn_output_weights += y attn_output_weights += y
attn_output_weights = attn_output_weights.reshape(
[bsz * self.num_heads, tgt_len, src_len])
attn_output_weights = F.softmax( attn_output_weights = F.softmax(
attn_output_weights.astype('float32'), attn_output_weights.astype('float32'),
axis=-1, axis=-1,
...@@ -136,43 +130,34 @@ class MultiheadAttention(nn.Layer): ...@@ -136,43 +130,34 @@ class MultiheadAttention(nn.Layer):
attn_output_weights = F.dropout( attn_output_weights = F.dropout(
attn_output_weights, p=self.dropout, training=self.training) attn_output_weights, p=self.dropout, training=self.training)
attn_output = paddle.bmm(attn_output_weights, v) attn_output = paddle.matmul(attn_output_weights, v)
assert list(attn_output. attn_output = paddle.reshape(
shape) == [bsz * self.num_heads, tgt_len, self.head_dim] paddle.transpose(attn_output, [2, 0, 1, 3]),
attn_output = attn_output.transpose([1, 0, 2]).reshape( [q_shape[0], q_shape[1], self.embed_dim])
[tgt_len, bsz, embed_dim])
attn_output = self.out_proj(attn_output) attn_output = self.out_proj(attn_output)
if need_weights: return attn_output
# average attention weights over heads
attn_output_weights = attn_output_weights.reshape(
[bsz, self.num_heads, tgt_len, src_len])
attn_output_weights = attn_output_weights.sum(
axis=1) / self.num_heads
else:
attn_output_weights = None
return attn_output, attn_output_weights
def _in_proj_q(self, query): def _in_proj_q(self, query):
query = query.transpose([1, 2, 0]) query = paddle.transpose(query, [1, 2, 0])
query = paddle.unsqueeze(query, axis=2) query = paddle.unsqueeze(query, axis=2)
res = self.conv1(query) res = self.conv1(query)
res = paddle.squeeze(res, axis=2) res = paddle.squeeze(res, axis=2)
res = res.transpose([2, 0, 1]) res = paddle.transpose(res, [2, 0, 1])
return res return res
def _in_proj_k(self, key): def _in_proj_k(self, key):
key = key.transpose([1, 2, 0]) key = paddle.transpose(key, [1, 2, 0])
key = paddle.unsqueeze(key, axis=2) key = paddle.unsqueeze(key, axis=2)
res = self.conv2(key) res = self.conv2(key)
res = paddle.squeeze(res, axis=2) res = paddle.squeeze(res, axis=2)
res = res.transpose([2, 0, 1]) res = paddle.transpose(res, [2, 0, 1])
return res return res
def _in_proj_v(self, value): def _in_proj_v(self, value):
value = value.transpose([1, 2, 0]) #(1, 2, 0) value = paddle.transpose(value, [1, 2, 0]) #(1, 2, 0)
value = paddle.unsqueeze(value, axis=2) value = paddle.unsqueeze(value, axis=2)
res = self.conv3(value) res = self.conv3(value)
res = paddle.squeeze(res, axis=2) res = paddle.squeeze(res, axis=2)
res = res.transpose([2, 0, 1]) res = paddle.transpose(res, [2, 0, 1])
return res return res
...@@ -61,12 +61,12 @@ class Transformer(nn.Layer): ...@@ -61,12 +61,12 @@ class Transformer(nn.Layer):
custom_decoder=None, custom_decoder=None,
in_channels=0, in_channels=0,
out_channels=0, out_channels=0,
dst_vocab_size=99,
scale_embedding=True): scale_embedding=True):
super(Transformer, self).__init__() super(Transformer, self).__init__()
self.out_channels = out_channels + 1
self.embedding = Embeddings( self.embedding = Embeddings(
d_model=d_model, d_model=d_model,
vocab=dst_vocab_size, vocab=self.out_channels,
padding_idx=0, padding_idx=0,
scale_embedding=scale_embedding) scale_embedding=scale_embedding)
self.positional_encoding = PositionalEncoding( self.positional_encoding = PositionalEncoding(
...@@ -96,9 +96,10 @@ class Transformer(nn.Layer): ...@@ -96,9 +96,10 @@ class Transformer(nn.Layer):
self.beam_size = beam_size self.beam_size = beam_size
self.d_model = d_model self.d_model = d_model
self.nhead = nhead self.nhead = nhead
self.tgt_word_prj = nn.Linear(d_model, dst_vocab_size, bias_attr=False) self.tgt_word_prj = nn.Linear(
d_model, self.out_channels, bias_attr=False)
w0 = np.random.normal(0.0, d_model**-0.5, w0 = np.random.normal(0.0, d_model**-0.5,
(d_model, dst_vocab_size)).astype(np.float32) (d_model, self.out_channels)).astype(np.float32)
self.tgt_word_prj.weight.set_value(w0) self.tgt_word_prj.weight.set_value(w0)
self.apply(self._init_weights) self.apply(self._init_weights)
...@@ -156,46 +157,41 @@ class Transformer(nn.Layer): ...@@ -156,46 +157,41 @@ class Transformer(nn.Layer):
return self.forward_test(src) return self.forward_test(src)
def forward_test(self, src): def forward_test(self, src):
bs = src.shape[0] bs = paddle.shape(src)[0]
if self.encoder is not None: if self.encoder is not None:
src = self.positional_encoding(src.transpose([1, 0, 2])) src = self.positional_encoding(paddle.transpose(src, [1, 0, 2]))
memory = self.encoder(src) memory = self.encoder(src)
else: else:
memory = src.squeeze(2).transpose([2, 0, 1]) memory = paddle.transpose(paddle.squeeze(src, 2), [2, 0, 1])
dec_seq = paddle.full((bs, 1), 2, dtype=paddle.int64) dec_seq = paddle.full((bs, 1), 2, dtype=paddle.int64)
dec_prob = paddle.full((bs, 1), 1., dtype=paddle.float32)
for len_dec_seq in range(1, 25): for len_dec_seq in range(1, 25):
src_enc = memory.clone() dec_seq_embed = paddle.transpose(self.embedding(dec_seq), [1, 0, 2])
tgt_key_padding_mask = self.generate_padding_mask(dec_seq)
dec_seq_embed = self.embedding(dec_seq).transpose([1, 0, 2])
dec_seq_embed = self.positional_encoding(dec_seq_embed) dec_seq_embed = self.positional_encoding(dec_seq_embed)
tgt_mask = self.generate_square_subsequent_mask(dec_seq_embed.shape[ tgt_mask = self.generate_square_subsequent_mask(
0]) paddle.shape(dec_seq_embed)[0])
output = self.decoder( output = self.decoder(
dec_seq_embed, dec_seq_embed,
src_enc, memory,
tgt_mask=tgt_mask, tgt_mask=tgt_mask,
memory_mask=None, memory_mask=None,
tgt_key_padding_mask=tgt_key_padding_mask, tgt_key_padding_mask=None,
memory_key_padding_mask=None) memory_key_padding_mask=None)
dec_output = output.transpose([1, 0, 2]) dec_output = paddle.transpose(output, [1, 0, 2])
dec_output = dec_output[:, -1, :]
dec_output = dec_output[:, word_prob = F.softmax(self.tgt_word_prj(dec_output), axis=1)
-1, :] # Pick the last step: (bh * bm) * d_h preds_idx = paddle.argmax(word_prob, axis=1)
word_prob = F.log_softmax(self.tgt_word_prj(dec_output), axis=1)
word_prob = word_prob.reshape([1, bs, -1])
preds_idx = word_prob.argmax(axis=2)
if paddle.equal_all( if paddle.equal_all(
preds_idx[-1], preds_idx,
paddle.full( paddle.full(
preds_idx[-1].shape, 3, dtype='int64')): paddle.shape(preds_idx), 3, dtype='int64')):
break break
preds_prob = paddle.max(word_prob, axis=1)
preds_prob = word_prob.max(axis=2)
dec_seq = paddle.concat( dec_seq = paddle.concat(
[dec_seq, preds_idx.reshape([-1, 1])], axis=1) [dec_seq, paddle.reshape(preds_idx, [-1, 1])], axis=1)
dec_prob = paddle.concat(
return dec_seq [dec_prob, paddle.reshape(preds_prob, [-1, 1])], axis=1)
return [dec_seq, dec_prob]
def forward_beam(self, images): def forward_beam(self, images):
''' Translation work in one batch ''' ''' Translation work in one batch '''
...@@ -211,14 +207,15 @@ class Transformer(nn.Layer): ...@@ -211,14 +207,15 @@ class Transformer(nn.Layer):
n_prev_active_inst, n_bm): n_prev_active_inst, n_bm):
''' Collect tensor parts associated to active instances. ''' ''' Collect tensor parts associated to active instances. '''
_, *d_hs = beamed_tensor.shape beamed_tensor_shape = paddle.shape(beamed_tensor)
n_curr_active_inst = len(curr_active_inst_idx) n_curr_active_inst = len(curr_active_inst_idx)
new_shape = (n_curr_active_inst * n_bm, *d_hs) new_shape = (n_curr_active_inst * n_bm, beamed_tensor_shape[1],
beamed_tensor_shape[2])
beamed_tensor = beamed_tensor.reshape([n_prev_active_inst, -1]) beamed_tensor = beamed_tensor.reshape([n_prev_active_inst, -1])
beamed_tensor = beamed_tensor.index_select( beamed_tensor = beamed_tensor.index_select(
paddle.to_tensor(curr_active_inst_idx), axis=0) curr_active_inst_idx, axis=0)
beamed_tensor = beamed_tensor.reshape([*new_shape]) beamed_tensor = beamed_tensor.reshape(new_shape)
return beamed_tensor return beamed_tensor
...@@ -249,44 +246,26 @@ class Transformer(nn.Layer): ...@@ -249,44 +246,26 @@ class Transformer(nn.Layer):
b.get_current_state() for b in inst_dec_beams if not b.done b.get_current_state() for b in inst_dec_beams if not b.done
] ]
dec_partial_seq = paddle.stack(dec_partial_seq) dec_partial_seq = paddle.stack(dec_partial_seq)
dec_partial_seq = dec_partial_seq.reshape([-1, len_dec_seq]) dec_partial_seq = dec_partial_seq.reshape([-1, len_dec_seq])
return dec_partial_seq return dec_partial_seq
def prepare_beam_memory_key_padding_mask(
inst_dec_beams, memory_key_padding_mask, n_bm):
keep = []
for idx in (memory_key_padding_mask):
if not inst_dec_beams[idx].done:
keep.append(idx)
memory_key_padding_mask = memory_key_padding_mask[
paddle.to_tensor(keep)]
len_s = memory_key_padding_mask.shape[-1]
n_inst = memory_key_padding_mask.shape[0]
memory_key_padding_mask = paddle.concat(
[memory_key_padding_mask for i in range(n_bm)], axis=1)
memory_key_padding_mask = memory_key_padding_mask.reshape(
[n_inst * n_bm, len_s]) #repeat(1, n_bm)
return memory_key_padding_mask
def predict_word(dec_seq, enc_output, n_active_inst, n_bm, def predict_word(dec_seq, enc_output, n_active_inst, n_bm,
memory_key_padding_mask): memory_key_padding_mask):
tgt_key_padding_mask = self.generate_padding_mask(dec_seq) dec_seq = paddle.transpose(self.embedding(dec_seq), [1, 0, 2])
dec_seq = self.embedding(dec_seq).transpose([1, 0, 2])
dec_seq = self.positional_encoding(dec_seq) dec_seq = self.positional_encoding(dec_seq)
tgt_mask = self.generate_square_subsequent_mask(dec_seq.shape[ tgt_mask = self.generate_square_subsequent_mask(
0]) paddle.shape(dec_seq)[0])
dec_output = self.decoder( dec_output = self.decoder(
dec_seq, dec_seq,
enc_output, enc_output,
tgt_mask=tgt_mask, tgt_mask=tgt_mask,
tgt_key_padding_mask=tgt_key_padding_mask, tgt_key_padding_mask=None,
memory_key_padding_mask=memory_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask, )
).transpose([1, 0, 2]) dec_output = paddle.transpose(dec_output, [1, 0, 2])
dec_output = dec_output[:, dec_output = dec_output[:,
-1, :] # Pick the last step: (bh * bm) * d_h -1, :] # Pick the last step: (bh * bm) * d_h
word_prob = F.log_softmax(self.tgt_word_prj(dec_output), axis=1) word_prob = F.softmax(self.tgt_word_prj(dec_output), axis=1)
word_prob = word_prob.reshape([n_active_inst, n_bm, -1]) word_prob = paddle.reshape(word_prob, [n_active_inst, n_bm, -1])
return word_prob return word_prob
def collect_active_inst_idx_list(inst_beams, word_prob, def collect_active_inst_idx_list(inst_beams, word_prob,
...@@ -302,9 +281,8 @@ class Transformer(nn.Layer): ...@@ -302,9 +281,8 @@ class Transformer(nn.Layer):
n_active_inst = len(inst_idx_to_position_map) n_active_inst = len(inst_idx_to_position_map)
dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq) dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq)
memory_key_padding_mask = None
word_prob = predict_word(dec_seq, enc_output, n_active_inst, n_bm, word_prob = predict_word(dec_seq, enc_output, n_active_inst, n_bm,
memory_key_padding_mask) None)
# Update the beam with predicted word prob information and collect incomplete instances # Update the beam with predicted word prob information and collect incomplete instances
active_inst_idx_list = collect_active_inst_idx_list( active_inst_idx_list = collect_active_inst_idx_list(
inst_dec_beams, word_prob, inst_idx_to_position_map) inst_dec_beams, word_prob, inst_idx_to_position_map)
...@@ -324,27 +302,21 @@ class Transformer(nn.Layer): ...@@ -324,27 +302,21 @@ class Transformer(nn.Layer):
with paddle.no_grad(): with paddle.no_grad():
#-- Encode #-- Encode
if self.encoder is not None: if self.encoder is not None:
src = self.positional_encoding(images.transpose([1, 0, 2])) src = self.positional_encoding(images.transpose([1, 0, 2]))
src_enc = self.encoder(src).transpose([1, 0, 2]) src_enc = self.encoder(src)
else: else:
src_enc = images.squeeze(2).transpose([0, 2, 1]) src_enc = images.squeeze(2).transpose([0, 2, 1])
#-- Repeat data for beam search
n_bm = self.beam_size n_bm = self.beam_size
n_inst, len_s, d_h = src_enc.shape src_shape = paddle.shape(src_enc)
src_enc = paddle.concat([src_enc for i in range(n_bm)], axis=1) inst_dec_beams = [Beam(n_bm) for _ in range(1)]
src_enc = src_enc.reshape([n_inst * n_bm, len_s, d_h]).transpose( active_inst_idx_list = list(range(1))
[1, 0, 2]) # Repeat data for beam search
#-- Prepare beams src_enc = paddle.tile(src_enc, [1, n_bm, 1])
inst_dec_beams = [Beam(n_bm) for _ in range(n_inst)]
#-- Bookkeeping for active or not
active_inst_idx_list = list(range(n_inst))
inst_idx_to_position_map = get_inst_idx_to_tensor_position_map( inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(
active_inst_idx_list) active_inst_idx_list)
#-- Decode # Decode
for len_dec_seq in range(1, 25): for len_dec_seq in range(1, 25):
src_enc_copy = src_enc.clone() src_enc_copy = src_enc.clone()
active_inst_idx_list = beam_decode_step( active_inst_idx_list = beam_decode_step(
...@@ -358,10 +330,19 @@ class Transformer(nn.Layer): ...@@ -358,10 +330,19 @@ class Transformer(nn.Layer):
batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams, batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams,
1) 1)
result_hyp = [] result_hyp = []
for bs_hyp in batch_hyp: hyp_scores = []
bs_hyp_pad = bs_hyp[0] + [3] * (25 - len(bs_hyp[0])) for bs_hyp, score in zip(batch_hyp, batch_scores):
l = len(bs_hyp[0])
bs_hyp_pad = bs_hyp[0] + [3] * (25 - l)
result_hyp.append(bs_hyp_pad) result_hyp.append(bs_hyp_pad)
return paddle.to_tensor(np.array(result_hyp), dtype=paddle.int64) score = float(score) / l
hyp_score = [score for _ in range(25)]
hyp_scores.append(hyp_score)
return [
paddle.to_tensor(
np.array(result_hyp), dtype=paddle.int64),
paddle.to_tensor(hyp_scores)
]
def generate_square_subsequent_mask(self, sz): def generate_square_subsequent_mask(self, sz):
"""Generate a square mask for the sequence. The masked positions are filled with float('-inf'). """Generate a square mask for the sequence. The masked positions are filled with float('-inf').
...@@ -376,7 +357,7 @@ class Transformer(nn.Layer): ...@@ -376,7 +357,7 @@ class Transformer(nn.Layer):
return mask return mask
def generate_padding_mask(self, x): def generate_padding_mask(self, x):
padding_mask = x.equal(paddle.to_tensor(0, dtype=x.dtype)) padding_mask = paddle.equal(x, paddle.to_tensor(0, dtype=x.dtype))
return padding_mask return padding_mask
def _reset_parameters(self): def _reset_parameters(self):
...@@ -514,17 +495,17 @@ class TransformerEncoderLayer(nn.Layer): ...@@ -514,17 +495,17 @@ class TransformerEncoderLayer(nn.Layer):
src, src,
src, src,
attn_mask=src_mask, attn_mask=src_mask,
key_padding_mask=src_key_padding_mask)[0] key_padding_mask=src_key_padding_mask)
src = src + self.dropout1(src2) src = src + self.dropout1(src2)
src = self.norm1(src) src = self.norm1(src)
src = src.transpose([1, 2, 0]) src = paddle.transpose(src, [1, 2, 0])
src = paddle.unsqueeze(src, 2) src = paddle.unsqueeze(src, 2)
src2 = self.conv2(F.relu(self.conv1(src))) src2 = self.conv2(F.relu(self.conv1(src)))
src2 = paddle.squeeze(src2, 2) src2 = paddle.squeeze(src2, 2)
src2 = src2.transpose([2, 0, 1]) src2 = paddle.transpose(src2, [2, 0, 1])
src = paddle.squeeze(src, 2) src = paddle.squeeze(src, 2)
src = src.transpose([2, 0, 1]) src = paddle.transpose(src, [2, 0, 1])
src = src + self.dropout2(src2) src = src + self.dropout2(src2)
src = self.norm2(src) src = self.norm2(src)
...@@ -598,7 +579,7 @@ class TransformerDecoderLayer(nn.Layer): ...@@ -598,7 +579,7 @@ class TransformerDecoderLayer(nn.Layer):
tgt, tgt,
tgt, tgt,
attn_mask=tgt_mask, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0] key_padding_mask=tgt_key_padding_mask)
tgt = tgt + self.dropout1(tgt2) tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt) tgt = self.norm1(tgt)
tgt2 = self.multihead_attn( tgt2 = self.multihead_attn(
...@@ -606,18 +587,18 @@ class TransformerDecoderLayer(nn.Layer): ...@@ -606,18 +587,18 @@ class TransformerDecoderLayer(nn.Layer):
memory, memory,
memory, memory,
attn_mask=memory_mask, attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask)[0] key_padding_mask=memory_key_padding_mask)
tgt = tgt + self.dropout2(tgt2) tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt) tgt = self.norm2(tgt)
# default # default
tgt = tgt.transpose([1, 2, 0]) tgt = paddle.transpose(tgt, [1, 2, 0])
tgt = paddle.unsqueeze(tgt, 2) tgt = paddle.unsqueeze(tgt, 2)
tgt2 = self.conv2(F.relu(self.conv1(tgt))) tgt2 = self.conv2(F.relu(self.conv1(tgt)))
tgt2 = paddle.squeeze(tgt2, 2) tgt2 = paddle.squeeze(tgt2, 2)
tgt2 = tgt2.transpose([2, 0, 1]) tgt2 = paddle.transpose(tgt2, [2, 0, 1])
tgt = paddle.squeeze(tgt, 2) tgt = paddle.squeeze(tgt, 2)
tgt = tgt.transpose([2, 0, 1]) tgt = paddle.transpose(tgt, [2, 0, 1])
tgt = tgt + self.dropout3(tgt2) tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt) tgt = self.norm3(tgt)
...@@ -656,8 +637,8 @@ class PositionalEncoding(nn.Layer): ...@@ -656,8 +637,8 @@ class PositionalEncoding(nn.Layer):
(-math.log(10000.0) / dim)) (-math.log(10000.0) / dim))
pe[:, 0::2] = paddle.sin(position * div_term) pe[:, 0::2] = paddle.sin(position * div_term)
pe[:, 1::2] = paddle.cos(position * div_term) pe[:, 1::2] = paddle.cos(position * div_term)
pe = pe.unsqueeze(0) pe = paddle.unsqueeze(pe, 0)
pe = pe.transpose([1, 0, 2]) pe = paddle.transpose(pe, [1, 0, 2])
self.register_buffer('pe', pe) self.register_buffer('pe', pe)
def forward(self, x): def forward(self, x):
...@@ -670,7 +651,7 @@ class PositionalEncoding(nn.Layer): ...@@ -670,7 +651,7 @@ class PositionalEncoding(nn.Layer):
Examples: Examples:
>>> output = pos_encoder(x) >>> output = pos_encoder(x)
""" """
x = x + self.pe[:x.shape[0], :] x = x + self.pe[:paddle.shape(x)[0], :]
return self.dropout(x) return self.dropout(x)
...@@ -702,7 +683,7 @@ class PositionalEncoding_2d(nn.Layer): ...@@ -702,7 +683,7 @@ class PositionalEncoding_2d(nn.Layer):
(-math.log(10000.0) / dim)) (-math.log(10000.0) / dim))
pe[:, 0::2] = paddle.sin(position * div_term) pe[:, 0::2] = paddle.sin(position * div_term)
pe[:, 1::2] = paddle.cos(position * div_term) pe[:, 1::2] = paddle.cos(position * div_term)
pe = pe.unsqueeze(0).transpose([1, 0, 2]) pe = paddle.transpose(paddle.unsqueeze(pe, 0), [1, 0, 2])
self.register_buffer('pe', pe) self.register_buffer('pe', pe)
self.avg_pool_1 = nn.AdaptiveAvgPool2D((1, 1)) self.avg_pool_1 = nn.AdaptiveAvgPool2D((1, 1))
...@@ -722,22 +703,23 @@ class PositionalEncoding_2d(nn.Layer): ...@@ -722,22 +703,23 @@ class PositionalEncoding_2d(nn.Layer):
Examples: Examples:
>>> output = pos_encoder(x) >>> output = pos_encoder(x)
""" """
w_pe = self.pe[:x.shape[-1], :] w_pe = self.pe[:paddle.shape(x)[-1], :]
w1 = self.linear1(self.avg_pool_1(x).squeeze()).unsqueeze(0) w1 = self.linear1(self.avg_pool_1(x).squeeze()).unsqueeze(0)
w_pe = w_pe * w1 w_pe = w_pe * w1
w_pe = w_pe.transpose([1, 2, 0]) w_pe = paddle.transpose(w_pe, [1, 2, 0])
w_pe = w_pe.unsqueeze(2) w_pe = paddle.unsqueeze(w_pe, 2)
h_pe = self.pe[:x.shape[-2], :] h_pe = self.pe[:paddle.shape(x).shape[-2], :]
w2 = self.linear2(self.avg_pool_2(x).squeeze()).unsqueeze(0) w2 = self.linear2(self.avg_pool_2(x).squeeze()).unsqueeze(0)
h_pe = h_pe * w2 h_pe = h_pe * w2
h_pe = h_pe.transpose([1, 2, 0]) h_pe = paddle.transpose(h_pe, [1, 2, 0])
h_pe = h_pe.unsqueeze(3) h_pe = paddle.unsqueeze(h_pe, 3)
x = x + w_pe + h_pe x = x + w_pe + h_pe
x = x.reshape( x = paddle.transpose(
[x.shape[0], x.shape[1], x.shape[2] * x.shape[3]]).transpose( paddle.reshape(x,
[2, 0, 1]) [x.shape[0], x.shape[1], x.shape[2] * x.shape[3]]),
[2, 0, 1])
return self.dropout(x) return self.dropout(x)
...@@ -817,7 +799,7 @@ class Beam(): ...@@ -817,7 +799,7 @@ class Beam():
def sort_scores(self): def sort_scores(self):
"Sort the scores." "Sort the scores."
return self.scores, paddle.to_tensor( return self.scores, paddle.to_tensor(
[i for i in range(self.scores.shape[0])], dtype='int32') [i for i in range(int(self.scores.shape[0]))], dtype='int32')
def get_the_best_score_and_idx(self): def get_the_best_score_and_idx(self):
"Get the score of the best in the beam." "Get the score of the best in the beam."
......
...@@ -235,7 +235,8 @@ class ParallelSARDecoder(BaseDecoder): ...@@ -235,7 +235,8 @@ class ParallelSARDecoder(BaseDecoder):
# cal mask of attention weight # cal mask of attention weight
for i, valid_ratio in enumerate(valid_ratios): for i, valid_ratio in enumerate(valid_ratios):
valid_width = min(w, math.ceil(w * valid_ratio)) valid_width = min(w, math.ceil(w * valid_ratio))
attn_weight[i, :, :, valid_width:, :] = float('-inf') if valid_width < w:
attn_weight[i, :, :, valid_width:, :] = float('-inf')
attn_weight = paddle.reshape(attn_weight, [bsz, T, -1]) attn_weight = paddle.reshape(attn_weight, [bsz, T, -1])
attn_weight = F.softmax(attn_weight, axis=-1) attn_weight = F.softmax(attn_weight, axis=-1)
......
...@@ -22,7 +22,8 @@ def build_neck(config): ...@@ -22,7 +22,8 @@ def build_neck(config):
from .rnn import SequenceEncoder from .rnn import SequenceEncoder
from .pg_fpn import PGFPN from .pg_fpn import PGFPN
from .table_fpn import TableFPN from .table_fpn import TableFPN
support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN', 'TableFPN'] from .fpn import FPN
support_dict = ['FPN','DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN', 'TableFPN']
module_name = config.pop('name') module_name = config.pop('name')
assert module_name in support_dict, Exception('neck only support {}'.format( assert module_name in support_dict, Exception('neck only support {}'.format(
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.nn as nn
import paddle
import math
import paddle.nn.functional as F
class Conv_BN_ReLU(nn.Layer):
def __init__(self, in_planes, out_planes, kernel_size=1, stride=1, padding=0):
super(Conv_BN_ReLU, self).__init__()
self.conv = nn.Conv2D(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
bias_attr=False)
self.bn = nn.BatchNorm2D(out_planes, momentum=0.1)
self.relu = nn.ReLU()
for m in self.sublayers():
if isinstance(m, nn.Conv2D):
n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
m.weight = paddle.create_parameter(shape=m.weight.shape, dtype='float32', default_initializer=paddle.nn.initializer.Normal(0, math.sqrt(2. / n)))
elif isinstance(m, nn.BatchNorm2D):
m.weight = paddle.create_parameter(shape=m.weight.shape, dtype='float32', default_initializer=paddle.nn.initializer.Constant(1.0))
m.bias = paddle.create_parameter(shape=m.bias.shape, dtype='float32', default_initializer=paddle.nn.initializer.Constant(0.0))
def forward(self, x):
return self.relu(self.bn(self.conv(x)))
class FPN(nn.Layer):
def __init__(self, in_channels, out_channels):
super(FPN, self).__init__()
# Top layer
self.toplayer_ = Conv_BN_ReLU(in_channels[3], out_channels, kernel_size=1, stride=1, padding=0)
# Lateral layers
self.latlayer1_ = Conv_BN_ReLU(in_channels[2], out_channels, kernel_size=1, stride=1, padding=0)
self.latlayer2_ = Conv_BN_ReLU(in_channels[1], out_channels, kernel_size=1, stride=1, padding=0)
self.latlayer3_ = Conv_BN_ReLU(in_channels[0], out_channels, kernel_size=1, stride=1, padding=0)
# Smooth layers
self.smooth1_ = Conv_BN_ReLU(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.smooth2_ = Conv_BN_ReLU(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.smooth3_ = Conv_BN_ReLU(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.out_channels = out_channels * 4
for m in self.sublayers():
if isinstance(m, nn.Conv2D):
n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
m.weight = paddle.create_parameter(shape=m.weight.shape, dtype='float32',
default_initializer=paddle.nn.initializer.Normal(0,
math.sqrt(2. / n)))
elif isinstance(m, nn.BatchNorm2D):
m.weight = paddle.create_parameter(shape=m.weight.shape, dtype='float32',
default_initializer=paddle.nn.initializer.Constant(1.0))
m.bias = paddle.create_parameter(shape=m.bias.shape, dtype='float32',
default_initializer=paddle.nn.initializer.Constant(0.0))
def _upsample(self, x, scale=1):
return F.upsample(x, scale_factor=scale, mode='bilinear')
def _upsample_add(self, x, y, scale=1):
return F.upsample(x, scale_factor=scale, mode='bilinear') + y
def forward(self, x):
f2, f3, f4, f5 = x
p5 = self.toplayer_(f5)
f4 = self.latlayer1_(f4)
p4 = self._upsample_add(p5, f4,2)
p4 = self.smooth1_(p4)
f3 = self.latlayer2_(f3)
p3 = self._upsample_add(p4, f3,2)
p3 = self.smooth2_(p3)
f2 = self.latlayer3_(f2)
p2 = self._upsample_add(p3, f2,2)
p2 = self.smooth3_(p2)
p3 = self._upsample(p3, 2)
p4 = self._upsample(p4, 4)
p5 = self._upsample(p5, 8)
fuse = paddle.concat([p2, p3, p4, p5], axis=1)
return fuse
\ No newline at end of file
...@@ -28,13 +28,14 @@ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, Di ...@@ -28,13 +28,14 @@ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, Di
TableLabelDecode, NRTRLabelDecode, SARLabelDecode , SEEDLabelDecode TableLabelDecode, NRTRLabelDecode, SARLabelDecode , SEEDLabelDecode
from .cls_postprocess import ClsPostProcess from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess from .pg_postprocess import PGPostProcess
from .pse_postprocess import PSEPostProcess
def build_post_process(config, global_config=None): def build_post_process(config, global_config=None):
support_dict = [ support_dict = [
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', 'DBPostProcess', 'PSEPostProcess', 'EASTPostProcess', 'SASTPostProcess',
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess', 'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode',
'DistillationCTCLabelDecode', 'TableLabelDecode', 'PGPostProcess', 'DistillationCTCLabelDecode', 'TableLabelDecode',
'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode', 'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode',
'SEEDLabelDecode' 'SEEDLabelDecode'
] ]
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .pse_postprocess import PSEPostProcess
\ No newline at end of file
## 编译
code from https://github.com/whai362/pan_pp.pytorch
```python
python3 setup.py build_ext --inplace
```
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
import subprocess
python_path = sys.executable
if subprocess.call('cd ppocr/postprocess/pse_postprocess/pse;{} setup.py build_ext --inplace;cd -'.format(python_path), shell=True) != 0:
raise RuntimeError('Cannot compile pse: {}'.format(os.path.dirname(os.path.realpath(__file__))))
from .pse import pse
\ No newline at end of file
import numpy as np
import cv2
cimport numpy as np
cimport cython
cimport libcpp
cimport libcpp.pair
cimport libcpp.queue
from libcpp.pair cimport *
from libcpp.queue cimport *
@cython.boundscheck(False)
@cython.wraparound(False)
cdef np.ndarray[np.int32_t, ndim=2] _pse(np.ndarray[np.uint8_t, ndim=3] kernels,
np.ndarray[np.int32_t, ndim=2] label,
int kernel_num,
int label_num,
float min_area=0):
cdef np.ndarray[np.int32_t, ndim=2] pred
pred = np.zeros((label.shape[0], label.shape[1]), dtype=np.int32)
for label_idx in range(1, label_num):
if np.sum(label == label_idx) < min_area:
label[label == label_idx] = 0
cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] que = \
queue[libcpp.pair.pair[np.int16_t,np.int16_t]]()
cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] nxt_que = \
queue[libcpp.pair.pair[np.int16_t,np.int16_t]]()
cdef np.int16_t* dx = [-1, 1, 0, 0]
cdef np.int16_t* dy = [0, 0, -1, 1]
cdef np.int16_t tmpx, tmpy
points = np.array(np.where(label > 0)).transpose((1, 0))
for point_idx in range(points.shape[0]):
tmpx, tmpy = points[point_idx, 0], points[point_idx, 1]
que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy))
pred[tmpx, tmpy] = label[tmpx, tmpy]
cdef libcpp.pair.pair[np.int16_t,np.int16_t] cur
cdef int cur_label
for kernel_idx in range(kernel_num - 1, -1, -1):
while not que.empty():
cur = que.front()
que.pop()
cur_label = pred[cur.first, cur.second]
is_edge = True
for j in range(4):
tmpx = cur.first + dx[j]
tmpy = cur.second + dy[j]
if tmpx < 0 or tmpx >= label.shape[0] or tmpy < 0 or tmpy >= label.shape[1]:
continue
if kernels[kernel_idx, tmpx, tmpy] == 0 or pred[tmpx, tmpy] > 0:
continue
que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy))
pred[tmpx, tmpy] = cur_label
is_edge = False
if is_edge:
nxt_que.push(cur)
que, nxt_que = nxt_que, que
return pred
def pse(kernels, min_area):
kernel_num = kernels.shape[0]
label_num, label = cv2.connectedComponents(kernels[-1], connectivity=4)
return _pse(kernels[:-1], label, kernel_num, label_num, min_area)
\ No newline at end of file
from distutils.core import setup, Extension
from Cython.Build import cythonize
import numpy
setup(ext_modules=cythonize(Extension(
'pse',
sources=['pse.pyx'],
language='c++',
include_dirs=[numpy.get_include()],
library_dirs=[],
libraries=[],
extra_compile_args=['-O3'],
extra_link_args=[]
)))
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import cv2
import paddle
from paddle.nn import functional as F
from ppocr.postprocess.pse_postprocess.pse import pse
class PSEPostProcess(object):
"""
The post process for PSE.
"""
def __init__(self,
thresh=0.5,
box_thresh=0.85,
min_area=16,
box_type='box',
scale=4,
**kwargs):
assert box_type in ['box', 'poly'], 'Only box and poly is supported'
self.thresh = thresh
self.box_thresh = box_thresh
self.min_area = min_area
self.box_type = box_type
self.scale = scale
def __call__(self, outs_dict, shape_list):
pred = outs_dict['maps']
if not isinstance(pred, paddle.Tensor):
pred = paddle.to_tensor(pred)
pred = F.interpolate(pred, scale_factor=4 // self.scale, mode='bilinear')
score = F.sigmoid(pred[:, 0, :, :])
kernels = (pred > self.thresh).astype('float32')
text_mask = kernels[:, 0, :, :]
kernels[:, 0:, :, :] = kernels[:, 0:, :, :] * text_mask
score = score.numpy()
kernels = kernels.numpy().astype(np.uint8)
boxes_batch = []
for batch_index in range(pred.shape[0]):
boxes, scores = self.boxes_from_bitmap(score[batch_index], kernels[batch_index], shape_list[batch_index])
boxes_batch.append({'points': boxes, 'scores': scores})
return boxes_batch
def boxes_from_bitmap(self, score, kernels, shape):
label = pse(kernels, self.min_area)
return self.generate_box(score, label, shape)
def generate_box(self, score, label, shape):
src_h, src_w, ratio_h, ratio_w = shape
label_num = np.max(label) + 1
boxes = []
scores = []
for i in range(1, label_num):
ind = label == i
points = np.array(np.where(ind)).transpose((1, 0))[:, ::-1]
if points.shape[0] < self.min_area:
label[ind] = 0
continue
score_i = np.mean(score[ind])
if score_i < self.box_thresh:
label[ind] = 0
continue
if self.box_type == 'box':
rect = cv2.minAreaRect(points)
bbox = cv2.boxPoints(rect)
elif self.box_type == 'poly':
box_height = np.max(points[:, 1]) + 10
box_width = np.max(points[:, 0]) + 10
mask = np.zeros((box_height, box_width), np.uint8)
mask[points[:, 1], points[:, 0]] = 255
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
bbox = np.squeeze(contours[0], 1)
else:
raise NotImplementedError
bbox[:, 0] = np.clip(
np.round(bbox[:, 0] / ratio_w), 0, src_w)
bbox[:, 1] = np.clip(
np.round(bbox[:, 1] / ratio_h), 0, src_h)
boxes.append(bbox)
scores.append(score_i)
return boxes, scores
...@@ -169,15 +169,20 @@ class NRTRLabelDecode(BaseRecLabelDecode): ...@@ -169,15 +169,20 @@ class NRTRLabelDecode(BaseRecLabelDecode):
character_type, use_space_char) character_type, use_space_char)
def __call__(self, preds, label=None, *args, **kwargs): def __call__(self, preds, label=None, *args, **kwargs):
if preds.dtype == paddle.int64:
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
if preds[0][0] == 2:
preds_idx = preds[:, 1:]
else:
preds_idx = preds
text = self.decode(preds_idx) if len(preds) == 2:
preds_id = preds[0]
preds_prob = preds[1]
if isinstance(preds_id, paddle.Tensor):
preds_id = preds_id.numpy()
if isinstance(preds_prob, paddle.Tensor):
preds_prob = preds_prob.numpy()
if preds_id[0][0] == 2:
preds_idx = preds_id[:, 1:]
preds_prob = preds_prob[:, 1:]
else:
preds_idx = preds_id
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
if label is None: if label is None:
return text return text
label = self.decode(label[:, 1:]) label = self.decode(label[:, 1:])
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
EPS = 1e-6
def iou_single(a, b, mask, n_class):
valid = mask == 1
a = a.masked_select(valid)
b = b.masked_select(valid)
miou = []
for i in range(n_class):
if a.shape == [0] and a.shape==b.shape:
inter = paddle.to_tensor(0.0)
union = paddle.to_tensor(0.0)
else:
inter = ((a == i).logical_and(b == i)).astype('float32')
union = ((a == i).logical_or(b == i)).astype('float32')
miou.append(paddle.sum(inter) / (paddle.sum(union) + EPS))
miou = sum(miou) / len(miou)
return miou
def iou(a, b, mask, n_class=2, reduce=True):
batch_size = a.shape[0]
a = a.reshape([batch_size, -1])
b = b.reshape([batch_size, -1])
mask = mask.reshape([batch_size, -1])
iou = paddle.zeros((batch_size,), dtype='float32')
for i in range(batch_size):
iou[i] = iou_single(a[i], b[i], mask[i], n_class)
if reduce:
iou = paddle.mean(iou)
return iou
\ No newline at end of file
...@@ -108,14 +108,15 @@ def load_dygraph_params(config, model, logger, optimizer): ...@@ -108,14 +108,15 @@ def load_dygraph_params(config, model, logger, optimizer):
for k1, k2 in zip(state_dict.keys(), params.keys()): for k1, k2 in zip(state_dict.keys(), params.keys()):
if list(state_dict[k1].shape) == list(params[k2].shape): if list(state_dict[k1].shape) == list(params[k2].shape):
new_state_dict[k1] = params[k2] new_state_dict[k1] = params[k2]
else: else:
logger.info( logger.info(
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !" f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
) )
model.set_state_dict(new_state_dict) model.set_state_dict(new_state_dict)
logger.info(f"loaded pretrained_model successful from {pm}") logger.info(f"loaded pretrained_model successful from {pm}")
return {} return {}
def load_pretrained_params(model, path): def load_pretrained_params(model, path):
if path is None: if path is None:
return False return False
...@@ -138,6 +139,7 @@ def load_pretrained_params(model, path): ...@@ -138,6 +139,7 @@ def load_pretrained_params(model, path):
print(f"load pretrain successful from {path}") print(f"load pretrain successful from {path}")
return model return model
def save_model(model, def save_model(model,
optimizer, optimizer,
model_path, model_path,
......
# 表格识别 # 表格识别
* [1. 表格识别 pipeline](#1)
* [2. 性能](#2)
* [3. 使用](#3)
+ [3.1 快速开始](#31)
+ [3.2 训练](#32)
+ [3.3 评估](#33)
+ [3.4 预测](#34)
<a name="1"></a>
## 1. 表格识别 pipeline ## 1. 表格识别 pipeline
表格识别主要包含三个模型 表格识别主要包含三个模型
1. 单行文本检测-DB 1. 单行文本检测-DB
2. 单行文本识别-CRNN 2. 单行文本识别-CRNN
...@@ -17,6 +27,8 @@ ...@@ -17,6 +27,8 @@
3. 由单行文字的坐标、识别结果和单元格的坐标一起组合出单元格的识别结果。 3. 由单行文字的坐标、识别结果和单元格的坐标一起组合出单元格的识别结果。
4. 单元格的识别结果和表格结构一起构造表格的html字符串。 4. 单元格的识别结果和表格结构一起构造表格的html字符串。
<a name="2"></a>
## 2. 性能 ## 2. 性能
我们在 PubTabNet<sup>[1]</sup> 评估数据集上对算法进行了评估,性能如下 我们在 PubTabNet<sup>[1]</sup> 评估数据集上对算法进行了评估,性能如下
...@@ -26,8 +38,9 @@ ...@@ -26,8 +38,9 @@
| EDD<sup>[2]</sup> | 88.3 | | EDD<sup>[2]</sup> | 88.3 |
| Ours | 93.32 | | Ours | 93.32 |
<a name="3"></a>
## 3. 使用 ## 3. 使用
<a name="31"></a>
### 3.1 快速开始 ### 3.1 快速开始
```python ```python
...@@ -48,7 +61,7 @@ python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_ta ...@@ -48,7 +61,7 @@ python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_ta
运行完成后,每张图片的excel表格会保存到output字段指定的目录下 运行完成后,每张图片的excel表格会保存到output字段指定的目录下
note: 上述模型是在 PubLayNet 数据集上训练的表格识别模型,仅支持英文扫描场景,如需识别其他场景需要自己训练模型后替换 `det_model_dir`,`rec_model_dir`,`table_model_dir`三个字段即可。 note: 上述模型是在 PubLayNet 数据集上训练的表格识别模型,仅支持英文扫描场景,如需识别其他场景需要自己训练模型后替换 `det_model_dir`,`rec_model_dir`,`table_model_dir`三个字段即可。
<a name="32"></a>
### 3.2 训练 ### 3.2 训练
在这一章节中,我们仅介绍表格结构模型的训练,[文字检测](../../doc/doc_ch/detection.md)[文字识别](../../doc/doc_ch/recognition.md)的模型训练请参考对应的文档。 在这一章节中,我们仅介绍表格结构模型的训练,[文字检测](../../doc/doc_ch/detection.md)[文字识别](../../doc/doc_ch/recognition.md)的模型训练请参考对应的文档。
...@@ -75,7 +88,7 @@ python3 tools/train.py -c configs/table/table_mv3.yml -o Global.checkpoints=./yo ...@@ -75,7 +88,7 @@ python3 tools/train.py -c configs/table/table_mv3.yml -o Global.checkpoints=./yo
**注意**`Global.checkpoints`的优先级高于`Global.pretrain_weights`的优先级,即同时指定两个参数时,优先加载`Global.checkpoints`指定的模型,如果`Global.checkpoints`指定的模型路径有误,会加载`Global.pretrain_weights`指定的模型。 **注意**`Global.checkpoints`的优先级高于`Global.pretrain_weights`的优先级,即同时指定两个参数时,优先加载`Global.checkpoints`指定的模型,如果`Global.checkpoints`指定的模型路径有误,会加载`Global.pretrain_weights`指定的模型。
<a name="33"></a>
### 3.3 评估 ### 3.3 评估
表格使用 [TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src) 作为模型的评估指标。在进行模型评估之前,需要将pipeline中的三个模型分别导出为inference模型(我们已经提供好),还需要准备评估的gt, gt示例如下: 表格使用 [TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src) 作为模型的评估指标。在进行模型评估之前,需要将pipeline中的三个模型分别导出为inference模型(我们已经提供好),还需要准备评估的gt, gt示例如下:
...@@ -100,7 +113,7 @@ python3 table/eval_table.py --det_model_dir=path/to/det_model_dir --rec_model_di ...@@ -100,7 +113,7 @@ python3 table/eval_table.py --det_model_dir=path/to/det_model_dir --rec_model_di
```bash ```bash
teds: 93.32 teds: 93.32
``` ```
<a name="34"></a>
### 3.4 预测 ### 3.4 预测
```python ```python
......
...@@ -32,7 +32,6 @@ def run_shell_command(cmd): ...@@ -32,7 +32,6 @@ def run_shell_command(cmd):
else: else:
return None return None
def parser_results_from_log_by_name(log_path, names_list): def parser_results_from_log_by_name(log_path, names_list):
if not os.path.exists(log_path): if not os.path.exists(log_path):
raise ValueError("The log file {} does not exists!".format(log_path)) raise ValueError("The log file {} does not exists!".format(log_path))
...@@ -46,11 +45,13 @@ def parser_results_from_log_by_name(log_path, names_list): ...@@ -46,11 +45,13 @@ def parser_results_from_log_by_name(log_path, names_list):
outs = run_shell_command(cmd) outs = run_shell_command(cmd)
outs = outs.split("\n")[0] outs = outs.split("\n")[0]
result = outs.split("{}".format(name))[-1] result = outs.split("{}".format(name))[-1]
result = json.loads(result) try:
result = json.loads(result)
except:
result = np.array([int(r) for r in result.split()]).reshape(-1, 4)
parser_results[name] = result parser_results[name] = result
return parser_results return parser_results
def load_gt_from_file(gt_file): def load_gt_from_file(gt_file):
if not os.path.exists(gt_file): if not os.path.exists(gt_file):
raise ValueError("The log file {} does not exists!".format(gt_file)) raise ValueError("The log file {} does not exists!".format(gt_file))
...@@ -60,7 +61,11 @@ def load_gt_from_file(gt_file): ...@@ -60,7 +61,11 @@ def load_gt_from_file(gt_file):
parser_gt = {} parser_gt = {}
for line in data: for line in data:
image_name, result = line.strip("\n").split("\t") image_name, result = line.strip("\n").split("\t")
result = json.loads(result) image_name = image_name.split('/')[-1]
try:
result = json.loads(result)
except:
result = np.array([int(r) for r in result.split()]).reshape(-1, 4)
parser_gt[image_name] = result parser_gt[image_name] = result
return parser_gt return parser_gt
......
...@@ -23,10 +23,10 @@ Architecture: ...@@ -23,10 +23,10 @@ Architecture:
name: MobileNetV3 name: MobileNetV3
scale: 0.5 scale: 0.5
model_name: large model_name: large
disable_se: True disable_se: False
Neck: Neck:
name: DBFPN name: DBFPN
out_channels: 96 out_channels: 256
Head: Head:
name: DBHead name: DBHead
k: 50 k: 50
...@@ -74,7 +74,7 @@ Train: ...@@ -74,7 +74,7 @@ Train:
channel_first: False channel_first: False
- DetLabelEncode: # Class handling label - DetLabelEncode: # Class handling label
- Resize: - Resize:
# size: [640, 640] size: [640, 640]
- MakeBorderMap: - MakeBorderMap:
shrink_ratio: 0.4 shrink_ratio: 0.4
thresh_min: 0.3 thresh_min: 0.3
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment