"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "fd232c2c79cbc45d5ba2683ca1f5568763ade254"
Commit 8c430b98 authored by Zihan Wang's avatar Zihan Wang
Browse files

fix docstrings

parent cadde4ee
...@@ -30,9 +30,11 @@ class LongformerEncoderConfig(encoders.BertEncoderConfig): ...@@ -30,9 +30,11 @@ class LongformerEncoderConfig(encoders.BertEncoderConfig):
Args: Args:
attention_window: list of ints representing the window size for each layer. attention_window: list of ints representing the window size for each layer.
global_attention_size: the size of global attention used for each token. global_attention_size: the size of global attention used for each token.
pad_token_id: the token id for the pad token
''' '''
attention_window: List[int] = dataclasses.field(default_factory=list) attention_window: List[int] = dataclasses.field(default_factory=list)
global_attention_size: int = 0 global_attention_size: int = 0
pad_token_id: int = 1
@gin.configurable @gin.configurable
@base_config.bind(LongformerEncoderConfig) @base_config.bind(LongformerEncoderConfig)
......
...@@ -65,6 +65,9 @@ class LongformerEncoder(tf.keras.layers.Layer): ...@@ -65,6 +65,9 @@ class LongformerEncoder(tf.keras.layers.Layer):
Args: Args:
vocab_size: The size of the token vocabulary. vocab_size: The size of the token vocabulary.
attention_window: list of ints representing the window size for each layer.
global_attention_size: the size of global attention used for each token.
pad_token_id: the token id for the pad token
hidden_size: The size of the transformer hidden layers. hidden_size: The size of the transformer hidden layers.
num_layers: The number of transformer layers. num_layers: The number of transformer layers.
num_attention_heads: The number of attention heads for each transformer. The num_attention_heads: The number of attention heads for each transformer. The
...@@ -120,23 +123,10 @@ class LongformerEncoder(tf.keras.layers.Layer): ...@@ -120,23 +123,10 @@ class LongformerEncoder(tf.keras.layers.Layer):
embedding_layer: Optional[tf.keras.layers.Layer] = None, embedding_layer: Optional[tf.keras.layers.Layer] = None,
norm_first: bool = False, norm_first: bool = False,
**kwargs): **kwargs):
# Pops kwargs that are used in V1 implementation.
if 'dict_outputs' in kwargs:
kwargs.pop('dict_outputs')
if 'return_all_encoder_outputs' in kwargs:
kwargs.pop('return_all_encoder_outputs')
if 'intermediate_size' in kwargs:
inner_dim = kwargs.pop('intermediate_size')
if 'activation' in kwargs:
inner_activation = kwargs.pop('activation')
if 'dropout_rate' in kwargs:
output_dropout = kwargs.pop('dropout_rate')
if 'attention_dropout_rate' in kwargs:
attention_dropout = kwargs.pop('attention_dropout_rate')
super().__init__(**kwargs) super().__init__(**kwargs)
# Longformer # Longformer args
self._attention_window = attention_window self._attention_window = attention_window
self.global_attention_size = global_attention_size self._global_attention_size = global_attention_size
self._pad_token_id = pad_token_id self._pad_token_id = pad_token_id
activation = tf.keras.activations.get(inner_activation) activation = tf.keras.activations.get(inner_activation)
...@@ -227,6 +217,7 @@ class LongformerEncoder(tf.keras.layers.Layer): ...@@ -227,6 +217,7 @@ class LongformerEncoder(tf.keras.layers.Layer):
'norm_first': norm_first, 'norm_first': norm_first,
# Longformer # Longformer
'attention_window': attention_window, 'attention_window': attention_window,
'global_attention_size': global_attention_size,
'pad_token_id': pad_token_id, 'pad_token_id': pad_token_id,
} }
self.inputs = dict( self.inputs = dict(
...@@ -273,15 +264,16 @@ class LongformerEncoder(tf.keras.layers.Layer): ...@@ -273,15 +264,16 @@ class LongformerEncoder(tf.keras.layers.Layer):
batch_size, seq_len = shape_list(mask) batch_size, seq_len = shape_list(mask)
# create masks with fixed len global_attention_size # create masks with fixed len global_attention_size
mask = tf.transpose(tf.concat(values=[tf.ones((self.global_attention_size, batch_size), tf.int32) * 2, mask = tf.transpose(tf.concat(values=[tf.ones((self._global_attention_size, batch_size), tf.int32) * 2,
tf.transpose(mask)[self.global_attention_size:]], axis=0)) tf.transpose(mask)[self._global_attention_size:]], axis=0))
is_index_masked = tf.math.less(mask, 1) is_index_masked = tf.math.less(mask, 1)
is_index_global_attn = tf.transpose(tf.concat(values=[ is_index_global_attn = tf.transpose(tf.concat(values=[
tf.ones((self.global_attention_size, batch_size), tf.bool), tf.zeros((seq_len - self.global_attention_size, batch_size), tf.bool) tf.ones((self._global_attention_size, batch_size), tf.bool), tf.zeros((seq_len - self._global_attention_size,
batch_size), tf.bool)
], axis=0)) ], axis=0))
is_global_attn = self.global_attention_size > 0 is_global_attn = self._global_attention_size > 0
# Longformer # Longformer
attention_mask = mask attention_mask = mask
...@@ -347,11 +339,11 @@ class LongformerEncoder(tf.keras.layers.Layer): ...@@ -347,11 +339,11 @@ class LongformerEncoder(tf.keras.layers.Layer):
def _pad_to_window_size( def _pad_to_window_size(
self, self,
word_ids, # input_ids word_ids,
mask, # attention_mask mask,
type_ids, # token_type_ids type_ids,
word_embeddings, # inputs_embeds word_embeddings,
pad_token_id, # pad_token_id pad_token_id,
): ):
"""A helper function to pad tokens and mask to work with implementation of Longformer selfattention.""" """A helper function to pad tokens and mask to work with implementation of Longformer selfattention."""
# padding # padding
...@@ -361,8 +353,7 @@ class LongformerEncoder(tf.keras.layers.Layer): ...@@ -361,8 +353,7 @@ class LongformerEncoder(tf.keras.layers.Layer):
assert attention_window % 2 == 0, f"`attention_window` should be an even value. Given {attention_window}" assert attention_window % 2 == 0, f"`attention_window` should be an even value. Given {attention_window}"
# input_shape = shape_list(input_ids) if input_ids is not None else shape_list(inputs_embeds) input_shape = shape_list(word_ids) if word_ids is not None else shape_list(word_embeddings)
input_shape = word_ids.shape if word_ids is not None else word_embeddings.shape
batch_size, seq_len = input_shape[:2] batch_size, seq_len = input_shape[:2]
if seq_len is not None: if seq_len is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment