Unverified Commit 809043f5 authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

[contrib] Fix the reference implementation of multihead_attn (#1423)



* follow the current signature
Signed-off-by: default avatarMasaki Kozuki <mkozuki@nvidia.com>

* call .backward on outputs
Signed-off-by: default avatarMasaki Kozuki <mkozuki@nvidia.com>

* update the other caller of _softmax_backward_data
Signed-off-by: default avatarMasaki Kozuki <mkozuki@nvidia.com>
parent 1337e81e
...@@ -263,7 +263,7 @@ class EncdecAttnFunc(torch.autograd.Function): ...@@ -263,7 +263,7 @@ class EncdecAttnFunc(torch.autograd.Function):
dropout_grads = torch._masked_scale(matmul2_dgrad1, dropout_mask, 1.0 / (1.0 - dropout_prob_t[0])) dropout_grads = torch._masked_scale(matmul2_dgrad1, dropout_mask, 1.0 / (1.0 - dropout_prob_t[0]))
# Softmax Grad (not a publically documented op) # Softmax Grad (not a publically documented op)
softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, softmax_results) softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, softmax_results.dtype)
# Matmul1 - DGRAD1 # Matmul1 - DGRAD1
# Input1: (data grads) [seqs*heads, seql_q, seql_k] # Input1: (data grads) [seqs*heads, seql_q, seql_k]
......
...@@ -236,7 +236,7 @@ class SelfAttnFunc(torch.autograd.Function): ...@@ -236,7 +236,7 @@ class SelfAttnFunc(torch.autograd.Function):
dropout_grads = torch._masked_scale(matmul2_dgrad1, dropout_mask, 1.0 / (1.0 - dropout_prob_t[0])) dropout_grads = torch._masked_scale(matmul2_dgrad1, dropout_mask, 1.0 / (1.0 - dropout_prob_t[0]))
# Softmax Grad (not a publically documented op) # Softmax Grad (not a publically documented op)
softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, softmax_results) softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, softmax_results.dtype)
# Matmul1 - DGRAD1 # Matmul1 - DGRAD1
# Input1: (data grads) [seqs*heads, seql_q, seql_k] # Input1: (data grads) [seqs*heads, seql_q, seql_k]
...@@ -301,6 +301,7 @@ class SelfAttnFunc(torch.autograd.Function): ...@@ -301,6 +301,7 @@ class SelfAttnFunc(torch.autograd.Function):
output_bias_grads, output_bias_grads,
None, None,
None, None,
None,
) )
......
...@@ -47,10 +47,8 @@ class EncdecMultiheadAttnTest(unittest.TestCase): ...@@ -47,10 +47,8 @@ class EncdecMultiheadAttnTest(unittest.TestCase):
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True) dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
def test_encdec_multihead_attn(self) : def test_encdec_multihead_attn(self) :
grads = torch.randn_like(self.tst_inputs_q) ref_outputs,_ = self.ref_layer.forward(self.ref_inputs_q,
self.ref_inputs_k,
ref_outputs,_ = self.ref_layer.forward(self.ref_inputs_q,
self.ref_inputs_k,
self.ref_inputs_k, self.ref_inputs_k,
key_padding_mask=None, key_padding_mask=None,
need_weights=False, need_weights=False,
...@@ -64,13 +62,15 @@ class EncdecMultiheadAttnTest(unittest.TestCase): ...@@ -64,13 +62,15 @@ class EncdecMultiheadAttnTest(unittest.TestCase):
need_weights=False, need_weights=False,
attn_mask=None, attn_mask=None,
is_training=True) is_training=True)
self.ref_inputs_q.backward(grads)
self.tst_inputs_q.backward(grads)
self.assertTrue(torch.allclose(self.ref_inputs_q, self.tst_inputs_q, atol=1e-5, rtol=1e-5)) self.assertTrue(torch.allclose(self.ref_inputs_q, self.tst_inputs_q, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.ref_inputs_k, self.tst_inputs_k, atol=1e-5, rtol=1e-5)) self.assertTrue(torch.allclose(self.ref_inputs_k, self.tst_inputs_k, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3)) self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))
with torch.no_grad():
ref_grads = torch.randn_like(ref_outputs)
tst_grads = ref_grads.clone()
ref_outputs.backward(ref_grads)
tst_outputs.backward(tst_grads)
self.assertTrue(torch.allclose(self.ref_inputs_q.grad, self.tst_inputs_q.grad, atol=1e-3, rtol=1e-3)) self.assertTrue(torch.allclose(self.ref_inputs_q.grad, self.tst_inputs_q.grad, atol=1e-3, rtol=1e-3))
def test_encdec_multihead_attn_time_mask(self) : def test_encdec_multihead_attn_time_mask(self) :
......
...@@ -43,29 +43,31 @@ class SelfMultiheadAttnTest(unittest.TestCase): ...@@ -43,29 +43,31 @@ class SelfMultiheadAttnTest(unittest.TestCase):
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True) dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
def test_self_multihead_attn(self) : def test_self_multihead_attn(self) :
grads = torch.randn_like(self.tst_inputs) ref_outputs,_ = self.ref_layer.forward(self.ref_inputs,
ref_outputs,_ = self.ref_layer.forward(self.ref_inputs,
self.ref_inputs,
self.ref_inputs, self.ref_inputs,
key_padding_mask=None, self.ref_inputs,
need_weights=False, key_padding_mask=None,
need_weights=False,
attn_mask=None, attn_mask=None,
is_training=True) is_training=True)
tst_outputs,_ = self.tst_layer.forward(self.tst_inputs, tst_outputs,_ = self.tst_layer.forward(self.tst_inputs,
self.tst_inputs,
self.tst_inputs, self.tst_inputs,
key_padding_mask=None, self.tst_inputs,
need_weights=False, key_padding_mask=None,
need_weights=False,
attn_mask=None, attn_mask=None,
is_training=True) is_training=True)
self.ref_inputs.backward(grads)
self.tst_inputs.backward(grads)
self.assertTrue(torch.allclose(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5)) self.assertTrue(torch.allclose(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3)) self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))
with torch.no_grad():
ref_grads = torch.randn_like(self.tst_inputs)
tst_grads = ref_grads.clone()
ref_outputs.backward(ref_grads)
tst_outputs.backward(tst_grads)
self.assertTrue(torch.allclose(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3)) self.assertTrue(torch.allclose(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3))
def test_self_multihead_attn_time_mask(self) : def test_self_multihead_attn_time_mask(self) :
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment