Commit c2165224 authored by Changhan Wang's avatar Changhan Wang Committed by Facebook Github Bot
Browse files

fix max lengths in Levenshtein Tramsformer

Summary: Fix the max length calculation in Levenshtein Transformer

Reviewed By: jhcross

Differential Revision: D17672946

fbshipit-source-id: e5efbe7e56cf879d3e822864e4398f99f45b04d4
parent 6f58e15e
......@@ -323,12 +323,16 @@ class LevenshteinTransformerModel(TransformerModel):
output_scores = decoder_out["output_scores"]
attn = decoder_out["attn"]
bsz = output_tokens.size(0)
if max_ratio is None:
max_lens = output_tokens.new(output_tokens.size(0)).fill_(255)
max_lens = output_tokens.new().fill_(255)
else:
max_lens = (
(~encoder_out["encoder_padding_mask"]).sum(1) * max_ratio
).clamp(min=10)
if encoder_out["encoder_padding_mask"] is None:
max_src_len = encoder_out["encoder_out"].size(1)
src_lens = encoder_out["encoder_out"].new(bsz).fill_(max_src_len)
else:
src_lens = (~encoder_out["encoder_padding_mask"]).sum(1)
max_lens = (src_lens * max_ratio).clamp(min=10).long()
# delete words
# do not delete tokens if it is <s> </s>
......@@ -364,7 +368,7 @@ class LevenshteinTransformerModel(TransformerModel):
mask_ins_score[:, :, 0] -= eos_penalty
mask_ins_pred = mask_ins_score.max(-1)[1]
mask_ins_pred = torch.min(
mask_ins_pred, max_lens[:, None].expand_as(mask_ins_pred)
mask_ins_pred, max_lens[can_ins_mask, None].expand_as(mask_ins_pred)
)
_tokens, _scores = _apply_ins_masks(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment