Commit 8c430b98 authored by Zihan Wang's avatar Zihan Wang
Browse files

fix docstrings

parent cadde4ee
......@@ -30,9 +30,11 @@ class LongformerEncoderConfig(encoders.BertEncoderConfig):
Args:
attention_window: list of ints representing the window size for each layer.
global_attention_size: the size of global attention used for each token.
pad_token_id: the token id for the pad token
'''
attention_window: List[int] = dataclasses.field(default_factory=list)
global_attention_size: int = 0
pad_token_id: int = 1
@gin.configurable
@base_config.bind(LongformerEncoderConfig)
......
......@@ -65,6 +65,9 @@ class LongformerEncoder(tf.keras.layers.Layer):
Args:
vocab_size: The size of the token vocabulary.
attention_window: list of ints representing the window size for each layer.
global_attention_size: the size of global attention used for each token.
pad_token_id: the token id for the pad token
hidden_size: The size of the transformer hidden layers.
num_layers: The number of transformer layers.
num_attention_heads: The number of attention heads for each transformer. The
......@@ -120,23 +123,10 @@ class LongformerEncoder(tf.keras.layers.Layer):
embedding_layer: Optional[tf.keras.layers.Layer] = None,
norm_first: bool = False,
**kwargs):
# Pops kwargs that are used in V1 implementation.
if 'dict_outputs' in kwargs:
kwargs.pop('dict_outputs')
if 'return_all_encoder_outputs' in kwargs:
kwargs.pop('return_all_encoder_outputs')
if 'intermediate_size' in kwargs:
inner_dim = kwargs.pop('intermediate_size')
if 'activation' in kwargs:
inner_activation = kwargs.pop('activation')
if 'dropout_rate' in kwargs:
output_dropout = kwargs.pop('dropout_rate')
if 'attention_dropout_rate' in kwargs:
attention_dropout = kwargs.pop('attention_dropout_rate')
super().__init__(**kwargs)
# Longformer
# Longformer args
self._attention_window = attention_window
self.global_attention_size = global_attention_size
self._global_attention_size = global_attention_size
self._pad_token_id = pad_token_id
activation = tf.keras.activations.get(inner_activation)
......@@ -227,6 +217,7 @@ class LongformerEncoder(tf.keras.layers.Layer):
'norm_first': norm_first,
# Longformer
'attention_window': attention_window,
'global_attention_size': global_attention_size,
'pad_token_id': pad_token_id,
}
self.inputs = dict(
......@@ -273,15 +264,16 @@ class LongformerEncoder(tf.keras.layers.Layer):
batch_size, seq_len = shape_list(mask)
# create masks with fixed len global_attention_size
mask = tf.transpose(tf.concat(values=[tf.ones((self.global_attention_size, batch_size), tf.int32) * 2,
tf.transpose(mask)[self.global_attention_size:]], axis=0))
mask = tf.transpose(tf.concat(values=[tf.ones((self._global_attention_size, batch_size), tf.int32) * 2,
tf.transpose(mask)[self._global_attention_size:]], axis=0))
is_index_masked = tf.math.less(mask, 1)
is_index_global_attn = tf.transpose(tf.concat(values=[
tf.ones((self.global_attention_size, batch_size), tf.bool), tf.zeros((seq_len - self.global_attention_size, batch_size), tf.bool)
tf.ones((self._global_attention_size, batch_size), tf.bool), tf.zeros((seq_len - self._global_attention_size,
batch_size), tf.bool)
], axis=0))
is_global_attn = self.global_attention_size > 0
is_global_attn = self._global_attention_size > 0
# Longformer
attention_mask = mask
......@@ -347,11 +339,11 @@ class LongformerEncoder(tf.keras.layers.Layer):
def _pad_to_window_size(
self,
word_ids, # input_ids
mask, # attention_mask
type_ids, # token_type_ids
word_embeddings, # inputs_embeds
pad_token_id, # pad_token_id
word_ids,
mask,
type_ids,
word_embeddings,
pad_token_id,
):
"""A helper function to pad tokens and mask to work with implementation of Longformer selfattention."""
# padding
......@@ -361,8 +353,7 @@ class LongformerEncoder(tf.keras.layers.Layer):
assert attention_window % 2 == 0, f"`attention_window` should be an even value. Given {attention_window}"
# input_shape = shape_list(input_ids) if input_ids is not None else shape_list(inputs_embeds)
input_shape = word_ids.shape if word_ids is not None else word_embeddings.shape
input_shape = shape_list(word_ids) if word_ids is not None else shape_list(word_embeddings)
batch_size, seq_len = input_shape[:2]
if seq_len is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment