Commit 40d76358 authored by Christina Floristean's avatar Christina Floristean
Browse files

Minor change to tests

parent 336c899a
......@@ -106,8 +106,6 @@ def random_attention_inputs(batch_size, n_seq, n, no_heads, c_hidden, inf=1e9,
mask = torch.randint(0, 2, (batch_size, n_seq, 1, 1, n), dtype=dtype, requires_grad=False).cuda()
z_bias = torch.rand(batch_size, 1, no_heads, n, n, dtype=dtype, requires_grad=requires_grad).cuda()
mask_bias = inf * (mask - 1)
if requires_grad:
mask_bias = mask_bias.detach().clone().requires_grad_()
biases = [mask_bias, z_bias]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment