Unverified Commit 700d6825 authored by Burc Eryilmaz's avatar Burc Eryilmaz Committed by GitHub
Browse files

fix default mode missing additive mask option (#924)


Co-authored-by: default avatarSukru Eryilmaz <seryilmaz@computelab-dgx1v-32.nvidia.com>
parent 459de22d
......@@ -6,7 +6,7 @@ class SelfAttnFunc(torch.autograd.Function):
def forward(ctx, use_time_mask, is_training, heads, scale, inputs,
input_weights, output_weights,
input_biases, output_biases,
mask, dropout_prob):
mask, is_additive_mask, dropout_prob):
use_biases_t = torch.tensor([input_biases is not None])
heads_t = torch.tensor([heads])
scale_t = torch.tensor([scale])
......@@ -60,8 +60,11 @@ class SelfAttnFunc(torch.autograd.Function):
batches,seql_q,seql_k = matmul1_results.size()
seqs = int(batches / heads)
matmul1_results = matmul1_results.view(seqs, heads, seql_q, seql_k)
mask = mask.to(torch.bool)
matmul1_results = matmul1_results.masked_fill_(mask.unsqueeze(1).unsqueeze(2), float('-inf'))
if is_additive_mask:
matmul1_results = matmul1_results + mask.unsqueeze(1).unsqueeze(2)
else:
mask = mask.to(torch.bool)
matmul1_results = matmul1_results.masked_fill_(mask.unsqueeze(1).unsqueeze(2), float('-inf'))
matmul1_results = matmul1_results.view(seqs*heads, seql_q, seql_k)
softmax_results = F.softmax(matmul1_results, dim=-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment