Commit 06d2681c authored by Le Hou's avatar Le Hou Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 422638136
parent 7ba713c9
...@@ -21,8 +21,8 @@ from absl.testing import parameterized ...@@ -21,8 +21,8 @@ from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
import tensorflow_datasets as tfds import tensorflow_datasets as tfds
from official.nlp.bert import tokenization
from official.nlp.data import classifier_data_lib from official.nlp.data import classifier_data_lib
from official.nlp.tools import tokenization
def decode_record(record, name_to_features): def decode_record(record, name_to_features):
......
...@@ -22,7 +22,6 @@ import os ...@@ -22,7 +22,6 @@ import os
from absl import app from absl import app
from absl import flags from absl import flags
import tensorflow as tf import tensorflow as tf
from official.nlp.bert import tokenization
from official.nlp.data import classifier_data_lib from official.nlp.data import classifier_data_lib
from official.nlp.data import sentence_retrieval_lib from official.nlp.data import sentence_retrieval_lib
# word-piece tokenizer based squad_lib # word-piece tokenizer based squad_lib
...@@ -30,6 +29,7 @@ from official.nlp.data import squad_lib as squad_lib_wp ...@@ -30,6 +29,7 @@ from official.nlp.data import squad_lib as squad_lib_wp
# sentence-piece tokenizer based squad_lib # sentence-piece tokenizer based squad_lib
from official.nlp.data import squad_lib_sp from official.nlp.data import squad_lib_sp
from official.nlp.data import tagging_data_lib from official.nlp.data import tagging_data_lib
from official.nlp.tools import tokenization
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
...@@ -24,7 +24,7 @@ from absl import flags ...@@ -24,7 +24,7 @@ from absl import flags
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.nlp.bert import tokenization from official.nlp.tools import tokenization
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
"""Create LM TF examples for XLNet.""" """Create LM TF examples for XLNet."""
import dataclasses
import json import json
import math import math
import os import os
...@@ -28,11 +29,10 @@ from absl import app ...@@ -28,11 +29,10 @@ from absl import app
from absl import flags from absl import flags
from absl import logging from absl import logging
import dataclasses
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from official.nlp.bert import tokenization from official.nlp.tools import tokenization
special_symbols = { special_symbols = {
"<unk>": 0, "<unk>": 0,
......
...@@ -17,8 +17,8 @@ ...@@ -17,8 +17,8 @@
import os import os
from absl import logging from absl import logging
from official.nlp.bert import tokenization
from official.nlp.data import classifier_data_lib from official.nlp.data import classifier_data_lib
from official.nlp.tools import tokenization
class BuccProcessor(classifier_data_lib.DataProcessor): class BuccProcessor(classifier_data_lib.DataProcessor):
......
...@@ -25,7 +25,7 @@ import six ...@@ -25,7 +25,7 @@ import six
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.nlp.bert import tokenization from official.nlp.tools import tokenization
class SquadExample(object): class SquadExample(object):
......
...@@ -28,7 +28,7 @@ from absl import logging ...@@ -28,7 +28,7 @@ from absl import logging
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from official.nlp.bert import tokenization from official.nlp.tools import tokenization
class SquadExample(object): class SquadExample(object):
......
...@@ -19,8 +19,8 @@ import os ...@@ -19,8 +19,8 @@ import os
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.nlp.bert import tokenization
from official.nlp.data import classifier_data_lib from official.nlp.data import classifier_data_lib
from official.nlp.tools import tokenization
# A negative label id for the padding label, which will not contribute # A negative label id for the padding label, which will not contribute
# to loss/metrics in training. # to loss/metrics in training.
......
...@@ -19,8 +19,8 @@ import random ...@@ -19,8 +19,8 @@ import random
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from official.nlp.bert import tokenization
from official.nlp.data import tagging_data_lib from official.nlp.data import tagging_data_lib
from official.nlp.tools import tokenization
def _create_fake_file(filename, labels, is_test): def _create_fake_file(filename, labels, is_test):
......
...@@ -19,7 +19,7 @@ import os ...@@ -19,7 +19,7 @@ import os
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from official.nlp.bert import configs from official.legacy.bert import configs
from official.nlp.configs import bert from official.nlp.configs import bert
from official.nlp.configs import encoders from official.nlp.configs import encoders
from official.nlp.data import dual_encoder_dataloader from official.nlp.data import dual_encoder_dataloader
......
...@@ -13,13 +13,13 @@ ...@@ -13,13 +13,13 @@
# limitations under the License. # limitations under the License.
"""Question answering task.""" """Question answering task."""
import dataclasses
import functools import functools
import json import json
import os import os
from typing import List, Optional from typing import List, Optional
from absl import logging from absl import logging
import dataclasses
import orbit import orbit
import tensorflow as tf import tensorflow as tf
...@@ -27,15 +27,15 @@ from official.core import base_task ...@@ -27,15 +27,15 @@ from official.core import base_task
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.core import task_factory from official.core import task_factory
from official.modeling.hyperparams import base_config from official.modeling.hyperparams import base_config
from official.nlp.bert import squad_evaluate_v1_1
from official.nlp.bert import squad_evaluate_v2_0
from official.nlp.bert import tokenization
from official.nlp.configs import encoders from official.nlp.configs import encoders
from official.nlp.data import data_loader_factory from official.nlp.data import data_loader_factory
from official.nlp.data import squad_lib as squad_lib_wp from official.nlp.data import squad_lib as squad_lib_wp
from official.nlp.data import squad_lib_sp from official.nlp.data import squad_lib_sp
from official.nlp.modeling import models from official.nlp.modeling import models
from official.nlp.tasks import utils from official.nlp.tasks import utils
from official.nlp.tools import squad_evaluate_v1_1
from official.nlp.tools import squad_evaluate_v2_0
from official.nlp.tools import tokenization
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -71,8 +71,8 @@ from absl import app ...@@ -71,8 +71,8 @@ from absl import app
from absl import flags from absl import flags
import gin import gin
from official.legacy.bert import configs
from official.modeling import hyperparams from official.modeling import hyperparams
from official.nlp.bert import configs
from official.nlp.configs import encoders from official.nlp.configs import encoders
from official.nlp.tools import export_tfhub_lib from official.nlp.tools import export_tfhub_lib
......
...@@ -28,8 +28,8 @@ import tensorflow as tf ...@@ -28,8 +28,8 @@ import tensorflow as tf
from tensorflow.core.protobuf import saved_model_pb2 from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
# pylint: enable=g-direct-tensorflow-import # pylint: enable=g-direct-tensorflow-import
from official.legacy.bert import configs
from official.modeling import tf_utils from official.modeling import tf_utils
from official.nlp.bert import configs
from official.nlp.configs import encoders from official.nlp.configs import encoders
from official.nlp.modeling import layers from official.nlp.modeling import layers
from official.nlp.modeling import models from official.nlp.modeling import models
......
...@@ -24,8 +24,8 @@ import tensorflow_hub as hub ...@@ -24,8 +24,8 @@ import tensorflow_hub as hub
import tensorflow_text as text import tensorflow_text as text
from sentencepiece import SentencePieceTrainer from sentencepiece import SentencePieceTrainer
from official.legacy.bert import configs
from official.modeling import tf_utils from official.modeling import tf_utils
from official.nlp.bert import configs
from official.nlp.configs import encoders from official.nlp.configs import encoders
from official.nlp.modeling import layers from official.nlp.modeling import layers
from official.nlp.modeling import models from official.nlp.modeling import models
......
...@@ -25,9 +25,9 @@ from absl import flags ...@@ -25,9 +25,9 @@ from absl import flags
import tensorflow as tf import tensorflow as tf
from official.legacy.albert import configs from official.legacy.albert import configs
from official.modeling import tf_utils from official.modeling import tf_utils
from official.nlp.bert import tf1_checkpoint_converter_lib
from official.nlp.modeling import models from official.nlp.modeling import models
from official.nlp.modeling import networks from official.nlp.modeling import networks
from official.nlp.tools import tf1_bert_checkpoint_converter_lib
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -128,12 +128,12 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint, ...@@ -128,12 +128,12 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint,
# Create a temporary V1 name-converted checkpoint in the output directory. # Create a temporary V1 name-converted checkpoint in the output directory.
temporary_checkpoint_dir = os.path.join(output_dir, "temp_v1") temporary_checkpoint_dir = os.path.join(output_dir, "temp_v1")
temporary_checkpoint = os.path.join(temporary_checkpoint_dir, "ckpt") temporary_checkpoint = os.path.join(temporary_checkpoint_dir, "ckpt")
tf1_checkpoint_converter_lib.convert( tf1_bert_checkpoint_converter_lib.convert(
checkpoint_from_path=v1_checkpoint, checkpoint_from_path=v1_checkpoint,
checkpoint_to_path=temporary_checkpoint, checkpoint_to_path=temporary_checkpoint,
num_heads=bert_config.num_attention_heads, num_heads=bert_config.num_attention_heads,
name_replacements=ALBERT_NAME_REPLACEMENTS, name_replacements=ALBERT_NAME_REPLACEMENTS,
permutations=tf1_checkpoint_converter_lib.BERT_V2_PERMUTATIONS, permutations=tf1_bert_checkpoint_converter_lib.BERT_V2_PERMUTATIONS,
exclude_patterns=["adam", "Adam"]) exclude_patterns=["adam", "Adam"])
# Create a V2 checkpoint from the temporary checkpoint. # Create a V2 checkpoint from the temporary checkpoint.
...@@ -144,9 +144,8 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint, ...@@ -144,9 +144,8 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint,
else: else:
raise ValueError("Unsupported converted_model: %s" % converted_model) raise ValueError("Unsupported converted_model: %s" % converted_model)
tf1_checkpoint_converter_lib.create_v2_checkpoint(model, temporary_checkpoint, tf1_bert_checkpoint_converter_lib.create_v2_checkpoint(
output_path, model, temporary_checkpoint, output_path, checkpoint_model_name)
checkpoint_model_name)
# Clean up the temporary checkpoint, if it exists. # Clean up the temporary checkpoint, if it exists.
try: try:
......
...@@ -25,11 +25,11 @@ from absl import app ...@@ -25,11 +25,11 @@ from absl import app
from absl import flags from absl import flags
import tensorflow as tf import tensorflow as tf
from official.legacy.bert import configs
from official.modeling import tf_utils from official.modeling import tf_utils
from official.nlp.bert import configs
from official.nlp.bert import tf1_checkpoint_converter_lib
from official.nlp.modeling import models from official.nlp.modeling import models
from official.nlp.modeling import networks from official.nlp.modeling import networks
from official.nlp.tools import tf1_bert_checkpoint_converter_lib
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -111,12 +111,13 @@ def convert_checkpoint(bert_config, ...@@ -111,12 +111,13 @@ def convert_checkpoint(bert_config,
temporary_checkpoint_dir = os.path.join(output_dir, "temp_v1") temporary_checkpoint_dir = os.path.join(output_dir, "temp_v1")
temporary_checkpoint = os.path.join(temporary_checkpoint_dir, "ckpt") temporary_checkpoint = os.path.join(temporary_checkpoint_dir, "ckpt")
tf1_checkpoint_converter_lib.convert( tf1_bert_checkpoint_converter_lib.convert(
checkpoint_from_path=v1_checkpoint, checkpoint_from_path=v1_checkpoint,
checkpoint_to_path=temporary_checkpoint, checkpoint_to_path=temporary_checkpoint,
num_heads=bert_config.num_attention_heads, num_heads=bert_config.num_attention_heads,
name_replacements=tf1_checkpoint_converter_lib.BERT_V2_NAME_REPLACEMENTS, name_replacements=(
permutations=tf1_checkpoint_converter_lib.BERT_V2_PERMUTATIONS, tf1_bert_checkpoint_converter_lib.BERT_V2_NAME_REPLACEMENTS),
permutations=tf1_bert_checkpoint_converter_lib.BERT_V2_PERMUTATIONS,
exclude_patterns=["adam", "Adam"]) exclude_patterns=["adam", "Adam"])
if converted_model == "encoder": if converted_model == "encoder":
...@@ -127,9 +128,8 @@ def convert_checkpoint(bert_config, ...@@ -127,9 +128,8 @@ def convert_checkpoint(bert_config,
raise ValueError("Unsupported converted_model: %s" % converted_model) raise ValueError("Unsupported converted_model: %s" % converted_model)
# Create a V2 checkpoint from the temporary checkpoint. # Create a V2 checkpoint from the temporary checkpoint.
tf1_checkpoint_converter_lib.create_v2_checkpoint(model, temporary_checkpoint, tf1_bert_checkpoint_converter_lib.create_v2_checkpoint(
output_path, model, temporary_checkpoint, output_path, checkpoint_model_name)
checkpoint_model_name)
# Clean up the temporary checkpoint, if it exists. # Clean up the temporary checkpoint, if it exists.
try: try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment