Commit 51783cc7 authored by hubertlu-tw's avatar hubertlu-tw
Browse files

Revert code changes to mutltihead_attn tests

parent 038ed999
......@@ -40,37 +40,37 @@ class EncdecMultiheadAttnTest(unittest.TestCase):
impl='fast')
self.tst_layer.cuda().half()
self.tst_layer.reset_parameters()
self.tst_inputs_q = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
self.tst_inputs_k = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
def test_encdec_multihead_attn(self) :
grads = torch.randn_like(self.tst_inputs_q)
ref_outputs,_ = self.ref_layer.forward(self.ref_inputs_q,
self.ref_inputs_k,
self.ref_inputs_k,
key_padding_mask=None,
need_weights=False,
key_padding_mask=None,
need_weights=False,
attn_mask=None,
is_training=True)
tst_outputs,_ = self.tst_layer.forward(self.tst_inputs_q,
self.tst_inputs_k,
tst_outputs,_ = self.tst_layer.forward(self.tst_inputs_q,
self.tst_inputs_k,
key_padding_mask=None,
need_weights=False,
self.tst_inputs_k,
key_padding_mask=None,
need_weights=False,
attn_mask=None,
is_training=True)
self.ref_inputs_q.backward(grads)
self.tst_inputs_q.backward(grads)
self.assertTrue(torch.allclose(self.ref_inputs_q, self.tst_inputs_q, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.ref_inputs_k, self.tst_inputs_k, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))
with torch.no_grad():
ref_grads = torch.randn_like(ref_outputs)
tst_grads = ref_grads.clone()
ref_outputs.backward(ref_grads)
tst_outputs.backward(tst_grads)
self.assertTrue(torch.allclose(self.ref_inputs_q.grad, self.tst_inputs_q.grad, atol=1e-3, rtol=1e-3))
def test_encdec_multihead_attn_time_mask(self) :
......
......@@ -15,34 +15,36 @@ class SelfMultiheadAttnTest(unittest.TestCase):
self.heads = 16
self.dropout_prob = 0.0
self.ref_layer = SelfMultiheadAttn(self.hidden_dim,
self.heads,
dropout=self.dropout_prob,
bias=False,
include_norm_add=False,
self.ref_layer = SelfMultiheadAttn(self.hidden_dim,
self.heads,
dropout=self.dropout_prob,
bias=False,
include_norm_add=False,
impl='default')
self.ref_layer.cuda().half()
self.ref_layer.reset_parameters()
self.ref_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
self.ref_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
# Reset seed so parameters are identical
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
self.tst_layer = SelfMultiheadAttn(self.hidden_dim,
self.heads,
dropout=self.dropout_prob,
bias=False,
include_norm_add=False,
self.tst_layer = SelfMultiheadAttn(self.hidden_dim,
self.heads,
dropout=self.dropout_prob,
bias=False,
include_norm_add=False,
impl='fast')
self.tst_layer.cuda().half()
self.tst_layer.reset_parameters()
self.tst_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
self.tst_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
def test_self_multihead_attn(self) :
def test_self_multihead_attn(self):
grads = torch.randn_like(self.tst_inputs)
ref_outputs,_ = self.ref_layer.forward(self.ref_inputs,
self.ref_inputs,
self.ref_inputs,
......@@ -59,15 +61,11 @@ class SelfMultiheadAttnTest(unittest.TestCase):
attn_mask=None,
is_training=True)
self.ref_inputs.backward(grads)
self.tst_inputs.backward(grads)
self.assertTrue(torch.allclose(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))
with torch.no_grad():
ref_grads = torch.randn_like(self.tst_inputs)
tst_grads = ref_grads.clone()
ref_outputs.backward(ref_grads)
tst_outputs.backward(tst_grads)
self.assertTrue(torch.allclose(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3))
def test_self_multihead_attn_time_mask(self) :
......@@ -75,23 +73,23 @@ class SelfMultiheadAttnTest(unittest.TestCase):
time_mask_byte= torch.triu(torch.ones(self.tst_inputs.size(0), self.tst_inputs.size(0), device=torch.device("cuda"), dtype=torch.uint8), 1)
time_mask_bool= time_mask_byte.to(torch.bool)
ref_outputs,_ = self.ref_layer.forward(self.ref_inputs,
self.ref_inputs,
ref_outputs,_ = self.ref_layer.forward(self.ref_inputs,
self.ref_inputs,
key_padding_mask=None,
need_weights=False,
self.ref_inputs,
key_padding_mask=None,
need_weights=False,
attn_mask=time_mask_bool,
is_training=True)
tst_outputs,_ = self.tst_layer.forward(self.tst_inputs,
self.tst_inputs,
tst_outputs,_ = self.tst_layer.forward(self.tst_inputs,
self.tst_inputs,
self.tst_inputs,
key_padding_mask=None,
need_weights=False,
key_padding_mask=None,
need_weights=False,
attn_mask=time_mask_byte,
is_training=True)
self.ref_inputs.backward(grads)
self.tst_inputs.backward(grads)
......@@ -104,23 +102,23 @@ class SelfMultiheadAttnTest(unittest.TestCase):
pad_mask_byte = torch.tril(torch.ones(self.tst_inputs.size(1), self.tst_inputs.size(0), device=torch.device("cuda"), dtype=torch.uint8), 1)
pad_mask_bool = pad_mask_byte.to(torch.bool)
ref_outputs,_ = self.ref_layer.forward(self.ref_inputs,
self.ref_inputs,
ref_outputs,_ = self.ref_layer.forward(self.ref_inputs,
self.ref_inputs,
key_padding_mask=pad_mask_bool,
need_weights=False,
self.ref_inputs,
key_padding_mask=pad_mask_bool,
need_weights=False,
attn_mask=None,
is_training=True)
tst_outputs,_ = self.tst_layer.forward(self.tst_inputs,
self.tst_inputs,
tst_outputs,_ = self.tst_layer.forward(self.tst_inputs,
self.tst_inputs,
key_padding_mask=pad_mask_byte,
need_weights=False,
self.tst_inputs,
key_padding_mask=pad_mask_byte,
need_weights=False,
attn_mask=None,
is_training=True)
self.ref_inputs.backward(grads)
self.tst_inputs.backward(grads)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment