Commit c5a94a61 authored by Rémi Louf's avatar Rémi Louf
Browse files

fix function that defines masks in XLM

the definition of `get_masks` would blow with the proper combination of
arguments. It was just a matter of moving a definition outside of a
control structure.
parent 488a6641
...@@ -73,16 +73,16 @@ def get_masks(slen, lengths, causal, padding_mask=None): ...@@ -73,16 +73,16 @@ def get_masks(slen, lengths, causal, padding_mask=None):
""" """
Generate hidden states mask, and optionally an attention mask. Generate hidden states mask, and optionally an attention mask.
""" """
bs = lengths.size(0) alen = torch.arange(slen, dtype=torch.long, device=lengths.device)
if padding_mask is not None: if padding_mask is not None:
mask = padding_mask mask = padding_mask
else: else:
assert lengths.max().item() <= slen assert lengths.max().item() <= slen
alen = torch.arange(slen, dtype=torch.long, device=lengths.device)
mask = alen < lengths[:, None] mask = alen < lengths[:, None]
# attention mask is the same as mask, or triangular inferior attention (causal) # attention mask is the same as mask, or triangular inferior attention (causal)
if causal: if causal:
bs = lengths.size(0)
attn_mask = alen[None, None, :].repeat(bs, slen, 1) <= alen[None, :, None] attn_mask = alen[None, None, :].repeat(bs, slen, 1) <= alen[None, :, None]
else: else:
attn_mask = mask attn_mask = mask
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment