encoders.py 8.47 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
16
17
"""Transformer Encoders.

Hongkun Yu's avatar
Hongkun Yu committed
18
Includes configurations and factory methods.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
19
"""
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
20
from typing import Optional
Hongkun Yu's avatar
Hongkun Yu committed
21
22

from absl import logging
23
import dataclasses
Hongkun Yu's avatar
Hongkun Yu committed
24
import gin
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
25
import tensorflow as tf
26

Hongkun Yu's avatar
Hongkun Yu committed
27
from official.modeling import hyperparams
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
28
from official.modeling import tf_utils
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
29
from official.nlp.modeling import layers
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
30
from official.nlp.modeling import networks
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
31
from official.nlp.projects.mobilebert import modeling
32
33
34


@dataclasses.dataclass
Hongkun Yu's avatar
Hongkun Yu committed
35
class BertEncoderConfig(hyperparams.Config):
36
37
38
39
40
41
  """BERT encoder configuration."""
  vocab_size: int = 30522
  hidden_size: int = 768
  num_layers: int = 12
  num_attention_heads: int = 12
  hidden_activation: str = "gelu"
Chen Chen's avatar
Chen Chen committed
42
  intermediate_size: int = 3072
43
44
45
46
47
  dropout_rate: float = 0.1
  attention_dropout_rate: float = 0.1
  max_position_embeddings: int = 512
  type_vocab_size: int = 2
  initializer_range: float = 0.02
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
48
  embedding_size: Optional[int] = None
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
49
50


A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
@dataclasses.dataclass
class MobileBertEncoderConfig(hyperparams.Config):
  """MobileBERT encoder configuration.

  Attributes:
    word_vocab_size: number of words in the vocabulary.
    word_embed_size: word embedding size.
    type_vocab_size: number of word types.
    max_sequence_length: maximum length of input sequence.
    num_blocks: number of transformer block in the encoder model.
    hidden_size: the hidden size for the transformer block.
    num_attention_heads: number of attention heads in the transformer block.
    intermediate_size: the size of the "intermediate" (a.k.a., feed
      forward) layer.
    intermediate_act_fn: the non-linear activation function to apply
      to the output of the intermediate/feed-forward layer.
    hidden_dropout_prob: dropout probability for the hidden layers.
    attention_probs_dropout_prob: dropout probability of the attention
      probabilities.
    intra_bottleneck_size: the size of bottleneck.
    initializer_range: The stddev of the truncated_normal_initializer for
        initializing all weight matrices.
    key_query_shared_bottleneck: whether to share linear transformation for
      keys and queries.
    num_feedforward_networks: number of stacked feed-forward networks.
    normalization_type: the type of normalization_type, only 'no_norm' and
      'layer_norm' are supported. 'no_norm' represents the element-wise linear
      transformation for the student model, as suggested by the original
      MobileBERT paper. 'layer_norm' is used for the teacher model.
    classifier_activation: if using the tanh activation for the final
      representation of the [CLS] token in fine-tuning.
    return_all_layers: if return all layer outputs.
    return_attention_score: if return attention scores for each layer.
  """
  word_vocab_size: int = 30522
  word_embed_size: int = 128
  type_vocab_size: int = 2
  max_sequence_length: int = 512
  num_blocks: int = 24
  hidden_size: int = 512
  num_attention_heads: int = 4
  intermediate_size: int = 4096
  intermediate_act_fn: str = "gelu"
  hidden_dropout_prob: float = 0.1
  attention_probs_dropout_prob: float = 0.1
  intra_bottleneck_size: int = 1024
  initializer_range: float = 0.02
  key_query_shared_bottleneck: bool = False
  num_feedforward_networks: int = 1
  normalization_type: str = "layer_norm"
  classifier_activation: bool = True
  return_all_layers: bool = False
  return_attention_score: bool = False


Hongkun Yu's avatar
Hongkun Yu committed
106
107
108
109
110
@dataclasses.dataclass
class EncoderConfig(hyperparams.OneOfConfig):
  """Encoder configuration."""
  type: Optional[str] = "bert"
  bert: BertEncoderConfig = BertEncoderConfig()
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
111
  mobilebert: MobileBertEncoderConfig = MobileBertEncoderConfig()
Hongkun Yu's avatar
Hongkun Yu committed
112
113
114
115


ENCODER_CLS = {
    "bert": networks.TransformerEncoder,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
116
    "mobilebert": modeling.MobileBERTEncoder,
Hongkun Yu's avatar
Hongkun Yu committed
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
}


@gin.configurable
def build_encoder(config: EncoderConfig,
                  embedding_layer: Optional[layers.OnDeviceEmbedding] = None,
                  encoder_cls=None,
                  bypass_config: bool = False):
  """Instantiate a Transformer encoder network from EncoderConfig.

  Args:
    config: the one-of encoder config, which provides encoder parameters of a
      chosen encoder.
    embedding_layer: an external embedding layer passed to the encoder.
    encoder_cls: an external encoder cls not included in the supported encoders,
      usually used by gin.configurable.
    bypass_config: whether to ignore config instance to create the object with
      `encoder_cls`.

  Returns:
    An encoder instance.
  """
  encoder_type = config.type
  encoder_cfg = config.get()
  encoder_cls = encoder_cls or ENCODER_CLS[encoder_type]
  logging.info("Encoder class: %s to build...", encoder_cls.__name__)
  if bypass_config:
    return encoder_cls()
Hongkun Yu's avatar
Hongkun Yu committed
145
146
  if encoder_cls.__name__ == "EncoderScaffold":
    embedding_cfg = dict(
Hongkun Yu's avatar
Hongkun Yu committed
147
148
149
150
        vocab_size=encoder_cfg.vocab_size,
        type_vocab_size=encoder_cfg.type_vocab_size,
        hidden_size=encoder_cfg.hidden_size,
        max_seq_length=encoder_cfg.max_position_embeddings,
Hongkun Yu's avatar
Hongkun Yu committed
151
        initializer=tf.keras.initializers.TruncatedNormal(
Hongkun Yu's avatar
Hongkun Yu committed
152
153
            stddev=encoder_cfg.initializer_range),
        dropout_rate=encoder_cfg.dropout_rate,
Hongkun Yu's avatar
Hongkun Yu committed
154
155
    )
    hidden_cfg = dict(
Hongkun Yu's avatar
Hongkun Yu committed
156
157
        num_attention_heads=encoder_cfg.num_attention_heads,
        intermediate_size=encoder_cfg.intermediate_size,
Hongkun Yu's avatar
Hongkun Yu committed
158
        intermediate_activation=tf_utils.get_activation(
Hongkun Yu's avatar
Hongkun Yu committed
159
160
161
            encoder_cfg.hidden_activation),
        dropout_rate=encoder_cfg.dropout_rate,
        attention_dropout_rate=encoder_cfg.attention_dropout_rate,
Hongkun Yu's avatar
Hongkun Yu committed
162
        kernel_initializer=tf.keras.initializers.TruncatedNormal(
Hongkun Yu's avatar
Hongkun Yu committed
163
            stddev=encoder_cfg.initializer_range),
Hongkun Yu's avatar
Hongkun Yu committed
164
165
166
167
    )
    kwargs = dict(
        embedding_cfg=embedding_cfg,
        hidden_cfg=hidden_cfg,
Hongkun Yu's avatar
Hongkun Yu committed
168
169
        num_hidden_instances=encoder_cfg.num_layers,
        pooled_output_dim=encoder_cfg.hidden_size,
Hongkun Yu's avatar
Hongkun Yu committed
170
        pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
Hongkun Yu's avatar
Hongkun Yu committed
171
            stddev=encoder_cfg.initializer_range))
Hongkun Yu's avatar
Hongkun Yu committed
172
173
    return encoder_cls(**kwargs)

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
  if encoder_type == "mobilebert":
    return encoder_cls(
        word_vocab_size=encoder_cfg.word_vocab_size,
        word_embed_size=encoder_cfg.word_embed_size,
        type_vocab_size=encoder_cfg.type_vocab_size,
        max_sequence_length=encoder_cfg.max_sequence_length,
        num_blocks=encoder_cfg.num_blocks,
        hidden_size=encoder_cfg.hidden_size,
        num_attention_heads=encoder_cfg.num_attention_heads,
        intermediate_size=encoder_cfg.intermediate_size,
        intermediate_act_fn=encoder_cfg.intermediate_act_fn,
        hidden_dropout_prob=encoder_cfg.hidden_dropout_prob,
        attention_probs_dropout_prob=encoder_cfg.attention_probs_dropout_prob,
        intra_bottleneck_size=encoder_cfg.intra_bottleneck_size,
        key_query_shared_bottleneck=encoder_cfg.key_query_shared_bottleneck,
        num_feedforward_networks=encoder_cfg.num_feedforward_networks,
        normalization_type=encoder_cfg.normalization_type,
        classifier_activation=encoder_cfg.classifier_activation,
        return_all_layers=encoder_cfg.return_all_layers,
        return_attention_score=encoder_cfg.return_attention_score)

Hongkun Yu's avatar
Hongkun Yu committed
195
196
197
198
199
200
201
202
203
204
205
206
207
  # Uses the default BERTEncoder configuration schema to create the encoder.
  # If it does not match, please add a switch branch by the encoder type.
  return encoder_cls(
      vocab_size=encoder_cfg.vocab_size,
      hidden_size=encoder_cfg.hidden_size,
      num_layers=encoder_cfg.num_layers,
      num_attention_heads=encoder_cfg.num_attention_heads,
      intermediate_size=encoder_cfg.intermediate_size,
      activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
      dropout_rate=encoder_cfg.dropout_rate,
      attention_dropout_rate=encoder_cfg.attention_dropout_rate,
      max_sequence_length=encoder_cfg.max_position_embeddings,
      type_vocab_size=encoder_cfg.type_vocab_size,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
208
      initializer=tf.keras.initializers.TruncatedNormal(
Hongkun Yu's avatar
Hongkun Yu committed
209
210
          stddev=encoder_cfg.initializer_range),
      embedding_width=encoder_cfg.embedding_size,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
211
      embedding_layer=embedding_layer)