bert.py 2.56 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
16
17
18
19
"""Multi-head BERT encoder network with classification heads.

Includes configurations and instantiation methods.
"""
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from typing import List, Optional, Text

import dataclasses
import tensorflow as tf

from official.modeling import tf_utils
from official.modeling.hyperparams import base_config
from official.nlp.configs import encoders
from official.nlp.modeling import layers
from official.nlp.modeling.models import bert_pretrainer


@dataclasses.dataclass
class ClsHeadConfig(base_config.Config):
  inner_dim: int = 0
  num_classes: int = 2
  activation: Optional[Text] = "tanh"
  dropout_rate: float = 0.0
  cls_token_idx: int = 0
  name: Optional[Text] = None


@dataclasses.dataclass
class BertPretrainerConfig(base_config.Config):
  """BERT encoder configuration."""
  encoder: encoders.TransformerEncoderConfig = (
      encoders.TransformerEncoderConfig())
  cls_heads: List[ClsHeadConfig] = dataclasses.field(default_factory=list)


A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
50
51
52
53
54
55
56
def instantiate_classification_heads_from_cfgs(
    cls_head_configs: List[ClsHeadConfig]) -> List[layers.ClassificationHead]:
  return [
      layers.ClassificationHead(**cfg.as_dict()) for cfg in cls_head_configs
    ] if cls_head_configs else []


Hongkun Yu's avatar
Hongkun Yu committed
57
def instantiate_pretrainer_from_cfg(
58
    config: BertPretrainerConfig,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
59
    encoder_network: Optional[tf.keras.Model] = None
Hongkun Yu's avatar
Hongkun Yu committed
60
) -> bert_pretrainer.BertPretrainerV2:
61
  """Instantiates a BertPretrainer from the config."""
Chen Chen's avatar
Chen Chen committed
62
  encoder_cfg = config.encoder
63
  if encoder_network is None:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
64
    encoder_network = encoders.instantiate_encoder_from_cfg(encoder_cfg)
65
  return bert_pretrainer.BertPretrainerV2(
Hongkun Yu's avatar
Hongkun Yu committed
66
      mlm_activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
67
68
69
      mlm_initializer=tf.keras.initializers.TruncatedNormal(
          stddev=encoder_cfg.initializer_range),
      encoder_network=encoder_network,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
70
71
      classification_heads=instantiate_classification_heads_from_cfgs(
          config.cls_heads))