Unverified Commit 809043f5 authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

[contrib] Fix the reference implementation of multihead_attn (#1423)



* follow the current signature
Signed-off-by: default avatarMasaki Kozuki <mkozuki@nvidia.com>

* call .backward on outputs
Signed-off-by: default avatarMasaki Kozuki <mkozuki@nvidia.com>

* update the other caller of _softmax_backward_data
Signed-off-by: default avatarMasaki Kozuki <mkozuki@nvidia.com>
parent 1337e81e
......@@ -263,7 +263,7 @@ class EncdecAttnFunc(torch.autograd.Function):
dropout_grads = torch._masked_scale(matmul2_dgrad1, dropout_mask, 1.0 / (1.0 - dropout_prob_t[0]))
# Softmax Grad (not a publically documented op)
softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, softmax_results)
softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, softmax_results.dtype)
# Matmul1 - DGRAD1
# Input1: (data grads) [seqs*heads, seql_q, seql_k]
......
......@@ -236,7 +236,7 @@ class SelfAttnFunc(torch.autograd.Function):
dropout_grads = torch._masked_scale(matmul2_dgrad1, dropout_mask, 1.0 / (1.0 - dropout_prob_t[0]))
# Softmax Grad (not a publically documented op)
softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, softmax_results)
softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, softmax_results.dtype)
# Matmul1 - DGRAD1
# Input1: (data grads) [seqs*heads, seql_q, seql_k]
......@@ -301,6 +301,7 @@ class SelfAttnFunc(torch.autograd.Function):
output_bias_grads,
None,
None,
None,
)
......
......@@ -47,10 +47,8 @@ class EncdecMultiheadAttnTest(unittest.TestCase):
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
def test_encdec_multihead_attn(self) :
grads = torch.randn_like(self.tst_inputs_q)
ref_outputs,_ = self.ref_layer.forward(self.ref_inputs_q,
self.ref_inputs_k,
ref_outputs,_ = self.ref_layer.forward(self.ref_inputs_q,
self.ref_inputs_k,
self.ref_inputs_k,
key_padding_mask=None,
need_weights=False,
......@@ -64,13 +62,15 @@ class EncdecMultiheadAttnTest(unittest.TestCase):
need_weights=False,
attn_mask=None,
is_training=True)
self.ref_inputs_q.backward(grads)
self.tst_inputs_q.backward(grads)
self.assertTrue(torch.allclose(self.ref_inputs_q, self.tst_inputs_q, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.ref_inputs_k, self.tst_inputs_k, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))
with torch.no_grad():
ref_grads = torch.randn_like(ref_outputs)
tst_grads = ref_grads.clone()
ref_outputs.backward(ref_grads)
tst_outputs.backward(tst_grads)
self.assertTrue(torch.allclose(self.ref_inputs_q.grad, self.tst_inputs_q.grad, atol=1e-3, rtol=1e-3))
def test_encdec_multihead_attn_time_mask(self) :
......
......@@ -43,29 +43,31 @@ class SelfMultiheadAttnTest(unittest.TestCase):
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
def test_self_multihead_attn(self) :
grads = torch.randn_like(self.tst_inputs)
ref_outputs,_ = self.ref_layer.forward(self.ref_inputs,
self.ref_inputs,
ref_outputs,_ = self.ref_layer.forward(self.ref_inputs,
self.ref_inputs,
key_padding_mask=None,
need_weights=False,
self.ref_inputs,
key_padding_mask=None,
need_weights=False,
attn_mask=None,
is_training=True)
tst_outputs,_ = self.tst_layer.forward(self.tst_inputs,
self.tst_inputs,
tst_outputs,_ = self.tst_layer.forward(self.tst_inputs,
self.tst_inputs,
key_padding_mask=None,
need_weights=False,
self.tst_inputs,
key_padding_mask=None,
need_weights=False,
attn_mask=None,
is_training=True)
self.ref_inputs.backward(grads)
self.tst_inputs.backward(grads)
self.assertTrue(torch.allclose(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))
with torch.no_grad():
ref_grads = torch.randn_like(self.tst_inputs)
tst_grads = ref_grads.clone()
ref_outputs.backward(ref_grads)
tst_outputs.backward(tst_grads)
self.assertTrue(torch.allclose(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3))
def test_self_multihead_attn_time_mask(self) :
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment