Unverified Commit 88399476 authored by Bartosz Szmelczynski's avatar Bartosz Szmelczynski Committed by GitHub
Browse files

Fix bigbird random attention (#21023)

* switch np.random.permutation to jax.random.permuation

* remove comments

* remove leftover comment

* skip similarity tests

* modify indices_prng_key usage, add deterministic behaviour

* update style

* remove unused import

* remove copy statement since classes are not identical

* remove numpy import

* revert removing copied from statements

* make style from copied

* remove copied from statement

* update copied from statement to include only np.ndarry

* add deterministic args, unittestskip equivalence tests
parent 27b66bea
...@@ -19,7 +19,6 @@ import flax ...@@ -19,7 +19,6 @@ import flax
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask from flax.linen import combine_masks, make_causal_mask
from flax.linen import partitioning as nn_partitioning from flax.linen import partitioning as nn_partitioning
...@@ -459,6 +458,10 @@ class FlaxBigBirdBlockSparseAttention(nn.Module): ...@@ -459,6 +458,10 @@ class FlaxBigBirdBlockSparseAttention(nn.Module):
key_layer = self.transpose_for_scores(self.key(hidden_states), n_heads, head_size) key_layer = self.transpose_for_scores(self.key(hidden_states), n_heads, head_size)
value_layer = self.transpose_for_scores(self.value(hidden_states), n_heads, head_size) value_layer = self.transpose_for_scores(self.value(hidden_states), n_heads, head_size)
indices_prng_key = None
if not deterministic:
indices_prng_key = self.make_rng("indices")
attn_output, attn_weights = self.bigbird_block_sparse_attention( attn_output, attn_weights = self.bigbird_block_sparse_attention(
query_layer, query_layer,
key_layer, key_layer,
...@@ -470,6 +473,8 @@ class FlaxBigBirdBlockSparseAttention(nn.Module): ...@@ -470,6 +473,8 @@ class FlaxBigBirdBlockSparseAttention(nn.Module):
blocked_encoder_mask, blocked_encoder_mask,
n_heads, n_heads,
head_size, head_size,
indices_prng_key=indices_prng_key,
deterministic=deterministic,
plan_from_length=None, plan_from_length=None,
plan_num_rand_blocks=None, plan_num_rand_blocks=None,
output_attentions=output_attentions, output_attentions=output_attentions,
...@@ -528,6 +533,8 @@ class FlaxBigBirdBlockSparseAttention(nn.Module): ...@@ -528,6 +533,8 @@ class FlaxBigBirdBlockSparseAttention(nn.Module):
to_blocked_mask, to_blocked_mask,
n_heads, n_heads,
head_size, head_size,
indices_prng_key: Optional[jax.random.PRNGKey] = None,
deterministic: Optional[bool] = True,
plan_from_length=None, plan_from_length=None,
plan_num_rand_blocks=None, plan_num_rand_blocks=None,
output_attentions=None, output_attentions=None,
...@@ -571,12 +578,18 @@ class FlaxBigBirdBlockSparseAttention(nn.Module): ...@@ -571,12 +578,18 @@ class FlaxBigBirdBlockSparseAttention(nn.Module):
rsqrt_d = 1 / jnp.sqrt(head_size) rsqrt_d = 1 / jnp.sqrt(head_size)
attn_mask_penalty = -10000.0 attn_mask_penalty = -10000.0
np.random.seed(self.block_sparse_seed)
if from_seq_len in [1024, 3072, 4096]: # old plans used in paper if from_seq_len in [1024, 3072, 4096]: # old plans used in paper
max_seqlen = self.config.max_position_embeddings max_seqlen = self.config.max_position_embeddings
rand_attn = [ rand_attn = [
self._bigbird_block_rand_mask( self._bigbird_block_rand_mask(
max_seqlen, max_seqlen, from_block_size, to_block_size, n_rand_blocks, last_idx=1024 max_seqlen,
max_seqlen,
from_block_size,
to_block_size,
n_rand_blocks,
indices_prng_key=indices_prng_key,
deterministic=deterministic,
last_idx=1024,
)[: (from_seq_len // from_block_size - 2)] )[: (from_seq_len // from_block_size - 2)]
for _ in range(n_heads) for _ in range(n_heads)
] ]
...@@ -585,7 +598,6 @@ class FlaxBigBirdBlockSparseAttention(nn.Module): ...@@ -585,7 +598,6 @@ class FlaxBigBirdBlockSparseAttention(nn.Module):
plan_from_length, plan_num_rand_blocks = self._get_rand_attn_plan( plan_from_length, plan_num_rand_blocks = self._get_rand_attn_plan(
from_seq_len, from_block_size, n_rand_blocks from_seq_len, from_block_size, n_rand_blocks
) )
rand_attn = self._bigbird_block_rand_mask_with_head( rand_attn = self._bigbird_block_rand_mask_with_head(
from_seq_length=from_seq_len, from_seq_length=from_seq_len,
to_seq_length=to_seq_len, to_seq_length=to_seq_len,
...@@ -594,6 +606,7 @@ class FlaxBigBirdBlockSparseAttention(nn.Module): ...@@ -594,6 +606,7 @@ class FlaxBigBirdBlockSparseAttention(nn.Module):
num_heads=n_heads, num_heads=n_heads,
plan_from_length=plan_from_length, plan_from_length=plan_from_length,
plan_num_rand_blocks=plan_num_rand_blocks, plan_num_rand_blocks=plan_num_rand_blocks,
indices_prng_key=indices_prng_key,
) )
rand_attn = jnp.stack(rand_attn, axis=0) rand_attn = jnp.stack(rand_attn, axis=0)
...@@ -942,7 +955,14 @@ class FlaxBigBirdBlockSparseAttention(nn.Module): ...@@ -942,7 +955,14 @@ class FlaxBigBirdBlockSparseAttention(nn.Module):
@staticmethod @staticmethod
def _bigbird_block_rand_mask( def _bigbird_block_rand_mask(
from_seq_length, to_seq_length, from_block_size, to_block_size, num_rand_blocks, last_idx=-1 from_seq_length,
to_seq_length,
from_block_size,
to_block_size,
num_rand_blocks,
indices_prng_key: Optional[jax.random.PRNGKey] = None,
deterministic: Optional[bool] = True,
last_idx: Optional[int] = -1,
): ):
""" """
Create adjacency list of random attention. Create adjacency list of random attention.
...@@ -953,6 +973,8 @@ class FlaxBigBirdBlockSparseAttention(nn.Module): ...@@ -953,6 +973,8 @@ class FlaxBigBirdBlockSparseAttention(nn.Module):
from_block_size: int. size of block in from sequence. from_block_size: int. size of block in from sequence.
to_block_size: int. size of block in to sequence. to_block_size: int. size of block in to sequence.
num_rand_blocks: int. Number of random chunks per row. num_rand_blocks: int. Number of random chunks per row.
indices_prng_key: jax.random.PRNGKey. PRNG key that is used to perform random jax operations.
deterministic: bool. When False random attention will be used.
last_idx: if -1 then num_rand_blocks blocks chosen anywhere in to sequence, last_idx: if -1 then num_rand_blocks blocks chosen anywhere in to sequence,
if positive then num_rand_blocks blocks chosen only up to last_idx. if positive then num_rand_blocks blocks chosen only up to last_idx.
...@@ -963,9 +985,12 @@ class FlaxBigBirdBlockSparseAttention(nn.Module): ...@@ -963,9 +985,12 @@ class FlaxBigBirdBlockSparseAttention(nn.Module):
if from_seq_length // from_block_size != to_seq_length // to_block_size: if from_seq_length // from_block_size != to_seq_length // to_block_size:
raise ValueError("Error the number of blocks needs to be same!") raise ValueError("Error the number of blocks needs to be same!")
rand_attn = jnp.zeros((from_seq_length // from_block_size - 2, num_rand_blocks), dtype=jnp.int32)
# deterministic nor randomness
if deterministic:
return rand_attn
rand_attn = np.zeros((from_seq_length // from_block_size - 2, num_rand_blocks), dtype=np.int32) middle_seq = jnp.arange(1, to_seq_length // to_block_size - 1, dtype=jnp.int32)
middle_seq = np.arange(1, to_seq_length // to_block_size - 1, dtype=np.int32)
last = to_seq_length // to_block_size - 1 last = to_seq_length // to_block_size - 1
if last_idx > (2 * to_block_size): if last_idx > (2 * to_block_size):
last = (last_idx // to_block_size) - 1 last = (last_idx // to_block_size) - 1
...@@ -975,25 +1000,31 @@ class FlaxBigBirdBlockSparseAttention(nn.Module): ...@@ -975,25 +1000,31 @@ class FlaxBigBirdBlockSparseAttention(nn.Module):
start = i - 2 start = i - 2
end = i end = i
if i == 1: if i == 1:
rand_attn[i - 1, :] = np.random.permutation(middle_seq[2:last])[:r] seq_values = jax.random.permutation(indices_prng_key, middle_seq[2:last])[:r]
rand_attn = rand_attn.at[i - 1].set(seq_values)
elif i == 2: elif i == 2:
rand_attn[i - 1, :] = np.random.permutation(middle_seq[3:last])[:r] seq_values = jax.random.permutation(indices_prng_key, middle_seq[3:last])[:r]
rand_attn = rand_attn.at[i - 1].set(seq_values)
elif i == from_seq_length // from_block_size - 3: elif i == from_seq_length // from_block_size - 3:
rand_attn[i - 1, :] = np.random.permutation(middle_seq[:last])[:r] seq_values = jax.random.permutation(indices_prng_key, middle_seq[:last])[:r]
rand_attn = rand_attn.at[i - 1].set(seq_values)
# Missing -3: should have been sliced till last-3 # Missing -3: should have been sliced till last-3
elif i == from_seq_length // from_block_size - 2: elif i == from_seq_length // from_block_size - 2:
rand_attn[i - 1, :] = np.random.permutation(middle_seq[:last])[:r] seq_values = jax.random.permutation(indices_prng_key, middle_seq[:last])[:r]
rand_attn = rand_attn.at[i - 1].set(seq_values)
# Missing -4: should have been sliced till last-4 # Missing -4: should have been sliced till last-4
else: else:
if start > last: if start > last:
start = last start = last
rand_attn[i - 1, :] = np.random.permutation(middle_seq[:start])[:r] seq_values = jax.random.permutation(indices_prng_key, middle_seq[:start])[:r]
rand_attn = rand_attn.at[i - 1].set(seq_values)
elif (end + 1) == last: elif (end + 1) == last:
rand_attn[i - 1, :] = np.random.permutation(middle_seq[:start])[:r] seq_values = jax.random.permutation(indices_prng_key, middle_seq[:start])[:r]
rand_attn = rand_attn.at[i - 1].set(seq_values)
else: else:
rand_attn[i - 1, :] = np.random.permutation( concat_values = jnp.concatenate((middle_seq[:start], middle_seq[end + 1 : last]))
np.concatenate((middle_seq[:start], middle_seq[end + 1 : last])) seq_values = jax.random.permutation(indices_prng_key, concat_values)[:r]
)[:r] rand_attn = rand_attn.at[i - 1].set(seq_values)
return rand_attn return rand_attn
def _bigbird_block_rand_mask_with_head( def _bigbird_block_rand_mask_with_head(
...@@ -1005,6 +1036,8 @@ class FlaxBigBirdBlockSparseAttention(nn.Module): ...@@ -1005,6 +1036,8 @@ class FlaxBigBirdBlockSparseAttention(nn.Module):
num_heads, num_heads,
plan_from_length, plan_from_length,
plan_num_rand_blocks, plan_num_rand_blocks,
indices_prng_key: Optional[jax.random.PRNGKey] = None,
deterministic: Optional[bool] = True,
window_block_left=1, window_block_left=1,
window_block_right=1, window_block_right=1,
global_block_top=1, global_block_top=1,
...@@ -1023,6 +1056,8 @@ class FlaxBigBirdBlockSparseAttention(nn.Module): ...@@ -1023,6 +1056,8 @@ class FlaxBigBirdBlockSparseAttention(nn.Module):
num_heads: int. total number of heads. num_heads: int. total number of heads.
plan_from_length: list. plan from length where num_random_blocks are choosen from. plan_from_length: list. plan from length where num_random_blocks are choosen from.
plan_num_rand_blocks: list. number of rand blocks within the plan. plan_num_rand_blocks: list. number of rand blocks within the plan.
indices_prng_key: jax.random.PRNGKey. PRNG key that is used to perform random jax operations.
deterministic: bool. When False random attention will be used.
window_block_left: int. number of blocks of window to left of a block. window_block_left: int. number of blocks of window to left of a block.
window_block_right: int. number of blocks of window to right of a block. window_block_right: int. number of blocks of window to right of a block.
global_block_top: int. number of blocks at the top. global_block_top: int. number of blocks at the top.
...@@ -1045,15 +1080,22 @@ class FlaxBigBirdBlockSparseAttention(nn.Module): ...@@ -1045,15 +1080,22 @@ class FlaxBigBirdBlockSparseAttention(nn.Module):
# Total number of blocks in the mmask # Total number of blocks in the mmask
num_blocks = from_seq_length // from_block_size num_blocks = from_seq_length // from_block_size
# Number of blocks per plan # Number of blocks per plan
plan_block_length = np.array(plan_from_length) // from_block_size plan_block_length = jnp.array(plan_from_length) // from_block_size
# till when to follow plan # till when to follow plan
max_plan_idx = plan_from_length.index(from_seq_length) max_plan_idx = plan_from_length.index(from_seq_length)
# Random Attention adjacency list # Random Attention adjacency list
rand_attn = [ rand_attn = [
np.zeros((num_blocks, np.sum(plan_num_rand_blocks[: max_plan_idx + 1])), dtype=np.int32) jnp.zeros((num_blocks, sum(plan_num_rand_blocks[: max_plan_idx + 1])), dtype=jnp.int32)
for i in range(num_heads) for i in range(num_heads)
] ]
# deterministic
if deterministic:
for nh in range(num_heads):
rand_attn[nh] = rand_attn[nh][global_block_top : num_blocks - global_block_bottom, :]
return rand_attn
# We will go iteratively over the plan blocks and pick random number of # We will go iteratively over the plan blocks and pick random number of
# Attention blocks from the legally allowed blocks # Attention blocks from the legally allowed blocks
for plan_idx in range(max_plan_idx + 1): for plan_idx in range(max_plan_idx + 1):
...@@ -1064,11 +1106,11 @@ class FlaxBigBirdBlockSparseAttention(nn.Module): ...@@ -1064,11 +1106,11 @@ class FlaxBigBirdBlockSparseAttention(nn.Module):
# column indx start fromm plan_block_length[plan_idx-1] and ends at # column indx start fromm plan_block_length[plan_idx-1] and ends at
# plan_block_length[plan_idx] # plan_block_length[plan_idx]
if plan_num_rand_blocks[plan_idx] > 0: if plan_num_rand_blocks[plan_idx] > 0:
rnd_r_cnt = int(np.sum(plan_num_rand_blocks[:plan_idx])) rnd_r_cnt = int(sum(plan_num_rand_blocks[:plan_idx]))
curr_r_cnt = int(np.sum(plan_num_rand_blocks[: plan_idx + 1])) curr_r_cnt = int(sum(plan_num_rand_blocks[: plan_idx + 1]))
for blk_rw_idx in range(global_block_top, plan_block_length[plan_idx - 1]): for blk_rw_idx in range(global_block_top, plan_block_length[plan_idx - 1]):
for h in range(num_heads): for h in range(num_heads):
rand_attn[h][blk_rw_idx, rnd_r_cnt:curr_r_cnt] = self._get_single_block_row_attention( single_block_row_attention = self._get_single_block_row_attention(
block_id=blk_rw_idx, block_id=blk_rw_idx,
to_start_block_id=plan_block_length[plan_idx - 1], to_start_block_id=plan_block_length[plan_idx - 1],
to_end_block_id=plan_block_length[plan_idx], to_end_block_id=plan_block_length[plan_idx],
...@@ -1077,6 +1119,10 @@ class FlaxBigBirdBlockSparseAttention(nn.Module): ...@@ -1077,6 +1119,10 @@ class FlaxBigBirdBlockSparseAttention(nn.Module):
window_block_right=window_block_right, window_block_right=window_block_right,
global_block_left=global_block_left, global_block_left=global_block_left,
global_block_right=global_block_right, global_block_right=global_block_right,
indices_prng_key=indices_prng_key,
)
rand_attn[h] = (
rand_attn[h].at[blk_rw_idx, rnd_r_cnt:curr_r_cnt].set(single_block_row_attention)
) )
for pl_id in range(plan_idx): for pl_id in range(plan_idx):
...@@ -1086,11 +1132,11 @@ class FlaxBigBirdBlockSparseAttention(nn.Module): ...@@ -1086,11 +1132,11 @@ class FlaxBigBirdBlockSparseAttention(nn.Module):
rnd_r_cnt = 0 rnd_r_cnt = 0
to_start_block_id = 0 to_start_block_id = 0
if pl_id > 0: if pl_id > 0:
rnd_r_cnt = int(np.sum(plan_num_rand_blocks[:pl_id])) rnd_r_cnt = int(sum(plan_num_rand_blocks[:pl_id]))
to_start_block_id = plan_block_length[pl_id - 1] to_start_block_id = plan_block_length[pl_id - 1]
curr_r_cnt = int(np.sum(plan_num_rand_blocks[: pl_id + 1])) curr_r_cnt = int(sum(plan_num_rand_blocks[: pl_id + 1]))
for h in range(num_heads): for h in range(num_heads):
rand_attn[h][blk_rw_idx, rnd_r_cnt:curr_r_cnt] = self._get_single_block_row_attention( single_block_row_attention = self._get_single_block_row_attention(
block_id=blk_rw_idx, block_id=blk_rw_idx,
to_start_block_id=to_start_block_id, to_start_block_id=to_start_block_id,
to_end_block_id=plan_block_length[pl_id], to_end_block_id=plan_block_length[pl_id],
...@@ -1099,21 +1145,24 @@ class FlaxBigBirdBlockSparseAttention(nn.Module): ...@@ -1099,21 +1145,24 @@ class FlaxBigBirdBlockSparseAttention(nn.Module):
window_block_right=window_block_right, window_block_right=window_block_right,
global_block_left=global_block_left, global_block_left=global_block_left,
global_block_right=global_block_right, global_block_right=global_block_right,
indices_prng_key=indices_prng_key,
)
rand_attn[h] = (
rand_attn[h].at[blk_rw_idx, rnd_r_cnt:curr_r_cnt].set(single_block_row_attention)
) )
if plan_num_rand_blocks[plan_idx] == 0: if plan_num_rand_blocks[plan_idx] == 0:
continue continue
curr_r_cnt = int(np.sum(plan_num_rand_blocks[: plan_idx + 1])) curr_r_cnt = int(sum(plan_num_rand_blocks[: plan_idx + 1]))
from_start_block_id = global_block_top from_start_block_id = global_block_top
to_start_block_id = 0 to_start_block_id = 0
if plan_idx > 0: if plan_idx > 0:
rnd_r_cnt = int(np.sum(plan_num_rand_blocks[:plan_idx])) rnd_r_cnt = int(sum(plan_num_rand_blocks[:plan_idx]))
from_start_block_id = plan_block_length[plan_idx - 1] from_start_block_id = plan_block_length[plan_idx - 1]
to_start_block_id = plan_block_length[plan_idx - 1] to_start_block_id = plan_block_length[plan_idx - 1]
for blk_rw_idx in range(from_start_block_id, plan_block_length[plan_idx]): for blk_rw_idx in range(from_start_block_id, plan_block_length[plan_idx]):
for h in range(num_heads): for h in range(num_heads):
rand_attn[h][blk_rw_idx, rnd_r_cnt:curr_r_cnt] = self._get_single_block_row_attention( single_block_row_attention = self._get_single_block_row_attention(
block_id=blk_rw_idx, block_id=blk_rw_idx,
to_start_block_id=to_start_block_id, to_start_block_id=to_start_block_id,
to_end_block_id=plan_block_length[plan_idx], to_end_block_id=plan_block_length[plan_idx],
...@@ -1122,11 +1171,12 @@ class FlaxBigBirdBlockSparseAttention(nn.Module): ...@@ -1122,11 +1171,12 @@ class FlaxBigBirdBlockSparseAttention(nn.Module):
window_block_right=window_block_right, window_block_right=window_block_right,
global_block_left=global_block_left, global_block_left=global_block_left,
global_block_right=global_block_right, global_block_right=global_block_right,
indices_prng_key=indices_prng_key,
) )
rand_attn[h] = rand_attn[h].at[blk_rw_idx, rnd_r_cnt:curr_r_cnt].set(single_block_row_attention)
for nh in range(num_heads): for nh in range(num_heads):
rand_attn[nh] = rand_attn[nh][global_block_top : num_blocks - global_block_bottom, :] rand_attn[nh] = rand_attn[nh][global_block_top : num_blocks - global_block_bottom, :]
return rand_attn return rand_attn
@staticmethod @staticmethod
...@@ -1135,6 +1185,7 @@ class FlaxBigBirdBlockSparseAttention(nn.Module): ...@@ -1135,6 +1185,7 @@ class FlaxBigBirdBlockSparseAttention(nn.Module):
to_start_block_id, to_start_block_id,
to_end_block_id, to_end_block_id,
num_rand_blocks, num_rand_blocks,
indices_prng_key: Optional[jax.random.PRNGKey] = None,
window_block_left=1, window_block_left=1,
window_block_right=1, window_block_right=1,
global_block_left=1, global_block_left=1,
...@@ -1148,6 +1199,7 @@ class FlaxBigBirdBlockSparseAttention(nn.Module): ...@@ -1148,6 +1199,7 @@ class FlaxBigBirdBlockSparseAttention(nn.Module):
to_start_block_id: int. random attention column start id. to_start_block_id: int. random attention column start id.
to_end_block_id: int. random attention column end id. to_end_block_id: int. random attention column end id.
num_rand_blocks: int. number of random blocks to be selected. num_rand_blocks: int. number of random blocks to be selected.
indices_prng_key: jax.random.PRNGKey. PRNG key that is used to perform random jax operations
window_block_left: int. number of blocks of window to left of a block. window_block_left: int. number of blocks of window to left of a block.
window_block_right: int. number of blocks of window to right of a block. window_block_right: int. number of blocks of window to right of a block.
global_block_left: int. Number of blocks globally used to the left. global_block_left: int. Number of blocks globally used to the left.
...@@ -1157,9 +1209,9 @@ class FlaxBigBirdBlockSparseAttention(nn.Module): ...@@ -1157,9 +1209,9 @@ class FlaxBigBirdBlockSparseAttention(nn.Module):
row containing the random attention vector of size num_rand_blocks. row containing the random attention vector of size num_rand_blocks.
""" """
# list of to_blocks from which to choose random attention # list of to_blocks from which to choose random attention
to_block_list = np.arange(to_start_block_id, to_end_block_id, dtype=np.int32) to_block_list = jnp.arange(to_start_block_id, to_end_block_id, dtype=jnp.int32)
# permute the blocks # permute the blocks
perm_block = np.random.permutation(to_block_list) perm_block = jax.random.permutation(indices_prng_key, to_block_list)
# illegal blocks for the current block id, using window # illegal blocks for the current block id, using window
illegal_blocks = list(range(block_id - window_block_left, block_id + window_block_right + 1)) illegal_blocks = list(range(block_id - window_block_left, block_id + window_block_right + 1))
...@@ -1176,14 +1228,14 @@ class FlaxBigBirdBlockSparseAttention(nn.Module): ...@@ -1176,14 +1228,14 @@ class FlaxBigBirdBlockSparseAttention(nn.Module):
if block_id == to_end_block_id - 2: if block_id == to_end_block_id - 2:
illegal_blocks.append(1) illegal_blocks.append(1)
selected_random_blokcs = [] selected_random_blocks = []
for i in range(to_end_block_id - to_start_block_id): for i in range(to_end_block_id - to_start_block_id):
if perm_block[i] not in illegal_blocks: if perm_block[i] not in illegal_blocks:
selected_random_blokcs.append(perm_block[i]) selected_random_blocks.append(perm_block[i])
if len(selected_random_blokcs) == num_rand_blocks: if len(selected_random_blocks) == num_rand_blocks:
break break
return np.array(selected_random_blokcs, dtype=np.int32) return jnp.array(selected_random_blocks, dtype=jnp.int32)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfOutput with Bert->BigBird # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfOutput with Bert->BigBird
...@@ -1507,11 +1559,11 @@ class FlaxBigBirdPredictionHeadTransform(nn.Module): ...@@ -1507,11 +1559,11 @@ class FlaxBigBirdPredictionHeadTransform(nn.Module):
return self.LayerNorm(hidden_states) return self.LayerNorm(hidden_states)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLMPredictionHead with Bert->BigBird # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLMPredictionHead with Bert->BigBird, np.ndarray->jnp.ndarray
class FlaxBigBirdLMPredictionHead(nn.Module): class FlaxBigBirdLMPredictionHead(nn.Module):
config: BigBirdConfig config: BigBirdConfig
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros
def setup(self): def setup(self):
self.transform = FlaxBigBirdPredictionHeadTransform(self.config, dtype=self.dtype) self.transform = FlaxBigBirdPredictionHeadTransform(self.config, dtype=self.dtype)
...@@ -1594,7 +1646,6 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel): ...@@ -1594,7 +1646,6 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel):
gradient_checkpointing=True, gradient_checkpointing=True,
) )
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.init_weights
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors # init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4") input_ids = jnp.zeros(input_shape, dtype="i4")
...@@ -1603,8 +1654,8 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel): ...@@ -1603,8 +1654,8 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel):
attention_mask = jnp.ones_like(input_ids) attention_mask = jnp.ones_like(input_ids)
head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads)) head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
params_rng, dropout_rng = jax.random.split(rng) params_rng, dropout_rng, indices_rng = jax.random.split(rng, num=3)
rngs = {"params": params_rng, "dropout": dropout_rng} rngs = {"params": params_rng, "dropout": dropout_rng, "indices": indices_rng}
if self.config.add_cross_attention: if self.config.add_cross_attention:
encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,)) encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,))
...@@ -1622,7 +1673,13 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel): ...@@ -1622,7 +1673,13 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel):
) )
else: else:
module_init_outputs = self.module.init( module_init_outputs = self.module.init(
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False rngs,
input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask,
return_dict=False,
) )
random_params = module_init_outputs["params"] random_params = module_init_outputs["params"]
...@@ -1658,7 +1715,6 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel): ...@@ -1658,7 +1715,6 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel):
return unfreeze(init_variables["cache"]) return unfreeze(init_variables["cache"])
@add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.__call__ with Bert->BigBird
def __call__( def __call__(
self, self,
input_ids, input_ids,
...@@ -1669,7 +1725,8 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel): ...@@ -1669,7 +1725,8 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
params: dict = None, params: dict = None,
dropout_rng: jax.random.PRNGKey = None, dropout_rng: Optional[jax.random.PRNGKey] = None,
indices_rng: Optional[jax.random.PRNGKey] = None,
train: bool = False, train: bool = False,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
...@@ -1697,6 +1754,9 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel): ...@@ -1697,6 +1754,9 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel):
# Handle any PRNG if needed # Handle any PRNG if needed
rngs = {} rngs = {}
if indices_rng is not None:
rngs["indices"] = indices_rng
if dropout_rng is not None: if dropout_rng is not None:
rngs["dropout"] = dropout_rng rngs["dropout"] = dropout_rng
...@@ -2382,7 +2442,8 @@ class FlaxBigBirdForQuestionAnswering(FlaxBigBirdPreTrainedModel): ...@@ -2382,7 +2442,8 @@ class FlaxBigBirdForQuestionAnswering(FlaxBigBirdPreTrainedModel):
head_mask=None, head_mask=None,
question_lengths=None, question_lengths=None,
params: dict = None, params: dict = None,
dropout_rng: jax.random.PRNGKey = None, dropout_rng: Optional[jax.random.PRNGKey] = None,
indices_rng: Optional[jax.random.PRNGKey] = None,
train: bool = False, train: bool = False,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
...@@ -2428,6 +2489,9 @@ class FlaxBigBirdForQuestionAnswering(FlaxBigBirdPreTrainedModel): ...@@ -2428,6 +2489,9 @@ class FlaxBigBirdForQuestionAnswering(FlaxBigBirdPreTrainedModel):
if dropout_rng is not None: if dropout_rng is not None:
rngs["dropout"] = dropout_rng rngs["dropout"] = dropout_rng
if indices_rng is not None:
rngs["indices"] = indices_rng
return self.module.apply( return self.module.apply(
{"params": params or self.params}, {"params": params or self.params},
jnp.array(input_ids, dtype="i4"), jnp.array(input_ids, dtype="i4"),
...@@ -2459,7 +2523,6 @@ append_call_sample_docstring( ...@@ -2459,7 +2523,6 @@ append_call_sample_docstring(
) )
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForCausalLMModule with Bert->BigBird
class FlaxBigBirdForCausalLMModule(nn.Module): class FlaxBigBirdForCausalLMModule(nn.Module):
config: BigBirdConfig config: BigBirdConfig
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
...@@ -2491,11 +2554,11 @@ class FlaxBigBirdForCausalLMModule(nn.Module): ...@@ -2491,11 +2554,11 @@ class FlaxBigBirdForCausalLMModule(nn.Module):
): ):
# Model # Model
outputs = self.bert( outputs = self.bert(
input_ids, input_ids=input_ids,
attention_mask, attention_mask=attention_mask,
token_type_ids, token_type_ids=token_type_ids,
position_ids, position_ids=position_ids,
head_mask, head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache, init_cache=init_cache,
......
...@@ -14,10 +14,8 @@ ...@@ -14,10 +14,8 @@
import unittest import unittest
import numpy as np
from transformers import BigBirdConfig, is_flax_available from transformers import BigBirdConfig, is_flax_available
from transformers.testing_utils import require_flax, slow from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
...@@ -129,7 +127,11 @@ class FlaxBigBirdModelTester(unittest.TestCase): ...@@ -129,7 +127,11 @@ class FlaxBigBirdModelTester(unittest.TestCase):
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
config, input_ids, token_type_ids, attention_mask = config_and_inputs config, input_ids, token_type_ids, attention_mask = config_and_inputs
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask} inputs_dict = {
"input_ids": input_ids,
"token_type_ids": token_type_ids,
"attention_mask": attention_mask,
}
return config, inputs_dict return config, inputs_dict
...@@ -180,8 +182,7 @@ class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase): ...@@ -180,8 +182,7 @@ class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase):
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes: for model_class_name in self.all_model_classes:
model = model_class_name.from_pretrained("google/bigbird-roberta-base") model = model_class_name.from_pretrained("google/bigbird-roberta-base")
outputs = model(np.ones((1, 1))) self.assertIsNotNone(model)
self.assertIsNotNone(outputs)
def test_attention_outputs(self): def test_attention_outputs(self):
if self.test_attn_probs: if self.test_attn_probs:
...@@ -220,3 +221,17 @@ class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase): ...@@ -220,3 +221,17 @@ class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase):
return return
else: else:
super().check_pt_flax_outputs(fx_outputs, pt_outputs, model_class, tol, name, attributes) super().check_pt_flax_outputs(fx_outputs, pt_outputs, model_class, tol, name, attributes)
@is_pt_flax_cross_test
@unittest.skip(
reason="Current Pytorch implementation has bug with random attention -> it always uses it not matter if we are in eval/train mode"
)
def test_equivalence_flax_to_pt(self):
pass
@is_pt_flax_cross_test
@unittest.skip(
reason="Current Pytorch implementation has bug with random attention -> it always uses it not matter if we are in eval/train mode"
)
def test_equivalence_pt_to_flax(self):
pass
...@@ -158,7 +158,7 @@ class FlaxModelTesterMixin: ...@@ -158,7 +158,7 @@ class FlaxModelTesterMixin:
if "ForMultipleChoice" in model_class.__name__: if "ForMultipleChoice" in model_class.__name__:
inputs_dict = { inputs_dict = {
k: jnp.broadcast_to(v[:, None], (v.shape[0], self.model_tester.num_choices, v.shape[-1])) k: jnp.broadcast_to(v[:, None], (v.shape[0], self.model_tester.num_choices, v.shape[-1]))
if isinstance(v, (jnp.ndarray, np.ndarray)) if isinstance(v, (jnp.ndarray, np.ndarray)) and k != "indices_prng_key"
else v else v
for k, v in inputs_dict.items() for k, v in inputs_dict.items()
} }
...@@ -629,7 +629,6 @@ class FlaxModelTesterMixin: ...@@ -629,7 +629,6 @@ class FlaxModelTesterMixin:
def test_hidden_states_output(self): def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class): def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config) model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class)) outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment