Commit 2e9bb539 authored by stephenwu's avatar stephenwu
Browse files

Merge branch 'master' of https://github.com/tensorflow/models into RTESuperGLUE

parents 7bae5317 8fba84f8
......@@ -14,3 +14,5 @@
# ==============================================================================
"""Ops package definition."""
from official.nlp.modeling.ops.beam_search import sequence_beam_search
from official.nlp.modeling.ops.segment_extractor import get_next_sentence_labels
from official.nlp.modeling.ops.segment_extractor import get_sentence_order_labels
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Exports a BERT-like encoder and its preprocessing as SavedModels for TF Hub.
This tool creates preprocessor and encoder SavedModels suitable for uploading
to https://tfhub.dev that implement the preprocessor and encoder APIs defined
at https://www.tensorflow.org/hub/common_saved_model_apis/text.
For a full usage guide, see
https://github.com/tensorflow/models/blob/master/official/nlp/docs/tfhub.md
Minimal usage examples:
1) Exporting an Encoder from checkpoint and config.
```
export_tfhub \
--encoder_config_file=${BERT_DIR:?}/bert_encoder.yaml \
--model_checkpoint_path=${BERT_DIR:?}/bert_model.ckpt \
--vocab_file=${BERT_DIR:?}/vocab.txt \
--export_type=model \
--export_path=/tmp/bert_model
```
An --encoder_config_file can specify encoder types other than BERT.
For BERT, a --bert_config_file in the legacy JSON format can be passed instead.
Flag --vocab_file (and flag --do_lower_case, whose default value is guessed
from the vocab_file path) capture how BertTokenizer was used in pre-training.
Use flag --sp_model_file instead if SentencepieceTokenizer was used.
Changing --export_type to model_with_mlm additionally creates an `.mlm`
subobject on the exported SavedModel that can be called to produce
the logits of the Masked Language Model task from pretraining.
The help string for flag --model_checkpoint_path explains the checkpoint
formats required for each --export_type.
2) Exporting a preprocessor SavedModel
```
export_tfhub \
--vocab_file ${BERT_DIR:?}/vocab.txt \
--export_type preprocessing --export_path /tmp/bert_preprocessing
```
Be sure to use flag values that match the encoder and how it has been
pre-trained (see above for --vocab_file vs --sp_model_file).
If your encoder has been trained with text preprocessing for which tfhub.dev
already has SavedModel, you could guide your users to reuse that one instead
of exporting and publishing your own.
TODO(b/175369555): When exporting to users of TensorFlow 2.4, add flag
`--experimental_disable_assert_in_preprocessing`.
"""
from absl import app
from absl import flags
import gin
from official.modeling import hyperparams
from official.nlp.bert import configs
from official.nlp.configs import encoders
from official.nlp.tools import export_tfhub_lib
FLAGS = flags.FLAGS
flags.DEFINE_enum(
"export_type", "model",
["model", "model_with_mlm", "preprocessing"],
"The overall type of SavedModel to export. Flags "
"--bert_config_file/--encoder_config_file and --vocab_file/--sp_model_file "
"control which particular encoder model and preprocessing are exported.")
flags.DEFINE_string(
"export_path", None,
"Directory to which the SavedModel is written.")
flags.DEFINE_string(
"encoder_config_file", None,
"A yaml file representing `encoders.EncoderConfig` to define the encoder "
"(BERT or other). "
"Exactly one of --bert_config_file and --encoder_config_file can be set. "
"Needed for --export_type model and model_with_mlm.")
flags.DEFINE_string(
"bert_config_file", None,
"A JSON file with a legacy BERT configuration to define the BERT encoder. "
"Exactly one of --bert_config_file and --encoder_config_file can be set. "
"Needed for --export_type model and model_with_mlm.")
flags.DEFINE_bool(
"copy_pooler_dense_to_encoder", False,
"When the model is trained using `BertPretrainerV2`, the pool layer "
"of next sentence prediction task exists in `ClassificationHead` passed "
"to `BertPretrainerV2`. If True, we will copy this pooler's dense layer "
"to the encoder that is exported by this tool (as in classic BERT). "
"Using `BertPretrainerV2` and leaving this False exports an untrained "
"(randomly initialized) pooling layer, which some authors recommend for "
"subsequent fine-tuning,")
flags.DEFINE_string(
"model_checkpoint_path", None,
"File path to a pre-trained model checkpoint. "
"For --export_type model, this has to be an object-based (TF2) checkpoint "
"that can be restored to `tf.train.Checkpoint(encoder=encoder)` "
"for the `encoder` defined by the config file."
"(Legacy checkpoints with `model=` instead of `encoder=` are also "
"supported for now.) "
"For --export_type model_with_mlm, it must be restorable to "
"`tf.train.Checkpoint(**BertPretrainerV2(...).checkpoint_items)`. "
"(For now, `tf.train.Checkpoint(pretrainer=BertPretrainerV2(...))` is also "
"accepted.)")
flags.DEFINE_string(
"vocab_file", None,
"For encoders trained on BertTokenzier input: "
"the vocabulary file that the encoder model was trained with. "
"Exactly one of --vocab_file and --sp_model_file can be set. "
"Needed for --export_type model, model_with_mlm and preprocessing.")
flags.DEFINE_string(
"sp_model_file", None,
"For encoders trained on SentencepieceTokenzier input: "
"the SentencePiece .model file that the encoder model was trained with. "
"Exactly one of --vocab_file and --sp_model_file can be set. "
"Needed for --export_type model, model_with_mlm and preprocessing.")
flags.DEFINE_bool(
"do_lower_case", None,
"Whether to lowercase before tokenization. "
"If left as None, and --vocab_file is set, do_lower_case will be enabled "
"if 'uncased' appears in the name of --vocab_file. "
"If left as None, and --sp_model_file set, do_lower_case defaults to true. "
"Needed for --export_type model, model_with_mlm and preprocessing.")
flags.DEFINE_integer(
"default_seq_length", 128,
"The sequence length of preprocessing results from "
"top-level preprocess method. This is also the default "
"sequence length for the bert_pack_inputs subobject."
"Needed for --export_type preprocessing.")
flags.DEFINE_bool(
"tokenize_with_offsets", False, # Broken by b/149576200.
"Whether to export a .tokenize_with_offsets subobject for "
"--export_type preprocessing.")
flags.DEFINE_multi_string(
"gin_file", default=None,
help="List of paths to the config files.")
flags.DEFINE_multi_string(
"gin_params", default=None,
help="List of Gin bindings.")
flags.DEFINE_bool( # TODO(b/175369555): Remove this flag and its use.
"experimental_disable_assert_in_preprocessing", False,
"Export a preprocessing model without tf.Assert ops. "
"Usually, that would be a bad idea, except TF2.4 has an issue with "
"Assert ops in tf.functions used in Dataset.map() on a TPU worker, "
"and omitting the Assert ops lets SavedModels avoid the issue.")
def main(argv):
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
if bool(FLAGS.vocab_file) == bool(FLAGS.sp_model_file):
raise ValueError("Exactly one of `vocab_file` and `sp_model_file` "
"can be specified, but got %s and %s." %
(FLAGS.vocab_file, FLAGS.sp_model_file))
do_lower_case = export_tfhub_lib.get_do_lower_case(
FLAGS.do_lower_case, FLAGS.vocab_file, FLAGS.sp_model_file)
if FLAGS.export_type in ("model", "model_with_mlm"):
if bool(FLAGS.bert_config_file) == bool(FLAGS.encoder_config_file):
raise ValueError("Exactly one of `bert_config_file` and "
"`encoder_config_file` can be specified, but got "
"%s and %s." %
(FLAGS.bert_config_file, FLAGS.encoder_config_file))
if FLAGS.bert_config_file:
bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
encoder_config = None
else:
bert_config = None
encoder_config = encoders.EncoderConfig()
encoder_config = hyperparams.override_params_dict(
encoder_config, FLAGS.encoder_config_file, is_strict=True)
export_tfhub_lib.export_model(
FLAGS.export_path,
bert_config=bert_config,
encoder_config=encoder_config,
model_checkpoint_path=FLAGS.model_checkpoint_path,
vocab_file=FLAGS.vocab_file,
sp_model_file=FLAGS.sp_model_file,
do_lower_case=do_lower_case,
with_mlm=FLAGS.export_type == "model_with_mlm",
copy_pooler_dense_to_encoder=FLAGS.copy_pooler_dense_to_encoder)
elif FLAGS.export_type == "preprocessing":
export_tfhub_lib.export_preprocessing(
FLAGS.export_path,
vocab_file=FLAGS.vocab_file,
sp_model_file=FLAGS.sp_model_file,
do_lower_case=do_lower_case,
default_seq_length=FLAGS.default_seq_length,
tokenize_with_offsets=FLAGS.tokenize_with_offsets,
experimental_disable_assert=
FLAGS.experimental_disable_assert_in_preprocessing)
else:
raise app.UsageError(
"Unknown value '%s' for flag --export_type" % FLAGS.export_type)
if __name__ == "__main__":
app.run(main)
This diff is collapsed.
This diff is collapsed.
......@@ -109,8 +109,8 @@ class Transformer(tf.keras.Model):
sequence. float tensor with shape [batch_size, target_length, vocab_size]
If target is none, then generate output sequence one token at a time.
returns a dictionary {
outputs: [batch_size, decoded length]
scores: [batch_size, float]}
outputs: int tensor with shape [batch_size, decoded_length]
scores: float tensor with shape [batch_size]}
Even when float16 is used, the output tensor(s) are always float32.
Raises:
......
......@@ -151,14 +151,8 @@ def translate_file(model,
text = distribution_strategy.run(text_as_per_replica)
outputs = distribution_strategy.experimental_local_results(
predict_step(text))
tags, unordered_val_outputs = outputs[0]
tags = [tag.numpy() for tag in tags._values]
unordered_val_outputs = [
val_output.numpy() for val_output in unordered_val_outputs._values]
# pylint: enable=protected-access
val_outputs = [None] * len(tags)
for k in range(len(tags)):
val_outputs[tags[k]] = unordered_val_outputs[k]
val_outputs = [output for _, output in outputs]
val_outputs = np.reshape(val_outputs, [params["decode_batch_size"], -1])
else:
val_outputs, _ = model.predict(text)
......
......@@ -22,14 +22,15 @@ import random
# Import libraries
from absl import app
from absl import flags
import absl.logging as _logging # pylint: disable=unused-import
from absl import logging
import numpy as np
import tensorflow.google as tf
from official.nlp.xlnet import preprocess_utils
import tensorflow.compat.v1 as tf
import sentencepiece as spm
from official.nlp.xlnet import preprocess_utils
FLAGS = flags.FLAGS
special_symbols = {
......@@ -89,6 +90,7 @@ def format_filename(prefix, bsz_per_host, seq_len, bi_data, suffix,
def _create_data(idx, input_paths):
"""Creates data."""
# Load sentence-piece model
sp = spm.SentencePieceProcessor()
sp.Load(FLAGS.sp_path)
......@@ -98,10 +100,10 @@ def _create_data(idx, input_paths):
for input_path in input_paths:
input_data, sent_ids = [], []
sent_id, line_cnt = True, 0
tf.logging.info("Processing %s", input_path)
logging.info("Processing %s", input_path)
for line in tf.gfile.Open(input_path):
if line_cnt % 100000 == 0:
tf.logging.info("Loading line %d", line_cnt)
logging.info("Loading line %d", line_cnt)
line_cnt += 1
if not line.strip():
......@@ -122,7 +124,7 @@ def _create_data(idx, input_paths):
sent_ids.extend([sent_id] * len(cur_sent))
sent_id = not sent_id
tf.logging.info("Finish with line %d", line_cnt)
logging.info("Finish with line %d", line_cnt)
if line_cnt == 0:
continue
......@@ -132,7 +134,7 @@ def _create_data(idx, input_paths):
total_line_cnt += line_cnt
input_shards.append((input_data, sent_ids))
tf.logging.info("[Task %d] Total number line: %d", idx, total_line_cnt)
logging.info("[Task %d] Total number line: %d", idx, total_line_cnt)
tfrecord_dir = os.path.join(FLAGS.save_dir, "tfrecords")
......@@ -142,8 +144,8 @@ def _create_data(idx, input_paths):
np.random.seed(100 * FLAGS.task + FLAGS.pass_id)
perm_indices = np.random.permutation(len(input_shards))
tf.logging.info("Using perm indices %s for pass %d",
perm_indices.tolist(), FLAGS.pass_id)
logging.info("Using perm indices %s for pass %d",
perm_indices.tolist(), FLAGS.pass_id)
input_data_list, sent_ids_list = [], []
prev_sent_id = None
......@@ -185,6 +187,7 @@ def _create_data(idx, input_paths):
def create_data(_):
"""Creates pretrain data."""
# Validate FLAGS
assert FLAGS.bsz_per_host % FLAGS.num_core_per_host == 0
if not FLAGS.use_tpu:
......@@ -221,16 +224,16 @@ def create_data(_):
# Interleavely split the work into FLAGS.num_task splits
file_paths = sorted(tf.gfile.Glob(FLAGS.input_glob))
tf.logging.info("Use glob: %s", FLAGS.input_glob)
tf.logging.info("Find %d files: %s", len(file_paths), file_paths)
logging.info("Use glob: %s", FLAGS.input_glob)
logging.info("Find %d files: %s", len(file_paths), file_paths)
task_file_paths = file_paths[FLAGS.task::FLAGS.num_task]
if not task_file_paths:
tf.logging.info("Exit: task %d has no file to process.", FLAGS.task)
logging.info("Exit: task %d has no file to process.", FLAGS.task)
return
tf.logging.info("Task %d process %d files: %s",
FLAGS.task, len(task_file_paths), task_file_paths)
logging.info("Task %d process %d files: %s",
FLAGS.task, len(task_file_paths), task_file_paths)
record_info = _create_data(FLAGS.task, task_file_paths)
record_prefix = "record_info-{}-{}-{}".format(
......@@ -253,6 +256,7 @@ def create_data(_):
def batchify(data, bsz_per_host, sent_ids=None):
"""Creates batches."""
num_step = len(data) // bsz_per_host
data = data[:bsz_per_host * num_step]
data = data.reshape(bsz_per_host, num_step)
......@@ -270,9 +274,9 @@ def _split_a_and_b(data, sent_ids, begin_idx, tot_len, extend_target=False):
data_len = data.shape[0]
if begin_idx + tot_len >= data_len:
tf.logging.info("[_split_a_and_b] returns None: "
"begin_idx %d + tot_len %d >= data_len %d",
begin_idx, tot_len, data_len)
logging.info("[_split_a_and_b] returns None: "
"begin_idx %d + tot_len %d >= data_len %d",
begin_idx, tot_len, data_len)
return None
end_idx = begin_idx + 1
......@@ -284,9 +288,9 @@ def _split_a_and_b(data, sent_ids, begin_idx, tot_len, extend_target=False):
end_idx += 1
a_begin = begin_idx
if len(cut_points) == 0 or random.random() < 0.5:
if len(cut_points) == 0 or random.random() < 0.5: # pylint:disable=g-explicit-length-test
label = 0
if len(cut_points) == 0:
if len(cut_points) == 0: # pylint:disable=g-explicit-length-test
a_end = end_idx
else:
a_end = random.choice(cut_points)
......@@ -321,9 +325,9 @@ def _split_a_and_b(data, sent_ids, begin_idx, tot_len, extend_target=False):
if extend_target:
if a_end >= data_len or b_end >= data_len:
tf.logging.info("[_split_a_and_b] returns None: "
"a_end %d or b_end %d >= data_len %d",
a_end, b_end, data_len)
logging.info("[_split_a_and_b] returns None: "
"a_end %d or b_end %d >= data_len %d",
a_end, b_end, data_len)
return None
a_target = data[a_begin + 1: a_end + 1]
b_target = data[b_begin: b_end + 1]
......@@ -342,9 +346,7 @@ def _is_start_piece(piece):
def _sample_mask(sp, seg, reverse=False, max_gram=5, goal_num_predict=None):
"""Sample `goal_num_predict` tokens for partial prediction.
About `mask_beta` tokens are chosen in a context of `mask_alpha` tokens."""
"""Samples `goal_num_predict` tokens for partial prediction."""
seg_len = len(seg)
mask = np.array([False] * seg_len, dtype=np.bool)
......@@ -406,8 +408,7 @@ def _sample_mask(sp, seg, reverse=False, max_gram=5, goal_num_predict=None):
def _sample_mask_ngram(sp, seg, reverse=False, max_gram=5,
goal_num_predict=None):
"""Sample `goal_num_predict` tokens for partial prediction.
About `mask_beta` tokens are chosen in a context of `mask_alpha` tokens."""
"""Sample `goal_num_predict` tokens for partial prediction."""
seg_len = len(seg)
mask = np.array([False] * seg_len, dtype=np.bool)
......@@ -474,6 +475,7 @@ def _sample_mask_ngram(sp, seg, reverse=False, max_gram=5,
def create_tfrecords(save_dir, basename, data, bsz_per_host, seq_len,
bi_data, sp):
"""Creates TFRecords."""
data, sent_ids = data[0], data[1]
num_core = FLAGS.num_core_per_host
......@@ -496,7 +498,7 @@ def create_tfrecords(save_dir, basename, data, bsz_per_host, seq_len,
else:
data, sent_ids = batchify(data, bsz_per_host, sent_ids)
tf.logging.info("Raw data shape %s.", data.shape)
logging.info("Raw data shape %s.", data.shape)
file_name = format_filename(
prefix=basename,
......@@ -512,7 +514,7 @@ def create_tfrecords(save_dir, basename, data, bsz_per_host, seq_len,
)
save_path = os.path.join(save_dir, file_name)
record_writer = tf.python_io.TFRecordWriter(save_path)
tf.logging.info("Start writing %s.", save_path)
logging.info("Start writing %s.", save_path)
num_batch = 0
reuse_len = FLAGS.reuse_len
......@@ -527,7 +529,7 @@ def create_tfrecords(save_dir, basename, data, bsz_per_host, seq_len,
i = 0
while i + seq_len <= data_len:
if num_batch % 500 == 0:
tf.logging.info("Processing batch %d", num_batch)
logging.info("Processing batch %d", num_batch)
all_ok = True
features = []
......@@ -542,7 +544,7 @@ def create_tfrecords(save_dir, basename, data, bsz_per_host, seq_len,
tot_len=seq_len - reuse_len - 3,
extend_target=True)
if results is None:
tf.logging.info("Break out with seq idx %d", i)
logging.info("Break out with seq idx %d", i)
all_ok = False
break
......@@ -600,7 +602,7 @@ def create_tfrecords(save_dir, basename, data, bsz_per_host, seq_len,
i += reuse_len
record_writer.close()
tf.logging.info("Done writing %s. Num of batches: %d", save_path, num_batch)
logging.info("Done writing %s. Num of batches: %d", save_path, num_batch)
return save_path, num_batch
......@@ -624,6 +626,7 @@ def _convert_example(example, use_bfloat16):
def parse_files_to_dataset(parser, file_names, split, num_batch, num_hosts,
host_id, num_core_per_host, bsz_per_core):
"""Parses files to a dataset."""
# list of file pathes
num_files = len(file_names)
num_files_per_host = num_files // num_hosts
......@@ -632,7 +635,7 @@ def parse_files_to_dataset(parser, file_names, split, num_batch, num_hosts,
if host_id == num_hosts - 1:
my_end_file_id = num_files
file_paths = file_names[my_start_file_id: my_end_file_id]
tf.logging.info("Host %d handles %d files", host_id, len(file_paths))
logging.info("Host %d handles %d files", host_id, len(file_paths))
assert split == "train"
dataset = tf.data.Dataset.from_tensor_slices(file_paths)
......@@ -657,9 +660,7 @@ def parse_files_to_dataset(parser, file_names, split, num_batch, num_hosts,
def _local_perm(inputs, targets, is_masked, perm_size, seq_len):
"""
Sample a permutation of the factorization order, and create an
attention mask accordingly.
"""Samples a permutation of the factorization order, and create a mask.
Args:
inputs: int64 Tensor in shape [seq_len], input ids.
......@@ -669,6 +670,10 @@ def _local_perm(inputs, targets, is_masked, perm_size, seq_len):
perm_size: the length of longest permutation. Could be set to be reuse_len.
Should not be larger than reuse_len or there will be data leaks.
seq_len: int, sequence length.
Returns:
The permutation mask, new targets, target mask, and new inputs.
"""
# Generate permutation indices
......@@ -726,6 +731,7 @@ def _local_perm(inputs, targets, is_masked, perm_size, seq_len):
def get_dataset(params, num_hosts, num_core_per_host, split, file_names,
num_batch, seq_len, reuse_len, perm_size, mask_alpha,
mask_beta, use_bfloat16=False, num_predict=None):
"""Gets the dataset."""
bsz_per_core = params["batch_size"]
if num_hosts > 1:
......@@ -821,7 +827,7 @@ def get_dataset(params, num_hosts, num_core_per_host, split, file_names,
_convert_example(example, use_bfloat16)
for k, v in example.items():
tf.logging.info("%s: %s", k, v)
logging.info("%s: %s", k, v)
return example
......@@ -855,6 +861,7 @@ def get_input_fn(
num_passes=None,
use_bfloat16=False,
num_predict=None):
"""Gets the input function."""
# Merge all record infos into a single one
record_glob_base = format_filename(
......@@ -872,15 +879,14 @@ def get_input_fn(
record_info = {"num_batch": 0, "filenames": []}
tfrecord_dirs = tfrecord_dir.split(",")
tf.logging.info("Use the following tfrecord dirs: %s", tfrecord_dirs)
logging.info("Use the following tfrecord dirs: %s", tfrecord_dirs)
for idx, record_dir in enumerate(tfrecord_dirs):
record_glob = os.path.join(record_dir, record_glob_base)
tf.logging.info("[%d] Record glob: %s", idx, record_glob)
logging.info("[%d] Record glob: %s", idx, record_glob)
record_paths = sorted(tf.gfile.Glob(record_glob))
tf.logging.info("[%d] Num of record info path: %d",
idx, len(record_paths))
logging.info("[%d] Num of record info path: %d", idx, len(record_paths))
cur_record_info = {"num_batch": 0, "filenames": []}
......@@ -890,7 +896,7 @@ def get_input_fn(
fields = record_info_name.split(".")[0].split("-")
pass_id = int(fields[-1])
if len(fields) == 5 and pass_id >= num_passes:
tf.logging.info("Skip pass %d: %s", pass_id, record_info_name)
logging.info("Skip pass %d: %s", pass_id, record_info_name)
continue
with tf.gfile.Open(record_info_path, "r") as fp:
......@@ -912,21 +918,19 @@ def get_input_fn(
new_filenames.append(new_filename)
cur_record_info["filenames"] = new_filenames
tf.logging.info("[Dir %d] Number of chosen batches: %s",
idx, cur_record_info["num_batch"])
tf.logging.info("[Dir %d] Number of chosen files: %s",
idx, len(cur_record_info["filenames"]))
tf.logging.info(cur_record_info["filenames"])
logging.info("[Dir %d] Number of chosen batches: %s",
idx, cur_record_info["num_batch"])
logging.info("[Dir %d] Number of chosen files: %s",
idx, len(cur_record_info["filenames"]))
logging.info(cur_record_info["filenames"])
# add `cur_record_info` to global `record_info`
record_info["num_batch"] += cur_record_info["num_batch"]
record_info["filenames"] += cur_record_info["filenames"]
tf.logging.info("Total number of batches: %d",
record_info["num_batch"])
tf.logging.info("Total number of files: %d",
len(record_info["filenames"]))
tf.logging.info(record_info["filenames"])
logging.info("Total number of batches: %d", record_info["num_batch"])
logging.info("Total number of files: %d", len(record_info["filenames"]))
logging.info(record_info["filenames"])
def input_fn(params):
"""docs."""
......@@ -952,8 +956,8 @@ def get_input_fn(
return input_fn, record_info
if __name__ == "__main__":
FLAGS = flags.FLAGS
def define_flags():
"""Defines relevant flags."""
flags.DEFINE_bool("use_tpu", True, help="whether to use TPUs")
flags.DEFINE_integer("bsz_per_host", 32, help="batch size per host.")
flags.DEFINE_integer("num_core_per_host", 8, help="num TPU cores per host.")
......@@ -991,5 +995,8 @@ if __name__ == "__main__":
flags.DEFINE_integer("task", 0, help="The Task ID. This value is used when "
"using multiple workers to identify each worker.")
tf.logging.set_verbosity(tf.logging.INFO)
if __name__ == "__main__":
define_flags()
logging.set_verbosity(logging.INFO)
app.run(create_data)
......@@ -17,9 +17,10 @@ r"""Tool to generate api_docs for tensorflow_models/official library.
Example:
python build_docs \
$> pip install -U git+https://github.com/tensorflow/docs
$> python build_docs \
--output_dir=/tmp/api_docs \
--project_short_name=tf_nlp.modeling \
--project_short_name=tfnlp \
--project_full_name="TensorFlow Official Models - NLP Modeling Library"
"""
......@@ -34,7 +35,7 @@ from tensorflow_docs.api_generator import doc_controls
from tensorflow_docs.api_generator import generate_lib
from tensorflow_docs.api_generator import public_api
from official.nlp import modeling as tf_nlp_modeling
from official.nlp import modeling as tfnlp
FLAGS = flags.FLAGS
......@@ -47,18 +48,15 @@ flags.DEFINE_string(
flags.DEFINE_bool('search_hints', True,
'Include metadata search hints in the generated files')
flags.DEFINE_string('site_path', 'tf_nlp_modeling/api_docs/python',
flags.DEFINE_string('site_path', '/api_docs/python',
'Path prefix in the _toc.yaml')
flags.DEFINE_bool('gen_report', False,
'Generate an API report containing the health of the '
'docstrings of the public API.')
flags.DEFINE_string(
'project_short_name', 'tf_nlp.modeling',
'The project short name referring to the python module to document.')
flags.DEFINE_string('project_full_name',
'TensorFlow Official Models - NLP Modeling Library',
'The main title for the project.')
PROJECT_SHORT_NAME = 'tfnlp'
PROJECT_FULL_NAME = 'TensorFlow Official Models - NLP Modeling Library'
def _hide_module_model_and_layer_methods():
......@@ -104,8 +102,8 @@ def gen_api_docs(code_url_prefix, site_path, output_dir, gen_report,
doc_generator = generate_lib.DocGenerator(
root_title=project_full_name,
py_modules=[(project_short_name, tf_nlp_modeling)],
base_dir=os.path.dirname(tf_nlp_modeling.__file__),
py_modules=[(project_short_name, tfnlp)],
base_dir=os.path.dirname(tfnlp.__file__),
code_url_prefix=code_url_prefix,
search_hints=search_hints,
site_path=site_path,
......@@ -126,8 +124,8 @@ def main(argv):
site_path=FLAGS.site_path,
output_dir=FLAGS.output_dir,
gen_report=FLAGS.gen_report,
project_short_name=FLAGS.project_short_name,
project_full_name=FLAGS.project_full_name,
project_short_name=PROJECT_SHORT_NAME,
project_full_name=PROJECT_FULL_NAME,
search_hints=FLAGS.search_hints)
......
......@@ -6,29 +6,39 @@ TF Vision model garden provides a large collection of baselines and checkpoints
## Image Classification
### ImageNet Baselines
#### Models trained with vanilla settings:
#### ResNet models trained with vanilla settings:
* Models are trained from scratch with batch size 4096 and 1.6 initial learning rate.
* Linear warmup is applied for the first 5 epochs.
* Models trained with l2 weight regularization and ReLU activation.
| model | resolution | epochs | Top-1 | Top-5 | download |
| ------------ |:-------------:|--------:|--------:|---------:|---------:|
| ResNet-50 | 224x224 | 90 | 76.1 | 92.9 | config |
| ResNet-50 | 224x224 | 200 | 77.1 | 93.5 | config |
| ResNet-101 | 224x224 | 200 | 78.3 | 94.2 | config |
| ResNet-152 | 224x224 | 200 | 78.7 | 94.3 | config |
#### Models trained with training features including:
* Label smoothing 0.1.
* Swish activation.
| model | resolution | epochs | Top-1 | Top-5 | download |
| ------------ |:-------------:| ---------:|--------:|---------:|---------:|
| ResNet-50 | 224x224 | 200 | 78.1 | 93.9 | [config](https://github.com/tensorflow/models/blob/master/official/vision/beta/configs/experiments/image_classification/imagenet_resnet50_tpu.yaml) |
| ResNet-101 | 224x224 | 200 | 79.1 | 94.5 | [config](https://github.com/tensorflow/models/blob/master/official/vision/beta/configs/experiments/image_classification/imagenet_resnet101_tpu.yaml) |
| ResNet-152 | 224x224 | 200 | 79.4 | 94.7 | [config](https://github.com/tensorflow/models/blob/master/official/vision/beta/configs/experiments/image_classification/imagenet_resnet152_tpu.yaml) |
| ResNet-200 | 224x224 | 200 | 79.9 | 94.8 | [config](https://github.com/tensorflow/models/blob/master/official/vision/beta/configs/experiments/image_classification/imagenet_resnet200_tpu.yaml) |
| ResNet-50 | 224x224 | 90 | 76.1 | 92.9 | [config](https://github.com/tensorflow/models/blob/master/official/vision/beta/configs/experiments/image_classification/imagenet_resnet50_tpu.yaml) |
| ResNet-50 | 224x224 | 200 | 77.1 | 93.5 | [config](https://github.com/tensorflow/models/blob/master/official/vision/beta/configs/experiments/image_classification/imagenet_resnet50_tpu.yaml) |
| ResNet-101 | 224x224 | 200 | 78.3 | 94.2 | [config](https://github.com/tensorflow/models/blob/master/official/vision/beta/configs/experiments/image_classification/imagenet_resnet101_tpu.yaml) |
| ResNet-152 | 224x224 | 200 | 78.7 | 94.3 | [config](https://github.com/tensorflow/models/blob/master/official/vision/beta/configs/experiments/image_classification/imagenet_resnet152_tpu.yaml) |
#### ResNet-RS models trained with settings including:
* ResNet-RS architectural changes and Swish activation.
* Regularization methods including Random Augment, 4e-5 weight decay, stochastic depth, label smoothing and dropout.
* New training methods including a 350-epoch schedule, cosine learning rate and
EMA.
* Configs are in this [directory](https://github.com/tensorflow/models/blob/master/official/vision/beta/configs/experiments/image_classification).
model | resolution | params (M) | Top-1 | Top-5 | download
--------- | :--------: | -----: | ----: | ----: | -------:
ResNet-RS-50 | 160x160 | 35.7 | 79.1 | 94.5 | [config](https://github.com/tensorflow/models/blob/master/official/vision/beta/configs/experiments/image_classification/imagenet_resnetrs50_i160.yaml) |
ResNet-RS-101 | 160x160 | 63.7 | 80.2 | 94.9 | [config](https://github.com/tensorflow/models/blob/master/official/vision/beta/configs/experiments/image_classification/imagenet_resnetrs101_i160.yaml) |
ResNet-RS-101 | 192x192 | 63.7 | 81.3 | 95.6 | [config](https://github.com/tensorflow/models/blob/master/official/vision/beta/configs/experiments/image_classification/imagenet_resnetrs101_i192.yaml) |
ResNet-RS-152 | 192x192 | 86.8 | 81.9 | 95.8 | [config](https://github.com/tensorflow/models/blob/master/official/vision/beta/configs/experiments/image_classification/imagenet_resnetrs152_i192.yaml) |
ResNet-RS-152 | 224x224 | 86.8 | 82.5 | 96.1 | [config](https://github.com/tensorflow/models/blob/master/official/vision/beta/configs/experiments/image_classification/imagenet_resnetrs152_i224.yaml) |
ResNet-RS-152 | 256x256 | 86.8 | 83.1 | 96.3 | [config](https://github.com/tensorflow/models/blob/master/official/vision/beta/configs/experiments/image_classification/imagenet_resnetrs152_i256.yaml) |
ResNet-RS-200 | 256x256 | 93.4 | 83.5 | 96.6 | [config](https://github.com/tensorflow/models/blob/master/official/vision/beta/configs/experiments/image_classification/imagenet_resnetrs200_i256.yaml) |
ResNet-RS-270 | 256x256 | 130.1 | 83.6 | 96.6 | [config](https://github.com/tensorflow/models/blob/master/official/vision/beta/configs/experiments/image_classification/imagenet_resnetrs270_i256.yaml) |
ResNet-RS-350 | 256x256 | 164.3 | 83.7 | 96.7 | [config](https://github.com/tensorflow/models/blob/master/official/vision/beta/configs/experiments/image_classification/imagenet_resnetrs350_i256.yaml) |
ResNet-RS-350 | 320x320 | 164.3 | 84.2 | 96.9 | [config](https://github.com/tensorflow/models/blob/master/official/vision/beta/configs/experiments/image_classification/imagenet_resnetrs420_i256.yaml) |
## Object Detection and Instance Segmentation
......
......@@ -30,6 +30,8 @@ class ResNet(hyperparams.Config):
stem_type: str = 'v0'
se_ratio: float = 0.0
stochastic_depth_drop_rate: float = 0.0
resnetd_shortcut: bool = False
replace_stem_max_pool: bool = False
@dataclasses.dataclass
......
......@@ -27,3 +27,11 @@ class NormActivation(hyperparams.Config):
use_sync_bn: bool = True
norm_momentum: float = 0.99
norm_epsilon: float = 0.001
@dataclasses.dataclass
class PseudoLabelDataConfig(hyperparams.Config):
"""Psuedo Label input config for training."""
input_path: str = ''
data_ratio: float = 1.0 # Per-batch ratio of pseudo-labeled to labeled data
file_type: str = 'tfrecord'
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment