Commit 980d5f44 authored by hubertlu-tw's avatar hubertlu-tw
Browse files

Fix torch._softmax_backward_data arguments

parent 5de49cc9
......@@ -206,7 +206,8 @@ class EncdecAttnFunc(torch.autograd.Function):
dropout_grads = torch._masked_scale(matmul2_dgrad1, dropout_mask, 1.0/(1.0-dropout_prob_t[0]))
# Softmax Grad (not a publically documented op)
softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, softmax_results)
### softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, softmax_results) # og
softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, torch.float32, grad_input=softmax_results)
# Matmul1 - DGRAD1
# Input1: (data grads) [seqs*heads, seql_q, seql_k]
......
......@@ -189,7 +189,8 @@ class SelfAttnFunc(torch.autograd.Function):
dropout_grads = torch._masked_scale(matmul2_dgrad1, dropout_mask, 1.0/(1.0-dropout_prob_t[0]))
# Softmax Grad (not a publically documented op)
softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, softmax_results)
### softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, softmax_results) # og
softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, torch.float32, grad_input=softmax_results)
# Matmul1 - DGRAD1
# Input1: (data grads) [seqs*heads, seql_q, seql_k]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment