Unverified Commit 6ee1a4fd authored by Vasudev Gupta's avatar Vasudev Gupta Committed by GitHub
Browse files

add everything (#11651)

parent 57b6a80d
...@@ -647,13 +647,13 @@ class BigBirdBlockSparseAttention(nn.Module): ...@@ -647,13 +647,13 @@ class BigBirdBlockSparseAttention(nn.Module):
[ [
to_mask[:, :, :, : 3 * to_block_size], to_mask[:, :, :, : 3 * to_block_size],
to_mask[:, :, :, -to_block_size:], to_mask[:, :, :, -to_block_size:],
first_context_layer.new_ones([bsz, 1, 1, n_rand_blocks * to_block_size]), to_mask.new_ones([bsz, 1, 1, n_rand_blocks * to_block_size]),
], ],
dim=3, dim=3,
) )
second_rand_pad = torch.cat( second_rand_pad = torch.cat(
[ [
first_context_layer.new_ones([bsz, n_heads, from_block_size, 4 * to_block_size]), rand_mask.new_ones([bsz, n_heads, from_block_size, 4 * to_block_size]),
rand_mask[:, :, 0], rand_mask[:, :, 0],
], ],
dim=3, dim=3,
...@@ -781,13 +781,13 @@ class BigBirdBlockSparseAttention(nn.Module): ...@@ -781,13 +781,13 @@ class BigBirdBlockSparseAttention(nn.Module):
[ [
to_mask[:, :, :, :to_block_size], to_mask[:, :, :, :to_block_size],
to_mask[:, :, :, -3 * to_block_size :], to_mask[:, :, :, -3 * to_block_size :],
context_layer.new_ones([bsz, 1, 1, n_rand_blocks * to_block_size]), to_mask.new_ones([bsz, 1, 1, n_rand_blocks * to_block_size]),
], ],
dim=3, dim=3,
) )
second_last_rand_pad = torch.cat( second_last_rand_pad = torch.cat(
[ [
context_layer.new_ones([bsz, n_heads, from_block_size, 4 * to_block_size]), rand_mask.new_ones([bsz, n_heads, from_block_size, 4 * to_block_size]),
rand_mask[:, :, -1], rand_mask[:, :, -1],
], ],
dim=3, dim=3,
......
...@@ -475,13 +475,13 @@ class BigBirdPegasusBlockSparseAttention(nn.Module): ...@@ -475,13 +475,13 @@ class BigBirdPegasusBlockSparseAttention(nn.Module):
[ [
to_mask[:, :, :, : 3 * to_block_size], to_mask[:, :, :, : 3 * to_block_size],
to_mask[:, :, :, -to_block_size:], to_mask[:, :, :, -to_block_size:],
first_context_layer.new_ones([bsz, 1, 1, n_rand_blocks * to_block_size]), to_mask.new_ones([bsz, 1, 1, n_rand_blocks * to_block_size]),
], ],
dim=3, dim=3,
) )
second_rand_pad = torch.cat( second_rand_pad = torch.cat(
[ [
first_context_layer.new_ones([bsz, n_heads, from_block_size, 4 * to_block_size]), rand_mask.new_ones([bsz, n_heads, from_block_size, 4 * to_block_size]),
rand_mask[:, :, 0], rand_mask[:, :, 0],
], ],
dim=3, dim=3,
...@@ -609,13 +609,13 @@ class BigBirdPegasusBlockSparseAttention(nn.Module): ...@@ -609,13 +609,13 @@ class BigBirdPegasusBlockSparseAttention(nn.Module):
[ [
to_mask[:, :, :, :to_block_size], to_mask[:, :, :, :to_block_size],
to_mask[:, :, :, -3 * to_block_size :], to_mask[:, :, :, -3 * to_block_size :],
context_layer.new_ones([bsz, 1, 1, n_rand_blocks * to_block_size]), to_mask.new_ones([bsz, 1, 1, n_rand_blocks * to_block_size]),
], ],
dim=3, dim=3,
) )
second_last_rand_pad = torch.cat( second_last_rand_pad = torch.cat(
[ [
context_layer.new_ones([bsz, n_heads, from_block_size, 4 * to_block_size]), rand_mask.new_ones([bsz, n_heads, from_block_size, 4 * to_block_size]),
rand_mask[:, :, -1], rand_mask[:, :, -1],
], ],
dim=3, dim=3,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment