Commit e9171ce1 authored by Xian Li's avatar Xian Li Committed by Facebook Github Bot
Browse files

Fix LevT edge cases

Summary:
To avoid the case where can_ins_mask has all False so max_lengths has size [0, 1] which failed expand_as operator. Move it back into the skipping branch in script.

The same for deletion and ins_word.

Reviewed By: kahne

Differential Revision: D18365340

fbshipit-source-id: 509ac21d7d6fd9083d0710697288203977314c52
parent 13d9e2ba
...@@ -285,7 +285,7 @@ class LevenshteinTransformerModel(TracingTransformerModel): ...@@ -285,7 +285,7 @@ class LevenshteinTransformerModel(TracingTransformerModel):
output_scores, output_scores,
attn: Tensor, attn: Tensor,
word_del_attn: Optional[Tensor], word_del_attn: Optional[Tensor],
word_del_pred, word_del_out,
can_del_word, can_del_word,
pad_idx: int, pad_idx: int,
bos_idx: int, bos_idx: int,
...@@ -294,6 +294,8 @@ class LevenshteinTransformerModel(TracingTransformerModel): ...@@ -294,6 +294,8 @@ class LevenshteinTransformerModel(TracingTransformerModel):
# delete words # delete words
# do not delete tokens if it is <s> </s> # do not delete tokens if it is <s> </s>
if can_del_word.sum() != 0: # we cannot delete, skip if can_del_word.sum() != 0: # we cannot delete, skip
word_del_score = F.log_softmax(word_del_out, 2)
word_del_pred = torch.jit.Attribute(word_del_score.max(-1)[1], bool)
in_tokens = output_tokens[can_del_word] in_tokens = output_tokens[can_del_word]
in_scores = output_scores[can_del_word] in_scores = output_scores[can_del_word]
# apply deletion to a tensor # apply deletion to a tensor
...@@ -331,14 +333,24 @@ class LevenshteinTransformerModel(TracingTransformerModel): ...@@ -331,14 +333,24 @@ class LevenshteinTransformerModel(TracingTransformerModel):
def ins_placeholders( def ins_placeholders(
output_tokens, output_tokens,
output_scores, output_scores,
mask_ins_pred, mask_ins_out,
can_ins_mask, can_ins_mask,
pad_idx: int, pad_idx: int,
unk_idx: int, unk_idx: int,
eos_idx: int, eos_idx: int,
max_ratio: float,
max_lengths,
): ):
# insert placeholders # insert placeholders
if can_ins_mask.sum() != 0: if can_ins_mask.sum() != 0:
mask_ins_score = F.log_softmax(mask_ins_out, 2)
if eos_penalty > 0.0:
mask_ins_score[:, :, 0] -= eos_penalty
mask_ins_pred = mask_ins_score.max(-1)[1]
if max_ratio is not None and encoder_out[1] is not None:
mask_ins_pred = torch.min(
mask_ins_pred, max_lengths[can_ins_mask, None].expand_as(mask_ins_pred)
)
in_tokens = output_tokens[can_ins_mask] in_tokens = output_tokens[can_ins_mask]
in_scores = output_scores[can_ins_mask] in_scores = output_scores[can_ins_mask]
in_masks = in_tokens.ne(pad_idx) in_masks = in_tokens.ne(pad_idx)
...@@ -380,14 +392,15 @@ class LevenshteinTransformerModel(TracingTransformerModel): ...@@ -380,14 +392,15 @@ class LevenshteinTransformerModel(TracingTransformerModel):
output_scores, output_scores,
attn: Tensor, attn: Tensor,
word_ins_attn, word_ins_attn,
word_ins_pred, word_ins_out,
word_ins_scores,
can_ins_word, can_ins_word,
pad_idx: int, pad_idx: int,
unk_idx: int, unk_idx: int,
): ):
# insert words # insert words
if can_ins_word.sum() != 0: if can_ins_word.sum() != 0:
word_ins_scores = F.log_softmax(word_ins_out, 2)
word_ins_pred = word_ins_scores.max(-1)[1]
in_tokens = output_tokens[can_ins_word] in_tokens = output_tokens[can_ins_word]
in_scores = output_scores[can_ins_word] in_scores = output_scores[can_ins_word]
word_ins_masks = in_tokens.eq(unk_idx) word_ins_masks = in_tokens.eq(unk_idx)
...@@ -411,15 +424,13 @@ class LevenshteinTransformerModel(TracingTransformerModel): ...@@ -411,15 +424,13 @@ class LevenshteinTransformerModel(TracingTransformerModel):
script_skip_tensor(output_tokens, can_del_word), script_skip_tensor(output_tokens, can_del_word),
script_skip_tensor_list(list(encoder_out), can_del_word), script_skip_tensor_list(list(encoder_out), can_del_word),
) )
word_del_score = F.log_softmax(word_del_out, 2)
word_del_pred = word_del_score.max(-1)[1].bool()
output_tokens, output_scores, attn = del_word( output_tokens, output_scores, attn = del_word(
output_tokens, output_tokens,
output_scores, output_scores,
attn, attn,
word_del_attn, word_del_attn,
word_del_pred, word_del_out,
can_del_word, can_del_word,
self.pad, self.pad,
self.bos, self.bos,
...@@ -431,23 +442,16 @@ class LevenshteinTransformerModel(TracingTransformerModel): ...@@ -431,23 +442,16 @@ class LevenshteinTransformerModel(TracingTransformerModel):
script_skip_tensor(output_tokens, can_ins_mask), script_skip_tensor(output_tokens, can_ins_mask),
script_skip_tensor_list(encoder_out, can_ins_mask), script_skip_tensor_list(encoder_out, can_ins_mask),
) )
mask_ins_score = F.log_softmax(mask_ins_out, 2)
if eos_penalty > 0.0:
mask_ins_score[:, :, 0] -= eos_penalty
mask_ins_pred = mask_ins_score.max(-1)[1]
if max_ratio is not None and encoder_out[1] is not None:
mask_ins_pred = torch.min(
mask_ins_pred, max_lengths[can_ins_mask, None].expand_as(mask_ins_pred)
)
output_tokens, output_scores = ins_placeholders( output_tokens, output_scores = ins_placeholders(
output_tokens, output_tokens,
output_scores, output_scores,
mask_ins_pred, mask_ins_out,
can_ins_mask, can_ins_mask,
self.pad, self.pad,
self.unk, self.unk,
self.eos, self.eos,
max_ratio,
max_lengths,
) )
can_ins_word = output_tokens.eq(self.unk).sum(1) > 0 can_ins_word = output_tokens.eq(self.unk).sum(1) > 0
...@@ -455,16 +459,14 @@ class LevenshteinTransformerModel(TracingTransformerModel): ...@@ -455,16 +459,14 @@ class LevenshteinTransformerModel(TracingTransformerModel):
script_skip_tensor(output_tokens, can_ins_word), script_skip_tensor(output_tokens, can_ins_word),
script_skip_tensor_list(encoder_out, can_ins_word), script_skip_tensor_list(encoder_out, can_ins_word),
) )
word_ins_score = F.log_softmax(word_ins_out, 2)
word_ins_pred = word_ins_score.max(-1)[1]
output_tokens, output_scores, attn = ins_words( output_tokens, output_scores, attn = ins_words(
output_tokens, output_tokens,
output_scores, output_scores,
attn, attn,
word_ins_attn, word_ins_attn,
word_ins_pred, word_ins_out,
word_ins_score,
can_ins_word, can_ins_word,
self.pad, self.pad,
self.unk, self.unk,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment