"docs/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "b7d44d97af5778012817bce06da7eec08ec2ffc3"
Unverified Commit 700d6825 authored by Burc Eryilmaz's avatar Burc Eryilmaz Committed by GitHub
Browse files

fix default mode missing additive mask option (#924)


Co-authored-by: default avatarSukru Eryilmaz <seryilmaz@computelab-dgx1v-32.nvidia.com>
parent 459de22d
...@@ -6,7 +6,7 @@ class SelfAttnFunc(torch.autograd.Function): ...@@ -6,7 +6,7 @@ class SelfAttnFunc(torch.autograd.Function):
def forward(ctx, use_time_mask, is_training, heads, scale, inputs, def forward(ctx, use_time_mask, is_training, heads, scale, inputs,
input_weights, output_weights, input_weights, output_weights,
input_biases, output_biases, input_biases, output_biases,
mask, dropout_prob): mask, is_additive_mask, dropout_prob):
use_biases_t = torch.tensor([input_biases is not None]) use_biases_t = torch.tensor([input_biases is not None])
heads_t = torch.tensor([heads]) heads_t = torch.tensor([heads])
scale_t = torch.tensor([scale]) scale_t = torch.tensor([scale])
...@@ -60,8 +60,11 @@ class SelfAttnFunc(torch.autograd.Function): ...@@ -60,8 +60,11 @@ class SelfAttnFunc(torch.autograd.Function):
batches,seql_q,seql_k = matmul1_results.size() batches,seql_q,seql_k = matmul1_results.size()
seqs = int(batches / heads) seqs = int(batches / heads)
matmul1_results = matmul1_results.view(seqs, heads, seql_q, seql_k) matmul1_results = matmul1_results.view(seqs, heads, seql_q, seql_k)
mask = mask.to(torch.bool) if is_additive_mask:
matmul1_results = matmul1_results.masked_fill_(mask.unsqueeze(1).unsqueeze(2), float('-inf')) matmul1_results = matmul1_results + mask.unsqueeze(1).unsqueeze(2)
else:
mask = mask.to(torch.bool)
matmul1_results = matmul1_results.masked_fill_(mask.unsqueeze(1).unsqueeze(2), float('-inf'))
matmul1_results = matmul1_results.view(seqs*heads, seql_q, seql_k) matmul1_results = matmul1_results.view(seqs*heads, seql_q, seql_k)
softmax_results = F.softmax(matmul1_results, dim=-1) softmax_results = F.softmax(matmul1_results, dim=-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment