encoders.py 13 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
16
17
"""Transformer Encoders.

Hongkun Yu's avatar
Hongkun Yu committed
18
Includes configurations and factory methods.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
19
"""
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
20
from typing import Optional
Hongkun Yu's avatar
Hongkun Yu committed
21
22

from absl import logging
23
import dataclasses
Hongkun Yu's avatar
Hongkun Yu committed
24
import gin
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
25
import tensorflow as tf
26

Hongkun Yu's avatar
Hongkun Yu committed
27
from official.modeling import hyperparams
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
28
from official.modeling import tf_utils
29
from official.nlp import keras_nlp
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
30
from official.nlp.modeling import networks
Hongkun Yu's avatar
Hongkun Yu committed
31
from official.nlp.projects.bigbird import encoder as bigbird_encoder
32
33
34


@dataclasses.dataclass
Hongkun Yu's avatar
Hongkun Yu committed
35
class BertEncoderConfig(hyperparams.Config):
36
37
38
39
40
41
  """BERT encoder configuration."""
  vocab_size: int = 30522
  hidden_size: int = 768
  num_layers: int = 12
  num_attention_heads: int = 12
  hidden_activation: str = "gelu"
Chen Chen's avatar
Chen Chen committed
42
  intermediate_size: int = 3072
43
44
45
46
47
  dropout_rate: float = 0.1
  attention_dropout_rate: float = 0.1
  max_position_embeddings: int = 512
  type_vocab_size: int = 2
  initializer_range: float = 0.02
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
48
  embedding_size: Optional[int] = None
Chen Chen's avatar
Chen Chen committed
49
  return_all_encoder_outputs: bool = False
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
50
51


A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
52
53
54
55
56
57
58
59
60
61
62
63
@dataclasses.dataclass
class MobileBertEncoderConfig(hyperparams.Config):
  """MobileBERT encoder configuration.

  Attributes:
    word_vocab_size: number of words in the vocabulary.
    word_embed_size: word embedding size.
    type_vocab_size: number of word types.
    max_sequence_length: maximum length of input sequence.
    num_blocks: number of transformer block in the encoder model.
    hidden_size: the hidden size for the transformer block.
    num_attention_heads: number of attention heads in the transformer block.
Hongkun Yu's avatar
Hongkun Yu committed
64
65
    intermediate_size: the size of the "intermediate" (a.k.a., feed forward)
      layer.
Chen Chen's avatar
Chen Chen committed
66
    hidden_activation: the non-linear activation function to apply to the
Hongkun Yu's avatar
Hongkun Yu committed
67
      output of the intermediate/feed-forward layer.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
68
69
70
71
72
    hidden_dropout_prob: dropout probability for the hidden layers.
    attention_probs_dropout_prob: dropout probability of the attention
      probabilities.
    intra_bottleneck_size: the size of bottleneck.
    initializer_range: The stddev of the truncated_normal_initializer for
Hongkun Yu's avatar
Hongkun Yu committed
73
74
75
      initializing all weight matrices.
    key_query_shared_bottleneck: whether to share linear transformation for keys
      and queries.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
    num_feedforward_networks: number of stacked feed-forward networks.
    normalization_type: the type of normalization_type, only 'no_norm' and
      'layer_norm' are supported. 'no_norm' represents the element-wise linear
      transformation for the student model, as suggested by the original
      MobileBERT paper. 'layer_norm' is used for the teacher model.
    classifier_activation: if using the tanh activation for the final
      representation of the [CLS] token in fine-tuning.
  """
  word_vocab_size: int = 30522
  word_embed_size: int = 128
  type_vocab_size: int = 2
  max_sequence_length: int = 512
  num_blocks: int = 24
  hidden_size: int = 512
  num_attention_heads: int = 4
  intermediate_size: int = 4096
Chen Chen's avatar
Chen Chen committed
92
  hidden_activation: str = "gelu"
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
93
94
95
96
97
98
99
100
101
102
  hidden_dropout_prob: float = 0.1
  attention_probs_dropout_prob: float = 0.1
  intra_bottleneck_size: int = 1024
  initializer_range: float = 0.02
  key_query_shared_bottleneck: bool = False
  num_feedforward_networks: int = 1
  normalization_type: str = "layer_norm"
  classifier_activation: bool = True


Chen Chen's avatar
Chen Chen committed
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
@dataclasses.dataclass
class AlbertEncoderConfig(hyperparams.Config):
  """ALBERT encoder configuration."""
  vocab_size: int = 30000
  embedding_width: int = 128
  hidden_size: int = 768
  num_layers: int = 12
  num_attention_heads: int = 12
  hidden_activation: str = "gelu"
  intermediate_size: int = 3072
  dropout_rate: float = 0.0
  attention_dropout_rate: float = 0.0
  max_position_embeddings: int = 512
  type_vocab_size: int = 2
  initializer_range: float = 0.02


Hongkun Yu's avatar
Hongkun Yu committed
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
@dataclasses.dataclass
class BigBirdEncoderConfig(hyperparams.Config):
  """BigBird encoder configuration."""
  vocab_size: int = 50358
  hidden_size: int = 768
  num_layers: int = 12
  num_attention_heads: int = 12
  hidden_activation: str = "gelu"
  intermediate_size: int = 3072
  dropout_rate: float = 0.1
  attention_dropout_rate: float = 0.1
  max_position_embeddings: int = 4096
  num_rand_blocks: int = 3
  block_size: int = 64
  type_vocab_size: int = 16
  initializer_range: float = 0.02
  embedding_size: Optional[int] = None


Allen Wang's avatar
Allen Wang committed
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
@dataclasses.dataclass
class XLNetEncoderConfig(hyperparams.Config):
  """XLNet encoder configuration."""
  vocab_size: int = 32000
  num_layers: int = 24
  hidden_size: int = 1024
  num_attention_heads: int = 16
  head_size: int = 64
  inner_size: int = 4096
  inner_activation: str = "gelu"
  dropout_rate: float = 0.1
  attention_dropout_rate: float = 0.1
  attention_type: str = "bi"
  bi_data: bool = False
  tie_attention_biases: bool = False
  memory_length: int = 0
  same_length: bool = False
  clamp_length: int = -1
  reuse_length: int = 0
  use_cls_mask: bool = False
  embedding_width: int = 1024
  initializer_range: float = 0.02
  two_stream: bool = False


Hongkun Yu's avatar
Hongkun Yu committed
164
165
166
167
@dataclasses.dataclass
class EncoderConfig(hyperparams.OneOfConfig):
  """Encoder configuration."""
  type: Optional[str] = "bert"
Chen Chen's avatar
Chen Chen committed
168
  albert: AlbertEncoderConfig = AlbertEncoderConfig()
Hongkun Yu's avatar
Hongkun Yu committed
169
  bert: BertEncoderConfig = BertEncoderConfig()
Hongkun Yu's avatar
Hongkun Yu committed
170
  bigbird: BigBirdEncoderConfig = BigBirdEncoderConfig()
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
171
  mobilebert: MobileBertEncoderConfig = MobileBertEncoderConfig()
Allen Wang's avatar
Allen Wang committed
172
  xlnet: XLNetEncoderConfig = XLNetEncoderConfig()
Hongkun Yu's avatar
Hongkun Yu committed
173
174
175


ENCODER_CLS = {
176
    "bert": networks.BertEncoder,
Chen Chen's avatar
Chen Chen committed
177
    "mobilebert": networks.MobileBERTEncoder,
Chen Chen's avatar
Chen Chen committed
178
    "albert": networks.AlbertEncoder,
Hongkun Yu's avatar
Hongkun Yu committed
179
    "bigbird": bigbird_encoder.BigBirdEncoder,
Allen Wang's avatar
Allen Wang committed
180
    "xlnet": networks.XLNetBase,
Hongkun Yu's avatar
Hongkun Yu committed
181
182
183
184
}


@gin.configurable
185
186
187
188
189
def build_encoder(
    config: EncoderConfig,
    embedding_layer: Optional[keras_nlp.layers.OnDeviceEmbedding] = None,
    encoder_cls=None,
    bypass_config: bool = False):
Hongkun Yu's avatar
Hongkun Yu committed
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
  """Instantiate a Transformer encoder network from EncoderConfig.

  Args:
    config: the one-of encoder config, which provides encoder parameters of a
      chosen encoder.
    embedding_layer: an external embedding layer passed to the encoder.
    encoder_cls: an external encoder cls not included in the supported encoders,
      usually used by gin.configurable.
    bypass_config: whether to ignore config instance to create the object with
      `encoder_cls`.

  Returns:
    An encoder instance.
  """
  encoder_type = config.type
  encoder_cfg = config.get()
  encoder_cls = encoder_cls or ENCODER_CLS[encoder_type]
  logging.info("Encoder class: %s to build...", encoder_cls.__name__)
  if bypass_config:
    return encoder_cls()
Hongkun Yu's avatar
Hongkun Yu committed
210
211
  if encoder_cls.__name__ == "EncoderScaffold":
    embedding_cfg = dict(
Hongkun Yu's avatar
Hongkun Yu committed
212
213
214
215
        vocab_size=encoder_cfg.vocab_size,
        type_vocab_size=encoder_cfg.type_vocab_size,
        hidden_size=encoder_cfg.hidden_size,
        max_seq_length=encoder_cfg.max_position_embeddings,
Hongkun Yu's avatar
Hongkun Yu committed
216
        initializer=tf.keras.initializers.TruncatedNormal(
Hongkun Yu's avatar
Hongkun Yu committed
217
218
            stddev=encoder_cfg.initializer_range),
        dropout_rate=encoder_cfg.dropout_rate,
Hongkun Yu's avatar
Hongkun Yu committed
219
220
    )
    hidden_cfg = dict(
Hongkun Yu's avatar
Hongkun Yu committed
221
222
        num_attention_heads=encoder_cfg.num_attention_heads,
        intermediate_size=encoder_cfg.intermediate_size,
Hongkun Yu's avatar
Hongkun Yu committed
223
        intermediate_activation=tf_utils.get_activation(
Hongkun Yu's avatar
Hongkun Yu committed
224
225
226
            encoder_cfg.hidden_activation),
        dropout_rate=encoder_cfg.dropout_rate,
        attention_dropout_rate=encoder_cfg.attention_dropout_rate,
Hongkun Yu's avatar
Hongkun Yu committed
227
        kernel_initializer=tf.keras.initializers.TruncatedNormal(
Hongkun Yu's avatar
Hongkun Yu committed
228
            stddev=encoder_cfg.initializer_range),
Hongkun Yu's avatar
Hongkun Yu committed
229
230
231
232
    )
    kwargs = dict(
        embedding_cfg=embedding_cfg,
        hidden_cfg=hidden_cfg,
Hongkun Yu's avatar
Hongkun Yu committed
233
234
        num_hidden_instances=encoder_cfg.num_layers,
        pooled_output_dim=encoder_cfg.hidden_size,
Hongkun Yu's avatar
Hongkun Yu committed
235
        pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
Chen Chen's avatar
Chen Chen committed
236
            stddev=encoder_cfg.initializer_range),
237
238
        return_all_layer_outputs=encoder_cfg.return_all_encoder_outputs,
        dict_outputs=True)
Hongkun Yu's avatar
Hongkun Yu committed
239
240
    return encoder_cls(**kwargs)

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
241
242
243
244
245
246
247
248
249
250
  if encoder_type == "mobilebert":
    return encoder_cls(
        word_vocab_size=encoder_cfg.word_vocab_size,
        word_embed_size=encoder_cfg.word_embed_size,
        type_vocab_size=encoder_cfg.type_vocab_size,
        max_sequence_length=encoder_cfg.max_sequence_length,
        num_blocks=encoder_cfg.num_blocks,
        hidden_size=encoder_cfg.hidden_size,
        num_attention_heads=encoder_cfg.num_attention_heads,
        intermediate_size=encoder_cfg.intermediate_size,
Chen Chen's avatar
Chen Chen committed
251
        intermediate_act_fn=encoder_cfg.hidden_activation,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
252
253
254
        hidden_dropout_prob=encoder_cfg.hidden_dropout_prob,
        attention_probs_dropout_prob=encoder_cfg.attention_probs_dropout_prob,
        intra_bottleneck_size=encoder_cfg.intra_bottleneck_size,
Chen Chen's avatar
Chen Chen committed
255
        initializer_range=encoder_cfg.initializer_range,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
256
257
258
        key_query_shared_bottleneck=encoder_cfg.key_query_shared_bottleneck,
        num_feedforward_networks=encoder_cfg.num_feedforward_networks,
        normalization_type=encoder_cfg.normalization_type,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
259
        classifier_activation=encoder_cfg.classifier_activation)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
260

Chen Chen's avatar
Chen Chen committed
261
262
263
264
265
266
267
268
269
270
271
272
273
274
  if encoder_type == "albert":
    return encoder_cls(
        vocab_size=encoder_cfg.vocab_size,
        embedding_width=encoder_cfg.embedding_width,
        hidden_size=encoder_cfg.hidden_size,
        num_layers=encoder_cfg.num_layers,
        num_attention_heads=encoder_cfg.num_attention_heads,
        max_sequence_length=encoder_cfg.max_position_embeddings,
        type_vocab_size=encoder_cfg.type_vocab_size,
        intermediate_size=encoder_cfg.intermediate_size,
        activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
        dropout_rate=encoder_cfg.dropout_rate,
        attention_dropout_rate=encoder_cfg.attention_dropout_rate,
        initializer=tf.keras.initializers.TruncatedNormal(
275
276
            stddev=encoder_cfg.initializer_range),
        dict_outputs=True)
Chen Chen's avatar
Chen Chen committed
277

Hongkun Yu's avatar
Hongkun Yu committed
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
  if encoder_type == "bigbird":
    return encoder_cls(
        vocab_size=encoder_cfg.vocab_size,
        hidden_size=encoder_cfg.hidden_size,
        num_layers=encoder_cfg.num_layers,
        num_attention_heads=encoder_cfg.num_attention_heads,
        intermediate_size=encoder_cfg.intermediate_size,
        activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
        dropout_rate=encoder_cfg.dropout_rate,
        attention_dropout_rate=encoder_cfg.attention_dropout_rate,
        num_rand_blocks=encoder_cfg.num_rand_blocks,
        block_size=encoder_cfg.block_size,
        max_sequence_length=encoder_cfg.max_position_embeddings,
        type_vocab_size=encoder_cfg.type_vocab_size,
        initializer=tf.keras.initializers.TruncatedNormal(
            stddev=encoder_cfg.initializer_range),
        embedding_width=encoder_cfg.embedding_size)

Allen Wang's avatar
Allen Wang committed
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
  if encoder_type == "xlnet":
    return encoder_cls(
        vocab_size=encoder_cfg.vocab_size,
        num_layers=encoder_cfg.num_layers,
        hidden_size=encoder_cfg.hidden_size,
        num_attention_heads=encoder_cfg.num_attention_heads,
        head_size=encoder_cfg.head_size,
        inner_size=encoder_cfg.inner_size,
        dropout_rate=encoder_cfg.dropout_rate,
        attention_dropout_rate=encoder_cfg.attention_dropout_rate,
        attention_type=encoder_cfg.attention_type,
        bi_data=encoder_cfg.bi_data,
        two_stream=encoder_cfg.two_stream,
        tie_attention_biases=encoder_cfg.tie_attention_biases,
        memory_length=encoder_cfg.memory_length,
        clamp_length=encoder_cfg.clamp_length,
        reuse_length=encoder_cfg.reuse_length,
        inner_activation=encoder_cfg.inner_activation,
        use_cls_mask=encoder_cfg.use_cls_mask,
        embedding_width=encoder_cfg.embedding_width,
        initializer=tf.keras.initializers.RandomNormal(
            stddev=encoder_cfg.initializer_range))

Hongkun Yu's avatar
Hongkun Yu committed
319
320
321
322
323
324
325
326
327
328
329
330
331
  # Uses the default BERTEncoder configuration schema to create the encoder.
  # If it does not match, please add a switch branch by the encoder type.
  return encoder_cls(
      vocab_size=encoder_cfg.vocab_size,
      hidden_size=encoder_cfg.hidden_size,
      num_layers=encoder_cfg.num_layers,
      num_attention_heads=encoder_cfg.num_attention_heads,
      intermediate_size=encoder_cfg.intermediate_size,
      activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
      dropout_rate=encoder_cfg.dropout_rate,
      attention_dropout_rate=encoder_cfg.attention_dropout_rate,
      max_sequence_length=encoder_cfg.max_position_embeddings,
      type_vocab_size=encoder_cfg.type_vocab_size,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
332
      initializer=tf.keras.initializers.TruncatedNormal(
Hongkun Yu's avatar
Hongkun Yu committed
333
334
          stddev=encoder_cfg.initializer_range),
      embedding_width=encoder_cfg.embedding_size,
Chen Chen's avatar
Chen Chen committed
335
      embedding_layer=embedding_layer,
336
337
      return_all_encoder_outputs=encoder_cfg.return_all_encoder_outputs,
      dict_outputs=True)