encoders.py 4.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
16
17
"""Transformer Encoders.

Hongkun Yu's avatar
Hongkun Yu committed
18
Includes configurations and factory methods.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
19
"""
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
20
from typing import Optional
Hongkun Yu's avatar
Hongkun Yu committed
21
22

from absl import logging
23
import dataclasses
Hongkun Yu's avatar
Hongkun Yu committed
24
import gin
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
25
import tensorflow as tf
26

Hongkun Yu's avatar
Hongkun Yu committed
27
from official.modeling import hyperparams
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
28
from official.modeling import tf_utils
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
29
from official.nlp.modeling import layers
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
30
from official.nlp.modeling import networks
31
32
33


@dataclasses.dataclass
Hongkun Yu's avatar
Hongkun Yu committed
34
class BertEncoderConfig(hyperparams.Config):
35
36
37
38
39
40
  """BERT encoder configuration."""
  vocab_size: int = 30522
  hidden_size: int = 768
  num_layers: int = 12
  num_attention_heads: int = 12
  hidden_activation: str = "gelu"
Chen Chen's avatar
Chen Chen committed
41
  intermediate_size: int = 3072
42
43
44
45
46
  dropout_rate: float = 0.1
  attention_dropout_rate: float = 0.1
  max_position_embeddings: int = 512
  type_vocab_size: int = 2
  initializer_range: float = 0.02
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
47
  embedding_size: Optional[int] = None
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
48
49


Hongkun Yu's avatar
Hongkun Yu committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
@dataclasses.dataclass
class EncoderConfig(hyperparams.OneOfConfig):
  """Encoder configuration."""
  type: Optional[str] = "bert"
  bert: BertEncoderConfig = BertEncoderConfig()


ENCODER_CLS = {
    "bert": networks.TransformerEncoder,
}


@gin.configurable
def build_encoder(config: EncoderConfig,
                  embedding_layer: Optional[layers.OnDeviceEmbedding] = None,
                  encoder_cls=None,
                  bypass_config: bool = False):
  """Instantiate a Transformer encoder network from EncoderConfig.

  Args:
    config: the one-of encoder config, which provides encoder parameters of a
      chosen encoder.
    embedding_layer: an external embedding layer passed to the encoder.
    encoder_cls: an external encoder cls not included in the supported encoders,
      usually used by gin.configurable.
    bypass_config: whether to ignore config instance to create the object with
      `encoder_cls`.

  Returns:
    An encoder instance.
  """
  encoder_type = config.type
  encoder_cfg = config.get()
  encoder_cls = encoder_cls or ENCODER_CLS[encoder_type]
  logging.info("Encoder class: %s to build...", encoder_cls.__name__)
  if bypass_config:
    return encoder_cls()
Hongkun Yu's avatar
Hongkun Yu committed
87
88
  if encoder_cls.__name__ == "EncoderScaffold":
    embedding_cfg = dict(
Hongkun Yu's avatar
Hongkun Yu committed
89
90
91
92
        vocab_size=encoder_cfg.vocab_size,
        type_vocab_size=encoder_cfg.type_vocab_size,
        hidden_size=encoder_cfg.hidden_size,
        max_seq_length=encoder_cfg.max_position_embeddings,
Hongkun Yu's avatar
Hongkun Yu committed
93
        initializer=tf.keras.initializers.TruncatedNormal(
Hongkun Yu's avatar
Hongkun Yu committed
94
95
            stddev=encoder_cfg.initializer_range),
        dropout_rate=encoder_cfg.dropout_rate,
Hongkun Yu's avatar
Hongkun Yu committed
96
97
    )
    hidden_cfg = dict(
Hongkun Yu's avatar
Hongkun Yu committed
98
99
        num_attention_heads=encoder_cfg.num_attention_heads,
        intermediate_size=encoder_cfg.intermediate_size,
Hongkun Yu's avatar
Hongkun Yu committed
100
        intermediate_activation=tf_utils.get_activation(
Hongkun Yu's avatar
Hongkun Yu committed
101
102
103
            encoder_cfg.hidden_activation),
        dropout_rate=encoder_cfg.dropout_rate,
        attention_dropout_rate=encoder_cfg.attention_dropout_rate,
Hongkun Yu's avatar
Hongkun Yu committed
104
        kernel_initializer=tf.keras.initializers.TruncatedNormal(
Hongkun Yu's avatar
Hongkun Yu committed
105
            stddev=encoder_cfg.initializer_range),
Hongkun Yu's avatar
Hongkun Yu committed
106
107
108
109
    )
    kwargs = dict(
        embedding_cfg=embedding_cfg,
        hidden_cfg=hidden_cfg,
Hongkun Yu's avatar
Hongkun Yu committed
110
111
        num_hidden_instances=encoder_cfg.num_layers,
        pooled_output_dim=encoder_cfg.hidden_size,
Hongkun Yu's avatar
Hongkun Yu committed
112
        pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
Hongkun Yu's avatar
Hongkun Yu committed
113
            stddev=encoder_cfg.initializer_range))
Hongkun Yu's avatar
Hongkun Yu committed
114
115
    return encoder_cls(**kwargs)

Hongkun Yu's avatar
Hongkun Yu committed
116
117
118
119
120
121
122
123
124
125
126
127
128
  # Uses the default BERTEncoder configuration schema to create the encoder.
  # If it does not match, please add a switch branch by the encoder type.
  return encoder_cls(
      vocab_size=encoder_cfg.vocab_size,
      hidden_size=encoder_cfg.hidden_size,
      num_layers=encoder_cfg.num_layers,
      num_attention_heads=encoder_cfg.num_attention_heads,
      intermediate_size=encoder_cfg.intermediate_size,
      activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
      dropout_rate=encoder_cfg.dropout_rate,
      attention_dropout_rate=encoder_cfg.attention_dropout_rate,
      max_sequence_length=encoder_cfg.max_position_embeddings,
      type_vocab_size=encoder_cfg.type_vocab_size,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
129
      initializer=tf.keras.initializers.TruncatedNormal(
Hongkun Yu's avatar
Hongkun Yu committed
130
131
          stddev=encoder_cfg.initializer_range),
      embedding_width=encoder_cfg.embedding_size,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
132
      embedding_layer=embedding_layer)