Commit 35aa1f31 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

[Refactor] BertConfig -> bert/configs.py

PiperOrigin-RevId: 294810822
parent 8bf0193d
...@@ -31,7 +31,7 @@ import tensorflow as tf ...@@ -31,7 +31,7 @@ import tensorflow as tf
# pylint: enable=g-bad-import-order # pylint: enable=g-bad-import-order
from official.benchmark import bert_benchmark_utils as benchmark_utils from official.benchmark import bert_benchmark_utils as benchmark_utils
from official.nlp import bert_modeling as modeling from official.nlp.bert import configs
from official.nlp.bert import run_classifier from official.nlp.bert import run_classifier
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
from official.utils.testing import benchmark_wrappers from official.utils.testing import benchmark_wrappers
...@@ -63,7 +63,7 @@ class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase): ...@@ -63,7 +63,7 @@ class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase):
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader: with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8')) input_meta_data = json.loads(reader.read().decode('utf-8'))
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
epochs = self.num_epochs if self.num_epochs else FLAGS.num_train_epochs epochs = self.num_epochs if self.num_epochs else FLAGS.num_train_epochs
if self.num_steps_per_epoch: if self.num_steps_per_epoch:
steps_per_epoch = self.num_steps_per_epoch steps_per_epoch = self.num_steps_per_epoch
......
...@@ -20,10 +20,10 @@ from __future__ import print_function ...@@ -20,10 +20,10 @@ from __future__ import print_function
import six import six
from official.nlp import bert_modeling from official.nlp.bert import configs
class AlbertConfig(bert_modeling.BertConfig): class AlbertConfig(configs.BertConfig):
"""Configuration for `ALBERT`.""" """Configuration for `ALBERT`."""
def __init__(self, def __init__(self,
......
...@@ -22,8 +22,8 @@ import tensorflow as tf ...@@ -22,8 +22,8 @@ import tensorflow as tf
import tensorflow_hub as hub import tensorflow_hub as hub
from official.modeling import tf_utils from official.modeling import tf_utils
from official.nlp import bert_modeling
from official.nlp.albert import configs as albert_configs from official.nlp.albert import configs as albert_configs
from official.nlp.bert import configs
from official.nlp.modeling import losses from official.nlp.modeling import losses
from official.nlp.modeling import networks from official.nlp.modeling import networks
from official.nlp.modeling.networks import bert_classifier from official.nlp.modeling.networks import bert_classifier
...@@ -114,7 +114,7 @@ def get_transformer_encoder(bert_config, ...@@ -114,7 +114,7 @@ def get_transformer_encoder(bert_config,
kwargs['embedding_width'] = bert_config.embedding_size kwargs['embedding_width'] = bert_config.embedding_size
return networks.AlbertTransformerEncoder(**kwargs) return networks.AlbertTransformerEncoder(**kwargs)
else: else:
assert isinstance(bert_config, bert_modeling.BertConfig) assert isinstance(bert_config, configs.BertConfig)
return networks.TransformerEncoder(**kwargs) return networks.TransformerEncoder(**kwargs)
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The main BERT model and related functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import json
import six
import tensorflow as tf
class BertConfig(object):
"""Configuration for `BertModel`."""
def __init__(self,
vocab_size,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
initializer_range=0.02,
backward_compatible=True):
"""Constructs BertConfig.
Args:
vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
hidden_size: Size of the encoder layers and the pooler layer.
num_hidden_layers: Number of hidden layers in the Transformer encoder.
num_attention_heads: Number of attention heads for each attention layer in
the Transformer encoder.
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
hidden_act: The non-linear activation function (function or string) in the
encoder and pooler.
hidden_dropout_prob: The dropout probability for all fully connected
layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob: The dropout ratio for the attention
probabilities.
max_position_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
`BertModel`.
initializer_range: The stdev of the truncated_normal_initializer for
initializing all weight matrices.
backward_compatible: Boolean, whether the variables shape are compatible
with checkpoints converted from TF 1.x BERT.
"""
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.backward_compatible = backward_compatible
@classmethod
def from_dict(cls, json_object):
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
config = BertConfig(vocab_size=None)
for (key, value) in six.iteritems(json_object):
config.__dict__[key] = value
return config
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `BertConfig` from a json file of parameters."""
with tf.io.gfile.GFile(json_file, "r") as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
...@@ -22,9 +22,8 @@ from absl import app ...@@ -22,9 +22,8 @@ from absl import app
from absl import flags from absl import flags
import tensorflow as tf import tensorflow as tf
from typing import Text from typing import Text
from official.nlp import bert_modeling
from official.nlp.bert import bert_models from official.nlp.bert import bert_models
from official.nlp.bert import configs
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -37,7 +36,7 @@ flags.DEFINE_string("vocab_file", None, ...@@ -37,7 +36,7 @@ flags.DEFINE_string("vocab_file", None,
"The vocabulary file that the BERT model was trained on.") "The vocabulary file that the BERT model was trained on.")
def create_bert_model(bert_config: bert_modeling.BertConfig) -> tf.keras.Model: def create_bert_model(bert_config: configs.BertConfig) -> tf.keras.Model:
"""Creates a BERT keras core model from BERT configuration. """Creates a BERT keras core model from BERT configuration.
Args: Args:
...@@ -64,7 +63,7 @@ def create_bert_model(bert_config: bert_modeling.BertConfig) -> tf.keras.Model: ...@@ -64,7 +63,7 @@ def create_bert_model(bert_config: bert_modeling.BertConfig) -> tf.keras.Model:
outputs=[pooled_output, sequence_output]), transformer_encoder outputs=[pooled_output, sequence_output]), transformer_encoder
def export_bert_tfhub(bert_config: bert_modeling.BertConfig, def export_bert_tfhub(bert_config: configs.BertConfig,
model_checkpoint_path: Text, hub_destination: Text, model_checkpoint_path: Text, hub_destination: Text,
vocab_file: Text): vocab_file: Text):
"""Restores a tf.keras.Model and saves for TF-Hub.""" """Restores a tf.keras.Model and saves for TF-Hub."""
...@@ -79,7 +78,7 @@ def export_bert_tfhub(bert_config: bert_modeling.BertConfig, ...@@ -79,7 +78,7 @@ def export_bert_tfhub(bert_config: bert_modeling.BertConfig,
def main(_): def main(_):
assert tf.version.VERSION.startswith('2.') assert tf.version.VERSION.startswith('2.')
bert_config = bert_modeling.BertConfig.from_json_file(FLAGS.bert_config_file) bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
export_bert_tfhub(bert_config, FLAGS.model_checkpoint_path, FLAGS.export_path, export_bert_tfhub(bert_config, FLAGS.model_checkpoint_path, FLAGS.export_path,
FLAGS.vocab_file) FLAGS.vocab_file)
......
...@@ -24,7 +24,7 @@ import numpy as np ...@@ -24,7 +24,7 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
import tensorflow_hub as hub import tensorflow_hub as hub
from official.nlp import bert_modeling from official.nlp.bert import configs
from official.nlp.bert import export_tfhub from official.nlp.bert import export_tfhub
...@@ -32,7 +32,7 @@ class ExportTfhubTest(tf.test.TestCase): ...@@ -32,7 +32,7 @@ class ExportTfhubTest(tf.test.TestCase):
def test_export_tfhub(self): def test_export_tfhub(self):
# Exports a savedmodel for TF-Hub # Exports a savedmodel for TF-Hub
bert_config = bert_modeling.BertConfig( bert_config = configs.BertConfig(
vocab_size=100, vocab_size=100,
hidden_size=16, hidden_size=16,
intermediate_size=32, intermediate_size=32,
......
...@@ -12,8 +12,7 @@ ...@@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""BERT classification finetuning runner in tf2.0.""" """BERT classification finetuning runner in TF 2.x."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
...@@ -27,18 +26,18 @@ from absl import flags ...@@ -27,18 +26,18 @@ from absl import flags
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
# pylint: disable=g-import-not-at-top,redefined-outer-name,reimported
from official.modeling import model_training_utils from official.modeling import model_training_utils
from official.nlp import bert_modeling as modeling
from official.nlp import optimization from official.nlp import optimization
from official.nlp.albert import configs as albert_configs from official.nlp.albert import configs as albert_configs
from official.nlp.bert import bert_models from official.nlp.bert import bert_models
from official.nlp.bert import common_flags from official.nlp.bert import common_flags
from official.nlp.bert import configs as bert_configs
from official.nlp.bert import input_pipeline from official.nlp.bert import input_pipeline
from official.nlp.bert import model_saving_utils from official.nlp.bert import model_saving_utils
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
flags.DEFINE_enum( flags.DEFINE_enum(
'mode', 'train_and_eval', ['train_and_eval', 'export_only'], 'mode', 'train_and_eval', ['train_and_eval', 'export_only'],
'One of {"train_and_eval", "export_only"}. `train_and_eval`: ' 'One of {"train_and_eval", "export_only"}. `train_and_eval`: '
...@@ -290,7 +289,7 @@ def run_bert(strategy, ...@@ -290,7 +289,7 @@ def run_bert(strategy,
eval_input_fn=None): eval_input_fn=None):
"""Run BERT training.""" """Run BERT training."""
if FLAGS.model_type == 'bert': if FLAGS.model_type == 'bert':
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
else: else:
assert FLAGS.model_type == 'albert' assert FLAGS.model_type == 'albert'
bert_config = albert_configs.AlbertConfig.from_json_file( bert_config = albert_configs.AlbertConfig.from_json_file(
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Run masked LM/next sentence masked_lm pre-training for BERT in tf2.0.""" """Run masked LM/next sentence masked_lm pre-training for BERT in TF 2.x."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
...@@ -22,16 +22,14 @@ from absl import flags ...@@ -22,16 +22,14 @@ from absl import flags
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
# pylint: disable=unused-import,g-import-not-at-top,redefined-outer-name,reimported
from official.modeling import model_training_utils from official.modeling import model_training_utils
from official.nlp import bert_modeling as modeling
from official.nlp import optimization from official.nlp import optimization
from official.nlp.bert import bert_models from official.nlp.bert import bert_models
from official.nlp.bert import common_flags from official.nlp.bert import common_flags
from official.nlp.bert import configs
from official.nlp.bert import input_pipeline from official.nlp.bert import input_pipeline
from official.nlp.bert import model_saving_utils
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
from official.utils.misc import tpu_lib
flags.DEFINE_string('input_files', None, flags.DEFINE_string('input_files', None,
'File path to retrieve training data for pre-training.') 'File path to retrieve training data for pre-training.')
...@@ -135,7 +133,7 @@ def run_customized_training(strategy, ...@@ -135,7 +133,7 @@ def run_customized_training(strategy,
def run_bert_pretrain(strategy): def run_bert_pretrain(strategy):
"""Runs BERT pre-training.""" """Runs BERT pre-training."""
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
if not strategy: if not strategy:
raise ValueError('Distribution strategy is not specified.') raise ValueError('Distribution strategy is not specified.')
......
...@@ -12,8 +12,7 @@ ...@@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Run BERT on SQuAD 1.1 and SQuAD 2.0 in tf2.0.""" """Run BERT on SQuAD 1.1 and SQuAD 2.0 in TF 2.x."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
...@@ -26,23 +25,20 @@ from absl import flags ...@@ -26,23 +25,20 @@ from absl import flags
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
# pylint: disable=unused-import,g-import-not-at-top,redefined-outer-name,reimported
from official.modeling import model_training_utils from official.modeling import model_training_utils
from official.nlp import bert_modeling as modeling
from official.nlp import optimization from official.nlp import optimization
from official.nlp.albert import configs as albert_configs from official.nlp.albert import configs as albert_configs
from official.nlp.bert import bert_models from official.nlp.bert import bert_models
from official.nlp.bert import common_flags from official.nlp.bert import common_flags
from official.nlp.bert import configs as bert_configs
from official.nlp.bert import input_pipeline from official.nlp.bert import input_pipeline
from official.nlp.bert import model_saving_utils from official.nlp.bert import model_saving_utils
# word-piece tokenizer based squad_lib
from official.nlp.bert import squad_lib as squad_lib_wp from official.nlp.bert import squad_lib as squad_lib_wp
# sentence-piece tokenizer based squad_lib
from official.nlp.bert import squad_lib_sp from official.nlp.bert import squad_lib_sp
from official.nlp.bert import tokenization from official.nlp.bert import tokenization
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
from official.utils.misc import tpu_lib
flags.DEFINE_enum( flags.DEFINE_enum(
'mode', 'train_and_predict', 'mode', 'train_and_predict',
...@@ -99,7 +95,7 @@ common_flags.define_common_bert_flags() ...@@ -99,7 +95,7 @@ common_flags.define_common_bert_flags()
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
MODEL_CLASSES = { MODEL_CLASSES = {
'bert': (modeling.BertConfig, squad_lib_wp, tokenization.FullTokenizer), 'bert': (bert_configs.BertConfig, squad_lib_wp, tokenization.FullTokenizer),
'albert': (albert_configs.AlbertConfig, squad_lib_sp, 'albert': (albert_configs.AlbertConfig, squad_lib_sp,
tokenization.FullSentencePieceTokenizer), tokenization.FullSentencePieceTokenizer),
} }
......
...@@ -28,7 +28,7 @@ from absl import flags ...@@ -28,7 +28,7 @@ from absl import flags
import tensorflow as tf import tensorflow as tf
from official.modeling import activations from official.modeling import activations
from official.nlp import bert_modeling as modeling from official.nlp.bert import configs
from official.nlp.bert import tf1_checkpoint_converter_lib from official.nlp.bert import tf1_checkpoint_converter_lib
from official.nlp.modeling import networks from official.nlp.modeling import networks
...@@ -101,7 +101,7 @@ def main(_): ...@@ -101,7 +101,7 @@ def main(_):
assert tf.version.VERSION.startswith('2.') assert tf.version.VERSION.startswith('2.')
output_path = FLAGS.converted_checkpoint_path output_path = FLAGS.converted_checkpoint_path
v1_checkpoint = FLAGS.checkpoint_to_convert v1_checkpoint = FLAGS.checkpoint_to_convert
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
convert_checkpoint(bert_config, output_path, v1_checkpoint) convert_checkpoint(bert_config, output_path, v1_checkpoint)
......
...@@ -19,131 +19,12 @@ from __future__ import division ...@@ -19,131 +19,12 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import copy import copy
import json
import math import math
import six
import tensorflow as tf import tensorflow as tf
from tensorflow.python.util import deprecation from tensorflow.python.util import deprecation
from official.modeling import tf_utils from official.modeling import tf_utils
from official.nlp.bert import configs
class BertConfig(object):
"""Configuration for `BertModel`."""
def __init__(self,
vocab_size,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
initializer_range=0.02,
backward_compatible=True):
"""Constructs BertConfig.
Args:
vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
hidden_size: Size of the encoder layers and the pooler layer.
num_hidden_layers: Number of hidden layers in the Transformer encoder.
num_attention_heads: Number of attention heads for each attention layer in
the Transformer encoder.
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
hidden_act: The non-linear activation function (function or string) in the
encoder and pooler.
hidden_dropout_prob: The dropout probability for all fully connected
layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob: The dropout ratio for the attention
probabilities.
max_position_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
`BertModel`.
initializer_range: The stdev of the truncated_normal_initializer for
initializing all weight matrices.
backward_compatible: Boolean, whether the variables shape are compatible
with checkpoints converted from TF 1.x BERT.
"""
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.backward_compatible = backward_compatible
@classmethod
def from_dict(cls, json_object):
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
config = BertConfig(vocab_size=None)
for (key, value) in six.iteritems(json_object):
config.__dict__[key] = value
return config
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `BertConfig` from a json file of parameters."""
with tf.io.gfile.GFile(json_file, "r") as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
class AlbertConfig(BertConfig):
"""Configuration for `ALBERT`."""
def __init__(self,
embedding_size,
num_hidden_groups=1,
inner_group_num=1,
**kwargs):
"""Constructs AlbertConfig.
Args:
embedding_size: Size of the factorized word embeddings.
num_hidden_groups: Number of group for the hidden layers, parameters in
the same group are shared. Note that this value and also the following
'inner_group_num' has to be 1 for now, because all released ALBERT
models set them to 1. We may support arbitary valid values in future.
inner_group_num: Number of inner repetition of attention and ffn.
**kwargs: The remaining arguments are the same as above 'BertConfig'.
"""
super(AlbertConfig, self).__init__(**kwargs)
self.embedding_size = embedding_size
# TODO(chendouble): 'inner_group_num' and 'num_hidden_groups' are always 1
# in the released ALBERT. Support other values in AlbertTransformerEncoder
# if needed.
if inner_group_num != 1 or num_hidden_groups != 1:
raise ValueError("We only support 'inner_group_num' and "
"'num_hidden_groups' as 1.")
@classmethod
def from_dict(cls, json_object):
"""Constructs a `AlbertConfig` from a Python dictionary of parameters."""
config = AlbertConfig(embedding_size=None, vocab_size=None)
for (key, value) in six.iteritems(json_object):
config.__dict__[key] = value
return config
@deprecation.deprecated(None, "The function should not be used any more.") @deprecation.deprecated(None, "The function should not be used any more.")
...@@ -174,7 +55,7 @@ class BertModel(tf.keras.layers.Layer): ...@@ -174,7 +55,7 @@ class BertModel(tf.keras.layers.Layer):
input_mask = tf.constant([[1, 1, 1], [1, 1, 0]]) input_mask = tf.constant([[1, 1, 1], [1, 1, 0]])
input_type_ids = tf.constant([[0, 0, 1], [0, 2, 0]]) input_type_ids = tf.constant([[0, 0, 1], [0, 2, 0]])
config = modeling.BertConfig(vocab_size=32000, hidden_size=512, config = configs.BertConfig(vocab_size=32000, hidden_size=512,
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
pooled_output, sequence_output = modeling.BertModel(config=config)( pooled_output, sequence_output = modeling.BertModel(config=config)(
...@@ -190,7 +71,7 @@ class BertModel(tf.keras.layers.Layer): ...@@ -190,7 +71,7 @@ class BertModel(tf.keras.layers.Layer):
def __init__(self, config, float_type=tf.float32, **kwargs): def __init__(self, config, float_type=tf.float32, **kwargs):
super(BertModel, self).__init__(**kwargs) super(BertModel, self).__init__(**kwargs)
self.config = ( self.config = (
BertConfig.from_dict(config) configs.BertConfig.from_dict(config)
if isinstance(config, dict) else copy.deepcopy(config)) if isinstance(config, dict) else copy.deepcopy(config))
self.float_type = float_type self.float_type = float_type
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment