"SpeechT5/fairseq/docs/command_line_tools.rst" did not exist on "417b607b2a622da9321c932a5b3bc0f6b0ece56b"
bert.py 5.03 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
16
17
18
19
"""Multi-head BERT encoder network with classification heads.

Includes configurations and instantiation methods.
"""
20
21
22
23
24
25
26
from typing import List, Optional, Text

import dataclasses
import tensorflow as tf

from official.modeling import tf_utils
from official.modeling.hyperparams import base_config
Hongkun Yu's avatar
Hongkun Yu committed
27
from official.modeling.hyperparams import config_definitions as cfg
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from official.nlp.configs import encoders
from official.nlp.modeling import layers
from official.nlp.modeling.models import bert_pretrainer


@dataclasses.dataclass
class ClsHeadConfig(base_config.Config):
  inner_dim: int = 0
  num_classes: int = 2
  activation: Optional[Text] = "tanh"
  dropout_rate: float = 0.0
  cls_token_idx: int = 0
  name: Optional[Text] = None


@dataclasses.dataclass
class BertPretrainerConfig(base_config.Config):
  """BERT encoder configuration."""
  num_masked_tokens: int = 76
  encoder: encoders.TransformerEncoderConfig = (
      encoders.TransformerEncoderConfig())
  cls_heads: List[ClsHeadConfig] = dataclasses.field(default_factory=list)


A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
52
53
54
55
56
57
58
59
def instantiate_classification_heads_from_cfgs(
    cls_head_configs: List[ClsHeadConfig]) -> List[layers.ClassificationHead]:
  return [
      layers.ClassificationHead(**cfg.as_dict()) for cfg in cls_head_configs
    ] if cls_head_configs else []


def instantiate_bertpretrainer_from_cfg(
60
    config: BertPretrainerConfig,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
61
62
    encoder_network: Optional[tf.keras.Model] = None
    ) -> bert_pretrainer.BertPretrainerV2:
63
  """Instantiates a BertPretrainer from the config."""
Chen Chen's avatar
Chen Chen committed
64
  encoder_cfg = config.encoder
65
  if encoder_network is None:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
66
    encoder_network = encoders.instantiate_encoder_from_cfg(encoder_cfg)
67
68
  return bert_pretrainer.BertPretrainerV2(
      config.num_masked_tokens,
Hongkun Yu's avatar
Hongkun Yu committed
69
      mlm_activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
70
71
72
      mlm_initializer=tf.keras.initializers.TruncatedNormal(
          stddev=encoder_cfg.initializer_range),
      encoder_network=encoder_network,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
73
74
      classification_heads=instantiate_classification_heads_from_cfgs(
          config.cls_heads))
Hongkun Yu's avatar
Hongkun Yu committed
75
76
77
78


@dataclasses.dataclass
class BertPretrainDataConfig(cfg.DataConfig):
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
79
  """Data config for BERT pretraining task (tasks/masked_lm)."""
Hongkun Yu's avatar
Hongkun Yu committed
80
81
82
83
84
85
86
87
88
89
90
  input_path: str = ""
  global_batch_size: int = 512
  is_training: bool = True
  seq_length: int = 512
  max_predictions_per_seq: int = 76
  use_next_sentence_label: bool = True
  use_position_id: bool = False


@dataclasses.dataclass
class BertPretrainEvalDataConfig(BertPretrainDataConfig):
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
91
  """Data config for the eval set in BERT pretraining task (tasks/masked_lm)."""
Hongkun Yu's avatar
Hongkun Yu committed
92
93
94
  input_path: str = ""
  global_batch_size: int = 512
  is_training: bool = False
95
96
97


@dataclasses.dataclass
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
98
99
class SentencePredictionDataConfig(cfg.DataConfig):
  """Data config for sentence prediction task (tasks/sentence_prediction)."""
100
101
102
103
104
105
106
  input_path: str = ""
  global_batch_size: int = 32
  is_training: bool = True
  seq_length: int = 128


@dataclasses.dataclass
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
107
108
class SentencePredictionDevDataConfig(cfg.DataConfig):
  """Dev Data config for sentence prediction (tasks/sentence_prediction)."""
109
110
111
112
113
  input_path: str = ""
  global_batch_size: int = 32
  is_training: bool = False
  seq_length: int = 128
  drop_remainder: bool = False
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128


@dataclasses.dataclass
class QADataConfig(cfg.DataConfig):
  """Data config for question answering task (tasks/question_answering)."""
  input_path: str = ""
  global_batch_size: int = 48
  is_training: bool = True
  seq_length: int = 384


@dataclasses.dataclass
class QADevDataConfig(cfg.DataConfig):
  """Dev Data config for queston answering (tasks/question_answering)."""
  input_path: str = ""
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
129
130
131
  input_preprocessed_data_path: str = ""
  version_2_with_negative: bool = False
  doc_stride: int = 128
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
132
133
134
  global_batch_size: int = 48
  is_training: bool = False
  seq_length: int = 384
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
135
  query_length: int = 64
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
136
  drop_remainder: bool = False
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
137
138
139
  vocab_file: str = ""
  tokenization: str = "WordPiece"  # WordPiece or SentencePiece
  do_lower_case: bool = True
Chen Chen's avatar
Chen Chen committed
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158


@dataclasses.dataclass
class TaggingDataConfig(cfg.DataConfig):
  """Data config for tagging (tasks/tagging)."""
  input_path: str = ""
  global_batch_size: int = 48
  is_training: bool = True
  seq_length: int = 384


@dataclasses.dataclass
class TaggingDevDataConfig(cfg.DataConfig):
  """Dev Data config for tagging (tasks/tagging)."""
  input_path: str = ""
  global_batch_size: int = 48
  is_training: bool = False
  seq_length: int = 384
  drop_remainder: bool = False