Commit f2adc5ef authored by Zihan Wang's avatar Zihan Wang
Browse files

fix argument passing

parent 5ad1d93f
...@@ -54,10 +54,10 @@ def get_encoder(encoder_cfg: LongformerEncoderConfig): ...@@ -54,10 +54,10 @@ def get_encoder(encoder_cfg: LongformerEncoderConfig):
hidden_size=encoder_cfg.hidden_size, hidden_size=encoder_cfg.hidden_size,
num_layers=encoder_cfg.num_layers, num_layers=encoder_cfg.num_layers,
num_attention_heads=encoder_cfg.num_attention_heads, num_attention_heads=encoder_cfg.num_attention_heads,
intermediate_size=encoder_cfg.intermediate_size, inner_dim=encoder_cfg.intermediate_size,
activation=tf_utils.get_activation(encoder_cfg.hidden_activation), inner_activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
dropout_rate=encoder_cfg.dropout_rate, output_dropout=encoder_cfg.dropout_rate,
attention_dropout_rate=encoder_cfg.attention_dropout_rate, attention_dropout=encoder_cfg.attention_dropout_rate,
max_sequence_length=encoder_cfg.max_position_embeddings, max_sequence_length=encoder_cfg.max_position_embeddings,
type_vocab_size=encoder_cfg.type_vocab_size, type_vocab_size=encoder_cfg.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal( initializer=tf.keras.initializers.TruncatedNormal(
......
...@@ -106,19 +106,6 @@ class LongformerEncoder(tf.keras.layers.Layer): ...@@ -106,19 +106,6 @@ class LongformerEncoder(tf.keras.layers.Layer):
embedding_layer: Optional[tf.keras.layers.Layer] = None, embedding_layer: Optional[tf.keras.layers.Layer] = None,
norm_first: bool = False, norm_first: bool = False,
**kwargs): **kwargs):
# Pops kwargs that are used in V1 implementation.
if 'dict_outputs' in kwargs:
kwargs.pop('dict_outputs')
if 'return_all_encoder_outputs' in kwargs:
kwargs.pop('return_all_encoder_outputs')
if 'intermediate_size' in kwargs:
inner_dim = kwargs.pop('intermediate_size')
if 'activation' in kwargs:
inner_activation = kwargs.pop('activation')
if 'dropout_rate' in kwargs:
output_dropout = kwargs.pop('dropout_rate')
if 'attention_dropout_rate' in kwargs:
attention_dropout = kwargs.pop('attention_dropout_rate')
super().__init__(**kwargs) super().__init__(**kwargs)
# Longformer args # Longformer args
self._attention_window = attention_window self._attention_window = attention_window
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment