You need to sign in or sign up before continuing.
Commit 3dcb5c77 authored by Changhan Wang's avatar Changhan Wang Committed by Facebook Github Bot
Browse files

fix levenshtein transfromer attn

Summary: When the `if` statements in the levenshtein transformer decoder forward are removed, `attn` may get inconsistent batch sizes with output tokens. This is a fix.

Reviewed By: cndn

Differential Revision: D17936411

fbshipit-source-id: a1583f3806dc9f41caeb783c043429e247035803
parent b5f41f82
...@@ -420,10 +420,16 @@ class LevenshteinTransformerModel(TransformerModel): ...@@ -420,10 +420,16 @@ class LevenshteinTransformerModel(TransformerModel):
initial_output_scores = initial_output_tokens.new_zeros( initial_output_scores = initial_output_tokens.new_zeros(
*initial_output_tokens.size() *initial_output_tokens.size()
).type_as(encoder_out["encoder_out"]) ).type_as(encoder_out["encoder_out"])
initial_attn = None
if getattr(self.decoder.layers[-1], "need_attn", False):
initial_attn = initial_output_tokens.new_zeros(
src_tokens.size(0), 2, src_tokens.size(1)
)
return { return {
"output_tokens": initial_output_tokens, "output_tokens": initial_output_tokens,
"output_scores": initial_output_scores, "output_scores": initial_output_scores,
"attn": None, "attn": initial_attn,
} }
......
...@@ -31,25 +31,45 @@ def skip_tensors(x, mask): ...@@ -31,25 +31,45 @@ def skip_tensors(x, mask):
raise NotImplementedError raise NotImplementedError
def expand_2d_or_3d_tensor(x, trg_dim, padding_idx):
"""
Expand 2D/3D tensor on dim=1
"""
if x is None:
return None
assert x.dim() == 2 or x.dim() == 3
assert trg_dim >= x.size(1), (trg_dim, x.size())
if trg_dim == x.size(1):
return x
dims = [x.size(0), trg_dim - x.size(1)]
if x.dim() == 3:
dims.append(x.size(2))
x = torch.cat([x, x.new_zeros(*dims).fill_(padding_idx)], 1)
return x
def fill_tensors(x, mask, y, padding_idx): def fill_tensors(x, mask, y, padding_idx):
""" """
Filling tensor x with y at masked positions (dim=0). Filling tensor x with y at masked positions (dim=0).
""" """
if x is None: if x is None:
return y return None
assert x.dim() == y.dim() and mask.size(0) == x.size(0) assert x.dim() == y.dim() and mask.size(0) == x.size(0)
assert x.dim() == 2 or (x.dim() == 3 and x.size(2) == y.size(2)) assert x.dim() == 2 or (x.dim() == 3 and x.size(2) == y.size(2))
n_selected = mask.sum() n_selected = mask.sum()
if n_selected == 0:
return x
assert n_selected == y.size(0) assert n_selected == y.size(0)
if n_selected == x.size(0): if n_selected == x.size(0):
return y return y
if x.size(1) < y.size(1): if x.size(1) < y.size(1):
dims = [x.size(0), y.size(1) - x.size(1)] x = expand_2d_or_3d_tensor(x, y.size(1), padding_idx)
if x.dim() == 3:
dims.append(x.size(2))
x = torch.cat([x, x.new_zeros(*dims).fill_(padding_idx)], 1)
x[mask] = y x[mask] = y
elif x.size(1) > y.size(1): elif x.size(1) > y.size(1):
x[mask] = padding_idx x[mask] = padding_idx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment