bert.py 2.63 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
16
17
18
19
"""Multi-head BERT encoder network with classification heads.

Includes configurations and instantiation methods.
"""
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from typing import List, Optional, Text

import dataclasses
import tensorflow as tf

from official.modeling import tf_utils
from official.modeling.hyperparams import base_config
from official.nlp.configs import encoders
from official.nlp.modeling import layers
from official.nlp.modeling.models import bert_pretrainer


@dataclasses.dataclass
class ClsHeadConfig(base_config.Config):
  inner_dim: int = 0
  num_classes: int = 2
  activation: Optional[Text] = "tanh"
  dropout_rate: float = 0.0
  cls_token_idx: int = 0
  name: Optional[Text] = None


@dataclasses.dataclass
class BertPretrainerConfig(base_config.Config):
  """BERT encoder configuration."""
  num_masked_tokens: int = 76
  encoder: encoders.TransformerEncoderConfig = (
      encoders.TransformerEncoderConfig())
  cls_heads: List[ClsHeadConfig] = dataclasses.field(default_factory=list)


A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
51
52
53
54
55
56
57
58
def instantiate_classification_heads_from_cfgs(
    cls_head_configs: List[ClsHeadConfig]) -> List[layers.ClassificationHead]:
  return [
      layers.ClassificationHead(**cfg.as_dict()) for cfg in cls_head_configs
    ] if cls_head_configs else []


def instantiate_bertpretrainer_from_cfg(
59
    config: BertPretrainerConfig,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
60
61
    encoder_network: Optional[tf.keras.Model] = None
    ) -> bert_pretrainer.BertPretrainerV2:
62
  """Instantiates a BertPretrainer from the config."""
Chen Chen's avatar
Chen Chen committed
63
  encoder_cfg = config.encoder
64
  if encoder_network is None:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
65
    encoder_network = encoders.instantiate_encoder_from_cfg(encoder_cfg)
66
67
  return bert_pretrainer.BertPretrainerV2(
      config.num_masked_tokens,
Hongkun Yu's avatar
Hongkun Yu committed
68
      mlm_activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
69
70
71
      mlm_initializer=tf.keras.initializers.TruncatedNormal(
          stddev=encoder_cfg.initializer_range),
      encoder_network=encoder_network,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
72
73
      classification_heads=instantiate_classification_heads_from_cfgs(
          config.cls_heads))