Commit 3dcb5c77 authored by Changhan Wang's avatar Changhan Wang Committed by Facebook Github Bot
Browse files

fix levenshtein transfromer attn

Summary: When the `if` statements in the levenshtein transformer decoder forward are removed, `attn` may get inconsistent batch sizes with output tokens. This is a fix.

Reviewed By: cndn

Differential Revision: D17936411

fbshipit-source-id: a1583f3806dc9f41caeb783c043429e247035803
parent b5f41f82
......@@ -420,10 +420,16 @@ class LevenshteinTransformerModel(TransformerModel):
initial_output_scores = initial_output_tokens.new_zeros(
*initial_output_tokens.size()
).type_as(encoder_out["encoder_out"])
initial_attn = None
if getattr(self.decoder.layers[-1], "need_attn", False):
initial_attn = initial_output_tokens.new_zeros(
src_tokens.size(0), 2, src_tokens.size(1)
)
return {
"output_tokens": initial_output_tokens,
"output_scores": initial_output_scores,
"attn": None,
"attn": initial_attn,
}
......
......@@ -31,25 +31,45 @@ def skip_tensors(x, mask):
raise NotImplementedError
def expand_2d_or_3d_tensor(x, trg_dim, padding_idx):
"""
Expand 2D/3D tensor on dim=1
"""
if x is None:
return None
assert x.dim() == 2 or x.dim() == 3
assert trg_dim >= x.size(1), (trg_dim, x.size())
if trg_dim == x.size(1):
return x
dims = [x.size(0), trg_dim - x.size(1)]
if x.dim() == 3:
dims.append(x.size(2))
x = torch.cat([x, x.new_zeros(*dims).fill_(padding_idx)], 1)
return x
def fill_tensors(x, mask, y, padding_idx):
"""
Filling tensor x with y at masked positions (dim=0).
"""
if x is None:
return y
return None
assert x.dim() == y.dim() and mask.size(0) == x.size(0)
assert x.dim() == 2 or (x.dim() == 3 and x.size(2) == y.size(2))
n_selected = mask.sum()
if n_selected == 0:
return x
assert n_selected == y.size(0)
if n_selected == x.size(0):
return y
if x.size(1) < y.size(1):
dims = [x.size(0), y.size(1) - x.size(1)]
if x.dim() == 3:
dims.append(x.size(2))
x = torch.cat([x, x.new_zeros(*dims).fill_(padding_idx)], 1)
x = expand_2d_or_3d_tensor(x, y.size(1), padding_idx)
x[mask] = y
elif x.size(1) > y.size(1):
x[mask] = padding_idx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment